mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
feat(wecom): add channel-side streaming support
This commit is contained in:
@@ -13,7 +13,6 @@ const (
|
||||
wecomCmdUploadMediaInit = "aibot_upload_media_init"
|
||||
wecomCmdUploadMediaChunk = "aibot_upload_media_chunk"
|
||||
wecomCmdUploadMediaEnd = "aibot_upload_media_finish"
|
||||
wecomMaxContentBytes = 20480
|
||||
)
|
||||
|
||||
type wecomEnvelope struct {
|
||||
|
||||
+143
-50
@@ -26,6 +26,7 @@ const (
|
||||
wecomUploadTimeout = 30 * time.Second
|
||||
wecomHeartbeatInterval = 30 * time.Second
|
||||
wecomStreamMaxDuration = 5*time.Minute + 30*time.Second
|
||||
wecomStreamMinInterval = 500 * time.Millisecond
|
||||
wecomRouteTTL = 30 * time.Minute
|
||||
wecomMediaTimeout = 30 * time.Second
|
||||
wecomRecentMessageMax = 1000
|
||||
@@ -61,6 +62,17 @@ type wecomTurn struct {
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
type wecomStreamer struct {
|
||||
channel *WeComChannel
|
||||
chatID string
|
||||
turn wecomTurn
|
||||
|
||||
mu sync.Mutex
|
||||
closed bool
|
||||
lastSentAt time.Time
|
||||
content string
|
||||
}
|
||||
|
||||
type recentMessageSet struct {
|
||||
mu sync.Mutex
|
||||
seen map[string]struct{}
|
||||
@@ -109,7 +121,6 @@ func NewChannel(cfg config.WeComConfig, messageBus *bus.MessageBus) (*WeComChann
|
||||
cfg,
|
||||
messageBus,
|
||||
cfg.AllowFrom,
|
||||
channels.WithMaxMessageLength(wecomMaxContentBytes),
|
||||
channels.WithReasoningChannelID(cfg.ReasoningChannelID),
|
||||
)
|
||||
|
||||
@@ -152,6 +163,27 @@ func (c *WeComChannel) Stop(_ context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *WeComChannel) BeginStream(_ context.Context, chatID string) (channels.Streamer, error) {
|
||||
if !c.IsRunning() {
|
||||
return nil, channels.ErrNotRunning
|
||||
}
|
||||
|
||||
turn, ok := c.getTurn(chatID)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("wecom streaming unavailable: no active turn")
|
||||
}
|
||||
if time.Since(turn.CreatedAt) > wecomStreamMaxDuration {
|
||||
c.consumeTurn(chatID, turn)
|
||||
return nil, fmt.Errorf("wecom streaming unavailable: turn expired")
|
||||
}
|
||||
|
||||
return &wecomStreamer{
|
||||
channel: c,
|
||||
chatID: chatID,
|
||||
turn: turn,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *WeComChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
|
||||
if !c.IsRunning() {
|
||||
return channels.ErrNotRunning
|
||||
@@ -164,11 +196,11 @@ func (c *WeComChannel) Send(ctx context.Context, msg bus.OutboundMessage) error
|
||||
if turn, ok := c.getTurn(msg.ChatID); ok {
|
||||
if time.Since(turn.CreatedAt) <= wecomStreamMaxDuration {
|
||||
if err := c.sendStreamReply(turn, content); err == nil {
|
||||
c.deleteTurn(msg.ChatID)
|
||||
c.consumeTurn(msg.ChatID, turn)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
c.deleteTurn(msg.ChatID)
|
||||
c.consumeTurn(msg.ChatID, turn)
|
||||
}
|
||||
|
||||
if route, ok := c.routes.Get(msg.ChatID); ok {
|
||||
@@ -649,13 +681,7 @@ func (c *WeComChannel) respondImmediate(reqID, content string) error {
|
||||
}
|
||||
|
||||
func (c *WeComChannel) sendStreamReply(turn wecomTurn, content string) error {
|
||||
chunks := splitContent(content, wecomMaxContentBytes)
|
||||
for idx, chunk := range chunks {
|
||||
if err := c.sendStreamChunk(turn, idx == len(chunks)-1, chunk); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
return c.sendStreamChunk(turn, true, content)
|
||||
}
|
||||
|
||||
func (c *WeComChannel) sendStreamChunk(turn wecomTurn, finish bool, content string) error {
|
||||
@@ -691,21 +717,16 @@ func (c *WeComChannel) sendActivePush(chatID string, chatType uint32, content st
|
||||
if strings.TrimSpace(chatID) == "" {
|
||||
return fmt.Errorf("empty chat ID: %w", channels.ErrSendFailed)
|
||||
}
|
||||
for _, chunk := range splitContent(content, wecomMaxContentBytes) {
|
||||
if err := c.sendCommand(wecomCommand{
|
||||
Cmd: wecomCmdSendMsg,
|
||||
Headers: wecomHeaders{ReqID: randomID(10)},
|
||||
Body: wecomSendMsgBody{
|
||||
ChatID: chatID,
|
||||
ChatType: chatType,
|
||||
MsgType: "markdown",
|
||||
Markdown: &wecomMarkdownContent{Content: chunk},
|
||||
},
|
||||
}, wecomCommandTimeout); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
return c.sendCommand(wecomCommand{
|
||||
Cmd: wecomCmdSendMsg,
|
||||
Headers: wecomHeaders{ReqID: randomID(10)},
|
||||
Body: wecomSendMsgBody{
|
||||
ChatID: chatID,
|
||||
ChatType: chatType,
|
||||
MsgType: "markdown",
|
||||
Markdown: &wecomMarkdownContent{Content: content},
|
||||
},
|
||||
}, wecomCommandTimeout)
|
||||
}
|
||||
|
||||
func (c *WeComChannel) sendActiveMedia(chatID string, chatType uint32, uploaded *wecomOutboundMedia) error {
|
||||
@@ -825,6 +846,26 @@ func (c *WeComChannel) queueTurn(chatID string, turn wecomTurn) {
|
||||
c.turns[chatID] = append(c.turns[chatID], turn)
|
||||
}
|
||||
|
||||
func (c *WeComChannel) consumeTurn(chatID string, turn wecomTurn) bool {
|
||||
c.turnsMu.Lock()
|
||||
defer c.turnsMu.Unlock()
|
||||
|
||||
queue := c.turns[chatID]
|
||||
if len(queue) == 0 {
|
||||
return false
|
||||
}
|
||||
current := queue[0]
|
||||
if current.ReqID != turn.ReqID || current.StreamID != turn.StreamID {
|
||||
return false
|
||||
}
|
||||
if len(queue) == 1 {
|
||||
delete(c.turns, chatID)
|
||||
return true
|
||||
}
|
||||
c.turns[chatID] = queue[1:]
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *WeComChannel) clearTurns() {
|
||||
c.turnsMu.Lock()
|
||||
c.turns = make(map[string][]wecomTurn)
|
||||
@@ -844,34 +885,86 @@ func randomID(n int) string {
|
||||
return string(buf)
|
||||
}
|
||||
|
||||
func splitContent(content string, maxBytes int) []string {
|
||||
if content == "" {
|
||||
return []string{""}
|
||||
func (s *wecomStreamer) Update(ctx context.Context, content string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.closed {
|
||||
return nil
|
||||
}
|
||||
if len(content) <= maxBytes {
|
||||
return []string{content}
|
||||
if err := s.validateActiveTurn(); err != nil {
|
||||
return err
|
||||
}
|
||||
chunks := channels.SplitMessage(content, maxBytes)
|
||||
var result []string
|
||||
for _, chunk := range chunks {
|
||||
if len(chunk) <= maxBytes {
|
||||
result = append(result, chunk)
|
||||
continue
|
||||
}
|
||||
for len(chunk) > maxBytes {
|
||||
end := maxBytes
|
||||
for end > 0 && chunk[end]>>6 == 0b10 {
|
||||
end--
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !s.lastSentAt.IsZero() {
|
||||
wait := time.Until(s.lastSentAt.Add(wecomStreamMinInterval))
|
||||
if wait > 0 {
|
||||
timer := time.NewTimer(wait)
|
||||
defer timer.Stop()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-timer.C:
|
||||
}
|
||||
if end == 0 {
|
||||
end = maxBytes
|
||||
}
|
||||
result = append(result, chunk[:end])
|
||||
chunk = strings.TrimLeft(chunk[end:], " \t\r\n")
|
||||
}
|
||||
if chunk != "" {
|
||||
result = append(result, chunk)
|
||||
}
|
||||
}
|
||||
return result
|
||||
|
||||
if err := s.channel.sendStreamChunk(s.turn, false, content); err != nil {
|
||||
return err
|
||||
}
|
||||
s.content = content
|
||||
s.lastSentAt = time.Now()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *wecomStreamer) Finalize(ctx context.Context, content string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.closed {
|
||||
return nil
|
||||
}
|
||||
if err := s.validateActiveTurn(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s.channel.sendStreamChunk(s.turn, true, content); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.content = content
|
||||
s.closed = true
|
||||
s.channel.consumeTurn(s.chatID, s.turn)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *wecomStreamer) Cancel(_ context.Context) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.closed {
|
||||
return
|
||||
}
|
||||
if s.validateActiveTurn() == nil {
|
||||
_ = s.channel.sendStreamChunk(s.turn, true, s.content)
|
||||
s.channel.consumeTurn(s.chatID, s.turn)
|
||||
}
|
||||
s.closed = true
|
||||
}
|
||||
|
||||
func (s *wecomStreamer) validateActiveTurn() error {
|
||||
if time.Since(s.turn.CreatedAt) > wecomStreamMaxDuration {
|
||||
s.channel.consumeTurn(s.chatID, s.turn)
|
||||
return fmt.Errorf("wecom streaming unavailable: turn expired")
|
||||
}
|
||||
current, ok := s.channel.getTurn(s.chatID)
|
||||
if !ok || current.ReqID != s.turn.ReqID || current.StreamID != s.turn.StreamID {
|
||||
return fmt.Errorf("wecom streaming unavailable: turn no longer active")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -86,6 +87,77 @@ func TestDispatchIncoming_UsesActualChatIDAndStoresReqIDRoute(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewChannel_DoesNotRegisterMessageSplitLimit(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ch := newTestWeComChannel(t, bus.NewMessageBus())
|
||||
if got := ch.MaxMessageLength(); got != 0 {
|
||||
t.Fatalf("MaxMessageLength() = %d, want 0", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBeginStream_UpdateAndFinalize(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ch := newTestWeComChannel(t, bus.NewMessageBus())
|
||||
ch.SetRunning(true)
|
||||
ch.queueTurn("chat-1", wecomTurn{
|
||||
ReqID: "req-1",
|
||||
ChatID: "chat-1",
|
||||
ChatType: 1,
|
||||
StreamID: "stream-1",
|
||||
CreatedAt: time.Now(),
|
||||
})
|
||||
|
||||
var commands []wecomCommand
|
||||
ch.commandSend = func(cmd wecomCommand, _ time.Duration) (wecomEnvelope, error) {
|
||||
commands = append(commands, cmd)
|
||||
return wecomTestAck(nil), nil
|
||||
}
|
||||
|
||||
streamer, err := ch.BeginStream(context.Background(), "chat-1")
|
||||
if err != nil {
|
||||
t.Fatalf("BeginStream() error = %v", err)
|
||||
}
|
||||
if err := streamer.Update(context.Background(), "draft"); err != nil {
|
||||
t.Fatalf("Update() error = %v", err)
|
||||
}
|
||||
if err := streamer.Finalize(context.Background(), "final"); err != nil {
|
||||
t.Fatalf("Finalize() error = %v", err)
|
||||
}
|
||||
|
||||
if len(commands) != 2 {
|
||||
t.Fatalf("expected 2 commands, got %d", len(commands))
|
||||
}
|
||||
for i, wantFinish := range []bool{false, true} {
|
||||
if commands[i].Cmd != wecomCmdRespondMsg {
|
||||
t.Fatalf("command[%d].Cmd = %q, want %q", i, commands[i].Cmd, wecomCmdRespondMsg)
|
||||
}
|
||||
body, ok := commands[i].Body.(wecomRespondMsgBody)
|
||||
if !ok {
|
||||
t.Fatalf("command[%d] body type = %T", i, commands[i].Body)
|
||||
}
|
||||
if body.Stream == nil {
|
||||
t.Fatalf("command[%d] missing stream body", i)
|
||||
}
|
||||
if body.Stream.ID != "stream-1" {
|
||||
t.Fatalf("command[%d] stream id = %q, want stream-1", i, body.Stream.ID)
|
||||
}
|
||||
if body.Stream.Finish != wantFinish {
|
||||
t.Fatalf("command[%d] finish = %v, want %v", i, body.Stream.Finish, wantFinish)
|
||||
}
|
||||
}
|
||||
if body := commands[0].Body.(wecomRespondMsgBody); body.Stream.Content != "draft" {
|
||||
t.Fatalf("update content = %q, want draft", body.Stream.Content)
|
||||
}
|
||||
if body := commands[1].Body.(wecomRespondMsgBody); body.Stream.Content != "final" {
|
||||
t.Fatalf("final content = %q, want final", body.Stream.Content)
|
||||
}
|
||||
if _, ok := ch.getTurn("chat-1"); ok {
|
||||
t.Fatal("expected turn to be consumed after Finalize")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSend_StreamFailureFallsBackToActualChatID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -155,6 +227,85 @@ func TestSend_StreamFailureFallsBackToActualChatID(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSend_DoesNotSplitStreamReply(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ch := newTestWeComChannel(t, bus.NewMessageBus())
|
||||
ch.SetRunning(true)
|
||||
ch.queueTurn("chat-1", wecomTurn{
|
||||
ReqID: "req-1",
|
||||
ChatID: "chat-1",
|
||||
ChatType: 1,
|
||||
StreamID: "stream-1",
|
||||
CreatedAt: time.Now(),
|
||||
})
|
||||
|
||||
var commands []wecomCommand
|
||||
ch.commandSend = func(cmd wecomCommand, _ time.Duration) (wecomEnvelope, error) {
|
||||
commands = append(commands, cmd)
|
||||
return wecomTestAck(nil), nil
|
||||
}
|
||||
|
||||
content := strings.Repeat("\u4e2d", 30000)
|
||||
if err := ch.Send(context.Background(), bus.OutboundMessage{
|
||||
Channel: "wecom",
|
||||
ChatID: "chat-1",
|
||||
Content: content,
|
||||
}); err != nil {
|
||||
t.Fatalf("Send() error = %v", err)
|
||||
}
|
||||
|
||||
if len(commands) != 1 {
|
||||
t.Fatalf("expected 1 stream command, got %d", len(commands))
|
||||
}
|
||||
body, ok := commands[0].Body.(wecomRespondMsgBody)
|
||||
if !ok {
|
||||
t.Fatalf("unexpected body type %T", commands[0].Body)
|
||||
}
|
||||
if body.Stream == nil || !body.Stream.Finish {
|
||||
t.Fatalf("stream body = %+v", body.Stream)
|
||||
}
|
||||
if body.Stream.Content != content {
|
||||
t.Fatalf("stream content length = %d, want %d", len(body.Stream.Content), len(content))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSend_DoesNotSplitActivePush(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ch := newTestWeComChannel(t, bus.NewMessageBus())
|
||||
ch.SetRunning(true)
|
||||
|
||||
var commands []wecomCommand
|
||||
ch.commandSend = func(cmd wecomCommand, _ time.Duration) (wecomEnvelope, error) {
|
||||
commands = append(commands, cmd)
|
||||
return wecomTestAck(nil), nil
|
||||
}
|
||||
|
||||
content := strings.Repeat("a", 30000)
|
||||
if err := ch.Send(context.Background(), bus.OutboundMessage{
|
||||
Channel: "wecom",
|
||||
ChatID: "chat-1",
|
||||
Content: content,
|
||||
}); err != nil {
|
||||
t.Fatalf("Send() error = %v", err)
|
||||
}
|
||||
|
||||
if len(commands) != 1 {
|
||||
t.Fatalf("expected 1 send command, got %d", len(commands))
|
||||
}
|
||||
if commands[0].Cmd != wecomCmdSendMsg {
|
||||
t.Fatalf("command = %q, want %q", commands[0].Cmd, wecomCmdSendMsg)
|
||||
}
|
||||
body, ok := commands[0].Body.(wecomSendMsgBody)
|
||||
if !ok {
|
||||
t.Fatalf("unexpected body type %T", commands[0].Body)
|
||||
}
|
||||
if body.Markdown == nil || body.Markdown.Content != content {
|
||||
t.Fatalf("markdown content length = %d, want %d", len(body.Markdown.Content), len(content))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendMedia_SendsActiveImage(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user