mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
fix: Templates update (#485)
* fix: add MaxTokens and Temperature fields to AgentInstance and update related logic * feat: add MaxTokens and Temperature options to SubagentManager and update tool loop logic * feat: add default temperature handling and update related tests * feat: allow temperature 0 and distinguish unset * fix: format MockLLMProvider struct in subagent_tool_test.go
This commit is contained in:
+15
-1
@@ -21,6 +21,8 @@ type AgentInstance struct {
|
||||
Fallbacks []string
|
||||
Workspace string
|
||||
MaxIterations int
|
||||
MaxTokens int
|
||||
Temperature float64
|
||||
ContextWindow int
|
||||
Provider providers.LLMProvider
|
||||
Sessions *session.SessionManager
|
||||
@@ -76,6 +78,16 @@ func NewAgentInstance(
|
||||
maxIter = 20
|
||||
}
|
||||
|
||||
maxTokens := defaults.MaxTokens
|
||||
if maxTokens == 0 {
|
||||
maxTokens = 8192
|
||||
}
|
||||
|
||||
temperature := 0.7
|
||||
if defaults.Temperature != nil {
|
||||
temperature = *defaults.Temperature
|
||||
}
|
||||
|
||||
// Resolve fallback candidates
|
||||
modelCfg := providers.ModelConfig{
|
||||
Primary: model,
|
||||
@@ -90,7 +102,9 @@ func NewAgentInstance(
|
||||
Fallbacks: fallbacks,
|
||||
Workspace: workspace,
|
||||
MaxIterations: maxIter,
|
||||
ContextWindow: defaults.MaxTokens,
|
||||
MaxTokens: maxTokens,
|
||||
Temperature: temperature,
|
||||
ContextWindow: maxTokens,
|
||||
Provider: provider,
|
||||
Sessions: sessionsManager,
|
||||
ContextBuilder: contextBuilder,
|
||||
|
||||
@@ -0,0 +1,95 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
func TestNewAgentInstance_UsesDefaultsTemperatureAndMaxTokens(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-instance-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: "test-model",
|
||||
MaxTokens: 1234,
|
||||
MaxToolIterations: 5,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
configuredTemp := 1.0
|
||||
cfg.Agents.Defaults.Temperature = &configuredTemp
|
||||
|
||||
provider := &mockProvider{}
|
||||
agent := NewAgentInstance(nil, &cfg.Agents.Defaults, cfg, provider)
|
||||
|
||||
if agent.MaxTokens != 1234 {
|
||||
t.Fatalf("MaxTokens = %d, want %d", agent.MaxTokens, 1234)
|
||||
}
|
||||
if agent.Temperature != 1.0 {
|
||||
t.Fatalf("Temperature = %f, want %f", agent.Temperature, 1.0)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewAgentInstance_DefaultsTemperatureWhenZero(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-instance-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: "test-model",
|
||||
MaxTokens: 1234,
|
||||
MaxToolIterations: 5,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
configuredTemp := 0.0
|
||||
cfg.Agents.Defaults.Temperature = &configuredTemp
|
||||
|
||||
provider := &mockProvider{}
|
||||
agent := NewAgentInstance(nil, &cfg.Agents.Defaults, cfg, provider)
|
||||
|
||||
if agent.Temperature != 0.0 {
|
||||
t.Fatalf("Temperature = %f, want %f", agent.Temperature, 0.0)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewAgentInstance_DefaultsTemperatureWhenUnset(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-instance-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: "test-model",
|
||||
MaxTokens: 1234,
|
||||
MaxToolIterations: 5,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
provider := &mockProvider{}
|
||||
agent := NewAgentInstance(nil, &cfg.Agents.Defaults, cfg, provider)
|
||||
|
||||
if agent.Temperature != 0.7 {
|
||||
t.Fatalf("Temperature = %f, want %f", agent.Temperature, 0.7)
|
||||
}
|
||||
}
|
||||
+7
-6
@@ -119,6 +119,7 @@ func registerSharedTools(cfg *config.Config, msgBus *bus.MessageBus, registry *A
|
||||
|
||||
// Spawn tool with allowlist checker
|
||||
subagentManager := tools.NewSubagentManager(provider, agent.Model, agent.Workspace, msgBus)
|
||||
subagentManager.SetLLMOptions(agent.MaxTokens, agent.Temperature)
|
||||
spawnTool := tools.NewSpawnTool(subagentManager)
|
||||
currentAgentID := agentID
|
||||
spawnTool.SetAllowlistChecker(func(targetAgentID string) bool {
|
||||
@@ -470,8 +471,8 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, agent *AgentInstance,
|
||||
"model": agent.Model,
|
||||
"messages_count": len(messages),
|
||||
"tools_count": len(providerToolDefs),
|
||||
"max_tokens": 8192,
|
||||
"temperature": 0.7,
|
||||
"max_tokens": agent.MaxTokens,
|
||||
"temperature": agent.Temperature,
|
||||
"system_prompt_len": len(messages[0].Content),
|
||||
})
|
||||
|
||||
@@ -492,8 +493,8 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, agent *AgentInstance,
|
||||
fbResult, fbErr := al.fallback.Execute(ctx, agent.Candidates,
|
||||
func(ctx context.Context, provider, model string) (*providers.LLMResponse, error) {
|
||||
return agent.Provider.Chat(ctx, messages, providerToolDefs, model, map[string]interface{}{
|
||||
"max_tokens": 8192,
|
||||
"temperature": 0.7,
|
||||
"max_tokens": agent.MaxTokens,
|
||||
"temperature": agent.Temperature,
|
||||
})
|
||||
},
|
||||
)
|
||||
@@ -508,8 +509,8 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, agent *AgentInstance,
|
||||
return fbResult.Response, nil
|
||||
}
|
||||
return agent.Provider.Chat(ctx, messages, providerToolDefs, agent.Model, map[string]interface{}{
|
||||
"max_tokens": 8192,
|
||||
"temperature": 0.7,
|
||||
"max_tokens": agent.MaxTokens,
|
||||
"temperature": agent.Temperature,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -14,20 +14,6 @@ import (
|
||||
"github.com/sipeed/picoclaw/pkg/tools"
|
||||
)
|
||||
|
||||
// mockProvider is a simple mock LLM provider for testing
|
||||
type mockProvider struct{}
|
||||
|
||||
func (m *mockProvider) Chat(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, opts map[string]interface{}) (*providers.LLMResponse, error) {
|
||||
return &providers.LLMResponse{
|
||||
Content: "Mock response",
|
||||
ToolCalls: []providers.ToolCall{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *mockProvider) GetDefaultModel() string {
|
||||
return "mock-model"
|
||||
}
|
||||
|
||||
func TestRecordLastChannel(t *testing.T) {
|
||||
// Create temp workspace
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
@@ -603,7 +589,6 @@ func TestAgentLoop_ContextExhaustionRetry(t *testing.T) {
|
||||
// Call ProcessDirectWithChannel
|
||||
// Note: ProcessDirectWithChannel calls processMessage which will execute runLLMIteration
|
||||
response, err := al.ProcessDirectWithChannel(context.Background(), "Trigger message", sessionKey, "test", "test-chat")
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Expected success after retry, got error: %v", err)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,20 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
type mockProvider struct{}
|
||||
|
||||
func (m *mockProvider) Chat(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, opts map[string]interface{}) (*providers.LLMResponse, error) {
|
||||
return &providers.LLMResponse{
|
||||
Content: "Mock response",
|
||||
ToolCalls: []providers.ToolCall{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *mockProvider) GetDefaultModel() string {
|
||||
return "mock-model"
|
||||
}
|
||||
@@ -147,7 +147,7 @@ type AgentDefaults struct {
|
||||
ImageModel string `json:"image_model,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_IMAGE_MODEL"`
|
||||
ImageModelFallbacks []string `json:"image_model_fallbacks,omitempty"`
|
||||
MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"`
|
||||
Temperature float64 `json:"temperature" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"`
|
||||
Temperature *float64 `json:"temperature,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"`
|
||||
MaxToolIterations int `json:"max_tool_iterations" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"`
|
||||
}
|
||||
|
||||
@@ -330,7 +330,6 @@ func DefaultConfig() *Config {
|
||||
Provider: "",
|
||||
Model: "glm-4.7",
|
||||
MaxTokens: 8192,
|
||||
Temperature: 0.7,
|
||||
MaxToolIterations: 20,
|
||||
},
|
||||
},
|
||||
|
||||
@@ -237,8 +237,8 @@ func TestDefaultConfig_MaxToolIterations(t *testing.T) {
|
||||
func TestDefaultConfig_Temperature(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
|
||||
if cfg.Agents.Defaults.Temperature == 0 {
|
||||
t.Error("Temperature should not be zero")
|
||||
if cfg.Agents.Defaults.Temperature != nil {
|
||||
t.Error("Temperature should be nil when not provided")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -334,8 +334,8 @@ func TestConfig_Complete(t *testing.T) {
|
||||
if cfg.Agents.Defaults.Model == "" {
|
||||
t.Error("Model should not be empty")
|
||||
}
|
||||
if cfg.Agents.Defaults.Temperature == 0 {
|
||||
t.Error("Temperature should have default value")
|
||||
if cfg.Agents.Defaults.Temperature != nil {
|
||||
t.Error("Temperature should be nil when not provided")
|
||||
}
|
||||
if cfg.Agents.Defaults.MaxTokens == 0 {
|
||||
t.Error("MaxTokens should not be zero")
|
||||
|
||||
@@ -76,7 +76,7 @@ func ConvertConfig(data map[string]interface{}) (*config.Config, []string, error
|
||||
cfg.Agents.Defaults.MaxTokens = int(v)
|
||||
}
|
||||
if v, ok := getFloat(defaults, "temperature"); ok {
|
||||
cfg.Agents.Defaults.Temperature = v
|
||||
cfg.Agents.Defaults.Temperature = &v
|
||||
}
|
||||
if v, ok := getFloat(defaults, "max_tool_iterations"); ok {
|
||||
cfg.Agents.Defaults.MaxToolIterations = int(v)
|
||||
|
||||
@@ -275,8 +275,11 @@ func TestConvertConfig(t *testing.T) {
|
||||
if cfg.Agents.Defaults.MaxTokens != 4096 {
|
||||
t.Errorf("MaxTokens = %d, want %d", cfg.Agents.Defaults.MaxTokens, 4096)
|
||||
}
|
||||
if cfg.Agents.Defaults.Temperature != 0.5 {
|
||||
t.Errorf("Temperature = %f, want %f", cfg.Agents.Defaults.Temperature, 0.5)
|
||||
if cfg.Agents.Defaults.Temperature == nil {
|
||||
t.Fatalf("Temperature is nil, want %f", 0.5)
|
||||
}
|
||||
if *cfg.Agents.Defaults.Temperature != 0.5 {
|
||||
t.Errorf("Temperature = %f, want %f", *cfg.Agents.Defaults.Temperature, 0.5)
|
||||
}
|
||||
if cfg.Agents.Defaults.Workspace != "~/.picoclaw/workspace" {
|
||||
t.Errorf("Workspace = %q, want %q", cfg.Agents.Defaults.Workspace, "~/.picoclaw/workspace")
|
||||
|
||||
+55
-18
@@ -23,15 +23,19 @@ type SubagentTask struct {
|
||||
}
|
||||
|
||||
type SubagentManager struct {
|
||||
tasks map[string]*SubagentTask
|
||||
mu sync.RWMutex
|
||||
provider providers.LLMProvider
|
||||
defaultModel string
|
||||
bus *bus.MessageBus
|
||||
workspace string
|
||||
tools *ToolRegistry
|
||||
maxIterations int
|
||||
nextID int
|
||||
tasks map[string]*SubagentTask
|
||||
mu sync.RWMutex
|
||||
provider providers.LLMProvider
|
||||
defaultModel string
|
||||
bus *bus.MessageBus
|
||||
workspace string
|
||||
tools *ToolRegistry
|
||||
maxIterations int
|
||||
maxTokens int
|
||||
temperature float64
|
||||
hasMaxTokens bool
|
||||
hasTemperature bool
|
||||
nextID int
|
||||
}
|
||||
|
||||
func NewSubagentManager(provider providers.LLMProvider, defaultModel, workspace string, bus *bus.MessageBus) *SubagentManager {
|
||||
@@ -47,6 +51,16 @@ func NewSubagentManager(provider providers.LLMProvider, defaultModel, workspace
|
||||
}
|
||||
}
|
||||
|
||||
// SetLLMOptions sets max tokens and temperature for subagent LLM calls.
|
||||
func (sm *SubagentManager) SetLLMOptions(maxTokens int, temperature float64) {
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
sm.maxTokens = maxTokens
|
||||
sm.hasMaxTokens = true
|
||||
sm.temperature = temperature
|
||||
sm.hasTemperature = true
|
||||
}
|
||||
|
||||
// SetTools sets the tool registry for subagent execution.
|
||||
// If not set, subagent will have access to the provided tools.
|
||||
func (sm *SubagentManager) SetTools(tools *ToolRegistry) {
|
||||
@@ -125,17 +139,29 @@ After completing the task, provide a clear summary of what was done.`
|
||||
sm.mu.RLock()
|
||||
tools := sm.tools
|
||||
maxIter := sm.maxIterations
|
||||
maxTokens := sm.maxTokens
|
||||
temperature := sm.temperature
|
||||
hasMaxTokens := sm.hasMaxTokens
|
||||
hasTemperature := sm.hasTemperature
|
||||
sm.mu.RUnlock()
|
||||
|
||||
var llmOptions map[string]any
|
||||
if hasMaxTokens || hasTemperature {
|
||||
llmOptions = map[string]any{}
|
||||
if hasMaxTokens {
|
||||
llmOptions["max_tokens"] = maxTokens
|
||||
}
|
||||
if hasTemperature {
|
||||
llmOptions["temperature"] = temperature
|
||||
}
|
||||
}
|
||||
|
||||
loopResult, err := RunToolLoop(ctx, ToolLoopConfig{
|
||||
Provider: sm.provider,
|
||||
Model: sm.defaultModel,
|
||||
Tools: tools,
|
||||
MaxIterations: maxIter,
|
||||
LLMOptions: map[string]any{
|
||||
"max_tokens": 4096,
|
||||
"temperature": 0.7,
|
||||
},
|
||||
LLMOptions: llmOptions,
|
||||
}, messages, task.OriginChannel, task.OriginChatID)
|
||||
|
||||
sm.mu.Lock()
|
||||
@@ -283,19 +309,30 @@ func (t *SubagentTool) Execute(ctx context.Context, args map[string]interface{})
|
||||
sm.mu.RLock()
|
||||
tools := sm.tools
|
||||
maxIter := sm.maxIterations
|
||||
maxTokens := sm.maxTokens
|
||||
temperature := sm.temperature
|
||||
hasMaxTokens := sm.hasMaxTokens
|
||||
hasTemperature := sm.hasTemperature
|
||||
sm.mu.RUnlock()
|
||||
|
||||
var llmOptions map[string]any
|
||||
if hasMaxTokens || hasTemperature {
|
||||
llmOptions = map[string]any{}
|
||||
if hasMaxTokens {
|
||||
llmOptions["max_tokens"] = maxTokens
|
||||
}
|
||||
if hasTemperature {
|
||||
llmOptions["temperature"] = temperature
|
||||
}
|
||||
}
|
||||
|
||||
loopResult, err := RunToolLoop(ctx, ToolLoopConfig{
|
||||
Provider: sm.provider,
|
||||
Model: sm.defaultModel,
|
||||
Tools: tools,
|
||||
MaxIterations: maxIter,
|
||||
LLMOptions: map[string]any{
|
||||
"max_tokens": 4096,
|
||||
"temperature": 0.7,
|
||||
},
|
||||
LLMOptions: llmOptions,
|
||||
}, messages, t.originChannel, t.originChatID)
|
||||
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("Subagent execution failed: %v", err)).WithError(err)
|
||||
}
|
||||
|
||||
@@ -10,9 +10,12 @@ import (
|
||||
)
|
||||
|
||||
// MockLLMProvider is a test implementation of LLMProvider
|
||||
type MockLLMProvider struct{}
|
||||
type MockLLMProvider struct {
|
||||
lastOptions map[string]interface{}
|
||||
}
|
||||
|
||||
func (m *MockLLMProvider) Chat(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, options map[string]interface{}) (*providers.LLMResponse, error) {
|
||||
m.lastOptions = options
|
||||
// Find the last user message to generate a response
|
||||
for i := len(messages) - 1; i >= 0; i-- {
|
||||
if messages[i].Role == "user" {
|
||||
@@ -36,6 +39,32 @@ func (m *MockLLMProvider) GetContextWindow() int {
|
||||
return 4096
|
||||
}
|
||||
|
||||
func TestSubagentManager_SetLLMOptions_AppliesToRunToolLoop(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil)
|
||||
manager.SetLLMOptions(2048, 0.6)
|
||||
tool := NewSubagentTool(manager)
|
||||
tool.SetContext("cli", "direct")
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{"task": "Do something"}
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
if result == nil || result.IsError {
|
||||
t.Fatalf("Expected successful result, got: %+v", result)
|
||||
}
|
||||
|
||||
if provider.lastOptions == nil {
|
||||
t.Fatal("Expected LLM options to be passed, got nil")
|
||||
}
|
||||
if provider.lastOptions["max_tokens"] != 2048 {
|
||||
t.Fatalf("max_tokens = %v, want %d", provider.lastOptions["max_tokens"], 2048)
|
||||
}
|
||||
if provider.lastOptions["temperature"] != 0.6 {
|
||||
t.Fatalf("temperature = %v, want %v", provider.lastOptions["temperature"], 0.6)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSubagentTool_Name verifies tool name
|
||||
func TestSubagentTool_Name(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
|
||||
@@ -55,12 +55,8 @@ func RunToolLoop(ctx context.Context, config ToolLoopConfig, messages []provider
|
||||
// 2. Set default LLM options
|
||||
llmOpts := config.LLMOptions
|
||||
if llmOpts == nil {
|
||||
llmOpts = map[string]any{
|
||||
"max_tokens": 4096,
|
||||
"temperature": 0.7,
|
||||
}
|
||||
llmOpts = map[string]any{}
|
||||
}
|
||||
|
||||
// 3. Call LLM
|
||||
response, err := config.Provider.Chat(ctx, messages, providerToolDefs, config.Model, llmOpts)
|
||||
if err != nil {
|
||||
|
||||
Reference in New Issue
Block a user