Files
picoclaw/pkg/channels/pico/client_test.go
T
Amir Mamaghani 544940807f feat(pico): add pico_client outbound WebSocket channel (#1198)
* feat(pico): add pico_client outbound WebSocket channel

Add a client-mode counterpart to the existing pico server channel.
pico_client connects to a remote Pico Protocol WebSocket server,
enabling picoclaw to bridge messages with external Pico-compatible
services.

Includes config, factory registration, manager wiring, 8 unit tests,
and a minimal echo-server example for interactive testing.

* fix(pico): address PR #1198 review — goroutine leak, race, auth

- Add per-connection context cancel to picoConn to prevent pingLoop
  goroutine leak on disconnect
- Re-acquire mutex in StartTyping stop closure to avoid stale conn race
- Remove query-param token auth from echo server (header-only)
- Move ListenAndServe to main goroutine where log.Fatal is safe

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* fix: replace ConsumeInbound with InboundChan select in client test

MessageBus does not expose a ConsumeInbound method. Use a select on
InboundChan() with context cancellation, matching the pattern used in
the bus package tests.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-20 20:43:40 +08:00

265 lines
5.9 KiB
Go

package pico
import (
"context"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/gorilla/websocket"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/channels"
"github.com/sipeed/picoclaw/pkg/config"
)
func TestNewPicoClientChannel_MissingURL(t *testing.T) {
_, err := NewPicoClientChannel(config.PicoClientConfig{}, bus.NewMessageBus())
if err == nil {
t.Fatal("expected error for missing URL")
}
if !strings.Contains(err.Error(), "url is required") {
t.Fatalf("unexpected error: %v", err)
}
}
func TestNewPicoClientChannel_OK(t *testing.T) {
ch, err := NewPicoClientChannel(config.PicoClientConfig{
URL: "ws://localhost:9999/ws",
}, bus.NewMessageBus())
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if ch.Name() != "pico_client" {
t.Fatalf("name = %q, want pico_client", ch.Name())
}
}
func TestSend_NotRunning(t *testing.T) {
ch, err := NewPicoClientChannel(config.PicoClientConfig{
URL: "ws://localhost:9999/ws",
}, bus.NewMessageBus())
if err != nil {
t.Fatal(err)
}
err = ch.Send(context.Background(), bus.OutboundMessage{Content: "hi"})
if !errors.Is(err, channels.ErrNotRunning) {
t.Fatalf("expected ErrNotRunning, got %v", err)
}
}
// testServer starts a WS server that echoes message.send back as message.create.
func testServer(t *testing.T, token string) *httptest.Server {
t.Helper()
upgrader := websocket.Upgrader{CheckOrigin: func(*http.Request) bool { return true }}
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if token != "" {
auth := r.Header.Get("Authorization")
if auth != "Bearer "+token {
http.Error(w, "unauthorized", http.StatusUnauthorized)
return
}
}
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
t.Logf("upgrade error: %v", err)
return
}
defer conn.Close()
for {
_, raw, err := conn.ReadMessage()
if err != nil {
return
}
var msg PicoMessage
if err := json.Unmarshal(raw, &msg); err != nil {
continue
}
if msg.Type == TypeMessageSend {
reply := newMessage(TypeMessageCreate, msg.Payload)
reply.SessionID = msg.SessionID
if err := conn.WriteJSON(reply); err != nil {
return
}
}
}
}))
}
func wsURL(httpURL string) string {
return "ws" + strings.TrimPrefix(httpURL, "http")
}
func TestClientChannel_ConnectAndSend(t *testing.T) {
srv := testServer(t, "test-token")
defer srv.Close()
mb := bus.NewMessageBus()
ch, err := NewPicoClientChannel(config.PicoClientConfig{
URL: wsURL(srv.URL),
Token: "test-token",
SessionID: "sess-1",
PingInterval: 60,
ReadTimeout: 10,
}, mb)
if err != nil {
t.Fatal(err)
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err = ch.Start(ctx); err != nil {
t.Fatalf("Start: %v", err)
}
defer ch.Stop(ctx)
// Send a message
err = ch.Send(ctx, bus.OutboundMessage{
ChatID: "pico_client:sess-1",
Content: "hello",
})
if err != nil {
t.Fatalf("Send: %v", err)
}
}
func TestClientChannel_AuthFailure(t *testing.T) {
srv := testServer(t, "correct-token")
defer srv.Close()
ch, err := NewPicoClientChannel(config.PicoClientConfig{
URL: wsURL(srv.URL),
Token: "wrong-token",
}, bus.NewMessageBus())
if err != nil {
t.Fatal(err)
}
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
err = ch.Start(ctx)
if err == nil {
ch.Stop(ctx)
t.Fatal("expected auth failure")
}
}
func TestClientChannel_ReceivesServerMessage(t *testing.T) {
srv := testServer(t, "")
defer srv.Close()
mb := bus.NewMessageBus()
ch, err := NewPicoClientChannel(config.PicoClientConfig{
URL: wsURL(srv.URL),
SessionID: "sess-echo",
ReadTimeout: 10,
}, mb)
if err != nil {
t.Fatal(err)
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err = ch.Start(ctx); err != nil {
t.Fatalf("Start: %v", err)
}
defer ch.Stop(ctx)
// Send a message; the echo server replies with message.create
err = ch.Send(ctx, bus.OutboundMessage{
ChatID: "pico_client:sess-echo",
Content: "ping",
})
if err != nil {
t.Fatalf("Send: %v", err)
}
// The echoed message.create is processed by handleServerMessage which
// calls HandleMessage → PublishInbound. Consume it from the bus.
select {
case msg := <-mb.InboundChan():
if msg.Content != "ping" {
t.Fatalf("received = %q, want %q", msg.Content, "ping")
}
case <-ctx.Done():
t.Fatal("timed out waiting for echoed message")
}
}
func TestClientChannel_StartTyping(t *testing.T) {
srv := testServer(t, "")
defer srv.Close()
ch, err := NewPicoClientChannel(config.PicoClientConfig{
URL: wsURL(srv.URL),
SessionID: "sess-type",
ReadTimeout: 10,
}, bus.NewMessageBus())
if err != nil {
t.Fatal(err)
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err = ch.Start(ctx); err != nil {
t.Fatalf("Start: %v", err)
}
defer ch.Stop(ctx)
stop, err := ch.StartTyping(ctx, "pico_client:sess-type")
if err != nil {
t.Fatalf("StartTyping: %v", err)
}
stop() // should not panic
}
func TestSend_ClosedConnection(t *testing.T) {
srv := testServer(t, "")
defer srv.Close()
ch, err := NewPicoClientChannel(config.PicoClientConfig{
URL: wsURL(srv.URL),
SessionID: "sess-close",
ReadTimeout: 10,
}, bus.NewMessageBus())
if err != nil {
t.Fatal(err)
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err = ch.Start(ctx); err != nil {
t.Fatalf("Start: %v", err)
}
// Force close the underlying connection
ch.mu.Lock()
ch.conn.close()
ch.mu.Unlock()
err = ch.Send(ctx, bus.OutboundMessage{
ChatID: "pico_client:sess-close",
Content: "should fail",
})
if !errors.Is(err, channels.ErrSendFailed) {
t.Fatalf("expected ErrSendFailed, got %v", err)
}
ch.Stop(ctx)
}