mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
Merge branch 'main' into fix-formatting
# Conflicts: # pkg/agent/loop.go # pkg/agent/loop_test.go # pkg/channels/discord.go # pkg/channels/onebot.go # pkg/config/config.go # pkg/tools/subagent_tool_test.go
This commit is contained in:
+55
-17
@@ -23,15 +23,19 @@ type SubagentTask struct {
|
||||
}
|
||||
|
||||
type SubagentManager struct {
|
||||
tasks map[string]*SubagentTask
|
||||
mu sync.RWMutex
|
||||
provider providers.LLMProvider
|
||||
defaultModel string
|
||||
bus *bus.MessageBus
|
||||
workspace string
|
||||
tools *ToolRegistry
|
||||
maxIterations int
|
||||
nextID int
|
||||
tasks map[string]*SubagentTask
|
||||
mu sync.RWMutex
|
||||
provider providers.LLMProvider
|
||||
defaultModel string
|
||||
bus *bus.MessageBus
|
||||
workspace string
|
||||
tools *ToolRegistry
|
||||
maxIterations int
|
||||
maxTokens int
|
||||
temperature float64
|
||||
hasMaxTokens bool
|
||||
hasTemperature bool
|
||||
nextID int
|
||||
}
|
||||
|
||||
func NewSubagentManager(
|
||||
@@ -51,6 +55,16 @@ func NewSubagentManager(
|
||||
}
|
||||
}
|
||||
|
||||
// SetLLMOptions sets max tokens and temperature for subagent LLM calls.
|
||||
func (sm *SubagentManager) SetLLMOptions(maxTokens int, temperature float64) {
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
sm.maxTokens = maxTokens
|
||||
sm.hasMaxTokens = true
|
||||
sm.temperature = temperature
|
||||
sm.hasTemperature = true
|
||||
}
|
||||
|
||||
// SetTools sets the tool registry for subagent execution.
|
||||
// If not set, subagent will have access to the provided tools.
|
||||
func (sm *SubagentManager) SetTools(tools *ToolRegistry) {
|
||||
@@ -133,17 +147,29 @@ After completing the task, provide a clear summary of what was done.`
|
||||
sm.mu.RLock()
|
||||
tools := sm.tools
|
||||
maxIter := sm.maxIterations
|
||||
maxTokens := sm.maxTokens
|
||||
temperature := sm.temperature
|
||||
hasMaxTokens := sm.hasMaxTokens
|
||||
hasTemperature := sm.hasTemperature
|
||||
sm.mu.RUnlock()
|
||||
|
||||
var llmOptions map[string]any
|
||||
if hasMaxTokens || hasTemperature {
|
||||
llmOptions = map[string]any{}
|
||||
if hasMaxTokens {
|
||||
llmOptions["max_tokens"] = maxTokens
|
||||
}
|
||||
if hasTemperature {
|
||||
llmOptions["temperature"] = temperature
|
||||
}
|
||||
}
|
||||
|
||||
loopResult, err := RunToolLoop(ctx, ToolLoopConfig{
|
||||
Provider: sm.provider,
|
||||
Model: sm.defaultModel,
|
||||
Tools: tools,
|
||||
MaxIterations: maxIter,
|
||||
LLMOptions: map[string]any{
|
||||
"max_tokens": 4096,
|
||||
"temperature": 0.7,
|
||||
},
|
||||
LLMOptions: llmOptions,
|
||||
}, messages, task.OriginChannel, task.OriginChatID)
|
||||
|
||||
sm.mu.Lock()
|
||||
@@ -296,17 +322,29 @@ func (t *SubagentTool) Execute(ctx context.Context, args map[string]any) *ToolRe
|
||||
sm.mu.RLock()
|
||||
tools := sm.tools
|
||||
maxIter := sm.maxIterations
|
||||
maxTokens := sm.maxTokens
|
||||
temperature := sm.temperature
|
||||
hasMaxTokens := sm.hasMaxTokens
|
||||
hasTemperature := sm.hasTemperature
|
||||
sm.mu.RUnlock()
|
||||
|
||||
var llmOptions map[string]any
|
||||
if hasMaxTokens || hasTemperature {
|
||||
llmOptions = map[string]any{}
|
||||
if hasMaxTokens {
|
||||
llmOptions["max_tokens"] = maxTokens
|
||||
}
|
||||
if hasTemperature {
|
||||
llmOptions["temperature"] = temperature
|
||||
}
|
||||
}
|
||||
|
||||
loopResult, err := RunToolLoop(ctx, ToolLoopConfig{
|
||||
Provider: sm.provider,
|
||||
Model: sm.defaultModel,
|
||||
Tools: tools,
|
||||
MaxIterations: maxIter,
|
||||
LLMOptions: map[string]any{
|
||||
"max_tokens": 4096,
|
||||
"temperature": 0.7,
|
||||
},
|
||||
LLMOptions: llmOptions,
|
||||
}, messages, t.originChannel, t.originChatID)
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("Subagent execution failed: %v", err)).WithError(err)
|
||||
|
||||
@@ -10,15 +10,12 @@ import (
|
||||
)
|
||||
|
||||
// MockLLMProvider is a test implementation of LLMProvider
|
||||
type MockLLMProvider struct{}
|
||||
type MockLLMProvider struct {
|
||||
lastOptions map[string]interface{}
|
||||
}
|
||||
|
||||
func (m *MockLLMProvider) Chat(
|
||||
ctx context.Context,
|
||||
messages []providers.Message,
|
||||
tools []providers.ToolDefinition,
|
||||
model string,
|
||||
options map[string]any,
|
||||
) (*providers.LLMResponse, error) {
|
||||
func (m *MockLLMProvider) Chat(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, options map[string]interface{}) (*providers.LLMResponse, error) {
|
||||
m.lastOptions = options
|
||||
// Find the last user message to generate a response
|
||||
for i := len(messages) - 1; i >= 0; i-- {
|
||||
if messages[i].Role == "user" {
|
||||
@@ -42,6 +39,32 @@ func (m *MockLLMProvider) GetContextWindow() int {
|
||||
return 4096
|
||||
}
|
||||
|
||||
func TestSubagentManager_SetLLMOptions_AppliesToRunToolLoop(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil)
|
||||
manager.SetLLMOptions(2048, 0.6)
|
||||
tool := NewSubagentTool(manager)
|
||||
tool.SetContext("cli", "direct")
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{"task": "Do something"}
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
if result == nil || result.IsError {
|
||||
t.Fatalf("Expected successful result, got: %+v", result)
|
||||
}
|
||||
|
||||
if provider.lastOptions == nil {
|
||||
t.Fatal("Expected LLM options to be passed, got nil")
|
||||
}
|
||||
if provider.lastOptions["max_tokens"] != 2048 {
|
||||
t.Fatalf("max_tokens = %v, want %d", provider.lastOptions["max_tokens"], 2048)
|
||||
}
|
||||
if provider.lastOptions["temperature"] != 0.6 {
|
||||
t.Fatalf("temperature = %v, want %v", provider.lastOptions["temperature"], 0.6)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSubagentTool_Name verifies tool name
|
||||
func TestSubagentTool_Name(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
@@ -85,13 +108,13 @@ func TestSubagentTool_Parameters(t *testing.T) {
|
||||
}
|
||||
|
||||
// Check properties
|
||||
props, ok := params["properties"].(map[string]any)
|
||||
props, ok := params["properties"].(map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatal("Properties should be a map")
|
||||
}
|
||||
|
||||
// Verify task parameter
|
||||
task, ok := props["task"].(map[string]any)
|
||||
task, ok := props["task"].(map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatal("Task parameter should exist")
|
||||
}
|
||||
@@ -100,7 +123,7 @@ func TestSubagentTool_Parameters(t *testing.T) {
|
||||
}
|
||||
|
||||
// Verify label parameter
|
||||
label, ok := props["label"].(map[string]any)
|
||||
label, ok := props["label"].(map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatal("Label parameter should exist")
|
||||
}
|
||||
@@ -140,7 +163,7 @@ func TestSubagentTool_Execute_Success(t *testing.T) {
|
||||
tool.SetContext("telegram", "chat-123")
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]any{
|
||||
args := map[string]interface{}{
|
||||
"task": "Write a haiku about coding",
|
||||
"label": "haiku-task",
|
||||
}
|
||||
@@ -195,7 +218,7 @@ func TestSubagentTool_Execute_NoLabel(t *testing.T) {
|
||||
tool := NewSubagentTool(manager)
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]any{
|
||||
args := map[string]interface{}{
|
||||
"task": "Test task without label",
|
||||
}
|
||||
|
||||
@@ -218,7 +241,7 @@ func TestSubagentTool_Execute_MissingTask(t *testing.T) {
|
||||
tool := NewSubagentTool(manager)
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]any{
|
||||
args := map[string]interface{}{
|
||||
"label": "test",
|
||||
}
|
||||
|
||||
@@ -245,7 +268,7 @@ func TestSubagentTool_Execute_NilManager(t *testing.T) {
|
||||
tool := NewSubagentTool(nil)
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]any{
|
||||
args := map[string]interface{}{
|
||||
"task": "test task",
|
||||
}
|
||||
|
||||
@@ -274,7 +297,7 @@ func TestSubagentTool_Execute_ContextPassing(t *testing.T) {
|
||||
tool.SetContext(channel, chatID)
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]any{
|
||||
args := map[string]interface{}{
|
||||
"task": "Test context passing",
|
||||
}
|
||||
|
||||
@@ -301,7 +324,7 @@ func TestSubagentTool_ForUserTruncation(t *testing.T) {
|
||||
|
||||
// Create a task that will generate long response
|
||||
longTask := strings.Repeat("This is a very long task description. ", 100)
|
||||
args := map[string]any{
|
||||
args := map[string]interface{}{
|
||||
"task": longTask,
|
||||
"label": "long-test",
|
||||
}
|
||||
|
||||
@@ -60,12 +60,8 @@ func RunToolLoop(
|
||||
// 2. Set default LLM options
|
||||
llmOpts := config.LLMOptions
|
||||
if llmOpts == nil {
|
||||
llmOpts = map[string]any{
|
||||
"max_tokens": 4096,
|
||||
"temperature": 0.7,
|
||||
}
|
||||
llmOpts = map[string]any{}
|
||||
}
|
||||
|
||||
// 3. Call LLM
|
||||
response, err := config.Provider.Chat(ctx, messages, providerToolDefs, config.Model, llmOpts)
|
||||
if err != nil {
|
||||
@@ -114,6 +110,7 @@ func RunToolLoop(
|
||||
Name: tc.Name,
|
||||
Arguments: string(argumentsJSON),
|
||||
},
|
||||
Name: tc.Name,
|
||||
})
|
||||
}
|
||||
messages = append(messages, assistantMsg)
|
||||
|
||||
+4
-2
@@ -504,8 +504,10 @@ func (t *WebFetchTool) extractText(htmlContent string) string {
|
||||
|
||||
result = strings.TrimSpace(result)
|
||||
|
||||
re = regexp.MustCompile(`\s+`)
|
||||
result = re.ReplaceAllLiteralString(result, " ")
|
||||
re = regexp.MustCompile(`[^\S\n]+`)
|
||||
result = re.ReplaceAllString(result, " ")
|
||||
re = regexp.MustCompile(`\n{3,}`)
|
||||
result = re.ReplaceAllString(result, "\n\n")
|
||||
|
||||
lines := strings.Split(result, "\n")
|
||||
var cleanLines []string
|
||||
|
||||
@@ -238,6 +238,80 @@ func TestWebTool_WebFetch_HTMLExtraction(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebFetchTool_extractText verifies text extraction preserves newlines
|
||||
func TestWebFetchTool_extractText(t *testing.T) {
|
||||
tool := &WebFetchTool{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantFunc func(t *testing.T, got string)
|
||||
}{
|
||||
{
|
||||
name: "preserves newlines between block elements",
|
||||
input: "<html><body><h1>Title</h1>\n<p>Paragraph 1</p>\n<p>Paragraph 2</p></body></html>",
|
||||
wantFunc: func(t *testing.T, got string) {
|
||||
lines := strings.Split(got, "\n")
|
||||
if len(lines) < 2 {
|
||||
t.Errorf("Expected multiple lines, got %d: %q", len(lines), got)
|
||||
}
|
||||
if !strings.Contains(got, "Title") || !strings.Contains(got, "Paragraph 1") || !strings.Contains(got, "Paragraph 2") {
|
||||
t.Errorf("Missing expected text: %q", got)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "removes script and style tags",
|
||||
input: "<script>alert('x');</script><style>body{}</style><p>Keep this</p>",
|
||||
wantFunc: func(t *testing.T, got string) {
|
||||
if strings.Contains(got, "alert") || strings.Contains(got, "body{}") {
|
||||
t.Errorf("Expected script/style content removed, got: %q", got)
|
||||
}
|
||||
if !strings.Contains(got, "Keep this") {
|
||||
t.Errorf("Expected 'Keep this' to remain, got: %q", got)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "collapses excessive blank lines",
|
||||
input: "<p>A</p>\n\n\n\n\n<p>B</p>",
|
||||
wantFunc: func(t *testing.T, got string) {
|
||||
if strings.Contains(got, "\n\n\n") {
|
||||
t.Errorf("Expected excessive blank lines collapsed, got: %q", got)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "collapses horizontal whitespace",
|
||||
input: "<p>hello world</p>",
|
||||
wantFunc: func(t *testing.T, got string) {
|
||||
if strings.Contains(got, " ") {
|
||||
t.Errorf("Expected spaces collapsed, got: %q", got)
|
||||
}
|
||||
if !strings.Contains(got, "hello world") {
|
||||
t.Errorf("Expected 'hello world', got: %q", got)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty input",
|
||||
input: "",
|
||||
wantFunc: func(t *testing.T, got string) {
|
||||
if got != "" {
|
||||
t.Errorf("Expected empty string, got: %q", got)
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := tool.extractText(tt.input)
|
||||
tt.wantFunc(t, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebTool_WebFetch_MissingDomain verifies error handling for URL without domain
|
||||
func TestWebTool_WebFetch_MissingDomain(t *testing.T) {
|
||||
tool := NewWebFetchTool(50000)
|
||||
|
||||
Reference in New Issue
Block a user