From b47a39af9cbbdf9cdc0ebe2a9ed3607c9d260596 Mon Sep 17 00:00:00 2001 From: winterfx Date: Tue, 24 Feb 2026 21:35:15 +0800 Subject: [PATCH] fix: handle multi-tool-call orphan detection in sanitizeHistoryForProvider Walk backwards over preceding tool messages to find the nearest assistant with ToolCalls, instead of only checking the immediate predecessor. Add unit tests for sanitizeHistoryForProvider covering key edge cases. --- pkg/agent/context.go | 15 ++- pkg/agent/context_test.go | 209 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 222 insertions(+), 2 deletions(-) create mode 100644 pkg/agent/context_test.go diff --git a/pkg/agent/context.go b/pkg/agent/context.go index a9db5afdd..7bd55d4ab 100644 --- a/pkg/agent/context.go +++ b/pkg/agent/context.go @@ -229,8 +229,19 @@ func sanitizeHistoryForProvider(history []providers.Message) []providers.Message logger.DebugCF("agent", "Dropping orphaned leading tool message", map[string]any{}) continue } - last := sanitized[len(sanitized)-1] - if last.Role != "assistant" || len(last.ToolCalls) == 0 { + // Walk backwards to find the nearest assistant message, + // skipping over any preceding tool messages (multi-tool-call case). + foundAssistant := false + for i := len(sanitized) - 1; i >= 0; i-- { + if sanitized[i].Role == "tool" { + continue + } + if sanitized[i].Role == "assistant" && len(sanitized[i].ToolCalls) > 0 { + foundAssistant = true + } + break + } + if !foundAssistant { logger.DebugCF("agent", "Dropping orphaned tool message", map[string]any{}) continue } diff --git a/pkg/agent/context_test.go b/pkg/agent/context_test.go new file mode 100644 index 000000000..e023c9c30 --- /dev/null +++ b/pkg/agent/context_test.go @@ -0,0 +1,209 @@ +package agent + +import ( + "testing" + + "github.com/sipeed/picoclaw/pkg/providers" +) + +func msg(role, content string) providers.Message { + return providers.Message{Role: role, Content: content} +} + +func assistantWithTools(toolIDs ...string) providers.Message { + calls := make([]providers.ToolCall, len(toolIDs)) + for i, id := range toolIDs { + calls[i] = providers.ToolCall{ID: id, Type: "function"} + } + return providers.Message{Role: "assistant", ToolCalls: calls} +} + +func toolResult(id string) providers.Message { + return providers.Message{Role: "tool", Content: "result", ToolCallID: id} +} + +func TestSanitizeHistoryForProvider_EmptyHistory(t *testing.T) { + result := sanitizeHistoryForProvider(nil) + if len(result) != 0 { + t.Fatalf("expected empty, got %d messages", len(result)) + } + + result = sanitizeHistoryForProvider([]providers.Message{}) + if len(result) != 0 { + t.Fatalf("expected empty, got %d messages", len(result)) + } +} + +func TestSanitizeHistoryForProvider_SingleToolCall(t *testing.T) { + history := []providers.Message{ + msg("user", "hello"), + assistantWithTools("A"), + toolResult("A"), + msg("assistant", "done"), + } + + result := sanitizeHistoryForProvider(history) + if len(result) != 4 { + t.Fatalf("expected 4 messages, got %d", len(result)) + } + assertRoles(t, result, "user", "assistant", "tool", "assistant") +} + +func TestSanitizeHistoryForProvider_MultiToolCalls(t *testing.T) { + history := []providers.Message{ + msg("user", "do two things"), + assistantWithTools("A", "B"), + toolResult("A"), + toolResult("B"), + msg("assistant", "both done"), + } + + result := sanitizeHistoryForProvider(history) + if len(result) != 5 { + t.Fatalf("expected 5 messages, got %d: %+v", len(result), roles(result)) + } + assertRoles(t, result, "user", "assistant", "tool", "tool", "assistant") +} + +func TestSanitizeHistoryForProvider_AssistantToolCallAfterPlainAssistant(t *testing.T) { + history := []providers.Message{ + msg("user", "hi"), + msg("assistant", "thinking"), + assistantWithTools("A"), + toolResult("A"), + } + + result := sanitizeHistoryForProvider(history) + if len(result) != 2 { + t.Fatalf("expected 2 messages, got %d: %+v", len(result), roles(result)) + } + assertRoles(t, result, "user", "assistant") +} + +func TestSanitizeHistoryForProvider_OrphanedLeadingTool(t *testing.T) { + history := []providers.Message{ + toolResult("A"), + msg("user", "hello"), + } + + result := sanitizeHistoryForProvider(history) + if len(result) != 1 { + t.Fatalf("expected 1 message, got %d: %+v", len(result), roles(result)) + } + assertRoles(t, result, "user") +} + +func TestSanitizeHistoryForProvider_ToolAfterUserDropped(t *testing.T) { + history := []providers.Message{ + msg("user", "hello"), + toolResult("A"), + } + + result := sanitizeHistoryForProvider(history) + if len(result) != 1 { + t.Fatalf("expected 1 message, got %d: %+v", len(result), roles(result)) + } + assertRoles(t, result, "user") +} + +func TestSanitizeHistoryForProvider_ToolAfterAssistantNoToolCalls(t *testing.T) { + history := []providers.Message{ + msg("user", "hello"), + msg("assistant", "hi"), + toolResult("A"), + } + + result := sanitizeHistoryForProvider(history) + if len(result) != 2 { + t.Fatalf("expected 2 messages, got %d: %+v", len(result), roles(result)) + } + assertRoles(t, result, "user", "assistant") +} + +func TestSanitizeHistoryForProvider_AssistantToolCallAtStart(t *testing.T) { + history := []providers.Message{ + assistantWithTools("A"), + toolResult("A"), + msg("user", "hello"), + } + + result := sanitizeHistoryForProvider(history) + if len(result) != 1 { + t.Fatalf("expected 1 message, got %d: %+v", len(result), roles(result)) + } + assertRoles(t, result, "user") +} + +func TestSanitizeHistoryForProvider_MultiToolCallsThenNewRound(t *testing.T) { + history := []providers.Message{ + msg("user", "do two things"), + assistantWithTools("A", "B"), + toolResult("A"), + toolResult("B"), + msg("assistant", "done"), + msg("user", "hi"), + assistantWithTools("C"), + toolResult("C"), + msg("assistant", "done again"), + } + + result := sanitizeHistoryForProvider(history) + if len(result) != 9 { + t.Fatalf("expected 9 messages, got %d: %+v", len(result), roles(result)) + } + assertRoles(t, result, "user", "assistant", "tool", "tool", "assistant", "user", "assistant", "tool", "assistant") +} + +func TestSanitizeHistoryForProvider_ConsecutiveMultiToolRounds(t *testing.T) { + history := []providers.Message{ + msg("user", "start"), + assistantWithTools("A", "B"), + toolResult("A"), + toolResult("B"), + assistantWithTools("C", "D"), + toolResult("C"), + toolResult("D"), + msg("assistant", "all done"), + } + + result := sanitizeHistoryForProvider(history) + if len(result) != 8 { + t.Fatalf("expected 8 messages, got %d: %+v", len(result), roles(result)) + } + assertRoles(t, result, "user", "assistant", "tool", "tool", "assistant", "tool", "tool", "assistant") +} + +func TestSanitizeHistoryForProvider_PlainConversation(t *testing.T) { + history := []providers.Message{ + msg("user", "hello"), + msg("assistant", "hi"), + msg("user", "how are you"), + msg("assistant", "fine"), + } + + result := sanitizeHistoryForProvider(history) + if len(result) != 4 { + t.Fatalf("expected 4 messages, got %d", len(result)) + } + assertRoles(t, result, "user", "assistant", "user", "assistant") +} + +func roles(msgs []providers.Message) []string { + r := make([]string, len(msgs)) + for i, m := range msgs { + r[i] = m.Role + } + return r +} + +func assertRoles(t *testing.T, msgs []providers.Message, expected ...string) { + t.Helper() + if len(msgs) != len(expected) { + t.Fatalf("role count mismatch: got %v, want %v", roles(msgs), expected) + } + for i, exp := range expected { + if msgs[i].Role != exp { + t.Errorf("message[%d]: got role %q, want %q", i, msgs[i].Role, exp) + } + } +}