diff --git a/pkg/agent/instance.go b/pkg/agent/instance.go index 54a5396e7..37b253685 100644 --- a/pkg/agent/instance.go +++ b/pkg/agent/instance.go @@ -21,6 +21,8 @@ type AgentInstance struct { Fallbacks []string Workspace string MaxIterations int + MaxTokens int + Temperature float64 ContextWindow int Provider providers.LLMProvider Sessions *session.SessionManager @@ -76,6 +78,16 @@ func NewAgentInstance( maxIter = 20 } + maxTokens := defaults.MaxTokens + if maxTokens == 0 { + maxTokens = 8192 + } + + temperature := 0.7 + if defaults.Temperature != nil { + temperature = *defaults.Temperature + } + // Resolve fallback candidates modelCfg := providers.ModelConfig{ Primary: model, @@ -90,7 +102,9 @@ func NewAgentInstance( Fallbacks: fallbacks, Workspace: workspace, MaxIterations: maxIter, - ContextWindow: defaults.MaxTokens, + MaxTokens: maxTokens, + Temperature: temperature, + ContextWindow: maxTokens, Provider: provider, Sessions: sessionsManager, ContextBuilder: contextBuilder, diff --git a/pkg/agent/instance_test.go b/pkg/agent/instance_test.go new file mode 100644 index 000000000..fcc8e9bea --- /dev/null +++ b/pkg/agent/instance_test.go @@ -0,0 +1,95 @@ +package agent + +import ( + "os" + "testing" + + "github.com/sipeed/picoclaw/pkg/config" +) + +func TestNewAgentInstance_UsesDefaultsTemperatureAndMaxTokens(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-instance-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 1234, + MaxToolIterations: 5, + }, + }, + } + + configuredTemp := 1.0 + cfg.Agents.Defaults.Temperature = &configuredTemp + + provider := &mockProvider{} + agent := NewAgentInstance(nil, &cfg.Agents.Defaults, cfg, provider) + + if agent.MaxTokens != 1234 { + t.Fatalf("MaxTokens = %d, want %d", agent.MaxTokens, 1234) + } + if agent.Temperature != 1.0 { + t.Fatalf("Temperature = %f, want %f", agent.Temperature, 1.0) + } +} + +func TestNewAgentInstance_DefaultsTemperatureWhenZero(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-instance-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 1234, + MaxToolIterations: 5, + }, + }, + } + + configuredTemp := 0.0 + cfg.Agents.Defaults.Temperature = &configuredTemp + + provider := &mockProvider{} + agent := NewAgentInstance(nil, &cfg.Agents.Defaults, cfg, provider) + + if agent.Temperature != 0.0 { + t.Fatalf("Temperature = %f, want %f", agent.Temperature, 0.0) + } +} + +func TestNewAgentInstance_DefaultsTemperatureWhenUnset(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-instance-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 1234, + MaxToolIterations: 5, + }, + }, + } + + provider := &mockProvider{} + agent := NewAgentInstance(nil, &cfg.Agents.Defaults, cfg, provider) + + if agent.Temperature != 0.7 { + t.Fatalf("Temperature = %f, want %f", agent.Temperature, 0.7) + } +} diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index 6d0a61375..0f1b26c5c 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -119,6 +119,7 @@ func registerSharedTools(cfg *config.Config, msgBus *bus.MessageBus, registry *A // Spawn tool with allowlist checker subagentManager := tools.NewSubagentManager(provider, agent.Model, agent.Workspace, msgBus) + subagentManager.SetLLMOptions(agent.MaxTokens, agent.Temperature) spawnTool := tools.NewSpawnTool(subagentManager) currentAgentID := agentID spawnTool.SetAllowlistChecker(func(targetAgentID string) bool { @@ -470,8 +471,8 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, agent *AgentInstance, "model": agent.Model, "messages_count": len(messages), "tools_count": len(providerToolDefs), - "max_tokens": 8192, - "temperature": 0.7, + "max_tokens": agent.MaxTokens, + "temperature": agent.Temperature, "system_prompt_len": len(messages[0].Content), }) @@ -492,8 +493,8 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, agent *AgentInstance, fbResult, fbErr := al.fallback.Execute(ctx, agent.Candidates, func(ctx context.Context, provider, model string) (*providers.LLMResponse, error) { return agent.Provider.Chat(ctx, messages, providerToolDefs, model, map[string]interface{}{ - "max_tokens": 8192, - "temperature": 0.7, + "max_tokens": agent.MaxTokens, + "temperature": agent.Temperature, }) }, ) @@ -508,8 +509,8 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, agent *AgentInstance, return fbResult.Response, nil } return agent.Provider.Chat(ctx, messages, providerToolDefs, agent.Model, map[string]interface{}{ - "max_tokens": 8192, - "temperature": 0.7, + "max_tokens": agent.MaxTokens, + "temperature": agent.Temperature, }) } diff --git a/pkg/agent/loop_test.go b/pkg/agent/loop_test.go index f2257973c..360685eca 100644 --- a/pkg/agent/loop_test.go +++ b/pkg/agent/loop_test.go @@ -14,20 +14,6 @@ import ( "github.com/sipeed/picoclaw/pkg/tools" ) -// mockProvider is a simple mock LLM provider for testing -type mockProvider struct{} - -func (m *mockProvider) Chat(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, opts map[string]interface{}) (*providers.LLMResponse, error) { - return &providers.LLMResponse{ - Content: "Mock response", - ToolCalls: []providers.ToolCall{}, - }, nil -} - -func (m *mockProvider) GetDefaultModel() string { - return "mock-model" -} - func TestRecordLastChannel(t *testing.T) { // Create temp workspace tmpDir, err := os.MkdirTemp("", "agent-test-*") @@ -603,7 +589,6 @@ func TestAgentLoop_ContextExhaustionRetry(t *testing.T) { // Call ProcessDirectWithChannel // Note: ProcessDirectWithChannel calls processMessage which will execute runLLMIteration response, err := al.ProcessDirectWithChannel(context.Background(), "Trigger message", sessionKey, "test", "test-chat") - if err != nil { t.Fatalf("Expected success after retry, got error: %v", err) } diff --git a/pkg/agent/mock_provider_test.go b/pkg/agent/mock_provider_test.go new file mode 100644 index 000000000..ccbecbafe --- /dev/null +++ b/pkg/agent/mock_provider_test.go @@ -0,0 +1,20 @@ +package agent + +import ( + "context" + + "github.com/sipeed/picoclaw/pkg/providers" +) + +type mockProvider struct{} + +func (m *mockProvider) Chat(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, opts map[string]interface{}) (*providers.LLMResponse, error) { + return &providers.LLMResponse{ + Content: "Mock response", + ToolCalls: []providers.ToolCall{}, + }, nil +} + +func (m *mockProvider) GetDefaultModel() string { + return "mock-model" +} diff --git a/pkg/config/config.go b/pkg/config/config.go index 682996bd6..3bdb6f030 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -147,7 +147,7 @@ type AgentDefaults struct { ImageModel string `json:"image_model,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_IMAGE_MODEL"` ImageModelFallbacks []string `json:"image_model_fallbacks,omitempty"` MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"` - Temperature float64 `json:"temperature" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"` + Temperature *float64 `json:"temperature,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"` MaxToolIterations int `json:"max_tool_iterations" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"` } @@ -330,7 +330,6 @@ func DefaultConfig() *Config { Provider: "", Model: "glm-4.7", MaxTokens: 8192, - Temperature: 0.7, MaxToolIterations: 20, }, }, diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go index 47916d155..7e706d8ce 100644 --- a/pkg/config/config_test.go +++ b/pkg/config/config_test.go @@ -237,8 +237,8 @@ func TestDefaultConfig_MaxToolIterations(t *testing.T) { func TestDefaultConfig_Temperature(t *testing.T) { cfg := DefaultConfig() - if cfg.Agents.Defaults.Temperature == 0 { - t.Error("Temperature should not be zero") + if cfg.Agents.Defaults.Temperature != nil { + t.Error("Temperature should be nil when not provided") } } @@ -334,8 +334,8 @@ func TestConfig_Complete(t *testing.T) { if cfg.Agents.Defaults.Model == "" { t.Error("Model should not be empty") } - if cfg.Agents.Defaults.Temperature == 0 { - t.Error("Temperature should have default value") + if cfg.Agents.Defaults.Temperature != nil { + t.Error("Temperature should be nil when not provided") } if cfg.Agents.Defaults.MaxTokens == 0 { t.Error("MaxTokens should not be zero") diff --git a/pkg/migrate/config.go b/pkg/migrate/config.go index 57032e566..665719f2a 100644 --- a/pkg/migrate/config.go +++ b/pkg/migrate/config.go @@ -76,7 +76,7 @@ func ConvertConfig(data map[string]interface{}) (*config.Config, []string, error cfg.Agents.Defaults.MaxTokens = int(v) } if v, ok := getFloat(defaults, "temperature"); ok { - cfg.Agents.Defaults.Temperature = v + cfg.Agents.Defaults.Temperature = &v } if v, ok := getFloat(defaults, "max_tool_iterations"); ok { cfg.Agents.Defaults.MaxToolIterations = int(v) diff --git a/pkg/migrate/migrate_test.go b/pkg/migrate/migrate_test.go index e930d45f4..f6f8b7908 100644 --- a/pkg/migrate/migrate_test.go +++ b/pkg/migrate/migrate_test.go @@ -275,8 +275,11 @@ func TestConvertConfig(t *testing.T) { if cfg.Agents.Defaults.MaxTokens != 4096 { t.Errorf("MaxTokens = %d, want %d", cfg.Agents.Defaults.MaxTokens, 4096) } - if cfg.Agents.Defaults.Temperature != 0.5 { - t.Errorf("Temperature = %f, want %f", cfg.Agents.Defaults.Temperature, 0.5) + if cfg.Agents.Defaults.Temperature == nil { + t.Fatalf("Temperature is nil, want %f", 0.5) + } + if *cfg.Agents.Defaults.Temperature != 0.5 { + t.Errorf("Temperature = %f, want %f", *cfg.Agents.Defaults.Temperature, 0.5) } if cfg.Agents.Defaults.Workspace != "~/.picoclaw/workspace" { t.Errorf("Workspace = %q, want %q", cfg.Agents.Defaults.Workspace, "~/.picoclaw/workspace") diff --git a/pkg/tools/subagent.go b/pkg/tools/subagent.go index 2fc7162d0..294ba6ea8 100644 --- a/pkg/tools/subagent.go +++ b/pkg/tools/subagent.go @@ -23,15 +23,19 @@ type SubagentTask struct { } type SubagentManager struct { - tasks map[string]*SubagentTask - mu sync.RWMutex - provider providers.LLMProvider - defaultModel string - bus *bus.MessageBus - workspace string - tools *ToolRegistry - maxIterations int - nextID int + tasks map[string]*SubagentTask + mu sync.RWMutex + provider providers.LLMProvider + defaultModel string + bus *bus.MessageBus + workspace string + tools *ToolRegistry + maxIterations int + maxTokens int + temperature float64 + hasMaxTokens bool + hasTemperature bool + nextID int } func NewSubagentManager(provider providers.LLMProvider, defaultModel, workspace string, bus *bus.MessageBus) *SubagentManager { @@ -47,6 +51,16 @@ func NewSubagentManager(provider providers.LLMProvider, defaultModel, workspace } } +// SetLLMOptions sets max tokens and temperature for subagent LLM calls. +func (sm *SubagentManager) SetLLMOptions(maxTokens int, temperature float64) { + sm.mu.Lock() + defer sm.mu.Unlock() + sm.maxTokens = maxTokens + sm.hasMaxTokens = true + sm.temperature = temperature + sm.hasTemperature = true +} + // SetTools sets the tool registry for subagent execution. // If not set, subagent will have access to the provided tools. func (sm *SubagentManager) SetTools(tools *ToolRegistry) { @@ -125,17 +139,29 @@ After completing the task, provide a clear summary of what was done.` sm.mu.RLock() tools := sm.tools maxIter := sm.maxIterations + maxTokens := sm.maxTokens + temperature := sm.temperature + hasMaxTokens := sm.hasMaxTokens + hasTemperature := sm.hasTemperature sm.mu.RUnlock() + var llmOptions map[string]any + if hasMaxTokens || hasTemperature { + llmOptions = map[string]any{} + if hasMaxTokens { + llmOptions["max_tokens"] = maxTokens + } + if hasTemperature { + llmOptions["temperature"] = temperature + } + } + loopResult, err := RunToolLoop(ctx, ToolLoopConfig{ Provider: sm.provider, Model: sm.defaultModel, Tools: tools, MaxIterations: maxIter, - LLMOptions: map[string]any{ - "max_tokens": 4096, - "temperature": 0.7, - }, + LLMOptions: llmOptions, }, messages, task.OriginChannel, task.OriginChatID) sm.mu.Lock() @@ -283,19 +309,30 @@ func (t *SubagentTool) Execute(ctx context.Context, args map[string]interface{}) sm.mu.RLock() tools := sm.tools maxIter := sm.maxIterations + maxTokens := sm.maxTokens + temperature := sm.temperature + hasMaxTokens := sm.hasMaxTokens + hasTemperature := sm.hasTemperature sm.mu.RUnlock() + var llmOptions map[string]any + if hasMaxTokens || hasTemperature { + llmOptions = map[string]any{} + if hasMaxTokens { + llmOptions["max_tokens"] = maxTokens + } + if hasTemperature { + llmOptions["temperature"] = temperature + } + } + loopResult, err := RunToolLoop(ctx, ToolLoopConfig{ Provider: sm.provider, Model: sm.defaultModel, Tools: tools, MaxIterations: maxIter, - LLMOptions: map[string]any{ - "max_tokens": 4096, - "temperature": 0.7, - }, + LLMOptions: llmOptions, }, messages, t.originChannel, t.originChatID) - if err != nil { return ErrorResult(fmt.Sprintf("Subagent execution failed: %v", err)).WithError(err) } diff --git a/pkg/tools/subagent_tool_test.go b/pkg/tools/subagent_tool_test.go index 8a7d22f24..f960a7fda 100644 --- a/pkg/tools/subagent_tool_test.go +++ b/pkg/tools/subagent_tool_test.go @@ -10,9 +10,12 @@ import ( ) // MockLLMProvider is a test implementation of LLMProvider -type MockLLMProvider struct{} +type MockLLMProvider struct { + lastOptions map[string]interface{} +} func (m *MockLLMProvider) Chat(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, options map[string]interface{}) (*providers.LLMResponse, error) { + m.lastOptions = options // Find the last user message to generate a response for i := len(messages) - 1; i >= 0; i-- { if messages[i].Role == "user" { @@ -36,6 +39,32 @@ func (m *MockLLMProvider) GetContextWindow() int { return 4096 } +func TestSubagentManager_SetLLMOptions_AppliesToRunToolLoop(t *testing.T) { + provider := &MockLLMProvider{} + manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil) + manager.SetLLMOptions(2048, 0.6) + tool := NewSubagentTool(manager) + tool.SetContext("cli", "direct") + + ctx := context.Background() + args := map[string]interface{}{"task": "Do something"} + result := tool.Execute(ctx, args) + + if result == nil || result.IsError { + t.Fatalf("Expected successful result, got: %+v", result) + } + + if provider.lastOptions == nil { + t.Fatal("Expected LLM options to be passed, got nil") + } + if provider.lastOptions["max_tokens"] != 2048 { + t.Fatalf("max_tokens = %v, want %d", provider.lastOptions["max_tokens"], 2048) + } + if provider.lastOptions["temperature"] != 0.6 { + t.Fatalf("temperature = %v, want %v", provider.lastOptions["temperature"], 0.6) + } +} + // TestSubagentTool_Name verifies tool name func TestSubagentTool_Name(t *testing.T) { provider := &MockLLMProvider{} diff --git a/pkg/tools/toolloop.go b/pkg/tools/toolloop.go index b07b14adb..e893217d3 100644 --- a/pkg/tools/toolloop.go +++ b/pkg/tools/toolloop.go @@ -55,12 +55,8 @@ func RunToolLoop(ctx context.Context, config ToolLoopConfig, messages []provider // 2. Set default LLM options llmOpts := config.LLMOptions if llmOpts == nil { - llmOpts = map[string]any{ - "max_tokens": 4096, - "temperature": 0.7, - } + llmOpts = map[string]any{} } - // 3. Call LLM response, err := config.Provider.Chat(ctx, messages, providerToolDefs, config.Model, llmOpts) if err != nil {