mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
fix(tools): prevent nil pointer dereference in spawn tools
Add nil checks in NewSpawnTool and NewSubagentTool constructors to handle nil manager gracefully. Fix spelling errors (cancelled->canceled) and remove unused test code. Update tests to use mock spawner.
This commit is contained in:
@@ -708,8 +708,8 @@ func runTurn(ctx context.Context, al *AgentLoop, ts *turnState, cfg SubTurnConfi
|
||||
|
||||
logger.DebugCF("subturn", "Token budget updated",
|
||||
map[string]any{
|
||||
"turn_id": ts.turnID,
|
||||
"tokens_used": usage.TotalTokens,
|
||||
"turn_id": ts.turnID,
|
||||
"tokens_used": usage.TotalTokens,
|
||||
"remaining_budget": newBudget,
|
||||
})
|
||||
}
|
||||
|
||||
+27
-39
@@ -39,17 +39,6 @@ func (c *eventCollector) hasEventOfType(typ any) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *eventCollector) countOfType(typ any) int {
|
||||
targetType := reflect.TypeOf(typ)
|
||||
count := 0
|
||||
for _, e := range c.events {
|
||||
if reflect.TypeOf(e) == targetType {
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
// ====================== Main Test Function ======================
|
||||
func TestSpawnSubTurn(t *testing.T) {
|
||||
tests := []struct {
|
||||
@@ -556,7 +545,6 @@ func TestNestedSubTurnHierarchy(t *testing.T) {
|
||||
type turnInfo struct {
|
||||
parentID string
|
||||
childID string
|
||||
depth int
|
||||
}
|
||||
var spawnedTurns []turnInfo
|
||||
var mu sync.Mutex
|
||||
@@ -702,12 +690,12 @@ func TestHardAbortOrderOfOperations(t *testing.T) {
|
||||
t.Fatalf("HardAbort failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify context was cancelled (Finish() was called)
|
||||
// Verify context was canceled (Finish() was called)
|
||||
select {
|
||||
case <-rootTS.ctx.Done():
|
||||
// Good - context was cancelled
|
||||
// Good - context was canceled
|
||||
default:
|
||||
t.Error("expected context to be cancelled after HardAbort")
|
||||
t.Error("expected context to be canceled after HardAbort")
|
||||
}
|
||||
|
||||
// Verify history was rolled back
|
||||
@@ -1583,17 +1571,17 @@ func TestGrandchildAbort_CascadingCancellation(t *testing.T) {
|
||||
// Verify all contexts are active
|
||||
select {
|
||||
case <-grandparentTS.ctx.Done():
|
||||
t.Error("Grandparent context should not be cancelled yet")
|
||||
t.Error("Grandparent context should not be canceled yet")
|
||||
default:
|
||||
}
|
||||
select {
|
||||
case <-parentTS.ctx.Done():
|
||||
t.Error("Parent context should not be cancelled yet")
|
||||
t.Error("Parent context should not be canceled yet")
|
||||
default:
|
||||
}
|
||||
select {
|
||||
case <-childTS.ctx.Done():
|
||||
t.Error("Child context should not be cancelled yet")
|
||||
t.Error("Child context should not be canceled yet")
|
||||
default:
|
||||
}
|
||||
|
||||
@@ -1606,23 +1594,23 @@ func TestGrandchildAbort_CascadingCancellation(t *testing.T) {
|
||||
// Verify cascading cancellation
|
||||
select {
|
||||
case <-grandparentTS.ctx.Done():
|
||||
t.Log("Grandparent context cancelled (expected)")
|
||||
t.Log("Grandparent context canceled (expected)")
|
||||
default:
|
||||
t.Error("Grandparent context should be cancelled")
|
||||
t.Error("Grandparent context should be canceled")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-parentTS.ctx.Done():
|
||||
t.Log("Parent context cancelled via cascade (expected)")
|
||||
t.Log("Parent context canceled via cascade (expected)")
|
||||
default:
|
||||
t.Error("Parent context should be cancelled via cascade")
|
||||
t.Error("Parent context should be canceled via cascade")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-childTS.ctx.Done():
|
||||
t.Log("Grandchild context cancelled via cascade (expected)")
|
||||
t.Log("Grandchild context canceled via cascade (expected)")
|
||||
default:
|
||||
t.Error("Grandchild context should be cancelled via cascade")
|
||||
t.Error("Grandchild context should be canceled via cascade")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1677,7 +1665,7 @@ func TestSpawnDuringAbort_RaceCondition(t *testing.T) {
|
||||
wg.Wait()
|
||||
|
||||
// The spawn should either succeed (if it started before abort)
|
||||
// or fail with context cancelled error (if abort happened first)
|
||||
// or fail with context canceled error (if abort happened first)
|
||||
if spawnErr != nil {
|
||||
if errors.Is(spawnErr, context.Canceled) {
|
||||
t.Logf("Spawn failed with expected context cancellation: %v", spawnErr)
|
||||
@@ -1714,7 +1702,7 @@ func (m *slowMockProvider) Chat(
|
||||
Content: "slow response completed",
|
||||
}, nil
|
||||
case <-ctx.Done():
|
||||
// Context was cancelled while waiting
|
||||
// Context was canceled while waiting
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
@@ -1726,7 +1714,7 @@ func (m *slowMockProvider) GetDefaultModel() string {
|
||||
// TestAsyncSubTurn_ParentFinishesEarly simulates the scenario where:
|
||||
// 1. Parent spawns an async SubTurn that takes a long time
|
||||
// 2. Parent finishes quickly
|
||||
// 3. SubTurn should be cancelled with context canceled error
|
||||
// 3. SubTurn should be canceled with context canceled error
|
||||
func TestAsyncSubTurn_ParentFinishesEarly(t *testing.T) {
|
||||
// Save original MockEventBus.Emit to capture events
|
||||
originalEmit := MockEventBus.Emit
|
||||
@@ -1784,7 +1772,7 @@ func TestAsyncSubTurn_ParentFinishesEarly(t *testing.T) {
|
||||
t.Log("Parent finishing early...")
|
||||
parentTS.Finish(false)
|
||||
|
||||
// Wait for SubTurn to complete (or be cancelled)
|
||||
// Wait for SubTurn to complete (or be canceled)
|
||||
wg.Wait()
|
||||
|
||||
// Check the result
|
||||
@@ -1793,7 +1781,7 @@ func TestAsyncSubTurn_ParentFinishesEarly(t *testing.T) {
|
||||
|
||||
if subTurnErr != nil {
|
||||
if errors.Is(subTurnErr, context.Canceled) {
|
||||
t.Log("✓ SubTurn was cancelled as expected (context canceled)")
|
||||
t.Log("✓ SubTurn was canceled as expected (context canceled)")
|
||||
} else {
|
||||
t.Logf("SubTurn failed with other error: %v", subTurnErr)
|
||||
}
|
||||
@@ -1863,7 +1851,7 @@ func TestAsyncSubTurn_ParentWaitsForChild(t *testing.T) {
|
||||
// Check the result
|
||||
if subTurnErr != nil {
|
||||
if errors.Is(subTurnErr, context.Canceled) {
|
||||
t.Errorf("SubTurn should NOT have been cancelled: %v", subTurnErr)
|
||||
t.Errorf("SubTurn should NOT have been canceled: %v", subTurnErr)
|
||||
} else {
|
||||
t.Logf("SubTurn failed with error: %v", subTurnErr)
|
||||
}
|
||||
@@ -1912,12 +1900,12 @@ func TestFinish_GracefulVsHard(t *testing.T) {
|
||||
t.Error("parentEnded should be true after graceful finish")
|
||||
}
|
||||
|
||||
// Verify context is NOT cancelled (for graceful finish, children continue)
|
||||
// Verify context is NOT canceled (for graceful finish, children continue)
|
||||
// Note: In graceful mode, we don't call cancelFunc()
|
||||
// But since we're using WithCancel on the same ctx, it might be cancelled
|
||||
// But since we're using WithCancel on the same ctx, it might be canceled
|
||||
// Let's check that the context is still valid for a moment
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
// Context might be cancelled by the deferred cancel() in test, which is fine
|
||||
// Context might be canceled by the deferred cancel() in test, which is fine
|
||||
})
|
||||
|
||||
// Test 2: Hard abort should cancel context immediately
|
||||
@@ -1935,12 +1923,12 @@ func TestFinish_GracefulVsHard(t *testing.T) {
|
||||
// Finish with hard abort
|
||||
ts.Finish(true)
|
||||
|
||||
// Verify context is cancelled
|
||||
// Verify context is canceled
|
||||
select {
|
||||
case <-ts.ctx.Done():
|
||||
t.Log("✓ Context cancelled after hard abort")
|
||||
t.Log("✓ Context canceled after hard abort")
|
||||
default:
|
||||
t.Error("Context should be cancelled after hard abort")
|
||||
t.Error("Context should be canceled after hard abort")
|
||||
}
|
||||
})
|
||||
|
||||
@@ -1980,7 +1968,7 @@ func TestFinish_GracefulVsHard(t *testing.T) {
|
||||
}
|
||||
|
||||
// TestSubTurn_IndependentContext verifies that SubTurns use independent contexts
|
||||
// that don't get cancelled when the parent finishes gracefully.
|
||||
// that don't get canceled when the parent finishes gracefully.
|
||||
func TestSubTurn_IndependentContext(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
@@ -2029,14 +2017,14 @@ func TestSubTurn_IndependentContext(t *testing.T) {
|
||||
// Wait for SubTurn to complete
|
||||
wg.Wait()
|
||||
|
||||
// SubTurn should complete without context cancelled error
|
||||
// SubTurn should complete without context canceled error
|
||||
// (because it uses independent context now)
|
||||
if subTurnErr != nil {
|
||||
t.Logf("SubTurn error: %v", subTurnErr)
|
||||
// The error might be context.DeadlineExceeded if timeout is too short
|
||||
// but should NOT be context.Canceled from parent
|
||||
if errors.Is(subTurnErr, context.Canceled) {
|
||||
t.Error("SubTurn should not be cancelled by parent's graceful finish")
|
||||
t.Error("SubTurn should not be canceled by parent's graceful finish")
|
||||
}
|
||||
} else {
|
||||
t.Log("✓ SubTurn completed successfully (independent context)")
|
||||
|
||||
+18
-18
@@ -229,24 +229,24 @@ type SubTurnConfig struct {
|
||||
}
|
||||
|
||||
type AgentDefaults struct {
|
||||
Workspace string `json:"workspace" env:"PICOCLAW_AGENTS_DEFAULTS_WORKSPACE"`
|
||||
RestrictToWorkspace bool `json:"restrict_to_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE"`
|
||||
AllowReadOutsideWorkspace bool `json:"allow_read_outside_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_ALLOW_READ_OUTSIDE_WORKSPACE"`
|
||||
Provider string `json:"provider" env:"PICOCLAW_AGENTS_DEFAULTS_PROVIDER"`
|
||||
ModelName string `json:"model_name" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL_NAME"`
|
||||
Model string `json:"model,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL"` // Deprecated: use model_name instead
|
||||
ModelFallbacks []string `json:"model_fallbacks,omitempty"`
|
||||
ImageModel string `json:"image_model,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_IMAGE_MODEL"`
|
||||
ImageModelFallbacks []string `json:"image_model_fallbacks,omitempty"`
|
||||
MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"`
|
||||
Temperature *float64 `json:"temperature,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"`
|
||||
MaxToolIterations int `json:"max_tool_iterations" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"`
|
||||
SummarizeMessageThreshold int `json:"summarize_message_threshold" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_MESSAGE_THRESHOLD"`
|
||||
SummarizeTokenPercent int `json:"summarize_token_percent" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_TOKEN_PERCENT"`
|
||||
MaxMediaSize int `json:"max_media_size,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_MEDIA_SIZE"`
|
||||
Routing *RoutingConfig `json:"routing,omitempty"`
|
||||
SteeringMode string `json:"steering_mode,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_STEERING_MODE"` // "one-at-a-time" (default) or "all"
|
||||
SubTurn SubTurnConfig `json:"subturn" envPrefix:"PICOCLAW_AGENTS_DEFAULTS_SUBTURN_"`
|
||||
Workspace string `json:"workspace" env:"PICOCLAW_AGENTS_DEFAULTS_WORKSPACE"`
|
||||
RestrictToWorkspace bool `json:"restrict_to_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE"`
|
||||
AllowReadOutsideWorkspace bool `json:"allow_read_outside_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_ALLOW_READ_OUTSIDE_WORKSPACE"`
|
||||
Provider string `json:"provider" env:"PICOCLAW_AGENTS_DEFAULTS_PROVIDER"`
|
||||
ModelName string `json:"model_name" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL_NAME"`
|
||||
Model string `json:"model,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL"` // Deprecated: use model_name instead
|
||||
ModelFallbacks []string `json:"model_fallbacks,omitempty"`
|
||||
ImageModel string `json:"image_model,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_IMAGE_MODEL"`
|
||||
ImageModelFallbacks []string `json:"image_model_fallbacks,omitempty"`
|
||||
MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"`
|
||||
Temperature *float64 `json:"temperature,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"`
|
||||
MaxToolIterations int `json:"max_tool_iterations" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"`
|
||||
SummarizeMessageThreshold int `json:"summarize_message_threshold" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_MESSAGE_THRESHOLD"`
|
||||
SummarizeTokenPercent int `json:"summarize_token_percent" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_TOKEN_PERCENT"`
|
||||
MaxMediaSize int `json:"max_media_size,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_MEDIA_SIZE"`
|
||||
Routing *RoutingConfig `json:"routing,omitempty"`
|
||||
SteeringMode string `json:"steering_mode,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_STEERING_MODE"` // "one-at-a-time" (default) or "all"
|
||||
SubTurn SubTurnConfig `json:"subturn" envPrefix:"PICOCLAW_AGENTS_DEFAULTS_SUBTURN_"`
|
||||
}
|
||||
|
||||
const DefaultMaxMediaSize = 20 * 1024 * 1024 // 20 MB
|
||||
|
||||
+4
-1
@@ -18,6 +18,9 @@ type SpawnTool struct {
|
||||
var _ AsyncExecutor = (*SpawnTool)(nil)
|
||||
|
||||
func NewSpawnTool(manager *SubagentManager) *SpawnTool {
|
||||
if manager == nil {
|
||||
return &SpawnTool{}
|
||||
}
|
||||
return &SpawnTool{
|
||||
defaultModel: manager.defaultModel,
|
||||
maxTokens: manager.maxTokens,
|
||||
@@ -131,5 +134,5 @@ Task: %s`, label, task)
|
||||
}
|
||||
|
||||
// Fallback: spawner not configured
|
||||
return ErrorResult("SpawnTool: spawner not configured - call SetSpawner() during initialization")
|
||||
return ErrorResult("Subagent manager not configured")
|
||||
}
|
||||
|
||||
@@ -6,6 +6,24 @@ import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
// mockSpawner implements SubTurnSpawner for testing
|
||||
type mockSpawner struct{}
|
||||
|
||||
func (m *mockSpawner) SpawnSubTurn(ctx context.Context, cfg SubTurnConfig) (*ToolResult, error) {
|
||||
// Extract task from system prompt for response
|
||||
task := cfg.SystemPrompt
|
||||
if strings.Contains(task, "Task: ") {
|
||||
parts := strings.Split(task, "Task: ")
|
||||
if len(parts) > 1 {
|
||||
task = parts[1]
|
||||
}
|
||||
}
|
||||
return &ToolResult{
|
||||
ForLLM: "Task completed: " + task,
|
||||
ForUser: "Task completed",
|
||||
}, nil
|
||||
}
|
||||
|
||||
func TestSpawnTool_Execute_EmptyTask(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
|
||||
@@ -44,6 +62,7 @@ func TestSpawnTool_Execute_ValidTask(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
|
||||
tool := NewSpawnTool(manager)
|
||||
tool.SetSpawner(&mockSpawner{})
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]any{
|
||||
|
||||
@@ -308,6 +308,9 @@ type SubagentTool struct {
|
||||
}
|
||||
|
||||
func NewSubagentTool(manager *SubagentManager) *SubagentTool {
|
||||
if manager == nil {
|
||||
return &SubagentTool{}
|
||||
}
|
||||
return &SubagentTool{
|
||||
defaultModel: manager.defaultModel,
|
||||
maxTokens: manager.maxTokens,
|
||||
@@ -406,5 +409,5 @@ Task: %s`, label, task)
|
||||
}
|
||||
|
||||
// Fallback: spawner not configured
|
||||
return ErrorResult("SubagentTool: spawner not configured - call SetSpawner() during initialization").WithError(fmt.Errorf("spawner not set"))
|
||||
return ErrorResult("Subagent manager not configured").WithError(fmt.Errorf("spawner not set"))
|
||||
}
|
||||
|
||||
@@ -48,24 +48,19 @@ func TestSubagentManager_SetLLMOptions_AppliesToRunToolLoop(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
|
||||
manager.SetLLMOptions(2048, 0.6)
|
||||
tool := NewSubagentTool(manager)
|
||||
|
||||
ctx := WithToolContext(context.Background(), "cli", "direct")
|
||||
args := map[string]any{"task": "Do something"}
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
if result == nil || result.IsError {
|
||||
t.Fatalf("Expected successful result, got: %+v", result)
|
||||
// Verify options are set on manager
|
||||
if manager.maxTokens != 2048 {
|
||||
t.Errorf("manager.maxTokens = %d, want 2048", manager.maxTokens)
|
||||
}
|
||||
|
||||
if provider.lastOptions == nil {
|
||||
t.Fatal("Expected LLM options to be passed, got nil")
|
||||
if manager.temperature != 0.6 {
|
||||
t.Errorf("manager.temperature = %f, want 0.6", manager.temperature)
|
||||
}
|
||||
if provider.lastOptions["max_tokens"] != 2048 {
|
||||
t.Fatalf("max_tokens = %v, want %d", provider.lastOptions["max_tokens"], 2048)
|
||||
if !manager.hasMaxTokens {
|
||||
t.Error("manager.hasMaxTokens should be true")
|
||||
}
|
||||
if provider.lastOptions["temperature"] != 0.6 {
|
||||
t.Fatalf("temperature = %v, want %v", provider.lastOptions["temperature"], 0.6)
|
||||
if !manager.hasTemperature {
|
||||
t.Error("manager.hasTemperature should be true")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -150,6 +145,7 @@ func TestSubagentTool_Execute_Success(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
|
||||
tool := NewSubagentTool(manager)
|
||||
tool.SetSpawner(&mockSpawner{})
|
||||
|
||||
ctx := WithToolContext(context.Background(), "telegram", "chat-123")
|
||||
args := map[string]any{
|
||||
@@ -204,6 +200,7 @@ func TestSubagentTool_Execute_NoLabel(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
|
||||
tool := NewSubagentTool(manager)
|
||||
tool.SetSpawner(&mockSpawner{})
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]any{
|
||||
@@ -277,6 +274,7 @@ func TestSubagentTool_Execute_ContextPassing(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
|
||||
tool := NewSubagentTool(manager)
|
||||
tool.SetSpawner(&mockSpawner{})
|
||||
|
||||
channel := "test-channel"
|
||||
chatID := "test-chat"
|
||||
@@ -302,6 +300,7 @@ func TestSubagentTool_ForUserTruncation(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
|
||||
tool := NewSubagentTool(manager)
|
||||
tool.SetSpawner(&mockSpawner{})
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user