From 60b68b305a1eb61057945ed6efe2b2ef2b3264ee Mon Sep 17 00:00:00 2001 From: Hoshina Date: Mon, 23 Feb 2026 04:55:15 +0800 Subject: [PATCH] feat(channels): add typing/placeholder automation and Pico Protocol channel (Phase 10 + 7) Phase 10: Define TypingCapable, MessageEditor, PlaceholderRecorder interfaces. Manager orchestrates outbound typing stop and placeholder editing via preSend. Migrate Telegram, Discord, Slack, OneBot to register state with Manager instead of handling locally in Send. Phase 7: Add native WebSocket Pico Protocol channel as reference implementation of all optional capability interfaces. --- cmd/picoclaw/internal/gateway/helpers.go | 1 + pkg/channels/base.go | 27 +- pkg/channels/discord/discord.go | 14 +- pkg/channels/interfaces.go | 24 ++ pkg/channels/manager.go | 56 +++ pkg/channels/manager_test.go | 216 ++++++++++++ pkg/channels/onebot/onebot.go | 13 +- pkg/channels/pico/init.go | 13 + pkg/channels/pico/pico.go | 430 +++++++++++++++++++++++ pkg/channels/pico/protocol.go | 46 +++ pkg/channels/slack/slack.go | 24 ++ pkg/channels/telegram/telegram.go | 113 +++--- pkg/config/config.go | 12 + pkg/config/defaults.go | 14 + 14 files changed, 913 insertions(+), 90 deletions(-) create mode 100644 pkg/channels/interfaces.go create mode 100644 pkg/channels/pico/init.go create mode 100644 pkg/channels/pico/pico.go create mode 100644 pkg/channels/pico/protocol.go diff --git a/cmd/picoclaw/internal/gateway/helpers.go b/cmd/picoclaw/internal/gateway/helpers.go index 758157f53..6ac41fab1 100644 --- a/cmd/picoclaw/internal/gateway/helpers.go +++ b/cmd/picoclaw/internal/gateway/helpers.go @@ -18,6 +18,7 @@ import ( _ "github.com/sipeed/picoclaw/pkg/channels/line" _ "github.com/sipeed/picoclaw/pkg/channels/maixcam" _ "github.com/sipeed/picoclaw/pkg/channels/onebot" + _ "github.com/sipeed/picoclaw/pkg/channels/pico" _ "github.com/sipeed/picoclaw/pkg/channels/qq" _ "github.com/sipeed/picoclaw/pkg/channels/slack" _ "github.com/sipeed/picoclaw/pkg/channels/telegram" diff --git a/pkg/channels/base.go b/pkg/channels/base.go index e345aedf0..c22a27eb9 100644 --- a/pkg/channels/base.go +++ b/pkg/channels/base.go @@ -44,14 +44,15 @@ type MessageLengthProvider interface { } type BaseChannel struct { - config any - bus *bus.MessageBus - running atomic.Bool - name string - allowList []string - maxMessageLength int - groupTrigger config.GroupTriggerConfig - mediaStore media.MediaStore + config any + bus *bus.MessageBus + running atomic.Bool + name string + allowList []string + maxMessageLength int + groupTrigger config.GroupTriggerConfig + mediaStore media.MediaStore + placeholderRecorder PlaceholderRecorder } func NewBaseChannel( @@ -203,6 +204,16 @@ func (c *BaseChannel) SetMediaStore(s media.MediaStore) { c.mediaStore = s } // GetMediaStore returns the injected MediaStore (may be nil). func (c *BaseChannel) GetMediaStore() media.MediaStore { return c.mediaStore } +// SetPlaceholderRecorder injects a PlaceholderRecorder into the channel. +func (c *BaseChannel) SetPlaceholderRecorder(r PlaceholderRecorder) { + c.placeholderRecorder = r +} + +// GetPlaceholderRecorder returns the injected PlaceholderRecorder (may be nil). +func (c *BaseChannel) GetPlaceholderRecorder() PlaceholderRecorder { + return c.placeholderRecorder +} + // BuildMediaScope constructs a scope key for media lifecycle tracking. func BuildMediaScope(channel, chatID, messageID string) string { id := messageID diff --git a/pkg/channels/discord/discord.go b/pkg/channels/discord/discord.go index 4ef4906c1..ee698da61 100644 --- a/pkg/channels/discord/discord.go +++ b/pkg/channels/discord/discord.go @@ -106,8 +106,6 @@ func (c *DiscordChannel) Stop(ctx context.Context) error { } func (c *DiscordChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { - c.stopTyping(msg.ChatID) - if !c.IsRunning() { return channels.ErrNotRunning } @@ -126,8 +124,6 @@ func (c *DiscordChannel) Send(ctx context.Context, msg bus.OutboundMessage) erro // SendMedia implements the channels.MediaSender interface. func (c *DiscordChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error { - c.stopTyping(msg.ChatID) - if !c.IsRunning() { return channels.ErrNotRunning } @@ -221,6 +217,12 @@ func (c *DiscordChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMes } } +// EditMessage implements channels.MessageEditor. +func (c *DiscordChannel) EditMessage(ctx context.Context, chatID string, messageID string, content string) error { + _, err := c.session.ChannelMessageEdit(chatID, messageID, content) + return err +} + func (c *DiscordChannel) sendChunk(ctx context.Context, channelID, content string) error { // Use the passed ctx for timeout control sendCtx, cancel := context.WithTimeout(ctx, sendTimeout) @@ -350,6 +352,10 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag // Start typing after all early returns — guaranteed to have a matching Send() c.startTyping(m.ChannelID) + // Register typing stop with Manager for outbound orchestration + if rec := c.GetPlaceholderRecorder(); rec != nil { + rec.RecordTypingStop("discord", m.ChannelID, func() { c.stopTyping(m.ChannelID) }) + } logger.DebugCF("discord", "Received message", map[string]any{ "sender_name": senderName, diff --git a/pkg/channels/interfaces.go b/pkg/channels/interfaces.go new file mode 100644 index 000000000..32bfe95f8 --- /dev/null +++ b/pkg/channels/interfaces.go @@ -0,0 +1,24 @@ +package channels + +import "context" + +// TypingCapable — channels that can show a typing/thinking indicator. +// StartTyping begins the indicator and returns a stop function. +// The stop function MUST be idempotent and safe to call multiple times. +type TypingCapable interface { + StartTyping(ctx context.Context, chatID string) (stop func(), err error) +} + +// MessageEditor — channels that can edit an existing message. +// messageID is always string; channels convert platform-specific types internally. +type MessageEditor interface { + EditMessage(ctx context.Context, chatID string, messageID string, content string) error +} + +// PlaceholderRecorder is injected into channels by Manager. +// Channels call these methods on inbound to register typing/placeholder state. +// Manager uses the registered state on outbound to stop typing and edit placeholders. +type PlaceholderRecorder interface { + RecordPlaceholder(channel, chatID, placeholderID string) + RecordTypingStop(channel, chatID string, stop func()) +} diff --git a/pkg/channels/manager.go b/pkg/channels/manager.go index 92412edeb..4b1a43b7b 100644 --- a/pkg/channels/manager.go +++ b/pkg/channels/manager.go @@ -62,12 +62,55 @@ type Manager struct { mux *http.ServeMux httpServer *http.Server mu sync.RWMutex + placeholders sync.Map // "channel:chatID" → placeholderID (string) + typingStops sync.Map // "channel:chatID" → func() } type asyncTask struct { cancel context.CancelFunc } +// RecordPlaceholder registers a placeholder message for later editing. +// Implements PlaceholderRecorder. +func (m *Manager) RecordPlaceholder(channel, chatID, placeholderID string) { + key := channel + ":" + chatID + m.placeholders.Store(key, placeholderID) +} + +// RecordTypingStop registers a typing stop function for later invocation. +// Implements PlaceholderRecorder. +func (m *Manager) RecordTypingStop(channel, chatID string, stop func()) { + key := channel + ":" + chatID + m.typingStops.Store(key, stop) +} + +// preSend handles typing stop and placeholder editing before sending a message. +// Returns true if the message was edited into a placeholder (skip Send). +func (m *Manager) preSend(ctx context.Context, name string, msg bus.OutboundMessage, ch Channel) bool { + key := name + ":" + msg.ChatID + + // 1. Stop typing + if v, loaded := m.typingStops.LoadAndDelete(key); loaded { + if stop, ok := v.(func()); ok { + stop() // idempotent, safe + } + } + + // 2. Try editing placeholder + if v, loaded := m.placeholders.LoadAndDelete(key); loaded { + if placeholderID, ok := v.(string); ok && placeholderID != "" { + if editor, ok := ch.(MessageEditor); ok { + if err := editor.EditMessage(ctx, msg.ChatID, placeholderID, msg.Content); err == nil { + return true // edited successfully, skip Send + } + // edit failed → fall through to normal Send + } + } + } + + return false +} + func NewManager(cfg *config.Config, messageBus *bus.MessageBus, store media.MediaStore) (*Manager, error) { m := &Manager{ channels: make(map[string]Channel), @@ -109,6 +152,10 @@ func (m *Manager) initChannel(name, displayName string) { setter.SetMediaStore(m.mediaStore) } } + // Inject PlaceholderRecorder if channel supports it + if setter, ok := ch.(interface{ SetPlaceholderRecorder(PlaceholderRecorder) }); ok { + setter.SetPlaceholderRecorder(m) + } m.channels[name] = ch m.workers[name] = newChannelWorker(name, ch) logger.InfoCF("channels", "Channel enabled successfully", map[string]any{ @@ -168,6 +215,10 @@ func (m *Manager) initChannels() error { m.initChannel("wecom_app", "WeCom App") } + if m.config.Channels.Pico.Enabled && m.config.Channels.Pico.Token != "" { + m.initChannel("pico", "Pico") + } + logger.InfoCF("channels", "Channel initialization completed", map[string]any{ "enabled_channels": len(m.channels), }) @@ -383,6 +434,11 @@ func (m *Manager) sendWithRetry(ctx context.Context, name string, w *channelWork return } + // Pre-send: stop typing and try to edit placeholder + if m.preSend(ctx, name, msg, w.ch) { + return // placeholder was edited successfully, skip Send + } + var lastErr error for attempt := 0; attempt <= maxRetries; attempt++ { lastErr = w.ch.Send(ctx, msg) diff --git a/pkg/channels/manager_test.go b/pkg/channels/manager_test.go index 162c9f8c9..0573c0a8e 100644 --- a/pkg/channels/manager_test.go +++ b/pkg/channels/manager_test.go @@ -416,3 +416,219 @@ func TestSendWithRetry_ExponentialBackoff(t *testing.T) { t.Fatalf("expected %d calls, got %d", maxRetries+1, callCount.Load()) } } + +// --- Phase 10: preSend orchestration tests --- + +// mockMessageEditor is a channel that supports MessageEditor. +type mockMessageEditor struct { + mockChannel + editFn func(ctx context.Context, chatID, messageID, content string) error +} + +func (m *mockMessageEditor) EditMessage(ctx context.Context, chatID, messageID, content string) error { + return m.editFn(ctx, chatID, messageID, content) +} + +func TestPreSend_PlaceholderEditSuccess(t *testing.T) { + m := newTestManager() + var sendCalled bool + var editCalled bool + + ch := &mockMessageEditor{ + mockChannel: mockChannel{ + sendFn: func(_ context.Context, _ bus.OutboundMessage) error { + sendCalled = true + return nil + }, + }, + editFn: func(_ context.Context, chatID, messageID, content string) error { + editCalled = true + if chatID != "123" { + t.Fatalf("expected chatID 123, got %s", chatID) + } + if messageID != "456" { + t.Fatalf("expected messageID 456, got %s", messageID) + } + if content != "hello" { + t.Fatalf("expected content 'hello', got %s", content) + } + return nil + }, + } + + // Register placeholder + m.RecordPlaceholder("test", "123", "456") + + msg := bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"} + edited := m.preSend(context.Background(), "test", msg, ch) + + if !edited { + t.Fatal("expected preSend to return true (placeholder edited)") + } + if !editCalled { + t.Fatal("expected EditMessage to be called") + } + if sendCalled { + t.Fatal("expected Send to NOT be called when placeholder edited") + } +} + +func TestPreSend_PlaceholderEditFails_FallsThrough(t *testing.T) { + m := newTestManager() + + ch := &mockMessageEditor{ + mockChannel: mockChannel{ + sendFn: func(_ context.Context, _ bus.OutboundMessage) error { + return nil + }, + }, + editFn: func(_ context.Context, _, _, _ string) error { + return fmt.Errorf("edit failed") + }, + } + + m.RecordPlaceholder("test", "123", "456") + + msg := bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"} + edited := m.preSend(context.Background(), "test", msg, ch) + + if edited { + t.Fatal("expected preSend to return false when edit fails") + } +} + +func TestPreSend_TypingStopCalled(t *testing.T) { + m := newTestManager() + var stopCalled bool + + ch := &mockChannel{ + sendFn: func(_ context.Context, _ bus.OutboundMessage) error { + return nil + }, + } + + m.RecordTypingStop("test", "123", func() { + stopCalled = true + }) + + msg := bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"} + m.preSend(context.Background(), "test", msg, ch) + + if !stopCalled { + t.Fatal("expected typing stop func to be called") + } +} + +func TestPreSend_NoRegisteredState(t *testing.T) { + m := newTestManager() + + ch := &mockChannel{ + sendFn: func(_ context.Context, _ bus.OutboundMessage) error { + return nil + }, + } + + msg := bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"} + edited := m.preSend(context.Background(), "test", msg, ch) + + if edited { + t.Fatal("expected preSend to return false with no registered state") + } +} + +func TestPreSend_TypingAndPlaceholder(t *testing.T) { + m := newTestManager() + var stopCalled bool + var editCalled bool + + ch := &mockMessageEditor{ + mockChannel: mockChannel{ + sendFn: func(_ context.Context, _ bus.OutboundMessage) error { + return nil + }, + }, + editFn: func(_ context.Context, _, _, _ string) error { + editCalled = true + return nil + }, + } + + m.RecordTypingStop("test", "123", func() { + stopCalled = true + }) + m.RecordPlaceholder("test", "123", "456") + + msg := bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"} + edited := m.preSend(context.Background(), "test", msg, ch) + + if !stopCalled { + t.Fatal("expected typing stop to be called") + } + if !editCalled { + t.Fatal("expected EditMessage to be called") + } + if !edited { + t.Fatal("expected preSend to return true") + } +} + +func TestRecordPlaceholder_ConcurrentSafe(t *testing.T) { + m := newTestManager() + + var wg sync.WaitGroup + for i := 0; i < 100; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + chatID := fmt.Sprintf("chat_%d", i%10) + m.RecordPlaceholder("test", chatID, fmt.Sprintf("msg_%d", i)) + }(i) + } + wg.Wait() +} + +func TestRecordTypingStop_ConcurrentSafe(t *testing.T) { + m := newTestManager() + + var wg sync.WaitGroup + for i := 0; i < 100; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + chatID := fmt.Sprintf("chat_%d", i%10) + m.RecordTypingStop("test", chatID, func() {}) + }(i) + } + wg.Wait() +} + +func TestSendWithRetry_PreSendEditsPlaceholder(t *testing.T) { + m := newTestManager() + var sendCalled bool + + ch := &mockMessageEditor{ + mockChannel: mockChannel{ + sendFn: func(_ context.Context, _ bus.OutboundMessage) error { + sendCalled = true + return nil + }, + }, + editFn: func(_ context.Context, _, _, _ string) error { + return nil // edit succeeds + }, + } + + m.RecordPlaceholder("test", "123", "456") + + w := &channelWorker{ + ch: ch, + limiter: rate.NewLimiter(rate.Inf, 1), + } + + msg := bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"} + m.sendWithRetry(context.Background(), "test", w, msg) + + if sendCalled { + t.Fatal("expected Send to NOT be called when placeholder was edited") + } +} diff --git a/pkg/channels/onebot/onebot.go b/pkg/channels/onebot/onebot.go index f32cb4948..682025b67 100644 --- a/pkg/channels/onebot/onebot.go +++ b/pkg/channels/onebot/onebot.go @@ -418,12 +418,6 @@ func (c *OneBotChannel) Send(ctx context.Context, msg bus.OutboundMessage) error return fmt.Errorf("onebot send: %w", channels.ErrTemporary) } - if msgID, ok := c.pendingEmojiMsg.LoadAndDelete(msg.ChatID); ok { - if mid, ok := msgID.(string); ok && mid != "" { - c.setMsgEmojiLike(mid, 289, false) - } - } - return nil } @@ -1037,6 +1031,13 @@ func (c *OneBotChannel) handleMessage(raw *oneBotRawEvent) { if raw.MessageType == "group" && messageID != "" && messageID != "0" { c.setMsgEmojiLike(messageID, 289, true) c.pendingEmojiMsg.Store(chatID, messageID) + // Register emoji stop with Manager for outbound orchestration + if rec := c.GetPlaceholderRecorder(); rec != nil { + capturedMsgID := messageID + rec.RecordTypingStop("onebot", chatID, func() { + c.setMsgEmojiLike(capturedMsgID, 289, false) + }) + } } c.HandleMessage(peer, messageID, senderID, chatID, content, parsed.Media, metadata) diff --git a/pkg/channels/pico/init.go b/pkg/channels/pico/init.go new file mode 100644 index 000000000..96d764418 --- /dev/null +++ b/pkg/channels/pico/init.go @@ -0,0 +1,13 @@ +package pico + +import ( + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" +) + +func init() { + channels.RegisterFactory("pico", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) { + return NewPicoChannel(cfg.Channels.Pico, b) + }) +} diff --git a/pkg/channels/pico/pico.go b/pkg/channels/pico/pico.go new file mode 100644 index 000000000..1c28ca732 --- /dev/null +++ b/pkg/channels/pico/pico.go @@ -0,0 +1,430 @@ +package pico + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/google/uuid" + "github.com/gorilla/websocket" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/logger" +) + +// picoConn represents a single WebSocket connection. +type picoConn struct { + id string + conn *websocket.Conn + sessionID string + writeMu sync.Mutex + closed atomic.Bool +} + +// writeJSON sends a JSON message to the connection with write locking. +func (pc *picoConn) writeJSON(v any) error { + if pc.closed.Load() { + return fmt.Errorf("connection closed") + } + pc.writeMu.Lock() + defer pc.writeMu.Unlock() + return pc.conn.WriteJSON(v) +} + +// close closes the connection. +func (pc *picoConn) close() { + if pc.closed.CompareAndSwap(false, true) { + pc.conn.Close() + } +} + +// PicoChannel implements the native Pico Protocol WebSocket channel. +// It serves as the reference implementation for all optional capability interfaces. +type PicoChannel struct { + *channels.BaseChannel + config config.PicoConfig + upgrader websocket.Upgrader + connections sync.Map // connID → *picoConn + connCount atomic.Int32 + ctx context.Context + cancel context.CancelFunc +} + +// NewPicoChannel creates a new Pico Protocol channel. +func NewPicoChannel(cfg config.PicoConfig, messageBus *bus.MessageBus) (*PicoChannel, error) { + if cfg.Token == "" { + return nil, fmt.Errorf("pico token is required") + } + + base := channels.NewBaseChannel("pico", cfg, messageBus, cfg.AllowFrom) + + allowOrigins := cfg.AllowOrigins + checkOrigin := func(r *http.Request) bool { + if len(allowOrigins) == 0 { + return true // allow all if not configured + } + origin := r.Header.Get("Origin") + for _, allowed := range allowOrigins { + if allowed == "*" || allowed == origin { + return true + } + } + return false + } + + return &PicoChannel{ + BaseChannel: base, + config: cfg, + upgrader: websocket.Upgrader{ + CheckOrigin: checkOrigin, + ReadBufferSize: 1024, + WriteBufferSize: 1024, + }, + }, nil +} + +// Start implements Channel. +func (c *PicoChannel) Start(ctx context.Context) error { + logger.InfoC("pico", "Starting Pico Protocol channel") + c.ctx, c.cancel = context.WithCancel(ctx) + c.SetRunning(true) + logger.InfoC("pico", "Pico Protocol channel started") + return nil +} + +// Stop implements Channel. +func (c *PicoChannel) Stop(ctx context.Context) error { + logger.InfoC("pico", "Stopping Pico Protocol channel") + c.SetRunning(false) + + // Close all connections + c.connections.Range(func(key, value any) bool { + if pc, ok := value.(*picoConn); ok { + pc.close() + } + c.connections.Delete(key) + return true + }) + + if c.cancel != nil { + c.cancel() + } + + logger.InfoC("pico", "Pico Protocol channel stopped") + return nil +} + +// WebhookPath implements channels.WebhookHandler. +func (c *PicoChannel) WebhookPath() string { return "/pico/" } + +// ServeHTTP implements http.Handler for the shared HTTP server. +func (c *PicoChannel) ServeHTTP(w http.ResponseWriter, r *http.Request) { + path := strings.TrimPrefix(r.URL.Path, "/pico") + + switch { + case path == "/ws" || path == "/ws/": + c.handleWebSocket(w, r) + default: + http.NotFound(w, r) + } +} + +// Send implements Channel — sends a message to the appropriate WebSocket connection. +func (c *PicoChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { + if !c.IsRunning() { + return channels.ErrNotRunning + } + + outMsg := newMessage(TypeMessageCreate, map[string]any{ + "content": msg.Content, + }) + + return c.broadcastToSession(msg.ChatID, outMsg) +} + +// EditMessage implements channels.MessageEditor. +func (c *PicoChannel) EditMessage(ctx context.Context, chatID string, messageID string, content string) error { + outMsg := newMessage(TypeMessageUpdate, map[string]any{ + "message_id": messageID, + "content": content, + }) + return c.broadcastToSession(chatID, outMsg) +} + +// StartTyping implements channels.TypingCapable. +func (c *PicoChannel) StartTyping(ctx context.Context, chatID string) (func(), error) { + startMsg := newMessage(TypeTypingStart, nil) + if err := c.broadcastToSession(chatID, startMsg); err != nil { + return func() {}, err + } + return func() { + stopMsg := newMessage(TypeTypingStop, nil) + c.broadcastToSession(chatID, stopMsg) + }, nil +} + +// broadcastToSession sends a message to all connections with a matching session. +func (c *PicoChannel) broadcastToSession(chatID string, msg PicoMessage) error { + // chatID format: "pico:" + sessionID := strings.TrimPrefix(chatID, "pico:") + msg.SessionID = sessionID + + var sent bool + c.connections.Range(func(key, value any) bool { + pc, ok := value.(*picoConn) + if !ok { + return true + } + if pc.sessionID == sessionID { + if err := pc.writeJSON(msg); err != nil { + logger.DebugCF("pico", "Write to connection failed", map[string]any{ + "conn_id": pc.id, + "error": err.Error(), + }) + } else { + sent = true + } + } + return true + }) + + if !sent { + return fmt.Errorf("no active connections for session %s: %w", sessionID, channels.ErrSendFailed) + } + return nil +} + +// handleWebSocket upgrades the HTTP connection and manages the WebSocket lifecycle. +func (c *PicoChannel) handleWebSocket(w http.ResponseWriter, r *http.Request) { + if !c.IsRunning() { + http.Error(w, "channel not running", http.StatusServiceUnavailable) + return + } + + // Authenticate + if !c.authenticate(r) { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + + // Check connection limit + maxConns := c.config.MaxConnections + if maxConns <= 0 { + maxConns = 100 + } + if int(c.connCount.Load()) >= maxConns { + http.Error(w, "too many connections", http.StatusServiceUnavailable) + return + } + + conn, err := c.upgrader.Upgrade(w, r, nil) + if err != nil { + logger.ErrorCF("pico", "WebSocket upgrade failed", map[string]any{ + "error": err.Error(), + }) + return + } + + // Determine session ID from query param or generate one + sessionID := r.URL.Query().Get("session_id") + if sessionID == "" { + sessionID = uuid.New().String() + } + + pc := &picoConn{ + id: uuid.New().String(), + conn: conn, + sessionID: sessionID, + } + + c.connections.Store(pc.id, pc) + c.connCount.Add(1) + + logger.InfoCF("pico", "WebSocket client connected", map[string]any{ + "conn_id": pc.id, + "session_id": sessionID, + }) + + go c.readLoop(pc) +} + +// authenticate checks the Bearer token from header or query parameter. +func (c *PicoChannel) authenticate(r *http.Request) bool { + token := c.config.Token + if token == "" { + return false + } + + // Check Authorization header + auth := r.Header.Get("Authorization") + if strings.HasPrefix(auth, "Bearer ") { + if strings.TrimPrefix(auth, "Bearer ") == token { + return true + } + } + + // Check query parameter + if r.URL.Query().Get("token") == token { + return true + } + + return false +} + +// readLoop reads messages from a WebSocket connection. +func (c *PicoChannel) readLoop(pc *picoConn) { + defer func() { + pc.close() + c.connections.Delete(pc.id) + c.connCount.Add(-1) + logger.InfoCF("pico", "WebSocket client disconnected", map[string]any{ + "conn_id": pc.id, + "session_id": pc.sessionID, + }) + }() + + readTimeout := time.Duration(c.config.ReadTimeout) * time.Second + if readTimeout <= 0 { + readTimeout = 60 * time.Second + } + + _ = pc.conn.SetReadDeadline(time.Now().Add(readTimeout)) + pc.conn.SetPongHandler(func(appData string) error { + _ = pc.conn.SetReadDeadline(time.Now().Add(readTimeout)) + return nil + }) + + // Start ping ticker + pingInterval := time.Duration(c.config.PingInterval) * time.Second + if pingInterval <= 0 { + pingInterval = 30 * time.Second + } + go c.pingLoop(pc, pingInterval) + + for { + select { + case <-c.ctx.Done(): + return + default: + } + + _, rawMsg, err := pc.conn.ReadMessage() + if err != nil { + if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseNormalClosure) { + logger.DebugCF("pico", "WebSocket read error", map[string]any{ + "conn_id": pc.id, + "error": err.Error(), + }) + } + return + } + + _ = pc.conn.SetReadDeadline(time.Now().Add(readTimeout)) + + var msg PicoMessage + if err := json.Unmarshal(rawMsg, &msg); err != nil { + errMsg := newError("invalid_message", "failed to parse message") + pc.writeJSON(errMsg) + continue + } + + c.handleMessage(pc, msg) + } +} + +// pingLoop sends periodic ping frames to keep the connection alive. +func (c *PicoChannel) pingLoop(pc *picoConn, interval time.Duration) { + ticker := time.NewTicker(interval) + defer ticker.Stop() + + for { + select { + case <-c.ctx.Done(): + return + case <-ticker.C: + if pc.closed.Load() { + return + } + pc.writeMu.Lock() + err := pc.conn.WriteMessage(websocket.PingMessage, nil) + pc.writeMu.Unlock() + if err != nil { + return + } + } + } +} + +// handleMessage processes an inbound Pico Protocol message. +func (c *PicoChannel) handleMessage(pc *picoConn, msg PicoMessage) { + switch msg.Type { + case TypePing: + pong := newMessage(TypePong, nil) + pong.ID = msg.ID + pc.writeJSON(pong) + + case TypeMessageSend: + c.handleMessageSend(pc, msg) + + default: + errMsg := newError("unknown_type", fmt.Sprintf("unknown message type: %s", msg.Type)) + pc.writeJSON(errMsg) + } +} + +// handleMessageSend processes an inbound message.send from a client. +func (c *PicoChannel) handleMessageSend(pc *picoConn, msg PicoMessage) { + content, _ := msg.Payload["content"].(string) + if strings.TrimSpace(content) == "" { + errMsg := newError("empty_content", "message content is empty") + pc.writeJSON(errMsg) + return + } + + sessionID := msg.SessionID + if sessionID == "" { + sessionID = pc.sessionID + } + + chatID := "pico:" + sessionID + senderID := "pico-user" + + peer := bus.Peer{Kind: "direct", ID: "pico:" + sessionID} + + metadata := map[string]string{ + "platform": "pico", + "session_id": sessionID, + "conn_id": pc.id, + } + + logger.DebugCF("pico", "Received message", map[string]any{ + "session_id": sessionID, + "preview": truncate(content, 50), + }) + + // Register typing with Manager + if rec := c.GetPlaceholderRecorder(); rec != nil { + stop, err := c.StartTyping(c.ctx, chatID) + if err == nil { + rec.RecordTypingStop("pico", chatID, stop) + } + } + + c.HandleMessage(peer, msg.ID, senderID, chatID, content, nil, metadata) +} + +// truncate truncates a string to maxLen runes. +func truncate(s string, maxLen int) string { + runes := []rune(s) + if len(runes) <= maxLen { + return s + } + return string(runes[:maxLen]) + "..." +} diff --git a/pkg/channels/pico/protocol.go b/pkg/channels/pico/protocol.go new file mode 100644 index 000000000..ca18df1dd --- /dev/null +++ b/pkg/channels/pico/protocol.go @@ -0,0 +1,46 @@ +package pico + +import "time" + +// Protocol message types. +const ( + // Client → Server + TypeMessageSend = "message.send" + TypeMediaSend = "media.send" + TypePing = "ping" + + // Server → Client + TypeMessageCreate = "message.create" + TypeMessageUpdate = "message.update" + TypeMediaCreate = "media.create" + TypeTypingStart = "typing.start" + TypeTypingStop = "typing.stop" + TypeError = "error" + TypePong = "pong" +) + +// PicoMessage is the wire format for all Pico Protocol messages. +type PicoMessage struct { + Type string `json:"type"` + ID string `json:"id,omitempty"` + SessionID string `json:"session_id,omitempty"` + Timestamp int64 `json:"timestamp,omitempty"` + Payload map[string]any `json:"payload,omitempty"` +} + +// newMessage creates a PicoMessage with the given type and payload. +func newMessage(msgType string, payload map[string]any) PicoMessage { + return PicoMessage{ + Type: msgType, + Timestamp: time.Now().UnixMilli(), + Payload: payload, + } +} + +// newError creates an error PicoMessage. +func newError(code, message string) PicoMessage { + return newMessage(TypeError, map[string]any{ + "code": code, + "message": message, + }) +} diff --git a/pkg/channels/slack/slack.go b/pkg/channels/slack/slack.go index 6fba2e0b4..e64525310 100644 --- a/pkg/channels/slack/slack.go +++ b/pkg/channels/slack/slack.go @@ -274,6 +274,18 @@ func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) { Timestamp: messageTS, }) + // Register typing stop (remove "eyes" reaction) with Manager + if rec := c.GetPlaceholderRecorder(); rec != nil { + capturedChannelID := channelID + capturedMessageTS := messageTS + rec.RecordTypingStop("slack", chatID, func() { + c.api.RemoveReaction("eyes", slack.ItemRef{ + Channel: capturedChannelID, + Timestamp: capturedMessageTS, + }) + }) + } + c.pendingAcks.Store(chatID, slackMessageRef{ ChannelID: channelID, Timestamp: messageTS, @@ -380,6 +392,18 @@ func (c *SlackChannel) handleAppMention(ev *slackevents.AppMentionEvent) { Timestamp: messageTS, }) + // Register typing stop (remove "eyes" reaction) with Manager + if rec := c.GetPlaceholderRecorder(); rec != nil { + capturedChannelID := channelID + capturedMessageTS := messageTS + rec.RecordTypingStop("slack", chatID, func() { + c.api.RemoveReaction("eyes", slack.ItemRef{ + Channel: capturedChannelID, + Timestamp: capturedMessageTS, + }) + }) + } + c.pendingAcks.Store(chatID, slackMessageRef{ ChannelID: channelID, Timestamp: messageTS, diff --git a/pkg/channels/telegram/telegram.go b/pkg/channels/telegram/telegram.go index c5c055163..98477f3a8 100644 --- a/pkg/channels/telegram/telegram.go +++ b/pkg/channels/telegram/telegram.go @@ -7,8 +7,8 @@ import ( "net/url" "os" "regexp" + "strconv" "strings" - "sync" "time" "github.com/mymmrac/telego" @@ -26,25 +26,13 @@ import ( type TelegramChannel struct { *channels.BaseChannel - bot *telego.Bot - bh *telegohandler.BotHandler - commands TelegramCommander - config *config.Config - chatIDs map[string]int64 - ctx context.Context - cancel context.CancelFunc - placeholders sync.Map // chatID -> messageID - stopThinking sync.Map // chatID -> thinkingCancel -} - -type thinkingCancel struct { - fn context.CancelFunc -} - -func (c *thinkingCancel) Cancel() { - if c != nil && c.fn != nil { - c.fn() - } + bot *telego.Bot + bh *telegohandler.BotHandler + commands TelegramCommander + config *config.Config + chatIDs map[string]int64 + ctx context.Context + cancel context.CancelFunc } func NewTelegramChannel(cfg *config.Config, bus *bus.MessageBus) (*TelegramChannel, error) { @@ -85,13 +73,11 @@ func NewTelegramChannel(cfg *config.Config, bus *bus.MessageBus) (*TelegramChann ) return &TelegramChannel{ - BaseChannel: base, - commands: NewTelegramCommands(bot, cfg), - bot: bot, - config: cfg, - chatIDs: make(map[string]int64), - placeholders: sync.Map{}, - stopThinking: sync.Map{}, + BaseChannel: base, + commands: NewTelegramCommands(bot, cfg), + bot: bot, + config: cfg, + chatIDs: make(map[string]int64), }, nil } @@ -149,21 +135,6 @@ func (c *TelegramChannel) Stop(ctx context.Context) error { logger.InfoC("telegram", "Stopping Telegram bot...") c.SetRunning(false) - // Clean up all thinking cancel functions to avoid context leaks - c.stopThinking.Range(func(key, value any) bool { - if cf, ok := value.(*thinkingCancel); ok && cf != nil { - cf.Cancel() - } - c.stopThinking.Delete(key) - return true - }) - - // Clean up placeholder state - c.placeholders.Range(func(key, value any) bool { - c.placeholders.Delete(key) - return true - }) - // Stop the bot handler if c.bh != nil { c.bh.Stop() @@ -187,28 +158,9 @@ func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) err return fmt.Errorf("invalid chat ID %s: %w", msg.ChatID, channels.ErrSendFailed) } - // Stop thinking animation - if stop, ok := c.stopThinking.Load(msg.ChatID); ok { - if cf, ok := stop.(*thinkingCancel); ok && cf != nil { - cf.Cancel() - } - c.stopThinking.Delete(msg.ChatID) - } - htmlContent := markdownToTelegramHTML(msg.Content) - // Try to edit placeholder - if pID, ok := c.placeholders.Load(msg.ChatID); ok { - c.placeholders.Delete(msg.ChatID) - editMsg := tu.EditMessageText(tu.ID(chatID), pID.(int), htmlContent) - editMsg.ParseMode = telego.ModeHTML - - if _, err = c.bot.EditMessageText(ctx, editMsg); err == nil { - return nil - } - // Fallback to new message if edit fails - } - + // Typing/placeholder handled by Manager.preSend — just send the message tgMsg := tu.Message(tu.ID(chatID), htmlContent) tgMsg.ParseMode = telego.ModeHTML @@ -225,6 +177,23 @@ func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) err return nil } +// EditMessage implements channels.MessageEditor. +func (c *TelegramChannel) EditMessage(ctx context.Context, chatID string, messageID string, content string) error { + cid, err := parseChatID(chatID) + if err != nil { + return err + } + mid, err := strconv.Atoi(messageID) + if err != nil { + return err + } + htmlContent := markdownToTelegramHTML(content) + editMsg := tu.EditMessageText(tu.ID(cid), mid, htmlContent) + editMsg.ParseMode = telego.ModeHTML + _, err = c.bot.EditMessageText(ctx, editMsg) + return err +} + // SendMedia implements the channels.MediaSender interface. func (c *TelegramChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error { if !c.IsRunning() { @@ -445,21 +414,21 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Mes }) } - // Stop any previous thinking animation - if prevStop, ok := c.stopThinking.Load(chatIDStr); ok { - if cf, ok := prevStop.(*thinkingCancel); ok && cf != nil { - cf.Cancel() - } - } - - // Create cancel function for thinking state + // Create cancel function for thinking state and register with Manager _, thinkCancel := context.WithTimeout(ctx, 5*time.Minute) - c.stopThinking.Store(chatIDStr, &thinkingCancel{fn: thinkCancel}) + if rec := c.GetPlaceholderRecorder(); rec != nil { + rec.RecordTypingStop("telegram", chatIDStr, thinkCancel) + } else { + // No recorder — cancel immediately to avoid context leak + thinkCancel() + } pMsg, err := c.bot.SendMessage(ctx, tu.Message(tu.ID(chatID), "Thinking... 💭")) if err == nil { pID := pMsg.MessageID - c.placeholders.Store(chatIDStr, pID) + if rec := c.GetPlaceholderRecorder(); rec != nil { + rec.RecordPlaceholder("telegram", chatIDStr, fmt.Sprintf("%d", pID)) + } } peerKind := "direct" diff --git a/pkg/config/config.go b/pkg/config/config.go index 0c89d05eb..d32e8db90 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -202,6 +202,7 @@ type ChannelsConfig struct { OneBot OneBotConfig `json:"onebot"` WeCom WeComConfig `json:"wecom"` WeComApp WeComAppConfig `json:"wecom_app"` + Pico PicoConfig `json:"pico"` } // GroupTriggerConfig controls when the bot responds in group chats. @@ -343,6 +344,17 @@ type WeComAppConfig struct { GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"` } +type PicoConfig struct { + Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_PICO_ENABLED"` + Token string `json:"token" env:"PICOCLAW_CHANNELS_PICO_TOKEN"` + AllowOrigins []string `json:"allow_origins,omitempty"` + PingInterval int `json:"ping_interval,omitempty"` // seconds, default 30 + ReadTimeout int `json:"read_timeout,omitempty"` // seconds, default 60 + WriteTimeout int `json:"write_timeout,omitempty"` // seconds, default 10 + MaxConnections int `json:"max_connections,omitempty"` // default 100 + AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_PICO_ALLOW_FROM"` +} + type HeartbeatConfig struct { Enabled bool `json:"enabled" env:"PICOCLAW_HEARTBEAT_ENABLED"` Interval int `json:"interval" env:"PICOCLAW_HEARTBEAT_INTERVAL"` // minutes, min 5 diff --git a/pkg/config/defaults.go b/pkg/config/defaults.go index 5c53a3963..8445510e2 100644 --- a/pkg/config/defaults.go +++ b/pkg/config/defaults.go @@ -33,6 +33,11 @@ func DefaultConfig() *Config { Enabled: false, Token: "", AllowFrom: FlexibleStringSlice{}, + Typing: TypingConfig{Enabled: true}, + Placeholder: PlaceholderConfig{ + Enabled: true, + Text: "Thinking... 💭", + }, }, Feishu: FeishuConfig{ Enabled: false, @@ -114,6 +119,15 @@ func DefaultConfig() *Config { AllowFrom: FlexibleStringSlice{}, ReplyTimeout: 5, }, + Pico: PicoConfig{ + Enabled: false, + Token: "", + PingInterval: 30, + ReadTimeout: 60, + WriteTimeout: 10, + MaxConnections: 100, + AllowFrom: FlexibleStringSlice{}, + }, }, Providers: ProvidersConfig{ OpenAI: OpenAIProviderConfig{WebSearch: true},