From df4f322f09239595bb9ddf163c885f936efff045 Mon Sep 17 00:00:00 2001 From: afjcjsbx Date: Sun, 22 Mar 2026 12:05:28 +0100 Subject: [PATCH 01/39] fix(tool): route binary outputs through the media pipeline. --- pkg/agent/loop.go | 145 +++++++++++++---- pkg/agent/loop_test.go | 296 +++++++++++++++++++++++++++++++++++ pkg/channels/manager.go | 42 ++++- pkg/channels/manager_test.go | 70 +++++++++ pkg/tools/mcp_tool.go | 280 +++++++++++++++++++++++++++++++-- pkg/tools/mcp_tool_test.go | 144 +++++++++++++++++ pkg/tools/normalization.go | 292 ++++++++++++++++++++++++++++++++++ pkg/tools/registry.go | 39 ++++- pkg/tools/registry_test.go | 111 +++++++++++++ pkg/tools/result.go | 58 ++++++- pkg/tools/result_test.go | 39 +++++ pkg/tools/send_file.go | 2 +- pkg/tools/send_file_test.go | 3 + pkg/tools/toolloop.go | 5 +- 14 files changed, 1462 insertions(+), 64 deletions(-) create mode 100644 pkg/tools/normalization.go diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index ed5c73afc..2b6672f9c 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -492,13 +492,13 @@ func (al *AgentLoop) GetConfig() *config.Config { func (al *AgentLoop) SetMediaStore(s media.MediaStore) { al.mediaStore = s - // Propagate store to send_file tools in all agents. + // Propagate store to all registered tools that can emit media. registry := al.GetRegistry() - registry.ForEachTool("send_file", func(t tools.Tool) { - if sf, ok := t.(*tools.SendFileTool); ok { - sf.SetMediaStore(s) + for _, agentID := range registry.ListAgentIDs() { + if agent, ok := registry.GetAgent(agentID); ok { + agent.Tools.SetMediaStore(s) } - }) + } } // SetTranscriber injects a voice transcriber for agent-level audio transcription. @@ -926,13 +926,26 @@ func (al *AgentLoop) runAgentLoop( agent.Sessions.AddMessage(opts.SessionKey, "user", opts.UserMessage) // 3. Run LLM iteration loop - finalContent, iteration, err := al.runLLMIteration(ctx, agent, messages, opts) + finalContent, iteration, responseHandled, err := al.runLLMIteration(ctx, agent, messages, opts) if err != nil { return "", err } - // If last tool had ForUser content and we already sent it, we might not need to send final response - // This is controlled by the tool's Silent flag and ForUser content + if responseHandled { + agent.Sessions.Save(opts.SessionKey) + + if opts.EnableSummary { + al.maybeSummarize(agent, opts.SessionKey, opts.Channel, opts.ChatID) + } + + logger.InfoCF("agent", "Response already handled by tool output", + map[string]any{ + "agent_id": agent.ID, + "session_key": opts.SessionKey, + "iterations": iteration, + }) + return "", nil + } // 4. Handle empty response if finalContent == "" { @@ -1030,14 +1043,57 @@ func (al *AgentLoop) handleReasoning( } } +const handledToolResponseSummary = "Requested output delivered via tool attachment." + +func (al *AgentLoop) buildOutboundMediaMessage( + channel string, + chatID string, + refs []string, +) bus.OutboundMediaMessage { + parts := make([]bus.MediaPart, 0, len(refs)) + for _, ref := range refs { + part := bus.MediaPart{Ref: ref} + if al.mediaStore != nil { + if _, meta, err := al.mediaStore.ResolveWithMeta(ref); err == nil { + part.Filename = meta.Filename + part.ContentType = meta.ContentType + part.Type = inferMediaType(meta.Filename, meta.ContentType) + } + } + parts = append(parts, part) + } + return bus.OutboundMediaMessage{ + Channel: channel, + ChatID: chatID, + Parts: parts, + } +} + +func (al *AgentLoop) buildArtifactTags(refs []string) []string { + if al.mediaStore == nil || len(refs) == 0 { + return nil + } + + tags := make([]string, 0, len(refs)) + for _, ref := range refs { + localPath, meta, err := al.mediaStore.ResolveWithMeta(ref) + if err != nil { + continue + } + mime := detectMIME(localPath, meta) + tags = append(tags, buildPathTag(mime, localPath)) + } + return tags +} + // runLLMIteration executes the LLM call loop with tool handling. -// Returns (finalContent, iteration, error). +// Returns (finalContent, iteration, responseHandled, error). func (al *AgentLoop) runLLMIteration( ctx context.Context, agent *AgentInstance, messages []providers.Message, opts processOptions, -) (string, int, error) { +) (string, int, bool, error) { iteration := 0 var finalContent string @@ -1240,7 +1296,7 @@ func (al *AgentLoop) runLLMIteration( "model": activeModel, "error": err.Error(), }) - return "", iteration, fmt.Errorf("LLM call failed after retries: %w", err) + return "", iteration, false, fmt.Errorf("LLM call failed after retries: %w", err) } go al.handleReasoning( @@ -1401,10 +1457,7 @@ func (al *AgentLoop) runLLMIteration( } // Determine content for the agent loop (ForLLM or error). - content := result.ForLLM - if content == "" && result.Err != nil { - content = result.Err.Error() - } + content := result.ContentForLLM() if content == "" { return } @@ -1439,8 +1492,14 @@ func (al *AgentLoop) runLLMIteration( } wg.Wait() + allResponsesHandled := len(agentResults) > 0 + // Process results in original order (send to user, save to session) for _, r := range agentResults { + if !r.result.ResponseHandled { + allResponsesHandled = false + } + // Send ForUser content to user immediately if not Silent if !r.result.Silent && r.result.ForUser != "" && opts.SendResponse { al.bus.PublishOutbound(ctx, bus.OutboundMessage{ @@ -1455,32 +1514,33 @@ func (al *AgentLoop) runLLMIteration( }) } - // If tool returned media refs, publish them as outbound media + // If tool returned media refs, publish them as outbound media only when the + // tool explicitly marked the user-visible delivery as already handled. if len(r.result.Media) > 0 { - parts := make([]bus.MediaPart, 0, len(r.result.Media)) - for _, ref := range r.result.Media { - part := bus.MediaPart{Ref: ref} - if al.mediaStore != nil { - if _, meta, err := al.mediaStore.ResolveWithMeta(ref); err == nil { - part.Filename = meta.Filename - part.ContentType = meta.ContentType - part.Type = inferMediaType(meta.Filename, meta.ContentType) + outboundMedia := al.buildOutboundMediaMessage(opts.Channel, opts.ChatID, r.result.Media) + if r.result.ResponseHandled { + if al.channelManager != nil { + if err := al.channelManager.SendMedia(ctx, outboundMedia); err != nil { + allResponsesHandled = false + logger.WarnCF("agent", "Synchronous media send failed, falling back to bus delivery", + map[string]any{ + "agent_id": agent.ID, + "tool": r.tc.Name, + "error": err.Error(), + }) + al.bus.PublishOutboundMedia(ctx, outboundMedia) } + } else { + al.bus.PublishOutboundMedia(ctx, outboundMedia) } - parts = append(parts, part) } - al.bus.PublishOutboundMedia(ctx, bus.OutboundMediaMessage{ - Channel: opts.Channel, - ChatID: opts.ChatID, - Parts: parts, - }) } // Determine content for LLM based on tool result - contentForLLM := r.result.ForLLM - if contentForLLM == "" && r.result.Err != nil { - contentForLLM = r.result.Err.Error() + if len(r.result.Media) > 0 && !r.result.ResponseHandled { + r.result.ArtifactTags = al.buildArtifactTags(r.result.Media) } + contentForLLM := r.result.ContentForLLM() toolResultMsg := providers.Message{ Role: "tool", @@ -1493,6 +1553,23 @@ func (al *AgentLoop) runLLMIteration( agent.Sessions.AddFullMessage(opts.SessionKey, toolResultMsg) } + if allResponsesHandled { + summaryMsg := providers.Message{ + Role: "assistant", + Content: handledToolResponseSummary, + } + messages = append(messages, summaryMsg) + agent.Sessions.AddFullMessage(opts.SessionKey, summaryMsg) + + logger.InfoCF("agent", "Tool output satisfied delivery; ending turn without follow-up LLM", + map[string]any{ + "agent_id": agent.ID, + "iteration": iteration, + "tool_count": len(agentResults), + }) + return "", iteration, true, nil + } + // Tick down TTL of discovered tools after processing tool results. // Only reached when tool calls were made (the loop continues); // the break on no-tool-call responses skips this. @@ -1505,7 +1582,7 @@ func (al *AgentLoop) runLLMIteration( }) } - return finalContent, iteration, nil + return finalContent, iteration, false, nil } // selectCandidates returns the model candidates and resolved model name to use diff --git a/pkg/agent/loop_test.go b/pkg/agent/loop_test.go index 28eab03db..ea63d9634 100644 --- a/pkg/agent/loop_test.go +++ b/pkg/agent/loop_test.go @@ -298,6 +298,152 @@ func TestToolRegistry_GetDefinitions(t *testing.T) { } } +func TestProcessMessage_MediaToolHandledSkipsFollowUpLLMAndFinalText(t *testing.T) { + tmpDir := t.TempDir() + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + + msgBus := bus.NewMessageBus() + provider := &handledMediaProvider{} + al := NewAgentLoop(cfg, msgBus, provider) + + store := media.NewFileMediaStore() + al.SetMediaStore(store) + + imagePath := filepath.Join(tmpDir, "screen.png") + if err := os.WriteFile(imagePath, []byte("fake screenshot"), 0o644); err != nil { + t.Fatalf("WriteFile(imagePath) error = %v", err) + } + + al.RegisterTool(&handledMediaTool{ + store: store, + path: imagePath, + }) + + response, err := al.processMessage(context.Background(), bus.InboundMessage{ + Channel: "telegram", + ChatID: "chat1", + SenderID: "user1", + Content: "take a screenshot of the screen and send it to me", + }) + if err != nil { + t.Fatalf("processMessage() error = %v", err) + } + if response != "" { + t.Fatalf("expected no final response when media tool already handled delivery, got %q", response) + } + if provider.calls != 1 { + t.Fatalf("expected exactly 1 LLM call, got %d", provider.calls) + } + if len(provider.toolCounts) != 1 { + t.Fatalf("expected tool counts for 1 provider call, got %d", len(provider.toolCounts)) + } + if provider.toolCounts[0] == 0 { + t.Fatal("expected tools to be available on the first LLM call") + } + + select { + case mediaMsg := <-msgBus.OutboundMediaChan(): + if mediaMsg.Channel != "telegram" || mediaMsg.ChatID != "chat1" { + t.Fatalf("unexpected outbound media target: %+v", mediaMsg) + } + if len(mediaMsg.Parts) != 1 { + t.Fatalf("expected exactly 1 outbound media part, got %d", len(mediaMsg.Parts)) + } + default: + t.Fatal("expected outbound media message to be published") + } + + defaultAgent := al.GetRegistry().GetDefaultAgent() + if defaultAgent == nil { + t.Fatal("expected default agent") + } + route, _, err := al.resolveMessageRoute(bus.InboundMessage{ + Channel: "telegram", + ChatID: "chat1", + SenderID: "user1", + Content: "take a screenshot of the screen and send it to me", + }) + if err != nil { + t.Fatalf("resolveMessageRoute() error = %v", err) + } + sessionKey := resolveScopeKey(route, "") + history := defaultAgent.Sessions.GetHistory(sessionKey) + if len(history) == 0 { + t.Fatal("expected session history to be saved") + } + last := history[len(history)-1] + if last.Role != "assistant" || last.Content != handledToolResponseSummary { + t.Fatalf("expected handled assistant summary in history, got %+v", last) + } +} + +func TestProcessMessage_MediaArtifactCanBeForwardedBySendFile(t *testing.T) { + tmpDir := t.TempDir() + cfg := config.DefaultConfig() + cfg.Agents.Defaults.Workspace = tmpDir + cfg.Agents.Defaults.Model = "test-model" + cfg.Agents.Defaults.MaxTokens = 4096 + cfg.Agents.Defaults.MaxToolIterations = 10 + + msgBus := bus.NewMessageBus() + provider := &artifactThenSendProvider{} + al := NewAgentLoop(cfg, msgBus, provider) + + store := media.NewFileMediaStore() + al.SetMediaStore(store) + + mediaDir := media.TempDir() + if err := os.MkdirAll(mediaDir, 0o700); err != nil { + t.Fatalf("MkdirAll(mediaDir) error = %v", err) + } + imagePath := filepath.Join(mediaDir, "artifact-screen.png") + if err := os.WriteFile(imagePath, []byte("fake screenshot"), 0o644); err != nil { + t.Fatalf("WriteFile(imagePath) error = %v", err) + } + + al.RegisterTool(&mediaArtifactTool{ + store: store, + path: imagePath, + }) + + response, err := al.processMessage(context.Background(), bus.InboundMessage{ + Channel: "telegram", + ChatID: "chat1", + SenderID: "user1", + Content: "take a screenshot of the screen and send it to me", + }) + if err != nil { + t.Fatalf("processMessage() error = %v", err) + } + if response != "" { + t.Fatalf("expected no final response after send_file handled delivery, got %q", response) + } + if provider.calls != 2 { + t.Fatalf("expected 2 LLM calls (artifact + send_file), got %d", provider.calls) + } + + select { + case mediaMsg := <-msgBus.OutboundMediaChan(): + if mediaMsg.Channel != "telegram" || mediaMsg.ChatID != "chat1" { + t.Fatalf("unexpected outbound media target: %+v", mediaMsg) + } + if len(mediaMsg.Parts) != 1 { + t.Fatalf("expected exactly 1 outbound media part, got %d", len(mediaMsg.Parts)) + } + default: + t.Fatal("expected outbound media from send_file") + } +} + // TestAgentLoop_GetStartupInfo verifies startup info contains tools func TestAgentLoop_GetStartupInfo(t *testing.T) { tmpDir, err := os.MkdirTemp("", "agent-test-*") @@ -420,6 +566,98 @@ func (m *countingMockProvider) GetDefaultModel() string { return "counting-mock-model" } +type handledMediaProvider struct { + calls int + toolCounts []int +} + +func (m *handledMediaProvider) Chat( + ctx context.Context, + messages []providers.Message, + tools []providers.ToolDefinition, + model string, + opts map[string]any, +) (*providers.LLMResponse, error) { + m.calls++ + m.toolCounts = append(m.toolCounts, len(tools)) + if m.calls == 1 { + return &providers.LLMResponse{ + Content: "Taking the screenshot now.", + ToolCalls: []providers.ToolCall{{ + ID: "call_handled_media", + Type: "function", + Name: "handled_media_tool", + Arguments: map[string]any{}, + }}, + }, nil + } + return &providers.LLMResponse{}, nil +} + +func (m *handledMediaProvider) GetDefaultModel() string { + return "handled-media-model" +} + +type artifactThenSendProvider struct { + calls int +} + +func (m *artifactThenSendProvider) Chat( + ctx context.Context, + messages []providers.Message, + tools []providers.ToolDefinition, + model string, + opts map[string]any, +) (*providers.LLMResponse, error) { + m.calls++ + if m.calls == 1 { + return &providers.LLMResponse{ + Content: "Taking the screenshot now.", + ToolCalls: []providers.ToolCall{{ + ID: "call_artifact_media", + Type: "function", + Name: "media_artifact_tool", + Arguments: map[string]any{}, + }}, + }, nil + } + + var artifactPath string + for i := len(messages) - 1; i >= 0; i-- { + if messages[i].Role != "tool" { + continue + } + start := strings.Index(messages[i].Content, "[file:") + if start < 0 { + continue + } + rest := messages[i].Content[start+len("[file:"):] + end := strings.Index(rest, "]") + if end < 0 { + continue + } + artifactPath = rest[:end] + break + } + if artifactPath == "" { + return nil, fmt.Errorf("provider did not receive artifact path in tool result") + } + + return &providers.LLMResponse{ + Content: "", + ToolCalls: []providers.ToolCall{{ + ID: "call_send_file", + Type: "function", + Name: "send_file", + Arguments: map[string]any{"path": artifactPath}, + }}, + }, nil +} + +func (m *artifactThenSendProvider) GetDefaultModel() string { + return "artifact-then-send-model" +} + type toolLimitOnlyProvider struct{} func (m *toolLimitOnlyProvider) Chat( @@ -465,6 +703,64 @@ func (m *mockCustomTool) Execute(ctx context.Context, args map[string]any) *tool return tools.SilentResult("Custom tool executed") } +type handledMediaTool struct { + store media.MediaStore + path string +} + +func (m *handledMediaTool) Name() string { return "handled_media_tool" } +func (m *handledMediaTool) Description() string { + return "Returns a media attachment and fully handles the user response" +} + +func (m *handledMediaTool) Parameters() map[string]any { + return map[string]any{ + "type": "object", + "properties": map[string]any{}, + } +} + +func (m *handledMediaTool) Execute(ctx context.Context, args map[string]any) *tools.ToolResult { + ref, err := m.store.Store(m.path, media.MediaMeta{ + Filename: filepath.Base(m.path), + ContentType: "image/png", + Source: "test:handled_media_tool", + }, "test:handled_media") + if err != nil { + return tools.ErrorResult(err.Error()).WithError(err) + } + return tools.MediaResult("Attachment delivered by tool.", []string{ref}).WithResponseHandled() +} + +type mediaArtifactTool struct { + store media.MediaStore + path string +} + +func (m *mediaArtifactTool) Name() string { return "media_artifact_tool" } +func (m *mediaArtifactTool) Description() string { + return "Returns a media artifact that the agent can forward or save later" +} + +func (m *mediaArtifactTool) Parameters() map[string]any { + return map[string]any{ + "type": "object", + "properties": map[string]any{}, + } +} + +func (m *mediaArtifactTool) Execute(ctx context.Context, args map[string]any) *tools.ToolResult { + ref, err := m.store.Store(m.path, media.MediaMeta{ + Filename: filepath.Base(m.path), + ContentType: "image/png", + Source: "test:media_artifact_tool", + }, "test:media_artifact") + if err != nil { + return tools.ErrorResult(err.Error()).WithError(err) + } + return tools.MediaResult("Artifact created.", []string{ref}) +} + type toolLimitTestTool struct{} func (m *toolLimitTestTool) Name() string { diff --git a/pkg/channels/manager.go b/pkg/channels/manager.go index ff3fa399c..f4a64807e 100644 --- a/pkg/channels/manager.go +++ b/pkg/channels/manager.go @@ -771,7 +771,7 @@ func (m *Manager) runMediaWorker(ctx context.Context, name string, w *channelWor if !ok { return } - m.sendMediaWithRetry(ctx, name, w, msg) + _ = m.sendMediaWithRetry(ctx, name, w, msg) case <-ctx.Done(): return } @@ -779,26 +779,31 @@ func (m *Manager) runMediaWorker(ctx context.Context, name string, w *channelWor } // sendMediaWithRetry sends a media message through the channel with rate limiting and -// retry logic. If the channel does not implement MediaSender, it silently skips. -func (m *Manager) sendMediaWithRetry(ctx context.Context, name string, w *channelWorker, msg bus.OutboundMediaMessage) { +// retry logic. It returns nil on success, or the last error after retries. +func (m *Manager) sendMediaWithRetry( + ctx context.Context, + name string, + w *channelWorker, + msg bus.OutboundMediaMessage, +) error { ms, ok := w.ch.(MediaSender) if !ok { logger.DebugCF("channels", "Channel does not support MediaSender, skipping media", map[string]any{ "channel": name, }) - return + return nil } // Rate limit: wait for token if err := w.limiter.Wait(ctx); err != nil { - return + return err } var lastErr error for attempt := 0; attempt <= maxRetries; attempt++ { lastErr = ms.SendMedia(ctx, msg) if lastErr == nil { - return + return nil } // Permanent failures — don't retry @@ -817,7 +822,7 @@ func (m *Manager) sendMediaWithRetry(ctx context.Context, name string, w *channe case <-time.After(rateLimitDelay): continue case <-ctx.Done(): - return + return ctx.Err() } } @@ -826,7 +831,7 @@ func (m *Manager) sendMediaWithRetry(ctx context.Context, name string, w *channe select { case <-time.After(backoff): case <-ctx.Done(): - return + return ctx.Err() } } @@ -837,6 +842,7 @@ func (m *Manager) sendMediaWithRetry(ctx context.Context, name string, w *channe "error": lastErr.Error(), "retries": maxRetries, }) + return lastErr } // runTTLJanitor periodically scans the typingStops and placeholders maps @@ -1029,6 +1035,26 @@ func (m *Manager) SendMessage(ctx context.Context, msg bus.OutboundMessage) erro return nil } +// SendMedia sends outbound media synchronously through the channel worker's +// rate limiter and retry logic. It blocks until the media is delivered (or all +// retries are exhausted), which preserves ordering when later agent behavior +// depends on actual media delivery. +func (m *Manager) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error { + m.mu.RLock() + _, exists := m.channels[msg.Channel] + w, wExists := m.workers[msg.Channel] + m.mu.RUnlock() + + if !exists { + return fmt.Errorf("channel %s not found", msg.Channel) + } + if !wExists || w == nil { + return fmt.Errorf("channel %s has no active worker", msg.Channel) + } + + return m.sendMediaWithRetry(ctx, msg.Channel, w, msg) +} + func (m *Manager) SendToChannel(ctx context.Context, channelName, chatID, content string) error { m.mu.RLock() _, exists := m.channels[channelName] diff --git a/pkg/channels/manager_test.go b/pkg/channels/manager_test.go index 7dfec9ebf..6a5dd7e30 100644 --- a/pkg/channels/manager_test.go +++ b/pkg/channels/manager_test.go @@ -43,6 +43,20 @@ func (m *mockChannel) EditMessage(ctx context.Context, chatID, messageID, conten return nil } +type mockMediaChannel struct { + mockChannel + sendMediaFn func(ctx context.Context, msg bus.OutboundMediaMessage) error + sentMediaMessages []bus.OutboundMediaMessage +} + +func (m *mockMediaChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error { + m.sentMediaMessages = append(m.sentMediaMessages, msg) + if m.sendMediaFn != nil { + return m.sendMediaFn(ctx, msg) + } + return nil +} + // newTestManager creates a minimal Manager suitable for unit tests. func newTestManager() *Manager { return &Manager{ @@ -208,6 +222,62 @@ func TestSendWithRetry_MaxRetriesExhausted(t *testing.T) { } } +func TestSendMedia_Success(t *testing.T) { + m := newTestManager() + var callCount int + ch := &mockMediaChannel{ + sendMediaFn: func(_ context.Context, _ bus.OutboundMediaMessage) error { + callCount++ + return nil + }, + } + w := &channelWorker{ + ch: ch, + limiter: rate.NewLimiter(rate.Inf, 1), + } + m.channels["test"] = ch + m.workers["test"] = w + + err := m.SendMedia(context.Background(), bus.OutboundMediaMessage{ + Channel: "test", + ChatID: "chat1", + Parts: []bus.MediaPart{{Ref: "media://abc"}}, + }) + if err != nil { + t.Fatalf("SendMedia() error = %v", err) + } + if callCount != 1 { + t.Fatalf("expected 1 SendMedia call, got %d", callCount) + } +} + +func TestSendMedia_PropagatesFailure(t *testing.T) { + m := newTestManager() + ch := &mockMediaChannel{ + sendMediaFn: func(_ context.Context, _ bus.OutboundMediaMessage) error { + return fmt.Errorf("bad upload: %w", ErrSendFailed) + }, + } + w := &channelWorker{ + ch: ch, + limiter: rate.NewLimiter(rate.Inf, 1), + } + m.channels["test"] = ch + m.workers["test"] = w + + err := m.SendMedia(context.Background(), bus.OutboundMediaMessage{ + Channel: "test", + ChatID: "chat1", + Parts: []bus.MediaPart{{Ref: "media://abc"}}, + }) + if err == nil { + t.Fatal("expected SendMedia to return error") + } + if !errors.Is(err, ErrSendFailed) { + t.Fatalf("expected ErrSendFailed, got %v", err) + } +} + func TestSendWithRetry_UnknownError(t *testing.T) { m := newTestManager() var callCount int diff --git a/pkg/tools/mcp_tool.go b/pkg/tools/mcp_tool.go index 6e53cf354..5bffb4e89 100644 --- a/pkg/tools/mcp_tool.go +++ b/pkg/tools/mcp_tool.go @@ -5,9 +5,13 @@ import ( "encoding/json" "fmt" "hash/fnv" + "os" "strings" + "time" "github.com/modelcontextprotocol/go-sdk/mcp" + + "github.com/sipeed/picoclaw/pkg/media" ) // MCPManager defines the interface for MCP manager operations @@ -25,6 +29,7 @@ type MCPTool struct { manager MCPManager serverName string tool *mcp.Tool + mediaStore media.MediaStore } // NewMCPTool creates a new MCP tool wrapper @@ -36,6 +41,10 @@ func NewMCPTool(manager MCPManager, serverName string, tool *mcp.Tool) *MCPTool } } +func (t *MCPTool) SetMediaStore(store media.MediaStore) { + t.mediaStore = store +} + // sanitizeIdentifierComponent normalizes a string so it can be safely used // as part of a tool/function identifier for downstream providers. // It: @@ -218,13 +227,7 @@ func (t *MCPTool) Execute(ctx context.Context, args map[string]any) *ToolResult WithError(fmt.Errorf("MCP tool error: %s", errMsg)) } - // Extract text content from result - output := extractContentText(result.Content) - - return &ToolResult{ - ForLLM: output, - IsError: false, - } + return t.normalizeResultContent(ctx, result.Content) } // extractContentText extracts text from MCP content array @@ -233,14 +236,269 @@ func extractContentText(content []mcp.Content) string { for _, c := range content { switch v := c.(type) { case *mcp.TextContent: - parts = append(parts, v.Text) + parts = append(parts, sanitizeToolLLMContent(v.Text)) case *mcp.ImageContent: - // For images, just indicate that an image was returned - parts = append(parts, fmt.Sprintf("[Image: %s]", v.MIMEType)) + parts = append(parts, fmt.Sprintf("[Image: %s]", normalizedMIMEType(v.MIMEType))) + case *mcp.AudioContent: + parts = append(parts, fmt.Sprintf("[Audio: %s]", normalizedMIMEType(v.MIMEType))) + case *mcp.ResourceLink: + parts = append(parts, summarizeResourceLink(v)) + case *mcp.EmbeddedResource: + parts = append(parts, summarizeEmbeddedResource(v)) default: // For other content types, use string representation parts = append(parts, fmt.Sprintf("[Content: %T]", v)) } } - return strings.Join(parts, "\n") + return sanitizeToolLLMContent(strings.Join(parts, "\n")) +} + +func (t *MCPTool) normalizeResultContent(ctx context.Context, content []mcp.Content) *ToolResult { + llmParts := make([]string, 0, len(content)) + mediaRefs := make([]string, 0, len(content)) + + for _, c := range content { + switch v := c.(type) { + case *mcp.TextContent: + text := strings.TrimSpace(sanitizeToolLLMContent(v.Text)) + if text != "" { + llmParts = append(llmParts, text) + } + case *mcp.ImageContent: + ref, note := t.storeBinaryContent( + ctx, + "image", + normalizedMIMEType(v.MIMEType), + v.Data, + v.Annotations, + ) + if ref != "" { + mediaRefs = append(mediaRefs, ref) + } + if note != "" { + llmParts = append(llmParts, note) + } + case *mcp.AudioContent: + ref, note := t.storeBinaryContent( + ctx, + "audio", + normalizedMIMEType(v.MIMEType), + v.Data, + v.Annotations, + ) + if ref != "" { + mediaRefs = append(mediaRefs, ref) + } + if note != "" { + llmParts = append(llmParts, note) + } + case *mcp.ResourceLink: + llmParts = append(llmParts, summarizeResourceLink(v)) + case *mcp.EmbeddedResource: + ref, note := t.storeEmbeddedResource(ctx, v) + if ref != "" { + mediaRefs = append(mediaRefs, ref) + } + if note != "" { + llmParts = append(llmParts, note) + } + default: + llmParts = append(llmParts, fmt.Sprintf("[MCP returned unsupported content type %T]", v)) + } + } + + result := &ToolResult{ + ForLLM: strings.Join(compactStrings(llmParts), "\n"), + Media: mediaRefs, + } + return result +} + +func (t *MCPTool) storeEmbeddedResource(ctx context.Context, content *mcp.EmbeddedResource) (string, string) { + if content == nil || content.Resource == nil { + return "", "[MCP returned an embedded resource without data.]" + } + + resource := content.Resource + if len(resource.Blob) > 0 { + return t.storeBinaryContent( + ctx, + "resource", + normalizedMIMEType(resource.MIMEType), + resource.Blob, + content.Annotations, + ) + } + + if strings.TrimSpace(resource.Text) != "" { + return "", sanitizeToolLLMContent(resource.Text) + } + + return "", summarizeEmbeddedResource(content) +} + +func (t *MCPTool) storeBinaryContent( + ctx context.Context, + kind string, + mimeType string, + data []byte, + annotations *mcp.Annotations, +) (string, string) { + if len(data) == 0 { + return "", fmt.Sprintf("[MCP returned %s content (%s) but it was empty.]", kind, mimeType) + } + if !annotationsAllowUser(annotations) { + return "", fmt.Sprintf( + "[MCP returned %s content (%s) for non-user audience; omitted from model context.]", + kind, + mimeType, + ) + } + if t.mediaStore == nil { + return "", fmt.Sprintf( + "[MCP returned %s content (%s); omitted from model context because media delivery is unavailable.]", + kind, + mimeType, + ) + } + + channel := ToolChannel(ctx) + chatID := ToolChatID(ctx) + if channel == "" || chatID == "" { + return "", fmt.Sprintf( + "[MCP returned %s content (%s); omitted from model context because no target chat was available.]", + kind, + mimeType, + ) + } + + dir := media.TempDir() + if err := os.MkdirAll(dir, 0o700); err != nil { + return "", fmt.Sprintf("[MCP returned %s content (%s) but it could not be stored.]", kind, mimeType) + } + + ext := extensionForMIMEType(mimeType) + tmpFile, err := os.CreateTemp(dir, "mcp-*"+ext) + if err != nil { + return "", fmt.Sprintf("[MCP returned %s content (%s) but it could not be stored.]", kind, mimeType) + } + tmpPath := tmpFile.Name() + if _, err = tmpFile.Write(data); err != nil { + _ = tmpFile.Close() + _ = os.Remove(tmpPath) + return "", fmt.Sprintf("[MCP returned %s content (%s) but it could not be stored.]", kind, mimeType) + } + if err = tmpFile.Close(); err != nil { + _ = os.Remove(tmpPath) + return "", fmt.Sprintf("[MCP returned %s content (%s) but it could not be stored.]", kind, mimeType) + } + + scope := fmt.Sprintf( + "tool:mcp:%s:%s:%s:%d", + sanitizeIdentifierComponent(t.serverName), + channel, + chatID, + time.Now().UnixNano(), + ) + filename := fmt.Sprintf( + "%s_%s%s", + sanitizeIdentifierComponent(t.serverName), + sanitizeIdentifierComponent(t.tool.Name), + ext, + ) + + ref, err := t.mediaStore.Store(tmpPath, media.MediaMeta{ + Filename: filename, + ContentType: mimeType, + Source: fmt.Sprintf( + "tool:mcp:%s:%s", + sanitizeIdentifierComponent(t.serverName), + sanitizeIdentifierComponent(t.tool.Name), + ), + }, scope) + if err != nil { + _ = os.Remove(tmpPath) + return "", fmt.Sprintf( + "[MCP returned %s content (%s) but it could not be registered as media.]", + kind, + mimeType, + ) + } + + return ref, fmt.Sprintf( + "[MCP returned %s content (%s); omitted from model context and stored as a local media artifact.]", + kind, + mimeType, + ) +} + +func summarizeResourceLink(content *mcp.ResourceLink) string { + if content == nil { + return "[MCP returned an empty resource link.]" + } + + parts := []string{"[MCP returned resource link"} + if content.Name != "" { + parts = append(parts, fmt.Sprintf("name=%q", content.Name)) + } + if content.URI != "" { + parts = append(parts, fmt.Sprintf("uri=%q", content.URI)) + } + if content.MIMEType != "" { + parts = append(parts, fmt.Sprintf("mime=%q", content.MIMEType)) + } + if content.Description != "" { + desc := strings.TrimSpace(content.Description) + if len(desc) > 200 { + desc = desc[:200] + "..." + } + parts = append(parts, fmt.Sprintf("description=%q", desc)) + } + return strings.Join(parts, ", ") + "]" +} + +func summarizeEmbeddedResource(content *mcp.EmbeddedResource) string { + if content == nil || content.Resource == nil { + return "[MCP returned an embedded resource.]" + } + + resource := content.Resource + if resource.URI != "" { + return fmt.Sprintf( + "[MCP returned embedded resource %q (%s).]", + resource.URI, + normalizedMIMEType(resource.MIMEType), + ) + } + return fmt.Sprintf("[MCP returned embedded resource (%s).]", normalizedMIMEType(resource.MIMEType)) +} + +func annotationsAllowUser(annotations *mcp.Annotations) bool { + if annotations == nil || len(annotations.Audience) == 0 { + return true + } + for _, audience := range annotations.Audience { + if strings.EqualFold(string(audience), "user") { + return true + } + } + return false +} + +func normalizedMIMEType(mimeType string) string { + if strings.TrimSpace(mimeType) == "" { + return "application/octet-stream" + } + return mimeType +} + +func compactStrings(parts []string) []string { + compact := make([]string, 0, len(parts)) + for _, part := range parts { + if strings.TrimSpace(part) == "" { + continue + } + compact = append(compact, part) + } + return compact } diff --git a/pkg/tools/mcp_tool_test.go b/pkg/tools/mcp_tool_test.go index 95bb0f992..8bbac3bc7 100644 --- a/pkg/tools/mcp_tool_test.go +++ b/pkg/tools/mcp_tool_test.go @@ -3,10 +3,14 @@ package tools import ( "context" "fmt" + "os" + "path/filepath" "strings" "testing" "github.com/modelcontextprotocol/go-sdk/mcp" + + "github.com/sipeed/picoclaw/pkg/media" ) // MockMCPManager is a mock implementation of MCPManager interface for testing @@ -490,3 +494,143 @@ func TestMCPTool_Parameters_MapSchema(t *testing.T) { t.Errorf("Name type should be 'string', got '%v'", nameParam["type"]) } } + +func TestMCPTool_Execute_ImageContentStoredAsMedia(t *testing.T) { + store := media.NewFileMediaStore() + manager := &MockMCPManager{ + callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error) { + return &mcp.CallToolResult{ + Content: []mcp.Content{ + &mcp.ImageContent{ + Data: []byte("fake-image-bytes"), + MIMEType: "image/png", + }, + }, + }, nil + }, + } + + mcpTool := NewMCPTool(manager, "screenshoto", &mcp.Tool{Name: "take_screenshot"}) + mcpTool.SetMediaStore(store) + + result := mcpTool.Execute(WithToolContext(context.Background(), "telegram", "chat-42"), nil) + + if result.IsError { + t.Fatalf("expected success, got %q", result.ForLLM) + } + if len(result.Media) != 1 { + t.Fatalf("expected 1 media ref, got %d", len(result.Media)) + } + if result.ResponseHandled { + t.Fatal("expected MCP image artifact not to mark response as handled") + } + if !strings.Contains(result.ForLLM, "stored as a local media artifact") { + t.Fatalf("expected local media artifact note, got %q", result.ForLLM) + } + + path, meta, err := store.ResolveWithMeta(result.Media[0]) + if err != nil { + t.Fatalf("expected stored media ref to resolve: %v", err) + } + if meta.ContentType != "image/png" { + t.Fatalf("expected image/png content type, got %q", meta.ContentType) + } + if filepath.Ext(path) != ".png" { + t.Fatalf("expected png temp file, got %q", path) + } + data, err := os.ReadFile(path) + if err != nil { + t.Fatalf("expected stored media file to be readable: %v", err) + } + if string(data) != "fake-image-bytes" { + t.Fatalf("expected stored media bytes to match input, got %q", string(data)) + } +} + +func TestMCPTool_Execute_EmbeddedResourceBlobStoredAsMedia(t *testing.T) { + store := media.NewFileMediaStore() + manager := &MockMCPManager{ + callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error) { + return &mcp.CallToolResult{ + Content: []mcp.Content{ + &mcp.EmbeddedResource{ + Resource: &mcp.ResourceContents{ + URI: "file:///tmp/report.png", + MIMEType: "image/png", + Blob: []byte("blob-bytes"), + }, + }, + }, + }, nil + }, + } + + mcpTool := NewMCPTool(manager, "grafana", &mcp.Tool{Name: "get_dashboard_image"}) + mcpTool.SetMediaStore(store) + + result := mcpTool.Execute(WithToolContext(context.Background(), "telegram", "chat-42"), nil) + + if len(result.Media) != 1 { + t.Fatalf("expected embedded resource blob to be stored as media, got %d refs", len(result.Media)) + } + path, _, err := store.ResolveWithMeta(result.Media[0]) + if err != nil { + t.Fatalf("expected stored media ref to resolve: %v", err) + } + data, err := os.ReadFile(path) + if err != nil { + t.Fatalf("expected stored media file to be readable: %v", err) + } + if string(data) != "blob-bytes" { + t.Fatalf("expected stored blob bytes to match input, got %q", string(data)) + } +} + +func TestMCPTool_Execute_RespectsUserAudienceForBinaryContent(t *testing.T) { + store := media.NewFileMediaStore() + manager := &MockMCPManager{ + callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error) { + return &mcp.CallToolResult{ + Content: []mcp.Content{ + &mcp.ImageContent{ + Data: []byte("assistant-only"), + MIMEType: "image/png", + Annotations: &mcp.Annotations{Audience: []mcp.Role{"assistant"}}, + }, + }, + }, nil + }, + } + + mcpTool := NewMCPTool(manager, "screenshoto", &mcp.Tool{Name: "take_screenshot"}) + mcpTool.SetMediaStore(store) + + result := mcpTool.Execute(WithToolContext(context.Background(), "telegram", "chat-42"), nil) + + if len(result.Media) != 0 { + t.Fatalf("expected no media ref for non-user audience, got %d", len(result.Media)) + } + if !strings.Contains(result.ForLLM, "non-user audience") { + t.Fatalf("expected audience note, got %q", result.ForLLM) + } +} + +func TestMCPTool_Execute_LargeBase64TextIsOmittedFromContext(t *testing.T) { + manager := &MockMCPManager{ + callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error) { + return &mcp.CallToolResult{ + Content: []mcp.Content{ + &mcp.TextContent{Text: strings.Repeat("QUJD", 400)}, + }, + }, nil + }, + } + + mcpTool := NewMCPTool(manager, "test_server", &mcp.Tool{Name: "dump_payload"}) + + result := mcpTool.Execute(context.Background(), nil) + + if result.ForLLM != largeBase64OmittedMessage { + t.Fatalf("expected sanitized large base64 note, got %q", result.ForLLM) + } +} diff --git a/pkg/tools/normalization.go b/pkg/tools/normalization.go new file mode 100644 index 000000000..3a76c5d92 --- /dev/null +++ b/pkg/tools/normalization.go @@ -0,0 +1,292 @@ +package tools + +import ( + "encoding/base64" + "fmt" + "mime" + "os" + "path/filepath" + "regexp" + "strings" + "time" + "unicode" + + "github.com/sipeed/picoclaw/pkg/media" +) + +const ( + largeBase64OmittedMessage = "[Tool returned a large base64-like payload; omitted from model context.]" + inlineMediaOmittedMessage = "[Tool returned inline media content; omitted from model context.]" + inlineMediaStoredMessage = "[Tool returned inline media content (%s); omitted from model context and registered as a media attachment.]" +) + +var ( + inlineMarkdownDataURLRe = regexp.MustCompile(`!\[[^\]]*\]\((data:[^)]+)\)`) + inlineRawDataURLRe = regexp.MustCompile(`data:[^;\s]+;base64,[A-Za-z0-9+/=\r\n]+`) +) + +func normalizeToolResult( + result *ToolResult, + toolName string, + store media.MediaStore, + channel string, + chatID string, +) *ToolResult { + if result == nil { + return nil + } + + notes := make([]string, 0, 2) + seen := make(map[string]struct{}) + + if store != nil && channel != "" && chatID != "" { + var refs []string + var extractedNotes []string + + result.ForLLM, refs, extractedNotes = extractInlineMediaRefs( + result.ForLLM, + toolName, + store, + channel, + chatID, + seen, + ) + result.Media = append(result.Media, refs...) + notes = append(notes, extractedNotes...) + + result.ForUser, refs, extractedNotes = extractInlineMediaRefs( + result.ForUser, + toolName, + store, + channel, + chatID, + seen, + ) + result.Media = append(result.Media, refs...) + notes = append(notes, extractedNotes...) + } + + result.ForLLM = sanitizeToolLLMContent(result.ForLLM) + + if len(result.Media) > 0 && len(notes) > 0 { + if strings.TrimSpace(result.ForLLM) == "" { + result.ForLLM = strings.Join(notes, "\n") + } else { + result.ForLLM = strings.TrimSpace(result.ForLLM) + "\n" + strings.Join(notes, "\n") + } + } + if len(result.Media) > 0 && strings.TrimSpace(result.ForLLM) == "" { + result.ForLLM = "[Tool returned media content; omitted from model context and registered as a media attachment.]" + } + + return result +} + +func sanitizeToolLLMContent(text string) string { + trimmed := strings.TrimSpace(text) + if trimmed == "" { + return text + } + if inlineMarkdownDataURLRe.MatchString(trimmed) || inlineRawDataURLRe.MatchString(trimmed) { + cleaned := inlineMarkdownDataURLRe.ReplaceAllString(trimmed, "") + cleaned = inlineRawDataURLRe.ReplaceAllString(cleaned, "") + cleaned = strings.TrimSpace(cleaned) + if cleaned == "" { + return inlineMediaOmittedMessage + } + return cleaned + "\n" + inlineMediaOmittedMessage + } + if looksLikeLargeBase64Payload(trimmed) { + return largeBase64OmittedMessage + } + return text +} + +func looksLikeLargeBase64Payload(text string) bool { + trimmed := strings.TrimSpace(text) + if len(trimmed) < 1024 { + return false + } + + nonSpace := 0 + base64Like := 0 + spaceCount := 0 + + for _, r := range trimmed { + if unicode.IsSpace(r) { + spaceCount++ + continue + } + nonSpace++ + if (r >= 'A' && r <= 'Z') || + (r >= 'a' && r <= 'z') || + (r >= '0' && r <= '9') || + r == '+' || r == '/' || r == '=' { + base64Like++ + } + } + + if nonSpace == 0 { + return false + } + + ratio := float64(base64Like) / float64(nonSpace) + return ratio >= 0.97 && spaceCount <= len(trimmed)/128 +} + +func extractInlineMediaRefs( + text string, + toolName string, + store media.MediaStore, + channel string, + chatID string, + seen map[string]struct{}, +) (cleaned string, refs []string, notes []string) { + cleaned = text + + matches := inlineMarkdownDataURLRe.FindAllStringSubmatch(cleaned, -1) + for _, match := range matches { + if len(match) < 2 { + continue + } + dataURL := match[1] + ref, note := storeInlineDataURL(toolName, store, channel, chatID, dataURL, seen) + if ref != "" { + refs = append(refs, ref) + } + if note != "" { + notes = append(notes, note) + } + cleaned = strings.ReplaceAll(cleaned, match[0], "") + } + + rawMatches := inlineRawDataURLRe.FindAllString(cleaned, -1) + for _, dataURL := range rawMatches { + ref, note := storeInlineDataURL(toolName, store, channel, chatID, dataURL, seen) + if ref != "" { + refs = append(refs, ref) + } + if note != "" { + notes = append(notes, note) + } + cleaned = strings.ReplaceAll(cleaned, dataURL, "") + } + + return strings.TrimSpace(cleaned), refs, notes +} + +func storeInlineDataURL( + toolName string, + store media.MediaStore, + channel string, + chatID string, + dataURL string, + seen map[string]struct{}, +) (ref string, note string) { + dataURL = strings.TrimSpace(dataURL) + if _, ok := seen[dataURL]; ok { + return "", "" + } + seen[dataURL] = struct{}{} + + if !strings.HasPrefix(strings.ToLower(dataURL), "data:") { + return "", "" + } + + comma := strings.IndexByte(dataURL, ',') + if comma <= 5 { + return "", "[Tool returned inline media content that could not be parsed.]" + } + + metaPart := dataURL[:comma] + payload := dataURL[comma+1:] + if !strings.Contains(strings.ToLower(metaPart), ";base64") { + return "", "[Tool returned inline media content that was not base64-encoded.]" + } + + mimeType := strings.TrimSpace(strings.TrimPrefix(metaPart, "data:")) + if semi := strings.IndexByte(mimeType, ';'); semi >= 0 { + mimeType = mimeType[:semi] + } + if mimeType == "" { + mimeType = "application/octet-stream" + } + + payload = strings.NewReplacer("\n", "", "\r", "", "\t", "", " ", "").Replace(payload) + decoded, err := base64.StdEncoding.DecodeString(payload) + if err != nil { + return "", fmt.Sprintf("[Tool returned inline media content (%s) that could not be decoded.]", mimeType) + } + + dir := media.TempDir() + if err = os.MkdirAll(dir, 0o700); err != nil { + return "", fmt.Sprintf("[Tool returned inline media content (%s) but it could not be stored.]", mimeType) + } + + ext := extensionForMIMEType(mimeType) + tmpFile, err := os.CreateTemp(dir, "tool-inline-*"+ext) + if err != nil { + return "", fmt.Sprintf("[Tool returned inline media content (%s) but it could not be stored.]", mimeType) + } + tmpPath := tmpFile.Name() + if _, err = tmpFile.Write(decoded); err != nil { + tmpFile.Close() + _ = os.Remove(tmpPath) + return "", fmt.Sprintf("[Tool returned inline media content (%s) but it could not be stored.]", mimeType) + } + if err = tmpFile.Close(); err != nil { + _ = os.Remove(tmpPath) + return "", fmt.Sprintf("[Tool returned inline media content (%s) but it could not be stored.]", mimeType) + } + + filename := sanitizeIdentifierComponent(toolName) + ext + scope := fmt.Sprintf( + "tool:inline:%s:%s:%s:%d", + sanitizeIdentifierComponent(toolName), + channel, + chatID, + time.Now().UnixNano(), + ) + + ref, err = store.Store(tmpPath, media.MediaMeta{ + Filename: filename, + ContentType: mimeType, + Source: fmt.Sprintf("tool:inline:%s", sanitizeIdentifierComponent(toolName)), + }, scope) + if err != nil { + _ = os.Remove(tmpPath) + return "", fmt.Sprintf("[Tool returned inline media content (%s) but it could not be registered.]", mimeType) + } + + return ref, fmt.Sprintf(inlineMediaStoredMessage, mimeType) +} + +func extensionForMIMEType(mimeType string) string { + if mimeType == "" { + return ".bin" + } + if exts, err := mime.ExtensionsByType(mimeType); err == nil && len(exts) > 0 { + return exts[0] + } + + switch strings.ToLower(mimeType) { + case "image/jpeg": + return ".jpg" + case "image/png": + return ".png" + case "image/gif": + return ".gif" + case "image/webp": + return ".webp" + case "audio/wav", "audio/x-wav": + return ".wav" + case "audio/mpeg": + return ".mp3" + case "audio/ogg": + return ".ogg" + case "video/mp4": + return ".mp4" + default: + return filepath.Ext(mimeType) + } +} diff --git a/pkg/tools/registry.go b/pkg/tools/registry.go index 0b0f51cc1..902eb4423 100644 --- a/pkg/tools/registry.go +++ b/pkg/tools/registry.go @@ -9,6 +9,7 @@ import ( "time" "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/media" "github.com/sipeed/picoclaw/pkg/providers" ) @@ -19,9 +20,14 @@ type ToolEntry struct { } type ToolRegistry struct { - tools map[string]*ToolEntry - mu sync.RWMutex - version atomic.Uint64 // incremented on Register/RegisterHidden for cache invalidation + tools map[string]*ToolEntry + mu sync.RWMutex + version atomic.Uint64 // incremented on Register/RegisterHidden for cache invalidation + mediaStore media.MediaStore +} + +type mediaStoreAware interface { + SetMediaStore(store media.MediaStore) } func NewToolRegistry() *ToolRegistry { @@ -43,6 +49,9 @@ func (r *ToolRegistry) Register(tool Tool) { IsCore: true, TTL: 0, // Core tools do not use TTL } + if aware, ok := tool.(mediaStoreAware); ok && r.mediaStore != nil { + aware.SetMediaStore(r.mediaStore) + } r.version.Add(1) logger.DebugCF("tools", "Registered core tool", map[string]any{"name": name}) } @@ -61,10 +70,27 @@ func (r *ToolRegistry) RegisterHidden(tool Tool) { IsCore: false, TTL: 0, } + if aware, ok := tool.(mediaStoreAware); ok && r.mediaStore != nil { + aware.SetMediaStore(r.mediaStore) + } r.version.Add(1) logger.DebugCF("tools", "Registered hidden tool", map[string]any{"name": name}) } +// SetMediaStore injects a MediaStore into all registered tools that can +// consume it, and remembers it for future registrations. +func (r *ToolRegistry) SetMediaStore(store media.MediaStore) { + r.mu.Lock() + defer r.mu.Unlock() + + r.mediaStore = store + for _, entry := range r.tools { + if aware, ok := entry.Tool.(mediaStoreAware); ok { + aware.SetMediaStore(store) + } + } +} + // PromoteTools atomically sets the TTL for multiple non-core tools. // This prevents a concurrent TickTTL from decrementing between promotions. func (r *ToolRegistry) PromoteTools(names []string, ttl int) { @@ -230,6 +256,8 @@ func (r *ToolRegistry) ExecuteWithContext( } } + result = normalizeToolResult(result, name, r.mediaStore, channel, chatID) + duration := time.Since(start) // Log based on result type @@ -251,7 +279,7 @@ func (r *ToolRegistry) ExecuteWithContext( map[string]any{ "tool": name, "duration_ms": duration.Milliseconds(), - "result_length": len(result.ForLLM), + "result_length": len(result.ContentForLLM()), }) } @@ -346,7 +374,8 @@ func (r *ToolRegistry) Clone() *ToolRegistry { r.mu.RLock() defer r.mu.RUnlock() clone := &ToolRegistry{ - tools: make(map[string]*ToolEntry, len(r.tools)), + tools: make(map[string]*ToolEntry, len(r.tools)), + mediaStore: r.mediaStore, } for name, entry := range r.tools { clone.tools[name] = &ToolEntry{ diff --git a/pkg/tools/registry_test.go b/pkg/tools/registry_test.go index 967758dfa..db52749f6 100644 --- a/pkg/tools/registry_test.go +++ b/pkg/tools/registry_test.go @@ -3,10 +3,13 @@ package tools import ( "context" "errors" + "os" + "path/filepath" "strings" "sync" "testing" + "github.com/sipeed/picoclaw/pkg/media" "github.com/sipeed/picoclaw/pkg/providers" ) @@ -46,6 +49,15 @@ func (m *mockAsyncRegistryTool) ExecuteAsync(_ context.Context, args map[string] return m.result } +type mockMediaStoreAwareTool struct { + mockRegistryTool + store media.MediaStore +} + +func (m *mockMediaStoreAwareTool) SetMediaStore(store media.MediaStore) { + m.store = store +} + // --- helpers --- func newMockTool(name, desc string) *mockRegistryTool { @@ -621,3 +633,102 @@ func TestToolRegistry_Execute_PanicDoesNotAffectOtherTools(t *testing.T) { t.Errorf("expected 'success', got %q", result2.ForLLM) } } + +func TestToolRegistry_SetMediaStore_PropagatesToExistingAndNewTools(t *testing.T) { + r := NewToolRegistry() + store := media.NewFileMediaStore() + + existing := &mockMediaStoreAwareTool{ + mockRegistryTool: *newMockTool("existing", "existing tool"), + } + r.Register(existing) + + r.SetMediaStore(store) + if existing.store != store { + t.Fatal("expected existing tool to receive media store") + } + + later := &mockMediaStoreAwareTool{ + mockRegistryTool: *newMockTool("later", "later tool"), + } + r.Register(later) + + if later.store != store { + t.Fatal("expected newly registered tool to inherit media store") + } +} + +func TestToolRegistry_ExecuteWithContext_SanitizesLargeBase64Payload(t *testing.T) { + r := NewToolRegistry() + payload := strings.Repeat("QUJD", 400) + r.Register(&mockRegistryTool{ + name: "base64_tool", + desc: "returns huge base64", + params: map[string]any{}, + result: SilentResult(payload), + }) + + result := r.ExecuteWithContext(context.Background(), "base64_tool", nil, "telegram", "chat-1", nil) + + if result.ForLLM != largeBase64OmittedMessage { + t.Fatalf("expected sanitized payload, got %q", result.ForLLM) + } +} + +func TestToolRegistry_ExecuteWithContext_ExtractsInlineMediaDataURL(t *testing.T) { + r := NewToolRegistry() + store := media.NewFileMediaStore() + r.SetMediaStore(store) + + payload := "![screenshot](data:image/png;base64,aGVsbG8=)" + r.Register(&mockRegistryTool{ + name: "inline_media_tool", + desc: "returns inline data url", + params: map[string]any{}, + result: SilentResult(payload), + }) + + result := r.ExecuteWithContext(context.Background(), "inline_media_tool", nil, "telegram", "chat-42", nil) + + if len(result.Media) != 1 { + t.Fatalf("expected 1 media ref, got %d", len(result.Media)) + } + if strings.Contains(result.ForLLM, "data:image/png;base64") { + t.Fatalf("expected inline data URL to be stripped from ForLLM, got %q", result.ForLLM) + } + if !strings.Contains(result.ForLLM, "registered as a media attachment") { + t.Fatalf("expected delivery note in ForLLM, got %q", result.ForLLM) + } + + path, err := store.Resolve(result.Media[0]) + if err != nil { + t.Fatalf("expected stored media ref to resolve: %v", err) + } + if _, err := os.Stat(path); err != nil { + t.Fatalf("expected stored media file to exist: %v", err) + } + if filepath.Ext(path) != ".png" { + t.Fatalf("expected stored inline media to use png extension, got %q", path) + } +} + +func TestToolRegistry_ExecuteWithContext_SanitizesInlineMediaWithoutStore(t *testing.T) { + r := NewToolRegistry() + + payload := "before ![img](data:image/png;base64,aGVsbG8=) after" + r.Register(&mockRegistryTool{ + name: "inline_media_no_store", + desc: "returns inline data url without store", + params: map[string]any{}, + result: SilentResult(payload), + }) + + result := r.ExecuteWithContext(context.Background(), "inline_media_no_store", nil, "telegram", "chat-42", nil) + + if strings.Contains(result.ForLLM, "data:image/png;base64") { + t.Fatalf("expected inline data URL to be removed from ForLLM, got %q", result.ForLLM) + } + if !strings.Contains(result.ForLLM, inlineMediaOmittedMessage) { + t.Fatalf("expected inline media omission note, got %q", result.ForLLM) + } +} diff --git a/pkg/tools/result.go b/pkg/tools/result.go index cab833284..f75ce6d40 100644 --- a/pkg/tools/result.go +++ b/pkg/tools/result.go @@ -1,6 +1,14 @@ package tools -import "encoding/json" +import ( + "encoding/json" + "strings" +) + +const ( + handledToolLLMNote = "The requested output has already been delivered to the user in the current chat. Do not call send_file or any other delivery tool again. If you reply, provide only a brief confirmation." + artifactPathsLLMNote = "Use `send_file` with one of these paths to send it to the user, or use file/exec tools to save it inside the workspace if requested." +) // ToolResult represents the structured return value from tool execution. // It provides clear semantics for different types of results and supports @@ -34,6 +42,48 @@ type ToolResult struct { // Media contains media store refs produced by this tool. // When non-empty, the agent will publish these as OutboundMediaMessage. Media []string `json:"media,omitempty"` + + // ArtifactTags exposes local artifact paths back to the LLM in a structured + // form, e.g. "[file:/tmp/example.png]". This is used when a tool produced a + // reusable local artifact but did not deliver it to the user yet. + ArtifactTags []string `json:"artifact_tags,omitempty"` + + // ResponseHandled indicates that this tool execution already satisfied the + // user's request at the channel/output level, so the agent loop can stop + // without a follow-up assistant response. + ResponseHandled bool `json:"response_handled,omitempty"` +} + +// ContentForLLM returns the normalized textual content to append to the +// conversation after a tool call. Errors fall back to Err when ForLLM is empty. +func (tr *ToolResult) ContentForLLM() string { + if tr == nil { + return "" + } + content := tr.ForLLM + if content == "" && tr.Err != nil { + content = tr.Err.Error() + } + if tr.ResponseHandled { + if content == "" { + return handledToolLLMNote + } + if !strings.Contains(content, handledToolLLMNote) { + content += "\n" + handledToolLLMNote + } + } + if len(tr.ArtifactTags) > 0 { + artifactNote := "Local artifact paths: " + strings.Join(tr.ArtifactTags, " ") + "\n" + artifactPathsLLMNote + if content == "" { + content = artifactNote + } else if !strings.Contains(content, artifactNote) { + content += "\n" + artifactNote + } + } + if content != "" { + return content + } + return "" } // NewToolResult creates a basic ToolResult with content for the LLM. @@ -158,3 +208,9 @@ func (tr *ToolResult) WithError(err error) *ToolResult { tr.Err = err return tr } + +// WithResponseHandled marks the tool result as already delivered to the user. +func (tr *ToolResult) WithResponseHandled() *ToolResult { + tr.ResponseHandled = true + return tr +} diff --git a/pkg/tools/result_test.go b/pkg/tools/result_test.go index a234e33f3..5f08cb4fa 100644 --- a/pkg/tools/result_test.go +++ b/pkg/tools/result_test.go @@ -3,6 +3,7 @@ package tools import ( "encoding/json" "errors" + "strings" "testing" ) @@ -227,3 +228,41 @@ func TestToolResultJSONStructure(t *testing.T) { t.Errorf("Expected silent false, got %v", parsed["silent"]) } } + +func TestToolResultContentForLLM_AppendsHandledDeliveryNote(t *testing.T) { + result := MediaResult("Screenshot attached.", []string{"media://example"}).WithResponseHandled() + + content := result.ContentForLLM() + if !strings.Contains(content, "Screenshot attached.") { + t.Fatalf("expected original content in ContentForLLM, got %q", content) + } + if !strings.Contains(content, handledToolLLMNote) { + t.Fatalf("expected handled delivery note in ContentForLLM, got %q", content) + } +} + +func TestToolResultContentForLLM_UsesHandledDeliveryNoteWhenEmpty(t *testing.T) { + result := (&ToolResult{}).WithResponseHandled() + + if got := result.ContentForLLM(); got != handledToolLLMNote { + t.Fatalf("ContentForLLM() = %q, want %q", got, handledToolLLMNote) + } +} + +func TestToolResultContentForLLM_AppendsArtifactPaths(t *testing.T) { + result := &ToolResult{ + ForLLM: "Artifact created.", + ArtifactTags: []string{"[file:/tmp/example.png]"}, + } + + content := result.ContentForLLM() + if !strings.Contains(content, "Artifact created.") { + t.Fatalf("expected original content in ContentForLLM, got %q", content) + } + if !strings.Contains(content, "Local artifact paths: [file:/tmp/example.png]") { + t.Fatalf("expected artifact path note in ContentForLLM, got %q", content) + } + if !strings.Contains(content, artifactPathsLLMNote) { + t.Fatalf("expected artifact guidance note in ContentForLLM, got %q", content) + } +} diff --git a/pkg/tools/send_file.go b/pkg/tools/send_file.go index a67bd4210..a59ad56aa 100644 --- a/pkg/tools/send_file.go +++ b/pkg/tools/send_file.go @@ -141,7 +141,7 @@ func (t *SendFileTool) Execute(ctx context.Context, args map[string]any) *ToolRe return ErrorResult(fmt.Sprintf("failed to register media: %v", err)) } - return MediaResult(fmt.Sprintf("File %q sent to user", filename), []string{ref}) + return MediaResult(fmt.Sprintf("File %q sent to user", filename), []string{ref}).WithResponseHandled() } // detectMediaType determines the MIME type of a file. diff --git a/pkg/tools/send_file_test.go b/pkg/tools/send_file_test.go index 6daaab31c..cfe5b43e1 100644 --- a/pkg/tools/send_file_test.go +++ b/pkg/tools/send_file_test.go @@ -104,6 +104,9 @@ func TestSendFileTool_Success(t *testing.T) { if result.Media[0][:8] != "media://" { t.Errorf("expected media:// ref, got %q", result.Media[0]) } + if !result.ResponseHandled { + t.Fatal("expected send_file success to mark response handled") + } } func TestSendFileTool_CustomFilename(t *testing.T) { diff --git a/pkg/tools/toolloop.go b/pkg/tools/toolloop.go index 244f0d4a2..387813e94 100644 --- a/pkg/tools/toolloop.go +++ b/pkg/tools/toolloop.go @@ -159,10 +159,7 @@ func RunToolLoop( // Append results in original order for _, r := range results { - contentForLLM := r.result.ForLLM - if contentForLLM == "" && r.result.Err != nil { - contentForLLM = r.result.Err.Error() - } + contentForLLM := r.result.ContentForLLM() messages = append(messages, providers.Message{ Role: "tool", From 930dd028f16742f52f33f7d8cc5ee780aa022ff2 Mon Sep 17 00:00:00 2001 From: afjcjsbx Date: Sun, 22 Mar 2026 13:47:23 +0100 Subject: [PATCH 02/39] fix err and placeholder --- pkg/channels/manager.go | 46 ++++++++++++++++++-- pkg/channels/manager_test.go | 84 ++++++++++++++++++++++++++++++++++++ 2 files changed, 127 insertions(+), 3 deletions(-) diff --git a/pkg/channels/manager.go b/pkg/channels/manager.go index f4a64807e..fec4922c3 100644 --- a/pkg/channels/manager.go +++ b/pkg/channels/manager.go @@ -206,6 +206,40 @@ func (m *Manager) preSend(ctx context.Context, name string, msg bus.OutboundMess return false } +// preSendMedia handles typing stop, reaction undo, and placeholder cleanup +// before sending media attachments. Unlike preSend for text messages, media +// delivery never edits the placeholder because there is no text payload to +// replace it with; it only attempts to delete the placeholder when possible. +func (m *Manager) preSendMedia(ctx context.Context, name string, msg bus.OutboundMediaMessage, ch Channel) { + key := name + ":" + msg.ChatID + + // 1. Stop typing + if v, loaded := m.typingStops.LoadAndDelete(key); loaded { + if entry, ok := v.(typingEntry); ok { + entry.stop() // idempotent, safe + } + } + + // 2. Undo reaction + if v, loaded := m.reactionUndos.LoadAndDelete(key); loaded { + if entry, ok := v.(reactionEntry); ok { + entry.undo() // idempotent, safe + } + } + + // 3. Clear any finalized stream marker for this chat before media delivery. + m.streamActive.LoadAndDelete(key) + + // 4. Delete placeholder if present. + if v, loaded := m.placeholders.LoadAndDelete(key); loaded { + if entry, ok := v.(placeholderEntry); ok && entry.id != "" { + if deleter, ok := ch.(MessageDeleter); ok { + deleter.DeleteMessage(ctx, msg.ChatID, entry.id) // best effort + } + } + } +} + func NewManager(cfg *config.Config, messageBus *bus.MessageBus, store media.MediaStore) (*Manager, error) { m := &Manager{ channels: make(map[string]Channel), @@ -779,7 +813,8 @@ func (m *Manager) runMediaWorker(ctx context.Context, name string, w *channelWor } // sendMediaWithRetry sends a media message through the channel with rate limiting and -// retry logic. It returns nil on success, or the last error after retries. +// retry logic. It returns nil on success, or the last error after retries, +// including when the channel does not support MediaSender. func (m *Manager) sendMediaWithRetry( ctx context.Context, name string, @@ -788,10 +823,12 @@ func (m *Manager) sendMediaWithRetry( ) error { ms, ok := w.ch.(MediaSender) if !ok { - logger.DebugCF("channels", "Channel does not support MediaSender, skipping media", map[string]any{ + err := fmt.Errorf("channel %q does not support media sending", name) + logger.WarnCF("channels", "Channel does not support MediaSender", map[string]any{ "channel": name, + "error": err.Error(), }) - return nil + return err } // Rate limit: wait for token @@ -799,6 +836,9 @@ func (m *Manager) sendMediaWithRetry( return err } + // Pre-send: stop typing and clean up any placeholder before sending media. + m.preSendMedia(ctx, name, msg, w.ch) + var lastErr error for attempt := 0; attempt <= maxRetries; attempt++ { lastErr = ms.SendMedia(ctx, msg) diff --git a/pkg/channels/manager_test.go b/pkg/channels/manager_test.go index 6a5dd7e30..b4fd2ba3d 100644 --- a/pkg/channels/manager_test.go +++ b/pkg/channels/manager_test.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "strings" "sync" "sync/atomic" "testing" @@ -57,6 +58,26 @@ func (m *mockMediaChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaM return nil } +type mockDeletingMediaChannel struct { + mockMediaChannel + deleteCalls int + lastDeleted struct { + chatID string + messageID string + } +} + +func (m *mockDeletingMediaChannel) DeleteMessage( + _ context.Context, + chatID string, + messageID string, +) error { + m.deleteCalls++ + m.lastDeleted.chatID = chatID + m.lastDeleted.messageID = messageID + return nil +} + // newTestManager creates a minimal Manager suitable for unit tests. func newTestManager() *Manager { return &Manager{ @@ -278,6 +299,69 @@ func TestSendMedia_PropagatesFailure(t *testing.T) { } } +func TestSendMedia_UnsupportedChannelReturnsError(t *testing.T) { + m := newTestManager() + ch := &mockChannel{ + sendFn: func(_ context.Context, _ bus.OutboundMessage) error { + return nil + }, + } + w := &channelWorker{ + ch: ch, + limiter: rate.NewLimiter(rate.Inf, 1), + } + m.channels["test"] = ch + m.workers["test"] = w + + err := m.SendMedia(context.Background(), bus.OutboundMediaMessage{ + Channel: "test", + ChatID: "chat1", + Parts: []bus.MediaPart{{Ref: "media://abc"}}, + }) + if err == nil { + t.Fatal("expected SendMedia to return error for unsupported channel") + } + if !strings.Contains(err.Error(), "does not support media sending") { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestSendMedia_DeletesPlaceholderBeforeSending(t *testing.T) { + m := newTestManager() + ch := &mockDeletingMediaChannel{ + mockMediaChannel: mockMediaChannel{ + sendMediaFn: func(_ context.Context, _ bus.OutboundMediaMessage) error { + return nil + }, + }, + } + w := &channelWorker{ + ch: ch, + limiter: rate.NewLimiter(rate.Inf, 1), + } + m.channels["test"] = ch + m.workers["test"] = w + m.RecordPlaceholder("test", "chat1", "placeholder-1") + + err := m.SendMedia(context.Background(), bus.OutboundMediaMessage{ + Channel: "test", + ChatID: "chat1", + Parts: []bus.MediaPart{{Ref: "media://abc"}}, + }) + if err != nil { + t.Fatalf("SendMedia() error = %v", err) + } + if ch.deleteCalls != 1 { + t.Fatalf("expected placeholder delete to be called once, got %d", ch.deleteCalls) + } + if ch.lastDeleted.chatID != "chat1" || ch.lastDeleted.messageID != "placeholder-1" { + t.Fatalf("unexpected placeholder deletion target: %+v", ch.lastDeleted) + } + if len(ch.sentMediaMessages) != 1 { + t.Fatalf("expected media to be sent once, got %d", len(ch.sentMediaMessages)) + } +} + func TestSendWithRetry_UnknownError(t *testing.T) { m := newTestManager() var callCount int From b90c5007f69b13fa633b47b7b9aeb9ddd2b73cc9 Mon Sep 17 00:00:00 2001 From: afjcjsbx Date: Sun, 22 Mar 2026 23:36:25 +0100 Subject: [PATCH 03/39] resolve conflicts --- pkg/agent/context.go | 68 +++++++ pkg/agent/loop.go | 342 ++++++++++++++++++++++------------- pkg/agent/loop_media.go | 18 ++ pkg/agent/loop_test.go | 82 ++++++++- pkg/commands/builtin.go | 1 + pkg/commands/builtin_test.go | 5 +- pkg/commands/cmd_list.go | 17 ++ pkg/commands/cmd_use.go | 9 + pkg/commands/request.go | 5 + pkg/commands/runtime.go | 1 + 10 files changed, 418 insertions(+), 130 deletions(-) create mode 100644 pkg/commands/cmd_use.go diff --git a/pkg/agent/context.go b/pkg/agent/context.go index 022230d41..d905674f3 100644 --- a/pkg/agent/context.go +++ b/pkg/agent/context.go @@ -508,6 +508,7 @@ func (cb *ContextBuilder) BuildMessages( currentMessage string, media []string, channel, chatID, senderID, senderDisplayName string, + activeSkills ...string, ) []providers.Message { messages := []providers.Message{} @@ -541,6 +542,11 @@ func (cb *ContextBuilder) BuildMessages( {Type: "text", Text: dynamicCtx}, } + if skillsText := cb.buildActiveSkillsContext(activeSkills); skillsText != "" { + stringParts = append(stringParts, skillsText) + contentBlocks = append(contentBlocks, providers.ContentBlock{Type: "text", Text: skillsText}) + } + if summary != "" { summaryText := fmt.Sprintf( "CONTEXT_SUMMARY: The following is an approximate summary of prior conversation "+ @@ -748,6 +754,68 @@ func (cb *ContextBuilder) AddAssistantMessage( return messages } +func (cb *ContextBuilder) buildActiveSkillsContext(skillNames []string) string { + if cb.skillsLoader == nil || len(skillNames) == 0 { + return "" + } + + var ordered []string + seen := make(map[string]struct{}, len(skillNames)) + for _, name := range skillNames { + canonical, ok := cb.ResolveSkillName(name) + if !ok { + continue + } + if _, exists := seen[canonical]; exists { + continue + } + seen[canonical] = struct{}{} + ordered = append(ordered, canonical) + } + if len(ordered) == 0 { + return "" + } + + content := cb.skillsLoader.LoadSkillsForContext(ordered) + if strings.TrimSpace(content) == "" { + return "" + } + + return fmt.Sprintf(`# Active Skills + +The following skills are active for this request. Follow them when relevant. + +%s`, content) +} + +func (cb *ContextBuilder) ListSkillNames() []string { + if cb.skillsLoader == nil { + return nil + } + + allSkills := cb.skillsLoader.ListSkills() + names := make([]string, 0, len(allSkills)) + for _, skill := range allSkills { + names = append(names, skill.Name) + } + return names +} + +func (cb *ContextBuilder) ResolveSkillName(name string) (string, bool) { + name = strings.TrimSpace(name) + if name == "" || cb.skillsLoader == nil { + return "", false + } + + for _, skill := range cb.skillsLoader.ListSkills() { + if strings.EqualFold(skill.Name, name) { + return skill.Name, true + } + } + + return "", false +} + // GetSkillsInfo returns information about loaded skills. func (cb *ContextBuilder) GetSkillsInfo() map[string]any { allSkills := cb.skillsLoader.ListSkills() diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index 2a8cb883b..861be59db 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -56,6 +56,7 @@ type AgentLoop struct { mcp mcpRuntime hookRuntime hookRuntime steering *steeringQueue + pendingSkills sync.Map mu sync.RWMutex // Concurrent turn management (from HEAD) @@ -77,6 +78,7 @@ type processOptions struct { SenderID string // Current sender ID for dynamic context SenderDisplayName string // Current sender display name for dynamic context UserMessage string // User message content (may include prefix) + ForcedSkills []string // Skills explicitly requested for this message SystemPromptOverride string // Override the default system prompt (Used by SubTurns) Media []string // media:// refs from inbound message InitialSteeringMessages []providers.Message // Steering messages from refactor/agent @@ -1310,6 +1312,15 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) return response, nil } + if pending := al.takePendingSkills(opts.SessionKey); len(pending) > 0 { + opts.ForcedSkills = append(opts.ForcedSkills, pending...) + logger.InfoCF("agent", "Applying pending skill override", + map[string]any{ + "session_key": opts.SessionKey, + "skills": strings.Join(pending, ","), + }) + } + return al.runAgentLoop(ctx, agent, opts) } @@ -1454,16 +1465,6 @@ func (al *AgentLoop) runAgentLoop( ts := newTurnState(agent, opts, al.newTurnEventScope(agent.ID, opts.SessionKey)) result, err := al.runTurn(ctx, ts) - // Resolve media:// refs: images→base64 data URLs, non-images→local paths in content - cfg := al.GetConfig() - maxMediaSize := cfg.Agents.Defaults.GetMaxMediaSize() - messages = resolveMediaRefs(messages, al.mediaStore, maxMediaSize) - - // 2. Save user message to session - agent.Sessions.AddMessage(opts.SessionKey, "user", opts.UserMessage) - - // 3. Run LLM iteration loop - finalContent, iteration, responseHandled, err := al.runLLMIteration(ctx, agent, messages, opts) if err != nil { return "", err } @@ -1471,22 +1472,6 @@ func (al *AgentLoop) runAgentLoop( return "", nil } - if responseHandled { - agent.Sessions.Save(opts.SessionKey) - - if opts.EnableSummary { - al.maybeSummarize(agent, opts.SessionKey, opts.Channel, opts.ChatID) - } - - logger.InfoCF("agent", "Response already handled by tool output", - map[string]any{ - "agent_id": agent.ID, - "session_key": opts.SessionKey, - "iterations": iteration, - }) - return "", nil - } - for _, followUp := range result.followUps { if pubErr := al.bus.PublishInbound(ctx, followUp); pubErr != nil { logger.WarnCF("agent", "Failed to publish follow-up after turn", @@ -1575,8 +1560,6 @@ func (al *AgentLoop) handleReasoning( } } -const handledToolResponseSummary = "Requested output delivered via tool attachment." - func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, error) { turnCtx, turnCancel := context.WithCancel(ctx) defer turnCancel() @@ -1631,6 +1614,7 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er ts.chatID, ts.opts.SenderID, ts.opts.SenderDisplayName, + activeSkillNames(ts.agent, ts.opts)..., ) cfg := al.GetConfig() @@ -1660,6 +1644,7 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er newHistory, newSummary, ts.userMessage, ts.media, ts.channel, ts.chatID, ts.opts.SenderID, ts.opts.SenderDisplayName, + activeSkillNames(ts.agent, ts.opts)..., ) messages = resolveMediaRefs(messages, al.mediaStore, maxMediaSize) } @@ -1682,59 +1667,8 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er activeCandidates, activeModel := al.selectCandidates(ts.agent, ts.userMessage, messages) pendingMessages := append([]providers.Message(nil), ts.opts.InitialSteeringMessages...) -const handledToolResponseSummary = "Requested output delivered via tool attachment." - -func (al *AgentLoop) buildOutboundMediaMessage( - channel string, - chatID string, - refs []string, -) bus.OutboundMediaMessage { - parts := make([]bus.MediaPart, 0, len(refs)) - for _, ref := range refs { - part := bus.MediaPart{Ref: ref} - if al.mediaStore != nil { - if _, meta, err := al.mediaStore.ResolveWithMeta(ref); err == nil { - part.Filename = meta.Filename - part.ContentType = meta.ContentType - part.Type = inferMediaType(meta.Filename, meta.ContentType) - } - } - parts = append(parts, part) - } - return bus.OutboundMediaMessage{ - Channel: channel, - ChatID: chatID, - Parts: parts, - } -} - -func (al *AgentLoop) buildArtifactTags(refs []string) []string { - if al.mediaStore == nil || len(refs) == 0 { - return nil - } - - tags := make([]string, 0, len(refs)) - for _, ref := range refs { - localPath, meta, err := al.mediaStore.ResolveWithMeta(ref) - if err != nil { - continue - } - mime := detectMIME(localPath, meta) - tags = append(tags, buildPathTag(mime, localPath)) - } - return tags -} - -// runLLMIteration executes the LLM call loop with tool handling. -// Returns (finalContent, iteration, responseHandled, error). -func (al *AgentLoop) runLLMIteration( - ctx context.Context, - agent *AgentInstance, - messages []providers.Message, - opts processOptions, -) (string, int, bool, error) { - iteration := 0 var finalContent string + const handledToolResponseSummary = "Requested output delivered via tool attachment." turnLoop: for ts.currentIteration() < ts.agent.MaxIterations || len(pendingMessages) > 0 || func() bool { @@ -2078,6 +2012,7 @@ turnLoop: newHistory, newSummary, "", nil, ts.channel, ts.chatID, "", "", // Empty SenderID and SenderDisplayName for retry + activeSkillNames(ts.agent, ts.opts)..., ) callMessages = messages if gracefulTerminal { @@ -2138,7 +2073,6 @@ turnLoop: if response.Usage != nil { innerTS.SetLastUsage(response.Usage) } - return "", iteration, false, fmt.Errorf("LLM call failed after retries: %w", err) } go al.handleReasoning( @@ -2189,7 +2123,6 @@ turnLoop: "agent_id": ts.agent.ID, "iteration": iteration, "content_chars": len(finalContent), - "streamed": streamer != nil, }) break } @@ -2211,6 +2144,7 @@ turnLoop: "iteration": iteration, }) + allResponsesHandled := len(normalizedToolCalls) > 0 assistantMsg := providers.Message{ Role: "assistant", Content: response.Content, @@ -2460,18 +2394,11 @@ turnLoop: if toolResult == nil { toolResult = tools.ErrorResult("hook returned nil tool result") } - - if !toolResult.Silent && toolResult.ForUser != "" && ts.opts.SendResponse { - allResponsesHandled := len(agentResults) > 0 - - // Process results in original order (send to user, save to session) - for _, r := range agentResults { - if !r.result.ResponseHandled { + if !toolResult.ResponseHandled { allResponsesHandled = false } - // Send ForUser content to user immediately if not Silent - if !r.result.Silent && r.result.ForUser != "" && opts.SendResponse { + if !toolResult.Silent && toolResult.ForUser != "" && ts.opts.SendResponse { al.bus.PublishOutbound(ctx, bus.OutboundMessage{ Channel: ts.channel, ChatID: ts.chatID, @@ -2493,24 +2420,7 @@ turnLoop: part.Filename = meta.Filename part.ContentType = meta.ContentType part.Type = inferMediaType(meta.Filename, meta.ContentType) - // If tool returned media refs, publish them as outbound media only when the - // tool explicitly marked the user-visible delivery as already handled. - if len(r.result.Media) > 0 { - outboundMedia := al.buildOutboundMediaMessage(opts.Channel, opts.ChatID, r.result.Media) - if r.result.ResponseHandled { - if al.channelManager != nil { - if err := al.channelManager.SendMedia(ctx, outboundMedia); err != nil { - allResponsesHandled = false - logger.WarnCF("agent", "Synchronous media send failed, falling back to bus delivery", - map[string]any{ - "agent_id": agent.ID, - "tool": r.tc.Name, - "error": err.Error(), - }) - al.bus.PublishOutboundMedia(ctx, outboundMedia) } - } else { - al.bus.PublishOutboundMedia(ctx, outboundMedia) } parts = append(parts, part) } @@ -2521,14 +2431,10 @@ turnLoop: }) } - contentForLLM := toolResult.ForLLM - if contentForLLM == "" && toolResult.Err != nil { - contentForLLM = toolResult.Err.Error() - // Determine content for LLM based on tool result - if len(r.result.Media) > 0 && !r.result.ResponseHandled { - r.result.ArtifactTags = al.buildArtifactTags(r.result.Media) + if len(toolResult.Media) > 0 && !toolResult.ResponseHandled { + toolResult.ArtifactTags = buildArtifactTags(al.mediaStore, toolResult.Media) } - contentForLLM := r.result.ContentForLLM() + contentForLLM := toolResult.ContentForLLM() toolResultMsg := providers.Message{ Role: "tool", @@ -2617,31 +2523,48 @@ turnLoop: } } - ts.agent.Tools.TickTTL() if allResponsesHandled { summaryMsg := providers.Message{ Role: "assistant", Content: handledToolResponseSummary, } messages = append(messages, summaryMsg) - agent.Sessions.AddFullMessage(opts.SessionKey, summaryMsg) + if !ts.opts.NoHistory { + ts.agent.Sessions.AddMessage(ts.sessionKey, summaryMsg.Role, summaryMsg.Content) + ts.recordPersistedMessage(summaryMsg) + if err := ts.agent.Sessions.Save(ts.sessionKey); err != nil { + turnStatus = TurnEndStatusError + al.emitEvent( + EventKindError, + ts.eventMeta("runTurn", "turn.error"), + ErrorPayload{ + Stage: "session_save", + Message: err.Error(), + }, + ) + return turnResult{}, err + } + } + if ts.opts.EnableSummary { + al.maybeSummarize(ts.agent, ts.sessionKey, ts.scope) + } + ts.setPhase(TurnPhaseCompleted) + ts.setFinalContent("") logger.InfoCF("agent", "Tool output satisfied delivery; ending turn without follow-up LLM", map[string]any{ - "agent_id": agent.ID, + "agent_id": ts.agent.ID, "iteration": iteration, - "tool_count": len(agentResults), + "tool_count": len(normalizedToolCalls), }) - return "", iteration, true, nil + return turnResult{ + finalContent: "", + status: turnStatus, + followUps: append([]bus.InboundMessage(nil), ts.followUps...), + }, nil } - // Tick down TTL of discovered tools after processing tool results. - // Only reached when tool calls were made (the loop continues); - // the break on no-tool-call responses skips this. - // NOTE: This is safe because processMessage is sequential per agent. - // If per-agent concurrency is added, TTL consistency between - // ToProviderDefs and Get must be re-evaluated. - agent.Tools.TickTTL() + ts.agent.Tools.TickTTL() logger.DebugCF("agent", "TTL tick after tool execution", map[string]any{ "agent_id": ts.agent.ID, "iteration": iteration, }) @@ -2664,7 +2587,6 @@ turnLoop: return al.abortTurn(ts) } - return finalContent, iteration, false, nil if finalContent == "" { if ts.currentIteration() >= ts.agent.MaxIterations && ts.agent.MaxIterations > 0 { finalContent = toolLimitResponse @@ -3212,6 +3134,10 @@ func (al *AgentLoop) handleCommand( return "", false } + if matched, handled, reply := al.applyExplicitSkillCommand(msg.Content, agent, opts); matched { + return reply, handled + } + if al.cmdRegistry == nil { return "", false } @@ -3245,6 +3171,97 @@ func (al *AgentLoop) handleCommand( } } +func activeSkillNames(agent *AgentInstance, opts processOptions) []string { + if agent == nil { + return nil + } + + combined := make([]string, 0, len(agent.SkillsFilter)+len(opts.ForcedSkills)) + combined = append(combined, agent.SkillsFilter...) + combined = append(combined, opts.ForcedSkills...) + if len(combined) == 0 { + return nil + } + + var resolved []string + seen := make(map[string]struct{}, len(combined)) + for _, name := range combined { + name = strings.TrimSpace(name) + if name == "" { + continue + } + if agent.ContextBuilder != nil { + if canonical, ok := agent.ContextBuilder.ResolveSkillName(name); ok { + name = canonical + } + } + key := strings.ToLower(name) + if _, ok := seen[key]; ok { + continue + } + seen[key] = struct{}{} + resolved = append(resolved, name) + } + + return resolved +} + +func (al *AgentLoop) applyExplicitSkillCommand( + raw string, + agent *AgentInstance, + opts *processOptions, +) (matched bool, handled bool, reply string) { + cmdName, ok := commands.CommandName(raw) + if !ok || cmdName != "use" { + return false, false, "" + } + + if agent == nil || agent.ContextBuilder == nil { + return true, true, commandsUnavailableSkillMessage() + } + + parts := strings.Fields(strings.TrimSpace(raw)) + if len(parts) < 2 { + return true, true, buildUseCommandHelp(agent) + } + + arg := strings.TrimSpace(parts[1]) + if strings.EqualFold(arg, "clear") || strings.EqualFold(arg, "off") { + if opts != nil { + al.clearPendingSkills(opts.SessionKey) + } + return true, true, "Cleared pending skill override." + } + + skillName, ok := agent.ContextBuilder.ResolveSkillName(arg) + if !ok { + return true, true, fmt.Sprintf("Unknown skill %q.\n\n%s", arg, buildUseCommandHelp(agent)) + } + + if len(parts) < 3 { + if opts == nil || strings.TrimSpace(opts.SessionKey) == "" { + return true, true, commandsUnavailableSkillMessage() + } + al.setPendingSkills(opts.SessionKey, []string{skillName}) + return true, true, fmt.Sprintf( + "Skill %q is armed for your next message. Send your next prompt normally, or use /use clear to cancel.", + skillName, + ) + } + + message := strings.TrimSpace(strings.Join(parts[2:], " ")) + if message == "" { + return true, true, buildUseCommandHelp(agent) + } + + if opts != nil { + opts.ForcedSkills = append(opts.ForcedSkills, skillName) + opts.UserMessage = message + } + + return true, false, "" +} + func (al *AgentLoop) buildCommandsRuntime(agent *AgentInstance, opts *processOptions) *commands.Runtime { registry := al.GetRegistry() cfg := al.GetConfig() @@ -3282,6 +3299,7 @@ func (al *AgentLoop) buildCommandsRuntime(agent *AgentInstance, opts *processOpt return al.reloadFunc() } if agent != nil { + rt.ListSkillNames = agent.ContextBuilder.ListSkillNames rt.GetModelInfo = func() (string, string) { return agent.Model, resolvedCandidateProvider(agent.Candidates, cfg.Agents.Defaults.Provider) } @@ -3334,6 +3352,74 @@ func (al *AgentLoop) buildCommandsRuntime(agent *AgentInstance, opts *processOpt return rt } +func commandsUnavailableSkillMessage() string { + return "Skill commands are unavailable in the current context." +} + +func buildUseCommandHelp(agent *AgentInstance) string { + usage := "Usage:\n/use \n/use \n/use clear" + if agent == nil || agent.ContextBuilder == nil { + return usage + } + + names := agent.ContextBuilder.ListSkillNames() + if len(names) == 0 { + return "No installed skills.\n\n" + usage + } + + return fmt.Sprintf("%s\n\nInstalled Skills:\n- %s", usage, strings.Join(names, "\n- ")) +} + +func (al *AgentLoop) setPendingSkills(sessionKey string, skillNames []string) { + sessionKey = strings.TrimSpace(sessionKey) + if sessionKey == "" || len(skillNames) == 0 { + return + } + + values := make([]string, 0, len(skillNames)) + for _, name := range skillNames { + name = strings.TrimSpace(name) + if name == "" { + continue + } + values = append(values, name) + } + if len(values) == 0 { + return + } + + al.pendingSkills.Store(sessionKey, values) +} + +func (al *AgentLoop) takePendingSkills(sessionKey string) []string { + sessionKey = strings.TrimSpace(sessionKey) + if sessionKey == "" { + return nil + } + + value, ok := al.pendingSkills.LoadAndDelete(sessionKey) + if !ok { + return nil + } + + skills, ok := value.([]string) + if !ok || len(skills) == 0 { + return nil + } + + out := make([]string, len(skills)) + copy(out, skills) + return out +} + +func (al *AgentLoop) clearPendingSkills(sessionKey string) { + sessionKey = strings.TrimSpace(sessionKey) + if sessionKey == "" { + return + } + al.pendingSkills.Delete(sessionKey) +} + func mapCommandError(result commands.ExecuteResult) string { if result.Command == "" { return fmt.Sprintf("Failed to execute command: %v", result.Err) diff --git a/pkg/agent/loop_media.go b/pkg/agent/loop_media.go index 1380f0214..e8314c10d 100644 --- a/pkg/agent/loop_media.go +++ b/pkg/agent/loop_media.go @@ -87,6 +87,24 @@ func resolveMediaRefs(messages []providers.Message, store media.MediaStore, maxS return result } +func buildArtifactTags(store media.MediaStore, refs []string) []string { + if store == nil || len(refs) == 0 { + return nil + } + + tags := make([]string, 0, len(refs)) + for _, ref := range refs { + localPath, meta, err := store.ResolveWithMeta(ref) + if err != nil { + continue + } + mime := detectMIME(localPath, meta) + tags = append(tags, buildPathTag(mime, localPath)) + } + + return tags +} + // detectMIME determines the MIME type from metadata or magic-bytes detection. // Returns empty string if detection fails. func detectMIME(localPath string, meta media.MediaMeta) string { diff --git a/pkg/agent/loop_test.go b/pkg/agent/loop_test.go index 8dbc5fae1..bca96934c 100644 --- a/pkg/agent/loop_test.go +++ b/pkg/agent/loop_test.go @@ -132,6 +132,86 @@ func TestProcessMessage_IncludesCurrentSenderInDynamicContext(t *testing.T) { } } +func TestApplyExplicitSkillCommand_ArmsSkillForNextMessage(t *testing.T) { + al, cfg, _, _, cleanup := newTestAgentLoop(t) + defer cleanup() + + if err := os.MkdirAll(filepath.Join(cfg.Agents.Defaults.Workspace, "skills", "finance-news"), 0o755); err != nil { + t.Fatalf("MkdirAll(skill) error = %v", err) + } + if err := os.WriteFile( + filepath.Join(cfg.Agents.Defaults.Workspace, "skills", "finance-news", "SKILL.md"), + []byte("# Finance News\n\nUse web tools for current finance updates.\n"), + 0o644, + ); err != nil { + t.Fatalf("WriteFile(SKILL.md) error = %v", err) + } + + agent := al.GetRegistry().GetDefaultAgent() + if agent == nil { + t.Fatal("expected default agent") + } + + opts := &processOptions{SessionKey: "agent:main:test"} + matched, handled, reply := al.applyExplicitSkillCommand("/use finance-news", agent, opts) + if !matched { + t.Fatal("expected /use command to match") + } + if !handled { + t.Fatal("expected /use without inline message to be handled immediately") + } + if !strings.Contains(reply, `Skill "finance-news" is armed for your next message`) { + t.Fatalf("unexpected reply: %q", reply) + } + + pending := al.takePendingSkills(opts.SessionKey) + if len(pending) != 1 || pending[0] != "finance-news" { + t.Fatalf("pending skills = %#v, want [finance-news]", pending) + } +} + +func TestApplyExplicitSkillCommand_InlineMessageMutatesOptions(t *testing.T) { + al, cfg, _, _, cleanup := newTestAgentLoop(t) + defer cleanup() + + if err := os.MkdirAll(filepath.Join(cfg.Agents.Defaults.Workspace, "skills", "finance-news"), 0o755); err != nil { + t.Fatalf("MkdirAll(skill) error = %v", err) + } + if err := os.WriteFile( + filepath.Join(cfg.Agents.Defaults.Workspace, "skills", "finance-news", "SKILL.md"), + []byte("# Finance News\n\nUse web tools for current finance updates.\n"), + 0o644, + ); err != nil { + t.Fatalf("WriteFile(SKILL.md) error = %v", err) + } + + agent := al.GetRegistry().GetDefaultAgent() + if agent == nil { + t.Fatal("expected default agent") + } + + opts := &processOptions{ + SessionKey: "agent:main:test", + UserMessage: "/use finance-news dammi le ultime news", + } + matched, handled, reply := al.applyExplicitSkillCommand(opts.UserMessage, agent, opts) + if !matched { + t.Fatal("expected /use command to match") + } + if handled { + t.Fatal("expected /use with inline message to fall through into normal agent execution") + } + if reply != "" { + t.Fatalf("unexpected reply: %q", reply) + } + if opts.UserMessage != "dammi le ultime news" { + t.Fatalf("opts.UserMessage = %q, want %q", opts.UserMessage, "dammi le ultime news") + } + if len(opts.ForcedSkills) != 1 || opts.ForcedSkills[0] != "finance-news" { + t.Fatalf("opts.ForcedSkills = %#v, want [finance-news]", opts.ForcedSkills) + } +} + func TestRecordLastChannel(t *testing.T) { al, cfg, msgBus, provider, cleanup := newTestAgentLoop(t) defer cleanup() @@ -381,7 +461,7 @@ func TestProcessMessage_MediaToolHandledSkipsFollowUpLLMAndFinalText(t *testing. t.Fatal("expected session history to be saved") } last := history[len(history)-1] - if last.Role != "assistant" || last.Content != handledToolResponseSummary { + if last.Role != "assistant" || last.Content != "Requested output delivered via tool attachment." { t.Fatalf("expected handled assistant summary in history, got %+v", last) } } diff --git a/pkg/commands/builtin.go b/pkg/commands/builtin.go index 7bd36b653..39e76f752 100644 --- a/pkg/commands/builtin.go +++ b/pkg/commands/builtin.go @@ -10,6 +10,7 @@ func BuiltinDefinitions() []Definition { helpCommand(), showCommand(), listCommand(), + useCommand(), switchCommand(), checkCommand(), clearCommand(), diff --git a/pkg/commands/builtin_test.go b/pkg/commands/builtin_test.go index 66a84825e..9f73b27b6 100644 --- a/pkg/commands/builtin_test.go +++ b/pkg/commands/builtin_test.go @@ -39,9 +39,12 @@ func TestBuiltinHelpHandler_ReturnsFormattedMessage(t *testing.T) { if !strings.Contains(reply, "/show [model|channel|agents]") { t.Fatalf("/help reply missing /show usage, got %q", reply) } - if !strings.Contains(reply, "/list [models|channels|agents]") { + if !strings.Contains(reply, "/list [models|channels|agents|skills]") { t.Fatalf("/help reply missing /list usage, got %q", reply) } + if !strings.Contains(reply, "/use [message]") { + t.Fatalf("/help reply missing /use usage, got %q", reply) + } } func TestBuiltinShowChannel_PreservesUserVisibleBehavior(t *testing.T) { diff --git a/pkg/commands/cmd_list.go b/pkg/commands/cmd_list.go index bf47b6e9c..7186a6c25 100644 --- a/pkg/commands/cmd_list.go +++ b/pkg/commands/cmd_list.go @@ -47,6 +47,23 @@ func listCommand() Definition { Description: "Registered agents", Handler: agentsHandler(), }, + { + Name: "skills", + Description: "Installed skills", + Handler: func(_ context.Context, req Request, rt *Runtime) error { + if rt == nil || rt.ListSkillNames == nil { + return req.Reply(unavailableMsg) + } + names := rt.ListSkillNames() + if len(names) == 0 { + return req.Reply("No installed skills") + } + return req.Reply(fmt.Sprintf( + "Installed Skills:\n- %s\n\nUse /use to force one for a single request, or /use to apply it to your next message.", + strings.Join(names, "\n- "), + )) + }, + }, }, } } diff --git a/pkg/commands/cmd_use.go b/pkg/commands/cmd_use.go new file mode 100644 index 000000000..4698f5f5e --- /dev/null +++ b/pkg/commands/cmd_use.go @@ -0,0 +1,9 @@ +package commands + +func useCommand() Definition { + return Definition{ + Name: "use", + Description: "Force a specific installed skill for one request", + Usage: "/use [message]", + } +} diff --git a/pkg/commands/request.go b/pkg/commands/request.go index 62ee600f2..233b3ef9c 100644 --- a/pkg/commands/request.go +++ b/pkg/commands/request.go @@ -41,6 +41,11 @@ func parseCommandName(input string) (string, bool) { return name, true } +// CommandName returns the normalized command name for an input if present. +func CommandName(input string) (string, bool) { + return parseCommandName(input) +} + func trimCommandPrefix(token string) (string, bool) { for _, prefix := range commandPrefixes { if strings.HasPrefix(token, prefix) { diff --git a/pkg/commands/runtime.go b/pkg/commands/runtime.go index f714e1ca4..5ba6a1bd2 100644 --- a/pkg/commands/runtime.go +++ b/pkg/commands/runtime.go @@ -10,6 +10,7 @@ type Runtime struct { GetModelInfo func() (name, provider string) ListAgentIDs func() []string ListDefinitions func() []Definition + ListSkillNames func() []string GetEnabledChannels func() []string GetActiveTurn func() any // Returning any to avoid circular dependency with agent package SwitchModel func(value string) (oldModel string, err error) From 388505d7e07388826dc9dc5f3cc5d360bbfc3c6c Mon Sep 17 00:00:00 2001 From: afjcjsbx Date: Sun, 22 Mar 2026 23:39:33 +0100 Subject: [PATCH 04/39] fix lint --- pkg/agent/loop.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index 861be59db..5f816ec34 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -2528,7 +2528,7 @@ turnLoop: Role: "assistant", Content: handledToolResponseSummary, } - messages = append(messages, summaryMsg) + if !ts.opts.NoHistory { ts.agent.Sessions.AddMessage(ts.sessionKey, summaryMsg.Role, summaryMsg.Content) ts.recordPersistedMessage(summaryMsg) From f735b0551cc80a21e2843f93ddad41fef5b882af Mon Sep 17 00:00:00 2001 From: afjcjsbx Date: Sun, 22 Mar 2026 23:46:10 +0100 Subject: [PATCH 05/39] fix --- pkg/agent/loop.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index 5f816ec34..debd99418 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -2201,6 +2201,7 @@ turnLoop: toolArgs = toolReq.Arguments } case HookActionDenyTool: + allResponsesHandled = false denyContent := hookDeniedToolContent("Tool execution denied by hook", decision.Reason) al.emitEvent( EventKindToolExecSkipped, @@ -2240,6 +2241,7 @@ turnLoop: ChatID: ts.chatID, }) if !approval.Approved { + allResponsesHandled = false denyContent := hookDeniedToolContent("Tool execution denied by approval hook", approval.Reason) al.emitEvent( EventKindToolExecSkipped, From 1e98f86fa9e203af3f7b3f7710e36e23d5a9da8f Mon Sep 17 00:00:00 2001 From: afjcjsbx Date: Mon, 23 Mar 2026 00:08:43 +0100 Subject: [PATCH 06/39] fix Ooutboundmedia --- pkg/agent/loop.go | 2 +- pkg/agent/loop_test.go | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index debd99418..c202eff17 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -2413,7 +2413,7 @@ turnLoop: }) } - if len(toolResult.Media) > 0 { + if len(toolResult.Media) > 0 && toolResult.ResponseHandled { parts := make([]bus.MediaPart, 0, len(toolResult.Media)) for _, ref := range toolResult.Media { part := bus.MediaPart{Ref: ref} diff --git a/pkg/agent/loop_test.go b/pkg/agent/loop_test.go index bca96934c..b7442e2fc 100644 --- a/pkg/agent/loop_test.go +++ b/pkg/agent/loop_test.go @@ -522,6 +522,12 @@ func TestProcessMessage_MediaArtifactCanBeForwardedBySendFile(t *testing.T) { default: t.Fatal("expected outbound media from send_file") } + + select { + case extra := <-msgBus.OutboundMediaChan(): + t.Fatalf("expected exactly one outbound media delivery, got extra %+v", extra) + default: + } } // TestAgentLoop_GetStartupInfo verifies startup info contains tools From 1f9d390a6414e5dd3fad662c094c8207e058000d Mon Sep 17 00:00:00 2001 From: Kristjan Kruus Date: Mon, 23 Mar 2026 14:26:51 +0200 Subject: [PATCH 07/39] fix: apply security credentials before config validation in web handlers - Move SecurityCopyFrom() before validateConfig() in PUT and PATCH handlers - Make SecurityCopyFrom() call applySecurityConfig() to populate private fields - Add tests for config save with security-only channel tokens Without this fix, saving config via the web UI fails with 'channels.pico.token is required' (and similar for Telegram/Discord) when tokens are stored in .security.yml, because the validation ran before security credentials were copied to the config struct. --- pkg/config/config.go | 5 ++ web/backend/api/config.go | 23 ++++--- web/backend/api/config_test.go | 116 +++++++++++++++++++++++++++++++++ 3 files changed, 135 insertions(+), 9 deletions(-) diff --git a/pkg/config/config.go b/pkg/config/config.go index 33919d9d7..b58069472 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -1942,6 +1942,11 @@ func (c *Config) ValidateModelList() error { func (c *Config) SecurityCopyFrom(cfg *Config) { c.security = cfg.security + if c.security != nil { + if err := applySecurityConfig(c, c.security); err != nil { + logger.Errorf("failed to apply security config in SecurityCopyFrom: %v", err) + } + } } func MergeAPIKeys(apiKey string, apiKeys []string) []string { diff --git a/web/backend/api/config.go b/web/backend/api/config.go index 7cdfde174..fa2e91dec 100644 --- a/web/backend/api/config.go +++ b/web/backend/api/config.go @@ -54,6 +54,15 @@ func (h *Handler) handleUpdateConfig(w http.ResponseWriter, r *http.Request) { cfg.Tools.Exec.AllowRemote = config.DefaultConfig().Tools.Exec.AllowRemote } + // Load existing config and copy security credentials before validation, + // so that security-managed fields (e.g. pico token) are available. + oldCfg, err := config.LoadConfig(h.configPath) + if err != nil { + http.Error(w, fmt.Sprintf("Failed to load config: %v", err), http.StatusInternalServerError) + return + } + cfg.SecurityCopyFrom(oldCfg) + if errs := validateConfig(&cfg); len(errs) > 0 { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusBadRequest) @@ -64,13 +73,7 @@ func (h *Handler) handleUpdateConfig(w http.ResponseWriter, r *http.Request) { return } - logger.Infof("new config: %+v", cfg) - oldCfg, err := config.LoadConfig(h.configPath) - if err != nil { - http.Error(w, fmt.Sprintf("Failed to load config: %v", err), http.StatusInternalServerError) - return - } - cfg.SecurityCopyFrom(oldCfg) + logger.Infof("configuration updated successfully") if err := config.SaveConfig(h.configPath, &cfg); err != nil { http.Error(w, fmt.Sprintf("Failed to save config: %v", err), http.StatusInternalServerError) @@ -149,6 +152,10 @@ func (h *Handler) handlePatchConfig(w http.ResponseWriter, r *http.Request) { return } + // Copy security credentials before validation so security-managed + // fields (e.g. pico token) are available for validation checks. + newCfg.SecurityCopyFrom(cfg) + if errs := validateConfig(&newCfg); len(errs) > 0 { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusBadRequest) @@ -159,8 +166,6 @@ func (h *Handler) handlePatchConfig(w http.ResponseWriter, r *http.Request) { return } - newCfg.SecurityCopyFrom(cfg) - if err := config.SaveConfig(h.configPath, &newCfg); err != nil { http.Error(w, fmt.Sprintf("Failed to save config: %v", err), http.StatusInternalServerError) return diff --git a/web/backend/api/config_test.go b/web/backend/api/config_test.go index bbf285e14..cf8cd505e 100644 --- a/web/backend/api/config_test.go +++ b/web/backend/api/config_test.go @@ -4,6 +4,8 @@ import ( "bytes" "net/http" "net/http/httptest" + "os" + "path/filepath" "testing" "github.com/sipeed/picoclaw/pkg/config" @@ -141,6 +143,120 @@ func TestHandlePatchConfig_AllowsInvalidExecRegexPatternsWhenExecDisabled(t *tes } } +// setupPicoEnabledEnv creates a test environment with Pico channel enabled and +// its token stored only in .security.yml (not in the JSON payload). +func setupPicoEnabledEnv(t *testing.T) (string, func()) { + t.Helper() + + tmp := t.TempDir() + oldHome := os.Getenv("HOME") + oldPicoHome := os.Getenv("PICOCLAW_HOME") + + if err := os.Setenv("HOME", tmp); err != nil { + t.Fatalf("set HOME: %v", err) + } + if err := os.Setenv("PICOCLAW_HOME", filepath.Join(tmp, ".picoclaw")); err != nil { + t.Fatalf("set PICOCLAW_HOME: %v", err) + } + + cfg := config.DefaultConfig() + cfg.ModelList = []*config.ModelConfig{{ + ModelName: "custom-default", + Model: "openai/gpt-4o", + }} + cfg.Agents.Defaults.ModelName = "custom-default" + cfg.Channels.Pico.Enabled = true + cfg.WithSecurity(&config.SecurityConfig{ + ModelList: map[string]config.ModelSecurityEntry{ + "custom-default": {APIKeys: []string{"sk-default"}}, + }, + Channels: config.ChannelsSecurity{ + Pico: &config.PicoSecurity{Token: "test-pico-token"}, + }, + }) + + configPath := filepath.Join(tmp, "config.json") + if err := config.SaveConfig(configPath, cfg); err != nil { + t.Fatalf("SaveConfig error: %v", err) + } + + cleanup := func() { + _ = os.Setenv("HOME", oldHome) + if oldPicoHome == "" { + _ = os.Unsetenv("PICOCLAW_HOME") + } else { + _ = os.Setenv("PICOCLAW_HOME", oldPicoHome) + } + } + return configPath, cleanup +} + +func TestHandleUpdateConfig_SucceedsWhenPicoTokenInSecurityOnly(t *testing.T) { + configPath, cleanup := setupPicoEnabledEnv(t) + defer cleanup() + + h := NewHandler(configPath) + mux := http.NewServeMux() + h.RegisterRoutes(mux) + + // PUT request with pico enabled but no token in JSON — token is in .security.yml + req := httptest.NewRequest(http.MethodPut, "/api/config", bytes.NewBufferString(`{ + "version": 1, + "agents": { + "defaults": { + "workspace": "~/.picoclaw/workspace", + "model_name": "custom-default" + } + }, + "channels": { + "pico": { + "enabled": true, + "ping_interval": 30, + "read_timeout": 60, + "write_timeout": 10, + "max_connections": 100 + } + }, + "model_list": [ + { + "model_name": "custom-default", + "model": "openai/gpt-4o", + "api_keys": ["sk-default"] + } + ] + }`)) + req.Header.Set("Content-Type", "application/json") + + rec := httptest.NewRecorder() + mux.ServeHTTP(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("PUT /api/config status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String()) + } +} + +func TestHandlePatchConfig_SucceedsWhenPicoTokenInSecurityOnly(t *testing.T) { + configPath, cleanup := setupPicoEnabledEnv(t) + defer cleanup() + + h := NewHandler(configPath) + mux := http.NewServeMux() + h.RegisterRoutes(mux) + + // PATCH request changing an unrelated field — pico token still in .security.yml + req := httptest.NewRequest(http.MethodPatch, "/api/config", bytes.NewBufferString(`{ + "gateway": { + "log_level": "info" + } + }`)) + req.Header.Set("Content-Type", "application/json") + + rec := httptest.NewRecorder() + mux.ServeHTTP(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("PATCH /api/config status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String()) + } +} + func TestHandlePatchConfig_AllowsInvalidDenyRegexPatternsWhenDenyPatternsDisabled(t *testing.T) { configPath, cleanup := setupOAuthTestEnv(t) defer cleanup() From 8ed171dbe62c8f6da56863a33a4ddd7a95505908 Mon Sep 17 00:00:00 2001 From: afjcjsbx Date: Mon, 23 Mar 2026 13:43:02 +0100 Subject: [PATCH 08/39] resolved conflicts --- pkg/agent/loop.go | 164 ++++------------------------------------- pkg/agent/loop_test.go | 4 +- 2 files changed, 15 insertions(+), 153 deletions(-) diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index f1fc9a2be..6a188416d 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -96,14 +96,15 @@ type continuationTarget struct { } const ( - defaultResponse = "The model returned an empty response. This may indicate a provider error or token limit." - toolLimitResponse = "I've reached `max_tool_iterations` without a final response. Increase `max_tool_iterations` in config.json if this task needs more tool steps." - sessionKeyAgentPrefix = "agent:" - metadataKeyAccountID = "account_id" - metadataKeyGuildID = "guild_id" - metadataKeyTeamID = "team_id" - metadataKeyParentPeerKind = "parent_peer_kind" - metadataKeyParentPeerID = "parent_peer_id" + defaultResponse = "The model returned an empty response. This may indicate a provider error or token limit." + toolLimitResponse = "I've reached `max_tool_iterations` without a final response. Increase `max_tool_iterations` in config.json if this task needs more tool steps." + handledToolResponseSummary = "Requested output delivered via tool attachment." + sessionKeyAgentPrefix = "agent:" + metadataKeyAccountID = "account_id" + metadataKeyGuildID = "guild_id" + metadataKeyTeamID = "team_id" + metadataKeyParentPeerKind = "parent_peer_kind" + metadataKeyParentPeerID = "parent_peer_id" ) func NewAgentLoop( @@ -3253,7 +3254,7 @@ func (al *AgentLoop) applyExplicitSkillCommand( skillName, ok := agent.ContextBuilder.ResolveSkillName(arg) if !ok { - return true, true, fmt.Sprintf("Unknown skill %q.\n\n%s", arg, buildUseCommandHelp(agent)) + return true, true, fmt.Sprintf("Unknown skill: %s\nUse /list skills to see installed skills.", arg) } if len(parts) < 3 { @@ -3320,7 +3321,9 @@ func (al *AgentLoop) buildCommandsRuntime(agent *AgentInstance, opts *processOpt return al.reloadFunc() } if agent != nil { - rt.ListSkillNames = agent.ContextBuilder.ListSkillNames + if agent.ContextBuilder != nil { + rt.ListSkillNames = agent.ContextBuilder.ListSkillNames + } rt.GetModelInfo = func() (string, string) { return agent.Model, resolvedCandidateProvider(agent.Candidates, cfg.Agents.Defaults.Provider) } @@ -3373,79 +3376,6 @@ func (al *AgentLoop) buildCommandsRuntime(agent *AgentInstance, opts *processOpt return rt } -func activeSkillNames(agent *AgentInstance, opts processOptions) []string { - var out []string - seen := make(map[string]struct{}) - - appendNames := func(names []string) { - for _, name := range names { - name = strings.TrimSpace(name) - if name == "" { - continue - } - if _, exists := seen[name]; exists { - continue - } - seen[name] = struct{}{} - out = append(out, name) - } - } - - if agent != nil { - appendNames(agent.SkillsFilter) - } - appendNames(opts.ForcedSkills) - - return out -} - -func (al *AgentLoop) applyExplicitSkillCommand( - raw string, - agent *AgentInstance, - opts *processOptions, -) (matched bool, handled bool, reply string) { - commandName, ok := commands.CommandName(raw) - if !ok || commandName != "use" { - return false, false, "" - } - - if agent == nil || agent.ContextBuilder == nil { - return true, true, commandsUnavailableSkillMessage() - } - - fields := strings.Fields(strings.TrimSpace(raw)) - if len(fields) < 2 { - return true, true, buildUseCommandHelp(agent) - } - - if strings.EqualFold(fields[1], "clear") || strings.EqualFold(fields[1], "off") { - al.clearPendingSkills(opts.SessionKey) - return true, true, "Cleared pending skill override." - } - - canonicalSkill, ok := agent.ContextBuilder.ResolveSkillName(fields[1]) - if !ok { - return true, true, fmt.Sprintf("Unknown skill: %s\nUse /list skills to see installed skills.", fields[1]) - } - - if len(fields) == 2 { - al.setPendingSkills(opts.SessionKey, []string{canonicalSkill}) - return true, true, fmt.Sprintf( - "Skill %q is armed for your next message.\nSend your next request normally, or use /use clear to cancel.", - canonicalSkill, - ) - } - - message := strings.TrimSpace(strings.Join(fields[2:], " ")) - if message == "" { - return true, true, buildUseCommandHelp(agent) - } - - opts.UserMessage = message - opts.ForcedSkills = append(opts.ForcedSkills, canonicalSkill) - return true, false, "" -} - func commandsUnavailableSkillMessage() string { return "Skill selection is unavailable in the current context." } @@ -3513,74 +3443,6 @@ func (al *AgentLoop) clearPendingSkills(sessionKey string) { al.pendingSkills.Delete(sessionKey) } -func commandsUnavailableSkillMessage() string { - return "Skill commands are unavailable in the current context." -} - -func buildUseCommandHelp(agent *AgentInstance) string { - usage := "Usage:\n/use \n/use \n/use clear" - if agent == nil || agent.ContextBuilder == nil { - return usage - } - - names := agent.ContextBuilder.ListSkillNames() - if len(names) == 0 { - return "No installed skills.\n\n" + usage - } - - return fmt.Sprintf("%s\n\nInstalled Skills:\n- %s", usage, strings.Join(names, "\n- ")) -} - -func (al *AgentLoop) setPendingSkills(sessionKey string, skillNames []string) { - sessionKey = strings.TrimSpace(sessionKey) - if sessionKey == "" || len(skillNames) == 0 { - return - } - - values := make([]string, 0, len(skillNames)) - for _, name := range skillNames { - name = strings.TrimSpace(name) - if name == "" { - continue - } - values = append(values, name) - } - if len(values) == 0 { - return - } - - al.pendingSkills.Store(sessionKey, values) -} - -func (al *AgentLoop) takePendingSkills(sessionKey string) []string { - sessionKey = strings.TrimSpace(sessionKey) - if sessionKey == "" { - return nil - } - - value, ok := al.pendingSkills.LoadAndDelete(sessionKey) - if !ok { - return nil - } - - skills, ok := value.([]string) - if !ok || len(skills) == 0 { - return nil - } - - out := make([]string, len(skills)) - copy(out, skills) - return out -} - -func (al *AgentLoop) clearPendingSkills(sessionKey string) { - sessionKey = strings.TrimSpace(sessionKey) - if sessionKey == "" { - return - } - al.pendingSkills.Delete(sessionKey) -} - func mapCommandError(result commands.ExecuteResult) string { if result.Command == "" { return fmt.Sprintf("Failed to execute command: %v", result.Err) diff --git a/pkg/agent/loop_test.go b/pkg/agent/loop_test.go index 9fd737e12..ffb87d7dd 100644 --- a/pkg/agent/loop_test.go +++ b/pkg/agent/loop_test.go @@ -541,7 +541,7 @@ func TestProcessMessage_MediaToolHandledSkipsFollowUpLLMAndFinalText(t *testing. Agents: config.AgentsConfig{ Defaults: config.AgentDefaults{ Workspace: tmpDir, - Model: "test-model", + ModelName: "test-model", MaxTokens: 4096, MaxToolIterations: 10, }, @@ -627,7 +627,7 @@ func TestProcessMessage_MediaArtifactCanBeForwardedBySendFile(t *testing.T) { tmpDir := t.TempDir() cfg := config.DefaultConfig() cfg.Agents.Defaults.Workspace = tmpDir - cfg.Agents.Defaults.Model = "test-model" + cfg.Agents.Defaults.ModelName = "test-model" cfg.Agents.Defaults.MaxTokens = 4096 cfg.Agents.Defaults.MaxToolIterations = 10 From 5d5536a1a64686ac53358d40b66bc5e3e8c238d7 Mon Sep 17 00:00:00 2001 From: afjcjsbx Date: Mon, 23 Mar 2026 14:09:52 +0100 Subject: [PATCH 09/39] fix delivery and steering --- pkg/agent/loop.go | 87 +++++++++++++----- pkg/agent/loop_test.go | 203 +++++++++++++++++++++++++++++++++++++---- 2 files changed, 248 insertions(+), 42 deletions(-) diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index 6a188416d..899c233cb 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -2413,6 +2413,47 @@ turnLoop: if toolResult == nil { toolResult = tools.ErrorResult("hook returned nil tool result") } + if len(toolResult.Media) > 0 && toolResult.ResponseHandled { + parts := make([]bus.MediaPart, 0, len(toolResult.Media)) + for _, ref := range toolResult.Media { + part := bus.MediaPart{Ref: ref} + if al.mediaStore != nil { + if _, meta, err := al.mediaStore.ResolveWithMeta(ref); err == nil { + part.Filename = meta.Filename + part.ContentType = meta.ContentType + part.Type = inferMediaType(meta.Filename, meta.ContentType) + } + } + parts = append(parts, part) + } + outboundMedia := bus.OutboundMediaMessage{ + Channel: ts.channel, + ChatID: ts.chatID, + Parts: parts, + } + if al.channelManager != nil && ts.channel != "" && !constants.IsInternalChannel(ts.channel) { + if err := al.channelManager.SendMedia(ctx, outboundMedia); err != nil { + logger.WarnCF("agent", "Failed to deliver handled tool media", + map[string]any{ + "agent_id": ts.agent.ID, + "tool": toolName, + "channel": ts.channel, + "chat_id": ts.chatID, + "error": err.Error(), + }) + toolResult = tools.ErrorResult(fmt.Sprintf("failed to deliver attachment: %v", err)).WithError(err) + } + } else if al.bus != nil { + al.bus.PublishOutboundMedia(ctx, outboundMedia) + // Queuing media is only best-effort; it has not been delivered yet. + toolResult.ResponseHandled = false + } + } + + if len(toolResult.Media) > 0 && !toolResult.ResponseHandled { + toolResult.ArtifactTags = buildArtifactTags(al.mediaStore, toolResult.Media) + } + if !toolResult.ResponseHandled { allResponsesHandled = false } @@ -2430,29 +2471,6 @@ turnLoop: }) } - if len(toolResult.Media) > 0 && toolResult.ResponseHandled { - parts := make([]bus.MediaPart, 0, len(toolResult.Media)) - for _, ref := range toolResult.Media { - part := bus.MediaPart{Ref: ref} - if al.mediaStore != nil { - if _, meta, err := al.mediaStore.ResolveWithMeta(ref); err == nil { - part.Filename = meta.Filename - part.ContentType = meta.ContentType - part.Type = inferMediaType(meta.Filename, meta.ContentType) - } - } - parts = append(parts, part) - } - al.bus.PublishOutboundMedia(ctx, bus.OutboundMediaMessage{ - Channel: ts.channel, - ChatID: ts.chatID, - Parts: parts, - }) - } - - if len(toolResult.Media) > 0 && !toolResult.ResponseHandled { - toolResult.ArtifactTags = buildArtifactTags(al.mediaStore, toolResult.Media) - } contentForLLM := toolResult.ContentForLLM() toolResultMsg := providers.Message{ @@ -2543,6 +2561,29 @@ turnLoop: } if allResponsesHandled { + if len(pendingMessages) > 0 { + logger.InfoCF("agent", "Pending steering exists after handled tool delivery; continuing turn before finalizing", + map[string]any{ + "agent_id": ts.agent.ID, + "steering_count": len(pendingMessages), + "session_key": ts.sessionKey, + }) + finalContent = "" + goto turnLoop + } + + if steerMsgs := al.dequeueSteeringMessagesForScope(ts.sessionKey); len(steerMsgs) > 0 { + logger.InfoCF("agent", "Steering arrived after handled tool delivery; continuing turn before finalizing", + map[string]any{ + "agent_id": ts.agent.ID, + "steering_count": len(steerMsgs), + "session_key": ts.sessionKey, + }) + pendingMessages = append(pendingMessages, steerMsgs...) + finalContent = "" + goto turnLoop + } + summaryMsg := providers.Message{ Role: "assistant", Content: handledToolResponseSummary, diff --git a/pkg/agent/loop_test.go b/pkg/agent/loop_test.go index ffb87d7dd..2bf544595 100644 --- a/pkg/agent/loop_test.go +++ b/pkg/agent/loop_test.go @@ -33,6 +33,41 @@ func (f *fakeChannel) IsAllowed(string) bool { func (f *fakeChannel) IsAllowedSender(sender bus.SenderInfo) bool { return true } func (f *fakeChannel) ReasoningChannelID() string { return f.id } +type fakeMediaChannel struct { + fakeChannel + sentMedia []bus.OutboundMediaMessage +} + +func (f *fakeMediaChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error { + f.sentMedia = append(f.sentMedia, msg) + return nil +} + +func newStartedTestChannelManager( + t *testing.T, + msgBus *bus.MessageBus, + store media.MediaStore, + name string, + ch channels.Channel, +) *channels.Manager { + t.Helper() + + cm, err := channels.NewManager(&config.Config{}, msgBus, store) + if err != nil { + t.Fatalf("NewManager() error = %v", err) + } + cm.RegisterChannel(name, ch) + if err := cm.StartAll(context.Background()); err != nil { + t.Fatalf("StartAll() error = %v", err) + } + t.Cleanup(func() { + if err := cm.StopAll(context.Background()); err != nil { + t.Fatalf("StopAll() error = %v", err) + } + }) + return cm +} + type recordingProvider struct { lastMessages []providers.Message } @@ -554,6 +589,8 @@ func TestProcessMessage_MediaToolHandledSkipsFollowUpLLMAndFinalText(t *testing. store := media.NewFileMediaStore() al.SetMediaStore(store) + telegramChannel := &fakeMediaChannel{fakeChannel: fakeChannel{id: "rid-telegram"}} + al.SetChannelManager(newStartedTestChannelManager(t, msgBus, store, "telegram", telegramChannel)) imagePath := filepath.Join(tmpDir, "screen.png") if err := os.WriteFile(imagePath, []byte("fake screenshot"), 0o644); err != nil { @@ -587,16 +624,20 @@ func TestProcessMessage_MediaToolHandledSkipsFollowUpLLMAndFinalText(t *testing. t.Fatal("expected tools to be available on the first LLM call") } + if len(telegramChannel.sentMedia) != 1 { + t.Fatalf("expected exactly 1 synchronously sent media message, got %d", len(telegramChannel.sentMedia)) + } + if telegramChannel.sentMedia[0].Channel != "telegram" || telegramChannel.sentMedia[0].ChatID != "chat1" { + t.Fatalf("unexpected sent media target: %+v", telegramChannel.sentMedia[0]) + } + if len(telegramChannel.sentMedia[0].Parts) != 1 { + t.Fatalf("expected exactly 1 sent media part, got %d", len(telegramChannel.sentMedia[0].Parts)) + } + select { - case mediaMsg := <-msgBus.OutboundMediaChan(): - if mediaMsg.Channel != "telegram" || mediaMsg.ChatID != "chat1" { - t.Fatalf("unexpected outbound media target: %+v", mediaMsg) - } - if len(mediaMsg.Parts) != 1 { - t.Fatalf("expected exactly 1 outbound media part, got %d", len(mediaMsg.Parts)) - } + case extra := <-msgBus.OutboundMediaChan(): + t.Fatalf("expected handled media to bypass async queue, got %+v", extra) default: - t.Fatal("expected outbound media message to be published") } defaultAgent := al.GetRegistry().GetDefaultAgent() @@ -623,6 +664,59 @@ func TestProcessMessage_MediaToolHandledSkipsFollowUpLLMAndFinalText(t *testing. } } +func TestProcessMessage_HandledToolProcessesQueuedSteeringBeforeReturning(t *testing.T) { + tmpDir := t.TempDir() + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + ModelName: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + + msgBus := bus.NewMessageBus() + provider := &handledMediaWithSteeringProvider{} + al := NewAgentLoop(cfg, msgBus, provider) + + store := media.NewFileMediaStore() + al.SetMediaStore(store) + telegramChannel := &fakeMediaChannel{fakeChannel: fakeChannel{id: "rid-telegram"}} + al.SetChannelManager(newStartedTestChannelManager(t, msgBus, store, "telegram", telegramChannel)) + + imagePath := filepath.Join(tmpDir, "screen-steering.png") + if err := os.WriteFile(imagePath, []byte("fake screenshot"), 0o644); err != nil { + t.Fatalf("WriteFile(imagePath) error = %v", err) + } + + al.RegisterTool(&handledMediaWithSteeringTool{ + store: store, + path: imagePath, + loop: al, + }) + + response, err := al.processMessage(context.Background(), bus.InboundMessage{ + Channel: "telegram", + ChatID: "chat1", + SenderID: "user1", + Content: "take a screenshot of the screen and send it to me", + }) + if err != nil { + t.Fatalf("processMessage() error = %v", err) + } + if response != "Handled the queued steering message." { + t.Fatalf("response = %q, want queued steering response", response) + } + if provider.calls != 2 { + t.Fatalf("expected 2 LLM calls after queued steering, got %d", provider.calls) + } + if len(telegramChannel.sentMedia) != 1 { + t.Fatalf("expected exactly 1 synchronously sent media message, got %d", len(telegramChannel.sentMedia)) + } +} + func TestProcessMessage_MediaArtifactCanBeForwardedBySendFile(t *testing.T) { tmpDir := t.TempDir() cfg := config.DefaultConfig() @@ -637,6 +731,8 @@ func TestProcessMessage_MediaArtifactCanBeForwardedBySendFile(t *testing.T) { store := media.NewFileMediaStore() al.SetMediaStore(store) + telegramChannel := &fakeMediaChannel{fakeChannel: fakeChannel{id: "rid-telegram"}} + al.SetChannelManager(newStartedTestChannelManager(t, msgBus, store, "telegram", telegramChannel)) mediaDir := media.TempDir() if err := os.MkdirAll(mediaDir, 0o700); err != nil { @@ -668,21 +764,19 @@ func TestProcessMessage_MediaArtifactCanBeForwardedBySendFile(t *testing.T) { t.Fatalf("expected 2 LLM calls (artifact + send_file), got %d", provider.calls) } - select { - case mediaMsg := <-msgBus.OutboundMediaChan(): - if mediaMsg.Channel != "telegram" || mediaMsg.ChatID != "chat1" { - t.Fatalf("unexpected outbound media target: %+v", mediaMsg) - } - if len(mediaMsg.Parts) != 1 { - t.Fatalf("expected exactly 1 outbound media part, got %d", len(mediaMsg.Parts)) - } - default: - t.Fatal("expected outbound media from send_file") + if len(telegramChannel.sentMedia) != 1 { + t.Fatalf("expected exactly 1 synchronously sent media message, got %d", len(telegramChannel.sentMedia)) + } + if telegramChannel.sentMedia[0].Channel != "telegram" || telegramChannel.sentMedia[0].ChatID != "chat1" { + t.Fatalf("unexpected sent media target: %+v", telegramChannel.sentMedia[0]) + } + if len(telegramChannel.sentMedia[0].Parts) != 1 { + t.Fatalf("expected exactly 1 sent media part, got %d", len(telegramChannel.sentMedia[0].Parts)) } select { case extra := <-msgBus.OutboundMediaChan(): - t.Fatalf("expected exactly one outbound media delivery, got extra %+v", extra) + t.Fatalf("expected synchronous send_file delivery to bypass async queue, got %+v", extra) default: } } @@ -975,6 +1069,77 @@ func (m *handledMediaTool) Execute(ctx context.Context, args map[string]any) *to return tools.MediaResult("Attachment delivered by tool.", []string{ref}).WithResponseHandled() } +type handledMediaWithSteeringProvider struct { + calls int +} + +func (m *handledMediaWithSteeringProvider) Chat( + ctx context.Context, + messages []providers.Message, + tools []providers.ToolDefinition, + model string, + opts map[string]any, +) (*providers.LLMResponse, error) { + m.calls++ + if m.calls == 1 { + return &providers.LLMResponse{ + Content: "Taking the screenshot now.", + ToolCalls: []providers.ToolCall{{ + ID: "call_handled_media_steering", + Type: "function", + Name: "handled_media_with_steering_tool", + Arguments: map[string]any{}, + }}, + }, nil + } + + for _, msg := range messages { + if msg.Role == "user" && msg.Content == "what about this instead?" { + return &providers.LLMResponse{Content: "Handled the queued steering message."}, nil + } + } + + return nil, fmt.Errorf("provider did not receive queued steering message") +} + +func (m *handledMediaWithSteeringProvider) GetDefaultModel() string { + return "handled-media-with-steering-model" +} + +type handledMediaWithSteeringTool struct { + store media.MediaStore + path string + loop *AgentLoop +} + +func (m *handledMediaWithSteeringTool) Name() string { return "handled_media_with_steering_tool" } +func (m *handledMediaWithSteeringTool) Description() string { + return "Returns handled media and enqueues a steering message during execution" +} + +func (m *handledMediaWithSteeringTool) Parameters() map[string]any { + return map[string]any{ + "type": "object", + "properties": map[string]any{}, + } +} + +func (m *handledMediaWithSteeringTool) Execute(ctx context.Context, args map[string]any) *tools.ToolResult { + if err := m.loop.Steer(providers.Message{Role: "user", Content: "what about this instead?"}); err != nil { + return tools.ErrorResult(err.Error()).WithError(err) + } + + ref, err := m.store.Store(m.path, media.MediaMeta{ + Filename: filepath.Base(m.path), + ContentType: "image/png", + Source: "test:handled_media_with_steering_tool", + }, "test:handled_media_with_steering") + if err != nil { + return tools.ErrorResult(err.Error()).WithError(err) + } + return tools.MediaResult("Attachment delivered by tool.", []string{ref}).WithResponseHandled() +} + type mediaArtifactTool struct { store media.MediaStore path string From f1ac1a107263cfc1addab7b15bbf07a38c7bf0a6 Mon Sep 17 00:00:00 2001 From: lc6464 <64722907+lc6464@users.noreply.github.com> Date: Tue, 24 Mar 2026 12:20:57 +0800 Subject: [PATCH 10/39] fix(web): ensure at least 40% of the characters are masked for api key - keys longer than 12 chars show prefix + last 4 chars - keys 9-12 chars show prefix + last 2 chars - shorter keys are fully masked --- web/backend/api/models.go | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/web/backend/api/models.go b/web/backend/api/models.go index 1e3b5f90a..142363079 100644 --- a/web/backend/api/models.go +++ b/web/backend/api/models.go @@ -307,16 +307,25 @@ func (h *Handler) handleSetDefaultModel(w http.ResponseWriter, r *http.Request) } // maskAPIKey returns a masked version of an API key for safe display. -// Keys longer than 8 chars show prefix + last 4 chars: "sk-****abcd" +// Keys longer than 12 chars show prefix + last 4 chars: "sk-****abcd". +// Keys 9-12 chars show prefix + last 2 chars: "sk-****cd". // Shorter keys are fully masked as "****". // Empty keys return empty string. +// Ensure at least 40% of the key is masked. func maskAPIKey(key string) string { if key == "" { return "" } + if len(key) <= 8 { return "****" } + + // Show first 3 chars and last 2 chars + if len(key) <= 12 { + return key[:3] + "****" + key[len(key)-2:] + } + // Show first 3 chars and last 4 chars return key[:3] + "****" + key[len(key)-4:] } From 66d2efc9d126d25c1ca7fd5926600b574158268e Mon Sep 17 00:00:00 2001 From: lc6464 <64722907+lc6464@users.noreply.github.com> Date: Tue, 24 Mar 2026 12:36:31 +0800 Subject: [PATCH 11/39] test(web): add test for maskAPIKey --- web/backend/api/models_test.go | 53 ++++++++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/web/backend/api/models_test.go b/web/backend/api/models_test.go index 44d10154e..5378e986e 100644 --- a/web/backend/api/models_test.go +++ b/web/backend/api/models_test.go @@ -315,3 +315,56 @@ func TestHandleListModels_NormalizesWildcardLocalAPIBaseForProbe(t *testing.T) { t.Fatalf("probe api base = %q, want %q", gotProbe, "http://127.0.0.1:8000/v1|custom-model|") } } + +func TestMaskAPIKey(t *testing.T) { + tests := []struct { + name string + key string + want string + }{ + { + name: "empty key", + key: "", + want: "", + }, + { + name: "short key fully masked", + key: "abcd", + want: "****", + }, + { + name: "length 8 boundary fully masked", + key: "12345678", + want: "****", + }, + { + name: "length 9 boundary shows last 2", + key: "123456789", + want: "123****89", + }, + { + name: "length 12 boundary shows last 2", + key: "abcdefghijkl", + want: "abc****kl", + }, + { + name: "length 13 boundary shows last 4", + key: "abcdefghijklm", + want: "abc****jklm", + }, + { + name: "typical api key", + key: "sk-1234567890abcd", + want: "sk-****abcd", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := maskAPIKey(tc.key) + if got != tc.want { + t.Fatalf("maskAPIKey(%q) = %q, want %q", tc.key, got, tc.want) + } + }) + } +} From b23a6b3f54e3ed0f38ab1d4ee563b938251c82f8 Mon Sep 17 00:00:00 2001 From: Hua Audio Date: Tue, 24 Mar 2026 06:33:35 +0100 Subject: [PATCH 12/39] Feat/move weixin login to auth and update docs (#1945) * move weixin to auth and update docs * fix ci test --- README.fr.md | 2 +- README.id.md | 2 +- README.it.md | 2 +- README.ja.md | 2 +- README.md | 2 +- README.pt-br.md | 2 +- README.vi.md | 2 +- README.zh.md | 2 +- cmd/picoclaw/internal/auth/command.go | 1 + cmd/picoclaw/internal/auth/command_test.go | 1 + cmd/picoclaw/internal/{onboard => auth}/weixin.go | 4 ++-- cmd/picoclaw/internal/onboard/command.go | 5 +---- cmd/picoclaw/internal/onboard/command_test.go | 5 ++--- docs/channels/weixin/README.md | 2 +- docs/channels/weixin/README.zh.md | 2 +- docs/chat-apps.md | 2 +- docs/fr/chat-apps.md | 2 +- docs/ja/chat-apps.md | 2 +- docs/pt-br/chat-apps.md | 2 +- docs/vi/chat-apps.md | 2 +- docs/zh/chat-apps.md | 2 +- 21 files changed, 23 insertions(+), 25 deletions(-) rename cmd/picoclaw/internal/{onboard => auth}/weixin.go (98%) diff --git a/README.fr.md b/README.fr.md index 301456262..a4fa628c9 100644 --- a/README.fr.md +++ b/README.fr.md @@ -524,7 +524,7 @@ Connectez PicoClaw au réseau social des Agents simplement en envoyant un seul m | Commande | Description | | ------------------------- | ---------------------------------------- | | `picoclaw onboard` | Initialiser la config & le workspace | -| `picoclaw onboard weixin` | Connecter un compte WeChat via QR | +| `picoclaw auth weixin` | Connecter un compte WeChat via QR | | `picoclaw agent -m "..."` | Chatter avec l'agent | | `picoclaw agent` | Mode chat interactif | | `picoclaw gateway` | Démarrer le gateway | diff --git a/README.id.md b/README.id.md index 6b7025ffd..6d62dcb9b 100644 --- a/README.id.md +++ b/README.id.md @@ -520,7 +520,7 @@ Hubungkan PicoClaw ke Jaringan Sosial Agent hanya dengan mengirim satu pesan mel | Perintah | Deskripsi | | -------------------------- | -------------------------------- | | `picoclaw onboard` | Inisialisasi konfigurasi & workspace | -| `picoclaw onboard weixin` | Hubungkan akun WeChat via QR | +| `picoclaw auth weixin` | Hubungkan akun WeChat via QR | | `picoclaw agent -m "..."` | Chat dengan agent | | `picoclaw agent` | Mode chat interaktif | | `picoclaw gateway` | Mulai gateway | diff --git a/README.it.md b/README.it.md index dae541a17..1ed73ee54 100644 --- a/README.it.md +++ b/README.it.md @@ -520,7 +520,7 @@ Connetti PicoClaw al Social Network degli Agent semplicemente inviando un singol | Comando | Descrizione | | ------------------------- | ---------------------------------- | | `picoclaw onboard` | Inizializza config & workspace | -| `picoclaw onboard weixin` | Connetti account WeChat tramite QR | +| `picoclaw auth weixin` | Connetti account WeChat tramite QR | | `picoclaw agent -m "..."` | Chatta con l'agent | | `picoclaw agent` | Modalità chat interattiva | | `picoclaw gateway` | Avvia il gateway | diff --git a/README.ja.md b/README.ja.md index 3096d4022..9165986ba 100644 --- a/README.ja.md +++ b/README.ja.md @@ -520,7 +520,7 @@ CLI または統合チャットアプリからメッセージを 1 つ送るだ | コマンド | 説明 | | ------------------------- | ------------------------------ | | `picoclaw onboard` | 設定&ワークスペースの初期化 | -| `picoclaw onboard weixin` | WeChat アカウントを QR で接続 | +| `picoclaw auth weixin` | WeChat アカウントを QR で接続 | | `picoclaw agent -m "..."` | Agent とチャット | | `picoclaw agent` | インタラクティブチャットモード | | `picoclaw gateway` | Gateway を起動 | diff --git a/README.md b/README.md index 72d38103c..568c87e59 100644 --- a/README.md +++ b/README.md @@ -523,7 +523,7 @@ Connect PicoClaw to the Agent Social Network simply by sending a single message | Command | Description | | ------------------------- | -------------------------------- | | `picoclaw onboard` | Initialize config & workspace | -| `picoclaw onboard weixin` | Connect WeChat account via QR | +| `picoclaw auth weixin` | Connect WeChat account via QR | | `picoclaw agent -m "..."` | Chat with the agent | | `picoclaw agent` | Interactive chat mode | | `picoclaw gateway` | Start the gateway | diff --git a/README.pt-br.md b/README.pt-br.md index 3c039f190..d4b303e24 100644 --- a/README.pt-br.md +++ b/README.pt-br.md @@ -520,7 +520,7 @@ Conecte o PicoClaw à Rede Social de Agents simplesmente enviando uma única men | Comando | Descrição | | ------------------------- | -------------------------------------- | | `picoclaw onboard` | Inicializar config e workspace | -| `picoclaw onboard weixin` | Conectar conta WeChat via QR | +| `picoclaw auth weixin` | Conectar conta WeChat via QR | | `picoclaw agent -m "..."` | Conversar com o agent | | `picoclaw agent` | Modo de chat interativo | | `picoclaw gateway` | Iniciar o gateway | diff --git a/README.vi.md b/README.vi.md index b63fd4ef7..ceeb02b63 100644 --- a/README.vi.md +++ b/README.vi.md @@ -520,7 +520,7 @@ Kết nối PicoClaw với Mạng xã hội Agent chỉ bằng cách gửi một | Lệnh | Mô tả | | ------------------------- | ---------------------------------------- | | `picoclaw onboard` | Khởi tạo cấu hình & workspace | -| `picoclaw onboard weixin` | Kết nối tài khoản WeChat qua QR | +| `picoclaw auth weixin` | Kết nối tài khoản WeChat qua QR | | `picoclaw agent -m "..."` | Trò chuyện với agent | | `picoclaw agent` | Chế độ trò chuyện tương tác | | `picoclaw gateway` | Khởi động gateway | diff --git a/README.zh.md b/README.zh.md index de96e5164..93abf89d3 100644 --- a/README.zh.md +++ b/README.zh.md @@ -520,7 +520,7 @@ PicoClaw 原生支持 [MCP](https://modelcontextprotocol.io/) — 连接任意 M | 命令 | 说明 | | ------------------------- | ---------------------- | | `picoclaw onboard` | 初始化配置与工作区 | -| `picoclaw onboard weixin` | 扫码连接微信个人号 | +| `picoclaw auth weixin` | 扫码连接微信个人号 | | `picoclaw agent -m "..."` | 与 Agent 对话 | | `picoclaw agent` | 交互式对话模式 | | `picoclaw gateway` | 启动网关 | diff --git a/cmd/picoclaw/internal/auth/command.go b/cmd/picoclaw/internal/auth/command.go index 12a0a3a8c..149095699 100644 --- a/cmd/picoclaw/internal/auth/command.go +++ b/cmd/picoclaw/internal/auth/command.go @@ -16,6 +16,7 @@ func NewAuthCommand() *cobra.Command { newLogoutCommand(), newStatusCommand(), newModelsCommand(), + newWeixinCommand(), ) return cmd diff --git a/cmd/picoclaw/internal/auth/command_test.go b/cmd/picoclaw/internal/auth/command_test.go index 48dc704dd..12f2bc186 100644 --- a/cmd/picoclaw/internal/auth/command_test.go +++ b/cmd/picoclaw/internal/auth/command_test.go @@ -32,6 +32,7 @@ func TestNewAuthCommand(t *testing.T) { "logout", "status", "models", + "weixin", } subcommands := cmd.Commands() diff --git a/cmd/picoclaw/internal/onboard/weixin.go b/cmd/picoclaw/internal/auth/weixin.go similarity index 98% rename from cmd/picoclaw/internal/onboard/weixin.go rename to cmd/picoclaw/internal/auth/weixin.go index 2e1c2ad75..948a81495 100644 --- a/cmd/picoclaw/internal/onboard/weixin.go +++ b/cmd/picoclaw/internal/auth/weixin.go @@ -1,4 +1,4 @@ -package onboard +package auth import ( "context" @@ -27,7 +27,7 @@ to authorize your account. On success, the bot token is saved to the picoclaw config so you can start the gateway immediately. Example: - picoclaw onboard weixin`, + picoclaw auth weixin`, RunE: func(cmd *cobra.Command, _ []string) error { return runWeixinOnboard(baseURL, proxy, time.Duration(timeout)*time.Second) }, diff --git a/cmd/picoclaw/internal/onboard/command.go b/cmd/picoclaw/internal/onboard/command.go index 1f94c6718..4be19b2a5 100644 --- a/cmd/picoclaw/internal/onboard/command.go +++ b/cmd/picoclaw/internal/onboard/command.go @@ -16,7 +16,7 @@ func NewOnboardCommand() *cobra.Command { cmd := &cobra.Command{ Use: "onboard", Aliases: []string{"o"}, - Short: "Initialize picoclaw configuration, workspace, and channel accounts", + Short: "Initialize picoclaw configuration and workspace", // Run without subcommands → original onboard flow Run: func(cmd *cobra.Command, args []string) { if len(args) == 0 { @@ -30,8 +30,5 @@ func NewOnboardCommand() *cobra.Command { cmd.Flags().BoolVar(&encrypt, "enc", false, "Enable credential encryption (generates SSH key and prompts for passphrase)") - // Channel onboarding subcommands - cmd.AddCommand(newWeixinCommand()) - return cmd } diff --git a/cmd/picoclaw/internal/onboard/command_test.go b/cmd/picoclaw/internal/onboard/command_test.go index 6b9fb6e95..56936190b 100644 --- a/cmd/picoclaw/internal/onboard/command_test.go +++ b/cmd/picoclaw/internal/onboard/command_test.go @@ -13,7 +13,7 @@ func TestNewOnboardCommand(t *testing.T) { require.NotNil(t, cmd) assert.Equal(t, "onboard", cmd.Use) - assert.Equal(t, "Initialize picoclaw configuration, workspace, and channel accounts", cmd.Short) + assert.Equal(t, "Initialize picoclaw configuration and workspace", cmd.Short) assert.Len(t, cmd.Aliases, 1) assert.True(t, cmd.HasAlias("o")) @@ -28,6 +28,5 @@ func TestNewOnboardCommand(t *testing.T) { encFlag := cmd.Flags().Lookup("enc") require.NotNil(t, encFlag, "expected --enc flag to be registered") assert.Equal(t, "false", encFlag.DefValue, "--enc should default to false") - assert.True(t, cmd.HasSubCommands()) - assert.NotNil(t, cmd.Commands()) + assert.False(t, cmd.HasSubCommands()) } diff --git a/docs/channels/weixin/README.md b/docs/channels/weixin/README.md index 22687fec4..0c51ff3c5 100644 --- a/docs/channels/weixin/README.md +++ b/docs/channels/weixin/README.md @@ -7,7 +7,7 @@ PicoClaw supports connecting to your personal WeChat account using the official The easiest way to set up the Weixin channel is using the interactive onboarding command: ```bash -picoclaw onboard weixin +picoclaw auth weixin ``` This command will: diff --git a/docs/channels/weixin/README.zh.md b/docs/channels/weixin/README.zh.md index d5e6f0a49..0f1181878 100644 --- a/docs/channels/weixin/README.zh.md +++ b/docs/channels/weixin/README.zh.md @@ -7,7 +7,7 @@ PicoClaw 支持使用腾讯官方 iLink API 连接您的个人微信账号。 最简单的方法是使用交互式 onboarding 命令进行一键激活: ```bash -picoclaw onboard weixin +picoclaw auth weixin ``` 该命令将: diff --git a/docs/chat-apps.md b/docs/chat-apps.md index d300f5544..4a78f465e 100644 --- a/docs/chat-apps.md +++ b/docs/chat-apps.md @@ -190,7 +190,7 @@ PicoClaw supports connecting to your personal WeChat account using the official Run the interactive QR login flow: ```bash -picoclaw onboard weixin +picoclaw auth weixin ``` Scan the printed QR code with your WeChat mobile app. On success, the token is saved to your config. diff --git a/docs/fr/chat-apps.md b/docs/fr/chat-apps.md index daff951f4..c36e002ff 100644 --- a/docs/fr/chat-apps.md +++ b/docs/fr/chat-apps.md @@ -179,7 +179,7 @@ PicoClaw prend en charge la connexion à votre compte WeChat personnel via l'API Lancez le flux de connexion interactif par QR code : ```bash -picoclaw onboard weixin +picoclaw auth weixin ``` Scannez le QR code affiché avec votre application WeChat mobile. Une fois connecté, le token est sauvegardé dans votre configuration. diff --git a/docs/ja/chat-apps.md b/docs/ja/chat-apps.md index 789c0125f..341dc4aba 100644 --- a/docs/ja/chat-apps.md +++ b/docs/ja/chat-apps.md @@ -184,7 +184,7 @@ PicoClaw は Tencent iLink 公式 API を使用して WeChat 個人アカウン インタラクティブな QR ログインフローを実行します: ```bash -picoclaw onboard weixin +picoclaw auth weixin ``` WeChat モバイルアプリで表示された QR コードをスキャンしてください。ログイン成功後、トークンが設定ファイルに保存されます。 diff --git a/docs/pt-br/chat-apps.md b/docs/pt-br/chat-apps.md index 4fa59b1b2..92fda329c 100644 --- a/docs/pt-br/chat-apps.md +++ b/docs/pt-br/chat-apps.md @@ -179,7 +179,7 @@ O PicoClaw suporta conexão com sua conta pessoal do WeChat usando a API oficial Execute o fluxo de login interativo por QR code: ```bash -picoclaw onboard weixin +picoclaw auth weixin ``` Escaneie o QR code exibido com seu aplicativo WeChat mobile. Após o login bem-sucedido, o token é salvo na sua configuração. diff --git a/docs/vi/chat-apps.md b/docs/vi/chat-apps.md index d907e5e91..5e2a81ccf 100644 --- a/docs/vi/chat-apps.md +++ b/docs/vi/chat-apps.md @@ -179,7 +179,7 @@ PicoClaw hỗ trợ kết nối với tài khoản WeChat cá nhân của bạn Chạy luồng đăng nhập QR tương tác: ```bash -picoclaw onboard weixin +picoclaw auth weixin ``` Quét mã QR được in ra bằng ứng dụng WeChat trên điện thoại. Sau khi đăng nhập thành công, token sẽ được lưu vào cấu hình. diff --git a/docs/zh/chat-apps.md b/docs/zh/chat-apps.md index aeba7d460..4d1451d68 100644 --- a/docs/zh/chat-apps.md +++ b/docs/zh/chat-apps.md @@ -191,7 +191,7 @@ PicoClaw 通过腾讯 iLink 官方 API 支持连接微信个人号。 运行交互式扫码登录流程: ```bash -picoclaw onboard weixin +picoclaw auth weixin ``` 用微信手机端扫描打印出的二维码。登录成功后,token 会自动保存到配置文件。 From 1ef2b6903dbaeb07d026aa0170e398293aa7f83c Mon Sep 17 00:00:00 2001 From: lc6464 <64722907+lc6464@users.noreply.github.com> Date: Tue, 24 Mar 2026 13:54:04 +0800 Subject: [PATCH 13/39] test(web): add percentage checking of characters displaying in APIKey --- web/backend/api/models.go | 2 +- web/backend/api/models_test.go | 19 +++++++++++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/web/backend/api/models.go b/web/backend/api/models.go index 142363079..64a7b5f1f 100644 --- a/web/backend/api/models.go +++ b/web/backend/api/models.go @@ -311,7 +311,7 @@ func (h *Handler) handleSetDefaultModel(w http.ResponseWriter, r *http.Request) // Keys 9-12 chars show prefix + last 2 chars: "sk-****cd". // Shorter keys are fully masked as "****". // Empty keys return empty string. -// Ensure at least 40% of the key is masked. +// Ensure at least 40% of the key will not be displayed. func maskAPIKey(key string) string { if key == "" { return "" diff --git a/web/backend/api/models_test.go b/web/backend/api/models_test.go index 5378e986e..0127ce675 100644 --- a/web/backend/api/models_test.go +++ b/web/backend/api/models_test.go @@ -4,6 +4,7 @@ import ( "encoding/json" "net/http" "net/http/httptest" + "strings" "sync" "testing" "time" @@ -365,6 +366,24 @@ func TestMaskAPIKey(t *testing.T) { if got != tc.want { t.Fatalf("maskAPIKey(%q) = %q, want %q", tc.key, got, tc.want) } + + if tc.key != "" { + displayed := strings.Replace(tc.want, "****", "", 1) + if len(tc.key) <= 8 { + if displayed != "" { + t.Fatalf("maskAPIKey(%q) displayed part = %q, want empty", tc.key, displayed) + } + } else { + if len(displayed)*10 > len(tc.key)*6 { + t.Fatalf( + "maskAPIKey(%q) displayed length = %d, want at most 60%% of %d", + tc.key, + len(displayed), + len(tc.key), + ) + } + } + } }) } } From d921bbb66727519d769df707f67817b6b3579c43 Mon Sep 17 00:00:00 2001 From: Cytown Date: Tue, 24 Mar 2026 16:24:12 +0800 Subject: [PATCH 14/39] bug fix for security initial cause can't save model in launcher (#1952) --- pkg/config/config.go | 1 + pkg/config/security.go | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/pkg/config/config.go b/pkg/config/config.go index f0d9aa580..a943fb2eb 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -1946,6 +1946,7 @@ func SaveConfig(path string, cfg *Config) error { if err != nil { return err } + logger.Infof("saving config to %s", path) return fileutil.WriteFileAtomic(path, data, 0o600) } diff --git a/pkg/config/security.go b/pkg/config/security.go index 816d465c7..1fda89bf0 100644 --- a/pkg/config/security.go +++ b/pkg/config/security.go @@ -31,7 +31,7 @@ type SecurityConfig struct { // Model API keys. Map key is model_name, can include suffix like "abc:0", "abc:1" // for load balancing with same model_name. The suffix ":N" is used to distinguish // multiple configs that share the same base model_name. - ModelList map[string]ModelSecurityEntry `yaml:"model_list,omitempty"` + ModelList map[string]ModelSecurityEntry `yaml:"model_list"` // Channel tokens/secrets Channels *ChannelsSecurity `yaml:"channels,omitempty"` From d23c24ce72977f3c87072813bde412ee7e8b9821 Mon Sep 17 00:00:00 2001 From: wenjie Date: Tue, 24 Mar 2026 17:03:28 +0800 Subject: [PATCH 15/39] fix(config): normalize empty security config before save/load (#1956) Normalize missing security sections when attaching, loading, and saving security config so existing config files without `.security.yml` can still be updated safely. This fixes Pico channel setup for legacy/existing configs and adds coverage for the missing security file path and unexported JSON field behavior. --- pkg/config/config.go | 2 ++ pkg/config/security.go | 23 +++++++++++++-- pkg/config/security_integration_test.go | 10 +++---- pkg/config/security_test.go | 3 ++ web/backend/api/config_test.go | 2 +- web/backend/api/pico_test.go | 39 +++++++++++++++++++++++++ 6 files changed, 70 insertions(+), 9 deletions(-) diff --git a/pkg/config/config.go b/pkg/config/config.go index 00f587159..8073dc723 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -106,6 +106,7 @@ func (c *Config) WithSecurity(sec *SecurityConfig) *Config { c.security = sec return c } + sec = normalizeSecurityConfig(sec) err := applySecurityConfig(c, sec) if err != nil { return nil @@ -1768,6 +1769,7 @@ func SaveConfig(path string, cfg *Config) error { logger.ErrorC("config", "security is nil") return fmt.Errorf("security is nil") } + cfg.security = normalizeSecurityConfig(cfg.security) // Ensure version is always set when saving if cfg.Version == 0 { cfg.Version = CurrentVersion diff --git a/pkg/config/security.go b/pkg/config/security.go index 1fda89bf0..5c71bf8c3 100644 --- a/pkg/config/security.go +++ b/pkg/config/security.go @@ -25,6 +25,25 @@ const ( SecurityConfigFile = ".security.yml" ) +func normalizeSecurityConfig(sec *SecurityConfig) *SecurityConfig { + if sec == nil { + sec = &SecurityConfig{} + } + if sec.ModelList == nil { + sec.ModelList = map[string]ModelSecurityEntry{} + } + if sec.Channels == nil { + sec.Channels = &ChannelsSecurity{} + } + if sec.Web == nil { + sec.Web = &WebToolsSecurity{} + } + if sec.Skills == nil { + sec.Skills = &SkillsSecurity{} + } + return sec +} + // SecurityConfig stores all sensitive data (API keys, tokens, secrets, passwords) // This data is loaded from security.yml and kept separate from the main config type SecurityConfig struct { @@ -191,7 +210,7 @@ func loadSecurityConfig(securityPath string) (*SecurityConfig, error) { data, err := os.ReadFile(securityPath) if err != nil { if os.IsNotExist(err) { - return &SecurityConfig{}, nil + return normalizeSecurityConfig(nil), nil } return nil, fmt.Errorf("failed to read security config: %w", err) } @@ -210,7 +229,7 @@ func loadSecurityConfig(securityPath string) (*SecurityConfig, error) { return nil, err } - return &sec, nil + return normalizeSecurityConfig(&sec), nil } // saveSecurityConfig saves the security configuration to security.yml diff --git a/pkg/config/security_integration_test.go b/pkg/config/security_integration_test.go index c1e1a2340..218914590 100644 --- a/pkg/config/security_integration_test.go +++ b/pkg/config/security_integration_test.go @@ -17,13 +17,12 @@ import ( // Test JSON unmarshal of private fields func TestJSONUnmarshalPrivateFields(t *testing.T) { - //nolint: govet type testStruct struct { PublicField string `json:"public"` - privateField string `json:"private"` + privateField string } - data := `{"public": "pub", "private": "priv"}` + data := `{"public": "pub", "privateField": "priv"}` var s testStruct if err := json.Unmarshal([]byte(data), &s); err != nil { t.Fatalf("JSON unmarshal failed: %v", err) @@ -35,9 +34,8 @@ func TestJSONUnmarshalPrivateFields(t *testing.T) { if s.PublicField != "pub" { t.Errorf("PublicField = %q, want 'pub'", s.PublicField) } - // This should fail because privateField is unexported - if s.privateField != "priv" { - t.Logf("privateField = %q, want 'priv' - THIS IS EXPECTED TO FAIL", s.privateField) + if s.privateField != "" { + t.Errorf("privateField = %q, want empty because unexported fields are ignored", s.privateField) } } diff --git a/pkg/config/security_test.go b/pkg/config/security_test.go index af08a67db..0f260ed59 100644 --- a/pkg/config/security_test.go +++ b/pkg/config/security_test.go @@ -20,6 +20,9 @@ func TestSecurityConfig(t *testing.T) { require.NoError(t, err) assert.NotNil(t, sec) assert.Empty(t, sec.ModelList) + assert.NotNil(t, sec.Channels) + assert.NotNil(t, sec.Web) + assert.NotNil(t, sec.Skills) }) } diff --git a/web/backend/api/config_test.go b/web/backend/api/config_test.go index cf8cd505e..9b05546f9 100644 --- a/web/backend/api/config_test.go +++ b/web/backend/api/config_test.go @@ -170,7 +170,7 @@ func setupPicoEnabledEnv(t *testing.T) (string, func()) { ModelList: map[string]config.ModelSecurityEntry{ "custom-default": {APIKeys: []string{"sk-default"}}, }, - Channels: config.ChannelsSecurity{ + Channels: &config.ChannelsSecurity{ Pico: &config.PicoSecurity{Token: "test-pico-token"}, }, }) diff --git a/web/backend/api/pico_test.go b/web/backend/api/pico_test.go index 263253cb2..b59878bf3 100644 --- a/web/backend/api/pico_test.go +++ b/web/backend/api/pico_test.go @@ -6,6 +6,7 @@ import ( "net/http" "net/http/httptest" "net/url" + "os" "path/filepath" "strconv" "testing" @@ -154,6 +155,44 @@ func TestEnsurePicoChannel_PreservesUserSettings(t *testing.T) { } } +func TestEnsurePicoChannel_ExistingConfigWithoutSecurityFile(t *testing.T) { + configPath := filepath.Join(t.TempDir(), "config.json") + + cfg := config.DefaultConfig() + raw, err := json.Marshal(cfg) + if err != nil { + t.Fatalf("Marshal() error = %v", err) + } + if err = os.WriteFile(configPath, raw, 0o600); err != nil { + t.Fatalf("WriteFile() error = %v", err) + } + + h := NewHandler(configPath) + + changed, err := h.ensurePicoChannel("") + if err != nil { + t.Fatalf("ensurePicoChannel() error = %v", err) + } + if !changed { + t.Fatal("ensurePicoChannel() should report changed when pico is missing") + } + + cfg, err = config.LoadConfig(configPath) + if err != nil { + t.Fatalf("LoadConfig() error = %v", err) + } + + if !cfg.Channels.Pico.Enabled { + t.Error("expected Pico to be enabled after setup") + } + if cfg.Channels.Pico.Token() == "" { + t.Error("expected a non-empty token after setup") + } + if _, err := os.Stat(filepath.Join(filepath.Dir(configPath), config.SecurityConfigFile)); err != nil { + t.Fatalf("expected .security.yml to be created: %v", err) + } +} + func TestEnsurePicoChannel_Idempotent(t *testing.T) { configPath := filepath.Join(t.TempDir(), "config.json") h := NewHandler(configPath) From ffbcbea4dcd8be86716da2ef2c616a4217435293 Mon Sep 17 00:00:00 2001 From: wenjie Date: Tue, 24 Mar 2026 17:31:28 +0800 Subject: [PATCH 16/39] fix(web): persist api_key when adding models (#1958) Make POST /api/models capture the request's api_key and store it via ModelConfig.SetAPIKey before saving config, so newly added models keep their credentials in the security config. Add a backend API test covering model creation with api_key persistence. --- web/backend/api/models.go | 13 ++++++++++-- web/backend/api/models_test.go | 39 ++++++++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+), 2 deletions(-) diff --git a/web/backend/api/models.go b/web/backend/api/models.go index 64a7b5f1f..48babd8cd 100644 --- a/web/backend/api/models.go +++ b/web/backend/api/models.go @@ -108,7 +108,12 @@ func (h *Handler) handleAddModel(w http.ResponseWriter, r *http.Request) { } defer r.Body.Close() - var mc config.ModelConfig + type custom struct { + config.ModelConfig + APIKey string `json:"api_key"` + } + + var mc custom if err = json.Unmarshal(body, &mc); err != nil { http.Error(w, fmt.Sprintf("Invalid JSON: %v", err), http.StatusBadRequest) return @@ -119,13 +124,17 @@ func (h *Handler) handleAddModel(w http.ResponseWriter, r *http.Request) { return } + if mc.APIKey != "" { + mc.ModelConfig.SetAPIKey(mc.APIKey) + } + cfg, err := config.LoadConfig(h.configPath) if err != nil { http.Error(w, fmt.Sprintf("Failed to load config: %v", err), http.StatusInternalServerError) return } - cfg.ModelList = append(cfg.ModelList, &mc) + cfg.ModelList = append(cfg.ModelList, &mc.ModelConfig) if err := config.SaveConfig(h.configPath, cfg); err != nil { http.Error(w, fmt.Sprintf("Failed to save config: %v", err), http.StatusInternalServerError) diff --git a/web/backend/api/models_test.go b/web/backend/api/models_test.go index 0127ce675..9d3e72bd3 100644 --- a/web/backend/api/models_test.go +++ b/web/backend/api/models_test.go @@ -1,6 +1,7 @@ package api import ( + "bytes" "encoding/json" "net/http" "net/http/httptest" @@ -317,6 +318,44 @@ func TestHandleListModels_NormalizesWildcardLocalAPIBaseForProbe(t *testing.T) { } } +func TestHandleAddModel_PersistsAPIKey(t *testing.T) { + configPath, cleanup := setupOAuthTestEnv(t) + defer cleanup() + + h := NewHandler(configPath) + mux := http.NewServeMux() + h.RegisterRoutes(mux) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/api/models", bytes.NewBufferString(`{ + "model_name":"new-model", + "model":"openai/gpt-4o-mini", + "api_key":"sk-new-model-key" + }`)) + req.Header.Set("Content-Type", "application/json") + mux.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String()) + } + + cfg, err := config.LoadConfig(configPath) + if err != nil { + t.Fatalf("LoadConfig() error = %v", err) + } + if len(cfg.ModelList) != 2 { + t.Fatalf("len(model_list) = %d, want 2", len(cfg.ModelList)) + } + + added := cfg.ModelList[1] + if added.ModelName != "new-model" { + t.Fatalf("model_name = %q, want %q", added.ModelName, "new-model") + } + if added.APIKey() != "sk-new-model-key" { + t.Fatalf("api_key = %q, want %q", added.APIKey(), "sk-new-model-key") + } +} + func TestMaskAPIKey(t *testing.T) { tests := []struct { name string From dea99da7d92ab9be6babc67b3bd83e59c9a62cad Mon Sep 17 00:00:00 2001 From: wenjie Date: Tue, 24 Mar 2026 18:06:29 +0800 Subject: [PATCH 17/39] fix(web): auto-configure Pico channel on launcher startup Export EnsurePicoChannel and reuse it during launcher and gateway startup so the Pico channel is initialized earlier with a generated token when needed. Also extend backend tests to cover startup-time Pico setup behavior and keep the setup path idempotent. --- web/backend/api/gateway.go | 2 +- web/backend/api/pico.go | 6 +-- web/backend/api/pico_test.go | 71 +++++++++++++++++++++++++----------- web/backend/main.go | 3 ++ 4 files changed, 56 insertions(+), 26 deletions(-) diff --git a/web/backend/api/gateway.go b/web/backend/api/gateway.go index 7f72f12b8..4bde5ce82 100644 --- a/web/backend/api/gateway.go +++ b/web/backend/api/gateway.go @@ -407,7 +407,7 @@ func (h *Handler) startGatewayLocked(initialStatus string, existingPid int) (int gateway.logs.Reset() // Ensure Pico Channel is configured before starting gateway - if _, err := h.ensurePicoChannel(""); err != nil { + if _, err := h.EnsurePicoChannel(""); err != nil { logger.ErrorC("gateway", fmt.Sprintf("Warning: failed to ensure pico channel: %v", err)) // Non-fatal: gateway can still start without pico channel } diff --git a/web/backend/api/pico.go b/web/backend/api/pico.go index 8fbb8737f..4faafc2ae 100644 --- a/web/backend/api/pico.go +++ b/web/backend/api/pico.go @@ -90,14 +90,14 @@ func (h *Handler) handleRegenPicoToken(w http.ResponseWriter, r *http.Request) { }) } -// ensurePicoChannel enables the Pico channel with sane defaults if it isn't +// EnsurePicoChannel enables the Pico channel with sane defaults if it isn't // already configured. Returns true when the config was modified. // // callerOrigin is the Origin header from the setup request. If non-empty and // no origins are configured yet, it's written as the allowed origin so the // WebSocket handshake works for whatever host the caller is on (LAN, custom // port, etc.). Pass "" when there's no request context. -func (h *Handler) ensurePicoChannel(callerOrigin string) (bool, error) { +func (h *Handler) EnsurePicoChannel(callerOrigin string) (bool, error) { cfg, err := config.LoadConfig(h.configPath) if err != nil { return false, fmt.Errorf("failed to load config: %w", err) @@ -134,7 +134,7 @@ func (h *Handler) ensurePicoChannel(callerOrigin string) (bool, error) { // // POST /api/pico/setup func (h *Handler) handlePicoSetup(w http.ResponseWriter, r *http.Request) { - changed, err := h.ensurePicoChannel(r.Header.Get("Origin")) + changed, err := h.EnsurePicoChannel(r.Header.Get("Origin")) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return diff --git a/web/backend/api/pico_test.go b/web/backend/api/pico_test.go index b59878bf3..051e356cf 100644 --- a/web/backend/api/pico_test.go +++ b/web/backend/api/pico_test.go @@ -18,12 +18,12 @@ func TestEnsurePicoChannel_FreshConfig(t *testing.T) { configPath := filepath.Join(t.TempDir(), "config.json") h := NewHandler(configPath) - changed, err := h.ensurePicoChannel("") + changed, err := h.EnsurePicoChannel("") if err != nil { - t.Fatalf("ensurePicoChannel() error = %v", err) + t.Fatalf("EnsurePicoChannel() error = %v", err) } if !changed { - t.Fatal("ensurePicoChannel() should report changed on a fresh config") + t.Fatal("EnsurePicoChannel() should report changed on a fresh config") } cfg, err := config.LoadConfig(configPath) @@ -43,8 +43,8 @@ func TestEnsurePicoChannel_DoesNotEnableTokenQuery(t *testing.T) { configPath := filepath.Join(t.TempDir(), "config.json") h := NewHandler(configPath) - if _, err := h.ensurePicoChannel(""); err != nil { - t.Fatalf("ensurePicoChannel() error = %v", err) + if _, err := h.EnsurePicoChannel(""); err != nil { + t.Fatalf("EnsurePicoChannel() error = %v", err) } cfg, err := config.LoadConfig(configPath) @@ -61,8 +61,8 @@ func TestEnsurePicoChannel_DoesNotSetWildcardOrigins(t *testing.T) { configPath := filepath.Join(t.TempDir(), "config.json") h := NewHandler(configPath) - if _, err := h.ensurePicoChannel("http://localhost:18800"); err != nil { - t.Fatalf("ensurePicoChannel() error = %v", err) + if _, err := h.EnsurePicoChannel("http://localhost:18800"); err != nil { + t.Fatalf("EnsurePicoChannel() error = %v", err) } cfg, err := config.LoadConfig(configPath) @@ -81,8 +81,8 @@ func TestEnsurePicoChannel_NoOriginWithoutCaller(t *testing.T) { configPath := filepath.Join(t.TempDir(), "config.json") h := NewHandler(configPath) - if _, err := h.ensurePicoChannel(""); err != nil { - t.Fatalf("ensurePicoChannel() error = %v", err) + if _, err := h.EnsurePicoChannel(""); err != nil { + t.Fatalf("EnsurePicoChannel() error = %v", err) } cfg, err := config.LoadConfig(configPath) @@ -102,8 +102,8 @@ func TestEnsurePicoChannel_SetsCallerOrigin(t *testing.T) { h := NewHandler(configPath) lanOrigin := "http://192.168.1.9:18800" - if _, err := h.ensurePicoChannel(lanOrigin); err != nil { - t.Fatalf("ensurePicoChannel() error = %v", err) + if _, err := h.EnsurePicoChannel(lanOrigin); err != nil { + t.Fatalf("EnsurePicoChannel() error = %v", err) } cfg, err := config.LoadConfig(configPath) @@ -131,12 +131,12 @@ func TestEnsurePicoChannel_PreservesUserSettings(t *testing.T) { h := NewHandler(configPath) - changed, err := h.ensurePicoChannel("") + changed, err := h.EnsurePicoChannel("") if err != nil { - t.Fatalf("ensurePicoChannel() error = %v", err) + t.Fatalf("EnsurePicoChannel() error = %v", err) } if changed { - t.Error("ensurePicoChannel() should not change a fully configured config") + t.Error("EnsurePicoChannel() should not change a fully configured config") } cfg, err = config.LoadConfig(configPath) @@ -169,12 +169,12 @@ func TestEnsurePicoChannel_ExistingConfigWithoutSecurityFile(t *testing.T) { h := NewHandler(configPath) - changed, err := h.ensurePicoChannel("") + changed, err := h.EnsurePicoChannel("") if err != nil { - t.Fatalf("ensurePicoChannel() error = %v", err) + t.Fatalf("EnsurePicoChannel() error = %v", err) } if !changed { - t.Fatal("ensurePicoChannel() should report changed when pico is missing") + t.Fatal("EnsurePicoChannel() should report changed when pico is missing") } cfg, err = config.LoadConfig(configPath) @@ -193,6 +193,33 @@ func TestEnsurePicoChannel_ExistingConfigWithoutSecurityFile(t *testing.T) { } } +func TestEnsurePicoChannel_ConfiguresPicoWithoutGateway(t *testing.T) { + configPath := filepath.Join(t.TempDir(), "config.json") + + cfg := config.DefaultConfig() + cfg.Agents.Defaults.ModelName = "" + if err := config.SaveConfig(configPath, cfg); err != nil { + t.Fatalf("SaveConfig() error = %v", err) + } + + h := NewHandler(configPath) + if _, err := h.EnsurePicoChannel(""); err != nil { + t.Fatalf("EnsurePicoChannel() error = %v", err) + } + + cfg, err := config.LoadConfig(configPath) + if err != nil { + t.Fatalf("LoadConfig() error = %v", err) + } + + if !cfg.Channels.Pico.Enabled { + t.Error("expected Pico to be enabled after launcher startup setup") + } + if cfg.Channels.Pico.Token() == "" { + t.Error("expected a non-empty token after launcher startup setup") + } +} + func TestEnsurePicoChannel_Idempotent(t *testing.T) { configPath := filepath.Join(t.TempDir(), "config.json") h := NewHandler(configPath) @@ -200,20 +227,20 @@ func TestEnsurePicoChannel_Idempotent(t *testing.T) { origin := "http://localhost:18800" // First call sets things up - if _, err := h.ensurePicoChannel(origin); err != nil { - t.Fatalf("first ensurePicoChannel() error = %v", err) + if _, err := h.EnsurePicoChannel(origin); err != nil { + t.Fatalf("first EnsurePicoChannel() error = %v", err) } cfg1, _ := config.LoadConfig(configPath) token1 := cfg1.Channels.Pico.Token() // Second call should be a no-op - changed, err := h.ensurePicoChannel(origin) + changed, err := h.EnsurePicoChannel(origin) if err != nil { - t.Fatalf("second ensurePicoChannel() error = %v", err) + t.Fatalf("second EnsurePicoChannel() error = %v", err) } if changed { - t.Error("second ensurePicoChannel() should not report changed") + t.Error("second EnsurePicoChannel() should not report changed") } cfg2, _ := config.LoadConfig(configPath) diff --git a/web/backend/main.go b/web/backend/main.go index 8183731fe..2f181603e 100644 --- a/web/backend/main.go +++ b/web/backend/main.go @@ -169,6 +169,9 @@ func main() { // API Routes (e.g. /api/status) apiHandler = api.NewHandler(absPath) + if _, err = apiHandler.EnsurePicoChannel(""); err != nil { + logger.ErrorC("web", fmt.Sprintf("Warning: failed to ensure pico channel on startup: %v", err)) + } apiHandler.SetServerOptions(portNum, effectivePublic, explicitPublic, launcherCfg.AllowedCIDRs) apiHandler.RegisterRoutes(mux) From fcc20ec72ccc2f9c413aa46239c6ab21ac976e28 Mon Sep 17 00:00:00 2001 From: Sabyasachi Patra Date: Tue, 24 Mar 2026 16:05:56 +0530 Subject: [PATCH 18/39] feat(tools): add tool argument schema validation before execution (#1877) Validate tool call arguments against each tool's Parameters() JSON Schema in ExecuteWithContext() before calling Execute(). This prevents type confusion, argument injection, and missing-field errors from reaching tools. Validates: required fields, type matching (string/integer/number/boolean/ array/object), enum membership, nested objects (recursive), array element types. Rejects unexpected extra properties unless additionalProperties is set to true (for MCP tool compatibility). Returns ToolResult{IsError: true} on failure so the LLM can self-correct. Ref: Security Hardening > Tool abuse prevention via strict parameter validation --- .gitignore | 1 + pkg/tools/registry.go | 8 + pkg/tools/validate.go | 209 +++++++++++++++++ pkg/tools/validate_test.go | 465 +++++++++++++++++++++++++++++++++++++ 4 files changed, 683 insertions(+) create mode 100644 pkg/tools/validate.go create mode 100644 pkg/tools/validate_test.go diff --git a/.gitignore b/.gitignore index 8b5f95215..72f3b1761 100644 --- a/.gitignore +++ b/.gitignore @@ -40,6 +40,7 @@ tasks/ # Plans docs/plans/ +docs/superpowers/ # Editors .vscode/ diff --git a/pkg/tools/registry.go b/pkg/tools/registry.go index ed373a28f..2c634e673 100644 --- a/pkg/tools/registry.go +++ b/pkg/tools/registry.go @@ -180,6 +180,14 @@ func (r *ToolRegistry) ExecuteWithContext( return ErrorResult(fmt.Sprintf("tool %q not found", name)).WithError(fmt.Errorf("tool not found")) } + // Validate arguments against the tool's declared schema. + if err := validateToolArgs(tool.Parameters(), args); err != nil { + logger.WarnCF("tool", "Tool argument validation failed", + map[string]any{"tool": name, "error": err.Error()}) + return ErrorResult(fmt.Sprintf("invalid arguments for tool %q: %s", name, err)). + WithError(fmt.Errorf("argument validation failed: %w", err)) + } + // Inject channel/chatID into ctx so tools read them via ToolChannel(ctx)/ToolChatID(ctx). // Always inject — tools validate what they require. ctx = WithToolContext(ctx, channel, chatID) diff --git a/pkg/tools/validate.go b/pkg/tools/validate.go new file mode 100644 index 000000000..940344708 --- /dev/null +++ b/pkg/tools/validate.go @@ -0,0 +1,209 @@ +package tools + +import ( + "fmt" + "math" +) + +// validateToolArgs validates args against a JSON Schema-like map. +// schema is expected to have optional keys: "properties", "required", "additionalProperties". +func validateToolArgs(schema map[string]any, args map[string]any) error { + if len(schema) == 0 { + return nil + } + + if args == nil { + args = map[string]any{} + } + + if err := checkRequired(schema, args); err != nil { + return err + } + + propsRaw, ok := schema["properties"] + if !ok { + return nil // no properties defined — accept any args + } + + props, ok := propsRaw.(map[string]any) + if !ok { + return nil + } + + additional := allowsAdditional(schema) + + for key, val := range args { + propSchemaRaw, known := props[key] + if !known { + if !additional { + return fmt.Errorf("unexpected property %q", key) + } + continue + } + propSchema, ok := propSchemaRaw.(map[string]any) + if !ok { + continue // can't validate without a proper schema map + } + if err := checkType(key, val, propSchema); err != nil { + return err + } + } + + return nil +} + +// checkRequired verifies that every field listed in schema["required"] is present in args. +func checkRequired(schema map[string]any, args map[string]any) error { + reqRaw, ok := schema["required"] + if !ok { + return nil + } + + var required []string + + switch r := reqRaw.(type) { + case []string: + required = r + case []any: + for _, v := range r { + s, ok := v.(string) + if ok { + required = append(required, s) + } + } + default: + return nil + } + + for _, field := range required { + if _, present := args[field]; !present { + return fmt.Errorf("missing required property %q", field) + } + } + return nil +} + +// allowsAdditional returns true when the schema explicitly sets +// "additionalProperties" to true, or when the key is absent (default: reject extras). +func allowsAdditional(schema map[string]any) bool { + v, ok := schema["additionalProperties"] + if !ok { + return false + } + b, ok := v.(bool) + return ok && b +} + +// checkType validates that val matches the JSON Schema type declared in propSchema. +func checkType(key string, val any, propSchema map[string]any) error { + typeRaw, ok := propSchema["type"] + if !ok { + return nil // no type constraint + } + typeName, ok := typeRaw.(string) + if !ok { + return nil + } + + switch typeName { + case "string": + if _, ok := val.(string); !ok { + return fmt.Errorf("property %q: expected string, got %T", key, val) + } + case "integer": + switch v := val.(type) { + case float64: + if v != math.Trunc(v) { + return fmt.Errorf("property %q: expected integer, got float64 with fractional part", key) + } + case int: + // ok + case int64: + // ok + default: + return fmt.Errorf("property %q: expected integer, got %T", key, val) + } + case "number": + switch val.(type) { + case float64, int, int64: + // ok + default: + return fmt.Errorf("property %q: expected number, got %T", key, val) + } + case "boolean": + if _, ok := val.(bool); !ok { + return fmt.Errorf("property %q: expected boolean, got %T", key, val) + } + case "array": + arr, ok := val.([]any) + if !ok { + return fmt.Errorf("property %q: expected array, got %T", key, val) + } + if err := checkArrayItems(key, arr, propSchema); err != nil { + return err + } + case "object": + obj, ok := val.(map[string]any) + if !ok { + return fmt.Errorf("property %q: expected object, got %T", key, val) + } + if err := validateToolArgs(propSchema, obj); err != nil { + return fmt.Errorf("property %q: %w", key, err) + } + } + + if err := checkEnum(key, val, propSchema); err != nil { + return err + } + + return nil +} + +// checkArrayItems validates each element of arr against the "items" sub-schema. +func checkArrayItems(key string, arr []any, propSchema map[string]any) error { + itemsRaw, ok := propSchema["items"] + if !ok { + return nil + } + itemSchema, ok := itemsRaw.(map[string]any) + if !ok { + return nil + } + for i, elem := range arr { + elemKey := fmt.Sprintf("%s[%d]", key, i) + if err := checkType(elemKey, elem, itemSchema); err != nil { + return err + } + } + return nil +} + +// checkEnum validates that val is one of the allowed enum values in propSchema. +func checkEnum(key string, val any, propSchema map[string]any) error { + enumRaw, ok := propSchema["enum"] + if !ok { + return nil + } + + switch ev := enumRaw.(type) { + case []any: + for _, allowed := range ev { + if val == allowed { + return nil + } + } + case []string: + s, ok := val.(string) + if ok { + for _, allowed := range ev { + if s == allowed { + return nil + } + } + } + default: + return nil // unknown enum format, skip + } + + return fmt.Errorf("property %q: value %v is not in enum", key, val) +} diff --git a/pkg/tools/validate_test.go b/pkg/tools/validate_test.go new file mode 100644 index 000000000..e7f4f619a --- /dev/null +++ b/pkg/tools/validate_test.go @@ -0,0 +1,465 @@ +package tools + +import ( + "context" + "strings" + "testing" +) + +// Ensure imports are used. +var ( + _ = context.Background + _ = strings.Contains +) + +func TestValidateToolArgs(t *testing.T) { + baseSchema := map[string]any{ + "type": "object", + "properties": map[string]any{ + "name": map[string]any{"type": "string"}, + "age": map[string]any{"type": "integer"}, + }, + "required": []string{"name"}, + } + + tests := []struct { + name string + schema map[string]any + args map[string]any + wantErr string // empty means no error expected + }{ + { + name: "valid args all required present", + schema: baseSchema, + args: map[string]any{"name": "alice", "age": float64(30)}, + }, + { + name: "missing required field", + schema: baseSchema, + args: map[string]any{"age": float64(30)}, + wantErr: "missing required property \"name\"", + }, + { + name: "wrong type string field gets number", + schema: baseSchema, + args: map[string]any{"name": float64(42)}, + wantErr: "expected string", + }, + { + name: "nil args with required fields", + schema: baseSchema, + args: nil, + wantErr: "missing required property \"name\"", + }, + { + name: "nil args no required fields", + schema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "name": map[string]any{"type": "string"}, + }, + }, + args: nil, + }, + { + name: "empty args no required fields", + schema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "name": map[string]any{"type": "string"}, + }, + }, + args: map[string]any{}, + }, + { + name: "optional field correct type", + schema: baseSchema, + args: map[string]any{"name": "bob", "age": float64(25)}, + }, + { + name: "optional field wrong type", + schema: baseSchema, + args: map[string]any{"name": "bob", "age": "twenty"}, + wantErr: "expected integer", + }, + { + name: "integer as float64 no fractional part", + schema: baseSchema, + args: map[string]any{"name": "carol", "age": float64(42)}, + }, + { + name: "actual float for integer field", + schema: baseSchema, + args: map[string]any{"name": "dave", "age": float64(42.5)}, + wantErr: "expected integer, got float64 with fractional part", + }, + { + name: "number type accepts float", + schema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "score": map[string]any{"type": "number"}, + }, + }, + args: map[string]any{"score": float64(3.14)}, + }, + { + name: "number type accepts integer", + schema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "score": map[string]any{"type": "number"}, + }, + }, + args: map[string]any{"score": float64(10)}, + }, + { + name: "boolean type valid", + schema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "flag": map[string]any{"type": "boolean"}, + }, + }, + args: map[string]any{"flag": true}, + }, + { + name: "boolean type wrong", + schema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "flag": map[string]any{"type": "boolean"}, + }, + }, + args: map[string]any{"flag": "true"}, + wantErr: "expected boolean", + }, + { + name: "required as []any from MCP deserialization", + schema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "cmd": map[string]any{"type": "string"}, + }, + "required": []any{"cmd"}, + }, + args: map[string]any{}, + wantErr: "missing required property \"cmd\"", + }, + { + name: "enum valid value []any", + schema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "color": map[string]any{"type": "string", "enum": []any{"red", "green", "blue"}}, + }, + }, + args: map[string]any{"color": "red"}, + }, + { + name: "enum invalid value []any", + schema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "color": map[string]any{"type": "string", "enum": []any{"red", "green", "blue"}}, + }, + }, + args: map[string]any{"color": "yellow"}, + wantErr: "not in enum", + }, + { + name: "enum valid value []string", + schema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "color": map[string]any{"type": "string", "enum": []string{"red", "green", "blue"}}, + }, + }, + args: map[string]any{"color": "green"}, + }, + { + name: "enum invalid value []string", + schema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "color": map[string]any{"type": "string", "enum": []string{"red", "green", "blue"}}, + }, + }, + args: map[string]any{"color": "yellow"}, + wantErr: "not in enum", + }, + { + name: "extra unexpected property rejected", + schema: baseSchema, + args: map[string]any{"name": "eve", "hobby": "chess"}, + wantErr: "unexpected property \"hobby\"", + }, + { + name: "extra property allowed with additionalProperties true", + schema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "name": map[string]any{"type": "string"}, + }, + "additionalProperties": true, + }, + args: map[string]any{"name": "eve", "hobby": "chess"}, + }, + { + name: "nested object valid", + schema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "address": map[string]any{ + "type": "object", + "properties": map[string]any{ + "city": map[string]any{"type": "string"}, + }, + "required": []string{"city"}, + }, + }, + }, + args: map[string]any{ + "address": map[string]any{"city": "Berlin"}, + }, + }, + { + name: "nested object wrong type", + schema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "address": map[string]any{ + "type": "object", + "properties": map[string]any{ + "city": map[string]any{"type": "string"}, + }, + }, + }, + }, + args: map[string]any{"address": "not an object"}, + wantErr: "expected object", + }, + { + name: "array with valid element types", + schema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "tags": map[string]any{ + "type": "array", + "items": map[string]any{"type": "string"}, + }, + }, + }, + args: map[string]any{"tags": []any{"a", "b", "c"}}, + }, + { + name: "array with wrong element types", + schema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "tags": map[string]any{ + "type": "array", + "items": map[string]any{"type": "string"}, + }, + }, + }, + args: map[string]any{"tags": []any{"a", float64(2)}}, + wantErr: "expected string", + }, + { + name: "schema with no properties key accepts any args", + schema: map[string]any{ + "type": "object", + }, + args: map[string]any{"anything": "goes"}, + }, + { + name: "empty schema accepts anything", + schema: map[string]any{}, + args: map[string]any{"foo": "bar"}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + err := validateToolArgs(tc.schema, tc.args) + if tc.wantErr == "" { + if err != nil { + t.Fatalf("expected no error, got: %v", err) + } + return + } + if err == nil { + t.Fatalf("expected error containing %q, got nil", tc.wantErr) + } + if !strings.Contains(err.Error(), tc.wantErr) { + t.Fatalf("expected error containing %q, got: %v", tc.wantErr, err) + } + }) + } +} + +func TestValidateToolArgs_RegistryIntegration(t *testing.T) { + r := NewToolRegistry() + r.Register(&mockRegistryTool{ + name: "read_file", + desc: "reads a file", + params: map[string]any{ + "type": "object", + "properties": map[string]any{ + "path": map[string]any{"type": "string"}, + }, + "required": []string{"path"}, + }, + result: SilentResult("file contents"), + }) + + // Valid args — should succeed + result := r.Execute(context.Background(), "read_file", map[string]any{"path": "/tmp/x"}) + if result.IsError { + t.Errorf("expected success, got error: %s", result.ForLLM) + } + + // Missing required field — should fail with validation error + result = r.Execute(context.Background(), "read_file", map[string]any{}) + if !result.IsError { + t.Error("expected validation error for missing required field") + } + if !strings.Contains(result.ForLLM, "missing required p") { + t.Errorf("expected 'missing required p...' in error, got %q", result.ForLLM) + } + if result.Err == nil { + t.Error("expected Err to be set via WithError") + } + + // Wrong type — should fail with validation error + result = r.Execute(context.Background(), "read_file", map[string]any{"path": 123.0}) + if !result.IsError { + t.Error("expected validation error for wrong type") + } + if !strings.Contains(result.ForLLM, "expected string") { + t.Errorf("expected 'expected string' in error, got %q", result.ForLLM) + } + + // Extra property — should fail with validation error + result = r.Execute(context.Background(), "read_file", map[string]any{"path": "/x", "__inject": true}) + if !result.IsError { + t.Error("expected validation error for extra property") + } + if !strings.Contains(result.ForLLM, "unexpected prop") { + t.Errorf("expected 'unexpected prop...' in error, got %q", result.ForLLM) + } +} + +func TestValidateToolArgs_RealSchemas(t *testing.T) { + execSchema := map[string]any{ + "type": "object", + "properties": map[string]any{ + "command": map[string]any{"type": "string"}, + "working_dir": map[string]any{"type": "string"}, + }, + "required": []string{"command"}, + } + + cronSchema := map[string]any{ + "type": "object", + "properties": map[string]any{ + "action": map[string]any{ + "type": "string", + "enum": []any{"add", "list", "remove", "enable", "disable"}, + }, + }, + "required": []string{"action"}, + } + + webSearchSchema := map[string]any{ + "type": "object", + "properties": map[string]any{ + "query": map[string]any{"type": "string"}, + "count": map[string]any{"type": "integer"}, + }, + "required": []string{"query"}, + } + + tests := []struct { + name string + schema map[string]any + args map[string]any + wantErr string + }{ + // ExecTool + { + name: "exec valid args", + schema: execSchema, + args: map[string]any{"command": "ls -la", "working_dir": "/tmp"}, + }, + { + name: "exec missing required command", + schema: execSchema, + args: map[string]any{"working_dir": "/tmp"}, + wantErr: "missing required property \"command\"", + }, + { + name: "exec wrong type for command", + schema: execSchema, + args: map[string]any{"command": float64(123)}, + wantErr: "expected string", + }, + { + name: "exec extra injected arg", + schema: execSchema, + args: map[string]any{"command": "ls", "malicious": "payload"}, + wantErr: "unexpected property \"malicious\"", + }, + + // CronTool + { + name: "cron valid enum value", + schema: cronSchema, + args: map[string]any{"action": "add"}, + }, + { + name: "cron invalid enum value", + schema: cronSchema, + args: map[string]any{"action": "destroy"}, + wantErr: "not in enum", + }, + + // WebSearchTool + { + name: "websearch valid args", + schema: webSearchSchema, + args: map[string]any{"query": "golang testing", "count": float64(10)}, + }, + { + name: "websearch missing required query", + schema: webSearchSchema, + args: map[string]any{"count": float64(5)}, + wantErr: "missing required property \"query\"", + }, + { + name: "websearch wrong type for count", + schema: webSearchSchema, + args: map[string]any{"query": "test", "count": "ten"}, + wantErr: "expected integer", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + err := validateToolArgs(tc.schema, tc.args) + if tc.wantErr == "" { + if err != nil { + t.Fatalf("expected no error, got: %v", err) + } + return + } + if err == nil { + t.Fatalf("expected error containing %q, got nil", tc.wantErr) + } + if !strings.Contains(err.Error(), tc.wantErr) { + t.Fatalf("expected error containing %q, got: %v", tc.wantErr, err) + } + }) + } +} From fa5ab720226e5c76e3ee553087c196dcb1302b1a Mon Sep 17 00:00:00 2001 From: hsguo Date: Tue, 24 Mar 2026 18:37:41 +0800 Subject: [PATCH 19/39] WeChat Web QR Code Integration (#1961) --- pkg/config/config.go | 7 + web/backend/api/channels.go | 1 + web/backend/api/config.go | 8 +- web/backend/api/router.go | 14 +- web/backend/api/weixin.go | 300 ++++++++++++++++++ web/frontend/src/api/channels.ts | 18 ++ .../channels/channel-config-page.tsx | 18 +- .../channels/channel-forms/weixin-form.tsx | 270 ++++++++++++++++ web/frontend/src/i18n/locales/en.json | 18 +- web/frontend/src/i18n/locales/zh.json | 18 +- 10 files changed, 661 insertions(+), 11 deletions(-) create mode 100644 web/backend/api/weixin.go create mode 100644 web/frontend/src/components/channels/channel-forms/weixin-form.tsx diff --git a/pkg/config/config.go b/pkg/config/config.go index 8073dc723..b281824ce 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -815,6 +815,7 @@ func (c *WeComAIBotConfig) SetSecret(secret string) { type WeixinConfig struct { Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_WEIXIN_ENABLED"` token string + AccountID string `json:"account_id,omitempty" env:"PICOCLAW_CHANNELS_WEIXIN_ACCOUNT_ID"` BaseURL string `json:"base_url" env:"PICOCLAW_CHANNELS_WEIXIN_BASE_URL"` CDNBaseURL string `json:"cdn_base_url" env:"PICOCLAW_CHANNELS_WEIXIN_CDN_BASE_URL"` Proxy string `json:"proxy" env:"PICOCLAW_CHANNELS_WEIXIN_PROXY"` @@ -2019,6 +2020,12 @@ func (c *Config) SecurityCopyFrom(cfg *Config) { } } +// ApplySecurity re-applies the stored security config to populate private fields (tokens, API keys, etc.). +// Call this after SecurityCopyFrom when you need private fields to be accessible for validation or use. +func (c *Config) ApplySecurity() error { + return applySecurityConfig(c, c.security) +} + func MergeAPIKeys(apiKey string, apiKeys []string) []string { seen := make(map[string]struct{}) var all []string diff --git a/web/backend/api/channels.go b/web/backend/api/channels.go index 507882823..21624d3ef 100644 --- a/web/backend/api/channels.go +++ b/web/backend/api/channels.go @@ -12,6 +12,7 @@ type channelCatalogItem struct { } var channelCatalog = []channelCatalogItem{ + {Name: "weixin", ConfigKey: "weixin"}, {Name: "telegram", ConfigKey: "telegram"}, {Name: "discord", ConfigKey: "discord"}, {Name: "slack", ConfigKey: "slack"}, diff --git a/web/backend/api/config.go b/web/backend/api/config.go index fa2e91dec..e67e3e6d7 100644 --- a/web/backend/api/config.go +++ b/web/backend/api/config.go @@ -152,9 +152,13 @@ func (h *Handler) handlePatchConfig(w http.ResponseWriter, r *http.Request) { return } - // Copy security credentials before validation so security-managed - // fields (e.g. pico token) are available for validation checks. + // Restore security fields (tokens/keys) from the loaded config before validation, + // because private fields are lost during JSON round-trip. newCfg.SecurityCopyFrom(cfg) + if err := newCfg.ApplySecurity(); err != nil { + http.Error(w, fmt.Sprintf("Failed to apply security config: %v", err), http.StatusInternalServerError) + return + } if errs := validateConfig(&newCfg); len(errs) > 0 { w.Header().Set("Content-Type", "application/json") diff --git a/web/backend/api/router.go b/web/backend/api/router.go index e4df86ed9..d09f68eac 100644 --- a/web/backend/api/router.go +++ b/web/backend/api/router.go @@ -17,15 +17,18 @@ type Handler struct { oauthMu sync.Mutex oauthFlows map[string]*oauthFlow oauthState map[string]string + weixinMu sync.Mutex + weixinFlows map[string]*weixinFlow } // NewHandler creates an instance of the API handler. func NewHandler(configPath string) *Handler { return &Handler{ - configPath: configPath, - serverPort: launcherconfig.DefaultPort, - oauthFlows: make(map[string]*oauthFlow), - oauthState: make(map[string]string), + configPath: configPath, + serverPort: launcherconfig.DefaultPort, + oauthFlows: make(map[string]*oauthFlow), + oauthState: make(map[string]string), + weixinFlows: make(map[string]*weixinFlow), } } @@ -69,6 +72,9 @@ func (h *Handler) RegisterRoutes(mux *http.ServeMux) { // Launcher service parameters (port/public) h.registerLauncherConfigRoutes(mux) + + // WeChat QR login flow + h.registerWeixinRoutes(mux) } // Shutdown gracefully shuts down the handler, stopping the gateway if it was started by this handler. diff --git a/web/backend/api/weixin.go b/web/backend/api/weixin.go new file mode 100644 index 000000000..e7e94f39e --- /dev/null +++ b/web/backend/api/weixin.go @@ -0,0 +1,300 @@ +package api + +import ( + "context" + "crypto/rand" + "encoding/base64" + "encoding/hex" + "encoding/json" + "fmt" + "net/http" + "strings" + "time" + + "rsc.io/qr" + + "github.com/sipeed/picoclaw/pkg/channels/weixin" + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/logger" +) + +const ( + weixinFlowTTL = 5 * time.Minute + weixinFlowGCAge = 30 * time.Minute + weixinBaseURL = "https://ilinkai.weixin.qq.com/" + weixinBotType = "3" +) + +const ( + weixinStatusWait = "wait" + weixinStatusScanned = "scaned" + weixinStatusConfirmed = "confirmed" + weixinStatusExpired = "expired" + weixinStatusError = "error" +) + +type weixinFlow struct { + ID string + Qrcode string // qrcode token from WeChat API (used for status polling) + QRDataURI string // base64 PNG data URI for display + AccountID string // IlinkBotID returned on confirmed + Status string // wait / scaned / confirmed / expired / error + Error string + CreatedAt time.Time + UpdatedAt time.Time + ExpiresAt time.Time +} + +type weixinFlowResponse struct { + FlowID string `json:"flow_id"` + Status string `json:"status"` + QRDataURI string `json:"qr_data_uri,omitempty"` + AccountID string `json:"account_id,omitempty"` + Error string `json:"error,omitempty"` +} + +// registerWeixinRoutes binds WeChat QR login endpoints to the ServeMux. +func (h *Handler) registerWeixinRoutes(mux *http.ServeMux) { + mux.HandleFunc("POST /api/weixin/flows", h.handleStartWeixinFlow) + mux.HandleFunc("GET /api/weixin/flows/{id}", h.handlePollWeixinFlow) +} + +// handleStartWeixinFlow starts a new WeChat QR login flow. +// +// POST /api/weixin/flows +func (h *Handler) handleStartWeixinFlow(w http.ResponseWriter, r *http.Request) { + ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second) + defer cancel() + + api, err := weixin.NewApiClient(weixinBaseURL, "", "") + if err != nil { + http.Error(w, fmt.Sprintf("failed to create weixin client: %v", err), http.StatusInternalServerError) + return + } + + qrResp, err := api.GetQRCode(ctx, weixinBotType) + if err != nil { + http.Error(w, fmt.Sprintf("failed to get QR code: %v", err), http.StatusInternalServerError) + return + } + + dataURI, err := generateQRDataURI(qrResp.QrcodeImgContent) + if err != nil { + http.Error(w, fmt.Sprintf("failed to generate QR image: %v", err), http.StatusInternalServerError) + return + } + + now := time.Now() + flow := &weixinFlow{ + ID: newWeixinFlowID(), + Qrcode: qrResp.Qrcode, + QRDataURI: dataURI, + Status: weixinStatusWait, + CreatedAt: now, + UpdatedAt: now, + ExpiresAt: now.Add(weixinFlowTTL), + } + h.storeWeixinFlow(flow) + + logger.InfoCF("weixin", "QR flow started", map[string]any{"flow_id": flow.ID}) + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(weixinFlowResponse{ + FlowID: flow.ID, + Status: flow.Status, + QRDataURI: flow.QRDataURI, + }) +} + +// handlePollWeixinFlow polls the WeChat API for QR code status and updates the flow. +// +// GET /api/weixin/flows/{id} +func (h *Handler) handlePollWeixinFlow(w http.ResponseWriter, r *http.Request) { + flowID := strings.TrimSpace(r.PathValue("id")) + if flowID == "" { + http.Error(w, "missing flow id", http.StatusBadRequest) + return + } + + flow, ok := h.getWeixinFlow(flowID) + if !ok { + http.Error(w, "flow not found", http.StatusNotFound) + return + } + + // Return terminal states directly without polling WeChat again + if flow.Status == weixinStatusConfirmed || + flow.Status == weixinStatusExpired || + flow.Status == weixinStatusError { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(weixinFlowResponse{ + FlowID: flow.ID, + Status: flow.Status, + Error: flow.Error, + }) + return + } + + ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second) + defer cancel() + + api, err := weixin.NewApiClient(weixinBaseURL, "", "") + if err != nil { + h.setWeixinFlowError(flowID, fmt.Sprintf("client error: %v", err)) + flow, _ = h.getWeixinFlow(flowID) + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(weixinFlowResponse{FlowID: flow.ID, Status: flow.Status, Error: flow.Error}) + return + } + + statusResp, err := api.GetQRCodeStatus(ctx, flow.Qrcode) + if err != nil { + // Transient error — keep current status, return it + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(weixinFlowResponse{ + FlowID: flow.ID, + Status: flow.Status, + QRDataURI: flow.QRDataURI, + }) + return + } + + switch statusResp.Status { + case weixinStatusWait: + // no change + + case weixinStatusScanned: + h.updateWeixinFlowStatus(flowID, weixinStatusScanned) + + case weixinStatusConfirmed: + if statusResp.BotToken == "" { + h.setWeixinFlowError(flowID, "login confirmed but missing bot_token") + break + } + if saveErr := h.saveWeixinToken(statusResp.BotToken, statusResp.IlinkBotID); saveErr != nil { + h.setWeixinFlowError(flowID, fmt.Sprintf("failed to save token: %v", saveErr)) + logger.ErrorCF("weixin", "failed to save token", map[string]any{"error": saveErr.Error()}) + break + } + h.setWeixinFlowConfirmed(flowID, statusResp.IlinkBotID) + logger.InfoCF("weixin", "QR login confirmed, token saved", map[string]any{ + "flow_id": flowID, + "account_id": statusResp.IlinkBotID, + }) + + case weixinStatusExpired: + h.updateWeixinFlowStatus(flowID, weixinStatusExpired) + + default: + // unknown status, keep as-is + } + + flow, _ = h.getWeixinFlow(flowID) + w.Header().Set("Content-Type", "application/json") + resp := weixinFlowResponse{ + FlowID: flow.ID, + Status: flow.Status, + AccountID: flow.AccountID, + Error: flow.Error, + } + if flow.Status == weixinStatusWait || flow.Status == weixinStatusScanned { + resp.QRDataURI = flow.QRDataURI + } + _ = json.NewEncoder(w).Encode(resp) +} + +// saveWeixinToken writes the token and account ID into the config file. +func (h *Handler) saveWeixinToken(token, accountID string) error { + cfg, err := config.LoadConfig(h.configPath) + if err != nil { + return fmt.Errorf("load config: %w", err) + } + cfg.Channels.Weixin.SetToken(token) + if accountID != "" { + cfg.Channels.Weixin.AccountID = accountID + } + return config.SaveConfig(h.configPath, cfg) +} + +// generateQRDataURI encodes content as a QR code PNG and returns a data URI. +func generateQRDataURI(content string) (string, error) { + code, err := qr.Encode(content, qr.L) + if err != nil { + return "", fmt.Errorf("qr encode: %w", err) + } + pngBytes := code.PNG() + encoded := base64.StdEncoding.EncodeToString(pngBytes) + return "data:image/png;base64," + encoded, nil +} + +func newWeixinFlowID() string { + buf := make([]byte, 12) + if _, err := rand.Read(buf); err != nil { + return fmt.Sprintf("wx_%d", time.Now().UnixNano()) + } + return "wx_" + hex.EncodeToString(buf) +} + +func (h *Handler) storeWeixinFlow(flow *weixinFlow) { + h.weixinMu.Lock() + defer h.weixinMu.Unlock() + h.gcWeixinFlowsLocked(time.Now()) + h.weixinFlows[flow.ID] = flow +} + +func (h *Handler) getWeixinFlow(flowID string) (*weixinFlow, bool) { + h.weixinMu.Lock() + defer h.weixinMu.Unlock() + h.gcWeixinFlowsLocked(time.Now()) + flow, ok := h.weixinFlows[flowID] + if !ok { + return nil, false + } + cp := *flow + return &cp, true +} + +func (h *Handler) updateWeixinFlowStatus(flowID, status string) { + h.weixinMu.Lock() + defer h.weixinMu.Unlock() + if flow, ok := h.weixinFlows[flowID]; ok { + flow.Status = status + flow.UpdatedAt = time.Now() + } +} + +func (h *Handler) setWeixinFlowConfirmed(flowID, accountID string) { + h.weixinMu.Lock() + defer h.weixinMu.Unlock() + if flow, ok := h.weixinFlows[flowID]; ok { + flow.Status = weixinStatusConfirmed + flow.AccountID = accountID + flow.UpdatedAt = time.Now() + } +} + +func (h *Handler) setWeixinFlowError(flowID, errMsg string) { + h.weixinMu.Lock() + defer h.weixinMu.Unlock() + if flow, ok := h.weixinFlows[flowID]; ok { + flow.Status = weixinStatusError + flow.Error = errMsg + flow.UpdatedAt = time.Now() + } +} + +func (h *Handler) gcWeixinFlowsLocked(now time.Time) { + for id, flow := range h.weixinFlows { + if flow.Status == weixinStatusWait || flow.Status == weixinStatusScanned { + if !flow.ExpiresAt.IsZero() && now.After(flow.ExpiresAt) { + flow.Status = weixinStatusExpired + flow.UpdatedAt = now + } + } + if flow.Status != weixinStatusWait && + flow.Status != weixinStatusScanned && + now.Sub(flow.UpdatedAt) > weixinFlowGCAge { + delete(h.weixinFlows, id) + } + } +} diff --git a/web/frontend/src/api/channels.ts b/web/frontend/src/api/channels.ts index ecd77632c..c3d3a65f3 100644 --- a/web/frontend/src/api/channels.ts +++ b/web/frontend/src/api/channels.ts @@ -62,4 +62,22 @@ export async function patchAppConfig( }) } +// WeChat QR login flow API + +export interface WeixinFlowResponse { + flow_id: string + status: "wait" | "scaned" | "confirmed" | "expired" | "error" + qr_data_uri?: string + account_id?: string + error?: string +} + +export async function startWeixinFlow(): Promise { + return request("/api/weixin/flows", { method: "POST" }) +} + +export async function pollWeixinFlow(flowID: string): Promise { + return request(`/api/weixin/flows/${encodeURIComponent(flowID)}`) +} + export type { ChannelsCatalogResponse, ConfigActionResponse } diff --git a/web/frontend/src/components/channels/channel-config-page.tsx b/web/frontend/src/components/channels/channel-config-page.tsx index b19d11e6a..4996a6314 100644 --- a/web/frontend/src/components/channels/channel-config-page.tsx +++ b/web/frontend/src/components/channels/channel-config-page.tsx @@ -17,6 +17,7 @@ import { FeishuForm } from "@/components/channels/channel-forms/feishu-form" import { GenericForm } from "@/components/channels/channel-forms/generic-form" import { SlackForm } from "@/components/channels/channel-forms/slack-form" import { TelegramForm } from "@/components/channels/channel-forms/telegram-form" +import { WeixinForm } from "@/components/channels/channel-forms/weixin-form" import { PageHeader } from "@/components/page-header" import { Button } from "@/components/ui/button" import { Switch } from "@/components/ui/switch" @@ -142,6 +143,8 @@ function isConfigured( ) case "onebot": return asString(config.ws_url) !== "" + case "weixin": + return asString(config.account_id) !== "" case "wecom": return asString(config.token) !== "" case "wecom_app": @@ -251,8 +254,8 @@ export function ChannelConfigPage({ channelName }: ChannelConfigPageProps) { const [editConfig, setEditConfig] = useState({}) const [enabled, setEnabled] = useState(false) - const loadData = useCallback(async () => { - setLoading(true) + const loadData = useCallback(async (silent = false) => { + if (!silent) setLoading(true) try { const [catalog, appConfig] = await Promise.all([ getChannelsCatalog(), @@ -285,7 +288,7 @@ export function ChannelConfigPage({ channelName }: ChannelConfigPageProps) { } catch (e) { setFetchError(e instanceof Error ? e.message : t("channels.loadError")) } finally { - setLoading(false) + if (!silent) setLoading(false) } }, [channelName, t]) @@ -446,6 +449,15 @@ export function ChannelConfigPage({ channelName }: ChannelConfigPageProps) { fieldErrors={fieldErrors} /> ) + case "weixin": + return ( + void loadData(true)} + /> + ) default: return ( void + isEdit: boolean + onBindSuccess?: () => void +} + +function asString(value: unknown): string { + return typeof value === "string" ? value : "" +} + +function asStringArray(value: unknown): string[] { + if (!Array.isArray(value)) return [] + return value.filter((item): item is string => typeof item === "string") +} + +export function WeixinForm({ config, onChange, isEdit, onBindSuccess }: WeixinFormProps) { + const { t } = useTranslation() + + const [bindState, setBindState] = useState("idle") + const [qrDataURI, setQrDataURI] = useState(null) + const [accountID, setAccountID] = useState(null) + const [errorMsg, setErrorMsg] = useState("") + + const pollTimerRef = useRef | null>(null) + const isBound = isEdit && asString(config.account_id) !== "" + const existingAccountID = asString(config.account_id) + + const stopPolling = useCallback(() => { + if (pollTimerRef.current !== null) { + clearInterval(pollTimerRef.current) + pollTimerRef.current = null + } + }, []) + + useEffect(() => () => stopPolling(), [stopPolling]) + + const startPolling = useCallback( + (id: string) => { + stopPolling() + pollTimerRef.current = setInterval(async () => { + try { + const resp = await pollWeixinFlow(id) + if (resp.status === "scaned") { + setBindState("scaned") + } else if (resp.status === "confirmed") { + stopPolling() + setAccountID(resp.account_id ?? null) + setBindState("confirmed") + onBindSuccess?.() + } else if (resp.status === "expired") { + stopPolling() + setBindState("expired") + } else if (resp.status === "error") { + stopPolling() + setBindState("error") + setErrorMsg(resp.error ?? t("channels.weixin.errorGeneric")) + } + } catch { + // transient network error — keep polling + } + }, 2000) + }, + [stopPolling, onBindSuccess, t], + ) + + const handleBind = async () => { + setBindState("loading") + setErrorMsg("") + setQrDataURI(null) + stopPolling() + try { + const resp = await startWeixinFlow() + setQrDataURI(resp.qr_data_uri ?? null) + setBindState("waiting") + startPolling(resp.flow_id) + } catch (e) { + setBindState("error") + setErrorMsg(e instanceof Error ? e.message : t("channels.weixin.errorGeneric")) + } + } + + const handleRebind = () => { + stopPolling() + setBindState("idle") + setQrDataURI(null) + setAccountID(null) + setErrorMsg("") + void handleBind() + } + + const renderBindSection = () => { + if (bindState === "idle") { + if (isBound) { + return ( +
+
+ + {t("channels.weixin.bound")} +
+ {existingAccountID && ( +

{existingAccountID}

+ )} + +
+ ) + } + return ( +
+

{t("channels.weixin.notBound")}

+ +
+ ) + } + + if (bindState === "loading") { + return ( +
+ +

{t("channels.weixin.generating")}

+
+ ) + } + + if (bindState === "waiting" || bindState === "scaned") { + return ( +
+ {qrDataURI ? ( + WeChat QR Code + ) : ( +
+ +
+ )} + {bindState === "scaned" ? ( +
+ + {t("channels.weixin.scanned")} +
+ ) : ( +

{t("channels.weixin.scanHint")}

+ )} + +
+ ) + } + + if (bindState === "confirmed") { + return ( +
+
+ +
+

+ {t("channels.weixin.bound")} +

+ {accountID && ( +

{accountID}

+ )} + +
+ ) + } + + if (bindState === "expired") { + return ( +
+
+ +
+

{t("channels.weixin.expired")}

+ +
+ ) + } + + if (bindState === "error") { + return ( +
+
+ +
+

{errorMsg || t("channels.weixin.errorGeneric")}

+ +
+ ) + } + + return null + } + + return ( +
+ {/* QR Bind Section */} +
+
+

{t("channels.weixin.bindTitle")}

+

{t("channels.weixin.bindDesc")}

+
+ {renderBindSection()} +
+ + {/* allow_from */} + + + onChange( + "allow_from", + e.target.value + .split(",") + .map((s: string) => s.trim()) + .filter(Boolean), + ) + } + placeholder={t("channels.field.allowFromPlaceholder")} + /> + + + {/* proxy */} + + onChange("proxy", e.target.value)} + placeholder="http://localhost:7890" + /> + +
+ ) +} diff --git a/web/frontend/src/i18n/locales/en.json b/web/frontend/src/i18n/locales/en.json index 66e39ad0e..0b0afa39d 100644 --- a/web/frontend/src/i18n/locales/en.json +++ b/web/frontend/src/i18n/locales/en.json @@ -240,7 +240,23 @@ "pico": "Web", "maixcam": "MaixCam", "matrix": "Matrix", - "irc": "IRC" + "irc": "IRC", + "weixin": "WeChat" + }, + "weixin": { + "bindTitle": "WeChat Account Binding", + "bindDesc": "Scan the QR code with WeChat to bind your personal account.", + "bind": "Bind WeChat", + "rebind": "Re-bind", + "bound": "WeChat Bound", + "notBound": "WeChat account not bound yet.", + "generating": "Generating QR code...", + "scanHint": "Open WeChat and scan the QR code", + "scanned": "Scanned — please confirm in WeChat", + "expired": "QR code expired", + "retry": "Try Again", + "refresh": "Refresh QR", + "errorGeneric": "An error occurred. Please try again." }, "field": { "token": "Bot Token", diff --git a/web/frontend/src/i18n/locales/zh.json b/web/frontend/src/i18n/locales/zh.json index 65f2a5548..e85e4dd44 100644 --- a/web/frontend/src/i18n/locales/zh.json +++ b/web/frontend/src/i18n/locales/zh.json @@ -240,7 +240,23 @@ "pico": "Web", "maixcam": "MaixCam", "matrix": "Matrix", - "irc": "IRC" + "irc": "IRC", + "weixin": "微信" + }, + "weixin": { + "bindTitle": "微信账号绑定", + "bindDesc": "使用微信扫描二维码以绑定您的个人微信账号。", + "bind": "绑定微信", + "rebind": "重新绑定", + "bound": "微信已绑定", + "notBound": "尚未绑定微信账号。", + "generating": "正在生成二维码...", + "scanHint": "打开微信,扫描二维码", + "scanned": "已扫码 — 请在微信中确认", + "expired": "二维码已过期", + "retry": "重试", + "refresh": "刷新二维码", + "errorGeneric": "发生错误,请重试。" }, "field": { "token": "Bot Token", From f2f6987f00c57950b7cf2f1a2298f154e176051f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=BE=8E=E9=9B=BB=E7=90=83?= Date: Tue, 24 Mar 2026 19:27:29 +0800 Subject: [PATCH 20/39] test(agent): allow mock custom tool args (#1965) --- pkg/agent/loop_test.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pkg/agent/loop_test.go b/pkg/agent/loop_test.go index 976d25c4b..1a4a44edf 100644 --- a/pkg/agent/loop_test.go +++ b/pkg/agent/loop_test.go @@ -636,8 +636,9 @@ func (m *mockCustomTool) Description() string { func (m *mockCustomTool) Parameters() map[string]any { return map[string]any{ - "type": "object", - "properties": map[string]any{}, + "type": "object", + "properties": map[string]any{}, + "additionalProperties": true, } } From 8b6cbd99090908e2ccbd56e18ca06cf9a9283ee5 Mon Sep 17 00:00:00 2001 From: lxowalle <83055338+lxowalle@users.noreply.github.com> Date: Tue, 24 Mar 2026 20:02:58 +0800 Subject: [PATCH 21/39] Fix: Prevent security.yml from being overwritten during config migration (#1966) --- pkg/config/config.go | 12 ++ pkg/config/migration_integration_test.go | 115 +++++++++++++++++++ pkg/config/security.go | 136 +++++++++++++++++++++++ 3 files changed, 263 insertions(+) diff --git a/pkg/config/config.go b/pkg/config/config.go index b281824ce..84e1ab61a 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -1395,6 +1395,18 @@ func LoadConfig(path string) (*Config, error) { if err != nil { return nil, err } + // Load existing security config and merge with migrated one to prevent data loss + existingSec, secErr := loadSecurityConfig(securityPath(path)) + if secErr != nil { + logger.WarnF("failed to load existing security config during migration", map[string]any{"error": secErr}) + } + if existingSec != nil && cfg.security != nil { + cfg.security = mergeSecurityConfig(existingSec, cfg.security) + // Re-apply the merged security config to update all channels and models + if err = applySecurityConfig(cfg, cfg.security); err != nil { + logger.WarnF("failed to re-apply merged security config during migration", map[string]any{"error": err}) + } + } defer func(cfg *Config) { _ = SaveConfig(path, cfg) }(cfg) diff --git a/pkg/config/migration_integration_test.go b/pkg/config/migration_integration_test.go index c884a6b5d..49d2a5831 100644 --- a/pkg/config/migration_integration_test.go +++ b/pkg/config/migration_integration_test.go @@ -566,3 +566,118 @@ func TestMigration_Integration_ModelNameField(t *testing.T) { t.Errorf("ModelFallbacks[0] = %q, want %q", cfg.Agents.Defaults.ModelFallbacks[0], "deepseek-chat") } } + +// TestMigration_PreservesExistingSecurityConfig tests that when migrating from v0 to v1, +// existing .security.yml values (e.g., loaded from environment variables) are preserved +// and not overwritten by empty values from the legacy config. +func TestMigration_PreservesExistingSecurityConfig(t *testing.T) { + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "config.json") + securityPath := filepath.Join(tmpDir, ".security.yml") + + // Create a legacy config (version 0) with model_list and channel config + // The model_list doesn't have api_keys, they should come from existing .security.yml + legacyConfig := `{ + "agents": { + "defaults": { + "provider": "openai", + "model": "gpt-4" + } + }, + "model_list": [ + { + "model_name": "openai", + "model": "openai/gpt-4" + } + ], + "channels": { + "telegram": { + "enabled": true + } + }, + "gateway": { + "host": "127.0.0.1", + "port": 18790 + }, + "tools": { + "web": {"enabled": true} + }, + "heartbeat": { + "enabled": true, + "interval": 30 + }, + "devices": { + "enabled": false + } + }` + + // Create an existing .security.yml with values that might come from env vars + existingSecurity := `model_list: + openai:0: + api_keys: + - sk-existing-key-from-env +channels: + telegram: + token: existing-telegram-token-from-env + discord: + token: existing-discord-token-from-env +web: + brave: + api_keys: + - existing-brave-key +` + + if err := os.WriteFile(configPath, []byte(legacyConfig), 0o600); err != nil { + t.Fatalf("Failed to write legacy config: %v", err) + } + + if err := os.WriteFile(securityPath, []byte(existingSecurity), 0o600); err != nil { + t.Fatalf("Failed to write existing security config: %v", err) + } + + // Load the config - this should trigger migration + cfg, err := LoadConfig(configPath) + if err != nil { + t.Fatalf("LoadConfig failed: %v", err) + } + + // Verify that the migrated config has the existing security values + // Telegram token should be preserved + if cfg.Channels.Telegram.Token() != "existing-telegram-token-from-env" { + t.Errorf("Telegram token was overwritten: got %q, want %q", + cfg.Channels.Telegram.Token(), "existing-telegram-token-from-env") + } + + // Discord token should be preserved (even though legacy config didn't have it) + if cfg.Channels.Discord.Token() != "existing-discord-token-from-env" { + t.Errorf("Discord token was overwritten: got %q, want %q", + cfg.Channels.Discord.Token(), "existing-discord-token-from-env") + } + + // Model API key should be preserved + if cfg.ModelList[0].APIKey() != "sk-existing-key-from-env" { + t.Errorf("Model API key was overwritten: got %q, want %q", + cfg.ModelList[0].APIKey(), "sk-existing-key-from-env") + } + + // Brave API key should be preserved + if cfg.Tools.Web.Brave.APIKey() != "existing-brave-key" { + t.Errorf("Brave API key was overwritten: got %q, want %q", + cfg.Tools.Web.Brave.APIKey(), "existing-brave-key") + } + + // Reload the security config from disk to verify it wasn't corrupted + reloadedSec, err := loadSecurityConfig(securityPath) + if err != nil { + t.Fatalf("Failed to reload security config: %v", err) + } + + if reloadedSec.Channels.Telegram == nil || + reloadedSec.Channels.Telegram.Token != "existing-telegram-token-from-env" { + t.Error("Telegram token not preserved in .security.yml file") + } + + if reloadedSec.Channels.Discord == nil || reloadedSec.Channels.Discord.Token != "existing-discord-token-from-env" { + t.Error("Discord token not preserved in .security.yml file") + } +} diff --git a/pkg/config/security.go b/pkg/config/security.go index 5c71bf8c3..da989ca88 100644 --- a/pkg/config/security.go +++ b/pkg/config/security.go @@ -244,6 +244,142 @@ func saveSecurityConfig(securityPath string, sec *SecurityConfig) error { return fileutil.WriteFileAtomic(securityPath, buf.Bytes(), 0o600) } +// mergeSecurityConfig merges two SecurityConfig instances, preferring non-empty values from 'newer'. +// This is used during config migration to preserve existing security data while adding new entries. +func mergeSecurityConfig(existing, newer *SecurityConfig) *SecurityConfig { + if existing == nil { + return normalizeSecurityConfig(newer) + } + if newer == nil { + return normalizeSecurityConfig(existing) + } + + result := normalizeSecurityConfig(nil) + + // Merge ModelList: prefer newer if it has keys, otherwise use existing + for k, v := range existing.ModelList { + result.ModelList[k] = v + } + for k, v := range newer.ModelList { + if len(v.APIKeys) > 0 { + result.ModelList[k] = v + } + } + + // Merge Channels + if existing.Channels != nil { + result.Channels = existing.Channels + } + if newer.Channels != nil { + if result.Channels == nil { + result.Channels = &ChannelsSecurity{} + } + mergeChannelsSecurity(result.Channels, newer.Channels) + } + + // Merge Web + if existing.Web != nil { + result.Web = existing.Web + } + if newer.Web != nil { + if result.Web == nil { + result.Web = &WebToolsSecurity{} + } + mergeWebToolsSecurity(result.Web, newer.Web) + } + + // Merge Skills + if existing.Skills != nil { + result.Skills = existing.Skills + } + if newer.Skills != nil { + if result.Skills == nil { + result.Skills = &SkillsSecurity{} + } + mergeSkillsSecurity(result.Skills, newer.Skills) + } + + return result +} + +func mergeChannelsSecurity(dst, src *ChannelsSecurity) { + if src.Telegram != nil && src.Telegram.Token != "" { + dst.Telegram = src.Telegram + } + if src.Feishu != nil && + (src.Feishu.AppSecret != "" || src.Feishu.EncryptKey != "" || src.Feishu.VerificationToken != "") { + dst.Feishu = src.Feishu + } + if src.Discord != nil && src.Discord.Token != "" { + dst.Discord = src.Discord + } + if src.Weixin != nil && src.Weixin.Token != "" { + dst.Weixin = src.Weixin + } + if src.QQ != nil && src.QQ.AppSecret != "" { + dst.QQ = src.QQ + } + if src.DingTalk != nil && src.DingTalk.ClientSecret != "" { + dst.DingTalk = src.DingTalk + } + if src.Slack != nil && (src.Slack.BotToken != "" || src.Slack.AppToken != "") { + dst.Slack = src.Slack + } + if src.Matrix != nil && src.Matrix.AccessToken != "" { + dst.Matrix = src.Matrix + } + if src.LINE != nil && (src.LINE.ChannelSecret != "" || src.LINE.ChannelAccessToken != "") { + dst.LINE = src.LINE + } + if src.OneBot != nil && src.OneBot.AccessToken != "" { + dst.OneBot = src.OneBot + } + if src.WeCom != nil && (src.WeCom.Token != "" || src.WeCom.EncodingAESKey != "") { + dst.WeCom = src.WeCom + } + if src.WeComApp != nil && + (src.WeComApp.CorpSecret != "" || src.WeComApp.Token != "" || src.WeComApp.EncodingAESKey != "") { + dst.WeComApp = src.WeComApp + } + if src.WeComAIBot != nil && + (src.WeComAIBot.Secret != "" || src.WeComAIBot.Token != "" || src.WeComAIBot.EncodingAESKey != "") { + dst.WeComAIBot = src.WeComAIBot + } + if src.Pico != nil && src.Pico.Token != "" { + dst.Pico = src.Pico + } + if src.IRC != nil && (src.IRC.Password != "" || src.IRC.NickServPassword != "" || src.IRC.SASLPassword != "") { + dst.IRC = src.IRC + } +} + +func mergeWebToolsSecurity(dst, src *WebToolsSecurity) { + if src.Brave != nil && len(src.Brave.APIKeys) > 0 { + dst.Brave = src.Brave + } + if src.Tavily != nil && len(src.Tavily.APIKeys) > 0 { + dst.Tavily = src.Tavily + } + if src.Perplexity != nil && len(src.Perplexity.APIKeys) > 0 { + dst.Perplexity = src.Perplexity + } + if src.GLMSearch != nil && src.GLMSearch.APIKey != "" { + dst.GLMSearch = src.GLMSearch + } + if src.BaiduSearch != nil && src.BaiduSearch.APIKey != "" { + dst.BaiduSearch = src.BaiduSearch + } +} + +func mergeSkillsSecurity(dst, src *SkillsSecurity) { + if src.Github != nil && src.Github.Token != "" { + dst.Github = src.Github + } + if src.ClawHub != nil && src.ClawHub.AuthToken != "" { + dst.ClawHub = src.ClawHub + } +} + // SensitiveDataCache caches the compiled regex for filtering sensitive data. // SensitiveDataCache caches the strings.Replacer for filtering sensitive data. // Computed once on first access via sync.Once. From a1f95f02bce27788da2a6f2e2da1a8c8a5250f46 Mon Sep 17 00:00:00 2001 From: Hoshina Date: Tue, 24 Mar 2026 15:03:41 +0800 Subject: [PATCH 22/39] refactor(wecom): rebuild ai bot channel --- pkg/agent/loop_test.go | 24 +- pkg/channels/manager.go | 11 +- pkg/channels/manager_channel.go | 21 +- pkg/channels/wecom/aibot.go | 1099 -------------- pkg/channels/wecom/aibot_test.go | 559 ------- pkg/channels/wecom/aibot_ws.go | 1347 ----------------- pkg/channels/wecom/aibot_ws_test.go | 295 ---- pkg/channels/wecom/app.go | 756 --------- pkg/channels/wecom/app_test.go | 1060 ------------- pkg/channels/wecom/bot.go | 499 ------ pkg/channels/wecom/bot_test.go | 734 --------- pkg/channels/wecom/common.go | 199 --- pkg/channels/wecom/dedupe.go | 54 - pkg/channels/wecom/dedupe_test.go | 83 - pkg/channels/wecom/init.go | 8 +- pkg/channels/wecom/media.go | 291 ++++ pkg/channels/wecom/media_test.go | 180 +++ pkg/channels/wecom/protocol.go | 122 ++ pkg/channels/wecom/reqid_store.go | 113 ++ pkg/channels/wecom/reqid_store_test.go | 24 + pkg/channels/wecom/wecom.go | 777 ++++++++++ pkg/channels/wecom/wecom_test.go | 167 ++ pkg/config/config.go | 203 +-- pkg/config/config_old.go | 216 +-- pkg/config/config_test.go | 3 +- pkg/config/defaults.go | 31 +- pkg/config/security.go | 43 +- pkg/config/security_integration_test.go | 42 +- pkg/migrate/sources/openclaw/common.go | 25 +- web/backend/api/channels.go | 2 - web/backend/api/config.go | 9 + .../channels/channel-config-page.tsx | 14 +- .../channels/channel-forms/generic-form.tsx | 10 + .../src/hooks/use-sidebar-channels.ts | 4 - web/frontend/src/i18n/locales/en.json | 2 - web/frontend/src/i18n/locales/zh.json | 2 - 36 files changed, 1833 insertions(+), 7196 deletions(-) delete mode 100644 pkg/channels/wecom/aibot.go delete mode 100644 pkg/channels/wecom/aibot_test.go delete mode 100644 pkg/channels/wecom/aibot_ws.go delete mode 100644 pkg/channels/wecom/aibot_ws_test.go delete mode 100644 pkg/channels/wecom/app.go delete mode 100644 pkg/channels/wecom/app_test.go delete mode 100644 pkg/channels/wecom/bot.go delete mode 100644 pkg/channels/wecom/bot_test.go delete mode 100644 pkg/channels/wecom/common.go delete mode 100644 pkg/channels/wecom/dedupe.go delete mode 100644 pkg/channels/wecom/dedupe_test.go create mode 100644 pkg/channels/wecom/media.go create mode 100644 pkg/channels/wecom/media_test.go create mode 100644 pkg/channels/wecom/protocol.go create mode 100644 pkg/channels/wecom/reqid_store.go create mode 100644 pkg/channels/wecom/reqid_store_test.go create mode 100644 pkg/channels/wecom/wecom.go create mode 100644 pkg/channels/wecom/wecom_test.go diff --git a/pkg/agent/loop_test.go b/pkg/agent/loop_test.go index 1a4a44edf..ee3a3c8bd 100644 --- a/pkg/agent/loop_test.go +++ b/pkg/agent/loop_test.go @@ -1495,18 +1495,17 @@ func TestTargetReasoningChannelID_AllChannels(t *testing.T) { t.Fatalf("Failed to create channel manager: %v", err) } for name, id := range map[string]string{ - "whatsapp": "rid-whatsapp", - "telegram": "rid-telegram", - "feishu": "rid-feishu", - "discord": "rid-discord", - "maixcam": "rid-maixcam", - "qq": "rid-qq", - "dingtalk": "rid-dingtalk", - "slack": "rid-slack", - "line": "rid-line", - "onebot": "rid-onebot", - "wecom": "rid-wecom", - "wecom_app": "rid-wecom-app", + "whatsapp": "rid-whatsapp", + "telegram": "rid-telegram", + "feishu": "rid-feishu", + "discord": "rid-discord", + "maixcam": "rid-maixcam", + "qq": "rid-qq", + "dingtalk": "rid-dingtalk", + "slack": "rid-slack", + "line": "rid-line", + "onebot": "rid-onebot", + "wecom": "rid-wecom", } { chManager.RegisterChannel(name, &fakeChannel{id: id}) } @@ -1526,7 +1525,6 @@ func TestTargetReasoningChannelID_AllChannels(t *testing.T) { {channel: "line", wantID: "rid-line"}, {channel: "onebot", wantID: "rid-onebot"}, {channel: "wecom", wantID: "rid-wecom"}, - {channel: "wecom_app", wantID: "rid-wecom-app"}, {channel: "unknown", wantID: ""}, } diff --git a/pkg/channels/manager.go b/pkg/channels/manager.go index f04d989a3..5cc15b4d2 100644 --- a/pkg/channels/manager.go +++ b/pkg/channels/manager.go @@ -371,19 +371,10 @@ func (m *Manager) initChannels(channels *config.ChannelsConfig) error { m.initChannel("onebot", "OneBot") } - if channels.WeCom.Enabled && channels.WeCom.Token() != "" { + if channels.WeCom.Enabled && channels.WeCom.BotID != "" && channels.WeCom.Secret() != "" { m.initChannel("wecom", "WeCom") } - if channels.WeComAIBot.Enabled && (channels.WeComAIBot.Token() != "" || - (channels.WeComAIBot.Secret() != "" && channels.WeComAIBot.BotID != "")) { - m.initChannel("wecom_aibot", "WeCom AI Bot") - } - - if channels.WeComApp.Enabled && channels.WeComApp.CorpID != "" { - m.initChannel("wecom_app", "WeCom App") - } - if channels.Weixin.Enabled && channels.Weixin.Token() != "" { m.initChannel("weixin", "Weixin") } diff --git a/pkg/channels/manager_channel.go b/pkg/channels/manager_channel.go index 86572e336..163218b75 100644 --- a/pkg/channels/manager_channel.go +++ b/pkg/channels/manager_channel.go @@ -49,15 +49,7 @@ func hiddenValues(key string, value map[string]any, ch config.ChannelsConfig) { value["token"] = ch.LINE.ChannelAccessToken() value["secret"] = ch.LINE.ChannelSecret() case "wecom": - value["token"] = ch.WeCom.Token() - value["key"] = ch.WeCom.EncodingAESKey() - case "wecom_app": - value["token"] = ch.WeComApp.Token() - value["secret"] = ch.WeComApp.CorpSecret() - case "wecom_aibot": - value["token"] = ch.WeComAIBot.Token() - value["key"] = ch.WeComAIBot.EncodingAESKey() - value["secret"] = ch.WeComAIBot.Secret() + value["secret"] = ch.WeCom.Secret() case "dingtalk": value["secret"] = ch.QQ.AppSecret() case "qq": @@ -156,16 +148,7 @@ func updateKeys(newcfg, old *config.ChannelsConfig) { newcfg.LINE.SetChannelSecret(old.LINE.ChannelSecret()) } if newcfg.WeCom.Enabled { - newcfg.WeCom.SetToken(old.WeCom.Token()) - newcfg.WeCom.SetEncodingAESKey(old.WeCom.EncodingAESKey()) - } - if newcfg.WeComApp.Enabled { - newcfg.WeComApp.SetToken(old.WeComApp.Token()) - newcfg.WeComApp.SetCorpSecret(old.WeComApp.CorpSecret()) - } - if newcfg.WeComAIBot.Enabled { - newcfg.WeComAIBot.SetToken(old.WeComAIBot.Token()) - newcfg.WeComAIBot.SetEncodingAESKey(old.WeComAIBot.EncodingAESKey()) + newcfg.WeCom.SetSecret(old.WeCom.Secret()) } if newcfg.DingTalk.Enabled { newcfg.DingTalk.SetClientSecret(old.DingTalk.ClientSecret()) diff --git a/pkg/channels/wecom/aibot.go b/pkg/channels/wecom/aibot.go deleted file mode 100644 index c5e148185..000000000 --- a/pkg/channels/wecom/aibot.go +++ /dev/null @@ -1,1099 +0,0 @@ -package wecom - -import ( - "bytes" - "context" - "crypto/rand" - "encoding/base64" - "encoding/json" - "fmt" - "io" - "math/big" - "net/http" - "strings" - "sync" - "time" - - "github.com/sipeed/picoclaw/pkg/bus" - "github.com/sipeed/picoclaw/pkg/channels" - "github.com/sipeed/picoclaw/pkg/config" - "github.com/sipeed/picoclaw/pkg/identity" - "github.com/sipeed/picoclaw/pkg/logger" - "github.com/sipeed/picoclaw/pkg/utils" -) - -// responseURLHTTPClient is a shared HTTP client for posting to WeCom response_url. -// Reusing it enables connection pooling across replies. -var responseURLHTTPClient = &http.Client{Timeout: 15 * time.Second} - -// WeComAIBotChannel implements the Channel interface for WeCom AI Bot (企业微信智能机器人) -type WeComAIBotChannel struct { - *channels.BaseChannel - config config.WeComAIBotConfig - ctx context.Context - cancel context.CancelFunc - streamTasks map[string]*streamTask // streamID -> task (for poll lookups) - chatTasks map[string][]*streamTask // chatID -> in-flight tasks queue (FIFO) - taskMu sync.RWMutex -} - -// streamTask represents a streaming task for AI Bot. -// -// Mutable fields (Finished, StreamClosed, StreamClosedAt) must be read/written -// while holding WeComAIBotChannel.taskMu. Immutable fields (StreamID, ChatID, -// ResponseURL, Question, CreatedTime, Deadline, answerCh, ctx, cancel) are set -// once at creation and never modified, so they are safe to read without a lock. -type streamTask struct { - // immutable after creation - StreamID string - ChatID string // used by Send() to find this task - ResponseURL string // temporary URL for proactive reply (valid 1 hour, use once) - Question string - CreatedTime time.Time - Deadline time.Time // ~30s, we close the stream here and switch to response_url - answerCh chan string // receives agent reply from Send() - ctx context.Context // canceled when task is removed; used to interrupt the agent goroutine - cancel context.CancelFunc // call on task removal to cancel ctx - - // mutable — guarded by WeComAIBotChannel.taskMu - StreamClosed bool // stream returned finish:true; waiting for agent to reply via response_url - StreamClosedAt time.Time // set when StreamClosed becomes true; used for accelerated cleanup - Finished bool // fully done -} - -// WeComAIBotMessage represents the decrypted JSON message from WeCom AI Bot -// Ref: https://developer.work.weixin.qq.com/document/path/100719 -type WeComAIBotMessage struct { - MsgID string `json:"msgid"` - AIBotID string `json:"aibotid"` - ChatID string `json:"chatid"` // only for group chat - ChatType string `json:"chattype"` // "single" or "group" - From struct { - UserID string `json:"userid"` - } `json:"from"` - ResponseURL string `json:"response_url"` // temporary URL for proactive reply - MsgType string `json:"msgtype"` - // text message - Text *struct { - Content string `json:"content"` - } `json:"text,omitempty"` - // stream polling refresh - Stream *struct { - ID string `json:"id"` - } `json:"stream,omitempty"` - // image message - Image *struct { - URL string `json:"url"` - } `json:"image,omitempty"` - // mixed message (text + image) - Mixed *struct { - MsgItem []struct { - MsgType string `json:"msgtype"` - Text *struct { - Content string `json:"content"` - } `json:"text,omitempty"` - Image *struct { - URL string `json:"url"` - } `json:"image,omitempty"` - } `json:"msg_item"` - } `json:"mixed,omitempty"` - // event field - Event *struct { - EventType string `json:"eventtype"` - } `json:"event,omitempty"` -} - -// WeComAIBotMsgItemImage holds the image payload inside a stream message item. -type WeComAIBotMsgItemImage struct { - Base64 string `json:"base64"` - MD5 string `json:"md5"` -} - -// WeComAIBotMsgItem is a single item inside a stream's msg_item list. -type WeComAIBotMsgItem struct { - MsgType string `json:"msgtype"` - Image *WeComAIBotMsgItemImage `json:"image,omitempty"` -} - -// WeComAIBotStreamInfo represents the detailed stream content in streaming responses. -type WeComAIBotStreamInfo struct { - ID string `json:"id"` - Finish bool `json:"finish"` - Content string `json:"content,omitempty"` - MsgItem []WeComAIBotMsgItem `json:"msg_item,omitempty"` -} - -// WeComAIBotStreamResponse represents the streaming response format -type WeComAIBotStreamResponse struct { - MsgType string `json:"msgtype"` - Stream WeComAIBotStreamInfo `json:"stream"` -} - -// WeComAIBotEncryptedResponse represents the encrypted response wrapper -// Fields match WXBizJsonMsgCrypt.generate() in Python SDK -type WeComAIBotEncryptedResponse struct { - Encrypt string `json:"encrypt"` - MsgSignature string `json:"msgsignature"` - Timestamp string `json:"timestamp"` - Nonce string `json:"nonce"` -} - -// NewWeComAIBotChannel creates a WeCom AI Bot channel instance. -// If cfg.BotID and cfg.secret are both set, it returns a WeComAIBotWSChannel -// using the WebSocket long-connection API. -// Otherwise it returns the webhook-mode WeComAIBotChannel (requires Token + -// EncodingAESKey). -func NewWeComAIBotChannel( - cfg config.WeComAIBotConfig, - messageBus *bus.MessageBus, -) (channels.Channel, error) { - // WebSocket long-connection mode takes priority when BotID + secret are set. - if cfg.BotID != "" && cfg.Secret() != "" { - logger.InfoC("wecom_aibot", "BotID and secret provided, using WebSocket mode") - return newWeComAIBotWSChannel(cfg, messageBus) - } - // Webhook (short-connection) mode. - if cfg.Token() == "" || cfg.EncodingAESKey() == "" { - return nil, fmt.Errorf( - "WeCom AI Bot requires either (bot_id + secret) for WebSocket mode " + - "or (token + encoding_aes_key) for webhook mode") - } - if cfg.ProcessingMessage == "" { - cfg.ProcessingMessage = config.DefaultWeComAIBotProcessingMessage - } - - base := channels.NewBaseChannel("wecom_aibot", cfg, messageBus, cfg.AllowFrom, - channels.WithMaxMessageLength(2048), - channels.WithReasoningChannelID(cfg.ReasoningChannelID), - ) - - return &WeComAIBotChannel{ - BaseChannel: base, - config: cfg, - streamTasks: make(map[string]*streamTask), - chatTasks: make(map[string][]*streamTask), - }, nil -} - -// Name returns the channel name -func (c *WeComAIBotChannel) Name() string { - return "wecom_aibot" -} - -// Start initializes the WeCom AI Bot channel -func (c *WeComAIBotChannel) Start(ctx context.Context) error { - logger.InfoC("wecom_aibot", "Starting WeCom AI Bot channel...") - - c.ctx, c.cancel = context.WithCancel(ctx) - - // Start cleanup goroutine for old tasks - go c.cleanupLoop() - - c.SetRunning(true) - logger.InfoC("wecom_aibot", "WeCom AI Bot channel started") - - return nil -} - -// Stop gracefully stops the WeCom AI Bot channel -func (c *WeComAIBotChannel) Stop(ctx context.Context) error { - logger.InfoC("wecom_aibot", "Stopping WeCom AI Bot channel...") - - if c.cancel != nil { - c.cancel() - } - - c.SetRunning(false) - logger.InfoC("wecom_aibot", "WeCom AI Bot channel stopped") - return nil -} - -// Send delivers the agent reply into the active streamTask for msg.ChatID. -// It writes into the earliest unfinished task in the queue (FIFO per chatID). -// If the stream has already closed (deadline passed), it posts directly to response_url. -func (c *WeComAIBotChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { - if !c.IsRunning() { - return channels.ErrNotRunning - } - c.taskMu.Lock() - queue := c.chatTasks[msg.ChatID] - // Only compact Finished tasks at the head of the queue. - // Tasks that are Finished in the middle are NOT removed here: doing a full - // scan on every Send() call would be O(n) and is unnecessary given that - // removeTask() always splices the task out of the queue immediately. - // Any Finished task left stranded in the middle (e.g. due to an unexpected - // code path) will be collected by cleanupOldTasks. - for len(queue) > 0 && queue[0].Finished { - queue = queue[1:] - } - c.chatTasks[msg.ChatID] = queue - var task *streamTask - var streamClosed bool - var responseURL string - if len(queue) > 0 { - task = queue[0] - // Read mutable fields while holding c.taskMu to avoid data races. - streamClosed = task.StreamClosed - responseURL = task.ResponseURL - } - c.taskMu.Unlock() - - if task == nil { - logger.DebugCF( - "wecom_aibot", - "Send: no active task for chat (may have timed out)", - map[string]any{ - "chat_id": msg.ChatID, - }, - ) - return nil - } - - if streamClosed { - // Stream already ended with a "please wait" notice; send the real reply via response_url. - // Note: task.StreamID and task.ChatID are immutable, safe to read without a lock. - logger.InfoCF("wecom_aibot", "Sending reply via response_url", map[string]any{ - "stream_id": task.StreamID, - "chat_id": msg.ChatID, - }) - if responseURL != "" { - if err := c.sendViaResponseURL(responseURL, msg.Content); err != nil { - logger.ErrorCF("wecom_aibot", "Failed to send via response_url", map[string]any{ - "error": err, - "stream_id": task.StreamID, - }) - c.removeTask(task) - return fmt.Errorf("response_url delivery failed: %w", channels.ErrSendFailed) - } - } else { - logger.WarnCF("wecom_aibot", "Stream closed but no response_url available", map[string]any{ - "stream_id": task.StreamID, - }) - } - c.removeTask(task) - return nil - } - - // Stream still open: deliver via answerCh for the next poll response. - select { - case task.answerCh <- msg.Content: - case <-task.ctx.Done(): - // Task was canceled (cleanup removed it); silently drop the reply. - return nil - case <-ctx.Done(): - return ctx.Err() - } - return nil -} - -// WebhookPath returns the path for registering on the shared HTTP server -func (c *WeComAIBotChannel) WebhookPath() string { - if c.config.WebhookPath == "" { - return "/webhook/wecom-aibot" - } - return c.config.WebhookPath -} - -// ServeHTTP implements http.Handler for the shared HTTP server -func (c *WeComAIBotChannel) ServeHTTP(w http.ResponseWriter, r *http.Request) { - c.handleWebhook(w, r) -} - -// HealthPath returns the health check endpoint path -func (c *WeComAIBotChannel) HealthPath() string { - return c.WebhookPath() + "/health" -} - -// HealthHandler handles health check requests -func (c *WeComAIBotChannel) HealthHandler(w http.ResponseWriter, r *http.Request) { - c.handleHealth(w, r) -} - -// handleWebhook handles incoming webhook requests from WeCom AI Bot -func (c *WeComAIBotChannel) handleWebhook(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - - // Log all incoming requests for debugging - logger.DebugCF("wecom_aibot", "Received webhook request", map[string]any{ - "method": r.Method, - "path": r.URL.Path, - "query": r.URL.RawQuery, - }) - - switch r.Method { - case http.MethodGet: - // URL verification - c.handleVerification(ctx, w, r) - case http.MethodPost: - // Message callback - c.handleMessageCallback(ctx, w, r) - default: - http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) - } -} - -// handleVerification handles the URL verification request from WeCom -func (c *WeComAIBotChannel) handleVerification( - ctx context.Context, - w http.ResponseWriter, - r *http.Request, -) { - msgSignature := r.URL.Query().Get("msg_signature") - timestamp := r.URL.Query().Get("timestamp") - nonce := r.URL.Query().Get("nonce") - echostr := r.URL.Query().Get("echostr") - - logger.DebugCF("wecom_aibot", "URL verification request", map[string]any{ - "msg_signature": msgSignature, - "timestamp": timestamp, - "nonce": nonce, - }) - - // Verify signature - if !verifySignature(c.config.Token(), msgSignature, timestamp, nonce, echostr) { - logger.ErrorC("wecom_aibot", "Signature verification failed") - http.Error(w, "Signature verification failed", http.StatusUnauthorized) - return - } - - // Decrypt echostr - // For WeCom AI Bot (智能机器人), receiveid should be empty string - decrypted, err := decryptMessageWithVerify(echostr, c.config.EncodingAESKey(), "") - if err != nil { - logger.ErrorCF("wecom_aibot", "Failed to decrypt echostr", map[string]any{ - "error": err, - }) - http.Error(w, "Decryption failed", http.StatusInternalServerError) - return - } - - // Remove BOM and whitespace as per WeCom documentation - decrypted = strings.TrimPrefix(decrypted, "\ufeff") - decrypted = strings.TrimSpace(decrypted) - - logger.InfoC("wecom_aibot", "URL verification successful") - w.Header().Set("Content-Type", "text/plain; charset=utf-8") - w.WriteHeader(http.StatusOK) - w.Write([]byte(decrypted)) -} - -// handleMessageCallback handles incoming messages from WeCom AI Bot -func (c *WeComAIBotChannel) handleMessageCallback( - ctx context.Context, - w http.ResponseWriter, - r *http.Request, -) { - msgSignature := r.URL.Query().Get("msg_signature") - timestamp := r.URL.Query().Get("timestamp") - nonce := r.URL.Query().Get("nonce") - - // Read request body (limit to 4 MB to prevent memory exhaustion) - const maxBodySize = 4 << 20 // 4 MB - body, err := io.ReadAll(io.LimitReader(r.Body, maxBodySize+1)) - if err != nil { - logger.ErrorCF("wecom_aibot", "Failed to read request body", map[string]any{ - "error": err, - }) - http.Error(w, "Failed to read body", http.StatusBadRequest) - return - } - if len(body) > maxBodySize { - http.Error(w, "Request body too large", http.StatusRequestEntityTooLarge) - return - } - - // Parse JSON body to get encrypted message - // Format: {"encrypt": "base64_encrypted_string"} - var encryptedMsg struct { - Encrypt string `json:"encrypt"` - } - if unmarshalErr := json.Unmarshal(body, &encryptedMsg); unmarshalErr != nil { - logger.ErrorCF("wecom_aibot", "Failed to parse JSON body", map[string]any{ - "error": unmarshalErr, - "body": string(body), - }) - http.Error(w, "Failed to parse JSON", http.StatusBadRequest) - return - } - - // Verify signature - if !verifySignature(c.config.Token(), msgSignature, timestamp, nonce, encryptedMsg.Encrypt) { - logger.ErrorC("wecom_aibot", "Signature verification failed") - http.Error(w, "Signature verification failed", http.StatusUnauthorized) - return - } - - // Decrypt message - // For WeCom AI Bot (智能机器人), receiveid is empty string - decrypted, err := decryptMessageWithVerify(encryptedMsg.Encrypt, c.config.EncodingAESKey(), "") - if err != nil { - logger.ErrorCF("wecom_aibot", "Failed to decrypt message", map[string]any{ - "error": err, - }) - http.Error(w, "Decryption failed", http.StatusInternalServerError) - return - } - - // Parse decrypted JSON message - var msg WeComAIBotMessage - if unmarshalErr := json.Unmarshal([]byte(decrypted), &msg); unmarshalErr != nil { - logger.ErrorCF("wecom_aibot", "Failed to parse decrypted JSON", map[string]any{ - "error": unmarshalErr, - "decrypted": decrypted, - }) - http.Error(w, "Failed to parse message", http.StatusInternalServerError) - return - } - - logger.DebugCF("wecom_aibot", "Decrypted message", map[string]any{ - "msgtype": msg.MsgType, - }) - - // Process the message and get streaming response - response := c.processMessage(ctx, msg, timestamp, nonce) - - // Check if response is empty (e.g. due to unsupported message type) - if response == "" { - response = c.encryptEmptyResponse(timestamp, nonce) - } - - // Return encrypted JSON response - w.Header().Set("Content-Type", "application/json; charset=utf-8") - w.WriteHeader(http.StatusOK) - w.Write([]byte(response)) -} - -// processMessage processes the received message and returns encrypted response -func (c *WeComAIBotChannel) processMessage( - ctx context.Context, - msg WeComAIBotMessage, - timestamp, nonce string, -) string { - logger.DebugCF("wecom_aibot", "Processing message", map[string]any{ - "msgtype": msg.MsgType, - }) - - switch msg.MsgType { - case "text": - return c.handleTextMessage(ctx, msg, timestamp, nonce) - case "stream": - return c.handleStreamMessage(ctx, msg, timestamp, nonce) - case "image": - return c.handleImageMessage(ctx, msg, timestamp, nonce) - case "mixed": - return c.handleMixedMessage(ctx, msg, timestamp, nonce) - case "event": - return c.handleEventMessage(ctx, msg, timestamp, nonce) - default: - logger.WarnCF("wecom_aibot", "Unsupported message type", map[string]any{ - "msgtype": msg.MsgType, - }) - return c.encryptResponse("", timestamp, nonce, WeComAIBotStreamResponse{ - MsgType: "stream", - Stream: WeComAIBotStreamInfo{ - ID: c.generateStreamID(), - Finish: true, - Content: "Unsupported message type: " + msg.MsgType, - }, - }) - } -} - -// handleTextMessage handles text messages by starting a new streaming task -func (c *WeComAIBotChannel) handleTextMessage( - ctx context.Context, - msg WeComAIBotMessage, - timestamp, nonce string, -) string { - if msg.Text == nil { - logger.ErrorC("wecom_aibot", "text message missing text field") - return c.encryptEmptyResponse(timestamp, nonce) - } - - content := msg.Text.Content - userID := msg.From.UserID - if userID == "" { - userID = "unknown" - } - - // chatID: group chat uses chatid, single chat uses userid - chatID := msg.ChatID - if chatID == "" { - chatID = userID - } - - streamID := c.generateStreamID() - - // WeCom stops sending stream-refresh callbacks after 6 minutes. - // Set a slightly shorter deadline so we can send a timeout notice before it gives up. - deadline := time.Now().Add(30 * time.Second) - - // Each task gets its own context derived from the channel lifetime context. - // Canceling taskCancel interrupts the agent goroutine when the task is removed. - taskCtx, taskCancel := context.WithCancel(c.ctx) - - task := &streamTask{ - StreamID: streamID, - ChatID: chatID, - ResponseURL: msg.ResponseURL, - Question: content, - CreatedTime: time.Now(), - Deadline: deadline, - Finished: false, - answerCh: make(chan string, 1), - ctx: taskCtx, - cancel: taskCancel, - } - - c.taskMu.Lock() - c.streamTasks[streamID] = task - c.chatTasks[chatID] = append(c.chatTasks[chatID], task) - c.taskMu.Unlock() - - // Publish to agent asynchronously; agent will call Send() with reply. - // Use task.ctx (not c.ctx) so the agent goroutine is canceled when the task is removed. - go func() { - sender := bus.SenderInfo{ - Platform: "wecom_aibot", - PlatformID: userID, - CanonicalID: identity.BuildCanonicalID("wecom_aibot", userID), - DisplayName: userID, - } - peerKind := "direct" - if msg.ChatType == "group" { - peerKind = "group" - } - peer := bus.Peer{Kind: peerKind, ID: chatID} - metadata := map[string]string{ - "channel": "wecom_aibot", - "chat_type": msg.ChatType, - "msg_type": "text", - "msgid": msg.MsgID, - "aibotid": msg.AIBotID, - "stream_id": streamID, - "response_url": msg.ResponseURL, - } - c.HandleMessage(task.ctx, peer, msg.MsgID, userID, chatID, - content, nil, metadata, sender) - }() - - // Return first streaming response immediately (finish=false, content empty) - return c.getStreamResponse(task, timestamp, nonce) -} - -// handleStreamMessage handles stream polling requests -func (c *WeComAIBotChannel) handleStreamMessage( - ctx context.Context, - msg WeComAIBotMessage, - timestamp, nonce string, -) string { - if msg.Stream == nil { - logger.ErrorC("wecom_aibot", "Stream message missing stream field") - return c.encryptEmptyResponse(timestamp, nonce) - } - - streamID := msg.Stream.ID - - c.taskMu.RLock() - task, exists := c.streamTasks[streamID] - c.taskMu.RUnlock() - - if !exists { - logger.DebugCF( - "wecom_aibot", - "Stream task not found (may be from previous session)", - map[string]any{ - "stream_id": streamID, - }, - ) - return c.encryptResponse(streamID, timestamp, nonce, WeComAIBotStreamResponse{ - MsgType: "stream", - Stream: WeComAIBotStreamInfo{ - ID: streamID, - Finish: true, - Content: "Task not found or already finished. Please resend your message to start a new session.", - }, - }) - } - - // Get next response - return c.getStreamResponse(task, timestamp, nonce) -} - -// handleImageMessage handles image messages -func (c *WeComAIBotChannel) handleImageMessage( - ctx context.Context, - msg WeComAIBotMessage, - timestamp, nonce string, -) string { - logger.WarnC("wecom_aibot", "Image message type not yet fully implemented") - if msg.Image == nil { - logger.ErrorC("wecom_aibot", "Image message missing image field") - return c.encryptEmptyResponse(timestamp, nonce) - } - - imageURL := msg.Image.URL - - // For now, just acknowledge receipt without echoing the image - return c.encryptResponse("", timestamp, nonce, WeComAIBotStreamResponse{ - MsgType: "stream", - Stream: WeComAIBotStreamInfo{ - ID: c.generateStreamID(), - Finish: true, - Content: fmt.Sprintf( - "Image received (URL: %s), but image messages are not yet supported", - imageURL, - ), - }, - }) -} - -// handleMixedMessage handles mixed (text + image) messages -func (c *WeComAIBotChannel) handleMixedMessage( - ctx context.Context, - msg WeComAIBotMessage, - timestamp, nonce string, -) string { - logger.WarnC("wecom_aibot", "Mixed message type not yet fully implemented") - return c.encryptResponse("", timestamp, nonce, WeComAIBotStreamResponse{ - MsgType: "stream", - Stream: WeComAIBotStreamInfo{ - ID: c.generateStreamID(), - Finish: true, - Content: "Mixed message type is not yet supported", - }, - }) -} - -// handleEventMessage handles event messages -func (c *WeComAIBotChannel) handleEventMessage( - ctx context.Context, - msg WeComAIBotMessage, - timestamp, nonce string, -) string { - eventType := "" - if msg.Event != nil { - eventType = msg.Event.EventType - } - logger.DebugCF("wecom_aibot", "Received event", map[string]any{ - "event_type": eventType, - }) - - // Send welcome message when user opens the chat window - if eventType == "enter_chat" && c.config.WelcomeMessage != "" { - streamID := c.generateStreamID() - return c.encryptResponse(streamID, timestamp, nonce, WeComAIBotStreamResponse{ - MsgType: "stream", - Stream: WeComAIBotStreamInfo{ - ID: streamID, - Finish: true, - Content: c.config.WelcomeMessage, - }, - }) - } - - return c.encryptEmptyResponse(timestamp, nonce) -} - -// getStreamResponse gets the next streaming response for a task. -// - If agent replied: return finish=true with the real answer. -// - If deadline passed: return finish=true with a "please wait" notice, keep task alive for response_url. -// - Otherwise: return finish=false (empty), client will poll again. -func (c *WeComAIBotChannel) getStreamResponse(task *streamTask, timestamp, nonce string) string { - var content string - var finish bool - var closeStreamOnly bool // close stream but do NOT remove task (response_url still pending) - - select { - case answer := <-task.answerCh: - // Agent replied before deadline — normal finish. - content = answer - finish = true - default: - if time.Now().After(task.Deadline) { - // Deadline reached: close the stream with a notice, then wait for agent via response_url. - content = c.config.ProcessingMessage - finish = true - closeStreamOnly = true - logger.InfoCF( - "wecom_aibot", - "Stream deadline reached, switching to response_url mode", - map[string]any{ - "stream_id": task.StreamID, - "chat_id": task.ChatID, - "response_url": task.ResponseURL != "", - }, - ) - } - // else: still waiting, return finish=false - } - - if finish && !closeStreamOnly { - // Normal finish: remove from all maps. - c.removeTask(task) - } else if closeStreamOnly { - // Mark stream as closed and remove from streamTasks under a single lock - // to keep StreamClosed/StreamClosedAt consistent with map membership. - c.taskMu.Lock() - task.StreamClosed = true - task.StreamClosedAt = time.Now() - delete(c.streamTasks, task.StreamID) - c.taskMu.Unlock() - } - - response := WeComAIBotStreamResponse{ - MsgType: "stream", - Stream: WeComAIBotStreamInfo{ - ID: task.StreamID, - Finish: finish, - Content: content, - }, - } - - return c.encryptResponse(task.StreamID, timestamp, nonce, response) -} - -// removeTask removes a task from both streamTasks and chatTasks, marks it finished, -// and cancels its context to interrupt the associated agent goroutine. -func (c *WeComAIBotChannel) removeTask(task *streamTask) { - // Cancel first so the agent goroutine stops as soon as possible, - // before we acquire the write lock. - task.cancel() - - c.taskMu.Lock() - task.Finished = true // written under c.taskMu, consistent with all readers - delete(c.streamTasks, task.StreamID) - queue := c.chatTasks[task.ChatID] - for i, t := range queue { - if t == task { - c.chatTasks[task.ChatID] = append(queue[:i], queue[i+1:]...) - break - } - } - if len(c.chatTasks[task.ChatID]) == 0 { - delete(c.chatTasks, task.ChatID) - } - c.taskMu.Unlock() -} - -// sendViaResponseURL posts a markdown reply to the WeCom response_url. -// response_url is valid for 1 hour and can only be used once per callback. -// Returned errors are wrapped with channels.ErrRateLimit, channels.ErrTemporary, -// or channels.ErrSendFailed so the manager can apply the right retry policy. -func (c *WeComAIBotChannel) sendViaResponseURL(responseURL, content string) error { - payload := map[string]any{ - "msgtype": "markdown", - "markdown": map[string]string{ - "content": content, - }, - } - body, err := json.Marshal(payload) - if err != nil { - return fmt.Errorf("failed to marshal payload: %w", err) - } - - ctx, cancel := context.WithTimeout(c.ctx, 15*time.Second) - defer cancel() - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, responseURL, bytes.NewBuffer(body)) - if err != nil { - return fmt.Errorf("failed to create request: %w", err) - } - req.Header.Set("Content-Type", "application/json; charset=utf-8") - - resp, err := responseURLHTTPClient.Do(req) - if err != nil { - return fmt.Errorf("post to response_url failed: %w: %w", channels.ErrTemporary, err) - } - defer resp.Body.Close() - - if resp.StatusCode == http.StatusOK { - return nil - } - - const maxErrBody = 64 << 10 // 64 KB is more than enough for any error response - respBody, err := io.ReadAll(io.LimitReader(resp.Body, maxErrBody)) - if err != nil { - return fmt.Errorf("reading response_url body: %w: %w", channels.ErrTemporary, err) - } - switch { - case resp.StatusCode == http.StatusTooManyRequests: - return fmt.Errorf("response_url rate limited (%d): %s: %w", - resp.StatusCode, respBody, channels.ErrRateLimit) - case resp.StatusCode >= 500: - return fmt.Errorf("response_url server error (%d): %s: %w", - resp.StatusCode, respBody, channels.ErrTemporary) - default: - return fmt.Errorf("response_url returned %d: %s: %w", - resp.StatusCode, respBody, channels.ErrSendFailed) - } -} - -// encryptResponse encrypts a streaming response -func (c *WeComAIBotChannel) encryptResponse( - streamID, timestamp, nonce string, - response WeComAIBotStreamResponse, -) string { - // Marshal response to JSON - plaintext, err := json.Marshal(response) - if err != nil { - logger.ErrorCF("wecom_aibot", "Failed to marshal response", map[string]any{ - "error": err, - }) - return "" - } - - logger.DebugCF("wecom_aibot", "Encrypting response", map[string]any{ - "stream_id": streamID, - "finish": response.Stream.Finish, - "preview": utils.Truncate(response.Stream.Content, 100), - }) - - // Encrypt message - encrypted, err := c.encryptMessage(string(plaintext), "") - if err != nil { - logger.ErrorCF("wecom_aibot", "Failed to encrypt message", map[string]any{ - "error": err, - }) - return "" - } - - // Generate signature - signature := computeSignature(c.config.Token(), timestamp, nonce, encrypted) - - // Build encrypted response - encryptedResp := WeComAIBotEncryptedResponse{ - Encrypt: encrypted, - MsgSignature: signature, - Timestamp: timestamp, - Nonce: nonce, - } - - respJSON, err := json.Marshal(encryptedResp) - if err != nil { - logger.ErrorCF("wecom_aibot", "Failed to marshal encrypted response", map[string]any{ - "error": err, - }) - return "" - } - - logger.DebugCF("wecom_aibot", "Response encrypted", map[string]any{ - "stream_id": streamID, - }) - - return string(respJSON) -} - -// encryptEmptyResponse returns a minimal valid encrypted response -func (c *WeComAIBotChannel) encryptEmptyResponse(timestamp, nonce string) string { - // Construct a zero-value stream response and encrypt it so that - // WeCom always receives a syntactically valid encrypted JSON object. - emptyResp := WeComAIBotStreamResponse{} - return c.encryptResponse("", timestamp, nonce, emptyResp) -} - -// encryptMessage encrypts a plain text message for WeCom AI Bot -func (c *WeComAIBotChannel) encryptMessage(plaintext, receiveid string) (string, error) { - aesKey, err := decodeWeComAESKey(c.config.EncodingAESKey()) - if err != nil { - return "", err - } - - frame, err := packWeComFrame(plaintext, receiveid) - if err != nil { - return "", err - } - - // PKCS7 padding then AES-CBC encrypt - paddedFrame := pkcs7Pad(frame, blockSize) - ciphertext, err := encryptAESCBC(aesKey, paddedFrame) - if err != nil { - return "", err - } - - return base64.StdEncoding.EncodeToString(ciphertext), nil -} - -// func (c *WeComAIBotChannel) downloadAndDecryptImage( -// ctx context.Context, -// imageURL string, -// ) ([]byte, error) { -// // Download image -// req, err := http.NewRequestWithContext(ctx, http.MethodGet, imageURL, nil) -// if err != nil { -// return nil, fmt.Errorf("failed to create request: %w", err) -// } - -// client := &http.Client{ -// Timeout: 15 * time.Second, -// } - -// resp, err := client.Do(req) -// if err != nil { -// return nil, fmt.Errorf("failed to download image: %w", err) -// } -// defer resp.Body.Close() - -// if resp.StatusCode != http.StatusOK { -// return nil, fmt.Errorf("download failed with status: %d", resp.StatusCode) -// } - -// // Limit image download to 20 MB to prevent memory exhaustion -// const maxImageSize = 20 << 20 // 20 MB -// encryptedData, err := io.ReadAll(io.LimitReader(resp.Body, maxImageSize+1)) -// if err != nil { -// return nil, fmt.Errorf("failed to read image data: %w", err) -// } -// if len(encryptedData) > maxImageSize { -// return nil, fmt.Errorf("image too large (exceeds %d MB)", maxImageSize>>20) -// } - -// logger.DebugCF("wecom_aibot", "Image downloaded", map[string]any{ -// "size": len(encryptedData), -// }) - -// // Decode AES key -// aesKey, err := decodeWeComAESKey(c.config.EncodingAESKey) -// if err != nil { -// return nil, err -// } - -// // Decrypt image (AES-CBC with IV = first 16 bytes of key, PKCS7 padding stripped) -// decryptedData, err := decryptAESCBC(aesKey, encryptedData) -// if err != nil { -// return nil, fmt.Errorf("failed to decrypt image: %w", err) -// } - -// logger.DebugCF("wecom_aibot", "Image decrypted", map[string]any{ -// "size": len(decryptedData), -// }) - -// return decryptedData, nil -// } - -// generateRandomID generates a cryptographically random alphanumeric ID of -// length n. Used for stream IDs and WebSocket request IDs. -func generateRandomID(n int) string { - const letters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" - b := make([]byte, n) - for i := range b { - num, _ := rand.Int(rand.Reader, big.NewInt(int64(len(letters)))) - b[i] = letters[num.Int64()] - } - return string(b) -} - -// generateStreamID generates a random 10-character stream ID (webhook mode). -func (c *WeComAIBotChannel) generateStreamID() string { - return generateRandomID(10) -} - -// cleanupLoop periodically cleans up old streaming tasks -func (c *WeComAIBotChannel) cleanupLoop() { - ticker := time.NewTicker(5 * time.Minute) - defer ticker.Stop() - - for { - select { - case <-ticker.C: - c.cleanupOldTasks() - case <-c.ctx.Done(): - return - } - } -} - -// cleanupOldTasks removes tasks that have exceeded their expected lifetime: -// - Active tasks (in streamTasks): cleaned up after 1 hour (response_url validity window). -// - StreamClosed tasks (in chatTasks only): cleaned up after streamClosedGracePeriod. -// These tasks are waiting for the agent to call Send() via response_url. If the agent -// crashes or times out without calling Send(), we must not let them accumulate indefinitely. -// The grace period is generous enough to cover typical LLM latency but far shorter than 1 hour, -// preventing chatTasks from filling up when many requests time out in quick succession. -const ( - streamClosedGracePeriod = 10 * time.Minute // max wait for agent after stream closes - taskMaxLifetime = 1 * time.Hour // absolute max (≈ response_url validity) -) - -func (c *WeComAIBotChannel) cleanupOldTasks() { - c.taskMu.Lock() - defer c.taskMu.Unlock() - - now := time.Now() - cutoff := now.Add(-taskMaxLifetime) - for id, task := range c.streamTasks { - if task.CreatedTime.Before(cutoff) { - delete(c.streamTasks, id) - task.cancel() // interrupt agent goroutine still waiting for LLM - queue := c.chatTasks[task.ChatID] - for i, t := range queue { - if t == task { - c.chatTasks[task.ChatID] = append(queue[:i], queue[i+1:]...) - break - } - } - if len(c.chatTasks[task.ChatID]) == 0 { - delete(c.chatTasks, task.ChatID) - } - logger.DebugCF("wecom_aibot", "Cleaned up expired task", map[string]any{ - "stream_id": id, - }) - } - } - // Clean up StreamClosed tasks from chatTasks. - // Two expiry conditions are checked: - // 1. Absolute expiry: task was created more than taskMaxLifetime ago. - // 2. Grace expiry: stream closed more than streamClosedGracePeriod ago - // (agent had enough time to reply; it is not coming back). - for chatID, queue := range c.chatTasks { - filtered := queue[:0] - for i, t := range queue { - absoluteExpired := t.CreatedTime.Before(cutoff) - graceExpired := t.StreamClosed && - !t.StreamClosedAt.IsZero() && - t.StreamClosedAt.Before(now.Add(-streamClosedGracePeriod)) - if t.Finished { - // Finished tasks should have been removed by removeTask(). - // Finding one here (especially not at position 0) means an - // unexpected code path left it stranded, causing the queue to - // grow silently. Log a warning so it is visible, then drop it. - if i > 0 { - logger.WarnCF("wecom_aibot", - "Found stranded Finished task in the middle of chatTasks queue; "+ - "this should not happen — removeTask() should have spliced it out", - map[string]any{ - "chat_id": chatID, - "stream_id": t.StreamID, - "position": i, - }) - } - // The task is already finished; its context was already canceled - // by removeTask(), so no further action is required. - continue - } else if !absoluteExpired && !graceExpired { - filtered = append(filtered, t) - } else { - t.cancel() // cancel any lingering agent goroutine - } - } - if len(filtered) == 0 { - delete(c.chatTasks, chatID) - } else { - c.chatTasks[chatID] = filtered - } - } -} - -// handleHealth handles health check requests -func (c *WeComAIBotChannel) handleHealth(w http.ResponseWriter, r *http.Request) { - status := "ok" - if !c.IsRunning() { - status = "not running" - } - - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(map[string]string{ - "status": status, - }) -} diff --git a/pkg/channels/wecom/aibot_test.go b/pkg/channels/wecom/aibot_test.go deleted file mode 100644 index 11c4393d6..000000000 --- a/pkg/channels/wecom/aibot_test.go +++ /dev/null @@ -1,559 +0,0 @@ -package wecom - -import ( - "context" - "encoding/json" - "testing" - "time" - - "github.com/sipeed/picoclaw/pkg/bus" - "github.com/sipeed/picoclaw/pkg/channels" - "github.com/sipeed/picoclaw/pkg/config" -) - -// ---- Webhook mode tests ---- - -func TestNewWeComAIBotChannel_WebhookMode(t *testing.T) { - t.Run("success with valid config", func(t *testing.T) { - cfg := config.WeComAIBotConfig{} - cfg.Enabled = true - cfg.SetToken("test_token") - cfg.SetEncodingAESKey("testkey1234567890123456789012345678901234567") - cfg.WebhookPath = "/webhook/test" - - messageBus := bus.NewMessageBus() - ch, err := NewWeComAIBotChannel(cfg, messageBus) - if err != nil { - t.Fatalf("Expected no error, got %v", err) - } - if ch == nil { - t.Fatal("Expected channel to be created") - } - if ch.Name() != "wecom_aibot" { - t.Errorf("Expected name 'wecom_aibot', got '%s'", ch.Name()) - } - // Webhook mode must implement WebhookHandler. - if _, ok := ch.(channels.WebhookHandler); !ok { - t.Error("Webhook mode channel should implement WebhookHandler") - } - }) - - t.Run("error with missing token", func(t *testing.T) { - cfg := config.WeComAIBotConfig{} - cfg.Enabled = true - cfg.SetEncodingAESKey("testkey1234567890123456789012345678901234567") - - messageBus := bus.NewMessageBus() - _, err := NewWeComAIBotChannel(cfg, messageBus) - if err == nil { - t.Fatal("Expected error for missing token, got nil") - } - }) - - t.Run("error with missing encoding key", func(t *testing.T) { - cfg := config.WeComAIBotConfig{} - cfg.Enabled = true - cfg.SetToken("test_token") - - messageBus := bus.NewMessageBus() - _, err := NewWeComAIBotChannel(cfg, messageBus) - if err == nil { - t.Fatal("Expected error for missing encoding key, got nil") - } - }) -} - -func TestWeComAIBotWebhookChannelStartStop(t *testing.T) { - cfg := config.WeComAIBotConfig{ - Enabled: true, - } - cfg.SetToken("test_token") - cfg.SetEncodingAESKey("testkey1234567890123456789012345678901234567") - - messageBus := bus.NewMessageBus() - ch, err := NewWeComAIBotChannel(cfg, messageBus) - if err != nil { - t.Fatalf("Failed to create channel: %v", err) - } - - ctx := context.Background() - - if err := ch.Start(ctx); err != nil { - t.Fatalf("Failed to start channel: %v", err) - } - if !ch.IsRunning() { - t.Error("Expected channel to be running after Start") - } - - if err := ch.Stop(ctx); err != nil { - t.Fatalf("Failed to stop channel: %v", err) - } - if ch.IsRunning() { - t.Error("Expected channel to be stopped after Stop") - } -} - -func TestWeComAIBotChannelWebhookPath(t *testing.T) { - t.Run("default path", func(t *testing.T) { - cfg := config.WeComAIBotConfig{} - cfg.Enabled = true - cfg.SetToken("test_token") - cfg.SetEncodingAESKey("testkey1234567890123456789012345678901234567") - - messageBus := bus.NewMessageBus() - ch, _ := NewWeComAIBotChannel(cfg, messageBus) - - wh, ok := ch.(channels.WebhookHandler) - if !ok { - t.Fatal("Expected channel to implement WebhookHandler") - } - expectedPath := "/webhook/wecom-aibot" - if wh.WebhookPath() != expectedPath { - t.Errorf("Expected webhook path '%s', got '%s'", expectedPath, wh.WebhookPath()) - } - }) - - t.Run("custom path", func(t *testing.T) { - customPath := "/custom/webhook" - cfg := config.WeComAIBotConfig{} - cfg.Enabled = true - cfg.SetToken("test_token") - cfg.SetEncodingAESKey("testkey1234567890123456789012345678901234567") - cfg.WebhookPath = customPath - - messageBus := bus.NewMessageBus() - ch, _ := NewWeComAIBotChannel(cfg, messageBus) - - wh, ok := ch.(channels.WebhookHandler) - if !ok { - t.Fatal("Expected channel to implement WebhookHandler") - } - if wh.WebhookPath() != customPath { - t.Errorf("Expected webhook path '%s', got '%s'", customPath, wh.WebhookPath()) - } - }) -} - -func TestWeComAIBotChannelGetStreamResponseProcessingMessage(t *testing.T) { - validAESKey := "abcdefghijklmnopqrstuvwxyz0123456789ABCDEFG" - - t.Run("uses default processing message", func(t *testing.T) { - cfg := config.WeComAIBotConfig{ - Enabled: true, - } - cfg.SetToken("test_token") - cfg.SetEncodingAESKey(validAESKey) - - messageBus := bus.NewMessageBus() - channel, err := NewWeComAIBotChannel(cfg, messageBus) - if err != nil { - t.Fatalf("Failed to create channel: %v", err) - } - ch, ok := channel.(*WeComAIBotChannel) - if !ok { - t.Fatal("Expected webhook mode channel") - } - - task := &streamTask{ - StreamID: "stream-default", - ChatID: "chat-default", - Deadline: time.Now().Add(-time.Second), - } - ch.streamTasks[task.StreamID] = task - ch.chatTasks[task.ChatID] = []*streamTask{task} - - resp := decodeStreamResponse(t, ch, ch.getStreamResponse(task, "1234567890", "nonce")) - - if !resp.Stream.Finish { - t.Fatal("Expected finished stream response after deadline") - } - if resp.Stream.Content != config.DefaultWeComAIBotProcessingMessage { - t.Fatalf("Expected default processing message %q, got %q", - config.DefaultWeComAIBotProcessingMessage, resp.Stream.Content) - } - if !task.StreamClosed { - t.Fatal("Expected task stream to be marked closed") - } - if _, ok := ch.streamTasks[task.StreamID]; ok { - t.Fatal("Expected closed stream task to be removed from streamTasks") - } - if len(ch.chatTasks[task.ChatID]) != 1 { - t.Fatalf("Expected task to remain queued for response_url delivery, got %d entries", - len(ch.chatTasks[task.ChatID])) - } - }) - - t.Run("uses custom processing message", func(t *testing.T) { - cfg := config.WeComAIBotConfig{ - Enabled: true, - ProcessingMessage: "Please wait a moment. The result will be delivered in a follow-up message.", - } - cfg.SetToken("test_token") - cfg.SetEncodingAESKey(validAESKey) - - messageBus := bus.NewMessageBus() - channel, err := NewWeComAIBotChannel(cfg, messageBus) - if err != nil { - t.Fatalf("Failed to create channel: %v", err) - } - ch, ok := channel.(*WeComAIBotChannel) - if !ok { - t.Fatal("Expected webhook mode channel") - } - - task := &streamTask{ - StreamID: "stream-custom", - ChatID: "chat-custom", - Deadline: time.Now().Add(-time.Second), - } - - resp := decodeStreamResponse(t, ch, ch.getStreamResponse(task, "1234567890", "nonce")) - - if resp.Stream.Content != cfg.ProcessingMessage { - t.Fatalf("Expected custom processing message %q, got %q", cfg.ProcessingMessage, resp.Stream.Content) - } - }) -} - -func TestGenerateStreamID(t *testing.T) { - cfg := config.WeComAIBotConfig{} - cfg.Enabled = true - cfg.SetToken("test_token") - cfg.SetEncodingAESKey("testkey1234567890123456789012345678901234567") - - messageBus := bus.NewMessageBus() - ch, _ := NewWeComAIBotChannel(cfg, messageBus) - webhookCh, ok := ch.(*WeComAIBotChannel) - if !ok { - t.Fatal("Expected webhook mode channel") - } - - ids := make(map[string]bool) - for i := 0; i < 100; i++ { - id := webhookCh.generateStreamID() - if len(id) != 10 { - t.Errorf("Expected stream ID length 10, got %d", len(id)) - } - if ids[id] { - t.Errorf("Duplicate stream ID generated: %s", id) - } - ids[id] = true - } -} - -func TestEncryptDecrypt(t *testing.T) { - // Use a valid 43-character base64 key (企业微信标准格式) - cfg := config.WeComAIBotConfig{} - cfg.Enabled = true - cfg.SetToken("test_token") - cfg.SetEncodingAESKey("abcdefghijklmnopqrstuvwxyz0123456789ABCDEFG") // 43 characters - - messageBus := bus.NewMessageBus() - ch, _ := NewWeComAIBotChannel(cfg, messageBus) - webhookCh, ok := ch.(*WeComAIBotChannel) - if !ok { - t.Fatal("Expected webhook mode channel") - } - - plaintext := "Hello, World!" - receiveid := "" - - encrypted, err := webhookCh.encryptMessage(plaintext, receiveid) - if err != nil { - t.Fatalf("Failed to encrypt message: %v", err) - } - if encrypted == "" { - t.Fatal("Encrypted message is empty") - } - - // Decrypt - decrypted, err := decryptMessageWithVerify(encrypted, cfg.EncodingAESKey(), receiveid) - if err != nil { - t.Fatalf("Failed to decrypt message: %v", err) - } - if decrypted != plaintext { - t.Errorf("Expected decrypted message '%s', got '%s'", plaintext, decrypted) - } -} - -func TestGenerateSignature(t *testing.T) { - token := "test_token" - timestamp := "1234567890" - nonce := "test_nonce" - encrypt := "encrypted_msg" - - signature := computeSignature(token, timestamp, nonce, encrypt) - if signature == "" { - t.Error("Generated signature is empty") - } - if !verifySignature(token, signature, timestamp, nonce, encrypt) { - t.Error("Generated signature does not verify correctly") - } -} - -func decodeStreamResponse(t *testing.T, ch *WeComAIBotChannel, encryptedResponse string) WeComAIBotStreamResponse { - t.Helper() - - var wrapped WeComAIBotEncryptedResponse - if err := json.Unmarshal([]byte(encryptedResponse), &wrapped); err != nil { - t.Fatalf("Failed to unmarshal encrypted response: %v", err) - } - - plaintext, err := decryptMessageWithVerify(wrapped.Encrypt, ch.config.EncodingAESKey(), "") - if err != nil { - t.Fatalf("Failed to decrypt response: %v", err) - } - - var resp WeComAIBotStreamResponse - if err := json.Unmarshal([]byte(plaintext), &resp); err != nil { - t.Fatalf("Failed to unmarshal decrypted response: %v", err) - } - - return resp -} - -// ---- WebSocket long-connection mode tests ---- - -func TestNewWeComAIBotChannel_WSMode(t *testing.T) { - t.Run("success with bot_id and secret", func(t *testing.T) { - cfg := config.WeComAIBotConfig{ - Enabled: true, - BotID: "test_bot_id", - } - cfg.SetSecret("test_secret") - messageBus := bus.NewMessageBus() - ch, err := NewWeComAIBotChannel(cfg, messageBus) - if err != nil { - t.Fatalf("Expected no error, got %v", err) - } - if ch == nil { - t.Fatal("Expected channel to be created") - } - if ch.Name() != "wecom_aibot" { - t.Errorf("Expected name 'wecom_aibot', got '%s'", ch.Name()) - } - // WebSocket mode must NOT implement WebhookHandler. - if _, ok := ch.(channels.WebhookHandler); ok { - t.Error("WebSocket mode channel should NOT implement WebhookHandler") - } - }) - - t.Run("ws mode takes priority over webhook fields", func(t *testing.T) { - cfg := config.WeComAIBotConfig{ - Enabled: true, - BotID: "test_bot_id", - } - cfg.SetSecret("test_secret") - cfg.SetToken("also_set") - cfg.SetEncodingAESKey("testkey1234567890123456789012345678901234567") - messageBus := bus.NewMessageBus() - ch, err := NewWeComAIBotChannel(cfg, messageBus) - if err != nil { - t.Fatalf("Expected no error, got %v", err) - } - if _, ok := ch.(*WeComAIBotWSChannel); !ok { - t.Error("Expected WebSocket mode channel when both BotID+secret and Token+Key are set") - } - }) - - t.Run("error with missing bot_id", func(t *testing.T) { - cfg := config.WeComAIBotConfig{ - Enabled: true, - } - cfg.SetSecret("test_secret") - messageBus := bus.NewMessageBus() - _, err := NewWeComAIBotChannel(cfg, messageBus) - // Missing bot_id alone means neither WS mode nor webhook mode is fully configured. - if err == nil { - t.Fatal("Expected error for missing bot_id, got nil") - } - }) - - t.Run("error with missing secret", func(t *testing.T) { - cfg := config.WeComAIBotConfig{ - Enabled: true, - BotID: "test_bot_id", - } - messageBus := bus.NewMessageBus() - _, err := NewWeComAIBotChannel(cfg, messageBus) - if err == nil { - t.Fatal("Expected error for missing secret, got nil") - } - }) -} - -func TestWeComAIBotWSChannelStartStop(t *testing.T) { - cfg := config.WeComAIBotConfig{ - Enabled: true, - BotID: "test_bot_id", - } - cfg.SetSecret("test_secret") - messageBus := bus.NewMessageBus() - ch, err := NewWeComAIBotChannel(cfg, messageBus) - if err != nil { - t.Fatalf("Failed to create channel: %v", err) - } - - ctx := context.Background() - - // Start launches a background goroutine; it should not block or return an error. - if err := ch.Start(ctx); err != nil { - t.Fatalf("Failed to start channel: %v", err) - } - if !ch.IsRunning() { - t.Error("Expected channel to be running after Start") - } - - // Stop should work regardless of whether the WebSocket actually connected. - if err := ch.Stop(ctx); err != nil { - t.Fatalf("Failed to stop channel: %v", err) - } - if ch.IsRunning() { - t.Error("Expected channel to be stopped after Stop") - } -} - -func TestGenerateRandomID(t *testing.T) { - ids := make(map[string]bool) - for i := 0; i < 200; i++ { - id := generateRandomID(10) - if len(id) != 10 { - t.Errorf("Expected ID length 10, got %d", len(id)) - } - if ids[id] { - t.Errorf("Duplicate ID generated: %s", id) - } - ids[id] = true - } -} - -func TestWSGenerateID(t *testing.T) { - ids := make(map[string]bool) - for i := 0; i < 200; i++ { - id := wsGenerateID() - if len(id) != 10 { - t.Errorf("Expected ID length 10, got %d", len(id)) - } - if ids[id] { - t.Errorf("Duplicate wsGenerateID result: %s", id) - } - ids[id] = true - } -} - -// ---- Webhook streaming fallback tests ---- - -// makeWebhookChannel creates a started WeComAIBotChannel for testing. -func makeWebhookChannel(t *testing.T) *WeComAIBotChannel { - t.Helper() - cfg := config.WeComAIBotConfig{ - Enabled: true, - } - cfg.SetToken("test_token") - cfg.SetEncodingAESKey("abcdefghijklmnopqrstuvwxyz0123456789ABCDEFG") - ch, err := NewWeComAIBotChannel(cfg, bus.NewMessageBus()) - if err != nil { - t.Fatalf("create channel: %v", err) - } - wc := ch.(*WeComAIBotChannel) - wc.ctx, wc.cancel = context.WithCancel(context.Background()) - return wc -} - -// makeStreamTask creates and registers a streamTask for testing. -func makeStreamTask(t *testing.T, ch *WeComAIBotChannel, streamID, chatID string, deadline time.Time) *streamTask { - t.Helper() - task := &streamTask{ - StreamID: streamID, - ChatID: chatID, - Deadline: deadline, - answerCh: make(chan string, 1), - } - task.ctx, task.cancel = context.WithCancel(ch.ctx) - ch.taskMu.Lock() - ch.streamTasks[streamID] = task - ch.chatTasks[chatID] = append(ch.chatTasks[chatID], task) - ch.taskMu.Unlock() - return task -} - -// TestGetStreamResponse_ImmediateAnswer verifies that when the agent has already -// placed its answer in answerCh, getStreamResponse returns a finish=true response -// and fully removes the task. -func TestGetStreamResponse_ImmediateAnswer(t *testing.T) { - ch := makeWebhookChannel(t) - defer ch.cancel() - - task := makeStreamTask(t, ch, "stream-1", "chat-1", time.Now().Add(30*time.Second)) - task.answerCh <- "hello from agent" - - result := ch.getStreamResponse(task, "ts123", "nonce123") - if result == "" { - t.Fatal("expected non-empty encrypted response") - } - - ch.taskMu.RLock() - _, exists := ch.streamTasks["stream-1"] - ch.taskMu.RUnlock() - if exists { - t.Error("task should have been removed from streamTasks after normal finish") - } - if !task.Finished { - t.Error("task.Finished should be true after normal finish") - } -} - -// TestGetStreamResponse_DeadlinePassed verifies that when the stream deadline has -// elapsed (no agent reply yet), getStreamResponse closes the stream but keeps the -// task alive so the response_url fallback can still deliver the answer. -func TestGetStreamResponse_DeadlinePassed(t *testing.T) { - ch := makeWebhookChannel(t) - defer ch.cancel() - - task := makeStreamTask(t, ch, "stream-2", "chat-2", time.Now().Add(-time.Millisecond)) - - result := ch.getStreamResponse(task, "ts456", "nonce456") - if result == "" { - t.Fatal("expected non-empty encrypted response") - } - - ch.taskMu.RLock() - _, stillStreaming := ch.streamTasks["stream-2"] - ch.taskMu.RUnlock() - if stillStreaming { - t.Error("task should have been removed from streamTasks after deadline") - } - if !task.StreamClosed { - t.Error("task.StreamClosed should be true after deadline") - } - if task.Finished { - t.Error("task.Finished must remain false: agent reply still expected via response_url") - } -} - -// TestGetStreamResponse_StillPending verifies that when neither the agent has -// replied nor the deadline has passed, getStreamResponse returns without altering -// task state (client should poll again). -func TestGetStreamResponse_StillPending(t *testing.T) { - ch := makeWebhookChannel(t) - defer ch.cancel() - - task := makeStreamTask(t, ch, "stream-3", "chat-3", time.Now().Add(30*time.Second)) - - result := ch.getStreamResponse(task, "ts789", "nonce789") - if result == "" { - t.Fatal("expected non-empty encrypted response") - } - - ch.taskMu.RLock() - _, exists := ch.streamTasks["stream-3"] - ch.taskMu.RUnlock() - if !exists { - t.Error("pending task should still be in streamTasks") - } - if task.Finished || task.StreamClosed { - t.Error("pending task should not be finished or stream-closed") - } - // Cleanup. - ch.removeTask(task) -} diff --git a/pkg/channels/wecom/aibot_ws.go b/pkg/channels/wecom/aibot_ws.go deleted file mode 100644 index 53dd7071f..000000000 --- a/pkg/channels/wecom/aibot_ws.go +++ /dev/null @@ -1,1347 +0,0 @@ -package wecom - -import ( - "context" - "encoding/base64" - "encoding/json" - "fmt" - "io" - "net/http" - "os" - "path/filepath" - "strings" - "sync" - "time" - - "github.com/gorilla/websocket" - - "github.com/sipeed/picoclaw/pkg/bus" - "github.com/sipeed/picoclaw/pkg/channels" - "github.com/sipeed/picoclaw/pkg/config" - "github.com/sipeed/picoclaw/pkg/identity" - "github.com/sipeed/picoclaw/pkg/logger" - "github.com/sipeed/picoclaw/pkg/media" - "github.com/sipeed/picoclaw/pkg/utils" -) - -// Long-connection WebSocket endpoint. -// Ref: https://developer.work.weixin.qq.com/document/path/101463 -const ( - wsEndpoint = "wss://openws.work.weixin.qq.com" - wsHeartbeatInterval = 30 * time.Second - wsConnectTimeout = 15 * time.Second - wsSubscribeTimeout = 10 * time.Second - wsSendMsgTimeout = 10 * time.Second - wsRespondMsgTimeout = 10 * time.Second - wsWelcomeMsgTimeout = 5 * time.Second // WeCom requires welcome reply within 5 seconds - wsMaxReconnectWait = 60 * time.Second - wsInitialReconnect = time.Second - - // WeCom requires finish=true within 6 minutes of the first stream frame. - // wsStreamTickInterval controls how often we send an in-progress hint. - // wsStreamMaxDuration is a safety margin below the 6-minute hard limit. - wsStreamTickInterval = 30 * time.Second - wsStreamMaxDuration = 5*time.Minute + 30*time.Second - - // wsImageDownloadTimeout caps the time we spend downloading an inbound image. - wsImageDownloadTimeout = 30 * time.Second - - // Keep req_id -> chat route for late fallback pushes after stream window closes. - wsLateReplyRouteTTL = 30 * time.Minute - - // wsStreamMaxContentBytes is the maximum UTF-8 byte length for the content field - // of a single WeCom AI Bot stream / text / markdown frame. - // Ref: https://developer.work.weixin.qq.com/document/path/101463 - wsStreamMaxContentBytes = 20480 -) - -// wsImageHTTPClient is a shared HTTP client for downloading inbound images. -// Reusing it enables connection pooling across multiple image downloads. -var wsImageHTTPClient = &http.Client{Timeout: wsImageDownloadTimeout} - -// WeComAIBotWSChannel implements channels.Channel for WeCom AI Bot using the -// WebSocket long-connection API. -// Unlike the webhook counterpart it does NOT implement WebhookHandler, so the -// HTTP manager will not register any callback URL for it. -type WeComAIBotWSChannel struct { - *channels.BaseChannel - config config.WeComAIBotConfig - ctx context.Context - cancel context.CancelFunc - - // conn is the active WebSocket connection; nil when disconnected. - // All writes are serialized through connMu. - conn *websocket.Conn - connMu sync.Mutex - - // dedupe prevents duplicate message processing (WeCom may re-deliver). - dedupe *MessageDeduplicator - - // reqStates holds per-req_id runtime state. - // It unifies active task state and late-reply fallback routing. - reqStates map[string]*wsReqState - reqStatesMu sync.Mutex - - // reqPending correlates command req_ids with response channels. - // Used only for subscribe/ping command-response pairs. - reqPending map[string]chan wsEnvelope - reqPendingMu sync.Mutex -} - -// wsTask tracks one in-progress agent reply for a single chat turn. -type wsTask struct { - ReqID string // req_id echoed in all replies for this turn - ChatID string - ChatType uint32 - StreamID string // our generated stream.id - answerCh chan string // agent delivers its reply here via Send() - ctx context.Context - cancel context.CancelFunc -} - -type wsReqState struct { - Task *wsTask - Route wsLateReplyRoute -} - -type wsLateReplyRoute struct { - ChatID string - ChatType uint32 - ReadyAt time.Time - ExpiresAt time.Time -} - -// ---- WebSocket protocol types ---- - -// wsEnvelope is the generic JSON envelope for all WebSocket messages. -type wsEnvelope struct { - Cmd string `json:"cmd,omitempty"` - Headers wsHeaders `json:"headers"` - Body json.RawMessage `json:"body,omitempty"` - ErrCode int `json:"errcode,omitempty"` - ErrMsg string `json:"errmsg,omitempty"` -} - -type wsHeaders struct { - ReqID string `json:"req_id"` -} - -// wsCommand is an outgoing request sent over the WebSocket. -type wsCommand struct { - Cmd string `json:"cmd"` - Headers wsHeaders `json:"headers"` - Body any `json:"body,omitempty"` -} - -type wsSendMsgBody struct { - ChatID string `json:"chatid"` - ChatType uint32 `json:"chat_type,omitempty"` - MsgType string `json:"msgtype"` - Markdown *wsMarkdownContent `json:"markdown,omitempty"` -} - -// wsRespondMsgBody is the body for aibot_respond_msg / aibot_respond_welcome_msg. -type wsRespondMsgBody struct { - MsgType string `json:"msgtype"` - Stream *wsStreamContent `json:"stream,omitempty"` - Text *wsTextContent `json:"text,omitempty"` - Markdown *wsMarkdownContent `json:"markdown,omitempty"` - Image *wsImageContent `json:"image,omitempty"` -} - -type wsStreamContent struct { - ID string `json:"id"` - Finish bool `json:"finish"` - Content string `json:"content,omitempty"` -} - -// wsImageContent carries a base64-encoded image payload for outbound messages. -type wsImageContent struct { - Base64 string `json:"base64"` - MD5 string `json:"md5"` -} - -type wsTextContent struct { - Content string `json:"content"` -} - -type wsMarkdownContent struct { - Content string `json:"content"` -} - -// WeComAIBotWSMessage is the decoded body of aibot_msg_callback / -// aibot_event_callback in WebSocket long-connection mode. -// The structure mirrors WeComAIBotMessage but includes extra fields -// that only appear in long-connection callbacks (Voice, AESKey on Image/File). -type WeComAIBotWSMessage struct { - MsgID string `json:"msgid"` - CreateTime int64 `json:"create_time,omitempty"` - AIBotID string `json:"aibotid"` - ChatID string `json:"chatid,omitempty"` - ChatType string `json:"chattype,omitempty"` // "single" | "group" - From struct { - UserID string `json:"userid"` - } `json:"from"` - MsgType string `json:"msgtype"` - Text *struct { - Content string `json:"content"` - } `json:"text,omitempty"` - Image *struct { - URL string `json:"url"` - AESKey string `json:"aeskey,omitempty"` // long-connection: per-resource decrypt key - } `json:"image,omitempty"` - Voice *struct { - Content string `json:"content"` // WeCom transcribes voice to text in callbacks - } `json:"voice,omitempty"` - Mixed *struct { - MsgItem []struct { - MsgType string `json:"msgtype"` - Text *struct { - Content string `json:"content"` - } `json:"text,omitempty"` - Image *struct { - URL string `json:"url"` - AESKey string `json:"aeskey,omitempty"` - } `json:"image,omitempty"` - } `json:"msg_item"` - } `json:"mixed,omitempty"` - Event *struct { - EventType string `json:"eventtype"` - } `json:"event,omitempty"` - File *struct { - URL string `json:"url"` - AESKey string `json:"aeskey,omitempty"` - } `json:"file,omitempty"` - Video *struct { - URL string `json:"url"` - AESKey string `json:"aeskey,omitempty"` - } `json:"video,omitempty"` -} - -// ---- Constructor ---- - -// newWeComAIBotWSChannel creates a WeComAIBotWSChannel for WebSocket mode. -func newWeComAIBotWSChannel( - cfg config.WeComAIBotConfig, - messageBus *bus.MessageBus, -) (*WeComAIBotWSChannel, error) { - if cfg.BotID == "" || cfg.Secret() == "" { - return nil, fmt.Errorf("bot_id and secret are required for WeCom AI Bot WebSocket mode") - } - - base := channels.NewBaseChannel("wecom_aibot", cfg, messageBus, cfg.AllowFrom, - channels.WithReasoningChannelID(cfg.ReasoningChannelID), - ) - - return &WeComAIBotWSChannel{ - BaseChannel: base, - config: cfg, - dedupe: NewMessageDeduplicator(wecomMaxProcessedMessages), - reqStates: make(map[string]*wsReqState), - reqPending: make(map[string]chan wsEnvelope), - }, nil -} - -// ---- Channel interface ---- - -// Name implements channels.Channel. -func (c *WeComAIBotWSChannel) Name() string { return "wecom_aibot" } - -// Start connects to the WeCom WebSocket endpoint and begins message processing. -func (c *WeComAIBotWSChannel) Start(ctx context.Context) error { - logger.InfoC("wecom_aibot", "Starting WeCom AI Bot channel (WebSocket long-connection mode)...") - c.ctx, c.cancel = context.WithCancel(ctx) - c.SetRunning(true) - go c.connectLoop() - logger.InfoC("wecom_aibot", "WeCom AI Bot channel started (WebSocket mode)") - return nil -} - -// Stop shuts down the channel and closes the WebSocket connection. -func (c *WeComAIBotWSChannel) Stop(_ context.Context) error { - logger.InfoC("wecom_aibot", "Stopping WeCom AI Bot channel (WebSocket mode)...") - if c.cancel != nil { - c.cancel() - } - c.connMu.Lock() - if c.conn != nil { - c.conn.Close() - c.conn = nil - } - c.connMu.Unlock() - c.SetRunning(false) - logger.InfoC("wecom_aibot", "WeCom AI Bot channel stopped") - return nil -} - -// Send delivers the agent reply for msg.ChatID. -// The waiting task goroutine picks it up and writes the final stream response. -func (c *WeComAIBotWSChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { - if !c.IsRunning() { - return channels.ErrNotRunning - } - - // msg.ChatID carries the inbound req_id (set by dispatchWSAgentTask). - // For cron-triggered messages, msg.ChatID is the real WeCom chat/user ID - // and there will be no matching entry in reqStates; fall through to proactive push. - task, route, ok := c.getReqState(msg.ChatID) - if !ok { - // No req_id record found — this is a cron/scheduler-originated message. - // Send it as a proactive markdown push using the chat ID directly. - logger.InfoCF("wecom_aibot", "Send: no req_id state, delivering via proactive push (cron/scheduler)", - map[string]any{"chat_id": msg.ChatID}) - if err := c.wsSendActivePush(msg.ChatID, 0, msg.Content); err != nil { - logger.WarnCF("wecom_aibot", "Proactive push failed", - map[string]any{"chat_id": msg.ChatID, "error": err.Error()}) - return fmt.Errorf("websocket delivery failed: %w", channels.ErrSendFailed) - } - return nil - } - - if task == nil { - if time.Now().Before(route.ReadyAt) { - // Keep using aibot_respond_msg within stream window; do not proactively - // push unless wsStreamMaxDuration has elapsed. - logger.WarnCF("wecom_aibot", "Send: stream window still open, skip proactive push", - map[string]any{"req_id": msg.ChatID, "ready_at": route.ReadyAt.Format(time.RFC3339)}) - return nil - } - - if err := c.wsSendActivePush(route.ChatID, route.ChatType, msg.Content); err != nil { - logger.WarnCF("wecom_aibot", "Late reply proactive push failed", - map[string]any{"req_id": msg.ChatID, "chat_id": route.ChatID, "error": err.Error()}) - return fmt.Errorf("websocket delivery failed: %w", channels.ErrSendFailed) - } - logger.InfoCF("wecom_aibot", "Late reply delivered via proactive push", - map[string]any{"req_id": msg.ChatID, "chat_id": route.ChatID, "chat_type": route.ChatType}) - c.deleteReqState(msg.ChatID) - return nil - } - - // Non-blocking fast path: when answerCh has space, deliver without racing - // against task.ctx.Done() (which fires when the task is canceled by a new - // incoming message, but the response must still be sent). - select { - case task.answerCh <- msg.Content: - return nil - default: - } - // answerCh was full; block with cancellation guards. - select { - case task.answerCh <- msg.Content: - case <-task.ctx.Done(): - return nil - case <-ctx.Done(): - return ctx.Err() - } - return nil -} - -// ---- Connection management ---- - -// wsBackoffResetDuration is the minimum duration a WebSocket connection must -// stay up before we reset the reconnect backoff to its initial value. This -// prevents a short burst of failures from causing long waits after later, -// stable connection periods. -const wsBackoffResetDuration = time.Minute - -// connectLoop maintains the WebSocket connection, reconnecting on failure with -// exponential backoff. -func (c *WeComAIBotWSChannel) connectLoop() { - backoff := wsInitialReconnect - for { - select { - case <-c.ctx.Done(): - return - default: - } - - logger.InfoC("wecom_aibot", "Connecting to WeCom WebSocket endpoint...") - start := time.Now() - if err := c.runConnection(); err != nil { - elapsed := time.Since(start) - // If the connection was stable for long enough, reset backoff so that - // a previous burst of failures does not keep us at the maximum delay. - if elapsed >= wsBackoffResetDuration { - backoff = wsInitialReconnect - } - select { - case <-c.ctx.Done(): - return - default: - logger.WarnCF("wecom_aibot", "WebSocket connection lost, reconnecting", - map[string]any{"error": err.Error(), "backoff": backoff.String()}) - select { - case <-time.After(backoff): - case <-c.ctx.Done(): - return - } - if backoff < wsMaxReconnectWait { - backoff *= 2 - if backoff > wsMaxReconnectWait { - backoff = wsMaxReconnectWait - } - } - } - } else { - // Clean exit (context canceled); stop reconnecting. - return - } - } -} - -// runConnection dials, subscribes, and runs the read/heartbeat loops until the -// connection closes or the channel context is canceled. -func (c *WeComAIBotWSChannel) runConnection() error { - dialCtx, dialCancel := context.WithTimeout(c.ctx, wsConnectTimeout) - conn, httpResp, err := websocket.DefaultDialer.DialContext(dialCtx, wsEndpoint, nil) - dialCancel() - if httpResp != nil { - httpResp.Body.Close() - } - if err != nil { - return fmt.Errorf("dial failed: %w", err) - } - - c.connMu.Lock() - c.conn = conn - c.connMu.Unlock() - - defer func() { - c.connMu.Lock() - if c.conn == conn { - c.conn = nil - } - c.connMu.Unlock() - // Cancel any tasks that were started over this connection so their - // agent goroutines do not keep running after the connection is gone. - c.cancelAllTasks() - }() - - // ---- Read loop (must start BEFORE subscribing) ---- - // sendAndWait blocks waiting for the subscribe response on reqPending; - // readLoop is the only goroutine that delivers messages to reqPending. - // Starting readLoop first avoids a deadlock where sendAndWait times out - // because no one reads the server's reply. - readErrCh := make(chan error, 1) - go func() { readErrCh <- c.readLoop(conn) }() - - // ---- Subscribe ---- - reqID := wsGenerateID() - resp, err := c.sendAndWait(conn, reqID, wsCommand{ - Cmd: "aibot_subscribe", - Headers: wsHeaders{ReqID: reqID}, - Body: map[string]string{ - "bot_id": c.config.BotID, - "secret": c.config.Secret(), - }, - }, wsSubscribeTimeout) - if err != nil { - conn.Close() // stop readLoop - <-readErrCh - return fmt.Errorf("subscribe failed: %w", err) - } - if resp.ErrCode != 0 { - conn.Close() - <-readErrCh - return fmt.Errorf("subscribe rejected (errcode=%d): %s", resp.ErrCode, resp.ErrMsg) - } - - logger.InfoC("wecom_aibot", "WebSocket subscription successful") - - // ---- Heartbeat goroutine ---- - hbDone := make(chan struct{}) - go func() { - defer close(hbDone) - c.heartbeatLoop(conn) - }() - - // Wait for the read loop to exit, then tear down the heartbeat. - readErr := <-readErrCh - conn.Close() // signal heartbeat to stop (idempotent) - <-hbDone - return readErr -} - -// sendAndWait registers a pending-response slot, sends cmd, and blocks until -// the matching response arrives or the timeout/context fires. -func (c *WeComAIBotWSChannel) sendAndWait( - conn *websocket.Conn, - reqID string, - cmd wsCommand, - timeout time.Duration, -) (wsEnvelope, error) { - ch := make(chan wsEnvelope, 1) - c.reqPendingMu.Lock() - c.reqPending[reqID] = ch - c.reqPendingMu.Unlock() - - cleanup := func() { - c.reqPendingMu.Lock() - delete(c.reqPending, reqID) - c.reqPendingMu.Unlock() - } - - data, err := json.Marshal(cmd) - if err != nil { - cleanup() - return wsEnvelope{}, fmt.Errorf("marshal command: %w", err) - } - c.connMu.Lock() - err = conn.WriteMessage(websocket.TextMessage, data) - c.connMu.Unlock() - if err != nil { - cleanup() - return wsEnvelope{}, fmt.Errorf("write command: %w", err) - } - - timer := time.NewTimer(timeout) - defer timer.Stop() - select { - case env := <-ch: - return env, nil - case <-timer.C: - cleanup() - return wsEnvelope{}, fmt.Errorf("timeout waiting for response (req_id=%s)", reqID) - case <-c.ctx.Done(): - cleanup() - return wsEnvelope{}, c.ctx.Err() - } -} - -// heartbeatLoop sends a ping every wsHeartbeatInterval until conn is closed. -// It validates the server's pong response via sendAndWait; a failed pong -// triggers a reconnection by closing the connection. -func (c *WeComAIBotWSChannel) heartbeatLoop(conn *websocket.Conn) { - ticker := time.NewTicker(wsHeartbeatInterval) - defer ticker.Stop() - for { - select { - case <-ticker.C: - reqID := wsGenerateID() - resp, err := c.sendAndWait(conn, reqID, wsCommand{ - Cmd: "ping", - Headers: wsHeaders{ReqID: reqID}, - }, wsHeartbeatInterval) - if err != nil { - logger.WarnCF("wecom_aibot", "Heartbeat failed, closing connection", - map[string]any{"error": err.Error()}) - conn.Close() - return - } - if resp.ErrCode != 0 { - logger.WarnCF("wecom_aibot", "Heartbeat rejected", - map[string]any{"errcode": resp.ErrCode, "errmsg": resp.ErrMsg}) - conn.Close() - return - } - logger.DebugCF("wecom_aibot", "Heartbeat pong received", map[string]any{"req_id": reqID}) - case <-c.ctx.Done(): - return - } - } -} - -// readLoop reads WebSocket messages and dispatches them until the connection -// closes or the channel is stopped. -func (c *WeComAIBotWSChannel) readLoop(conn *websocket.Conn) error { - for { - _, raw, err := conn.ReadMessage() - if err != nil { - select { - case <-c.ctx.Done(): - return nil // clean shutdown - default: - return fmt.Errorf("read error: %w", err) - } - } - - var env wsEnvelope - if err := json.Unmarshal(raw, &env); err != nil { - logger.WarnCF("wecom_aibot", "Failed to parse WebSocket message", - map[string]any{"error": err.Error(), "raw": string(raw)}) - continue - } - - // Command responses have an empty Cmd field; forward to any waiting - // sendAndWait() call, or silently drop if no one is waiting (e.g. - // late responses after timeout). - if env.Cmd == "" && env.Headers.ReqID != "" { - c.reqPendingMu.Lock() - ch, ok := c.reqPending[env.Headers.ReqID] - if ok { - delete(c.reqPending, env.Headers.ReqID) - } - c.reqPendingMu.Unlock() - if ok { - ch <- env - } - continue - } - - // Dispatch to appropriate handler in a separate goroutine so the - // read loop is never blocked by a slow agent. - go c.handleEnvelope(env) - } -} - -// ---- Message / event handlers ---- - -// handleEnvelope routes a WebSocket envelope to the right handler. -func (c *WeComAIBotWSChannel) handleEnvelope(env wsEnvelope) { - switch env.Cmd { - case "aibot_msg_callback": - c.handleMsgCallback(env) - case "aibot_event_callback": - c.handleEventCallback(env) - default: - logger.DebugCF("wecom_aibot", "Unhandled WebSocket command", - map[string]any{"cmd": env.Cmd}) - } -} - -// handleMsgCallback processes aibot_msg_callback. -func (c *WeComAIBotWSChannel) handleMsgCallback(env wsEnvelope) { - var msg WeComAIBotWSMessage - if err := json.Unmarshal(env.Body, &msg); err != nil { - logger.WarnCF("wecom_aibot", "Failed to parse msg callback body", - map[string]any{"error": err.Error()}) - return - } - - // Deduplicate by msgid (WeCom may re-deliver on network issues). - if msg.MsgID != "" && !c.dedupe.MarkMessageProcessed(msg.MsgID) { - logger.DebugCF("wecom_aibot", "Duplicate message ignored", - map[string]any{"msgid": msg.MsgID}) - return - } - - reqID := env.Headers.ReqID - switch msg.MsgType { - case "text": - c.handleWSTextMessage(reqID, msg) - case "image": - c.handleWSImageMessage(reqID, msg) - case "voice": - c.handleWSVoiceMessage(reqID, msg) - case "mixed": - c.handleWSMixedMessage(reqID, msg) - case "file": - c.handleWSFileMessage(reqID, msg) - case "video": - c.handleWSVideoMessage(reqID, msg) - default: - logger.WarnCF("wecom_aibot", "Unsupported message type", - map[string]any{"msgtype": msg.MsgType}) - c.wsSendStreamFinish(reqID, wsGenerateID(), - "Unsupported message type: "+msg.MsgType) - } -} - -// handleEventCallback processes aibot_event_callback. -func (c *WeComAIBotWSChannel) handleEventCallback(env wsEnvelope) { - var msg WeComAIBotWSMessage - if err := json.Unmarshal(env.Body, &msg); err != nil { - logger.WarnCF("wecom_aibot", "Failed to parse event callback body", - map[string]any{"error": err.Error()}) - return - } - - // Deduplicate by msgid. - if msg.MsgID != "" && !c.dedupe.MarkMessageProcessed(msg.MsgID) { - logger.DebugCF("wecom_aibot", "Duplicate event ignored", - map[string]any{"msgid": msg.MsgID}) - return - } - - var eventType string - if msg.Event != nil { - eventType = msg.Event.EventType - } - logger.DebugCF("wecom_aibot", "Received event callback", - map[string]any{"event_type": eventType}) - - switch eventType { - case "enter_chat": - if c.config.WelcomeMessage != "" { - c.wsSendWelcomeMsg(env.Headers.ReqID, c.config.WelcomeMessage) - } - case "disconnected_event": - // The server will close this connection after sending this event. - // connectLoop will detect the closure and reconnect automatically. - logger.WarnC("wecom_aibot", - "Received disconnected_event: this connection is being replaced by a newer one") - default: - logger.DebugCF("wecom_aibot", "Unhandled event type", - map[string]any{"event_type": eventType}) - } -} - -// handleWSTextMessage dispatches a plain-text message to the agent and streams -// the reply back over the WebSocket connection. -func (c *WeComAIBotWSChannel) handleWSTextMessage(reqID string, msg WeComAIBotWSMessage) { - if msg.Text == nil { - logger.ErrorC("wecom_aibot", "text message missing text field") - return - } - c.dispatchWSAgentTask(reqID, msg, msg.Text.Content, nil) -} - -// handleWSImageMessage downloads and stores the inbound image, then dispatches -// it to the agent as a media-tagged message. -func (c *WeComAIBotWSChannel) handleWSImageMessage(reqID string, msg WeComAIBotWSMessage) { - if msg.Image == nil { - logger.WarnC("wecom_aibot", "Image message missing image field") - c.wsSendStreamFinish(reqID, wsGenerateID(), "Image message could not be processed.") - return - } - c.wsHandleMediaMessage(reqID, msg, msg.Image.URL, msg.Image.AESKey, "image") -} - -// wsHandleMediaMessage is a shared helper for image, file and video messages. -// It downloads the resource, stores it in MediaStore, and dispatches to the agent. -func (c *WeComAIBotWSChannel) wsHandleMediaMessage( - reqID string, msg WeComAIBotWSMessage, - resourceURL, aesKey, label string, -) { - chatID := wsChatID(msg) - - ctx, cancel := context.WithTimeout(c.ctx, wsImageDownloadTimeout) - defer cancel() - - ref, err := c.storeWSMedia(ctx, chatID, msg.MsgID, resourceURL, aesKey, wsLabelToDefaultExt(label)) - if err != nil { - logger.WarnCF("wecom_aibot", "Failed to download/store WS "+label, - map[string]any{"error": err.Error(), "url": resourceURL}) - c.wsSendStreamFinish(reqID, wsGenerateID(), - strings.ToUpper(label[:1])+label[1:]+" message could not be processed.") - return - } - - c.dispatchWSAgentTask(reqID, msg, "["+label+"]", []string{ref}) -} - -// handleWSMixedMessage handles mixed text+image messages. -// All text parts are collected into the content string; all image parts are -// downloaded and stored in MediaStore before dispatching to the agent. -func (c *WeComAIBotWSChannel) handleWSMixedMessage(reqID string, msg WeComAIBotWSMessage) { - if msg.Mixed == nil { - logger.WarnC("wecom_aibot", "Mixed message has no content") - c.wsSendStreamFinish(reqID, wsGenerateID(), "Mixed message type is not yet fully supported.") - return - } - - chatID := wsChatID(msg) - - ctx, cancel := context.WithTimeout(c.ctx, wsImageDownloadTimeout) - defer cancel() - - var textParts []string - var mediaRefs []string - for _, item := range msg.Mixed.MsgItem { - switch item.MsgType { - case "text": - if item.Text != nil && item.Text.Content != "" { - textParts = append(textParts, item.Text.Content) - } - case "image": - if item.Image != nil { - ref, err := c.storeWSMedia(ctx, chatID, - msg.MsgID+"-"+wsGenerateID(), item.Image.URL, item.Image.AESKey, ".jpg") - if err != nil { - logger.WarnCF("wecom_aibot", "Failed to download/store mixed image", - map[string]any{"error": err.Error()}) - } else { - mediaRefs = append(mediaRefs, ref) - } - } - default: - logger.WarnCF("wecom_aibot", "Unsupported item type in mixed message", - map[string]any{"msgtype": item.MsgType}) - } - } - - if len(textParts) == 0 && len(mediaRefs) == 0 { - logger.WarnC("wecom_aibot", "Mixed message has no usable content") - c.wsSendStreamFinish(reqID, wsGenerateID(), "Mixed message type is not yet fully supported.") - return - } - - content := strings.Join(textParts, "\n") - if content == "" { - content = "[images]" - } - c.dispatchWSAgentTask(reqID, msg, content, mediaRefs) -} - -// dispatchWSAgentTask registers a new agent task, sends the opening stream frame, -// and starts a goroutine that runs the agent and streams the reply back. -// content is the text forwarded to the agent; mediaRefs are optional media -// store references attached to the inbound message. -func (c *WeComAIBotWSChannel) dispatchWSAgentTask( - reqID string, - msg WeComAIBotWSMessage, - content string, - mediaRefs []string, -) { - userID := msg.From.UserID - if userID == "" { - userID = "unknown" - } - // actualChatID is the real WeCom chat/user ID used for peer identification. - // reqID is used as the routing chatID so each turn is independently addressable. - actualChatID := wsChatID(msg) - - streamID := wsGenerateID() - chatType := wsChatTypeValue(msg.ChatType) - taskCtx, taskCancel := context.WithCancel(c.ctx) - - task := &wsTask{ - ReqID: reqID, - ChatID: actualChatID, - ChatType: chatType, - StreamID: streamID, - answerCh: make(chan string, 1), - ctx: taskCtx, - cancel: taskCancel, - } - // Each req_id is unique per WeCom turn; tasks run concurrently, no cancellation. - c.setReqState(reqID, &wsReqState{ - Task: task, - Route: wsLateReplyRoute{ - ChatID: actualChatID, - ChatType: chatType, - ReadyAt: time.Now().Add(wsStreamMaxDuration), - ExpiresAt: time.Now().Add(wsLateReplyRouteTTL), - }, - }) - - logger.DebugCF("wecom_aibot", "Registered new agent task", - map[string]any{"chat_id": actualChatID, "req_id": reqID, "stream_id": streamID}) - - // Send an empty stream opening frame (finish=false) immediately. - c.wsSendStreamChunk(reqID, streamID, false, "") - - go func() { - defer func() { - taskCancel() - c.clearReqTask(reqID, task) - }() - - sender := bus.SenderInfo{ - Platform: "wecom_aibot", - PlatformID: userID, - CanonicalID: identity.BuildCanonicalID("wecom_aibot", userID), - DisplayName: userID, - } - peerKind := "direct" - if msg.ChatType == "group" { - peerKind = "group" - } - peer := bus.Peer{Kind: peerKind, ID: actualChatID} - metadata := map[string]string{ - "channel": "wecom_aibot", - "chat_id": actualChatID, - "chat_type": msg.ChatType, - "msg_type": msg.MsgType, - "msgid": msg.MsgID, - "aibotid": msg.AIBotID, - "stream_id": streamID, - } - // Pass reqID as chatID: OutboundMessage.ChatID = reqID → Send() finds tasks[reqID]. - c.HandleMessage(taskCtx, peer, reqID, userID, reqID, - content, mediaRefs, metadata, sender) - - // Wait for the agent reply. While waiting, send periodic finish=false - // hints so the user knows processing is still in progress. - // WeCom requires finish=true within 6 minutes of the first stream frame; - // wsStreamMaxDuration enforces that limit with a safety margin. - waitHints := []string{ - "⏳ Processing, please wait...", - "⏳ Still processing, please wait...", - "⏳ Almost there, please wait...", - } - ticker := time.NewTicker(wsStreamTickInterval) - defer ticker.Stop() - deadlineTimer := time.NewTimer(wsStreamMaxDuration) - defer deadlineTimer.Stop() - tickCount := 0 - for { - select { - case answer := <-task.answerCh: - // Split the answer into byte-bounded chunks and send as stream frames. - // All but the last carry finish=false; the final frame closes the stream. - chunks := splitWSContent(answer, wsStreamMaxContentBytes) - for i, chunk := range chunks { - c.wsSendStreamChunk(reqID, streamID, i == len(chunks)-1, chunk) - } - c.deleteReqState(reqID) - return - case <-ticker.C: - hint := waitHints[tickCount%len(waitHints)] - tickCount++ - logger.DebugCF("wecom_aibot", "Sending stream progress hint", - map[string]any{"chat_id": actualChatID, "tick": tickCount}) - c.wsSendStreamChunk(reqID, streamID, false, hint) - case <-deadlineTimer.C: - logger.WarnCF("wecom_aibot", - "Stream response deadline reached, closing stream; late reply will be pushed", - map[string]any{"chat_id": actualChatID}) - c.wsSendStreamFinish(reqID, streamID, - "⏳ Processing is taking longer than expected, the response will be sent as a follow-up message.") - return - case <-taskCtx.Done(): - // Give a short grace period so that a response queued in the bus - // just before cancellation can still be delivered. This closes a - // race where a rapid second message cancels this task after the - // agent already published but before Send() wrote to answerCh. - // - // The connection is gone at this point, so we cannot use - // wsSendStreamFinish. Try wsSendActivePush on the (possibly - // already-restored) connection; if that also fails, leave the - // route intact so Send() can push the reply once reconnected. - select { - case answer := <-task.answerCh: - if err := c.wsSendActivePush(task.ChatID, task.ChatType, answer); err != nil { - logger.WarnCF("wecom_aibot", - "Grace-period push failed after task cancellation; reply may be lost", - map[string]any{"req_id": reqID, "chat_id": task.ChatID, "error": err.Error()}) - } else { - c.deleteReqState(reqID) - } - case <-time.After(100 * time.Millisecond): - } - return - } - } - }() -} - -// handleWSVoiceMessage handles voice messages. -// WeCom transcribes voice to text in the callback; if the transcription is -// present it is dispatched as plain text to the agent. -func (c *WeComAIBotWSChannel) handleWSVoiceMessage(reqID string, msg WeComAIBotWSMessage) { - if msg.Voice != nil && msg.Voice.Content != "" { - c.dispatchWSAgentTask(reqID, msg, msg.Voice.Content, nil) - return - } - c.wsSendStreamFinish(reqID, wsGenerateID(), "Voice messages are not yet supported.") -} - -// handleWSFileMessage handles file messages. -func (c *WeComAIBotWSChannel) handleWSFileMessage(reqID string, msg WeComAIBotWSMessage) { - if msg.File == nil { - logger.WarnC("wecom_aibot", "File message missing file field") - c.wsSendStreamFinish(reqID, wsGenerateID(), "File message could not be processed.") - return - } - c.wsHandleMediaMessage(reqID, msg, msg.File.URL, msg.File.AESKey, "file") -} - -// handleWSVideoMessage handles video messages. -func (c *WeComAIBotWSChannel) handleWSVideoMessage(reqID string, msg WeComAIBotWSMessage) { - if msg.Video == nil { - logger.WarnC("wecom_aibot", "Video message missing video field") - c.wsSendStreamFinish(reqID, wsGenerateID(), "Video message could not be processed.") - return - } - c.wsHandleMediaMessage(reqID, msg, msg.Video.URL, msg.Video.AESKey, "video") -} - -// ---- WebSocket write helpers ---- - -// wsSendStreamChunk sends an aibot_respond_msg stream frame. -func (c *WeComAIBotWSChannel) wsSendStreamChunk(reqID, streamID string, finish bool, content string) { - logger.DebugCF("wecom_aibot", "Sending stream chunk", map[string]any{ - "stream_id": streamID, - "finish": finish, - "preview": utils.Truncate(content, 100), - }) - cmd := wsCommand{ - Cmd: "aibot_respond_msg", - Headers: wsHeaders{ReqID: reqID}, - Body: wsRespondMsgBody{ - MsgType: "stream", - Stream: &wsStreamContent{ - ID: streamID, - Finish: finish, - Content: content, - }, - }, - } - if err := c.writeWSAndWait(cmd, wsRespondMsgTimeout); err != nil { - logger.WarnCF("wecom_aibot", "Stream chunk ack failed", map[string]any{ - "req_id": reqID, - "stream_id": streamID, - "finish": finish, - "error": err, - }) - } -} - -// wsSendStreamFinish sends the final aibot_respond_msg frame (finish=true, no images). -func (c *WeComAIBotWSChannel) wsSendStreamFinish(reqID, streamID, content string) { - c.wsSendStreamChunk(reqID, streamID, true, content) -} - -// wsSendWelcomeMsg sends a text welcome message via aibot_respond_welcome_msg. -func (c *WeComAIBotWSChannel) wsSendWelcomeMsg(reqID, content string) { - logger.DebugCF("wecom_aibot", "Sending welcome message", map[string]any{"req_id": reqID}) - cmd := wsCommand{ - Cmd: "aibot_respond_welcome_msg", - Headers: wsHeaders{ReqID: reqID}, - Body: wsRespondMsgBody{ - MsgType: "text", - Text: &wsTextContent{Content: content}, - }, - } - if err := c.writeWSAndWait(cmd, wsWelcomeMsgTimeout); err != nil { - logger.WarnCF("wecom_aibot", "Welcome message ack failed", - map[string]any{"req_id": reqID, "error": err.Error()}) - } -} - -// wsSendActivePush sends a proactive markdown message using aibot_send_msg. -// Long content is automatically split into byte-bounded chunks (≤ wsStreamMaxContentBytes -// each) and delivered as consecutive messages. -// It is used as a fallback for late replies after stream response window expires. -func (c *WeComAIBotWSChannel) wsSendActivePush(chatID string, chatType uint32, content string) error { - if chatID == "" { - return fmt.Errorf("chatid is empty") - } - for _, chunk := range splitWSContent(content, wsStreamMaxContentBytes) { - reqID := wsGenerateID() - if err := c.writeWSAndWait(wsCommand{ - Cmd: "aibot_send_msg", - Headers: wsHeaders{ReqID: reqID}, - Body: wsSendMsgBody{ - ChatID: chatID, - ChatType: chatType, - MsgType: "markdown", - Markdown: &wsMarkdownContent{Content: chunk}, - }, - }, wsSendMsgTimeout); err != nil { - return err - } - } - return nil -} - -// writeWSAndWait writes cmd to the active connection and validates the command response. -func (c *WeComAIBotWSChannel) writeWSAndWait(cmd wsCommand, timeout time.Duration) error { - if cmd.Headers.ReqID == "" { - return fmt.Errorf("req_id is empty") - } - - c.connMu.Lock() - conn := c.conn - c.connMu.Unlock() - if conn == nil { - return fmt.Errorf("websocket not connected") - } - - resp, err := c.sendAndWait(conn, cmd.Headers.ReqID, cmd, timeout) - if err != nil { - return err - } - if resp.ErrCode != 0 { - return fmt.Errorf("%s rejected (errcode=%d): %s", cmd.Cmd, resp.ErrCode, resp.ErrMsg) - } - return nil -} - -// cancelAllTasks cancels every pending agent task; called when the connection drops. -// It also expires each task's stream window (ReadyAt = now) so that when the agent -// eventually delivers its reply via Send(), the message is forwarded via -// wsSendActivePush on the restored connection instead of being silently discarded. -func (c *WeComAIBotWSChannel) cancelAllTasks() { - c.reqStatesMu.Lock() - defer c.reqStatesMu.Unlock() - now := time.Now() - for _, state := range c.reqStates { - if state != nil && state.Task != nil { - state.Task.cancel() - state.Task = nil - // Expire the stream window immediately so Send() uses wsSendActivePush. - state.Route.ReadyAt = now - } - } -} - -func (c *WeComAIBotWSChannel) setReqState(reqID string, state *wsReqState) { - c.reqStatesMu.Lock() - defer c.reqStatesMu.Unlock() - now := time.Now() - for k, v := range c.reqStates { - if v == nil || now.After(v.Route.ExpiresAt) { - delete(c.reqStates, k) - } - } - c.reqStates[reqID] = state -} - -func (c *WeComAIBotWSChannel) getReqState(reqID string) (*wsTask, wsLateReplyRoute, bool) { - c.reqStatesMu.Lock() - defer c.reqStatesMu.Unlock() - state, ok := c.reqStates[reqID] - if !ok || state == nil { - return nil, wsLateReplyRoute{}, false - } - if time.Now().After(state.Route.ExpiresAt) { - delete(c.reqStates, reqID) - return nil, wsLateReplyRoute{}, false - } - return state.Task, state.Route, true -} - -func (c *WeComAIBotWSChannel) deleteReqState(reqID string) { - c.reqStatesMu.Lock() - delete(c.reqStates, reqID) - c.reqStatesMu.Unlock() -} - -func (c *WeComAIBotWSChannel) clearReqTask(reqID string, task *wsTask) { - c.reqStatesMu.Lock() - defer c.reqStatesMu.Unlock() - state, ok := c.reqStates[reqID] - if !ok || state == nil { - return - } - if state.Task == task { - state.Task = nil - } -} - -func wsChatTypeValue(chatType string) uint32 { - if chatType == "group" { - return 2 - } - return 1 -} - -// wsChatID returns the effective chat ID from a WS message. -// For group messages it is msg.ChatID; for single chats it falls back to the sender's UserID. -func wsChatID(msg WeComAIBotWSMessage) string { - if msg.ChatID != "" { - return msg.ChatID - } - return msg.From.UserID -} - -// wsGenerateID generates a random 10-character alphanumeric ID. -// It is package-level (not a method) so it can be shared by both channel modes. -func wsGenerateID() string { - return generateRandomID(10) -} - -// ---- Inbound media download helpers ---- - -// storeWSMedia downloads the resource at resourceURL (with optional AES-CBC -// decryption) and stores it in the MediaStore. The file extension is inferred -// from the HTTP Content-Type response header; defaultExt is used as a fallback -// when the content type is absent or unrecognized. -func (c *WeComAIBotWSChannel) storeWSMedia( - ctx context.Context, - chatID, msgID, resourceURL, aesKey, defaultExt string, -) (string, error) { - store := c.GetMediaStore() - if store == nil { - return "", fmt.Errorf("no media store available") - } - - const maxSize = 20 << 20 // 20 MB - - req, err := http.NewRequestWithContext(ctx, http.MethodGet, resourceURL, nil) - if err != nil { - return "", fmt.Errorf("create request: %w", err) - } - resp, err := wsImageHTTPClient.Do(req) - if err != nil { - return "", fmt.Errorf("download: %w", err) - } - defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - return "", fmt.Errorf("download HTTP %d", resp.StatusCode) - } - - // Infer file extension from the Content-Type response header. - ext := wsMediaExtFromContentType(resp.Header.Get("Content-Type")) - if ext == "" { - ext = defaultExt - } - - // Buffer the media in memory, bounded to maxSize. - data, err := io.ReadAll(io.LimitReader(resp.Body, int64(maxSize)+1)) - if err != nil { - return "", fmt.Errorf("read media: %w", err) - } - if len(data) > maxSize { - return "", fmt.Errorf("media too large (> %d MB)", maxSize>>20) - } - - // AES-CBC decryption if a key is present. - if aesKey != "" { - key, decErr := base64.StdEncoding.DecodeString(aesKey) - if decErr != nil || len(key) != 32 { - key, decErr = decodeWeComAESKey(aesKey) - if decErr != nil { - return "", fmt.Errorf("decode media AES key: %w", decErr) - } - } - data, err = decryptAESCBC(key, data) - if err != nil { - return "", fmt.Errorf("decrypt media: %w", err) - } - } - - // Write to a temp file. The file is owned by the MediaStore and deleted by - // store.ReleaseAll — no caller-side cleanup needed. - mediaDir := filepath.Join(os.TempDir(), "picoclaw_media") - if err = os.MkdirAll(mediaDir, 0o700); err != nil { - return "", fmt.Errorf("mkdir: %w", err) - } - tmpFile, err := os.CreateTemp(mediaDir, msgID+"-*"+ext) - if err != nil { - return "", fmt.Errorf("create temp file: %w", err) - } - tmpPath := tmpFile.Name() - _, writeErr := tmpFile.Write(data) - closeErr := tmpFile.Close() - if writeErr != nil { - os.Remove(tmpPath) - return "", fmt.Errorf("write media: %w", writeErr) - } - if closeErr != nil { - os.Remove(tmpPath) - return "", fmt.Errorf("close media: %w", closeErr) - } - - scope := channels.BuildMediaScope("wecom_aibot", chatID, msgID) - ref, err := store.Store(tmpPath, media.MediaMeta{ - Filename: msgID + ext, - Source: "wecom_aibot", - CleanupPolicy: media.CleanupPolicyDeleteOnCleanup, - }, scope) - if err != nil { - os.Remove(tmpPath) - return "", fmt.Errorf("store: %w", err) - } - return ref, nil -} - -// wsMediaExtFromContentType returns the lowercase file extension (with leading -// dot) for the given Content-Type value, or "" when the type is unrecognized. -func wsMediaExtFromContentType(contentType string) string { - if contentType == "" { - return "" - } - // Strip parameters (e.g. "image/jpeg; charset=utf-8" → "image/jpeg"). - mt := strings.ToLower(strings.TrimSpace(strings.SplitN(contentType, ";", 2)[0])) - switch mt { - case "image/jpeg", "image/jpg": - return ".jpg" - case "image/png": - return ".png" - case "image/gif": - return ".gif" - case "image/webp": - return ".webp" - case "video/mp4": - return ".mp4" - case "video/mpeg", "video/x-mpeg": - return ".mpeg" - case "video/quicktime": - return ".mov" - case "video/webm": - return ".webm" - case "audio/mpeg", "audio/mp3": - return ".mp3" - case "audio/ogg": - return ".ogg" - case "audio/wav": - return ".wav" - case "application/pdf": - return ".pdf" - case "application/zip": - return ".zip" - case "application/x-rar-compressed", "application/vnd.rar": - return ".rar" - case "text/plain": - return ".txt" - case "application/msword": - return ".doc" - case "application/vnd.openxmlformats-officedocument.wordprocessingml.document": - return ".docx" - case "application/vnd.ms-excel": - return ".xls" - case "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": - return ".xlsx" - case "application/vnd.ms-powerpoint": - return ".ppt" - case "application/vnd.openxmlformats-officedocument.presentationml.presentation": - return ".pptx" - } - return "" -} - -// wsLabelToDefaultExt returns the default file extension for the given media label -// used in wsHandleMediaMessage. It is the fallback when Content-Type detection fails. -func wsLabelToDefaultExt(label string) string { - switch label { - case "image": - return ".jpg" - case "video": - return ".mp4" - default: // "file" and any future labels - return ".bin" - } -} - -// ---- Content length helpers ---- - -// splitWSContent splits content into chunks each fitting within maxBytes UTF-8 -// bytes, preserving code block integrity via channels.SplitMessage. -// When SplitMessage still produces an oversized chunk (e.g. dense CJK content), -// splitAtByteBoundary is applied as a last-resort byte-level fallback. -func splitWSContent(content string, maxBytes int) []string { - if len(content) <= maxBytes { - return []string{content} - } - // SplitMessage works in runes. Use maxBytes as the rune limit: for pure ASCII - // this is exact; for multibyte content the byte verification below catches - // any chunk that still overflows. - chunks := channels.SplitMessage(content, maxBytes) - var result []string - for _, chunk := range chunks { - if len(chunk) <= maxBytes { - result = append(result, chunk) - } else { - // Still too large in bytes (e.g. dense CJK); force-split at UTF-8 boundaries. - result = append(result, splitAtByteBoundary(chunk, maxBytes)...) - } - } - return result -} - -// splitAtByteBoundary splits s into parts each ≤ maxBytes bytes by walking back -// from the hard byte limit to find a valid UTF-8 rune start boundary. -// This is a last-resort fallback; it does not try to preserve code blocks. -func splitAtByteBoundary(s string, maxBytes int) []string { - var parts []string - for len(s) > maxBytes { - end := maxBytes - // Walk back past any UTF-8 continuation bytes (high two bits == 10). - for end > 0 && s[end]>>6 == 0b10 { - end-- - } - if end == 0 { - end = maxBytes // shouldn't happen with valid UTF-8 - } - parts = append(parts, s[:end]) - s = strings.TrimLeft(s[end:], " \t\n\r") - } - if s != "" { - parts = append(parts, s) - } - return parts -} diff --git a/pkg/channels/wecom/aibot_ws_test.go b/pkg/channels/wecom/aibot_ws_test.go deleted file mode 100644 index f2f8833a1..000000000 --- a/pkg/channels/wecom/aibot_ws_test.go +++ /dev/null @@ -1,295 +0,0 @@ -package wecom - -import ( - "bytes" - "context" - "net/http" - "net/http/httptest" - "os" - "strings" - "testing" - - "github.com/sipeed/picoclaw/pkg/bus" - "github.com/sipeed/picoclaw/pkg/channels" - "github.com/sipeed/picoclaw/pkg/config" - "github.com/sipeed/picoclaw/pkg/media" -) - -// newTestWSChannel creates a WeComAIBotWSChannel ready for unit testing. -func newTestWSChannel(t *testing.T) *WeComAIBotWSChannel { - t.Helper() - cfg := config.WeComAIBotConfig{ - Enabled: true, - BotID: "test_bot_id", - } - cfg.SetSecret("test_secret") - ch, err := newWeComAIBotWSChannel(cfg, bus.NewMessageBus()) - if err != nil { - t.Fatalf("create WS channel: %v", err) - } - return ch -} - -// TestStoreWSMedia_NilStore verifies that storeWSMedia returns an error when no -// MediaStore has been injected. -func TestStoreWSMedia_NilStore(t *testing.T) { - ch := newTestWSChannel(t) - _, err := ch.storeWSMedia(context.Background(), "chat1", "msg1", "http://any", "", ".jpg") - if err == nil { - t.Fatal("expected error when no MediaStore is set") - } -} - -// TestStoreWSMedia_HTTPError verifies that storeWSMedia propagates HTTP errors -// from the media server. -func TestStoreWSMedia_HTTPError(t *testing.T) { - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - http.Error(w, "not found", http.StatusNotFound) - })) - defer srv.Close() - - ch := newTestWSChannel(t) - ch.SetMediaStore(media.NewFileMediaStore()) - - _, err := ch.storeWSMedia(context.Background(), "chat1", "msg1", srv.URL, "", ".jpg") - if err == nil { - t.Fatal("expected error for HTTP 404") - } -} - -// TestStoreWSMedia_ServerUnavailable verifies that storeWSMedia returns a clear -// error when the media server cannot be reached. -func TestStoreWSMedia_ServerUnavailable(t *testing.T) { - ch := newTestWSChannel(t) - ch.SetMediaStore(media.NewFileMediaStore()) - - // Port 1 is reserved and will refuse the connection immediately. - _, err := ch.storeWSMedia(context.Background(), "chat1", "msg1", "http://127.0.0.1:1", "", ".jpg") - if err == nil { - t.Fatal("expected error for unreachable server") - } -} - -// TestStoreWSMedia_Success_NoAES verifies the happy path: the media is downloaded, -// a media ref is returned, and the file persists and is readable via Resolve until -// ReleaseAll is called. The server returns no Content-Type, so the defaultExt is used. -func TestStoreWSMedia_Success_NoAES(t *testing.T) { - imageData := bytes.Repeat([]byte("x"), 256) - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.WriteHeader(http.StatusOK) - _, _ = w.Write(imageData) - })) - defer srv.Close() - - ch := newTestWSChannel(t) - store := media.NewFileMediaStore() - ch.SetMediaStore(store) - - ref, err := ch.storeWSMedia(context.Background(), "chat1", "msg1", srv.URL, "", ".jpg") - if err != nil { - t.Fatalf("expected no error, got %v", err) - } - if ref == "" { - t.Fatal("expected non-empty ref") - } - - // File must be accessible after storeWSMedia returns (no premature deletion). - path, err := store.Resolve(ref) - if err != nil { - t.Fatalf("ref should resolve: %v", err) - } - got, err := os.ReadFile(path) - if err != nil { - t.Fatalf("file should exist at %s: %v", path, err) - } - if !bytes.Equal(got, imageData) { - t.Errorf("content mismatch: got len=%d, want len=%d", len(got), len(imageData)) - } - - // ReleaseAll must delete the file (store owns lifecycle). - scope := channels.BuildMediaScope("wecom_aibot", "chat1", "msg1") - if err := store.ReleaseAll(scope); err != nil { - t.Fatalf("ReleaseAll failed: %v", err) - } - if _, err := os.Stat(path); !os.IsNotExist(err) { - t.Errorf("file should have been deleted by ReleaseAll, stat err: %v", err) - } -} - -// TestStoreWSMedia_MultipleMessages verifies that concurrent media messages with -// different msgIDs do not collide and each resolve to distinct files. -func TestStoreWSMedia_MultipleMessages(t *testing.T) { - imageA := bytes.Repeat([]byte("a"), 64) - imageB := bytes.Repeat([]byte("b"), 64) - - srvA := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.WriteHeader(http.StatusOK) - _, _ = w.Write(imageA) - })) - defer srvA.Close() - srvB := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.WriteHeader(http.StatusOK) - _, _ = w.Write(imageB) - })) - defer srvB.Close() - - ch := newTestWSChannel(t) - store := media.NewFileMediaStore() - ch.SetMediaStore(store) - - refA, err := ch.storeWSMedia(context.Background(), "chat1", "msgA", srvA.URL, "", ".jpg") - if err != nil { - t.Fatalf("storeWSMedia A: %v", err) - } - refB, err := ch.storeWSMedia(context.Background(), "chat1", "msgB", srvB.URL, "", ".jpg") - if err != nil { - t.Fatalf("storeWSMedia B: %v", err) - } - if refA == refB { - t.Fatal("distinct messages must produce distinct refs") - } - - pathA, _ := store.Resolve(refA) - pathB, _ := store.Resolve(refB) - if pathA == pathB { - t.Fatal("distinct messages must be stored at distinct paths") - } - - gotA, _ := os.ReadFile(pathA) - gotB, _ := os.ReadFile(pathB) - if !bytes.Equal(gotA, imageA) { - t.Errorf("content mismatch for message A") - } - if !bytes.Equal(gotB, imageB) { - t.Errorf("content mismatch for message B") - } -} - -// TestStoreWSMedia_ContentTypeExt verifies that the file extension is inferred -// from the HTTP Content-Type header and the defaultExt fallback is used when the -// type is absent or unrecognized. -func TestStoreWSMedia_ContentTypeExt(t *testing.T) { - tests := []struct { - contentType string - wantExt string - }{ - {"image/jpeg", ".jpg"}, - {"image/png", ".png"}, - {"video/mp4", ".mp4"}, - {"application/pdf", ".pdf"}, - {"application/zip", ".zip"}, - // With parameters stripped. - {"video/mp4; codecs=avc1", ".mp4"}, - // Unknown type → falls back to defaultExt. - {"", ""}, - {"application/octet-stream", ""}, - } - for _, tc := range tests { - got := wsMediaExtFromContentType(tc.contentType) - if got != tc.wantExt { - t.Errorf("wsMediaExtFromContentType(%q) = %q, want %q", tc.contentType, got, tc.wantExt) - } - } - - // End-to-end: server returns Content-Type: video/mp4, defaultExt is .bin. - // The stored file should carry the .mp4 extension, not .bin. - payload := bytes.Repeat([]byte("v"), 128) - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.Header().Set("Content-Type", "video/mp4") - w.WriteHeader(http.StatusOK) - _, _ = w.Write(payload) - })) - defer srv.Close() - - ch := newTestWSChannel(t) - store := media.NewFileMediaStore() - ch.SetMediaStore(store) - - ref, err := ch.storeWSMedia(context.Background(), "chat1", "vid1", srv.URL, "", ".bin") - if err != nil { - t.Fatalf("storeWSMedia: %v", err) - } - path, err := store.Resolve(ref) - if err != nil { - t.Fatalf("resolve: %v", err) - } - if ext := path[len(path)-4:]; ext != ".mp4" { - t.Errorf("expected .mp4 extension from Content-Type, got %q", ext) - } -} - -// TestSplitWSContent verifies byte-aware splitting of stream content. -func TestSplitWSContent(t *testing.T) { - t.Run("short content is not split", func(t *testing.T) { - chunks := splitWSContent("hello", 20480) - if len(chunks) != 1 || chunks[0] != "hello" { - t.Fatalf("unexpected chunks: %v", chunks) - } - }) - - t.Run("ASCII content split at byte boundary", func(t *testing.T) { - // Build a string just over the limit. - content := strings.Repeat("a", 20481) - chunks := splitWSContent(content, 20480) - if len(chunks) < 2 { - t.Fatalf("expected >= 2 chunks, got %d", len(chunks)) - } - for i, c := range chunks { - if len(c) > 20480 { - t.Errorf("chunk %d has %d bytes, want <= 20480", i, len(c)) - } - } - // Reassembled content must equal the original (possibly without leading - // whitespace that splitWSContent trims between chunks). - joined := strings.Join(chunks, "") - if len(joined) < len(content)-len(chunks) { - t.Errorf("joined length %d too short (original %d)", len(joined), len(content)) - } - }) - - t.Run("CJK content split within byte limit", func(t *testing.T) { - // Each CJK rune is 3 bytes in UTF-8. - // 7000 CJK chars = 21000 bytes, which exceeds 20480. - content := strings.Repeat("\u4e2d", 7000) - chunks := splitWSContent(content, 20480) - if len(chunks) < 2 { - t.Fatalf("expected >= 2 chunks for 21000-byte CJK content, got %d", len(chunks)) - } - for i, c := range chunks { - if len(c) > 20480 { - t.Errorf("chunk %d has %d bytes, want <= 20480", i, len(c)) - } - // Every chunk must be valid UTF-8. - if !strings.ContainsRune(c, '\u4e2d') && len(c) > 0 { - // quick plausibility check — content was pure CJK - } - } - }) -} - -// TestSplitAtByteBoundary verifies the last-resort byte-boundary splitter. -func TestSplitAtByteBoundary(t *testing.T) { - t.Run("ASCII fits in one chunk", func(t *testing.T) { - parts := splitAtByteBoundary("hello world", 100) - if len(parts) != 1 { - t.Fatalf("expected 1 part, got %d", len(parts)) - } - }) - - t.Run("splits at byte boundary, never mid-rune", func(t *testing.T) { - // 10 CJK characters = 30 bytes; split at 20 bytes. - s := strings.Repeat("\u6587", 10) // 10 × 3 bytes = 30 bytes - parts := splitAtByteBoundary(s, 20) - for i, p := range parts { - if len(p) > 20 { - t.Errorf("part %d has %d bytes, want <= 20", i, len(p)) - } - // Must be valid UTF-8 (no torn multi-byte sequences). - for j, r := range p { - if r == '\uFFFD' { - t.Errorf("part %d has replacement rune at position %d: torn UTF-8", i, j) - } - } - } - }) -} diff --git a/pkg/channels/wecom/app.go b/pkg/channels/wecom/app.go deleted file mode 100644 index fccfc60a3..000000000 --- a/pkg/channels/wecom/app.go +++ /dev/null @@ -1,756 +0,0 @@ -package wecom - -import ( - "bytes" - "context" - "encoding/json" - "encoding/xml" - "fmt" - "io" - "mime/multipart" - "net/http" - "net/url" - "os" - "path/filepath" - "strings" - "sync" - "time" - - "github.com/sipeed/picoclaw/pkg/bus" - "github.com/sipeed/picoclaw/pkg/channels" - "github.com/sipeed/picoclaw/pkg/config" - "github.com/sipeed/picoclaw/pkg/identity" - "github.com/sipeed/picoclaw/pkg/logger" - "github.com/sipeed/picoclaw/pkg/utils" -) - -const ( - wecomAPIBase = "https://qyapi.weixin.qq.com" -) - -// WeComAppChannel implements the Channel interface for WeCom App (企业微信自建应用) -type WeComAppChannel struct { - *channels.BaseChannel - config config.WeComAppConfig - client *http.Client - accessToken string - tokenExpiry time.Time - tokenMu sync.RWMutex - ctx context.Context - cancel context.CancelFunc - processedMsgs *MessageDeduplicator -} - -// WeComXMLMessage represents the XML message structure from WeCom -type WeComXMLMessage struct { - XMLName xml.Name `xml:"xml"` - ToUserName string `xml:"ToUserName"` - FromUserName string `xml:"FromUserName"` - CreateTime int64 `xml:"CreateTime"` - MsgType string `xml:"MsgType"` - Content string `xml:"Content"` - MsgId int64 `xml:"MsgId"` - AgentID int64 `xml:"AgentID"` - PicUrl string `xml:"PicUrl"` - MediaId string `xml:"MediaId"` - Format string `xml:"Format"` - ThumbMediaId string `xml:"ThumbMediaId"` - LocationX float64 `xml:"Location_X"` - LocationY float64 `xml:"Location_Y"` - Scale int `xml:"Scale"` - Label string `xml:"Label"` - Title string `xml:"Title"` - Description string `xml:"Description"` - Url string `xml:"Url"` - Event string `xml:"Event"` - EventKey string `xml:"EventKey"` -} - -// WeComTextMessage represents text message for sending -type WeComTextMessage struct { - ToUser string `json:"touser"` - MsgType string `json:"msgtype"` - AgentID int64 `json:"agentid"` - Text struct { - Content string `json:"content"` - } `json:"text"` - Safe int `json:"safe,omitempty"` -} - -// WeComMarkdownMessage represents markdown message for sending -type WeComMarkdownMessage struct { - ToUser string `json:"touser"` - MsgType string `json:"msgtype"` - AgentID int64 `json:"agentid"` - Markdown struct { - Content string `json:"content"` - } `json:"markdown"` -} - -// WeComImageMessage represents image message for sending -type WeComImageMessage struct { - ToUser string `json:"touser"` - MsgType string `json:"msgtype"` - AgentID int64 `json:"agentid"` - Image struct { - MediaID string `json:"media_id"` - } `json:"image"` -} - -// WeComAccessTokenResponse represents the access token API response -type WeComAccessTokenResponse struct { - ErrCode int `json:"errcode"` - ErrMsg string `json:"errmsg"` - AccessToken string `json:"access_token"` - ExpiresIn int `json:"expires_in"` -} - -// WeComSendMessageResponse represents the send message API response -type WeComSendMessageResponse struct { - ErrCode int `json:"errcode"` - ErrMsg string `json:"errmsg"` - InvalidUser string `json:"invaliduser"` - InvalidParty string `json:"invalidparty"` - InvalidTag string `json:"invalidtag"` -} - -// PKCS7Padding adds PKCS7 padding -type PKCS7Padding struct{} - -// NewWeComAppChannel creates a new WeCom App channel instance -func NewWeComAppChannel(cfg config.WeComAppConfig, messageBus *bus.MessageBus) (*WeComAppChannel, error) { - if cfg.CorpID == "" || cfg.CorpSecret() == "" || cfg.AgentID == 0 { - return nil, fmt.Errorf("wecom_app corp_id, corp_secret and agent_id are required") - } - - base := channels.NewBaseChannel("wecom_app", cfg, messageBus, cfg.AllowFrom, - channels.WithMaxMessageLength(2048), - channels.WithGroupTrigger(cfg.GroupTrigger), - channels.WithReasoningChannelID(cfg.ReasoningChannelID), - ) - - // Client timeout must be >= the configured ReplyTimeout so the - // per-request context deadline is always the effective limit. - clientTimeout := 30 * time.Second - if d := time.Duration(cfg.ReplyTimeout) * time.Second; d > clientTimeout { - clientTimeout = d - } - - ctx, cancel := context.WithCancel(context.Background()) - return &WeComAppChannel{ - BaseChannel: base, - config: cfg, - client: &http.Client{Timeout: clientTimeout}, - ctx: ctx, - cancel: cancel, - processedMsgs: NewMessageDeduplicator(wecomMaxProcessedMessages), - }, nil -} - -// Name returns the channel name -func (c *WeComAppChannel) Name() string { - return "wecom_app" -} - -// Start initializes the WeCom App channel -func (c *WeComAppChannel) Start(ctx context.Context) error { - logger.InfoC("wecom_app", "Starting WeCom App channel...") - - // Cancel the context created in the constructor to avoid a resource leak. - if c.cancel != nil { - c.cancel() - } - c.ctx, c.cancel = context.WithCancel(ctx) - - // Get initial access token - if err := c.refreshAccessToken(); err != nil { - logger.WarnCF("wecom_app", "Failed to get initial access token", map[string]any{ - "error": err.Error(), - }) - } - - // Start token refresh goroutine - go c.tokenRefreshLoop() - - c.SetRunning(true) - logger.InfoC("wecom_app", "WeCom App channel started") - - return nil -} - -// Stop gracefully stops the WeCom App channel -func (c *WeComAppChannel) Stop(ctx context.Context) error { - logger.InfoC("wecom_app", "Stopping WeCom App channel...") - - if c.cancel != nil { - c.cancel() - } - - c.SetRunning(false) - logger.InfoC("wecom_app", "WeCom App channel stopped") - return nil -} - -// Send sends a message to WeCom user proactively using access token -func (c *WeComAppChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { - if !c.IsRunning() { - return channels.ErrNotRunning - } - - accessToken := c.getAccessToken() - if accessToken == "" { - return fmt.Errorf("no valid access token available") - } - - logger.DebugCF("wecom_app", "Sending message", map[string]any{ - "chat_id": msg.ChatID, - "preview": utils.Truncate(msg.Content, 100), - }) - - return c.sendTextMessage(ctx, accessToken, msg.ChatID, msg.Content) -} - -// SendMedia implements the channels.MediaSender interface. -func (c *WeComAppChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error { - if !c.IsRunning() { - return channels.ErrNotRunning - } - - accessToken := c.getAccessToken() - if accessToken == "" { - return fmt.Errorf("no valid access token available: %w", channels.ErrTemporary) - } - - store := c.GetMediaStore() - if store == nil { - return fmt.Errorf("no media store available: %w", channels.ErrSendFailed) - } - - for _, part := range msg.Parts { - localPath, err := store.Resolve(part.Ref) - if err != nil { - logger.ErrorCF("wecom_app", "Failed to resolve media ref", map[string]any{ - "ref": part.Ref, - "error": err.Error(), - }) - continue - } - - // Map part type to WeCom media type - var mediaType string - switch part.Type { - case "image": - mediaType = "image" - case "audio": - mediaType = "voice" - case "video": - mediaType = "video" - default: - mediaType = "file" - } - - // Upload media to get media_id - mediaID, err := c.uploadMedia(ctx, accessToken, mediaType, localPath) - if err != nil { - logger.ErrorCF("wecom_app", "Failed to upload media", map[string]any{ - "type": mediaType, - "error": err.Error(), - }) - // Fallback: send caption as text - if part.Caption != "" { - _ = c.sendTextMessage(ctx, accessToken, msg.ChatID, part.Caption) - } - continue - } - - // Send media message using the media_id - if mediaType == "image" { - err = c.sendImageMessage(ctx, accessToken, msg.ChatID, mediaID) - } else { - // For non-image types, send as text fallback with caption - caption := part.Caption - if caption == "" { - caption = fmt.Sprintf("[%s: %s]", part.Type, part.Filename) - } - err = c.sendTextMessage(ctx, accessToken, msg.ChatID, caption) - } - - if err != nil { - return err - } - } - - return nil -} - -// uploadMedia uploads a local file to WeCom temporary media storage. -func (c *WeComAppChannel) uploadMedia(ctx context.Context, accessToken, mediaType, localPath string) (string, error) { - apiURL := fmt.Sprintf("%s/cgi-bin/media/upload?access_token=%s&type=%s", - wecomAPIBase, url.QueryEscape(accessToken), url.QueryEscape(mediaType)) - - file, err := os.Open(localPath) - if err != nil { - return "", fmt.Errorf("failed to open file: %w", err) - } - defer file.Close() - - body := &bytes.Buffer{} - writer := multipart.NewWriter(body) - - filename := filepath.Base(localPath) - formFile, err := writer.CreateFormFile("media", filename) - if err != nil { - return "", fmt.Errorf("failed to create form file: %w", err) - } - - if _, err = io.Copy(formFile, file); err != nil { - return "", fmt.Errorf("failed to copy file content: %w", err) - } - writer.Close() - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, body) - if err != nil { - return "", fmt.Errorf("failed to create request: %w", err) - } - req.Header.Set("Content-Type", writer.FormDataContentType()) - - resp, err := c.client.Do(req) - if err != nil { - return "", channels.ClassifyNetError(err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - respBody, readErr := io.ReadAll(resp.Body) - if readErr != nil { - return "", channels.ClassifySendError( - resp.StatusCode, - fmt.Errorf("reading wecom upload error response: %w", readErr), - ) - } - return "", channels.ClassifySendError( - resp.StatusCode, - fmt.Errorf("wecom upload error: %s", string(respBody)), - ) - } - - var result struct { - ErrCode int `json:"errcode"` - ErrMsg string `json:"errmsg"` - MediaID string `json:"media_id"` - } - if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { - return "", fmt.Errorf("failed to parse upload response: %w", err) - } - - if result.ErrCode != 0 { - return "", fmt.Errorf("upload API error: %s (code: %d)", result.ErrMsg, result.ErrCode) - } - - return result.MediaID, nil -} - -// sendWeComMessage marshals payload and POSTs it to the WeCom message API. -func (c *WeComAppChannel) sendWeComMessage(ctx context.Context, accessToken string, payload any) error { - apiURL := fmt.Sprintf("%s/cgi-bin/message/send?access_token=%s", wecomAPIBase, accessToken) - - jsonData, err := json.Marshal(payload) - if err != nil { - return fmt.Errorf("failed to marshal message: %w", err) - } - - timeout := c.config.ReplyTimeout - if timeout <= 0 { - timeout = 5 - } - - reqCtx, cancel := context.WithTimeout(ctx, time.Duration(timeout)*time.Second) - defer cancel() - - req, err := http.NewRequestWithContext(reqCtx, http.MethodPost, apiURL, bytes.NewBuffer(jsonData)) - if err != nil { - return fmt.Errorf("failed to create request: %w", err) - } - req.Header.Set("Content-Type", "application/json") - - resp, err := c.client.Do(req) - if err != nil { - return channels.ClassifyNetError(err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - respBody, readErr := io.ReadAll(resp.Body) - if readErr != nil { - return channels.ClassifySendError( - resp.StatusCode, - fmt.Errorf("reading wecom_app error response: %w", readErr), - ) - } - return channels.ClassifySendError( - resp.StatusCode, - fmt.Errorf("wecom_app API error: %s", string(respBody)), - ) - } - - respBody, err := io.ReadAll(resp.Body) - if err != nil { - return fmt.Errorf("failed to read response: %w", err) - } - - var sendResp WeComSendMessageResponse - if err := json.Unmarshal(respBody, &sendResp); err != nil { - return fmt.Errorf("failed to parse response: %w", err) - } - - if sendResp.ErrCode != 0 { - return fmt.Errorf("API error: %s (code: %d)", sendResp.ErrMsg, sendResp.ErrCode) - } - - return nil -} - -// sendImageMessage sends an image message using a media_id. -func (c *WeComAppChannel) sendImageMessage(ctx context.Context, accessToken, userID, mediaID string) error { - msg := WeComImageMessage{ - ToUser: userID, - MsgType: "image", - AgentID: c.config.AgentID, - } - msg.Image.MediaID = mediaID - return c.sendWeComMessage(ctx, accessToken, msg) -} - -// WebhookPath returns the path for registering on the shared HTTP server. -func (c *WeComAppChannel) WebhookPath() string { - if c.config.WebhookPath != "" { - return c.config.WebhookPath - } - return "/webhook/wecom-app" -} - -// ServeHTTP implements http.Handler for the shared HTTP server. -func (c *WeComAppChannel) ServeHTTP(w http.ResponseWriter, r *http.Request) { - c.handleWebhook(w, r) -} - -// HealthPath returns the health check endpoint path. -func (c *WeComAppChannel) HealthPath() string { - return "/health/wecom-app" -} - -// HealthHandler handles health check requests. -func (c *WeComAppChannel) HealthHandler(w http.ResponseWriter, r *http.Request) { - c.handleHealth(w, r) -} - -// handleWebhook handles incoming webhook requests from WeCom -func (c *WeComAppChannel) handleWebhook(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - - // Log all incoming requests for debugging - logger.DebugCF("wecom_app", "Received webhook request", map[string]any{ - "method": r.Method, - "url": r.URL.String(), - "path": r.URL.Path, - "query": r.URL.RawQuery, - }) - - if r.Method == http.MethodGet { - // Handle verification request - c.handleVerification(ctx, w, r) - return - } - - if r.Method == http.MethodPost { - // Handle message callback - c.handleMessageCallback(ctx, w, r) - return - } - - logger.WarnCF("wecom_app", "Method not allowed", map[string]any{ - "method": r.Method, - }) - http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) -} - -// handleVerification handles the URL verification request from WeCom -func (c *WeComAppChannel) handleVerification(ctx context.Context, w http.ResponseWriter, r *http.Request) { - query := r.URL.Query() - msgSignature := query.Get("msg_signature") - timestamp := query.Get("timestamp") - nonce := query.Get("nonce") - echostr := query.Get("echostr") - - logger.DebugCF("wecom_app", "Handling verification request", map[string]any{ - "msg_signature": msgSignature, - "timestamp": timestamp, - "nonce": nonce, - "echostr": echostr, - "corp_id": c.config.CorpID, - }) - - if msgSignature == "" || timestamp == "" || nonce == "" || echostr == "" { - logger.ErrorC("wecom_app", "Missing parameters in verification request") - http.Error(w, "Missing parameters", http.StatusBadRequest) - return - } - - // Verify signature - if !verifySignature(c.config.Token(), msgSignature, timestamp, nonce, echostr) { - logger.WarnCF("wecom_app", "Signature verification failed", map[string]any{ - "token": c.config.Token(), - "msg_signature": msgSignature, - "timestamp": timestamp, - "nonce": nonce, - }) - http.Error(w, "Invalid signature", http.StatusForbidden) - return - } - - logger.DebugC("wecom_app", "Signature verification passed") - - // Decrypt echostr with CorpID verification - // For WeCom App (自建应用), receiveid should be corp_id - logger.DebugCF("wecom_app", "Attempting to decrypt echostr", map[string]any{ - "encoding_aes_key": c.config.EncodingAESKey(), - "corp_id": c.config.CorpID, - }) - decryptedEchoStr, err := decryptMessageWithVerify(echostr, c.config.EncodingAESKey(), c.config.CorpID) - if err != nil { - logger.ErrorCF("wecom_app", "Failed to decrypt echostr", map[string]any{ - "error": err.Error(), - "encoding_aes_key": c.config.EncodingAESKey, - "corp_id": c.config.CorpID, - }) - http.Error(w, "Decryption failed", http.StatusInternalServerError) - return - } - - logger.DebugCF("wecom_app", "Successfully decrypted echostr", map[string]any{ - "decrypted": decryptedEchoStr, - }) - - // Remove BOM and whitespace as per WeCom documentation - // The response must be plain text without quotes, BOM, or newlines - decryptedEchoStr = strings.TrimSpace(decryptedEchoStr) - decryptedEchoStr = strings.TrimPrefix(decryptedEchoStr, "\xef\xbb\xbf") // Remove UTF-8 BOM - w.Write([]byte(decryptedEchoStr)) -} - -// handleMessageCallback handles incoming messages from WeCom -func (c *WeComAppChannel) handleMessageCallback(ctx context.Context, w http.ResponseWriter, r *http.Request) { - query := r.URL.Query() - msgSignature := query.Get("msg_signature") - timestamp := query.Get("timestamp") - nonce := query.Get("nonce") - - if msgSignature == "" || timestamp == "" || nonce == "" { - http.Error(w, "Missing parameters", http.StatusBadRequest) - return - } - - // Read request body - body, err := io.ReadAll(r.Body) - if err != nil { - http.Error(w, "Failed to read body", http.StatusBadRequest) - return - } - defer r.Body.Close() - - // Parse XML to get encrypted message - var encryptedMsg struct { - XMLName xml.Name `xml:"xml"` - ToUserName string `xml:"ToUserName"` - Encrypt string `xml:"Encrypt"` - AgentID string `xml:"AgentID"` - } - - if err = xml.Unmarshal(body, &encryptedMsg); err != nil { - logger.ErrorCF("wecom_app", "Failed to parse XML", map[string]any{ - "error": err.Error(), - }) - http.Error(w, "Invalid XML", http.StatusBadRequest) - return - } - - // Verify signature - if !verifySignature(c.config.Token(), msgSignature, timestamp, nonce, encryptedMsg.Encrypt) { - logger.WarnC("wecom_app", "Message signature verification failed") - http.Error(w, "Invalid signature", http.StatusForbidden) - return - } - - // Decrypt message with CorpID verification - // For WeCom App (自建应用), receiveid should be corp_id - decryptedMsg, err := decryptMessageWithVerify(encryptedMsg.Encrypt, c.config.EncodingAESKey(), c.config.CorpID) - if err != nil { - logger.ErrorCF("wecom_app", "Failed to decrypt message", map[string]any{ - "error": err.Error(), - }) - http.Error(w, "Decryption failed", http.StatusInternalServerError) - return - } - - // Parse decrypted XML message - var msg WeComXMLMessage - if err := xml.Unmarshal([]byte(decryptedMsg), &msg); err != nil { - logger.ErrorCF("wecom_app", "Failed to parse decrypted message", map[string]any{ - "error": err.Error(), - }) - http.Error(w, "Invalid message format", http.StatusBadRequest) - return - } - - // Process the message with the channel's long-lived context (not the HTTP - // request context, which is canceled as soon as we return the response). - go c.processMessage(c.ctx, msg) - - // Return success response immediately - // WeCom App requires response within configured timeout (default 5 seconds) - w.Write([]byte("success")) -} - -// processMessage processes the received message -func (c *WeComAppChannel) processMessage(ctx context.Context, msg WeComXMLMessage) { - // Skip non-text messages for now (can be extended) - if msg.MsgType != "text" && msg.MsgType != "image" && msg.MsgType != "voice" { - logger.DebugCF("wecom_app", "Skipping non-supported message type", map[string]any{ - "msg_type": msg.MsgType, - }) - return - } - - // Message deduplication: Use msg_id to prevent duplicate processing - // As per WeCom documentation, use msg_id for deduplication - msgID := fmt.Sprintf("%d", msg.MsgId) - if !c.processedMsgs.MarkMessageProcessed(msgID) { - logger.DebugCF("wecom_app", "Skipping duplicate message", map[string]any{ - "msg_id": msgID, - }) - return - } - - senderID := msg.FromUserName - chatID := senderID // WeCom App uses user ID as chat ID for direct messages - - // Build metadata - // WeCom App only supports direct messages (private chat) - peer := bus.Peer{Kind: "direct", ID: senderID} - messageID := fmt.Sprintf("%d", msg.MsgId) - - metadata := map[string]string{ - "msg_type": msg.MsgType, - "msg_id": fmt.Sprintf("%d", msg.MsgId), - "agent_id": fmt.Sprintf("%d", msg.AgentID), - "platform": "wecom_app", - "media_id": msg.MediaId, - "create_time": fmt.Sprintf("%d", msg.CreateTime), - } - - content := msg.Content - - logger.DebugCF("wecom_app", "Received message", map[string]any{ - "sender_id": senderID, - "msg_type": msg.MsgType, - "preview": utils.Truncate(content, 50), - }) - - // Build sender info - appSender := bus.SenderInfo{ - Platform: "wecom", - PlatformID: senderID, - CanonicalID: identity.BuildCanonicalID("wecom", senderID), - } - - // Handle the message through the base channel - c.HandleMessage(ctx, peer, messageID, senderID, chatID, content, nil, metadata, appSender) -} - -// tokenRefreshLoop periodically refreshes the access token -func (c *WeComAppChannel) tokenRefreshLoop() { - ticker := time.NewTicker(5 * time.Minute) - defer ticker.Stop() - - for { - select { - case <-c.ctx.Done(): - return - case <-ticker.C: - if err := c.refreshAccessToken(); err != nil { - logger.ErrorCF("wecom_app", "Failed to refresh access token", map[string]any{ - "error": err.Error(), - }) - } - } - } -} - -// refreshAccessToken gets a new access token from WeCom API -func (c *WeComAppChannel) refreshAccessToken() error { - apiURL := fmt.Sprintf("%s/cgi-bin/gettoken?corpid=%s&corpsecret=%s", - wecomAPIBase, url.QueryEscape(c.config.CorpID), url.QueryEscape(c.config.CorpSecret())) - - resp, err := http.Get(apiURL) - if err != nil { - return fmt.Errorf("failed to request access token: %w", err) - } - defer resp.Body.Close() - - body, err := io.ReadAll(resp.Body) - if err != nil { - return fmt.Errorf("failed to read response: %w", err) - } - - var tokenResp WeComAccessTokenResponse - if err := json.Unmarshal(body, &tokenResp); err != nil { - return fmt.Errorf("failed to parse response: %w", err) - } - - if tokenResp.ErrCode != 0 { - return fmt.Errorf("API error: %s (code: %d)", tokenResp.ErrMsg, tokenResp.ErrCode) - } - - c.tokenMu.Lock() - c.accessToken = tokenResp.AccessToken - c.tokenExpiry = time.Now().Add(time.Duration(tokenResp.ExpiresIn-300) * time.Second) // Refresh 5 minutes early - c.tokenMu.Unlock() - - logger.DebugC("wecom_app", "Access token refreshed successfully") - return nil -} - -// getAccessToken returns the current valid access token -func (c *WeComAppChannel) getAccessToken() string { - c.tokenMu.RLock() - defer c.tokenMu.RUnlock() - - if time.Now().After(c.tokenExpiry) { - return "" - } - - return c.accessToken -} - -// sendTextMessage sends a text message to a user. -func (c *WeComAppChannel) sendTextMessage(ctx context.Context, accessToken, userID, content string) error { - msg := WeComTextMessage{ - ToUser: userID, - MsgType: "text", - AgentID: c.config.AgentID, - } - msg.Text.Content = content - return c.sendWeComMessage(ctx, accessToken, msg) -} - -// handleHealth handles health check requests -func (c *WeComAppChannel) handleHealth(w http.ResponseWriter, r *http.Request) { - status := map[string]any{ - "status": "ok", - "running": c.IsRunning(), - "has_token": c.getAccessToken() != "", - } - - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(status) -} diff --git a/pkg/channels/wecom/app_test.go b/pkg/channels/wecom/app_test.go deleted file mode 100644 index 502544441..000000000 --- a/pkg/channels/wecom/app_test.go +++ /dev/null @@ -1,1060 +0,0 @@ -package wecom - -import ( - "bytes" - "context" - "crypto/aes" - "crypto/cipher" - "crypto/sha1" - "encoding/base64" - "encoding/binary" - "encoding/json" - "encoding/xml" - "fmt" - "net/http" - "net/http/httptest" - "sort" - "strings" - "testing" - "time" - - "github.com/sipeed/picoclaw/pkg/bus" - "github.com/sipeed/picoclaw/pkg/config" -) - -// generateTestAESKeyApp generates a valid test AES key for WeCom App -func generateTestAESKeyApp() string { - // AES key needs to be 32 bytes (256 bits) for AES-256 - key := make([]byte, 32) - for i := range key { - key[i] = byte(i + 1) - } - // Return base64 encoded key without padding - return base64.StdEncoding.EncodeToString(key)[:43] -} - -// encryptTestMessageApp encrypts a message for testing WeCom App -func encryptTestMessageApp(message, aesKey string) (string, error) { - // Decode AES key - key, err := base64.StdEncoding.DecodeString(aesKey + "=") - if err != nil { - return "", err - } - - // Prepare message: random(16) + msg_len(4) + msg + corp_id - random := make([]byte, 0, 16) - for i := range 16 { - random = append(random, byte(i+1)) - } - - msgBytes := []byte(message) - corpID := []byte("test_corp_id") - - msgLen := uint32(len(msgBytes)) - lenBytes := make([]byte, 4) - binary.BigEndian.PutUint32(lenBytes, msgLen) - - plainText := append(random, lenBytes...) - plainText = append(plainText, msgBytes...) - plainText = append(plainText, corpID...) - - // PKCS7 padding - blockSize := aes.BlockSize - padding := blockSize - len(plainText)%blockSize - padText := bytes.Repeat([]byte{byte(padding)}, padding) - plainText = append(plainText, padText...) - - // Encrypt - block, err := aes.NewCipher(key) - if err != nil { - return "", err - } - - mode := cipher.NewCBCEncrypter(block, key[:aes.BlockSize]) - cipherText := make([]byte, len(plainText)) - mode.CryptBlocks(cipherText, plainText) - - return base64.StdEncoding.EncodeToString(cipherText), nil -} - -// generateSignatureApp generates a signature for testing WeCom App -func generateSignatureApp(token, timestamp, nonce, msgEncrypt string) string { - params := []string{token, timestamp, nonce, msgEncrypt} - sort.Strings(params) - str := strings.Join(params, "") - hash := sha1.Sum([]byte(str)) - return fmt.Sprintf("%x", hash) -} - -func TestNewWeComAppChannel(t *testing.T) { - msgBus := bus.NewMessageBus() - - t.Run("missing corp_id", func(t *testing.T) { - cfg := config.WeComAppConfig{ - CorpID: "", - AgentID: 1000002, - } - cfg.SetCorpSecret("test_secret") - _, err := NewWeComAppChannel(cfg, msgBus) - if err == nil { - t.Error("expected error for missing corp_id, got nil") - } - }) - - t.Run("missing corp_secret", func(t *testing.T) { - cfg := config.WeComAppConfig{ - CorpID: "test_corp_id", - AgentID: 1000002, - } - _, err := NewWeComAppChannel(cfg, msgBus) - if err == nil { - t.Error("expected error for missing corp_secret, got nil") - } - }) - - t.Run("missing agent_id", func(t *testing.T) { - cfg := config.WeComAppConfig{ - CorpID: "test_corp_id", - AgentID: 0, - } - cfg.SetCorpSecret("test_secret") - _, err := NewWeComAppChannel(cfg, msgBus) - if err == nil { - t.Error("expected error for missing agent_id, got nil") - } - }) - - t.Run("valid config", func(t *testing.T) { - cfg := config.WeComAppConfig{ - CorpID: "test_corp_id", - AgentID: 1000002, - AllowFrom: []string{"user1", "user2"}, - } - cfg.SetCorpSecret("test_secret") - ch, err := NewWeComAppChannel(cfg, msgBus) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if ch.Name() != "wecom_app" { - t.Errorf("Name() = %q, want %q", ch.Name(), "wecom_app") - } - if ch.IsRunning() { - t.Error("new channel should not be running") - } - }) -} - -func TestWeComAppChannelIsAllowed(t *testing.T) { - msgBus := bus.NewMessageBus() - - t.Run("empty allowlist allows all", func(t *testing.T) { - cfg := config.WeComAppConfig{ - CorpID: "test_corp_id", - AgentID: 1000002, - AllowFrom: []string{}, - } - cfg.SetCorpSecret("test_secret") - ch, _ := NewWeComAppChannel(cfg, msgBus) - if !ch.IsAllowed("any_user") { - t.Error("empty allowlist should allow all users") - } - }) - - t.Run("allowlist restricts users", func(t *testing.T) { - cfg := config.WeComAppConfig{ - CorpID: "test_corp_id", - AgentID: 1000002, - AllowFrom: []string{"allowed_user"}, - } - cfg.SetCorpSecret("test_secret") - ch, _ := NewWeComAppChannel(cfg, msgBus) - if !ch.IsAllowed("allowed_user") { - t.Error("allowed user should pass allowlist check") - } - if ch.IsAllowed("blocked_user") { - t.Error("non-allowed user should be blocked") - } - }) -} - -func TestWeComAppVerifySignature(t *testing.T) { - msgBus := bus.NewMessageBus() - cfg := config.WeComAppConfig{} - cfg.CorpID = "test_corp_id" - cfg.SetCorpSecret("test_secret") - cfg.AgentID = 1000002 - cfg.SetToken("test_token") - ch, _ := NewWeComAppChannel(cfg, msgBus) - - t.Run("valid signature", func(t *testing.T) { - timestamp := "1234567890" - nonce := "test_nonce" - msgEncrypt := "test_message" - expectedSig := generateSignatureApp("test_token", timestamp, nonce, msgEncrypt) - - if !verifySignature(ch.config.Token(), expectedSig, timestamp, nonce, msgEncrypt) { - t.Error("valid signature should pass verification") - } - }) - - t.Run("invalid signature", func(t *testing.T) { - timestamp := "1234567890" - nonce := "test_nonce" - msgEncrypt := "test_message" - - if verifySignature(ch.config.Token(), "invalid_sig", timestamp, nonce, msgEncrypt) { - t.Error("invalid signature should fail verification") - } - }) - - t.Run("empty token rejects verification (fail-closed)", func(t *testing.T) { - cfgEmpty := config.WeComAppConfig{} - cfgEmpty.CorpID = "test_corp_id" - cfgEmpty.SetCorpSecret("test_secret") - cfgEmpty.AgentID = 1000002 - cfgEmpty.SetToken("") - chEmpty, _ := NewWeComAppChannel(cfgEmpty, msgBus) - - if verifySignature(chEmpty.config.Token(), "any_sig", "any_ts", "any_nonce", "any_msg") { - t.Error("empty token should reject verification (fail-closed)") - } - }) -} - -func TestWeComAppDecryptMessage(t *testing.T) { - msgBus := bus.NewMessageBus() - - t.Run("decrypt without AES key", func(t *testing.T) { - cfg := config.WeComAppConfig{} - cfg.CorpID = "test_corp_id" - cfg.SetCorpSecret("test_secret") - cfg.AgentID = 1000002 - cfg.SetEncodingAESKey("") - ch, _ := NewWeComAppChannel(cfg, msgBus) - - // Without AES key, message should be base64 decoded only - plainText := "hello world" - encoded := base64.StdEncoding.EncodeToString([]byte(plainText)) - - result, err := decryptMessage(encoded, ch.config.EncodingAESKey()) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if result != plainText { - t.Errorf("decryptMessage() = %q, want %q", result, plainText) - } - }) - - t.Run("decrypt with AES key", func(t *testing.T) { - aesKey := generateTestAESKeyApp() - cfg := config.WeComAppConfig{ - CorpID: "test_corp_id", - AgentID: 1000002, - } - cfg.SetCorpSecret("test_secret") - cfg.SetEncodingAESKey(aesKey) - ch, _ := NewWeComAppChannel(cfg, msgBus) - - originalMsg := "Hello" - encrypted, err := encryptTestMessageApp(originalMsg, aesKey) - if err != nil { - t.Fatalf("failed to encrypt test message: %v", err) - } - - result, err := decryptMessage(encrypted, ch.config.EncodingAESKey()) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if result != originalMsg { - t.Errorf("WeComDecryptMessage() = %q, want %q", result, originalMsg) - } - }) - - t.Run("invalid base64", func(t *testing.T) { - cfg := config.WeComAppConfig{ - CorpID: "test_corp_id", - AgentID: 1000002, - } - cfg.SetCorpSecret("test_secret") - cfg.SetEncodingAESKey("") - ch, _ := NewWeComAppChannel(cfg, msgBus) - - _, err := decryptMessage("invalid_base64!!!", ch.config.EncodingAESKey()) - if err == nil { - t.Error("expected error for invalid base64, got nil") - } - }) - - t.Run("invalid AES key", func(t *testing.T) { - cfg := config.WeComAppConfig{} - cfg.CorpID = "test_corp_id" - cfg.SetCorpSecret("test_secret") - cfg.AgentID = 1000002 - cfg.SetEncodingAESKey("invalid_key") - ch, _ := NewWeComAppChannel(cfg, msgBus) - - _, err := decryptMessage(base64.StdEncoding.EncodeToString([]byte("test")), ch.config.EncodingAESKey()) - if err == nil { - t.Error("expected error for invalid AES key, got nil") - } - }) - - t.Run("ciphertext too short", func(t *testing.T) { - aesKey := generateTestAESKeyApp() - cfg := config.WeComAppConfig{} - cfg.CorpID = "test_corp_id" - cfg.SetCorpSecret("test_secret") - cfg.AgentID = 1000002 - cfg.SetEncodingAESKey(aesKey) - ch, _ := NewWeComAppChannel(cfg, msgBus) - - // Encrypt a very short message that results in ciphertext less than block size - shortData := make([]byte, 8) - _, err := decryptMessage(base64.StdEncoding.EncodeToString(shortData), ch.config.EncodingAESKey()) - if err == nil { - t.Error("expected error for short ciphertext, got nil") - } - }) -} - -func TestWeComAppHandleVerification(t *testing.T) { - msgBus := bus.NewMessageBus() - aesKey := generateTestAESKeyApp() - cfg := config.WeComAppConfig{} - cfg.CorpID = "test_corp_id" - cfg.SetCorpSecret("test_secret") - cfg.AgentID = 1000002 - cfg.SetToken("test_token") - cfg.SetEncodingAESKey(aesKey) - ch, _ := NewWeComAppChannel(cfg, msgBus) - - t.Run("valid verification request", func(t *testing.T) { - echostr := "test_echostr_123" - encryptedEchostr, _ := encryptTestMessageApp(echostr, aesKey) - timestamp := "1234567890" - nonce := "test_nonce" - signature := generateSignatureApp("test_token", timestamp, nonce, encryptedEchostr) - - req := httptest.NewRequest( - http.MethodGet, - "/webhook/wecom-app?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce+"&echostr="+encryptedEchostr, - nil, - ) - w := httptest.NewRecorder() - - ch.handleVerification(context.Background(), w, req) - - if w.Code != http.StatusOK { - t.Errorf("status code = %d, want %d", w.Code, http.StatusOK) - } - if w.Body.String() != echostr { - t.Errorf("response body = %q, want %q", w.Body.String(), echostr) - } - }) - - t.Run("missing parameters", func(t *testing.T) { - req := httptest.NewRequest(http.MethodGet, "/webhook/wecom-app?msg_signature=sig×tamp=ts", nil) - w := httptest.NewRecorder() - - ch.handleVerification(context.Background(), w, req) - - if w.Code != http.StatusBadRequest { - t.Errorf("status code = %d, want %d", w.Code, http.StatusBadRequest) - } - }) - - t.Run("invalid signature", func(t *testing.T) { - echostr := "test_echostr" - encryptedEchostr, _ := encryptTestMessageApp(echostr, aesKey) - timestamp := "1234567890" - nonce := "test_nonce" - - req := httptest.NewRequest( - http.MethodGet, - "/webhook/wecom-app?msg_signature=invalid_sig×tamp="+timestamp+"&nonce="+nonce+"&echostr="+encryptedEchostr, - nil, - ) - w := httptest.NewRecorder() - - ch.handleVerification(context.Background(), w, req) - - if w.Code != http.StatusForbidden { - t.Errorf("status code = %d, want %d", w.Code, http.StatusForbidden) - } - }) -} - -func TestWeComAppHandleMessageCallback(t *testing.T) { - msgBus := bus.NewMessageBus() - aesKey := generateTestAESKeyApp() - cfg := config.WeComAppConfig{} - cfg.CorpID = "test_corp_id" - cfg.SetCorpSecret("test_secret") - cfg.AgentID = 1000002 - cfg.SetToken("test_token") - cfg.SetEncodingAESKey(aesKey) - ch, _ := NewWeComAppChannel(cfg, msgBus) - - t.Run("valid message callback", func(t *testing.T) { - // Create XML message - xmlMsg := WeComXMLMessage{ - ToUserName: "corp_id", - FromUserName: "user123", - CreateTime: 1234567890, - MsgType: "text", - Content: "Hello World", - MsgId: 123456, - AgentID: 1000002, - } - xmlData, _ := xml.Marshal(xmlMsg) - - // Encrypt message - encrypted, _ := encryptTestMessageApp(string(xmlData), aesKey) - - // Create encrypted XML wrapper - encryptedWrapper := struct { - XMLName xml.Name `xml:"xml"` - Encrypt string `xml:"Encrypt"` - }{ - Encrypt: encrypted, - } - wrapperData, _ := xml.Marshal(encryptedWrapper) - - timestamp := "1234567890" - nonce := "test_nonce" - signature := generateSignatureApp("test_token", timestamp, nonce, encrypted) - - req := httptest.NewRequest( - http.MethodPost, - "/webhook/wecom-app?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce, - bytes.NewReader(wrapperData), - ) - w := httptest.NewRecorder() - - ch.handleMessageCallback(context.Background(), w, req) - - if w.Code != http.StatusOK { - t.Errorf("status code = %d, want %d", w.Code, http.StatusOK) - } - if w.Body.String() != "success" { - t.Errorf("response body = %q, want %q", w.Body.String(), "success") - } - }) - - t.Run("missing parameters", func(t *testing.T) { - req := httptest.NewRequest(http.MethodPost, "/webhook/wecom-app?msg_signature=sig", nil) - w := httptest.NewRecorder() - - ch.handleMessageCallback(context.Background(), w, req) - - if w.Code != http.StatusBadRequest { - t.Errorf("status code = %d, want %d", w.Code, http.StatusBadRequest) - } - }) - - t.Run("invalid XML", func(t *testing.T) { - timestamp := "1234567890" - nonce := "test_nonce" - signature := generateSignatureApp("test_token", timestamp, nonce, "") - - req := httptest.NewRequest( - http.MethodPost, - "/webhook/wecom-app?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce, - strings.NewReader("invalid xml"), - ) - w := httptest.NewRecorder() - - ch.handleMessageCallback(context.Background(), w, req) - - if w.Code != http.StatusBadRequest { - t.Errorf("status code = %d, want %d", w.Code, http.StatusBadRequest) - } - }) - - t.Run("invalid signature", func(t *testing.T) { - encryptedWrapper := struct { - XMLName xml.Name `xml:"xml"` - Encrypt string `xml:"Encrypt"` - }{ - Encrypt: "encrypted_data", - } - wrapperData, _ := xml.Marshal(encryptedWrapper) - - timestamp := "1234567890" - nonce := "test_nonce" - - req := httptest.NewRequest( - http.MethodPost, - "/webhook/wecom-app?msg_signature=invalid_sig×tamp="+timestamp+"&nonce="+nonce, - bytes.NewReader(wrapperData), - ) - w := httptest.NewRecorder() - - ch.handleMessageCallback(context.Background(), w, req) - - if w.Code != http.StatusForbidden { - t.Errorf("status code = %d, want %d", w.Code, http.StatusForbidden) - } - }) -} - -func TestWeComAppProcessMessage(t *testing.T) { - msgBus := bus.NewMessageBus() - cfg := config.WeComAppConfig{ - CorpID: "test_corp_id", - AgentID: 1000002, - } - cfg.SetCorpSecret("test_secret") - ch, _ := NewWeComAppChannel(cfg, msgBus) - - t.Run("process text message", func(t *testing.T) { - msg := WeComXMLMessage{ - ToUserName: "corp_id", - FromUserName: "user123", - CreateTime: 1234567890, - MsgType: "text", - Content: "Hello World", - MsgId: 123456, - AgentID: 1000002, - } - - // Should not panic - ch.processMessage(context.Background(), msg) - }) - - t.Run("process image message", func(t *testing.T) { - msg := WeComXMLMessage{ - ToUserName: "corp_id", - FromUserName: "user123", - CreateTime: 1234567890, - MsgType: "image", - PicUrl: "https://example.com/image.jpg", - MediaId: "media_123", - MsgId: 123456, - AgentID: 1000002, - } - - // Should not panic - ch.processMessage(context.Background(), msg) - }) - - t.Run("process voice message", func(t *testing.T) { - msg := WeComXMLMessage{ - ToUserName: "corp_id", - FromUserName: "user123", - CreateTime: 1234567890, - MsgType: "voice", - MediaId: "media_123", - Format: "amr", - MsgId: 123456, - AgentID: 1000002, - } - - // Should not panic - ch.processMessage(context.Background(), msg) - }) - - t.Run("skip unsupported message type", func(t *testing.T) { - msg := WeComXMLMessage{ - ToUserName: "corp_id", - FromUserName: "user123", - CreateTime: 1234567890, - MsgType: "video", - MsgId: 123456, - AgentID: 1000002, - } - - // Should not panic - ch.processMessage(context.Background(), msg) - }) - - t.Run("process event message", func(t *testing.T) { - msg := WeComXMLMessage{ - ToUserName: "corp_id", - FromUserName: "user123", - CreateTime: 1234567890, - MsgType: "event", - Event: "subscribe", - MsgId: 123456, - AgentID: 1000002, - } - - // Should not panic - ch.processMessage(context.Background(), msg) - }) -} - -func TestWeComAppHandleWebhook(t *testing.T) { - msgBus := bus.NewMessageBus() - cfg := config.WeComAppConfig{} - cfg.CorpID = "test_corp_id" - cfg.SetCorpSecret("test_secret") - cfg.AgentID = 1000002 - cfg.SetToken("test_token") - ch, _ := NewWeComAppChannel(cfg, msgBus) - - t.Run("GET request calls verification", func(t *testing.T) { - echostr := "test_echostr" - encoded := base64.StdEncoding.EncodeToString([]byte(echostr)) - timestamp := "1234567890" - nonce := "test_nonce" - signature := generateSignatureApp("test_token", timestamp, nonce, encoded) - - req := httptest.NewRequest( - http.MethodGet, - "/webhook/wecom-app?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce+"&echostr="+encoded, - nil, - ) - w := httptest.NewRecorder() - - ch.handleWebhook(w, req) - - if w.Code != http.StatusOK { - t.Errorf("status code = %d, want %d", w.Code, http.StatusOK) - } - }) - - t.Run("POST request calls message callback", func(t *testing.T) { - encryptedWrapper := struct { - XMLName xml.Name `xml:"xml"` - Encrypt string `xml:"Encrypt"` - }{ - Encrypt: base64.StdEncoding.EncodeToString([]byte("test")), - } - wrapperData, _ := xml.Marshal(encryptedWrapper) - - timestamp := "1234567890" - nonce := "test_nonce" - signature := generateSignatureApp("test_token", timestamp, nonce, encryptedWrapper.Encrypt) - - req := httptest.NewRequest( - http.MethodPost, - "/webhook/wecom-app?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce, - bytes.NewReader(wrapperData), - ) - w := httptest.NewRecorder() - - ch.handleWebhook(w, req) - - // Should not be method not allowed - if w.Code == http.StatusMethodNotAllowed { - t.Error("POST request should not return Method Not Allowed") - } - }) - - t.Run("unsupported method", func(t *testing.T) { - req := httptest.NewRequest(http.MethodPut, "/webhook/wecom-app", nil) - w := httptest.NewRecorder() - - ch.handleWebhook(w, req) - - if w.Code != http.StatusMethodNotAllowed { - t.Errorf("status code = %d, want %d", w.Code, http.StatusMethodNotAllowed) - } - }) -} - -func TestWeComAppHandleHealth(t *testing.T) { - msgBus := bus.NewMessageBus() - cfg := config.WeComAppConfig{ - CorpID: "test_corp_id", - AgentID: 1000002, - } - cfg.SetCorpSecret("test_secret") - ch, _ := NewWeComAppChannel(cfg, msgBus) - - req := httptest.NewRequest(http.MethodGet, "/health/wecom-app", nil) - w := httptest.NewRecorder() - - ch.handleHealth(w, req) - - if w.Code != http.StatusOK { - t.Errorf("status code = %d, want %d", w.Code, http.StatusOK) - } - - contentType := w.Header().Get("Content-Type") - if contentType != "application/json" { - t.Errorf("Content-Type = %q, want %q", contentType, "application/json") - } - - body := w.Body.String() - if !strings.Contains(body, "status") || !strings.Contains(body, "running") || !strings.Contains(body, "has_token") { - t.Errorf("response body should contain status, running, and has_token fields, got: %s", body) - } -} - -func TestWeComAppAccessToken(t *testing.T) { - msgBus := bus.NewMessageBus() - cfg := config.WeComAppConfig{ - CorpID: "test_corp_id", - AgentID: 1000002, - } - cfg.SetCorpSecret("test_secret") - ch, _ := NewWeComAppChannel(cfg, msgBus) - - t.Run("get empty access token initially", func(t *testing.T) { - token := ch.getAccessToken() - if token != "" { - t.Errorf("getAccessToken() = %q, want empty string", token) - } - }) - - t.Run("set and get access token", func(t *testing.T) { - ch.tokenMu.Lock() - ch.accessToken = "test_token_123" - ch.tokenExpiry = time.Now().Add(1 * time.Hour) - ch.tokenMu.Unlock() - - token := ch.getAccessToken() - if token != "test_token_123" { - t.Errorf("getAccessToken() = %q, want %q", token, "test_token_123") - } - }) - - t.Run("expired token returns empty", func(t *testing.T) { - ch.tokenMu.Lock() - ch.accessToken = "expired_token" - ch.tokenExpiry = time.Now().Add(-1 * time.Hour) - ch.tokenMu.Unlock() - - token := ch.getAccessToken() - if token != "" { - t.Errorf("getAccessToken() = %q, want empty string for expired token", token) - } - }) -} - -func TestWeComAppMessageStructures(t *testing.T) { - t.Run("WeComTextMessage structure", func(t *testing.T) { - msg := WeComTextMessage{ - ToUser: "user123", - MsgType: "text", - AgentID: 1000002, - } - msg.Text.Content = "Hello World" - - if msg.ToUser != "user123" { - t.Errorf("ToUser = %q, want %q", msg.ToUser, "user123") - } - if msg.MsgType != "text" { - t.Errorf("MsgType = %q, want %q", msg.MsgType, "text") - } - if msg.AgentID != 1000002 { - t.Errorf("AgentID = %d, want %d", msg.AgentID, 1000002) - } - if msg.Text.Content != "Hello World" { - t.Errorf("Text.Content = %q, want %q", msg.Text.Content, "Hello World") - } - - // Test JSON marshaling - jsonData, err := json.Marshal(msg) - if err != nil { - t.Fatalf("failed to marshal JSON: %v", err) - } - - var unmarshaled WeComTextMessage - err = json.Unmarshal(jsonData, &unmarshaled) - if err != nil { - t.Fatalf("failed to unmarshal JSON: %v", err) - } - - if unmarshaled.ToUser != msg.ToUser { - t.Errorf("JSON round-trip failed for ToUser") - } - }) - - t.Run("WeComMarkdownMessage structure", func(t *testing.T) { - msg := WeComMarkdownMessage{ - ToUser: "user123", - MsgType: "markdown", - AgentID: 1000002, - } - msg.Markdown.Content = "# Hello\nWorld" - - if msg.Markdown.Content != "# Hello\nWorld" { - t.Errorf("Markdown.Content = %q, want %q", msg.Markdown.Content, "# Hello\nWorld") - } - - // Test JSON marshaling - jsonData, err := json.Marshal(msg) - if err != nil { - t.Fatalf("failed to marshal JSON: %v", err) - } - - if !bytes.Contains(jsonData, []byte("markdown")) { - t.Error("JSON should contain 'markdown' field") - } - }) - - t.Run("WeComImageMessage structure", func(t *testing.T) { - msg := WeComImageMessage{ - ToUser: "user123", - MsgType: "image", - AgentID: 1000002, - } - msg.Image.MediaID = "media_123456" - - if msg.Image.MediaID != "media_123456" { - t.Errorf("Image.MediaID = %q, want %q", msg.Image.MediaID, "media_123456") - } - if msg.ToUser != "user123" { - t.Errorf("ToUser = %q, want %q", msg.ToUser, "user123") - } - if msg.MsgType != "image" { - t.Errorf("MsgType = %q, want %q", msg.MsgType, "image") - } - if msg.AgentID != 1000002 { - t.Errorf("AgentID = %d, want %d", msg.AgentID, 1000002) - } - }) - - t.Run("WeComAccessTokenResponse structure", func(t *testing.T) { - jsonData := `{ - "errcode": 0, - "errmsg": "ok", - "access_token": "test_access_token", - "expires_in": 7200 - }` - - var resp WeComAccessTokenResponse - err := json.Unmarshal([]byte(jsonData), &resp) - if err != nil { - t.Fatalf("failed to unmarshal JSON: %v", err) - } - - if resp.ErrCode != 0 { - t.Errorf("ErrCode = %d, want %d", resp.ErrCode, 0) - } - if resp.ErrMsg != "ok" { - t.Errorf("ErrMsg = %q, want %q", resp.ErrMsg, "ok") - } - if resp.AccessToken != "test_access_token" { - t.Errorf("AccessToken = %q, want %q", resp.AccessToken, "test_access_token") - } - if resp.ExpiresIn != 7200 { - t.Errorf("ExpiresIn = %d, want %d", resp.ExpiresIn, 7200) - } - }) - - t.Run("WeComSendMessageResponse structure", func(t *testing.T) { - jsonData := `{ - "errcode": 0, - "errmsg": "ok", - "invaliduser": "", - "invalidparty": "", - "invalidtag": "" - }` - - var resp WeComSendMessageResponse - err := json.Unmarshal([]byte(jsonData), &resp) - if err != nil { - t.Fatalf("failed to unmarshal JSON: %v", err) - } - - if resp.ErrCode != 0 { - t.Errorf("ErrCode = %d, want %d", resp.ErrCode, 0) - } - if resp.ErrMsg != "ok" { - t.Errorf("ErrMsg = %q, want %q", resp.ErrMsg, "ok") - } - }) -} - -func TestWeComAppXMLMessageStructure(t *testing.T) { - xmlData := ` - - - - 1234567890 - - - 1234567890123456 - 1000002 -` - - var msg WeComXMLMessage - err := xml.Unmarshal([]byte(xmlData), &msg) - if err != nil { - t.Fatalf("failed to unmarshal XML: %v", err) - } - - if msg.ToUserName != "corp_id" { - t.Errorf("ToUserName = %q, want %q", msg.ToUserName, "corp_id") - } - if msg.FromUserName != "user123" { - t.Errorf("FromUserName = %q, want %q", msg.FromUserName, "user123") - } - if msg.CreateTime != 1234567890 { - t.Errorf("CreateTime = %d, want %d", msg.CreateTime, 1234567890) - } - if msg.MsgType != "text" { - t.Errorf("MsgType = %q, want %q", msg.MsgType, "text") - } - if msg.Content != "Hello World" { - t.Errorf("Content = %q, want %q", msg.Content, "Hello World") - } - if msg.MsgId != 1234567890123456 { - t.Errorf("MsgId = %d, want %d", msg.MsgId, 1234567890123456) - } - if msg.AgentID != 1000002 { - t.Errorf("AgentID = %d, want %d", msg.AgentID, 1000002) - } -} - -func TestWeComAppXMLMessageImage(t *testing.T) { - xmlData := ` - - - - 1234567890 - - - - 1234567890123456 - 1000002 -` - - var msg WeComXMLMessage - err := xml.Unmarshal([]byte(xmlData), &msg) - if err != nil { - t.Fatalf("failed to unmarshal XML: %v", err) - } - - if msg.MsgType != "image" { - t.Errorf("MsgType = %q, want %q", msg.MsgType, "image") - } - if msg.PicUrl != "https://example.com/image.jpg" { - t.Errorf("PicUrl = %q, want %q", msg.PicUrl, "https://example.com/image.jpg") - } - if msg.MediaId != "media_123" { - t.Errorf("MediaId = %q, want %q", msg.MediaId, "media_123") - } -} - -func TestWeComAppXMLMessageVoice(t *testing.T) { - xmlData := ` - - - - 1234567890 - - - - 1234567890123456 - 1000002 -` - - var msg WeComXMLMessage - err := xml.Unmarshal([]byte(xmlData), &msg) - if err != nil { - t.Fatalf("failed to unmarshal XML: %v", err) - } - - if msg.MsgType != "voice" { - t.Errorf("MsgType = %q, want %q", msg.MsgType, "voice") - } - if msg.Format != "amr" { - t.Errorf("Format = %q, want %q", msg.Format, "amr") - } -} - -func TestWeComAppXMLMessageLocation(t *testing.T) { - xmlData := ` - - - - 1234567890 - - 39.9042 - 116.4074 - 16 - - 1234567890123456 - 1000002 -` - - var msg WeComXMLMessage - err := xml.Unmarshal([]byte(xmlData), &msg) - if err != nil { - t.Fatalf("failed to unmarshal XML: %v", err) - } - - if msg.MsgType != "location" { - t.Errorf("MsgType = %q, want %q", msg.MsgType, "location") - } - if msg.LocationX != 39.9042 { - t.Errorf("LocationX = %f, want %f", msg.LocationX, 39.9042) - } - if msg.LocationY != 116.4074 { - t.Errorf("LocationY = %f, want %f", msg.LocationY, 116.4074) - } - if msg.Scale != 16 { - t.Errorf("Scale = %d, want %d", msg.Scale, 16) - } - if msg.Label != "Beijing" { - t.Errorf("Label = %q, want %q", msg.Label, "Beijing") - } -} - -func TestWeComAppXMLMessageLink(t *testing.T) { - xmlData := ` - - - - 1234567890 - - <![CDATA[Link Title]]> - - - 1234567890123456 - 1000002 -` - - var msg WeComXMLMessage - err := xml.Unmarshal([]byte(xmlData), &msg) - if err != nil { - t.Fatalf("failed to unmarshal XML: %v", err) - } - - if msg.MsgType != "link" { - t.Errorf("MsgType = %q, want %q", msg.MsgType, "link") - } - if msg.Title != "Link Title" { - t.Errorf("Title = %q, want %q", msg.Title, "Link Title") - } - if msg.Description != "Link Description" { - t.Errorf("Description = %q, want %q", msg.Description, "Link Description") - } - if msg.Url != "https://example.com" { - t.Errorf("Url = %q, want %q", msg.Url, "https://example.com") - } -} - -func TestWeComAppXMLMessageEvent(t *testing.T) { - xmlData := ` - - - - 1234567890 - - - - 1000002 -` - - var msg WeComXMLMessage - err := xml.Unmarshal([]byte(xmlData), &msg) - if err != nil { - t.Fatalf("failed to unmarshal XML: %v", err) - } - - if msg.MsgType != "event" { - t.Errorf("MsgType = %q, want %q", msg.MsgType, "event") - } - if msg.Event != "subscribe" { - t.Errorf("Event = %q, want %q", msg.Event, "subscribe") - } - if msg.EventKey != "event_key_123" { - t.Errorf("EventKey = %q, want %q", msg.EventKey, "event_key_123") - } -} diff --git a/pkg/channels/wecom/bot.go b/pkg/channels/wecom/bot.go deleted file mode 100644 index 22461b768..000000000 --- a/pkg/channels/wecom/bot.go +++ /dev/null @@ -1,499 +0,0 @@ -package wecom - -import ( - "bytes" - "context" - "encoding/json" - "encoding/xml" - "fmt" - "io" - "net/http" - "strings" - "time" - - "github.com/sipeed/picoclaw/pkg/bus" - "github.com/sipeed/picoclaw/pkg/channels" - "github.com/sipeed/picoclaw/pkg/config" - "github.com/sipeed/picoclaw/pkg/identity" - "github.com/sipeed/picoclaw/pkg/logger" - "github.com/sipeed/picoclaw/pkg/utils" -) - -// WeComBotChannel implements the Channel interface for WeCom Bot (企业微信智能机器人) -// Uses webhook callback mode - simpler than WeCom App but only supports passive replies -type WeComBotChannel struct { - *channels.BaseChannel - config config.WeComConfig - client *http.Client - ctx context.Context - cancel context.CancelFunc - processedMsgs *MessageDeduplicator -} - -// WeComBotMessage represents the JSON message structure from WeCom Bot (AIBOT) -type WeComBotMessage struct { - MsgID string `json:"msgid"` - AIBotID string `json:"aibotid"` - ChatID string `json:"chatid"` // Session ID, only present for group chats - ChatType string `json:"chattype"` // "single" for DM, "group" for group chat - From struct { - UserID string `json:"userid"` - } `json:"from"` - ResponseURL string `json:"response_url"` - MsgType string `json:"msgtype"` // text, image, voice, file, mixed - Text struct { - Content string `json:"content"` - } `json:"text"` - Image struct { - URL string `json:"url"` - } `json:"image"` - Voice struct { - Content string `json:"content"` // Voice to text content - } `json:"voice"` - File struct { - URL string `json:"url"` - } `json:"file"` - Mixed struct { - MsgItem []struct { - MsgType string `json:"msgtype"` - Text struct { - Content string `json:"content"` - } `json:"text"` - Image struct { - URL string `json:"url"` - } `json:"image"` - } `json:"msg_item"` - } `json:"mixed"` - Quote struct { - MsgType string `json:"msgtype"` - Text struct { - Content string `json:"content"` - } `json:"text"` - } `json:"quote"` -} - -// WeComBotReplyMessage represents the reply message structure -type WeComBotReplyMessage struct { - MsgType string `json:"msgtype"` - Text struct { - Content string `json:"content"` - } `json:"text,omitempty"` -} - -// NewWeComBotChannel creates a new WeCom Bot channel instance -func NewWeComBotChannel(cfg config.WeComConfig, messageBus *bus.MessageBus) (*WeComBotChannel, error) { - if cfg.Token() == "" || cfg.WebhookURL == "" { - return nil, fmt.Errorf("wecom token and webhook_url are required") - } - - base := channels.NewBaseChannel("wecom", cfg, messageBus, cfg.AllowFrom, - channels.WithMaxMessageLength(2048), - channels.WithGroupTrigger(cfg.GroupTrigger), - channels.WithReasoningChannelID(cfg.ReasoningChannelID), - ) - - // Client timeout must be >= the configured ReplyTimeout so the - // per-request context deadline is always the effective limit. - clientTimeout := 30 * time.Second - if d := time.Duration(cfg.ReplyTimeout) * time.Second; d > clientTimeout { - clientTimeout = d - } - - ctx, cancel := context.WithCancel(context.Background()) - return &WeComBotChannel{ - BaseChannel: base, - config: cfg, - client: &http.Client{Timeout: clientTimeout}, - ctx: ctx, - cancel: cancel, - processedMsgs: NewMessageDeduplicator(wecomMaxProcessedMessages), - }, nil -} - -// Name returns the channel name -func (c *WeComBotChannel) Name() string { - return "wecom" -} - -// Start initializes the WeCom Bot channel -func (c *WeComBotChannel) Start(ctx context.Context) error { - logger.InfoC("wecom", "Starting WeCom Bot channel...") - - // Cancel the context created in the constructor to avoid a resource leak. - if c.cancel != nil { - c.cancel() - } - c.ctx, c.cancel = context.WithCancel(ctx) - - c.SetRunning(true) - logger.InfoC("wecom", "WeCom Bot channel started") - - return nil -} - -// Stop gracefully stops the WeCom Bot channel -func (c *WeComBotChannel) Stop(ctx context.Context) error { - logger.InfoC("wecom", "Stopping WeCom Bot channel...") - - if c.cancel != nil { - c.cancel() - } - - c.SetRunning(false) - logger.InfoC("wecom", "WeCom Bot channel stopped") - return nil -} - -// Send sends a message to WeCom user via webhook API -// Note: WeCom Bot can only reply within the configured timeout (default 5 seconds) of receiving a message -// For delayed responses, we use the webhook URL -func (c *WeComBotChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { - if !c.IsRunning() { - return channels.ErrNotRunning - } - - logger.DebugCF("wecom", "Sending message via webhook", map[string]any{ - "chat_id": msg.ChatID, - "preview": utils.Truncate(msg.Content, 100), - }) - - return c.sendWebhookReply(ctx, msg.ChatID, msg.Content) -} - -// WebhookPath returns the path for registering on the shared HTTP server. -func (c *WeComBotChannel) WebhookPath() string { - if c.config.WebhookPath != "" { - return c.config.WebhookPath - } - return "/webhook/wecom" -} - -// ServeHTTP implements http.Handler for the shared HTTP server. -func (c *WeComBotChannel) ServeHTTP(w http.ResponseWriter, r *http.Request) { - c.handleWebhook(w, r) -} - -// HealthPath returns the health check endpoint path. -func (c *WeComBotChannel) HealthPath() string { - return "/health/wecom" -} - -// HealthHandler handles health check requests. -func (c *WeComBotChannel) HealthHandler(w http.ResponseWriter, r *http.Request) { - c.handleHealth(w, r) -} - -// handleWebhook handles incoming webhook requests from WeCom -func (c *WeComBotChannel) handleWebhook(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - - if r.Method == http.MethodGet { - // Handle verification request - c.handleVerification(ctx, w, r) - return - } - - if r.Method == http.MethodPost { - // Handle message callback - c.handleMessageCallback(ctx, w, r) - return - } - - http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) -} - -// handleVerification handles the URL verification request from WeCom -func (c *WeComBotChannel) handleVerification(ctx context.Context, w http.ResponseWriter, r *http.Request) { - query := r.URL.Query() - msgSignature := query.Get("msg_signature") - timestamp := query.Get("timestamp") - nonce := query.Get("nonce") - echostr := query.Get("echostr") - - if msgSignature == "" || timestamp == "" || nonce == "" || echostr == "" { - http.Error(w, "Missing parameters", http.StatusBadRequest) - return - } - - // Verify signature - if !verifySignature(c.config.Token(), msgSignature, timestamp, nonce, echostr) { - logger.WarnC("wecom", "Signature verification failed") - http.Error(w, "Invalid signature", http.StatusForbidden) - return - } - - // Decrypt echostr - // For AIBOT (智能机器人), receiveid should be empty string "" - // Reference: https://developer.work.weixin.qq.com/document/path/101033 - decryptedEchoStr, err := decryptMessageWithVerify(echostr, c.config.EncodingAESKey(), "") - if err != nil { - logger.ErrorCF("wecom", "Failed to decrypt echostr", map[string]any{ - "error": err.Error(), - }) - http.Error(w, "Decryption failed", http.StatusInternalServerError) - return - } - - // Remove BOM and whitespace as per WeCom documentation - // The response must be plain text without quotes, BOM, or newlines - decryptedEchoStr = strings.TrimSpace(decryptedEchoStr) - decryptedEchoStr = strings.TrimPrefix(decryptedEchoStr, "\xef\xbb\xbf") // Remove UTF-8 BOM - w.Write([]byte(decryptedEchoStr)) -} - -// handleMessageCallback handles incoming messages from WeCom -func (c *WeComBotChannel) handleMessageCallback(ctx context.Context, w http.ResponseWriter, r *http.Request) { - query := r.URL.Query() - msgSignature := query.Get("msg_signature") - timestamp := query.Get("timestamp") - nonce := query.Get("nonce") - - if msgSignature == "" || timestamp == "" || nonce == "" { - http.Error(w, "Missing parameters", http.StatusBadRequest) - return - } - - // Read request body - body, err := io.ReadAll(r.Body) - if err != nil { - http.Error(w, "Failed to read body", http.StatusBadRequest) - return - } - defer r.Body.Close() - - // Parse XML to get encrypted message - var encryptedMsg struct { - XMLName xml.Name `xml:"xml"` - ToUserName string `xml:"ToUserName"` - Encrypt string `xml:"Encrypt"` - AgentID string `xml:"AgentID"` - } - - if err = xml.Unmarshal(body, &encryptedMsg); err != nil { - logger.ErrorCF("wecom", "Failed to parse XML", map[string]any{ - "error": err.Error(), - }) - http.Error(w, "Invalid XML", http.StatusBadRequest) - return - } - - // Verify signature - if !verifySignature(c.config.Token(), msgSignature, timestamp, nonce, encryptedMsg.Encrypt) { - logger.WarnC("wecom", "Message signature verification failed") - http.Error(w, "Invalid signature", http.StatusForbidden) - return - } - - // Decrypt message - // For AIBOT (智能机器人), receiveid should be empty string "" - // Reference: https://developer.work.weixin.qq.com/document/path/101033 - decryptedMsg, err := decryptMessageWithVerify(encryptedMsg.Encrypt, c.config.EncodingAESKey(), "") - if err != nil { - logger.ErrorCF("wecom", "Failed to decrypt message", map[string]any{ - "error": err.Error(), - }) - http.Error(w, "Decryption failed", http.StatusInternalServerError) - return - } - - // Parse decrypted JSON message (AIBOT uses JSON format) - var msg WeComBotMessage - if err := json.Unmarshal([]byte(decryptedMsg), &msg); err != nil { - logger.ErrorCF("wecom", "Failed to parse decrypted message", map[string]any{ - "error": err.Error(), - }) - http.Error(w, "Invalid message format", http.StatusBadRequest) - return - } - - // Process the message with the channel's long-lived context (not the HTTP - // request context, which is canceled as soon as we return the response). - go c.processMessage(c.ctx, msg) - - // Return success response immediately - // WeCom Bot requires response within configured timeout (default 5 seconds) - w.Write([]byte("success")) -} - -// processMessage processes the received message -func (c *WeComBotChannel) processMessage(ctx context.Context, msg WeComBotMessage) { - // Skip unsupported message types - if msg.MsgType != "text" && msg.MsgType != "image" && msg.MsgType != "voice" && msg.MsgType != "file" && - msg.MsgType != "mixed" { - logger.DebugCF("wecom", "Skipping non-supported message type", map[string]any{ - "msg_type": msg.MsgType, - }) - return - } - - // Message deduplication: Use msg_id to prevent duplicate processing - msgID := msg.MsgID - if !c.processedMsgs.MarkMessageProcessed(msgID) { - logger.DebugCF("wecom", "Skipping duplicate message", map[string]any{ - "msg_id": msgID, - }) - return - } - - senderID := msg.From.UserID - - // Determine if this is a group chat or direct message - // ChatType: "single" for DM, "group" for group chat - isGroupChat := msg.ChatType == "group" - - var chatID, peerKind, peerID string - if isGroupChat { - // Group chat: use ChatID as chatID and peer_id - chatID = msg.ChatID - peerKind = "group" - peerID = msg.ChatID - } else { - // Direct message: use senderID as chatID and peer_id - chatID = senderID - peerKind = "direct" - peerID = senderID - } - - // Extract content based on message type - var content string - switch msg.MsgType { - case "text": - content = msg.Text.Content - case "voice": - content = msg.Voice.Content // Voice to text content - case "mixed": - // For mixed messages, concatenate text items - for _, item := range msg.Mixed.MsgItem { - if item.MsgType == "text" { - content += item.Text.Content - } - } - case "image", "file": - // For image and file, we don't have text content - content = "" - } - - // Build metadata - peer := bus.Peer{Kind: peerKind, ID: peerID} - - // In group chats, apply unified group trigger filtering - if isGroupChat { - respond, cleaned := c.ShouldRespondInGroup(false, content) - if !respond { - return - } - content = cleaned - } - - metadata := map[string]string{ - "msg_type": msg.MsgType, - "msg_id": msg.MsgID, - "platform": "wecom", - "response_url": msg.ResponseURL, - } - if isGroupChat { - metadata["chat_id"] = msg.ChatID - metadata["sender_id"] = senderID - } - - logger.DebugCF("wecom", "Received message", map[string]any{ - "sender_id": senderID, - "msg_type": msg.MsgType, - "peer_kind": peerKind, - "is_group_chat": isGroupChat, - "preview": utils.Truncate(content, 50), - }) - - // Build sender info - sender := bus.SenderInfo{ - Platform: "wecom", - PlatformID: senderID, - CanonicalID: identity.BuildCanonicalID("wecom", senderID), - } - - if !c.IsAllowedSender(sender) { - return - } - - // Handle the message through the base channel - c.HandleMessage(ctx, peer, msg.MsgID, senderID, chatID, content, nil, metadata, sender) -} - -// sendWebhookReply sends a reply using the webhook URL -func (c *WeComBotChannel) sendWebhookReply(ctx context.Context, userID, content string) error { - reply := WeComBotReplyMessage{ - MsgType: "text", - } - reply.Text.Content = content - - jsonData, err := json.Marshal(reply) - if err != nil { - return fmt.Errorf("failed to marshal reply: %w", err) - } - - // Use configurable timeout (default 5 seconds) - timeout := c.config.ReplyTimeout - if timeout <= 0 { - timeout = 5 - } - - reqCtx, cancel := context.WithTimeout(ctx, time.Duration(timeout)*time.Second) - defer cancel() - - req, err := http.NewRequestWithContext(reqCtx, http.MethodPost, c.config.WebhookURL, bytes.NewBuffer(jsonData)) - if err != nil { - return fmt.Errorf("failed to create request: %w", err) - } - req.Header.Set("Content-Type", "application/json") - - resp, err := c.client.Do(req) - if err != nil { - return channels.ClassifyNetError(err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - body, readErr := io.ReadAll(resp.Body) - if readErr != nil { - return channels.ClassifySendError( - resp.StatusCode, - fmt.Errorf("reading webhook error response: %w", readErr), - ) - } - return channels.ClassifySendError( - resp.StatusCode, - fmt.Errorf("webhook API error: %s", string(body)), - ) - } - - body, err := io.ReadAll(resp.Body) - if err != nil { - return fmt.Errorf("failed to read response: %w", err) - } - - // Check response - var result struct { - ErrCode int `json:"errcode"` - ErrMsg string `json:"errmsg"` - } - if err := json.Unmarshal(body, &result); err != nil { - return fmt.Errorf("failed to parse response: %w", err) - } - - if result.ErrCode != 0 { - return fmt.Errorf("webhook API error: %s (code: %d)", result.ErrMsg, result.ErrCode) - } - - return nil -} - -// handleHealth handles health check requests -func (c *WeComBotChannel) handleHealth(w http.ResponseWriter, r *http.Request) { - status := map[string]any{ - "status": "ok", - "running": c.IsRunning(), - } - - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(status) -} diff --git a/pkg/channels/wecom/bot_test.go b/pkg/channels/wecom/bot_test.go deleted file mode 100644 index 7b50a86f7..000000000 --- a/pkg/channels/wecom/bot_test.go +++ /dev/null @@ -1,734 +0,0 @@ -package wecom - -import ( - "bytes" - "context" - "crypto/aes" - "crypto/cipher" - "crypto/sha1" - "encoding/base64" - "encoding/binary" - "encoding/json" - "encoding/xml" - "fmt" - "net/http" - "net/http/httptest" - "sort" - "strings" - "testing" - - "github.com/sipeed/picoclaw/pkg/bus" - "github.com/sipeed/picoclaw/pkg/config" -) - -// generateTestAESKey generates a valid test AES key -func generateTestAESKey() string { - // AES key needs to be 32 bytes (256 bits) for AES-256 - key := make([]byte, 32) - for i := range key { - key[i] = byte(i) - } - // Return base64 encoded key without padding - return base64.StdEncoding.EncodeToString(key)[:43] -} - -// encryptTestMessage encrypts a message for testing (AIBOT JSON format) -func encryptTestMessage(message, aesKey string) (string, error) { - // Decode AES key - key, err := base64.StdEncoding.DecodeString(aesKey + "=") - if err != nil { - return "", err - } - - // Prepare message: random(16) + msg_len(4) + msg + receiveid - random := make([]byte, 0, 16) - for i := range 16 { - random = append(random, byte(i)) - } - - msgBytes := []byte(message) - receiveID := []byte("test_aibot_id") - - msgLen := uint32(len(msgBytes)) - lenBytes := make([]byte, 4) - binary.BigEndian.PutUint32(lenBytes, msgLen) - - plainText := append(random, lenBytes...) - plainText = append(plainText, msgBytes...) - plainText = append(plainText, receiveID...) - - // PKCS7 padding - blockSize := aes.BlockSize - padding := blockSize - len(plainText)%blockSize - padText := bytes.Repeat([]byte{byte(padding)}, padding) - plainText = append(plainText, padText...) - - // Encrypt - block, err := aes.NewCipher(key) - if err != nil { - return "", err - } - - mode := cipher.NewCBCEncrypter(block, key[:aes.BlockSize]) - cipherText := make([]byte, len(plainText)) - mode.CryptBlocks(cipherText, plainText) - - return base64.StdEncoding.EncodeToString(cipherText), nil -} - -// generateSignature generates a signature for testing -func generateSignature(token, timestamp, nonce, msgEncrypt string) string { - params := []string{token, timestamp, nonce, msgEncrypt} - sort.Strings(params) - str := strings.Join(params, "") - hash := sha1.Sum([]byte(str)) - return fmt.Sprintf("%x", hash) -} - -func TestNewWeComBotChannel(t *testing.T) { - msgBus := bus.NewMessageBus() - - t.Run("missing token", func(t *testing.T) { - cfg := config.WeComConfig{} - cfg.SetToken("") - cfg.WebhookURL = "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test" - _, err := NewWeComBotChannel(cfg, msgBus) - if err == nil { - t.Error("expected error for missing token, got nil") - } - }) - - t.Run("missing webhook_url", func(t *testing.T) { - cfg := config.WeComConfig{} - cfg.SetToken("test_token") - cfg.WebhookURL = "" - _, err := NewWeComBotChannel(cfg, msgBus) - if err == nil { - t.Error("expected error for missing webhook_url, got nil") - } - }) - - t.Run("valid config", func(t *testing.T) { - cfg := config.WeComConfig{} - cfg.SetToken("test_token") - cfg.WebhookURL = "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test" - cfg.AllowFrom = []string{"user1", "user2"} - ch, err := NewWeComBotChannel(cfg, msgBus) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if ch.Name() != "wecom" { - t.Errorf("Name() = %q, want %q", ch.Name(), "wecom") - } - if ch.IsRunning() { - t.Error("new channel should not be running") - } - }) -} - -func TestWeComBotChannelIsAllowed(t *testing.T) { - msgBus := bus.NewMessageBus() - - t.Run("empty allowlist allows all", func(t *testing.T) { - cfg := config.WeComConfig{} - cfg.SetToken("test_token") - cfg.WebhookURL = "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test" - cfg.AllowFrom = []string{} - ch, _ := NewWeComBotChannel(cfg, msgBus) - if !ch.IsAllowed("any_user") { - t.Error("empty allowlist should allow all users") - } - }) - - t.Run("allowlist restricts users", func(t *testing.T) { - cfg := config.WeComConfig{} - cfg.SetToken("test_token") - cfg.WebhookURL = "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test" - cfg.AllowFrom = []string{"allowed_user"} - ch, _ := NewWeComBotChannel(cfg, msgBus) - if !ch.IsAllowed("allowed_user") { - t.Error("allowed user should pass allowlist check") - } - if ch.IsAllowed("blocked_user") { - t.Error("non-allowed user should be blocked") - } - }) -} - -func TestWeComBotVerifySignature(t *testing.T) { - msgBus := bus.NewMessageBus() - cfg := config.WeComConfig{} - cfg.SetToken("test_token") - cfg.WebhookURL = "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test" - ch, _ := NewWeComBotChannel(cfg, msgBus) - - t.Run("valid signature", func(t *testing.T) { - timestamp := "1234567890" - nonce := "test_nonce" - msgEncrypt := "test_message" - expectedSig := generateSignature("test_token", timestamp, nonce, msgEncrypt) - - if !verifySignature(ch.config.Token(), expectedSig, timestamp, nonce, msgEncrypt) { - t.Error("valid signature should pass verification") - } - }) - - t.Run("invalid signature", func(t *testing.T) { - timestamp := "1234567890" - nonce := "test_nonce" - msgEncrypt := "test_message" - - if verifySignature(ch.config.Token(), "invalid_sig", timestamp, nonce, msgEncrypt) { - t.Error("invalid signature should fail verification") - } - }) - - t.Run("empty token rejects verification (fail-closed)", func(t *testing.T) { - cfgEmpty := config.WeComConfig{} - cfgEmpty.SetToken("") - cfgEmpty.WebhookURL = "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test" - chEmpty := &WeComBotChannel{ - config: cfgEmpty, - } - - if verifySignature(chEmpty.config.Token(), "any_sig", "any_ts", "any_nonce", "any_msg") { - t.Error("empty token should reject verification (fail-closed)") - } - }) -} - -func TestWeComBotDecryptMessage(t *testing.T) { - msgBus := bus.NewMessageBus() - - t.Run("decrypt without AES key", func(t *testing.T) { - cfg := config.WeComConfig{} - cfg.SetToken("test_token") - cfg.WebhookURL = "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test" - cfg.SetEncodingAESKey("") - ch, _ := NewWeComBotChannel(cfg, msgBus) - - // Without AES key, message should be base64 decoded only - plainText := "hello world" - encoded := base64.StdEncoding.EncodeToString([]byte(plainText)) - - result, err := decryptMessage(encoded, ch.config.EncodingAESKey()) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if result != plainText { - t.Errorf("decryptMessage() = %q, want %q", result, plainText) - } - }) - - t.Run("decrypt with AES key", func(t *testing.T) { - aesKey := generateTestAESKey() - cfg := config.WeComConfig{} - cfg.SetToken("test_token") - cfg.WebhookURL = "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test" - cfg.SetEncodingAESKey(aesKey) - ch, _ := NewWeComBotChannel(cfg, msgBus) - - originalMsg := "Hello" - encrypted, err := encryptTestMessage(originalMsg, aesKey) - if err != nil { - t.Fatalf("failed to encrypt test message: %v", err) - } - - result, err := decryptMessage(encrypted, ch.config.EncodingAESKey()) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if result != originalMsg { - t.Errorf("WeComDecryptMessage() = %q, want %q", result, originalMsg) - } - }) - - t.Run("invalid base64", func(t *testing.T) { - cfg := config.WeComConfig{} - cfg.SetToken("test_token") - cfg.WebhookURL = "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test" - cfg.SetEncodingAESKey("") - ch, _ := NewWeComBotChannel(cfg, msgBus) - - _, err := decryptMessage("invalid_base64!!!", ch.config.EncodingAESKey()) - if err == nil { - t.Error("expected error for invalid base64, got nil") - } - }) - - t.Run("invalid AES key", func(t *testing.T) { - cfg := config.WeComConfig{} - cfg.SetToken("test_token") - cfg.WebhookURL = "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test" - cfg.SetEncodingAESKey("invalid_key") - ch, _ := NewWeComBotChannel(cfg, msgBus) - - _, err := decryptMessage(base64.StdEncoding.EncodeToString([]byte("test")), ch.config.EncodingAESKey()) - if err == nil { - t.Error("expected error for invalid AES key, got nil") - } - }) -} - -func TestWeComBotPKCS7Unpad(t *testing.T) { - tests := []struct { - name string - input []byte - expected []byte - }{ - { - name: "empty input", - input: []byte{}, - expected: []byte{}, - }, - { - name: "valid padding 3 bytes", - input: append([]byte("hello"), bytes.Repeat([]byte{3}, 3)...), - expected: []byte("hello"), - }, - { - name: "valid padding 16 bytes (full block)", - input: append([]byte("123456789012345"), bytes.Repeat([]byte{16}, 16)...), - expected: []byte("123456789012345"), - }, - { - name: "invalid padding larger than data", - input: []byte{20}, - expected: nil, // should return error - }, - { - name: "invalid padding zero", - input: append([]byte("test"), byte(0)), - expected: nil, // should return error - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result, err := pkcs7Unpad(tt.input) - if tt.expected == nil { - // This case should return an error - if err == nil { - t.Errorf("pkcs7Unpad() expected error for invalid padding, got result: %v", result) - } - return - } - if err != nil { - t.Errorf("pkcs7Unpad() unexpected error: %v", err) - return - } - if !bytes.Equal(result, tt.expected) { - t.Errorf("pkcs7Unpad() = %v, want %v", result, tt.expected) - } - }) - } -} - -func TestWeComBotHandleVerification(t *testing.T) { - msgBus := bus.NewMessageBus() - aesKey := generateTestAESKey() - cfg := config.WeComConfig{} - cfg.SetToken("test_token") - cfg.SetEncodingAESKey(aesKey) - cfg.WebhookURL = "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test" - ch, _ := NewWeComBotChannel(cfg, msgBus) - - t.Run("valid verification request", func(t *testing.T) { - echostr := "test_echostr_123" - encryptedEchostr, _ := encryptTestMessage(echostr, aesKey) - timestamp := "1234567890" - nonce := "test_nonce" - signature := generateSignature("test_token", timestamp, nonce, encryptedEchostr) - - req := httptest.NewRequest( - http.MethodGet, - "/webhook/wecom?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce+"&echostr="+encryptedEchostr, - nil, - ) - w := httptest.NewRecorder() - - ch.handleVerification(context.Background(), w, req) - - if w.Code != http.StatusOK { - t.Errorf("status code = %d, want %d", w.Code, http.StatusOK) - } - if w.Body.String() != echostr { - t.Errorf("response body = %q, want %q", w.Body.String(), echostr) - } - }) - - t.Run("missing parameters", func(t *testing.T) { - req := httptest.NewRequest(http.MethodGet, "/webhook/wecom?msg_signature=sig×tamp=ts", nil) - w := httptest.NewRecorder() - - ch.handleVerification(context.Background(), w, req) - - if w.Code != http.StatusBadRequest { - t.Errorf("status code = %d, want %d", w.Code, http.StatusBadRequest) - } - }) - - t.Run("invalid signature", func(t *testing.T) { - echostr := "test_echostr" - encryptedEchostr, _ := encryptTestMessage(echostr, aesKey) - timestamp := "1234567890" - nonce := "test_nonce" - - req := httptest.NewRequest( - http.MethodGet, - "/webhook/wecom?msg_signature=invalid_sig×tamp="+timestamp+"&nonce="+nonce+"&echostr="+encryptedEchostr, - nil, - ) - w := httptest.NewRecorder() - - ch.handleVerification(context.Background(), w, req) - - if w.Code != http.StatusForbidden { - t.Errorf("status code = %d, want %d", w.Code, http.StatusForbidden) - } - }) -} - -func TestWeComBotHandleMessageCallback(t *testing.T) { - msgBus := bus.NewMessageBus() - aesKey := generateTestAESKey() - cfg := config.WeComConfig{} - cfg.SetToken("test_token") - cfg.SetEncodingAESKey(aesKey) - cfg.WebhookURL = "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test" - ch, _ := NewWeComBotChannel(cfg, msgBus) - - runBotMessageCallback := func(t *testing.T, jsonMsg string) *httptest.ResponseRecorder { - t.Helper() - encrypted, _ := encryptTestMessage(jsonMsg, aesKey) - encryptedWrapper := struct { - XMLName xml.Name `xml:"xml"` - Encrypt string `xml:"Encrypt"` - }{ - Encrypt: encrypted, - } - wrapperData, _ := xml.Marshal(encryptedWrapper) - timestamp := "1234567890" - nonce := "test_nonce" - signature := generateSignature("test_token", timestamp, nonce, encrypted) - req := httptest.NewRequest( - http.MethodPost, - "/webhook/wecom?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce, - bytes.NewReader(wrapperData), - ) - w := httptest.NewRecorder() - ch.handleMessageCallback(context.Background(), w, req) - return w - } - - t.Run("valid direct message callback", func(t *testing.T) { - w := runBotMessageCallback(t, `{ - "msgid": "test_msg_id_123", - "aibotid": "test_aibot_id", - "chattype": "single", - "from": {"userid": "user123"}, - "response_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", - "msgtype": "text", - "text": {"content": "Hello World"} - }`) - if w.Code != http.StatusOK { - t.Errorf("status code = %d, want %d", w.Code, http.StatusOK) - } - if w.Body.String() != "success" { - t.Errorf("response body = %q, want %q", w.Body.String(), "success") - } - }) - - t.Run("valid group message callback", func(t *testing.T) { - w := runBotMessageCallback(t, `{ - "msgid": "test_msg_id_456", - "aibotid": "test_aibot_id", - "chatid": "group_chat_id_123", - "chattype": "group", - "from": {"userid": "user456"}, - "response_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", - "msgtype": "text", - "text": {"content": "Hello Group"} - }`) - if w.Code != http.StatusOK { - t.Errorf("status code = %d, want %d", w.Code, http.StatusOK) - } - if w.Body.String() != "success" { - t.Errorf("response body = %q, want %q", w.Body.String(), "success") - } - }) - - t.Run("missing parameters", func(t *testing.T) { - req := httptest.NewRequest(http.MethodPost, "/webhook/wecom?msg_signature=sig", nil) - w := httptest.NewRecorder() - - ch.handleMessageCallback(context.Background(), w, req) - - if w.Code != http.StatusBadRequest { - t.Errorf("status code = %d, want %d", w.Code, http.StatusBadRequest) - } - }) - - t.Run("invalid XML", func(t *testing.T) { - timestamp := "1234567890" - nonce := "test_nonce" - signature := generateSignature("test_token", timestamp, nonce, "") - - req := httptest.NewRequest( - http.MethodPost, - "/webhook/wecom?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce, - strings.NewReader("invalid xml"), - ) - w := httptest.NewRecorder() - - ch.handleMessageCallback(context.Background(), w, req) - - if w.Code != http.StatusBadRequest { - t.Errorf("status code = %d, want %d", w.Code, http.StatusBadRequest) - } - }) - - t.Run("invalid signature", func(t *testing.T) { - encryptedWrapper := struct { - XMLName xml.Name `xml:"xml"` - Encrypt string `xml:"Encrypt"` - }{ - Encrypt: "encrypted_data", - } - wrapperData, _ := xml.Marshal(encryptedWrapper) - - timestamp := "1234567890" - nonce := "test_nonce" - - req := httptest.NewRequest( - http.MethodPost, - "/webhook/wecom?msg_signature=invalid_sig×tamp="+timestamp+"&nonce="+nonce, - bytes.NewReader(wrapperData), - ) - w := httptest.NewRecorder() - - ch.handleMessageCallback(context.Background(), w, req) - - if w.Code != http.StatusForbidden { - t.Errorf("status code = %d, want %d", w.Code, http.StatusForbidden) - } - }) -} - -func TestWeComBotProcessMessage(t *testing.T) { - msgBus := bus.NewMessageBus() - cfg := config.WeComConfig{} - cfg.SetToken("test_token") - cfg.WebhookURL = "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test" - ch, _ := NewWeComBotChannel(cfg, msgBus) - - t.Run("process direct text message", func(t *testing.T) { - msg := WeComBotMessage{ - MsgID: "test_msg_id_123", - AIBotID: "test_aibot_id", - ChatType: "single", - ResponseURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", - MsgType: "text", - } - msg.From.UserID = "user123" - msg.Text.Content = "Hello World" - - // Should not panic - ch.processMessage(context.Background(), msg) - }) - - t.Run("process group text message", func(t *testing.T) { - msg := WeComBotMessage{ - MsgID: "test_msg_id_456", - AIBotID: "test_aibot_id", - ChatID: "group_chat_id_123", - ChatType: "group", - ResponseURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", - MsgType: "text", - } - msg.From.UserID = "user456" - msg.Text.Content = "Hello Group" - - // Should not panic - ch.processMessage(context.Background(), msg) - }) - - t.Run("process voice message", func(t *testing.T) { - msg := WeComBotMessage{ - MsgID: "test_msg_id_789", - AIBotID: "test_aibot_id", - ChatType: "single", - ResponseURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", - MsgType: "voice", - } - msg.From.UserID = "user123" - msg.Voice.Content = "Voice message text" - - // Should not panic - ch.processMessage(context.Background(), msg) - }) - - t.Run("skip unsupported message type", func(t *testing.T) { - msg := WeComBotMessage{ - MsgID: "test_msg_id_000", - AIBotID: "test_aibot_id", - ChatType: "single", - ResponseURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", - MsgType: "video", - } - msg.From.UserID = "user123" - - // Should not panic - ch.processMessage(context.Background(), msg) - }) -} - -func TestWeComBotHandleWebhook(t *testing.T) { - msgBus := bus.NewMessageBus() - cfg := config.WeComConfig{} - cfg.SetToken("test_token") - cfg.WebhookURL = "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test" - ch, _ := NewWeComBotChannel(cfg, msgBus) - - t.Run("GET request calls verification", func(t *testing.T) { - echostr := "test_echostr" - encoded := base64.StdEncoding.EncodeToString([]byte(echostr)) - timestamp := "1234567890" - nonce := "test_nonce" - signature := generateSignature("test_token", timestamp, nonce, encoded) - - req := httptest.NewRequest( - http.MethodGet, - "/webhook/wecom?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce+"&echostr="+encoded, - nil, - ) - w := httptest.NewRecorder() - - ch.handleWebhook(w, req) - - if w.Code != http.StatusOK { - t.Errorf("status code = %d, want %d", w.Code, http.StatusOK) - } - }) - - t.Run("POST request calls message callback", func(t *testing.T) { - encryptedWrapper := struct { - XMLName xml.Name `xml:"xml"` - Encrypt string `xml:"Encrypt"` - }{ - Encrypt: base64.StdEncoding.EncodeToString([]byte("test")), - } - wrapperData, _ := xml.Marshal(encryptedWrapper) - - timestamp := "1234567890" - nonce := "test_nonce" - signature := generateSignature("test_token", timestamp, nonce, encryptedWrapper.Encrypt) - - req := httptest.NewRequest( - http.MethodPost, - "/webhook/wecom?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce, - bytes.NewReader(wrapperData), - ) - w := httptest.NewRecorder() - - ch.handleWebhook(w, req) - - // Should not be method not allowed - if w.Code == http.StatusMethodNotAllowed { - t.Error("POST request should not return Method Not Allowed") - } - }) - - t.Run("unsupported method", func(t *testing.T) { - req := httptest.NewRequest(http.MethodPut, "/webhook/wecom", nil) - w := httptest.NewRecorder() - - ch.handleWebhook(w, req) - - if w.Code != http.StatusMethodNotAllowed { - t.Errorf("status code = %d, want %d", w.Code, http.StatusMethodNotAllowed) - } - }) -} - -func TestWeComBotHandleHealth(t *testing.T) { - msgBus := bus.NewMessageBus() - cfg := config.WeComConfig{} - cfg.SetToken("test_token") - cfg.WebhookURL = "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test" - ch, _ := NewWeComBotChannel(cfg, msgBus) - - req := httptest.NewRequest(http.MethodGet, "/health/wecom", nil) - w := httptest.NewRecorder() - - ch.handleHealth(w, req) - - if w.Code != http.StatusOK { - t.Errorf("status code = %d, want %d", w.Code, http.StatusOK) - } - - contentType := w.Header().Get("Content-Type") - if contentType != "application/json" { - t.Errorf("Content-Type = %q, want %q", contentType, "application/json") - } - - body := w.Body.String() - if !strings.Contains(body, "status") || !strings.Contains(body, "running") { - t.Errorf("response body should contain status and running fields, got: %s", body) - } -} - -func TestWeComBotReplyMessage(t *testing.T) { - msg := WeComBotReplyMessage{ - MsgType: "text", - } - msg.Text.Content = "Hello World" - - if msg.MsgType != "text" { - t.Errorf("MsgType = %q, want %q", msg.MsgType, "text") - } - if msg.Text.Content != "Hello World" { - t.Errorf("Text.Content = %q, want %q", msg.Text.Content, "Hello World") - } -} - -func TestWeComBotMessageStructure(t *testing.T) { - jsonData := `{ - "msgid": "test_msg_id_123", - "aibotid": "test_aibot_id", - "chatid": "group_chat_id_123", - "chattype": "group", - "from": {"userid": "user123"}, - "response_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", - "msgtype": "text", - "text": {"content": "Hello World"} - }` - - var msg WeComBotMessage - err := json.Unmarshal([]byte(jsonData), &msg) - if err != nil { - t.Fatalf("failed to unmarshal JSON: %v", err) - } - - if msg.MsgID != "test_msg_id_123" { - t.Errorf("MsgID = %q, want %q", msg.MsgID, "test_msg_id_123") - } - if msg.AIBotID != "test_aibot_id" { - t.Errorf("AIBotID = %q, want %q", msg.AIBotID, "test_aibot_id") - } - if msg.ChatID != "group_chat_id_123" { - t.Errorf("ChatID = %q, want %q", msg.ChatID, "group_chat_id_123") - } - if msg.ChatType != "group" { - t.Errorf("ChatType = %q, want %q", msg.ChatType, "group") - } - if msg.From.UserID != "user123" { - t.Errorf("From.UserID = %q, want %q", msg.From.UserID, "user123") - } - if msg.MsgType != "text" { - t.Errorf("MsgType = %q, want %q", msg.MsgType, "text") - } - if msg.Text.Content != "Hello World" { - t.Errorf("Text.Content = %q, want %q", msg.Text.Content, "Hello World") - } -} diff --git a/pkg/channels/wecom/common.go b/pkg/channels/wecom/common.go deleted file mode 100644 index 9a622a2fc..000000000 --- a/pkg/channels/wecom/common.go +++ /dev/null @@ -1,199 +0,0 @@ -package wecom - -import ( - "bytes" - "crypto/aes" - "crypto/cipher" - "crypto/rand" - "crypto/sha1" - "encoding/base64" - "encoding/binary" - "fmt" - "math/big" - "sort" - "strings" -) - -// blockSize is the PKCS7 block size used by WeCom (32) -const blockSize = 32 - -// computeSignature computes the WeCom message signature from the given parameters. -// It sorts [token, timestamp, nonce, encrypt], concatenates them and returns the SHA1 hex digest. -func computeSignature(token, timestamp, nonce, encrypt string) string { - params := []string{token, timestamp, nonce, encrypt} - sort.Strings(params) - str := strings.Join(params, "") - hash := sha1.Sum([]byte(str)) - return fmt.Sprintf("%x", hash) -} - -// verifySignature verifies the message signature for WeCom -// This is a common function used by both WeCom Bot and WeCom App -func verifySignature(token, msgSignature, timestamp, nonce, msgEncrypt string) bool { - if token == "" { - return false - } - return computeSignature(token, timestamp, nonce, msgEncrypt) == msgSignature -} - -// decryptMessage decrypts the encrypted message using AES -// For AIBOT, receiveid should be the aibotid; for other apps, it should be corp_id -func decryptMessage(encryptedMsg, encodingAESKey string) (string, error) { - return decryptMessageWithVerify(encryptedMsg, encodingAESKey, "") -} - -// decryptMessageWithVerify decrypts the encrypted message and optionally verifies receiveid -// receiveid: for AIBOT use aibotid, for WeCom App use corp_id. If empty, skip verification. -func decryptMessageWithVerify(encryptedMsg, encodingAESKey, receiveid string) (string, error) { - if encodingAESKey == "" { - // No encryption, return as is (base64 decode) - decoded, err := base64.StdEncoding.DecodeString(encryptedMsg) - if err != nil { - return "", err - } - return string(decoded), nil - } - - aesKey, err := decodeWeComAESKey(encodingAESKey) - if err != nil { - return "", err - } - - cipherText, err := base64.StdEncoding.DecodeString(encryptedMsg) - if err != nil { - return "", fmt.Errorf("failed to decode message: %w", err) - } - - plainText, err := decryptAESCBC(aesKey, cipherText) - if err != nil { - return "", err - } - - return unpackWeComFrame(plainText, receiveid) -} - -// decodeWeComAESKey base64-decodes the 43-character EncodingAESKey (trailing "=" is -// appended automatically) and validates that the result is exactly 32 bytes. -// It is the single place that handles this repeated pattern in both encrypt and decrypt paths. -func decodeWeComAESKey(encodingAESKey string) ([]byte, error) { - aesKey, err := base64.StdEncoding.DecodeString(encodingAESKey + "=") - if err != nil { - return nil, fmt.Errorf("failed to decode AES key: %w", err) - } - if len(aesKey) != 32 { - return nil, fmt.Errorf("invalid AES key length: %d", len(aesKey)) - } - return aesKey, nil -} - -// encryptAESCBC encrypts plaintext using AES-CBC with the given key, mirroring -// decryptAESCBC. IV = aesKey[:aes.BlockSize]. The caller must PKCS7-pad the -// plaintext to a multiple of aes.BlockSize before calling. -func encryptAESCBC(aesKey, plaintext []byte) ([]byte, error) { - block, err := aes.NewCipher(aesKey) - if err != nil { - return nil, fmt.Errorf("failed to create cipher: %w", err) - } - iv := aesKey[:aes.BlockSize] - ciphertext := make([]byte, len(plaintext)) - cipher.NewCBCEncrypter(block, iv).CryptBlocks(ciphertext, plaintext) - return ciphertext, nil -} - -// packWeComFrame builds the WeCom wire format: -// -// random(16 ASCII digits) + msg_len(4, big-endian) + msg + receiveid -func packWeComFrame(msg, receiveid string) ([]byte, error) { - randomBytes := make([]byte, 16) - for i := range 16 { - n, err := rand.Int(rand.Reader, big.NewInt(10)) - if err != nil { - return nil, fmt.Errorf("failed to generate random: %w", err) - } - randomBytes[i] = byte('0' + n.Int64()) - } - msgBytes := []byte(msg) - msgLenBytes := make([]byte, 4) - binary.BigEndian.PutUint32(msgLenBytes, uint32(len(msgBytes))) - var buf bytes.Buffer - buf.Write(randomBytes) - buf.Write(msgLenBytes) - buf.Write(msgBytes) - buf.WriteString(receiveid) - return buf.Bytes(), nil -} - -// unpackWeComFrame parses the WeCom wire format produced by packWeComFrame. -// If receiveid is non-empty it verifies the frame's trailing receiveid field. -func unpackWeComFrame(data []byte, receiveid string) (string, error) { - if len(data) < 20 { - return "", fmt.Errorf("decrypted frame too short: %d bytes", len(data)) - } - msgLen := binary.BigEndian.Uint32(data[16:20]) - if int(msgLen) > len(data)-20 { - return "", fmt.Errorf("invalid message length: %d", msgLen) - } - msg := data[20 : 20+msgLen] - if receiveid != "" && len(data) > 20+int(msgLen) { - actualReceiveID := string(data[20+msgLen:]) - if actualReceiveID != receiveid { - return "", fmt.Errorf("receiveid mismatch: expected %s, got %s", receiveid, actualReceiveID) - } - } - return string(msg), nil -} - -// decryptAESCBC decrypts ciphertext using AES-CBC with the given key. -// IV = aesKey[:aes.BlockSize]. PKCS7 padding is stripped from the returned plaintext. -func decryptAESCBC(aesKey, ciphertext []byte) ([]byte, error) { - if len(ciphertext) == 0 { - return nil, fmt.Errorf("ciphertext is empty") - } - if len(ciphertext)%aes.BlockSize != 0 { - return nil, fmt.Errorf("ciphertext length %d is not a multiple of block size", len(ciphertext)) - } - block, err := aes.NewCipher(aesKey) - if err != nil { - return nil, fmt.Errorf("failed to create cipher: %w", err) - } - iv := aesKey[:aes.BlockSize] - plaintext := make([]byte, len(ciphertext)) - cipher.NewCBCDecrypter(block, iv).CryptBlocks(plaintext, ciphertext) - plaintext, err = pkcs7Unpad(plaintext) - if err != nil { - return nil, fmt.Errorf("failed to unpad: %w", err) - } - return plaintext, nil -} - -// pkcs7Pad adds PKCS7 padding -func pkcs7Pad(data []byte, blockSize int) []byte { - padding := blockSize - (len(data) % blockSize) - if padding == 0 { - padding = blockSize - } - padText := bytes.Repeat([]byte{byte(padding)}, padding) - return append(data, padText...) -} - -// pkcs7Unpad removes PKCS7 padding with validation -func pkcs7Unpad(data []byte) ([]byte, error) { - if len(data) == 0 { - return data, nil - } - padding := int(data[len(data)-1]) - // WeCom uses 32-byte block size for PKCS7 padding - if padding == 0 || padding > blockSize { - return nil, fmt.Errorf("invalid padding size: %d", padding) - } - if padding > len(data) { - return nil, fmt.Errorf("padding size larger than data") - } - // Verify all padding bytes - for i := range padding { - if data[len(data)-1-i] != byte(padding) { - return nil, fmt.Errorf("invalid padding byte at position %d", i) - } - } - return data[:len(data)-padding], nil -} diff --git a/pkg/channels/wecom/dedupe.go b/pkg/channels/wecom/dedupe.go deleted file mode 100644 index 865be668e..000000000 --- a/pkg/channels/wecom/dedupe.go +++ /dev/null @@ -1,54 +0,0 @@ -package wecom - -import "sync" - -const wecomMaxProcessedMessages = 1000 - -// MessageDeduplicator provides thread-safe message deduplication using a circular queue (ring buffer) -// combined with a hash map. This ensures fast O(1) lookups while naturally evicting the oldest -// messages without causing "amnesia cliffs" when the limit is reached. -type MessageDeduplicator struct { - mu sync.Mutex - msgs map[string]bool - ring []string - idx int - max int -} - -// NewMessageDeduplicator creates a new deduplicator with the specified capacity. -func NewMessageDeduplicator(maxEntries int) *MessageDeduplicator { - if maxEntries <= 0 { - maxEntries = wecomMaxProcessedMessages - } - return &MessageDeduplicator{ - msgs: make(map[string]bool, maxEntries), - ring: make([]string, maxEntries), - max: maxEntries, - } -} - -// MarkMessageProcessed marks msgID as processed and returns false for duplicates. -func (d *MessageDeduplicator) MarkMessageProcessed(msgID string) bool { - d.mu.Lock() - defer d.mu.Unlock() - - // 1. Check for duplicate - if d.msgs[msgID] { - return false - } - - // 2. Evict the oldest message at our current ring position (if any) - oldestID := d.ring[d.idx] - if oldestID != "" { - delete(d.msgs, oldestID) - } - - // 3. Store the new message - d.msgs[msgID] = true - d.ring[d.idx] = msgID - - // 4. Advance the circle queue index - d.idx = (d.idx + 1) % d.max - - return true -} diff --git a/pkg/channels/wecom/dedupe_test.go b/pkg/channels/wecom/dedupe_test.go deleted file mode 100644 index 10dff4cfe..000000000 --- a/pkg/channels/wecom/dedupe_test.go +++ /dev/null @@ -1,83 +0,0 @@ -package wecom - -import ( - "sync" - "testing" -) - -func TestMessageDeduplicator_DuplicateDetection(t *testing.T) { - d := NewMessageDeduplicator(wecomMaxProcessedMessages) - - if ok := d.MarkMessageProcessed("msg-1"); !ok { - t.Fatalf("first message should be accepted") - } - - if ok := d.MarkMessageProcessed("msg-1"); ok { - t.Fatalf("duplicate message should be rejected") - } -} - -func TestMessageDeduplicator_ConcurrentSameMessage(t *testing.T) { - d := NewMessageDeduplicator(wecomMaxProcessedMessages) - - const goroutines = 64 - var wg sync.WaitGroup - wg.Add(goroutines) - - results := make(chan bool, goroutines) - for i := 0; i < goroutines; i++ { - go func() { - defer wg.Done() - results <- d.MarkMessageProcessed("msg-concurrent") - }() - } - - wg.Wait() - close(results) - - successes := 0 - for ok := range results { - if ok { - successes++ - } - } - - if successes != 1 { - t.Fatalf("expected exactly 1 successful mark, got %d", successes) - } -} - -func TestMessageDeduplicator_CircularQueueEviction(t *testing.T) { - // Create a deduplicator with a very small capacity to test eviction easily. - capacity := 3 - d := NewMessageDeduplicator(capacity) - - // Fill the queue. - d.MarkMessageProcessed("msg-1") - d.MarkMessageProcessed("msg-2") - d.MarkMessageProcessed("msg-3") - - // At this point, the queue is full. msg-1 is the oldest. - if len(d.msgs) != 3 { - t.Fatalf("expected map size to be 3, got %d", len(d.msgs)) - } - - // This should evict msg-1 and add msg-4. - if ok := d.MarkMessageProcessed("msg-4"); !ok { - t.Fatalf("msg-4 should be accepted") - } - - if len(d.msgs) != 3 { - t.Fatalf("expected map size to remain at max capacity (3), got %d", len(d.msgs)) - } - - // msg-1 should now be forgotten (evicted). - if ok := d.MarkMessageProcessed("msg-1"); !ok { - t.Fatalf("msg-1 should be accepted again because it was evicted") - } - - // msg-2 should have been evicted when we added msg-1 back. - if ok := d.MarkMessageProcessed("msg-2"); !ok { - t.Fatalf("msg-2 should be accepted again because it was evicted") - } -} diff --git a/pkg/channels/wecom/init.go b/pkg/channels/wecom/init.go index bc5a70fa3..3aad84d42 100644 --- a/pkg/channels/wecom/init.go +++ b/pkg/channels/wecom/init.go @@ -8,12 +8,6 @@ import ( func init() { channels.RegisterFactory("wecom", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) { - return NewWeComBotChannel(cfg.Channels.WeCom, b) - }) - channels.RegisterFactory("wecom_app", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) { - return NewWeComAppChannel(cfg.Channels.WeComApp, b) - }) - channels.RegisterFactory("wecom_aibot", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) { - return NewWeComAIBotChannel(cfg.Channels.WeComAIBot, b) + return NewChannel(cfg.Channels.WeCom, b) }) } diff --git a/pkg/channels/wecom/media.go b/pkg/channels/wecom/media.go new file mode 100644 index 000000000..defe226d4 --- /dev/null +++ b/pkg/channels/wecom/media.go @@ -0,0 +1,291 @@ +package wecom + +import ( + "context" + "crypto/aes" + "crypto/cipher" + "encoding/base64" + "fmt" + "io" + "mime" + "net/http" + "net/url" + "os" + "path/filepath" + "strings" + + "github.com/h2non/filetype" + + "github.com/sipeed/picoclaw/pkg/media" +) + +func decodeMediaAESKey(value string) ([]byte, error) { + if value == "" { + return nil, nil + } + key, err := base64.StdEncoding.DecodeString(value) + if err == nil && len(key) == 32 { + return key, nil + } + key, err = base64.StdEncoding.DecodeString(value + "=") + if err != nil { + return nil, fmt.Errorf("decode AES key: %w", err) + } + if len(key) != 32 { + return nil, fmt.Errorf("invalid AES key length %d", len(key)) + } + return key, nil +} + +func decryptAESCBC(key, ciphertext []byte) ([]byte, error) { + if len(ciphertext) == 0 { + return nil, fmt.Errorf("ciphertext is empty") + } + if len(ciphertext)%aes.BlockSize != 0 { + return nil, fmt.Errorf("ciphertext length %d is not a multiple of block size", len(ciphertext)) + } + block, err := aes.NewCipher(key) + if err != nil { + return nil, fmt.Errorf("create cipher: %w", err) + } + plaintext := make([]byte, len(ciphertext)) + iv := key[:aes.BlockSize] + cipher.NewCBCDecrypter(block, iv).CryptBlocks(plaintext, ciphertext) + return pkcs7Unpad(plaintext) +} + +func pkcs7Unpad(data []byte) ([]byte, error) { + if len(data) == 0 { + return nil, fmt.Errorf("empty plaintext") + } + padding := int(data[len(data)-1]) + if padding == 0 || padding > 32 || padding > len(data) { + return nil, fmt.Errorf("invalid padding size %d", padding) + } + for i := 0; i < padding; i++ { + if data[len(data)-1-i] != byte(padding) { + return nil, fmt.Errorf("invalid padding byte") + } + } + return data[:len(data)-padding], nil +} + +func inferMediaExt(contentType, fallback string) string { + contentType = normalizeWeComContentType(contentType) + switch contentType { + case "image/jpeg", "image/jpg": + return ".jpg" + case "image/png": + return ".png" + case "image/gif": + return ".gif" + case "image/webp": + return ".webp" + case "application/pdf": + return ".pdf" + case "video/mp4": + return ".mp4" + default: + return fallback + } +} + +func normalizeWeComContentType(value string) string { + value = strings.ToLower(strings.TrimSpace(value)) + if idx := strings.Index(value, ";"); idx >= 0 { + value = strings.TrimSpace(value[:idx]) + } + return value +} + +func isGenericWeComContentType(value string) bool { + switch normalizeWeComContentType(value) { + case "", "application/octet-stream", "binary/octet-stream", "application/unknown", "application/binary": + return true + default: + return false + } +} + +func sanitizeWeComFilename(name string) string { + name = filepath.Base(strings.TrimSpace(name)) + if name == "." || name == "/" || name == "" { + return "" + } + return name +} + +func candidateWeComFilename(resourceURL, contentDisposition, fallbackName string) string { + if _, params, err := mime.ParseMediaType(contentDisposition); err == nil { + if name := sanitizeWeComFilename(params["filename"]); name != "" { + return name + } + if name := sanitizeWeComFilename(params["filename*"]); name != "" { + return name + } + } + + if parsed, err := url.Parse(resourceURL); err == nil { + query := parsed.Query() + for _, key := range []string{"filename", "file_name", "name"} { + if name := sanitizeWeComFilename(query.Get(key)); name != "" { + return name + } + } + if name := sanitizeWeComFilename(parsed.Path); name != "" { + return name + } + } + + return sanitizeWeComFilename(fallbackName) +} + +func detectWeComFiletype(data []byte) (string, string) { + kind, err := filetype.Match(data) + if err != nil || kind == filetype.Unknown { + return "", "" + } + ext := "" + if kind.Extension != "" { + ext = "." + strings.ToLower(kind.Extension) + } + return normalizeWeComContentType(kind.MIME.Value), ext +} + +func detectWeComMediaMetadata(data []byte, fallbackName, fallbackContentType, resourceURL, contentDisposition string) (string, string) { + filename := candidateWeComFilename(resourceURL, contentDisposition, fallbackName) + if filename == "" { + filename = "media" + } + + ext := strings.ToLower(filepath.Ext(filename)) + contentType := normalizeWeComContentType(fallbackContentType) + detectedType, detectedExt := detectWeComFiletype(data) + + if ext != "" && isGenericWeComContentType(contentType) { + if byExt := normalizeWeComContentType(mime.TypeByExtension(ext)); byExt != "" { + contentType = byExt + } + } + + if detectedType != "" { + switch { + case contentType == "": + contentType = detectedType + case isGenericWeComContentType(contentType): + contentType = detectedType + case strings.HasPrefix(detectedType, "image/") && !strings.HasPrefix(contentType, "image/"): + contentType = detectedType + case strings.HasPrefix(detectedType, "audio/") && !strings.HasPrefix(contentType, "audio/"): + contentType = detectedType + case strings.HasPrefix(detectedType, "video/") && !strings.HasPrefix(contentType, "video/"): + contentType = detectedType + } + } + + if contentType == "" && ext != "" { + contentType = normalizeWeComContentType(mime.TypeByExtension(ext)) + } + if contentType == "" { + contentType = normalizeWeComContentType(http.DetectContentType(data)) + } + + if ext == "" { + ext = detectedExt + } + if ext == "" && contentType != "" { + if exts, err := mime.ExtensionsByType(contentType); err == nil && len(exts) > 0 { + ext = strings.ToLower(exts[0]) + } + } + + if filepath.Ext(filename) == "" && ext != "" { + filename += ext + } + return filename, contentType +} + +func (c *WeComChannel) storeRemoteMedia( + ctx context.Context, + scope, msgID, resourceURL, aesKey, fallbackExt string, +) (string, error) { + store := c.GetMediaStore() + if store == nil { + return "", fmt.Errorf("no media store available") + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, resourceURL, nil) + if err != nil { + return "", fmt.Errorf("create request: %w", err) + } + resp, err := c.mediaClient.Do(req) + if err != nil { + return "", fmt.Errorf("download media: %w", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("download media returned HTTP %d", resp.StatusCode) + } + + const maxSize = 20 << 20 + data, err := io.ReadAll(io.LimitReader(resp.Body, maxSize+1)) + if err != nil { + return "", fmt.Errorf("read media: %w", err) + } + if len(data) > maxSize { + return "", fmt.Errorf("media too large") + } + + if aesKey != "" { + key, keyErr := decodeMediaAESKey(aesKey) + if keyErr != nil { + return "", keyErr + } + data, err = decryptAESCBC(key, data) + if err != nil { + return "", fmt.Errorf("decrypt media: %w", err) + } + } + + filename, contentType := detectWeComMediaMetadata( + data, + msgID+fallbackExt, + resp.Header.Get("Content-Type"), + resourceURL, + resp.Header.Get("Content-Disposition"), + ) + ext := filepath.Ext(filename) + if ext == "" { + ext = inferMediaExt(contentType, fallbackExt) + } + mediaDir := filepath.Join(os.TempDir(), "picoclaw_media") + if mkdirErr := os.MkdirAll(mediaDir, 0o700); mkdirErr != nil { + return "", fmt.Errorf("mkdir media dir: %w", mkdirErr) + } + tmpFile, err := os.CreateTemp(mediaDir, msgID+"-*"+ext) + if err != nil { + return "", fmt.Errorf("create temp file: %w", err) + } + tmpPath := tmpFile.Name() + if _, writeErr := tmpFile.Write(data); writeErr != nil { + tmpFile.Close() + _ = os.Remove(tmpPath) + return "", fmt.Errorf("write temp file: %w", writeErr) + } + if closeErr := tmpFile.Close(); closeErr != nil { + _ = os.Remove(tmpPath) + return "", fmt.Errorf("close temp file: %w", closeErr) + } + + ref, err := store.Store(tmpPath, media.MediaMeta{ + Filename: filename, + ContentType: contentType, + Source: "wecom", + CleanupPolicy: media.CleanupPolicyDeleteOnCleanup, + }, scope) + if err != nil { + _ = os.Remove(tmpPath) + return "", err + } + return ref, nil +} diff --git a/pkg/channels/wecom/media_test.go b/pkg/channels/wecom/media_test.go new file mode 100644 index 000000000..d5307e5d2 --- /dev/null +++ b/pkg/channels/wecom/media_test.go @@ -0,0 +1,180 @@ +package wecom + +import ( + "bytes" + "context" + "encoding/base64" + "io" + "net/http" + "strings" + "testing" + + basechannels "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/media" +) + +func TestStoreRemoteMedia_DetectsJPEGContentTypeFromBody(t *testing.T) { + t.Parallel() + + const jpegBase64 = "/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAP//////////////////////////////////////////////////////////////////////////////////////" + + "//////////////////////////////////////////////////////////////////////////////////////////////2wBDAf//////////////////////////////////////////////////////////////////////////////////////" + + "//////////////////////////////////////////////////////////////////////////////////////////////wAARCAABAAEDASIAAhEBAxEB/8QAFQABAQAAAAAAAAAAAAAAAAAAAAb/xAAVEQEBAAAAAAAAAAAAAAAAAAAABf/aAAwDAQACEAMQAAAB6A//xAAVEAEBAAAAAAAAAAAAAAAAAAAAEf/aAAgBAQABBQJf/8QAFBEBAAAAAAAAAAAAAAAAAAAAEP/aAAgBAwEBPwF//8QAFBEBAAAAAAAAAAAAAAAAAAAAEP/aAAgBAgEBPwF//8QAFBABAAAAAAAAAAAAAAAAAAAAEP/aAAgBAQAGPwJf/8QAFBABAAAAAAAAAAAAAAAAAAAAEP/aAAgBAQABPyFf/9k=" + + jpegData := decodeTestBase64(t, jpegBase64) + store := media.NewFileMediaStore() + ch := &WeComChannel{ + BaseChannel: basechannels.NewBaseChannel("wecom", nil, nil, nil), + mediaClient: &http.Client{ + Transport: roundTripFunc(func(*http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/octet-stream"}}, + Body: io.NopCloser(bytes.NewReader(jpegData)), + }, nil + }), + }, + } + ch.SetMediaStore(store) + + ref, err := ch.storeRemoteMedia(context.Background(), "test-scope", "msg-1", "https://wecom.example/media", "", "") + if err != nil { + t.Fatalf("storeRemoteMedia returned error: %v", err) + } + t.Cleanup(func() { + _ = store.ReleaseAll("test-scope") + }) + + _, meta, err := store.ResolveWithMeta(ref) + if err != nil { + t.Fatalf("resolve media ref: %v", err) + } + if meta.ContentType != "image/jpeg" { + t.Fatalf("expected image/jpeg content type, got %q", meta.ContentType) + } + if !strings.HasSuffix(meta.Filename, ".jpg") && !strings.HasSuffix(meta.Filename, ".jpeg") { + t.Fatalf("expected jpeg filename, got %q", meta.Filename) + } +} + +func TestDetectWeComMediaMetadata_UsesFallbackExtensionWhenBodyUnknown(t *testing.T) { + t.Parallel() + + filename, contentType := detectWeComMediaMetadata([]byte("not a real image"), "msg-2.pdf", "", "", "") + if filename != "msg-2.pdf" { + t.Fatalf("expected fallback filename to be preserved, got %q", filename) + } + if contentType != "application/pdf" { + t.Fatalf("expected application/pdf from fallback extension, got %q", contentType) + } +} + +func TestStoreRemoteMedia_PreservesSuffixFromURL(t *testing.T) { + t.Parallel() + + docxLikeData := []byte("PK\x03\x04fake office payload") + store := media.NewFileMediaStore() + ch := &WeComChannel{ + BaseChannel: basechannels.NewBaseChannel("wecom", nil, nil, nil), + mediaClient: &http.Client{ + Transport: roundTripFunc(func(*http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/octet-stream"}}, + Body: io.NopCloser(bytes.NewReader(docxLikeData)), + }, nil + }), + }, + } + ch.SetMediaStore(store) + + ref, err := ch.storeRemoteMedia( + context.Background(), + "test-scope", + "msg-docx", + "https://wecom.example/media/report.docx?signature=1", + "", + ".bin", + ) + if err != nil { + t.Fatalf("storeRemoteMedia returned error: %v", err) + } + t.Cleanup(func() { + _ = store.ReleaseAll("test-scope") + }) + + localPath, meta, err := store.ResolveWithMeta(ref) + if err != nil { + t.Fatalf("resolve media ref: %v", err) + } + if !strings.HasSuffix(meta.Filename, ".docx") { + t.Fatalf("expected docx filename, got %q", meta.Filename) + } + if !strings.HasSuffix(strings.ToLower(localPath), ".docx") { + t.Fatalf("expected docx temp path, got %q", localPath) + } +} + +func TestStoreRemoteMedia_PreservesSuffixFromContentDisposition(t *testing.T) { + t.Parallel() + + pptxLikeData := []byte("PK\x03\x04fake office payload") + store := media.NewFileMediaStore() + ch := &WeComChannel{ + BaseChannel: basechannels.NewBaseChannel("wecom", nil, nil, nil), + mediaClient: &http.Client{ + Transport: roundTripFunc(func(*http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{ + "Content-Type": []string{"application/octet-stream"}, + "Content-Disposition": []string{`attachment; filename="slides.pptx"`}, + }, + Body: io.NopCloser(bytes.NewReader(pptxLikeData)), + }, nil + }), + }, + } + ch.SetMediaStore(store) + + ref, err := ch.storeRemoteMedia( + context.Background(), + "test-scope", + "msg-pptx", + "https://wecom.example/media/download", + "", + ".bin", + ) + if err != nil { + t.Fatalf("storeRemoteMedia returned error: %v", err) + } + t.Cleanup(func() { + _ = store.ReleaseAll("test-scope") + }) + + localPath, meta, err := store.ResolveWithMeta(ref) + if err != nil { + t.Fatalf("resolve media ref: %v", err) + } + if !strings.HasSuffix(meta.Filename, ".pptx") { + t.Fatalf("expected pptx filename, got %q", meta.Filename) + } + if !strings.HasSuffix(strings.ToLower(localPath), ".pptx") { + t.Fatalf("expected pptx temp path, got %q", localPath) + } +} + +func decodeTestBase64(t *testing.T, value string) []byte { + t.Helper() + + data, err := io.ReadAll(base64.NewDecoder(base64.StdEncoding, strings.NewReader(value))) + if err != nil { + t.Fatalf("decode base64 fixture: %v", err) + } + return data +} + +type roundTripFunc func(*http.Request) (*http.Response, error) + +func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req) +} diff --git a/pkg/channels/wecom/protocol.go b/pkg/channels/wecom/protocol.go new file mode 100644 index 000000000..6867d8856 --- /dev/null +++ b/pkg/channels/wecom/protocol.go @@ -0,0 +1,122 @@ +package wecom + +import "encoding/json" + +const ( + wecomDefaultWebSocketURL = "wss://openws.work.weixin.qq.com" + wecomCmdSubscribe = "aibot_subscribe" + wecomCmdPing = "ping" + wecomCmdMsgCallback = "aibot_msg_callback" + wecomCmdEventCallback = "aibot_event_callback" + wecomCmdRespondMsg = "aibot_respond_msg" + wecomCmdSendMsg = "aibot_send_msg" + wecomMaxContentBytes = 20480 +) + +type wecomEnvelope struct { + Cmd string `json:"cmd,omitempty"` + Headers wecomHeaders `json:"headers"` + Body json.RawMessage `json:"body,omitempty"` + ErrCode int `json:"errcode,omitempty"` + ErrMsg string `json:"errmsg,omitempty"` +} + +type wecomHeaders struct { + ReqID string `json:"req_id,omitempty"` +} + +type wecomCommand struct { + Cmd string `json:"cmd"` + Headers wecomHeaders `json:"headers"` + Body any `json:"body,omitempty"` +} + +type wecomSendMsgBody struct { + ChatID string `json:"chatid"` + ChatType uint32 `json:"chat_type,omitempty"` + MsgType string `json:"msgtype"` + Markdown *wecomMarkdownContent `json:"markdown,omitempty"` +} + +type wecomRespondMsgBody struct { + MsgType string `json:"msgtype"` + Stream *wecomStreamContent `json:"stream,omitempty"` +} + +type wecomStreamContent struct { + ID string `json:"id"` + Finish bool `json:"finish"` + Content string `json:"content,omitempty"` +} + +type wecomMarkdownContent struct { + Content string `json:"content"` +} + +type wecomIncomingMessage struct { + MsgID string `json:"msgid"` + AIBotID string `json:"aibotid"` + ChatID string `json:"chatid,omitempty"` + ChatType string `json:"chattype,omitempty"` + From struct { + UserID string `json:"userid"` + } `json:"from"` + MsgType string `json:"msgtype"` + Text *struct { + Content string `json:"content"` + } `json:"text,omitempty"` + Image *struct { + URL string `json:"url"` + AESKey string `json:"aeskey,omitempty"` + } `json:"image,omitempty"` + File *struct { + URL string `json:"url"` + AESKey string `json:"aeskey,omitempty"` + } `json:"file,omitempty"` + Video *struct { + URL string `json:"url"` + AESKey string `json:"aeskey,omitempty"` + } `json:"video,omitempty"` + Voice *struct { + Content string `json:"content"` + } `json:"voice,omitempty"` + Mixed *struct { + MsgItem []struct { + MsgType string `json:"msgtype"` + Text *struct { + Content string `json:"content"` + } `json:"text,omitempty"` + Image *struct { + URL string `json:"url"` + AESKey string `json:"aeskey,omitempty"` + } `json:"image,omitempty"` + File *struct { + URL string `json:"url"` + AESKey string `json:"aeskey,omitempty"` + } `json:"file,omitempty"` + } `json:"msg_item"` + } `json:"mixed,omitempty"` + Quote *struct { + MsgType string `json:"msgtype"` + Text *struct { + Content string `json:"content"` + } `json:"text,omitempty"` + } `json:"quote,omitempty"` + Event *struct { + EventType string `json:"eventtype"` + } `json:"event,omitempty"` +} + +func incomingChatID(msg wecomIncomingMessage) string { + if msg.ChatID != "" { + return msg.ChatID + } + return msg.From.UserID +} + +func incomingChatTypeCode(kind string) uint32 { + if kind == "group" { + return 2 + } + return 1 +} diff --git a/pkg/channels/wecom/reqid_store.go b/pkg/channels/wecom/reqid_store.go new file mode 100644 index 000000000..59e64e63d --- /dev/null +++ b/pkg/channels/wecom/reqid_store.go @@ -0,0 +1,113 @@ +package wecom + +import ( + "encoding/json" + "errors" + "os" + "path/filepath" + "sync" + "time" +) + +type wecomRoute struct { + ReqID string `json:"req_id"` + ChatID string `json:"chat_id"` + ChatType uint32 `json:"chat_type"` + ExpiresAt time.Time `json:"expires_at"` +} + +type reqIDStore struct { + mu sync.Mutex + path string + routes map[string]wecomRoute +} + +func newReqIDStore(path string) *reqIDStore { + if path == "" { + path = defaultReqIDStorePath() + } + s := &reqIDStore{ + path: path, + routes: make(map[string]wecomRoute), + } + _ = s.load() + return s +} + +func defaultReqIDStorePath() string { + if home, err := os.UserHomeDir(); err == nil && home != "" { + return filepath.Join(home, ".picoclaw", "wecom", "reqid-store.json") + } + return filepath.Join(os.TempDir(), "picoclaw-wecom-reqid-store.json") +} + +func (s *reqIDStore) Put(chatID, reqID string, chatType uint32, ttl time.Duration) error { + if reqID == "" || chatID == "" { + return nil + } + s.mu.Lock() + defer s.mu.Unlock() + s.deleteExpiredLocked(time.Now()) + s.routes[chatID] = wecomRoute{ + ReqID: reqID, + ChatID: chatID, + ChatType: chatType, + ExpiresAt: time.Now().Add(ttl), + } + return s.saveLocked() +} + +func (s *reqIDStore) Get(chatID string) (wecomRoute, bool) { + s.mu.Lock() + defer s.mu.Unlock() + s.deleteExpiredLocked(time.Now()) + route, ok := s.routes[chatID] + return route, ok +} + +func (s *reqIDStore) Delete(chatID string) error { + s.mu.Lock() + defer s.mu.Unlock() + delete(s.routes, chatID) + return s.saveLocked() +} + +func (s *reqIDStore) load() error { + s.mu.Lock() + defer s.mu.Unlock() + + data, err := os.ReadFile(s.path) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + return nil + } + return err + } + + var routes map[string]wecomRoute + if err := json.Unmarshal(data, &routes); err != nil { + return err + } + s.routes = routes + s.deleteExpiredLocked(time.Now()) + return nil +} + +func (s *reqIDStore) deleteExpiredLocked(now time.Time) { + for chatID, route := range s.routes { + if !route.ExpiresAt.IsZero() && now.After(route.ExpiresAt) { + delete(s.routes, chatID) + } + } +} + +func (s *reqIDStore) saveLocked() error { + if err := os.MkdirAll(filepath.Dir(s.path), 0o700); err != nil { + return err + } + data, err := json.MarshalIndent(s.routes, "", " ") + if err != nil { + return err + } + return os.WriteFile(s.path, data, 0o600) +} diff --git a/pkg/channels/wecom/reqid_store_test.go b/pkg/channels/wecom/reqid_store_test.go new file mode 100644 index 000000000..e68e82500 --- /dev/null +++ b/pkg/channels/wecom/reqid_store_test.go @@ -0,0 +1,24 @@ +package wecom + +import ( + "path/filepath" + "testing" + "time" +) + +func TestReqIDStorePersistsRoutes(t *testing.T) { + storePath := filepath.Join(t.TempDir(), "reqids.json") + store := newReqIDStore(storePath) + if err := store.Put("chat-1", "req-1", 2, time.Hour); err != nil { + t.Fatalf("Put() error = %v", err) + } + + reloaded := newReqIDStore(storePath) + route, ok := reloaded.Get("chat-1") + if !ok { + t.Fatal("expected persisted route to be loaded") + } + if route.ChatID != "chat-1" || route.ReqID != "req-1" || route.ChatType != 2 { + t.Fatalf("loaded route = %+v", route) + } +} diff --git a/pkg/channels/wecom/wecom.go b/pkg/channels/wecom/wecom.go new file mode 100644 index 000000000..11959c259 --- /dev/null +++ b/pkg/channels/wecom/wecom.go @@ -0,0 +1,777 @@ +package wecom + +import ( + "context" + "crypto/rand" + "encoding/json" + "fmt" + "math/big" + "net/http" + "strings" + "sync" + "time" + + "github.com/gorilla/websocket" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/identity" + "github.com/sipeed/picoclaw/pkg/logger" +) + +const ( + wecomConnectTimeout = 15 * time.Second + wecomCommandTimeout = 10 * time.Second + wecomHeartbeatInterval = 30 * time.Second + wecomStreamMaxDuration = 5*time.Minute + 30*time.Second + wecomRouteTTL = 30 * time.Minute + wecomMediaTimeout = 30 * time.Second + wecomRecentMessageMax = 1000 +) + +type WeComChannel struct { + *channels.BaseChannel + config config.WeComConfig + + ctx context.Context + cancel context.CancelFunc + + conn *websocket.Conn + connMu sync.Mutex + + pendingMu sync.Mutex + pending map[string]chan wecomEnvelope + + turnsMu sync.Mutex + turns map[string][]wecomTurn + + recent *recentMessageSet + routes *reqIDStore + mediaClient *http.Client + commandSend func(wecomCommand, time.Duration) error +} + +type wecomTurn struct { + ReqID string + ChatID string + ChatType uint32 + StreamID string + CreatedAt time.Time +} + +type recentMessageSet struct { + mu sync.Mutex + seen map[string]struct{} + ring []string + idx int +} + +func newRecentMessageSet(capacity int) *recentMessageSet { + if capacity <= 0 { + capacity = wecomRecentMessageMax + } + return &recentMessageSet{ + seen: make(map[string]struct{}, capacity), + ring: make([]string, capacity), + } +} + +func (s *recentMessageSet) Mark(id string) bool { + if id == "" { + return true + } + s.mu.Lock() + defer s.mu.Unlock() + if _, ok := s.seen[id]; ok { + return false + } + if old := s.ring[s.idx]; old != "" { + delete(s.seen, old) + } + s.ring[s.idx] = id + s.idx = (s.idx + 1) % len(s.ring) + s.seen[id] = struct{}{} + return true +} + +func NewChannel(cfg config.WeComConfig, messageBus *bus.MessageBus) (*WeComChannel, error) { + if cfg.BotID == "" || cfg.Secret() == "" { + return nil, fmt.Errorf("wecom bot_id and secret are required") + } + if cfg.WebSocketURL == "" { + cfg.WebSocketURL = wecomDefaultWebSocketURL + } + + base := channels.NewBaseChannel( + "wecom", + cfg, + messageBus, + cfg.AllowFrom, + channels.WithMaxMessageLength(wecomMaxContentBytes), + channels.WithReasoningChannelID(cfg.ReasoningChannelID), + ) + + ch := &WeComChannel{ + BaseChannel: base, + config: cfg, + pending: make(map[string]chan wecomEnvelope), + turns: make(map[string][]wecomTurn), + recent: newRecentMessageSet(wecomRecentMessageMax), + routes: newReqIDStore(""), + mediaClient: &http.Client{Timeout: wecomMediaTimeout}, + } + ch.SetOwner(ch) + return ch, nil +} + +func (c *WeComChannel) Name() string { return "wecom" } + +func (c *WeComChannel) Start(ctx context.Context) error { + logger.InfoC("wecom", "Starting WeCom channel...") + c.ctx, c.cancel = context.WithCancel(ctx) + c.SetRunning(true) + go c.connectLoop() + return nil +} + +func (c *WeComChannel) Stop(_ context.Context) error { + logger.InfoC("wecom", "Stopping WeCom channel...") + if c.cancel != nil { + c.cancel() + } + c.connMu.Lock() + if c.conn != nil { + _ = c.conn.Close() + c.conn = nil + } + c.connMu.Unlock() + c.clearTurns() + c.SetRunning(false) + return nil +} + +func (c *WeComChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { + if !c.IsRunning() { + return channels.ErrNotRunning + } + content := strings.TrimSpace(msg.Content) + if content == "" { + return nil + } + + if turn, ok := c.getTurn(msg.ChatID); ok { + if time.Since(turn.CreatedAt) <= wecomStreamMaxDuration { + if err := c.sendStreamReply(turn, content); err == nil { + c.deleteTurn(msg.ChatID) + return nil + } + } + c.deleteTurn(msg.ChatID) + } + + if route, ok := c.routes.Get(msg.ChatID); ok { + if err := c.sendActivePush(route.ChatID, route.ChatType, content); err != nil { + return err + } + return nil + } + + if err := c.sendActivePush(msg.ChatID, 0, content); err != nil { + return err + } + return nil +} + +func (c *WeComChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error { + if !c.IsRunning() { + return channels.ErrNotRunning + } + var parts []string + for _, part := range msg.Parts { + switch { + case part.Caption != "": + parts = append(parts, part.Caption) + case part.Filename != "": + parts = append(parts, fmt.Sprintf("[media: %s]", part.Filename)) + default: + parts = append(parts, "[media attachments are not yet supported]") + } + } + return c.Send(ctx, bus.OutboundMessage{ + Channel: msg.Channel, + ChatID: msg.ChatID, + Content: strings.Join(parts, "\n"), + }) +} + +func (c *WeComChannel) connectLoop() { + backoff := time.Second + for { + select { + case <-c.ctx.Done(): + return + default: + } + + if err := c.runConnection(); err != nil { + logger.WarnCF("wecom", "WeCom connection lost", map[string]any{ + "error": err.Error(), + "backoff": backoff.String(), + }) + select { + case <-time.After(backoff): + case <-c.ctx.Done(): + return + } + if backoff < time.Minute { + backoff *= 2 + if backoff > time.Minute { + backoff = time.Minute + } + } + continue + } + return + } +} + +func (c *WeComChannel) runConnection() error { + dialCtx, cancel := context.WithTimeout(c.ctx, wecomConnectTimeout) + defer cancel() + + conn, resp, err := websocket.DefaultDialer.DialContext(dialCtx, c.config.WebSocketURL, nil) + if resp != nil { + _ = resp.Body.Close() + } + if err != nil { + return fmt.Errorf("%w: %v", channels.ErrTemporary, err) + } + + c.connMu.Lock() + c.conn = conn + c.connMu.Unlock() + defer func() { + c.connMu.Lock() + if c.conn == conn { + c.conn = nil + } + c.connMu.Unlock() + _ = conn.Close() + c.clearTurns() + }() + + readErrCh := make(chan error, 1) + go func() { + readErrCh <- c.readLoop(conn) + }() + + if writeErr := c.writeAndWait(conn, wecomCommand{ + Cmd: wecomCmdSubscribe, + Headers: wecomHeaders{ReqID: randomID(10)}, + Body: map[string]string{ + "bot_id": c.config.BotID, + "secret": c.config.Secret(), + }, + }, wecomCommandTimeout); writeErr != nil { + return writeErr + } + + heartbeatDone := make(chan struct{}) + go func() { + defer close(heartbeatDone) + c.heartbeatLoop(conn) + }() + + err = <-readErrCh + _ = conn.Close() + <-heartbeatDone + return err +} + +func (c *WeComChannel) heartbeatLoop(conn *websocket.Conn) { + ticker := time.NewTicker(wecomHeartbeatInterval) + defer ticker.Stop() + for { + select { + case <-ticker.C: + if err := c.writeAndWait(conn, wecomCommand{ + Cmd: wecomCmdPing, + Headers: wecomHeaders{ReqID: randomID(10)}, + }, wecomCommandTimeout); err != nil { + logger.WarnCF("wecom", "Heartbeat failed", map[string]any{"error": err.Error()}) + _ = conn.Close() + return + } + case <-c.ctx.Done(): + return + } + } +} + +func (c *WeComChannel) readLoop(conn *websocket.Conn) error { + for { + _, raw, err := conn.ReadMessage() + if err != nil { + select { + case <-c.ctx.Done(): + return nil + default: + return fmt.Errorf("%w: %v", channels.ErrTemporary, err) + } + } + + var env wecomEnvelope + if err := json.Unmarshal(raw, &env); err != nil { + logger.WarnCF("wecom", "Failed to parse WebSocket message", map[string]any{"error": err.Error()}) + continue + } + + if env.Cmd == "" && env.Headers.ReqID != "" { + c.pendingMu.Lock() + ch, ok := c.pending[env.Headers.ReqID] + if ok { + delete(c.pending, env.Headers.ReqID) + } + c.pendingMu.Unlock() + if ok { + ch <- env + } + continue + } + + go c.handleEnvelope(env) + } +} + +func (c *WeComChannel) handleEnvelope(env wecomEnvelope) { + switch env.Cmd { + case wecomCmdMsgCallback: + c.handleMessageCallback(env) + case wecomCmdEventCallback: + c.handleEventCallback(env) + default: + logger.DebugCF("wecom", "Ignoring unsupported WeCom command", map[string]any{"cmd": env.Cmd}) + } +} + +func (c *WeComChannel) handleEventCallback(env wecomEnvelope) { + var msg wecomIncomingMessage + if err := json.Unmarshal(env.Body, &msg); err != nil { + logger.WarnCF("wecom", "Failed to parse WeCom event callback", map[string]any{"error": err.Error()}) + } +} + +func (c *WeComChannel) handleMessageCallback(env wecomEnvelope) { + var msg wecomIncomingMessage + if err := json.Unmarshal(env.Body, &msg); err != nil { + logger.WarnCF("wecom", "Failed to parse WeCom message callback", map[string]any{"error": err.Error()}) + return + } + if !c.recent.Mark(msg.MsgID) { + return + } + + reqID := env.Headers.ReqID + if reqID == "" { + logger.WarnC("wecom", "WeCom message callback missing req_id") + return + } + if msg.Event != nil && msg.Event.EventType != "" { + return + } + + if err := c.dispatchIncoming(reqID, msg); err != nil { + logger.WarnCF("wecom", "Failed to dispatch WeCom message", map[string]any{ + "req_id": reqID, + "error": err.Error(), + }) + _ = c.respondImmediate(reqID, "The WeCom message could not be processed.") + } +} + +func (c *WeComChannel) dispatchIncoming(reqID string, msg wecomIncomingMessage) error { + senderID := msg.From.UserID + if senderID == "" { + senderID = "unknown" + } + actualChatID := incomingChatID(msg) + chatType := incomingChatTypeCode(msg.ChatType) + peerKind := "direct" + if msg.ChatType == "group" { + peerKind = "group" + } + + sender := bus.SenderInfo{ + Platform: "wecom", + PlatformID: senderID, + CanonicalID: identity.BuildCanonicalID("wecom", senderID), + DisplayName: senderID, + } + + var ( + content string + quoteText string + mediaRefs []string + err error + ) + scope := channels.BuildMediaScope("wecom", actualChatID, msg.MsgID) + switch msg.MsgType { + case "text": + if msg.Text != nil { + content = strings.TrimSpace(msg.Text.Content) + } + case "voice": + if msg.Voice != nil { + content = strings.TrimSpace(msg.Voice.Content) + } + case "image": + content = "[image]" + mediaRefs, err = c.collectSingleMedia(c.ctx, scope, msg.MsgID, &mediaPayload{ + url: msg.Image.URL, + aesKey: msg.Image.AESKey, + }, "image", ".jpg") + case "file": + content = "[file]" + mediaRefs, err = c.collectSingleMedia(c.ctx, scope, msg.MsgID, &mediaPayload{ + url: msg.File.URL, + aesKey: msg.File.AESKey, + }, "file", ".bin") + case "video": + content = "[video]" + mediaRefs, err = c.collectSingleMedia(c.ctx, scope, msg.MsgID, &mediaPayload{ + url: msg.Video.URL, + aesKey: msg.Video.AESKey, + }, "video", ".mp4") + case "mixed": + content, mediaRefs, err = c.collectMixedMedia(c.ctx, scope, msg) + default: + return c.respondImmediate(reqID, "Unsupported WeCom message type: "+msg.MsgType) + } + if err != nil { + return err + } + if msg.Quote != nil && msg.Quote.Text != nil { + quoteText = strings.TrimSpace(msg.Quote.Text.Content) + if content == "" { + content = quoteText + } + } + if content == "" && len(mediaRefs) == 0 { + return c.respondImmediate(reqID, "The WeCom message did not contain usable content.") + } + + turn := wecomTurn{ + ReqID: reqID, + ChatID: actualChatID, + ChatType: chatType, + StreamID: randomID(10), + CreatedAt: time.Now(), + } + c.queueTurn(actualChatID, turn) + if err := c.routes.Put(actualChatID, reqID, chatType, wecomRouteTTL); err != nil { + logger.WarnCF("wecom", "Failed to persist req_id route", map[string]any{ + "chat_id": actualChatID, + "req_id": reqID, + "error": err.Error(), + }) + } + + opening := "" + if c.config.SendThinkingMessage { + opening = "Processing..." + } + if err := c.sendStreamChunk(turn, false, opening); err != nil { + return err + } + + peer := bus.Peer{Kind: peerKind, ID: actualChatID} + metadata := map[string]string{ + "channel": "wecom", + "req_id": reqID, + "chat_id": actualChatID, + "chat_type": msg.ChatType, + "msg_id": msg.MsgID, + "msg_type": msg.MsgType, + } + if quoteText != "" { + metadata["quote_text"] = quoteText + } + + c.HandleMessage(c.ctx, peer, msg.MsgID, senderID, actualChatID, content, mediaRefs, metadata, sender) + return nil +} + +func (c *WeComChannel) collectSingleMedia( + ctx context.Context, + scope, msgID string, + payload interface { + GetURL() string + GetAESKey() string + }, + label, fallbackExt string, +) ([]string, error) { + if payload == nil || payload.GetURL() == "" { + return nil, fmt.Errorf("%s payload is empty", label) + } + ref, err := c.storeRemoteMedia(ctx, scope, msgID, payload.GetURL(), payload.GetAESKey(), fallbackExt) + if err != nil { + return nil, err + } + return []string{ref}, nil +} + +type mediaPayload struct { + url string + aesKey string +} + +func (p *mediaPayload) GetURL() string { return p.url } +func (p *mediaPayload) GetAESKey() string { return p.aesKey } + +func (c *WeComChannel) collectMixedMedia( + ctx context.Context, + scope string, + msg wecomIncomingMessage, +) (string, []string, error) { + if msg.Mixed == nil { + return "", nil, fmt.Errorf("mixed message is empty") + } + + var textParts []string + var refs []string + for idx, item := range msg.Mixed.MsgItem { + switch item.MsgType { + case "text": + if item.Text != nil && strings.TrimSpace(item.Text.Content) != "" { + textParts = append(textParts, strings.TrimSpace(item.Text.Content)) + } + case "image": + if item.Image != nil && item.Image.URL != "" { + ref, err := c.storeRemoteMedia( + ctx, + scope, + fmt.Sprintf("%s-%d", msg.MsgID, idx), + item.Image.URL, + item.Image.AESKey, + ".jpg", + ) + if err != nil { + return "", nil, err + } + refs = append(refs, ref) + } + case "file": + if item.File != nil && item.File.URL != "" { + ref, err := c.storeRemoteMedia( + ctx, + scope, + fmt.Sprintf("%s-%d", msg.MsgID, idx), + item.File.URL, + item.File.AESKey, + ".bin", + ) + if err != nil { + return "", nil, err + } + refs = append(refs, ref) + } + } + } + + content := strings.Join(textParts, "\n") + if content == "" && len(refs) > 0 { + content = "[media]" + } + return content, refs, nil +} + +func (c *WeComChannel) respondImmediate(reqID, content string) error { + turn := wecomTurn{ + ReqID: reqID, + StreamID: randomID(10), + CreatedAt: time.Now(), + } + return c.sendStreamChunk(turn, true, content) +} + +func (c *WeComChannel) sendStreamReply(turn wecomTurn, content string) error { + chunks := splitContent(content, wecomMaxContentBytes) + for idx, chunk := range chunks { + if err := c.sendStreamChunk(turn, idx == len(chunks)-1, chunk); err != nil { + return err + } + } + return nil +} + +func (c *WeComChannel) sendStreamChunk(turn wecomTurn, finish bool, content string) error { + return c.sendCommand(wecomCommand{ + Cmd: wecomCmdRespondMsg, + Headers: wecomHeaders{ReqID: turn.ReqID}, + Body: wecomRespondMsgBody{ + MsgType: "stream", + Stream: &wecomStreamContent{ + ID: turn.StreamID, + Finish: finish, + Content: content, + }, + }, + }, wecomCommandTimeout) +} + +func (c *WeComChannel) sendActivePush(chatID string, chatType uint32, content string) error { + if strings.TrimSpace(chatID) == "" { + return fmt.Errorf("empty chat ID: %w", channels.ErrSendFailed) + } + for _, chunk := range splitContent(content, wecomMaxContentBytes) { + if err := c.sendCommand(wecomCommand{ + Cmd: wecomCmdSendMsg, + Headers: wecomHeaders{ReqID: randomID(10)}, + Body: wecomSendMsgBody{ + ChatID: chatID, + ChatType: chatType, + MsgType: "markdown", + Markdown: &wecomMarkdownContent{Content: chunk}, + }, + }, wecomCommandTimeout); err != nil { + return err + } + } + return nil +} + +func (c *WeComChannel) sendCommand(cmd wecomCommand, timeout time.Duration) error { + if c.commandSend != nil { + return c.commandSend(cmd, timeout) + } + return c.writeCurrent(cmd, timeout) +} + +func (c *WeComChannel) writeCurrent(cmd wecomCommand, timeout time.Duration) error { + c.connMu.Lock() + conn := c.conn + c.connMu.Unlock() + if conn == nil { + return fmt.Errorf("wecom websocket not connected: %w", channels.ErrTemporary) + } + return c.writeAndWait(conn, cmd, timeout) +} + +func (c *WeComChannel) writeAndWait(conn *websocket.Conn, cmd wecomCommand, timeout time.Duration) error { + if cmd.Headers.ReqID == "" { + cmd.Headers.ReqID = randomID(10) + } + waitCh := make(chan wecomEnvelope, 1) + c.pendingMu.Lock() + c.pending[cmd.Headers.ReqID] = waitCh + c.pendingMu.Unlock() + defer func() { + c.pendingMu.Lock() + delete(c.pending, cmd.Headers.ReqID) + c.pendingMu.Unlock() + }() + + data, err := json.Marshal(cmd) + if err != nil { + return fmt.Errorf("%w: %v", channels.ErrSendFailed, err) + } + c.connMu.Lock() + err = conn.WriteMessage(websocket.TextMessage, data) + c.connMu.Unlock() + if err != nil { + return fmt.Errorf("%w: %v", channels.ErrTemporary, err) + } + + timer := time.NewTimer(timeout) + defer timer.Stop() + select { + case env := <-waitCh: + if env.ErrCode != 0 { + return fmt.Errorf("%w: wecom errcode=%d errmsg=%s", channels.ErrTemporary, env.ErrCode, env.ErrMsg) + } + return nil + case <-timer.C: + return fmt.Errorf("%w: timeout waiting for WeCom ack", channels.ErrTemporary) + case <-c.ctx.Done(): + return c.ctx.Err() + } +} + +func (c *WeComChannel) getTurn(chatID string) (wecomTurn, bool) { + c.turnsMu.Lock() + defer c.turnsMu.Unlock() + queue := c.turns[chatID] + if len(queue) == 0 { + return wecomTurn{}, false + } + return queue[0], true +} + +func (c *WeComChannel) deleteTurn(chatID string) { + c.turnsMu.Lock() + defer c.turnsMu.Unlock() + queue := c.turns[chatID] + if len(queue) <= 1 { + delete(c.turns, chatID) + return + } + c.turns[chatID] = queue[1:] +} + +func (c *WeComChannel) queueTurn(chatID string, turn wecomTurn) { + c.turnsMu.Lock() + defer c.turnsMu.Unlock() + c.turns[chatID] = append(c.turns[chatID], turn) +} + +func (c *WeComChannel) clearTurns() { + c.turnsMu.Lock() + c.turns = make(map[string][]wecomTurn) + c.turnsMu.Unlock() +} + +func randomID(n int) string { + const alphabet = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" + if n <= 0 { + n = 10 + } + buf := make([]byte, n) + for i := range buf { + v, _ := rand.Int(rand.Reader, big.NewInt(int64(len(alphabet)))) + buf[i] = alphabet[v.Int64()] + } + return string(buf) +} + +func splitContent(content string, maxBytes int) []string { + if content == "" { + return []string{""} + } + if len(content) <= maxBytes { + return []string{content} + } + chunks := channels.SplitMessage(content, maxBytes) + var result []string + for _, chunk := range chunks { + if len(chunk) <= maxBytes { + result = append(result, chunk) + continue + } + for len(chunk) > maxBytes { + end := maxBytes + for end > 0 && chunk[end]>>6 == 0b10 { + end-- + } + if end == 0 { + end = maxBytes + } + result = append(result, chunk[:end]) + chunk = strings.TrimLeft(chunk[end:], " \t\r\n") + } + if chunk != "" { + result = append(result, chunk) + } + } + return result +} diff --git a/pkg/channels/wecom/wecom_test.go b/pkg/channels/wecom/wecom_test.go new file mode 100644 index 000000000..e0ee2e628 --- /dev/null +++ b/pkg/channels/wecom/wecom_test.go @@ -0,0 +1,167 @@ +package wecom + +import ( + "context" + "errors" + "path/filepath" + "testing" + "time" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/config" +) + +func TestDispatchIncoming_UsesActualChatIDAndStoresReqIDRoute(t *testing.T) { + t.Parallel() + + messageBus := bus.NewMessageBus() + ch := newTestWeComChannel(t, messageBus) + + var commands []wecomCommand + ch.commandSend = func(cmd wecomCommand, _ time.Duration) error { + commands = append(commands, cmd) + return nil + } + + msg := wecomIncomingMessage{ + MsgID: "msg-1", + ChatID: "chat-1", + ChatType: "direct", + MsgType: "text", + Text: &struct { + Content string `json:"content"` + }{Content: "hello"}, + } + msg.From.UserID = "user-1" + + if err := ch.dispatchIncoming("req-1", msg); err != nil { + t.Fatalf("dispatchIncoming() error = %v", err) + } + + select { + case inbound := <-messageBus.InboundChan(): + if inbound.ChatID != "chat-1" { + t.Fatalf("inbound ChatID = %q, want chat-1", inbound.ChatID) + } + if inbound.MessageID != "msg-1" { + t.Fatalf("inbound MessageID = %q, want msg-1", inbound.MessageID) + } + if inbound.Peer.ID != "chat-1" { + t.Fatalf("inbound Peer.ID = %q, want chat-1", inbound.Peer.ID) + } + if inbound.Metadata["req_id"] != "req-1" { + t.Fatalf("inbound req_id = %q, want req-1", inbound.Metadata["req_id"]) + } + default: + t.Fatal("expected inbound message to be published") + } + + turn, ok := ch.getTurn("chat-1") + if !ok { + t.Fatal("expected queued turn for chat-1") + } + if turn.ReqID != "req-1" { + t.Fatalf("turn.ReqID = %q, want req-1", turn.ReqID) + } + + route, ok := ch.routes.Get("chat-1") + if !ok { + t.Fatal("expected persisted route for chat-1") + } + if route.ReqID != "req-1" || route.ChatType != 1 { + t.Fatalf("route = %+v", route) + } + + if len(commands) != 1 { + t.Fatalf("expected 1 opening command, got %d", len(commands)) + } + if commands[0].Cmd != wecomCmdRespondMsg { + t.Fatalf("opening command = %q, want %q", commands[0].Cmd, wecomCmdRespondMsg) + } + if commands[0].Headers.ReqID != "req-1" { + t.Fatalf("opening req_id = %q, want req-1", commands[0].Headers.ReqID) + } +} + +func TestSend_StreamFailureFallsBackToActualChatID(t *testing.T) { + t.Parallel() + + ch := newTestWeComChannel(t, bus.NewMessageBus()) + ch.SetRunning(true) + ch.queueTurn("chat-1", wecomTurn{ + ReqID: "req-1", + ChatID: "chat-1", + ChatType: 1, + StreamID: "stream-1", + CreatedAt: time.Now(), + }) + ch.queueTurn("chat-1", wecomTurn{ + ReqID: "req-2", + ChatID: "chat-1", + ChatType: 1, + StreamID: "stream-2", + CreatedAt: time.Now(), + }) + if err := ch.routes.Put("chat-1", "req-2", 1, time.Hour); err != nil { + t.Fatalf("Put() error = %v", err) + } + + var commands []wecomCommand + ch.commandSend = func(cmd wecomCommand, _ time.Duration) error { + commands = append(commands, cmd) + if len(commands) == 1 && cmd.Cmd == wecomCmdRespondMsg { + return errors.New("stream send failed") + } + return nil + } + + if err := ch.Send(context.Background(), bus.OutboundMessage{ + Channel: "wecom", + ChatID: "chat-1", + Content: "hello", + }); err != nil { + t.Fatalf("Send() error = %v", err) + } + + if len(commands) != 2 { + t.Fatalf("expected 2 commands, got %d", len(commands)) + } + if commands[0].Cmd != wecomCmdRespondMsg || commands[0].Headers.ReqID != "req-1" { + t.Fatalf("first command = %+v", commands[0]) + } + if commands[1].Cmd != wecomCmdSendMsg { + t.Fatalf("second command = %q, want %q", commands[1].Cmd, wecomCmdSendMsg) + } + body, ok := commands[1].Body.(wecomSendMsgBody) + if !ok { + t.Fatalf("unexpected send body type %T", commands[1].Body) + } + if body.ChatID != "chat-1" { + t.Fatalf("send chatid = %q, want chat-1", body.ChatID) + } + if body.ChatType != 1 { + t.Fatalf("send chat_type = %d, want 1", body.ChatType) + } + + nextTurn, ok := ch.getTurn("chat-1") + if !ok { + t.Fatal("expected second turn to remain queued") + } + if nextTurn.ReqID != "req-2" { + t.Fatalf("next queued req_id = %q, want req-2", nextTurn.ReqID) + } +} + +func newTestWeComChannel(t *testing.T, messageBus *bus.MessageBus) *WeComChannel { + t.Helper() + + cfg := config.WeComConfig{BotID: "bot-1"} + cfg.SetSecret("secret-1") + ch, err := NewChannel(cfg, messageBus) + if err != nil { + t.Fatalf("NewChannel() error = %v", err) + } + ch.ctx = context.Background() + ch.routes = newReqIDStore(filepath.Join(t.TempDir(), "reqids.json")) + return ch +} diff --git a/pkg/config/config.go b/pkg/config/config.go index 84e1ab61a..c2815c0db 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -321,10 +321,7 @@ type AgentDefaults struct { ToolFeedback ToolFeedbackConfig `json:"tool_feedback,omitempty"` } -const ( - DefaultMaxMediaSize = 20 * 1024 * 1024 // 20 MB - DefaultWeComAIBotProcessingMessage = "⏳ Processing, please wait. The results will be sent shortly." -) +const DefaultMaxMediaSize = 20 * 1024 * 1024 // 20 MB func (d *AgentDefaults) GetMaxMediaSize() int { if d.MaxMediaSize > 0 { @@ -364,9 +361,7 @@ type ChannelsConfig struct { Matrix MatrixConfig `json:"matrix"` LINE LINEConfig `json:"line"` OneBot OneBotConfig `json:"onebot"` - WeCom WeComConfig `json:"wecom"` - WeComApp WeComAppConfig `json:"wecom_app"` - WeComAIBot WeComAIBotConfig `json:"wecom_aibot"` + WeCom WeComConfig `json:"wecom" envPrefix:"PICOCLAW_CHANNELS_WECOM_"` Weixin WeixinConfig `json:"weixin"` Pico PicoConfig `json:"pico"` PicoClient PicoClientConfig `json:"pico_client"` @@ -678,136 +673,28 @@ func (c *OneBotConfig) SetAccessToken(token string) { c.secDirty = true } +type WeComGroupConfig struct { + AllowFrom FlexibleStringSlice `json:"allow_from,omitempty"` +} + type WeComConfig struct { - Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_WECOM_ENABLED"` - token string - encodingAESKey string - WebhookURL string `json:"webhook_url" env:"PICOCLAW_CHANNELS_WECOM_WEBHOOK_URL"` - WebhookHost string `json:"webhook_host" env:"PICOCLAW_CHANNELS_WECOM_WEBHOOK_HOST"` - WebhookPort int `json:"webhook_port" env:"PICOCLAW_CHANNELS_WECOM_WEBHOOK_PORT"` - WebhookPath string `json:"webhook_path" env:"PICOCLAW_CHANNELS_WECOM_WEBHOOK_PATH"` - AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_WECOM_ALLOW_FROM"` - ReplyTimeout int `json:"reply_timeout" env:"PICOCLAW_CHANNELS_WECOM_REPLY_TIMEOUT"` - GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"` - ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_WECOM_REASONING_CHANNEL_ID"` - secDirty bool + Enabled bool `json:"enabled" env:"ENABLED"` + BotID string `json:"bot_id" env:"BOT_ID"` + secret string + WebSocketURL string `json:"websocket_url,omitempty" env:"WEBSOCKET_URL"` + SendThinkingMessage bool `json:"send_thinking_message" env:"SEND_THINKING_MESSAGE"` + AllowFrom FlexibleStringSlice `json:"allow_from" env:"ALLOW_FROM"` + ReasoningChannelID string `json:"reasoning_channel_id" env:"REASONING_CHANNEL_ID"` + secDirty bool } -// Token returns the WeCom token -func (c *WeComConfig) Token() string { - return c.token -} - -// SetToken sets the WeCom token -func (c *WeComConfig) SetToken(token string) { - c.token = token - c.secDirty = true -} - -// EncodingAESKey returns the WeCom encoding AES key -func (c *WeComConfig) EncodingAESKey() string { - return c.encodingAESKey -} - -// SetEncodingAESKey sets the WeCom encoding AES key -func (c *WeComConfig) SetEncodingAESKey(key string) { - c.encodingAESKey = key - c.secDirty = true -} - -type WeComAppConfig struct { - Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_WECOM_APP_ENABLED"` - CorpID string `json:"corp_id" env:"PICOCLAW_CHANNELS_WECOM_APP_CORP_ID"` - corpSecret string - AgentID int64 `json:"agent_id" env:"PICOCLAW_CHANNELS_WECOM_APP_AGENT_ID"` - token string - encodingAESKey string - WebhookHost string `json:"webhook_host" env:"PICOCLAW_CHANNELS_WECOM_APP_WEBHOOK_HOST"` - WebhookPort int `json:"webhook_port" env:"PICOCLAW_CHANNELS_WECOM_APP_WEBHOOK_PORT"` - WebhookPath string `json:"webhook_path" env:"PICOCLAW_CHANNELS_WECOM_APP_WEBHOOK_PATH"` - AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_WECOM_APP_ALLOW_FROM"` - ReplyTimeout int `json:"reply_timeout" env:"PICOCLAW_CHANNELS_WECOM_APP_REPLY_TIMEOUT"` - GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"` - ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_WECOM_APP_REASONING_CHANNEL_ID"` - secDirty bool -} - -// CorpSecret returns the corporate secret for WeCom app -func (c *WeComAppConfig) CorpSecret() string { - return c.corpSecret -} - -// SetCorpSecret sets the corporate secret for WeCom app -func (c *WeComAppConfig) SetCorpSecret(secret string) { - c.corpSecret = secret - c.secDirty = true -} - -// Token returns the webhook token for WeCom app -func (c *WeComAppConfig) Token() string { - return c.token -} - -// SetToken sets the webhook token for WeCom app -func (c *WeComAppConfig) SetToken(token string) { - c.token = token - c.secDirty = true -} - -// EncodingAESKey returns the encoding AES key for WeCom app -func (c *WeComAppConfig) EncodingAESKey() string { - return c.encodingAESKey -} - -// SetEncodingAESKey sets the encoding AES key for WeCom app -func (c *WeComAppConfig) SetEncodingAESKey(key string) { - c.encodingAESKey = key - c.secDirty = true -} - -type WeComAIBotConfig struct { - Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_ENABLED"` - BotID string `json:"bot_id,omitempty" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_BOT_ID"` - secret string - token string - encodingAESKey string - WebhookPath string `json:"webhook_path,omitempty" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_WEBHOOK_PATH"` - AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_ALLOW_FROM"` - ReplyTimeout int `json:"reply_timeout" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_REPLY_TIMEOUT"` - MaxSteps int `json:"max_steps" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_MAX_STEPS"` // Maximum streaming steps - WelcomeMessage string `json:"welcome_message" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_WELCOME_MESSAGE"` // Sent on enter_chat event; empty = no welcome - ProcessingMessage string `json:"processing_message,omitempty" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_PROCESSING_MESSAGE"` - ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_REASONING_CHANNEL_ID"` - secDirty bool -} - -// Token returns the webhook token for WeCom AI bot -func (c *WeComAIBotConfig) Token() string { - return c.token -} - -// EncodingAESKey returns the encoding AES key for WeCom AI bot -func (c *WeComAIBotConfig) EncodingAESKey() string { - return c.encodingAESKey -} - -// SetToken sets the token for WeCom AI bot -func (c *WeComAIBotConfig) SetToken(token string) { - c.token = token - c.secDirty = true -} - -// SetEncodingAESKey sets the encoding AES key for WeCom AI bot -func (c *WeComAIBotConfig) SetEncodingAESKey(key string) { - c.encodingAESKey = key - c.secDirty = true -} - -func (c *WeComAIBotConfig) Secret() string { +// Secret returns the WeCom bot secret. +func (c *WeComConfig) Secret() string { return c.secret } -func (c *WeComAIBotConfig) SetSecret(secret string) { +// SetSecret sets the WeCom bot secret. +func (c *WeComConfig) SetSecret(secret string) { c.secret = secret c.secDirty = true } @@ -1623,39 +1510,10 @@ func applySecurityConfig(cfg *Config, sec *SecurityConfig) error { cfg.Channels.OneBot.accessToken = sec.Channels.OneBot.AccessToken } - // Handle WeCom token and encoding key + // Handle WeCom bot secret if sec.Channels.WeCom != nil { - if sec.Channels.WeCom.Token != "" { - cfg.Channels.WeCom.token = sec.Channels.WeCom.Token - } - if sec.Channels.WeCom.EncodingAESKey != "" { - cfg.Channels.WeCom.encodingAESKey = sec.Channels.WeCom.EncodingAESKey - } - } - - // Handle WeCom App credentials - if sec.Channels.WeComApp != nil { - if sec.Channels.WeComApp.CorpSecret != "" { - cfg.Channels.WeComApp.corpSecret = sec.Channels.WeComApp.CorpSecret - } - if sec.Channels.WeComApp.Token != "" { - cfg.Channels.WeComApp.token = sec.Channels.WeComApp.Token - } - if sec.Channels.WeComApp.EncodingAESKey != "" { - cfg.Channels.WeComApp.encodingAESKey = sec.Channels.WeComApp.EncodingAESKey - } - } - - // Handle WeCom AI Bot credentials - if sec.Channels.WeComAIBot != nil { - if sec.Channels.WeComAIBot.Token != "" { - cfg.Channels.WeComAIBot.token = sec.Channels.WeComAIBot.Token - } - if sec.Channels.WeComAIBot.EncodingAESKey != "" { - cfg.Channels.WeComAIBot.encodingAESKey = sec.Channels.WeComAIBot.EncodingAESKey - } - if sec.Channels.WeComAIBot.Secret != "" { - cfg.Channels.WeComAIBot.secret = sec.Channels.WeComAIBot.Secret + if sec.Channels.WeCom.Secret != "" { + cfg.Channels.WeCom.secret = sec.Channels.WeCom.Secret } } @@ -1879,27 +1737,10 @@ func SaveConfig(path string, cfg *Config) error { } if cfg.Channels.WeCom.secDirty { cfg.security.Channels.WeCom = &WeComSecurity{ - Token: cfg.Channels.WeCom.Token(), - EncodingAESKey: cfg.Channels.WeCom.EncodingAESKey(), + Secret: cfg.Channels.WeCom.Secret(), } cfg.Channels.WeCom.secDirty = false } - if cfg.Channels.WeComApp.secDirty { - cfg.security.Channels.WeComApp = &WeComAppSecurity{ - CorpSecret: cfg.Channels.WeComApp.CorpSecret(), - Token: cfg.Channels.WeComApp.Token(), - EncodingAESKey: cfg.Channels.WeComApp.EncodingAESKey(), - } - cfg.Channels.WeComApp.secDirty = false - } - if cfg.Channels.WeComAIBot.secDirty { - cfg.security.Channels.WeComAIBot = &WeComAIBotSecurity{ - Token: cfg.Channels.WeComAIBot.Token(), - EncodingAESKey: cfg.Channels.WeComAIBot.EncodingAESKey(), - Secret: cfg.Channels.WeComAIBot.Secret(), - } - cfg.Channels.WeComAIBot.secDirty = false - } if cfg.Tools.Web.Brave.secDirty { cfg.security.Web.Brave = &BraveSecurity{ APIKeys: cfg.Tools.Web.Brave.APIKeys(), diff --git a/pkg/config/config_old.go b/pkg/config/config_old.go index 01909f5a9..44c9435d1 100644 --- a/pkg/config/config_old.go +++ b/pkg/config/config_old.go @@ -85,23 +85,21 @@ type toolsConfigV0 struct { } type channelsConfigV0 struct { - WhatsApp WhatsAppConfig `json:"whatsapp"` - Telegram telegramConfigV0 `json:"telegram"` - Feishu feishuConfigV0 `json:"feishu"` - Discord discordConfigV0 `json:"discord"` - MaixCam maixcamConfigV0 `json:"maixcam"` - Weixin weixinConfigV0 `json:"weixin"` - QQ qqConfigV0 `json:"qq"` - DingTalk dingtalkConfigV0 `json:"dingtalk"` - Slack slackConfigV0 `json:"slack"` - Matrix matrixConfigV0 `json:"matrix"` - LINE lineConfigV0 `json:"line"` - OneBot onebotConfigV0 `json:"onebot"` - WeCom wecomConfigV0 `json:"wecom"` - WeComApp wecomappConfigV0 `json:"wecom_app"` - WeComAIBot wecomaibotConfigV0 `json:"wecom_aibot"` - Pico picoConfigV0 `json:"pico"` - IRC ircConfigV0 `json:"irc"` + WhatsApp WhatsAppConfig `json:"whatsapp"` + Telegram telegramConfigV0 `json:"telegram"` + Feishu feishuConfigV0 `json:"feishu"` + Discord discordConfigV0 `json:"discord"` + MaixCam maixcamConfigV0 `json:"maixcam"` + Weixin weixinConfigV0 `json:"weixin"` + QQ qqConfigV0 `json:"qq"` + DingTalk dingtalkConfigV0 `json:"dingtalk"` + Slack slackConfigV0 `json:"slack"` + Matrix matrixConfigV0 `json:"matrix"` + LINE lineConfigV0 `json:"line"` + OneBot onebotConfigV0 `json:"onebot"` + WeCom wecomConfigV0 `json:"wecom" envPrefix:"PICOCLAW_CHANNELS_WECOM_"` + Pico picoConfigV0 `json:"pico"` + IRC ircConfigV0 `json:"irc"` } func (v *channelsConfigV0) ToChannelsConfig() (ChannelsConfig, ChannelsSecurity) { @@ -117,45 +115,39 @@ func (v *channelsConfigV0) ToChannelsConfig() (ChannelsConfig, ChannelsSecurity) line, lineSecurity := v.LINE.ToLINEConfig() onebot, onebotSecurity := v.OneBot.ToOneBotConfig() wecom, wecomSecurity := v.WeCom.ToWeComConfig() - wecomapp, wecomappSecurity := v.WeComApp.ToWeComAppConfig() - wecomaibot, wecomaibotSecurity := v.WeComAIBot.ToWeComAIBotConfig() pico, picoSecurity := v.Pico.ToPicoConfig() irc, ircSecurity := v.IRC.ToIRCConfig() return ChannelsConfig{ - WhatsApp: v.WhatsApp, - Telegram: telegram, - Feishu: feishu, - Discord: discord, - MaixCam: maixcam, - QQ: qq, - Weixin: weixin, - DingTalk: dingtalk, - Slack: slack, - Matrix: matrix, - LINE: line, - OneBot: onebot, - WeCom: wecom, - WeComApp: wecomapp, - WeComAIBot: wecomaibot, - Pico: pico, - IRC: irc, + WhatsApp: v.WhatsApp, + Telegram: telegram, + Feishu: feishu, + Discord: discord, + MaixCam: maixcam, + QQ: qq, + Weixin: weixin, + DingTalk: dingtalk, + Slack: slack, + Matrix: matrix, + LINE: line, + OneBot: onebot, + WeCom: wecom, + Pico: pico, + IRC: irc, }, ChannelsSecurity{ - Telegram: telegramSecurity, - Feishu: feishuSecurity, - Discord: discordSecurity, - QQ: qqSecurity, - Weixin: weixinSecurity, - DingTalk: dingtalkSecurity, - Slack: slackSecurity, - Matrix: matrixSecurity, - LINE: lineSecurity, - OneBot: onebotSecurity, - WeCom: wecomSecurity, - WeComApp: wecomappSecurity, - WeComAIBot: wecomaibotSecurity, - Pico: picoSecurity, - IRC: ircSecurity, + Telegram: telegramSecurity, + Feishu: feishuSecurity, + Discord: discordSecurity, + QQ: qqSecurity, + Weixin: weixinSecurity, + DingTalk: dingtalkSecurity, + Slack: slackSecurity, + Matrix: matrixSecurity, + LINE: lineSecurity, + OneBot: onebotSecurity, + WeCom: wecomSecurity, + Pico: picoSecurity, + IRC: ircSecurity, } } @@ -473,39 +465,32 @@ func (v *onebotConfigV0) ToOneBotConfig() (OneBotConfig, *OneBotSecurity) { } type wecomConfigV0 struct { - Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_WECOM_ENABLED"` - Token string `json:"token" env:"PICOCLAW_CHANNELS_WECOM_TOKEN"` - EncodingAESKey string `json:"encoding_aes_key" env:"PICOCLAW_CHANNELS_WECOM_ENCODING_AES_KEY"` - WebhookURL string `json:"webhook_url" env:"PICOCLAW_CHANNELS_WECOM_WEBHOOK_URL"` - WebhookHost string `json:"webhook_host" env:"PICOCLAW_CHANNELS_WECOM_WEBHOOK_HOST"` - WebhookPort int `json:"webhook_port" env:"PICOCLAW_CHANNELS_WECOM_WEBHOOK_PORT"` - WebhookPath string `json:"webhook_path" env:"PICOCLAW_CHANNELS_WECOM_WEBHOOK_PATH"` - AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_WECOM_ALLOW_FROM"` - ReplyTimeout int `json:"reply_timeout" env:"PICOCLAW_CHANNELS_WECOM_REPLY_TIMEOUT"` - GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"` - ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_WECOM_REASONING_CHANNEL_ID"` + Enabled bool `json:"enabled" env:"ENABLED"` + BotID string `json:"bot_id" env:"BOT_ID"` + Secret string `json:"secret" env:"SECRET"` + WebSocketURL string `json:"websocket_url,omitempty" env:"WEBSOCKET_URL"` + SendThinkingMessage bool `json:"send_thinking_message" env:"SEND_THINKING_MESSAGE"` + DMPolicy string `json:"dm_policy,omitempty" env:"DM_POLICY"` + AllowFrom FlexibleStringSlice `json:"allow_from" env:"ALLOW_FROM"` + GroupPolicy string `json:"group_policy,omitempty" env:"GROUP_POLICY"` + GroupAllowFrom FlexibleStringSlice `json:"group_allow_from,omitempty" env:"GROUP_ALLOW_FROM"` + Groups map[string]WeComGroupConfig `json:"groups,omitempty"` + ReasoningChannelID string `json:"reasoning_channel_id" env:"REASONING_CHANNEL_ID"` } func (v *wecomConfigV0) ToWeComConfig() (WeComConfig, *WeComSecurity) { var sec *WeComSecurity - if v.Token != "" || v.EncodingAESKey != "" { - sec = &WeComSecurity{ - Token: v.Token, - EncodingAESKey: v.EncodingAESKey, - } + if v.Secret != "" { + sec = &WeComSecurity{Secret: v.Secret} } return WeComConfig{ - Enabled: v.Enabled, - token: v.Token, - encodingAESKey: v.EncodingAESKey, - WebhookURL: v.WebhookURL, - WebhookHost: v.WebhookHost, - WebhookPort: v.WebhookPort, - WebhookPath: v.WebhookPath, - AllowFrom: v.AllowFrom, - ReplyTimeout: v.ReplyTimeout, - GroupTrigger: v.GroupTrigger, - ReasoningChannelID: v.ReasoningChannelID, + Enabled: v.Enabled, + BotID: v.BotID, + secret: v.Secret, + WebSocketURL: v.WebSocketURL, + SendThinkingMessage: v.SendThinkingMessage, + AllowFrom: v.AllowFrom, + ReasoningChannelID: v.ReasoningChannelID, }, sec } @@ -537,81 +522,6 @@ func (v *weixinConfigV0) ToWeiXinConfig() (WeixinConfig, *WeixinSecurity) { }, sec } -type wecomappConfigV0 struct { - Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_WECOM_APP_ENABLED"` - CorpID string `json:"corp_id" env:"PICOCLAW_CHANNELS_WECOM_APP_CORP_ID"` - CorpSecret string `json:"corp_secret" env:"PICOCLAW_CHANNELS_WECOM_APP_CORP_SECRET"` - AgentID int64 `json:"agent_id" env:"PICOCLAW_CHANNELS_WECOM_APP_AGENT_ID"` - Token string `json:"token" env:"PICOCLAW_CHANNELS_WECOM_APP_TOKEN"` - EncodingAESKey string `json:"encoding_aes_key" env:"PICOCLAW_CHANNELS_WECOM_APP_ENCODING_AES_KEY"` - WebhookHost string `json:"webhook_host" env:"PICOCLAW_CHANNELS_WECOM_APP_WEBHOOK_HOST"` - WebhookPort int `json:"webhook_port" env:"PICOCLAW_CHANNELS_WECOM_APP_WEBHOOK_PORT"` - WebhookPath string `json:"webhook_path" env:"PICOCLAW_CHANNELS_WECOM_APP_WEBHOOK_PATH"` - AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_WECOM_APP_ALLOW_FROM"` - ReplyTimeout int `json:"reply_timeout" env:"PICOCLAW_CHANNELS_WECOM_APP_REPLY_TIMEOUT"` - GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"` - ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_WECOM_APP_REASONING_CHANNEL_ID"` -} - -func (v *wecomappConfigV0) ToWeComAppConfig() (WeComAppConfig, *WeComAppSecurity) { - var sec *WeComAppSecurity - if v.CorpSecret != "" || v.Token != "" || v.EncodingAESKey != "" { - sec = &WeComAppSecurity{ - CorpSecret: v.CorpSecret, - Token: v.Token, - EncodingAESKey: v.EncodingAESKey, - } - } - return WeComAppConfig{ - Enabled: v.Enabled, - CorpID: v.CorpID, - corpSecret: v.CorpSecret, - AgentID: v.AgentID, - token: v.Token, - encodingAESKey: v.EncodingAESKey, - WebhookHost: v.WebhookHost, - WebhookPort: v.WebhookPort, - WebhookPath: v.WebhookPath, - AllowFrom: v.AllowFrom, - ReplyTimeout: v.ReplyTimeout, - GroupTrigger: v.GroupTrigger, - ReasoningChannelID: v.ReasoningChannelID, - }, sec -} - -type wecomaibotConfigV0 struct { - Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_ENABLED"` - Token string `json:"token" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_TOKEN"` - Secret string `json:"secret" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_SECRET"` - EncodingAESKey string `json:"encoding_aes_key" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_ENCODING_AES_KEY"` - WebhookPath string `json:"webhook_path" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_WEBHOOK_PATH"` - AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_ALLOW_FROM"` - ReplyTimeout int `json:"reply_timeout" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_REPLY_TIMEOUT"` - MaxSteps int `json:"max_steps" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_MAX_STEPS"` - WelcomeMessage string `json:"welcome_message" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_WELCOME_MESSAGE"` - ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_REASONING_CHANNEL_ID"` -} - -func (v *wecomaibotConfigV0) ToWeComAIBotConfig() (WeComAIBotConfig, *WeComAIBotSecurity) { - var sec *WeComAIBotSecurity - if v.Token != "" || v.Secret != "" || v.EncodingAESKey != "" { - sec = &WeComAIBotSecurity{ - Token: v.Token, - Secret: v.Secret, - EncodingAESKey: v.EncodingAESKey, - } - } - return WeComAIBotConfig{ - Enabled: v.Enabled, - WebhookPath: v.WebhookPath, - AllowFrom: v.AllowFrom, - ReplyTimeout: v.ReplyTimeout, - MaxSteps: v.MaxSteps, - WelcomeMessage: v.WelcomeMessage, - ReasoningChannelID: v.ReasoningChannelID, - }, sec -} - type picoConfigV0 struct { Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_PICO_ENABLED"` Token string `json:"token" env:"PICOCLAW_CHANNELS_PICO_TOKEN"` diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go index b356d474f..2fa5d2195 100644 --- a/pkg/config/config_test.go +++ b/pkg/config/config_test.go @@ -1372,8 +1372,7 @@ func TestFilterSensitiveData_AllTokenTypes(t *testing.T) { Feishu: &FeishuSecurity{AppSecret: "feishu-app-secret-123", EncryptKey: "feishu-encrypt-key"}, DingTalk: &DingTalkSecurity{ClientSecret: "dingtalk-client-secret"}, OneBot: &OneBotSecurity{AccessToken: "onebot-access-token"}, - WeCom: &WeComSecurity{Token: "wecom-token", EncodingAESKey: "wecom-aes-key"}, - WeComApp: &WeComAppSecurity{CorpSecret: "wecom-app-secret", Token: "wecom-app-token"}, + WeCom: &WeComSecurity{Secret: "wecom-secret"}, Pico: &PicoSecurity{Token: "pico-token-abc123"}, IRC: &IRCSecurity{ Password: "irc-password", diff --git a/pkg/config/defaults.go b/pkg/config/defaults.go index c1d0ea0f6..b5d73977d 100644 --- a/pkg/config/defaults.go +++ b/pkg/config/defaults.go @@ -129,32 +129,11 @@ func DefaultConfig() *Config { AllowFrom: FlexibleStringSlice{}, }, WeCom: WeComConfig{ - Enabled: false, - WebhookURL: "", - WebhookHost: "0.0.0.0", - WebhookPort: 18793, - WebhookPath: "/webhook/wecom", - AllowFrom: FlexibleStringSlice{}, - ReplyTimeout: 5, - }, - WeComApp: WeComAppConfig{ - Enabled: false, - CorpID: "", - AgentID: 0, - WebhookHost: "0.0.0.0", - WebhookPort: 18792, - WebhookPath: "/webhook/wecom-app", - AllowFrom: FlexibleStringSlice{}, - ReplyTimeout: 5, - }, - WeComAIBot: WeComAIBotConfig{ - Enabled: false, - WebhookPath: "/webhook/wecom-aibot", - AllowFrom: FlexibleStringSlice{}, - ReplyTimeout: 5, - MaxSteps: 10, - WelcomeMessage: "Hello! I'm your AI assistant. How can I help you today?", - ProcessingMessage: DefaultWeComAIBotProcessingMessage, + Enabled: false, + BotID: "", + WebSocketURL: "wss://openws.work.weixin.qq.com", + SendThinkingMessage: true, + AllowFrom: FlexibleStringSlice{}, }, Weixin: WeixinConfig{ Enabled: false, diff --git a/pkg/config/security.go b/pkg/config/security.go index da989ca88..72f0c013f 100644 --- a/pkg/config/security.go +++ b/pkg/config/security.go @@ -69,21 +69,19 @@ type ModelSecurityEntry struct { // ChannelsSecurity stores channel-related security data type ChannelsSecurity struct { - Telegram *TelegramSecurity `yaml:"telegram,omitempty"` - Feishu *FeishuSecurity `yaml:"feishu,omitempty"` - Discord *DiscordSecurity `yaml:"discord,omitempty"` - Weixin *WeixinSecurity `yaml:"weixin,omitempty"` - QQ *QQSecurity `yaml:"qq,omitempty"` - DingTalk *DingTalkSecurity `yaml:"dingtalk,omitempty"` - Slack *SlackSecurity `yaml:"slack,omitempty"` - Matrix *MatrixSecurity `yaml:"matrix,omitempty"` - LINE *LINESecurity `yaml:"line,omitempty"` - OneBot *OneBotSecurity `yaml:"onebot,omitempty"` - WeCom *WeComSecurity `yaml:"wecom,omitempty"` - WeComApp *WeComAppSecurity `yaml:"wecom_app,omitempty"` - WeComAIBot *WeComAIBotSecurity `yaml:"wecom_aibot,omitempty"` - Pico *PicoSecurity `yaml:"pico,omitempty"` - IRC *IRCSecurity `yaml:"irc,omitempty"` + Telegram *TelegramSecurity `yaml:"telegram,omitempty"` + Feishu *FeishuSecurity `yaml:"feishu,omitempty"` + Discord *DiscordSecurity `yaml:"discord,omitempty"` + Weixin *WeixinSecurity `yaml:"weixin,omitempty"` + QQ *QQSecurity `yaml:"qq,omitempty"` + DingTalk *DingTalkSecurity `yaml:"dingtalk,omitempty"` + Slack *SlackSecurity `yaml:"slack,omitempty"` + Matrix *MatrixSecurity `yaml:"matrix,omitempty"` + LINE *LINESecurity `yaml:"line,omitempty"` + OneBot *OneBotSecurity `yaml:"onebot,omitempty"` + WeCom *WeComSecurity `yaml:"wecom,omitempty"` + Pico *PicoSecurity `yaml:"pico,omitempty"` + IRC *IRCSecurity `yaml:"irc,omitempty"` } type TelegramSecurity struct { @@ -131,20 +129,7 @@ type OneBotSecurity struct { } type WeComSecurity struct { - Token string `yaml:"token,omitempty" env:"PICOCLAW_CHANNELS_WECOM_TOKEN"` - EncodingAESKey string `yaml:"encoding_aes_key,omitempty" env:"PICOCLAW_CHANNELS_WECOM_ENCODING_AES_KEY"` -} - -type WeComAppSecurity struct { - CorpSecret string `yaml:"corp_secret,omitempty" env:"PICOCLAW_CHANNELS_WECOM_APP_CORP_SECRET"` - Token string `yaml:"token,omitempty" env:"PICOCLAW_CHANNELS_WECOM_APP_TOKEN"` - EncodingAESKey string `yaml:"encoding_aes_key,omitempty" env:"PICOCLAW_CHANNELS_WECOM_APP_ENCODING_AES_KEY"` -} - -type WeComAIBotSecurity struct { - Secret string `yaml:"secret,omitempty" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_SECRET"` - Token string `yaml:"token,omitempty" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_TOKEN"` - EncodingAESKey string `yaml:"encoding_aes_key,omitempty" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_ENCODING_AES_KEY"` + Secret string `yaml:"secret,omitempty" env:"PICOCLAW_CHANNELS_WECOM_SECRET"` } type PicoSecurity struct { diff --git a/pkg/config/security_integration_test.go b/pkg/config/security_integration_test.go index 218914590..03990ce5b 100644 --- a/pkg/config/security_integration_test.go +++ b/pkg/config/security_integration_test.go @@ -240,15 +240,7 @@ func TestAllSecurityKeysAccessible(t *testing.T) { }, "wecom": { "enabled": true, - "webhook_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook" - }, - "wecom_app": { - "enabled": true, - "corp_id": "test_corp_id", - "agent_id": 123456 - }, - "wecom_aibot": { - "enabled": true + "bot_id": "test_wecom_bot_id" }, "pico": { "enabled": true @@ -315,15 +307,7 @@ channels: onebot: access_token: "onebot_test_access_token" wecom: - token: "wecom_test_webhook_token" - encoding_aes_key: "wecom_test_aes_key" - wecom_app: - corp_secret: "wecom_app_test_corp_secret" - token: "wecom_app_test_token" - encoding_aes_key: "wecom_app_test_aes_key" - wecom_aibot: - token: "wecom_aibot_test_token" - encoding_aes_key: "wecom_aibot_test_aes_key" + secret: "wecom_test_secret" pico: token: "pico_test_token" irc: @@ -409,24 +393,10 @@ skills: t.Logf("OneBot AccessToken(): %s", cfg.Channels.OneBot.AccessToken()) // WeCom - assert.Equal(t, "wecom_test_webhook_token", cfg.Channels.WeCom.Token()) - assert.Equal(t, "wecom_test_aes_key", cfg.Channels.WeCom.EncodingAESKey()) - t.Logf("WeCom Token(): %s", cfg.Channels.WeCom.Token()) - t.Logf("WeCom EncodingAESKey(): %s", cfg.Channels.WeCom.EncodingAESKey()) - - // WeCom App - assert.Equal(t, "wecom_app_test_corp_secret", cfg.Channels.WeComApp.CorpSecret()) - assert.Equal(t, "wecom_app_test_token", cfg.Channels.WeComApp.Token()) - assert.Equal(t, "wecom_app_test_aes_key", cfg.Channels.WeComApp.EncodingAESKey()) - t.Logf("WeComApp CorpSecret(): %s", cfg.Channels.WeComApp.CorpSecret()) - t.Logf("WeComApp Token(): %s", cfg.Channels.WeComApp.Token()) - t.Logf("WeComApp EncodingAESKey(): %s", cfg.Channels.WeComApp.EncodingAESKey()) - - // WeCom AI Bot - assert.Equal(t, "wecom_aibot_test_token", cfg.Channels.WeComAIBot.Token()) - assert.Equal(t, "wecom_aibot_test_aes_key", cfg.Channels.WeComAIBot.EncodingAESKey()) - t.Logf("WeComAIBot Token(): %s", cfg.Channels.WeComAIBot.Token()) - t.Logf("WeComAIBot EncodingAESKey(): %s", cfg.Channels.WeComAIBot.EncodingAESKey()) + assert.Equal(t, "test_wecom_bot_id", cfg.Channels.WeCom.BotID) + assert.Equal(t, "wecom_test_secret", cfg.Channels.WeCom.Secret()) + t.Logf("WeCom BotID: %s", cfg.Channels.WeCom.BotID) + t.Logf("WeCom Secret(): %s", cfg.Channels.WeCom.Secret()) // Pico assert.Equal(t, "pico_test_token", cfg.Channels.Pico.Token()) diff --git a/pkg/migrate/sources/openclaw/common.go b/pkg/migrate/sources/openclaw/common.go index 337c950d0..938f15b80 100644 --- a/pkg/migrate/sources/openclaw/common.go +++ b/pkg/migrate/sources/openclaw/common.go @@ -13,17 +13,16 @@ var migrateableDirs = []string{ } var supportedChannels = map[string]bool{ - "whatsapp": true, - "telegram": true, - "feishu": true, - "discord": true, - "maixcam": true, - "qq": true, - "dingtalk": true, - "slack": true, - "matrix": true, - "line": true, - "onebot": true, - "wecom": true, - "wecom_app": true, + "whatsapp": true, + "telegram": true, + "feishu": true, + "discord": true, + "maixcam": true, + "qq": true, + "dingtalk": true, + "slack": true, + "matrix": true, + "line": true, + "onebot": true, + "wecom": true, } diff --git a/web/backend/api/channels.go b/web/backend/api/channels.go index 21624d3ef..dd4c9af3d 100644 --- a/web/backend/api/channels.go +++ b/web/backend/api/channels.go @@ -22,8 +22,6 @@ var channelCatalog = []channelCatalogItem{ {Name: "qq", ConfigKey: "qq"}, {Name: "onebot", ConfigKey: "onebot"}, {Name: "wecom", ConfigKey: "wecom"}, - {Name: "wecom_app", ConfigKey: "wecom_app"}, - {Name: "wecom_aibot", ConfigKey: "wecom_aibot"}, {Name: "whatsapp", ConfigKey: "whatsapp", Variant: "bridge"}, {Name: "whatsapp_native", ConfigKey: "whatsapp", Variant: "native"}, {Name: "pico", ConfigKey: "pico"}, diff --git a/web/backend/api/config.go b/web/backend/api/config.go index e67e3e6d7..5a7f3cebc 100644 --- a/web/backend/api/config.go +++ b/web/backend/api/config.go @@ -209,6 +209,15 @@ func validateConfig(cfg *config.Config) []string { errs = append(errs, "channels.discord.token is required when discord channel is enabled") } + if cfg.Channels.WeCom.Enabled { + if cfg.Channels.WeCom.BotID == "" { + errs = append(errs, "channels.wecom.bot_id is required when wecom channel is enabled") + } + if cfg.Channels.WeCom.Secret() == "" { + errs = append(errs, "channels.wecom.secret is required when wecom channel is enabled") + } + } + if cfg.Tools.Exec.Enabled { if cfg.Tools.Exec.EnableDenyPatterns { errs = append( diff --git a/web/frontend/src/components/channels/channel-config-page.tsx b/web/frontend/src/components/channels/channel-config-page.tsx index 4996a6314..e621da70c 100644 --- a/web/frontend/src/components/channels/channel-config-page.tsx +++ b/web/frontend/src/components/channels/channel-config-page.tsx @@ -146,13 +146,7 @@ function isConfigured( case "weixin": return asString(config.account_id) !== "" case "wecom": - return asString(config.token) !== "" - case "wecom_app": - return ( - asString(config.corp_id) !== "" && asString(config.corp_secret) !== "" - ) - case "wecom_aibot": - return asString(config.token) !== "" + return asString(config.bot_id) !== "" case "whatsapp": return asString(config.bridge_url) !== "" case "whatsapp_native": @@ -193,11 +187,7 @@ function getRequiredFieldKeys(channelName: string): string[] { case "onebot": return ["ws_url"] case "wecom": - return ["token"] - case "wecom_app": - return ["corp_id", "corp_secret"] - case "wecom_aibot": - return ["token"] + return ["bot_id", "secret"] case "whatsapp": return ["bridge_url"] case "pico": diff --git a/web/frontend/src/components/channels/channel-forms/generic-form.tsx b/web/frontend/src/components/channels/channel-forms/generic-form.tsx index db14fc206..1a872542b 100644 --- a/web/frontend/src/components/channels/channel-forms/generic-form.tsx +++ b/web/frontend/src/components/channels/channel-forms/generic-form.tsx @@ -28,6 +28,7 @@ const SECRET_FIELDS = new Set([ "encoding_aes_key", "encrypt_key", "verification_token", + "secret", "password", "nickserv_password", "sasl_password", @@ -44,6 +45,7 @@ const OBJECT_FIELDS = new Set([ "allow_token_query", "allow_from", "allow_origins", + "groups", ]) function formatLabel(key: string): string { @@ -118,6 +120,14 @@ export function GenericForm({ app_id: t("channels.form.desc.appId"), client_id: t("channels.form.desc.clientId"), corp_id: t("channels.form.desc.corpId"), + bot_id: t("channels.form.desc.appId"), + websocket_url: t("channels.form.desc.wsUrl"), + dm_policy: t("channels.form.desc.genericField", { field: "DM policy" }), + group_policy: t("channels.form.desc.genericField", { field: "group policy" }), + group_allow_from: t("channels.form.desc.allowFrom"), + send_thinking_message: t("channels.form.desc.genericField", { + field: "thinking message behavior", + }), agent_id: t("channels.form.desc.agentId"), webhook_url: t("channels.form.desc.webhookUrl"), webhook_host: t("channels.form.desc.webhookHost"), diff --git a/web/frontend/src/hooks/use-sidebar-channels.ts b/web/frontend/src/hooks/use-sidebar-channels.ts index 5579a955b..be35f1a94 100644 --- a/web/frontend/src/hooks/use-sidebar-channels.ts +++ b/web/frontend/src/hooks/use-sidebar-channels.ts @@ -35,8 +35,6 @@ const CHANNEL_IMPORTANCE_ORDER = [ "slack", "line", "wecom", - "wecom_app", - "wecom_aibot", "dingtalk", "qq", "onebot", @@ -76,8 +74,6 @@ const CHANNEL_ICON_MAP: Record< line: IconBrandLine, qq: IconBrandQq, wecom: IconBrandWechat, - wecom_app: IconBrandWechat, - wecom_aibot: IconBrandWechat, whatsapp: IconBrandWhatsapp, whatsapp_native: IconBrandWhatsapp, matrix: IconBrandMatrix, diff --git a/web/frontend/src/i18n/locales/en.json b/web/frontend/src/i18n/locales/en.json index 0b0afa39d..207385aa1 100644 --- a/web/frontend/src/i18n/locales/en.json +++ b/web/frontend/src/i18n/locales/en.json @@ -233,8 +233,6 @@ "qq": "QQ", "onebot": "OneBot", "wecom": "WeCom", - "wecom_app": "WeCom App", - "wecom_aibot": "WeCom AI Bot", "whatsapp": "WhatsApp", "whatsapp_native": "WhatsApp Native", "pico": "Web", diff --git a/web/frontend/src/i18n/locales/zh.json b/web/frontend/src/i18n/locales/zh.json index e85e4dd44..8d452bac4 100644 --- a/web/frontend/src/i18n/locales/zh.json +++ b/web/frontend/src/i18n/locales/zh.json @@ -233,8 +233,6 @@ "qq": "QQ", "onebot": "OneBot", "wecom": "企业微信", - "wecom_app": "企业微信应用", - "wecom_aibot": "企业微信 AI 机器人", "whatsapp": "WhatsApp", "whatsapp_native": "WhatsApp Native", "pico": "Web", From b0bcf1d3c9258507decb92a3eec1551d9a92faa4 Mon Sep 17 00:00:00 2001 From: Hoshina Date: Tue, 24 Mar 2026 15:04:14 +0800 Subject: [PATCH 23/39] docs(wecom): update examples and docs --- config/config.example.json | 34 ++++------------------------ pkg/channels/README.md | 7 +++--- pkg/channels/README.zh.md | 7 +++--- pkg/config/SECURITY_CONFIG.md | 8 +------ pkg/config/example_security_usage.go | 8 +------ 5 files changed, 12 insertions(+), 52 deletions(-) diff --git a/config/config.example.json b/config/config.example.json index 88578701a..54d387548 100644 --- a/config/config.example.json +++ b/config/config.example.json @@ -182,39 +182,13 @@ "reasoning_channel_id": "" }, "wecom": { - "_comment": "WeCom Bot - Easier setup, supports group chats", - "enabled": false, - "token": "YOUR_TOKEN", - "encoding_aes_key": "YOUR_43_CHAR_ENCODING_AES_KEY", - "webhook_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=YOUR_KEY", - "webhook_path": "/webhook/wecom", - "allow_from": [], - "reply_timeout": 5, - "reasoning_channel_id": "" - }, - "wecom_app": { - "_comment": "WeCom App (自建应用) - More features, proactive messaging, private chat only.", - "enabled": false, - "corp_id": "YOUR_CORP_ID", - "corp_secret": "YOUR_CORP_SECRET", - "agent_id": 1000002, - "token": "YOUR_TOKEN", - "encoding_aes_key": "YOUR_43_CHAR_ENCODING_AES_KEY", - "webhook_path": "/webhook/wecom-app", - "allow_from": [], - "reply_timeout": 5, - "reasoning_channel_id": "" - }, - "wecom_aibot": { - "_comment": "WeCom AI Bot (智能机器人) - Official WeCom AI Bot integration, supports proactive messaging and private chats.", + "_comment": "WeCom AI Bot over WebSocket.", "enabled": false, "bot_id": "YOUR_BOT_ID", "secret": "YOUR_SECRET", - "token": "YOUR_TOKEN", - "encoding_aes_key": "YOUR_43_CHAR_ENCODING_AES_KEY", - "webhook_path": "/webhook/wecom-aibot", - "max_steps": 10, - "welcome_message": "Hello! I'm your AI assistant. How can I help you today?", + "websocket_url": "wss://openws.work.weixin.qq.com", + "send_thinking_message": true, + "allow_from": [], "reasoning_channel_id": "" }, "pico": { diff --git a/pkg/channels/README.md b/pkg/channels/README.md index b7c56660b..7f238ece5 100644 --- a/pkg/channels/README.md +++ b/pkg/channels/README.md @@ -1255,8 +1255,7 @@ make test # Full test suite | `pkg/channels/onebot/` | `"onebot"` | ReactionCapable, MediaSender | | `pkg/channels/dingtalk/` | `"dingtalk"` | — | | `pkg/channels/feishu/` | `"feishu"` | — (architecture-specific build tags: `feishu_32.go` / `feishu_64.go`) | -| `pkg/channels/wecom/` | `"wecom"` | WebhookHandler, HealthChecker | -| `pkg/channels/wecom/` | `"wecom_app"` | MediaSender, WebhookHandler, HealthChecker | +| `pkg/channels/wecom/` | `"wecom"` | MediaSender | | `pkg/channels/qq/` | `"qq"` | — | | `pkg/channels/whatsapp/` | `"whatsapp"` | — (Bridge mode) | | `pkg/channels/whatsapp_native/` | `"whatsapp_native"` | — (Native whatsmeow mode) | @@ -1371,7 +1370,7 @@ agentLoop.Stop() // Stop Agent 2. **Feishu architecture-specific compilation**: The Feishu channel uses build tags to distinguish 32-bit and 64-bit architectures (`feishu_32.go` / `feishu_64.go`). Feishu uses the SDK's WebSocket mode (not HTTP webhook), so it does not implement `WebhookHandler`. -3. **WeCom has two factories**: `"wecom"` (Bot mode, webhook only) and `"wecom_app"` (App mode, supports MediaSender) are registered separately. Both implement `WebhookHandler` and `HealthChecker`. +3. **WeCom is now a single channel**: `"wecom"` is implemented as a WebSocket-based AI Bot channel with route persistence. Access control uses the shared channel allowlist mechanism. It no longer exposes the legacy webhook/app split. 4. **Pico Protocol**: `pkg/channels/pico/` implements a custom PicoClaw native protocol channel that receives messages via WebSocket webhook (`/pico/ws`). @@ -1381,4 +1380,4 @@ agentLoop.Stop() // Stop Agent 7. **PlaceholderConfig vs implementation**: `PlaceholderConfig` appears in 6 channel configs (Telegram, Discord, Slack, LINE, OneBot, Pico), but only channels that implement both `PlaceholderCapable` + `MessageEditor` (Telegram, Discord, Pico) can actually use placeholder message editing. The rest are reserved fields. -8. **ReasoningChannelID**: Most channel configs include a `reasoning_channel_id` field to route LLM reasoning/thinking output to a designated channel (WhatsApp, Telegram, Feishu, Discord, MaixCam, QQ, DingTalk, Slack, LINE, OneBot, WeCom, WeComApp). Note: `PicoConfig` does not currently expose this field. `BaseChannel` exposes this via the `WithReasoningChannelID` option and `ReasoningChannelID()` method. \ No newline at end of file +8. **ReasoningChannelID**: Most channel configs include a `reasoning_channel_id` field to route LLM reasoning/thinking output to a designated channel (WhatsApp, Telegram, Feishu, Discord, MaixCam, QQ, DingTalk, Slack, LINE, OneBot, WeCom). Note: `PicoConfig` does not currently expose this field. `BaseChannel` exposes this via the `WithReasoningChannelID` option and `ReasoningChannelID()` method. diff --git a/pkg/channels/README.zh.md b/pkg/channels/README.zh.md index 2c5e7356e..8bc8c8dbc 100644 --- a/pkg/channels/README.zh.md +++ b/pkg/channels/README.zh.md @@ -1254,8 +1254,7 @@ make test # 全量测试 | `pkg/channels/onebot/` | `"onebot"` | ReactionCapable, MediaSender | | `pkg/channels/dingtalk/` | `"dingtalk"` | — | | `pkg/channels/feishu/` | `"feishu"` | — (架构特定 build tags: `feishu_32.go` / `feishu_64.go`) | -| `pkg/channels/wecom/` | `"wecom"` | WebhookHandler, HealthChecker | -| `pkg/channels/wecom/` | `"wecom_app"` | MediaSender, WebhookHandler, HealthChecker | +| `pkg/channels/wecom/` | `"wecom"` | MediaSender | | `pkg/channels/qq/` | `"qq"` | — | | `pkg/channels/whatsapp/` | `"whatsapp"` | — (Bridge 模式) | | `pkg/channels/whatsapp_native/` | `"whatsapp_native"` | — (原生 whatsmeow 模式) | @@ -1370,7 +1369,7 @@ agentLoop.Stop() // 停止 Agent 2. **Feishu 架构特定编译**:Feishu channel 使用 build tags 区分 32 位和 64 位架构(`feishu_32.go` / `feishu_64.go`)。Feishu 使用 SDK 的 WebSocket 模式(非 HTTP webhook),因此不实现 `WebhookHandler`。 -3. **WeCom 有两个工厂**:`"wecom"`(Bot 模式,纯 webhook)和 `"wecom_app"`(应用模式,支持 MediaSender)分别注册。两者都实现了 `WebhookHandler` 和 `HealthChecker`。 +3. **WeCom 现在只有一个 channel**:`"wecom"` 采用 WebSocket AI Bot 实现,带路由持久化;访问控制走统一的 channel 白名单机制,不再保留旧的 webhook/app 双分支。 4. **Pico Protocol**:`pkg/channels/pico/` 实现了一个自定义的 PicoClaw 原生协议 channel,通过 WebSocket webhook (`/pico/ws`) 接收消息。 @@ -1380,4 +1379,4 @@ agentLoop.Stop() // 停止 Agent 7. **PlaceholderConfig 的配置与实现**:`PlaceholderConfig` 出现在 6 个 channel config 中(Telegram、Discord、Slack、LINE、OneBot、Pico),但只有实现了 `PlaceholderCapable` + `MessageEditor` 的 channel(Telegram、Discord、Pico)能真正使用占位消息编辑功能。其余 channel 的 `PlaceholderConfig` 为预留字段。 -8. **ReasoningChannelID**:大多数 channel config 都包含 `reasoning_channel_id` 字段,用于将 LLM 的思维链(reasoning/thinking)路由到指定 channel(WhatsApp、Telegram、Feishu、Discord、MaixCam、QQ、DingTalk、Slack、LINE、OneBot、WeCom、WeComApp)。注意:`PicoConfig` 目前不包含该字段。`BaseChannel` 通过 `WithReasoningChannelID` 选项和 `ReasoningChannelID()` 方法暴露此配置。 \ No newline at end of file +8. **ReasoningChannelID**:大多数 channel config 都包含 `reasoning_channel_id` 字段,用于将 LLM 的思维链(reasoning/thinking)路由到指定 channel(WhatsApp、Telegram、Feishu、Discord、MaixCam、QQ、DingTalk、Slack、LINE、OneBot、WeCom)。注意:`PicoConfig` 目前不包含该字段。`BaseChannel` 通过 `WithReasoningChannelID` 选项和 `ReasoningChannelID()` 方法暴露此配置。 diff --git a/pkg/config/SECURITY_CONFIG.md b/pkg/config/SECURITY_CONFIG.md index c5aed54ae..4f783aaa5 100644 --- a/pkg/config/SECURITY_CONFIG.md +++ b/pkg/config/SECURITY_CONFIG.md @@ -99,13 +99,7 @@ Examples: - `ref:channels.line.channel_secret` - `ref:channels.line.channel_access_token` - `ref:channels.onebot.access_token` -- `ref:channels.wecom.token` -- `ref:channels.wecom.encoding_aes_key` -- `ref:channels.wecom_app.corp_secret` -- `ref:channels.wecom_app.token` -- `ref:channels.wecom_app.encoding_aes_key` -- `ref:channels.wecom_aibot.token` -- `ref:channels.wecom_aibot.encoding_aes_key` +- `ref:channels.wecom.secret` - `ref:channels.pico.token` - `ref:channels.irc.password` - `ref:channels.irc.nickserv_password` diff --git a/pkg/config/example_security_usage.go b/pkg/config/example_security_usage.go index cba76c6bc..09aee9aa3 100644 --- a/pkg/config/example_security_usage.go +++ b/pkg/config/example_security_usage.go @@ -153,13 +153,7 @@ Both single and multiple keys should use the array format. - ref:channels.line.channel_secret - ref:channels.line.channel_access_token - ref:channels.onebot.access_token -- ref:channels.wecom.token -- ref:channels.wecom.encoding_aes_key -- ref:channels.wecom_app.corp_secret -- ref:channels.wecom_app.token -- ref:channels.wecom_app.encoding_aes_key -- ref:channels.wecom_aibot.token -- ref:channels.wecom_aibot.encoding_aes_key +- ref:channels.wecom.secret - ref:channels.pico.token - ref:channels.irc.password - ref:channels.irc.nickserv_password From e760cb737c9d8d0a00faf5cb75c352fdba914a73 Mon Sep 17 00:00:00 2001 From: Hoshina Date: Tue, 24 Mar 2026 15:04:46 +0800 Subject: [PATCH 24/39] feat(auth): add wecom cli qr login --- cmd/picoclaw/internal/auth/command.go | 1 + cmd/picoclaw/internal/auth/command_test.go | 1 + cmd/picoclaw/internal/auth/wecom.go | 407 +++++++++++++++++++++ cmd/picoclaw/internal/auth/wecom_test.go | 157 ++++++++ 4 files changed, 566 insertions(+) create mode 100644 cmd/picoclaw/internal/auth/wecom.go create mode 100644 cmd/picoclaw/internal/auth/wecom_test.go diff --git a/cmd/picoclaw/internal/auth/command.go b/cmd/picoclaw/internal/auth/command.go index 149095699..9de083d8d 100644 --- a/cmd/picoclaw/internal/auth/command.go +++ b/cmd/picoclaw/internal/auth/command.go @@ -17,6 +17,7 @@ func NewAuthCommand() *cobra.Command { newStatusCommand(), newModelsCommand(), newWeixinCommand(), + newWeComCommand(), ) return cmd diff --git a/cmd/picoclaw/internal/auth/command_test.go b/cmd/picoclaw/internal/auth/command_test.go index 12f2bc186..3c7f2d3d6 100644 --- a/cmd/picoclaw/internal/auth/command_test.go +++ b/cmd/picoclaw/internal/auth/command_test.go @@ -33,6 +33,7 @@ func TestNewAuthCommand(t *testing.T) { "status", "models", "weixin", + "wecom", } subcommands := cmd.Commands() diff --git a/cmd/picoclaw/internal/auth/wecom.go b/cmd/picoclaw/internal/auth/wecom.go new file mode 100644 index 000000000..8261f5f80 --- /dev/null +++ b/cmd/picoclaw/internal/auth/wecom.go @@ -0,0 +1,407 @@ +package auth + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "os" + "runtime" + "strconv" + "strings" + "time" + + "github.com/mdp/qrterminal/v3" + "github.com/spf13/cobra" + + "github.com/sipeed/picoclaw/cmd/picoclaw/internal" + "github.com/sipeed/picoclaw/pkg/config" +) + +const ( + wecomQRSourceID = "picoclaw" + wecomQRGenerateEndpoint = "https://work.weixin.qq.com/ai/qc/generate" + wecomQRQueryEndpoint = "https://work.weixin.qq.com/ai/qc/query_result" + wecomQRPageEndpoint = "https://work.weixin.qq.com/ai/qc/gen" + wecomQRHTTPTimeout = 15 * time.Second + wecomQRPollInterval = 3 * time.Second + wecomQRPollTimeout = 5 * time.Minute + wecomDefaultWebSocketURL = "wss://openws.work.weixin.qq.com" +) + +type wecomQRScanner func(context.Context, wecomQRFlowOptions) (wecomQRBotInfo, error) + +type wecomQRFlowOptions struct { + HTTPClient *http.Client + GenerateURL string + QueryURL string + QRCodePageURL string + SourceID string + PollInterval time.Duration + PollTimeout time.Duration + Writer io.Writer +} + +type wecomQRBotInfo struct { + BotID string + Secret string +} + +type wecomQRSession struct { + SCode string + AuthURL string +} + +type wecomQRGenerateResponse struct { + ErrCode int `json:"errcode,omitempty"` + ErrMsg string `json:"errmsg,omitempty"` + Data struct { + SCode string `json:"scode"` + AuthURL string `json:"auth_url"` + } `json:"data"` +} + +type wecomQRQueryResponse struct { + ErrCode int `json:"errcode,omitempty"` + ErrMsg string `json:"errmsg,omitempty"` + Data struct { + Status string `json:"status"` + BotInfo struct { + BotID string `json:"botid"` + Secret string `json:"secret"` + } `json:"bot_info"` + } `json:"data"` +} + +func newWeComCommand() *cobra.Command { + var timeout time.Duration + + cmd := &cobra.Command{ + Use: "wecom", + Short: "Scan a WeCom QR code and configure channels.wecom", + Args: cobra.NoArgs, + RunE: func(_ *cobra.Command, _ []string) error { + return authWeComCmd(timeout) + }, + } + + cmd.Flags().DurationVar(&timeout, "timeout", wecomQRPollTimeout, "How long to wait for QR confirmation") + + return cmd +} + +func authWeComCmd(timeout time.Duration) error { + return authWeComCmdWithScanner(context.Background(), os.Stdout, timeout, scanWeComQRCodeInteractive) +} + +func authWeComCmdWithScanner( + ctx context.Context, + writer io.Writer, + timeout time.Duration, + scanner wecomQRScanner, +) error { + if scanner == nil { + return fmt.Errorf("wecom QR scanner is nil") + } + if writer == nil { + writer = os.Stdout + } + + cfg, err := internal.LoadConfig() + if err != nil { + return fmt.Errorf("failed to load config: %w", err) + } + + opts := defaultWeComQRFlowOptions(timeout) + opts.Writer = writer + + botInfo, err := scanner(ctx, opts) + if err != nil { + return err + } + + applyWeComAuthResult(cfg, botInfo) + + if saveErr := config.SaveConfig(internal.GetConfigPath(), cfg); saveErr != nil { + return fmt.Errorf("failed to save config: %w", saveErr) + } + + fmt.Fprintln(writer) + fmt.Fprintln(writer, "WeCom connected.") + fmt.Fprintf(writer, "Bot ID: %s\n", botInfo.BotID) + fmt.Fprintf(writer, "Config: %s\n", internal.GetConfigPath()) + + return nil +} + +func defaultWeComQRFlowOptions(timeout time.Duration) wecomQRFlowOptions { + if timeout <= 0 { + timeout = wecomQRPollTimeout + } + + return wecomQRFlowOptions{ + HTTPClient: &http.Client{Timeout: wecomQRHTTPTimeout}, + GenerateURL: wecomQRGenerateEndpoint, + QueryURL: wecomQRQueryEndpoint, + QRCodePageURL: wecomQRPageEndpoint, + SourceID: wecomQRSourceID, + PollInterval: wecomQRPollInterval, + PollTimeout: timeout, + Writer: os.Stdout, + } +} + +func applyWeComAuthResult(cfg *config.Config, botInfo wecomQRBotInfo) { + cfg.Channels.WeCom.Enabled = true + cfg.Channels.WeCom.BotID = botInfo.BotID + cfg.Channels.WeCom.SetSecret(botInfo.Secret) + if strings.TrimSpace(cfg.Channels.WeCom.WebSocketURL) == "" { + cfg.Channels.WeCom.WebSocketURL = wecomDefaultWebSocketURL + } +} + +func scanWeComQRCodeInteractive(ctx context.Context, opts wecomQRFlowOptions) (wecomQRBotInfo, error) { + opts = normalizeWeComQRFlowOptions(opts) + + fmt.Fprintln(opts.Writer, "Requesting WeCom QR code...") + + session, err := fetchWeComQRCode(ctx, opts) + if err != nil { + return wecomQRBotInfo{}, err + } + + fmt.Fprintln(opts.Writer) + fmt.Fprintln(opts.Writer, "=======================================================") + fmt.Fprintln(opts.Writer, "Please scan the following QR code with WeCom:") + fmt.Fprintln(opts.Writer, "=======================================================") + fmt.Fprintln(opts.Writer) + + qrterminal.GenerateWithConfig(session.AuthURL, qrterminal.Config{ + Level: qrterminal.L, + Writer: opts.Writer, + HalfBlocks: true, + }) + + pageURL, err := buildWeComQRCodePageURL(opts.QRCodePageURL, opts.SourceID, session.SCode) + if err != nil { + return wecomQRBotInfo{}, err + } + + fmt.Fprintln(opts.Writer) + fmt.Fprintf(opts.Writer, "QR Code Link: %s\n", pageURL) + fmt.Fprintln(opts.Writer) + fmt.Fprintln(opts.Writer, "Waiting for scan...") + + return pollWeComQRCodeResult(ctx, opts, session.SCode) +} + +func normalizeWeComQRFlowOptions(opts wecomQRFlowOptions) wecomQRFlowOptions { + if opts.HTTPClient == nil { + opts.HTTPClient = &http.Client{Timeout: wecomQRHTTPTimeout} + } + if strings.TrimSpace(opts.GenerateURL) == "" { + opts.GenerateURL = wecomQRGenerateEndpoint + } + if strings.TrimSpace(opts.QueryURL) == "" { + opts.QueryURL = wecomQRQueryEndpoint + } + if strings.TrimSpace(opts.QRCodePageURL) == "" { + opts.QRCodePageURL = wecomQRPageEndpoint + } + if strings.TrimSpace(opts.SourceID) == "" { + opts.SourceID = wecomQRSourceID + } + if opts.PollInterval <= 0 { + opts.PollInterval = wecomQRPollInterval + } + if opts.PollTimeout <= 0 { + opts.PollTimeout = wecomQRPollTimeout + } + if opts.Writer == nil { + opts.Writer = os.Stdout + } + + return opts +} + +func fetchWeComQRCode(ctx context.Context, opts wecomQRFlowOptions) (wecomQRSession, error) { + generateURL, err := buildWeComQRGenerateURL(opts.GenerateURL, opts.SourceID, wecomPlatformCode()) + if err != nil { + return wecomQRSession{}, err + } + + var resp wecomQRGenerateResponse + if err := doWeComJSONGet(ctx, opts.HTTPClient, generateURL, &resp); err != nil { + return wecomQRSession{}, fmt.Errorf("failed to get WeCom QR code: %w", err) + } + if resp.ErrCode != 0 { + return wecomQRSession{}, fmt.Errorf( + "failed to get WeCom QR code: errcode=%d errmsg=%s", + resp.ErrCode, + resp.ErrMsg, + ) + } + if resp.Data.SCode == "" || resp.Data.AuthURL == "" { + return wecomQRSession{}, fmt.Errorf("failed to get WeCom QR code: response missing scode or auth_url") + } + + return wecomQRSession{ + SCode: resp.Data.SCode, + AuthURL: resp.Data.AuthURL, + }, nil +} + +func pollWeComQRCodeResult(ctx context.Context, opts wecomQRFlowOptions, scode string) (wecomQRBotInfo, error) { + if strings.TrimSpace(scode) == "" { + return wecomQRBotInfo{}, fmt.Errorf("missing WeCom QR scode") + } + + timeoutCtx, cancel := context.WithTimeout(ctx, opts.PollTimeout) + defer cancel() + + var scannedPrinted bool + + for { + status, err := queryWeComQRCodeStatus(timeoutCtx, opts, scode) + if err != nil { + if errors.Is(err, context.DeadlineExceeded) || errors.Is(timeoutCtx.Err(), context.DeadlineExceeded) { + return wecomQRBotInfo{}, fmt.Errorf("WeCom QR scan timed out after %s", opts.PollTimeout) + } + return wecomQRBotInfo{}, err + } + + switch strings.ToLower(status.Data.Status) { + case "success": + if status.Data.BotInfo.BotID == "" || status.Data.BotInfo.Secret == "" { + return wecomQRBotInfo{}, fmt.Errorf("WeCom QR scan succeeded but bot credentials are missing") + } + return wecomQRBotInfo{ + BotID: status.Data.BotInfo.BotID, + Secret: status.Data.BotInfo.Secret, + }, nil + case "expired": + return wecomQRBotInfo{}, fmt.Errorf("WeCom QR code expired, please retry") + case "scaned", "scanned": + if !scannedPrinted { + fmt.Fprintln(opts.Writer, "QR code scanned. Confirm the login in WeCom.") + scannedPrinted = true + } + } + + select { + case <-timeoutCtx.Done(): + if errors.Is(timeoutCtx.Err(), context.DeadlineExceeded) { + return wecomQRBotInfo{}, fmt.Errorf("WeCom QR scan timed out after %s", opts.PollTimeout) + } + return wecomQRBotInfo{}, timeoutCtx.Err() + case <-time.After(opts.PollInterval): + } + } +} + +func queryWeComQRCodeStatus(ctx context.Context, opts wecomQRFlowOptions, scode string) (wecomQRQueryResponse, error) { + queryURL, err := buildWeComQRQueryURL(opts.QueryURL, scode) + if err != nil { + return wecomQRQueryResponse{}, err + } + + var resp wecomQRQueryResponse + if err := doWeComJSONGet(ctx, opts.HTTPClient, queryURL, &resp); err != nil { + return wecomQRQueryResponse{}, fmt.Errorf("failed to query WeCom QR result: %w", err) + } + if resp.ErrCode != 0 { + return wecomQRQueryResponse{}, fmt.Errorf( + "failed to query WeCom QR result: errcode=%d errmsg=%s", + resp.ErrCode, + resp.ErrMsg, + ) + } + + return resp, nil +} + +func buildWeComQRGenerateURL(baseURL, sourceID string, platformCode int) (string, error) { + u, err := url.Parse(baseURL) + if err != nil { + return "", fmt.Errorf("invalid WeCom QR generate URL: %w", err) + } + + query := u.Query() + query.Set("source", sourceID) + query.Set("sourceID", sourceID) + query.Set("plat", strconv.Itoa(platformCode)) + u.RawQuery = query.Encode() + + return u.String(), nil +} + +func buildWeComQRQueryURL(baseURL, scode string) (string, error) { + u, err := url.Parse(baseURL) + if err != nil { + return "", fmt.Errorf("invalid WeCom QR query URL: %w", err) + } + + query := u.Query() + query.Set("scode", scode) + u.RawQuery = query.Encode() + + return u.String(), nil +} + +func buildWeComQRCodePageURL(baseURL, sourceID, scode string) (string, error) { + u, err := url.Parse(baseURL) + if err != nil { + return "", fmt.Errorf("invalid WeCom QR page URL: %w", err) + } + + query := u.Query() + query.Set("source", sourceID) + query.Set("sourceID", sourceID) + query.Set("scode", scode) + u.RawQuery = query.Encode() + + return u.String(), nil +} + +func doWeComJSONGet(ctx context.Context, client *http.Client, targetURL string, out any) error { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, targetURL, nil) + if err != nil { + return err + } + + resp, err := client.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, readErr := io.ReadAll(io.LimitReader(resp.Body, 8192)) + if readErr != nil { + return fmt.Errorf("unexpected status %s", resp.Status) + } + return fmt.Errorf("unexpected status %s: %s", resp.Status, strings.TrimSpace(string(body))) + } + + if err := json.NewDecoder(resp.Body).Decode(out); err != nil { + return fmt.Errorf("decode JSON response: %w", err) + } + + return nil +} + +func wecomPlatformCode() int { + switch runtime.GOOS { + case "darwin": + return 1 + case "windows": + return 2 + case "linux": + return 3 + default: + return 0 + } +} diff --git a/cmd/picoclaw/internal/auth/wecom_test.go b/cmd/picoclaw/internal/auth/wecom_test.go new file mode 100644 index 000000000..c2a4624ae --- /dev/null +++ b/cmd/picoclaw/internal/auth/wecom_test.go @@ -0,0 +1,157 @@ +package auth + +import ( + "bytes" + "context" + "net/http" + "net/http/httptest" + "net/url" + "path/filepath" + "strconv" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/sipeed/picoclaw/cmd/picoclaw/internal" + "github.com/sipeed/picoclaw/pkg/config" +) + +func TestNewWeComCommand(t *testing.T) { + cmd := newWeComCommand() + + require.NotNil(t, cmd) + assert.Equal(t, "wecom", cmd.Use) + assert.Equal(t, "Scan a WeCom QR code and configure channels.wecom", cmd.Short) + assert.NotNil(t, cmd.Flags().Lookup("timeout")) +} + +func TestBuildWeComQRGenerateURL(t *testing.T) { + rawURL, err := buildWeComQRGenerateURL("https://example.com/ai/qc/generate", wecomQRSourceID, 3) + require.NoError(t, err) + + parsed, err := url.Parse(rawURL) + require.NoError(t, err) + + assert.Equal(t, wecomQRSourceID, parsed.Query().Get("source")) + assert.Equal(t, wecomQRSourceID, parsed.Query().Get("sourceID")) + assert.Equal(t, "3", parsed.Query().Get("plat")) +} + +func TestBuildWeComQRCodePageURL(t *testing.T) { + rawURL, err := buildWeComQRCodePageURL("https://example.com/ai/qc/gen", wecomQRSourceID, "scode-1") + require.NoError(t, err) + + parsed, err := url.Parse(rawURL) + require.NoError(t, err) + + assert.Equal(t, wecomQRSourceID, parsed.Query().Get("source")) + assert.Equal(t, wecomQRSourceID, parsed.Query().Get("sourceID")) + assert.Equal(t, "scode-1", parsed.Query().Get("scode")) +} + +func TestFetchWeComQRCode(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/generate", r.URL.Path) + assert.Equal(t, wecomQRSourceID, r.URL.Query().Get("source")) + assert.Equal(t, wecomQRSourceID, r.URL.Query().Get("sourceID")) + assert.Equal(t, strconv.Itoa(wecomPlatformCode()), r.URL.Query().Get("plat")) + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"data":{"scode":"scode-1","auth_url":"https://example.com/qr"}}`)) + })) + defer server.Close() + + opts := normalizeWeComQRFlowOptions(wecomQRFlowOptions{ + HTTPClient: server.Client(), + GenerateURL: server.URL + "/generate", + Writer: bytes.NewBuffer(nil), + }) + + session, err := fetchWeComQRCode(context.Background(), opts) + require.NoError(t, err) + assert.Equal(t, "scode-1", session.SCode) + assert.Equal(t, "https://example.com/qr", session.AuthURL) +} + +func TestPollWeComQRCodeResult(t *testing.T) { + var calls atomic.Int32 + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + call := calls.Add(1) + assert.Equal(t, "/query", r.URL.Path) + assert.Equal(t, "scode-1", r.URL.Query().Get("scode")) + w.Header().Set("Content-Type", "application/json") + switch call { + case 1: + _, _ = w.Write([]byte(`{"data":{"status":"wait"}}`)) + case 2: + _, _ = w.Write([]byte(`{"data":{"status":"scaned"}}`)) + default: + _, _ = w.Write([]byte(`{"data":{"status":"success","bot_info":{"botid":"bot-1","secret":"secret-1"}}}`)) + } + })) + defer server.Close() + + var output bytes.Buffer + opts := normalizeWeComQRFlowOptions(wecomQRFlowOptions{ + HTTPClient: server.Client(), + QueryURL: server.URL + "/query", + PollInterval: time.Millisecond, + PollTimeout: time.Second, + Writer: &output, + }) + + botInfo, err := pollWeComQRCodeResult(context.Background(), opts, "scode-1") + require.NoError(t, err) + assert.Equal(t, "bot-1", botInfo.BotID) + assert.Equal(t, "secret-1", botInfo.Secret) + assert.Contains(t, output.String(), "QR code scanned. Confirm the login in WeCom.") +} + +func TestApplyWeComAuthResult(t *testing.T) { + cfg := config.DefaultConfig() + cfg.Channels.WeCom.WebSocketURL = "" + + applyWeComAuthResult(cfg, wecomQRBotInfo{ + BotID: "bot-1", + Secret: "secret-1", + }) + + assert.True(t, cfg.Channels.WeCom.Enabled) + assert.Equal(t, "bot-1", cfg.Channels.WeCom.BotID) + assert.Equal(t, "secret-1", cfg.Channels.WeCom.Secret()) + assert.Equal(t, wecomDefaultWebSocketURL, cfg.Channels.WeCom.WebSocketURL) +} + +func TestAuthWeComCmdWithScanner(t *testing.T) { + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "config.json") + + t.Setenv(config.EnvHome, tmpDir) + t.Setenv(config.EnvConfig, configPath) + + var output bytes.Buffer + err := authWeComCmdWithScanner( + context.Background(), + &output, + time.Second, + func(_ context.Context, opts wecomQRFlowOptions) (wecomQRBotInfo, error) { + assert.Equal(t, wecomQRSourceID, opts.SourceID) + return wecomQRBotInfo{ + BotID: "bot-1", + Secret: "secret-1", + }, nil + }, + ) + require.NoError(t, err) + + cfg, err := config.LoadConfig(internal.GetConfigPath()) + require.NoError(t, err) + assert.True(t, cfg.Channels.WeCom.Enabled) + assert.Equal(t, "bot-1", cfg.Channels.WeCom.BotID) + assert.Equal(t, "secret-1", cfg.Channels.WeCom.Secret()) + assert.Equal(t, wecomDefaultWebSocketURL, cfg.Channels.WeCom.WebSocketURL) + assert.Contains(t, output.String(), "WeCom connected.") +} From c3631d84ba5ab68c7cb008681e317f32c728fba7 Mon Sep 17 00:00:00 2001 From: Hoshina Date: Tue, 24 Mar 2026 16:12:28 +0800 Subject: [PATCH 25/39] feat(wecom): send media via temp uploads --- pkg/channels/wecom/media.go | 513 ++++++++++++++++++++++++++++++- pkg/channels/wecom/protocol.go | 64 +++- pkg/channels/wecom/wecom.go | 146 +++++++-- pkg/channels/wecom/wecom_test.go | 351 ++++++++++++++++++++- 4 files changed, 1037 insertions(+), 37 deletions(-) diff --git a/pkg/channels/wecom/media.go b/pkg/channels/wecom/media.go index defe226d4..ebcc481e8 100644 --- a/pkg/channels/wecom/media.go +++ b/pkg/channels/wecom/media.go @@ -4,7 +4,10 @@ import ( "context" "crypto/aes" "crypto/cipher" + "crypto/md5" "encoding/base64" + "encoding/hex" + "encoding/json" "fmt" "io" "mime" @@ -13,12 +16,73 @@ import ( "os" "path/filepath" "strings" + "time" "github.com/h2non/filetype" + "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/media" ) +const ( + wecomOutboundMediaMaxBytes = 20 << 20 + wecomOutboundImageMaxBytes = 2 << 20 + wecomOutboundVoiceMaxBytes = 2 << 20 + wecomOutboundVideoMaxBytes = 10 << 20 + wecomUploadChunkMaxBytes = 512 << 10 + wecomUploadMaxChunks = 100 + wecomUploadMinBytes = 5 +) + +type wecomOutboundMedia struct { + MsgType string + MediaID string + Title string + Description string +} + +func (m *wecomOutboundMedia) respondBody() wecomRespondMsgBody { + body := wecomRespondMsgBody{MsgType: m.MsgType} + switch m.MsgType { + case "file": + body.File = &wecomMediaRefContent{MediaID: m.MediaID} + case "image": + body.Image = &wecomMediaRefContent{MediaID: m.MediaID} + case "voice": + body.Voice = &wecomMediaRefContent{MediaID: m.MediaID} + case "video": + body.Video = &wecomVideoContent{ + MediaID: m.MediaID, + Title: m.Title, + Description: m.Description, + } + } + return body +} + +func (m *wecomOutboundMedia) sendBody(chatID string, chatType uint32) wecomSendMsgBody { + body := wecomSendMsgBody{ + ChatID: chatID, + ChatType: chatType, + MsgType: m.MsgType, + } + switch m.MsgType { + case "file": + body.File = &wecomMediaRefContent{MediaID: m.MediaID} + case "image": + body.Image = &wecomMediaRefContent{MediaID: m.MediaID} + case "voice": + body.Voice = &wecomMediaRefContent{MediaID: m.MediaID} + case "video": + body.Video = &wecomVideoContent{ + MediaID: m.MediaID, + Title: m.Title, + Description: m.Description, + } + } + return body +} + func decodeMediaAESKey(value string) ([]byte, error) { if value == "" { return nil, nil @@ -227,12 +291,11 @@ func (c *WeComChannel) storeRemoteMedia( return "", fmt.Errorf("download media returned HTTP %d", resp.StatusCode) } - const maxSize = 20 << 20 - data, err := io.ReadAll(io.LimitReader(resp.Body, maxSize+1)) + data, err := io.ReadAll(io.LimitReader(resp.Body, wecomOutboundMediaMaxBytes+1)) if err != nil { return "", fmt.Errorf("read media: %w", err) } - if len(data) > maxSize { + if len(data) > wecomOutboundMediaMaxBytes { return "", fmt.Errorf("media too large") } @@ -289,3 +352,447 @@ func (c *WeComChannel) storeRemoteMedia( } return ref, nil } + +func detectLocalWeComContentType(localPath, hint string) string { + contentType := normalizeWeComContentType(hint) + if !isGenericWeComContentType(contentType) { + return contentType + } + + if kind, err := filetype.MatchFile(localPath); err == nil && kind != filetype.Unknown { + return normalizeWeComContentType(kind.MIME.Value) + } + + if ext := strings.ToLower(filepath.Ext(localPath)); ext != "" { + if byExt := normalizeWeComContentType(mime.TypeByExtension(ext)); byExt != "" { + return byExt + } + } + + file, err := os.Open(localPath) + if err != nil { + return contentType + } + defer file.Close() + + buf := make([]byte, 512) + n, err := file.Read(buf) + if err != nil && err != io.EOF { + return contentType + } + if n == 0 { + return contentType + } + return normalizeWeComContentType(http.DetectContentType(buf[:n])) +} + +func writeWeComTempFile(prefix, filename string, data []byte) (string, error) { + mediaDir := filepath.Join(os.TempDir(), "picoclaw_media") + if err := os.MkdirAll(mediaDir, 0o700); err != nil { + return "", fmt.Errorf("mkdir media dir: %w", err) + } + + ext := strings.ToLower(filepath.Ext(filename)) + tmpFile, err := os.CreateTemp(mediaDir, prefix+"-*"+ext) + if err != nil { + return "", fmt.Errorf("create temp file: %w", err) + } + tmpPath := tmpFile.Name() + + if _, err := tmpFile.Write(data); err != nil { + _ = tmpFile.Close() + _ = os.Remove(tmpPath) + return "", fmt.Errorf("write temp file: %w", err) + } + if err := tmpFile.Close(); err != nil { + _ = os.Remove(tmpPath) + return "", fmt.Errorf("close temp file: %w", err) + } + return tmpPath, nil +} + +func (c *WeComChannel) downloadRemoteMediaToTemp( + ctx context.Context, + resourceURL, fallbackName string, +) (string, string, string, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, resourceURL, nil) + if err != nil { + return "", "", "", fmt.Errorf("create request: %w", err) + } + + resp, err := c.mediaClient.Do(req) + if err != nil { + return "", "", "", fmt.Errorf("download media: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024)) + return "", "", "", fmt.Errorf("download media returned HTTP %d: %s", resp.StatusCode, string(body)) + } + + data, err := io.ReadAll(io.LimitReader(resp.Body, wecomOutboundMediaMaxBytes+1)) + if err != nil { + return "", "", "", fmt.Errorf("read media: %w", err) + } + if len(data) > wecomOutboundMediaMaxBytes { + return "", "", "", fmt.Errorf("media too large") + } + + filename, contentType := detectWeComMediaMetadata( + data, + fallbackName, + resp.Header.Get("Content-Type"), + resourceURL, + resp.Header.Get("Content-Disposition"), + ) + tmpPath, err := writeWeComTempFile("wecom-outbound", filename, data) + if err != nil { + return "", "", "", err + } + return tmpPath, filename, contentType, nil +} + +func (c *WeComChannel) resolveOutboundPart( + ctx context.Context, + part bus.MediaPart, +) (string, string, string, func(), error) { + cleanup := func() {} + filename := sanitizeWeComFilename(part.Filename) + contentType := normalizeWeComContentType(part.ContentType) + ref := strings.TrimSpace(part.Ref) + + switch { + case ref == "": + return "", filename, contentType, cleanup, nil + + case strings.HasPrefix(ref, "http://") || strings.HasPrefix(ref, "https://"): + localPath, name, ct, err := c.downloadRemoteMediaToTemp(ctx, ref, filename) + if err != nil { + return "", "", "", cleanup, err + } + return localPath, name, ct, func() { _ = os.Remove(localPath) }, nil + + case strings.HasPrefix(ref, "media://"): + store := c.GetMediaStore() + if store == nil { + return "", "", "", cleanup, fmt.Errorf("no media store available") + } + + localPath, meta, err := store.ResolveWithMeta(ref) + if err != nil { + return "", "", "", cleanup, err + } + if filename == "" { + filename = sanitizeWeComFilename(meta.Filename) + } + if contentType == "" { + contentType = normalizeWeComContentType(meta.ContentType) + } + if strings.HasPrefix(localPath, "http://") || strings.HasPrefix(localPath, "https://") { + tmpPath, name, ct, err := c.downloadRemoteMediaToTemp(ctx, localPath, filename) + if err != nil { + return "", "", "", cleanup, err + } + return tmpPath, name, ct, func() { _ = os.Remove(tmpPath) }, nil + } + if _, err := os.Stat(localPath); err != nil { + return "", "", "", cleanup, err + } + if filename == "" { + filename = sanitizeWeComFilename(filepath.Base(localPath)) + } + if contentType == "" { + contentType = detectLocalWeComContentType(localPath, "") + } + return localPath, filename, contentType, cleanup, nil + + case strings.HasPrefix(ref, "file://"): + u, err := url.Parse(ref) + if err != nil { + return "", "", "", cleanup, err + } + localPath := u.Path + if _, err := os.Stat(localPath); err != nil { + return "", "", "", cleanup, err + } + if filename == "" { + filename = sanitizeWeComFilename(filepath.Base(localPath)) + } + if contentType == "" { + contentType = detectLocalWeComContentType(localPath, "") + } + return localPath, filename, contentType, cleanup, nil + + default: + if _, err := os.Stat(ref); err != nil { + return "", "", "", cleanup, err + } + if filename == "" { + filename = sanitizeWeComFilename(filepath.Base(ref)) + } + if contentType == "" { + contentType = detectLocalWeComContentType(ref, "") + } + return ref, filename, contentType, cleanup, nil + } +} + +func canWeComSendImage(contentType, ext string, size int64) bool { + if size > wecomOutboundImageMaxBytes { + return false + } + switch normalizeWeComContentType(contentType) { + case "image/jpeg", "image/jpg", "image/png", "image/gif": + return true + } + switch strings.ToLower(ext) { + case ".jpg", ".jpeg", ".png", ".gif": + return true + default: + return false + } +} + +func canWeComSendVoice(contentType, ext string, size int64) bool { + if size > wecomOutboundVoiceMaxBytes { + return false + } + contentType = normalizeWeComContentType(contentType) + return strings.Contains(contentType, "amr") || strings.EqualFold(ext, ".amr") +} + +func canWeComSendVideo(contentType, ext string, size int64) bool { + if size > wecomOutboundVideoMaxBytes { + return false + } + return normalizeWeComContentType(contentType) == "video/mp4" || strings.EqualFold(ext, ".mp4") +} + +func outboundWeComMediaKind(partType, filename, contentType string, size int64) string { + if size < wecomUploadMinBytes { + return "" + } + + partType = strings.ToLower(strings.TrimSpace(partType)) + contentType = normalizeWeComContentType(contentType) + ext := strings.ToLower(filepath.Ext(filename)) + + if partType == "file" { + if size <= wecomOutboundMediaMaxBytes { + return "file" + } + return "" + } + + if (partType == "image" || partType == "") && canWeComSendImage(contentType, ext, size) { + return "image" + } + if (partType == "audio" || partType == "voice" || partType == "") && canWeComSendVoice(contentType, ext, size) { + return "voice" + } + if (partType == "video" || partType == "") && canWeComSendVideo(contentType, ext, size) { + return "video" + } + if size <= wecomOutboundMediaMaxBytes { + return "file" + } + return "" +} + +func trimWeComBytes(value string, limit int) string { + value = strings.TrimSpace(value) + if limit <= 0 || len(value) <= limit { + return value + } + size := 0 + var out strings.Builder + for _, r := range value { + width := len(string(r)) + if size+width > limit { + break + } + size += width + out.WriteRune(r) + } + return out.String() +} + +func ensureWeComOutboundFilename(filename, localPath, contentType string) string { + filename = sanitizeWeComFilename(filename) + if filename == "" { + filename = sanitizeWeComFilename(filepath.Base(localPath)) + } + if filename == "" { + filename = "media" + } + if filepath.Ext(filename) == "" { + fallbackExt := inferMediaExt(contentType, strings.ToLower(filepath.Ext(localPath))) + if fallbackExt != "" { + filename += fallbackExt + } + } + filename = trimWeComBytes(filename, 256) + if filename == "" { + return "media" + } + return filename +} + +func buildWeComVideoContent(mediaID, filename, description string) *wecomVideoContent { + title := strings.TrimSuffix(filename, filepath.Ext(filename)) + title = trimWeComBytes(title, 64) + if title == "" { + title = "video" + } + description = trimWeComBytes(description, 512) + return &wecomVideoContent{ + MediaID: mediaID, + Title: title, + Description: description, + } +} + +func decodeWeComEnvelopeBody[T any](env wecomEnvelope) (T, error) { + var out T + if len(env.Body) == 0 { + return out, fmt.Errorf("wecom response body is empty") + } + if err := json.Unmarshal(env.Body, &out); err != nil { + return out, fmt.Errorf("decode wecom response body: %w", err) + } + return out, nil +} + +func (c *WeComChannel) uploadOutboundMedia( + ctx context.Context, + localPath, filename, contentType string, + part bus.MediaPart, +) (*wecomOutboundMedia, error) { + _ = ctx + + contentType = detectLocalWeComContentType(localPath, contentType) + filename = ensureWeComOutboundFilename(filename, localPath, contentType) + + data, err := os.ReadFile(localPath) + if err != nil { + return nil, fmt.Errorf("read media file: %w", err) + } + size := int64(len(data)) + kind := outboundWeComMediaKind(part.Type, filename, contentType, size) + if kind == "" { + return nil, fmt.Errorf("unsupported wecom media type or size for %q", filename) + } + + totalChunks := (len(data) + wecomUploadChunkMaxBytes - 1) / wecomUploadChunkMaxBytes + if totalChunks <= 0 || totalChunks > wecomUploadMaxChunks { + return nil, fmt.Errorf("wecom upload requires 1-%d chunks, got %d", wecomUploadMaxChunks, totalChunks) + } + + sum := md5.Sum(data) + initEnv, err := c.sendCommandAck(wecomCommand{ + Cmd: wecomCmdUploadMediaInit, + Headers: wecomHeaders{ReqID: randomID(10)}, + Body: wecomUploadMediaInitBody{ + Type: kind, + Filename: filename, + TotalSize: size, + TotalChunks: totalChunks, + MD5: hex.EncodeToString(sum[:]), + }, + }, wecomUploadTimeout) + if err != nil { + return nil, err + } + initResp, err := decodeWeComEnvelopeBody[wecomUploadMediaInitResponse](initEnv) + if err != nil { + return nil, err + } + if strings.TrimSpace(initResp.UploadID) == "" { + return nil, fmt.Errorf("wecom upload init returned empty upload_id") + } + + for idx, offset := 0, 0; offset < len(data); idx, offset = idx+1, offset+wecomUploadChunkMaxBytes { + end := offset + wecomUploadChunkMaxBytes + if end > len(data) { + end = len(data) + } + if err := c.sendCommand(wecomCommand{ + Cmd: wecomCmdUploadMediaChunk, + Headers: wecomHeaders{ReqID: randomID(10)}, + Body: wecomUploadMediaChunkBody{ + UploadID: initResp.UploadID, + ChunkIndex: idx, + Base64Data: base64.StdEncoding.EncodeToString(data[offset:end]), + }, + }, wecomUploadTimeout); err != nil { + return nil, err + } + } + + finishEnv, err := c.sendCommandAck(wecomCommand{ + Cmd: wecomCmdUploadMediaEnd, + Headers: wecomHeaders{ReqID: randomID(10)}, + Body: wecomUploadMediaFinishBody{ + UploadID: initResp.UploadID, + }, + }, wecomUploadTimeout) + if err != nil { + return nil, err + } + finishResp, err := decodeWeComEnvelopeBody[wecomUploadMediaFinishResponse](finishEnv) + if err != nil { + return nil, err + } + if strings.TrimSpace(finishResp.MediaID) == "" { + return nil, fmt.Errorf("wecom upload finish returned empty media_id") + } + + uploaded := &wecomOutboundMedia{ + MsgType: kind, + MediaID: finishResp.MediaID, + } + if kind == "video" { + video := buildWeComVideoContent(finishResp.MediaID, filename, part.Caption) + uploaded.Title = video.Title + uploaded.Description = video.Description + } + return uploaded, nil +} + +func fallbackWeComMediaText(part bus.MediaPart, kind, filename string) string { + var lines []string + if caption := strings.TrimSpace(part.Caption); caption != "" { + lines = append(lines, caption) + } + + label := kind + if label == "" { + label = "media" + } + if filename != "" { + lines = append(lines, fmt.Sprintf("[%s: %s]", label, filename)) + } else { + lines = append(lines, fmt.Sprintf("[%s attachment]", label)) + } + + ref := strings.TrimSpace(part.Ref) + if strings.HasPrefix(ref, "http://") || strings.HasPrefix(ref, "https://") { + lines = append(lines, ref) + } + + return strings.Join(lines, "\n") +} + +func (c *WeComChannel) resolveMediaRoute(chatID string) (wecomTurn, uint32, bool) { + if turn, ok := c.getTurn(chatID); ok { + if time.Since(turn.CreatedAt) <= wecomStreamMaxDuration { + return turn, turn.ChatType, true + } + c.deleteTurn(chatID) + } + if route, ok := c.routes.Get(chatID); ok { + return wecomTurn{ChatID: route.ChatID, ChatType: route.ChatType}, route.ChatType, false + } + return wecomTurn{ChatID: chatID}, 0, false +} diff --git a/pkg/channels/wecom/protocol.go b/pkg/channels/wecom/protocol.go index 6867d8856..0190e70e5 100644 --- a/pkg/channels/wecom/protocol.go +++ b/pkg/channels/wecom/protocol.go @@ -10,6 +10,9 @@ const ( wecomCmdEventCallback = "aibot_event_callback" wecomCmdRespondMsg = "aibot_respond_msg" wecomCmdSendMsg = "aibot_send_msg" + wecomCmdUploadMediaInit = "aibot_upload_media_init" + wecomCmdUploadMediaChunk = "aibot_upload_media_chunk" + wecomCmdUploadMediaEnd = "aibot_upload_media_finish" wecomMaxContentBytes = 20480 ) @@ -32,15 +35,26 @@ type wecomCommand struct { } type wecomSendMsgBody struct { - ChatID string `json:"chatid"` - ChatType uint32 `json:"chat_type,omitempty"` - MsgType string `json:"msgtype"` - Markdown *wecomMarkdownContent `json:"markdown,omitempty"` + ChatID string `json:"chatid"` + ChatType uint32 `json:"chat_type,omitempty"` + MsgType string `json:"msgtype"` + Markdown *wecomMarkdownContent `json:"markdown,omitempty"` + File *wecomMediaRefContent `json:"file,omitempty"` + Image *wecomMediaRefContent `json:"image,omitempty"` + Voice *wecomMediaRefContent `json:"voice,omitempty"` + Video *wecomVideoContent `json:"video,omitempty"` + TemplateCard map[string]any `json:"template_card,omitempty"` } type wecomRespondMsgBody struct { - MsgType string `json:"msgtype"` - Stream *wecomStreamContent `json:"stream,omitempty"` + MsgType string `json:"msgtype"` + Stream *wecomStreamContent `json:"stream,omitempty"` + Markdown *wecomMarkdownContent `json:"markdown,omitempty"` + File *wecomMediaRefContent `json:"file,omitempty"` + Image *wecomMediaRefContent `json:"image,omitempty"` + Voice *wecomMediaRefContent `json:"voice,omitempty"` + Video *wecomVideoContent `json:"video,omitempty"` + TemplateCard map[string]any `json:"template_card,omitempty"` } type wecomStreamContent struct { @@ -53,6 +67,44 @@ type wecomMarkdownContent struct { Content string `json:"content"` } +type wecomMediaRefContent struct { + MediaID string `json:"media_id"` +} + +type wecomVideoContent struct { + MediaID string `json:"media_id"` + Title string `json:"title,omitempty"` + Description string `json:"description,omitempty"` +} + +type wecomUploadMediaInitBody struct { + Type string `json:"type"` + Filename string `json:"filename"` + TotalSize int64 `json:"total_size"` + TotalChunks int `json:"total_chunks"` + MD5 string `json:"md5,omitempty"` +} + +type wecomUploadMediaInitResponse struct { + UploadID string `json:"upload_id"` +} + +type wecomUploadMediaChunkBody struct { + UploadID string `json:"upload_id"` + ChunkIndex int `json:"chunk_index"` + Base64Data string `json:"base64_data"` +} + +type wecomUploadMediaFinishBody struct { + UploadID string `json:"upload_id"` +} + +type wecomUploadMediaFinishResponse struct { + Type string `json:"type"` + MediaID string `json:"media_id"` + CreatedAt json.RawMessage `json:"created_at"` +} + type wecomIncomingMessage struct { MsgID string `json:"msgid"` AIBotID string `json:"aibotid"` diff --git a/pkg/channels/wecom/wecom.go b/pkg/channels/wecom/wecom.go index 11959c259..ac8f8d9c8 100644 --- a/pkg/channels/wecom/wecom.go +++ b/pkg/channels/wecom/wecom.go @@ -23,6 +23,7 @@ import ( const ( wecomConnectTimeout = 15 * time.Second wecomCommandTimeout = 10 * time.Second + wecomUploadTimeout = 30 * time.Second wecomHeartbeatInterval = 30 * time.Second wecomStreamMaxDuration = 5*time.Minute + 30*time.Second wecomRouteTTL = 30 * time.Minute @@ -49,7 +50,7 @@ type WeComChannel struct { recent *recentMessageSet routes *reqIDStore mediaClient *http.Client - commandSend func(wecomCommand, time.Duration) error + commandSend func(wecomCommand, time.Duration) (wecomEnvelope, error) } type wecomTurn struct { @@ -187,22 +188,74 @@ func (c *WeComChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessa if !c.IsRunning() { return channels.ErrNotRunning } - var parts []string + + route, chatType, hasTurn := c.resolveMediaRoute(msg.ChatID) + chatID := route.ChatID + if chatID == "" { + chatID = msg.ChatID + } + for _, part := range msg.Parts { - switch { - case part.Caption != "": - parts = append(parts, part.Caption) - case part.Filename != "": - parts = append(parts, fmt.Sprintf("[media: %s]", part.Filename)) - default: - parts = append(parts, "[media attachments are not yet supported]") + if strings.TrimSpace(part.Ref) == "" { + if caption := strings.TrimSpace(part.Caption); caption != "" { + if err := c.sendActivePush(chatID, chatType, caption); err != nil { + return err + } + } + continue + } + + localPath, filename, contentType, cleanup, err := c.resolveOutboundPart(ctx, part) + if err != nil { + return fmt.Errorf("wecom resolve media %q: %v: %w", part.Ref, err, channels.ErrSendFailed) + } + + func() { + if cleanup != nil { + defer cleanup() + } + + uploaded, uploadErr := c.uploadOutboundMedia(ctx, localPath, filename, contentType, part) + if uploadErr != nil { + logger.WarnCF("wecom", "Falling back to placeholder after media upload failure", map[string]any{ + "chat_id": chatID, + "ref": part.Ref, + "filename": filename, + "content_type": contentType, + "error": uploadErr.Error(), + }) + if hasTurn { + if finishErr := c.sendStreamChunk(route, true, ""); finishErr != nil { + err = finishErr + return + } + c.deleteTurn(msg.ChatID) + hasTurn = false + } + err = c.sendActivePush(chatID, chatType, fallbackWeComMediaText(part, "", filename)) + return + } + + if hasTurn { + err = c.sendTurnMedia(route, uploaded) + c.deleteTurn(msg.ChatID) + hasTurn = false + } else { + err = c.sendActiveMedia(chatID, chatType, uploaded) + } + if err != nil { + return + } + if caption := strings.TrimSpace(part.Caption); caption != "" { + err = c.sendActivePush(chatID, chatType, caption) + } + }() + if err != nil { + return err } } - return c.Send(ctx, bus.OutboundMessage{ - Channel: msg.Channel, - ChatID: msg.ChatID, - Content: strings.Join(parts, "\n"), - }) + + return nil } func (c *WeComChannel) connectLoop() { @@ -620,6 +673,20 @@ func (c *WeComChannel) sendStreamChunk(turn wecomTurn, finish bool, content stri }, wecomCommandTimeout) } +func (c *WeComChannel) sendTurnMedia(turn wecomTurn, uploaded *wecomOutboundMedia) error { + if uploaded == nil { + return fmt.Errorf("wecom outbound media is nil: %w", channels.ErrSendFailed) + } + if err := c.sendCommand(wecomCommand{ + Cmd: wecomCmdRespondMsg, + Headers: wecomHeaders{ReqID: turn.ReqID}, + Body: uploaded.respondBody(), + }, wecomCommandTimeout); err != nil { + return err + } + return c.sendStreamChunk(turn, true, "") +} + func (c *WeComChannel) sendActivePush(chatID string, chatType uint32, content string) error { if strings.TrimSpace(chatID) == "" { return fmt.Errorf("empty chat ID: %w", channels.ErrSendFailed) @@ -641,24 +708,57 @@ func (c *WeComChannel) sendActivePush(chatID string, chatType uint32, content st return nil } +func (c *WeComChannel) sendActiveMedia(chatID string, chatType uint32, uploaded *wecomOutboundMedia) error { + if strings.TrimSpace(chatID) == "" { + return fmt.Errorf("empty chat ID: %w", channels.ErrSendFailed) + } + if uploaded == nil { + return fmt.Errorf("wecom outbound media is nil: %w", channels.ErrSendFailed) + } + return c.sendCommand(wecomCommand{ + Cmd: wecomCmdSendMsg, + Headers: wecomHeaders{ReqID: randomID(10)}, + Body: uploaded.sendBody(chatID, chatType), + }, wecomCommandTimeout) +} + func (c *WeComChannel) sendCommand(cmd wecomCommand, timeout time.Duration) error { + _, err := c.sendCommandAck(cmd, timeout) + return err +} + +func (c *WeComChannel) sendCommandAck(cmd wecomCommand, timeout time.Duration) (wecomEnvelope, error) { if c.commandSend != nil { return c.commandSend(cmd, timeout) } - return c.writeCurrent(cmd, timeout) + return c.writeCurrentAck(cmd, timeout) } func (c *WeComChannel) writeCurrent(cmd wecomCommand, timeout time.Duration) error { + _, err := c.writeCurrentAck(cmd, timeout) + return err +} + +func (c *WeComChannel) writeCurrentAck(cmd wecomCommand, timeout time.Duration) (wecomEnvelope, error) { c.connMu.Lock() conn := c.conn c.connMu.Unlock() if conn == nil { - return fmt.Errorf("wecom websocket not connected: %w", channels.ErrTemporary) + return wecomEnvelope{}, fmt.Errorf("wecom websocket not connected: %w", channels.ErrTemporary) } - return c.writeAndWait(conn, cmd, timeout) + return c.writeAndWaitAck(conn, cmd, timeout) } func (c *WeComChannel) writeAndWait(conn *websocket.Conn, cmd wecomCommand, timeout time.Duration) error { + _, err := c.writeAndWaitAck(conn, cmd, timeout) + return err +} + +func (c *WeComChannel) writeAndWaitAck( + conn *websocket.Conn, + cmd wecomCommand, + timeout time.Duration, +) (wecomEnvelope, error) { if cmd.Headers.ReqID == "" { cmd.Headers.ReqID = randomID(10) } @@ -674,13 +774,13 @@ func (c *WeComChannel) writeAndWait(conn *websocket.Conn, cmd wecomCommand, time data, err := json.Marshal(cmd) if err != nil { - return fmt.Errorf("%w: %v", channels.ErrSendFailed, err) + return wecomEnvelope{}, fmt.Errorf("%w: %v", channels.ErrSendFailed, err) } c.connMu.Lock() err = conn.WriteMessage(websocket.TextMessage, data) c.connMu.Unlock() if err != nil { - return fmt.Errorf("%w: %v", channels.ErrTemporary, err) + return wecomEnvelope{}, fmt.Errorf("%w: %v", channels.ErrTemporary, err) } timer := time.NewTimer(timeout) @@ -688,13 +788,13 @@ func (c *WeComChannel) writeAndWait(conn *websocket.Conn, cmd wecomCommand, time select { case env := <-waitCh: if env.ErrCode != 0 { - return fmt.Errorf("%w: wecom errcode=%d errmsg=%s", channels.ErrTemporary, env.ErrCode, env.ErrMsg) + return wecomEnvelope{}, fmt.Errorf("%w: wecom errcode=%d errmsg=%s", channels.ErrTemporary, env.ErrCode, env.ErrMsg) } - return nil + return env, nil case <-timer.C: - return fmt.Errorf("%w: timeout waiting for WeCom ack", channels.ErrTemporary) + return wecomEnvelope{}, fmt.Errorf("%w: timeout waiting for WeCom ack", channels.ErrTemporary) case <-c.ctx.Done(): - return c.ctx.Err() + return wecomEnvelope{}, c.ctx.Err() } } diff --git a/pkg/channels/wecom/wecom_test.go b/pkg/channels/wecom/wecom_test.go index e0ee2e628..45176015f 100644 --- a/pkg/channels/wecom/wecom_test.go +++ b/pkg/channels/wecom/wecom_test.go @@ -2,13 +2,16 @@ package wecom import ( "context" + "encoding/json" "errors" + "os" "path/filepath" "testing" "time" "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/media" ) func TestDispatchIncoming_UsesActualChatIDAndStoresReqIDRoute(t *testing.T) { @@ -18,9 +21,9 @@ func TestDispatchIncoming_UsesActualChatIDAndStoresReqIDRoute(t *testing.T) { ch := newTestWeComChannel(t, messageBus) var commands []wecomCommand - ch.commandSend = func(cmd wecomCommand, _ time.Duration) error { + ch.commandSend = func(cmd wecomCommand, _ time.Duration) (wecomEnvelope, error) { commands = append(commands, cmd) - return nil + return wecomTestAck(nil), nil } msg := wecomIncomingMessage{ @@ -107,12 +110,12 @@ func TestSend_StreamFailureFallsBackToActualChatID(t *testing.T) { } var commands []wecomCommand - ch.commandSend = func(cmd wecomCommand, _ time.Duration) error { + ch.commandSend = func(cmd wecomCommand, _ time.Duration) (wecomEnvelope, error) { commands = append(commands, cmd) if len(commands) == 1 && cmd.Cmd == wecomCmdRespondMsg { - return errors.New("stream send failed") + return wecomEnvelope{}, errors.New("stream send failed") } - return nil + return wecomTestAck(nil), nil } if err := ch.Send(context.Background(), bus.OutboundMessage{ @@ -152,6 +155,301 @@ func TestSend_StreamFailureFallsBackToActualChatID(t *testing.T) { } } +func TestSendMedia_SendsActiveImage(t *testing.T) { + t.Parallel() + + ch := newTestWeComChannel(t, bus.NewMessageBus()) + ch.SetRunning(true) + + store := media.NewFileMediaStore() + ch.SetMediaStore(store) + + imageData := wecomTestJPEGData(t) + imagePath := filepath.Join(t.TempDir(), "photo.jpg") + if err := os.WriteFile(imagePath, imageData, 0o600); err != nil { + t.Fatalf("WriteFile() error = %v", err) + } + ref, err := store.Store(imagePath, media.MediaMeta{ + Filename: "photo.jpg", + ContentType: "image/jpeg", + Source: "test", + CleanupPolicy: media.CleanupPolicyForgetOnly, + }, "scope-1") + if err != nil { + t.Fatalf("Store() error = %v", err) + } + + var commands []wecomCommand + ch.commandSend = func(cmd wecomCommand, _ time.Duration) (wecomEnvelope, error) { + commands = append(commands, cmd) + switch cmd.Cmd { + case wecomCmdUploadMediaInit: + return wecomTestAck(wecomUploadMediaInitResponse{UploadID: "upload-1"}), nil + case wecomCmdUploadMediaEnd: + return wecomTestAck(wecomUploadMediaFinishResponse{ + Type: "image", + MediaID: "media-1", + }), nil + default: + return wecomTestAck(nil), nil + } + } + + err = ch.SendMedia(context.Background(), bus.OutboundMediaMessage{ + Channel: "wecom", + ChatID: "chat-1", + Parts: []bus.MediaPart{{ + Ref: ref, + Type: "image", + Filename: "photo.jpg", + ContentType: "image/jpeg", + }}, + }) + if err != nil { + t.Fatalf("SendMedia() error = %v", err) + } + + if len(commands) != 4 { + t.Fatalf("expected 4 commands, got %d", len(commands)) + } + if commands[0].Cmd != wecomCmdUploadMediaInit { + t.Fatalf("first command = %q, want %q", commands[0].Cmd, wecomCmdUploadMediaInit) + } + initBody, ok := commands[0].Body.(wecomUploadMediaInitBody) + if !ok { + t.Fatalf("unexpected init body type %T", commands[0].Body) + } + if initBody.Type != "image" || initBody.Filename != "photo.jpg" || initBody.TotalChunks != 1 { + t.Fatalf("init body = %+v", initBody) + } + if commands[1].Cmd != wecomCmdUploadMediaChunk { + t.Fatalf("second command = %q, want %q", commands[1].Cmd, wecomCmdUploadMediaChunk) + } + chunkBody, ok := commands[1].Body.(wecomUploadMediaChunkBody) + if !ok { + t.Fatalf("unexpected chunk body type %T", commands[1].Body) + } + if chunkBody.UploadID != "upload-1" || chunkBody.ChunkIndex != 0 || chunkBody.Base64Data == "" { + t.Fatalf("chunk body = %+v", chunkBody) + } + if commands[2].Cmd != wecomCmdUploadMediaEnd { + t.Fatalf("third command = %q, want %q", commands[2].Cmd, wecomCmdUploadMediaEnd) + } + if commands[3].Cmd != wecomCmdSendMsg { + t.Fatalf("fourth command = %q, want %q", commands[3].Cmd, wecomCmdSendMsg) + } + + body, ok := commands[3].Body.(wecomSendMsgBody) + if !ok { + t.Fatalf("unexpected send body type %T", commands[3].Body) + } + if body.MsgType != "image" || body.Image == nil { + t.Fatalf("send body = %+v", body) + } + if body.ChatID != "chat-1" { + t.Fatalf("send chatid = %q, want chat-1", body.ChatID) + } + if body.Image.MediaID != "media-1" { + t.Fatalf("image media_id = %q, want media-1", body.Image.MediaID) + } +} + +func TestSendMedia_UsesTurnImageAndFinishesStream(t *testing.T) { + t.Parallel() + + ch := newTestWeComChannel(t, bus.NewMessageBus()) + ch.SetRunning(true) + + store := media.NewFileMediaStore() + ch.SetMediaStore(store) + + imageData := wecomTestJPEGData(t) + imagePath := filepath.Join(t.TempDir(), "reply.jpg") + if err := os.WriteFile(imagePath, imageData, 0o600); err != nil { + t.Fatalf("WriteFile() error = %v", err) + } + ref, err := store.Store(imagePath, media.MediaMeta{ + Filename: "reply.jpg", + ContentType: "image/jpeg", + Source: "test", + CleanupPolicy: media.CleanupPolicyForgetOnly, + }, "scope-2") + if err != nil { + t.Fatalf("Store() error = %v", err) + } + + ch.queueTurn("chat-1", wecomTurn{ + ReqID: "req-1", + ChatID: "chat-1", + ChatType: 1, + StreamID: "stream-1", + CreatedAt: time.Now(), + }) + if err := ch.routes.Put("chat-1", "req-1", 1, time.Hour); err != nil { + t.Fatalf("Put() error = %v", err) + } + + var commands []wecomCommand + ch.commandSend = func(cmd wecomCommand, _ time.Duration) (wecomEnvelope, error) { + commands = append(commands, cmd) + switch cmd.Cmd { + case wecomCmdUploadMediaInit: + return wecomTestAck(wecomUploadMediaInitResponse{UploadID: "upload-2"}), nil + case wecomCmdUploadMediaEnd: + return wecomTestAck(wecomUploadMediaFinishResponse{ + Type: "image", + MediaID: "media-2", + }), nil + default: + return wecomTestAck(nil), nil + } + } + + err = ch.SendMedia(context.Background(), bus.OutboundMediaMessage{ + Channel: "wecom", + ChatID: "chat-1", + Parts: []bus.MediaPart{{ + Ref: ref, + Type: "image", + Filename: "reply.jpg", + ContentType: "image/jpeg", + }}, + }) + if err != nil { + t.Fatalf("SendMedia() error = %v", err) + } + + if len(commands) != 5 { + t.Fatalf("expected 5 commands, got %d", len(commands)) + } + if commands[0].Cmd != wecomCmdUploadMediaInit { + t.Fatalf("first command = %+v", commands[0]) + } + if commands[1].Cmd != wecomCmdUploadMediaChunk { + t.Fatalf("second command = %+v", commands[1]) + } + if commands[2].Cmd != wecomCmdUploadMediaEnd { + t.Fatalf("third command = %+v", commands[2]) + } + if commands[3].Cmd != wecomCmdRespondMsg || commands[3].Headers.ReqID != "req-1" { + t.Fatalf("fourth command = %+v", commands[3]) + } + if commands[4].Cmd != wecomCmdRespondMsg || commands[4].Headers.ReqID != "req-1" { + t.Fatalf("fifth command = %+v", commands[4]) + } + + imageBody, ok := commands[3].Body.(wecomRespondMsgBody) + if !ok { + t.Fatalf("unexpected image body type %T", commands[3].Body) + } + if imageBody.MsgType != "image" || imageBody.Image == nil { + t.Fatalf("image body = %+v", imageBody) + } + if imageBody.Image.MediaID != "media-2" { + t.Fatalf("image media_id = %q, want media-2", imageBody.Image.MediaID) + } + + streamBody, ok := commands[4].Body.(wecomRespondMsgBody) + if !ok { + t.Fatalf("unexpected finish body type %T", commands[4].Body) + } + if streamBody.MsgType != "stream" || streamBody.Stream == nil || !streamBody.Stream.Finish { + t.Fatalf("finish body = %+v", streamBody) + } + + if _, ok := ch.getTurn("chat-1"); ok { + t.Fatal("expected turn to be removed after media send") + } +} + +func TestSendMedia_SendsActiveFile(t *testing.T) { + t.Parallel() + + ch := newTestWeComChannel(t, bus.NewMessageBus()) + ch.SetRunning(true) + + store := media.NewFileMediaStore() + ch.SetMediaStore(store) + + filePath := filepath.Join(t.TempDir(), "report.pdf") + if err := os.WriteFile(filePath, []byte("%PDF-1.4"), 0o600); err != nil { + t.Fatalf("WriteFile() error = %v", err) + } + ref, err := store.Store(filePath, media.MediaMeta{ + Filename: "report.pdf", + ContentType: "application/pdf", + Source: "test", + CleanupPolicy: media.CleanupPolicyForgetOnly, + }, "scope-3") + if err != nil { + t.Fatalf("Store() error = %v", err) + } + + var commands []wecomCommand + ch.commandSend = func(cmd wecomCommand, _ time.Duration) (wecomEnvelope, error) { + commands = append(commands, cmd) + switch cmd.Cmd { + case wecomCmdUploadMediaInit: + return wecomTestAck(wecomUploadMediaInitResponse{UploadID: "upload-3"}), nil + case wecomCmdUploadMediaEnd: + return wecomTestAck(wecomUploadMediaFinishResponse{ + Type: "file", + MediaID: "media-3", + }), nil + default: + return wecomTestAck(nil), nil + } + } + + err = ch.SendMedia(context.Background(), bus.OutboundMediaMessage{ + Channel: "wecom", + ChatID: "chat-2", + Parts: []bus.MediaPart{{ + Ref: ref, + Type: "file", + Filename: "report.pdf", + ContentType: "application/pdf", + }}, + }) + if err != nil { + t.Fatalf("SendMedia() error = %v", err) + } + + if len(commands) != 4 { + t.Fatalf("expected 4 commands, got %d", len(commands)) + } + if commands[0].Cmd != wecomCmdUploadMediaInit { + t.Fatalf("first command = %q, want %q", commands[0].Cmd, wecomCmdUploadMediaInit) + } + initBody, ok := commands[0].Body.(wecomUploadMediaInitBody) + if !ok { + t.Fatalf("unexpected init body type %T", commands[0].Body) + } + if initBody.Type != "file" || initBody.Filename != "report.pdf" { + t.Fatalf("init body = %+v", initBody) + } + if commands[1].Cmd != wecomCmdUploadMediaChunk { + t.Fatalf("second command = %q, want %q", commands[1].Cmd, wecomCmdUploadMediaChunk) + } + if commands[2].Cmd != wecomCmdUploadMediaEnd { + t.Fatalf("third command = %q, want %q", commands[2].Cmd, wecomCmdUploadMediaEnd) + } + if commands[3].Cmd != wecomCmdSendMsg { + t.Fatalf("fourth command = %q, want %q", commands[3].Cmd, wecomCmdSendMsg) + } + + body, ok := commands[3].Body.(wecomSendMsgBody) + if !ok { + t.Fatalf("unexpected body type %T", commands[3].Body) + } + if body.MsgType != "file" || body.File == nil { + t.Fatalf("body = %+v", body) + } + if body.File.MediaID != "media-3" { + t.Fatalf("file media_id = %q, want media-3", body.File.MediaID) + } +} + func newTestWeComChannel(t *testing.T, messageBus *bus.MessageBus) *WeComChannel { t.Helper() @@ -165,3 +463,46 @@ func newTestWeComChannel(t *testing.T, messageBus *bus.MessageBus) *WeComChannel ch.routes = newReqIDStore(filepath.Join(t.TempDir(), "reqids.json")) return ch } + +func wecomTestJPEGData(t *testing.T) []byte { + t.Helper() + + const jpegBase64 = "/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAP//////////////////////////////////////////////////////////////////////////////////////" + + "//////////////////////////////////////////////////////////////////////////////////////////////2wBDAf//////////////////////////////////////////////////////////////////////////////////////" + + "//////////////////////////////////////////////////////////////////////////////////////////////wAARCAABAAEDASIAAhEBAxEB/8QAFQABAQAAAAAAAAAAAAAAAAAAAAb/xAAVEQEBAAAAAAAAAAAAAAAAAAAABf/aAAwDAQACEAMQAAAB6A//xAAVEAEBAAAAAAAAAAAAAAAAAAAAEf/aAAgBAQABBQJf/8QAFBEBAAAAAAAAAAAAAAAAAAAAEP/aAAgBAwEBPwF//8QAFBEBAAAAAAAAAAAAAAAAAAAAEP/aAAgBAgEBPwF//8QAFBABAAAAAAAAAAAAAAAAAAAAEP/aAAgBAQAGPwJf/8QAFBABAAAAAAAAAAAAAAAAAAAAEP/aAAgBAQABPyFf/9k=" + + return decodeTestBase64(t, jpegBase64) +} + +func TestDecodeWeComUploadFinish_AcceptsNumericCreatedAt(t *testing.T) { + t.Parallel() + + resp, err := decodeWeComEnvelopeBody[wecomUploadMediaFinishResponse](wecomEnvelope{ + Body: json.RawMessage(`{"type":"file","media_id":"media-1","created_at":1380000000}`), + }) + if err != nil { + t.Fatalf("decodeWeComEnvelopeBody() error = %v", err) + } + if resp.Type != "file" || resp.MediaID != "media-1" { + t.Fatalf("response = %+v", resp) + } + if string(resp.CreatedAt) != "1380000000" { + t.Fatalf("created_at = %s, want 1380000000", string(resp.CreatedAt)) + } +} + +func wecomTestAck(body any) wecomEnvelope { + var raw []byte + if body != nil { + encoded, err := json.Marshal(body) + if err != nil { + panic(err) + } + raw = encoded + } + return wecomEnvelope{ + ErrCode: 0, + ErrMsg: "ok", + Body: raw, + } +} From 11b6b10d5947478e70c80acfc6c0841f8142bcca Mon Sep 17 00:00:00 2001 From: Hoshina Date: Tue, 24 Mar 2026 16:34:21 +0800 Subject: [PATCH 26/39] fix(linter): fix ci lint err --- pkg/channels/wecom/media.go | 12 ++++++++---- pkg/channels/wecom/wecom.go | 12 ++++++------ pkg/channels/wecom/wecom_test.go | 5 +++-- 3 files changed, 17 insertions(+), 12 deletions(-) diff --git a/pkg/channels/wecom/media.go b/pkg/channels/wecom/media.go index ebcc481e8..974a3bf4d 100644 --- a/pkg/channels/wecom/media.go +++ b/pkg/channels/wecom/media.go @@ -216,7 +216,10 @@ func detectWeComFiletype(data []byte) (string, string) { return normalizeWeComContentType(kind.MIME.Value), ext } -func detectWeComMediaMetadata(data []byte, fallbackName, fallbackContentType, resourceURL, contentDisposition string) (string, string) { +func detectWeComMediaMetadata( + data []byte, + fallbackName, fallbackContentType, resourceURL, contentDisposition string, +) (string, string) { filename := candidateWeComFilename(resourceURL, contentDisposition, fallbackName) if filename == "" { filename = "media" @@ -717,7 +720,7 @@ func (c *WeComChannel) uploadOutboundMedia( if end > len(data) { end = len(data) } - if err := c.sendCommand(wecomCommand{ + sendErr := c.sendCommand(wecomCommand{ Cmd: wecomCmdUploadMediaChunk, Headers: wecomHeaders{ReqID: randomID(10)}, Body: wecomUploadMediaChunkBody{ @@ -725,8 +728,9 @@ func (c *WeComChannel) uploadOutboundMedia( ChunkIndex: idx, Base64Data: base64.StdEncoding.EncodeToString(data[offset:end]), }, - }, wecomUploadTimeout); err != nil { - return nil, err + }, wecomUploadTimeout) + if sendErr != nil { + return nil, sendErr } } diff --git a/pkg/channels/wecom/wecom.go b/pkg/channels/wecom/wecom.go index ac8f8d9c8..075c1732f 100644 --- a/pkg/channels/wecom/wecom.go +++ b/pkg/channels/wecom/wecom.go @@ -734,11 +734,6 @@ func (c *WeComChannel) sendCommandAck(cmd wecomCommand, timeout time.Duration) ( return c.writeCurrentAck(cmd, timeout) } -func (c *WeComChannel) writeCurrent(cmd wecomCommand, timeout time.Duration) error { - _, err := c.writeCurrentAck(cmd, timeout) - return err -} - func (c *WeComChannel) writeCurrentAck(cmd wecomCommand, timeout time.Duration) (wecomEnvelope, error) { c.connMu.Lock() conn := c.conn @@ -788,7 +783,12 @@ func (c *WeComChannel) writeAndWaitAck( select { case env := <-waitCh: if env.ErrCode != 0 { - return wecomEnvelope{}, fmt.Errorf("%w: wecom errcode=%d errmsg=%s", channels.ErrTemporary, env.ErrCode, env.ErrMsg) + return wecomEnvelope{}, fmt.Errorf( + "%w: wecom errcode=%d errmsg=%s", + channels.ErrTemporary, + env.ErrCode, + env.ErrMsg, + ) } return env, nil case <-timer.C: diff --git a/pkg/channels/wecom/wecom_test.go b/pkg/channels/wecom/wecom_test.go index 45176015f..478423307 100644 --- a/pkg/channels/wecom/wecom_test.go +++ b/pkg/channels/wecom/wecom_test.go @@ -285,8 +285,9 @@ func TestSendMedia_UsesTurnImageAndFinishesStream(t *testing.T) { StreamID: "stream-1", CreatedAt: time.Now(), }) - if err := ch.routes.Put("chat-1", "req-1", 1, time.Hour); err != nil { - t.Fatalf("Put() error = %v", err) + putErr := ch.routes.Put("chat-1", "req-1", 1, time.Hour) + if putErr != nil { + t.Fatalf("Put() error = %v", putErr) } var commands []wecomCommand From 3b498d2e4b2b991c4c33cbc55e04d01cec307bbb Mon Sep 17 00:00:00 2001 From: Hoshina Date: Tue, 24 Mar 2026 20:17:16 +0800 Subject: [PATCH 27/39] feat(wecom): add channel-side streaming support --- pkg/channels/wecom/protocol.go | 1 - pkg/channels/wecom/wecom.go | 193 +++++++++++++++++++++++-------- pkg/channels/wecom/wecom_test.go | 151 ++++++++++++++++++++++++ 3 files changed, 294 insertions(+), 51 deletions(-) diff --git a/pkg/channels/wecom/protocol.go b/pkg/channels/wecom/protocol.go index 0190e70e5..f42ce3bf4 100644 --- a/pkg/channels/wecom/protocol.go +++ b/pkg/channels/wecom/protocol.go @@ -13,7 +13,6 @@ const ( wecomCmdUploadMediaInit = "aibot_upload_media_init" wecomCmdUploadMediaChunk = "aibot_upload_media_chunk" wecomCmdUploadMediaEnd = "aibot_upload_media_finish" - wecomMaxContentBytes = 20480 ) type wecomEnvelope struct { diff --git a/pkg/channels/wecom/wecom.go b/pkg/channels/wecom/wecom.go index 075c1732f..26e971921 100644 --- a/pkg/channels/wecom/wecom.go +++ b/pkg/channels/wecom/wecom.go @@ -26,6 +26,7 @@ const ( wecomUploadTimeout = 30 * time.Second wecomHeartbeatInterval = 30 * time.Second wecomStreamMaxDuration = 5*time.Minute + 30*time.Second + wecomStreamMinInterval = 500 * time.Millisecond wecomRouteTTL = 30 * time.Minute wecomMediaTimeout = 30 * time.Second wecomRecentMessageMax = 1000 @@ -61,6 +62,17 @@ type wecomTurn struct { CreatedAt time.Time } +type wecomStreamer struct { + channel *WeComChannel + chatID string + turn wecomTurn + + mu sync.Mutex + closed bool + lastSentAt time.Time + content string +} + type recentMessageSet struct { mu sync.Mutex seen map[string]struct{} @@ -109,7 +121,6 @@ func NewChannel(cfg config.WeComConfig, messageBus *bus.MessageBus) (*WeComChann cfg, messageBus, cfg.AllowFrom, - channels.WithMaxMessageLength(wecomMaxContentBytes), channels.WithReasoningChannelID(cfg.ReasoningChannelID), ) @@ -152,6 +163,27 @@ func (c *WeComChannel) Stop(_ context.Context) error { return nil } +func (c *WeComChannel) BeginStream(_ context.Context, chatID string) (channels.Streamer, error) { + if !c.IsRunning() { + return nil, channels.ErrNotRunning + } + + turn, ok := c.getTurn(chatID) + if !ok { + return nil, fmt.Errorf("wecom streaming unavailable: no active turn") + } + if time.Since(turn.CreatedAt) > wecomStreamMaxDuration { + c.consumeTurn(chatID, turn) + return nil, fmt.Errorf("wecom streaming unavailable: turn expired") + } + + return &wecomStreamer{ + channel: c, + chatID: chatID, + turn: turn, + }, nil +} + func (c *WeComChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { if !c.IsRunning() { return channels.ErrNotRunning @@ -164,11 +196,11 @@ func (c *WeComChannel) Send(ctx context.Context, msg bus.OutboundMessage) error if turn, ok := c.getTurn(msg.ChatID); ok { if time.Since(turn.CreatedAt) <= wecomStreamMaxDuration { if err := c.sendStreamReply(turn, content); err == nil { - c.deleteTurn(msg.ChatID) + c.consumeTurn(msg.ChatID, turn) return nil } } - c.deleteTurn(msg.ChatID) + c.consumeTurn(msg.ChatID, turn) } if route, ok := c.routes.Get(msg.ChatID); ok { @@ -649,13 +681,7 @@ func (c *WeComChannel) respondImmediate(reqID, content string) error { } func (c *WeComChannel) sendStreamReply(turn wecomTurn, content string) error { - chunks := splitContent(content, wecomMaxContentBytes) - for idx, chunk := range chunks { - if err := c.sendStreamChunk(turn, idx == len(chunks)-1, chunk); err != nil { - return err - } - } - return nil + return c.sendStreamChunk(turn, true, content) } func (c *WeComChannel) sendStreamChunk(turn wecomTurn, finish bool, content string) error { @@ -691,21 +717,16 @@ func (c *WeComChannel) sendActivePush(chatID string, chatType uint32, content st if strings.TrimSpace(chatID) == "" { return fmt.Errorf("empty chat ID: %w", channels.ErrSendFailed) } - for _, chunk := range splitContent(content, wecomMaxContentBytes) { - if err := c.sendCommand(wecomCommand{ - Cmd: wecomCmdSendMsg, - Headers: wecomHeaders{ReqID: randomID(10)}, - Body: wecomSendMsgBody{ - ChatID: chatID, - ChatType: chatType, - MsgType: "markdown", - Markdown: &wecomMarkdownContent{Content: chunk}, - }, - }, wecomCommandTimeout); err != nil { - return err - } - } - return nil + return c.sendCommand(wecomCommand{ + Cmd: wecomCmdSendMsg, + Headers: wecomHeaders{ReqID: randomID(10)}, + Body: wecomSendMsgBody{ + ChatID: chatID, + ChatType: chatType, + MsgType: "markdown", + Markdown: &wecomMarkdownContent{Content: content}, + }, + }, wecomCommandTimeout) } func (c *WeComChannel) sendActiveMedia(chatID string, chatType uint32, uploaded *wecomOutboundMedia) error { @@ -825,6 +846,26 @@ func (c *WeComChannel) queueTurn(chatID string, turn wecomTurn) { c.turns[chatID] = append(c.turns[chatID], turn) } +func (c *WeComChannel) consumeTurn(chatID string, turn wecomTurn) bool { + c.turnsMu.Lock() + defer c.turnsMu.Unlock() + + queue := c.turns[chatID] + if len(queue) == 0 { + return false + } + current := queue[0] + if current.ReqID != turn.ReqID || current.StreamID != turn.StreamID { + return false + } + if len(queue) == 1 { + delete(c.turns, chatID) + return true + } + c.turns[chatID] = queue[1:] + return true +} + func (c *WeComChannel) clearTurns() { c.turnsMu.Lock() c.turns = make(map[string][]wecomTurn) @@ -844,34 +885,86 @@ func randomID(n int) string { return string(buf) } -func splitContent(content string, maxBytes int) []string { - if content == "" { - return []string{""} +func (s *wecomStreamer) Update(ctx context.Context, content string) error { + s.mu.Lock() + defer s.mu.Unlock() + + if s.closed { + return nil } - if len(content) <= maxBytes { - return []string{content} + if err := s.validateActiveTurn(); err != nil { + return err } - chunks := channels.SplitMessage(content, maxBytes) - var result []string - for _, chunk := range chunks { - if len(chunk) <= maxBytes { - result = append(result, chunk) - continue - } - for len(chunk) > maxBytes { - end := maxBytes - for end > 0 && chunk[end]>>6 == 0b10 { - end-- + if err := ctx.Err(); err != nil { + return err + } + + if !s.lastSentAt.IsZero() { + wait := time.Until(s.lastSentAt.Add(wecomStreamMinInterval)) + if wait > 0 { + timer := time.NewTimer(wait) + defer timer.Stop() + select { + case <-ctx.Done(): + return ctx.Err() + case <-timer.C: } - if end == 0 { - end = maxBytes - } - result = append(result, chunk[:end]) - chunk = strings.TrimLeft(chunk[end:], " \t\r\n") - } - if chunk != "" { - result = append(result, chunk) } } - return result + + if err := s.channel.sendStreamChunk(s.turn, false, content); err != nil { + return err + } + s.content = content + s.lastSentAt = time.Now() + return nil +} + +func (s *wecomStreamer) Finalize(ctx context.Context, content string) error { + s.mu.Lock() + defer s.mu.Unlock() + + if s.closed { + return nil + } + if err := s.validateActiveTurn(); err != nil { + return err + } + if err := ctx.Err(); err != nil { + return err + } + if err := s.channel.sendStreamChunk(s.turn, true, content); err != nil { + return err + } + + s.content = content + s.closed = true + s.channel.consumeTurn(s.chatID, s.turn) + return nil +} + +func (s *wecomStreamer) Cancel(_ context.Context) { + s.mu.Lock() + defer s.mu.Unlock() + + if s.closed { + return + } + if s.validateActiveTurn() == nil { + _ = s.channel.sendStreamChunk(s.turn, true, s.content) + s.channel.consumeTurn(s.chatID, s.turn) + } + s.closed = true +} + +func (s *wecomStreamer) validateActiveTurn() error { + if time.Since(s.turn.CreatedAt) > wecomStreamMaxDuration { + s.channel.consumeTurn(s.chatID, s.turn) + return fmt.Errorf("wecom streaming unavailable: turn expired") + } + current, ok := s.channel.getTurn(s.chatID) + if !ok || current.ReqID != s.turn.ReqID || current.StreamID != s.turn.StreamID { + return fmt.Errorf("wecom streaming unavailable: turn no longer active") + } + return nil } diff --git a/pkg/channels/wecom/wecom_test.go b/pkg/channels/wecom/wecom_test.go index 478423307..c7a4adfc0 100644 --- a/pkg/channels/wecom/wecom_test.go +++ b/pkg/channels/wecom/wecom_test.go @@ -6,6 +6,7 @@ import ( "errors" "os" "path/filepath" + "strings" "testing" "time" @@ -86,6 +87,77 @@ func TestDispatchIncoming_UsesActualChatIDAndStoresReqIDRoute(t *testing.T) { } } +func TestNewChannel_DoesNotRegisterMessageSplitLimit(t *testing.T) { + t.Parallel() + + ch := newTestWeComChannel(t, bus.NewMessageBus()) + if got := ch.MaxMessageLength(); got != 0 { + t.Fatalf("MaxMessageLength() = %d, want 0", got) + } +} + +func TestBeginStream_UpdateAndFinalize(t *testing.T) { + t.Parallel() + + ch := newTestWeComChannel(t, bus.NewMessageBus()) + ch.SetRunning(true) + ch.queueTurn("chat-1", wecomTurn{ + ReqID: "req-1", + ChatID: "chat-1", + ChatType: 1, + StreamID: "stream-1", + CreatedAt: time.Now(), + }) + + var commands []wecomCommand + ch.commandSend = func(cmd wecomCommand, _ time.Duration) (wecomEnvelope, error) { + commands = append(commands, cmd) + return wecomTestAck(nil), nil + } + + streamer, err := ch.BeginStream(context.Background(), "chat-1") + if err != nil { + t.Fatalf("BeginStream() error = %v", err) + } + if err := streamer.Update(context.Background(), "draft"); err != nil { + t.Fatalf("Update() error = %v", err) + } + if err := streamer.Finalize(context.Background(), "final"); err != nil { + t.Fatalf("Finalize() error = %v", err) + } + + if len(commands) != 2 { + t.Fatalf("expected 2 commands, got %d", len(commands)) + } + for i, wantFinish := range []bool{false, true} { + if commands[i].Cmd != wecomCmdRespondMsg { + t.Fatalf("command[%d].Cmd = %q, want %q", i, commands[i].Cmd, wecomCmdRespondMsg) + } + body, ok := commands[i].Body.(wecomRespondMsgBody) + if !ok { + t.Fatalf("command[%d] body type = %T", i, commands[i].Body) + } + if body.Stream == nil { + t.Fatalf("command[%d] missing stream body", i) + } + if body.Stream.ID != "stream-1" { + t.Fatalf("command[%d] stream id = %q, want stream-1", i, body.Stream.ID) + } + if body.Stream.Finish != wantFinish { + t.Fatalf("command[%d] finish = %v, want %v", i, body.Stream.Finish, wantFinish) + } + } + if body := commands[0].Body.(wecomRespondMsgBody); body.Stream.Content != "draft" { + t.Fatalf("update content = %q, want draft", body.Stream.Content) + } + if body := commands[1].Body.(wecomRespondMsgBody); body.Stream.Content != "final" { + t.Fatalf("final content = %q, want final", body.Stream.Content) + } + if _, ok := ch.getTurn("chat-1"); ok { + t.Fatal("expected turn to be consumed after Finalize") + } +} + func TestSend_StreamFailureFallsBackToActualChatID(t *testing.T) { t.Parallel() @@ -155,6 +227,85 @@ func TestSend_StreamFailureFallsBackToActualChatID(t *testing.T) { } } +func TestSend_DoesNotSplitStreamReply(t *testing.T) { + t.Parallel() + + ch := newTestWeComChannel(t, bus.NewMessageBus()) + ch.SetRunning(true) + ch.queueTurn("chat-1", wecomTurn{ + ReqID: "req-1", + ChatID: "chat-1", + ChatType: 1, + StreamID: "stream-1", + CreatedAt: time.Now(), + }) + + var commands []wecomCommand + ch.commandSend = func(cmd wecomCommand, _ time.Duration) (wecomEnvelope, error) { + commands = append(commands, cmd) + return wecomTestAck(nil), nil + } + + content := strings.Repeat("\u4e2d", 30000) + if err := ch.Send(context.Background(), bus.OutboundMessage{ + Channel: "wecom", + ChatID: "chat-1", + Content: content, + }); err != nil { + t.Fatalf("Send() error = %v", err) + } + + if len(commands) != 1 { + t.Fatalf("expected 1 stream command, got %d", len(commands)) + } + body, ok := commands[0].Body.(wecomRespondMsgBody) + if !ok { + t.Fatalf("unexpected body type %T", commands[0].Body) + } + if body.Stream == nil || !body.Stream.Finish { + t.Fatalf("stream body = %+v", body.Stream) + } + if body.Stream.Content != content { + t.Fatalf("stream content length = %d, want %d", len(body.Stream.Content), len(content)) + } +} + +func TestSend_DoesNotSplitActivePush(t *testing.T) { + t.Parallel() + + ch := newTestWeComChannel(t, bus.NewMessageBus()) + ch.SetRunning(true) + + var commands []wecomCommand + ch.commandSend = func(cmd wecomCommand, _ time.Duration) (wecomEnvelope, error) { + commands = append(commands, cmd) + return wecomTestAck(nil), nil + } + + content := strings.Repeat("a", 30000) + if err := ch.Send(context.Background(), bus.OutboundMessage{ + Channel: "wecom", + ChatID: "chat-1", + Content: content, + }); err != nil { + t.Fatalf("Send() error = %v", err) + } + + if len(commands) != 1 { + t.Fatalf("expected 1 send command, got %d", len(commands)) + } + if commands[0].Cmd != wecomCmdSendMsg { + t.Fatalf("command = %q, want %q", commands[0].Cmd, wecomCmdSendMsg) + } + body, ok := commands[0].Body.(wecomSendMsgBody) + if !ok { + t.Fatalf("unexpected body type %T", commands[0].Body) + } + if body.Markdown == nil || body.Markdown.Content != content { + t.Fatalf("markdown content length = %d, want %d", len(body.Markdown.Content), len(content)) + } +} + func TestSendMedia_SendsActiveImage(t *testing.T) { t.Parallel() From cd48c3bde89ffcc66a1898657b9cd60dc805b274 Mon Sep 17 00:00:00 2001 From: Hoshina Date: Tue, 24 Mar 2026 20:27:31 +0800 Subject: [PATCH 28/39] fix(config): remove stale wecom security merge fields --- pkg/config/security.go | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/pkg/config/security.go b/pkg/config/security.go index 72f0c013f..47ad1a5b0 100644 --- a/pkg/config/security.go +++ b/pkg/config/security.go @@ -319,17 +319,9 @@ func mergeChannelsSecurity(dst, src *ChannelsSecurity) { if src.OneBot != nil && src.OneBot.AccessToken != "" { dst.OneBot = src.OneBot } - if src.WeCom != nil && (src.WeCom.Token != "" || src.WeCom.EncodingAESKey != "") { + if src.WeCom != nil && src.WeCom.Secret != "" { dst.WeCom = src.WeCom } - if src.WeComApp != nil && - (src.WeComApp.CorpSecret != "" || src.WeComApp.Token != "" || src.WeComApp.EncodingAESKey != "") { - dst.WeComApp = src.WeComApp - } - if src.WeComAIBot != nil && - (src.WeComAIBot.Secret != "" || src.WeComAIBot.Token != "" || src.WeComAIBot.EncodingAESKey != "") { - dst.WeComAIBot = src.WeComAIBot - } if src.Pico != nil && src.Pico.Token != "" { dst.Pico = src.Pico } From 4d7a629b7996145ff16a662832261c3e8b7954ed Mon Sep 17 00:00:00 2001 From: wenjie Date: Tue, 24 Mar 2026 20:33:32 +0800 Subject: [PATCH 29/39] feat(web): improve Weixin channel binding flow (#1968) - persist Weixin bindings, enable the channel automatically, and try to restart the gateway - refresh frontend channel and gateway state after successful binding - harden QR polling state handling and update related channel UI behavior - localize sidebar channel priority, add Weixin icon support, and add backend test coverage --- web/backend/api/weixin.go | 25 +++- web/backend/api/weixin_test.go | 56 ++++++++ web/frontend/src/api/channels.ts | 8 +- web/frontend/src/components/app-sidebar.tsx | 7 +- .../channels/channel-config-page.tsx | 104 ++++++++------ .../channels/channel-forms/weixin-form.tsx | 133 ++++++++++++++---- .../src/components/chat/user-message.tsx | 2 +- .../src/components/config/form-model.ts | 5 +- .../src/hooks/use-sidebar-channels.ts | 31 ++-- web/frontend/src/i18n/locales/en.json | 7 +- web/frontend/src/i18n/locales/zh.json | 7 +- 11 files changed, 290 insertions(+), 95 deletions(-) create mode 100644 web/backend/api/weixin_test.go diff --git a/web/backend/api/weixin.go b/web/backend/api/weixin.go index e7e94f39e..808b88c41 100644 --- a/web/backend/api/weixin.go +++ b/web/backend/api/weixin.go @@ -171,7 +171,7 @@ func (h *Handler) handlePollWeixinFlow(w http.ResponseWriter, r *http.Request) { h.setWeixinFlowError(flowID, "login confirmed but missing bot_token") break } - if saveErr := h.saveWeixinToken(statusResp.BotToken, statusResp.IlinkBotID); saveErr != nil { + if saveErr := h.saveWeixinBinding(statusResp.BotToken, statusResp.IlinkBotID); saveErr != nil { h.setWeixinFlowError(flowID, fmt.Sprintf("failed to save token: %v", saveErr)) logger.ErrorCF("weixin", "failed to save token", map[string]any{"error": saveErr.Error()}) break @@ -203,17 +203,34 @@ func (h *Handler) handlePollWeixinFlow(w http.ResponseWriter, r *http.Request) { _ = json.NewEncoder(w).Encode(resp) } -// saveWeixinToken writes the token and account ID into the config file. -func (h *Handler) saveWeixinToken(token, accountID string) error { +// saveWeixinBinding writes the token/account ID, enables the Weixin channel, +// and best-effort restarts the gateway when it is currently running. +func (h *Handler) saveWeixinBinding(token, accountID string) error { cfg, err := config.LoadConfig(h.configPath) if err != nil { return fmt.Errorf("load config: %w", err) } cfg.Channels.Weixin.SetToken(token) + cfg.Channels.Weixin.Enabled = true if accountID != "" { cfg.Channels.Weixin.AccountID = accountID } - return config.SaveConfig(h.configPath, cfg) + if err := config.SaveConfig(h.configPath, cfg); err != nil { + return err + } + + status := h.gatewayStatusData() + gatewayStatus, _ := status["gateway_status"].(string) + if gatewayStatus != "running" { + return nil + } + + if _, err := h.RestartGateway(); err != nil { + logger.ErrorCF("weixin", "failed to restart gateway after saving binding", map[string]any{ + "error": err.Error(), + }) + } + return nil } // generateQRDataURI encodes content as a QR code PNG and returns a data URI. diff --git a/web/backend/api/weixin_test.go b/web/backend/api/weixin_test.go new file mode 100644 index 000000000..03342b72b --- /dev/null +++ b/web/backend/api/weixin_test.go @@ -0,0 +1,56 @@ +package api + +import ( + "io" + "net/http" + "os" + "path/filepath" + "strconv" + "strings" + "testing" + "time" + + "github.com/sipeed/picoclaw/pkg/config" +) + +func TestSaveWeixinBindingReturnsSuccessWhenRestartFails(t *testing.T) { + resetGatewayTestState(t) + + configPath := filepath.Join(t.TempDir(), "config.json") + cfg := config.DefaultConfig() + if err := config.SaveConfig(configPath, cfg); err != nil { + t.Fatalf("SaveConfig() error = %v", err) + } + + originalHealthGet := gatewayHealthGet + gatewayHealthGet = func(url string, timeout time.Duration) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader( + `{"status":"ok","uptime":"1s","pid":` + strconv.Itoa(os.Getpid()) + `}`, + )), + }, nil + } + t.Cleanup(func() { + gatewayHealthGet = originalHealthGet + }) + + h := NewHandler(configPath) + if err := h.saveWeixinBinding("bot-token", "bot-account"); err != nil { + t.Fatalf("saveWeixinBinding() error = %v, want nil after config save succeeds", err) + } + + savedCfg, err := config.LoadConfig(configPath) + if err != nil { + t.Fatalf("LoadConfig() error = %v", err) + } + if got := savedCfg.Channels.Weixin.Token(); got != "bot-token" { + t.Fatalf("Weixin.Token() = %q, want %q", got, "bot-token") + } + if got := savedCfg.Channels.Weixin.AccountID; got != "bot-account" { + t.Fatalf("Weixin.AccountID = %q, want %q", got, "bot-account") + } + if !savedCfg.Channels.Weixin.Enabled { + t.Fatalf("Weixin.Enabled = false, want true") + } +} diff --git a/web/frontend/src/api/channels.ts b/web/frontend/src/api/channels.ts index c3d3a65f3..d4c3ac74b 100644 --- a/web/frontend/src/api/channels.ts +++ b/web/frontend/src/api/channels.ts @@ -76,8 +76,12 @@ export async function startWeixinFlow(): Promise { return request("/api/weixin/flows", { method: "POST" }) } -export async function pollWeixinFlow(flowID: string): Promise { - return request(`/api/weixin/flows/${encodeURIComponent(flowID)}`) +export async function pollWeixinFlow( + flowID: string, +): Promise { + return request( + `/api/weixin/flows/${encodeURIComponent(flowID)}`, + ) } export type { ChannelsCatalogResponse, ConfigActionResponse } diff --git a/web/frontend/src/components/app-sidebar.tsx b/web/frontend/src/components/app-sidebar.tsx index 702212857..0e135c0c1 100644 --- a/web/frontend/src/components/app-sidebar.tsx +++ b/web/frontend/src/components/app-sidebar.tsx @@ -67,14 +67,17 @@ const baseNavGroups: Omit[] = [ export function AppSidebar({ ...props }: React.ComponentProps) { const routerState = useRouterState() - const { t } = useTranslation() + const { i18n, t } = useTranslation() const currentPath = routerState.location.pathname const { channelItems, hasMoreChannels, showAllChannels, toggleShowAllChannels, - } = useSidebarChannels({ t }) + } = useSidebarChannels({ + language: (i18n.resolvedLanguage ?? i18n.language ?? "").toLowerCase(), + t, + }) const navGroups: NavGroup[] = React.useMemo(() => { return [ diff --git a/web/frontend/src/components/channels/channel-config-page.tsx b/web/frontend/src/components/channels/channel-config-page.tsx index 4996a6314..ee483d652 100644 --- a/web/frontend/src/components/channels/channel-config-page.tsx +++ b/web/frontend/src/components/channels/channel-config-page.tsx @@ -1,8 +1,6 @@ import { IconLoader2 } from "@tabler/icons-react" -import { useAtomValue } from "jotai" import { useCallback, useEffect, useMemo, useRef, useState } from "react" import { useTranslation } from "react-i18next" -import { toast } from "sonner" import { type ChannelConfig, @@ -21,7 +19,8 @@ import { WeixinForm } from "@/components/channels/channel-forms/weixin-form" import { PageHeader } from "@/components/page-header" import { Button } from "@/components/ui/button" import { Switch } from "@/components/ui/switch" -import { gatewayAtom } from "@/store/gateway" +import { useGateway } from "@/hooks/use-gateway" +import { refreshGatewayState } from "@/store/gateway" interface ChannelConfigPageProps { channelName: string @@ -241,7 +240,7 @@ const CHANNELS_WITHOUT_DOCS = new Set([ export function ChannelConfigPage({ channelName }: ChannelConfigPageProps) { const { t, i18n } = useTranslation() - const gateway = useAtomValue(gatewayAtom) + const { state: gatewayState } = useGateway() const [loading, setLoading] = useState(true) const [saving, setSaving] = useState(false) @@ -254,56 +253,59 @@ export function ChannelConfigPage({ channelName }: ChannelConfigPageProps) { const [editConfig, setEditConfig] = useState({}) const [enabled, setEnabled] = useState(false) - const loadData = useCallback(async (silent = false) => { - if (!silent) setLoading(true) - try { - const [catalog, appConfig] = await Promise.all([ - getChannelsCatalog(), - getAppConfig(), - ]) - const matched = - catalog.channels.find((item) => item.name === channelName) ?? null + const loadData = useCallback( + async (silent = false) => { + if (!silent) setLoading(true) + try { + const [catalog, appConfig] = await Promise.all([ + getChannelsCatalog(), + getAppConfig(), + ]) + const matched = + catalog.channels.find((item) => item.name === channelName) ?? null - if (!matched) { - setChannel(null) - setFetchError( - t("channels.page.notFound", { - name: channelName, - }), - ) - return + if (!matched) { + setChannel(null) + setFetchError( + t("channels.page.notFound", { + name: channelName, + }), + ) + return + } + + const channelsConfig = asRecord(asRecord(appConfig).channels) + const raw = asRecord(channelsConfig[matched.config_key]) + const normalized = normalizeConfig(matched, raw) + + setChannel(matched) + setBaseConfig(normalized) + setEditConfig(buildEditConfig(normalized)) + setEnabled(asBool(normalized.enabled)) + setFetchError("") + setServerError("") + setFieldErrors({}) + } catch (e) { + setFetchError(e instanceof Error ? e.message : t("channels.loadError")) + } finally { + if (!silent) setLoading(false) } - - const channelsConfig = asRecord(asRecord(appConfig).channels) - const raw = asRecord(channelsConfig[matched.config_key]) - const normalized = normalizeConfig(matched, raw) - - setChannel(matched) - setBaseConfig(normalized) - setEditConfig(buildEditConfig(normalized)) - setEnabled(asBool(normalized.enabled)) - setFetchError("") - setServerError("") - setFieldErrors({}) - } catch (e) { - setFetchError(e instanceof Error ? e.message : t("channels.loadError")) - } finally { - if (!silent) setLoading(false) - } - }, [channelName, t]) + }, + [channelName, t], + ) useEffect(() => { loadData() }, [loadData]) - const previousGatewayStatusRef = useRef(gateway.status) + const previousGatewayStatusRef = useRef(gatewayState) useEffect(() => { const previousStatus = previousGatewayStatusRef.current - if (previousStatus !== "running" && gateway.status === "running") { + if (previousStatus !== "running" && gatewayState === "running") { void loadData() } - previousGatewayStatusRef.current = gateway.status - }, [gateway.status, loadData]) + previousGatewayStatusRef.current = gatewayState + }, [gatewayState, loadData]) const savePayload = useMemo(() => { if (!channel) return null @@ -396,18 +398,28 @@ export function ChannelConfigPage({ channelName }: ChannelConfigPageProps) { [channel.config_key]: savePayload, }, }) - toast.success(t("channels.page.saveSuccess")) await loadData() } catch (e) { const message = e instanceof Error ? e.message : t("channels.page.saveError") setServerError(message) - toast.error(message) } finally { setSaving(false) } } + const handleWeixinBindSuccess = useCallback(async () => { + try { + setEnabled(true) + await Promise.all([loadData(true), refreshGatewayState({ force: true })]) + } catch (e) { + const message = + e instanceof Error ? e.message : t("channels.page.saveError") + setServerError(message) + await loadData(true) + } + }, [loadData, t]) + const renderForm = () => { if (!channel) return null const isEdit = configured @@ -455,7 +467,7 @@ export function ChannelConfigPage({ channelName }: ChannelConfigPageProps) { config={editConfig} onChange={handleChange} isEdit={isEdit} - onBindSuccess={() => void loadData(true)} + onBindSuccess={() => void handleWeixinBindSuccess()} /> ) default: diff --git a/web/frontend/src/components/channels/channel-forms/weixin-form.tsx b/web/frontend/src/components/channels/channel-forms/weixin-form.tsx index 765136b25..20e66ffc2 100644 --- a/web/frontend/src/components/channels/channel-forms/weixin-form.tsx +++ b/web/frontend/src/components/channels/channel-forms/weixin-form.tsx @@ -1,4 +1,10 @@ -import { IconLoader2, IconRefresh, IconCheck, IconX, IconQrcode } from "@tabler/icons-react" +import { + IconCheck, + IconLoader2, + IconQrcode, + IconRefresh, + IconX, +} from "@tabler/icons-react" import { useCallback, useEffect, useRef, useState } from "react" import { useTranslation } from "react-i18next" @@ -8,7 +14,14 @@ import { Field } from "@/components/shared-form" import { Button } from "@/components/ui/button" import { Input } from "@/components/ui/input" -type BindingState = "idle" | "loading" | "waiting" | "scaned" | "confirmed" | "expired" | "error" +type BindingState = + | "idle" + | "loading" + | "waiting" + | "scaned" + | "confirmed" + | "expired" + | "error" interface WeixinFormProps { config: ChannelConfig @@ -26,7 +39,12 @@ function asStringArray(value: unknown): string[] { return value.filter((item): item is string => typeof item === "string") } -export function WeixinForm({ config, onChange, isEdit, onBindSuccess }: WeixinFormProps) { +export function WeixinForm({ + config, + onChange, + isEdit, + onBindSuccess, +}: WeixinFormProps) { const { t } = useTranslation() const [bindState, setBindState] = useState("idle") @@ -35,10 +53,12 @@ export function WeixinForm({ config, onChange, isEdit, onBindSuccess }: WeixinFo const [errorMsg, setErrorMsg] = useState("") const pollTimerRef = useRef | null>(null) + const pollGenerationRef = useRef(0) const isBound = isEdit && asString(config.account_id) !== "" const existingAccountID = asString(config.account_id) const stopPolling = useCallback(() => { + pollGenerationRef.current += 1 if (pollTimerRef.current !== null) { clearInterval(pollTimerRef.current) pollTimerRef.current = null @@ -47,17 +67,32 @@ export function WeixinForm({ config, onChange, isEdit, onBindSuccess }: WeixinFo useEffect(() => () => stopPolling(), [stopPolling]) + useEffect(() => { + if (!existingAccountID) return + stopPolling() + setAccountID(existingAccountID) + setBindState("confirmed") + setErrorMsg("") + }, [existingAccountID, stopPolling]) + const startPolling = useCallback( (id: string) => { stopPolling() + const generation = pollGenerationRef.current + let inFlight = false pollTimerRef.current = setInterval(async () => { + if (inFlight) return + inFlight = true try { const resp = await pollWeixinFlow(id) + if (generation !== pollGenerationRef.current) { + return + } if (resp.status === "scaned") { setBindState("scaned") } else if (resp.status === "confirmed") { stopPolling() - setAccountID(resp.account_id ?? null) + setAccountID(resp.account_id ?? existingAccountID ?? null) setBindState("confirmed") onBindSuccess?.() } else if (resp.status === "expired") { @@ -70,10 +105,12 @@ export function WeixinForm({ config, onChange, isEdit, onBindSuccess }: WeixinFo } } catch { // transient network error — keep polling + } finally { + inFlight = false } }, 2000) }, - [stopPolling, onBindSuccess, t], + [existingAccountID, stopPolling, onBindSuccess, t], ) const handleBind = async () => { @@ -88,7 +125,9 @@ export function WeixinForm({ config, onChange, isEdit, onBindSuccess }: WeixinFo startPolling(resp.flow_id) } catch (e) { setBindState("error") - setErrorMsg(e instanceof Error ? e.message : t("channels.weixin.errorGeneric")) + setErrorMsg( + e instanceof Error ? e.message : t("channels.weixin.errorGeneric"), + ) } } @@ -111,9 +150,16 @@ export function WeixinForm({ config, onChange, isEdit, onBindSuccess }: WeixinFo {t("channels.weixin.bound")} {existingAccountID && ( -

{existingAccountID}

+

+ {existingAccountID} +

)} - @@ -122,7 +168,9 @@ export function WeixinForm({ config, onChange, isEdit, onBindSuccess }: WeixinFo } return (
-

{t("channels.weixin.notBound")}

+

+ {t("channels.weixin.notBound")} +

@@ -174,15 +237,25 @@ export function WeixinForm({ config, onChange, isEdit, onBindSuccess }: WeixinFo return (
- +

{t("channels.weixin.bound")}

{accountID && ( -

{accountID}

+

+ {accountID} +

)} - @@ -196,7 +269,9 @@ export function WeixinForm({ config, onChange, isEdit, onBindSuccess }: WeixinFo
-

{t("channels.weixin.expired")}

+

+ {t("channels.weixin.expired")} +

+
+ {testResult && ( +
+ {testResult.allowed + ? `${t("pages.config.pattern_detector_result_allowed")}${testResult.matchedWhitelist ? ` (${testResult.matchedWhitelist})` : ""}` + : testResult.blocked + ? `${t("pages.config.pattern_detector_result_blocked")}${testResult.matchedBlacklist ? ` (${testResult.matchedBlacklist})` : ""}` + : t("pages.config.pattern_detector_result_no_match")} +
+ )} +
+ +