mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-05-25 16:00:35 +00:00
This reverts commit e556a816e4.
This commit is contained in:
@@ -3,7 +3,6 @@ package auth
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
@@ -20,19 +19,6 @@ import (
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
func newIPv4TestServer(t *testing.T, handler http.Handler) *httptest.Server {
|
||||
t.Helper()
|
||||
|
||||
server := httptest.NewUnstartedServer(handler)
|
||||
listener, err := net.Listen("tcp4", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
server.Listener = listener
|
||||
server.Start()
|
||||
t.Cleanup(server.Close)
|
||||
return server
|
||||
}
|
||||
|
||||
func TestNewWeComCommand(t *testing.T) {
|
||||
cmd := newWeComCommand()
|
||||
|
||||
@@ -67,7 +53,7 @@ func TestBuildWeComQRCodePageURL(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestFetchWeComQRCode(t *testing.T) {
|
||||
server := newIPv4TestServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "/generate", r.URL.Path)
|
||||
assert.Equal(t, wecomQRSourceID, r.URL.Query().Get("source"))
|
||||
assert.Equal(t, wecomQRSourceID, r.URL.Query().Get("sourceID"))
|
||||
@@ -75,6 +61,7 @@ func TestFetchWeComQRCode(t *testing.T) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"data":{"scode":"scode-1","auth_url":"https://example.com/qr"}}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
opts := normalizeWeComQRFlowOptions(wecomQRFlowOptions{
|
||||
HTTPClient: server.Client(),
|
||||
@@ -91,7 +78,7 @@ func TestFetchWeComQRCode(t *testing.T) {
|
||||
func TestPollWeComQRCodeResult(t *testing.T) {
|
||||
var calls atomic.Int32
|
||||
|
||||
server := newIPv4TestServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
call := calls.Add(1)
|
||||
assert.Equal(t, "/query", r.URL.Path)
|
||||
assert.Equal(t, "scode-1", r.URL.Query().Get("scode"))
|
||||
@@ -105,6 +92,7 @@ func TestPollWeComQRCodeResult(t *testing.T) {
|
||||
_, _ = w.Write([]byte(`{"data":{"status":"success","bot_info":{"botid":"bot-1","secret":"secret-1"}}}`))
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
var output bytes.Buffer
|
||||
opts := normalizeWeComQRFlowOptions(wecomQRFlowOptions{
|
||||
|
||||
@@ -8,56 +8,26 @@ Discord is a free voice, video, and text chat application designed for communiti
|
||||
|
||||
```json
|
||||
{
|
||||
"agents": {
|
||||
"defaults": {
|
||||
"tool_feedback": {
|
||||
"enabled": true,
|
||||
"max_args_length": 300
|
||||
}
|
||||
}
|
||||
},
|
||||
"channel_list": {
|
||||
"discord": {
|
||||
"enabled": true,
|
||||
"type": "discord",
|
||||
"token": "YOUR_BOT_TOKEN",
|
||||
"allow_from": ["YOUR_USER_ID"],
|
||||
"placeholder": {
|
||||
"enabled": true,
|
||||
"text": ["Thinking... 💭"]
|
||||
},
|
||||
"group_trigger": {
|
||||
"mention_only": false
|
||||
},
|
||||
"reasoning_channel_id": ""
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
| Field | Type | Required | Description |
|
||||
| -------------------- | ------ | -------- | --------------------------------------------------------------------------- |
|
||||
| enabled | bool | Yes | Whether to enable the Discord channel |
|
||||
| token | string | Yes | Discord Bot Token |
|
||||
| allow_from | array | No | Allowlist of user IDs; empty means all users are allowed |
|
||||
| placeholder | object | No | Placeholder message config shown while the agent is working |
|
||||
| group_trigger | object | No | Group trigger settings (example: { "mention_only": false }) |
|
||||
| reasoning_channel_id | string | No | Optional target channel ID for reasoning/thinking output |
|
||||
|
||||
## Visible Execution Feedback
|
||||
|
||||
Discord can show three different kinds of "working" feedback:
|
||||
|
||||
1. Typing indicator: automatic, no extra config needed.
|
||||
2. Placeholder message: enable `channel_list.discord.placeholder.enabled` to send a visible `Thinking...` message that is later edited into the final reply.
|
||||
3. Tool execution feedback: enable `agents.defaults.tool_feedback.enabled` to send a short message before each tool call, for example:
|
||||
|
||||
```text
|
||||
🔧 `web_search`
|
||||
Checking the latest PicoClaw release notes before I answer.
|
||||
```
|
||||
|
||||
If you only see `Bot is typing`, check that `placeholder.enabled` or `tool_feedback.enabled` is actually set in your runtime config.
|
||||
| Field | Type | Required | Description |
|
||||
| ------------- | ------ | -------- | --------------------------------------------------------------------------- |
|
||||
| enabled | bool | Yes | Whether to enable the Discord channel |
|
||||
| token | string | Yes | Discord Bot Token |
|
||||
| allow_from | array | No | Allowlist of user IDs; empty means all users are allowed |
|
||||
| group_trigger | object | No | Group trigger settings (example: { "mention_only": false }) |
|
||||
|
||||
## Setup
|
||||
|
||||
|
||||
@@ -112,7 +112,6 @@ const (
|
||||
pendingTurnPrefix = "pending-"
|
||||
metadataKeyMessageKind = "message_kind"
|
||||
messageKindThought = "thought"
|
||||
messageKindToolFeedback = "tool_feedback"
|
||||
metadataKeyAccountID = "account_id"
|
||||
metadataKeyGuildID = "guild_id"
|
||||
metadataKeyTeamID = "team_id"
|
||||
|
||||
+1
-365
@@ -24,7 +24,6 @@ import (
|
||||
"github.com/sipeed/picoclaw/pkg/routing"
|
||||
"github.com/sipeed/picoclaw/pkg/session"
|
||||
"github.com/sipeed/picoclaw/pkg/tools"
|
||||
"github.com/sipeed/picoclaw/pkg/utils"
|
||||
)
|
||||
|
||||
type fakeChannel struct{ id string }
|
||||
@@ -1759,157 +1758,6 @@ func (m *toolFeedbackProvider) GetDefaultModel() string {
|
||||
return "heartbeat-tool-feedback-model"
|
||||
}
|
||||
|
||||
type toolFeedbackReasoningProvider struct {
|
||||
filePath string
|
||||
calls int
|
||||
}
|
||||
|
||||
func (m *toolFeedbackReasoningProvider) Chat(
|
||||
ctx context.Context,
|
||||
messages []providers.Message,
|
||||
tools []providers.ToolDefinition,
|
||||
model string,
|
||||
opts map[string]any,
|
||||
) (*providers.LLMResponse, error) {
|
||||
m.calls++
|
||||
if m.calls == 1 {
|
||||
return &providers.LLMResponse{
|
||||
ReasoningContent: "Read README.md first to confirm the context that needs to be changed.",
|
||||
ToolCalls: []providers.ToolCall{{
|
||||
ID: "call_reasoning_read_file",
|
||||
Type: "function",
|
||||
Name: "read_file",
|
||||
Arguments: map[string]any{"path": m.filePath},
|
||||
}},
|
||||
}, nil
|
||||
}
|
||||
|
||||
return &providers.LLMResponse{
|
||||
Content: "DONE",
|
||||
ToolCalls: []providers.ToolCall{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *toolFeedbackReasoningProvider) GetDefaultModel() string {
|
||||
return "tool-feedback-reasoning-model"
|
||||
}
|
||||
|
||||
func TestToolFeedbackExplanationFromResponse_UsesCurrentContentFirst(t *testing.T) {
|
||||
response := &providers.LLMResponse{
|
||||
Content: "Read README.md first",
|
||||
ReasoningContent: "current reasoning fallback",
|
||||
}
|
||||
messages := []providers.Message{
|
||||
{Role: "user", Content: "check file"},
|
||||
{Role: "assistant", Content: "Previous turn explanation"},
|
||||
{Role: "tool", Content: "tool output", ToolCallID: "call_1"},
|
||||
}
|
||||
|
||||
got := toolFeedbackExplanationFromResponse(response, messages, 300)
|
||||
if got != "Read README.md first" {
|
||||
t.Fatalf("toolFeedbackExplanationFromResponse() = %q, want current content", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolFeedbackExplanationFromResponse_UsesExplicitToolCallExtraContent(t *testing.T) {
|
||||
response := &providers.LLMResponse{
|
||||
ToolCalls: []providers.ToolCall{{
|
||||
ID: "call_1",
|
||||
Name: "read_file",
|
||||
ExtraContent: &providers.ExtraContent{
|
||||
ToolFeedbackExplanation: "Read README.md first to confirm the current project structure.",
|
||||
},
|
||||
}},
|
||||
}
|
||||
messages := []providers.Message{
|
||||
{Role: "user", Content: "check file"},
|
||||
{Role: "assistant", Content: ""},
|
||||
{Role: "tool", Content: "tool output", ToolCallID: "call_1"},
|
||||
}
|
||||
|
||||
got := toolFeedbackExplanationFromResponse(response, messages, 300)
|
||||
if got != "Read README.md first to confirm the current project structure." {
|
||||
t.Fatalf("toolFeedbackExplanationFromResponse() = %q, want explicit tool feedback explanation", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolFeedbackExplanationForToolCall_PrefersToolSpecificExtraContent(t *testing.T) {
|
||||
response := &providers.LLMResponse{
|
||||
Content: "Shared explanation",
|
||||
ToolCalls: []providers.ToolCall{
|
||||
{
|
||||
ID: "call_1",
|
||||
Name: "read_file",
|
||||
ExtraContent: &providers.ExtraContent{
|
||||
ToolFeedbackExplanation: "Read README.md first.",
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "call_2",
|
||||
Name: "edit_file",
|
||||
ExtraContent: &providers.ExtraContent{
|
||||
ToolFeedbackExplanation: "Update config example after reading it.",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
got1 := toolFeedbackExplanationForToolCall(response, response.ToolCalls[0], nil, 300)
|
||||
got2 := toolFeedbackExplanationForToolCall(response, response.ToolCalls[1], nil, 300)
|
||||
if got1 != "Read README.md first." {
|
||||
t.Fatalf("toolFeedbackExplanationForToolCall() first = %q, want tool-specific explanation", got1)
|
||||
}
|
||||
if got2 != "Update config example after reading it." {
|
||||
t.Fatalf("toolFeedbackExplanationForToolCall() second = %q, want tool-specific explanation", got2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolFeedbackExplanationForToolCall_DoesNotReuseAnotherToolCallExplanation(t *testing.T) {
|
||||
response := &providers.LLMResponse{
|
||||
ToolCalls: []providers.ToolCall{
|
||||
{
|
||||
ID: "call_1",
|
||||
Name: "read_file",
|
||||
},
|
||||
{
|
||||
ID: "call_2",
|
||||
Name: "edit_file",
|
||||
ExtraContent: &providers.ExtraContent{
|
||||
ToolFeedbackExplanation: "Update config example after reading it.",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
messages := []providers.Message{
|
||||
{Role: "user", Content: "inspect the config and update the example"},
|
||||
}
|
||||
|
||||
got := toolFeedbackExplanationForToolCall(response, response.ToolCalls[0], messages, 300)
|
||||
want := utils.ToolFeedbackContinuationHint + ": inspect the config and update the example"
|
||||
if got != want {
|
||||
t.Fatalf("toolFeedbackExplanationForToolCall() = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolFeedbackExplanationFromResponse_DoesNotUseReasoningContent(t *testing.T) {
|
||||
response := &providers.LLMResponse{
|
||||
Content: "",
|
||||
ReasoningContent: "hidden reasoning should not be shown",
|
||||
}
|
||||
messages := []providers.Message{
|
||||
{Role: "user", Content: "check file"},
|
||||
{Role: "assistant", Content: "Previous turn explanation"},
|
||||
{Role: "user", Content: "Inspect README.md and update the config example."},
|
||||
{Role: "tool", Content: "tool output", ToolCallID: "call_1"},
|
||||
}
|
||||
|
||||
got := toolFeedbackExplanationFromResponse(response, messages, 300)
|
||||
want := utils.ToolFeedbackContinuationHint + ": Inspect README.md and update the config example."
|
||||
if got != want {
|
||||
t.Fatalf("toolFeedbackExplanationFromResponse() = %q, want latest user content fallback", got)
|
||||
}
|
||||
}
|
||||
|
||||
type picoInterleavedContentProvider struct {
|
||||
calls int
|
||||
}
|
||||
@@ -3808,16 +3656,7 @@ func TestProcessMessage_PublishesToolFeedbackWhenEnabled(t *testing.T) {
|
||||
t.Fatalf("unexpected tool feedback context: %+v", outbound.Context)
|
||||
}
|
||||
if !strings.Contains(outbound.Content, "`read_file`") {
|
||||
t.Fatalf("tool feedback content = %q, want read_file summary", outbound.Content)
|
||||
}
|
||||
if !strings.Contains(outbound.Content, utils.ToolFeedbackContinuationHint) {
|
||||
t.Fatalf("tool feedback content = %q, want continuation hint fallback", outbound.Content)
|
||||
}
|
||||
if !strings.Contains(outbound.Content, "check tool feedback") {
|
||||
t.Fatalf("tool feedback content = %q, want current user intent fallback", outbound.Content)
|
||||
}
|
||||
if strings.Contains(outbound.Content, "Previous turn explanation") {
|
||||
t.Fatalf("tool feedback content = %q, want no previous assistant fallback", outbound.Content)
|
||||
t.Fatalf("tool feedback content = %q, want read_file preview", outbound.Content)
|
||||
}
|
||||
if outbound.AgentID != "main" {
|
||||
t.Fatalf("tool feedback agent_id = %q, want main", outbound.AgentID)
|
||||
@@ -3833,130 +3672,6 @@ func TestProcessMessage_PublishesToolFeedbackWhenEnabled(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessMessage_DoesNotLeakReasoningContentInToolFeedback(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
heartbeatFile := filepath.Join(tmpDir, "tool-feedback-reasoning.txt")
|
||||
if err := os.WriteFile(heartbeatFile, []byte("tool feedback task"), 0o644); err != nil {
|
||||
t.Fatalf("WriteFile() error = %v", err)
|
||||
}
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
ModelName: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
ToolFeedback: config.ToolFeedbackConfig{
|
||||
Enabled: true,
|
||||
MaxArgsLength: 300,
|
||||
},
|
||||
},
|
||||
},
|
||||
Tools: config.ToolsConfig{
|
||||
ReadFile: config.ReadFileToolConfig{
|
||||
Enabled: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &toolFeedbackReasoningProvider{filePath: heartbeatFile}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
|
||||
response, err := al.processMessage(context.Background(), testInboundMessage(bus.InboundMessage{
|
||||
Channel: "telegram",
|
||||
SenderID: "user-1",
|
||||
ChatID: "chat-1",
|
||||
Content: "check reasoning fallback",
|
||||
}))
|
||||
if err != nil {
|
||||
t.Fatalf("processMessage() error = %v", err)
|
||||
}
|
||||
if response != "DONE" {
|
||||
t.Fatalf("processMessage() response = %q, want %q", response, "DONE")
|
||||
}
|
||||
|
||||
select {
|
||||
case outbound := <-msgBus.OutboundChan():
|
||||
if !strings.Contains(outbound.Content, "`read_file`") {
|
||||
t.Fatalf("tool feedback content = %q, want read_file summary", outbound.Content)
|
||||
}
|
||||
if !strings.Contains(outbound.Content, utils.ToolFeedbackContinuationHint) {
|
||||
t.Fatalf("tool feedback content = %q, want continuation hint fallback", outbound.Content)
|
||||
}
|
||||
if !strings.Contains(outbound.Content, "check reasoning fallback") {
|
||||
t.Fatalf("tool feedback content = %q, want current user intent fallback", outbound.Content)
|
||||
}
|
||||
if strings.Contains(outbound.Content, "Read README.md first") {
|
||||
t.Fatalf("tool feedback content = %q, should not leak hidden reasoning", outbound.Content)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("expected outbound tool feedback without leaking reasoning")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessMessage_DoesNotPublishToolFeedbackForDiscordWhenDisabled(t *testing.T) {
|
||||
assertToolFeedbackNotPublishedWhenDisabled(t, "discord")
|
||||
}
|
||||
|
||||
func assertToolFeedbackNotPublishedWhenDisabled(t *testing.T, channel string) {
|
||||
t.Helper()
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
heartbeatFile := filepath.Join(tmpDir, "tool-feedback-"+channel+".txt")
|
||||
if err := os.WriteFile(heartbeatFile, []byte("tool feedback task"), 0o644); err != nil {
|
||||
t.Fatalf("WriteFile() error = %v", err)
|
||||
}
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
ModelName: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
Tools: config.ToolsConfig{
|
||||
ReadFile: config.ReadFileToolConfig{
|
||||
Enabled: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &toolFeedbackProvider{filePath: heartbeatFile}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
|
||||
response, err := al.processMessage(context.Background(), testInboundMessage(bus.InboundMessage{
|
||||
Channel: channel,
|
||||
SenderID: "user-1",
|
||||
ChatID: "chat-1",
|
||||
Content: "check tool feedback",
|
||||
}))
|
||||
if err != nil {
|
||||
t.Fatalf("processMessage() error = %v", err)
|
||||
}
|
||||
if response != "HEARTBEAT_OK" {
|
||||
t.Fatalf("processMessage() response = %q, want %q", response, "HEARTBEAT_OK")
|
||||
}
|
||||
|
||||
select {
|
||||
case outbound := <-msgBus.OutboundChan():
|
||||
t.Fatalf("expected no outbound tool feedback for %s when disabled, got %+v", channel, outbound)
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessMessage_DoesNotPublishToolFeedbackForTelegramWhenDisabled(t *testing.T) {
|
||||
assertToolFeedbackNotPublishedWhenDisabled(t, "telegram")
|
||||
}
|
||||
|
||||
func TestProcessMessage_DoesNotPublishToolFeedbackForFeishuWhenDisabled(t *testing.T) {
|
||||
assertToolFeedbackNotPublishedWhenDisabled(t, "feishu")
|
||||
}
|
||||
|
||||
func TestProcessMessage_MessageToolPublishesOutboundWithTurnMetadata(t *testing.T) {
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Agents.Defaults.Workspace = t.TempDir()
|
||||
@@ -4131,85 +3846,6 @@ func TestRunAgentLoop_PicoSkipsInterimPublishWhenNotAllowed(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestRun_PicoToolFeedbackSuppressesDuplicateInterimAssistantContent(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
ModelName: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
ToolFeedback: config.ToolFeedbackConfig{
|
||||
Enabled: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &picoInterleavedContentProvider{}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
|
||||
agent := al.GetRegistry().GetDefaultAgent()
|
||||
if agent == nil {
|
||||
t.Fatal("expected default agent")
|
||||
}
|
||||
agent.Tools.Register(&toolLimitTestTool{})
|
||||
|
||||
runCtx, runCancel := context.WithCancel(context.Background())
|
||||
defer runCancel()
|
||||
|
||||
runDone := make(chan error, 1)
|
||||
go func() {
|
||||
runDone <- al.Run(runCtx)
|
||||
}()
|
||||
|
||||
if err := msgBus.PublishInbound(context.Background(), bus.InboundMessage{
|
||||
Channel: "pico",
|
||||
SenderID: "user-1",
|
||||
ChatID: "session-1",
|
||||
Content: "run with tools",
|
||||
}); err != nil {
|
||||
t.Fatalf("PublishInbound() error = %v", err)
|
||||
}
|
||||
|
||||
outputs := make([]string, 0, 2)
|
||||
deadline := time.After(2 * time.Second)
|
||||
for len(outputs) < 2 {
|
||||
select {
|
||||
case outbound := <-msgBus.OutboundChan():
|
||||
outputs = append(outputs, outbound.Content)
|
||||
case <-deadline:
|
||||
t.Fatalf("timed out waiting for pico outputs, got %v", outputs)
|
||||
}
|
||||
}
|
||||
|
||||
if outputs[0] != "🔧 `tool_limit_test_tool`\nintermediate model text" {
|
||||
t.Fatalf("first outbound content = %q, want tool feedback summary", outputs[0])
|
||||
}
|
||||
if outputs[1] != "final model text" {
|
||||
t.Fatalf("second outbound content = %q, want %q", outputs[1], "final model text")
|
||||
}
|
||||
|
||||
runCancel()
|
||||
select {
|
||||
case err := <-runDone:
|
||||
if err != nil {
|
||||
t.Fatalf("Run() error = %v", err)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timed out waiting for Run() to exit")
|
||||
}
|
||||
|
||||
select {
|
||||
case outbound := <-msgBus.OutboundChan():
|
||||
t.Fatalf("unexpected extra pico output after tool feedback + final reply: %+v", outbound)
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveMediaRefs_ResolvesToBase64(t *testing.T) {
|
||||
store := media.NewFileMediaStore()
|
||||
dir := t.TempDir()
|
||||
|
||||
+20
-31
@@ -635,11 +635,7 @@ turnLoop:
|
||||
}
|
||||
logger.DebugCF("agent", "LLM response", llmResponseFields)
|
||||
|
||||
if al.bus != nil &&
|
||||
ts.channel == "pico" &&
|
||||
len(response.ToolCalls) > 0 &&
|
||||
ts.opts.AllowInterimPicoPublish &&
|
||||
!shouldPublishToolFeedback(al.cfg, ts) {
|
||||
if al.bus != nil && ts.channel == "pico" && len(response.ToolCalls) > 0 && ts.opts.AllowInterimPicoPublish {
|
||||
if strings.TrimSpace(response.Content) != "" {
|
||||
outCtx, outCancel := context.WithTimeout(turnCtx, 3*time.Second)
|
||||
err := al.bus.PublishOutbound(outCtx, bus.OutboundMessage{
|
||||
@@ -709,19 +705,7 @@ turnLoop:
|
||||
}
|
||||
for _, tc := range normalizedToolCalls {
|
||||
argumentsJSON, _ := json.Marshal(tc.Arguments)
|
||||
toolFeedbackExplanation := toolFeedbackExplanationForToolCall(
|
||||
response,
|
||||
tc,
|
||||
messages,
|
||||
al.cfg.Agents.Defaults.GetToolFeedbackMaxArgsLength(),
|
||||
)
|
||||
extraContent := tc.ExtraContent
|
||||
if strings.TrimSpace(toolFeedbackExplanation) != "" {
|
||||
if extraContent == nil {
|
||||
extraContent = &providers.ExtraContent{}
|
||||
}
|
||||
extraContent.ToolFeedbackExplanation = toolFeedbackExplanation
|
||||
}
|
||||
thoughtSignature := ""
|
||||
if tc.Function != nil {
|
||||
thoughtSignature = tc.Function.ThoughtSignature
|
||||
@@ -799,16 +783,21 @@ turnLoop:
|
||||
)
|
||||
|
||||
// Send tool feedback to chat channel if enabled (same as normal tool execution)
|
||||
if shouldPublishToolFeedback(al.cfg, ts) {
|
||||
toolFeedbackExplanation := toolFeedbackExplanationForToolCall(
|
||||
response,
|
||||
tc,
|
||||
messages,
|
||||
if al.cfg.Agents.Defaults.IsToolFeedbackEnabled() &&
|
||||
ts.channel != "" &&
|
||||
!ts.opts.SuppressToolFeedback {
|
||||
argsJSON, _ := json.Marshal(toolArgs)
|
||||
feedbackPreview := utils.Truncate(
|
||||
string(argsJSON),
|
||||
al.cfg.Agents.Defaults.GetToolFeedbackMaxArgsLength(),
|
||||
)
|
||||
feedbackMsg := utils.FormatToolFeedbackMessage(toolName, toolFeedbackExplanation)
|
||||
feedbackMsg := utils.FormatToolFeedbackMessage(toolName, feedbackPreview)
|
||||
fbCtx, fbCancel := context.WithTimeout(turnCtx, 3*time.Second)
|
||||
_ = al.bus.PublishOutbound(fbCtx, outboundMessageForTurnWithKind(ts, feedbackMsg, messageKindToolFeedback))
|
||||
_ = al.bus.PublishOutbound(fbCtx, bus.OutboundMessage{
|
||||
Channel: ts.channel,
|
||||
ChatID: ts.chatID,
|
||||
Content: feedbackMsg,
|
||||
})
|
||||
fbCancel()
|
||||
}
|
||||
|
||||
@@ -1078,16 +1067,16 @@ turnLoop:
|
||||
)
|
||||
|
||||
// Send tool feedback to chat channel if enabled (from HEAD)
|
||||
if shouldPublishToolFeedback(al.cfg, ts) {
|
||||
toolFeedbackExplanation := toolFeedbackExplanationForToolCall(
|
||||
response,
|
||||
tc,
|
||||
messages,
|
||||
if al.cfg.Agents.Defaults.IsToolFeedbackEnabled() &&
|
||||
ts.channel != "" &&
|
||||
!ts.opts.SuppressToolFeedback {
|
||||
feedbackPreview := utils.Truncate(
|
||||
string(argsJSON),
|
||||
al.cfg.Agents.Defaults.GetToolFeedbackMaxArgsLength(),
|
||||
)
|
||||
feedbackMsg := utils.FormatToolFeedbackMessage(tc.Name, toolFeedbackExplanation)
|
||||
feedbackMsg := utils.FormatToolFeedbackMessage(tc.Name, feedbackPreview)
|
||||
fbCtx, fbCancel := context.WithTimeout(turnCtx, 3*time.Second)
|
||||
_ = al.bus.PublishOutbound(fbCtx, outboundMessageForTurnWithKind(ts, feedbackMsg, messageKindToolFeedback))
|
||||
_ = al.bus.PublishOutbound(fbCtx, outboundMessageForTurn(ts, feedbackMsg))
|
||||
fbCancel()
|
||||
}
|
||||
|
||||
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/commands"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/session"
|
||||
"github.com/sipeed/picoclaw/pkg/utils"
|
||||
@@ -85,98 +84,6 @@ func outboundMessageForTurn(ts *turnState, content string) bus.OutboundMessage {
|
||||
}
|
||||
}
|
||||
|
||||
func outboundMessageForTurnWithKind(ts *turnState, content, kind string) bus.OutboundMessage {
|
||||
msg := outboundMessageForTurn(ts, content)
|
||||
if strings.TrimSpace(kind) == "" {
|
||||
return msg
|
||||
}
|
||||
if msg.Context.Raw == nil {
|
||||
msg.Context.Raw = make(map[string]string, 1)
|
||||
}
|
||||
msg.Context.Raw[metadataKeyMessageKind] = kind
|
||||
return msg
|
||||
}
|
||||
|
||||
func latestUserContent(messages []providers.Message) string {
|
||||
for i := len(messages) - 1; i >= 0; i-- {
|
||||
msg := messages[i]
|
||||
if msg.Role != "user" {
|
||||
continue
|
||||
}
|
||||
if content := strings.TrimSpace(msg.Content); content != "" {
|
||||
return content
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func toolFeedbackExplanationFromResponse(
|
||||
response *providers.LLMResponse,
|
||||
messages []providers.Message,
|
||||
maxLen int,
|
||||
) string {
|
||||
if response == nil {
|
||||
return ""
|
||||
}
|
||||
explanation := strings.TrimSpace(response.Content)
|
||||
if explanation == "" {
|
||||
explanation = toolFeedbackExplanationFromToolCalls(response.ToolCalls)
|
||||
}
|
||||
if explanation == "" {
|
||||
explanation = toolFeedbackExplanationFromMessages(messages)
|
||||
}
|
||||
return utils.Truncate(explanation, maxLen)
|
||||
}
|
||||
|
||||
func toolFeedbackExplanationFromToolCalls(toolCalls []providers.ToolCall) string {
|
||||
for _, tc := range toolCalls {
|
||||
if tc.ExtraContent == nil {
|
||||
continue
|
||||
}
|
||||
if explanation := strings.TrimSpace(tc.ExtraContent.ToolFeedbackExplanation); explanation != "" {
|
||||
return explanation
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func toolFeedbackExplanationForToolCall(
|
||||
response *providers.LLMResponse,
|
||||
toolCall providers.ToolCall,
|
||||
messages []providers.Message,
|
||||
maxLen int,
|
||||
) string {
|
||||
if toolCall.ExtraContent != nil {
|
||||
if explanation := strings.TrimSpace(toolCall.ExtraContent.ToolFeedbackExplanation); explanation != "" {
|
||||
return utils.Truncate(explanation, maxLen)
|
||||
}
|
||||
}
|
||||
if response == nil {
|
||||
return utils.Truncate(toolFeedbackExplanationFromMessages(messages), maxLen)
|
||||
}
|
||||
|
||||
explanation := strings.TrimSpace(response.Content)
|
||||
if explanation == "" {
|
||||
explanation = toolFeedbackExplanationFromMessages(messages)
|
||||
}
|
||||
return utils.Truncate(explanation, maxLen)
|
||||
}
|
||||
|
||||
func toolFeedbackExplanationFromMessages(messages []providers.Message) string {
|
||||
explanation := latestUserContent(messages)
|
||||
if explanation != "" {
|
||||
return utils.ToolFeedbackContinuationHint + ": " + explanation
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func shouldPublishToolFeedback(cfg *config.Config, ts *turnState) bool {
|
||||
if ts == nil || ts.channel == "" || ts.opts.SuppressToolFeedback {
|
||||
return false
|
||||
}
|
||||
return cfg != nil && cfg.Agents.Defaults.IsToolFeedbackEnabled()
|
||||
}
|
||||
|
||||
func cloneEventArguments(args map[string]any) map[string]any {
|
||||
if len(args) == 0 {
|
||||
return nil
|
||||
|
||||
+20
-170
@@ -45,12 +45,9 @@ type DiscordChannel struct {
|
||||
cancel context.CancelFunc
|
||||
typingMu sync.Mutex
|
||||
typingStop map[string]chan struct{} // chatID → stop signal
|
||||
progress *channels.ToolFeedbackAnimator
|
||||
botUserID string // stored for mention checking
|
||||
botUserID string // stored for mention checking
|
||||
bus *bus.MessageBus
|
||||
tts tts.TTSProvider
|
||||
playTTSFn func(context.Context, *discordgo.VoiceConnection, string, uint64)
|
||||
ttsVoiceFn func(string) (*discordgo.VoiceConnection, bool)
|
||||
voiceMu sync.RWMutex
|
||||
voiceSSRC map[string]map[uint32]string // guildID -> ssrc -> userID
|
||||
|
||||
@@ -87,7 +84,7 @@ func NewDiscordChannel(
|
||||
channels.WithReasoningChannelID(bc.ReasoningChannelID),
|
||||
)
|
||||
|
||||
ch := &DiscordChannel{
|
||||
return &DiscordChannel{
|
||||
BaseChannel: base,
|
||||
bc: bc,
|
||||
session: session,
|
||||
@@ -96,11 +93,7 @@ func NewDiscordChannel(
|
||||
typingStop: make(map[string]chan struct{}),
|
||||
bus: bus,
|
||||
voiceSSRC: make(map[string]map[uint32]string),
|
||||
}
|
||||
ch.playTTSFn = ch.playTTS
|
||||
ch.ttsVoiceFn = ch.voiceConnectionForTTS
|
||||
ch.progress = channels.NewToolFeedbackAnimator(ch.EditMessage)
|
||||
return ch, nil
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *DiscordChannel) Start(ctx context.Context) error {
|
||||
@@ -149,9 +142,6 @@ func (c *DiscordChannel) Stop(ctx context.Context) error {
|
||||
if c.cancel != nil {
|
||||
c.cancel()
|
||||
}
|
||||
if c.progress != nil {
|
||||
c.progress.StopAll()
|
||||
}
|
||||
|
||||
if err := c.session.Close(); err != nil {
|
||||
return fmt.Errorf("failed to close discord session: %w", err)
|
||||
@@ -174,88 +164,32 @@ func (c *DiscordChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([]s
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
isToolFeedback := outboundMessageIsToolFeedback(msg)
|
||||
if isToolFeedback {
|
||||
if msgID, handled, err := c.progress.Update(ctx, channelID, msg.Content); handled {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
if c.tts != nil {
|
||||
if ch, err := c.session.State.Channel(channelID); err == nil && ch.GuildID != "" {
|
||||
if vc, ok := c.session.VoiceConnections[ch.GuildID]; ok && vc != nil {
|
||||
// Cancel any previous TTS playback
|
||||
c.ttsMu.Lock()
|
||||
if c.cancelTTS != nil {
|
||||
c.cancelTTS()
|
||||
}
|
||||
ttsCtx, ttsCancel := context.WithCancel(c.ctx)
|
||||
c.ttsPlayID++
|
||||
playID := c.ttsPlayID
|
||||
c.cancelTTS = ttsCancel
|
||||
c.ttsMu.Unlock()
|
||||
|
||||
go c.playTTS(ttsCtx, vc, msg.Content, playID)
|
||||
}
|
||||
return []string{msgID}, nil
|
||||
}
|
||||
}
|
||||
trackedMsgID, hasTrackedMsg := c.currentToolFeedbackMessage(channelID)
|
||||
c.maybeStartTTS(channelID, msg.Content, isToolFeedback)
|
||||
if !isToolFeedback {
|
||||
if msgIDs, handled := c.FinalizeToolFeedbackMessage(ctx, msg); handled {
|
||||
return msgIDs, nil
|
||||
}
|
||||
}
|
||||
|
||||
content := msg.Content
|
||||
if isToolFeedback {
|
||||
content = channels.InitialAnimatedToolFeedbackContent(msg.Content)
|
||||
}
|
||||
msgID, err := c.sendChunk(ctx, channelID, content, msg.ReplyToMessageID)
|
||||
msgID, err := c.sendChunk(ctx, channelID, msg.Content, msg.ReplyToMessageID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if isToolFeedback {
|
||||
c.RecordToolFeedbackMessage(channelID, msgID, msg.Content)
|
||||
} else if hasTrackedMsg {
|
||||
c.dismissTrackedToolFeedbackMessage(ctx, channelID, trackedMsgID)
|
||||
}
|
||||
return []string{msgID}, nil
|
||||
}
|
||||
|
||||
func (c *DiscordChannel) maybeStartTTS(channelID, content string, isToolFeedback bool) {
|
||||
if c.tts == nil || isToolFeedback {
|
||||
return
|
||||
}
|
||||
|
||||
voiceFn := c.ttsVoiceFn
|
||||
if voiceFn == nil {
|
||||
voiceFn = c.voiceConnectionForTTS
|
||||
}
|
||||
vc, ok := voiceFn(channelID)
|
||||
if !ok || vc == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Cancel any previous TTS playback.
|
||||
c.ttsMu.Lock()
|
||||
if c.cancelTTS != nil {
|
||||
c.cancelTTS()
|
||||
}
|
||||
ttsCtx, ttsCancel := context.WithCancel(c.ctx)
|
||||
c.ttsPlayID++
|
||||
playID := c.ttsPlayID
|
||||
c.cancelTTS = ttsCancel
|
||||
playFn := c.playTTSFn
|
||||
c.ttsMu.Unlock()
|
||||
|
||||
if playFn == nil {
|
||||
playFn = c.playTTS
|
||||
}
|
||||
go playFn(ttsCtx, vc, content, playID)
|
||||
}
|
||||
|
||||
func (c *DiscordChannel) voiceConnectionForTTS(channelID string) (*discordgo.VoiceConnection, bool) {
|
||||
if c.session == nil || c.session.State == nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
ch, err := c.session.State.Channel(channelID)
|
||||
if err != nil || ch == nil || ch.GuildID == "" {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
vc, ok := c.session.VoiceConnections[ch.GuildID]
|
||||
if !ok || vc == nil {
|
||||
return nil, false
|
||||
}
|
||||
return vc, true
|
||||
}
|
||||
|
||||
// SendMedia implements the channels.MediaSender interface.
|
||||
func (c *DiscordChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) ([]string, error) {
|
||||
if !c.IsRunning() {
|
||||
@@ -266,7 +200,6 @@ func (c *DiscordChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMes
|
||||
if channelID == "" {
|
||||
return nil, fmt.Errorf("channel ID is empty")
|
||||
}
|
||||
trackedMsgID, hasTrackedMsg := c.currentToolFeedbackMessage(channelID)
|
||||
|
||||
store := c.GetMediaStore()
|
||||
if store == nil {
|
||||
@@ -348,9 +281,6 @@ func (c *DiscordChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMes
|
||||
if r.err != nil {
|
||||
return nil, fmt.Errorf("discord send media: %w", channels.ErrTemporary)
|
||||
}
|
||||
if hasTrackedMsg {
|
||||
c.dismissTrackedToolFeedbackMessage(ctx, channelID, trackedMsgID)
|
||||
}
|
||||
return []string{r.id}, nil
|
||||
case <-sendCtx.Done():
|
||||
// Close all file readers
|
||||
@@ -365,15 +295,10 @@ func (c *DiscordChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMes
|
||||
|
||||
// EditMessage implements channels.MessageEditor.
|
||||
func (c *DiscordChannel) EditMessage(ctx context.Context, chatID string, messageID string, content string) error {
|
||||
_, err := c.session.ChannelMessageEdit(chatID, messageID, content, discordgo.WithContext(ctx))
|
||||
_, err := c.session.ChannelMessageEdit(chatID, messageID, content)
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteMessage implements channels.MessageDeleter.
|
||||
func (c *DiscordChannel) DeleteMessage(ctx context.Context, chatID string, messageID string) error {
|
||||
return c.session.ChannelMessageDelete(chatID, messageID, discordgo.WithContext(ctx))
|
||||
}
|
||||
|
||||
// SendPlaceholder implements channels.PlaceholderCapable.
|
||||
// It sends a placeholder message that will later be edited to the actual
|
||||
// response via EditMessage (channels.MessageEditor).
|
||||
@@ -392,81 +317,6 @@ func (c *DiscordChannel) SendPlaceholder(ctx context.Context, chatID string) (st
|
||||
return msg.ID, nil
|
||||
}
|
||||
|
||||
func outboundMessageIsToolFeedback(msg bus.OutboundMessage) bool {
|
||||
if len(msg.Context.Raw) == 0 {
|
||||
return false
|
||||
}
|
||||
return strings.EqualFold(strings.TrimSpace(msg.Context.Raw["message_kind"]), "tool_feedback")
|
||||
}
|
||||
|
||||
func (c *DiscordChannel) currentToolFeedbackMessage(chatID string) (string, bool) {
|
||||
if c.progress == nil {
|
||||
return "", false
|
||||
}
|
||||
return c.progress.Current(chatID)
|
||||
}
|
||||
|
||||
func (c *DiscordChannel) takeToolFeedbackMessage(chatID string) (string, string, bool) {
|
||||
if c.progress == nil {
|
||||
return "", "", false
|
||||
}
|
||||
return c.progress.Take(chatID)
|
||||
}
|
||||
|
||||
func (c *DiscordChannel) RecordToolFeedbackMessage(chatID, messageID, content string) {
|
||||
if c.progress == nil {
|
||||
return
|
||||
}
|
||||
c.progress.Record(chatID, messageID, content)
|
||||
}
|
||||
|
||||
func (c *DiscordChannel) ClearToolFeedbackMessage(chatID string) {
|
||||
if c.progress == nil {
|
||||
return
|
||||
}
|
||||
c.progress.Clear(chatID)
|
||||
}
|
||||
|
||||
func (c *DiscordChannel) DismissToolFeedbackMessage(ctx context.Context, chatID string) {
|
||||
msgID, ok := c.currentToolFeedbackMessage(chatID)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
c.dismissTrackedToolFeedbackMessage(ctx, chatID, msgID)
|
||||
}
|
||||
|
||||
func (c *DiscordChannel) dismissTrackedToolFeedbackMessage(ctx context.Context, chatID, messageID string) {
|
||||
if strings.TrimSpace(chatID) == "" || strings.TrimSpace(messageID) == "" {
|
||||
return
|
||||
}
|
||||
c.ClearToolFeedbackMessage(chatID)
|
||||
_ = c.DeleteMessage(ctx, chatID, messageID)
|
||||
}
|
||||
|
||||
func (c *DiscordChannel) finalizeTrackedToolFeedbackMessage(
|
||||
ctx context.Context,
|
||||
chatID string,
|
||||
content string,
|
||||
editFn func(context.Context, string, string, string) error,
|
||||
) ([]string, bool) {
|
||||
msgID, baseContent, ok := c.takeToolFeedbackMessage(chatID)
|
||||
if !ok || editFn == nil {
|
||||
return nil, false
|
||||
}
|
||||
if err := editFn(ctx, chatID, msgID, content); err != nil {
|
||||
c.RecordToolFeedbackMessage(chatID, msgID, baseContent)
|
||||
return nil, false
|
||||
}
|
||||
return []string{msgID}, true
|
||||
}
|
||||
|
||||
func (c *DiscordChannel) FinalizeToolFeedbackMessage(ctx context.Context, msg bus.OutboundMessage) ([]string, bool) {
|
||||
if outboundMessageIsToolFeedback(msg) {
|
||||
return nil, false
|
||||
}
|
||||
return c.finalizeTrackedToolFeedbackMessage(ctx, msg.ChatID, msg.Content, c.EditMessage)
|
||||
}
|
||||
|
||||
func (c *DiscordChannel) sendChunk(ctx context.Context, channelID, content, replyToID string) (string, error) {
|
||||
// Use the passed ctx for timeout control
|
||||
sendCtx, cancel := context.WithTimeout(ctx, sendTimeout)
|
||||
|
||||
@@ -1,37 +1,13 @@
|
||||
package discord
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/bwmarrin/discordgo"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/audio/tts"
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/channels"
|
||||
)
|
||||
|
||||
type stubTTSProvider struct{}
|
||||
|
||||
func (stubTTSProvider) Name() string { return "stub-tts" }
|
||||
|
||||
func (stubTTSProvider) Synthesize(context.Context, string) (io.ReadCloser, error) {
|
||||
return io.NopCloser(&noopReader{}), nil
|
||||
}
|
||||
|
||||
type noopReader struct{}
|
||||
|
||||
func (*noopReader) Read(p []byte) (int, error) {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
func TestApplyDiscordProxy_CustomProxy(t *testing.T) {
|
||||
session, err := discordgo.New("Bot test-token")
|
||||
if err != nil {
|
||||
@@ -113,224 +89,3 @@ func TestApplyDiscordProxy_InvalidProxyURL(t *testing.T) {
|
||||
t.Fatal("applyDiscordProxy() expected error for invalid proxy URL, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSend_NonToolFeedbackDeletesTrackedProgressMessage(t *testing.T) {
|
||||
var (
|
||||
mu sync.Mutex
|
||||
requests []string
|
||||
)
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
mu.Lock()
|
||||
requests = append(requests, r.Method+" "+r.URL.Path)
|
||||
mu.Unlock()
|
||||
|
||||
switch {
|
||||
case r.Method == http.MethodPatch && r.URL.Path == "/channels/chat-1/messages/prog-1":
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = io.WriteString(w, `{"id":"prog-1"}`)
|
||||
default:
|
||||
t.Fatalf("unexpected request: %s %s", r.Method, r.URL.Path)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
origChannels := discordgo.EndpointChannels
|
||||
discordgo.EndpointChannels = server.URL + "/channels/"
|
||||
defer func() {
|
||||
discordgo.EndpointChannels = origChannels
|
||||
}()
|
||||
|
||||
session, err := discordgo.New("Bot test-token")
|
||||
if err != nil {
|
||||
t.Fatalf("discordgo.New() error: %v", err)
|
||||
}
|
||||
session.Client = server.Client()
|
||||
|
||||
ch := &DiscordChannel{
|
||||
BaseChannel: channels.NewBaseChannel("discord", nil, bus.NewMessageBus(), nil),
|
||||
session: session,
|
||||
ctx: context.Background(),
|
||||
typingStop: make(map[string]chan struct{}),
|
||||
voiceSSRC: make(map[string]map[uint32]string),
|
||||
}
|
||||
ch.progress = channels.NewToolFeedbackAnimator(ch.EditMessage)
|
||||
ch.SetRunning(true)
|
||||
ch.RecordToolFeedbackMessage("chat-1", "prog-1", "🔧 `read_file`")
|
||||
|
||||
ids, err := ch.Send(context.Background(), bus.OutboundMessage{
|
||||
ChatID: "chat-1",
|
||||
Content: "final reply",
|
||||
Context: bus.InboundContext{
|
||||
Channel: "discord",
|
||||
ChatID: "chat-1",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Send() error = %v", err)
|
||||
}
|
||||
if got, want := ids, []string{"prog-1"}; !reflect.DeepEqual(got, want) {
|
||||
t.Fatalf("Send() ids = %v, want %v", got, want)
|
||||
}
|
||||
if _, ok := ch.currentToolFeedbackMessage("chat-1"); ok {
|
||||
t.Fatal("expected tracked tool feedback message to be cleared")
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
wantRequests := []string{
|
||||
"PATCH /channels/chat-1/messages/prog-1",
|
||||
}
|
||||
if !reflect.DeepEqual(requests, wantRequests) {
|
||||
t.Fatalf("requests = %v, want %v", requests, wantRequests)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEditMessage_UsesContextCancellation(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
select {
|
||||
case <-r.Context().Done():
|
||||
return
|
||||
case <-time.After(time.Second):
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = io.WriteString(w, `{"id":"msg-1"}`)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
origChannels := discordgo.EndpointChannels
|
||||
discordgo.EndpointChannels = server.URL + "/channels/"
|
||||
defer func() {
|
||||
discordgo.EndpointChannels = origChannels
|
||||
}()
|
||||
|
||||
session, err := discordgo.New("Bot test-token")
|
||||
if err != nil {
|
||||
t.Fatalf("discordgo.New() error: %v", err)
|
||||
}
|
||||
session.Client = server.Client()
|
||||
|
||||
ch := &DiscordChannel{
|
||||
BaseChannel: channels.NewBaseChannel("discord", nil, bus.NewMessageBus(), nil),
|
||||
session: session,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
start := time.Now()
|
||||
err = ch.EditMessage(ctx, "chat-1", "msg-1", "still running")
|
||||
elapsed := time.Since(start)
|
||||
|
||||
if err == nil {
|
||||
t.Fatal("expected EditMessage() to fail when context times out")
|
||||
}
|
||||
if elapsed >= 500*time.Millisecond {
|
||||
t.Fatalf("EditMessage() ignored context timeout, elapsed=%v", elapsed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFinalizeTrackedToolFeedbackMessage_StopsTrackingBeforeEdit(t *testing.T) {
|
||||
ch := &DiscordChannel{
|
||||
progress: channels.NewToolFeedbackAnimator(nil),
|
||||
}
|
||||
ch.RecordToolFeedbackMessage("chat-1", "msg-1", "🔧 `read_file`")
|
||||
|
||||
msgIDs, handled := ch.finalizeTrackedToolFeedbackMessage(
|
||||
context.Background(),
|
||||
"chat-1",
|
||||
"final reply",
|
||||
func(_ context.Context, chatID, messageID, content string) error {
|
||||
if _, ok := ch.currentToolFeedbackMessage(chatID); ok {
|
||||
t.Fatal("expected tracked tool feedback to be stopped before edit")
|
||||
}
|
||||
if chatID != "chat-1" || messageID != "msg-1" || content != "final reply" {
|
||||
t.Fatalf("unexpected edit args: %s %s %s", chatID, messageID, content)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
)
|
||||
if !handled {
|
||||
t.Fatal("expected finalizeTrackedToolFeedbackMessage to handle tracked message")
|
||||
}
|
||||
if got, want := msgIDs, []string{"msg-1"}; !reflect.DeepEqual(got, want) {
|
||||
t.Fatalf("finalizeTrackedToolFeedbackMessage() ids = %v, want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSend_NonToolFeedbackFinalizerStillStartsTTS(t *testing.T) {
|
||||
var (
|
||||
mu sync.Mutex
|
||||
requests []string
|
||||
)
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
mu.Lock()
|
||||
requests = append(requests, r.Method+" "+r.URL.Path)
|
||||
mu.Unlock()
|
||||
|
||||
switch {
|
||||
case r.Method == http.MethodPatch && r.URL.Path == "/channels/chat-1/messages/prog-1":
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = io.WriteString(w, `{"id":"prog-1"}`)
|
||||
default:
|
||||
t.Fatalf("unexpected request: %s %s", r.Method, r.URL.Path)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
origChannels := discordgo.EndpointChannels
|
||||
discordgo.EndpointChannels = server.URL + "/channels/"
|
||||
defer func() {
|
||||
discordgo.EndpointChannels = origChannels
|
||||
}()
|
||||
|
||||
session, err := discordgo.New("Bot test-token")
|
||||
if err != nil {
|
||||
t.Fatalf("discordgo.New() error: %v", err)
|
||||
}
|
||||
session.Client = server.Client()
|
||||
|
||||
ttsStarted := make(chan string, 1)
|
||||
ch := &DiscordChannel{
|
||||
BaseChannel: channels.NewBaseChannel("discord", nil, bus.NewMessageBus(), nil),
|
||||
session: session,
|
||||
ctx: context.Background(),
|
||||
typingStop: make(map[string]chan struct{}),
|
||||
voiceSSRC: make(map[string]map[uint32]string),
|
||||
tts: tts.TTSProvider(stubTTSProvider{}),
|
||||
}
|
||||
ch.ttsVoiceFn = func(string) (*discordgo.VoiceConnection, bool) {
|
||||
return &discordgo.VoiceConnection{}, true
|
||||
}
|
||||
ch.playTTSFn = func(_ context.Context, _ *discordgo.VoiceConnection, text string, _ uint64) {
|
||||
ttsStarted <- text
|
||||
}
|
||||
ch.progress = channels.NewToolFeedbackAnimator(ch.EditMessage)
|
||||
ch.SetRunning(true)
|
||||
ch.RecordToolFeedbackMessage("chat-1", "prog-1", "🔧 `read_file`")
|
||||
|
||||
ids, err := ch.Send(context.Background(), bus.OutboundMessage{
|
||||
ChatID: "chat-1",
|
||||
Content: "final reply",
|
||||
Context: bus.InboundContext{
|
||||
Channel: "discord",
|
||||
ChatID: "chat-1",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Send() error = %v", err)
|
||||
}
|
||||
if got, want := ids, []string{"prog-1"}; !reflect.DeepEqual(got, want) {
|
||||
t.Fatalf("Send() ids = %v, want %v", got, want)
|
||||
}
|
||||
|
||||
select {
|
||||
case got := <-ttsStarted:
|
||||
if got != "final reply" {
|
||||
t.Fatalf("TTS content = %q, want final reply", got)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("expected TTS to start for finalized tracked tool feedback reply")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -49,8 +49,6 @@ type FeishuChannel struct {
|
||||
|
||||
mu sync.Mutex
|
||||
cancel context.CancelFunc
|
||||
|
||||
progress *channels.ToolFeedbackAnimator
|
||||
}
|
||||
|
||||
type cachedMessage struct {
|
||||
@@ -76,7 +74,6 @@ func NewFeishuChannel(bc *config.Channel, cfg *config.FeishuSettings, bus *bus.M
|
||||
tokenCache: tc,
|
||||
client: lark.NewClient(cfg.AppID, cfg.AppSecret.String(), opts...),
|
||||
}
|
||||
ch.progress = channels.NewToolFeedbackAnimator(ch.EditMessage)
|
||||
ch.SetOwner(ch)
|
||||
return ch, nil
|
||||
}
|
||||
@@ -135,9 +132,6 @@ func (c *FeishuChannel) Stop(ctx context.Context) error {
|
||||
}
|
||||
c.wsClient = nil
|
||||
c.mu.Unlock()
|
||||
if c.progress != nil {
|
||||
c.progress.StopAll()
|
||||
}
|
||||
|
||||
c.SetRunning(false)
|
||||
logger.InfoC("feishu", "Feishu channel stopped")
|
||||
@@ -155,50 +149,17 @@ func (c *FeishuChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([]st
|
||||
return nil, fmt.Errorf("chat ID is empty: %w", channels.ErrSendFailed)
|
||||
}
|
||||
|
||||
isToolFeedback := outboundMessageIsToolFeedback(msg)
|
||||
trackedMsgID, hasTrackedMsg := c.currentToolFeedbackMessage(msg.ChatID)
|
||||
if isToolFeedback {
|
||||
if msgID, handled, err := c.progress.Update(ctx, msg.ChatID, msg.Content); handled {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return []string{msgID}, nil
|
||||
}
|
||||
} else {
|
||||
if msgIDs, handled := c.FinalizeToolFeedbackMessage(ctx, msg); handled {
|
||||
return msgIDs, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Build interactive card with markdown content
|
||||
sendContent := msg.Content
|
||||
if isToolFeedback {
|
||||
sendContent = channels.InitialAnimatedToolFeedbackContent(msg.Content)
|
||||
}
|
||||
cardContent, err := buildMarkdownCard(sendContent)
|
||||
cardContent, err := buildMarkdownCard(msg.Content)
|
||||
if err != nil {
|
||||
// If card build fails, fall back to plain text
|
||||
msgID, sendErr := c.sendText(ctx, msg.ChatID, sendContent)
|
||||
if sendErr != nil {
|
||||
return nil, sendErr
|
||||
}
|
||||
if isToolFeedback {
|
||||
c.RecordToolFeedbackMessage(msg.ChatID, msgID, msg.Content)
|
||||
} else if hasTrackedMsg {
|
||||
c.dismissTrackedToolFeedbackMessage(ctx, msg.ChatID, trackedMsgID)
|
||||
}
|
||||
return []string{msgID}, nil
|
||||
return nil, c.sendText(ctx, msg.ChatID, msg.Content)
|
||||
}
|
||||
|
||||
// First attempt: try sending as interactive card
|
||||
msgID, err := c.sendCard(ctx, msg.ChatID, cardContent)
|
||||
err = c.sendCard(ctx, msg.ChatID, cardContent)
|
||||
if err == nil {
|
||||
if isToolFeedback {
|
||||
c.RecordToolFeedbackMessage(msg.ChatID, msgID, msg.Content)
|
||||
} else if hasTrackedMsg {
|
||||
c.dismissTrackedToolFeedbackMessage(ctx, msg.ChatID, trackedMsgID)
|
||||
}
|
||||
return []string{msgID}, nil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Check if error is due to card table limit (error code 11310)
|
||||
@@ -213,14 +174,9 @@ func (c *FeishuChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([]st
|
||||
})
|
||||
|
||||
// Second attempt: fall back to plain text message
|
||||
msgID, textErr := c.sendText(ctx, msg.ChatID, sendContent)
|
||||
textErr := c.sendText(ctx, msg.ChatID, msg.Content)
|
||||
if textErr == nil {
|
||||
if isToolFeedback {
|
||||
c.RecordToolFeedbackMessage(msg.ChatID, msgID, msg.Content)
|
||||
} else if hasTrackedMsg {
|
||||
c.dismissTrackedToolFeedbackMessage(ctx, msg.ChatID, trackedMsgID)
|
||||
}
|
||||
return []string{msgID}, nil
|
||||
return nil, nil
|
||||
}
|
||||
// If text also fails, return the text error
|
||||
return nil, textErr
|
||||
@@ -254,23 +210,6 @@ func (c *FeishuChannel) EditMessage(ctx context.Context, chatID, messageID, cont
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteMessage implements channels.MessageDeleter.
|
||||
func (c *FeishuChannel) DeleteMessage(ctx context.Context, chatID, messageID string) error {
|
||||
req := larkim.NewDeleteMessageReqBuilder().
|
||||
MessageId(messageID).
|
||||
Build()
|
||||
|
||||
resp, err := c.client.Im.V1.Message.Delete(ctx, req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("feishu delete: %w", err)
|
||||
}
|
||||
if !resp.Success() {
|
||||
c.invalidateTokenOnAuthError(resp.Code)
|
||||
return fmt.Errorf("feishu delete api error (code=%d msg=%s)", resp.Code, resp.Msg)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SendPlaceholder implements channels.PlaceholderCapable.
|
||||
// Sends an interactive card with placeholder text and returns its message ID.
|
||||
func (c *FeishuChannel) SendPlaceholder(ctx context.Context, chatID string) (string, error) {
|
||||
@@ -312,81 +251,6 @@ func (c *FeishuChannel) SendPlaceholder(ctx context.Context, chatID string) (str
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func outboundMessageIsToolFeedback(msg bus.OutboundMessage) bool {
|
||||
if len(msg.Context.Raw) == 0 {
|
||||
return false
|
||||
}
|
||||
return strings.EqualFold(strings.TrimSpace(msg.Context.Raw["message_kind"]), "tool_feedback")
|
||||
}
|
||||
|
||||
func (c *FeishuChannel) currentToolFeedbackMessage(chatID string) (string, bool) {
|
||||
if c.progress == nil {
|
||||
return "", false
|
||||
}
|
||||
return c.progress.Current(chatID)
|
||||
}
|
||||
|
||||
func (c *FeishuChannel) takeToolFeedbackMessage(chatID string) (string, string, bool) {
|
||||
if c.progress == nil {
|
||||
return "", "", false
|
||||
}
|
||||
return c.progress.Take(chatID)
|
||||
}
|
||||
|
||||
func (c *FeishuChannel) RecordToolFeedbackMessage(chatID, messageID, content string) {
|
||||
if c.progress == nil {
|
||||
return
|
||||
}
|
||||
c.progress.Record(chatID, messageID, content)
|
||||
}
|
||||
|
||||
func (c *FeishuChannel) ClearToolFeedbackMessage(chatID string) {
|
||||
if c.progress == nil {
|
||||
return
|
||||
}
|
||||
c.progress.Clear(chatID)
|
||||
}
|
||||
|
||||
func (c *FeishuChannel) DismissToolFeedbackMessage(ctx context.Context, chatID string) {
|
||||
msgID, ok := c.currentToolFeedbackMessage(chatID)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
c.dismissTrackedToolFeedbackMessage(ctx, chatID, msgID)
|
||||
}
|
||||
|
||||
func (c *FeishuChannel) dismissTrackedToolFeedbackMessage(ctx context.Context, chatID, messageID string) {
|
||||
if strings.TrimSpace(chatID) == "" || strings.TrimSpace(messageID) == "" {
|
||||
return
|
||||
}
|
||||
c.ClearToolFeedbackMessage(chatID)
|
||||
_ = c.DeleteMessage(ctx, chatID, messageID)
|
||||
}
|
||||
|
||||
func (c *FeishuChannel) finalizeTrackedToolFeedbackMessage(
|
||||
ctx context.Context,
|
||||
chatID string,
|
||||
content string,
|
||||
editFn func(context.Context, string, string, string) error,
|
||||
) ([]string, bool) {
|
||||
msgID, baseContent, ok := c.takeToolFeedbackMessage(chatID)
|
||||
if !ok || editFn == nil {
|
||||
return nil, false
|
||||
}
|
||||
if err := editFn(ctx, chatID, msgID, content); err != nil {
|
||||
c.RecordToolFeedbackMessage(chatID, msgID, baseContent)
|
||||
return nil, false
|
||||
}
|
||||
return []string{msgID}, true
|
||||
}
|
||||
|
||||
func (c *FeishuChannel) FinalizeToolFeedbackMessage(ctx context.Context, msg bus.OutboundMessage) ([]string, bool) {
|
||||
if outboundMessageIsToolFeedback(msg) {
|
||||
return nil, false
|
||||
}
|
||||
return c.finalizeTrackedToolFeedbackMessage(ctx, msg.ChatID, msg.Content, c.EditMessage)
|
||||
}
|
||||
|
||||
// ReactToMessage implements channels.ReactionCapable.
|
||||
// Adds a reaction (randomly chosen from config) and returns an undo function to remove it.
|
||||
func (c *FeishuChannel) ReactToMessage(ctx context.Context, chatID, messageID string) (func(), error) {
|
||||
@@ -459,7 +323,6 @@ func (c *FeishuChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMess
|
||||
if !c.IsRunning() {
|
||||
return nil, channels.ErrNotRunning
|
||||
}
|
||||
trackedMsgID, hasTrackedMsg := c.currentToolFeedbackMessage(msg.ChatID)
|
||||
|
||||
if msg.ChatID == "" {
|
||||
return nil, fmt.Errorf("chat ID is empty: %w", channels.ErrSendFailed)
|
||||
@@ -476,10 +339,6 @@ func (c *FeishuChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMess
|
||||
}
|
||||
}
|
||||
|
||||
if hasTrackedMsg {
|
||||
c.dismissTrackedToolFeedbackMessage(ctx, msg.ChatID, trackedMsgID)
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
@@ -942,7 +801,7 @@ func appendMediaTags(content, messageType string, mediaRefs []string) string {
|
||||
}
|
||||
|
||||
// sendCard sends an interactive card message to a chat.
|
||||
func (c *FeishuChannel) sendCard(ctx context.Context, chatID, cardContent string) (string, error) {
|
||||
func (c *FeishuChannel) sendCard(ctx context.Context, chatID, cardContent string) error {
|
||||
req := larkim.NewCreateMessageReqBuilder().
|
||||
ReceiveIdType(larkim.ReceiveIdTypeChatId).
|
||||
Body(larkim.NewCreateMessageReqBodyBuilder().
|
||||
@@ -954,26 +813,23 @@ func (c *FeishuChannel) sendCard(ctx context.Context, chatID, cardContent string
|
||||
|
||||
resp, err := c.client.Im.V1.Message.Create(ctx, req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("feishu send card: %w", channels.ErrTemporary)
|
||||
return fmt.Errorf("feishu send card: %w", channels.ErrTemporary)
|
||||
}
|
||||
|
||||
if !resp.Success() {
|
||||
c.invalidateTokenOnAuthError(resp.Code)
|
||||
return "", fmt.Errorf("feishu api error (code=%d msg=%s): %w", resp.Code, resp.Msg, channels.ErrTemporary)
|
||||
return fmt.Errorf("feishu api error (code=%d msg=%s): %w", resp.Code, resp.Msg, channels.ErrTemporary)
|
||||
}
|
||||
|
||||
logger.DebugCF("feishu", "Feishu card message sent", map[string]any{
|
||||
"chat_id": chatID,
|
||||
})
|
||||
|
||||
if resp.Data != nil && resp.Data.MessageId != nil {
|
||||
return *resp.Data.MessageId, nil
|
||||
}
|
||||
return "", nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// sendText sends a plain text message to a chat (fallback when card fails).
|
||||
func (c *FeishuChannel) sendText(ctx context.Context, chatID, text string) (string, error) {
|
||||
func (c *FeishuChannel) sendText(ctx context.Context, chatID, text string) error {
|
||||
content, _ := json.Marshal(map[string]string{"text": text})
|
||||
|
||||
req := larkim.NewCreateMessageReqBuilder().
|
||||
@@ -987,21 +843,18 @@ func (c *FeishuChannel) sendText(ctx context.Context, chatID, text string) (stri
|
||||
|
||||
resp, err := c.client.Im.V1.Message.Create(ctx, req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("feishu send text: %w", channels.ErrTemporary)
|
||||
return fmt.Errorf("feishu send text: %w", channels.ErrTemporary)
|
||||
}
|
||||
|
||||
if !resp.Success() {
|
||||
return "", fmt.Errorf("feishu text api error (code=%d msg=%s): %w", resp.Code, resp.Msg, channels.ErrTemporary)
|
||||
return fmt.Errorf("feishu text api error (code=%d msg=%s): %w", resp.Code, resp.Msg, channels.ErrTemporary)
|
||||
}
|
||||
|
||||
logger.DebugCF("feishu", "Feishu text message sent (fallback)", map[string]any{
|
||||
"chat_id": chatID,
|
||||
})
|
||||
|
||||
if resp.Data != nil && resp.Data.MessageId != nil {
|
||||
return *resp.Data.MessageId, nil
|
||||
}
|
||||
return "", nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// sendImage uploads an image and sends it as a message.
|
||||
|
||||
@@ -3,13 +3,9 @@
|
||||
package feishu
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
larkim "github.com/larksuite/oapi-sdk-go/v3/service/im/v1"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/channels"
|
||||
)
|
||||
|
||||
func TestExtractContent(t *testing.T) {
|
||||
@@ -283,84 +279,3 @@ func TestExtractFeishuSenderID(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFinalizeTrackedToolFeedbackMessage_ClearAfterSuccessfulEdit(t *testing.T) {
|
||||
ch := &FeishuChannel{
|
||||
progress: channels.NewToolFeedbackAnimator(nil),
|
||||
}
|
||||
ch.RecordToolFeedbackMessage("chat-1", "msg-1", "🔧 `read_file`")
|
||||
|
||||
msgIDs, handled := ch.finalizeTrackedToolFeedbackMessage(
|
||||
context.Background(),
|
||||
"chat-1",
|
||||
"final reply",
|
||||
func(_ context.Context, chatID, messageID, content string) error {
|
||||
if chatID != "chat-1" || messageID != "msg-1" || content != "final reply" {
|
||||
t.Fatalf("unexpected edit args: %s %s %s", chatID, messageID, content)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
)
|
||||
if !handled {
|
||||
t.Fatal("expected finalizeTrackedToolFeedbackMessage to handle tracked message")
|
||||
}
|
||||
if len(msgIDs) != 1 || msgIDs[0] != "msg-1" {
|
||||
t.Fatalf("unexpected msgIDs: %v", msgIDs)
|
||||
}
|
||||
if _, ok := ch.currentToolFeedbackMessage("chat-1"); ok {
|
||||
t.Fatal("expected tracked tool feedback to be cleared after successful edit")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFinalizeTrackedToolFeedbackMessage_StopsTrackingBeforeEdit(t *testing.T) {
|
||||
ch := &FeishuChannel{
|
||||
progress: channels.NewToolFeedbackAnimator(nil),
|
||||
}
|
||||
ch.RecordToolFeedbackMessage("chat-1", "msg-1", "🔧 `read_file`")
|
||||
|
||||
msgIDs, handled := ch.finalizeTrackedToolFeedbackMessage(
|
||||
context.Background(),
|
||||
"chat-1",
|
||||
"final reply",
|
||||
func(_ context.Context, chatID, messageID, content string) error {
|
||||
if _, ok := ch.currentToolFeedbackMessage(chatID); ok {
|
||||
t.Fatal("expected tracked tool feedback to be stopped before edit")
|
||||
}
|
||||
if chatID != "chat-1" || messageID != "msg-1" || content != "final reply" {
|
||||
t.Fatalf("unexpected edit args: %s %s %s", chatID, messageID, content)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
)
|
||||
if !handled {
|
||||
t.Fatal("expected finalizeTrackedToolFeedbackMessage to handle tracked message")
|
||||
}
|
||||
if len(msgIDs) != 1 || msgIDs[0] != "msg-1" {
|
||||
t.Fatalf("unexpected msgIDs: %v", msgIDs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFinalizeTrackedToolFeedbackMessage_EditFailureKeepsTrackedMessage(t *testing.T) {
|
||||
ch := &FeishuChannel{
|
||||
progress: channels.NewToolFeedbackAnimator(nil),
|
||||
}
|
||||
ch.RecordToolFeedbackMessage("chat-1", "msg-1", "🔧 `read_file`")
|
||||
|
||||
msgIDs, handled := ch.finalizeTrackedToolFeedbackMessage(
|
||||
context.Background(),
|
||||
"chat-1",
|
||||
"final reply",
|
||||
func(context.Context, string, string, string) error {
|
||||
return errors.New("edit failed")
|
||||
},
|
||||
)
|
||||
if handled {
|
||||
t.Fatal("expected finalizeTrackedToolFeedbackMessage to report unhandled on edit failure")
|
||||
}
|
||||
if len(msgIDs) != 0 {
|
||||
t.Fatalf("unexpected msgIDs: %v", msgIDs)
|
||||
}
|
||||
if msgID, ok := ch.currentToolFeedbackMessage("chat-1"); !ok || msgID != "msg-1" {
|
||||
t.Fatalf("expected tracked tool feedback to remain after failed edit, got (%q, %v)", msgID, ok)
|
||||
}
|
||||
}
|
||||
|
||||
+18
-94
@@ -14,7 +14,6 @@ import (
|
||||
"net"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -26,7 +25,6 @@ import (
|
||||
"github.com/sipeed/picoclaw/pkg/health"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/media"
|
||||
"github.com/sipeed/picoclaw/pkg/utils"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -98,15 +96,6 @@ type Manager struct {
|
||||
channelHashes map[string]string // channel name → config hash
|
||||
}
|
||||
|
||||
type toolFeedbackMessageTracker interface {
|
||||
RecordToolFeedbackMessage(chatID, messageID, content string)
|
||||
ClearToolFeedbackMessage(chatID string)
|
||||
}
|
||||
|
||||
type toolFeedbackMessageCleaner interface {
|
||||
DismissToolFeedbackMessage(ctx context.Context, chatID string)
|
||||
}
|
||||
|
||||
type asyncTask struct {
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
@@ -119,13 +108,6 @@ func outboundMessageChatID(msg bus.OutboundMessage) string {
|
||||
return msg.ChatID
|
||||
}
|
||||
|
||||
func outboundMessageIsToolFeedback(msg bus.OutboundMessage) bool {
|
||||
if len(msg.Context.Raw) == 0 {
|
||||
return false
|
||||
}
|
||||
return strings.EqualFold(strings.TrimSpace(msg.Context.Raw["message_kind"]), "tool_feedback")
|
||||
}
|
||||
|
||||
func outboundMediaChannel(msg bus.OutboundMediaMessage) string {
|
||||
return msg.Context.Channel
|
||||
}
|
||||
@@ -134,16 +116,6 @@ func outboundMediaChatID(msg bus.OutboundMediaMessage) string {
|
||||
return msg.ChatID
|
||||
}
|
||||
|
||||
func dismissTrackedToolFeedbackMessage(ctx context.Context, ch Channel, chatID string) {
|
||||
if cleaner, ok := ch.(toolFeedbackMessageCleaner); ok {
|
||||
cleaner.DismissToolFeedbackMessage(ctx, chatID)
|
||||
return
|
||||
}
|
||||
if tracker, ok := ch.(toolFeedbackMessageTracker); ok {
|
||||
tracker.ClearToolFeedbackMessage(chatID)
|
||||
}
|
||||
}
|
||||
|
||||
// RecordPlaceholder registers a placeholder message for later editing.
|
||||
// Implements PlaceholderRecorder.
|
||||
func (m *Manager) RecordPlaceholder(channel, chatID, placeholderID string) {
|
||||
@@ -224,19 +196,7 @@ func (m *Manager) preSend(ctx context.Context, name string, msg bus.OutboundMess
|
||||
}
|
||||
}
|
||||
|
||||
isToolFeedback := outboundMessageIsToolFeedback(msg)
|
||||
|
||||
// 3. If a stream already finalized this chat, stale tool feedback must be
|
||||
// dropped without consuming the final-response marker. Streaming finalization
|
||||
// bypasses the worker queue, so older queued feedback can arrive before the
|
||||
// normal final outbound message that cleans up the marker and placeholder.
|
||||
if isToolFeedback {
|
||||
if _, loaded := m.streamActive.Load(key); loaded {
|
||||
return nil, true
|
||||
}
|
||||
}
|
||||
|
||||
// 4. If a stream already finalized this message, delete the placeholder and skip send
|
||||
// 3. If a stream already finalized this message, delete the placeholder and skip send
|
||||
if _, loaded := m.streamActive.LoadAndDelete(key); loaded {
|
||||
if v, loaded := m.placeholders.LoadAndDelete(key); loaded {
|
||||
if entry, ok := v.(placeholderEntry); ok && entry.id != "" {
|
||||
@@ -248,26 +208,14 @@ func (m *Manager) preSend(ctx context.Context, name string, msg bus.OutboundMess
|
||||
}
|
||||
}
|
||||
}
|
||||
if !isToolFeedback {
|
||||
dismissTrackedToolFeedbackMessage(ctx, ch, chatID)
|
||||
}
|
||||
return nil, true
|
||||
}
|
||||
|
||||
// 5. Try editing placeholder
|
||||
// 4. Try editing placeholder
|
||||
if v, loaded := m.placeholders.LoadAndDelete(key); loaded {
|
||||
if entry, ok := v.(placeholderEntry); ok && entry.id != "" {
|
||||
if editor, ok := ch.(MessageEditor); ok {
|
||||
content := msg.Content
|
||||
if isToolFeedback {
|
||||
content = InitialAnimatedToolFeedbackContent(msg.Content)
|
||||
}
|
||||
if err := editor.EditMessage(ctx, chatID, entry.id, content); err == nil {
|
||||
if tracker, ok := ch.(toolFeedbackMessageTracker); ok && isToolFeedback {
|
||||
tracker.RecordToolFeedbackMessage(chatID, entry.id, msg.Content)
|
||||
} else if !isToolFeedback {
|
||||
dismissTrackedToolFeedbackMessage(ctx, ch, chatID)
|
||||
}
|
||||
if err := editor.EditMessage(ctx, chatID, entry.id, msg.Content); err == nil {
|
||||
return []string{entry.id}, true
|
||||
}
|
||||
// edit failed → fall through to normal Send
|
||||
@@ -364,27 +312,22 @@ func (m *Manager) GetStreamer(ctx context.Context, channelName, chatID string) (
|
||||
// Mark streamActive on Finalize so preSend knows to clean up the placeholder
|
||||
key := channelName + ":" + chatID
|
||||
return &finalizeHookStreamer{
|
||||
Streamer: streamer,
|
||||
onFinalize: func(finalizeCtx context.Context) {
|
||||
dismissTrackedToolFeedbackMessage(finalizeCtx, ch, chatID)
|
||||
m.streamActive.Store(key, true)
|
||||
},
|
||||
Streamer: streamer,
|
||||
onFinalize: func() { m.streamActive.Store(key, true) },
|
||||
}, true
|
||||
}
|
||||
|
||||
// finalizeHookStreamer wraps a Streamer to run a hook on Finalize.
|
||||
type finalizeHookStreamer struct {
|
||||
Streamer
|
||||
onFinalize func(context.Context)
|
||||
onFinalize func()
|
||||
}
|
||||
|
||||
func (s *finalizeHookStreamer) Finalize(ctx context.Context, content string) error {
|
||||
if err := s.Streamer.Finalize(ctx, content); err != nil {
|
||||
return err
|
||||
}
|
||||
if s.onFinalize != nil {
|
||||
s.onFinalize(ctx)
|
||||
}
|
||||
s.onFinalize()
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -826,21 +769,18 @@ func (m *Manager) runWorker(ctx context.Context, name string, w *channelWorker)
|
||||
// Collect all message chunks to send
|
||||
var chunks []string
|
||||
|
||||
// Step 1: Try marker-based splitting if enabled.
|
||||
// Tool feedback must stay a single message, so it skips marker splitting.
|
||||
if m.config != nil && m.config.Agents.Defaults.SplitOnMarker && !outboundMessageIsToolFeedback(msg) {
|
||||
// Step 1: Try marker-based splitting if enabled
|
||||
if m.config != nil && m.config.Agents.Defaults.SplitOnMarker {
|
||||
if markerChunks := SplitByMarker(msg.Content); len(markerChunks) > 1 {
|
||||
for _, chunk := range markerChunks {
|
||||
chunkMsg := msg
|
||||
chunkMsg.Content = chunk
|
||||
chunks = append(chunks, splitOutboundMessageContent(chunkMsg, maxLen)...)
|
||||
chunks = append(chunks, splitByLength(chunk, maxLen)...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Step 2: Fallback to length-based splitting if no chunks from marker
|
||||
if len(chunks) == 0 {
|
||||
chunks = splitOutboundMessageContent(msg, maxLen)
|
||||
chunks = splitByLength(msg.Content, maxLen)
|
||||
}
|
||||
|
||||
// Step 3: Send all chunks
|
||||
@@ -855,25 +795,12 @@ func (m *Manager) runWorker(ctx context.Context, name string, w *channelWorker)
|
||||
}
|
||||
}
|
||||
|
||||
// splitOutboundMessageContent splits regular outbound content by maxLen, but
|
||||
// keeps tool feedback in a single message by truncating the explanation body.
|
||||
func splitOutboundMessageContent(msg bus.OutboundMessage, maxLen int) []string {
|
||||
if maxLen > 0 {
|
||||
if outboundMessageIsToolFeedback(msg) {
|
||||
animationSafeLen := maxLen - MaxToolFeedbackAnimationFrameLength()
|
||||
if animationSafeLen <= 0 {
|
||||
animationSafeLen = maxLen
|
||||
}
|
||||
if len([]rune(msg.Content)) > animationSafeLen {
|
||||
return []string{utils.FitToolFeedbackMessage(msg.Content, animationSafeLen)}
|
||||
}
|
||||
return []string{msg.Content}
|
||||
}
|
||||
if len([]rune(msg.Content)) > maxLen {
|
||||
return SplitMessage(msg.Content, maxLen)
|
||||
}
|
||||
// splitByLength splits content by maxLen if needed, otherwise returns single chunk.
|
||||
func splitByLength(content string, maxLen int) []string {
|
||||
if maxLen > 0 && len([]rune(content)) > maxLen {
|
||||
return SplitMessage(content, maxLen)
|
||||
}
|
||||
return []string{msg.Content}
|
||||
return []string{content}
|
||||
}
|
||||
|
||||
// sendWithRetry sends a message through the channel with rate limiting and
|
||||
@@ -1337,16 +1264,13 @@ func (m *Manager) SendMessage(ctx context.Context, msg bus.OutboundMessage) erro
|
||||
if mlp, ok := w.ch.(MessageLengthProvider); ok {
|
||||
maxLen = mlp.MaxMessageLength()
|
||||
}
|
||||
if chunks := splitOutboundMessageContent(msg, maxLen); len(chunks) > 1 {
|
||||
for _, chunk := range chunks {
|
||||
if maxLen > 0 && len([]rune(msg.Content)) > maxLen {
|
||||
for _, chunk := range SplitMessage(msg.Content, maxLen) {
|
||||
chunkMsg := msg
|
||||
chunkMsg.Content = chunk
|
||||
m.sendWithRetry(ctx, channelName, w, chunkMsg)
|
||||
}
|
||||
} else {
|
||||
if len(chunks) == 1 {
|
||||
msg.Content = chunks[0]
|
||||
}
|
||||
m.sendWithRetry(ctx, channelName, w, msg)
|
||||
}
|
||||
return nil
|
||||
|
||||
@@ -13,8 +13,6 @@ import (
|
||||
"golang.org/x/time/rate"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/utils"
|
||||
)
|
||||
|
||||
// mockChannel is a test double that delegates Send to a configurable function.
|
||||
@@ -78,9 +76,8 @@ func (m *mockMediaChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaM
|
||||
|
||||
type mockDeletingMediaChannel struct {
|
||||
mockMediaChannel
|
||||
deleteCalls int
|
||||
dismissedChatID string
|
||||
lastDeleted struct {
|
||||
deleteCalls int
|
||||
lastDeleted struct {
|
||||
chatID string
|
||||
messageID string
|
||||
}
|
||||
@@ -97,37 +94,6 @@ func (m *mockDeletingMediaChannel) DeleteMessage(
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockDeletingMediaChannel) DismissToolFeedbackMessage(_ context.Context, chatID string) {
|
||||
m.dismissedChatID = chatID
|
||||
}
|
||||
|
||||
type mockStreamer struct {
|
||||
finalizeFn func(context.Context, string) error
|
||||
}
|
||||
|
||||
func (m *mockStreamer) Update(context.Context, string) error { return nil }
|
||||
|
||||
func (m *mockStreamer) Finalize(ctx context.Context, content string) error {
|
||||
if m.finalizeFn != nil {
|
||||
return m.finalizeFn(ctx, content)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockStreamer) Cancel(context.Context) {}
|
||||
|
||||
type mockStreamingChannel struct {
|
||||
mockMessageEditor
|
||||
streamer Streamer
|
||||
}
|
||||
|
||||
func (m *mockStreamingChannel) BeginStream(context.Context, string) (Streamer, error) {
|
||||
if m.streamer == nil {
|
||||
return nil, errors.New("missing streamer")
|
||||
}
|
||||
return m.streamer, nil
|
||||
}
|
||||
|
||||
// newTestManager creates a minimal Manager suitable for unit tests.
|
||||
func newTestManager() *Manager {
|
||||
return &Manager{
|
||||
@@ -749,43 +715,13 @@ func TestSendWithRetry_ExponentialBackoff(t *testing.T) {
|
||||
// mockMessageEditor is a channel that supports MessageEditor.
|
||||
type mockMessageEditor struct {
|
||||
mockChannel
|
||||
editFn func(ctx context.Context, chatID, messageID, content string) error
|
||||
finalizeFn func(ctx context.Context, msg bus.OutboundMessage) ([]string, bool)
|
||||
finalizeCalled bool
|
||||
recordedChatID string
|
||||
recordedMessageID string
|
||||
clearedChatID string
|
||||
dismissedChatID string
|
||||
editFn func(ctx context.Context, chatID, messageID, content string) error
|
||||
}
|
||||
|
||||
func (m *mockMessageEditor) EditMessage(ctx context.Context, chatID, messageID, content string) error {
|
||||
return m.editFn(ctx, chatID, messageID, content)
|
||||
}
|
||||
|
||||
func (m *mockMessageEditor) RecordToolFeedbackMessage(chatID, messageID, _ string) {
|
||||
m.recordedChatID = chatID
|
||||
m.recordedMessageID = messageID
|
||||
}
|
||||
|
||||
func (m *mockMessageEditor) ClearToolFeedbackMessage(chatID string) {
|
||||
m.clearedChatID = chatID
|
||||
}
|
||||
|
||||
func (m *mockMessageEditor) DismissToolFeedbackMessage(_ context.Context, chatID string) {
|
||||
m.dismissedChatID = chatID
|
||||
}
|
||||
|
||||
func (m *mockMessageEditor) FinalizeToolFeedbackMessage(
|
||||
ctx context.Context,
|
||||
msg bus.OutboundMessage,
|
||||
) ([]string, bool) {
|
||||
m.finalizeCalled = true
|
||||
if m.finalizeFn == nil {
|
||||
return nil, false
|
||||
}
|
||||
return m.finalizeFn(ctx, msg)
|
||||
}
|
||||
|
||||
func TestPreSend_PlaceholderEditSuccess(t *testing.T) {
|
||||
m := newTestManager()
|
||||
var sendCalled bool
|
||||
@@ -830,360 +766,6 @@ func TestPreSend_PlaceholderEditSuccess(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestPreSend_ToolFeedbackPlaceholderEditRecordsTrackedMessage(t *testing.T) {
|
||||
m := newTestManager()
|
||||
|
||||
ch := &mockMessageEditor{
|
||||
editFn: func(_ context.Context, chatID, messageID, content string) error {
|
||||
if chatID != "123" || messageID != "456" || content != "hello" {
|
||||
t.Fatalf("unexpected edit args: %s %s %s", chatID, messageID, content)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
m.RecordPlaceholder("test", "123", "456")
|
||||
|
||||
msg := testOutboundMessage(bus.OutboundMessage{
|
||||
Channel: "test",
|
||||
ChatID: "123",
|
||||
Content: "hello",
|
||||
Context: bus.InboundContext{
|
||||
Channel: "test",
|
||||
ChatID: "123",
|
||||
Raw: map[string]string{
|
||||
"message_kind": "tool_feedback",
|
||||
},
|
||||
},
|
||||
})
|
||||
_, edited := m.preSend(context.Background(), "test", msg, ch)
|
||||
if !edited {
|
||||
t.Fatal("expected preSend to edit placeholder")
|
||||
}
|
||||
if ch.recordedChatID != "123" || ch.recordedMessageID != "456" {
|
||||
t.Fatalf("expected tracked message 123/456, got %q/%q", ch.recordedChatID, ch.recordedMessageID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPreSend_NonToolFeedbackLeavesTrackedMessageForChannelSend(t *testing.T) {
|
||||
m := newTestManager()
|
||||
ch := &mockMessageEditor{}
|
||||
|
||||
msg := testOutboundMessage(bus.OutboundMessage{
|
||||
Channel: "test",
|
||||
ChatID: "123",
|
||||
Content: "final reply",
|
||||
Context: bus.InboundContext{
|
||||
Channel: "test",
|
||||
ChatID: "123",
|
||||
},
|
||||
})
|
||||
|
||||
_, edited := m.preSend(context.Background(), "test", msg, ch)
|
||||
if edited {
|
||||
t.Fatal("expected preSend to fall through when no placeholder exists")
|
||||
}
|
||||
if ch.dismissedChatID != "" {
|
||||
t.Fatalf("expected tracked tool feedback cleanup to be deferred to channel send, got %q", ch.dismissedChatID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPreSend_NonToolFeedbackDefersTrackedMessageFinalizationToChannelSend(t *testing.T) {
|
||||
m := newTestManager()
|
||||
ch := &mockMessageEditor{
|
||||
finalizeFn: func(_ context.Context, msg bus.OutboundMessage) ([]string, bool) {
|
||||
if msg.ChatID != "123" || msg.Content != "final reply" {
|
||||
t.Fatalf("unexpected finalize msg: %+v", msg)
|
||||
}
|
||||
return []string{"tool-msg-1"}, true
|
||||
},
|
||||
}
|
||||
|
||||
msg := testOutboundMessage(bus.OutboundMessage{
|
||||
Channel: "test",
|
||||
ChatID: "123",
|
||||
Content: "final reply",
|
||||
Context: bus.InboundContext{
|
||||
Channel: "test",
|
||||
ChatID: "123",
|
||||
},
|
||||
})
|
||||
|
||||
msgIDs, handled := m.preSend(context.Background(), "test", msg, ch)
|
||||
if handled {
|
||||
t.Fatalf("expected preSend to defer to channel Send, got msgIDs=%v", msgIDs)
|
||||
}
|
||||
if len(msgIDs) != 0 {
|
||||
t.Fatalf("expected no msgIDs from preSend, got %v", msgIDs)
|
||||
}
|
||||
if ch.dismissedChatID != "" {
|
||||
t.Fatalf("expected tracked cleanup to remain in channel Send, got %q", ch.dismissedChatID)
|
||||
}
|
||||
if ch.finalizeCalled {
|
||||
t.Fatal("expected preSend to skip channel tool feedback finalization")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPreSend_StaleToolFeedbackDoesNotConsumeStreamActiveMarker(t *testing.T) {
|
||||
m := newTestManager()
|
||||
m.streamActive.Store("test:123", true)
|
||||
m.RecordPlaceholder("test", "123", "placeholder-1")
|
||||
|
||||
var editedContent string
|
||||
ch := &mockMessageEditor{
|
||||
editFn: func(_ context.Context, chatID, messageID, content string) error {
|
||||
if chatID != "123" || messageID != "placeholder-1" {
|
||||
t.Fatalf("unexpected edit target: %s/%s", chatID, messageID)
|
||||
}
|
||||
editedContent = content
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
toolFeedback := testOutboundMessage(bus.OutboundMessage{
|
||||
Channel: "test",
|
||||
ChatID: "123",
|
||||
Content: "🔧 `read_file`\nReading config",
|
||||
Context: bus.InboundContext{
|
||||
Channel: "test",
|
||||
ChatID: "123",
|
||||
Raw: map[string]string{
|
||||
"message_kind": "tool_feedback",
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
msgIDs, handled := m.preSend(context.Background(), "test", toolFeedback, ch)
|
||||
if !handled {
|
||||
t.Fatal("expected stale tool feedback to be dropped after stream finalize")
|
||||
}
|
||||
if len(msgIDs) != 0 {
|
||||
t.Fatalf("expected no delivered message IDs for stale feedback, got %v", msgIDs)
|
||||
}
|
||||
if _, ok := m.streamActive.Load("test:123"); !ok {
|
||||
t.Fatal("expected streamActive marker to remain for the final outbound message")
|
||||
}
|
||||
if _, ok := m.placeholders.Load("test:123"); !ok {
|
||||
t.Fatal("expected placeholder cleanup to remain deferred to the final outbound message")
|
||||
}
|
||||
if ch.editedMessages != 0 {
|
||||
t.Fatalf("expected no placeholder edit for stale feedback, got %d edits", ch.editedMessages)
|
||||
}
|
||||
|
||||
finalMsg := testOutboundMessage(bus.OutboundMessage{
|
||||
Channel: "test",
|
||||
ChatID: "123",
|
||||
Content: "final streamed reply",
|
||||
Context: bus.InboundContext{
|
||||
Channel: "test",
|
||||
ChatID: "123",
|
||||
},
|
||||
})
|
||||
|
||||
_, handled = m.preSend(context.Background(), "test", finalMsg, ch)
|
||||
if !handled {
|
||||
t.Fatal("expected final outbound message to consume streamActive marker")
|
||||
}
|
||||
if _, ok := m.streamActive.Load("test:123"); ok {
|
||||
t.Fatal("expected streamActive marker to be cleared by final outbound message")
|
||||
}
|
||||
if _, ok := m.placeholders.Load("test:123"); ok {
|
||||
t.Fatal("expected placeholder to be cleaned up by final outbound message")
|
||||
}
|
||||
if editedContent != "final streamed reply" {
|
||||
t.Fatalf("editedContent = %q, want final streamed reply", editedContent)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPreSendMedia_LeavesTrackedMessageForChannelSend(t *testing.T) {
|
||||
m := newTestManager()
|
||||
ch := &mockDeletingMediaChannel{}
|
||||
|
||||
m.preSendMedia(context.Background(), "test", bus.OutboundMediaMessage{
|
||||
ChatID: "123",
|
||||
Context: bus.InboundContext{
|
||||
Channel: "test",
|
||||
ChatID: "123",
|
||||
},
|
||||
}, ch)
|
||||
|
||||
if ch.dismissedChatID != "" {
|
||||
t.Fatalf(
|
||||
"expected tracked tool feedback cleanup to be deferred to channel media send, got %q",
|
||||
ch.dismissedChatID,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSplitOutboundMessageContent_ToolFeedbackTruncatesInsteadOfSplitting(t *testing.T) {
|
||||
msg := testOutboundMessage(bus.OutboundMessage{
|
||||
Channel: "test",
|
||||
ChatID: "123",
|
||||
Content: "\U0001f527 `read_file`\nRead README.md first to confirm the current project structure before editing the config example.",
|
||||
Context: bus.InboundContext{
|
||||
Channel: "test",
|
||||
ChatID: "123",
|
||||
Raw: map[string]string{
|
||||
"message_kind": "tool_feedback",
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
chunks := splitOutboundMessageContent(msg, 40)
|
||||
if len(chunks) != 1 {
|
||||
t.Fatalf("len(chunks) = %d, want 1", len(chunks))
|
||||
}
|
||||
want := utils.FitToolFeedbackMessage(msg.Content, 40-MaxToolFeedbackAnimationFrameLength())
|
||||
if chunks[0] != want {
|
||||
t.Fatalf("chunk = %q, want %q", chunks[0], want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSplitOutboundMessageContent_ToolFeedbackReservesAnimationFrame(t *testing.T) {
|
||||
msg := testOutboundMessage(bus.OutboundMessage{
|
||||
Channel: "test",
|
||||
ChatID: "123",
|
||||
Content: "🔧 `read_file`\n1234567890",
|
||||
Context: bus.InboundContext{
|
||||
Channel: "test",
|
||||
ChatID: "123",
|
||||
Raw: map[string]string{
|
||||
"message_kind": "tool_feedback",
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
chunks := splitOutboundMessageContent(msg, len([]rune(msg.Content)))
|
||||
if len(chunks) != 1 {
|
||||
t.Fatalf("len(chunks) = %d, want 1", len(chunks))
|
||||
}
|
||||
|
||||
animated := formatAnimatedToolFeedbackContent(chunks[0], strings.Repeat(".", MaxToolFeedbackAnimationFrameLength()))
|
||||
if got, maxLen := len([]rune(animated)), len([]rune(msg.Content)); got > maxLen {
|
||||
t.Fatalf("animated len = %d, want <= %d; content=%q", got, maxLen, animated)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetStreamer_FinalizeDismissesTrackedToolFeedback(t *testing.T) {
|
||||
m := newTestManager()
|
||||
ch := &mockStreamingChannel{
|
||||
mockMessageEditor: mockMessageEditor{},
|
||||
streamer: &mockStreamer{
|
||||
finalizeFn: func(_ context.Context, content string) error {
|
||||
if content != "final reply" {
|
||||
t.Fatalf("unexpected finalize content: %q", content)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
}
|
||||
m.channels["test"] = ch
|
||||
|
||||
streamer, ok := m.GetStreamer(context.Background(), "test", "123")
|
||||
if !ok {
|
||||
t.Fatal("expected streamer to be available")
|
||||
}
|
||||
if err := streamer.Finalize(context.Background(), "final reply"); err != nil {
|
||||
t.Fatalf("Finalize() error = %v", err)
|
||||
}
|
||||
if ch.dismissedChatID != "123" {
|
||||
t.Fatalf("expected tracked tool feedback to be dismissed for chat 123, got %q", ch.dismissedChatID)
|
||||
}
|
||||
if _, ok := m.streamActive.Load("test:123"); !ok {
|
||||
t.Fatal("expected streamActive marker to be recorded after finalize")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetStreamer_FinalizeFailureDoesNotDismissTrackedToolFeedback(t *testing.T) {
|
||||
m := newTestManager()
|
||||
ch := &mockStreamingChannel{
|
||||
mockMessageEditor: mockMessageEditor{},
|
||||
streamer: &mockStreamer{
|
||||
finalizeFn: func(context.Context, string) error {
|
||||
return errors.New("finalize failed")
|
||||
},
|
||||
},
|
||||
}
|
||||
m.channels["test"] = ch
|
||||
|
||||
streamer, ok := m.GetStreamer(context.Background(), "test", "123")
|
||||
if !ok {
|
||||
t.Fatal("expected streamer to be available")
|
||||
}
|
||||
if err := streamer.Finalize(context.Background(), "final reply"); err == nil {
|
||||
t.Fatal("expected Finalize() to fail")
|
||||
}
|
||||
if ch.dismissedChatID != "" {
|
||||
t.Fatalf("expected no tool feedback dismissal on finalize failure, got %q", ch.dismissedChatID)
|
||||
}
|
||||
if _, ok := m.streamActive.Load("test:123"); ok {
|
||||
t.Fatal("expected no streamActive marker after finalize failure")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunWorker_ToolFeedbackSkipsMarkerSplitting(t *testing.T) {
|
||||
m := newTestManager()
|
||||
m.config = &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
SplitOnMarker: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
var (
|
||||
mu sync.Mutex
|
||||
received []string
|
||||
)
|
||||
ch := &mockChannelWithLength{
|
||||
mockChannel: mockChannel{
|
||||
sendFn: func(_ context.Context, msg bus.OutboundMessage) error {
|
||||
mu.Lock()
|
||||
received = append(received, msg.Content)
|
||||
mu.Unlock()
|
||||
return nil
|
||||
},
|
||||
},
|
||||
maxLen: 200,
|
||||
}
|
||||
|
||||
w := &channelWorker{
|
||||
ch: ch,
|
||||
queue: make(chan bus.OutboundMessage, 1),
|
||||
done: make(chan struct{}),
|
||||
limiter: rate.NewLimiter(rate.Inf, 1),
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
go m.runWorker(ctx, "test", w)
|
||||
|
||||
content := "🔧 `read_file`\nRead current config first.<|[SPLIT]|>Then update the example."
|
||||
w.queue <- testOutboundMessage(bus.OutboundMessage{
|
||||
Channel: "test",
|
||||
ChatID: "123",
|
||||
Content: content,
|
||||
Context: bus.InboundContext{
|
||||
Channel: "test",
|
||||
ChatID: "123",
|
||||
Raw: map[string]string{
|
||||
"message_kind": "tool_feedback",
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
if len(received) != 1 {
|
||||
t.Fatalf("len(received) = %d, want 1", len(received))
|
||||
}
|
||||
if received[0] != content {
|
||||
t.Fatalf("received[0] = %q, want %q", received[0], content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPreSend_PlaceholderEditFails_FallsThrough(t *testing.T) {
|
||||
m := newTestManager()
|
||||
|
||||
|
||||
@@ -46,13 +46,6 @@ const (
|
||||
|
||||
var matrixMentionHrefRegexp = regexp.MustCompile(`(?i)<a[^>]+href=["']([^"']+)["']`)
|
||||
|
||||
func outboundMessageIsToolFeedback(msg bus.OutboundMessage) bool {
|
||||
if len(msg.Context.Raw) == 0 {
|
||||
return false
|
||||
}
|
||||
return strings.EqualFold(strings.TrimSpace(msg.Context.Raw["message_kind"]), "tool_feedback")
|
||||
}
|
||||
|
||||
type roomKindCacheEntry struct {
|
||||
isGroup bool
|
||||
expiresAt time.Time
|
||||
@@ -199,7 +192,6 @@ type MatrixChannel struct {
|
||||
|
||||
cryptoHelper *cryptohelper.CryptoHelper
|
||||
cryptoDbPath string
|
||||
progress *channels.ToolFeedbackAnimator
|
||||
}
|
||||
|
||||
func NewMatrixChannel(
|
||||
@@ -244,7 +236,7 @@ func NewMatrixChannel(
|
||||
channels.WithReasoningChannelID(bc.ReasoningChannelID),
|
||||
)
|
||||
|
||||
ch := &MatrixChannel{
|
||||
return &MatrixChannel{
|
||||
BaseChannel: base,
|
||||
bc: bc,
|
||||
client: client,
|
||||
@@ -256,9 +248,7 @@ func NewMatrixChannel(
|
||||
localpartMentionR: localpartMentionRegexp(matrixLocalpart(client.UserID)),
|
||||
typingMu: sync.Mutex{},
|
||||
cryptoDbPath: cryptoDatabasePath,
|
||||
}
|
||||
ch.progress = channels.NewToolFeedbackAnimator(ch.EditMessage)
|
||||
return ch, nil
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *MatrixChannel) Start(ctx context.Context) error {
|
||||
@@ -307,9 +297,6 @@ func (c *MatrixChannel) Stop(ctx context.Context) error {
|
||||
c.cancel()
|
||||
}
|
||||
c.stopTypingSessions(ctx)
|
||||
if c.progress != nil {
|
||||
c.progress.StopAll()
|
||||
}
|
||||
|
||||
// Close crypto helper if initialized
|
||||
if c.cryptoHelper != nil {
|
||||
@@ -411,36 +398,11 @@ func (c *MatrixChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([]st
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
isToolFeedback := outboundMessageIsToolFeedback(msg)
|
||||
if isToolFeedback {
|
||||
if msgID, handled, err := c.progress.Update(ctx, msg.ChatID, content); handled {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return []string{msgID}, nil
|
||||
}
|
||||
}
|
||||
trackedMsgID, hasTrackedMsg := c.currentToolFeedbackMessage(msg.ChatID)
|
||||
if !isToolFeedback {
|
||||
if msgIDs, handled := c.FinalizeToolFeedbackMessage(ctx, msg); handled {
|
||||
return msgIDs, nil
|
||||
}
|
||||
}
|
||||
if isToolFeedback {
|
||||
content = channels.InitialAnimatedToolFeedbackContent(content)
|
||||
}
|
||||
|
||||
resp, err := c.client.SendMessageEvent(ctx, roomID, event.EventMessage, c.messageContent(content))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("matrix send: %w", channels.ErrTemporary)
|
||||
}
|
||||
msgID := resp.EventID.String()
|
||||
if isToolFeedback {
|
||||
c.RecordToolFeedbackMessage(msg.ChatID, msgID, msg.Content)
|
||||
} else if hasTrackedMsg {
|
||||
c.dismissTrackedToolFeedbackMessage(ctx, msg.ChatID, trackedMsgID)
|
||||
}
|
||||
return []string{msgID}, nil
|
||||
return []string{resp.EventID.String()}, nil
|
||||
}
|
||||
|
||||
func (c *MatrixChannel) messageContent(text string) *event.MessageEventContent {
|
||||
@@ -457,8 +419,6 @@ func (c *MatrixChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMess
|
||||
if !c.IsRunning() {
|
||||
return nil, channels.ErrNotRunning
|
||||
}
|
||||
trackedMsgID, hasTrackedMsg := c.currentToolFeedbackMessage(msg.ChatID)
|
||||
|
||||
sendCtx := ctx
|
||||
if sendCtx == nil {
|
||||
sendCtx = context.Background()
|
||||
@@ -569,10 +529,6 @@ func (c *MatrixChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMess
|
||||
}
|
||||
}
|
||||
|
||||
if hasTrackedMsg {
|
||||
c.dismissTrackedToolFeedbackMessage(ctx, msg.ChatID, trackedMsgID)
|
||||
}
|
||||
|
||||
return eventIDs, nil
|
||||
}
|
||||
|
||||
@@ -656,89 +612,6 @@ func (c *MatrixChannel) EditMessage(ctx context.Context, chatID string, messageI
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteMessage implements channels.MessageDeleter.
|
||||
func (c *MatrixChannel) DeleteMessage(ctx context.Context, chatID string, messageID string) error {
|
||||
roomID := id.RoomID(strings.TrimSpace(chatID))
|
||||
if roomID == "" {
|
||||
return fmt.Errorf("matrix room ID is empty")
|
||||
}
|
||||
eventID := id.EventID(strings.TrimSpace(messageID))
|
||||
if eventID == "" {
|
||||
return fmt.Errorf("matrix message ID is empty")
|
||||
}
|
||||
|
||||
_, err := c.client.RedactEvent(ctx, roomID, eventID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *MatrixChannel) currentToolFeedbackMessage(chatID string) (string, bool) {
|
||||
if c.progress == nil {
|
||||
return "", false
|
||||
}
|
||||
return c.progress.Current(chatID)
|
||||
}
|
||||
|
||||
func (c *MatrixChannel) takeToolFeedbackMessage(chatID string) (string, string, bool) {
|
||||
if c.progress == nil {
|
||||
return "", "", false
|
||||
}
|
||||
return c.progress.Take(chatID)
|
||||
}
|
||||
|
||||
func (c *MatrixChannel) RecordToolFeedbackMessage(chatID, messageID, content string) {
|
||||
if c.progress == nil {
|
||||
return
|
||||
}
|
||||
c.progress.Record(chatID, messageID, content)
|
||||
}
|
||||
|
||||
func (c *MatrixChannel) ClearToolFeedbackMessage(chatID string) {
|
||||
if c.progress == nil {
|
||||
return
|
||||
}
|
||||
c.progress.Clear(chatID)
|
||||
}
|
||||
|
||||
func (c *MatrixChannel) DismissToolFeedbackMessage(ctx context.Context, chatID string) {
|
||||
msgID, ok := c.currentToolFeedbackMessage(chatID)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
c.dismissTrackedToolFeedbackMessage(ctx, chatID, msgID)
|
||||
}
|
||||
|
||||
func (c *MatrixChannel) dismissTrackedToolFeedbackMessage(ctx context.Context, chatID, messageID string) {
|
||||
if strings.TrimSpace(chatID) == "" || strings.TrimSpace(messageID) == "" {
|
||||
return
|
||||
}
|
||||
c.ClearToolFeedbackMessage(chatID)
|
||||
_ = c.DeleteMessage(ctx, chatID, messageID)
|
||||
}
|
||||
|
||||
func (c *MatrixChannel) finalizeTrackedToolFeedbackMessage(
|
||||
ctx context.Context,
|
||||
chatID string,
|
||||
content string,
|
||||
editFn func(context.Context, string, string, string) error,
|
||||
) ([]string, bool) {
|
||||
msgID, baseContent, ok := c.takeToolFeedbackMessage(chatID)
|
||||
if !ok || editFn == nil {
|
||||
return nil, false
|
||||
}
|
||||
if err := editFn(ctx, chatID, msgID, content); err != nil {
|
||||
c.RecordToolFeedbackMessage(chatID, msgID, baseContent)
|
||||
return nil, false
|
||||
}
|
||||
return []string{msgID}, true
|
||||
}
|
||||
|
||||
func (c *MatrixChannel) FinalizeToolFeedbackMessage(ctx context.Context, msg bus.OutboundMessage) ([]string, bool) {
|
||||
if outboundMessageIsToolFeedback(msg) {
|
||||
return nil, false
|
||||
}
|
||||
return c.finalizeTrackedToolFeedbackMessage(ctx, msg.ChatID, msg.Content, c.EditMessage)
|
||||
}
|
||||
|
||||
func (c *MatrixChannel) handleMemberEvent(ctx context.Context, evt *event.Event) {
|
||||
if !c.config.JoinOnInvite {
|
||||
return
|
||||
|
||||
@@ -14,7 +14,6 @@ import (
|
||||
"maunium.net/go/mautrix/event"
|
||||
"maunium.net/go/mautrix/id"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/channels"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/media"
|
||||
)
|
||||
@@ -42,34 +41,6 @@ func TestMatrixLocalpartMentionRegexp(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestFinalizeTrackedToolFeedbackMessage_StopsTrackingBeforeEdit(t *testing.T) {
|
||||
ch := &MatrixChannel{
|
||||
progress: channels.NewToolFeedbackAnimator(nil),
|
||||
}
|
||||
ch.RecordToolFeedbackMessage("!room:matrix.org", "$event1", "🔧 `read_file`")
|
||||
|
||||
msgIDs, handled := ch.finalizeTrackedToolFeedbackMessage(
|
||||
context.Background(),
|
||||
"!room:matrix.org",
|
||||
"final reply",
|
||||
func(_ context.Context, chatID, messageID, content string) error {
|
||||
if _, ok := ch.currentToolFeedbackMessage(chatID); ok {
|
||||
t.Fatal("expected tracked tool feedback to be stopped before edit")
|
||||
}
|
||||
if chatID != "!room:matrix.org" || messageID != "$event1" || content != "final reply" {
|
||||
t.Fatalf("unexpected edit args: %s %s %s", chatID, messageID, content)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
)
|
||||
if !handled {
|
||||
t.Fatal("expected finalizeTrackedToolFeedbackMessage to handle tracked message")
|
||||
}
|
||||
if len(msgIDs) != 1 || msgIDs[0] != "$event1" {
|
||||
t.Fatalf("finalizeTrackedToolFeedbackMessage() ids = %v, want [$event1]", msgIDs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStripUserMention(t *testing.T) {
|
||||
userID := id.UserID("@picoclaw:matrix.org")
|
||||
|
||||
|
||||
+4
-114
@@ -46,13 +46,6 @@ func outboundMessageIsThought(msg bus.OutboundMessage) bool {
|
||||
return strings.EqualFold(strings.TrimSpace(msg.Context.Raw["message_kind"]), MessageKindThought)
|
||||
}
|
||||
|
||||
func outboundMessageIsToolFeedback(msg bus.OutboundMessage) bool {
|
||||
if len(msg.Context.Raw) == 0 {
|
||||
return false
|
||||
}
|
||||
return strings.EqualFold(strings.TrimSpace(msg.Context.Raw["message_kind"]), "tool_feedback")
|
||||
}
|
||||
|
||||
// writeJSON sends a JSON message to the connection with write locking.
|
||||
func (pc *picoConn) writeJSON(v any) error {
|
||||
if pc.closed.Load() {
|
||||
@@ -85,7 +78,6 @@ type PicoChannel struct {
|
||||
connsMu sync.RWMutex
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
progress *channels.ToolFeedbackAnimator
|
||||
}
|
||||
|
||||
// NewPicoChannel creates a new Pico Protocol channel.
|
||||
@@ -114,7 +106,7 @@ func NewPicoChannel(
|
||||
return false
|
||||
}
|
||||
|
||||
ch := &PicoChannel{
|
||||
return &PicoChannel{
|
||||
BaseChannel: base,
|
||||
bc: bc,
|
||||
config: cfg,
|
||||
@@ -125,9 +117,7 @@ func NewPicoChannel(
|
||||
},
|
||||
connections: make(map[string]*picoConn),
|
||||
sessionConnections: make(map[string]map[string]*picoConn),
|
||||
}
|
||||
ch.progress = channels.NewToolFeedbackAnimator(ch.EditMessage)
|
||||
return ch, nil
|
||||
}, nil
|
||||
}
|
||||
|
||||
// createAndAddConnection checks MaxConnections and registers a connection atomically.
|
||||
@@ -245,9 +235,6 @@ func (c *PicoChannel) Stop(ctx context.Context) error {
|
||||
if c.cancel != nil {
|
||||
c.cancel()
|
||||
}
|
||||
if c.progress != nil {
|
||||
c.progress.StopAll()
|
||||
}
|
||||
|
||||
logger.InfoC("pico", "Pico Protocol channel stopped")
|
||||
return nil
|
||||
@@ -274,43 +261,13 @@ func (c *PicoChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([]stri
|
||||
return nil, channels.ErrNotRunning
|
||||
}
|
||||
isThought := outboundMessageIsThought(msg)
|
||||
isToolFeedback := outboundMessageIsToolFeedback(msg)
|
||||
if isToolFeedback {
|
||||
if msgID, handled, err := c.progress.Update(ctx, msg.ChatID, msg.Content); handled {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return []string{msgID}, nil
|
||||
}
|
||||
}
|
||||
trackedMsgID, hasTrackedMsg := c.currentToolFeedbackMessage(msg.ChatID)
|
||||
if !isToolFeedback {
|
||||
if msgIDs, handled := c.FinalizeToolFeedbackMessage(ctx, msg); handled {
|
||||
return msgIDs, nil
|
||||
}
|
||||
}
|
||||
|
||||
content := msg.Content
|
||||
if isToolFeedback {
|
||||
content = channels.InitialAnimatedToolFeedbackContent(msg.Content)
|
||||
}
|
||||
msgID := uuid.New().String()
|
||||
|
||||
outMsg := newMessage(TypeMessageCreate, map[string]any{
|
||||
PayloadKeyContent: content,
|
||||
PayloadKeyContent: msg.Content,
|
||||
PayloadKeyThought: isThought,
|
||||
"message_id": msgID,
|
||||
})
|
||||
|
||||
if err := c.broadcastToSession(msg.ChatID, outMsg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if isToolFeedback {
|
||||
c.RecordToolFeedbackMessage(msg.ChatID, msgID, msg.Content)
|
||||
} else if hasTrackedMsg {
|
||||
c.dismissTrackedToolFeedbackMessage(ctx, msg.ChatID, trackedMsgID)
|
||||
}
|
||||
return []string{msgID}, nil
|
||||
return nil, c.broadcastToSession(msg.ChatID, outMsg)
|
||||
}
|
||||
|
||||
// EditMessage implements channels.MessageEditor.
|
||||
@@ -322,73 +279,6 @@ func (c *PicoChannel) EditMessage(ctx context.Context, chatID string, messageID
|
||||
return c.broadcastToSession(chatID, outMsg)
|
||||
}
|
||||
|
||||
func (c *PicoChannel) currentToolFeedbackMessage(chatID string) (string, bool) {
|
||||
if c.progress == nil {
|
||||
return "", false
|
||||
}
|
||||
return c.progress.Current(chatID)
|
||||
}
|
||||
|
||||
func (c *PicoChannel) takeToolFeedbackMessage(chatID string) (string, string, bool) {
|
||||
if c.progress == nil {
|
||||
return "", "", false
|
||||
}
|
||||
return c.progress.Take(chatID)
|
||||
}
|
||||
|
||||
func (c *PicoChannel) RecordToolFeedbackMessage(chatID, messageID, content string) {
|
||||
if c.progress == nil {
|
||||
return
|
||||
}
|
||||
c.progress.Record(chatID, messageID, content)
|
||||
}
|
||||
|
||||
func (c *PicoChannel) ClearToolFeedbackMessage(chatID string) {
|
||||
if c.progress == nil {
|
||||
return
|
||||
}
|
||||
c.progress.Clear(chatID)
|
||||
}
|
||||
|
||||
func (c *PicoChannel) DismissToolFeedbackMessage(ctx context.Context, chatID string) {
|
||||
msgID, ok := c.currentToolFeedbackMessage(chatID)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
c.dismissTrackedToolFeedbackMessage(ctx, chatID, msgID)
|
||||
}
|
||||
|
||||
func (c *PicoChannel) dismissTrackedToolFeedbackMessage(ctx context.Context, chatID, messageID string) {
|
||||
if strings.TrimSpace(chatID) == "" || strings.TrimSpace(messageID) == "" {
|
||||
return
|
||||
}
|
||||
c.ClearToolFeedbackMessage(chatID)
|
||||
}
|
||||
|
||||
func (c *PicoChannel) finalizeTrackedToolFeedbackMessage(
|
||||
ctx context.Context,
|
||||
chatID string,
|
||||
content string,
|
||||
editFn func(context.Context, string, string, string) error,
|
||||
) ([]string, bool) {
|
||||
msgID, baseContent, ok := c.takeToolFeedbackMessage(chatID)
|
||||
if !ok || editFn == nil {
|
||||
return nil, false
|
||||
}
|
||||
if err := editFn(ctx, chatID, msgID, content); err != nil {
|
||||
c.RecordToolFeedbackMessage(chatID, msgID, baseContent)
|
||||
return nil, false
|
||||
}
|
||||
return []string{msgID}, true
|
||||
}
|
||||
|
||||
func (c *PicoChannel) FinalizeToolFeedbackMessage(ctx context.Context, msg bus.OutboundMessage) ([]string, bool) {
|
||||
if outboundMessageIsToolFeedback(msg) {
|
||||
return nil, false
|
||||
}
|
||||
return c.finalizeTrackedToolFeedbackMessage(ctx, msg.ChatID, msg.Content, c.EditMessage)
|
||||
}
|
||||
|
||||
// StartTyping implements channels.TypingCapable.
|
||||
func (c *PicoChannel) StartTyping(ctx context.Context, chatID string) (func(), error) {
|
||||
startMsg := newMessage(TypeTypingStart, nil)
|
||||
|
||||
@@ -27,34 +27,6 @@ func newTestPicoChannel(t *testing.T) *PicoChannel {
|
||||
return ch
|
||||
}
|
||||
|
||||
func TestFinalizeTrackedToolFeedbackMessage_StopsTrackingBeforeEdit(t *testing.T) {
|
||||
ch := &PicoChannel{
|
||||
progress: channels.NewToolFeedbackAnimator(nil),
|
||||
}
|
||||
ch.RecordToolFeedbackMessage("pico:chat-1", "msg-1", "🔧 `read_file`")
|
||||
|
||||
msgIDs, handled := ch.finalizeTrackedToolFeedbackMessage(
|
||||
context.Background(),
|
||||
"pico:chat-1",
|
||||
"final reply",
|
||||
func(_ context.Context, chatID, messageID, content string) error {
|
||||
if _, ok := ch.currentToolFeedbackMessage(chatID); ok {
|
||||
t.Fatal("expected tracked tool feedback to be stopped before edit")
|
||||
}
|
||||
if chatID != "pico:chat-1" || messageID != "msg-1" || content != "final reply" {
|
||||
t.Fatalf("unexpected edit args: %s %s %s", chatID, messageID, content)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
)
|
||||
if !handled {
|
||||
t.Fatal("expected finalizeTrackedToolFeedbackMessage to handle tracked message")
|
||||
}
|
||||
if len(msgIDs) != 1 || msgIDs[0] != "msg-1" {
|
||||
t.Fatalf("finalizeTrackedToolFeedbackMessage() ids = %v, want [msg-1]", msgIDs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateAndAddConnection_RespectsMaxConnectionsConcurrently(t *testing.T) {
|
||||
ch := newTestPicoChannel(t)
|
||||
|
||||
|
||||
@@ -66,10 +66,6 @@ func (c *TelegramChannel) startCommandRegistration(ctx context.Context, defs []c
|
||||
if register == nil {
|
||||
register = c.RegisterCommands
|
||||
}
|
||||
delayFn := c.commandRegDelayFn
|
||||
if delayFn == nil {
|
||||
delayFn = commandRegistrationDelay
|
||||
}
|
||||
|
||||
regCtx, cancel := context.WithCancel(ctx)
|
||||
c.commandRegCancel = cancel
|
||||
@@ -95,7 +91,7 @@ func (c *TelegramChannel) startCommandRegistration(ctx context.Context, defs []c
|
||||
return
|
||||
}
|
||||
|
||||
delay := delayFn(attempt)
|
||||
delay := commandRegistrationDelay(attempt)
|
||||
logger.WarnCF("telegram", "Telegram command registration failed; will retry", map[string]any{
|
||||
"error": err.Error(),
|
||||
"retry_after": delay.String(),
|
||||
|
||||
@@ -31,12 +31,14 @@ func TestStartCommandRegistration_DoesNotBlock(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestStartCommandRegistration_RetriesUntilSuccessThenStops(t *testing.T) {
|
||||
ch := &TelegramChannel{
|
||||
commandRegDelayFn: func(int) time.Duration { return 5 * time.Millisecond },
|
||||
}
|
||||
ch := &TelegramChannel{}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
origBackoff := commandRegistrationBackoff
|
||||
commandRegistrationBackoff = []time.Duration{5 * time.Millisecond}
|
||||
defer func() { commandRegistrationBackoff = origBackoff }()
|
||||
|
||||
var attempts atomic.Int32
|
||||
ch.registerFunc = func(context.Context, []commands.Definition) error {
|
||||
n := attempts.Add(1)
|
||||
@@ -67,10 +69,12 @@ func TestStartCommandRegistration_RetriesUntilSuccessThenStops(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestStartCommandRegistration_StopsAfterCancel(t *testing.T) {
|
||||
ch := &TelegramChannel{
|
||||
commandRegDelayFn: func(int) time.Duration { return 5 * time.Millisecond },
|
||||
}
|
||||
ch := &TelegramChannel{}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
origBackoff := commandRegistrationBackoff
|
||||
commandRegistrationBackoff = []time.Duration{5 * time.Millisecond}
|
||||
defer func() { commandRegistrationBackoff = origBackoff }()
|
||||
defer cancel()
|
||||
|
||||
var attempts atomic.Int32
|
||||
|
||||
@@ -45,18 +45,16 @@ var (
|
||||
|
||||
type TelegramChannel struct {
|
||||
*channels.BaseChannel
|
||||
bot *telego.Bot
|
||||
bh *th.BotHandler
|
||||
bc *config.Channel
|
||||
chatIDs map[string]int64
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
tgCfg *config.TelegramSettings
|
||||
progress *channels.ToolFeedbackAnimator
|
||||
bot *telego.Bot
|
||||
bh *th.BotHandler
|
||||
bc *config.Channel
|
||||
chatIDs map[string]int64
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
tgCfg *config.TelegramSettings
|
||||
|
||||
registerFunc func(context.Context, []commands.Definition) error
|
||||
commandRegDelayFn func(int) time.Duration
|
||||
commandRegCancel context.CancelFunc
|
||||
registerFunc func(context.Context, []commands.Definition) error
|
||||
commandRegCancel context.CancelFunc
|
||||
}
|
||||
|
||||
func NewTelegramChannel(
|
||||
@@ -106,15 +104,13 @@ func NewTelegramChannel(
|
||||
channels.WithReasoningChannelID(bc.ReasoningChannelID),
|
||||
)
|
||||
|
||||
ch := &TelegramChannel{
|
||||
return &TelegramChannel{
|
||||
BaseChannel: base,
|
||||
bot: bot,
|
||||
bc: bc,
|
||||
chatIDs: make(map[string]int64),
|
||||
tgCfg: telegramCfg,
|
||||
}
|
||||
ch.progress = channels.NewToolFeedbackAnimator(ch.EditMessage)
|
||||
return ch, nil
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *TelegramChannel) Start(ctx context.Context) error {
|
||||
@@ -172,9 +168,6 @@ func (c *TelegramChannel) Stop(ctx context.Context) error {
|
||||
if c.cancel != nil {
|
||||
c.cancel()
|
||||
}
|
||||
if c.progress != nil {
|
||||
c.progress.StopAll()
|
||||
}
|
||||
if c.commandRegCancel != nil {
|
||||
c.commandRegCancel()
|
||||
}
|
||||
@@ -198,35 +191,12 @@ func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([]
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
isToolFeedback := outboundMessageIsToolFeedback(msg)
|
||||
toolFeedbackContent := msg.Content
|
||||
if isToolFeedback {
|
||||
toolFeedbackContent = fitToolFeedbackForTelegram(msg.Content, useMarkdownV2, 4096)
|
||||
}
|
||||
if isToolFeedback {
|
||||
if msgID, handled, err := c.progress.Update(ctx, msg.ChatID, toolFeedbackContent); handled {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return []string{msgID}, nil
|
||||
}
|
||||
}
|
||||
trackedMsgID, hasTrackedMsg := c.currentToolFeedbackMessage(msg.ChatID)
|
||||
if !isToolFeedback {
|
||||
if msgIDs, handled := c.FinalizeToolFeedbackMessage(ctx, msg); handled {
|
||||
return msgIDs, nil
|
||||
}
|
||||
}
|
||||
|
||||
// The Manager already splits messages to ≤4000 chars (WithMaxMessageLength),
|
||||
// so msg.Content is guaranteed to be within that limit. We still need to
|
||||
// check if HTML expansion pushes it beyond Telegram's 4096-char API limit.
|
||||
replyToID := msg.ReplyToMessageID
|
||||
var messageIDs []string
|
||||
queue := []string{msg.Content}
|
||||
if isToolFeedback {
|
||||
queue = []string{channels.InitialAnimatedToolFeedbackContent(toolFeedbackContent)}
|
||||
}
|
||||
for len(queue) > 0 {
|
||||
chunk := queue[0]
|
||||
queue = queue[1:]
|
||||
@@ -234,13 +204,6 @@ func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([]
|
||||
content := parseContent(chunk, useMarkdownV2)
|
||||
|
||||
if len([]rune(content)) > 4096 {
|
||||
if isToolFeedback {
|
||||
fittedChunk := fitToolFeedbackForTelegram(chunk, useMarkdownV2, 4096)
|
||||
if fittedChunk != "" && fittedChunk != chunk {
|
||||
queue = append([]string{fittedChunk}, queue...)
|
||||
continue
|
||||
}
|
||||
}
|
||||
runeChunk := []rune(chunk)
|
||||
ratio := float64(len(runeChunk)) / float64(len([]rune(content)))
|
||||
smallerLen := int(float64(4096) * ratio * 0.95) // 5% safety margin
|
||||
@@ -307,12 +270,6 @@ func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([]
|
||||
replyToID = ""
|
||||
}
|
||||
|
||||
if isToolFeedback && len(messageIDs) > 0 {
|
||||
c.RecordToolFeedbackMessage(msg.ChatID, messageIDs[0], toolFeedbackContent)
|
||||
} else if !isToolFeedback && hasTrackedMsg {
|
||||
c.dismissTrackedToolFeedbackMessage(ctx, msg.ChatID, trackedMsgID)
|
||||
}
|
||||
|
||||
return messageIDs, nil
|
||||
}
|
||||
|
||||
@@ -480,81 +437,6 @@ func (c *TelegramChannel) DeleteMessage(ctx context.Context, chatID string, mess
|
||||
})
|
||||
}
|
||||
|
||||
func outboundMessageIsToolFeedback(msg bus.OutboundMessage) bool {
|
||||
if len(msg.Context.Raw) == 0 {
|
||||
return false
|
||||
}
|
||||
return strings.EqualFold(strings.TrimSpace(msg.Context.Raw["message_kind"]), "tool_feedback")
|
||||
}
|
||||
|
||||
func (c *TelegramChannel) currentToolFeedbackMessage(chatID string) (string, bool) {
|
||||
if c.progress == nil {
|
||||
return "", false
|
||||
}
|
||||
return c.progress.Current(chatID)
|
||||
}
|
||||
|
||||
func (c *TelegramChannel) takeToolFeedbackMessage(chatID string) (string, string, bool) {
|
||||
if c.progress == nil {
|
||||
return "", "", false
|
||||
}
|
||||
return c.progress.Take(chatID)
|
||||
}
|
||||
|
||||
func (c *TelegramChannel) RecordToolFeedbackMessage(chatID, messageID, content string) {
|
||||
if c.progress == nil {
|
||||
return
|
||||
}
|
||||
c.progress.Record(chatID, messageID, content)
|
||||
}
|
||||
|
||||
func (c *TelegramChannel) ClearToolFeedbackMessage(chatID string) {
|
||||
if c.progress == nil {
|
||||
return
|
||||
}
|
||||
c.progress.Clear(chatID)
|
||||
}
|
||||
|
||||
func (c *TelegramChannel) DismissToolFeedbackMessage(ctx context.Context, chatID string) {
|
||||
msgID, ok := c.currentToolFeedbackMessage(chatID)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
c.dismissTrackedToolFeedbackMessage(ctx, chatID, msgID)
|
||||
}
|
||||
|
||||
func (c *TelegramChannel) dismissTrackedToolFeedbackMessage(ctx context.Context, chatID, messageID string) {
|
||||
if strings.TrimSpace(chatID) == "" || strings.TrimSpace(messageID) == "" {
|
||||
return
|
||||
}
|
||||
c.ClearToolFeedbackMessage(chatID)
|
||||
_ = c.DeleteMessage(ctx, chatID, messageID)
|
||||
}
|
||||
|
||||
func (c *TelegramChannel) finalizeTrackedToolFeedbackMessage(
|
||||
ctx context.Context,
|
||||
chatID string,
|
||||
content string,
|
||||
editFn func(context.Context, string, string, string) error,
|
||||
) ([]string, bool) {
|
||||
msgID, baseContent, ok := c.takeToolFeedbackMessage(chatID)
|
||||
if !ok || editFn == nil {
|
||||
return nil, false
|
||||
}
|
||||
if err := editFn(ctx, chatID, msgID, content); err != nil {
|
||||
c.RecordToolFeedbackMessage(chatID, msgID, baseContent)
|
||||
return nil, false
|
||||
}
|
||||
return []string{msgID}, true
|
||||
}
|
||||
|
||||
func (c *TelegramChannel) FinalizeToolFeedbackMessage(ctx context.Context, msg bus.OutboundMessage) ([]string, bool) {
|
||||
if outboundMessageIsToolFeedback(msg) {
|
||||
return nil, false
|
||||
}
|
||||
return c.finalizeTrackedToolFeedbackMessage(ctx, msg.ChatID, msg.Content, c.EditMessage)
|
||||
}
|
||||
|
||||
// SendPlaceholder implements channels.PlaceholderCapable.
|
||||
// It sends a placeholder message (e.g. "Thinking... 💭") that will later be
|
||||
// edited to the actual response via EditMessage (channels.MessageEditor).
|
||||
@@ -586,7 +468,6 @@ func (c *TelegramChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMe
|
||||
if !c.IsRunning() {
|
||||
return nil, channels.ErrNotRunning
|
||||
}
|
||||
trackedMsgID, hasTrackedMsg := c.currentToolFeedbackMessage(msg.ChatID)
|
||||
|
||||
chatID, threadID, err := resolveTelegramOutboundTarget(msg.ChatID, &msg.Context)
|
||||
if err != nil {
|
||||
@@ -695,10 +576,6 @@ func (c *TelegramChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMe
|
||||
}
|
||||
}
|
||||
|
||||
if hasTrackedMsg {
|
||||
c.dismissTrackedToolFeedbackMessage(ctx, msg.ChatID, trackedMsgID)
|
||||
}
|
||||
|
||||
return messageIDs, nil
|
||||
}
|
||||
|
||||
@@ -1070,41 +947,6 @@ func parseContent(text string, useMarkdownV2 bool) string {
|
||||
return markdownToTelegramHTML(text)
|
||||
}
|
||||
|
||||
func fitToolFeedbackForTelegram(content string, useMarkdownV2 bool, maxParsedLen int) string {
|
||||
content = strings.TrimSpace(content)
|
||||
if content == "" || maxParsedLen <= 0 {
|
||||
return ""
|
||||
}
|
||||
animationSafeLen := maxParsedLen - channels.MaxToolFeedbackAnimationFrameLength()
|
||||
if animationSafeLen <= 0 {
|
||||
animationSafeLen = maxParsedLen
|
||||
}
|
||||
if len([]rune(parseContent(content, useMarkdownV2))) <= animationSafeLen {
|
||||
return content
|
||||
}
|
||||
|
||||
low := 1
|
||||
high := len([]rune(content))
|
||||
best := utils.Truncate(content, 1)
|
||||
|
||||
for low <= high {
|
||||
mid := (low + high) / 2
|
||||
candidate := utils.FitToolFeedbackMessage(content, mid)
|
||||
if candidate == "" {
|
||||
high = mid - 1
|
||||
continue
|
||||
}
|
||||
if len([]rune(parseContent(candidate, useMarkdownV2))) <= animationSafeLen {
|
||||
best = candidate
|
||||
low = mid + 1
|
||||
continue
|
||||
}
|
||||
high = mid - 1
|
||||
}
|
||||
|
||||
return best
|
||||
}
|
||||
|
||||
// parseTelegramChatID splits "chatID/threadID" into its components.
|
||||
// Returns threadID=0 when no "/" is present (non-forum messages).
|
||||
func parseTelegramChatID(chatID string) (int64, int, error) {
|
||||
|
||||
@@ -108,7 +108,7 @@ func TestHandleMessage_GroupMentionOnly_BotCommandEntity(t *testing.T) {
|
||||
t.Fatalf("handleMessage error: %v", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Microsecond)
|
||||
defer cancel()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
|
||||
@@ -98,12 +98,8 @@ func (s *multipartRecordingConstructor) MultipartRequest(
|
||||
|
||||
// successResponse returns a ta.Response that telego will treat as a successful SendMessage.
|
||||
func successResponse(t *testing.T) *ta.Response {
|
||||
return successResponseWithMessageID(t, 1)
|
||||
}
|
||||
|
||||
func successResponseWithMessageID(t *testing.T, messageID int) *ta.Response {
|
||||
t.Helper()
|
||||
msg := &telego.Message{MessageID: messageID}
|
||||
msg := &telego.Message{MessageID: 1}
|
||||
b, err := json.Marshal(msg)
|
||||
require.NoError(t, err)
|
||||
return &ta.Response{Ok: true, Result: b}
|
||||
@@ -146,7 +142,6 @@ func newTestChannelWithConstructor(
|
||||
chatIDs: make(map[string]int64),
|
||||
bc: &config.Channel{Type: config.ChannelTelegram, Enabled: true},
|
||||
tgCfg: &config.TelegramSettings{},
|
||||
progress: channels.NewToolFeedbackAnimator(nil),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -271,101 +266,6 @@ func TestSend_ShortMessage_SingleCall(t *testing.T) {
|
||||
assert.Len(t, caller.calls, 1, "short message should result in exactly one SendMessage call")
|
||||
}
|
||||
|
||||
func TestSend_NonToolFeedbackDeletesTrackedProgressMessage(t *testing.T) {
|
||||
caller := &stubCaller{
|
||||
callFn: func(ctx context.Context, url string, data *ta.RequestData) (*ta.Response, error) {
|
||||
switch {
|
||||
case strings.Contains(url, "editMessageText"):
|
||||
return successResponseWithMessageID(t, 1), nil
|
||||
default:
|
||||
t.Fatalf("unexpected API call: %s", url)
|
||||
return nil, nil
|
||||
}
|
||||
},
|
||||
}
|
||||
ch := newTestChannel(t, caller)
|
||||
ch.RecordToolFeedbackMessage("12345", "1", "🔧 `read_file`")
|
||||
|
||||
ids, err := ch.Send(context.Background(), bus.OutboundMessage{
|
||||
ChatID: "12345",
|
||||
Content: "final reply",
|
||||
})
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []string{"1"}, ids)
|
||||
require.Len(t, caller.calls, 1)
|
||||
assert.Contains(t, caller.calls[0].URL, "editMessageText")
|
||||
_, ok := ch.currentToolFeedbackMessage("12345")
|
||||
assert.False(t, ok, "tracked tool feedback should be cleared after final reply")
|
||||
}
|
||||
|
||||
func TestFinalizeTrackedToolFeedbackMessage_StopsTrackingBeforeEdit(t *testing.T) {
|
||||
ch := newTestChannel(t, &stubCaller{
|
||||
callFn: func(context.Context, string, *ta.RequestData) (*ta.Response, error) {
|
||||
t.Fatal("unexpected API call")
|
||||
return nil, nil
|
||||
},
|
||||
})
|
||||
ch.RecordToolFeedbackMessage("12345", "1", "🔧 `read_file`")
|
||||
|
||||
msgIDs, handled := ch.finalizeTrackedToolFeedbackMessage(
|
||||
context.Background(),
|
||||
"12345",
|
||||
"final reply",
|
||||
func(_ context.Context, chatID, messageID, content string) error {
|
||||
_, ok := ch.currentToolFeedbackMessage(chatID)
|
||||
assert.False(t, ok, "tracked tool feedback should be stopped before edit")
|
||||
assert.Equal(t, "12345", chatID)
|
||||
assert.Equal(t, "1", messageID)
|
||||
assert.Equal(t, "final reply", content)
|
||||
return nil
|
||||
},
|
||||
)
|
||||
|
||||
assert.True(t, handled)
|
||||
assert.Equal(t, []string{"1"}, msgIDs)
|
||||
}
|
||||
|
||||
func TestSend_ToolFeedbackStaysSingleMessageAfterHTMLExpansion(t *testing.T) {
|
||||
caller := &stubCaller{
|
||||
callFn: func(ctx context.Context, url string, data *ta.RequestData) (*ta.Response, error) {
|
||||
return successResponse(t), nil
|
||||
},
|
||||
}
|
||||
ch := newTestChannel(t, caller)
|
||||
|
||||
_, err := ch.Send(context.Background(), bus.OutboundMessage{
|
||||
ChatID: "12345",
|
||||
Content: "🔧 `read_file`\n" + strings.Repeat("<", 2000),
|
||||
Context: bus.InboundContext{
|
||||
Channel: "telegram",
|
||||
ChatID: "12345",
|
||||
Raw: map[string]string{
|
||||
"message_kind": "tool_feedback",
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, caller.calls, 1, "tool feedback should stay a single Telegram message after HTML escaping")
|
||||
}
|
||||
|
||||
func TestFitToolFeedbackForTelegram_ReservesAnimationFrame(t *testing.T) {
|
||||
content := "🔧 `read_file`\n" + strings.Repeat("a", 4096)
|
||||
|
||||
fitted := fitToolFeedbackForTelegram(content, false, 4096)
|
||||
animated := strings.Replace(
|
||||
fitted,
|
||||
"`\n",
|
||||
strings.Repeat(".", channels.MaxToolFeedbackAnimationFrameLength())+"`\n",
|
||||
1,
|
||||
)
|
||||
|
||||
if got := len([]rune(parseContent(animated, false))); got > 4096 {
|
||||
t.Fatalf("animated parsed length = %d, want <= 4096", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSend_LongMessage_SingleCall(t *testing.T) {
|
||||
// With WithMaxMessageLength(4000), the Manager pre-splits messages before
|
||||
// they reach Send(). A message at exactly 4000 chars should go through
|
||||
|
||||
@@ -1,240 +0,0 @@
|
||||
package channels
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const toolFeedbackAnimationInterval = 3 * time.Second
|
||||
|
||||
const initialToolFeedbackAnimationFrame = ""
|
||||
|
||||
var toolFeedbackAnimationFrames = []string{"..", "."}
|
||||
|
||||
// MaxToolFeedbackAnimationFrameLength returns the largest frame suffix length
|
||||
// so callers can reserve room before sending messages to length-limited APIs.
|
||||
func MaxToolFeedbackAnimationFrameLength() int {
|
||||
maxLen := len([]rune(initialToolFeedbackAnimationFrame))
|
||||
for _, frame := range toolFeedbackAnimationFrames {
|
||||
if frameLen := len([]rune(frame)); frameLen > maxLen {
|
||||
maxLen = frameLen
|
||||
}
|
||||
}
|
||||
return maxLen
|
||||
}
|
||||
|
||||
type toolFeedbackAnimationState struct {
|
||||
messageID string
|
||||
baseContent string
|
||||
stop chan struct{}
|
||||
done chan struct{}
|
||||
}
|
||||
|
||||
type ToolFeedbackAnimator struct {
|
||||
mu sync.Mutex
|
||||
editFn func(ctx context.Context, chatID, messageID, content string) error
|
||||
entries map[string]*toolFeedbackAnimationState
|
||||
}
|
||||
|
||||
func NewToolFeedbackAnimator(
|
||||
editFn func(ctx context.Context, chatID, messageID, content string) error,
|
||||
) *ToolFeedbackAnimator {
|
||||
return &ToolFeedbackAnimator{
|
||||
editFn: editFn,
|
||||
entries: make(map[string]*toolFeedbackAnimationState),
|
||||
}
|
||||
}
|
||||
|
||||
func (a *ToolFeedbackAnimator) Current(chatID string) (string, bool) {
|
||||
if a == nil || strings.TrimSpace(chatID) == "" {
|
||||
return "", false
|
||||
}
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
entry, ok := a.entries[chatID]
|
||||
if !ok || strings.TrimSpace(entry.messageID) == "" {
|
||||
return "", false
|
||||
}
|
||||
return entry.messageID, true
|
||||
}
|
||||
|
||||
func (a *ToolFeedbackAnimator) Record(chatID, messageID, content string) {
|
||||
if a == nil {
|
||||
return
|
||||
}
|
||||
chatID = strings.TrimSpace(chatID)
|
||||
messageID = strings.TrimSpace(messageID)
|
||||
content = strings.TrimSpace(content)
|
||||
if chatID == "" || messageID == "" || content == "" {
|
||||
return
|
||||
}
|
||||
|
||||
entry := &toolFeedbackAnimationState{
|
||||
messageID: messageID,
|
||||
baseContent: content,
|
||||
stop: make(chan struct{}),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
|
||||
var previous *toolFeedbackAnimationState
|
||||
a.mu.Lock()
|
||||
if old, ok := a.entries[chatID]; ok {
|
||||
previous = old
|
||||
}
|
||||
a.entries[chatID] = entry
|
||||
a.mu.Unlock()
|
||||
|
||||
stopToolFeedbackAnimation(previous)
|
||||
go a.run(chatID, entry)
|
||||
}
|
||||
|
||||
func (a *ToolFeedbackAnimator) Clear(chatID string) {
|
||||
if a == nil || strings.TrimSpace(chatID) == "" {
|
||||
return
|
||||
}
|
||||
entry := a.detach(chatID)
|
||||
stopToolFeedbackAnimation(entry)
|
||||
}
|
||||
|
||||
func (a *ToolFeedbackAnimator) Take(chatID string) (string, string, bool) {
|
||||
if a == nil || strings.TrimSpace(chatID) == "" {
|
||||
return "", "", false
|
||||
}
|
||||
entry := a.detach(chatID)
|
||||
if entry == nil || strings.TrimSpace(entry.messageID) == "" {
|
||||
return "", "", false
|
||||
}
|
||||
stopToolFeedbackAnimation(entry)
|
||||
return entry.messageID, entry.baseContent, true
|
||||
}
|
||||
|
||||
// Update edits an existing tracked feedback message. If the edit fails, the
|
||||
// previous feedback state is restored so callers can retry without orphaning
|
||||
// the old progress message.
|
||||
func (a *ToolFeedbackAnimator) Update(ctx context.Context, chatID, content string) (string, bool, error) {
|
||||
if a == nil || a.editFn == nil {
|
||||
return "", false, nil
|
||||
}
|
||||
msgID, baseContent, ok := a.Take(chatID)
|
||||
if !ok {
|
||||
return "", false, nil
|
||||
}
|
||||
|
||||
animatedContent := InitialAnimatedToolFeedbackContent(content)
|
||||
if err := a.editFn(ctx, strings.TrimSpace(chatID), msgID, animatedContent); err != nil {
|
||||
a.Record(chatID, msgID, baseContent)
|
||||
return "", true, err
|
||||
}
|
||||
|
||||
a.Record(chatID, msgID, content)
|
||||
return msgID, true, nil
|
||||
}
|
||||
|
||||
func (a *ToolFeedbackAnimator) StopAll() {
|
||||
if a == nil {
|
||||
return
|
||||
}
|
||||
a.mu.Lock()
|
||||
entries := make([]*toolFeedbackAnimationState, 0, len(a.entries))
|
||||
for chatID, entry := range a.entries {
|
||||
entries = append(entries, entry)
|
||||
delete(a.entries, chatID)
|
||||
}
|
||||
a.mu.Unlock()
|
||||
|
||||
for _, entry := range entries {
|
||||
stopToolFeedbackAnimation(entry)
|
||||
}
|
||||
}
|
||||
|
||||
func (a *ToolFeedbackAnimator) detach(chatID string) *toolFeedbackAnimationState {
|
||||
if a == nil || strings.TrimSpace(chatID) == "" {
|
||||
return nil
|
||||
}
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
entry := a.entries[chatID]
|
||||
delete(a.entries, chatID)
|
||||
return entry
|
||||
}
|
||||
|
||||
func (a *ToolFeedbackAnimator) run(chatID string, entry *toolFeedbackAnimationState) {
|
||||
defer close(entry.done)
|
||||
|
||||
ticker := time.NewTicker(toolFeedbackAnimationInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
frameIdx := 1
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-entry.stop:
|
||||
return
|
||||
case <-ticker.C:
|
||||
if a.editFn == nil {
|
||||
continue
|
||||
}
|
||||
frame := toolFeedbackAnimationFrames[frameIdx%len(toolFeedbackAnimationFrames)]
|
||||
content := formatAnimatedToolFeedbackContent(entry.baseContent, frame)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
_ = a.editFn(ctx, chatID, entry.messageID, content)
|
||||
cancel()
|
||||
frameIdx++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func InitialAnimatedToolFeedbackContent(baseContent string) string {
|
||||
return formatAnimatedToolFeedbackContent(baseContent, initialToolFeedbackAnimationFrame)
|
||||
}
|
||||
|
||||
func formatAnimatedToolFeedbackContent(baseContent, frame string) string {
|
||||
baseContent = strings.TrimSpace(baseContent)
|
||||
frame = strings.TrimSpace(frame)
|
||||
if baseContent == "" {
|
||||
return ""
|
||||
}
|
||||
if frame == "" {
|
||||
return baseContent
|
||||
}
|
||||
lineBreak := strings.IndexByte(baseContent, '\n')
|
||||
if lineBreak < 0 {
|
||||
return appendToolFeedbackFrame(baseContent, frame)
|
||||
}
|
||||
return appendToolFeedbackFrame(baseContent[:lineBreak], frame) + baseContent[lineBreak:]
|
||||
}
|
||||
|
||||
func appendToolFeedbackFrame(firstLine, frame string) string {
|
||||
firstLine = strings.TrimSpace(firstLine)
|
||||
frame = strings.TrimSpace(frame)
|
||||
if firstLine == "" {
|
||||
return ""
|
||||
}
|
||||
if frame == "" {
|
||||
return firstLine
|
||||
}
|
||||
|
||||
openTick := strings.IndexByte(firstLine, '`')
|
||||
if openTick >= 0 {
|
||||
if closeOffset := strings.IndexByte(firstLine[openTick+1:], '`'); closeOffset >= 0 {
|
||||
closeTick := openTick + 1 + closeOffset
|
||||
return firstLine[:closeTick] + frame + firstLine[closeTick:]
|
||||
}
|
||||
}
|
||||
|
||||
return firstLine + frame
|
||||
}
|
||||
|
||||
func stopToolFeedbackAnimation(entry *toolFeedbackAnimationState) {
|
||||
if entry == nil {
|
||||
return
|
||||
}
|
||||
select {
|
||||
case <-entry.stop:
|
||||
default:
|
||||
close(entry.stop)
|
||||
}
|
||||
<-entry.done
|
||||
}
|
||||
@@ -1,121 +0,0 @@
|
||||
package channels
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestFormatAnimatedToolFeedbackContent(t *testing.T) {
|
||||
got := formatAnimatedToolFeedbackContent("🔧 `read_file`\nReading config file", "running..")
|
||||
want := "🔧 `read_filerunning..`\nReading config file"
|
||||
if got != want {
|
||||
t.Fatalf("formatAnimatedToolFeedbackContent() = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInitialAnimatedToolFeedbackContent(t *testing.T) {
|
||||
got := InitialAnimatedToolFeedbackContent("🔧 `exec`\nRunning command")
|
||||
want := "🔧 `exec`\nRunning command"
|
||||
if got != want {
|
||||
t.Fatalf("InitialAnimatedToolFeedbackContent() = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatAnimatedToolFeedbackContent_WithoutCodeSpan(t *testing.T) {
|
||||
got := formatAnimatedToolFeedbackContent("hello", "running..")
|
||||
want := "hellorunning.."
|
||||
if got != want {
|
||||
t.Fatalf("formatAnimatedToolFeedbackContent() without code span = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolFeedbackAnimator_RecordCurrentAndClear(t *testing.T) {
|
||||
animator := NewToolFeedbackAnimator(nil)
|
||||
animator.Record("chat-1", "msg-1", "🔧 `read_file`")
|
||||
|
||||
msgID, ok := animator.Current("chat-1")
|
||||
if !ok || msgID != "msg-1" {
|
||||
t.Fatalf("Current() = (%q, %v), want (msg-1, true)", msgID, ok)
|
||||
}
|
||||
|
||||
animator.Clear("chat-1")
|
||||
|
||||
msgID, ok = animator.Current("chat-1")
|
||||
if ok || msgID != "" {
|
||||
t.Fatalf("Current() after Clear = (%q, %v), want (\"\", false)", msgID, ok)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolFeedbackAnimator_TakeStopsTrackingAndReturnsState(t *testing.T) {
|
||||
animator := NewToolFeedbackAnimator(nil)
|
||||
animator.Record("chat-1", "msg-1", "🔧 `read_file`\nChecking config")
|
||||
|
||||
msgID, baseContent, ok := animator.Take("chat-1")
|
||||
if !ok {
|
||||
t.Fatal("Take() = not found, want tracked message")
|
||||
}
|
||||
if msgID != "msg-1" {
|
||||
t.Fatalf("Take() msgID = %q, want msg-1", msgID)
|
||||
}
|
||||
if baseContent != "🔧 `read_file`\nChecking config" {
|
||||
t.Fatalf("Take() baseContent = %q", baseContent)
|
||||
}
|
||||
if _, ok := animator.Current("chat-1"); ok {
|
||||
t.Fatal("expected tracked message to be removed after Take()")
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolFeedbackAnimator_UpdateStopsTrackingBeforeEdit(t *testing.T) {
|
||||
var animator *ToolFeedbackAnimator
|
||||
animator = NewToolFeedbackAnimator(func(_ context.Context, chatID, messageID, content string) error {
|
||||
if _, ok := animator.Current(chatID); ok {
|
||||
t.Fatal("expected tracked tool feedback to be stopped before edit")
|
||||
}
|
||||
if messageID != "msg-1" {
|
||||
t.Fatalf("messageID = %q, want msg-1", messageID)
|
||||
}
|
||||
if content != "🔧 `write_file`\nUpdating config" {
|
||||
t.Fatalf("content = %q, want updated animated content", content)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
defer animator.StopAll()
|
||||
|
||||
animator.Record("chat-1", "msg-1", "🔧 `read_file`\nChecking config")
|
||||
|
||||
msgID, handled, err := animator.Update(context.Background(), "chat-1", "🔧 `write_file`\nUpdating config")
|
||||
if err != nil {
|
||||
t.Fatalf("Update() error = %v", err)
|
||||
}
|
||||
if !handled {
|
||||
t.Fatal("Update() handled = false, want true")
|
||||
}
|
||||
if msgID != "msg-1" {
|
||||
t.Fatalf("Update() msgID = %q, want msg-1", msgID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolFeedbackAnimator_UpdateFailureRestoresTracking(t *testing.T) {
|
||||
editErr := errors.New("edit failed")
|
||||
animator := NewToolFeedbackAnimator(func(context.Context, string, string, string) error {
|
||||
return editErr
|
||||
})
|
||||
defer animator.StopAll()
|
||||
|
||||
animator.Record("chat-1", "msg-1", "🔧 `read_file`\nChecking config")
|
||||
|
||||
msgID, handled, err := animator.Update(context.Background(), "chat-1", "🔧 `write_file`\nUpdating config")
|
||||
if !handled {
|
||||
t.Fatal("Update() handled = false, want true")
|
||||
}
|
||||
if !errors.Is(err, editErr) {
|
||||
t.Fatalf("Update() error = %v, want editErr", err)
|
||||
}
|
||||
if msgID != "" {
|
||||
t.Fatalf("Update() msgID = %q, want empty on failed edit", msgID)
|
||||
}
|
||||
if currentID, ok := animator.Current("chat-1"); !ok || currentID != "msg-1" {
|
||||
t.Fatalf("Current() after failed Update = (%q, %v), want (msg-1, true)", currentID, ok)
|
||||
}
|
||||
}
|
||||
@@ -286,7 +286,7 @@ func (d *AgentDefaults) GetMaxMediaSize() int {
|
||||
return DefaultMaxMediaSize
|
||||
}
|
||||
|
||||
// GetToolFeedbackMaxArgsLength returns the max visible text length for tool feedback messages.
|
||||
// GetToolFeedbackMaxArgsLength returns the max args preview length for tool feedback messages.
|
||||
func (d *AgentDefaults) GetToolFeedbackMaxArgsLength() int {
|
||||
if d.ToolFeedback.MaxArgsLength > 0 {
|
||||
return d.ToolFeedback.MaxArgsLength
|
||||
|
||||
@@ -55,12 +55,6 @@ func buildCLIToolsPrompt(tools []ToolDefinition) string {
|
||||
func NormalizeToolCall(tc ToolCall) ToolCall {
|
||||
normalized := tc
|
||||
|
||||
if normalized.ThoughtSignature == "" &&
|
||||
normalized.ExtraContent != nil &&
|
||||
normalized.ExtraContent.Google != nil {
|
||||
normalized.ThoughtSignature = normalized.ExtraContent.Google.ThoughtSignature
|
||||
}
|
||||
|
||||
// Ensure Name is populated from Function if not set
|
||||
if normalized.Name == "" && normalized.Function != nil {
|
||||
normalized.Name = normalized.Function.Name
|
||||
@@ -83,9 +77,8 @@ func NormalizeToolCall(tc ToolCall) ToolCall {
|
||||
argsJSON, _ := json.Marshal(normalized.Arguments)
|
||||
if normalized.Function == nil {
|
||||
normalized.Function = &FunctionCall{
|
||||
Name: normalized.Name,
|
||||
Arguments: string(argsJSON),
|
||||
ThoughtSignature: normalized.ThoughtSignature,
|
||||
Name: normalized.Name,
|
||||
Arguments: string(argsJSON),
|
||||
}
|
||||
} else {
|
||||
if normalized.Function.Name == "" {
|
||||
@@ -97,12 +90,6 @@ func NormalizeToolCall(tc ToolCall) ToolCall {
|
||||
if normalized.Function.Arguments == "" {
|
||||
normalized.Function.Arguments = string(argsJSON)
|
||||
}
|
||||
if normalized.Function.ThoughtSignature == "" {
|
||||
normalized.Function.ThoughtSignature = normalized.ThoughtSignature
|
||||
}
|
||||
if normalized.ThoughtSignature == "" {
|
||||
normalized.ThoughtSignature = normalized.Function.ThoughtSignature
|
||||
}
|
||||
}
|
||||
|
||||
return normalized
|
||||
|
||||
@@ -70,23 +70,11 @@ func NewHTTPClient(proxy string) *http.Client {
|
||||
// It mirrors protocoltypes.Message but omits SystemParts, which is an
|
||||
// internal field that would be unknown to third-party endpoints.
|
||||
type openaiMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||
ToolCalls []openaiToolCall `json:"tool_calls,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
}
|
||||
|
||||
type openaiToolCall struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type,omitempty"`
|
||||
Function *openaiFunctionCall `json:"function,omitempty"`
|
||||
}
|
||||
|
||||
type openaiFunctionCall struct {
|
||||
Name string `json:"name"`
|
||||
Arguments string `json:"arguments"`
|
||||
ThoughtSignature string `json:"thought_signature,omitempty"`
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
}
|
||||
|
||||
// SerializeMessages converts internal Message structs to the OpenAI wire format.
|
||||
@@ -96,13 +84,12 @@ type openaiFunctionCall struct {
|
||||
func SerializeMessages(messages []Message) []any {
|
||||
out := make([]any, 0, len(messages))
|
||||
for _, m := range messages {
|
||||
toolCalls := serializeToolCalls(m.ToolCalls)
|
||||
if len(m.Media) == 0 {
|
||||
out = append(out, openaiMessage{
|
||||
Role: m.Role,
|
||||
Content: m.Content,
|
||||
ReasoningContent: m.ReasoningContent,
|
||||
ToolCalls: toolCalls,
|
||||
ToolCalls: m.ToolCalls,
|
||||
ToolCallID: m.ToolCallID,
|
||||
})
|
||||
continue
|
||||
@@ -145,8 +132,8 @@ func SerializeMessages(messages []Message) []any {
|
||||
if m.ToolCallID != "" {
|
||||
msg["tool_call_id"] = m.ToolCallID
|
||||
}
|
||||
if len(toolCalls) > 0 {
|
||||
msg["tool_calls"] = toolCalls
|
||||
if len(m.ToolCalls) > 0 {
|
||||
msg["tool_calls"] = m.ToolCalls
|
||||
}
|
||||
if m.ReasoningContent != "" {
|
||||
msg["reasoning_content"] = m.ReasoningContent
|
||||
@@ -156,55 +143,6 @@ func SerializeMessages(messages []Message) []any {
|
||||
return out
|
||||
}
|
||||
|
||||
func serializeToolCalls(toolCalls []ToolCall) []openaiToolCall {
|
||||
if len(toolCalls) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
out := make([]openaiToolCall, 0, len(toolCalls))
|
||||
for _, tc := range toolCalls {
|
||||
wireCall := openaiToolCall{
|
||||
ID: tc.ID,
|
||||
Type: tc.Type,
|
||||
}
|
||||
|
||||
if tc.Function != nil {
|
||||
thoughtSignature := tc.Function.ThoughtSignature
|
||||
if thoughtSignature == "" {
|
||||
thoughtSignature = tc.ThoughtSignature
|
||||
}
|
||||
if thoughtSignature == "" && tc.ExtraContent != nil && tc.ExtraContent.Google != nil {
|
||||
thoughtSignature = tc.ExtraContent.Google.ThoughtSignature
|
||||
}
|
||||
wireCall.Function = &openaiFunctionCall{
|
||||
Name: tc.Function.Name,
|
||||
Arguments: tc.Function.Arguments,
|
||||
ThoughtSignature: thoughtSignature,
|
||||
}
|
||||
} else if tc.Name != "" || len(tc.Arguments) > 0 || tc.ThoughtSignature != "" {
|
||||
thoughtSignature := tc.ThoughtSignature
|
||||
if thoughtSignature == "" && tc.ExtraContent != nil && tc.ExtraContent.Google != nil {
|
||||
thoughtSignature = tc.ExtraContent.Google.ThoughtSignature
|
||||
}
|
||||
argsJSON := "{}"
|
||||
if len(tc.Arguments) > 0 {
|
||||
if encoded, err := json.Marshal(tc.Arguments); err == nil {
|
||||
argsJSON = string(encoded)
|
||||
}
|
||||
}
|
||||
wireCall.Function = &openaiFunctionCall{
|
||||
Name: tc.Name,
|
||||
Arguments: argsJSON,
|
||||
ThoughtSignature: thoughtSignature,
|
||||
}
|
||||
}
|
||||
|
||||
out = append(out, wireCall)
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func parseDataAudioURL(mediaURL string) (format, data string, ok bool) {
|
||||
if !strings.HasPrefix(mediaURL, "data:audio/") {
|
||||
return "", "", false
|
||||
@@ -247,7 +185,6 @@ func ParseResponse(body io.Reader) (*LLMResponse, error) {
|
||||
Google *struct {
|
||||
ThoughtSignature string `json:"thought_signature"`
|
||||
} `json:"google"`
|
||||
ToolFeedbackExplanation string `json:"tool_feedback_explanation"`
|
||||
} `json:"extra_content"`
|
||||
} `json:"tool_calls"`
|
||||
} `json:"message"`
|
||||
@@ -291,17 +228,11 @@ func ParseResponse(body io.Reader) (*LLMResponse, error) {
|
||||
ThoughtSignature: thoughtSignature,
|
||||
}
|
||||
|
||||
if tc.ExtraContent != nil {
|
||||
extraContent := &ExtraContent{
|
||||
ToolFeedbackExplanation: tc.ExtraContent.ToolFeedbackExplanation,
|
||||
}
|
||||
if thoughtSignature != "" {
|
||||
extraContent.Google = &GoogleExtra{
|
||||
if thoughtSignature != "" {
|
||||
toolCall.ExtraContent = &ExtraContent{
|
||||
Google: &GoogleExtra{
|
||||
ThoughtSignature: thoughtSignature,
|
||||
}
|
||||
}
|
||||
if extraContent.Google != nil || strings.TrimSpace(extraContent.ToolFeedbackExplanation) != "" {
|
||||
toolCall.ExtraContent = extraContent
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -162,104 +162,6 @@ func TestSerializeMessages_StripsSystemParts(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSerializeMessages_StripsInternalToolCallExtraContent(t *testing.T) {
|
||||
messages := []Message{
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []ToolCall{{
|
||||
ID: "call_1",
|
||||
Type: "function",
|
||||
Function: &FunctionCall{
|
||||
Name: "read_file",
|
||||
Arguments: `{"path":"README.md"}`,
|
||||
ThoughtSignature: "sig-1",
|
||||
},
|
||||
ExtraContent: &ExtraContent{
|
||||
Google: &GoogleExtra{
|
||||
ThoughtSignature: "sig-ignored-here",
|
||||
},
|
||||
ToolFeedbackExplanation: "Read README.md first.",
|
||||
},
|
||||
}},
|
||||
},
|
||||
}
|
||||
|
||||
result := SerializeMessages(messages)
|
||||
|
||||
data, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
t.Fatalf("json.Marshal() error = %v", err)
|
||||
}
|
||||
payload := string(data)
|
||||
if strings.Contains(payload, "extra_content") {
|
||||
t.Fatalf("serialized payload should not include internal extra_content: %s", payload)
|
||||
}
|
||||
if !strings.Contains(payload, "thought_signature") {
|
||||
t.Fatalf("serialized payload should preserve function thought_signature: %s", payload)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSerializeMessages_PreservesTopLevelThoughtSignature(t *testing.T) {
|
||||
messages := []Message{
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []ToolCall{{
|
||||
ID: "call_1",
|
||||
Type: "function",
|
||||
ThoughtSignature: "sig-1",
|
||||
Function: &FunctionCall{
|
||||
Name: "read_file",
|
||||
Arguments: `{"path":"README.md"}`,
|
||||
},
|
||||
}},
|
||||
},
|
||||
}
|
||||
|
||||
result := SerializeMessages(messages)
|
||||
|
||||
data, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
t.Fatalf("json.Marshal() error = %v", err)
|
||||
}
|
||||
payload := string(data)
|
||||
if !strings.Contains(payload, `"thought_signature":"sig-1"`) {
|
||||
t.Fatalf("serialized payload should preserve top-level thought signature: %s", payload)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSerializeMessages_PreservesGoogleExtraThoughtSignature(t *testing.T) {
|
||||
messages := []Message{
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []ToolCall{{
|
||||
ID: "call_1",
|
||||
Type: "function",
|
||||
Function: &FunctionCall{
|
||||
Name: "read_file",
|
||||
Arguments: `{"path":"README.md"}`,
|
||||
},
|
||||
ExtraContent: &ExtraContent{
|
||||
Google: &GoogleExtra{ThoughtSignature: "sig-1"},
|
||||
},
|
||||
}},
|
||||
},
|
||||
}
|
||||
|
||||
result := SerializeMessages(messages)
|
||||
|
||||
data, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
t.Fatalf("json.Marshal() error = %v", err)
|
||||
}
|
||||
payload := string(data)
|
||||
if strings.Contains(payload, "extra_content") {
|
||||
t.Fatalf("serialized payload should not include extra_content: %s", payload)
|
||||
}
|
||||
if !strings.Contains(payload, `"thought_signature":"sig-1"`) {
|
||||
t.Fatalf("serialized payload should preserve google thought signature: %s", payload)
|
||||
}
|
||||
}
|
||||
|
||||
// --- ParseResponse tests ---
|
||||
|
||||
func TestParseResponse_BasicContent(t *testing.T) {
|
||||
@@ -332,27 +234,6 @@ func TestParseResponse_WithReasoningContent(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseResponse_WithToolFeedbackExplanationExtraContent(t *testing.T) {
|
||||
body := `{"choices":[{"message":{"content":"","tool_calls":[{"id":"call_1","type":"function","function":{"name":"test_tool","arguments":"{}"},"extra_content":{"tool_feedback_explanation":"Check the current config before editing."}}]},"finish_reason":"tool_calls"}]}`
|
||||
out, err := ParseResponse(strings.NewReader(body))
|
||||
if err != nil {
|
||||
t.Fatalf("ParseResponse() error = %v", err)
|
||||
}
|
||||
if len(out.ToolCalls) != 1 {
|
||||
t.Fatalf("len(ToolCalls) = %d, want 1", len(out.ToolCalls))
|
||||
}
|
||||
if out.ToolCalls[0].ExtraContent == nil {
|
||||
t.Fatal("ExtraContent is nil")
|
||||
}
|
||||
if out.ToolCalls[0].ExtraContent.ToolFeedbackExplanation != "Check the current config before editing." {
|
||||
t.Fatalf(
|
||||
"ToolFeedbackExplanation = %q, want %q",
|
||||
out.ToolCalls[0].ExtraContent.ToolFeedbackExplanation,
|
||||
"Check the current config before editing.",
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseResponse_InvalidJSON(t *testing.T) {
|
||||
_, err := ParseResponse(strings.NewReader("not json"))
|
||||
if err == nil {
|
||||
|
||||
@@ -11,8 +11,7 @@ type ToolCall struct {
|
||||
}
|
||||
|
||||
type ExtraContent struct {
|
||||
Google *GoogleExtra `json:"google,omitempty"`
|
||||
ToolFeedbackExplanation string `json:"tool_feedback_explanation,omitempty"`
|
||||
Google *GoogleExtra `json:"google,omitempty"`
|
||||
}
|
||||
|
||||
type GoogleExtra struct {
|
||||
|
||||
@@ -1,24 +0,0 @@
|
||||
package providers
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestNormalizeToolCall_PreservesExtraContentGoogleThoughtSignature(t *testing.T) {
|
||||
tc := NormalizeToolCall(ToolCall{
|
||||
ID: "call_1",
|
||||
Name: "search",
|
||||
Arguments: map[string]any{"q": "pico"},
|
||||
ExtraContent: &ExtraContent{
|
||||
Google: &GoogleExtra{ThoughtSignature: "sig-1"},
|
||||
},
|
||||
})
|
||||
|
||||
if tc.ThoughtSignature != "sig-1" {
|
||||
t.Fatalf("ThoughtSignature = %q, want sig-1", tc.ThoughtSignature)
|
||||
}
|
||||
if tc.Function == nil {
|
||||
t.Fatal("Function is nil")
|
||||
}
|
||||
if tc.Function.ThoughtSignature != "sig-1" {
|
||||
t.Fatalf("Function.ThoughtSignature = %q, want sig-1", tc.Function.ThoughtSignature)
|
||||
}
|
||||
}
|
||||
@@ -1,57 +1,9 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
import "fmt"
|
||||
|
||||
const ToolFeedbackContinuationHint = "Continuing the current task."
|
||||
|
||||
// FormatToolFeedbackMessage renders the model-provided explanation for why a
|
||||
// tool is being executed. When the model does not provide one, it keeps only
|
||||
// the tool line and does not expose raw arguments or fallback text.
|
||||
func FormatToolFeedbackMessage(toolName, explanation string) string {
|
||||
toolName = strings.TrimSpace(toolName)
|
||||
explanation = strings.TrimSpace(explanation)
|
||||
|
||||
if toolName == "" {
|
||||
return explanation
|
||||
}
|
||||
if explanation == "" {
|
||||
return fmt.Sprintf("\U0001f527 `%s`", toolName)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("\U0001f527 `%s`\n%s", toolName, explanation)
|
||||
}
|
||||
|
||||
// FitToolFeedbackMessage keeps tool feedback within a single outbound message.
|
||||
// It preserves the first line when possible and truncates the explanation body
|
||||
// instead of letting the message be split into multiple chunks.
|
||||
func FitToolFeedbackMessage(content string, maxLen int) string {
|
||||
content = strings.TrimSpace(content)
|
||||
if content == "" || maxLen <= 0 {
|
||||
return ""
|
||||
}
|
||||
if len([]rune(content)) <= maxLen {
|
||||
return content
|
||||
}
|
||||
|
||||
firstLine, rest, hasRest := strings.Cut(content, "\n")
|
||||
firstLine = strings.TrimSpace(firstLine)
|
||||
rest = strings.TrimSpace(rest)
|
||||
|
||||
if !hasRest || rest == "" {
|
||||
return Truncate(firstLine, maxLen)
|
||||
}
|
||||
|
||||
if len([]rune(firstLine)) >= maxLen {
|
||||
return Truncate(firstLine, maxLen)
|
||||
}
|
||||
|
||||
remaining := maxLen - len([]rune(firstLine)) - 1
|
||||
if remaining <= 0 {
|
||||
return Truncate(firstLine, maxLen)
|
||||
}
|
||||
|
||||
return firstLine + "\n" + Truncate(rest, remaining)
|
||||
// FormatToolFeedbackMessage renders the tool name and arguments preview in the
|
||||
// same markdown shape used by live tool feedback and session reconstruction.
|
||||
func FormatToolFeedbackMessage(toolName, argsPreview string) string {
|
||||
return fmt.Sprintf("\U0001f527 `%s`\n```\n%s\n```", toolName, argsPreview)
|
||||
}
|
||||
|
||||
@@ -3,47 +3,9 @@ package utils
|
||||
import "testing"
|
||||
|
||||
func TestFormatToolFeedbackMessage(t *testing.T) {
|
||||
got := FormatToolFeedbackMessage(
|
||||
"read_file",
|
||||
"I will read README.md first to confirm the current project structure.",
|
||||
)
|
||||
want := "\U0001f527 `read_file`\nI will read README.md first to confirm the current project structure."
|
||||
got := FormatToolFeedbackMessage("read_file", "{\"path\":\"README.md\"}")
|
||||
want := "\U0001f527 `read_file`\n```\n{\"path\":\"README.md\"}\n```"
|
||||
if got != want {
|
||||
t.Fatalf("FormatToolFeedbackMessage() = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatToolFeedbackMessage_EmptyExplanationKeepsOnlyToolLine(t *testing.T) {
|
||||
got := FormatToolFeedbackMessage("read_file", "")
|
||||
want := "\U0001f527 `read_file`"
|
||||
if got != want {
|
||||
t.Fatalf("FormatToolFeedbackMessage() = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatToolFeedbackMessage_EmptyToolNameOmitsToolLine(t *testing.T) {
|
||||
got := FormatToolFeedbackMessage("", "Continue drafting the final response.")
|
||||
want := "Continue drafting the final response."
|
||||
if got != want {
|
||||
t.Fatalf("FormatToolFeedbackMessage() = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFitToolFeedbackMessage_TruncatesBodyWithinSingleMessage(t *testing.T) {
|
||||
got := FitToolFeedbackMessage(
|
||||
"\U0001f527 `read_file`\nRead README.md first to confirm the current project structure.",
|
||||
40,
|
||||
)
|
||||
want := "\U0001f527 `read_file`\nRead README.md first to..."
|
||||
if got != want {
|
||||
t.Fatalf("FitToolFeedbackMessage() = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFitToolFeedbackMessage_TruncatesSingleLineMessage(t *testing.T) {
|
||||
got := FitToolFeedbackMessage("\U0001f527 `read_file`", 10)
|
||||
want := "\U0001f527 `read..."
|
||||
if got != want {
|
||||
t.Fatalf("FitToolFeedbackMessage() = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
+15
-74
@@ -486,15 +486,6 @@ func visibleSessionMessages(messages []providers.Message, toolFeedbackMaxArgsLen
|
||||
transcript = append(transcript, visibleToolMessages...)
|
||||
}
|
||||
|
||||
// When assistant content exactly matches the rendered tool summary or
|
||||
// tool-delivered message, skip it to avoid duplicates. Distinct content
|
||||
// must remain visible in restored session history.
|
||||
if len(msg.ToolCalls) > 0 &&
|
||||
len(msg.Media) == 0 &&
|
||||
assistantToolCallContentDuplicated(msg.Content, toolSummaryMessages, visibleToolMessages) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Pico web chat can persist both visible `message` tool output and a
|
||||
// later plain assistant reply in the same turn. Hide only the fixed
|
||||
// internal summary that marks handled tool delivery.
|
||||
@@ -513,43 +504,6 @@ func visibleSessionMessages(messages []providers.Message, toolFeedbackMaxArgsLen
|
||||
return transcript
|
||||
}
|
||||
|
||||
func assistantToolCallContentDuplicated(
|
||||
content string,
|
||||
toolSummaryMessages []sessionChatMessage,
|
||||
visibleToolMessages []sessionChatMessage,
|
||||
) bool {
|
||||
content = strings.TrimSpace(content)
|
||||
if content == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, msg := range toolSummaryMessages {
|
||||
if toolSummaryContainsContent(msg.Content, content) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
for _, msg := range visibleToolMessages {
|
||||
if strings.TrimSpace(msg.Content) == content {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func toolSummaryContainsContent(summary, content string) bool {
|
||||
summary = strings.TrimSpace(summary)
|
||||
content = strings.TrimSpace(content)
|
||||
if summary == "" || content == "" {
|
||||
return false
|
||||
}
|
||||
if summary == content {
|
||||
return true
|
||||
}
|
||||
|
||||
_, body, hasBody := strings.Cut(summary, "\n")
|
||||
return hasBody && strings.TrimSpace(body) == content
|
||||
}
|
||||
|
||||
func assistantMessageTransientThought(msg providers.Message) bool {
|
||||
return strings.TrimSpace(msg.Content) == "" &&
|
||||
strings.TrimSpace(msg.ReasoningContent) != "" &&
|
||||
@@ -575,51 +529,38 @@ func visibleAssistantToolSummaryMessages(
|
||||
messages := make([]sessionChatMessage, 0, len(toolCalls))
|
||||
for _, tc := range toolCalls {
|
||||
name := tc.Name
|
||||
argsJSON := ""
|
||||
if tc.Function != nil {
|
||||
if name == "" {
|
||||
name = tc.Function.Name
|
||||
}
|
||||
argsJSON = tc.Function.Arguments
|
||||
}
|
||||
|
||||
if strings.TrimSpace(name) == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
if strings.TrimSpace(argsJSON) == "" && len(tc.Arguments) > 0 {
|
||||
if encodedArgs, err := json.Marshal(tc.Arguments); err == nil {
|
||||
argsJSON = string(encodedArgs)
|
||||
}
|
||||
}
|
||||
|
||||
argsPreview := strings.TrimSpace(argsJSON)
|
||||
if argsPreview == "" {
|
||||
argsPreview = "{}"
|
||||
}
|
||||
|
||||
messages = append(messages, sessionChatMessage{
|
||||
Role: "assistant",
|
||||
Content: utils.FormatToolFeedbackMessage(
|
||||
name,
|
||||
visibleAssistantToolSummaryText(tc, toolFeedbackMaxArgsLength),
|
||||
),
|
||||
Role: "assistant",
|
||||
Content: utils.FormatToolFeedbackMessage(name, utils.Truncate(argsPreview, toolFeedbackMaxArgsLength)),
|
||||
})
|
||||
}
|
||||
|
||||
return messages
|
||||
}
|
||||
|
||||
func visibleAssistantToolSummaryText(
|
||||
tc providers.ToolCall,
|
||||
toolFeedbackMaxArgsLength int,
|
||||
) string {
|
||||
if tc.ExtraContent != nil {
|
||||
if explanation := strings.TrimSpace(tc.ExtraContent.ToolFeedbackExplanation); explanation != "" {
|
||||
return utils.Truncate(explanation, toolFeedbackMaxArgsLength)
|
||||
}
|
||||
}
|
||||
|
||||
argsJSON := ""
|
||||
if tc.Function != nil {
|
||||
argsJSON = tc.Function.Arguments
|
||||
}
|
||||
if strings.TrimSpace(argsJSON) == "" && len(tc.Arguments) > 0 {
|
||||
if encodedArgs, err := json.Marshal(tc.Arguments); err == nil {
|
||||
argsJSON = string(encodedArgs)
|
||||
}
|
||||
}
|
||||
|
||||
return utils.Truncate(strings.TrimSpace(argsJSON), toolFeedbackMaxArgsLength)
|
||||
}
|
||||
|
||||
func visibleAssistantToolMessages(toolCalls []providers.ToolCall) []sessionChatMessage {
|
||||
if len(toolCalls) == 0 {
|
||||
return nil
|
||||
|
||||
@@ -540,7 +540,7 @@ func TestHandleListSessions_MessageCountUsesVisibleTranscript(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleGetSession_DoesNotDuplicateAssistantToolCallContent(t *testing.T) {
|
||||
func TestHandleGetSession_PreservesToolSummaryAndAssistantContent(t *testing.T) {
|
||||
configPath, cleanup := setupOAuthTestEnv(t)
|
||||
defer cleanup()
|
||||
|
||||
@@ -555,7 +555,7 @@ func TestHandleGetSession_DoesNotDuplicateAssistantToolCallContent(t *testing.T)
|
||||
{Role: "user", Content: "check file"},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "Read the file before replying.",
|
||||
Content: "model final reply",
|
||||
ToolCalls: []providers.ToolCall{
|
||||
{
|
||||
ID: "call_1",
|
||||
@@ -564,9 +564,6 @@ func TestHandleGetSession_DoesNotDuplicateAssistantToolCallContent(t *testing.T)
|
||||
Name: "read_file",
|
||||
Arguments: `{"path":"README.md","start_line":1,"end_line":10}`,
|
||||
},
|
||||
ExtraContent: &providers.ExtraContent{
|
||||
ToolFeedbackExplanation: "Read the file before replying.",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -597,8 +594,8 @@ func TestHandleGetSession_DoesNotDuplicateAssistantToolCallContent(t *testing.T)
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("Unmarshal() error = %v", err)
|
||||
}
|
||||
if len(resp.Messages) != 2 {
|
||||
t.Fatalf("len(resp.Messages) = %d, want 2", len(resp.Messages))
|
||||
if len(resp.Messages) != 3 {
|
||||
t.Fatalf("len(resp.Messages) = %d, want 3", len(resp.Messages))
|
||||
}
|
||||
if resp.Messages[0].Role != "user" || resp.Messages[0].Content != "check file" {
|
||||
t.Fatalf("first message = %#v, want user/check file", resp.Messages[0])
|
||||
@@ -606,153 +603,8 @@ func TestHandleGetSession_DoesNotDuplicateAssistantToolCallContent(t *testing.T)
|
||||
if !strings.Contains(resp.Messages[1].Content, "`read_file`") {
|
||||
t.Fatalf("tool summary message = %#v, want read_file summary", resp.Messages[1])
|
||||
}
|
||||
if !strings.Contains(resp.Messages[1].Content, "Read the file before replying.") {
|
||||
t.Fatalf("tool summary message = %#v, want tool explanation", resp.Messages[1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleGetSession_PreservesDistinctAssistantToolCallContent(t *testing.T) {
|
||||
configPath, cleanup := setupOAuthTestEnv(t)
|
||||
defer cleanup()
|
||||
|
||||
dir := sessionsTestDir(t, configPath)
|
||||
store, err := memory.NewJSONLStore(dir)
|
||||
if err != nil {
|
||||
t.Fatalf("NewJSONLStore() error = %v", err)
|
||||
}
|
||||
|
||||
sessionKey := picoSessionPrefix + "detail-tool-summary-distinct-content"
|
||||
for _, msg := range []providers.Message{
|
||||
{Role: "user", Content: "check file"},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "I will summarize the findings after reading the file.",
|
||||
ToolCalls: []providers.ToolCall{
|
||||
{
|
||||
ID: "call_1",
|
||||
Type: "function",
|
||||
Function: &providers.FunctionCall{
|
||||
Name: "read_file",
|
||||
Arguments: `{"path":"README.md","start_line":1,"end_line":10}`,
|
||||
},
|
||||
ExtraContent: &providers.ExtraContent{
|
||||
ToolFeedbackExplanation: "Read the file before replying.",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
} {
|
||||
if err := store.AddFullMessage(nil, sessionKey, msg); err != nil {
|
||||
t.Fatalf("AddFullMessage() error = %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
h := NewHandler(configPath)
|
||||
mux := http.NewServeMux()
|
||||
h.RegisterRoutes(mux)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/sessions/detail-tool-summary-distinct-content", nil)
|
||||
mux.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String())
|
||||
}
|
||||
|
||||
var resp struct {
|
||||
Messages []struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
} `json:"messages"`
|
||||
}
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("Unmarshal() error = %v", err)
|
||||
}
|
||||
if len(resp.Messages) != 3 {
|
||||
t.Fatalf("len(resp.Messages) = %d, want 3", len(resp.Messages))
|
||||
}
|
||||
if !strings.Contains(resp.Messages[1].Content, "`read_file`") {
|
||||
t.Fatalf("tool summary message = %#v, want read_file summary", resp.Messages[1])
|
||||
}
|
||||
if resp.Messages[2].Role != "assistant" ||
|
||||
resp.Messages[2].Content != "I will summarize the findings after reading the file." {
|
||||
t.Fatalf("assistant content = %#v, want preserved distinct content", resp.Messages[2])
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleGetSession_PreservesMediaWhenAssistantToolCallContentDuplicatesSummary(t *testing.T) {
|
||||
configPath, cleanup := setupOAuthTestEnv(t)
|
||||
defer cleanup()
|
||||
|
||||
dir := sessionsTestDir(t, configPath)
|
||||
store, err := memory.NewJSONLStore(dir)
|
||||
if err != nil {
|
||||
t.Fatalf("NewJSONLStore() error = %v", err)
|
||||
}
|
||||
|
||||
sessionKey := picoSessionPrefix + "detail-tool-summary-duplicate-content-with-media"
|
||||
for _, msg := range []providers.Message{
|
||||
{Role: "user", Content: "check screenshot"},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "Reviewing the generated screenshot.",
|
||||
Media: []string{"data:image/png;base64,abc123"},
|
||||
ToolCalls: []providers.ToolCall{
|
||||
{
|
||||
ID: "call_1",
|
||||
Type: "function",
|
||||
Function: &providers.FunctionCall{
|
||||
Name: "view_image",
|
||||
Arguments: `{"path":"artifact.png"}`,
|
||||
},
|
||||
ExtraContent: &providers.ExtraContent{
|
||||
ToolFeedbackExplanation: "Reviewing the generated screenshot.",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
} {
|
||||
if err := store.AddFullMessage(nil, sessionKey, msg); err != nil {
|
||||
t.Fatalf("AddFullMessage() error = %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
h := NewHandler(configPath)
|
||||
mux := http.NewServeMux()
|
||||
h.RegisterRoutes(mux)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/sessions/detail-tool-summary-duplicate-content-with-media", nil)
|
||||
mux.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String())
|
||||
}
|
||||
|
||||
var resp struct {
|
||||
Messages []struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
Media []string `json:"media"`
|
||||
} `json:"messages"`
|
||||
}
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("Unmarshal() error = %v", err)
|
||||
}
|
||||
if len(resp.Messages) != 3 {
|
||||
t.Fatalf("len(resp.Messages) = %d, want 3", len(resp.Messages))
|
||||
}
|
||||
if !strings.Contains(resp.Messages[1].Content, "`view_image`") {
|
||||
t.Fatalf("tool summary message = %#v, want view_image summary", resp.Messages[1])
|
||||
}
|
||||
if resp.Messages[2].Role != "assistant" {
|
||||
t.Fatalf("assistant message role = %q, want assistant", resp.Messages[2].Role)
|
||||
}
|
||||
if resp.Messages[2].Content != "Reviewing the generated screenshot." {
|
||||
t.Fatalf("assistant content = %q, want preserved duplicated content with media", resp.Messages[2].Content)
|
||||
}
|
||||
if len(resp.Messages[2].Media) != 1 || resp.Messages[2].Media[0] != "data:image/png;base64,abc123" {
|
||||
t.Fatalf("assistant media = %#v, want preserved media", resp.Messages[2].Media)
|
||||
if resp.Messages[2].Role != "assistant" || resp.Messages[2].Content != "model final reply" {
|
||||
t.Fatalf("assistant message = %#v, want model final reply", resp.Messages[2])
|
||||
}
|
||||
}
|
||||
|
||||
@@ -777,7 +629,6 @@ func TestHandleGetSession_UsesConfiguredToolFeedbackMaxArgsLength(t *testing.T)
|
||||
}
|
||||
|
||||
argsJSON := `{"path":"README.md","start_line":1,"end_line":10,"extra":"abcdefghijklmnopqrstuvwxyz"}`
|
||||
explanation := "Read README.md first to confirm the current project structure before editing the config example."
|
||||
sessionKey := picoSessionPrefix + "detail-tool-summary-max-args"
|
||||
err = store.AddFullMessage(nil, sessionKey, providers.Message{Role: "user", Content: "check file"})
|
||||
if err != nil {
|
||||
@@ -792,9 +643,6 @@ func TestHandleGetSession_UsesConfiguredToolFeedbackMaxArgsLength(t *testing.T)
|
||||
Name: "read_file",
|
||||
Arguments: argsJSON,
|
||||
},
|
||||
ExtraContent: &providers.ExtraContent{
|
||||
ToolFeedbackExplanation: explanation,
|
||||
},
|
||||
}},
|
||||
})
|
||||
if err != nil {
|
||||
@@ -827,93 +675,13 @@ func TestHandleGetSession_UsesConfiguredToolFeedbackMaxArgsLength(t *testing.T)
|
||||
t.Fatalf("len(resp.Messages) = %d, want at least 2", len(resp.Messages))
|
||||
}
|
||||
|
||||
wantPreview := utils.Truncate(explanation, 20)
|
||||
wantPreview := utils.Truncate(argsJSON, 20)
|
||||
if !strings.Contains(resp.Messages[1].Content, wantPreview) {
|
||||
t.Fatalf("tool summary = %q, want preview %q", resp.Messages[1].Content, wantPreview)
|
||||
}
|
||||
if strings.Contains(resp.Messages[1].Content, argsJSON) {
|
||||
t.Fatalf("tool summary = %q, expected configured truncation", resp.Messages[1].Content)
|
||||
}
|
||||
if !strings.Contains(resp.Messages[1].Content, "`read_file`") {
|
||||
t.Fatalf("tool summary = %q, want read_file summary", resp.Messages[1].Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleGetSession_FallsBackToLegacyToolArgumentsWhenExplanationMissing(t *testing.T) {
|
||||
configPath, cleanup := setupOAuthTestEnv(t)
|
||||
defer cleanup()
|
||||
|
||||
cfg, err := config.LoadConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig() error = %v", err)
|
||||
}
|
||||
cfg.Agents.Defaults.ToolFeedback.MaxArgsLength = 20
|
||||
err = config.SaveConfig(configPath, cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("SaveConfig() error = %v", err)
|
||||
}
|
||||
|
||||
dir := sessionsTestDir(t, configPath)
|
||||
store, err := memory.NewJSONLStore(dir)
|
||||
if err != nil {
|
||||
t.Fatalf("NewJSONLStore() error = %v", err)
|
||||
}
|
||||
|
||||
argsJSON := `{"path":"README.md","start_line":1,"end_line":10,"extra":"abcdefghijklmnopqrstuvwxyz"}`
|
||||
sessionKey := picoSessionPrefix + "detail-tool-summary-legacy-args"
|
||||
if err := store.AddFullMessage(
|
||||
nil,
|
||||
sessionKey,
|
||||
providers.Message{Role: "user", Content: "check file"},
|
||||
); err != nil {
|
||||
t.Fatalf("AddFullMessage(user) error = %v", err)
|
||||
}
|
||||
if err := store.AddFullMessage(nil, sessionKey, providers.Message{
|
||||
Role: "assistant",
|
||||
ToolCalls: []providers.ToolCall{{
|
||||
ID: "call_1",
|
||||
Type: "function",
|
||||
Function: &providers.FunctionCall{
|
||||
Name: "read_file",
|
||||
Arguments: argsJSON,
|
||||
},
|
||||
}},
|
||||
}); err != nil {
|
||||
t.Fatalf("AddFullMessage(assistant) error = %v", err)
|
||||
}
|
||||
|
||||
h := NewHandler(configPath)
|
||||
mux := http.NewServeMux()
|
||||
h.RegisterRoutes(mux)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/sessions/detail-tool-summary-legacy-args", nil)
|
||||
mux.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String())
|
||||
}
|
||||
|
||||
var resp struct {
|
||||
Messages []struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
} `json:"messages"`
|
||||
}
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("Unmarshal() error = %v", err)
|
||||
}
|
||||
if len(resp.Messages) < 2 {
|
||||
t.Fatalf("len(resp.Messages) = %d, want at least 2", len(resp.Messages))
|
||||
}
|
||||
|
||||
wantPreview := utils.Truncate(argsJSON, 20)
|
||||
if !strings.Contains(resp.Messages[1].Content, "`read_file`") {
|
||||
t.Fatalf("tool summary = %q, want read_file summary", resp.Messages[1].Content)
|
||||
}
|
||||
if !strings.Contains(resp.Messages[1].Content, wantPreview) {
|
||||
t.Fatalf("tool summary = %q, want legacy args preview %q", resp.Messages[1].Content, wantPreview)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleGetSession_IncludesMediaOnlyMessages(t *testing.T) {
|
||||
|
||||
@@ -592,9 +592,9 @@
|
||||
"split_on_marker": "Chatty Mode",
|
||||
"split_on_marker_hint": "Split long messages into short ones like real human chatting.",
|
||||
"tool_feedback_enabled": "Tool Feedback",
|
||||
"tool_feedback_enabled_hint": "Send a short execution note into the current chat before each tool runs.",
|
||||
"tool_feedback_max_args_length": "Tool Feedback Length",
|
||||
"tool_feedback_max_args_length_hint": "Maximum number of characters shown in each tool feedback message. Set to 0 to use the default.",
|
||||
"tool_feedback_enabled_hint": "Send a short tool-call preview into the current chat before each tool execution.",
|
||||
"tool_feedback_max_args_length": "Tool Feedback Args Preview Length",
|
||||
"tool_feedback_max_args_length_hint": "Maximum number of argument characters shown in each tool feedback message. Set to 0 to use the default.",
|
||||
"exec_enabled": "Allow Commands",
|
||||
"exec_enabled_hint": "Enable or disable command execution for the app. When disabled, no command requests will run.",
|
||||
"allow_remote": "Allow Remote Commands",
|
||||
|
||||
@@ -592,9 +592,9 @@
|
||||
"split_on_marker": "连续短消息",
|
||||
"split_on_marker_hint": "像真人聊天一样,把长难句拆成多条短消息快速发出",
|
||||
"tool_feedback_enabled": "工具反馈",
|
||||
"tool_feedback_enabled_hint": "在每次执行工具前,先向当前会话发送一条简短的执行说明",
|
||||
"tool_feedback_max_args_length": "工具反馈长度",
|
||||
"tool_feedback_max_args_length_hint": "每条工具反馈消息中展示的字符上限。设为 0 时使用默认值",
|
||||
"tool_feedback_enabled_hint": "在每次执行工具前,先向当前会话发送一条简短的工具调用预览",
|
||||
"tool_feedback_max_args_length": "工具反馈参数预览长度",
|
||||
"tool_feedback_max_args_length_hint": "每条工具反馈消息中展示的参数字符上限。设为 0 时使用默认值",
|
||||
"exec_enabled": "允许命令执行",
|
||||
"exec_enabled_hint": "控制应用是否允许执行命令。关闭后,所有命令请求都不会执行",
|
||||
"allow_remote": "允许远程命令执行",
|
||||
|
||||
Reference in New Issue
Block a user