From 544940807f4eee0dc8dd136a79a85aa2eef23e87 Mon Sep 17 00:00:00 2001 From: Amir Mamaghani <67312799+amirmamaghani@users.noreply.github.com> Date: Fri, 20 Mar 2026 13:43:40 +0100 Subject: [PATCH] feat(pico): add pico_client outbound WebSocket channel (#1198) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat(pico): add pico_client outbound WebSocket channel Add a client-mode counterpart to the existing pico server channel. pico_client connects to a remote Pico Protocol WebSocket server, enabling picoclaw to bridge messages with external Pico-compatible services. Includes config, factory registration, manager wiring, 8 unit tests, and a minimal echo-server example for interactive testing. * fix(pico): address PR #1198 review — goroutine leak, race, auth - Add per-connection context cancel to picoConn to prevent pingLoop goroutine leak on disconnect - Re-acquire mutex in StartTyping stop closure to avoid stale conn race - Remove query-param token auth from echo server (header-only) - Move ListenAndServe to main goroutine where log.Fatal is safe Co-Authored-By: Claude Opus 4.6 * fix: replace ConsumeInbound with InboundChan select in client test MessageBus does not expose a ConsumeInbound method. Use a select on InboundChan() with context cancellation, matching the pattern used in the bus package tests. Co-Authored-By: Claude Opus 4.6 (1M context) --------- Co-authored-by: Claude Opus 4.6 --- config/config.example.json | 19 ++ examples/pico-echo-server/README.md | 47 ++++ examples/pico-echo-server/main.go | 160 ++++++++++++++ pkg/channels/manager.go | 4 + pkg/channels/pico/client.go | 319 ++++++++++++++++++++++++++++ pkg/channels/pico/client_test.go | 264 +++++++++++++++++++++++ pkg/channels/pico/init.go | 3 + pkg/channels/pico/pico.go | 4 + pkg/config/config.go | 11 + 9 files changed, 831 insertions(+) create mode 100644 examples/pico-echo-server/README.md create mode 100644 examples/pico-echo-server/main.go create mode 100644 pkg/channels/pico/client.go create mode 100644 pkg/channels/pico/client_test.go diff --git a/config/config.example.json b/config/config.example.json index 221e89491..69ac062ac 100644 --- a/config/config.example.json +++ b/config/config.example.json @@ -213,6 +213,25 @@ "welcome_message": "Hello! I'm your AI assistant. How can I help you today?", "reasoning_channel_id": "" }, + "pico": { + "enabled": false, + "token": "YOUR_PICO_TOKEN", + "allow_token_query": false, + "allow_origins": [], + "ping_interval": 30, + "read_timeout": 60, + "max_connections": 100, + "allow_from": [] + }, + "pico_client": { + "enabled": false, + "url": "wss://remote-pico-server/pico/ws", + "token": "YOUR_PICO_TOKEN", + "session_id": "", + "ping_interval": 30, + "read_timeout": 60, + "allow_from": [] + }, "irc": { "enabled": false, "server": "irc.libera.chat:6697", diff --git a/examples/pico-echo-server/README.md b/examples/pico-echo-server/README.md new file mode 100644 index 000000000..f6b5d8020 --- /dev/null +++ b/examples/pico-echo-server/README.md @@ -0,0 +1,47 @@ +# pico-echo-server + +Minimal Pico Protocol WebSocket server for testing the `pico_client` channel. + +## Usage + +```bash +go run ./examples/pico-echo-server -addr :9090 -token secret +``` + +### Flags + +| Flag | Default | Description | +|----------|---------|------------------------------------| +| `-addr` | `:9090` | Listen address | +| `-token` | (none) | Auth token; empty disables auth | + +## How it works + +- Listens for WebSocket connections at `/ws` +- Authenticates via `Authorization: Bearer ` header or `?token=` query param +- Prints received `message.send` content to stdout +- Responds to `ping` with `pong` +- Lines typed into stdin are broadcast as `message.create` to all connected clients + +## Testing with pico_client + +1. Start the server: + ```bash + go run ./examples/pico-echo-server -token mytoken + ``` + +2. Configure `pico_client` in your `config.json`: + ```json + { + "channels": { + "pico_client": { + "enabled": true, + "url": "ws://localhost:9090/ws", + "token": "mytoken", + "session_id": "test-session" + } + } + } + ``` + +3. Start picoclaw — the client connects and you can exchange messages interactively via stdin/stdout. diff --git a/examples/pico-echo-server/main.go b/examples/pico-echo-server/main.go new file mode 100644 index 000000000..46970fb34 --- /dev/null +++ b/examples/pico-echo-server/main.go @@ -0,0 +1,160 @@ +// pico-echo-server is a minimal Pico Protocol WebSocket server for testing +// the pico_client channel. It accepts connections, prints received messages +// to stdout, and forwards stdin lines as message.create to all connected clients. +// +// Usage: +// +// go run ./examples/pico-echo-server -addr :9090 -token secret +// +// Then configure pico_client with url=ws://localhost:9090/ws&token=secret. +package main + +import ( + "bufio" + "encoding/json" + "flag" + "fmt" + "log" + "net/http" + "os" + "strings" + "sync" + "time" + + "github.com/gorilla/websocket" +) + +type picoMessage struct { + Type string `json:"type"` + ID string `json:"id,omitempty"` + SessionID string `json:"session_id,omitempty"` + Timestamp int64 `json:"timestamp,omitempty"` + Payload map[string]any `json:"payload,omitempty"` +} + +var upgrader = websocket.Upgrader{CheckOrigin: func(*http.Request) bool { return true }} + +type server struct { + token string + mu sync.Mutex + conns map[*websocket.Conn]string // conn → sessionID +} + +func (s *server) handleWS(w http.ResponseWriter, r *http.Request) { + if s.token != "" { + auth := r.Header.Get("Authorization") + if auth != "Bearer "+s.token { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + } + + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + log.Printf("upgrade: %v", err) + return + } + + sessionID := r.URL.Query().Get("session_id") + if sessionID == "" { + sessionID = fmt.Sprintf("sess-%d", time.Now().UnixMilli()) + } + + s.mu.Lock() + s.conns[conn] = sessionID + s.mu.Unlock() + + log.Printf("[+] client connected (session=%s)", sessionID) + + defer func() { + s.mu.Lock() + delete(s.conns, conn) + s.mu.Unlock() + conn.Close() + log.Printf("[-] client disconnected (session=%s)", sessionID) + }() + + for { + _, raw, err := conn.ReadMessage() + if err != nil { + if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseNormalClosure) { + log.Printf("read error: %v", err) + } + return + } + + var msg picoMessage + if err := json.Unmarshal(raw, &msg); err != nil { + log.Printf("bad json: %v", err) + continue + } + + switch msg.Type { + case "ping": + pong := picoMessage{Type: "pong", ID: msg.ID, Timestamp: time.Now().UnixMilli()} + conn.WriteJSON(pong) + + case "message.send": + content, _ := msg.Payload["content"].(string) + fmt.Printf("[%s] %s\n", sessionID, content) + + case "typing.start": + log.Printf("[%s] typing...", sessionID) + + case "typing.stop": + log.Printf("[%s] stopped typing", sessionID) + + default: + log.Printf("[%s] unknown type: %s", sessionID, msg.Type) + } + } +} + +func (s *server) broadcast(content string) { + msg := picoMessage{ + Type: "message.create", + Timestamp: time.Now().UnixMilli(), + Payload: map[string]any{"content": content}, + } + + s.mu.Lock() + defer s.mu.Unlock() + + for conn, sid := range s.conns { + msg.SessionID = sid + if err := conn.WriteJSON(msg); err != nil { + log.Printf("write to %s failed: %v", sid, err) + } + } +} + +func main() { + addr := flag.String("addr", ":9090", "listen address") + token := flag.String("token", "", "auth token (empty = no auth)") + flag.Parse() + + s := &server{ + token: *token, + conns: make(map[*websocket.Conn]string), + } + + http.HandleFunc("/ws", s.handleWS) + + log.Printf("listening on %s", *addr) + log.Printf("connect with: ws://localhost%s/ws", *addr) + fmt.Println("Type messages to send to connected clients (Ctrl+C to quit):") + + go func() { + scanner := bufio.NewScanner(os.Stdin) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line == "" { + continue + } + s.broadcast(line) + log.Printf("[server] sent: %s", line) + } + }() + + log.Fatal(http.ListenAndServe(*addr, nil)) +} diff --git a/pkg/channels/manager.go b/pkg/channels/manager.go index c980daf66..741fad53e 100644 --- a/pkg/channels/manager.go +++ b/pkg/channels/manager.go @@ -323,6 +323,10 @@ func (m *Manager) initChannels(channels *config.ChannelsConfig) error { m.initChannel("pico", "Pico") } + if channels.PicoClient.Enabled && channels.PicoClient.URL != "" { + m.initChannel("pico_client", "Pico Client") + } + if channels.IRC.Enabled && channels.IRC.Server != "" { m.initChannel("irc", "IRC") } diff --git a/pkg/channels/pico/client.go b/pkg/channels/pico/client.go new file mode 100644 index 000000000..2c335050d --- /dev/null +++ b/pkg/channels/pico/client.go @@ -0,0 +1,319 @@ +package pico + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "strings" + "sync" + "time" + + "github.com/google/uuid" + "github.com/gorilla/websocket" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/identity" + "github.com/sipeed/picoclaw/pkg/logger" +) + +// PicoClientChannel connects to a remote Pico Protocol WebSocket server. +type PicoClientChannel struct { + *channels.BaseChannel + config config.PicoClientConfig + conn *picoConn + mu sync.Mutex + ctx context.Context + cancel context.CancelFunc +} + +// NewPicoClientChannel creates a new Pico Protocol client channel. +func NewPicoClientChannel( + cfg config.PicoClientConfig, + messageBus *bus.MessageBus, +) (*PicoClientChannel, error) { + if cfg.URL == "" { + return nil, fmt.Errorf("pico_client url is required") + } + + base := channels.NewBaseChannel("pico_client", cfg, messageBus, cfg.AllowFrom) + + return &PicoClientChannel{ + BaseChannel: base, + config: cfg, + }, nil +} + +// Start dials the remote server and begins reading. +func (c *PicoClientChannel) Start(ctx context.Context) error { + logger.InfoC("pico_client", "Starting Pico Client channel") + c.ctx, c.cancel = context.WithCancel(ctx) + + if err := c.dial(); err != nil { + c.cancel() + return fmt.Errorf("pico_client initial connect: %w", err) + } + + c.SetRunning(true) + go c.reconnectLoop() + + logger.InfoCF("pico_client", "Connected", map[string]any{"url": c.config.URL}) + return nil +} + +// Stop closes the connection. +func (c *PicoClientChannel) Stop(ctx context.Context) error { + logger.InfoC("pico_client", "Stopping Pico Client channel") + c.SetRunning(false) + if c.cancel != nil { + c.cancel() + } + c.mu.Lock() + if c.conn != nil { + c.conn.close() + } + c.mu.Unlock() + logger.InfoC("pico_client", "Pico Client channel stopped") + return nil +} + +func (c *PicoClientChannel) dial() error { + header := http.Header{} + if c.config.Token != "" { + header.Set("Authorization", "Bearer "+c.config.Token) + } + + ws, resp, err := websocket.DefaultDialer.DialContext(c.ctx, c.config.URL, header) + if resp != nil && resp.Body != nil { + resp.Body.Close() + } + if err != nil { + return err + } + + connCtx, connCancel := context.WithCancel(c.ctx) + + pc := &picoConn{ + id: uuid.New().String(), + conn: ws, + sessionID: c.config.SessionID, + cancel: connCancel, + } + if pc.sessionID == "" { + pc.sessionID = uuid.New().String() + } + + c.mu.Lock() + c.conn = pc + c.mu.Unlock() + + go c.readLoop(connCtx, pc) + return nil +} + +// reconnectLoop re-dials when the connection drops. +func (c *PicoClientChannel) reconnectLoop() { + for { + select { + case <-c.ctx.Done(): + return + default: + } + + c.mu.Lock() + pc := c.conn + c.mu.Unlock() + + if pc == nil || pc.closed.Load() { + backoff := 5 * time.Second + logger.InfoC("pico_client", "Reconnecting...") + if err := c.dial(); err != nil { + logger.WarnCF("pico_client", "Reconnect failed", map[string]any{ + "error": err.Error(), + }) + select { + case <-c.ctx.Done(): + return + case <-time.After(backoff): + } + continue + } + logger.InfoC("pico_client", "Reconnected") + } + + select { + case <-c.ctx.Done(): + return + case <-time.After(1 * time.Second): + } + } +} + +func (c *PicoClientChannel) readLoop(connCtx context.Context, pc *picoConn) { + defer pc.close() + + readTimeout := time.Duration(c.config.ReadTimeout) * time.Second + if readTimeout <= 0 { + readTimeout = 60 * time.Second + } + + _ = pc.conn.SetReadDeadline(time.Now().Add(readTimeout)) + pc.conn.SetPongHandler(func(string) error { + return pc.conn.SetReadDeadline(time.Now().Add(readTimeout)) + }) + + pingInterval := time.Duration(c.config.PingInterval) * time.Second + if pingInterval <= 0 { + pingInterval = 30 * time.Second + } + go c.pingLoop(connCtx, pc, pingInterval) + + for { + select { + case <-connCtx.Done(): + return + default: + } + + _, raw, err := pc.conn.ReadMessage() + if err != nil { + if websocket.IsUnexpectedCloseError( + err, + websocket.CloseGoingAway, + websocket.CloseNormalClosure, + ) { + logger.DebugCF("pico_client", "Read error", map[string]any{ + "error": err.Error(), + }) + } + return + } + + _ = pc.conn.SetReadDeadline(time.Now().Add(readTimeout)) + + var msg PicoMessage + if err := json.Unmarshal(raw, &msg); err != nil { + continue + } + + c.handleInbound(pc, msg) + } +} + +func (c *PicoClientChannel) pingLoop(connCtx context.Context, pc *picoConn, interval time.Duration) { + ticker := time.NewTicker(interval) + defer ticker.Stop() + for { + select { + case <-connCtx.Done(): + return + case <-ticker.C: + if pc.closed.Load() { + return + } + pc.writeMu.Lock() + err := pc.conn.WriteMessage(websocket.PingMessage, nil) + pc.writeMu.Unlock() + if err != nil { + return + } + } + } +} + +// handleInbound processes messages from the remote server. +// In client mode the server sends message.create (responses) and the client +// sends message.send (user input). We treat message.create from the server +// as inbound user messages to feed into the agent loop. +func (c *PicoClientChannel) handleInbound(pc *picoConn, msg PicoMessage) { + switch msg.Type { + case TypePong: + // response to our ping, ignore + case TypeMessageCreate: + // Server sent us a message — treat as inbound + c.handleServerMessage(pc, msg) + default: + logger.DebugCF("pico_client", "Ignoring message type", map[string]any{ + "type": msg.Type, + }) + } +} + +func (c *PicoClientChannel) handleServerMessage(pc *picoConn, msg PicoMessage) { + content, _ := msg.Payload["content"].(string) + if strings.TrimSpace(content) == "" { + return + } + + sessionID := msg.SessionID + if sessionID == "" { + sessionID = pc.sessionID + } + + chatID := "pico_client:" + sessionID + senderID := "pico-remote" + peer := bus.Peer{Kind: "direct", ID: chatID} + + sender := bus.SenderInfo{ + Platform: "pico_client", + PlatformID: senderID, + CanonicalID: identity.BuildCanonicalID("pico_client", senderID), + } + + if !c.IsAllowedSender(sender) { + return + } + + c.HandleMessage(c.ctx, peer, msg.ID, senderID, chatID, content, nil, map[string]string{ + "platform": "pico_client", + "session_id": sessionID, + }, sender) +} + +// Send sends a message to the remote server. +func (c *PicoClientChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { + if !c.IsRunning() { + return channels.ErrNotRunning + } + c.mu.Lock() + pc := c.conn + c.mu.Unlock() + if pc == nil || pc.closed.Load() { + return channels.ErrSendFailed + } + + outMsg := newMessage(TypeMessageSend, map[string]any{ + "content": msg.Content, + }) + outMsg.SessionID = strings.TrimPrefix(msg.ChatID, "pico_client:") + return pc.writeJSON(outMsg) +} + +// StartTyping implements channels.TypingCapable. +func (c *PicoClientChannel) StartTyping(ctx context.Context, chatID string) (func(), error) { + c.mu.Lock() + pc := c.conn + c.mu.Unlock() + if pc == nil || pc.closed.Load() { + return func() {}, nil + } + + startMsg := newMessage(TypeTypingStart, nil) + startMsg.SessionID = strings.TrimPrefix(chatID, "pico_client:") + if err := pc.writeJSON(startMsg); err != nil { + return func() {}, err + } + return func() { + c.mu.Lock() + currentPC := c.conn + c.mu.Unlock() + if currentPC == nil { + return + } + stopMsg := newMessage(TypeTypingStop, nil) + stopMsg.SessionID = strings.TrimPrefix(chatID, "pico_client:") + currentPC.writeJSON(stopMsg) + }, nil +} diff --git a/pkg/channels/pico/client_test.go b/pkg/channels/pico/client_test.go new file mode 100644 index 000000000..118c9abea --- /dev/null +++ b/pkg/channels/pico/client_test.go @@ -0,0 +1,264 @@ +package pico + +import ( + "context" + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/gorilla/websocket" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" +) + +func TestNewPicoClientChannel_MissingURL(t *testing.T) { + _, err := NewPicoClientChannel(config.PicoClientConfig{}, bus.NewMessageBus()) + if err == nil { + t.Fatal("expected error for missing URL") + } + if !strings.Contains(err.Error(), "url is required") { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestNewPicoClientChannel_OK(t *testing.T) { + ch, err := NewPicoClientChannel(config.PicoClientConfig{ + URL: "ws://localhost:9999/ws", + }, bus.NewMessageBus()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if ch.Name() != "pico_client" { + t.Fatalf("name = %q, want pico_client", ch.Name()) + } +} + +func TestSend_NotRunning(t *testing.T) { + ch, err := NewPicoClientChannel(config.PicoClientConfig{ + URL: "ws://localhost:9999/ws", + }, bus.NewMessageBus()) + if err != nil { + t.Fatal(err) + } + err = ch.Send(context.Background(), bus.OutboundMessage{Content: "hi"}) + if !errors.Is(err, channels.ErrNotRunning) { + t.Fatalf("expected ErrNotRunning, got %v", err) + } +} + +// testServer starts a WS server that echoes message.send back as message.create. +func testServer(t *testing.T, token string) *httptest.Server { + t.Helper() + upgrader := websocket.Upgrader{CheckOrigin: func(*http.Request) bool { return true }} + + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if token != "" { + auth := r.Header.Get("Authorization") + if auth != "Bearer "+token { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + } + + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Logf("upgrade error: %v", err) + return + } + defer conn.Close() + + for { + _, raw, err := conn.ReadMessage() + if err != nil { + return + } + + var msg PicoMessage + if err := json.Unmarshal(raw, &msg); err != nil { + continue + } + + if msg.Type == TypeMessageSend { + reply := newMessage(TypeMessageCreate, msg.Payload) + reply.SessionID = msg.SessionID + if err := conn.WriteJSON(reply); err != nil { + return + } + } + } + })) +} + +func wsURL(httpURL string) string { + return "ws" + strings.TrimPrefix(httpURL, "http") +} + +func TestClientChannel_ConnectAndSend(t *testing.T) { + srv := testServer(t, "test-token") + defer srv.Close() + + mb := bus.NewMessageBus() + ch, err := NewPicoClientChannel(config.PicoClientConfig{ + URL: wsURL(srv.URL), + Token: "test-token", + SessionID: "sess-1", + PingInterval: 60, + ReadTimeout: 10, + }, mb) + if err != nil { + t.Fatal(err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + if err = ch.Start(ctx); err != nil { + t.Fatalf("Start: %v", err) + } + defer ch.Stop(ctx) + + // Send a message + err = ch.Send(ctx, bus.OutboundMessage{ + ChatID: "pico_client:sess-1", + Content: "hello", + }) + if err != nil { + t.Fatalf("Send: %v", err) + } +} + +func TestClientChannel_AuthFailure(t *testing.T) { + srv := testServer(t, "correct-token") + defer srv.Close() + + ch, err := NewPicoClientChannel(config.PicoClientConfig{ + URL: wsURL(srv.URL), + Token: "wrong-token", + }, bus.NewMessageBus()) + if err != nil { + t.Fatal(err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + err = ch.Start(ctx) + if err == nil { + ch.Stop(ctx) + t.Fatal("expected auth failure") + } +} + +func TestClientChannel_ReceivesServerMessage(t *testing.T) { + srv := testServer(t, "") + defer srv.Close() + + mb := bus.NewMessageBus() + + ch, err := NewPicoClientChannel(config.PicoClientConfig{ + URL: wsURL(srv.URL), + SessionID: "sess-echo", + ReadTimeout: 10, + }, mb) + if err != nil { + t.Fatal(err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + if err = ch.Start(ctx); err != nil { + t.Fatalf("Start: %v", err) + } + defer ch.Stop(ctx) + + // Send a message; the echo server replies with message.create + err = ch.Send(ctx, bus.OutboundMessage{ + ChatID: "pico_client:sess-echo", + Content: "ping", + }) + if err != nil { + t.Fatalf("Send: %v", err) + } + + // The echoed message.create is processed by handleServerMessage which + // calls HandleMessage → PublishInbound. Consume it from the bus. + select { + case msg := <-mb.InboundChan(): + if msg.Content != "ping" { + t.Fatalf("received = %q, want %q", msg.Content, "ping") + } + case <-ctx.Done(): + t.Fatal("timed out waiting for echoed message") + } +} + +func TestClientChannel_StartTyping(t *testing.T) { + srv := testServer(t, "") + defer srv.Close() + + ch, err := NewPicoClientChannel(config.PicoClientConfig{ + URL: wsURL(srv.URL), + SessionID: "sess-type", + ReadTimeout: 10, + }, bus.NewMessageBus()) + if err != nil { + t.Fatal(err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + if err = ch.Start(ctx); err != nil { + t.Fatalf("Start: %v", err) + } + defer ch.Stop(ctx) + + stop, err := ch.StartTyping(ctx, "pico_client:sess-type") + if err != nil { + t.Fatalf("StartTyping: %v", err) + } + stop() // should not panic +} + +func TestSend_ClosedConnection(t *testing.T) { + srv := testServer(t, "") + defer srv.Close() + + ch, err := NewPicoClientChannel(config.PicoClientConfig{ + URL: wsURL(srv.URL), + SessionID: "sess-close", + ReadTimeout: 10, + }, bus.NewMessageBus()) + if err != nil { + t.Fatal(err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + if err = ch.Start(ctx); err != nil { + t.Fatalf("Start: %v", err) + } + + // Force close the underlying connection + ch.mu.Lock() + ch.conn.close() + ch.mu.Unlock() + + err = ch.Send(ctx, bus.OutboundMessage{ + ChatID: "pico_client:sess-close", + Content: "should fail", + }) + if !errors.Is(err, channels.ErrSendFailed) { + t.Fatalf("expected ErrSendFailed, got %v", err) + } + + ch.Stop(ctx) +} diff --git a/pkg/channels/pico/init.go b/pkg/channels/pico/init.go index 96d764418..0319279d8 100644 --- a/pkg/channels/pico/init.go +++ b/pkg/channels/pico/init.go @@ -10,4 +10,7 @@ func init() { channels.RegisterFactory("pico", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) { return NewPicoChannel(cfg.Channels.Pico, b) }) + channels.RegisterFactory("pico_client", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) { + return NewPicoClientChannel(cfg.Channels.PicoClient, b) + }) } diff --git a/pkg/channels/pico/pico.go b/pkg/channels/pico/pico.go index 206e71f92..77e7bbdb6 100644 --- a/pkg/channels/pico/pico.go +++ b/pkg/channels/pico/pico.go @@ -27,6 +27,7 @@ type picoConn struct { sessionID string writeMu sync.Mutex closed atomic.Bool + cancel context.CancelFunc // cancels per-connection goroutines (e.g. pingLoop) } // writeJSON sends a JSON message to the connection with write locking. @@ -42,6 +43,9 @@ func (pc *picoConn) writeJSON(v any) error { // close closes the connection. func (pc *picoConn) close() { if pc.closed.CompareAndSwap(false, true) { + if pc.cancel != nil { + pc.cancel() + } pc.conn.Close() } } diff --git a/pkg/config/config.go b/pkg/config/config.go index 33a5db8ae..f524e952a 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -297,6 +297,7 @@ type ChannelsConfig struct { WeComApp WeComAppConfig `json:"wecom_app"` WeComAIBot WeComAIBotConfig `json:"wecom_aibot"` Pico PicoConfig `json:"pico"` + PicoClient PicoClientConfig `json:"pico_client"` IRC IRCConfig `json:"irc"` } @@ -504,6 +505,16 @@ type PicoConfig struct { Placeholder PlaceholderConfig `json:"placeholder,omitempty"` } +type PicoClientConfig struct { + Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_PICO_CLIENT_ENABLED"` + URL string `json:"url" env:"PICOCLAW_CHANNELS_PICO_CLIENT_URL"` + Token string `json:"token" env:"PICOCLAW_CHANNELS_PICO_CLIENT_TOKEN"` + SessionID string `json:"session_id,omitempty"` + PingInterval int `json:"ping_interval,omitempty"` + ReadTimeout int `json:"read_timeout,omitempty"` + AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_PICO_CLIENT_ALLOW_FROM"` +} + type IRCConfig struct { Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_IRC_ENABLED"` Server string `json:"server" env:"PICOCLAW_CHANNELS_IRC_SERVER"`