mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
983 lines
23 KiB
Go
983 lines
23 KiB
Go
package wecom
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"encoding/json"
|
|
"fmt"
|
|
"math/big"
|
|
"net/http"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/gorilla/websocket"
|
|
|
|
"github.com/sipeed/picoclaw/pkg/bus"
|
|
"github.com/sipeed/picoclaw/pkg/channels"
|
|
"github.com/sipeed/picoclaw/pkg/config"
|
|
"github.com/sipeed/picoclaw/pkg/identity"
|
|
"github.com/sipeed/picoclaw/pkg/logger"
|
|
)
|
|
|
|
const (
|
|
wecomConnectTimeout = 15 * time.Second
|
|
wecomCommandTimeout = 10 * time.Second
|
|
wecomUploadTimeout = 30 * time.Second
|
|
wecomHeartbeatInterval = 30 * time.Second
|
|
wecomStreamMaxDuration = 5*time.Minute + 30*time.Second
|
|
wecomStreamMinInterval = 500 * time.Millisecond
|
|
wecomRouteTTL = 30 * time.Minute
|
|
wecomMediaTimeout = 30 * time.Second
|
|
wecomRecentMessageMax = 1000
|
|
)
|
|
|
|
type WeComChannel struct {
|
|
*channels.BaseChannel
|
|
config *config.WeComSettings
|
|
|
|
ctx context.Context
|
|
cancel context.CancelFunc
|
|
|
|
conn *websocket.Conn
|
|
connMu sync.Mutex
|
|
|
|
pendingMu sync.Mutex
|
|
pending map[string]chan wecomEnvelope
|
|
|
|
turnsMu sync.Mutex
|
|
turns map[string][]wecomTurn
|
|
|
|
recent *recentMessageSet
|
|
routes *reqIDStore
|
|
mediaClient *http.Client
|
|
commandSend func(wecomCommand, time.Duration) (wecomEnvelope, error)
|
|
}
|
|
|
|
type wecomTurn struct {
|
|
ReqID string
|
|
ChatID string
|
|
ChatType uint32
|
|
StreamID string
|
|
CreatedAt time.Time
|
|
}
|
|
|
|
type wecomStreamer struct {
|
|
channel *WeComChannel
|
|
chatID string
|
|
turn wecomTurn
|
|
|
|
mu sync.Mutex
|
|
closed bool
|
|
lastSentAt time.Time
|
|
content string
|
|
}
|
|
|
|
type recentMessageSet struct {
|
|
mu sync.Mutex
|
|
seen map[string]struct{}
|
|
ring []string
|
|
idx int
|
|
}
|
|
|
|
func newRecentMessageSet(capacity int) *recentMessageSet {
|
|
if capacity <= 0 {
|
|
capacity = wecomRecentMessageMax
|
|
}
|
|
return &recentMessageSet{
|
|
seen: make(map[string]struct{}, capacity),
|
|
ring: make([]string, capacity),
|
|
}
|
|
}
|
|
|
|
func (s *recentMessageSet) Mark(id string) bool {
|
|
if id == "" {
|
|
return true
|
|
}
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
if _, ok := s.seen[id]; ok {
|
|
return false
|
|
}
|
|
if old := s.ring[s.idx]; old != "" {
|
|
delete(s.seen, old)
|
|
}
|
|
s.ring[s.idx] = id
|
|
s.idx = (s.idx + 1) % len(s.ring)
|
|
s.seen[id] = struct{}{}
|
|
return true
|
|
}
|
|
|
|
func NewChannel(bc *config.Channel, cfg *config.WeComSettings, messageBus *bus.MessageBus) (*WeComChannel, error) {
|
|
if cfg.BotID == "" || cfg.Secret.String() == "" {
|
|
return nil, fmt.Errorf("wecom bot_id and secret are required")
|
|
}
|
|
if cfg.WebSocketURL == "" {
|
|
cfg.WebSocketURL = wecomDefaultWebSocketURL
|
|
}
|
|
|
|
base := channels.NewBaseChannel(
|
|
"wecom",
|
|
cfg,
|
|
messageBus,
|
|
bc.AllowFrom,
|
|
channels.WithReasoningChannelID(bc.ReasoningChannelID),
|
|
)
|
|
|
|
ch := &WeComChannel{
|
|
BaseChannel: base,
|
|
config: cfg,
|
|
pending: make(map[string]chan wecomEnvelope),
|
|
turns: make(map[string][]wecomTurn),
|
|
recent: newRecentMessageSet(wecomRecentMessageMax),
|
|
routes: newReqIDStore(""),
|
|
mediaClient: &http.Client{Timeout: wecomMediaTimeout},
|
|
}
|
|
ch.SetOwner(ch)
|
|
return ch, nil
|
|
}
|
|
|
|
func (c *WeComChannel) Name() string { return "wecom" }
|
|
|
|
func (c *WeComChannel) Start(ctx context.Context) error {
|
|
logger.InfoC("wecom", "Starting WeCom channel...")
|
|
c.ctx, c.cancel = context.WithCancel(ctx)
|
|
c.SetRunning(true)
|
|
go c.connectLoop()
|
|
return nil
|
|
}
|
|
|
|
func (c *WeComChannel) Stop(_ context.Context) error {
|
|
logger.InfoC("wecom", "Stopping WeCom channel...")
|
|
if c.cancel != nil {
|
|
c.cancel()
|
|
}
|
|
c.connMu.Lock()
|
|
if c.conn != nil {
|
|
_ = c.conn.Close()
|
|
c.conn = nil
|
|
}
|
|
c.connMu.Unlock()
|
|
c.clearTurns()
|
|
c.SetRunning(false)
|
|
return nil
|
|
}
|
|
|
|
func (c *WeComChannel) BeginStream(_ context.Context, chatID string) (channels.Streamer, error) {
|
|
if !c.IsRunning() {
|
|
return nil, channels.ErrNotRunning
|
|
}
|
|
|
|
turn, ok := c.getTurn(chatID)
|
|
if !ok {
|
|
return nil, fmt.Errorf("wecom streaming unavailable: no active turn")
|
|
}
|
|
if time.Since(turn.CreatedAt) > wecomStreamMaxDuration {
|
|
c.consumeTurn(chatID, turn)
|
|
return nil, fmt.Errorf("wecom streaming unavailable: turn expired")
|
|
}
|
|
|
|
return &wecomStreamer{
|
|
channel: c,
|
|
chatID: chatID,
|
|
turn: turn,
|
|
}, nil
|
|
}
|
|
|
|
func (c *WeComChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([]string, error) {
|
|
if !c.IsRunning() {
|
|
return nil, channels.ErrNotRunning
|
|
}
|
|
content := strings.TrimSpace(msg.Content)
|
|
if content == "" {
|
|
return nil, nil
|
|
}
|
|
|
|
if turn, ok := c.getTurn(msg.ChatID); ok {
|
|
if time.Since(turn.CreatedAt) <= wecomStreamMaxDuration {
|
|
if err := c.sendStreamReply(turn, content); err == nil {
|
|
c.consumeTurn(msg.ChatID, turn)
|
|
return nil, nil
|
|
}
|
|
}
|
|
c.consumeTurn(msg.ChatID, turn)
|
|
}
|
|
|
|
if route, ok := c.routes.Get(msg.ChatID); ok {
|
|
if err := c.sendActivePush(route.ChatID, route.ChatType, content); err != nil {
|
|
return nil, err
|
|
}
|
|
return nil, nil
|
|
}
|
|
|
|
if err := c.sendActivePush(msg.ChatID, 0, content); err != nil {
|
|
return nil, err
|
|
}
|
|
return nil, nil
|
|
}
|
|
|
|
func (c *WeComChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) ([]string, error) {
|
|
if !c.IsRunning() {
|
|
return nil, channels.ErrNotRunning
|
|
}
|
|
|
|
route, chatType, hasTurn := c.resolveMediaRoute(msg.ChatID)
|
|
chatID := route.ChatID
|
|
if chatID == "" {
|
|
chatID = msg.ChatID
|
|
}
|
|
|
|
for _, part := range msg.Parts {
|
|
if strings.TrimSpace(part.Ref) == "" {
|
|
if caption := strings.TrimSpace(part.Caption); caption != "" {
|
|
if err := c.sendActivePush(chatID, chatType, caption); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
continue
|
|
}
|
|
|
|
localPath, filename, contentType, cleanup, err := c.resolveOutboundPart(ctx, part)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("wecom resolve media %q: %v: %w", part.Ref, err, channels.ErrSendFailed)
|
|
}
|
|
|
|
func() {
|
|
if cleanup != nil {
|
|
defer cleanup()
|
|
}
|
|
|
|
uploaded, uploadErr := c.uploadOutboundMedia(ctx, localPath, filename, contentType, part)
|
|
if uploadErr != nil {
|
|
logger.WarnCF("wecom", "Falling back to placeholder after media upload failure", map[string]any{
|
|
"chat_id": chatID,
|
|
"ref": part.Ref,
|
|
"filename": filename,
|
|
"content_type": contentType,
|
|
"error": uploadErr.Error(),
|
|
})
|
|
if hasTurn {
|
|
if finishErr := c.sendStreamChunk(route, true, ""); finishErr != nil {
|
|
err = finishErr
|
|
return
|
|
}
|
|
c.deleteTurn(msg.ChatID)
|
|
hasTurn = false
|
|
}
|
|
err = c.sendActivePush(chatID, chatType, fallbackWeComMediaText(part, "", filename))
|
|
return
|
|
}
|
|
|
|
if hasTurn {
|
|
err = c.sendTurnMedia(route, uploaded)
|
|
c.deleteTurn(msg.ChatID)
|
|
hasTurn = false
|
|
} else {
|
|
err = c.sendActiveMedia(chatID, chatType, uploaded)
|
|
}
|
|
if err != nil {
|
|
return
|
|
}
|
|
if caption := strings.TrimSpace(part.Caption); caption != "" {
|
|
err = c.sendActivePush(chatID, chatType, caption)
|
|
}
|
|
}()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
return nil, nil
|
|
}
|
|
|
|
func (c *WeComChannel) connectLoop() {
|
|
backoff := time.Second
|
|
for {
|
|
select {
|
|
case <-c.ctx.Done():
|
|
return
|
|
default:
|
|
}
|
|
|
|
if err := c.runConnection(); err != nil {
|
|
logger.WarnCF("wecom", "WeCom connection lost", map[string]any{
|
|
"error": err.Error(),
|
|
"backoff": backoff.String(),
|
|
})
|
|
select {
|
|
case <-time.After(backoff):
|
|
case <-c.ctx.Done():
|
|
return
|
|
}
|
|
if backoff < time.Minute {
|
|
backoff *= 2
|
|
if backoff > time.Minute {
|
|
backoff = time.Minute
|
|
}
|
|
}
|
|
continue
|
|
}
|
|
return
|
|
}
|
|
}
|
|
|
|
func (c *WeComChannel) runConnection() error {
|
|
dialCtx, cancel := context.WithTimeout(c.ctx, wecomConnectTimeout)
|
|
defer cancel()
|
|
|
|
conn, resp, err := websocket.DefaultDialer.DialContext(dialCtx, c.config.WebSocketURL, nil)
|
|
if resp != nil {
|
|
_ = resp.Body.Close()
|
|
}
|
|
if err != nil {
|
|
return fmt.Errorf("%w: %v", channels.ErrTemporary, err)
|
|
}
|
|
|
|
c.connMu.Lock()
|
|
c.conn = conn
|
|
c.connMu.Unlock()
|
|
defer func() {
|
|
c.connMu.Lock()
|
|
if c.conn == conn {
|
|
c.conn = nil
|
|
}
|
|
c.connMu.Unlock()
|
|
_ = conn.Close()
|
|
c.clearTurns()
|
|
}()
|
|
|
|
readErrCh := make(chan error, 1)
|
|
go func() {
|
|
readErrCh <- c.readLoop(conn)
|
|
}()
|
|
|
|
if writeErr := c.writeAndWait(conn, wecomCommand{
|
|
Cmd: wecomCmdSubscribe,
|
|
Headers: wecomHeaders{ReqID: randomID(10)},
|
|
Body: map[string]string{
|
|
"bot_id": c.config.BotID,
|
|
"secret": c.config.Secret.String(),
|
|
},
|
|
}, wecomCommandTimeout); writeErr != nil {
|
|
return writeErr
|
|
}
|
|
|
|
heartbeatDone := make(chan struct{})
|
|
go func() {
|
|
defer close(heartbeatDone)
|
|
c.heartbeatLoop(conn)
|
|
}()
|
|
|
|
err = <-readErrCh
|
|
_ = conn.Close()
|
|
<-heartbeatDone
|
|
return err
|
|
}
|
|
|
|
func (c *WeComChannel) heartbeatLoop(conn *websocket.Conn) {
|
|
ticker := time.NewTicker(wecomHeartbeatInterval)
|
|
defer ticker.Stop()
|
|
for {
|
|
select {
|
|
case <-ticker.C:
|
|
if err := c.writeAndWait(conn, wecomCommand{
|
|
Cmd: wecomCmdPing,
|
|
Headers: wecomHeaders{ReqID: randomID(10)},
|
|
}, wecomCommandTimeout); err != nil {
|
|
logger.WarnCF("wecom", "Heartbeat failed", map[string]any{"error": err.Error()})
|
|
_ = conn.Close()
|
|
return
|
|
}
|
|
case <-c.ctx.Done():
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func (c *WeComChannel) readLoop(conn *websocket.Conn) error {
|
|
for {
|
|
_, raw, err := conn.ReadMessage()
|
|
if err != nil {
|
|
select {
|
|
case <-c.ctx.Done():
|
|
return nil
|
|
default:
|
|
return fmt.Errorf("%w: %v", channels.ErrTemporary, err)
|
|
}
|
|
}
|
|
|
|
var env wecomEnvelope
|
|
if err := json.Unmarshal(raw, &env); err != nil {
|
|
logger.WarnCF("wecom", "Failed to parse WebSocket message", map[string]any{"error": err.Error()})
|
|
continue
|
|
}
|
|
|
|
if env.Cmd == "" && env.Headers.ReqID != "" {
|
|
c.pendingMu.Lock()
|
|
ch, ok := c.pending[env.Headers.ReqID]
|
|
if ok {
|
|
delete(c.pending, env.Headers.ReqID)
|
|
}
|
|
c.pendingMu.Unlock()
|
|
if ok {
|
|
ch <- env
|
|
}
|
|
continue
|
|
}
|
|
|
|
go c.handleEnvelope(env)
|
|
}
|
|
}
|
|
|
|
func (c *WeComChannel) handleEnvelope(env wecomEnvelope) {
|
|
switch env.Cmd {
|
|
case wecomCmdMsgCallback:
|
|
c.handleMessageCallback(env)
|
|
case wecomCmdEventCallback:
|
|
c.handleEventCallback(env)
|
|
default:
|
|
logger.DebugCF("wecom", "Ignoring unsupported WeCom command", map[string]any{"cmd": env.Cmd})
|
|
}
|
|
}
|
|
|
|
func (c *WeComChannel) handleEventCallback(env wecomEnvelope) {
|
|
var msg wecomIncomingMessage
|
|
if err := json.Unmarshal(env.Body, &msg); err != nil {
|
|
logger.WarnCF("wecom", "Failed to parse WeCom event callback", map[string]any{"error": err.Error()})
|
|
}
|
|
}
|
|
|
|
func (c *WeComChannel) handleMessageCallback(env wecomEnvelope) {
|
|
var msg wecomIncomingMessage
|
|
if err := json.Unmarshal(env.Body, &msg); err != nil {
|
|
logger.WarnCF("wecom", "Failed to parse WeCom message callback", map[string]any{"error": err.Error()})
|
|
return
|
|
}
|
|
if !c.recent.Mark(msg.MsgID) {
|
|
return
|
|
}
|
|
|
|
reqID := env.Headers.ReqID
|
|
if reqID == "" {
|
|
logger.WarnC("wecom", "WeCom message callback missing req_id")
|
|
return
|
|
}
|
|
if msg.Event != nil && msg.Event.EventType != "" {
|
|
return
|
|
}
|
|
|
|
if err := c.dispatchIncoming(reqID, msg); err != nil {
|
|
logger.WarnCF("wecom", "Failed to dispatch WeCom message", map[string]any{
|
|
"req_id": reqID,
|
|
"error": err.Error(),
|
|
})
|
|
_ = c.respondImmediate(reqID, "The WeCom message could not be processed.")
|
|
}
|
|
}
|
|
|
|
func (c *WeComChannel) dispatchIncoming(reqID string, msg wecomIncomingMessage) error {
|
|
senderID := msg.From.UserID
|
|
if senderID == "" {
|
|
senderID = "unknown"
|
|
}
|
|
actualChatID := incomingChatID(msg)
|
|
chatType := incomingChatTypeCode(msg.ChatType)
|
|
peerKind := "direct"
|
|
if msg.ChatType == "group" {
|
|
peerKind = "group"
|
|
}
|
|
|
|
sender := bus.SenderInfo{
|
|
Platform: "wecom",
|
|
PlatformID: senderID,
|
|
CanonicalID: identity.BuildCanonicalID("wecom", senderID),
|
|
DisplayName: senderID,
|
|
}
|
|
|
|
var (
|
|
content string
|
|
quoteText string
|
|
mediaRefs []string
|
|
err error
|
|
)
|
|
scope := channels.BuildMediaScope("wecom", actualChatID, msg.MsgID)
|
|
switch msg.MsgType {
|
|
case "text":
|
|
if msg.Text != nil {
|
|
content = strings.TrimSpace(msg.Text.Content)
|
|
}
|
|
case "voice":
|
|
if msg.Voice != nil {
|
|
content = strings.TrimSpace(msg.Voice.Content)
|
|
}
|
|
case "image":
|
|
content = "[image]"
|
|
mediaRefs, err = c.collectSingleMedia(c.ctx, scope, msg.MsgID, &mediaPayload{
|
|
url: msg.Image.URL,
|
|
aesKey: msg.Image.AESKey,
|
|
}, "image", ".jpg")
|
|
case "file":
|
|
content = "[file]"
|
|
mediaRefs, err = c.collectSingleMedia(c.ctx, scope, msg.MsgID, &mediaPayload{
|
|
url: msg.File.URL,
|
|
aesKey: msg.File.AESKey,
|
|
}, "file", ".bin")
|
|
case "video":
|
|
content = "[video]"
|
|
mediaRefs, err = c.collectSingleMedia(c.ctx, scope, msg.MsgID, &mediaPayload{
|
|
url: msg.Video.URL,
|
|
aesKey: msg.Video.AESKey,
|
|
}, "video", ".mp4")
|
|
case "mixed":
|
|
content, mediaRefs, err = c.collectMixedMedia(c.ctx, scope, msg)
|
|
default:
|
|
return c.respondImmediate(reqID, "Unsupported WeCom message type: "+msg.MsgType)
|
|
}
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if msg.Quote != nil && msg.Quote.Text != nil {
|
|
quoteText = strings.TrimSpace(msg.Quote.Text.Content)
|
|
if content == "" {
|
|
content = quoteText
|
|
}
|
|
}
|
|
if content == "" && len(mediaRefs) == 0 {
|
|
return c.respondImmediate(reqID, "The WeCom message did not contain usable content.")
|
|
}
|
|
|
|
turn := wecomTurn{
|
|
ReqID: reqID,
|
|
ChatID: actualChatID,
|
|
ChatType: chatType,
|
|
StreamID: randomID(10),
|
|
CreatedAt: time.Now(),
|
|
}
|
|
c.queueTurn(actualChatID, turn)
|
|
if err := c.routes.Put(actualChatID, reqID, chatType, wecomRouteTTL); err != nil {
|
|
logger.WarnCF("wecom", "Failed to persist req_id route", map[string]any{
|
|
"chat_id": actualChatID,
|
|
"req_id": reqID,
|
|
"error": err.Error(),
|
|
})
|
|
}
|
|
|
|
opening := ""
|
|
if c.config.SendThinkingMessage {
|
|
opening = "Processing..."
|
|
}
|
|
if err := c.sendStreamChunk(turn, false, opening); err != nil {
|
|
return err
|
|
}
|
|
|
|
metadata := map[string]string{
|
|
"channel": "wecom",
|
|
"req_id": reqID,
|
|
"chat_id": actualChatID,
|
|
"chat_type": msg.ChatType,
|
|
"msg_id": msg.MsgID,
|
|
"msg_type": msg.MsgType,
|
|
}
|
|
if quoteText != "" {
|
|
metadata["quote_text"] = quoteText
|
|
}
|
|
|
|
inboundCtx := bus.InboundContext{
|
|
Channel: c.Name(),
|
|
Account: strings.TrimSpace(msg.AIBotID),
|
|
ChatID: actualChatID,
|
|
ChatType: peerKind,
|
|
SenderID: senderID,
|
|
MessageID: msg.MsgID,
|
|
ReplyHandles: map[string]string{
|
|
"req_id": reqID,
|
|
},
|
|
Raw: metadata,
|
|
}
|
|
|
|
c.HandleInboundContext(c.ctx, actualChatID, content, mediaRefs, inboundCtx, sender)
|
|
return nil
|
|
}
|
|
|
|
func (c *WeComChannel) collectSingleMedia(
|
|
ctx context.Context,
|
|
scope, msgID string,
|
|
payload interface {
|
|
GetURL() string
|
|
GetAESKey() string
|
|
},
|
|
label, fallbackExt string,
|
|
) ([]string, error) {
|
|
if payload == nil || payload.GetURL() == "" {
|
|
return nil, fmt.Errorf("%s payload is empty", label)
|
|
}
|
|
ref, err := c.storeRemoteMedia(ctx, scope, msgID, payload.GetURL(), payload.GetAESKey(), fallbackExt)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return []string{ref}, nil
|
|
}
|
|
|
|
type mediaPayload struct {
|
|
url string
|
|
aesKey string
|
|
}
|
|
|
|
func (p *mediaPayload) GetURL() string { return p.url }
|
|
func (p *mediaPayload) GetAESKey() string { return p.aesKey }
|
|
|
|
func (c *WeComChannel) collectMixedMedia(
|
|
ctx context.Context,
|
|
scope string,
|
|
msg wecomIncomingMessage,
|
|
) (string, []string, error) {
|
|
if msg.Mixed == nil {
|
|
return "", nil, fmt.Errorf("mixed message is empty")
|
|
}
|
|
|
|
var textParts []string
|
|
var refs []string
|
|
for idx, item := range msg.Mixed.MsgItem {
|
|
switch item.MsgType {
|
|
case "text":
|
|
if item.Text != nil && strings.TrimSpace(item.Text.Content) != "" {
|
|
textParts = append(textParts, strings.TrimSpace(item.Text.Content))
|
|
}
|
|
case "image":
|
|
if item.Image != nil && item.Image.URL != "" {
|
|
ref, err := c.storeRemoteMedia(
|
|
ctx,
|
|
scope,
|
|
fmt.Sprintf("%s-%d", msg.MsgID, idx),
|
|
item.Image.URL,
|
|
item.Image.AESKey,
|
|
".jpg",
|
|
)
|
|
if err != nil {
|
|
return "", nil, err
|
|
}
|
|
refs = append(refs, ref)
|
|
}
|
|
case "file":
|
|
if item.File != nil && item.File.URL != "" {
|
|
ref, err := c.storeRemoteMedia(
|
|
ctx,
|
|
scope,
|
|
fmt.Sprintf("%s-%d", msg.MsgID, idx),
|
|
item.File.URL,
|
|
item.File.AESKey,
|
|
".bin",
|
|
)
|
|
if err != nil {
|
|
return "", nil, err
|
|
}
|
|
refs = append(refs, ref)
|
|
}
|
|
}
|
|
}
|
|
|
|
content := strings.Join(textParts, "\n")
|
|
if content == "" && len(refs) > 0 {
|
|
content = "[media]"
|
|
}
|
|
return content, refs, nil
|
|
}
|
|
|
|
func (c *WeComChannel) respondImmediate(reqID, content string) error {
|
|
turn := wecomTurn{
|
|
ReqID: reqID,
|
|
StreamID: randomID(10),
|
|
CreatedAt: time.Now(),
|
|
}
|
|
return c.sendStreamChunk(turn, true, content)
|
|
}
|
|
|
|
func (c *WeComChannel) sendStreamReply(turn wecomTurn, content string) error {
|
|
return c.sendStreamChunk(turn, true, content)
|
|
}
|
|
|
|
func (c *WeComChannel) sendStreamChunk(turn wecomTurn, finish bool, content string) error {
|
|
return c.sendCommand(wecomCommand{
|
|
Cmd: wecomCmdRespondMsg,
|
|
Headers: wecomHeaders{ReqID: turn.ReqID},
|
|
Body: wecomRespondMsgBody{
|
|
MsgType: "stream",
|
|
Stream: &wecomStreamContent{
|
|
ID: turn.StreamID,
|
|
Finish: finish,
|
|
Content: content,
|
|
},
|
|
},
|
|
}, wecomCommandTimeout)
|
|
}
|
|
|
|
func (c *WeComChannel) sendTurnMedia(turn wecomTurn, uploaded *wecomOutboundMedia) error {
|
|
if uploaded == nil {
|
|
return fmt.Errorf("wecom outbound media is nil: %w", channels.ErrSendFailed)
|
|
}
|
|
if err := c.sendCommand(wecomCommand{
|
|
Cmd: wecomCmdRespondMsg,
|
|
Headers: wecomHeaders{ReqID: turn.ReqID},
|
|
Body: uploaded.respondBody(),
|
|
}, wecomCommandTimeout); err != nil {
|
|
return err
|
|
}
|
|
return c.sendStreamChunk(turn, true, "")
|
|
}
|
|
|
|
func (c *WeComChannel) sendActivePush(chatID string, chatType uint32, content string) error {
|
|
if strings.TrimSpace(chatID) == "" {
|
|
return fmt.Errorf("empty chat ID: %w", channels.ErrSendFailed)
|
|
}
|
|
return c.sendCommand(wecomCommand{
|
|
Cmd: wecomCmdSendMsg,
|
|
Headers: wecomHeaders{ReqID: randomID(10)},
|
|
Body: wecomSendMsgBody{
|
|
ChatID: chatID,
|
|
ChatType: chatType,
|
|
MsgType: "markdown",
|
|
Markdown: &wecomMarkdownContent{Content: content},
|
|
},
|
|
}, wecomCommandTimeout)
|
|
}
|
|
|
|
func (c *WeComChannel) sendActiveMedia(chatID string, chatType uint32, uploaded *wecomOutboundMedia) error {
|
|
if strings.TrimSpace(chatID) == "" {
|
|
return fmt.Errorf("empty chat ID: %w", channels.ErrSendFailed)
|
|
}
|
|
if uploaded == nil {
|
|
return fmt.Errorf("wecom outbound media is nil: %w", channels.ErrSendFailed)
|
|
}
|
|
return c.sendCommand(wecomCommand{
|
|
Cmd: wecomCmdSendMsg,
|
|
Headers: wecomHeaders{ReqID: randomID(10)},
|
|
Body: uploaded.sendBody(chatID, chatType),
|
|
}, wecomCommandTimeout)
|
|
}
|
|
|
|
func (c *WeComChannel) sendCommand(cmd wecomCommand, timeout time.Duration) error {
|
|
_, err := c.sendCommandAck(cmd, timeout)
|
|
return err
|
|
}
|
|
|
|
func (c *WeComChannel) sendCommandAck(cmd wecomCommand, timeout time.Duration) (wecomEnvelope, error) {
|
|
if c.commandSend != nil {
|
|
return c.commandSend(cmd, timeout)
|
|
}
|
|
return c.writeCurrentAck(cmd, timeout)
|
|
}
|
|
|
|
func (c *WeComChannel) writeCurrentAck(cmd wecomCommand, timeout time.Duration) (wecomEnvelope, error) {
|
|
c.connMu.Lock()
|
|
conn := c.conn
|
|
c.connMu.Unlock()
|
|
if conn == nil {
|
|
return wecomEnvelope{}, fmt.Errorf("wecom websocket not connected: %w", channels.ErrTemporary)
|
|
}
|
|
return c.writeAndWaitAck(conn, cmd, timeout)
|
|
}
|
|
|
|
func (c *WeComChannel) writeAndWait(conn *websocket.Conn, cmd wecomCommand, timeout time.Duration) error {
|
|
_, err := c.writeAndWaitAck(conn, cmd, timeout)
|
|
return err
|
|
}
|
|
|
|
func (c *WeComChannel) writeAndWaitAck(
|
|
conn *websocket.Conn,
|
|
cmd wecomCommand,
|
|
timeout time.Duration,
|
|
) (wecomEnvelope, error) {
|
|
if cmd.Headers.ReqID == "" {
|
|
cmd.Headers.ReqID = randomID(10)
|
|
}
|
|
waitCh := make(chan wecomEnvelope, 1)
|
|
c.pendingMu.Lock()
|
|
c.pending[cmd.Headers.ReqID] = waitCh
|
|
c.pendingMu.Unlock()
|
|
defer func() {
|
|
c.pendingMu.Lock()
|
|
delete(c.pending, cmd.Headers.ReqID)
|
|
c.pendingMu.Unlock()
|
|
}()
|
|
|
|
data, err := json.Marshal(cmd)
|
|
if err != nil {
|
|
return wecomEnvelope{}, fmt.Errorf("%w: %v", channels.ErrSendFailed, err)
|
|
}
|
|
c.connMu.Lock()
|
|
err = conn.WriteMessage(websocket.TextMessage, data)
|
|
c.connMu.Unlock()
|
|
if err != nil {
|
|
return wecomEnvelope{}, fmt.Errorf("%w: %v", channels.ErrTemporary, err)
|
|
}
|
|
|
|
timer := time.NewTimer(timeout)
|
|
defer timer.Stop()
|
|
select {
|
|
case env := <-waitCh:
|
|
if env.ErrCode != 0 {
|
|
return wecomEnvelope{}, fmt.Errorf(
|
|
"%w: wecom errcode=%d errmsg=%s",
|
|
channels.ErrTemporary,
|
|
env.ErrCode,
|
|
env.ErrMsg,
|
|
)
|
|
}
|
|
return env, nil
|
|
case <-timer.C:
|
|
return wecomEnvelope{}, fmt.Errorf("%w: timeout waiting for WeCom ack", channels.ErrTemporary)
|
|
case <-c.ctx.Done():
|
|
return wecomEnvelope{}, c.ctx.Err()
|
|
}
|
|
}
|
|
|
|
func (c *WeComChannel) getTurn(chatID string) (wecomTurn, bool) {
|
|
c.turnsMu.Lock()
|
|
defer c.turnsMu.Unlock()
|
|
queue := c.turns[chatID]
|
|
if len(queue) == 0 {
|
|
return wecomTurn{}, false
|
|
}
|
|
return queue[0], true
|
|
}
|
|
|
|
func (c *WeComChannel) deleteTurn(chatID string) {
|
|
c.turnsMu.Lock()
|
|
defer c.turnsMu.Unlock()
|
|
queue := c.turns[chatID]
|
|
if len(queue) <= 1 {
|
|
delete(c.turns, chatID)
|
|
return
|
|
}
|
|
c.turns[chatID] = queue[1:]
|
|
}
|
|
|
|
func (c *WeComChannel) queueTurn(chatID string, turn wecomTurn) {
|
|
c.turnsMu.Lock()
|
|
defer c.turnsMu.Unlock()
|
|
c.turns[chatID] = append(c.turns[chatID], turn)
|
|
}
|
|
|
|
func (c *WeComChannel) consumeTurn(chatID string, turn wecomTurn) bool {
|
|
c.turnsMu.Lock()
|
|
defer c.turnsMu.Unlock()
|
|
|
|
queue := c.turns[chatID]
|
|
if len(queue) == 0 {
|
|
return false
|
|
}
|
|
current := queue[0]
|
|
if current.ReqID != turn.ReqID || current.StreamID != turn.StreamID {
|
|
return false
|
|
}
|
|
if len(queue) == 1 {
|
|
delete(c.turns, chatID)
|
|
return true
|
|
}
|
|
c.turns[chatID] = queue[1:]
|
|
return true
|
|
}
|
|
|
|
func (c *WeComChannel) clearTurns() {
|
|
c.turnsMu.Lock()
|
|
c.turns = make(map[string][]wecomTurn)
|
|
c.turnsMu.Unlock()
|
|
}
|
|
|
|
func randomID(n int) string {
|
|
const alphabet = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
|
if n <= 0 {
|
|
n = 10
|
|
}
|
|
buf := make([]byte, n)
|
|
for i := range buf {
|
|
v, _ := rand.Int(rand.Reader, big.NewInt(int64(len(alphabet))))
|
|
buf[i] = alphabet[v.Int64()]
|
|
}
|
|
return string(buf)
|
|
}
|
|
|
|
func (s *wecomStreamer) Update(ctx context.Context, content string) error {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
if s.closed {
|
|
return nil
|
|
}
|
|
if err := s.validateActiveTurn(); err != nil {
|
|
return err
|
|
}
|
|
if err := ctx.Err(); err != nil {
|
|
return err
|
|
}
|
|
|
|
if !s.lastSentAt.IsZero() {
|
|
wait := time.Until(s.lastSentAt.Add(wecomStreamMinInterval))
|
|
if wait > 0 {
|
|
timer := time.NewTimer(wait)
|
|
defer timer.Stop()
|
|
select {
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
case <-timer.C:
|
|
}
|
|
}
|
|
}
|
|
|
|
if err := s.channel.sendStreamChunk(s.turn, false, content); err != nil {
|
|
return err
|
|
}
|
|
s.content = content
|
|
s.lastSentAt = time.Now()
|
|
return nil
|
|
}
|
|
|
|
func (s *wecomStreamer) Finalize(ctx context.Context, content string) error {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
if s.closed {
|
|
return nil
|
|
}
|
|
if err := s.validateActiveTurn(); err != nil {
|
|
return err
|
|
}
|
|
if err := ctx.Err(); err != nil {
|
|
return err
|
|
}
|
|
if err := s.channel.sendStreamChunk(s.turn, true, content); err != nil {
|
|
return err
|
|
}
|
|
|
|
s.content = content
|
|
s.closed = true
|
|
s.channel.consumeTurn(s.chatID, s.turn)
|
|
return nil
|
|
}
|
|
|
|
func (s *wecomStreamer) Cancel(_ context.Context) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
if s.closed {
|
|
return
|
|
}
|
|
if s.validateActiveTurn() == nil {
|
|
_ = s.channel.sendStreamChunk(s.turn, true, s.content)
|
|
s.channel.consumeTurn(s.chatID, s.turn)
|
|
}
|
|
s.closed = true
|
|
}
|
|
|
|
func (s *wecomStreamer) validateActiveTurn() error {
|
|
if time.Since(s.turn.CreatedAt) > wecomStreamMaxDuration {
|
|
s.channel.consumeTurn(s.chatID, s.turn)
|
|
return fmt.Errorf("wecom streaming unavailable: turn expired")
|
|
}
|
|
current, ok := s.channel.getTurn(s.chatID)
|
|
if !ok || current.ReqID != s.turn.ReqID || current.StreamID != s.turn.StreamID {
|
|
return fmt.Errorf("wecom streaming unavailable: turn no longer active")
|
|
}
|
|
return nil
|
|
}
|