feat(wecom): add channel-side streaming support

This commit is contained in:
Hoshina
2026-03-24 20:17:16 +08:00
parent 11b6b10d59
commit 3b498d2e4b
3 changed files with 294 additions and 51 deletions
-1
View File
@@ -13,7 +13,6 @@ const (
wecomCmdUploadMediaInit = "aibot_upload_media_init"
wecomCmdUploadMediaChunk = "aibot_upload_media_chunk"
wecomCmdUploadMediaEnd = "aibot_upload_media_finish"
wecomMaxContentBytes = 20480
)
type wecomEnvelope struct {
+143 -50
View File
@@ -26,6 +26,7 @@ const (
wecomUploadTimeout = 30 * time.Second
wecomHeartbeatInterval = 30 * time.Second
wecomStreamMaxDuration = 5*time.Minute + 30*time.Second
wecomStreamMinInterval = 500 * time.Millisecond
wecomRouteTTL = 30 * time.Minute
wecomMediaTimeout = 30 * time.Second
wecomRecentMessageMax = 1000
@@ -61,6 +62,17 @@ type wecomTurn struct {
CreatedAt time.Time
}
type wecomStreamer struct {
channel *WeComChannel
chatID string
turn wecomTurn
mu sync.Mutex
closed bool
lastSentAt time.Time
content string
}
type recentMessageSet struct {
mu sync.Mutex
seen map[string]struct{}
@@ -109,7 +121,6 @@ func NewChannel(cfg config.WeComConfig, messageBus *bus.MessageBus) (*WeComChann
cfg,
messageBus,
cfg.AllowFrom,
channels.WithMaxMessageLength(wecomMaxContentBytes),
channels.WithReasoningChannelID(cfg.ReasoningChannelID),
)
@@ -152,6 +163,27 @@ func (c *WeComChannel) Stop(_ context.Context) error {
return nil
}
func (c *WeComChannel) BeginStream(_ context.Context, chatID string) (channels.Streamer, error) {
if !c.IsRunning() {
return nil, channels.ErrNotRunning
}
turn, ok := c.getTurn(chatID)
if !ok {
return nil, fmt.Errorf("wecom streaming unavailable: no active turn")
}
if time.Since(turn.CreatedAt) > wecomStreamMaxDuration {
c.consumeTurn(chatID, turn)
return nil, fmt.Errorf("wecom streaming unavailable: turn expired")
}
return &wecomStreamer{
channel: c,
chatID: chatID,
turn: turn,
}, nil
}
func (c *WeComChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
if !c.IsRunning() {
return channels.ErrNotRunning
@@ -164,11 +196,11 @@ func (c *WeComChannel) Send(ctx context.Context, msg bus.OutboundMessage) error
if turn, ok := c.getTurn(msg.ChatID); ok {
if time.Since(turn.CreatedAt) <= wecomStreamMaxDuration {
if err := c.sendStreamReply(turn, content); err == nil {
c.deleteTurn(msg.ChatID)
c.consumeTurn(msg.ChatID, turn)
return nil
}
}
c.deleteTurn(msg.ChatID)
c.consumeTurn(msg.ChatID, turn)
}
if route, ok := c.routes.Get(msg.ChatID); ok {
@@ -649,13 +681,7 @@ func (c *WeComChannel) respondImmediate(reqID, content string) error {
}
func (c *WeComChannel) sendStreamReply(turn wecomTurn, content string) error {
chunks := splitContent(content, wecomMaxContentBytes)
for idx, chunk := range chunks {
if err := c.sendStreamChunk(turn, idx == len(chunks)-1, chunk); err != nil {
return err
}
}
return nil
return c.sendStreamChunk(turn, true, content)
}
func (c *WeComChannel) sendStreamChunk(turn wecomTurn, finish bool, content string) error {
@@ -691,21 +717,16 @@ func (c *WeComChannel) sendActivePush(chatID string, chatType uint32, content st
if strings.TrimSpace(chatID) == "" {
return fmt.Errorf("empty chat ID: %w", channels.ErrSendFailed)
}
for _, chunk := range splitContent(content, wecomMaxContentBytes) {
if err := c.sendCommand(wecomCommand{
Cmd: wecomCmdSendMsg,
Headers: wecomHeaders{ReqID: randomID(10)},
Body: wecomSendMsgBody{
ChatID: chatID,
ChatType: chatType,
MsgType: "markdown",
Markdown: &wecomMarkdownContent{Content: chunk},
},
}, wecomCommandTimeout); err != nil {
return err
}
}
return nil
return c.sendCommand(wecomCommand{
Cmd: wecomCmdSendMsg,
Headers: wecomHeaders{ReqID: randomID(10)},
Body: wecomSendMsgBody{
ChatID: chatID,
ChatType: chatType,
MsgType: "markdown",
Markdown: &wecomMarkdownContent{Content: content},
},
}, wecomCommandTimeout)
}
func (c *WeComChannel) sendActiveMedia(chatID string, chatType uint32, uploaded *wecomOutboundMedia) error {
@@ -825,6 +846,26 @@ func (c *WeComChannel) queueTurn(chatID string, turn wecomTurn) {
c.turns[chatID] = append(c.turns[chatID], turn)
}
func (c *WeComChannel) consumeTurn(chatID string, turn wecomTurn) bool {
c.turnsMu.Lock()
defer c.turnsMu.Unlock()
queue := c.turns[chatID]
if len(queue) == 0 {
return false
}
current := queue[0]
if current.ReqID != turn.ReqID || current.StreamID != turn.StreamID {
return false
}
if len(queue) == 1 {
delete(c.turns, chatID)
return true
}
c.turns[chatID] = queue[1:]
return true
}
func (c *WeComChannel) clearTurns() {
c.turnsMu.Lock()
c.turns = make(map[string][]wecomTurn)
@@ -844,34 +885,86 @@ func randomID(n int) string {
return string(buf)
}
func splitContent(content string, maxBytes int) []string {
if content == "" {
return []string{""}
func (s *wecomStreamer) Update(ctx context.Context, content string) error {
s.mu.Lock()
defer s.mu.Unlock()
if s.closed {
return nil
}
if len(content) <= maxBytes {
return []string{content}
if err := s.validateActiveTurn(); err != nil {
return err
}
chunks := channels.SplitMessage(content, maxBytes)
var result []string
for _, chunk := range chunks {
if len(chunk) <= maxBytes {
result = append(result, chunk)
continue
}
for len(chunk) > maxBytes {
end := maxBytes
for end > 0 && chunk[end]>>6 == 0b10 {
end--
if err := ctx.Err(); err != nil {
return err
}
if !s.lastSentAt.IsZero() {
wait := time.Until(s.lastSentAt.Add(wecomStreamMinInterval))
if wait > 0 {
timer := time.NewTimer(wait)
defer timer.Stop()
select {
case <-ctx.Done():
return ctx.Err()
case <-timer.C:
}
if end == 0 {
end = maxBytes
}
result = append(result, chunk[:end])
chunk = strings.TrimLeft(chunk[end:], " \t\r\n")
}
if chunk != "" {
result = append(result, chunk)
}
}
return result
if err := s.channel.sendStreamChunk(s.turn, false, content); err != nil {
return err
}
s.content = content
s.lastSentAt = time.Now()
return nil
}
func (s *wecomStreamer) Finalize(ctx context.Context, content string) error {
s.mu.Lock()
defer s.mu.Unlock()
if s.closed {
return nil
}
if err := s.validateActiveTurn(); err != nil {
return err
}
if err := ctx.Err(); err != nil {
return err
}
if err := s.channel.sendStreamChunk(s.turn, true, content); err != nil {
return err
}
s.content = content
s.closed = true
s.channel.consumeTurn(s.chatID, s.turn)
return nil
}
func (s *wecomStreamer) Cancel(_ context.Context) {
s.mu.Lock()
defer s.mu.Unlock()
if s.closed {
return
}
if s.validateActiveTurn() == nil {
_ = s.channel.sendStreamChunk(s.turn, true, s.content)
s.channel.consumeTurn(s.chatID, s.turn)
}
s.closed = true
}
func (s *wecomStreamer) validateActiveTurn() error {
if time.Since(s.turn.CreatedAt) > wecomStreamMaxDuration {
s.channel.consumeTurn(s.chatID, s.turn)
return fmt.Errorf("wecom streaming unavailable: turn expired")
}
current, ok := s.channel.getTurn(s.chatID)
if !ok || current.ReqID != s.turn.ReqID || current.StreamID != s.turn.StreamID {
return fmt.Errorf("wecom streaming unavailable: turn no longer active")
}
return nil
}
+151
View File
@@ -6,6 +6,7 @@ import (
"errors"
"os"
"path/filepath"
"strings"
"testing"
"time"
@@ -86,6 +87,77 @@ func TestDispatchIncoming_UsesActualChatIDAndStoresReqIDRoute(t *testing.T) {
}
}
func TestNewChannel_DoesNotRegisterMessageSplitLimit(t *testing.T) {
t.Parallel()
ch := newTestWeComChannel(t, bus.NewMessageBus())
if got := ch.MaxMessageLength(); got != 0 {
t.Fatalf("MaxMessageLength() = %d, want 0", got)
}
}
func TestBeginStream_UpdateAndFinalize(t *testing.T) {
t.Parallel()
ch := newTestWeComChannel(t, bus.NewMessageBus())
ch.SetRunning(true)
ch.queueTurn("chat-1", wecomTurn{
ReqID: "req-1",
ChatID: "chat-1",
ChatType: 1,
StreamID: "stream-1",
CreatedAt: time.Now(),
})
var commands []wecomCommand
ch.commandSend = func(cmd wecomCommand, _ time.Duration) (wecomEnvelope, error) {
commands = append(commands, cmd)
return wecomTestAck(nil), nil
}
streamer, err := ch.BeginStream(context.Background(), "chat-1")
if err != nil {
t.Fatalf("BeginStream() error = %v", err)
}
if err := streamer.Update(context.Background(), "draft"); err != nil {
t.Fatalf("Update() error = %v", err)
}
if err := streamer.Finalize(context.Background(), "final"); err != nil {
t.Fatalf("Finalize() error = %v", err)
}
if len(commands) != 2 {
t.Fatalf("expected 2 commands, got %d", len(commands))
}
for i, wantFinish := range []bool{false, true} {
if commands[i].Cmd != wecomCmdRespondMsg {
t.Fatalf("command[%d].Cmd = %q, want %q", i, commands[i].Cmd, wecomCmdRespondMsg)
}
body, ok := commands[i].Body.(wecomRespondMsgBody)
if !ok {
t.Fatalf("command[%d] body type = %T", i, commands[i].Body)
}
if body.Stream == nil {
t.Fatalf("command[%d] missing stream body", i)
}
if body.Stream.ID != "stream-1" {
t.Fatalf("command[%d] stream id = %q, want stream-1", i, body.Stream.ID)
}
if body.Stream.Finish != wantFinish {
t.Fatalf("command[%d] finish = %v, want %v", i, body.Stream.Finish, wantFinish)
}
}
if body := commands[0].Body.(wecomRespondMsgBody); body.Stream.Content != "draft" {
t.Fatalf("update content = %q, want draft", body.Stream.Content)
}
if body := commands[1].Body.(wecomRespondMsgBody); body.Stream.Content != "final" {
t.Fatalf("final content = %q, want final", body.Stream.Content)
}
if _, ok := ch.getTurn("chat-1"); ok {
t.Fatal("expected turn to be consumed after Finalize")
}
}
func TestSend_StreamFailureFallsBackToActualChatID(t *testing.T) {
t.Parallel()
@@ -155,6 +227,85 @@ func TestSend_StreamFailureFallsBackToActualChatID(t *testing.T) {
}
}
func TestSend_DoesNotSplitStreamReply(t *testing.T) {
t.Parallel()
ch := newTestWeComChannel(t, bus.NewMessageBus())
ch.SetRunning(true)
ch.queueTurn("chat-1", wecomTurn{
ReqID: "req-1",
ChatID: "chat-1",
ChatType: 1,
StreamID: "stream-1",
CreatedAt: time.Now(),
})
var commands []wecomCommand
ch.commandSend = func(cmd wecomCommand, _ time.Duration) (wecomEnvelope, error) {
commands = append(commands, cmd)
return wecomTestAck(nil), nil
}
content := strings.Repeat("\u4e2d", 30000)
if err := ch.Send(context.Background(), bus.OutboundMessage{
Channel: "wecom",
ChatID: "chat-1",
Content: content,
}); err != nil {
t.Fatalf("Send() error = %v", err)
}
if len(commands) != 1 {
t.Fatalf("expected 1 stream command, got %d", len(commands))
}
body, ok := commands[0].Body.(wecomRespondMsgBody)
if !ok {
t.Fatalf("unexpected body type %T", commands[0].Body)
}
if body.Stream == nil || !body.Stream.Finish {
t.Fatalf("stream body = %+v", body.Stream)
}
if body.Stream.Content != content {
t.Fatalf("stream content length = %d, want %d", len(body.Stream.Content), len(content))
}
}
func TestSend_DoesNotSplitActivePush(t *testing.T) {
t.Parallel()
ch := newTestWeComChannel(t, bus.NewMessageBus())
ch.SetRunning(true)
var commands []wecomCommand
ch.commandSend = func(cmd wecomCommand, _ time.Duration) (wecomEnvelope, error) {
commands = append(commands, cmd)
return wecomTestAck(nil), nil
}
content := strings.Repeat("a", 30000)
if err := ch.Send(context.Background(), bus.OutboundMessage{
Channel: "wecom",
ChatID: "chat-1",
Content: content,
}); err != nil {
t.Fatalf("Send() error = %v", err)
}
if len(commands) != 1 {
t.Fatalf("expected 1 send command, got %d", len(commands))
}
if commands[0].Cmd != wecomCmdSendMsg {
t.Fatalf("command = %q, want %q", commands[0].Cmd, wecomCmdSendMsg)
}
body, ok := commands[0].Body.(wecomSendMsgBody)
if !ok {
t.Fatalf("unexpected body type %T", commands[0].Body)
}
if body.Markdown == nil || body.Markdown.Content != content {
t.Fatalf("markdown content length = %d, want %d", len(body.Markdown.Content), len(content))
}
}
func TestSendMedia_SendsActiveImage(t *testing.T) {
t.Parallel()