Files
picoclaw/web/backend/api/pico_test.go
T
lxowalle 451db2f5d8 Feat(channels): unify animated tool feedback across chat channels and Pico (#2622)
* feat(channels): unify tool feedback animation across discord telegram and feishu

* fix(tool-feedback): unify fallback and single-message delivery

* fix(channels): finalize tool feedback in place

* fix ci

* feat: improve tool feedback

* fix review blockers in pico token cache and tool feedback

fix(provider): preserve function thought signatures

fix(feishu): recover tool feedback after edit fallback

* * delete dead code

* fix(pico): clean up tool feedback progress state

* fix ci

* fix(web): preserve tool feedback line breaks in chat

* fix(channels): preserve tool feedback progress state

fix(pico): preserve context usage when finalizing tool feedback

chore: record branch review pass

fix: preserve tool feedback finalization state

fix(web): handle pico history update fallback

* fix ci
2026-04-23 10:35:50 +08:00

978 lines
26 KiB
Go

package api
import (
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"net/url"
"os"
"path/filepath"
"strconv"
"strings"
"testing"
"github.com/sipeed/picoclaw/pkg/config"
ppid "github.com/sipeed/picoclaw/pkg/pid"
)
func newPicoProxyRequest(method, path string) *http.Request {
req := httptest.NewRequest(method, "http://launcher.local:18800"+path, nil)
req.Header.Set("Origin", "http://launcher.local:18800")
return req
}
func TestEnsurePicoChannel_FreshConfig(t *testing.T) {
configPath := filepath.Join(t.TempDir(), "config.json")
h := NewHandler(configPath)
changed, err := h.EnsurePicoChannel()
if err != nil {
t.Fatalf("EnsurePicoChannel() error = %v", err)
}
if !changed {
t.Fatal("EnsurePicoChannel() should report changed on a fresh config")
}
cfg, err := config.LoadConfig(configPath)
if err != nil {
t.Fatalf("LoadConfig() error = %v", err)
}
bc := cfg.Channels["pico"]
decoded, err := bc.GetDecoded()
if err != nil {
t.Fatalf("GetDecoded() error = %v", err)
}
picoCfg := decoded.(*config.PicoSettings)
if !bc.Enabled {
t.Error("expected Pico to be enabled after setup")
}
if picoCfg.Token.String() == "" {
t.Error("expected a non-empty token after setup")
}
}
func TestEnsurePicoChannel_DoesNotEnableTokenQuery(t *testing.T) {
configPath := filepath.Join(t.TempDir(), "config.json")
h := NewHandler(configPath)
if _, err := h.EnsurePicoChannel(); err != nil {
t.Fatalf("EnsurePicoChannel() error = %v", err)
}
cfg, err := config.LoadConfig(configPath)
if err != nil {
t.Fatalf("LoadConfig() error = %v", err)
}
bc := cfg.Channels["pico"]
decoded, err := bc.GetDecoded()
if err != nil {
t.Fatalf("GetDecoded() error = %v", err)
}
picoCfg := decoded.(*config.PicoSettings)
if picoCfg.AllowTokenQuery {
t.Error("setup must not enable allow_token_query by default")
}
}
func TestEnsurePicoChannel_LeavesAllowOriginsEmptyByDefault(t *testing.T) {
configPath := filepath.Join(t.TempDir(), "config.json")
h := NewHandler(configPath)
if _, err := h.EnsurePicoChannel(); err != nil {
t.Fatalf("EnsurePicoChannel() error = %v", err)
}
cfg, err := config.LoadConfig(configPath)
if err != nil {
t.Fatalf("LoadConfig() error = %v", err)
}
bc := cfg.Channels["pico"]
decoded, err := bc.GetDecoded()
if err != nil {
t.Fatalf("GetDecoded() error = %v", err)
}
picoCfg := decoded.(*config.PicoSettings)
if len(picoCfg.AllowOrigins) != 0 {
t.Errorf("allow_origins = %v, want empty", picoCfg.AllowOrigins)
}
}
func TestEnsurePicoChannel_NoOriginConfigurationRequired(t *testing.T) {
configPath := filepath.Join(t.TempDir(), "config.json")
h := NewHandler(configPath)
if _, err := h.EnsurePicoChannel(); err != nil {
t.Fatalf("EnsurePicoChannel() error = %v", err)
}
cfg, err := config.LoadConfig(configPath)
if err != nil {
t.Fatalf("LoadConfig() error = %v", err)
}
bc := cfg.Channels["pico"]
decoded, err := bc.GetDecoded()
if err != nil {
t.Fatalf("GetDecoded() error = %v", err)
}
picoCfg := decoded.(*config.PicoSettings)
if len(picoCfg.AllowOrigins) != 0 {
t.Errorf("allow_origins = %v, want empty", picoCfg.AllowOrigins)
}
}
func TestEnsurePicoChannel_PreservesUserSettings(t *testing.T) {
configPath := filepath.Join(t.TempDir(), "config.json")
// Pre-configure with custom user settings
cfg := config.DefaultConfig()
bc := cfg.Channels["pico"]
decoded, err := bc.GetDecoded()
if err != nil {
t.Fatalf("GetDecoded() error = %v", err)
}
picoCfg := decoded.(*config.PicoSettings)
bc.Enabled = true
picoCfg.SetToken("user-custom-token")
picoCfg.AllowTokenQuery = true
picoCfg.AllowOrigins = []string{"https://myapp.example.com"}
if err = config.SaveConfig(configPath, cfg); err != nil {
t.Fatalf("SaveConfig() error = %v", err)
}
h := NewHandler(configPath)
changed, err := h.EnsurePicoChannel()
if err != nil {
t.Fatalf("EnsurePicoChannel() error = %v", err)
}
if changed {
t.Error("EnsurePicoChannel() should not change a fully configured config")
}
cfg, err = config.LoadConfig(configPath)
if err != nil {
t.Fatalf("LoadConfig() error = %v", err)
}
bc = cfg.Channels["pico"]
decoded, err = bc.GetDecoded()
if err != nil {
t.Fatalf("GetDecoded() error = %v", err)
}
picoCfg = decoded.(*config.PicoSettings)
if picoCfg.Token.String() != "user-custom-token" {
t.Errorf("token = %q, want %q", picoCfg.Token.String(), "user-custom-token")
}
if !picoCfg.AllowTokenQuery {
t.Error("user's allow_token_query=true must be preserved")
}
if len(picoCfg.AllowOrigins) != 1 || picoCfg.AllowOrigins[0] != "https://myapp.example.com" {
t.Errorf("allow_origins = %v, want [https://myapp.example.com]", picoCfg.AllowOrigins)
}
}
func TestEnsurePicoChannel_ExistingConfigWithoutSecurityFile(t *testing.T) {
configPath := filepath.Join(t.TempDir(), "config.json")
cfg := config.DefaultConfig()
raw, err := json.Marshal(cfg)
if err != nil {
t.Fatalf("Marshal() error = %v", err)
}
if err = os.WriteFile(configPath, raw, 0o600); err != nil {
t.Fatalf("WriteFile() error = %v", err)
}
h := NewHandler(configPath)
changed, err := h.EnsurePicoChannel()
if err != nil {
t.Fatalf("EnsurePicoChannel() error = %v", err)
}
if !changed {
t.Fatal("EnsurePicoChannel() should report changed when pico is missing")
}
cfg, err = config.LoadConfig(configPath)
if err != nil {
t.Fatalf("LoadConfig() error = %v", err)
}
bc := cfg.Channels["pico"]
decoded, err := bc.GetDecoded()
if err != nil {
t.Fatalf("GetDecoded() error = %v", err)
}
picoCfg := decoded.(*config.PicoSettings)
if !bc.Enabled {
t.Error("expected Pico to be enabled after setup")
}
if picoCfg.Token.String() == "" {
t.Error("expected a non-empty token after setup")
}
if _, err := os.Stat(filepath.Join(filepath.Dir(configPath), config.SecurityConfigFile)); err != nil {
t.Fatalf("expected .security.yml to be created: %v", err)
}
}
func TestEnsurePicoChannel_ConfiguresPicoWithoutGateway(t *testing.T) {
configPath := filepath.Join(t.TempDir(), "config.json")
cfg := config.DefaultConfig()
cfg.Agents.Defaults.ModelName = ""
if err := config.SaveConfig(configPath, cfg); err != nil {
t.Fatalf("SaveConfig() error = %v", err)
}
h := NewHandler(configPath)
if _, err := h.EnsurePicoChannel(); err != nil {
t.Fatalf("EnsurePicoChannel() error = %v", err)
}
cfg, err := config.LoadConfig(configPath)
if err != nil {
t.Fatalf("LoadConfig() error = %v", err)
}
bc := cfg.Channels["pico"]
decoded, err := bc.GetDecoded()
if err != nil {
t.Fatalf("GetDecoded() error = %v", err)
}
picoCfg := decoded.(*config.PicoSettings)
if !bc.Enabled {
t.Error("expected Pico to be enabled after launcher startup setup")
}
if picoCfg.Token.String() == "" {
t.Error("expected a non-empty token after launcher startup setup")
}
}
func TestEnsurePicoChannel_Idempotent(t *testing.T) {
configPath := filepath.Join(t.TempDir(), "config.json")
h := NewHandler(configPath)
// First call sets things up
if _, err := h.EnsurePicoChannel(); err != nil {
t.Fatalf("first EnsurePicoChannel() error = %v", err)
}
cfg1, _ := config.LoadConfig(configPath)
bc := cfg1.Channels["pico"]
decoded, err := bc.GetDecoded()
if err != nil {
t.Fatalf("GetDecoded() error = %v", err)
}
picoCfg := decoded.(*config.PicoSettings)
token1 := picoCfg.Token.String()
// Second call should be a no-op
changed, err := h.EnsurePicoChannel()
if err != nil {
t.Fatalf("second EnsurePicoChannel() error = %v", err)
}
if changed {
t.Error("second EnsurePicoChannel() should not report changed")
}
cfg2, _ := config.LoadConfig(configPath)
bc = cfg2.Channels["pico"]
decoded, err = bc.GetDecoded()
if err != nil {
t.Fatalf("GetDecoded() error = %v", err)
}
picoCfg = decoded.(*config.PicoSettings)
if picoCfg.Token.String() != token1 {
t.Error("token should not change on subsequent calls")
}
}
func TestHandlePicoSetup_DoesNotPersistRequestOrigin(t *testing.T) {
configPath := filepath.Join(t.TempDir(), "config.json")
h := NewHandler(configPath)
req := httptest.NewRequest("POST", "/api/pico/setup", nil)
req.Header.Set("Origin", "http://10.0.0.5:3000")
rec := httptest.NewRecorder()
h.handlePicoSetup(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
}
cfg, err := config.LoadConfig(configPath)
if err != nil {
t.Fatalf("LoadConfig() error = %v", err)
}
bc := cfg.Channels["pico"]
decoded, err := bc.GetDecoded()
if err != nil {
t.Fatalf("GetDecoded() error = %v", err)
}
picoCfg := decoded.(*config.PicoSettings)
if len(picoCfg.AllowOrigins) != 0 {
t.Errorf("allow_origins = %v, want empty", picoCfg.AllowOrigins)
}
}
func TestHandlePicoSetup_Response(t *testing.T) {
configPath := filepath.Join(t.TempDir(), "config.json")
h := NewHandler(configPath)
req := httptest.NewRequest("POST", "/api/pico/setup", nil)
rec := httptest.NewRecorder()
h.handlePicoSetup(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
}
var resp map[string]any
if err := json.NewDecoder(rec.Body).Decode(&resp); err != nil {
t.Fatalf("failed to decode response: %v", err)
}
if _, ok := resp["token"]; ok {
t.Error("response must not expose the raw pico token")
}
if resp["ws_url"] == nil || resp["ws_url"] == "" {
t.Error("response should contain ws_url")
}
if resp["enabled"] != true {
t.Error("response should have enabled=true")
}
if resp["changed"] != true {
t.Error("response should have changed=true on first setup")
}
if resp["configured"] != true {
t.Error("response should have configured=true")
}
}
func TestHandleGetPicoInfo_OmitsToken(t *testing.T) {
configPath := filepath.Join(t.TempDir(), "config.json")
h := NewHandler(configPath)
if _, err := h.EnsurePicoChannel(); err != nil {
t.Fatalf("EnsurePicoChannel() error = %v", err)
}
req := httptest.NewRequest(http.MethodGet, "http://launcher.local/api/pico/info", nil)
rec := httptest.NewRecorder()
h.handleGetPicoInfo(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
}
var resp map[string]any
if err := json.NewDecoder(rec.Body).Decode(&resp); err != nil {
t.Fatalf("failed to decode response: %v", err)
}
if _, ok := resp["token"]; ok {
t.Fatal("info response must not expose the raw pico token")
}
if resp["enabled"] != true {
t.Fatalf("enabled = %#v, want true", resp["enabled"])
}
if resp["configured"] != true {
t.Fatalf("configured = %#v, want true", resp["configured"])
}
if resp["ws_url"] == nil || resp["ws_url"] == "" {
t.Fatal("response should contain ws_url")
}
}
func TestHandleRegenPicoToken_RefreshesGatewayTokenCache(t *testing.T) {
configPath := filepath.Join(t.TempDir(), "config.json")
h := NewHandler(configPath)
if _, err := h.EnsurePicoChannel(); err != nil {
t.Fatalf("EnsurePicoChannel() error = %v", err)
}
origPicoToken := gateway.picoToken
t.Cleanup(func() {
gateway.mu.Lock()
gateway.picoToken = origPicoToken
gateway.mu.Unlock()
})
gateway.mu.Lock()
gateway.picoToken = "stale-token"
gateway.mu.Unlock()
req := httptest.NewRequest(http.MethodPost, "http://launcher.local/api/pico/token", nil)
rec := httptest.NewRecorder()
h.handleRegenPicoToken(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
}
cfg, err := config.LoadConfig(configPath)
if err != nil {
t.Fatalf("LoadConfig() error = %v", err)
}
bc := cfg.Channels["pico"]
decoded, err := bc.GetDecoded()
if err != nil {
t.Fatalf("GetDecoded() error = %v", err)
}
token := decoded.(*config.PicoSettings).Token.String()
if token == "" {
t.Fatal("expected regenerated pico token to be persisted")
}
if token == "stale-token" {
t.Fatal("expected regenerated pico token to differ from stale cache")
}
gateway.mu.Lock()
defer gateway.mu.Unlock()
if gateway.picoToken != token {
t.Fatalf("gateway.picoToken = %q, want %q", gateway.picoToken, token)
}
}
func TestHandleWebSocketProxyReloadsGatewayTargetFromConfig(t *testing.T) {
origMatcher := gatewayProcessMatcher
gatewayProcessMatcher = func(int) (bool, bool) { return true, true }
t.Cleanup(func() { gatewayProcessMatcher = origMatcher })
home := t.TempDir()
t.Setenv("PICOCLAW_HOME", home)
configPath := filepath.Join(t.TempDir(), "config.json")
h := NewHandler(configPath)
handler := h.handleWebSocketProxy()
server1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/pico/ws" {
t.Fatalf("server1 path = %q, want %q", r.URL.Path, "/pico/ws")
}
w.WriteHeader(http.StatusOK)
_, _ = io.WriteString(w, "server1")
}))
defer server1.Close()
server2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/pico/ws" {
t.Fatalf("server2 path = %q, want %q", r.URL.Path, "/pico/ws")
}
w.WriteHeader(http.StatusOK)
_, _ = io.WriteString(w, "server2")
}))
defer server2.Close()
cfg := config.DefaultConfig()
cfg.Gateway.Host = "127.0.0.1"
cfg.Gateway.Port = mustGatewayTestPort(t, server1.URL)
if err := config.SaveConfig(configPath, cfg); err != nil {
t.Fatalf("SaveConfig() error = %v", err)
}
cmd := startGatewayLikeProcess(t)
t.Cleanup(func() {
if cmd.Process != nil {
_ = cmd.Process.Kill()
}
_ = cmd.Wait()
})
writeTestPidFile(t, ppid.PidFileData{
PID: cmd.Process.Pid,
Token: "test-token",
Host: cfg.Gateway.Host,
Port: cfg.Gateway.Port,
})
origPidData := gateway.pidData
origPicoToken := gateway.picoToken
t.Cleanup(func() {
ppid.RemovePidFile(globalConfigDir())
gateway.pidData = origPidData
gateway.picoToken = origPicoToken
})
gateway.pidData = &ppid.PidFileData{}
gateway.picoToken = "pico"
req1 := newPicoProxyRequest(http.MethodGet, "/pico/ws")
rec1 := httptest.NewRecorder()
handler(rec1, req1)
if rec1.Code != http.StatusOK {
t.Fatalf("first status = %d, want %d", rec1.Code, http.StatusOK)
}
if body := rec1.Body.String(); body != "server1" {
t.Fatalf("first body = %q, want %q", body, "server1")
}
cfg.Gateway.Port = mustGatewayTestPort(t, server2.URL)
if err := config.SaveConfig(configPath, cfg); err != nil {
t.Fatalf("SaveConfig() error = %v", err)
}
req2 := newPicoProxyRequest(http.MethodGet, "/pico/ws")
rec2 := httptest.NewRecorder()
handler(rec2, req2)
if rec2.Code != http.StatusOK {
t.Fatalf("second status = %d, want %d", rec2.Code, http.StatusOK)
}
if body := rec2.Body.String(); body != "server2" {
t.Fatalf("second body = %q, want %q", body, "server2")
}
}
func TestHandleWebSocketProxyLoadsCachedPicoTokenWhenMissing(t *testing.T) {
origMatcher := gatewayProcessMatcher
gatewayProcessMatcher = func(int) (bool, bool) { return true, true }
t.Cleanup(func() { gatewayProcessMatcher = origMatcher })
home := t.TempDir()
t.Setenv("PICOCLAW_HOME", home)
configPath := filepath.Join(t.TempDir(), "config.json")
h := NewHandler(configPath)
handler := h.handleWebSocketProxy()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/pico/ws" {
t.Fatalf("path = %q, want %q", r.URL.Path, "/pico/ws")
}
w.WriteHeader(http.StatusOK)
_, _ = io.WriteString(w, "proxied")
}))
defer server.Close()
cfg := config.DefaultConfig()
cfg.Gateway.Host = "127.0.0.1"
cfg.Gateway.Port = mustGatewayTestPort(t, server.URL)
bc := cfg.Channels["pico"]
decoded, err := bc.GetDecoded()
if err != nil {
t.Fatalf("GetDecoded() error = %v", err)
}
picoCfg := decoded.(*config.PicoSettings)
bc.Enabled = true
picoCfg.SetToken("cached-token")
if err := config.SaveConfig(configPath, cfg); err != nil {
t.Fatalf("SaveConfig() error = %v", err)
}
cmd := startGatewayLikeProcess(t)
t.Cleanup(func() {
if cmd.Process != nil {
_ = cmd.Process.Kill()
}
_ = cmd.Wait()
})
writeTestPidFile(t, ppid.PidFileData{
PID: cmd.Process.Pid,
Token: "test-token",
Host: cfg.Gateway.Host,
Port: cfg.Gateway.Port,
})
t.Cleanup(func() {
ppid.RemovePidFile(globalConfigDir())
})
origPidData := gateway.pidData
origPicoToken := gateway.picoToken
t.Cleanup(func() {
gateway.pidData = origPidData
gateway.picoToken = origPicoToken
})
gateway.pidData = &ppid.PidFileData{}
gateway.picoToken = ""
req := newPicoProxyRequest(http.MethodGet, "/pico/ws?session_id=test-session")
rec := httptest.NewRecorder()
handler(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
}
if body := rec.Body.String(); body != "proxied" {
t.Fatalf("body = %q, want %q", body, "proxied")
}
if gateway.picoToken != "cached-token" {
t.Fatalf("gateway.picoToken = %q, want %q", gateway.picoToken, "cached-token")
}
}
func TestHandleWebSocketProxyLoadsPidDataOnDemand(t *testing.T) {
origMatcher := gatewayProcessMatcher
gatewayProcessMatcher = func(int) (bool, bool) { return true, true }
t.Cleanup(func() { gatewayProcessMatcher = origMatcher })
home := t.TempDir()
t.Setenv("PICOCLAW_HOME", home)
configPath := filepath.Join(t.TempDir(), "config.json")
h := NewHandler(configPath)
handler := h.handleWebSocketProxy()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/pico/ws" {
t.Fatalf("path = %q, want %q", r.URL.Path, "/pico/ws")
}
w.WriteHeader(http.StatusOK)
_, _ = io.WriteString(w, r.Header.Get(protocolKey))
}))
defer server.Close()
cfg := config.DefaultConfig()
cfg.Gateway.Host = "127.0.0.1"
cfg.Gateway.Port = mustGatewayTestPort(t, server.URL)
bc := cfg.Channels["pico"]
bc.Enabled = true
decoded, err := bc.GetDecoded()
if err != nil {
t.Fatalf("GetDecoded() error = %v", err)
}
decoded.(*config.PicoSettings).SetToken("ui-token")
if err := config.SaveConfig(configPath, cfg); err != nil {
t.Fatalf("SaveConfig() error = %v", err)
}
cmd := startGatewayLikeProcess(t)
t.Cleanup(func() {
if cmd.Process != nil {
_ = cmd.Process.Kill()
}
_ = cmd.Wait()
})
pidData := ppid.PidFileData{
PID: cmd.Process.Pid,
Token: "test-token",
Host: cfg.Gateway.Host,
Port: cfg.Gateway.Port,
}
writeTestPidFile(t, pidData)
t.Cleanup(func() {
ppid.RemovePidFile(globalConfigDir())
})
origPidData := gateway.pidData
origPicoToken := gateway.picoToken
origStatus := gateway.runtimeStatus
t.Cleanup(func() {
gateway.mu.Lock()
gateway.pidData = origPidData
gateway.picoToken = origPicoToken
gateway.runtimeStatus = origStatus
gateway.mu.Unlock()
})
gateway.mu.Lock()
gateway.pidData = nil
gateway.picoToken = ""
setGatewayRuntimeStatusLocked("stopped")
gateway.mu.Unlock()
req := newPicoProxyRequest(http.MethodGet, "/pico/ws?session_id=test-session")
rec := httptest.NewRecorder()
handler(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
}
expected := tokenPrefix + "ui-token"
if got := rec.Body.String(); got != expected {
t.Fatalf("forwarded protocol = %q, want %q", got, expected)
}
gateway.mu.Lock()
defer gateway.mu.Unlock()
if gateway.pidData == nil {
t.Fatal("gateway.pidData should be loaded from pid file")
}
if gateway.runtimeStatus != "running" {
t.Fatalf("runtimeStatus = %q, want %q", gateway.runtimeStatus, "running")
}
}
func TestCreatePicoHTTPProxyInjectsGatewayAuth(t *testing.T) {
configPath := filepath.Join(t.TempDir(), "config.json")
h := NewHandler(configPath)
cfg := config.DefaultConfig()
cfg.Gateway.Host = "127.0.0.1"
cfg.Gateway.Port = 18790
bc := cfg.Channels["pico"]
bc.Enabled = true
decoded, err := bc.GetDecoded()
if err != nil {
t.Fatalf("GetDecoded() error = %v", err)
}
decoded.(*config.PicoSettings).SetToken("ui-token")
if err := config.SaveConfig(configPath, cfg); err != nil {
t.Fatalf("SaveConfig() error = %v", err)
}
proxy := h.createPicoHTTPProxy("ui-token")
var capturedPath string
var capturedAuth string
proxy.Transport = roundTripFunc(func(req *http.Request) (*http.Response, error) {
capturedPath = req.URL.Path
capturedAuth = req.Header.Get("Authorization")
return &http.Response{
StatusCode: http.StatusOK,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader("proxied")),
Request: req,
}, nil
})
req := httptest.NewRequest(http.MethodGet, "/pico/media/attachment-1", nil)
rec := httptest.NewRecorder()
proxy.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
}
if capturedPath != "/pico/media/attachment-1" {
t.Fatalf("capturedPath = %q, want %q", capturedPath, "/pico/media/attachment-1")
}
expected := "Bearer ui-token"
if capturedAuth != expected {
t.Fatalf("Authorization = %q, want %q", capturedAuth, expected)
}
}
func TestHandlePicoMediaProxyUsesRawBearerToken(t *testing.T) {
home := t.TempDir()
t.Setenv("PICOCLAW_HOME", home)
configPath := filepath.Join(t.TempDir(), "config.json")
h := NewHandler(configPath)
handler := h.handlePicoMediaProxy()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/pico/media/attachment-1" {
t.Fatalf("path = %q, want %q", r.URL.Path, "/pico/media/attachment-1")
}
if got := r.Header.Get("Authorization"); got != "Bearer ui-token" {
t.Fatalf("Authorization = %q, want %q", got, "Bearer ui-token")
}
w.WriteHeader(http.StatusOK)
_, _ = io.WriteString(w, "proxied-media")
}))
defer server.Close()
cfg := config.DefaultConfig()
cfg.Gateway.Host = "127.0.0.1"
cfg.Gateway.Port = mustGatewayTestPort(t, server.URL)
bc := cfg.Channels["pico"]
bc.Enabled = true
decoded, err := bc.GetDecoded()
if err != nil {
t.Fatalf("GetDecoded() error = %v", err)
}
decoded.(*config.PicoSettings).SetToken("ui-token")
if err := config.SaveConfig(configPath, cfg); err != nil {
t.Fatalf("SaveConfig() error = %v", err)
}
cmd := startGatewayLikeProcess(t)
t.Cleanup(func() {
if cmd.Process != nil {
_ = cmd.Process.Kill()
}
_ = cmd.Wait()
})
origPidData := gateway.pidData
origPicoToken := gateway.picoToken
origCmd := gateway.cmd
t.Cleanup(func() {
gateway.mu.Lock()
gateway.pidData = origPidData
gateway.picoToken = origPicoToken
gateway.cmd = origCmd
gateway.mu.Unlock()
})
gateway.mu.Lock()
gateway.pidData = &ppid.PidFileData{PID: cmd.Process.Pid}
gateway.picoToken = "ui-token"
gateway.cmd = cmd
gateway.mu.Unlock()
req := newPicoProxyRequest(http.MethodGet, "/pico/media/attachment-1")
rec := httptest.NewRecorder()
handler(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
}
if body := rec.Body.String(); body != "proxied-media" {
t.Fatalf("body = %q, want %q", body, "proxied-media")
}
}
func TestHandleWebSocketProxyRejectsStalePidDataAfterProcessExit(t *testing.T) {
tmpDir := t.TempDir()
t.Setenv("HOME", tmpDir)
t.Setenv("PICOCLAW_HOME", filepath.Join(tmpDir, ".picoclaw"))
configPath := filepath.Join(tmpDir, "config.json")
h := NewHandler(configPath)
handler := h.handleWebSocketProxy()
cfg := config.DefaultConfig()
bc := cfg.Channels["pico"]
bc.Enabled = true
decoded, err := bc.GetDecoded()
if err != nil {
t.Fatalf("GetDecoded() error = %v", err)
}
decoded.(*config.PicoSettings).SetToken("ui-token")
if err := config.SaveConfig(configPath, cfg); err != nil {
t.Fatalf("SaveConfig() error = %v", err)
}
cmd := startLongRunningProcess(t)
if cmd.Process != nil {
_ = cmd.Process.Kill()
}
_ = cmd.Wait()
origPidData := gateway.pidData
origPicoToken := gateway.picoToken
origCmd := gateway.cmd
origStatus := gateway.runtimeStatus
t.Cleanup(func() {
gateway.mu.Lock()
gateway.pidData = origPidData
gateway.picoToken = origPicoToken
gateway.cmd = origCmd
gateway.runtimeStatus = origStatus
gateway.mu.Unlock()
})
gateway.mu.Lock()
gateway.pidData = &ppid.PidFileData{PID: cmd.Process.Pid, Token: "stale-token"}
gateway.picoToken = "ui-token"
gateway.cmd = cmd
setGatewayRuntimeStatusLocked("running")
gateway.mu.Unlock()
req := newPicoProxyRequest(http.MethodGet, "/pico/ws?session_id=test-session")
rec := httptest.NewRecorder()
handler(rec, req)
if rec.Code != http.StatusServiceUnavailable {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusServiceUnavailable)
}
gateway.mu.Lock()
defer gateway.mu.Unlock()
if gateway.pidData != nil {
t.Fatal("gateway.pidData should be cleared after stale process exit is detected")
}
}
func TestHandleWebSocketProxy_AllowsArbitraryOrigin(t *testing.T) {
origMatcher := gatewayProcessMatcher
gatewayProcessMatcher = func(int) (bool, bool) { return true, true }
t.Cleanup(func() { gatewayProcessMatcher = origMatcher })
home := t.TempDir()
t.Setenv("PICOCLAW_HOME", home)
configPath := filepath.Join(t.TempDir(), "config.json")
h := NewHandler(configPath)
handler := h.handleWebSocketProxy()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/pico/ws" {
t.Fatalf("path = %q, want %q", r.URL.Path, "/pico/ws")
}
w.WriteHeader(http.StatusOK)
_, _ = io.WriteString(w, "proxied")
}))
defer server.Close()
cfg := config.DefaultConfig()
cfg.Gateway.Host = "127.0.0.1"
cfg.Gateway.Port = mustGatewayTestPort(t, server.URL)
bc := cfg.Channels["pico"]
bc.Enabled = true
decoded, err := bc.GetDecoded()
if err != nil {
t.Fatalf("GetDecoded() error = %v", err)
}
decoded.(*config.PicoSettings).SetToken("ui-token")
if err := config.SaveConfig(configPath, cfg); err != nil {
t.Fatalf("SaveConfig() error = %v", err)
}
cmd := startGatewayLikeProcess(t)
t.Cleanup(func() {
if cmd.Process != nil {
_ = cmd.Process.Kill()
}
_ = cmd.Wait()
})
writeTestPidFile(t, ppid.PidFileData{
PID: cmd.Process.Pid,
Token: "test-token",
Host: cfg.Gateway.Host,
Port: cfg.Gateway.Port,
})
t.Cleanup(func() {
ppid.RemovePidFile(globalConfigDir())
})
origPidData := gateway.pidData
origPicoToken := gateway.picoToken
t.Cleanup(func() {
gateway.pidData = origPidData
gateway.picoToken = origPicoToken
})
gateway.pidData = &ppid.PidFileData{}
gateway.picoToken = "ui-token"
req := httptest.NewRequest(http.MethodGet, "http://launcher.local/pico/ws?session_id=test-session", nil)
req.Header.Set("Origin", "http://evil.example")
rec := httptest.NewRecorder()
handler(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
}
}
func mustGatewayTestPort(t *testing.T, rawURL string) int {
t.Helper()
parsed, err := url.Parse(rawURL)
if err != nil {
t.Fatalf("url.Parse() error = %v", err)
}
port, err := strconv.Atoi(parsed.Port())
if err != nil {
t.Fatalf("Atoi(%q) error = %v", parsed.Port(), err)
}
return port
}
type roundTripFunc func(*http.Request) (*http.Response, error)
func (fn roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return fn(req)
}