From 3b498d2e4b2b991c4c33cbc55e04d01cec307bbb Mon Sep 17 00:00:00 2001 From: Hoshina Date: Tue, 24 Mar 2026 20:17:16 +0800 Subject: [PATCH] feat(wecom): add channel-side streaming support --- pkg/channels/wecom/protocol.go | 1 - pkg/channels/wecom/wecom.go | 193 +++++++++++++++++++++++-------- pkg/channels/wecom/wecom_test.go | 151 ++++++++++++++++++++++++ 3 files changed, 294 insertions(+), 51 deletions(-) diff --git a/pkg/channels/wecom/protocol.go b/pkg/channels/wecom/protocol.go index 0190e70e5..f42ce3bf4 100644 --- a/pkg/channels/wecom/protocol.go +++ b/pkg/channels/wecom/protocol.go @@ -13,7 +13,6 @@ const ( wecomCmdUploadMediaInit = "aibot_upload_media_init" wecomCmdUploadMediaChunk = "aibot_upload_media_chunk" wecomCmdUploadMediaEnd = "aibot_upload_media_finish" - wecomMaxContentBytes = 20480 ) type wecomEnvelope struct { diff --git a/pkg/channels/wecom/wecom.go b/pkg/channels/wecom/wecom.go index 075c1732f..26e971921 100644 --- a/pkg/channels/wecom/wecom.go +++ b/pkg/channels/wecom/wecom.go @@ -26,6 +26,7 @@ const ( wecomUploadTimeout = 30 * time.Second wecomHeartbeatInterval = 30 * time.Second wecomStreamMaxDuration = 5*time.Minute + 30*time.Second + wecomStreamMinInterval = 500 * time.Millisecond wecomRouteTTL = 30 * time.Minute wecomMediaTimeout = 30 * time.Second wecomRecentMessageMax = 1000 @@ -61,6 +62,17 @@ type wecomTurn struct { CreatedAt time.Time } +type wecomStreamer struct { + channel *WeComChannel + chatID string + turn wecomTurn + + mu sync.Mutex + closed bool + lastSentAt time.Time + content string +} + type recentMessageSet struct { mu sync.Mutex seen map[string]struct{} @@ -109,7 +121,6 @@ func NewChannel(cfg config.WeComConfig, messageBus *bus.MessageBus) (*WeComChann cfg, messageBus, cfg.AllowFrom, - channels.WithMaxMessageLength(wecomMaxContentBytes), channels.WithReasoningChannelID(cfg.ReasoningChannelID), ) @@ -152,6 +163,27 @@ func (c *WeComChannel) Stop(_ context.Context) error { return nil } +func (c *WeComChannel) BeginStream(_ context.Context, chatID string) (channels.Streamer, error) { + if !c.IsRunning() { + return nil, channels.ErrNotRunning + } + + turn, ok := c.getTurn(chatID) + if !ok { + return nil, fmt.Errorf("wecom streaming unavailable: no active turn") + } + if time.Since(turn.CreatedAt) > wecomStreamMaxDuration { + c.consumeTurn(chatID, turn) + return nil, fmt.Errorf("wecom streaming unavailable: turn expired") + } + + return &wecomStreamer{ + channel: c, + chatID: chatID, + turn: turn, + }, nil +} + func (c *WeComChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { if !c.IsRunning() { return channels.ErrNotRunning @@ -164,11 +196,11 @@ func (c *WeComChannel) Send(ctx context.Context, msg bus.OutboundMessage) error if turn, ok := c.getTurn(msg.ChatID); ok { if time.Since(turn.CreatedAt) <= wecomStreamMaxDuration { if err := c.sendStreamReply(turn, content); err == nil { - c.deleteTurn(msg.ChatID) + c.consumeTurn(msg.ChatID, turn) return nil } } - c.deleteTurn(msg.ChatID) + c.consumeTurn(msg.ChatID, turn) } if route, ok := c.routes.Get(msg.ChatID); ok { @@ -649,13 +681,7 @@ func (c *WeComChannel) respondImmediate(reqID, content string) error { } func (c *WeComChannel) sendStreamReply(turn wecomTurn, content string) error { - chunks := splitContent(content, wecomMaxContentBytes) - for idx, chunk := range chunks { - if err := c.sendStreamChunk(turn, idx == len(chunks)-1, chunk); err != nil { - return err - } - } - return nil + return c.sendStreamChunk(turn, true, content) } func (c *WeComChannel) sendStreamChunk(turn wecomTurn, finish bool, content string) error { @@ -691,21 +717,16 @@ func (c *WeComChannel) sendActivePush(chatID string, chatType uint32, content st if strings.TrimSpace(chatID) == "" { return fmt.Errorf("empty chat ID: %w", channels.ErrSendFailed) } - for _, chunk := range splitContent(content, wecomMaxContentBytes) { - if err := c.sendCommand(wecomCommand{ - Cmd: wecomCmdSendMsg, - Headers: wecomHeaders{ReqID: randomID(10)}, - Body: wecomSendMsgBody{ - ChatID: chatID, - ChatType: chatType, - MsgType: "markdown", - Markdown: &wecomMarkdownContent{Content: chunk}, - }, - }, wecomCommandTimeout); err != nil { - return err - } - } - return nil + return c.sendCommand(wecomCommand{ + Cmd: wecomCmdSendMsg, + Headers: wecomHeaders{ReqID: randomID(10)}, + Body: wecomSendMsgBody{ + ChatID: chatID, + ChatType: chatType, + MsgType: "markdown", + Markdown: &wecomMarkdownContent{Content: content}, + }, + }, wecomCommandTimeout) } func (c *WeComChannel) sendActiveMedia(chatID string, chatType uint32, uploaded *wecomOutboundMedia) error { @@ -825,6 +846,26 @@ func (c *WeComChannel) queueTurn(chatID string, turn wecomTurn) { c.turns[chatID] = append(c.turns[chatID], turn) } +func (c *WeComChannel) consumeTurn(chatID string, turn wecomTurn) bool { + c.turnsMu.Lock() + defer c.turnsMu.Unlock() + + queue := c.turns[chatID] + if len(queue) == 0 { + return false + } + current := queue[0] + if current.ReqID != turn.ReqID || current.StreamID != turn.StreamID { + return false + } + if len(queue) == 1 { + delete(c.turns, chatID) + return true + } + c.turns[chatID] = queue[1:] + return true +} + func (c *WeComChannel) clearTurns() { c.turnsMu.Lock() c.turns = make(map[string][]wecomTurn) @@ -844,34 +885,86 @@ func randomID(n int) string { return string(buf) } -func splitContent(content string, maxBytes int) []string { - if content == "" { - return []string{""} +func (s *wecomStreamer) Update(ctx context.Context, content string) error { + s.mu.Lock() + defer s.mu.Unlock() + + if s.closed { + return nil } - if len(content) <= maxBytes { - return []string{content} + if err := s.validateActiveTurn(); err != nil { + return err } - chunks := channels.SplitMessage(content, maxBytes) - var result []string - for _, chunk := range chunks { - if len(chunk) <= maxBytes { - result = append(result, chunk) - continue - } - for len(chunk) > maxBytes { - end := maxBytes - for end > 0 && chunk[end]>>6 == 0b10 { - end-- + if err := ctx.Err(); err != nil { + return err + } + + if !s.lastSentAt.IsZero() { + wait := time.Until(s.lastSentAt.Add(wecomStreamMinInterval)) + if wait > 0 { + timer := time.NewTimer(wait) + defer timer.Stop() + select { + case <-ctx.Done(): + return ctx.Err() + case <-timer.C: } - if end == 0 { - end = maxBytes - } - result = append(result, chunk[:end]) - chunk = strings.TrimLeft(chunk[end:], " \t\r\n") - } - if chunk != "" { - result = append(result, chunk) } } - return result + + if err := s.channel.sendStreamChunk(s.turn, false, content); err != nil { + return err + } + s.content = content + s.lastSentAt = time.Now() + return nil +} + +func (s *wecomStreamer) Finalize(ctx context.Context, content string) error { + s.mu.Lock() + defer s.mu.Unlock() + + if s.closed { + return nil + } + if err := s.validateActiveTurn(); err != nil { + return err + } + if err := ctx.Err(); err != nil { + return err + } + if err := s.channel.sendStreamChunk(s.turn, true, content); err != nil { + return err + } + + s.content = content + s.closed = true + s.channel.consumeTurn(s.chatID, s.turn) + return nil +} + +func (s *wecomStreamer) Cancel(_ context.Context) { + s.mu.Lock() + defer s.mu.Unlock() + + if s.closed { + return + } + if s.validateActiveTurn() == nil { + _ = s.channel.sendStreamChunk(s.turn, true, s.content) + s.channel.consumeTurn(s.chatID, s.turn) + } + s.closed = true +} + +func (s *wecomStreamer) validateActiveTurn() error { + if time.Since(s.turn.CreatedAt) > wecomStreamMaxDuration { + s.channel.consumeTurn(s.chatID, s.turn) + return fmt.Errorf("wecom streaming unavailable: turn expired") + } + current, ok := s.channel.getTurn(s.chatID) + if !ok || current.ReqID != s.turn.ReqID || current.StreamID != s.turn.StreamID { + return fmt.Errorf("wecom streaming unavailable: turn no longer active") + } + return nil } diff --git a/pkg/channels/wecom/wecom_test.go b/pkg/channels/wecom/wecom_test.go index 478423307..c7a4adfc0 100644 --- a/pkg/channels/wecom/wecom_test.go +++ b/pkg/channels/wecom/wecom_test.go @@ -6,6 +6,7 @@ import ( "errors" "os" "path/filepath" + "strings" "testing" "time" @@ -86,6 +87,77 @@ func TestDispatchIncoming_UsesActualChatIDAndStoresReqIDRoute(t *testing.T) { } } +func TestNewChannel_DoesNotRegisterMessageSplitLimit(t *testing.T) { + t.Parallel() + + ch := newTestWeComChannel(t, bus.NewMessageBus()) + if got := ch.MaxMessageLength(); got != 0 { + t.Fatalf("MaxMessageLength() = %d, want 0", got) + } +} + +func TestBeginStream_UpdateAndFinalize(t *testing.T) { + t.Parallel() + + ch := newTestWeComChannel(t, bus.NewMessageBus()) + ch.SetRunning(true) + ch.queueTurn("chat-1", wecomTurn{ + ReqID: "req-1", + ChatID: "chat-1", + ChatType: 1, + StreamID: "stream-1", + CreatedAt: time.Now(), + }) + + var commands []wecomCommand + ch.commandSend = func(cmd wecomCommand, _ time.Duration) (wecomEnvelope, error) { + commands = append(commands, cmd) + return wecomTestAck(nil), nil + } + + streamer, err := ch.BeginStream(context.Background(), "chat-1") + if err != nil { + t.Fatalf("BeginStream() error = %v", err) + } + if err := streamer.Update(context.Background(), "draft"); err != nil { + t.Fatalf("Update() error = %v", err) + } + if err := streamer.Finalize(context.Background(), "final"); err != nil { + t.Fatalf("Finalize() error = %v", err) + } + + if len(commands) != 2 { + t.Fatalf("expected 2 commands, got %d", len(commands)) + } + for i, wantFinish := range []bool{false, true} { + if commands[i].Cmd != wecomCmdRespondMsg { + t.Fatalf("command[%d].Cmd = %q, want %q", i, commands[i].Cmd, wecomCmdRespondMsg) + } + body, ok := commands[i].Body.(wecomRespondMsgBody) + if !ok { + t.Fatalf("command[%d] body type = %T", i, commands[i].Body) + } + if body.Stream == nil { + t.Fatalf("command[%d] missing stream body", i) + } + if body.Stream.ID != "stream-1" { + t.Fatalf("command[%d] stream id = %q, want stream-1", i, body.Stream.ID) + } + if body.Stream.Finish != wantFinish { + t.Fatalf("command[%d] finish = %v, want %v", i, body.Stream.Finish, wantFinish) + } + } + if body := commands[0].Body.(wecomRespondMsgBody); body.Stream.Content != "draft" { + t.Fatalf("update content = %q, want draft", body.Stream.Content) + } + if body := commands[1].Body.(wecomRespondMsgBody); body.Stream.Content != "final" { + t.Fatalf("final content = %q, want final", body.Stream.Content) + } + if _, ok := ch.getTurn("chat-1"); ok { + t.Fatal("expected turn to be consumed after Finalize") + } +} + func TestSend_StreamFailureFallsBackToActualChatID(t *testing.T) { t.Parallel() @@ -155,6 +227,85 @@ func TestSend_StreamFailureFallsBackToActualChatID(t *testing.T) { } } +func TestSend_DoesNotSplitStreamReply(t *testing.T) { + t.Parallel() + + ch := newTestWeComChannel(t, bus.NewMessageBus()) + ch.SetRunning(true) + ch.queueTurn("chat-1", wecomTurn{ + ReqID: "req-1", + ChatID: "chat-1", + ChatType: 1, + StreamID: "stream-1", + CreatedAt: time.Now(), + }) + + var commands []wecomCommand + ch.commandSend = func(cmd wecomCommand, _ time.Duration) (wecomEnvelope, error) { + commands = append(commands, cmd) + return wecomTestAck(nil), nil + } + + content := strings.Repeat("\u4e2d", 30000) + if err := ch.Send(context.Background(), bus.OutboundMessage{ + Channel: "wecom", + ChatID: "chat-1", + Content: content, + }); err != nil { + t.Fatalf("Send() error = %v", err) + } + + if len(commands) != 1 { + t.Fatalf("expected 1 stream command, got %d", len(commands)) + } + body, ok := commands[0].Body.(wecomRespondMsgBody) + if !ok { + t.Fatalf("unexpected body type %T", commands[0].Body) + } + if body.Stream == nil || !body.Stream.Finish { + t.Fatalf("stream body = %+v", body.Stream) + } + if body.Stream.Content != content { + t.Fatalf("stream content length = %d, want %d", len(body.Stream.Content), len(content)) + } +} + +func TestSend_DoesNotSplitActivePush(t *testing.T) { + t.Parallel() + + ch := newTestWeComChannel(t, bus.NewMessageBus()) + ch.SetRunning(true) + + var commands []wecomCommand + ch.commandSend = func(cmd wecomCommand, _ time.Duration) (wecomEnvelope, error) { + commands = append(commands, cmd) + return wecomTestAck(nil), nil + } + + content := strings.Repeat("a", 30000) + if err := ch.Send(context.Background(), bus.OutboundMessage{ + Channel: "wecom", + ChatID: "chat-1", + Content: content, + }); err != nil { + t.Fatalf("Send() error = %v", err) + } + + if len(commands) != 1 { + t.Fatalf("expected 1 send command, got %d", len(commands)) + } + if commands[0].Cmd != wecomCmdSendMsg { + t.Fatalf("command = %q, want %q", commands[0].Cmd, wecomCmdSendMsg) + } + body, ok := commands[0].Body.(wecomSendMsgBody) + if !ok { + t.Fatalf("unexpected body type %T", commands[0].Body) + } + if body.Markdown == nil || body.Markdown.Content != content { + t.Fatalf("markdown content length = %d, want %d", len(body.Markdown.Content), len(content)) + } +} + func TestSendMedia_SendsActiveImage(t *testing.T) { t.Parallel()