fix: handle multi-tool-call orphan detection in sanitizeHistoryForProvider

Walk backwards over preceding tool messages to find the nearest assistant
with ToolCalls, instead of only checking the immediate predecessor. Add
unit tests for sanitizeHistoryForProvider covering key edge cases.
This commit is contained in:
winterfx
2026-02-24 21:35:15 +08:00
parent 18ba88869a
commit b47a39af9c
2 changed files with 222 additions and 2 deletions
+13 -2
View File
@@ -229,8 +229,19 @@ func sanitizeHistoryForProvider(history []providers.Message) []providers.Message
logger.DebugCF("agent", "Dropping orphaned leading tool message", map[string]any{})
continue
}
last := sanitized[len(sanitized)-1]
if last.Role != "assistant" || len(last.ToolCalls) == 0 {
// Walk backwards to find the nearest assistant message,
// skipping over any preceding tool messages (multi-tool-call case).
foundAssistant := false
for i := len(sanitized) - 1; i >= 0; i-- {
if sanitized[i].Role == "tool" {
continue
}
if sanitized[i].Role == "assistant" && len(sanitized[i].ToolCalls) > 0 {
foundAssistant = true
}
break
}
if !foundAssistant {
logger.DebugCF("agent", "Dropping orphaned tool message", map[string]any{})
continue
}
+209
View File
@@ -0,0 +1,209 @@
package agent
import (
"testing"
"github.com/sipeed/picoclaw/pkg/providers"
)
func msg(role, content string) providers.Message {
return providers.Message{Role: role, Content: content}
}
func assistantWithTools(toolIDs ...string) providers.Message {
calls := make([]providers.ToolCall, len(toolIDs))
for i, id := range toolIDs {
calls[i] = providers.ToolCall{ID: id, Type: "function"}
}
return providers.Message{Role: "assistant", ToolCalls: calls}
}
func toolResult(id string) providers.Message {
return providers.Message{Role: "tool", Content: "result", ToolCallID: id}
}
func TestSanitizeHistoryForProvider_EmptyHistory(t *testing.T) {
result := sanitizeHistoryForProvider(nil)
if len(result) != 0 {
t.Fatalf("expected empty, got %d messages", len(result))
}
result = sanitizeHistoryForProvider([]providers.Message{})
if len(result) != 0 {
t.Fatalf("expected empty, got %d messages", len(result))
}
}
func TestSanitizeHistoryForProvider_SingleToolCall(t *testing.T) {
history := []providers.Message{
msg("user", "hello"),
assistantWithTools("A"),
toolResult("A"),
msg("assistant", "done"),
}
result := sanitizeHistoryForProvider(history)
if len(result) != 4 {
t.Fatalf("expected 4 messages, got %d", len(result))
}
assertRoles(t, result, "user", "assistant", "tool", "assistant")
}
func TestSanitizeHistoryForProvider_MultiToolCalls(t *testing.T) {
history := []providers.Message{
msg("user", "do two things"),
assistantWithTools("A", "B"),
toolResult("A"),
toolResult("B"),
msg("assistant", "both done"),
}
result := sanitizeHistoryForProvider(history)
if len(result) != 5 {
t.Fatalf("expected 5 messages, got %d: %+v", len(result), roles(result))
}
assertRoles(t, result, "user", "assistant", "tool", "tool", "assistant")
}
func TestSanitizeHistoryForProvider_AssistantToolCallAfterPlainAssistant(t *testing.T) {
history := []providers.Message{
msg("user", "hi"),
msg("assistant", "thinking"),
assistantWithTools("A"),
toolResult("A"),
}
result := sanitizeHistoryForProvider(history)
if len(result) != 2 {
t.Fatalf("expected 2 messages, got %d: %+v", len(result), roles(result))
}
assertRoles(t, result, "user", "assistant")
}
func TestSanitizeHistoryForProvider_OrphanedLeadingTool(t *testing.T) {
history := []providers.Message{
toolResult("A"),
msg("user", "hello"),
}
result := sanitizeHistoryForProvider(history)
if len(result) != 1 {
t.Fatalf("expected 1 message, got %d: %+v", len(result), roles(result))
}
assertRoles(t, result, "user")
}
func TestSanitizeHistoryForProvider_ToolAfterUserDropped(t *testing.T) {
history := []providers.Message{
msg("user", "hello"),
toolResult("A"),
}
result := sanitizeHistoryForProvider(history)
if len(result) != 1 {
t.Fatalf("expected 1 message, got %d: %+v", len(result), roles(result))
}
assertRoles(t, result, "user")
}
func TestSanitizeHistoryForProvider_ToolAfterAssistantNoToolCalls(t *testing.T) {
history := []providers.Message{
msg("user", "hello"),
msg("assistant", "hi"),
toolResult("A"),
}
result := sanitizeHistoryForProvider(history)
if len(result) != 2 {
t.Fatalf("expected 2 messages, got %d: %+v", len(result), roles(result))
}
assertRoles(t, result, "user", "assistant")
}
func TestSanitizeHistoryForProvider_AssistantToolCallAtStart(t *testing.T) {
history := []providers.Message{
assistantWithTools("A"),
toolResult("A"),
msg("user", "hello"),
}
result := sanitizeHistoryForProvider(history)
if len(result) != 1 {
t.Fatalf("expected 1 message, got %d: %+v", len(result), roles(result))
}
assertRoles(t, result, "user")
}
func TestSanitizeHistoryForProvider_MultiToolCallsThenNewRound(t *testing.T) {
history := []providers.Message{
msg("user", "do two things"),
assistantWithTools("A", "B"),
toolResult("A"),
toolResult("B"),
msg("assistant", "done"),
msg("user", "hi"),
assistantWithTools("C"),
toolResult("C"),
msg("assistant", "done again"),
}
result := sanitizeHistoryForProvider(history)
if len(result) != 9 {
t.Fatalf("expected 9 messages, got %d: %+v", len(result), roles(result))
}
assertRoles(t, result, "user", "assistant", "tool", "tool", "assistant", "user", "assistant", "tool", "assistant")
}
func TestSanitizeHistoryForProvider_ConsecutiveMultiToolRounds(t *testing.T) {
history := []providers.Message{
msg("user", "start"),
assistantWithTools("A", "B"),
toolResult("A"),
toolResult("B"),
assistantWithTools("C", "D"),
toolResult("C"),
toolResult("D"),
msg("assistant", "all done"),
}
result := sanitizeHistoryForProvider(history)
if len(result) != 8 {
t.Fatalf("expected 8 messages, got %d: %+v", len(result), roles(result))
}
assertRoles(t, result, "user", "assistant", "tool", "tool", "assistant", "tool", "tool", "assistant")
}
func TestSanitizeHistoryForProvider_PlainConversation(t *testing.T) {
history := []providers.Message{
msg("user", "hello"),
msg("assistant", "hi"),
msg("user", "how are you"),
msg("assistant", "fine"),
}
result := sanitizeHistoryForProvider(history)
if len(result) != 4 {
t.Fatalf("expected 4 messages, got %d", len(result))
}
assertRoles(t, result, "user", "assistant", "user", "assistant")
}
func roles(msgs []providers.Message) []string {
r := make([]string, len(msgs))
for i, m := range msgs {
r[i] = m.Role
}
return r
}
func assertRoles(t *testing.T, msgs []providers.Message, expected ...string) {
t.Helper()
if len(msgs) != len(expected) {
t.Fatalf("role count mismatch: got %v, want %v", roles(msgs), expected)
}
for i, exp := range expected {
if msgs[i].Role != exp {
t.Errorf("message[%d]: got role %q, want %q", i, msgs[i].Role, exp)
}
}
}