diff --git a/.github/workflows/pr.yml b/.github/workflows/pr.yml index 902d4d4eb..2d544d4f0 100644 --- a/.github/workflows/pr.yml +++ b/.github/workflows/pr.yml @@ -23,10 +23,13 @@ jobs: uses: golangci/golangci-lint-action@v9 with: version: v2.10.1 + args: --build-tags=goolm,stdjson vuln_check: name: Security Check runs-on: ubuntu-latest + env: + GOFLAGS: -tags=goolm,stdjson steps: - name: Checkout uses: actions/checkout@v6 @@ -59,4 +62,4 @@ jobs: run: go generate ./... - name: Run go test - run: go test ./... + run: go test -tags goolm,stdjson ./... diff --git a/.gitignore b/.gitignore index 8b5f95215..72f3b1761 100644 --- a/.gitignore +++ b/.gitignore @@ -40,6 +40,7 @@ tasks/ # Plans docs/plans/ +docs/superpowers/ # Editors .vscode/ diff --git a/.goreleaser.yaml b/.goreleaser.yaml index a73f87f30..ea93d0377 100644 --- a/.goreleaser.yaml +++ b/.goreleaser.yaml @@ -15,6 +15,7 @@ builds: env: - CGO_ENABLED=0 tags: + - goolm - stdjson ldflags: - -s -w @@ -57,6 +58,7 @@ builds: env: - CGO_ENABLED=0 tags: + - goolm - stdjson ldflags: - -s -w @@ -95,6 +97,7 @@ builds: env: - CGO_ENABLED=0 tags: + - goolm - stdjson ldflags: - -s -w diff --git a/Makefile b/Makefile index 411cd9dc5..8ffadb4ef 100644 --- a/Makefile +++ b/Makefile @@ -17,7 +17,13 @@ LDFLAGS=-X $(CONFIG_PKG).Version=$(VERSION) -X $(CONFIG_PKG).GitCommit=$(GIT_COM # Go variables GO?=CGO_ENABLED=0 go WEB_GO?=$(GO) -GOFLAGS?=-v -tags stdjson +GO_BUILD_TAGS?=goolm,stdjson +GOFLAGS?=-v -tags $(GO_BUILD_TAGS) +comma:=, +empty:= +space:=$(empty) $(empty) +GO_BUILD_TAGS_NO_GOOLM:=$(subst $(space),$(comma),$(strip $(filter-out goolm,$(subst $(comma),$(space),$(GO_BUILD_TAGS))))) +GOFLAGS_NO_GOOLM?=-v -tags $(GO_BUILD_TAGS_NO_GOOLM) # Patch MIPS LE ELF e_flags (offset 36) for NaN2008-only kernels (e.g. Ingenic X2600). # @@ -130,15 +136,15 @@ build-whatsapp-native: generate ## @echo "Building $(BINARY_NAME) with WhatsApp native for $(PLATFORM)/$(ARCH)..." @echo "Building for multiple platforms..." @mkdir -p $(BUILD_DIR) - GOOS=linux GOARCH=amd64 $(GO) build -tags whatsapp_native -ldflags "$(LDFLAGS)" -o $(BUILD_DIR)/$(BINARY_NAME)-linux-amd64 ./$(CMD_DIR) - GOOS=linux GOARCH=arm GOARM=7 $(GO) build -tags whatsapp_native -ldflags "$(LDFLAGS)" -o $(BUILD_DIR)/$(BINARY_NAME)-linux-arm ./$(CMD_DIR) - GOOS=linux GOARCH=arm64 $(GO) build -tags whatsapp_native -ldflags "$(LDFLAGS)" -o $(BUILD_DIR)/$(BINARY_NAME)-linux-arm64 ./$(CMD_DIR) - GOOS=linux GOARCH=loong64 $(GO) build -tags whatsapp_native -ldflags "$(LDFLAGS)" -o $(BUILD_DIR)/$(BINARY_NAME)-linux-loong64 ./$(CMD_DIR) - GOOS=linux GOARCH=riscv64 $(GO) build -tags whatsapp_native -ldflags "$(LDFLAGS)" -o $(BUILD_DIR)/$(BINARY_NAME)-linux-riscv64 ./$(CMD_DIR) - GOOS=linux GOARCH=mipsle GOMIPS=softfloat $(GO) build -tags whatsapp_native -ldflags "$(LDFLAGS)" -o $(BUILD_DIR)/$(BINARY_NAME)-linux-mipsle ./$(CMD_DIR) + GOOS=linux GOARCH=amd64 $(GO) build -tags $(GO_BUILD_TAGS),whatsapp_native -ldflags "$(LDFLAGS)" -o $(BUILD_DIR)/$(BINARY_NAME)-linux-amd64 ./$(CMD_DIR) + GOOS=linux GOARCH=arm GOARM=7 $(GO) build -tags $(GO_BUILD_TAGS),whatsapp_native -ldflags "$(LDFLAGS)" -o $(BUILD_DIR)/$(BINARY_NAME)-linux-arm ./$(CMD_DIR) + GOOS=linux GOARCH=arm64 $(GO) build -tags $(GO_BUILD_TAGS),whatsapp_native -ldflags "$(LDFLAGS)" -o $(BUILD_DIR)/$(BINARY_NAME)-linux-arm64 ./$(CMD_DIR) + GOOS=linux GOARCH=loong64 $(GO) build -tags $(GO_BUILD_TAGS),whatsapp_native -ldflags "$(LDFLAGS)" -o $(BUILD_DIR)/$(BINARY_NAME)-linux-loong64 ./$(CMD_DIR) + GOOS=linux GOARCH=riscv64 $(GO) build -tags $(GO_BUILD_TAGS),whatsapp_native -ldflags "$(LDFLAGS)" -o $(BUILD_DIR)/$(BINARY_NAME)-linux-riscv64 ./$(CMD_DIR) + GOOS=linux GOARCH=mipsle GOMIPS=softfloat $(GO) build -tags $(GO_BUILD_TAGS_NO_GOOLM),whatsapp_native -ldflags "$(LDFLAGS)" -o $(BUILD_DIR)/$(BINARY_NAME)-linux-mipsle ./$(CMD_DIR) $(call PATCH_MIPS_FLAGS,$(BUILD_DIR)/$(BINARY_NAME)-linux-mipsle) - GOOS=darwin GOARCH=arm64 $(GO) build -tags whatsapp_native -ldflags "$(LDFLAGS)" -o $(BUILD_DIR)/$(BINARY_NAME)-darwin-arm64 ./$(CMD_DIR) - GOOS=windows GOARCH=amd64 $(GO) build -tags whatsapp_native -ldflags "$(LDFLAGS)" -o $(BUILD_DIR)/$(BINARY_NAME)-windows-amd64.exe ./$(CMD_DIR) + GOOS=darwin GOARCH=arm64 $(GO) build -tags $(GO_BUILD_TAGS),whatsapp_native -ldflags "$(LDFLAGS)" -o $(BUILD_DIR)/$(BINARY_NAME)-darwin-arm64 ./$(CMD_DIR) + GOOS=windows GOARCH=amd64 $(GO) build -tags $(GO_BUILD_TAGS),whatsapp_native -ldflags "$(LDFLAGS)" -o $(BUILD_DIR)/$(BINARY_NAME)-windows-amd64.exe ./$(CMD_DIR) ## @$(GO) build $(GOFLAGS) -tags whatsapp_native -ldflags "$(LDFLAGS)" -o $(BINARY_PATH) ./$(CMD_DIR) @echo "Build complete" ## @ln -sf $(BINARY_NAME)-$(PLATFORM)-$(ARCH) $(BUILD_DIR)/$(BINARY_NAME) @@ -147,21 +153,21 @@ build-whatsapp-native: generate build-linux-arm: generate @echo "Building for linux/arm (GOARM=7)..." @mkdir -p $(BUILD_DIR) - GOOS=linux GOARCH=arm GOARM=7 $(GO) build -ldflags "$(LDFLAGS)" -o $(BUILD_DIR)/$(BINARY_NAME)-linux-arm ./$(CMD_DIR) + GOOS=linux GOARCH=arm GOARM=7 $(GO) build $(GOFLAGS) -ldflags "$(LDFLAGS)" -o $(BUILD_DIR)/$(BINARY_NAME)-linux-arm ./$(CMD_DIR) @echo "Build complete: $(BUILD_DIR)/$(BINARY_NAME)-linux-arm" ## build-linux-arm64: Build for Linux ARM64 (e.g. Raspberry Pi Zero 2 W 64-bit) build-linux-arm64: generate @echo "Building for linux/arm64..." @mkdir -p $(BUILD_DIR) - GOOS=linux GOARCH=arm64 $(GO) build -ldflags "$(LDFLAGS)" -o $(BUILD_DIR)/$(BINARY_NAME)-linux-arm64 ./$(CMD_DIR) + GOOS=linux GOARCH=arm64 $(GO) build $(GOFLAGS) -ldflags "$(LDFLAGS)" -o $(BUILD_DIR)/$(BINARY_NAME)-linux-arm64 ./$(CMD_DIR) @echo "Build complete: $(BUILD_DIR)/$(BINARY_NAME)-linux-arm64" ## build-linux-mipsle: Build for Linux MIPS32 LE build-linux-mipsle: generate @echo "Building for linux/mipsle (softfloat)..." @mkdir -p $(BUILD_DIR) - GOOS=linux GOARCH=mipsle GOMIPS=softfloat $(GO) build -ldflags "$(LDFLAGS)" -o $(BUILD_DIR)/$(BINARY_NAME)-linux-mipsle ./$(CMD_DIR) + GOOS=linux GOARCH=mipsle GOMIPS=softfloat $(GO) build $(GOFLAGS_NO_GOOLM) -ldflags "$(LDFLAGS)" -o $(BUILD_DIR)/$(BINARY_NAME)-linux-mipsle ./$(CMD_DIR) $(call PATCH_MIPS_FLAGS,$(BUILD_DIR)/$(BINARY_NAME)-linux-mipsle) @echo "Build complete: $(BUILD_DIR)/$(BINARY_NAME)-linux-mipsle" @@ -173,18 +179,18 @@ build-pi-zero: build-linux-arm build-linux-arm64 build-all: generate @echo "Building for multiple platforms..." @mkdir -p $(BUILD_DIR) - GOOS=linux GOARCH=amd64 $(GO) build -ldflags "$(LDFLAGS)" -o $(BUILD_DIR)/$(BINARY_NAME)-linux-amd64 ./$(CMD_DIR) - GOOS=linux GOARCH=arm GOARM=7 $(GO) build -ldflags "$(LDFLAGS)" -o $(BUILD_DIR)/$(BINARY_NAME)-linux-arm ./$(CMD_DIR) - GOOS=linux GOARCH=arm64 $(GO) build -ldflags "$(LDFLAGS)" -o $(BUILD_DIR)/$(BINARY_NAME)-linux-arm64 ./$(CMD_DIR) - GOOS=linux GOARCH=loong64 $(GO) build -ldflags "$(LDFLAGS)" -o $(BUILD_DIR)/$(BINARY_NAME)-linux-loong64 ./$(CMD_DIR) - GOOS=linux GOARCH=riscv64 $(GO) build -ldflags "$(LDFLAGS)" -o $(BUILD_DIR)/$(BINARY_NAME)-linux-riscv64 ./$(CMD_DIR) - GOOS=linux GOARCH=mipsle GOMIPS=softfloat $(GO) build -ldflags "$(LDFLAGS)" -o $(BUILD_DIR)/$(BINARY_NAME)-linux-mipsle ./$(CMD_DIR) + GOOS=linux GOARCH=amd64 $(GO) build $(GOFLAGS) -ldflags "$(LDFLAGS)" -o $(BUILD_DIR)/$(BINARY_NAME)-linux-amd64 ./$(CMD_DIR) + GOOS=linux GOARCH=arm GOARM=7 $(GO) build $(GOFLAGS) -ldflags "$(LDFLAGS)" -o $(BUILD_DIR)/$(BINARY_NAME)-linux-arm ./$(CMD_DIR) + GOOS=linux GOARCH=arm64 $(GO) build $(GOFLAGS) -ldflags "$(LDFLAGS)" -o $(BUILD_DIR)/$(BINARY_NAME)-linux-arm64 ./$(CMD_DIR) + GOOS=linux GOARCH=loong64 $(GO) build $(GOFLAGS) -ldflags "$(LDFLAGS)" -o $(BUILD_DIR)/$(BINARY_NAME)-linux-loong64 ./$(CMD_DIR) + GOOS=linux GOARCH=riscv64 $(GO) build $(GOFLAGS) -ldflags "$(LDFLAGS)" -o $(BUILD_DIR)/$(BINARY_NAME)-linux-riscv64 ./$(CMD_DIR) + GOOS=linux GOARCH=mipsle GOMIPS=softfloat $(GO) build $(GOFLAGS_NO_GOOLM) -ldflags "$(LDFLAGS)" -o $(BUILD_DIR)/$(BINARY_NAME)-linux-mipsle ./$(CMD_DIR) $(call PATCH_MIPS_FLAGS,$(BUILD_DIR)/$(BINARY_NAME)-linux-mipsle) - GOOS=linux GOARCH=arm GOARM=7 $(GO) build -ldflags "$(LDFLAGS)" -o $(BUILD_DIR)/$(BINARY_NAME)-linux-armv7 ./$(CMD_DIR) - GOOS=darwin GOARCH=arm64 $(GO) build -ldflags "$(LDFLAGS)" -o $(BUILD_DIR)/$(BINARY_NAME)-darwin-arm64 ./$(CMD_DIR) - GOOS=windows GOARCH=amd64 $(GO) build -ldflags "$(LDFLAGS)" -o $(BUILD_DIR)/$(BINARY_NAME)-windows-amd64.exe ./$(CMD_DIR) - GOOS=netbsd GOARCH=amd64 $(GO) build -ldflags "$(LDFLAGS)" -o $(BUILD_DIR)/$(BINARY_NAME)-netbsd-amd64 ./$(CMD_DIR) - GOOS=netbsd GOARCH=arm64 $(GO) build -ldflags "$(LDFLAGS)" -o $(BUILD_DIR)/$(BINARY_NAME)-netbsd-arm64 ./$(CMD_DIR) + GOOS=linux GOARCH=arm GOARM=7 $(GO) build $(GOFLAGS) -ldflags "$(LDFLAGS)" -o $(BUILD_DIR)/$(BINARY_NAME)-linux-armv7 ./$(CMD_DIR) + GOOS=darwin GOARCH=arm64 $(GO) build $(GOFLAGS) -ldflags "$(LDFLAGS)" -o $(BUILD_DIR)/$(BINARY_NAME)-darwin-arm64 ./$(CMD_DIR) + GOOS=windows GOARCH=amd64 $(GO) build $(GOFLAGS) -ldflags "$(LDFLAGS)" -o $(BUILD_DIR)/$(BINARY_NAME)-windows-amd64.exe ./$(CMD_DIR) + GOOS=netbsd GOARCH=amd64 $(GO) build $(GOFLAGS) -ldflags "$(LDFLAGS)" -o $(BUILD_DIR)/$(BINARY_NAME)-netbsd-amd64 ./$(CMD_DIR) + GOOS=netbsd GOARCH=arm64 $(GO) build $(GOFLAGS) -ldflags "$(LDFLAGS)" -o $(BUILD_DIR)/$(BINARY_NAME)-netbsd-arm64 ./$(CMD_DIR) @echo "All builds complete" ## install: Install picoclaw to system and copy builtin skills @@ -221,13 +227,13 @@ clean: ## vet: Run go vet for static analysis vet: generate - @packages="$$(go list ./...)" && \ - $(GO) vet $$(printf '%s\n' "$$packages" | grep -v '^github.com/sipeed/picoclaw/web/') + @packages="$$($(GO) list $(GOFLAGS) ./...)" && \ + $(GO) vet $(GOFLAGS) $$(printf '%s\n' "$$packages" | grep -v '^github.com/sipeed/picoclaw/web/') @cd web/backend && $(WEB_GO) vet ./... ## test: Test Go code test: generate - @$(GO) test $$(go list ./... | grep -v github.com/sipeed/picoclaw/web/) + @$(GO) test $(GOFLAGS) $$($(GO) list $(GOFLAGS) ./... | grep -v github.com/sipeed/picoclaw/web/) @cd web && make test ## fmt: Format Go code @@ -236,11 +242,11 @@ fmt: ## lint: Run linters lint: - @$(GOLANGCI_LINT) run + @CGO_ENABLED=0 $(GOLANGCI_LINT) run --build-tags $(GO_BUILD_TAGS) ## fix: Fix linting issues fix: - @$(GOLANGCI_LINT) run --fix + @CGO_ENABLED=0 $(GOLANGCI_LINT) run --fix --build-tags $(GO_BUILD_TAGS) ## deps: Download dependencies deps: diff --git a/README.fr.md b/README.fr.md index 301456262..a4fa628c9 100644 --- a/README.fr.md +++ b/README.fr.md @@ -524,7 +524,7 @@ Connectez PicoClaw au réseau social des Agents simplement en envoyant un seul m | Commande | Description | | ------------------------- | ---------------------------------------- | | `picoclaw onboard` | Initialiser la config & le workspace | -| `picoclaw onboard weixin` | Connecter un compte WeChat via QR | +| `picoclaw auth weixin` | Connecter un compte WeChat via QR | | `picoclaw agent -m "..."` | Chatter avec l'agent | | `picoclaw agent` | Mode chat interactif | | `picoclaw gateway` | Démarrer le gateway | diff --git a/README.id.md b/README.id.md index 6b7025ffd..6d62dcb9b 100644 --- a/README.id.md +++ b/README.id.md @@ -520,7 +520,7 @@ Hubungkan PicoClaw ke Jaringan Sosial Agent hanya dengan mengirim satu pesan mel | Perintah | Deskripsi | | -------------------------- | -------------------------------- | | `picoclaw onboard` | Inisialisasi konfigurasi & workspace | -| `picoclaw onboard weixin` | Hubungkan akun WeChat via QR | +| `picoclaw auth weixin` | Hubungkan akun WeChat via QR | | `picoclaw agent -m "..."` | Chat dengan agent | | `picoclaw agent` | Mode chat interaktif | | `picoclaw gateway` | Mulai gateway | diff --git a/README.it.md b/README.it.md index dae541a17..1ed73ee54 100644 --- a/README.it.md +++ b/README.it.md @@ -520,7 +520,7 @@ Connetti PicoClaw al Social Network degli Agent semplicemente inviando un singol | Comando | Descrizione | | ------------------------- | ---------------------------------- | | `picoclaw onboard` | Inizializza config & workspace | -| `picoclaw onboard weixin` | Connetti account WeChat tramite QR | +| `picoclaw auth weixin` | Connetti account WeChat tramite QR | | `picoclaw agent -m "..."` | Chatta con l'agent | | `picoclaw agent` | Modalità chat interattiva | | `picoclaw gateway` | Avvia il gateway | diff --git a/README.ja.md b/README.ja.md index 3096d4022..9165986ba 100644 --- a/README.ja.md +++ b/README.ja.md @@ -520,7 +520,7 @@ CLI または統合チャットアプリからメッセージを 1 つ送るだ | コマンド | 説明 | | ------------------------- | ------------------------------ | | `picoclaw onboard` | 設定&ワークスペースの初期化 | -| `picoclaw onboard weixin` | WeChat アカウントを QR で接続 | +| `picoclaw auth weixin` | WeChat アカウントを QR で接続 | | `picoclaw agent -m "..."` | Agent とチャット | | `picoclaw agent` | インタラクティブチャットモード | | `picoclaw gateway` | Gateway を起動 | diff --git a/README.md b/README.md index 3ddce3a3f..f627e261e 100644 --- a/README.md +++ b/README.md @@ -526,7 +526,7 @@ Connect PicoClaw to the Agent Social Network simply by sending a single message | Command | Description | | ------------------------- | -------------------------------- | | `picoclaw onboard` | Initialize config & workspace | -| `picoclaw onboard weixin` | Connect WeChat account via QR | +| `picoclaw auth weixin` | Connect WeChat account via QR | | `picoclaw agent -m "..."` | Chat with the agent | | `picoclaw agent` | Interactive chat mode | | `picoclaw gateway` | Start the gateway | diff --git a/README.pt-br.md b/README.pt-br.md index 3c039f190..d4b303e24 100644 --- a/README.pt-br.md +++ b/README.pt-br.md @@ -520,7 +520,7 @@ Conecte o PicoClaw à Rede Social de Agents simplesmente enviando uma única men | Comando | Descrição | | ------------------------- | -------------------------------------- | | `picoclaw onboard` | Inicializar config e workspace | -| `picoclaw onboard weixin` | Conectar conta WeChat via QR | +| `picoclaw auth weixin` | Conectar conta WeChat via QR | | `picoclaw agent -m "..."` | Conversar com o agent | | `picoclaw agent` | Modo de chat interativo | | `picoclaw gateway` | Iniciar o gateway | diff --git a/README.vi.md b/README.vi.md index b63fd4ef7..ceeb02b63 100644 --- a/README.vi.md +++ b/README.vi.md @@ -520,7 +520,7 @@ Kết nối PicoClaw với Mạng xã hội Agent chỉ bằng cách gửi một | Lệnh | Mô tả | | ------------------------- | ---------------------------------------- | | `picoclaw onboard` | Khởi tạo cấu hình & workspace | -| `picoclaw onboard weixin` | Kết nối tài khoản WeChat qua QR | +| `picoclaw auth weixin` | Kết nối tài khoản WeChat qua QR | | `picoclaw agent -m "..."` | Trò chuyện với agent | | `picoclaw agent` | Chế độ trò chuyện tương tác | | `picoclaw gateway` | Khởi động gateway | diff --git a/README.zh.md b/README.zh.md index de96e5164..93abf89d3 100644 --- a/README.zh.md +++ b/README.zh.md @@ -520,7 +520,7 @@ PicoClaw 原生支持 [MCP](https://modelcontextprotocol.io/) — 连接任意 M | 命令 | 说明 | | ------------------------- | ---------------------- | | `picoclaw onboard` | 初始化配置与工作区 | -| `picoclaw onboard weixin` | 扫码连接微信个人号 | +| `picoclaw auth weixin` | 扫码连接微信个人号 | | `picoclaw agent -m "..."` | 与 Agent 对话 | | `picoclaw agent` | 交互式对话模式 | | `picoclaw gateway` | 启动网关 | diff --git a/cmd/picoclaw/internal/auth/command.go b/cmd/picoclaw/internal/auth/command.go index 12a0a3a8c..9de083d8d 100644 --- a/cmd/picoclaw/internal/auth/command.go +++ b/cmd/picoclaw/internal/auth/command.go @@ -16,6 +16,8 @@ func NewAuthCommand() *cobra.Command { newLogoutCommand(), newStatusCommand(), newModelsCommand(), + newWeixinCommand(), + newWeComCommand(), ) return cmd diff --git a/cmd/picoclaw/internal/auth/command_test.go b/cmd/picoclaw/internal/auth/command_test.go index 48dc704dd..3c7f2d3d6 100644 --- a/cmd/picoclaw/internal/auth/command_test.go +++ b/cmd/picoclaw/internal/auth/command_test.go @@ -32,6 +32,8 @@ func TestNewAuthCommand(t *testing.T) { "logout", "status", "models", + "weixin", + "wecom", } subcommands := cmd.Commands() diff --git a/cmd/picoclaw/internal/auth/wecom.go b/cmd/picoclaw/internal/auth/wecom.go new file mode 100644 index 000000000..8261f5f80 --- /dev/null +++ b/cmd/picoclaw/internal/auth/wecom.go @@ -0,0 +1,407 @@ +package auth + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "os" + "runtime" + "strconv" + "strings" + "time" + + "github.com/mdp/qrterminal/v3" + "github.com/spf13/cobra" + + "github.com/sipeed/picoclaw/cmd/picoclaw/internal" + "github.com/sipeed/picoclaw/pkg/config" +) + +const ( + wecomQRSourceID = "picoclaw" + wecomQRGenerateEndpoint = "https://work.weixin.qq.com/ai/qc/generate" + wecomQRQueryEndpoint = "https://work.weixin.qq.com/ai/qc/query_result" + wecomQRPageEndpoint = "https://work.weixin.qq.com/ai/qc/gen" + wecomQRHTTPTimeout = 15 * time.Second + wecomQRPollInterval = 3 * time.Second + wecomQRPollTimeout = 5 * time.Minute + wecomDefaultWebSocketURL = "wss://openws.work.weixin.qq.com" +) + +type wecomQRScanner func(context.Context, wecomQRFlowOptions) (wecomQRBotInfo, error) + +type wecomQRFlowOptions struct { + HTTPClient *http.Client + GenerateURL string + QueryURL string + QRCodePageURL string + SourceID string + PollInterval time.Duration + PollTimeout time.Duration + Writer io.Writer +} + +type wecomQRBotInfo struct { + BotID string + Secret string +} + +type wecomQRSession struct { + SCode string + AuthURL string +} + +type wecomQRGenerateResponse struct { + ErrCode int `json:"errcode,omitempty"` + ErrMsg string `json:"errmsg,omitempty"` + Data struct { + SCode string `json:"scode"` + AuthURL string `json:"auth_url"` + } `json:"data"` +} + +type wecomQRQueryResponse struct { + ErrCode int `json:"errcode,omitempty"` + ErrMsg string `json:"errmsg,omitempty"` + Data struct { + Status string `json:"status"` + BotInfo struct { + BotID string `json:"botid"` + Secret string `json:"secret"` + } `json:"bot_info"` + } `json:"data"` +} + +func newWeComCommand() *cobra.Command { + var timeout time.Duration + + cmd := &cobra.Command{ + Use: "wecom", + Short: "Scan a WeCom QR code and configure channels.wecom", + Args: cobra.NoArgs, + RunE: func(_ *cobra.Command, _ []string) error { + return authWeComCmd(timeout) + }, + } + + cmd.Flags().DurationVar(&timeout, "timeout", wecomQRPollTimeout, "How long to wait for QR confirmation") + + return cmd +} + +func authWeComCmd(timeout time.Duration) error { + return authWeComCmdWithScanner(context.Background(), os.Stdout, timeout, scanWeComQRCodeInteractive) +} + +func authWeComCmdWithScanner( + ctx context.Context, + writer io.Writer, + timeout time.Duration, + scanner wecomQRScanner, +) error { + if scanner == nil { + return fmt.Errorf("wecom QR scanner is nil") + } + if writer == nil { + writer = os.Stdout + } + + cfg, err := internal.LoadConfig() + if err != nil { + return fmt.Errorf("failed to load config: %w", err) + } + + opts := defaultWeComQRFlowOptions(timeout) + opts.Writer = writer + + botInfo, err := scanner(ctx, opts) + if err != nil { + return err + } + + applyWeComAuthResult(cfg, botInfo) + + if saveErr := config.SaveConfig(internal.GetConfigPath(), cfg); saveErr != nil { + return fmt.Errorf("failed to save config: %w", saveErr) + } + + fmt.Fprintln(writer) + fmt.Fprintln(writer, "WeCom connected.") + fmt.Fprintf(writer, "Bot ID: %s\n", botInfo.BotID) + fmt.Fprintf(writer, "Config: %s\n", internal.GetConfigPath()) + + return nil +} + +func defaultWeComQRFlowOptions(timeout time.Duration) wecomQRFlowOptions { + if timeout <= 0 { + timeout = wecomQRPollTimeout + } + + return wecomQRFlowOptions{ + HTTPClient: &http.Client{Timeout: wecomQRHTTPTimeout}, + GenerateURL: wecomQRGenerateEndpoint, + QueryURL: wecomQRQueryEndpoint, + QRCodePageURL: wecomQRPageEndpoint, + SourceID: wecomQRSourceID, + PollInterval: wecomQRPollInterval, + PollTimeout: timeout, + Writer: os.Stdout, + } +} + +func applyWeComAuthResult(cfg *config.Config, botInfo wecomQRBotInfo) { + cfg.Channels.WeCom.Enabled = true + cfg.Channels.WeCom.BotID = botInfo.BotID + cfg.Channels.WeCom.SetSecret(botInfo.Secret) + if strings.TrimSpace(cfg.Channels.WeCom.WebSocketURL) == "" { + cfg.Channels.WeCom.WebSocketURL = wecomDefaultWebSocketURL + } +} + +func scanWeComQRCodeInteractive(ctx context.Context, opts wecomQRFlowOptions) (wecomQRBotInfo, error) { + opts = normalizeWeComQRFlowOptions(opts) + + fmt.Fprintln(opts.Writer, "Requesting WeCom QR code...") + + session, err := fetchWeComQRCode(ctx, opts) + if err != nil { + return wecomQRBotInfo{}, err + } + + fmt.Fprintln(opts.Writer) + fmt.Fprintln(opts.Writer, "=======================================================") + fmt.Fprintln(opts.Writer, "Please scan the following QR code with WeCom:") + fmt.Fprintln(opts.Writer, "=======================================================") + fmt.Fprintln(opts.Writer) + + qrterminal.GenerateWithConfig(session.AuthURL, qrterminal.Config{ + Level: qrterminal.L, + Writer: opts.Writer, + HalfBlocks: true, + }) + + pageURL, err := buildWeComQRCodePageURL(opts.QRCodePageURL, opts.SourceID, session.SCode) + if err != nil { + return wecomQRBotInfo{}, err + } + + fmt.Fprintln(opts.Writer) + fmt.Fprintf(opts.Writer, "QR Code Link: %s\n", pageURL) + fmt.Fprintln(opts.Writer) + fmt.Fprintln(opts.Writer, "Waiting for scan...") + + return pollWeComQRCodeResult(ctx, opts, session.SCode) +} + +func normalizeWeComQRFlowOptions(opts wecomQRFlowOptions) wecomQRFlowOptions { + if opts.HTTPClient == nil { + opts.HTTPClient = &http.Client{Timeout: wecomQRHTTPTimeout} + } + if strings.TrimSpace(opts.GenerateURL) == "" { + opts.GenerateURL = wecomQRGenerateEndpoint + } + if strings.TrimSpace(opts.QueryURL) == "" { + opts.QueryURL = wecomQRQueryEndpoint + } + if strings.TrimSpace(opts.QRCodePageURL) == "" { + opts.QRCodePageURL = wecomQRPageEndpoint + } + if strings.TrimSpace(opts.SourceID) == "" { + opts.SourceID = wecomQRSourceID + } + if opts.PollInterval <= 0 { + opts.PollInterval = wecomQRPollInterval + } + if opts.PollTimeout <= 0 { + opts.PollTimeout = wecomQRPollTimeout + } + if opts.Writer == nil { + opts.Writer = os.Stdout + } + + return opts +} + +func fetchWeComQRCode(ctx context.Context, opts wecomQRFlowOptions) (wecomQRSession, error) { + generateURL, err := buildWeComQRGenerateURL(opts.GenerateURL, opts.SourceID, wecomPlatformCode()) + if err != nil { + return wecomQRSession{}, err + } + + var resp wecomQRGenerateResponse + if err := doWeComJSONGet(ctx, opts.HTTPClient, generateURL, &resp); err != nil { + return wecomQRSession{}, fmt.Errorf("failed to get WeCom QR code: %w", err) + } + if resp.ErrCode != 0 { + return wecomQRSession{}, fmt.Errorf( + "failed to get WeCom QR code: errcode=%d errmsg=%s", + resp.ErrCode, + resp.ErrMsg, + ) + } + if resp.Data.SCode == "" || resp.Data.AuthURL == "" { + return wecomQRSession{}, fmt.Errorf("failed to get WeCom QR code: response missing scode or auth_url") + } + + return wecomQRSession{ + SCode: resp.Data.SCode, + AuthURL: resp.Data.AuthURL, + }, nil +} + +func pollWeComQRCodeResult(ctx context.Context, opts wecomQRFlowOptions, scode string) (wecomQRBotInfo, error) { + if strings.TrimSpace(scode) == "" { + return wecomQRBotInfo{}, fmt.Errorf("missing WeCom QR scode") + } + + timeoutCtx, cancel := context.WithTimeout(ctx, opts.PollTimeout) + defer cancel() + + var scannedPrinted bool + + for { + status, err := queryWeComQRCodeStatus(timeoutCtx, opts, scode) + if err != nil { + if errors.Is(err, context.DeadlineExceeded) || errors.Is(timeoutCtx.Err(), context.DeadlineExceeded) { + return wecomQRBotInfo{}, fmt.Errorf("WeCom QR scan timed out after %s", opts.PollTimeout) + } + return wecomQRBotInfo{}, err + } + + switch strings.ToLower(status.Data.Status) { + case "success": + if status.Data.BotInfo.BotID == "" || status.Data.BotInfo.Secret == "" { + return wecomQRBotInfo{}, fmt.Errorf("WeCom QR scan succeeded but bot credentials are missing") + } + return wecomQRBotInfo{ + BotID: status.Data.BotInfo.BotID, + Secret: status.Data.BotInfo.Secret, + }, nil + case "expired": + return wecomQRBotInfo{}, fmt.Errorf("WeCom QR code expired, please retry") + case "scaned", "scanned": + if !scannedPrinted { + fmt.Fprintln(opts.Writer, "QR code scanned. Confirm the login in WeCom.") + scannedPrinted = true + } + } + + select { + case <-timeoutCtx.Done(): + if errors.Is(timeoutCtx.Err(), context.DeadlineExceeded) { + return wecomQRBotInfo{}, fmt.Errorf("WeCom QR scan timed out after %s", opts.PollTimeout) + } + return wecomQRBotInfo{}, timeoutCtx.Err() + case <-time.After(opts.PollInterval): + } + } +} + +func queryWeComQRCodeStatus(ctx context.Context, opts wecomQRFlowOptions, scode string) (wecomQRQueryResponse, error) { + queryURL, err := buildWeComQRQueryURL(opts.QueryURL, scode) + if err != nil { + return wecomQRQueryResponse{}, err + } + + var resp wecomQRQueryResponse + if err := doWeComJSONGet(ctx, opts.HTTPClient, queryURL, &resp); err != nil { + return wecomQRQueryResponse{}, fmt.Errorf("failed to query WeCom QR result: %w", err) + } + if resp.ErrCode != 0 { + return wecomQRQueryResponse{}, fmt.Errorf( + "failed to query WeCom QR result: errcode=%d errmsg=%s", + resp.ErrCode, + resp.ErrMsg, + ) + } + + return resp, nil +} + +func buildWeComQRGenerateURL(baseURL, sourceID string, platformCode int) (string, error) { + u, err := url.Parse(baseURL) + if err != nil { + return "", fmt.Errorf("invalid WeCom QR generate URL: %w", err) + } + + query := u.Query() + query.Set("source", sourceID) + query.Set("sourceID", sourceID) + query.Set("plat", strconv.Itoa(platformCode)) + u.RawQuery = query.Encode() + + return u.String(), nil +} + +func buildWeComQRQueryURL(baseURL, scode string) (string, error) { + u, err := url.Parse(baseURL) + if err != nil { + return "", fmt.Errorf("invalid WeCom QR query URL: %w", err) + } + + query := u.Query() + query.Set("scode", scode) + u.RawQuery = query.Encode() + + return u.String(), nil +} + +func buildWeComQRCodePageURL(baseURL, sourceID, scode string) (string, error) { + u, err := url.Parse(baseURL) + if err != nil { + return "", fmt.Errorf("invalid WeCom QR page URL: %w", err) + } + + query := u.Query() + query.Set("source", sourceID) + query.Set("sourceID", sourceID) + query.Set("scode", scode) + u.RawQuery = query.Encode() + + return u.String(), nil +} + +func doWeComJSONGet(ctx context.Context, client *http.Client, targetURL string, out any) error { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, targetURL, nil) + if err != nil { + return err + } + + resp, err := client.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, readErr := io.ReadAll(io.LimitReader(resp.Body, 8192)) + if readErr != nil { + return fmt.Errorf("unexpected status %s", resp.Status) + } + return fmt.Errorf("unexpected status %s: %s", resp.Status, strings.TrimSpace(string(body))) + } + + if err := json.NewDecoder(resp.Body).Decode(out); err != nil { + return fmt.Errorf("decode JSON response: %w", err) + } + + return nil +} + +func wecomPlatformCode() int { + switch runtime.GOOS { + case "darwin": + return 1 + case "windows": + return 2 + case "linux": + return 3 + default: + return 0 + } +} diff --git a/cmd/picoclaw/internal/auth/wecom_test.go b/cmd/picoclaw/internal/auth/wecom_test.go new file mode 100644 index 000000000..c2a4624ae --- /dev/null +++ b/cmd/picoclaw/internal/auth/wecom_test.go @@ -0,0 +1,157 @@ +package auth + +import ( + "bytes" + "context" + "net/http" + "net/http/httptest" + "net/url" + "path/filepath" + "strconv" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/sipeed/picoclaw/cmd/picoclaw/internal" + "github.com/sipeed/picoclaw/pkg/config" +) + +func TestNewWeComCommand(t *testing.T) { + cmd := newWeComCommand() + + require.NotNil(t, cmd) + assert.Equal(t, "wecom", cmd.Use) + assert.Equal(t, "Scan a WeCom QR code and configure channels.wecom", cmd.Short) + assert.NotNil(t, cmd.Flags().Lookup("timeout")) +} + +func TestBuildWeComQRGenerateURL(t *testing.T) { + rawURL, err := buildWeComQRGenerateURL("https://example.com/ai/qc/generate", wecomQRSourceID, 3) + require.NoError(t, err) + + parsed, err := url.Parse(rawURL) + require.NoError(t, err) + + assert.Equal(t, wecomQRSourceID, parsed.Query().Get("source")) + assert.Equal(t, wecomQRSourceID, parsed.Query().Get("sourceID")) + assert.Equal(t, "3", parsed.Query().Get("plat")) +} + +func TestBuildWeComQRCodePageURL(t *testing.T) { + rawURL, err := buildWeComQRCodePageURL("https://example.com/ai/qc/gen", wecomQRSourceID, "scode-1") + require.NoError(t, err) + + parsed, err := url.Parse(rawURL) + require.NoError(t, err) + + assert.Equal(t, wecomQRSourceID, parsed.Query().Get("source")) + assert.Equal(t, wecomQRSourceID, parsed.Query().Get("sourceID")) + assert.Equal(t, "scode-1", parsed.Query().Get("scode")) +} + +func TestFetchWeComQRCode(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/generate", r.URL.Path) + assert.Equal(t, wecomQRSourceID, r.URL.Query().Get("source")) + assert.Equal(t, wecomQRSourceID, r.URL.Query().Get("sourceID")) + assert.Equal(t, strconv.Itoa(wecomPlatformCode()), r.URL.Query().Get("plat")) + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"data":{"scode":"scode-1","auth_url":"https://example.com/qr"}}`)) + })) + defer server.Close() + + opts := normalizeWeComQRFlowOptions(wecomQRFlowOptions{ + HTTPClient: server.Client(), + GenerateURL: server.URL + "/generate", + Writer: bytes.NewBuffer(nil), + }) + + session, err := fetchWeComQRCode(context.Background(), opts) + require.NoError(t, err) + assert.Equal(t, "scode-1", session.SCode) + assert.Equal(t, "https://example.com/qr", session.AuthURL) +} + +func TestPollWeComQRCodeResult(t *testing.T) { + var calls atomic.Int32 + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + call := calls.Add(1) + assert.Equal(t, "/query", r.URL.Path) + assert.Equal(t, "scode-1", r.URL.Query().Get("scode")) + w.Header().Set("Content-Type", "application/json") + switch call { + case 1: + _, _ = w.Write([]byte(`{"data":{"status":"wait"}}`)) + case 2: + _, _ = w.Write([]byte(`{"data":{"status":"scaned"}}`)) + default: + _, _ = w.Write([]byte(`{"data":{"status":"success","bot_info":{"botid":"bot-1","secret":"secret-1"}}}`)) + } + })) + defer server.Close() + + var output bytes.Buffer + opts := normalizeWeComQRFlowOptions(wecomQRFlowOptions{ + HTTPClient: server.Client(), + QueryURL: server.URL + "/query", + PollInterval: time.Millisecond, + PollTimeout: time.Second, + Writer: &output, + }) + + botInfo, err := pollWeComQRCodeResult(context.Background(), opts, "scode-1") + require.NoError(t, err) + assert.Equal(t, "bot-1", botInfo.BotID) + assert.Equal(t, "secret-1", botInfo.Secret) + assert.Contains(t, output.String(), "QR code scanned. Confirm the login in WeCom.") +} + +func TestApplyWeComAuthResult(t *testing.T) { + cfg := config.DefaultConfig() + cfg.Channels.WeCom.WebSocketURL = "" + + applyWeComAuthResult(cfg, wecomQRBotInfo{ + BotID: "bot-1", + Secret: "secret-1", + }) + + assert.True(t, cfg.Channels.WeCom.Enabled) + assert.Equal(t, "bot-1", cfg.Channels.WeCom.BotID) + assert.Equal(t, "secret-1", cfg.Channels.WeCom.Secret()) + assert.Equal(t, wecomDefaultWebSocketURL, cfg.Channels.WeCom.WebSocketURL) +} + +func TestAuthWeComCmdWithScanner(t *testing.T) { + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "config.json") + + t.Setenv(config.EnvHome, tmpDir) + t.Setenv(config.EnvConfig, configPath) + + var output bytes.Buffer + err := authWeComCmdWithScanner( + context.Background(), + &output, + time.Second, + func(_ context.Context, opts wecomQRFlowOptions) (wecomQRBotInfo, error) { + assert.Equal(t, wecomQRSourceID, opts.SourceID) + return wecomQRBotInfo{ + BotID: "bot-1", + Secret: "secret-1", + }, nil + }, + ) + require.NoError(t, err) + + cfg, err := config.LoadConfig(internal.GetConfigPath()) + require.NoError(t, err) + assert.True(t, cfg.Channels.WeCom.Enabled) + assert.Equal(t, "bot-1", cfg.Channels.WeCom.BotID) + assert.Equal(t, "secret-1", cfg.Channels.WeCom.Secret()) + assert.Equal(t, wecomDefaultWebSocketURL, cfg.Channels.WeCom.WebSocketURL) + assert.Contains(t, output.String(), "WeCom connected.") +} diff --git a/cmd/picoclaw/internal/onboard/weixin.go b/cmd/picoclaw/internal/auth/weixin.go similarity index 98% rename from cmd/picoclaw/internal/onboard/weixin.go rename to cmd/picoclaw/internal/auth/weixin.go index 2e1c2ad75..948a81495 100644 --- a/cmd/picoclaw/internal/onboard/weixin.go +++ b/cmd/picoclaw/internal/auth/weixin.go @@ -1,4 +1,4 @@ -package onboard +package auth import ( "context" @@ -27,7 +27,7 @@ to authorize your account. On success, the bot token is saved to the picoclaw config so you can start the gateway immediately. Example: - picoclaw onboard weixin`, + picoclaw auth weixin`, RunE: func(cmd *cobra.Command, _ []string) error { return runWeixinOnboard(baseURL, proxy, time.Duration(timeout)*time.Second) }, diff --git a/cmd/picoclaw/internal/onboard/command.go b/cmd/picoclaw/internal/onboard/command.go index 1f94c6718..4be19b2a5 100644 --- a/cmd/picoclaw/internal/onboard/command.go +++ b/cmd/picoclaw/internal/onboard/command.go @@ -16,7 +16,7 @@ func NewOnboardCommand() *cobra.Command { cmd := &cobra.Command{ Use: "onboard", Aliases: []string{"o"}, - Short: "Initialize picoclaw configuration, workspace, and channel accounts", + Short: "Initialize picoclaw configuration and workspace", // Run without subcommands → original onboard flow Run: func(cmd *cobra.Command, args []string) { if len(args) == 0 { @@ -30,8 +30,5 @@ func NewOnboardCommand() *cobra.Command { cmd.Flags().BoolVar(&encrypt, "enc", false, "Enable credential encryption (generates SSH key and prompts for passphrase)") - // Channel onboarding subcommands - cmd.AddCommand(newWeixinCommand()) - return cmd } diff --git a/cmd/picoclaw/internal/onboard/command_test.go b/cmd/picoclaw/internal/onboard/command_test.go index 6b9fb6e95..56936190b 100644 --- a/cmd/picoclaw/internal/onboard/command_test.go +++ b/cmd/picoclaw/internal/onboard/command_test.go @@ -13,7 +13,7 @@ func TestNewOnboardCommand(t *testing.T) { require.NotNil(t, cmd) assert.Equal(t, "onboard", cmd.Use) - assert.Equal(t, "Initialize picoclaw configuration, workspace, and channel accounts", cmd.Short) + assert.Equal(t, "Initialize picoclaw configuration and workspace", cmd.Short) assert.Len(t, cmd.Aliases, 1) assert.True(t, cmd.HasAlias("o")) @@ -28,6 +28,5 @@ func TestNewOnboardCommand(t *testing.T) { encFlag := cmd.Flags().Lookup("enc") require.NotNil(t, encFlag, "expected --enc flag to be registered") assert.Equal(t, "false", encFlag.DefValue, "--enc should default to false") - assert.True(t, cmd.HasSubCommands()) - assert.NotNil(t, cmd.Commands()) + assert.False(t, cmd.HasSubCommands()) } diff --git a/config/config.example.json b/config/config.example.json index 82aee2904..90b603d72 100644 --- a/config/config.example.json +++ b/config/config.example.json @@ -162,7 +162,9 @@ "enabled": true, "text": "Thinking... 💭" }, - "reasoning_channel_id": "" + "reasoning_channel_id": "", + "crypto_database_path": "", + "crypto_passphrase": "YOUR_MATRIX_CRYPTO_PICKLE_KEY" }, "line": { "enabled": false, @@ -182,39 +184,13 @@ "reasoning_channel_id": "" }, "wecom": { - "_comment": "WeCom Bot - Easier setup, supports group chats", - "enabled": false, - "token": "YOUR_TOKEN", - "encoding_aes_key": "YOUR_43_CHAR_ENCODING_AES_KEY", - "webhook_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=YOUR_KEY", - "webhook_path": "/webhook/wecom", - "allow_from": [], - "reply_timeout": 5, - "reasoning_channel_id": "" - }, - "wecom_app": { - "_comment": "WeCom App (自建应用) - More features, proactive messaging, private chat only.", - "enabled": false, - "corp_id": "YOUR_CORP_ID", - "corp_secret": "YOUR_CORP_SECRET", - "agent_id": 1000002, - "token": "YOUR_TOKEN", - "encoding_aes_key": "YOUR_43_CHAR_ENCODING_AES_KEY", - "webhook_path": "/webhook/wecom-app", - "allow_from": [], - "reply_timeout": 5, - "reasoning_channel_id": "" - }, - "wecom_aibot": { - "_comment": "WeCom AI Bot (智能机器人) - Official WeCom AI Bot integration, supports proactive messaging and private chats.", + "_comment": "WeCom AI Bot over WebSocket.", "enabled": false, "bot_id": "YOUR_BOT_ID", "secret": "YOUR_SECRET", - "token": "YOUR_TOKEN", - "encoding_aes_key": "YOUR_43_CHAR_ENCODING_AES_KEY", - "webhook_path": "/webhook/wecom-aibot", - "max_steps": 10, - "welcome_message": "Hello! I'm your AI assistant. How can I help you today?", + "websocket_url": "wss://openws.work.weixin.qq.com", + "send_thinking_message": true, + "allow_from": [], "reasoning_channel_id": "" }, "pico": { diff --git a/docs/channels/matrix/README.md b/docs/channels/matrix/README.md index 2ed19245a..dd4b45eba 100644 --- a/docs/channels/matrix/README.md +++ b/docs/channels/matrix/README.md @@ -25,7 +25,9 @@ Add this to `config.json`: "text": "Thinking..." }, "reasoning_channel_id": "", - "message_format": "richtext" + "message_format": "richtext", + "crypto_database_path": "", + "crypto_passphrase": "YOUR_MATRIX_CRYPTO_PICKLE_KEY" } } } @@ -46,6 +48,8 @@ Add this to `config.json`: | placeholder | object | No | Placeholder message config | | reasoning_channel_id | string | No | Target channel for reasoning output | | message_format | string | No | Output format: `"richtext"` (default) renders markdown as HTML; `"plain"` sends plain text only | +| crypto_database_path | string | No | Path to store the crypto database (uses workspace path `~/.picoclaw/workspace` if empty) | +| crypto_passphrase | string | No | Serialization key for encrypting session keys in the database; must remain unchanged once set | ## 3. Currently Supported @@ -58,6 +62,7 @@ Add this to `config.json`: - Typing state (`m.typing`) - Placeholder message + final reply replacement - Auto-join invited rooms (can be disabled) +- End-to-end encryption (E2EE) support for encrypted messages ## 4. TODO diff --git a/docs/channels/matrix/README.zh.md b/docs/channels/matrix/README.zh.md index 8db3e4383..cd68a057e 100644 --- a/docs/channels/matrix/README.zh.md +++ b/docs/channels/matrix/README.zh.md @@ -24,7 +24,10 @@ "enabled": true, "text": "Thinking... 💭" }, - "reasoning_channel_id": "" + "reasoning_channel_id": "", + "message_format": "richtext", + "crypto_database_path": "", + "crypto_passphrase": "YOUR_MATRIX_CRYPTO_PICKLE_KEY" } } } @@ -45,6 +48,8 @@ | placeholder | object | 否 | 占位消息配置 | | reasoning_channel_id | string | 否 | 思维链输出目标通道 | | message_format | string | 否 | 消息格式:`richtext`(富文本)或 `plain`(纯文本) | +| crypto_database_path | string | 否 | 加密数据库存储路径(为空时使用工作空间路径 `~/.picoclaw/workspace`) | +| crypto_passphrase | string | 否 | 加密数据库中 session key 的序列化密钥;设置后不能更改 | ## 3. 当前支持 @@ -56,6 +61,7 @@ - Typing 状态(`m.typing`) - 占位消息(`Thinking... 💭`)+ 最终回复替换 - 自动加入邀请房间(可关闭) +- 端对端加密(E2EE)消息支持 ## 4. TODO diff --git a/docs/channels/wecom/README.md b/docs/channels/wecom/README.md new file mode 100644 index 000000000..ecdfbc47b --- /dev/null +++ b/docs/channels/wecom/README.md @@ -0,0 +1,104 @@ +> Back to [README](../../../README.md) + +# WeCom + +PicoClaw now exposes WeCom as a single `channels.wecom` channel built on the official WeCom AI Bot WebSocket API. +This replaces the legacy `wecom`, `wecom_app`, and `wecom_aibot` split with one configuration model. + +## What This Channel Supports + +- Direct chat and group chat delivery +- Channel-side streaming replies over WeCom's AI Bot protocol +- Incoming text, voice, image, file, video, and mixed messages +- Outbound text and media replies (`image`, `file`, `voice`, `video`) +- QR-based CLI onboarding with `picoclaw auth wecom` +- Shared allowlist and `reasoning_channel_id` routing + +> No public webhook callback URL is required for this channel. PicoClaw opens an outbound WebSocket connection to WeCom. + +## Quick Start + +### Option 1: QR Login From CLI + +Run: + +```bash +picoclaw auth wecom +``` + +The command prints a QR code in the terminal, waits for confirmation in WeCom, and then writes the resulting +`bot_id` and `secret` into `channels.wecom`. + +Use `--timeout` if you want to wait longer: + +```bash +picoclaw auth wecom --timeout 10m +``` + +### Option 2: Configure Manually + +```json +{ + "channels": { + "wecom": { + "enabled": true, + "bot_id": "YOUR_BOT_ID", + "secret": "YOUR_SECRET", + "websocket_url": "wss://openws.work.weixin.qq.com", + "send_thinking_message": true, + "allow_from": [], + "reasoning_channel_id": "" + } + } +} +``` + +## Configuration + +| Field | Type | Required | Description | +| ----- | ---- | -------- | ----------- | +| `enabled` | bool | No | Enables the WeCom channel. | +| `bot_id` | string | Yes | WeCom AI Bot identifier. Required when the channel is enabled. | +| `secret` | string | Yes | WeCom AI Bot secret. Required when the channel is enabled. | +| `websocket_url` | string | No | WebSocket endpoint. Defaults to `wss://openws.work.weixin.qq.com`. | +| `send_thinking_message` | bool | No | Sends an initial `Processing...` chunk before the final streamed reply. Defaults to `true`. | +| `allow_from` | array | No | Sender allowlist. Empty means allow all senders. | +| `reasoning_channel_id` | string | No | Optional destination for reasoning/thinking output. | + +## Runtime Behavior + +- PicoClaw keeps the active WeCom turn so normal replies can continue the same stream when possible. +- If streaming is no longer available, replies fall back to active push delivery to the resolved chat route. +- Incoming media is downloaded into the media store before being handed to the agent. +- Outbound media is uploaded to WeCom in temporary chunks and then sent as a regular media message. + +## Migration Notes + +This branch removes the old multi-channel WeCom model. + +| Previous config | Now | +| --------------- | --- | +| `channels.wecom` webhook bot | Replace with `channels.wecom` using `bot_id` + `secret`. | +| `channels.wecom_app` | Remove it and use `channels.wecom`. | +| `channels.wecom_aibot` | Move the config to `channels.wecom`. | +| `token`, `encoding_aes_key`, `webhook_url`, `webhook_path` | No longer used by the WeCom channel. | +| `corp_id`, `corp_secret`, `agent_id` | No longer used by the WeCom channel. | +| `welcome_message`, `processing_message`, `max_steps` under WeCom | No longer part of the WeCom channel config. | + +## Troubleshooting + +### `picoclaw auth wecom` times out + +- Re-run with a larger `--timeout`. +- Make sure the QR code was confirmed inside WeCom, not only scanned. + +### WebSocket connection fails + +- Verify `bot_id` and `secret`. +- Confirm the host can reach `wss://openws.work.weixin.qq.com`. + +### Replies do not arrive + +- Check whether `allow_from` blocks the sender. +- Check launcher or startup validation for missing `channels.wecom.bot_id` / `channels.wecom.secret`. + diff --git a/docs/channels/wecom/README.zh.md b/docs/channels/wecom/README.zh.md new file mode 100644 index 000000000..6b4a5e495 --- /dev/null +++ b/docs/channels/wecom/README.zh.md @@ -0,0 +1,104 @@ +> 返回 [README](../../../README.zh.md) + +# 企业微信 + +PicoClaw 现在将企业微信统一为一个 `channels.wecom` 渠道,并基于企业微信官方 AI Bot WebSocket 协议实现。 +这取代了旧的 `wecom`、`wecom_app`、`wecom_aibot` 三套配置模型。 + +## 当前渠道能力 + +- 支持私聊和群聊 +- 支持企业微信侧流式回复 +- 支持接收文本、语音、图片、文件、视频和 mixed 消息 +- 支持发送文本与媒体消息(`image`、`file`、`voice`、`video`) +- 支持通过 `picoclaw auth wecom` 扫码写入配置 +- 支持统一白名单与 `reasoning_channel_id` + +> 这个渠道不再需要公网 webhook 回调地址。PicoClaw 会主动向企业微信发起 WebSocket 连接。 + +## 快速开始 + +### 方式 1:命令行扫码登录 + +运行: + +```bash +picoclaw auth wecom +``` + +该命令会在终端打印二维码,等待你在企业微信中确认,然后把生成的 `bot_id` 和 `secret` 写入 +`channels.wecom`。 + +如果需要更长等待时间,可以加 `--timeout`: + +```bash +picoclaw auth wecom --timeout 10m +``` + +### 方式 2:手动配置 + +```json +{ + "channels": { + "wecom": { + "enabled": true, + "bot_id": "YOUR_BOT_ID", + "secret": "YOUR_SECRET", + "websocket_url": "wss://openws.work.weixin.qq.com", + "send_thinking_message": true, + "allow_from": [], + "reasoning_channel_id": "" + } + } +} +``` + +## 配置字段 + +| 字段 | 类型 | 必填 | 说明 | +| ---- | ---- | ---- | ---- | +| `enabled` | bool | 否 | 是否启用企业微信渠道。 | +| `bot_id` | string | 是 | 企业微信 AI Bot 标识。渠道启用时必填。 | +| `secret` | string | 是 | 企业微信 AI Bot 密钥。渠道启用时必填。 | +| `websocket_url` | string | 否 | WebSocket 地址,默认 `wss://openws.work.weixin.qq.com`。 | +| `send_thinking_message` | bool | 否 | 是否在流式最终回复前先发送一段 `Processing...` 开场消息,默认 `true`。 | +| `allow_from` | array | 否 | 发送者白名单;空数组表示允许所有发送者。 | +| `reasoning_channel_id` | string | 否 | 可选的 reasoning/thinking 输出目标。 | + +## 运行时行为 + +- PicoClaw 会保留当前会话对应的企业微信 turn,优先继续同一个流式回复。 +- 如果流式上下文已经失效,回复会自动回退到主动推送消息。 +- 收到的媒体会先下载到 media store,再交给 Agent 处理。 +- 发出的媒体会先按分片上传到企业微信,再作为普通媒体消息发送。 + +## 迁移说明 + +这个分支移除了旧的多通道企业微信模型。 + +| 旧配置 | 现在怎么做 | +| ------ | ---------- | +| `channels.wecom` webhook 机器人 | 改为使用 `bot_id` + `secret` 的 `channels.wecom`。 | +| `channels.wecom_app` | 删除,统一迁移到 `channels.wecom`。 | +| `channels.wecom_aibot` | 配置迁移到 `channels.wecom`。 | +| `token`、`encoding_aes_key`、`webhook_url`、`webhook_path` | 企业微信渠道不再使用这些字段。 | +| `corp_id`、`corp_secret`、`agent_id` | 企业微信渠道不再使用这些字段。 | +| 企业微信下的 `welcome_message`、`processing_message`、`max_steps` | 不再属于企业微信渠道配置。 | + +## 常见问题 + +### `picoclaw auth wecom` 超时 + +- 用更大的 `--timeout` 重新执行。 +- 确认是在企业微信里完成了确认,而不只是扫描二维码。 + +### WebSocket 连接失败 + +- 检查 `bot_id` 和 `secret` 是否正确。 +- 确认运行环境可以访问 `wss://openws.work.weixin.qq.com`。 + +### 消息没有回到企业微信 + +- 检查 `allow_from` 是否拦截了发送者。 +- 检查启动日志或 launcher 校验,确认 `channels.wecom.bot_id` / `channels.wecom.secret` 已填写。 + diff --git a/docs/channels/weixin/README.md b/docs/channels/weixin/README.md index 22687fec4..0c51ff3c5 100644 --- a/docs/channels/weixin/README.md +++ b/docs/channels/weixin/README.md @@ -7,7 +7,7 @@ PicoClaw supports connecting to your personal WeChat account using the official The easiest way to set up the Weixin channel is using the interactive onboarding command: ```bash -picoclaw onboard weixin +picoclaw auth weixin ``` This command will: diff --git a/docs/channels/weixin/README.zh.md b/docs/channels/weixin/README.zh.md index d5e6f0a49..0f1181878 100644 --- a/docs/channels/weixin/README.zh.md +++ b/docs/channels/weixin/README.zh.md @@ -7,7 +7,7 @@ PicoClaw 支持使用腾讯官方 iLink API 连接您的个人微信账号。 最简单的方法是使用交互式 onboarding 命令进行一键激活: ```bash -picoclaw onboard weixin +picoclaw auth weixin ``` 该命令将: diff --git a/docs/chat-apps.md b/docs/chat-apps.md index d300f5544..3d01994ff 100644 --- a/docs/chat-apps.md +++ b/docs/chat-apps.md @@ -6,7 +6,7 @@ Talk to your picoclaw through Telegram, Discord, WhatsApp, Matrix, QQ, DingTalk, LINE, WeCom, Feishu, Slack, IRC, OneBot, MaixCam, or Pico (native protocol) -> **Note**: All webhook-based channels (LINE, WeCom, etc.) are served on a single shared Gateway HTTP server (`gateway.host`:`gateway.port`, default `127.0.0.1:18790`). There are no per-channel ports to configure. Note: Feishu uses WebSocket/SDK mode and does not use the shared HTTP webhook server. +> **Note**: Channels that rely on HTTP callbacks share a single Gateway HTTP server (`gateway.host`:`gateway.port`, default `127.0.0.1:18790`). Socket/stream-based channels such as Feishu, DingTalk, and WeCom do not rely on the shared webhook server for inbound delivery. | Channel | Difficulty | Description | Documentation | | -------------------- | ------------------ | ----------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------- | @@ -19,7 +19,7 @@ Talk to your picoclaw through Telegram, Discord, WhatsApp, Matrix, QQ, DingTalk, | **QQ** | ⭐⭐ Medium | Official bot API, Chinese community | [Docs](channels/qq/README.md) | | **DingTalk** | ⭐⭐ Medium | Stream mode (no public IP needed), enterprise | [Docs](channels/dingtalk/README.md) | | **LINE** | ⭐⭐⭐ Advanced | HTTPS Webhook required | [Docs](channels/line/README.md) | -| **WeCom (企业微信)** | ⭐⭐⭐ Advanced | Group Bot (Webhook), custom App (API), AI Bot | [Bot](channels/wecom/wecom_bot/README.md) / [App](channels/wecom/wecom_app/README.md) / [AI Bot](channels/wecom/wecom_aibot/README.md) | +| **WeCom (企业微信)** | ⭐⭐⭐ Advanced | Official AI Bot over WebSocket, streaming + media | [Docs](channels/wecom/README.md) | | **Feishu (飞书)** | ⭐⭐⭐ Advanced | Enterprise collaboration, feature-rich | [Docs](channels/feishu/README.md) | | **IRC** | ⭐⭐ Medium | Server + TLS configuration | [Docs](#irc) | | **OneBot** | ⭐⭐ Medium | NapCat/Go-CQHTTP compatible, community ecosystem | [Docs](channels/onebot/README.md) | @@ -190,7 +190,7 @@ PicoClaw supports connecting to your personal WeChat account using the official Run the interactive QR login flow: ```bash -picoclaw onboard weixin +picoclaw auth weixin ``` Scan the printed QR code with your WeChat mobile app. On success, the token is saved to your config. @@ -380,102 +380,34 @@ picoclaw gateway
WeCom (企业微信) -PicoClaw supports three types of WeCom integration: +PicoClaw now exposes WeCom as a single AI Bot channel over WebSocket. +No public webhook callback URL is required. -**Option 1: WeCom Bot (Bot)** - Easier setup, supports group chats -**Option 2: WeCom App (Custom App)** - More features, proactive messaging, private chat only -**Option 3: WeCom AI Bot (AI Bot)** - Official AI Bot, streaming replies, supports group & private chat +See [WeCom Configuration Guide](channels/wecom/README.md) for the full configuration reference and migration notes. -See [WeCom AI Bot Configuration Guide](channels/wecom/wecom_aibot/README.md) for detailed setup instructions. +**Quick Setup - Recommended** -**Quick Setup - WeCom Bot:** +**1. Authenticate** -**1. Create a bot** +```bash +picoclaw auth wecom +``` -* Go to WeCom Admin Console → Group Chat → Add Group Bot -* Copy the webhook URL (format: `https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=xxx`) +This command shows a QR code, waits for approval in WeCom, and writes `bot_id` + `secret` into `channels.wecom`. -**2. Configure** +**2. Configure manually if needed** ```json { "channels": { "wecom": { "enabled": true, - "token": "YOUR_TOKEN", - "encoding_aes_key": "YOUR_ENCODING_AES_KEY", - "webhook_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=YOUR_KEY", - "webhook_path": "/webhook/wecom", - "allow_from": [] - } - } -} -``` - -> WeCom webhook is served on the shared Gateway server (`gateway.host`:`gateway.port`, default `127.0.0.1:18790`). - -**Quick Setup - WeCom App:** - -**1. Create an app** - -* Go to WeCom Admin Console → App Management → Create App -* Copy **AgentId** and **Secret** -* Go to "My Company" page, copy **CorpID** - -**2. Configure receive message** - -* In App details, click "Receive Message" → "Set API" -* Set URL to `http://your-server:18790/webhook/wecom-app` -* Generate **Token** and **EncodingAESKey** - -**3. Configure** - -```json -{ - "channels": { - "wecom_app": { - "enabled": true, - "corp_id": "wwxxxxxxxxxxxxxxxx", - "corp_secret": "YOUR_CORP_SECRET", - "agent_id": 1000002, - "token": "YOUR_TOKEN", - "encoding_aes_key": "YOUR_ENCODING_AES_KEY", - "webhook_path": "/webhook/wecom-app", - "allow_from": [] - } - } -} -``` - -**4. Run** - -```bash -picoclaw gateway -``` - -> **Note**: WeCom webhook callbacks are served on the Gateway port (default 18790). Use a reverse proxy for HTTPS. - -**Quick Setup - WeCom AI Bot:** - -**1. Create an AI Bot** - -* Go to WeCom Admin Console → App Management → AI Bot -* In the AI Bot settings, configure callback URL: `http://your-server:18790/webhook/wecom-aibot` -* Copy **Token** and click "Random Generate" for **EncodingAESKey** - -**2. Configure** - -```json -{ - "channels": { - "wecom_aibot": { - "enabled": true, - "token": "YOUR_TOKEN", - "encoding_aes_key": "YOUR_43_CHAR_ENCODING_AES_KEY", - "webhook_path": "/webhook/wecom-aibot", + "bot_id": "YOUR_BOT_ID", + "secret": "YOUR_SECRET", + "websocket_url": "wss://openws.work.weixin.qq.com", + "send_thinking_message": true, "allow_from": [], - "welcome_message": "Hello! How can I help you?", - "processing_message": "⏳ Processing, please wait. The results will be sent shortly." + "reasoning_channel_id": "" } } } @@ -487,7 +419,7 @@ picoclaw gateway picoclaw gateway ``` -> **Note**: WeCom AI Bot uses streaming pull protocol — no reply timeout concerns. Long tasks (>30 seconds) automatically switch to `response_url` push delivery. +> Legacy `wecom_app` and `wecom_aibot` entries are replaced by the unified `channels.wecom` config in this branch.
diff --git a/docs/fr/chat-apps.md b/docs/fr/chat-apps.md index daff951f4..c36e002ff 100644 --- a/docs/fr/chat-apps.md +++ b/docs/fr/chat-apps.md @@ -179,7 +179,7 @@ PicoClaw prend en charge la connexion à votre compte WeChat personnel via l'API Lancez le flux de connexion interactif par QR code : ```bash -picoclaw onboard weixin +picoclaw auth weixin ``` Scannez le QR code affiché avec votre application WeChat mobile. Une fois connecté, le token est sauvegardé dans votre configuration. diff --git a/docs/ja/chat-apps.md b/docs/ja/chat-apps.md index 789c0125f..341dc4aba 100644 --- a/docs/ja/chat-apps.md +++ b/docs/ja/chat-apps.md @@ -184,7 +184,7 @@ PicoClaw は Tencent iLink 公式 API を使用して WeChat 個人アカウン インタラクティブな QR ログインフローを実行します: ```bash -picoclaw onboard weixin +picoclaw auth weixin ``` WeChat モバイルアプリで表示された QR コードをスキャンしてください。ログイン成功後、トークンが設定ファイルに保存されます。 diff --git a/docs/pt-br/chat-apps.md b/docs/pt-br/chat-apps.md index 4fa59b1b2..92fda329c 100644 --- a/docs/pt-br/chat-apps.md +++ b/docs/pt-br/chat-apps.md @@ -179,7 +179,7 @@ O PicoClaw suporta conexão com sua conta pessoal do WeChat usando a API oficial Execute o fluxo de login interativo por QR code: ```bash -picoclaw onboard weixin +picoclaw auth weixin ``` Escaneie o QR code exibido com seu aplicativo WeChat mobile. Após o login bem-sucedido, o token é salvo na sua configuração. diff --git a/docs/vi/chat-apps.md b/docs/vi/chat-apps.md index d907e5e91..5e2a81ccf 100644 --- a/docs/vi/chat-apps.md +++ b/docs/vi/chat-apps.md @@ -179,7 +179,7 @@ PicoClaw hỗ trợ kết nối với tài khoản WeChat cá nhân của bạn Chạy luồng đăng nhập QR tương tác: ```bash -picoclaw onboard weixin +picoclaw auth weixin ``` Quét mã QR được in ra bằng ứng dụng WeChat trên điện thoại. Sau khi đăng nhập thành công, token sẽ được lưu vào cấu hình. diff --git a/docs/zh/chat-apps.md b/docs/zh/chat-apps.md index aeba7d460..47add38ac 100644 --- a/docs/zh/chat-apps.md +++ b/docs/zh/chat-apps.md @@ -6,7 +6,7 @@ PicoClaw 支持多种聊天平台,使您的 Agent 能够连接到任何地方。 -> **注意**: 所有 Webhook 类渠道(LINE、WeCom 等)均挂载在同一个 Gateway HTTP 服务器上(`gateway.host`:`gateway.port`,默认 `127.0.0.1:18790`),无需为每个渠道单独配置端口。注意:飞书(Feishu)使用 WebSocket/SDK 模式,不通过该共享 HTTP webhook 服务器接收消息。 +> **注意**: 依赖 HTTP 回调的渠道共用同一个 Gateway HTTP 服务器(`gateway.host`:`gateway.port`,默认 `127.0.0.1:18790`),无需为每个渠道单独配置端口。飞书、钉钉、企业微信这类 Socket/Stream 模式渠道不依赖共享 webhook 服务器来接收入站消息。 ### 核心渠道 @@ -21,7 +21,7 @@ PicoClaw 支持多种聊天平台,使您的 Agent 能够连接到任何地方 | **QQ** | ⭐⭐ 中等 | 官方机器人 API,适合国内社群 | [查看文档](../channels/qq/README.zh.md) | | **钉钉 (DingTalk)** | ⭐⭐ 中等 | Stream 模式无需公网,企业办公首选 | [查看文档](../channels/dingtalk/README.zh.md) | | **LINE** | ⭐⭐⭐ 较难 | 需要 HTTPS Webhook | [查看文档](../channels/line/README.zh.md) | -| **企业微信 (WeCom)** | ⭐⭐⭐ 较难 | 支持群机器人(Webhook)、自建应用(API)和智能机器人(AI Bot) | [Bot 文档](../channels/wecom/wecom_bot/README.zh.md) / [App 文档](../channels/wecom/wecom_app/README.zh.md) / [AI Bot 文档](../channels/wecom/wecom_aibot/README.zh.md) | +| **企业微信 (WeCom)** | ⭐⭐⭐ 较难 | 官方 AI Bot WebSocket 接入,支持流式回复和媒体消息 | [查看文档](../channels/wecom/README.zh.md) | | **飞书 (Feishu)** | ⭐⭐⭐ 较难 | 企业级协作,功能丰富 | [查看文档](../channels/feishu/README.zh.md) | | **IRC** | ⭐⭐ 中等 | 服务器 + TLS 配置 | [查看文档](#irc) | | **OneBot** | ⭐⭐ 中等 | 兼容 NapCat/Go-CQHTTP,社区生态丰富 | [查看文档](../channels/onebot/README.zh.md) | @@ -191,7 +191,7 @@ PicoClaw 通过腾讯 iLink 官方 API 支持连接微信个人号。 运行交互式扫码登录流程: ```bash -picoclaw onboard weixin +picoclaw auth weixin ``` 用微信手机端扫描打印出的二维码。登录成功后,token 会自动保存到配置文件。 @@ -492,102 +492,34 @@ picoclaw gateway
企业微信 (WeCom) -PicoClaw 支持三种企业微信集成方式: +PicoClaw 现在将企业微信统一为一个基于 WebSocket 的 AI Bot 渠道。 +它不再需要公网 webhook 回调地址。 -**方式 1: 群机器人 (Bot)** — 设置简单,支持群聊 -**方式 2: 自建应用 (App)** — 功能更多,支持主动推送,仅私聊 -**方式 3: 智能机器人 (AI Bot)** — 官方 AI Bot,流式回复,支持群聊和私聊 +完整配置说明和迁移说明请参考 [企业微信配置指南](../channels/wecom/README.zh.md)。 -详细设置请参考 [企业微信 AI Bot 配置指南](../channels/wecom/wecom_aibot/README.zh.md)。 +**推荐快速接入** -**快速设置 — 群机器人:** +**1. 认证** -**1. 创建 Bot** +```bash +picoclaw auth wecom +``` -* 企业微信管理后台 → 群聊 → 添加群机器人 -* 复制 Webhook URL(格式:`https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=xxx`) +该命令会显示二维码,等待你在企业微信里确认,然后把 `bot_id` 和 `secret` 写入 `channels.wecom`。 -**2. 配置** +**2. 如需手动配置** ```json { "channels": { "wecom": { "enabled": true, - "token": "YOUR_TOKEN", - "encoding_aes_key": "YOUR_ENCODING_AES_KEY", - "webhook_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=YOUR_KEY", - "webhook_path": "/webhook/wecom", - "allow_from": [] - } - } -} -``` - -> WeCom Webhook 挂载在共享 Gateway 服务器上(`gateway.host`:`gateway.port`,默认 `127.0.0.1:18790`)。 - -**快速设置 — 自建应用:** - -**1. 创建应用** - -* 企业微信管理后台 → 应用管理 → 创建应用 -* 复制 **AgentId** 和 **Secret** -* 前往"我的企业"页面,复制 **CorpID** - -**2. 配置接收消息** - -* 在应用详情中,点击"接收消息" → "设置 API" -* 设置 URL 为 `http://your-server:18790/webhook/wecom-app` -* 生成 **Token** 和 **EncodingAESKey** - -**3. 配置** - -```json -{ - "channels": { - "wecom_app": { - "enabled": true, - "corp_id": "wwxxxxxxxxxxxxxxxx", - "corp_secret": "YOUR_CORP_SECRET", - "agent_id": 1000002, - "token": "YOUR_TOKEN", - "encoding_aes_key": "YOUR_ENCODING_AES_KEY", - "webhook_path": "/webhook/wecom-app", - "allow_from": [] - } - } -} -``` - -**4. 运行** - -```bash -picoclaw gateway -``` - -> **注意**: WeCom Webhook 回调挂载在 Gateway 端口(默认 18790)。使用反向代理配置 HTTPS。 - -**快速设置 — 智能机器人 (AI Bot):** - -**1. 创建 AI Bot** - -* 企业微信管理后台 → 应用管理 → AI Bot -* 在 AI Bot 设置中配置回调 URL:`http://your-server:18790/webhook/wecom-aibot` -* 复制 **Token** 并点击"随机生成" **EncodingAESKey** - -**2. 配置** - -```json -{ - "channels": { - "wecom_aibot": { - "enabled": true, - "token": "YOUR_TOKEN", - "encoding_aes_key": "YOUR_43_CHAR_ENCODING_AES_KEY", - "webhook_path": "/webhook/wecom-aibot", + "bot_id": "YOUR_BOT_ID", + "secret": "YOUR_SECRET", + "websocket_url": "wss://openws.work.weixin.qq.com", + "send_thinking_message": true, "allow_from": [], - "welcome_message": "你好!有什么可以帮你的?", - "processing_message": "⏳ Processing, please wait. The results will be sent shortly." + "reasoning_channel_id": "" } } } @@ -599,7 +531,7 @@ picoclaw gateway picoclaw gateway ``` -> **注意**: 企业微信 AI Bot 使用流式拉取协议,无回复超时问题。长任务(>30 秒)会自动切换到 `response_url` 推送投递。 +> 这个分支中旧的 `wecom_app` 和 `wecom_aibot` 配置已经被统一的 `channels.wecom` 替代。
diff --git a/go.mod b/go.mod index bce41d0d3..e9ef37e98 100644 --- a/go.mod +++ b/go.mod @@ -20,6 +20,7 @@ require ( github.com/gorilla/websocket v1.5.3 github.com/h2non/filetype v1.1.3 github.com/larksuite/oapi-sdk-go/v3 v3.5.3 + github.com/mattn/go-sqlite3 v1.14.34 github.com/mdp/qrterminal/v3 v3.2.1 github.com/modelcontextprotocol/go-sdk v1.4.1 github.com/mymmrac/telego v1.7.0 @@ -31,6 +32,7 @@ require ( github.com/spf13/cobra v1.10.2 github.com/stretchr/testify v1.11.1 github.com/tencent-connect/botgo v0.2.1 + go.mau.fi/util v0.9.7 go.mau.fi/whatsmeow v0.0.0-20260219150138-7ae702b1eed4 golang.org/x/oauth2 v0.36.0 golang.org/x/term v0.41.0 @@ -77,7 +79,6 @@ require ( github.com/spf13/pflag v1.0.10 // indirect github.com/vektah/gqlparser/v2 v2.5.27 // indirect go.mau.fi/libsignal v0.2.1 // indirect - go.mau.fi/util v0.9.7 // indirect golang.org/x/exp v0.0.0-20260312153236-7ab1446f8b90 // indirect golang.org/x/text v0.35.0 // indirect modernc.org/libc v1.67.6 // indirect diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index 725d42614..995c5720b 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -96,14 +96,15 @@ type continuationTarget struct { } const ( - defaultResponse = "The model returned an empty response. This may indicate a provider error or token limit." - toolLimitResponse = "I've reached `max_tool_iterations` without a final response. Increase `max_tool_iterations` in config.json if this task needs more tool steps." - sessionKeyAgentPrefix = "agent:" - metadataKeyAccountID = "account_id" - metadataKeyGuildID = "guild_id" - metadataKeyTeamID = "team_id" - metadataKeyParentPeerKind = "parent_peer_kind" - metadataKeyParentPeerID = "parent_peer_id" + defaultResponse = "The model returned an empty response. This may indicate a provider error or token limit." + toolLimitResponse = "I've reached `max_tool_iterations` without a final response. Increase `max_tool_iterations` in config.json if this task needs more tool steps." + handledToolResponseSummary = "Requested output delivered via tool attachment." + sessionKeyAgentPrefix = "agent:" + metadataKeyAccountID = "account_id" + metadataKeyGuildID = "guild_id" + metadataKeyTeamID = "team_id" + metadataKeyParentPeerKind = "parent_peer_kind" + metadataKeyParentPeerID = "parent_peer_id" ) func NewAgentLoop( @@ -1030,13 +1031,13 @@ func (al *AgentLoop) GetConfig() *config.Config { func (al *AgentLoop) SetMediaStore(s media.MediaStore) { al.mediaStore = s - // Propagate store to send_file tools in all agents. + // Propagate store to all registered tools that can emit media. registry := al.GetRegistry() - registry.ForEachTool("send_file", func(t tools.Tool) { - if sf, ok := t.(*tools.SendFileTool); ok { - sf.SetMediaStore(s) + for _, agentID := range registry.ListAgentIDs() { + if agent, ok := registry.GetAgent(agentID); ok { + agent.Tools.SetMediaStore(s) } - }) + } } // SetTranscriber injects a voice transcriber for agent-level audio transcription. @@ -2165,6 +2166,7 @@ turnLoop: "iteration": iteration, }) + allResponsesHandled := len(normalizedToolCalls) > 0 assistantMsg := providers.Message{ Role: "assistant", Content: response.Content, @@ -2221,6 +2223,7 @@ turnLoop: toolArgs = toolReq.Arguments } case HookActionDenyTool: + allResponsesHandled = false denyContent := hookDeniedToolContent("Tool execution denied by hook", decision.Reason) al.emitEvent( EventKindToolExecSkipped, @@ -2260,6 +2263,7 @@ turnLoop: ChatID: ts.chatID, }) if !approval.Approved { + allResponsesHandled = false denyContent := hookDeniedToolContent("Tool execution denied by approval hook", approval.Reason) al.emitEvent( EventKindToolExecSkipped, @@ -2333,10 +2337,7 @@ turnLoop: } // Determine content for the agent loop (ForLLM or error). - content := result.ForLLM - if content == "" && result.Err != nil { - content = result.Err.Error() - } + content := result.ContentForLLM() if content == "" { return } @@ -2420,6 +2421,50 @@ turnLoop: if toolResult == nil { toolResult = tools.ErrorResult("hook returned nil tool result") } + if len(toolResult.Media) > 0 && toolResult.ResponseHandled { + parts := make([]bus.MediaPart, 0, len(toolResult.Media)) + for _, ref := range toolResult.Media { + part := bus.MediaPart{Ref: ref} + if al.mediaStore != nil { + if _, meta, err := al.mediaStore.ResolveWithMeta(ref); err == nil { + part.Filename = meta.Filename + part.ContentType = meta.ContentType + part.Type = inferMediaType(meta.Filename, meta.ContentType) + } + } + parts = append(parts, part) + } + outboundMedia := bus.OutboundMediaMessage{ + Channel: ts.channel, + ChatID: ts.chatID, + Parts: parts, + } + if al.channelManager != nil && ts.channel != "" && !constants.IsInternalChannel(ts.channel) { + if err := al.channelManager.SendMedia(ctx, outboundMedia); err != nil { + logger.WarnCF("agent", "Failed to deliver handled tool media", + map[string]any{ + "agent_id": ts.agent.ID, + "tool": toolName, + "channel": ts.channel, + "chat_id": ts.chatID, + "error": err.Error(), + }) + toolResult = tools.ErrorResult(fmt.Sprintf("failed to deliver attachment: %v", err)).WithError(err) + } + } else if al.bus != nil { + al.bus.PublishOutboundMedia(ctx, outboundMedia) + // Queuing media is only best-effort; it has not been delivered yet. + toolResult.ResponseHandled = false + } + } + + if len(toolResult.Media) > 0 && !toolResult.ResponseHandled { + toolResult.ArtifactTags = buildArtifactTags(al.mediaStore, toolResult.Media) + } + + if !toolResult.ResponseHandled { + allResponsesHandled = false + } if !toolResult.Silent && toolResult.ForUser != "" && ts.opts.SendResponse { al.bus.PublishOutbound(ctx, bus.OutboundMessage{ @@ -2434,30 +2479,7 @@ turnLoop: }) } - if len(toolResult.Media) > 0 { - parts := make([]bus.MediaPart, 0, len(toolResult.Media)) - for _, ref := range toolResult.Media { - part := bus.MediaPart{Ref: ref} - if al.mediaStore != nil { - if _, meta, err := al.mediaStore.ResolveWithMeta(ref); err == nil { - part.Filename = meta.Filename - part.ContentType = meta.ContentType - part.Type = inferMediaType(meta.Filename, meta.ContentType) - } - } - parts = append(parts, part) - } - al.bus.PublishOutboundMedia(ctx, bus.OutboundMediaMessage{ - Channel: ts.channel, - ChatID: ts.chatID, - Parts: parts, - }) - } - - contentForLLM := toolResult.ForLLM - if contentForLLM == "" && toolResult.Err != nil { - contentForLLM = toolResult.Err.Error() - } + contentForLLM := toolResult.ContentForLLM() // Filter sensitive data (API keys, tokens, secrets) before sending to LLM if al.cfg.Tools.IsFilterSensitiveDataEnabled() { @@ -2552,6 +2574,70 @@ turnLoop: } } + if allResponsesHandled { + if len(pendingMessages) > 0 { + logger.InfoCF("agent", "Pending steering exists after handled tool delivery; continuing turn before finalizing", + map[string]any{ + "agent_id": ts.agent.ID, + "steering_count": len(pendingMessages), + "session_key": ts.sessionKey, + }) + finalContent = "" + goto turnLoop + } + + if steerMsgs := al.dequeueSteeringMessagesForScope(ts.sessionKey); len(steerMsgs) > 0 { + logger.InfoCF("agent", "Steering arrived after handled tool delivery; continuing turn before finalizing", + map[string]any{ + "agent_id": ts.agent.ID, + "steering_count": len(steerMsgs), + "session_key": ts.sessionKey, + }) + pendingMessages = append(pendingMessages, steerMsgs...) + finalContent = "" + goto turnLoop + } + + summaryMsg := providers.Message{ + Role: "assistant", + Content: handledToolResponseSummary, + } + + if !ts.opts.NoHistory { + ts.agent.Sessions.AddMessage(ts.sessionKey, summaryMsg.Role, summaryMsg.Content) + ts.recordPersistedMessage(summaryMsg) + if err := ts.agent.Sessions.Save(ts.sessionKey); err != nil { + turnStatus = TurnEndStatusError + al.emitEvent( + EventKindError, + ts.eventMeta("runTurn", "turn.error"), + ErrorPayload{ + Stage: "session_save", + Message: err.Error(), + }, + ) + return turnResult{}, err + } + } + if ts.opts.EnableSummary { + al.maybeSummarize(ts.agent, ts.sessionKey, ts.scope) + } + + ts.setPhase(TurnPhaseCompleted) + ts.setFinalContent("") + logger.InfoCF("agent", "Tool output satisfied delivery; ending turn without follow-up LLM", + map[string]any{ + "agent_id": ts.agent.ID, + "iteration": iteration, + "tool_count": len(normalizedToolCalls), + }) + return turnResult{ + finalContent: "", + status: turnStatus, + followUps: append([]bus.InboundMessage(nil), ts.followUps...), + }, nil + } + ts.agent.Tools.TickTTL() logger.DebugCF("agent", "TTL tick after tool execution", map[string]any{ "agent_id": ts.agent.ID, "iteration": iteration, @@ -3159,6 +3245,97 @@ func (al *AgentLoop) handleCommand( } } +func activeSkillNames(agent *AgentInstance, opts processOptions) []string { + if agent == nil { + return nil + } + + combined := make([]string, 0, len(agent.SkillsFilter)+len(opts.ForcedSkills)) + combined = append(combined, agent.SkillsFilter...) + combined = append(combined, opts.ForcedSkills...) + if len(combined) == 0 { + return nil + } + + var resolved []string + seen := make(map[string]struct{}, len(combined)) + for _, name := range combined { + name = strings.TrimSpace(name) + if name == "" { + continue + } + if agent.ContextBuilder != nil { + if canonical, ok := agent.ContextBuilder.ResolveSkillName(name); ok { + name = canonical + } + } + key := strings.ToLower(name) + if _, ok := seen[key]; ok { + continue + } + seen[key] = struct{}{} + resolved = append(resolved, name) + } + + return resolved +} + +func (al *AgentLoop) applyExplicitSkillCommand( + raw string, + agent *AgentInstance, + opts *processOptions, +) (matched bool, handled bool, reply string) { + cmdName, ok := commands.CommandName(raw) + if !ok || cmdName != "use" { + return false, false, "" + } + + if agent == nil || agent.ContextBuilder == nil { + return true, true, commandsUnavailableSkillMessage() + } + + parts := strings.Fields(strings.TrimSpace(raw)) + if len(parts) < 2 { + return true, true, buildUseCommandHelp(agent) + } + + arg := strings.TrimSpace(parts[1]) + if strings.EqualFold(arg, "clear") || strings.EqualFold(arg, "off") { + if opts != nil { + al.clearPendingSkills(opts.SessionKey) + } + return true, true, "Cleared pending skill override." + } + + skillName, ok := agent.ContextBuilder.ResolveSkillName(arg) + if !ok { + return true, true, fmt.Sprintf("Unknown skill: %s\nUse /list skills to see installed skills.", arg) + } + + if len(parts) < 3 { + if opts == nil || strings.TrimSpace(opts.SessionKey) == "" { + return true, true, commandsUnavailableSkillMessage() + } + al.setPendingSkills(opts.SessionKey, []string{skillName}) + return true, true, fmt.Sprintf( + "Skill %q is armed for your next message. Send your next prompt normally, or use /use clear to cancel.", + skillName, + ) + } + + message := strings.TrimSpace(strings.Join(parts[2:], " ")) + if message == "" { + return true, true, buildUseCommandHelp(agent) + } + + if opts != nil { + opts.ForcedSkills = append(opts.ForcedSkills, skillName) + opts.UserMessage = message + } + + return true, false, "" +} + func (al *AgentLoop) buildCommandsRuntime(agent *AgentInstance, opts *processOptions) *commands.Runtime { registry := al.GetRegistry() cfg := al.GetConfig() @@ -3199,6 +3376,9 @@ func (al *AgentLoop) buildCommandsRuntime(agent *AgentInstance, opts *processOpt return al.reloadFunc() } if agent != nil { + if agent.ContextBuilder != nil { + rt.ListSkillNames = agent.ContextBuilder.ListSkillNames + } rt.GetModelInfo = func() (string, string) { return agent.Model, resolvedCandidateProvider(agent.Candidates, cfg.Agents.Defaults.Provider) } @@ -3251,79 +3431,6 @@ func (al *AgentLoop) buildCommandsRuntime(agent *AgentInstance, opts *processOpt return rt } -func activeSkillNames(agent *AgentInstance, opts processOptions) []string { - var out []string - seen := make(map[string]struct{}) - - appendNames := func(names []string) { - for _, name := range names { - name = strings.TrimSpace(name) - if name == "" { - continue - } - if _, exists := seen[name]; exists { - continue - } - seen[name] = struct{}{} - out = append(out, name) - } - } - - if agent != nil { - appendNames(agent.SkillsFilter) - } - appendNames(opts.ForcedSkills) - - return out -} - -func (al *AgentLoop) applyExplicitSkillCommand( - raw string, - agent *AgentInstance, - opts *processOptions, -) (matched bool, handled bool, reply string) { - commandName, ok := commands.CommandName(raw) - if !ok || commandName != "use" { - return false, false, "" - } - - if agent == nil || agent.ContextBuilder == nil { - return true, true, commandsUnavailableSkillMessage() - } - - fields := strings.Fields(strings.TrimSpace(raw)) - if len(fields) < 2 { - return true, true, buildUseCommandHelp(agent) - } - - if strings.EqualFold(fields[1], "clear") || strings.EqualFold(fields[1], "off") { - al.clearPendingSkills(opts.SessionKey) - return true, true, "Cleared pending skill override." - } - - canonicalSkill, ok := agent.ContextBuilder.ResolveSkillName(fields[1]) - if !ok { - return true, true, fmt.Sprintf("Unknown skill: %s\nUse /list skills to see installed skills.", fields[1]) - } - - if len(fields) == 2 { - al.setPendingSkills(opts.SessionKey, []string{canonicalSkill}) - return true, true, fmt.Sprintf( - "Skill %q is armed for your next message.\nSend your next request normally, or use /use clear to cancel.", - canonicalSkill, - ) - } - - message := strings.TrimSpace(strings.Join(fields[2:], " ")) - if message == "" { - return true, true, buildUseCommandHelp(agent) - } - - opts.UserMessage = message - opts.ForcedSkills = append(opts.ForcedSkills, canonicalSkill) - return true, false, "" -} - func commandsUnavailableSkillMessage() string { return "Skill selection is unavailable in the current context." } diff --git a/pkg/agent/loop_media.go b/pkg/agent/loop_media.go index 1380f0214..e8314c10d 100644 --- a/pkg/agent/loop_media.go +++ b/pkg/agent/loop_media.go @@ -87,6 +87,24 @@ func resolveMediaRefs(messages []providers.Message, store media.MediaStore, maxS return result } +func buildArtifactTags(store media.MediaStore, refs []string) []string { + if store == nil || len(refs) == 0 { + return nil + } + + tags := make([]string, 0, len(refs)) + for _, ref := range refs { + localPath, meta, err := store.ResolveWithMeta(ref) + if err != nil { + continue + } + mime := detectMIME(localPath, meta) + tags = append(tags, buildPathTag(mime, localPath)) + } + + return tags +} + // detectMIME determines the MIME type from metadata or magic-bytes detection. // Returns empty string if detection fails. func detectMIME(localPath string, meta media.MediaMeta) string { diff --git a/pkg/agent/loop_test.go b/pkg/agent/loop_test.go index 976d25c4b..e0a5dffb3 100644 --- a/pkg/agent/loop_test.go +++ b/pkg/agent/loop_test.go @@ -33,6 +33,41 @@ func (f *fakeChannel) IsAllowed(string) bool { func (f *fakeChannel) IsAllowedSender(sender bus.SenderInfo) bool { return true } func (f *fakeChannel) ReasoningChannelID() string { return f.id } +type fakeMediaChannel struct { + fakeChannel + sentMedia []bus.OutboundMediaMessage +} + +func (f *fakeMediaChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error { + f.sentMedia = append(f.sentMedia, msg) + return nil +} + +func newStartedTestChannelManager( + t *testing.T, + msgBus *bus.MessageBus, + store media.MediaStore, + name string, + ch channels.Channel, +) *channels.Manager { + t.Helper() + + cm, err := channels.NewManager(&config.Config{}, msgBus, store) + if err != nil { + t.Fatalf("NewManager() error = %v", err) + } + cm.RegisterChannel(name, ch) + if err := cm.StartAll(context.Background()); err != nil { + t.Fatalf("StartAll() error = %v", err) + } + t.Cleanup(func() { + if err := cm.StopAll(context.Background()); err != nil { + t.Fatalf("StopAll() error = %v", err) + } + }) + return cm +} + type recordingProvider struct { lastMessages []providers.Message } @@ -289,6 +324,86 @@ func TestProcessMessage_UseCommandArmsSkillForNextMessage(t *testing.T) { } } +func TestApplyExplicitSkillCommand_ArmsSkillForNextMessage(t *testing.T) { + al, cfg, _, _, cleanup := newTestAgentLoop(t) + defer cleanup() + + if err := os.MkdirAll(filepath.Join(cfg.Agents.Defaults.Workspace, "skills", "finance-news"), 0o755); err != nil { + t.Fatalf("MkdirAll(skill) error = %v", err) + } + if err := os.WriteFile( + filepath.Join(cfg.Agents.Defaults.Workspace, "skills", "finance-news", "SKILL.md"), + []byte("# Finance News\n\nUse web tools for current finance updates.\n"), + 0o644, + ); err != nil { + t.Fatalf("WriteFile(SKILL.md) error = %v", err) + } + + agent := al.GetRegistry().GetDefaultAgent() + if agent == nil { + t.Fatal("expected default agent") + } + + opts := &processOptions{SessionKey: "agent:main:test"} + matched, handled, reply := al.applyExplicitSkillCommand("/use finance-news", agent, opts) + if !matched { + t.Fatal("expected /use command to match") + } + if !handled { + t.Fatal("expected /use without inline message to be handled immediately") + } + if !strings.Contains(reply, `Skill "finance-news" is armed for your next message`) { + t.Fatalf("unexpected reply: %q", reply) + } + + pending := al.takePendingSkills(opts.SessionKey) + if len(pending) != 1 || pending[0] != "finance-news" { + t.Fatalf("pending skills = %#v, want [finance-news]", pending) + } +} + +func TestApplyExplicitSkillCommand_InlineMessageMutatesOptions(t *testing.T) { + al, cfg, _, _, cleanup := newTestAgentLoop(t) + defer cleanup() + + if err := os.MkdirAll(filepath.Join(cfg.Agents.Defaults.Workspace, "skills", "finance-news"), 0o755); err != nil { + t.Fatalf("MkdirAll(skill) error = %v", err) + } + if err := os.WriteFile( + filepath.Join(cfg.Agents.Defaults.Workspace, "skills", "finance-news", "SKILL.md"), + []byte("# Finance News\n\nUse web tools for current finance updates.\n"), + 0o644, + ); err != nil { + t.Fatalf("WriteFile(SKILL.md) error = %v", err) + } + + agent := al.GetRegistry().GetDefaultAgent() + if agent == nil { + t.Fatal("expected default agent") + } + + opts := &processOptions{ + SessionKey: "agent:main:test", + UserMessage: "/use finance-news dammi le ultime news", + } + matched, handled, reply := al.applyExplicitSkillCommand(opts.UserMessage, agent, opts) + if !matched { + t.Fatal("expected /use command to match") + } + if handled { + t.Fatal("expected /use with inline message to fall through into normal agent execution") + } + if reply != "" { + t.Fatalf("unexpected reply: %q", reply) + } + if opts.UserMessage != "dammi le ultime news" { + t.Fatalf("opts.UserMessage = %q, want %q", opts.UserMessage, "dammi le ultime news") + } + if len(opts.ForcedSkills) != 1 || opts.ForcedSkills[0] != "finance-news" { + t.Fatalf("opts.ForcedSkills = %#v, want [finance-news]", opts.ForcedSkills) + } +} + func TestRecordLastChannel(t *testing.T) { al, cfg, msgBus, provider, cleanup := newTestAgentLoop(t) defer cleanup() @@ -455,6 +570,217 @@ func TestToolRegistry_GetDefinitions(t *testing.T) { } } +func TestProcessMessage_MediaToolHandledSkipsFollowUpLLMAndFinalText(t *testing.T) { + tmpDir := t.TempDir() + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + ModelName: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + + msgBus := bus.NewMessageBus() + provider := &handledMediaProvider{} + al := NewAgentLoop(cfg, msgBus, provider) + + store := media.NewFileMediaStore() + al.SetMediaStore(store) + telegramChannel := &fakeMediaChannel{fakeChannel: fakeChannel{id: "rid-telegram"}} + al.SetChannelManager(newStartedTestChannelManager(t, msgBus, store, "telegram", telegramChannel)) + + imagePath := filepath.Join(tmpDir, "screen.png") + if err := os.WriteFile(imagePath, []byte("fake screenshot"), 0o644); err != nil { + t.Fatalf("WriteFile(imagePath) error = %v", err) + } + + al.RegisterTool(&handledMediaTool{ + store: store, + path: imagePath, + }) + + response, err := al.processMessage(context.Background(), bus.InboundMessage{ + Channel: "telegram", + ChatID: "chat1", + SenderID: "user1", + Content: "take a screenshot of the screen and send it to me", + }) + if err != nil { + t.Fatalf("processMessage() error = %v", err) + } + if response != "" { + t.Fatalf("expected no final response when media tool already handled delivery, got %q", response) + } + if provider.calls != 1 { + t.Fatalf("expected exactly 1 LLM call, got %d", provider.calls) + } + if len(provider.toolCounts) != 1 { + t.Fatalf("expected tool counts for 1 provider call, got %d", len(provider.toolCounts)) + } + if provider.toolCounts[0] == 0 { + t.Fatal("expected tools to be available on the first LLM call") + } + + if len(telegramChannel.sentMedia) != 1 { + t.Fatalf("expected exactly 1 synchronously sent media message, got %d", len(telegramChannel.sentMedia)) + } + if telegramChannel.sentMedia[0].Channel != "telegram" || telegramChannel.sentMedia[0].ChatID != "chat1" { + t.Fatalf("unexpected sent media target: %+v", telegramChannel.sentMedia[0]) + } + if len(telegramChannel.sentMedia[0].Parts) != 1 { + t.Fatalf("expected exactly 1 sent media part, got %d", len(telegramChannel.sentMedia[0].Parts)) + } + + select { + case extra := <-msgBus.OutboundMediaChan(): + t.Fatalf("expected handled media to bypass async queue, got %+v", extra) + default: + } + + defaultAgent := al.GetRegistry().GetDefaultAgent() + if defaultAgent == nil { + t.Fatal("expected default agent") + } + route, _, err := al.resolveMessageRoute(bus.InboundMessage{ + Channel: "telegram", + ChatID: "chat1", + SenderID: "user1", + Content: "take a screenshot of the screen and send it to me", + }) + if err != nil { + t.Fatalf("resolveMessageRoute() error = %v", err) + } + sessionKey := resolveScopeKey(route, "") + history := defaultAgent.Sessions.GetHistory(sessionKey) + if len(history) == 0 { + t.Fatal("expected session history to be saved") + } + last := history[len(history)-1] + if last.Role != "assistant" || last.Content != "Requested output delivered via tool attachment." { + t.Fatalf("expected handled assistant summary in history, got %+v", last) + } +} + +func TestProcessMessage_HandledToolProcessesQueuedSteeringBeforeReturning(t *testing.T) { + tmpDir := t.TempDir() + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + ModelName: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + + msgBus := bus.NewMessageBus() + provider := &handledMediaWithSteeringProvider{} + al := NewAgentLoop(cfg, msgBus, provider) + + store := media.NewFileMediaStore() + al.SetMediaStore(store) + telegramChannel := &fakeMediaChannel{fakeChannel: fakeChannel{id: "rid-telegram"}} + al.SetChannelManager(newStartedTestChannelManager(t, msgBus, store, "telegram", telegramChannel)) + + imagePath := filepath.Join(tmpDir, "screen-steering.png") + if err := os.WriteFile(imagePath, []byte("fake screenshot"), 0o644); err != nil { + t.Fatalf("WriteFile(imagePath) error = %v", err) + } + + al.RegisterTool(&handledMediaWithSteeringTool{ + store: store, + path: imagePath, + loop: al, + }) + + response, err := al.processMessage(context.Background(), bus.InboundMessage{ + Channel: "telegram", + ChatID: "chat1", + SenderID: "user1", + Content: "take a screenshot of the screen and send it to me", + }) + if err != nil { + t.Fatalf("processMessage() error = %v", err) + } + if response != "Handled the queued steering message." { + t.Fatalf("response = %q, want queued steering response", response) + } + if provider.calls != 2 { + t.Fatalf("expected 2 LLM calls after queued steering, got %d", provider.calls) + } + if len(telegramChannel.sentMedia) != 1 { + t.Fatalf("expected exactly 1 synchronously sent media message, got %d", len(telegramChannel.sentMedia)) + } +} + +func TestProcessMessage_MediaArtifactCanBeForwardedBySendFile(t *testing.T) { + tmpDir := t.TempDir() + cfg := config.DefaultConfig() + cfg.Agents.Defaults.Workspace = tmpDir + cfg.Agents.Defaults.ModelName = "test-model" + cfg.Agents.Defaults.MaxTokens = 4096 + cfg.Agents.Defaults.MaxToolIterations = 10 + + msgBus := bus.NewMessageBus() + provider := &artifactThenSendProvider{} + al := NewAgentLoop(cfg, msgBus, provider) + + store := media.NewFileMediaStore() + al.SetMediaStore(store) + telegramChannel := &fakeMediaChannel{fakeChannel: fakeChannel{id: "rid-telegram"}} + al.SetChannelManager(newStartedTestChannelManager(t, msgBus, store, "telegram", telegramChannel)) + + mediaDir := media.TempDir() + if err := os.MkdirAll(mediaDir, 0o700); err != nil { + t.Fatalf("MkdirAll(mediaDir) error = %v", err) + } + imagePath := filepath.Join(mediaDir, "artifact-screen.png") + if err := os.WriteFile(imagePath, []byte("fake screenshot"), 0o644); err != nil { + t.Fatalf("WriteFile(imagePath) error = %v", err) + } + + al.RegisterTool(&mediaArtifactTool{ + store: store, + path: imagePath, + }) + + response, err := al.processMessage(context.Background(), bus.InboundMessage{ + Channel: "telegram", + ChatID: "chat1", + SenderID: "user1", + Content: "take a screenshot of the screen and send it to me", + }) + if err != nil { + t.Fatalf("processMessage() error = %v", err) + } + if response != "" { + t.Fatalf("expected no final response after send_file handled delivery, got %q", response) + } + if provider.calls != 2 { + t.Fatalf("expected 2 LLM calls (artifact + send_file), got %d", provider.calls) + } + + if len(telegramChannel.sentMedia) != 1 { + t.Fatalf("expected exactly 1 synchronously sent media message, got %d", len(telegramChannel.sentMedia)) + } + if telegramChannel.sentMedia[0].Channel != "telegram" || telegramChannel.sentMedia[0].ChatID != "chat1" { + t.Fatalf("unexpected sent media target: %+v", telegramChannel.sentMedia[0]) + } + if len(telegramChannel.sentMedia[0].Parts) != 1 { + t.Fatalf("expected exactly 1 sent media part, got %d", len(telegramChannel.sentMedia[0].Parts)) + } + + select { + case extra := <-msgBus.OutboundMediaChan(): + t.Fatalf("expected synchronous send_file delivery to bypass async queue, got %+v", extra) + default: + } +} + // TestAgentLoop_GetStartupInfo verifies startup info contains tools func TestAgentLoop_GetStartupInfo(t *testing.T) { tmpDir, err := os.MkdirTemp("", "agent-test-*") @@ -600,6 +926,98 @@ func (m *countingMockProvider) GetDefaultModel() string { return "counting-mock-model" } +type handledMediaProvider struct { + calls int + toolCounts []int +} + +func (m *handledMediaProvider) Chat( + ctx context.Context, + messages []providers.Message, + tools []providers.ToolDefinition, + model string, + opts map[string]any, +) (*providers.LLMResponse, error) { + m.calls++ + m.toolCounts = append(m.toolCounts, len(tools)) + if m.calls == 1 { + return &providers.LLMResponse{ + Content: "Taking the screenshot now.", + ToolCalls: []providers.ToolCall{{ + ID: "call_handled_media", + Type: "function", + Name: "handled_media_tool", + Arguments: map[string]any{}, + }}, + }, nil + } + return &providers.LLMResponse{}, nil +} + +func (m *handledMediaProvider) GetDefaultModel() string { + return "handled-media-model" +} + +type artifactThenSendProvider struct { + calls int +} + +func (m *artifactThenSendProvider) Chat( + ctx context.Context, + messages []providers.Message, + tools []providers.ToolDefinition, + model string, + opts map[string]any, +) (*providers.LLMResponse, error) { + m.calls++ + if m.calls == 1 { + return &providers.LLMResponse{ + Content: "Taking the screenshot now.", + ToolCalls: []providers.ToolCall{{ + ID: "call_artifact_media", + Type: "function", + Name: "media_artifact_tool", + Arguments: map[string]any{}, + }}, + }, nil + } + + var artifactPath string + for i := len(messages) - 1; i >= 0; i-- { + if messages[i].Role != "tool" { + continue + } + start := strings.Index(messages[i].Content, "[file:") + if start < 0 { + continue + } + rest := messages[i].Content[start+len("[file:"):] + end := strings.Index(rest, "]") + if end < 0 { + continue + } + artifactPath = rest[:end] + break + } + if artifactPath == "" { + return nil, fmt.Errorf("provider did not receive artifact path in tool result") + } + + return &providers.LLMResponse{ + Content: "", + ToolCalls: []providers.ToolCall{{ + ID: "call_send_file", + Type: "function", + Name: "send_file", + Arguments: map[string]any{"path": artifactPath}, + }}, + }, nil +} + +func (m *artifactThenSendProvider) GetDefaultModel() string { + return "artifact-then-send-model" +} + type toolLimitOnlyProvider struct{} func (m *toolLimitOnlyProvider) Chat( @@ -636,8 +1054,9 @@ func (m *mockCustomTool) Description() string { func (m *mockCustomTool) Parameters() map[string]any { return map[string]any{ - "type": "object", - "properties": map[string]any{}, + "type": "object", + "properties": map[string]any{}, + "additionalProperties": true, } } @@ -645,6 +1064,135 @@ func (m *mockCustomTool) Execute(ctx context.Context, args map[string]any) *tool return tools.SilentResult("Custom tool executed") } +type handledMediaTool struct { + store media.MediaStore + path string +} + +func (m *handledMediaTool) Name() string { return "handled_media_tool" } +func (m *handledMediaTool) Description() string { + return "Returns a media attachment and fully handles the user response" +} + +func (m *handledMediaTool) Parameters() map[string]any { + return map[string]any{ + "type": "object", + "properties": map[string]any{}, + } +} + +func (m *handledMediaTool) Execute(ctx context.Context, args map[string]any) *tools.ToolResult { + ref, err := m.store.Store(m.path, media.MediaMeta{ + Filename: filepath.Base(m.path), + ContentType: "image/png", + Source: "test:handled_media_tool", + }, "test:handled_media") + if err != nil { + return tools.ErrorResult(err.Error()).WithError(err) + } + return tools.MediaResult("Attachment delivered by tool.", []string{ref}).WithResponseHandled() +} + +type handledMediaWithSteeringProvider struct { + calls int +} + +func (m *handledMediaWithSteeringProvider) Chat( + ctx context.Context, + messages []providers.Message, + tools []providers.ToolDefinition, + model string, + opts map[string]any, +) (*providers.LLMResponse, error) { + m.calls++ + if m.calls == 1 { + return &providers.LLMResponse{ + Content: "Taking the screenshot now.", + ToolCalls: []providers.ToolCall{{ + ID: "call_handled_media_steering", + Type: "function", + Name: "handled_media_with_steering_tool", + Arguments: map[string]any{}, + }}, + }, nil + } + + for _, msg := range messages { + if msg.Role == "user" && msg.Content == "what about this instead?" { + return &providers.LLMResponse{Content: "Handled the queued steering message."}, nil + } + } + + return nil, fmt.Errorf("provider did not receive queued steering message") +} + +func (m *handledMediaWithSteeringProvider) GetDefaultModel() string { + return "handled-media-with-steering-model" +} + +type handledMediaWithSteeringTool struct { + store media.MediaStore + path string + loop *AgentLoop +} + +func (m *handledMediaWithSteeringTool) Name() string { return "handled_media_with_steering_tool" } +func (m *handledMediaWithSteeringTool) Description() string { + return "Returns handled media and enqueues a steering message during execution" +} + +func (m *handledMediaWithSteeringTool) Parameters() map[string]any { + return map[string]any{ + "type": "object", + "properties": map[string]any{}, + } +} + +func (m *handledMediaWithSteeringTool) Execute(ctx context.Context, args map[string]any) *tools.ToolResult { + if err := m.loop.Steer(providers.Message{Role: "user", Content: "what about this instead?"}); err != nil { + return tools.ErrorResult(err.Error()).WithError(err) + } + + ref, err := m.store.Store(m.path, media.MediaMeta{ + Filename: filepath.Base(m.path), + ContentType: "image/png", + Source: "test:handled_media_with_steering_tool", + }, "test:handled_media_with_steering") + if err != nil { + return tools.ErrorResult(err.Error()).WithError(err) + } + return tools.MediaResult("Attachment delivered by tool.", []string{ref}).WithResponseHandled() +} + +type mediaArtifactTool struct { + store media.MediaStore + path string +} + +func (m *mediaArtifactTool) Name() string { return "media_artifact_tool" } +func (m *mediaArtifactTool) Description() string { + return "Returns a media artifact that the agent can forward or save later" +} + +func (m *mediaArtifactTool) Parameters() map[string]any { + return map[string]any{ + "type": "object", + "properties": map[string]any{}, + } +} + +func (m *mediaArtifactTool) Execute(ctx context.Context, args map[string]any) *tools.ToolResult { + ref, err := m.store.Store(m.path, media.MediaMeta{ + Filename: filepath.Base(m.path), + ContentType: "image/png", + Source: "test:media_artifact_tool", + }, "test:media_artifact") + if err != nil { + return tools.ErrorResult(err.Error()).WithError(err) + } + return tools.MediaResult("Artifact created.", []string{ref}) +} + type toolLimitTestTool struct{} func (m *toolLimitTestTool) Name() string { @@ -1494,18 +2042,17 @@ func TestTargetReasoningChannelID_AllChannels(t *testing.T) { t.Fatalf("Failed to create channel manager: %v", err) } for name, id := range map[string]string{ - "whatsapp": "rid-whatsapp", - "telegram": "rid-telegram", - "feishu": "rid-feishu", - "discord": "rid-discord", - "maixcam": "rid-maixcam", - "qq": "rid-qq", - "dingtalk": "rid-dingtalk", - "slack": "rid-slack", - "line": "rid-line", - "onebot": "rid-onebot", - "wecom": "rid-wecom", - "wecom_app": "rid-wecom-app", + "whatsapp": "rid-whatsapp", + "telegram": "rid-telegram", + "feishu": "rid-feishu", + "discord": "rid-discord", + "maixcam": "rid-maixcam", + "qq": "rid-qq", + "dingtalk": "rid-dingtalk", + "slack": "rid-slack", + "line": "rid-line", + "onebot": "rid-onebot", + "wecom": "rid-wecom", } { chManager.RegisterChannel(name, &fakeChannel{id: id}) } @@ -1525,7 +2072,6 @@ func TestTargetReasoningChannelID_AllChannels(t *testing.T) { {channel: "line", wantID: "rid-line"}, {channel: "onebot", wantID: "rid-onebot"}, {channel: "wecom", wantID: "rid-wecom"}, - {channel: "wecom_app", wantID: "rid-wecom-app"}, {channel: "unknown", wantID: ""}, } diff --git a/pkg/channels/README.md b/pkg/channels/README.md index b7c56660b..7f238ece5 100644 --- a/pkg/channels/README.md +++ b/pkg/channels/README.md @@ -1255,8 +1255,7 @@ make test # Full test suite | `pkg/channels/onebot/` | `"onebot"` | ReactionCapable, MediaSender | | `pkg/channels/dingtalk/` | `"dingtalk"` | — | | `pkg/channels/feishu/` | `"feishu"` | — (architecture-specific build tags: `feishu_32.go` / `feishu_64.go`) | -| `pkg/channels/wecom/` | `"wecom"` | WebhookHandler, HealthChecker | -| `pkg/channels/wecom/` | `"wecom_app"` | MediaSender, WebhookHandler, HealthChecker | +| `pkg/channels/wecom/` | `"wecom"` | MediaSender | | `pkg/channels/qq/` | `"qq"` | — | | `pkg/channels/whatsapp/` | `"whatsapp"` | — (Bridge mode) | | `pkg/channels/whatsapp_native/` | `"whatsapp_native"` | — (Native whatsmeow mode) | @@ -1371,7 +1370,7 @@ agentLoop.Stop() // Stop Agent 2. **Feishu architecture-specific compilation**: The Feishu channel uses build tags to distinguish 32-bit and 64-bit architectures (`feishu_32.go` / `feishu_64.go`). Feishu uses the SDK's WebSocket mode (not HTTP webhook), so it does not implement `WebhookHandler`. -3. **WeCom has two factories**: `"wecom"` (Bot mode, webhook only) and `"wecom_app"` (App mode, supports MediaSender) are registered separately. Both implement `WebhookHandler` and `HealthChecker`. +3. **WeCom is now a single channel**: `"wecom"` is implemented as a WebSocket-based AI Bot channel with route persistence. Access control uses the shared channel allowlist mechanism. It no longer exposes the legacy webhook/app split. 4. **Pico Protocol**: `pkg/channels/pico/` implements a custom PicoClaw native protocol channel that receives messages via WebSocket webhook (`/pico/ws`). @@ -1381,4 +1380,4 @@ agentLoop.Stop() // Stop Agent 7. **PlaceholderConfig vs implementation**: `PlaceholderConfig` appears in 6 channel configs (Telegram, Discord, Slack, LINE, OneBot, Pico), but only channels that implement both `PlaceholderCapable` + `MessageEditor` (Telegram, Discord, Pico) can actually use placeholder message editing. The rest are reserved fields. -8. **ReasoningChannelID**: Most channel configs include a `reasoning_channel_id` field to route LLM reasoning/thinking output to a designated channel (WhatsApp, Telegram, Feishu, Discord, MaixCam, QQ, DingTalk, Slack, LINE, OneBot, WeCom, WeComApp). Note: `PicoConfig` does not currently expose this field. `BaseChannel` exposes this via the `WithReasoningChannelID` option and `ReasoningChannelID()` method. \ No newline at end of file +8. **ReasoningChannelID**: Most channel configs include a `reasoning_channel_id` field to route LLM reasoning/thinking output to a designated channel (WhatsApp, Telegram, Feishu, Discord, MaixCam, QQ, DingTalk, Slack, LINE, OneBot, WeCom). Note: `PicoConfig` does not currently expose this field. `BaseChannel` exposes this via the `WithReasoningChannelID` option and `ReasoningChannelID()` method. diff --git a/pkg/channels/README.zh.md b/pkg/channels/README.zh.md index 2c5e7356e..8bc8c8dbc 100644 --- a/pkg/channels/README.zh.md +++ b/pkg/channels/README.zh.md @@ -1254,8 +1254,7 @@ make test # 全量测试 | `pkg/channels/onebot/` | `"onebot"` | ReactionCapable, MediaSender | | `pkg/channels/dingtalk/` | `"dingtalk"` | — | | `pkg/channels/feishu/` | `"feishu"` | — (架构特定 build tags: `feishu_32.go` / `feishu_64.go`) | -| `pkg/channels/wecom/` | `"wecom"` | WebhookHandler, HealthChecker | -| `pkg/channels/wecom/` | `"wecom_app"` | MediaSender, WebhookHandler, HealthChecker | +| `pkg/channels/wecom/` | `"wecom"` | MediaSender | | `pkg/channels/qq/` | `"qq"` | — | | `pkg/channels/whatsapp/` | `"whatsapp"` | — (Bridge 模式) | | `pkg/channels/whatsapp_native/` | `"whatsapp_native"` | — (原生 whatsmeow 模式) | @@ -1370,7 +1369,7 @@ agentLoop.Stop() // 停止 Agent 2. **Feishu 架构特定编译**:Feishu channel 使用 build tags 区分 32 位和 64 位架构(`feishu_32.go` / `feishu_64.go`)。Feishu 使用 SDK 的 WebSocket 模式(非 HTTP webhook),因此不实现 `WebhookHandler`。 -3. **WeCom 有两个工厂**:`"wecom"`(Bot 模式,纯 webhook)和 `"wecom_app"`(应用模式,支持 MediaSender)分别注册。两者都实现了 `WebhookHandler` 和 `HealthChecker`。 +3. **WeCom 现在只有一个 channel**:`"wecom"` 采用 WebSocket AI Bot 实现,带路由持久化;访问控制走统一的 channel 白名单机制,不再保留旧的 webhook/app 双分支。 4. **Pico Protocol**:`pkg/channels/pico/` 实现了一个自定义的 PicoClaw 原生协议 channel,通过 WebSocket webhook (`/pico/ws`) 接收消息。 @@ -1380,4 +1379,4 @@ agentLoop.Stop() // 停止 Agent 7. **PlaceholderConfig 的配置与实现**:`PlaceholderConfig` 出现在 6 个 channel config 中(Telegram、Discord、Slack、LINE、OneBot、Pico),但只有实现了 `PlaceholderCapable` + `MessageEditor` 的 channel(Telegram、Discord、Pico)能真正使用占位消息编辑功能。其余 channel 的 `PlaceholderConfig` 为预留字段。 -8. **ReasoningChannelID**:大多数 channel config 都包含 `reasoning_channel_id` 字段,用于将 LLM 的思维链(reasoning/thinking)路由到指定 channel(WhatsApp、Telegram、Feishu、Discord、MaixCam、QQ、DingTalk、Slack、LINE、OneBot、WeCom、WeComApp)。注意:`PicoConfig` 目前不包含该字段。`BaseChannel` 通过 `WithReasoningChannelID` 选项和 `ReasoningChannelID()` 方法暴露此配置。 \ No newline at end of file +8. **ReasoningChannelID**:大多数 channel config 都包含 `reasoning_channel_id` 字段,用于将 LLM 的思维链(reasoning/thinking)路由到指定 channel(WhatsApp、Telegram、Feishu、Discord、MaixCam、QQ、DingTalk、Slack、LINE、OneBot、WeCom)。注意:`PicoConfig` 目前不包含该字段。`BaseChannel` 通过 `WithReasoningChannelID` 选项和 `ReasoningChannelID()` 方法暴露此配置。 diff --git a/pkg/channels/manager.go b/pkg/channels/manager.go index f04d989a3..c8269dc77 100644 --- a/pkg/channels/manager.go +++ b/pkg/channels/manager.go @@ -206,6 +206,40 @@ func (m *Manager) preSend(ctx context.Context, name string, msg bus.OutboundMess return false } +// preSendMedia handles typing stop, reaction undo, and placeholder cleanup +// before sending media attachments. Unlike preSend for text messages, media +// delivery never edits the placeholder because there is no text payload to +// replace it with; it only attempts to delete the placeholder when possible. +func (m *Manager) preSendMedia(ctx context.Context, name string, msg bus.OutboundMediaMessage, ch Channel) { + key := name + ":" + msg.ChatID + + // 1. Stop typing + if v, loaded := m.typingStops.LoadAndDelete(key); loaded { + if entry, ok := v.(typingEntry); ok { + entry.stop() // idempotent, safe + } + } + + // 2. Undo reaction + if v, loaded := m.reactionUndos.LoadAndDelete(key); loaded { + if entry, ok := v.(reactionEntry); ok { + entry.undo() // idempotent, safe + } + } + + // 3. Clear any finalized stream marker for this chat before media delivery. + m.streamActive.LoadAndDelete(key) + + // 4. Delete placeholder if present. + if v, loaded := m.placeholders.LoadAndDelete(key); loaded { + if entry, ok := v.(placeholderEntry); ok && entry.id != "" { + if deleter, ok := ch.(MessageDeleter); ok { + deleter.DeleteMessage(ctx, msg.ChatID, entry.id) // best effort + } + } + } +} + func NewManager(cfg *config.Config, messageBus *bus.MessageBus, store media.MediaStore) (*Manager, error) { m := &Manager{ channels: make(map[string]Channel), @@ -371,19 +405,10 @@ func (m *Manager) initChannels(channels *config.ChannelsConfig) error { m.initChannel("onebot", "OneBot") } - if channels.WeCom.Enabled && channels.WeCom.Token() != "" { + if channels.WeCom.Enabled && channels.WeCom.BotID != "" && channels.WeCom.Secret() != "" { m.initChannel("wecom", "WeCom") } - if channels.WeComAIBot.Enabled && (channels.WeComAIBot.Token() != "" || - (channels.WeComAIBot.Secret() != "" && channels.WeComAIBot.BotID != "")) { - m.initChannel("wecom_aibot", "WeCom AI Bot") - } - - if channels.WeComApp.Enabled && channels.WeComApp.CorpID != "" { - m.initChannel("wecom_app", "WeCom App") - } - if channels.Weixin.Enabled && channels.Weixin.Token() != "" { m.initChannel("weixin", "Weixin") } @@ -774,7 +799,7 @@ func (m *Manager) runMediaWorker(ctx context.Context, name string, w *channelWor if !ok { return } - m.sendMediaWithRetry(ctx, name, w, msg) + _ = m.sendMediaWithRetry(ctx, name, w, msg) case <-ctx.Done(): return } @@ -782,26 +807,37 @@ func (m *Manager) runMediaWorker(ctx context.Context, name string, w *channelWor } // sendMediaWithRetry sends a media message through the channel with rate limiting and -// retry logic. If the channel does not implement MediaSender, it silently skips. -func (m *Manager) sendMediaWithRetry(ctx context.Context, name string, w *channelWorker, msg bus.OutboundMediaMessage) { +// retry logic. It returns nil on success, or the last error after retries, +// including when the channel does not support MediaSender. +func (m *Manager) sendMediaWithRetry( + ctx context.Context, + name string, + w *channelWorker, + msg bus.OutboundMediaMessage, +) error { ms, ok := w.ch.(MediaSender) if !ok { - logger.DebugCF("channels", "Channel does not support MediaSender, skipping media", map[string]any{ + err := fmt.Errorf("channel %q does not support media sending", name) + logger.WarnCF("channels", "Channel does not support MediaSender", map[string]any{ "channel": name, + "error": err.Error(), }) - return + return err } // Rate limit: wait for token if err := w.limiter.Wait(ctx); err != nil { - return + return err } + // Pre-send: stop typing and clean up any placeholder before sending media. + m.preSendMedia(ctx, name, msg, w.ch) + var lastErr error for attempt := 0; attempt <= maxRetries; attempt++ { lastErr = ms.SendMedia(ctx, msg) if lastErr == nil { - return + return nil } // Permanent failures — don't retry @@ -820,7 +856,7 @@ func (m *Manager) sendMediaWithRetry(ctx context.Context, name string, w *channe case <-time.After(rateLimitDelay): continue case <-ctx.Done(): - return + return ctx.Err() } } @@ -829,7 +865,7 @@ func (m *Manager) sendMediaWithRetry(ctx context.Context, name string, w *channe select { case <-time.After(backoff): case <-ctx.Done(): - return + return ctx.Err() } } @@ -840,6 +876,7 @@ func (m *Manager) sendMediaWithRetry(ctx context.Context, name string, w *channe "error": lastErr.Error(), "retries": maxRetries, }) + return lastErr } // runTTLJanitor periodically scans the typingStops and placeholders maps @@ -1032,6 +1069,26 @@ func (m *Manager) SendMessage(ctx context.Context, msg bus.OutboundMessage) erro return nil } +// SendMedia sends outbound media synchronously through the channel worker's +// rate limiter and retry logic. It blocks until the media is delivered (or all +// retries are exhausted), which preserves ordering when later agent behavior +// depends on actual media delivery. +func (m *Manager) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error { + m.mu.RLock() + _, exists := m.channels[msg.Channel] + w, wExists := m.workers[msg.Channel] + m.mu.RUnlock() + + if !exists { + return fmt.Errorf("channel %s not found", msg.Channel) + } + if !wExists || w == nil { + return fmt.Errorf("channel %s has no active worker", msg.Channel) + } + + return m.sendMediaWithRetry(ctx, msg.Channel, w, msg) +} + func (m *Manager) SendToChannel(ctx context.Context, channelName, chatID, content string) error { m.mu.RLock() _, exists := m.channels[channelName] diff --git a/pkg/channels/manager_channel.go b/pkg/channels/manager_channel.go index 86572e336..163218b75 100644 --- a/pkg/channels/manager_channel.go +++ b/pkg/channels/manager_channel.go @@ -49,15 +49,7 @@ func hiddenValues(key string, value map[string]any, ch config.ChannelsConfig) { value["token"] = ch.LINE.ChannelAccessToken() value["secret"] = ch.LINE.ChannelSecret() case "wecom": - value["token"] = ch.WeCom.Token() - value["key"] = ch.WeCom.EncodingAESKey() - case "wecom_app": - value["token"] = ch.WeComApp.Token() - value["secret"] = ch.WeComApp.CorpSecret() - case "wecom_aibot": - value["token"] = ch.WeComAIBot.Token() - value["key"] = ch.WeComAIBot.EncodingAESKey() - value["secret"] = ch.WeComAIBot.Secret() + value["secret"] = ch.WeCom.Secret() case "dingtalk": value["secret"] = ch.QQ.AppSecret() case "qq": @@ -156,16 +148,7 @@ func updateKeys(newcfg, old *config.ChannelsConfig) { newcfg.LINE.SetChannelSecret(old.LINE.ChannelSecret()) } if newcfg.WeCom.Enabled { - newcfg.WeCom.SetToken(old.WeCom.Token()) - newcfg.WeCom.SetEncodingAESKey(old.WeCom.EncodingAESKey()) - } - if newcfg.WeComApp.Enabled { - newcfg.WeComApp.SetToken(old.WeComApp.Token()) - newcfg.WeComApp.SetCorpSecret(old.WeComApp.CorpSecret()) - } - if newcfg.WeComAIBot.Enabled { - newcfg.WeComAIBot.SetToken(old.WeComAIBot.Token()) - newcfg.WeComAIBot.SetEncodingAESKey(old.WeComAIBot.EncodingAESKey()) + newcfg.WeCom.SetSecret(old.WeCom.Secret()) } if newcfg.DingTalk.Enabled { newcfg.DingTalk.SetClientSecret(old.DingTalk.ClientSecret()) diff --git a/pkg/channels/manager_test.go b/pkg/channels/manager_test.go index 7dfec9ebf..b4fd2ba3d 100644 --- a/pkg/channels/manager_test.go +++ b/pkg/channels/manager_test.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "strings" "sync" "sync/atomic" "testing" @@ -43,6 +44,40 @@ func (m *mockChannel) EditMessage(ctx context.Context, chatID, messageID, conten return nil } +type mockMediaChannel struct { + mockChannel + sendMediaFn func(ctx context.Context, msg bus.OutboundMediaMessage) error + sentMediaMessages []bus.OutboundMediaMessage +} + +func (m *mockMediaChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error { + m.sentMediaMessages = append(m.sentMediaMessages, msg) + if m.sendMediaFn != nil { + return m.sendMediaFn(ctx, msg) + } + return nil +} + +type mockDeletingMediaChannel struct { + mockMediaChannel + deleteCalls int + lastDeleted struct { + chatID string + messageID string + } +} + +func (m *mockDeletingMediaChannel) DeleteMessage( + _ context.Context, + chatID string, + messageID string, +) error { + m.deleteCalls++ + m.lastDeleted.chatID = chatID + m.lastDeleted.messageID = messageID + return nil +} + // newTestManager creates a minimal Manager suitable for unit tests. func newTestManager() *Manager { return &Manager{ @@ -208,6 +243,125 @@ func TestSendWithRetry_MaxRetriesExhausted(t *testing.T) { } } +func TestSendMedia_Success(t *testing.T) { + m := newTestManager() + var callCount int + ch := &mockMediaChannel{ + sendMediaFn: func(_ context.Context, _ bus.OutboundMediaMessage) error { + callCount++ + return nil + }, + } + w := &channelWorker{ + ch: ch, + limiter: rate.NewLimiter(rate.Inf, 1), + } + m.channels["test"] = ch + m.workers["test"] = w + + err := m.SendMedia(context.Background(), bus.OutboundMediaMessage{ + Channel: "test", + ChatID: "chat1", + Parts: []bus.MediaPart{{Ref: "media://abc"}}, + }) + if err != nil { + t.Fatalf("SendMedia() error = %v", err) + } + if callCount != 1 { + t.Fatalf("expected 1 SendMedia call, got %d", callCount) + } +} + +func TestSendMedia_PropagatesFailure(t *testing.T) { + m := newTestManager() + ch := &mockMediaChannel{ + sendMediaFn: func(_ context.Context, _ bus.OutboundMediaMessage) error { + return fmt.Errorf("bad upload: %w", ErrSendFailed) + }, + } + w := &channelWorker{ + ch: ch, + limiter: rate.NewLimiter(rate.Inf, 1), + } + m.channels["test"] = ch + m.workers["test"] = w + + err := m.SendMedia(context.Background(), bus.OutboundMediaMessage{ + Channel: "test", + ChatID: "chat1", + Parts: []bus.MediaPart{{Ref: "media://abc"}}, + }) + if err == nil { + t.Fatal("expected SendMedia to return error") + } + if !errors.Is(err, ErrSendFailed) { + t.Fatalf("expected ErrSendFailed, got %v", err) + } +} + +func TestSendMedia_UnsupportedChannelReturnsError(t *testing.T) { + m := newTestManager() + ch := &mockChannel{ + sendFn: func(_ context.Context, _ bus.OutboundMessage) error { + return nil + }, + } + w := &channelWorker{ + ch: ch, + limiter: rate.NewLimiter(rate.Inf, 1), + } + m.channels["test"] = ch + m.workers["test"] = w + + err := m.SendMedia(context.Background(), bus.OutboundMediaMessage{ + Channel: "test", + ChatID: "chat1", + Parts: []bus.MediaPart{{Ref: "media://abc"}}, + }) + if err == nil { + t.Fatal("expected SendMedia to return error for unsupported channel") + } + if !strings.Contains(err.Error(), "does not support media sending") { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestSendMedia_DeletesPlaceholderBeforeSending(t *testing.T) { + m := newTestManager() + ch := &mockDeletingMediaChannel{ + mockMediaChannel: mockMediaChannel{ + sendMediaFn: func(_ context.Context, _ bus.OutboundMediaMessage) error { + return nil + }, + }, + } + w := &channelWorker{ + ch: ch, + limiter: rate.NewLimiter(rate.Inf, 1), + } + m.channels["test"] = ch + m.workers["test"] = w + m.RecordPlaceholder("test", "chat1", "placeholder-1") + + err := m.SendMedia(context.Background(), bus.OutboundMediaMessage{ + Channel: "test", + ChatID: "chat1", + Parts: []bus.MediaPart{{Ref: "media://abc"}}, + }) + if err != nil { + t.Fatalf("SendMedia() error = %v", err) + } + if ch.deleteCalls != 1 { + t.Fatalf("expected placeholder delete to be called once, got %d", ch.deleteCalls) + } + if ch.lastDeleted.chatID != "chat1" || ch.lastDeleted.messageID != "placeholder-1" { + t.Fatalf("unexpected placeholder deletion target: %+v", ch.lastDeleted) + } + if len(ch.sentMediaMessages) != 1 { + t.Fatalf("expected media to be sent once, got %d", len(ch.sentMediaMessages)) + } +} + func TestSendWithRetry_UnknownError(t *testing.T) { m := newTestManager() var callCount int diff --git a/pkg/channels/matrix/init.go b/pkg/channels/matrix/init.go index 6677f855e..4d6ad45a7 100644 --- a/pkg/channels/matrix/init.go +++ b/pkg/channels/matrix/init.go @@ -1,6 +1,8 @@ package matrix import ( + "path/filepath" + "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/channels" "github.com/sipeed/picoclaw/pkg/config" @@ -8,6 +10,11 @@ import ( func init() { channels.RegisterFactory("matrix", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) { - return NewMatrixChannel(cfg.Channels.Matrix, b) + matrixCfg := cfg.Channels.Matrix + cryptoDatabasePath := matrixCfg.CryptoDatabasePath + if cryptoDatabasePath == "" { + cryptoDatabasePath = filepath.Join(cfg.WorkspacePath(), "matrix") + } + return NewMatrixChannel(matrixCfg, b, cryptoDatabasePath) }) } diff --git a/pkg/channels/matrix/matrix.go b/pkg/channels/matrix/matrix.go index 98c607d0b..50b86158d 100644 --- a/pkg/channels/matrix/matrix.go +++ b/pkg/channels/matrix/matrix.go @@ -2,6 +2,7 @@ package matrix import ( "context" + "database/sql" "fmt" "html" "io" @@ -17,9 +18,12 @@ import ( "github.com/gomarkdown/markdown" mdhtml "github.com/gomarkdown/markdown/html" "github.com/gomarkdown/markdown/parser" + "go.mau.fi/util/dbutil" "maunium.net/go/mautrix" + "maunium.net/go/mautrix/crypto/cryptohelper" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" + _ "modernc.org/sqlite" "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/channels" @@ -30,6 +34,9 @@ import ( ) const ( + sqliteDriver = "sqlite" + dbName = "store.db" + typingRefreshInterval = 20 * time.Second typingServerTTL = 30 * time.Second roomKindCacheTTL = 5 * time.Minute @@ -181,9 +188,16 @@ type MatrixChannel struct { roomKindCache *roomKindCache localpartMentionR *regexp.Regexp + + cryptoHelper *cryptohelper.CryptoHelper + cryptoDbPath string } -func NewMatrixChannel(cfg config.MatrixConfig, messageBus *bus.MessageBus) (*MatrixChannel, error) { +func NewMatrixChannel( + cfg config.MatrixConfig, + messageBus *bus.MessageBus, + cryptoDatabasePath string, +) (*MatrixChannel, error) { homeserver := strings.TrimSpace(cfg.Homeserver) userID := strings.TrimSpace(cfg.UserID) accessToken := strings.TrimSpace(cfg.AccessToken()) @@ -230,6 +244,7 @@ func NewMatrixChannel(cfg config.MatrixConfig, messageBus *bus.MessageBus) (*Mat roomKindCache: newRoomKindCache(roomKindCacheMaxEntries, roomKindCacheTTL), localpartMentionR: localpartMentionRegexp(matrixLocalpart(client.UserID)), typingMu: sync.Mutex{}, + cryptoDbPath: cryptoDatabasePath, }, nil } @@ -239,7 +254,21 @@ func (c *MatrixChannel) Start(ctx context.Context) error { c.ctx, c.cancel = context.WithCancel(ctx) c.startTime = time.Now() + // Initialize crypto helper if database and passphrase are configured + if c.cryptoDbPath != "" && c.config.CryptoPassphrase != "" { + if err := c.initCrypto(ctx); err != nil { + logger.WarnCF( + "matrix", + "Failed to initialize crypto, continuing without encryption support", + map[string]any{ + "error": err.Error(), + }, + ) + } + } + c.syncer.OnEventType(event.EventMessage, c.handleMessageEvent) + c.syncer.OnEventType(event.EventEncrypted, c.handleMessageEvent) c.syncer.OnEventType(event.StateMember, c.handleMemberEvent) c.SetRunning(true) @@ -266,10 +295,84 @@ func (c *MatrixChannel) Stop(ctx context.Context) error { } c.stopTypingSessions(ctx) + // Close crypto helper if initialized + if c.cryptoHelper != nil { + c.cryptoHelper.Close() + c.cryptoHelper = nil + c.client.Crypto = nil + } + logger.InfoC("matrix", "Matrix channel stopped") return nil } +func (c *MatrixChannel) initCrypto(ctx context.Context) error { + logger.InfoC("matrix", "Initializing crypto helper") + + // Ensure the crypto database directory exists + if err := os.MkdirAll(c.cryptoDbPath, 0o700); err != nil { + return fmt.Errorf("create crypto database directory: %w", err) + } + + // Create database with sqlite driver (modernc.org/sqlite) + dbPath := filepath.Join(c.cryptoDbPath, dbName) + connStr := "file:" + dbPath + "?_foreign_keys=on" + + db, err := sql.Open(sqliteDriver, connStr) + if err != nil { + return fmt.Errorf("open crypto database: %w", err) + } + db.SetMaxOpenConns(1) + db.SetMaxIdleConns(1) + + // Execute PRAGMA statements + // This is equivalent to the "sqlite3-fk-wal" dialect used by cryptohelper + pragmaStmts := []string{ + "PRAGMA foreign_keys = ON", + "PRAGMA journal_mode = WAL", + "PRAGMA synchronous = NORMAL", + "PRAGMA busy_timeout = 5000", + } + for _, pragma := range pragmaStmts { + if _, err = db.ExecContext(ctx, pragma); err != nil { + _ = db.Close() + return fmt.Errorf("execute %s: %w", pragma, err) + } + } + + // Wrap with dbutil for dialect support + wrappedDB, err := dbutil.NewWithDB(db, sqliteDriver) + if err != nil { + _ = db.Close() + return fmt.Errorf("wrap database: %w", err) + } + + cryptoHelper, err := cryptohelper.NewCryptoHelper(c.client, []byte(c.config.CryptoPassphrase), wrappedDB) + if err != nil { + return fmt.Errorf("create crypto helper: %w", err) + } + + if c.client.DeviceID == "" { + resp, whoamiErr := c.client.Whoami(ctx) + if whoamiErr != nil { + _ = db.Close() + return fmt.Errorf("get device ID via whoami: %w", whoamiErr) + } + c.client.DeviceID = resp.DeviceID + } + + if err = cryptoHelper.Init(ctx); err != nil { + cryptoHelper.Close() + return fmt.Errorf("init crypto helper: %w", err) + } + + c.client.Crypto = cryptoHelper + c.cryptoHelper = cryptoHelper + + logger.InfoC("matrix", "Crypto helper initialized successfully") + return nil +} + func markdownToHTML(md string) string { p := parser.NewWithExtensions(parser.CommonExtensions | parser.AutoHeadingIDs) renderer := mdhtml.NewRenderer(mdhtml.RendererOptions{Flags: mdhtml.CommonFlags}) @@ -548,9 +651,26 @@ func (c *MatrixChannel) handleMessageEvent(ctx context.Context, evt *event.Event return } - msgEvt := evt.Content.AsMessage() - if msgEvt == nil { - return + var msgEvt *event.MessageEventContent + switch evt.Type { + case event.EventMessage: + // When crypto is enabled, events marked WasEncrypted=true are + // re-dispatched by c.cryptoHelper after decryption and will be + // processed again in the EventEncrypted branch. Skip to avoid duplication. + if c.client.Crypto != nil && evt.Mautrix.WasEncrypted { + return + } + + msgEvt = evt.Content.AsMessage() + if msgEvt == nil || msgEvt.MsgType == "" { + return + } + case event.EventEncrypted: + var ok bool + msgEvt, ok = c.decryptEvent(ctx, evt) + if !ok { + return + } } // Ignore edits. @@ -642,6 +762,36 @@ func (c *MatrixChannel) handleMessageEvent(ctx context.Context, evt *event.Event ) } +// decryptEvent decrypts an encrypted event and returns the decrypted message event content. +// It returns the decrypted content and a boolean indicating whether decryption was successful. +func (c *MatrixChannel) decryptEvent(ctx context.Context, evt *event.Event) (*event.MessageEventContent, bool) { + if c.client.Crypto == nil { + logger.DebugCF("matrix", "Received encrypted message but crypto is not enabled", map[string]any{ + "room_id": evt.RoomID.String(), + }) + return nil, false + } + + decrypted, err := c.client.Crypto.Decrypt(ctx, evt) + if err != nil { + logger.WarnCF("matrix", "Failed to decrypt message", map[string]any{ + "room_id": evt.RoomID.String(), + "error": err.Error(), + }) + return nil, false + } + + if decrypted.Type != event.EventMessage { + logger.DebugCF("matrix", "Decrypted event is not a message event", map[string]any{ + "room_id": evt.RoomID.String(), + "type": decrypted.Type.String(), + }) + return nil, false + } + + return decrypted.Content.AsMessage(), true +} + func (c *MatrixChannel) extractInboundContent( ctx context.Context, msgEvt *event.MessageEventContent, diff --git a/pkg/channels/pico/pico.go b/pkg/channels/pico/pico.go index 86ce98b06..f3ba55a92 100644 --- a/pkg/channels/pico/pico.go +++ b/pkg/channels/pico/pico.go @@ -54,12 +54,13 @@ func (pc *picoConn) close() { // It serves as the reference implementation for all optional capability interfaces. type PicoChannel struct { *channels.BaseChannel - config config.PicoConfig - upgrader websocket.Upgrader - connections sync.Map // connID → *picoConn - connCount atomic.Int32 - ctx context.Context - cancel context.CancelFunc + config config.PicoConfig + upgrader websocket.Upgrader + connections map[string]*picoConn // connID -> *picoConn + sessionConnections map[string]map[string]*picoConn // sessionID -> connID -> *picoConn + connsMu sync.RWMutex + ctx context.Context + cancel context.CancelFunc } // NewPicoChannel creates a new Pico Protocol channel. @@ -92,9 +93,104 @@ func NewPicoChannel(cfg config.PicoConfig, messageBus *bus.MessageBus) (*PicoCha ReadBufferSize: 1024, WriteBufferSize: 1024, }, + connections: make(map[string]*picoConn), + sessionConnections: make(map[string]map[string]*picoConn), }, nil } +// createAndAddConnection checks MaxConnections and registers a connection atomically. +func (c *PicoChannel) createAndAddConnection(conn *websocket.Conn, sessionID string, maxConns int) (*picoConn, error) { + c.connsMu.Lock() + defer c.connsMu.Unlock() + if len(c.connections) >= maxConns { + return nil, channels.ErrTemporary + } + + var connID string + for { + connID = uuid.New().String() + if _, exists := c.connections[connID]; !exists { + break + } + } + + pc := &picoConn{ + id: connID, + conn: conn, + sessionID: sessionID, + } + + c.connections[pc.id] = pc + bySession, ok := c.sessionConnections[pc.sessionID] + if !ok { + bySession = make(map[string]*picoConn) + c.sessionConnections[pc.sessionID] = bySession + } + bySession[pc.id] = pc + + return pc, nil +} + +// removeConnection deletes a connection from indexes and returns it when found. +func (c *PicoChannel) removeConnection(connID string) *picoConn { + c.connsMu.Lock() + defer c.connsMu.Unlock() + + pc, ok := c.connections[connID] + if !ok { + return nil + } + + delete(c.connections, connID) + if bySession, ok := c.sessionConnections[pc.sessionID]; ok { + delete(bySession, connID) + if len(bySession) == 0 { + delete(c.sessionConnections, pc.sessionID) + } + } + + return pc +} + +// takeAllConnections snapshots and clears all connection indexes. +func (c *PicoChannel) takeAllConnections() []*picoConn { + c.connsMu.Lock() + defer c.connsMu.Unlock() + + all := make([]*picoConn, 0, len(c.connections)) + for _, pc := range c.connections { + all = append(all, pc) + } + clear(c.connections) + clear(c.sessionConnections) + + return all +} + +// sessionConnectionsSnapshot returns all active connections for a session. +func (c *PicoChannel) sessionConnectionsSnapshot(sessionID string) []*picoConn { + c.connsMu.RLock() + defer c.connsMu.RUnlock() + + bySession, ok := c.sessionConnections[sessionID] + if !ok || len(bySession) == 0 { + return nil + } + + conns := make([]*picoConn, 0, len(bySession)) + for _, pc := range bySession { + conns = append(conns, pc) + } + return conns +} + +// currentConnCount returns a lock-protected snapshot of active connection count. +func (c *PicoChannel) currentConnCount() int { + c.connsMu.RLock() + defer c.connsMu.RUnlock() + return len(c.connections) +} + // Start implements Channel. func (c *PicoChannel) Start(ctx context.Context) error { logger.InfoC("pico", "Starting Pico Protocol channel") @@ -110,13 +206,9 @@ func (c *PicoChannel) Stop(ctx context.Context) error { c.SetRunning(false) // Close all connections - c.connections.Range(func(key, value any) bool { - if pc, ok := value.(*picoConn); ok { - pc.close() - } - c.connections.Delete(key) - return true - }) + for _, pc := range c.takeAllConnections() { + pc.close() + } if c.cancel != nil { c.cancel() @@ -133,8 +225,8 @@ func (c *PicoChannel) WebhookPath() string { return "/pico/" } func (c *PicoChannel) ServeHTTP(w http.ResponseWriter, r *http.Request) { path := strings.TrimPrefix(r.URL.Path, "/pico") - switch { - case path == "/ws" || path == "/ws/": + switch path { + case "/ws", "/ws/": c.handleWebSocket(w, r) default: http.NotFound(w, r) @@ -208,23 +300,16 @@ func (c *PicoChannel) broadcastToSession(chatID string, msg PicoMessage) error { msg.SessionID = sessionID var sent bool - c.connections.Range(func(key, value any) bool { - pc, ok := value.(*picoConn) - if !ok { - return true + for _, pc := range c.sessionConnectionsSnapshot(sessionID) { + if err := pc.writeJSON(msg); err != nil { + logger.DebugCF("pico", "Write to connection failed", map[string]any{ + "conn_id": pc.id, + "error": err.Error(), + }) + } else { + sent = true } - if pc.sessionID == sessionID { - if err := pc.writeJSON(msg); err != nil { - logger.DebugCF("pico", "Write to connection failed", map[string]any{ - "conn_id": pc.id, - "error": err.Error(), - }) - } else { - sent = true - } - } - return true - }) + } if !sent { return fmt.Errorf("no active connections for session %s: %w", sessionID, channels.ErrSendFailed) @@ -250,7 +335,7 @@ func (c *PicoChannel) handleWebSocket(w http.ResponseWriter, r *http.Request) { if maxConns <= 0 { maxConns = 100 } - if int(c.connCount.Load()) >= maxConns { + if c.currentConnCount() >= maxConns { http.Error(w, "too many connections", http.StatusServiceUnavailable) return } @@ -275,15 +360,17 @@ func (c *PicoChannel) handleWebSocket(w http.ResponseWriter, r *http.Request) { sessionID = uuid.New().String() } - pc := &picoConn{ - id: uuid.New().String(), - conn: conn, - sessionID: sessionID, + pc, err := c.createAndAddConnection(conn, sessionID, maxConns) + if err != nil { + _ = conn.WriteControl( + websocket.CloseMessage, + websocket.FormatCloseMessage(websocket.CloseTryAgainLater, "too many connections"), + time.Now().Add(2*time.Second), + ) + _ = conn.Close() + return } - c.connections.Store(pc.id, pc) - c.connCount.Add(1) - logger.InfoCF("pico", "WebSocket client connected", map[string]any{ "conn_id": pc.id, "session_id": sessionID, @@ -341,12 +428,12 @@ func (c *PicoChannel) matchedSubprotocol(r *http.Request) string { func (c *PicoChannel) readLoop(pc *picoConn) { defer func() { pc.close() - c.connections.Delete(pc.id) - c.connCount.Add(-1) - logger.InfoCF("pico", "WebSocket client disconnected", map[string]any{ - "conn_id": pc.id, - "session_id": pc.sessionID, - }) + if removed := c.removeConnection(pc.id); removed != nil { + logger.InfoCF("pico", "WebSocket client disconnected", map[string]any{ + "conn_id": removed.id, + "session_id": removed.sessionID, + }) + } }() readTimeout := time.Duration(c.config.ReadTimeout) * time.Second diff --git a/pkg/channels/pico/pico_test.go b/pkg/channels/pico/pico_test.go new file mode 100644 index 000000000..e712767ad --- /dev/null +++ b/pkg/channels/pico/pico_test.go @@ -0,0 +1,144 @@ +package pico + +import ( + "context" + "errors" + "fmt" + "sync" + "testing" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" +) + +func newTestPicoChannel(t *testing.T) *PicoChannel { + t.Helper() + + cfg := config.PicoConfig{} + cfg.SetToken("test-token") + ch, err := NewPicoChannel(cfg, bus.NewMessageBus()) + if err != nil { + t.Fatalf("NewPicoChannel: %v", err) + } + + ch.ctx = context.Background() + return ch +} + +func TestCreateAndAddConnection_RespectsMaxConnectionsConcurrently(t *testing.T) { + ch := newTestPicoChannel(t) + + const ( + maxConns = 5 + goroutines = 64 + sessionID = "session-a" + ) + + var wg sync.WaitGroup + var mu sync.Mutex + successCount := 0 + errCount := 0 + + wg.Add(goroutines) + for i := 0; i < goroutines; i++ { + go func() { + defer wg.Done() + + pc, err := ch.createAndAddConnection(nil, sessionID, maxConns) + mu.Lock() + defer mu.Unlock() + + if err == nil { + successCount++ + if pc == nil { + t.Errorf("pc is nil on success") + } + return + } + if !errors.Is(err, channels.ErrTemporary) { + t.Errorf("unexpected error: %v", err) + return + } + errCount++ + }() + } + wg.Wait() + + if successCount > maxConns { + t.Fatalf("successCount=%d > maxConns=%d", successCount, maxConns) + } + if successCount+errCount != goroutines { + t.Fatalf("success=%d err=%d total=%d want=%d", successCount, errCount, successCount+errCount, goroutines) + } + if got := ch.currentConnCount(); got != maxConns { + t.Fatalf("currentConnCount=%d want=%d", got, maxConns) + } +} + +func TestRemoveConnection_CleansBothIndexes(t *testing.T) { + ch := newTestPicoChannel(t) + + pc, err := ch.createAndAddConnection(nil, "session-cleanup", 10) + if err != nil { + t.Fatalf("createAndAddConnection: %v", err) + } + + removed := ch.removeConnection(pc.id) + if removed == nil { + t.Fatal("removeConnection returned nil") + } + + ch.connsMu.RLock() + defer ch.connsMu.RUnlock() + + if _, ok := ch.connections[pc.id]; ok { + t.Fatalf("connID %s still exists in connections", pc.id) + } + if _, ok := ch.sessionConnections[pc.sessionID]; ok { + t.Fatalf("session %s still exists in sessionConnections", pc.sessionID) + } + if got := len(ch.connections); got != 0 { + t.Fatalf("len(connections)=%d want=0", got) + } +} + +func TestBroadcastToSession_TargetsOnlyRequestedSession(t *testing.T) { + ch := newTestPicoChannel(t) + + target := &picoConn{id: "target", sessionID: "s-target"} + target.closed.Store(true) + ch.addConnForTest(target) + + other := &picoConn{id: "other", sessionID: "s-other"} + ch.addConnForTest(other) + + err := ch.broadcastToSession("pico:s-target", newMessage(TypeMessageCreate, map[string]any{"content": "hello"})) + if err == nil { + t.Fatal("expected send failure due to closed target connection") + } + if !errors.Is(err, channels.ErrSendFailed) { + t.Fatalf("expected ErrSendFailed, got %v", err) + } +} + +func (c *PicoChannel) addConnForTest(pc *picoConn) { + c.connsMu.Lock() + defer c.connsMu.Unlock() + if c.connections == nil { + c.connections = make(map[string]*picoConn) + } + if c.sessionConnections == nil { + c.sessionConnections = make(map[string]map[string]*picoConn) + } + if _, exists := c.connections[pc.id]; exists { + panic(fmt.Sprintf("duplicate conn id in test: %s", pc.id)) + } + c.connections[pc.id] = pc + bySession, ok := c.sessionConnections[pc.sessionID] + if !ok { + bySession = make(map[string]*picoConn) + c.sessionConnections[pc.sessionID] = bySession + } + bySession[pc.id] = pc +} diff --git a/pkg/channels/telegram/telegram.go b/pkg/channels/telegram/telegram.go index d0011d21b..e7da1d615 100644 --- a/pkg/channels/telegram/telegram.go +++ b/pkg/channels/telegram/telegram.go @@ -642,8 +642,12 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Mes } } + if content == "" && len(mediaPaths) == 0 { + return nil + } + if content == "" { - content = "[empty message]" + content = "[media only]" } // In group chats, apply unified group trigger filtering diff --git a/pkg/channels/telegram/telegram_test.go b/pkg/channels/telegram/telegram_test.go index 6bf1077af..fd189d9a7 100644 --- a/pkg/channels/telegram/telegram_test.go +++ b/pkg/channels/telegram/telegram_test.go @@ -641,3 +641,35 @@ func TestHandleMessage_ReplyThread_NonForum_NoIsolation(t *testing.T) { assert.Empty(t, inbound.Metadata["parent_peer_kind"]) assert.Empty(t, inbound.Metadata["parent_peer_id"]) } + +func TestHandleMessage_EmptyContent_Ignored(t *testing.T) { + messageBus := bus.NewMessageBus() + ch := &TelegramChannel{ + BaseChannel: channels.NewBaseChannel("telegram", nil, messageBus, nil), + chatIDs: make(map[string]int64), + ctx: context.Background(), + } + + // Service message with no text/caption/media (like ForumTopicCreated) + msg := &telego.Message{ + MessageID: 123, + Chat: telego.Chat{ + ID: 456, + Type: "group", + }, + From: &telego.User{ + ID: 789, + FirstName: "User", + }, + } + + err := ch.handleMessage(context.Background(), msg) + require.NoError(t, err) + + // Should NOT publish to message bus + select { + case <-messageBus.InboundChan(): + t.Fatal("Empty message should not be published to message bus") + default: + } +} diff --git a/pkg/channels/wecom/aibot.go b/pkg/channels/wecom/aibot.go deleted file mode 100644 index c5e148185..000000000 --- a/pkg/channels/wecom/aibot.go +++ /dev/null @@ -1,1099 +0,0 @@ -package wecom - -import ( - "bytes" - "context" - "crypto/rand" - "encoding/base64" - "encoding/json" - "fmt" - "io" - "math/big" - "net/http" - "strings" - "sync" - "time" - - "github.com/sipeed/picoclaw/pkg/bus" - "github.com/sipeed/picoclaw/pkg/channels" - "github.com/sipeed/picoclaw/pkg/config" - "github.com/sipeed/picoclaw/pkg/identity" - "github.com/sipeed/picoclaw/pkg/logger" - "github.com/sipeed/picoclaw/pkg/utils" -) - -// responseURLHTTPClient is a shared HTTP client for posting to WeCom response_url. -// Reusing it enables connection pooling across replies. -var responseURLHTTPClient = &http.Client{Timeout: 15 * time.Second} - -// WeComAIBotChannel implements the Channel interface for WeCom AI Bot (企业微信智能机器人) -type WeComAIBotChannel struct { - *channels.BaseChannel - config config.WeComAIBotConfig - ctx context.Context - cancel context.CancelFunc - streamTasks map[string]*streamTask // streamID -> task (for poll lookups) - chatTasks map[string][]*streamTask // chatID -> in-flight tasks queue (FIFO) - taskMu sync.RWMutex -} - -// streamTask represents a streaming task for AI Bot. -// -// Mutable fields (Finished, StreamClosed, StreamClosedAt) must be read/written -// while holding WeComAIBotChannel.taskMu. Immutable fields (StreamID, ChatID, -// ResponseURL, Question, CreatedTime, Deadline, answerCh, ctx, cancel) are set -// once at creation and never modified, so they are safe to read without a lock. -type streamTask struct { - // immutable after creation - StreamID string - ChatID string // used by Send() to find this task - ResponseURL string // temporary URL for proactive reply (valid 1 hour, use once) - Question string - CreatedTime time.Time - Deadline time.Time // ~30s, we close the stream here and switch to response_url - answerCh chan string // receives agent reply from Send() - ctx context.Context // canceled when task is removed; used to interrupt the agent goroutine - cancel context.CancelFunc // call on task removal to cancel ctx - - // mutable — guarded by WeComAIBotChannel.taskMu - StreamClosed bool // stream returned finish:true; waiting for agent to reply via response_url - StreamClosedAt time.Time // set when StreamClosed becomes true; used for accelerated cleanup - Finished bool // fully done -} - -// WeComAIBotMessage represents the decrypted JSON message from WeCom AI Bot -// Ref: https://developer.work.weixin.qq.com/document/path/100719 -type WeComAIBotMessage struct { - MsgID string `json:"msgid"` - AIBotID string `json:"aibotid"` - ChatID string `json:"chatid"` // only for group chat - ChatType string `json:"chattype"` // "single" or "group" - From struct { - UserID string `json:"userid"` - } `json:"from"` - ResponseURL string `json:"response_url"` // temporary URL for proactive reply - MsgType string `json:"msgtype"` - // text message - Text *struct { - Content string `json:"content"` - } `json:"text,omitempty"` - // stream polling refresh - Stream *struct { - ID string `json:"id"` - } `json:"stream,omitempty"` - // image message - Image *struct { - URL string `json:"url"` - } `json:"image,omitempty"` - // mixed message (text + image) - Mixed *struct { - MsgItem []struct { - MsgType string `json:"msgtype"` - Text *struct { - Content string `json:"content"` - } `json:"text,omitempty"` - Image *struct { - URL string `json:"url"` - } `json:"image,omitempty"` - } `json:"msg_item"` - } `json:"mixed,omitempty"` - // event field - Event *struct { - EventType string `json:"eventtype"` - } `json:"event,omitempty"` -} - -// WeComAIBotMsgItemImage holds the image payload inside a stream message item. -type WeComAIBotMsgItemImage struct { - Base64 string `json:"base64"` - MD5 string `json:"md5"` -} - -// WeComAIBotMsgItem is a single item inside a stream's msg_item list. -type WeComAIBotMsgItem struct { - MsgType string `json:"msgtype"` - Image *WeComAIBotMsgItemImage `json:"image,omitempty"` -} - -// WeComAIBotStreamInfo represents the detailed stream content in streaming responses. -type WeComAIBotStreamInfo struct { - ID string `json:"id"` - Finish bool `json:"finish"` - Content string `json:"content,omitempty"` - MsgItem []WeComAIBotMsgItem `json:"msg_item,omitempty"` -} - -// WeComAIBotStreamResponse represents the streaming response format -type WeComAIBotStreamResponse struct { - MsgType string `json:"msgtype"` - Stream WeComAIBotStreamInfo `json:"stream"` -} - -// WeComAIBotEncryptedResponse represents the encrypted response wrapper -// Fields match WXBizJsonMsgCrypt.generate() in Python SDK -type WeComAIBotEncryptedResponse struct { - Encrypt string `json:"encrypt"` - MsgSignature string `json:"msgsignature"` - Timestamp string `json:"timestamp"` - Nonce string `json:"nonce"` -} - -// NewWeComAIBotChannel creates a WeCom AI Bot channel instance. -// If cfg.BotID and cfg.secret are both set, it returns a WeComAIBotWSChannel -// using the WebSocket long-connection API. -// Otherwise it returns the webhook-mode WeComAIBotChannel (requires Token + -// EncodingAESKey). -func NewWeComAIBotChannel( - cfg config.WeComAIBotConfig, - messageBus *bus.MessageBus, -) (channels.Channel, error) { - // WebSocket long-connection mode takes priority when BotID + secret are set. - if cfg.BotID != "" && cfg.Secret() != "" { - logger.InfoC("wecom_aibot", "BotID and secret provided, using WebSocket mode") - return newWeComAIBotWSChannel(cfg, messageBus) - } - // Webhook (short-connection) mode. - if cfg.Token() == "" || cfg.EncodingAESKey() == "" { - return nil, fmt.Errorf( - "WeCom AI Bot requires either (bot_id + secret) for WebSocket mode " + - "or (token + encoding_aes_key) for webhook mode") - } - if cfg.ProcessingMessage == "" { - cfg.ProcessingMessage = config.DefaultWeComAIBotProcessingMessage - } - - base := channels.NewBaseChannel("wecom_aibot", cfg, messageBus, cfg.AllowFrom, - channels.WithMaxMessageLength(2048), - channels.WithReasoningChannelID(cfg.ReasoningChannelID), - ) - - return &WeComAIBotChannel{ - BaseChannel: base, - config: cfg, - streamTasks: make(map[string]*streamTask), - chatTasks: make(map[string][]*streamTask), - }, nil -} - -// Name returns the channel name -func (c *WeComAIBotChannel) Name() string { - return "wecom_aibot" -} - -// Start initializes the WeCom AI Bot channel -func (c *WeComAIBotChannel) Start(ctx context.Context) error { - logger.InfoC("wecom_aibot", "Starting WeCom AI Bot channel...") - - c.ctx, c.cancel = context.WithCancel(ctx) - - // Start cleanup goroutine for old tasks - go c.cleanupLoop() - - c.SetRunning(true) - logger.InfoC("wecom_aibot", "WeCom AI Bot channel started") - - return nil -} - -// Stop gracefully stops the WeCom AI Bot channel -func (c *WeComAIBotChannel) Stop(ctx context.Context) error { - logger.InfoC("wecom_aibot", "Stopping WeCom AI Bot channel...") - - if c.cancel != nil { - c.cancel() - } - - c.SetRunning(false) - logger.InfoC("wecom_aibot", "WeCom AI Bot channel stopped") - return nil -} - -// Send delivers the agent reply into the active streamTask for msg.ChatID. -// It writes into the earliest unfinished task in the queue (FIFO per chatID). -// If the stream has already closed (deadline passed), it posts directly to response_url. -func (c *WeComAIBotChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { - if !c.IsRunning() { - return channels.ErrNotRunning - } - c.taskMu.Lock() - queue := c.chatTasks[msg.ChatID] - // Only compact Finished tasks at the head of the queue. - // Tasks that are Finished in the middle are NOT removed here: doing a full - // scan on every Send() call would be O(n) and is unnecessary given that - // removeTask() always splices the task out of the queue immediately. - // Any Finished task left stranded in the middle (e.g. due to an unexpected - // code path) will be collected by cleanupOldTasks. - for len(queue) > 0 && queue[0].Finished { - queue = queue[1:] - } - c.chatTasks[msg.ChatID] = queue - var task *streamTask - var streamClosed bool - var responseURL string - if len(queue) > 0 { - task = queue[0] - // Read mutable fields while holding c.taskMu to avoid data races. - streamClosed = task.StreamClosed - responseURL = task.ResponseURL - } - c.taskMu.Unlock() - - if task == nil { - logger.DebugCF( - "wecom_aibot", - "Send: no active task for chat (may have timed out)", - map[string]any{ - "chat_id": msg.ChatID, - }, - ) - return nil - } - - if streamClosed { - // Stream already ended with a "please wait" notice; send the real reply via response_url. - // Note: task.StreamID and task.ChatID are immutable, safe to read without a lock. - logger.InfoCF("wecom_aibot", "Sending reply via response_url", map[string]any{ - "stream_id": task.StreamID, - "chat_id": msg.ChatID, - }) - if responseURL != "" { - if err := c.sendViaResponseURL(responseURL, msg.Content); err != nil { - logger.ErrorCF("wecom_aibot", "Failed to send via response_url", map[string]any{ - "error": err, - "stream_id": task.StreamID, - }) - c.removeTask(task) - return fmt.Errorf("response_url delivery failed: %w", channels.ErrSendFailed) - } - } else { - logger.WarnCF("wecom_aibot", "Stream closed but no response_url available", map[string]any{ - "stream_id": task.StreamID, - }) - } - c.removeTask(task) - return nil - } - - // Stream still open: deliver via answerCh for the next poll response. - select { - case task.answerCh <- msg.Content: - case <-task.ctx.Done(): - // Task was canceled (cleanup removed it); silently drop the reply. - return nil - case <-ctx.Done(): - return ctx.Err() - } - return nil -} - -// WebhookPath returns the path for registering on the shared HTTP server -func (c *WeComAIBotChannel) WebhookPath() string { - if c.config.WebhookPath == "" { - return "/webhook/wecom-aibot" - } - return c.config.WebhookPath -} - -// ServeHTTP implements http.Handler for the shared HTTP server -func (c *WeComAIBotChannel) ServeHTTP(w http.ResponseWriter, r *http.Request) { - c.handleWebhook(w, r) -} - -// HealthPath returns the health check endpoint path -func (c *WeComAIBotChannel) HealthPath() string { - return c.WebhookPath() + "/health" -} - -// HealthHandler handles health check requests -func (c *WeComAIBotChannel) HealthHandler(w http.ResponseWriter, r *http.Request) { - c.handleHealth(w, r) -} - -// handleWebhook handles incoming webhook requests from WeCom AI Bot -func (c *WeComAIBotChannel) handleWebhook(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - - // Log all incoming requests for debugging - logger.DebugCF("wecom_aibot", "Received webhook request", map[string]any{ - "method": r.Method, - "path": r.URL.Path, - "query": r.URL.RawQuery, - }) - - switch r.Method { - case http.MethodGet: - // URL verification - c.handleVerification(ctx, w, r) - case http.MethodPost: - // Message callback - c.handleMessageCallback(ctx, w, r) - default: - http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) - } -} - -// handleVerification handles the URL verification request from WeCom -func (c *WeComAIBotChannel) handleVerification( - ctx context.Context, - w http.ResponseWriter, - r *http.Request, -) { - msgSignature := r.URL.Query().Get("msg_signature") - timestamp := r.URL.Query().Get("timestamp") - nonce := r.URL.Query().Get("nonce") - echostr := r.URL.Query().Get("echostr") - - logger.DebugCF("wecom_aibot", "URL verification request", map[string]any{ - "msg_signature": msgSignature, - "timestamp": timestamp, - "nonce": nonce, - }) - - // Verify signature - if !verifySignature(c.config.Token(), msgSignature, timestamp, nonce, echostr) { - logger.ErrorC("wecom_aibot", "Signature verification failed") - http.Error(w, "Signature verification failed", http.StatusUnauthorized) - return - } - - // Decrypt echostr - // For WeCom AI Bot (智能机器人), receiveid should be empty string - decrypted, err := decryptMessageWithVerify(echostr, c.config.EncodingAESKey(), "") - if err != nil { - logger.ErrorCF("wecom_aibot", "Failed to decrypt echostr", map[string]any{ - "error": err, - }) - http.Error(w, "Decryption failed", http.StatusInternalServerError) - return - } - - // Remove BOM and whitespace as per WeCom documentation - decrypted = strings.TrimPrefix(decrypted, "\ufeff") - decrypted = strings.TrimSpace(decrypted) - - logger.InfoC("wecom_aibot", "URL verification successful") - w.Header().Set("Content-Type", "text/plain; charset=utf-8") - w.WriteHeader(http.StatusOK) - w.Write([]byte(decrypted)) -} - -// handleMessageCallback handles incoming messages from WeCom AI Bot -func (c *WeComAIBotChannel) handleMessageCallback( - ctx context.Context, - w http.ResponseWriter, - r *http.Request, -) { - msgSignature := r.URL.Query().Get("msg_signature") - timestamp := r.URL.Query().Get("timestamp") - nonce := r.URL.Query().Get("nonce") - - // Read request body (limit to 4 MB to prevent memory exhaustion) - const maxBodySize = 4 << 20 // 4 MB - body, err := io.ReadAll(io.LimitReader(r.Body, maxBodySize+1)) - if err != nil { - logger.ErrorCF("wecom_aibot", "Failed to read request body", map[string]any{ - "error": err, - }) - http.Error(w, "Failed to read body", http.StatusBadRequest) - return - } - if len(body) > maxBodySize { - http.Error(w, "Request body too large", http.StatusRequestEntityTooLarge) - return - } - - // Parse JSON body to get encrypted message - // Format: {"encrypt": "base64_encrypted_string"} - var encryptedMsg struct { - Encrypt string `json:"encrypt"` - } - if unmarshalErr := json.Unmarshal(body, &encryptedMsg); unmarshalErr != nil { - logger.ErrorCF("wecom_aibot", "Failed to parse JSON body", map[string]any{ - "error": unmarshalErr, - "body": string(body), - }) - http.Error(w, "Failed to parse JSON", http.StatusBadRequest) - return - } - - // Verify signature - if !verifySignature(c.config.Token(), msgSignature, timestamp, nonce, encryptedMsg.Encrypt) { - logger.ErrorC("wecom_aibot", "Signature verification failed") - http.Error(w, "Signature verification failed", http.StatusUnauthorized) - return - } - - // Decrypt message - // For WeCom AI Bot (智能机器人), receiveid is empty string - decrypted, err := decryptMessageWithVerify(encryptedMsg.Encrypt, c.config.EncodingAESKey(), "") - if err != nil { - logger.ErrorCF("wecom_aibot", "Failed to decrypt message", map[string]any{ - "error": err, - }) - http.Error(w, "Decryption failed", http.StatusInternalServerError) - return - } - - // Parse decrypted JSON message - var msg WeComAIBotMessage - if unmarshalErr := json.Unmarshal([]byte(decrypted), &msg); unmarshalErr != nil { - logger.ErrorCF("wecom_aibot", "Failed to parse decrypted JSON", map[string]any{ - "error": unmarshalErr, - "decrypted": decrypted, - }) - http.Error(w, "Failed to parse message", http.StatusInternalServerError) - return - } - - logger.DebugCF("wecom_aibot", "Decrypted message", map[string]any{ - "msgtype": msg.MsgType, - }) - - // Process the message and get streaming response - response := c.processMessage(ctx, msg, timestamp, nonce) - - // Check if response is empty (e.g. due to unsupported message type) - if response == "" { - response = c.encryptEmptyResponse(timestamp, nonce) - } - - // Return encrypted JSON response - w.Header().Set("Content-Type", "application/json; charset=utf-8") - w.WriteHeader(http.StatusOK) - w.Write([]byte(response)) -} - -// processMessage processes the received message and returns encrypted response -func (c *WeComAIBotChannel) processMessage( - ctx context.Context, - msg WeComAIBotMessage, - timestamp, nonce string, -) string { - logger.DebugCF("wecom_aibot", "Processing message", map[string]any{ - "msgtype": msg.MsgType, - }) - - switch msg.MsgType { - case "text": - return c.handleTextMessage(ctx, msg, timestamp, nonce) - case "stream": - return c.handleStreamMessage(ctx, msg, timestamp, nonce) - case "image": - return c.handleImageMessage(ctx, msg, timestamp, nonce) - case "mixed": - return c.handleMixedMessage(ctx, msg, timestamp, nonce) - case "event": - return c.handleEventMessage(ctx, msg, timestamp, nonce) - default: - logger.WarnCF("wecom_aibot", "Unsupported message type", map[string]any{ - "msgtype": msg.MsgType, - }) - return c.encryptResponse("", timestamp, nonce, WeComAIBotStreamResponse{ - MsgType: "stream", - Stream: WeComAIBotStreamInfo{ - ID: c.generateStreamID(), - Finish: true, - Content: "Unsupported message type: " + msg.MsgType, - }, - }) - } -} - -// handleTextMessage handles text messages by starting a new streaming task -func (c *WeComAIBotChannel) handleTextMessage( - ctx context.Context, - msg WeComAIBotMessage, - timestamp, nonce string, -) string { - if msg.Text == nil { - logger.ErrorC("wecom_aibot", "text message missing text field") - return c.encryptEmptyResponse(timestamp, nonce) - } - - content := msg.Text.Content - userID := msg.From.UserID - if userID == "" { - userID = "unknown" - } - - // chatID: group chat uses chatid, single chat uses userid - chatID := msg.ChatID - if chatID == "" { - chatID = userID - } - - streamID := c.generateStreamID() - - // WeCom stops sending stream-refresh callbacks after 6 minutes. - // Set a slightly shorter deadline so we can send a timeout notice before it gives up. - deadline := time.Now().Add(30 * time.Second) - - // Each task gets its own context derived from the channel lifetime context. - // Canceling taskCancel interrupts the agent goroutine when the task is removed. - taskCtx, taskCancel := context.WithCancel(c.ctx) - - task := &streamTask{ - StreamID: streamID, - ChatID: chatID, - ResponseURL: msg.ResponseURL, - Question: content, - CreatedTime: time.Now(), - Deadline: deadline, - Finished: false, - answerCh: make(chan string, 1), - ctx: taskCtx, - cancel: taskCancel, - } - - c.taskMu.Lock() - c.streamTasks[streamID] = task - c.chatTasks[chatID] = append(c.chatTasks[chatID], task) - c.taskMu.Unlock() - - // Publish to agent asynchronously; agent will call Send() with reply. - // Use task.ctx (not c.ctx) so the agent goroutine is canceled when the task is removed. - go func() { - sender := bus.SenderInfo{ - Platform: "wecom_aibot", - PlatformID: userID, - CanonicalID: identity.BuildCanonicalID("wecom_aibot", userID), - DisplayName: userID, - } - peerKind := "direct" - if msg.ChatType == "group" { - peerKind = "group" - } - peer := bus.Peer{Kind: peerKind, ID: chatID} - metadata := map[string]string{ - "channel": "wecom_aibot", - "chat_type": msg.ChatType, - "msg_type": "text", - "msgid": msg.MsgID, - "aibotid": msg.AIBotID, - "stream_id": streamID, - "response_url": msg.ResponseURL, - } - c.HandleMessage(task.ctx, peer, msg.MsgID, userID, chatID, - content, nil, metadata, sender) - }() - - // Return first streaming response immediately (finish=false, content empty) - return c.getStreamResponse(task, timestamp, nonce) -} - -// handleStreamMessage handles stream polling requests -func (c *WeComAIBotChannel) handleStreamMessage( - ctx context.Context, - msg WeComAIBotMessage, - timestamp, nonce string, -) string { - if msg.Stream == nil { - logger.ErrorC("wecom_aibot", "Stream message missing stream field") - return c.encryptEmptyResponse(timestamp, nonce) - } - - streamID := msg.Stream.ID - - c.taskMu.RLock() - task, exists := c.streamTasks[streamID] - c.taskMu.RUnlock() - - if !exists { - logger.DebugCF( - "wecom_aibot", - "Stream task not found (may be from previous session)", - map[string]any{ - "stream_id": streamID, - }, - ) - return c.encryptResponse(streamID, timestamp, nonce, WeComAIBotStreamResponse{ - MsgType: "stream", - Stream: WeComAIBotStreamInfo{ - ID: streamID, - Finish: true, - Content: "Task not found or already finished. Please resend your message to start a new session.", - }, - }) - } - - // Get next response - return c.getStreamResponse(task, timestamp, nonce) -} - -// handleImageMessage handles image messages -func (c *WeComAIBotChannel) handleImageMessage( - ctx context.Context, - msg WeComAIBotMessage, - timestamp, nonce string, -) string { - logger.WarnC("wecom_aibot", "Image message type not yet fully implemented") - if msg.Image == nil { - logger.ErrorC("wecom_aibot", "Image message missing image field") - return c.encryptEmptyResponse(timestamp, nonce) - } - - imageURL := msg.Image.URL - - // For now, just acknowledge receipt without echoing the image - return c.encryptResponse("", timestamp, nonce, WeComAIBotStreamResponse{ - MsgType: "stream", - Stream: WeComAIBotStreamInfo{ - ID: c.generateStreamID(), - Finish: true, - Content: fmt.Sprintf( - "Image received (URL: %s), but image messages are not yet supported", - imageURL, - ), - }, - }) -} - -// handleMixedMessage handles mixed (text + image) messages -func (c *WeComAIBotChannel) handleMixedMessage( - ctx context.Context, - msg WeComAIBotMessage, - timestamp, nonce string, -) string { - logger.WarnC("wecom_aibot", "Mixed message type not yet fully implemented") - return c.encryptResponse("", timestamp, nonce, WeComAIBotStreamResponse{ - MsgType: "stream", - Stream: WeComAIBotStreamInfo{ - ID: c.generateStreamID(), - Finish: true, - Content: "Mixed message type is not yet supported", - }, - }) -} - -// handleEventMessage handles event messages -func (c *WeComAIBotChannel) handleEventMessage( - ctx context.Context, - msg WeComAIBotMessage, - timestamp, nonce string, -) string { - eventType := "" - if msg.Event != nil { - eventType = msg.Event.EventType - } - logger.DebugCF("wecom_aibot", "Received event", map[string]any{ - "event_type": eventType, - }) - - // Send welcome message when user opens the chat window - if eventType == "enter_chat" && c.config.WelcomeMessage != "" { - streamID := c.generateStreamID() - return c.encryptResponse(streamID, timestamp, nonce, WeComAIBotStreamResponse{ - MsgType: "stream", - Stream: WeComAIBotStreamInfo{ - ID: streamID, - Finish: true, - Content: c.config.WelcomeMessage, - }, - }) - } - - return c.encryptEmptyResponse(timestamp, nonce) -} - -// getStreamResponse gets the next streaming response for a task. -// - If agent replied: return finish=true with the real answer. -// - If deadline passed: return finish=true with a "please wait" notice, keep task alive for response_url. -// - Otherwise: return finish=false (empty), client will poll again. -func (c *WeComAIBotChannel) getStreamResponse(task *streamTask, timestamp, nonce string) string { - var content string - var finish bool - var closeStreamOnly bool // close stream but do NOT remove task (response_url still pending) - - select { - case answer := <-task.answerCh: - // Agent replied before deadline — normal finish. - content = answer - finish = true - default: - if time.Now().After(task.Deadline) { - // Deadline reached: close the stream with a notice, then wait for agent via response_url. - content = c.config.ProcessingMessage - finish = true - closeStreamOnly = true - logger.InfoCF( - "wecom_aibot", - "Stream deadline reached, switching to response_url mode", - map[string]any{ - "stream_id": task.StreamID, - "chat_id": task.ChatID, - "response_url": task.ResponseURL != "", - }, - ) - } - // else: still waiting, return finish=false - } - - if finish && !closeStreamOnly { - // Normal finish: remove from all maps. - c.removeTask(task) - } else if closeStreamOnly { - // Mark stream as closed and remove from streamTasks under a single lock - // to keep StreamClosed/StreamClosedAt consistent with map membership. - c.taskMu.Lock() - task.StreamClosed = true - task.StreamClosedAt = time.Now() - delete(c.streamTasks, task.StreamID) - c.taskMu.Unlock() - } - - response := WeComAIBotStreamResponse{ - MsgType: "stream", - Stream: WeComAIBotStreamInfo{ - ID: task.StreamID, - Finish: finish, - Content: content, - }, - } - - return c.encryptResponse(task.StreamID, timestamp, nonce, response) -} - -// removeTask removes a task from both streamTasks and chatTasks, marks it finished, -// and cancels its context to interrupt the associated agent goroutine. -func (c *WeComAIBotChannel) removeTask(task *streamTask) { - // Cancel first so the agent goroutine stops as soon as possible, - // before we acquire the write lock. - task.cancel() - - c.taskMu.Lock() - task.Finished = true // written under c.taskMu, consistent with all readers - delete(c.streamTasks, task.StreamID) - queue := c.chatTasks[task.ChatID] - for i, t := range queue { - if t == task { - c.chatTasks[task.ChatID] = append(queue[:i], queue[i+1:]...) - break - } - } - if len(c.chatTasks[task.ChatID]) == 0 { - delete(c.chatTasks, task.ChatID) - } - c.taskMu.Unlock() -} - -// sendViaResponseURL posts a markdown reply to the WeCom response_url. -// response_url is valid for 1 hour and can only be used once per callback. -// Returned errors are wrapped with channels.ErrRateLimit, channels.ErrTemporary, -// or channels.ErrSendFailed so the manager can apply the right retry policy. -func (c *WeComAIBotChannel) sendViaResponseURL(responseURL, content string) error { - payload := map[string]any{ - "msgtype": "markdown", - "markdown": map[string]string{ - "content": content, - }, - } - body, err := json.Marshal(payload) - if err != nil { - return fmt.Errorf("failed to marshal payload: %w", err) - } - - ctx, cancel := context.WithTimeout(c.ctx, 15*time.Second) - defer cancel() - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, responseURL, bytes.NewBuffer(body)) - if err != nil { - return fmt.Errorf("failed to create request: %w", err) - } - req.Header.Set("Content-Type", "application/json; charset=utf-8") - - resp, err := responseURLHTTPClient.Do(req) - if err != nil { - return fmt.Errorf("post to response_url failed: %w: %w", channels.ErrTemporary, err) - } - defer resp.Body.Close() - - if resp.StatusCode == http.StatusOK { - return nil - } - - const maxErrBody = 64 << 10 // 64 KB is more than enough for any error response - respBody, err := io.ReadAll(io.LimitReader(resp.Body, maxErrBody)) - if err != nil { - return fmt.Errorf("reading response_url body: %w: %w", channels.ErrTemporary, err) - } - switch { - case resp.StatusCode == http.StatusTooManyRequests: - return fmt.Errorf("response_url rate limited (%d): %s: %w", - resp.StatusCode, respBody, channels.ErrRateLimit) - case resp.StatusCode >= 500: - return fmt.Errorf("response_url server error (%d): %s: %w", - resp.StatusCode, respBody, channels.ErrTemporary) - default: - return fmt.Errorf("response_url returned %d: %s: %w", - resp.StatusCode, respBody, channels.ErrSendFailed) - } -} - -// encryptResponse encrypts a streaming response -func (c *WeComAIBotChannel) encryptResponse( - streamID, timestamp, nonce string, - response WeComAIBotStreamResponse, -) string { - // Marshal response to JSON - plaintext, err := json.Marshal(response) - if err != nil { - logger.ErrorCF("wecom_aibot", "Failed to marshal response", map[string]any{ - "error": err, - }) - return "" - } - - logger.DebugCF("wecom_aibot", "Encrypting response", map[string]any{ - "stream_id": streamID, - "finish": response.Stream.Finish, - "preview": utils.Truncate(response.Stream.Content, 100), - }) - - // Encrypt message - encrypted, err := c.encryptMessage(string(plaintext), "") - if err != nil { - logger.ErrorCF("wecom_aibot", "Failed to encrypt message", map[string]any{ - "error": err, - }) - return "" - } - - // Generate signature - signature := computeSignature(c.config.Token(), timestamp, nonce, encrypted) - - // Build encrypted response - encryptedResp := WeComAIBotEncryptedResponse{ - Encrypt: encrypted, - MsgSignature: signature, - Timestamp: timestamp, - Nonce: nonce, - } - - respJSON, err := json.Marshal(encryptedResp) - if err != nil { - logger.ErrorCF("wecom_aibot", "Failed to marshal encrypted response", map[string]any{ - "error": err, - }) - return "" - } - - logger.DebugCF("wecom_aibot", "Response encrypted", map[string]any{ - "stream_id": streamID, - }) - - return string(respJSON) -} - -// encryptEmptyResponse returns a minimal valid encrypted response -func (c *WeComAIBotChannel) encryptEmptyResponse(timestamp, nonce string) string { - // Construct a zero-value stream response and encrypt it so that - // WeCom always receives a syntactically valid encrypted JSON object. - emptyResp := WeComAIBotStreamResponse{} - return c.encryptResponse("", timestamp, nonce, emptyResp) -} - -// encryptMessage encrypts a plain text message for WeCom AI Bot -func (c *WeComAIBotChannel) encryptMessage(plaintext, receiveid string) (string, error) { - aesKey, err := decodeWeComAESKey(c.config.EncodingAESKey()) - if err != nil { - return "", err - } - - frame, err := packWeComFrame(plaintext, receiveid) - if err != nil { - return "", err - } - - // PKCS7 padding then AES-CBC encrypt - paddedFrame := pkcs7Pad(frame, blockSize) - ciphertext, err := encryptAESCBC(aesKey, paddedFrame) - if err != nil { - return "", err - } - - return base64.StdEncoding.EncodeToString(ciphertext), nil -} - -// func (c *WeComAIBotChannel) downloadAndDecryptImage( -// ctx context.Context, -// imageURL string, -// ) ([]byte, error) { -// // Download image -// req, err := http.NewRequestWithContext(ctx, http.MethodGet, imageURL, nil) -// if err != nil { -// return nil, fmt.Errorf("failed to create request: %w", err) -// } - -// client := &http.Client{ -// Timeout: 15 * time.Second, -// } - -// resp, err := client.Do(req) -// if err != nil { -// return nil, fmt.Errorf("failed to download image: %w", err) -// } -// defer resp.Body.Close() - -// if resp.StatusCode != http.StatusOK { -// return nil, fmt.Errorf("download failed with status: %d", resp.StatusCode) -// } - -// // Limit image download to 20 MB to prevent memory exhaustion -// const maxImageSize = 20 << 20 // 20 MB -// encryptedData, err := io.ReadAll(io.LimitReader(resp.Body, maxImageSize+1)) -// if err != nil { -// return nil, fmt.Errorf("failed to read image data: %w", err) -// } -// if len(encryptedData) > maxImageSize { -// return nil, fmt.Errorf("image too large (exceeds %d MB)", maxImageSize>>20) -// } - -// logger.DebugCF("wecom_aibot", "Image downloaded", map[string]any{ -// "size": len(encryptedData), -// }) - -// // Decode AES key -// aesKey, err := decodeWeComAESKey(c.config.EncodingAESKey) -// if err != nil { -// return nil, err -// } - -// // Decrypt image (AES-CBC with IV = first 16 bytes of key, PKCS7 padding stripped) -// decryptedData, err := decryptAESCBC(aesKey, encryptedData) -// if err != nil { -// return nil, fmt.Errorf("failed to decrypt image: %w", err) -// } - -// logger.DebugCF("wecom_aibot", "Image decrypted", map[string]any{ -// "size": len(decryptedData), -// }) - -// return decryptedData, nil -// } - -// generateRandomID generates a cryptographically random alphanumeric ID of -// length n. Used for stream IDs and WebSocket request IDs. -func generateRandomID(n int) string { - const letters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" - b := make([]byte, n) - for i := range b { - num, _ := rand.Int(rand.Reader, big.NewInt(int64(len(letters)))) - b[i] = letters[num.Int64()] - } - return string(b) -} - -// generateStreamID generates a random 10-character stream ID (webhook mode). -func (c *WeComAIBotChannel) generateStreamID() string { - return generateRandomID(10) -} - -// cleanupLoop periodically cleans up old streaming tasks -func (c *WeComAIBotChannel) cleanupLoop() { - ticker := time.NewTicker(5 * time.Minute) - defer ticker.Stop() - - for { - select { - case <-ticker.C: - c.cleanupOldTasks() - case <-c.ctx.Done(): - return - } - } -} - -// cleanupOldTasks removes tasks that have exceeded their expected lifetime: -// - Active tasks (in streamTasks): cleaned up after 1 hour (response_url validity window). -// - StreamClosed tasks (in chatTasks only): cleaned up after streamClosedGracePeriod. -// These tasks are waiting for the agent to call Send() via response_url. If the agent -// crashes or times out without calling Send(), we must not let them accumulate indefinitely. -// The grace period is generous enough to cover typical LLM latency but far shorter than 1 hour, -// preventing chatTasks from filling up when many requests time out in quick succession. -const ( - streamClosedGracePeriod = 10 * time.Minute // max wait for agent after stream closes - taskMaxLifetime = 1 * time.Hour // absolute max (≈ response_url validity) -) - -func (c *WeComAIBotChannel) cleanupOldTasks() { - c.taskMu.Lock() - defer c.taskMu.Unlock() - - now := time.Now() - cutoff := now.Add(-taskMaxLifetime) - for id, task := range c.streamTasks { - if task.CreatedTime.Before(cutoff) { - delete(c.streamTasks, id) - task.cancel() // interrupt agent goroutine still waiting for LLM - queue := c.chatTasks[task.ChatID] - for i, t := range queue { - if t == task { - c.chatTasks[task.ChatID] = append(queue[:i], queue[i+1:]...) - break - } - } - if len(c.chatTasks[task.ChatID]) == 0 { - delete(c.chatTasks, task.ChatID) - } - logger.DebugCF("wecom_aibot", "Cleaned up expired task", map[string]any{ - "stream_id": id, - }) - } - } - // Clean up StreamClosed tasks from chatTasks. - // Two expiry conditions are checked: - // 1. Absolute expiry: task was created more than taskMaxLifetime ago. - // 2. Grace expiry: stream closed more than streamClosedGracePeriod ago - // (agent had enough time to reply; it is not coming back). - for chatID, queue := range c.chatTasks { - filtered := queue[:0] - for i, t := range queue { - absoluteExpired := t.CreatedTime.Before(cutoff) - graceExpired := t.StreamClosed && - !t.StreamClosedAt.IsZero() && - t.StreamClosedAt.Before(now.Add(-streamClosedGracePeriod)) - if t.Finished { - // Finished tasks should have been removed by removeTask(). - // Finding one here (especially not at position 0) means an - // unexpected code path left it stranded, causing the queue to - // grow silently. Log a warning so it is visible, then drop it. - if i > 0 { - logger.WarnCF("wecom_aibot", - "Found stranded Finished task in the middle of chatTasks queue; "+ - "this should not happen — removeTask() should have spliced it out", - map[string]any{ - "chat_id": chatID, - "stream_id": t.StreamID, - "position": i, - }) - } - // The task is already finished; its context was already canceled - // by removeTask(), so no further action is required. - continue - } else if !absoluteExpired && !graceExpired { - filtered = append(filtered, t) - } else { - t.cancel() // cancel any lingering agent goroutine - } - } - if len(filtered) == 0 { - delete(c.chatTasks, chatID) - } else { - c.chatTasks[chatID] = filtered - } - } -} - -// handleHealth handles health check requests -func (c *WeComAIBotChannel) handleHealth(w http.ResponseWriter, r *http.Request) { - status := "ok" - if !c.IsRunning() { - status = "not running" - } - - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(map[string]string{ - "status": status, - }) -} diff --git a/pkg/channels/wecom/aibot_test.go b/pkg/channels/wecom/aibot_test.go deleted file mode 100644 index 11c4393d6..000000000 --- a/pkg/channels/wecom/aibot_test.go +++ /dev/null @@ -1,559 +0,0 @@ -package wecom - -import ( - "context" - "encoding/json" - "testing" - "time" - - "github.com/sipeed/picoclaw/pkg/bus" - "github.com/sipeed/picoclaw/pkg/channels" - "github.com/sipeed/picoclaw/pkg/config" -) - -// ---- Webhook mode tests ---- - -func TestNewWeComAIBotChannel_WebhookMode(t *testing.T) { - t.Run("success with valid config", func(t *testing.T) { - cfg := config.WeComAIBotConfig{} - cfg.Enabled = true - cfg.SetToken("test_token") - cfg.SetEncodingAESKey("testkey1234567890123456789012345678901234567") - cfg.WebhookPath = "/webhook/test" - - messageBus := bus.NewMessageBus() - ch, err := NewWeComAIBotChannel(cfg, messageBus) - if err != nil { - t.Fatalf("Expected no error, got %v", err) - } - if ch == nil { - t.Fatal("Expected channel to be created") - } - if ch.Name() != "wecom_aibot" { - t.Errorf("Expected name 'wecom_aibot', got '%s'", ch.Name()) - } - // Webhook mode must implement WebhookHandler. - if _, ok := ch.(channels.WebhookHandler); !ok { - t.Error("Webhook mode channel should implement WebhookHandler") - } - }) - - t.Run("error with missing token", func(t *testing.T) { - cfg := config.WeComAIBotConfig{} - cfg.Enabled = true - cfg.SetEncodingAESKey("testkey1234567890123456789012345678901234567") - - messageBus := bus.NewMessageBus() - _, err := NewWeComAIBotChannel(cfg, messageBus) - if err == nil { - t.Fatal("Expected error for missing token, got nil") - } - }) - - t.Run("error with missing encoding key", func(t *testing.T) { - cfg := config.WeComAIBotConfig{} - cfg.Enabled = true - cfg.SetToken("test_token") - - messageBus := bus.NewMessageBus() - _, err := NewWeComAIBotChannel(cfg, messageBus) - if err == nil { - t.Fatal("Expected error for missing encoding key, got nil") - } - }) -} - -func TestWeComAIBotWebhookChannelStartStop(t *testing.T) { - cfg := config.WeComAIBotConfig{ - Enabled: true, - } - cfg.SetToken("test_token") - cfg.SetEncodingAESKey("testkey1234567890123456789012345678901234567") - - messageBus := bus.NewMessageBus() - ch, err := NewWeComAIBotChannel(cfg, messageBus) - if err != nil { - t.Fatalf("Failed to create channel: %v", err) - } - - ctx := context.Background() - - if err := ch.Start(ctx); err != nil { - t.Fatalf("Failed to start channel: %v", err) - } - if !ch.IsRunning() { - t.Error("Expected channel to be running after Start") - } - - if err := ch.Stop(ctx); err != nil { - t.Fatalf("Failed to stop channel: %v", err) - } - if ch.IsRunning() { - t.Error("Expected channel to be stopped after Stop") - } -} - -func TestWeComAIBotChannelWebhookPath(t *testing.T) { - t.Run("default path", func(t *testing.T) { - cfg := config.WeComAIBotConfig{} - cfg.Enabled = true - cfg.SetToken("test_token") - cfg.SetEncodingAESKey("testkey1234567890123456789012345678901234567") - - messageBus := bus.NewMessageBus() - ch, _ := NewWeComAIBotChannel(cfg, messageBus) - - wh, ok := ch.(channels.WebhookHandler) - if !ok { - t.Fatal("Expected channel to implement WebhookHandler") - } - expectedPath := "/webhook/wecom-aibot" - if wh.WebhookPath() != expectedPath { - t.Errorf("Expected webhook path '%s', got '%s'", expectedPath, wh.WebhookPath()) - } - }) - - t.Run("custom path", func(t *testing.T) { - customPath := "/custom/webhook" - cfg := config.WeComAIBotConfig{} - cfg.Enabled = true - cfg.SetToken("test_token") - cfg.SetEncodingAESKey("testkey1234567890123456789012345678901234567") - cfg.WebhookPath = customPath - - messageBus := bus.NewMessageBus() - ch, _ := NewWeComAIBotChannel(cfg, messageBus) - - wh, ok := ch.(channels.WebhookHandler) - if !ok { - t.Fatal("Expected channel to implement WebhookHandler") - } - if wh.WebhookPath() != customPath { - t.Errorf("Expected webhook path '%s', got '%s'", customPath, wh.WebhookPath()) - } - }) -} - -func TestWeComAIBotChannelGetStreamResponseProcessingMessage(t *testing.T) { - validAESKey := "abcdefghijklmnopqrstuvwxyz0123456789ABCDEFG" - - t.Run("uses default processing message", func(t *testing.T) { - cfg := config.WeComAIBotConfig{ - Enabled: true, - } - cfg.SetToken("test_token") - cfg.SetEncodingAESKey(validAESKey) - - messageBus := bus.NewMessageBus() - channel, err := NewWeComAIBotChannel(cfg, messageBus) - if err != nil { - t.Fatalf("Failed to create channel: %v", err) - } - ch, ok := channel.(*WeComAIBotChannel) - if !ok { - t.Fatal("Expected webhook mode channel") - } - - task := &streamTask{ - StreamID: "stream-default", - ChatID: "chat-default", - Deadline: time.Now().Add(-time.Second), - } - ch.streamTasks[task.StreamID] = task - ch.chatTasks[task.ChatID] = []*streamTask{task} - - resp := decodeStreamResponse(t, ch, ch.getStreamResponse(task, "1234567890", "nonce")) - - if !resp.Stream.Finish { - t.Fatal("Expected finished stream response after deadline") - } - if resp.Stream.Content != config.DefaultWeComAIBotProcessingMessage { - t.Fatalf("Expected default processing message %q, got %q", - config.DefaultWeComAIBotProcessingMessage, resp.Stream.Content) - } - if !task.StreamClosed { - t.Fatal("Expected task stream to be marked closed") - } - if _, ok := ch.streamTasks[task.StreamID]; ok { - t.Fatal("Expected closed stream task to be removed from streamTasks") - } - if len(ch.chatTasks[task.ChatID]) != 1 { - t.Fatalf("Expected task to remain queued for response_url delivery, got %d entries", - len(ch.chatTasks[task.ChatID])) - } - }) - - t.Run("uses custom processing message", func(t *testing.T) { - cfg := config.WeComAIBotConfig{ - Enabled: true, - ProcessingMessage: "Please wait a moment. The result will be delivered in a follow-up message.", - } - cfg.SetToken("test_token") - cfg.SetEncodingAESKey(validAESKey) - - messageBus := bus.NewMessageBus() - channel, err := NewWeComAIBotChannel(cfg, messageBus) - if err != nil { - t.Fatalf("Failed to create channel: %v", err) - } - ch, ok := channel.(*WeComAIBotChannel) - if !ok { - t.Fatal("Expected webhook mode channel") - } - - task := &streamTask{ - StreamID: "stream-custom", - ChatID: "chat-custom", - Deadline: time.Now().Add(-time.Second), - } - - resp := decodeStreamResponse(t, ch, ch.getStreamResponse(task, "1234567890", "nonce")) - - if resp.Stream.Content != cfg.ProcessingMessage { - t.Fatalf("Expected custom processing message %q, got %q", cfg.ProcessingMessage, resp.Stream.Content) - } - }) -} - -func TestGenerateStreamID(t *testing.T) { - cfg := config.WeComAIBotConfig{} - cfg.Enabled = true - cfg.SetToken("test_token") - cfg.SetEncodingAESKey("testkey1234567890123456789012345678901234567") - - messageBus := bus.NewMessageBus() - ch, _ := NewWeComAIBotChannel(cfg, messageBus) - webhookCh, ok := ch.(*WeComAIBotChannel) - if !ok { - t.Fatal("Expected webhook mode channel") - } - - ids := make(map[string]bool) - for i := 0; i < 100; i++ { - id := webhookCh.generateStreamID() - if len(id) != 10 { - t.Errorf("Expected stream ID length 10, got %d", len(id)) - } - if ids[id] { - t.Errorf("Duplicate stream ID generated: %s", id) - } - ids[id] = true - } -} - -func TestEncryptDecrypt(t *testing.T) { - // Use a valid 43-character base64 key (企业微信标准格式) - cfg := config.WeComAIBotConfig{} - cfg.Enabled = true - cfg.SetToken("test_token") - cfg.SetEncodingAESKey("abcdefghijklmnopqrstuvwxyz0123456789ABCDEFG") // 43 characters - - messageBus := bus.NewMessageBus() - ch, _ := NewWeComAIBotChannel(cfg, messageBus) - webhookCh, ok := ch.(*WeComAIBotChannel) - if !ok { - t.Fatal("Expected webhook mode channel") - } - - plaintext := "Hello, World!" - receiveid := "" - - encrypted, err := webhookCh.encryptMessage(plaintext, receiveid) - if err != nil { - t.Fatalf("Failed to encrypt message: %v", err) - } - if encrypted == "" { - t.Fatal("Encrypted message is empty") - } - - // Decrypt - decrypted, err := decryptMessageWithVerify(encrypted, cfg.EncodingAESKey(), receiveid) - if err != nil { - t.Fatalf("Failed to decrypt message: %v", err) - } - if decrypted != plaintext { - t.Errorf("Expected decrypted message '%s', got '%s'", plaintext, decrypted) - } -} - -func TestGenerateSignature(t *testing.T) { - token := "test_token" - timestamp := "1234567890" - nonce := "test_nonce" - encrypt := "encrypted_msg" - - signature := computeSignature(token, timestamp, nonce, encrypt) - if signature == "" { - t.Error("Generated signature is empty") - } - if !verifySignature(token, signature, timestamp, nonce, encrypt) { - t.Error("Generated signature does not verify correctly") - } -} - -func decodeStreamResponse(t *testing.T, ch *WeComAIBotChannel, encryptedResponse string) WeComAIBotStreamResponse { - t.Helper() - - var wrapped WeComAIBotEncryptedResponse - if err := json.Unmarshal([]byte(encryptedResponse), &wrapped); err != nil { - t.Fatalf("Failed to unmarshal encrypted response: %v", err) - } - - plaintext, err := decryptMessageWithVerify(wrapped.Encrypt, ch.config.EncodingAESKey(), "") - if err != nil { - t.Fatalf("Failed to decrypt response: %v", err) - } - - var resp WeComAIBotStreamResponse - if err := json.Unmarshal([]byte(plaintext), &resp); err != nil { - t.Fatalf("Failed to unmarshal decrypted response: %v", err) - } - - return resp -} - -// ---- WebSocket long-connection mode tests ---- - -func TestNewWeComAIBotChannel_WSMode(t *testing.T) { - t.Run("success with bot_id and secret", func(t *testing.T) { - cfg := config.WeComAIBotConfig{ - Enabled: true, - BotID: "test_bot_id", - } - cfg.SetSecret("test_secret") - messageBus := bus.NewMessageBus() - ch, err := NewWeComAIBotChannel(cfg, messageBus) - if err != nil { - t.Fatalf("Expected no error, got %v", err) - } - if ch == nil { - t.Fatal("Expected channel to be created") - } - if ch.Name() != "wecom_aibot" { - t.Errorf("Expected name 'wecom_aibot', got '%s'", ch.Name()) - } - // WebSocket mode must NOT implement WebhookHandler. - if _, ok := ch.(channels.WebhookHandler); ok { - t.Error("WebSocket mode channel should NOT implement WebhookHandler") - } - }) - - t.Run("ws mode takes priority over webhook fields", func(t *testing.T) { - cfg := config.WeComAIBotConfig{ - Enabled: true, - BotID: "test_bot_id", - } - cfg.SetSecret("test_secret") - cfg.SetToken("also_set") - cfg.SetEncodingAESKey("testkey1234567890123456789012345678901234567") - messageBus := bus.NewMessageBus() - ch, err := NewWeComAIBotChannel(cfg, messageBus) - if err != nil { - t.Fatalf("Expected no error, got %v", err) - } - if _, ok := ch.(*WeComAIBotWSChannel); !ok { - t.Error("Expected WebSocket mode channel when both BotID+secret and Token+Key are set") - } - }) - - t.Run("error with missing bot_id", func(t *testing.T) { - cfg := config.WeComAIBotConfig{ - Enabled: true, - } - cfg.SetSecret("test_secret") - messageBus := bus.NewMessageBus() - _, err := NewWeComAIBotChannel(cfg, messageBus) - // Missing bot_id alone means neither WS mode nor webhook mode is fully configured. - if err == nil { - t.Fatal("Expected error for missing bot_id, got nil") - } - }) - - t.Run("error with missing secret", func(t *testing.T) { - cfg := config.WeComAIBotConfig{ - Enabled: true, - BotID: "test_bot_id", - } - messageBus := bus.NewMessageBus() - _, err := NewWeComAIBotChannel(cfg, messageBus) - if err == nil { - t.Fatal("Expected error for missing secret, got nil") - } - }) -} - -func TestWeComAIBotWSChannelStartStop(t *testing.T) { - cfg := config.WeComAIBotConfig{ - Enabled: true, - BotID: "test_bot_id", - } - cfg.SetSecret("test_secret") - messageBus := bus.NewMessageBus() - ch, err := NewWeComAIBotChannel(cfg, messageBus) - if err != nil { - t.Fatalf("Failed to create channel: %v", err) - } - - ctx := context.Background() - - // Start launches a background goroutine; it should not block or return an error. - if err := ch.Start(ctx); err != nil { - t.Fatalf("Failed to start channel: %v", err) - } - if !ch.IsRunning() { - t.Error("Expected channel to be running after Start") - } - - // Stop should work regardless of whether the WebSocket actually connected. - if err := ch.Stop(ctx); err != nil { - t.Fatalf("Failed to stop channel: %v", err) - } - if ch.IsRunning() { - t.Error("Expected channel to be stopped after Stop") - } -} - -func TestGenerateRandomID(t *testing.T) { - ids := make(map[string]bool) - for i := 0; i < 200; i++ { - id := generateRandomID(10) - if len(id) != 10 { - t.Errorf("Expected ID length 10, got %d", len(id)) - } - if ids[id] { - t.Errorf("Duplicate ID generated: %s", id) - } - ids[id] = true - } -} - -func TestWSGenerateID(t *testing.T) { - ids := make(map[string]bool) - for i := 0; i < 200; i++ { - id := wsGenerateID() - if len(id) != 10 { - t.Errorf("Expected ID length 10, got %d", len(id)) - } - if ids[id] { - t.Errorf("Duplicate wsGenerateID result: %s", id) - } - ids[id] = true - } -} - -// ---- Webhook streaming fallback tests ---- - -// makeWebhookChannel creates a started WeComAIBotChannel for testing. -func makeWebhookChannel(t *testing.T) *WeComAIBotChannel { - t.Helper() - cfg := config.WeComAIBotConfig{ - Enabled: true, - } - cfg.SetToken("test_token") - cfg.SetEncodingAESKey("abcdefghijklmnopqrstuvwxyz0123456789ABCDEFG") - ch, err := NewWeComAIBotChannel(cfg, bus.NewMessageBus()) - if err != nil { - t.Fatalf("create channel: %v", err) - } - wc := ch.(*WeComAIBotChannel) - wc.ctx, wc.cancel = context.WithCancel(context.Background()) - return wc -} - -// makeStreamTask creates and registers a streamTask for testing. -func makeStreamTask(t *testing.T, ch *WeComAIBotChannel, streamID, chatID string, deadline time.Time) *streamTask { - t.Helper() - task := &streamTask{ - StreamID: streamID, - ChatID: chatID, - Deadline: deadline, - answerCh: make(chan string, 1), - } - task.ctx, task.cancel = context.WithCancel(ch.ctx) - ch.taskMu.Lock() - ch.streamTasks[streamID] = task - ch.chatTasks[chatID] = append(ch.chatTasks[chatID], task) - ch.taskMu.Unlock() - return task -} - -// TestGetStreamResponse_ImmediateAnswer verifies that when the agent has already -// placed its answer in answerCh, getStreamResponse returns a finish=true response -// and fully removes the task. -func TestGetStreamResponse_ImmediateAnswer(t *testing.T) { - ch := makeWebhookChannel(t) - defer ch.cancel() - - task := makeStreamTask(t, ch, "stream-1", "chat-1", time.Now().Add(30*time.Second)) - task.answerCh <- "hello from agent" - - result := ch.getStreamResponse(task, "ts123", "nonce123") - if result == "" { - t.Fatal("expected non-empty encrypted response") - } - - ch.taskMu.RLock() - _, exists := ch.streamTasks["stream-1"] - ch.taskMu.RUnlock() - if exists { - t.Error("task should have been removed from streamTasks after normal finish") - } - if !task.Finished { - t.Error("task.Finished should be true after normal finish") - } -} - -// TestGetStreamResponse_DeadlinePassed verifies that when the stream deadline has -// elapsed (no agent reply yet), getStreamResponse closes the stream but keeps the -// task alive so the response_url fallback can still deliver the answer. -func TestGetStreamResponse_DeadlinePassed(t *testing.T) { - ch := makeWebhookChannel(t) - defer ch.cancel() - - task := makeStreamTask(t, ch, "stream-2", "chat-2", time.Now().Add(-time.Millisecond)) - - result := ch.getStreamResponse(task, "ts456", "nonce456") - if result == "" { - t.Fatal("expected non-empty encrypted response") - } - - ch.taskMu.RLock() - _, stillStreaming := ch.streamTasks["stream-2"] - ch.taskMu.RUnlock() - if stillStreaming { - t.Error("task should have been removed from streamTasks after deadline") - } - if !task.StreamClosed { - t.Error("task.StreamClosed should be true after deadline") - } - if task.Finished { - t.Error("task.Finished must remain false: agent reply still expected via response_url") - } -} - -// TestGetStreamResponse_StillPending verifies that when neither the agent has -// replied nor the deadline has passed, getStreamResponse returns without altering -// task state (client should poll again). -func TestGetStreamResponse_StillPending(t *testing.T) { - ch := makeWebhookChannel(t) - defer ch.cancel() - - task := makeStreamTask(t, ch, "stream-3", "chat-3", time.Now().Add(30*time.Second)) - - result := ch.getStreamResponse(task, "ts789", "nonce789") - if result == "" { - t.Fatal("expected non-empty encrypted response") - } - - ch.taskMu.RLock() - _, exists := ch.streamTasks["stream-3"] - ch.taskMu.RUnlock() - if !exists { - t.Error("pending task should still be in streamTasks") - } - if task.Finished || task.StreamClosed { - t.Error("pending task should not be finished or stream-closed") - } - // Cleanup. - ch.removeTask(task) -} diff --git a/pkg/channels/wecom/aibot_ws.go b/pkg/channels/wecom/aibot_ws.go deleted file mode 100644 index 53dd7071f..000000000 --- a/pkg/channels/wecom/aibot_ws.go +++ /dev/null @@ -1,1347 +0,0 @@ -package wecom - -import ( - "context" - "encoding/base64" - "encoding/json" - "fmt" - "io" - "net/http" - "os" - "path/filepath" - "strings" - "sync" - "time" - - "github.com/gorilla/websocket" - - "github.com/sipeed/picoclaw/pkg/bus" - "github.com/sipeed/picoclaw/pkg/channels" - "github.com/sipeed/picoclaw/pkg/config" - "github.com/sipeed/picoclaw/pkg/identity" - "github.com/sipeed/picoclaw/pkg/logger" - "github.com/sipeed/picoclaw/pkg/media" - "github.com/sipeed/picoclaw/pkg/utils" -) - -// Long-connection WebSocket endpoint. -// Ref: https://developer.work.weixin.qq.com/document/path/101463 -const ( - wsEndpoint = "wss://openws.work.weixin.qq.com" - wsHeartbeatInterval = 30 * time.Second - wsConnectTimeout = 15 * time.Second - wsSubscribeTimeout = 10 * time.Second - wsSendMsgTimeout = 10 * time.Second - wsRespondMsgTimeout = 10 * time.Second - wsWelcomeMsgTimeout = 5 * time.Second // WeCom requires welcome reply within 5 seconds - wsMaxReconnectWait = 60 * time.Second - wsInitialReconnect = time.Second - - // WeCom requires finish=true within 6 minutes of the first stream frame. - // wsStreamTickInterval controls how often we send an in-progress hint. - // wsStreamMaxDuration is a safety margin below the 6-minute hard limit. - wsStreamTickInterval = 30 * time.Second - wsStreamMaxDuration = 5*time.Minute + 30*time.Second - - // wsImageDownloadTimeout caps the time we spend downloading an inbound image. - wsImageDownloadTimeout = 30 * time.Second - - // Keep req_id -> chat route for late fallback pushes after stream window closes. - wsLateReplyRouteTTL = 30 * time.Minute - - // wsStreamMaxContentBytes is the maximum UTF-8 byte length for the content field - // of a single WeCom AI Bot stream / text / markdown frame. - // Ref: https://developer.work.weixin.qq.com/document/path/101463 - wsStreamMaxContentBytes = 20480 -) - -// wsImageHTTPClient is a shared HTTP client for downloading inbound images. -// Reusing it enables connection pooling across multiple image downloads. -var wsImageHTTPClient = &http.Client{Timeout: wsImageDownloadTimeout} - -// WeComAIBotWSChannel implements channels.Channel for WeCom AI Bot using the -// WebSocket long-connection API. -// Unlike the webhook counterpart it does NOT implement WebhookHandler, so the -// HTTP manager will not register any callback URL for it. -type WeComAIBotWSChannel struct { - *channels.BaseChannel - config config.WeComAIBotConfig - ctx context.Context - cancel context.CancelFunc - - // conn is the active WebSocket connection; nil when disconnected. - // All writes are serialized through connMu. - conn *websocket.Conn - connMu sync.Mutex - - // dedupe prevents duplicate message processing (WeCom may re-deliver). - dedupe *MessageDeduplicator - - // reqStates holds per-req_id runtime state. - // It unifies active task state and late-reply fallback routing. - reqStates map[string]*wsReqState - reqStatesMu sync.Mutex - - // reqPending correlates command req_ids with response channels. - // Used only for subscribe/ping command-response pairs. - reqPending map[string]chan wsEnvelope - reqPendingMu sync.Mutex -} - -// wsTask tracks one in-progress agent reply for a single chat turn. -type wsTask struct { - ReqID string // req_id echoed in all replies for this turn - ChatID string - ChatType uint32 - StreamID string // our generated stream.id - answerCh chan string // agent delivers its reply here via Send() - ctx context.Context - cancel context.CancelFunc -} - -type wsReqState struct { - Task *wsTask - Route wsLateReplyRoute -} - -type wsLateReplyRoute struct { - ChatID string - ChatType uint32 - ReadyAt time.Time - ExpiresAt time.Time -} - -// ---- WebSocket protocol types ---- - -// wsEnvelope is the generic JSON envelope for all WebSocket messages. -type wsEnvelope struct { - Cmd string `json:"cmd,omitempty"` - Headers wsHeaders `json:"headers"` - Body json.RawMessage `json:"body,omitempty"` - ErrCode int `json:"errcode,omitempty"` - ErrMsg string `json:"errmsg,omitempty"` -} - -type wsHeaders struct { - ReqID string `json:"req_id"` -} - -// wsCommand is an outgoing request sent over the WebSocket. -type wsCommand struct { - Cmd string `json:"cmd"` - Headers wsHeaders `json:"headers"` - Body any `json:"body,omitempty"` -} - -type wsSendMsgBody struct { - ChatID string `json:"chatid"` - ChatType uint32 `json:"chat_type,omitempty"` - MsgType string `json:"msgtype"` - Markdown *wsMarkdownContent `json:"markdown,omitempty"` -} - -// wsRespondMsgBody is the body for aibot_respond_msg / aibot_respond_welcome_msg. -type wsRespondMsgBody struct { - MsgType string `json:"msgtype"` - Stream *wsStreamContent `json:"stream,omitempty"` - Text *wsTextContent `json:"text,omitempty"` - Markdown *wsMarkdownContent `json:"markdown,omitempty"` - Image *wsImageContent `json:"image,omitempty"` -} - -type wsStreamContent struct { - ID string `json:"id"` - Finish bool `json:"finish"` - Content string `json:"content,omitempty"` -} - -// wsImageContent carries a base64-encoded image payload for outbound messages. -type wsImageContent struct { - Base64 string `json:"base64"` - MD5 string `json:"md5"` -} - -type wsTextContent struct { - Content string `json:"content"` -} - -type wsMarkdownContent struct { - Content string `json:"content"` -} - -// WeComAIBotWSMessage is the decoded body of aibot_msg_callback / -// aibot_event_callback in WebSocket long-connection mode. -// The structure mirrors WeComAIBotMessage but includes extra fields -// that only appear in long-connection callbacks (Voice, AESKey on Image/File). -type WeComAIBotWSMessage struct { - MsgID string `json:"msgid"` - CreateTime int64 `json:"create_time,omitempty"` - AIBotID string `json:"aibotid"` - ChatID string `json:"chatid,omitempty"` - ChatType string `json:"chattype,omitempty"` // "single" | "group" - From struct { - UserID string `json:"userid"` - } `json:"from"` - MsgType string `json:"msgtype"` - Text *struct { - Content string `json:"content"` - } `json:"text,omitempty"` - Image *struct { - URL string `json:"url"` - AESKey string `json:"aeskey,omitempty"` // long-connection: per-resource decrypt key - } `json:"image,omitempty"` - Voice *struct { - Content string `json:"content"` // WeCom transcribes voice to text in callbacks - } `json:"voice,omitempty"` - Mixed *struct { - MsgItem []struct { - MsgType string `json:"msgtype"` - Text *struct { - Content string `json:"content"` - } `json:"text,omitempty"` - Image *struct { - URL string `json:"url"` - AESKey string `json:"aeskey,omitempty"` - } `json:"image,omitempty"` - } `json:"msg_item"` - } `json:"mixed,omitempty"` - Event *struct { - EventType string `json:"eventtype"` - } `json:"event,omitempty"` - File *struct { - URL string `json:"url"` - AESKey string `json:"aeskey,omitempty"` - } `json:"file,omitempty"` - Video *struct { - URL string `json:"url"` - AESKey string `json:"aeskey,omitempty"` - } `json:"video,omitempty"` -} - -// ---- Constructor ---- - -// newWeComAIBotWSChannel creates a WeComAIBotWSChannel for WebSocket mode. -func newWeComAIBotWSChannel( - cfg config.WeComAIBotConfig, - messageBus *bus.MessageBus, -) (*WeComAIBotWSChannel, error) { - if cfg.BotID == "" || cfg.Secret() == "" { - return nil, fmt.Errorf("bot_id and secret are required for WeCom AI Bot WebSocket mode") - } - - base := channels.NewBaseChannel("wecom_aibot", cfg, messageBus, cfg.AllowFrom, - channels.WithReasoningChannelID(cfg.ReasoningChannelID), - ) - - return &WeComAIBotWSChannel{ - BaseChannel: base, - config: cfg, - dedupe: NewMessageDeduplicator(wecomMaxProcessedMessages), - reqStates: make(map[string]*wsReqState), - reqPending: make(map[string]chan wsEnvelope), - }, nil -} - -// ---- Channel interface ---- - -// Name implements channels.Channel. -func (c *WeComAIBotWSChannel) Name() string { return "wecom_aibot" } - -// Start connects to the WeCom WebSocket endpoint and begins message processing. -func (c *WeComAIBotWSChannel) Start(ctx context.Context) error { - logger.InfoC("wecom_aibot", "Starting WeCom AI Bot channel (WebSocket long-connection mode)...") - c.ctx, c.cancel = context.WithCancel(ctx) - c.SetRunning(true) - go c.connectLoop() - logger.InfoC("wecom_aibot", "WeCom AI Bot channel started (WebSocket mode)") - return nil -} - -// Stop shuts down the channel and closes the WebSocket connection. -func (c *WeComAIBotWSChannel) Stop(_ context.Context) error { - logger.InfoC("wecom_aibot", "Stopping WeCom AI Bot channel (WebSocket mode)...") - if c.cancel != nil { - c.cancel() - } - c.connMu.Lock() - if c.conn != nil { - c.conn.Close() - c.conn = nil - } - c.connMu.Unlock() - c.SetRunning(false) - logger.InfoC("wecom_aibot", "WeCom AI Bot channel stopped") - return nil -} - -// Send delivers the agent reply for msg.ChatID. -// The waiting task goroutine picks it up and writes the final stream response. -func (c *WeComAIBotWSChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { - if !c.IsRunning() { - return channels.ErrNotRunning - } - - // msg.ChatID carries the inbound req_id (set by dispatchWSAgentTask). - // For cron-triggered messages, msg.ChatID is the real WeCom chat/user ID - // and there will be no matching entry in reqStates; fall through to proactive push. - task, route, ok := c.getReqState(msg.ChatID) - if !ok { - // No req_id record found — this is a cron/scheduler-originated message. - // Send it as a proactive markdown push using the chat ID directly. - logger.InfoCF("wecom_aibot", "Send: no req_id state, delivering via proactive push (cron/scheduler)", - map[string]any{"chat_id": msg.ChatID}) - if err := c.wsSendActivePush(msg.ChatID, 0, msg.Content); err != nil { - logger.WarnCF("wecom_aibot", "Proactive push failed", - map[string]any{"chat_id": msg.ChatID, "error": err.Error()}) - return fmt.Errorf("websocket delivery failed: %w", channels.ErrSendFailed) - } - return nil - } - - if task == nil { - if time.Now().Before(route.ReadyAt) { - // Keep using aibot_respond_msg within stream window; do not proactively - // push unless wsStreamMaxDuration has elapsed. - logger.WarnCF("wecom_aibot", "Send: stream window still open, skip proactive push", - map[string]any{"req_id": msg.ChatID, "ready_at": route.ReadyAt.Format(time.RFC3339)}) - return nil - } - - if err := c.wsSendActivePush(route.ChatID, route.ChatType, msg.Content); err != nil { - logger.WarnCF("wecom_aibot", "Late reply proactive push failed", - map[string]any{"req_id": msg.ChatID, "chat_id": route.ChatID, "error": err.Error()}) - return fmt.Errorf("websocket delivery failed: %w", channels.ErrSendFailed) - } - logger.InfoCF("wecom_aibot", "Late reply delivered via proactive push", - map[string]any{"req_id": msg.ChatID, "chat_id": route.ChatID, "chat_type": route.ChatType}) - c.deleteReqState(msg.ChatID) - return nil - } - - // Non-blocking fast path: when answerCh has space, deliver without racing - // against task.ctx.Done() (which fires when the task is canceled by a new - // incoming message, but the response must still be sent). - select { - case task.answerCh <- msg.Content: - return nil - default: - } - // answerCh was full; block with cancellation guards. - select { - case task.answerCh <- msg.Content: - case <-task.ctx.Done(): - return nil - case <-ctx.Done(): - return ctx.Err() - } - return nil -} - -// ---- Connection management ---- - -// wsBackoffResetDuration is the minimum duration a WebSocket connection must -// stay up before we reset the reconnect backoff to its initial value. This -// prevents a short burst of failures from causing long waits after later, -// stable connection periods. -const wsBackoffResetDuration = time.Minute - -// connectLoop maintains the WebSocket connection, reconnecting on failure with -// exponential backoff. -func (c *WeComAIBotWSChannel) connectLoop() { - backoff := wsInitialReconnect - for { - select { - case <-c.ctx.Done(): - return - default: - } - - logger.InfoC("wecom_aibot", "Connecting to WeCom WebSocket endpoint...") - start := time.Now() - if err := c.runConnection(); err != nil { - elapsed := time.Since(start) - // If the connection was stable for long enough, reset backoff so that - // a previous burst of failures does not keep us at the maximum delay. - if elapsed >= wsBackoffResetDuration { - backoff = wsInitialReconnect - } - select { - case <-c.ctx.Done(): - return - default: - logger.WarnCF("wecom_aibot", "WebSocket connection lost, reconnecting", - map[string]any{"error": err.Error(), "backoff": backoff.String()}) - select { - case <-time.After(backoff): - case <-c.ctx.Done(): - return - } - if backoff < wsMaxReconnectWait { - backoff *= 2 - if backoff > wsMaxReconnectWait { - backoff = wsMaxReconnectWait - } - } - } - } else { - // Clean exit (context canceled); stop reconnecting. - return - } - } -} - -// runConnection dials, subscribes, and runs the read/heartbeat loops until the -// connection closes or the channel context is canceled. -func (c *WeComAIBotWSChannel) runConnection() error { - dialCtx, dialCancel := context.WithTimeout(c.ctx, wsConnectTimeout) - conn, httpResp, err := websocket.DefaultDialer.DialContext(dialCtx, wsEndpoint, nil) - dialCancel() - if httpResp != nil { - httpResp.Body.Close() - } - if err != nil { - return fmt.Errorf("dial failed: %w", err) - } - - c.connMu.Lock() - c.conn = conn - c.connMu.Unlock() - - defer func() { - c.connMu.Lock() - if c.conn == conn { - c.conn = nil - } - c.connMu.Unlock() - // Cancel any tasks that were started over this connection so their - // agent goroutines do not keep running after the connection is gone. - c.cancelAllTasks() - }() - - // ---- Read loop (must start BEFORE subscribing) ---- - // sendAndWait blocks waiting for the subscribe response on reqPending; - // readLoop is the only goroutine that delivers messages to reqPending. - // Starting readLoop first avoids a deadlock where sendAndWait times out - // because no one reads the server's reply. - readErrCh := make(chan error, 1) - go func() { readErrCh <- c.readLoop(conn) }() - - // ---- Subscribe ---- - reqID := wsGenerateID() - resp, err := c.sendAndWait(conn, reqID, wsCommand{ - Cmd: "aibot_subscribe", - Headers: wsHeaders{ReqID: reqID}, - Body: map[string]string{ - "bot_id": c.config.BotID, - "secret": c.config.Secret(), - }, - }, wsSubscribeTimeout) - if err != nil { - conn.Close() // stop readLoop - <-readErrCh - return fmt.Errorf("subscribe failed: %w", err) - } - if resp.ErrCode != 0 { - conn.Close() - <-readErrCh - return fmt.Errorf("subscribe rejected (errcode=%d): %s", resp.ErrCode, resp.ErrMsg) - } - - logger.InfoC("wecom_aibot", "WebSocket subscription successful") - - // ---- Heartbeat goroutine ---- - hbDone := make(chan struct{}) - go func() { - defer close(hbDone) - c.heartbeatLoop(conn) - }() - - // Wait for the read loop to exit, then tear down the heartbeat. - readErr := <-readErrCh - conn.Close() // signal heartbeat to stop (idempotent) - <-hbDone - return readErr -} - -// sendAndWait registers a pending-response slot, sends cmd, and blocks until -// the matching response arrives or the timeout/context fires. -func (c *WeComAIBotWSChannel) sendAndWait( - conn *websocket.Conn, - reqID string, - cmd wsCommand, - timeout time.Duration, -) (wsEnvelope, error) { - ch := make(chan wsEnvelope, 1) - c.reqPendingMu.Lock() - c.reqPending[reqID] = ch - c.reqPendingMu.Unlock() - - cleanup := func() { - c.reqPendingMu.Lock() - delete(c.reqPending, reqID) - c.reqPendingMu.Unlock() - } - - data, err := json.Marshal(cmd) - if err != nil { - cleanup() - return wsEnvelope{}, fmt.Errorf("marshal command: %w", err) - } - c.connMu.Lock() - err = conn.WriteMessage(websocket.TextMessage, data) - c.connMu.Unlock() - if err != nil { - cleanup() - return wsEnvelope{}, fmt.Errorf("write command: %w", err) - } - - timer := time.NewTimer(timeout) - defer timer.Stop() - select { - case env := <-ch: - return env, nil - case <-timer.C: - cleanup() - return wsEnvelope{}, fmt.Errorf("timeout waiting for response (req_id=%s)", reqID) - case <-c.ctx.Done(): - cleanup() - return wsEnvelope{}, c.ctx.Err() - } -} - -// heartbeatLoop sends a ping every wsHeartbeatInterval until conn is closed. -// It validates the server's pong response via sendAndWait; a failed pong -// triggers a reconnection by closing the connection. -func (c *WeComAIBotWSChannel) heartbeatLoop(conn *websocket.Conn) { - ticker := time.NewTicker(wsHeartbeatInterval) - defer ticker.Stop() - for { - select { - case <-ticker.C: - reqID := wsGenerateID() - resp, err := c.sendAndWait(conn, reqID, wsCommand{ - Cmd: "ping", - Headers: wsHeaders{ReqID: reqID}, - }, wsHeartbeatInterval) - if err != nil { - logger.WarnCF("wecom_aibot", "Heartbeat failed, closing connection", - map[string]any{"error": err.Error()}) - conn.Close() - return - } - if resp.ErrCode != 0 { - logger.WarnCF("wecom_aibot", "Heartbeat rejected", - map[string]any{"errcode": resp.ErrCode, "errmsg": resp.ErrMsg}) - conn.Close() - return - } - logger.DebugCF("wecom_aibot", "Heartbeat pong received", map[string]any{"req_id": reqID}) - case <-c.ctx.Done(): - return - } - } -} - -// readLoop reads WebSocket messages and dispatches them until the connection -// closes or the channel is stopped. -func (c *WeComAIBotWSChannel) readLoop(conn *websocket.Conn) error { - for { - _, raw, err := conn.ReadMessage() - if err != nil { - select { - case <-c.ctx.Done(): - return nil // clean shutdown - default: - return fmt.Errorf("read error: %w", err) - } - } - - var env wsEnvelope - if err := json.Unmarshal(raw, &env); err != nil { - logger.WarnCF("wecom_aibot", "Failed to parse WebSocket message", - map[string]any{"error": err.Error(), "raw": string(raw)}) - continue - } - - // Command responses have an empty Cmd field; forward to any waiting - // sendAndWait() call, or silently drop if no one is waiting (e.g. - // late responses after timeout). - if env.Cmd == "" && env.Headers.ReqID != "" { - c.reqPendingMu.Lock() - ch, ok := c.reqPending[env.Headers.ReqID] - if ok { - delete(c.reqPending, env.Headers.ReqID) - } - c.reqPendingMu.Unlock() - if ok { - ch <- env - } - continue - } - - // Dispatch to appropriate handler in a separate goroutine so the - // read loop is never blocked by a slow agent. - go c.handleEnvelope(env) - } -} - -// ---- Message / event handlers ---- - -// handleEnvelope routes a WebSocket envelope to the right handler. -func (c *WeComAIBotWSChannel) handleEnvelope(env wsEnvelope) { - switch env.Cmd { - case "aibot_msg_callback": - c.handleMsgCallback(env) - case "aibot_event_callback": - c.handleEventCallback(env) - default: - logger.DebugCF("wecom_aibot", "Unhandled WebSocket command", - map[string]any{"cmd": env.Cmd}) - } -} - -// handleMsgCallback processes aibot_msg_callback. -func (c *WeComAIBotWSChannel) handleMsgCallback(env wsEnvelope) { - var msg WeComAIBotWSMessage - if err := json.Unmarshal(env.Body, &msg); err != nil { - logger.WarnCF("wecom_aibot", "Failed to parse msg callback body", - map[string]any{"error": err.Error()}) - return - } - - // Deduplicate by msgid (WeCom may re-deliver on network issues). - if msg.MsgID != "" && !c.dedupe.MarkMessageProcessed(msg.MsgID) { - logger.DebugCF("wecom_aibot", "Duplicate message ignored", - map[string]any{"msgid": msg.MsgID}) - return - } - - reqID := env.Headers.ReqID - switch msg.MsgType { - case "text": - c.handleWSTextMessage(reqID, msg) - case "image": - c.handleWSImageMessage(reqID, msg) - case "voice": - c.handleWSVoiceMessage(reqID, msg) - case "mixed": - c.handleWSMixedMessage(reqID, msg) - case "file": - c.handleWSFileMessage(reqID, msg) - case "video": - c.handleWSVideoMessage(reqID, msg) - default: - logger.WarnCF("wecom_aibot", "Unsupported message type", - map[string]any{"msgtype": msg.MsgType}) - c.wsSendStreamFinish(reqID, wsGenerateID(), - "Unsupported message type: "+msg.MsgType) - } -} - -// handleEventCallback processes aibot_event_callback. -func (c *WeComAIBotWSChannel) handleEventCallback(env wsEnvelope) { - var msg WeComAIBotWSMessage - if err := json.Unmarshal(env.Body, &msg); err != nil { - logger.WarnCF("wecom_aibot", "Failed to parse event callback body", - map[string]any{"error": err.Error()}) - return - } - - // Deduplicate by msgid. - if msg.MsgID != "" && !c.dedupe.MarkMessageProcessed(msg.MsgID) { - logger.DebugCF("wecom_aibot", "Duplicate event ignored", - map[string]any{"msgid": msg.MsgID}) - return - } - - var eventType string - if msg.Event != nil { - eventType = msg.Event.EventType - } - logger.DebugCF("wecom_aibot", "Received event callback", - map[string]any{"event_type": eventType}) - - switch eventType { - case "enter_chat": - if c.config.WelcomeMessage != "" { - c.wsSendWelcomeMsg(env.Headers.ReqID, c.config.WelcomeMessage) - } - case "disconnected_event": - // The server will close this connection after sending this event. - // connectLoop will detect the closure and reconnect automatically. - logger.WarnC("wecom_aibot", - "Received disconnected_event: this connection is being replaced by a newer one") - default: - logger.DebugCF("wecom_aibot", "Unhandled event type", - map[string]any{"event_type": eventType}) - } -} - -// handleWSTextMessage dispatches a plain-text message to the agent and streams -// the reply back over the WebSocket connection. -func (c *WeComAIBotWSChannel) handleWSTextMessage(reqID string, msg WeComAIBotWSMessage) { - if msg.Text == nil { - logger.ErrorC("wecom_aibot", "text message missing text field") - return - } - c.dispatchWSAgentTask(reqID, msg, msg.Text.Content, nil) -} - -// handleWSImageMessage downloads and stores the inbound image, then dispatches -// it to the agent as a media-tagged message. -func (c *WeComAIBotWSChannel) handleWSImageMessage(reqID string, msg WeComAIBotWSMessage) { - if msg.Image == nil { - logger.WarnC("wecom_aibot", "Image message missing image field") - c.wsSendStreamFinish(reqID, wsGenerateID(), "Image message could not be processed.") - return - } - c.wsHandleMediaMessage(reqID, msg, msg.Image.URL, msg.Image.AESKey, "image") -} - -// wsHandleMediaMessage is a shared helper for image, file and video messages. -// It downloads the resource, stores it in MediaStore, and dispatches to the agent. -func (c *WeComAIBotWSChannel) wsHandleMediaMessage( - reqID string, msg WeComAIBotWSMessage, - resourceURL, aesKey, label string, -) { - chatID := wsChatID(msg) - - ctx, cancel := context.WithTimeout(c.ctx, wsImageDownloadTimeout) - defer cancel() - - ref, err := c.storeWSMedia(ctx, chatID, msg.MsgID, resourceURL, aesKey, wsLabelToDefaultExt(label)) - if err != nil { - logger.WarnCF("wecom_aibot", "Failed to download/store WS "+label, - map[string]any{"error": err.Error(), "url": resourceURL}) - c.wsSendStreamFinish(reqID, wsGenerateID(), - strings.ToUpper(label[:1])+label[1:]+" message could not be processed.") - return - } - - c.dispatchWSAgentTask(reqID, msg, "["+label+"]", []string{ref}) -} - -// handleWSMixedMessage handles mixed text+image messages. -// All text parts are collected into the content string; all image parts are -// downloaded and stored in MediaStore before dispatching to the agent. -func (c *WeComAIBotWSChannel) handleWSMixedMessage(reqID string, msg WeComAIBotWSMessage) { - if msg.Mixed == nil { - logger.WarnC("wecom_aibot", "Mixed message has no content") - c.wsSendStreamFinish(reqID, wsGenerateID(), "Mixed message type is not yet fully supported.") - return - } - - chatID := wsChatID(msg) - - ctx, cancel := context.WithTimeout(c.ctx, wsImageDownloadTimeout) - defer cancel() - - var textParts []string - var mediaRefs []string - for _, item := range msg.Mixed.MsgItem { - switch item.MsgType { - case "text": - if item.Text != nil && item.Text.Content != "" { - textParts = append(textParts, item.Text.Content) - } - case "image": - if item.Image != nil { - ref, err := c.storeWSMedia(ctx, chatID, - msg.MsgID+"-"+wsGenerateID(), item.Image.URL, item.Image.AESKey, ".jpg") - if err != nil { - logger.WarnCF("wecom_aibot", "Failed to download/store mixed image", - map[string]any{"error": err.Error()}) - } else { - mediaRefs = append(mediaRefs, ref) - } - } - default: - logger.WarnCF("wecom_aibot", "Unsupported item type in mixed message", - map[string]any{"msgtype": item.MsgType}) - } - } - - if len(textParts) == 0 && len(mediaRefs) == 0 { - logger.WarnC("wecom_aibot", "Mixed message has no usable content") - c.wsSendStreamFinish(reqID, wsGenerateID(), "Mixed message type is not yet fully supported.") - return - } - - content := strings.Join(textParts, "\n") - if content == "" { - content = "[images]" - } - c.dispatchWSAgentTask(reqID, msg, content, mediaRefs) -} - -// dispatchWSAgentTask registers a new agent task, sends the opening stream frame, -// and starts a goroutine that runs the agent and streams the reply back. -// content is the text forwarded to the agent; mediaRefs are optional media -// store references attached to the inbound message. -func (c *WeComAIBotWSChannel) dispatchWSAgentTask( - reqID string, - msg WeComAIBotWSMessage, - content string, - mediaRefs []string, -) { - userID := msg.From.UserID - if userID == "" { - userID = "unknown" - } - // actualChatID is the real WeCom chat/user ID used for peer identification. - // reqID is used as the routing chatID so each turn is independently addressable. - actualChatID := wsChatID(msg) - - streamID := wsGenerateID() - chatType := wsChatTypeValue(msg.ChatType) - taskCtx, taskCancel := context.WithCancel(c.ctx) - - task := &wsTask{ - ReqID: reqID, - ChatID: actualChatID, - ChatType: chatType, - StreamID: streamID, - answerCh: make(chan string, 1), - ctx: taskCtx, - cancel: taskCancel, - } - // Each req_id is unique per WeCom turn; tasks run concurrently, no cancellation. - c.setReqState(reqID, &wsReqState{ - Task: task, - Route: wsLateReplyRoute{ - ChatID: actualChatID, - ChatType: chatType, - ReadyAt: time.Now().Add(wsStreamMaxDuration), - ExpiresAt: time.Now().Add(wsLateReplyRouteTTL), - }, - }) - - logger.DebugCF("wecom_aibot", "Registered new agent task", - map[string]any{"chat_id": actualChatID, "req_id": reqID, "stream_id": streamID}) - - // Send an empty stream opening frame (finish=false) immediately. - c.wsSendStreamChunk(reqID, streamID, false, "") - - go func() { - defer func() { - taskCancel() - c.clearReqTask(reqID, task) - }() - - sender := bus.SenderInfo{ - Platform: "wecom_aibot", - PlatformID: userID, - CanonicalID: identity.BuildCanonicalID("wecom_aibot", userID), - DisplayName: userID, - } - peerKind := "direct" - if msg.ChatType == "group" { - peerKind = "group" - } - peer := bus.Peer{Kind: peerKind, ID: actualChatID} - metadata := map[string]string{ - "channel": "wecom_aibot", - "chat_id": actualChatID, - "chat_type": msg.ChatType, - "msg_type": msg.MsgType, - "msgid": msg.MsgID, - "aibotid": msg.AIBotID, - "stream_id": streamID, - } - // Pass reqID as chatID: OutboundMessage.ChatID = reqID → Send() finds tasks[reqID]. - c.HandleMessage(taskCtx, peer, reqID, userID, reqID, - content, mediaRefs, metadata, sender) - - // Wait for the agent reply. While waiting, send periodic finish=false - // hints so the user knows processing is still in progress. - // WeCom requires finish=true within 6 minutes of the first stream frame; - // wsStreamMaxDuration enforces that limit with a safety margin. - waitHints := []string{ - "⏳ Processing, please wait...", - "⏳ Still processing, please wait...", - "⏳ Almost there, please wait...", - } - ticker := time.NewTicker(wsStreamTickInterval) - defer ticker.Stop() - deadlineTimer := time.NewTimer(wsStreamMaxDuration) - defer deadlineTimer.Stop() - tickCount := 0 - for { - select { - case answer := <-task.answerCh: - // Split the answer into byte-bounded chunks and send as stream frames. - // All but the last carry finish=false; the final frame closes the stream. - chunks := splitWSContent(answer, wsStreamMaxContentBytes) - for i, chunk := range chunks { - c.wsSendStreamChunk(reqID, streamID, i == len(chunks)-1, chunk) - } - c.deleteReqState(reqID) - return - case <-ticker.C: - hint := waitHints[tickCount%len(waitHints)] - tickCount++ - logger.DebugCF("wecom_aibot", "Sending stream progress hint", - map[string]any{"chat_id": actualChatID, "tick": tickCount}) - c.wsSendStreamChunk(reqID, streamID, false, hint) - case <-deadlineTimer.C: - logger.WarnCF("wecom_aibot", - "Stream response deadline reached, closing stream; late reply will be pushed", - map[string]any{"chat_id": actualChatID}) - c.wsSendStreamFinish(reqID, streamID, - "⏳ Processing is taking longer than expected, the response will be sent as a follow-up message.") - return - case <-taskCtx.Done(): - // Give a short grace period so that a response queued in the bus - // just before cancellation can still be delivered. This closes a - // race where a rapid second message cancels this task after the - // agent already published but before Send() wrote to answerCh. - // - // The connection is gone at this point, so we cannot use - // wsSendStreamFinish. Try wsSendActivePush on the (possibly - // already-restored) connection; if that also fails, leave the - // route intact so Send() can push the reply once reconnected. - select { - case answer := <-task.answerCh: - if err := c.wsSendActivePush(task.ChatID, task.ChatType, answer); err != nil { - logger.WarnCF("wecom_aibot", - "Grace-period push failed after task cancellation; reply may be lost", - map[string]any{"req_id": reqID, "chat_id": task.ChatID, "error": err.Error()}) - } else { - c.deleteReqState(reqID) - } - case <-time.After(100 * time.Millisecond): - } - return - } - } - }() -} - -// handleWSVoiceMessage handles voice messages. -// WeCom transcribes voice to text in the callback; if the transcription is -// present it is dispatched as plain text to the agent. -func (c *WeComAIBotWSChannel) handleWSVoiceMessage(reqID string, msg WeComAIBotWSMessage) { - if msg.Voice != nil && msg.Voice.Content != "" { - c.dispatchWSAgentTask(reqID, msg, msg.Voice.Content, nil) - return - } - c.wsSendStreamFinish(reqID, wsGenerateID(), "Voice messages are not yet supported.") -} - -// handleWSFileMessage handles file messages. -func (c *WeComAIBotWSChannel) handleWSFileMessage(reqID string, msg WeComAIBotWSMessage) { - if msg.File == nil { - logger.WarnC("wecom_aibot", "File message missing file field") - c.wsSendStreamFinish(reqID, wsGenerateID(), "File message could not be processed.") - return - } - c.wsHandleMediaMessage(reqID, msg, msg.File.URL, msg.File.AESKey, "file") -} - -// handleWSVideoMessage handles video messages. -func (c *WeComAIBotWSChannel) handleWSVideoMessage(reqID string, msg WeComAIBotWSMessage) { - if msg.Video == nil { - logger.WarnC("wecom_aibot", "Video message missing video field") - c.wsSendStreamFinish(reqID, wsGenerateID(), "Video message could not be processed.") - return - } - c.wsHandleMediaMessage(reqID, msg, msg.Video.URL, msg.Video.AESKey, "video") -} - -// ---- WebSocket write helpers ---- - -// wsSendStreamChunk sends an aibot_respond_msg stream frame. -func (c *WeComAIBotWSChannel) wsSendStreamChunk(reqID, streamID string, finish bool, content string) { - logger.DebugCF("wecom_aibot", "Sending stream chunk", map[string]any{ - "stream_id": streamID, - "finish": finish, - "preview": utils.Truncate(content, 100), - }) - cmd := wsCommand{ - Cmd: "aibot_respond_msg", - Headers: wsHeaders{ReqID: reqID}, - Body: wsRespondMsgBody{ - MsgType: "stream", - Stream: &wsStreamContent{ - ID: streamID, - Finish: finish, - Content: content, - }, - }, - } - if err := c.writeWSAndWait(cmd, wsRespondMsgTimeout); err != nil { - logger.WarnCF("wecom_aibot", "Stream chunk ack failed", map[string]any{ - "req_id": reqID, - "stream_id": streamID, - "finish": finish, - "error": err, - }) - } -} - -// wsSendStreamFinish sends the final aibot_respond_msg frame (finish=true, no images). -func (c *WeComAIBotWSChannel) wsSendStreamFinish(reqID, streamID, content string) { - c.wsSendStreamChunk(reqID, streamID, true, content) -} - -// wsSendWelcomeMsg sends a text welcome message via aibot_respond_welcome_msg. -func (c *WeComAIBotWSChannel) wsSendWelcomeMsg(reqID, content string) { - logger.DebugCF("wecom_aibot", "Sending welcome message", map[string]any{"req_id": reqID}) - cmd := wsCommand{ - Cmd: "aibot_respond_welcome_msg", - Headers: wsHeaders{ReqID: reqID}, - Body: wsRespondMsgBody{ - MsgType: "text", - Text: &wsTextContent{Content: content}, - }, - } - if err := c.writeWSAndWait(cmd, wsWelcomeMsgTimeout); err != nil { - logger.WarnCF("wecom_aibot", "Welcome message ack failed", - map[string]any{"req_id": reqID, "error": err.Error()}) - } -} - -// wsSendActivePush sends a proactive markdown message using aibot_send_msg. -// Long content is automatically split into byte-bounded chunks (≤ wsStreamMaxContentBytes -// each) and delivered as consecutive messages. -// It is used as a fallback for late replies after stream response window expires. -func (c *WeComAIBotWSChannel) wsSendActivePush(chatID string, chatType uint32, content string) error { - if chatID == "" { - return fmt.Errorf("chatid is empty") - } - for _, chunk := range splitWSContent(content, wsStreamMaxContentBytes) { - reqID := wsGenerateID() - if err := c.writeWSAndWait(wsCommand{ - Cmd: "aibot_send_msg", - Headers: wsHeaders{ReqID: reqID}, - Body: wsSendMsgBody{ - ChatID: chatID, - ChatType: chatType, - MsgType: "markdown", - Markdown: &wsMarkdownContent{Content: chunk}, - }, - }, wsSendMsgTimeout); err != nil { - return err - } - } - return nil -} - -// writeWSAndWait writes cmd to the active connection and validates the command response. -func (c *WeComAIBotWSChannel) writeWSAndWait(cmd wsCommand, timeout time.Duration) error { - if cmd.Headers.ReqID == "" { - return fmt.Errorf("req_id is empty") - } - - c.connMu.Lock() - conn := c.conn - c.connMu.Unlock() - if conn == nil { - return fmt.Errorf("websocket not connected") - } - - resp, err := c.sendAndWait(conn, cmd.Headers.ReqID, cmd, timeout) - if err != nil { - return err - } - if resp.ErrCode != 0 { - return fmt.Errorf("%s rejected (errcode=%d): %s", cmd.Cmd, resp.ErrCode, resp.ErrMsg) - } - return nil -} - -// cancelAllTasks cancels every pending agent task; called when the connection drops. -// It also expires each task's stream window (ReadyAt = now) so that when the agent -// eventually delivers its reply via Send(), the message is forwarded via -// wsSendActivePush on the restored connection instead of being silently discarded. -func (c *WeComAIBotWSChannel) cancelAllTasks() { - c.reqStatesMu.Lock() - defer c.reqStatesMu.Unlock() - now := time.Now() - for _, state := range c.reqStates { - if state != nil && state.Task != nil { - state.Task.cancel() - state.Task = nil - // Expire the stream window immediately so Send() uses wsSendActivePush. - state.Route.ReadyAt = now - } - } -} - -func (c *WeComAIBotWSChannel) setReqState(reqID string, state *wsReqState) { - c.reqStatesMu.Lock() - defer c.reqStatesMu.Unlock() - now := time.Now() - for k, v := range c.reqStates { - if v == nil || now.After(v.Route.ExpiresAt) { - delete(c.reqStates, k) - } - } - c.reqStates[reqID] = state -} - -func (c *WeComAIBotWSChannel) getReqState(reqID string) (*wsTask, wsLateReplyRoute, bool) { - c.reqStatesMu.Lock() - defer c.reqStatesMu.Unlock() - state, ok := c.reqStates[reqID] - if !ok || state == nil { - return nil, wsLateReplyRoute{}, false - } - if time.Now().After(state.Route.ExpiresAt) { - delete(c.reqStates, reqID) - return nil, wsLateReplyRoute{}, false - } - return state.Task, state.Route, true -} - -func (c *WeComAIBotWSChannel) deleteReqState(reqID string) { - c.reqStatesMu.Lock() - delete(c.reqStates, reqID) - c.reqStatesMu.Unlock() -} - -func (c *WeComAIBotWSChannel) clearReqTask(reqID string, task *wsTask) { - c.reqStatesMu.Lock() - defer c.reqStatesMu.Unlock() - state, ok := c.reqStates[reqID] - if !ok || state == nil { - return - } - if state.Task == task { - state.Task = nil - } -} - -func wsChatTypeValue(chatType string) uint32 { - if chatType == "group" { - return 2 - } - return 1 -} - -// wsChatID returns the effective chat ID from a WS message. -// For group messages it is msg.ChatID; for single chats it falls back to the sender's UserID. -func wsChatID(msg WeComAIBotWSMessage) string { - if msg.ChatID != "" { - return msg.ChatID - } - return msg.From.UserID -} - -// wsGenerateID generates a random 10-character alphanumeric ID. -// It is package-level (not a method) so it can be shared by both channel modes. -func wsGenerateID() string { - return generateRandomID(10) -} - -// ---- Inbound media download helpers ---- - -// storeWSMedia downloads the resource at resourceURL (with optional AES-CBC -// decryption) and stores it in the MediaStore. The file extension is inferred -// from the HTTP Content-Type response header; defaultExt is used as a fallback -// when the content type is absent or unrecognized. -func (c *WeComAIBotWSChannel) storeWSMedia( - ctx context.Context, - chatID, msgID, resourceURL, aesKey, defaultExt string, -) (string, error) { - store := c.GetMediaStore() - if store == nil { - return "", fmt.Errorf("no media store available") - } - - const maxSize = 20 << 20 // 20 MB - - req, err := http.NewRequestWithContext(ctx, http.MethodGet, resourceURL, nil) - if err != nil { - return "", fmt.Errorf("create request: %w", err) - } - resp, err := wsImageHTTPClient.Do(req) - if err != nil { - return "", fmt.Errorf("download: %w", err) - } - defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - return "", fmt.Errorf("download HTTP %d", resp.StatusCode) - } - - // Infer file extension from the Content-Type response header. - ext := wsMediaExtFromContentType(resp.Header.Get("Content-Type")) - if ext == "" { - ext = defaultExt - } - - // Buffer the media in memory, bounded to maxSize. - data, err := io.ReadAll(io.LimitReader(resp.Body, int64(maxSize)+1)) - if err != nil { - return "", fmt.Errorf("read media: %w", err) - } - if len(data) > maxSize { - return "", fmt.Errorf("media too large (> %d MB)", maxSize>>20) - } - - // AES-CBC decryption if a key is present. - if aesKey != "" { - key, decErr := base64.StdEncoding.DecodeString(aesKey) - if decErr != nil || len(key) != 32 { - key, decErr = decodeWeComAESKey(aesKey) - if decErr != nil { - return "", fmt.Errorf("decode media AES key: %w", decErr) - } - } - data, err = decryptAESCBC(key, data) - if err != nil { - return "", fmt.Errorf("decrypt media: %w", err) - } - } - - // Write to a temp file. The file is owned by the MediaStore and deleted by - // store.ReleaseAll — no caller-side cleanup needed. - mediaDir := filepath.Join(os.TempDir(), "picoclaw_media") - if err = os.MkdirAll(mediaDir, 0o700); err != nil { - return "", fmt.Errorf("mkdir: %w", err) - } - tmpFile, err := os.CreateTemp(mediaDir, msgID+"-*"+ext) - if err != nil { - return "", fmt.Errorf("create temp file: %w", err) - } - tmpPath := tmpFile.Name() - _, writeErr := tmpFile.Write(data) - closeErr := tmpFile.Close() - if writeErr != nil { - os.Remove(tmpPath) - return "", fmt.Errorf("write media: %w", writeErr) - } - if closeErr != nil { - os.Remove(tmpPath) - return "", fmt.Errorf("close media: %w", closeErr) - } - - scope := channels.BuildMediaScope("wecom_aibot", chatID, msgID) - ref, err := store.Store(tmpPath, media.MediaMeta{ - Filename: msgID + ext, - Source: "wecom_aibot", - CleanupPolicy: media.CleanupPolicyDeleteOnCleanup, - }, scope) - if err != nil { - os.Remove(tmpPath) - return "", fmt.Errorf("store: %w", err) - } - return ref, nil -} - -// wsMediaExtFromContentType returns the lowercase file extension (with leading -// dot) for the given Content-Type value, or "" when the type is unrecognized. -func wsMediaExtFromContentType(contentType string) string { - if contentType == "" { - return "" - } - // Strip parameters (e.g. "image/jpeg; charset=utf-8" → "image/jpeg"). - mt := strings.ToLower(strings.TrimSpace(strings.SplitN(contentType, ";", 2)[0])) - switch mt { - case "image/jpeg", "image/jpg": - return ".jpg" - case "image/png": - return ".png" - case "image/gif": - return ".gif" - case "image/webp": - return ".webp" - case "video/mp4": - return ".mp4" - case "video/mpeg", "video/x-mpeg": - return ".mpeg" - case "video/quicktime": - return ".mov" - case "video/webm": - return ".webm" - case "audio/mpeg", "audio/mp3": - return ".mp3" - case "audio/ogg": - return ".ogg" - case "audio/wav": - return ".wav" - case "application/pdf": - return ".pdf" - case "application/zip": - return ".zip" - case "application/x-rar-compressed", "application/vnd.rar": - return ".rar" - case "text/plain": - return ".txt" - case "application/msword": - return ".doc" - case "application/vnd.openxmlformats-officedocument.wordprocessingml.document": - return ".docx" - case "application/vnd.ms-excel": - return ".xls" - case "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": - return ".xlsx" - case "application/vnd.ms-powerpoint": - return ".ppt" - case "application/vnd.openxmlformats-officedocument.presentationml.presentation": - return ".pptx" - } - return "" -} - -// wsLabelToDefaultExt returns the default file extension for the given media label -// used in wsHandleMediaMessage. It is the fallback when Content-Type detection fails. -func wsLabelToDefaultExt(label string) string { - switch label { - case "image": - return ".jpg" - case "video": - return ".mp4" - default: // "file" and any future labels - return ".bin" - } -} - -// ---- Content length helpers ---- - -// splitWSContent splits content into chunks each fitting within maxBytes UTF-8 -// bytes, preserving code block integrity via channels.SplitMessage. -// When SplitMessage still produces an oversized chunk (e.g. dense CJK content), -// splitAtByteBoundary is applied as a last-resort byte-level fallback. -func splitWSContent(content string, maxBytes int) []string { - if len(content) <= maxBytes { - return []string{content} - } - // SplitMessage works in runes. Use maxBytes as the rune limit: for pure ASCII - // this is exact; for multibyte content the byte verification below catches - // any chunk that still overflows. - chunks := channels.SplitMessage(content, maxBytes) - var result []string - for _, chunk := range chunks { - if len(chunk) <= maxBytes { - result = append(result, chunk) - } else { - // Still too large in bytes (e.g. dense CJK); force-split at UTF-8 boundaries. - result = append(result, splitAtByteBoundary(chunk, maxBytes)...) - } - } - return result -} - -// splitAtByteBoundary splits s into parts each ≤ maxBytes bytes by walking back -// from the hard byte limit to find a valid UTF-8 rune start boundary. -// This is a last-resort fallback; it does not try to preserve code blocks. -func splitAtByteBoundary(s string, maxBytes int) []string { - var parts []string - for len(s) > maxBytes { - end := maxBytes - // Walk back past any UTF-8 continuation bytes (high two bits == 10). - for end > 0 && s[end]>>6 == 0b10 { - end-- - } - if end == 0 { - end = maxBytes // shouldn't happen with valid UTF-8 - } - parts = append(parts, s[:end]) - s = strings.TrimLeft(s[end:], " \t\n\r") - } - if s != "" { - parts = append(parts, s) - } - return parts -} diff --git a/pkg/channels/wecom/aibot_ws_test.go b/pkg/channels/wecom/aibot_ws_test.go deleted file mode 100644 index f2f8833a1..000000000 --- a/pkg/channels/wecom/aibot_ws_test.go +++ /dev/null @@ -1,295 +0,0 @@ -package wecom - -import ( - "bytes" - "context" - "net/http" - "net/http/httptest" - "os" - "strings" - "testing" - - "github.com/sipeed/picoclaw/pkg/bus" - "github.com/sipeed/picoclaw/pkg/channels" - "github.com/sipeed/picoclaw/pkg/config" - "github.com/sipeed/picoclaw/pkg/media" -) - -// newTestWSChannel creates a WeComAIBotWSChannel ready for unit testing. -func newTestWSChannel(t *testing.T) *WeComAIBotWSChannel { - t.Helper() - cfg := config.WeComAIBotConfig{ - Enabled: true, - BotID: "test_bot_id", - } - cfg.SetSecret("test_secret") - ch, err := newWeComAIBotWSChannel(cfg, bus.NewMessageBus()) - if err != nil { - t.Fatalf("create WS channel: %v", err) - } - return ch -} - -// TestStoreWSMedia_NilStore verifies that storeWSMedia returns an error when no -// MediaStore has been injected. -func TestStoreWSMedia_NilStore(t *testing.T) { - ch := newTestWSChannel(t) - _, err := ch.storeWSMedia(context.Background(), "chat1", "msg1", "http://any", "", ".jpg") - if err == nil { - t.Fatal("expected error when no MediaStore is set") - } -} - -// TestStoreWSMedia_HTTPError verifies that storeWSMedia propagates HTTP errors -// from the media server. -func TestStoreWSMedia_HTTPError(t *testing.T) { - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - http.Error(w, "not found", http.StatusNotFound) - })) - defer srv.Close() - - ch := newTestWSChannel(t) - ch.SetMediaStore(media.NewFileMediaStore()) - - _, err := ch.storeWSMedia(context.Background(), "chat1", "msg1", srv.URL, "", ".jpg") - if err == nil { - t.Fatal("expected error for HTTP 404") - } -} - -// TestStoreWSMedia_ServerUnavailable verifies that storeWSMedia returns a clear -// error when the media server cannot be reached. -func TestStoreWSMedia_ServerUnavailable(t *testing.T) { - ch := newTestWSChannel(t) - ch.SetMediaStore(media.NewFileMediaStore()) - - // Port 1 is reserved and will refuse the connection immediately. - _, err := ch.storeWSMedia(context.Background(), "chat1", "msg1", "http://127.0.0.1:1", "", ".jpg") - if err == nil { - t.Fatal("expected error for unreachable server") - } -} - -// TestStoreWSMedia_Success_NoAES verifies the happy path: the media is downloaded, -// a media ref is returned, and the file persists and is readable via Resolve until -// ReleaseAll is called. The server returns no Content-Type, so the defaultExt is used. -func TestStoreWSMedia_Success_NoAES(t *testing.T) { - imageData := bytes.Repeat([]byte("x"), 256) - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.WriteHeader(http.StatusOK) - _, _ = w.Write(imageData) - })) - defer srv.Close() - - ch := newTestWSChannel(t) - store := media.NewFileMediaStore() - ch.SetMediaStore(store) - - ref, err := ch.storeWSMedia(context.Background(), "chat1", "msg1", srv.URL, "", ".jpg") - if err != nil { - t.Fatalf("expected no error, got %v", err) - } - if ref == "" { - t.Fatal("expected non-empty ref") - } - - // File must be accessible after storeWSMedia returns (no premature deletion). - path, err := store.Resolve(ref) - if err != nil { - t.Fatalf("ref should resolve: %v", err) - } - got, err := os.ReadFile(path) - if err != nil { - t.Fatalf("file should exist at %s: %v", path, err) - } - if !bytes.Equal(got, imageData) { - t.Errorf("content mismatch: got len=%d, want len=%d", len(got), len(imageData)) - } - - // ReleaseAll must delete the file (store owns lifecycle). - scope := channels.BuildMediaScope("wecom_aibot", "chat1", "msg1") - if err := store.ReleaseAll(scope); err != nil { - t.Fatalf("ReleaseAll failed: %v", err) - } - if _, err := os.Stat(path); !os.IsNotExist(err) { - t.Errorf("file should have been deleted by ReleaseAll, stat err: %v", err) - } -} - -// TestStoreWSMedia_MultipleMessages verifies that concurrent media messages with -// different msgIDs do not collide and each resolve to distinct files. -func TestStoreWSMedia_MultipleMessages(t *testing.T) { - imageA := bytes.Repeat([]byte("a"), 64) - imageB := bytes.Repeat([]byte("b"), 64) - - srvA := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.WriteHeader(http.StatusOK) - _, _ = w.Write(imageA) - })) - defer srvA.Close() - srvB := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.WriteHeader(http.StatusOK) - _, _ = w.Write(imageB) - })) - defer srvB.Close() - - ch := newTestWSChannel(t) - store := media.NewFileMediaStore() - ch.SetMediaStore(store) - - refA, err := ch.storeWSMedia(context.Background(), "chat1", "msgA", srvA.URL, "", ".jpg") - if err != nil { - t.Fatalf("storeWSMedia A: %v", err) - } - refB, err := ch.storeWSMedia(context.Background(), "chat1", "msgB", srvB.URL, "", ".jpg") - if err != nil { - t.Fatalf("storeWSMedia B: %v", err) - } - if refA == refB { - t.Fatal("distinct messages must produce distinct refs") - } - - pathA, _ := store.Resolve(refA) - pathB, _ := store.Resolve(refB) - if pathA == pathB { - t.Fatal("distinct messages must be stored at distinct paths") - } - - gotA, _ := os.ReadFile(pathA) - gotB, _ := os.ReadFile(pathB) - if !bytes.Equal(gotA, imageA) { - t.Errorf("content mismatch for message A") - } - if !bytes.Equal(gotB, imageB) { - t.Errorf("content mismatch for message B") - } -} - -// TestStoreWSMedia_ContentTypeExt verifies that the file extension is inferred -// from the HTTP Content-Type header and the defaultExt fallback is used when the -// type is absent or unrecognized. -func TestStoreWSMedia_ContentTypeExt(t *testing.T) { - tests := []struct { - contentType string - wantExt string - }{ - {"image/jpeg", ".jpg"}, - {"image/png", ".png"}, - {"video/mp4", ".mp4"}, - {"application/pdf", ".pdf"}, - {"application/zip", ".zip"}, - // With parameters stripped. - {"video/mp4; codecs=avc1", ".mp4"}, - // Unknown type → falls back to defaultExt. - {"", ""}, - {"application/octet-stream", ""}, - } - for _, tc := range tests { - got := wsMediaExtFromContentType(tc.contentType) - if got != tc.wantExt { - t.Errorf("wsMediaExtFromContentType(%q) = %q, want %q", tc.contentType, got, tc.wantExt) - } - } - - // End-to-end: server returns Content-Type: video/mp4, defaultExt is .bin. - // The stored file should carry the .mp4 extension, not .bin. - payload := bytes.Repeat([]byte("v"), 128) - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.Header().Set("Content-Type", "video/mp4") - w.WriteHeader(http.StatusOK) - _, _ = w.Write(payload) - })) - defer srv.Close() - - ch := newTestWSChannel(t) - store := media.NewFileMediaStore() - ch.SetMediaStore(store) - - ref, err := ch.storeWSMedia(context.Background(), "chat1", "vid1", srv.URL, "", ".bin") - if err != nil { - t.Fatalf("storeWSMedia: %v", err) - } - path, err := store.Resolve(ref) - if err != nil { - t.Fatalf("resolve: %v", err) - } - if ext := path[len(path)-4:]; ext != ".mp4" { - t.Errorf("expected .mp4 extension from Content-Type, got %q", ext) - } -} - -// TestSplitWSContent verifies byte-aware splitting of stream content. -func TestSplitWSContent(t *testing.T) { - t.Run("short content is not split", func(t *testing.T) { - chunks := splitWSContent("hello", 20480) - if len(chunks) != 1 || chunks[0] != "hello" { - t.Fatalf("unexpected chunks: %v", chunks) - } - }) - - t.Run("ASCII content split at byte boundary", func(t *testing.T) { - // Build a string just over the limit. - content := strings.Repeat("a", 20481) - chunks := splitWSContent(content, 20480) - if len(chunks) < 2 { - t.Fatalf("expected >= 2 chunks, got %d", len(chunks)) - } - for i, c := range chunks { - if len(c) > 20480 { - t.Errorf("chunk %d has %d bytes, want <= 20480", i, len(c)) - } - } - // Reassembled content must equal the original (possibly without leading - // whitespace that splitWSContent trims between chunks). - joined := strings.Join(chunks, "") - if len(joined) < len(content)-len(chunks) { - t.Errorf("joined length %d too short (original %d)", len(joined), len(content)) - } - }) - - t.Run("CJK content split within byte limit", func(t *testing.T) { - // Each CJK rune is 3 bytes in UTF-8. - // 7000 CJK chars = 21000 bytes, which exceeds 20480. - content := strings.Repeat("\u4e2d", 7000) - chunks := splitWSContent(content, 20480) - if len(chunks) < 2 { - t.Fatalf("expected >= 2 chunks for 21000-byte CJK content, got %d", len(chunks)) - } - for i, c := range chunks { - if len(c) > 20480 { - t.Errorf("chunk %d has %d bytes, want <= 20480", i, len(c)) - } - // Every chunk must be valid UTF-8. - if !strings.ContainsRune(c, '\u4e2d') && len(c) > 0 { - // quick plausibility check — content was pure CJK - } - } - }) -} - -// TestSplitAtByteBoundary verifies the last-resort byte-boundary splitter. -func TestSplitAtByteBoundary(t *testing.T) { - t.Run("ASCII fits in one chunk", func(t *testing.T) { - parts := splitAtByteBoundary("hello world", 100) - if len(parts) != 1 { - t.Fatalf("expected 1 part, got %d", len(parts)) - } - }) - - t.Run("splits at byte boundary, never mid-rune", func(t *testing.T) { - // 10 CJK characters = 30 bytes; split at 20 bytes. - s := strings.Repeat("\u6587", 10) // 10 × 3 bytes = 30 bytes - parts := splitAtByteBoundary(s, 20) - for i, p := range parts { - if len(p) > 20 { - t.Errorf("part %d has %d bytes, want <= 20", i, len(p)) - } - // Must be valid UTF-8 (no torn multi-byte sequences). - for j, r := range p { - if r == '\uFFFD' { - t.Errorf("part %d has replacement rune at position %d: torn UTF-8", i, j) - } - } - } - }) -} diff --git a/pkg/channels/wecom/app.go b/pkg/channels/wecom/app.go deleted file mode 100644 index fccfc60a3..000000000 --- a/pkg/channels/wecom/app.go +++ /dev/null @@ -1,756 +0,0 @@ -package wecom - -import ( - "bytes" - "context" - "encoding/json" - "encoding/xml" - "fmt" - "io" - "mime/multipart" - "net/http" - "net/url" - "os" - "path/filepath" - "strings" - "sync" - "time" - - "github.com/sipeed/picoclaw/pkg/bus" - "github.com/sipeed/picoclaw/pkg/channels" - "github.com/sipeed/picoclaw/pkg/config" - "github.com/sipeed/picoclaw/pkg/identity" - "github.com/sipeed/picoclaw/pkg/logger" - "github.com/sipeed/picoclaw/pkg/utils" -) - -const ( - wecomAPIBase = "https://qyapi.weixin.qq.com" -) - -// WeComAppChannel implements the Channel interface for WeCom App (企业微信自建应用) -type WeComAppChannel struct { - *channels.BaseChannel - config config.WeComAppConfig - client *http.Client - accessToken string - tokenExpiry time.Time - tokenMu sync.RWMutex - ctx context.Context - cancel context.CancelFunc - processedMsgs *MessageDeduplicator -} - -// WeComXMLMessage represents the XML message structure from WeCom -type WeComXMLMessage struct { - XMLName xml.Name `xml:"xml"` - ToUserName string `xml:"ToUserName"` - FromUserName string `xml:"FromUserName"` - CreateTime int64 `xml:"CreateTime"` - MsgType string `xml:"MsgType"` - Content string `xml:"Content"` - MsgId int64 `xml:"MsgId"` - AgentID int64 `xml:"AgentID"` - PicUrl string `xml:"PicUrl"` - MediaId string `xml:"MediaId"` - Format string `xml:"Format"` - ThumbMediaId string `xml:"ThumbMediaId"` - LocationX float64 `xml:"Location_X"` - LocationY float64 `xml:"Location_Y"` - Scale int `xml:"Scale"` - Label string `xml:"Label"` - Title string `xml:"Title"` - Description string `xml:"Description"` - Url string `xml:"Url"` - Event string `xml:"Event"` - EventKey string `xml:"EventKey"` -} - -// WeComTextMessage represents text message for sending -type WeComTextMessage struct { - ToUser string `json:"touser"` - MsgType string `json:"msgtype"` - AgentID int64 `json:"agentid"` - Text struct { - Content string `json:"content"` - } `json:"text"` - Safe int `json:"safe,omitempty"` -} - -// WeComMarkdownMessage represents markdown message for sending -type WeComMarkdownMessage struct { - ToUser string `json:"touser"` - MsgType string `json:"msgtype"` - AgentID int64 `json:"agentid"` - Markdown struct { - Content string `json:"content"` - } `json:"markdown"` -} - -// WeComImageMessage represents image message for sending -type WeComImageMessage struct { - ToUser string `json:"touser"` - MsgType string `json:"msgtype"` - AgentID int64 `json:"agentid"` - Image struct { - MediaID string `json:"media_id"` - } `json:"image"` -} - -// WeComAccessTokenResponse represents the access token API response -type WeComAccessTokenResponse struct { - ErrCode int `json:"errcode"` - ErrMsg string `json:"errmsg"` - AccessToken string `json:"access_token"` - ExpiresIn int `json:"expires_in"` -} - -// WeComSendMessageResponse represents the send message API response -type WeComSendMessageResponse struct { - ErrCode int `json:"errcode"` - ErrMsg string `json:"errmsg"` - InvalidUser string `json:"invaliduser"` - InvalidParty string `json:"invalidparty"` - InvalidTag string `json:"invalidtag"` -} - -// PKCS7Padding adds PKCS7 padding -type PKCS7Padding struct{} - -// NewWeComAppChannel creates a new WeCom App channel instance -func NewWeComAppChannel(cfg config.WeComAppConfig, messageBus *bus.MessageBus) (*WeComAppChannel, error) { - if cfg.CorpID == "" || cfg.CorpSecret() == "" || cfg.AgentID == 0 { - return nil, fmt.Errorf("wecom_app corp_id, corp_secret and agent_id are required") - } - - base := channels.NewBaseChannel("wecom_app", cfg, messageBus, cfg.AllowFrom, - channels.WithMaxMessageLength(2048), - channels.WithGroupTrigger(cfg.GroupTrigger), - channels.WithReasoningChannelID(cfg.ReasoningChannelID), - ) - - // Client timeout must be >= the configured ReplyTimeout so the - // per-request context deadline is always the effective limit. - clientTimeout := 30 * time.Second - if d := time.Duration(cfg.ReplyTimeout) * time.Second; d > clientTimeout { - clientTimeout = d - } - - ctx, cancel := context.WithCancel(context.Background()) - return &WeComAppChannel{ - BaseChannel: base, - config: cfg, - client: &http.Client{Timeout: clientTimeout}, - ctx: ctx, - cancel: cancel, - processedMsgs: NewMessageDeduplicator(wecomMaxProcessedMessages), - }, nil -} - -// Name returns the channel name -func (c *WeComAppChannel) Name() string { - return "wecom_app" -} - -// Start initializes the WeCom App channel -func (c *WeComAppChannel) Start(ctx context.Context) error { - logger.InfoC("wecom_app", "Starting WeCom App channel...") - - // Cancel the context created in the constructor to avoid a resource leak. - if c.cancel != nil { - c.cancel() - } - c.ctx, c.cancel = context.WithCancel(ctx) - - // Get initial access token - if err := c.refreshAccessToken(); err != nil { - logger.WarnCF("wecom_app", "Failed to get initial access token", map[string]any{ - "error": err.Error(), - }) - } - - // Start token refresh goroutine - go c.tokenRefreshLoop() - - c.SetRunning(true) - logger.InfoC("wecom_app", "WeCom App channel started") - - return nil -} - -// Stop gracefully stops the WeCom App channel -func (c *WeComAppChannel) Stop(ctx context.Context) error { - logger.InfoC("wecom_app", "Stopping WeCom App channel...") - - if c.cancel != nil { - c.cancel() - } - - c.SetRunning(false) - logger.InfoC("wecom_app", "WeCom App channel stopped") - return nil -} - -// Send sends a message to WeCom user proactively using access token -func (c *WeComAppChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { - if !c.IsRunning() { - return channels.ErrNotRunning - } - - accessToken := c.getAccessToken() - if accessToken == "" { - return fmt.Errorf("no valid access token available") - } - - logger.DebugCF("wecom_app", "Sending message", map[string]any{ - "chat_id": msg.ChatID, - "preview": utils.Truncate(msg.Content, 100), - }) - - return c.sendTextMessage(ctx, accessToken, msg.ChatID, msg.Content) -} - -// SendMedia implements the channels.MediaSender interface. -func (c *WeComAppChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error { - if !c.IsRunning() { - return channels.ErrNotRunning - } - - accessToken := c.getAccessToken() - if accessToken == "" { - return fmt.Errorf("no valid access token available: %w", channels.ErrTemporary) - } - - store := c.GetMediaStore() - if store == nil { - return fmt.Errorf("no media store available: %w", channels.ErrSendFailed) - } - - for _, part := range msg.Parts { - localPath, err := store.Resolve(part.Ref) - if err != nil { - logger.ErrorCF("wecom_app", "Failed to resolve media ref", map[string]any{ - "ref": part.Ref, - "error": err.Error(), - }) - continue - } - - // Map part type to WeCom media type - var mediaType string - switch part.Type { - case "image": - mediaType = "image" - case "audio": - mediaType = "voice" - case "video": - mediaType = "video" - default: - mediaType = "file" - } - - // Upload media to get media_id - mediaID, err := c.uploadMedia(ctx, accessToken, mediaType, localPath) - if err != nil { - logger.ErrorCF("wecom_app", "Failed to upload media", map[string]any{ - "type": mediaType, - "error": err.Error(), - }) - // Fallback: send caption as text - if part.Caption != "" { - _ = c.sendTextMessage(ctx, accessToken, msg.ChatID, part.Caption) - } - continue - } - - // Send media message using the media_id - if mediaType == "image" { - err = c.sendImageMessage(ctx, accessToken, msg.ChatID, mediaID) - } else { - // For non-image types, send as text fallback with caption - caption := part.Caption - if caption == "" { - caption = fmt.Sprintf("[%s: %s]", part.Type, part.Filename) - } - err = c.sendTextMessage(ctx, accessToken, msg.ChatID, caption) - } - - if err != nil { - return err - } - } - - return nil -} - -// uploadMedia uploads a local file to WeCom temporary media storage. -func (c *WeComAppChannel) uploadMedia(ctx context.Context, accessToken, mediaType, localPath string) (string, error) { - apiURL := fmt.Sprintf("%s/cgi-bin/media/upload?access_token=%s&type=%s", - wecomAPIBase, url.QueryEscape(accessToken), url.QueryEscape(mediaType)) - - file, err := os.Open(localPath) - if err != nil { - return "", fmt.Errorf("failed to open file: %w", err) - } - defer file.Close() - - body := &bytes.Buffer{} - writer := multipart.NewWriter(body) - - filename := filepath.Base(localPath) - formFile, err := writer.CreateFormFile("media", filename) - if err != nil { - return "", fmt.Errorf("failed to create form file: %w", err) - } - - if _, err = io.Copy(formFile, file); err != nil { - return "", fmt.Errorf("failed to copy file content: %w", err) - } - writer.Close() - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, body) - if err != nil { - return "", fmt.Errorf("failed to create request: %w", err) - } - req.Header.Set("Content-Type", writer.FormDataContentType()) - - resp, err := c.client.Do(req) - if err != nil { - return "", channels.ClassifyNetError(err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - respBody, readErr := io.ReadAll(resp.Body) - if readErr != nil { - return "", channels.ClassifySendError( - resp.StatusCode, - fmt.Errorf("reading wecom upload error response: %w", readErr), - ) - } - return "", channels.ClassifySendError( - resp.StatusCode, - fmt.Errorf("wecom upload error: %s", string(respBody)), - ) - } - - var result struct { - ErrCode int `json:"errcode"` - ErrMsg string `json:"errmsg"` - MediaID string `json:"media_id"` - } - if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { - return "", fmt.Errorf("failed to parse upload response: %w", err) - } - - if result.ErrCode != 0 { - return "", fmt.Errorf("upload API error: %s (code: %d)", result.ErrMsg, result.ErrCode) - } - - return result.MediaID, nil -} - -// sendWeComMessage marshals payload and POSTs it to the WeCom message API. -func (c *WeComAppChannel) sendWeComMessage(ctx context.Context, accessToken string, payload any) error { - apiURL := fmt.Sprintf("%s/cgi-bin/message/send?access_token=%s", wecomAPIBase, accessToken) - - jsonData, err := json.Marshal(payload) - if err != nil { - return fmt.Errorf("failed to marshal message: %w", err) - } - - timeout := c.config.ReplyTimeout - if timeout <= 0 { - timeout = 5 - } - - reqCtx, cancel := context.WithTimeout(ctx, time.Duration(timeout)*time.Second) - defer cancel() - - req, err := http.NewRequestWithContext(reqCtx, http.MethodPost, apiURL, bytes.NewBuffer(jsonData)) - if err != nil { - return fmt.Errorf("failed to create request: %w", err) - } - req.Header.Set("Content-Type", "application/json") - - resp, err := c.client.Do(req) - if err != nil { - return channels.ClassifyNetError(err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - respBody, readErr := io.ReadAll(resp.Body) - if readErr != nil { - return channels.ClassifySendError( - resp.StatusCode, - fmt.Errorf("reading wecom_app error response: %w", readErr), - ) - } - return channels.ClassifySendError( - resp.StatusCode, - fmt.Errorf("wecom_app API error: %s", string(respBody)), - ) - } - - respBody, err := io.ReadAll(resp.Body) - if err != nil { - return fmt.Errorf("failed to read response: %w", err) - } - - var sendResp WeComSendMessageResponse - if err := json.Unmarshal(respBody, &sendResp); err != nil { - return fmt.Errorf("failed to parse response: %w", err) - } - - if sendResp.ErrCode != 0 { - return fmt.Errorf("API error: %s (code: %d)", sendResp.ErrMsg, sendResp.ErrCode) - } - - return nil -} - -// sendImageMessage sends an image message using a media_id. -func (c *WeComAppChannel) sendImageMessage(ctx context.Context, accessToken, userID, mediaID string) error { - msg := WeComImageMessage{ - ToUser: userID, - MsgType: "image", - AgentID: c.config.AgentID, - } - msg.Image.MediaID = mediaID - return c.sendWeComMessage(ctx, accessToken, msg) -} - -// WebhookPath returns the path for registering on the shared HTTP server. -func (c *WeComAppChannel) WebhookPath() string { - if c.config.WebhookPath != "" { - return c.config.WebhookPath - } - return "/webhook/wecom-app" -} - -// ServeHTTP implements http.Handler for the shared HTTP server. -func (c *WeComAppChannel) ServeHTTP(w http.ResponseWriter, r *http.Request) { - c.handleWebhook(w, r) -} - -// HealthPath returns the health check endpoint path. -func (c *WeComAppChannel) HealthPath() string { - return "/health/wecom-app" -} - -// HealthHandler handles health check requests. -func (c *WeComAppChannel) HealthHandler(w http.ResponseWriter, r *http.Request) { - c.handleHealth(w, r) -} - -// handleWebhook handles incoming webhook requests from WeCom -func (c *WeComAppChannel) handleWebhook(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - - // Log all incoming requests for debugging - logger.DebugCF("wecom_app", "Received webhook request", map[string]any{ - "method": r.Method, - "url": r.URL.String(), - "path": r.URL.Path, - "query": r.URL.RawQuery, - }) - - if r.Method == http.MethodGet { - // Handle verification request - c.handleVerification(ctx, w, r) - return - } - - if r.Method == http.MethodPost { - // Handle message callback - c.handleMessageCallback(ctx, w, r) - return - } - - logger.WarnCF("wecom_app", "Method not allowed", map[string]any{ - "method": r.Method, - }) - http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) -} - -// handleVerification handles the URL verification request from WeCom -func (c *WeComAppChannel) handleVerification(ctx context.Context, w http.ResponseWriter, r *http.Request) { - query := r.URL.Query() - msgSignature := query.Get("msg_signature") - timestamp := query.Get("timestamp") - nonce := query.Get("nonce") - echostr := query.Get("echostr") - - logger.DebugCF("wecom_app", "Handling verification request", map[string]any{ - "msg_signature": msgSignature, - "timestamp": timestamp, - "nonce": nonce, - "echostr": echostr, - "corp_id": c.config.CorpID, - }) - - if msgSignature == "" || timestamp == "" || nonce == "" || echostr == "" { - logger.ErrorC("wecom_app", "Missing parameters in verification request") - http.Error(w, "Missing parameters", http.StatusBadRequest) - return - } - - // Verify signature - if !verifySignature(c.config.Token(), msgSignature, timestamp, nonce, echostr) { - logger.WarnCF("wecom_app", "Signature verification failed", map[string]any{ - "token": c.config.Token(), - "msg_signature": msgSignature, - "timestamp": timestamp, - "nonce": nonce, - }) - http.Error(w, "Invalid signature", http.StatusForbidden) - return - } - - logger.DebugC("wecom_app", "Signature verification passed") - - // Decrypt echostr with CorpID verification - // For WeCom App (自建应用), receiveid should be corp_id - logger.DebugCF("wecom_app", "Attempting to decrypt echostr", map[string]any{ - "encoding_aes_key": c.config.EncodingAESKey(), - "corp_id": c.config.CorpID, - }) - decryptedEchoStr, err := decryptMessageWithVerify(echostr, c.config.EncodingAESKey(), c.config.CorpID) - if err != nil { - logger.ErrorCF("wecom_app", "Failed to decrypt echostr", map[string]any{ - "error": err.Error(), - "encoding_aes_key": c.config.EncodingAESKey, - "corp_id": c.config.CorpID, - }) - http.Error(w, "Decryption failed", http.StatusInternalServerError) - return - } - - logger.DebugCF("wecom_app", "Successfully decrypted echostr", map[string]any{ - "decrypted": decryptedEchoStr, - }) - - // Remove BOM and whitespace as per WeCom documentation - // The response must be plain text without quotes, BOM, or newlines - decryptedEchoStr = strings.TrimSpace(decryptedEchoStr) - decryptedEchoStr = strings.TrimPrefix(decryptedEchoStr, "\xef\xbb\xbf") // Remove UTF-8 BOM - w.Write([]byte(decryptedEchoStr)) -} - -// handleMessageCallback handles incoming messages from WeCom -func (c *WeComAppChannel) handleMessageCallback(ctx context.Context, w http.ResponseWriter, r *http.Request) { - query := r.URL.Query() - msgSignature := query.Get("msg_signature") - timestamp := query.Get("timestamp") - nonce := query.Get("nonce") - - if msgSignature == "" || timestamp == "" || nonce == "" { - http.Error(w, "Missing parameters", http.StatusBadRequest) - return - } - - // Read request body - body, err := io.ReadAll(r.Body) - if err != nil { - http.Error(w, "Failed to read body", http.StatusBadRequest) - return - } - defer r.Body.Close() - - // Parse XML to get encrypted message - var encryptedMsg struct { - XMLName xml.Name `xml:"xml"` - ToUserName string `xml:"ToUserName"` - Encrypt string `xml:"Encrypt"` - AgentID string `xml:"AgentID"` - } - - if err = xml.Unmarshal(body, &encryptedMsg); err != nil { - logger.ErrorCF("wecom_app", "Failed to parse XML", map[string]any{ - "error": err.Error(), - }) - http.Error(w, "Invalid XML", http.StatusBadRequest) - return - } - - // Verify signature - if !verifySignature(c.config.Token(), msgSignature, timestamp, nonce, encryptedMsg.Encrypt) { - logger.WarnC("wecom_app", "Message signature verification failed") - http.Error(w, "Invalid signature", http.StatusForbidden) - return - } - - // Decrypt message with CorpID verification - // For WeCom App (自建应用), receiveid should be corp_id - decryptedMsg, err := decryptMessageWithVerify(encryptedMsg.Encrypt, c.config.EncodingAESKey(), c.config.CorpID) - if err != nil { - logger.ErrorCF("wecom_app", "Failed to decrypt message", map[string]any{ - "error": err.Error(), - }) - http.Error(w, "Decryption failed", http.StatusInternalServerError) - return - } - - // Parse decrypted XML message - var msg WeComXMLMessage - if err := xml.Unmarshal([]byte(decryptedMsg), &msg); err != nil { - logger.ErrorCF("wecom_app", "Failed to parse decrypted message", map[string]any{ - "error": err.Error(), - }) - http.Error(w, "Invalid message format", http.StatusBadRequest) - return - } - - // Process the message with the channel's long-lived context (not the HTTP - // request context, which is canceled as soon as we return the response). - go c.processMessage(c.ctx, msg) - - // Return success response immediately - // WeCom App requires response within configured timeout (default 5 seconds) - w.Write([]byte("success")) -} - -// processMessage processes the received message -func (c *WeComAppChannel) processMessage(ctx context.Context, msg WeComXMLMessage) { - // Skip non-text messages for now (can be extended) - if msg.MsgType != "text" && msg.MsgType != "image" && msg.MsgType != "voice" { - logger.DebugCF("wecom_app", "Skipping non-supported message type", map[string]any{ - "msg_type": msg.MsgType, - }) - return - } - - // Message deduplication: Use msg_id to prevent duplicate processing - // As per WeCom documentation, use msg_id for deduplication - msgID := fmt.Sprintf("%d", msg.MsgId) - if !c.processedMsgs.MarkMessageProcessed(msgID) { - logger.DebugCF("wecom_app", "Skipping duplicate message", map[string]any{ - "msg_id": msgID, - }) - return - } - - senderID := msg.FromUserName - chatID := senderID // WeCom App uses user ID as chat ID for direct messages - - // Build metadata - // WeCom App only supports direct messages (private chat) - peer := bus.Peer{Kind: "direct", ID: senderID} - messageID := fmt.Sprintf("%d", msg.MsgId) - - metadata := map[string]string{ - "msg_type": msg.MsgType, - "msg_id": fmt.Sprintf("%d", msg.MsgId), - "agent_id": fmt.Sprintf("%d", msg.AgentID), - "platform": "wecom_app", - "media_id": msg.MediaId, - "create_time": fmt.Sprintf("%d", msg.CreateTime), - } - - content := msg.Content - - logger.DebugCF("wecom_app", "Received message", map[string]any{ - "sender_id": senderID, - "msg_type": msg.MsgType, - "preview": utils.Truncate(content, 50), - }) - - // Build sender info - appSender := bus.SenderInfo{ - Platform: "wecom", - PlatformID: senderID, - CanonicalID: identity.BuildCanonicalID("wecom", senderID), - } - - // Handle the message through the base channel - c.HandleMessage(ctx, peer, messageID, senderID, chatID, content, nil, metadata, appSender) -} - -// tokenRefreshLoop periodically refreshes the access token -func (c *WeComAppChannel) tokenRefreshLoop() { - ticker := time.NewTicker(5 * time.Minute) - defer ticker.Stop() - - for { - select { - case <-c.ctx.Done(): - return - case <-ticker.C: - if err := c.refreshAccessToken(); err != nil { - logger.ErrorCF("wecom_app", "Failed to refresh access token", map[string]any{ - "error": err.Error(), - }) - } - } - } -} - -// refreshAccessToken gets a new access token from WeCom API -func (c *WeComAppChannel) refreshAccessToken() error { - apiURL := fmt.Sprintf("%s/cgi-bin/gettoken?corpid=%s&corpsecret=%s", - wecomAPIBase, url.QueryEscape(c.config.CorpID), url.QueryEscape(c.config.CorpSecret())) - - resp, err := http.Get(apiURL) - if err != nil { - return fmt.Errorf("failed to request access token: %w", err) - } - defer resp.Body.Close() - - body, err := io.ReadAll(resp.Body) - if err != nil { - return fmt.Errorf("failed to read response: %w", err) - } - - var tokenResp WeComAccessTokenResponse - if err := json.Unmarshal(body, &tokenResp); err != nil { - return fmt.Errorf("failed to parse response: %w", err) - } - - if tokenResp.ErrCode != 0 { - return fmt.Errorf("API error: %s (code: %d)", tokenResp.ErrMsg, tokenResp.ErrCode) - } - - c.tokenMu.Lock() - c.accessToken = tokenResp.AccessToken - c.tokenExpiry = time.Now().Add(time.Duration(tokenResp.ExpiresIn-300) * time.Second) // Refresh 5 minutes early - c.tokenMu.Unlock() - - logger.DebugC("wecom_app", "Access token refreshed successfully") - return nil -} - -// getAccessToken returns the current valid access token -func (c *WeComAppChannel) getAccessToken() string { - c.tokenMu.RLock() - defer c.tokenMu.RUnlock() - - if time.Now().After(c.tokenExpiry) { - return "" - } - - return c.accessToken -} - -// sendTextMessage sends a text message to a user. -func (c *WeComAppChannel) sendTextMessage(ctx context.Context, accessToken, userID, content string) error { - msg := WeComTextMessage{ - ToUser: userID, - MsgType: "text", - AgentID: c.config.AgentID, - } - msg.Text.Content = content - return c.sendWeComMessage(ctx, accessToken, msg) -} - -// handleHealth handles health check requests -func (c *WeComAppChannel) handleHealth(w http.ResponseWriter, r *http.Request) { - status := map[string]any{ - "status": "ok", - "running": c.IsRunning(), - "has_token": c.getAccessToken() != "", - } - - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(status) -} diff --git a/pkg/channels/wecom/app_test.go b/pkg/channels/wecom/app_test.go deleted file mode 100644 index 502544441..000000000 --- a/pkg/channels/wecom/app_test.go +++ /dev/null @@ -1,1060 +0,0 @@ -package wecom - -import ( - "bytes" - "context" - "crypto/aes" - "crypto/cipher" - "crypto/sha1" - "encoding/base64" - "encoding/binary" - "encoding/json" - "encoding/xml" - "fmt" - "net/http" - "net/http/httptest" - "sort" - "strings" - "testing" - "time" - - "github.com/sipeed/picoclaw/pkg/bus" - "github.com/sipeed/picoclaw/pkg/config" -) - -// generateTestAESKeyApp generates a valid test AES key for WeCom App -func generateTestAESKeyApp() string { - // AES key needs to be 32 bytes (256 bits) for AES-256 - key := make([]byte, 32) - for i := range key { - key[i] = byte(i + 1) - } - // Return base64 encoded key without padding - return base64.StdEncoding.EncodeToString(key)[:43] -} - -// encryptTestMessageApp encrypts a message for testing WeCom App -func encryptTestMessageApp(message, aesKey string) (string, error) { - // Decode AES key - key, err := base64.StdEncoding.DecodeString(aesKey + "=") - if err != nil { - return "", err - } - - // Prepare message: random(16) + msg_len(4) + msg + corp_id - random := make([]byte, 0, 16) - for i := range 16 { - random = append(random, byte(i+1)) - } - - msgBytes := []byte(message) - corpID := []byte("test_corp_id") - - msgLen := uint32(len(msgBytes)) - lenBytes := make([]byte, 4) - binary.BigEndian.PutUint32(lenBytes, msgLen) - - plainText := append(random, lenBytes...) - plainText = append(plainText, msgBytes...) - plainText = append(plainText, corpID...) - - // PKCS7 padding - blockSize := aes.BlockSize - padding := blockSize - len(plainText)%blockSize - padText := bytes.Repeat([]byte{byte(padding)}, padding) - plainText = append(plainText, padText...) - - // Encrypt - block, err := aes.NewCipher(key) - if err != nil { - return "", err - } - - mode := cipher.NewCBCEncrypter(block, key[:aes.BlockSize]) - cipherText := make([]byte, len(plainText)) - mode.CryptBlocks(cipherText, plainText) - - return base64.StdEncoding.EncodeToString(cipherText), nil -} - -// generateSignatureApp generates a signature for testing WeCom App -func generateSignatureApp(token, timestamp, nonce, msgEncrypt string) string { - params := []string{token, timestamp, nonce, msgEncrypt} - sort.Strings(params) - str := strings.Join(params, "") - hash := sha1.Sum([]byte(str)) - return fmt.Sprintf("%x", hash) -} - -func TestNewWeComAppChannel(t *testing.T) { - msgBus := bus.NewMessageBus() - - t.Run("missing corp_id", func(t *testing.T) { - cfg := config.WeComAppConfig{ - CorpID: "", - AgentID: 1000002, - } - cfg.SetCorpSecret("test_secret") - _, err := NewWeComAppChannel(cfg, msgBus) - if err == nil { - t.Error("expected error for missing corp_id, got nil") - } - }) - - t.Run("missing corp_secret", func(t *testing.T) { - cfg := config.WeComAppConfig{ - CorpID: "test_corp_id", - AgentID: 1000002, - } - _, err := NewWeComAppChannel(cfg, msgBus) - if err == nil { - t.Error("expected error for missing corp_secret, got nil") - } - }) - - t.Run("missing agent_id", func(t *testing.T) { - cfg := config.WeComAppConfig{ - CorpID: "test_corp_id", - AgentID: 0, - } - cfg.SetCorpSecret("test_secret") - _, err := NewWeComAppChannel(cfg, msgBus) - if err == nil { - t.Error("expected error for missing agent_id, got nil") - } - }) - - t.Run("valid config", func(t *testing.T) { - cfg := config.WeComAppConfig{ - CorpID: "test_corp_id", - AgentID: 1000002, - AllowFrom: []string{"user1", "user2"}, - } - cfg.SetCorpSecret("test_secret") - ch, err := NewWeComAppChannel(cfg, msgBus) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if ch.Name() != "wecom_app" { - t.Errorf("Name() = %q, want %q", ch.Name(), "wecom_app") - } - if ch.IsRunning() { - t.Error("new channel should not be running") - } - }) -} - -func TestWeComAppChannelIsAllowed(t *testing.T) { - msgBus := bus.NewMessageBus() - - t.Run("empty allowlist allows all", func(t *testing.T) { - cfg := config.WeComAppConfig{ - CorpID: "test_corp_id", - AgentID: 1000002, - AllowFrom: []string{}, - } - cfg.SetCorpSecret("test_secret") - ch, _ := NewWeComAppChannel(cfg, msgBus) - if !ch.IsAllowed("any_user") { - t.Error("empty allowlist should allow all users") - } - }) - - t.Run("allowlist restricts users", func(t *testing.T) { - cfg := config.WeComAppConfig{ - CorpID: "test_corp_id", - AgentID: 1000002, - AllowFrom: []string{"allowed_user"}, - } - cfg.SetCorpSecret("test_secret") - ch, _ := NewWeComAppChannel(cfg, msgBus) - if !ch.IsAllowed("allowed_user") { - t.Error("allowed user should pass allowlist check") - } - if ch.IsAllowed("blocked_user") { - t.Error("non-allowed user should be blocked") - } - }) -} - -func TestWeComAppVerifySignature(t *testing.T) { - msgBus := bus.NewMessageBus() - cfg := config.WeComAppConfig{} - cfg.CorpID = "test_corp_id" - cfg.SetCorpSecret("test_secret") - cfg.AgentID = 1000002 - cfg.SetToken("test_token") - ch, _ := NewWeComAppChannel(cfg, msgBus) - - t.Run("valid signature", func(t *testing.T) { - timestamp := "1234567890" - nonce := "test_nonce" - msgEncrypt := "test_message" - expectedSig := generateSignatureApp("test_token", timestamp, nonce, msgEncrypt) - - if !verifySignature(ch.config.Token(), expectedSig, timestamp, nonce, msgEncrypt) { - t.Error("valid signature should pass verification") - } - }) - - t.Run("invalid signature", func(t *testing.T) { - timestamp := "1234567890" - nonce := "test_nonce" - msgEncrypt := "test_message" - - if verifySignature(ch.config.Token(), "invalid_sig", timestamp, nonce, msgEncrypt) { - t.Error("invalid signature should fail verification") - } - }) - - t.Run("empty token rejects verification (fail-closed)", func(t *testing.T) { - cfgEmpty := config.WeComAppConfig{} - cfgEmpty.CorpID = "test_corp_id" - cfgEmpty.SetCorpSecret("test_secret") - cfgEmpty.AgentID = 1000002 - cfgEmpty.SetToken("") - chEmpty, _ := NewWeComAppChannel(cfgEmpty, msgBus) - - if verifySignature(chEmpty.config.Token(), "any_sig", "any_ts", "any_nonce", "any_msg") { - t.Error("empty token should reject verification (fail-closed)") - } - }) -} - -func TestWeComAppDecryptMessage(t *testing.T) { - msgBus := bus.NewMessageBus() - - t.Run("decrypt without AES key", func(t *testing.T) { - cfg := config.WeComAppConfig{} - cfg.CorpID = "test_corp_id" - cfg.SetCorpSecret("test_secret") - cfg.AgentID = 1000002 - cfg.SetEncodingAESKey("") - ch, _ := NewWeComAppChannel(cfg, msgBus) - - // Without AES key, message should be base64 decoded only - plainText := "hello world" - encoded := base64.StdEncoding.EncodeToString([]byte(plainText)) - - result, err := decryptMessage(encoded, ch.config.EncodingAESKey()) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if result != plainText { - t.Errorf("decryptMessage() = %q, want %q", result, plainText) - } - }) - - t.Run("decrypt with AES key", func(t *testing.T) { - aesKey := generateTestAESKeyApp() - cfg := config.WeComAppConfig{ - CorpID: "test_corp_id", - AgentID: 1000002, - } - cfg.SetCorpSecret("test_secret") - cfg.SetEncodingAESKey(aesKey) - ch, _ := NewWeComAppChannel(cfg, msgBus) - - originalMsg := "Hello" - encrypted, err := encryptTestMessageApp(originalMsg, aesKey) - if err != nil { - t.Fatalf("failed to encrypt test message: %v", err) - } - - result, err := decryptMessage(encrypted, ch.config.EncodingAESKey()) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if result != originalMsg { - t.Errorf("WeComDecryptMessage() = %q, want %q", result, originalMsg) - } - }) - - t.Run("invalid base64", func(t *testing.T) { - cfg := config.WeComAppConfig{ - CorpID: "test_corp_id", - AgentID: 1000002, - } - cfg.SetCorpSecret("test_secret") - cfg.SetEncodingAESKey("") - ch, _ := NewWeComAppChannel(cfg, msgBus) - - _, err := decryptMessage("invalid_base64!!!", ch.config.EncodingAESKey()) - if err == nil { - t.Error("expected error for invalid base64, got nil") - } - }) - - t.Run("invalid AES key", func(t *testing.T) { - cfg := config.WeComAppConfig{} - cfg.CorpID = "test_corp_id" - cfg.SetCorpSecret("test_secret") - cfg.AgentID = 1000002 - cfg.SetEncodingAESKey("invalid_key") - ch, _ := NewWeComAppChannel(cfg, msgBus) - - _, err := decryptMessage(base64.StdEncoding.EncodeToString([]byte("test")), ch.config.EncodingAESKey()) - if err == nil { - t.Error("expected error for invalid AES key, got nil") - } - }) - - t.Run("ciphertext too short", func(t *testing.T) { - aesKey := generateTestAESKeyApp() - cfg := config.WeComAppConfig{} - cfg.CorpID = "test_corp_id" - cfg.SetCorpSecret("test_secret") - cfg.AgentID = 1000002 - cfg.SetEncodingAESKey(aesKey) - ch, _ := NewWeComAppChannel(cfg, msgBus) - - // Encrypt a very short message that results in ciphertext less than block size - shortData := make([]byte, 8) - _, err := decryptMessage(base64.StdEncoding.EncodeToString(shortData), ch.config.EncodingAESKey()) - if err == nil { - t.Error("expected error for short ciphertext, got nil") - } - }) -} - -func TestWeComAppHandleVerification(t *testing.T) { - msgBus := bus.NewMessageBus() - aesKey := generateTestAESKeyApp() - cfg := config.WeComAppConfig{} - cfg.CorpID = "test_corp_id" - cfg.SetCorpSecret("test_secret") - cfg.AgentID = 1000002 - cfg.SetToken("test_token") - cfg.SetEncodingAESKey(aesKey) - ch, _ := NewWeComAppChannel(cfg, msgBus) - - t.Run("valid verification request", func(t *testing.T) { - echostr := "test_echostr_123" - encryptedEchostr, _ := encryptTestMessageApp(echostr, aesKey) - timestamp := "1234567890" - nonce := "test_nonce" - signature := generateSignatureApp("test_token", timestamp, nonce, encryptedEchostr) - - req := httptest.NewRequest( - http.MethodGet, - "/webhook/wecom-app?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce+"&echostr="+encryptedEchostr, - nil, - ) - w := httptest.NewRecorder() - - ch.handleVerification(context.Background(), w, req) - - if w.Code != http.StatusOK { - t.Errorf("status code = %d, want %d", w.Code, http.StatusOK) - } - if w.Body.String() != echostr { - t.Errorf("response body = %q, want %q", w.Body.String(), echostr) - } - }) - - t.Run("missing parameters", func(t *testing.T) { - req := httptest.NewRequest(http.MethodGet, "/webhook/wecom-app?msg_signature=sig×tamp=ts", nil) - w := httptest.NewRecorder() - - ch.handleVerification(context.Background(), w, req) - - if w.Code != http.StatusBadRequest { - t.Errorf("status code = %d, want %d", w.Code, http.StatusBadRequest) - } - }) - - t.Run("invalid signature", func(t *testing.T) { - echostr := "test_echostr" - encryptedEchostr, _ := encryptTestMessageApp(echostr, aesKey) - timestamp := "1234567890" - nonce := "test_nonce" - - req := httptest.NewRequest( - http.MethodGet, - "/webhook/wecom-app?msg_signature=invalid_sig×tamp="+timestamp+"&nonce="+nonce+"&echostr="+encryptedEchostr, - nil, - ) - w := httptest.NewRecorder() - - ch.handleVerification(context.Background(), w, req) - - if w.Code != http.StatusForbidden { - t.Errorf("status code = %d, want %d", w.Code, http.StatusForbidden) - } - }) -} - -func TestWeComAppHandleMessageCallback(t *testing.T) { - msgBus := bus.NewMessageBus() - aesKey := generateTestAESKeyApp() - cfg := config.WeComAppConfig{} - cfg.CorpID = "test_corp_id" - cfg.SetCorpSecret("test_secret") - cfg.AgentID = 1000002 - cfg.SetToken("test_token") - cfg.SetEncodingAESKey(aesKey) - ch, _ := NewWeComAppChannel(cfg, msgBus) - - t.Run("valid message callback", func(t *testing.T) { - // Create XML message - xmlMsg := WeComXMLMessage{ - ToUserName: "corp_id", - FromUserName: "user123", - CreateTime: 1234567890, - MsgType: "text", - Content: "Hello World", - MsgId: 123456, - AgentID: 1000002, - } - xmlData, _ := xml.Marshal(xmlMsg) - - // Encrypt message - encrypted, _ := encryptTestMessageApp(string(xmlData), aesKey) - - // Create encrypted XML wrapper - encryptedWrapper := struct { - XMLName xml.Name `xml:"xml"` - Encrypt string `xml:"Encrypt"` - }{ - Encrypt: encrypted, - } - wrapperData, _ := xml.Marshal(encryptedWrapper) - - timestamp := "1234567890" - nonce := "test_nonce" - signature := generateSignatureApp("test_token", timestamp, nonce, encrypted) - - req := httptest.NewRequest( - http.MethodPost, - "/webhook/wecom-app?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce, - bytes.NewReader(wrapperData), - ) - w := httptest.NewRecorder() - - ch.handleMessageCallback(context.Background(), w, req) - - if w.Code != http.StatusOK { - t.Errorf("status code = %d, want %d", w.Code, http.StatusOK) - } - if w.Body.String() != "success" { - t.Errorf("response body = %q, want %q", w.Body.String(), "success") - } - }) - - t.Run("missing parameters", func(t *testing.T) { - req := httptest.NewRequest(http.MethodPost, "/webhook/wecom-app?msg_signature=sig", nil) - w := httptest.NewRecorder() - - ch.handleMessageCallback(context.Background(), w, req) - - if w.Code != http.StatusBadRequest { - t.Errorf("status code = %d, want %d", w.Code, http.StatusBadRequest) - } - }) - - t.Run("invalid XML", func(t *testing.T) { - timestamp := "1234567890" - nonce := "test_nonce" - signature := generateSignatureApp("test_token", timestamp, nonce, "") - - req := httptest.NewRequest( - http.MethodPost, - "/webhook/wecom-app?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce, - strings.NewReader("invalid xml"), - ) - w := httptest.NewRecorder() - - ch.handleMessageCallback(context.Background(), w, req) - - if w.Code != http.StatusBadRequest { - t.Errorf("status code = %d, want %d", w.Code, http.StatusBadRequest) - } - }) - - t.Run("invalid signature", func(t *testing.T) { - encryptedWrapper := struct { - XMLName xml.Name `xml:"xml"` - Encrypt string `xml:"Encrypt"` - }{ - Encrypt: "encrypted_data", - } - wrapperData, _ := xml.Marshal(encryptedWrapper) - - timestamp := "1234567890" - nonce := "test_nonce" - - req := httptest.NewRequest( - http.MethodPost, - "/webhook/wecom-app?msg_signature=invalid_sig×tamp="+timestamp+"&nonce="+nonce, - bytes.NewReader(wrapperData), - ) - w := httptest.NewRecorder() - - ch.handleMessageCallback(context.Background(), w, req) - - if w.Code != http.StatusForbidden { - t.Errorf("status code = %d, want %d", w.Code, http.StatusForbidden) - } - }) -} - -func TestWeComAppProcessMessage(t *testing.T) { - msgBus := bus.NewMessageBus() - cfg := config.WeComAppConfig{ - CorpID: "test_corp_id", - AgentID: 1000002, - } - cfg.SetCorpSecret("test_secret") - ch, _ := NewWeComAppChannel(cfg, msgBus) - - t.Run("process text message", func(t *testing.T) { - msg := WeComXMLMessage{ - ToUserName: "corp_id", - FromUserName: "user123", - CreateTime: 1234567890, - MsgType: "text", - Content: "Hello World", - MsgId: 123456, - AgentID: 1000002, - } - - // Should not panic - ch.processMessage(context.Background(), msg) - }) - - t.Run("process image message", func(t *testing.T) { - msg := WeComXMLMessage{ - ToUserName: "corp_id", - FromUserName: "user123", - CreateTime: 1234567890, - MsgType: "image", - PicUrl: "https://example.com/image.jpg", - MediaId: "media_123", - MsgId: 123456, - AgentID: 1000002, - } - - // Should not panic - ch.processMessage(context.Background(), msg) - }) - - t.Run("process voice message", func(t *testing.T) { - msg := WeComXMLMessage{ - ToUserName: "corp_id", - FromUserName: "user123", - CreateTime: 1234567890, - MsgType: "voice", - MediaId: "media_123", - Format: "amr", - MsgId: 123456, - AgentID: 1000002, - } - - // Should not panic - ch.processMessage(context.Background(), msg) - }) - - t.Run("skip unsupported message type", func(t *testing.T) { - msg := WeComXMLMessage{ - ToUserName: "corp_id", - FromUserName: "user123", - CreateTime: 1234567890, - MsgType: "video", - MsgId: 123456, - AgentID: 1000002, - } - - // Should not panic - ch.processMessage(context.Background(), msg) - }) - - t.Run("process event message", func(t *testing.T) { - msg := WeComXMLMessage{ - ToUserName: "corp_id", - FromUserName: "user123", - CreateTime: 1234567890, - MsgType: "event", - Event: "subscribe", - MsgId: 123456, - AgentID: 1000002, - } - - // Should not panic - ch.processMessage(context.Background(), msg) - }) -} - -func TestWeComAppHandleWebhook(t *testing.T) { - msgBus := bus.NewMessageBus() - cfg := config.WeComAppConfig{} - cfg.CorpID = "test_corp_id" - cfg.SetCorpSecret("test_secret") - cfg.AgentID = 1000002 - cfg.SetToken("test_token") - ch, _ := NewWeComAppChannel(cfg, msgBus) - - t.Run("GET request calls verification", func(t *testing.T) { - echostr := "test_echostr" - encoded := base64.StdEncoding.EncodeToString([]byte(echostr)) - timestamp := "1234567890" - nonce := "test_nonce" - signature := generateSignatureApp("test_token", timestamp, nonce, encoded) - - req := httptest.NewRequest( - http.MethodGet, - "/webhook/wecom-app?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce+"&echostr="+encoded, - nil, - ) - w := httptest.NewRecorder() - - ch.handleWebhook(w, req) - - if w.Code != http.StatusOK { - t.Errorf("status code = %d, want %d", w.Code, http.StatusOK) - } - }) - - t.Run("POST request calls message callback", func(t *testing.T) { - encryptedWrapper := struct { - XMLName xml.Name `xml:"xml"` - Encrypt string `xml:"Encrypt"` - }{ - Encrypt: base64.StdEncoding.EncodeToString([]byte("test")), - } - wrapperData, _ := xml.Marshal(encryptedWrapper) - - timestamp := "1234567890" - nonce := "test_nonce" - signature := generateSignatureApp("test_token", timestamp, nonce, encryptedWrapper.Encrypt) - - req := httptest.NewRequest( - http.MethodPost, - "/webhook/wecom-app?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce, - bytes.NewReader(wrapperData), - ) - w := httptest.NewRecorder() - - ch.handleWebhook(w, req) - - // Should not be method not allowed - if w.Code == http.StatusMethodNotAllowed { - t.Error("POST request should not return Method Not Allowed") - } - }) - - t.Run("unsupported method", func(t *testing.T) { - req := httptest.NewRequest(http.MethodPut, "/webhook/wecom-app", nil) - w := httptest.NewRecorder() - - ch.handleWebhook(w, req) - - if w.Code != http.StatusMethodNotAllowed { - t.Errorf("status code = %d, want %d", w.Code, http.StatusMethodNotAllowed) - } - }) -} - -func TestWeComAppHandleHealth(t *testing.T) { - msgBus := bus.NewMessageBus() - cfg := config.WeComAppConfig{ - CorpID: "test_corp_id", - AgentID: 1000002, - } - cfg.SetCorpSecret("test_secret") - ch, _ := NewWeComAppChannel(cfg, msgBus) - - req := httptest.NewRequest(http.MethodGet, "/health/wecom-app", nil) - w := httptest.NewRecorder() - - ch.handleHealth(w, req) - - if w.Code != http.StatusOK { - t.Errorf("status code = %d, want %d", w.Code, http.StatusOK) - } - - contentType := w.Header().Get("Content-Type") - if contentType != "application/json" { - t.Errorf("Content-Type = %q, want %q", contentType, "application/json") - } - - body := w.Body.String() - if !strings.Contains(body, "status") || !strings.Contains(body, "running") || !strings.Contains(body, "has_token") { - t.Errorf("response body should contain status, running, and has_token fields, got: %s", body) - } -} - -func TestWeComAppAccessToken(t *testing.T) { - msgBus := bus.NewMessageBus() - cfg := config.WeComAppConfig{ - CorpID: "test_corp_id", - AgentID: 1000002, - } - cfg.SetCorpSecret("test_secret") - ch, _ := NewWeComAppChannel(cfg, msgBus) - - t.Run("get empty access token initially", func(t *testing.T) { - token := ch.getAccessToken() - if token != "" { - t.Errorf("getAccessToken() = %q, want empty string", token) - } - }) - - t.Run("set and get access token", func(t *testing.T) { - ch.tokenMu.Lock() - ch.accessToken = "test_token_123" - ch.tokenExpiry = time.Now().Add(1 * time.Hour) - ch.tokenMu.Unlock() - - token := ch.getAccessToken() - if token != "test_token_123" { - t.Errorf("getAccessToken() = %q, want %q", token, "test_token_123") - } - }) - - t.Run("expired token returns empty", func(t *testing.T) { - ch.tokenMu.Lock() - ch.accessToken = "expired_token" - ch.tokenExpiry = time.Now().Add(-1 * time.Hour) - ch.tokenMu.Unlock() - - token := ch.getAccessToken() - if token != "" { - t.Errorf("getAccessToken() = %q, want empty string for expired token", token) - } - }) -} - -func TestWeComAppMessageStructures(t *testing.T) { - t.Run("WeComTextMessage structure", func(t *testing.T) { - msg := WeComTextMessage{ - ToUser: "user123", - MsgType: "text", - AgentID: 1000002, - } - msg.Text.Content = "Hello World" - - if msg.ToUser != "user123" { - t.Errorf("ToUser = %q, want %q", msg.ToUser, "user123") - } - if msg.MsgType != "text" { - t.Errorf("MsgType = %q, want %q", msg.MsgType, "text") - } - if msg.AgentID != 1000002 { - t.Errorf("AgentID = %d, want %d", msg.AgentID, 1000002) - } - if msg.Text.Content != "Hello World" { - t.Errorf("Text.Content = %q, want %q", msg.Text.Content, "Hello World") - } - - // Test JSON marshaling - jsonData, err := json.Marshal(msg) - if err != nil { - t.Fatalf("failed to marshal JSON: %v", err) - } - - var unmarshaled WeComTextMessage - err = json.Unmarshal(jsonData, &unmarshaled) - if err != nil { - t.Fatalf("failed to unmarshal JSON: %v", err) - } - - if unmarshaled.ToUser != msg.ToUser { - t.Errorf("JSON round-trip failed for ToUser") - } - }) - - t.Run("WeComMarkdownMessage structure", func(t *testing.T) { - msg := WeComMarkdownMessage{ - ToUser: "user123", - MsgType: "markdown", - AgentID: 1000002, - } - msg.Markdown.Content = "# Hello\nWorld" - - if msg.Markdown.Content != "# Hello\nWorld" { - t.Errorf("Markdown.Content = %q, want %q", msg.Markdown.Content, "# Hello\nWorld") - } - - // Test JSON marshaling - jsonData, err := json.Marshal(msg) - if err != nil { - t.Fatalf("failed to marshal JSON: %v", err) - } - - if !bytes.Contains(jsonData, []byte("markdown")) { - t.Error("JSON should contain 'markdown' field") - } - }) - - t.Run("WeComImageMessage structure", func(t *testing.T) { - msg := WeComImageMessage{ - ToUser: "user123", - MsgType: "image", - AgentID: 1000002, - } - msg.Image.MediaID = "media_123456" - - if msg.Image.MediaID != "media_123456" { - t.Errorf("Image.MediaID = %q, want %q", msg.Image.MediaID, "media_123456") - } - if msg.ToUser != "user123" { - t.Errorf("ToUser = %q, want %q", msg.ToUser, "user123") - } - if msg.MsgType != "image" { - t.Errorf("MsgType = %q, want %q", msg.MsgType, "image") - } - if msg.AgentID != 1000002 { - t.Errorf("AgentID = %d, want %d", msg.AgentID, 1000002) - } - }) - - t.Run("WeComAccessTokenResponse structure", func(t *testing.T) { - jsonData := `{ - "errcode": 0, - "errmsg": "ok", - "access_token": "test_access_token", - "expires_in": 7200 - }` - - var resp WeComAccessTokenResponse - err := json.Unmarshal([]byte(jsonData), &resp) - if err != nil { - t.Fatalf("failed to unmarshal JSON: %v", err) - } - - if resp.ErrCode != 0 { - t.Errorf("ErrCode = %d, want %d", resp.ErrCode, 0) - } - if resp.ErrMsg != "ok" { - t.Errorf("ErrMsg = %q, want %q", resp.ErrMsg, "ok") - } - if resp.AccessToken != "test_access_token" { - t.Errorf("AccessToken = %q, want %q", resp.AccessToken, "test_access_token") - } - if resp.ExpiresIn != 7200 { - t.Errorf("ExpiresIn = %d, want %d", resp.ExpiresIn, 7200) - } - }) - - t.Run("WeComSendMessageResponse structure", func(t *testing.T) { - jsonData := `{ - "errcode": 0, - "errmsg": "ok", - "invaliduser": "", - "invalidparty": "", - "invalidtag": "" - }` - - var resp WeComSendMessageResponse - err := json.Unmarshal([]byte(jsonData), &resp) - if err != nil { - t.Fatalf("failed to unmarshal JSON: %v", err) - } - - if resp.ErrCode != 0 { - t.Errorf("ErrCode = %d, want %d", resp.ErrCode, 0) - } - if resp.ErrMsg != "ok" { - t.Errorf("ErrMsg = %q, want %q", resp.ErrMsg, "ok") - } - }) -} - -func TestWeComAppXMLMessageStructure(t *testing.T) { - xmlData := ` - - - - 1234567890 - - - 1234567890123456 - 1000002 -` - - var msg WeComXMLMessage - err := xml.Unmarshal([]byte(xmlData), &msg) - if err != nil { - t.Fatalf("failed to unmarshal XML: %v", err) - } - - if msg.ToUserName != "corp_id" { - t.Errorf("ToUserName = %q, want %q", msg.ToUserName, "corp_id") - } - if msg.FromUserName != "user123" { - t.Errorf("FromUserName = %q, want %q", msg.FromUserName, "user123") - } - if msg.CreateTime != 1234567890 { - t.Errorf("CreateTime = %d, want %d", msg.CreateTime, 1234567890) - } - if msg.MsgType != "text" { - t.Errorf("MsgType = %q, want %q", msg.MsgType, "text") - } - if msg.Content != "Hello World" { - t.Errorf("Content = %q, want %q", msg.Content, "Hello World") - } - if msg.MsgId != 1234567890123456 { - t.Errorf("MsgId = %d, want %d", msg.MsgId, 1234567890123456) - } - if msg.AgentID != 1000002 { - t.Errorf("AgentID = %d, want %d", msg.AgentID, 1000002) - } -} - -func TestWeComAppXMLMessageImage(t *testing.T) { - xmlData := ` - - - - 1234567890 - - - - 1234567890123456 - 1000002 -` - - var msg WeComXMLMessage - err := xml.Unmarshal([]byte(xmlData), &msg) - if err != nil { - t.Fatalf("failed to unmarshal XML: %v", err) - } - - if msg.MsgType != "image" { - t.Errorf("MsgType = %q, want %q", msg.MsgType, "image") - } - if msg.PicUrl != "https://example.com/image.jpg" { - t.Errorf("PicUrl = %q, want %q", msg.PicUrl, "https://example.com/image.jpg") - } - if msg.MediaId != "media_123" { - t.Errorf("MediaId = %q, want %q", msg.MediaId, "media_123") - } -} - -func TestWeComAppXMLMessageVoice(t *testing.T) { - xmlData := ` - - - - 1234567890 - - - - 1234567890123456 - 1000002 -` - - var msg WeComXMLMessage - err := xml.Unmarshal([]byte(xmlData), &msg) - if err != nil { - t.Fatalf("failed to unmarshal XML: %v", err) - } - - if msg.MsgType != "voice" { - t.Errorf("MsgType = %q, want %q", msg.MsgType, "voice") - } - if msg.Format != "amr" { - t.Errorf("Format = %q, want %q", msg.Format, "amr") - } -} - -func TestWeComAppXMLMessageLocation(t *testing.T) { - xmlData := ` - - - - 1234567890 - - 39.9042 - 116.4074 - 16 - - 1234567890123456 - 1000002 -` - - var msg WeComXMLMessage - err := xml.Unmarshal([]byte(xmlData), &msg) - if err != nil { - t.Fatalf("failed to unmarshal XML: %v", err) - } - - if msg.MsgType != "location" { - t.Errorf("MsgType = %q, want %q", msg.MsgType, "location") - } - if msg.LocationX != 39.9042 { - t.Errorf("LocationX = %f, want %f", msg.LocationX, 39.9042) - } - if msg.LocationY != 116.4074 { - t.Errorf("LocationY = %f, want %f", msg.LocationY, 116.4074) - } - if msg.Scale != 16 { - t.Errorf("Scale = %d, want %d", msg.Scale, 16) - } - if msg.Label != "Beijing" { - t.Errorf("Label = %q, want %q", msg.Label, "Beijing") - } -} - -func TestWeComAppXMLMessageLink(t *testing.T) { - xmlData := ` - - - - 1234567890 - - <![CDATA[Link Title]]> - - - 1234567890123456 - 1000002 -` - - var msg WeComXMLMessage - err := xml.Unmarshal([]byte(xmlData), &msg) - if err != nil { - t.Fatalf("failed to unmarshal XML: %v", err) - } - - if msg.MsgType != "link" { - t.Errorf("MsgType = %q, want %q", msg.MsgType, "link") - } - if msg.Title != "Link Title" { - t.Errorf("Title = %q, want %q", msg.Title, "Link Title") - } - if msg.Description != "Link Description" { - t.Errorf("Description = %q, want %q", msg.Description, "Link Description") - } - if msg.Url != "https://example.com" { - t.Errorf("Url = %q, want %q", msg.Url, "https://example.com") - } -} - -func TestWeComAppXMLMessageEvent(t *testing.T) { - xmlData := ` - - - - 1234567890 - - - - 1000002 -` - - var msg WeComXMLMessage - err := xml.Unmarshal([]byte(xmlData), &msg) - if err != nil { - t.Fatalf("failed to unmarshal XML: %v", err) - } - - if msg.MsgType != "event" { - t.Errorf("MsgType = %q, want %q", msg.MsgType, "event") - } - if msg.Event != "subscribe" { - t.Errorf("Event = %q, want %q", msg.Event, "subscribe") - } - if msg.EventKey != "event_key_123" { - t.Errorf("EventKey = %q, want %q", msg.EventKey, "event_key_123") - } -} diff --git a/pkg/channels/wecom/bot.go b/pkg/channels/wecom/bot.go deleted file mode 100644 index 22461b768..000000000 --- a/pkg/channels/wecom/bot.go +++ /dev/null @@ -1,499 +0,0 @@ -package wecom - -import ( - "bytes" - "context" - "encoding/json" - "encoding/xml" - "fmt" - "io" - "net/http" - "strings" - "time" - - "github.com/sipeed/picoclaw/pkg/bus" - "github.com/sipeed/picoclaw/pkg/channels" - "github.com/sipeed/picoclaw/pkg/config" - "github.com/sipeed/picoclaw/pkg/identity" - "github.com/sipeed/picoclaw/pkg/logger" - "github.com/sipeed/picoclaw/pkg/utils" -) - -// WeComBotChannel implements the Channel interface for WeCom Bot (企业微信智能机器人) -// Uses webhook callback mode - simpler than WeCom App but only supports passive replies -type WeComBotChannel struct { - *channels.BaseChannel - config config.WeComConfig - client *http.Client - ctx context.Context - cancel context.CancelFunc - processedMsgs *MessageDeduplicator -} - -// WeComBotMessage represents the JSON message structure from WeCom Bot (AIBOT) -type WeComBotMessage struct { - MsgID string `json:"msgid"` - AIBotID string `json:"aibotid"` - ChatID string `json:"chatid"` // Session ID, only present for group chats - ChatType string `json:"chattype"` // "single" for DM, "group" for group chat - From struct { - UserID string `json:"userid"` - } `json:"from"` - ResponseURL string `json:"response_url"` - MsgType string `json:"msgtype"` // text, image, voice, file, mixed - Text struct { - Content string `json:"content"` - } `json:"text"` - Image struct { - URL string `json:"url"` - } `json:"image"` - Voice struct { - Content string `json:"content"` // Voice to text content - } `json:"voice"` - File struct { - URL string `json:"url"` - } `json:"file"` - Mixed struct { - MsgItem []struct { - MsgType string `json:"msgtype"` - Text struct { - Content string `json:"content"` - } `json:"text"` - Image struct { - URL string `json:"url"` - } `json:"image"` - } `json:"msg_item"` - } `json:"mixed"` - Quote struct { - MsgType string `json:"msgtype"` - Text struct { - Content string `json:"content"` - } `json:"text"` - } `json:"quote"` -} - -// WeComBotReplyMessage represents the reply message structure -type WeComBotReplyMessage struct { - MsgType string `json:"msgtype"` - Text struct { - Content string `json:"content"` - } `json:"text,omitempty"` -} - -// NewWeComBotChannel creates a new WeCom Bot channel instance -func NewWeComBotChannel(cfg config.WeComConfig, messageBus *bus.MessageBus) (*WeComBotChannel, error) { - if cfg.Token() == "" || cfg.WebhookURL == "" { - return nil, fmt.Errorf("wecom token and webhook_url are required") - } - - base := channels.NewBaseChannel("wecom", cfg, messageBus, cfg.AllowFrom, - channels.WithMaxMessageLength(2048), - channels.WithGroupTrigger(cfg.GroupTrigger), - channels.WithReasoningChannelID(cfg.ReasoningChannelID), - ) - - // Client timeout must be >= the configured ReplyTimeout so the - // per-request context deadline is always the effective limit. - clientTimeout := 30 * time.Second - if d := time.Duration(cfg.ReplyTimeout) * time.Second; d > clientTimeout { - clientTimeout = d - } - - ctx, cancel := context.WithCancel(context.Background()) - return &WeComBotChannel{ - BaseChannel: base, - config: cfg, - client: &http.Client{Timeout: clientTimeout}, - ctx: ctx, - cancel: cancel, - processedMsgs: NewMessageDeduplicator(wecomMaxProcessedMessages), - }, nil -} - -// Name returns the channel name -func (c *WeComBotChannel) Name() string { - return "wecom" -} - -// Start initializes the WeCom Bot channel -func (c *WeComBotChannel) Start(ctx context.Context) error { - logger.InfoC("wecom", "Starting WeCom Bot channel...") - - // Cancel the context created in the constructor to avoid a resource leak. - if c.cancel != nil { - c.cancel() - } - c.ctx, c.cancel = context.WithCancel(ctx) - - c.SetRunning(true) - logger.InfoC("wecom", "WeCom Bot channel started") - - return nil -} - -// Stop gracefully stops the WeCom Bot channel -func (c *WeComBotChannel) Stop(ctx context.Context) error { - logger.InfoC("wecom", "Stopping WeCom Bot channel...") - - if c.cancel != nil { - c.cancel() - } - - c.SetRunning(false) - logger.InfoC("wecom", "WeCom Bot channel stopped") - return nil -} - -// Send sends a message to WeCom user via webhook API -// Note: WeCom Bot can only reply within the configured timeout (default 5 seconds) of receiving a message -// For delayed responses, we use the webhook URL -func (c *WeComBotChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { - if !c.IsRunning() { - return channels.ErrNotRunning - } - - logger.DebugCF("wecom", "Sending message via webhook", map[string]any{ - "chat_id": msg.ChatID, - "preview": utils.Truncate(msg.Content, 100), - }) - - return c.sendWebhookReply(ctx, msg.ChatID, msg.Content) -} - -// WebhookPath returns the path for registering on the shared HTTP server. -func (c *WeComBotChannel) WebhookPath() string { - if c.config.WebhookPath != "" { - return c.config.WebhookPath - } - return "/webhook/wecom" -} - -// ServeHTTP implements http.Handler for the shared HTTP server. -func (c *WeComBotChannel) ServeHTTP(w http.ResponseWriter, r *http.Request) { - c.handleWebhook(w, r) -} - -// HealthPath returns the health check endpoint path. -func (c *WeComBotChannel) HealthPath() string { - return "/health/wecom" -} - -// HealthHandler handles health check requests. -func (c *WeComBotChannel) HealthHandler(w http.ResponseWriter, r *http.Request) { - c.handleHealth(w, r) -} - -// handleWebhook handles incoming webhook requests from WeCom -func (c *WeComBotChannel) handleWebhook(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - - if r.Method == http.MethodGet { - // Handle verification request - c.handleVerification(ctx, w, r) - return - } - - if r.Method == http.MethodPost { - // Handle message callback - c.handleMessageCallback(ctx, w, r) - return - } - - http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) -} - -// handleVerification handles the URL verification request from WeCom -func (c *WeComBotChannel) handleVerification(ctx context.Context, w http.ResponseWriter, r *http.Request) { - query := r.URL.Query() - msgSignature := query.Get("msg_signature") - timestamp := query.Get("timestamp") - nonce := query.Get("nonce") - echostr := query.Get("echostr") - - if msgSignature == "" || timestamp == "" || nonce == "" || echostr == "" { - http.Error(w, "Missing parameters", http.StatusBadRequest) - return - } - - // Verify signature - if !verifySignature(c.config.Token(), msgSignature, timestamp, nonce, echostr) { - logger.WarnC("wecom", "Signature verification failed") - http.Error(w, "Invalid signature", http.StatusForbidden) - return - } - - // Decrypt echostr - // For AIBOT (智能机器人), receiveid should be empty string "" - // Reference: https://developer.work.weixin.qq.com/document/path/101033 - decryptedEchoStr, err := decryptMessageWithVerify(echostr, c.config.EncodingAESKey(), "") - if err != nil { - logger.ErrorCF("wecom", "Failed to decrypt echostr", map[string]any{ - "error": err.Error(), - }) - http.Error(w, "Decryption failed", http.StatusInternalServerError) - return - } - - // Remove BOM and whitespace as per WeCom documentation - // The response must be plain text without quotes, BOM, or newlines - decryptedEchoStr = strings.TrimSpace(decryptedEchoStr) - decryptedEchoStr = strings.TrimPrefix(decryptedEchoStr, "\xef\xbb\xbf") // Remove UTF-8 BOM - w.Write([]byte(decryptedEchoStr)) -} - -// handleMessageCallback handles incoming messages from WeCom -func (c *WeComBotChannel) handleMessageCallback(ctx context.Context, w http.ResponseWriter, r *http.Request) { - query := r.URL.Query() - msgSignature := query.Get("msg_signature") - timestamp := query.Get("timestamp") - nonce := query.Get("nonce") - - if msgSignature == "" || timestamp == "" || nonce == "" { - http.Error(w, "Missing parameters", http.StatusBadRequest) - return - } - - // Read request body - body, err := io.ReadAll(r.Body) - if err != nil { - http.Error(w, "Failed to read body", http.StatusBadRequest) - return - } - defer r.Body.Close() - - // Parse XML to get encrypted message - var encryptedMsg struct { - XMLName xml.Name `xml:"xml"` - ToUserName string `xml:"ToUserName"` - Encrypt string `xml:"Encrypt"` - AgentID string `xml:"AgentID"` - } - - if err = xml.Unmarshal(body, &encryptedMsg); err != nil { - logger.ErrorCF("wecom", "Failed to parse XML", map[string]any{ - "error": err.Error(), - }) - http.Error(w, "Invalid XML", http.StatusBadRequest) - return - } - - // Verify signature - if !verifySignature(c.config.Token(), msgSignature, timestamp, nonce, encryptedMsg.Encrypt) { - logger.WarnC("wecom", "Message signature verification failed") - http.Error(w, "Invalid signature", http.StatusForbidden) - return - } - - // Decrypt message - // For AIBOT (智能机器人), receiveid should be empty string "" - // Reference: https://developer.work.weixin.qq.com/document/path/101033 - decryptedMsg, err := decryptMessageWithVerify(encryptedMsg.Encrypt, c.config.EncodingAESKey(), "") - if err != nil { - logger.ErrorCF("wecom", "Failed to decrypt message", map[string]any{ - "error": err.Error(), - }) - http.Error(w, "Decryption failed", http.StatusInternalServerError) - return - } - - // Parse decrypted JSON message (AIBOT uses JSON format) - var msg WeComBotMessage - if err := json.Unmarshal([]byte(decryptedMsg), &msg); err != nil { - logger.ErrorCF("wecom", "Failed to parse decrypted message", map[string]any{ - "error": err.Error(), - }) - http.Error(w, "Invalid message format", http.StatusBadRequest) - return - } - - // Process the message with the channel's long-lived context (not the HTTP - // request context, which is canceled as soon as we return the response). - go c.processMessage(c.ctx, msg) - - // Return success response immediately - // WeCom Bot requires response within configured timeout (default 5 seconds) - w.Write([]byte("success")) -} - -// processMessage processes the received message -func (c *WeComBotChannel) processMessage(ctx context.Context, msg WeComBotMessage) { - // Skip unsupported message types - if msg.MsgType != "text" && msg.MsgType != "image" && msg.MsgType != "voice" && msg.MsgType != "file" && - msg.MsgType != "mixed" { - logger.DebugCF("wecom", "Skipping non-supported message type", map[string]any{ - "msg_type": msg.MsgType, - }) - return - } - - // Message deduplication: Use msg_id to prevent duplicate processing - msgID := msg.MsgID - if !c.processedMsgs.MarkMessageProcessed(msgID) { - logger.DebugCF("wecom", "Skipping duplicate message", map[string]any{ - "msg_id": msgID, - }) - return - } - - senderID := msg.From.UserID - - // Determine if this is a group chat or direct message - // ChatType: "single" for DM, "group" for group chat - isGroupChat := msg.ChatType == "group" - - var chatID, peerKind, peerID string - if isGroupChat { - // Group chat: use ChatID as chatID and peer_id - chatID = msg.ChatID - peerKind = "group" - peerID = msg.ChatID - } else { - // Direct message: use senderID as chatID and peer_id - chatID = senderID - peerKind = "direct" - peerID = senderID - } - - // Extract content based on message type - var content string - switch msg.MsgType { - case "text": - content = msg.Text.Content - case "voice": - content = msg.Voice.Content // Voice to text content - case "mixed": - // For mixed messages, concatenate text items - for _, item := range msg.Mixed.MsgItem { - if item.MsgType == "text" { - content += item.Text.Content - } - } - case "image", "file": - // For image and file, we don't have text content - content = "" - } - - // Build metadata - peer := bus.Peer{Kind: peerKind, ID: peerID} - - // In group chats, apply unified group trigger filtering - if isGroupChat { - respond, cleaned := c.ShouldRespondInGroup(false, content) - if !respond { - return - } - content = cleaned - } - - metadata := map[string]string{ - "msg_type": msg.MsgType, - "msg_id": msg.MsgID, - "platform": "wecom", - "response_url": msg.ResponseURL, - } - if isGroupChat { - metadata["chat_id"] = msg.ChatID - metadata["sender_id"] = senderID - } - - logger.DebugCF("wecom", "Received message", map[string]any{ - "sender_id": senderID, - "msg_type": msg.MsgType, - "peer_kind": peerKind, - "is_group_chat": isGroupChat, - "preview": utils.Truncate(content, 50), - }) - - // Build sender info - sender := bus.SenderInfo{ - Platform: "wecom", - PlatformID: senderID, - CanonicalID: identity.BuildCanonicalID("wecom", senderID), - } - - if !c.IsAllowedSender(sender) { - return - } - - // Handle the message through the base channel - c.HandleMessage(ctx, peer, msg.MsgID, senderID, chatID, content, nil, metadata, sender) -} - -// sendWebhookReply sends a reply using the webhook URL -func (c *WeComBotChannel) sendWebhookReply(ctx context.Context, userID, content string) error { - reply := WeComBotReplyMessage{ - MsgType: "text", - } - reply.Text.Content = content - - jsonData, err := json.Marshal(reply) - if err != nil { - return fmt.Errorf("failed to marshal reply: %w", err) - } - - // Use configurable timeout (default 5 seconds) - timeout := c.config.ReplyTimeout - if timeout <= 0 { - timeout = 5 - } - - reqCtx, cancel := context.WithTimeout(ctx, time.Duration(timeout)*time.Second) - defer cancel() - - req, err := http.NewRequestWithContext(reqCtx, http.MethodPost, c.config.WebhookURL, bytes.NewBuffer(jsonData)) - if err != nil { - return fmt.Errorf("failed to create request: %w", err) - } - req.Header.Set("Content-Type", "application/json") - - resp, err := c.client.Do(req) - if err != nil { - return channels.ClassifyNetError(err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - body, readErr := io.ReadAll(resp.Body) - if readErr != nil { - return channels.ClassifySendError( - resp.StatusCode, - fmt.Errorf("reading webhook error response: %w", readErr), - ) - } - return channels.ClassifySendError( - resp.StatusCode, - fmt.Errorf("webhook API error: %s", string(body)), - ) - } - - body, err := io.ReadAll(resp.Body) - if err != nil { - return fmt.Errorf("failed to read response: %w", err) - } - - // Check response - var result struct { - ErrCode int `json:"errcode"` - ErrMsg string `json:"errmsg"` - } - if err := json.Unmarshal(body, &result); err != nil { - return fmt.Errorf("failed to parse response: %w", err) - } - - if result.ErrCode != 0 { - return fmt.Errorf("webhook API error: %s (code: %d)", result.ErrMsg, result.ErrCode) - } - - return nil -} - -// handleHealth handles health check requests -func (c *WeComBotChannel) handleHealth(w http.ResponseWriter, r *http.Request) { - status := map[string]any{ - "status": "ok", - "running": c.IsRunning(), - } - - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(status) -} diff --git a/pkg/channels/wecom/bot_test.go b/pkg/channels/wecom/bot_test.go deleted file mode 100644 index 7b50a86f7..000000000 --- a/pkg/channels/wecom/bot_test.go +++ /dev/null @@ -1,734 +0,0 @@ -package wecom - -import ( - "bytes" - "context" - "crypto/aes" - "crypto/cipher" - "crypto/sha1" - "encoding/base64" - "encoding/binary" - "encoding/json" - "encoding/xml" - "fmt" - "net/http" - "net/http/httptest" - "sort" - "strings" - "testing" - - "github.com/sipeed/picoclaw/pkg/bus" - "github.com/sipeed/picoclaw/pkg/config" -) - -// generateTestAESKey generates a valid test AES key -func generateTestAESKey() string { - // AES key needs to be 32 bytes (256 bits) for AES-256 - key := make([]byte, 32) - for i := range key { - key[i] = byte(i) - } - // Return base64 encoded key without padding - return base64.StdEncoding.EncodeToString(key)[:43] -} - -// encryptTestMessage encrypts a message for testing (AIBOT JSON format) -func encryptTestMessage(message, aesKey string) (string, error) { - // Decode AES key - key, err := base64.StdEncoding.DecodeString(aesKey + "=") - if err != nil { - return "", err - } - - // Prepare message: random(16) + msg_len(4) + msg + receiveid - random := make([]byte, 0, 16) - for i := range 16 { - random = append(random, byte(i)) - } - - msgBytes := []byte(message) - receiveID := []byte("test_aibot_id") - - msgLen := uint32(len(msgBytes)) - lenBytes := make([]byte, 4) - binary.BigEndian.PutUint32(lenBytes, msgLen) - - plainText := append(random, lenBytes...) - plainText = append(plainText, msgBytes...) - plainText = append(plainText, receiveID...) - - // PKCS7 padding - blockSize := aes.BlockSize - padding := blockSize - len(plainText)%blockSize - padText := bytes.Repeat([]byte{byte(padding)}, padding) - plainText = append(plainText, padText...) - - // Encrypt - block, err := aes.NewCipher(key) - if err != nil { - return "", err - } - - mode := cipher.NewCBCEncrypter(block, key[:aes.BlockSize]) - cipherText := make([]byte, len(plainText)) - mode.CryptBlocks(cipherText, plainText) - - return base64.StdEncoding.EncodeToString(cipherText), nil -} - -// generateSignature generates a signature for testing -func generateSignature(token, timestamp, nonce, msgEncrypt string) string { - params := []string{token, timestamp, nonce, msgEncrypt} - sort.Strings(params) - str := strings.Join(params, "") - hash := sha1.Sum([]byte(str)) - return fmt.Sprintf("%x", hash) -} - -func TestNewWeComBotChannel(t *testing.T) { - msgBus := bus.NewMessageBus() - - t.Run("missing token", func(t *testing.T) { - cfg := config.WeComConfig{} - cfg.SetToken("") - cfg.WebhookURL = "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test" - _, err := NewWeComBotChannel(cfg, msgBus) - if err == nil { - t.Error("expected error for missing token, got nil") - } - }) - - t.Run("missing webhook_url", func(t *testing.T) { - cfg := config.WeComConfig{} - cfg.SetToken("test_token") - cfg.WebhookURL = "" - _, err := NewWeComBotChannel(cfg, msgBus) - if err == nil { - t.Error("expected error for missing webhook_url, got nil") - } - }) - - t.Run("valid config", func(t *testing.T) { - cfg := config.WeComConfig{} - cfg.SetToken("test_token") - cfg.WebhookURL = "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test" - cfg.AllowFrom = []string{"user1", "user2"} - ch, err := NewWeComBotChannel(cfg, msgBus) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if ch.Name() != "wecom" { - t.Errorf("Name() = %q, want %q", ch.Name(), "wecom") - } - if ch.IsRunning() { - t.Error("new channel should not be running") - } - }) -} - -func TestWeComBotChannelIsAllowed(t *testing.T) { - msgBus := bus.NewMessageBus() - - t.Run("empty allowlist allows all", func(t *testing.T) { - cfg := config.WeComConfig{} - cfg.SetToken("test_token") - cfg.WebhookURL = "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test" - cfg.AllowFrom = []string{} - ch, _ := NewWeComBotChannel(cfg, msgBus) - if !ch.IsAllowed("any_user") { - t.Error("empty allowlist should allow all users") - } - }) - - t.Run("allowlist restricts users", func(t *testing.T) { - cfg := config.WeComConfig{} - cfg.SetToken("test_token") - cfg.WebhookURL = "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test" - cfg.AllowFrom = []string{"allowed_user"} - ch, _ := NewWeComBotChannel(cfg, msgBus) - if !ch.IsAllowed("allowed_user") { - t.Error("allowed user should pass allowlist check") - } - if ch.IsAllowed("blocked_user") { - t.Error("non-allowed user should be blocked") - } - }) -} - -func TestWeComBotVerifySignature(t *testing.T) { - msgBus := bus.NewMessageBus() - cfg := config.WeComConfig{} - cfg.SetToken("test_token") - cfg.WebhookURL = "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test" - ch, _ := NewWeComBotChannel(cfg, msgBus) - - t.Run("valid signature", func(t *testing.T) { - timestamp := "1234567890" - nonce := "test_nonce" - msgEncrypt := "test_message" - expectedSig := generateSignature("test_token", timestamp, nonce, msgEncrypt) - - if !verifySignature(ch.config.Token(), expectedSig, timestamp, nonce, msgEncrypt) { - t.Error("valid signature should pass verification") - } - }) - - t.Run("invalid signature", func(t *testing.T) { - timestamp := "1234567890" - nonce := "test_nonce" - msgEncrypt := "test_message" - - if verifySignature(ch.config.Token(), "invalid_sig", timestamp, nonce, msgEncrypt) { - t.Error("invalid signature should fail verification") - } - }) - - t.Run("empty token rejects verification (fail-closed)", func(t *testing.T) { - cfgEmpty := config.WeComConfig{} - cfgEmpty.SetToken("") - cfgEmpty.WebhookURL = "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test" - chEmpty := &WeComBotChannel{ - config: cfgEmpty, - } - - if verifySignature(chEmpty.config.Token(), "any_sig", "any_ts", "any_nonce", "any_msg") { - t.Error("empty token should reject verification (fail-closed)") - } - }) -} - -func TestWeComBotDecryptMessage(t *testing.T) { - msgBus := bus.NewMessageBus() - - t.Run("decrypt without AES key", func(t *testing.T) { - cfg := config.WeComConfig{} - cfg.SetToken("test_token") - cfg.WebhookURL = "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test" - cfg.SetEncodingAESKey("") - ch, _ := NewWeComBotChannel(cfg, msgBus) - - // Without AES key, message should be base64 decoded only - plainText := "hello world" - encoded := base64.StdEncoding.EncodeToString([]byte(plainText)) - - result, err := decryptMessage(encoded, ch.config.EncodingAESKey()) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if result != plainText { - t.Errorf("decryptMessage() = %q, want %q", result, plainText) - } - }) - - t.Run("decrypt with AES key", func(t *testing.T) { - aesKey := generateTestAESKey() - cfg := config.WeComConfig{} - cfg.SetToken("test_token") - cfg.WebhookURL = "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test" - cfg.SetEncodingAESKey(aesKey) - ch, _ := NewWeComBotChannel(cfg, msgBus) - - originalMsg := "Hello" - encrypted, err := encryptTestMessage(originalMsg, aesKey) - if err != nil { - t.Fatalf("failed to encrypt test message: %v", err) - } - - result, err := decryptMessage(encrypted, ch.config.EncodingAESKey()) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if result != originalMsg { - t.Errorf("WeComDecryptMessage() = %q, want %q", result, originalMsg) - } - }) - - t.Run("invalid base64", func(t *testing.T) { - cfg := config.WeComConfig{} - cfg.SetToken("test_token") - cfg.WebhookURL = "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test" - cfg.SetEncodingAESKey("") - ch, _ := NewWeComBotChannel(cfg, msgBus) - - _, err := decryptMessage("invalid_base64!!!", ch.config.EncodingAESKey()) - if err == nil { - t.Error("expected error for invalid base64, got nil") - } - }) - - t.Run("invalid AES key", func(t *testing.T) { - cfg := config.WeComConfig{} - cfg.SetToken("test_token") - cfg.WebhookURL = "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test" - cfg.SetEncodingAESKey("invalid_key") - ch, _ := NewWeComBotChannel(cfg, msgBus) - - _, err := decryptMessage(base64.StdEncoding.EncodeToString([]byte("test")), ch.config.EncodingAESKey()) - if err == nil { - t.Error("expected error for invalid AES key, got nil") - } - }) -} - -func TestWeComBotPKCS7Unpad(t *testing.T) { - tests := []struct { - name string - input []byte - expected []byte - }{ - { - name: "empty input", - input: []byte{}, - expected: []byte{}, - }, - { - name: "valid padding 3 bytes", - input: append([]byte("hello"), bytes.Repeat([]byte{3}, 3)...), - expected: []byte("hello"), - }, - { - name: "valid padding 16 bytes (full block)", - input: append([]byte("123456789012345"), bytes.Repeat([]byte{16}, 16)...), - expected: []byte("123456789012345"), - }, - { - name: "invalid padding larger than data", - input: []byte{20}, - expected: nil, // should return error - }, - { - name: "invalid padding zero", - input: append([]byte("test"), byte(0)), - expected: nil, // should return error - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result, err := pkcs7Unpad(tt.input) - if tt.expected == nil { - // This case should return an error - if err == nil { - t.Errorf("pkcs7Unpad() expected error for invalid padding, got result: %v", result) - } - return - } - if err != nil { - t.Errorf("pkcs7Unpad() unexpected error: %v", err) - return - } - if !bytes.Equal(result, tt.expected) { - t.Errorf("pkcs7Unpad() = %v, want %v", result, tt.expected) - } - }) - } -} - -func TestWeComBotHandleVerification(t *testing.T) { - msgBus := bus.NewMessageBus() - aesKey := generateTestAESKey() - cfg := config.WeComConfig{} - cfg.SetToken("test_token") - cfg.SetEncodingAESKey(aesKey) - cfg.WebhookURL = "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test" - ch, _ := NewWeComBotChannel(cfg, msgBus) - - t.Run("valid verification request", func(t *testing.T) { - echostr := "test_echostr_123" - encryptedEchostr, _ := encryptTestMessage(echostr, aesKey) - timestamp := "1234567890" - nonce := "test_nonce" - signature := generateSignature("test_token", timestamp, nonce, encryptedEchostr) - - req := httptest.NewRequest( - http.MethodGet, - "/webhook/wecom?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce+"&echostr="+encryptedEchostr, - nil, - ) - w := httptest.NewRecorder() - - ch.handleVerification(context.Background(), w, req) - - if w.Code != http.StatusOK { - t.Errorf("status code = %d, want %d", w.Code, http.StatusOK) - } - if w.Body.String() != echostr { - t.Errorf("response body = %q, want %q", w.Body.String(), echostr) - } - }) - - t.Run("missing parameters", func(t *testing.T) { - req := httptest.NewRequest(http.MethodGet, "/webhook/wecom?msg_signature=sig×tamp=ts", nil) - w := httptest.NewRecorder() - - ch.handleVerification(context.Background(), w, req) - - if w.Code != http.StatusBadRequest { - t.Errorf("status code = %d, want %d", w.Code, http.StatusBadRequest) - } - }) - - t.Run("invalid signature", func(t *testing.T) { - echostr := "test_echostr" - encryptedEchostr, _ := encryptTestMessage(echostr, aesKey) - timestamp := "1234567890" - nonce := "test_nonce" - - req := httptest.NewRequest( - http.MethodGet, - "/webhook/wecom?msg_signature=invalid_sig×tamp="+timestamp+"&nonce="+nonce+"&echostr="+encryptedEchostr, - nil, - ) - w := httptest.NewRecorder() - - ch.handleVerification(context.Background(), w, req) - - if w.Code != http.StatusForbidden { - t.Errorf("status code = %d, want %d", w.Code, http.StatusForbidden) - } - }) -} - -func TestWeComBotHandleMessageCallback(t *testing.T) { - msgBus := bus.NewMessageBus() - aesKey := generateTestAESKey() - cfg := config.WeComConfig{} - cfg.SetToken("test_token") - cfg.SetEncodingAESKey(aesKey) - cfg.WebhookURL = "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test" - ch, _ := NewWeComBotChannel(cfg, msgBus) - - runBotMessageCallback := func(t *testing.T, jsonMsg string) *httptest.ResponseRecorder { - t.Helper() - encrypted, _ := encryptTestMessage(jsonMsg, aesKey) - encryptedWrapper := struct { - XMLName xml.Name `xml:"xml"` - Encrypt string `xml:"Encrypt"` - }{ - Encrypt: encrypted, - } - wrapperData, _ := xml.Marshal(encryptedWrapper) - timestamp := "1234567890" - nonce := "test_nonce" - signature := generateSignature("test_token", timestamp, nonce, encrypted) - req := httptest.NewRequest( - http.MethodPost, - "/webhook/wecom?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce, - bytes.NewReader(wrapperData), - ) - w := httptest.NewRecorder() - ch.handleMessageCallback(context.Background(), w, req) - return w - } - - t.Run("valid direct message callback", func(t *testing.T) { - w := runBotMessageCallback(t, `{ - "msgid": "test_msg_id_123", - "aibotid": "test_aibot_id", - "chattype": "single", - "from": {"userid": "user123"}, - "response_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", - "msgtype": "text", - "text": {"content": "Hello World"} - }`) - if w.Code != http.StatusOK { - t.Errorf("status code = %d, want %d", w.Code, http.StatusOK) - } - if w.Body.String() != "success" { - t.Errorf("response body = %q, want %q", w.Body.String(), "success") - } - }) - - t.Run("valid group message callback", func(t *testing.T) { - w := runBotMessageCallback(t, `{ - "msgid": "test_msg_id_456", - "aibotid": "test_aibot_id", - "chatid": "group_chat_id_123", - "chattype": "group", - "from": {"userid": "user456"}, - "response_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", - "msgtype": "text", - "text": {"content": "Hello Group"} - }`) - if w.Code != http.StatusOK { - t.Errorf("status code = %d, want %d", w.Code, http.StatusOK) - } - if w.Body.String() != "success" { - t.Errorf("response body = %q, want %q", w.Body.String(), "success") - } - }) - - t.Run("missing parameters", func(t *testing.T) { - req := httptest.NewRequest(http.MethodPost, "/webhook/wecom?msg_signature=sig", nil) - w := httptest.NewRecorder() - - ch.handleMessageCallback(context.Background(), w, req) - - if w.Code != http.StatusBadRequest { - t.Errorf("status code = %d, want %d", w.Code, http.StatusBadRequest) - } - }) - - t.Run("invalid XML", func(t *testing.T) { - timestamp := "1234567890" - nonce := "test_nonce" - signature := generateSignature("test_token", timestamp, nonce, "") - - req := httptest.NewRequest( - http.MethodPost, - "/webhook/wecom?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce, - strings.NewReader("invalid xml"), - ) - w := httptest.NewRecorder() - - ch.handleMessageCallback(context.Background(), w, req) - - if w.Code != http.StatusBadRequest { - t.Errorf("status code = %d, want %d", w.Code, http.StatusBadRequest) - } - }) - - t.Run("invalid signature", func(t *testing.T) { - encryptedWrapper := struct { - XMLName xml.Name `xml:"xml"` - Encrypt string `xml:"Encrypt"` - }{ - Encrypt: "encrypted_data", - } - wrapperData, _ := xml.Marshal(encryptedWrapper) - - timestamp := "1234567890" - nonce := "test_nonce" - - req := httptest.NewRequest( - http.MethodPost, - "/webhook/wecom?msg_signature=invalid_sig×tamp="+timestamp+"&nonce="+nonce, - bytes.NewReader(wrapperData), - ) - w := httptest.NewRecorder() - - ch.handleMessageCallback(context.Background(), w, req) - - if w.Code != http.StatusForbidden { - t.Errorf("status code = %d, want %d", w.Code, http.StatusForbidden) - } - }) -} - -func TestWeComBotProcessMessage(t *testing.T) { - msgBus := bus.NewMessageBus() - cfg := config.WeComConfig{} - cfg.SetToken("test_token") - cfg.WebhookURL = "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test" - ch, _ := NewWeComBotChannel(cfg, msgBus) - - t.Run("process direct text message", func(t *testing.T) { - msg := WeComBotMessage{ - MsgID: "test_msg_id_123", - AIBotID: "test_aibot_id", - ChatType: "single", - ResponseURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", - MsgType: "text", - } - msg.From.UserID = "user123" - msg.Text.Content = "Hello World" - - // Should not panic - ch.processMessage(context.Background(), msg) - }) - - t.Run("process group text message", func(t *testing.T) { - msg := WeComBotMessage{ - MsgID: "test_msg_id_456", - AIBotID: "test_aibot_id", - ChatID: "group_chat_id_123", - ChatType: "group", - ResponseURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", - MsgType: "text", - } - msg.From.UserID = "user456" - msg.Text.Content = "Hello Group" - - // Should not panic - ch.processMessage(context.Background(), msg) - }) - - t.Run("process voice message", func(t *testing.T) { - msg := WeComBotMessage{ - MsgID: "test_msg_id_789", - AIBotID: "test_aibot_id", - ChatType: "single", - ResponseURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", - MsgType: "voice", - } - msg.From.UserID = "user123" - msg.Voice.Content = "Voice message text" - - // Should not panic - ch.processMessage(context.Background(), msg) - }) - - t.Run("skip unsupported message type", func(t *testing.T) { - msg := WeComBotMessage{ - MsgID: "test_msg_id_000", - AIBotID: "test_aibot_id", - ChatType: "single", - ResponseURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", - MsgType: "video", - } - msg.From.UserID = "user123" - - // Should not panic - ch.processMessage(context.Background(), msg) - }) -} - -func TestWeComBotHandleWebhook(t *testing.T) { - msgBus := bus.NewMessageBus() - cfg := config.WeComConfig{} - cfg.SetToken("test_token") - cfg.WebhookURL = "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test" - ch, _ := NewWeComBotChannel(cfg, msgBus) - - t.Run("GET request calls verification", func(t *testing.T) { - echostr := "test_echostr" - encoded := base64.StdEncoding.EncodeToString([]byte(echostr)) - timestamp := "1234567890" - nonce := "test_nonce" - signature := generateSignature("test_token", timestamp, nonce, encoded) - - req := httptest.NewRequest( - http.MethodGet, - "/webhook/wecom?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce+"&echostr="+encoded, - nil, - ) - w := httptest.NewRecorder() - - ch.handleWebhook(w, req) - - if w.Code != http.StatusOK { - t.Errorf("status code = %d, want %d", w.Code, http.StatusOK) - } - }) - - t.Run("POST request calls message callback", func(t *testing.T) { - encryptedWrapper := struct { - XMLName xml.Name `xml:"xml"` - Encrypt string `xml:"Encrypt"` - }{ - Encrypt: base64.StdEncoding.EncodeToString([]byte("test")), - } - wrapperData, _ := xml.Marshal(encryptedWrapper) - - timestamp := "1234567890" - nonce := "test_nonce" - signature := generateSignature("test_token", timestamp, nonce, encryptedWrapper.Encrypt) - - req := httptest.NewRequest( - http.MethodPost, - "/webhook/wecom?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce, - bytes.NewReader(wrapperData), - ) - w := httptest.NewRecorder() - - ch.handleWebhook(w, req) - - // Should not be method not allowed - if w.Code == http.StatusMethodNotAllowed { - t.Error("POST request should not return Method Not Allowed") - } - }) - - t.Run("unsupported method", func(t *testing.T) { - req := httptest.NewRequest(http.MethodPut, "/webhook/wecom", nil) - w := httptest.NewRecorder() - - ch.handleWebhook(w, req) - - if w.Code != http.StatusMethodNotAllowed { - t.Errorf("status code = %d, want %d", w.Code, http.StatusMethodNotAllowed) - } - }) -} - -func TestWeComBotHandleHealth(t *testing.T) { - msgBus := bus.NewMessageBus() - cfg := config.WeComConfig{} - cfg.SetToken("test_token") - cfg.WebhookURL = "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test" - ch, _ := NewWeComBotChannel(cfg, msgBus) - - req := httptest.NewRequest(http.MethodGet, "/health/wecom", nil) - w := httptest.NewRecorder() - - ch.handleHealth(w, req) - - if w.Code != http.StatusOK { - t.Errorf("status code = %d, want %d", w.Code, http.StatusOK) - } - - contentType := w.Header().Get("Content-Type") - if contentType != "application/json" { - t.Errorf("Content-Type = %q, want %q", contentType, "application/json") - } - - body := w.Body.String() - if !strings.Contains(body, "status") || !strings.Contains(body, "running") { - t.Errorf("response body should contain status and running fields, got: %s", body) - } -} - -func TestWeComBotReplyMessage(t *testing.T) { - msg := WeComBotReplyMessage{ - MsgType: "text", - } - msg.Text.Content = "Hello World" - - if msg.MsgType != "text" { - t.Errorf("MsgType = %q, want %q", msg.MsgType, "text") - } - if msg.Text.Content != "Hello World" { - t.Errorf("Text.Content = %q, want %q", msg.Text.Content, "Hello World") - } -} - -func TestWeComBotMessageStructure(t *testing.T) { - jsonData := `{ - "msgid": "test_msg_id_123", - "aibotid": "test_aibot_id", - "chatid": "group_chat_id_123", - "chattype": "group", - "from": {"userid": "user123"}, - "response_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", - "msgtype": "text", - "text": {"content": "Hello World"} - }` - - var msg WeComBotMessage - err := json.Unmarshal([]byte(jsonData), &msg) - if err != nil { - t.Fatalf("failed to unmarshal JSON: %v", err) - } - - if msg.MsgID != "test_msg_id_123" { - t.Errorf("MsgID = %q, want %q", msg.MsgID, "test_msg_id_123") - } - if msg.AIBotID != "test_aibot_id" { - t.Errorf("AIBotID = %q, want %q", msg.AIBotID, "test_aibot_id") - } - if msg.ChatID != "group_chat_id_123" { - t.Errorf("ChatID = %q, want %q", msg.ChatID, "group_chat_id_123") - } - if msg.ChatType != "group" { - t.Errorf("ChatType = %q, want %q", msg.ChatType, "group") - } - if msg.From.UserID != "user123" { - t.Errorf("From.UserID = %q, want %q", msg.From.UserID, "user123") - } - if msg.MsgType != "text" { - t.Errorf("MsgType = %q, want %q", msg.MsgType, "text") - } - if msg.Text.Content != "Hello World" { - t.Errorf("Text.Content = %q, want %q", msg.Text.Content, "Hello World") - } -} diff --git a/pkg/channels/wecom/common.go b/pkg/channels/wecom/common.go deleted file mode 100644 index 9a622a2fc..000000000 --- a/pkg/channels/wecom/common.go +++ /dev/null @@ -1,199 +0,0 @@ -package wecom - -import ( - "bytes" - "crypto/aes" - "crypto/cipher" - "crypto/rand" - "crypto/sha1" - "encoding/base64" - "encoding/binary" - "fmt" - "math/big" - "sort" - "strings" -) - -// blockSize is the PKCS7 block size used by WeCom (32) -const blockSize = 32 - -// computeSignature computes the WeCom message signature from the given parameters. -// It sorts [token, timestamp, nonce, encrypt], concatenates them and returns the SHA1 hex digest. -func computeSignature(token, timestamp, nonce, encrypt string) string { - params := []string{token, timestamp, nonce, encrypt} - sort.Strings(params) - str := strings.Join(params, "") - hash := sha1.Sum([]byte(str)) - return fmt.Sprintf("%x", hash) -} - -// verifySignature verifies the message signature for WeCom -// This is a common function used by both WeCom Bot and WeCom App -func verifySignature(token, msgSignature, timestamp, nonce, msgEncrypt string) bool { - if token == "" { - return false - } - return computeSignature(token, timestamp, nonce, msgEncrypt) == msgSignature -} - -// decryptMessage decrypts the encrypted message using AES -// For AIBOT, receiveid should be the aibotid; for other apps, it should be corp_id -func decryptMessage(encryptedMsg, encodingAESKey string) (string, error) { - return decryptMessageWithVerify(encryptedMsg, encodingAESKey, "") -} - -// decryptMessageWithVerify decrypts the encrypted message and optionally verifies receiveid -// receiveid: for AIBOT use aibotid, for WeCom App use corp_id. If empty, skip verification. -func decryptMessageWithVerify(encryptedMsg, encodingAESKey, receiveid string) (string, error) { - if encodingAESKey == "" { - // No encryption, return as is (base64 decode) - decoded, err := base64.StdEncoding.DecodeString(encryptedMsg) - if err != nil { - return "", err - } - return string(decoded), nil - } - - aesKey, err := decodeWeComAESKey(encodingAESKey) - if err != nil { - return "", err - } - - cipherText, err := base64.StdEncoding.DecodeString(encryptedMsg) - if err != nil { - return "", fmt.Errorf("failed to decode message: %w", err) - } - - plainText, err := decryptAESCBC(aesKey, cipherText) - if err != nil { - return "", err - } - - return unpackWeComFrame(plainText, receiveid) -} - -// decodeWeComAESKey base64-decodes the 43-character EncodingAESKey (trailing "=" is -// appended automatically) and validates that the result is exactly 32 bytes. -// It is the single place that handles this repeated pattern in both encrypt and decrypt paths. -func decodeWeComAESKey(encodingAESKey string) ([]byte, error) { - aesKey, err := base64.StdEncoding.DecodeString(encodingAESKey + "=") - if err != nil { - return nil, fmt.Errorf("failed to decode AES key: %w", err) - } - if len(aesKey) != 32 { - return nil, fmt.Errorf("invalid AES key length: %d", len(aesKey)) - } - return aesKey, nil -} - -// encryptAESCBC encrypts plaintext using AES-CBC with the given key, mirroring -// decryptAESCBC. IV = aesKey[:aes.BlockSize]. The caller must PKCS7-pad the -// plaintext to a multiple of aes.BlockSize before calling. -func encryptAESCBC(aesKey, plaintext []byte) ([]byte, error) { - block, err := aes.NewCipher(aesKey) - if err != nil { - return nil, fmt.Errorf("failed to create cipher: %w", err) - } - iv := aesKey[:aes.BlockSize] - ciphertext := make([]byte, len(plaintext)) - cipher.NewCBCEncrypter(block, iv).CryptBlocks(ciphertext, plaintext) - return ciphertext, nil -} - -// packWeComFrame builds the WeCom wire format: -// -// random(16 ASCII digits) + msg_len(4, big-endian) + msg + receiveid -func packWeComFrame(msg, receiveid string) ([]byte, error) { - randomBytes := make([]byte, 16) - for i := range 16 { - n, err := rand.Int(rand.Reader, big.NewInt(10)) - if err != nil { - return nil, fmt.Errorf("failed to generate random: %w", err) - } - randomBytes[i] = byte('0' + n.Int64()) - } - msgBytes := []byte(msg) - msgLenBytes := make([]byte, 4) - binary.BigEndian.PutUint32(msgLenBytes, uint32(len(msgBytes))) - var buf bytes.Buffer - buf.Write(randomBytes) - buf.Write(msgLenBytes) - buf.Write(msgBytes) - buf.WriteString(receiveid) - return buf.Bytes(), nil -} - -// unpackWeComFrame parses the WeCom wire format produced by packWeComFrame. -// If receiveid is non-empty it verifies the frame's trailing receiveid field. -func unpackWeComFrame(data []byte, receiveid string) (string, error) { - if len(data) < 20 { - return "", fmt.Errorf("decrypted frame too short: %d bytes", len(data)) - } - msgLen := binary.BigEndian.Uint32(data[16:20]) - if int(msgLen) > len(data)-20 { - return "", fmt.Errorf("invalid message length: %d", msgLen) - } - msg := data[20 : 20+msgLen] - if receiveid != "" && len(data) > 20+int(msgLen) { - actualReceiveID := string(data[20+msgLen:]) - if actualReceiveID != receiveid { - return "", fmt.Errorf("receiveid mismatch: expected %s, got %s", receiveid, actualReceiveID) - } - } - return string(msg), nil -} - -// decryptAESCBC decrypts ciphertext using AES-CBC with the given key. -// IV = aesKey[:aes.BlockSize]. PKCS7 padding is stripped from the returned plaintext. -func decryptAESCBC(aesKey, ciphertext []byte) ([]byte, error) { - if len(ciphertext) == 0 { - return nil, fmt.Errorf("ciphertext is empty") - } - if len(ciphertext)%aes.BlockSize != 0 { - return nil, fmt.Errorf("ciphertext length %d is not a multiple of block size", len(ciphertext)) - } - block, err := aes.NewCipher(aesKey) - if err != nil { - return nil, fmt.Errorf("failed to create cipher: %w", err) - } - iv := aesKey[:aes.BlockSize] - plaintext := make([]byte, len(ciphertext)) - cipher.NewCBCDecrypter(block, iv).CryptBlocks(plaintext, ciphertext) - plaintext, err = pkcs7Unpad(plaintext) - if err != nil { - return nil, fmt.Errorf("failed to unpad: %w", err) - } - return plaintext, nil -} - -// pkcs7Pad adds PKCS7 padding -func pkcs7Pad(data []byte, blockSize int) []byte { - padding := blockSize - (len(data) % blockSize) - if padding == 0 { - padding = blockSize - } - padText := bytes.Repeat([]byte{byte(padding)}, padding) - return append(data, padText...) -} - -// pkcs7Unpad removes PKCS7 padding with validation -func pkcs7Unpad(data []byte) ([]byte, error) { - if len(data) == 0 { - return data, nil - } - padding := int(data[len(data)-1]) - // WeCom uses 32-byte block size for PKCS7 padding - if padding == 0 || padding > blockSize { - return nil, fmt.Errorf("invalid padding size: %d", padding) - } - if padding > len(data) { - return nil, fmt.Errorf("padding size larger than data") - } - // Verify all padding bytes - for i := range padding { - if data[len(data)-1-i] != byte(padding) { - return nil, fmt.Errorf("invalid padding byte at position %d", i) - } - } - return data[:len(data)-padding], nil -} diff --git a/pkg/channels/wecom/dedupe.go b/pkg/channels/wecom/dedupe.go deleted file mode 100644 index 865be668e..000000000 --- a/pkg/channels/wecom/dedupe.go +++ /dev/null @@ -1,54 +0,0 @@ -package wecom - -import "sync" - -const wecomMaxProcessedMessages = 1000 - -// MessageDeduplicator provides thread-safe message deduplication using a circular queue (ring buffer) -// combined with a hash map. This ensures fast O(1) lookups while naturally evicting the oldest -// messages without causing "amnesia cliffs" when the limit is reached. -type MessageDeduplicator struct { - mu sync.Mutex - msgs map[string]bool - ring []string - idx int - max int -} - -// NewMessageDeduplicator creates a new deduplicator with the specified capacity. -func NewMessageDeduplicator(maxEntries int) *MessageDeduplicator { - if maxEntries <= 0 { - maxEntries = wecomMaxProcessedMessages - } - return &MessageDeduplicator{ - msgs: make(map[string]bool, maxEntries), - ring: make([]string, maxEntries), - max: maxEntries, - } -} - -// MarkMessageProcessed marks msgID as processed and returns false for duplicates. -func (d *MessageDeduplicator) MarkMessageProcessed(msgID string) bool { - d.mu.Lock() - defer d.mu.Unlock() - - // 1. Check for duplicate - if d.msgs[msgID] { - return false - } - - // 2. Evict the oldest message at our current ring position (if any) - oldestID := d.ring[d.idx] - if oldestID != "" { - delete(d.msgs, oldestID) - } - - // 3. Store the new message - d.msgs[msgID] = true - d.ring[d.idx] = msgID - - // 4. Advance the circle queue index - d.idx = (d.idx + 1) % d.max - - return true -} diff --git a/pkg/channels/wecom/dedupe_test.go b/pkg/channels/wecom/dedupe_test.go deleted file mode 100644 index 10dff4cfe..000000000 --- a/pkg/channels/wecom/dedupe_test.go +++ /dev/null @@ -1,83 +0,0 @@ -package wecom - -import ( - "sync" - "testing" -) - -func TestMessageDeduplicator_DuplicateDetection(t *testing.T) { - d := NewMessageDeduplicator(wecomMaxProcessedMessages) - - if ok := d.MarkMessageProcessed("msg-1"); !ok { - t.Fatalf("first message should be accepted") - } - - if ok := d.MarkMessageProcessed("msg-1"); ok { - t.Fatalf("duplicate message should be rejected") - } -} - -func TestMessageDeduplicator_ConcurrentSameMessage(t *testing.T) { - d := NewMessageDeduplicator(wecomMaxProcessedMessages) - - const goroutines = 64 - var wg sync.WaitGroup - wg.Add(goroutines) - - results := make(chan bool, goroutines) - for i := 0; i < goroutines; i++ { - go func() { - defer wg.Done() - results <- d.MarkMessageProcessed("msg-concurrent") - }() - } - - wg.Wait() - close(results) - - successes := 0 - for ok := range results { - if ok { - successes++ - } - } - - if successes != 1 { - t.Fatalf("expected exactly 1 successful mark, got %d", successes) - } -} - -func TestMessageDeduplicator_CircularQueueEviction(t *testing.T) { - // Create a deduplicator with a very small capacity to test eviction easily. - capacity := 3 - d := NewMessageDeduplicator(capacity) - - // Fill the queue. - d.MarkMessageProcessed("msg-1") - d.MarkMessageProcessed("msg-2") - d.MarkMessageProcessed("msg-3") - - // At this point, the queue is full. msg-1 is the oldest. - if len(d.msgs) != 3 { - t.Fatalf("expected map size to be 3, got %d", len(d.msgs)) - } - - // This should evict msg-1 and add msg-4. - if ok := d.MarkMessageProcessed("msg-4"); !ok { - t.Fatalf("msg-4 should be accepted") - } - - if len(d.msgs) != 3 { - t.Fatalf("expected map size to remain at max capacity (3), got %d", len(d.msgs)) - } - - // msg-1 should now be forgotten (evicted). - if ok := d.MarkMessageProcessed("msg-1"); !ok { - t.Fatalf("msg-1 should be accepted again because it was evicted") - } - - // msg-2 should have been evicted when we added msg-1 back. - if ok := d.MarkMessageProcessed("msg-2"); !ok { - t.Fatalf("msg-2 should be accepted again because it was evicted") - } -} diff --git a/pkg/channels/wecom/init.go b/pkg/channels/wecom/init.go index bc5a70fa3..3aad84d42 100644 --- a/pkg/channels/wecom/init.go +++ b/pkg/channels/wecom/init.go @@ -8,12 +8,6 @@ import ( func init() { channels.RegisterFactory("wecom", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) { - return NewWeComBotChannel(cfg.Channels.WeCom, b) - }) - channels.RegisterFactory("wecom_app", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) { - return NewWeComAppChannel(cfg.Channels.WeComApp, b) - }) - channels.RegisterFactory("wecom_aibot", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) { - return NewWeComAIBotChannel(cfg.Channels.WeComAIBot, b) + return NewChannel(cfg.Channels.WeCom, b) }) } diff --git a/pkg/channels/wecom/media.go b/pkg/channels/wecom/media.go new file mode 100644 index 000000000..974a3bf4d --- /dev/null +++ b/pkg/channels/wecom/media.go @@ -0,0 +1,802 @@ +package wecom + +import ( + "context" + "crypto/aes" + "crypto/cipher" + "crypto/md5" + "encoding/base64" + "encoding/hex" + "encoding/json" + "fmt" + "io" + "mime" + "net/http" + "net/url" + "os" + "path/filepath" + "strings" + "time" + + "github.com/h2non/filetype" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/media" +) + +const ( + wecomOutboundMediaMaxBytes = 20 << 20 + wecomOutboundImageMaxBytes = 2 << 20 + wecomOutboundVoiceMaxBytes = 2 << 20 + wecomOutboundVideoMaxBytes = 10 << 20 + wecomUploadChunkMaxBytes = 512 << 10 + wecomUploadMaxChunks = 100 + wecomUploadMinBytes = 5 +) + +type wecomOutboundMedia struct { + MsgType string + MediaID string + Title string + Description string +} + +func (m *wecomOutboundMedia) respondBody() wecomRespondMsgBody { + body := wecomRespondMsgBody{MsgType: m.MsgType} + switch m.MsgType { + case "file": + body.File = &wecomMediaRefContent{MediaID: m.MediaID} + case "image": + body.Image = &wecomMediaRefContent{MediaID: m.MediaID} + case "voice": + body.Voice = &wecomMediaRefContent{MediaID: m.MediaID} + case "video": + body.Video = &wecomVideoContent{ + MediaID: m.MediaID, + Title: m.Title, + Description: m.Description, + } + } + return body +} + +func (m *wecomOutboundMedia) sendBody(chatID string, chatType uint32) wecomSendMsgBody { + body := wecomSendMsgBody{ + ChatID: chatID, + ChatType: chatType, + MsgType: m.MsgType, + } + switch m.MsgType { + case "file": + body.File = &wecomMediaRefContent{MediaID: m.MediaID} + case "image": + body.Image = &wecomMediaRefContent{MediaID: m.MediaID} + case "voice": + body.Voice = &wecomMediaRefContent{MediaID: m.MediaID} + case "video": + body.Video = &wecomVideoContent{ + MediaID: m.MediaID, + Title: m.Title, + Description: m.Description, + } + } + return body +} + +func decodeMediaAESKey(value string) ([]byte, error) { + if value == "" { + return nil, nil + } + key, err := base64.StdEncoding.DecodeString(value) + if err == nil && len(key) == 32 { + return key, nil + } + key, err = base64.StdEncoding.DecodeString(value + "=") + if err != nil { + return nil, fmt.Errorf("decode AES key: %w", err) + } + if len(key) != 32 { + return nil, fmt.Errorf("invalid AES key length %d", len(key)) + } + return key, nil +} + +func decryptAESCBC(key, ciphertext []byte) ([]byte, error) { + if len(ciphertext) == 0 { + return nil, fmt.Errorf("ciphertext is empty") + } + if len(ciphertext)%aes.BlockSize != 0 { + return nil, fmt.Errorf("ciphertext length %d is not a multiple of block size", len(ciphertext)) + } + block, err := aes.NewCipher(key) + if err != nil { + return nil, fmt.Errorf("create cipher: %w", err) + } + plaintext := make([]byte, len(ciphertext)) + iv := key[:aes.BlockSize] + cipher.NewCBCDecrypter(block, iv).CryptBlocks(plaintext, ciphertext) + return pkcs7Unpad(plaintext) +} + +func pkcs7Unpad(data []byte) ([]byte, error) { + if len(data) == 0 { + return nil, fmt.Errorf("empty plaintext") + } + padding := int(data[len(data)-1]) + if padding == 0 || padding > 32 || padding > len(data) { + return nil, fmt.Errorf("invalid padding size %d", padding) + } + for i := 0; i < padding; i++ { + if data[len(data)-1-i] != byte(padding) { + return nil, fmt.Errorf("invalid padding byte") + } + } + return data[:len(data)-padding], nil +} + +func inferMediaExt(contentType, fallback string) string { + contentType = normalizeWeComContentType(contentType) + switch contentType { + case "image/jpeg", "image/jpg": + return ".jpg" + case "image/png": + return ".png" + case "image/gif": + return ".gif" + case "image/webp": + return ".webp" + case "application/pdf": + return ".pdf" + case "video/mp4": + return ".mp4" + default: + return fallback + } +} + +func normalizeWeComContentType(value string) string { + value = strings.ToLower(strings.TrimSpace(value)) + if idx := strings.Index(value, ";"); idx >= 0 { + value = strings.TrimSpace(value[:idx]) + } + return value +} + +func isGenericWeComContentType(value string) bool { + switch normalizeWeComContentType(value) { + case "", "application/octet-stream", "binary/octet-stream", "application/unknown", "application/binary": + return true + default: + return false + } +} + +func sanitizeWeComFilename(name string) string { + name = filepath.Base(strings.TrimSpace(name)) + if name == "." || name == "/" || name == "" { + return "" + } + return name +} + +func candidateWeComFilename(resourceURL, contentDisposition, fallbackName string) string { + if _, params, err := mime.ParseMediaType(contentDisposition); err == nil { + if name := sanitizeWeComFilename(params["filename"]); name != "" { + return name + } + if name := sanitizeWeComFilename(params["filename*"]); name != "" { + return name + } + } + + if parsed, err := url.Parse(resourceURL); err == nil { + query := parsed.Query() + for _, key := range []string{"filename", "file_name", "name"} { + if name := sanitizeWeComFilename(query.Get(key)); name != "" { + return name + } + } + if name := sanitizeWeComFilename(parsed.Path); name != "" { + return name + } + } + + return sanitizeWeComFilename(fallbackName) +} + +func detectWeComFiletype(data []byte) (string, string) { + kind, err := filetype.Match(data) + if err != nil || kind == filetype.Unknown { + return "", "" + } + ext := "" + if kind.Extension != "" { + ext = "." + strings.ToLower(kind.Extension) + } + return normalizeWeComContentType(kind.MIME.Value), ext +} + +func detectWeComMediaMetadata( + data []byte, + fallbackName, fallbackContentType, resourceURL, contentDisposition string, +) (string, string) { + filename := candidateWeComFilename(resourceURL, contentDisposition, fallbackName) + if filename == "" { + filename = "media" + } + + ext := strings.ToLower(filepath.Ext(filename)) + contentType := normalizeWeComContentType(fallbackContentType) + detectedType, detectedExt := detectWeComFiletype(data) + + if ext != "" && isGenericWeComContentType(contentType) { + if byExt := normalizeWeComContentType(mime.TypeByExtension(ext)); byExt != "" { + contentType = byExt + } + } + + if detectedType != "" { + switch { + case contentType == "": + contentType = detectedType + case isGenericWeComContentType(contentType): + contentType = detectedType + case strings.HasPrefix(detectedType, "image/") && !strings.HasPrefix(contentType, "image/"): + contentType = detectedType + case strings.HasPrefix(detectedType, "audio/") && !strings.HasPrefix(contentType, "audio/"): + contentType = detectedType + case strings.HasPrefix(detectedType, "video/") && !strings.HasPrefix(contentType, "video/"): + contentType = detectedType + } + } + + if contentType == "" && ext != "" { + contentType = normalizeWeComContentType(mime.TypeByExtension(ext)) + } + if contentType == "" { + contentType = normalizeWeComContentType(http.DetectContentType(data)) + } + + if ext == "" { + ext = detectedExt + } + if ext == "" && contentType != "" { + if exts, err := mime.ExtensionsByType(contentType); err == nil && len(exts) > 0 { + ext = strings.ToLower(exts[0]) + } + } + + if filepath.Ext(filename) == "" && ext != "" { + filename += ext + } + return filename, contentType +} + +func (c *WeComChannel) storeRemoteMedia( + ctx context.Context, + scope, msgID, resourceURL, aesKey, fallbackExt string, +) (string, error) { + store := c.GetMediaStore() + if store == nil { + return "", fmt.Errorf("no media store available") + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, resourceURL, nil) + if err != nil { + return "", fmt.Errorf("create request: %w", err) + } + resp, err := c.mediaClient.Do(req) + if err != nil { + return "", fmt.Errorf("download media: %w", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("download media returned HTTP %d", resp.StatusCode) + } + + data, err := io.ReadAll(io.LimitReader(resp.Body, wecomOutboundMediaMaxBytes+1)) + if err != nil { + return "", fmt.Errorf("read media: %w", err) + } + if len(data) > wecomOutboundMediaMaxBytes { + return "", fmt.Errorf("media too large") + } + + if aesKey != "" { + key, keyErr := decodeMediaAESKey(aesKey) + if keyErr != nil { + return "", keyErr + } + data, err = decryptAESCBC(key, data) + if err != nil { + return "", fmt.Errorf("decrypt media: %w", err) + } + } + + filename, contentType := detectWeComMediaMetadata( + data, + msgID+fallbackExt, + resp.Header.Get("Content-Type"), + resourceURL, + resp.Header.Get("Content-Disposition"), + ) + ext := filepath.Ext(filename) + if ext == "" { + ext = inferMediaExt(contentType, fallbackExt) + } + mediaDir := filepath.Join(os.TempDir(), "picoclaw_media") + if mkdirErr := os.MkdirAll(mediaDir, 0o700); mkdirErr != nil { + return "", fmt.Errorf("mkdir media dir: %w", mkdirErr) + } + tmpFile, err := os.CreateTemp(mediaDir, msgID+"-*"+ext) + if err != nil { + return "", fmt.Errorf("create temp file: %w", err) + } + tmpPath := tmpFile.Name() + if _, writeErr := tmpFile.Write(data); writeErr != nil { + tmpFile.Close() + _ = os.Remove(tmpPath) + return "", fmt.Errorf("write temp file: %w", writeErr) + } + if closeErr := tmpFile.Close(); closeErr != nil { + _ = os.Remove(tmpPath) + return "", fmt.Errorf("close temp file: %w", closeErr) + } + + ref, err := store.Store(tmpPath, media.MediaMeta{ + Filename: filename, + ContentType: contentType, + Source: "wecom", + CleanupPolicy: media.CleanupPolicyDeleteOnCleanup, + }, scope) + if err != nil { + _ = os.Remove(tmpPath) + return "", err + } + return ref, nil +} + +func detectLocalWeComContentType(localPath, hint string) string { + contentType := normalizeWeComContentType(hint) + if !isGenericWeComContentType(contentType) { + return contentType + } + + if kind, err := filetype.MatchFile(localPath); err == nil && kind != filetype.Unknown { + return normalizeWeComContentType(kind.MIME.Value) + } + + if ext := strings.ToLower(filepath.Ext(localPath)); ext != "" { + if byExt := normalizeWeComContentType(mime.TypeByExtension(ext)); byExt != "" { + return byExt + } + } + + file, err := os.Open(localPath) + if err != nil { + return contentType + } + defer file.Close() + + buf := make([]byte, 512) + n, err := file.Read(buf) + if err != nil && err != io.EOF { + return contentType + } + if n == 0 { + return contentType + } + return normalizeWeComContentType(http.DetectContentType(buf[:n])) +} + +func writeWeComTempFile(prefix, filename string, data []byte) (string, error) { + mediaDir := filepath.Join(os.TempDir(), "picoclaw_media") + if err := os.MkdirAll(mediaDir, 0o700); err != nil { + return "", fmt.Errorf("mkdir media dir: %w", err) + } + + ext := strings.ToLower(filepath.Ext(filename)) + tmpFile, err := os.CreateTemp(mediaDir, prefix+"-*"+ext) + if err != nil { + return "", fmt.Errorf("create temp file: %w", err) + } + tmpPath := tmpFile.Name() + + if _, err := tmpFile.Write(data); err != nil { + _ = tmpFile.Close() + _ = os.Remove(tmpPath) + return "", fmt.Errorf("write temp file: %w", err) + } + if err := tmpFile.Close(); err != nil { + _ = os.Remove(tmpPath) + return "", fmt.Errorf("close temp file: %w", err) + } + return tmpPath, nil +} + +func (c *WeComChannel) downloadRemoteMediaToTemp( + ctx context.Context, + resourceURL, fallbackName string, +) (string, string, string, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, resourceURL, nil) + if err != nil { + return "", "", "", fmt.Errorf("create request: %w", err) + } + + resp, err := c.mediaClient.Do(req) + if err != nil { + return "", "", "", fmt.Errorf("download media: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024)) + return "", "", "", fmt.Errorf("download media returned HTTP %d: %s", resp.StatusCode, string(body)) + } + + data, err := io.ReadAll(io.LimitReader(resp.Body, wecomOutboundMediaMaxBytes+1)) + if err != nil { + return "", "", "", fmt.Errorf("read media: %w", err) + } + if len(data) > wecomOutboundMediaMaxBytes { + return "", "", "", fmt.Errorf("media too large") + } + + filename, contentType := detectWeComMediaMetadata( + data, + fallbackName, + resp.Header.Get("Content-Type"), + resourceURL, + resp.Header.Get("Content-Disposition"), + ) + tmpPath, err := writeWeComTempFile("wecom-outbound", filename, data) + if err != nil { + return "", "", "", err + } + return tmpPath, filename, contentType, nil +} + +func (c *WeComChannel) resolveOutboundPart( + ctx context.Context, + part bus.MediaPart, +) (string, string, string, func(), error) { + cleanup := func() {} + filename := sanitizeWeComFilename(part.Filename) + contentType := normalizeWeComContentType(part.ContentType) + ref := strings.TrimSpace(part.Ref) + + switch { + case ref == "": + return "", filename, contentType, cleanup, nil + + case strings.HasPrefix(ref, "http://") || strings.HasPrefix(ref, "https://"): + localPath, name, ct, err := c.downloadRemoteMediaToTemp(ctx, ref, filename) + if err != nil { + return "", "", "", cleanup, err + } + return localPath, name, ct, func() { _ = os.Remove(localPath) }, nil + + case strings.HasPrefix(ref, "media://"): + store := c.GetMediaStore() + if store == nil { + return "", "", "", cleanup, fmt.Errorf("no media store available") + } + + localPath, meta, err := store.ResolveWithMeta(ref) + if err != nil { + return "", "", "", cleanup, err + } + if filename == "" { + filename = sanitizeWeComFilename(meta.Filename) + } + if contentType == "" { + contentType = normalizeWeComContentType(meta.ContentType) + } + if strings.HasPrefix(localPath, "http://") || strings.HasPrefix(localPath, "https://") { + tmpPath, name, ct, err := c.downloadRemoteMediaToTemp(ctx, localPath, filename) + if err != nil { + return "", "", "", cleanup, err + } + return tmpPath, name, ct, func() { _ = os.Remove(tmpPath) }, nil + } + if _, err := os.Stat(localPath); err != nil { + return "", "", "", cleanup, err + } + if filename == "" { + filename = sanitizeWeComFilename(filepath.Base(localPath)) + } + if contentType == "" { + contentType = detectLocalWeComContentType(localPath, "") + } + return localPath, filename, contentType, cleanup, nil + + case strings.HasPrefix(ref, "file://"): + u, err := url.Parse(ref) + if err != nil { + return "", "", "", cleanup, err + } + localPath := u.Path + if _, err := os.Stat(localPath); err != nil { + return "", "", "", cleanup, err + } + if filename == "" { + filename = sanitizeWeComFilename(filepath.Base(localPath)) + } + if contentType == "" { + contentType = detectLocalWeComContentType(localPath, "") + } + return localPath, filename, contentType, cleanup, nil + + default: + if _, err := os.Stat(ref); err != nil { + return "", "", "", cleanup, err + } + if filename == "" { + filename = sanitizeWeComFilename(filepath.Base(ref)) + } + if contentType == "" { + contentType = detectLocalWeComContentType(ref, "") + } + return ref, filename, contentType, cleanup, nil + } +} + +func canWeComSendImage(contentType, ext string, size int64) bool { + if size > wecomOutboundImageMaxBytes { + return false + } + switch normalizeWeComContentType(contentType) { + case "image/jpeg", "image/jpg", "image/png", "image/gif": + return true + } + switch strings.ToLower(ext) { + case ".jpg", ".jpeg", ".png", ".gif": + return true + default: + return false + } +} + +func canWeComSendVoice(contentType, ext string, size int64) bool { + if size > wecomOutboundVoiceMaxBytes { + return false + } + contentType = normalizeWeComContentType(contentType) + return strings.Contains(contentType, "amr") || strings.EqualFold(ext, ".amr") +} + +func canWeComSendVideo(contentType, ext string, size int64) bool { + if size > wecomOutboundVideoMaxBytes { + return false + } + return normalizeWeComContentType(contentType) == "video/mp4" || strings.EqualFold(ext, ".mp4") +} + +func outboundWeComMediaKind(partType, filename, contentType string, size int64) string { + if size < wecomUploadMinBytes { + return "" + } + + partType = strings.ToLower(strings.TrimSpace(partType)) + contentType = normalizeWeComContentType(contentType) + ext := strings.ToLower(filepath.Ext(filename)) + + if partType == "file" { + if size <= wecomOutboundMediaMaxBytes { + return "file" + } + return "" + } + + if (partType == "image" || partType == "") && canWeComSendImage(contentType, ext, size) { + return "image" + } + if (partType == "audio" || partType == "voice" || partType == "") && canWeComSendVoice(contentType, ext, size) { + return "voice" + } + if (partType == "video" || partType == "") && canWeComSendVideo(contentType, ext, size) { + return "video" + } + if size <= wecomOutboundMediaMaxBytes { + return "file" + } + return "" +} + +func trimWeComBytes(value string, limit int) string { + value = strings.TrimSpace(value) + if limit <= 0 || len(value) <= limit { + return value + } + size := 0 + var out strings.Builder + for _, r := range value { + width := len(string(r)) + if size+width > limit { + break + } + size += width + out.WriteRune(r) + } + return out.String() +} + +func ensureWeComOutboundFilename(filename, localPath, contentType string) string { + filename = sanitizeWeComFilename(filename) + if filename == "" { + filename = sanitizeWeComFilename(filepath.Base(localPath)) + } + if filename == "" { + filename = "media" + } + if filepath.Ext(filename) == "" { + fallbackExt := inferMediaExt(contentType, strings.ToLower(filepath.Ext(localPath))) + if fallbackExt != "" { + filename += fallbackExt + } + } + filename = trimWeComBytes(filename, 256) + if filename == "" { + return "media" + } + return filename +} + +func buildWeComVideoContent(mediaID, filename, description string) *wecomVideoContent { + title := strings.TrimSuffix(filename, filepath.Ext(filename)) + title = trimWeComBytes(title, 64) + if title == "" { + title = "video" + } + description = trimWeComBytes(description, 512) + return &wecomVideoContent{ + MediaID: mediaID, + Title: title, + Description: description, + } +} + +func decodeWeComEnvelopeBody[T any](env wecomEnvelope) (T, error) { + var out T + if len(env.Body) == 0 { + return out, fmt.Errorf("wecom response body is empty") + } + if err := json.Unmarshal(env.Body, &out); err != nil { + return out, fmt.Errorf("decode wecom response body: %w", err) + } + return out, nil +} + +func (c *WeComChannel) uploadOutboundMedia( + ctx context.Context, + localPath, filename, contentType string, + part bus.MediaPart, +) (*wecomOutboundMedia, error) { + _ = ctx + + contentType = detectLocalWeComContentType(localPath, contentType) + filename = ensureWeComOutboundFilename(filename, localPath, contentType) + + data, err := os.ReadFile(localPath) + if err != nil { + return nil, fmt.Errorf("read media file: %w", err) + } + size := int64(len(data)) + kind := outboundWeComMediaKind(part.Type, filename, contentType, size) + if kind == "" { + return nil, fmt.Errorf("unsupported wecom media type or size for %q", filename) + } + + totalChunks := (len(data) + wecomUploadChunkMaxBytes - 1) / wecomUploadChunkMaxBytes + if totalChunks <= 0 || totalChunks > wecomUploadMaxChunks { + return nil, fmt.Errorf("wecom upload requires 1-%d chunks, got %d", wecomUploadMaxChunks, totalChunks) + } + + sum := md5.Sum(data) + initEnv, err := c.sendCommandAck(wecomCommand{ + Cmd: wecomCmdUploadMediaInit, + Headers: wecomHeaders{ReqID: randomID(10)}, + Body: wecomUploadMediaInitBody{ + Type: kind, + Filename: filename, + TotalSize: size, + TotalChunks: totalChunks, + MD5: hex.EncodeToString(sum[:]), + }, + }, wecomUploadTimeout) + if err != nil { + return nil, err + } + initResp, err := decodeWeComEnvelopeBody[wecomUploadMediaInitResponse](initEnv) + if err != nil { + return nil, err + } + if strings.TrimSpace(initResp.UploadID) == "" { + return nil, fmt.Errorf("wecom upload init returned empty upload_id") + } + + for idx, offset := 0, 0; offset < len(data); idx, offset = idx+1, offset+wecomUploadChunkMaxBytes { + end := offset + wecomUploadChunkMaxBytes + if end > len(data) { + end = len(data) + } + sendErr := c.sendCommand(wecomCommand{ + Cmd: wecomCmdUploadMediaChunk, + Headers: wecomHeaders{ReqID: randomID(10)}, + Body: wecomUploadMediaChunkBody{ + UploadID: initResp.UploadID, + ChunkIndex: idx, + Base64Data: base64.StdEncoding.EncodeToString(data[offset:end]), + }, + }, wecomUploadTimeout) + if sendErr != nil { + return nil, sendErr + } + } + + finishEnv, err := c.sendCommandAck(wecomCommand{ + Cmd: wecomCmdUploadMediaEnd, + Headers: wecomHeaders{ReqID: randomID(10)}, + Body: wecomUploadMediaFinishBody{ + UploadID: initResp.UploadID, + }, + }, wecomUploadTimeout) + if err != nil { + return nil, err + } + finishResp, err := decodeWeComEnvelopeBody[wecomUploadMediaFinishResponse](finishEnv) + if err != nil { + return nil, err + } + if strings.TrimSpace(finishResp.MediaID) == "" { + return nil, fmt.Errorf("wecom upload finish returned empty media_id") + } + + uploaded := &wecomOutboundMedia{ + MsgType: kind, + MediaID: finishResp.MediaID, + } + if kind == "video" { + video := buildWeComVideoContent(finishResp.MediaID, filename, part.Caption) + uploaded.Title = video.Title + uploaded.Description = video.Description + } + return uploaded, nil +} + +func fallbackWeComMediaText(part bus.MediaPart, kind, filename string) string { + var lines []string + if caption := strings.TrimSpace(part.Caption); caption != "" { + lines = append(lines, caption) + } + + label := kind + if label == "" { + label = "media" + } + if filename != "" { + lines = append(lines, fmt.Sprintf("[%s: %s]", label, filename)) + } else { + lines = append(lines, fmt.Sprintf("[%s attachment]", label)) + } + + ref := strings.TrimSpace(part.Ref) + if strings.HasPrefix(ref, "http://") || strings.HasPrefix(ref, "https://") { + lines = append(lines, ref) + } + + return strings.Join(lines, "\n") +} + +func (c *WeComChannel) resolveMediaRoute(chatID string) (wecomTurn, uint32, bool) { + if turn, ok := c.getTurn(chatID); ok { + if time.Since(turn.CreatedAt) <= wecomStreamMaxDuration { + return turn, turn.ChatType, true + } + c.deleteTurn(chatID) + } + if route, ok := c.routes.Get(chatID); ok { + return wecomTurn{ChatID: route.ChatID, ChatType: route.ChatType}, route.ChatType, false + } + return wecomTurn{ChatID: chatID}, 0, false +} diff --git a/pkg/channels/wecom/media_test.go b/pkg/channels/wecom/media_test.go new file mode 100644 index 000000000..d5307e5d2 --- /dev/null +++ b/pkg/channels/wecom/media_test.go @@ -0,0 +1,180 @@ +package wecom + +import ( + "bytes" + "context" + "encoding/base64" + "io" + "net/http" + "strings" + "testing" + + basechannels "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/media" +) + +func TestStoreRemoteMedia_DetectsJPEGContentTypeFromBody(t *testing.T) { + t.Parallel() + + const jpegBase64 = "/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAP//////////////////////////////////////////////////////////////////////////////////////" + + "//////////////////////////////////////////////////////////////////////////////////////////////2wBDAf//////////////////////////////////////////////////////////////////////////////////////" + + "//////////////////////////////////////////////////////////////////////////////////////////////wAARCAABAAEDASIAAhEBAxEB/8QAFQABAQAAAAAAAAAAAAAAAAAAAAb/xAAVEQEBAAAAAAAAAAAAAAAAAAAABf/aAAwDAQACEAMQAAAB6A//xAAVEAEBAAAAAAAAAAAAAAAAAAAAEf/aAAgBAQABBQJf/8QAFBEBAAAAAAAAAAAAAAAAAAAAEP/aAAgBAwEBPwF//8QAFBEBAAAAAAAAAAAAAAAAAAAAEP/aAAgBAgEBPwF//8QAFBABAAAAAAAAAAAAAAAAAAAAEP/aAAgBAQAGPwJf/8QAFBABAAAAAAAAAAAAAAAAAAAAEP/aAAgBAQABPyFf/9k=" + + jpegData := decodeTestBase64(t, jpegBase64) + store := media.NewFileMediaStore() + ch := &WeComChannel{ + BaseChannel: basechannels.NewBaseChannel("wecom", nil, nil, nil), + mediaClient: &http.Client{ + Transport: roundTripFunc(func(*http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/octet-stream"}}, + Body: io.NopCloser(bytes.NewReader(jpegData)), + }, nil + }), + }, + } + ch.SetMediaStore(store) + + ref, err := ch.storeRemoteMedia(context.Background(), "test-scope", "msg-1", "https://wecom.example/media", "", "") + if err != nil { + t.Fatalf("storeRemoteMedia returned error: %v", err) + } + t.Cleanup(func() { + _ = store.ReleaseAll("test-scope") + }) + + _, meta, err := store.ResolveWithMeta(ref) + if err != nil { + t.Fatalf("resolve media ref: %v", err) + } + if meta.ContentType != "image/jpeg" { + t.Fatalf("expected image/jpeg content type, got %q", meta.ContentType) + } + if !strings.HasSuffix(meta.Filename, ".jpg") && !strings.HasSuffix(meta.Filename, ".jpeg") { + t.Fatalf("expected jpeg filename, got %q", meta.Filename) + } +} + +func TestDetectWeComMediaMetadata_UsesFallbackExtensionWhenBodyUnknown(t *testing.T) { + t.Parallel() + + filename, contentType := detectWeComMediaMetadata([]byte("not a real image"), "msg-2.pdf", "", "", "") + if filename != "msg-2.pdf" { + t.Fatalf("expected fallback filename to be preserved, got %q", filename) + } + if contentType != "application/pdf" { + t.Fatalf("expected application/pdf from fallback extension, got %q", contentType) + } +} + +func TestStoreRemoteMedia_PreservesSuffixFromURL(t *testing.T) { + t.Parallel() + + docxLikeData := []byte("PK\x03\x04fake office payload") + store := media.NewFileMediaStore() + ch := &WeComChannel{ + BaseChannel: basechannels.NewBaseChannel("wecom", nil, nil, nil), + mediaClient: &http.Client{ + Transport: roundTripFunc(func(*http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/octet-stream"}}, + Body: io.NopCloser(bytes.NewReader(docxLikeData)), + }, nil + }), + }, + } + ch.SetMediaStore(store) + + ref, err := ch.storeRemoteMedia( + context.Background(), + "test-scope", + "msg-docx", + "https://wecom.example/media/report.docx?signature=1", + "", + ".bin", + ) + if err != nil { + t.Fatalf("storeRemoteMedia returned error: %v", err) + } + t.Cleanup(func() { + _ = store.ReleaseAll("test-scope") + }) + + localPath, meta, err := store.ResolveWithMeta(ref) + if err != nil { + t.Fatalf("resolve media ref: %v", err) + } + if !strings.HasSuffix(meta.Filename, ".docx") { + t.Fatalf("expected docx filename, got %q", meta.Filename) + } + if !strings.HasSuffix(strings.ToLower(localPath), ".docx") { + t.Fatalf("expected docx temp path, got %q", localPath) + } +} + +func TestStoreRemoteMedia_PreservesSuffixFromContentDisposition(t *testing.T) { + t.Parallel() + + pptxLikeData := []byte("PK\x03\x04fake office payload") + store := media.NewFileMediaStore() + ch := &WeComChannel{ + BaseChannel: basechannels.NewBaseChannel("wecom", nil, nil, nil), + mediaClient: &http.Client{ + Transport: roundTripFunc(func(*http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{ + "Content-Type": []string{"application/octet-stream"}, + "Content-Disposition": []string{`attachment; filename="slides.pptx"`}, + }, + Body: io.NopCloser(bytes.NewReader(pptxLikeData)), + }, nil + }), + }, + } + ch.SetMediaStore(store) + + ref, err := ch.storeRemoteMedia( + context.Background(), + "test-scope", + "msg-pptx", + "https://wecom.example/media/download", + "", + ".bin", + ) + if err != nil { + t.Fatalf("storeRemoteMedia returned error: %v", err) + } + t.Cleanup(func() { + _ = store.ReleaseAll("test-scope") + }) + + localPath, meta, err := store.ResolveWithMeta(ref) + if err != nil { + t.Fatalf("resolve media ref: %v", err) + } + if !strings.HasSuffix(meta.Filename, ".pptx") { + t.Fatalf("expected pptx filename, got %q", meta.Filename) + } + if !strings.HasSuffix(strings.ToLower(localPath), ".pptx") { + t.Fatalf("expected pptx temp path, got %q", localPath) + } +} + +func decodeTestBase64(t *testing.T, value string) []byte { + t.Helper() + + data, err := io.ReadAll(base64.NewDecoder(base64.StdEncoding, strings.NewReader(value))) + if err != nil { + t.Fatalf("decode base64 fixture: %v", err) + } + return data +} + +type roundTripFunc func(*http.Request) (*http.Response, error) + +func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req) +} diff --git a/pkg/channels/wecom/protocol.go b/pkg/channels/wecom/protocol.go new file mode 100644 index 000000000..f42ce3bf4 --- /dev/null +++ b/pkg/channels/wecom/protocol.go @@ -0,0 +1,173 @@ +package wecom + +import "encoding/json" + +const ( + wecomDefaultWebSocketURL = "wss://openws.work.weixin.qq.com" + wecomCmdSubscribe = "aibot_subscribe" + wecomCmdPing = "ping" + wecomCmdMsgCallback = "aibot_msg_callback" + wecomCmdEventCallback = "aibot_event_callback" + wecomCmdRespondMsg = "aibot_respond_msg" + wecomCmdSendMsg = "aibot_send_msg" + wecomCmdUploadMediaInit = "aibot_upload_media_init" + wecomCmdUploadMediaChunk = "aibot_upload_media_chunk" + wecomCmdUploadMediaEnd = "aibot_upload_media_finish" +) + +type wecomEnvelope struct { + Cmd string `json:"cmd,omitempty"` + Headers wecomHeaders `json:"headers"` + Body json.RawMessage `json:"body,omitempty"` + ErrCode int `json:"errcode,omitempty"` + ErrMsg string `json:"errmsg,omitempty"` +} + +type wecomHeaders struct { + ReqID string `json:"req_id,omitempty"` +} + +type wecomCommand struct { + Cmd string `json:"cmd"` + Headers wecomHeaders `json:"headers"` + Body any `json:"body,omitempty"` +} + +type wecomSendMsgBody struct { + ChatID string `json:"chatid"` + ChatType uint32 `json:"chat_type,omitempty"` + MsgType string `json:"msgtype"` + Markdown *wecomMarkdownContent `json:"markdown,omitempty"` + File *wecomMediaRefContent `json:"file,omitempty"` + Image *wecomMediaRefContent `json:"image,omitempty"` + Voice *wecomMediaRefContent `json:"voice,omitempty"` + Video *wecomVideoContent `json:"video,omitempty"` + TemplateCard map[string]any `json:"template_card,omitempty"` +} + +type wecomRespondMsgBody struct { + MsgType string `json:"msgtype"` + Stream *wecomStreamContent `json:"stream,omitempty"` + Markdown *wecomMarkdownContent `json:"markdown,omitempty"` + File *wecomMediaRefContent `json:"file,omitempty"` + Image *wecomMediaRefContent `json:"image,omitempty"` + Voice *wecomMediaRefContent `json:"voice,omitempty"` + Video *wecomVideoContent `json:"video,omitempty"` + TemplateCard map[string]any `json:"template_card,omitempty"` +} + +type wecomStreamContent struct { + ID string `json:"id"` + Finish bool `json:"finish"` + Content string `json:"content,omitempty"` +} + +type wecomMarkdownContent struct { + Content string `json:"content"` +} + +type wecomMediaRefContent struct { + MediaID string `json:"media_id"` +} + +type wecomVideoContent struct { + MediaID string `json:"media_id"` + Title string `json:"title,omitempty"` + Description string `json:"description,omitempty"` +} + +type wecomUploadMediaInitBody struct { + Type string `json:"type"` + Filename string `json:"filename"` + TotalSize int64 `json:"total_size"` + TotalChunks int `json:"total_chunks"` + MD5 string `json:"md5,omitempty"` +} + +type wecomUploadMediaInitResponse struct { + UploadID string `json:"upload_id"` +} + +type wecomUploadMediaChunkBody struct { + UploadID string `json:"upload_id"` + ChunkIndex int `json:"chunk_index"` + Base64Data string `json:"base64_data"` +} + +type wecomUploadMediaFinishBody struct { + UploadID string `json:"upload_id"` +} + +type wecomUploadMediaFinishResponse struct { + Type string `json:"type"` + MediaID string `json:"media_id"` + CreatedAt json.RawMessage `json:"created_at"` +} + +type wecomIncomingMessage struct { + MsgID string `json:"msgid"` + AIBotID string `json:"aibotid"` + ChatID string `json:"chatid,omitempty"` + ChatType string `json:"chattype,omitempty"` + From struct { + UserID string `json:"userid"` + } `json:"from"` + MsgType string `json:"msgtype"` + Text *struct { + Content string `json:"content"` + } `json:"text,omitempty"` + Image *struct { + URL string `json:"url"` + AESKey string `json:"aeskey,omitempty"` + } `json:"image,omitempty"` + File *struct { + URL string `json:"url"` + AESKey string `json:"aeskey,omitempty"` + } `json:"file,omitempty"` + Video *struct { + URL string `json:"url"` + AESKey string `json:"aeskey,omitempty"` + } `json:"video,omitempty"` + Voice *struct { + Content string `json:"content"` + } `json:"voice,omitempty"` + Mixed *struct { + MsgItem []struct { + MsgType string `json:"msgtype"` + Text *struct { + Content string `json:"content"` + } `json:"text,omitempty"` + Image *struct { + URL string `json:"url"` + AESKey string `json:"aeskey,omitempty"` + } `json:"image,omitempty"` + File *struct { + URL string `json:"url"` + AESKey string `json:"aeskey,omitempty"` + } `json:"file,omitempty"` + } `json:"msg_item"` + } `json:"mixed,omitempty"` + Quote *struct { + MsgType string `json:"msgtype"` + Text *struct { + Content string `json:"content"` + } `json:"text,omitempty"` + } `json:"quote,omitempty"` + Event *struct { + EventType string `json:"eventtype"` + } `json:"event,omitempty"` +} + +func incomingChatID(msg wecomIncomingMessage) string { + if msg.ChatID != "" { + return msg.ChatID + } + return msg.From.UserID +} + +func incomingChatTypeCode(kind string) uint32 { + if kind == "group" { + return 2 + } + return 1 +} diff --git a/pkg/channels/wecom/reqid_store.go b/pkg/channels/wecom/reqid_store.go new file mode 100644 index 000000000..59e64e63d --- /dev/null +++ b/pkg/channels/wecom/reqid_store.go @@ -0,0 +1,113 @@ +package wecom + +import ( + "encoding/json" + "errors" + "os" + "path/filepath" + "sync" + "time" +) + +type wecomRoute struct { + ReqID string `json:"req_id"` + ChatID string `json:"chat_id"` + ChatType uint32 `json:"chat_type"` + ExpiresAt time.Time `json:"expires_at"` +} + +type reqIDStore struct { + mu sync.Mutex + path string + routes map[string]wecomRoute +} + +func newReqIDStore(path string) *reqIDStore { + if path == "" { + path = defaultReqIDStorePath() + } + s := &reqIDStore{ + path: path, + routes: make(map[string]wecomRoute), + } + _ = s.load() + return s +} + +func defaultReqIDStorePath() string { + if home, err := os.UserHomeDir(); err == nil && home != "" { + return filepath.Join(home, ".picoclaw", "wecom", "reqid-store.json") + } + return filepath.Join(os.TempDir(), "picoclaw-wecom-reqid-store.json") +} + +func (s *reqIDStore) Put(chatID, reqID string, chatType uint32, ttl time.Duration) error { + if reqID == "" || chatID == "" { + return nil + } + s.mu.Lock() + defer s.mu.Unlock() + s.deleteExpiredLocked(time.Now()) + s.routes[chatID] = wecomRoute{ + ReqID: reqID, + ChatID: chatID, + ChatType: chatType, + ExpiresAt: time.Now().Add(ttl), + } + return s.saveLocked() +} + +func (s *reqIDStore) Get(chatID string) (wecomRoute, bool) { + s.mu.Lock() + defer s.mu.Unlock() + s.deleteExpiredLocked(time.Now()) + route, ok := s.routes[chatID] + return route, ok +} + +func (s *reqIDStore) Delete(chatID string) error { + s.mu.Lock() + defer s.mu.Unlock() + delete(s.routes, chatID) + return s.saveLocked() +} + +func (s *reqIDStore) load() error { + s.mu.Lock() + defer s.mu.Unlock() + + data, err := os.ReadFile(s.path) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + return nil + } + return err + } + + var routes map[string]wecomRoute + if err := json.Unmarshal(data, &routes); err != nil { + return err + } + s.routes = routes + s.deleteExpiredLocked(time.Now()) + return nil +} + +func (s *reqIDStore) deleteExpiredLocked(now time.Time) { + for chatID, route := range s.routes { + if !route.ExpiresAt.IsZero() && now.After(route.ExpiresAt) { + delete(s.routes, chatID) + } + } +} + +func (s *reqIDStore) saveLocked() error { + if err := os.MkdirAll(filepath.Dir(s.path), 0o700); err != nil { + return err + } + data, err := json.MarshalIndent(s.routes, "", " ") + if err != nil { + return err + } + return os.WriteFile(s.path, data, 0o600) +} diff --git a/pkg/channels/wecom/reqid_store_test.go b/pkg/channels/wecom/reqid_store_test.go new file mode 100644 index 000000000..e68e82500 --- /dev/null +++ b/pkg/channels/wecom/reqid_store_test.go @@ -0,0 +1,24 @@ +package wecom + +import ( + "path/filepath" + "testing" + "time" +) + +func TestReqIDStorePersistsRoutes(t *testing.T) { + storePath := filepath.Join(t.TempDir(), "reqids.json") + store := newReqIDStore(storePath) + if err := store.Put("chat-1", "req-1", 2, time.Hour); err != nil { + t.Fatalf("Put() error = %v", err) + } + + reloaded := newReqIDStore(storePath) + route, ok := reloaded.Get("chat-1") + if !ok { + t.Fatal("expected persisted route to be loaded") + } + if route.ChatID != "chat-1" || route.ReqID != "req-1" || route.ChatType != 2 { + t.Fatalf("loaded route = %+v", route) + } +} diff --git a/pkg/channels/wecom/wecom.go b/pkg/channels/wecom/wecom.go new file mode 100644 index 000000000..26e971921 --- /dev/null +++ b/pkg/channels/wecom/wecom.go @@ -0,0 +1,970 @@ +package wecom + +import ( + "context" + "crypto/rand" + "encoding/json" + "fmt" + "math/big" + "net/http" + "strings" + "sync" + "time" + + "github.com/gorilla/websocket" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/identity" + "github.com/sipeed/picoclaw/pkg/logger" +) + +const ( + wecomConnectTimeout = 15 * time.Second + wecomCommandTimeout = 10 * time.Second + wecomUploadTimeout = 30 * time.Second + wecomHeartbeatInterval = 30 * time.Second + wecomStreamMaxDuration = 5*time.Minute + 30*time.Second + wecomStreamMinInterval = 500 * time.Millisecond + wecomRouteTTL = 30 * time.Minute + wecomMediaTimeout = 30 * time.Second + wecomRecentMessageMax = 1000 +) + +type WeComChannel struct { + *channels.BaseChannel + config config.WeComConfig + + ctx context.Context + cancel context.CancelFunc + + conn *websocket.Conn + connMu sync.Mutex + + pendingMu sync.Mutex + pending map[string]chan wecomEnvelope + + turnsMu sync.Mutex + turns map[string][]wecomTurn + + recent *recentMessageSet + routes *reqIDStore + mediaClient *http.Client + commandSend func(wecomCommand, time.Duration) (wecomEnvelope, error) +} + +type wecomTurn struct { + ReqID string + ChatID string + ChatType uint32 + StreamID string + CreatedAt time.Time +} + +type wecomStreamer struct { + channel *WeComChannel + chatID string + turn wecomTurn + + mu sync.Mutex + closed bool + lastSentAt time.Time + content string +} + +type recentMessageSet struct { + mu sync.Mutex + seen map[string]struct{} + ring []string + idx int +} + +func newRecentMessageSet(capacity int) *recentMessageSet { + if capacity <= 0 { + capacity = wecomRecentMessageMax + } + return &recentMessageSet{ + seen: make(map[string]struct{}, capacity), + ring: make([]string, capacity), + } +} + +func (s *recentMessageSet) Mark(id string) bool { + if id == "" { + return true + } + s.mu.Lock() + defer s.mu.Unlock() + if _, ok := s.seen[id]; ok { + return false + } + if old := s.ring[s.idx]; old != "" { + delete(s.seen, old) + } + s.ring[s.idx] = id + s.idx = (s.idx + 1) % len(s.ring) + s.seen[id] = struct{}{} + return true +} + +func NewChannel(cfg config.WeComConfig, messageBus *bus.MessageBus) (*WeComChannel, error) { + if cfg.BotID == "" || cfg.Secret() == "" { + return nil, fmt.Errorf("wecom bot_id and secret are required") + } + if cfg.WebSocketURL == "" { + cfg.WebSocketURL = wecomDefaultWebSocketURL + } + + base := channels.NewBaseChannel( + "wecom", + cfg, + messageBus, + cfg.AllowFrom, + channels.WithReasoningChannelID(cfg.ReasoningChannelID), + ) + + ch := &WeComChannel{ + BaseChannel: base, + config: cfg, + pending: make(map[string]chan wecomEnvelope), + turns: make(map[string][]wecomTurn), + recent: newRecentMessageSet(wecomRecentMessageMax), + routes: newReqIDStore(""), + mediaClient: &http.Client{Timeout: wecomMediaTimeout}, + } + ch.SetOwner(ch) + return ch, nil +} + +func (c *WeComChannel) Name() string { return "wecom" } + +func (c *WeComChannel) Start(ctx context.Context) error { + logger.InfoC("wecom", "Starting WeCom channel...") + c.ctx, c.cancel = context.WithCancel(ctx) + c.SetRunning(true) + go c.connectLoop() + return nil +} + +func (c *WeComChannel) Stop(_ context.Context) error { + logger.InfoC("wecom", "Stopping WeCom channel...") + if c.cancel != nil { + c.cancel() + } + c.connMu.Lock() + if c.conn != nil { + _ = c.conn.Close() + c.conn = nil + } + c.connMu.Unlock() + c.clearTurns() + c.SetRunning(false) + return nil +} + +func (c *WeComChannel) BeginStream(_ context.Context, chatID string) (channels.Streamer, error) { + if !c.IsRunning() { + return nil, channels.ErrNotRunning + } + + turn, ok := c.getTurn(chatID) + if !ok { + return nil, fmt.Errorf("wecom streaming unavailable: no active turn") + } + if time.Since(turn.CreatedAt) > wecomStreamMaxDuration { + c.consumeTurn(chatID, turn) + return nil, fmt.Errorf("wecom streaming unavailable: turn expired") + } + + return &wecomStreamer{ + channel: c, + chatID: chatID, + turn: turn, + }, nil +} + +func (c *WeComChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { + if !c.IsRunning() { + return channels.ErrNotRunning + } + content := strings.TrimSpace(msg.Content) + if content == "" { + return nil + } + + if turn, ok := c.getTurn(msg.ChatID); ok { + if time.Since(turn.CreatedAt) <= wecomStreamMaxDuration { + if err := c.sendStreamReply(turn, content); err == nil { + c.consumeTurn(msg.ChatID, turn) + return nil + } + } + c.consumeTurn(msg.ChatID, turn) + } + + if route, ok := c.routes.Get(msg.ChatID); ok { + if err := c.sendActivePush(route.ChatID, route.ChatType, content); err != nil { + return err + } + return nil + } + + if err := c.sendActivePush(msg.ChatID, 0, content); err != nil { + return err + } + return nil +} + +func (c *WeComChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error { + if !c.IsRunning() { + return channels.ErrNotRunning + } + + route, chatType, hasTurn := c.resolveMediaRoute(msg.ChatID) + chatID := route.ChatID + if chatID == "" { + chatID = msg.ChatID + } + + for _, part := range msg.Parts { + if strings.TrimSpace(part.Ref) == "" { + if caption := strings.TrimSpace(part.Caption); caption != "" { + if err := c.sendActivePush(chatID, chatType, caption); err != nil { + return err + } + } + continue + } + + localPath, filename, contentType, cleanup, err := c.resolveOutboundPart(ctx, part) + if err != nil { + return fmt.Errorf("wecom resolve media %q: %v: %w", part.Ref, err, channels.ErrSendFailed) + } + + func() { + if cleanup != nil { + defer cleanup() + } + + uploaded, uploadErr := c.uploadOutboundMedia(ctx, localPath, filename, contentType, part) + if uploadErr != nil { + logger.WarnCF("wecom", "Falling back to placeholder after media upload failure", map[string]any{ + "chat_id": chatID, + "ref": part.Ref, + "filename": filename, + "content_type": contentType, + "error": uploadErr.Error(), + }) + if hasTurn { + if finishErr := c.sendStreamChunk(route, true, ""); finishErr != nil { + err = finishErr + return + } + c.deleteTurn(msg.ChatID) + hasTurn = false + } + err = c.sendActivePush(chatID, chatType, fallbackWeComMediaText(part, "", filename)) + return + } + + if hasTurn { + err = c.sendTurnMedia(route, uploaded) + c.deleteTurn(msg.ChatID) + hasTurn = false + } else { + err = c.sendActiveMedia(chatID, chatType, uploaded) + } + if err != nil { + return + } + if caption := strings.TrimSpace(part.Caption); caption != "" { + err = c.sendActivePush(chatID, chatType, caption) + } + }() + if err != nil { + return err + } + } + + return nil +} + +func (c *WeComChannel) connectLoop() { + backoff := time.Second + for { + select { + case <-c.ctx.Done(): + return + default: + } + + if err := c.runConnection(); err != nil { + logger.WarnCF("wecom", "WeCom connection lost", map[string]any{ + "error": err.Error(), + "backoff": backoff.String(), + }) + select { + case <-time.After(backoff): + case <-c.ctx.Done(): + return + } + if backoff < time.Minute { + backoff *= 2 + if backoff > time.Minute { + backoff = time.Minute + } + } + continue + } + return + } +} + +func (c *WeComChannel) runConnection() error { + dialCtx, cancel := context.WithTimeout(c.ctx, wecomConnectTimeout) + defer cancel() + + conn, resp, err := websocket.DefaultDialer.DialContext(dialCtx, c.config.WebSocketURL, nil) + if resp != nil { + _ = resp.Body.Close() + } + if err != nil { + return fmt.Errorf("%w: %v", channels.ErrTemporary, err) + } + + c.connMu.Lock() + c.conn = conn + c.connMu.Unlock() + defer func() { + c.connMu.Lock() + if c.conn == conn { + c.conn = nil + } + c.connMu.Unlock() + _ = conn.Close() + c.clearTurns() + }() + + readErrCh := make(chan error, 1) + go func() { + readErrCh <- c.readLoop(conn) + }() + + if writeErr := c.writeAndWait(conn, wecomCommand{ + Cmd: wecomCmdSubscribe, + Headers: wecomHeaders{ReqID: randomID(10)}, + Body: map[string]string{ + "bot_id": c.config.BotID, + "secret": c.config.Secret(), + }, + }, wecomCommandTimeout); writeErr != nil { + return writeErr + } + + heartbeatDone := make(chan struct{}) + go func() { + defer close(heartbeatDone) + c.heartbeatLoop(conn) + }() + + err = <-readErrCh + _ = conn.Close() + <-heartbeatDone + return err +} + +func (c *WeComChannel) heartbeatLoop(conn *websocket.Conn) { + ticker := time.NewTicker(wecomHeartbeatInterval) + defer ticker.Stop() + for { + select { + case <-ticker.C: + if err := c.writeAndWait(conn, wecomCommand{ + Cmd: wecomCmdPing, + Headers: wecomHeaders{ReqID: randomID(10)}, + }, wecomCommandTimeout); err != nil { + logger.WarnCF("wecom", "Heartbeat failed", map[string]any{"error": err.Error()}) + _ = conn.Close() + return + } + case <-c.ctx.Done(): + return + } + } +} + +func (c *WeComChannel) readLoop(conn *websocket.Conn) error { + for { + _, raw, err := conn.ReadMessage() + if err != nil { + select { + case <-c.ctx.Done(): + return nil + default: + return fmt.Errorf("%w: %v", channels.ErrTemporary, err) + } + } + + var env wecomEnvelope + if err := json.Unmarshal(raw, &env); err != nil { + logger.WarnCF("wecom", "Failed to parse WebSocket message", map[string]any{"error": err.Error()}) + continue + } + + if env.Cmd == "" && env.Headers.ReqID != "" { + c.pendingMu.Lock() + ch, ok := c.pending[env.Headers.ReqID] + if ok { + delete(c.pending, env.Headers.ReqID) + } + c.pendingMu.Unlock() + if ok { + ch <- env + } + continue + } + + go c.handleEnvelope(env) + } +} + +func (c *WeComChannel) handleEnvelope(env wecomEnvelope) { + switch env.Cmd { + case wecomCmdMsgCallback: + c.handleMessageCallback(env) + case wecomCmdEventCallback: + c.handleEventCallback(env) + default: + logger.DebugCF("wecom", "Ignoring unsupported WeCom command", map[string]any{"cmd": env.Cmd}) + } +} + +func (c *WeComChannel) handleEventCallback(env wecomEnvelope) { + var msg wecomIncomingMessage + if err := json.Unmarshal(env.Body, &msg); err != nil { + logger.WarnCF("wecom", "Failed to parse WeCom event callback", map[string]any{"error": err.Error()}) + } +} + +func (c *WeComChannel) handleMessageCallback(env wecomEnvelope) { + var msg wecomIncomingMessage + if err := json.Unmarshal(env.Body, &msg); err != nil { + logger.WarnCF("wecom", "Failed to parse WeCom message callback", map[string]any{"error": err.Error()}) + return + } + if !c.recent.Mark(msg.MsgID) { + return + } + + reqID := env.Headers.ReqID + if reqID == "" { + logger.WarnC("wecom", "WeCom message callback missing req_id") + return + } + if msg.Event != nil && msg.Event.EventType != "" { + return + } + + if err := c.dispatchIncoming(reqID, msg); err != nil { + logger.WarnCF("wecom", "Failed to dispatch WeCom message", map[string]any{ + "req_id": reqID, + "error": err.Error(), + }) + _ = c.respondImmediate(reqID, "The WeCom message could not be processed.") + } +} + +func (c *WeComChannel) dispatchIncoming(reqID string, msg wecomIncomingMessage) error { + senderID := msg.From.UserID + if senderID == "" { + senderID = "unknown" + } + actualChatID := incomingChatID(msg) + chatType := incomingChatTypeCode(msg.ChatType) + peerKind := "direct" + if msg.ChatType == "group" { + peerKind = "group" + } + + sender := bus.SenderInfo{ + Platform: "wecom", + PlatformID: senderID, + CanonicalID: identity.BuildCanonicalID("wecom", senderID), + DisplayName: senderID, + } + + var ( + content string + quoteText string + mediaRefs []string + err error + ) + scope := channels.BuildMediaScope("wecom", actualChatID, msg.MsgID) + switch msg.MsgType { + case "text": + if msg.Text != nil { + content = strings.TrimSpace(msg.Text.Content) + } + case "voice": + if msg.Voice != nil { + content = strings.TrimSpace(msg.Voice.Content) + } + case "image": + content = "[image]" + mediaRefs, err = c.collectSingleMedia(c.ctx, scope, msg.MsgID, &mediaPayload{ + url: msg.Image.URL, + aesKey: msg.Image.AESKey, + }, "image", ".jpg") + case "file": + content = "[file]" + mediaRefs, err = c.collectSingleMedia(c.ctx, scope, msg.MsgID, &mediaPayload{ + url: msg.File.URL, + aesKey: msg.File.AESKey, + }, "file", ".bin") + case "video": + content = "[video]" + mediaRefs, err = c.collectSingleMedia(c.ctx, scope, msg.MsgID, &mediaPayload{ + url: msg.Video.URL, + aesKey: msg.Video.AESKey, + }, "video", ".mp4") + case "mixed": + content, mediaRefs, err = c.collectMixedMedia(c.ctx, scope, msg) + default: + return c.respondImmediate(reqID, "Unsupported WeCom message type: "+msg.MsgType) + } + if err != nil { + return err + } + if msg.Quote != nil && msg.Quote.Text != nil { + quoteText = strings.TrimSpace(msg.Quote.Text.Content) + if content == "" { + content = quoteText + } + } + if content == "" && len(mediaRefs) == 0 { + return c.respondImmediate(reqID, "The WeCom message did not contain usable content.") + } + + turn := wecomTurn{ + ReqID: reqID, + ChatID: actualChatID, + ChatType: chatType, + StreamID: randomID(10), + CreatedAt: time.Now(), + } + c.queueTurn(actualChatID, turn) + if err := c.routes.Put(actualChatID, reqID, chatType, wecomRouteTTL); err != nil { + logger.WarnCF("wecom", "Failed to persist req_id route", map[string]any{ + "chat_id": actualChatID, + "req_id": reqID, + "error": err.Error(), + }) + } + + opening := "" + if c.config.SendThinkingMessage { + opening = "Processing..." + } + if err := c.sendStreamChunk(turn, false, opening); err != nil { + return err + } + + peer := bus.Peer{Kind: peerKind, ID: actualChatID} + metadata := map[string]string{ + "channel": "wecom", + "req_id": reqID, + "chat_id": actualChatID, + "chat_type": msg.ChatType, + "msg_id": msg.MsgID, + "msg_type": msg.MsgType, + } + if quoteText != "" { + metadata["quote_text"] = quoteText + } + + c.HandleMessage(c.ctx, peer, msg.MsgID, senderID, actualChatID, content, mediaRefs, metadata, sender) + return nil +} + +func (c *WeComChannel) collectSingleMedia( + ctx context.Context, + scope, msgID string, + payload interface { + GetURL() string + GetAESKey() string + }, + label, fallbackExt string, +) ([]string, error) { + if payload == nil || payload.GetURL() == "" { + return nil, fmt.Errorf("%s payload is empty", label) + } + ref, err := c.storeRemoteMedia(ctx, scope, msgID, payload.GetURL(), payload.GetAESKey(), fallbackExt) + if err != nil { + return nil, err + } + return []string{ref}, nil +} + +type mediaPayload struct { + url string + aesKey string +} + +func (p *mediaPayload) GetURL() string { return p.url } +func (p *mediaPayload) GetAESKey() string { return p.aesKey } + +func (c *WeComChannel) collectMixedMedia( + ctx context.Context, + scope string, + msg wecomIncomingMessage, +) (string, []string, error) { + if msg.Mixed == nil { + return "", nil, fmt.Errorf("mixed message is empty") + } + + var textParts []string + var refs []string + for idx, item := range msg.Mixed.MsgItem { + switch item.MsgType { + case "text": + if item.Text != nil && strings.TrimSpace(item.Text.Content) != "" { + textParts = append(textParts, strings.TrimSpace(item.Text.Content)) + } + case "image": + if item.Image != nil && item.Image.URL != "" { + ref, err := c.storeRemoteMedia( + ctx, + scope, + fmt.Sprintf("%s-%d", msg.MsgID, idx), + item.Image.URL, + item.Image.AESKey, + ".jpg", + ) + if err != nil { + return "", nil, err + } + refs = append(refs, ref) + } + case "file": + if item.File != nil && item.File.URL != "" { + ref, err := c.storeRemoteMedia( + ctx, + scope, + fmt.Sprintf("%s-%d", msg.MsgID, idx), + item.File.URL, + item.File.AESKey, + ".bin", + ) + if err != nil { + return "", nil, err + } + refs = append(refs, ref) + } + } + } + + content := strings.Join(textParts, "\n") + if content == "" && len(refs) > 0 { + content = "[media]" + } + return content, refs, nil +} + +func (c *WeComChannel) respondImmediate(reqID, content string) error { + turn := wecomTurn{ + ReqID: reqID, + StreamID: randomID(10), + CreatedAt: time.Now(), + } + return c.sendStreamChunk(turn, true, content) +} + +func (c *WeComChannel) sendStreamReply(turn wecomTurn, content string) error { + return c.sendStreamChunk(turn, true, content) +} + +func (c *WeComChannel) sendStreamChunk(turn wecomTurn, finish bool, content string) error { + return c.sendCommand(wecomCommand{ + Cmd: wecomCmdRespondMsg, + Headers: wecomHeaders{ReqID: turn.ReqID}, + Body: wecomRespondMsgBody{ + MsgType: "stream", + Stream: &wecomStreamContent{ + ID: turn.StreamID, + Finish: finish, + Content: content, + }, + }, + }, wecomCommandTimeout) +} + +func (c *WeComChannel) sendTurnMedia(turn wecomTurn, uploaded *wecomOutboundMedia) error { + if uploaded == nil { + return fmt.Errorf("wecom outbound media is nil: %w", channels.ErrSendFailed) + } + if err := c.sendCommand(wecomCommand{ + Cmd: wecomCmdRespondMsg, + Headers: wecomHeaders{ReqID: turn.ReqID}, + Body: uploaded.respondBody(), + }, wecomCommandTimeout); err != nil { + return err + } + return c.sendStreamChunk(turn, true, "") +} + +func (c *WeComChannel) sendActivePush(chatID string, chatType uint32, content string) error { + if strings.TrimSpace(chatID) == "" { + return fmt.Errorf("empty chat ID: %w", channels.ErrSendFailed) + } + return c.sendCommand(wecomCommand{ + Cmd: wecomCmdSendMsg, + Headers: wecomHeaders{ReqID: randomID(10)}, + Body: wecomSendMsgBody{ + ChatID: chatID, + ChatType: chatType, + MsgType: "markdown", + Markdown: &wecomMarkdownContent{Content: content}, + }, + }, wecomCommandTimeout) +} + +func (c *WeComChannel) sendActiveMedia(chatID string, chatType uint32, uploaded *wecomOutboundMedia) error { + if strings.TrimSpace(chatID) == "" { + return fmt.Errorf("empty chat ID: %w", channels.ErrSendFailed) + } + if uploaded == nil { + return fmt.Errorf("wecom outbound media is nil: %w", channels.ErrSendFailed) + } + return c.sendCommand(wecomCommand{ + Cmd: wecomCmdSendMsg, + Headers: wecomHeaders{ReqID: randomID(10)}, + Body: uploaded.sendBody(chatID, chatType), + }, wecomCommandTimeout) +} + +func (c *WeComChannel) sendCommand(cmd wecomCommand, timeout time.Duration) error { + _, err := c.sendCommandAck(cmd, timeout) + return err +} + +func (c *WeComChannel) sendCommandAck(cmd wecomCommand, timeout time.Duration) (wecomEnvelope, error) { + if c.commandSend != nil { + return c.commandSend(cmd, timeout) + } + return c.writeCurrentAck(cmd, timeout) +} + +func (c *WeComChannel) writeCurrentAck(cmd wecomCommand, timeout time.Duration) (wecomEnvelope, error) { + c.connMu.Lock() + conn := c.conn + c.connMu.Unlock() + if conn == nil { + return wecomEnvelope{}, fmt.Errorf("wecom websocket not connected: %w", channels.ErrTemporary) + } + return c.writeAndWaitAck(conn, cmd, timeout) +} + +func (c *WeComChannel) writeAndWait(conn *websocket.Conn, cmd wecomCommand, timeout time.Duration) error { + _, err := c.writeAndWaitAck(conn, cmd, timeout) + return err +} + +func (c *WeComChannel) writeAndWaitAck( + conn *websocket.Conn, + cmd wecomCommand, + timeout time.Duration, +) (wecomEnvelope, error) { + if cmd.Headers.ReqID == "" { + cmd.Headers.ReqID = randomID(10) + } + waitCh := make(chan wecomEnvelope, 1) + c.pendingMu.Lock() + c.pending[cmd.Headers.ReqID] = waitCh + c.pendingMu.Unlock() + defer func() { + c.pendingMu.Lock() + delete(c.pending, cmd.Headers.ReqID) + c.pendingMu.Unlock() + }() + + data, err := json.Marshal(cmd) + if err != nil { + return wecomEnvelope{}, fmt.Errorf("%w: %v", channels.ErrSendFailed, err) + } + c.connMu.Lock() + err = conn.WriteMessage(websocket.TextMessage, data) + c.connMu.Unlock() + if err != nil { + return wecomEnvelope{}, fmt.Errorf("%w: %v", channels.ErrTemporary, err) + } + + timer := time.NewTimer(timeout) + defer timer.Stop() + select { + case env := <-waitCh: + if env.ErrCode != 0 { + return wecomEnvelope{}, fmt.Errorf( + "%w: wecom errcode=%d errmsg=%s", + channels.ErrTemporary, + env.ErrCode, + env.ErrMsg, + ) + } + return env, nil + case <-timer.C: + return wecomEnvelope{}, fmt.Errorf("%w: timeout waiting for WeCom ack", channels.ErrTemporary) + case <-c.ctx.Done(): + return wecomEnvelope{}, c.ctx.Err() + } +} + +func (c *WeComChannel) getTurn(chatID string) (wecomTurn, bool) { + c.turnsMu.Lock() + defer c.turnsMu.Unlock() + queue := c.turns[chatID] + if len(queue) == 0 { + return wecomTurn{}, false + } + return queue[0], true +} + +func (c *WeComChannel) deleteTurn(chatID string) { + c.turnsMu.Lock() + defer c.turnsMu.Unlock() + queue := c.turns[chatID] + if len(queue) <= 1 { + delete(c.turns, chatID) + return + } + c.turns[chatID] = queue[1:] +} + +func (c *WeComChannel) queueTurn(chatID string, turn wecomTurn) { + c.turnsMu.Lock() + defer c.turnsMu.Unlock() + c.turns[chatID] = append(c.turns[chatID], turn) +} + +func (c *WeComChannel) consumeTurn(chatID string, turn wecomTurn) bool { + c.turnsMu.Lock() + defer c.turnsMu.Unlock() + + queue := c.turns[chatID] + if len(queue) == 0 { + return false + } + current := queue[0] + if current.ReqID != turn.ReqID || current.StreamID != turn.StreamID { + return false + } + if len(queue) == 1 { + delete(c.turns, chatID) + return true + } + c.turns[chatID] = queue[1:] + return true +} + +func (c *WeComChannel) clearTurns() { + c.turnsMu.Lock() + c.turns = make(map[string][]wecomTurn) + c.turnsMu.Unlock() +} + +func randomID(n int) string { + const alphabet = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" + if n <= 0 { + n = 10 + } + buf := make([]byte, n) + for i := range buf { + v, _ := rand.Int(rand.Reader, big.NewInt(int64(len(alphabet)))) + buf[i] = alphabet[v.Int64()] + } + return string(buf) +} + +func (s *wecomStreamer) Update(ctx context.Context, content string) error { + s.mu.Lock() + defer s.mu.Unlock() + + if s.closed { + return nil + } + if err := s.validateActiveTurn(); err != nil { + return err + } + if err := ctx.Err(); err != nil { + return err + } + + if !s.lastSentAt.IsZero() { + wait := time.Until(s.lastSentAt.Add(wecomStreamMinInterval)) + if wait > 0 { + timer := time.NewTimer(wait) + defer timer.Stop() + select { + case <-ctx.Done(): + return ctx.Err() + case <-timer.C: + } + } + } + + if err := s.channel.sendStreamChunk(s.turn, false, content); err != nil { + return err + } + s.content = content + s.lastSentAt = time.Now() + return nil +} + +func (s *wecomStreamer) Finalize(ctx context.Context, content string) error { + s.mu.Lock() + defer s.mu.Unlock() + + if s.closed { + return nil + } + if err := s.validateActiveTurn(); err != nil { + return err + } + if err := ctx.Err(); err != nil { + return err + } + if err := s.channel.sendStreamChunk(s.turn, true, content); err != nil { + return err + } + + s.content = content + s.closed = true + s.channel.consumeTurn(s.chatID, s.turn) + return nil +} + +func (s *wecomStreamer) Cancel(_ context.Context) { + s.mu.Lock() + defer s.mu.Unlock() + + if s.closed { + return + } + if s.validateActiveTurn() == nil { + _ = s.channel.sendStreamChunk(s.turn, true, s.content) + s.channel.consumeTurn(s.chatID, s.turn) + } + s.closed = true +} + +func (s *wecomStreamer) validateActiveTurn() error { + if time.Since(s.turn.CreatedAt) > wecomStreamMaxDuration { + s.channel.consumeTurn(s.chatID, s.turn) + return fmt.Errorf("wecom streaming unavailable: turn expired") + } + current, ok := s.channel.getTurn(s.chatID) + if !ok || current.ReqID != s.turn.ReqID || current.StreamID != s.turn.StreamID { + return fmt.Errorf("wecom streaming unavailable: turn no longer active") + } + return nil +} diff --git a/pkg/channels/wecom/wecom_test.go b/pkg/channels/wecom/wecom_test.go new file mode 100644 index 000000000..c7a4adfc0 --- /dev/null +++ b/pkg/channels/wecom/wecom_test.go @@ -0,0 +1,660 @@ +package wecom + +import ( + "context" + "encoding/json" + "errors" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/media" +) + +func TestDispatchIncoming_UsesActualChatIDAndStoresReqIDRoute(t *testing.T) { + t.Parallel() + + messageBus := bus.NewMessageBus() + ch := newTestWeComChannel(t, messageBus) + + var commands []wecomCommand + ch.commandSend = func(cmd wecomCommand, _ time.Duration) (wecomEnvelope, error) { + commands = append(commands, cmd) + return wecomTestAck(nil), nil + } + + msg := wecomIncomingMessage{ + MsgID: "msg-1", + ChatID: "chat-1", + ChatType: "direct", + MsgType: "text", + Text: &struct { + Content string `json:"content"` + }{Content: "hello"}, + } + msg.From.UserID = "user-1" + + if err := ch.dispatchIncoming("req-1", msg); err != nil { + t.Fatalf("dispatchIncoming() error = %v", err) + } + + select { + case inbound := <-messageBus.InboundChan(): + if inbound.ChatID != "chat-1" { + t.Fatalf("inbound ChatID = %q, want chat-1", inbound.ChatID) + } + if inbound.MessageID != "msg-1" { + t.Fatalf("inbound MessageID = %q, want msg-1", inbound.MessageID) + } + if inbound.Peer.ID != "chat-1" { + t.Fatalf("inbound Peer.ID = %q, want chat-1", inbound.Peer.ID) + } + if inbound.Metadata["req_id"] != "req-1" { + t.Fatalf("inbound req_id = %q, want req-1", inbound.Metadata["req_id"]) + } + default: + t.Fatal("expected inbound message to be published") + } + + turn, ok := ch.getTurn("chat-1") + if !ok { + t.Fatal("expected queued turn for chat-1") + } + if turn.ReqID != "req-1" { + t.Fatalf("turn.ReqID = %q, want req-1", turn.ReqID) + } + + route, ok := ch.routes.Get("chat-1") + if !ok { + t.Fatal("expected persisted route for chat-1") + } + if route.ReqID != "req-1" || route.ChatType != 1 { + t.Fatalf("route = %+v", route) + } + + if len(commands) != 1 { + t.Fatalf("expected 1 opening command, got %d", len(commands)) + } + if commands[0].Cmd != wecomCmdRespondMsg { + t.Fatalf("opening command = %q, want %q", commands[0].Cmd, wecomCmdRespondMsg) + } + if commands[0].Headers.ReqID != "req-1" { + t.Fatalf("opening req_id = %q, want req-1", commands[0].Headers.ReqID) + } +} + +func TestNewChannel_DoesNotRegisterMessageSplitLimit(t *testing.T) { + t.Parallel() + + ch := newTestWeComChannel(t, bus.NewMessageBus()) + if got := ch.MaxMessageLength(); got != 0 { + t.Fatalf("MaxMessageLength() = %d, want 0", got) + } +} + +func TestBeginStream_UpdateAndFinalize(t *testing.T) { + t.Parallel() + + ch := newTestWeComChannel(t, bus.NewMessageBus()) + ch.SetRunning(true) + ch.queueTurn("chat-1", wecomTurn{ + ReqID: "req-1", + ChatID: "chat-1", + ChatType: 1, + StreamID: "stream-1", + CreatedAt: time.Now(), + }) + + var commands []wecomCommand + ch.commandSend = func(cmd wecomCommand, _ time.Duration) (wecomEnvelope, error) { + commands = append(commands, cmd) + return wecomTestAck(nil), nil + } + + streamer, err := ch.BeginStream(context.Background(), "chat-1") + if err != nil { + t.Fatalf("BeginStream() error = %v", err) + } + if err := streamer.Update(context.Background(), "draft"); err != nil { + t.Fatalf("Update() error = %v", err) + } + if err := streamer.Finalize(context.Background(), "final"); err != nil { + t.Fatalf("Finalize() error = %v", err) + } + + if len(commands) != 2 { + t.Fatalf("expected 2 commands, got %d", len(commands)) + } + for i, wantFinish := range []bool{false, true} { + if commands[i].Cmd != wecomCmdRespondMsg { + t.Fatalf("command[%d].Cmd = %q, want %q", i, commands[i].Cmd, wecomCmdRespondMsg) + } + body, ok := commands[i].Body.(wecomRespondMsgBody) + if !ok { + t.Fatalf("command[%d] body type = %T", i, commands[i].Body) + } + if body.Stream == nil { + t.Fatalf("command[%d] missing stream body", i) + } + if body.Stream.ID != "stream-1" { + t.Fatalf("command[%d] stream id = %q, want stream-1", i, body.Stream.ID) + } + if body.Stream.Finish != wantFinish { + t.Fatalf("command[%d] finish = %v, want %v", i, body.Stream.Finish, wantFinish) + } + } + if body := commands[0].Body.(wecomRespondMsgBody); body.Stream.Content != "draft" { + t.Fatalf("update content = %q, want draft", body.Stream.Content) + } + if body := commands[1].Body.(wecomRespondMsgBody); body.Stream.Content != "final" { + t.Fatalf("final content = %q, want final", body.Stream.Content) + } + if _, ok := ch.getTurn("chat-1"); ok { + t.Fatal("expected turn to be consumed after Finalize") + } +} + +func TestSend_StreamFailureFallsBackToActualChatID(t *testing.T) { + t.Parallel() + + ch := newTestWeComChannel(t, bus.NewMessageBus()) + ch.SetRunning(true) + ch.queueTurn("chat-1", wecomTurn{ + ReqID: "req-1", + ChatID: "chat-1", + ChatType: 1, + StreamID: "stream-1", + CreatedAt: time.Now(), + }) + ch.queueTurn("chat-1", wecomTurn{ + ReqID: "req-2", + ChatID: "chat-1", + ChatType: 1, + StreamID: "stream-2", + CreatedAt: time.Now(), + }) + if err := ch.routes.Put("chat-1", "req-2", 1, time.Hour); err != nil { + t.Fatalf("Put() error = %v", err) + } + + var commands []wecomCommand + ch.commandSend = func(cmd wecomCommand, _ time.Duration) (wecomEnvelope, error) { + commands = append(commands, cmd) + if len(commands) == 1 && cmd.Cmd == wecomCmdRespondMsg { + return wecomEnvelope{}, errors.New("stream send failed") + } + return wecomTestAck(nil), nil + } + + if err := ch.Send(context.Background(), bus.OutboundMessage{ + Channel: "wecom", + ChatID: "chat-1", + Content: "hello", + }); err != nil { + t.Fatalf("Send() error = %v", err) + } + + if len(commands) != 2 { + t.Fatalf("expected 2 commands, got %d", len(commands)) + } + if commands[0].Cmd != wecomCmdRespondMsg || commands[0].Headers.ReqID != "req-1" { + t.Fatalf("first command = %+v", commands[0]) + } + if commands[1].Cmd != wecomCmdSendMsg { + t.Fatalf("second command = %q, want %q", commands[1].Cmd, wecomCmdSendMsg) + } + body, ok := commands[1].Body.(wecomSendMsgBody) + if !ok { + t.Fatalf("unexpected send body type %T", commands[1].Body) + } + if body.ChatID != "chat-1" { + t.Fatalf("send chatid = %q, want chat-1", body.ChatID) + } + if body.ChatType != 1 { + t.Fatalf("send chat_type = %d, want 1", body.ChatType) + } + + nextTurn, ok := ch.getTurn("chat-1") + if !ok { + t.Fatal("expected second turn to remain queued") + } + if nextTurn.ReqID != "req-2" { + t.Fatalf("next queued req_id = %q, want req-2", nextTurn.ReqID) + } +} + +func TestSend_DoesNotSplitStreamReply(t *testing.T) { + t.Parallel() + + ch := newTestWeComChannel(t, bus.NewMessageBus()) + ch.SetRunning(true) + ch.queueTurn("chat-1", wecomTurn{ + ReqID: "req-1", + ChatID: "chat-1", + ChatType: 1, + StreamID: "stream-1", + CreatedAt: time.Now(), + }) + + var commands []wecomCommand + ch.commandSend = func(cmd wecomCommand, _ time.Duration) (wecomEnvelope, error) { + commands = append(commands, cmd) + return wecomTestAck(nil), nil + } + + content := strings.Repeat("\u4e2d", 30000) + if err := ch.Send(context.Background(), bus.OutboundMessage{ + Channel: "wecom", + ChatID: "chat-1", + Content: content, + }); err != nil { + t.Fatalf("Send() error = %v", err) + } + + if len(commands) != 1 { + t.Fatalf("expected 1 stream command, got %d", len(commands)) + } + body, ok := commands[0].Body.(wecomRespondMsgBody) + if !ok { + t.Fatalf("unexpected body type %T", commands[0].Body) + } + if body.Stream == nil || !body.Stream.Finish { + t.Fatalf("stream body = %+v", body.Stream) + } + if body.Stream.Content != content { + t.Fatalf("stream content length = %d, want %d", len(body.Stream.Content), len(content)) + } +} + +func TestSend_DoesNotSplitActivePush(t *testing.T) { + t.Parallel() + + ch := newTestWeComChannel(t, bus.NewMessageBus()) + ch.SetRunning(true) + + var commands []wecomCommand + ch.commandSend = func(cmd wecomCommand, _ time.Duration) (wecomEnvelope, error) { + commands = append(commands, cmd) + return wecomTestAck(nil), nil + } + + content := strings.Repeat("a", 30000) + if err := ch.Send(context.Background(), bus.OutboundMessage{ + Channel: "wecom", + ChatID: "chat-1", + Content: content, + }); err != nil { + t.Fatalf("Send() error = %v", err) + } + + if len(commands) != 1 { + t.Fatalf("expected 1 send command, got %d", len(commands)) + } + if commands[0].Cmd != wecomCmdSendMsg { + t.Fatalf("command = %q, want %q", commands[0].Cmd, wecomCmdSendMsg) + } + body, ok := commands[0].Body.(wecomSendMsgBody) + if !ok { + t.Fatalf("unexpected body type %T", commands[0].Body) + } + if body.Markdown == nil || body.Markdown.Content != content { + t.Fatalf("markdown content length = %d, want %d", len(body.Markdown.Content), len(content)) + } +} + +func TestSendMedia_SendsActiveImage(t *testing.T) { + t.Parallel() + + ch := newTestWeComChannel(t, bus.NewMessageBus()) + ch.SetRunning(true) + + store := media.NewFileMediaStore() + ch.SetMediaStore(store) + + imageData := wecomTestJPEGData(t) + imagePath := filepath.Join(t.TempDir(), "photo.jpg") + if err := os.WriteFile(imagePath, imageData, 0o600); err != nil { + t.Fatalf("WriteFile() error = %v", err) + } + ref, err := store.Store(imagePath, media.MediaMeta{ + Filename: "photo.jpg", + ContentType: "image/jpeg", + Source: "test", + CleanupPolicy: media.CleanupPolicyForgetOnly, + }, "scope-1") + if err != nil { + t.Fatalf("Store() error = %v", err) + } + + var commands []wecomCommand + ch.commandSend = func(cmd wecomCommand, _ time.Duration) (wecomEnvelope, error) { + commands = append(commands, cmd) + switch cmd.Cmd { + case wecomCmdUploadMediaInit: + return wecomTestAck(wecomUploadMediaInitResponse{UploadID: "upload-1"}), nil + case wecomCmdUploadMediaEnd: + return wecomTestAck(wecomUploadMediaFinishResponse{ + Type: "image", + MediaID: "media-1", + }), nil + default: + return wecomTestAck(nil), nil + } + } + + err = ch.SendMedia(context.Background(), bus.OutboundMediaMessage{ + Channel: "wecom", + ChatID: "chat-1", + Parts: []bus.MediaPart{{ + Ref: ref, + Type: "image", + Filename: "photo.jpg", + ContentType: "image/jpeg", + }}, + }) + if err != nil { + t.Fatalf("SendMedia() error = %v", err) + } + + if len(commands) != 4 { + t.Fatalf("expected 4 commands, got %d", len(commands)) + } + if commands[0].Cmd != wecomCmdUploadMediaInit { + t.Fatalf("first command = %q, want %q", commands[0].Cmd, wecomCmdUploadMediaInit) + } + initBody, ok := commands[0].Body.(wecomUploadMediaInitBody) + if !ok { + t.Fatalf("unexpected init body type %T", commands[0].Body) + } + if initBody.Type != "image" || initBody.Filename != "photo.jpg" || initBody.TotalChunks != 1 { + t.Fatalf("init body = %+v", initBody) + } + if commands[1].Cmd != wecomCmdUploadMediaChunk { + t.Fatalf("second command = %q, want %q", commands[1].Cmd, wecomCmdUploadMediaChunk) + } + chunkBody, ok := commands[1].Body.(wecomUploadMediaChunkBody) + if !ok { + t.Fatalf("unexpected chunk body type %T", commands[1].Body) + } + if chunkBody.UploadID != "upload-1" || chunkBody.ChunkIndex != 0 || chunkBody.Base64Data == "" { + t.Fatalf("chunk body = %+v", chunkBody) + } + if commands[2].Cmd != wecomCmdUploadMediaEnd { + t.Fatalf("third command = %q, want %q", commands[2].Cmd, wecomCmdUploadMediaEnd) + } + if commands[3].Cmd != wecomCmdSendMsg { + t.Fatalf("fourth command = %q, want %q", commands[3].Cmd, wecomCmdSendMsg) + } + + body, ok := commands[3].Body.(wecomSendMsgBody) + if !ok { + t.Fatalf("unexpected send body type %T", commands[3].Body) + } + if body.MsgType != "image" || body.Image == nil { + t.Fatalf("send body = %+v", body) + } + if body.ChatID != "chat-1" { + t.Fatalf("send chatid = %q, want chat-1", body.ChatID) + } + if body.Image.MediaID != "media-1" { + t.Fatalf("image media_id = %q, want media-1", body.Image.MediaID) + } +} + +func TestSendMedia_UsesTurnImageAndFinishesStream(t *testing.T) { + t.Parallel() + + ch := newTestWeComChannel(t, bus.NewMessageBus()) + ch.SetRunning(true) + + store := media.NewFileMediaStore() + ch.SetMediaStore(store) + + imageData := wecomTestJPEGData(t) + imagePath := filepath.Join(t.TempDir(), "reply.jpg") + if err := os.WriteFile(imagePath, imageData, 0o600); err != nil { + t.Fatalf("WriteFile() error = %v", err) + } + ref, err := store.Store(imagePath, media.MediaMeta{ + Filename: "reply.jpg", + ContentType: "image/jpeg", + Source: "test", + CleanupPolicy: media.CleanupPolicyForgetOnly, + }, "scope-2") + if err != nil { + t.Fatalf("Store() error = %v", err) + } + + ch.queueTurn("chat-1", wecomTurn{ + ReqID: "req-1", + ChatID: "chat-1", + ChatType: 1, + StreamID: "stream-1", + CreatedAt: time.Now(), + }) + putErr := ch.routes.Put("chat-1", "req-1", 1, time.Hour) + if putErr != nil { + t.Fatalf("Put() error = %v", putErr) + } + + var commands []wecomCommand + ch.commandSend = func(cmd wecomCommand, _ time.Duration) (wecomEnvelope, error) { + commands = append(commands, cmd) + switch cmd.Cmd { + case wecomCmdUploadMediaInit: + return wecomTestAck(wecomUploadMediaInitResponse{UploadID: "upload-2"}), nil + case wecomCmdUploadMediaEnd: + return wecomTestAck(wecomUploadMediaFinishResponse{ + Type: "image", + MediaID: "media-2", + }), nil + default: + return wecomTestAck(nil), nil + } + } + + err = ch.SendMedia(context.Background(), bus.OutboundMediaMessage{ + Channel: "wecom", + ChatID: "chat-1", + Parts: []bus.MediaPart{{ + Ref: ref, + Type: "image", + Filename: "reply.jpg", + ContentType: "image/jpeg", + }}, + }) + if err != nil { + t.Fatalf("SendMedia() error = %v", err) + } + + if len(commands) != 5 { + t.Fatalf("expected 5 commands, got %d", len(commands)) + } + if commands[0].Cmd != wecomCmdUploadMediaInit { + t.Fatalf("first command = %+v", commands[0]) + } + if commands[1].Cmd != wecomCmdUploadMediaChunk { + t.Fatalf("second command = %+v", commands[1]) + } + if commands[2].Cmd != wecomCmdUploadMediaEnd { + t.Fatalf("third command = %+v", commands[2]) + } + if commands[3].Cmd != wecomCmdRespondMsg || commands[3].Headers.ReqID != "req-1" { + t.Fatalf("fourth command = %+v", commands[3]) + } + if commands[4].Cmd != wecomCmdRespondMsg || commands[4].Headers.ReqID != "req-1" { + t.Fatalf("fifth command = %+v", commands[4]) + } + + imageBody, ok := commands[3].Body.(wecomRespondMsgBody) + if !ok { + t.Fatalf("unexpected image body type %T", commands[3].Body) + } + if imageBody.MsgType != "image" || imageBody.Image == nil { + t.Fatalf("image body = %+v", imageBody) + } + if imageBody.Image.MediaID != "media-2" { + t.Fatalf("image media_id = %q, want media-2", imageBody.Image.MediaID) + } + + streamBody, ok := commands[4].Body.(wecomRespondMsgBody) + if !ok { + t.Fatalf("unexpected finish body type %T", commands[4].Body) + } + if streamBody.MsgType != "stream" || streamBody.Stream == nil || !streamBody.Stream.Finish { + t.Fatalf("finish body = %+v", streamBody) + } + + if _, ok := ch.getTurn("chat-1"); ok { + t.Fatal("expected turn to be removed after media send") + } +} + +func TestSendMedia_SendsActiveFile(t *testing.T) { + t.Parallel() + + ch := newTestWeComChannel(t, bus.NewMessageBus()) + ch.SetRunning(true) + + store := media.NewFileMediaStore() + ch.SetMediaStore(store) + + filePath := filepath.Join(t.TempDir(), "report.pdf") + if err := os.WriteFile(filePath, []byte("%PDF-1.4"), 0o600); err != nil { + t.Fatalf("WriteFile() error = %v", err) + } + ref, err := store.Store(filePath, media.MediaMeta{ + Filename: "report.pdf", + ContentType: "application/pdf", + Source: "test", + CleanupPolicy: media.CleanupPolicyForgetOnly, + }, "scope-3") + if err != nil { + t.Fatalf("Store() error = %v", err) + } + + var commands []wecomCommand + ch.commandSend = func(cmd wecomCommand, _ time.Duration) (wecomEnvelope, error) { + commands = append(commands, cmd) + switch cmd.Cmd { + case wecomCmdUploadMediaInit: + return wecomTestAck(wecomUploadMediaInitResponse{UploadID: "upload-3"}), nil + case wecomCmdUploadMediaEnd: + return wecomTestAck(wecomUploadMediaFinishResponse{ + Type: "file", + MediaID: "media-3", + }), nil + default: + return wecomTestAck(nil), nil + } + } + + err = ch.SendMedia(context.Background(), bus.OutboundMediaMessage{ + Channel: "wecom", + ChatID: "chat-2", + Parts: []bus.MediaPart{{ + Ref: ref, + Type: "file", + Filename: "report.pdf", + ContentType: "application/pdf", + }}, + }) + if err != nil { + t.Fatalf("SendMedia() error = %v", err) + } + + if len(commands) != 4 { + t.Fatalf("expected 4 commands, got %d", len(commands)) + } + if commands[0].Cmd != wecomCmdUploadMediaInit { + t.Fatalf("first command = %q, want %q", commands[0].Cmd, wecomCmdUploadMediaInit) + } + initBody, ok := commands[0].Body.(wecomUploadMediaInitBody) + if !ok { + t.Fatalf("unexpected init body type %T", commands[0].Body) + } + if initBody.Type != "file" || initBody.Filename != "report.pdf" { + t.Fatalf("init body = %+v", initBody) + } + if commands[1].Cmd != wecomCmdUploadMediaChunk { + t.Fatalf("second command = %q, want %q", commands[1].Cmd, wecomCmdUploadMediaChunk) + } + if commands[2].Cmd != wecomCmdUploadMediaEnd { + t.Fatalf("third command = %q, want %q", commands[2].Cmd, wecomCmdUploadMediaEnd) + } + if commands[3].Cmd != wecomCmdSendMsg { + t.Fatalf("fourth command = %q, want %q", commands[3].Cmd, wecomCmdSendMsg) + } + + body, ok := commands[3].Body.(wecomSendMsgBody) + if !ok { + t.Fatalf("unexpected body type %T", commands[3].Body) + } + if body.MsgType != "file" || body.File == nil { + t.Fatalf("body = %+v", body) + } + if body.File.MediaID != "media-3" { + t.Fatalf("file media_id = %q, want media-3", body.File.MediaID) + } +} + +func newTestWeComChannel(t *testing.T, messageBus *bus.MessageBus) *WeComChannel { + t.Helper() + + cfg := config.WeComConfig{BotID: "bot-1"} + cfg.SetSecret("secret-1") + ch, err := NewChannel(cfg, messageBus) + if err != nil { + t.Fatalf("NewChannel() error = %v", err) + } + ch.ctx = context.Background() + ch.routes = newReqIDStore(filepath.Join(t.TempDir(), "reqids.json")) + return ch +} + +func wecomTestJPEGData(t *testing.T) []byte { + t.Helper() + + const jpegBase64 = "/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAP//////////////////////////////////////////////////////////////////////////////////////" + + "//////////////////////////////////////////////////////////////////////////////////////////////2wBDAf//////////////////////////////////////////////////////////////////////////////////////" + + "//////////////////////////////////////////////////////////////////////////////////////////////wAARCAABAAEDASIAAhEBAxEB/8QAFQABAQAAAAAAAAAAAAAAAAAAAAb/xAAVEQEBAAAAAAAAAAAAAAAAAAAABf/aAAwDAQACEAMQAAAB6A//xAAVEAEBAAAAAAAAAAAAAAAAAAAAEf/aAAgBAQABBQJf/8QAFBEBAAAAAAAAAAAAAAAAAAAAEP/aAAgBAwEBPwF//8QAFBEBAAAAAAAAAAAAAAAAAAAAEP/aAAgBAgEBPwF//8QAFBABAAAAAAAAAAAAAAAAAAAAEP/aAAgBAQAGPwJf/8QAFBABAAAAAAAAAAAAAAAAAAAAEP/aAAgBAQABPyFf/9k=" + + return decodeTestBase64(t, jpegBase64) +} + +func TestDecodeWeComUploadFinish_AcceptsNumericCreatedAt(t *testing.T) { + t.Parallel() + + resp, err := decodeWeComEnvelopeBody[wecomUploadMediaFinishResponse](wecomEnvelope{ + Body: json.RawMessage(`{"type":"file","media_id":"media-1","created_at":1380000000}`), + }) + if err != nil { + t.Fatalf("decodeWeComEnvelopeBody() error = %v", err) + } + if resp.Type != "file" || resp.MediaID != "media-1" { + t.Fatalf("response = %+v", resp) + } + if string(resp.CreatedAt) != "1380000000" { + t.Fatalf("created_at = %s, want 1380000000", string(resp.CreatedAt)) + } +} + +func wecomTestAck(body any) wecomEnvelope { + var raw []byte + if body != nil { + encoded, err := json.Marshal(body) + if err != nil { + panic(err) + } + raw = encoded + } + return wecomEnvelope{ + ErrCode: 0, + ErrMsg: "ok", + Body: raw, + } +} diff --git a/pkg/config/config.go b/pkg/config/config.go index f0d9aa580..47bb7e8f1 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -106,6 +106,7 @@ func (c *Config) WithSecurity(sec *SecurityConfig) *Config { c.security = sec return c } + sec = normalizeSecurityConfig(sec) err := applySecurityConfig(c, sec) if err != nil { return nil @@ -320,10 +321,7 @@ type AgentDefaults struct { ToolFeedback ToolFeedbackConfig `json:"tool_feedback,omitempty"` } -const ( - DefaultMaxMediaSize = 20 * 1024 * 1024 // 20 MB - DefaultWeComAIBotProcessingMessage = "⏳ Processing, please wait. The results will be sent shortly." -) +const DefaultMaxMediaSize = 20 * 1024 * 1024 // 20 MB func (d *AgentDefaults) GetMaxMediaSize() int { if d.MaxMediaSize > 0 { @@ -363,9 +361,7 @@ type ChannelsConfig struct { Matrix MatrixConfig `json:"matrix"` LINE LINEConfig `json:"line"` OneBot OneBotConfig `json:"onebot"` - WeCom WeComConfig `json:"wecom"` - WeComApp WeComAppConfig `json:"wecom_app"` - WeComAIBot WeComAIBotConfig `json:"wecom_aibot"` + WeCom WeComConfig `json:"wecom" envPrefix:"PICOCLAW_CHANNELS_WECOM_"` Weixin WeixinConfig `json:"weixin"` Pico PicoConfig `json:"pico"` PicoClient PicoClientConfig `json:"pico_client"` @@ -385,7 +381,7 @@ type TypingConfig struct { // PlaceholderConfig controls placeholder message behavior (Phase 10). type PlaceholderConfig struct { - Enabled bool `json:"enabled,omitempty"` + Enabled bool `json:"enabled"` Text string `json:"text,omitempty"` } @@ -590,18 +586,20 @@ func (c *SlackConfig) SetAppToken(token string) { } type MatrixConfig struct { - Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_MATRIX_ENABLED"` - Homeserver string `json:"homeserver" env:"PICOCLAW_CHANNELS_MATRIX_HOMESERVER"` - UserID string `json:"user_id" env:"PICOCLAW_CHANNELS_MATRIX_USER_ID"` + Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_MATRIX_ENABLED"` + Homeserver string `json:"homeserver" env:"PICOCLAW_CHANNELS_MATRIX_HOMESERVER"` + UserID string `json:"user_id" env:"PICOCLAW_CHANNELS_MATRIX_USER_ID"` accessToken string - DeviceID string `json:"device_id,omitempty" env:"PICOCLAW_CHANNELS_MATRIX_DEVICE_ID"` - JoinOnInvite bool `json:"join_on_invite" env:"PICOCLAW_CHANNELS_MATRIX_JOIN_ON_INVITE"` - MessageFormat string `json:"message_format,omitempty" env:"PICOCLAW_CHANNELS_MATRIX_MESSAGE_FORMAT"` - AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_MATRIX_ALLOW_FROM"` + DeviceID string `json:"device_id,omitempty" env:"PICOCLAW_CHANNELS_MATRIX_DEVICE_ID"` + JoinOnInvite bool `json:"join_on_invite" env:"PICOCLAW_CHANNELS_MATRIX_JOIN_ON_INVITE"` + MessageFormat string `json:"message_format,omitempty" env:"PICOCLAW_CHANNELS_MATRIX_MESSAGE_FORMAT"` + AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_MATRIX_ALLOW_FROM"` GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"` Placeholder PlaceholderConfig `json:"placeholder,omitempty"` - ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_MATRIX_REASONING_CHANNEL_ID"` + ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_MATRIX_REASONING_CHANNEL_ID"` secDirty bool + CryptoDatabasePath string `json:"crypto_database_path,omitempty" env:"PICOCLAW_CHANNELS_MATRIX_CRYPTO_DATABASE_PATH"` + CryptoPassphrase string `json:"crypto_passphrase,omitempty" env:"PICOCLAW_CHANNELS_MATRIX_CRYPTO_PASSPHRASE"` } // AccessToken returns the Matrix access token @@ -677,136 +675,28 @@ func (c *OneBotConfig) SetAccessToken(token string) { c.secDirty = true } +type WeComGroupConfig struct { + AllowFrom FlexibleStringSlice `json:"allow_from,omitempty"` +} + type WeComConfig struct { - Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_WECOM_ENABLED"` - token string - encodingAESKey string - WebhookURL string `json:"webhook_url" env:"PICOCLAW_CHANNELS_WECOM_WEBHOOK_URL"` - WebhookHost string `json:"webhook_host" env:"PICOCLAW_CHANNELS_WECOM_WEBHOOK_HOST"` - WebhookPort int `json:"webhook_port" env:"PICOCLAW_CHANNELS_WECOM_WEBHOOK_PORT"` - WebhookPath string `json:"webhook_path" env:"PICOCLAW_CHANNELS_WECOM_WEBHOOK_PATH"` - AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_WECOM_ALLOW_FROM"` - ReplyTimeout int `json:"reply_timeout" env:"PICOCLAW_CHANNELS_WECOM_REPLY_TIMEOUT"` - GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"` - ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_WECOM_REASONING_CHANNEL_ID"` - secDirty bool + Enabled bool `json:"enabled" env:"ENABLED"` + BotID string `json:"bot_id" env:"BOT_ID"` + secret string + WebSocketURL string `json:"websocket_url,omitempty" env:"WEBSOCKET_URL"` + SendThinkingMessage bool `json:"send_thinking_message" env:"SEND_THINKING_MESSAGE"` + AllowFrom FlexibleStringSlice `json:"allow_from" env:"ALLOW_FROM"` + ReasoningChannelID string `json:"reasoning_channel_id" env:"REASONING_CHANNEL_ID"` + secDirty bool } -// Token returns the WeCom token -func (c *WeComConfig) Token() string { - return c.token -} - -// SetToken sets the WeCom token -func (c *WeComConfig) SetToken(token string) { - c.token = token - c.secDirty = true -} - -// EncodingAESKey returns the WeCom encoding AES key -func (c *WeComConfig) EncodingAESKey() string { - return c.encodingAESKey -} - -// SetEncodingAESKey sets the WeCom encoding AES key -func (c *WeComConfig) SetEncodingAESKey(key string) { - c.encodingAESKey = key - c.secDirty = true -} - -type WeComAppConfig struct { - Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_WECOM_APP_ENABLED"` - CorpID string `json:"corp_id" env:"PICOCLAW_CHANNELS_WECOM_APP_CORP_ID"` - corpSecret string - AgentID int64 `json:"agent_id" env:"PICOCLAW_CHANNELS_WECOM_APP_AGENT_ID"` - token string - encodingAESKey string - WebhookHost string `json:"webhook_host" env:"PICOCLAW_CHANNELS_WECOM_APP_WEBHOOK_HOST"` - WebhookPort int `json:"webhook_port" env:"PICOCLAW_CHANNELS_WECOM_APP_WEBHOOK_PORT"` - WebhookPath string `json:"webhook_path" env:"PICOCLAW_CHANNELS_WECOM_APP_WEBHOOK_PATH"` - AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_WECOM_APP_ALLOW_FROM"` - ReplyTimeout int `json:"reply_timeout" env:"PICOCLAW_CHANNELS_WECOM_APP_REPLY_TIMEOUT"` - GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"` - ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_WECOM_APP_REASONING_CHANNEL_ID"` - secDirty bool -} - -// CorpSecret returns the corporate secret for WeCom app -func (c *WeComAppConfig) CorpSecret() string { - return c.corpSecret -} - -// SetCorpSecret sets the corporate secret for WeCom app -func (c *WeComAppConfig) SetCorpSecret(secret string) { - c.corpSecret = secret - c.secDirty = true -} - -// Token returns the webhook token for WeCom app -func (c *WeComAppConfig) Token() string { - return c.token -} - -// SetToken sets the webhook token for WeCom app -func (c *WeComAppConfig) SetToken(token string) { - c.token = token - c.secDirty = true -} - -// EncodingAESKey returns the encoding AES key for WeCom app -func (c *WeComAppConfig) EncodingAESKey() string { - return c.encodingAESKey -} - -// SetEncodingAESKey sets the encoding AES key for WeCom app -func (c *WeComAppConfig) SetEncodingAESKey(key string) { - c.encodingAESKey = key - c.secDirty = true -} - -type WeComAIBotConfig struct { - Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_ENABLED"` - BotID string `json:"bot_id,omitempty" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_BOT_ID"` - secret string - token string - encodingAESKey string - WebhookPath string `json:"webhook_path,omitempty" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_WEBHOOK_PATH"` - AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_ALLOW_FROM"` - ReplyTimeout int `json:"reply_timeout" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_REPLY_TIMEOUT"` - MaxSteps int `json:"max_steps" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_MAX_STEPS"` // Maximum streaming steps - WelcomeMessage string `json:"welcome_message" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_WELCOME_MESSAGE"` // Sent on enter_chat event; empty = no welcome - ProcessingMessage string `json:"processing_message,omitempty" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_PROCESSING_MESSAGE"` - ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_REASONING_CHANNEL_ID"` - secDirty bool -} - -// Token returns the webhook token for WeCom AI bot -func (c *WeComAIBotConfig) Token() string { - return c.token -} - -// EncodingAESKey returns the encoding AES key for WeCom AI bot -func (c *WeComAIBotConfig) EncodingAESKey() string { - return c.encodingAESKey -} - -// SetToken sets the token for WeCom AI bot -func (c *WeComAIBotConfig) SetToken(token string) { - c.token = token - c.secDirty = true -} - -// SetEncodingAESKey sets the encoding AES key for WeCom AI bot -func (c *WeComAIBotConfig) SetEncodingAESKey(key string) { - c.encodingAESKey = key - c.secDirty = true -} - -func (c *WeComAIBotConfig) Secret() string { +// Secret returns the WeCom bot secret. +func (c *WeComConfig) Secret() string { return c.secret } -func (c *WeComAIBotConfig) SetSecret(secret string) { +// SetSecret sets the WeCom bot secret. +func (c *WeComConfig) SetSecret(secret string) { c.secret = secret c.secDirty = true } @@ -814,6 +704,7 @@ func (c *WeComAIBotConfig) SetSecret(secret string) { type WeixinConfig struct { Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_WEIXIN_ENABLED"` token string + AccountID string `json:"account_id,omitempty" env:"PICOCLAW_CHANNELS_WEIXIN_ACCOUNT_ID"` BaseURL string `json:"base_url" env:"PICOCLAW_CHANNELS_WEIXIN_BASE_URL"` CDNBaseURL string `json:"cdn_base_url" env:"PICOCLAW_CHANNELS_WEIXIN_CDN_BASE_URL"` Proxy string `json:"proxy" env:"PICOCLAW_CHANNELS_WEIXIN_PROXY"` @@ -966,6 +857,10 @@ type ModelConfig struct { secModelName string apiKeys []string secDirty bool + + // isVirtual marks this model as a virtual model generated from multi-key expansion. + // Virtual models should not be persisted to config files. + isVirtual bool } // APIKey returns the first API key from apiKeys @@ -976,6 +871,11 @@ func (c *ModelConfig) APIKey() string { return "" } +// IsVirtual returns true if this model was generated from multi-key expansion. +func (c *ModelConfig) IsVirtual() bool { + return c.isVirtual +} + // Validate checks if the ModelConfig has all required fields. func (c *ModelConfig) Validate() error { if c.ModelName == "" { @@ -1393,6 +1293,18 @@ func LoadConfig(path string) (*Config, error) { if err != nil { return nil, err } + // Load existing security config and merge with migrated one to prevent data loss + existingSec, secErr := loadSecurityConfig(securityPath(path)) + if secErr != nil { + logger.WarnF("failed to load existing security config during migration", map[string]any{"error": secErr}) + } + if existingSec != nil && cfg.security != nil { + cfg.security = mergeSecurityConfig(existingSec, cfg.security) + // Re-apply the merged security config to update all channels and models + if err = applySecurityConfig(cfg, cfg.security); err != nil { + logger.WarnF("failed to re-apply merged security config during migration", map[string]any{"error": err}) + } + } defer func(cfg *Config) { _ = SaveConfig(path, cfg) }(cfg) @@ -1609,39 +1521,10 @@ func applySecurityConfig(cfg *Config, sec *SecurityConfig) error { cfg.Channels.OneBot.accessToken = sec.Channels.OneBot.AccessToken } - // Handle WeCom token and encoding key + // Handle WeCom bot secret if sec.Channels.WeCom != nil { - if sec.Channels.WeCom.Token != "" { - cfg.Channels.WeCom.token = sec.Channels.WeCom.Token - } - if sec.Channels.WeCom.EncodingAESKey != "" { - cfg.Channels.WeCom.encodingAESKey = sec.Channels.WeCom.EncodingAESKey - } - } - - // Handle WeCom App credentials - if sec.Channels.WeComApp != nil { - if sec.Channels.WeComApp.CorpSecret != "" { - cfg.Channels.WeComApp.corpSecret = sec.Channels.WeComApp.CorpSecret - } - if sec.Channels.WeComApp.Token != "" { - cfg.Channels.WeComApp.token = sec.Channels.WeComApp.Token - } - if sec.Channels.WeComApp.EncodingAESKey != "" { - cfg.Channels.WeComApp.encodingAESKey = sec.Channels.WeComApp.EncodingAESKey - } - } - - // Handle WeCom AI Bot credentials - if sec.Channels.WeComAIBot != nil { - if sec.Channels.WeComAIBot.Token != "" { - cfg.Channels.WeComAIBot.token = sec.Channels.WeComAIBot.Token - } - if sec.Channels.WeComAIBot.EncodingAESKey != "" { - cfg.Channels.WeComAIBot.encodingAESKey = sec.Channels.WeComAIBot.EncodingAESKey - } - if sec.Channels.WeComAIBot.Secret != "" { - cfg.Channels.WeComAIBot.secret = sec.Channels.WeComAIBot.Secret + if sec.Channels.WeCom.Secret != "" { + cfg.Channels.WeCom.secret = sec.Channels.WeCom.Secret } } @@ -1768,6 +1651,7 @@ func SaveConfig(path string, cfg *Config) error { logger.ErrorC("config", "security is nil") return fmt.Errorf("security is nil") } + cfg.security = normalizeSecurityConfig(cfg.security) // Ensure version is always set when saving if cfg.Version == 0 { cfg.Version = CurrentVersion @@ -1864,27 +1748,10 @@ func SaveConfig(path string, cfg *Config) error { } if cfg.Channels.WeCom.secDirty { cfg.security.Channels.WeCom = &WeComSecurity{ - Token: cfg.Channels.WeCom.Token(), - EncodingAESKey: cfg.Channels.WeCom.EncodingAESKey(), + Secret: cfg.Channels.WeCom.Secret(), } cfg.Channels.WeCom.secDirty = false } - if cfg.Channels.WeComApp.secDirty { - cfg.security.Channels.WeComApp = &WeComAppSecurity{ - CorpSecret: cfg.Channels.WeComApp.CorpSecret(), - Token: cfg.Channels.WeComApp.Token(), - EncodingAESKey: cfg.Channels.WeComApp.EncodingAESKey(), - } - cfg.Channels.WeComApp.secDirty = false - } - if cfg.Channels.WeComAIBot.secDirty { - cfg.security.Channels.WeComAIBot = &WeComAIBotSecurity{ - Token: cfg.Channels.WeComAIBot.Token(), - EncodingAESKey: cfg.Channels.WeComAIBot.EncodingAESKey(), - Secret: cfg.Channels.WeComAIBot.Secret(), - } - cfg.Channels.WeComAIBot.secDirty = false - } if cfg.Tools.Web.Brave.secDirty { cfg.security.Web.Brave = &BraveSecurity{ APIKeys: cfg.Tools.Web.Brave.APIKeys(), @@ -1942,10 +1809,24 @@ func SaveConfig(path string, cfg *Config) error { return err } + // Filter out virtual models before serializing to config file + nonVirtualModels := make([]*ModelConfig, 0, len(cfg.ModelList)) + for _, m := range cfg.ModelList { + if !m.isVirtual { + nonVirtualModels = append(nonVirtualModels, m) + } + } + // Temporarily replace ModelList with filtered version for serialization + originalModelList := cfg.ModelList + cfg.ModelList = nonVirtualModels + data, err := json.MarshalIndent(cfg, "", " ") + // Restore original ModelList after serialization + cfg.ModelList = originalModelList if err != nil { return err } + logger.Infof("saving config to %s", path) return fileutil.WriteFileAtomic(path, data, 0o600) } @@ -2009,6 +1890,17 @@ func (c *Config) ValidateModelList() error { func (c *Config) SecurityCopyFrom(cfg *Config) { c.security = cfg.security + if c.security != nil { + if err := applySecurityConfig(c, c.security); err != nil { + logger.Errorf("failed to apply security config in SecurityCopyFrom: %v", err) + } + } +} + +// ApplySecurity re-applies the stored security config to populate private fields (tokens, API keys, etc.). +// Call this after SecurityCopyFrom when you need private fields to be accessible for validation or use. +func (c *Config) ApplySecurity() error { + return applySecurityConfig(c, c.security) } func MergeAPIKeys(apiKey string, apiKeys []string) []string { @@ -2148,6 +2040,7 @@ func expandMultiKeyModels(models []*ModelConfig) []*ModelConfig { RequestTimeout: m.RequestTimeout, ThinkingLevel: m.ThinkingLevel, ExtraBody: m.ExtraBody, + isVirtual: true, } expanded = append(expanded, additionalEntry) fallbackNames = append(fallbackNames, expandedName) diff --git a/pkg/config/config_old.go b/pkg/config/config_old.go index 01909f5a9..44c9435d1 100644 --- a/pkg/config/config_old.go +++ b/pkg/config/config_old.go @@ -85,23 +85,21 @@ type toolsConfigV0 struct { } type channelsConfigV0 struct { - WhatsApp WhatsAppConfig `json:"whatsapp"` - Telegram telegramConfigV0 `json:"telegram"` - Feishu feishuConfigV0 `json:"feishu"` - Discord discordConfigV0 `json:"discord"` - MaixCam maixcamConfigV0 `json:"maixcam"` - Weixin weixinConfigV0 `json:"weixin"` - QQ qqConfigV0 `json:"qq"` - DingTalk dingtalkConfigV0 `json:"dingtalk"` - Slack slackConfigV0 `json:"slack"` - Matrix matrixConfigV0 `json:"matrix"` - LINE lineConfigV0 `json:"line"` - OneBot onebotConfigV0 `json:"onebot"` - WeCom wecomConfigV0 `json:"wecom"` - WeComApp wecomappConfigV0 `json:"wecom_app"` - WeComAIBot wecomaibotConfigV0 `json:"wecom_aibot"` - Pico picoConfigV0 `json:"pico"` - IRC ircConfigV0 `json:"irc"` + WhatsApp WhatsAppConfig `json:"whatsapp"` + Telegram telegramConfigV0 `json:"telegram"` + Feishu feishuConfigV0 `json:"feishu"` + Discord discordConfigV0 `json:"discord"` + MaixCam maixcamConfigV0 `json:"maixcam"` + Weixin weixinConfigV0 `json:"weixin"` + QQ qqConfigV0 `json:"qq"` + DingTalk dingtalkConfigV0 `json:"dingtalk"` + Slack slackConfigV0 `json:"slack"` + Matrix matrixConfigV0 `json:"matrix"` + LINE lineConfigV0 `json:"line"` + OneBot onebotConfigV0 `json:"onebot"` + WeCom wecomConfigV0 `json:"wecom" envPrefix:"PICOCLAW_CHANNELS_WECOM_"` + Pico picoConfigV0 `json:"pico"` + IRC ircConfigV0 `json:"irc"` } func (v *channelsConfigV0) ToChannelsConfig() (ChannelsConfig, ChannelsSecurity) { @@ -117,45 +115,39 @@ func (v *channelsConfigV0) ToChannelsConfig() (ChannelsConfig, ChannelsSecurity) line, lineSecurity := v.LINE.ToLINEConfig() onebot, onebotSecurity := v.OneBot.ToOneBotConfig() wecom, wecomSecurity := v.WeCom.ToWeComConfig() - wecomapp, wecomappSecurity := v.WeComApp.ToWeComAppConfig() - wecomaibot, wecomaibotSecurity := v.WeComAIBot.ToWeComAIBotConfig() pico, picoSecurity := v.Pico.ToPicoConfig() irc, ircSecurity := v.IRC.ToIRCConfig() return ChannelsConfig{ - WhatsApp: v.WhatsApp, - Telegram: telegram, - Feishu: feishu, - Discord: discord, - MaixCam: maixcam, - QQ: qq, - Weixin: weixin, - DingTalk: dingtalk, - Slack: slack, - Matrix: matrix, - LINE: line, - OneBot: onebot, - WeCom: wecom, - WeComApp: wecomapp, - WeComAIBot: wecomaibot, - Pico: pico, - IRC: irc, + WhatsApp: v.WhatsApp, + Telegram: telegram, + Feishu: feishu, + Discord: discord, + MaixCam: maixcam, + QQ: qq, + Weixin: weixin, + DingTalk: dingtalk, + Slack: slack, + Matrix: matrix, + LINE: line, + OneBot: onebot, + WeCom: wecom, + Pico: pico, + IRC: irc, }, ChannelsSecurity{ - Telegram: telegramSecurity, - Feishu: feishuSecurity, - Discord: discordSecurity, - QQ: qqSecurity, - Weixin: weixinSecurity, - DingTalk: dingtalkSecurity, - Slack: slackSecurity, - Matrix: matrixSecurity, - LINE: lineSecurity, - OneBot: onebotSecurity, - WeCom: wecomSecurity, - WeComApp: wecomappSecurity, - WeComAIBot: wecomaibotSecurity, - Pico: picoSecurity, - IRC: ircSecurity, + Telegram: telegramSecurity, + Feishu: feishuSecurity, + Discord: discordSecurity, + QQ: qqSecurity, + Weixin: weixinSecurity, + DingTalk: dingtalkSecurity, + Slack: slackSecurity, + Matrix: matrixSecurity, + LINE: lineSecurity, + OneBot: onebotSecurity, + WeCom: wecomSecurity, + Pico: picoSecurity, + IRC: ircSecurity, } } @@ -473,39 +465,32 @@ func (v *onebotConfigV0) ToOneBotConfig() (OneBotConfig, *OneBotSecurity) { } type wecomConfigV0 struct { - Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_WECOM_ENABLED"` - Token string `json:"token" env:"PICOCLAW_CHANNELS_WECOM_TOKEN"` - EncodingAESKey string `json:"encoding_aes_key" env:"PICOCLAW_CHANNELS_WECOM_ENCODING_AES_KEY"` - WebhookURL string `json:"webhook_url" env:"PICOCLAW_CHANNELS_WECOM_WEBHOOK_URL"` - WebhookHost string `json:"webhook_host" env:"PICOCLAW_CHANNELS_WECOM_WEBHOOK_HOST"` - WebhookPort int `json:"webhook_port" env:"PICOCLAW_CHANNELS_WECOM_WEBHOOK_PORT"` - WebhookPath string `json:"webhook_path" env:"PICOCLAW_CHANNELS_WECOM_WEBHOOK_PATH"` - AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_WECOM_ALLOW_FROM"` - ReplyTimeout int `json:"reply_timeout" env:"PICOCLAW_CHANNELS_WECOM_REPLY_TIMEOUT"` - GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"` - ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_WECOM_REASONING_CHANNEL_ID"` + Enabled bool `json:"enabled" env:"ENABLED"` + BotID string `json:"bot_id" env:"BOT_ID"` + Secret string `json:"secret" env:"SECRET"` + WebSocketURL string `json:"websocket_url,omitempty" env:"WEBSOCKET_URL"` + SendThinkingMessage bool `json:"send_thinking_message" env:"SEND_THINKING_MESSAGE"` + DMPolicy string `json:"dm_policy,omitempty" env:"DM_POLICY"` + AllowFrom FlexibleStringSlice `json:"allow_from" env:"ALLOW_FROM"` + GroupPolicy string `json:"group_policy,omitempty" env:"GROUP_POLICY"` + GroupAllowFrom FlexibleStringSlice `json:"group_allow_from,omitempty" env:"GROUP_ALLOW_FROM"` + Groups map[string]WeComGroupConfig `json:"groups,omitempty"` + ReasoningChannelID string `json:"reasoning_channel_id" env:"REASONING_CHANNEL_ID"` } func (v *wecomConfigV0) ToWeComConfig() (WeComConfig, *WeComSecurity) { var sec *WeComSecurity - if v.Token != "" || v.EncodingAESKey != "" { - sec = &WeComSecurity{ - Token: v.Token, - EncodingAESKey: v.EncodingAESKey, - } + if v.Secret != "" { + sec = &WeComSecurity{Secret: v.Secret} } return WeComConfig{ - Enabled: v.Enabled, - token: v.Token, - encodingAESKey: v.EncodingAESKey, - WebhookURL: v.WebhookURL, - WebhookHost: v.WebhookHost, - WebhookPort: v.WebhookPort, - WebhookPath: v.WebhookPath, - AllowFrom: v.AllowFrom, - ReplyTimeout: v.ReplyTimeout, - GroupTrigger: v.GroupTrigger, - ReasoningChannelID: v.ReasoningChannelID, + Enabled: v.Enabled, + BotID: v.BotID, + secret: v.Secret, + WebSocketURL: v.WebSocketURL, + SendThinkingMessage: v.SendThinkingMessage, + AllowFrom: v.AllowFrom, + ReasoningChannelID: v.ReasoningChannelID, }, sec } @@ -537,81 +522,6 @@ func (v *weixinConfigV0) ToWeiXinConfig() (WeixinConfig, *WeixinSecurity) { }, sec } -type wecomappConfigV0 struct { - Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_WECOM_APP_ENABLED"` - CorpID string `json:"corp_id" env:"PICOCLAW_CHANNELS_WECOM_APP_CORP_ID"` - CorpSecret string `json:"corp_secret" env:"PICOCLAW_CHANNELS_WECOM_APP_CORP_SECRET"` - AgentID int64 `json:"agent_id" env:"PICOCLAW_CHANNELS_WECOM_APP_AGENT_ID"` - Token string `json:"token" env:"PICOCLAW_CHANNELS_WECOM_APP_TOKEN"` - EncodingAESKey string `json:"encoding_aes_key" env:"PICOCLAW_CHANNELS_WECOM_APP_ENCODING_AES_KEY"` - WebhookHost string `json:"webhook_host" env:"PICOCLAW_CHANNELS_WECOM_APP_WEBHOOK_HOST"` - WebhookPort int `json:"webhook_port" env:"PICOCLAW_CHANNELS_WECOM_APP_WEBHOOK_PORT"` - WebhookPath string `json:"webhook_path" env:"PICOCLAW_CHANNELS_WECOM_APP_WEBHOOK_PATH"` - AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_WECOM_APP_ALLOW_FROM"` - ReplyTimeout int `json:"reply_timeout" env:"PICOCLAW_CHANNELS_WECOM_APP_REPLY_TIMEOUT"` - GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"` - ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_WECOM_APP_REASONING_CHANNEL_ID"` -} - -func (v *wecomappConfigV0) ToWeComAppConfig() (WeComAppConfig, *WeComAppSecurity) { - var sec *WeComAppSecurity - if v.CorpSecret != "" || v.Token != "" || v.EncodingAESKey != "" { - sec = &WeComAppSecurity{ - CorpSecret: v.CorpSecret, - Token: v.Token, - EncodingAESKey: v.EncodingAESKey, - } - } - return WeComAppConfig{ - Enabled: v.Enabled, - CorpID: v.CorpID, - corpSecret: v.CorpSecret, - AgentID: v.AgentID, - token: v.Token, - encodingAESKey: v.EncodingAESKey, - WebhookHost: v.WebhookHost, - WebhookPort: v.WebhookPort, - WebhookPath: v.WebhookPath, - AllowFrom: v.AllowFrom, - ReplyTimeout: v.ReplyTimeout, - GroupTrigger: v.GroupTrigger, - ReasoningChannelID: v.ReasoningChannelID, - }, sec -} - -type wecomaibotConfigV0 struct { - Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_ENABLED"` - Token string `json:"token" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_TOKEN"` - Secret string `json:"secret" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_SECRET"` - EncodingAESKey string `json:"encoding_aes_key" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_ENCODING_AES_KEY"` - WebhookPath string `json:"webhook_path" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_WEBHOOK_PATH"` - AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_ALLOW_FROM"` - ReplyTimeout int `json:"reply_timeout" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_REPLY_TIMEOUT"` - MaxSteps int `json:"max_steps" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_MAX_STEPS"` - WelcomeMessage string `json:"welcome_message" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_WELCOME_MESSAGE"` - ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_REASONING_CHANNEL_ID"` -} - -func (v *wecomaibotConfigV0) ToWeComAIBotConfig() (WeComAIBotConfig, *WeComAIBotSecurity) { - var sec *WeComAIBotSecurity - if v.Token != "" || v.Secret != "" || v.EncodingAESKey != "" { - sec = &WeComAIBotSecurity{ - Token: v.Token, - Secret: v.Secret, - EncodingAESKey: v.EncodingAESKey, - } - } - return WeComAIBotConfig{ - Enabled: v.Enabled, - WebhookPath: v.WebhookPath, - AllowFrom: v.AllowFrom, - ReplyTimeout: v.ReplyTimeout, - MaxSteps: v.MaxSteps, - WelcomeMessage: v.WelcomeMessage, - ReasoningChannelID: v.ReasoningChannelID, - }, sec -} - type picoConfigV0 struct { Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_PICO_ENABLED"` Token string `json:"token" env:"PICOCLAW_CHANNELS_PICO_TOKEN"` diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go index b356d474f..bedd46f6e 100644 --- a/pkg/config/config_test.go +++ b/pkg/config/config_test.go @@ -360,6 +360,96 @@ func TestSaveConfig_IncludesEmptyLegacyModelField(t *testing.T) { } } +func TestSaveConfig_PreservesDisabledTelegramPlaceholder(t *testing.T) { + tmpDir := t.TempDir() + path := filepath.Join(tmpDir, "config.json") + + cfg := DefaultConfig() + cfg.Channels.Telegram.Placeholder.Enabled = false + + if err := SaveConfig(path, cfg); err != nil { + t.Fatalf("SaveConfig failed: %v", err) + } + + data, err := os.ReadFile(path) + if err != nil { + t.Fatalf("ReadFile failed: %v", err) + } + if !strings.Contains(string(data), `"placeholder": {`) { + t.Fatalf("saved config should include telegram placeholder config, got: %s", string(data)) + } + if !strings.Contains(string(data), `"enabled": false`) { + t.Fatalf("saved config should persist placeholder.enabled=false, got: %s", string(data)) + } + + loaded, err := LoadConfig(path) + if err != nil { + t.Fatalf("LoadConfig failed: %v", err) + } + if loaded.Channels.Telegram.Placeholder.Enabled { + t.Fatal("telegram placeholder should remain disabled after SaveConfig/LoadConfig round-trip") + } +} + +// TestSaveConfig_FiltersVirtualModels verifies that SaveConfig does not write +// virtual models (generated by expandMultiKeyModels) to the config file. +func TestSaveConfig_FiltersVirtualModels(t *testing.T) { + tmpDir := t.TempDir() + path := filepath.Join(tmpDir, "config.json") + + cfg := DefaultConfig() + + // Manually add a virtual model to ModelList (simulating what expandMultiKeyModels does) + primaryModel := &ModelConfig{ + ModelName: "gpt-4", + Model: "openai/gpt-4o", + apiKeys: []string{"key1"}, + } + virtualModel := &ModelConfig{ + ModelName: "gpt-4__key_1", + Model: "openai/gpt-4o", + apiKeys: []string{"key2"}, + isVirtual: true, + } + cfg.ModelList = []*ModelConfig{primaryModel, virtualModel} + + // SaveConfig should filter out virtual models + if err := SaveConfig(path, cfg); err != nil { + t.Fatalf("SaveConfig failed: %v", err) + } + + // Reload and verify + reloaded, err := LoadConfig(path) + if err != nil { + t.Fatalf("LoadConfig failed: %v", err) + } + + // Should only have the primary model, not the virtual one + if len(reloaded.ModelList) != 1 { + t.Fatalf("expected 1 model after reload, got %d", len(reloaded.ModelList)) + } + + if reloaded.ModelList[0].ModelName != "gpt-4" { + t.Errorf("expected model_name 'gpt-4', got %q", reloaded.ModelList[0].ModelName) + } + + // Verify virtual model was not persisted + for _, m := range reloaded.ModelList { + if m.ModelName == "gpt-4__key_1" { + t.Errorf("virtual model gpt-4__key_1 should not have been saved") + } + } + + // Verify the saved file does not contain the virtual model name + data, err := os.ReadFile(path) + if err != nil { + t.Fatalf("ReadFile failed: %v", err) + } + if strings.Contains(string(data), "gpt-4__key_1") { + t.Errorf("saved config should not contain virtual model name 'gpt-4__key_1'") + } +} + // TestConfig_Complete verifies all config fields are set func TestConfig_Complete(t *testing.T) { cfg := DefaultConfig() @@ -1372,8 +1462,7 @@ func TestFilterSensitiveData_AllTokenTypes(t *testing.T) { Feishu: &FeishuSecurity{AppSecret: "feishu-app-secret-123", EncryptKey: "feishu-encrypt-key"}, DingTalk: &DingTalkSecurity{ClientSecret: "dingtalk-client-secret"}, OneBot: &OneBotSecurity{AccessToken: "onebot-access-token"}, - WeCom: &WeComSecurity{Token: "wecom-token", EncodingAESKey: "wecom-aes-key"}, - WeComApp: &WeComAppSecurity{CorpSecret: "wecom-app-secret", Token: "wecom-app-token"}, + WeCom: &WeComSecurity{Secret: "wecom-secret"}, Pico: &PicoSecurity{Token: "pico-token-abc123"}, IRC: &IRCSecurity{ Password: "irc-password", diff --git a/pkg/config/defaults.go b/pkg/config/defaults.go index c1d0ea0f6..ba1a5a0cf 100644 --- a/pkg/config/defaults.go +++ b/pkg/config/defaults.go @@ -113,6 +113,8 @@ func DefaultConfig() *Config { Enabled: true, Text: "Thinking... 💭", }, + CryptoDatabasePath: "", + CryptoPassphrase: "", }, LINE: LINEConfig{ Enabled: false, @@ -129,32 +131,11 @@ func DefaultConfig() *Config { AllowFrom: FlexibleStringSlice{}, }, WeCom: WeComConfig{ - Enabled: false, - WebhookURL: "", - WebhookHost: "0.0.0.0", - WebhookPort: 18793, - WebhookPath: "/webhook/wecom", - AllowFrom: FlexibleStringSlice{}, - ReplyTimeout: 5, - }, - WeComApp: WeComAppConfig{ - Enabled: false, - CorpID: "", - AgentID: 0, - WebhookHost: "0.0.0.0", - WebhookPort: 18792, - WebhookPath: "/webhook/wecom-app", - AllowFrom: FlexibleStringSlice{}, - ReplyTimeout: 5, - }, - WeComAIBot: WeComAIBotConfig{ - Enabled: false, - WebhookPath: "/webhook/wecom-aibot", - AllowFrom: FlexibleStringSlice{}, - ReplyTimeout: 5, - MaxSteps: 10, - WelcomeMessage: "Hello! I'm your AI assistant. How can I help you today?", - ProcessingMessage: DefaultWeComAIBotProcessingMessage, + Enabled: false, + BotID: "", + WebSocketURL: "wss://openws.work.weixin.qq.com", + SendThinkingMessage: true, + AllowFrom: FlexibleStringSlice{}, }, Weixin: WeixinConfig{ Enabled: false, diff --git a/pkg/config/example_security_usage.go b/pkg/config/example_security_usage.go index 0a6749537..42a1831b0 100644 --- a/pkg/config/example_security_usage.go +++ b/pkg/config/example_security_usage.go @@ -104,30 +104,30 @@ Note: Sensitive fields are omitted because they're loaded from .security.yml ], "channels": { "telegram": { - "enabled": true" + "enabled": true // token is automatically loaded from .security.yml }, "discord": { - "enabled": true" + "enabled": true // token is automatically loaded from .security.yml } }, "tools": { "web": { "brave": { - "enabled": true" + "enabled": true // api_key is automatically loaded from .security.yml }, "tavily": { - "enabled": true" + "enabled": true // api_key is automatically loaded from .security.yml }, "glm_search": { - "enabled": true" + "enabled": true // api_key is automatically loaded from .security.yml }, "baidu_search": { - "enabled": true" + "enabled": true // api_key is automatically loaded from .security.yml } } @@ -237,8 +237,6 @@ channels: nickserv_password: "value" sasl_password: "value" -``` - ## Web Tool API Keys **Brave, Tavily, Perplexity:** @@ -429,13 +427,13 @@ web: "tools": { "web": { "brave": { - "enabled": true" + "enabled": true }, "tavily": { - "enabled": true" + "enabled": true }, "glm_search": { - "enabled": true" + "enabled": true } } } diff --git a/pkg/config/migration_integration_test.go b/pkg/config/migration_integration_test.go index c884a6b5d..49d2a5831 100644 --- a/pkg/config/migration_integration_test.go +++ b/pkg/config/migration_integration_test.go @@ -566,3 +566,118 @@ func TestMigration_Integration_ModelNameField(t *testing.T) { t.Errorf("ModelFallbacks[0] = %q, want %q", cfg.Agents.Defaults.ModelFallbacks[0], "deepseek-chat") } } + +// TestMigration_PreservesExistingSecurityConfig tests that when migrating from v0 to v1, +// existing .security.yml values (e.g., loaded from environment variables) are preserved +// and not overwritten by empty values from the legacy config. +func TestMigration_PreservesExistingSecurityConfig(t *testing.T) { + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "config.json") + securityPath := filepath.Join(tmpDir, ".security.yml") + + // Create a legacy config (version 0) with model_list and channel config + // The model_list doesn't have api_keys, they should come from existing .security.yml + legacyConfig := `{ + "agents": { + "defaults": { + "provider": "openai", + "model": "gpt-4" + } + }, + "model_list": [ + { + "model_name": "openai", + "model": "openai/gpt-4" + } + ], + "channels": { + "telegram": { + "enabled": true + } + }, + "gateway": { + "host": "127.0.0.1", + "port": 18790 + }, + "tools": { + "web": {"enabled": true} + }, + "heartbeat": { + "enabled": true, + "interval": 30 + }, + "devices": { + "enabled": false + } + }` + + // Create an existing .security.yml with values that might come from env vars + existingSecurity := `model_list: + openai:0: + api_keys: + - sk-existing-key-from-env +channels: + telegram: + token: existing-telegram-token-from-env + discord: + token: existing-discord-token-from-env +web: + brave: + api_keys: + - existing-brave-key +` + + if err := os.WriteFile(configPath, []byte(legacyConfig), 0o600); err != nil { + t.Fatalf("Failed to write legacy config: %v", err) + } + + if err := os.WriteFile(securityPath, []byte(existingSecurity), 0o600); err != nil { + t.Fatalf("Failed to write existing security config: %v", err) + } + + // Load the config - this should trigger migration + cfg, err := LoadConfig(configPath) + if err != nil { + t.Fatalf("LoadConfig failed: %v", err) + } + + // Verify that the migrated config has the existing security values + // Telegram token should be preserved + if cfg.Channels.Telegram.Token() != "existing-telegram-token-from-env" { + t.Errorf("Telegram token was overwritten: got %q, want %q", + cfg.Channels.Telegram.Token(), "existing-telegram-token-from-env") + } + + // Discord token should be preserved (even though legacy config didn't have it) + if cfg.Channels.Discord.Token() != "existing-discord-token-from-env" { + t.Errorf("Discord token was overwritten: got %q, want %q", + cfg.Channels.Discord.Token(), "existing-discord-token-from-env") + } + + // Model API key should be preserved + if cfg.ModelList[0].APIKey() != "sk-existing-key-from-env" { + t.Errorf("Model API key was overwritten: got %q, want %q", + cfg.ModelList[0].APIKey(), "sk-existing-key-from-env") + } + + // Brave API key should be preserved + if cfg.Tools.Web.Brave.APIKey() != "existing-brave-key" { + t.Errorf("Brave API key was overwritten: got %q, want %q", + cfg.Tools.Web.Brave.APIKey(), "existing-brave-key") + } + + // Reload the security config from disk to verify it wasn't corrupted + reloadedSec, err := loadSecurityConfig(securityPath) + if err != nil { + t.Fatalf("Failed to reload security config: %v", err) + } + + if reloadedSec.Channels.Telegram == nil || + reloadedSec.Channels.Telegram.Token != "existing-telegram-token-from-env" { + t.Error("Telegram token not preserved in .security.yml file") + } + + if reloadedSec.Channels.Discord == nil || reloadedSec.Channels.Discord.Token != "existing-discord-token-from-env" { + t.Error("Discord token not preserved in .security.yml file") + } +} diff --git a/pkg/config/multikey_test.go b/pkg/config/multikey_test.go index cc529905c..c17fcc53b 100644 --- a/pkg/config/multikey_test.go +++ b/pkg/config/multikey_test.go @@ -232,6 +232,78 @@ func TestExpandMultiKeyModels_PreservesOtherFields(t *testing.T) { } } +func TestExpandMultiKeyModels_IsVirtualFlag(t *testing.T) { + models := []*ModelConfig{ + { + ModelName: "gpt-4", + Model: "openai/gpt-4o", + apiKeys: []string{"key1", "key2", "key3"}, + }, + } + + result := expandMultiKeyModels(models) + + // Should expand to 3 models + if len(result) != 3 { + t.Fatalf("expected 3 models, got %d", len(result)) + } + + // Primary model should NOT be virtual + primary := result[2] + if primary.isVirtual { + t.Errorf("primary model should not be virtual") + } + if primary.ModelName != "gpt-4" { + t.Errorf("expected primary model_name 'gpt-4', got %q", primary.ModelName) + } + + // Virtual models should have isVirtual = true + virtual1 := result[0] + if !virtual1.isVirtual { + t.Errorf("gpt-4__key_1 should be virtual") + } + if virtual1.ModelName != "gpt-4__key_1" { + t.Errorf("expected virtual model_name 'gpt-4__key_1', got %q", virtual1.ModelName) + } + + virtual2 := result[1] + if !virtual2.isVirtual { + t.Errorf("gpt-4__key_2 should be virtual") + } + if virtual2.ModelName != "gpt-4__key_2" { + t.Errorf("expected virtual model_name 'gpt-4__key_2', got %q", virtual2.ModelName) + } + + // IsVirtual() method should work + if !virtual1.IsVirtual() { + t.Errorf("IsVirtual() should return true for virtual model") + } + if primary.IsVirtual() { + t.Errorf("IsVirtual() should return false for primary model") + } +} + +func TestExpandMultiKeyModels_SingleKey_NotVirtual(t *testing.T) { + models := []*ModelConfig{ + { + ModelName: "gpt-4", + Model: "openai/gpt-4o", + apiKeys: []string{"single-key"}, + }, + } + + result := expandMultiKeyModels(models) + + if len(result) != 1 { + t.Fatalf("expected 1 model, got %d", len(result)) + } + + // Single key model should NOT be virtual + if result[0].isVirtual { + t.Errorf("single key model should not be virtual") + } +} + func TestMergeAPIKeys(t *testing.T) { tests := []struct { name string diff --git a/pkg/config/security.go b/pkg/config/security.go index 816d465c7..47ad1a5b0 100644 --- a/pkg/config/security.go +++ b/pkg/config/security.go @@ -25,13 +25,32 @@ const ( SecurityConfigFile = ".security.yml" ) +func normalizeSecurityConfig(sec *SecurityConfig) *SecurityConfig { + if sec == nil { + sec = &SecurityConfig{} + } + if sec.ModelList == nil { + sec.ModelList = map[string]ModelSecurityEntry{} + } + if sec.Channels == nil { + sec.Channels = &ChannelsSecurity{} + } + if sec.Web == nil { + sec.Web = &WebToolsSecurity{} + } + if sec.Skills == nil { + sec.Skills = &SkillsSecurity{} + } + return sec +} + // SecurityConfig stores all sensitive data (API keys, tokens, secrets, passwords) // This data is loaded from security.yml and kept separate from the main config type SecurityConfig struct { // Model API keys. Map key is model_name, can include suffix like "abc:0", "abc:1" // for load balancing with same model_name. The suffix ":N" is used to distinguish // multiple configs that share the same base model_name. - ModelList map[string]ModelSecurityEntry `yaml:"model_list,omitempty"` + ModelList map[string]ModelSecurityEntry `yaml:"model_list"` // Channel tokens/secrets Channels *ChannelsSecurity `yaml:"channels,omitempty"` @@ -50,21 +69,19 @@ type ModelSecurityEntry struct { // ChannelsSecurity stores channel-related security data type ChannelsSecurity struct { - Telegram *TelegramSecurity `yaml:"telegram,omitempty"` - Feishu *FeishuSecurity `yaml:"feishu,omitempty"` - Discord *DiscordSecurity `yaml:"discord,omitempty"` - Weixin *WeixinSecurity `yaml:"weixin,omitempty"` - QQ *QQSecurity `yaml:"qq,omitempty"` - DingTalk *DingTalkSecurity `yaml:"dingtalk,omitempty"` - Slack *SlackSecurity `yaml:"slack,omitempty"` - Matrix *MatrixSecurity `yaml:"matrix,omitempty"` - LINE *LINESecurity `yaml:"line,omitempty"` - OneBot *OneBotSecurity `yaml:"onebot,omitempty"` - WeCom *WeComSecurity `yaml:"wecom,omitempty"` - WeComApp *WeComAppSecurity `yaml:"wecom_app,omitempty"` - WeComAIBot *WeComAIBotSecurity `yaml:"wecom_aibot,omitempty"` - Pico *PicoSecurity `yaml:"pico,omitempty"` - IRC *IRCSecurity `yaml:"irc,omitempty"` + Telegram *TelegramSecurity `yaml:"telegram,omitempty"` + Feishu *FeishuSecurity `yaml:"feishu,omitempty"` + Discord *DiscordSecurity `yaml:"discord,omitempty"` + Weixin *WeixinSecurity `yaml:"weixin,omitempty"` + QQ *QQSecurity `yaml:"qq,omitempty"` + DingTalk *DingTalkSecurity `yaml:"dingtalk,omitempty"` + Slack *SlackSecurity `yaml:"slack,omitempty"` + Matrix *MatrixSecurity `yaml:"matrix,omitempty"` + LINE *LINESecurity `yaml:"line,omitempty"` + OneBot *OneBotSecurity `yaml:"onebot,omitempty"` + WeCom *WeComSecurity `yaml:"wecom,omitempty"` + Pico *PicoSecurity `yaml:"pico,omitempty"` + IRC *IRCSecurity `yaml:"irc,omitempty"` } type TelegramSecurity struct { @@ -112,20 +129,7 @@ type OneBotSecurity struct { } type WeComSecurity struct { - Token string `yaml:"token,omitempty" env:"PICOCLAW_CHANNELS_WECOM_TOKEN"` - EncodingAESKey string `yaml:"encoding_aes_key,omitempty" env:"PICOCLAW_CHANNELS_WECOM_ENCODING_AES_KEY"` -} - -type WeComAppSecurity struct { - CorpSecret string `yaml:"corp_secret,omitempty" env:"PICOCLAW_CHANNELS_WECOM_APP_CORP_SECRET"` - Token string `yaml:"token,omitempty" env:"PICOCLAW_CHANNELS_WECOM_APP_TOKEN"` - EncodingAESKey string `yaml:"encoding_aes_key,omitempty" env:"PICOCLAW_CHANNELS_WECOM_APP_ENCODING_AES_KEY"` -} - -type WeComAIBotSecurity struct { - Secret string `yaml:"secret,omitempty" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_SECRET"` - Token string `yaml:"token,omitempty" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_TOKEN"` - EncodingAESKey string `yaml:"encoding_aes_key,omitempty" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_ENCODING_AES_KEY"` + Secret string `yaml:"secret,omitempty" env:"PICOCLAW_CHANNELS_WECOM_SECRET"` } type PicoSecurity struct { @@ -191,7 +195,7 @@ func loadSecurityConfig(securityPath string) (*SecurityConfig, error) { data, err := os.ReadFile(securityPath) if err != nil { if os.IsNotExist(err) { - return &SecurityConfig{}, nil + return normalizeSecurityConfig(nil), nil } return nil, fmt.Errorf("failed to read security config: %w", err) } @@ -210,7 +214,7 @@ func loadSecurityConfig(securityPath string) (*SecurityConfig, error) { return nil, err } - return &sec, nil + return normalizeSecurityConfig(&sec), nil } // saveSecurityConfig saves the security configuration to security.yml @@ -225,6 +229,134 @@ func saveSecurityConfig(securityPath string, sec *SecurityConfig) error { return fileutil.WriteFileAtomic(securityPath, buf.Bytes(), 0o600) } +// mergeSecurityConfig merges two SecurityConfig instances, preferring non-empty values from 'newer'. +// This is used during config migration to preserve existing security data while adding new entries. +func mergeSecurityConfig(existing, newer *SecurityConfig) *SecurityConfig { + if existing == nil { + return normalizeSecurityConfig(newer) + } + if newer == nil { + return normalizeSecurityConfig(existing) + } + + result := normalizeSecurityConfig(nil) + + // Merge ModelList: prefer newer if it has keys, otherwise use existing + for k, v := range existing.ModelList { + result.ModelList[k] = v + } + for k, v := range newer.ModelList { + if len(v.APIKeys) > 0 { + result.ModelList[k] = v + } + } + + // Merge Channels + if existing.Channels != nil { + result.Channels = existing.Channels + } + if newer.Channels != nil { + if result.Channels == nil { + result.Channels = &ChannelsSecurity{} + } + mergeChannelsSecurity(result.Channels, newer.Channels) + } + + // Merge Web + if existing.Web != nil { + result.Web = existing.Web + } + if newer.Web != nil { + if result.Web == nil { + result.Web = &WebToolsSecurity{} + } + mergeWebToolsSecurity(result.Web, newer.Web) + } + + // Merge Skills + if existing.Skills != nil { + result.Skills = existing.Skills + } + if newer.Skills != nil { + if result.Skills == nil { + result.Skills = &SkillsSecurity{} + } + mergeSkillsSecurity(result.Skills, newer.Skills) + } + + return result +} + +func mergeChannelsSecurity(dst, src *ChannelsSecurity) { + if src.Telegram != nil && src.Telegram.Token != "" { + dst.Telegram = src.Telegram + } + if src.Feishu != nil && + (src.Feishu.AppSecret != "" || src.Feishu.EncryptKey != "" || src.Feishu.VerificationToken != "") { + dst.Feishu = src.Feishu + } + if src.Discord != nil && src.Discord.Token != "" { + dst.Discord = src.Discord + } + if src.Weixin != nil && src.Weixin.Token != "" { + dst.Weixin = src.Weixin + } + if src.QQ != nil && src.QQ.AppSecret != "" { + dst.QQ = src.QQ + } + if src.DingTalk != nil && src.DingTalk.ClientSecret != "" { + dst.DingTalk = src.DingTalk + } + if src.Slack != nil && (src.Slack.BotToken != "" || src.Slack.AppToken != "") { + dst.Slack = src.Slack + } + if src.Matrix != nil && src.Matrix.AccessToken != "" { + dst.Matrix = src.Matrix + } + if src.LINE != nil && (src.LINE.ChannelSecret != "" || src.LINE.ChannelAccessToken != "") { + dst.LINE = src.LINE + } + if src.OneBot != nil && src.OneBot.AccessToken != "" { + dst.OneBot = src.OneBot + } + if src.WeCom != nil && src.WeCom.Secret != "" { + dst.WeCom = src.WeCom + } + if src.Pico != nil && src.Pico.Token != "" { + dst.Pico = src.Pico + } + if src.IRC != nil && (src.IRC.Password != "" || src.IRC.NickServPassword != "" || src.IRC.SASLPassword != "") { + dst.IRC = src.IRC + } +} + +func mergeWebToolsSecurity(dst, src *WebToolsSecurity) { + if src.Brave != nil && len(src.Brave.APIKeys) > 0 { + dst.Brave = src.Brave + } + if src.Tavily != nil && len(src.Tavily.APIKeys) > 0 { + dst.Tavily = src.Tavily + } + if src.Perplexity != nil && len(src.Perplexity.APIKeys) > 0 { + dst.Perplexity = src.Perplexity + } + if src.GLMSearch != nil && src.GLMSearch.APIKey != "" { + dst.GLMSearch = src.GLMSearch + } + if src.BaiduSearch != nil && src.BaiduSearch.APIKey != "" { + dst.BaiduSearch = src.BaiduSearch + } +} + +func mergeSkillsSecurity(dst, src *SkillsSecurity) { + if src.Github != nil && src.Github.Token != "" { + dst.Github = src.Github + } + if src.ClawHub != nil && src.ClawHub.AuthToken != "" { + dst.ClawHub = src.ClawHub + } +} + // SensitiveDataCache caches the compiled regex for filtering sensitive data. // SensitiveDataCache caches the strings.Replacer for filtering sensitive data. // Computed once on first access via sync.Once. diff --git a/pkg/config/security_integration_test.go b/pkg/config/security_integration_test.go index c1e1a2340..03990ce5b 100644 --- a/pkg/config/security_integration_test.go +++ b/pkg/config/security_integration_test.go @@ -17,13 +17,12 @@ import ( // Test JSON unmarshal of private fields func TestJSONUnmarshalPrivateFields(t *testing.T) { - //nolint: govet type testStruct struct { PublicField string `json:"public"` - privateField string `json:"private"` + privateField string } - data := `{"public": "pub", "private": "priv"}` + data := `{"public": "pub", "privateField": "priv"}` var s testStruct if err := json.Unmarshal([]byte(data), &s); err != nil { t.Fatalf("JSON unmarshal failed: %v", err) @@ -35,9 +34,8 @@ func TestJSONUnmarshalPrivateFields(t *testing.T) { if s.PublicField != "pub" { t.Errorf("PublicField = %q, want 'pub'", s.PublicField) } - // This should fail because privateField is unexported - if s.privateField != "priv" { - t.Logf("privateField = %q, want 'priv' - THIS IS EXPECTED TO FAIL", s.privateField) + if s.privateField != "" { + t.Errorf("privateField = %q, want empty because unexported fields are ignored", s.privateField) } } @@ -242,15 +240,7 @@ func TestAllSecurityKeysAccessible(t *testing.T) { }, "wecom": { "enabled": true, - "webhook_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook" - }, - "wecom_app": { - "enabled": true, - "corp_id": "test_corp_id", - "agent_id": 123456 - }, - "wecom_aibot": { - "enabled": true + "bot_id": "test_wecom_bot_id" }, "pico": { "enabled": true @@ -317,15 +307,7 @@ channels: onebot: access_token: "onebot_test_access_token" wecom: - token: "wecom_test_webhook_token" - encoding_aes_key: "wecom_test_aes_key" - wecom_app: - corp_secret: "wecom_app_test_corp_secret" - token: "wecom_app_test_token" - encoding_aes_key: "wecom_app_test_aes_key" - wecom_aibot: - token: "wecom_aibot_test_token" - encoding_aes_key: "wecom_aibot_test_aes_key" + secret: "wecom_test_secret" pico: token: "pico_test_token" irc: @@ -411,24 +393,10 @@ skills: t.Logf("OneBot AccessToken(): %s", cfg.Channels.OneBot.AccessToken()) // WeCom - assert.Equal(t, "wecom_test_webhook_token", cfg.Channels.WeCom.Token()) - assert.Equal(t, "wecom_test_aes_key", cfg.Channels.WeCom.EncodingAESKey()) - t.Logf("WeCom Token(): %s", cfg.Channels.WeCom.Token()) - t.Logf("WeCom EncodingAESKey(): %s", cfg.Channels.WeCom.EncodingAESKey()) - - // WeCom App - assert.Equal(t, "wecom_app_test_corp_secret", cfg.Channels.WeComApp.CorpSecret()) - assert.Equal(t, "wecom_app_test_token", cfg.Channels.WeComApp.Token()) - assert.Equal(t, "wecom_app_test_aes_key", cfg.Channels.WeComApp.EncodingAESKey()) - t.Logf("WeComApp CorpSecret(): %s", cfg.Channels.WeComApp.CorpSecret()) - t.Logf("WeComApp Token(): %s", cfg.Channels.WeComApp.Token()) - t.Logf("WeComApp EncodingAESKey(): %s", cfg.Channels.WeComApp.EncodingAESKey()) - - // WeCom AI Bot - assert.Equal(t, "wecom_aibot_test_token", cfg.Channels.WeComAIBot.Token()) - assert.Equal(t, "wecom_aibot_test_aes_key", cfg.Channels.WeComAIBot.EncodingAESKey()) - t.Logf("WeComAIBot Token(): %s", cfg.Channels.WeComAIBot.Token()) - t.Logf("WeComAIBot EncodingAESKey(): %s", cfg.Channels.WeComAIBot.EncodingAESKey()) + assert.Equal(t, "test_wecom_bot_id", cfg.Channels.WeCom.BotID) + assert.Equal(t, "wecom_test_secret", cfg.Channels.WeCom.Secret()) + t.Logf("WeCom BotID: %s", cfg.Channels.WeCom.BotID) + t.Logf("WeCom Secret(): %s", cfg.Channels.WeCom.Secret()) // Pico assert.Equal(t, "pico_test_token", cfg.Channels.Pico.Token()) diff --git a/pkg/config/security_test.go b/pkg/config/security_test.go index af08a67db..0f260ed59 100644 --- a/pkg/config/security_test.go +++ b/pkg/config/security_test.go @@ -20,6 +20,9 @@ func TestSecurityConfig(t *testing.T) { require.NoError(t, err) assert.NotNil(t, sec) assert.Empty(t, sec.ModelList) + assert.NotNil(t, sec.Channels) + assert.NotNil(t, sec.Web) + assert.NotNil(t, sec.Skills) }) } diff --git a/pkg/gateway/channel_matrix.go b/pkg/gateway/channel_matrix.go new file mode 100644 index 000000000..f753c60e2 --- /dev/null +++ b/pkg/gateway/channel_matrix.go @@ -0,0 +1,21 @@ +//go:build !mipsle && !netbsd + +package gateway + +import ( + // Matrix currently pulls in mautrix crypto and modernc sqlite transitively. + // + // We exclude it on: + // - linux/mipsle: mautrix crypto falls back to libolm when the `goolm` build + // tag is unavailable, and modernc.org/sqlite/modernc.org/libc also lacks a + // working build path for our mipsle + softfloat target. + // - netbsd/*: modernc.org/sqlite v1.46.1 fails to compile due to broken + // generated mutex code on NetBSD (for example sqlite_netbsd_amd64.go calls + // mu.enter/mu.leave, but the generated mutex type does not define them). + // + // This means Matrix is currently unavailable on those targets. The proper + // long-term fix is to split Matrix basic support from its E2EE/sqlite-backed + // crypto path, or to upgrade/replace the upstream sqlite dependency once the + // affected targets are supported. + _ "github.com/sipeed/picoclaw/pkg/channels/matrix" +) diff --git a/pkg/gateway/gateway.go b/pkg/gateway/gateway.go index fc2465747..03d7dfe0c 100644 --- a/pkg/gateway/gateway.go +++ b/pkg/gateway/gateway.go @@ -20,7 +20,6 @@ import ( _ "github.com/sipeed/picoclaw/pkg/channels/irc" _ "github.com/sipeed/picoclaw/pkg/channels/line" _ "github.com/sipeed/picoclaw/pkg/channels/maixcam" - _ "github.com/sipeed/picoclaw/pkg/channels/matrix" _ "github.com/sipeed/picoclaw/pkg/channels/onebot" _ "github.com/sipeed/picoclaw/pkg/channels/pico" _ "github.com/sipeed/picoclaw/pkg/channels/qq" diff --git a/pkg/migrate/sources/openclaw/common.go b/pkg/migrate/sources/openclaw/common.go index 337c950d0..938f15b80 100644 --- a/pkg/migrate/sources/openclaw/common.go +++ b/pkg/migrate/sources/openclaw/common.go @@ -13,17 +13,16 @@ var migrateableDirs = []string{ } var supportedChannels = map[string]bool{ - "whatsapp": true, - "telegram": true, - "feishu": true, - "discord": true, - "maixcam": true, - "qq": true, - "dingtalk": true, - "slack": true, - "matrix": true, - "line": true, - "onebot": true, - "wecom": true, - "wecom_app": true, + "whatsapp": true, + "telegram": true, + "feishu": true, + "discord": true, + "maixcam": true, + "qq": true, + "dingtalk": true, + "slack": true, + "matrix": true, + "line": true, + "onebot": true, + "wecom": true, } diff --git a/pkg/tools/mcp_tool.go b/pkg/tools/mcp_tool.go index 6e53cf354..5bffb4e89 100644 --- a/pkg/tools/mcp_tool.go +++ b/pkg/tools/mcp_tool.go @@ -5,9 +5,13 @@ import ( "encoding/json" "fmt" "hash/fnv" + "os" "strings" + "time" "github.com/modelcontextprotocol/go-sdk/mcp" + + "github.com/sipeed/picoclaw/pkg/media" ) // MCPManager defines the interface for MCP manager operations @@ -25,6 +29,7 @@ type MCPTool struct { manager MCPManager serverName string tool *mcp.Tool + mediaStore media.MediaStore } // NewMCPTool creates a new MCP tool wrapper @@ -36,6 +41,10 @@ func NewMCPTool(manager MCPManager, serverName string, tool *mcp.Tool) *MCPTool } } +func (t *MCPTool) SetMediaStore(store media.MediaStore) { + t.mediaStore = store +} + // sanitizeIdentifierComponent normalizes a string so it can be safely used // as part of a tool/function identifier for downstream providers. // It: @@ -218,13 +227,7 @@ func (t *MCPTool) Execute(ctx context.Context, args map[string]any) *ToolResult WithError(fmt.Errorf("MCP tool error: %s", errMsg)) } - // Extract text content from result - output := extractContentText(result.Content) - - return &ToolResult{ - ForLLM: output, - IsError: false, - } + return t.normalizeResultContent(ctx, result.Content) } // extractContentText extracts text from MCP content array @@ -233,14 +236,269 @@ func extractContentText(content []mcp.Content) string { for _, c := range content { switch v := c.(type) { case *mcp.TextContent: - parts = append(parts, v.Text) + parts = append(parts, sanitizeToolLLMContent(v.Text)) case *mcp.ImageContent: - // For images, just indicate that an image was returned - parts = append(parts, fmt.Sprintf("[Image: %s]", v.MIMEType)) + parts = append(parts, fmt.Sprintf("[Image: %s]", normalizedMIMEType(v.MIMEType))) + case *mcp.AudioContent: + parts = append(parts, fmt.Sprintf("[Audio: %s]", normalizedMIMEType(v.MIMEType))) + case *mcp.ResourceLink: + parts = append(parts, summarizeResourceLink(v)) + case *mcp.EmbeddedResource: + parts = append(parts, summarizeEmbeddedResource(v)) default: // For other content types, use string representation parts = append(parts, fmt.Sprintf("[Content: %T]", v)) } } - return strings.Join(parts, "\n") + return sanitizeToolLLMContent(strings.Join(parts, "\n")) +} + +func (t *MCPTool) normalizeResultContent(ctx context.Context, content []mcp.Content) *ToolResult { + llmParts := make([]string, 0, len(content)) + mediaRefs := make([]string, 0, len(content)) + + for _, c := range content { + switch v := c.(type) { + case *mcp.TextContent: + text := strings.TrimSpace(sanitizeToolLLMContent(v.Text)) + if text != "" { + llmParts = append(llmParts, text) + } + case *mcp.ImageContent: + ref, note := t.storeBinaryContent( + ctx, + "image", + normalizedMIMEType(v.MIMEType), + v.Data, + v.Annotations, + ) + if ref != "" { + mediaRefs = append(mediaRefs, ref) + } + if note != "" { + llmParts = append(llmParts, note) + } + case *mcp.AudioContent: + ref, note := t.storeBinaryContent( + ctx, + "audio", + normalizedMIMEType(v.MIMEType), + v.Data, + v.Annotations, + ) + if ref != "" { + mediaRefs = append(mediaRefs, ref) + } + if note != "" { + llmParts = append(llmParts, note) + } + case *mcp.ResourceLink: + llmParts = append(llmParts, summarizeResourceLink(v)) + case *mcp.EmbeddedResource: + ref, note := t.storeEmbeddedResource(ctx, v) + if ref != "" { + mediaRefs = append(mediaRefs, ref) + } + if note != "" { + llmParts = append(llmParts, note) + } + default: + llmParts = append(llmParts, fmt.Sprintf("[MCP returned unsupported content type %T]", v)) + } + } + + result := &ToolResult{ + ForLLM: strings.Join(compactStrings(llmParts), "\n"), + Media: mediaRefs, + } + return result +} + +func (t *MCPTool) storeEmbeddedResource(ctx context.Context, content *mcp.EmbeddedResource) (string, string) { + if content == nil || content.Resource == nil { + return "", "[MCP returned an embedded resource without data.]" + } + + resource := content.Resource + if len(resource.Blob) > 0 { + return t.storeBinaryContent( + ctx, + "resource", + normalizedMIMEType(resource.MIMEType), + resource.Blob, + content.Annotations, + ) + } + + if strings.TrimSpace(resource.Text) != "" { + return "", sanitizeToolLLMContent(resource.Text) + } + + return "", summarizeEmbeddedResource(content) +} + +func (t *MCPTool) storeBinaryContent( + ctx context.Context, + kind string, + mimeType string, + data []byte, + annotations *mcp.Annotations, +) (string, string) { + if len(data) == 0 { + return "", fmt.Sprintf("[MCP returned %s content (%s) but it was empty.]", kind, mimeType) + } + if !annotationsAllowUser(annotations) { + return "", fmt.Sprintf( + "[MCP returned %s content (%s) for non-user audience; omitted from model context.]", + kind, + mimeType, + ) + } + if t.mediaStore == nil { + return "", fmt.Sprintf( + "[MCP returned %s content (%s); omitted from model context because media delivery is unavailable.]", + kind, + mimeType, + ) + } + + channel := ToolChannel(ctx) + chatID := ToolChatID(ctx) + if channel == "" || chatID == "" { + return "", fmt.Sprintf( + "[MCP returned %s content (%s); omitted from model context because no target chat was available.]", + kind, + mimeType, + ) + } + + dir := media.TempDir() + if err := os.MkdirAll(dir, 0o700); err != nil { + return "", fmt.Sprintf("[MCP returned %s content (%s) but it could not be stored.]", kind, mimeType) + } + + ext := extensionForMIMEType(mimeType) + tmpFile, err := os.CreateTemp(dir, "mcp-*"+ext) + if err != nil { + return "", fmt.Sprintf("[MCP returned %s content (%s) but it could not be stored.]", kind, mimeType) + } + tmpPath := tmpFile.Name() + if _, err = tmpFile.Write(data); err != nil { + _ = tmpFile.Close() + _ = os.Remove(tmpPath) + return "", fmt.Sprintf("[MCP returned %s content (%s) but it could not be stored.]", kind, mimeType) + } + if err = tmpFile.Close(); err != nil { + _ = os.Remove(tmpPath) + return "", fmt.Sprintf("[MCP returned %s content (%s) but it could not be stored.]", kind, mimeType) + } + + scope := fmt.Sprintf( + "tool:mcp:%s:%s:%s:%d", + sanitizeIdentifierComponent(t.serverName), + channel, + chatID, + time.Now().UnixNano(), + ) + filename := fmt.Sprintf( + "%s_%s%s", + sanitizeIdentifierComponent(t.serverName), + sanitizeIdentifierComponent(t.tool.Name), + ext, + ) + + ref, err := t.mediaStore.Store(tmpPath, media.MediaMeta{ + Filename: filename, + ContentType: mimeType, + Source: fmt.Sprintf( + "tool:mcp:%s:%s", + sanitizeIdentifierComponent(t.serverName), + sanitizeIdentifierComponent(t.tool.Name), + ), + }, scope) + if err != nil { + _ = os.Remove(tmpPath) + return "", fmt.Sprintf( + "[MCP returned %s content (%s) but it could not be registered as media.]", + kind, + mimeType, + ) + } + + return ref, fmt.Sprintf( + "[MCP returned %s content (%s); omitted from model context and stored as a local media artifact.]", + kind, + mimeType, + ) +} + +func summarizeResourceLink(content *mcp.ResourceLink) string { + if content == nil { + return "[MCP returned an empty resource link.]" + } + + parts := []string{"[MCP returned resource link"} + if content.Name != "" { + parts = append(parts, fmt.Sprintf("name=%q", content.Name)) + } + if content.URI != "" { + parts = append(parts, fmt.Sprintf("uri=%q", content.URI)) + } + if content.MIMEType != "" { + parts = append(parts, fmt.Sprintf("mime=%q", content.MIMEType)) + } + if content.Description != "" { + desc := strings.TrimSpace(content.Description) + if len(desc) > 200 { + desc = desc[:200] + "..." + } + parts = append(parts, fmt.Sprintf("description=%q", desc)) + } + return strings.Join(parts, ", ") + "]" +} + +func summarizeEmbeddedResource(content *mcp.EmbeddedResource) string { + if content == nil || content.Resource == nil { + return "[MCP returned an embedded resource.]" + } + + resource := content.Resource + if resource.URI != "" { + return fmt.Sprintf( + "[MCP returned embedded resource %q (%s).]", + resource.URI, + normalizedMIMEType(resource.MIMEType), + ) + } + return fmt.Sprintf("[MCP returned embedded resource (%s).]", normalizedMIMEType(resource.MIMEType)) +} + +func annotationsAllowUser(annotations *mcp.Annotations) bool { + if annotations == nil || len(annotations.Audience) == 0 { + return true + } + for _, audience := range annotations.Audience { + if strings.EqualFold(string(audience), "user") { + return true + } + } + return false +} + +func normalizedMIMEType(mimeType string) string { + if strings.TrimSpace(mimeType) == "" { + return "application/octet-stream" + } + return mimeType +} + +func compactStrings(parts []string) []string { + compact := make([]string, 0, len(parts)) + for _, part := range parts { + if strings.TrimSpace(part) == "" { + continue + } + compact = append(compact, part) + } + return compact } diff --git a/pkg/tools/mcp_tool_test.go b/pkg/tools/mcp_tool_test.go index 95bb0f992..8bbac3bc7 100644 --- a/pkg/tools/mcp_tool_test.go +++ b/pkg/tools/mcp_tool_test.go @@ -3,10 +3,14 @@ package tools import ( "context" "fmt" + "os" + "path/filepath" "strings" "testing" "github.com/modelcontextprotocol/go-sdk/mcp" + + "github.com/sipeed/picoclaw/pkg/media" ) // MockMCPManager is a mock implementation of MCPManager interface for testing @@ -490,3 +494,143 @@ func TestMCPTool_Parameters_MapSchema(t *testing.T) { t.Errorf("Name type should be 'string', got '%v'", nameParam["type"]) } } + +func TestMCPTool_Execute_ImageContentStoredAsMedia(t *testing.T) { + store := media.NewFileMediaStore() + manager := &MockMCPManager{ + callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error) { + return &mcp.CallToolResult{ + Content: []mcp.Content{ + &mcp.ImageContent{ + Data: []byte("fake-image-bytes"), + MIMEType: "image/png", + }, + }, + }, nil + }, + } + + mcpTool := NewMCPTool(manager, "screenshoto", &mcp.Tool{Name: "take_screenshot"}) + mcpTool.SetMediaStore(store) + + result := mcpTool.Execute(WithToolContext(context.Background(), "telegram", "chat-42"), nil) + + if result.IsError { + t.Fatalf("expected success, got %q", result.ForLLM) + } + if len(result.Media) != 1 { + t.Fatalf("expected 1 media ref, got %d", len(result.Media)) + } + if result.ResponseHandled { + t.Fatal("expected MCP image artifact not to mark response as handled") + } + if !strings.Contains(result.ForLLM, "stored as a local media artifact") { + t.Fatalf("expected local media artifact note, got %q", result.ForLLM) + } + + path, meta, err := store.ResolveWithMeta(result.Media[0]) + if err != nil { + t.Fatalf("expected stored media ref to resolve: %v", err) + } + if meta.ContentType != "image/png" { + t.Fatalf("expected image/png content type, got %q", meta.ContentType) + } + if filepath.Ext(path) != ".png" { + t.Fatalf("expected png temp file, got %q", path) + } + data, err := os.ReadFile(path) + if err != nil { + t.Fatalf("expected stored media file to be readable: %v", err) + } + if string(data) != "fake-image-bytes" { + t.Fatalf("expected stored media bytes to match input, got %q", string(data)) + } +} + +func TestMCPTool_Execute_EmbeddedResourceBlobStoredAsMedia(t *testing.T) { + store := media.NewFileMediaStore() + manager := &MockMCPManager{ + callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error) { + return &mcp.CallToolResult{ + Content: []mcp.Content{ + &mcp.EmbeddedResource{ + Resource: &mcp.ResourceContents{ + URI: "file:///tmp/report.png", + MIMEType: "image/png", + Blob: []byte("blob-bytes"), + }, + }, + }, + }, nil + }, + } + + mcpTool := NewMCPTool(manager, "grafana", &mcp.Tool{Name: "get_dashboard_image"}) + mcpTool.SetMediaStore(store) + + result := mcpTool.Execute(WithToolContext(context.Background(), "telegram", "chat-42"), nil) + + if len(result.Media) != 1 { + t.Fatalf("expected embedded resource blob to be stored as media, got %d refs", len(result.Media)) + } + path, _, err := store.ResolveWithMeta(result.Media[0]) + if err != nil { + t.Fatalf("expected stored media ref to resolve: %v", err) + } + data, err := os.ReadFile(path) + if err != nil { + t.Fatalf("expected stored media file to be readable: %v", err) + } + if string(data) != "blob-bytes" { + t.Fatalf("expected stored blob bytes to match input, got %q", string(data)) + } +} + +func TestMCPTool_Execute_RespectsUserAudienceForBinaryContent(t *testing.T) { + store := media.NewFileMediaStore() + manager := &MockMCPManager{ + callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error) { + return &mcp.CallToolResult{ + Content: []mcp.Content{ + &mcp.ImageContent{ + Data: []byte("assistant-only"), + MIMEType: "image/png", + Annotations: &mcp.Annotations{Audience: []mcp.Role{"assistant"}}, + }, + }, + }, nil + }, + } + + mcpTool := NewMCPTool(manager, "screenshoto", &mcp.Tool{Name: "take_screenshot"}) + mcpTool.SetMediaStore(store) + + result := mcpTool.Execute(WithToolContext(context.Background(), "telegram", "chat-42"), nil) + + if len(result.Media) != 0 { + t.Fatalf("expected no media ref for non-user audience, got %d", len(result.Media)) + } + if !strings.Contains(result.ForLLM, "non-user audience") { + t.Fatalf("expected audience note, got %q", result.ForLLM) + } +} + +func TestMCPTool_Execute_LargeBase64TextIsOmittedFromContext(t *testing.T) { + manager := &MockMCPManager{ + callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error) { + return &mcp.CallToolResult{ + Content: []mcp.Content{ + &mcp.TextContent{Text: strings.Repeat("QUJD", 400)}, + }, + }, nil + }, + } + + mcpTool := NewMCPTool(manager, "test_server", &mcp.Tool{Name: "dump_payload"}) + + result := mcpTool.Execute(context.Background(), nil) + + if result.ForLLM != largeBase64OmittedMessage { + t.Fatalf("expected sanitized large base64 note, got %q", result.ForLLM) + } +} diff --git a/pkg/tools/normalization.go b/pkg/tools/normalization.go new file mode 100644 index 000000000..3a76c5d92 --- /dev/null +++ b/pkg/tools/normalization.go @@ -0,0 +1,292 @@ +package tools + +import ( + "encoding/base64" + "fmt" + "mime" + "os" + "path/filepath" + "regexp" + "strings" + "time" + "unicode" + + "github.com/sipeed/picoclaw/pkg/media" +) + +const ( + largeBase64OmittedMessage = "[Tool returned a large base64-like payload; omitted from model context.]" + inlineMediaOmittedMessage = "[Tool returned inline media content; omitted from model context.]" + inlineMediaStoredMessage = "[Tool returned inline media content (%s); omitted from model context and registered as a media attachment.]" +) + +var ( + inlineMarkdownDataURLRe = regexp.MustCompile(`!\[[^\]]*\]\((data:[^)]+)\)`) + inlineRawDataURLRe = regexp.MustCompile(`data:[^;\s]+;base64,[A-Za-z0-9+/=\r\n]+`) +) + +func normalizeToolResult( + result *ToolResult, + toolName string, + store media.MediaStore, + channel string, + chatID string, +) *ToolResult { + if result == nil { + return nil + } + + notes := make([]string, 0, 2) + seen := make(map[string]struct{}) + + if store != nil && channel != "" && chatID != "" { + var refs []string + var extractedNotes []string + + result.ForLLM, refs, extractedNotes = extractInlineMediaRefs( + result.ForLLM, + toolName, + store, + channel, + chatID, + seen, + ) + result.Media = append(result.Media, refs...) + notes = append(notes, extractedNotes...) + + result.ForUser, refs, extractedNotes = extractInlineMediaRefs( + result.ForUser, + toolName, + store, + channel, + chatID, + seen, + ) + result.Media = append(result.Media, refs...) + notes = append(notes, extractedNotes...) + } + + result.ForLLM = sanitizeToolLLMContent(result.ForLLM) + + if len(result.Media) > 0 && len(notes) > 0 { + if strings.TrimSpace(result.ForLLM) == "" { + result.ForLLM = strings.Join(notes, "\n") + } else { + result.ForLLM = strings.TrimSpace(result.ForLLM) + "\n" + strings.Join(notes, "\n") + } + } + if len(result.Media) > 0 && strings.TrimSpace(result.ForLLM) == "" { + result.ForLLM = "[Tool returned media content; omitted from model context and registered as a media attachment.]" + } + + return result +} + +func sanitizeToolLLMContent(text string) string { + trimmed := strings.TrimSpace(text) + if trimmed == "" { + return text + } + if inlineMarkdownDataURLRe.MatchString(trimmed) || inlineRawDataURLRe.MatchString(trimmed) { + cleaned := inlineMarkdownDataURLRe.ReplaceAllString(trimmed, "") + cleaned = inlineRawDataURLRe.ReplaceAllString(cleaned, "") + cleaned = strings.TrimSpace(cleaned) + if cleaned == "" { + return inlineMediaOmittedMessage + } + return cleaned + "\n" + inlineMediaOmittedMessage + } + if looksLikeLargeBase64Payload(trimmed) { + return largeBase64OmittedMessage + } + return text +} + +func looksLikeLargeBase64Payload(text string) bool { + trimmed := strings.TrimSpace(text) + if len(trimmed) < 1024 { + return false + } + + nonSpace := 0 + base64Like := 0 + spaceCount := 0 + + for _, r := range trimmed { + if unicode.IsSpace(r) { + spaceCount++ + continue + } + nonSpace++ + if (r >= 'A' && r <= 'Z') || + (r >= 'a' && r <= 'z') || + (r >= '0' && r <= '9') || + r == '+' || r == '/' || r == '=' { + base64Like++ + } + } + + if nonSpace == 0 { + return false + } + + ratio := float64(base64Like) / float64(nonSpace) + return ratio >= 0.97 && spaceCount <= len(trimmed)/128 +} + +func extractInlineMediaRefs( + text string, + toolName string, + store media.MediaStore, + channel string, + chatID string, + seen map[string]struct{}, +) (cleaned string, refs []string, notes []string) { + cleaned = text + + matches := inlineMarkdownDataURLRe.FindAllStringSubmatch(cleaned, -1) + for _, match := range matches { + if len(match) < 2 { + continue + } + dataURL := match[1] + ref, note := storeInlineDataURL(toolName, store, channel, chatID, dataURL, seen) + if ref != "" { + refs = append(refs, ref) + } + if note != "" { + notes = append(notes, note) + } + cleaned = strings.ReplaceAll(cleaned, match[0], "") + } + + rawMatches := inlineRawDataURLRe.FindAllString(cleaned, -1) + for _, dataURL := range rawMatches { + ref, note := storeInlineDataURL(toolName, store, channel, chatID, dataURL, seen) + if ref != "" { + refs = append(refs, ref) + } + if note != "" { + notes = append(notes, note) + } + cleaned = strings.ReplaceAll(cleaned, dataURL, "") + } + + return strings.TrimSpace(cleaned), refs, notes +} + +func storeInlineDataURL( + toolName string, + store media.MediaStore, + channel string, + chatID string, + dataURL string, + seen map[string]struct{}, +) (ref string, note string) { + dataURL = strings.TrimSpace(dataURL) + if _, ok := seen[dataURL]; ok { + return "", "" + } + seen[dataURL] = struct{}{} + + if !strings.HasPrefix(strings.ToLower(dataURL), "data:") { + return "", "" + } + + comma := strings.IndexByte(dataURL, ',') + if comma <= 5 { + return "", "[Tool returned inline media content that could not be parsed.]" + } + + metaPart := dataURL[:comma] + payload := dataURL[comma+1:] + if !strings.Contains(strings.ToLower(metaPart), ";base64") { + return "", "[Tool returned inline media content that was not base64-encoded.]" + } + + mimeType := strings.TrimSpace(strings.TrimPrefix(metaPart, "data:")) + if semi := strings.IndexByte(mimeType, ';'); semi >= 0 { + mimeType = mimeType[:semi] + } + if mimeType == "" { + mimeType = "application/octet-stream" + } + + payload = strings.NewReplacer("\n", "", "\r", "", "\t", "", " ", "").Replace(payload) + decoded, err := base64.StdEncoding.DecodeString(payload) + if err != nil { + return "", fmt.Sprintf("[Tool returned inline media content (%s) that could not be decoded.]", mimeType) + } + + dir := media.TempDir() + if err = os.MkdirAll(dir, 0o700); err != nil { + return "", fmt.Sprintf("[Tool returned inline media content (%s) but it could not be stored.]", mimeType) + } + + ext := extensionForMIMEType(mimeType) + tmpFile, err := os.CreateTemp(dir, "tool-inline-*"+ext) + if err != nil { + return "", fmt.Sprintf("[Tool returned inline media content (%s) but it could not be stored.]", mimeType) + } + tmpPath := tmpFile.Name() + if _, err = tmpFile.Write(decoded); err != nil { + tmpFile.Close() + _ = os.Remove(tmpPath) + return "", fmt.Sprintf("[Tool returned inline media content (%s) but it could not be stored.]", mimeType) + } + if err = tmpFile.Close(); err != nil { + _ = os.Remove(tmpPath) + return "", fmt.Sprintf("[Tool returned inline media content (%s) but it could not be stored.]", mimeType) + } + + filename := sanitizeIdentifierComponent(toolName) + ext + scope := fmt.Sprintf( + "tool:inline:%s:%s:%s:%d", + sanitizeIdentifierComponent(toolName), + channel, + chatID, + time.Now().UnixNano(), + ) + + ref, err = store.Store(tmpPath, media.MediaMeta{ + Filename: filename, + ContentType: mimeType, + Source: fmt.Sprintf("tool:inline:%s", sanitizeIdentifierComponent(toolName)), + }, scope) + if err != nil { + _ = os.Remove(tmpPath) + return "", fmt.Sprintf("[Tool returned inline media content (%s) but it could not be registered.]", mimeType) + } + + return ref, fmt.Sprintf(inlineMediaStoredMessage, mimeType) +} + +func extensionForMIMEType(mimeType string) string { + if mimeType == "" { + return ".bin" + } + if exts, err := mime.ExtensionsByType(mimeType); err == nil && len(exts) > 0 { + return exts[0] + } + + switch strings.ToLower(mimeType) { + case "image/jpeg": + return ".jpg" + case "image/png": + return ".png" + case "image/gif": + return ".gif" + case "image/webp": + return ".webp" + case "audio/wav", "audio/x-wav": + return ".wav" + case "audio/mpeg": + return ".mp3" + case "audio/ogg": + return ".ogg" + case "video/mp4": + return ".mp4" + default: + return filepath.Ext(mimeType) + } +} diff --git a/pkg/tools/registry.go b/pkg/tools/registry.go index ed373a28f..56af8d695 100644 --- a/pkg/tools/registry.go +++ b/pkg/tools/registry.go @@ -9,6 +9,7 @@ import ( "time" "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/media" "github.com/sipeed/picoclaw/pkg/providers" ) @@ -19,9 +20,14 @@ type ToolEntry struct { } type ToolRegistry struct { - tools map[string]*ToolEntry - mu sync.RWMutex - version atomic.Uint64 // incremented on Register/RegisterHidden for cache invalidation + tools map[string]*ToolEntry + mu sync.RWMutex + version atomic.Uint64 // incremented on Register/RegisterHidden for cache invalidation + mediaStore media.MediaStore +} + +type mediaStoreAware interface { + SetMediaStore(store media.MediaStore) } func NewToolRegistry() *ToolRegistry { @@ -43,6 +49,9 @@ func (r *ToolRegistry) Register(tool Tool) { IsCore: true, TTL: 0, // Core tools do not use TTL } + if aware, ok := tool.(mediaStoreAware); ok && r.mediaStore != nil { + aware.SetMediaStore(r.mediaStore) + } r.version.Add(1) logger.DebugCF("tools", "Registered core tool", map[string]any{"name": name}) } @@ -61,10 +70,27 @@ func (r *ToolRegistry) RegisterHidden(tool Tool) { IsCore: false, TTL: 0, } + if aware, ok := tool.(mediaStoreAware); ok && r.mediaStore != nil { + aware.SetMediaStore(r.mediaStore) + } r.version.Add(1) logger.DebugCF("tools", "Registered hidden tool", map[string]any{"name": name}) } +// SetMediaStore injects a MediaStore into all registered tools that can +// consume it, and remembers it for future registrations. +func (r *ToolRegistry) SetMediaStore(store media.MediaStore) { + r.mu.Lock() + defer r.mu.Unlock() + + r.mediaStore = store + for _, entry := range r.tools { + if aware, ok := entry.Tool.(mediaStoreAware); ok { + aware.SetMediaStore(store) + } + } +} + // PromoteTools atomically sets the TTL for multiple non-core tools. // This prevents a concurrent TickTTL from decrementing between promotions. func (r *ToolRegistry) PromoteTools(names []string, ttl int) { @@ -180,6 +206,14 @@ func (r *ToolRegistry) ExecuteWithContext( return ErrorResult(fmt.Sprintf("tool %q not found", name)).WithError(fmt.Errorf("tool not found")) } + // Validate arguments against the tool's declared schema. + if err := validateToolArgs(tool.Parameters(), args); err != nil { + logger.WarnCF("tool", "Tool argument validation failed", + map[string]any{"tool": name, "error": err.Error()}) + return ErrorResult(fmt.Sprintf("invalid arguments for tool %q: %s", name, err)). + WithError(fmt.Errorf("argument validation failed: %w", err)) + } + // Inject channel/chatID into ctx so tools read them via ToolChannel(ctx)/ToolChatID(ctx). // Always inject — tools validate what they require. ctx = WithToolContext(ctx, channel, chatID) @@ -230,6 +264,8 @@ func (r *ToolRegistry) ExecuteWithContext( } } + result = normalizeToolResult(result, name, r.mediaStore, channel, chatID) + duration := time.Since(start) // Log based on result type @@ -251,7 +287,7 @@ func (r *ToolRegistry) ExecuteWithContext( map[string]any{ "tool": name, "duration_ms": duration.Milliseconds(), - "result_length": len(result.ForLLM), + "result_length": len(result.ContentForLLM()), }) } @@ -346,7 +382,8 @@ func (r *ToolRegistry) Clone() *ToolRegistry { r.mu.RLock() defer r.mu.RUnlock() clone := &ToolRegistry{ - tools: make(map[string]*ToolEntry, len(r.tools)), + tools: make(map[string]*ToolEntry, len(r.tools)), + mediaStore: r.mediaStore, } for name, entry := range r.tools { clone.tools[name] = &ToolEntry{ diff --git a/pkg/tools/registry_test.go b/pkg/tools/registry_test.go index 967758dfa..db52749f6 100644 --- a/pkg/tools/registry_test.go +++ b/pkg/tools/registry_test.go @@ -3,10 +3,13 @@ package tools import ( "context" "errors" + "os" + "path/filepath" "strings" "sync" "testing" + "github.com/sipeed/picoclaw/pkg/media" "github.com/sipeed/picoclaw/pkg/providers" ) @@ -46,6 +49,15 @@ func (m *mockAsyncRegistryTool) ExecuteAsync(_ context.Context, args map[string] return m.result } +type mockMediaStoreAwareTool struct { + mockRegistryTool + store media.MediaStore +} + +func (m *mockMediaStoreAwareTool) SetMediaStore(store media.MediaStore) { + m.store = store +} + // --- helpers --- func newMockTool(name, desc string) *mockRegistryTool { @@ -621,3 +633,102 @@ func TestToolRegistry_Execute_PanicDoesNotAffectOtherTools(t *testing.T) { t.Errorf("expected 'success', got %q", result2.ForLLM) } } + +func TestToolRegistry_SetMediaStore_PropagatesToExistingAndNewTools(t *testing.T) { + r := NewToolRegistry() + store := media.NewFileMediaStore() + + existing := &mockMediaStoreAwareTool{ + mockRegistryTool: *newMockTool("existing", "existing tool"), + } + r.Register(existing) + + r.SetMediaStore(store) + if existing.store != store { + t.Fatal("expected existing tool to receive media store") + } + + later := &mockMediaStoreAwareTool{ + mockRegistryTool: *newMockTool("later", "later tool"), + } + r.Register(later) + + if later.store != store { + t.Fatal("expected newly registered tool to inherit media store") + } +} + +func TestToolRegistry_ExecuteWithContext_SanitizesLargeBase64Payload(t *testing.T) { + r := NewToolRegistry() + payload := strings.Repeat("QUJD", 400) + r.Register(&mockRegistryTool{ + name: "base64_tool", + desc: "returns huge base64", + params: map[string]any{}, + result: SilentResult(payload), + }) + + result := r.ExecuteWithContext(context.Background(), "base64_tool", nil, "telegram", "chat-1", nil) + + if result.ForLLM != largeBase64OmittedMessage { + t.Fatalf("expected sanitized payload, got %q", result.ForLLM) + } +} + +func TestToolRegistry_ExecuteWithContext_ExtractsInlineMediaDataURL(t *testing.T) { + r := NewToolRegistry() + store := media.NewFileMediaStore() + r.SetMediaStore(store) + + payload := "![screenshot](data:image/png;base64,aGVsbG8=)" + r.Register(&mockRegistryTool{ + name: "inline_media_tool", + desc: "returns inline data url", + params: map[string]any{}, + result: SilentResult(payload), + }) + + result := r.ExecuteWithContext(context.Background(), "inline_media_tool", nil, "telegram", "chat-42", nil) + + if len(result.Media) != 1 { + t.Fatalf("expected 1 media ref, got %d", len(result.Media)) + } + if strings.Contains(result.ForLLM, "data:image/png;base64") { + t.Fatalf("expected inline data URL to be stripped from ForLLM, got %q", result.ForLLM) + } + if !strings.Contains(result.ForLLM, "registered as a media attachment") { + t.Fatalf("expected delivery note in ForLLM, got %q", result.ForLLM) + } + + path, err := store.Resolve(result.Media[0]) + if err != nil { + t.Fatalf("expected stored media ref to resolve: %v", err) + } + if _, err := os.Stat(path); err != nil { + t.Fatalf("expected stored media file to exist: %v", err) + } + if filepath.Ext(path) != ".png" { + t.Fatalf("expected stored inline media to use png extension, got %q", path) + } +} + +func TestToolRegistry_ExecuteWithContext_SanitizesInlineMediaWithoutStore(t *testing.T) { + r := NewToolRegistry() + + payload := "before ![img](data:image/png;base64,aGVsbG8=) after" + r.Register(&mockRegistryTool{ + name: "inline_media_no_store", + desc: "returns inline data url without store", + params: map[string]any{}, + result: SilentResult(payload), + }) + + result := r.ExecuteWithContext(context.Background(), "inline_media_no_store", nil, "telegram", "chat-42", nil) + + if strings.Contains(result.ForLLM, "data:image/png;base64") { + t.Fatalf("expected inline data URL to be removed from ForLLM, got %q", result.ForLLM) + } + if !strings.Contains(result.ForLLM, inlineMediaOmittedMessage) { + t.Fatalf("expected inline media omission note, got %q", result.ForLLM) + } +} diff --git a/pkg/tools/result.go b/pkg/tools/result.go index bf34b7bc6..c81213125 100644 --- a/pkg/tools/result.go +++ b/pkg/tools/result.go @@ -2,10 +2,16 @@ package tools import ( "encoding/json" + "strings" "github.com/sipeed/picoclaw/pkg/providers" ) +const ( + handledToolLLMNote = "The requested output has already been delivered to the user in the current chat. Do not call send_file or any other delivery tool again. If you reply, provide only a brief confirmation." + artifactPathsLLMNote = "Use `send_file` with one of these paths to send it to the user, or use file/exec tools to save it inside the workspace if requested." +) + // ToolResult represents the structured return value from tool execution. // It provides clear semantics for different types of results and supports // async operations, user-facing messages, and error handling. @@ -43,6 +49,48 @@ type ToolResult struct { // Only populated by SubTurn executions; used by evaluator_optimizer // to carry stateful worker context across evaluation iterations. Messages []providers.Message `json:"-"` + + // ArtifactTags exposes local artifact paths back to the LLM in a structured + // form, e.g. "[file:/tmp/example.png]". This is used when a tool produced a + // reusable local artifact but did not deliver it to the user yet. + ArtifactTags []string `json:"artifact_tags,omitempty"` + + // ResponseHandled indicates that this tool execution already satisfied the + // user's request at the channel/output level, so the agent loop can stop + // without a follow-up assistant response. + ResponseHandled bool `json:"response_handled,omitempty"` +} + +// ContentForLLM returns the normalized textual content to append to the +// conversation after a tool call. Errors fall back to Err when ForLLM is empty. +func (tr *ToolResult) ContentForLLM() string { + if tr == nil { + return "" + } + content := tr.ForLLM + if content == "" && tr.Err != nil { + content = tr.Err.Error() + } + if tr.ResponseHandled { + if content == "" { + return handledToolLLMNote + } + if !strings.Contains(content, handledToolLLMNote) { + content += "\n" + handledToolLLMNote + } + } + if len(tr.ArtifactTags) > 0 { + artifactNote := "Local artifact paths: " + strings.Join(tr.ArtifactTags, " ") + "\n" + artifactPathsLLMNote + if content == "" { + content = artifactNote + } else if !strings.Contains(content, artifactNote) { + content += "\n" + artifactNote + } + } + if content != "" { + return content + } + return "" } // NewToolResult creates a basic ToolResult with content for the LLM. @@ -167,3 +215,9 @@ func (tr *ToolResult) WithError(err error) *ToolResult { tr.Err = err return tr } + +// WithResponseHandled marks the tool result as already delivered to the user. +func (tr *ToolResult) WithResponseHandled() *ToolResult { + tr.ResponseHandled = true + return tr +} diff --git a/pkg/tools/result_test.go b/pkg/tools/result_test.go index a234e33f3..5f08cb4fa 100644 --- a/pkg/tools/result_test.go +++ b/pkg/tools/result_test.go @@ -3,6 +3,7 @@ package tools import ( "encoding/json" "errors" + "strings" "testing" ) @@ -227,3 +228,41 @@ func TestToolResultJSONStructure(t *testing.T) { t.Errorf("Expected silent false, got %v", parsed["silent"]) } } + +func TestToolResultContentForLLM_AppendsHandledDeliveryNote(t *testing.T) { + result := MediaResult("Screenshot attached.", []string{"media://example"}).WithResponseHandled() + + content := result.ContentForLLM() + if !strings.Contains(content, "Screenshot attached.") { + t.Fatalf("expected original content in ContentForLLM, got %q", content) + } + if !strings.Contains(content, handledToolLLMNote) { + t.Fatalf("expected handled delivery note in ContentForLLM, got %q", content) + } +} + +func TestToolResultContentForLLM_UsesHandledDeliveryNoteWhenEmpty(t *testing.T) { + result := (&ToolResult{}).WithResponseHandled() + + if got := result.ContentForLLM(); got != handledToolLLMNote { + t.Fatalf("ContentForLLM() = %q, want %q", got, handledToolLLMNote) + } +} + +func TestToolResultContentForLLM_AppendsArtifactPaths(t *testing.T) { + result := &ToolResult{ + ForLLM: "Artifact created.", + ArtifactTags: []string{"[file:/tmp/example.png]"}, + } + + content := result.ContentForLLM() + if !strings.Contains(content, "Artifact created.") { + t.Fatalf("expected original content in ContentForLLM, got %q", content) + } + if !strings.Contains(content, "Local artifact paths: [file:/tmp/example.png]") { + t.Fatalf("expected artifact path note in ContentForLLM, got %q", content) + } + if !strings.Contains(content, artifactPathsLLMNote) { + t.Fatalf("expected artifact guidance note in ContentForLLM, got %q", content) + } +} diff --git a/pkg/tools/send_file.go b/pkg/tools/send_file.go index 57b99a845..44198381e 100644 --- a/pkg/tools/send_file.go +++ b/pkg/tools/send_file.go @@ -142,7 +142,7 @@ func (t *SendFileTool) Execute(ctx context.Context, args map[string]any) *ToolRe return ErrorResult(fmt.Sprintf("failed to register media: %v", err)) } - return MediaResult(fmt.Sprintf("File %q sent to user", filename), []string{ref}) + return MediaResult(fmt.Sprintf("File %q sent to user", filename), []string{ref}).WithResponseHandled() } // detectMediaType determines the MIME type of a file. diff --git a/pkg/tools/send_file_test.go b/pkg/tools/send_file_test.go index 0a99e8028..f36baf7d0 100644 --- a/pkg/tools/send_file_test.go +++ b/pkg/tools/send_file_test.go @@ -104,6 +104,9 @@ func TestSendFileTool_Success(t *testing.T) { if result.Media[0][:8] != "media://" { t.Errorf("expected media:// ref, got %q", result.Media[0]) } + if !result.ResponseHandled { + t.Fatal("expected send_file success to mark response handled") + } _, meta, err := store.ResolveWithMeta(result.Media[0]) if err != nil { diff --git a/pkg/tools/toolloop.go b/pkg/tools/toolloop.go index 244f0d4a2..387813e94 100644 --- a/pkg/tools/toolloop.go +++ b/pkg/tools/toolloop.go @@ -159,10 +159,7 @@ func RunToolLoop( // Append results in original order for _, r := range results { - contentForLLM := r.result.ForLLM - if contentForLLM == "" && r.result.Err != nil { - contentForLLM = r.result.Err.Error() - } + contentForLLM := r.result.ContentForLLM() messages = append(messages, providers.Message{ Role: "tool", diff --git a/pkg/tools/validate.go b/pkg/tools/validate.go new file mode 100644 index 000000000..940344708 --- /dev/null +++ b/pkg/tools/validate.go @@ -0,0 +1,209 @@ +package tools + +import ( + "fmt" + "math" +) + +// validateToolArgs validates args against a JSON Schema-like map. +// schema is expected to have optional keys: "properties", "required", "additionalProperties". +func validateToolArgs(schema map[string]any, args map[string]any) error { + if len(schema) == 0 { + return nil + } + + if args == nil { + args = map[string]any{} + } + + if err := checkRequired(schema, args); err != nil { + return err + } + + propsRaw, ok := schema["properties"] + if !ok { + return nil // no properties defined — accept any args + } + + props, ok := propsRaw.(map[string]any) + if !ok { + return nil + } + + additional := allowsAdditional(schema) + + for key, val := range args { + propSchemaRaw, known := props[key] + if !known { + if !additional { + return fmt.Errorf("unexpected property %q", key) + } + continue + } + propSchema, ok := propSchemaRaw.(map[string]any) + if !ok { + continue // can't validate without a proper schema map + } + if err := checkType(key, val, propSchema); err != nil { + return err + } + } + + return nil +} + +// checkRequired verifies that every field listed in schema["required"] is present in args. +func checkRequired(schema map[string]any, args map[string]any) error { + reqRaw, ok := schema["required"] + if !ok { + return nil + } + + var required []string + + switch r := reqRaw.(type) { + case []string: + required = r + case []any: + for _, v := range r { + s, ok := v.(string) + if ok { + required = append(required, s) + } + } + default: + return nil + } + + for _, field := range required { + if _, present := args[field]; !present { + return fmt.Errorf("missing required property %q", field) + } + } + return nil +} + +// allowsAdditional returns true when the schema explicitly sets +// "additionalProperties" to true, or when the key is absent (default: reject extras). +func allowsAdditional(schema map[string]any) bool { + v, ok := schema["additionalProperties"] + if !ok { + return false + } + b, ok := v.(bool) + return ok && b +} + +// checkType validates that val matches the JSON Schema type declared in propSchema. +func checkType(key string, val any, propSchema map[string]any) error { + typeRaw, ok := propSchema["type"] + if !ok { + return nil // no type constraint + } + typeName, ok := typeRaw.(string) + if !ok { + return nil + } + + switch typeName { + case "string": + if _, ok := val.(string); !ok { + return fmt.Errorf("property %q: expected string, got %T", key, val) + } + case "integer": + switch v := val.(type) { + case float64: + if v != math.Trunc(v) { + return fmt.Errorf("property %q: expected integer, got float64 with fractional part", key) + } + case int: + // ok + case int64: + // ok + default: + return fmt.Errorf("property %q: expected integer, got %T", key, val) + } + case "number": + switch val.(type) { + case float64, int, int64: + // ok + default: + return fmt.Errorf("property %q: expected number, got %T", key, val) + } + case "boolean": + if _, ok := val.(bool); !ok { + return fmt.Errorf("property %q: expected boolean, got %T", key, val) + } + case "array": + arr, ok := val.([]any) + if !ok { + return fmt.Errorf("property %q: expected array, got %T", key, val) + } + if err := checkArrayItems(key, arr, propSchema); err != nil { + return err + } + case "object": + obj, ok := val.(map[string]any) + if !ok { + return fmt.Errorf("property %q: expected object, got %T", key, val) + } + if err := validateToolArgs(propSchema, obj); err != nil { + return fmt.Errorf("property %q: %w", key, err) + } + } + + if err := checkEnum(key, val, propSchema); err != nil { + return err + } + + return nil +} + +// checkArrayItems validates each element of arr against the "items" sub-schema. +func checkArrayItems(key string, arr []any, propSchema map[string]any) error { + itemsRaw, ok := propSchema["items"] + if !ok { + return nil + } + itemSchema, ok := itemsRaw.(map[string]any) + if !ok { + return nil + } + for i, elem := range arr { + elemKey := fmt.Sprintf("%s[%d]", key, i) + if err := checkType(elemKey, elem, itemSchema); err != nil { + return err + } + } + return nil +} + +// checkEnum validates that val is one of the allowed enum values in propSchema. +func checkEnum(key string, val any, propSchema map[string]any) error { + enumRaw, ok := propSchema["enum"] + if !ok { + return nil + } + + switch ev := enumRaw.(type) { + case []any: + for _, allowed := range ev { + if val == allowed { + return nil + } + } + case []string: + s, ok := val.(string) + if ok { + for _, allowed := range ev { + if s == allowed { + return nil + } + } + } + default: + return nil // unknown enum format, skip + } + + return fmt.Errorf("property %q: value %v is not in enum", key, val) +} diff --git a/pkg/tools/validate_test.go b/pkg/tools/validate_test.go new file mode 100644 index 000000000..e7f4f619a --- /dev/null +++ b/pkg/tools/validate_test.go @@ -0,0 +1,465 @@ +package tools + +import ( + "context" + "strings" + "testing" +) + +// Ensure imports are used. +var ( + _ = context.Background + _ = strings.Contains +) + +func TestValidateToolArgs(t *testing.T) { + baseSchema := map[string]any{ + "type": "object", + "properties": map[string]any{ + "name": map[string]any{"type": "string"}, + "age": map[string]any{"type": "integer"}, + }, + "required": []string{"name"}, + } + + tests := []struct { + name string + schema map[string]any + args map[string]any + wantErr string // empty means no error expected + }{ + { + name: "valid args all required present", + schema: baseSchema, + args: map[string]any{"name": "alice", "age": float64(30)}, + }, + { + name: "missing required field", + schema: baseSchema, + args: map[string]any{"age": float64(30)}, + wantErr: "missing required property \"name\"", + }, + { + name: "wrong type string field gets number", + schema: baseSchema, + args: map[string]any{"name": float64(42)}, + wantErr: "expected string", + }, + { + name: "nil args with required fields", + schema: baseSchema, + args: nil, + wantErr: "missing required property \"name\"", + }, + { + name: "nil args no required fields", + schema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "name": map[string]any{"type": "string"}, + }, + }, + args: nil, + }, + { + name: "empty args no required fields", + schema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "name": map[string]any{"type": "string"}, + }, + }, + args: map[string]any{}, + }, + { + name: "optional field correct type", + schema: baseSchema, + args: map[string]any{"name": "bob", "age": float64(25)}, + }, + { + name: "optional field wrong type", + schema: baseSchema, + args: map[string]any{"name": "bob", "age": "twenty"}, + wantErr: "expected integer", + }, + { + name: "integer as float64 no fractional part", + schema: baseSchema, + args: map[string]any{"name": "carol", "age": float64(42)}, + }, + { + name: "actual float for integer field", + schema: baseSchema, + args: map[string]any{"name": "dave", "age": float64(42.5)}, + wantErr: "expected integer, got float64 with fractional part", + }, + { + name: "number type accepts float", + schema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "score": map[string]any{"type": "number"}, + }, + }, + args: map[string]any{"score": float64(3.14)}, + }, + { + name: "number type accepts integer", + schema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "score": map[string]any{"type": "number"}, + }, + }, + args: map[string]any{"score": float64(10)}, + }, + { + name: "boolean type valid", + schema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "flag": map[string]any{"type": "boolean"}, + }, + }, + args: map[string]any{"flag": true}, + }, + { + name: "boolean type wrong", + schema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "flag": map[string]any{"type": "boolean"}, + }, + }, + args: map[string]any{"flag": "true"}, + wantErr: "expected boolean", + }, + { + name: "required as []any from MCP deserialization", + schema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "cmd": map[string]any{"type": "string"}, + }, + "required": []any{"cmd"}, + }, + args: map[string]any{}, + wantErr: "missing required property \"cmd\"", + }, + { + name: "enum valid value []any", + schema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "color": map[string]any{"type": "string", "enum": []any{"red", "green", "blue"}}, + }, + }, + args: map[string]any{"color": "red"}, + }, + { + name: "enum invalid value []any", + schema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "color": map[string]any{"type": "string", "enum": []any{"red", "green", "blue"}}, + }, + }, + args: map[string]any{"color": "yellow"}, + wantErr: "not in enum", + }, + { + name: "enum valid value []string", + schema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "color": map[string]any{"type": "string", "enum": []string{"red", "green", "blue"}}, + }, + }, + args: map[string]any{"color": "green"}, + }, + { + name: "enum invalid value []string", + schema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "color": map[string]any{"type": "string", "enum": []string{"red", "green", "blue"}}, + }, + }, + args: map[string]any{"color": "yellow"}, + wantErr: "not in enum", + }, + { + name: "extra unexpected property rejected", + schema: baseSchema, + args: map[string]any{"name": "eve", "hobby": "chess"}, + wantErr: "unexpected property \"hobby\"", + }, + { + name: "extra property allowed with additionalProperties true", + schema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "name": map[string]any{"type": "string"}, + }, + "additionalProperties": true, + }, + args: map[string]any{"name": "eve", "hobby": "chess"}, + }, + { + name: "nested object valid", + schema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "address": map[string]any{ + "type": "object", + "properties": map[string]any{ + "city": map[string]any{"type": "string"}, + }, + "required": []string{"city"}, + }, + }, + }, + args: map[string]any{ + "address": map[string]any{"city": "Berlin"}, + }, + }, + { + name: "nested object wrong type", + schema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "address": map[string]any{ + "type": "object", + "properties": map[string]any{ + "city": map[string]any{"type": "string"}, + }, + }, + }, + }, + args: map[string]any{"address": "not an object"}, + wantErr: "expected object", + }, + { + name: "array with valid element types", + schema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "tags": map[string]any{ + "type": "array", + "items": map[string]any{"type": "string"}, + }, + }, + }, + args: map[string]any{"tags": []any{"a", "b", "c"}}, + }, + { + name: "array with wrong element types", + schema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "tags": map[string]any{ + "type": "array", + "items": map[string]any{"type": "string"}, + }, + }, + }, + args: map[string]any{"tags": []any{"a", float64(2)}}, + wantErr: "expected string", + }, + { + name: "schema with no properties key accepts any args", + schema: map[string]any{ + "type": "object", + }, + args: map[string]any{"anything": "goes"}, + }, + { + name: "empty schema accepts anything", + schema: map[string]any{}, + args: map[string]any{"foo": "bar"}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + err := validateToolArgs(tc.schema, tc.args) + if tc.wantErr == "" { + if err != nil { + t.Fatalf("expected no error, got: %v", err) + } + return + } + if err == nil { + t.Fatalf("expected error containing %q, got nil", tc.wantErr) + } + if !strings.Contains(err.Error(), tc.wantErr) { + t.Fatalf("expected error containing %q, got: %v", tc.wantErr, err) + } + }) + } +} + +func TestValidateToolArgs_RegistryIntegration(t *testing.T) { + r := NewToolRegistry() + r.Register(&mockRegistryTool{ + name: "read_file", + desc: "reads a file", + params: map[string]any{ + "type": "object", + "properties": map[string]any{ + "path": map[string]any{"type": "string"}, + }, + "required": []string{"path"}, + }, + result: SilentResult("file contents"), + }) + + // Valid args — should succeed + result := r.Execute(context.Background(), "read_file", map[string]any{"path": "/tmp/x"}) + if result.IsError { + t.Errorf("expected success, got error: %s", result.ForLLM) + } + + // Missing required field — should fail with validation error + result = r.Execute(context.Background(), "read_file", map[string]any{}) + if !result.IsError { + t.Error("expected validation error for missing required field") + } + if !strings.Contains(result.ForLLM, "missing required p") { + t.Errorf("expected 'missing required p...' in error, got %q", result.ForLLM) + } + if result.Err == nil { + t.Error("expected Err to be set via WithError") + } + + // Wrong type — should fail with validation error + result = r.Execute(context.Background(), "read_file", map[string]any{"path": 123.0}) + if !result.IsError { + t.Error("expected validation error for wrong type") + } + if !strings.Contains(result.ForLLM, "expected string") { + t.Errorf("expected 'expected string' in error, got %q", result.ForLLM) + } + + // Extra property — should fail with validation error + result = r.Execute(context.Background(), "read_file", map[string]any{"path": "/x", "__inject": true}) + if !result.IsError { + t.Error("expected validation error for extra property") + } + if !strings.Contains(result.ForLLM, "unexpected prop") { + t.Errorf("expected 'unexpected prop...' in error, got %q", result.ForLLM) + } +} + +func TestValidateToolArgs_RealSchemas(t *testing.T) { + execSchema := map[string]any{ + "type": "object", + "properties": map[string]any{ + "command": map[string]any{"type": "string"}, + "working_dir": map[string]any{"type": "string"}, + }, + "required": []string{"command"}, + } + + cronSchema := map[string]any{ + "type": "object", + "properties": map[string]any{ + "action": map[string]any{ + "type": "string", + "enum": []any{"add", "list", "remove", "enable", "disable"}, + }, + }, + "required": []string{"action"}, + } + + webSearchSchema := map[string]any{ + "type": "object", + "properties": map[string]any{ + "query": map[string]any{"type": "string"}, + "count": map[string]any{"type": "integer"}, + }, + "required": []string{"query"}, + } + + tests := []struct { + name string + schema map[string]any + args map[string]any + wantErr string + }{ + // ExecTool + { + name: "exec valid args", + schema: execSchema, + args: map[string]any{"command": "ls -la", "working_dir": "/tmp"}, + }, + { + name: "exec missing required command", + schema: execSchema, + args: map[string]any{"working_dir": "/tmp"}, + wantErr: "missing required property \"command\"", + }, + { + name: "exec wrong type for command", + schema: execSchema, + args: map[string]any{"command": float64(123)}, + wantErr: "expected string", + }, + { + name: "exec extra injected arg", + schema: execSchema, + args: map[string]any{"command": "ls", "malicious": "payload"}, + wantErr: "unexpected property \"malicious\"", + }, + + // CronTool + { + name: "cron valid enum value", + schema: cronSchema, + args: map[string]any{"action": "add"}, + }, + { + name: "cron invalid enum value", + schema: cronSchema, + args: map[string]any{"action": "destroy"}, + wantErr: "not in enum", + }, + + // WebSearchTool + { + name: "websearch valid args", + schema: webSearchSchema, + args: map[string]any{"query": "golang testing", "count": float64(10)}, + }, + { + name: "websearch missing required query", + schema: webSearchSchema, + args: map[string]any{"count": float64(5)}, + wantErr: "missing required property \"query\"", + }, + { + name: "websearch wrong type for count", + schema: webSearchSchema, + args: map[string]any{"query": "test", "count": "ten"}, + wantErr: "expected integer", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + err := validateToolArgs(tc.schema, tc.args) + if tc.wantErr == "" { + if err != nil { + t.Fatalf("expected no error, got: %v", err) + } + return + } + if err == nil { + t.Fatalf("expected error containing %q, got nil", tc.wantErr) + } + if !strings.Contains(err.Error(), tc.wantErr) { + t.Fatalf("expected error containing %q, got: %v", tc.wantErr, err) + } + }) + } +} diff --git a/web/backend/api/channels.go b/web/backend/api/channels.go index 507882823..dd4c9af3d 100644 --- a/web/backend/api/channels.go +++ b/web/backend/api/channels.go @@ -12,6 +12,7 @@ type channelCatalogItem struct { } var channelCatalog = []channelCatalogItem{ + {Name: "weixin", ConfigKey: "weixin"}, {Name: "telegram", ConfigKey: "telegram"}, {Name: "discord", ConfigKey: "discord"}, {Name: "slack", ConfigKey: "slack"}, @@ -21,8 +22,6 @@ var channelCatalog = []channelCatalogItem{ {Name: "qq", ConfigKey: "qq"}, {Name: "onebot", ConfigKey: "onebot"}, {Name: "wecom", ConfigKey: "wecom"}, - {Name: "wecom_app", ConfigKey: "wecom_app"}, - {Name: "wecom_aibot", ConfigKey: "wecom_aibot"}, {Name: "whatsapp", ConfigKey: "whatsapp", Variant: "bridge"}, {Name: "whatsapp_native", ConfigKey: "whatsapp", Variant: "native"}, {Name: "pico", ConfigKey: "pico"}, diff --git a/web/backend/api/config.go b/web/backend/api/config.go index 7cdfde174..618b8438d 100644 --- a/web/backend/api/config.go +++ b/web/backend/api/config.go @@ -6,6 +6,7 @@ import ( "io" "net/http" "regexp" + "strings" "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/logger" @@ -16,6 +17,7 @@ func (h *Handler) registerConfigRoutes(mux *http.ServeMux) { mux.HandleFunc("GET /api/config", h.handleGetConfig) mux.HandleFunc("PUT /api/config", h.handleUpdateConfig) mux.HandleFunc("PATCH /api/config", h.handlePatchConfig) + mux.HandleFunc("POST /api/config/test-command-patterns", h.handleTestCommandPatterns) } // handleGetConfig returns the complete system configuration. @@ -54,6 +56,15 @@ func (h *Handler) handleUpdateConfig(w http.ResponseWriter, r *http.Request) { cfg.Tools.Exec.AllowRemote = config.DefaultConfig().Tools.Exec.AllowRemote } + // Load existing config and copy security credentials before validation, + // so that security-managed fields (e.g. pico token) are available. + oldCfg, err := config.LoadConfig(h.configPath) + if err != nil { + http.Error(w, fmt.Sprintf("Failed to load config: %v", err), http.StatusInternalServerError) + return + } + cfg.SecurityCopyFrom(oldCfg) + if errs := validateConfig(&cfg); len(errs) > 0 { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusBadRequest) @@ -64,13 +75,7 @@ func (h *Handler) handleUpdateConfig(w http.ResponseWriter, r *http.Request) { return } - logger.Infof("new config: %+v", cfg) - oldCfg, err := config.LoadConfig(h.configPath) - if err != nil { - http.Error(w, fmt.Sprintf("Failed to load config: %v", err), http.StatusInternalServerError) - return - } - cfg.SecurityCopyFrom(oldCfg) + logger.Infof("configuration updated successfully") if err := config.SaveConfig(h.configPath, &cfg); err != nil { http.Error(w, fmt.Sprintf("Failed to save config: %v", err), http.StatusInternalServerError) @@ -149,6 +154,14 @@ func (h *Handler) handlePatchConfig(w http.ResponseWriter, r *http.Request) { return } + // Restore security fields (tokens/keys) from the loaded config before validation, + // because private fields are lost during JSON round-trip. + newCfg.SecurityCopyFrom(cfg) + if err := newCfg.ApplySecurity(); err != nil { + http.Error(w, fmt.Sprintf("Failed to apply security config: %v", err), http.StatusInternalServerError) + return + } + if errs := validateConfig(&newCfg); len(errs) > 0 { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusBadRequest) @@ -159,8 +172,6 @@ func (h *Handler) handlePatchConfig(w http.ResponseWriter, r *http.Request) { return } - newCfg.SecurityCopyFrom(cfg) - if err := config.SaveConfig(h.configPath, &newCfg); err != nil { http.Error(w, fmt.Sprintf("Failed to save config: %v", err), http.StatusInternalServerError) return @@ -170,6 +181,70 @@ func (h *Handler) handlePatchConfig(w http.ResponseWriter, r *http.Request) { json.NewEncoder(w).Encode(map[string]string{"status": "ok"}) } +// handleTestCommandPatterns tests a command against whitelist and blacklist patterns. +// +// POST /api/config/test-command-patterns +func (h *Handler) handleTestCommandPatterns(w http.ResponseWriter, r *http.Request) { + body, err := io.ReadAll(io.LimitReader(r.Body, 1<<20)) + if err != nil { + http.Error(w, "Failed to read request body", http.StatusBadRequest) + return + } + defer r.Body.Close() + + var req struct { + AllowPatterns []string `json:"allow_patterns"` + DenyPatterns []string `json:"deny_patterns"` + Command string `json:"command"` + } + if err := json.Unmarshal(body, &req); err != nil { + http.Error(w, fmt.Sprintf("Invalid JSON: %v", err), http.StatusBadRequest) + return + } + + lower := strings.ToLower(strings.TrimSpace(req.Command)) + + type result struct { + Allowed bool `json:"allowed"` + Blocked bool `json:"blocked"` + MatchedWhitelist *string `json:"matched_whitelist,omitempty"` + MatchedBlacklist *string `json:"matched_blacklist,omitempty"` + } + + resp := result{Allowed: false, Blocked: false} + + // Check whitelist first + for _, pattern := range req.AllowPatterns { + re, err := regexp.Compile(pattern) + if err != nil { + continue // skip invalid patterns + } + if re.MatchString(lower) { + resp.Allowed = true + resp.MatchedWhitelist = &pattern + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + return + } + } + + // Check blacklist + for _, pattern := range req.DenyPatterns { + re, err := regexp.Compile(pattern) + if err != nil { + continue + } + if re.MatchString(lower) { + resp.Blocked = true + resp.MatchedBlacklist = &pattern + break + } + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) +} + // validateConfig checks the config for common errors before saving. // Returns a list of human-readable error strings; empty means valid. func validateConfig(cfg *config.Config) []string { @@ -200,6 +275,15 @@ func validateConfig(cfg *config.Config) []string { errs = append(errs, "channels.discord.token is required when discord channel is enabled") } + if cfg.Channels.WeCom.Enabled { + if cfg.Channels.WeCom.BotID == "" { + errs = append(errs, "channels.wecom.bot_id is required when wecom channel is enabled") + } + if cfg.Channels.WeCom.Secret() == "" { + errs = append(errs, "channels.wecom.secret is required when wecom channel is enabled") + } + } + if cfg.Tools.Exec.Enabled { if cfg.Tools.Exec.EnableDenyPatterns { errs = append( diff --git a/web/backend/api/config_test.go b/web/backend/api/config_test.go index bbf285e14..36acd95b0 100644 --- a/web/backend/api/config_test.go +++ b/web/backend/api/config_test.go @@ -4,6 +4,8 @@ import ( "bytes" "net/http" "net/http/httptest" + "os" + "path/filepath" "testing" "github.com/sipeed/picoclaw/pkg/config" @@ -141,6 +143,120 @@ func TestHandlePatchConfig_AllowsInvalidExecRegexPatternsWhenExecDisabled(t *tes } } +// setupPicoEnabledEnv creates a test environment with Pico channel enabled and +// its token stored only in .security.yml (not in the JSON payload). +func setupPicoEnabledEnv(t *testing.T) (string, func()) { + t.Helper() + + tmp := t.TempDir() + oldHome := os.Getenv("HOME") + oldPicoHome := os.Getenv("PICOCLAW_HOME") + + if err := os.Setenv("HOME", tmp); err != nil { + t.Fatalf("set HOME: %v", err) + } + if err := os.Setenv("PICOCLAW_HOME", filepath.Join(tmp, ".picoclaw")); err != nil { + t.Fatalf("set PICOCLAW_HOME: %v", err) + } + + cfg := config.DefaultConfig() + cfg.ModelList = []*config.ModelConfig{{ + ModelName: "custom-default", + Model: "openai/gpt-4o", + }} + cfg.Agents.Defaults.ModelName = "custom-default" + cfg.Channels.Pico.Enabled = true + cfg.WithSecurity(&config.SecurityConfig{ + ModelList: map[string]config.ModelSecurityEntry{ + "custom-default": {APIKeys: []string{"sk-default"}}, + }, + Channels: &config.ChannelsSecurity{ + Pico: &config.PicoSecurity{Token: "test-pico-token"}, + }, + }) + + configPath := filepath.Join(tmp, "config.json") + if err := config.SaveConfig(configPath, cfg); err != nil { + t.Fatalf("SaveConfig error: %v", err) + } + + cleanup := func() { + _ = os.Setenv("HOME", oldHome) + if oldPicoHome == "" { + _ = os.Unsetenv("PICOCLAW_HOME") + } else { + _ = os.Setenv("PICOCLAW_HOME", oldPicoHome) + } + } + return configPath, cleanup +} + +func TestHandleUpdateConfig_SucceedsWhenPicoTokenInSecurityOnly(t *testing.T) { + configPath, cleanup := setupPicoEnabledEnv(t) + defer cleanup() + + h := NewHandler(configPath) + mux := http.NewServeMux() + h.RegisterRoutes(mux) + + // PUT request with pico enabled but no token in JSON — token is in .security.yml + req := httptest.NewRequest(http.MethodPut, "/api/config", bytes.NewBufferString(`{ + "version": 1, + "agents": { + "defaults": { + "workspace": "~/.picoclaw/workspace", + "model_name": "custom-default" + } + }, + "channels": { + "pico": { + "enabled": true, + "ping_interval": 30, + "read_timeout": 60, + "write_timeout": 10, + "max_connections": 100 + } + }, + "model_list": [ + { + "model_name": "custom-default", + "model": "openai/gpt-4o", + "api_keys": ["sk-default"] + } + ] + }`)) + req.Header.Set("Content-Type", "application/json") + + rec := httptest.NewRecorder() + mux.ServeHTTP(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("PUT /api/config status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String()) + } +} + +func TestHandlePatchConfig_SucceedsWhenPicoTokenInSecurityOnly(t *testing.T) { + configPath, cleanup := setupPicoEnabledEnv(t) + defer cleanup() + + h := NewHandler(configPath) + mux := http.NewServeMux() + h.RegisterRoutes(mux) + + // PATCH request changing an unrelated field — pico token still in .security.yml + req := httptest.NewRequest(http.MethodPatch, "/api/config", bytes.NewBufferString(`{ + "gateway": { + "log_level": "info" + } + }`)) + req.Header.Set("Content-Type", "application/json") + + rec := httptest.NewRecorder() + mux.ServeHTTP(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("PATCH /api/config status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String()) + } +} + func TestHandlePatchConfig_AllowsInvalidDenyRegexPatternsWhenDenyPatternsDisabled(t *testing.T) { configPath, cleanup := setupOAuthTestEnv(t) defer cleanup() @@ -166,3 +282,170 @@ func TestHandlePatchConfig_AllowsInvalidDenyRegexPatternsWhenDenyPatternsDisable t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String()) } } + +// testCommandPatterns is a helper that sets up a handler and sends a test-command-patterns request. +func testCommandPatterns(t *testing.T, configPath string, body string) *httptest.ResponseRecorder { + t.Helper() + h := NewHandler(configPath) + mux := http.NewServeMux() + h.RegisterRoutes(mux) + req := httptest.NewRequest(http.MethodPost, "/api/config/test-command-patterns", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + mux.ServeHTTP(rec, req) + return rec +} + +func TestHandleTestCommandPatterns_MatchesWhitelist(t *testing.T) { + configPath, cleanup := setupOAuthTestEnv(t) + defer cleanup() + + rec := testCommandPatterns(t, configPath, `{ + "allow_patterns": ["^echo\\s+hello"], + "deny_patterns": ["^rm\\s+-rf"], + "command": "echo hello world" + }`) + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String()) + } + if !bytes.Contains(rec.Body.Bytes(), []byte(`"allowed":true`)) { + t.Fatalf("expected allowed=true, body=%s", rec.Body.String()) + } + if bytes.Contains(rec.Body.Bytes(), []byte(`"blocked":true`)) { + t.Fatalf("expected blocked=false when whitelist matches, body=%s", rec.Body.String()) + } +} + +func TestHandleTestCommandPatterns_MatchesBlacklistNotWhitelist(t *testing.T) { + configPath, cleanup := setupOAuthTestEnv(t) + defer cleanup() + + rec := testCommandPatterns(t, configPath, `{ + "allow_patterns": ["^echo\\s+hello"], + "deny_patterns": ["^rm\\s+-rf"], + "command": "rm -rf /tmp" + }`) + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String()) + } + if !bytes.Contains(rec.Body.Bytes(), []byte(`"blocked":true`)) { + t.Fatalf("expected blocked=true, body=%s", rec.Body.String()) + } + if bytes.Contains(rec.Body.Bytes(), []byte(`"allowed":true`)) { + t.Fatalf("expected allowed=false when blacklist matches but not whitelist, body=%s", rec.Body.String()) + } +} + +func TestHandleTestCommandPatterns_MatchesNeither(t *testing.T) { + configPath, cleanup := setupOAuthTestEnv(t) + defer cleanup() + + rec := testCommandPatterns(t, configPath, `{ + "allow_patterns": ["^echo\\s+hello"], + "deny_patterns": ["^rm\\s+-rf"], + "command": "ls -la" + }`) + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String()) + } + if bytes.Contains(rec.Body.Bytes(), []byte(`"allowed":true`)) { + t.Fatalf("expected allowed=false, body=%s", rec.Body.String()) + } + if bytes.Contains(rec.Body.Bytes(), []byte(`"blocked":true`)) { + t.Fatalf("expected blocked=false, body=%s", rec.Body.String()) + } +} + +func TestHandleTestCommandPatterns_CaseInsensitiveWithGoFlag(t *testing.T) { + configPath, cleanup := setupOAuthTestEnv(t) + defer cleanup() + + rec := testCommandPatterns(t, configPath, `{ + "allow_patterns": ["(?i)^ECHO"], + "deny_patterns": [], + "command": "echo hello" + }`) + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String()) + } + if !bytes.Contains(rec.Body.Bytes(), []byte(`"allowed":true`)) { + t.Fatalf("expected allowed=true with Go (?i) flag, body=%s", rec.Body.String()) + } +} + +func TestHandleTestCommandPatterns_EmptyPatterns(t *testing.T) { + configPath, cleanup := setupOAuthTestEnv(t) + defer cleanup() + + rec := testCommandPatterns(t, configPath, `{ + "allow_patterns": [], + "deny_patterns": [], + "command": "rm -rf /tmp" + }`) + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String()) + } + if bytes.Contains(rec.Body.Bytes(), []byte(`"allowed":true`)) { + t.Fatalf("expected allowed=false with empty patterns, body=%s", rec.Body.String()) + } + if bytes.Contains(rec.Body.Bytes(), []byte(`"blocked":true`)) { + t.Fatalf("expected blocked=false with empty patterns, body=%s", rec.Body.String()) + } +} + +func TestHandleTestCommandPatterns_InvalidRegexSkipped(t *testing.T) { + configPath, cleanup := setupOAuthTestEnv(t) + defer cleanup() + + rec := testCommandPatterns(t, configPath, `{ + "allow_patterns": ["([[", "^echo"], + "deny_patterns": [], + "command": "echo hello" + }`) + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String()) + } + if !bytes.Contains(rec.Body.Bytes(), []byte(`"allowed":true`)) { + t.Fatalf("expected allowed=true, invalid pattern skipped and valid one matched, body=%s", rec.Body.String()) + } +} + +func TestHandleTestCommandPatterns_ReturnsMatchedPattern(t *testing.T) { + configPath, cleanup := setupOAuthTestEnv(t) + defer cleanup() + + rec := testCommandPatterns(t, configPath, `{ + "allow_patterns": [], + "deny_patterns": ["\\$(?i)[a-zA-Z_]*(SECRET|KEY|PASSWORD|TOKEN|AUTH)[a-zA-Z0-9_]*"], + "command": "echo $GITHUB_API_KEY" + }`) + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String()) + } + if !bytes.Contains(rec.Body.Bytes(), []byte(`"blocked":true`)) { + t.Fatalf("expected blocked=true, body=%s", rec.Body.String()) + } + if !bytes.Contains(rec.Body.Bytes(), []byte(`matched_blacklist`)) { + t.Fatalf("expected matched_blacklist field, body=%s", rec.Body.String()) + } +} + +func TestHandleTestCommandPatterns_InvalidJSON(t *testing.T) { + configPath, cleanup := setupOAuthTestEnv(t) + defer cleanup() + + h := NewHandler(configPath) + mux := http.NewServeMux() + h.RegisterRoutes(mux) + req := httptest.NewRequest( + http.MethodPost, + "/api/config/test-command-patterns", + bytes.NewBufferString(`{invalid json}`), + ) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + mux.ServeHTTP(rec, req) + if rec.Code != http.StatusBadRequest { + t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusBadRequest, rec.Body.String()) + } +} diff --git a/web/backend/api/gateway.go b/web/backend/api/gateway.go index 7f72f12b8..4bde5ce82 100644 --- a/web/backend/api/gateway.go +++ b/web/backend/api/gateway.go @@ -407,7 +407,7 @@ func (h *Handler) startGatewayLocked(initialStatus string, existingPid int) (int gateway.logs.Reset() // Ensure Pico Channel is configured before starting gateway - if _, err := h.ensurePicoChannel(""); err != nil { + if _, err := h.EnsurePicoChannel(""); err != nil { logger.ErrorC("gateway", fmt.Sprintf("Warning: failed to ensure pico channel: %v", err)) // Non-fatal: gateway can still start without pico channel } diff --git a/web/backend/api/models.go b/web/backend/api/models.go index 1e3b5f90a..ce7719906 100644 --- a/web/backend/api/models.go +++ b/web/backend/api/models.go @@ -42,6 +42,7 @@ type modelResponse struct { // Meta Configured bool `json:"configured"` IsDefault bool `json:"is_default"` + IsVirtual bool `json:"is_virtual"` } // handleListModels returns all model_list entries with masked API keys. @@ -86,6 +87,7 @@ func (h *Handler) handleListModels(w http.ResponseWriter, r *http.Request) { ExtraBody: m.ExtraBody, Configured: configured[i], IsDefault: m.ModelName == defaultModel, + IsVirtual: m.IsVirtual(), }) } @@ -108,7 +110,12 @@ func (h *Handler) handleAddModel(w http.ResponseWriter, r *http.Request) { } defer r.Body.Close() - var mc config.ModelConfig + type custom struct { + config.ModelConfig + APIKey string `json:"api_key"` + } + + var mc custom if err = json.Unmarshal(body, &mc); err != nil { http.Error(w, fmt.Sprintf("Invalid JSON: %v", err), http.StatusBadRequest) return @@ -119,13 +126,17 @@ func (h *Handler) handleAddModel(w http.ResponseWriter, r *http.Request) { return } + if mc.APIKey != "" { + mc.ModelConfig.SetAPIKey(mc.APIKey) + } + cfg, err := config.LoadConfig(h.configPath) if err != nil { http.Error(w, fmt.Sprintf("Failed to load config: %v", err), http.StatusInternalServerError) return } - cfg.ModelList = append(cfg.ModelList, &mc) + cfg.ModelList = append(cfg.ModelList, &mc.ModelConfig) if err := config.SaveConfig(h.configPath, cfg); err != nil { http.Error(w, fmt.Sprintf("Failed to save config: %v", err), http.StatusInternalServerError) @@ -279,11 +290,13 @@ func (h *Handler) handleSetDefaultModel(w http.ResponseWriter, r *http.Request) return } - // Verify the model_name exists in model_list + // Verify the model_name exists in model_list and is not a virtual model found := false + isVirtual := false for _, m := range cfg.ModelList { if m.ModelName == req.ModelName { found = true + isVirtual = m.IsVirtual() break } } @@ -291,6 +304,10 @@ func (h *Handler) handleSetDefaultModel(w http.ResponseWriter, r *http.Request) http.Error(w, fmt.Sprintf("Model %q not found in model_list", req.ModelName), http.StatusNotFound) return } + if isVirtual { + http.Error(w, fmt.Sprintf("Cannot set virtual model %q as default", req.ModelName), http.StatusBadRequest) + return + } cfg.Agents.Defaults.ModelName = req.ModelName @@ -307,16 +324,25 @@ func (h *Handler) handleSetDefaultModel(w http.ResponseWriter, r *http.Request) } // maskAPIKey returns a masked version of an API key for safe display. -// Keys longer than 8 chars show prefix + last 4 chars: "sk-****abcd" +// Keys longer than 12 chars show prefix + last 4 chars: "sk-****abcd". +// Keys 9-12 chars show prefix + last 2 chars: "sk-****cd". // Shorter keys are fully masked as "****". // Empty keys return empty string. +// Ensure at least 40% of the key will not be displayed. func maskAPIKey(key string) string { if key == "" { return "" } + if len(key) <= 8 { return "****" } + + // Show first 3 chars and last 2 chars + if len(key) <= 12 { + return key[:3] + "****" + key[len(key)-2:] + } + // Show first 3 chars and last 4 chars return key[:3] + "****" + key[len(key)-4:] } diff --git a/web/backend/api/models_test.go b/web/backend/api/models_test.go index 44d10154e..c80527fe3 100644 --- a/web/backend/api/models_test.go +++ b/web/backend/api/models_test.go @@ -1,9 +1,11 @@ package api import ( + "bytes" "encoding/json" "net/http" "net/http/httptest" + "strings" "sync" "testing" "time" @@ -315,3 +317,152 @@ func TestHandleListModels_NormalizesWildcardLocalAPIBaseForProbe(t *testing.T) { t.Fatalf("probe api base = %q, want %q", gotProbe, "http://127.0.0.1:8000/v1|custom-model|") } } + +func TestHandleAddModel_PersistsAPIKey(t *testing.T) { + configPath, cleanup := setupOAuthTestEnv(t) + defer cleanup() + + h := NewHandler(configPath) + mux := http.NewServeMux() + h.RegisterRoutes(mux) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/api/models", bytes.NewBufferString(`{ + "model_name":"new-model", + "model":"openai/gpt-4o-mini", + "api_key":"sk-new-model-key" + }`)) + req.Header.Set("Content-Type", "application/json") + mux.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String()) + } + + cfg, err := config.LoadConfig(configPath) + if err != nil { + t.Fatalf("LoadConfig() error = %v", err) + } + if len(cfg.ModelList) != 2 { + t.Fatalf("len(model_list) = %d, want 2", len(cfg.ModelList)) + } + + added := cfg.ModelList[1] + if added.ModelName != "new-model" { + t.Fatalf("model_name = %q, want %q", added.ModelName, "new-model") + } + if added.APIKey() != "sk-new-model-key" { + t.Fatalf("api_key = %q, want %q", added.APIKey(), "sk-new-model-key") + } +} + +// TestHandleSetDefaultModel_RejectsNonexistentModel tests that setting a non-existent +// model as default returns 404. This covers the case where virtual models (which are +// filtered by SaveConfig) cannot be set as default. +func TestHandleSetDefaultModel_RejectsNonexistentModel(t *testing.T) { + configPath, cleanup := setupOAuthTestEnv(t) + defer cleanup() + + // First save a valid config with a primary model + cfg, err := config.LoadConfig(configPath) + if err != nil { + t.Fatalf("LoadConfig() error = %v", err) + } + cfg.ModelList = []*config.ModelConfig{ + {ModelName: "gpt-4", Model: "openai/gpt-4o"}, + } + if err := config.SaveConfig(configPath, cfg); err != nil { + t.Fatalf("SaveConfig() error = %v", err) + } + + // Try to set a non-existent model (like a virtual model name) as default + h := NewHandler(configPath) + mux := http.NewServeMux() + h.RegisterRoutes(mux) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/api/models/default", bytes.NewBufferString(`{ + "model_name": "gpt-4__key_1" + }`)) + req.Header.Set("Content-Type", "application/json") + mux.ServeHTTP(rec, req) + + // Should return 404 because the virtual model doesn't exist in the persisted config + if rec.Code != http.StatusNotFound { + t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusNotFound, rec.Body.String()) + } + if !strings.Contains(rec.Body.String(), "not found") { + t.Fatalf("error message should mention 'not found', got: %s", rec.Body.String()) + } +} + +func TestMaskAPIKey(t *testing.T) { + tests := []struct { + name string + key string + want string + }{ + { + name: "empty key", + key: "", + want: "", + }, + { + name: "short key fully masked", + key: "abcd", + want: "****", + }, + { + name: "length 8 boundary fully masked", + key: "12345678", + want: "****", + }, + { + name: "length 9 boundary shows last 2", + key: "123456789", + want: "123****89", + }, + { + name: "length 12 boundary shows last 2", + key: "abcdefghijkl", + want: "abc****kl", + }, + { + name: "length 13 boundary shows last 4", + key: "abcdefghijklm", + want: "abc****jklm", + }, + { + name: "typical api key", + key: "sk-1234567890abcd", + want: "sk-****abcd", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := maskAPIKey(tc.key) + if got != tc.want { + t.Fatalf("maskAPIKey(%q) = %q, want %q", tc.key, got, tc.want) + } + + if tc.key != "" { + displayed := strings.Replace(tc.want, "****", "", 1) + if len(tc.key) <= 8 { + if displayed != "" { + t.Fatalf("maskAPIKey(%q) displayed part = %q, want empty", tc.key, displayed) + } + } else { + if len(displayed)*10 > len(tc.key)*6 { + t.Fatalf( + "maskAPIKey(%q) displayed length = %d, want at most 60%% of %d", + tc.key, + len(displayed), + len(tc.key), + ) + } + } + } + }) + } +} diff --git a/web/backend/api/pico.go b/web/backend/api/pico.go index 8fbb8737f..4faafc2ae 100644 --- a/web/backend/api/pico.go +++ b/web/backend/api/pico.go @@ -90,14 +90,14 @@ func (h *Handler) handleRegenPicoToken(w http.ResponseWriter, r *http.Request) { }) } -// ensurePicoChannel enables the Pico channel with sane defaults if it isn't +// EnsurePicoChannel enables the Pico channel with sane defaults if it isn't // already configured. Returns true when the config was modified. // // callerOrigin is the Origin header from the setup request. If non-empty and // no origins are configured yet, it's written as the allowed origin so the // WebSocket handshake works for whatever host the caller is on (LAN, custom // port, etc.). Pass "" when there's no request context. -func (h *Handler) ensurePicoChannel(callerOrigin string) (bool, error) { +func (h *Handler) EnsurePicoChannel(callerOrigin string) (bool, error) { cfg, err := config.LoadConfig(h.configPath) if err != nil { return false, fmt.Errorf("failed to load config: %w", err) @@ -134,7 +134,7 @@ func (h *Handler) ensurePicoChannel(callerOrigin string) (bool, error) { // // POST /api/pico/setup func (h *Handler) handlePicoSetup(w http.ResponseWriter, r *http.Request) { - changed, err := h.ensurePicoChannel(r.Header.Get("Origin")) + changed, err := h.EnsurePicoChannel(r.Header.Get("Origin")) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return diff --git a/web/backend/api/pico_test.go b/web/backend/api/pico_test.go index 263253cb2..051e356cf 100644 --- a/web/backend/api/pico_test.go +++ b/web/backend/api/pico_test.go @@ -6,6 +6,7 @@ import ( "net/http" "net/http/httptest" "net/url" + "os" "path/filepath" "strconv" "testing" @@ -17,12 +18,12 @@ func TestEnsurePicoChannel_FreshConfig(t *testing.T) { configPath := filepath.Join(t.TempDir(), "config.json") h := NewHandler(configPath) - changed, err := h.ensurePicoChannel("") + changed, err := h.EnsurePicoChannel("") if err != nil { - t.Fatalf("ensurePicoChannel() error = %v", err) + t.Fatalf("EnsurePicoChannel() error = %v", err) } if !changed { - t.Fatal("ensurePicoChannel() should report changed on a fresh config") + t.Fatal("EnsurePicoChannel() should report changed on a fresh config") } cfg, err := config.LoadConfig(configPath) @@ -42,8 +43,8 @@ func TestEnsurePicoChannel_DoesNotEnableTokenQuery(t *testing.T) { configPath := filepath.Join(t.TempDir(), "config.json") h := NewHandler(configPath) - if _, err := h.ensurePicoChannel(""); err != nil { - t.Fatalf("ensurePicoChannel() error = %v", err) + if _, err := h.EnsurePicoChannel(""); err != nil { + t.Fatalf("EnsurePicoChannel() error = %v", err) } cfg, err := config.LoadConfig(configPath) @@ -60,8 +61,8 @@ func TestEnsurePicoChannel_DoesNotSetWildcardOrigins(t *testing.T) { configPath := filepath.Join(t.TempDir(), "config.json") h := NewHandler(configPath) - if _, err := h.ensurePicoChannel("http://localhost:18800"); err != nil { - t.Fatalf("ensurePicoChannel() error = %v", err) + if _, err := h.EnsurePicoChannel("http://localhost:18800"); err != nil { + t.Fatalf("EnsurePicoChannel() error = %v", err) } cfg, err := config.LoadConfig(configPath) @@ -80,8 +81,8 @@ func TestEnsurePicoChannel_NoOriginWithoutCaller(t *testing.T) { configPath := filepath.Join(t.TempDir(), "config.json") h := NewHandler(configPath) - if _, err := h.ensurePicoChannel(""); err != nil { - t.Fatalf("ensurePicoChannel() error = %v", err) + if _, err := h.EnsurePicoChannel(""); err != nil { + t.Fatalf("EnsurePicoChannel() error = %v", err) } cfg, err := config.LoadConfig(configPath) @@ -101,8 +102,8 @@ func TestEnsurePicoChannel_SetsCallerOrigin(t *testing.T) { h := NewHandler(configPath) lanOrigin := "http://192.168.1.9:18800" - if _, err := h.ensurePicoChannel(lanOrigin); err != nil { - t.Fatalf("ensurePicoChannel() error = %v", err) + if _, err := h.EnsurePicoChannel(lanOrigin); err != nil { + t.Fatalf("EnsurePicoChannel() error = %v", err) } cfg, err := config.LoadConfig(configPath) @@ -130,12 +131,12 @@ func TestEnsurePicoChannel_PreservesUserSettings(t *testing.T) { h := NewHandler(configPath) - changed, err := h.ensurePicoChannel("") + changed, err := h.EnsurePicoChannel("") if err != nil { - t.Fatalf("ensurePicoChannel() error = %v", err) + t.Fatalf("EnsurePicoChannel() error = %v", err) } if changed { - t.Error("ensurePicoChannel() should not change a fully configured config") + t.Error("EnsurePicoChannel() should not change a fully configured config") } cfg, err = config.LoadConfig(configPath) @@ -154,6 +155,71 @@ func TestEnsurePicoChannel_PreservesUserSettings(t *testing.T) { } } +func TestEnsurePicoChannel_ExistingConfigWithoutSecurityFile(t *testing.T) { + configPath := filepath.Join(t.TempDir(), "config.json") + + cfg := config.DefaultConfig() + raw, err := json.Marshal(cfg) + if err != nil { + t.Fatalf("Marshal() error = %v", err) + } + if err = os.WriteFile(configPath, raw, 0o600); err != nil { + t.Fatalf("WriteFile() error = %v", err) + } + + h := NewHandler(configPath) + + changed, err := h.EnsurePicoChannel("") + if err != nil { + t.Fatalf("EnsurePicoChannel() error = %v", err) + } + if !changed { + t.Fatal("EnsurePicoChannel() should report changed when pico is missing") + } + + cfg, err = config.LoadConfig(configPath) + if err != nil { + t.Fatalf("LoadConfig() error = %v", err) + } + + if !cfg.Channels.Pico.Enabled { + t.Error("expected Pico to be enabled after setup") + } + if cfg.Channels.Pico.Token() == "" { + t.Error("expected a non-empty token after setup") + } + if _, err := os.Stat(filepath.Join(filepath.Dir(configPath), config.SecurityConfigFile)); err != nil { + t.Fatalf("expected .security.yml to be created: %v", err) + } +} + +func TestEnsurePicoChannel_ConfiguresPicoWithoutGateway(t *testing.T) { + configPath := filepath.Join(t.TempDir(), "config.json") + + cfg := config.DefaultConfig() + cfg.Agents.Defaults.ModelName = "" + if err := config.SaveConfig(configPath, cfg); err != nil { + t.Fatalf("SaveConfig() error = %v", err) + } + + h := NewHandler(configPath) + if _, err := h.EnsurePicoChannel(""); err != nil { + t.Fatalf("EnsurePicoChannel() error = %v", err) + } + + cfg, err := config.LoadConfig(configPath) + if err != nil { + t.Fatalf("LoadConfig() error = %v", err) + } + + if !cfg.Channels.Pico.Enabled { + t.Error("expected Pico to be enabled after launcher startup setup") + } + if cfg.Channels.Pico.Token() == "" { + t.Error("expected a non-empty token after launcher startup setup") + } +} + func TestEnsurePicoChannel_Idempotent(t *testing.T) { configPath := filepath.Join(t.TempDir(), "config.json") h := NewHandler(configPath) @@ -161,20 +227,20 @@ func TestEnsurePicoChannel_Idempotent(t *testing.T) { origin := "http://localhost:18800" // First call sets things up - if _, err := h.ensurePicoChannel(origin); err != nil { - t.Fatalf("first ensurePicoChannel() error = %v", err) + if _, err := h.EnsurePicoChannel(origin); err != nil { + t.Fatalf("first EnsurePicoChannel() error = %v", err) } cfg1, _ := config.LoadConfig(configPath) token1 := cfg1.Channels.Pico.Token() // Second call should be a no-op - changed, err := h.ensurePicoChannel(origin) + changed, err := h.EnsurePicoChannel(origin) if err != nil { - t.Fatalf("second ensurePicoChannel() error = %v", err) + t.Fatalf("second EnsurePicoChannel() error = %v", err) } if changed { - t.Error("second ensurePicoChannel() should not report changed") + t.Error("second EnsurePicoChannel() should not report changed") } cfg2, _ := config.LoadConfig(configPath) diff --git a/web/backend/api/router.go b/web/backend/api/router.go index e4df86ed9..d09f68eac 100644 --- a/web/backend/api/router.go +++ b/web/backend/api/router.go @@ -17,15 +17,18 @@ type Handler struct { oauthMu sync.Mutex oauthFlows map[string]*oauthFlow oauthState map[string]string + weixinMu sync.Mutex + weixinFlows map[string]*weixinFlow } // NewHandler creates an instance of the API handler. func NewHandler(configPath string) *Handler { return &Handler{ - configPath: configPath, - serverPort: launcherconfig.DefaultPort, - oauthFlows: make(map[string]*oauthFlow), - oauthState: make(map[string]string), + configPath: configPath, + serverPort: launcherconfig.DefaultPort, + oauthFlows: make(map[string]*oauthFlow), + oauthState: make(map[string]string), + weixinFlows: make(map[string]*weixinFlow), } } @@ -69,6 +72,9 @@ func (h *Handler) RegisterRoutes(mux *http.ServeMux) { // Launcher service parameters (port/public) h.registerLauncherConfigRoutes(mux) + + // WeChat QR login flow + h.registerWeixinRoutes(mux) } // Shutdown gracefully shuts down the handler, stopping the gateway if it was started by this handler. diff --git a/web/backend/api/weixin.go b/web/backend/api/weixin.go new file mode 100644 index 000000000..808b88c41 --- /dev/null +++ b/web/backend/api/weixin.go @@ -0,0 +1,317 @@ +package api + +import ( + "context" + "crypto/rand" + "encoding/base64" + "encoding/hex" + "encoding/json" + "fmt" + "net/http" + "strings" + "time" + + "rsc.io/qr" + + "github.com/sipeed/picoclaw/pkg/channels/weixin" + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/logger" +) + +const ( + weixinFlowTTL = 5 * time.Minute + weixinFlowGCAge = 30 * time.Minute + weixinBaseURL = "https://ilinkai.weixin.qq.com/" + weixinBotType = "3" +) + +const ( + weixinStatusWait = "wait" + weixinStatusScanned = "scaned" + weixinStatusConfirmed = "confirmed" + weixinStatusExpired = "expired" + weixinStatusError = "error" +) + +type weixinFlow struct { + ID string + Qrcode string // qrcode token from WeChat API (used for status polling) + QRDataURI string // base64 PNG data URI for display + AccountID string // IlinkBotID returned on confirmed + Status string // wait / scaned / confirmed / expired / error + Error string + CreatedAt time.Time + UpdatedAt time.Time + ExpiresAt time.Time +} + +type weixinFlowResponse struct { + FlowID string `json:"flow_id"` + Status string `json:"status"` + QRDataURI string `json:"qr_data_uri,omitempty"` + AccountID string `json:"account_id,omitempty"` + Error string `json:"error,omitempty"` +} + +// registerWeixinRoutes binds WeChat QR login endpoints to the ServeMux. +func (h *Handler) registerWeixinRoutes(mux *http.ServeMux) { + mux.HandleFunc("POST /api/weixin/flows", h.handleStartWeixinFlow) + mux.HandleFunc("GET /api/weixin/flows/{id}", h.handlePollWeixinFlow) +} + +// handleStartWeixinFlow starts a new WeChat QR login flow. +// +// POST /api/weixin/flows +func (h *Handler) handleStartWeixinFlow(w http.ResponseWriter, r *http.Request) { + ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second) + defer cancel() + + api, err := weixin.NewApiClient(weixinBaseURL, "", "") + if err != nil { + http.Error(w, fmt.Sprintf("failed to create weixin client: %v", err), http.StatusInternalServerError) + return + } + + qrResp, err := api.GetQRCode(ctx, weixinBotType) + if err != nil { + http.Error(w, fmt.Sprintf("failed to get QR code: %v", err), http.StatusInternalServerError) + return + } + + dataURI, err := generateQRDataURI(qrResp.QrcodeImgContent) + if err != nil { + http.Error(w, fmt.Sprintf("failed to generate QR image: %v", err), http.StatusInternalServerError) + return + } + + now := time.Now() + flow := &weixinFlow{ + ID: newWeixinFlowID(), + Qrcode: qrResp.Qrcode, + QRDataURI: dataURI, + Status: weixinStatusWait, + CreatedAt: now, + UpdatedAt: now, + ExpiresAt: now.Add(weixinFlowTTL), + } + h.storeWeixinFlow(flow) + + logger.InfoCF("weixin", "QR flow started", map[string]any{"flow_id": flow.ID}) + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(weixinFlowResponse{ + FlowID: flow.ID, + Status: flow.Status, + QRDataURI: flow.QRDataURI, + }) +} + +// handlePollWeixinFlow polls the WeChat API for QR code status and updates the flow. +// +// GET /api/weixin/flows/{id} +func (h *Handler) handlePollWeixinFlow(w http.ResponseWriter, r *http.Request) { + flowID := strings.TrimSpace(r.PathValue("id")) + if flowID == "" { + http.Error(w, "missing flow id", http.StatusBadRequest) + return + } + + flow, ok := h.getWeixinFlow(flowID) + if !ok { + http.Error(w, "flow not found", http.StatusNotFound) + return + } + + // Return terminal states directly without polling WeChat again + if flow.Status == weixinStatusConfirmed || + flow.Status == weixinStatusExpired || + flow.Status == weixinStatusError { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(weixinFlowResponse{ + FlowID: flow.ID, + Status: flow.Status, + Error: flow.Error, + }) + return + } + + ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second) + defer cancel() + + api, err := weixin.NewApiClient(weixinBaseURL, "", "") + if err != nil { + h.setWeixinFlowError(flowID, fmt.Sprintf("client error: %v", err)) + flow, _ = h.getWeixinFlow(flowID) + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(weixinFlowResponse{FlowID: flow.ID, Status: flow.Status, Error: flow.Error}) + return + } + + statusResp, err := api.GetQRCodeStatus(ctx, flow.Qrcode) + if err != nil { + // Transient error — keep current status, return it + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(weixinFlowResponse{ + FlowID: flow.ID, + Status: flow.Status, + QRDataURI: flow.QRDataURI, + }) + return + } + + switch statusResp.Status { + case weixinStatusWait: + // no change + + case weixinStatusScanned: + h.updateWeixinFlowStatus(flowID, weixinStatusScanned) + + case weixinStatusConfirmed: + if statusResp.BotToken == "" { + h.setWeixinFlowError(flowID, "login confirmed but missing bot_token") + break + } + if saveErr := h.saveWeixinBinding(statusResp.BotToken, statusResp.IlinkBotID); saveErr != nil { + h.setWeixinFlowError(flowID, fmt.Sprintf("failed to save token: %v", saveErr)) + logger.ErrorCF("weixin", "failed to save token", map[string]any{"error": saveErr.Error()}) + break + } + h.setWeixinFlowConfirmed(flowID, statusResp.IlinkBotID) + logger.InfoCF("weixin", "QR login confirmed, token saved", map[string]any{ + "flow_id": flowID, + "account_id": statusResp.IlinkBotID, + }) + + case weixinStatusExpired: + h.updateWeixinFlowStatus(flowID, weixinStatusExpired) + + default: + // unknown status, keep as-is + } + + flow, _ = h.getWeixinFlow(flowID) + w.Header().Set("Content-Type", "application/json") + resp := weixinFlowResponse{ + FlowID: flow.ID, + Status: flow.Status, + AccountID: flow.AccountID, + Error: flow.Error, + } + if flow.Status == weixinStatusWait || flow.Status == weixinStatusScanned { + resp.QRDataURI = flow.QRDataURI + } + _ = json.NewEncoder(w).Encode(resp) +} + +// saveWeixinBinding writes the token/account ID, enables the Weixin channel, +// and best-effort restarts the gateway when it is currently running. +func (h *Handler) saveWeixinBinding(token, accountID string) error { + cfg, err := config.LoadConfig(h.configPath) + if err != nil { + return fmt.Errorf("load config: %w", err) + } + cfg.Channels.Weixin.SetToken(token) + cfg.Channels.Weixin.Enabled = true + if accountID != "" { + cfg.Channels.Weixin.AccountID = accountID + } + if err := config.SaveConfig(h.configPath, cfg); err != nil { + return err + } + + status := h.gatewayStatusData() + gatewayStatus, _ := status["gateway_status"].(string) + if gatewayStatus != "running" { + return nil + } + + if _, err := h.RestartGateway(); err != nil { + logger.ErrorCF("weixin", "failed to restart gateway after saving binding", map[string]any{ + "error": err.Error(), + }) + } + return nil +} + +// generateQRDataURI encodes content as a QR code PNG and returns a data URI. +func generateQRDataURI(content string) (string, error) { + code, err := qr.Encode(content, qr.L) + if err != nil { + return "", fmt.Errorf("qr encode: %w", err) + } + pngBytes := code.PNG() + encoded := base64.StdEncoding.EncodeToString(pngBytes) + return "data:image/png;base64," + encoded, nil +} + +func newWeixinFlowID() string { + buf := make([]byte, 12) + if _, err := rand.Read(buf); err != nil { + return fmt.Sprintf("wx_%d", time.Now().UnixNano()) + } + return "wx_" + hex.EncodeToString(buf) +} + +func (h *Handler) storeWeixinFlow(flow *weixinFlow) { + h.weixinMu.Lock() + defer h.weixinMu.Unlock() + h.gcWeixinFlowsLocked(time.Now()) + h.weixinFlows[flow.ID] = flow +} + +func (h *Handler) getWeixinFlow(flowID string) (*weixinFlow, bool) { + h.weixinMu.Lock() + defer h.weixinMu.Unlock() + h.gcWeixinFlowsLocked(time.Now()) + flow, ok := h.weixinFlows[flowID] + if !ok { + return nil, false + } + cp := *flow + return &cp, true +} + +func (h *Handler) updateWeixinFlowStatus(flowID, status string) { + h.weixinMu.Lock() + defer h.weixinMu.Unlock() + if flow, ok := h.weixinFlows[flowID]; ok { + flow.Status = status + flow.UpdatedAt = time.Now() + } +} + +func (h *Handler) setWeixinFlowConfirmed(flowID, accountID string) { + h.weixinMu.Lock() + defer h.weixinMu.Unlock() + if flow, ok := h.weixinFlows[flowID]; ok { + flow.Status = weixinStatusConfirmed + flow.AccountID = accountID + flow.UpdatedAt = time.Now() + } +} + +func (h *Handler) setWeixinFlowError(flowID, errMsg string) { + h.weixinMu.Lock() + defer h.weixinMu.Unlock() + if flow, ok := h.weixinFlows[flowID]; ok { + flow.Status = weixinStatusError + flow.Error = errMsg + flow.UpdatedAt = time.Now() + } +} + +func (h *Handler) gcWeixinFlowsLocked(now time.Time) { + for id, flow := range h.weixinFlows { + if flow.Status == weixinStatusWait || flow.Status == weixinStatusScanned { + if !flow.ExpiresAt.IsZero() && now.After(flow.ExpiresAt) { + flow.Status = weixinStatusExpired + flow.UpdatedAt = now + } + } + if flow.Status != weixinStatusWait && + flow.Status != weixinStatusScanned && + now.Sub(flow.UpdatedAt) > weixinFlowGCAge { + delete(h.weixinFlows, id) + } + } +} diff --git a/web/backend/api/weixin_test.go b/web/backend/api/weixin_test.go new file mode 100644 index 000000000..03342b72b --- /dev/null +++ b/web/backend/api/weixin_test.go @@ -0,0 +1,56 @@ +package api + +import ( + "io" + "net/http" + "os" + "path/filepath" + "strconv" + "strings" + "testing" + "time" + + "github.com/sipeed/picoclaw/pkg/config" +) + +func TestSaveWeixinBindingReturnsSuccessWhenRestartFails(t *testing.T) { + resetGatewayTestState(t) + + configPath := filepath.Join(t.TempDir(), "config.json") + cfg := config.DefaultConfig() + if err := config.SaveConfig(configPath, cfg); err != nil { + t.Fatalf("SaveConfig() error = %v", err) + } + + originalHealthGet := gatewayHealthGet + gatewayHealthGet = func(url string, timeout time.Duration) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader( + `{"status":"ok","uptime":"1s","pid":` + strconv.Itoa(os.Getpid()) + `}`, + )), + }, nil + } + t.Cleanup(func() { + gatewayHealthGet = originalHealthGet + }) + + h := NewHandler(configPath) + if err := h.saveWeixinBinding("bot-token", "bot-account"); err != nil { + t.Fatalf("saveWeixinBinding() error = %v, want nil after config save succeeds", err) + } + + savedCfg, err := config.LoadConfig(configPath) + if err != nil { + t.Fatalf("LoadConfig() error = %v", err) + } + if got := savedCfg.Channels.Weixin.Token(); got != "bot-token" { + t.Fatalf("Weixin.Token() = %q, want %q", got, "bot-token") + } + if got := savedCfg.Channels.Weixin.AccountID; got != "bot-account" { + t.Fatalf("Weixin.AccountID = %q, want %q", got, "bot-account") + } + if !savedCfg.Channels.Weixin.Enabled { + t.Fatalf("Weixin.Enabled = false, want true") + } +} diff --git a/web/backend/main.go b/web/backend/main.go index 8183731fe..2f181603e 100644 --- a/web/backend/main.go +++ b/web/backend/main.go @@ -169,6 +169,9 @@ func main() { // API Routes (e.g. /api/status) apiHandler = api.NewHandler(absPath) + if _, err = apiHandler.EnsurePicoChannel(""); err != nil { + logger.ErrorC("web", fmt.Sprintf("Warning: failed to ensure pico channel on startup: %v", err)) + } apiHandler.SetServerOptions(portNum, effectivePublic, explicitPublic, launcherCfg.AllowedCIDRs) apiHandler.RegisterRoutes(mux) diff --git a/web/frontend/src/api/channels.ts b/web/frontend/src/api/channels.ts index ecd77632c..d4c3ac74b 100644 --- a/web/frontend/src/api/channels.ts +++ b/web/frontend/src/api/channels.ts @@ -62,4 +62,26 @@ export async function patchAppConfig( }) } +// WeChat QR login flow API + +export interface WeixinFlowResponse { + flow_id: string + status: "wait" | "scaned" | "confirmed" | "expired" | "error" + qr_data_uri?: string + account_id?: string + error?: string +} + +export async function startWeixinFlow(): Promise { + return request("/api/weixin/flows", { method: "POST" }) +} + +export async function pollWeixinFlow( + flowID: string, +): Promise { + return request( + `/api/weixin/flows/${encodeURIComponent(flowID)}`, + ) +} + export type { ChannelsCatalogResponse, ConfigActionResponse } diff --git a/web/frontend/src/api/models.ts b/web/frontend/src/api/models.ts index 2fd042593..aa66a7389 100644 --- a/web/frontend/src/api/models.ts +++ b/web/frontend/src/api/models.ts @@ -21,6 +21,7 @@ export interface ModelInfo { // Meta configured: boolean is_default: boolean + is_virtual: boolean } interface ModelsListResponse { diff --git a/web/frontend/src/components/app-sidebar.tsx b/web/frontend/src/components/app-sidebar.tsx index 702212857..0e135c0c1 100644 --- a/web/frontend/src/components/app-sidebar.tsx +++ b/web/frontend/src/components/app-sidebar.tsx @@ -67,14 +67,17 @@ const baseNavGroups: Omit[] = [ export function AppSidebar({ ...props }: React.ComponentProps) { const routerState = useRouterState() - const { t } = useTranslation() + const { i18n, t } = useTranslation() const currentPath = routerState.location.pathname const { channelItems, hasMoreChannels, showAllChannels, toggleShowAllChannels, - } = useSidebarChannels({ t }) + } = useSidebarChannels({ + language: (i18n.resolvedLanguage ?? i18n.language ?? "").toLowerCase(), + t, + }) const navGroups: NavGroup[] = React.useMemo(() => { return [ diff --git a/web/frontend/src/components/channels/channel-config-page.tsx b/web/frontend/src/components/channels/channel-config-page.tsx index b19d11e6a..7f1f695bc 100644 --- a/web/frontend/src/components/channels/channel-config-page.tsx +++ b/web/frontend/src/components/channels/channel-config-page.tsx @@ -1,8 +1,6 @@ import { IconLoader2 } from "@tabler/icons-react" -import { useAtomValue } from "jotai" import { useCallback, useEffect, useMemo, useRef, useState } from "react" import { useTranslation } from "react-i18next" -import { toast } from "sonner" import { type ChannelConfig, @@ -17,10 +15,12 @@ import { FeishuForm } from "@/components/channels/channel-forms/feishu-form" import { GenericForm } from "@/components/channels/channel-forms/generic-form" import { SlackForm } from "@/components/channels/channel-forms/slack-form" import { TelegramForm } from "@/components/channels/channel-forms/telegram-form" +import { WeixinForm } from "@/components/channels/channel-forms/weixin-form" import { PageHeader } from "@/components/page-header" import { Button } from "@/components/ui/button" import { Switch } from "@/components/ui/switch" -import { gatewayAtom } from "@/store/gateway" +import { useGateway } from "@/hooks/use-gateway" +import { refreshGatewayState } from "@/store/gateway" interface ChannelConfigPageProps { channelName: string @@ -142,14 +142,10 @@ function isConfigured( ) case "onebot": return asString(config.ws_url) !== "" + case "weixin": + return asString(config.account_id) !== "" case "wecom": - return asString(config.token) !== "" - case "wecom_app": - return ( - asString(config.corp_id) !== "" && asString(config.corp_secret) !== "" - ) - case "wecom_aibot": - return asString(config.token) !== "" + return asString(config.bot_id) !== "" case "whatsapp": return asString(config.bridge_url) !== "" case "whatsapp_native": @@ -190,11 +186,7 @@ function getRequiredFieldKeys(channelName: string): string[] { case "onebot": return ["ws_url"] case "wecom": - return ["token"] - case "wecom_app": - return ["corp_id", "corp_secret"] - case "wecom_aibot": - return ["token"] + return ["bot_id", "secret"] case "whatsapp": return ["bridge_url"] case "pico": @@ -238,7 +230,7 @@ const CHANNELS_WITHOUT_DOCS = new Set([ export function ChannelConfigPage({ channelName }: ChannelConfigPageProps) { const { t, i18n } = useTranslation() - const gateway = useAtomValue(gatewayAtom) + const { state: gatewayState } = useGateway() const [loading, setLoading] = useState(true) const [saving, setSaving] = useState(false) @@ -251,56 +243,59 @@ export function ChannelConfigPage({ channelName }: ChannelConfigPageProps) { const [editConfig, setEditConfig] = useState({}) const [enabled, setEnabled] = useState(false) - const loadData = useCallback(async () => { - setLoading(true) - try { - const [catalog, appConfig] = await Promise.all([ - getChannelsCatalog(), - getAppConfig(), - ]) - const matched = - catalog.channels.find((item) => item.name === channelName) ?? null + const loadData = useCallback( + async (silent = false) => { + if (!silent) setLoading(true) + try { + const [catalog, appConfig] = await Promise.all([ + getChannelsCatalog(), + getAppConfig(), + ]) + const matched = + catalog.channels.find((item) => item.name === channelName) ?? null - if (!matched) { - setChannel(null) - setFetchError( - t("channels.page.notFound", { - name: channelName, - }), - ) - return + if (!matched) { + setChannel(null) + setFetchError( + t("channels.page.notFound", { + name: channelName, + }), + ) + return + } + + const channelsConfig = asRecord(asRecord(appConfig).channels) + const raw = asRecord(channelsConfig[matched.config_key]) + const normalized = normalizeConfig(matched, raw) + + setChannel(matched) + setBaseConfig(normalized) + setEditConfig(buildEditConfig(normalized)) + setEnabled(asBool(normalized.enabled)) + setFetchError("") + setServerError("") + setFieldErrors({}) + } catch (e) { + setFetchError(e instanceof Error ? e.message : t("channels.loadError")) + } finally { + if (!silent) setLoading(false) } - - const channelsConfig = asRecord(asRecord(appConfig).channels) - const raw = asRecord(channelsConfig[matched.config_key]) - const normalized = normalizeConfig(matched, raw) - - setChannel(matched) - setBaseConfig(normalized) - setEditConfig(buildEditConfig(normalized)) - setEnabled(asBool(normalized.enabled)) - setFetchError("") - setServerError("") - setFieldErrors({}) - } catch (e) { - setFetchError(e instanceof Error ? e.message : t("channels.loadError")) - } finally { - setLoading(false) - } - }, [channelName, t]) + }, + [channelName, t], + ) useEffect(() => { loadData() }, [loadData]) - const previousGatewayStatusRef = useRef(gateway.status) + const previousGatewayStatusRef = useRef(gatewayState) useEffect(() => { const previousStatus = previousGatewayStatusRef.current - if (previousStatus !== "running" && gateway.status === "running") { + if (previousStatus !== "running" && gatewayState === "running") { void loadData() } - previousGatewayStatusRef.current = gateway.status - }, [gateway.status, loadData]) + previousGatewayStatusRef.current = gatewayState + }, [gatewayState, loadData]) const savePayload = useMemo(() => { if (!channel) return null @@ -393,18 +388,28 @@ export function ChannelConfigPage({ channelName }: ChannelConfigPageProps) { [channel.config_key]: savePayload, }, }) - toast.success(t("channels.page.saveSuccess")) await loadData() } catch (e) { const message = e instanceof Error ? e.message : t("channels.page.saveError") setServerError(message) - toast.error(message) } finally { setSaving(false) } } + const handleWeixinBindSuccess = useCallback(async () => { + try { + setEnabled(true) + await Promise.all([loadData(true), refreshGatewayState({ force: true })]) + } catch (e) { + const message = + e instanceof Error ? e.message : t("channels.page.saveError") + setServerError(message) + await loadData(true) + } + }, [loadData, t]) + const renderForm = () => { if (!channel) return null const isEdit = configured @@ -446,6 +451,15 @@ export function ChannelConfigPage({ channelName }: ChannelConfigPageProps) { fieldErrors={fieldErrors} /> ) + case "weixin": + return ( + void handleWeixinBindSuccess()} + /> + ) default: return ( void + isEdit: boolean + onBindSuccess?: () => void +} + +function asString(value: unknown): string { + return typeof value === "string" ? value : "" +} + +function asStringArray(value: unknown): string[] { + if (!Array.isArray(value)) return [] + return value.filter((item): item is string => typeof item === "string") +} + +export function WeixinForm({ + config, + onChange, + isEdit, + onBindSuccess, +}: WeixinFormProps) { + const { t } = useTranslation() + + const [bindState, setBindState] = useState("idle") + const [qrDataURI, setQrDataURI] = useState(null) + const [accountID, setAccountID] = useState(null) + const [errorMsg, setErrorMsg] = useState("") + + const pollTimerRef = useRef | null>(null) + const pollGenerationRef = useRef(0) + const isBound = isEdit && asString(config.account_id) !== "" + const existingAccountID = asString(config.account_id) + + const stopPolling = useCallback(() => { + pollGenerationRef.current += 1 + if (pollTimerRef.current !== null) { + clearInterval(pollTimerRef.current) + pollTimerRef.current = null + } + }, []) + + useEffect(() => () => stopPolling(), [stopPolling]) + + useEffect(() => { + if (!existingAccountID) return + stopPolling() + setAccountID(existingAccountID) + setBindState("confirmed") + setErrorMsg("") + }, [existingAccountID, stopPolling]) + + const startPolling = useCallback( + (id: string) => { + stopPolling() + const generation = pollGenerationRef.current + let inFlight = false + pollTimerRef.current = setInterval(async () => { + if (inFlight) return + inFlight = true + try { + const resp = await pollWeixinFlow(id) + if (generation !== pollGenerationRef.current) { + return + } + if (resp.status === "scaned") { + setBindState("scaned") + } else if (resp.status === "confirmed") { + stopPolling() + setAccountID(resp.account_id ?? existingAccountID ?? null) + setBindState("confirmed") + onBindSuccess?.() + } else if (resp.status === "expired") { + stopPolling() + setBindState("expired") + } else if (resp.status === "error") { + stopPolling() + setBindState("error") + setErrorMsg(resp.error ?? t("channels.weixin.errorGeneric")) + } + } catch { + // transient network error — keep polling + } finally { + inFlight = false + } + }, 2000) + }, + [existingAccountID, stopPolling, onBindSuccess, t], + ) + + const handleBind = async () => { + setBindState("loading") + setErrorMsg("") + setQrDataURI(null) + stopPolling() + try { + const resp = await startWeixinFlow() + setQrDataURI(resp.qr_data_uri ?? null) + setBindState("waiting") + startPolling(resp.flow_id) + } catch (e) { + setBindState("error") + setErrorMsg( + e instanceof Error ? e.message : t("channels.weixin.errorGeneric"), + ) + } + } + + const handleRebind = () => { + stopPolling() + setBindState("idle") + setQrDataURI(null) + setAccountID(null) + setErrorMsg("") + void handleBind() + } + + const renderBindSection = () => { + if (bindState === "idle") { + if (isBound) { + return ( +
+
+ + {t("channels.weixin.bound")} +
+ {existingAccountID && ( +

+ {existingAccountID} +

+ )} + +
+ ) + } + return ( +
+

+ {t("channels.weixin.notBound")} +

+ +
+ ) + } + + if (bindState === "loading") { + return ( +
+ +

+ {t("channels.weixin.generating")} +

+
+ ) + } + + if (bindState === "waiting" || bindState === "scaned") { + return ( +
+ {qrDataURI ? ( + WeChat QR Code + ) : ( +
+ +
+ )} + {bindState === "scaned" ? ( +
+ + {t("channels.weixin.scanned")} +
+ ) : ( +

+ {t("channels.weixin.scanHint")} +

+ )} + +
+ ) + } + + if (bindState === "confirmed") { + return ( +
+
+ +
+

+ {t("channels.weixin.bound")} +

+ {accountID && ( +

+ {accountID} +

+ )} + +
+ ) + } + + if (bindState === "expired") { + return ( +
+
+ +
+

+ {t("channels.weixin.expired")} +

+ +
+ ) + } + + if (bindState === "error") { + return ( +
+
+ +
+

+ {errorMsg || t("channels.weixin.errorGeneric")} +

+ +
+ ) + } + + return null + } + + return ( +
+ {/* QR Bind Section */} +
+
+

+ {t("channels.weixin.bindTitle")} +

+

+ {t("channels.weixin.bindDesc")} +

+
+ {renderBindSection()} +
+ + {/* allow_from */} + + + onChange( + "allow_from", + e.target.value + .split(",") + .map((s: string) => s.trim()) + .filter(Boolean), + ) + } + placeholder={t("channels.field.allowFromPlaceholder")} + /> + + + {/* proxy */} + + onChange("proxy", e.target.value)} + placeholder="http://localhost:7890" + /> + +
+ ) +} diff --git a/web/frontend/src/components/chat/user-message.tsx b/web/frontend/src/components/chat/user-message.tsx index 84978e907..0035f14fa 100644 --- a/web/frontend/src/components/chat/user-message.tsx +++ b/web/frontend/src/components/chat/user-message.tsx @@ -5,7 +5,7 @@ interface UserMessageProps { export function UserMessage({ content }: UserMessageProps) { return (
-
+
{content}
diff --git a/web/frontend/src/components/config/config-sections.tsx b/web/frontend/src/components/config/config-sections.tsx index 5482b0a35..1f7426d22 100644 --- a/web/frontend/src/components/config/config-sections.tsx +++ b/web/frontend/src/components/config/config-sections.tsx @@ -1,3 +1,4 @@ +import { useState } from "react" import type { ReactNode } from "react" import { useTranslation } from "react-i18next" @@ -7,6 +8,7 @@ import { type LauncherForm, } from "@/components/config/form-model" import { Field, SwitchCardField } from "@/components/shared-form" +import { Button } from "@/components/ui/button" import { Card, CardContent, @@ -201,6 +203,56 @@ interface ExecSectionProps { export function ExecSection({ form, onFieldChange }: ExecSectionProps) { const { t } = useTranslation() + const [testCommand, setTestCommand] = useState("") + const [testResult, setTestResult] = useState<{ + allowed: boolean + blocked: boolean + matchedWhitelist: string | null + matchedBlacklist: string | null + } | null>(null) + const [isLoading, setIsLoading] = useState(false) + + const testPatterns = async () => { + if (!testCommand.trim()) { + setTestResult(null) + return + } + + const allowPatterns = form.customAllowPatternsText + .split("\n") + .map((p) => p.trim()) + .filter((p) => p.length > 0) + const denyPatterns = form.enableDenyPatterns + ? form.customDenyPatternsText + .split("\n") + .map((p) => p.trim()) + .filter((p) => p.length > 0) + : [] + + setIsLoading(true) + try { + const res = await fetch("/api/config/test-command-patterns", { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ + allow_patterns: allowPatterns, + deny_patterns: denyPatterns, + command: testCommand, + }), + }) + const data = await res.json() + setTestResult({ + allowed: data.allowed, + blocked: data.blocked, + matchedWhitelist: data.matched_whitelist ?? null, + matchedBlacklist: data.matched_blacklist ?? null, + }) + } catch { + setTestResult(null) + } finally { + setIsLoading(false) + } + } return ( @@ -266,6 +318,50 @@ export function ExecSection({ form, onFieldChange }: ExecSectionProps) { /> + +
+
+ setTestCommand(e.target.value)} + onKeyDown={(e) => { + if (e.key === "Enter") { + testPatterns() + } + }} + /> + +
+ {testResult && ( +
+ {testResult.allowed + ? `${t("pages.config.pattern_detector_result_allowed")}${testResult.matchedWhitelist ? ` (${testResult.matchedWhitelist})` : ""}` + : testResult.blocked + ? `${t("pages.config.pattern_detector_result_blocked")}${testResult.matchedBlacklist ? ` (${testResult.matchedBlacklist})` : ""}` + : t("pages.config.pattern_detector_result_no_match")} +
+ )} +
+
+ )} + {model.is_virtual && ( + + {t("models.badge.virtual")} + + )}
diff --git a/web/frontend/src/hooks/use-sidebar-channels.ts b/web/frontend/src/hooks/use-sidebar-channels.ts index 5579a955b..22fc24e57 100644 --- a/web/frontend/src/hooks/use-sidebar-channels.ts +++ b/web/frontend/src/hooks/use-sidebar-channels.ts @@ -28,15 +28,10 @@ import { getChannelDisplayName } from "@/components/channels/channel-display-nam import { gatewayAtom } from "@/store/gateway" const DEFAULT_VISIBLE_CHANNELS = 4 -const CHANNEL_IMPORTANCE_ORDER = [ - "discord", - "feishu", - "telegram", +const CHANNEL_IMPORTANCE_TAIL = [ "slack", "line", "wecom", - "wecom_app", - "wecom_aibot", "dingtalk", "qq", "onebot", @@ -47,9 +42,13 @@ const CHANNEL_IMPORTANCE_ORDER = [ "whatsapp", "whatsapp_native", ] -const CHANNEL_IMPORTANCE_INDEX = new Map( - CHANNEL_IMPORTANCE_ORDER.map((name, index) => [name, index]), -) + +function getChannelImportanceOrder(language: string): string[] { + const priority = language.startsWith("zh") + ? ["feishu", "weixin", "discord", "telegram"] + : ["discord", "telegram", "feishu", "weixin"] + return [...priority, ...CHANNEL_IMPORTANCE_TAIL] +} function IconLark({ className }: { className?: string }) { return React.createElement("span", { @@ -75,9 +74,8 @@ const CHANNEL_ICON_MAP: Record< dingtalk: IconBrandDingtalk, line: IconBrandLine, qq: IconBrandQq, + weixin: IconBrandWechat, wecom: IconBrandWechat, - wecom_app: IconBrandWechat, - wecom_aibot: IconBrandWechat, whatsapp: IconBrandWhatsapp, whatsapp_native: IconBrandWhatsapp, matrix: IconBrandMatrix, @@ -134,10 +132,11 @@ export interface SidebarChannelNavItem { } interface UseSidebarChannelsOptions { + language: string t: TFunction } -export function useSidebarChannels({ t }: UseSidebarChannelsOptions) { +export function useSidebarChannels({ language, t }: UseSidebarChannelsOptions) { const gateway = useAtomValue(gatewayAtom) const [channels, setChannels] = React.useState([]) const [enabledMap, setEnabledMap] = React.useState>( @@ -183,6 +182,12 @@ export function useSidebarChannels({ t }: UseSidebarChannelsOptions) { previousGatewayStatusRef.current = gateway.status }, [gateway.status, reloadChannels]) + const channelImportanceIndex = React.useMemo(() => { + return new Map( + getChannelImportanceOrder(language).map((name, index) => [name, index]), + ) + }, [language]) + const sortedChannels = React.useMemo(() => { const list = [...channels] list.sort((a, b) => { @@ -193,9 +198,9 @@ export function useSidebarChannels({ t }: UseSidebarChannelsOptions) { } const aImportance = - CHANNEL_IMPORTANCE_INDEX.get(a.name) ?? Number.MAX_SAFE_INTEGER + channelImportanceIndex.get(a.name) ?? Number.MAX_SAFE_INTEGER const bImportance = - CHANNEL_IMPORTANCE_INDEX.get(b.name) ?? Number.MAX_SAFE_INTEGER + channelImportanceIndex.get(b.name) ?? Number.MAX_SAFE_INTEGER if (aImportance !== bImportance) { return aImportance - bImportance } @@ -205,7 +210,7 @@ export function useSidebarChannels({ t }: UseSidebarChannelsOptions) { ) }) return list - }, [channels, enabledMap, t]) + }, [channelImportanceIndex, channels, enabledMap, t]) const hasMoreChannels = sortedChannels.length > DEFAULT_VISIBLE_CHANNELS const visibleChannels = showAllChannels diff --git a/web/frontend/src/i18n/locales/en.json b/web/frontend/src/i18n/locales/en.json index 66e39ad0e..eebf6e9fc 100644 --- a/web/frontend/src/i18n/locales/en.json +++ b/web/frontend/src/i18n/locales/en.json @@ -154,7 +154,8 @@ "unconfigured": "Not configured" }, "badge": { - "default": "Default" + "default": "Default", + "virtual": "Virtual" }, "action": { "edit": "Edit API key", @@ -233,14 +234,31 @@ "qq": "QQ", "onebot": "OneBot", "wecom": "WeCom", - "wecom_app": "WeCom App", - "wecom_aibot": "WeCom AI Bot", "whatsapp": "WhatsApp", "whatsapp_native": "WhatsApp Native", "pico": "Web", "maixcam": "MaixCam", "matrix": "Matrix", - "irc": "IRC" + "irc": "IRC", + "weixin": "WeChat" + }, + "weixin": { + "warningTitle": "Testing phase, use with caution", + "warningDesc": "The WeChat channel is still experimental and may carry a risk of account suspension. Use it only if you understand and accept the risk.", + "bindEnableSuccess": "WeChat connected and the channel has been enabled automatically.", + "bindTitle": "WeChat Account Binding", + "bindDesc": "Scan the QR code with WeChat to bind your personal account.", + "bind": "Bind WeChat", + "rebind": "Re-bind", + "bound": "WeChat Bound", + "notBound": "WeChat account not bound yet.", + "generating": "Generating QR code...", + "scanHint": "Open WeChat and scan the QR code", + "scanned": "Scanned — please confirm in WeChat", + "expired": "QR code expired", + "retry": "Try Again", + "refresh": "Refresh QR", + "errorGeneric": "An error occurred. Please try again." }, "field": { "token": "Bot Token", @@ -273,7 +291,9 @@ "saveError": "Failed to save channel configuration", "enabled": "enabled", "docLink": "Documentation", - "enableLabel": "Enable channel" + "enableLabel": "Enable channel", + "restartRequiredTitle": "Gateway restart required", + "restartRequiredDesc": "The latest {{name}} configuration has been saved. Restart the gateway for it to take effect." }, "form": { "desc": { @@ -413,6 +433,13 @@ "custom_allow_patterns": "Command Whitelist", "custom_allow_patterns_hint": "Add extra command-allow rules, one regular expression per line. A command matching any rule here skips blacklist matching, but other safety limits still apply.", "custom_patterns_placeholder": "^rm\\s+-rf\\b\n^git\\s+push\\b", + "pattern_detector_title": "Pattern Detection Tool", + "pattern_detector_hint": "Enter a command to test if it matches any blacklist or whitelist patterns.", + "pattern_detector_input_placeholder": "Enter a command to test, e.g., rm -rf /tmp", + "pattern_detector_test_button": "Test", + "pattern_detector_result_allowed": "Allowed (matches whitelist)", + "pattern_detector_result_blocked": "Blocked (matches blacklist)", + "pattern_detector_result_no_match": "No match (will use default rules)", "allow_shell_execution": "Allow Scheduled Commands", "allow_shell_execution_hint": "Allow scheduled tasks to run commands by default. When disabled, users must pass command_confirm=true to schedule a command task.", "cron_exec_timeout": "Scheduled Command Timeout (minutes)", diff --git a/web/frontend/src/i18n/locales/zh.json b/web/frontend/src/i18n/locales/zh.json index 65f2a5548..848bea15f 100644 --- a/web/frontend/src/i18n/locales/zh.json +++ b/web/frontend/src/i18n/locales/zh.json @@ -154,7 +154,8 @@ "unconfigured": "未配置" }, "badge": { - "default": "默认" + "default": "默认", + "virtual": "虚拟" }, "action": { "edit": "编辑 API Key", @@ -233,14 +234,31 @@ "qq": "QQ", "onebot": "OneBot", "wecom": "企业微信", - "wecom_app": "企业微信应用", - "wecom_aibot": "企业微信 AI 机器人", "whatsapp": "WhatsApp", "whatsapp_native": "WhatsApp Native", "pico": "Web", "maixcam": "MaixCam", "matrix": "Matrix", - "irc": "IRC" + "irc": "IRC", + "weixin": "微信" + }, + "weixin": { + "warningTitle": "测试阶段,请谨慎使用", + "warningDesc": "微信 Channel 当前仍处于测试阶段,存在封号风险。请仅在充分了解风险的前提下使用。", + "bindEnableSuccess": "微信已连接,频道已自动启用。", + "bindTitle": "微信账号绑定", + "bindDesc": "使用微信扫描二维码以绑定您的个人微信账号。", + "bind": "绑定微信", + "rebind": "重新绑定", + "bound": "微信已绑定", + "notBound": "尚未绑定微信账号。", + "generating": "正在生成二维码...", + "scanHint": "打开微信,扫描二维码", + "scanned": "已扫码 — 请在微信中确认", + "expired": "二维码已过期", + "retry": "重试", + "refresh": "刷新二维码", + "errorGeneric": "发生错误,请重试。" }, "field": { "token": "Bot Token", @@ -273,7 +291,9 @@ "saveError": "保存频道配置失败", "enabled": "已启用", "docLink": "配置文档", - "enableLabel": "启用频道" + "enableLabel": "启用频道", + "restartRequiredTitle": "需要重启服务", + "restartRequiredDesc": "{{name}} 的最新配置已保存。重启服务后才能正式生效。" }, "form": { "desc": { @@ -413,6 +433,13 @@ "custom_allow_patterns": "命令白名单", "custom_allow_patterns_hint": "用于补充额外的命令放行规则,每行一个正则表达式。命中任意一条规则的命令会跳过黑名单检查,但仍受其他安全限制约束。", "custom_patterns_placeholder": "^rm\\s+-rf\\b\n^git\\s+push\\b", + "pattern_detector_title": "规则检测工具", + "pattern_detector_hint": "输入命令以检测其是否匹配黑名单或白名单规则。", + "pattern_detector_input_placeholder": "输入要检测的命令,例如 rm -rf /tmp", + "pattern_detector_test_button": "检测", + "pattern_detector_result_allowed": "允许(匹配白名单)", + "pattern_detector_result_blocked": "阻止(匹配黑名单)", + "pattern_detector_result_no_match": "无匹配(将使用默认规则)", "allow_shell_execution": "允许定时任务运行命令", "allow_shell_execution_hint": "开启后,定时任务默认允许运行命令。关闭后,必须显式传入 command_confirm=true 才能创建运行命令的定时任务。", "cron_exec_timeout": "定时命令超时(分钟)",