diff --git a/cmd/picoclaw/internal/auth/wecom_test.go b/cmd/picoclaw/internal/auth/wecom_test.go index c152481be..aafd39e69 100644 --- a/cmd/picoclaw/internal/auth/wecom_test.go +++ b/cmd/picoclaw/internal/auth/wecom_test.go @@ -3,6 +3,7 @@ package auth import ( "bytes" "context" + "net" "net/http" "net/http/httptest" "net/url" @@ -19,6 +20,19 @@ import ( "github.com/sipeed/picoclaw/pkg/config" ) +func newIPv4TestServer(t *testing.T, handler http.Handler) *httptest.Server { + t.Helper() + + server := httptest.NewUnstartedServer(handler) + listener, err := net.Listen("tcp4", "127.0.0.1:0") + require.NoError(t, err) + + server.Listener = listener + server.Start() + t.Cleanup(server.Close) + return server +} + func TestNewWeComCommand(t *testing.T) { cmd := newWeComCommand() @@ -53,7 +67,7 @@ func TestBuildWeComQRCodePageURL(t *testing.T) { } func TestFetchWeComQRCode(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + server := newIPv4TestServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { assert.Equal(t, "/generate", r.URL.Path) assert.Equal(t, wecomQRSourceID, r.URL.Query().Get("source")) assert.Equal(t, wecomQRSourceID, r.URL.Query().Get("sourceID")) @@ -61,7 +75,6 @@ func TestFetchWeComQRCode(t *testing.T) { w.Header().Set("Content-Type", "application/json") _, _ = w.Write([]byte(`{"data":{"scode":"scode-1","auth_url":"https://example.com/qr"}}`)) })) - defer server.Close() opts := normalizeWeComQRFlowOptions(wecomQRFlowOptions{ HTTPClient: server.Client(), @@ -78,7 +91,7 @@ func TestFetchWeComQRCode(t *testing.T) { func TestPollWeComQRCodeResult(t *testing.T) { var calls atomic.Int32 - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + server := newIPv4TestServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { call := calls.Add(1) assert.Equal(t, "/query", r.URL.Path) assert.Equal(t, "scode-1", r.URL.Query().Get("scode")) @@ -92,7 +105,6 @@ func TestPollWeComQRCodeResult(t *testing.T) { _, _ = w.Write([]byte(`{"data":{"status":"success","bot_info":{"botid":"bot-1","secret":"secret-1"}}}`)) } })) - defer server.Close() var output bytes.Buffer opts := normalizeWeComQRFlowOptions(wecomQRFlowOptions{ diff --git a/docs/channels/discord/README.md b/docs/channels/discord/README.md index 771289d28..741bc64a1 100644 --- a/docs/channels/discord/README.md +++ b/docs/channels/discord/README.md @@ -8,26 +8,56 @@ Discord is a free voice, video, and text chat application designed for communiti ```json { + "agents": { + "defaults": { + "tool_feedback": { + "enabled": true, + "max_args_length": 300 + } + } + }, "channel_list": { "discord": { "enabled": true, "type": "discord", "token": "YOUR_BOT_TOKEN", "allow_from": ["YOUR_USER_ID"], + "placeholder": { + "enabled": true, + "text": ["Thinking... 💭"] + }, "group_trigger": { "mention_only": false - } + }, + "reasoning_channel_id": "" } } } ``` -| Field | Type | Required | Description | -| ------------- | ------ | -------- | --------------------------------------------------------------------------- | -| enabled | bool | Yes | Whether to enable the Discord channel | -| token | string | Yes | Discord Bot Token | -| allow_from | array | No | Allowlist of user IDs; empty means all users are allowed | -| group_trigger | object | No | Group trigger settings (example: { "mention_only": false }) | +| Field | Type | Required | Description | +| -------------------- | ------ | -------- | --------------------------------------------------------------------------- | +| enabled | bool | Yes | Whether to enable the Discord channel | +| token | string | Yes | Discord Bot Token | +| allow_from | array | No | Allowlist of user IDs; empty means all users are allowed | +| placeholder | object | No | Placeholder message config shown while the agent is working | +| group_trigger | object | No | Group trigger settings (example: { "mention_only": false }) | +| reasoning_channel_id | string | No | Optional target channel ID for reasoning/thinking output | + +## Visible Execution Feedback + +Discord can show three different kinds of "working" feedback: + +1. Typing indicator: automatic, no extra config needed. +2. Placeholder message: enable `channel_list.discord.placeholder.enabled` to send a visible `Thinking...` message that is later edited into the final reply. +3. Tool execution feedback: enable `agents.defaults.tool_feedback.enabled` to send a short message before each tool call, for example: + +```text +🔧 `web_search` +Checking the latest PicoClaw release notes before I answer. +``` + +If you only see `Bot is typing`, check that `placeholder.enabled` or `tool_feedback.enabled` is actually set in your runtime config. ## Setup diff --git a/pkg/agent/agent.go b/pkg/agent/agent.go index 3c242eecb..3e9bd845e 100644 --- a/pkg/agent/agent.go +++ b/pkg/agent/agent.go @@ -112,6 +112,7 @@ const ( pendingTurnPrefix = "pending-" metadataKeyMessageKind = "message_kind" messageKindThought = "thought" + messageKindToolFeedback = "tool_feedback" metadataKeyAccountID = "account_id" metadataKeyGuildID = "guild_id" metadataKeyTeamID = "team_id" diff --git a/pkg/agent/agent_test.go b/pkg/agent/agent_test.go index 313569153..2addc0535 100644 --- a/pkg/agent/agent_test.go +++ b/pkg/agent/agent_test.go @@ -24,6 +24,7 @@ import ( "github.com/sipeed/picoclaw/pkg/routing" "github.com/sipeed/picoclaw/pkg/session" "github.com/sipeed/picoclaw/pkg/tools" + "github.com/sipeed/picoclaw/pkg/utils" ) type fakeChannel struct{ id string } @@ -1761,6 +1762,157 @@ func (m *toolFeedbackProvider) GetDefaultModel() string { return "heartbeat-tool-feedback-model" } +type toolFeedbackReasoningProvider struct { + filePath string + calls int +} + +func (m *toolFeedbackReasoningProvider) Chat( + ctx context.Context, + messages []providers.Message, + tools []providers.ToolDefinition, + model string, + opts map[string]any, +) (*providers.LLMResponse, error) { + m.calls++ + if m.calls == 1 { + return &providers.LLMResponse{ + ReasoningContent: "Read README.md first to confirm the context that needs to be changed.", + ToolCalls: []providers.ToolCall{{ + ID: "call_reasoning_read_file", + Type: "function", + Name: "read_file", + Arguments: map[string]any{"path": m.filePath}, + }}, + }, nil + } + + return &providers.LLMResponse{ + Content: "DONE", + ToolCalls: []providers.ToolCall{}, + }, nil +} + +func (m *toolFeedbackReasoningProvider) GetDefaultModel() string { + return "tool-feedback-reasoning-model" +} + +func TestToolFeedbackExplanationFromResponse_UsesCurrentContentFirst(t *testing.T) { + response := &providers.LLMResponse{ + Content: "Read README.md first", + ReasoningContent: "current reasoning fallback", + } + messages := []providers.Message{ + {Role: "user", Content: "check file"}, + {Role: "assistant", Content: "Previous turn explanation"}, + {Role: "tool", Content: "tool output", ToolCallID: "call_1"}, + } + + got := toolFeedbackExplanationFromResponse(response, messages, 300) + if got != "Read README.md first" { + t.Fatalf("toolFeedbackExplanationFromResponse() = %q, want current content", got) + } +} + +func TestToolFeedbackExplanationFromResponse_UsesExplicitToolCallExtraContent(t *testing.T) { + response := &providers.LLMResponse{ + ToolCalls: []providers.ToolCall{{ + ID: "call_1", + Name: "read_file", + ExtraContent: &providers.ExtraContent{ + ToolFeedbackExplanation: "Read README.md first to confirm the current project structure.", + }, + }}, + } + messages := []providers.Message{ + {Role: "user", Content: "check file"}, + {Role: "assistant", Content: ""}, + {Role: "tool", Content: "tool output", ToolCallID: "call_1"}, + } + + got := toolFeedbackExplanationFromResponse(response, messages, 300) + if got != "Read README.md first to confirm the current project structure." { + t.Fatalf("toolFeedbackExplanationFromResponse() = %q, want explicit tool feedback explanation", got) + } +} + +func TestToolFeedbackExplanationForToolCall_PrefersToolSpecificExtraContent(t *testing.T) { + response := &providers.LLMResponse{ + Content: "Shared explanation", + ToolCalls: []providers.ToolCall{ + { + ID: "call_1", + Name: "read_file", + ExtraContent: &providers.ExtraContent{ + ToolFeedbackExplanation: "Read README.md first.", + }, + }, + { + ID: "call_2", + Name: "edit_file", + ExtraContent: &providers.ExtraContent{ + ToolFeedbackExplanation: "Update config example after reading it.", + }, + }, + }, + } + + got1 := toolFeedbackExplanationForToolCall(response, response.ToolCalls[0], nil, 300) + got2 := toolFeedbackExplanationForToolCall(response, response.ToolCalls[1], nil, 300) + if got1 != "Read README.md first." { + t.Fatalf("toolFeedbackExplanationForToolCall() first = %q, want tool-specific explanation", got1) + } + if got2 != "Update config example after reading it." { + t.Fatalf("toolFeedbackExplanationForToolCall() second = %q, want tool-specific explanation", got2) + } +} + +func TestToolFeedbackExplanationForToolCall_DoesNotReuseAnotherToolCallExplanation(t *testing.T) { + response := &providers.LLMResponse{ + ToolCalls: []providers.ToolCall{ + { + ID: "call_1", + Name: "read_file", + }, + { + ID: "call_2", + Name: "edit_file", + ExtraContent: &providers.ExtraContent{ + ToolFeedbackExplanation: "Update config example after reading it.", + }, + }, + }, + } + messages := []providers.Message{ + {Role: "user", Content: "inspect the config and update the example"}, + } + + got := toolFeedbackExplanationForToolCall(response, response.ToolCalls[0], messages, 300) + want := utils.ToolFeedbackContinuationHint + ": inspect the config and update the example" + if got != want { + t.Fatalf("toolFeedbackExplanationForToolCall() = %q, want %q", got, want) + } +} + +func TestToolFeedbackExplanationFromResponse_DoesNotUseReasoningContent(t *testing.T) { + response := &providers.LLMResponse{ + Content: "", + ReasoningContent: "hidden reasoning should not be shown", + } + messages := []providers.Message{ + {Role: "user", Content: "check file"}, + {Role: "assistant", Content: "Previous turn explanation"}, + {Role: "user", Content: "Inspect README.md and update the config example."}, + {Role: "tool", Content: "tool output", ToolCallID: "call_1"}, + } + + got := toolFeedbackExplanationFromResponse(response, messages, 300) + want := utils.ToolFeedbackContinuationHint + ": Inspect README.md and update the config example." + if got != want { + t.Fatalf("toolFeedbackExplanationFromResponse() = %q, want latest user content fallback", got) + } +} + type picoInterleavedContentProvider struct { calls int } @@ -3728,7 +3880,16 @@ func TestProcessMessage_PublishesToolFeedbackWhenEnabled(t *testing.T) { t.Fatalf("unexpected tool feedback context: %+v", outbound.Context) } if !strings.Contains(outbound.Content, "`read_file`") { - t.Fatalf("tool feedback content = %q, want read_file preview", outbound.Content) + t.Fatalf("tool feedback content = %q, want read_file summary", outbound.Content) + } + if !strings.Contains(outbound.Content, utils.ToolFeedbackContinuationHint) { + t.Fatalf("tool feedback content = %q, want continuation hint fallback", outbound.Content) + } + if !strings.Contains(outbound.Content, "check tool feedback") { + t.Fatalf("tool feedback content = %q, want current user intent fallback", outbound.Content) + } + if strings.Contains(outbound.Content, "Previous turn explanation") { + t.Fatalf("tool feedback content = %q, want no previous assistant fallback", outbound.Content) } if outbound.AgentID != "main" { t.Fatalf("tool feedback agent_id = %q, want main", outbound.AgentID) @@ -3744,6 +3905,130 @@ func TestProcessMessage_PublishesToolFeedbackWhenEnabled(t *testing.T) { } } +func TestProcessMessage_DoesNotLeakReasoningContentInToolFeedback(t *testing.T) { + tmpDir := t.TempDir() + heartbeatFile := filepath.Join(tmpDir, "tool-feedback-reasoning.txt") + if err := os.WriteFile(heartbeatFile, []byte("tool feedback task"), 0o644); err != nil { + t.Fatalf("WriteFile() error = %v", err) + } + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + ModelName: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + ToolFeedback: config.ToolFeedbackConfig{ + Enabled: true, + MaxArgsLength: 300, + }, + }, + }, + Tools: config.ToolsConfig{ + ReadFile: config.ReadFileToolConfig{ + Enabled: true, + }, + }, + } + + msgBus := bus.NewMessageBus() + provider := &toolFeedbackReasoningProvider{filePath: heartbeatFile} + al := NewAgentLoop(cfg, msgBus, provider) + + response, err := al.processMessage(context.Background(), testInboundMessage(bus.InboundMessage{ + Channel: "telegram", + SenderID: "user-1", + ChatID: "chat-1", + Content: "check reasoning fallback", + })) + if err != nil { + t.Fatalf("processMessage() error = %v", err) + } + if response != "DONE" { + t.Fatalf("processMessage() response = %q, want %q", response, "DONE") + } + + select { + case outbound := <-msgBus.OutboundChan(): + if !strings.Contains(outbound.Content, "`read_file`") { + t.Fatalf("tool feedback content = %q, want read_file summary", outbound.Content) + } + if !strings.Contains(outbound.Content, utils.ToolFeedbackContinuationHint) { + t.Fatalf("tool feedback content = %q, want continuation hint fallback", outbound.Content) + } + if !strings.Contains(outbound.Content, "check reasoning fallback") { + t.Fatalf("tool feedback content = %q, want current user intent fallback", outbound.Content) + } + if strings.Contains(outbound.Content, "Read README.md first") { + t.Fatalf("tool feedback content = %q, should not leak hidden reasoning", outbound.Content) + } + case <-time.After(2 * time.Second): + t.Fatal("expected outbound tool feedback without leaking reasoning") + } +} + +func TestProcessMessage_DoesNotPublishToolFeedbackForDiscordWhenDisabled(t *testing.T) { + assertToolFeedbackNotPublishedWhenDisabled(t, "discord") +} + +func assertToolFeedbackNotPublishedWhenDisabled(t *testing.T, channel string) { + t.Helper() + + tmpDir := t.TempDir() + heartbeatFile := filepath.Join(tmpDir, "tool-feedback-"+channel+".txt") + if err := os.WriteFile(heartbeatFile, []byte("tool feedback task"), 0o644); err != nil { + t.Fatalf("WriteFile() error = %v", err) + } + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + ModelName: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + Tools: config.ToolsConfig{ + ReadFile: config.ReadFileToolConfig{ + Enabled: true, + }, + }, + } + + msgBus := bus.NewMessageBus() + provider := &toolFeedbackProvider{filePath: heartbeatFile} + al := NewAgentLoop(cfg, msgBus, provider) + + response, err := al.processMessage(context.Background(), testInboundMessage(bus.InboundMessage{ + Channel: channel, + SenderID: "user-1", + ChatID: "chat-1", + Content: "check tool feedback", + })) + if err != nil { + t.Fatalf("processMessage() error = %v", err) + } + if response != "HEARTBEAT_OK" { + t.Fatalf("processMessage() response = %q, want %q", response, "HEARTBEAT_OK") + } + + select { + case outbound := <-msgBus.OutboundChan(): + t.Fatalf("expected no outbound tool feedback for %s when disabled, got %+v", channel, outbound) + case <-time.After(200 * time.Millisecond): + } +} + +func TestProcessMessage_DoesNotPublishToolFeedbackForTelegramWhenDisabled(t *testing.T) { + assertToolFeedbackNotPublishedWhenDisabled(t, "telegram") +} + +func TestProcessMessage_DoesNotPublishToolFeedbackForFeishuWhenDisabled(t *testing.T) { + assertToolFeedbackNotPublishedWhenDisabled(t, "feishu") +} + func TestProcessMessage_MessageToolPublishesOutboundWithTurnMetadata(t *testing.T) { cfg := config.DefaultConfig() cfg.Agents.Defaults.Workspace = t.TempDir() @@ -3918,6 +4203,85 @@ func TestRunAgentLoop_PicoSkipsInterimPublishWhenNotAllowed(t *testing.T) { } } +func TestRun_PicoToolFeedbackSuppressesDuplicateInterimAssistantContent(t *testing.T) { + tmpDir := t.TempDir() + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + ModelName: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + ToolFeedback: config.ToolFeedbackConfig{ + Enabled: true, + }, + }, + }, + } + + msgBus := bus.NewMessageBus() + provider := &picoInterleavedContentProvider{} + al := NewAgentLoop(cfg, msgBus, provider) + + agent := al.GetRegistry().GetDefaultAgent() + if agent == nil { + t.Fatal("expected default agent") + } + agent.Tools.Register(&toolLimitTestTool{}) + + runCtx, runCancel := context.WithCancel(context.Background()) + defer runCancel() + + runDone := make(chan error, 1) + go func() { + runDone <- al.Run(runCtx) + }() + + if err := msgBus.PublishInbound(context.Background(), bus.InboundMessage{ + Channel: "pico", + SenderID: "user-1", + ChatID: "session-1", + Content: "run with tools", + }); err != nil { + t.Fatalf("PublishInbound() error = %v", err) + } + + outputs := make([]string, 0, 2) + deadline := time.After(2 * time.Second) + for len(outputs) < 2 { + select { + case outbound := <-msgBus.OutboundChan(): + outputs = append(outputs, outbound.Content) + case <-deadline: + t.Fatalf("timed out waiting for pico outputs, got %v", outputs) + } + } + + if outputs[0] != "🔧 `tool_limit_test_tool`\nintermediate model text" { + t.Fatalf("first outbound content = %q, want tool feedback summary", outputs[0]) + } + if outputs[1] != "final model text" { + t.Fatalf("second outbound content = %q, want %q", outputs[1], "final model text") + } + + runCancel() + select { + case err := <-runDone: + if err != nil { + t.Fatalf("Run() error = %v", err) + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for Run() to exit") + } + + select { + case outbound := <-msgBus.OutboundChan(): + t.Fatalf("unexpected extra pico output after tool feedback + final reply: %+v", outbound) + case <-time.After(200 * time.Millisecond): + } +} + func TestResolveMediaRefs_ResolvesToBase64(t *testing.T) { store := media.NewFileMediaStore() dir := t.TempDir() diff --git a/pkg/agent/agent_utils.go b/pkg/agent/agent_utils.go index 2574f0222..ff98dad68 100644 --- a/pkg/agent/agent_utils.go +++ b/pkg/agent/agent_utils.go @@ -11,6 +11,7 @@ import ( "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/commands" + "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/providers" "github.com/sipeed/picoclaw/pkg/session" "github.com/sipeed/picoclaw/pkg/utils" @@ -84,6 +85,98 @@ func outboundMessageForTurn(ts *turnState, content string) bus.OutboundMessage { } } +func outboundMessageForTurnWithKind(ts *turnState, content, kind string) bus.OutboundMessage { + msg := outboundMessageForTurn(ts, content) + if strings.TrimSpace(kind) == "" { + return msg + } + if msg.Context.Raw == nil { + msg.Context.Raw = make(map[string]string, 1) + } + msg.Context.Raw[metadataKeyMessageKind] = kind + return msg +} + +func latestUserContent(messages []providers.Message) string { + for i := len(messages) - 1; i >= 0; i-- { + msg := messages[i] + if msg.Role != "user" { + continue + } + if content := strings.TrimSpace(msg.Content); content != "" { + return content + } + } + return "" +} + +func toolFeedbackExplanationFromResponse( + response *providers.LLMResponse, + messages []providers.Message, + maxLen int, +) string { + if response == nil { + return "" + } + explanation := strings.TrimSpace(response.Content) + if explanation == "" { + explanation = toolFeedbackExplanationFromToolCalls(response.ToolCalls) + } + if explanation == "" { + explanation = toolFeedbackExplanationFromMessages(messages) + } + return utils.Truncate(explanation, maxLen) +} + +func toolFeedbackExplanationFromToolCalls(toolCalls []providers.ToolCall) string { + for _, tc := range toolCalls { + if tc.ExtraContent == nil { + continue + } + if explanation := strings.TrimSpace(tc.ExtraContent.ToolFeedbackExplanation); explanation != "" { + return explanation + } + } + return "" +} + +func toolFeedbackExplanationForToolCall( + response *providers.LLMResponse, + toolCall providers.ToolCall, + messages []providers.Message, + maxLen int, +) string { + if toolCall.ExtraContent != nil { + if explanation := strings.TrimSpace(toolCall.ExtraContent.ToolFeedbackExplanation); explanation != "" { + return utils.Truncate(explanation, maxLen) + } + } + if response == nil { + return utils.Truncate(toolFeedbackExplanationFromMessages(messages), maxLen) + } + + explanation := strings.TrimSpace(response.Content) + if explanation == "" { + explanation = toolFeedbackExplanationFromMessages(messages) + } + return utils.Truncate(explanation, maxLen) +} + +func toolFeedbackExplanationFromMessages(messages []providers.Message) string { + explanation := latestUserContent(messages) + if explanation != "" { + return utils.ToolFeedbackContinuationHint + ": " + explanation + } + return "" +} + +func shouldPublishToolFeedback(cfg *config.Config, ts *turnState) bool { + if ts == nil || ts.channel == "" || ts.opts.SuppressToolFeedback { + return false + } + return cfg != nil && cfg.Agents.Defaults.IsToolFeedbackEnabled() +} + func cloneEventArguments(args map[string]any) map[string]any { if len(args) == 0 { return nil diff --git a/pkg/agent/hooks_test.go b/pkg/agent/hooks_test.go index f024cba04..1cfa341a7 100644 --- a/pkg/agent/hooks_test.go +++ b/pkg/agent/hooks_test.go @@ -4,6 +4,7 @@ import ( "context" "errors" "os" + "strings" "sync" "testing" "time" @@ -403,6 +404,24 @@ func (h *toolRewriteHook) AfterTool( return next, HookDecision{Action: HookActionModify}, nil } +type toolRenameHook struct{} + +func (h *toolRenameHook) BeforeTool( + ctx context.Context, + call *ToolCallHookRequest, +) (*ToolCallHookRequest, HookDecision, error) { + next := call.Clone() + next.Tool = "echo_text_rewritten" + return next, HookDecision{Action: HookActionModify}, nil +} + +func (h *toolRenameHook) AfterTool( + ctx context.Context, + result *ToolResultHookResponse, +) (*ToolResultHookResponse, HookDecision, error) { + return result.Clone(), HookDecision{Action: HookActionContinue}, nil +} + func TestAgentLoop_Hooks_ToolInterceptorCanRewrite(t *testing.T) { provider := &toolHookProvider{} al, agent, cleanup := newHookTestLoop(t, provider) @@ -430,6 +449,75 @@ func TestAgentLoop_Hooks_ToolInterceptorCanRewrite(t *testing.T) { } } +type echoTextRewrittenTool struct{} + +func (t *echoTextRewrittenTool) Name() string { + return "echo_text_rewritten" +} + +func (t *echoTextRewrittenTool) Description() string { + return "echo a rewritten text argument" +} + +func (t *echoTextRewrittenTool) Parameters() map[string]any { + return map[string]any{ + "type": "object", + "properties": map[string]any{ + "text": map[string]any{ + "type": "string", + }, + }, + } +} + +func (t *echoTextRewrittenTool) Execute(ctx context.Context, args map[string]any) *tools.ToolResult { + text, _ := args["text"].(string) + return tools.SilentResult("rewritten:" + text) +} + +func TestAgentLoop_Hooks_ToolFeedbackUsesRewrittenToolName(t *testing.T) { + provider := &toolHookProvider{} + al, agent, cleanup := newHookTestLoop(t, provider) + defer cleanup() + + al.cfg.Agents.Defaults.ToolFeedback.Enabled = true + al.RegisterTool(&echoTextTool{}) + al.RegisterTool(&echoTextRewrittenTool{}) + if err := al.MountHook(NamedHook("tool-rename", &toolRenameHook{})); err != nil { + t.Fatalf("MountHook failed: %v", err) + } + + _, err := al.runAgentLoop(context.Background(), agent, processOptions{ + SessionKey: "session-1", + Channel: "cli", + ChatID: "direct", + UserMessage: "run tool", + DefaultResponse: defaultResponse, + EnableSummary: false, + SendResponse: false, + }) + if err != nil { + t.Fatalf("runAgentLoop failed: %v", err) + } + + msgBus, ok := al.bus.(*bus.MessageBus) + if !ok { + t.Fatalf("expected concrete MessageBus, got %T", al.bus) + } + + select { + case outbound := <-msgBus.OutboundChan(): + if !strings.Contains(outbound.Content, "`echo_text_rewritten`") { + t.Fatalf("tool feedback content = %q, want rewritten tool name", outbound.Content) + } + if strings.Contains(outbound.Content, "`echo_text`") { + t.Fatalf("tool feedback content = %q, want no original tool name", outbound.Content) + } + case <-time.After(2 * time.Second): + t.Fatal("expected outbound tool feedback") + } +} + type denyApprovalHook struct{} func (h *denyApprovalHook) ApproveTool(ctx context.Context, req *ToolApprovalRequest) (ApprovalDecision, error) { @@ -804,6 +892,77 @@ func TestAgentLoop_HookRespond_BusFallback(t *testing.T) { } } +func TestAgentLoop_HookRespond_ResponseHandledMediaPreservesOutboundContext(t *testing.T) { + provider := &multiToolProvider{ + toolCalls: []providers.ToolCall{ + {ID: "call-1", Name: "media_tool", Arguments: map[string]any{}}, + }, + finalContent: "done", + } + al, agent, cleanup := newHookTestLoop(t, provider) + defer cleanup() + + hook := &respondWithMediaHook{ + respondTools: map[string]bool{"media_tool": true}, + media: []string{"media://test/image.png"}, + responseHandled: true, + forLLM: "media sent successfully", + } + if err := al.MountHook(NamedHook("media-hook", hook)); err != nil { + t.Fatalf("MountHook failed: %v", err) + } + + telegramChannel := &fakeMediaChannel{fakeChannel: fakeChannel{id: "rid-telegram"}} + al.channelManager = newStartedTestChannelManager(t, + al.bus.(*bus.MessageBus), al.mediaStore, "telegram", telegramChannel) + + _, err := al.runAgentLoop(context.Background(), agent, processOptions{ + Dispatch: DispatchRequest{ + SessionKey: "session-topic-media", + SessionScope: &session.SessionScope{ + Version: session.ScopeVersionV1, + AgentID: agent.ID, + Channel: "telegram", + Dimensions: []string{"chat"}, + Values: map[string]string{ + "chat": "forum:-100123/42", + }, + }, + InboundContext: &bus.InboundContext{ + Channel: "telegram", + ChatID: "-100123", + TopicID: "42", + ChatType: "group", + SenderID: "user1", + }, + UserMessage: "send media", + }, + DefaultResponse: defaultResponse, + EnableSummary: false, + SendResponse: false, + }) + if err != nil { + t.Fatalf("runAgentLoop failed: %v", err) + } + + if len(telegramChannel.sentMedia) != 1 { + t.Fatalf("expected exactly 1 sent media message, got %d", len(telegramChannel.sentMedia)) + } + sent := telegramChannel.sentMedia[0] + if sent.Context.Channel != "telegram" || sent.Context.ChatID != "-100123" || sent.Context.TopicID != "42" { + t.Fatalf("unexpected media context: %+v", sent.Context) + } + if sent.AgentID != agent.ID { + t.Fatalf("sent media agent_id = %q, want %q", sent.AgentID, agent.ID) + } + if sent.SessionKey != "session-topic-media" { + t.Fatalf("sent media session_key = %q, want session-topic-media", sent.SessionKey) + } + if sent.Scope == nil || sent.Scope.Values["chat"] != "forum:-100123/42" { + t.Fatalf("unexpected sent media scope: %+v", sent.Scope) + } +} + type multiToolProvider struct { mu sync.Mutex callCount int diff --git a/pkg/agent/pipeline_execute.go b/pkg/agent/pipeline_execute.go index 87254619c..48e72e096 100644 --- a/pkg/agent/pipeline_execute.go +++ b/pkg/agent/pipeline_execute.go @@ -80,21 +80,16 @@ toolLoop: }, ) - if al.cfg.Agents.Defaults.IsToolFeedbackEnabled() && - ts.channel != "" && - !ts.opts.SuppressToolFeedback { - argsJSON, _ := json.Marshal(toolArgs) - feedbackPreview := utils.Truncate( - string(argsJSON), + if shouldPublishToolFeedback(al.cfg, ts) { + toolFeedbackExplanation := toolFeedbackExplanationForToolCall( + exec.response, + tc, + messages, al.cfg.Agents.Defaults.GetToolFeedbackMaxArgsLength(), ) - feedbackMsg := utils.FormatToolFeedbackMessage(toolName, feedbackPreview) + feedbackMsg := utils.FormatToolFeedbackMessage(toolName, toolFeedbackExplanation) fbCtx, fbCancel := context.WithTimeout(turnCtx, 3*time.Second) - _ = al.bus.PublishOutbound(fbCtx, bus.OutboundMessage{ - Channel: ts.channel, - ChatID: ts.chatID, - Content: feedbackMsg, - }) + _ = al.bus.PublishOutbound(fbCtx, outboundMessageForTurnWithKind(ts, feedbackMsg, messageKindToolFeedback)) fbCancel() } @@ -131,7 +126,16 @@ toolLoop: outboundMedia := bus.OutboundMediaMessage{ Channel: ts.channel, ChatID: ts.chatID, - Parts: parts, + Context: outboundContextFromInbound( + ts.opts.Dispatch.InboundContext, + ts.channel, + ts.chatID, + ts.opts.Dispatch.ReplyToMessageID(), + ), + AgentID: ts.agent.ID, + SessionKey: ts.sessionKey, + Scope: outboundScopeFromSessionScope(ts.opts.Dispatch.SessionScope), + Parts: parts, } if al.channelManager != nil && ts.channel != "" && !constants.IsInternalChannel(ts.channel) { if err := al.channelManager.SendMedia(ctx, outboundMedia); err != nil { @@ -353,16 +357,16 @@ toolLoop: }, ) - if al.cfg.Agents.Defaults.IsToolFeedbackEnabled() && - ts.channel != "" && - !ts.opts.SuppressToolFeedback { - feedbackPreview := utils.Truncate( - string(argsJSON), + if shouldPublishToolFeedback(al.cfg, ts) { + toolFeedbackExplanation := toolFeedbackExplanationForToolCall( + exec.response, + tc, + messages, al.cfg.Agents.Defaults.GetToolFeedbackMaxArgsLength(), ) - feedbackMsg := utils.FormatToolFeedbackMessage(tc.Name, feedbackPreview) + feedbackMsg := utils.FormatToolFeedbackMessage(toolName, toolFeedbackExplanation) fbCtx, fbCancel := context.WithTimeout(turnCtx, 3*time.Second) - _ = al.bus.PublishOutbound(fbCtx, outboundMessageForTurn(ts, feedbackMsg)) + _ = al.bus.PublishOutbound(fbCtx, outboundMessageForTurnWithKind(ts, feedbackMsg, messageKindToolFeedback)) fbCancel() } diff --git a/pkg/agent/pipeline_llm.go b/pkg/agent/pipeline_llm.go index c426c25c9..29940bc01 100644 --- a/pkg/agent/pipeline_llm.go +++ b/pkg/agent/pipeline_llm.go @@ -424,7 +424,11 @@ func (p *Pipeline) CallLLM( } logger.DebugCF("agent", "LLM response", llmResponseFields) - if al.bus != nil && ts.channel == "pico" && len(exec.response.ToolCalls) > 0 && ts.opts.AllowInterimPicoPublish { + if al.bus != nil && + ts.channel == "pico" && + len(exec.response.ToolCalls) > 0 && + ts.opts.AllowInterimPicoPublish && + !shouldPublishToolFeedback(al.cfg, ts) { if strings.TrimSpace(exec.response.Content) != "" { outCtx, outCancel := context.WithTimeout(turnCtx, 3*time.Second) publishErr := al.bus.PublishOutbound(outCtx, bus.OutboundMessage{ @@ -496,7 +500,19 @@ func (p *Pipeline) CallLLM( } for _, tc := range exec.normalizedToolCalls { argumentsJSON, _ := json.Marshal(tc.Arguments) + toolFeedbackExplanation := toolFeedbackExplanationForToolCall( + exec.response, + tc, + exec.messages, + al.cfg.Agents.Defaults.GetToolFeedbackMaxArgsLength(), + ) extraContent := tc.ExtraContent + if strings.TrimSpace(toolFeedbackExplanation) != "" { + if extraContent == nil { + extraContent = &providers.ExtraContent{} + } + extraContent.ToolFeedbackExplanation = toolFeedbackExplanation + } thoughtSignature := "" if tc.Function != nil { thoughtSignature = tc.Function.ThoughtSignature diff --git a/pkg/agent/subturn_test.go b/pkg/agent/subturn_test.go index 6a2ba835d..040063249 100644 --- a/pkg/agent/subturn_test.go +++ b/pkg/agent/subturn_test.go @@ -1650,6 +1650,38 @@ func TestGrandchildAbort_CascadingCancellation(t *testing.T) { } } +func TestNestedSubTurn_GracefulFinishSignalsDirectChildren(t *testing.T) { + parentCtx := context.Background() + parentTS := &turnState{ + ctx: parentCtx, + turnID: "parent-graceful", + depth: 1, + pendingResults: make(chan *tools.ToolResult, 16), + } + parentTS.ctx, parentTS.cancelFunc = context.WithCancel(parentCtx) + + childTS := &turnState{ + ctx: context.Background(), + turnID: "child-graceful", + depth: 2, + parentTurnState: parentTS, + pendingResults: make(chan *tools.ToolResult, 16), + } + + if childTS.IsParentEnded() { + t.Fatal("IsParentEnded should be false before parent finishes") + } + + parentTS.Finish(false) + + if !parentTS.parentEnded.Load() { + t.Fatal("parentEnded should be true after graceful finish") + } + if !childTS.IsParentEnded() { + t.Fatal("nested child should observe parent graceful finish") + } +} + // TestSpawnDuringAbort_RaceCondition verifies behavior when trying to spawn // a sub-turn while the parent is being aborted. func TestSpawnDuringAbort_RaceCondition(t *testing.T) { diff --git a/pkg/agent/turn_state.go b/pkg/agent/turn_state.go index edf8654b5..8b5fd4e2c 100644 --- a/pkg/agent/turn_state.go +++ b/pkg/agent/turn_state.go @@ -554,9 +554,9 @@ func (ts *turnState) Finish(isHardAbort bool) { ts.mu.Unlock() }) - // If this is a graceful finish (not hard abort), signal to children - if !isHardAbort && ts.parentTurnState == nil { - // This is a root turn finishing gracefully + // Any graceful finish must signal direct children so nested SubTurns can + // observe parent completion and decide whether to stop or continue. + if !isHardAbort { ts.parentEnded.Store(true) } diff --git a/pkg/channels/discord/discord.go b/pkg/channels/discord/discord.go index 28f7277d3..514b9b3b1 100644 --- a/pkg/channels/discord/discord.go +++ b/pkg/channels/discord/discord.go @@ -45,9 +45,12 @@ type DiscordChannel struct { cancel context.CancelFunc typingMu sync.Mutex typingStop map[string]chan struct{} // chatID → stop signal - botUserID string // stored for mention checking + progress *channels.ToolFeedbackAnimator + botUserID string // stored for mention checking bus *bus.MessageBus tts tts.TTSProvider + playTTSFn func(context.Context, *discordgo.VoiceConnection, string, uint64) + ttsVoiceFn func(string) (*discordgo.VoiceConnection, bool) voiceMu sync.RWMutex voiceSSRC map[string]map[uint32]string // guildID -> ssrc -> userID @@ -84,7 +87,7 @@ func NewDiscordChannel( channels.WithReasoningChannelID(bc.ReasoningChannelID), ) - return &DiscordChannel{ + ch := &DiscordChannel{ BaseChannel: base, bc: bc, session: session, @@ -93,7 +96,11 @@ func NewDiscordChannel( typingStop: make(map[string]chan struct{}), bus: bus, voiceSSRC: make(map[string]map[uint32]string), - }, nil + } + ch.playTTSFn = ch.playTTS + ch.ttsVoiceFn = ch.voiceConnectionForTTS + ch.progress = channels.NewToolFeedbackAnimator(ch.EditMessage) + return ch, nil } func (c *DiscordChannel) Start(ctx context.Context) error { @@ -142,6 +149,9 @@ func (c *DiscordChannel) Stop(ctx context.Context) error { if c.cancel != nil { c.cancel() } + if c.progress != nil { + c.progress.StopAll() + } if err := c.session.Close(); err != nil { return fmt.Errorf("failed to close discord session: %w", err) @@ -164,32 +174,88 @@ func (c *DiscordChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([]s return nil, nil } - if c.tts != nil { - if ch, err := c.session.State.Channel(channelID); err == nil && ch.GuildID != "" { - if vc, ok := c.session.VoiceConnections[ch.GuildID]; ok && vc != nil { - // Cancel any previous TTS playback - c.ttsMu.Lock() - if c.cancelTTS != nil { - c.cancelTTS() - } - ttsCtx, ttsCancel := context.WithCancel(c.ctx) - c.ttsPlayID++ - playID := c.ttsPlayID - c.cancelTTS = ttsCancel - c.ttsMu.Unlock() - - go c.playTTS(ttsCtx, vc, msg.Content, playID) + isToolFeedback := outboundMessageIsToolFeedback(msg) + if isToolFeedback { + if msgID, handled, err := c.progress.Update(ctx, channelID, msg.Content); handled { + if err != nil { + return nil, err } + return []string{msgID}, nil + } + } + trackedMsgID, hasTrackedMsg := c.currentToolFeedbackMessage(channelID) + c.maybeStartTTS(channelID, msg.Content, isToolFeedback) + if !isToolFeedback { + if msgIDs, handled := c.FinalizeToolFeedbackMessage(ctx, msg); handled { + return msgIDs, nil } } - msgID, err := c.sendChunk(ctx, channelID, msg.Content, msg.ReplyToMessageID) + content := msg.Content + if isToolFeedback { + content = channels.InitialAnimatedToolFeedbackContent(msg.Content) + } + msgID, err := c.sendChunk(ctx, channelID, content, msg.ReplyToMessageID) if err != nil { return nil, err } + if isToolFeedback { + c.RecordToolFeedbackMessage(channelID, msgID, msg.Content) + } else if hasTrackedMsg { + c.dismissTrackedToolFeedbackMessage(ctx, channelID, trackedMsgID) + } return []string{msgID}, nil } +func (c *DiscordChannel) maybeStartTTS(channelID, content string, isToolFeedback bool) { + if c.tts == nil || isToolFeedback { + return + } + + voiceFn := c.ttsVoiceFn + if voiceFn == nil { + voiceFn = c.voiceConnectionForTTS + } + vc, ok := voiceFn(channelID) + if !ok || vc == nil { + return + } + + // Cancel any previous TTS playback. + c.ttsMu.Lock() + if c.cancelTTS != nil { + c.cancelTTS() + } + ttsCtx, ttsCancel := context.WithCancel(c.ctx) + c.ttsPlayID++ + playID := c.ttsPlayID + c.cancelTTS = ttsCancel + playFn := c.playTTSFn + c.ttsMu.Unlock() + + if playFn == nil { + playFn = c.playTTS + } + go playFn(ttsCtx, vc, content, playID) +} + +func (c *DiscordChannel) voiceConnectionForTTS(channelID string) (*discordgo.VoiceConnection, bool) { + if c.session == nil || c.session.State == nil { + return nil, false + } + + ch, err := c.session.State.Channel(channelID) + if err != nil || ch == nil || ch.GuildID == "" { + return nil, false + } + + vc, ok := c.session.VoiceConnections[ch.GuildID] + if !ok || vc == nil { + return nil, false + } + return vc, true +} + // SendMedia implements the channels.MediaSender interface. func (c *DiscordChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) ([]string, error) { if !c.IsRunning() { @@ -200,6 +266,7 @@ func (c *DiscordChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMes if channelID == "" { return nil, fmt.Errorf("channel ID is empty") } + trackedMsgID, hasTrackedMsg := c.currentToolFeedbackMessage(channelID) store := c.GetMediaStore() if store == nil { @@ -281,6 +348,9 @@ func (c *DiscordChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMes if r.err != nil { return nil, fmt.Errorf("discord send media: %w", channels.ErrTemporary) } + if hasTrackedMsg { + c.dismissTrackedToolFeedbackMessage(ctx, channelID, trackedMsgID) + } return []string{r.id}, nil case <-sendCtx.Done(): // Close all file readers @@ -295,10 +365,15 @@ func (c *DiscordChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMes // EditMessage implements channels.MessageEditor. func (c *DiscordChannel) EditMessage(ctx context.Context, chatID string, messageID string, content string) error { - _, err := c.session.ChannelMessageEdit(chatID, messageID, content) + _, err := c.session.ChannelMessageEdit(chatID, messageID, content, discordgo.WithContext(ctx)) return err } +// DeleteMessage implements channels.MessageDeleter. +func (c *DiscordChannel) DeleteMessage(ctx context.Context, chatID string, messageID string) error { + return c.session.ChannelMessageDelete(chatID, messageID, discordgo.WithContext(ctx)) +} + // SendPlaceholder implements channels.PlaceholderCapable. // It sends a placeholder message that will later be edited to the actual // response via EditMessage (channels.MessageEditor). @@ -317,6 +392,81 @@ func (c *DiscordChannel) SendPlaceholder(ctx context.Context, chatID string) (st return msg.ID, nil } +func outboundMessageIsToolFeedback(msg bus.OutboundMessage) bool { + if len(msg.Context.Raw) == 0 { + return false + } + return strings.EqualFold(strings.TrimSpace(msg.Context.Raw["message_kind"]), "tool_feedback") +} + +func (c *DiscordChannel) currentToolFeedbackMessage(chatID string) (string, bool) { + if c.progress == nil { + return "", false + } + return c.progress.Current(chatID) +} + +func (c *DiscordChannel) takeToolFeedbackMessage(chatID string) (string, string, bool) { + if c.progress == nil { + return "", "", false + } + return c.progress.Take(chatID) +} + +func (c *DiscordChannel) RecordToolFeedbackMessage(chatID, messageID, content string) { + if c.progress == nil { + return + } + c.progress.Record(chatID, messageID, content) +} + +func (c *DiscordChannel) ClearToolFeedbackMessage(chatID string) { + if c.progress == nil { + return + } + c.progress.Clear(chatID) +} + +func (c *DiscordChannel) DismissToolFeedbackMessage(ctx context.Context, chatID string) { + msgID, ok := c.currentToolFeedbackMessage(chatID) + if !ok { + return + } + c.dismissTrackedToolFeedbackMessage(ctx, chatID, msgID) +} + +func (c *DiscordChannel) dismissTrackedToolFeedbackMessage(ctx context.Context, chatID, messageID string) { + if strings.TrimSpace(chatID) == "" || strings.TrimSpace(messageID) == "" { + return + } + c.ClearToolFeedbackMessage(chatID) + _ = c.DeleteMessage(ctx, chatID, messageID) +} + +func (c *DiscordChannel) finalizeTrackedToolFeedbackMessage( + ctx context.Context, + chatID string, + content string, + editFn func(context.Context, string, string, string) error, +) ([]string, bool) { + msgID, baseContent, ok := c.takeToolFeedbackMessage(chatID) + if !ok || editFn == nil { + return nil, false + } + if err := editFn(ctx, chatID, msgID, content); err != nil { + c.RecordToolFeedbackMessage(chatID, msgID, baseContent) + return nil, false + } + return []string{msgID}, true +} + +func (c *DiscordChannel) FinalizeToolFeedbackMessage(ctx context.Context, msg bus.OutboundMessage) ([]string, bool) { + if outboundMessageIsToolFeedback(msg) { + return nil, false + } + return c.finalizeTrackedToolFeedbackMessage(ctx, msg.ChatID, msg.Content, c.EditMessage) +} + func (c *DiscordChannel) sendChunk(ctx context.Context, channelID, content, replyToID string) (string, error) { // Use the passed ctx for timeout control sendCtx, cancel := context.WithTimeout(ctx, sendTimeout) diff --git a/pkg/channels/discord/discord_test.go b/pkg/channels/discord/discord_test.go index 0cd5328f4..d42b0bc52 100644 --- a/pkg/channels/discord/discord_test.go +++ b/pkg/channels/discord/discord_test.go @@ -1,13 +1,37 @@ package discord import ( + "context" + "io" "net/http" + "net/http/httptest" "net/url" + "reflect" + "sync" "testing" + "time" "github.com/bwmarrin/discordgo" + + "github.com/sipeed/picoclaw/pkg/audio/tts" + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" ) +type stubTTSProvider struct{} + +func (stubTTSProvider) Name() string { return "stub-tts" } + +func (stubTTSProvider) Synthesize(context.Context, string) (io.ReadCloser, error) { + return io.NopCloser(&noopReader{}), nil +} + +type noopReader struct{} + +func (*noopReader) Read(p []byte) (int, error) { + return 0, io.EOF +} + func TestApplyDiscordProxy_CustomProxy(t *testing.T) { session, err := discordgo.New("Bot test-token") if err != nil { @@ -89,3 +113,224 @@ func TestApplyDiscordProxy_InvalidProxyURL(t *testing.T) { t.Fatal("applyDiscordProxy() expected error for invalid proxy URL, got nil") } } + +func TestSend_NonToolFeedbackDeletesTrackedProgressMessage(t *testing.T) { + var ( + mu sync.Mutex + requests []string + ) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + mu.Lock() + requests = append(requests, r.Method+" "+r.URL.Path) + mu.Unlock() + + switch { + case r.Method == http.MethodPatch && r.URL.Path == "/channels/chat-1/messages/prog-1": + w.Header().Set("Content-Type", "application/json") + _, _ = io.WriteString(w, `{"id":"prog-1"}`) + default: + t.Fatalf("unexpected request: %s %s", r.Method, r.URL.Path) + } + })) + defer server.Close() + + origChannels := discordgo.EndpointChannels + discordgo.EndpointChannels = server.URL + "/channels/" + defer func() { + discordgo.EndpointChannels = origChannels + }() + + session, err := discordgo.New("Bot test-token") + if err != nil { + t.Fatalf("discordgo.New() error: %v", err) + } + session.Client = server.Client() + + ch := &DiscordChannel{ + BaseChannel: channels.NewBaseChannel("discord", nil, bus.NewMessageBus(), nil), + session: session, + ctx: context.Background(), + typingStop: make(map[string]chan struct{}), + voiceSSRC: make(map[string]map[uint32]string), + } + ch.progress = channels.NewToolFeedbackAnimator(ch.EditMessage) + ch.SetRunning(true) + ch.RecordToolFeedbackMessage("chat-1", "prog-1", "🔧 `read_file`") + + ids, err := ch.Send(context.Background(), bus.OutboundMessage{ + ChatID: "chat-1", + Content: "final reply", + Context: bus.InboundContext{ + Channel: "discord", + ChatID: "chat-1", + }, + }) + if err != nil { + t.Fatalf("Send() error = %v", err) + } + if got, want := ids, []string{"prog-1"}; !reflect.DeepEqual(got, want) { + t.Fatalf("Send() ids = %v, want %v", got, want) + } + if _, ok := ch.currentToolFeedbackMessage("chat-1"); ok { + t.Fatal("expected tracked tool feedback message to be cleared") + } + + mu.Lock() + defer mu.Unlock() + wantRequests := []string{ + "PATCH /channels/chat-1/messages/prog-1", + } + if !reflect.DeepEqual(requests, wantRequests) { + t.Fatalf("requests = %v, want %v", requests, wantRequests) + } +} + +func TestEditMessage_UsesContextCancellation(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + select { + case <-r.Context().Done(): + return + case <-time.After(time.Second): + w.Header().Set("Content-Type", "application/json") + _, _ = io.WriteString(w, `{"id":"msg-1"}`) + } + })) + defer server.Close() + + origChannels := discordgo.EndpointChannels + discordgo.EndpointChannels = server.URL + "/channels/" + defer func() { + discordgo.EndpointChannels = origChannels + }() + + session, err := discordgo.New("Bot test-token") + if err != nil { + t.Fatalf("discordgo.New() error: %v", err) + } + session.Client = server.Client() + + ch := &DiscordChannel{ + BaseChannel: channels.NewBaseChannel("discord", nil, bus.NewMessageBus(), nil), + session: session, + } + + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + + start := time.Now() + err = ch.EditMessage(ctx, "chat-1", "msg-1", "still running") + elapsed := time.Since(start) + + if err == nil { + t.Fatal("expected EditMessage() to fail when context times out") + } + if elapsed >= 500*time.Millisecond { + t.Fatalf("EditMessage() ignored context timeout, elapsed=%v", elapsed) + } +} + +func TestFinalizeTrackedToolFeedbackMessage_StopsTrackingBeforeEdit(t *testing.T) { + ch := &DiscordChannel{ + progress: channels.NewToolFeedbackAnimator(nil), + } + ch.RecordToolFeedbackMessage("chat-1", "msg-1", "🔧 `read_file`") + + msgIDs, handled := ch.finalizeTrackedToolFeedbackMessage( + context.Background(), + "chat-1", + "final reply", + func(_ context.Context, chatID, messageID, content string) error { + if _, ok := ch.currentToolFeedbackMessage(chatID); ok { + t.Fatal("expected tracked tool feedback to be stopped before edit") + } + if chatID != "chat-1" || messageID != "msg-1" || content != "final reply" { + t.Fatalf("unexpected edit args: %s %s %s", chatID, messageID, content) + } + return nil + }, + ) + if !handled { + t.Fatal("expected finalizeTrackedToolFeedbackMessage to handle tracked message") + } + if got, want := msgIDs, []string{"msg-1"}; !reflect.DeepEqual(got, want) { + t.Fatalf("finalizeTrackedToolFeedbackMessage() ids = %v, want %v", got, want) + } +} + +func TestSend_NonToolFeedbackFinalizerStillStartsTTS(t *testing.T) { + var ( + mu sync.Mutex + requests []string + ) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + mu.Lock() + requests = append(requests, r.Method+" "+r.URL.Path) + mu.Unlock() + + switch { + case r.Method == http.MethodPatch && r.URL.Path == "/channels/chat-1/messages/prog-1": + w.Header().Set("Content-Type", "application/json") + _, _ = io.WriteString(w, `{"id":"prog-1"}`) + default: + t.Fatalf("unexpected request: %s %s", r.Method, r.URL.Path) + } + })) + defer server.Close() + + origChannels := discordgo.EndpointChannels + discordgo.EndpointChannels = server.URL + "/channels/" + defer func() { + discordgo.EndpointChannels = origChannels + }() + + session, err := discordgo.New("Bot test-token") + if err != nil { + t.Fatalf("discordgo.New() error: %v", err) + } + session.Client = server.Client() + + ttsStarted := make(chan string, 1) + ch := &DiscordChannel{ + BaseChannel: channels.NewBaseChannel("discord", nil, bus.NewMessageBus(), nil), + session: session, + ctx: context.Background(), + typingStop: make(map[string]chan struct{}), + voiceSSRC: make(map[string]map[uint32]string), + tts: tts.TTSProvider(stubTTSProvider{}), + } + ch.ttsVoiceFn = func(string) (*discordgo.VoiceConnection, bool) { + return &discordgo.VoiceConnection{}, true + } + ch.playTTSFn = func(_ context.Context, _ *discordgo.VoiceConnection, text string, _ uint64) { + ttsStarted <- text + } + ch.progress = channels.NewToolFeedbackAnimator(ch.EditMessage) + ch.SetRunning(true) + ch.RecordToolFeedbackMessage("chat-1", "prog-1", "🔧 `read_file`") + + ids, err := ch.Send(context.Background(), bus.OutboundMessage{ + ChatID: "chat-1", + Content: "final reply", + Context: bus.InboundContext{ + Channel: "discord", + ChatID: "chat-1", + }, + }) + if err != nil { + t.Fatalf("Send() error = %v", err) + } + if got, want := ids, []string{"prog-1"}; !reflect.DeepEqual(got, want) { + t.Fatalf("Send() ids = %v, want %v", got, want) + } + + select { + case got := <-ttsStarted: + if got != "final reply" { + t.Fatalf("TTS content = %q, want final reply", got) + } + case <-time.After(2 * time.Second): + t.Fatal("expected TTS to start for finalized tracked tool feedback reply") + } +} diff --git a/pkg/channels/feishu/feishu_64.go b/pkg/channels/feishu/feishu_64.go index 02ee47d69..8f3ae39d9 100644 --- a/pkg/channels/feishu/feishu_64.go +++ b/pkg/channels/feishu/feishu_64.go @@ -49,6 +49,9 @@ type FeishuChannel struct { mu sync.Mutex cancel context.CancelFunc + + progress *channels.ToolFeedbackAnimator + deleteMessageFn func(context.Context, string, string) error } type cachedMessage struct { @@ -74,6 +77,8 @@ func NewFeishuChannel(bc *config.Channel, cfg *config.FeishuSettings, bus *bus.M tokenCache: tc, client: lark.NewClient(cfg.AppID, cfg.AppSecret.String(), opts...), } + ch.deleteMessageFn = ch.deleteMessageAPI + ch.progress = channels.NewToolFeedbackAnimator(ch.EditMessage) ch.SetOwner(ch) return ch, nil } @@ -132,6 +137,9 @@ func (c *FeishuChannel) Stop(ctx context.Context) error { } c.wsClient = nil c.mu.Unlock() + if c.progress != nil { + c.progress.StopAll() + } c.SetRunning(false) logger.InfoC("feishu", "Feishu channel stopped") @@ -149,17 +157,55 @@ func (c *FeishuChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([]st return nil, fmt.Errorf("chat ID is empty: %w", channels.ErrSendFailed) } + isToolFeedback := outboundMessageIsToolFeedback(msg) + if isToolFeedback { + if msgID, handled, err := c.progress.Update(ctx, msg.ChatID, msg.Content); handled { + if err != nil { + // Feishu can fall back to plain text for a previous progress + // message, and those messages cannot be patched through the card + // edit API. Drop the stale tracker and recreate the progress + // message so later tool feedback is not blocked. + c.resetTrackedToolFeedbackAfterEditFailure(ctx, msg.ChatID) + } else { + return []string{msgID}, nil + } + } + } else { + if msgIDs, handled := c.FinalizeToolFeedbackMessage(ctx, msg); handled { + return msgIDs, nil + } + } + trackedMsgID, hasTrackedMsg := c.currentToolFeedbackMessage(msg.ChatID) + // Build interactive card with markdown content - cardContent, err := buildMarkdownCard(msg.Content) + sendContent := msg.Content + if isToolFeedback { + sendContent = channels.InitialAnimatedToolFeedbackContent(msg.Content) + } + cardContent, err := buildMarkdownCard(sendContent) if err != nil { // If card build fails, fall back to plain text - return nil, c.sendText(ctx, msg.ChatID, msg.Content) + msgID, sendErr := c.sendText(ctx, msg.ChatID, sendContent) + if sendErr != nil { + return nil, sendErr + } + if isToolFeedback { + c.RecordToolFeedbackMessage(msg.ChatID, msgID, msg.Content) + } else if hasTrackedMsg { + c.dismissTrackedToolFeedbackMessage(ctx, msg.ChatID, trackedMsgID) + } + return []string{msgID}, nil } // First attempt: try sending as interactive card - err = c.sendCard(ctx, msg.ChatID, cardContent) + msgID, err := c.sendCard(ctx, msg.ChatID, cardContent) if err == nil { - return nil, nil + if isToolFeedback { + c.RecordToolFeedbackMessage(msg.ChatID, msgID, msg.Content) + } else if hasTrackedMsg { + c.dismissTrackedToolFeedbackMessage(ctx, msg.ChatID, trackedMsgID) + } + return []string{msgID}, nil } // Check if error is due to card table limit (error code 11310) @@ -174,9 +220,14 @@ func (c *FeishuChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([]st }) // Second attempt: fall back to plain text message - textErr := c.sendText(ctx, msg.ChatID, msg.Content) + msgID, textErr := c.sendText(ctx, msg.ChatID, sendContent) if textErr == nil { - return nil, nil + if isToolFeedback { + c.RecordToolFeedbackMessage(msg.ChatID, msgID, msg.Content) + } else if hasTrackedMsg { + c.dismissTrackedToolFeedbackMessage(ctx, msg.ChatID, trackedMsgID) + } + return []string{msgID}, nil } // If text also fails, return the text error return nil, textErr @@ -210,6 +261,31 @@ func (c *FeishuChannel) EditMessage(ctx context.Context, chatID, messageID, cont return nil } +// DeleteMessage implements channels.MessageDeleter. +func (c *FeishuChannel) DeleteMessage(ctx context.Context, chatID, messageID string) error { + deleteFn := c.deleteMessageFn + if deleteFn == nil { + deleteFn = c.deleteMessageAPI + } + return deleteFn(ctx, chatID, messageID) +} + +func (c *FeishuChannel) deleteMessageAPI(ctx context.Context, chatID, messageID string) error { + req := larkim.NewDeleteMessageReqBuilder(). + MessageId(messageID). + Build() + + resp, err := c.client.Im.V1.Message.Delete(ctx, req) + if err != nil { + return fmt.Errorf("feishu delete: %w", err) + } + if !resp.Success() { + c.invalidateTokenOnAuthError(resp.Code) + return fmt.Errorf("feishu delete api error (code=%d msg=%s)", resp.Code, resp.Msg) + } + return nil +} + // SendPlaceholder implements channels.PlaceholderCapable. // Sends an interactive card with placeholder text and returns its message ID. func (c *FeishuChannel) SendPlaceholder(ctx context.Context, chatID string) (string, error) { @@ -251,6 +327,93 @@ func (c *FeishuChannel) SendPlaceholder(ctx context.Context, chatID string) (str return "", nil } +func outboundMessageIsToolFeedback(msg bus.OutboundMessage) bool { + if len(msg.Context.Raw) == 0 { + return false + } + return strings.EqualFold(strings.TrimSpace(msg.Context.Raw["message_kind"]), "tool_feedback") +} + +func (c *FeishuChannel) currentToolFeedbackMessage(chatID string) (string, bool) { + if c.progress == nil { + return "", false + } + return c.progress.Current(chatID) +} + +func (c *FeishuChannel) takeToolFeedbackMessage(chatID string) (string, string, bool) { + if c.progress == nil { + return "", "", false + } + return c.progress.Take(chatID) +} + +func (c *FeishuChannel) RecordToolFeedbackMessage(chatID, messageID, content string) { + if c.progress == nil { + return + } + c.progress.Record(chatID, messageID, content) +} + +func (c *FeishuChannel) ClearToolFeedbackMessage(chatID string) { + if c.progress == nil { + return + } + c.progress.Clear(chatID) +} + +func (c *FeishuChannel) DismissToolFeedbackMessage(ctx context.Context, chatID string) { + msgID, ok := c.currentToolFeedbackMessage(chatID) + if !ok { + return + } + c.dismissTrackedToolFeedbackMessage(ctx, chatID, msgID) +} + +func (c *FeishuChannel) resetTrackedToolFeedbackAfterEditFailure(ctx context.Context, chatID string) { + msgID, ok := c.currentToolFeedbackMessage(chatID) + if !ok { + return + } + c.dismissTrackedToolFeedbackMessage(ctx, chatID, msgID) +} + +func (c *FeishuChannel) dismissTrackedToolFeedbackMessage(ctx context.Context, chatID, messageID string) { + if strings.TrimSpace(chatID) == "" || strings.TrimSpace(messageID) == "" { + return + } + c.ClearToolFeedbackMessage(chatID) + deleteFn := c.deleteMessageFn + if deleteFn == nil { + deleteFn = c.deleteMessageAPI + } + _ = deleteFn(ctx, chatID, messageID) +} + +func (c *FeishuChannel) finalizeTrackedToolFeedbackMessage( + ctx context.Context, + chatID string, + content string, + editFn func(context.Context, string, string, string) error, +) ([]string, bool) { + msgID, baseContent, ok := c.takeToolFeedbackMessage(chatID) + if !ok || editFn == nil { + return nil, false + } + if err := editFn(ctx, chatID, msgID, content); err != nil { + c.RecordToolFeedbackMessage(chatID, msgID, baseContent) + return nil, false + } + return []string{msgID}, true +} + +func (c *FeishuChannel) FinalizeToolFeedbackMessage(ctx context.Context, msg bus.OutboundMessage) ([]string, bool) { + if outboundMessageIsToolFeedback(msg) { + return nil, false + } + return c.finalizeTrackedToolFeedbackMessage(ctx, msg.ChatID, msg.Content, c.EditMessage) +} + // ReactToMessage implements channels.ReactionCapable. // Adds a reaction (randomly chosen from config) and returns an undo function to remove it. func (c *FeishuChannel) ReactToMessage(ctx context.Context, chatID, messageID string) (func(), error) { @@ -323,6 +486,7 @@ func (c *FeishuChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMess if !c.IsRunning() { return nil, channels.ErrNotRunning } + trackedMsgID, hasTrackedMsg := c.currentToolFeedbackMessage(msg.ChatID) if msg.ChatID == "" { return nil, fmt.Errorf("chat ID is empty: %w", channels.ErrSendFailed) @@ -339,6 +503,10 @@ func (c *FeishuChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMess } } + if hasTrackedMsg { + c.dismissTrackedToolFeedbackMessage(ctx, msg.ChatID, trackedMsgID) + } + return nil, nil } @@ -801,7 +969,7 @@ func appendMediaTags(content, messageType string, mediaRefs []string) string { } // sendCard sends an interactive card message to a chat. -func (c *FeishuChannel) sendCard(ctx context.Context, chatID, cardContent string) error { +func (c *FeishuChannel) sendCard(ctx context.Context, chatID, cardContent string) (string, error) { req := larkim.NewCreateMessageReqBuilder(). ReceiveIdType(larkim.ReceiveIdTypeChatId). Body(larkim.NewCreateMessageReqBodyBuilder(). @@ -813,23 +981,26 @@ func (c *FeishuChannel) sendCard(ctx context.Context, chatID, cardContent string resp, err := c.client.Im.V1.Message.Create(ctx, req) if err != nil { - return fmt.Errorf("feishu send card: %w", channels.ErrTemporary) + return "", fmt.Errorf("feishu send card: %w", channels.ErrTemporary) } if !resp.Success() { c.invalidateTokenOnAuthError(resp.Code) - return fmt.Errorf("feishu api error (code=%d msg=%s): %w", resp.Code, resp.Msg, channels.ErrTemporary) + return "", fmt.Errorf("feishu api error (code=%d msg=%s): %w", resp.Code, resp.Msg, channels.ErrTemporary) } logger.DebugCF("feishu", "Feishu card message sent", map[string]any{ "chat_id": chatID, }) - return nil + if resp.Data != nil && resp.Data.MessageId != nil { + return *resp.Data.MessageId, nil + } + return "", nil } // sendText sends a plain text message to a chat (fallback when card fails). -func (c *FeishuChannel) sendText(ctx context.Context, chatID, text string) error { +func (c *FeishuChannel) sendText(ctx context.Context, chatID, text string) (string, error) { content, _ := json.Marshal(map[string]string{"text": text}) req := larkim.NewCreateMessageReqBuilder(). @@ -843,18 +1014,21 @@ func (c *FeishuChannel) sendText(ctx context.Context, chatID, text string) error resp, err := c.client.Im.V1.Message.Create(ctx, req) if err != nil { - return fmt.Errorf("feishu send text: %w", channels.ErrTemporary) + return "", fmt.Errorf("feishu send text: %w", channels.ErrTemporary) } if !resp.Success() { - return fmt.Errorf("feishu text api error (code=%d msg=%s): %w", resp.Code, resp.Msg, channels.ErrTemporary) + return "", fmt.Errorf("feishu text api error (code=%d msg=%s): %w", resp.Code, resp.Msg, channels.ErrTemporary) } logger.DebugCF("feishu", "Feishu text message sent (fallback)", map[string]any{ "chat_id": chatID, }) - return nil + if resp.Data != nil && resp.Data.MessageId != nil { + return *resp.Data.MessageId, nil + } + return "", nil } // sendImage uploads an image and sends it as a message. diff --git a/pkg/channels/feishu/feishu_64_test.go b/pkg/channels/feishu/feishu_64_test.go index 9010abf69..48fdf0f74 100644 --- a/pkg/channels/feishu/feishu_64_test.go +++ b/pkg/channels/feishu/feishu_64_test.go @@ -3,9 +3,13 @@ package feishu import ( + "context" + "errors" "testing" larkim "github.com/larksuite/oapi-sdk-go/v3/service/im/v1" + + "github.com/sipeed/picoclaw/pkg/channels" ) func TestExtractContent(t *testing.T) { @@ -279,3 +283,110 @@ func TestExtractFeishuSenderID(t *testing.T) { }) } } + +func TestFinalizeTrackedToolFeedbackMessage_ClearAfterSuccessfulEdit(t *testing.T) { + ch := &FeishuChannel{ + progress: channels.NewToolFeedbackAnimator(nil), + } + ch.RecordToolFeedbackMessage("chat-1", "msg-1", "🔧 `read_file`") + + msgIDs, handled := ch.finalizeTrackedToolFeedbackMessage( + context.Background(), + "chat-1", + "final reply", + func(_ context.Context, chatID, messageID, content string) error { + if chatID != "chat-1" || messageID != "msg-1" || content != "final reply" { + t.Fatalf("unexpected edit args: %s %s %s", chatID, messageID, content) + } + return nil + }, + ) + if !handled { + t.Fatal("expected finalizeTrackedToolFeedbackMessage to handle tracked message") + } + if len(msgIDs) != 1 || msgIDs[0] != "msg-1" { + t.Fatalf("unexpected msgIDs: %v", msgIDs) + } + if _, ok := ch.currentToolFeedbackMessage("chat-1"); ok { + t.Fatal("expected tracked tool feedback to be cleared after successful edit") + } +} + +func TestFinalizeTrackedToolFeedbackMessage_StopsTrackingBeforeEdit(t *testing.T) { + ch := &FeishuChannel{ + progress: channels.NewToolFeedbackAnimator(nil), + } + ch.RecordToolFeedbackMessage("chat-1", "msg-1", "🔧 `read_file`") + + msgIDs, handled := ch.finalizeTrackedToolFeedbackMessage( + context.Background(), + "chat-1", + "final reply", + func(_ context.Context, chatID, messageID, content string) error { + if _, ok := ch.currentToolFeedbackMessage(chatID); ok { + t.Fatal("expected tracked tool feedback to be stopped before edit") + } + if chatID != "chat-1" || messageID != "msg-1" || content != "final reply" { + t.Fatalf("unexpected edit args: %s %s %s", chatID, messageID, content) + } + return nil + }, + ) + if !handled { + t.Fatal("expected finalizeTrackedToolFeedbackMessage to handle tracked message") + } + if len(msgIDs) != 1 || msgIDs[0] != "msg-1" { + t.Fatalf("unexpected msgIDs: %v", msgIDs) + } +} + +func TestFinalizeTrackedToolFeedbackMessage_EditFailureKeepsTrackedMessage(t *testing.T) { + ch := &FeishuChannel{ + progress: channels.NewToolFeedbackAnimator(nil), + } + ch.RecordToolFeedbackMessage("chat-1", "msg-1", "🔧 `read_file`") + + msgIDs, handled := ch.finalizeTrackedToolFeedbackMessage( + context.Background(), + "chat-1", + "final reply", + func(context.Context, string, string, string) error { + return errors.New("edit failed") + }, + ) + if handled { + t.Fatal("expected finalizeTrackedToolFeedbackMessage to report unhandled on edit failure") + } + if len(msgIDs) != 0 { + t.Fatalf("unexpected msgIDs: %v", msgIDs) + } + if msgID, ok := ch.currentToolFeedbackMessage("chat-1"); !ok || msgID != "msg-1" { + t.Fatalf("expected tracked tool feedback to remain after failed edit, got (%q, %v)", msgID, ok) + } +} + +func TestResetTrackedToolFeedbackAfterEditFailure_DismissesTrackedMessage(t *testing.T) { + var ( + deletedChatID string + deletedMsgID string + ) + + ch := &FeishuChannel{ + progress: channels.NewToolFeedbackAnimator(nil), + deleteMessageFn: func(_ context.Context, chatID, messageID string) error { + deletedChatID = chatID + deletedMsgID = messageID + return nil + }, + } + ch.RecordToolFeedbackMessage("chat-1", "msg-1", "🔧 `read_file`") + + ch.resetTrackedToolFeedbackAfterEditFailure(context.Background(), "chat-1") + + if deletedChatID != "chat-1" || deletedMsgID != "msg-1" { + t.Fatalf("unexpected delete target: chat=%q msg=%q", deletedChatID, deletedMsgID) + } + if _, ok := ch.currentToolFeedbackMessage("chat-1"); ok { + t.Fatal("expected tracked tool feedback to be cleared after edit failure reset") + } +} diff --git a/pkg/channels/manager.go b/pkg/channels/manager.go index 928676cbc..2ffb1bb10 100644 --- a/pkg/channels/manager.go +++ b/pkg/channels/manager.go @@ -14,6 +14,7 @@ import ( "net" "net/http" "sort" + "strings" "sync" "time" @@ -25,6 +26,7 @@ import ( "github.com/sipeed/picoclaw/pkg/health" "github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/media" + "github.com/sipeed/picoclaw/pkg/utils" ) const ( @@ -96,6 +98,23 @@ type Manager struct { channelHashes map[string]string // channel name → config hash } +type toolFeedbackMessageTracker interface { + RecordToolFeedbackMessage(chatID, messageID, content string) + ClearToolFeedbackMessage(chatID string) +} + +type toolFeedbackMessageCleaner interface { + DismissToolFeedbackMessage(ctx context.Context, chatID string) +} + +type toolFeedbackMessageTargetResolver interface { + ToolFeedbackMessageChatID(chatID string, outboundCtx *bus.InboundContext) string +} + +type toolFeedbackMessageContentPreparer interface { + PrepareToolFeedbackMessageContent(content string) string +} + type asyncTask struct { cancel context.CancelFunc } @@ -108,6 +127,13 @@ func outboundMessageChatID(msg bus.OutboundMessage) string { return msg.ChatID } +func outboundMessageIsToolFeedback(msg bus.OutboundMessage) bool { + if len(msg.Context.Raw) == 0 { + return false + } + return strings.EqualFold(strings.TrimSpace(msg.Context.Raw["message_kind"]), "tool_feedback") +} + func outboundMediaChannel(msg bus.OutboundMediaMessage) string { return msg.Context.Channel } @@ -116,6 +142,47 @@ func outboundMediaChatID(msg bus.OutboundMediaMessage) string { return msg.ChatID } +func trackedToolFeedbackMessageChatID(ch Channel, chatID string, outboundCtx *bus.InboundContext) string { + if resolver, ok := ch.(toolFeedbackMessageTargetResolver); ok { + if resolved := strings.TrimSpace(resolver.ToolFeedbackMessageChatID(chatID, outboundCtx)); resolved != "" { + return resolved + } + } + return strings.TrimSpace(chatID) +} + +func dismissTrackedToolFeedbackMessage( + ctx context.Context, + ch Channel, + chatID string, + outboundCtx *bus.InboundContext, +) { + trackedChatID := trackedToolFeedbackMessageChatID(ch, chatID, outboundCtx) + if trackedChatID == "" { + return + } + if cleaner, ok := ch.(toolFeedbackMessageCleaner); ok { + cleaner.DismissToolFeedbackMessage(ctx, trackedChatID) + return + } + if tracker, ok := ch.(toolFeedbackMessageTracker); ok { + tracker.ClearToolFeedbackMessage(trackedChatID) + } +} + +func prepareToolFeedbackMessageContent(ch Channel, content string) string { + prepared := strings.TrimSpace(content) + if prepared == "" { + return "" + } + if preparer, ok := ch.(toolFeedbackMessageContentPreparer); ok { + if candidate := strings.TrimSpace(preparer.PrepareToolFeedbackMessageContent(prepared)); candidate != "" { + return candidate + } + } + return prepared +} + // RecordPlaceholder registers a placeholder message for later editing. // Implements PlaceholderRecorder. func (m *Manager) RecordPlaceholder(channel, chatID, placeholderID string) { @@ -196,7 +263,19 @@ func (m *Manager) preSend(ctx context.Context, name string, msg bus.OutboundMess } } - // 3. If a stream already finalized this message, delete the placeholder and skip send + isToolFeedback := outboundMessageIsToolFeedback(msg) + + // 3. If a stream already finalized this chat, stale tool feedback must be + // dropped without consuming the final-response marker. Streaming finalization + // bypasses the worker queue, so older queued feedback can arrive before the + // normal final outbound message that cleans up the marker and placeholder. + if isToolFeedback { + if _, loaded := m.streamActive.Load(key); loaded { + return nil, true + } + } + + // 4. If a stream already finalized this message, delete the placeholder and skip send if _, loaded := m.streamActive.LoadAndDelete(key); loaded { if v, loaded := m.placeholders.LoadAndDelete(key); loaded { if entry, ok := v.(placeholderEntry); ok && entry.id != "" { @@ -208,14 +287,29 @@ func (m *Manager) preSend(ctx context.Context, name string, msg bus.OutboundMess } } } + if !isToolFeedback { + dismissTrackedToolFeedbackMessage(ctx, ch, chatID, &msg.Context) + } return nil, true } - // 4. Try editing placeholder + // 5. Try editing placeholder if v, loaded := m.placeholders.LoadAndDelete(key); loaded { if entry, ok := v.(placeholderEntry); ok && entry.id != "" { if editor, ok := ch.(MessageEditor); ok { - if err := editor.EditMessage(ctx, chatID, entry.id, msg.Content); err == nil { + content := msg.Content + trackedContent := msg.Content + if isToolFeedback { + trackedContent = prepareToolFeedbackMessageContent(ch, msg.Content) + content = InitialAnimatedToolFeedbackContent(trackedContent) + } + if err := editor.EditMessage(ctx, chatID, entry.id, content); err == nil { + trackedChatID := trackedToolFeedbackMessageChatID(ch, chatID, &msg.Context) + if tracker, ok := ch.(toolFeedbackMessageTracker); ok && isToolFeedback { + tracker.RecordToolFeedbackMessage(trackedChatID, entry.id, trackedContent) + } else if !isToolFeedback { + dismissTrackedToolFeedbackMessage(ctx, ch, chatID, &msg.Context) + } return []string{entry.id}, true } // edit failed → fall through to normal Send @@ -312,22 +406,35 @@ func (m *Manager) GetStreamer(ctx context.Context, channelName, chatID string) ( // Mark streamActive on Finalize so preSend knows to clean up the placeholder key := channelName + ":" + chatID return &finalizeHookStreamer{ - Streamer: streamer, - onFinalize: func() { m.streamActive.Store(key, true) }, + Streamer: streamer, + onFinalize: func(finalizeCtx context.Context) { + dismissTrackedToolFeedbackMessage( + finalizeCtx, + ch, + chatID, + &bus.InboundContext{ + Channel: channelName, + ChatID: chatID, + }, + ) + m.streamActive.Store(key, true) + }, }, true } // finalizeHookStreamer wraps a Streamer to run a hook on Finalize. type finalizeHookStreamer struct { Streamer - onFinalize func() + onFinalize func(context.Context) } func (s *finalizeHookStreamer) Finalize(ctx context.Context, content string) error { if err := s.Streamer.Finalize(ctx, content); err != nil { return err } - s.onFinalize() + if s.onFinalize != nil { + s.onFinalize(ctx) + } return nil } @@ -769,18 +876,21 @@ func (m *Manager) runWorker(ctx context.Context, name string, w *channelWorker) // Collect all message chunks to send var chunks []string - // Step 1: Try marker-based splitting if enabled - if m.config != nil && m.config.Agents.Defaults.SplitOnMarker { + // Step 1: Try marker-based splitting if enabled. + // Tool feedback must stay a single message, so it skips marker splitting. + if m.config != nil && m.config.Agents.Defaults.SplitOnMarker && !outboundMessageIsToolFeedback(msg) { if markerChunks := SplitByMarker(msg.Content); len(markerChunks) > 1 { for _, chunk := range markerChunks { - chunks = append(chunks, splitByLength(chunk, maxLen)...) + chunkMsg := msg + chunkMsg.Content = chunk + chunks = append(chunks, splitOutboundMessageContent(chunkMsg, maxLen)...) } } } // Step 2: Fallback to length-based splitting if no chunks from marker if len(chunks) == 0 { - chunks = splitByLength(msg.Content, maxLen) + chunks = splitOutboundMessageContent(msg, maxLen) } // Step 3: Send all chunks @@ -795,12 +905,25 @@ func (m *Manager) runWorker(ctx context.Context, name string, w *channelWorker) } } -// splitByLength splits content by maxLen if needed, otherwise returns single chunk. -func splitByLength(content string, maxLen int) []string { - if maxLen > 0 && len([]rune(content)) > maxLen { - return SplitMessage(content, maxLen) +// splitOutboundMessageContent splits regular outbound content by maxLen, but +// keeps tool feedback in a single message by truncating the explanation body. +func splitOutboundMessageContent(msg bus.OutboundMessage, maxLen int) []string { + if maxLen > 0 { + if outboundMessageIsToolFeedback(msg) { + animationSafeLen := maxLen - MaxToolFeedbackAnimationFrameLength() + if animationSafeLen <= 0 { + animationSafeLen = maxLen + } + if len([]rune(msg.Content)) > animationSafeLen { + return []string{utils.FitToolFeedbackMessage(msg.Content, animationSafeLen)} + } + return []string{msg.Content} + } + if len([]rune(msg.Content)) > maxLen { + return SplitMessage(msg.Content, maxLen) + } } - return []string{content} + return []string{msg.Content} } // sendWithRetry sends a message through the channel with rate limiting and @@ -1264,13 +1387,16 @@ func (m *Manager) SendMessage(ctx context.Context, msg bus.OutboundMessage) erro if mlp, ok := w.ch.(MessageLengthProvider); ok { maxLen = mlp.MaxMessageLength() } - if maxLen > 0 && len([]rune(msg.Content)) > maxLen { - for _, chunk := range SplitMessage(msg.Content, maxLen) { + if chunks := splitOutboundMessageContent(msg, maxLen); len(chunks) > 1 { + for _, chunk := range chunks { chunkMsg := msg chunkMsg.Content = chunk m.sendWithRetry(ctx, channelName, w, chunkMsg) } } else { + if len(chunks) == 1 { + msg.Content = chunks[0] + } m.sendWithRetry(ctx, channelName, w, msg) } return nil diff --git a/pkg/channels/manager_test.go b/pkg/channels/manager_test.go index 881993d9c..273c90468 100644 --- a/pkg/channels/manager_test.go +++ b/pkg/channels/manager_test.go @@ -13,6 +13,8 @@ import ( "golang.org/x/time/rate" "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/utils" ) // mockChannel is a test double that delegates Send to a configurable function. @@ -76,8 +78,9 @@ func (m *mockMediaChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaM type mockDeletingMediaChannel struct { mockMediaChannel - deleteCalls int - lastDeleted struct { + deleteCalls int + dismissedChatID string + lastDeleted struct { chatID string messageID string } @@ -94,6 +97,48 @@ func (m *mockDeletingMediaChannel) DeleteMessage( return nil } +func (m *mockDeletingMediaChannel) DismissToolFeedbackMessage(_ context.Context, chatID string) { + m.dismissedChatID = chatID +} + +type mockStreamer struct { + finalizeFn func(context.Context, string) error +} + +func (m *mockStreamer) Update(context.Context, string) error { return nil } + +func (m *mockStreamer) Finalize(ctx context.Context, content string) error { + if m.finalizeFn != nil { + return m.finalizeFn(ctx, content) + } + return nil +} + +func (m *mockStreamer) Cancel(context.Context) {} + +type mockStreamingChannel struct { + mockMessageEditor + streamer Streamer + resolveChatIDFn func(chatID string, outboundCtx *bus.InboundContext) string +} + +func (m *mockStreamingChannel) BeginStream(context.Context, string) (Streamer, error) { + if m.streamer == nil { + return nil, errors.New("missing streamer") + } + return m.streamer, nil +} + +func (m *mockStreamingChannel) ToolFeedbackMessageChatID( + chatID string, + outboundCtx *bus.InboundContext, +) string { + if m.resolveChatIDFn != nil { + return m.resolveChatIDFn(chatID, outboundCtx) + } + return chatID +} + // newTestManager creates a minimal Manager suitable for unit tests. func newTestManager() *Manager { return &Manager{ @@ -715,13 +760,72 @@ func TestSendWithRetry_ExponentialBackoff(t *testing.T) { // mockMessageEditor is a channel that supports MessageEditor. type mockMessageEditor struct { mockChannel - editFn func(ctx context.Context, chatID, messageID, content string) error + editFn func(ctx context.Context, chatID, messageID, content string) error + finalizeFn func(ctx context.Context, msg bus.OutboundMessage) ([]string, bool) + finalizeCalled bool + recordedChatID string + recordedMessageID string + recordedContent string + clearedChatID string + dismissedChatID string } func (m *mockMessageEditor) EditMessage(ctx context.Context, chatID, messageID, content string) error { return m.editFn(ctx, chatID, messageID, content) } +func (m *mockMessageEditor) RecordToolFeedbackMessage(chatID, messageID, content string) { + m.recordedChatID = chatID + m.recordedMessageID = messageID + m.recordedContent = content +} + +func (m *mockMessageEditor) ClearToolFeedbackMessage(chatID string) { + m.clearedChatID = chatID +} + +func (m *mockMessageEditor) DismissToolFeedbackMessage(_ context.Context, chatID string) { + m.dismissedChatID = chatID +} + +func (m *mockMessageEditor) FinalizeToolFeedbackMessage( + ctx context.Context, + msg bus.OutboundMessage, +) ([]string, bool) { + m.finalizeCalled = true + if m.finalizeFn == nil { + return nil, false + } + return m.finalizeFn(ctx, msg) +} + +type mockResolvedToolFeedbackEditor struct { + mockMessageEditor + resolveChatIDFn func(chatID string, outboundCtx *bus.InboundContext) string +} + +func (m *mockResolvedToolFeedbackEditor) ToolFeedbackMessageChatID( + chatID string, + outboundCtx *bus.InboundContext, +) string { + if m.resolveChatIDFn != nil { + return m.resolveChatIDFn(chatID, outboundCtx) + } + return chatID +} + +type mockPreparedToolFeedbackEditor struct { + mockMessageEditor + prepareFn func(content string) string +} + +func (m *mockPreparedToolFeedbackEditor) PrepareToolFeedbackMessageContent(content string) string { + if m.prepareFn != nil { + return m.prepareFn(content) + } + return content +} + func TestPreSend_PlaceholderEditSuccess(t *testing.T) { m := newTestManager() var sendCalled bool @@ -766,6 +870,539 @@ func TestPreSend_PlaceholderEditSuccess(t *testing.T) { } } +func TestPreSend_ToolFeedbackPlaceholderEditRecordsTrackedMessage(t *testing.T) { + m := newTestManager() + + ch := &mockMessageEditor{ + editFn: func(_ context.Context, chatID, messageID, content string) error { + if chatID != "123" || messageID != "456" || content != "hello" { + t.Fatalf("unexpected edit args: %s %s %s", chatID, messageID, content) + } + return nil + }, + } + + m.RecordPlaceholder("test", "123", "456") + + msg := testOutboundMessage(bus.OutboundMessage{ + Channel: "test", + ChatID: "123", + Content: "hello", + Context: bus.InboundContext{ + Channel: "test", + ChatID: "123", + Raw: map[string]string{ + "message_kind": "tool_feedback", + }, + }, + }) + _, edited := m.preSend(context.Background(), "test", msg, ch) + if !edited { + t.Fatal("expected preSend to edit placeholder") + } + if ch.recordedChatID != "123" || ch.recordedMessageID != "456" { + t.Fatalf("expected tracked message 123/456, got %q/%q", ch.recordedChatID, ch.recordedMessageID) + } +} + +func TestPreSend_ToolFeedbackPlaceholderEditUsesResolvedTrackedChatID(t *testing.T) { + m := newTestManager() + + ch := &mockResolvedToolFeedbackEditor{ + mockMessageEditor: mockMessageEditor{ + editFn: func(_ context.Context, chatID, messageID, content string) error { + if chatID != "-100123" || messageID != "456" || content != "hello" { + t.Fatalf("unexpected edit args: %s %s %s", chatID, messageID, content) + } + return nil + }, + }, + resolveChatIDFn: func(chatID string, outboundCtx *bus.InboundContext) string { + if chatID != "-100123" { + t.Fatalf("expected raw chat ID, got %q", chatID) + } + if outboundCtx == nil || outboundCtx.TopicID != "42" { + t.Fatalf("expected topic-aware outbound context, got %+v", outboundCtx) + } + return chatID + "/" + outboundCtx.TopicID + }, + } + + m.RecordPlaceholder("test", "-100123", "456") + + msg := testOutboundMessage(bus.OutboundMessage{ + Channel: "test", + ChatID: "-100123", + Content: "hello", + Context: bus.InboundContext{ + Channel: "test", + ChatID: "-100123", + TopicID: "42", + Raw: map[string]string{ + "message_kind": "tool_feedback", + }, + }, + }) + _, edited := m.preSend(context.Background(), "test", msg, ch) + if !edited { + t.Fatal("expected preSend to edit placeholder") + } + if ch.recordedChatID != "-100123/42" || ch.recordedMessageID != "456" { + t.Fatalf("expected resolved tracked message -100123/42/456, got %q/%q", + ch.recordedChatID, ch.recordedMessageID) + } +} + +func TestPreSend_ToolFeedbackPlaceholderEditUsesPreparedContent(t *testing.T) { + m := newTestManager() + + const rawContent = "🔧 `read_file`\n" + "" + const preparedContent = "🔧 `read_file`\n<raw>" + + ch := &mockPreparedToolFeedbackEditor{ + mockMessageEditor: mockMessageEditor{ + editFn: func(_ context.Context, chatID, messageID, content string) error { + if chatID != "123" || messageID != "456" { + t.Fatalf("unexpected edit target: %s/%s", chatID, messageID) + } + if content != InitialAnimatedToolFeedbackContent(preparedContent) { + t.Fatalf("unexpected prepared content: %q", content) + } + return nil + }, + }, + prepareFn: func(content string) string { + if content != rawContent { + t.Fatalf("unexpected raw tool feedback: %q", content) + } + return preparedContent + }, + } + + m.RecordPlaceholder("test", "123", "456") + + msg := testOutboundMessage(bus.OutboundMessage{ + Channel: "test", + ChatID: "123", + Content: rawContent, + Context: bus.InboundContext{ + Channel: "test", + ChatID: "123", + Raw: map[string]string{ + "message_kind": "tool_feedback", + }, + }, + }) + + _, edited := m.preSend(context.Background(), "test", msg, ch) + if !edited { + t.Fatal("expected preSend to edit placeholder") + } + if ch.recordedContent != preparedContent { + t.Fatalf("expected tracked content %q, got %q", preparedContent, ch.recordedContent) + } +} + +func TestPreSend_NonToolFeedbackLeavesTrackedMessageForChannelSend(t *testing.T) { + m := newTestManager() + ch := &mockMessageEditor{} + + msg := testOutboundMessage(bus.OutboundMessage{ + Channel: "test", + ChatID: "123", + Content: "final reply", + Context: bus.InboundContext{ + Channel: "test", + ChatID: "123", + }, + }) + + _, edited := m.preSend(context.Background(), "test", msg, ch) + if edited { + t.Fatal("expected preSend to fall through when no placeholder exists") + } + if ch.dismissedChatID != "" { + t.Fatalf("expected tracked tool feedback cleanup to be deferred to channel send, got %q", ch.dismissedChatID) + } +} + +func TestPreSend_NonToolFeedbackDefersTrackedMessageFinalizationToChannelSend(t *testing.T) { + m := newTestManager() + ch := &mockMessageEditor{ + finalizeFn: func(_ context.Context, msg bus.OutboundMessage) ([]string, bool) { + if msg.ChatID != "123" || msg.Content != "final reply" { + t.Fatalf("unexpected finalize msg: %+v", msg) + } + return []string{"tool-msg-1"}, true + }, + } + + msg := testOutboundMessage(bus.OutboundMessage{ + Channel: "test", + ChatID: "123", + Content: "final reply", + Context: bus.InboundContext{ + Channel: "test", + ChatID: "123", + }, + }) + + msgIDs, handled := m.preSend(context.Background(), "test", msg, ch) + if handled { + t.Fatalf("expected preSend to defer to channel Send, got msgIDs=%v", msgIDs) + } + if len(msgIDs) != 0 { + t.Fatalf("expected no msgIDs from preSend, got %v", msgIDs) + } + if ch.dismissedChatID != "" { + t.Fatalf("expected tracked cleanup to remain in channel Send, got %q", ch.dismissedChatID) + } + if ch.finalizeCalled { + t.Fatal("expected preSend to skip channel tool feedback finalization") + } +} + +func TestPreSend_StaleToolFeedbackDoesNotConsumeStreamActiveMarker(t *testing.T) { + m := newTestManager() + m.streamActive.Store("test:123", true) + m.RecordPlaceholder("test", "123", "placeholder-1") + + var editedContent string + ch := &mockMessageEditor{ + editFn: func(_ context.Context, chatID, messageID, content string) error { + if chatID != "123" || messageID != "placeholder-1" { + t.Fatalf("unexpected edit target: %s/%s", chatID, messageID) + } + editedContent = content + return nil + }, + } + + toolFeedback := testOutboundMessage(bus.OutboundMessage{ + Channel: "test", + ChatID: "123", + Content: "🔧 `read_file`\nReading config", + Context: bus.InboundContext{ + Channel: "test", + ChatID: "123", + Raw: map[string]string{ + "message_kind": "tool_feedback", + }, + }, + }) + + msgIDs, handled := m.preSend(context.Background(), "test", toolFeedback, ch) + if !handled { + t.Fatal("expected stale tool feedback to be dropped after stream finalize") + } + if len(msgIDs) != 0 { + t.Fatalf("expected no delivered message IDs for stale feedback, got %v", msgIDs) + } + if _, ok := m.streamActive.Load("test:123"); !ok { + t.Fatal("expected streamActive marker to remain for the final outbound message") + } + if _, ok := m.placeholders.Load("test:123"); !ok { + t.Fatal("expected placeholder cleanup to remain deferred to the final outbound message") + } + if ch.editedMessages != 0 { + t.Fatalf("expected no placeholder edit for stale feedback, got %d edits", ch.editedMessages) + } + + finalMsg := testOutboundMessage(bus.OutboundMessage{ + Channel: "test", + ChatID: "123", + Content: "final streamed reply", + Context: bus.InboundContext{ + Channel: "test", + ChatID: "123", + }, + }) + + _, handled = m.preSend(context.Background(), "test", finalMsg, ch) + if !handled { + t.Fatal("expected final outbound message to consume streamActive marker") + } + if _, ok := m.streamActive.Load("test:123"); ok { + t.Fatal("expected streamActive marker to be cleared by final outbound message") + } + if _, ok := m.placeholders.Load("test:123"); ok { + t.Fatal("expected placeholder to be cleaned up by final outbound message") + } + if editedContent != "final streamed reply" { + t.Fatalf("editedContent = %q, want final streamed reply", editedContent) + } +} + +func TestPreSendMedia_LeavesTrackedMessageForChannelSend(t *testing.T) { + m := newTestManager() + ch := &mockDeletingMediaChannel{} + + m.preSendMedia(context.Background(), "test", bus.OutboundMediaMessage{ + ChatID: "123", + Context: bus.InboundContext{ + Channel: "test", + ChatID: "123", + }, + }, ch) + + if ch.dismissedChatID != "" { + t.Fatalf( + "expected tracked tool feedback cleanup to be deferred to channel media send, got %q", + ch.dismissedChatID, + ) + } +} + +func TestSplitOutboundMessageContent_ToolFeedbackTruncatesInsteadOfSplitting(t *testing.T) { + msg := testOutboundMessage(bus.OutboundMessage{ + Channel: "test", + ChatID: "123", + Content: "\U0001f527 `read_file`\nRead README.md first to confirm the current project structure before editing the config example.", + Context: bus.InboundContext{ + Channel: "test", + ChatID: "123", + Raw: map[string]string{ + "message_kind": "tool_feedback", + }, + }, + }) + + chunks := splitOutboundMessageContent(msg, 40) + if len(chunks) != 1 { + t.Fatalf("len(chunks) = %d, want 1", len(chunks)) + } + want := utils.FitToolFeedbackMessage(msg.Content, 40-MaxToolFeedbackAnimationFrameLength()) + if chunks[0] != want { + t.Fatalf("chunk = %q, want %q", chunks[0], want) + } +} + +func TestSplitOutboundMessageContent_ToolFeedbackReservesAnimationFrame(t *testing.T) { + msg := testOutboundMessage(bus.OutboundMessage{ + Channel: "test", + ChatID: "123", + Content: "🔧 `read_file`\n1234567890", + Context: bus.InboundContext{ + Channel: "test", + ChatID: "123", + Raw: map[string]string{ + "message_kind": "tool_feedback", + }, + }, + }) + + chunks := splitOutboundMessageContent(msg, len([]rune(msg.Content))) + if len(chunks) != 1 { + t.Fatalf("len(chunks) = %d, want 1", len(chunks)) + } + + animated := formatAnimatedToolFeedbackContent(chunks[0], strings.Repeat(".", MaxToolFeedbackAnimationFrameLength())) + if got, maxLen := len([]rune(animated)), len([]rune(msg.Content)); got > maxLen { + t.Fatalf("animated len = %d, want <= %d; content=%q", got, maxLen, animated) + } +} + +func TestGetStreamer_FinalizeDismissesTrackedToolFeedback(t *testing.T) { + m := newTestManager() + ch := &mockStreamingChannel{ + mockMessageEditor: mockMessageEditor{}, + streamer: &mockStreamer{ + finalizeFn: func(_ context.Context, content string) error { + if content != "final reply" { + t.Fatalf("unexpected finalize content: %q", content) + } + return nil + }, + }, + } + m.channels["test"] = ch + + streamer, ok := m.GetStreamer(context.Background(), "test", "123") + if !ok { + t.Fatal("expected streamer to be available") + } + if err := streamer.Finalize(context.Background(), "final reply"); err != nil { + t.Fatalf("Finalize() error = %v", err) + } + if ch.dismissedChatID != "123" { + t.Fatalf("expected tracked tool feedback to be dismissed for chat 123, got %q", ch.dismissedChatID) + } + if _, ok := m.streamActive.Load("test:123"); !ok { + t.Fatal("expected streamActive marker to be recorded after finalize") + } +} + +func TestGetStreamer_FinalizeDismissesResolvedTrackedToolFeedback(t *testing.T) { + m := newTestManager() + ch := &mockStreamingChannel{ + mockMessageEditor: mockMessageEditor{}, + streamer: &mockStreamer{ + finalizeFn: func(_ context.Context, content string) error { + if content != "final reply" { + t.Fatalf("unexpected finalize content: %q", content) + } + return nil + }, + }, + resolveChatIDFn: func(chatID string, outboundCtx *bus.InboundContext) string { + if outboundCtx == nil { + t.Fatal("expected outbound context during stream finalize") + } + if outboundCtx.ChatID != "-100123/42" { + t.Fatalf("unexpected outbound context: %+v", outboundCtx) + } + return outboundCtx.ChatID + }, + } + m.channels["test"] = ch + + streamer, ok := m.GetStreamer(context.Background(), "test", "-100123/42") + if !ok { + t.Fatal("expected streamer to be available") + } + if err := streamer.Finalize(context.Background(), "final reply"); err != nil { + t.Fatalf("Finalize() error = %v", err) + } + if ch.dismissedChatID != "-100123/42" { + t.Fatalf("expected resolved tracked tool feedback dismissal, got %q", ch.dismissedChatID) + } + if _, ok := m.streamActive.Load("test:-100123/42"); !ok { + t.Fatal("expected streamActive marker to be recorded after finalize") + } +} + +func TestPreSend_PlaceholderEditSuccessDismissesResolvedTrackedToolFeedback(t *testing.T) { + m := newTestManager() + + ch := &mockResolvedToolFeedbackEditor{ + mockMessageEditor: mockMessageEditor{ + editFn: func(_ context.Context, chatID, messageID, content string) error { + if chatID != "-100123" || messageID != "456" || content != "done" { + t.Fatalf("unexpected edit args: %s %s %s", chatID, messageID, content) + } + return nil + }, + }, + resolveChatIDFn: func(chatID string, outboundCtx *bus.InboundContext) string { + if outboundCtx == nil || outboundCtx.TopicID != "42" { + t.Fatalf("expected topic-aware outbound context, got %+v", outboundCtx) + } + return chatID + "/" + outboundCtx.TopicID + }, + } + + m.RecordPlaceholder("test", "-100123", "456") + + msg := testOutboundMessage(bus.OutboundMessage{ + Channel: "test", + ChatID: "-100123", + Content: "done", + Context: bus.InboundContext{ + Channel: "test", + ChatID: "-100123", + TopicID: "42", + }, + }) + + _, edited := m.preSend(context.Background(), "test", msg, ch) + if !edited { + t.Fatal("expected preSend to edit placeholder") + } + if ch.dismissedChatID != "-100123/42" { + t.Fatalf("expected resolved tracked dismissal, got %q", ch.dismissedChatID) + } +} + +func TestGetStreamer_FinalizeFailureDoesNotDismissTrackedToolFeedback(t *testing.T) { + m := newTestManager() + ch := &mockStreamingChannel{ + mockMessageEditor: mockMessageEditor{}, + streamer: &mockStreamer{ + finalizeFn: func(context.Context, string) error { + return errors.New("finalize failed") + }, + }, + } + m.channels["test"] = ch + + streamer, ok := m.GetStreamer(context.Background(), "test", "123") + if !ok { + t.Fatal("expected streamer to be available") + } + if err := streamer.Finalize(context.Background(), "final reply"); err == nil { + t.Fatal("expected Finalize() to fail") + } + if ch.dismissedChatID != "" { + t.Fatalf("expected no tool feedback dismissal on finalize failure, got %q", ch.dismissedChatID) + } + if _, ok := m.streamActive.Load("test:123"); ok { + t.Fatal("expected no streamActive marker after finalize failure") + } +} + +func TestRunWorker_ToolFeedbackSkipsMarkerSplitting(t *testing.T) { + m := newTestManager() + m.config = &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + SplitOnMarker: true, + }, + }, + } + + var ( + mu sync.Mutex + received []string + ) + ch := &mockChannelWithLength{ + mockChannel: mockChannel{ + sendFn: func(_ context.Context, msg bus.OutboundMessage) error { + mu.Lock() + received = append(received, msg.Content) + mu.Unlock() + return nil + }, + }, + maxLen: 200, + } + + w := &channelWorker{ + ch: ch, + queue: make(chan bus.OutboundMessage, 1), + done: make(chan struct{}), + limiter: rate.NewLimiter(rate.Inf, 1), + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go m.runWorker(ctx, "test", w) + + content := "🔧 `read_file`\nRead current config first.<|[SPLIT]|>Then update the example." + w.queue <- testOutboundMessage(bus.OutboundMessage{ + Channel: "test", + ChatID: "123", + Content: content, + Context: bus.InboundContext{ + Channel: "test", + ChatID: "123", + Raw: map[string]string{ + "message_kind": "tool_feedback", + }, + }, + }) + + time.Sleep(100 * time.Millisecond) + + mu.Lock() + defer mu.Unlock() + if len(received) != 1 { + t.Fatalf("len(received) = %d, want 1", len(received)) + } + if received[0] != content { + t.Fatalf("received[0] = %q, want %q", received[0], content) + } +} + func TestPreSend_PlaceholderEditFails_FallsThrough(t *testing.T) { m := newTestManager() diff --git a/pkg/channels/matrix/matrix.go b/pkg/channels/matrix/matrix.go index 40e1b0a36..04599d6d2 100644 --- a/pkg/channels/matrix/matrix.go +++ b/pkg/channels/matrix/matrix.go @@ -46,6 +46,13 @@ const ( var matrixMentionHrefRegexp = regexp.MustCompile(`(?i)]+href=["']([^"']+)["']`) +func outboundMessageIsToolFeedback(msg bus.OutboundMessage) bool { + if len(msg.Context.Raw) == 0 { + return false + } + return strings.EqualFold(strings.TrimSpace(msg.Context.Raw["message_kind"]), "tool_feedback") +} + type roomKindCacheEntry struct { isGroup bool expiresAt time.Time @@ -192,6 +199,7 @@ type MatrixChannel struct { cryptoHelper *cryptohelper.CryptoHelper cryptoDbPath string + progress *channels.ToolFeedbackAnimator } func NewMatrixChannel( @@ -236,7 +244,7 @@ func NewMatrixChannel( channels.WithReasoningChannelID(bc.ReasoningChannelID), ) - return &MatrixChannel{ + ch := &MatrixChannel{ BaseChannel: base, bc: bc, client: client, @@ -248,7 +256,9 @@ func NewMatrixChannel( localpartMentionR: localpartMentionRegexp(matrixLocalpart(client.UserID)), typingMu: sync.Mutex{}, cryptoDbPath: cryptoDatabasePath, - }, nil + } + ch.progress = channels.NewToolFeedbackAnimator(ch.EditMessage) + return ch, nil } func (c *MatrixChannel) Start(ctx context.Context) error { @@ -297,6 +307,9 @@ func (c *MatrixChannel) Stop(ctx context.Context) error { c.cancel() } c.stopTypingSessions(ctx) + if c.progress != nil { + c.progress.StopAll() + } // Close crypto helper if initialized if c.cryptoHelper != nil { @@ -398,11 +411,36 @@ func (c *MatrixChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([]st return nil, nil } + isToolFeedback := outboundMessageIsToolFeedback(msg) + if isToolFeedback { + if msgID, handled, err := c.progress.Update(ctx, msg.ChatID, content); handled { + if err != nil { + return nil, err + } + return []string{msgID}, nil + } + } + trackedMsgID, hasTrackedMsg := c.currentToolFeedbackMessage(msg.ChatID) + if !isToolFeedback { + if msgIDs, handled := c.FinalizeToolFeedbackMessage(ctx, msg); handled { + return msgIDs, nil + } + } + if isToolFeedback { + content = channels.InitialAnimatedToolFeedbackContent(content) + } + resp, err := c.client.SendMessageEvent(ctx, roomID, event.EventMessage, c.messageContent(content)) if err != nil { return nil, fmt.Errorf("matrix send: %w", channels.ErrTemporary) } - return []string{resp.EventID.String()}, nil + msgID := resp.EventID.String() + if isToolFeedback { + c.RecordToolFeedbackMessage(msg.ChatID, msgID, msg.Content) + } else if hasTrackedMsg { + c.dismissTrackedToolFeedbackMessage(ctx, msg.ChatID, trackedMsgID) + } + return []string{msgID}, nil } func (c *MatrixChannel) messageContent(text string) *event.MessageEventContent { @@ -419,6 +457,8 @@ func (c *MatrixChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMess if !c.IsRunning() { return nil, channels.ErrNotRunning } + trackedMsgID, hasTrackedMsg := c.currentToolFeedbackMessage(msg.ChatID) + sendCtx := ctx if sendCtx == nil { sendCtx = context.Background() @@ -529,6 +569,10 @@ func (c *MatrixChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMess } } + if hasTrackedMsg { + c.dismissTrackedToolFeedbackMessage(ctx, msg.ChatID, trackedMsgID) + } + return eventIDs, nil } @@ -612,6 +656,89 @@ func (c *MatrixChannel) EditMessage(ctx context.Context, chatID string, messageI return err } +// DeleteMessage implements channels.MessageDeleter. +func (c *MatrixChannel) DeleteMessage(ctx context.Context, chatID string, messageID string) error { + roomID := id.RoomID(strings.TrimSpace(chatID)) + if roomID == "" { + return fmt.Errorf("matrix room ID is empty") + } + eventID := id.EventID(strings.TrimSpace(messageID)) + if eventID == "" { + return fmt.Errorf("matrix message ID is empty") + } + + _, err := c.client.RedactEvent(ctx, roomID, eventID) + return err +} + +func (c *MatrixChannel) currentToolFeedbackMessage(chatID string) (string, bool) { + if c.progress == nil { + return "", false + } + return c.progress.Current(chatID) +} + +func (c *MatrixChannel) takeToolFeedbackMessage(chatID string) (string, string, bool) { + if c.progress == nil { + return "", "", false + } + return c.progress.Take(chatID) +} + +func (c *MatrixChannel) RecordToolFeedbackMessage(chatID, messageID, content string) { + if c.progress == nil { + return + } + c.progress.Record(chatID, messageID, content) +} + +func (c *MatrixChannel) ClearToolFeedbackMessage(chatID string) { + if c.progress == nil { + return + } + c.progress.Clear(chatID) +} + +func (c *MatrixChannel) DismissToolFeedbackMessage(ctx context.Context, chatID string) { + msgID, ok := c.currentToolFeedbackMessage(chatID) + if !ok { + return + } + c.dismissTrackedToolFeedbackMessage(ctx, chatID, msgID) +} + +func (c *MatrixChannel) dismissTrackedToolFeedbackMessage(ctx context.Context, chatID, messageID string) { + if strings.TrimSpace(chatID) == "" || strings.TrimSpace(messageID) == "" { + return + } + c.ClearToolFeedbackMessage(chatID) + _ = c.DeleteMessage(ctx, chatID, messageID) +} + +func (c *MatrixChannel) finalizeTrackedToolFeedbackMessage( + ctx context.Context, + chatID string, + content string, + editFn func(context.Context, string, string, string) error, +) ([]string, bool) { + msgID, baseContent, ok := c.takeToolFeedbackMessage(chatID) + if !ok || editFn == nil { + return nil, false + } + if err := editFn(ctx, chatID, msgID, content); err != nil { + c.RecordToolFeedbackMessage(chatID, msgID, baseContent) + return nil, false + } + return []string{msgID}, true +} + +func (c *MatrixChannel) FinalizeToolFeedbackMessage(ctx context.Context, msg bus.OutboundMessage) ([]string, bool) { + if outboundMessageIsToolFeedback(msg) { + return nil, false + } + return c.finalizeTrackedToolFeedbackMessage(ctx, msg.ChatID, msg.Content, c.EditMessage) +} + func (c *MatrixChannel) handleMemberEvent(ctx context.Context, evt *event.Event) { if !c.config.JoinOnInvite { return diff --git a/pkg/channels/matrix/matrix_test.go b/pkg/channels/matrix/matrix_test.go index 07f08f32b..066f08059 100644 --- a/pkg/channels/matrix/matrix_test.go +++ b/pkg/channels/matrix/matrix_test.go @@ -14,6 +14,7 @@ import ( "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" + "github.com/sipeed/picoclaw/pkg/channels" "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/media" ) @@ -41,6 +42,34 @@ func TestMatrixLocalpartMentionRegexp(t *testing.T) { } } +func TestFinalizeTrackedToolFeedbackMessage_StopsTrackingBeforeEdit(t *testing.T) { + ch := &MatrixChannel{ + progress: channels.NewToolFeedbackAnimator(nil), + } + ch.RecordToolFeedbackMessage("!room:matrix.org", "$event1", "🔧 `read_file`") + + msgIDs, handled := ch.finalizeTrackedToolFeedbackMessage( + context.Background(), + "!room:matrix.org", + "final reply", + func(_ context.Context, chatID, messageID, content string) error { + if _, ok := ch.currentToolFeedbackMessage(chatID); ok { + t.Fatal("expected tracked tool feedback to be stopped before edit") + } + if chatID != "!room:matrix.org" || messageID != "$event1" || content != "final reply" { + t.Fatalf("unexpected edit args: %s %s %s", chatID, messageID, content) + } + return nil + }, + ) + if !handled { + t.Fatal("expected finalizeTrackedToolFeedbackMessage to handle tracked message") + } + if len(msgIDs) != 1 || msgIDs[0] != "$event1" { + t.Fatalf("finalizeTrackedToolFeedbackMessage() ids = %v, want [$event1]", msgIDs) + } +} + func TestStripUserMention(t *testing.T) { userID := id.UserID("@picoclaw:matrix.org") diff --git a/pkg/channels/pico/pico.go b/pkg/channels/pico/pico.go index 4d1fad1ed..31360b3de 100644 --- a/pkg/channels/pico/pico.go +++ b/pkg/channels/pico/pico.go @@ -50,6 +50,17 @@ func outboundMessageIsThought(msg bus.OutboundMessage) bool { return strings.EqualFold(strings.TrimSpace(msg.Context.Raw["message_kind"]), MessageKindThought) } +func outboundMessageIsToolFeedback(msg bus.OutboundMessage) bool { + if len(msg.Context.Raw) == 0 { + return false + } + return strings.EqualFold(strings.TrimSpace(msg.Context.Raw["message_kind"]), "tool_feedback") +} + +func outboundMessageFinalizesTrackedToolFeedback(msg bus.OutboundMessage) bool { + return !outboundMessageIsToolFeedback(msg) && !outboundMessageIsThought(msg) +} + // writeJSON sends a JSON message to the connection with write locking. func (pc *picoConn) writeJSON(v any) error { if pc.closed.Load() { @@ -82,6 +93,8 @@ type PicoChannel struct { connsMu sync.RWMutex ctx context.Context cancel context.CancelFunc + progress *channels.ToolFeedbackAnimator + deleteMessageFn func(context.Context, string, string) error } // NewPicoChannel creates a new Pico Protocol channel. @@ -110,7 +123,7 @@ func NewPicoChannel( return false } - return &PicoChannel{ + ch := &PicoChannel{ BaseChannel: base, bc: bc, config: cfg, @@ -121,7 +134,10 @@ func NewPicoChannel( }, connections: make(map[string]*picoConn), sessionConnections: make(map[string]map[string]*picoConn), - }, nil + } + ch.progress = channels.NewToolFeedbackAnimator(ch.EditMessage) + ch.deleteMessageFn = ch.DeleteMessage + return ch, nil } // createAndAddConnection checks MaxConnections and registers a connection atomically. @@ -239,6 +255,9 @@ func (c *PicoChannel) Stop(ctx context.Context) error { if c.cancel != nil { c.cancel() } + if c.progress != nil { + c.progress.StopAll() + } logger.InfoC("pico", "Pico Protocol channel stopped") return nil @@ -269,26 +288,133 @@ func (c *PicoChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([]stri return nil, channels.ErrNotRunning } isThought := outboundMessageIsThought(msg) + isToolFeedback := outboundMessageIsToolFeedback(msg) + if isToolFeedback { + if msgID, handled, err := c.progress.Update(ctx, msg.ChatID, msg.Content); handled { + if err != nil { + return nil, err + } + return []string{msgID}, nil + } + } + trackedMsgID, hasTrackedMsg := c.currentToolFeedbackMessage(msg.ChatID) + if outboundMessageFinalizesTrackedToolFeedback(msg) { + if msgIDs, handled := c.FinalizeToolFeedbackMessage(ctx, msg); handled { + return msgIDs, nil + } + } + + content := msg.Content + if isToolFeedback { + content = channels.InitialAnimatedToolFeedbackContent(msg.Content) + } + msgID := uuid.New().String() payload := map[string]any{ - PayloadKeyContent: msg.Content, + PayloadKeyContent: content, PayloadKeyThought: isThought, + "message_id": msgID, } setContextUsagePayload(payload, msg.ContextUsage) outMsg := newMessage(TypeMessageCreate, payload) - return nil, c.broadcastToSession(msg.ChatID, outMsg) + if err := c.broadcastToSession(msg.ChatID, outMsg); err != nil { + return nil, err + } + if isToolFeedback { + c.RecordToolFeedbackMessage(msg.ChatID, msgID, msg.Content) + } else if hasTrackedMsg && outboundMessageFinalizesTrackedToolFeedback(msg) { + c.dismissTrackedToolFeedbackMessage(ctx, msg.ChatID, trackedMsgID) + } + return []string{msgID}, nil } // EditMessage implements channels.MessageEditor. func (c *PicoChannel) EditMessage(ctx context.Context, chatID string, messageID string, content string) error { - outMsg := newMessage(TypeMessageUpdate, map[string]any{ + return c.editMessage(ctx, chatID, messageID, content, nil) +} + +// DeleteMessage implements channels.MessageDeleter. +func (c *PicoChannel) DeleteMessage(ctx context.Context, chatID string, messageID string) error { + outMsg := newMessage(TypeMessageDelete, map[string]any{ "message_id": messageID, - "content": content, }) return c.broadcastToSession(chatID, outMsg) } +func (c *PicoChannel) currentToolFeedbackMessage(chatID string) (string, bool) { + if c.progress == nil { + return "", false + } + return c.progress.Current(chatID) +} + +func (c *PicoChannel) takeToolFeedbackMessage(chatID string) (string, string, bool) { + if c.progress == nil { + return "", "", false + } + return c.progress.Take(chatID) +} + +func (c *PicoChannel) RecordToolFeedbackMessage(chatID, messageID, content string) { + if c.progress == nil { + return + } + c.progress.Record(chatID, messageID, content) +} + +func (c *PicoChannel) ClearToolFeedbackMessage(chatID string) { + if c.progress == nil { + return + } + c.progress.Clear(chatID) +} + +func (c *PicoChannel) DismissToolFeedbackMessage(ctx context.Context, chatID string) { + msgID, ok := c.currentToolFeedbackMessage(chatID) + if !ok { + return + } + c.dismissTrackedToolFeedbackMessage(ctx, chatID, msgID) +} + +func (c *PicoChannel) dismissTrackedToolFeedbackMessage(ctx context.Context, chatID, messageID string) { + if strings.TrimSpace(chatID) == "" || strings.TrimSpace(messageID) == "" { + return + } + c.ClearToolFeedbackMessage(chatID) + deleteFn := c.deleteMessageFn + if deleteFn == nil { + deleteFn = c.DeleteMessage + } + _ = deleteFn(ctx, chatID, messageID) +} + +func (c *PicoChannel) finalizeTrackedToolFeedbackMessage( + ctx context.Context, + chatID string, + content string, + editFn func(context.Context, string, string, string, *bus.ContextUsage) error, + contextUsage *bus.ContextUsage, +) ([]string, bool) { + msgID, baseContent, ok := c.takeToolFeedbackMessage(chatID) + if !ok || editFn == nil { + return nil, false + } + if err := editFn(ctx, chatID, msgID, content, contextUsage); err != nil { + c.RecordToolFeedbackMessage(chatID, msgID, baseContent) + return nil, false + } + return []string{msgID}, true +} + +func (c *PicoChannel) FinalizeToolFeedbackMessage(ctx context.Context, msg bus.OutboundMessage) ([]string, bool) { + if !outboundMessageFinalizesTrackedToolFeedback(msg) { + return nil, false + } + return c.finalizeTrackedToolFeedbackMessage(ctx, msg.ChatID, msg.Content, c.editMessage, msg.ContextUsage) +} + // StartTyping implements channels.TypingCapable. func (c *PicoChannel) StartTyping(ctx context.Context, chatID string) (func(), error) { startMsg := newMessage(TypeTypingStart, nil) @@ -332,6 +458,7 @@ func (c *PicoChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessag if !c.IsRunning() { return nil, channels.ErrNotRunning } + trackedMsgID, hasTrackedMsg := c.currentToolFeedbackMessage(msg.ChatID) store := c.GetMediaStore() if store == nil { @@ -407,6 +534,9 @@ func (c *PicoChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessag if err := c.broadcastToSession(msg.ChatID, outMsg); err != nil { return nil, err } + if hasTrackedMsg { + c.dismissTrackedToolFeedbackMessage(ctx, msg.ChatID, trackedMsgID) + } return []string{msgID}, nil } @@ -939,3 +1069,19 @@ func setContextUsagePayload(payload map[string]any, u *bus.ContextUsage) { "used_percent": u.UsedPercent, } } + +func (c *PicoChannel) editMessage( + ctx context.Context, + chatID string, + messageID string, + content string, + contextUsage *bus.ContextUsage, +) error { + payload := map[string]any{ + "message_id": messageID, + "content": content, + } + setContextUsagePayload(payload, contextUsage) + outMsg := newMessage(TypeMessageUpdate, payload) + return c.broadcastToSession(chatID, outMsg) +} diff --git a/pkg/channels/pico/pico_test.go b/pkg/channels/pico/pico_test.go index f0d179527..22ed5451a 100644 --- a/pkg/channels/pico/pico_test.go +++ b/pkg/channels/pico/pico_test.go @@ -4,12 +4,16 @@ import ( "context" "errors" "fmt" + "net/http" "net/http/httptest" "os" "path/filepath" "strings" "sync" "testing" + "time" + + "github.com/gorilla/websocket" "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/channels" @@ -32,6 +36,163 @@ func newTestPicoChannel(t *testing.T) *PicoChannel { return ch } +func TestFinalizeTrackedToolFeedbackMessage_StopsTrackingBeforeEdit(t *testing.T) { + ch := &PicoChannel{ + progress: channels.NewToolFeedbackAnimator(nil), + } + ch.RecordToolFeedbackMessage("pico:chat-1", "msg-1", "🔧 `read_file`") + + msgIDs, handled := ch.finalizeTrackedToolFeedbackMessage( + context.Background(), + "pico:chat-1", + "final reply", + func(_ context.Context, chatID, messageID, content string, contextUsage *bus.ContextUsage) error { + if _, ok := ch.currentToolFeedbackMessage(chatID); ok { + t.Fatal("expected tracked tool feedback to be stopped before edit") + } + if chatID != "pico:chat-1" || messageID != "msg-1" || content != "final reply" { + t.Fatalf("unexpected edit args: %s %s %s", chatID, messageID, content) + } + if contextUsage != nil { + t.Fatalf("unexpected context usage: %+v", contextUsage) + } + return nil + }, + nil, + ) + if !handled { + t.Fatal("expected finalizeTrackedToolFeedbackMessage to handle tracked message") + } + if len(msgIDs) != 1 || msgIDs[0] != "msg-1" { + t.Fatalf("finalizeTrackedToolFeedbackMessage() ids = %v, want [msg-1]", msgIDs) + } +} + +func TestDismissTrackedToolFeedbackMessage_DeletesProgressMessage(t *testing.T) { + ch := &PicoChannel{ + progress: channels.NewToolFeedbackAnimator(nil), + } + ch.RecordToolFeedbackMessage("pico:chat-1", "msg-1", "🔧 `read_file`") + + var deleted struct { + chatID string + messageID string + } + ch.deleteMessageFn = func(_ context.Context, chatID string, messageID string) error { + deleted.chatID = chatID + deleted.messageID = messageID + return nil + } + + ch.DismissToolFeedbackMessage(context.Background(), "pico:chat-1") + + if deleted.chatID != "pico:chat-1" || deleted.messageID != "msg-1" { + t.Fatalf("unexpected delete target: %+v", deleted) + } + if _, ok := ch.currentToolFeedbackMessage("pico:chat-1"); ok { + t.Fatal("expected tracked tool feedback to be cleared after dismissal") + } +} + +func TestSend_ThoughtMessageDoesNotFinalizeTrackedToolFeedback(t *testing.T) { + ch := newTestPicoChannel(t) + + if err := ch.Start(context.Background()); err != nil { + t.Fatalf("Start() error = %v", err) + } + defer ch.Stop(context.Background()) + + clientConn, received, cleanup := newTestPicoWebSocket(t) + defer cleanup() + ch.addConnForTest(&picoConn{id: "conn-1", conn: clientConn, sessionID: "sess-1"}) + + ch.RecordToolFeedbackMessage("pico:sess-1", "msg-progress", "🔧 `read_file`\nReading config") + + if _, err := ch.Send(context.Background(), bus.OutboundMessage{ + ChatID: "pico:sess-1", + Content: "thinking trace", + Context: bus.InboundContext{ + Channel: "pico", + ChatID: "pico:sess-1", + Raw: map[string]string{ + "message_kind": MessageKindThought, + }, + }, + }); err != nil { + t.Fatalf("Send(thought) error = %v", err) + } + + select { + case msg := <-received: + if msg.Type != TypeMessageCreate { + t.Fatalf("thought message type = %q, want %q", msg.Type, TypeMessageCreate) + } + payload := msg.Payload + if got := payload[PayloadKeyContent]; got != "thinking trace" { + t.Fatalf("thought content = %#v, want %q", got, "thinking trace") + } + if got := payload[PayloadKeyThought]; got != true { + t.Fatalf("thought flag = %#v, want true", got) + } + if got := payload["message_id"]; got == "msg-progress" || got == nil || got == "" { + t.Fatalf("thought message_id = %#v, want new non-progress id", got) + } + case <-time.After(time.Second): + t.Fatal("expected thought message to be delivered") + } + + if msgID, ok := ch.currentToolFeedbackMessage("pico:sess-1"); !ok || msgID != "msg-progress" { + t.Fatalf("tracked tool feedback = (%q, %v), want (msg-progress, true)", msgID, ok) + } + + if _, err := ch.Send(context.Background(), bus.OutboundMessage{ + ChatID: "pico:sess-1", + Content: "final reply", + Context: bus.InboundContext{ + Channel: "pico", + ChatID: "pico:sess-1", + }, + ContextUsage: &bus.ContextUsage{ + UsedTokens: 321, + TotalTokens: 4096, + CompressAtTokens: 3072, + UsedPercent: 8, + }, + }); err != nil { + t.Fatalf("Send(final) error = %v", err) + } + + select { + case msg := <-received: + if msg.Type != TypeMessageUpdate { + t.Fatalf("final message type = %q, want %q", msg.Type, TypeMessageUpdate) + } + payload := msg.Payload + if got := payload["message_id"]; got != "msg-progress" { + t.Fatalf("final message_id = %#v, want %q", got, "msg-progress") + } + if got := payload[PayloadKeyContent]; got != "final reply" { + t.Fatalf("final content = %#v, want %q", got, "final reply") + } + rawUsage, ok := payload["context_usage"].(map[string]any) + if !ok { + t.Fatalf("final context_usage = %#v, want map payload", payload["context_usage"]) + } + if got, ok := rawUsage["used_tokens"].(float64); !ok || got != 321 { + t.Fatalf("used_tokens = %#v, want 321", rawUsage["used_tokens"]) + } + if got, ok := rawUsage["total_tokens"].(float64); !ok || got != 4096 { + t.Fatalf("total_tokens = %#v, want 4096", rawUsage["total_tokens"]) + } + case <-time.After(time.Second): + t.Fatal("expected final reply to finalize tracked tool feedback") + } + + if _, ok := ch.currentToolFeedbackMessage("pico:sess-1"); ok { + t.Fatal("expected tracked tool feedback to be cleared after final reply") + } +} + func TestCreateAndAddConnection_RespectsMaxConnectionsConcurrently(t *testing.T) { ch := newTestPicoChannel(t) @@ -169,6 +330,75 @@ func TestSendMedia_ResolvesMediaBeforeDelivery(t *testing.T) { } } +func TestSendMedia_DismissesTrackedToolFeedbackMessage(t *testing.T) { + ch := newTestPicoChannel(t) + store := media.NewFileMediaStore() + ch.SetMediaStore(store) + + if err := ch.Start(context.Background()); err != nil { + t.Fatalf("Start() error = %v", err) + } + defer ch.Stop(context.Background()) + + clientConn, received, cleanup := newTestPicoWebSocket(t) + defer cleanup() + ch.addConnForTest(&picoConn{id: "conn-1", conn: clientConn, sessionID: "sess-1"}) + + localPath := filepath.Join(t.TempDir(), "report.txt") + if err := os.WriteFile(localPath, []byte("attachment body"), 0o600); err != nil { + t.Fatalf("WriteFile() error = %v", err) + } + + ref, err := store.Store(localPath, media.MediaMeta{ + Filename: "report.txt", + ContentType: "text/plain", + }, "test-scope") + if err != nil { + t.Fatalf("Store() error = %v", err) + } + + ch.RecordToolFeedbackMessage("pico:sess-1", "msg-progress", "🔧 `read_file`") + + var deleted struct { + chatID string + messageID string + } + ch.deleteMessageFn = func(_ context.Context, chatID string, messageID string) error { + deleted.chatID = chatID + deleted.messageID = messageID + return nil + } + + _, err = ch.SendMedia(context.Background(), bus.OutboundMediaMessage{ + ChatID: "pico:sess-1", + Parts: []bus.MediaPart{{ + Ref: ref, + Type: "file", + Filename: "report.txt", + ContentType: "text/plain", + }}, + }) + if err != nil { + t.Fatalf("SendMedia() error = %v", err) + } + + select { + case msg := <-received: + if msg.Type != TypeMessageCreate { + t.Fatalf("message type = %q, want %q", msg.Type, TypeMessageCreate) + } + case <-time.After(time.Second): + t.Fatal("expected media message to be delivered") + } + + if deleted.chatID != "pico:sess-1" || deleted.messageID != "msg-progress" { + t.Fatalf("unexpected delete target: %+v", deleted) + } + if _, ok := ch.currentToolFeedbackMessage("pico:sess-1"); ok { + t.Fatal("expected tracked tool feedback to be cleared after media delivery") + } +} + func TestPicoDownloadURLForRef(t *testing.T) { got, err := picoDownloadURLForRef("media://attachment-1") if err != nil { @@ -240,3 +470,39 @@ func (c *PicoChannel) addConnForTest(pc *picoConn) { } bySession[pc.id] = pc } + +func newTestPicoWebSocket(t *testing.T) (*websocket.Conn, <-chan PicoMessage, func()) { + t.Helper() + + received := make(chan PicoMessage, 4) + upgrader := websocket.Upgrader{} + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Errorf("Upgrade() error = %v", err) + return + } + defer conn.Close() + for { + var msg PicoMessage + if err := conn.ReadJSON(&msg); err != nil { + return + } + received <- msg + } + })) + + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + clientConn, resp, err := websocket.DefaultDialer.Dial(wsURL, nil) + if err != nil { + server.Close() + t.Fatalf("Dial() error = %v", err) + } + + cleanup := func() { + clientConn.Close() + server.Close() + } + defer resp.Body.Close() + return clientConn, received, cleanup +} diff --git a/pkg/channels/pico/protocol.go b/pkg/channels/pico/protocol.go index 051beed1b..8a27b8c93 100644 --- a/pkg/channels/pico/protocol.go +++ b/pkg/channels/pico/protocol.go @@ -12,6 +12,7 @@ const ( // TypeMessageCreate is sent from server to client. TypeMessageCreate = "message.create" TypeMessageUpdate = "message.update" + TypeMessageDelete = "message.delete" TypeMediaCreate = "media.create" TypeTypingStart = "typing.start" TypeTypingStop = "typing.stop" diff --git a/pkg/channels/telegram/command_registration.go b/pkg/channels/telegram/command_registration.go index d3152ec3d..c6b362601 100644 --- a/pkg/channels/telegram/command_registration.go +++ b/pkg/channels/telegram/command_registration.go @@ -66,6 +66,10 @@ func (c *TelegramChannel) startCommandRegistration(ctx context.Context, defs []c if register == nil { register = c.RegisterCommands } + delayFn := c.commandRegDelayFn + if delayFn == nil { + delayFn = commandRegistrationDelay + } regCtx, cancel := context.WithCancel(ctx) c.commandRegCancel = cancel @@ -91,7 +95,7 @@ func (c *TelegramChannel) startCommandRegistration(ctx context.Context, defs []c return } - delay := commandRegistrationDelay(attempt) + delay := delayFn(attempt) logger.WarnCF("telegram", "Telegram command registration failed; will retry", map[string]any{ "error": err.Error(), "retry_after": delay.String(), diff --git a/pkg/channels/telegram/command_registration_test.go b/pkg/channels/telegram/command_registration_test.go index 26f891b2e..c30c6f68d 100644 --- a/pkg/channels/telegram/command_registration_test.go +++ b/pkg/channels/telegram/command_registration_test.go @@ -31,14 +31,12 @@ func TestStartCommandRegistration_DoesNotBlock(t *testing.T) { } func TestStartCommandRegistration_RetriesUntilSuccessThenStops(t *testing.T) { - ch := &TelegramChannel{} + ch := &TelegramChannel{ + commandRegDelayFn: func(int) time.Duration { return 5 * time.Millisecond }, + } ctx, cancel := context.WithCancel(context.Background()) defer cancel() - origBackoff := commandRegistrationBackoff - commandRegistrationBackoff = []time.Duration{5 * time.Millisecond} - defer func() { commandRegistrationBackoff = origBackoff }() - var attempts atomic.Int32 ch.registerFunc = func(context.Context, []commands.Definition) error { n := attempts.Add(1) @@ -69,12 +67,10 @@ func TestStartCommandRegistration_RetriesUntilSuccessThenStops(t *testing.T) { } func TestStartCommandRegistration_StopsAfterCancel(t *testing.T) { - ch := &TelegramChannel{} + ch := &TelegramChannel{ + commandRegDelayFn: func(int) time.Duration { return 5 * time.Millisecond }, + } ctx, cancel := context.WithCancel(context.Background()) - - origBackoff := commandRegistrationBackoff - commandRegistrationBackoff = []time.Duration{5 * time.Millisecond} - defer func() { commandRegistrationBackoff = origBackoff }() defer cancel() var attempts atomic.Int32 diff --git a/pkg/channels/telegram/telegram.go b/pkg/channels/telegram/telegram.go index 2a9cfe4ae..cebebfed6 100644 --- a/pkg/channels/telegram/telegram.go +++ b/pkg/channels/telegram/telegram.go @@ -45,16 +45,18 @@ var ( type TelegramChannel struct { *channels.BaseChannel - bot *telego.Bot - bh *th.BotHandler - bc *config.Channel - chatIDs map[string]int64 - ctx context.Context - cancel context.CancelFunc - tgCfg *config.TelegramSettings + bot *telego.Bot + bh *th.BotHandler + bc *config.Channel + chatIDs map[string]int64 + ctx context.Context + cancel context.CancelFunc + tgCfg *config.TelegramSettings + progress *channels.ToolFeedbackAnimator - registerFunc func(context.Context, []commands.Definition) error - commandRegCancel context.CancelFunc + registerFunc func(context.Context, []commands.Definition) error + commandRegDelayFn func(int) time.Duration + commandRegCancel context.CancelFunc } func NewTelegramChannel( @@ -104,13 +106,15 @@ func NewTelegramChannel( channels.WithReasoningChannelID(bc.ReasoningChannelID), ) - return &TelegramChannel{ + ch := &TelegramChannel{ BaseChannel: base, bot: bot, bc: bc, chatIDs: make(map[string]int64), tgCfg: telegramCfg, - }, nil + } + ch.progress = channels.NewToolFeedbackAnimator(ch.EditMessage) + return ch, nil } func (c *TelegramChannel) Start(ctx context.Context) error { @@ -168,6 +172,9 @@ func (c *TelegramChannel) Stop(ctx context.Context) error { if c.cancel != nil { c.cancel() } + if c.progress != nil { + c.progress.StopAll() + } if c.commandRegCancel != nil { c.commandRegCancel() } @@ -191,12 +198,36 @@ func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([] return nil, nil } + isToolFeedback := outboundMessageIsToolFeedback(msg) + toolFeedbackContent := msg.Content + if isToolFeedback { + toolFeedbackContent = fitToolFeedbackForTelegram(msg.Content, useMarkdownV2, 4096) + } + trackedChatID := telegramToolFeedbackChatKey(msg.ChatID, &msg.Context) + if isToolFeedback { + if msgID, handled, err := c.progress.Update(ctx, trackedChatID, toolFeedbackContent); handled { + if err != nil { + return nil, err + } + return []string{msgID}, nil + } + } + trackedMsgID, hasTrackedMsg := c.currentToolFeedbackMessage(trackedChatID) + if !isToolFeedback { + if msgIDs, handled := c.finalizeToolFeedbackMessageForChat(ctx, trackedChatID, msg); handled { + return msgIDs, nil + } + } + // The Manager already splits messages to ≤4000 chars (WithMaxMessageLength), // so msg.Content is guaranteed to be within that limit. We still need to // check if HTML expansion pushes it beyond Telegram's 4096-char API limit. replyToID := msg.ReplyToMessageID var messageIDs []string queue := []string{msg.Content} + if isToolFeedback { + queue = []string{channels.InitialAnimatedToolFeedbackContent(toolFeedbackContent)} + } for len(queue) > 0 { chunk := queue[0] queue = queue[1:] @@ -204,6 +235,13 @@ func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([] content := parseContent(chunk, useMarkdownV2) if len([]rune(content)) > 4096 { + if isToolFeedback { + fittedChunk := fitToolFeedbackForTelegram(chunk, useMarkdownV2, 4096) + if fittedChunk != "" && fittedChunk != chunk { + queue = append([]string{fittedChunk}, queue...) + continue + } + } runeChunk := []rune(chunk) ratio := float64(len(runeChunk)) / float64(len([]rune(content))) smallerLen := int(float64(4096) * ratio * 0.95) // 5% safety margin @@ -270,6 +308,12 @@ func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([] replyToID = "" } + if isToolFeedback && len(messageIDs) > 0 { + c.RecordToolFeedbackMessage(trackedChatID, messageIDs[0], toolFeedbackContent) + } else if !isToolFeedback && hasTrackedMsg { + c.dismissTrackedToolFeedbackMessage(ctx, trackedChatID, trackedMsgID) + } + return messageIDs, nil } @@ -437,6 +481,89 @@ func (c *TelegramChannel) DeleteMessage(ctx context.Context, chatID string, mess }) } +func outboundMessageIsToolFeedback(msg bus.OutboundMessage) bool { + if len(msg.Context.Raw) == 0 { + return false + } + return strings.EqualFold(strings.TrimSpace(msg.Context.Raw["message_kind"]), "tool_feedback") +} + +func (c *TelegramChannel) currentToolFeedbackMessage(chatID string) (string, bool) { + if c.progress == nil { + return "", false + } + return c.progress.Current(chatID) +} + +func (c *TelegramChannel) takeToolFeedbackMessage(chatID string) (string, string, bool) { + if c.progress == nil { + return "", "", false + } + return c.progress.Take(chatID) +} + +func (c *TelegramChannel) RecordToolFeedbackMessage(chatID, messageID, content string) { + if c.progress == nil { + return + } + c.progress.Record(chatID, messageID, content) +} + +func (c *TelegramChannel) ClearToolFeedbackMessage(chatID string) { + if c.progress == nil { + return + } + c.progress.Clear(chatID) +} + +func (c *TelegramChannel) DismissToolFeedbackMessage(ctx context.Context, chatID string) { + msgID, ok := c.currentToolFeedbackMessage(chatID) + if !ok { + return + } + c.dismissTrackedToolFeedbackMessage(ctx, chatID, msgID) +} + +func (c *TelegramChannel) dismissTrackedToolFeedbackMessage(ctx context.Context, chatID, messageID string) { + if strings.TrimSpace(chatID) == "" || strings.TrimSpace(messageID) == "" { + return + } + c.ClearToolFeedbackMessage(chatID) + _ = c.DeleteMessage(ctx, chatID, messageID) +} + +func (c *TelegramChannel) finalizeTrackedToolFeedbackMessage( + ctx context.Context, + chatID string, + content string, + editFn func(context.Context, string, string, string) error, +) ([]string, bool) { + msgID, baseContent, ok := c.takeToolFeedbackMessage(chatID) + if !ok || editFn == nil { + return nil, false + } + if err := editFn(ctx, chatID, msgID, content); err != nil { + c.RecordToolFeedbackMessage(chatID, msgID, baseContent) + return nil, false + } + return []string{msgID}, true +} + +func (c *TelegramChannel) FinalizeToolFeedbackMessage(ctx context.Context, msg bus.OutboundMessage) ([]string, bool) { + if outboundMessageIsToolFeedback(msg) { + return nil, false + } + return c.finalizeToolFeedbackMessageForChat(ctx, telegramToolFeedbackChatKey(msg.ChatID, &msg.Context), msg) +} + +func (c *TelegramChannel) finalizeToolFeedbackMessageForChat( + ctx context.Context, + chatID string, + msg bus.OutboundMessage, +) ([]string, bool) { + return c.finalizeTrackedToolFeedbackMessage(ctx, chatID, msg.Content, c.EditMessage) +} + // SendPlaceholder implements channels.PlaceholderCapable. // It sends a placeholder message (e.g. "Thinking... 💭") that will later be // edited to the actual response via EditMessage (channels.MessageEditor). @@ -468,6 +595,8 @@ func (c *TelegramChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMe if !c.IsRunning() { return nil, channels.ErrNotRunning } + trackedChatID := telegramToolFeedbackChatKey(msg.ChatID, &msg.Context) + trackedMsgID, hasTrackedMsg := c.currentToolFeedbackMessage(trackedChatID) chatID, threadID, err := resolveTelegramOutboundTarget(msg.ChatID, &msg.Context) if err != nil { @@ -576,6 +705,10 @@ func (c *TelegramChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMe } } + if hasTrackedMsg { + c.dismissTrackedToolFeedbackMessage(ctx, trackedChatID, trackedMsgID) + } + return messageIDs, nil } @@ -947,6 +1080,60 @@ func parseContent(text string, useMarkdownV2 bool) string { return markdownToTelegramHTML(text) } +func fitToolFeedbackForTelegram(content string, useMarkdownV2 bool, maxParsedLen int) string { + content = strings.TrimSpace(content) + if content == "" || maxParsedLen <= 0 { + return "" + } + animationSafeLen := maxParsedLen - channels.MaxToolFeedbackAnimationFrameLength() + if animationSafeLen <= 0 { + animationSafeLen = maxParsedLen + } + if len([]rune(parseContent(content, useMarkdownV2))) <= animationSafeLen { + return content + } + + low := 1 + high := len([]rune(content)) + best := utils.Truncate(content, 1) + + for low <= high { + mid := (low + high) / 2 + candidate := utils.FitToolFeedbackMessage(content, mid) + if candidate == "" { + high = mid - 1 + continue + } + if len([]rune(parseContent(candidate, useMarkdownV2))) <= animationSafeLen { + best = candidate + low = mid + 1 + continue + } + high = mid - 1 + } + + return best +} + +func (c *TelegramChannel) PrepareToolFeedbackMessageContent(content string) string { + if c == nil || c.tgCfg == nil { + return strings.TrimSpace(content) + } + return fitToolFeedbackForTelegram(content, c.tgCfg.UseMarkdownV2, 4096) +} + +func telegramToolFeedbackChatKey(chatID string, outboundCtx *bus.InboundContext) string { + resolvedChatID, threadID, err := resolveTelegramOutboundTarget(chatID, outboundCtx) + if err != nil || threadID == 0 { + return strings.TrimSpace(chatID) + } + return fmt.Sprintf("%d/%d", resolvedChatID, threadID) +} + +func (c *TelegramChannel) ToolFeedbackMessageChatID(chatID string, outboundCtx *bus.InboundContext) string { + return telegramToolFeedbackChatKey(chatID, outboundCtx) +} + // parseTelegramChatID splits "chatID/threadID" into its components. // Returns threadID=0 when no "/" is present (non-forum messages). func parseTelegramChatID(chatID string) (int64, int, error) { @@ -1097,7 +1284,7 @@ func (c *TelegramChannel) BeginStream(ctx context.Context, chatID string) (chann return nil, fmt.Errorf("streaming disabled in config") } - cid, _, err := parseTelegramChatID(chatID) + cid, threadID, err := parseTelegramChatID(chatID) if err != nil { return nil, err } @@ -1106,6 +1293,7 @@ func (c *TelegramChannel) BeginStream(ctx context.Context, chatID string) (chann return &telegramStreamer{ bot: c.bot, chatID: cid, + threadID: threadID, draftID: cryptoRandInt(), throttleInterval: time.Duration(streamCfg.ThrottleSeconds) * time.Second, minGrowth: streamCfg.MinGrowthChars, @@ -1118,6 +1306,7 @@ func (c *TelegramChannel) BeginStream(ctx context.Context, chatID string) (chann type telegramStreamer struct { bot *telego.Bot chatID int64 + threadID int draftID int throttleInterval time.Duration minGrowth int @@ -1145,10 +1334,11 @@ func (s *telegramStreamer) Update(ctx context.Context, content string) error { htmlContent := markdownToTelegramHTML(content) err := s.bot.SendMessageDraft(ctx, &telego.SendMessageDraftParams{ - ChatID: s.chatID, - DraftID: s.draftID, - Text: htmlContent, - ParseMode: telego.ModeHTML, + ChatID: s.chatID, + MessageThreadID: s.threadID, + DraftID: s.draftID, + Text: htmlContent, + ParseMode: telego.ModeHTML, }) if err != nil { // First error → degrade silently (e.g. no forum mode) @@ -1167,6 +1357,7 @@ func (s *telegramStreamer) Update(ctx context.Context, content string) error { func (s *telegramStreamer) Finalize(ctx context.Context, content string) error { htmlContent := markdownToTelegramHTML(content) tgMsg := tu.Message(tu.ID(s.chatID), htmlContent) + tgMsg.MessageThreadID = s.threadID tgMsg.ParseMode = telego.ModeHTML if _, err := s.bot.SendMessage(ctx, tgMsg); err != nil { diff --git a/pkg/channels/telegram/telegram_group_command_filter_test.go b/pkg/channels/telegram/telegram_group_command_filter_test.go index 614b2ca7f..20b2004a9 100644 --- a/pkg/channels/telegram/telegram_group_command_filter_test.go +++ b/pkg/channels/telegram/telegram_group_command_filter_test.go @@ -108,7 +108,7 @@ func TestHandleMessage_GroupMentionOnly_BotCommandEntity(t *testing.T) { t.Fatalf("handleMessage error: %v", err) } - ctx, cancel := context.WithTimeout(context.Background(), 200*time.Microsecond) + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() select { case <-ctx.Done(): diff --git a/pkg/channels/telegram/telegram_test.go b/pkg/channels/telegram/telegram_test.go index 3d147b337..69c76b430 100644 --- a/pkg/channels/telegram/telegram_test.go +++ b/pkg/channels/telegram/telegram_test.go @@ -98,8 +98,12 @@ func (s *multipartRecordingConstructor) MultipartRequest( // successResponse returns a ta.Response that telego will treat as a successful SendMessage. func successResponse(t *testing.T) *ta.Response { + return successResponseWithMessageID(t, 1) +} + +func successResponseWithMessageID(t *testing.T, messageID int) *ta.Response { t.Helper() - msg := &telego.Message{MessageID: 1} + msg := &telego.Message{MessageID: messageID} b, err := json.Marshal(msg) require.NoError(t, err) return &ta.Response{Ok: true, Result: b} @@ -142,6 +146,7 @@ func newTestChannelWithConstructor( chatIDs: make(map[string]int64), bc: &config.Channel{Type: config.ChannelTelegram, Enabled: true}, tgCfg: &config.TelegramSettings{}, + progress: channels.NewToolFeedbackAnimator(nil), } } @@ -266,6 +271,176 @@ func TestSend_ShortMessage_SingleCall(t *testing.T) { assert.Len(t, caller.calls, 1, "short message should result in exactly one SendMessage call") } +func TestSend_NonToolFeedbackDeletesTrackedProgressMessage(t *testing.T) { + caller := &stubCaller{ + callFn: func(ctx context.Context, url string, data *ta.RequestData) (*ta.Response, error) { + switch { + case strings.Contains(url, "editMessageText"): + return successResponseWithMessageID(t, 1), nil + default: + t.Fatalf("unexpected API call: %s", url) + return nil, nil + } + }, + } + ch := newTestChannel(t, caller) + ch.RecordToolFeedbackMessage("12345", "1", "🔧 `read_file`") + + ids, err := ch.Send(context.Background(), bus.OutboundMessage{ + ChatID: "12345", + Content: "final reply", + }) + + assert.NoError(t, err) + assert.Equal(t, []string{"1"}, ids) + require.Len(t, caller.calls, 1) + assert.Contains(t, caller.calls[0].URL, "editMessageText") + _, ok := ch.currentToolFeedbackMessage("12345") + assert.False(t, ok, "tracked tool feedback should be cleared after final reply") +} + +func TestSend_ToolFeedbackTrackingIsTopicScoped(t *testing.T) { + nextMessageID := 0 + caller := &stubCaller{ + callFn: func(ctx context.Context, url string, data *ta.RequestData) (*ta.Response, error) { + nextMessageID++ + return successResponseWithMessageID(t, nextMessageID), nil + }, + } + ch := newTestChannel(t, caller) + + _, err := ch.Send(context.Background(), bus.OutboundMessage{ + ChatID: "-1001234567890", + Content: "🔧 `read_file`", + Context: bus.InboundContext{ + Channel: "telegram", + ChatID: "-1001234567890", + TopicID: "42", + Raw: map[string]string{ + "message_kind": "tool_feedback", + }, + }, + }) + require.NoError(t, err) + + _, ok := ch.currentToolFeedbackMessage("-1001234567890") + assert.False(t, ok, "base chat should not track topic-specific tool feedback") + + msgID, ok := ch.currentToolFeedbackMessage("-1001234567890/42") + require.True(t, ok, "topic chat should track tool feedback") + assert.Equal(t, "1", msgID) +} + +func TestSend_TopicReplyDoesNotFinalizeDifferentTopicToolFeedback(t *testing.T) { + nextMessageID := 0 + caller := &stubCaller{ + callFn: func(ctx context.Context, url string, data *ta.RequestData) (*ta.Response, error) { + nextMessageID++ + return successResponseWithMessageID(t, nextMessageID), nil + }, + } + ch := newTestChannel(t, caller) + + _, err := ch.Send(context.Background(), bus.OutboundMessage{ + ChatID: "-1001234567890", + Content: "🔧 `read_file`", + Context: bus.InboundContext{ + Channel: "telegram", + ChatID: "-1001234567890", + TopicID: "42", + Raw: map[string]string{ + "message_kind": "tool_feedback", + }, + }, + }) + require.NoError(t, err) + + ids, err := ch.Send(context.Background(), bus.OutboundMessage{ + ChatID: "-1001234567890", + Content: "final reply in another topic", + Context: bus.InboundContext{ + Channel: "telegram", + ChatID: "-1001234567890", + TopicID: "43", + }, + }) + require.NoError(t, err) + require.Len(t, caller.calls, 2) + assert.Equal(t, []string{"2"}, ids) + assert.Contains(t, caller.calls[1].URL, "sendMessage") + assert.NotContains(t, caller.calls[1].URL, "editMessageText") + + _, ok := ch.currentToolFeedbackMessage("-1001234567890/42") + assert.True(t, ok, "tool feedback in the original topic should remain tracked") +} + +func TestFinalizeTrackedToolFeedbackMessage_StopsTrackingBeforeEdit(t *testing.T) { + ch := newTestChannel(t, &stubCaller{ + callFn: func(context.Context, string, *ta.RequestData) (*ta.Response, error) { + t.Fatal("unexpected API call") + return nil, nil + }, + }) + ch.RecordToolFeedbackMessage("12345", "1", "🔧 `read_file`") + + msgIDs, handled := ch.finalizeTrackedToolFeedbackMessage( + context.Background(), + "12345", + "final reply", + func(_ context.Context, chatID, messageID, content string) error { + _, ok := ch.currentToolFeedbackMessage(chatID) + assert.False(t, ok, "tracked tool feedback should be stopped before edit") + assert.Equal(t, "12345", chatID) + assert.Equal(t, "1", messageID) + assert.Equal(t, "final reply", content) + return nil + }, + ) + + assert.True(t, handled) + assert.Equal(t, []string{"1"}, msgIDs) +} + +func TestSend_ToolFeedbackStaysSingleMessageAfterHTMLExpansion(t *testing.T) { + caller := &stubCaller{ + callFn: func(ctx context.Context, url string, data *ta.RequestData) (*ta.Response, error) { + return successResponse(t), nil + }, + } + ch := newTestChannel(t, caller) + + _, err := ch.Send(context.Background(), bus.OutboundMessage{ + ChatID: "12345", + Content: "🔧 `read_file`\n" + strings.Repeat("<", 2000), + Context: bus.InboundContext{ + Channel: "telegram", + ChatID: "12345", + Raw: map[string]string{ + "message_kind": "tool_feedback", + }, + }, + }) + + assert.NoError(t, err) + assert.Len(t, caller.calls, 1, "tool feedback should stay a single Telegram message after HTML escaping") +} + +func TestFitToolFeedbackForTelegram_ReservesAnimationFrame(t *testing.T) { + content := "🔧 `read_file`\n" + strings.Repeat("a", 4096) + + fitted := fitToolFeedbackForTelegram(content, false, 4096) + animated := strings.Replace( + fitted, + "`\n", + strings.Repeat(".", channels.MaxToolFeedbackAnimationFrameLength())+"`\n", + 1, + ) + + if got := len([]rune(parseContent(animated, false))); got > 4096 { + t.Fatalf("animated parsed length = %d, want <= 4096", got) + } +} + func TestSend_LongMessage_SingleCall(t *testing.T) { // With WithMaxMessageLength(4000), the Manager pre-splits messages before // they reach Send(). A message at exactly 4000 chars should go through @@ -560,6 +735,58 @@ func TestSend_UsesContextTopicIDWhenChatIDDoesNotIncludeThread(t *testing.T) { assert.Equal(t, "Hello from topic context", params.Text) } +func TestBeginStream_UpdateUsesForumThreadID(t *testing.T) { + caller := &stubCaller{ + callFn: func(ctx context.Context, url string, data *ta.RequestData) (*ta.Response, error) { + return &ta.Response{Ok: true, Result: []byte("true")}, nil + }, + } + ch := newTestChannel(t, caller) + ch.tgCfg.Streaming.Enabled = true + + streamer, err := ch.BeginStream(context.Background(), "-1001234567890/42") + require.NoError(t, err) + require.NoError(t, streamer.Update(context.Background(), "partial")) + require.Len(t, caller.calls, 1) + assert.Contains(t, caller.calls[0].URL, "sendMessageDraft") + + var params struct { + ChatID int64 `json:"chat_id"` + MessageThreadID int `json:"message_thread_id"` + Text string `json:"text"` + } + require.NoError(t, json.Unmarshal(caller.calls[0].Data.BodyRaw, ¶ms)) + assert.Equal(t, int64(-1001234567890), params.ChatID) + assert.Equal(t, 42, params.MessageThreadID) + assert.Equal(t, "partial", params.Text) +} + +func TestBeginStream_FinalizeUsesForumThreadID(t *testing.T) { + caller := &stubCaller{ + callFn: func(ctx context.Context, url string, data *ta.RequestData) (*ta.Response, error) { + return successResponse(t), nil + }, + } + ch := newTestChannel(t, caller) + ch.tgCfg.Streaming.Enabled = true + + streamer, err := ch.BeginStream(context.Background(), "-1001234567890/42") + require.NoError(t, err) + require.NoError(t, streamer.Finalize(context.Background(), "final")) + require.Len(t, caller.calls, 1) + assert.Contains(t, caller.calls[0].URL, "sendMessage") + + var params struct { + ChatID int64 `json:"chat_id"` + MessageThreadID int `json:"message_thread_id"` + Text string `json:"text"` + } + require.NoError(t, json.Unmarshal(caller.calls[0].Data.BodyRaw, ¶ms)) + assert.Equal(t, int64(-1001234567890), params.ChatID) + assert.Equal(t, 42, params.MessageThreadID) + assert.Equal(t, "final", params.Text) +} + func TestHandleMessage_ForumTopic_SetsMetadata(t *testing.T) { messageBus := bus.NewMessageBus() ch := &TelegramChannel{ diff --git a/pkg/channels/tool_feedback_animator.go b/pkg/channels/tool_feedback_animator.go new file mode 100644 index 000000000..b424612bf --- /dev/null +++ b/pkg/channels/tool_feedback_animator.go @@ -0,0 +1,240 @@ +package channels + +import ( + "context" + "strings" + "sync" + "time" +) + +const toolFeedbackAnimationInterval = 3 * time.Second + +const initialToolFeedbackAnimationFrame = "" + +var toolFeedbackAnimationFrames = []string{"..", "."} + +// MaxToolFeedbackAnimationFrameLength returns the largest frame suffix length +// so callers can reserve room before sending messages to length-limited APIs. +func MaxToolFeedbackAnimationFrameLength() int { + maxLen := len([]rune(initialToolFeedbackAnimationFrame)) + for _, frame := range toolFeedbackAnimationFrames { + if frameLen := len([]rune(frame)); frameLen > maxLen { + maxLen = frameLen + } + } + return maxLen +} + +type toolFeedbackAnimationState struct { + messageID string + baseContent string + stop chan struct{} + done chan struct{} +} + +type ToolFeedbackAnimator struct { + mu sync.Mutex + editFn func(ctx context.Context, chatID, messageID, content string) error + entries map[string]*toolFeedbackAnimationState +} + +func NewToolFeedbackAnimator( + editFn func(ctx context.Context, chatID, messageID, content string) error, +) *ToolFeedbackAnimator { + return &ToolFeedbackAnimator{ + editFn: editFn, + entries: make(map[string]*toolFeedbackAnimationState), + } +} + +func (a *ToolFeedbackAnimator) Current(chatID string) (string, bool) { + if a == nil || strings.TrimSpace(chatID) == "" { + return "", false + } + a.mu.Lock() + defer a.mu.Unlock() + entry, ok := a.entries[chatID] + if !ok || strings.TrimSpace(entry.messageID) == "" { + return "", false + } + return entry.messageID, true +} + +func (a *ToolFeedbackAnimator) Record(chatID, messageID, content string) { + if a == nil { + return + } + chatID = strings.TrimSpace(chatID) + messageID = strings.TrimSpace(messageID) + content = strings.TrimSpace(content) + if chatID == "" || messageID == "" || content == "" { + return + } + + entry := &toolFeedbackAnimationState{ + messageID: messageID, + baseContent: content, + stop: make(chan struct{}), + done: make(chan struct{}), + } + + var previous *toolFeedbackAnimationState + a.mu.Lock() + if old, ok := a.entries[chatID]; ok { + previous = old + } + a.entries[chatID] = entry + a.mu.Unlock() + + stopToolFeedbackAnimation(previous) + go a.run(chatID, entry) +} + +func (a *ToolFeedbackAnimator) Clear(chatID string) { + if a == nil || strings.TrimSpace(chatID) == "" { + return + } + entry := a.detach(chatID) + stopToolFeedbackAnimation(entry) +} + +func (a *ToolFeedbackAnimator) Take(chatID string) (string, string, bool) { + if a == nil || strings.TrimSpace(chatID) == "" { + return "", "", false + } + entry := a.detach(chatID) + if entry == nil || strings.TrimSpace(entry.messageID) == "" { + return "", "", false + } + stopToolFeedbackAnimation(entry) + return entry.messageID, entry.baseContent, true +} + +// Update edits an existing tracked feedback message. If the edit fails, the +// previous feedback state is restored so callers can retry without orphaning +// the old progress message. +func (a *ToolFeedbackAnimator) Update(ctx context.Context, chatID, content string) (string, bool, error) { + if a == nil || a.editFn == nil { + return "", false, nil + } + msgID, baseContent, ok := a.Take(chatID) + if !ok { + return "", false, nil + } + + animatedContent := InitialAnimatedToolFeedbackContent(content) + if err := a.editFn(ctx, strings.TrimSpace(chatID), msgID, animatedContent); err != nil { + a.Record(chatID, msgID, baseContent) + return "", true, err + } + + a.Record(chatID, msgID, content) + return msgID, true, nil +} + +func (a *ToolFeedbackAnimator) StopAll() { + if a == nil { + return + } + a.mu.Lock() + entries := make([]*toolFeedbackAnimationState, 0, len(a.entries)) + for chatID, entry := range a.entries { + entries = append(entries, entry) + delete(a.entries, chatID) + } + a.mu.Unlock() + + for _, entry := range entries { + stopToolFeedbackAnimation(entry) + } +} + +func (a *ToolFeedbackAnimator) detach(chatID string) *toolFeedbackAnimationState { + if a == nil || strings.TrimSpace(chatID) == "" { + return nil + } + a.mu.Lock() + defer a.mu.Unlock() + entry := a.entries[chatID] + delete(a.entries, chatID) + return entry +} + +func (a *ToolFeedbackAnimator) run(chatID string, entry *toolFeedbackAnimationState) { + defer close(entry.done) + + ticker := time.NewTicker(toolFeedbackAnimationInterval) + defer ticker.Stop() + + frameIdx := 1 + + for { + select { + case <-entry.stop: + return + case <-ticker.C: + if a.editFn == nil { + continue + } + frame := toolFeedbackAnimationFrames[frameIdx%len(toolFeedbackAnimationFrames)] + content := formatAnimatedToolFeedbackContent(entry.baseContent, frame) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + _ = a.editFn(ctx, chatID, entry.messageID, content) + cancel() + frameIdx++ + } + } +} + +func InitialAnimatedToolFeedbackContent(baseContent string) string { + return formatAnimatedToolFeedbackContent(baseContent, initialToolFeedbackAnimationFrame) +} + +func formatAnimatedToolFeedbackContent(baseContent, frame string) string { + baseContent = strings.TrimSpace(baseContent) + frame = strings.TrimSpace(frame) + if baseContent == "" { + return "" + } + if frame == "" { + return baseContent + } + lineBreak := strings.IndexByte(baseContent, '\n') + if lineBreak < 0 { + return appendToolFeedbackFrame(baseContent, frame) + } + return appendToolFeedbackFrame(baseContent[:lineBreak], frame) + baseContent[lineBreak:] +} + +func appendToolFeedbackFrame(firstLine, frame string) string { + firstLine = strings.TrimSpace(firstLine) + frame = strings.TrimSpace(frame) + if firstLine == "" { + return "" + } + if frame == "" { + return firstLine + } + + openTick := strings.IndexByte(firstLine, '`') + if openTick >= 0 { + if closeOffset := strings.IndexByte(firstLine[openTick+1:], '`'); closeOffset >= 0 { + closeTick := openTick + 1 + closeOffset + return firstLine[:closeTick] + frame + firstLine[closeTick:] + } + } + + return firstLine + frame +} + +func stopToolFeedbackAnimation(entry *toolFeedbackAnimationState) { + if entry == nil { + return + } + select { + case <-entry.stop: + default: + close(entry.stop) + } + <-entry.done +} diff --git a/pkg/channels/tool_feedback_animator_test.go b/pkg/channels/tool_feedback_animator_test.go new file mode 100644 index 000000000..a23284548 --- /dev/null +++ b/pkg/channels/tool_feedback_animator_test.go @@ -0,0 +1,121 @@ +package channels + +import ( + "context" + "errors" + "testing" +) + +func TestFormatAnimatedToolFeedbackContent(t *testing.T) { + got := formatAnimatedToolFeedbackContent("🔧 `read_file`\nReading config file", "running..") + want := "🔧 `read_filerunning..`\nReading config file" + if got != want { + t.Fatalf("formatAnimatedToolFeedbackContent() = %q, want %q", got, want) + } +} + +func TestInitialAnimatedToolFeedbackContent(t *testing.T) { + got := InitialAnimatedToolFeedbackContent("🔧 `exec`\nRunning command") + want := "🔧 `exec`\nRunning command" + if got != want { + t.Fatalf("InitialAnimatedToolFeedbackContent() = %q, want %q", got, want) + } +} + +func TestFormatAnimatedToolFeedbackContent_WithoutCodeSpan(t *testing.T) { + got := formatAnimatedToolFeedbackContent("hello", "running..") + want := "hellorunning.." + if got != want { + t.Fatalf("formatAnimatedToolFeedbackContent() without code span = %q, want %q", got, want) + } +} + +func TestToolFeedbackAnimator_RecordCurrentAndClear(t *testing.T) { + animator := NewToolFeedbackAnimator(nil) + animator.Record("chat-1", "msg-1", "🔧 `read_file`") + + msgID, ok := animator.Current("chat-1") + if !ok || msgID != "msg-1" { + t.Fatalf("Current() = (%q, %v), want (msg-1, true)", msgID, ok) + } + + animator.Clear("chat-1") + + msgID, ok = animator.Current("chat-1") + if ok || msgID != "" { + t.Fatalf("Current() after Clear = (%q, %v), want (\"\", false)", msgID, ok) + } +} + +func TestToolFeedbackAnimator_TakeStopsTrackingAndReturnsState(t *testing.T) { + animator := NewToolFeedbackAnimator(nil) + animator.Record("chat-1", "msg-1", "🔧 `read_file`\nChecking config") + + msgID, baseContent, ok := animator.Take("chat-1") + if !ok { + t.Fatal("Take() = not found, want tracked message") + } + if msgID != "msg-1" { + t.Fatalf("Take() msgID = %q, want msg-1", msgID) + } + if baseContent != "🔧 `read_file`\nChecking config" { + t.Fatalf("Take() baseContent = %q", baseContent) + } + if _, ok := animator.Current("chat-1"); ok { + t.Fatal("expected tracked message to be removed after Take()") + } +} + +func TestToolFeedbackAnimator_UpdateStopsTrackingBeforeEdit(t *testing.T) { + var animator *ToolFeedbackAnimator + animator = NewToolFeedbackAnimator(func(_ context.Context, chatID, messageID, content string) error { + if _, ok := animator.Current(chatID); ok { + t.Fatal("expected tracked tool feedback to be stopped before edit") + } + if messageID != "msg-1" { + t.Fatalf("messageID = %q, want msg-1", messageID) + } + if content != "🔧 `write_file`\nUpdating config" { + t.Fatalf("content = %q, want updated animated content", content) + } + return nil + }) + defer animator.StopAll() + + animator.Record("chat-1", "msg-1", "🔧 `read_file`\nChecking config") + + msgID, handled, err := animator.Update(context.Background(), "chat-1", "🔧 `write_file`\nUpdating config") + if err != nil { + t.Fatalf("Update() error = %v", err) + } + if !handled { + t.Fatal("Update() handled = false, want true") + } + if msgID != "msg-1" { + t.Fatalf("Update() msgID = %q, want msg-1", msgID) + } +} + +func TestToolFeedbackAnimator_UpdateFailureRestoresTracking(t *testing.T) { + editErr := errors.New("edit failed") + animator := NewToolFeedbackAnimator(func(context.Context, string, string, string) error { + return editErr + }) + defer animator.StopAll() + + animator.Record("chat-1", "msg-1", "🔧 `read_file`\nChecking config") + + msgID, handled, err := animator.Update(context.Background(), "chat-1", "🔧 `write_file`\nUpdating config") + if !handled { + t.Fatal("Update() handled = false, want true") + } + if !errors.Is(err, editErr) { + t.Fatalf("Update() error = %v, want editErr", err) + } + if msgID != "" { + t.Fatalf("Update() msgID = %q, want empty on failed edit", msgID) + } + if currentID, ok := animator.Current("chat-1"); !ok || currentID != "msg-1" { + t.Fatalf("Current() after failed Update = (%q, %v), want (msg-1, true)", currentID, ok) + } +} diff --git a/pkg/config/config.go b/pkg/config/config.go index a39cb55ae..161108638 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -286,7 +286,7 @@ func (d *AgentDefaults) GetMaxMediaSize() int { return DefaultMaxMediaSize } -// GetToolFeedbackMaxArgsLength returns the max args preview length for tool feedback messages. +// GetToolFeedbackMaxArgsLength returns the max visible text length for tool feedback messages. func (d *AgentDefaults) GetToolFeedbackMaxArgsLength() int { if d.ToolFeedback.MaxArgsLength > 0 { return d.ToolFeedback.MaxArgsLength diff --git a/pkg/providers/cli/toolcall_utils.go b/pkg/providers/cli/toolcall_utils.go index b480082eb..1f58c9a26 100644 --- a/pkg/providers/cli/toolcall_utils.go +++ b/pkg/providers/cli/toolcall_utils.go @@ -55,6 +55,12 @@ func buildCLIToolsPrompt(tools []ToolDefinition) string { func NormalizeToolCall(tc ToolCall) ToolCall { normalized := tc + if normalized.ThoughtSignature == "" && + normalized.ExtraContent != nil && + normalized.ExtraContent.Google != nil { + normalized.ThoughtSignature = normalized.ExtraContent.Google.ThoughtSignature + } + // Ensure Name is populated from Function if not set if normalized.Name == "" && normalized.Function != nil { normalized.Name = normalized.Function.Name @@ -77,8 +83,9 @@ func NormalizeToolCall(tc ToolCall) ToolCall { argsJSON, _ := json.Marshal(normalized.Arguments) if normalized.Function == nil { normalized.Function = &FunctionCall{ - Name: normalized.Name, - Arguments: string(argsJSON), + Name: normalized.Name, + Arguments: string(argsJSON), + ThoughtSignature: normalized.ThoughtSignature, } } else { if normalized.Function.Name == "" { @@ -90,6 +97,12 @@ func NormalizeToolCall(tc ToolCall) ToolCall { if normalized.Function.Arguments == "" { normalized.Function.Arguments = string(argsJSON) } + if normalized.Function.ThoughtSignature == "" { + normalized.Function.ThoughtSignature = normalized.ThoughtSignature + } + if normalized.ThoughtSignature == "" { + normalized.ThoughtSignature = normalized.Function.ThoughtSignature + } } return normalized diff --git a/pkg/providers/common/common.go b/pkg/providers/common/common.go index 90142fb8b..0a702e85e 100644 --- a/pkg/providers/common/common.go +++ b/pkg/providers/common/common.go @@ -70,11 +70,23 @@ func NewHTTPClient(proxy string) *http.Client { // It mirrors protocoltypes.Message but omits SystemParts, which is an // internal field that would be unknown to third-party endpoints. type openaiMessage struct { - Role string `json:"role"` - Content string `json:"content"` - ReasoningContent string `json:"reasoning_content,omitempty"` - ToolCalls []ToolCall `json:"tool_calls,omitempty"` - ToolCallID string `json:"tool_call_id,omitempty"` + Role string `json:"role"` + Content string `json:"content"` + ReasoningContent string `json:"reasoning_content,omitempty"` + ToolCalls []openaiToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` +} + +type openaiToolCall struct { + ID string `json:"id"` + Type string `json:"type,omitempty"` + Function *openaiFunctionCall `json:"function,omitempty"` +} + +type openaiFunctionCall struct { + Name string `json:"name"` + Arguments string `json:"arguments"` + ThoughtSignature string `json:"thought_signature,omitempty"` } // SerializeMessages converts internal Message structs to the OpenAI wire format. @@ -84,12 +96,13 @@ type openaiMessage struct { func SerializeMessages(messages []Message) []any { out := make([]any, 0, len(messages)) for _, m := range messages { + toolCalls := serializeToolCalls(m.ToolCalls) if len(m.Media) == 0 { out = append(out, openaiMessage{ Role: m.Role, Content: m.Content, ReasoningContent: m.ReasoningContent, - ToolCalls: m.ToolCalls, + ToolCalls: toolCalls, ToolCallID: m.ToolCallID, }) continue @@ -132,8 +145,8 @@ func SerializeMessages(messages []Message) []any { if m.ToolCallID != "" { msg["tool_call_id"] = m.ToolCallID } - if len(m.ToolCalls) > 0 { - msg["tool_calls"] = m.ToolCalls + if len(toolCalls) > 0 { + msg["tool_calls"] = toolCalls } if m.ReasoningContent != "" { msg["reasoning_content"] = m.ReasoningContent @@ -143,6 +156,55 @@ func SerializeMessages(messages []Message) []any { return out } +func serializeToolCalls(toolCalls []ToolCall) []openaiToolCall { + if len(toolCalls) == 0 { + return nil + } + + out := make([]openaiToolCall, 0, len(toolCalls)) + for _, tc := range toolCalls { + wireCall := openaiToolCall{ + ID: tc.ID, + Type: tc.Type, + } + + if tc.Function != nil { + thoughtSignature := tc.Function.ThoughtSignature + if thoughtSignature == "" { + thoughtSignature = tc.ThoughtSignature + } + if thoughtSignature == "" && tc.ExtraContent != nil && tc.ExtraContent.Google != nil { + thoughtSignature = tc.ExtraContent.Google.ThoughtSignature + } + wireCall.Function = &openaiFunctionCall{ + Name: tc.Function.Name, + Arguments: tc.Function.Arguments, + ThoughtSignature: thoughtSignature, + } + } else if tc.Name != "" || len(tc.Arguments) > 0 || tc.ThoughtSignature != "" { + thoughtSignature := tc.ThoughtSignature + if thoughtSignature == "" && tc.ExtraContent != nil && tc.ExtraContent.Google != nil { + thoughtSignature = tc.ExtraContent.Google.ThoughtSignature + } + argsJSON := "{}" + if len(tc.Arguments) > 0 { + if encoded, err := json.Marshal(tc.Arguments); err == nil { + argsJSON = string(encoded) + } + } + wireCall.Function = &openaiFunctionCall{ + Name: tc.Name, + Arguments: argsJSON, + ThoughtSignature: thoughtSignature, + } + } + + out = append(out, wireCall) + } + + return out +} + func parseDataAudioURL(mediaURL string) (format, data string, ok bool) { if !strings.HasPrefix(mediaURL, "data:audio/") { return "", "", false @@ -178,13 +240,15 @@ func ParseResponse(body io.Reader) (*LLMResponse, error) { ID string `json:"id"` Type string `json:"type"` Function *struct { - Name string `json:"name"` - Arguments json.RawMessage `json:"arguments"` + Name string `json:"name"` + Arguments json.RawMessage `json:"arguments"` + ThoughtSignature string `json:"thought_signature"` } `json:"function"` ExtraContent *struct { Google *struct { ThoughtSignature string `json:"thought_signature"` } `json:"google"` + ToolFeedbackExplanation string `json:"tool_feedback_explanation"` } `json:"extra_content"` } `json:"tool_calls"` } `json:"message"` @@ -210,9 +274,11 @@ func ParseResponse(body io.Reader) (*LLMResponse, error) { arguments := make(map[string]any) name := "" - // Extract thought_signature from Gemini/Google-specific extra content thoughtSignature := "" - if tc.ExtraContent != nil && tc.ExtraContent.Google != nil { + if tc.Function != nil { + thoughtSignature = tc.Function.ThoughtSignature + } + if thoughtSignature == "" && tc.ExtraContent != nil && tc.ExtraContent.Google != nil { thoughtSignature = tc.ExtraContent.Google.ThoughtSignature } @@ -228,11 +294,20 @@ func ParseResponse(body io.Reader) (*LLMResponse, error) { ThoughtSignature: thoughtSignature, } - if thoughtSignature != "" { - toolCall.ExtraContent = &ExtraContent{ - Google: &GoogleExtra{ + if thoughtSignature != "" || tc.ExtraContent != nil { + extraContent := &ExtraContent{ + ToolFeedbackExplanation: "", + } + if tc.ExtraContent != nil { + extraContent.ToolFeedbackExplanation = tc.ExtraContent.ToolFeedbackExplanation + } + if thoughtSignature != "" { + extraContent.Google = &GoogleExtra{ ThoughtSignature: thoughtSignature, - }, + } + } + if extraContent.Google != nil || strings.TrimSpace(extraContent.ToolFeedbackExplanation) != "" { + toolCall.ExtraContent = extraContent } } diff --git a/pkg/providers/common/common_test.go b/pkg/providers/common/common_test.go index c107bb665..a42d778f1 100644 --- a/pkg/providers/common/common_test.go +++ b/pkg/providers/common/common_test.go @@ -162,6 +162,104 @@ func TestSerializeMessages_StripsSystemParts(t *testing.T) { } } +func TestSerializeMessages_StripsInternalToolCallExtraContent(t *testing.T) { + messages := []Message{ + { + Role: "assistant", + ToolCalls: []ToolCall{{ + ID: "call_1", + Type: "function", + Function: &FunctionCall{ + Name: "read_file", + Arguments: `{"path":"README.md"}`, + ThoughtSignature: "sig-1", + }, + ExtraContent: &ExtraContent{ + Google: &GoogleExtra{ + ThoughtSignature: "sig-ignored-here", + }, + ToolFeedbackExplanation: "Read README.md first.", + }, + }}, + }, + } + + result := SerializeMessages(messages) + + data, err := json.Marshal(result) + if err != nil { + t.Fatalf("json.Marshal() error = %v", err) + } + payload := string(data) + if strings.Contains(payload, "extra_content") { + t.Fatalf("serialized payload should not include internal extra_content: %s", payload) + } + if !strings.Contains(payload, "thought_signature") { + t.Fatalf("serialized payload should preserve function thought_signature: %s", payload) + } +} + +func TestSerializeMessages_PreservesTopLevelThoughtSignature(t *testing.T) { + messages := []Message{ + { + Role: "assistant", + ToolCalls: []ToolCall{{ + ID: "call_1", + Type: "function", + ThoughtSignature: "sig-1", + Function: &FunctionCall{ + Name: "read_file", + Arguments: `{"path":"README.md"}`, + }, + }}, + }, + } + + result := SerializeMessages(messages) + + data, err := json.Marshal(result) + if err != nil { + t.Fatalf("json.Marshal() error = %v", err) + } + payload := string(data) + if !strings.Contains(payload, `"thought_signature":"sig-1"`) { + t.Fatalf("serialized payload should preserve top-level thought signature: %s", payload) + } +} + +func TestSerializeMessages_PreservesGoogleExtraThoughtSignature(t *testing.T) { + messages := []Message{ + { + Role: "assistant", + ToolCalls: []ToolCall{{ + ID: "call_1", + Type: "function", + Function: &FunctionCall{ + Name: "read_file", + Arguments: `{"path":"README.md"}`, + }, + ExtraContent: &ExtraContent{ + Google: &GoogleExtra{ThoughtSignature: "sig-1"}, + }, + }}, + }, + } + + result := SerializeMessages(messages) + + data, err := json.Marshal(result) + if err != nil { + t.Fatalf("json.Marshal() error = %v", err) + } + payload := string(data) + if strings.Contains(payload, "extra_content") { + t.Fatalf("serialized payload should not include extra_content: %s", payload) + } + if !strings.Contains(payload, `"thought_signature":"sig-1"`) { + t.Fatalf("serialized payload should preserve google thought signature: %s", payload) + } +} + // --- ParseResponse tests --- func TestParseResponse_BasicContent(t *testing.T) { @@ -234,6 +332,27 @@ func TestParseResponse_WithReasoningContent(t *testing.T) { } } +func TestParseResponse_WithToolFeedbackExplanationExtraContent(t *testing.T) { + body := `{"choices":[{"message":{"content":"","tool_calls":[{"id":"call_1","type":"function","function":{"name":"test_tool","arguments":"{}"},"extra_content":{"tool_feedback_explanation":"Check the current config before editing."}}]},"finish_reason":"tool_calls"}]}` + out, err := ParseResponse(strings.NewReader(body)) + if err != nil { + t.Fatalf("ParseResponse() error = %v", err) + } + if len(out.ToolCalls) != 1 { + t.Fatalf("len(ToolCalls) = %d, want 1", len(out.ToolCalls)) + } + if out.ToolCalls[0].ExtraContent == nil { + t.Fatal("ExtraContent is nil") + } + if out.ToolCalls[0].ExtraContent.ToolFeedbackExplanation != "Check the current config before editing." { + t.Fatalf( + "ToolFeedbackExplanation = %q, want %q", + out.ToolCalls[0].ExtraContent.ToolFeedbackExplanation, + "Check the current config before editing.", + ) + } +} + func TestParseResponse_InvalidJSON(t *testing.T) { _, err := ParseResponse(strings.NewReader("not json")) if err == nil { @@ -626,3 +745,27 @@ func TestParseResponse_WithThoughtSignature(t *testing.T) { out.ToolCalls[0].ExtraContent.Google.ThoughtSignature, "sig123") } } + +func TestParseResponse_WithFunctionThoughtSignature(t *testing.T) { + body := `{"choices":[{"message":{"content":"","tool_calls":[{"id":"call_1","type":"function","function":{"name":"test_tool","arguments":"{}","thought_signature":"sig456"}}]},"finish_reason":"tool_calls"}]}` + out, err := ParseResponse(strings.NewReader(body)) + if err != nil { + t.Fatalf("ParseResponse() error = %v", err) + } + if len(out.ToolCalls) != 1 { + t.Fatalf("len(ToolCalls) = %d, want 1", len(out.ToolCalls)) + } + if out.ToolCalls[0].ThoughtSignature != "sig456" { + t.Fatalf("ThoughtSignature = %q, want %q", out.ToolCalls[0].ThoughtSignature, "sig456") + } + if out.ToolCalls[0].ExtraContent == nil || out.ToolCalls[0].ExtraContent.Google == nil { + t.Fatal("ExtraContent.Google is nil") + } + if out.ToolCalls[0].ExtraContent.Google.ThoughtSignature != "sig456" { + t.Fatalf( + "ExtraContent.Google.ThoughtSignature = %q, want %q", + out.ToolCalls[0].ExtraContent.Google.ThoughtSignature, + "sig456", + ) + } +} diff --git a/pkg/providers/protocoltypes/types.go b/pkg/providers/protocoltypes/types.go index 89f68928a..f3553f8b0 100644 --- a/pkg/providers/protocoltypes/types.go +++ b/pkg/providers/protocoltypes/types.go @@ -11,7 +11,8 @@ type ToolCall struct { } type ExtraContent struct { - Google *GoogleExtra `json:"google,omitempty"` + Google *GoogleExtra `json:"google,omitempty"` + ToolFeedbackExplanation string `json:"tool_feedback_explanation,omitempty"` } type GoogleExtra struct { diff --git a/pkg/providers/toolcall_utils_test.go b/pkg/providers/toolcall_utils_test.go new file mode 100644 index 000000000..a4bb03c2e --- /dev/null +++ b/pkg/providers/toolcall_utils_test.go @@ -0,0 +1,24 @@ +package providers + +import "testing" + +func TestNormalizeToolCall_PreservesExtraContentGoogleThoughtSignature(t *testing.T) { + tc := NormalizeToolCall(ToolCall{ + ID: "call_1", + Name: "search", + Arguments: map[string]any{"q": "pico"}, + ExtraContent: &ExtraContent{ + Google: &GoogleExtra{ThoughtSignature: "sig-1"}, + }, + }) + + if tc.ThoughtSignature != "sig-1" { + t.Fatalf("ThoughtSignature = %q, want sig-1", tc.ThoughtSignature) + } + if tc.Function == nil { + t.Fatal("Function is nil") + } + if tc.Function.ThoughtSignature != "sig-1" { + t.Fatalf("Function.ThoughtSignature = %q, want sig-1", tc.Function.ThoughtSignature) + } +} diff --git a/pkg/utils/tool_feedback.go b/pkg/utils/tool_feedback.go index a6c8895b8..1a8b6c747 100644 --- a/pkg/utils/tool_feedback.go +++ b/pkg/utils/tool_feedback.go @@ -1,9 +1,57 @@ package utils -import "fmt" +import ( + "fmt" + "strings" +) -// FormatToolFeedbackMessage renders the tool name and arguments preview in the -// same markdown shape used by live tool feedback and session reconstruction. -func FormatToolFeedbackMessage(toolName, argsPreview string) string { - return fmt.Sprintf("\U0001f527 `%s`\n```\n%s\n```", toolName, argsPreview) +const ToolFeedbackContinuationHint = "Continuing the current task." + +// FormatToolFeedbackMessage renders the model-provided explanation for why a +// tool is being executed. When the model does not provide one, it keeps only +// the tool line and does not expose raw arguments or fallback text. +func FormatToolFeedbackMessage(toolName, explanation string) string { + toolName = strings.TrimSpace(toolName) + explanation = strings.TrimSpace(explanation) + + if toolName == "" { + return explanation + } + if explanation == "" { + return fmt.Sprintf("\U0001f527 `%s`", toolName) + } + + return fmt.Sprintf("\U0001f527 `%s`\n%s", toolName, explanation) +} + +// FitToolFeedbackMessage keeps tool feedback within a single outbound message. +// It preserves the first line when possible and truncates the explanation body +// instead of letting the message be split into multiple chunks. +func FitToolFeedbackMessage(content string, maxLen int) string { + content = strings.TrimSpace(content) + if content == "" || maxLen <= 0 { + return "" + } + if len([]rune(content)) <= maxLen { + return content + } + + firstLine, rest, hasRest := strings.Cut(content, "\n") + firstLine = strings.TrimSpace(firstLine) + rest = strings.TrimSpace(rest) + + if !hasRest || rest == "" { + return Truncate(firstLine, maxLen) + } + + if len([]rune(firstLine)) >= maxLen { + return Truncate(firstLine, maxLen) + } + + remaining := maxLen - len([]rune(firstLine)) - 1 + if remaining <= 0 { + return Truncate(firstLine, maxLen) + } + + return firstLine + "\n" + Truncate(rest, remaining) } diff --git a/pkg/utils/tool_feedback_test.go b/pkg/utils/tool_feedback_test.go index d7a55ce6b..316ce2408 100644 --- a/pkg/utils/tool_feedback_test.go +++ b/pkg/utils/tool_feedback_test.go @@ -3,9 +3,47 @@ package utils import "testing" func TestFormatToolFeedbackMessage(t *testing.T) { - got := FormatToolFeedbackMessage("read_file", "{\"path\":\"README.md\"}") - want := "\U0001f527 `read_file`\n```\n{\"path\":\"README.md\"}\n```" + got := FormatToolFeedbackMessage( + "read_file", + "I will read README.md first to confirm the current project structure.", + ) + want := "\U0001f527 `read_file`\nI will read README.md first to confirm the current project structure." if got != want { t.Fatalf("FormatToolFeedbackMessage() = %q, want %q", got, want) } } + +func TestFormatToolFeedbackMessage_EmptyExplanationKeepsOnlyToolLine(t *testing.T) { + got := FormatToolFeedbackMessage("read_file", "") + want := "\U0001f527 `read_file`" + if got != want { + t.Fatalf("FormatToolFeedbackMessage() = %q, want %q", got, want) + } +} + +func TestFormatToolFeedbackMessage_EmptyToolNameOmitsToolLine(t *testing.T) { + got := FormatToolFeedbackMessage("", "Continue drafting the final response.") + want := "Continue drafting the final response." + if got != want { + t.Fatalf("FormatToolFeedbackMessage() = %q, want %q", got, want) + } +} + +func TestFitToolFeedbackMessage_TruncatesBodyWithinSingleMessage(t *testing.T) { + got := FitToolFeedbackMessage( + "\U0001f527 `read_file`\nRead README.md first to confirm the current project structure.", + 40, + ) + want := "\U0001f527 `read_file`\nRead README.md first to..." + if got != want { + t.Fatalf("FitToolFeedbackMessage() = %q, want %q", got, want) + } +} + +func TestFitToolFeedbackMessage_TruncatesSingleLineMessage(t *testing.T) { + got := FitToolFeedbackMessage("\U0001f527 `read_file`", 10) + want := "\U0001f527 `read..." + if got != want { + t.Fatalf("FitToolFeedbackMessage() = %q, want %q", got, want) + } +} diff --git a/web/backend/api/pico.go b/web/backend/api/pico.go index 439a41a1c..8eeff4041 100644 --- a/web/backend/api/pico.go +++ b/web/backend/api/pico.go @@ -233,6 +233,10 @@ func (h *Handler) handleRegenPicoToken(w http.ResponseWriter, r *http.Request) { return } + gateway.mu.Lock() + gateway.picoToken = token + gateway.mu.Unlock() + h.writePicoInfoResponse(w, r, cfg, nil) } diff --git a/web/backend/api/pico_test.go b/web/backend/api/pico_test.go index 77fe1039b..6f7cefd4d 100644 --- a/web/backend/api/pico_test.go +++ b/web/backend/api/pico_test.go @@ -393,6 +393,58 @@ func TestHandleGetPicoInfo_OmitsToken(t *testing.T) { } } +func TestHandleRegenPicoToken_RefreshesGatewayTokenCache(t *testing.T) { + configPath := filepath.Join(t.TempDir(), "config.json") + h := NewHandler(configPath) + + if _, err := h.EnsurePicoChannel(); err != nil { + t.Fatalf("EnsurePicoChannel() error = %v", err) + } + + origPicoToken := gateway.picoToken + t.Cleanup(func() { + gateway.mu.Lock() + gateway.picoToken = origPicoToken + gateway.mu.Unlock() + }) + + gateway.mu.Lock() + gateway.picoToken = "stale-token" + gateway.mu.Unlock() + + req := httptest.NewRequest(http.MethodPost, "http://launcher.local/api/pico/token", nil) + rec := httptest.NewRecorder() + h.handleRegenPicoToken(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK) + } + + cfg, err := config.LoadConfig(configPath) + if err != nil { + t.Fatalf("LoadConfig() error = %v", err) + } + + bc := cfg.Channels["pico"] + decoded, err := bc.GetDecoded() + if err != nil { + t.Fatalf("GetDecoded() error = %v", err) + } + token := decoded.(*config.PicoSettings).Token.String() + if token == "" { + t.Fatal("expected regenerated pico token to be persisted") + } + if token == "stale-token" { + t.Fatal("expected regenerated pico token to differ from stale cache") + } + + gateway.mu.Lock() + defer gateway.mu.Unlock() + if gateway.picoToken != token { + t.Fatalf("gateway.picoToken = %q, want %q", gateway.picoToken, token) + } +} + func TestHandleWebSocketProxyReloadsGatewayTargetFromConfig(t *testing.T) { origMatcher := gatewayProcessMatcher gatewayProcessMatcher = func(int) (bool, bool) { return true, true } diff --git a/web/backend/api/session.go b/web/backend/api/session.go index 0483b57cc..6ac1eb988 100644 --- a/web/backend/api/session.go +++ b/web/backend/api/session.go @@ -510,6 +510,16 @@ func visibleSessionMessages(messages []providers.Message, toolFeedbackMaxArgsLen transcript = append(transcript, visibleToolMessages...) } + // When assistant content exactly matches the rendered tool summary or + // tool-delivered message, skip it to avoid duplicates. Distinct content + // must remain visible in restored session history. + if len(msg.ToolCalls) > 0 && + len(msg.Media) == 0 && + len(attachments) == 0 && + assistantToolCallContentDuplicated(msg.Content, toolSummaryMessages, visibleToolMessages) { + continue + } + // Pico web chat can persist both visible `message` tool output and a // later plain assistant reply in the same turn. Hide only the fixed // internal summary that marks handled tool delivery. @@ -549,6 +559,43 @@ func filterSessionChatMessages(messages []sessionChatMessage) []sessionChatMessa return filtered } +func assistantToolCallContentDuplicated( + content string, + toolSummaryMessages []sessionChatMessage, + visibleToolMessages []sessionChatMessage, +) bool { + content = strings.TrimSpace(content) + if content == "" { + return false + } + + for _, msg := range toolSummaryMessages { + if toolSummaryContainsContent(msg.Content, content) { + return true + } + } + for _, msg := range visibleToolMessages { + if strings.TrimSpace(msg.Content) == content { + return true + } + } + return false +} + +func toolSummaryContainsContent(summary, content string) bool { + summary = strings.TrimSpace(summary) + content = strings.TrimSpace(content) + if summary == "" || content == "" { + return false + } + if summary == content { + return true + } + + _, body, hasBody := strings.Cut(summary, "\n") + return hasBody && strings.TrimSpace(body) == content +} + func sessionAttachments(msg providers.Message) []sessionChatAttachment { if len(msg.Attachments) == 0 { return nil @@ -663,20 +710,41 @@ func visibleAssistantToolSummaryMessages( } } - argsPreview := strings.TrimSpace(argsJSON) - if argsPreview == "" { - argsPreview = "{}" - } - messages = append(messages, sessionChatMessage{ - Role: "assistant", - Content: utils.FormatToolFeedbackMessage(name, utils.Truncate(argsPreview, toolFeedbackMaxArgsLength)), + Role: "assistant", + Content: utils.FormatToolFeedbackMessage( + name, + visibleAssistantToolSummaryText(tc, toolFeedbackMaxArgsLength), + ), }) } return messages } +func visibleAssistantToolSummaryText( + tc providers.ToolCall, + toolFeedbackMaxArgsLength int, +) string { + if tc.ExtraContent != nil { + if explanation := strings.TrimSpace(tc.ExtraContent.ToolFeedbackExplanation); explanation != "" { + return utils.Truncate(explanation, toolFeedbackMaxArgsLength) + } + } + + argsJSON := "" + if tc.Function != nil { + argsJSON = tc.Function.Arguments + } + if strings.TrimSpace(argsJSON) == "" && len(tc.Arguments) > 0 { + if encodedArgs, err := json.Marshal(tc.Arguments); err == nil { + argsJSON = string(encodedArgs) + } + } + + return utils.Truncate(strings.TrimSpace(argsJSON), toolFeedbackMaxArgsLength) +} + func visibleAssistantToolMessages(toolCalls []providers.ToolCall) []sessionChatMessage { if len(toolCalls) == 0 { return nil diff --git a/web/backend/api/session_test.go b/web/backend/api/session_test.go index d2efb3879..6afb8a94f 100644 --- a/web/backend/api/session_test.go +++ b/web/backend/api/session_test.go @@ -675,7 +675,7 @@ func TestHandleListSessions_MessageCountUsesVisibleTranscript(t *testing.T) { } } -func TestHandleGetSession_PreservesToolSummaryAndAssistantContent(t *testing.T) { +func TestHandleGetSession_DoesNotDuplicateAssistantToolCallContent(t *testing.T) { configPath, cleanup := setupOAuthTestEnv(t) defer cleanup() @@ -690,7 +690,7 @@ func TestHandleGetSession_PreservesToolSummaryAndAssistantContent(t *testing.T) {Role: "user", Content: "check file"}, { Role: "assistant", - Content: "model final reply", + Content: "Read the file before replying.", ToolCalls: []providers.ToolCall{ { ID: "call_1", @@ -699,6 +699,9 @@ func TestHandleGetSession_PreservesToolSummaryAndAssistantContent(t *testing.T) Name: "read_file", Arguments: `{"path":"README.md","start_line":1,"end_line":10}`, }, + ExtraContent: &providers.ExtraContent{ + ToolFeedbackExplanation: "Read the file before replying.", + }, }, }, }, @@ -730,8 +733,8 @@ func TestHandleGetSession_PreservesToolSummaryAndAssistantContent(t *testing.T) if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil { t.Fatalf("Unmarshal() error = %v", err) } - if len(resp.Messages) != 3 { - t.Fatalf("len(resp.Messages) = %d, want 3", len(resp.Messages)) + if len(resp.Messages) != 2 { + t.Fatalf("len(resp.Messages) = %d, want 2", len(resp.Messages)) } if resp.Messages[0].Role != "user" || resp.Messages[0].Content != "check file" { t.Fatalf("first message = %#v, want user/check file", resp.Messages[0]) @@ -739,8 +742,153 @@ func TestHandleGetSession_PreservesToolSummaryAndAssistantContent(t *testing.T) if !strings.Contains(resp.Messages[1].Content, "`read_file`") { t.Fatalf("tool summary message = %#v, want read_file summary", resp.Messages[1]) } - if resp.Messages[2].Role != "assistant" || resp.Messages[2].Content != "model final reply" { - t.Fatalf("assistant message = %#v, want model final reply", resp.Messages[2]) + if !strings.Contains(resp.Messages[1].Content, "Read the file before replying.") { + t.Fatalf("tool summary message = %#v, want tool explanation", resp.Messages[1]) + } +} + +func TestHandleGetSession_PreservesDistinctAssistantToolCallContent(t *testing.T) { + configPath, cleanup := setupOAuthTestEnv(t) + defer cleanup() + + dir := sessionsTestDir(t, configPath) + store, err := memory.NewJSONLStore(dir) + if err != nil { + t.Fatalf("NewJSONLStore() error = %v", err) + } + + sessionKey := picoSessionPrefix + "detail-tool-summary-distinct-content" + for _, msg := range []providers.Message{ + {Role: "user", Content: "check file"}, + { + Role: "assistant", + Content: "I will summarize the findings after reading the file.", + ToolCalls: []providers.ToolCall{ + { + ID: "call_1", + Type: "function", + Function: &providers.FunctionCall{ + Name: "read_file", + Arguments: `{"path":"README.md","start_line":1,"end_line":10}`, + }, + ExtraContent: &providers.ExtraContent{ + ToolFeedbackExplanation: "Read the file before replying.", + }, + }, + }, + }, + } { + if err := store.AddFullMessage(nil, sessionKey, msg); err != nil { + t.Fatalf("AddFullMessage() error = %v", err) + } + } + + h := NewHandler(configPath) + mux := http.NewServeMux() + h.RegisterRoutes(mux) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/api/sessions/detail-tool-summary-distinct-content", nil) + mux.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String()) + } + + var resp struct { + Messages []struct { + Role string `json:"role"` + Content string `json:"content"` + } `json:"messages"` + } + if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil { + t.Fatalf("Unmarshal() error = %v", err) + } + if len(resp.Messages) != 3 { + t.Fatalf("len(resp.Messages) = %d, want 3", len(resp.Messages)) + } + if !strings.Contains(resp.Messages[1].Content, "`read_file`") { + t.Fatalf("tool summary message = %#v, want read_file summary", resp.Messages[1]) + } + if resp.Messages[2].Role != "assistant" || + resp.Messages[2].Content != "I will summarize the findings after reading the file." { + t.Fatalf("assistant content = %#v, want preserved distinct content", resp.Messages[2]) + } +} + +func TestHandleGetSession_PreservesMediaWhenAssistantToolCallContentDuplicatesSummary(t *testing.T) { + configPath, cleanup := setupOAuthTestEnv(t) + defer cleanup() + + dir := sessionsTestDir(t, configPath) + store, err := memory.NewJSONLStore(dir) + if err != nil { + t.Fatalf("NewJSONLStore() error = %v", err) + } + + sessionKey := picoSessionPrefix + "detail-tool-summary-duplicate-content-with-media" + for _, msg := range []providers.Message{ + {Role: "user", Content: "check screenshot"}, + { + Role: "assistant", + Content: "Reviewing the generated screenshot.", + Media: []string{"data:image/png;base64,abc123"}, + ToolCalls: []providers.ToolCall{ + { + ID: "call_1", + Type: "function", + Function: &providers.FunctionCall{ + Name: "view_image", + Arguments: `{"path":"artifact.png"}`, + }, + ExtraContent: &providers.ExtraContent{ + ToolFeedbackExplanation: "Reviewing the generated screenshot.", + }, + }, + }, + }, + } { + if err := store.AddFullMessage(nil, sessionKey, msg); err != nil { + t.Fatalf("AddFullMessage() error = %v", err) + } + } + + h := NewHandler(configPath) + mux := http.NewServeMux() + h.RegisterRoutes(mux) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/api/sessions/detail-tool-summary-duplicate-content-with-media", nil) + mux.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String()) + } + + var resp struct { + Messages []struct { + Role string `json:"role"` + Content string `json:"content"` + Media []string `json:"media"` + } `json:"messages"` + } + if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil { + t.Fatalf("Unmarshal() error = %v", err) + } + if len(resp.Messages) != 3 { + t.Fatalf("len(resp.Messages) = %d, want 3", len(resp.Messages)) + } + if !strings.Contains(resp.Messages[1].Content, "`view_image`") { + t.Fatalf("tool summary message = %#v, want view_image summary", resp.Messages[1]) + } + if resp.Messages[2].Role != "assistant" { + t.Fatalf("assistant message role = %q, want assistant", resp.Messages[2].Role) + } + if resp.Messages[2].Content != "Reviewing the generated screenshot." { + t.Fatalf("assistant content = %q, want preserved duplicated content with media", resp.Messages[2].Content) + } + if len(resp.Messages[2].Media) != 1 || resp.Messages[2].Media[0] != "data:image/png;base64,abc123" { + t.Fatalf("assistant media = %#v, want preserved media", resp.Messages[2].Media) } for _, msg := range resp.Messages { if msg.Role == "tool" || strings.Contains(msg.Content, "raw read_file result") { @@ -749,6 +897,90 @@ func TestHandleGetSession_PreservesToolSummaryAndAssistantContent(t *testing.T) } } +func TestHandleGetSession_PreservesAttachmentsWhenAssistantToolCallContentDuplicatesSummary(t *testing.T) { + configPath, cleanup := setupOAuthTestEnv(t) + defer cleanup() + + dir := sessionsTestDir(t, configPath) + store, err := memory.NewJSONLStore(dir) + if err != nil { + t.Fatalf("NewJSONLStore() error = %v", err) + } + + sessionKey := picoSessionPrefix + "detail-tool-summary-duplicate-content-with-attachments" + for _, msg := range []providers.Message{ + {Role: "user", Content: "check report"}, + { + Role: "assistant", + Content: "Reviewing the generated report.", + Attachments: []providers.Attachment{{ + Type: "file", + URL: "https://example.com/report.txt", + Filename: "report.txt", + ContentType: "text/plain", + }}, + ToolCalls: []providers.ToolCall{ + { + ID: "call_1", + Type: "function", + Function: &providers.FunctionCall{ + Name: "read_file", + Arguments: `{"path":"report.txt"}`, + }, + ExtraContent: &providers.ExtraContent{ + ToolFeedbackExplanation: "Reviewing the generated report.", + }, + }, + }, + }, + } { + if err := store.AddFullMessage(nil, sessionKey, msg); err != nil { + t.Fatalf("AddFullMessage() error = %v", err) + } + } + + h := NewHandler(configPath) + mux := http.NewServeMux() + h.RegisterRoutes(mux) + + rec := httptest.NewRecorder() + req := httptest.NewRequest( + http.MethodGet, + "/api/sessions/detail-tool-summary-duplicate-content-with-attachments", + nil, + ) + mux.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String()) + } + + var resp struct { + Messages []sessionChatMessage `json:"messages"` + } + if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil { + t.Fatalf("Unmarshal() error = %v", err) + } + if len(resp.Messages) != 3 { + t.Fatalf("len(resp.Messages) = %d, want 3", len(resp.Messages)) + } + if !strings.Contains(resp.Messages[1].Content, "`read_file`") { + t.Fatalf("tool summary message = %#v, want read_file summary", resp.Messages[1]) + } + if resp.Messages[2].Role != "assistant" { + t.Fatalf("assistant message role = %q, want assistant", resp.Messages[2].Role) + } + if resp.Messages[2].Content != "Reviewing the generated report." { + t.Fatalf("assistant content = %q, want preserved duplicated content", resp.Messages[2].Content) + } + if len(resp.Messages[2].Attachments) != 1 { + t.Fatalf("len(assistant.Attachments) = %d, want 1", len(resp.Messages[2].Attachments)) + } + if resp.Messages[2].Attachments[0].URL != "https://example.com/report.txt" { + t.Fatalf("attachment url = %q, want report URL", resp.Messages[2].Attachments[0].URL) + } +} + func TestHandleGetSession_UsesConfiguredToolFeedbackMaxArgsLength(t *testing.T) { configPath, cleanup := setupOAuthTestEnv(t) defer cleanup() @@ -770,6 +1002,7 @@ func TestHandleGetSession_UsesConfiguredToolFeedbackMaxArgsLength(t *testing.T) } argsJSON := `{"path":"README.md","start_line":1,"end_line":10,"extra":"abcdefghijklmnopqrstuvwxyz"}` + explanation := "Read README.md first to confirm the current project structure before editing the config example." sessionKey := picoSessionPrefix + "detail-tool-summary-max-args" err = store.AddFullMessage(nil, sessionKey, providers.Message{Role: "user", Content: "check file"}) if err != nil { @@ -784,6 +1017,9 @@ func TestHandleGetSession_UsesConfiguredToolFeedbackMaxArgsLength(t *testing.T) Name: "read_file", Arguments: argsJSON, }, + ExtraContent: &providers.ExtraContent{ + ToolFeedbackExplanation: explanation, + }, }}, }) if err != nil { @@ -816,13 +1052,93 @@ func TestHandleGetSession_UsesConfiguredToolFeedbackMaxArgsLength(t *testing.T) t.Fatalf("len(resp.Messages) = %d, want at least 2", len(resp.Messages)) } - wantPreview := utils.Truncate(argsJSON, 20) + wantPreview := utils.Truncate(explanation, 20) if !strings.Contains(resp.Messages[1].Content, wantPreview) { t.Fatalf("tool summary = %q, want preview %q", resp.Messages[1].Content, wantPreview) } if strings.Contains(resp.Messages[1].Content, argsJSON) { t.Fatalf("tool summary = %q, expected configured truncation", resp.Messages[1].Content) } + if !strings.Contains(resp.Messages[1].Content, "`read_file`") { + t.Fatalf("tool summary = %q, want read_file summary", resp.Messages[1].Content) + } +} + +func TestHandleGetSession_FallsBackToLegacyToolArgumentsWhenExplanationMissing(t *testing.T) { + configPath, cleanup := setupOAuthTestEnv(t) + defer cleanup() + + cfg, err := config.LoadConfig(configPath) + if err != nil { + t.Fatalf("LoadConfig() error = %v", err) + } + cfg.Agents.Defaults.ToolFeedback.MaxArgsLength = 20 + err = config.SaveConfig(configPath, cfg) + if err != nil { + t.Fatalf("SaveConfig() error = %v", err) + } + + dir := sessionsTestDir(t, configPath) + store, err := memory.NewJSONLStore(dir) + if err != nil { + t.Fatalf("NewJSONLStore() error = %v", err) + } + + argsJSON := `{"path":"README.md","start_line":1,"end_line":10,"extra":"abcdefghijklmnopqrstuvwxyz"}` + sessionKey := picoSessionPrefix + "detail-tool-summary-legacy-args" + if err := store.AddFullMessage( + nil, + sessionKey, + providers.Message{Role: "user", Content: "check file"}, + ); err != nil { + t.Fatalf("AddFullMessage(user) error = %v", err) + } + if err := store.AddFullMessage(nil, sessionKey, providers.Message{ + Role: "assistant", + ToolCalls: []providers.ToolCall{{ + ID: "call_1", + Type: "function", + Function: &providers.FunctionCall{ + Name: "read_file", + Arguments: argsJSON, + }, + }}, + }); err != nil { + t.Fatalf("AddFullMessage(assistant) error = %v", err) + } + + h := NewHandler(configPath) + mux := http.NewServeMux() + h.RegisterRoutes(mux) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/api/sessions/detail-tool-summary-legacy-args", nil) + mux.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String()) + } + + var resp struct { + Messages []struct { + Role string `json:"role"` + Content string `json:"content"` + } `json:"messages"` + } + if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil { + t.Fatalf("Unmarshal() error = %v", err) + } + if len(resp.Messages) < 2 { + t.Fatalf("len(resp.Messages) = %d, want at least 2", len(resp.Messages)) + } + + wantPreview := utils.Truncate(argsJSON, 20) + if !strings.Contains(resp.Messages[1].Content, "`read_file`") { + t.Fatalf("tool summary = %q, want read_file summary", resp.Messages[1].Content) + } + if !strings.Contains(resp.Messages[1].Content, wantPreview) { + t.Fatalf("tool summary = %q, want legacy args preview %q", resp.Messages[1].Content, wantPreview) + } } func TestHandleGetSession_IncludesMediaOnlyMessages(t *testing.T) { diff --git a/web/frontend/src/components/chat/assistant-message.tsx b/web/frontend/src/components/chat/assistant-message.tsx index 4dfc261c2..c09f5a06d 100644 --- a/web/frontend/src/components/chat/assistant-message.tsx +++ b/web/frontend/src/components/chat/assistant-message.tsx @@ -100,8 +100,8 @@ export function AssistantMessage({ className={cn( "prose dark:prose-invert prose-pre:my-2 prose-pre:overflow-x-auto prose-pre:rounded-lg prose-pre:border prose-pre:bg-zinc-100 prose-pre:p-0 prose-pre:text-zinc-900 dark:prose-pre:bg-zinc-950 dark:prose-pre:text-zinc-100 max-w-none [overflow-wrap:anywhere] break-words", isThought - ? "prose-p:my-1.5 px-3 pt-0 pb-3 text-[13px] leading-relaxed opacity-70" - : "prose-p:my-2 p-4 text-[15px] leading-relaxed", + ? "prose-p:my-1.5 prose-p:whitespace-pre-wrap px-3 pt-0 pb-3 text-[13px] leading-relaxed opacity-70" + : "prose-p:my-2 prose-p:whitespace-pre-wrap p-4 text-[15px] leading-relaxed", )} > = 0; i -= 1) { + if (messages[i].role === "user") { + lastUserIndex = i + break + } + } + + for (let i = messages.length - 1; i >= 0; i -= 1) { + if (i <= lastUserIndex) { + break + } + if (isToolFeedbackMessage(messages[i])) { + return i + } + } + return -1 +} + export function handlePicoMessage( message: PicoMessage, expectedSessionId: string, @@ -138,21 +168,88 @@ export function handlePicoMessage( const hasKind = hasAssistantKindPayload(payload) const kind = parseAssistantMessageKind(payload) const attachments = parseAttachments(payload) + const contextUsage = parseContextUsage(payload) + const timestamp = + message.timestamp !== undefined && + Number.isFinite(Number(message.timestamp)) + ? normalizeUnixTimestamp(Number(message.timestamp)) + : Date.now() if (!messageId) { break } updateChatStore((prev) => ({ - messages: prev.messages.map((msg) => - msg.id === messageId - ? { - ...msg, - content, - ...(hasKind ? { kind } : {}), - ...(attachments ? { attachments } : {}), - } - : msg, - ), + messages: (() => { + let found = false + const messages = prev.messages.map((msg) => { + if (msg.id !== messageId) { + return msg + } + found = true + return { + ...msg, + id: messageId, + content, + ...(hasKind ? { kind } : {}), + ...(attachments ? { attachments } : {}), + } + }) + if (found) { + return messages + } + + const fallbackIndex = findToolFeedbackMessageIndex(messages) + if (fallbackIndex >= 0) { + return messages.map((msg, index) => + index === fallbackIndex + ? { + ...msg, + id: messageId, + content, + ...(hasKind ? { kind } : {}), + ...(attachments ? { attachments } : {}), + } + : msg, + ) + } + + return [ + ...messages, + { + id: messageId, + role: "assistant" as const, + content, + ...(hasKind ? { kind } : {}), + ...(attachments ? { attachments } : {}), + timestamp, + }, + ] + })(), + ...(contextUsage ? { contextUsage } : {}), + })) + break + } + + case "message.delete": { + const messageId = payload.message_id as string + if (!messageId) { + break + } + + updateChatStore((prev) => ({ + messages: (() => { + const exactMessages = prev.messages.filter((msg) => msg.id !== messageId) + if (exactMessages.length !== prev.messages.length) { + return exactMessages + } + + const fallbackIndex = findToolFeedbackMessageIndex(prev.messages) + if (fallbackIndex < 0) { + return prev.messages + } + + return prev.messages.filter((_, index) => index !== fallbackIndex) + })(), })) break } diff --git a/web/frontend/src/i18n/locales/en.json b/web/frontend/src/i18n/locales/en.json index 7ded188c1..cf8d91f0c 100644 --- a/web/frontend/src/i18n/locales/en.json +++ b/web/frontend/src/i18n/locales/en.json @@ -605,9 +605,9 @@ "split_on_marker": "Chatty Mode", "split_on_marker_hint": "Split long messages into short ones like real human chatting.", "tool_feedback_enabled": "Tool Feedback", - "tool_feedback_enabled_hint": "Send a short tool-call preview into the current chat before each tool execution.", - "tool_feedback_max_args_length": "Tool Feedback Args Preview Length", - "tool_feedback_max_args_length_hint": "Maximum number of argument characters shown in each tool feedback message. Set to 0 to use the default.", + "tool_feedback_enabled_hint": "Send a short execution note into the current chat before each tool runs.", + "tool_feedback_max_args_length": "Tool Feedback Length", + "tool_feedback_max_args_length_hint": "Maximum number of characters shown in each tool feedback message. Set to 0 to use the default.", "exec_enabled": "Allow Commands", "exec_enabled_hint": "Enable or disable command execution for the app. When disabled, no command requests will run.", "allow_remote": "Allow Remote Commands", diff --git a/web/frontend/src/i18n/locales/zh.json b/web/frontend/src/i18n/locales/zh.json index ca71d7ef8..1d3c571c3 100644 --- a/web/frontend/src/i18n/locales/zh.json +++ b/web/frontend/src/i18n/locales/zh.json @@ -605,9 +605,9 @@ "split_on_marker": "连续短消息", "split_on_marker_hint": "像真人聊天一样,把长难句拆成多条短消息快速发出", "tool_feedback_enabled": "工具反馈", - "tool_feedback_enabled_hint": "在每次执行工具前,先向当前会话发送一条简短的工具调用预览", - "tool_feedback_max_args_length": "工具反馈参数预览长度", - "tool_feedback_max_args_length_hint": "每条工具反馈消息中展示的参数字符上限。设为 0 时使用默认值", + "tool_feedback_enabled_hint": "在每次执行工具前,先向当前会话发送一条简短的执行说明", + "tool_feedback_max_args_length": "工具反馈长度", + "tool_feedback_max_args_length_hint": "每条工具反馈消息中展示的字符上限。设为 0 时使用默认值", "exec_enabled": "允许命令执行", "exec_enabled_hint": "控制应用是否允许执行命令。关闭后,所有命令请求都不会执行", "allow_remote": "允许远程命令执行",