diff --git a/.goreleaser.yaml b/.goreleaser.yaml index 49d230506..90bdc8437 100644 --- a/.goreleaser.yaml +++ b/.goreleaser.yaml @@ -31,7 +31,7 @@ builds: - loong64 - arm goarm: - - "7" + - "7" main: ./cmd/picoclaw ignore: - goos: windows diff --git a/Makefile b/Makefile index a14723616..c7375a544 100644 --- a/Makefile +++ b/Makefile @@ -87,11 +87,46 @@ build: generate @echo "Build complete: $(BINARY_PATH)" @ln -sf $(BINARY_NAME)-$(PLATFORM)-$(ARCH) $(BUILD_DIR)/$(BINARY_NAME) +## build-whatsapp-native: Build with WhatsApp native (whatsmeow) support; larger binary +build-whatsapp-native: generate +## @echo "Building $(BINARY_NAME) with WhatsApp native for $(PLATFORM)/$(ARCH)..." + @echo "Building for multiple platforms..." + @mkdir -p $(BUILD_DIR) + GOOS=linux GOARCH=amd64 $(GO) build -tags whatsapp_native $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-linux-amd64 ./$(CMD_DIR) + GOOS=linux GOARCH=arm GOARM=7 $(GO) build -tags whatsapp_native $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-linux-arm ./$(CMD_DIR) + GOOS=linux GOARCH=arm64 $(GO) build -tags whatsapp_native $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-linux-arm64 ./$(CMD_DIR) + GOOS=linux GOARCH=loong64 $(GO) build -tags whatsapp_native $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-linux-loong64 ./$(CMD_DIR) + GOOS=linux GOARCH=riscv64 $(GO) build -tags whatsapp_native $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-linux-riscv64 ./$(CMD_DIR) + GOOS=darwin GOARCH=arm64 $(GO) build -tags whatsapp_native $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-darwin-arm64 ./$(CMD_DIR) + GOOS=windows GOARCH=amd64 $(GO) build -tags whatsapp_native $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-windows-amd64.exe ./$(CMD_DIR) +## @$(GO) build $(GOFLAGS) -tags whatsapp_native $(LDFLAGS) -o $(BINARY_PATH) ./$(CMD_DIR) + @echo "Build complete" +## @ln -sf $(BINARY_NAME)-$(PLATFORM)-$(ARCH) $(BUILD_DIR)/$(BINARY_NAME) + +## build-linux-arm: Build for Linux ARMv7 (e.g. Raspberry Pi Zero 2 W 32-bit) +build-linux-arm: generate + @echo "Building for linux/arm (GOARM=7)..." + @mkdir -p $(BUILD_DIR) + GOOS=linux GOARCH=arm GOARM=7 $(GO) build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-linux-arm ./$(CMD_DIR) + @echo "Build complete: $(BUILD_DIR)/$(BINARY_NAME)-linux-arm" + +## build-linux-arm64: Build for Linux ARM64 (e.g. Raspberry Pi Zero 2 W 64-bit) +build-linux-arm64: generate + @echo "Building for linux/arm64..." + @mkdir -p $(BUILD_DIR) + GOOS=linux GOARCH=arm64 $(GO) build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-linux-arm64 ./$(CMD_DIR) + @echo "Build complete: $(BUILD_DIR)/$(BINARY_NAME)-linux-arm64" + +## build-pi-zero: Build for Raspberry Pi Zero 2 W (32-bit and 64-bit) +build-pi-zero: build-linux-arm build-linux-arm64 + @echo "Pi Zero 2 W builds: $(BUILD_DIR)/$(BINARY_NAME)-linux-arm (32-bit), $(BUILD_DIR)/$(BINARY_NAME)-linux-arm64 (64-bit)" + ## build-all: Build picoclaw for all platforms build-all: generate @echo "Building for multiple platforms..." @mkdir -p $(BUILD_DIR) GOOS=linux GOARCH=amd64 $(GO) build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-linux-amd64 ./$(CMD_DIR) + GOOS=linux GOARCH=arm GOARM=7 $(GO) build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-linux-arm ./$(CMD_DIR) GOOS=linux GOARCH=arm64 $(GO) build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-linux-arm64 ./$(CMD_DIR) GOOS=linux GOARCH=loong64 $(GO) build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-linux-loong64 ./$(CMD_DIR) GOOS=linux GOARCH=riscv64 $(GO) build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-linux-riscv64 ./$(CMD_DIR) diff --git a/README.md b/README.md index 6f5cce4d0..b040d0605 100644 --- a/README.md +++ b/README.md @@ -154,10 +154,15 @@ make build # Build for multiple platforms make build-all +# Build for Raspberry Pi Zero 2 W (32-bit: make build-linux-arm; 64-bit: make build-linux-arm64) +make build-pi-zero + # Build And Install make install ``` +**Raspberry Pi Zero 2 W:** Use the binary that matches your OS: 32-bit Raspberry Pi OS → `make build-linux-arm` (output: `build/picoclaw-linux-arm`); 64-bit → `make build-linux-arm64` (output: `build/picoclaw-linux-arm64`). Or run `make build-pi-zero` to build both. + ## 🐳 Docker Compose You can also run PicoClaw using Docker Compose without installing anything locally. @@ -288,12 +293,13 @@ That's it! You have a working AI assistant in 2 minutes. ## 💬 Chat Apps -Talk to your picoclaw through Telegram, Discord, DingTalk, LINE, or WeCom +Talk to your picoclaw through Telegram, Discord, WhatsApp, DingTalk, LINE, or WeCom | Channel | Setup | | ------------ | ---------------------------------- | | **Telegram** | Easy (just a token) | | **Discord** | Easy (bot token + intents) | +| **WhatsApp** | Easy (native: QR scan; or bridge URL) | | **QQ** | Easy (AppID + AppSecret) | | **DingTalk** | Medium (app credentials) | | **LINE** | Medium (credentials + webhook URL) | @@ -384,6 +390,33 @@ picoclaw gateway +
+WhatsApp (native via whatsmeow) + +PicoClaw can connect to WhatsApp in two ways: + +- **Native (recommended):** In-process using [whatsmeow](https://github.com/tulir/whatsmeow). No separate bridge. Set `"use_native": true` and leave `bridge_url` empty. On first run, scan the QR code with WhatsApp (Linked Devices). Session is stored under your workspace (e.g. `workspace/whatsapp/`). The native channel is **optional** to keep the default binary small; build with `-tags whatsapp_native` (e.g. `make build-whatsapp-native` or `go build -tags whatsapp_native ./cmd/...`). +- **Bridge:** Connect to an external WebSocket bridge. Set `bridge_url` (e.g. `ws://localhost:3001`) and keep `use_native` false. + +**Configure (native)** + +```json +{ + "channels": { + "whatsapp": { + "enabled": true, + "use_native": true, + "session_store_path": "", + "allow_from": [] + } + } +} +``` + +If `session_store_path` is empty, the session is stored in `<workspace>/whatsapp/`. Run `picoclaw gateway`; on first run, scan the QR code printed in the terminal with WhatsApp → Linked Devices. + +
+
QQ @@ -1070,7 +1103,11 @@ picoclaw agent -m "Hello" "allow_from": [""] }, "whatsapp": { - "enabled": false + "enabled": false, + "bridge_url": "ws://localhost:3001", + "use_native": false, + "session_store_path": "", + "allow_from": [] }, "feishu": { "enabled": false, diff --git a/assets/picoclaw_detect_person.mp4 b/assets/picoclaw_detect_person.mp4 deleted file mode 100644 index b56999689..000000000 Binary files a/assets/picoclaw_detect_person.mp4 and /dev/null differ diff --git a/assets/wechat.png b/assets/wechat.png index e30c34e4e..1900c7556 100644 Binary files a/assets/wechat.png and b/assets/wechat.png differ diff --git a/cmd/picoclaw/internal/agent/helpers.go b/cmd/picoclaw/internal/agent/helpers.go index 746e9755e..f754abc65 100644 --- a/cmd/picoclaw/internal/agent/helpers.go +++ b/cmd/picoclaw/internal/agent/helpers.go @@ -48,6 +48,7 @@ func agentCmd(message, sessionKey, model string, debug bool) error { } msgBus := bus.NewMessageBus() + defer msgBus.Close() agentLoop := agent.NewAgentLoop(cfg, msgBus, provider) // Print agent startup info (only for interactive mode) diff --git a/cmd/picoclaw/internal/gateway/helpers.go b/cmd/picoclaw/internal/gateway/helpers.go index a06625dc9..baa489b92 100644 --- a/cmd/picoclaw/internal/gateway/helpers.go +++ b/cmd/picoclaw/internal/gateway/helpers.go @@ -2,29 +2,39 @@ package gateway import ( "context" - "errors" "fmt" - "net/http" "os" "os/signal" "path/filepath" - "strings" "time" "github.com/sipeed/picoclaw/cmd/picoclaw/internal" "github.com/sipeed/picoclaw/pkg/agent" "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/channels" + _ "github.com/sipeed/picoclaw/pkg/channels/dingtalk" + _ "github.com/sipeed/picoclaw/pkg/channels/discord" + _ "github.com/sipeed/picoclaw/pkg/channels/feishu" + _ "github.com/sipeed/picoclaw/pkg/channels/line" + _ "github.com/sipeed/picoclaw/pkg/channels/maixcam" + _ "github.com/sipeed/picoclaw/pkg/channels/onebot" + _ "github.com/sipeed/picoclaw/pkg/channels/pico" + _ "github.com/sipeed/picoclaw/pkg/channels/qq" + _ "github.com/sipeed/picoclaw/pkg/channels/slack" + _ "github.com/sipeed/picoclaw/pkg/channels/telegram" + _ "github.com/sipeed/picoclaw/pkg/channels/wecom" + _ "github.com/sipeed/picoclaw/pkg/channels/whatsapp" + _ "github.com/sipeed/picoclaw/pkg/channels/whatsapp_native" "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/cron" "github.com/sipeed/picoclaw/pkg/devices" "github.com/sipeed/picoclaw/pkg/health" "github.com/sipeed/picoclaw/pkg/heartbeat" "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/media" "github.com/sipeed/picoclaw/pkg/providers" "github.com/sipeed/picoclaw/pkg/state" "github.com/sipeed/picoclaw/pkg/tools" - "github.com/sipeed/picoclaw/pkg/voice" ) func gatewayCmd(debug bool) error { @@ -105,49 +115,23 @@ func gatewayCmd(debug bool) error { return tools.SilentResult(response) }) - channelManager, err := channels.NewManager(cfg, msgBus) + // Create media store for file lifecycle management with TTL cleanup + mediaStore := media.NewFileMediaStoreWithCleanup(media.MediaCleanerConfig{ + Enabled: cfg.Tools.MediaCleanup.Enabled, + MaxAge: time.Duration(cfg.Tools.MediaCleanup.MaxAge) * time.Minute, + Interval: time.Duration(cfg.Tools.MediaCleanup.Interval) * time.Minute, + }) + mediaStore.Start() + + channelManager, err := channels.NewManager(cfg, msgBus, mediaStore) if err != nil { + mediaStore.Stop() return fmt.Errorf("error creating channel manager: %w", err) } - // Inject channel manager into agent loop for command handling + // Inject channel manager and media store into agent loop agentLoop.SetChannelManager(channelManager) - - var transcriber *voice.GroqTranscriber - groqAPIKey := cfg.Providers.Groq.APIKey - if groqAPIKey == "" { - for _, mc := range cfg.ModelList { - if strings.HasPrefix(mc.Model, "groq/") && mc.APIKey != "" { - groqAPIKey = mc.APIKey - break - } - } - } - if groqAPIKey != "" { - transcriber = voice.NewGroqTranscriber(groqAPIKey) - logger.InfoC("voice", "Groq voice transcription enabled") - } - - if transcriber != nil { - if telegramChannel, ok := channelManager.GetChannel("telegram"); ok { - if tc, ok := telegramChannel.(*channels.TelegramChannel); ok { - tc.SetTranscriber(transcriber) - logger.InfoC("voice", "Groq transcription attached to Telegram channel") - } - } - if discordChannel, ok := channelManager.GetChannel("discord"); ok { - if dc, ok := discordChannel.(*channels.DiscordChannel); ok { - dc.SetTranscriber(transcriber) - logger.InfoC("voice", "Groq transcription attached to Discord channel") - } - } - if slackChannel, ok := channelManager.GetChannel("slack"); ok { - if sc, ok := slackChannel.(*channels.SlackChannel); ok { - sc.SetTranscriber(transcriber) - logger.InfoC("voice", "Groq transcription attached to Slack channel") - } - } - } + agentLoop.SetMediaStore(mediaStore) enabledChannels := channelManager.GetEnabledChannels() if len(enabledChannels) > 0 { @@ -184,16 +168,15 @@ func gatewayCmd(debug bool) error { fmt.Println("✓ Device event service started") } + // Setup shared HTTP server with health endpoints and webhook handlers + healthServer := health.NewServer(cfg.Gateway.Host, cfg.Gateway.Port) + addr := fmt.Sprintf("%s:%d", cfg.Gateway.Host, cfg.Gateway.Port) + channelManager.SetupHTTPServer(addr, healthServer) + if err := channelManager.StartAll(ctx); err != nil { fmt.Printf("Error starting channels: %v\n", err) } - healthServer := health.NewServer(cfg.Gateway.Host, cfg.Gateway.Port) - go func() { - if err := healthServer.Start(); err != nil && !errors.Is(err, http.ErrServerClosed) { - logger.ErrorCF("health", "Health server error", map[string]any{"error": err.Error()}) - } - }() fmt.Printf("✓ Health endpoints available at http://%s:%d/health and /ready\n", cfg.Gateway.Host, cfg.Gateway.Port) go agentLoop.Run(ctx) @@ -207,12 +190,19 @@ func gatewayCmd(debug bool) error { cp.Close() } cancel() - healthServer.Stop(context.Background()) + msgBus.Close() + + // Use a fresh context with timeout for graceful shutdown, + // since the original ctx is already canceled. + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 15*time.Second) + defer shutdownCancel() + + channelManager.StopAll(shutdownCtx) deviceService.Stop() heartbeatService.Stop() cronService.Stop() + mediaStore.Stop() agentLoop.Stop() - channelManager.StopAll(ctx) fmt.Println("✓ Gateway stopped") return nil diff --git a/config/config.example.json b/config/config.example.json index 9575039f8..55a823009 100644 --- a/config/config.example.json +++ b/config/config.example.json @@ -52,30 +52,37 @@ "proxy": "", "allow_from": [ "YOUR_USER_ID" - ] + ], + "reasoning_channel_id": "" }, "discord": { "enabled": false, "token": "YOUR_DISCORD_BOT_TOKEN", "allow_from": [], - "mention_only": false + "mention_only": false, + "reasoning_channel_id": "" }, "qq": { "enabled": false, "app_id": "YOUR_QQ_APP_ID", "app_secret": "YOUR_QQ_APP_SECRET", - "allow_from": [] + "allow_from": [], + "reasoning_channel_id": "" }, "maixcam": { "enabled": false, "host": "0.0.0.0", "port": 18790, - "allow_from": [] + "allow_from": [], + "reasoning_channel_id": "" }, "whatsapp": { "enabled": false, "bridge_url": "ws://localhost:3001", - "allow_from": [] + "use_native": false, + "session_store_path": "", + "allow_from": [], + "reasoning_channel_id": "" }, "feishu": { "enabled": false, @@ -83,19 +90,22 @@ "app_secret": "", "encrypt_key": "", "verification_token": "", - "allow_from": [] + "allow_from": [], + "reasoning_channel_id": "" }, "dingtalk": { "enabled": false, "client_id": "YOUR_CLIENT_ID", "client_secret": "YOUR_CLIENT_SECRET", - "allow_from": [] + "allow_from": [], + "reasoning_channel_id": "" }, "slack": { "enabled": false, "bot_token": "xoxb-YOUR-BOT-TOKEN", "app_token": "xapp-YOUR-APP-TOKEN", - "allow_from": [] + "allow_from": [], + "reasoning_channel_id": "" }, "line": { "enabled": false, @@ -104,7 +114,8 @@ "webhook_host": "0.0.0.0", "webhook_port": 18791, "webhook_path": "/webhook/line", - "allow_from": [] + "allow_from": [], + "reasoning_channel_id": "" }, "onebot": { "enabled": false, @@ -112,7 +123,8 @@ "access_token": "", "reconnect_interval": 5, "group_trigger_prefix": [], - "allow_from": [] + "allow_from": [], + "reasoning_channel_id": "" }, "wecom": { "_comment": "WeCom Bot (智能机器人) - Easier setup, supports group chats", @@ -124,7 +136,8 @@ "webhook_port": 18793, "webhook_path": "/webhook/wecom", "allow_from": [], - "reply_timeout": 5 + "reply_timeout": 5, + "reasoning_channel_id": "" }, "wecom_app": { "_comment": "WeCom App (自建应用) - More features, proactive messaging, private chat only. See docs/wecom-app-configuration.md", @@ -138,7 +151,8 @@ "webhook_port": 18792, "webhook_path": "/webhook/wecom-app", "allow_from": [], - "reply_timeout": 5 + "reply_timeout": 5, + "reasoning_channel_id": "" } }, "providers": { diff --git a/docs/design/issue-783-investigation-and-fix-plan.zh.md b/docs/design/issue-783-investigation-and-fix-plan.zh.md new file mode 100644 index 000000000..1c9fc1e70 --- /dev/null +++ b/docs/design/issue-783-investigation-and-fix-plan.zh.md @@ -0,0 +1,61 @@ +# Issue #783 调研与修复执行文档 + +## 1. 问题澄清(已确认) + +- 现象:当 `agents.*.model.primary/fallbacks` 使用 `model_name` 别名(如 `step-3.5-flash`)时,fallback 链路将别名当作真实 `provider/model` 解析,导致 `provider` 可能为空、`model` 可能错误。 +- 根因:`ResolveCandidates` 仅对字符串做 `ParseModelRef`,未先通过 `model_list` 将别名映射到真实 `model` 字段。 +- 影响: + - fallback 执行可能把别名直接发给 OpenAI-compatible provider,触发 `Unknown Model`。 + - `defaults.provider` 为空时,日志出现 `provider=` 空值。 + +## 2. 本次目标 + +- 修复 fallback 候选解析:优先通过 `model_list` 解析别名。 +- 兼容旧行为:若未命中 `model_list`,继续走原有 `ParseModelRef` 兜底。 +- 补充测试:覆盖别名、嵌套路径模型(如 `openrouter/stepfun/...`)、空默认 provider。 +- 验证代码风格:与当前仓库风格保持一致(命名、错误处理、测试结构)。 + +## 3. 联网最佳实践调研结论(已完成) + +- [x] 查阅 OpenAI-compatible 网关(如 OpenRouter)对 `model` 字段的推荐处理。 +- [x] 查阅多 provider/fallback 设计最佳实践(候选解析、日志可观测性)。 +- [x] 将外部建议映射为本仓库可执行约束。 + +外部参考要点(来自 OpenRouter/LiteLLM/Cloudflare AI Gateway 等官方文档): + +- 优先显式配置,不依赖字符串切分推断 provider。 +- 对网关模型标识应保留完整路径语义,避免截断导致 Unknown Model。 +- fallback 与 primary 应复用同一解析策略,避免“主路径正确、降级路径错误”。 + +参考链接: + +- OpenRouter Provider Routing: https://openrouter.ai/docs/guides/routing/provider-selection +- OpenRouter Model Fallbacks: https://openrouter.ai/docs/guides/routing/model-fallbacks +- OpenRouter Chat Completion API: https://openrouter.ai/docs/api-reference/chat-completion +- LiteLLM Router Architecture: https://docs.litellm.ai/docs/router_architecture +- Cloudflare AI Gateway Chat Completion: https://developers.cloudflare.com/ai-gateway/usage/chat-completion/ + +与本仓库对应的可执行约束: + +- 在 fallback candidate 构建阶段先做 `model_name -> model_list.model` 映射。 +- 未命中映射时保留旧解析行为,保证兼容性。 +- 用新增测试锁定“别名 + 嵌套模型路径 + 空默认 provider”场景。 + +## 4. 实施步骤(顺序执行) + +- [x] Step 1: 对齐现有代码模式,定位最小改动点(`pkg/agent` + `pkg/providers`)。 +- [x] Step 2: 实现“基于 model_list 的 fallback 候选解析”。 +- [x] Step 3: 增加/更新单元测试,覆盖 issue 场景。 +- [x] Step 4: 代码风格一致性复核(与现有文件风格对照)。 +- [x] Step 5: 运行质量门禁(LSP + `make check`)。 + +## 5. 执行记录 + +- 状态:已完成 +- 已完成改动: + - `pkg/providers/fallback.go`:新增 `ResolveCandidatesWithLookup`,并保持 `ResolveCandidates` 向后兼容。 + - `pkg/agent/instance.go`:在构建 fallback candidates 前,优先通过 `model_list` 解析别名,并对无协议模型补齐默认 `openai/` 前缀后再解析。 + - `pkg/providers/fallback_test.go`:新增别名解析与去重测试。 + - `pkg/agent/instance_test.go`:新增 agent 侧别名解析到嵌套模型路径、无协议模型解析测试。 +- 风格对齐检查(完成):与 `pkg/providers/fallback_test.go`、`pkg/providers/model_ref_test.go` 现有模式一致。 +- 质量验证(完成):先 `make generate`,后 `make check` 全量通过。 diff --git a/docs/troubleshooting.md b/docs/troubleshooting.md new file mode 100644 index 000000000..219d2c6e3 --- /dev/null +++ b/docs/troubleshooting.md @@ -0,0 +1,43 @@ +# Troubleshooting + +## "model ... not found in model_list" or OpenRouter "free is not a valid model ID" + +**Symptom:** You see either: + +- `Error creating provider: model "openrouter/free" not found in model_list` +- OpenRouter returns 400: `"free is not a valid model ID"` + +**Cause:** The `model` field in your `model_list` entry is what gets sent to the API. For OpenRouter you must use the **full** model ID, not a shorthand. + +- **Wrong:** `"model": "free"` → OpenRouter receives `free` and rejects it. +- **Right:** `"model": "openrouter/free"` → OpenRouter receives `openrouter/free` (auto free-tier routing). + +**Fix:** In `~/.picoclaw/config.json` (or your config path): + +1. **agents.defaults.model** must match a `model_name` in `model_list` (e.g. `"openrouter-free"`). +2. That entry’s **model** must be a valid OpenRouter model ID, for example: + - `"openrouter/free"` – auto free-tier + - `"google/gemini-2.0-flash-exp:free"` + - `"meta-llama/llama-3.1-8b-instruct:free"` + +Example snippet: + +```json +{ + "agents": { + "defaults": { + "model": "openrouter-free" + } + }, + "model_list": [ + { + "model_name": "openrouter-free", + "model": "openrouter/free", + "api_key": "sk-or-v1-YOUR_OPENROUTER_KEY", + "api_base": "https://openrouter.ai/api/v1" + } + ] +} +``` + +Get your key at [OpenRouter Keys](https://openrouter.ai/keys). diff --git a/go.mod b/go.mod index 98e20d07d..d7f9b1901 100644 --- a/go.mod +++ b/go.mod @@ -11,6 +11,7 @@ require ( github.com/google/uuid v1.6.0 github.com/gorilla/websocket v1.5.3 github.com/larksuite/oapi-sdk-go/v3 v3.5.3 + github.com/mdp/qrterminal/v3 v3.2.1 github.com/mymmrac/telego v1.6.0 github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1 github.com/openai/openai-go/v3 v3.22.0 @@ -18,15 +19,40 @@ require ( github.com/spf13/cobra v1.10.2 github.com/stretchr/testify v1.11.1 github.com/tencent-connect/botgo v0.2.1 + go.mau.fi/whatsmeow v0.0.0-20260219150138-7ae702b1eed4 golang.org/x/oauth2 v0.35.0 + golang.org/x/time v0.14.0 + google.golang.org/protobuf v1.36.11 + modernc.org/sqlite v1.46.1 ) require ( + filippo.io/edwards25519 v1.1.0 // indirect + github.com/beeper/argo-go v1.1.2 // indirect + github.com/coder/websocket v1.8.14 // indirect github.com/davecgh/go-spew v1.1.1 // indirect + github.com/dustin/go-humanize v1.0.1 // indirect + github.com/elliotchance/orderedmap/v3 v3.1.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/mattn/go-colorable v0.1.14 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect + github.com/ncruces/go-strftime v1.0.0 // indirect + github.com/petermattis/goid v0.0.0-20260113132338-7c7de50cc741 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect + github.com/rs/zerolog v1.34.0 // indirect github.com/spf13/pflag v1.0.10 // indirect + github.com/vektah/gqlparser/v2 v2.5.27 // indirect + go.mau.fi/libsignal v0.2.1 // indirect + go.mau.fi/util v0.9.6 // indirect + golang.org/x/exp v0.0.0-20260212183809-81e46e3db34a // indirect + golang.org/x/term v0.40.0 // indirect + golang.org/x/text v0.34.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect + modernc.org/libc v1.67.6 // indirect + modernc.org/mathutil v1.7.1 // indirect + modernc.org/memory v1.11.0 // indirect + rsc.io/qr v0.2.0 // indirect ) require ( diff --git a/go.sum b/go.sum index abbb11cd6..941ab67ce 100644 --- a/go.sum +++ b/go.sum @@ -1,10 +1,20 @@ cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k= +filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= +filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= +github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU= +github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU= github.com/adhocore/gronx v1.19.6 h1:5KNVcoR9ACgL9HhEqCm5QXsab/gI4QDIybTAWcXDKDc= github.com/adhocore/gronx v1.19.6/go.mod h1:7oUY1WAU8rEJWmAxXR2DN0JaO4gi9khSgKjiRypqteg= +github.com/agnivade/levenshtein v1.2.1 h1:EHBY3UOn1gwdy/VbFwgo4cxecRznFk7fKWN1KOX7eoM= +github.com/agnivade/levenshtein v1.2.1/go.mod h1:QVVI16kDrtSuwcpd0p1+xMC6Z/VfhtCyDIjcwga4/DU= +github.com/andreyvit/diff v0.0.0-20170406064948-c7f18ee00883 h1:bvNMNQO63//z+xNgfBlViaCIJKLlCJ6/fmUseuG0wVQ= +github.com/andreyvit/diff v0.0.0-20170406064948-c7f18ee00883/go.mod h1:rCTlJbsFo29Kk6CurOXKm700vrz8f0KW0JNfpkRJY/8= github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ= github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY= github.com/anthropics/anthropic-sdk-go v1.22.1 h1:xbsc3vJKCX/ELDZSpTNfz9wCgrFsamwFewPb1iI0Xh0= github.com/anthropics/anthropic-sdk-go v1.22.1/go.mod h1:WTz31rIUHUHqai2UslPpw5CwXrQP3geYBioRV4WOLvE= +github.com/beeper/argo-go v1.1.2 h1:UQI2G8F+NLfGTOmTUI0254pGKx/HUU/etbUGTJv91Fs= +github.com/beeper/argo-go v1.1.2/go.mod h1:M+LJAnyowKVQ6Rdj6XYGEn+qcVFkb3R/MUpqkGR0hM4= github.com/bwmarrin/discordgo v0.29.0 h1:FmWeXFaKUwrcL3Cx65c20bTRW+vOb6k8AnaP+EgjDno= github.com/bwmarrin/discordgo v0.29.0/go.mod h1:NJZpH+1AfhIcyQsPeuBKsUtYrRnjkyu0kIVMCHkZtRY= github.com/bytedance/gopkg v0.1.3 h1:TPBSwH8RsouGCBcMBktLt1AymVo2TVsBVCY4b6TnZ/M= @@ -25,12 +35,19 @@ github.com/chzyer/test v1.0.0 h1:p3BQDXSxOhOG0P9z6/hGnII4LGiEPOYBhs8asl/fC04= github.com/chzyer/test v1.0.0/go.mod h1:2JlltgoNkt4TW/z9V/IzDdFaMTM2JPIi26O1pF38GC8= github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M= github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU= +github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g= +github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg= +github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= +github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= +github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= +github.com/elliotchance/orderedmap/v3 v3.1.0 h1:j4DJ5ObEmMBt/lcwIecKcoRxIQUEnw0L804lXYDt/pg= +github.com/elliotchance/orderedmap/v3 v3.1.0/go.mod h1:G+Hc2RwaZvJMcS4JpGCOyViCnGeKf0bTYCGTO4uhjSo= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= github.com/github/copilot-sdk/go v0.1.23 h1:uExtO/inZQndCZMiSAA1hvXINiz9tqo/MZgQzFzurxw= @@ -42,6 +59,7 @@ github.com/go-resty/resty/v2 v2.17.1/go.mod h1:kCKZ3wWmwJaNc7S29BRtUhJwy7iqmn+2m github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE= github.com/go-test/deep v1.1.1 h1:0r/53hagsehfO4bzD2Pgr/+RgHqhmf+k1Bpse2cTu1U= github.com/go-test/deep v1.1.1/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncVwcsE= +github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= @@ -63,6 +81,8 @@ github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/jsonschema-go v0.4.2 h1:tmrUohrwoLZZS/P3x7ex0WAVknEkBZM46iALbcqoRA8= github.com/google/jsonschema-go v0.4.2/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE= +github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs= +github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= @@ -72,6 +92,8 @@ github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aN github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/grbit/go-json v0.11.0 h1:bAbyMdYrYl/OjYsSqLH99N2DyQ291mHy726Mx+sYrnc= github.com/grbit/go-json v0.11.0/go.mod h1:IYpHsdybQ386+6g3VE6AXQ3uTGa5mquBme5/ZWmtzek= +github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k= +github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= @@ -91,8 +113,21 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/larksuite/oapi-sdk-go/v3 v3.5.3 h1:xvf8Dv29kBXC5/DNDCLhHkAFW8l/0LlQJimO5Zn+JUk= github.com/larksuite/oapi-sdk-go/v3 v3.5.3/go.mod h1:ZEplY+kwuIrj/nqw5uSCINNATcH3KdxSN7y+UxYY5fI= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= +github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-sqlite3 v1.14.34 h1:3NtcvcUnFBPsuRcno8pUtupspG/GM+9nZ88zgJcp6Zk= +github.com/mattn/go-sqlite3 v1.14.34/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= +github.com/mdp/qrterminal/v3 v3.2.1 h1:6+yQjiiOsSuXT5n9/m60E54vdgFsw0zhADHhHLrFet4= +github.com/mdp/qrterminal/v3 v3.2.1/go.mod h1:jOTmXvnBsMy5xqLniO0R++Jmjs2sTm9dFSuQ5kpz/SU= github.com/mymmrac/telego v1.6.0 h1:Zc8rgyHozvd/7ZgyrigyHdAF9koHYMfilYfyB6wlFC0= github.com/mymmrac/telego v1.6.0/go.mod h1:xt6ZWA8zi8KmuzryE1ImEdl9JSwjHNpM4yhC7D8hU4Y= +github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w= +github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A= github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU= github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= @@ -105,13 +140,23 @@ github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1 h1:Lb/Uzkiw2Ugt2Xf03J5wmv github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1/go.mod h1:ln3IqPYYocZbYvl9TAOrG/cxGR9xcn4pnZRLdCTEGEU= github.com/openai/openai-go/v3 v3.22.0 h1:6MEoNoV8sbjOVmXdvhmuX3BjVbVdcExbVyGixiyJ8ys= github.com/openai/openai-go/v3 v3.22.0/go.mod h1:cdufnVK14cWcT9qA1rRtrXx4FTRsgbDPW7Ia7SS5cZo= +github.com/petermattis/goid v0.0.0-20260113132338-7c7de50cc741 h1:KPpdlQLZcHfTMQRi6bFQ7ogNO0ltFT4PmtwTLW4W+14= +github.com/petermattis/goid v0.0.0-20260113132338-7c7de50cc741/go.mod h1:pxMtw7cyUw6B2bRH0ZBANSPg+AoSud1I1iyJHI69jH4= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= +github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= +github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY= +github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/sergi/go-diff v1.3.1 h1:xkr+Oxo4BOQKmkn/B9eMK0g5Kg/983T9DqqPHwYqD+8= +github.com/sergi/go-diff v1.3.1/go.mod h1:aMJSSKb2lpPvRNec0+w3fl7LP9IOFzdc9Pa4NFbPK1I= github.com/slack-go/slack v0.17.3 h1:zV5qO3Q+WJAQ/XwbGfNFrRMaJ5T/naqaonyPV/1TP4g= github.com/slack-go/slack v0.17.3/go.mod h1:X+UqOufi3LYQHDnMG1vxf0J8asC6+WllXrVrhl8/Prk= github.com/spf13/cobra v1.10.2 h1:DMTTonx5m65Ic0GOoRY2c16WCbHxOOw6xxezuLaBpcU= @@ -153,11 +198,19 @@ github.com/valyala/fasthttp v1.69.0 h1:fNLLESD2SooWeh2cidsuFtOcrEi4uB4m1mPrkJMZy github.com/valyala/fasthttp v1.69.0/go.mod h1:4wA4PfAraPlAsJ5jMSqCE2ug5tqUPwKXxVj8oNECGcw= github.com/valyala/fastjson v1.6.7 h1:ZE4tRy0CIkh+qDc5McjatheGX2czdn8slQjomexVpBM= github.com/valyala/fastjson v1.6.7/go.mod h1:CLCAqky6SMuOcxStkYQvblddUtoRxhYMGLrsQns1aXY= +github.com/vektah/gqlparser/v2 v2.5.27 h1:RHPD3JOplpk5mP5JGX8RKZkt2/Vwj/PZv0HxTdwFp0s= +github.com/vektah/gqlparser/v2 v2.5.27/go.mod h1:D1/VCZtV3LPnQrcPBeR/q5jkSQIPti0uYCP/RI0gIeo= github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU= github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +go.mau.fi/libsignal v0.2.1 h1:vRZG4EzTn70XY6Oh/pVKrQGuMHBkAWlGRC22/85m9L0= +go.mau.fi/libsignal v0.2.1/go.mod h1:iVvjrHyfQqWajOUaMEsIfo3IqgVMrhWcPiiEzk7NgoU= +go.mau.fi/util v0.9.6 h1:2nsvxm49KhI3wrFltr0+wSUBlnQ4CMtykuELjpIU+ts= +go.mau.fi/util v0.9.6/go.mod h1:sIJpRH7Iy5Ad1SBuxQoatxtIeErgzxCtjd/2hCMkYMI= +go.mau.fi/whatsmeow v0.0.0-20260219150138-7ae702b1eed4 h1:hsmlwsM+VqfF70cpdZEeIUKer2XWCQmQPK0u0tHy3ZQ= +go.mau.fi/whatsmeow v0.0.0-20260219150138-7ae702b1eed4/go.mod h1:mXCRFyPEPn4jqWz6Afirn8vY7DpHCPnlKq6I2cWwFHM= go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y= go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU= go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= @@ -171,10 +224,14 @@ golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5y golang.org/x/crypto v0.16.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4= golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts= golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos= +golang.org/x/exp v0.0.0-20260212183809-81e46e3db34a h1:ovFr6Z0MNmU7nH8VaX5xqw+05ST2uO1exVfZPVqRC5o= +golang.org/x/exp v0.0.0-20260212183809-81e46e3db34a/go.mod h1:K79w1Vqn7PoiZn+TkNpx3BUWUQksGO3JcVX6qIjytmA= golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/mod v0.33.0 h1:tHFzIWbBifEmbwtGz65eaWyGiGZatSrT9prnU8DbVL8= +golang.org/x/mod v0.33.0/go.mod h1:swjeQEj+6r7fODbD2cqrnje9PnziFuw4bmLbBZFrQ5w= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= @@ -217,8 +274,11 @@ golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= @@ -227,6 +287,8 @@ golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuX golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= golang.org/x/term v0.15.0/go.mod h1:BDl952bC7+uMoWR75FIrCDx79TPU9oHkTZ9yRbYOrX0= +golang.org/x/term v0.40.0 h1:36e4zGLqU4yhjlmxEaagx2KuYbJq3EwY8K943ZsHcvg= +golang.org/x/term v0.40.0/go.mod h1:w2P8uVp06p2iyKKuvXIm7N/y0UCRt3UfJTfZ7oOpglM= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= @@ -234,8 +296,10 @@ golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= -golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE= -golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= +golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk= +golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA= +golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI= +golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= @@ -243,6 +307,8 @@ golang.org/x/tools v0.0.0-20201224043029-2b0845dc783e/go.mod h1:emZCQorbCU4vsT4f golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= +golang.org/x/tools v0.42.0 h1:uNgphsn75Tdz5Ji2q36v/nsFSfR/9BRFvqhGBaJGd5k= +golang.org/x/tools v0.42.0/go.mod h1:Ma6lCIwGZvHK6XtgbswSoWroEkhugApmsXyrUmBhfr0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= @@ -255,6 +321,8 @@ google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzi google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= +google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= +google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= @@ -269,3 +337,33 @@ gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +modernc.org/cc/v4 v4.27.1 h1:9W30zRlYrefrDV2JE2O8VDtJ1yPGownxciz5rrbQZis= +modernc.org/cc/v4 v4.27.1/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0= +modernc.org/ccgo/v4 v4.30.1 h1:4r4U1J6Fhj98NKfSjnPUN7Ze2c6MnAdL0hWw6+LrJpc= +modernc.org/ccgo/v4 v4.30.1/go.mod h1:bIOeI1JL54Utlxn+LwrFyjCx2n2RDiYEaJVSrgdrRfM= +modernc.org/fileutil v1.3.40 h1:ZGMswMNc9JOCrcrakF1HrvmergNLAmxOPjizirpfqBA= +modernc.org/fileutil v1.3.40/go.mod h1:HxmghZSZVAz/LXcMNwZPA/DRrQZEVP9VX0V4LQGQFOc= +modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI= +modernc.org/gc/v2 v2.6.5/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito= +modernc.org/gc/v3 v3.1.1 h1:k8T3gkXWY9sEiytKhcgyiZ2L0DTyCQ/nvX+LoCljoRE= +modernc.org/gc/v3 v3.1.1/go.mod h1:HFK/6AGESC7Ex+EZJhJ2Gni6cTaYpSMmU/cT9RmlfYY= +modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks= +modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI= +modernc.org/libc v1.67.6 h1:eVOQvpModVLKOdT+LvBPjdQqfrZq+pC39BygcT+E7OI= +modernc.org/libc v1.67.6/go.mod h1:JAhxUVlolfYDErnwiqaLvUqc8nfb2r6S6slAgZOnaiE= +modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU= +modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg= +modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI= +modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw= +modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8= +modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns= +modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w= +modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE= +modernc.org/sqlite v1.46.1 h1:eFJ2ShBLIEnUWlLy12raN0Z1plqmFX9Qe3rjQTKt6sU= +modernc.org/sqlite v1.46.1/go.mod h1:CzbrU2lSB1DKUusvwGz7rqEKIq+NUd8GWuBBZDs9/nA= +modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0= +modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A= +modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y= +modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM= +rsc.io/qr v0.2.0 h1:6vBLea5/NRMVTz8V66gipeLycZMl/+UlFmk8DvqQ6WY= +rsc.io/qr v0.2.0/go.mod h1:IF+uZjkb9fqyeF/4tlBoynqmQxUoPfWEKh921coOuXs= diff --git a/pkg/agent/instance.go b/pkg/agent/instance.go index a6fd365c7..65a1fe04d 100644 --- a/pkg/agent/instance.go +++ b/pkg/agent/instance.go @@ -92,7 +92,47 @@ func NewAgentInstance( Primary: model, Fallbacks: fallbacks, } - candidates := providers.ResolveCandidates(modelCfg, defaults.Provider) + resolveFromModelList := func(raw string) (string, bool) { + ensureProtocol := func(model string) string { + model = strings.TrimSpace(model) + if model == "" { + return "" + } + if strings.Contains(model, "/") { + return model + } + return "openai/" + model + } + + raw = strings.TrimSpace(raw) + if raw == "" { + return "", false + } + + if cfg != nil { + if mc, err := cfg.GetModelConfig(raw); err == nil && mc != nil && strings.TrimSpace(mc.Model) != "" { + return ensureProtocol(mc.Model), true + } + + for i := range cfg.ModelList { + fullModel := strings.TrimSpace(cfg.ModelList[i].Model) + if fullModel == "" { + continue + } + if fullModel == raw { + return ensureProtocol(fullModel), true + } + _, modelID := providers.ExtractProtocol(fullModel) + if modelID == raw { + return ensureProtocol(fullModel), true + } + } + } + + return "", false + } + + candidates := providers.ResolveCandidatesWithLookup(modelCfg, defaults.Provider, resolveFromModelList) return &AgentInstance{ ID: agentID, diff --git a/pkg/agent/instance_test.go b/pkg/agent/instance_test.go index fcc8e9bea..af1bf2ead 100644 --- a/pkg/agent/instance_test.go +++ b/pkg/agent/instance_test.go @@ -93,3 +93,77 @@ func TestNewAgentInstance_DefaultsTemperatureWhenUnset(t *testing.T) { t.Fatalf("Temperature = %f, want %f", agent.Temperature, 0.7) } } + +func TestNewAgentInstance_ResolveCandidatesFromModelListAlias(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-instance-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "step-3.5-flash", + }, + }, + ModelList: []config.ModelConfig{ + { + ModelName: "step-3.5-flash", + Model: "openrouter/stepfun/step-3.5-flash:free", + APIBase: "https://openrouter.ai/api/v1", + }, + }, + } + + provider := &mockProvider{} + agent := NewAgentInstance(nil, &cfg.Agents.Defaults, cfg, provider) + + if len(agent.Candidates) != 1 { + t.Fatalf("len(Candidates) = %d, want 1", len(agent.Candidates)) + } + if agent.Candidates[0].Provider != "openrouter" { + t.Fatalf("candidate provider = %q, want %q", agent.Candidates[0].Provider, "openrouter") + } + if agent.Candidates[0].Model != "stepfun/step-3.5-flash:free" { + t.Fatalf("candidate model = %q, want %q", agent.Candidates[0].Model, "stepfun/step-3.5-flash:free") + } +} + +func TestNewAgentInstance_ResolveCandidatesFromModelListAliasWithoutProtocol(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-instance-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "glm-5", + }, + }, + ModelList: []config.ModelConfig{ + { + ModelName: "glm-5", + Model: "glm-5", + APIBase: "https://api.z.ai/api/coding/paas/v4", + }, + }, + } + + provider := &mockProvider{} + agent := NewAgentInstance(nil, &cfg.Agents.Defaults, cfg, provider) + + if len(agent.Candidates) != 1 { + t.Fatalf("len(Candidates) = %d, want 1", len(agent.Candidates)) + } + if agent.Candidates[0].Provider != "openai" { + t.Fatalf("candidate provider = %q, want %q", agent.Candidates[0].Provider, "openai") + } + if agent.Candidates[0].Model != "glm-5" { + t.Fatalf("candidate model = %q, want %q", agent.Candidates[0].Model, "glm-5") + } +} diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index 693f2227b..29827d0b2 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -10,6 +10,7 @@ import ( "context" "encoding/json" "fmt" + "path/filepath" "strings" "sync" "sync/atomic" @@ -21,6 +22,7 @@ import ( "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/constants" "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/media" "github.com/sipeed/picoclaw/pkg/providers" "github.com/sipeed/picoclaw/pkg/routing" "github.com/sipeed/picoclaw/pkg/skills" @@ -38,6 +40,7 @@ type AgentLoop struct { summarizing sync.Map fallback *providers.FallbackChain channelManager *channels.Manager + mediaStore media.MediaStore } // processOptions configures how a message is processed @@ -52,6 +55,8 @@ type processOptions struct { NoHistory bool // If true, don't load session history (for heartbeat) } +const defaultResponse = "I've completed processing but have no response to give. Increase `max_tool_iterations` in config.json." + func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers.LLMProvider) *AgentLoop { registry := NewAgentRegistry(cfg, provider) @@ -119,12 +124,13 @@ func registerSharedTools( // Message tool messageTool := tools.NewMessageTool() messageTool.SetSendCallback(func(channel, chatID, content string) error { - msgBus.PublishOutbound(bus.OutboundMessage{ + pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer pubCancel() + return msgBus.PublishOutbound(pubCtx, bus.OutboundMessage{ Channel: channel, ChatID: chatID, Content: content, }) - return nil }) agent.Tools.Register(messageTool) @@ -165,33 +171,61 @@ func (al *AgentLoop) Run(ctx context.Context) error { continue } - response, err := al.processMessage(ctx, msg) - if err != nil { - response = fmt.Sprintf("Error processing message: %v", err) - } + // Process message + func() { + // TODO: Re-enable media cleanup after inbound media is properly consumed by the agent. + // Currently disabled because files are deleted before the LLM can access their content. + // defer func() { + // if al.mediaStore != nil && msg.MediaScope != "" { + // if releaseErr := al.mediaStore.ReleaseAll(msg.MediaScope); releaseErr != nil { + // logger.WarnCF("agent", "Failed to release media", map[string]any{ + // "scope": msg.MediaScope, + // "error": releaseErr.Error(), + // }) + // } + // } + // }() - if response != "" { - // Check if the message tool already sent a response during this round. - // If so, skip publishing to avoid duplicate messages to the user. - // Use default agent's tools to check (message tool is shared). - alreadySent := false - defaultAgent := al.registry.GetDefaultAgent() - if defaultAgent != nil { - if tool, ok := defaultAgent.Tools.Get("message"); ok { - if mt, ok := tool.(*tools.MessageTool); ok { - alreadySent = mt.HasSentInRound() + response, err := al.processMessage(ctx, msg) + if err != nil { + response = fmt.Sprintf("Error processing message: %v", err) + } + + if response != "" { + // Check if the message tool already sent a response during this round. + // If so, skip publishing to avoid duplicate messages to the user. + // Use default agent's tools to check (message tool is shared). + alreadySent := false + defaultAgent := al.registry.GetDefaultAgent() + if defaultAgent != nil { + if tool, ok := defaultAgent.Tools.Get("message"); ok { + if mt, ok := tool.(*tools.MessageTool); ok { + alreadySent = mt.HasSentInRound() + } } } - } - if !alreadySent { - al.bus.PublishOutbound(bus.OutboundMessage{ - Channel: msg.Channel, - ChatID: msg.ChatID, - Content: response, - }) + if !alreadySent { + al.bus.PublishOutbound(ctx, bus.OutboundMessage{ + Channel: msg.Channel, + ChatID: msg.ChatID, + Content: response, + }) + logger.InfoCF("agent", "Published outbound response", + map[string]any{ + "channel": msg.Channel, + "chat_id": msg.ChatID, + "content_len": len(response), + }) + } else { + logger.DebugCF( + "agent", + "Skipped outbound (message tool already sent)", + map[string]any{"channel": msg.Channel}, + ) + } } - } + }() } } @@ -214,6 +248,41 @@ func (al *AgentLoop) SetChannelManager(cm *channels.Manager) { al.channelManager = cm } +// SetMediaStore injects a MediaStore for media lifecycle management. +func (al *AgentLoop) SetMediaStore(s media.MediaStore) { + al.mediaStore = s +} + +// inferMediaType determines the media type ("image", "audio", "video", "file") +// from a filename and MIME content type. +func inferMediaType(filename, contentType string) string { + ct := strings.ToLower(contentType) + fn := strings.ToLower(filename) + + if strings.HasPrefix(ct, "image/") { + return "image" + } + if strings.HasPrefix(ct, "audio/") || ct == "application/ogg" { + return "audio" + } + if strings.HasPrefix(ct, "video/") { + return "video" + } + + // Fallback: infer from extension + ext := filepath.Ext(fn) + switch ext { + case ".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp", ".svg": + return "image" + case ".mp3", ".wav", ".ogg", ".m4a", ".flac", ".aac", ".wma", ".opus": + return "audio" + case ".mp4", ".avi", ".mov", ".webm", ".mkv": + return "video" + } + + return "file" +} + // RecordLastChannel records the last active channel for this workspace. // This uses the atomic state save mechanism to prevent data loss on crash. func (al *AgentLoop) RecordLastChannel(channel string) error { @@ -255,12 +324,15 @@ func (al *AgentLoop) ProcessDirectWithChannel( // Each heartbeat is independent and doesn't accumulate context. func (al *AgentLoop) ProcessHeartbeat(ctx context.Context, content, channel, chatID string) (string, error) { agent := al.registry.GetDefaultAgent() + if agent == nil { + return "", fmt.Errorf("no default agent for heartbeat") + } return al.runAgentLoop(ctx, agent, processOptions{ SessionKey: "heartbeat", Channel: channel, ChatID: chatID, UserMessage: content, - DefaultResponse: "I've completed processing but have no response to give.", + DefaultResponse: defaultResponse, EnableSummary: false, SendResponse: false, NoHistory: true, // Don't load session history for heartbeat @@ -307,6 +379,16 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) if !ok { agent = al.registry.GetDefaultAgent() } + if agent == nil { + return "", fmt.Errorf("no agent available for route (agent_id=%s)", route.AgentID) + } + + // Reset message-tool state for this round so we don't skip publishing due to a previous round. + if tool, ok := agent.Tools.Get("message"); ok { + if mt, ok := tool.(tools.ContextualTool); ok { + mt.SetContext(msg.Channel, msg.ChatID) + } + } // Use routed session key, but honor pre-set agent-scoped keys (for ProcessDirect/cron) sessionKey := route.SessionKey @@ -326,7 +408,7 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) Channel: msg.Channel, ChatID: msg.ChatID, UserMessage: msg.Content, - DefaultResponse: "I've completed processing but have no response to give.", + DefaultResponse: defaultResponse, EnableSummary: true, SendResponse: false, }) @@ -373,6 +455,9 @@ func (al *AgentLoop) processSystemMessage(ctx context.Context, msg bus.InboundMe // Use default agent for system messages agent := al.registry.GetDefaultAgent() + if agent == nil { + return "", fmt.Errorf("no default agent for system message") + } // Use the origin session for context sessionKey := routing.BuildAgentMainSessionKey(agent.ID) @@ -448,7 +533,7 @@ func (al *AgentLoop) runAgentLoop(ctx context.Context, agent *AgentInstance, opt // 8. Optional: send response via bus if opts.SendResponse { - al.bus.PublishOutbound(bus.OutboundMessage{ + al.bus.PublishOutbound(ctx, bus.OutboundMessage{ Channel: opts.Channel, ChatID: opts.ChatID, Content: finalContent, @@ -468,6 +553,34 @@ func (al *AgentLoop) runAgentLoop(ctx context.Context, agent *AgentInstance, opt return finalContent, nil } +func (al *AgentLoop) targetReasoningChannelID(channelName string) (chatID string) { + if al.channelManager == nil { + return "" + } + if ch, ok := al.channelManager.GetChannel(channelName); ok { + return ch.ReasoningChannelID() + } + return "" +} + +func (al *AgentLoop) handleReasoning(ctx context.Context, reasoningContent, channelName, channelID string) { + if reasoningContent == "" || channelName == "" || channelID == "" { + return + } + + // Check context cancellation before attempting to publish, + // since PublishOutbound's select may race between send and ctx.Done(). + if ctx.Err() != nil { + return + } + + al.bus.PublishOutbound(ctx, bus.OutboundMessage{ + Channel: channelName, + ChatID: channelID, + Content: reasoningContent, + }) +} + // runLLMIteration executes the LLM call loop with tool handling. func (al *AgentLoop) runLLMIteration( ctx context.Context, @@ -565,7 +678,7 @@ func (al *AgentLoop) runLLMIteration( }) if retry == 0 && !constants.IsInternalChannel(opts.Channel) { - al.bus.PublishOutbound(bus.OutboundMessage{ + al.bus.PublishOutbound(ctx, bus.OutboundMessage{ Channel: opts.Channel, ChatID: opts.ChatID, Content: "Context window exceeded. Compressing history and retrying...", @@ -594,6 +707,18 @@ func (al *AgentLoop) runLLMIteration( return "", iteration, fmt.Errorf("LLM call failed after retries: %w", err) } + go al.handleReasoning(ctx, response.Reasoning, opts.Channel, al.targetReasoningChannelID(opts.Channel)) + + logger.DebugCF("agent", "LLM response", + map[string]any{ + "agent_id": agent.ID, + "iteration": iteration, + "content_chars": len(response.Content), + "tool_calls": len(response.ToolCalls), + "reasoning": response.Reasoning, + "target_channel": al.targetReasoningChannelID(opts.Channel), + "channel": opts.Channel, + }) // Check if no tool calls - we're done if len(response.ToolCalls) == 0 { finalContent = response.Content @@ -695,7 +820,7 @@ func (al *AgentLoop) runLLMIteration( // Send ForUser content to user immediately if not Silent if !toolResult.Silent && toolResult.ForUser != "" && opts.SendResponse { - al.bus.PublishOutbound(bus.OutboundMessage{ + al.bus.PublishOutbound(ctx, bus.OutboundMessage{ Channel: opts.Channel, ChatID: opts.ChatID, Content: toolResult.ForUser, @@ -707,6 +832,28 @@ func (al *AgentLoop) runLLMIteration( }) } + // If tool returned media refs, publish them as outbound media + if len(toolResult.Media) > 0 && opts.SendResponse { + parts := make([]bus.MediaPart, 0, len(toolResult.Media)) + for _, ref := range toolResult.Media { + part := bus.MediaPart{Ref: ref} + // Populate metadata from MediaStore when available + if al.mediaStore != nil { + if _, meta, err := al.mediaStore.ResolveWithMeta(ref); err == nil { + part.Filename = meta.Filename + part.ContentType = meta.ContentType + part.Type = inferMediaType(meta.Filename, meta.ContentType) + } + } + parts = append(parts, part) + } + al.bus.PublishOutboundMedia(ctx, bus.OutboundMediaMessage{ + Channel: opts.Channel, + ChatID: opts.ChatID, + Parts: parts, + }) + } + // Determine content for LLM based on tool result contentForLLM := toolResult.ForLLM if contentForLLM == "" && toolResult.Err != nil { @@ -1119,21 +1266,20 @@ func (al *AgentLoop) handleCommand(ctx context.Context, msg bus.InboundMessage) return "", false } -// extractPeer extracts the routing peer from inbound message metadata. +// extractPeer extracts the routing peer from the inbound message's structured Peer field. func extractPeer(msg bus.InboundMessage) *routing.RoutePeer { - peerKind := msg.Metadata["peer_kind"] - if peerKind == "" { + if msg.Peer.Kind == "" { return nil } - peerID := msg.Metadata["peer_id"] + peerID := msg.Peer.ID if peerID == "" { - if peerKind == "direct" { + if msg.Peer.Kind == "direct" { peerID = msg.SenderID } else { peerID = msg.ChatID } } - return &routing.RoutePeer{Kind: peerKind, ID: peerID} + return &routing.RoutePeer{Kind: msg.Peer.Kind, ID: peerID} } // extractParentPeer extracts the parent peer (reply-to) from inbound message metadata. diff --git a/pkg/agent/loop_test.go b/pkg/agent/loop_test.go index aa9f823b7..2a23c889c 100644 --- a/pkg/agent/loop_test.go +++ b/pkg/agent/loop_test.go @@ -10,11 +10,23 @@ import ( "time" "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/providers" "github.com/sipeed/picoclaw/pkg/tools" ) +type fakeChannel struct{ id string } + +func (f *fakeChannel) Name() string { return "fake" } +func (f *fakeChannel) Start(ctx context.Context) error { return nil } +func (f *fakeChannel) Stop(ctx context.Context) error { return nil } +func (f *fakeChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { return nil } +func (f *fakeChannel) IsRunning() bool { return true } +func (f *fakeChannel) IsAllowed(string) bool { return true } +func (f *fakeChannel) IsAllowedSender(sender bus.SenderInfo) bool { return true } +func (f *fakeChannel) ReasoningChannelID() string { return f.id } + func TestRecordLastChannel(t *testing.T) { // Create temp workspace tmpDir, err := os.MkdirTemp("", "agent-test-*") @@ -620,3 +632,158 @@ func TestAgentLoop_ContextExhaustionRetry(t *testing.T) { t.Errorf("Expected history to be compressed (len < 8), got %d", len(finalHistory)) } } + +func TestTargetReasoningChannelID_AllChannels(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + + al := NewAgentLoop(cfg, bus.NewMessageBus(), &mockProvider{}) + chManager, err := channels.NewManager(&config.Config{}, bus.NewMessageBus(), nil) + if err != nil { + t.Fatalf("Failed to create channel manager: %v", err) + } + for name, id := range map[string]string{ + "whatsapp": "rid-whatsapp", + "telegram": "rid-telegram", + "feishu": "rid-feishu", + "discord": "rid-discord", + "maixcam": "rid-maixcam", + "qq": "rid-qq", + "dingtalk": "rid-dingtalk", + "slack": "rid-slack", + "line": "rid-line", + "onebot": "rid-onebot", + "wecom": "rid-wecom", + "wecom_app": "rid-wecom-app", + } { + chManager.RegisterChannel(name, &fakeChannel{id: id}) + } + al.SetChannelManager(chManager) + tests := []struct { + channel string + wantID string + }{ + {channel: "whatsapp", wantID: "rid-whatsapp"}, + {channel: "telegram", wantID: "rid-telegram"}, + {channel: "feishu", wantID: "rid-feishu"}, + {channel: "discord", wantID: "rid-discord"}, + {channel: "maixcam", wantID: "rid-maixcam"}, + {channel: "qq", wantID: "rid-qq"}, + {channel: "dingtalk", wantID: "rid-dingtalk"}, + {channel: "slack", wantID: "rid-slack"}, + {channel: "line", wantID: "rid-line"}, + {channel: "onebot", wantID: "rid-onebot"}, + {channel: "wecom", wantID: "rid-wecom"}, + {channel: "wecom_app", wantID: "rid-wecom-app"}, + {channel: "unknown", wantID: ""}, + } + + for _, tt := range tests { + t.Run(tt.channel, func(t *testing.T) { + got := al.targetReasoningChannelID(tt.channel) + if got != tt.wantID { + t.Fatalf("targetReasoningChannelID(%q) = %q, want %q", tt.channel, got, tt.wantID) + } + }) + } +} + +func TestHandleReasoning(t *testing.T) { + newLoop := func(t *testing.T) (*AgentLoop, *bus.MessageBus) { + t.Helper() + tmpDir, err := os.MkdirTemp("", "agent-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + t.Cleanup(func() { _ = os.RemoveAll(tmpDir) }) + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + msgBus := bus.NewMessageBus() + return NewAgentLoop(cfg, msgBus, &mockProvider{}), msgBus + } + + t.Run("skips when any required field is empty", func(t *testing.T) { + al, msgBus := newLoop(t) + al.handleReasoning(context.Background(), "reasoning", "telegram", "") + + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond) + defer cancel() + if msg, ok := msgBus.SubscribeOutbound(ctx); ok { + t.Fatalf("expected no outbound message, got %+v", msg) + } + }) + + t.Run("publishes one message for non telegram", func(t *testing.T) { + al, msgBus := newLoop(t) + al.handleReasoning(context.Background(), "hello reasoning", "slack", "channel-1") + + ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) + defer cancel() + msg, ok := msgBus.SubscribeOutbound(ctx) + if !ok { + t.Fatal("expected an outbound message") + } + if msg.Channel != "slack" || msg.ChatID != "channel-1" || msg.Content != "hello reasoning" { + t.Fatalf("unexpected outbound message: %+v", msg) + } + }) + + t.Run("publishes one message for telegram", func(t *testing.T) { + al, msgBus := newLoop(t) + reasoning := "hello telegram reasoning" + al.handleReasoning(context.Background(), reasoning, "telegram", "tg-chat") + + ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) + defer cancel() + msg, ok := msgBus.SubscribeOutbound(ctx) + if !ok { + t.Fatal("expected outbound message") + } + + if msg.Channel != "telegram" { + t.Fatalf("expected telegram channel message, got %+v", msg) + } + if msg.ChatID != "tg-chat" { + t.Fatalf("expected chatID tg-chat, got %+v", msg) + } + if msg.Content != reasoning { + t.Fatalf("content mismatch: got %q want %q", msg.Content, reasoning) + } + }) + t.Run("expired ctx", func(t *testing.T) { + al, msgBus := newLoop(t) + reasoning := "hello telegram reasoning" + ctx, cancel := context.WithCancel(context.Background()) + cancel() + al.handleReasoning(ctx, reasoning, "telegram", "tg-chat") + + ctx, cancel = context.WithTimeout(context.Background(), 200*time.Millisecond) + defer cancel() + msg, ok := msgBus.SubscribeOutbound(ctx) + if ok { + t.Fatalf("expected no outbound message, got %+v", msg) + } + }) +} diff --git a/pkg/bus/bus.go b/pkg/bus/bus.go index 58c0a25d5..f5ff9587d 100644 --- a/pkg/bus/bus.go +++ b/pkg/bus/bus.go @@ -2,81 +2,156 @@ package bus import ( "context" - "sync" + "errors" + "sync/atomic" + + "github.com/sipeed/picoclaw/pkg/logger" ) +// ErrBusClosed is returned when publishing to a closed MessageBus. +var ErrBusClosed = errors.New("message bus closed") + +const defaultBusBufferSize = 64 + type MessageBus struct { - inbound chan InboundMessage - outbound chan OutboundMessage - handlers map[string]MessageHandler - closed bool - mu sync.RWMutex + inbound chan InboundMessage + outbound chan OutboundMessage + outboundMedia chan OutboundMediaMessage + done chan struct{} + closed atomic.Bool } func NewMessageBus() *MessageBus { return &MessageBus{ - inbound: make(chan InboundMessage, 100), - outbound: make(chan OutboundMessage, 100), - handlers: make(map[string]MessageHandler), + inbound: make(chan InboundMessage, defaultBusBufferSize), + outbound: make(chan OutboundMessage, defaultBusBufferSize), + outboundMedia: make(chan OutboundMediaMessage, defaultBusBufferSize), + done: make(chan struct{}), } } -func (mb *MessageBus) PublishInbound(msg InboundMessage) { - mb.mu.RLock() - defer mb.mu.RUnlock() - if mb.closed { - return +func (mb *MessageBus) PublishInbound(ctx context.Context, msg InboundMessage) error { + if mb.closed.Load() { + return ErrBusClosed + } + if err := ctx.Err(); err != nil { + return err + } + select { + case mb.inbound <- msg: + return nil + case <-mb.done: + return ErrBusClosed + case <-ctx.Done(): + return ctx.Err() } - mb.inbound <- msg } func (mb *MessageBus) ConsumeInbound(ctx context.Context) (InboundMessage, bool) { select { - case msg := <-mb.inbound: - return msg, true + case msg, ok := <-mb.inbound: + return msg, ok + case <-mb.done: + return InboundMessage{}, false case <-ctx.Done(): return InboundMessage{}, false } } -func (mb *MessageBus) PublishOutbound(msg OutboundMessage) { - mb.mu.RLock() - defer mb.mu.RUnlock() - if mb.closed { - return +func (mb *MessageBus) PublishOutbound(ctx context.Context, msg OutboundMessage) error { + if mb.closed.Load() { + return ErrBusClosed + } + if err := ctx.Err(); err != nil { + return err + } + select { + case mb.outbound <- msg: + return nil + case <-mb.done: + return ErrBusClosed + case <-ctx.Done(): + return ctx.Err() } - mb.outbound <- msg } func (mb *MessageBus) SubscribeOutbound(ctx context.Context) (OutboundMessage, bool) { select { - case msg := <-mb.outbound: - return msg, true + case msg, ok := <-mb.outbound: + return msg, ok + case <-mb.done: + return OutboundMessage{}, false case <-ctx.Done(): return OutboundMessage{}, false } } -func (mb *MessageBus) RegisterHandler(channel string, handler MessageHandler) { - mb.mu.Lock() - defer mb.mu.Unlock() - mb.handlers[channel] = handler +func (mb *MessageBus) PublishOutboundMedia(ctx context.Context, msg OutboundMediaMessage) error { + if mb.closed.Load() { + return ErrBusClosed + } + if err := ctx.Err(); err != nil { + return err + } + select { + case mb.outboundMedia <- msg: + return nil + case <-mb.done: + return ErrBusClosed + case <-ctx.Done(): + return ctx.Err() + } } -func (mb *MessageBus) GetHandler(channel string) (MessageHandler, bool) { - mb.mu.RLock() - defer mb.mu.RUnlock() - handler, ok := mb.handlers[channel] - return handler, ok +func (mb *MessageBus) SubscribeOutboundMedia(ctx context.Context) (OutboundMediaMessage, bool) { + select { + case msg, ok := <-mb.outboundMedia: + return msg, ok + case <-mb.done: + return OutboundMediaMessage{}, false + case <-ctx.Done(): + return OutboundMediaMessage{}, false + } } func (mb *MessageBus) Close() { - mb.mu.Lock() - defer mb.mu.Unlock() - if mb.closed { - return + if mb.closed.CompareAndSwap(false, true) { + close(mb.done) + + // Drain buffered channels so messages aren't silently lost. + // Channels are NOT closed to avoid send-on-closed panics from concurrent publishers. + drained := 0 + for { + select { + case <-mb.inbound: + drained++ + default: + goto doneInbound + } + } + doneInbound: + for { + select { + case <-mb.outbound: + drained++ + default: + goto doneOutbound + } + } + doneOutbound: + for { + select { + case <-mb.outboundMedia: + drained++ + default: + goto doneMedia + } + } + doneMedia: + if drained > 0 { + logger.DebugCF("bus", "Drained buffered messages during close", map[string]any{ + "count": drained, + }) + } } - mb.closed = true - close(mb.inbound) - close(mb.outbound) } diff --git a/pkg/bus/bus_test.go b/pkg/bus/bus_test.go new file mode 100644 index 000000000..e07b8c7fe --- /dev/null +++ b/pkg/bus/bus_test.go @@ -0,0 +1,229 @@ +package bus + +import ( + "context" + "sync" + "testing" + "time" +) + +func TestPublishConsume(t *testing.T) { + mb := NewMessageBus() + defer mb.Close() + + ctx := context.Background() + + msg := InboundMessage{ + Channel: "test", + SenderID: "user1", + ChatID: "chat1", + Content: "hello", + } + + if err := mb.PublishInbound(ctx, msg); err != nil { + t.Fatalf("PublishInbound failed: %v", err) + } + + got, ok := mb.ConsumeInbound(ctx) + if !ok { + t.Fatal("ConsumeInbound returned ok=false") + } + if got.Content != "hello" { + t.Fatalf("expected content 'hello', got %q", got.Content) + } + if got.Channel != "test" { + t.Fatalf("expected channel 'test', got %q", got.Channel) + } +} + +func TestPublishOutboundSubscribe(t *testing.T) { + mb := NewMessageBus() + defer mb.Close() + + ctx := context.Background() + + msg := OutboundMessage{ + Channel: "telegram", + ChatID: "123", + Content: "world", + } + + if err := mb.PublishOutbound(ctx, msg); err != nil { + t.Fatalf("PublishOutbound failed: %v", err) + } + + got, ok := mb.SubscribeOutbound(ctx) + if !ok { + t.Fatal("SubscribeOutbound returned ok=false") + } + if got.Content != "world" { + t.Fatalf("expected content 'world', got %q", got.Content) + } +} + +func TestPublishInbound_ContextCancel(t *testing.T) { + mb := NewMessageBus() + defer mb.Close() + + // Fill the buffer + ctx := context.Background() + for i := range defaultBusBufferSize { + if err := mb.PublishInbound(ctx, InboundMessage{Content: "fill"}); err != nil { + t.Fatalf("fill failed at %d: %v", i, err) + } + } + + // Now buffer is full; publish with a canceled context + cancelCtx, cancel := context.WithCancel(context.Background()) + cancel() + + err := mb.PublishInbound(cancelCtx, InboundMessage{Content: "overflow"}) + if err == nil { + t.Fatal("expected error from canceled context, got nil") + } + if err != context.Canceled { + t.Fatalf("expected context.Canceled, got %v", err) + } +} + +func TestPublishInbound_BusClosed(t *testing.T) { + mb := NewMessageBus() + mb.Close() + + err := mb.PublishInbound(context.Background(), InboundMessage{Content: "test"}) + if err != ErrBusClosed { + t.Fatalf("expected ErrBusClosed, got %v", err) + } +} + +func TestPublishOutbound_BusClosed(t *testing.T) { + mb := NewMessageBus() + mb.Close() + + err := mb.PublishOutbound(context.Background(), OutboundMessage{Content: "test"}) + if err != ErrBusClosed { + t.Fatalf("expected ErrBusClosed, got %v", err) + } +} + +func TestConsumeInbound_ContextCancel(t *testing.T) { + mb := NewMessageBus() + defer mb.Close() + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + _, ok := mb.ConsumeInbound(ctx) + if ok { + t.Fatal("expected ok=false when context is canceled") + } +} + +func TestConsumeInbound_BusClosed(t *testing.T) { + mb := NewMessageBus() + mb.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + _, ok := mb.ConsumeInbound(ctx) + if ok { + t.Fatal("expected ok=false when bus is closed") + } +} + +func TestSubscribeOutbound_BusClosed(t *testing.T) { + mb := NewMessageBus() + mb.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + _, ok := mb.SubscribeOutbound(ctx) + if ok { + t.Fatal("expected ok=false when bus is closed") + } +} + +func TestConcurrentPublishClose(t *testing.T) { + mb := NewMessageBus() + ctx := context.Background() + + const numGoroutines = 100 + var wg sync.WaitGroup + wg.Add(numGoroutines + 1) + + // Spawn many goroutines trying to publish + for range numGoroutines { + go func() { + defer wg.Done() + // Use a short timeout context so we don't block forever after close + publishCtx, cancel := context.WithTimeout(ctx, 50*time.Millisecond) + defer cancel() + // Errors are expected; we just must not panic or deadlock + _ = mb.PublishInbound(publishCtx, InboundMessage{Content: "concurrent"}) + }() + } + + // Close from another goroutine + go func() { + defer wg.Done() + time.Sleep(5 * time.Millisecond) + mb.Close() + }() + + // Must complete without deadlock + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + + select { + case <-done: + // success + case <-time.After(5 * time.Second): + t.Fatal("test timed out - possible deadlock") + } +} + +func TestPublishInbound_FullBuffer(t *testing.T) { + mb := NewMessageBus() + defer mb.Close() + + ctx := context.Background() + + // Fill the buffer + for i := range defaultBusBufferSize { + if err := mb.PublishInbound(ctx, InboundMessage{Content: "fill"}); err != nil { + t.Fatalf("fill failed at %d: %v", i, err) + } + } + + // Buffer is full; publish with short timeout + timeoutCtx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + defer cancel() + + err := mb.PublishInbound(timeoutCtx, InboundMessage{Content: "overflow"}) + if err == nil { + t.Fatal("expected error when buffer is full and context times out") + } + if err != context.DeadlineExceeded { + t.Fatalf("expected context.DeadlineExceeded, got %v", err) + } +} + +func TestCloseIdempotent(t *testing.T) { + mb := NewMessageBus() + + // Multiple Close calls must not panic + mb.Close() + mb.Close() + mb.Close() + + // After close, publish should return ErrBusClosed + err := mb.PublishInbound(context.Background(), InboundMessage{Content: "test"}) + if err != ErrBusClosed { + t.Fatalf("expected ErrBusClosed after multiple closes, got %v", err) + } +} diff --git a/pkg/bus/types.go b/pkg/bus/types.go index 44f9181a5..7ad8f0417 100644 --- a/pkg/bus/types.go +++ b/pkg/bus/types.go @@ -1,11 +1,30 @@ package bus +// Peer identifies the routing peer for a message (direct, group, channel, etc.) +type Peer struct { + Kind string `json:"kind"` // "direct" | "group" | "channel" | "" + ID string `json:"id"` +} + +// SenderInfo provides structured sender identity information. +type SenderInfo struct { + Platform string `json:"platform,omitempty"` // "telegram", "discord", "slack", ... + PlatformID string `json:"platform_id,omitempty"` // raw platform ID, e.g. "123456" + CanonicalID string `json:"canonical_id,omitempty"` // "platform:id" format + Username string `json:"username,omitempty"` // username (e.g. @alice) + DisplayName string `json:"display_name,omitempty"` // display name +} + type InboundMessage struct { Channel string `json:"channel"` SenderID string `json:"sender_id"` + Sender SenderInfo `json:"sender"` ChatID string `json:"chat_id"` Content string `json:"content"` Media []string `json:"media,omitempty"` + Peer Peer `json:"peer"` // routing peer + MessageID string `json:"message_id,omitempty"` // platform message ID + MediaScope string `json:"media_scope,omitempty"` // media lifecycle scope SessionKey string `json:"session_key"` Metadata map[string]string `json:"metadata,omitempty"` } @@ -16,4 +35,18 @@ type OutboundMessage struct { Content string `json:"content"` } -type MessageHandler func(InboundMessage) error +// MediaPart describes a single media attachment to send. +type MediaPart struct { + Type string `json:"type"` // "image" | "audio" | "video" | "file" + Ref string `json:"ref"` // media store ref, e.g. "media://abc123" + Caption string `json:"caption,omitempty"` // optional caption text + Filename string `json:"filename,omitempty"` // original filename hint + ContentType string `json:"content_type,omitempty"` // MIME type hint +} + +// OutboundMediaMessage carries media attachments from Agent to channels via the bus. +type OutboundMediaMessage struct { + Channel string `json:"channel"` + ChatID string `json:"chat_id"` + Parts []MediaPart `json:"parts"` +} diff --git a/pkg/channels/README.md b/pkg/channels/README.md new file mode 100644 index 000000000..52b9f98f4 --- /dev/null +++ b/pkg/channels/README.md @@ -0,0 +1,1331 @@ +# PicoClaw Channel System Refactor: Complete Development Guide + +> **Branch**: `refactor/channel-system` +> **Status**: Active development (~40 commits) +> **Scope**: `pkg/channels/`, `pkg/bus/`, `pkg/media/`, `pkg/identity/`, `cmd/picoclaw/internal/gateway/` + +--- + +## Table of Contents + +- [Part 1: Architecture Overview](#part-1-architecture-overview) +- [Part 2: Migration Guide — From main Branch to Refactored Branch](#part-2-migration-guide--from-main-branch-to-refactored-branch) +- [Part 3: New Channel Development Guide — Implementing a Channel from Scratch](#part-3-new-channel-development-guide--implementing-a-channel-from-scratch) +- [Part 4: Core Subsystem Details](#part-4-core-subsystem-details) +- [Part 5: Key Design Decisions and Conventions](#part-5-key-design-decisions-and-conventions) +- [Appendix: Complete File Listing and Interface Quick Reference](#appendix-complete-file-listing-and-interface-quick-reference) + +--- + +## Part 1: Architecture Overview + +### 1.1 Before and After Comparison + +**Before Refactor (main branch)**: + +``` +pkg/channels/ +├── telegram.go # Each channel directly in the channels package +├── discord.go +├── slack.go +├── manager.go # Manager directly references each channel type +├── ... +``` + +- All channel implementations lived at the top level of `pkg/channels/` +- Manager constructed each channel via `switch` or `if-else` chains +- Routing info like Peer and MessageID was buried in `Metadata map[string]string` +- No rate limiting or retry on message sending +- No unified media file lifecycle management +- Each channel ran its own HTTP server +- Group chat trigger filtering logic was scattered across channels + +**After Refactor (refactor/channel-system branch)**: + +``` +pkg/channels/ +├── base.go # BaseChannel shared abstraction layer +├── interfaces.go # Optional capability interfaces (TypingCapable, MessageEditor, ReactionCapable, PlaceholderCapable, PlaceholderRecorder) +├── media.go # MediaSender optional interface +├── webhook.go # WebhookHandler, HealthChecker optional interfaces +├── errors.go # Sentinel errors (ErrNotRunning, ErrRateLimit, ErrTemporary, ErrSendFailed) +├── errutil.go # Error classification helpers +├── registry.go # Factory registry (RegisterFactory / getFactory) +├── manager.go # Unified orchestration: Worker queues, rate limiting, retries, Typing/Placeholder, shared HTTP +├── split.go # Smart long-message splitting (preserves code block integrity) +├── telegram/ # Each channel in its own sub-package +│ ├── init.go # Factory registration +│ ├── telegram.go # Implementation +│ └── telegram_commands.go +├── discord/ +│ ├── init.go +│ └── discord.go +├── slack/ line/ onebot/ dingtalk/ feishu/ wecom/ qq/ whatsapp/ maixcam/ pico/ +│ └── ... + +pkg/bus/ +├── bus.go # MessageBus (buffer 64, safe close + drain) +├── types.go # Structured message types (Peer, SenderInfo, MediaPart, InboundMessage, OutboundMessage, OutboundMediaMessage) + +pkg/media/ +├── store.go # MediaStore interface + FileMediaStore implementation (two-phase release, TTL cleanup) + +pkg/identity/ +├── identity.go # Unified user identity: canonical "platform:id" format + backward-compatible matching +``` + +### 1.2 Message Flow Overview + +``` +┌────────────┐ InboundMessage ┌───────────┐ LLM + Tools ┌────────────┐ +│ Telegram │──┐ │ │ │ │ +│ Discord │──┤ PublishInbound() │ │ PublishOutbound() │ │ +│ Slack │──┼──────────────────────▶ │ MessageBus │ ◀─────────────────── │ AgentLoop │ +│ LINE │──┤ (buffered chan, 64) │ │ (buffered chan, 64) │ │ +│ ... │──┘ │ │ │ │ +└────────────┘ └─────┬─────┘ └────────────┘ + │ + SubscribeOutbound() │ SubscribeOutboundMedia() + ▼ + ┌───────────────────┐ + │ Manager │ + │ ├── dispatchOutbound() Route to Worker queues + │ ├── dispatchOutboundMedia() + │ ├── runWorker() Message split + sendWithRetry() + │ ├── runMediaWorker() sendMediaWithRetry() + │ ├── preSend() Stop Typing + Undo Reaction + Edit Placeholder + │ └── runTTLJanitor() Clean up expired Typing/Placeholder + └────────┬──────────┘ + │ + channel.Send() / SendMedia() + │ + ▼ + ┌────────────────┐ + │ Platform APIs │ + └────────────────┘ +``` + +### 1.3 Key Design Principles + +| Principle | Description | +|-----------|-------------| +| **Sub-package Isolation** | Each channel is a standalone Go sub-package, depending on `BaseChannel` and interfaces from the `channels` parent package | +| **Factory Registration** | Sub-packages self-register via `init()`, Manager looks up factories by name, eliminating import coupling | +| **Capability Discovery** | Optional capabilities are declared via interfaces (`MediaSender`, `TypingCapable`, `ReactionCapable`, `PlaceholderCapable`, `MessageEditor`, `WebhookHandler`), discovered by Manager via runtime type assertions | +| **Structured Messages** | Peer, MessageID, and SenderInfo promoted from Metadata to first-class fields on InboundMessage | +| **Error Classification** | Channels return sentinel errors (`ErrRateLimit`, `ErrTemporary`, etc.), Manager uses these to determine retry strategy | +| **Centralized Orchestration** | Rate limiting, message splitting, retries, and Typing/Reaction/Placeholder management are all handled by Manager and BaseChannel; channels only need to implement Send | + +--- + +## Part 2: Migration Guide — From main Branch to Refactored Branch + +### 2.1 If You Have Unmerged Channel Changes + +#### Step 1: Identify which files you modified + +On the main branch, channel files were directly in `pkg/channels/` top level, e.g.: +- `pkg/channels/telegram.go` +- `pkg/channels/discord.go` + +After refactoring, these files have been removed and code moved to corresponding sub-packages: +- `pkg/channels/telegram/telegram.go` +- `pkg/channels/discord/discord.go` + +#### Step 2: Understand the structural change mapping + +| main branch file | Refactored branch location | Changes | +|---|---|---| +| `pkg/channels/telegram.go` | `pkg/channels/telegram/telegram.go` + `init.go` | Package name changed from `channels` to `telegram` | +| `pkg/channels/discord.go` | `pkg/channels/discord/discord.go` + `init.go` | Same as above | +| `pkg/channels/manager.go` | `pkg/channels/manager.go` | Extensively rewritten | +| _(did not exist)_ | `pkg/channels/base.go` | New shared abstraction layer | +| _(did not exist)_ | `pkg/channels/registry.go` | New factory registry | +| _(did not exist)_ | `pkg/channels/errors.go` + `errutil.go` | New error classification system | +| _(did not exist)_ | `pkg/channels/interfaces.go` | New optional capability interfaces | +| _(did not exist)_ | `pkg/channels/media.go` | New MediaSender interface | +| _(did not exist)_ | `pkg/channels/webhook.go` | New WebhookHandler/HealthChecker | +| _(did not exist)_ | `pkg/channels/split.go` | New message splitting (migrated from utils) | +| _(did not exist)_ | `pkg/bus/types.go` | New structured message types | +| _(did not exist)_ | `pkg/media/store.go` | New media file lifecycle management | +| _(did not exist)_ | `pkg/identity/identity.go` | New unified user identity | + +#### Step 3: Migrate your channel code + +Using Telegram as an example, the main changes are: + +**3a. Package declaration and imports** + +```go +// Old code (main branch) +package channels + +import ( + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/config" +) + +// New code (refactored branch) +package telegram + +import ( + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" // Reference parent package + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/identity" // New + "github.com/sipeed/picoclaw/pkg/media" // New (if media support needed) +) +``` + +**3b. Struct embeds BaseChannel** + +```go +// Old code: directly held bus, config, etc. fields +type TelegramChannel struct { + bus *bus.MessageBus + config *config.Config + running bool + allowList []string + // ... +} + +// New code: embed BaseChannel, which provides bus, running, allowList, etc. +type TelegramChannel struct { + *channels.BaseChannel // Embed shared abstraction + bot *telego.Bot + config *config.Config + // ... only channel-specific fields +} +``` + +**3c. Constructor** + +```go +// Old code: direct assignment +func NewTelegramChannel(cfg *config.Config, bus *bus.MessageBus) (*TelegramChannel, error) { + return &TelegramChannel{ + bus: bus, + config: cfg, + allowList: cfg.Channels.Telegram.AllowFrom, + // ... + }, nil +} + +// New code: use NewBaseChannel + functional options +func NewTelegramChannel(cfg *config.Config, bus *bus.MessageBus) (*TelegramChannel, error) { + base := channels.NewBaseChannel( + "telegram", // Name + cfg.Channels.Telegram, // Raw config (any type) + bus, // Message bus + cfg.Channels.Telegram.AllowFrom, // Allow list + channels.WithMaxMessageLength(4096), // Platform message length limit + channels.WithGroupTrigger(cfg.Channels.Telegram.GroupTrigger), // Group trigger config + ) + return &TelegramChannel{ + BaseChannel: base, + bot: bot, + config: cfg, + }, nil +} +``` + +**3d. Start/Stop lifecycle** + +```go +// New code: use SetRunning atomic operation +func (c *TelegramChannel) Start(ctx context.Context) error { + // ... initialize bot, webhook, etc. + c.SetRunning(true) // Must be called after ready + go bh.Start() + return nil +} + +func (c *TelegramChannel) Stop(ctx context.Context) error { + c.SetRunning(false) // Must be called before cleanup + // ... stop bot handler, cancel context + return nil +} +``` + +**3e. Send method error returns** + +```go +// Old code: returns plain error +func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { + if !c.running { return fmt.Errorf("not running") } + // ... + if err != nil { return err } +} + +// New code: must return sentinel errors for Manager to determine retry strategy +func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { + if !c.IsRunning() { + return channels.ErrNotRunning // ← Manager will not retry + } + // ... + if err != nil { + // Use ClassifySendError to wrap error based on HTTP status code + return channels.ClassifySendError(statusCode, err) + // Or manually wrap: + // return fmt.Errorf("%w: %v", channels.ErrTemporary, err) + // return fmt.Errorf("%w: %v", channels.ErrRateLimit, err) + // return fmt.Errorf("%w: %v", channels.ErrSendFailed, err) + } + return nil +} +``` + +**3f. Message reception (Inbound)** + +```go +// Old code: directly construct InboundMessage and publish +msg := bus.InboundMessage{ + Channel: "telegram", + SenderID: senderID, + ChatID: chatID, + Content: content, + Metadata: map[string]string{ + "peer_kind": "group", // Routing info buried in metadata + "peer_id": chatID, + "message_id": msgID, + }, +} +c.bus.PublishInbound(ctx, msg) + +// New code: use BaseChannel.HandleMessage with structured fields +sender := bus.SenderInfo{ + Platform: "telegram", + PlatformID: strconv.FormatInt(from.ID, 10), + CanonicalID: identity.BuildCanonicalID("telegram", strconv.FormatInt(from.ID, 10)), + Username: from.Username, + DisplayName: from.FirstName, +} + +peer := bus.Peer{ + Kind: "group", // or "direct" + ID: chatID, +} + +// HandleMessage internally calls IsAllowedSender for permission checks, builds MediaScope, and publishes to bus +c.HandleMessage(ctx, peer, messageID, senderID, chatID, content, mediaRefs, metadata, sender) +``` + +**3g. Add factory registration (required)** + +Create `init.go` for your channel: + +```go +// pkg/channels/telegram/init.go +package telegram + +import ( + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" +) + +func init() { + channels.RegisterFactory("telegram", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) { + return NewTelegramChannel(cfg, b) + }) +} +``` + +**3h. Import sub-package in Gateway** + +```go +// cmd/picoclaw/internal/gateway/helpers.go +import ( + _ "github.com/sipeed/picoclaw/pkg/channels/telegram" // Triggers init() registration + _ "github.com/sipeed/picoclaw/pkg/channels/discord" + _ "github.com/sipeed/picoclaw/pkg/channels/your_new_channel" // New addition +) +``` + +#### Step 4: Migrate bus message usage + +If your code directly reads routing fields from `InboundMessage.Metadata`: + +```go +// Old code +peerKind := msg.Metadata["peer_kind"] +peerID := msg.Metadata["peer_id"] +msgID := msg.Metadata["message_id"] + +// New code +peerKind := msg.Peer.Kind // First-class field +peerID := msg.Peer.ID // First-class field +msgID := msg.MessageID // First-class field +sender := msg.Sender // bus.SenderInfo struct +scope := msg.MediaScope // Media lifecycle scope +``` + +#### Step 5: Migrate allow-list checks + +```go +// Old code +if !c.isAllowed(senderID) { return } + +// New code: prefer structured check +if !c.IsAllowedSender(sender) { return } +// Or fall back to string check: +if !c.IsAllowed(senderID) { return } +``` + +`BaseChannel.HandleMessage` already handles this logic internally — no need to duplicate the check in your channel. + +### 2.2 If You Have Manager Modifications + +The Manager has been completely rewritten. Your modifications will need to account for the new architecture: + +| Old Manager Responsibility | New Manager Responsibility | +|---|---| +| Directly construct channels (switch/if-else) | Look up and construct via factory registry | +| Directly call channel.Send | Per-channel Worker queues + rate limiting + retries | +| No message splitting | Automatic splitting based on MaxMessageLength | +| Each channel runs its own HTTP server | Unified shared HTTP server | +| No Typing/Placeholder management | Unified preSend handles Typing stop + Reaction undo + Placeholder edit; inbound-side BaseChannel.HandleMessage auto-orchestrates Typing/Reaction/Placeholder | +| No TTL cleanup | runTTLJanitor periodically cleans up expired Typing/Reaction/Placeholder entries | + +### 2.3 If You Have Agent Loop Modifications + +Main changes to the Agent Loop: + +1. **MediaStore injection**: `agentLoop.SetMediaStore(mediaStore)` — Agent resolves media references produced by tools via MediaStore +2. **ChannelManager injection**: `agentLoop.SetChannelManager(channelManager)` — Agent can query channel state +3. **OutboundMediaMessage**: Agent now sends media messages via `bus.PublishOutboundMedia()` instead of embedding them in text replies +4. **extractPeer**: Routing uses `msg.Peer` structured fields instead of Metadata lookups + +--- + +## Part 3: New Channel Development Guide — Implementing a Channel from Scratch + +### 3.1 Minimum Implementation Checklist + +To add a new chat platform (e.g., `matrix`), you need to: + +1. ✅ Create sub-package directory `pkg/channels/matrix/` +2. ✅ Create `init.go` — factory registration +3. ✅ Create `matrix.go` — channel implementation +4. ✅ Add blank import in Gateway helpers +5. ✅ Add config check in Manager.initChannels() +6. ✅ Add config struct in `pkg/config/` + +### 3.2 Complete Template + +#### `pkg/channels/matrix/init.go` + +```go +package matrix + +import ( + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" +) + +func init() { + channels.RegisterFactory("matrix", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) { + return NewMatrixChannel(cfg, b) + }) +} +``` + +#### `pkg/channels/matrix/matrix.go` + +```go +package matrix + +import ( + "context" + "fmt" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/identity" + "github.com/sipeed/picoclaw/pkg/logger" +) + +// MatrixChannel implements channels.Channel for the Matrix protocol. +type MatrixChannel struct { + *channels.BaseChannel // Must embed + config *config.Config + ctx context.Context + cancel context.CancelFunc + // ... Matrix SDK client, etc. +} + +func NewMatrixChannel(cfg *config.Config, msgBus *bus.MessageBus) (*MatrixChannel, error) { + matrixCfg := cfg.Channels.Matrix // Assumes this field exists in config + + base := channels.NewBaseChannel( + "matrix", // Channel name (globally unique) + matrixCfg, // Raw config + msgBus, // Message bus + matrixCfg.AllowFrom, // Allow list + channels.WithMaxMessageLength(65536), // Matrix message length limit + channels.WithGroupTrigger(matrixCfg.GroupTrigger), + ) + + return &MatrixChannel{ + BaseChannel: base, + config: cfg, + }, nil +} + +// ========== Required Channel Interface Methods ========== + +func (c *MatrixChannel) Start(ctx context.Context) error { + c.ctx, c.cancel = context.WithCancel(ctx) + + // 1. Initialize Matrix client + // 2. Start listening for messages + // 3. Mark as running + c.SetRunning(true) + + logger.InfoC("matrix", "Matrix channel started") + return nil +} + +func (c *MatrixChannel) Stop(ctx context.Context) error { + c.SetRunning(false) + + if c.cancel != nil { + c.cancel() + } + + logger.InfoC("matrix", "Matrix channel stopped") + return nil +} + +func (c *MatrixChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { + // 1. Check running state + if !c.IsRunning() { + return channels.ErrNotRunning + } + + // 2. Send message to Matrix + err := c.sendToMatrix(ctx, msg.ChatID, msg.Content) + if err != nil { + // 3. Must use error classification wrapping + // If you have an HTTP status code: + // return channels.ClassifySendError(statusCode, err) + // If it's a network error: + // return channels.ClassifyNetError(err) + // If manual classification is needed: + return fmt.Errorf("%w: %v", channels.ErrTemporary, err) + } + + return nil +} + +// ========== Incoming Message Handling ========== + +func (c *MatrixChannel) handleIncoming(roomID, senderID, displayName, content string, msgID string) { + // 1. Construct structured sender identity + sender := bus.SenderInfo{ + Platform: "matrix", + PlatformID: senderID, + CanonicalID: identity.BuildCanonicalID("matrix", senderID), + Username: senderID, + DisplayName: displayName, + } + + // 2. Determine Peer type (direct vs group) + peer := bus.Peer{ + Kind: "group", // or "direct" + ID: roomID, + } + + // 3. Group chat filtering (if applicable) + isGroup := peer.Kind == "group" + if isGroup { + isMentioned := false // Detect @mentions based on platform specifics + shouldRespond, cleanContent := c.ShouldRespondInGroup(isMentioned, content) + if !shouldRespond { + return + } + content = cleanContent + } + + // 4. Handle media attachments (if any) + var mediaRefs []string + store := c.GetMediaStore() + if store != nil { + // Download attachment locally → store.Store() → get ref + // mediaRefs = append(mediaRefs, ref) + } + + // 5. Call HandleMessage to publish to bus + // HandleMessage internally will: + // - Check IsAllowedSender/IsAllowed + // - Build MediaScope + // - Publish InboundMessage + c.HandleMessage( + c.ctx, + peer, + msgID, // Platform message ID + senderID, // Raw sender ID + roomID, // Chat/room ID + content, // Message content + mediaRefs, // Media reference list + nil, // Extra metadata (usually nil) + sender, // SenderInfo (variadic parameter) + ) +} + +// ========== Internal Methods ========== + +func (c *MatrixChannel) sendToMatrix(ctx context.Context, roomID, content string) error { + // Actual Matrix SDK call + return nil +} +``` + +### 3.3 Optional Capability Interfaces + +Depending on platform capabilities, your channel can optionally implement the following interfaces: + +#### MediaSender — Send Media Attachments + +```go +// If the platform supports sending images/files/audio/video +func (c *MatrixChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error { + if !c.IsRunning() { + return channels.ErrNotRunning + } + + store := c.GetMediaStore() + if store == nil { + return fmt.Errorf("no media store: %w", channels.ErrSendFailed) + } + + for _, part := range msg.Parts { + localPath, err := store.Resolve(part.Ref) + if err != nil { + logger.ErrorCF("matrix", "Failed to resolve media", map[string]any{ + "ref": part.Ref, "error": err.Error(), + }) + continue + } + + // Call the appropriate API based on part.Type ("image"|"audio"|"video"|"file") + switch part.Type { + case "image": + // Upload image to Matrix + default: + // Upload file to Matrix + } + } + return nil +} +``` + +#### TypingCapable — Typing Indicator + +```go +// If the platform supports "typing..." indicators +func (c *MatrixChannel) StartTyping(ctx context.Context, chatID string) (stop func(), err error) { + // Call Matrix API to send typing indicator + // The returned stop function must be idempotent + stopped := false + return func() { + if !stopped { + stopped = true + // Call Matrix API to stop typing + } + }, nil +} +``` + +#### ReactionCapable — Message Reaction Indicator + +```go +// If the platform supports adding emoji reactions to inbound messages (e.g., Slack's 👀, OneBot's emoji 289) +func (c *MatrixChannel) ReactToMessage(ctx context.Context, chatID, messageID string) (undo func(), err error) { + // Call Matrix API to add reaction to message + // The returned undo function removes the reaction, must be idempotent + err = c.addReaction(chatID, messageID, "eyes") + if err != nil { + return func() {}, err + } + return func() { + c.removeReaction(chatID, messageID, "eyes") + }, nil +} +``` + +#### MessageEditor — Message Editing + +```go +// If the platform supports editing sent messages (used for Placeholder replacement) +func (c *MatrixChannel) EditMessage(ctx context.Context, chatID, messageID, content string) error { + // Call Matrix API to edit message + return nil +} +``` + +#### WebhookHandler — HTTP Webhook Reception + +```go +// If the channel receives messages via webhook (rather than long-polling/WebSocket) +func (c *MatrixChannel) WebhookPath() string { + return "/webhook/matrix" // Path will be registered on the shared HTTP server +} + +func (c *MatrixChannel) ServeHTTP(w http.ResponseWriter, r *http.Request) { + // Handle webhook request +} +``` + +#### HealthChecker — Health Check Endpoint + +```go +func (c *MatrixChannel) HealthPath() string { + return "/health/matrix" +} + +func (c *MatrixChannel) HealthHandler(w http.ResponseWriter, r *http.Request) { + if c.IsRunning() { + w.WriteHeader(http.StatusOK) + w.Write([]byte("OK")) + } else { + w.WriteHeader(http.StatusServiceUnavailable) + } +} +``` + +### 3.4 Inbound-side Typing/Reaction/Placeholder Auto-orchestration + +`BaseChannel.HandleMessage` automatically detects whether the channel implements `TypingCapable`, `ReactionCapable`, and/or `PlaceholderCapable` **before** publishing the inbound message, and triggers the corresponding indicators. The three pipelines are completely independent and do not interfere with each other: + +```go +// Automatically executed inside BaseChannel.HandleMessage (no manual calls needed): +if c.owner != nil && c.placeholderRecorder != nil { + // Typing — independent pipeline + if tc, ok := c.owner.(TypingCapable); ok { + if stop, err := tc.StartTyping(ctx, chatID); err == nil { + c.placeholderRecorder.RecordTypingStop(c.name, chatID, stop) + } + } + // Reaction — independent pipeline + if rc, ok := c.owner.(ReactionCapable); ok && messageID != "" { + if undo, err := rc.ReactToMessage(ctx, chatID, messageID); err == nil { + c.placeholderRecorder.RecordReactionUndo(c.name, chatID, undo) + } + } + // Placeholder — independent pipeline + if pc, ok := c.owner.(PlaceholderCapable); ok { + if phID, err := pc.SendPlaceholder(ctx, chatID); err == nil && phID != "" { + c.placeholderRecorder.RecordPlaceholder(c.name, chatID, phID) + } + } +} +``` + +**This means**: +- Channels implementing `TypingCapable` (Telegram, Discord, LINE, Pico) do not need to manually call `StartTyping` + `RecordTypingStop` in `handleMessage` +- Channels implementing `ReactionCapable` (Slack, OneBot) do not need to manually call `AddReaction` + `RecordTypingStop` in `handleMessage` +- Channels implementing `PlaceholderCapable` (Telegram, Discord, Pico) do not need to manually send placeholder messages and call `RecordPlaceholder` in `handleMessage` +- Channels only need to implement the corresponding interface; `HandleMessage` handles orchestration automatically +- Channels that don't implement these interfaces are unaffected (type assertions will fail and be skipped) +- `PlaceholderCapable`'s `SendPlaceholder` method internally decides whether to send based on the configured `PlaceholderConfig.Enabled`; returning `("", nil)` skips registration + +**Owner Injection**: Manager automatically calls `SetOwner(ch)` in `initChannel` to inject the concrete channel into BaseChannel — no manual setup required from developers. + +When the Agent finishes processing a message, Manager's `preSend` automatically: +1. Calls the recorded `stop()` to stop Typing +2. Calls the recorded `undo()` to undo Reaction +3. If there is a Placeholder and the channel implements `MessageEditor`, attempts to edit the Placeholder with the final reply (skipping Send) + +### 3.5 Register Configuration and Gateway Integration + +#### Add configuration in `pkg/config/config.go` + +```go +type ChannelsConfig struct { + // ... existing channels + Matrix MatrixChannelConfig `yaml:"matrix" json:"matrix"` +} + +type MatrixChannelConfig struct { + Enabled bool `yaml:"enabled" json:"enabled"` + HomeServer string `yaml:"home_server" json:"home_server"` + Token string `yaml:"token" json:"token"` + AllowFrom []string `yaml:"allow_from" json:"allow_from"` + GroupTrigger GroupTriggerConfig `yaml:"group_trigger" json:"group_trigger"` +} +``` + +#### Add entry in Manager.initChannels() + +```go +// In the initChannels() method of pkg/channels/manager.go +if m.config.Channels.Matrix.Enabled && m.config.Channels.Matrix.Token != "" { + m.initChannel("matrix", "Matrix") +} +``` + +#### Add blank import in Gateway + +```go +// cmd/picoclaw/internal/gateway/helpers.go +import ( + _ "github.com/sipeed/picoclaw/pkg/channels/matrix" +) +``` + +--- + +## Part 4: Core Subsystem Details + +### 4.1 MessageBus + +**Files**: `pkg/bus/bus.go`, `pkg/bus/types.go` + +```go +type MessageBus struct { + inbound chan InboundMessage // buffer = 64 + outbound chan OutboundMessage // buffer = 64 + outboundMedia chan OutboundMediaMessage // buffer = 64 + done chan struct{} // Close signal + closed atomic.Bool // Prevents double-close +} +``` + +**Key Behaviors**: + +| Method | Behavior | +|--------|----------| +| `PublishInbound(ctx, msg)` | Check closed → send to inbound channel → block/timeout/close | +| `ConsumeInbound(ctx)` | Read from inbound → block/close/cancel | +| `PublishOutbound(ctx, msg)` | Send to outbound channel | +| `SubscribeOutbound(ctx)` | Read from outbound (called by Manager dispatcher) | +| `PublishOutboundMedia(ctx, msg)` | Send to outboundMedia channel | +| `SubscribeOutboundMedia(ctx)` | Read from outboundMedia (called by Manager media dispatcher) | +| `Close()` | CAS close → close(done) → drain all channels (**does not close the channels themselves** to avoid concurrent send-on-closed panic) | + +**Design Notes**: +- Buffer size increased from 16 to 64 to reduce blocking under burst load +- `Close()` does not close the underlying channels (only closes the `done` signal channel), because there may be concurrent `Publish` goroutines +- Drain loop ensures buffered messages are not silently dropped + +### 4.2 Structured Message Types + +**File**: `pkg/bus/types.go` + +```go +// Routing peer +type Peer struct { + Kind string `json:"kind"` // "direct" | "group" | "channel" | "" + ID string `json:"id"` +} + +// Sender identity information +type SenderInfo struct { + Platform string `json:"platform,omitempty"` // "telegram", "discord", ... + PlatformID string `json:"platform_id,omitempty"` // Platform-native ID + CanonicalID string `json:"canonical_id,omitempty"` // "platform:id" canonical format + Username string `json:"username,omitempty"` + DisplayName string `json:"display_name,omitempty"` +} + +// Inbound message +type InboundMessage struct { + Channel string // Source channel name + SenderID string // Sender ID (prefer CanonicalID) + Sender SenderInfo // Structured sender info + ChatID string // Chat/room ID + Content string // Message text + Media []string // Media reference list (media://...) + Peer Peer // Routing peer (first-class field) + MessageID string // Platform message ID (first-class field) + MediaScope string // Media lifecycle scope + SessionKey string // Session key + Metadata map[string]string // Only for channel-specific extensions +} + +// Outbound text message +type OutboundMessage struct { + Channel string + ChatID string + Content string +} + +// Outbound media message +type OutboundMediaMessage struct { + Channel string + ChatID string + Parts []MediaPart +} + +// Media part +type MediaPart struct { + Type string // "image" | "audio" | "video" | "file" + Ref string // "media://uuid" + Caption string + Filename string + ContentType string +} +``` + +### 4.3 BaseChannel + +**File**: `pkg/channels/base.go` + +BaseChannel is the shared abstraction layer for all channels, providing the following capabilities: + +| Method/Feature | Description | +|---|---| +| `Name() string` | Channel name | +| `IsRunning() bool` | Atomically read running state | +| `SetRunning(bool)` | Atomically set running state | +| `MaxMessageLength() int` | Message length limit (rune count), 0 = unlimited | +| `IsAllowed(senderID string) bool` | Legacy allow-list check (supports `"id\|username"` and `"@username"` formats) | +| `IsAllowedSender(sender SenderInfo) bool` | New allow-list check (delegates to `identity.MatchAllowed`) | +| `ShouldRespondInGroup(isMentioned, content) (bool, string)` | Unified group chat trigger filtering logic | +| `HandleMessage(...)` | Unified inbound message handling: permission check → build MediaScope → auto-trigger Typing/Reaction → publish to Bus | +| `SetMediaStore(s) / GetMediaStore()` | MediaStore injected by Manager | +| `SetPlaceholderRecorder(r) / GetPlaceholderRecorder()` | PlaceholderRecorder injected by Manager | +| `SetOwner(ch)` | Concrete channel reference injected by Manager (used for Typing/Reaction type assertions in HandleMessage) | + +**Functional Options**: + +```go +channels.WithMaxMessageLength(4096) // Set platform message length limit +channels.WithGroupTrigger(groupTriggerCfg) // Set group trigger configuration +``` + +### 4.4 Factory Registry + +**File**: `pkg/channels/registry.go` + +```go +type ChannelFactory func(cfg *config.Config, bus *bus.MessageBus) (Channel, error) + +func RegisterFactory(name string, f ChannelFactory) // Called in sub-package init() +func getFactory(name string) (ChannelFactory, bool) // Called internally by Manager +``` + +The factory registry is protected by `sync.RWMutex` and registrations occur during `init()` phase (completed at process startup). Manager looks up factories by name in `initChannel()` and calls them. + +### 4.5 Error Classification and Retries + +**Files**: `pkg/channels/errors.go`, `pkg/channels/errutil.go` + +#### Sentinel Errors + +```go +var ( + ErrNotRunning = errors.New("channel not running") // Permanent: do not retry + ErrRateLimit = errors.New("rate limited") // Fixed delay: retry after 1s + ErrTemporary = errors.New("temporary failure") // Exponential backoff: 500ms * 2^attempt, max 8s + ErrSendFailed = errors.New("send failed") // Permanent: do not retry +) +``` + +#### Error Classification Helpers + +```go +// Automatically classify based on HTTP status code +func ClassifySendError(statusCode int, rawErr error) error { + // 429 → ErrRateLimit + // 5xx → ErrTemporary + // 4xx → ErrSendFailed +} + +// Wrap network errors as temporary +func ClassifyNetError(err error) error { + // → ErrTemporary +} +``` + +#### Manager Retry Strategy (`sendWithRetry`) + +``` +Max retries: 3 +Rate limit delay: 1 second +Base backoff: 500 milliseconds +Max backoff: 8 seconds + +Retry logic: + ErrNotRunning → Fail immediately, no retry + ErrSendFailed → Fail immediately, no retry + ErrRateLimit → Wait 1s → retry + ErrTemporary → Wait 500ms * 2^attempt (max 8s) → retry + Other unknown → Wait 500ms * 2^attempt (max 8s) → retry +``` + +### 4.6 Manager Orchestration + +**File**: `pkg/channels/manager.go` + +#### Per-channel Worker Architecture + +```go +type channelWorker struct { + ch Channel // Channel instance + queue chan bus.OutboundMessage // Outbound text queue (buffered 16) + mediaQueue chan bus.OutboundMediaMessage // Outbound media queue (buffered 16) + done chan struct{} // Text worker completion signal + mediaDone chan struct{} // Media worker completion signal + limiter *rate.Limiter // Per-channel rate limiter +} +``` + +#### Per-channel Rate Limit Configuration + +```go +var channelRateConfig = map[string]float64{ + "telegram": 20, // 20 msg/s + "discord": 1, // 1 msg/s + "slack": 1, // 1 msg/s + "line": 10, // 10 msg/s +} +// Default: 10 msg/s +// burst = max(1, ceil(rate/2)) +``` + +#### Lifecycle Management + +``` +StartAll: + 1. Iterate registered channels → channel.Start(ctx) + 2. Create channelWorker for each successfully started channel + 3. Start goroutines: + - runWorker (per-channel outbound text) + - runMediaWorker (per-channel outbound media) + - dispatchOutbound (route from bus to worker queues) + - dispatchOutboundMedia (route from bus to media worker queues) + - runTTLJanitor (every 10s clean up expired typing/placeholder) + 4. Start shared HTTP server (if configured) + +StopAll: + 1. Shut down shared HTTP server (5s timeout) + 2. Cancel dispatcher context + 3. Close text worker queues → wait for drain to complete + 4. Close media worker queues → wait for drain to complete + 5. Stop each channel (channel.Stop) +``` + +#### Typing/Reaction/Placeholder Management + +```go +// Manager implements PlaceholderRecorder interface +func (m *Manager) RecordPlaceholder(channel, chatID, placeholderID string) +func (m *Manager) RecordTypingStop(channel, chatID string, stop func()) +func (m *Manager) RecordReactionUndo(channel, chatID string, undo func()) + +// Inbound side: BaseChannel.HandleMessage auto-orchestrates +// BaseChannel.HandleMessage, before PublishInbound, auto-triggers via owner type assertions: +// - TypingCapable.StartTyping → RecordTypingStop +// - ReactionCapable.ReactToMessage → RecordReactionUndo +// - PlaceholderCapable.SendPlaceholder → RecordPlaceholder +// All three are independent and do not interfere with each other. Channels don't need to call these manually. + +// Outbound side: pre-send processing +func (m *Manager) preSend(ctx, name, msg, ch) bool { + key := name + ":" + msg.ChatID + // 1. Stop Typing (call stored stop function) + // 2. Undo Reaction (call stored undo function) + // 3. Attempt to edit Placeholder (if channel implements MessageEditor) + // Success → return true (skip Send) + // Failure → return false (proceed with Send) +} +``` + +Manager storage is fully separated; three pipelines do not interfere: + +```go +Manager { + typingStops sync.Map // "channel:chatID" → typingEntry ← manages TypingCapable + reactionUndos sync.Map // "channel:chatID" → reactionEntry ← manages ReactionCapable + placeholders sync.Map // "channel:chatID" → placeholderEntry +} +``` + +TTL Cleanup: +- Typing stop functions: 5-minute TTL (auto-calls stop and deletes on expiry) +- Reaction undo functions: 5-minute TTL (auto-calls undo and deletes on expiry) +- Placeholder IDs: 10-minute TTL (deletes on expiry) +- Cleanup interval: 10 seconds + +### 4.7 Message Splitting + +**File**: `pkg/channels/split.go` + +`SplitMessage(content string, maxLen int) []string` + +Smart splitting strategy: +1. Calculate effective split point = maxLen - 10% buffer (to reserve space for code block closure) +2. Prefer splitting at newlines +3. Otherwise split at spaces/tabs +4. Detect unclosed code blocks (` ``` `) +5. If a code block is unclosed: + - Attempt to extend to maxLen to include the closing fence + - If the code block is too long, inject close/reopen fences (`\n```\n` + header) + - Last resort: split before the code block starts + +### 4.8 MediaStore + +**File**: `pkg/media/store.go` + +```go +type MediaStore interface { + Store(localPath string, meta MediaMeta, scope string) (ref string, err error) + Resolve(ref string) (localPath string, err error) + ResolveWithMeta(ref string) (localPath string, meta MediaMeta, err error) + ReleaseAll(scope string) error +} +``` + +**FileMediaStore Implementation**: +- Pure in-memory mapping, no file copy/move +- Reference format: `media://` +- Scope format: `channel:chatID:messageID` (generated by `BuildMediaScope`) +- **Two-phase operation**: + - Phase 1 (holding lock): collect and delete entries from map + - Phase 2 (no lock): delete files from disk + - Purpose: minimize lock contention +- **TTL Cleanup**: `NewFileMediaStoreWithCleanup` → `Start()` launches background cleanup goroutine +- Cleanup interval and max TTL are controlled by configuration + +### 4.9 Identity + +**File**: `pkg/identity/identity.go` + +```go +// Build canonical ID +func BuildCanonicalID(platform, platformID string) string +// → "telegram:123456" + +// Parse canonical ID +func ParseCanonicalID(canonical string) (platform, id string, ok bool) + +// Match against allow list (backward-compatible) +func MatchAllowed(sender bus.SenderInfo, allowed string) bool +``` + +`MatchAllowed` supported allow-list formats: +| Format | Matching | +|--------|----------| +| `"123456"` | Matches `sender.PlatformID` | +| `"@alice"` | Matches `sender.Username` | +| `"123456\|alice"` | Matches PlatformID or Username (legacy format compatibility) | +| `"telegram:123456"` | Exact match on `sender.CanonicalID` (new format) | + +### 4.10 Shared HTTP Server + +**File**: `pkg/channels/manager.go`'s `SetupHTTPServer` + +Manager creates a single `http.Server` and auto-discovers and registers: +- Channels implementing `WebhookHandler` → mounted at `wh.WebhookPath()` +- Channels implementing `HealthChecker` → mounted at `hc.HealthPath()` +- Global health endpoint registered by `health.Server.RegisterOnMux` + +Timeout configuration: ReadTimeout = 30s, WriteTimeout = 30s + +--- + +## Part 5: Key Design Decisions and Conventions + +### 5.1 Mandatory Conventions + +1. **Error classification is a contract**: A channel's `Send` method **must** return sentinel errors (or wrap them). Manager's retry strategy relies entirely on `errors.Is` checks. Returning unclassified errors will cause Manager to treat them as "unknown errors" (exponential backoff retry). + +2. **SetRunning is a lifecycle signal**: **Must** call `c.SetRunning(true)` after successful `Start`, and **must** call `c.SetRunning(false)` at the beginning of `Stop`. **Must** check `c.IsRunning()` in `Send` and return `ErrNotRunning`. + +3. **HandleMessage includes permission checks**: Do not perform your own permission checks before calling `HandleMessage` (unless you need platform-specific preprocessing before the check). `HandleMessage` already calls `IsAllowedSender`/`IsAllowed` internally. + +4. **Message splitting is handled by Manager**: A channel's `Send` method does not need to handle long message splitting. Manager automatically splits based on `MaxMessageLength()` before calling `Send`. Channels only need to declare the limit via `WithMaxMessageLength`. + +5. **Typing/Reaction/Placeholder is handled by BaseChannel + Manager automatically**: A channel's `Send` method does not need to manage Typing stop, Reaction undo, or Placeholder editing. `BaseChannel.HandleMessage` auto-triggers `TypingCapable`, `ReactionCapable`, and `PlaceholderCapable` on the inbound side (via `owner` type assertions); Manager's `preSend` auto-stops Typing, undoes Reaction, and edits Placeholder on the outbound side. Channels only need to implement the corresponding interfaces. + +6. **Factory registration belongs in init()**: Each sub-package must have an `init.go` file calling `channels.RegisterFactory`. Gateway must trigger registration via blank imports (`_ "pkg/channels/xxx"`). + +### 5.2 Metadata Field Usage Conventions + +**Do NOT put the following information in Metadata anymore**: +- `peer_kind` / `peer_id` → Use `InboundMessage.Peer` +- `message_id` → Use `InboundMessage.MessageID` +- `sender_platform` / `sender_username` → Use `InboundMessage.Sender` + +**Metadata should only be used for**: +- Channel-specific extension information (e.g., Telegram's `reply_to_message_id`) +- Temporary information that doesn't fit into structured fields + +### 5.3 Concurrency Safety Conventions + +- `BaseChannel.running`: Uses `atomic.Bool`, thread-safe +- `Manager.channels` / `Manager.workers`: Protected by `sync.RWMutex` +- `Manager.placeholders` / `Manager.typingStops` / `Manager.reactionUndos`: Uses `sync.Map` +- `MessageBus.closed`: Uses `atomic.Bool` +- `FileMediaStore`: Uses `sync.RWMutex`, two-phase operation to minimize lock-hold time +- Channel Worker queue: Go channel, inherently concurrent-safe + +### 5.4 Testing Conventions + +Existing test files: +- `pkg/channels/base_test.go` — BaseChannel unit tests +- `pkg/channels/manager_test.go` — Manager unit tests +- `pkg/channels/split_test.go` — Message splitting tests +- `pkg/channels/errors_test.go` — Error type tests +- `pkg/channels/errutil_test.go` — Error classification tests + +To add tests for a new channel: +```bash +go test ./pkg/channels/matrix/ -v # Sub-package tests +go test ./pkg/channels/ -run TestSpecific -v # Framework tests +make test # Full test suite +``` + +--- + +## Appendix: Complete File Listing and Interface Quick Reference + +### A.1 Framework Layer Files + +| File | Responsibility | +|------|---------------| +| `pkg/channels/base.go` | BaseChannel struct, Channel interface, MessageLengthProvider, BaseChannelOption, HandleMessage | +| `pkg/channels/interfaces.go` | TypingCapable, MessageEditor, ReactionCapable, PlaceholderCapable, PlaceholderRecorder interfaces | +| `pkg/channels/media.go` | MediaSender interface | +| `pkg/channels/webhook.go` | WebhookHandler, HealthChecker interfaces | +| `pkg/channels/errors.go` | ErrNotRunning, ErrRateLimit, ErrTemporary, ErrSendFailed sentinels | +| `pkg/channels/errutil.go` | ClassifySendError, ClassifyNetError helpers | +| `pkg/channels/registry.go` | RegisterFactory, getFactory factory registry | +| `pkg/channels/manager.go` | Manager: Worker queues, rate limiting, retries, preSend, shared HTTP, TTL janitor | +| `pkg/channels/split.go` | SplitMessage long-message splitting | +| `pkg/bus/bus.go` | MessageBus implementation | +| `pkg/bus/types.go` | Peer, SenderInfo, InboundMessage, OutboundMessage, OutboundMediaMessage, MediaPart | +| `pkg/media/store.go` | MediaStore interface, FileMediaStore implementation | +| `pkg/identity/identity.go` | BuildCanonicalID, ParseCanonicalID, MatchAllowed | + +### A.2 Channel Sub-packages + +| Sub-package | Registered Name | Optional Interfaces | +|-------------|----------------|-------------------| +| `pkg/channels/telegram/` | `"telegram"` | MessageEditor, MediaSender, TypingCapable, PlaceholderCapable | +| `pkg/channels/discord/` | `"discord"` | MessageEditor, TypingCapable, PlaceholderCapable | +| `pkg/channels/slack/` | `"slack"` | ReactionCapable | +| `pkg/channels/line/` | `"line"` | WebhookHandler, HealthChecker, TypingCapable | +| `pkg/channels/onebot/` | `"onebot"` | ReactionCapable | +| `pkg/channels/dingtalk/` | `"dingtalk"` | WebhookHandler | +| `pkg/channels/feishu/` | `"feishu"` | WebhookHandler (architecture-specific build tags) | +| `pkg/channels/wecom/` | `"wecom"` + `"wecom_app"` | WebhookHandler | +| `pkg/channels/qq/` | `"qq"` | — | +| `pkg/channels/whatsapp/` | `"whatsapp"` | — | +| `pkg/channels/maixcam/` | `"maixcam"` | — | +| `pkg/channels/pico/` | `"pico"` | WebhookHandler (Pico Protocol), TypingCapable, PlaceholderCapable | + +### A.3 Interface Quick Reference + +```go +// ===== Required ===== +type Channel interface { + Name() string + Start(ctx context.Context) error + Stop(ctx context.Context) error + Send(ctx context.Context, msg bus.OutboundMessage) error + IsRunning() bool + IsAllowed(senderID string) bool + IsAllowedSender(sender bus.SenderInfo) bool +} + +// ===== Optional ===== +type MediaSender interface { + SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error +} + +type TypingCapable interface { + StartTyping(ctx context.Context, chatID string) (stop func(), err error) +} + +type ReactionCapable interface { + ReactToMessage(ctx context.Context, chatID, messageID string) (undo func(), err error) +} + +type PlaceholderCapable interface { + SendPlaceholder(ctx context.Context, chatID string) (messageID string, err error) +} + +type MessageEditor interface { + EditMessage(ctx context.Context, chatID, messageID, content string) error +} + +type WebhookHandler interface { + WebhookPath() string + http.Handler +} + +type HealthChecker interface { + HealthPath() string + HealthHandler(w http.ResponseWriter, r *http.Request) +} + +type MessageLengthProvider interface { + MaxMessageLength() int +} + +// ===== Injected by Manager ===== +type PlaceholderRecorder interface { + RecordPlaceholder(channel, chatID, placeholderID string) + RecordTypingStop(channel, chatID string, stop func()) + RecordReactionUndo(channel, chatID string, undo func()) +} +``` + +### A.4 Gateway Startup Sequence (Complete Bootstrap Flow) + +```go +// 1. Create core components +msgBus := bus.NewMessageBus() +provider := providers.CreateProvider(cfg) +agentLoop := agent.NewAgentLoop(cfg, msgBus, provider) + +// 2. Create media store (with TTL cleanup) +mediaStore := media.NewFileMediaStoreWithCleanup(cleanerConfig) +mediaStore.Start() + +// 3. Create Channel Manager (triggers initChannels → factory lookup → construct → inject MediaStore/PlaceholderRecorder/Owner) +channelManager := channels.NewManager(cfg, msgBus, mediaStore) + +// 4. Inject references +agentLoop.SetChannelManager(channelManager) +agentLoop.SetMediaStore(mediaStore) + +// 5. Configure shared HTTP server +channelManager.SetupHTTPServer(addr, healthServer) + +// 6. Start +channelManager.StartAll(ctx) // Start channels + workers + dispatchers + HTTP server +go agentLoop.Run(ctx) // Start Agent message loop + +// 7. Shutdown (signal-triggered) +cancel() // Cancel context +msgBus.Close() // Signal close + drain +channelManager.StopAll(shutdownCtx) // Stop HTTP + workers + channels +mediaStore.Stop() // Stop TTL cleanup +agentLoop.Stop() // Stop Agent +``` + +### A.5 Per-channel Rate Limit Reference + +| Channel | Rate (msg/s) | Burst | +|---------|-------------|-------| +| telegram | 20 | 10 | +| discord | 1 | 1 | +| slack | 1 | 1 | +| line | 10 | 5 | +| _others_ | 10 (default) | 5 | + +### A.6 Known Limitations and Caveats + +1. **Media cleanup temporarily disabled**: The `ReleaseAll` call in the Agent loop is commented out (`refactor(loop): disable media cleanup to prevent premature file deletion`) because session boundaries are not yet clearly defined. TTL cleanup remains active. + +2. **Feishu architecture-specific compilation**: The Feishu channel uses build tags to distinguish 32-bit and 64-bit architectures (`feishu_32.go` / `feishu_64.go`). + +3. **WeCom has two factories**: `"wecom"` (Bot mode) and `"wecom_app"` (App mode) are registered separately. + +4. **Pico Protocol**: `pkg/channels/pico/` implements a custom PicoClaw native protocol channel that receives messages via webhook. \ No newline at end of file diff --git a/pkg/channels/README.zh.md b/pkg/channels/README.zh.md new file mode 100644 index 000000000..0a9487cd0 --- /dev/null +++ b/pkg/channels/README.zh.md @@ -0,0 +1,1331 @@ +# PicoClaw Channel System 重构:完整开发指南 + +> **分支**: `refactor/channel-system` +> **状态**: 活跃开发中(约 40 commits) +> **影响范围**: `pkg/channels/`, `pkg/bus/`, `pkg/media/`, `pkg/identity/`, `cmd/picoclaw/internal/gateway/` + +--- + +## 目录 + +- [第一部分:架构总览](#第一部分架构总览) +- [第二部分:迁移指南——从 main 分支迁移到重构分支](#第二部分迁移指南从-main-分支迁移到重构分支) +- [第三部分:新 Channel 开发指南——从零实现一个新 Channel](#第三部分新-channel-开发指南从零实现一个新-channel) +- [第四部分:核心子系统详解](#第四部分核心子系统详解) +- [第五部分:关键设计决策与约定](#第五部分关键设计决策与约定) +- [附录:完整文件清单与接口速查表](#附录完整文件清单与接口速查表) + +--- + +## 第一部分:架构总览 + +### 1.1 重构前后对比 + +**重构前(main 分支)**: + +``` +pkg/channels/ +├── telegram.go # 每个 channel 直接放在 channels 包内 +├── discord.go +├── slack.go +├── manager.go # Manager 直接引用各 channel 类型 +├── ... +``` + +- Channel 实现全部在 `pkg/channels/` 包的顶层 +- Manager 通过 `switch` 或 `if-else` 链条直接构造各 channel +- Peer、MessageID 等路由信息埋在 `Metadata map[string]string` 中 +- 消息发送没有速率限制和重试 +- 没有统一的媒体文件生命周期管理 +- 各 channel 各自启动 HTTP 服务器 +- 群聊触发过滤逻辑分散在各 channel 中 + +**重构后(refactor/channel-system 分支)**: + +``` +pkg/channels/ +├── base.go # BaseChannel 共享抽象层 +├── interfaces.go # 可选能力接口(TypingCapable, MessageEditor, ReactionCapable, PlaceholderCapable, PlaceholderRecorder) +├── media.go # MediaSender 可选接口 +├── webhook.go # WebhookHandler, HealthChecker 可选接口 +├── errors.go # 错误哨兵值(ErrNotRunning, ErrRateLimit, ErrTemporary, ErrSendFailed) +├── errutil.go # 错误分类帮助函数 +├── registry.go # 工厂注册表(RegisterFactory / getFactory) +├── manager.go # 统一编排:Worker 队列、速率限制、重试、Typing/Placeholder、共享 HTTP +├── split.go # 长消息智能分割(保留代码块完整性) +├── telegram/ # 每个 channel 独立子包 +│ ├── init.go # 工厂注册 +│ ├── telegram.go # 实现 +│ └── telegram_commands.go +├── discord/ +│ ├── init.go +│ └── discord.go +├── slack/ line/ onebot/ dingtalk/ feishu/ wecom/ qq/ whatsapp/ maixcam/ pico/ +│ └── ... + +pkg/bus/ +├── bus.go # MessageBus(缓冲区 64,安全关闭+排水) +├── types.go # 结构化消息类型(Peer, SenderInfo, MediaPart, InboundMessage, OutboundMessage, OutboundMediaMessage) + +pkg/media/ +├── store.go # MediaStore 接口 + FileMediaStore 实现(两阶段释放,TTL 清理) + +pkg/identity/ +├── identity.go # 统一用户身份:规范 "platform:id" 格式 + 向后兼容匹配 +``` + +### 1.2 消息流转全景图 + +``` +┌────────────┐ InboundMessage ┌───────────┐ LLM + Tools ┌────────────┐ +│ Telegram │──┐ │ │ │ │ +│ Discord │──┤ PublishInbound() │ │ PublishOutbound() │ │ +│ Slack │──┼──────────────────────▶ │ MessageBus │ ◀─────────────────── │ AgentLoop │ +│ LINE │──┤ (buffered chan, 64) │ │ (buffered chan, 64) │ │ +│ ... │──┘ │ │ │ │ +└────────────┘ └─────┬─────┘ └────────────┘ + │ + SubscribeOutbound() │ SubscribeOutboundMedia() + ▼ + ┌───────────────────┐ + │ Manager │ + │ ├── dispatchOutbound() 路由到 Worker 队列 + │ ├── dispatchOutboundMedia() + │ ├── runWorker() 消息分割 + sendWithRetry() + │ ├── runMediaWorker() sendMediaWithRetry() + │ ├── preSend() 停止 Typing + 撤销 Reaction + 编辑 Placeholder + │ └── runTTLJanitor() 清理过期 Typing/Placeholder + └────────┬──────────┘ + │ + channel.Send() / SendMedia() + │ + ▼ + ┌────────────────┐ + │ 各平台 API/SDK │ + └────────────────┘ +``` + +### 1.3 关键设计原则 + +| 原则 | 说明 | +|------|------| +| **子包隔离** | 每个 channel 一个独立 Go 子包,依赖 `channels` 父包提供的 `BaseChannel` 和接口 | +| **工厂注册** | 各子包通过 `init()` 自注册,Manager 通过名字查找工厂,消除 import 耦合 | +| **能力发现** | 可选能力通过接口(`MediaSender`, `TypingCapable`, `ReactionCapable`, `PlaceholderCapable`, `MessageEditor`, `WebhookHandler`)声明,Manager 运行时类型断言发现 | +| **结构化消息** | Peer、MessageID、SenderInfo 从 Metadata 提升为 InboundMessage 的一等字段 | +| **错误分类** | Channel 返回哨兵错误(`ErrRateLimit`, `ErrTemporary` 等),Manager 据此决定重试策略 | +| **集中编排** | 速率限制、消息分割、重试、Typing/Reaction/Placeholder 全部由 Manager 和 BaseChannel 统一处理,Channel 只负责 Send | + +--- + +## 第二部分:迁移指南——从 main 分支迁移到重构分支 + +### 2.1 如果你有未合并的 Channel 修改 + +#### 步骤 1:确认你修改了哪些文件 + +在 main 分支上,Channel 文件直接位于 `pkg/channels/` 顶层,例如: +- `pkg/channels/telegram.go` +- `pkg/channels/discord.go` + +重构后,这些文件已被删除,代码移动到了对应子包: +- `pkg/channels/telegram/telegram.go` +- `pkg/channels/discord/discord.go` + +#### 步骤 2:理解结构变化映射 + +| main 分支文件 | 重构分支位置 | 变化 | +|---|---|---| +| `pkg/channels/telegram.go` | `pkg/channels/telegram/telegram.go` + `init.go` | 包名从 `channels` 变为 `telegram` | +| `pkg/channels/discord.go` | `pkg/channels/discord/discord.go` + `init.go` | 同上 | +| `pkg/channels/manager.go` | `pkg/channels/manager.go` | 大幅重写 | +| _(不存在)_ | `pkg/channels/base.go` | 新增共享抽象层 | +| _(不存在)_ | `pkg/channels/registry.go` | 新增工厂注册表 | +| _(不存在)_ | `pkg/channels/errors.go` + `errutil.go` | 新增错误分类体系 | +| _(不存在)_ | `pkg/channels/interfaces.go` | 新增可选能力接口 | +| _(不存在)_ | `pkg/channels/media.go` | 新增 MediaSender 接口 | +| _(不存在)_ | `pkg/channels/webhook.go` | 新增 WebhookHandler/HealthChecker | +| _(不存在)_ | `pkg/channels/split.go` | 新增消息分割(从 utils 迁入) | +| _(不存在)_ | `pkg/bus/types.go` | 新增结构化消息类型 | +| _(不存在)_ | `pkg/media/store.go` | 新增媒体文件生命周期管理 | +| _(不存在)_ | `pkg/identity/identity.go` | 新增统一用户身份 | + +#### 步骤 3:迁移你的 Channel 代码 + +以 Telegram 为例,主要改动项: + +**3a. 包声明和导入** + +```go +// 旧代码(main 分支) +package channels + +import ( + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/config" +) + +// 新代码(重构分支) +package telegram + +import ( + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" // 引用父包 + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/identity" // 新增 + "github.com/sipeed/picoclaw/pkg/media" // 新增(如需媒体) +) +``` + +**3b. 结构体嵌入 BaseChannel** + +```go +// 旧代码:直接持有 bus、config 等字段 +type TelegramChannel struct { + bus *bus.MessageBus + config *config.Config + running bool + allowList []string + // ... +} + +// 新代码:嵌入 BaseChannel,它提供 bus、running、allowList 等 +type TelegramChannel struct { + *channels.BaseChannel // 嵌入共享抽象 + bot *telego.Bot + config *config.Config + // ... 只保留 channel 特有字段 +} +``` + +**3c. 构造函数** + +```go +// 旧代码:直接赋值 +func NewTelegramChannel(cfg *config.Config, bus *bus.MessageBus) (*TelegramChannel, error) { + return &TelegramChannel{ + bus: bus, + config: cfg, + allowList: cfg.Channels.Telegram.AllowFrom, + // ... + }, nil +} + +// 新代码:使用 NewBaseChannel + 功能选项 +func NewTelegramChannel(cfg *config.Config, bus *bus.MessageBus) (*TelegramChannel, error) { + base := channels.NewBaseChannel( + "telegram", // 名称 + cfg.Channels.Telegram, // 原始配置(any 类型) + bus, // 消息总线 + cfg.Channels.Telegram.AllowFrom, // 允许列表 + channels.WithMaxMessageLength(4096), // 平台消息长度上限 + channels.WithGroupTrigger(cfg.Channels.Telegram.GroupTrigger), // 群聊触发配置 + ) + return &TelegramChannel{ + BaseChannel: base, + bot: bot, + config: cfg, + }, nil +} +``` + +**3d. Start/Stop 生命周期** + +```go +// 新代码:使用 SetRunning 原子操作 +func (c *TelegramChannel) Start(ctx context.Context) error { + // ... 初始化 bot、webhook 等 + c.SetRunning(true) // 必须在就绪后调用 + go bh.Start() + return nil +} + +func (c *TelegramChannel) Stop(ctx context.Context) error { + c.SetRunning(false) // 必须在清理前调用 + // ... 停止 bot handler、取消 context + return nil +} +``` + +**3e. Send 方法的错误返回** + +```go +// 旧代码:返回普通 error +func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { + if !c.running { return fmt.Errorf("not running") } + // ... + if err != nil { return err } +} + +// 新代码:必须返回哨兵错误,供 Manager 判断重试策略 +func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { + if !c.IsRunning() { + return channels.ErrNotRunning // ← Manager 不会重试 + } + // ... + if err != nil { + // 使用 ClassifySendError 根据 HTTP 状态码包装错误 + return channels.ClassifySendError(statusCode, err) + // 或手动包装: + // return fmt.Errorf("%w: %v", channels.ErrTemporary, err) + // return fmt.Errorf("%w: %v", channels.ErrRateLimit, err) + // return fmt.Errorf("%w: %v", channels.ErrSendFailed, err) + } + return nil +} +``` + +**3f. 消息接收(Inbound)** + +```go +// 旧代码:直接构造 InboundMessage 并发布 +msg := bus.InboundMessage{ + Channel: "telegram", + SenderID: senderID, + ChatID: chatID, + Content: content, + Metadata: map[string]string{ + "peer_kind": "group", // 路由信息埋在 metadata + "peer_id": chatID, + "message_id": msgID, + }, +} +c.bus.PublishInbound(ctx, msg) + +// 新代码:使用 BaseChannel.HandleMessage,传入结构化字段 +sender := bus.SenderInfo{ + Platform: "telegram", + PlatformID: strconv.FormatInt(from.ID, 10), + CanonicalID: identity.BuildCanonicalID("telegram", strconv.FormatInt(from.ID, 10)), + Username: from.Username, + DisplayName: from.FirstName, +} + +peer := bus.Peer{ + Kind: "group", // 或 "direct" + ID: chatID, +} + +// HandleMessage 内部调用 IsAllowedSender 检查权限,构建 MediaScope,发布到 bus +c.HandleMessage(ctx, peer, messageID, senderID, chatID, content, mediaRefs, metadata, sender) +``` + +**3g. 添加工厂注册(必需)** + +为你的 channel 创建 `init.go`: + +```go +// pkg/channels/telegram/init.go +package telegram + +import ( + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" +) + +func init() { + channels.RegisterFactory("telegram", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) { + return NewTelegramChannel(cfg, b) + }) +} +``` + +**3h. 在 Gateway 中导入子包** + +```go +// cmd/picoclaw/internal/gateway/helpers.go +import ( + _ "github.com/sipeed/picoclaw/pkg/channels/telegram" // 触发 init() 注册 + _ "github.com/sipeed/picoclaw/pkg/channels/discord" + _ "github.com/sipeed/picoclaw/pkg/channels/your_new_channel" // 新增 +) +``` + +#### 步骤 4:迁移 Bus 消息使用方式 + +如果你的代码直接读取 `InboundMessage.Metadata` 中的路由字段: + +```go +// 旧代码 +peerKind := msg.Metadata["peer_kind"] +peerID := msg.Metadata["peer_id"] +msgID := msg.Metadata["message_id"] + +// 新代码 +peerKind := msg.Peer.Kind // 一等字段 +peerID := msg.Peer.ID // 一等字段 +msgID := msg.MessageID // 一等字段 +sender := msg.Sender // bus.SenderInfo 结构体 +scope := msg.MediaScope // 媒体生命周期作用域 +``` + +#### 步骤 5:迁移允许列表检查 + +```go +// 旧代码 +if !c.isAllowed(senderID) { return } + +// 新代码:优先使用结构化检查 +if !c.IsAllowedSender(sender) { return } +// 或回退到字符串检查: +if !c.IsAllowed(senderID) { return } +``` + +`BaseChannel.HandleMessage` 方法内部已经处理了这个逻辑,无需在 channel 中重复检查。 + +### 2.2 如果你有 Manager 的修改 + +Manager 已被完全重写。你的修改需要理解新架构: + +| 旧 Manager 职责 | 新 Manager 职责 | +|---|---| +| 直接构造 channel(switch/if-else) | 通过工厂注册表查找并构造 | +| 直接调用 channel.Send | 通过 per-channel Worker 队列 + 速率限制 + 重试 | +| 无消息分割 | 自动根据 MaxMessageLength 分割长消息 | +| 各 channel 自建 HTTP 服务器 | 统一共享 HTTP 服务器 | +| 无 Typing/Placeholder 管理 | 统一 preSend 处理 Typing 停止 + Reaction 撤销 + Placeholder 编辑;入站侧 BaseChannel.HandleMessage 自动编排 Typing/Reaction/Placeholder | +| 无 TTL 清理 | runTTLJanitor 定期清理过期 Typing/Reaction/Placeholder 条目 | + +### 2.3 如果你有 Agent Loop 的修改 + +Agent Loop 的主要变化: + +1. **MediaStore 注入**:`agentLoop.SetMediaStore(mediaStore)` — Agent 通过 MediaStore 解析工具产生的媒体引用 +2. **ChannelManager 注入**:`agentLoop.SetChannelManager(channelManager)` — Agent 可查询 channel 状态 +3. **OutboundMediaMessage**:Agent 现在通过 `bus.PublishOutboundMedia()` 发送媒体消息,而非嵌入文本回复 +4. **extractPeer**:路由使用 `msg.Peer` 结构化字段而非 Metadata 查找 + +--- + +## 第三部分:新 Channel 开发指南——从零实现一个新 Channel + +### 3.1 最小实现清单 + +要添加一个新的聊天平台(例如 `matrix`),你需要: + +1. ✅ 创建子包目录 `pkg/channels/matrix/` +2. ✅ 创建 `init.go` — 工厂注册 +3. ✅ 创建 `matrix.go` — Channel 实现 +4. ✅ 在 Gateway helpers 中添加 blank import +5. ✅ 在 Manager.initChannels() 中添加配置检查 +6. ✅ 在 `pkg/config/` 中添加配置结构体 + +### 3.2 完整模板 + +#### `pkg/channels/matrix/init.go` + +```go +package matrix + +import ( + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" +) + +func init() { + channels.RegisterFactory("matrix", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) { + return NewMatrixChannel(cfg, b) + }) +} +``` + +#### `pkg/channels/matrix/matrix.go` + +```go +package matrix + +import ( + "context" + "fmt" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/identity" + "github.com/sipeed/picoclaw/pkg/logger" +) + +// MatrixChannel implements channels.Channel for the Matrix protocol. +type MatrixChannel struct { + *channels.BaseChannel // 必须嵌入 + config *config.Config + ctx context.Context + cancel context.CancelFunc + // ... Matrix SDK 客户端等 +} + +func NewMatrixChannel(cfg *config.Config, msgBus *bus.MessageBus) (*MatrixChannel, error) { + matrixCfg := cfg.Channels.Matrix // 假设配置中有此字段 + + base := channels.NewBaseChannel( + "matrix", // channel 名称(全局唯一) + matrixCfg, // 原始配置 + msgBus, // 消息总线 + matrixCfg.AllowFrom, // 允许列表 + channels.WithMaxMessageLength(65536), // Matrix 消息长度限制 + channels.WithGroupTrigger(matrixCfg.GroupTrigger), + ) + + return &MatrixChannel{ + BaseChannel: base, + config: cfg, + }, nil +} + +// ========== 必须实现的 Channel 接口方法 ========== + +func (c *MatrixChannel) Start(ctx context.Context) error { + c.ctx, c.cancel = context.WithCancel(ctx) + + // 1. 初始化 Matrix 客户端 + // 2. 开始监听消息 + // 3. 标记为运行中 + c.SetRunning(true) + + logger.InfoC("matrix", "Matrix channel started") + return nil +} + +func (c *MatrixChannel) Stop(ctx context.Context) error { + c.SetRunning(false) + + if c.cancel != nil { + c.cancel() + } + + logger.InfoC("matrix", "Matrix channel stopped") + return nil +} + +func (c *MatrixChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { + // 1. 检查运行状态 + if !c.IsRunning() { + return channels.ErrNotRunning + } + + // 2. 发送消息到 Matrix + err := c.sendToMatrix(ctx, msg.ChatID, msg.Content) + if err != nil { + // 3. 必须使用错误分类包装 + // 如果你有 HTTP 状态码: + // return channels.ClassifySendError(statusCode, err) + // 如果是网络错误: + // return channels.ClassifyNetError(err) + // 如果需要手动分类: + return fmt.Errorf("%w: %v", channels.ErrTemporary, err) + } + + return nil +} + +// ========== 消息接收处理 ========== + +func (c *MatrixChannel) handleIncoming(roomID, senderID, displayName, content string, msgID string) { + // 1. 构造结构化发送者身份 + sender := bus.SenderInfo{ + Platform: "matrix", + PlatformID: senderID, + CanonicalID: identity.BuildCanonicalID("matrix", senderID), + Username: senderID, + DisplayName: displayName, + } + + // 2. 确定 Peer 类型(直聊 vs 群聊) + peer := bus.Peer{ + Kind: "group", // 或 "direct" + ID: roomID, + } + + // 3. 群聊过滤(如适用) + isGroup := peer.Kind == "group" + if isGroup { + isMentioned := false // 根据平台特性检测 @提及 + shouldRespond, cleanContent := c.ShouldRespondInGroup(isMentioned, content) + if !shouldRespond { + return + } + content = cleanContent + } + + // 4. 处理媒体附件(如有) + var mediaRefs []string + store := c.GetMediaStore() + if store != nil { + // 下载附件到本地 → store.Store() → 获取 ref + // mediaRefs = append(mediaRefs, ref) + } + + // 5. 调用 HandleMessage 发布到 bus + // HandleMessage 内部会: + // - 检查 IsAllowedSender/IsAllowed + // - 构建 MediaScope + // - 发布 InboundMessage + c.HandleMessage( + c.ctx, + peer, + msgID, // 平台消息 ID + senderID, // 原始发送者 ID + roomID, // 聊天/房间 ID + content, // 消息内容 + mediaRefs, // 媒体引用列表 + nil, // 额外 metadata(通常 nil) + sender, // SenderInfo(variadic 参数) + ) +} + +// ========== 内部方法 ========== + +func (c *MatrixChannel) sendToMatrix(ctx context.Context, roomID, content string) error { + // 实际的 Matrix SDK 调用 + return nil +} +``` + +### 3.3 可选能力接口 + +根据平台能力,你的 Channel 可以选择性实现以下接口: + +#### MediaSender — 发送媒体附件 + +```go +// 如果平台支持发送图片/文件/音频/视频 +func (c *MatrixChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error { + if !c.IsRunning() { + return channels.ErrNotRunning + } + + store := c.GetMediaStore() + if store == nil { + return fmt.Errorf("no media store: %w", channels.ErrSendFailed) + } + + for _, part := range msg.Parts { + localPath, err := store.Resolve(part.Ref) + if err != nil { + logger.ErrorCF("matrix", "Failed to resolve media", map[string]any{ + "ref": part.Ref, "error": err.Error(), + }) + continue + } + + // 根据 part.Type ("image"|"audio"|"video"|"file") 调用对应 API + switch part.Type { + case "image": + // 上传图片到 Matrix + default: + // 上传文件到 Matrix + } + } + return nil +} +``` + +#### TypingCapable — Typing 指示器 + +```go +// 如果平台支持 "正在输入..." 提示 +func (c *MatrixChannel) StartTyping(ctx context.Context, chatID string) (stop func(), err error) { + // 调用 Matrix API 发送 typing 指示器 + // 返回的 stop 函数必须是幂等的 + stopped := false + return func() { + if !stopped { + stopped = true + // 调用 Matrix API 停止 typing + } + }, nil +} +``` + +#### ReactionCapable — 消息反应指示器 + +```go +// 如果平台支持对入站消息添加 emoji 反应(如 Slack 的 👀、OneBot 的表情 289) +func (c *MatrixChannel) ReactToMessage(ctx context.Context, chatID, messageID string) (undo func(), err error) { + // 调用 Matrix API 添加反应到消息 + // 返回的 undo 函数移除反应,必须是幂等的 + err = c.addReaction(chatID, messageID, "eyes") + if err != nil { + return func() {}, err + } + return func() { + c.removeReaction(chatID, messageID, "eyes") + }, nil +} +``` + +#### MessageEditor — 消息编辑 + +```go +// 如果平台支持编辑已发送的消息(用于 Placeholder 替换) +func (c *MatrixChannel) EditMessage(ctx context.Context, chatID, messageID, content string) error { + // 调用 Matrix API 编辑消息 + return nil +} +``` + +#### WebhookHandler — HTTP Webhook 接收 + +```go +// 如果 channel 通过 webhook 接收消息(而非长轮询/WebSocket) +func (c *MatrixChannel) WebhookPath() string { + return "/webhook/matrix" // 路径会被注册到共享 HTTP 服务器 +} + +func (c *MatrixChannel) ServeHTTP(w http.ResponseWriter, r *http.Request) { + // 处理 webhook 请求 +} +``` + +#### HealthChecker — 健康检查端点 + +```go +func (c *MatrixChannel) HealthPath() string { + return "/health/matrix" +} + +func (c *MatrixChannel) HealthHandler(w http.ResponseWriter, r *http.Request) { + if c.IsRunning() { + w.WriteHeader(http.StatusOK) + w.Write([]byte("OK")) + } else { + w.WriteHeader(http.StatusServiceUnavailable) + } +} +``` + +### 3.4 入站侧 Typing/Reaction/Placeholder 自动编排 + +`BaseChannel.HandleMessage` 在发布入站消息**之前**,自动检测 channel 是否实现了 `TypingCapable`、`ReactionCapable` 和/或 `PlaceholderCapable`,并触发相应的指示器。三条管道完全独立,互不干扰: + +```go +// BaseChannel.HandleMessage 内部自动执行(无需 channel 手动调用): +if c.owner != nil && c.placeholderRecorder != nil { + // Typing — 独立管道 + if tc, ok := c.owner.(TypingCapable); ok { + if stop, err := tc.StartTyping(ctx, chatID); err == nil { + c.placeholderRecorder.RecordTypingStop(c.name, chatID, stop) + } + } + // Reaction — 独立管道 + if rc, ok := c.owner.(ReactionCapable); ok && messageID != "" { + if undo, err := rc.ReactToMessage(ctx, chatID, messageID); err == nil { + c.placeholderRecorder.RecordReactionUndo(c.name, chatID, undo) + } + } + // Placeholder — 独立管道 + if pc, ok := c.owner.(PlaceholderCapable); ok { + if phID, err := pc.SendPlaceholder(ctx, chatID); err == nil && phID != "" { + c.placeholderRecorder.RecordPlaceholder(c.name, chatID, phID) + } + } +} +``` + +**这意味着**: +- 实现 `TypingCapable` 的 channel(Telegram、Discord、LINE、Pico)无需在 `handleMessage` 中手动调用 `StartTyping` + `RecordTypingStop` +- 实现 `ReactionCapable` 的 channel(Slack、OneBot)无需在 `handleMessage` 中手动调用 `AddReaction` + `RecordTypingStop` +- 实现 `PlaceholderCapable` 的 channel(Telegram、Discord、Pico)无需在 `handleMessage` 中手动发送占位消息并调用 `RecordPlaceholder` +- Channel 只需实现对应接口,`HandleMessage` 会自动完成编排 +- 不实现这些接口的 channel 不受影响(类型断言会失败,跳过) +- `PlaceholderCapable` 的 `SendPlaceholder` 方法内部根据配置的 `PlaceholderConfig.Enabled` 决定是否发送;返回 `("", nil)` 时跳过注册 + +**Owner 注入**:Manager 在 `initChannel` 中自动调用 `SetOwner(ch)` 将具体 channel 注入 BaseChannel,无需开发者手动设置。 + +当 Agent 处理完消息后,Manager 的 `preSend` 会自动: +1. 调用已记录的 `stop()` 停止 Typing +2. 调用已记录的 `undo()` 撤销 Reaction +3. 如果有 Placeholder,且 channel 实现了 `MessageEditor`,尝试编辑 Placeholder 为最终回复(跳过 Send) + +### 3.5 注册配置和 Gateway 接入 + +#### 在 `pkg/config/config.go` 中添加配置 + +```go +type ChannelsConfig struct { + // ... 现有 channels + Matrix MatrixChannelConfig `yaml:"matrix" json:"matrix"` +} + +type MatrixChannelConfig struct { + Enabled bool `yaml:"enabled" json:"enabled"` + HomeServer string `yaml:"home_server" json:"home_server"` + Token string `yaml:"token" json:"token"` + AllowFrom []string `yaml:"allow_from" json:"allow_from"` + GroupTrigger GroupTriggerConfig `yaml:"group_trigger" json:"group_trigger"` +} +``` + +#### 在 Manager.initChannels() 中添加入口 + +```go +// pkg/channels/manager.go 的 initChannels() 方法中 +if m.config.Channels.Matrix.Enabled && m.config.Channels.Matrix.Token != "" { + m.initChannel("matrix", "Matrix") +} +``` + +#### 在 Gateway 中添加 blank import + +```go +// cmd/picoclaw/internal/gateway/helpers.go +import ( + _ "github.com/sipeed/picoclaw/pkg/channels/matrix" +) +``` + +--- + +## 第四部分:核心子系统详解 + +### 4.1 MessageBus + +**文件**:`pkg/bus/bus.go`、`pkg/bus/types.go` + +```go +type MessageBus struct { + inbound chan InboundMessage // 缓冲区 = 64 + outbound chan OutboundMessage // 缓冲区 = 64 + outboundMedia chan OutboundMediaMessage // 缓冲区 = 64 + done chan struct{} // 关闭信号 + closed atomic.Bool // 防止重复关闭 +} +``` + +**关键行为**: + +| 方法 | 行为 | +|------|------| +| `PublishInbound(ctx, msg)` | 检查 closed → 发送到 inbound channel → 阻塞/超时/关闭 | +| `ConsumeInbound(ctx)` | 从 inbound 读取 → 阻塞/关闭/取消 | +| `PublishOutbound(ctx, msg)` | 发送到 outbound channel | +| `SubscribeOutbound(ctx)` | 从 outbound 读取(Manager dispatcher 调用) | +| `PublishOutboundMedia(ctx, msg)` | 发送到 outboundMedia channel | +| `SubscribeOutboundMedia(ctx)` | 从 outboundMedia 读取(Manager media dispatcher 调用) | +| `Close()` | CAS 关闭 → close(done) → 排水所有 channel(**不关闭 channel 本身**,避免并发 send-on-closed panic) | + +**设计要点**: +- 缓冲区从 16 增至 64,减少突发负载下的阻塞 +- `Close()` 不关闭底层 channel(只关闭 `done` 信号通道),因为可能有正在并发 `Publish` 的 goroutine +- 排水循环确保 buffered 消息不被静默丢弃 + +### 4.2 结构化消息类型 + +**文件**:`pkg/bus/types.go` + +```go +// 路由对等体 +type Peer struct { + Kind string `json:"kind"` // "direct" | "group" | "channel" | "" + ID string `json:"id"` +} + +// 发送者身份信息 +type SenderInfo struct { + Platform string `json:"platform,omitempty"` // "telegram", "discord", ... + PlatformID string `json:"platform_id,omitempty"` // 平台原始 ID + CanonicalID string `json:"canonical_id,omitempty"` // "platform:id" 规范格式 + Username string `json:"username,omitempty"` + DisplayName string `json:"display_name,omitempty"` +} + +// 入站消息 +type InboundMessage struct { + Channel string // 来源 channel 名称 + SenderID string // 发送者 ID(优先使用 CanonicalID) + Sender SenderInfo // 结构化发送者信息 + ChatID string // 聊天/房间 ID + Content string // 消息文本 + Media []string // 媒体引用列表(media://...) + Peer Peer // 路由对等体(一等字段) + MessageID string // 平台消息 ID(一等字段) + MediaScope string // 媒体生命周期作用域 + SessionKey string // 会话键 + Metadata map[string]string // 仅用于 channel 特有扩展 +} + +// 出站文本消息 +type OutboundMessage struct { + Channel string + ChatID string + Content string +} + +// 出站媒体消息 +type OutboundMediaMessage struct { + Channel string + ChatID string + Parts []MediaPart +} + +// 媒体片段 +type MediaPart struct { + Type string // "image" | "audio" | "video" | "file" + Ref string // "media://uuid" + Caption string + Filename string + ContentType string +} +``` + +### 4.3 BaseChannel + +**文件**:`pkg/channels/base.go` + +BaseChannel 是所有 channel 的共享抽象层,提供以下能力: + +| 方法/特性 | 说明 | +|---|---| +| `Name() string` | Channel 名称 | +| `IsRunning() bool` | 原子读取运行状态 | +| `SetRunning(bool)` | 原子设置运行状态 | +| `MaxMessageLength() int` | 消息长度限制(rune 计数),0 = 无限制 | +| `IsAllowed(senderID string) bool` | 旧格式允许列表检查(支持 `"id\|username"` 和 `"@username"` 格式) | +| `IsAllowedSender(sender SenderInfo) bool` | 新格式允许列表检查(委托给 `identity.MatchAllowed`) | +| `ShouldRespondInGroup(isMentioned, content) (bool, string)` | 统一群聊触发过滤逻辑 | +| `HandleMessage(...)` | 统一入站消息处理:权限检查 → 构建 MediaScope → 自动触发 Typing/Reaction → 发布到 Bus | +| `SetMediaStore(s) / GetMediaStore()` | Manager 注入的媒体存储 | +| `SetPlaceholderRecorder(r) / GetPlaceholderRecorder()` | Manager 注入的占位符记录器 | +| `SetOwner(ch) ` | Manager 注入的具体 channel 引用(用于 HandleMessage 内部的 Typing/Reaction 类型断言) | + +**功能选项**: + +```go +channels.WithMaxMessageLength(4096) // 设置平台消息长度限制 +channels.WithGroupTrigger(groupTriggerCfg) // 设置群聊触发配置 +``` + +### 4.4 工厂注册表 + +**文件**:`pkg/channels/registry.go` + +```go +type ChannelFactory func(cfg *config.Config, bus *bus.MessageBus) (Channel, error) + +func RegisterFactory(name string, f ChannelFactory) // 子包 init() 中调用 +func getFactory(name string) (ChannelFactory, bool) // Manager 内部调用 +``` + +工厂注册表使用 `sync.RWMutex` 保护,在 `init()` 阶段注册(进程启动时完成)。Manager 在 `initChannel()` 中通过名字查找工厂并调用它。 + +### 4.5 错误分类与重试 + +**文件**:`pkg/channels/errors.go`、`pkg/channels/errutil.go` + +#### 哨兵错误 + +```go +var ( + ErrNotRunning = errors.New("channel not running") // 永久:不重试 + ErrRateLimit = errors.New("rate limited") // 固定延迟:1s 后重试 + ErrTemporary = errors.New("temporary failure") // 指数退避:500ms * 2^attempt,最大 8s + ErrSendFailed = errors.New("send failed") // 永久:不重试 +) +``` + +#### 错误分类帮助函数 + +```go +// 根据 HTTP 状态码自动分类 +func ClassifySendError(statusCode int, rawErr error) error { + // 429 → ErrRateLimit + // 5xx → ErrTemporary + // 4xx → ErrSendFailed +} + +// 网络错误统一包装为临时错误 +func ClassifyNetError(err error) error { + // → ErrTemporary +} +``` + +#### Manager 重试策略(`sendWithRetry`) + +``` +最大重试次数: 3 +速率限制延迟: 1 秒 +基础退避: 500 毫秒 +最大退避: 8 秒 + +重试逻辑: + ErrNotRunning → 立即失败,不重试 + ErrSendFailed → 立即失败,不重试 + ErrRateLimit → 等待 1s → 重试 + ErrTemporary → 等待 500ms * 2^attempt(最大 8s) → 重试 + 其他未知错误 → 等待 500ms * 2^attempt(最大 8s) → 重试 +``` + +### 4.6 Manager 编排 + +**文件**:`pkg/channels/manager.go` + +#### Per-channel Worker 架构 + +```go +type channelWorker struct { + ch Channel // channel 实例 + queue chan bus.OutboundMessage // 出站文本队列(缓冲 16) + mediaQueue chan bus.OutboundMediaMessage // 出站媒体队列(缓冲 16) + done chan struct{} // 文本 worker 完成信号 + mediaDone chan struct{} // 媒体 worker 完成信号 + limiter *rate.Limiter // per-channel 速率限制器 +} +``` + +#### Per-channel 速率限制配置 + +```go +var channelRateConfig = map[string]float64{ + "telegram": 20, // 20 msg/s + "discord": 1, // 1 msg/s + "slack": 1, // 1 msg/s + "line": 10, // 10 msg/s +} +// 默认: 10 msg/s +// burst = max(1, ceil(rate/2)) +``` + +#### 生命周期管理 + +``` +StartAll: + 1. 遍历已注册 channels → channel.Start(ctx) + 2. 为每个启动成功的 channel 创建 channelWorker + 3. 启动 goroutines: + - runWorker (per-channel 出站文本) + - runMediaWorker (per-channel 出站媒体) + - dispatchOutbound (从 bus 路由到 worker 队列) + - dispatchOutboundMedia (从 bus 路由到 media worker 队列) + - runTTLJanitor (每 10s 清理过期 typing/placeholder) + 4. 启动共享 HTTP 服务器(如已配置) + +StopAll: + 1. 关闭共享 HTTP 服务器(5s 超时) + 2. 取消 dispatcher context + 3. 关闭 text worker 队列 → 等待排水完成 + 4. 关闭 media worker 队列 → 等待排水完成 + 5. 停止每个 channel(channel.Stop) +``` + +#### Typing/Reaction/Placeholder 管理 + +```go +// Manager 实现 PlaceholderRecorder 接口 +func (m *Manager) RecordPlaceholder(channel, chatID, placeholderID string) +func (m *Manager) RecordTypingStop(channel, chatID string, stop func()) +func (m *Manager) RecordReactionUndo(channel, chatID string, undo func()) + +// 入站侧:BaseChannel.HandleMessage 自动编排 +// BaseChannel.HandleMessage 在 PublishInbound 之前,通过 owner 类型断言自动触发: +// - TypingCapable.StartTyping → RecordTypingStop +// - ReactionCapable.ReactToMessage → RecordReactionUndo +// - PlaceholderCapable.SendPlaceholder → RecordPlaceholder +// 三者独立,互不干扰。Channel 无需手动调用。 + +// 出站侧:发送前处理 +func (m *Manager) preSend(ctx, name, msg, ch) bool { + key := name + ":" + msg.ChatID + // 1. 停止 Typing(调用存储的 stop 函数) + // 2. 撤销 Reaction(调用存储的 undo 函数) + // 3. 尝试编辑 Placeholder(如果 channel 实现了 MessageEditor) + // 成功 → return true(跳过 Send) + // 失败 → return false(继续 Send) +} +``` + +Manager 存储完全分离,三条管道互不干扰: + +```go +Manager { + typingStops sync.Map // "channel:chatID" → typingEntry ← 管 TypingCapable + reactionUndos sync.Map // "channel:chatID" → reactionEntry ← 管 ReactionCapable + placeholders sync.Map // "channel:chatID" → placeholderEntry +} +``` + +TTL 清理: +- Typing 停止函数:5 分钟 TTL(到期后自动调用 stop 并删除) +- Reaction 撤销函数:5 分钟 TTL(到期后自动调用 undo 并删除) +- Placeholder ID:10 分钟 TTL(到期后删除) +- 清理间隔:10 秒 + +### 4.7 消息分割 + +**文件**:`pkg/channels/split.go` + +`SplitMessage(content string, maxLen int) []string` + +智能分割策略: +1. 计算有效分割点 = maxLen - 10% 缓冲区(为代码块闭合留空间) +2. 优先在换行符处分割 +3. 其次在空格/制表符处分割 +4. 检测未闭合的代码块(` ``` `) +5. 如果代码块未闭合: + - 尝试扩展到 maxLen 以包含闭合围栏 + - 如果代码块太长,注入闭合/重开围栏(`\n```\n` + header) + - 最后手段:在代码块开始前分割 + +### 4.8 MediaStore + +**文件**:`pkg/media/store.go` + +```go +type MediaStore interface { + Store(localPath string, meta MediaMeta, scope string) (ref string, err error) + Resolve(ref string) (localPath string, err error) + ResolveWithMeta(ref string) (localPath string, meta MediaMeta, err error) + ReleaseAll(scope string) error +} +``` + +**FileMediaStore 实现**: +- 纯内存映射,不复制/移动文件 +- 引用格式:`media://` +- Scope 格式:`channel:chatID:messageID`(由 `BuildMediaScope` 生成) +- **两阶段操作**: + - Phase 1(持锁):从 map 中收集并删除条目 + - Phase 2(无锁):从磁盘删除文件 + - 目的:最小化锁争用 +- **TTL 清理**:`NewFileMediaStoreWithCleanup` → `Start()` 启动后台清理协程 +- 清理间隔和最大存活时间由配置控制 + +### 4.9 Identity + +**文件**:`pkg/identity/identity.go` + +```go +// 构建规范 ID +func BuildCanonicalID(platform, platformID string) string +// → "telegram:123456" + +// 解析规范 ID +func ParseCanonicalID(canonical string) (platform, id string, ok bool) + +// 匹配允许列表(向后兼容) +func MatchAllowed(sender bus.SenderInfo, allowed string) bool +``` + +`MatchAllowed` 支持的允许列表格式: +| 格式 | 匹配方式 | +|------|----------| +| `"123456"` | 匹配 `sender.PlatformID` | +| `"@alice"` | 匹配 `sender.Username` | +| `"123456\|alice"` | 匹配 PlatformID 或 Username(旧格式兼容) | +| `"telegram:123456"` | 精确匹配 `sender.CanonicalID`(新格式) | + +### 4.10 共享 HTTP 服务器 + +**文件**:`pkg/channels/manager.go` 的 `SetupHTTPServer` + +Manager 创建单一 `http.Server`,自动发现和注册: +- 实现 `WebhookHandler` 的 channel → 挂载到 `wh.WebhookPath()` +- 实现 `HealthChecker` 的 channel → 挂载到 `hc.HealthPath()` +- Health 全局端点由 `health.Server.RegisterOnMux` 注册 + +超时配置:ReadTimeout = 30s, WriteTimeout = 30s + +--- + +## 第五部分:关键设计决策与约定 + +### 5.1 必须遵守的约定 + +1. **错误分类是合约**:Channel 的 `Send` 方法**必须**返回哨兵错误(或包装它们)。Manager 的重试策略完全依赖 `errors.Is` 检查。如果返回未分类的错误,Manager 会按"未知错误"处理(指数退避重试)。 + +2. **SetRunning 是生命周期信号**:`Start` 成功后**必须**调用 `c.SetRunning(true)`,`Stop` 开始时**必须**调用 `c.SetRunning(false)`。`Send` 中**必须**检查 `c.IsRunning()` 并返回 `ErrNotRunning`。 + +3. **HandleMessage 包含权限检查**:不要在调用 `HandleMessage` 之前自行进行权限检查(除非你需要在检查前做平台特定的预处理)。`HandleMessage` 内部已经调用 `IsAllowedSender`/`IsAllowed`。 + +4. **消息分割由 Manager 处理**:Channel 的 `Send` 方法不需要处理长消息分割。Manager 会在调用 `Send` 之前根据 `MaxMessageLength()` 自动分割。Channel 只需通过 `WithMaxMessageLength` 声明限制。 + +5. **Typing/Reaction/Placeholder 由 BaseChannel + Manager 自动处理**:Channel 的 `Send` 方法不需要管理 Typing 停止、Reaction 撤销或 Placeholder 编辑。`BaseChannel.HandleMessage` 在入站侧自动触发 `TypingCapable`、`ReactionCapable` 和 `PlaceholderCapable`(通过 `owner` 类型断言);Manager 的 `preSend` 在出站侧自动停止 Typing、撤销 Reaction、编辑 Placeholder。Channel 只需实现对应接口即可。 + +6. **工厂注册在 init() 中**:每个子包必须有 `init.go` 文件调用 `channels.RegisterFactory`。Gateway 必须通过 blank import(`_ "pkg/channels/xxx"`)触发注册。 + +### 5.2 Metadata 字段使用约定 + +**不要再把以下信息放入 Metadata**: +- `peer_kind` / `peer_id` → 使用 `InboundMessage.Peer` +- `message_id` → 使用 `InboundMessage.MessageID` +- `sender_platform` / `sender_username` → 使用 `InboundMessage.Sender` + +**Metadata 仅用于**: +- Channel 特有的扩展信息(如 Telegram 的 `reply_to_message_id`) +- 不适合放入结构化字段的临时信息 + +### 5.3 并发安全约定 + +- `BaseChannel.running`:使用 `atomic.Bool`,线程安全 +- `Manager.channels` / `Manager.workers`:使用 `sync.RWMutex` 保护 +- `Manager.placeholders` / `Manager.typingStops` / `Manager.reactionUndos`:使用 `sync.Map` +- `MessageBus.closed`:使用 `atomic.Bool` +- `FileMediaStore`:使用 `sync.RWMutex`,两阶段操作减少持锁时间 +- Channel Worker queue:Go channel,天然并发安全 + +### 5.4 测试约定 + +已有测试文件: +- `pkg/channels/base_test.go` — BaseChannel 单元测试 +- `pkg/channels/manager_test.go` — Manager 单元测试 +- `pkg/channels/split_test.go` — 消息分割测试 +- `pkg/channels/errors_test.go` — 错误类型测试 +- `pkg/channels/errutil_test.go` — 错误分类测试 + +为新 channel 添加测试时: +```bash +go test ./pkg/channels/matrix/ -v # 子包测试 +go test ./pkg/channels/ -run TestSpecific -v # 框架测试 +make test # 全量测试 +``` + +--- + +## 附录:完整文件清单与接口速查表 + +### A.1 框架层文件 + +| 文件 | 职责 | +|------|------| +| `pkg/channels/base.go` | BaseChannel 结构体、Channel 接口、MessageLengthProvider、BaseChannelOption、HandleMessage | +| `pkg/channels/interfaces.go` | TypingCapable、MessageEditor、ReactionCapable、PlaceholderCapable、PlaceholderRecorder 接口 | +| `pkg/channels/media.go` | MediaSender 接口 | +| `pkg/channels/webhook.go` | WebhookHandler、HealthChecker 接口 | +| `pkg/channels/errors.go` | ErrNotRunning、ErrRateLimit、ErrTemporary、ErrSendFailed 哨兵 | +| `pkg/channels/errutil.go` | ClassifySendError、ClassifyNetError 帮助函数 | +| `pkg/channels/registry.go` | RegisterFactory、getFactory 工厂注册表 | +| `pkg/channels/manager.go` | Manager:Worker 队列、速率限制、重试、preSend、共享 HTTP、TTL janitor | +| `pkg/channels/split.go` | SplitMessage 长消息分割 | +| `pkg/bus/bus.go` | MessageBus 实现 | +| `pkg/bus/types.go` | Peer、SenderInfo、InboundMessage、OutboundMessage、OutboundMediaMessage、MediaPart | +| `pkg/media/store.go` | MediaStore 接口、FileMediaStore 实现 | +| `pkg/identity/identity.go` | BuildCanonicalID、ParseCanonicalID、MatchAllowed | + +### A.2 Channel 子包 + +| 子包 | 注册名 | 可选接口 | +|------|--------|----------| +| `pkg/channels/telegram/` | `"telegram"` | MessageEditor, MediaSender, TypingCapable, PlaceholderCapable | +| `pkg/channels/discord/` | `"discord"` | MessageEditor, TypingCapable, PlaceholderCapable | +| `pkg/channels/slack/` | `"slack"` | ReactionCapable | +| `pkg/channels/line/` | `"line"` | WebhookHandler, HealthChecker, TypingCapable | +| `pkg/channels/onebot/` | `"onebot"` | ReactionCapable | +| `pkg/channels/dingtalk/` | `"dingtalk"` | WebhookHandler | +| `pkg/channels/feishu/` | `"feishu"` | WebhookHandler (架构特定 build tags) | +| `pkg/channels/wecom/` | `"wecom"` + `"wecom_app"` | WebhookHandler | +| `pkg/channels/qq/` | `"qq"` | — | +| `pkg/channels/whatsapp/` | `"whatsapp"` | — | +| `pkg/channels/maixcam/` | `"maixcam"` | — | +| `pkg/channels/pico/` | `"pico"` | WebhookHandler (Pico Protocol), TypingCapable, PlaceholderCapable | + +### A.3 接口速查表 + +```go +// ===== 必须实现 ===== +type Channel interface { + Name() string + Start(ctx context.Context) error + Stop(ctx context.Context) error + Send(ctx context.Context, msg bus.OutboundMessage) error + IsRunning() bool + IsAllowed(senderID string) bool + IsAllowedSender(sender bus.SenderInfo) bool +} + +// ===== 可选实现 ===== +type MediaSender interface { + SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error +} + +type TypingCapable interface { + StartTyping(ctx context.Context, chatID string) (stop func(), err error) +} + +type ReactionCapable interface { + ReactToMessage(ctx context.Context, chatID, messageID string) (undo func(), err error) +} + +type PlaceholderCapable interface { + SendPlaceholder(ctx context.Context, chatID string) (messageID string, err error) +} + +type MessageEditor interface { + EditMessage(ctx context.Context, chatID, messageID, content string) error +} + +type WebhookHandler interface { + WebhookPath() string + http.Handler +} + +type HealthChecker interface { + HealthPath() string + HealthHandler(w http.ResponseWriter, r *http.Request) +} + +type MessageLengthProvider interface { + MaxMessageLength() int +} + +// ===== 由 Manager 注入 ===== +type PlaceholderRecorder interface { + RecordPlaceholder(channel, chatID, placeholderID string) + RecordTypingStop(channel, chatID string, stop func()) + RecordReactionUndo(channel, chatID string, undo func()) +} +``` + +### A.4 Gateway 启动序列(完整引导流程) + +```go +// 1. 创建核心组件 +msgBus := bus.NewMessageBus() +provider := providers.CreateProvider(cfg) +agentLoop := agent.NewAgentLoop(cfg, msgBus, provider) + +// 2. 创建媒体存储(带 TTL 清理) +mediaStore := media.NewFileMediaStoreWithCleanup(cleanerConfig) +mediaStore.Start() + +// 3. 创建 Channel Manager(触发 initChannels → 工厂查找 → 构造 → 注入 MediaStore/PlaceholderRecorder/Owner) +channelManager := channels.NewManager(cfg, msgBus, mediaStore) + +// 4. 注入引用 +agentLoop.SetChannelManager(channelManager) +agentLoop.SetMediaStore(mediaStore) + +// 5. 配置共享 HTTP 服务器 +channelManager.SetupHTTPServer(addr, healthServer) + +// 6. 启动 +channelManager.StartAll(ctx) // 启动 channels + workers + dispatchers + HTTP server +go agentLoop.Run(ctx) // 启动 Agent 消息循环 + +// 7. 关闭(信号触发) +cancel() // 取消 context +msgBus.Close() // 信号关闭 + 排水 +channelManager.StopAll(shutdownCtx) // 停止 HTTP + workers + channels +mediaStore.Stop() // 停止 TTL 清理 +agentLoop.Stop() // 停止 Agent +``` + +### A.5 Per-channel 速率限制参考 + +| Channel | 速率 (msg/s) | Burst | +|---------|-------------|-------| +| telegram | 20 | 10 | +| discord | 1 | 1 | +| slack | 1 | 1 | +| line | 10 | 5 | +| _其他_ | 10 (默认) | 5 | + +### A.6 已知限制和注意事项 + +1. **媒体清理暂时禁用**:Agent loop 中的 `ReleaseAll` 调用被注释掉了(`refactor(loop): disable media cleanup to prevent premature file deletion`),因为会话边界尚未明确定义。TTL 清理仍然有效。 + +2. **Feishu 架构特定编译**:Feishu channel 使用 build tags 区分 32 位和 64 位架构(`feishu_32.go` / `feishu_64.go`)。 + +3. **WeCom 有两个工厂**:`"wecom"`(Bot 模式)和 `"wecom_app"`(应用模式)分别注册。 + +4. **Pico Protocol**:`pkg/channels/pico/` 实现了一个自定义的 PicoClaw 原生协议 channel,通过 webhook 接收消息。 \ No newline at end of file diff --git a/pkg/channels/base.go b/pkg/channels/base.go index cd6419ebb..063a66523 100644 --- a/pkg/channels/base.go +++ b/pkg/channels/base.go @@ -2,11 +2,44 @@ package channels import ( "context" + "crypto/rand" + "encoding/binary" + "encoding/hex" + "strconv" "strings" + "sync/atomic" + "time" "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/identity" + "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/media" ) +var ( + uniqueIDCounter uint64 + uniqueIDPrefix string +) + +func init() { + // One-time read from crypto/rand for a unique prefix (single syscall). + var b [8]byte + if _, err := rand.Read(b[:]); err != nil { + // fallback to time-based prefix + binary.BigEndian.PutUint64(b[:], uint64(time.Now().UnixNano())) + } + uniqueIDPrefix = hex.EncodeToString(b[:]) +} + +// uniqueID generates a process-unique ID using a random prefix and an atomic counter. +// This ID is intended for internal correlation (e.g. media scope keys) and is NOT +// cryptographically secure — it must not be used in contexts where unpredictability matters. +func uniqueID() string { + n := atomic.AddUint64(&uniqueIDCounter, 1) + return uniqueIDPrefix + strconv.FormatUint(n, 16) +} + type Channel interface { Name() string Start(ctx context.Context) error @@ -14,32 +47,126 @@ type Channel interface { Send(ctx context.Context, msg bus.OutboundMessage) error IsRunning() bool IsAllowed(senderID string) bool + IsAllowedSender(sender bus.SenderInfo) bool + ReasoningChannelID() string +} + +// BaseChannelOption is a functional option for configuring a BaseChannel. +type BaseChannelOption func(*BaseChannel) + +// WithMaxMessageLength sets the maximum message length (in runes) for a channel. +// Messages exceeding this limit will be automatically split by the Manager. +// A value of 0 means no limit. +func WithMaxMessageLength(n int) BaseChannelOption { + return func(c *BaseChannel) { c.maxMessageLength = n } +} + +// WithGroupTrigger sets the group trigger configuration for a channel. +func WithGroupTrigger(gt config.GroupTriggerConfig) BaseChannelOption { + return func(c *BaseChannel) { c.groupTrigger = gt } +} + +// WithReasoningChannelID sets the reasoning channel ID where thoughts should be sent. +func WithReasoningChannelID(id string) BaseChannelOption { + return func(c *BaseChannel) { c.reasoningChannelID = id } +} + +// MessageLengthProvider is an opt-in interface that channels implement +// to advertise their maximum message length. The Manager uses this via +// type assertion to decide whether to split outbound messages. +type MessageLengthProvider interface { + MaxMessageLength() int } type BaseChannel struct { - config any - bus *bus.MessageBus - running bool - name string - allowList []string + config any + bus *bus.MessageBus + running atomic.Bool + name string + allowList []string + maxMessageLength int + groupTrigger config.GroupTriggerConfig + mediaStore media.MediaStore + placeholderRecorder PlaceholderRecorder + owner Channel // the concrete channel that embeds this BaseChannel + reasoningChannelID string } -func NewBaseChannel(name string, config any, bus *bus.MessageBus, allowList []string) *BaseChannel { - return &BaseChannel{ +func NewBaseChannel( + name string, + config any, + bus *bus.MessageBus, + allowList []string, + opts ...BaseChannelOption, +) *BaseChannel { + bc := &BaseChannel{ config: config, bus: bus, name: name, allowList: allowList, - running: false, } + for _, opt := range opts { + opt(bc) + } + return bc +} + +// MaxMessageLength returns the maximum message length (in runes) for this channel. +// A value of 0 means no limit. +func (c *BaseChannel) MaxMessageLength() int { + return c.maxMessageLength +} + +// ShouldRespondInGroup determines whether the bot should respond in a group chat. +// Each channel is responsible for: +// 1. Detecting isMentioned (platform-specific) +// 2. Stripping bot mention from content (platform-specific) +// 3. Calling this method to get the group response decision +// +// Logic: +// - If isMentioned → always respond +// - If mention_only configured and not mentioned → ignore +// - If prefixes configured → respond if content starts with any prefix (strip it) +// - If prefixes configured but no match and not mentioned → ignore +// - Otherwise (no group_trigger configured) → respond to all (permissive default) +func (c *BaseChannel) ShouldRespondInGroup(isMentioned bool, content string) (bool, string) { + gt := c.groupTrigger + + // Mentioned → always respond + if isMentioned { + return true, strings.TrimSpace(content) + } + + // mention_only → require mention + if gt.MentionOnly { + return false, content + } + + // Prefix matching + if len(gt.Prefixes) > 0 { + for _, prefix := range gt.Prefixes { + if prefix != "" && strings.HasPrefix(content, prefix) { + return true, strings.TrimSpace(strings.TrimPrefix(content, prefix)) + } + } + // Prefixes configured but none matched and not mentioned → ignore + return false, content + } + + // No group_trigger configured → permissive (respond to all) + return true, strings.TrimSpace(content) } func (c *BaseChannel) Name() string { return c.name } +func (c *BaseChannel) ReasoningChannelID() string { + return c.reasoningChannelID +} + func (c *BaseChannel) IsRunning() bool { - return c.running + return c.running.Load() } func (c *BaseChannel) IsAllowed(senderID string) bool { @@ -81,23 +208,130 @@ func (c *BaseChannel) IsAllowed(senderID string) bool { return false } -func (c *BaseChannel) HandleMessage(senderID, chatID, content string, media []string, metadata map[string]string) { - if !c.IsAllowed(senderID) { - return +// IsAllowedSender checks whether a structured SenderInfo is permitted by the allow-list. +// It delegates to identity.MatchAllowed for each entry, providing unified matching +// across all legacy formats and the new canonical "platform:id" format. +func (c *BaseChannel) IsAllowedSender(sender bus.SenderInfo) bool { + if len(c.allowList) == 0 { + return true } + for _, allowed := range c.allowList { + if identity.MatchAllowed(sender, allowed) { + return true + } + } + + return false +} + +func (c *BaseChannel) HandleMessage( + ctx context.Context, + peer bus.Peer, + messageID, senderID, chatID, content string, + media []string, + metadata map[string]string, + senderOpts ...bus.SenderInfo, +) { + // Use SenderInfo-based allow check when available, else fall back to string + var sender bus.SenderInfo + if len(senderOpts) > 0 { + sender = senderOpts[0] + } + if sender.CanonicalID != "" || sender.PlatformID != "" { + if !c.IsAllowedSender(sender) { + return + } + } else { + if !c.IsAllowed(senderID) { + return + } + } + + // Set SenderID to canonical if available, otherwise keep the raw senderID + resolvedSenderID := senderID + if sender.CanonicalID != "" { + resolvedSenderID = sender.CanonicalID + } + + scope := BuildMediaScope(c.name, chatID, messageID) + msg := bus.InboundMessage{ - Channel: c.name, - SenderID: senderID, - ChatID: chatID, - Content: content, - Media: media, - Metadata: metadata, + Channel: c.name, + SenderID: resolvedSenderID, + Sender: sender, + ChatID: chatID, + Content: content, + Media: media, + Peer: peer, + MessageID: messageID, + MediaScope: scope, + Metadata: metadata, } - c.bus.PublishInbound(msg) + // Auto-trigger typing indicator, message reaction, and placeholder before publishing. + // Each capability is independent — all three may fire for the same message. + if c.owner != nil && c.placeholderRecorder != nil { + // Typing — independent pipeline + if tc, ok := c.owner.(TypingCapable); ok { + if stop, err := tc.StartTyping(ctx, chatID); err == nil { + c.placeholderRecorder.RecordTypingStop(c.name, chatID, stop) + } + } + // Reaction — independent pipeline + if rc, ok := c.owner.(ReactionCapable); ok && messageID != "" { + if undo, err := rc.ReactToMessage(ctx, chatID, messageID); err == nil { + c.placeholderRecorder.RecordReactionUndo(c.name, chatID, undo) + } + } + // Placeholder — independent pipeline + if pc, ok := c.owner.(PlaceholderCapable); ok { + if phID, err := pc.SendPlaceholder(ctx, chatID); err == nil && phID != "" { + c.placeholderRecorder.RecordPlaceholder(c.name, chatID, phID) + } + } + } + + if err := c.bus.PublishInbound(ctx, msg); err != nil { + logger.ErrorCF("channels", "Failed to publish inbound message", map[string]any{ + "channel": c.name, + "chat_id": chatID, + "error": err.Error(), + }) + } } -func (c *BaseChannel) setRunning(running bool) { - c.running = running +func (c *BaseChannel) SetRunning(running bool) { + c.running.Store(running) +} + +// SetMediaStore injects a MediaStore into the channel. +func (c *BaseChannel) SetMediaStore(s media.MediaStore) { c.mediaStore = s } + +// GetMediaStore returns the injected MediaStore (may be nil). +func (c *BaseChannel) GetMediaStore() media.MediaStore { return c.mediaStore } + +// SetPlaceholderRecorder injects a PlaceholderRecorder into the channel. +func (c *BaseChannel) SetPlaceholderRecorder(r PlaceholderRecorder) { + c.placeholderRecorder = r +} + +// GetPlaceholderRecorder returns the injected PlaceholderRecorder (may be nil). +func (c *BaseChannel) GetPlaceholderRecorder() PlaceholderRecorder { + return c.placeholderRecorder +} + +// SetOwner injects the concrete channel that embeds this BaseChannel. +// This allows HandleMessage to auto-trigger TypingCapable / ReactionCapable / PlaceholderCapable. +func (c *BaseChannel) SetOwner(ch Channel) { + c.owner = ch +} + +// BuildMediaScope constructs a scope key for media lifecycle tracking. +func BuildMediaScope(channel, chatID, messageID string) string { + id := messageID + if id == "" { + id = uniqueID() + } + return channel + ":" + chatID + ":" + id } diff --git a/pkg/channels/base_test.go b/pkg/channels/base_test.go index 78c6d1d66..6132b8bf9 100644 --- a/pkg/channels/base_test.go +++ b/pkg/channels/base_test.go @@ -1,6 +1,11 @@ package channels -import "testing" +import ( + "testing" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/config" +) func TestBaseChannelIsAllowed(t *testing.T) { tests := []struct { @@ -50,3 +55,211 @@ func TestBaseChannelIsAllowed(t *testing.T) { }) } } + +func TestShouldRespondInGroup(t *testing.T) { + tests := []struct { + name string + gt config.GroupTriggerConfig + isMentioned bool + content string + wantRespond bool + wantContent string + }{ + { + name: "no config - permissive default", + gt: config.GroupTriggerConfig{}, + isMentioned: false, + content: "hello world", + wantRespond: true, + wantContent: "hello world", + }, + { + name: "no config - mentioned", + gt: config.GroupTriggerConfig{}, + isMentioned: true, + content: "hello world", + wantRespond: true, + wantContent: "hello world", + }, + { + name: "mention_only - not mentioned", + gt: config.GroupTriggerConfig{MentionOnly: true}, + isMentioned: false, + content: "hello world", + wantRespond: false, + wantContent: "hello world", + }, + { + name: "mention_only - mentioned", + gt: config.GroupTriggerConfig{MentionOnly: true}, + isMentioned: true, + content: "hello world", + wantRespond: true, + wantContent: "hello world", + }, + { + name: "prefix match", + gt: config.GroupTriggerConfig{Prefixes: []string{"/ask"}}, + isMentioned: false, + content: "/ask hello", + wantRespond: true, + wantContent: "hello", + }, + { + name: "prefix no match - not mentioned", + gt: config.GroupTriggerConfig{Prefixes: []string{"/ask"}}, + isMentioned: false, + content: "hello world", + wantRespond: false, + wantContent: "hello world", + }, + { + name: "prefix no match - but mentioned", + gt: config.GroupTriggerConfig{Prefixes: []string{"/ask"}}, + isMentioned: true, + content: "hello world", + wantRespond: true, + wantContent: "hello world", + }, + { + name: "multiple prefixes - second matches", + gt: config.GroupTriggerConfig{Prefixes: []string{"/ask", "/bot"}}, + isMentioned: false, + content: "/bot help me", + wantRespond: true, + wantContent: "help me", + }, + { + name: "mention_only with prefixes - mentioned overrides", + gt: config.GroupTriggerConfig{MentionOnly: true, Prefixes: []string{"/ask"}}, + isMentioned: true, + content: "hello", + wantRespond: true, + wantContent: "hello", + }, + { + name: "mention_only with prefixes - not mentioned, no prefix", + gt: config.GroupTriggerConfig{MentionOnly: true, Prefixes: []string{"/ask"}}, + isMentioned: false, + content: "hello", + wantRespond: false, + wantContent: "hello", + }, + { + name: "empty prefix in list is skipped", + gt: config.GroupTriggerConfig{Prefixes: []string{"", "/ask"}}, + isMentioned: false, + content: "/ask test", + wantRespond: true, + wantContent: "test", + }, + { + name: "prefix strips leading whitespace after prefix", + gt: config.GroupTriggerConfig{Prefixes: []string{"/ask "}}, + isMentioned: false, + content: "/ask hello", + wantRespond: true, + wantContent: "hello", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ch := NewBaseChannel("test", nil, nil, nil, WithGroupTrigger(tt.gt)) + gotRespond, gotContent := ch.ShouldRespondInGroup(tt.isMentioned, tt.content) + if gotRespond != tt.wantRespond { + t.Errorf("ShouldRespondInGroup() respond = %v, want %v", gotRespond, tt.wantRespond) + } + if gotContent != tt.wantContent { + t.Errorf("ShouldRespondInGroup() content = %q, want %q", gotContent, tt.wantContent) + } + }) + } +} + +func TestIsAllowedSender(t *testing.T) { + tests := []struct { + name string + allowList []string + sender bus.SenderInfo + want bool + }{ + { + name: "empty allowlist allows all", + allowList: nil, + sender: bus.SenderInfo{PlatformID: "anyone"}, + want: true, + }, + { + name: "numeric ID matches PlatformID", + allowList: []string{"123456"}, + sender: bus.SenderInfo{ + Platform: "telegram", + PlatformID: "123456", + CanonicalID: "telegram:123456", + }, + want: true, + }, + { + name: "canonical format matches", + allowList: []string{"telegram:123456"}, + sender: bus.SenderInfo{ + Platform: "telegram", + PlatformID: "123456", + CanonicalID: "telegram:123456", + }, + want: true, + }, + { + name: "canonical format wrong platform", + allowList: []string{"discord:123456"}, + sender: bus.SenderInfo{ + Platform: "telegram", + PlatformID: "123456", + CanonicalID: "telegram:123456", + }, + want: false, + }, + { + name: "@username matches", + allowList: []string{"@alice"}, + sender: bus.SenderInfo{ + Platform: "telegram", + PlatformID: "123456", + CanonicalID: "telegram:123456", + Username: "alice", + }, + want: true, + }, + { + name: "compound id|username matches by ID", + allowList: []string{"123456|alice"}, + sender: bus.SenderInfo{ + Platform: "telegram", + PlatformID: "123456", + CanonicalID: "telegram:123456", + Username: "alice", + }, + want: true, + }, + { + name: "non matching sender denied", + allowList: []string{"654321"}, + sender: bus.SenderInfo{ + Platform: "telegram", + PlatformID: "123456", + CanonicalID: "telegram:123456", + }, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ch := NewBaseChannel("test", nil, nil, tt.allowList) + if got := ch.IsAllowedSender(tt.sender); got != tt.want { + t.Fatalf("IsAllowedSender(%+v) = %v, want %v", tt.sender, got, tt.want) + } + }) + } +} diff --git a/pkg/channels/dingtalk.go b/pkg/channels/dingtalk/dingtalk.go similarity index 83% rename from pkg/channels/dingtalk.go rename to pkg/channels/dingtalk/dingtalk.go index 662fba3b7..8642ad362 100644 --- a/pkg/channels/dingtalk.go +++ b/pkg/channels/dingtalk/dingtalk.go @@ -1,7 +1,7 @@ // PicoClaw - Ultra-lightweight personal AI agent // DingTalk channel implementation using Stream Mode -package channels +package dingtalk import ( "context" @@ -12,7 +12,9 @@ import ( "github.com/open-dingtalk/dingtalk-stream-sdk-go/client" "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/identity" "github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/utils" ) @@ -20,7 +22,7 @@ import ( // DingTalkChannel implements the Channel interface for DingTalk (钉钉) // It uses WebSocket for receiving messages via stream mode and API for sending type DingTalkChannel struct { - *BaseChannel + *channels.BaseChannel config config.DingTalkConfig clientID string clientSecret string @@ -37,7 +39,11 @@ func NewDingTalkChannel(cfg config.DingTalkConfig, messageBus *bus.MessageBus) ( return nil, fmt.Errorf("dingtalk client_id and client_secret are required") } - base := NewBaseChannel("dingtalk", cfg, messageBus, cfg.AllowFrom) + base := channels.NewBaseChannel("dingtalk", cfg, messageBus, cfg.AllowFrom, + channels.WithMaxMessageLength(20000), + channels.WithGroupTrigger(cfg.GroupTrigger), + channels.WithReasoningChannelID(cfg.ReasoningChannelID), + ) return &DingTalkChannel{ BaseChannel: base, @@ -70,7 +76,7 @@ func (c *DingTalkChannel) Start(ctx context.Context) error { return fmt.Errorf("failed to start stream client: %w", err) } - c.setRunning(true) + c.SetRunning(true) logger.InfoC("dingtalk", "DingTalk channel started (Stream Mode)") return nil } @@ -87,7 +93,7 @@ func (c *DingTalkChannel) Stop(ctx context.Context) error { c.streamClient.Close() } - c.setRunning(false) + c.SetRunning(false) logger.InfoC("dingtalk", "DingTalk channel stopped") return nil } @@ -95,7 +101,7 @@ func (c *DingTalkChannel) Stop(ctx context.Context) error { // Send sends a message to DingTalk via the chatbot reply API func (c *DingTalkChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { if !c.IsRunning() { - return fmt.Errorf("dingtalk channel not running") + return channels.ErrNotRunning } // Get session webhook from storage @@ -159,12 +165,17 @@ func (c *DingTalkChannel) onChatBotMessageReceived( "session_webhook": data.SessionWebhook, } + var peer bus.Peer if data.ConversationType == "1" { - metadata["peer_kind"] = "direct" - metadata["peer_id"] = senderID + peer = bus.Peer{Kind: "direct", ID: senderID} } else { - metadata["peer_kind"] = "group" - metadata["peer_id"] = data.ConversationId + peer = bus.Peer{Kind: "group", ID: data.ConversationId} + // In group chats, apply unified group trigger filtering + respond, cleaned := c.ShouldRespondInGroup(false, content) + if !respond { + return nil, nil + } + content = cleaned } logger.DebugCF("dingtalk", "Received message", map[string]any{ @@ -173,8 +184,20 @@ func (c *DingTalkChannel) onChatBotMessageReceived( "preview": utils.Truncate(content, 50), }) + // Build sender info + sender := bus.SenderInfo{ + Platform: "dingtalk", + PlatformID: senderID, + CanonicalID: identity.BuildCanonicalID("dingtalk", senderID), + DisplayName: senderNick, + } + + if !c.IsAllowedSender(sender) { + return nil, nil + } + // Handle the message through the base channel - c.HandleMessage(senderID, chatID, content, nil, metadata) + c.HandleMessage(ctx, peer, "", senderID, chatID, content, nil, metadata, sender) // Return nil to indicate we've handled the message asynchronously // The response will be sent through the message bus @@ -197,7 +220,7 @@ func (c *DingTalkChannel) SendDirectReply(ctx context.Context, sessionWebhook, c contentBytes, ) if err != nil { - return fmt.Errorf("failed to send reply: %w", err) + return fmt.Errorf("dingtalk send: %w", channels.ErrTemporary) } return nil diff --git a/pkg/channels/dingtalk/init.go b/pkg/channels/dingtalk/init.go new file mode 100644 index 000000000..5f49bce8c --- /dev/null +++ b/pkg/channels/dingtalk/init.go @@ -0,0 +1,13 @@ +package dingtalk + +import ( + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" +) + +func init() { + channels.RegisterFactory("dingtalk", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) { + return NewDingTalkChannel(cfg.Channels.DingTalk, b) + }) +} diff --git a/pkg/channels/discord.go b/pkg/channels/discord/discord.go similarity index 52% rename from pkg/channels/discord.go rename to pkg/channels/discord/discord.go index f6faa3373..cd6a2560f 100644 --- a/pkg/channels/discord.go +++ b/pkg/channels/discord/discord.go @@ -1,4 +1,4 @@ -package channels +package discord import ( "context" @@ -11,26 +11,27 @@ import ( "github.com/bwmarrin/discordgo" "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/identity" "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/media" "github.com/sipeed/picoclaw/pkg/utils" - "github.com/sipeed/picoclaw/pkg/voice" ) const ( - transcriptionTimeout = 30 * time.Second - sendTimeout = 10 * time.Second + sendTimeout = 10 * time.Second ) type DiscordChannel struct { - *BaseChannel - session *discordgo.Session - config config.DiscordConfig - transcriber *voice.GroqTranscriber - ctx context.Context - typingMu sync.Mutex - typingStop map[string]chan struct{} // chatID → stop signal - botUserID string // stored for mention checking + *channels.BaseChannel + session *discordgo.Session + config config.DiscordConfig + ctx context.Context + cancel context.CancelFunc + typingMu sync.Mutex + typingStop map[string]chan struct{} // chatID → stop signal + botUserID string // stored for mention checking } func NewDiscordChannel(cfg config.DiscordConfig, bus *bus.MessageBus) (*DiscordChannel, error) { @@ -39,33 +40,25 @@ func NewDiscordChannel(cfg config.DiscordConfig, bus *bus.MessageBus) (*DiscordC return nil, fmt.Errorf("failed to create discord session: %w", err) } - base := NewBaseChannel("discord", cfg, bus, cfg.AllowFrom) + base := channels.NewBaseChannel("discord", cfg, bus, cfg.AllowFrom, + channels.WithMaxMessageLength(2000), + channels.WithGroupTrigger(cfg.GroupTrigger), + channels.WithReasoningChannelID(cfg.ReasoningChannelID), + ) return &DiscordChannel{ BaseChannel: base, session: session, config: cfg, - transcriber: nil, ctx: context.Background(), typingStop: make(map[string]chan struct{}), }, nil } -func (c *DiscordChannel) SetTranscriber(transcriber *voice.GroqTranscriber) { - c.transcriber = transcriber -} - -func (c *DiscordChannel) getContext() context.Context { - if c.ctx == nil { - return context.Background() - } - return c.ctx -} - func (c *DiscordChannel) Start(ctx context.Context) error { logger.InfoC("discord", "Starting Discord bot") - c.ctx = ctx + c.ctx, c.cancel = context.WithCancel(ctx) // Get bot user ID before opening session to avoid race condition botUser, err := c.session.User("@me") @@ -80,7 +73,7 @@ func (c *DiscordChannel) Start(ctx context.Context) error { return fmt.Errorf("failed to open discord session: %w", err) } - c.setRunning(true) + c.SetRunning(true) logger.InfoCF("discord", "Discord bot connected", map[string]any{ "username": botUser.Username, @@ -92,7 +85,7 @@ func (c *DiscordChannel) Start(ctx context.Context) error { func (c *DiscordChannel) Stop(ctx context.Context) error { logger.InfoC("discord", "Stopping Discord bot") - c.setRunning(false) + c.SetRunning(false) // Stop all typing goroutines before closing session c.typingMu.Lock() @@ -102,6 +95,11 @@ func (c *DiscordChannel) Stop(ctx context.Context) error { } c.typingMu.Unlock() + // Cancel our context so typing goroutines using c.ctx.Done() exit + if c.cancel != nil { + c.cancel() + } + if err := c.session.Close(); err != nil { return fmt.Errorf("failed to close discord session: %w", err) } @@ -110,10 +108,8 @@ func (c *DiscordChannel) Stop(ctx context.Context) error { } func (c *DiscordChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { - c.stopTyping(msg.ChatID) - if !c.IsRunning() { - return fmt.Errorf("discord bot not running") + return channels.ErrNotRunning } channelID := msg.ChatID @@ -121,20 +117,133 @@ func (c *DiscordChannel) Send(ctx context.Context, msg bus.OutboundMessage) erro return fmt.Errorf("channel ID is empty") } - runes := []rune(msg.Content) - if len(runes) == 0 { + if len([]rune(msg.Content)) == 0 { return nil } - chunks := utils.SplitMessage(msg.Content, 2000) // Split messages into chunks, Discord length limit: 2000 chars + return c.sendChunk(ctx, channelID, msg.Content) +} - for _, chunk := range chunks { - if err := c.sendChunk(ctx, channelID, chunk); err != nil { - return err +// SendMedia implements the channels.MediaSender interface. +func (c *DiscordChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error { + if !c.IsRunning() { + return channels.ErrNotRunning + } + + channelID := msg.ChatID + if channelID == "" { + return fmt.Errorf("channel ID is empty") + } + + store := c.GetMediaStore() + if store == nil { + return fmt.Errorf("no media store available: %w", channels.ErrSendFailed) + } + + // Collect all files into a single ChannelMessageSendComplex call + files := make([]*discordgo.File, 0, len(msg.Parts)) + var caption string + + for _, part := range msg.Parts { + localPath, err := store.Resolve(part.Ref) + if err != nil { + logger.ErrorCF("discord", "Failed to resolve media ref", map[string]any{ + "ref": part.Ref, + "error": err.Error(), + }) + continue + } + + file, err := os.Open(localPath) + if err != nil { + logger.ErrorCF("discord", "Failed to open media file", map[string]any{ + "path": localPath, + "error": err.Error(), + }) + continue + } + // Note: discordgo reads from the Reader and we can't close it before send + + filename := part.Filename + if filename == "" { + filename = "file" + } + + files = append(files, &discordgo.File{ + Name: filename, + ContentType: part.ContentType, + Reader: file, + }) + + if part.Caption != "" && caption == "" { + caption = part.Caption } } - return nil + if len(files) == 0 { + return nil + } + + sendCtx, cancel := context.WithTimeout(ctx, sendTimeout) + defer cancel() + + done := make(chan error, 1) + go func() { + _, err := c.session.ChannelMessageSendComplex(channelID, &discordgo.MessageSend{ + Content: caption, + Files: files, + }) + done <- err + }() + + select { + case err := <-done: + // Close all file readers + for _, f := range files { + if closer, ok := f.Reader.(*os.File); ok { + closer.Close() + } + } + if err != nil { + return fmt.Errorf("discord send media: %w", channels.ErrTemporary) + } + return nil + case <-sendCtx.Done(): + // Close all file readers + for _, f := range files { + if closer, ok := f.Reader.(*os.File); ok { + closer.Close() + } + } + return sendCtx.Err() + } +} + +// EditMessage implements channels.MessageEditor. +func (c *DiscordChannel) EditMessage(ctx context.Context, chatID string, messageID string, content string) error { + _, err := c.session.ChannelMessageEdit(chatID, messageID, content) + return err +} + +// SendPlaceholder implements channels.PlaceholderCapable. +// It sends a placeholder message that will later be edited to the actual +// response via EditMessage (channels.MessageEditor). +func (c *DiscordChannel) SendPlaceholder(ctx context.Context, chatID string) (string, error) { + if !c.config.Placeholder.Enabled { + return "", nil + } + + text := c.config.Placeholder.Text + if text == "" { + text = "Thinking... 💭" + } + + msg, err := c.session.ChannelMessageSend(chatID, text) + if err != nil { + return "", err + } + + return msg.ID, nil } func (c *DiscordChannel) sendChunk(ctx context.Context, channelID, content string) error { @@ -151,11 +260,11 @@ func (c *DiscordChannel) sendChunk(ctx context.Context, channelID, content strin select { case err := <-done: if err != nil { - return fmt.Errorf("failed to send discord message: %w", err) + return fmt.Errorf("discord send: %w", channels.ErrTemporary) } return nil case <-sendCtx.Done(): - return fmt.Errorf("send message timeout: %w", sendCtx.Err()) + return sendCtx.Err() } } @@ -176,17 +285,32 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag return } - // Check allowlist first to avoid downloading attachments and transcribing for rejected users - if !c.IsAllowed(m.Author.ID) { + // Check allowlist first to avoid downloading attachments for rejected users + sender := bus.SenderInfo{ + Platform: "discord", + PlatformID: m.Author.ID, + CanonicalID: identity.BuildCanonicalID("discord", m.Author.ID), + Username: m.Author.Username, + } + // Build display name + displayName := m.Author.Username + if m.Author.Discriminator != "" && m.Author.Discriminator != "0" { + displayName += "#" + m.Author.Discriminator + } + sender.DisplayName = displayName + + if !c.IsAllowedSender(sender) { logger.DebugCF("discord", "Message rejected by allowlist", map[string]any{ "user_id": m.Author.ID, }) return } - // If configured to only respond to mentions, check if bot is mentioned - // Skip this check for DMs (GuildID is empty) - DMs should always be responded to - if c.config.MentionOnly && m.GuildID != "" { + content := m.Content + + // In guild (group) channels, apply unified group trigger filtering + // DMs (GuildID is empty) always get a response + if m.GuildID != "" { isMentioned := false for _, mention := range m.Mentions { if mention.ID == c.botUserID { @@ -194,36 +318,39 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag break } } - if !isMentioned { - logger.DebugCF("discord", "Message ignored - bot not mentioned", map[string]any{ + content = c.stripBotMention(content) + respond, cleaned := c.ShouldRespondInGroup(isMentioned, content) + if !respond { + logger.DebugCF("discord", "Group message ignored by group trigger", map[string]any{ "user_id": m.Author.ID, }) return } + content = cleaned + } else { + // DMs: just strip bot mention without filtering + content = c.stripBotMention(content) } senderID := m.Author.ID - senderName := m.Author.Username - if m.Author.Discriminator != "" && m.Author.Discriminator != "0" { - senderName += "#" + m.Author.Discriminator - } - content := m.Content - content = c.stripBotMention(content) mediaPaths := make([]string, 0, len(m.Attachments)) - localFiles := make([]string, 0, len(m.Attachments)) - // Ensure temp files are cleaned up when function returns - defer func() { - for _, file := range localFiles { - if err := os.Remove(file); err != nil { - logger.DebugCF("discord", "Failed to cleanup temp file", map[string]any{ - "file": file, - "error": err.Error(), - }) + scope := channels.BuildMediaScope("discord", m.ChannelID, m.ID) + + // Helper to register a local file with the media store + storeMedia := func(localPath, filename string) string { + if store := c.GetMediaStore(); store != nil { + ref, err := store.Store(localPath, media.MediaMeta{ + Filename: filename, + Source: "discord", + }, scope) + if err == nil { + return ref } } - }() + return localPath // fallback + } for _, attachment := range m.Attachments { isAudio := utils.IsAudioFile(attachment.Filename, attachment.ContentType) @@ -231,30 +358,8 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag if isAudio { localPath := c.downloadAttachment(attachment.URL, attachment.Filename) if localPath != "" { - localFiles = append(localFiles, localPath) - - var transcribedText string - if c.transcriber != nil && c.transcriber.IsAvailable() { - ctx, cancel := context.WithTimeout(c.getContext(), transcriptionTimeout) - result, err := c.transcriber.Transcribe(ctx, localPath) - cancel() // Release context resources immediately to avoid leaks in for loop - - if err != nil { - logger.ErrorCF("discord", "Voice transcription failed", map[string]any{ - "error": err.Error(), - }) - transcribedText = fmt.Sprintf("[audio: %s (transcription failed)]", attachment.Filename) - } else { - transcribedText = fmt.Sprintf("[audio transcription: %s]", result.Text) - logger.DebugCF("discord", "Audio transcribed successfully", map[string]any{ - "text": result.Text, - }) - } - } else { - transcribedText = fmt.Sprintf("[audio: %s]", attachment.Filename) - } - - content = appendContent(content, transcribedText) + mediaPaths = append(mediaPaths, storeMedia(localPath, attachment.Filename)) + content = appendContent(content, fmt.Sprintf("[audio: %s]", attachment.Filename)) } else { logger.WarnCF("discord", "Failed to download audio attachment", map[string]any{ "url": attachment.URL, @@ -277,11 +382,8 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag content = "[media only]" } - // Start typing after all early returns — guaranteed to have a matching Send() - c.startTyping(m.ChannelID) - logger.DebugCF("discord", "Received message", map[string]any{ - "sender_name": senderName, + "sender_name": sender.DisplayName, "sender_id": senderID, "preview": utils.Truncate(content, 50), }) @@ -293,19 +395,18 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag peerID = senderID } + peer := bus.Peer{Kind: peerKind, ID: peerID} + metadata := map[string]string{ - "message_id": m.ID, "user_id": senderID, "username": m.Author.Username, - "display_name": senderName, + "display_name": sender.DisplayName, "guild_id": m.GuildID, "channel_id": m.ChannelID, "is_dm": fmt.Sprintf("%t", m.GuildID == ""), - "peer_kind": peerKind, - "peer_id": peerID, } - c.HandleMessage(senderID, m.ChannelID, content, mediaPaths, metadata) + c.HandleMessage(c.ctx, peer, m.ID, senderID, m.ChannelID, content, mediaPaths, metadata, sender) } // startTyping starts a continuous typing indicator loop for the given chatID. @@ -354,6 +455,13 @@ func (c *DiscordChannel) stopTyping(chatID string) { } } +// StartTyping implements channels.TypingCapable. +// It starts a continuous typing indicator and returns an idempotent stop function. +func (c *DiscordChannel) StartTyping(ctx context.Context, chatID string) (func(), error) { + c.startTyping(chatID) + return func() { c.stopTyping(chatID) }, nil +} + func (c *DiscordChannel) downloadAttachment(url, filename string) string { return utils.DownloadFile(url, filename, utils.DownloadOptions{ LoggerPrefix: "discord", diff --git a/pkg/channels/discord/init.go b/pkg/channels/discord/init.go new file mode 100644 index 000000000..15a539804 --- /dev/null +++ b/pkg/channels/discord/init.go @@ -0,0 +1,13 @@ +package discord + +import ( + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" +) + +func init() { + channels.RegisterFactory("discord", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) { + return NewDiscordChannel(cfg.Channels.Discord, b) + }) +} diff --git a/pkg/channels/errors.go b/pkg/channels/errors.go new file mode 100644 index 000000000..09ee88b3f --- /dev/null +++ b/pkg/channels/errors.go @@ -0,0 +1,21 @@ +package channels + +import "errors" + +var ( + // ErrNotRunning indicates the channel is not running. + // Manager will not retry. + ErrNotRunning = errors.New("channel not running") + + // ErrRateLimit indicates the platform returned a rate-limit response (e.g. HTTP 429). + // Manager will wait a fixed delay and retry. + ErrRateLimit = errors.New("rate limited") + + // ErrTemporary indicates a transient failure (e.g. network timeout, 5xx). + // Manager will use exponential backoff and retry. + ErrTemporary = errors.New("temporary failure") + + // ErrSendFailed indicates a permanent failure (e.g. invalid chat ID, 4xx non-429). + // Manager will not retry. + ErrSendFailed = errors.New("send failed") +) diff --git a/pkg/channels/errors_test.go b/pkg/channels/errors_test.go new file mode 100644 index 000000000..e5592345a --- /dev/null +++ b/pkg/channels/errors_test.go @@ -0,0 +1,56 @@ +package channels + +import ( + "errors" + "fmt" + "testing" +) + +func TestErrorsIs(t *testing.T) { + wrapped := fmt.Errorf("telegram API: %w", ErrRateLimit) + if !errors.Is(wrapped, ErrRateLimit) { + t.Error("wrapped ErrRateLimit should match") + } + if errors.Is(wrapped, ErrTemporary) { + t.Error("wrapped ErrRateLimit should not match ErrTemporary") + } +} + +func TestErrorsIsAllTypes(t *testing.T) { + sentinels := []error{ErrNotRunning, ErrRateLimit, ErrTemporary, ErrSendFailed} + + for _, sentinel := range sentinels { + wrapped := fmt.Errorf("context: %w", sentinel) + if !errors.Is(wrapped, sentinel) { + t.Errorf("wrapped %v should match itself", sentinel) + } + + // Verify it doesn't match other sentinel errors + for _, other := range sentinels { + if other == sentinel { + continue + } + if errors.Is(wrapped, other) { + t.Errorf("wrapped %v should not match %v", sentinel, other) + } + } + } +} + +func TestErrorMessages(t *testing.T) { + tests := []struct { + err error + want string + }{ + {ErrNotRunning, "channel not running"}, + {ErrRateLimit, "rate limited"}, + {ErrTemporary, "temporary failure"}, + {ErrSendFailed, "send failed"}, + } + + for _, tt := range tests { + if got := tt.err.Error(); got != tt.want { + t.Errorf("error message = %q, want %q", got, tt.want) + } + } +} diff --git a/pkg/channels/errutil.go b/pkg/channels/errutil.go new file mode 100644 index 000000000..319e3c980 --- /dev/null +++ b/pkg/channels/errutil.go @@ -0,0 +1,30 @@ +package channels + +import ( + "fmt" + "net/http" +) + +// ClassifySendError wraps a raw error with the appropriate sentinel based on +// an HTTP status code. Channels that perform HTTP API calls should use this +// in their Send path. +func ClassifySendError(statusCode int, rawErr error) error { + switch { + case statusCode == http.StatusTooManyRequests: + return fmt.Errorf("%w: %v", ErrRateLimit, rawErr) + case statusCode >= 500: + return fmt.Errorf("%w: %v", ErrTemporary, rawErr) + case statusCode >= 400: + return fmt.Errorf("%w: %v", ErrSendFailed, rawErr) + default: + return rawErr + } +} + +// ClassifyNetError wraps a network/timeout error as ErrTemporary. +func ClassifyNetError(err error) error { + if err == nil { + return nil + } + return fmt.Errorf("%w: %v", ErrTemporary, err) +} diff --git a/pkg/channels/errutil_test.go b/pkg/channels/errutil_test.go new file mode 100644 index 000000000..e3d35f65b --- /dev/null +++ b/pkg/channels/errutil_test.go @@ -0,0 +1,97 @@ +package channels + +import ( + "errors" + "fmt" + "testing" +) + +func TestClassifySendError(t *testing.T) { + raw := fmt.Errorf("some API error") + + tests := []struct { + name string + statusCode int + wantIs error + wantNil bool + }{ + {"429 -> ErrRateLimit", 429, ErrRateLimit, false}, + {"500 -> ErrTemporary", 500, ErrTemporary, false}, + {"502 -> ErrTemporary", 502, ErrTemporary, false}, + {"503 -> ErrTemporary", 503, ErrTemporary, false}, + {"400 -> ErrSendFailed", 400, ErrSendFailed, false}, + {"403 -> ErrSendFailed", 403, ErrSendFailed, false}, + {"404 -> ErrSendFailed", 404, ErrSendFailed, false}, + {"200 -> raw error", 200, nil, false}, + {"201 -> raw error", 201, nil, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ClassifySendError(tt.statusCode, raw) + if err == nil { + t.Fatal("expected non-nil error") + } + if tt.wantIs != nil { + if !errors.Is(err, tt.wantIs) { + t.Errorf("errors.Is(err, %v) = false, want true; err = %v", tt.wantIs, err) + } + } else { + // Should return the raw error unchanged + if err != raw { + t.Errorf("expected raw error to be returned unchanged for status %d, got %v", tt.statusCode, err) + } + } + }) + } +} + +func TestClassifySendErrorNoFalsePositive(t *testing.T) { + raw := fmt.Errorf("some error") + + // 429 should NOT match ErrTemporary or ErrSendFailed + err := ClassifySendError(429, raw) + if errors.Is(err, ErrTemporary) { + t.Error("429 should not match ErrTemporary") + } + if errors.Is(err, ErrSendFailed) { + t.Error("429 should not match ErrSendFailed") + } + + // 500 should NOT match ErrRateLimit or ErrSendFailed + err = ClassifySendError(500, raw) + if errors.Is(err, ErrRateLimit) { + t.Error("500 should not match ErrRateLimit") + } + if errors.Is(err, ErrSendFailed) { + t.Error("500 should not match ErrSendFailed") + } + + // 400 should NOT match ErrRateLimit or ErrTemporary + err = ClassifySendError(400, raw) + if errors.Is(err, ErrRateLimit) { + t.Error("400 should not match ErrRateLimit") + } + if errors.Is(err, ErrTemporary) { + t.Error("400 should not match ErrTemporary") + } +} + +func TestClassifyNetError(t *testing.T) { + t.Run("nil error returns nil", func(t *testing.T) { + if err := ClassifyNetError(nil); err != nil { + t.Errorf("expected nil, got %v", err) + } + }) + + t.Run("non-nil error wraps as ErrTemporary", func(t *testing.T) { + raw := fmt.Errorf("connection refused") + err := ClassifyNetError(raw) + if err == nil { + t.Fatal("expected non-nil error") + } + if !errors.Is(err, ErrTemporary) { + t.Errorf("errors.Is(err, ErrTemporary) = false, want true; err = %v", err) + } + }) +} diff --git a/pkg/channels/feishu/common.go b/pkg/channels/feishu/common.go new file mode 100644 index 000000000..e8a057741 --- /dev/null +++ b/pkg/channels/feishu/common.go @@ -0,0 +1,9 @@ +package feishu + +// stringValue safely dereferences a *string pointer. +func stringValue(v *string) string { + if v == nil { + return "" + } + return *v +} diff --git a/pkg/channels/feishu_32.go b/pkg/channels/feishu/feishu_32.go similarity index 93% rename from pkg/channels/feishu_32.go rename to pkg/channels/feishu/feishu_32.go index 5109b8195..d0ec758c6 100644 --- a/pkg/channels/feishu_32.go +++ b/pkg/channels/feishu/feishu_32.go @@ -1,18 +1,19 @@ //go:build !amd64 && !arm64 && !riscv64 && !mips64 && !ppc64 -package channels +package feishu import ( "context" "errors" "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" "github.com/sipeed/picoclaw/pkg/config" ) // FeishuChannel is a stub implementation for 32-bit architectures type FeishuChannel struct { - *BaseChannel + *channels.BaseChannel } // NewFeishuChannel returns an error on 32-bit architectures where the Feishu SDK is not supported diff --git a/pkg/channels/feishu_64.go b/pkg/channels/feishu/feishu_64.go similarity index 78% rename from pkg/channels/feishu_64.go rename to pkg/channels/feishu/feishu_64.go index 42e74980f..1db1bf669 100644 --- a/pkg/channels/feishu_64.go +++ b/pkg/channels/feishu/feishu_64.go @@ -1,6 +1,6 @@ //go:build amd64 || arm64 || riscv64 || mips64 || ppc64 -package channels +package feishu import ( "context" @@ -15,13 +15,15 @@ import ( larkws "github.com/larksuite/oapi-sdk-go/v3/ws" "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/identity" "github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/utils" ) type FeishuChannel struct { - *BaseChannel + *channels.BaseChannel config config.FeishuConfig client *lark.Client wsClient *larkws.Client @@ -31,7 +33,10 @@ type FeishuChannel struct { } func NewFeishuChannel(cfg config.FeishuConfig, bus *bus.MessageBus) (*FeishuChannel, error) { - base := NewBaseChannel("feishu", cfg, bus, cfg.AllowFrom) + base := channels.NewBaseChannel("feishu", cfg, bus, cfg.AllowFrom, + channels.WithGroupTrigger(cfg.GroupTrigger), + channels.WithReasoningChannelID(cfg.ReasoningChannelID), + ) return &FeishuChannel{ BaseChannel: base, @@ -60,7 +65,7 @@ func (c *FeishuChannel) Start(ctx context.Context) error { wsClient := c.wsClient c.mu.Unlock() - c.setRunning(true) + c.SetRunning(true) logger.InfoC("feishu", "Feishu channel started (websocket mode)") go func() { @@ -83,14 +88,14 @@ func (c *FeishuChannel) Stop(ctx context.Context) error { c.wsClient = nil c.mu.Unlock() - c.setRunning(false) + c.SetRunning(false) logger.InfoC("feishu", "Feishu channel stopped") return nil } func (c *FeishuChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { if !c.IsRunning() { - return fmt.Errorf("feishu channel not running") + return channels.ErrNotRunning } if msg.ChatID == "" { @@ -114,11 +119,11 @@ func (c *FeishuChannel) Send(ctx context.Context, msg bus.OutboundMessage) error resp, err := c.client.Im.V1.Message.Create(ctx, req) if err != nil { - return fmt.Errorf("failed to send feishu message: %w", err) + return fmt.Errorf("feishu send: %w", channels.ErrTemporary) } if !resp.Success() { - return fmt.Errorf("feishu api error: code=%d msg=%s", resp.Code, resp.Msg) + return fmt.Errorf("feishu api error (code=%d msg=%s): %w", resp.Code, resp.Msg, channels.ErrTemporary) } logger.DebugCF("feishu", "Feishu message sent", map[string]any{ @@ -128,7 +133,7 @@ func (c *FeishuChannel) Send(ctx context.Context, msg bus.OutboundMessage) error return nil } -func (c *FeishuChannel) handleMessageReceive(_ context.Context, event *larkim.P2MessageReceiveV1) error { +func (c *FeishuChannel) handleMessageReceive(ctx context.Context, event *larkim.P2MessageReceiveV1) error { if event == nil || event.Event == nil || event.Event.Message == nil { return nil } @@ -152,8 +157,9 @@ func (c *FeishuChannel) handleMessageReceive(_ context.Context, event *larkim.P2 } metadata := map[string]string{} - if messageID := stringValue(message.MessageId); messageID != "" { - metadata["message_id"] = messageID + messageID := "" + if mid := stringValue(message.MessageId); mid != "" { + messageID = mid } if messageType := stringValue(message.MessageType); messageType != "" { metadata["message_type"] = messageType @@ -166,12 +172,17 @@ func (c *FeishuChannel) handleMessageReceive(_ context.Context, event *larkim.P2 } chatType := stringValue(message.ChatType) + var peer bus.Peer if chatType == "p2p" { - metadata["peer_kind"] = "direct" - metadata["peer_id"] = senderID + peer = bus.Peer{Kind: "direct", ID: senderID} } else { - metadata["peer_kind"] = "group" - metadata["peer_id"] = chatID + peer = bus.Peer{Kind: "group", ID: chatID} + // In group chats, apply unified group trigger filtering + respond, cleaned := c.ShouldRespondInGroup(false, content) + if !respond { + return nil + } + content = cleaned } logger.InfoCF("feishu", "Feishu message received", map[string]any{ @@ -180,7 +191,17 @@ func (c *FeishuChannel) handleMessageReceive(_ context.Context, event *larkim.P2 "preview": utils.Truncate(content, 80), }) - c.HandleMessage(senderID, chatID, content, nil, metadata) + senderInfo := bus.SenderInfo{ + Platform: "feishu", + PlatformID: senderID, + CanonicalID: identity.BuildCanonicalID("feishu", senderID), + } + + if !c.IsAllowedSender(senderInfo) { + return nil + } + + c.HandleMessage(ctx, peer, messageID, senderID, chatID, content, nil, metadata, senderInfo) return nil } @@ -218,10 +239,3 @@ func extractFeishuMessageContent(message *larkim.EventMessage) string { return *message.Content } - -func stringValue(v *string) string { - if v == nil { - return "" - } - return *v -} diff --git a/pkg/channels/feishu/init.go b/pkg/channels/feishu/init.go new file mode 100644 index 000000000..7e5a62dae --- /dev/null +++ b/pkg/channels/feishu/init.go @@ -0,0 +1,13 @@ +package feishu + +import ( + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" +) + +func init() { + channels.RegisterFactory("feishu", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) { + return NewFeishuChannel(cfg.Channels.Feishu, b) + }) +} diff --git a/pkg/channels/interfaces.go b/pkg/channels/interfaces.go new file mode 100644 index 000000000..74caeeac5 --- /dev/null +++ b/pkg/channels/interfaces.go @@ -0,0 +1,41 @@ +package channels + +import "context" + +// TypingCapable — channels that can show a typing/thinking indicator. +// StartTyping begins the indicator and returns a stop function. +// The stop function MUST be idempotent and safe to call multiple times. +type TypingCapable interface { + StartTyping(ctx context.Context, chatID string) (stop func(), err error) +} + +// MessageEditor — channels that can edit an existing message. +// messageID is always string; channels convert platform-specific types internally. +type MessageEditor interface { + EditMessage(ctx context.Context, chatID string, messageID string, content string) error +} + +// ReactionCapable — channels that can add a reaction (e.g. 👀) to an inbound message. +// ReactToMessage adds a reaction and returns an undo function to remove it. +// The undo function MUST be idempotent and safe to call multiple times. +type ReactionCapable interface { + ReactToMessage(ctx context.Context, chatID, messageID string) (undo func(), err error) +} + +// PlaceholderCapable — channels that can send a placeholder message +// (e.g. "Thinking... 💭") that will later be edited to the actual response. +// The channel MUST also implement MessageEditor for the placeholder to be useful. +// SendPlaceholder returns the platform message ID of the placeholder so that +// Manager.preSend can later edit it via MessageEditor.EditMessage. +type PlaceholderCapable interface { + SendPlaceholder(ctx context.Context, chatID string) (messageID string, err error) +} + +// PlaceholderRecorder is injected into channels by Manager. +// Channels call these methods on inbound to register typing/placeholder state. +// Manager uses the registered state on outbound to stop typing and edit placeholders. +type PlaceholderRecorder interface { + RecordPlaceholder(channel, chatID, placeholderID string) + RecordTypingStop(channel, chatID string, stop func()) + RecordReactionUndo(channel, chatID string, undo func()) +} diff --git a/pkg/channels/line/init.go b/pkg/channels/line/init.go new file mode 100644 index 000000000..9265575cc --- /dev/null +++ b/pkg/channels/line/init.go @@ -0,0 +1,13 @@ +package line + +import ( + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" +) + +func init() { + channels.RegisterFactory("line", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) { + return NewLINEChannel(cfg.Channels.LINE, b) + }) +} diff --git a/pkg/channels/line.go b/pkg/channels/line/line.go similarity index 75% rename from pkg/channels/line.go rename to pkg/channels/line/line.go index 44134996f..9fac2831c 100644 --- a/pkg/channels/line.go +++ b/pkg/channels/line/line.go @@ -1,4 +1,4 @@ -package channels +package line import ( "bytes" @@ -10,14 +10,16 @@ import ( "fmt" "io" "net/http" - "os" "strings" "sync" "time" "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/identity" "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/media" "github.com/sipeed/picoclaw/pkg/utils" ) @@ -41,9 +43,8 @@ type replyTokenEntry struct { // using the LINE Messaging API with HTTP webhook for receiving messages // and REST API for sending messages. type LINEChannel struct { - *BaseChannel + *channels.BaseChannel config config.LINEConfig - httpServer *http.Server botUserID string // Bot's user ID botBasicID string // Bot's basic ID (e.g. @216ru...) botDisplayName string // Bot's display name for text-based mention detection @@ -59,7 +60,11 @@ func NewLINEChannel(cfg config.LINEConfig, messageBus *bus.MessageBus) (*LINECha return nil, fmt.Errorf("line channel_secret and channel_access_token are required") } - base := NewBaseChannel("line", cfg, messageBus, cfg.AllowFrom) + base := channels.NewBaseChannel("line", cfg, messageBus, cfg.AllowFrom, + channels.WithMaxMessageLength(5000), + channels.WithGroupTrigger(cfg.GroupTrigger), + channels.WithReasoningChannelID(cfg.ReasoningChannelID), + ) return &LINEChannel{ BaseChannel: base, @@ -67,7 +72,7 @@ func NewLINEChannel(cfg config.LINEConfig, messageBus *bus.MessageBus) (*LINECha }, nil } -// Start launches the HTTP webhook server. +// Start initializes the LINE channel. func (c *LINEChannel) Start(ctx context.Context) error { logger.InfoC("line", "Starting LINE channel (Webhook Mode)") @@ -86,32 +91,7 @@ func (c *LINEChannel) Start(ctx context.Context) error { }) } - mux := http.NewServeMux() - path := c.config.WebhookPath - if path == "" { - path = "/webhook/line" - } - mux.HandleFunc(path, c.webhookHandler) - - addr := fmt.Sprintf("%s:%d", c.config.WebhookHost, c.config.WebhookPort) - c.httpServer = &http.Server{ - Addr: addr, - Handler: mux, - } - - go func() { - logger.InfoCF("line", "LINE webhook server listening", map[string]any{ - "addr": addr, - "path": path, - }) - if err := c.httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed { - logger.ErrorCF("line", "Webhook server error", map[string]any{ - "error": err.Error(), - }) - } - }() - - c.setRunning(true) + c.SetRunning(true) logger.InfoC("line", "LINE channel started (Webhook Mode)") return nil } @@ -150,7 +130,7 @@ func (c *LINEChannel) fetchBotInfo() error { return nil } -// Stop gracefully shuts down the HTTP server. +// Stop gracefully stops the LINE channel. func (c *LINEChannel) Stop(ctx context.Context) error { logger.InfoC("line", "Stopping LINE channel") @@ -158,21 +138,24 @@ func (c *LINEChannel) Stop(ctx context.Context) error { c.cancel() } - if c.httpServer != nil { - shutdownCtx, cancel := context.WithTimeout(ctx, 5*time.Second) - defer cancel() - if err := c.httpServer.Shutdown(shutdownCtx); err != nil { - logger.ErrorCF("line", "Webhook server shutdown error", map[string]any{ - "error": err.Error(), - }) - } - } - - c.setRunning(false) + c.SetRunning(false) logger.InfoC("line", "LINE channel stopped") return nil } +// WebhookPath returns the path for registering on the shared HTTP server. +func (c *LINEChannel) WebhookPath() string { + if c.config.WebhookPath != "" { + return c.config.WebhookPath + } + return "/webhook/line" +} + +// ServeHTTP implements http.Handler for the shared HTTP server. +func (c *LINEChannel) ServeHTTP(w http.ResponseWriter, r *http.Request) { + c.webhookHandler(w, r) +} + // webhookHandler handles incoming LINE webhook requests. func (c *LINEChannel) webhookHandler(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { @@ -284,14 +267,6 @@ func (c *LINEChannel) processEvent(event lineEvent) { return } - // In group chats, only respond when the bot is mentioned - if isGroup && !c.isBotMentioned(msg) { - logger.DebugCF("line", "Ignoring group message without mention", map[string]any{ - "chat_id": chatID, - }) - return - } - // Store reply token for later use if event.ReplyToken != "" { c.replyTokens.Store(chatID, replyTokenEntry{ @@ -307,18 +282,22 @@ func (c *LINEChannel) processEvent(event lineEvent) { var content string var mediaPaths []string - localFiles := []string{} - defer func() { - for _, file := range localFiles { - if err := os.Remove(file); err != nil { - logger.DebugCF("line", "Failed to cleanup temp file", map[string]any{ - "file": file, - "error": err.Error(), - }) + scope := channels.BuildMediaScope("line", chatID, msg.ID) + + // Helper to register a local file with the media store + storeMedia := func(localPath, filename string) string { + if store := c.GetMediaStore(); store != nil { + ref, err := store.Store(localPath, media.MediaMeta{ + Filename: filename, + Source: "line", + }, scope) + if err == nil { + return ref } } - }() + return localPath // fallback + } switch msg.Type { case "text": @@ -330,22 +309,19 @@ func (c *LINEChannel) processEvent(event lineEvent) { case "image": localPath := c.downloadContent(msg.ID, "image.jpg") if localPath != "" { - localFiles = append(localFiles, localPath) - mediaPaths = append(mediaPaths, localPath) + mediaPaths = append(mediaPaths, storeMedia(localPath, "image.jpg")) content = "[image]" } case "audio": localPath := c.downloadContent(msg.ID, "audio.m4a") if localPath != "" { - localFiles = append(localFiles, localPath) - mediaPaths = append(mediaPaths, localPath) + mediaPaths = append(mediaPaths, storeMedia(localPath, "audio.m4a")) content = "[audio]" } case "video": localPath := c.downloadContent(msg.ID, "video.mp4") if localPath != "" { - localFiles = append(localFiles, localPath) - mediaPaths = append(mediaPaths, localPath) + mediaPaths = append(mediaPaths, storeMedia(localPath, "video.mp4")) content = "[video]" } case "file": @@ -360,18 +336,29 @@ func (c *LINEChannel) processEvent(event lineEvent) { return } + // In group chats, apply unified group trigger filtering + if isGroup { + isMentioned := c.isBotMentioned(msg) + respond, cleaned := c.ShouldRespondInGroup(isMentioned, content) + if !respond { + logger.DebugCF("line", "Ignoring group message by group trigger", map[string]any{ + "chat_id": chatID, + }) + return + } + content = cleaned + } + metadata := map[string]string{ "platform": "line", "source_type": event.Source.Type, - "message_id": msg.ID, } + var peer bus.Peer if isGroup { - metadata["peer_kind"] = "group" - metadata["peer_id"] = chatID + peer = bus.Peer{Kind: "group", ID: chatID} } else { - metadata["peer_kind"] = "direct" - metadata["peer_id"] = senderID + peer = bus.Peer{Kind: "direct", ID: senderID} } logger.DebugCF("line", "Received message", map[string]any{ @@ -382,10 +369,17 @@ func (c *LINEChannel) processEvent(event lineEvent) { "preview": utils.Truncate(content, 50), }) - // Show typing/loading indicator (requires user ID, not group ID) - c.sendLoading(senderID) + sender := bus.SenderInfo{ + Platform: "line", + PlatformID: senderID, + CanonicalID: identity.BuildCanonicalID("line", senderID), + } - c.HandleMessage(senderID, chatID, content, mediaPaths, metadata) + if !c.IsAllowedSender(sender) { + return + } + + c.HandleMessage(c.ctx, peer, msg.ID, senderID, chatID, content, mediaPaths, metadata, sender) } // isBotMentioned checks if the bot is mentioned in the message. @@ -491,7 +485,7 @@ func (c *LINEChannel) resolveChatID(source lineSource) string { // using a cached reply token, then falls back to the Push API. func (c *LINEChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { if !c.IsRunning() { - return fmt.Errorf("line channel not running") + return channels.ErrNotRunning } // Load and consume quote token for this chat @@ -519,6 +513,36 @@ func (c *LINEChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { return c.sendPush(ctx, msg.ChatID, msg.Content, quoteToken) } +// SendMedia implements the channels.MediaSender interface. +// LINE requires media to be accessible via public URL; since we only have local files, +// we fall back to sending a text message with the filename/caption. +// For full support, an external file hosting service would be needed. +func (c *LINEChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error { + if !c.IsRunning() { + return channels.ErrNotRunning + } + + store := c.GetMediaStore() + if store == nil { + return fmt.Errorf("no media store available: %w", channels.ErrSendFailed) + } + + // LINE Messaging API requires publicly accessible URLs for media messages. + // Since we only have local file paths, send caption text as fallback. + for _, part := range msg.Parts { + caption := part.Caption + if caption == "" { + caption = fmt.Sprintf("[%s: %s]", part.Type, part.Filename) + } + + if err := c.sendPush(ctx, msg.ChatID, caption, ""); err != nil { + return err + } + } + + return nil +} + // buildTextMessage creates a text message object, optionally with quoteToken. func buildTextMessage(content, quoteToken string) map[string]string { msg := map[string]string{ @@ -551,17 +575,58 @@ func (c *LINEChannel) sendPush(ctx context.Context, to, content, quoteToken stri return c.callAPI(ctx, linePushEndpoint, payload) } +// StartTyping implements channels.TypingCapable using LINE's loading animation. +// +// NOTE: The LINE loading animation API only works for 1:1 chats. +// Group/room chat IDs (starting with "C" or "R") are detected automatically; +// for these, a no-op stop function is returned without calling the API. +func (c *LINEChannel) StartTyping(ctx context.Context, chatID string) (func(), error) { + if chatID == "" { + return func() {}, nil + } + + // Group/room chats: LINE loading animation is 1:1 only. + if strings.HasPrefix(chatID, "C") || strings.HasPrefix(chatID, "R") { + return func() {}, nil + } + + typingCtx, cancel := context.WithCancel(ctx) + var once sync.Once + stop := func() { once.Do(cancel) } + + // Send immediately, then refresh periodically for long-running tasks. + if err := c.sendLoading(typingCtx, chatID); err != nil { + stop() + return stop, err + } + + ticker := time.NewTicker(50 * time.Second) + go func() { + defer ticker.Stop() + for { + select { + case <-typingCtx.Done(): + return + case <-ticker.C: + if err := c.sendLoading(typingCtx, chatID); err != nil { + logger.DebugCF("line", "Failed to refresh loading indicator", map[string]any{ + "error": err.Error(), + }) + } + } + } + }() + + return stop, nil +} + // sendLoading sends a loading animation indicator to the chat. -func (c *LINEChannel) sendLoading(chatID string) { +func (c *LINEChannel) sendLoading(ctx context.Context, chatID string) error { payload := map[string]any{ "chatId": chatID, "loadingSeconds": 60, } - if err := c.callAPI(c.ctx, lineLoadingEndpoint, payload); err != nil { - logger.DebugCF("line", "Failed to send loading indicator", map[string]any{ - "error": err.Error(), - }) - } + return c.callAPI(ctx, lineLoadingEndpoint, payload) } // callAPI makes an authenticated POST request to the LINE API. @@ -582,13 +647,13 @@ func (c *LINEChannel) callAPI(ctx context.Context, endpoint string, payload any) client := &http.Client{Timeout: 30 * time.Second} resp, err := client.Do(req) if err != nil { - return fmt.Errorf("API request failed: %w", err) + return channels.ClassifyNetError(err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { respBody, _ := io.ReadAll(resp.Body) - return fmt.Errorf("LINE API error (status %d): %s", resp.StatusCode, string(respBody)) + return channels.ClassifySendError(resp.StatusCode, fmt.Errorf("LINE API error: %s", string(respBody))) } return nil diff --git a/pkg/channels/maixcam/init.go b/pkg/channels/maixcam/init.go new file mode 100644 index 000000000..5a269b22b --- /dev/null +++ b/pkg/channels/maixcam/init.go @@ -0,0 +1,13 @@ +package maixcam + +import ( + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" +) + +func init() { + channels.RegisterFactory("maixcam", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) { + return NewMaixCamChannel(cfg.Channels.MaixCam, b) + }) +} diff --git a/pkg/channels/maixcam.go b/pkg/channels/maixcam/maixcam.go similarity index 78% rename from pkg/channels/maixcam.go rename to pkg/channels/maixcam/maixcam.go index 34ce62b20..ff9a3ed1a 100644 --- a/pkg/channels/maixcam.go +++ b/pkg/channels/maixcam/maixcam.go @@ -1,4 +1,4 @@ -package channels +package maixcam import ( "context" @@ -6,16 +6,21 @@ import ( "fmt" "net" "sync" + "time" "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/identity" "github.com/sipeed/picoclaw/pkg/logger" ) type MaixCamChannel struct { - *BaseChannel + *channels.BaseChannel config config.MaixCamConfig listener net.Listener + ctx context.Context + cancel context.CancelFunc clients map[net.Conn]bool clientsMux sync.RWMutex } @@ -28,7 +33,13 @@ type MaixCamMessage struct { } func NewMaixCamChannel(cfg config.MaixCamConfig, bus *bus.MessageBus) (*MaixCamChannel, error) { - base := NewBaseChannel("maixcam", cfg, bus, cfg.AllowFrom) + base := channels.NewBaseChannel( + "maixcam", + cfg, + bus, + cfg.AllowFrom, + channels.WithReasoningChannelID(cfg.ReasoningChannelID), + ) return &MaixCamChannel{ BaseChannel: base, @@ -40,37 +51,40 @@ func NewMaixCamChannel(cfg config.MaixCamConfig, bus *bus.MessageBus) (*MaixCamC func (c *MaixCamChannel) Start(ctx context.Context) error { logger.InfoC("maixcam", "Starting MaixCam channel server") + c.ctx, c.cancel = context.WithCancel(ctx) + addr := fmt.Sprintf("%s:%d", c.config.Host, c.config.Port) listener, err := net.Listen("tcp", addr) if err != nil { + c.cancel() return fmt.Errorf("failed to listen on %s: %w", addr, err) } c.listener = listener - c.setRunning(true) + c.SetRunning(true) logger.InfoCF("maixcam", "MaixCam server listening", map[string]any{ "host": c.config.Host, "port": c.config.Port, }) - go c.acceptConnections(ctx) + go c.acceptConnections() return nil } -func (c *MaixCamChannel) acceptConnections(ctx context.Context) { +func (c *MaixCamChannel) acceptConnections() { logger.DebugC("maixcam", "Starting connection acceptor") for { select { - case <-ctx.Done(): + case <-c.ctx.Done(): logger.InfoC("maixcam", "Stopping connection acceptor") return default: conn, err := c.listener.Accept() if err != nil { - if c.running { + if c.IsRunning() { logger.ErrorCF("maixcam", "Failed to accept connection", map[string]any{ "error": err.Error(), }) @@ -86,12 +100,12 @@ func (c *MaixCamChannel) acceptConnections(ctx context.Context) { c.clients[conn] = true c.clientsMux.Unlock() - go c.handleConnection(conn, ctx) + go c.handleConnection(conn) } } } -func (c *MaixCamChannel) handleConnection(conn net.Conn, ctx context.Context) { +func (c *MaixCamChannel) handleConnection(conn net.Conn) { logger.DebugC("maixcam", "Handling MaixCam connection") defer func() { @@ -106,7 +120,7 @@ func (c *MaixCamChannel) handleConnection(conn net.Conn, ctx context.Context) { for { select { - case <-ctx.Done(): + case <-c.ctx.Done(): return default: var msg MaixCamMessage @@ -170,11 +184,29 @@ func (c *MaixCamChannel) handlePersonDetection(msg MaixCamMessage) { "y": fmt.Sprintf("%.0f", y), "w": fmt.Sprintf("%.0f", w), "h": fmt.Sprintf("%.0f", h), - "peer_kind": "channel", - "peer_id": "default", } - c.HandleMessage(senderID, chatID, content, []string{}, metadata) + sender := bus.SenderInfo{ + Platform: "maixcam", + PlatformID: "maixcam", + CanonicalID: identity.BuildCanonicalID("maixcam", "maixcam"), + } + + if !c.IsAllowedSender(sender) { + return + } + + c.HandleMessage( + c.ctx, + bus.Peer{Kind: "channel", ID: "default"}, + "", + senderID, + chatID, + content, + []string{}, + metadata, + sender, + ) } func (c *MaixCamChannel) handleStatusUpdate(msg MaixCamMessage) { @@ -185,7 +217,12 @@ func (c *MaixCamChannel) handleStatusUpdate(msg MaixCamMessage) { func (c *MaixCamChannel) Stop(ctx context.Context) error { logger.InfoC("maixcam", "Stopping MaixCam channel") - c.setRunning(false) + c.SetRunning(false) + + // Cancel context first to signal goroutines to exit + if c.cancel != nil { + c.cancel() + } if c.listener != nil { c.listener.Close() @@ -205,7 +242,14 @@ func (c *MaixCamChannel) Stop(ctx context.Context) error { func (c *MaixCamChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { if !c.IsRunning() { - return fmt.Errorf("maixcam channel not running") + return channels.ErrNotRunning + } + + // Check ctx before entering write path + select { + case <-ctx.Done(): + return ctx.Err() + default: } c.clientsMux.RLock() @@ -230,13 +274,15 @@ func (c *MaixCamChannel) Send(ctx context.Context, msg bus.OutboundMessage) erro var sendErr error for conn := range c.clients { + _ = conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) if _, err := conn.Write(data); err != nil { logger.ErrorCF("maixcam", "Failed to send to client", map[string]any{ "client": conn.RemoteAddr().String(), "error": err.Error(), }) - sendErr = err + sendErr = fmt.Errorf("maixcam send: %w", channels.ErrTemporary) } + _ = conn.SetWriteDeadline(time.Time{}) } return sendErr diff --git a/pkg/channels/manager.go b/pkg/channels/manager.go index 75edaf49e..31af9672c 100644 --- a/pkg/channels/manager.go +++ b/pkg/channels/manager.go @@ -8,32 +8,152 @@ package channels import ( "context" + "errors" "fmt" + "math" + "net/http" "sync" + "time" + + "golang.org/x/time/rate" "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/constants" + "github.com/sipeed/picoclaw/pkg/health" "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/media" ) +const ( + defaultChannelQueueSize = 16 + defaultRateLimit = 10 // default 10 msg/s + maxRetries = 3 + rateLimitDelay = 1 * time.Second + baseBackoff = 500 * time.Millisecond + maxBackoff = 8 * time.Second + + janitorInterval = 10 * time.Second + typingStopTTL = 5 * time.Minute + placeholderTTL = 10 * time.Minute +) + +// typingEntry wraps a typing stop function with a creation timestamp for TTL eviction. +type typingEntry struct { + stop func() + createdAt time.Time +} + +// reactionEntry wraps a reaction undo function with a creation timestamp for TTL eviction. +type reactionEntry struct { + undo func() + createdAt time.Time +} + +// placeholderEntry wraps a placeholder ID with a creation timestamp for TTL eviction. +type placeholderEntry struct { + id string + createdAt time.Time +} + +// channelRateConfig maps channel name to per-second rate limit. +var channelRateConfig = map[string]float64{ + "telegram": 20, + "discord": 1, + "slack": 1, + "line": 10, +} + +type channelWorker struct { + ch Channel + queue chan bus.OutboundMessage + mediaQueue chan bus.OutboundMediaMessage + done chan struct{} + mediaDone chan struct{} + limiter *rate.Limiter +} + type Manager struct { - channels map[string]Channel - bus *bus.MessageBus - config *config.Config - dispatchTask *asyncTask - mu sync.RWMutex + channels map[string]Channel + workers map[string]*channelWorker + bus *bus.MessageBus + config *config.Config + mediaStore media.MediaStore + dispatchTask *asyncTask + mux *http.ServeMux + httpServer *http.Server + mu sync.RWMutex + placeholders sync.Map // "channel:chatID" → placeholderID (string) + typingStops sync.Map // "channel:chatID" → func() + reactionUndos sync.Map // "channel:chatID" → reactionEntry } type asyncTask struct { cancel context.CancelFunc } -func NewManager(cfg *config.Config, messageBus *bus.MessageBus) (*Manager, error) { +// RecordPlaceholder registers a placeholder message for later editing. +// Implements PlaceholderRecorder. +func (m *Manager) RecordPlaceholder(channel, chatID, placeholderID string) { + key := channel + ":" + chatID + m.placeholders.Store(key, placeholderEntry{id: placeholderID, createdAt: time.Now()}) +} + +// RecordTypingStop registers a typing stop function for later invocation. +// Implements PlaceholderRecorder. +func (m *Manager) RecordTypingStop(channel, chatID string, stop func()) { + key := channel + ":" + chatID + m.typingStops.Store(key, typingEntry{stop: stop, createdAt: time.Now()}) +} + +// RecordReactionUndo registers a reaction undo function for later invocation. +// Implements PlaceholderRecorder. +func (m *Manager) RecordReactionUndo(channel, chatID string, undo func()) { + key := channel + ":" + chatID + m.reactionUndos.Store(key, reactionEntry{undo: undo, createdAt: time.Now()}) +} + +// preSend handles typing stop, reaction undo, and placeholder editing before sending a message. +// Returns true if the message was edited into a placeholder (skip Send). +func (m *Manager) preSend(ctx context.Context, name string, msg bus.OutboundMessage, ch Channel) bool { + key := name + ":" + msg.ChatID + + // 1. Stop typing + if v, loaded := m.typingStops.LoadAndDelete(key); loaded { + if entry, ok := v.(typingEntry); ok { + entry.stop() // idempotent, safe + } + } + + // 2. Undo reaction + if v, loaded := m.reactionUndos.LoadAndDelete(key); loaded { + if entry, ok := v.(reactionEntry); ok { + entry.undo() // idempotent, safe + } + } + + // 3. Try editing placeholder + if v, loaded := m.placeholders.LoadAndDelete(key); loaded { + if entry, ok := v.(placeholderEntry); ok && entry.id != "" { + if editor, ok := ch.(MessageEditor); ok { + if err := editor.EditMessage(ctx, msg.ChatID, entry.id, msg.Content); err == nil { + return true // edited successfully, skip Send + } + // edit failed → fall through to normal Send + } + } + } + + return false +} + +func NewManager(cfg *config.Config, messageBus *bus.MessageBus, store media.MediaStore) (*Manager, error) { m := &Manager{ - channels: make(map[string]Channel), - bus: messageBus, - config: cfg, + channels: make(map[string]Channel), + workers: make(map[string]*channelWorker), + bus: messageBus, + config: cfg, + mediaStore: store, } if err := m.initChannels(); err != nil { @@ -43,163 +163,104 @@ func NewManager(cfg *config.Config, messageBus *bus.MessageBus) (*Manager, error return m, nil } +// initChannel is a helper that looks up a factory by name and creates the channel. +func (m *Manager) initChannel(name, displayName string) { + f, ok := getFactory(name) + if !ok { + logger.WarnCF("channels", "Factory not registered", map[string]any{ + "channel": displayName, + }) + return + } + logger.DebugCF("channels", "Attempting to initialize channel", map[string]any{ + "channel": displayName, + }) + ch, err := f(m.config, m.bus) + if err != nil { + logger.ErrorCF("channels", "Failed to initialize channel", map[string]any{ + "channel": displayName, + "error": err.Error(), + }) + } else { + // Inject MediaStore if channel supports it + if m.mediaStore != nil { + if setter, ok := ch.(interface{ SetMediaStore(s media.MediaStore) }); ok { + setter.SetMediaStore(m.mediaStore) + } + } + // Inject PlaceholderRecorder if channel supports it + if setter, ok := ch.(interface{ SetPlaceholderRecorder(r PlaceholderRecorder) }); ok { + setter.SetPlaceholderRecorder(m) + } + // Inject owner reference so BaseChannel.HandleMessage can auto-trigger typing/reaction + if setter, ok := ch.(interface{ SetOwner(ch Channel) }); ok { + setter.SetOwner(ch) + } + m.channels[name] = ch + logger.InfoCF("channels", "Channel enabled successfully", map[string]any{ + "channel": displayName, + }) + } +} + func (m *Manager) initChannels() error { logger.InfoC("channels", "Initializing channel manager") if m.config.Channels.Telegram.Enabled && m.config.Channels.Telegram.Token != "" { - logger.DebugC("channels", "Attempting to initialize Telegram channel") - telegram, err := NewTelegramChannel(m.config, m.bus) - if err != nil { - logger.ErrorCF("channels", "Failed to initialize Telegram channel", map[string]any{ - "error": err.Error(), - }) - } else { - m.channels["telegram"] = telegram - logger.InfoC("channels", "Telegram channel enabled successfully") - } + m.initChannel("telegram", "Telegram") } - if m.config.Channels.WhatsApp.Enabled && m.config.Channels.WhatsApp.BridgeURL != "" { - logger.DebugC("channels", "Attempting to initialize WhatsApp channel") - whatsapp, err := NewWhatsAppChannel(m.config.Channels.WhatsApp, m.bus) - if err != nil { - logger.ErrorCF("channels", "Failed to initialize WhatsApp channel", map[string]any{ - "error": err.Error(), - }) - } else { - m.channels["whatsapp"] = whatsapp - logger.InfoC("channels", "WhatsApp channel enabled successfully") + if m.config.Channels.WhatsApp.Enabled { + waCfg := m.config.Channels.WhatsApp + if waCfg.UseNative { + m.initChannel("whatsapp_native", "WhatsApp Native") + } else if waCfg.BridgeURL != "" { + m.initChannel("whatsapp", "WhatsApp") } } if m.config.Channels.Feishu.Enabled { - logger.DebugC("channels", "Attempting to initialize Feishu channel") - feishu, err := NewFeishuChannel(m.config.Channels.Feishu, m.bus) - if err != nil { - logger.ErrorCF("channels", "Failed to initialize Feishu channel", map[string]any{ - "error": err.Error(), - }) - } else { - m.channels["feishu"] = feishu - logger.InfoC("channels", "Feishu channel enabled successfully") - } + m.initChannel("feishu", "Feishu") } if m.config.Channels.Discord.Enabled && m.config.Channels.Discord.Token != "" { - logger.DebugC("channels", "Attempting to initialize Discord channel") - discord, err := NewDiscordChannel(m.config.Channels.Discord, m.bus) - if err != nil { - logger.ErrorCF("channels", "Failed to initialize Discord channel", map[string]any{ - "error": err.Error(), - }) - } else { - m.channels["discord"] = discord - logger.InfoC("channels", "Discord channel enabled successfully") - } + m.initChannel("discord", "Discord") } if m.config.Channels.MaixCam.Enabled { - logger.DebugC("channels", "Attempting to initialize MaixCam channel") - maixcam, err := NewMaixCamChannel(m.config.Channels.MaixCam, m.bus) - if err != nil { - logger.ErrorCF("channels", "Failed to initialize MaixCam channel", map[string]any{ - "error": err.Error(), - }) - } else { - m.channels["maixcam"] = maixcam - logger.InfoC("channels", "MaixCam channel enabled successfully") - } + m.initChannel("maixcam", "MaixCam") } if m.config.Channels.QQ.Enabled { - logger.DebugC("channels", "Attempting to initialize QQ channel") - qq, err := NewQQChannel(m.config.Channels.QQ, m.bus) - if err != nil { - logger.ErrorCF("channels", "Failed to initialize QQ channel", map[string]any{ - "error": err.Error(), - }) - } else { - m.channels["qq"] = qq - logger.InfoC("channels", "QQ channel enabled successfully") - } + m.initChannel("qq", "QQ") } if m.config.Channels.DingTalk.Enabled && m.config.Channels.DingTalk.ClientID != "" { - logger.DebugC("channels", "Attempting to initialize DingTalk channel") - dingtalk, err := NewDingTalkChannel(m.config.Channels.DingTalk, m.bus) - if err != nil { - logger.ErrorCF("channels", "Failed to initialize DingTalk channel", map[string]any{ - "error": err.Error(), - }) - } else { - m.channels["dingtalk"] = dingtalk - logger.InfoC("channels", "DingTalk channel enabled successfully") - } + m.initChannel("dingtalk", "DingTalk") } if m.config.Channels.Slack.Enabled && m.config.Channels.Slack.BotToken != "" { - logger.DebugC("channels", "Attempting to initialize Slack channel") - slackCh, err := NewSlackChannel(m.config.Channels.Slack, m.bus) - if err != nil { - logger.ErrorCF("channels", "Failed to initialize Slack channel", map[string]any{ - "error": err.Error(), - }) - } else { - m.channels["slack"] = slackCh - logger.InfoC("channels", "Slack channel enabled successfully") - } + m.initChannel("slack", "Slack") } if m.config.Channels.LINE.Enabled && m.config.Channels.LINE.ChannelAccessToken != "" { - logger.DebugC("channels", "Attempting to initialize LINE channel") - line, err := NewLINEChannel(m.config.Channels.LINE, m.bus) - if err != nil { - logger.ErrorCF("channels", "Failed to initialize LINE channel", map[string]any{ - "error": err.Error(), - }) - } else { - m.channels["line"] = line - logger.InfoC("channels", "LINE channel enabled successfully") - } + m.initChannel("line", "LINE") } if m.config.Channels.OneBot.Enabled && m.config.Channels.OneBot.WSUrl != "" { - logger.DebugC("channels", "Attempting to initialize OneBot channel") - onebot, err := NewOneBotChannel(m.config.Channels.OneBot, m.bus) - if err != nil { - logger.ErrorCF("channels", "Failed to initialize OneBot channel", map[string]any{ - "error": err.Error(), - }) - } else { - m.channels["onebot"] = onebot - logger.InfoC("channels", "OneBot channel enabled successfully") - } + m.initChannel("onebot", "OneBot") } if m.config.Channels.WeCom.Enabled && m.config.Channels.WeCom.Token != "" { - logger.DebugC("channels", "Attempting to initialize WeCom channel") - wecom, err := NewWeComBotChannel(m.config.Channels.WeCom, m.bus) - if err != nil { - logger.ErrorCF("channels", "Failed to initialize WeCom channel", map[string]any{ - "error": err.Error(), - }) - } else { - m.channels["wecom"] = wecom - logger.InfoC("channels", "WeCom channel enabled successfully") - } + m.initChannel("wecom", "WeCom") } if m.config.Channels.WeComApp.Enabled && m.config.Channels.WeComApp.CorpID != "" { - logger.DebugC("channels", "Attempting to initialize WeCom App channel") - wecomApp, err := NewWeComAppChannel(m.config.Channels.WeComApp, m.bus) - if err != nil { - logger.ErrorCF("channels", "Failed to initialize WeCom App channel", map[string]any{ - "error": err.Error(), - }) - } else { - m.channels["wecom_app"] = wecomApp - logger.InfoC("channels", "WeCom App channel enabled successfully") - } + m.initChannel("wecom_app", "WeCom App") + } + + if m.config.Channels.Pico.Enabled && m.config.Channels.Pico.Token != "" { + m.initChannel("pico", "Pico") } logger.InfoCF("channels", "Channel initialization completed", map[string]any{ @@ -209,6 +270,43 @@ func (m *Manager) initChannels() error { return nil } +// SetupHTTPServer creates a shared HTTP server with the given listen address. +// It registers health endpoints from the health server and discovers channels +// that implement WebhookHandler and/or HealthChecker to register their handlers. +func (m *Manager) SetupHTTPServer(addr string, healthServer *health.Server) { + m.mux = http.NewServeMux() + + // Register health endpoints + if healthServer != nil { + healthServer.RegisterOnMux(m.mux) + } + + // Discover and register webhook handlers and health checkers + for name, ch := range m.channels { + if wh, ok := ch.(WebhookHandler); ok { + m.mux.Handle(wh.WebhookPath(), wh) + logger.InfoCF("channels", "Webhook handler registered", map[string]any{ + "channel": name, + "path": wh.WebhookPath(), + }) + } + if hc, ok := ch.(HealthChecker); ok { + m.mux.HandleFunc(hc.HealthPath(), hc.HealthHandler) + logger.InfoCF("channels", "Health endpoint registered", map[string]any{ + "channel": name, + "path": hc.HealthPath(), + }) + } + } + + m.httpServer = &http.Server{ + Addr: addr, + Handler: m.mux, + ReadTimeout: 30 * time.Second, + WriteTimeout: 30 * time.Second, + } +} + func (m *Manager) StartAll(ctx context.Context) error { m.mu.Lock() defer m.mu.Unlock() @@ -223,8 +321,6 @@ func (m *Manager) StartAll(ctx context.Context) error { dispatchCtx, cancel := context.WithCancel(ctx) m.dispatchTask = &asyncTask{cancel: cancel} - go m.dispatchOutbound(dispatchCtx) - for name, channel := range m.channels { logger.InfoCF("channels", "Starting channel", map[string]any{ "channel": name, @@ -234,7 +330,34 @@ func (m *Manager) StartAll(ctx context.Context) error { "channel": name, "error": err.Error(), }) + continue } + // Lazily create worker only after channel starts successfully + w := newChannelWorker(name, channel) + m.workers[name] = w + go m.runWorker(dispatchCtx, name, w) + go m.runMediaWorker(dispatchCtx, name, w) + } + + // Start the dispatcher that reads from the bus and routes to workers + go m.dispatchOutbound(dispatchCtx) + go m.dispatchOutboundMedia(dispatchCtx) + + // Start the TTL janitor that cleans up stale typing/placeholder entries + go m.runTTLJanitor(dispatchCtx) + + // Start shared HTTP server if configured + if m.httpServer != nil { + go func() { + logger.InfoCF("channels", "Shared HTTP server listening", map[string]any{ + "addr": m.httpServer.Addr, + }) + if err := m.httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed { + logger.ErrorCF("channels", "Shared HTTP server error", map[string]any{ + "error": err.Error(), + }) + } + }() } logger.InfoC("channels", "All channels started") @@ -247,11 +370,48 @@ func (m *Manager) StopAll(ctx context.Context) error { logger.InfoC("channels", "Stopping all channels") + // Shutdown shared HTTP server first + if m.httpServer != nil { + shutdownCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + if err := m.httpServer.Shutdown(shutdownCtx); err != nil { + logger.ErrorCF("channels", "Shared HTTP server shutdown error", map[string]any{ + "error": err.Error(), + }) + } + m.httpServer = nil + } + + // Cancel dispatcher if m.dispatchTask != nil { m.dispatchTask.cancel() m.dispatchTask = nil } + // Close all worker queues and wait for them to drain + for _, w := range m.workers { + if w != nil { + close(w.queue) + } + } + for _, w := range m.workers { + if w != nil { + <-w.done + } + } + // Close all media worker queues and wait for them to drain + for _, w := range m.workers { + if w != nil { + close(w.mediaQueue) + } + } + for _, w := range m.workers { + if w != nil { + <-w.mediaDone + } + } + + // Stop all channels for name, channel := range m.channels { logger.InfoCF("channels", "Stopping channel", map[string]any{ "channel": name, @@ -268,42 +428,316 @@ func (m *Manager) StopAll(ctx context.Context) error { return nil } +// newChannelWorker creates a channelWorker with a rate limiter configured +// for the given channel name. +func newChannelWorker(name string, ch Channel) *channelWorker { + rateVal := float64(defaultRateLimit) + if r, ok := channelRateConfig[name]; ok { + rateVal = r + } + burst := int(math.Max(1, math.Ceil(rateVal/2))) + + return &channelWorker{ + ch: ch, + queue: make(chan bus.OutboundMessage, defaultChannelQueueSize), + mediaQueue: make(chan bus.OutboundMediaMessage, defaultChannelQueueSize), + done: make(chan struct{}), + mediaDone: make(chan struct{}), + limiter: rate.NewLimiter(rate.Limit(rateVal), burst), + } +} + +// runWorker processes outbound messages for a single channel, splitting +// messages that exceed the channel's maximum message length. +func (m *Manager) runWorker(ctx context.Context, name string, w *channelWorker) { + defer close(w.done) + for { + select { + case msg, ok := <-w.queue: + if !ok { + return + } + maxLen := 0 + if mlp, ok := w.ch.(MessageLengthProvider); ok { + maxLen = mlp.MaxMessageLength() + } + if maxLen > 0 && len([]rune(msg.Content)) > maxLen { + chunks := SplitMessage(msg.Content, maxLen) + for _, chunk := range chunks { + chunkMsg := msg + chunkMsg.Content = chunk + m.sendWithRetry(ctx, name, w, chunkMsg) + } + } else { + m.sendWithRetry(ctx, name, w, msg) + } + case <-ctx.Done(): + return + } + } +} + +// sendWithRetry sends a message through the channel with rate limiting and +// retry logic. It classifies errors to determine the retry strategy: +// - ErrNotRunning / ErrSendFailed: permanent, no retry +// - ErrRateLimit: fixed delay retry +// - ErrTemporary / unknown: exponential backoff retry +func (m *Manager) sendWithRetry(ctx context.Context, name string, w *channelWorker, msg bus.OutboundMessage) { + // Rate limit: wait for token + if err := w.limiter.Wait(ctx); err != nil { + // ctx canceled, shutting down + return + } + + // Pre-send: stop typing and try to edit placeholder + if m.preSend(ctx, name, msg, w.ch) { + return // placeholder was edited successfully, skip Send + } + + var lastErr error + for attempt := 0; attempt <= maxRetries; attempt++ { + lastErr = w.ch.Send(ctx, msg) + if lastErr == nil { + return + } + + // Permanent failures — don't retry + if errors.Is(lastErr, ErrNotRunning) || errors.Is(lastErr, ErrSendFailed) { + break + } + + // Last attempt exhausted — don't sleep + if attempt == maxRetries { + break + } + + // Rate limit error — fixed delay + if errors.Is(lastErr, ErrRateLimit) { + select { + case <-time.After(rateLimitDelay): + continue + case <-ctx.Done(): + return + } + } + + // ErrTemporary or unknown error — exponential backoff + backoff := min(time.Duration(float64(baseBackoff)*math.Pow(2, float64(attempt))), maxBackoff) + select { + case <-time.After(backoff): + case <-ctx.Done(): + return + } + } + + // All retries exhausted or permanent failure + logger.ErrorCF("channels", "Send failed", map[string]any{ + "channel": name, + "chat_id": msg.ChatID, + "error": lastErr.Error(), + "retries": maxRetries, + }) +} + func (m *Manager) dispatchOutbound(ctx context.Context) { logger.InfoC("channels", "Outbound dispatcher started") for { - select { - case <-ctx.Done(): + msg, ok := m.bus.SubscribeOutbound(ctx) + if !ok { logger.InfoC("channels", "Outbound dispatcher stopped") return - default: - msg, ok := m.bus.SubscribeOutbound(ctx) + } + + // Silently skip internal channels + if constants.IsInternalChannel(msg.Channel) { + continue + } + + m.mu.RLock() + _, exists := m.channels[msg.Channel] + w, wExists := m.workers[msg.Channel] + m.mu.RUnlock() + + if !exists { + logger.WarnCF("channels", "Unknown channel for outbound message", map[string]any{ + "channel": msg.Channel, + }) + continue + } + + if wExists && w != nil { + select { + case w.queue <- msg: + case <-ctx.Done(): + return + } + } else if exists { + logger.WarnCF("channels", "Channel has no active worker, skipping message", map[string]any{ + "channel": msg.Channel, + }) + } + } +} + +func (m *Manager) dispatchOutboundMedia(ctx context.Context) { + logger.InfoC("channels", "Outbound media dispatcher started") + + for { + msg, ok := m.bus.SubscribeOutboundMedia(ctx) + if !ok { + logger.InfoC("channels", "Outbound media dispatcher stopped") + return + } + + // Silently skip internal channels + if constants.IsInternalChannel(msg.Channel) { + continue + } + + m.mu.RLock() + _, exists := m.channels[msg.Channel] + w, wExists := m.workers[msg.Channel] + m.mu.RUnlock() + + if !exists { + logger.WarnCF("channels", "Unknown channel for outbound media message", map[string]any{ + "channel": msg.Channel, + }) + continue + } + + if wExists && w != nil { + select { + case w.mediaQueue <- msg: + case <-ctx.Done(): + return + } + } else if exists { + logger.WarnCF("channels", "Channel has no active worker, skipping media message", map[string]any{ + "channel": msg.Channel, + }) + } + } +} + +// runMediaWorker processes outbound media messages for a single channel. +func (m *Manager) runMediaWorker(ctx context.Context, name string, w *channelWorker) { + defer close(w.mediaDone) + for { + select { + case msg, ok := <-w.mediaQueue: if !ok { + return + } + m.sendMediaWithRetry(ctx, name, w, msg) + case <-ctx.Done(): + return + } + } +} + +// sendMediaWithRetry sends a media message through the channel with rate limiting and +// retry logic. If the channel does not implement MediaSender, it silently skips. +func (m *Manager) sendMediaWithRetry(ctx context.Context, name string, w *channelWorker, msg bus.OutboundMediaMessage) { + ms, ok := w.ch.(MediaSender) + if !ok { + logger.DebugCF("channels", "Channel does not support MediaSender, skipping media", map[string]any{ + "channel": name, + }) + return + } + + // Rate limit: wait for token + if err := w.limiter.Wait(ctx); err != nil { + return + } + + var lastErr error + for attempt := 0; attempt <= maxRetries; attempt++ { + lastErr = ms.SendMedia(ctx, msg) + if lastErr == nil { + return + } + + // Permanent failures — don't retry + if errors.Is(lastErr, ErrNotRunning) || errors.Is(lastErr, ErrSendFailed) { + break + } + + // Last attempt exhausted — don't sleep + if attempt == maxRetries { + break + } + + // Rate limit error — fixed delay + if errors.Is(lastErr, ErrRateLimit) { + select { + case <-time.After(rateLimitDelay): continue + case <-ctx.Done(): + return } + } - // Silently skip internal channels - if constants.IsInternalChannel(msg.Channel) { - continue - } + // ErrTemporary or unknown error — exponential backoff + backoff := min(time.Duration(float64(baseBackoff)*math.Pow(2, float64(attempt))), maxBackoff) + select { + case <-time.After(backoff): + case <-ctx.Done(): + return + } + } - m.mu.RLock() - channel, exists := m.channels[msg.Channel] - m.mu.RUnlock() + // All retries exhausted or permanent failure + logger.ErrorCF("channels", "SendMedia failed", map[string]any{ + "channel": name, + "chat_id": msg.ChatID, + "error": lastErr.Error(), + "retries": maxRetries, + }) +} - if !exists { - logger.WarnCF("channels", "Unknown channel for outbound message", map[string]any{ - "channel": msg.Channel, - }) - continue - } +// runTTLJanitor periodically scans the typingStops and placeholders maps +// and evicts entries that have exceeded their TTL. This prevents memory +// accumulation when outbound paths fail to trigger preSend (e.g. LLM errors). +func (m *Manager) runTTLJanitor(ctx context.Context) { + ticker := time.NewTicker(janitorInterval) + defer ticker.Stop() - if err := channel.Send(ctx, msg); err != nil { - logger.ErrorCF("channels", "Error sending message to channel", map[string]any{ - "channel": msg.Channel, - "error": err.Error(), - }) - } + for { + select { + case <-ctx.Done(): + return + case now := <-ticker.C: + m.typingStops.Range(func(key, value any) bool { + if entry, ok := value.(typingEntry); ok { + if now.Sub(entry.createdAt) > typingStopTTL { + if _, loaded := m.typingStops.LoadAndDelete(key); loaded { + entry.stop() // idempotent, safe + } + } + } + return true + }) + m.reactionUndos.Range(func(key, value any) bool { + if entry, ok := value.(reactionEntry); ok { + if now.Sub(entry.createdAt) > typingStopTTL { + if _, loaded := m.reactionUndos.LoadAndDelete(key); loaded { + entry.undo() // idempotent, safe + } + } + } + return true + }) + m.placeholders.Range(func(key, value any) bool { + if entry, ok := value.(placeholderEntry); ok { + if now.Sub(entry.createdAt) > placeholderTTL { + m.placeholders.Delete(key) + } + } + return true + }) } } } @@ -349,12 +783,20 @@ func (m *Manager) RegisterChannel(name string, channel Channel) { func (m *Manager) UnregisterChannel(name string) { m.mu.Lock() defer m.mu.Unlock() + if w, ok := m.workers[name]; ok && w != nil { + close(w.queue) + <-w.done + close(w.mediaQueue) + <-w.mediaDone + } + delete(m.workers, name) delete(m.channels, name) } func (m *Manager) SendToChannel(ctx context.Context, channelName, chatID, content string) error { m.mu.RLock() - channel, exists := m.channels[channelName] + _, exists := m.channels[channelName] + w, wExists := m.workers[channelName] m.mu.RUnlock() if !exists { @@ -367,5 +809,16 @@ func (m *Manager) SendToChannel(ctx context.Context, channelName, chatID, conten Content: content, } + if wExists && w != nil { + select { + case w.queue <- msg: + return nil + case <-ctx.Done(): + return ctx.Err() + } + } + + // Fallback: direct send (should not happen) + channel, _ := m.channels[channelName] return channel.Send(ctx, msg) } diff --git a/pkg/channels/manager_test.go b/pkg/channels/manager_test.go new file mode 100644 index 000000000..f09ecfe2f --- /dev/null +++ b/pkg/channels/manager_test.go @@ -0,0 +1,862 @@ +package channels + +import ( + "context" + "errors" + "fmt" + "sync" + "sync/atomic" + "testing" + "time" + + "golang.org/x/time/rate" + + "github.com/sipeed/picoclaw/pkg/bus" +) + +// mockChannel is a test double that delegates Send to a configurable function. +type mockChannel struct { + BaseChannel + sendFn func(ctx context.Context, msg bus.OutboundMessage) error +} + +func (m *mockChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { + return m.sendFn(ctx, msg) +} + +func (m *mockChannel) Start(ctx context.Context) error { return nil } +func (m *mockChannel) Stop(ctx context.Context) error { return nil } + +// newTestManager creates a minimal Manager suitable for unit tests. +func newTestManager() *Manager { + return &Manager{ + channels: make(map[string]Channel), + workers: make(map[string]*channelWorker), + } +} + +func TestSendWithRetry_Success(t *testing.T) { + m := newTestManager() + var callCount int + ch := &mockChannel{ + sendFn: func(_ context.Context, _ bus.OutboundMessage) error { + callCount++ + return nil + }, + } + w := &channelWorker{ + ch: ch, + limiter: rate.NewLimiter(rate.Inf, 1), + } + + ctx := context.Background() + msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"} + + m.sendWithRetry(ctx, "test", w, msg) + + if callCount != 1 { + t.Fatalf("expected 1 Send call, got %d", callCount) + } +} + +func TestSendWithRetry_TemporaryThenSuccess(t *testing.T) { + m := newTestManager() + var callCount int + ch := &mockChannel{ + sendFn: func(_ context.Context, _ bus.OutboundMessage) error { + callCount++ + if callCount <= 2 { + return fmt.Errorf("network error: %w", ErrTemporary) + } + return nil + }, + } + w := &channelWorker{ + ch: ch, + limiter: rate.NewLimiter(rate.Inf, 1), + } + + ctx := context.Background() + msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"} + + m.sendWithRetry(ctx, "test", w, msg) + + if callCount != 3 { + t.Fatalf("expected 3 Send calls (2 failures + 1 success), got %d", callCount) + } +} + +func TestSendWithRetry_PermanentFailure(t *testing.T) { + m := newTestManager() + var callCount int + ch := &mockChannel{ + sendFn: func(_ context.Context, _ bus.OutboundMessage) error { + callCount++ + return fmt.Errorf("bad chat ID: %w", ErrSendFailed) + }, + } + w := &channelWorker{ + ch: ch, + limiter: rate.NewLimiter(rate.Inf, 1), + } + + ctx := context.Background() + msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"} + + m.sendWithRetry(ctx, "test", w, msg) + + if callCount != 1 { + t.Fatalf("expected 1 Send call (no retry for permanent failure), got %d", callCount) + } +} + +func TestSendWithRetry_NotRunning(t *testing.T) { + m := newTestManager() + var callCount int + ch := &mockChannel{ + sendFn: func(_ context.Context, _ bus.OutboundMessage) error { + callCount++ + return ErrNotRunning + }, + } + w := &channelWorker{ + ch: ch, + limiter: rate.NewLimiter(rate.Inf, 1), + } + + ctx := context.Background() + msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"} + + m.sendWithRetry(ctx, "test", w, msg) + + if callCount != 1 { + t.Fatalf("expected 1 Send call (no retry for ErrNotRunning), got %d", callCount) + } +} + +func TestSendWithRetry_RateLimitRetry(t *testing.T) { + m := newTestManager() + var callCount int + ch := &mockChannel{ + sendFn: func(_ context.Context, _ bus.OutboundMessage) error { + callCount++ + if callCount == 1 { + return fmt.Errorf("429: %w", ErrRateLimit) + } + return nil + }, + } + w := &channelWorker{ + ch: ch, + limiter: rate.NewLimiter(rate.Inf, 1), + } + + ctx := context.Background() + msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"} + + start := time.Now() + m.sendWithRetry(ctx, "test", w, msg) + elapsed := time.Since(start) + + if callCount != 2 { + t.Fatalf("expected 2 Send calls (1 rate limit + 1 success), got %d", callCount) + } + // Should have waited at least rateLimitDelay (1s) but allow some slack + if elapsed < 900*time.Millisecond { + t.Fatalf("expected at least ~1s delay for rate limit retry, got %v", elapsed) + } +} + +func TestSendWithRetry_MaxRetriesExhausted(t *testing.T) { + m := newTestManager() + var callCount int + ch := &mockChannel{ + sendFn: func(_ context.Context, _ bus.OutboundMessage) error { + callCount++ + return fmt.Errorf("timeout: %w", ErrTemporary) + }, + } + w := &channelWorker{ + ch: ch, + limiter: rate.NewLimiter(rate.Inf, 1), + } + + ctx := context.Background() + msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"} + + m.sendWithRetry(ctx, "test", w, msg) + + expected := maxRetries + 1 // initial attempt + maxRetries retries + if callCount != expected { + t.Fatalf("expected %d Send calls, got %d", expected, callCount) + } +} + +func TestSendWithRetry_UnknownError(t *testing.T) { + m := newTestManager() + var callCount int + ch := &mockChannel{ + sendFn: func(_ context.Context, _ bus.OutboundMessage) error { + callCount++ + if callCount == 1 { + return errors.New("random unexpected error") + } + return nil + }, + } + w := &channelWorker{ + ch: ch, + limiter: rate.NewLimiter(rate.Inf, 1), + } + + ctx := context.Background() + msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"} + + m.sendWithRetry(ctx, "test", w, msg) + + if callCount != 2 { + t.Fatalf("expected 2 Send calls (unknown error treated as temporary), got %d", callCount) + } +} + +func TestSendWithRetry_ContextCancelled(t *testing.T) { + m := newTestManager() + var callCount int + ch := &mockChannel{ + sendFn: func(_ context.Context, _ bus.OutboundMessage) error { + callCount++ + return fmt.Errorf("timeout: %w", ErrTemporary) + }, + } + w := &channelWorker{ + ch: ch, + limiter: rate.NewLimiter(rate.Inf, 1), + } + + ctx, cancel := context.WithCancel(context.Background()) + msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"} + + // Cancel context after first Send attempt returns + ch.sendFn = func(_ context.Context, _ bus.OutboundMessage) error { + callCount++ + cancel() + return fmt.Errorf("timeout: %w", ErrTemporary) + } + + m.sendWithRetry(ctx, "test", w, msg) + + // Should have called Send once, then noticed ctx canceled during backoff + if callCount != 1 { + t.Fatalf("expected 1 Send call before context cancellation, got %d", callCount) + } +} + +func TestWorkerRateLimiter(t *testing.T) { + m := newTestManager() + + var mu sync.Mutex + var sendTimes []time.Time + + ch := &mockChannel{ + sendFn: func(_ context.Context, _ bus.OutboundMessage) error { + mu.Lock() + sendTimes = append(sendTimes, time.Now()) + mu.Unlock() + return nil + }, + } + + // Create a worker with a low rate: 2 msg/s, burst 1 + w := &channelWorker{ + ch: ch, + queue: make(chan bus.OutboundMessage, 10), + done: make(chan struct{}), + limiter: rate.NewLimiter(2, 1), + } + + ctx := t.Context() + + go m.runWorker(ctx, "test", w) + + // Enqueue 4 messages + for i := range 4 { + w.queue <- bus.OutboundMessage{Channel: "test", ChatID: "1", Content: fmt.Sprintf("msg%d", i)} + } + + // Wait enough time for all messages to be sent (4 msgs at 2/s = ~2s, give extra margin) + time.Sleep(3 * time.Second) + + mu.Lock() + times := make([]time.Time, len(sendTimes)) + copy(times, sendTimes) + mu.Unlock() + + if len(times) != 4 { + t.Fatalf("expected 4 sends, got %d", len(times)) + } + + // Verify rate limiting: total duration should be at least 1s + // (first message immediate, then ~500ms between each subsequent one at 2/s) + totalDuration := times[len(times)-1].Sub(times[0]) + if totalDuration < 1*time.Second { + t.Fatalf("expected total duration >= 1s for 4 msgs at 2/s rate, got %v", totalDuration) + } +} + +func TestNewChannelWorker_DefaultRate(t *testing.T) { + ch := &mockChannel{} + w := newChannelWorker("unknown_channel", ch) + + if w.limiter == nil { + t.Fatal("expected limiter to be non-nil") + } + if w.limiter.Limit() != rate.Limit(defaultRateLimit) { + t.Fatalf("expected rate limit %v, got %v", rate.Limit(defaultRateLimit), w.limiter.Limit()) + } +} + +func TestNewChannelWorker_ConfiguredRate(t *testing.T) { + ch := &mockChannel{} + + for name, expectedRate := range channelRateConfig { + w := newChannelWorker(name, ch) + if w.limiter.Limit() != rate.Limit(expectedRate) { + t.Fatalf("channel %s: expected rate %v, got %v", name, expectedRate, w.limiter.Limit()) + } + } +} + +func TestRunWorker_MessageSplitting(t *testing.T) { + m := newTestManager() + + var mu sync.Mutex + var received []string + + ch := &mockChannelWithLength{ + mockChannel: mockChannel{ + sendFn: func(_ context.Context, msg bus.OutboundMessage) error { + mu.Lock() + received = append(received, msg.Content) + mu.Unlock() + return nil + }, + }, + maxLen: 5, + } + + w := &channelWorker{ + ch: ch, + queue: make(chan bus.OutboundMessage, 10), + done: make(chan struct{}), + limiter: rate.NewLimiter(rate.Inf, 1), + } + + ctx := t.Context() + + go m.runWorker(ctx, "test", w) + + // Send a message that should be split + w.queue <- bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello world"} + + time.Sleep(100 * time.Millisecond) + + mu.Lock() + count := len(received) + mu.Unlock() + + if count < 2 { + t.Fatalf("expected message to be split into at least 2 chunks, got %d", count) + } +} + +// mockChannelWithLength implements MessageLengthProvider. +type mockChannelWithLength struct { + mockChannel + maxLen int +} + +func (m *mockChannelWithLength) MaxMessageLength() int { + return m.maxLen +} + +func TestSendWithRetry_ExponentialBackoff(t *testing.T) { + m := newTestManager() + + var callTimes []time.Time + var callCount atomic.Int32 + ch := &mockChannel{ + sendFn: func(_ context.Context, _ bus.OutboundMessage) error { + callTimes = append(callTimes, time.Now()) + callCount.Add(1) + return fmt.Errorf("timeout: %w", ErrTemporary) + }, + } + w := &channelWorker{ + ch: ch, + limiter: rate.NewLimiter(rate.Inf, 1), + } + + ctx := context.Background() + msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"} + + start := time.Now() + m.sendWithRetry(ctx, "test", w, msg) + totalElapsed := time.Since(start) + + // With maxRetries=3: attempts at 0, ~500ms, ~1.5s, ~3.5s + // Total backoff: 500ms + 1s + 2s = 3.5s + // Allow some margin + if totalElapsed < 3*time.Second { + t.Fatalf("expected total elapsed >= 3s for exponential backoff, got %v", totalElapsed) + } + + if int(callCount.Load()) != maxRetries+1 { + t.Fatalf("expected %d calls, got %d", maxRetries+1, callCount.Load()) + } +} + +// --- Phase 10: preSend orchestration tests --- + +// mockMessageEditor is a channel that supports MessageEditor. +type mockMessageEditor struct { + mockChannel + editFn func(ctx context.Context, chatID, messageID, content string) error +} + +func (m *mockMessageEditor) EditMessage(ctx context.Context, chatID, messageID, content string) error { + return m.editFn(ctx, chatID, messageID, content) +} + +func TestPreSend_PlaceholderEditSuccess(t *testing.T) { + m := newTestManager() + var sendCalled bool + var editCalled bool + + ch := &mockMessageEditor{ + mockChannel: mockChannel{ + sendFn: func(_ context.Context, _ bus.OutboundMessage) error { + sendCalled = true + return nil + }, + }, + editFn: func(_ context.Context, chatID, messageID, content string) error { + editCalled = true + if chatID != "123" { + t.Fatalf("expected chatID 123, got %s", chatID) + } + if messageID != "456" { + t.Fatalf("expected messageID 456, got %s", messageID) + } + if content != "hello" { + t.Fatalf("expected content 'hello', got %s", content) + } + return nil + }, + } + + // Register placeholder + m.RecordPlaceholder("test", "123", "456") + + msg := bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"} + edited := m.preSend(context.Background(), "test", msg, ch) + + if !edited { + t.Fatal("expected preSend to return true (placeholder edited)") + } + if !editCalled { + t.Fatal("expected EditMessage to be called") + } + if sendCalled { + t.Fatal("expected Send to NOT be called when placeholder edited") + } +} + +func TestPreSend_PlaceholderEditFails_FallsThrough(t *testing.T) { + m := newTestManager() + + ch := &mockMessageEditor{ + mockChannel: mockChannel{ + sendFn: func(_ context.Context, _ bus.OutboundMessage) error { + return nil + }, + }, + editFn: func(_ context.Context, _, _, _ string) error { + return fmt.Errorf("edit failed") + }, + } + + m.RecordPlaceholder("test", "123", "456") + + msg := bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"} + edited := m.preSend(context.Background(), "test", msg, ch) + + if edited { + t.Fatal("expected preSend to return false when edit fails") + } +} + +func TestPreSend_TypingStopCalled(t *testing.T) { + m := newTestManager() + var stopCalled bool + + ch := &mockChannel{ + sendFn: func(_ context.Context, _ bus.OutboundMessage) error { + return nil + }, + } + + m.RecordTypingStop("test", "123", func() { + stopCalled = true + }) + + msg := bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"} + m.preSend(context.Background(), "test", msg, ch) + + if !stopCalled { + t.Fatal("expected typing stop func to be called") + } +} + +func TestPreSend_NoRegisteredState(t *testing.T) { + m := newTestManager() + + ch := &mockChannel{ + sendFn: func(_ context.Context, _ bus.OutboundMessage) error { + return nil + }, + } + + msg := bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"} + edited := m.preSend(context.Background(), "test", msg, ch) + + if edited { + t.Fatal("expected preSend to return false with no registered state") + } +} + +func TestPreSend_TypingAndPlaceholder(t *testing.T) { + m := newTestManager() + var stopCalled bool + var editCalled bool + + ch := &mockMessageEditor{ + mockChannel: mockChannel{ + sendFn: func(_ context.Context, _ bus.OutboundMessage) error { + return nil + }, + }, + editFn: func(_ context.Context, _, _, _ string) error { + editCalled = true + return nil + }, + } + + m.RecordTypingStop("test", "123", func() { + stopCalled = true + }) + m.RecordPlaceholder("test", "123", "456") + + msg := bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"} + edited := m.preSend(context.Background(), "test", msg, ch) + + if !stopCalled { + t.Fatal("expected typing stop to be called") + } + if !editCalled { + t.Fatal("expected EditMessage to be called") + } + if !edited { + t.Fatal("expected preSend to return true") + } +} + +func TestRecordPlaceholder_ConcurrentSafe(t *testing.T) { + m := newTestManager() + + var wg sync.WaitGroup + for i := range 100 { + wg.Add(1) + go func(i int) { + defer wg.Done() + chatID := fmt.Sprintf("chat_%d", i%10) + m.RecordPlaceholder("test", chatID, fmt.Sprintf("msg_%d", i)) + }(i) + } + wg.Wait() +} + +func TestRecordTypingStop_ConcurrentSafe(t *testing.T) { + m := newTestManager() + + var wg sync.WaitGroup + for i := range 100 { + wg.Add(1) + go func(i int) { + defer wg.Done() + chatID := fmt.Sprintf("chat_%d", i%10) + m.RecordTypingStop("test", chatID, func() {}) + }(i) + } + wg.Wait() +} + +func TestSendWithRetry_PreSendEditsPlaceholder(t *testing.T) { + m := newTestManager() + var sendCalled bool + + ch := &mockMessageEditor{ + mockChannel: mockChannel{ + sendFn: func(_ context.Context, _ bus.OutboundMessage) error { + sendCalled = true + return nil + }, + }, + editFn: func(_ context.Context, _, _, _ string) error { + return nil // edit succeeds + }, + } + + m.RecordPlaceholder("test", "123", "456") + + w := &channelWorker{ + ch: ch, + limiter: rate.NewLimiter(rate.Inf, 1), + } + + msg := bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"} + m.sendWithRetry(context.Background(), "test", w, msg) + + if sendCalled { + t.Fatal("expected Send to NOT be called when placeholder was edited") + } +} + +// --- Dispatcher exit tests (Step 1) --- + +func TestDispatcherExitsOnCancel(t *testing.T) { + mb := bus.NewMessageBus() + defer mb.Close() + + m := &Manager{ + channels: make(map[string]Channel), + workers: make(map[string]*channelWorker), + bus: mb, + } + + ctx, cancel := context.WithCancel(context.Background()) + done := make(chan struct{}) + + go func() { + m.dispatchOutbound(ctx) + close(done) + }() + + // Cancel context and verify the dispatcher exits quickly + cancel() + + select { + case <-done: + // success + case <-time.After(2 * time.Second): + t.Fatal("dispatchOutbound did not exit within 2s after context cancel") + } +} + +func TestDispatcherMediaExitsOnCancel(t *testing.T) { + mb := bus.NewMessageBus() + defer mb.Close() + + m := &Manager{ + channels: make(map[string]Channel), + workers: make(map[string]*channelWorker), + bus: mb, + } + + ctx, cancel := context.WithCancel(context.Background()) + done := make(chan struct{}) + + go func() { + m.dispatchOutboundMedia(ctx) + close(done) + }() + + cancel() + + select { + case <-done: + // success + case <-time.After(2 * time.Second): + t.Fatal("dispatchOutboundMedia did not exit within 2s after context cancel") + } +} + +// --- TTL Janitor tests (Step 2) --- + +func TestTypingStopJanitorEviction(t *testing.T) { + m := newTestManager() + + var stopCalled atomic.Bool + // Store a typing entry with a creation time far in the past + m.typingStops.Store("test:123", typingEntry{ + stop: func() { stopCalled.Store(true) }, + createdAt: time.Now().Add(-10 * time.Minute), // well past typingStopTTL + }) + + // Run janitor with a short-lived context + ctx, cancel := context.WithCancel(context.Background()) + + // Manually trigger the janitor logic once by simulating a tick + go func() { + // Override janitor to run immediately + now := time.Now() + m.typingStops.Range(func(key, value any) bool { + if entry, ok := value.(typingEntry); ok { + if now.Sub(entry.createdAt) > typingStopTTL { + if _, loaded := m.typingStops.LoadAndDelete(key); loaded { + entry.stop() + } + } + } + return true + }) + cancel() + }() + + <-ctx.Done() + + if !stopCalled.Load() { + t.Fatal("expected typing stop function to be called by janitor eviction") + } + + // Verify entry was deleted + if _, loaded := m.typingStops.Load("test:123"); loaded { + t.Fatal("expected typing entry to be deleted after eviction") + } +} + +func TestPlaceholderJanitorEviction(t *testing.T) { + m := newTestManager() + + // Store a placeholder entry with a creation time far in the past + m.placeholders.Store("test:456", placeholderEntry{ + id: "msg_old", + createdAt: time.Now().Add(-20 * time.Minute), // well past placeholderTTL + }) + + // Simulate janitor logic + now := time.Now() + m.placeholders.Range(func(key, value any) bool { + if entry, ok := value.(placeholderEntry); ok { + if now.Sub(entry.createdAt) > placeholderTTL { + m.placeholders.Delete(key) + } + } + return true + }) + + // Verify entry was deleted + if _, loaded := m.placeholders.Load("test:456"); loaded { + t.Fatal("expected placeholder entry to be deleted after eviction") + } +} + +func TestPreSendStillWorksWithWrappedTypes(t *testing.T) { + m := newTestManager() + var stopCalled bool + var editCalled bool + + ch := &mockMessageEditor{ + mockChannel: mockChannel{ + sendFn: func(_ context.Context, _ bus.OutboundMessage) error { + return nil + }, + }, + editFn: func(_ context.Context, chatID, messageID, content string) error { + editCalled = true + if messageID != "ph_id" { + t.Fatalf("expected messageID ph_id, got %s", messageID) + } + return nil + }, + } + + // Use the new wrapped types via the public API + m.RecordTypingStop("test", "chat1", func() { + stopCalled = true + }) + m.RecordPlaceholder("test", "chat1", "ph_id") + + msg := bus.OutboundMessage{Channel: "test", ChatID: "chat1", Content: "response"} + edited := m.preSend(context.Background(), "test", msg, ch) + + if !stopCalled { + t.Fatal("expected typing stop to be called via wrapped type") + } + if !editCalled { + t.Fatal("expected EditMessage to be called via wrapped type") + } + if !edited { + t.Fatal("expected preSend to return true") + } +} + +// --- Lazy worker creation tests (Step 6) --- + +func TestLazyWorkerCreation(t *testing.T) { + m := newTestManager() + + ch := &mockChannel{ + sendFn: func(_ context.Context, _ bus.OutboundMessage) error { + return nil + }, + } + + // RegisterChannel should NOT create a worker + m.RegisterChannel("lazy", ch) + + m.mu.RLock() + _, chExists := m.channels["lazy"] + _, wExists := m.workers["lazy"] + m.mu.RUnlock() + + if !chExists { + t.Fatal("expected channel to be registered") + } + if wExists { + t.Fatal("expected worker to NOT be created by RegisterChannel (lazy creation)") + } +} + +// --- FastID uniqueness test (Step 5) --- + +func TestBuildMediaScope_FastIDUniqueness(t *testing.T) { + seen := make(map[string]bool) + + for range 1000 { + scope := BuildMediaScope("test", "chat1", "") + if seen[scope] { + t.Fatalf("duplicate scope generated: %s", scope) + } + seen[scope] = true + } + + // Verify format: "channel:chatID:id" + scope := BuildMediaScope("telegram", "42", "") + parts := 0 + for _, c := range scope { + if c == ':' { + parts++ + } + } + if parts != 2 { + t.Fatalf("expected scope to have 2 colons (channel:chatID:id), got: %s", scope) + } +} + +func TestBuildMediaScope_WithMessageID(t *testing.T) { + scope := BuildMediaScope("discord", "chat99", "msg123") + expected := "discord:chat99:msg123" + if scope != expected { + t.Fatalf("expected %s, got %s", expected, scope) + } +} diff --git a/pkg/channels/media.go b/pkg/channels/media.go new file mode 100644 index 000000000..c645a6180 --- /dev/null +++ b/pkg/channels/media.go @@ -0,0 +1,15 @@ +package channels + +import ( + "context" + + "github.com/sipeed/picoclaw/pkg/bus" +) + +// MediaSender is an optional interface for channels that can send +// media attachments (images, files, audio, video). +// Manager discovers channels implementing this interface via type +// assertion and routes OutboundMediaMessage to them. +type MediaSender interface { + SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error +} diff --git a/pkg/channels/onebot/init.go b/pkg/channels/onebot/init.go new file mode 100644 index 000000000..84c06dfd6 --- /dev/null +++ b/pkg/channels/onebot/init.go @@ -0,0 +1,13 @@ +package onebot + +import ( + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" +) + +func init() { + channels.RegisterFactory("onebot", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) { + return NewOneBotChannel(cfg.Channels.OneBot, b) + }) +} diff --git a/pkg/channels/onebot.go b/pkg/channels/onebot/onebot.go similarity index 75% rename from pkg/channels/onebot.go rename to pkg/channels/onebot/onebot.go index 1b0cbc4ab..62a9eb34a 100644 --- a/pkg/channels/onebot.go +++ b/pkg/channels/onebot/onebot.go @@ -1,10 +1,9 @@ -package channels +package onebot import ( "context" "encoding/json" "fmt" - "os" "strconv" "strings" "sync" @@ -14,30 +13,30 @@ import ( "github.com/gorilla/websocket" "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/identity" "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/media" "github.com/sipeed/picoclaw/pkg/utils" - "github.com/sipeed/picoclaw/pkg/voice" ) type OneBotChannel struct { - *BaseChannel - config config.OneBotConfig - conn *websocket.Conn - ctx context.Context - cancel context.CancelFunc - dedup map[string]struct{} - dedupRing []string - dedupIdx int - mu sync.Mutex - writeMu sync.Mutex - echoCounter int64 - selfID int64 - pending map[string]chan json.RawMessage - pendingMu sync.Mutex - transcriber *voice.GroqTranscriber - lastMessageID sync.Map - pendingEmojiMsg sync.Map + *channels.BaseChannel + config config.OneBotConfig + conn *websocket.Conn + ctx context.Context + cancel context.CancelFunc + dedup map[string]struct{} + dedupRing []string + dedupIdx int + mu sync.Mutex + writeMu sync.Mutex + echoCounter int64 + selfID int64 + pending map[string]chan json.RawMessage + pendingMu sync.Mutex + lastMessageID sync.Map } type oneBotRawEvent struct { @@ -98,7 +97,10 @@ type oneBotMessageSegment struct { } func NewOneBotChannel(cfg config.OneBotConfig, messageBus *bus.MessageBus) (*OneBotChannel, error) { - base := NewBaseChannel("onebot", cfg, messageBus, cfg.AllowFrom) + base := channels.NewBaseChannel("onebot", cfg, messageBus, cfg.AllowFrom, + channels.WithGroupTrigger(cfg.GroupTrigger), + channels.WithReasoningChannelID(cfg.ReasoningChannelID), + ) const dedupSize = 1024 return &OneBotChannel{ @@ -111,10 +113,6 @@ func NewOneBotChannel(cfg config.OneBotConfig, messageBus *bus.MessageBus) (*One }, nil } -func (c *OneBotChannel) SetTranscriber(transcriber *voice.GroqTranscriber) { - c.transcriber = transcriber -} - func (c *OneBotChannel) setMsgEmojiLike(messageID string, emojiID int, set bool) { go func() { _, err := c.sendAPIRequest("set_msg_emoji_like", map[string]any{ @@ -131,6 +129,22 @@ func (c *OneBotChannel) setMsgEmojiLike(messageID string, emojiID int, set bool) }() } +// ReactToMessage implements channels.ReactionCapable. +// It adds an emoji reaction (ID 289) to group messages and returns an undo function. +// Private messages return a no-op since reactions are only meaningful in groups. +func (c *OneBotChannel) ReactToMessage(ctx context.Context, chatID, messageID string) (func(), error) { + // Only react in group chats + if !strings.HasPrefix(chatID, "group:") { + return func() {}, nil + } + + c.setMsgEmojiLike(messageID, 289, true) + + return func() { + c.setMsgEmojiLike(messageID, 289, false) + }, nil +} + func (c *OneBotChannel) Start(ctx context.Context) error { if c.config.WSUrl == "" { return fmt.Errorf("OneBot ws_url not configured") @@ -159,7 +173,7 @@ func (c *OneBotChannel) Start(ctx context.Context) error { } } - c.setRunning(true) + c.SetRunning(true) logger.InfoC("onebot", "OneBot channel started successfully") return nil @@ -300,7 +314,9 @@ func (c *OneBotChannel) sendAPIRequest(action string, params any, timeout time.D } c.writeMu.Lock() + _ = conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) err = conn.WriteMessage(websocket.TextMessage, data) + _ = conn.SetWriteDeadline(time.Time{}) c.writeMu.Unlock() if err != nil { @@ -309,6 +325,9 @@ func (c *OneBotChannel) sendAPIRequest(action string, params any, timeout time.D select { case resp := <-ch: + if resp == nil { + return nil, fmt.Errorf("API request %s: channel stopped", action) + } return resp, nil case <-time.After(timeout): return nil, fmt.Errorf("API request %s timed out after %v", action, timeout) @@ -346,7 +365,7 @@ func (c *OneBotChannel) reconnectLoop() { func (c *OneBotChannel) Stop(ctx context.Context) error { logger.InfoC("onebot", "Stopping OneBot channel") - c.setRunning(false) + c.SetRunning(false) if c.cancel != nil { c.cancel() @@ -354,7 +373,10 @@ func (c *OneBotChannel) Stop(ctx context.Context) error { c.pendingMu.Lock() for echo, ch := range c.pending { - close(ch) + select { + case ch <- nil: // non-blocking wake for blocked sendAPIRequest goroutines + default: + } delete(c.pending, echo) } c.pendingMu.Unlock() @@ -371,7 +393,14 @@ func (c *OneBotChannel) Stop(ctx context.Context) error { func (c *OneBotChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { if !c.IsRunning() { - return fmt.Errorf("OneBot channel not running") + return channels.ErrNotRunning + } + + // Check ctx before entering write path + select { + case <-ctx.Done(): + return ctx.Err() + default: } c.mu.Lock() @@ -401,20 +430,127 @@ func (c *OneBotChannel) Send(ctx context.Context, msg bus.OutboundMessage) error } c.writeMu.Lock() + _ = conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) err = conn.WriteMessage(websocket.TextMessage, data) + _ = conn.SetWriteDeadline(time.Time{}) c.writeMu.Unlock() if err != nil { logger.ErrorCF("onebot", "Failed to send message", map[string]any{ "error": err.Error(), }) - return err + return fmt.Errorf("onebot send: %w", channels.ErrTemporary) } - if msgID, ok := c.pendingEmojiMsg.LoadAndDelete(msg.ChatID); ok { - if mid, ok := msgID.(string); ok && mid != "" { - c.setMsgEmojiLike(mid, 289, false) + return nil +} + +// SendMedia implements the channels.MediaSender interface. +func (c *OneBotChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error { + if !c.IsRunning() { + return channels.ErrNotRunning + } + + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + + c.mu.Lock() + conn := c.conn + c.mu.Unlock() + + if conn == nil { + return fmt.Errorf("OneBot WebSocket not connected") + } + + store := c.GetMediaStore() + if store == nil { + return fmt.Errorf("no media store available: %w", channels.ErrSendFailed) + } + + // Build media segments + var segments []oneBotMessageSegment + for _, part := range msg.Parts { + localPath, err := store.Resolve(part.Ref) + if err != nil { + logger.ErrorCF("onebot", "Failed to resolve media ref", map[string]any{ + "ref": part.Ref, + "error": err.Error(), + }) + continue } + + var segType string + switch part.Type { + case "image": + segType = "image" + case "video": + segType = "video" + case "audio": + segType = "record" + default: + segType = "file" + } + + segments = append(segments, oneBotMessageSegment{ + Type: segType, + Data: map[string]any{"file": "file://" + localPath}, + }) + + if part.Caption != "" { + segments = append(segments, oneBotMessageSegment{ + Type: "text", + Data: map[string]any{"text": part.Caption}, + }) + } + } + + if len(segments) == 0 { + return nil + } + + chatID := msg.ChatID + var action, idKey string + var rawID string + if rest, ok := strings.CutPrefix(chatID, "group:"); ok { + action, idKey, rawID = "send_group_msg", "group_id", rest + } else if rest, ok := strings.CutPrefix(chatID, "private:"); ok { + action, idKey, rawID = "send_private_msg", "user_id", rest + } else { + action, idKey, rawID = "send_private_msg", "user_id", chatID + } + + id, err := strconv.ParseInt(rawID, 10, 64) + if err != nil { + return fmt.Errorf("invalid %s in chatID: %s: %w", idKey, chatID, channels.ErrSendFailed) + } + + echo := fmt.Sprintf("send_%d", atomic.AddInt64(&c.echoCounter, 1)) + + req := oneBotAPIRequest{ + Action: action, + Params: map[string]any{idKey: id, "message": segments}, + Echo: echo, + } + + data, err := json.Marshal(req) + if err != nil { + return fmt.Errorf("failed to marshal OneBot request: %w", err) + } + + c.writeMu.Lock() + _ = conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) + err = conn.WriteMessage(websocket.TextMessage, data) + _ = conn.SetWriteDeadline(time.Time{}) + c.writeMu.Unlock() + + if err != nil { + logger.ErrorCF("onebot", "Failed to send media message", map[string]any{ + "error": err.Error(), + }) + return fmt.Errorf("onebot send media: %w", channels.ErrTemporary) } return nil @@ -571,11 +707,15 @@ type parseMessageResult struct { Text string IsBotMentioned bool Media []string - LocalFiles []string ReplyTo string } -func (c *OneBotChannel) parseMessageSegments(raw json.RawMessage, selfID int64) parseMessageResult { +func (c *OneBotChannel) parseMessageSegments( + raw json.RawMessage, + selfID int64, + store media.MediaStore, + scope string, +) parseMessageResult { if len(raw) == 0 { return parseMessageResult{} } @@ -602,10 +742,23 @@ func (c *OneBotChannel) parseMessageSegments(raw json.RawMessage, selfID int64) var textParts []string mentioned := false selfIDStr := strconv.FormatInt(selfID, 10) - var media []string - var localFiles []string + var mediaRefs []string var replyTo string + // Helper to register a local file with the media store + storeFile := func(localPath, filename string) string { + if store != nil { + ref, err := store.Store(localPath, media.MediaMeta{ + Filename: filename, + Source: "onebot", + }, scope) + if err == nil { + return ref + } + } + return localPath // fallback + } + for _, seg := range segments { segType, _ := seg["type"].(string) data, _ := seg["data"].(map[string]any) @@ -641,8 +794,7 @@ func (c *OneBotChannel) parseMessageSegments(raw json.RawMessage, selfID int64) LoggerPrefix: "onebot", }) if localPath != "" { - media = append(media, localPath) - localFiles = append(localFiles, localPath) + mediaRefs = append(mediaRefs, storeFile(localPath, filename)) textParts = append(textParts, fmt.Sprintf("[%s]", segType)) } } @@ -656,24 +808,8 @@ func (c *OneBotChannel) parseMessageSegments(raw json.RawMessage, selfID int64) LoggerPrefix: "onebot", }) if localPath != "" { - localFiles = append(localFiles, localPath) - if c.transcriber != nil && c.transcriber.IsAvailable() { - tctx, tcancel := context.WithTimeout(c.ctx, 30*time.Second) - result, err := c.transcriber.Transcribe(tctx, localPath) - tcancel() - if err != nil { - logger.WarnCF("onebot", "Voice transcription failed", map[string]any{ - "error": err.Error(), - }) - textParts = append(textParts, "[voice (transcription failed)]") - media = append(media, localPath) - } else { - textParts = append(textParts, fmt.Sprintf("[voice transcription: %s]", result.Text)) - } - } else { - textParts = append(textParts, "[voice]") - media = append(media, localPath) - } + textParts = append(textParts, "[voice]") + mediaRefs = append(mediaRefs, storeFile(localPath, "voice.amr")) } } } @@ -701,8 +837,7 @@ func (c *OneBotChannel) parseMessageSegments(raw json.RawMessage, selfID int64) return parseMessageResult{ Text: strings.TrimSpace(strings.Join(textParts, "")), IsBotMentioned: mentioned, - Media: media, - LocalFiles: localFiles, + Media: mediaRefs, ReplyTo: replyTo, } } @@ -711,7 +846,13 @@ func (c *OneBotChannel) handleRawEvent(raw *oneBotRawEvent) { switch raw.PostType { case "message": if userID, err := parseJSONInt64(raw.UserID); err == nil && userID > 0 { - if !c.IsAllowed(strconv.FormatInt(userID, 10)) { + // Build minimal sender for allowlist check + sender := bus.SenderInfo{ + Platform: "onebot", + PlatformID: strconv.FormatInt(userID, 10), + CanonicalID: identity.BuildCanonicalID("onebot", strconv.FormatInt(userID, 10)), + } + if !c.IsAllowedSender(sender) { logger.DebugCF("onebot", "Message rejected by allowlist", map[string]any{ "user_id": userID, }) @@ -794,7 +935,17 @@ func (c *OneBotChannel) handleMessage(raw *oneBotRawEvent) { selfID = atomic.LoadInt64(&c.selfID) } - parsed := c.parseMessageSegments(raw.Message, selfID) + // Compute scope for media store before parsing (parsing may download files) + var chatIDForScope string + switch raw.MessageType { + case "group": + chatIDForScope = "group:" + strconv.FormatInt(groupID, 10) + default: + chatIDForScope = "private:" + strconv.FormatInt(userID, 10) + } + scope := channels.BuildMediaScope("onebot", chatIDForScope, messageID) + + parsed := c.parseMessageSegments(raw.Message, selfID, c.GetMediaStore(), scope) isBotMentioned := parsed.IsBotMentioned content := raw.RawMessage @@ -823,20 +974,6 @@ func (c *OneBotChannel) handleMessage(raw *oneBotRawEvent) { } } - // Clean up temp files when done - if len(parsed.LocalFiles) > 0 { - defer func() { - for _, f := range parsed.LocalFiles { - if err := os.Remove(f); err != nil { - logger.DebugCF("onebot", "Failed to remove temp file", map[string]any{ - "path": f, - "error": err.Error(), - }) - } - } - }() - } - if c.isDuplicate(messageID) { logger.DebugCF("onebot", "Duplicate message, skipping", map[string]any{ "message_id": messageID, @@ -854,9 +991,9 @@ func (c *OneBotChannel) handleMessage(raw *oneBotRawEvent) { senderID := strconv.FormatInt(userID, 10) var chatID string - metadata := map[string]string{ - "message_id": messageID, - } + var peer bus.Peer + + metadata := map[string]string{} if parsed.ReplyTo != "" { metadata["reply_to_message_id"] = parsed.ReplyTo @@ -865,14 +1002,12 @@ func (c *OneBotChannel) handleMessage(raw *oneBotRawEvent) { switch raw.MessageType { case "private": chatID = "private:" + senderID - metadata["peer_kind"] = "direct" - metadata["peer_id"] = senderID + peer = bus.Peer{Kind: "direct", ID: senderID} case "group": groupIDStr := strconv.FormatInt(groupID, 10) chatID = "group:" + groupIDStr - metadata["peer_kind"] = "group" - metadata["peer_id"] = groupIDStr + peer = bus.Peer{Kind: "group", ID: groupIDStr} metadata["group_id"] = groupIDStr senderUserID, _ := parseJSONInt64(sender.UserID) @@ -886,8 +1021,8 @@ func (c *OneBotChannel) handleMessage(raw *oneBotRawEvent) { metadata["sender_name"] = sender.Nickname } - triggered, strippedContent := c.checkGroupTrigger(content, isBotMentioned) - if !triggered { + respond, strippedContent := c.ShouldRespondInGroup(isBotMentioned, content) + if !respond { logger.DebugCF("onebot", "Group message ignored (no trigger)", map[string]any{ "sender": senderID, "group": groupIDStr, @@ -922,12 +1057,21 @@ func (c *OneBotChannel) handleMessage(raw *oneBotRawEvent) { c.lastMessageID.Store(chatID, messageID) - if raw.MessageType == "group" && messageID != "" && messageID != "0" { - c.setMsgEmojiLike(messageID, 289, true) - c.pendingEmojiMsg.Store(chatID, messageID) + senderInfo := bus.SenderInfo{ + Platform: "onebot", + PlatformID: senderID, + CanonicalID: identity.BuildCanonicalID("onebot", senderID), + DisplayName: sender.Nickname, } - c.HandleMessage(senderID, chatID, content, parsed.Media, metadata) + if !c.IsAllowedSender(senderInfo) { + logger.DebugCF("onebot", "Message rejected by allowlist (senderInfo)", map[string]any{ + "sender": senderID, + }) + return + } + + c.HandleMessage(c.ctx, peer, messageID, senderID, chatID, content, parsed.Media, metadata, senderInfo) } func (c *OneBotChannel) isDuplicate(messageID string) bool { @@ -959,23 +1103,3 @@ func truncate(s string, n int) string { } return string(runes[:n]) + "..." } - -func (c *OneBotChannel) checkGroupTrigger( - content string, - isBotMentioned bool, -) (triggered bool, strippedContent string) { - if isBotMentioned { - return true, strings.TrimSpace(content) - } - - for _, prefix := range c.config.GroupTriggerPrefix { - if prefix == "" { - continue - } - if after, ok := strings.CutPrefix(content, prefix); ok { - return true, strings.TrimSpace(after) - } - } - - return false, content -} diff --git a/pkg/channels/pico/init.go b/pkg/channels/pico/init.go new file mode 100644 index 000000000..96d764418 --- /dev/null +++ b/pkg/channels/pico/init.go @@ -0,0 +1,13 @@ +package pico + +import ( + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" +) + +func init() { + channels.RegisterFactory("pico", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) { + return NewPicoChannel(cfg.Channels.Pico, b) + }) +} diff --git a/pkg/channels/pico/pico.go b/pkg/channels/pico/pico.go new file mode 100644 index 000000000..8d8b62a67 --- /dev/null +++ b/pkg/channels/pico/pico.go @@ -0,0 +1,462 @@ +package pico + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/google/uuid" + "github.com/gorilla/websocket" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/identity" + "github.com/sipeed/picoclaw/pkg/logger" +) + +// picoConn represents a single WebSocket connection. +type picoConn struct { + id string + conn *websocket.Conn + sessionID string + writeMu sync.Mutex + closed atomic.Bool +} + +// writeJSON sends a JSON message to the connection with write locking. +func (pc *picoConn) writeJSON(v any) error { + if pc.closed.Load() { + return fmt.Errorf("connection closed") + } + pc.writeMu.Lock() + defer pc.writeMu.Unlock() + return pc.conn.WriteJSON(v) +} + +// close closes the connection. +func (pc *picoConn) close() { + if pc.closed.CompareAndSwap(false, true) { + pc.conn.Close() + } +} + +// PicoChannel implements the native Pico Protocol WebSocket channel. +// It serves as the reference implementation for all optional capability interfaces. +type PicoChannel struct { + *channels.BaseChannel + config config.PicoConfig + upgrader websocket.Upgrader + connections sync.Map // connID → *picoConn + connCount atomic.Int32 + ctx context.Context + cancel context.CancelFunc +} + +// NewPicoChannel creates a new Pico Protocol channel. +func NewPicoChannel(cfg config.PicoConfig, messageBus *bus.MessageBus) (*PicoChannel, error) { + if cfg.Token == "" { + return nil, fmt.Errorf("pico token is required") + } + + base := channels.NewBaseChannel("pico", cfg, messageBus, cfg.AllowFrom) + + allowOrigins := cfg.AllowOrigins + checkOrigin := func(r *http.Request) bool { + if len(allowOrigins) == 0 { + return true // allow all if not configured + } + origin := r.Header.Get("Origin") + for _, allowed := range allowOrigins { + if allowed == "*" || allowed == origin { + return true + } + } + return false + } + + return &PicoChannel{ + BaseChannel: base, + config: cfg, + upgrader: websocket.Upgrader{ + CheckOrigin: checkOrigin, + ReadBufferSize: 1024, + WriteBufferSize: 1024, + }, + }, nil +} + +// Start implements Channel. +func (c *PicoChannel) Start(ctx context.Context) error { + logger.InfoC("pico", "Starting Pico Protocol channel") + c.ctx, c.cancel = context.WithCancel(ctx) + c.SetRunning(true) + logger.InfoC("pico", "Pico Protocol channel started") + return nil +} + +// Stop implements Channel. +func (c *PicoChannel) Stop(ctx context.Context) error { + logger.InfoC("pico", "Stopping Pico Protocol channel") + c.SetRunning(false) + + // Close all connections + c.connections.Range(func(key, value any) bool { + if pc, ok := value.(*picoConn); ok { + pc.close() + } + c.connections.Delete(key) + return true + }) + + if c.cancel != nil { + c.cancel() + } + + logger.InfoC("pico", "Pico Protocol channel stopped") + return nil +} + +// WebhookPath implements channels.WebhookHandler. +func (c *PicoChannel) WebhookPath() string { return "/pico/" } + +// ServeHTTP implements http.Handler for the shared HTTP server. +func (c *PicoChannel) ServeHTTP(w http.ResponseWriter, r *http.Request) { + path := strings.TrimPrefix(r.URL.Path, "/pico") + + switch { + case path == "/ws" || path == "/ws/": + c.handleWebSocket(w, r) + default: + http.NotFound(w, r) + } +} + +// Send implements Channel — sends a message to the appropriate WebSocket connection. +func (c *PicoChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { + if !c.IsRunning() { + return channels.ErrNotRunning + } + + outMsg := newMessage(TypeMessageCreate, map[string]any{ + "content": msg.Content, + }) + + return c.broadcastToSession(msg.ChatID, outMsg) +} + +// EditMessage implements channels.MessageEditor. +func (c *PicoChannel) EditMessage(ctx context.Context, chatID string, messageID string, content string) error { + outMsg := newMessage(TypeMessageUpdate, map[string]any{ + "message_id": messageID, + "content": content, + }) + return c.broadcastToSession(chatID, outMsg) +} + +// StartTyping implements channels.TypingCapable. +func (c *PicoChannel) StartTyping(ctx context.Context, chatID string) (func(), error) { + startMsg := newMessage(TypeTypingStart, nil) + if err := c.broadcastToSession(chatID, startMsg); err != nil { + return func() {}, err + } + return func() { + stopMsg := newMessage(TypeTypingStop, nil) + c.broadcastToSession(chatID, stopMsg) + }, nil +} + +// SendPlaceholder implements channels.PlaceholderCapable. +// It sends a placeholder message via the Pico Protocol that will later be +// edited to the actual response via EditMessage (channels.MessageEditor). +func (c *PicoChannel) SendPlaceholder(ctx context.Context, chatID string) (string, error) { + if !c.config.Placeholder.Enabled { + return "", nil + } + + text := c.config.Placeholder.Text + if text == "" { + text = "Thinking... 💭" + } + + msgID := uuid.New().String() + outMsg := newMessage(TypeMessageCreate, map[string]any{ + "content": text, + "message_id": msgID, + }) + + if err := c.broadcastToSession(chatID, outMsg); err != nil { + return "", err + } + + return msgID, nil +} + +// broadcastToSession sends a message to all connections with a matching session. +func (c *PicoChannel) broadcastToSession(chatID string, msg PicoMessage) error { + // chatID format: "pico:" + sessionID := strings.TrimPrefix(chatID, "pico:") + msg.SessionID = sessionID + + var sent bool + c.connections.Range(func(key, value any) bool { + pc, ok := value.(*picoConn) + if !ok { + return true + } + if pc.sessionID == sessionID { + if err := pc.writeJSON(msg); err != nil { + logger.DebugCF("pico", "Write to connection failed", map[string]any{ + "conn_id": pc.id, + "error": err.Error(), + }) + } else { + sent = true + } + } + return true + }) + + if !sent { + return fmt.Errorf("no active connections for session %s: %w", sessionID, channels.ErrSendFailed) + } + return nil +} + +// handleWebSocket upgrades the HTTP connection and manages the WebSocket lifecycle. +func (c *PicoChannel) handleWebSocket(w http.ResponseWriter, r *http.Request) { + if !c.IsRunning() { + http.Error(w, "channel not running", http.StatusServiceUnavailable) + return + } + + // Authenticate + if !c.authenticate(r) { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + + // Check connection limit + maxConns := c.config.MaxConnections + if maxConns <= 0 { + maxConns = 100 + } + if int(c.connCount.Load()) >= maxConns { + http.Error(w, "too many connections", http.StatusServiceUnavailable) + return + } + + conn, err := c.upgrader.Upgrade(w, r, nil) + if err != nil { + logger.ErrorCF("pico", "WebSocket upgrade failed", map[string]any{ + "error": err.Error(), + }) + return + } + + // Determine session ID from query param or generate one + sessionID := r.URL.Query().Get("session_id") + if sessionID == "" { + sessionID = uuid.New().String() + } + + pc := &picoConn{ + id: uuid.New().String(), + conn: conn, + sessionID: sessionID, + } + + c.connections.Store(pc.id, pc) + c.connCount.Add(1) + + logger.InfoCF("pico", "WebSocket client connected", map[string]any{ + "conn_id": pc.id, + "session_id": sessionID, + }) + + go c.readLoop(pc) +} + +// authenticate checks the Bearer token from the Authorization header. +// Query parameter authentication is only allowed when AllowTokenQuery is explicitly enabled. +func (c *PicoChannel) authenticate(r *http.Request) bool { + token := c.config.Token + if token == "" { + return false + } + + // Check Authorization header + auth := r.Header.Get("Authorization") + if after, ok := strings.CutPrefix(auth, "Bearer "); ok { + if after == token { + return true + } + } + + // Check query parameter only when explicitly allowed + if c.config.AllowTokenQuery { + if r.URL.Query().Get("token") == token { + return true + } + } + + return false +} + +// readLoop reads messages from a WebSocket connection. +func (c *PicoChannel) readLoop(pc *picoConn) { + defer func() { + pc.close() + c.connections.Delete(pc.id) + c.connCount.Add(-1) + logger.InfoCF("pico", "WebSocket client disconnected", map[string]any{ + "conn_id": pc.id, + "session_id": pc.sessionID, + }) + }() + + readTimeout := time.Duration(c.config.ReadTimeout) * time.Second + if readTimeout <= 0 { + readTimeout = 60 * time.Second + } + + _ = pc.conn.SetReadDeadline(time.Now().Add(readTimeout)) + pc.conn.SetPongHandler(func(appData string) error { + _ = pc.conn.SetReadDeadline(time.Now().Add(readTimeout)) + return nil + }) + + // Start ping ticker + pingInterval := time.Duration(c.config.PingInterval) * time.Second + if pingInterval <= 0 { + pingInterval = 30 * time.Second + } + go c.pingLoop(pc, pingInterval) + + for { + select { + case <-c.ctx.Done(): + return + default: + } + + _, rawMsg, err := pc.conn.ReadMessage() + if err != nil { + if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseNormalClosure) { + logger.DebugCF("pico", "WebSocket read error", map[string]any{ + "conn_id": pc.id, + "error": err.Error(), + }) + } + return + } + + _ = pc.conn.SetReadDeadline(time.Now().Add(readTimeout)) + + var msg PicoMessage + if err := json.Unmarshal(rawMsg, &msg); err != nil { + errMsg := newError("invalid_message", "failed to parse message") + pc.writeJSON(errMsg) + continue + } + + c.handleMessage(pc, msg) + } +} + +// pingLoop sends periodic ping frames to keep the connection alive. +func (c *PicoChannel) pingLoop(pc *picoConn, interval time.Duration) { + ticker := time.NewTicker(interval) + defer ticker.Stop() + + for { + select { + case <-c.ctx.Done(): + return + case <-ticker.C: + if pc.closed.Load() { + return + } + pc.writeMu.Lock() + err := pc.conn.WriteMessage(websocket.PingMessage, nil) + pc.writeMu.Unlock() + if err != nil { + return + } + } + } +} + +// handleMessage processes an inbound Pico Protocol message. +func (c *PicoChannel) handleMessage(pc *picoConn, msg PicoMessage) { + switch msg.Type { + case TypePing: + pong := newMessage(TypePong, nil) + pong.ID = msg.ID + pc.writeJSON(pong) + + case TypeMessageSend: + c.handleMessageSend(pc, msg) + + default: + errMsg := newError("unknown_type", fmt.Sprintf("unknown message type: %s", msg.Type)) + pc.writeJSON(errMsg) + } +} + +// handleMessageSend processes an inbound message.send from a client. +func (c *PicoChannel) handleMessageSend(pc *picoConn, msg PicoMessage) { + content, _ := msg.Payload["content"].(string) + if strings.TrimSpace(content) == "" { + errMsg := newError("empty_content", "message content is empty") + pc.writeJSON(errMsg) + return + } + + sessionID := msg.SessionID + if sessionID == "" { + sessionID = pc.sessionID + } + + chatID := "pico:" + sessionID + senderID := "pico-user" + + peer := bus.Peer{Kind: "direct", ID: "pico:" + sessionID} + + metadata := map[string]string{ + "platform": "pico", + "session_id": sessionID, + "conn_id": pc.id, + } + + logger.DebugCF("pico", "Received message", map[string]any{ + "session_id": sessionID, + "preview": truncate(content, 50), + }) + + sender := bus.SenderInfo{ + Platform: "pico", + PlatformID: senderID, + CanonicalID: identity.BuildCanonicalID("pico", senderID), + } + + if !c.IsAllowedSender(sender) { + return + } + + c.HandleMessage(c.ctx, peer, msg.ID, senderID, chatID, content, nil, metadata, sender) +} + +// truncate truncates a string to maxLen runes. +func truncate(s string, maxLen int) string { + runes := []rune(s) + if len(runes) <= maxLen { + return s + } + return string(runes[:maxLen]) + "..." +} diff --git a/pkg/channels/pico/protocol.go b/pkg/channels/pico/protocol.go new file mode 100644 index 000000000..0a630e193 --- /dev/null +++ b/pkg/channels/pico/protocol.go @@ -0,0 +1,46 @@ +package pico + +import "time" + +// Protocol message types. +const ( + // TypeMessageSend is sent from client to server. + TypeMessageSend = "message.send" + TypeMediaSend = "media.send" + TypePing = "ping" + + // TypeMessageCreate is sent from server to client. + TypeMessageCreate = "message.create" + TypeMessageUpdate = "message.update" + TypeMediaCreate = "media.create" + TypeTypingStart = "typing.start" + TypeTypingStop = "typing.stop" + TypeError = "error" + TypePong = "pong" +) + +// PicoMessage is the wire format for all Pico Protocol messages. +type PicoMessage struct { + Type string `json:"type"` + ID string `json:"id,omitempty"` + SessionID string `json:"session_id,omitempty"` + Timestamp int64 `json:"timestamp,omitempty"` + Payload map[string]any `json:"payload,omitempty"` +} + +// newMessage creates a PicoMessage with the given type and payload. +func newMessage(msgType string, payload map[string]any) PicoMessage { + return PicoMessage{ + Type: msgType, + Timestamp: time.Now().UnixMilli(), + Payload: payload, + } +} + +// newError creates an error PicoMessage. +func newError(code, message string) PicoMessage { + return newMessage(TypeError, map[string]any{ + "code": code, + "message": message, + }) +} diff --git a/pkg/channels/qq/init.go b/pkg/channels/qq/init.go new file mode 100644 index 000000000..15b955089 --- /dev/null +++ b/pkg/channels/qq/init.go @@ -0,0 +1,13 @@ +package qq + +import ( + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" +) + +func init() { + channels.RegisterFactory("qq", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) { + return NewQQChannel(cfg.Channels.QQ, b) + }) +} diff --git a/pkg/channels/qq.go b/pkg/channels/qq/qq.go similarity index 74% rename from pkg/channels/qq.go rename to pkg/channels/qq/qq.go index b10776db6..112964143 100644 --- a/pkg/channels/qq.go +++ b/pkg/channels/qq/qq.go @@ -1,4 +1,4 @@ -package channels +package qq import ( "context" @@ -14,12 +14,14 @@ import ( "golang.org/x/oauth2" "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/identity" "github.com/sipeed/picoclaw/pkg/logger" ) type QQChannel struct { - *BaseChannel + *channels.BaseChannel config config.QQConfig api openapi.OpenAPI tokenSource oauth2.TokenSource @@ -31,7 +33,10 @@ type QQChannel struct { } func NewQQChannel(cfg config.QQConfig, messageBus *bus.MessageBus) (*QQChannel, error) { - base := NewBaseChannel("qq", cfg, messageBus, cfg.AllowFrom) + base := channels.NewBaseChannel("qq", cfg, messageBus, cfg.AllowFrom, + channels.WithGroupTrigger(cfg.GroupTrigger), + channels.WithReasoningChannelID(cfg.ReasoningChannelID), + ) return &QQChannel{ BaseChannel: base, @@ -90,11 +95,11 @@ func (c *QQChannel) Start(ctx context.Context) error { logger.ErrorCF("qq", "WebSocket session error", map[string]any{ "error": err.Error(), }) - c.setRunning(false) + c.SetRunning(false) } }() - c.setRunning(true) + c.SetRunning(true) logger.InfoC("qq", "QQ bot started successfully") return nil @@ -102,7 +107,7 @@ func (c *QQChannel) Start(ctx context.Context) error { func (c *QQChannel) Stop(ctx context.Context) error { logger.InfoC("qq", "Stopping QQ bot") - c.setRunning(false) + c.SetRunning(false) if c.cancel != nil { c.cancel() @@ -113,7 +118,7 @@ func (c *QQChannel) Stop(ctx context.Context) error { func (c *QQChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { if !c.IsRunning() { - return fmt.Errorf("QQ bot not running") + return channels.ErrNotRunning } // construct message @@ -127,7 +132,7 @@ func (c *QQChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { logger.ErrorCF("qq", "Failed to send C2C message", map[string]any{ "error": err.Error(), }) - return err + return fmt.Errorf("qq send: %w", channels.ErrTemporary) } return nil @@ -162,20 +167,35 @@ func (c *QQChannel) handleC2CMessage() event.C2CMessageEventHandler { "length": len(content), }) - // forward to message bus - metadata := map[string]string{ - "message_id": data.ID, - "peer_kind": "direct", - "peer_id": senderID, + // 转发到消息总线 + metadata := map[string]string{} + + sender := bus.SenderInfo{ + Platform: "qq", + PlatformID: data.Author.ID, + CanonicalID: identity.BuildCanonicalID("qq", data.Author.ID), } - c.HandleMessage(senderID, senderID, content, []string{}, metadata) + if !c.IsAllowedSender(sender) { + return nil + } + + c.HandleMessage(c.ctx, + bus.Peer{Kind: "direct", ID: senderID}, + data.ID, + senderID, + senderID, + content, + []string{}, + metadata, + sender, + ) return nil } } -// handleGroupATMessage handles group @messages +// handleGroupATMessage handles QQ group @ messages func (c *QQChannel) handleGroupATMessage() event.GroupATMessageEventHandler { return func(event *dto.WSPayload, data *dto.WSGroupATMessageData) error { // deduplication check @@ -192,34 +212,57 @@ func (c *QQChannel) handleGroupATMessage() event.GroupATMessageEventHandler { return nil } - // extract message content (remove @bot part) + // extract message content (remove @ bot part) content := data.Content if content == "" { logger.DebugC("qq", "Received empty group message, ignoring") return nil } + // GroupAT event means bot is always mentioned; apply group trigger filtering + respond, cleaned := c.ShouldRespondInGroup(true, content) + if !respond { + return nil + } + content = cleaned + logger.InfoCF("qq", "Received group AT message", map[string]any{ "sender": senderID, "group": data.GroupID, "length": len(content), }) - // forward to message bus (use GroupID as ChatID) + // 转发到消息总线(使用 GroupID 作为 ChatID) metadata := map[string]string{ - "message_id": data.ID, - "group_id": data.GroupID, - "peer_kind": "group", - "peer_id": data.GroupID, + "group_id": data.GroupID, } - c.HandleMessage(senderID, data.GroupID, content, []string{}, metadata) + sender := bus.SenderInfo{ + Platform: "qq", + PlatformID: data.Author.ID, + CanonicalID: identity.BuildCanonicalID("qq", data.Author.ID), + } + + if !c.IsAllowedSender(sender) { + return nil + } + + c.HandleMessage(c.ctx, + bus.Peer{Kind: "group", ID: data.GroupID}, + data.ID, + senderID, + data.GroupID, + content, + []string{}, + metadata, + sender, + ) return nil } } -// isDuplicate checks if message is duplicate +// isDuplicate 检查消息是否重复 func (c *QQChannel) isDuplicate(messageID string) bool { c.mu.Lock() defer c.mu.Unlock() @@ -230,9 +273,9 @@ func (c *QQChannel) isDuplicate(messageID string) bool { c.processedIDs[messageID] = true - // simple cleanup: limit map size + // 简单清理:限制 map 大小 if len(c.processedIDs) > 10000 { - // clear half + // 清空一半 count := 0 for id := range c.processedIDs { if count >= 5000 { diff --git a/pkg/channels/registry.go b/pkg/channels/registry.go new file mode 100644 index 000000000..36a05bf3e --- /dev/null +++ b/pkg/channels/registry.go @@ -0,0 +1,32 @@ +package channels + +import ( + "sync" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/config" +) + +// ChannelFactory is a constructor function that creates a Channel from config and message bus. +// Each channel subpackage registers one or more factories via init(). +type ChannelFactory func(cfg *config.Config, bus *bus.MessageBus) (Channel, error) + +var ( + factoriesMu sync.RWMutex + factories = map[string]ChannelFactory{} +) + +// RegisterFactory registers a named channel factory. Called from subpackage init() functions. +func RegisterFactory(name string, f ChannelFactory) { + factoriesMu.Lock() + defer factoriesMu.Unlock() + factories[name] = f +} + +// getFactory looks up a channel factory by name. +func getFactory(name string) (ChannelFactory, bool) { + factoriesMu.RLock() + defer factoriesMu.RUnlock() + f, ok := factories[name] + return f, ok +} diff --git a/pkg/channels/slack/init.go b/pkg/channels/slack/init.go new file mode 100644 index 000000000..c131bb291 --- /dev/null +++ b/pkg/channels/slack/init.go @@ -0,0 +1,13 @@ +package slack + +import ( + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" +) + +func init() { + channels.RegisterFactory("slack", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) { + return NewSlackChannel(cfg.Channels.Slack, b) + }) +} diff --git a/pkg/channels/slack.go b/pkg/channels/slack/slack.go similarity index 66% rename from pkg/channels/slack.go rename to pkg/channels/slack/slack.go index cfb731b16..024b1b023 100644 --- a/pkg/channels/slack.go +++ b/pkg/channels/slack/slack.go @@ -1,32 +1,31 @@ -package channels +package slack import ( "context" "fmt" - "os" "strings" "sync" - "time" "github.com/slack-go/slack" "github.com/slack-go/slack/slackevents" "github.com/slack-go/slack/socketmode" "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/identity" "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/media" "github.com/sipeed/picoclaw/pkg/utils" - "github.com/sipeed/picoclaw/pkg/voice" ) type SlackChannel struct { - *BaseChannel + *channels.BaseChannel config config.SlackConfig api *slack.Client socketClient *socketmode.Client botUserID string teamID string - transcriber *voice.GroqTranscriber ctx context.Context cancel context.CancelFunc pendingAcks sync.Map @@ -49,7 +48,11 @@ func NewSlackChannel(cfg config.SlackConfig, messageBus *bus.MessageBus) (*Slack socketClient := socketmode.New(api) - base := NewBaseChannel("slack", cfg, messageBus, cfg.AllowFrom) + base := channels.NewBaseChannel("slack", cfg, messageBus, cfg.AllowFrom, + channels.WithMaxMessageLength(40000), + channels.WithGroupTrigger(cfg.GroupTrigger), + channels.WithReasoningChannelID(cfg.ReasoningChannelID), + ) return &SlackChannel{ BaseChannel: base, @@ -59,10 +62,6 @@ func NewSlackChannel(cfg config.SlackConfig, messageBus *bus.MessageBus) (*Slack }, nil } -func (c *SlackChannel) SetTranscriber(transcriber *voice.GroqTranscriber) { - c.transcriber = transcriber -} - func (c *SlackChannel) Start(ctx context.Context) error { logger.InfoC("slack", "Starting Slack channel (Socket Mode)") @@ -92,7 +91,7 @@ func (c *SlackChannel) Start(ctx context.Context) error { } }() - c.setRunning(true) + c.SetRunning(true) logger.InfoC("slack", "Slack channel started (Socket Mode)") return nil } @@ -104,14 +103,14 @@ func (c *SlackChannel) Stop(ctx context.Context) error { c.cancel() } - c.setRunning(false) + c.SetRunning(false) logger.InfoC("slack", "Slack channel stopped") return nil } func (c *SlackChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { if !c.IsRunning() { - return fmt.Errorf("slack channel not running") + return channels.ErrNotRunning } channelID, threadTS := parseSlackChatID(msg.ChatID) @@ -129,7 +128,7 @@ func (c *SlackChannel) Send(ctx context.Context, msg bus.OutboundMessage) error _, _, err := c.api.PostMessageContext(ctx, channelID, opts...) if err != nil { - return fmt.Errorf("failed to send slack message: %w", err) + return fmt.Errorf("slack send: %w", channels.ErrTemporary) } if ref, ok := c.pendingAcks.LoadAndDelete(msg.ChatID); ok { @@ -148,6 +147,82 @@ func (c *SlackChannel) Send(ctx context.Context, msg bus.OutboundMessage) error return nil } +// SendMedia implements the channels.MediaSender interface. +func (c *SlackChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error { + if !c.IsRunning() { + return channels.ErrNotRunning + } + + channelID, _ := parseSlackChatID(msg.ChatID) + if channelID == "" { + return fmt.Errorf("invalid slack chat ID: %s", msg.ChatID) + } + + store := c.GetMediaStore() + if store == nil { + return fmt.Errorf("no media store available: %w", channels.ErrSendFailed) + } + + for _, part := range msg.Parts { + localPath, err := store.Resolve(part.Ref) + if err != nil { + logger.ErrorCF("slack", "Failed to resolve media ref", map[string]any{ + "ref": part.Ref, + "error": err.Error(), + }) + continue + } + + filename := part.Filename + if filename == "" { + filename = "file" + } + + title := part.Caption + if title == "" { + title = filename + } + + _, err = c.api.UploadFileV2Context(ctx, slack.UploadFileV2Parameters{ + Channel: channelID, + File: localPath, + Filename: filename, + Title: title, + }) + if err != nil { + logger.ErrorCF("slack", "Failed to upload media", map[string]any{ + "filename": filename, + "error": err.Error(), + }) + return fmt.Errorf("slack send media: %w", channels.ErrTemporary) + } + } + + return nil +} + +// ReactToMessage implements channels.ReactionCapable. +// It adds an "eyes" (👀) reaction to the inbound message and returns an undo function +// that removes the reaction. +func (c *SlackChannel) ReactToMessage(ctx context.Context, chatID, messageID string) (func(), error) { + channelID, _ := parseSlackChatID(chatID) + if channelID == "" { + return func() {}, nil + } + + c.api.AddReaction("eyes", slack.ItemRef{ + Channel: channelID, + Timestamp: messageID, + }) + + return func() { + c.api.RemoveReaction("eyes", slack.ItemRef{ + Channel: channelID, + Timestamp: messageID, + }) + }, nil +} + func (c *SlackChannel) eventLoop() { for { select { @@ -201,7 +276,12 @@ func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) { } // check allowlist to avoid downloading attachments for rejected users - if !c.IsAllowed(ev.User) { + sender := bus.SenderInfo{ + Platform: "slack", + PlatformID: ev.User, + CanonicalID: identity.BuildCanonicalID("slack", ev.User), + } + if !c.IsAllowedSender(sender) { logger.DebugCF("slack", "Message rejected by allowlist", map[string]any{ "user_id": ev.User, }) @@ -218,11 +298,6 @@ func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) { chatID = channelID + "/" + threadTS } - c.api.AddReaction("eyes", slack.ItemRef{ - Channel: channelID, - Timestamp: messageTS, - }) - c.pendingAcks.Store(chatID, slackMessageRef{ ChannelID: channelID, Timestamp: messageTS, @@ -231,20 +306,32 @@ func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) { content := ev.Text content = c.stripBotMention(content) - var mediaPaths []string - localFiles := []string{} // track local files that need cleanup + // In non-DM channels, apply group trigger filtering + if !strings.HasPrefix(channelID, "D") { + respond, cleaned := c.ShouldRespondInGroup(false, content) + if !respond { + return + } + content = cleaned + } - // ensure temp files are cleaned up when function returns - defer func() { - for _, file := range localFiles { - if err := os.Remove(file); err != nil { - logger.DebugCF("slack", "Failed to cleanup temp file", map[string]any{ - "file": file, - "error": err.Error(), - }) + var mediaPaths []string + + scope := channels.BuildMediaScope("slack", chatID, messageTS) + + // Helper to register a local file with the media store + storeMedia := func(localPath, filename string) string { + if store := c.GetMediaStore(); store != nil { + ref, err := store.Store(localPath, media.MediaMeta{ + Filename: filename, + Source: "slack", + }, scope) + if err == nil { + return ref } } - }() + return localPath // fallback + } if ev.Message != nil && len(ev.Message.Files) > 0 { for _, file := range ev.Message.Files { @@ -252,23 +339,8 @@ func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) { if localPath == "" { continue } - localFiles = append(localFiles, localPath) - mediaPaths = append(mediaPaths, localPath) - - if utils.IsAudioFile(file.Name, file.Mimetype) && c.transcriber != nil && c.transcriber.IsAvailable() { - ctx, cancel := context.WithTimeout(c.ctx, 30*time.Second) - defer cancel() - result, err := c.transcriber.Transcribe(ctx, localPath) - - if err != nil { - logger.ErrorCF("slack", "Voice transcription failed", map[string]any{"error": err.Error()}) - content += fmt.Sprintf("\n[audio: %s (transcription failed)]", file.Name) - } else { - content += fmt.Sprintf("\n[voice transcription: %s]", result.Text) - } - } else { - content += fmt.Sprintf("\n[file: %s]", file.Name) - } + mediaPaths = append(mediaPaths, storeMedia(localPath, file.Name)) + content += fmt.Sprintf("\n[file: %s]", file.Name) } } @@ -283,13 +355,13 @@ func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) { peerID = senderID } + peer := bus.Peer{Kind: peerKind, ID: peerID} + metadata := map[string]string{ "message_ts": messageTS, "channel_id": channelID, "thread_ts": threadTS, "platform": "slack", - "peer_kind": peerKind, - "peer_id": peerID, "team_id": c.teamID, } @@ -300,7 +372,7 @@ func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) { "has_thread": threadTS != "", }) - c.HandleMessage(senderID, chatID, content, mediaPaths, metadata) + c.HandleMessage(c.ctx, peer, messageTS, senderID, chatID, content, mediaPaths, metadata, sender) } func (c *SlackChannel) handleAppMention(ev *slackevents.AppMentionEvent) { @@ -308,7 +380,11 @@ func (c *SlackChannel) handleAppMention(ev *slackevents.AppMentionEvent) { return } - if !c.IsAllowed(ev.User) { + if !c.IsAllowedSender(bus.SenderInfo{ + Platform: "slack", + PlatformID: ev.User, + CanonicalID: identity.BuildCanonicalID("slack", ev.User), + }) { logger.DebugCF("slack", "Mention rejected by allowlist", map[string]any{ "user_id": ev.User, }) @@ -316,6 +392,11 @@ func (c *SlackChannel) handleAppMention(ev *slackevents.AppMentionEvent) { } senderID := ev.User + mentionSender := bus.SenderInfo{ + Platform: "slack", + PlatformID: senderID, + CanonicalID: identity.BuildCanonicalID("slack", senderID), + } channelID := ev.Channel threadTS := ev.ThreadTimeStamp messageTS := ev.TimeStamp @@ -327,11 +408,6 @@ func (c *SlackChannel) handleAppMention(ev *slackevents.AppMentionEvent) { chatID = channelID + "/" + messageTS } - c.api.AddReaction("eyes", slack.ItemRef{ - Channel: channelID, - Timestamp: messageTS, - }) - c.pendingAcks.Store(chatID, slackMessageRef{ ChannelID: channelID, Timestamp: messageTS, @@ -350,18 +426,18 @@ func (c *SlackChannel) handleAppMention(ev *slackevents.AppMentionEvent) { mentionPeerID = senderID } + mentionPeer := bus.Peer{Kind: mentionPeerKind, ID: mentionPeerID} + metadata := map[string]string{ "message_ts": messageTS, "channel_id": channelID, "thread_ts": threadTS, "platform": "slack", "is_mention": "true", - "peer_kind": mentionPeerKind, - "peer_id": mentionPeerID, "team_id": c.teamID, } - c.HandleMessage(senderID, chatID, content, nil, metadata) + c.HandleMessage(c.ctx, mentionPeer, messageTS, senderID, chatID, content, nil, metadata, mentionSender) } func (c *SlackChannel) handleSlashCommand(event socketmode.Event) { @@ -374,7 +450,12 @@ func (c *SlackChannel) handleSlashCommand(event socketmode.Event) { c.socketClient.Ack(*event.Request) } - if !c.IsAllowed(cmd.UserID) { + cmdSender := bus.SenderInfo{ + Platform: "slack", + PlatformID: cmd.UserID, + CanonicalID: identity.BuildCanonicalID("slack", cmd.UserID), + } + if !c.IsAllowedSender(cmdSender) { logger.DebugCF("slack", "Slash command rejected by allowlist", map[string]any{ "user_id": cmd.UserID, }) @@ -395,8 +476,6 @@ func (c *SlackChannel) handleSlashCommand(event socketmode.Event) { "platform": "slack", "is_command": "true", "trigger_id": cmd.TriggerID, - "peer_kind": "channel", - "peer_id": channelID, "team_id": c.teamID, } @@ -406,7 +485,17 @@ func (c *SlackChannel) handleSlashCommand(event socketmode.Event) { "text": utils.Truncate(content, 50), }) - c.HandleMessage(senderID, chatID, content, nil, metadata) + c.HandleMessage( + c.ctx, + bus.Peer{Kind: "channel", ID: channelID}, + "", + senderID, + chatID, + content, + nil, + metadata, + cmdSender, + ) } func (c *SlackChannel) downloadSlackFile(file slack.File) string { diff --git a/pkg/channels/slack_test.go b/pkg/channels/slack/slack_test.go similarity index 99% rename from pkg/channels/slack_test.go rename to pkg/channels/slack/slack_test.go index 3707c2703..30e0d2d73 100644 --- a/pkg/channels/slack_test.go +++ b/pkg/channels/slack/slack_test.go @@ -1,4 +1,4 @@ -package channels +package slack import ( "testing" diff --git a/pkg/channels/split.go b/pkg/channels/split.go new file mode 100644 index 000000000..bb26c6d8f --- /dev/null +++ b/pkg/channels/split.go @@ -0,0 +1,208 @@ +package channels + +import ( + "strings" +) + +// SplitMessage splits long messages into chunks, preserving code block integrity. +// The maxLen parameter is measured in runes (Unicode characters), not bytes. +// The function reserves a buffer (10% of maxLen, min 50) to leave room for closing code blocks, +// but may extend to maxLen when needed. +// Call SplitMessage with the full text content and the maximum allowed length of a single message; +// it returns a slice of message chunks that each respect maxLen and avoid splitting fenced code blocks. +func SplitMessage(content string, maxLen int) []string { + if maxLen <= 0 { + if content == "" { + return nil + } + return []string{content} + } + + runes := []rune(content) + totalLen := len(runes) + var messages []string + + // Dynamic buffer: 10% of maxLen, but at least 50 chars if possible + codeBlockBuffer := max(maxLen/10, 50) + if codeBlockBuffer > maxLen/2 { + codeBlockBuffer = maxLen / 2 + } + + start := 0 + for start < totalLen { + remaining := totalLen - start + if remaining <= maxLen { + messages = append(messages, string(runes[start:totalLen])) + break + } + + // Effective split point: maxLen minus buffer, to leave room for code blocks + effectiveLimit := max(maxLen-codeBlockBuffer, maxLen/2) + + end := start + effectiveLimit + + // Find natural split point within the effective limit + msgEnd := findLastNewlineInRange(runes, start, end, 200) + if msgEnd <= start { + msgEnd = findLastSpaceInRange(runes, start, end, 100) + } + if msgEnd <= start { + msgEnd = end + } + + // Check if this would end with an incomplete code block + unclosedIdx := findLastUnclosedCodeBlockInRange(runes, start, msgEnd) + + if unclosedIdx >= 0 { + // Message would end with incomplete code block + // Try to extend up to maxLen to include the closing ``` + if totalLen > msgEnd { + closingIdx := findNextClosingCodeBlockInRange(runes, msgEnd, totalLen) + if closingIdx > 0 && closingIdx-start <= maxLen { + // Extend to include the closing ``` + msgEnd = closingIdx + } else { + // Code block is too long to fit in one chunk or missing closing fence. + // Try to split inside by injecting closing and reopening fences. + headerEnd := findNewlineFrom(runes, unclosedIdx) + var header string + if headerEnd == -1 { + header = strings.TrimSpace(string(runes[unclosedIdx : unclosedIdx+3])) + } else { + header = strings.TrimSpace(string(runes[unclosedIdx:headerEnd])) + } + headerEndIdx := unclosedIdx + len([]rune(header)) + if headerEnd != -1 { + headerEndIdx = headerEnd + } + + // If we have a reasonable amount of content after the header, split inside + if msgEnd > headerEndIdx+20 { + // Find a better split point closer to maxLen + innerLimit := min( + // Leave room for "\n```" + start+maxLen-5, totalLen) + betterEnd := findLastNewlineInRange(runes, start, innerLimit, 200) + if betterEnd > headerEndIdx { + msgEnd = betterEnd + } else { + msgEnd = innerLimit + } + chunk := strings.TrimRight(string(runes[start:msgEnd]), " \t\n\r") + "\n```" + messages = append(messages, chunk) + remaining := strings.TrimSpace(header + "\n" + string(runes[msgEnd:totalLen])) + // Replace the tail of runes with the reconstructed remaining + runes = []rune(remaining) + totalLen = len(runes) + start = 0 + continue + } + + // Otherwise, try to split before the code block starts + newEnd := findLastNewlineInRange(runes, start, unclosedIdx, 200) + if newEnd <= start { + newEnd = findLastSpaceInRange(runes, start, unclosedIdx, 100) + } + if newEnd > start { + msgEnd = newEnd + } else { + // If we can't split before, we MUST split inside (last resort) + if unclosedIdx-start > 20 { + msgEnd = unclosedIdx + } else { + splitAt := min(start+maxLen-5, totalLen) + chunk := strings.TrimRight(string(runes[start:splitAt]), " \t\n\r") + "\n```" + messages = append(messages, chunk) + remaining := strings.TrimSpace(header + "\n" + string(runes[splitAt:totalLen])) + runes = []rune(remaining) + totalLen = len(runes) + start = 0 + continue + } + } + } + } + } + + if msgEnd <= start { + msgEnd = start + effectiveLimit + } + + messages = append(messages, string(runes[start:msgEnd])) + // Advance start, skipping leading whitespace of next chunk + start = msgEnd + for start < totalLen && (runes[start] == ' ' || runes[start] == '\t' || runes[start] == '\n' || runes[start] == '\r') { + start++ + } + } + + return messages +} + +// findLastUnclosedCodeBlockInRange finds the last opening ``` that doesn't have a closing ``` +// within runes[start:end]. Returns the absolute rune index or -1. +func findLastUnclosedCodeBlockInRange(runes []rune, start, end int) int { + inCodeBlock := false + lastOpenIdx := -1 + + for i := start; i < end; i++ { + if i+2 < end && runes[i] == '`' && runes[i+1] == '`' && runes[i+2] == '`' { + if !inCodeBlock { + lastOpenIdx = i + } + inCodeBlock = !inCodeBlock + i += 2 + } + } + + if inCodeBlock { + return lastOpenIdx + } + return -1 +} + +// findNextClosingCodeBlockInRange finds the next closing ``` starting from startIdx +// within runes[startIdx:end]. Returns the absolute index after the closing ``` or -1. +func findNextClosingCodeBlockInRange(runes []rune, startIdx, end int) int { + for i := startIdx; i < end; i++ { + if i+2 < end && runes[i] == '`' && runes[i+1] == '`' && runes[i+2] == '`' { + return i + 3 + } + } + return -1 +} + +// findNewlineFrom finds the first newline character starting from the given index. +// Returns the absolute index or -1 if not found. +func findNewlineFrom(runes []rune, from int) int { + for i := from; i < len(runes); i++ { + if runes[i] == '\n' { + return i + } + } + return -1 +} + +// findLastNewlineInRange finds the last newline within the last searchWindow runes +// of the range runes[start:end]. Returns the absolute index or start-1 (indicating not found). +func findLastNewlineInRange(runes []rune, start, end, searchWindow int) int { + searchStart := max(end-searchWindow, start) + for i := end - 1; i >= searchStart; i-- { + if runes[i] == '\n' { + return i + } + } + return start - 1 +} + +// findLastSpaceInRange finds the last space/tab within the last searchWindow runes +// of the range runes[start:end]. Returns the absolute index or start-1 (indicating not found). +func findLastSpaceInRange(runes []rune, start, end, searchWindow int) int { + searchStart := max(end-searchWindow, start) + for i := end - 1; i >= searchStart; i-- { + if runes[i] == ' ' || runes[i] == '\t' { + return i + } + } + return start - 1 +} diff --git a/pkg/channels/split_test.go b/pkg/channels/split_test.go new file mode 100644 index 000000000..a922f9558 --- /dev/null +++ b/pkg/channels/split_test.go @@ -0,0 +1,362 @@ +package channels + +import ( + "strings" + "testing" +) + +func TestSplitMessage(t *testing.T) { + longText := strings.Repeat("a", 2500) + longCode := "```go\n" + strings.Repeat("fmt.Println(\"hello\")\n", 100) + "```" // ~2100 chars + + tests := []struct { + name string + content string + maxLen int + expectChunks int // Check number of chunks + checkContent func(t *testing.T, chunks []string) // Custom validation + }{ + { + name: "Empty message", + content: "", + maxLen: 2000, + expectChunks: 0, + }, + { + name: "Short message fits in one chunk", + content: "Hello world", + maxLen: 2000, + expectChunks: 1, + }, + { + name: "Simple split regular text", + content: longText, + maxLen: 2000, + expectChunks: 2, + checkContent: func(t *testing.T, chunks []string) { + if len([]rune(chunks[0])) > 2000 { + t.Errorf("Chunk 0 too large: %d runes", len([]rune(chunks[0]))) + } + if len([]rune(chunks[0]))+len([]rune(chunks[1])) != len([]rune(longText)) { + t.Errorf( + "Total rune length mismatch. Got %d, want %d", + len([]rune(chunks[0]))+len([]rune(chunks[1])), + len([]rune(longText)), + ) + } + }, + }, + { + name: "Split at newline", + // 1750 chars then newline, then more chars. + // Dynamic buffer: 2000 / 10 = 200. + // Effective limit: 2000 - 200 = 1800. + // Split should happen at newline because it's at 1750 (< 1800). + // Total length must > 2000 to trigger split. 1750 + 1 + 300 = 2051. + content: strings.Repeat("a", 1750) + "\n" + strings.Repeat("b", 300), + maxLen: 2000, + expectChunks: 2, + checkContent: func(t *testing.T, chunks []string) { + if len([]rune(chunks[0])) != 1750 { + t.Errorf("Expected chunk 0 to be 1750 runes (split at newline), got %d", len([]rune(chunks[0]))) + } + if chunks[1] != strings.Repeat("b", 300) { + t.Errorf("Chunk 1 content mismatch. Len: %d", len([]rune(chunks[1]))) + } + }, + }, + { + name: "Long code block split", + content: "Prefix\n" + longCode, + maxLen: 2000, + expectChunks: 2, + checkContent: func(t *testing.T, chunks []string) { + // Check that first chunk ends with closing fence + if !strings.HasSuffix(chunks[0], "\n```") { + t.Error("First chunk should end with injected closing fence") + } + // Check that second chunk starts with execution header + if !strings.HasPrefix(chunks[1], "```go") { + t.Error("Second chunk should start with injected code block header") + } + }, + }, + { + name: "Preserve Unicode characters (rune-aware)", + content: strings.Repeat("\u4e16", 2500), // 2500 runes, 7500 bytes + maxLen: 2000, + expectChunks: 2, + checkContent: func(t *testing.T, chunks []string) { + // Verify chunks contain valid unicode and don't split mid-rune + for i, chunk := range chunks { + runeCount := len([]rune(chunk)) + if runeCount > 2000 { + t.Errorf("Chunk %d has %d runes, exceeds maxLen 2000", i, runeCount) + } + if !strings.Contains(chunk, "\u4e16") { + t.Errorf("Chunk %d should contain unicode characters", i) + } + } + // Verify total rune count is preserved + totalRunes := 0 + for _, chunk := range chunks { + totalRunes += len([]rune(chunk)) + } + if totalRunes != 2500 { + t.Errorf("Total rune count mismatch. Got %d, want 2500", totalRunes) + } + }, + }, + { + name: "Zero maxLen returns single chunk", + content: "Hello world", + maxLen: 0, + expectChunks: 1, + checkContent: func(t *testing.T, chunks []string) { + if chunks[0] != "Hello world" { + t.Errorf("Expected original content, got %q", chunks[0]) + } + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := SplitMessage(tc.content, tc.maxLen) + + if tc.expectChunks == 0 { + if len(got) != 0 { + t.Errorf("Expected 0 chunks, got %d", len(got)) + } + return + } + + if len(got) != tc.expectChunks { + t.Errorf("Expected %d chunks, got %d", tc.expectChunks, len(got)) + // Log sizes for debugging + for i, c := range got { + t.Logf("Chunk %d length: %d", i, len(c)) + } + return // Stop further checks if count assumes specific split + } + + if tc.checkContent != nil { + tc.checkContent(t, got) + } + }) + } +} + +// --- Helper function tests for index-based rune operations --- + +func TestFindLastNewlineInRange(t *testing.T) { + runes := []rune("aaa\nbbb\nccc") + // Indices: 0123 4567 89 10 + + tests := []struct { + name string + start, end int + searchWindow int + want int + }{ + {"finds last newline in full range", 0, 11, 200, 7}, + {"finds newline within search window", 0, 11, 4, 7}, + {"narrow window misses newline outside window", 4, 11, 3, 3}, // returns start-1 (not found) + {"no newline in range", 0, 3, 200, -1}, // start-1 = -1 + {"range limited to first segment", 0, 4, 200, 3}, + {"search window of 1 at newline", 0, 8, 1, 7}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := findLastNewlineInRange(runes, tc.start, tc.end, tc.searchWindow) + if got != tc.want { + t.Errorf("findLastNewlineInRange(runes, %d, %d, %d) = %d, want %d", + tc.start, tc.end, tc.searchWindow, got, tc.want) + } + }) + } +} + +func TestFindLastSpaceInRange(t *testing.T) { + runes := []rune("abc def\tghi") + // Indices: 0123 4567 89 10 + + tests := []struct { + name string + start, end int + searchWindow int + want int + }{ + {"finds tab as last space/tab", 0, 11, 200, 7}, + {"finds space when tab out of window", 0, 7, 200, 3}, + {"no space in range", 0, 3, 200, -1}, + {"narrow window finds tab", 5, 11, 4, 7}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := findLastSpaceInRange(runes, tc.start, tc.end, tc.searchWindow) + if got != tc.want { + t.Errorf("findLastSpaceInRange(runes, %d, %d, %d) = %d, want %d", + tc.start, tc.end, tc.searchWindow, got, tc.want) + } + }) + } +} + +func TestFindNewlineFrom(t *testing.T) { + runes := []rune("hello\nworld\n") + + tests := []struct { + name string + from int + want int + }{ + {"from start", 0, 5}, + {"from after first newline", 6, 11}, + {"from past all newlines", 12, -1}, + {"from newline itself", 5, 5}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := findNewlineFrom(runes, tc.from) + if got != tc.want { + t.Errorf("findNewlineFrom(runes, %d) = %d, want %d", tc.from, got, tc.want) + } + }) + } +} + +func TestFindLastUnclosedCodeBlockInRange(t *testing.T) { + tests := []struct { + name string + content string + start, end int + want int + }{ + { + name: "no code blocks", + content: "hello world", + start: 0, end: 11, + want: -1, + }, + { + name: "complete code block", + content: "```go\ncode\n```", + start: 0, end: 14, + want: -1, + }, + { + name: "unclosed code block", + content: "text\n```go\ncode here", + start: 0, end: 20, + want: 5, + }, + { + name: "closed then unclosed", + content: "```a\n```\n```b\ncode", + start: 0, end: 17, + want: 9, + }, + { + name: "search within subrange", + content: "```a\n```\n```b\ncode", + start: 9, end: 17, + want: 9, + }, + { + name: "subrange with no code blocks", + content: "```a\n```\nhello", + start: 9, end: 14, + want: -1, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + runes := []rune(tc.content) + got := findLastUnclosedCodeBlockInRange(runes, tc.start, tc.end) + if got != tc.want { + t.Errorf("findLastUnclosedCodeBlockInRange(%q, %d, %d) = %d, want %d", + tc.content, tc.start, tc.end, got, tc.want) + } + }) + } +} + +func TestFindNextClosingCodeBlockInRange(t *testing.T) { + tests := []struct { + name string + content string + startIdx int + end int + want int + }{ + { + name: "finds closing fence", + content: "code\n```\nmore", + startIdx: 0, end: 13, + want: 8, // position after ``` + }, + { + name: "no closing fence", + content: "just code here", + startIdx: 0, end: 14, + want: -1, + }, + { + name: "fence at start of search", + content: "```end", + startIdx: 0, end: 6, + want: 3, + }, + { + name: "fence outside range", + content: "code\n```", + startIdx: 0, end: 4, + want: -1, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + runes := []rune(tc.content) + got := findNextClosingCodeBlockInRange(runes, tc.startIdx, tc.end) + if got != tc.want { + t.Errorf("findNextClosingCodeBlockInRange(%q, %d, %d) = %d, want %d", + tc.content, tc.startIdx, tc.end, got, tc.want) + } + }) + } +} + +func TestSplitMessage_CodeBlockIntegrity(t *testing.T) { + // Focused test for the core requirement: splitting inside a code block preserves syntax highlighting + + // 60 chars total approximately + content := "```go\npackage main\n\nfunc main() {\n\tprintln(\"Hello\")\n}\n```" + maxLen := 40 + + chunks := SplitMessage(content, maxLen) + + if len(chunks) != 2 { + t.Fatalf("Expected 2 chunks, got %d: %q", len(chunks), chunks) + } + + // First chunk must end with "\n```" + if !strings.HasSuffix(chunks[0], "\n```") { + t.Errorf("First chunk should end with closing fence. Got: %q", chunks[0]) + } + + // Second chunk must start with the header "```go" + if !strings.HasPrefix(chunks[1], "```go") { + t.Errorf("Second chunk should start with code block header. Got: %q", chunks[1]) + } + + // First chunk should contain meaningful content + if len([]rune(chunks[0])) > 40 { + t.Errorf("First chunk exceeded maxLen: length %d runes", len([]rune(chunks[0]))) + } +} diff --git a/pkg/channels/telegram/init.go b/pkg/channels/telegram/init.go new file mode 100644 index 000000000..ac87bb805 --- /dev/null +++ b/pkg/channels/telegram/init.go @@ -0,0 +1,13 @@ +package telegram + +import ( + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" +) + +func init() { + channels.RegisterFactory("telegram", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) { + return NewTelegramChannel(cfg, b) + }) +} diff --git a/pkg/channels/telegram.go b/pkg/channels/telegram/telegram.go similarity index 53% rename from pkg/channels/telegram.go rename to pkg/channels/telegram/telegram.go index 6592d9bc0..a11cf53b8 100644 --- a/pkg/channels/telegram.go +++ b/pkg/channels/telegram/telegram.go @@ -1,4 +1,4 @@ -package channels +package telegram import ( "context" @@ -7,8 +7,8 @@ import ( "net/url" "os" "regexp" + "strconv" "strings" - "sync" "time" "github.com/mymmrac/telego" @@ -17,10 +17,12 @@ import ( tu "github.com/mymmrac/telego/telegoutil" "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/identity" "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/media" "github.com/sipeed/picoclaw/pkg/utils" - "github.com/sipeed/picoclaw/pkg/voice" ) var ( @@ -37,24 +39,14 @@ var ( ) type TelegramChannel struct { - *BaseChannel - bot *telego.Bot - commands TelegramCommander - config *config.Config - chatIDs map[string]int64 - transcriber *voice.GroqTranscriber - placeholders sync.Map // chatID -> messageID - stopThinking sync.Map // chatID -> thinkingCancel -} - -type thinkingCancel struct { - fn context.CancelFunc -} - -func (c *thinkingCancel) Cancel() { - if c != nil && c.fn != nil { - c.fn() - } + *channels.BaseChannel + bot *telego.Bot + bh *telegohandler.BotHandler + commands TelegramCommander + config *config.Config + chatIDs map[string]int64 + ctx context.Context + cancel context.CancelFunc } func NewTelegramChannel(cfg *config.Config, bus *bus.MessageBus) (*TelegramChannel, error) { @@ -85,38 +77,44 @@ func NewTelegramChannel(cfg *config.Config, bus *bus.MessageBus) (*TelegramChann return nil, fmt.Errorf("failed to create telegram bot: %w", err) } - base := NewBaseChannel("telegram", telegramCfg, bus, telegramCfg.AllowFrom) + base := channels.NewBaseChannel( + "telegram", + telegramCfg, + bus, + telegramCfg.AllowFrom, + channels.WithMaxMessageLength(4096), + channels.WithGroupTrigger(telegramCfg.GroupTrigger), + channels.WithReasoningChannelID(telegramCfg.ReasoningChannelID), + ) return &TelegramChannel{ - BaseChannel: base, - commands: NewTelegramCommands(bot, cfg), - bot: bot, - config: cfg, - chatIDs: make(map[string]int64), - transcriber: nil, - placeholders: sync.Map{}, - stopThinking: sync.Map{}, + BaseChannel: base, + commands: NewTelegramCommands(bot, cfg), + bot: bot, + config: cfg, + chatIDs: make(map[string]int64), }, nil } -func (c *TelegramChannel) SetTranscriber(transcriber *voice.GroqTranscriber) { - c.transcriber = transcriber -} - func (c *TelegramChannel) Start(ctx context.Context) error { logger.InfoC("telegram", "Starting Telegram bot (polling mode)...") - updates, err := c.bot.UpdatesViaLongPolling(ctx, &telego.GetUpdatesParams{ + c.ctx, c.cancel = context.WithCancel(ctx) + + updates, err := c.bot.UpdatesViaLongPolling(c.ctx, &telego.GetUpdatesParams{ Timeout: 30, }) if err != nil { + c.cancel() return fmt.Errorf("failed to start long polling: %w", err) } bh, err := telegohandler.NewBotHandler(c.bot, updates) if err != nil { + c.cancel() return fmt.Errorf("failed to create bot handler: %w", err) } + c.bh = bh bh.HandleMessage(func(ctx *th.Context, message telego.Message) error { c.commands.Help(ctx, message) @@ -138,59 +136,46 @@ func (c *TelegramChannel) Start(ctx context.Context) error { return c.handleMessage(ctx, &message) }, th.AnyMessage()) - c.setRunning(true) + c.SetRunning(true) logger.InfoCF("telegram", "Telegram bot connected", map[string]any{ "username": c.bot.Username(), }) go bh.Start() - go func() { - <-ctx.Done() - bh.Stop() - }() - return nil } func (c *TelegramChannel) Stop(ctx context.Context) error { logger.InfoC("telegram", "Stopping Telegram bot...") - c.setRunning(false) + c.SetRunning(false) + + // Stop the bot handler + if c.bh != nil { + c.bh.Stop() + } + + // Cancel our context (stops long polling) + if c.cancel != nil { + c.cancel() + } + return nil } func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { if !c.IsRunning() { - return fmt.Errorf("telegram bot not running") + return channels.ErrNotRunning } chatID, err := parseChatID(msg.ChatID) if err != nil { - return fmt.Errorf("invalid chat ID: %w", err) - } - - // Stop thinking animation - if stop, ok := c.stopThinking.Load(msg.ChatID); ok { - if cf, ok := stop.(*thinkingCancel); ok && cf != nil { - cf.Cancel() - } - c.stopThinking.Delete(msg.ChatID) + return fmt.Errorf("invalid chat ID %s: %w", msg.ChatID, channels.ErrSendFailed) } htmlContent := markdownToTelegramHTML(msg.Content) - // Try to edit placeholder - if pID, ok := c.placeholders.Load(msg.ChatID); ok { - c.placeholders.Delete(msg.ChatID) - editMsg := tu.EditMessageText(tu.ID(chatID), pID.(int), htmlContent) - editMsg.ParseMode = telego.ModeHTML - - if _, err = c.bot.EditMessageText(ctx, editMsg); err == nil { - return nil - } - // Fallback to new message if edit fails - } - + // Typing/placeholder handled by Manager.preSend — just send the message tgMsg := tu.Message(tu.ID(chatID), htmlContent) tgMsg.ParseMode = telego.ModeHTML @@ -199,9 +184,164 @@ func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) err "error": err.Error(), }) tgMsg.ParseMode = "" - _, err = c.bot.SendMessage(ctx, tgMsg) + if _, err = c.bot.SendMessage(ctx, tgMsg); err != nil { + return fmt.Errorf("telegram send: %w", channels.ErrTemporary) + } + } + + return nil +} + +// StartTyping implements channels.TypingCapable. +// It sends ChatAction(typing) immediately and then repeats every 4 seconds +// (Telegram's typing indicator expires after ~5s) in a background goroutine. +// The returned stop function is idempotent and cancels the goroutine. +func (c *TelegramChannel) StartTyping(ctx context.Context, chatID string) (func(), error) { + cid, err := parseChatID(chatID) + if err != nil { + return func() {}, err + } + + // Send the first typing action immediately + _ = c.bot.SendChatAction(ctx, tu.ChatAction(tu.ID(cid), telego.ChatActionTyping)) + + typingCtx, cancel := context.WithCancel(ctx) + go func() { + ticker := time.NewTicker(4 * time.Second) + defer ticker.Stop() + for { + select { + case <-typingCtx.Done(): + return + case <-ticker.C: + _ = c.bot.SendChatAction(typingCtx, tu.ChatAction(tu.ID(cid), telego.ChatActionTyping)) + } + } + }() + + return cancel, nil +} + +// EditMessage implements channels.MessageEditor. +func (c *TelegramChannel) EditMessage(ctx context.Context, chatID string, messageID string, content string) error { + cid, err := parseChatID(chatID) + if err != nil { return err } + mid, err := strconv.Atoi(messageID) + if err != nil { + return err + } + htmlContent := markdownToTelegramHTML(content) + editMsg := tu.EditMessageText(tu.ID(cid), mid, htmlContent) + editMsg.ParseMode = telego.ModeHTML + _, err = c.bot.EditMessageText(ctx, editMsg) + return err +} + +// SendPlaceholder implements channels.PlaceholderCapable. +// It sends a placeholder message (e.g. "Thinking... 💭") that will later be +// edited to the actual response via EditMessage (channels.MessageEditor). +func (c *TelegramChannel) SendPlaceholder(ctx context.Context, chatID string) (string, error) { + phCfg := c.config.Channels.Telegram.Placeholder + if !phCfg.Enabled { + return "", nil + } + + text := phCfg.Text + if text == "" { + text = "Thinking... 💭" + } + + cid, err := parseChatID(chatID) + if err != nil { + return "", err + } + + pMsg, err := c.bot.SendMessage(ctx, tu.Message(tu.ID(cid), text)) + if err != nil { + return "", err + } + + return fmt.Sprintf("%d", pMsg.MessageID), nil +} + +// SendMedia implements the channels.MediaSender interface. +func (c *TelegramChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error { + if !c.IsRunning() { + return channels.ErrNotRunning + } + + chatID, err := parseChatID(msg.ChatID) + if err != nil { + return fmt.Errorf("invalid chat ID %s: %w", msg.ChatID, channels.ErrSendFailed) + } + + store := c.GetMediaStore() + if store == nil { + return fmt.Errorf("no media store available: %w", channels.ErrSendFailed) + } + + for _, part := range msg.Parts { + localPath, err := store.Resolve(part.Ref) + if err != nil { + logger.ErrorCF("telegram", "Failed to resolve media ref", map[string]any{ + "ref": part.Ref, + "error": err.Error(), + }) + continue + } + + file, err := os.Open(localPath) + if err != nil { + logger.ErrorCF("telegram", "Failed to open media file", map[string]any{ + "path": localPath, + "error": err.Error(), + }) + continue + } + + switch part.Type { + case "image": + params := &telego.SendPhotoParams{ + ChatID: tu.ID(chatID), + Photo: telego.InputFile{File: file}, + Caption: part.Caption, + } + _, err = c.bot.SendPhoto(ctx, params) + case "audio": + params := &telego.SendAudioParams{ + ChatID: tu.ID(chatID), + Audio: telego.InputFile{File: file}, + Caption: part.Caption, + } + _, err = c.bot.SendAudio(ctx, params) + case "video": + params := &telego.SendVideoParams{ + ChatID: tu.ID(chatID), + Video: telego.InputFile{File: file}, + Caption: part.Caption, + } + _, err = c.bot.SendVideo(ctx, params) + default: // "file" or unknown types + params := &telego.SendDocumentParams{ + ChatID: tu.ID(chatID), + Document: telego.InputFile{File: file}, + Caption: part.Caption, + } + _, err = c.bot.SendDocument(ctx, params) + } + + file.Close() + + if err != nil { + logger.ErrorCF("telegram", "Failed to send media", map[string]any{ + "type": part.Type, + "error": err.Error(), + }) + return fmt.Errorf("telegram send media: %w", channels.ErrTemporary) + } + } return nil } @@ -216,37 +356,46 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Mes return fmt.Errorf("message sender (user) is nil") } - senderID := fmt.Sprintf("%d", user.ID) - if user.Username != "" { - senderID = fmt.Sprintf("%d|%s", user.ID, user.Username) + platformID := fmt.Sprintf("%d", user.ID) + sender := bus.SenderInfo{ + Platform: "telegram", + PlatformID: platformID, + CanonicalID: identity.BuildCanonicalID("telegram", platformID), + Username: user.Username, + DisplayName: user.FirstName, } // check allowlist to avoid downloading attachments for rejected users - if !c.IsAllowed(senderID) { + if !c.IsAllowedSender(sender) { logger.DebugCF("telegram", "Message rejected by allowlist", map[string]any{ - "user_id": senderID, + "user_id": platformID, }) return nil } chatID := message.Chat.ID - c.chatIDs[senderID] = chatID + c.chatIDs[platformID] = chatID content := "" mediaPaths := []string{} - localFiles := []string{} // track local files that need cleanup - // ensure temp files are cleaned up when function returns - defer func() { - for _, file := range localFiles { - if err := os.Remove(file); err != nil { - logger.DebugCF("telegram", "Failed to cleanup temp file", map[string]any{ - "file": file, - "error": err.Error(), - }) + chatIDStr := fmt.Sprintf("%d", chatID) + messageIDStr := fmt.Sprintf("%d", message.MessageID) + scope := channels.BuildMediaScope("telegram", chatIDStr, messageIDStr) + + // Helper to register a local file with the media store + storeMedia := func(localPath, filename string) string { + if store := c.GetMediaStore(); store != nil { + ref, err := store.Store(localPath, media.MediaMeta{ + Filename: filename, + Source: "telegram", + }, scope) + if err == nil { + return ref } } - }() + return localPath // fallback: use raw path + } if message.Text != "" { content += message.Text @@ -263,8 +412,7 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Mes photo := message.Photo[len(message.Photo)-1] photoPath := c.downloadPhoto(ctx, photo.FileID) if photoPath != "" { - localFiles = append(localFiles, photoPath) - mediaPaths = append(mediaPaths, photoPath) + mediaPaths = append(mediaPaths, storeMedia(photoPath, "photo.jpg")) if content != "" { content += "\n" } @@ -275,43 +423,19 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Mes if message.Voice != nil { voicePath := c.downloadFile(ctx, message.Voice.FileID, ".ogg") if voicePath != "" { - localFiles = append(localFiles, voicePath) - mediaPaths = append(mediaPaths, voicePath) - - var transcribedText string - if c.transcriber != nil && c.transcriber.IsAvailable() { - transcriberCtx, cancel := context.WithTimeout(ctx, 30*time.Second) - defer cancel() - - result, err := c.transcriber.Transcribe(transcriberCtx, voicePath) - if err != nil { - logger.ErrorCF("telegram", "Voice transcription failed", map[string]any{ - "error": err.Error(), - "path": voicePath, - }) - transcribedText = "[voice (transcription failed)]" - } else { - transcribedText = fmt.Sprintf("[voice transcription: %s]", result.Text) - logger.InfoCF("telegram", "Voice transcribed successfully", map[string]any{ - "text": result.Text, - }) - } - } else { - transcribedText = "[voice]" - } + mediaPaths = append(mediaPaths, storeMedia(voicePath, "voice.ogg")) if content != "" { content += "\n" } - content += transcribedText + content += "[voice]" } } if message.Audio != nil { audioPath := c.downloadFile(ctx, message.Audio.FileID, ".mp3") if audioPath != "" { - localFiles = append(localFiles, audioPath) - mediaPaths = append(mediaPaths, audioPath) + mediaPaths = append(mediaPaths, storeMedia(audioPath, "audio.mp3")) if content != "" { content += "\n" } @@ -322,8 +446,7 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Mes if message.Document != nil { docPath := c.downloadFile(ctx, message.Document.FileID, "") if docPath != "" { - localFiles = append(localFiles, docPath) - mediaPaths = append(mediaPaths, docPath) + mediaPaths = append(mediaPaths, storeMedia(docPath, "document")) if content != "" { content += "\n" } @@ -335,37 +458,26 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Mes content = "[empty message]" } + // In group chats, apply unified group trigger filtering + if message.Chat.Type != "private" { + isMentioned := c.isBotMentioned(message) + if isMentioned { + content = c.stripBotMention(content) + } + respond, cleaned := c.ShouldRespondInGroup(isMentioned, content) + if !respond { + return nil + } + content = cleaned + } + logger.DebugCF("telegram", "Received message", map[string]any{ - "sender_id": senderID, + "sender_id": sender.CanonicalID, "chat_id": fmt.Sprintf("%d", chatID), "preview": utils.Truncate(content, 50), }) - // Thinking indicator - err := c.bot.SendChatAction(ctx, tu.ChatAction(tu.ID(chatID), telego.ChatActionTyping)) - if err != nil { - logger.ErrorCF("telegram", "Failed to send chat action", map[string]any{ - "error": err.Error(), - }) - } - - // Stop any previous thinking animation - chatIDStr := fmt.Sprintf("%d", chatID) - if prevStop, ok := c.stopThinking.Load(chatIDStr); ok { - if cf, ok := prevStop.(*thinkingCancel); ok && cf != nil { - cf.Cancel() - } - } - - // Create cancel function for thinking state - _, thinkCancel := context.WithTimeout(ctx, 5*time.Minute) - c.stopThinking.Store(chatIDStr, &thinkingCancel{fn: thinkCancel}) - - pMsg, err := c.bot.SendMessage(ctx, tu.Message(tu.ID(chatID), "Thinking... 💭")) - if err == nil { - pID := pMsg.MessageID - c.placeholders.Store(chatIDStr, pID) - } + // Placeholder is now auto-triggered by BaseChannel.HandleMessage via PlaceholderCapable peerKind := "direct" peerID := fmt.Sprintf("%d", user.ID) @@ -374,17 +486,26 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Mes peerID = fmt.Sprintf("%d", chatID) } + peer := bus.Peer{Kind: peerKind, ID: peerID} + messageID := fmt.Sprintf("%d", message.MessageID) + metadata := map[string]string{ - "message_id": fmt.Sprintf("%d", message.MessageID), "user_id": fmt.Sprintf("%d", user.ID), "username": user.Username, "first_name": user.FirstName, "is_group": fmt.Sprintf("%t", message.Chat.Type != "private"), - "peer_kind": peerKind, - "peer_id": peerID, } - c.HandleMessage(fmt.Sprintf("%d", user.ID), fmt.Sprintf("%d", chatID), content, mediaPaths, metadata) + c.HandleMessage(c.ctx, + peer, + messageID, + platformID, + fmt.Sprintf("%d", chatID), + content, + mediaPaths, + metadata, + sender, + ) return nil } @@ -537,3 +658,52 @@ func escapeHTML(text string) string { text = strings.ReplaceAll(text, ">", ">") return text } + +// isBotMentioned checks if the bot is mentioned in the message via entities. +func (c *TelegramChannel) isBotMentioned(message *telego.Message) bool { + botUsername := c.bot.Username() + if botUsername == "" { + return false + } + + entities := message.Entities + if entities == nil { + entities = message.CaptionEntities + } + + for _, entity := range entities { + if entity.Type == "mention" { + // Extract the mention text from the message + text := message.Text + if text == "" { + text = message.Caption + } + runes := []rune(text) + end := entity.Offset + entity.Length + if end <= len(runes) { + mention := string(runes[entity.Offset:end]) + if strings.EqualFold(mention, "@"+botUsername) { + return true + } + } + } + if entity.Type == "text_mention" && entity.User != nil { + if entity.User.Username == botUsername { + return true + } + } + } + return false +} + +// stripBotMention removes the @bot mention from the content. +func (c *TelegramChannel) stripBotMention(content string) string { + botUsername := c.bot.Username() + if botUsername == "" { + return content + } + // Case-insensitive replacement + re := regexp.MustCompile(`(?i)@` + regexp.QuoteMeta(botUsername)) + content = re.ReplaceAllString(content, "") + return strings.TrimSpace(content) +} diff --git a/pkg/channels/telegram_commands.go b/pkg/channels/telegram/telegram_commands.go similarity index 98% rename from pkg/channels/telegram_commands.go rename to pkg/channels/telegram/telegram_commands.go index f28434f46..496fc5e4f 100644 --- a/pkg/channels/telegram_commands.go +++ b/pkg/channels/telegram/telegram_commands.go @@ -1,4 +1,4 @@ -package channels +package telegram import ( "context" @@ -119,7 +119,7 @@ func (c *cmd) List(ctx context.Context, message telego.Message) error { if provider == "" { provider = "configured default" } - response = fmt.Sprintf("Configured Model: %s\nProvider: %s\n\nTo change models, update config.yaml", + response = fmt.Sprintf("Configured Model: %s\nProvider: %s\n\nTo change models, update config.json", c.config.Agents.Defaults.GetModelName(), provider) case "channels": diff --git a/pkg/channels/webhook.go b/pkg/channels/webhook.go new file mode 100644 index 000000000..3cf27baf6 --- /dev/null +++ b/pkg/channels/webhook.go @@ -0,0 +1,20 @@ +package channels + +import "net/http" + +// WebhookHandler is an optional interface for channels that receive messages +// via HTTP webhooks. Manager discovers channels implementing this interface +// and registers them on the shared HTTP server. +type WebhookHandler interface { + // WebhookPath returns the path to mount this handler on the shared server. + // Examples: "/webhook/line", "/webhook/wecom" + WebhookPath() string + http.Handler // ServeHTTP(w http.ResponseWriter, r *http.Request) +} + +// HealthChecker is an optional interface for channels that expose +// a health check endpoint on the shared HTTP server. +type HealthChecker interface { + HealthPath() string + HealthHandler(w http.ResponseWriter, r *http.Request) +} diff --git a/pkg/channels/wecom_app.go b/pkg/channels/wecom/app.go similarity index 67% rename from pkg/channels/wecom_app.go rename to pkg/channels/wecom/app.go index 302603445..42a74e8c9 100644 --- a/pkg/channels/wecom_app.go +++ b/pkg/channels/wecom/app.go @@ -1,8 +1,4 @@ -// PicoClaw - Ultra-lightweight personal AI agent -// WeCom App (企业微信自建应用) channel implementation -// Supports receiving messages via webhook callback and sending messages proactively - -package channels +package wecom import ( "bytes" @@ -11,14 +7,19 @@ import ( "encoding/xml" "fmt" "io" + "mime/multipart" "net/http" "net/url" + "os" + "path/filepath" "strings" "sync" "time" "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/identity" "github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/utils" ) @@ -29,9 +30,8 @@ const ( // WeComAppChannel implements the Channel interface for WeCom App (企业微信自建应用) type WeComAppChannel struct { - *BaseChannel + *channels.BaseChannel config config.WeComAppConfig - server *http.Server accessToken string tokenExpiry time.Time tokenMu sync.RWMutex @@ -123,7 +123,11 @@ func NewWeComAppChannel(cfg config.WeComAppConfig, messageBus *bus.MessageBus) ( return nil, fmt.Errorf("wecom_app corp_id, corp_secret and agent_id are required") } - base := NewBaseChannel("wecom_app", cfg, messageBus, cfg.AllowFrom) + base := channels.NewBaseChannel("wecom_app", cfg, messageBus, cfg.AllowFrom, + channels.WithMaxMessageLength(2048), + channels.WithGroupTrigger(cfg.GroupTrigger), + channels.WithReasoningChannelID(cfg.ReasoningChannelID), + ) return &WeComAppChannel{ BaseChannel: base, @@ -137,7 +141,7 @@ func (c *WeComAppChannel) Name() string { return "wecom_app" } -// Start initializes the WeCom App channel with HTTP webhook server +// Start initializes the WeCom App channel func (c *WeComAppChannel) Start(ctx context.Context) error { logger.InfoC("wecom_app", "Starting WeCom App channel...") @@ -153,37 +157,8 @@ func (c *WeComAppChannel) Start(ctx context.Context) error { // Start token refresh goroutine go c.tokenRefreshLoop() - // Setup HTTP server for webhook - mux := http.NewServeMux() - webhookPath := c.config.WebhookPath - if webhookPath == "" { - webhookPath = "/webhook/wecom-app" - } - mux.HandleFunc(webhookPath, c.handleWebhook) - - // Health check endpoint - mux.HandleFunc("/health/wecom-app", c.handleHealth) - - addr := fmt.Sprintf("%s:%d", c.config.WebhookHost, c.config.WebhookPort) - c.server = &http.Server{ - Addr: addr, - Handler: mux, - } - - c.setRunning(true) - logger.InfoCF("wecom_app", "WeCom App channel started", map[string]any{ - "address": addr, - "path": webhookPath, - }) - - // Start server in goroutine - go func() { - if err := c.server.ListenAndServe(); err != nil && err != http.ErrServerClosed { - logger.ErrorCF("wecom_app", "HTTP server error", map[string]any{ - "error": err.Error(), - }) - } - }() + c.SetRunning(true) + logger.InfoC("wecom_app", "WeCom App channel started") return nil } @@ -196,13 +171,7 @@ func (c *WeComAppChannel) Stop(ctx context.Context) error { c.cancel() } - if c.server != nil { - shutdownCtx, cancel := context.WithTimeout(ctx, 5*time.Second) - defer cancel() - c.server.Shutdown(shutdownCtx) - } - - c.setRunning(false) + c.SetRunning(false) logger.InfoC("wecom_app", "WeCom App channel stopped") return nil } @@ -210,7 +179,7 @@ func (c *WeComAppChannel) Stop(ctx context.Context) error { // Send sends a message to WeCom user proactively using access token func (c *WeComAppChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { if !c.IsRunning() { - return fmt.Errorf("wecom_app channel not running") + return channels.ErrNotRunning } accessToken := c.getAccessToken() @@ -226,6 +195,220 @@ func (c *WeComAppChannel) Send(ctx context.Context, msg bus.OutboundMessage) err return c.sendTextMessage(ctx, accessToken, msg.ChatID, msg.Content) } +// SendMedia implements the channels.MediaSender interface. +func (c *WeComAppChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error { + if !c.IsRunning() { + return channels.ErrNotRunning + } + + accessToken := c.getAccessToken() + if accessToken == "" { + return fmt.Errorf("no valid access token available: %w", channels.ErrTemporary) + } + + store := c.GetMediaStore() + if store == nil { + return fmt.Errorf("no media store available: %w", channels.ErrSendFailed) + } + + for _, part := range msg.Parts { + localPath, err := store.Resolve(part.Ref) + if err != nil { + logger.ErrorCF("wecom_app", "Failed to resolve media ref", map[string]any{ + "ref": part.Ref, + "error": err.Error(), + }) + continue + } + + // Map part type to WeCom media type + var mediaType string + switch part.Type { + case "image": + mediaType = "image" + case "audio": + mediaType = "voice" + case "video": + mediaType = "video" + default: + mediaType = "file" + } + + // Upload media to get media_id + mediaID, err := c.uploadMedia(ctx, accessToken, mediaType, localPath) + if err != nil { + logger.ErrorCF("wecom_app", "Failed to upload media", map[string]any{ + "type": mediaType, + "error": err.Error(), + }) + // Fallback: send caption as text + if part.Caption != "" { + _ = c.sendTextMessage(ctx, accessToken, msg.ChatID, part.Caption) + } + continue + } + + // Send media message using the media_id + if mediaType == "image" { + err = c.sendImageMessage(ctx, accessToken, msg.ChatID, mediaID) + } else { + // For non-image types, send as text fallback with caption + caption := part.Caption + if caption == "" { + caption = fmt.Sprintf("[%s: %s]", part.Type, part.Filename) + } + err = c.sendTextMessage(ctx, accessToken, msg.ChatID, caption) + } + + if err != nil { + return err + } + } + + return nil +} + +// uploadMedia uploads a local file to WeCom temporary media storage. +func (c *WeComAppChannel) uploadMedia(ctx context.Context, accessToken, mediaType, localPath string) (string, error) { + apiURL := fmt.Sprintf("%s/cgi-bin/media/upload?access_token=%s&type=%s", + wecomAPIBase, url.QueryEscape(accessToken), url.QueryEscape(mediaType)) + + file, err := os.Open(localPath) + if err != nil { + return "", fmt.Errorf("failed to open file: %w", err) + } + defer file.Close() + + body := &bytes.Buffer{} + writer := multipart.NewWriter(body) + + filename := filepath.Base(localPath) + formFile, err := writer.CreateFormFile("media", filename) + if err != nil { + return "", fmt.Errorf("failed to create form file: %w", err) + } + + if _, err = io.Copy(formFile, file); err != nil { + return "", fmt.Errorf("failed to copy file content: %w", err) + } + writer.Close() + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, body) + if err != nil { + return "", fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", writer.FormDataContentType()) + + client := &http.Client{Timeout: 30 * time.Second} + resp, err := client.Do(req) + if err != nil { + return "", channels.ClassifyNetError(err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + respBody, _ := io.ReadAll(resp.Body) + return "", channels.ClassifySendError(resp.StatusCode, fmt.Errorf("wecom upload error: %s", string(respBody))) + } + + var result struct { + ErrCode int `json:"errcode"` + ErrMsg string `json:"errmsg"` + MediaID string `json:"media_id"` + } + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return "", fmt.Errorf("failed to parse upload response: %w", err) + } + + if result.ErrCode != 0 { + return "", fmt.Errorf("upload API error: %s (code: %d)", result.ErrMsg, result.ErrCode) + } + + return result.MediaID, nil +} + +// sendImageMessage sends an image message using a media_id. +func (c *WeComAppChannel) sendImageMessage(ctx context.Context, accessToken, userID, mediaID string) error { + apiURL := fmt.Sprintf("%s/cgi-bin/message/send?access_token=%s", wecomAPIBase, accessToken) + + msg := WeComImageMessage{ + ToUser: userID, + MsgType: "image", + AgentID: c.config.AgentID, + } + msg.Image.MediaID = mediaID + + jsonData, err := json.Marshal(msg) + if err != nil { + return fmt.Errorf("failed to marshal message: %w", err) + } + + timeout := c.config.ReplyTimeout + if timeout <= 0 { + timeout = 5 + } + + reqCtx, cancel := context.WithTimeout(ctx, time.Duration(timeout)*time.Second) + defer cancel() + + req, err := http.NewRequestWithContext(reqCtx, http.MethodPost, apiURL, bytes.NewBuffer(jsonData)) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + + client := &http.Client{Timeout: time.Duration(timeout) * time.Second} + resp, err := client.Do(req) + if err != nil { + return channels.ClassifyNetError(err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + respBody, _ := io.ReadAll(resp.Body) + return channels.ClassifySendError(resp.StatusCode, fmt.Errorf("wecom_app API error: %s", string(respBody))) + } + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("failed to read response: %w", err) + } + + var sendResp WeComSendMessageResponse + if err := json.Unmarshal(respBody, &sendResp); err != nil { + return fmt.Errorf("failed to parse response: %w", err) + } + + if sendResp.ErrCode != 0 { + return fmt.Errorf("API error: %s (code: %d)", sendResp.ErrMsg, sendResp.ErrCode) + } + + return nil +} + +// WebhookPath returns the path for registering on the shared HTTP server. +func (c *WeComAppChannel) WebhookPath() string { + if c.config.WebhookPath != "" { + return c.config.WebhookPath + } + return "/webhook/wecom-app" +} + +// ServeHTTP implements http.Handler for the shared HTTP server. +func (c *WeComAppChannel) ServeHTTP(w http.ResponseWriter, r *http.Request) { + c.handleWebhook(w, r) +} + +// HealthPath returns the health check endpoint path. +func (c *WeComAppChannel) HealthPath() string { + return "/health/wecom-app" +} + +// HealthHandler handles health check requests. +func (c *WeComAppChannel) HealthHandler(w http.ResponseWriter, r *http.Request) { + c.handleHealth(w, r) +} + // handleWebhook handles incoming webhook requests from WeCom func (c *WeComAppChannel) handleWebhook(w http.ResponseWriter, r *http.Request) { ctx := r.Context() @@ -279,7 +462,7 @@ func (c *WeComAppChannel) handleVerification(ctx context.Context, w http.Respons } // Verify signature - if !WeComVerifySignature(c.config.Token, msgSignature, timestamp, nonce, echostr) { + if !verifySignature(c.config.Token, msgSignature, timestamp, nonce, echostr) { logger.WarnCF("wecom_app", "Signature verification failed", map[string]any{ "token": c.config.Token, "msg_signature": msgSignature, @@ -298,7 +481,7 @@ func (c *WeComAppChannel) handleVerification(ctx context.Context, w http.Respons "encoding_aes_key": c.config.EncodingAESKey, "corp_id": c.config.CorpID, }) - decryptedEchoStr, err := WeComDecryptMessageWithVerify(echostr, c.config.EncodingAESKey, c.config.CorpID) + decryptedEchoStr, err := decryptMessageWithVerify(echostr, c.config.EncodingAESKey, c.config.CorpID) if err != nil { logger.ErrorCF("wecom_app", "Failed to decrypt echostr", map[string]any{ "error": err.Error(), @@ -357,7 +540,7 @@ func (c *WeComAppChannel) handleMessageCallback(ctx context.Context, w http.Resp } // Verify signature - if !WeComVerifySignature(c.config.Token, msgSignature, timestamp, nonce, encryptedMsg.Encrypt) { + if !verifySignature(c.config.Token, msgSignature, timestamp, nonce, encryptedMsg.Encrypt) { logger.WarnC("wecom_app", "Message signature verification failed") http.Error(w, "Invalid signature", http.StatusForbidden) return @@ -365,7 +548,7 @@ func (c *WeComAppChannel) handleMessageCallback(ctx context.Context, w http.Resp // Decrypt message with CorpID verification // For WeCom App (自建应用), receiveid should be corp_id - decryptedMsg, err := WeComDecryptMessageWithVerify(encryptedMsg.Encrypt, c.config.EncodingAESKey, c.config.CorpID) + decryptedMsg, err := decryptMessageWithVerify(encryptedMsg.Encrypt, c.config.EncodingAESKey, c.config.CorpID) if err != nil { logger.ErrorCF("wecom_app", "Failed to decrypt message", map[string]any{ "error": err.Error(), @@ -428,6 +611,9 @@ func (c *WeComAppChannel) processMessage(ctx context.Context, msg WeComXMLMessag // Build metadata // WeCom App only supports direct messages (private chat) + peer := bus.Peer{Kind: "direct", ID: senderID} + messageID := fmt.Sprintf("%d", msg.MsgId) + metadata := map[string]string{ "msg_type": msg.MsgType, "msg_id": fmt.Sprintf("%d", msg.MsgId), @@ -435,8 +621,6 @@ func (c *WeComAppChannel) processMessage(ctx context.Context, msg WeComXMLMessag "platform": "wecom_app", "media_id": msg.MediaId, "create_time": fmt.Sprintf("%d", msg.CreateTime), - "peer_kind": "direct", - "peer_id": senderID, } content := msg.Content @@ -447,8 +631,15 @@ func (c *WeComAppChannel) processMessage(ctx context.Context, msg WeComXMLMessag "preview": utils.Truncate(content, 50), }) + // Build sender info + appSender := bus.SenderInfo{ + Platform: "wecom", + PlatformID: senderID, + CanonicalID: identity.BuildCanonicalID("wecom", senderID), + } + // Handle the message through the base channel - c.HandleMessage(senderID, chatID, content, nil, metadata) + c.HandleMessage(ctx, peer, messageID, senderID, chatID, content, nil, metadata, appSender) } // tokenRefreshLoop periodically refreshes the access token @@ -550,10 +741,15 @@ func (c *WeComAppChannel) sendTextMessage(ctx context.Context, accessToken, user client := &http.Client{Timeout: time.Duration(timeout) * time.Second} resp, err := client.Do(req) if err != nil { - return fmt.Errorf("failed to send message: %w", err) + return channels.ClassifyNetError(err) } defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return channels.ClassifySendError(resp.StatusCode, fmt.Errorf("wecom_app API error: %s", string(body))) + } + body, err := io.ReadAll(resp.Body) if err != nil { return fmt.Errorf("failed to read response: %w", err) diff --git a/pkg/channels/wecom_app_test.go b/pkg/channels/wecom/app_test.go similarity index 95% rename from pkg/channels/wecom_app_test.go rename to pkg/channels/wecom/app_test.go index ba911d49f..0d15e955b 100644 --- a/pkg/channels/wecom_app_test.go +++ b/pkg/channels/wecom/app_test.go @@ -1,7 +1,4 @@ -// PicoClaw - Ultra-lightweight personal AI agent -// WeCom App (企业微信自建应用) channel tests - -package channels +package wecom import ( "bytes" @@ -197,7 +194,7 @@ func TestWeComAppVerifySignature(t *testing.T) { msgEncrypt := "test_message" expectedSig := generateSignatureApp("test_token", timestamp, nonce, msgEncrypt) - if !WeComVerifySignature(ch.config.Token, expectedSig, timestamp, nonce, msgEncrypt) { + if !verifySignature(ch.config.Token, expectedSig, timestamp, nonce, msgEncrypt) { t.Error("valid signature should pass verification") } }) @@ -207,7 +204,7 @@ func TestWeComAppVerifySignature(t *testing.T) { nonce := "test_nonce" msgEncrypt := "test_message" - if WeComVerifySignature(ch.config.Token, "invalid_sig", timestamp, nonce, msgEncrypt) { + if verifySignature(ch.config.Token, "invalid_sig", timestamp, nonce, msgEncrypt) { t.Error("invalid signature should fail verification") } }) @@ -221,7 +218,7 @@ func TestWeComAppVerifySignature(t *testing.T) { } chEmpty, _ := NewWeComAppChannel(cfgEmpty, msgBus) - if !WeComVerifySignature(chEmpty.config.Token, "any_sig", "any_ts", "any_nonce", "any_msg") { + if !verifySignature(chEmpty.config.Token, "any_sig", "any_ts", "any_nonce", "any_msg") { t.Error("empty token should skip verification and return true") } }) @@ -243,7 +240,7 @@ func TestWeComAppDecryptMessage(t *testing.T) { plainText := "hello world" encoded := base64.StdEncoding.EncodeToString([]byte(plainText)) - result, err := WeComDecryptMessage(encoded, ch.config.EncodingAESKey) + result, err := decryptMessage(encoded, ch.config.EncodingAESKey) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -268,7 +265,7 @@ func TestWeComAppDecryptMessage(t *testing.T) { t.Fatalf("failed to encrypt test message: %v", err) } - result, err := WeComDecryptMessage(encrypted, ch.config.EncodingAESKey) + result, err := decryptMessage(encrypted, ch.config.EncodingAESKey) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -286,7 +283,7 @@ func TestWeComAppDecryptMessage(t *testing.T) { } ch, _ := NewWeComAppChannel(cfg, msgBus) - _, err := WeComDecryptMessage("invalid_base64!!!", ch.config.EncodingAESKey) + _, err := decryptMessage("invalid_base64!!!", ch.config.EncodingAESKey) if err == nil { t.Error("expected error for invalid base64, got nil") } @@ -301,7 +298,7 @@ func TestWeComAppDecryptMessage(t *testing.T) { } ch, _ := NewWeComAppChannel(cfg, msgBus) - _, err := WeComDecryptMessage(base64.StdEncoding.EncodeToString([]byte("test")), ch.config.EncodingAESKey) + _, err := decryptMessage(base64.StdEncoding.EncodeToString([]byte("test")), ch.config.EncodingAESKey) if err == nil { t.Error("expected error for invalid AES key, got nil") } @@ -319,7 +316,7 @@ func TestWeComAppDecryptMessage(t *testing.T) { // Encrypt a very short message that results in ciphertext less than block size shortData := make([]byte, 8) - _, err := WeComDecryptMessage(base64.StdEncoding.EncodeToString(shortData), ch.config.EncodingAESKey) + _, err := decryptMessage(base64.StdEncoding.EncodeToString(shortData), ch.config.EncodingAESKey) if err == nil { t.Error("expected error for short ciphertext, got nil") } @@ -361,7 +358,7 @@ func TestWeComAppPKCS7Unpad(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result, err := pkcs7UnpadWeCom(tt.input) + result, err := pkcs7Unpad(tt.input) if tt.expected == nil { // This case should return an error if err == nil { @@ -852,6 +849,28 @@ func TestWeComAppMessageStructures(t *testing.T) { } }) + t.Run("WeComImageMessage structure", func(t *testing.T) { + msg := WeComImageMessage{ + ToUser: "user123", + MsgType: "image", + AgentID: 1000002, + } + msg.Image.MediaID = "media_123456" + + if msg.Image.MediaID != "media_123456" { + t.Errorf("Image.MediaID = %q, want %q", msg.Image.MediaID, "media_123456") + } + if msg.ToUser != "user123" { + t.Errorf("ToUser = %q, want %q", msg.ToUser, "user123") + } + if msg.MsgType != "image" { + t.Errorf("MsgType = %q, want %q", msg.MsgType, "image") + } + if msg.AgentID != 1000002 { + t.Errorf("AgentID = %d, want %d", msg.AgentID, 1000002) + } + }) + t.Run("WeComAccessTokenResponse structure", func(t *testing.T) { jsonData := `{ "errcode": 0, diff --git a/pkg/channels/wecom.go b/pkg/channels/wecom/bot.go similarity index 66% rename from pkg/channels/wecom.go rename to pkg/channels/wecom/bot.go index e24157f5c..4c576b84b 100644 --- a/pkg/channels/wecom.go +++ b/pkg/channels/wecom/bot.go @@ -1,29 +1,21 @@ -// PicoClaw - Ultra-lightweight personal AI agent -// WeCom Bot (企业微信智能机器人) channel implementation -// Uses webhook callback mode for receiving messages and webhook API for sending replies - -package channels +package wecom import ( "bytes" "context" - "crypto/aes" - "crypto/cipher" - "crypto/sha1" - "encoding/base64" - "encoding/binary" "encoding/json" "encoding/xml" "fmt" "io" "net/http" - "sort" "strings" "sync" "time" "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/identity" "github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/utils" ) @@ -31,9 +23,8 @@ import ( // WeComBotChannel implements the Channel interface for WeCom Bot (企业微信智能机器人) // Uses webhook callback mode - simpler than WeCom App but only supports passive replies type WeComBotChannel struct { - *BaseChannel + *channels.BaseChannel config config.WeComConfig - server *http.Server ctx context.Context cancel context.CancelFunc processedMsgs map[string]bool // Message deduplication: msg_id -> processed @@ -96,7 +87,11 @@ func NewWeComBotChannel(cfg config.WeComConfig, messageBus *bus.MessageBus) (*We return nil, fmt.Errorf("wecom token and webhook_url are required") } - base := NewBaseChannel("wecom", cfg, messageBus, cfg.AllowFrom) + base := channels.NewBaseChannel("wecom", cfg, messageBus, cfg.AllowFrom, + channels.WithMaxMessageLength(2048), + channels.WithGroupTrigger(cfg.GroupTrigger), + channels.WithReasoningChannelID(cfg.ReasoningChannelID), + ) return &WeComBotChannel{ BaseChannel: base, @@ -110,43 +105,14 @@ func (c *WeComBotChannel) Name() string { return "wecom" } -// Start initializes the WeCom Bot channel with HTTP webhook server +// Start initializes the WeCom Bot channel func (c *WeComBotChannel) Start(ctx context.Context) error { logger.InfoC("wecom", "Starting WeCom Bot channel...") c.ctx, c.cancel = context.WithCancel(ctx) - // Setup HTTP server for webhook - mux := http.NewServeMux() - webhookPath := c.config.WebhookPath - if webhookPath == "" { - webhookPath = "/webhook/wecom" - } - mux.HandleFunc(webhookPath, c.handleWebhook) - - // Health check endpoint - mux.HandleFunc("/health/wecom", c.handleHealth) - - addr := fmt.Sprintf("%s:%d", c.config.WebhookHost, c.config.WebhookPort) - c.server = &http.Server{ - Addr: addr, - Handler: mux, - } - - c.setRunning(true) - logger.InfoCF("wecom", "WeCom Bot channel started", map[string]any{ - "address": addr, - "path": webhookPath, - }) - - // Start server in goroutine - go func() { - if err := c.server.ListenAndServe(); err != nil && err != http.ErrServerClosed { - logger.ErrorCF("wecom", "HTTP server error", map[string]any{ - "error": err.Error(), - }) - } - }() + c.SetRunning(true) + logger.InfoC("wecom", "WeCom Bot channel started") return nil } @@ -159,13 +125,7 @@ func (c *WeComBotChannel) Stop(ctx context.Context) error { c.cancel() } - if c.server != nil { - shutdownCtx, cancel := context.WithTimeout(ctx, 5*time.Second) - defer cancel() - c.server.Shutdown(shutdownCtx) - } - - c.setRunning(false) + c.SetRunning(false) logger.InfoC("wecom", "WeCom Bot channel stopped") return nil } @@ -175,7 +135,7 @@ func (c *WeComBotChannel) Stop(ctx context.Context) error { // For delayed responses, we use the webhook URL func (c *WeComBotChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { if !c.IsRunning() { - return fmt.Errorf("wecom channel not running") + return channels.ErrNotRunning } logger.DebugCF("wecom", "Sending message via webhook", map[string]any{ @@ -186,6 +146,29 @@ func (c *WeComBotChannel) Send(ctx context.Context, msg bus.OutboundMessage) err return c.sendWebhookReply(ctx, msg.ChatID, msg.Content) } +// WebhookPath returns the path for registering on the shared HTTP server. +func (c *WeComBotChannel) WebhookPath() string { + if c.config.WebhookPath != "" { + return c.config.WebhookPath + } + return "/webhook/wecom" +} + +// ServeHTTP implements http.Handler for the shared HTTP server. +func (c *WeComBotChannel) ServeHTTP(w http.ResponseWriter, r *http.Request) { + c.handleWebhook(w, r) +} + +// HealthPath returns the health check endpoint path. +func (c *WeComBotChannel) HealthPath() string { + return "/health/wecom" +} + +// HealthHandler handles health check requests. +func (c *WeComBotChannel) HealthHandler(w http.ResponseWriter, r *http.Request) { + c.handleHealth(w, r) +} + // handleWebhook handles incoming webhook requests from WeCom func (c *WeComBotChannel) handleWebhook(w http.ResponseWriter, r *http.Request) { ctx := r.Context() @@ -219,7 +202,7 @@ func (c *WeComBotChannel) handleVerification(ctx context.Context, w http.Respons } // Verify signature - if !WeComVerifySignature(c.config.Token, msgSignature, timestamp, nonce, echostr) { + if !verifySignature(c.config.Token, msgSignature, timestamp, nonce, echostr) { logger.WarnC("wecom", "Signature verification failed") http.Error(w, "Invalid signature", http.StatusForbidden) return @@ -228,7 +211,7 @@ func (c *WeComBotChannel) handleVerification(ctx context.Context, w http.Respons // Decrypt echostr // For AIBOT (智能机器人), receiveid should be empty string "" // Reference: https://developer.work.weixin.qq.com/document/path/101033 - decryptedEchoStr, err := WeComDecryptMessageWithVerify(echostr, c.config.EncodingAESKey, "") + decryptedEchoStr, err := decryptMessageWithVerify(echostr, c.config.EncodingAESKey, "") if err != nil { logger.ErrorCF("wecom", "Failed to decrypt echostr", map[string]any{ "error": err.Error(), @@ -281,7 +264,7 @@ func (c *WeComBotChannel) handleMessageCallback(ctx context.Context, w http.Resp } // Verify signature - if !WeComVerifySignature(c.config.Token, msgSignature, timestamp, nonce, encryptedMsg.Encrypt) { + if !verifySignature(c.config.Token, msgSignature, timestamp, nonce, encryptedMsg.Encrypt) { logger.WarnC("wecom", "Message signature verification failed") http.Error(w, "Invalid signature", http.StatusForbidden) return @@ -290,7 +273,7 @@ func (c *WeComBotChannel) handleMessageCallback(ctx context.Context, w http.Resp // Decrypt message // For AIBOT (智能机器人), receiveid should be empty string "" // Reference: https://developer.work.weixin.qq.com/document/path/101033 - decryptedMsg, err := WeComDecryptMessageWithVerify(encryptedMsg.Encrypt, c.config.EncodingAESKey, "") + decryptedMsg, err := decryptMessageWithVerify(encryptedMsg.Encrypt, c.config.EncodingAESKey, "") if err != nil { logger.ErrorCF("wecom", "Failed to decrypt message", map[string]any{ "error": err.Error(), @@ -387,12 +370,21 @@ func (c *WeComBotChannel) processMessage(ctx context.Context, msg WeComBotMessag } // Build metadata + peer := bus.Peer{Kind: peerKind, ID: peerID} + + // In group chats, apply unified group trigger filtering + if isGroupChat { + respond, cleaned := c.ShouldRespondInGroup(false, content) + if !respond { + return + } + content = cleaned + } + metadata := map[string]string{ "msg_type": msg.MsgType, "msg_id": msg.MsgID, "platform": "wecom", - "peer_kind": peerKind, - "peer_id": peerID, "response_url": msg.ResponseURL, } if isGroupChat { @@ -408,8 +400,19 @@ func (c *WeComBotChannel) processMessage(ctx context.Context, msg WeComBotMessag "preview": utils.Truncate(content, 50), }) + // Build sender info + sender := bus.SenderInfo{ + Platform: "wecom", + PlatformID: senderID, + CanonicalID: identity.BuildCanonicalID("wecom", senderID), + } + + if !c.IsAllowedSender(sender) { + return + } + // Handle the message through the base channel - c.HandleMessage(senderID, chatID, content, nil, metadata) + c.HandleMessage(ctx, peer, msg.MsgID, senderID, chatID, content, nil, metadata, sender) } // sendWebhookReply sends a reply using the webhook URL @@ -442,10 +445,15 @@ func (c *WeComBotChannel) sendWebhookReply(ctx context.Context, userID, content client := &http.Client{Timeout: time.Duration(timeout) * time.Second} resp, err := client.Do(req) if err != nil { - return fmt.Errorf("failed to send webhook reply: %w", err) + return channels.ClassifyNetError(err) } defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return channels.ClassifySendError(resp.StatusCode, fmt.Errorf("webhook API error: %s", string(body))) + } + body, err := io.ReadAll(resp.Body) if err != nil { return fmt.Errorf("failed to read response: %w", err) @@ -477,129 +485,3 @@ func (c *WeComBotChannel) handleHealth(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(status) } - -// WeCom common utilities for both WeCom Bot and WeCom App -// The following functions were moved from wecom_common.go - -// WeComVerifySignature verifies the message signature for WeCom -// This is a common function used by both WeCom Bot and WeCom App -func WeComVerifySignature(token, msgSignature, timestamp, nonce, msgEncrypt string) bool { - if token == "" { - return true // Skip verification if token is not set - } - - // Sort parameters - params := []string{token, timestamp, nonce, msgEncrypt} - sort.Strings(params) - - // Concatenate - str := strings.Join(params, "") - - // SHA1 hash - hash := sha1.Sum([]byte(str)) - expectedSignature := fmt.Sprintf("%x", hash) - - return expectedSignature == msgSignature -} - -// WeComDecryptMessage decrypts the encrypted message using AES -// This is a common function used by both WeCom Bot and WeCom App -// For AIBOT, receiveid should be the aibotid; for other apps, it should be corp_id -func WeComDecryptMessage(encryptedMsg, encodingAESKey string) (string, error) { - return WeComDecryptMessageWithVerify(encryptedMsg, encodingAESKey, "") -} - -// WeComDecryptMessageWithVerify decrypts the encrypted message and optionally verifies receiveid -// receiveid: for AIBOT use aibotid, for WeCom App use corp_id. If empty, skip verification. -func WeComDecryptMessageWithVerify(encryptedMsg, encodingAESKey, receiveid string) (string, error) { - if encodingAESKey == "" { - // No encryption, return as is (base64 decode) - decoded, err := base64.StdEncoding.DecodeString(encryptedMsg) - if err != nil { - return "", err - } - return string(decoded), nil - } - - // Decode AES key (base64) - aesKey, err := base64.StdEncoding.DecodeString(encodingAESKey + "=") - if err != nil { - return "", fmt.Errorf("failed to decode AES key: %w", err) - } - - // Decode encrypted message - cipherText, err := base64.StdEncoding.DecodeString(encryptedMsg) - if err != nil { - return "", fmt.Errorf("failed to decode message: %w", err) - } - - // AES decrypt - block, err := aes.NewCipher(aesKey) - if err != nil { - return "", fmt.Errorf("failed to create cipher: %w", err) - } - - if len(cipherText) < aes.BlockSize { - return "", fmt.Errorf("ciphertext too short") - } - - // IV is the first 16 bytes of AESKey - iv := aesKey[:aes.BlockSize] - mode := cipher.NewCBCDecrypter(block, iv) - plainText := make([]byte, len(cipherText)) - mode.CryptBlocks(plainText, cipherText) - - // Remove PKCS7 padding - plainText, err = pkcs7UnpadWeCom(plainText) - if err != nil { - return "", fmt.Errorf("failed to unpad: %w", err) - } - - // Parse message structure - // Format: random(16) + msg_len(4) + msg + receiveid - if len(plainText) < 20 { - return "", fmt.Errorf("decrypted message too short") - } - - msgLen := binary.BigEndian.Uint32(plainText[16:20]) - if int(msgLen) > len(plainText)-20 { - return "", fmt.Errorf("invalid message length") - } - - msg := plainText[20 : 20+msgLen] - - // Verify receiveid if provided - if receiveid != "" && len(plainText) > 20+int(msgLen) { - actualReceiveID := string(plainText[20+msgLen:]) - if actualReceiveID != receiveid { - return "", fmt.Errorf("receiveid mismatch: expected %s, got %s", receiveid, actualReceiveID) - } - } - - return string(msg), nil -} - -// pkcs7UnpadWeCom removes PKCS7 padding with validation -// WeCom uses block size of 32 (not standard AES block size of 16) -const wecomBlockSize = 32 - -func pkcs7UnpadWeCom(data []byte) ([]byte, error) { - if len(data) == 0 { - return data, nil - } - padding := int(data[len(data)-1]) - // WeCom uses 32-byte block size for PKCS7 padding - if padding == 0 || padding > wecomBlockSize { - return nil, fmt.Errorf("invalid padding size: %d", padding) - } - if padding > len(data) { - return nil, fmt.Errorf("padding size larger than data") - } - // Verify all padding bytes - for i := range padding { - if data[len(data)-1-i] != byte(padding) { - return nil, fmt.Errorf("invalid padding byte at position %d", i) - } - } - return data[:len(data)-padding], nil -} diff --git a/pkg/channels/wecom_test.go b/pkg/channels/wecom/bot_test.go similarity index 95% rename from pkg/channels/wecom_test.go rename to pkg/channels/wecom/bot_test.go index 88aed8d2b..97b503ce8 100644 --- a/pkg/channels/wecom_test.go +++ b/pkg/channels/wecom/bot_test.go @@ -1,7 +1,4 @@ -// PicoClaw - Ultra-lightweight personal AI agent -// WeCom Bot (企业微信智能机器人) channel tests - -package channels +package wecom import ( "bytes" @@ -177,7 +174,7 @@ func TestWeComBotVerifySignature(t *testing.T) { msgEncrypt := "test_message" expectedSig := generateSignature("test_token", timestamp, nonce, msgEncrypt) - if !WeComVerifySignature(ch.config.Token, expectedSig, timestamp, nonce, msgEncrypt) { + if !verifySignature(ch.config.Token, expectedSig, timestamp, nonce, msgEncrypt) { t.Error("valid signature should pass verification") } }) @@ -187,7 +184,7 @@ func TestWeComBotVerifySignature(t *testing.T) { nonce := "test_nonce" msgEncrypt := "test_message" - if WeComVerifySignature(ch.config.Token, "invalid_sig", timestamp, nonce, msgEncrypt) { + if verifySignature(ch.config.Token, "invalid_sig", timestamp, nonce, msgEncrypt) { t.Error("invalid signature should fail verification") } }) @@ -202,7 +199,7 @@ func TestWeComBotVerifySignature(t *testing.T) { config: cfgEmpty, } - if !WeComVerifySignature(chEmpty.config.Token, "any_sig", "any_ts", "any_nonce", "any_msg") { + if !verifySignature(chEmpty.config.Token, "any_sig", "any_ts", "any_nonce", "any_msg") { t.Error("empty token should skip verification and return true") } }) @@ -223,7 +220,7 @@ func TestWeComBotDecryptMessage(t *testing.T) { plainText := "hello world" encoded := base64.StdEncoding.EncodeToString([]byte(plainText)) - result, err := WeComDecryptMessage(encoded, ch.config.EncodingAESKey) + result, err := decryptMessage(encoded, ch.config.EncodingAESKey) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -247,7 +244,7 @@ func TestWeComBotDecryptMessage(t *testing.T) { t.Fatalf("failed to encrypt test message: %v", err) } - result, err := WeComDecryptMessage(encrypted, ch.config.EncodingAESKey) + result, err := decryptMessage(encrypted, ch.config.EncodingAESKey) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -264,7 +261,7 @@ func TestWeComBotDecryptMessage(t *testing.T) { } ch, _ := NewWeComBotChannel(cfg, msgBus) - _, err := WeComDecryptMessage("invalid_base64!!!", ch.config.EncodingAESKey) + _, err := decryptMessage("invalid_base64!!!", ch.config.EncodingAESKey) if err == nil { t.Error("expected error for invalid base64, got nil") } @@ -278,7 +275,7 @@ func TestWeComBotDecryptMessage(t *testing.T) { } ch, _ := NewWeComBotChannel(cfg, msgBus) - _, err := WeComDecryptMessage(base64.StdEncoding.EncodeToString([]byte("test")), ch.config.EncodingAESKey) + _, err := decryptMessage(base64.StdEncoding.EncodeToString([]byte("test")), ch.config.EncodingAESKey) if err == nil { t.Error("expected error for invalid AES key, got nil") } @@ -320,20 +317,20 @@ func TestWeComBotPKCS7Unpad(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result, err := pkcs7UnpadWeCom(tt.input) + result, err := pkcs7Unpad(tt.input) if tt.expected == nil { // This case should return an error if err == nil { - t.Errorf("pkcs7UnpadWeCom() expected error for invalid padding, got result: %v", result) + t.Errorf("pkcs7Unpad() expected error for invalid padding, got result: %v", result) } return } if err != nil { - t.Errorf("pkcs7UnpadWeCom() unexpected error: %v", err) + t.Errorf("pkcs7Unpad() unexpected error: %v", err) return } if !bytes.Equal(result, tt.expected) { - t.Errorf("pkcs7UnpadWeCom() = %v, want %v", result, tt.expected) + t.Errorf("pkcs7Unpad() = %v, want %v", result, tt.expected) } }) } diff --git a/pkg/channels/wecom/common.go b/pkg/channels/wecom/common.go new file mode 100644 index 000000000..39a27d04c --- /dev/null +++ b/pkg/channels/wecom/common.go @@ -0,0 +1,134 @@ +package wecom + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/sha1" + "encoding/base64" + "encoding/binary" + "fmt" + "sort" + "strings" +) + +// blockSize is the PKCS7 block size used by WeCom (32) +const blockSize = 32 + +// verifySignature verifies the message signature for WeCom +// This is a common function used by both WeCom Bot and WeCom App +func verifySignature(token, msgSignature, timestamp, nonce, msgEncrypt string) bool { + if token == "" { + return true // Skip verification if token is not set + } + + // Sort parameters + params := []string{token, timestamp, nonce, msgEncrypt} + sort.Strings(params) + + // Concatenate + str := strings.Join(params, "") + + // SHA1 hash + hash := sha1.Sum([]byte(str)) + expectedSignature := fmt.Sprintf("%x", hash) + + return expectedSignature == msgSignature +} + +// decryptMessage decrypts the encrypted message using AES +// For AIBOT, receiveid should be the aibotid; for other apps, it should be corp_id +func decryptMessage(encryptedMsg, encodingAESKey string) (string, error) { + return decryptMessageWithVerify(encryptedMsg, encodingAESKey, "") +} + +// decryptMessageWithVerify decrypts the encrypted message and optionally verifies receiveid +// receiveid: for AIBOT use aibotid, for WeCom App use corp_id. If empty, skip verification. +func decryptMessageWithVerify(encryptedMsg, encodingAESKey, receiveid string) (string, error) { + if encodingAESKey == "" { + // No encryption, return as is (base64 decode) + decoded, err := base64.StdEncoding.DecodeString(encryptedMsg) + if err != nil { + return "", err + } + return string(decoded), nil + } + + // Decode AES key (base64) + aesKey, err := base64.StdEncoding.DecodeString(encodingAESKey + "=") + if err != nil { + return "", fmt.Errorf("failed to decode AES key: %w", err) + } + + // Decode encrypted message + cipherText, err := base64.StdEncoding.DecodeString(encryptedMsg) + if err != nil { + return "", fmt.Errorf("failed to decode message: %w", err) + } + + // AES decrypt + block, err := aes.NewCipher(aesKey) + if err != nil { + return "", fmt.Errorf("failed to create cipher: %w", err) + } + + if len(cipherText) < aes.BlockSize { + return "", fmt.Errorf("ciphertext too short") + } + + // IV is the first 16 bytes of AESKey + iv := aesKey[:aes.BlockSize] + mode := cipher.NewCBCDecrypter(block, iv) + plainText := make([]byte, len(cipherText)) + mode.CryptBlocks(plainText, cipherText) + + // Remove PKCS7 padding + plainText, err = pkcs7Unpad(plainText) + if err != nil { + return "", fmt.Errorf("failed to unpad: %w", err) + } + + // Parse message structure + // Format: random(16) + msg_len(4) + msg + receiveid + if len(plainText) < 20 { + return "", fmt.Errorf("decrypted message too short") + } + + msgLen := binary.BigEndian.Uint32(plainText[16:20]) + if int(msgLen) > len(plainText)-20 { + return "", fmt.Errorf("invalid message length") + } + + msg := plainText[20 : 20+msgLen] + + // Verify receiveid if provided + if receiveid != "" && len(plainText) > 20+int(msgLen) { + actualReceiveID := string(plainText[20+msgLen:]) + if actualReceiveID != receiveid { + return "", fmt.Errorf("receiveid mismatch: expected %s, got %s", receiveid, actualReceiveID) + } + } + + return string(msg), nil +} + +// pkcs7Unpad removes PKCS7 padding with validation +func pkcs7Unpad(data []byte) ([]byte, error) { + if len(data) == 0 { + return data, nil + } + padding := int(data[len(data)-1]) + // WeCom uses 32-byte block size for PKCS7 padding + if padding == 0 || padding > blockSize { + return nil, fmt.Errorf("invalid padding size: %d", padding) + } + if padding > len(data) { + return nil, fmt.Errorf("padding size larger than data") + } + // Verify all padding bytes + for i := range padding { + if data[len(data)-1-i] != byte(padding) { + return nil, fmt.Errorf("invalid padding byte at position %d", i) + } + } + return data[:len(data)-padding], nil +} diff --git a/pkg/channels/wecom/init.go b/pkg/channels/wecom/init.go new file mode 100644 index 000000000..3ef1ecdf3 --- /dev/null +++ b/pkg/channels/wecom/init.go @@ -0,0 +1,16 @@ +package wecom + +import ( + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" +) + +func init() { + channels.RegisterFactory("wecom", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) { + return NewWeComBotChannel(cfg.Channels.WeCom, b) + }) + channels.RegisterFactory("wecom_app", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) { + return NewWeComAppChannel(cfg.Channels.WeComApp, b) + }) +} diff --git a/pkg/channels/whatsapp/init.go b/pkg/channels/whatsapp/init.go new file mode 100644 index 000000000..d9c2669c3 --- /dev/null +++ b/pkg/channels/whatsapp/init.go @@ -0,0 +1,13 @@ +package whatsapp + +import ( + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" +) + +func init() { + channels.RegisterFactory("whatsapp", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) { + return NewWhatsAppChannel(cfg.Channels.WhatsApp, b) + }) +} diff --git a/pkg/channels/whatsapp.go b/pkg/channels/whatsapp/whatsapp.go similarity index 54% rename from pkg/channels/whatsapp.go rename to pkg/channels/whatsapp/whatsapp.go index 2dc4017ac..70b3e02bf 100644 --- a/pkg/channels/whatsapp.go +++ b/pkg/channels/whatsapp/whatsapp.go @@ -1,31 +1,42 @@ -package channels +package whatsapp import ( "context" "encoding/json" "fmt" - "log" "sync" "time" "github.com/gorilla/websocket" "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/identity" + "github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/utils" ) type WhatsAppChannel struct { - *BaseChannel + *channels.BaseChannel conn *websocket.Conn config config.WhatsAppConfig url string + ctx context.Context + cancel context.CancelFunc mu sync.Mutex connected bool } func NewWhatsAppChannel(cfg config.WhatsAppConfig, bus *bus.MessageBus) (*WhatsAppChannel, error) { - base := NewBaseChannel("whatsapp", cfg, bus, cfg.AllowFrom) + base := channels.NewBaseChannel( + "whatsapp", + cfg, + bus, + cfg.AllowFrom, + channels.WithMaxMessageLength(65536), + channels.WithReasoningChannelID(cfg.ReasoningChannelID), + ) return &WhatsAppChannel{ BaseChannel: base, @@ -36,7 +47,11 @@ func NewWhatsAppChannel(cfg config.WhatsAppConfig, bus *bus.MessageBus) (*WhatsA } func (c *WhatsAppChannel) Start(ctx context.Context) error { - log.Printf("Starting WhatsApp channel connecting to %s...", c.url) + logger.InfoCF("whatsapp", "Starting WhatsApp channel", map[string]any{ + "bridge_url": c.url, + }) + + c.ctx, c.cancel = context.WithCancel(ctx) dialer := websocket.DefaultDialer dialer.HandshakeTimeout = 10 * time.Second @@ -46,6 +61,7 @@ func (c *WhatsAppChannel) Start(ctx context.Context) error { resp.Body.Close() } if err != nil { + c.cancel() return fmt.Errorf("failed to connect to WhatsApp bridge: %w", err) } @@ -54,39 +70,57 @@ func (c *WhatsAppChannel) Start(ctx context.Context) error { c.connected = true c.mu.Unlock() - c.setRunning(true) - log.Println("WhatsApp channel connected") + c.SetRunning(true) + logger.InfoC("whatsapp", "WhatsApp channel connected") - go c.listen(ctx) + go c.listen() return nil } func (c *WhatsAppChannel) Stop(ctx context.Context) error { - log.Println("Stopping WhatsApp channel...") + logger.InfoC("whatsapp", "Stopping WhatsApp channel...") + + // Cancel context first to signal listen goroutine to exit + if c.cancel != nil { + c.cancel() + } c.mu.Lock() defer c.mu.Unlock() if c.conn != nil { if err := c.conn.Close(); err != nil { - log.Printf("Error closing WhatsApp connection: %v", err) + logger.ErrorCF("whatsapp", "Error closing WhatsApp connection", map[string]any{ + "error": err.Error(), + }) } c.conn = nil } c.connected = false - c.setRunning(false) + c.SetRunning(false) return nil } func (c *WhatsAppChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { + if !c.IsRunning() { + return channels.ErrNotRunning + } + + // Check ctx before acquiring lock + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + c.mu.Lock() defer c.mu.Unlock() if c.conn == nil { - return fmt.Errorf("whatsapp connection not established") + return fmt.Errorf("whatsapp connection not established: %w", channels.ErrTemporary) } payload := map[string]any{ @@ -100,17 +134,20 @@ func (c *WhatsAppChannel) Send(ctx context.Context, msg bus.OutboundMessage) err return fmt.Errorf("failed to marshal message: %w", err) } + _ = c.conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) if err := c.conn.WriteMessage(websocket.TextMessage, data); err != nil { - return fmt.Errorf("failed to send message: %w", err) + _ = c.conn.SetWriteDeadline(time.Time{}) + return fmt.Errorf("whatsapp send: %w", channels.ErrTemporary) } + _ = c.conn.SetWriteDeadline(time.Time{}) return nil } -func (c *WhatsAppChannel) listen(ctx context.Context) { +func (c *WhatsAppChannel) listen() { for { select { - case <-ctx.Done(): + case <-c.ctx.Done(): return default: c.mu.Lock() @@ -124,14 +161,18 @@ func (c *WhatsAppChannel) listen(ctx context.Context) { _, message, err := conn.ReadMessage() if err != nil { - log.Printf("WhatsApp read error: %v", err) + logger.ErrorCF("whatsapp", "WhatsApp read error", map[string]any{ + "error": err.Error(), + }) time.Sleep(2 * time.Second) continue } var msg map[string]any if err := json.Unmarshal(message, &msg); err != nil { - log.Printf("Failed to unmarshal WhatsApp message: %v", err) + logger.ErrorCF("whatsapp", "Failed to unmarshal WhatsApp message", map[string]any{ + "error": err.Error(), + }) continue } @@ -174,22 +215,38 @@ func (c *WhatsAppChannel) handleIncomingMessage(msg map[string]any) { } metadata := make(map[string]string) - if messageID, ok := msg["id"].(string); ok { - metadata["message_id"] = messageID + var messageID string + if mid, ok := msg["id"].(string); ok { + messageID = mid } if userName, ok := msg["from_name"].(string); ok { metadata["user_name"] = userName } + var peer bus.Peer if chatID == senderID { - metadata["peer_kind"] = "direct" - metadata["peer_id"] = senderID + peer = bus.Peer{Kind: "direct", ID: senderID} } else { - metadata["peer_kind"] = "group" - metadata["peer_id"] = chatID + peer = bus.Peer{Kind: "group", ID: chatID} } - log.Printf("WhatsApp message from %s: %s...", senderID, utils.Truncate(content, 50)) + logger.InfoCF("whatsapp", "WhatsApp message received", map[string]any{ + "sender": senderID, + "preview": utils.Truncate(content, 50), + }) - c.HandleMessage(senderID, chatID, content, mediaPaths, metadata) + sender := bus.SenderInfo{ + Platform: "whatsapp", + PlatformID: senderID, + CanonicalID: identity.BuildCanonicalID("whatsapp", senderID), + } + if display, ok := metadata["user_name"]; ok { + sender.DisplayName = display + } + + if !c.IsAllowedSender(sender) { + return + } + + c.HandleMessage(c.ctx, peer, messageID, senderID, chatID, content, mediaPaths, metadata, sender) } diff --git a/pkg/channels/whatsapp_native/init.go b/pkg/channels/whatsapp_native/init.go new file mode 100644 index 000000000..df13e8539 --- /dev/null +++ b/pkg/channels/whatsapp_native/init.go @@ -0,0 +1,20 @@ +package whatsapp + +import ( + "path/filepath" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" +) + +func init() { + channels.RegisterFactory("whatsapp_native", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) { + waCfg := cfg.Channels.WhatsApp + storePath := waCfg.SessionStorePath + if storePath == "" { + storePath = filepath.Join(cfg.WorkspacePath(), "whatsapp") + } + return NewWhatsAppNativeChannel(waCfg, b, storePath) + }) +} diff --git a/pkg/channels/whatsapp_native/whatsapp_native.go b/pkg/channels/whatsapp_native/whatsapp_native.go new file mode 100644 index 000000000..23115bda7 --- /dev/null +++ b/pkg/channels/whatsapp_native/whatsapp_native.go @@ -0,0 +1,341 @@ +//go:build whatsapp_native + +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package whatsapp + +import ( + "context" + "database/sql" + "fmt" + "os" + "path/filepath" + "strings" + "sync" + "time" + + "github.com/mdp/qrterminal/v3" + "go.mau.fi/whatsmeow" + "go.mau.fi/whatsmeow/proto/waE2E" + "go.mau.fi/whatsmeow/store/sqlstore" + "go.mau.fi/whatsmeow/types" + "go.mau.fi/whatsmeow/types/events" + waLog "go.mau.fi/whatsmeow/util/log" + "google.golang.org/protobuf/proto" + _ "modernc.org/sqlite" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/identity" + "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/utils" +) + +const ( + sqliteDriver = "sqlite" + whatsappDBName = "store.db" + + reconnectInitial = 5 * time.Second + reconnectMax = 5 * time.Minute + reconnectMultiplier = 2.0 +) + +// WhatsAppNativeChannel implements the WhatsApp channel using whatsmeow (in-process, no external bridge). +type WhatsAppNativeChannel struct { + *channels.BaseChannel + config config.WhatsAppConfig + storePath string + client *whatsmeow.Client + container *sqlstore.Container + mu sync.Mutex + runCtx context.Context + runCancel context.CancelFunc + reconnectMu sync.Mutex + reconnecting bool +} + +// NewWhatsAppNativeChannel creates a WhatsApp channel that uses whatsmeow for connection. +// storePath is the directory for the SQLite session store (e.g. workspace/whatsapp). +func NewWhatsAppNativeChannel( + cfg config.WhatsAppConfig, + bus *bus.MessageBus, + storePath string, +) (channels.Channel, error) { + base := channels.NewBaseChannel("whatsapp_native", cfg, bus, cfg.AllowFrom, channels.WithMaxMessageLength(65536)) + if storePath == "" { + storePath = "whatsapp" + } + c := &WhatsAppNativeChannel{ + BaseChannel: base, + config: cfg, + storePath: storePath, + } + return c, nil +} + +func (c *WhatsAppNativeChannel) Start(ctx context.Context) error { + logger.InfoCF("whatsapp", "Starting WhatsApp native channel (whatsmeow)", map[string]any{"store": c.storePath}) + + if err := os.MkdirAll(c.storePath, 0o700); err != nil { + return fmt.Errorf("create session store dir: %w", err) + } + + dbPath := filepath.Join(c.storePath, whatsappDBName) + connStr := "file:" + dbPath + "?_foreign_keys=on" + + db, err := sql.Open(sqliteDriver, connStr) + if err != nil { + return fmt.Errorf("open whatsapp store: %w", err) + } + db.SetMaxOpenConns(1) + db.SetMaxIdleConns(1) + if _, err = db.ExecContext(ctx, "PRAGMA foreign_keys = ON"); err != nil { + _ = db.Close() + return fmt.Errorf("enable foreign keys: %w", err) + } + + waLogger := waLog.Stdout("WhatsApp", "WARN", true) + container := sqlstore.NewWithDB(db, sqliteDriver, waLogger) + if err = container.Upgrade(ctx); err != nil { + _ = db.Close() + return fmt.Errorf("open whatsapp store: %w", err) + } + + deviceStore, err := container.GetFirstDevice(ctx) + if err != nil { + _ = container.Close() + return fmt.Errorf("get device store: %w", err) + } + + client := whatsmeow.NewClient(deviceStore, waLogger) + client.AddEventHandler(c.eventHandler) + + c.mu.Lock() + c.container = container + c.client = client + c.mu.Unlock() + + if client.Store.ID == nil { + qrChan, err := client.GetQRChannel(ctx) + if err != nil { + _ = container.Close() + return fmt.Errorf("get QR channel: %w", err) + } + if err := client.Connect(); err != nil { + _ = container.Close() + return fmt.Errorf("connect: %w", err) + } + for evt := range qrChan { + if evt.Event == "code" { + logger.InfoCF("whatsapp", "Scan this QR code with WhatsApp (Linked Devices):", nil) + qrterminal.GenerateWithConfig(evt.Code, qrterminal.Config{ + Level: qrterminal.L, + Writer: os.Stdout, + HalfBlocks: true, + }) + } else { + logger.InfoCF("whatsapp", "WhatsApp login event", map[string]any{"event": evt.Event}) + } + } + } else { + if err := client.Connect(); err != nil { + _ = container.Close() + return fmt.Errorf("connect: %w", err) + } + } + + c.runCtx, c.runCancel = context.WithCancel(ctx) + c.SetRunning(true) + logger.InfoC("whatsapp", "WhatsApp native channel connected") + return nil +} + +func (c *WhatsAppNativeChannel) Stop(ctx context.Context) error { + logger.InfoC("whatsapp", "Stopping WhatsApp native channel") + if c.runCancel != nil { + c.runCancel() + } + c.mu.Lock() + client := c.client + container := c.container + c.client = nil + c.container = nil + c.mu.Unlock() + + if client != nil { + client.Disconnect() + } + if container != nil { + _ = container.Close() + } + c.SetRunning(false) + return nil +} + +func (c *WhatsAppNativeChannel) eventHandler(evt any) { + switch evt.(type) { + case *events.Message: + c.handleIncoming(evt.(*events.Message)) + case *events.Disconnected: + logger.InfoCF("whatsapp", "WhatsApp disconnected, will attempt reconnection", nil) + c.reconnectMu.Lock() + if c.reconnecting { + c.reconnectMu.Unlock() + return + } + c.reconnecting = true + c.reconnectMu.Unlock() + go c.reconnectWithBackoff() + } +} + +func (c *WhatsAppNativeChannel) reconnectWithBackoff() { + defer func() { + c.reconnectMu.Lock() + c.reconnecting = false + c.reconnectMu.Unlock() + }() + + backoff := reconnectInitial + for { + select { + case <-c.runCtx.Done(): + return + default: + } + + c.mu.Lock() + client := c.client + c.mu.Unlock() + if client == nil { + return + } + + logger.InfoCF("whatsapp", "WhatsApp reconnecting", map[string]any{"backoff": backoff.String()}) + err := client.Connect() + if err == nil { + logger.InfoC("whatsapp", "WhatsApp reconnected") + return + } + + logger.WarnCF("whatsapp", "WhatsApp reconnect failed", map[string]any{"error": err.Error()}) + + select { + case <-c.runCtx.Done(): + return + case <-time.After(backoff): + if backoff < reconnectMax { + next := time.Duration(float64(backoff) * reconnectMultiplier) + if next > reconnectMax { + next = reconnectMax + } + backoff = next + } + } + } +} + +func (c *WhatsAppNativeChannel) handleIncoming(evt *events.Message) { + if evt.Message == nil { + return + } + senderID := evt.Info.Sender.String() + chatID := evt.Info.Chat.String() + content := evt.Message.GetConversation() + if content == "" && evt.Message.ExtendedTextMessage != nil { + content = evt.Message.ExtendedTextMessage.GetText() + } + content = utils.SanitizeMessageContent(content) + + if content == "" { + return + } + + var mediaPaths []string + + metadata := make(map[string]string) + metadata["message_id"] = evt.Info.ID + if evt.Info.PushName != "" { + metadata["user_name"] = evt.Info.PushName + } + if evt.Info.Chat.Server == types.GroupServer { + metadata["peer_kind"] = "group" + metadata["peer_id"] = chatID + } else { + metadata["peer_kind"] = "direct" + metadata["peer_id"] = senderID + } + + peerKind := "direct" + if evt.Info.Chat.Server == types.GroupServer { + peerKind = "group" + } + peer := bus.Peer{Kind: peerKind, ID: chatID} + messageID := evt.Info.ID + sender := bus.SenderInfo{ + Platform: "whatsapp", + PlatformID: senderID, + CanonicalID: identity.BuildCanonicalID("whatsapp", senderID), + DisplayName: evt.Info.PushName, + } + + if !c.IsAllowedSender(sender) { + return + } + + logger.DebugCF( + "whatsapp", + "WhatsApp message received", + map[string]any{"sender_id": senderID, "content_preview": utils.Truncate(content, 50)}, + ) + c.HandleMessage(c.runCtx, peer, messageID, senderID, chatID, content, mediaPaths, metadata, sender) +} + +func (c *WhatsAppNativeChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { + if !c.IsRunning() { + return channels.ErrNotRunning + } + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + + c.mu.Lock() + client := c.client + c.mu.Unlock() + + if client == nil || !client.IsConnected() { + return fmt.Errorf("whatsapp connection not established: %w", channels.ErrTemporary) + } + + to, err := parseJID(msg.ChatID) + if err != nil { + return fmt.Errorf("invalid chat id %q: %w", msg.ChatID, err) + } + + waMsg := &waE2E.Message{ + Conversation: proto.String(msg.Content), + } + + if _, err = client.SendMessage(ctx, to, waMsg); err != nil { + return fmt.Errorf("whatsapp send: %w", channels.ErrTemporary) + } + return nil +} + +// parseJID converts a chat ID (phone number or JID string) to types.JID. +func parseJID(s string) (types.JID, error) { + s = strings.TrimSpace(s) + if s == "" { + return types.JID{}, fmt.Errorf("empty chat id") + } + if strings.Contains(s, "@") { + return types.ParseJID(s) + } + return types.NewJID(s, types.DefaultUserServer), nil +} diff --git a/pkg/channels/whatsapp_native/whatsapp_native_stub.go b/pkg/channels/whatsapp_native/whatsapp_native_stub.go new file mode 100644 index 000000000..984af23e7 --- /dev/null +++ b/pkg/channels/whatsapp_native/whatsapp_native_stub.go @@ -0,0 +1,21 @@ +//go:build !whatsapp_native + +package whatsapp + +import ( + "fmt" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" +) + +// NewWhatsAppNativeChannel returns an error when the binary was not built with -tags whatsapp_native. +// Build with: go build -tags whatsapp_native ./cmd/... +func NewWhatsAppNativeChannel( + cfg config.WhatsAppConfig, + bus *bus.MessageBus, + storePath string, +) (channels.Channel, error) { + return nil, fmt.Errorf("whatsapp native not compiled in; build with -tags whatsapp_native") +} diff --git a/pkg/config/config.go b/pkg/config/config.go index ca5803c35..d84772d2b 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -172,7 +172,7 @@ type AgentDefaults struct { RestrictToWorkspace bool `json:"restrict_to_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE"` Provider string `json:"provider" env:"PICOCLAW_AGENTS_DEFAULTS_PROVIDER"` ModelName string `json:"model_name,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL_NAME"` - Model string `json:"model,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL"` // Deprecated: use model_name instead + Model string `json:"model" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL"` // Deprecated: use model_name instead ModelFallbacks []string `json:"model_fallbacks,omitempty"` ImageModel string `json:"image_model,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_IMAGE_MODEL"` ImageModelFallbacks []string `json:"image_model_fallbacks,omitempty"` @@ -203,108 +203,173 @@ type ChannelsConfig struct { OneBot OneBotConfig `json:"onebot"` WeCom WeComConfig `json:"wecom"` WeComApp WeComAppConfig `json:"wecom_app"` + Pico PicoConfig `json:"pico"` +} + +// GroupTriggerConfig controls when the bot responds in group chats. +type GroupTriggerConfig struct { + MentionOnly bool `json:"mention_only,omitempty"` + Prefixes []string `json:"prefixes,omitempty"` +} + +// TypingConfig controls typing indicator behavior (Phase 10). +type TypingConfig struct { + Enabled bool `json:"enabled,omitempty"` +} + +// PlaceholderConfig controls placeholder message behavior (Phase 10). +type PlaceholderConfig struct { + Enabled bool `json:"enabled,omitempty"` + Text string `json:"text,omitempty"` } type WhatsAppConfig struct { - Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_WHATSAPP_ENABLED"` - BridgeURL string `json:"bridge_url" env:"PICOCLAW_CHANNELS_WHATSAPP_BRIDGE_URL"` - AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_WHATSAPP_ALLOW_FROM"` + Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_WHATSAPP_ENABLED"` + BridgeURL string `json:"bridge_url" env:"PICOCLAW_CHANNELS_WHATSAPP_BRIDGE_URL"` + UseNative bool `json:"use_native" env:"PICOCLAW_CHANNELS_WHATSAPP_USE_NATIVE"` + SessionStorePath string `json:"session_store_path" env:"PICOCLAW_CHANNELS_WHATSAPP_SESSION_STORE_PATH"` + AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_WHATSAPP_ALLOW_FROM"` + ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_WHATSAPP_REASONING_CHANNEL_ID"` } type TelegramConfig struct { - Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_TELEGRAM_ENABLED"` - Token string `json:"token" env:"PICOCLAW_CHANNELS_TELEGRAM_TOKEN"` - Proxy string `json:"proxy" env:"PICOCLAW_CHANNELS_TELEGRAM_PROXY"` - AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_TELEGRAM_ALLOW_FROM"` + Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_TELEGRAM_ENABLED"` + Token string `json:"token" env:"PICOCLAW_CHANNELS_TELEGRAM_TOKEN"` + Proxy string `json:"proxy" env:"PICOCLAW_CHANNELS_TELEGRAM_PROXY"` + AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_TELEGRAM_ALLOW_FROM"` + GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"` + Typing TypingConfig `json:"typing,omitempty"` + Placeholder PlaceholderConfig `json:"placeholder,omitempty"` + ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_TELEGRAM_REASONING_CHANNEL_ID"` } type FeishuConfig struct { - Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_FEISHU_ENABLED"` - AppID string `json:"app_id" env:"PICOCLAW_CHANNELS_FEISHU_APP_ID"` - AppSecret string `json:"app_secret" env:"PICOCLAW_CHANNELS_FEISHU_APP_SECRET"` - EncryptKey string `json:"encrypt_key" env:"PICOCLAW_CHANNELS_FEISHU_ENCRYPT_KEY"` - VerificationToken string `json:"verification_token" env:"PICOCLAW_CHANNELS_FEISHU_VERIFICATION_TOKEN"` - AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_FEISHU_ALLOW_FROM"` + Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_FEISHU_ENABLED"` + AppID string `json:"app_id" env:"PICOCLAW_CHANNELS_FEISHU_APP_ID"` + AppSecret string `json:"app_secret" env:"PICOCLAW_CHANNELS_FEISHU_APP_SECRET"` + EncryptKey string `json:"encrypt_key" env:"PICOCLAW_CHANNELS_FEISHU_ENCRYPT_KEY"` + VerificationToken string `json:"verification_token" env:"PICOCLAW_CHANNELS_FEISHU_VERIFICATION_TOKEN"` + AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_FEISHU_ALLOW_FROM"` + GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"` + ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_FEISHU_REASONING_CHANNEL_ID"` } type DiscordConfig struct { - Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_DISCORD_ENABLED"` - Token string `json:"token" env:"PICOCLAW_CHANNELS_DISCORD_TOKEN"` - AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_DISCORD_ALLOW_FROM"` - MentionOnly bool `json:"mention_only" env:"PICOCLAW_CHANNELS_DISCORD_MENTION_ONLY"` + Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_DISCORD_ENABLED"` + Token string `json:"token" env:"PICOCLAW_CHANNELS_DISCORD_TOKEN"` + AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_DISCORD_ALLOW_FROM"` + MentionOnly bool `json:"mention_only" env:"PICOCLAW_CHANNELS_DISCORD_MENTION_ONLY"` + GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"` + Typing TypingConfig `json:"typing,omitempty"` + Placeholder PlaceholderConfig `json:"placeholder,omitempty"` + ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_DISCORD_REASONING_CHANNEL_ID"` } type MaixCamConfig struct { - Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_MAIXCAM_ENABLED"` - Host string `json:"host" env:"PICOCLAW_CHANNELS_MAIXCAM_HOST"` - Port int `json:"port" env:"PICOCLAW_CHANNELS_MAIXCAM_PORT"` - AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_MAIXCAM_ALLOW_FROM"` + Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_MAIXCAM_ENABLED"` + Host string `json:"host" env:"PICOCLAW_CHANNELS_MAIXCAM_HOST"` + Port int `json:"port" env:"PICOCLAW_CHANNELS_MAIXCAM_PORT"` + AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_MAIXCAM_ALLOW_FROM"` + ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_MAIXCAM_REASONING_CHANNEL_ID"` } type QQConfig struct { - Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_QQ_ENABLED"` - AppID string `json:"app_id" env:"PICOCLAW_CHANNELS_QQ_APP_ID"` - AppSecret string `json:"app_secret" env:"PICOCLAW_CHANNELS_QQ_APP_SECRET"` - AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_QQ_ALLOW_FROM"` + Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_QQ_ENABLED"` + AppID string `json:"app_id" env:"PICOCLAW_CHANNELS_QQ_APP_ID"` + AppSecret string `json:"app_secret" env:"PICOCLAW_CHANNELS_QQ_APP_SECRET"` + AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_QQ_ALLOW_FROM"` + GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"` + ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_QQ_REASONING_CHANNEL_ID"` } type DingTalkConfig struct { - Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_DINGTALK_ENABLED"` - ClientID string `json:"client_id" env:"PICOCLAW_CHANNELS_DINGTALK_CLIENT_ID"` - ClientSecret string `json:"client_secret" env:"PICOCLAW_CHANNELS_DINGTALK_CLIENT_SECRET"` - AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_DINGTALK_ALLOW_FROM"` + Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_DINGTALK_ENABLED"` + ClientID string `json:"client_id" env:"PICOCLAW_CHANNELS_DINGTALK_CLIENT_ID"` + ClientSecret string `json:"client_secret" env:"PICOCLAW_CHANNELS_DINGTALK_CLIENT_SECRET"` + AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_DINGTALK_ALLOW_FROM"` + GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"` + ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_DINGTALK_REASONING_CHANNEL_ID"` } type SlackConfig struct { - Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_SLACK_ENABLED"` - BotToken string `json:"bot_token" env:"PICOCLAW_CHANNELS_SLACK_BOT_TOKEN"` - AppToken string `json:"app_token" env:"PICOCLAW_CHANNELS_SLACK_APP_TOKEN"` - AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_SLACK_ALLOW_FROM"` + Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_SLACK_ENABLED"` + BotToken string `json:"bot_token" env:"PICOCLAW_CHANNELS_SLACK_BOT_TOKEN"` + AppToken string `json:"app_token" env:"PICOCLAW_CHANNELS_SLACK_APP_TOKEN"` + AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_SLACK_ALLOW_FROM"` + GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"` + Typing TypingConfig `json:"typing,omitempty"` + Placeholder PlaceholderConfig `json:"placeholder,omitempty"` + ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_SLACK_REASONING_CHANNEL_ID"` } type LINEConfig struct { - Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_LINE_ENABLED"` - ChannelSecret string `json:"channel_secret" env:"PICOCLAW_CHANNELS_LINE_CHANNEL_SECRET"` - ChannelAccessToken string `json:"channel_access_token" env:"PICOCLAW_CHANNELS_LINE_CHANNEL_ACCESS_TOKEN"` - WebhookHost string `json:"webhook_host" env:"PICOCLAW_CHANNELS_LINE_WEBHOOK_HOST"` - WebhookPort int `json:"webhook_port" env:"PICOCLAW_CHANNELS_LINE_WEBHOOK_PORT"` - WebhookPath string `json:"webhook_path" env:"PICOCLAW_CHANNELS_LINE_WEBHOOK_PATH"` - AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_LINE_ALLOW_FROM"` + Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_LINE_ENABLED"` + ChannelSecret string `json:"channel_secret" env:"PICOCLAW_CHANNELS_LINE_CHANNEL_SECRET"` + ChannelAccessToken string `json:"channel_access_token" env:"PICOCLAW_CHANNELS_LINE_CHANNEL_ACCESS_TOKEN"` + WebhookHost string `json:"webhook_host" env:"PICOCLAW_CHANNELS_LINE_WEBHOOK_HOST"` + WebhookPort int `json:"webhook_port" env:"PICOCLAW_CHANNELS_LINE_WEBHOOK_PORT"` + WebhookPath string `json:"webhook_path" env:"PICOCLAW_CHANNELS_LINE_WEBHOOK_PATH"` + AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_LINE_ALLOW_FROM"` + GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"` + Typing TypingConfig `json:"typing,omitempty"` + Placeholder PlaceholderConfig `json:"placeholder,omitempty"` + ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_LINE_REASONING_CHANNEL_ID"` } type OneBotConfig struct { - Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_ONEBOT_ENABLED"` - WSUrl string `json:"ws_url" env:"PICOCLAW_CHANNELS_ONEBOT_WS_URL"` - AccessToken string `json:"access_token" env:"PICOCLAW_CHANNELS_ONEBOT_ACCESS_TOKEN"` - ReconnectInterval int `json:"reconnect_interval" env:"PICOCLAW_CHANNELS_ONEBOT_RECONNECT_INTERVAL"` - GroupTriggerPrefix []string `json:"group_trigger_prefix" env:"PICOCLAW_CHANNELS_ONEBOT_GROUP_TRIGGER_PREFIX"` - AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_ONEBOT_ALLOW_FROM"` + Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_ONEBOT_ENABLED"` + WSUrl string `json:"ws_url" env:"PICOCLAW_CHANNELS_ONEBOT_WS_URL"` + AccessToken string `json:"access_token" env:"PICOCLAW_CHANNELS_ONEBOT_ACCESS_TOKEN"` + ReconnectInterval int `json:"reconnect_interval" env:"PICOCLAW_CHANNELS_ONEBOT_RECONNECT_INTERVAL"` + GroupTriggerPrefix []string `json:"group_trigger_prefix" env:"PICOCLAW_CHANNELS_ONEBOT_GROUP_TRIGGER_PREFIX"` + AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_ONEBOT_ALLOW_FROM"` + GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"` + Typing TypingConfig `json:"typing,omitempty"` + Placeholder PlaceholderConfig `json:"placeholder,omitempty"` + ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_ONEBOT_REASONING_CHANNEL_ID"` } type WeComConfig struct { - Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_WECOM_ENABLED"` - Token string `json:"token" env:"PICOCLAW_CHANNELS_WECOM_TOKEN"` - EncodingAESKey string `json:"encoding_aes_key" env:"PICOCLAW_CHANNELS_WECOM_ENCODING_AES_KEY"` - WebhookURL string `json:"webhook_url" env:"PICOCLAW_CHANNELS_WECOM_WEBHOOK_URL"` - WebhookHost string `json:"webhook_host" env:"PICOCLAW_CHANNELS_WECOM_WEBHOOK_HOST"` - WebhookPort int `json:"webhook_port" env:"PICOCLAW_CHANNELS_WECOM_WEBHOOK_PORT"` - WebhookPath string `json:"webhook_path" env:"PICOCLAW_CHANNELS_WECOM_WEBHOOK_PATH"` - AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_WECOM_ALLOW_FROM"` - ReplyTimeout int `json:"reply_timeout" env:"PICOCLAW_CHANNELS_WECOM_REPLY_TIMEOUT"` + Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_WECOM_ENABLED"` + Token string `json:"token" env:"PICOCLAW_CHANNELS_WECOM_TOKEN"` + EncodingAESKey string `json:"encoding_aes_key" env:"PICOCLAW_CHANNELS_WECOM_ENCODING_AES_KEY"` + WebhookURL string `json:"webhook_url" env:"PICOCLAW_CHANNELS_WECOM_WEBHOOK_URL"` + WebhookHost string `json:"webhook_host" env:"PICOCLAW_CHANNELS_WECOM_WEBHOOK_HOST"` + WebhookPort int `json:"webhook_port" env:"PICOCLAW_CHANNELS_WECOM_WEBHOOK_PORT"` + WebhookPath string `json:"webhook_path" env:"PICOCLAW_CHANNELS_WECOM_WEBHOOK_PATH"` + AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_WECOM_ALLOW_FROM"` + ReplyTimeout int `json:"reply_timeout" env:"PICOCLAW_CHANNELS_WECOM_REPLY_TIMEOUT"` + GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"` + ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_WECOM_REASONING_CHANNEL_ID"` } type WeComAppConfig struct { - Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_WECOM_APP_ENABLED"` - CorpID string `json:"corp_id" env:"PICOCLAW_CHANNELS_WECOM_APP_CORP_ID"` - CorpSecret string `json:"corp_secret" env:"PICOCLAW_CHANNELS_WECOM_APP_CORP_SECRET"` - AgentID int64 `json:"agent_id" env:"PICOCLAW_CHANNELS_WECOM_APP_AGENT_ID"` - Token string `json:"token" env:"PICOCLAW_CHANNELS_WECOM_APP_TOKEN"` - EncodingAESKey string `json:"encoding_aes_key" env:"PICOCLAW_CHANNELS_WECOM_APP_ENCODING_AES_KEY"` - WebhookHost string `json:"webhook_host" env:"PICOCLAW_CHANNELS_WECOM_APP_WEBHOOK_HOST"` - WebhookPort int `json:"webhook_port" env:"PICOCLAW_CHANNELS_WECOM_APP_WEBHOOK_PORT"` - WebhookPath string `json:"webhook_path" env:"PICOCLAW_CHANNELS_WECOM_APP_WEBHOOK_PATH"` - AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_WECOM_APP_ALLOW_FROM"` - ReplyTimeout int `json:"reply_timeout" env:"PICOCLAW_CHANNELS_WECOM_APP_REPLY_TIMEOUT"` + Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_WECOM_APP_ENABLED"` + CorpID string `json:"corp_id" env:"PICOCLAW_CHANNELS_WECOM_APP_CORP_ID"` + CorpSecret string `json:"corp_secret" env:"PICOCLAW_CHANNELS_WECOM_APP_CORP_SECRET"` + AgentID int64 `json:"agent_id" env:"PICOCLAW_CHANNELS_WECOM_APP_AGENT_ID"` + Token string `json:"token" env:"PICOCLAW_CHANNELS_WECOM_APP_TOKEN"` + EncodingAESKey string `json:"encoding_aes_key" env:"PICOCLAW_CHANNELS_WECOM_APP_ENCODING_AES_KEY"` + WebhookHost string `json:"webhook_host" env:"PICOCLAW_CHANNELS_WECOM_APP_WEBHOOK_HOST"` + WebhookPort int `json:"webhook_port" env:"PICOCLAW_CHANNELS_WECOM_APP_WEBHOOK_PORT"` + WebhookPath string `json:"webhook_path" env:"PICOCLAW_CHANNELS_WECOM_APP_WEBHOOK_PATH"` + AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_WECOM_APP_ALLOW_FROM"` + ReplyTimeout int `json:"reply_timeout" env:"PICOCLAW_CHANNELS_WECOM_APP_REPLY_TIMEOUT"` + GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"` + ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_WECOM_APP_REASONING_CHANNEL_ID"` +} + +type PicoConfig struct { + Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_PICO_ENABLED"` + Token string `json:"token" env:"PICOCLAW_CHANNELS_PICO_TOKEN"` + AllowTokenQuery bool `json:"allow_token_query,omitempty"` + AllowOrigins []string `json:"allow_origins,omitempty"` + PingInterval int `json:"ping_interval,omitempty"` + ReadTimeout int `json:"read_timeout,omitempty"` + WriteTimeout int `json:"write_timeout,omitempty"` + MaxConnections int `json:"max_connections,omitempty"` + AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_PICO_ALLOW_FROM"` + Placeholder PlaceholderConfig `json:"placeholder,omitempty"` } type HeartbeatConfig struct { @@ -470,11 +535,18 @@ type ExecConfig struct { CustomDenyPatterns []string `json:"custom_deny_patterns" env:"PICOCLAW_TOOLS_EXEC_CUSTOM_DENY_PATTERNS"` } +type MediaCleanupConfig struct { + Enabled bool `json:"enabled" env:"PICOCLAW_MEDIA_CLEANUP_ENABLED"` + MaxAge int `json:"max_age_minutes" env:"PICOCLAW_MEDIA_CLEANUP_MAX_AGE"` + Interval int `json:"interval_minutes" env:"PICOCLAW_MEDIA_CLEANUP_INTERVAL"` +} + type ToolsConfig struct { - Web WebToolsConfig `json:"web"` - Cron CronToolsConfig `json:"cron"` - Exec ExecConfig `json:"exec"` - Skills SkillsToolsConfig `json:"skills"` + Web WebToolsConfig `json:"web"` + Cron CronToolsConfig `json:"cron"` + Exec ExecConfig `json:"exec"` + Skills SkillsToolsConfig `json:"skills"` + MediaCleanup MediaCleanupConfig `json:"media_cleanup"` } type SkillsToolsConfig struct { @@ -537,6 +609,9 @@ func LoadConfig(path string) (*Config, error) { return nil, err } + // Migrate legacy channel config fields to new unified structures + cfg.migrateChannelConfigs() + // Auto-migrate: if only legacy providers config exists, convert to model_list if len(cfg.ModelList) == 0 && cfg.HasProvidersConfig() { cfg.ModelList = ConvertProvidersToModelList(cfg) @@ -550,6 +625,18 @@ func LoadConfig(path string) (*Config, error) { return cfg, nil } +func (c *Config) migrateChannelConfigs() { + // Discord: mention_only -> group_trigger.mention_only + if c.Channels.Discord.MentionOnly && !c.Channels.Discord.GroupTrigger.MentionOnly { + c.Channels.Discord.GroupTrigger.MentionOnly = true + } + + // OneBot: group_trigger_prefix -> group_trigger.prefixes + if len(c.Channels.OneBot.GroupTriggerPrefix) > 0 && len(c.Channels.OneBot.GroupTrigger.Prefixes) == 0 { + c.Channels.OneBot.GroupTrigger.Prefixes = c.Channels.OneBot.GroupTriggerPrefix + } +} + func SaveConfig(path string, cfg *Config) error { data, err := json.MarshalIndent(cfg, "", " ") if err != nil { diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go index bf56b7f34..12fd10b50 100644 --- a/pkg/config/config_test.go +++ b/pkg/config/config_test.go @@ -5,6 +5,7 @@ import ( "os" "path/filepath" "runtime" + "strings" "testing" ) @@ -210,8 +211,8 @@ func TestDefaultConfig_WorkspacePath(t *testing.T) { func TestDefaultConfig_Model(t *testing.T) { cfg := DefaultConfig() - if cfg.Agents.Defaults.Model == "" { - t.Error("Model should not be empty") + if cfg.Agents.Defaults.Model != "" { + t.Error("Model should be empty") } } @@ -324,6 +325,25 @@ func TestSaveConfig_FilePermissions(t *testing.T) { } } +func TestSaveConfig_IncludesEmptyLegacyModelField(t *testing.T) { + tmpDir := t.TempDir() + path := filepath.Join(tmpDir, "config.json") + + cfg := DefaultConfig() + if err := SaveConfig(path, cfg); err != nil { + t.Fatalf("SaveConfig failed: %v", err) + } + + data, err := os.ReadFile(path) + if err != nil { + t.Fatalf("ReadFile failed: %v", err) + } + + if !strings.Contains(string(data), `"model": ""`) { + t.Fatalf("saved config should include empty legacy model field, got: %s", string(data)) + } +} + // TestConfig_Complete verifies all config fields are set func TestConfig_Complete(t *testing.T) { cfg := DefaultConfig() @@ -331,8 +351,8 @@ func TestConfig_Complete(t *testing.T) { if cfg.Agents.Defaults.Workspace == "" { t.Error("Workspace should not be empty") } - if cfg.Agents.Defaults.Model == "" { - t.Error("Model should not be empty") + if cfg.Agents.Defaults.Model != "" { + t.Error("Model should be empty") } if cfg.Agents.Defaults.Temperature != nil { t.Error("Temperature should be nil when not provided") diff --git a/pkg/config/defaults.go b/pkg/config/defaults.go index cf799140d..ebb924859 100644 --- a/pkg/config/defaults.go +++ b/pkg/config/defaults.go @@ -13,10 +13,10 @@ func DefaultConfig() *Config { Workspace: "~/.picoclaw/workspace", RestrictToWorkspace: true, Provider: "", - Model: "glm-4.7", - MaxTokens: 8192, + Model: "", + MaxTokens: 32768, Temperature: nil, // nil means use provider default - MaxToolIterations: 20, + MaxToolIterations: 50, }, }, Bindings: []AgentBinding{}, @@ -25,14 +25,21 @@ func DefaultConfig() *Config { }, Channels: ChannelsConfig{ WhatsApp: WhatsAppConfig{ - Enabled: false, - BridgeURL: "ws://localhost:3001", - AllowFrom: FlexibleStringSlice{}, + Enabled: false, + BridgeURL: "ws://localhost:3001", + UseNative: false, + SessionStorePath: "", + AllowFrom: FlexibleStringSlice{}, }, Telegram: TelegramConfig{ Enabled: false, Token: "", AllowFrom: FlexibleStringSlice{}, + Typing: TypingConfig{Enabled: true}, + Placeholder: PlaceholderConfig{ + Enabled: true, + Text: "Thinking... 💭", + }, }, Feishu: FeishuConfig{ Enabled: false, @@ -80,6 +87,7 @@ func DefaultConfig() *Config { WebhookPort: 18791, WebhookPath: "/webhook/line", AllowFrom: FlexibleStringSlice{}, + GroupTrigger: GroupTriggerConfig{MentionOnly: true}, }, OneBot: OneBotConfig{ Enabled: false, @@ -113,6 +121,15 @@ func DefaultConfig() *Config { AllowFrom: FlexibleStringSlice{}, ReplyTimeout: 5, }, + Pico: PicoConfig{ + Enabled: false, + Token: "", + PingInterval: 30, + ReadTimeout: 60, + WriteTimeout: 10, + MaxConnections: 100, + AllowFrom: FlexibleStringSlice{}, + }, }, Providers: ProvidersConfig{ OpenAI: OpenAIProviderConfig{WebSearch: true}, @@ -276,6 +293,11 @@ func DefaultConfig() *Config { Port: 18790, }, Tools: ToolsConfig{ + MediaCleanup: MediaCleanupConfig{ + Enabled: true, + MaxAge: 30, + Interval: 5, + }, Web: WebToolsConfig{ Proxy: "", Brave: BraveConfig{ diff --git a/pkg/devices/service.go b/pkg/devices/service.go index 1541d3c57..1bafe6085 100644 --- a/pkg/devices/service.go +++ b/pkg/devices/service.go @@ -4,6 +4,7 @@ import ( "context" "strings" "sync" + "time" "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/constants" @@ -127,7 +128,9 @@ func (s *Service) sendNotification(ev *events.DeviceEvent) { } msg := ev.FormatMessage() - msgBus.PublishOutbound(bus.OutboundMessage{ + pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer pubCancel() + msgBus.PublishOutbound(pubCtx, bus.OutboundMessage{ Channel: platform, ChatID: userID, Content: msg, diff --git a/pkg/health/server.go b/pkg/health/server.go index d1acfb662..5609ebdf6 100644 --- a/pkg/health/server.go +++ b/pkg/health/server.go @@ -155,6 +155,13 @@ func (s *Server) readyHandler(w http.ResponseWriter, r *http.Request) { }) } +// RegisterOnMux registers /health and /ready handlers onto the given mux. +// This allows the health endpoints to be served by a shared HTTP server. +func (s *Server) RegisterOnMux(mux *http.ServeMux) { + mux.HandleFunc("/health", s.healthHandler) + mux.HandleFunc("/ready", s.readyHandler) +} + func statusString(ok bool) string { if ok { return "ok" diff --git a/pkg/heartbeat/service.go b/pkg/heartbeat/service.go index 58462c120..09c93fc6b 100644 --- a/pkg/heartbeat/service.go +++ b/pkg/heartbeat/service.go @@ -7,6 +7,7 @@ package heartbeat import ( + "context" "fmt" "os" "path/filepath" @@ -308,7 +309,9 @@ func (hs *HeartbeatService) sendResponse(response string) { return } - msgBus.PublishOutbound(bus.OutboundMessage{ + pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer pubCancel() + msgBus.PublishOutbound(pubCtx, bus.OutboundMessage{ Channel: platform, ChatID: userID, Content: response, diff --git a/pkg/identity/identity.go b/pkg/identity/identity.go new file mode 100644 index 000000000..6bc09c210 --- /dev/null +++ b/pkg/identity/identity.go @@ -0,0 +1,107 @@ +// Package identity provides unified user identity utilities for PicoClaw. +// It introduces a canonical "platform:id" format and matching logic +// that is backward-compatible with all legacy allow-list formats. +package identity + +import ( + "strings" + + "github.com/sipeed/picoclaw/pkg/bus" +) + +// BuildCanonicalID constructs a canonical "platform:id" identifier. +// Both platform and platformID are lowercased and trimmed. +func BuildCanonicalID(platform, platformID string) string { + p := strings.ToLower(strings.TrimSpace(platform)) + id := strings.TrimSpace(platformID) + if p == "" || id == "" { + return "" + } + return p + ":" + id +} + +// ParseCanonicalID splits a canonical ID ("platform:id") into its parts. +// Returns ok=false if the input does not contain a colon separator. +func ParseCanonicalID(canonical string) (platform, id string, ok bool) { + canonical = strings.TrimSpace(canonical) + idx := strings.Index(canonical, ":") + if idx <= 0 || idx == len(canonical)-1 { + return "", "", false + } + return canonical[:idx], canonical[idx+1:], true +} + +// MatchAllowed checks whether the given sender matches a single allow-list entry. +// It is backward-compatible with all legacy formats: +// +// - "123456" → matches sender.PlatformID +// - "@alice" → matches sender.Username +// - "123456|alice" → matches PlatformID or Username +// - "telegram:123456" → exact match on sender.CanonicalID +func MatchAllowed(sender bus.SenderInfo, allowed string) bool { + allowed = strings.TrimSpace(allowed) + if allowed == "" { + return false + } + + // Try canonical match first: "platform:id" format + if platform, id, ok := ParseCanonicalID(allowed); ok { + // Only treat as canonical if the platform portion looks like a known platform name + // (not a pure-numeric string, which could be a compound ID) + if !isNumeric(platform) { + candidate := BuildCanonicalID(platform, id) + if candidate != "" && sender.CanonicalID != "" { + return strings.EqualFold(sender.CanonicalID, candidate) + } + // If sender has no canonical ID, try matching platform + platformID + return strings.EqualFold(platform, sender.Platform) && + sender.PlatformID == id + } + } + + // Strip leading "@" for username matching + trimmed := strings.TrimPrefix(allowed, "@") + + // Split compound "id|username" format + allowedID := trimmed + allowedUser := "" + if idx := strings.Index(trimmed, "|"); idx > 0 { + allowedID = trimmed[:idx] + allowedUser = trimmed[idx+1:] + } + + // Match against PlatformID + if sender.PlatformID != "" && sender.PlatformID == allowedID { + return true + } + + // Match against Username + if sender.Username != "" { + if sender.Username == trimmed || sender.Username == allowedUser { + return true + } + } + + // Match compound sender format against allowed parts + if allowedUser != "" && sender.PlatformID != "" && sender.PlatformID == allowedID { + return true + } + if allowedUser != "" && sender.Username != "" && sender.Username == allowedUser { + return true + } + + return false +} + +// isNumeric returns true if s consists entirely of digits. +func isNumeric(s string) bool { + if s == "" { + return false + } + for _, r := range s { + if r < '0' || r > '9' { + return false + } + } + return true +} diff --git a/pkg/identity/identity_test.go b/pkg/identity/identity_test.go new file mode 100644 index 000000000..3d24bd794 --- /dev/null +++ b/pkg/identity/identity_test.go @@ -0,0 +1,229 @@ +package identity + +import ( + "testing" + + "github.com/sipeed/picoclaw/pkg/bus" +) + +func TestBuildCanonicalID(t *testing.T) { + tests := []struct { + platform string + platformID string + want string + }{ + {"telegram", "123456", "telegram:123456"}, + {"Discord", "98765432", "discord:98765432"}, + {"SLACK", "U123ABC", "slack:U123ABC"}, + {"", "123", ""}, + {"telegram", "", ""}, + {" telegram ", " 123 ", "telegram:123"}, + } + + for _, tt := range tests { + got := BuildCanonicalID(tt.platform, tt.platformID) + if got != tt.want { + t.Errorf("BuildCanonicalID(%q, %q) = %q, want %q", + tt.platform, tt.platformID, got, tt.want) + } + } +} + +func TestParseCanonicalID(t *testing.T) { + tests := []struct { + input string + wantPlatform string + wantID string + wantOk bool + }{ + {"telegram:123456", "telegram", "123456", true}, + {"discord:98765432", "discord", "98765432", true}, + {"slack:U123ABC", "slack", "U123ABC", true}, + {"nocolon", "", "", false}, + {"", "", "", false}, + {":missing", "", "", false}, + {"missing:", "", "", false}, + } + + for _, tt := range tests { + platform, id, ok := ParseCanonicalID(tt.input) + if ok != tt.wantOk || platform != tt.wantPlatform || id != tt.wantID { + t.Errorf("ParseCanonicalID(%q) = (%q, %q, %v), want (%q, %q, %v)", + tt.input, platform, id, ok, + tt.wantPlatform, tt.wantID, tt.wantOk) + } + } +} + +func TestMatchAllowed(t *testing.T) { + telegramSender := bus.SenderInfo{ + Platform: "telegram", + PlatformID: "123456", + CanonicalID: "telegram:123456", + Username: "alice", + DisplayName: "Alice Smith", + } + + discordSender := bus.SenderInfo{ + Platform: "discord", + PlatformID: "98765432", + CanonicalID: "discord:98765432", + Username: "bob", + DisplayName: "bob#1234", + } + + noCanonicalSender := bus.SenderInfo{ + Platform: "telegram", + PlatformID: "999", + Username: "carol", + } + + tests := []struct { + name string + sender bus.SenderInfo + allowed string + want bool + }{ + // Pure numeric ID matching + { + name: "numeric ID matches PlatformID", + sender: telegramSender, + allowed: "123456", + want: true, + }, + { + name: "numeric ID does not match", + sender: telegramSender, + allowed: "654321", + want: false, + }, + // Username matching + { + name: "@username matches Username", + sender: telegramSender, + allowed: "@alice", + want: true, + }, + { + name: "@username does not match", + sender: telegramSender, + allowed: "@bob", + want: false, + }, + // Compound format "id|username" + { + name: "compound matches by ID", + sender: telegramSender, + allowed: "123456|alice", + want: true, + }, + { + name: "compound matches by username", + sender: telegramSender, + allowed: "999|alice", + want: true, + }, + { + name: "compound does not match", + sender: telegramSender, + allowed: "654321|bob", + want: false, + }, + // Canonical format "platform:id" + { + name: "canonical matches exactly", + sender: telegramSender, + allowed: "telegram:123456", + want: true, + }, + { + name: "canonical case-insensitive platform", + sender: telegramSender, + allowed: "Telegram:123456", + want: true, + }, + { + name: "canonical wrong platform", + sender: telegramSender, + allowed: "discord:123456", + want: false, + }, + { + name: "canonical wrong ID", + sender: telegramSender, + allowed: "telegram:654321", + want: false, + }, + // Cross-platform canonical + { + name: "discord canonical match", + sender: discordSender, + allowed: "discord:98765432", + want: true, + }, + { + name: "telegram canonical does not match discord sender", + sender: discordSender, + allowed: "telegram:98765432", + want: false, + }, + // Sender without canonical ID + { + name: "canonical match falls back to platform+platformID", + sender: noCanonicalSender, + allowed: "telegram:999", + want: true, + }, + { + name: "platform mismatch on fallback", + sender: noCanonicalSender, + allowed: "discord:999", + want: false, + }, + // Empty allowed string + { + name: "empty allowed never matches", + sender: telegramSender, + allowed: "", + want: false, + }, + // Whitespace handling + { + name: "trimmed allowed matches", + sender: telegramSender, + allowed: " 123456 ", + want: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := MatchAllowed(tt.sender, tt.allowed) + if got != tt.want { + t.Errorf("MatchAllowed(%+v, %q) = %v, want %v", + tt.sender, tt.allowed, got, tt.want) + } + }) + } +} + +func TestIsNumeric(t *testing.T) { + tests := []struct { + input string + want bool + }{ + {"123456", true}, + {"0", true}, + {"", false}, + {"abc", false}, + {"12a34", false}, + {"telegram", false}, + } + + for _, tt := range tests { + got := isNumeric(tt.input) + if got != tt.want { + t.Errorf("isNumeric(%q) = %v, want %v", tt.input, got, tt.want) + } + } +} diff --git a/pkg/media/store.go b/pkg/media/store.go new file mode 100644 index 000000000..30220986c --- /dev/null +++ b/pkg/media/store.go @@ -0,0 +1,271 @@ +package media + +import ( + "fmt" + "os" + "sync" + "time" + + "github.com/google/uuid" + + "github.com/sipeed/picoclaw/pkg/logger" +) + +// MediaMeta holds metadata about a stored media file. +type MediaMeta struct { + Filename string + ContentType string + Source string // "telegram", "discord", "tool:image-gen", etc. +} + +// MediaStore manages the lifecycle of media files associated with processing scopes. +type MediaStore interface { + // Store registers an existing local file under the given scope. + // Returns a ref identifier (e.g. "media://"). + // Store does not move or copy the file; it only records the mapping. + Store(localPath string, meta MediaMeta, scope string) (ref string, err error) + + // Resolve returns the local file path for a given ref. + Resolve(ref string) (localPath string, err error) + + // ResolveWithMeta returns the local file path and metadata for a given ref. + ResolveWithMeta(ref string) (localPath string, meta MediaMeta, err error) + + // ReleaseAll deletes all files registered under the given scope + // and removes the mapping entries. File-not-exist errors are ignored. + ReleaseAll(scope string) error +} + +// mediaEntry holds the path and metadata for a stored media file. +type mediaEntry struct { + path string + meta MediaMeta + storedAt time.Time +} + +// MediaCleanerConfig configures the background TTL cleanup. +type MediaCleanerConfig struct { + Enabled bool + MaxAge time.Duration + Interval time.Duration +} + +// FileMediaStore is a pure in-memory implementation of MediaStore. +// Files are expected to already exist on disk (e.g. in /tmp/picoclaw_media/). +type FileMediaStore struct { + mu sync.RWMutex + refs map[string]mediaEntry + scopeToRefs map[string]map[string]struct{} + refToScope map[string]string + + cleanerCfg MediaCleanerConfig + stop chan struct{} + startOnce sync.Once + stopOnce sync.Once + nowFunc func() time.Time // for testing +} + +// NewFileMediaStore creates a new FileMediaStore without background cleanup. +func NewFileMediaStore() *FileMediaStore { + return &FileMediaStore{ + refs: make(map[string]mediaEntry), + scopeToRefs: make(map[string]map[string]struct{}), + refToScope: make(map[string]string), + nowFunc: time.Now, + } +} + +// NewFileMediaStoreWithCleanup creates a FileMediaStore with TTL-based background cleanup. +func NewFileMediaStoreWithCleanup(cfg MediaCleanerConfig) *FileMediaStore { + return &FileMediaStore{ + refs: make(map[string]mediaEntry), + scopeToRefs: make(map[string]map[string]struct{}), + refToScope: make(map[string]string), + cleanerCfg: cfg, + stop: make(chan struct{}), + nowFunc: time.Now, + } +} + +// Store registers a local file under the given scope. The file must exist. +func (s *FileMediaStore) Store(localPath string, meta MediaMeta, scope string) (string, error) { + if _, err := os.Stat(localPath); err != nil { + return "", fmt.Errorf("media store: %s: %w", localPath, err) + } + + ref := "media://" + uuid.New().String() + + s.mu.Lock() + defer s.mu.Unlock() + + s.refs[ref] = mediaEntry{path: localPath, meta: meta, storedAt: s.nowFunc()} + if s.scopeToRefs[scope] == nil { + s.scopeToRefs[scope] = make(map[string]struct{}) + } + s.scopeToRefs[scope][ref] = struct{}{} + s.refToScope[ref] = scope + + return ref, nil +} + +// Resolve returns the local path for the given ref. +func (s *FileMediaStore) Resolve(ref string) (string, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + entry, ok := s.refs[ref] + if !ok { + return "", fmt.Errorf("media store: unknown ref: %s", ref) + } + return entry.path, nil +} + +// ResolveWithMeta returns the local path and metadata for the given ref. +func (s *FileMediaStore) ResolveWithMeta(ref string) (string, MediaMeta, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + entry, ok := s.refs[ref] + if !ok { + return "", MediaMeta{}, fmt.Errorf("media store: unknown ref: %s", ref) + } + return entry.path, entry.meta, nil +} + +// ReleaseAll removes all files under the given scope and cleans up mappings. +// Phase 1 (under lock): remove entries from maps. +// Phase 2 (no lock): delete files from disk. +func (s *FileMediaStore) ReleaseAll(scope string) error { + // Phase 1: collect paths and remove from maps under lock + var paths []string + + s.mu.Lock() + refs, ok := s.scopeToRefs[scope] + if !ok { + s.mu.Unlock() + return nil + } + + for ref := range refs { + if entry, exists := s.refs[ref]; exists { + paths = append(paths, entry.path) + } + delete(s.refs, ref) + delete(s.refToScope, ref) + } + delete(s.scopeToRefs, scope) + s.mu.Unlock() + + // Phase 2: delete files without holding the lock + for _, p := range paths { + if err := os.Remove(p); err != nil && !os.IsNotExist(err) { + logger.WarnCF("media", "release: failed to remove file", map[string]any{ + "path": p, + "error": err.Error(), + }) + } + } + + return nil +} + +// CleanExpired removes all entries older than MaxAge. +// Phase 1 (under lock): identify expired entries and remove from maps. +// Phase 2 (no lock): delete files from disk to minimize lock contention. +func (s *FileMediaStore) CleanExpired() int { + if s.cleanerCfg.MaxAge <= 0 { + return 0 + } + + // Phase 1: collect expired entries under lock + type expiredEntry struct { + ref string + path string + } + + s.mu.Lock() + cutoff := s.nowFunc().Add(-s.cleanerCfg.MaxAge) + var expired []expiredEntry + + for ref, entry := range s.refs { + if entry.storedAt.Before(cutoff) { + expired = append(expired, expiredEntry{ref: ref, path: entry.path}) + + if scope, ok := s.refToScope[ref]; ok { + if scopeRefs, ok := s.scopeToRefs[scope]; ok { + delete(scopeRefs, ref) + if len(scopeRefs) == 0 { + delete(s.scopeToRefs, scope) + } + } + } + + delete(s.refs, ref) + delete(s.refToScope, ref) + } + } + s.mu.Unlock() + + // Phase 2: delete files without holding the lock + for _, e := range expired { + if err := os.Remove(e.path); err != nil && !os.IsNotExist(err) { + logger.WarnCF("media", "cleanup: failed to remove file", map[string]any{ + "path": e.path, + "error": err.Error(), + }) + } + } + + return len(expired) +} + +// Start begins the background cleanup goroutine if cleanup is enabled. +// Safe to call multiple times; only the first call starts the goroutine. +func (s *FileMediaStore) Start() { + if !s.cleanerCfg.Enabled || s.stop == nil { + return + } + if s.cleanerCfg.Interval <= 0 || s.cleanerCfg.MaxAge <= 0 { + logger.WarnCF("media", "cleanup: skipped due to invalid config", map[string]any{ + "interval": s.cleanerCfg.Interval.String(), + "max_age": s.cleanerCfg.MaxAge.String(), + }) + return + } + + s.startOnce.Do(func() { + logger.InfoCF("media", "cleanup enabled", map[string]any{ + "interval": s.cleanerCfg.Interval.String(), + "max_age": s.cleanerCfg.MaxAge.String(), + }) + + go func() { + ticker := time.NewTicker(s.cleanerCfg.Interval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + if n := s.CleanExpired(); n > 0 { + logger.InfoCF("media", "cleanup: removed expired entries", map[string]any{ + "count": n, + }) + } + case <-s.stop: + return + } + } + }() + }) +} + +// Stop terminates the background cleanup goroutine. +// Safe to call multiple times; only the first call closes the channel. +func (s *FileMediaStore) Stop() { + if s.stop == nil { + return + } + s.stopOnce.Do(func() { + close(s.stop) + }) +} diff --git a/pkg/media/store_test.go b/pkg/media/store_test.go new file mode 100644 index 000000000..1dcfdf350 --- /dev/null +++ b/pkg/media/store_test.go @@ -0,0 +1,530 @@ +package media + +import ( + "fmt" + "os" + "path/filepath" + "strings" + "sync" + "testing" + "time" +) + +func createTempFile(t *testing.T, dir, name string) string { + t.Helper() + path := filepath.Join(dir, name) + if err := os.WriteFile(path, []byte("test content"), 0o644); err != nil { + t.Fatalf("failed to create temp file: %v", err) + } + return path +} + +func TestStoreAndResolve(t *testing.T) { + dir := t.TempDir() + store := NewFileMediaStore() + + path := createTempFile(t, dir, "photo.jpg") + + ref, err := store.Store(path, MediaMeta{Filename: "photo.jpg", Source: "telegram"}, "scope1") + if err != nil { + t.Fatalf("Store failed: %v", err) + } + + if !strings.HasPrefix(ref, "media://") { + t.Errorf("ref should start with media://, got %q", ref) + } + + resolved, err := store.Resolve(ref) + if err != nil { + t.Fatalf("Resolve failed: %v", err) + } + if resolved != path { + t.Errorf("Resolve returned %q, want %q", resolved, path) + } +} + +func TestReleaseAll(t *testing.T) { + dir := t.TempDir() + store := NewFileMediaStore() + + paths := make([]string, 3) + refs := make([]string, 3) + for i := range 3 { + paths[i] = createTempFile(t, dir, strings.Repeat("a", i+1)+".jpg") + var err error + refs[i], err = store.Store(paths[i], MediaMeta{Source: "test"}, "scope1") + if err != nil { + t.Fatalf("Store failed: %v", err) + } + } + + if err := store.ReleaseAll("scope1"); err != nil { + t.Fatalf("ReleaseAll failed: %v", err) + } + + // Files should be deleted + for _, p := range paths { + if _, err := os.Stat(p); !os.IsNotExist(err) { + t.Errorf("file %q should have been deleted", p) + } + } + + // Refs should be unresolvable + for _, ref := range refs { + if _, err := store.Resolve(ref); err == nil { + t.Errorf("Resolve(%q) should fail after ReleaseAll", ref) + } + } +} + +func TestMultiScopeIsolation(t *testing.T) { + dir := t.TempDir() + store := NewFileMediaStore() + + pathA := createTempFile(t, dir, "fileA.jpg") + pathB := createTempFile(t, dir, "fileB.jpg") + + refA, _ := store.Store(pathA, MediaMeta{Source: "test"}, "scopeA") + refB, _ := store.Store(pathB, MediaMeta{Source: "test"}, "scopeB") + + // Release only scopeA + if err := store.ReleaseAll("scopeA"); err != nil { + t.Fatalf("ReleaseAll(scopeA) failed: %v", err) + } + + // scopeA file should be gone + if _, err := os.Stat(pathA); !os.IsNotExist(err) { + t.Error("file A should have been deleted") + } + if _, err := store.Resolve(refA); err == nil { + t.Error("refA should be unresolvable after release") + } + + // scopeB file should still exist + if _, err := os.Stat(pathB); err != nil { + t.Error("file B should still exist") + } + resolved, err := store.Resolve(refB) + if err != nil { + t.Fatalf("refB should still resolve: %v", err) + } + if resolved != pathB { + t.Errorf("resolved %q, want %q", resolved, pathB) + } +} + +func TestReleaseAllIdempotent(t *testing.T) { + store := NewFileMediaStore() + + // ReleaseAll on non-existent scope should not error + if err := store.ReleaseAll("nonexistent"); err != nil { + t.Fatalf("ReleaseAll on empty scope should not error: %v", err) + } + + // Create and release, then release again + dir := t.TempDir() + path := createTempFile(t, dir, "file.jpg") + _, _ = store.Store(path, MediaMeta{Source: "test"}, "scope1") + + if err := store.ReleaseAll("scope1"); err != nil { + t.Fatalf("first ReleaseAll failed: %v", err) + } + if err := store.ReleaseAll("scope1"); err != nil { + t.Fatalf("second ReleaseAll should not error: %v", err) + } +} + +func TestReleaseAllCleansMappingsIfRefsMissing(t *testing.T) { + dir := t.TempDir() + store := NewFileMediaStore() + + path := createTempFile(t, dir, "file.jpg") + ref, err := store.Store(path, MediaMeta{Source: "test"}, "scope1") + if err != nil { + t.Fatalf("Store failed: %v", err) + } + + // Simulate internal inconsistency: scopeToRefs/refToScope contains ref but refs map doesn't. + store.mu.Lock() + delete(store.refs, ref) + store.mu.Unlock() + + if err := store.ReleaseAll("scope1"); err != nil { + t.Fatalf("ReleaseAll failed: %v", err) + } + + // ReleaseAll should still clean mappings (even if it can't delete the file without the path). + store.mu.RLock() + defer store.mu.RUnlock() + if _, ok := store.refToScope[ref]; ok { + t.Error("refToScope should not contain ref after ReleaseAll") + } + if _, ok := store.scopeToRefs["scope1"]; ok { + t.Error("scopeToRefs should not contain scope1 after ReleaseAll") + } +} + +func TestStoreNonexistentFile(t *testing.T) { + store := NewFileMediaStore() + + _, err := store.Store("/nonexistent/path/file.jpg", MediaMeta{Source: "test"}, "scope1") + if err == nil { + t.Error("Store should fail for nonexistent file") + } + // Error message should include the underlying os error, not just "file does not exist" + if !strings.Contains(err.Error(), "no such file or directory") && + !strings.Contains(err.Error(), "cannot find") { + t.Errorf("Error should contain OS error detail, got: %v", err) + } +} + +func TestResolveWithMeta(t *testing.T) { + dir := t.TempDir() + store := NewFileMediaStore() + + path := createTempFile(t, dir, "image.png") + meta := MediaMeta{ + Filename: "image.png", + ContentType: "image/png", + Source: "telegram", + } + + ref, err := store.Store(path, meta, "scope1") + if err != nil { + t.Fatalf("Store failed: %v", err) + } + + resolvedPath, resolvedMeta, err := store.ResolveWithMeta(ref) + if err != nil { + t.Fatalf("ResolveWithMeta failed: %v", err) + } + if resolvedPath != path { + t.Errorf("ResolveWithMeta path = %q, want %q", resolvedPath, path) + } + if resolvedMeta.Filename != meta.Filename { + t.Errorf("ResolveWithMeta Filename = %q, want %q", resolvedMeta.Filename, meta.Filename) + } + if resolvedMeta.ContentType != meta.ContentType { + t.Errorf("ResolveWithMeta ContentType = %q, want %q", resolvedMeta.ContentType, meta.ContentType) + } + if resolvedMeta.Source != meta.Source { + t.Errorf("ResolveWithMeta Source = %q, want %q", resolvedMeta.Source, meta.Source) + } + + // Unknown ref should fail + _, _, err = store.ResolveWithMeta("media://nonexistent") + if err == nil { + t.Error("ResolveWithMeta should fail for unknown ref") + } +} + +func TestConcurrentSafety(t *testing.T) { + dir := t.TempDir() + store := NewFileMediaStore() + + const goroutines = 20 + const filesPerGoroutine = 5 + + var wg sync.WaitGroup + wg.Add(goroutines) + + for g := range goroutines { + go func(gIdx int) { + defer wg.Done() + scope := strings.Repeat("s", gIdx+1) + + for i := range filesPerGoroutine { + path := createTempFile(t, dir, strings.Repeat("f", gIdx*filesPerGoroutine+i+1)+".tmp") + ref, err := store.Store(path, MediaMeta{Source: "test"}, scope) + if err != nil { + t.Errorf("Store failed: %v", err) + return + } + + if _, err := store.Resolve(ref); err != nil { + t.Errorf("Resolve failed: %v", err) + } + } + + if err := store.ReleaseAll(scope); err != nil { + t.Errorf("ReleaseAll failed: %v", err) + } + }(g) + } + + wg.Wait() +} + +// --- TTL cleanup tests --- + +func newTestStoreWithCleanup(maxAge time.Duration) *FileMediaStore { + s := NewFileMediaStoreWithCleanup(MediaCleanerConfig{ + Enabled: true, + MaxAge: maxAge, + Interval: time.Hour, // won't tick in tests + }) + return s +} + +func TestCleanExpiredRemovesOldEntries(t *testing.T) { + dir := t.TempDir() + now := time.Now() + store := newTestStoreWithCleanup(10 * time.Minute) + store.nowFunc = func() time.Time { return now.Add(-20 * time.Minute) } + + path := createTempFile(t, dir, "old.jpg") + ref, err := store.Store(path, MediaMeta{Source: "test"}, "scope1") + if err != nil { + t.Fatalf("Store failed: %v", err) + } + + // Advance clock to present + store.nowFunc = func() time.Time { return now } + removed := store.CleanExpired() + + if removed != 1 { + t.Errorf("expected 1 removed, got %d", removed) + } + if _, err := store.Resolve(ref); err == nil { + t.Error("expired ref should be unresolvable") + } + if _, err := os.Stat(path); !os.IsNotExist(err) { + t.Error("expired file should be deleted") + } +} + +func TestCleanExpiredKeepsNonExpired(t *testing.T) { + dir := t.TempDir() + now := time.Now() + store := newTestStoreWithCleanup(10 * time.Minute) + store.nowFunc = func() time.Time { return now } + + path := createTempFile(t, dir, "fresh.jpg") + ref, err := store.Store(path, MediaMeta{Source: "test"}, "scope1") + if err != nil { + t.Fatalf("Store failed: %v", err) + } + + removed := store.CleanExpired() + if removed != 0 { + t.Errorf("expected 0 removed, got %d", removed) + } + + if _, err := store.Resolve(ref); err != nil { + t.Errorf("fresh ref should still resolve: %v", err) + } + if _, err := os.Stat(path); err != nil { + t.Error("fresh file should still exist") + } +} + +func TestCleanExpiredMixedAges(t *testing.T) { + dir := t.TempDir() + now := time.Now() + store := newTestStoreWithCleanup(10 * time.Minute) + + // Store old entry + store.nowFunc = func() time.Time { return now.Add(-20 * time.Minute) } + oldPath := createTempFile(t, dir, "old.jpg") + oldRef, _ := store.Store(oldPath, MediaMeta{Source: "test"}, "scope1") + + // Store fresh entry + store.nowFunc = func() time.Time { return now } + freshPath := createTempFile(t, dir, "fresh.jpg") + freshRef, _ := store.Store(freshPath, MediaMeta{Source: "test"}, "scope1") + + removed := store.CleanExpired() + if removed != 1 { + t.Errorf("expected 1 removed, got %d", removed) + } + + if _, err := store.Resolve(oldRef); err == nil { + t.Error("old ref should be gone") + } + if _, err := store.Resolve(freshRef); err != nil { + t.Errorf("fresh ref should still resolve: %v", err) + } +} + +func TestCleanExpiredCleansEmptyScopes(t *testing.T) { + dir := t.TempDir() + now := time.Now() + store := newTestStoreWithCleanup(10 * time.Minute) + + // Store old entry as the only one in scope + store.nowFunc = func() time.Time { return now.Add(-20 * time.Minute) } + path := createTempFile(t, dir, "only.jpg") + store.Store(path, MediaMeta{Source: "test"}, "lonely_scope") + + store.nowFunc = func() time.Time { return now } + store.CleanExpired() + + store.mu.RLock() + defer store.mu.RUnlock() + if _, ok := store.scopeToRefs["lonely_scope"]; ok { + t.Error("empty scope should be cleaned up") + } +} + +func TestStartStopLifecycle(t *testing.T) { + store := NewFileMediaStoreWithCleanup(MediaCleanerConfig{ + Enabled: true, + MaxAge: time.Minute, + Interval: 50 * time.Millisecond, + }) + + // Start and stop should not panic + store.Start() + // Double start should not spawn a second goroutine + store.Start() + time.Sleep(100 * time.Millisecond) + store.Stop() + + // Double stop should not panic + store.Stop() +} + +func TestCleanExpiredZeroMaxAge(t *testing.T) { + store := NewFileMediaStoreWithCleanup(MediaCleanerConfig{ + Enabled: true, + MaxAge: 0, + Interval: time.Hour, + }) + + dir := t.TempDir() + path := createTempFile(t, dir, "file.jpg") + ref, _ := store.Store(path, MediaMeta{Source: "test"}, "scope1") + + // Zero MaxAge should be a no-op + removed := store.CleanExpired() + if removed != 0 { + t.Errorf("expected 0 removed with zero MaxAge, got %d", removed) + } + if _, err := store.Resolve(ref); err != nil { + t.Errorf("ref should still resolve: %v", err) + } +} + +func TestStartDisabledIsNoop(t *testing.T) { + store := NewFileMediaStoreWithCleanup(MediaCleanerConfig{ + Enabled: false, + MaxAge: time.Minute, + Interval: time.Minute, + }) + // Should not start any goroutine or panic + store.Start() + store.Stop() +} + +func TestStartZeroIntervalNoPanic(t *testing.T) { + store := NewFileMediaStoreWithCleanup(MediaCleanerConfig{ + Enabled: true, + MaxAge: time.Minute, + Interval: 0, + }) + // Zero interval should not panic (time.NewTicker panics on <= 0) + store.Start() + store.Stop() +} + +func TestStartZeroMaxAgeNoPanic(t *testing.T) { + store := NewFileMediaStoreWithCleanup(MediaCleanerConfig{ + Enabled: true, + MaxAge: 0, + Interval: time.Minute, + }) + store.Start() + store.Stop() +} + +func TestConcurrentCleanupSafety(t *testing.T) { + dir := t.TempDir() + store := newTestStoreWithCleanup(50 * time.Millisecond) + store.nowFunc = time.Now + + const workers = 10 + const ops = 20 + var wg sync.WaitGroup + wg.Add(workers * 4) + + // Store workers + for w := range workers { + go func(wIdx int) { + defer wg.Done() + scope := fmt.Sprintf("scope-%d", wIdx) + for i := range ops { + p := createTempFile(t, dir, fmt.Sprintf("w%d-f%d.tmp", wIdx, i)) + store.Store(p, MediaMeta{Source: "test"}, scope) + } + }(w) + } + + // Resolve workers + for range workers { + go func() { + defer wg.Done() + for range ops { + store.Resolve("media://nonexistent") + } + }() + } + + // ReleaseAll workers + for w := range workers { + go func(wIdx int) { + defer wg.Done() + for range ops { + store.ReleaseAll(fmt.Sprintf("scope-%d", wIdx)) + } + }(w) + } + + // CleanExpired workers + for range workers { + go func() { + defer wg.Done() + for range ops { + store.CleanExpired() + } + }() + } + + wg.Wait() +} + +func TestRefToScopeConsistency(t *testing.T) { + dir := t.TempDir() + store := NewFileMediaStore() + + // Store entries in two scopes + ref1, _ := store.Store(createTempFile(t, dir, "a.jpg"), MediaMeta{Source: "test"}, "s1") + ref2, _ := store.Store(createTempFile(t, dir, "b.jpg"), MediaMeta{Source: "test"}, "s1") + ref3, _ := store.Store(createTempFile(t, dir, "c.jpg"), MediaMeta{Source: "test"}, "s2") + + store.mu.RLock() + checkRef := func(ref, expectedScope string) { + t.Helper() + if scope, ok := store.refToScope[ref]; !ok || scope != expectedScope { + t.Errorf("refToScope[%s] = %q, want %q", ref, scope, expectedScope) + } + } + checkRef(ref1, "s1") + checkRef(ref2, "s1") + checkRef(ref3, "s2") + store.mu.RUnlock() + + // Release s1 and verify refToScope is cleaned + store.ReleaseAll("s1") + + store.mu.RLock() + defer store.mu.RUnlock() + if _, ok := store.refToScope[ref1]; ok { + t.Error("refToScope should not contain ref1 after ReleaseAll") + } + if _, ok := store.refToScope[ref2]; ok { + t.Error("refToScope should not contain ref2 after ReleaseAll") + } + if _, ok := store.refToScope[ref3]; !ok { + t.Error("refToScope should still contain ref3") + } +} diff --git a/pkg/migrate/config.go b/pkg/migrate/config.go index 869b39827..ea91565e8 100644 --- a/pkg/migrate/config.go +++ b/pkg/migrate/config.go @@ -165,6 +165,12 @@ func ConvertConfig(data map[string]any) (*config.Config, []string, error) { if v, ok := getString(cMap, "bridge_url"); ok { cfg.Channels.WhatsApp.BridgeURL = v } + if v, ok := getBool(cMap, "use_native"); ok { + cfg.Channels.WhatsApp.UseNative = v + } + if v, ok := getString(cMap, "session_store_path"); ok { + cfg.Channels.WhatsApp.SessionStorePath = v + } case "feishu": cfg.Channels.Feishu.Enabled = enabled cfg.Channels.Feishu.AllowFrom = allowFrom diff --git a/pkg/migrate/migrate_test.go b/pkg/migrate/migrate_test.go index b6b3d70aa..9216442bb 100644 --- a/pkg/migrate/migrate_test.go +++ b/pkg/migrate/migrate_test.go @@ -296,8 +296,8 @@ func TestConvertConfig(t *testing.T) { if len(warnings) != 0 { t.Errorf("expected no warnings, got %v", warnings) } - if cfg.Agents.Defaults.Model != "glm-4.7" { - t.Errorf("default model should be glm-4.7, got %q", cfg.Agents.Defaults.Model) + if cfg.Agents.Defaults.Model != "" { + t.Errorf("default model should be nil, got %q", cfg.Agents.Defaults.Model) } }) } diff --git a/pkg/providers/fallback.go b/pkg/providers/fallback.go index ecd451ec9..7ba563b66 100644 --- a/pkg/providers/fallback.go +++ b/pkg/providers/fallback.go @@ -43,11 +43,26 @@ func NewFallbackChain(cooldown *CooldownTracker) *FallbackChain { // ResolveCandidates parses model config into a deduplicated candidate list. func ResolveCandidates(cfg ModelConfig, defaultProvider string) []FallbackCandidate { + return ResolveCandidatesWithLookup(cfg, defaultProvider, nil) +} + +func ResolveCandidatesWithLookup( + cfg ModelConfig, + defaultProvider string, + lookup func(raw string) (resolved string, ok bool), +) []FallbackCandidate { seen := make(map[string]bool) var candidates []FallbackCandidate addCandidate := func(raw string) { - ref := ParseModelRef(raw, defaultProvider) + candidateRaw := strings.TrimSpace(raw) + if lookup != nil { + if resolved, ok := lookup(candidateRaw); ok { + candidateRaw = resolved + } + } + + ref := ParseModelRef(candidateRaw, defaultProvider) if ref == nil { return } diff --git a/pkg/providers/fallback_test.go b/pkg/providers/fallback_test.go index ebba054ef..1783ebcb5 100644 --- a/pkg/providers/fallback_test.go +++ b/pkg/providers/fallback_test.go @@ -453,6 +453,75 @@ func TestResolveCandidates_EmptyPrimary(t *testing.T) { } } +func TestResolveCandidatesWithLookup_AliasResolvesToNestedModel(t *testing.T) { + cfg := ModelConfig{ + Primary: "step-3.5-flash", + Fallbacks: nil, + } + + lookup := func(raw string) (string, bool) { + if raw == "step-3.5-flash" { + return "openrouter/stepfun/step-3.5-flash:free", true + } + return "", false + } + + candidates := ResolveCandidatesWithLookup(cfg, "", lookup) + if len(candidates) != 1 { + t.Fatalf("candidates = %d, want 1", len(candidates)) + } + if candidates[0].Provider != "openrouter" { + t.Fatalf("provider = %q, want openrouter", candidates[0].Provider) + } + if candidates[0].Model != "stepfun/step-3.5-flash:free" { + t.Fatalf("model = %q, want stepfun/step-3.5-flash:free", candidates[0].Model) + } +} + +func TestResolveCandidatesWithLookup_DeduplicateAfterLookup(t *testing.T) { + cfg := ModelConfig{ + Primary: "step-3.5-flash", + Fallbacks: []string{"openrouter/stepfun/step-3.5-flash:free"}, + } + + lookup := func(raw string) (string, bool) { + if raw == "step-3.5-flash" { + return "openrouter/stepfun/step-3.5-flash:free", true + } + return "", false + } + + candidates := ResolveCandidatesWithLookup(cfg, "", lookup) + if len(candidates) != 1 { + t.Fatalf("candidates = %d, want 1", len(candidates)) + } +} + +func TestResolveCandidatesWithLookup_AliasWithoutProtocolUsesDefaultProvider(t *testing.T) { + cfg := ModelConfig{ + Primary: "glm-5", + Fallbacks: nil, + } + + lookup := func(raw string) (string, bool) { + if raw == "glm-5" { + return "glm-5", true + } + return "", false + } + + candidates := ResolveCandidatesWithLookup(cfg, "openai", lookup) + if len(candidates) != 1 { + t.Fatalf("candidates = %d, want 1", len(candidates)) + } + if candidates[0].Provider != "openai" { + t.Fatalf("provider = %q, want openai", candidates[0].Provider) + } + if candidates[0].Model != "glm-5" { + t.Fatalf("model = %q, want glm-5", candidates[0].Model) + } +} + func TestFallbackExhaustedError_Message(t *testing.T) { e := &FallbackExhaustedError{ Attempts: []FallbackAttempt{ diff --git a/pkg/providers/openai_compat/provider.go b/pkg/providers/openai_compat/provider.go index cd606d533..d922ed5f7 100644 --- a/pkg/providers/openai_compat/provider.go +++ b/pkg/providers/openai_compat/provider.go @@ -25,6 +25,7 @@ type ( ToolFunctionDefinition = protocoltypes.ToolFunctionDefinition ExtraContent = protocoltypes.ExtraContent GoogleExtra = protocoltypes.GoogleExtra + ReasoningDetail = protocoltypes.ReasoningDetail ) type Provider struct { @@ -198,8 +199,10 @@ func parseResponse(body []byte) (*LLMResponse, error) { var apiResponse struct { Choices []struct { Message struct { - Content string `json:"content"` - ReasoningContent string `json:"reasoning_content"` + Content string `json:"content"` + ReasoningContent string `json:"reasoning_content"` + Reasoning string `json:"reasoning"` + ReasoningDetails []ReasoningDetail `json:"reasoning_details"` ToolCalls []struct { ID string `json:"id"` Type string `json:"type"` @@ -274,6 +277,8 @@ func parseResponse(body []byte) (*LLMResponse, error) { return &LLMResponse{ Content: choice.Message.Content, ReasoningContent: choice.Message.ReasoningContent, + Reasoning: choice.Message.Reasoning, + ReasoningDetails: choice.Message.ReasoningDetails, ToolCalls: toolCalls, FinishReason: choice.FinishReason, Usage: apiResponse.Usage, diff --git a/pkg/providers/protocoltypes/types.go b/pkg/providers/protocoltypes/types.go index 33f052c5a..99f13334e 100644 --- a/pkg/providers/protocoltypes/types.go +++ b/pkg/providers/protocoltypes/types.go @@ -25,11 +25,20 @@ type FunctionCall struct { } type LLMResponse struct { - Content string `json:"content"` - ReasoningContent string `json:"reasoning_content,omitempty"` - ToolCalls []ToolCall `json:"tool_calls,omitempty"` - FinishReason string `json:"finish_reason"` - Usage *UsageInfo `json:"usage,omitempty"` + Content string `json:"content"` + ReasoningContent string `json:"reasoning_content,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + FinishReason string `json:"finish_reason"` + Usage *UsageInfo `json:"usage,omitempty"` + Reasoning string `json:"reasoning"` + ReasoningDetails []ReasoningDetail `json:"reasoning_details"` +} + +type ReasoningDetail struct { + Format string `json:"format"` + Index int `json:"index"` + Type string `json:"type"` + Text string `json:"text"` } type UsageInfo struct { diff --git a/pkg/routing/session_key.go b/pkg/routing/session_key.go index e12f0d1d8..eab592bec 100644 --- a/pkg/routing/session_key.go +++ b/pkg/routing/session_key.go @@ -163,6 +163,15 @@ func resolveLinkedPeerID(identityLinks map[string][]string, channel, peerID stri scopedCandidate := fmt.Sprintf("%s:%s", channel, strings.ToLower(peerID)) candidates[scopedCandidate] = true } + + // If peerID is already in canonical "platform:id" format, also add the + // bare ID part as a candidate for backward compatibility with identity_links + // that use raw IDs (e.g. "123" instead of "telegram:123"). + if idx := strings.Index(rawCandidate, ":"); idx > 0 && idx < len(rawCandidate)-1 { + bareID := rawCandidate[idx+1:] + candidates[bareID] = true + } + if len(candidates) == 0 { return "" } diff --git a/pkg/routing/session_key_test.go b/pkg/routing/session_key_test.go index 81e4ce018..ad7a1ca02 100644 --- a/pkg/routing/session_key_test.go +++ b/pkg/routing/session_key_test.go @@ -115,6 +115,51 @@ func TestBuildAgentPeerSessionKey_IdentityLink(t *testing.T) { } } +func TestResolveLinkedPeerID_CanonicalPeerID(t *testing.T) { + // When peerID is already in canonical "platform:id" format, + // it should match identity_links that use the bare ID. + links := map[string][]string{ + "john": {"123"}, + } + got := resolveLinkedPeerID(links, "telegram", "telegram:123") + if got != "john" { + t.Errorf("resolveLinkedPeerID with canonical peerID = %q, want %q", got, "john") + } +} + +func TestResolveLinkedPeerID_CanonicalInLinks(t *testing.T) { + // When identity_links contain canonical IDs and peerID is canonical too + links := map[string][]string{ + "john": {"telegram:123", "discord:456"}, + } + got := resolveLinkedPeerID(links, "telegram", "telegram:123") + if got != "john" { + t.Errorf("resolveLinkedPeerID canonical in links = %q, want %q", got, "john") + } +} + +func TestResolveLinkedPeerID_BarePeerIDMatchesCanonicalLink(t *testing.T) { + // When peerID is bare "123" and links have "telegram:123", + // the scoped candidate "telegram:123" should match. + links := map[string][]string{ + "john": {"telegram:123"}, + } + got := resolveLinkedPeerID(links, "telegram", "123") + if got != "john" { + t.Errorf("resolveLinkedPeerID bare peer matches canonical link = %q, want %q", got, "john") + } +} + +func TestResolveLinkedPeerID_NoMatch(t *testing.T) { + links := map[string][]string{ + "john": {"telegram:123"}, + } + got := resolveLinkedPeerID(links, "discord", "999") + if got != "" { + t.Errorf("resolveLinkedPeerID no match = %q, want empty", got) + } +} + func TestParseAgentSessionKey_Valid(t *testing.T) { parsed := ParseAgentSessionKey("agent:sales:telegram:direct:user123") if parsed == nil { diff --git a/pkg/tools/cron.go b/pkg/tools/cron.go index 3140e5e25..0afd51f2d 100644 --- a/pkg/tools/cron.go +++ b/pkg/tools/cron.go @@ -296,7 +296,9 @@ func (t *CronTool) ExecuteJob(ctx context.Context, job *cron.CronJob) string { output = fmt.Sprintf("Scheduled command '%s' executed:\n%s", job.Payload.Command, result.ForLLM) } - t.msgBus.PublishOutbound(bus.OutboundMessage{ + pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer pubCancel() + t.msgBus.PublishOutbound(pubCtx, bus.OutboundMessage{ Channel: channel, ChatID: chatID, Content: output, @@ -306,7 +308,9 @@ func (t *CronTool) ExecuteJob(ctx context.Context, job *cron.CronJob) string { // If deliver=true, send message directly without agent processing if job.Payload.Deliver { - t.msgBus.PublishOutbound(bus.OutboundMessage{ + pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer pubCancel() + t.msgBus.PublishOutbound(pubCtx, bus.OutboundMessage{ Channel: channel, ChatID: chatID, Content: job.Payload.Message, diff --git a/pkg/tools/result.go b/pkg/tools/result.go index b13055b1c..cab833284 100644 --- a/pkg/tools/result.go +++ b/pkg/tools/result.go @@ -30,6 +30,10 @@ type ToolResult struct { // Err is the underlying error (not JSON serialized). // Used for internal error handling and logging. Err error `json:"-"` + + // Media contains media store refs produced by this tool. + // When non-empty, the agent will publish these as OutboundMediaMessage. + Media []string `json:"media,omitempty"` } // NewToolResult creates a basic ToolResult with content for the LLM. @@ -120,6 +124,19 @@ func UserResult(content string) *ToolResult { } } +// MediaResult creates a ToolResult with media refs for the user. +// The agent will publish these refs as OutboundMediaMessage. +// +// Example: +// +// result := MediaResult("Image generated successfully", []string{"media://abc123"}) +func MediaResult(forLLM string, mediaRefs []string) *ToolResult { + return &ToolResult{ + ForLLM: forLLM, + Media: mediaRefs, + } +} + // MarshalJSON implements custom JSON serialization. // The Err field is excluded from JSON output via the json:"-" tag. func (tr *ToolResult) MarshalJSON() ([]byte, error) { diff --git a/pkg/tools/subagent.go b/pkg/tools/subagent.go index ad371a649..69f1a49a2 100644 --- a/pkg/tools/subagent.go +++ b/pkg/tools/subagent.go @@ -218,7 +218,9 @@ After completing the task, provide a clear summary of what was done.` // Send announce message back to main agent if sm.bus != nil { announceContent := fmt.Sprintf("Task '%s' completed.\n\nResult:\n%s", task.Label, task.Result) - sm.bus.PublishInbound(bus.InboundMessage{ + pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer pubCancel() + sm.bus.PublishInbound(pubCtx, bus.InboundMessage{ Channel: "system", SenderID: fmt.Sprintf("subagent:%s", task.ID), // Format: "original_channel:original_chat_id" for routing back diff --git a/pkg/utils/message.go b/pkg/utils/message.go deleted file mode 100644 index a65506edc..000000000 --- a/pkg/utils/message.go +++ /dev/null @@ -1,167 +0,0 @@ -package utils - -import ( - "strings" -) - -// SplitMessage splits long messages into chunks, preserving code block integrity. -// The function reserves a buffer (10% of maxLen, min 50) to leave room for closing code blocks, -// but may extend to maxLen when needed. -// Call SplitMessage with the full text content and the maximum allowed length of a single message; -// it returns a slice of message chunks that each respect maxLen and avoid splitting fenced code blocks. -func SplitMessage(content string, maxLen int) []string { - var messages []string - - // Dynamic buffer: 10% of maxLen, but at least 50 chars if possible - codeBlockBuffer := max(maxLen/10, 50) - if codeBlockBuffer > maxLen/2 { - codeBlockBuffer = maxLen / 2 - } - - for len(content) > 0 { - if len(content) <= maxLen { - messages = append(messages, content) - break - } - - // Effective split point: maxLen minus buffer, to leave room for code blocks - effectiveLimit := max(maxLen-codeBlockBuffer, maxLen/2) - - // Find natural split point within the effective limit - msgEnd := findLastNewline(content[:effectiveLimit], 200) - if msgEnd <= 0 { - msgEnd = findLastSpace(content[:effectiveLimit], 100) - } - if msgEnd <= 0 { - msgEnd = effectiveLimit - } - - // Check if this would end with an incomplete code block - candidate := content[:msgEnd] - unclosedIdx := findLastUnclosedCodeBlock(candidate) - - if unclosedIdx >= 0 { - // Message would end with incomplete code block - // Try to extend up to maxLen to include the closing ``` - if len(content) > msgEnd { - closingIdx := findNextClosingCodeBlock(content, msgEnd) - if closingIdx > 0 && closingIdx <= maxLen { - // Extend to include the closing ``` - msgEnd = closingIdx - } else { - // Code block is too long to fit in one chunk or missing closing fence. - // Try to split inside by injecting closing and reopening fences. - headerEnd := strings.Index(content[unclosedIdx:], "\n") - if headerEnd == -1 { - headerEnd = unclosedIdx + 3 - } else { - headerEnd += unclosedIdx - } - header := strings.TrimSpace(content[unclosedIdx:headerEnd]) - - // If we have a reasonable amount of content after the header, split inside - if msgEnd > headerEnd+20 { - // Find a better split point closer to maxLen - innerLimit := maxLen - 5 // Leave room for "\n```" - betterEnd := findLastNewline(content[:innerLimit], 200) - if betterEnd > headerEnd { - msgEnd = betterEnd - } else { - msgEnd = innerLimit - } - messages = append(messages, strings.TrimRight(content[:msgEnd], " \t\n\r")+"\n```") - content = strings.TrimSpace(header + "\n" + content[msgEnd:]) - continue - } - - // Otherwise, try to split before the code block starts - newEnd := findLastNewline(content[:unclosedIdx], 200) - if newEnd <= 0 { - newEnd = findLastSpace(content[:unclosedIdx], 100) - } - if newEnd > 0 { - msgEnd = newEnd - } else { - // If we can't split before, we MUST split inside (last resort) - if unclosedIdx > 20 { - msgEnd = unclosedIdx - } else { - msgEnd = maxLen - 5 - messages = append(messages, strings.TrimRight(content[:msgEnd], " \t\n\r")+"\n```") - content = strings.TrimSpace(header + "\n" + content[msgEnd:]) - continue - } - } - } - } - } - - if msgEnd <= 0 { - msgEnd = effectiveLimit - } - - messages = append(messages, content[:msgEnd]) - content = strings.TrimSpace(content[msgEnd:]) - } - - return messages -} - -// findLastUnclosedCodeBlock finds the last opening ``` that doesn't have a closing ``` -// Returns the position of the opening ``` or -1 if all code blocks are complete -func findLastUnclosedCodeBlock(text string) int { - inCodeBlock := false - lastOpenIdx := -1 - - for i := 0; i < len(text); i++ { - if i+2 < len(text) && text[i] == '`' && text[i+1] == '`' && text[i+2] == '`' { - // Toggle code block state on each fence - if !inCodeBlock { - // Entering a code block: record this opening fence - lastOpenIdx = i - } - inCodeBlock = !inCodeBlock - i += 2 - } - } - - if inCodeBlock { - return lastOpenIdx - } - return -1 -} - -// findNextClosingCodeBlock finds the next closing ``` starting from a position -// Returns the position after the closing ``` or -1 if not found -func findNextClosingCodeBlock(text string, startIdx int) int { - for i := startIdx; i < len(text); i++ { - if i+2 < len(text) && text[i] == '`' && text[i+1] == '`' && text[i+2] == '`' { - return i + 3 - } - } - return -1 -} - -// findLastNewline finds the last newline character within the last N characters -// Returns the position of the newline or -1 if not found -func findLastNewline(s string, searchWindow int) int { - searchStart := max(len(s)-searchWindow, 0) - for i := len(s) - 1; i >= searchStart; i-- { - if s[i] == '\n' { - return i - } - } - return -1 -} - -// findLastSpace finds the last space character within the last N characters -// Returns the position of the space or -1 if not found -func findLastSpace(s string, searchWindow int) int { - searchStart := max(len(s)-searchWindow, 0) - for i := len(s) - 1; i >= searchStart; i-- { - if s[i] == ' ' || s[i] == '\t' { - return i - } - } - return -1 -} diff --git a/pkg/utils/message_test.go b/pkg/utils/message_test.go deleted file mode 100644 index 338509437..000000000 --- a/pkg/utils/message_test.go +++ /dev/null @@ -1,151 +0,0 @@ -package utils - -import ( - "strings" - "testing" -) - -func TestSplitMessage(t *testing.T) { - longText := strings.Repeat("a", 2500) - longCode := "```go\n" + strings.Repeat("fmt.Println(\"hello\")\n", 100) + "```" // ~2100 chars - - tests := []struct { - name string - content string - maxLen int - expectChunks int // Check number of chunks - checkContent func(t *testing.T, chunks []string) // Custom validation - }{ - { - name: "Empty message", - content: "", - maxLen: 2000, - expectChunks: 0, - }, - { - name: "Short message fits in one chunk", - content: "Hello world", - maxLen: 2000, - expectChunks: 1, - }, - { - name: "Simple split regular text", - content: longText, - maxLen: 2000, - expectChunks: 2, - checkContent: func(t *testing.T, chunks []string) { - if len(chunks[0]) > 2000 { - t.Errorf("Chunk 0 too large: %d", len(chunks[0])) - } - if len(chunks[0])+len(chunks[1]) != len(longText) { - t.Errorf("Total length mismatch. Got %d, want %d", len(chunks[0])+len(chunks[1]), len(longText)) - } - }, - }, - { - name: "Split at newline", - // 1750 chars then newline, then more chars. - // Dynamic buffer: 2000 / 10 = 200. - // Effective limit: 2000 - 200 = 1800. - // Split should happen at newline because it's at 1750 (< 1800). - // Total length must > 2000 to trigger split. 1750 + 1 + 300 = 2051. - content: strings.Repeat("a", 1750) + "\n" + strings.Repeat("b", 300), - maxLen: 2000, - expectChunks: 2, - checkContent: func(t *testing.T, chunks []string) { - if len(chunks[0]) != 1750 { - t.Errorf("Expected chunk 0 to be 1750 length (split at newline), got %d", len(chunks[0])) - } - if chunks[1] != strings.Repeat("b", 300) { - t.Errorf("Chunk 1 content mismatch. Len: %d", len(chunks[1])) - } - }, - }, - { - name: "Long code block split", - content: "Prefix\n" + longCode, - maxLen: 2000, - expectChunks: 2, - checkContent: func(t *testing.T, chunks []string) { - // Check that first chunk ends with closing fence - if !strings.HasSuffix(chunks[0], "\n```") { - t.Error("First chunk should end with injected closing fence") - } - // Check that second chunk starts with execution header - if !strings.HasPrefix(chunks[1], "```go") { - t.Error("Second chunk should start with injected code block header") - } - }, - }, - { - name: "Preserve Unicode characters", - content: strings.Repeat("\u4e16", 1000), // 3000 bytes - maxLen: 2000, - expectChunks: 2, - checkContent: func(t *testing.T, chunks []string) { - // Just verify we didn't panic and got valid strings. - // Go strings are UTF-8, if we split mid-rune it would be bad, - // but standard slicing might do that. - // Let's assume standard behavior is acceptable or check if it produces invalid rune? - if !strings.Contains(chunks[0], "\u4e16") { - t.Error("Chunk should contain unicode characters") - } - }, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - got := SplitMessage(tc.content, tc.maxLen) - - if tc.expectChunks == 0 { - if len(got) != 0 { - t.Errorf("Expected 0 chunks, got %d", len(got)) - } - return - } - - if len(got) != tc.expectChunks { - t.Errorf("Expected %d chunks, got %d", tc.expectChunks, len(got)) - // Log sizes for debugging - for i, c := range got { - t.Logf("Chunk %d length: %d", i, len(c)) - } - return // Stop further checks if count assumes specific split - } - - if tc.checkContent != nil { - tc.checkContent(t, got) - } - }) - } -} - -func TestSplitMessage_CodeBlockIntegrity(t *testing.T) { - // Focused test for the core requirement: splitting inside a code block preserves syntax highlighting - - // 60 chars total approximately - content := "```go\npackage main\n\nfunc main() {\n\tprintln(\"Hello\")\n}\n```" - maxLen := 40 - - chunks := SplitMessage(content, maxLen) - - if len(chunks) != 2 { - t.Fatalf("Expected 2 chunks, got %d: %q", len(chunks), chunks) - } - - // First chunk must end with "\n```" - if !strings.HasSuffix(chunks[0], "\n```") { - t.Errorf("First chunk should end with closing fence. Got: %q", chunks[0]) - } - - // Second chunk must start with the header "```go" - if !strings.HasPrefix(chunks[1], "```go") { - t.Errorf("Second chunk should start with code block header. Got: %q", chunks[1]) - } - - // First chunk should contain meaningful content - if len(chunks[0]) > 40 { - t.Errorf("First chunk exceeded maxLen: length %d", len(chunks[0])) - } -} diff --git a/pkg/utils/string.go b/pkg/utils/string.go index 62d9beee0..02f346db4 100644 --- a/pkg/utils/string.go +++ b/pkg/utils/string.go @@ -1,5 +1,31 @@ package utils +import ( + "strings" + "unicode" +) + +// SanitizeMessageContent removes Unicode control characters, format characters (RTL overrides, +// zero-width characters), and other non-graphic characters that could confuse an LLM +// or cause display issues in the agent UI. +func SanitizeMessageContent(input string) string { + var sb strings.Builder + // Pre-allocate memory to avoid multiple allocations + sb.Grow(len(input)) + + for _, r := range input { + // unicode.IsGraphic returns true if the rune is a Unicode graphic character. + // This includes letters, marks, numbers, punctuation, and symbols. + // It excludes control characters (Cc), format characters (Cf), + // surrogates (Cs), and private use (Co). + if unicode.IsGraphic(r) || r == '\n' || r == '\r' || r == '\t' { + sb.WriteRune(r) + } + } + + return sb.String() +} + // Truncate returns a truncated version of s with at most maxLen runes. // Handles multi-byte Unicode characters properly. // If the string is truncated, "..." is appended to indicate truncation. diff --git a/pkg/utils/string_test.go b/pkg/utils/string_test.go index a44ead228..e3b5af052 100644 --- a/pkg/utils/string_test.go +++ b/pkg/utils/string_test.go @@ -104,3 +104,27 @@ func TestTruncate(t *testing.T) { }) } } + +func TestSanitizeMessageContent(t *testing.T) { + tests := []struct { + name string + input string + want string + }{ + {"empty", "", ""}, + {"plain text unchanged", "Hello world", "Hello world"}, + {"strip ZWSP", "Hello\u200bworld", "Helloworld"}, + {"strip RTL override", "Hi\u202eevil", "Hievil"}, + {"strip BOM", "\uFEFFcontent", "content"}, + {"strip multiple", "a\u200c\u202ab\u202cc", "abc"}, + {"unicode letters preserved", "café \u65e5\u672c\u8a9e", "café \u65e5\u672c\u8a9e"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := SanitizeMessageContent(tt.input) + if got != tt.want { + t.Errorf("SanitizeMessageContent(%q) = %q, want %q", tt.input, got, tt.want) + } + }) + } +}