diff --git a/cmd/picoclaw/internal/gateway/helpers.go b/cmd/picoclaw/internal/gateway/helpers.go index 758157f53..6ac41fab1 100644 --- a/cmd/picoclaw/internal/gateway/helpers.go +++ b/cmd/picoclaw/internal/gateway/helpers.go @@ -18,6 +18,7 @@ import ( _ "github.com/sipeed/picoclaw/pkg/channels/line" _ "github.com/sipeed/picoclaw/pkg/channels/maixcam" _ "github.com/sipeed/picoclaw/pkg/channels/onebot" + _ "github.com/sipeed/picoclaw/pkg/channels/pico" _ "github.com/sipeed/picoclaw/pkg/channels/qq" _ "github.com/sipeed/picoclaw/pkg/channels/slack" _ "github.com/sipeed/picoclaw/pkg/channels/telegram" diff --git a/pkg/channels/base.go b/pkg/channels/base.go index e345aedf0..c22a27eb9 100644 --- a/pkg/channels/base.go +++ b/pkg/channels/base.go @@ -44,14 +44,15 @@ type MessageLengthProvider interface { } type BaseChannel struct { - config any - bus *bus.MessageBus - running atomic.Bool - name string - allowList []string - maxMessageLength int - groupTrigger config.GroupTriggerConfig - mediaStore media.MediaStore + config any + bus *bus.MessageBus + running atomic.Bool + name string + allowList []string + maxMessageLength int + groupTrigger config.GroupTriggerConfig + mediaStore media.MediaStore + placeholderRecorder PlaceholderRecorder } func NewBaseChannel( @@ -203,6 +204,16 @@ func (c *BaseChannel) SetMediaStore(s media.MediaStore) { c.mediaStore = s } // GetMediaStore returns the injected MediaStore (may be nil). func (c *BaseChannel) GetMediaStore() media.MediaStore { return c.mediaStore } +// SetPlaceholderRecorder injects a PlaceholderRecorder into the channel. +func (c *BaseChannel) SetPlaceholderRecorder(r PlaceholderRecorder) { + c.placeholderRecorder = r +} + +// GetPlaceholderRecorder returns the injected PlaceholderRecorder (may be nil). +func (c *BaseChannel) GetPlaceholderRecorder() PlaceholderRecorder { + return c.placeholderRecorder +} + // BuildMediaScope constructs a scope key for media lifecycle tracking. func BuildMediaScope(channel, chatID, messageID string) string { id := messageID diff --git a/pkg/channels/discord/discord.go b/pkg/channels/discord/discord.go index 4ef4906c1..ee698da61 100644 --- a/pkg/channels/discord/discord.go +++ b/pkg/channels/discord/discord.go @@ -106,8 +106,6 @@ func (c *DiscordChannel) Stop(ctx context.Context) error { } func (c *DiscordChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { - c.stopTyping(msg.ChatID) - if !c.IsRunning() { return channels.ErrNotRunning } @@ -126,8 +124,6 @@ func (c *DiscordChannel) Send(ctx context.Context, msg bus.OutboundMessage) erro // SendMedia implements the channels.MediaSender interface. func (c *DiscordChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error { - c.stopTyping(msg.ChatID) - if !c.IsRunning() { return channels.ErrNotRunning } @@ -221,6 +217,12 @@ func (c *DiscordChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMes } } +// EditMessage implements channels.MessageEditor. +func (c *DiscordChannel) EditMessage(ctx context.Context, chatID string, messageID string, content string) error { + _, err := c.session.ChannelMessageEdit(chatID, messageID, content) + return err +} + func (c *DiscordChannel) sendChunk(ctx context.Context, channelID, content string) error { // Use the passed ctx for timeout control sendCtx, cancel := context.WithTimeout(ctx, sendTimeout) @@ -350,6 +352,10 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag // Start typing after all early returns — guaranteed to have a matching Send() c.startTyping(m.ChannelID) + // Register typing stop with Manager for outbound orchestration + if rec := c.GetPlaceholderRecorder(); rec != nil { + rec.RecordTypingStop("discord", m.ChannelID, func() { c.stopTyping(m.ChannelID) }) + } logger.DebugCF("discord", "Received message", map[string]any{ "sender_name": senderName, diff --git a/pkg/channels/interfaces.go b/pkg/channels/interfaces.go new file mode 100644 index 000000000..32bfe95f8 --- /dev/null +++ b/pkg/channels/interfaces.go @@ -0,0 +1,24 @@ +package channels + +import "context" + +// TypingCapable — channels that can show a typing/thinking indicator. +// StartTyping begins the indicator and returns a stop function. +// The stop function MUST be idempotent and safe to call multiple times. +type TypingCapable interface { + StartTyping(ctx context.Context, chatID string) (stop func(), err error) +} + +// MessageEditor — channels that can edit an existing message. +// messageID is always string; channels convert platform-specific types internally. +type MessageEditor interface { + EditMessage(ctx context.Context, chatID string, messageID string, content string) error +} + +// PlaceholderRecorder is injected into channels by Manager. +// Channels call these methods on inbound to register typing/placeholder state. +// Manager uses the registered state on outbound to stop typing and edit placeholders. +type PlaceholderRecorder interface { + RecordPlaceholder(channel, chatID, placeholderID string) + RecordTypingStop(channel, chatID string, stop func()) +} diff --git a/pkg/channels/manager.go b/pkg/channels/manager.go index 92412edeb..4b1a43b7b 100644 --- a/pkg/channels/manager.go +++ b/pkg/channels/manager.go @@ -62,12 +62,55 @@ type Manager struct { mux *http.ServeMux httpServer *http.Server mu sync.RWMutex + placeholders sync.Map // "channel:chatID" → placeholderID (string) + typingStops sync.Map // "channel:chatID" → func() } type asyncTask struct { cancel context.CancelFunc } +// RecordPlaceholder registers a placeholder message for later editing. +// Implements PlaceholderRecorder. +func (m *Manager) RecordPlaceholder(channel, chatID, placeholderID string) { + key := channel + ":" + chatID + m.placeholders.Store(key, placeholderID) +} + +// RecordTypingStop registers a typing stop function for later invocation. +// Implements PlaceholderRecorder. +func (m *Manager) RecordTypingStop(channel, chatID string, stop func()) { + key := channel + ":" + chatID + m.typingStops.Store(key, stop) +} + +// preSend handles typing stop and placeholder editing before sending a message. +// Returns true if the message was edited into a placeholder (skip Send). +func (m *Manager) preSend(ctx context.Context, name string, msg bus.OutboundMessage, ch Channel) bool { + key := name + ":" + msg.ChatID + + // 1. Stop typing + if v, loaded := m.typingStops.LoadAndDelete(key); loaded { + if stop, ok := v.(func()); ok { + stop() // idempotent, safe + } + } + + // 2. Try editing placeholder + if v, loaded := m.placeholders.LoadAndDelete(key); loaded { + if placeholderID, ok := v.(string); ok && placeholderID != "" { + if editor, ok := ch.(MessageEditor); ok { + if err := editor.EditMessage(ctx, msg.ChatID, placeholderID, msg.Content); err == nil { + return true // edited successfully, skip Send + } + // edit failed → fall through to normal Send + } + } + } + + return false +} + func NewManager(cfg *config.Config, messageBus *bus.MessageBus, store media.MediaStore) (*Manager, error) { m := &Manager{ channels: make(map[string]Channel), @@ -109,6 +152,10 @@ func (m *Manager) initChannel(name, displayName string) { setter.SetMediaStore(m.mediaStore) } } + // Inject PlaceholderRecorder if channel supports it + if setter, ok := ch.(interface{ SetPlaceholderRecorder(PlaceholderRecorder) }); ok { + setter.SetPlaceholderRecorder(m) + } m.channels[name] = ch m.workers[name] = newChannelWorker(name, ch) logger.InfoCF("channels", "Channel enabled successfully", map[string]any{ @@ -168,6 +215,10 @@ func (m *Manager) initChannels() error { m.initChannel("wecom_app", "WeCom App") } + if m.config.Channels.Pico.Enabled && m.config.Channels.Pico.Token != "" { + m.initChannel("pico", "Pico") + } + logger.InfoCF("channels", "Channel initialization completed", map[string]any{ "enabled_channels": len(m.channels), }) @@ -383,6 +434,11 @@ func (m *Manager) sendWithRetry(ctx context.Context, name string, w *channelWork return } + // Pre-send: stop typing and try to edit placeholder + if m.preSend(ctx, name, msg, w.ch) { + return // placeholder was edited successfully, skip Send + } + var lastErr error for attempt := 0; attempt <= maxRetries; attempt++ { lastErr = w.ch.Send(ctx, msg) diff --git a/pkg/channels/manager_test.go b/pkg/channels/manager_test.go index 162c9f8c9..0573c0a8e 100644 --- a/pkg/channels/manager_test.go +++ b/pkg/channels/manager_test.go @@ -416,3 +416,219 @@ func TestSendWithRetry_ExponentialBackoff(t *testing.T) { t.Fatalf("expected %d calls, got %d", maxRetries+1, callCount.Load()) } } + +// --- Phase 10: preSend orchestration tests --- + +// mockMessageEditor is a channel that supports MessageEditor. +type mockMessageEditor struct { + mockChannel + editFn func(ctx context.Context, chatID, messageID, content string) error +} + +func (m *mockMessageEditor) EditMessage(ctx context.Context, chatID, messageID, content string) error { + return m.editFn(ctx, chatID, messageID, content) +} + +func TestPreSend_PlaceholderEditSuccess(t *testing.T) { + m := newTestManager() + var sendCalled bool + var editCalled bool + + ch := &mockMessageEditor{ + mockChannel: mockChannel{ + sendFn: func(_ context.Context, _ bus.OutboundMessage) error { + sendCalled = true + return nil + }, + }, + editFn: func(_ context.Context, chatID, messageID, content string) error { + editCalled = true + if chatID != "123" { + t.Fatalf("expected chatID 123, got %s", chatID) + } + if messageID != "456" { + t.Fatalf("expected messageID 456, got %s", messageID) + } + if content != "hello" { + t.Fatalf("expected content 'hello', got %s", content) + } + return nil + }, + } + + // Register placeholder + m.RecordPlaceholder("test", "123", "456") + + msg := bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"} + edited := m.preSend(context.Background(), "test", msg, ch) + + if !edited { + t.Fatal("expected preSend to return true (placeholder edited)") + } + if !editCalled { + t.Fatal("expected EditMessage to be called") + } + if sendCalled { + t.Fatal("expected Send to NOT be called when placeholder edited") + } +} + +func TestPreSend_PlaceholderEditFails_FallsThrough(t *testing.T) { + m := newTestManager() + + ch := &mockMessageEditor{ + mockChannel: mockChannel{ + sendFn: func(_ context.Context, _ bus.OutboundMessage) error { + return nil + }, + }, + editFn: func(_ context.Context, _, _, _ string) error { + return fmt.Errorf("edit failed") + }, + } + + m.RecordPlaceholder("test", "123", "456") + + msg := bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"} + edited := m.preSend(context.Background(), "test", msg, ch) + + if edited { + t.Fatal("expected preSend to return false when edit fails") + } +} + +func TestPreSend_TypingStopCalled(t *testing.T) { + m := newTestManager() + var stopCalled bool + + ch := &mockChannel{ + sendFn: func(_ context.Context, _ bus.OutboundMessage) error { + return nil + }, + } + + m.RecordTypingStop("test", "123", func() { + stopCalled = true + }) + + msg := bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"} + m.preSend(context.Background(), "test", msg, ch) + + if !stopCalled { + t.Fatal("expected typing stop func to be called") + } +} + +func TestPreSend_NoRegisteredState(t *testing.T) { + m := newTestManager() + + ch := &mockChannel{ + sendFn: func(_ context.Context, _ bus.OutboundMessage) error { + return nil + }, + } + + msg := bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"} + edited := m.preSend(context.Background(), "test", msg, ch) + + if edited { + t.Fatal("expected preSend to return false with no registered state") + } +} + +func TestPreSend_TypingAndPlaceholder(t *testing.T) { + m := newTestManager() + var stopCalled bool + var editCalled bool + + ch := &mockMessageEditor{ + mockChannel: mockChannel{ + sendFn: func(_ context.Context, _ bus.OutboundMessage) error { + return nil + }, + }, + editFn: func(_ context.Context, _, _, _ string) error { + editCalled = true + return nil + }, + } + + m.RecordTypingStop("test", "123", func() { + stopCalled = true + }) + m.RecordPlaceholder("test", "123", "456") + + msg := bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"} + edited := m.preSend(context.Background(), "test", msg, ch) + + if !stopCalled { + t.Fatal("expected typing stop to be called") + } + if !editCalled { + t.Fatal("expected EditMessage to be called") + } + if !edited { + t.Fatal("expected preSend to return true") + } +} + +func TestRecordPlaceholder_ConcurrentSafe(t *testing.T) { + m := newTestManager() + + var wg sync.WaitGroup + for i := 0; i < 100; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + chatID := fmt.Sprintf("chat_%d", i%10) + m.RecordPlaceholder("test", chatID, fmt.Sprintf("msg_%d", i)) + }(i) + } + wg.Wait() +} + +func TestRecordTypingStop_ConcurrentSafe(t *testing.T) { + m := newTestManager() + + var wg sync.WaitGroup + for i := 0; i < 100; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + chatID := fmt.Sprintf("chat_%d", i%10) + m.RecordTypingStop("test", chatID, func() {}) + }(i) + } + wg.Wait() +} + +func TestSendWithRetry_PreSendEditsPlaceholder(t *testing.T) { + m := newTestManager() + var sendCalled bool + + ch := &mockMessageEditor{ + mockChannel: mockChannel{ + sendFn: func(_ context.Context, _ bus.OutboundMessage) error { + sendCalled = true + return nil + }, + }, + editFn: func(_ context.Context, _, _, _ string) error { + return nil // edit succeeds + }, + } + + m.RecordPlaceholder("test", "123", "456") + + w := &channelWorker{ + ch: ch, + limiter: rate.NewLimiter(rate.Inf, 1), + } + + msg := bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"} + m.sendWithRetry(context.Background(), "test", w, msg) + + if sendCalled { + t.Fatal("expected Send to NOT be called when placeholder was edited") + } +} diff --git a/pkg/channels/onebot/onebot.go b/pkg/channels/onebot/onebot.go index f32cb4948..682025b67 100644 --- a/pkg/channels/onebot/onebot.go +++ b/pkg/channels/onebot/onebot.go @@ -418,12 +418,6 @@ func (c *OneBotChannel) Send(ctx context.Context, msg bus.OutboundMessage) error return fmt.Errorf("onebot send: %w", channels.ErrTemporary) } - if msgID, ok := c.pendingEmojiMsg.LoadAndDelete(msg.ChatID); ok { - if mid, ok := msgID.(string); ok && mid != "" { - c.setMsgEmojiLike(mid, 289, false) - } - } - return nil } @@ -1037,6 +1031,13 @@ func (c *OneBotChannel) handleMessage(raw *oneBotRawEvent) { if raw.MessageType == "group" && messageID != "" && messageID != "0" { c.setMsgEmojiLike(messageID, 289, true) c.pendingEmojiMsg.Store(chatID, messageID) + // Register emoji stop with Manager for outbound orchestration + if rec := c.GetPlaceholderRecorder(); rec != nil { + capturedMsgID := messageID + rec.RecordTypingStop("onebot", chatID, func() { + c.setMsgEmojiLike(capturedMsgID, 289, false) + }) + } } c.HandleMessage(peer, messageID, senderID, chatID, content, parsed.Media, metadata) diff --git a/pkg/channels/pico/init.go b/pkg/channels/pico/init.go new file mode 100644 index 000000000..96d764418 --- /dev/null +++ b/pkg/channels/pico/init.go @@ -0,0 +1,13 @@ +package pico + +import ( + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" +) + +func init() { + channels.RegisterFactory("pico", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) { + return NewPicoChannel(cfg.Channels.Pico, b) + }) +} diff --git a/pkg/channels/pico/pico.go b/pkg/channels/pico/pico.go new file mode 100644 index 000000000..1c28ca732 --- /dev/null +++ b/pkg/channels/pico/pico.go @@ -0,0 +1,430 @@ +package pico + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/google/uuid" + "github.com/gorilla/websocket" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/logger" +) + +// picoConn represents a single WebSocket connection. +type picoConn struct { + id string + conn *websocket.Conn + sessionID string + writeMu sync.Mutex + closed atomic.Bool +} + +// writeJSON sends a JSON message to the connection with write locking. +func (pc *picoConn) writeJSON(v any) error { + if pc.closed.Load() { + return fmt.Errorf("connection closed") + } + pc.writeMu.Lock() + defer pc.writeMu.Unlock() + return pc.conn.WriteJSON(v) +} + +// close closes the connection. +func (pc *picoConn) close() { + if pc.closed.CompareAndSwap(false, true) { + pc.conn.Close() + } +} + +// PicoChannel implements the native Pico Protocol WebSocket channel. +// It serves as the reference implementation for all optional capability interfaces. +type PicoChannel struct { + *channels.BaseChannel + config config.PicoConfig + upgrader websocket.Upgrader + connections sync.Map // connID → *picoConn + connCount atomic.Int32 + ctx context.Context + cancel context.CancelFunc +} + +// NewPicoChannel creates a new Pico Protocol channel. +func NewPicoChannel(cfg config.PicoConfig, messageBus *bus.MessageBus) (*PicoChannel, error) { + if cfg.Token == "" { + return nil, fmt.Errorf("pico token is required") + } + + base := channels.NewBaseChannel("pico", cfg, messageBus, cfg.AllowFrom) + + allowOrigins := cfg.AllowOrigins + checkOrigin := func(r *http.Request) bool { + if len(allowOrigins) == 0 { + return true // allow all if not configured + } + origin := r.Header.Get("Origin") + for _, allowed := range allowOrigins { + if allowed == "*" || allowed == origin { + return true + } + } + return false + } + + return &PicoChannel{ + BaseChannel: base, + config: cfg, + upgrader: websocket.Upgrader{ + CheckOrigin: checkOrigin, + ReadBufferSize: 1024, + WriteBufferSize: 1024, + }, + }, nil +} + +// Start implements Channel. +func (c *PicoChannel) Start(ctx context.Context) error { + logger.InfoC("pico", "Starting Pico Protocol channel") + c.ctx, c.cancel = context.WithCancel(ctx) + c.SetRunning(true) + logger.InfoC("pico", "Pico Protocol channel started") + return nil +} + +// Stop implements Channel. +func (c *PicoChannel) Stop(ctx context.Context) error { + logger.InfoC("pico", "Stopping Pico Protocol channel") + c.SetRunning(false) + + // Close all connections + c.connections.Range(func(key, value any) bool { + if pc, ok := value.(*picoConn); ok { + pc.close() + } + c.connections.Delete(key) + return true + }) + + if c.cancel != nil { + c.cancel() + } + + logger.InfoC("pico", "Pico Protocol channel stopped") + return nil +} + +// WebhookPath implements channels.WebhookHandler. +func (c *PicoChannel) WebhookPath() string { return "/pico/" } + +// ServeHTTP implements http.Handler for the shared HTTP server. +func (c *PicoChannel) ServeHTTP(w http.ResponseWriter, r *http.Request) { + path := strings.TrimPrefix(r.URL.Path, "/pico") + + switch { + case path == "/ws" || path == "/ws/": + c.handleWebSocket(w, r) + default: + http.NotFound(w, r) + } +} + +// Send implements Channel — sends a message to the appropriate WebSocket connection. +func (c *PicoChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { + if !c.IsRunning() { + return channels.ErrNotRunning + } + + outMsg := newMessage(TypeMessageCreate, map[string]any{ + "content": msg.Content, + }) + + return c.broadcastToSession(msg.ChatID, outMsg) +} + +// EditMessage implements channels.MessageEditor. +func (c *PicoChannel) EditMessage(ctx context.Context, chatID string, messageID string, content string) error { + outMsg := newMessage(TypeMessageUpdate, map[string]any{ + "message_id": messageID, + "content": content, + }) + return c.broadcastToSession(chatID, outMsg) +} + +// StartTyping implements channels.TypingCapable. +func (c *PicoChannel) StartTyping(ctx context.Context, chatID string) (func(), error) { + startMsg := newMessage(TypeTypingStart, nil) + if err := c.broadcastToSession(chatID, startMsg); err != nil { + return func() {}, err + } + return func() { + stopMsg := newMessage(TypeTypingStop, nil) + c.broadcastToSession(chatID, stopMsg) + }, nil +} + +// broadcastToSession sends a message to all connections with a matching session. +func (c *PicoChannel) broadcastToSession(chatID string, msg PicoMessage) error { + // chatID format: "pico:" + sessionID := strings.TrimPrefix(chatID, "pico:") + msg.SessionID = sessionID + + var sent bool + c.connections.Range(func(key, value any) bool { + pc, ok := value.(*picoConn) + if !ok { + return true + } + if pc.sessionID == sessionID { + if err := pc.writeJSON(msg); err != nil { + logger.DebugCF("pico", "Write to connection failed", map[string]any{ + "conn_id": pc.id, + "error": err.Error(), + }) + } else { + sent = true + } + } + return true + }) + + if !sent { + return fmt.Errorf("no active connections for session %s: %w", sessionID, channels.ErrSendFailed) + } + return nil +} + +// handleWebSocket upgrades the HTTP connection and manages the WebSocket lifecycle. +func (c *PicoChannel) handleWebSocket(w http.ResponseWriter, r *http.Request) { + if !c.IsRunning() { + http.Error(w, "channel not running", http.StatusServiceUnavailable) + return + } + + // Authenticate + if !c.authenticate(r) { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + + // Check connection limit + maxConns := c.config.MaxConnections + if maxConns <= 0 { + maxConns = 100 + } + if int(c.connCount.Load()) >= maxConns { + http.Error(w, "too many connections", http.StatusServiceUnavailable) + return + } + + conn, err := c.upgrader.Upgrade(w, r, nil) + if err != nil { + logger.ErrorCF("pico", "WebSocket upgrade failed", map[string]any{ + "error": err.Error(), + }) + return + } + + // Determine session ID from query param or generate one + sessionID := r.URL.Query().Get("session_id") + if sessionID == "" { + sessionID = uuid.New().String() + } + + pc := &picoConn{ + id: uuid.New().String(), + conn: conn, + sessionID: sessionID, + } + + c.connections.Store(pc.id, pc) + c.connCount.Add(1) + + logger.InfoCF("pico", "WebSocket client connected", map[string]any{ + "conn_id": pc.id, + "session_id": sessionID, + }) + + go c.readLoop(pc) +} + +// authenticate checks the Bearer token from header or query parameter. +func (c *PicoChannel) authenticate(r *http.Request) bool { + token := c.config.Token + if token == "" { + return false + } + + // Check Authorization header + auth := r.Header.Get("Authorization") + if strings.HasPrefix(auth, "Bearer ") { + if strings.TrimPrefix(auth, "Bearer ") == token { + return true + } + } + + // Check query parameter + if r.URL.Query().Get("token") == token { + return true + } + + return false +} + +// readLoop reads messages from a WebSocket connection. +func (c *PicoChannel) readLoop(pc *picoConn) { + defer func() { + pc.close() + c.connections.Delete(pc.id) + c.connCount.Add(-1) + logger.InfoCF("pico", "WebSocket client disconnected", map[string]any{ + "conn_id": pc.id, + "session_id": pc.sessionID, + }) + }() + + readTimeout := time.Duration(c.config.ReadTimeout) * time.Second + if readTimeout <= 0 { + readTimeout = 60 * time.Second + } + + _ = pc.conn.SetReadDeadline(time.Now().Add(readTimeout)) + pc.conn.SetPongHandler(func(appData string) error { + _ = pc.conn.SetReadDeadline(time.Now().Add(readTimeout)) + return nil + }) + + // Start ping ticker + pingInterval := time.Duration(c.config.PingInterval) * time.Second + if pingInterval <= 0 { + pingInterval = 30 * time.Second + } + go c.pingLoop(pc, pingInterval) + + for { + select { + case <-c.ctx.Done(): + return + default: + } + + _, rawMsg, err := pc.conn.ReadMessage() + if err != nil { + if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseNormalClosure) { + logger.DebugCF("pico", "WebSocket read error", map[string]any{ + "conn_id": pc.id, + "error": err.Error(), + }) + } + return + } + + _ = pc.conn.SetReadDeadline(time.Now().Add(readTimeout)) + + var msg PicoMessage + if err := json.Unmarshal(rawMsg, &msg); err != nil { + errMsg := newError("invalid_message", "failed to parse message") + pc.writeJSON(errMsg) + continue + } + + c.handleMessage(pc, msg) + } +} + +// pingLoop sends periodic ping frames to keep the connection alive. +func (c *PicoChannel) pingLoop(pc *picoConn, interval time.Duration) { + ticker := time.NewTicker(interval) + defer ticker.Stop() + + for { + select { + case <-c.ctx.Done(): + return + case <-ticker.C: + if pc.closed.Load() { + return + } + pc.writeMu.Lock() + err := pc.conn.WriteMessage(websocket.PingMessage, nil) + pc.writeMu.Unlock() + if err != nil { + return + } + } + } +} + +// handleMessage processes an inbound Pico Protocol message. +func (c *PicoChannel) handleMessage(pc *picoConn, msg PicoMessage) { + switch msg.Type { + case TypePing: + pong := newMessage(TypePong, nil) + pong.ID = msg.ID + pc.writeJSON(pong) + + case TypeMessageSend: + c.handleMessageSend(pc, msg) + + default: + errMsg := newError("unknown_type", fmt.Sprintf("unknown message type: %s", msg.Type)) + pc.writeJSON(errMsg) + } +} + +// handleMessageSend processes an inbound message.send from a client. +func (c *PicoChannel) handleMessageSend(pc *picoConn, msg PicoMessage) { + content, _ := msg.Payload["content"].(string) + if strings.TrimSpace(content) == "" { + errMsg := newError("empty_content", "message content is empty") + pc.writeJSON(errMsg) + return + } + + sessionID := msg.SessionID + if sessionID == "" { + sessionID = pc.sessionID + } + + chatID := "pico:" + sessionID + senderID := "pico-user" + + peer := bus.Peer{Kind: "direct", ID: "pico:" + sessionID} + + metadata := map[string]string{ + "platform": "pico", + "session_id": sessionID, + "conn_id": pc.id, + } + + logger.DebugCF("pico", "Received message", map[string]any{ + "session_id": sessionID, + "preview": truncate(content, 50), + }) + + // Register typing with Manager + if rec := c.GetPlaceholderRecorder(); rec != nil { + stop, err := c.StartTyping(c.ctx, chatID) + if err == nil { + rec.RecordTypingStop("pico", chatID, stop) + } + } + + c.HandleMessage(peer, msg.ID, senderID, chatID, content, nil, metadata) +} + +// truncate truncates a string to maxLen runes. +func truncate(s string, maxLen int) string { + runes := []rune(s) + if len(runes) <= maxLen { + return s + } + return string(runes[:maxLen]) + "..." +} diff --git a/pkg/channels/pico/protocol.go b/pkg/channels/pico/protocol.go new file mode 100644 index 000000000..ca18df1dd --- /dev/null +++ b/pkg/channels/pico/protocol.go @@ -0,0 +1,46 @@ +package pico + +import "time" + +// Protocol message types. +const ( + // Client → Server + TypeMessageSend = "message.send" + TypeMediaSend = "media.send" + TypePing = "ping" + + // Server → Client + TypeMessageCreate = "message.create" + TypeMessageUpdate = "message.update" + TypeMediaCreate = "media.create" + TypeTypingStart = "typing.start" + TypeTypingStop = "typing.stop" + TypeError = "error" + TypePong = "pong" +) + +// PicoMessage is the wire format for all Pico Protocol messages. +type PicoMessage struct { + Type string `json:"type"` + ID string `json:"id,omitempty"` + SessionID string `json:"session_id,omitempty"` + Timestamp int64 `json:"timestamp,omitempty"` + Payload map[string]any `json:"payload,omitempty"` +} + +// newMessage creates a PicoMessage with the given type and payload. +func newMessage(msgType string, payload map[string]any) PicoMessage { + return PicoMessage{ + Type: msgType, + Timestamp: time.Now().UnixMilli(), + Payload: payload, + } +} + +// newError creates an error PicoMessage. +func newError(code, message string) PicoMessage { + return newMessage(TypeError, map[string]any{ + "code": code, + "message": message, + }) +} diff --git a/pkg/channels/slack/slack.go b/pkg/channels/slack/slack.go index 6fba2e0b4..e64525310 100644 --- a/pkg/channels/slack/slack.go +++ b/pkg/channels/slack/slack.go @@ -274,6 +274,18 @@ func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) { Timestamp: messageTS, }) + // Register typing stop (remove "eyes" reaction) with Manager + if rec := c.GetPlaceholderRecorder(); rec != nil { + capturedChannelID := channelID + capturedMessageTS := messageTS + rec.RecordTypingStop("slack", chatID, func() { + c.api.RemoveReaction("eyes", slack.ItemRef{ + Channel: capturedChannelID, + Timestamp: capturedMessageTS, + }) + }) + } + c.pendingAcks.Store(chatID, slackMessageRef{ ChannelID: channelID, Timestamp: messageTS, @@ -380,6 +392,18 @@ func (c *SlackChannel) handleAppMention(ev *slackevents.AppMentionEvent) { Timestamp: messageTS, }) + // Register typing stop (remove "eyes" reaction) with Manager + if rec := c.GetPlaceholderRecorder(); rec != nil { + capturedChannelID := channelID + capturedMessageTS := messageTS + rec.RecordTypingStop("slack", chatID, func() { + c.api.RemoveReaction("eyes", slack.ItemRef{ + Channel: capturedChannelID, + Timestamp: capturedMessageTS, + }) + }) + } + c.pendingAcks.Store(chatID, slackMessageRef{ ChannelID: channelID, Timestamp: messageTS, diff --git a/pkg/channels/telegram/telegram.go b/pkg/channels/telegram/telegram.go index c5c055163..98477f3a8 100644 --- a/pkg/channels/telegram/telegram.go +++ b/pkg/channels/telegram/telegram.go @@ -7,8 +7,8 @@ import ( "net/url" "os" "regexp" + "strconv" "strings" - "sync" "time" "github.com/mymmrac/telego" @@ -26,25 +26,13 @@ import ( type TelegramChannel struct { *channels.BaseChannel - bot *telego.Bot - bh *telegohandler.BotHandler - commands TelegramCommander - config *config.Config - chatIDs map[string]int64 - ctx context.Context - cancel context.CancelFunc - placeholders sync.Map // chatID -> messageID - stopThinking sync.Map // chatID -> thinkingCancel -} - -type thinkingCancel struct { - fn context.CancelFunc -} - -func (c *thinkingCancel) Cancel() { - if c != nil && c.fn != nil { - c.fn() - } + bot *telego.Bot + bh *telegohandler.BotHandler + commands TelegramCommander + config *config.Config + chatIDs map[string]int64 + ctx context.Context + cancel context.CancelFunc } func NewTelegramChannel(cfg *config.Config, bus *bus.MessageBus) (*TelegramChannel, error) { @@ -85,13 +73,11 @@ func NewTelegramChannel(cfg *config.Config, bus *bus.MessageBus) (*TelegramChann ) return &TelegramChannel{ - BaseChannel: base, - commands: NewTelegramCommands(bot, cfg), - bot: bot, - config: cfg, - chatIDs: make(map[string]int64), - placeholders: sync.Map{}, - stopThinking: sync.Map{}, + BaseChannel: base, + commands: NewTelegramCommands(bot, cfg), + bot: bot, + config: cfg, + chatIDs: make(map[string]int64), }, nil } @@ -149,21 +135,6 @@ func (c *TelegramChannel) Stop(ctx context.Context) error { logger.InfoC("telegram", "Stopping Telegram bot...") c.SetRunning(false) - // Clean up all thinking cancel functions to avoid context leaks - c.stopThinking.Range(func(key, value any) bool { - if cf, ok := value.(*thinkingCancel); ok && cf != nil { - cf.Cancel() - } - c.stopThinking.Delete(key) - return true - }) - - // Clean up placeholder state - c.placeholders.Range(func(key, value any) bool { - c.placeholders.Delete(key) - return true - }) - // Stop the bot handler if c.bh != nil { c.bh.Stop() @@ -187,28 +158,9 @@ func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) err return fmt.Errorf("invalid chat ID %s: %w", msg.ChatID, channels.ErrSendFailed) } - // Stop thinking animation - if stop, ok := c.stopThinking.Load(msg.ChatID); ok { - if cf, ok := stop.(*thinkingCancel); ok && cf != nil { - cf.Cancel() - } - c.stopThinking.Delete(msg.ChatID) - } - htmlContent := markdownToTelegramHTML(msg.Content) - // Try to edit placeholder - if pID, ok := c.placeholders.Load(msg.ChatID); ok { - c.placeholders.Delete(msg.ChatID) - editMsg := tu.EditMessageText(tu.ID(chatID), pID.(int), htmlContent) - editMsg.ParseMode = telego.ModeHTML - - if _, err = c.bot.EditMessageText(ctx, editMsg); err == nil { - return nil - } - // Fallback to new message if edit fails - } - + // Typing/placeholder handled by Manager.preSend — just send the message tgMsg := tu.Message(tu.ID(chatID), htmlContent) tgMsg.ParseMode = telego.ModeHTML @@ -225,6 +177,23 @@ func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) err return nil } +// EditMessage implements channels.MessageEditor. +func (c *TelegramChannel) EditMessage(ctx context.Context, chatID string, messageID string, content string) error { + cid, err := parseChatID(chatID) + if err != nil { + return err + } + mid, err := strconv.Atoi(messageID) + if err != nil { + return err + } + htmlContent := markdownToTelegramHTML(content) + editMsg := tu.EditMessageText(tu.ID(cid), mid, htmlContent) + editMsg.ParseMode = telego.ModeHTML + _, err = c.bot.EditMessageText(ctx, editMsg) + return err +} + // SendMedia implements the channels.MediaSender interface. func (c *TelegramChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error { if !c.IsRunning() { @@ -445,21 +414,21 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Mes }) } - // Stop any previous thinking animation - if prevStop, ok := c.stopThinking.Load(chatIDStr); ok { - if cf, ok := prevStop.(*thinkingCancel); ok && cf != nil { - cf.Cancel() - } - } - - // Create cancel function for thinking state + // Create cancel function for thinking state and register with Manager _, thinkCancel := context.WithTimeout(ctx, 5*time.Minute) - c.stopThinking.Store(chatIDStr, &thinkingCancel{fn: thinkCancel}) + if rec := c.GetPlaceholderRecorder(); rec != nil { + rec.RecordTypingStop("telegram", chatIDStr, thinkCancel) + } else { + // No recorder — cancel immediately to avoid context leak + thinkCancel() + } pMsg, err := c.bot.SendMessage(ctx, tu.Message(tu.ID(chatID), "Thinking... 💭")) if err == nil { pID := pMsg.MessageID - c.placeholders.Store(chatIDStr, pID) + if rec := c.GetPlaceholderRecorder(); rec != nil { + rec.RecordPlaceholder("telegram", chatIDStr, fmt.Sprintf("%d", pID)) + } } peerKind := "direct" diff --git a/pkg/config/config.go b/pkg/config/config.go index 0c89d05eb..d32e8db90 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -202,6 +202,7 @@ type ChannelsConfig struct { OneBot OneBotConfig `json:"onebot"` WeCom WeComConfig `json:"wecom"` WeComApp WeComAppConfig `json:"wecom_app"` + Pico PicoConfig `json:"pico"` } // GroupTriggerConfig controls when the bot responds in group chats. @@ -343,6 +344,17 @@ type WeComAppConfig struct { GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"` } +type PicoConfig struct { + Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_PICO_ENABLED"` + Token string `json:"token" env:"PICOCLAW_CHANNELS_PICO_TOKEN"` + AllowOrigins []string `json:"allow_origins,omitempty"` + PingInterval int `json:"ping_interval,omitempty"` // seconds, default 30 + ReadTimeout int `json:"read_timeout,omitempty"` // seconds, default 60 + WriteTimeout int `json:"write_timeout,omitempty"` // seconds, default 10 + MaxConnections int `json:"max_connections,omitempty"` // default 100 + AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_PICO_ALLOW_FROM"` +} + type HeartbeatConfig struct { Enabled bool `json:"enabled" env:"PICOCLAW_HEARTBEAT_ENABLED"` Interval int `json:"interval" env:"PICOCLAW_HEARTBEAT_INTERVAL"` // minutes, min 5 diff --git a/pkg/config/defaults.go b/pkg/config/defaults.go index 5c53a3963..8445510e2 100644 --- a/pkg/config/defaults.go +++ b/pkg/config/defaults.go @@ -33,6 +33,11 @@ func DefaultConfig() *Config { Enabled: false, Token: "", AllowFrom: FlexibleStringSlice{}, + Typing: TypingConfig{Enabled: true}, + Placeholder: PlaceholderConfig{ + Enabled: true, + Text: "Thinking... 💭", + }, }, Feishu: FeishuConfig{ Enabled: false, @@ -114,6 +119,15 @@ func DefaultConfig() *Config { AllowFrom: FlexibleStringSlice{}, ReplyTimeout: 5, }, + Pico: PicoConfig{ + Enabled: false, + Token: "", + PingInterval: 30, + ReadTimeout: 60, + WriteTimeout: 10, + MaxConnections: 100, + AllowFrom: FlexibleStringSlice{}, + }, }, Providers: ProvidersConfig{ OpenAI: OpenAIProviderConfig{WebSearch: true},