mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
451db2f5d8
* feat(channels): unify tool feedback animation across discord telegram and feishu * fix(tool-feedback): unify fallback and single-message delivery * fix(channels): finalize tool feedback in place * fix ci * feat: improve tool feedback * fix review blockers in pico token cache and tool feedback fix(provider): preserve function thought signatures fix(feishu): recover tool feedback after edit fallback * * delete dead code * fix(pico): clean up tool feedback progress state * fix ci * fix(web): preserve tool feedback line breaks in chat * fix(channels): preserve tool feedback progress state fix(pico): preserve context usage when finalizing tool feedback chore: record branch review pass fix: preserve tool feedback finalization state fix(web): handle pico history update fallback * fix ci
978 lines
26 KiB
Go
978 lines
26 KiB
Go
package api
|
|
|
|
import (
|
|
"encoding/json"
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"net/url"
|
|
"os"
|
|
"path/filepath"
|
|
"strconv"
|
|
"strings"
|
|
"testing"
|
|
|
|
"github.com/sipeed/picoclaw/pkg/config"
|
|
ppid "github.com/sipeed/picoclaw/pkg/pid"
|
|
)
|
|
|
|
func newPicoProxyRequest(method, path string) *http.Request {
|
|
req := httptest.NewRequest(method, "http://launcher.local:18800"+path, nil)
|
|
req.Header.Set("Origin", "http://launcher.local:18800")
|
|
return req
|
|
}
|
|
|
|
func TestEnsurePicoChannel_FreshConfig(t *testing.T) {
|
|
configPath := filepath.Join(t.TempDir(), "config.json")
|
|
h := NewHandler(configPath)
|
|
|
|
changed, err := h.EnsurePicoChannel()
|
|
if err != nil {
|
|
t.Fatalf("EnsurePicoChannel() error = %v", err)
|
|
}
|
|
if !changed {
|
|
t.Fatal("EnsurePicoChannel() should report changed on a fresh config")
|
|
}
|
|
|
|
cfg, err := config.LoadConfig(configPath)
|
|
if err != nil {
|
|
t.Fatalf("LoadConfig() error = %v", err)
|
|
}
|
|
|
|
bc := cfg.Channels["pico"]
|
|
decoded, err := bc.GetDecoded()
|
|
if err != nil {
|
|
t.Fatalf("GetDecoded() error = %v", err)
|
|
}
|
|
picoCfg := decoded.(*config.PicoSettings)
|
|
if !bc.Enabled {
|
|
t.Error("expected Pico to be enabled after setup")
|
|
}
|
|
if picoCfg.Token.String() == "" {
|
|
t.Error("expected a non-empty token after setup")
|
|
}
|
|
}
|
|
|
|
func TestEnsurePicoChannel_DoesNotEnableTokenQuery(t *testing.T) {
|
|
configPath := filepath.Join(t.TempDir(), "config.json")
|
|
h := NewHandler(configPath)
|
|
|
|
if _, err := h.EnsurePicoChannel(); err != nil {
|
|
t.Fatalf("EnsurePicoChannel() error = %v", err)
|
|
}
|
|
|
|
cfg, err := config.LoadConfig(configPath)
|
|
if err != nil {
|
|
t.Fatalf("LoadConfig() error = %v", err)
|
|
}
|
|
|
|
bc := cfg.Channels["pico"]
|
|
decoded, err := bc.GetDecoded()
|
|
if err != nil {
|
|
t.Fatalf("GetDecoded() error = %v", err)
|
|
}
|
|
picoCfg := decoded.(*config.PicoSettings)
|
|
if picoCfg.AllowTokenQuery {
|
|
t.Error("setup must not enable allow_token_query by default")
|
|
}
|
|
}
|
|
|
|
func TestEnsurePicoChannel_LeavesAllowOriginsEmptyByDefault(t *testing.T) {
|
|
configPath := filepath.Join(t.TempDir(), "config.json")
|
|
h := NewHandler(configPath)
|
|
|
|
if _, err := h.EnsurePicoChannel(); err != nil {
|
|
t.Fatalf("EnsurePicoChannel() error = %v", err)
|
|
}
|
|
|
|
cfg, err := config.LoadConfig(configPath)
|
|
if err != nil {
|
|
t.Fatalf("LoadConfig() error = %v", err)
|
|
}
|
|
|
|
bc := cfg.Channels["pico"]
|
|
decoded, err := bc.GetDecoded()
|
|
if err != nil {
|
|
t.Fatalf("GetDecoded() error = %v", err)
|
|
}
|
|
picoCfg := decoded.(*config.PicoSettings)
|
|
if len(picoCfg.AllowOrigins) != 0 {
|
|
t.Errorf("allow_origins = %v, want empty", picoCfg.AllowOrigins)
|
|
}
|
|
}
|
|
|
|
func TestEnsurePicoChannel_NoOriginConfigurationRequired(t *testing.T) {
|
|
configPath := filepath.Join(t.TempDir(), "config.json")
|
|
h := NewHandler(configPath)
|
|
|
|
if _, err := h.EnsurePicoChannel(); err != nil {
|
|
t.Fatalf("EnsurePicoChannel() error = %v", err)
|
|
}
|
|
|
|
cfg, err := config.LoadConfig(configPath)
|
|
if err != nil {
|
|
t.Fatalf("LoadConfig() error = %v", err)
|
|
}
|
|
|
|
bc := cfg.Channels["pico"]
|
|
decoded, err := bc.GetDecoded()
|
|
if err != nil {
|
|
t.Fatalf("GetDecoded() error = %v", err)
|
|
}
|
|
picoCfg := decoded.(*config.PicoSettings)
|
|
if len(picoCfg.AllowOrigins) != 0 {
|
|
t.Errorf("allow_origins = %v, want empty", picoCfg.AllowOrigins)
|
|
}
|
|
}
|
|
|
|
func TestEnsurePicoChannel_PreservesUserSettings(t *testing.T) {
|
|
configPath := filepath.Join(t.TempDir(), "config.json")
|
|
|
|
// Pre-configure with custom user settings
|
|
cfg := config.DefaultConfig()
|
|
bc := cfg.Channels["pico"]
|
|
decoded, err := bc.GetDecoded()
|
|
if err != nil {
|
|
t.Fatalf("GetDecoded() error = %v", err)
|
|
}
|
|
picoCfg := decoded.(*config.PicoSettings)
|
|
bc.Enabled = true
|
|
picoCfg.SetToken("user-custom-token")
|
|
picoCfg.AllowTokenQuery = true
|
|
picoCfg.AllowOrigins = []string{"https://myapp.example.com"}
|
|
if err = config.SaveConfig(configPath, cfg); err != nil {
|
|
t.Fatalf("SaveConfig() error = %v", err)
|
|
}
|
|
|
|
h := NewHandler(configPath)
|
|
|
|
changed, err := h.EnsurePicoChannel()
|
|
if err != nil {
|
|
t.Fatalf("EnsurePicoChannel() error = %v", err)
|
|
}
|
|
if changed {
|
|
t.Error("EnsurePicoChannel() should not change a fully configured config")
|
|
}
|
|
|
|
cfg, err = config.LoadConfig(configPath)
|
|
if err != nil {
|
|
t.Fatalf("LoadConfig() error = %v", err)
|
|
}
|
|
|
|
bc = cfg.Channels["pico"]
|
|
decoded, err = bc.GetDecoded()
|
|
if err != nil {
|
|
t.Fatalf("GetDecoded() error = %v", err)
|
|
}
|
|
picoCfg = decoded.(*config.PicoSettings)
|
|
if picoCfg.Token.String() != "user-custom-token" {
|
|
t.Errorf("token = %q, want %q", picoCfg.Token.String(), "user-custom-token")
|
|
}
|
|
if !picoCfg.AllowTokenQuery {
|
|
t.Error("user's allow_token_query=true must be preserved")
|
|
}
|
|
if len(picoCfg.AllowOrigins) != 1 || picoCfg.AllowOrigins[0] != "https://myapp.example.com" {
|
|
t.Errorf("allow_origins = %v, want [https://myapp.example.com]", picoCfg.AllowOrigins)
|
|
}
|
|
}
|
|
|
|
func TestEnsurePicoChannel_ExistingConfigWithoutSecurityFile(t *testing.T) {
|
|
configPath := filepath.Join(t.TempDir(), "config.json")
|
|
|
|
cfg := config.DefaultConfig()
|
|
raw, err := json.Marshal(cfg)
|
|
if err != nil {
|
|
t.Fatalf("Marshal() error = %v", err)
|
|
}
|
|
if err = os.WriteFile(configPath, raw, 0o600); err != nil {
|
|
t.Fatalf("WriteFile() error = %v", err)
|
|
}
|
|
|
|
h := NewHandler(configPath)
|
|
|
|
changed, err := h.EnsurePicoChannel()
|
|
if err != nil {
|
|
t.Fatalf("EnsurePicoChannel() error = %v", err)
|
|
}
|
|
if !changed {
|
|
t.Fatal("EnsurePicoChannel() should report changed when pico is missing")
|
|
}
|
|
|
|
cfg, err = config.LoadConfig(configPath)
|
|
if err != nil {
|
|
t.Fatalf("LoadConfig() error = %v", err)
|
|
}
|
|
|
|
bc := cfg.Channels["pico"]
|
|
decoded, err := bc.GetDecoded()
|
|
if err != nil {
|
|
t.Fatalf("GetDecoded() error = %v", err)
|
|
}
|
|
picoCfg := decoded.(*config.PicoSettings)
|
|
if !bc.Enabled {
|
|
t.Error("expected Pico to be enabled after setup")
|
|
}
|
|
if picoCfg.Token.String() == "" {
|
|
t.Error("expected a non-empty token after setup")
|
|
}
|
|
if _, err := os.Stat(filepath.Join(filepath.Dir(configPath), config.SecurityConfigFile)); err != nil {
|
|
t.Fatalf("expected .security.yml to be created: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestEnsurePicoChannel_ConfiguresPicoWithoutGateway(t *testing.T) {
|
|
configPath := filepath.Join(t.TempDir(), "config.json")
|
|
|
|
cfg := config.DefaultConfig()
|
|
cfg.Agents.Defaults.ModelName = ""
|
|
if err := config.SaveConfig(configPath, cfg); err != nil {
|
|
t.Fatalf("SaveConfig() error = %v", err)
|
|
}
|
|
|
|
h := NewHandler(configPath)
|
|
if _, err := h.EnsurePicoChannel(); err != nil {
|
|
t.Fatalf("EnsurePicoChannel() error = %v", err)
|
|
}
|
|
|
|
cfg, err := config.LoadConfig(configPath)
|
|
if err != nil {
|
|
t.Fatalf("LoadConfig() error = %v", err)
|
|
}
|
|
|
|
bc := cfg.Channels["pico"]
|
|
decoded, err := bc.GetDecoded()
|
|
if err != nil {
|
|
t.Fatalf("GetDecoded() error = %v", err)
|
|
}
|
|
picoCfg := decoded.(*config.PicoSettings)
|
|
if !bc.Enabled {
|
|
t.Error("expected Pico to be enabled after launcher startup setup")
|
|
}
|
|
if picoCfg.Token.String() == "" {
|
|
t.Error("expected a non-empty token after launcher startup setup")
|
|
}
|
|
}
|
|
|
|
func TestEnsurePicoChannel_Idempotent(t *testing.T) {
|
|
configPath := filepath.Join(t.TempDir(), "config.json")
|
|
h := NewHandler(configPath)
|
|
|
|
// First call sets things up
|
|
if _, err := h.EnsurePicoChannel(); err != nil {
|
|
t.Fatalf("first EnsurePicoChannel() error = %v", err)
|
|
}
|
|
|
|
cfg1, _ := config.LoadConfig(configPath)
|
|
bc := cfg1.Channels["pico"]
|
|
decoded, err := bc.GetDecoded()
|
|
if err != nil {
|
|
t.Fatalf("GetDecoded() error = %v", err)
|
|
}
|
|
picoCfg := decoded.(*config.PicoSettings)
|
|
token1 := picoCfg.Token.String()
|
|
|
|
// Second call should be a no-op
|
|
changed, err := h.EnsurePicoChannel()
|
|
if err != nil {
|
|
t.Fatalf("second EnsurePicoChannel() error = %v", err)
|
|
}
|
|
if changed {
|
|
t.Error("second EnsurePicoChannel() should not report changed")
|
|
}
|
|
|
|
cfg2, _ := config.LoadConfig(configPath)
|
|
bc = cfg2.Channels["pico"]
|
|
decoded, err = bc.GetDecoded()
|
|
if err != nil {
|
|
t.Fatalf("GetDecoded() error = %v", err)
|
|
}
|
|
picoCfg = decoded.(*config.PicoSettings)
|
|
if picoCfg.Token.String() != token1 {
|
|
t.Error("token should not change on subsequent calls")
|
|
}
|
|
}
|
|
|
|
func TestHandlePicoSetup_DoesNotPersistRequestOrigin(t *testing.T) {
|
|
configPath := filepath.Join(t.TempDir(), "config.json")
|
|
h := NewHandler(configPath)
|
|
|
|
req := httptest.NewRequest("POST", "/api/pico/setup", nil)
|
|
req.Header.Set("Origin", "http://10.0.0.5:3000")
|
|
rec := httptest.NewRecorder()
|
|
|
|
h.handlePicoSetup(rec, req)
|
|
|
|
if rec.Code != http.StatusOK {
|
|
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
|
|
}
|
|
|
|
cfg, err := config.LoadConfig(configPath)
|
|
if err != nil {
|
|
t.Fatalf("LoadConfig() error = %v", err)
|
|
}
|
|
|
|
bc := cfg.Channels["pico"]
|
|
decoded, err := bc.GetDecoded()
|
|
if err != nil {
|
|
t.Fatalf("GetDecoded() error = %v", err)
|
|
}
|
|
picoCfg := decoded.(*config.PicoSettings)
|
|
if len(picoCfg.AllowOrigins) != 0 {
|
|
t.Errorf("allow_origins = %v, want empty", picoCfg.AllowOrigins)
|
|
}
|
|
}
|
|
|
|
func TestHandlePicoSetup_Response(t *testing.T) {
|
|
configPath := filepath.Join(t.TempDir(), "config.json")
|
|
h := NewHandler(configPath)
|
|
|
|
req := httptest.NewRequest("POST", "/api/pico/setup", nil)
|
|
rec := httptest.NewRecorder()
|
|
|
|
h.handlePicoSetup(rec, req)
|
|
|
|
if rec.Code != http.StatusOK {
|
|
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
|
|
}
|
|
|
|
var resp map[string]any
|
|
if err := json.NewDecoder(rec.Body).Decode(&resp); err != nil {
|
|
t.Fatalf("failed to decode response: %v", err)
|
|
}
|
|
|
|
if _, ok := resp["token"]; ok {
|
|
t.Error("response must not expose the raw pico token")
|
|
}
|
|
if resp["ws_url"] == nil || resp["ws_url"] == "" {
|
|
t.Error("response should contain ws_url")
|
|
}
|
|
if resp["enabled"] != true {
|
|
t.Error("response should have enabled=true")
|
|
}
|
|
if resp["changed"] != true {
|
|
t.Error("response should have changed=true on first setup")
|
|
}
|
|
if resp["configured"] != true {
|
|
t.Error("response should have configured=true")
|
|
}
|
|
}
|
|
|
|
func TestHandleGetPicoInfo_OmitsToken(t *testing.T) {
|
|
configPath := filepath.Join(t.TempDir(), "config.json")
|
|
h := NewHandler(configPath)
|
|
|
|
if _, err := h.EnsurePicoChannel(); err != nil {
|
|
t.Fatalf("EnsurePicoChannel() error = %v", err)
|
|
}
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "http://launcher.local/api/pico/info", nil)
|
|
rec := httptest.NewRecorder()
|
|
|
|
h.handleGetPicoInfo(rec, req)
|
|
|
|
if rec.Code != http.StatusOK {
|
|
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
|
|
}
|
|
|
|
var resp map[string]any
|
|
if err := json.NewDecoder(rec.Body).Decode(&resp); err != nil {
|
|
t.Fatalf("failed to decode response: %v", err)
|
|
}
|
|
|
|
if _, ok := resp["token"]; ok {
|
|
t.Fatal("info response must not expose the raw pico token")
|
|
}
|
|
if resp["enabled"] != true {
|
|
t.Fatalf("enabled = %#v, want true", resp["enabled"])
|
|
}
|
|
if resp["configured"] != true {
|
|
t.Fatalf("configured = %#v, want true", resp["configured"])
|
|
}
|
|
if resp["ws_url"] == nil || resp["ws_url"] == "" {
|
|
t.Fatal("response should contain ws_url")
|
|
}
|
|
}
|
|
|
|
func TestHandleRegenPicoToken_RefreshesGatewayTokenCache(t *testing.T) {
|
|
configPath := filepath.Join(t.TempDir(), "config.json")
|
|
h := NewHandler(configPath)
|
|
|
|
if _, err := h.EnsurePicoChannel(); err != nil {
|
|
t.Fatalf("EnsurePicoChannel() error = %v", err)
|
|
}
|
|
|
|
origPicoToken := gateway.picoToken
|
|
t.Cleanup(func() {
|
|
gateway.mu.Lock()
|
|
gateway.picoToken = origPicoToken
|
|
gateway.mu.Unlock()
|
|
})
|
|
|
|
gateway.mu.Lock()
|
|
gateway.picoToken = "stale-token"
|
|
gateway.mu.Unlock()
|
|
|
|
req := httptest.NewRequest(http.MethodPost, "http://launcher.local/api/pico/token", nil)
|
|
rec := httptest.NewRecorder()
|
|
h.handleRegenPicoToken(rec, req)
|
|
|
|
if rec.Code != http.StatusOK {
|
|
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
|
|
}
|
|
|
|
cfg, err := config.LoadConfig(configPath)
|
|
if err != nil {
|
|
t.Fatalf("LoadConfig() error = %v", err)
|
|
}
|
|
|
|
bc := cfg.Channels["pico"]
|
|
decoded, err := bc.GetDecoded()
|
|
if err != nil {
|
|
t.Fatalf("GetDecoded() error = %v", err)
|
|
}
|
|
token := decoded.(*config.PicoSettings).Token.String()
|
|
if token == "" {
|
|
t.Fatal("expected regenerated pico token to be persisted")
|
|
}
|
|
if token == "stale-token" {
|
|
t.Fatal("expected regenerated pico token to differ from stale cache")
|
|
}
|
|
|
|
gateway.mu.Lock()
|
|
defer gateway.mu.Unlock()
|
|
if gateway.picoToken != token {
|
|
t.Fatalf("gateway.picoToken = %q, want %q", gateway.picoToken, token)
|
|
}
|
|
}
|
|
|
|
func TestHandleWebSocketProxyReloadsGatewayTargetFromConfig(t *testing.T) {
|
|
origMatcher := gatewayProcessMatcher
|
|
gatewayProcessMatcher = func(int) (bool, bool) { return true, true }
|
|
t.Cleanup(func() { gatewayProcessMatcher = origMatcher })
|
|
|
|
home := t.TempDir()
|
|
t.Setenv("PICOCLAW_HOME", home)
|
|
|
|
configPath := filepath.Join(t.TempDir(), "config.json")
|
|
h := NewHandler(configPath)
|
|
handler := h.handleWebSocketProxy()
|
|
|
|
server1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.URL.Path != "/pico/ws" {
|
|
t.Fatalf("server1 path = %q, want %q", r.URL.Path, "/pico/ws")
|
|
}
|
|
w.WriteHeader(http.StatusOK)
|
|
_, _ = io.WriteString(w, "server1")
|
|
}))
|
|
defer server1.Close()
|
|
|
|
server2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.URL.Path != "/pico/ws" {
|
|
t.Fatalf("server2 path = %q, want %q", r.URL.Path, "/pico/ws")
|
|
}
|
|
w.WriteHeader(http.StatusOK)
|
|
_, _ = io.WriteString(w, "server2")
|
|
}))
|
|
defer server2.Close()
|
|
|
|
cfg := config.DefaultConfig()
|
|
cfg.Gateway.Host = "127.0.0.1"
|
|
cfg.Gateway.Port = mustGatewayTestPort(t, server1.URL)
|
|
if err := config.SaveConfig(configPath, cfg); err != nil {
|
|
t.Fatalf("SaveConfig() error = %v", err)
|
|
}
|
|
cmd := startGatewayLikeProcess(t)
|
|
t.Cleanup(func() {
|
|
if cmd.Process != nil {
|
|
_ = cmd.Process.Kill()
|
|
}
|
|
_ = cmd.Wait()
|
|
})
|
|
writeTestPidFile(t, ppid.PidFileData{
|
|
PID: cmd.Process.Pid,
|
|
Token: "test-token",
|
|
Host: cfg.Gateway.Host,
|
|
Port: cfg.Gateway.Port,
|
|
})
|
|
origPidData := gateway.pidData
|
|
origPicoToken := gateway.picoToken
|
|
t.Cleanup(func() {
|
|
ppid.RemovePidFile(globalConfigDir())
|
|
gateway.pidData = origPidData
|
|
gateway.picoToken = origPicoToken
|
|
})
|
|
|
|
gateway.pidData = &ppid.PidFileData{}
|
|
gateway.picoToken = "pico"
|
|
req1 := newPicoProxyRequest(http.MethodGet, "/pico/ws")
|
|
rec1 := httptest.NewRecorder()
|
|
handler(rec1, req1)
|
|
|
|
if rec1.Code != http.StatusOK {
|
|
t.Fatalf("first status = %d, want %d", rec1.Code, http.StatusOK)
|
|
}
|
|
if body := rec1.Body.String(); body != "server1" {
|
|
t.Fatalf("first body = %q, want %q", body, "server1")
|
|
}
|
|
|
|
cfg.Gateway.Port = mustGatewayTestPort(t, server2.URL)
|
|
if err := config.SaveConfig(configPath, cfg); err != nil {
|
|
t.Fatalf("SaveConfig() error = %v", err)
|
|
}
|
|
|
|
req2 := newPicoProxyRequest(http.MethodGet, "/pico/ws")
|
|
rec2 := httptest.NewRecorder()
|
|
handler(rec2, req2)
|
|
|
|
if rec2.Code != http.StatusOK {
|
|
t.Fatalf("second status = %d, want %d", rec2.Code, http.StatusOK)
|
|
}
|
|
if body := rec2.Body.String(); body != "server2" {
|
|
t.Fatalf("second body = %q, want %q", body, "server2")
|
|
}
|
|
}
|
|
|
|
func TestHandleWebSocketProxyLoadsCachedPicoTokenWhenMissing(t *testing.T) {
|
|
origMatcher := gatewayProcessMatcher
|
|
gatewayProcessMatcher = func(int) (bool, bool) { return true, true }
|
|
t.Cleanup(func() { gatewayProcessMatcher = origMatcher })
|
|
|
|
home := t.TempDir()
|
|
t.Setenv("PICOCLAW_HOME", home)
|
|
|
|
configPath := filepath.Join(t.TempDir(), "config.json")
|
|
h := NewHandler(configPath)
|
|
handler := h.handleWebSocketProxy()
|
|
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.URL.Path != "/pico/ws" {
|
|
t.Fatalf("path = %q, want %q", r.URL.Path, "/pico/ws")
|
|
}
|
|
w.WriteHeader(http.StatusOK)
|
|
_, _ = io.WriteString(w, "proxied")
|
|
}))
|
|
defer server.Close()
|
|
|
|
cfg := config.DefaultConfig()
|
|
cfg.Gateway.Host = "127.0.0.1"
|
|
cfg.Gateway.Port = mustGatewayTestPort(t, server.URL)
|
|
bc := cfg.Channels["pico"]
|
|
decoded, err := bc.GetDecoded()
|
|
if err != nil {
|
|
t.Fatalf("GetDecoded() error = %v", err)
|
|
}
|
|
picoCfg := decoded.(*config.PicoSettings)
|
|
bc.Enabled = true
|
|
picoCfg.SetToken("cached-token")
|
|
if err := config.SaveConfig(configPath, cfg); err != nil {
|
|
t.Fatalf("SaveConfig() error = %v", err)
|
|
}
|
|
cmd := startGatewayLikeProcess(t)
|
|
t.Cleanup(func() {
|
|
if cmd.Process != nil {
|
|
_ = cmd.Process.Kill()
|
|
}
|
|
_ = cmd.Wait()
|
|
})
|
|
writeTestPidFile(t, ppid.PidFileData{
|
|
PID: cmd.Process.Pid,
|
|
Token: "test-token",
|
|
Host: cfg.Gateway.Host,
|
|
Port: cfg.Gateway.Port,
|
|
})
|
|
t.Cleanup(func() {
|
|
ppid.RemovePidFile(globalConfigDir())
|
|
})
|
|
|
|
origPidData := gateway.pidData
|
|
origPicoToken := gateway.picoToken
|
|
t.Cleanup(func() {
|
|
gateway.pidData = origPidData
|
|
gateway.picoToken = origPicoToken
|
|
})
|
|
|
|
gateway.pidData = &ppid.PidFileData{}
|
|
gateway.picoToken = ""
|
|
|
|
req := newPicoProxyRequest(http.MethodGet, "/pico/ws?session_id=test-session")
|
|
rec := httptest.NewRecorder()
|
|
handler(rec, req)
|
|
|
|
if rec.Code != http.StatusOK {
|
|
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
|
|
}
|
|
if body := rec.Body.String(); body != "proxied" {
|
|
t.Fatalf("body = %q, want %q", body, "proxied")
|
|
}
|
|
if gateway.picoToken != "cached-token" {
|
|
t.Fatalf("gateway.picoToken = %q, want %q", gateway.picoToken, "cached-token")
|
|
}
|
|
}
|
|
|
|
func TestHandleWebSocketProxyLoadsPidDataOnDemand(t *testing.T) {
|
|
origMatcher := gatewayProcessMatcher
|
|
gatewayProcessMatcher = func(int) (bool, bool) { return true, true }
|
|
t.Cleanup(func() { gatewayProcessMatcher = origMatcher })
|
|
|
|
home := t.TempDir()
|
|
t.Setenv("PICOCLAW_HOME", home)
|
|
|
|
configPath := filepath.Join(t.TempDir(), "config.json")
|
|
h := NewHandler(configPath)
|
|
handler := h.handleWebSocketProxy()
|
|
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.URL.Path != "/pico/ws" {
|
|
t.Fatalf("path = %q, want %q", r.URL.Path, "/pico/ws")
|
|
}
|
|
w.WriteHeader(http.StatusOK)
|
|
_, _ = io.WriteString(w, r.Header.Get(protocolKey))
|
|
}))
|
|
defer server.Close()
|
|
|
|
cfg := config.DefaultConfig()
|
|
cfg.Gateway.Host = "127.0.0.1"
|
|
cfg.Gateway.Port = mustGatewayTestPort(t, server.URL)
|
|
bc := cfg.Channels["pico"]
|
|
bc.Enabled = true
|
|
decoded, err := bc.GetDecoded()
|
|
if err != nil {
|
|
t.Fatalf("GetDecoded() error = %v", err)
|
|
}
|
|
decoded.(*config.PicoSettings).SetToken("ui-token")
|
|
if err := config.SaveConfig(configPath, cfg); err != nil {
|
|
t.Fatalf("SaveConfig() error = %v", err)
|
|
}
|
|
|
|
cmd := startGatewayLikeProcess(t)
|
|
t.Cleanup(func() {
|
|
if cmd.Process != nil {
|
|
_ = cmd.Process.Kill()
|
|
}
|
|
_ = cmd.Wait()
|
|
})
|
|
pidData := ppid.PidFileData{
|
|
PID: cmd.Process.Pid,
|
|
Token: "test-token",
|
|
Host: cfg.Gateway.Host,
|
|
Port: cfg.Gateway.Port,
|
|
}
|
|
writeTestPidFile(t, pidData)
|
|
t.Cleanup(func() {
|
|
ppid.RemovePidFile(globalConfigDir())
|
|
})
|
|
|
|
origPidData := gateway.pidData
|
|
origPicoToken := gateway.picoToken
|
|
origStatus := gateway.runtimeStatus
|
|
t.Cleanup(func() {
|
|
gateway.mu.Lock()
|
|
gateway.pidData = origPidData
|
|
gateway.picoToken = origPicoToken
|
|
gateway.runtimeStatus = origStatus
|
|
gateway.mu.Unlock()
|
|
})
|
|
|
|
gateway.mu.Lock()
|
|
gateway.pidData = nil
|
|
gateway.picoToken = ""
|
|
setGatewayRuntimeStatusLocked("stopped")
|
|
gateway.mu.Unlock()
|
|
|
|
req := newPicoProxyRequest(http.MethodGet, "/pico/ws?session_id=test-session")
|
|
rec := httptest.NewRecorder()
|
|
handler(rec, req)
|
|
|
|
if rec.Code != http.StatusOK {
|
|
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
|
|
}
|
|
|
|
expected := tokenPrefix + "ui-token"
|
|
if got := rec.Body.String(); got != expected {
|
|
t.Fatalf("forwarded protocol = %q, want %q", got, expected)
|
|
}
|
|
|
|
gateway.mu.Lock()
|
|
defer gateway.mu.Unlock()
|
|
if gateway.pidData == nil {
|
|
t.Fatal("gateway.pidData should be loaded from pid file")
|
|
}
|
|
if gateway.runtimeStatus != "running" {
|
|
t.Fatalf("runtimeStatus = %q, want %q", gateway.runtimeStatus, "running")
|
|
}
|
|
}
|
|
|
|
func TestCreatePicoHTTPProxyInjectsGatewayAuth(t *testing.T) {
|
|
configPath := filepath.Join(t.TempDir(), "config.json")
|
|
h := NewHandler(configPath)
|
|
|
|
cfg := config.DefaultConfig()
|
|
cfg.Gateway.Host = "127.0.0.1"
|
|
cfg.Gateway.Port = 18790
|
|
bc := cfg.Channels["pico"]
|
|
bc.Enabled = true
|
|
decoded, err := bc.GetDecoded()
|
|
if err != nil {
|
|
t.Fatalf("GetDecoded() error = %v", err)
|
|
}
|
|
decoded.(*config.PicoSettings).SetToken("ui-token")
|
|
if err := config.SaveConfig(configPath, cfg); err != nil {
|
|
t.Fatalf("SaveConfig() error = %v", err)
|
|
}
|
|
|
|
proxy := h.createPicoHTTPProxy("ui-token")
|
|
var capturedPath string
|
|
var capturedAuth string
|
|
proxy.Transport = roundTripFunc(func(req *http.Request) (*http.Response, error) {
|
|
capturedPath = req.URL.Path
|
|
capturedAuth = req.Header.Get("Authorization")
|
|
return &http.Response{
|
|
StatusCode: http.StatusOK,
|
|
Header: make(http.Header),
|
|
Body: io.NopCloser(strings.NewReader("proxied")),
|
|
Request: req,
|
|
}, nil
|
|
})
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/pico/media/attachment-1", nil)
|
|
rec := httptest.NewRecorder()
|
|
proxy.ServeHTTP(rec, req)
|
|
|
|
if rec.Code != http.StatusOK {
|
|
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
|
|
}
|
|
if capturedPath != "/pico/media/attachment-1" {
|
|
t.Fatalf("capturedPath = %q, want %q", capturedPath, "/pico/media/attachment-1")
|
|
}
|
|
expected := "Bearer ui-token"
|
|
if capturedAuth != expected {
|
|
t.Fatalf("Authorization = %q, want %q", capturedAuth, expected)
|
|
}
|
|
}
|
|
|
|
func TestHandlePicoMediaProxyUsesRawBearerToken(t *testing.T) {
|
|
home := t.TempDir()
|
|
t.Setenv("PICOCLAW_HOME", home)
|
|
|
|
configPath := filepath.Join(t.TempDir(), "config.json")
|
|
h := NewHandler(configPath)
|
|
handler := h.handlePicoMediaProxy()
|
|
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.URL.Path != "/pico/media/attachment-1" {
|
|
t.Fatalf("path = %q, want %q", r.URL.Path, "/pico/media/attachment-1")
|
|
}
|
|
if got := r.Header.Get("Authorization"); got != "Bearer ui-token" {
|
|
t.Fatalf("Authorization = %q, want %q", got, "Bearer ui-token")
|
|
}
|
|
w.WriteHeader(http.StatusOK)
|
|
_, _ = io.WriteString(w, "proxied-media")
|
|
}))
|
|
defer server.Close()
|
|
|
|
cfg := config.DefaultConfig()
|
|
cfg.Gateway.Host = "127.0.0.1"
|
|
cfg.Gateway.Port = mustGatewayTestPort(t, server.URL)
|
|
bc := cfg.Channels["pico"]
|
|
bc.Enabled = true
|
|
decoded, err := bc.GetDecoded()
|
|
if err != nil {
|
|
t.Fatalf("GetDecoded() error = %v", err)
|
|
}
|
|
decoded.(*config.PicoSettings).SetToken("ui-token")
|
|
if err := config.SaveConfig(configPath, cfg); err != nil {
|
|
t.Fatalf("SaveConfig() error = %v", err)
|
|
}
|
|
|
|
cmd := startGatewayLikeProcess(t)
|
|
t.Cleanup(func() {
|
|
if cmd.Process != nil {
|
|
_ = cmd.Process.Kill()
|
|
}
|
|
_ = cmd.Wait()
|
|
})
|
|
|
|
origPidData := gateway.pidData
|
|
origPicoToken := gateway.picoToken
|
|
origCmd := gateway.cmd
|
|
t.Cleanup(func() {
|
|
gateway.mu.Lock()
|
|
gateway.pidData = origPidData
|
|
gateway.picoToken = origPicoToken
|
|
gateway.cmd = origCmd
|
|
gateway.mu.Unlock()
|
|
})
|
|
|
|
gateway.mu.Lock()
|
|
gateway.pidData = &ppid.PidFileData{PID: cmd.Process.Pid}
|
|
gateway.picoToken = "ui-token"
|
|
gateway.cmd = cmd
|
|
gateway.mu.Unlock()
|
|
|
|
req := newPicoProxyRequest(http.MethodGet, "/pico/media/attachment-1")
|
|
rec := httptest.NewRecorder()
|
|
handler(rec, req)
|
|
|
|
if rec.Code != http.StatusOK {
|
|
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
|
|
}
|
|
if body := rec.Body.String(); body != "proxied-media" {
|
|
t.Fatalf("body = %q, want %q", body, "proxied-media")
|
|
}
|
|
}
|
|
|
|
func TestHandleWebSocketProxyRejectsStalePidDataAfterProcessExit(t *testing.T) {
|
|
tmpDir := t.TempDir()
|
|
t.Setenv("HOME", tmpDir)
|
|
t.Setenv("PICOCLAW_HOME", filepath.Join(tmpDir, ".picoclaw"))
|
|
|
|
configPath := filepath.Join(tmpDir, "config.json")
|
|
h := NewHandler(configPath)
|
|
handler := h.handleWebSocketProxy()
|
|
|
|
cfg := config.DefaultConfig()
|
|
bc := cfg.Channels["pico"]
|
|
bc.Enabled = true
|
|
decoded, err := bc.GetDecoded()
|
|
if err != nil {
|
|
t.Fatalf("GetDecoded() error = %v", err)
|
|
}
|
|
decoded.(*config.PicoSettings).SetToken("ui-token")
|
|
if err := config.SaveConfig(configPath, cfg); err != nil {
|
|
t.Fatalf("SaveConfig() error = %v", err)
|
|
}
|
|
|
|
cmd := startLongRunningProcess(t)
|
|
if cmd.Process != nil {
|
|
_ = cmd.Process.Kill()
|
|
}
|
|
_ = cmd.Wait()
|
|
|
|
origPidData := gateway.pidData
|
|
origPicoToken := gateway.picoToken
|
|
origCmd := gateway.cmd
|
|
origStatus := gateway.runtimeStatus
|
|
t.Cleanup(func() {
|
|
gateway.mu.Lock()
|
|
gateway.pidData = origPidData
|
|
gateway.picoToken = origPicoToken
|
|
gateway.cmd = origCmd
|
|
gateway.runtimeStatus = origStatus
|
|
gateway.mu.Unlock()
|
|
})
|
|
|
|
gateway.mu.Lock()
|
|
gateway.pidData = &ppid.PidFileData{PID: cmd.Process.Pid, Token: "stale-token"}
|
|
gateway.picoToken = "ui-token"
|
|
gateway.cmd = cmd
|
|
setGatewayRuntimeStatusLocked("running")
|
|
gateway.mu.Unlock()
|
|
|
|
req := newPicoProxyRequest(http.MethodGet, "/pico/ws?session_id=test-session")
|
|
rec := httptest.NewRecorder()
|
|
handler(rec, req)
|
|
|
|
if rec.Code != http.StatusServiceUnavailable {
|
|
t.Fatalf("status = %d, want %d", rec.Code, http.StatusServiceUnavailable)
|
|
}
|
|
gateway.mu.Lock()
|
|
defer gateway.mu.Unlock()
|
|
if gateway.pidData != nil {
|
|
t.Fatal("gateway.pidData should be cleared after stale process exit is detected")
|
|
}
|
|
}
|
|
|
|
func TestHandleWebSocketProxy_AllowsArbitraryOrigin(t *testing.T) {
|
|
origMatcher := gatewayProcessMatcher
|
|
gatewayProcessMatcher = func(int) (bool, bool) { return true, true }
|
|
t.Cleanup(func() { gatewayProcessMatcher = origMatcher })
|
|
|
|
home := t.TempDir()
|
|
t.Setenv("PICOCLAW_HOME", home)
|
|
|
|
configPath := filepath.Join(t.TempDir(), "config.json")
|
|
h := NewHandler(configPath)
|
|
handler := h.handleWebSocketProxy()
|
|
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.URL.Path != "/pico/ws" {
|
|
t.Fatalf("path = %q, want %q", r.URL.Path, "/pico/ws")
|
|
}
|
|
w.WriteHeader(http.StatusOK)
|
|
_, _ = io.WriteString(w, "proxied")
|
|
}))
|
|
defer server.Close()
|
|
|
|
cfg := config.DefaultConfig()
|
|
cfg.Gateway.Host = "127.0.0.1"
|
|
cfg.Gateway.Port = mustGatewayTestPort(t, server.URL)
|
|
bc := cfg.Channels["pico"]
|
|
bc.Enabled = true
|
|
decoded, err := bc.GetDecoded()
|
|
if err != nil {
|
|
t.Fatalf("GetDecoded() error = %v", err)
|
|
}
|
|
decoded.(*config.PicoSettings).SetToken("ui-token")
|
|
if err := config.SaveConfig(configPath, cfg); err != nil {
|
|
t.Fatalf("SaveConfig() error = %v", err)
|
|
}
|
|
|
|
cmd := startGatewayLikeProcess(t)
|
|
t.Cleanup(func() {
|
|
if cmd.Process != nil {
|
|
_ = cmd.Process.Kill()
|
|
}
|
|
_ = cmd.Wait()
|
|
})
|
|
writeTestPidFile(t, ppid.PidFileData{
|
|
PID: cmd.Process.Pid,
|
|
Token: "test-token",
|
|
Host: cfg.Gateway.Host,
|
|
Port: cfg.Gateway.Port,
|
|
})
|
|
t.Cleanup(func() {
|
|
ppid.RemovePidFile(globalConfigDir())
|
|
})
|
|
|
|
origPidData := gateway.pidData
|
|
origPicoToken := gateway.picoToken
|
|
t.Cleanup(func() {
|
|
gateway.pidData = origPidData
|
|
gateway.picoToken = origPicoToken
|
|
})
|
|
|
|
gateway.pidData = &ppid.PidFileData{}
|
|
gateway.picoToken = "ui-token"
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "http://launcher.local/pico/ws?session_id=test-session", nil)
|
|
req.Header.Set("Origin", "http://evil.example")
|
|
rec := httptest.NewRecorder()
|
|
handler(rec, req)
|
|
|
|
if rec.Code != http.StatusOK {
|
|
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
|
|
}
|
|
}
|
|
|
|
func mustGatewayTestPort(t *testing.T, rawURL string) int {
|
|
t.Helper()
|
|
|
|
parsed, err := url.Parse(rawURL)
|
|
if err != nil {
|
|
t.Fatalf("url.Parse() error = %v", err)
|
|
}
|
|
|
|
port, err := strconv.Atoi(parsed.Port())
|
|
if err != nil {
|
|
t.Fatalf("Atoi(%q) error = %v", parsed.Port(), err)
|
|
}
|
|
|
|
return port
|
|
}
|
|
|
|
type roundTripFunc func(*http.Request) (*http.Response, error)
|
|
|
|
func (fn roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
return fn(req)
|
|
}
|