From c3631d84ba5ab68c7cb008681e317f32c728fba7 Mon Sep 17 00:00:00 2001 From: Hoshina Date: Tue, 24 Mar 2026 16:12:28 +0800 Subject: [PATCH] feat(wecom): send media via temp uploads --- pkg/channels/wecom/media.go | 513 ++++++++++++++++++++++++++++++- pkg/channels/wecom/protocol.go | 64 +++- pkg/channels/wecom/wecom.go | 146 +++++++-- pkg/channels/wecom/wecom_test.go | 351 ++++++++++++++++++++- 4 files changed, 1037 insertions(+), 37 deletions(-) diff --git a/pkg/channels/wecom/media.go b/pkg/channels/wecom/media.go index defe226d4..ebcc481e8 100644 --- a/pkg/channels/wecom/media.go +++ b/pkg/channels/wecom/media.go @@ -4,7 +4,10 @@ import ( "context" "crypto/aes" "crypto/cipher" + "crypto/md5" "encoding/base64" + "encoding/hex" + "encoding/json" "fmt" "io" "mime" @@ -13,12 +16,73 @@ import ( "os" "path/filepath" "strings" + "time" "github.com/h2non/filetype" + "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/media" ) +const ( + wecomOutboundMediaMaxBytes = 20 << 20 + wecomOutboundImageMaxBytes = 2 << 20 + wecomOutboundVoiceMaxBytes = 2 << 20 + wecomOutboundVideoMaxBytes = 10 << 20 + wecomUploadChunkMaxBytes = 512 << 10 + wecomUploadMaxChunks = 100 + wecomUploadMinBytes = 5 +) + +type wecomOutboundMedia struct { + MsgType string + MediaID string + Title string + Description string +} + +func (m *wecomOutboundMedia) respondBody() wecomRespondMsgBody { + body := wecomRespondMsgBody{MsgType: m.MsgType} + switch m.MsgType { + case "file": + body.File = &wecomMediaRefContent{MediaID: m.MediaID} + case "image": + body.Image = &wecomMediaRefContent{MediaID: m.MediaID} + case "voice": + body.Voice = &wecomMediaRefContent{MediaID: m.MediaID} + case "video": + body.Video = &wecomVideoContent{ + MediaID: m.MediaID, + Title: m.Title, + Description: m.Description, + } + } + return body +} + +func (m *wecomOutboundMedia) sendBody(chatID string, chatType uint32) wecomSendMsgBody { + body := wecomSendMsgBody{ + ChatID: chatID, + ChatType: chatType, + MsgType: m.MsgType, + } + switch m.MsgType { + case "file": + body.File = &wecomMediaRefContent{MediaID: m.MediaID} + case "image": + body.Image = &wecomMediaRefContent{MediaID: m.MediaID} + case "voice": + body.Voice = &wecomMediaRefContent{MediaID: m.MediaID} + case "video": + body.Video = &wecomVideoContent{ + MediaID: m.MediaID, + Title: m.Title, + Description: m.Description, + } + } + return body +} + func decodeMediaAESKey(value string) ([]byte, error) { if value == "" { return nil, nil @@ -227,12 +291,11 @@ func (c *WeComChannel) storeRemoteMedia( return "", fmt.Errorf("download media returned HTTP %d", resp.StatusCode) } - const maxSize = 20 << 20 - data, err := io.ReadAll(io.LimitReader(resp.Body, maxSize+1)) + data, err := io.ReadAll(io.LimitReader(resp.Body, wecomOutboundMediaMaxBytes+1)) if err != nil { return "", fmt.Errorf("read media: %w", err) } - if len(data) > maxSize { + if len(data) > wecomOutboundMediaMaxBytes { return "", fmt.Errorf("media too large") } @@ -289,3 +352,447 @@ func (c *WeComChannel) storeRemoteMedia( } return ref, nil } + +func detectLocalWeComContentType(localPath, hint string) string { + contentType := normalizeWeComContentType(hint) + if !isGenericWeComContentType(contentType) { + return contentType + } + + if kind, err := filetype.MatchFile(localPath); err == nil && kind != filetype.Unknown { + return normalizeWeComContentType(kind.MIME.Value) + } + + if ext := strings.ToLower(filepath.Ext(localPath)); ext != "" { + if byExt := normalizeWeComContentType(mime.TypeByExtension(ext)); byExt != "" { + return byExt + } + } + + file, err := os.Open(localPath) + if err != nil { + return contentType + } + defer file.Close() + + buf := make([]byte, 512) + n, err := file.Read(buf) + if err != nil && err != io.EOF { + return contentType + } + if n == 0 { + return contentType + } + return normalizeWeComContentType(http.DetectContentType(buf[:n])) +} + +func writeWeComTempFile(prefix, filename string, data []byte) (string, error) { + mediaDir := filepath.Join(os.TempDir(), "picoclaw_media") + if err := os.MkdirAll(mediaDir, 0o700); err != nil { + return "", fmt.Errorf("mkdir media dir: %w", err) + } + + ext := strings.ToLower(filepath.Ext(filename)) + tmpFile, err := os.CreateTemp(mediaDir, prefix+"-*"+ext) + if err != nil { + return "", fmt.Errorf("create temp file: %w", err) + } + tmpPath := tmpFile.Name() + + if _, err := tmpFile.Write(data); err != nil { + _ = tmpFile.Close() + _ = os.Remove(tmpPath) + return "", fmt.Errorf("write temp file: %w", err) + } + if err := tmpFile.Close(); err != nil { + _ = os.Remove(tmpPath) + return "", fmt.Errorf("close temp file: %w", err) + } + return tmpPath, nil +} + +func (c *WeComChannel) downloadRemoteMediaToTemp( + ctx context.Context, + resourceURL, fallbackName string, +) (string, string, string, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, resourceURL, nil) + if err != nil { + return "", "", "", fmt.Errorf("create request: %w", err) + } + + resp, err := c.mediaClient.Do(req) + if err != nil { + return "", "", "", fmt.Errorf("download media: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024)) + return "", "", "", fmt.Errorf("download media returned HTTP %d: %s", resp.StatusCode, string(body)) + } + + data, err := io.ReadAll(io.LimitReader(resp.Body, wecomOutboundMediaMaxBytes+1)) + if err != nil { + return "", "", "", fmt.Errorf("read media: %w", err) + } + if len(data) > wecomOutboundMediaMaxBytes { + return "", "", "", fmt.Errorf("media too large") + } + + filename, contentType := detectWeComMediaMetadata( + data, + fallbackName, + resp.Header.Get("Content-Type"), + resourceURL, + resp.Header.Get("Content-Disposition"), + ) + tmpPath, err := writeWeComTempFile("wecom-outbound", filename, data) + if err != nil { + return "", "", "", err + } + return tmpPath, filename, contentType, nil +} + +func (c *WeComChannel) resolveOutboundPart( + ctx context.Context, + part bus.MediaPart, +) (string, string, string, func(), error) { + cleanup := func() {} + filename := sanitizeWeComFilename(part.Filename) + contentType := normalizeWeComContentType(part.ContentType) + ref := strings.TrimSpace(part.Ref) + + switch { + case ref == "": + return "", filename, contentType, cleanup, nil + + case strings.HasPrefix(ref, "http://") || strings.HasPrefix(ref, "https://"): + localPath, name, ct, err := c.downloadRemoteMediaToTemp(ctx, ref, filename) + if err != nil { + return "", "", "", cleanup, err + } + return localPath, name, ct, func() { _ = os.Remove(localPath) }, nil + + case strings.HasPrefix(ref, "media://"): + store := c.GetMediaStore() + if store == nil { + return "", "", "", cleanup, fmt.Errorf("no media store available") + } + + localPath, meta, err := store.ResolveWithMeta(ref) + if err != nil { + return "", "", "", cleanup, err + } + if filename == "" { + filename = sanitizeWeComFilename(meta.Filename) + } + if contentType == "" { + contentType = normalizeWeComContentType(meta.ContentType) + } + if strings.HasPrefix(localPath, "http://") || strings.HasPrefix(localPath, "https://") { + tmpPath, name, ct, err := c.downloadRemoteMediaToTemp(ctx, localPath, filename) + if err != nil { + return "", "", "", cleanup, err + } + return tmpPath, name, ct, func() { _ = os.Remove(tmpPath) }, nil + } + if _, err := os.Stat(localPath); err != nil { + return "", "", "", cleanup, err + } + if filename == "" { + filename = sanitizeWeComFilename(filepath.Base(localPath)) + } + if contentType == "" { + contentType = detectLocalWeComContentType(localPath, "") + } + return localPath, filename, contentType, cleanup, nil + + case strings.HasPrefix(ref, "file://"): + u, err := url.Parse(ref) + if err != nil { + return "", "", "", cleanup, err + } + localPath := u.Path + if _, err := os.Stat(localPath); err != nil { + return "", "", "", cleanup, err + } + if filename == "" { + filename = sanitizeWeComFilename(filepath.Base(localPath)) + } + if contentType == "" { + contentType = detectLocalWeComContentType(localPath, "") + } + return localPath, filename, contentType, cleanup, nil + + default: + if _, err := os.Stat(ref); err != nil { + return "", "", "", cleanup, err + } + if filename == "" { + filename = sanitizeWeComFilename(filepath.Base(ref)) + } + if contentType == "" { + contentType = detectLocalWeComContentType(ref, "") + } + return ref, filename, contentType, cleanup, nil + } +} + +func canWeComSendImage(contentType, ext string, size int64) bool { + if size > wecomOutboundImageMaxBytes { + return false + } + switch normalizeWeComContentType(contentType) { + case "image/jpeg", "image/jpg", "image/png", "image/gif": + return true + } + switch strings.ToLower(ext) { + case ".jpg", ".jpeg", ".png", ".gif": + return true + default: + return false + } +} + +func canWeComSendVoice(contentType, ext string, size int64) bool { + if size > wecomOutboundVoiceMaxBytes { + return false + } + contentType = normalizeWeComContentType(contentType) + return strings.Contains(contentType, "amr") || strings.EqualFold(ext, ".amr") +} + +func canWeComSendVideo(contentType, ext string, size int64) bool { + if size > wecomOutboundVideoMaxBytes { + return false + } + return normalizeWeComContentType(contentType) == "video/mp4" || strings.EqualFold(ext, ".mp4") +} + +func outboundWeComMediaKind(partType, filename, contentType string, size int64) string { + if size < wecomUploadMinBytes { + return "" + } + + partType = strings.ToLower(strings.TrimSpace(partType)) + contentType = normalizeWeComContentType(contentType) + ext := strings.ToLower(filepath.Ext(filename)) + + if partType == "file" { + if size <= wecomOutboundMediaMaxBytes { + return "file" + } + return "" + } + + if (partType == "image" || partType == "") && canWeComSendImage(contentType, ext, size) { + return "image" + } + if (partType == "audio" || partType == "voice" || partType == "") && canWeComSendVoice(contentType, ext, size) { + return "voice" + } + if (partType == "video" || partType == "") && canWeComSendVideo(contentType, ext, size) { + return "video" + } + if size <= wecomOutboundMediaMaxBytes { + return "file" + } + return "" +} + +func trimWeComBytes(value string, limit int) string { + value = strings.TrimSpace(value) + if limit <= 0 || len(value) <= limit { + return value + } + size := 0 + var out strings.Builder + for _, r := range value { + width := len(string(r)) + if size+width > limit { + break + } + size += width + out.WriteRune(r) + } + return out.String() +} + +func ensureWeComOutboundFilename(filename, localPath, contentType string) string { + filename = sanitizeWeComFilename(filename) + if filename == "" { + filename = sanitizeWeComFilename(filepath.Base(localPath)) + } + if filename == "" { + filename = "media" + } + if filepath.Ext(filename) == "" { + fallbackExt := inferMediaExt(contentType, strings.ToLower(filepath.Ext(localPath))) + if fallbackExt != "" { + filename += fallbackExt + } + } + filename = trimWeComBytes(filename, 256) + if filename == "" { + return "media" + } + return filename +} + +func buildWeComVideoContent(mediaID, filename, description string) *wecomVideoContent { + title := strings.TrimSuffix(filename, filepath.Ext(filename)) + title = trimWeComBytes(title, 64) + if title == "" { + title = "video" + } + description = trimWeComBytes(description, 512) + return &wecomVideoContent{ + MediaID: mediaID, + Title: title, + Description: description, + } +} + +func decodeWeComEnvelopeBody[T any](env wecomEnvelope) (T, error) { + var out T + if len(env.Body) == 0 { + return out, fmt.Errorf("wecom response body is empty") + } + if err := json.Unmarshal(env.Body, &out); err != nil { + return out, fmt.Errorf("decode wecom response body: %w", err) + } + return out, nil +} + +func (c *WeComChannel) uploadOutboundMedia( + ctx context.Context, + localPath, filename, contentType string, + part bus.MediaPart, +) (*wecomOutboundMedia, error) { + _ = ctx + + contentType = detectLocalWeComContentType(localPath, contentType) + filename = ensureWeComOutboundFilename(filename, localPath, contentType) + + data, err := os.ReadFile(localPath) + if err != nil { + return nil, fmt.Errorf("read media file: %w", err) + } + size := int64(len(data)) + kind := outboundWeComMediaKind(part.Type, filename, contentType, size) + if kind == "" { + return nil, fmt.Errorf("unsupported wecom media type or size for %q", filename) + } + + totalChunks := (len(data) + wecomUploadChunkMaxBytes - 1) / wecomUploadChunkMaxBytes + if totalChunks <= 0 || totalChunks > wecomUploadMaxChunks { + return nil, fmt.Errorf("wecom upload requires 1-%d chunks, got %d", wecomUploadMaxChunks, totalChunks) + } + + sum := md5.Sum(data) + initEnv, err := c.sendCommandAck(wecomCommand{ + Cmd: wecomCmdUploadMediaInit, + Headers: wecomHeaders{ReqID: randomID(10)}, + Body: wecomUploadMediaInitBody{ + Type: kind, + Filename: filename, + TotalSize: size, + TotalChunks: totalChunks, + MD5: hex.EncodeToString(sum[:]), + }, + }, wecomUploadTimeout) + if err != nil { + return nil, err + } + initResp, err := decodeWeComEnvelopeBody[wecomUploadMediaInitResponse](initEnv) + if err != nil { + return nil, err + } + if strings.TrimSpace(initResp.UploadID) == "" { + return nil, fmt.Errorf("wecom upload init returned empty upload_id") + } + + for idx, offset := 0, 0; offset < len(data); idx, offset = idx+1, offset+wecomUploadChunkMaxBytes { + end := offset + wecomUploadChunkMaxBytes + if end > len(data) { + end = len(data) + } + if err := c.sendCommand(wecomCommand{ + Cmd: wecomCmdUploadMediaChunk, + Headers: wecomHeaders{ReqID: randomID(10)}, + Body: wecomUploadMediaChunkBody{ + UploadID: initResp.UploadID, + ChunkIndex: idx, + Base64Data: base64.StdEncoding.EncodeToString(data[offset:end]), + }, + }, wecomUploadTimeout); err != nil { + return nil, err + } + } + + finishEnv, err := c.sendCommandAck(wecomCommand{ + Cmd: wecomCmdUploadMediaEnd, + Headers: wecomHeaders{ReqID: randomID(10)}, + Body: wecomUploadMediaFinishBody{ + UploadID: initResp.UploadID, + }, + }, wecomUploadTimeout) + if err != nil { + return nil, err + } + finishResp, err := decodeWeComEnvelopeBody[wecomUploadMediaFinishResponse](finishEnv) + if err != nil { + return nil, err + } + if strings.TrimSpace(finishResp.MediaID) == "" { + return nil, fmt.Errorf("wecom upload finish returned empty media_id") + } + + uploaded := &wecomOutboundMedia{ + MsgType: kind, + MediaID: finishResp.MediaID, + } + if kind == "video" { + video := buildWeComVideoContent(finishResp.MediaID, filename, part.Caption) + uploaded.Title = video.Title + uploaded.Description = video.Description + } + return uploaded, nil +} + +func fallbackWeComMediaText(part bus.MediaPart, kind, filename string) string { + var lines []string + if caption := strings.TrimSpace(part.Caption); caption != "" { + lines = append(lines, caption) + } + + label := kind + if label == "" { + label = "media" + } + if filename != "" { + lines = append(lines, fmt.Sprintf("[%s: %s]", label, filename)) + } else { + lines = append(lines, fmt.Sprintf("[%s attachment]", label)) + } + + ref := strings.TrimSpace(part.Ref) + if strings.HasPrefix(ref, "http://") || strings.HasPrefix(ref, "https://") { + lines = append(lines, ref) + } + + return strings.Join(lines, "\n") +} + +func (c *WeComChannel) resolveMediaRoute(chatID string) (wecomTurn, uint32, bool) { + if turn, ok := c.getTurn(chatID); ok { + if time.Since(turn.CreatedAt) <= wecomStreamMaxDuration { + return turn, turn.ChatType, true + } + c.deleteTurn(chatID) + } + if route, ok := c.routes.Get(chatID); ok { + return wecomTurn{ChatID: route.ChatID, ChatType: route.ChatType}, route.ChatType, false + } + return wecomTurn{ChatID: chatID}, 0, false +} diff --git a/pkg/channels/wecom/protocol.go b/pkg/channels/wecom/protocol.go index 6867d8856..0190e70e5 100644 --- a/pkg/channels/wecom/protocol.go +++ b/pkg/channels/wecom/protocol.go @@ -10,6 +10,9 @@ const ( wecomCmdEventCallback = "aibot_event_callback" wecomCmdRespondMsg = "aibot_respond_msg" wecomCmdSendMsg = "aibot_send_msg" + wecomCmdUploadMediaInit = "aibot_upload_media_init" + wecomCmdUploadMediaChunk = "aibot_upload_media_chunk" + wecomCmdUploadMediaEnd = "aibot_upload_media_finish" wecomMaxContentBytes = 20480 ) @@ -32,15 +35,26 @@ type wecomCommand struct { } type wecomSendMsgBody struct { - ChatID string `json:"chatid"` - ChatType uint32 `json:"chat_type,omitempty"` - MsgType string `json:"msgtype"` - Markdown *wecomMarkdownContent `json:"markdown,omitempty"` + ChatID string `json:"chatid"` + ChatType uint32 `json:"chat_type,omitempty"` + MsgType string `json:"msgtype"` + Markdown *wecomMarkdownContent `json:"markdown,omitempty"` + File *wecomMediaRefContent `json:"file,omitempty"` + Image *wecomMediaRefContent `json:"image,omitempty"` + Voice *wecomMediaRefContent `json:"voice,omitempty"` + Video *wecomVideoContent `json:"video,omitempty"` + TemplateCard map[string]any `json:"template_card,omitempty"` } type wecomRespondMsgBody struct { - MsgType string `json:"msgtype"` - Stream *wecomStreamContent `json:"stream,omitempty"` + MsgType string `json:"msgtype"` + Stream *wecomStreamContent `json:"stream,omitempty"` + Markdown *wecomMarkdownContent `json:"markdown,omitempty"` + File *wecomMediaRefContent `json:"file,omitempty"` + Image *wecomMediaRefContent `json:"image,omitempty"` + Voice *wecomMediaRefContent `json:"voice,omitempty"` + Video *wecomVideoContent `json:"video,omitempty"` + TemplateCard map[string]any `json:"template_card,omitempty"` } type wecomStreamContent struct { @@ -53,6 +67,44 @@ type wecomMarkdownContent struct { Content string `json:"content"` } +type wecomMediaRefContent struct { + MediaID string `json:"media_id"` +} + +type wecomVideoContent struct { + MediaID string `json:"media_id"` + Title string `json:"title,omitempty"` + Description string `json:"description,omitempty"` +} + +type wecomUploadMediaInitBody struct { + Type string `json:"type"` + Filename string `json:"filename"` + TotalSize int64 `json:"total_size"` + TotalChunks int `json:"total_chunks"` + MD5 string `json:"md5,omitempty"` +} + +type wecomUploadMediaInitResponse struct { + UploadID string `json:"upload_id"` +} + +type wecomUploadMediaChunkBody struct { + UploadID string `json:"upload_id"` + ChunkIndex int `json:"chunk_index"` + Base64Data string `json:"base64_data"` +} + +type wecomUploadMediaFinishBody struct { + UploadID string `json:"upload_id"` +} + +type wecomUploadMediaFinishResponse struct { + Type string `json:"type"` + MediaID string `json:"media_id"` + CreatedAt json.RawMessage `json:"created_at"` +} + type wecomIncomingMessage struct { MsgID string `json:"msgid"` AIBotID string `json:"aibotid"` diff --git a/pkg/channels/wecom/wecom.go b/pkg/channels/wecom/wecom.go index 11959c259..ac8f8d9c8 100644 --- a/pkg/channels/wecom/wecom.go +++ b/pkg/channels/wecom/wecom.go @@ -23,6 +23,7 @@ import ( const ( wecomConnectTimeout = 15 * time.Second wecomCommandTimeout = 10 * time.Second + wecomUploadTimeout = 30 * time.Second wecomHeartbeatInterval = 30 * time.Second wecomStreamMaxDuration = 5*time.Minute + 30*time.Second wecomRouteTTL = 30 * time.Minute @@ -49,7 +50,7 @@ type WeComChannel struct { recent *recentMessageSet routes *reqIDStore mediaClient *http.Client - commandSend func(wecomCommand, time.Duration) error + commandSend func(wecomCommand, time.Duration) (wecomEnvelope, error) } type wecomTurn struct { @@ -187,22 +188,74 @@ func (c *WeComChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessa if !c.IsRunning() { return channels.ErrNotRunning } - var parts []string + + route, chatType, hasTurn := c.resolveMediaRoute(msg.ChatID) + chatID := route.ChatID + if chatID == "" { + chatID = msg.ChatID + } + for _, part := range msg.Parts { - switch { - case part.Caption != "": - parts = append(parts, part.Caption) - case part.Filename != "": - parts = append(parts, fmt.Sprintf("[media: %s]", part.Filename)) - default: - parts = append(parts, "[media attachments are not yet supported]") + if strings.TrimSpace(part.Ref) == "" { + if caption := strings.TrimSpace(part.Caption); caption != "" { + if err := c.sendActivePush(chatID, chatType, caption); err != nil { + return err + } + } + continue + } + + localPath, filename, contentType, cleanup, err := c.resolveOutboundPart(ctx, part) + if err != nil { + return fmt.Errorf("wecom resolve media %q: %v: %w", part.Ref, err, channels.ErrSendFailed) + } + + func() { + if cleanup != nil { + defer cleanup() + } + + uploaded, uploadErr := c.uploadOutboundMedia(ctx, localPath, filename, contentType, part) + if uploadErr != nil { + logger.WarnCF("wecom", "Falling back to placeholder after media upload failure", map[string]any{ + "chat_id": chatID, + "ref": part.Ref, + "filename": filename, + "content_type": contentType, + "error": uploadErr.Error(), + }) + if hasTurn { + if finishErr := c.sendStreamChunk(route, true, ""); finishErr != nil { + err = finishErr + return + } + c.deleteTurn(msg.ChatID) + hasTurn = false + } + err = c.sendActivePush(chatID, chatType, fallbackWeComMediaText(part, "", filename)) + return + } + + if hasTurn { + err = c.sendTurnMedia(route, uploaded) + c.deleteTurn(msg.ChatID) + hasTurn = false + } else { + err = c.sendActiveMedia(chatID, chatType, uploaded) + } + if err != nil { + return + } + if caption := strings.TrimSpace(part.Caption); caption != "" { + err = c.sendActivePush(chatID, chatType, caption) + } + }() + if err != nil { + return err } } - return c.Send(ctx, bus.OutboundMessage{ - Channel: msg.Channel, - ChatID: msg.ChatID, - Content: strings.Join(parts, "\n"), - }) + + return nil } func (c *WeComChannel) connectLoop() { @@ -620,6 +673,20 @@ func (c *WeComChannel) sendStreamChunk(turn wecomTurn, finish bool, content stri }, wecomCommandTimeout) } +func (c *WeComChannel) sendTurnMedia(turn wecomTurn, uploaded *wecomOutboundMedia) error { + if uploaded == nil { + return fmt.Errorf("wecom outbound media is nil: %w", channels.ErrSendFailed) + } + if err := c.sendCommand(wecomCommand{ + Cmd: wecomCmdRespondMsg, + Headers: wecomHeaders{ReqID: turn.ReqID}, + Body: uploaded.respondBody(), + }, wecomCommandTimeout); err != nil { + return err + } + return c.sendStreamChunk(turn, true, "") +} + func (c *WeComChannel) sendActivePush(chatID string, chatType uint32, content string) error { if strings.TrimSpace(chatID) == "" { return fmt.Errorf("empty chat ID: %w", channels.ErrSendFailed) @@ -641,24 +708,57 @@ func (c *WeComChannel) sendActivePush(chatID string, chatType uint32, content st return nil } +func (c *WeComChannel) sendActiveMedia(chatID string, chatType uint32, uploaded *wecomOutboundMedia) error { + if strings.TrimSpace(chatID) == "" { + return fmt.Errorf("empty chat ID: %w", channels.ErrSendFailed) + } + if uploaded == nil { + return fmt.Errorf("wecom outbound media is nil: %w", channels.ErrSendFailed) + } + return c.sendCommand(wecomCommand{ + Cmd: wecomCmdSendMsg, + Headers: wecomHeaders{ReqID: randomID(10)}, + Body: uploaded.sendBody(chatID, chatType), + }, wecomCommandTimeout) +} + func (c *WeComChannel) sendCommand(cmd wecomCommand, timeout time.Duration) error { + _, err := c.sendCommandAck(cmd, timeout) + return err +} + +func (c *WeComChannel) sendCommandAck(cmd wecomCommand, timeout time.Duration) (wecomEnvelope, error) { if c.commandSend != nil { return c.commandSend(cmd, timeout) } - return c.writeCurrent(cmd, timeout) + return c.writeCurrentAck(cmd, timeout) } func (c *WeComChannel) writeCurrent(cmd wecomCommand, timeout time.Duration) error { + _, err := c.writeCurrentAck(cmd, timeout) + return err +} + +func (c *WeComChannel) writeCurrentAck(cmd wecomCommand, timeout time.Duration) (wecomEnvelope, error) { c.connMu.Lock() conn := c.conn c.connMu.Unlock() if conn == nil { - return fmt.Errorf("wecom websocket not connected: %w", channels.ErrTemporary) + return wecomEnvelope{}, fmt.Errorf("wecom websocket not connected: %w", channels.ErrTemporary) } - return c.writeAndWait(conn, cmd, timeout) + return c.writeAndWaitAck(conn, cmd, timeout) } func (c *WeComChannel) writeAndWait(conn *websocket.Conn, cmd wecomCommand, timeout time.Duration) error { + _, err := c.writeAndWaitAck(conn, cmd, timeout) + return err +} + +func (c *WeComChannel) writeAndWaitAck( + conn *websocket.Conn, + cmd wecomCommand, + timeout time.Duration, +) (wecomEnvelope, error) { if cmd.Headers.ReqID == "" { cmd.Headers.ReqID = randomID(10) } @@ -674,13 +774,13 @@ func (c *WeComChannel) writeAndWait(conn *websocket.Conn, cmd wecomCommand, time data, err := json.Marshal(cmd) if err != nil { - return fmt.Errorf("%w: %v", channels.ErrSendFailed, err) + return wecomEnvelope{}, fmt.Errorf("%w: %v", channels.ErrSendFailed, err) } c.connMu.Lock() err = conn.WriteMessage(websocket.TextMessage, data) c.connMu.Unlock() if err != nil { - return fmt.Errorf("%w: %v", channels.ErrTemporary, err) + return wecomEnvelope{}, fmt.Errorf("%w: %v", channels.ErrTemporary, err) } timer := time.NewTimer(timeout) @@ -688,13 +788,13 @@ func (c *WeComChannel) writeAndWait(conn *websocket.Conn, cmd wecomCommand, time select { case env := <-waitCh: if env.ErrCode != 0 { - return fmt.Errorf("%w: wecom errcode=%d errmsg=%s", channels.ErrTemporary, env.ErrCode, env.ErrMsg) + return wecomEnvelope{}, fmt.Errorf("%w: wecom errcode=%d errmsg=%s", channels.ErrTemporary, env.ErrCode, env.ErrMsg) } - return nil + return env, nil case <-timer.C: - return fmt.Errorf("%w: timeout waiting for WeCom ack", channels.ErrTemporary) + return wecomEnvelope{}, fmt.Errorf("%w: timeout waiting for WeCom ack", channels.ErrTemporary) case <-c.ctx.Done(): - return c.ctx.Err() + return wecomEnvelope{}, c.ctx.Err() } } diff --git a/pkg/channels/wecom/wecom_test.go b/pkg/channels/wecom/wecom_test.go index e0ee2e628..45176015f 100644 --- a/pkg/channels/wecom/wecom_test.go +++ b/pkg/channels/wecom/wecom_test.go @@ -2,13 +2,16 @@ package wecom import ( "context" + "encoding/json" "errors" + "os" "path/filepath" "testing" "time" "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/media" ) func TestDispatchIncoming_UsesActualChatIDAndStoresReqIDRoute(t *testing.T) { @@ -18,9 +21,9 @@ func TestDispatchIncoming_UsesActualChatIDAndStoresReqIDRoute(t *testing.T) { ch := newTestWeComChannel(t, messageBus) var commands []wecomCommand - ch.commandSend = func(cmd wecomCommand, _ time.Duration) error { + ch.commandSend = func(cmd wecomCommand, _ time.Duration) (wecomEnvelope, error) { commands = append(commands, cmd) - return nil + return wecomTestAck(nil), nil } msg := wecomIncomingMessage{ @@ -107,12 +110,12 @@ func TestSend_StreamFailureFallsBackToActualChatID(t *testing.T) { } var commands []wecomCommand - ch.commandSend = func(cmd wecomCommand, _ time.Duration) error { + ch.commandSend = func(cmd wecomCommand, _ time.Duration) (wecomEnvelope, error) { commands = append(commands, cmd) if len(commands) == 1 && cmd.Cmd == wecomCmdRespondMsg { - return errors.New("stream send failed") + return wecomEnvelope{}, errors.New("stream send failed") } - return nil + return wecomTestAck(nil), nil } if err := ch.Send(context.Background(), bus.OutboundMessage{ @@ -152,6 +155,301 @@ func TestSend_StreamFailureFallsBackToActualChatID(t *testing.T) { } } +func TestSendMedia_SendsActiveImage(t *testing.T) { + t.Parallel() + + ch := newTestWeComChannel(t, bus.NewMessageBus()) + ch.SetRunning(true) + + store := media.NewFileMediaStore() + ch.SetMediaStore(store) + + imageData := wecomTestJPEGData(t) + imagePath := filepath.Join(t.TempDir(), "photo.jpg") + if err := os.WriteFile(imagePath, imageData, 0o600); err != nil { + t.Fatalf("WriteFile() error = %v", err) + } + ref, err := store.Store(imagePath, media.MediaMeta{ + Filename: "photo.jpg", + ContentType: "image/jpeg", + Source: "test", + CleanupPolicy: media.CleanupPolicyForgetOnly, + }, "scope-1") + if err != nil { + t.Fatalf("Store() error = %v", err) + } + + var commands []wecomCommand + ch.commandSend = func(cmd wecomCommand, _ time.Duration) (wecomEnvelope, error) { + commands = append(commands, cmd) + switch cmd.Cmd { + case wecomCmdUploadMediaInit: + return wecomTestAck(wecomUploadMediaInitResponse{UploadID: "upload-1"}), nil + case wecomCmdUploadMediaEnd: + return wecomTestAck(wecomUploadMediaFinishResponse{ + Type: "image", + MediaID: "media-1", + }), nil + default: + return wecomTestAck(nil), nil + } + } + + err = ch.SendMedia(context.Background(), bus.OutboundMediaMessage{ + Channel: "wecom", + ChatID: "chat-1", + Parts: []bus.MediaPart{{ + Ref: ref, + Type: "image", + Filename: "photo.jpg", + ContentType: "image/jpeg", + }}, + }) + if err != nil { + t.Fatalf("SendMedia() error = %v", err) + } + + if len(commands) != 4 { + t.Fatalf("expected 4 commands, got %d", len(commands)) + } + if commands[0].Cmd != wecomCmdUploadMediaInit { + t.Fatalf("first command = %q, want %q", commands[0].Cmd, wecomCmdUploadMediaInit) + } + initBody, ok := commands[0].Body.(wecomUploadMediaInitBody) + if !ok { + t.Fatalf("unexpected init body type %T", commands[0].Body) + } + if initBody.Type != "image" || initBody.Filename != "photo.jpg" || initBody.TotalChunks != 1 { + t.Fatalf("init body = %+v", initBody) + } + if commands[1].Cmd != wecomCmdUploadMediaChunk { + t.Fatalf("second command = %q, want %q", commands[1].Cmd, wecomCmdUploadMediaChunk) + } + chunkBody, ok := commands[1].Body.(wecomUploadMediaChunkBody) + if !ok { + t.Fatalf("unexpected chunk body type %T", commands[1].Body) + } + if chunkBody.UploadID != "upload-1" || chunkBody.ChunkIndex != 0 || chunkBody.Base64Data == "" { + t.Fatalf("chunk body = %+v", chunkBody) + } + if commands[2].Cmd != wecomCmdUploadMediaEnd { + t.Fatalf("third command = %q, want %q", commands[2].Cmd, wecomCmdUploadMediaEnd) + } + if commands[3].Cmd != wecomCmdSendMsg { + t.Fatalf("fourth command = %q, want %q", commands[3].Cmd, wecomCmdSendMsg) + } + + body, ok := commands[3].Body.(wecomSendMsgBody) + if !ok { + t.Fatalf("unexpected send body type %T", commands[3].Body) + } + if body.MsgType != "image" || body.Image == nil { + t.Fatalf("send body = %+v", body) + } + if body.ChatID != "chat-1" { + t.Fatalf("send chatid = %q, want chat-1", body.ChatID) + } + if body.Image.MediaID != "media-1" { + t.Fatalf("image media_id = %q, want media-1", body.Image.MediaID) + } +} + +func TestSendMedia_UsesTurnImageAndFinishesStream(t *testing.T) { + t.Parallel() + + ch := newTestWeComChannel(t, bus.NewMessageBus()) + ch.SetRunning(true) + + store := media.NewFileMediaStore() + ch.SetMediaStore(store) + + imageData := wecomTestJPEGData(t) + imagePath := filepath.Join(t.TempDir(), "reply.jpg") + if err := os.WriteFile(imagePath, imageData, 0o600); err != nil { + t.Fatalf("WriteFile() error = %v", err) + } + ref, err := store.Store(imagePath, media.MediaMeta{ + Filename: "reply.jpg", + ContentType: "image/jpeg", + Source: "test", + CleanupPolicy: media.CleanupPolicyForgetOnly, + }, "scope-2") + if err != nil { + t.Fatalf("Store() error = %v", err) + } + + ch.queueTurn("chat-1", wecomTurn{ + ReqID: "req-1", + ChatID: "chat-1", + ChatType: 1, + StreamID: "stream-1", + CreatedAt: time.Now(), + }) + if err := ch.routes.Put("chat-1", "req-1", 1, time.Hour); err != nil { + t.Fatalf("Put() error = %v", err) + } + + var commands []wecomCommand + ch.commandSend = func(cmd wecomCommand, _ time.Duration) (wecomEnvelope, error) { + commands = append(commands, cmd) + switch cmd.Cmd { + case wecomCmdUploadMediaInit: + return wecomTestAck(wecomUploadMediaInitResponse{UploadID: "upload-2"}), nil + case wecomCmdUploadMediaEnd: + return wecomTestAck(wecomUploadMediaFinishResponse{ + Type: "image", + MediaID: "media-2", + }), nil + default: + return wecomTestAck(nil), nil + } + } + + err = ch.SendMedia(context.Background(), bus.OutboundMediaMessage{ + Channel: "wecom", + ChatID: "chat-1", + Parts: []bus.MediaPart{{ + Ref: ref, + Type: "image", + Filename: "reply.jpg", + ContentType: "image/jpeg", + }}, + }) + if err != nil { + t.Fatalf("SendMedia() error = %v", err) + } + + if len(commands) != 5 { + t.Fatalf("expected 5 commands, got %d", len(commands)) + } + if commands[0].Cmd != wecomCmdUploadMediaInit { + t.Fatalf("first command = %+v", commands[0]) + } + if commands[1].Cmd != wecomCmdUploadMediaChunk { + t.Fatalf("second command = %+v", commands[1]) + } + if commands[2].Cmd != wecomCmdUploadMediaEnd { + t.Fatalf("third command = %+v", commands[2]) + } + if commands[3].Cmd != wecomCmdRespondMsg || commands[3].Headers.ReqID != "req-1" { + t.Fatalf("fourth command = %+v", commands[3]) + } + if commands[4].Cmd != wecomCmdRespondMsg || commands[4].Headers.ReqID != "req-1" { + t.Fatalf("fifth command = %+v", commands[4]) + } + + imageBody, ok := commands[3].Body.(wecomRespondMsgBody) + if !ok { + t.Fatalf("unexpected image body type %T", commands[3].Body) + } + if imageBody.MsgType != "image" || imageBody.Image == nil { + t.Fatalf("image body = %+v", imageBody) + } + if imageBody.Image.MediaID != "media-2" { + t.Fatalf("image media_id = %q, want media-2", imageBody.Image.MediaID) + } + + streamBody, ok := commands[4].Body.(wecomRespondMsgBody) + if !ok { + t.Fatalf("unexpected finish body type %T", commands[4].Body) + } + if streamBody.MsgType != "stream" || streamBody.Stream == nil || !streamBody.Stream.Finish { + t.Fatalf("finish body = %+v", streamBody) + } + + if _, ok := ch.getTurn("chat-1"); ok { + t.Fatal("expected turn to be removed after media send") + } +} + +func TestSendMedia_SendsActiveFile(t *testing.T) { + t.Parallel() + + ch := newTestWeComChannel(t, bus.NewMessageBus()) + ch.SetRunning(true) + + store := media.NewFileMediaStore() + ch.SetMediaStore(store) + + filePath := filepath.Join(t.TempDir(), "report.pdf") + if err := os.WriteFile(filePath, []byte("%PDF-1.4"), 0o600); err != nil { + t.Fatalf("WriteFile() error = %v", err) + } + ref, err := store.Store(filePath, media.MediaMeta{ + Filename: "report.pdf", + ContentType: "application/pdf", + Source: "test", + CleanupPolicy: media.CleanupPolicyForgetOnly, + }, "scope-3") + if err != nil { + t.Fatalf("Store() error = %v", err) + } + + var commands []wecomCommand + ch.commandSend = func(cmd wecomCommand, _ time.Duration) (wecomEnvelope, error) { + commands = append(commands, cmd) + switch cmd.Cmd { + case wecomCmdUploadMediaInit: + return wecomTestAck(wecomUploadMediaInitResponse{UploadID: "upload-3"}), nil + case wecomCmdUploadMediaEnd: + return wecomTestAck(wecomUploadMediaFinishResponse{ + Type: "file", + MediaID: "media-3", + }), nil + default: + return wecomTestAck(nil), nil + } + } + + err = ch.SendMedia(context.Background(), bus.OutboundMediaMessage{ + Channel: "wecom", + ChatID: "chat-2", + Parts: []bus.MediaPart{{ + Ref: ref, + Type: "file", + Filename: "report.pdf", + ContentType: "application/pdf", + }}, + }) + if err != nil { + t.Fatalf("SendMedia() error = %v", err) + } + + if len(commands) != 4 { + t.Fatalf("expected 4 commands, got %d", len(commands)) + } + if commands[0].Cmd != wecomCmdUploadMediaInit { + t.Fatalf("first command = %q, want %q", commands[0].Cmd, wecomCmdUploadMediaInit) + } + initBody, ok := commands[0].Body.(wecomUploadMediaInitBody) + if !ok { + t.Fatalf("unexpected init body type %T", commands[0].Body) + } + if initBody.Type != "file" || initBody.Filename != "report.pdf" { + t.Fatalf("init body = %+v", initBody) + } + if commands[1].Cmd != wecomCmdUploadMediaChunk { + t.Fatalf("second command = %q, want %q", commands[1].Cmd, wecomCmdUploadMediaChunk) + } + if commands[2].Cmd != wecomCmdUploadMediaEnd { + t.Fatalf("third command = %q, want %q", commands[2].Cmd, wecomCmdUploadMediaEnd) + } + if commands[3].Cmd != wecomCmdSendMsg { + t.Fatalf("fourth command = %q, want %q", commands[3].Cmd, wecomCmdSendMsg) + } + + body, ok := commands[3].Body.(wecomSendMsgBody) + if !ok { + t.Fatalf("unexpected body type %T", commands[3].Body) + } + if body.MsgType != "file" || body.File == nil { + t.Fatalf("body = %+v", body) + } + if body.File.MediaID != "media-3" { + t.Fatalf("file media_id = %q, want media-3", body.File.MediaID) + } +} + func newTestWeComChannel(t *testing.T, messageBus *bus.MessageBus) *WeComChannel { t.Helper() @@ -165,3 +463,46 @@ func newTestWeComChannel(t *testing.T, messageBus *bus.MessageBus) *WeComChannel ch.routes = newReqIDStore(filepath.Join(t.TempDir(), "reqids.json")) return ch } + +func wecomTestJPEGData(t *testing.T) []byte { + t.Helper() + + const jpegBase64 = "/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAP//////////////////////////////////////////////////////////////////////////////////////" + + "//////////////////////////////////////////////////////////////////////////////////////////////2wBDAf//////////////////////////////////////////////////////////////////////////////////////" + + "//////////////////////////////////////////////////////////////////////////////////////////////wAARCAABAAEDASIAAhEBAxEB/8QAFQABAQAAAAAAAAAAAAAAAAAAAAb/xAAVEQEBAAAAAAAAAAAAAAAAAAAABf/aAAwDAQACEAMQAAAB6A//xAAVEAEBAAAAAAAAAAAAAAAAAAAAEf/aAAgBAQABBQJf/8QAFBEBAAAAAAAAAAAAAAAAAAAAEP/aAAgBAwEBPwF//8QAFBEBAAAAAAAAAAAAAAAAAAAAEP/aAAgBAgEBPwF//8QAFBABAAAAAAAAAAAAAAAAAAAAEP/aAAgBAQAGPwJf/8QAFBABAAAAAAAAAAAAAAAAAAAAEP/aAAgBAQABPyFf/9k=" + + return decodeTestBase64(t, jpegBase64) +} + +func TestDecodeWeComUploadFinish_AcceptsNumericCreatedAt(t *testing.T) { + t.Parallel() + + resp, err := decodeWeComEnvelopeBody[wecomUploadMediaFinishResponse](wecomEnvelope{ + Body: json.RawMessage(`{"type":"file","media_id":"media-1","created_at":1380000000}`), + }) + if err != nil { + t.Fatalf("decodeWeComEnvelopeBody() error = %v", err) + } + if resp.Type != "file" || resp.MediaID != "media-1" { + t.Fatalf("response = %+v", resp) + } + if string(resp.CreatedAt) != "1380000000" { + t.Fatalf("created_at = %s, want 1380000000", string(resp.CreatedAt)) + } +} + +func wecomTestAck(body any) wecomEnvelope { + var raw []byte + if body != nil { + encoded, err := json.Marshal(body) + if err != nil { + panic(err) + } + raw = encoded + } + return wecomEnvelope{ + ErrCode: 0, + ErrMsg: "ok", + Body: raw, + } +}