mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
feat(web): migrate launcher to modular web frontend/backend and improve management UX (#1275)
* refactor: remove the legacy picoclaw-launcher * feat: create initial web frontend and backend structure * feat(packaging): add desktop entry for PicoClaw Launcher (#1062) - Add .desktop file with Terminal=true, named "PicoClaw Launcher" - Install to /usr/share/applications/ for app menu visibility - Add 512x512 PNG icon to /usr/share/icons/hicolor/ Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> * `make dev`: If you haven't built it before, you need to run `build` first. * feat(web): comprehensive web UI and backend refactoring This commit introduces a major overhaul of both the frontend web UI and the Go backend API, transitioning to a highly modular architecture and integrating new core features. Backend: - Refactored monolithic API endpoints into domain-specific modules (config, gateway, log, models, pico, session). - Cleaned up obsolete files (`server.go`, `status.go`, WebSocket handlers) and outdated tests. - Implemented Gateway process lifecycle management (start/stop/restart) and real-time log streaming. Frontend: - Integrated Shadcn UI components to establish a modern, consistent design system. - Introduced a new application layout featuring a responsive sidebar (`app-sidebar`) and header. - Implemented internationalization (i18n) with initial support for English and Chinese. - Restructured API clients, hooks, and Zustand stores into logical domains. - Added new management pages for Settings, Logs, Models, Providers, and Credentials. - Upgraded the Pico chat interface with session history management and dynamic model selection. Build & Config: - Updated frontend dependencies, Vite configuration, and lockfiles. - Refined routing setup and overarching application stylesheets. * feat(web): enhance model management, sorting, and deletion logic - Implement model sorting in UI (default > configured > unconfigured) - Prevent deletion of default models in the frontend - Update backend to clear default settings when a model is deleted - Add existence validation when setting a default model via API - Group models in chat UI by type (API Key, OAuth, Local) - Conditionally display model selector in chat based on configuration status * refactor(web): refactor chat page into modular components/hooks and update i18n - split chat route into dedicated chat components (page, composer, empty state, messages, history, model selector) - extract model/session logic into use-chat-models and use-session-history hooks - update chat locale keys in en/zh and add empty-state/history-related translations * refactor(models): refactor models page into modular components and improve UX - split /models route into dedicated components (page, provider section, card, add/edit sheets, delete dialog) - add provider grouping/sorting, provider labels/icons, and a no-default hint in the models page - add "Set as default model" toggle to add/edit flows with safer defaults - introduce shared form helpers and new UI primitives (field, label, switch) - update i18n strings (en/zh) for models and gateway header text usage - apply minor UI polish (models nav icon, separator client directive) * fix(web): add SPA index fallback for embedded frontend routes Serve existing static assets as-is, keep /api/* and missing asset paths returning 404, and add tests for SPA fallback behavior on refresh. * fix(frontend/chat): normalize message timestamp units to prevent invalid far-future dates * chore: delete TestSPARouteFallsBackToIndex * feat: update build for web-based launcher (#1186) - Makefile: add build-launcher target (builds frontend + Go backend) - GoReleaser: point picoclaw-launcher build to web/backend, add frontend build hook, restore winres hook with updated paths - Restore icon.ico and winres config from main for Windows builds Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> * feat(credentials): add multi-provider OAuth credential management - add backend `/api/oauth/*` endpoints for provider status, browser/device-code/token login, flow query/polling, and logout - extend API handler with OAuth flow/state tracking and route registration, plus OAuth unit tests - implement frontend credentials page/components for OpenAI, Anthropic, and Google Antigravity login/logout - add OAuth API client and `useCredentialsPage` hook, with new EN/ZH i18n strings * chore: remove placeholder index.html from dist (#1188) The .gitkeep is sufficient for go:embed to find the dist directory. Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> * fix(frontend): polish model and credential UX; remove Providers nav - remove the Providers item from sidebar navigation and locale keys - simplify chat composer by dropping attach/voice action buttons - support ReactNode titles in credential cards and add provider brand icons - refine sheet header/footer styling and device-code footer button hierarchy - disable “Set default” when a model is unconfigured or already default * feat(web): Update config page (#1173) * feat(web): Update config page * fix(web): useEffect resets editorValue whenever config changes * fix(web): react-hooks/set-state-in-effect error & pnpm lint #1173 * feat(web): add channel management page for web console (#1190) * feat(web): add channel management page for web console Add a complete channel management UI that allows users to configure messaging channels (Telegram, Discord, Slack, Feishu, etc.) directly from the web console instead of manually editing config.json. Backend: GET/PUT/PATCH API endpoints for listing, updating, and toggling channels with secret field masking. Frontend: Channel cards grid with enable/disable toggles, per-channel configuration sheets with dedicated forms for major platforms and a generic fallback for others. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix(web/channels): move channels to own sidebar group and fix sheet padding - Channels now has its own navigation group instead of being under Services - Fix edit sheet form content padding (px-1 -> px-4) to match header/footer - Fix naked return lint error in extractChannelInfo Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> * fix(web): harden channel config updates and resolve frontend lint issues - validate channel PUT/PATCH updates before saving and return structured validation errors - require `enabled` in toggle requests to avoid silent false defaults - support editing `allow_origins` in the generic channel form and parse string/array inputs on backend - replace channel form `any` usage with `ChannelConfig` (`Record<string, unknown>`) and add safe value helpers - add i18n strings for allow-origins fields and apply related frontend formatting cleanups * fix(frontend): prevent false "Invalid JSON" errors in config editor * feat: add startup readiness checks and propagate start availability to UI - add gateway precondition validation for default model and credentials - auto-start gateway on backend boot when conditions are met - include gateway_start_allowed and gateway_start_reason in status updates - prevent frontend start actions when gateway cannot be started * feat(web): revamp channel config UX with catalog-based routing - replace legacy channel management endpoints with a backend channel catalog API - switch frontend channel updates to PATCH /api/config and per-channel config pages - add dynamic channel items in the sidebar with support for expand/collapse - migrate /channels to nested routes (/channels/$name) and remove old card/sheet flow - improve channel forms with clearer hints, required/error states, and reusable switch cards - fix Discord mention-only toggle to read/write group_trigger.mention_only * refactor(frontend): move shared-form to components and unify default-model switch with SwitchCardField * fix(frontend): improve model form validation and unify secret placeholder handling - block duplicate model aliases when adding a model (with localized error messages) - share masked secret placeholder logic across model and channel forms - refresh gateway state after setting the default model - apply minor UI cleanup to provider icon rendering * feat(web): add visual system config and launcher/autostart controls - add launcher config model and persistence (`launcher-config.json`) for port/public/CIDR settings - add system APIs for launch-at-login and launcher parameters - apply CIDR-based access-control middleware to backend HTTP routes - split config routing into visual config and raw JSON config pages - add frontend system API client and visual config sections for runtime/devices/launcher - expand i18n strings (en/zh) for new config UI - improve sidebar active matching and session ID generation fallback * refactor(frontend): remove i18n fallback strings and drop providers route - Replace `t(key, defaultValue)` calls with key-only translations across UI pages - Clean up locale files by pruning unused keys and adding missing shared keys - Remove the obsolete `/providers` page and update generated route tree * fix(backend): correct gateway status detection on Windows * fix(repo): keep web backend dist placeholder tracked --------- Co-authored-by: Guoguo <16666742+imguoguo@users.noreply.github.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> Co-authored-by: Dihubopen <dihubcn@gmail.com> Co-authored-by: Dihubopen <130813726+Dihubopen@users.noreply.github.com>
This commit is contained in:
@@ -0,0 +1,19 @@
|
||||
# Go build output
|
||||
*.exe
|
||||
*.dll
|
||||
*.so
|
||||
*.dylib
|
||||
*.test
|
||||
*.out
|
||||
picoclaw-web
|
||||
|
||||
# Frontend build artifacts (embedded by Go)
|
||||
dist/*
|
||||
!dist/.gitkeep
|
||||
|
||||
# OS
|
||||
.DS_Store
|
||||
|
||||
# Editors
|
||||
.vscode/
|
||||
.idea/
|
||||
@@ -0,0 +1,47 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type channelCatalogItem struct {
|
||||
Name string `json:"name"`
|
||||
ConfigKey string `json:"config_key"`
|
||||
Variant string `json:"variant,omitempty"`
|
||||
}
|
||||
|
||||
var channelCatalog = []channelCatalogItem{
|
||||
{Name: "telegram", ConfigKey: "telegram"},
|
||||
{Name: "discord", ConfigKey: "discord"},
|
||||
{Name: "slack", ConfigKey: "slack"},
|
||||
{Name: "feishu", ConfigKey: "feishu"},
|
||||
{Name: "dingtalk", ConfigKey: "dingtalk"},
|
||||
{Name: "line", ConfigKey: "line"},
|
||||
{Name: "qq", ConfigKey: "qq"},
|
||||
{Name: "onebot", ConfigKey: "onebot"},
|
||||
{Name: "wecom", ConfigKey: "wecom"},
|
||||
{Name: "wecom_app", ConfigKey: "wecom_app"},
|
||||
{Name: "wecom_aibot", ConfigKey: "wecom_aibot"},
|
||||
{Name: "whatsapp", ConfigKey: "whatsapp", Variant: "bridge"},
|
||||
{Name: "whatsapp_native", ConfigKey: "whatsapp", Variant: "native"},
|
||||
{Name: "pico", ConfigKey: "pico"},
|
||||
{Name: "maixcam", ConfigKey: "maixcam"},
|
||||
{Name: "matrix", ConfigKey: "matrix"},
|
||||
{Name: "irc", ConfigKey: "irc"},
|
||||
}
|
||||
|
||||
// registerChannelRoutes binds read-only channel catalog endpoints to the ServeMux.
|
||||
func (h *Handler) registerChannelRoutes(mux *http.ServeMux) {
|
||||
mux.HandleFunc("GET /api/channels/catalog", h.handleListChannelCatalog)
|
||||
}
|
||||
|
||||
// handleListChannelCatalog returns the channels supported by backend.
|
||||
//
|
||||
// GET /api/channels/catalog
|
||||
func (h *Handler) handleListChannelCatalog(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"channels": channelCatalog,
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,221 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
// registerConfigRoutes binds configuration management endpoints to the ServeMux.
|
||||
func (h *Handler) registerConfigRoutes(mux *http.ServeMux) {
|
||||
mux.HandleFunc("GET /api/config", h.handleGetConfig)
|
||||
mux.HandleFunc("PUT /api/config", h.handleUpdateConfig)
|
||||
mux.HandleFunc("PATCH /api/config", h.handlePatchConfig)
|
||||
}
|
||||
|
||||
// loadFilteredConfig loads the configuration and filters out default placeholder credentials
|
||||
// (like API limits/keys) if the configuration file has not been created yet by the user.
|
||||
func (h *Handler) loadFilteredConfig() (*config.Config, error) {
|
||||
cfg, err := config.LoadConfig(h.configPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
configExists := false
|
||||
if h.configPath != "" {
|
||||
if _, err := os.Stat(h.configPath); err == nil {
|
||||
configExists = true
|
||||
}
|
||||
}
|
||||
|
||||
if !configExists {
|
||||
for i := range cfg.ModelList {
|
||||
cfg.ModelList[i].APIKey = ""
|
||||
cfg.ModelList[i].AuthMethod = ""
|
||||
}
|
||||
}
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
// handleGetConfig returns the complete system configuration.
|
||||
//
|
||||
// GET /api/config
|
||||
func (h *Handler) handleGetConfig(w http.ResponseWriter, r *http.Request) {
|
||||
cfg, err := h.loadFilteredConfig()
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("Failed to load config: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(cfg); err != nil {
|
||||
http.Error(w, "Failed to encode response", http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
// handleUpdateConfig updates the complete system configuration.
|
||||
//
|
||||
// PUT /api/config
|
||||
func (h *Handler) handleUpdateConfig(w http.ResponseWriter, r *http.Request) {
|
||||
body, err := io.ReadAll(io.LimitReader(r.Body, 1<<20))
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to read request body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
defer r.Body.Close()
|
||||
|
||||
var cfg config.Config
|
||||
if err := json.Unmarshal(body, &cfg); err != nil {
|
||||
http.Error(w, fmt.Sprintf("Invalid JSON: %v", err), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if errs := validateConfig(&cfg); len(errs) > 0 {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"status": "validation_error",
|
||||
"errors": errs,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if err := config.SaveConfig(h.configPath, &cfg); err != nil {
|
||||
http.Error(w, fmt.Sprintf("Failed to save config: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]string{"status": "ok"})
|
||||
}
|
||||
|
||||
// handlePatchConfig partially updates the system configuration using JSON Merge Patch (RFC 7396).
|
||||
// Only the fields present in the request body will be updated; all other fields remain unchanged.
|
||||
//
|
||||
// PATCH /api/config
|
||||
func (h *Handler) handlePatchConfig(w http.ResponseWriter, r *http.Request) {
|
||||
patchBody, err := io.ReadAll(io.LimitReader(r.Body, 1<<20))
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to read request body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
defer r.Body.Close()
|
||||
|
||||
// Validate the patch is valid JSON
|
||||
var patch map[string]any
|
||||
if err = json.Unmarshal(patchBody, &patch); err != nil {
|
||||
http.Error(w, fmt.Sprintf("Invalid JSON: %v", err), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Load existing config and marshal to a map for merging
|
||||
cfg, err := config.LoadConfig(h.configPath)
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("Failed to load config: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
existing, err := json.Marshal(cfg)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to serialize current config", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
var base map[string]any
|
||||
if err = json.Unmarshal(existing, &base); err != nil {
|
||||
http.Error(w, "Failed to parse current config", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Recursively merge patch into base
|
||||
mergeMap(base, patch)
|
||||
|
||||
// Convert merged map back to Config struct
|
||||
merged, err := json.Marshal(base)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to serialize merged config", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
var newCfg config.Config
|
||||
if err := json.Unmarshal(merged, &newCfg); err != nil {
|
||||
http.Error(w, fmt.Sprintf("Merged config is invalid: %v", err), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if errs := validateConfig(&newCfg); len(errs) > 0 {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"status": "validation_error",
|
||||
"errors": errs,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if err := config.SaveConfig(h.configPath, &newCfg); err != nil {
|
||||
http.Error(w, fmt.Sprintf("Failed to save config: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]string{"status": "ok"})
|
||||
}
|
||||
|
||||
// validateConfig checks the config for common errors before saving.
|
||||
// Returns a list of human-readable error strings; empty means valid.
|
||||
func validateConfig(cfg *config.Config) []string {
|
||||
var errs []string
|
||||
|
||||
// Validate model_list entries
|
||||
if err := cfg.ValidateModelList(); err != nil {
|
||||
errs = append(errs, err.Error())
|
||||
}
|
||||
|
||||
// Gateway port range
|
||||
if cfg.Gateway.Port != 0 && (cfg.Gateway.Port < 1 || cfg.Gateway.Port > 65535) {
|
||||
errs = append(errs, fmt.Sprintf("gateway.port %d is out of valid range (1-65535)", cfg.Gateway.Port))
|
||||
}
|
||||
|
||||
// Pico channel: token required when enabled
|
||||
if cfg.Channels.Pico.Enabled && cfg.Channels.Pico.Token == "" {
|
||||
errs = append(errs, "channels.pico.token is required when pico channel is enabled")
|
||||
}
|
||||
|
||||
// Telegram: token required when enabled
|
||||
if cfg.Channels.Telegram.Enabled && cfg.Channels.Telegram.Token == "" {
|
||||
errs = append(errs, "channels.telegram.token is required when telegram channel is enabled")
|
||||
}
|
||||
|
||||
// Discord: token required when enabled
|
||||
if cfg.Channels.Discord.Enabled && cfg.Channels.Discord.Token == "" {
|
||||
errs = append(errs, "channels.discord.token is required when discord channel is enabled")
|
||||
}
|
||||
|
||||
return errs
|
||||
}
|
||||
|
||||
// mergeMap recursively merges src into dst (JSON Merge Patch semantics).
|
||||
// - If a key in src has a null value, it is deleted from dst.
|
||||
// - If both dst and src have a nested object for the same key, merge recursively.
|
||||
// - Otherwise the value from src overwrites dst.
|
||||
func mergeMap(dst, src map[string]any) {
|
||||
for key, srcVal := range src {
|
||||
if srcVal == nil {
|
||||
delete(dst, key)
|
||||
continue
|
||||
}
|
||||
srcMap, srcIsMap := srcVal.(map[string]any)
|
||||
dstMap, dstIsMap := dst[key].(map[string]any)
|
||||
if srcIsMap && dstIsMap {
|
||||
mergeMap(dstMap, srcMap)
|
||||
} else {
|
||||
dst[key] = srcVal
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,62 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// GatewayEvent represents a state change event for the gateway process.
|
||||
type GatewayEvent struct {
|
||||
Status string `json:"gateway_status"` // "running", "starting", "stopped", "error"
|
||||
PID int `json:"pid,omitempty"`
|
||||
}
|
||||
|
||||
// EventBroadcaster manages SSE client subscriptions and broadcasts events.
|
||||
type EventBroadcaster struct {
|
||||
mu sync.RWMutex
|
||||
clients map[chan string]struct{}
|
||||
}
|
||||
|
||||
// NewEventBroadcaster creates a new broadcaster.
|
||||
func NewEventBroadcaster() *EventBroadcaster {
|
||||
return &EventBroadcaster{
|
||||
clients: make(map[chan string]struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Subscribe adds a new listener channel and returns it.
|
||||
// The caller must call Unsubscribe when done.
|
||||
func (b *EventBroadcaster) Subscribe() chan string {
|
||||
ch := make(chan string, 8)
|
||||
b.mu.Lock()
|
||||
b.clients[ch] = struct{}{}
|
||||
b.mu.Unlock()
|
||||
return ch
|
||||
}
|
||||
|
||||
// Unsubscribe removes a listener channel and closes it.
|
||||
func (b *EventBroadcaster) Unsubscribe(ch chan string) {
|
||||
b.mu.Lock()
|
||||
delete(b.clients, ch)
|
||||
b.mu.Unlock()
|
||||
close(ch)
|
||||
}
|
||||
|
||||
// Broadcast sends a GatewayEvent to all connected SSE clients.
|
||||
func (b *EventBroadcaster) Broadcast(event GatewayEvent) {
|
||||
data, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
b.mu.RLock()
|
||||
defer b.mu.RUnlock()
|
||||
|
||||
for ch := range b.clients {
|
||||
// Non-blocking send; drop event if client is slow
|
||||
select {
|
||||
case ch <- string(data):
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,555 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
// gateway holds the state for the managed gateway process.
|
||||
var gateway = struct {
|
||||
mu sync.Mutex
|
||||
cmd *exec.Cmd
|
||||
logs *LogBuffer
|
||||
events *EventBroadcaster
|
||||
}{
|
||||
logs: NewLogBuffer(200),
|
||||
events: NewEventBroadcaster(),
|
||||
}
|
||||
|
||||
// registerGatewayRoutes binds gateway lifecycle endpoints to the ServeMux.
|
||||
func (h *Handler) registerGatewayRoutes(mux *http.ServeMux) {
|
||||
mux.HandleFunc("GET /api/gateway/status", h.handleGatewayStatus)
|
||||
mux.HandleFunc("GET /api/gateway/events", h.handleGatewayEvents)
|
||||
mux.HandleFunc("POST /api/gateway/start", h.handleGatewayStart)
|
||||
mux.HandleFunc("POST /api/gateway/stop", h.handleGatewayStop)
|
||||
mux.HandleFunc("POST /api/gateway/restart", h.handleGatewayRestart)
|
||||
}
|
||||
|
||||
// TryAutoStartGateway checks whether gateway start preconditions are met and
|
||||
// starts it when possible. Intended to be called by the backend at startup.
|
||||
func (h *Handler) TryAutoStartGateway() {
|
||||
gateway.mu.Lock()
|
||||
defer gateway.mu.Unlock()
|
||||
|
||||
if isGatewayProcessAliveLocked() {
|
||||
return
|
||||
}
|
||||
if gateway.cmd != nil && gateway.cmd.Process != nil {
|
||||
gateway.cmd = nil
|
||||
}
|
||||
|
||||
ready, reason, err := h.gatewayStartReady()
|
||||
if err != nil {
|
||||
log.Printf("Skip auto-starting gateway: %v", err)
|
||||
return
|
||||
}
|
||||
if !ready {
|
||||
log.Printf("Skip auto-starting gateway: %s", reason)
|
||||
return
|
||||
}
|
||||
|
||||
pid, err := h.startGatewayLocked()
|
||||
if err != nil {
|
||||
log.Printf("Failed to auto-start gateway: %v", err)
|
||||
return
|
||||
}
|
||||
log.Printf("Gateway auto-started (PID: %d)", pid)
|
||||
}
|
||||
|
||||
// gatewayStartReady validates whether current config can start the gateway.
|
||||
func (h *Handler) gatewayStartReady() (bool, string, error) {
|
||||
cfg, err := config.LoadConfig(h.configPath)
|
||||
if err != nil {
|
||||
return false, "", fmt.Errorf("failed to load config: %w", err)
|
||||
}
|
||||
|
||||
modelName := strings.TrimSpace(cfg.Agents.Defaults.GetModelName())
|
||||
if modelName == "" {
|
||||
return false, "no default model configured", nil
|
||||
}
|
||||
|
||||
modelCfg := lookupModelConfig(cfg, modelName)
|
||||
if modelCfg == nil {
|
||||
return false, fmt.Sprintf("default model %q is invalid", modelName), nil
|
||||
}
|
||||
|
||||
hasCredential := strings.TrimSpace(modelCfg.APIKey) != "" ||
|
||||
strings.TrimSpace(modelCfg.AuthMethod) != ""
|
||||
if !hasCredential {
|
||||
return false, fmt.Sprintf("default model %q has no credentials configured", modelName), nil
|
||||
}
|
||||
|
||||
return true, "", nil
|
||||
}
|
||||
|
||||
func lookupModelConfig(cfg *config.Config, modelName string) *config.ModelConfig {
|
||||
modelCfg, err := cfg.GetModelConfig(modelName)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return modelCfg
|
||||
}
|
||||
|
||||
func isGatewayProcessAliveLocked() bool {
|
||||
return isCmdProcessAliveLocked(gateway.cmd)
|
||||
}
|
||||
|
||||
func isCmdProcessAliveLocked(cmd *exec.Cmd) bool {
|
||||
if cmd == nil || cmd.Process == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Wait() sets ProcessState when the process exits; use it when available.
|
||||
if cmd.ProcessState != nil && cmd.ProcessState.Exited() {
|
||||
return false
|
||||
}
|
||||
|
||||
// Windows does not support Signal(0) probing. If we still own cmd and it
|
||||
// has not reported exit, treat it as alive.
|
||||
if runtime.GOOS == "windows" {
|
||||
return true
|
||||
}
|
||||
|
||||
return cmd.Process.Signal(syscall.Signal(0)) == nil
|
||||
}
|
||||
|
||||
func (h *Handler) startGatewayLocked() (int, error) {
|
||||
// Locate the picoclaw executable
|
||||
execPath := findPicoclawBinary()
|
||||
|
||||
cmd := exec.Command(execPath, "gateway")
|
||||
|
||||
stdoutPipe, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to create stdout pipe: %w", err)
|
||||
}
|
||||
|
||||
stderrPipe, err := cmd.StderrPipe()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to create stderr pipe: %w", err)
|
||||
}
|
||||
|
||||
// Clear old logs for this new run
|
||||
gateway.logs.Reset()
|
||||
|
||||
// Ensure Pico Channel is configured before starting gateway
|
||||
if _, err := h.ensurePicoChannel(); err != nil {
|
||||
log.Printf("Warning: failed to ensure pico channel: %v", err)
|
||||
// Non-fatal: gateway can still start without pico channel
|
||||
}
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
return 0, fmt.Errorf("failed to start gateway: %w", err)
|
||||
}
|
||||
|
||||
gateway.cmd = cmd
|
||||
pid := cmd.Process.Pid
|
||||
log.Printf("Started picoclaw gateway (PID: %d) from %s", pid, execPath)
|
||||
|
||||
// Broadcast starting event
|
||||
gateway.events.Broadcast(GatewayEvent{Status: "starting", PID: pid})
|
||||
|
||||
// Capture stdout/stderr in background
|
||||
go scanPipe(stdoutPipe, gateway.logs)
|
||||
go scanPipe(stderrPipe, gateway.logs)
|
||||
|
||||
// Wait for exit in background and clean up
|
||||
go func() {
|
||||
if err := cmd.Wait(); err != nil {
|
||||
log.Printf("Gateway process exited: %v", err)
|
||||
} else {
|
||||
log.Printf("Gateway process exited normally")
|
||||
}
|
||||
|
||||
gateway.mu.Lock()
|
||||
if gateway.cmd == cmd {
|
||||
gateway.cmd = nil
|
||||
}
|
||||
gateway.mu.Unlock()
|
||||
|
||||
// Broadcast stopped event
|
||||
gateway.events.Broadcast(GatewayEvent{Status: "stopped"})
|
||||
}()
|
||||
|
||||
// Start a goroutine to probe health and broadcast "running" once ready
|
||||
go func() {
|
||||
for i := 0; i < 30; i++ { // try for up to 15 seconds
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
gateway.mu.Lock()
|
||||
stillOurs := gateway.cmd == cmd
|
||||
gateway.mu.Unlock()
|
||||
if !stillOurs {
|
||||
return
|
||||
}
|
||||
cfg, err := config.LoadConfig(h.configPath)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
healthHost := "127.0.0.1"
|
||||
if cfg.Gateway.Host != "" && cfg.Gateway.Host != "0.0.0.0" {
|
||||
healthHost = cfg.Gateway.Host
|
||||
}
|
||||
healthPort := cfg.Gateway.Port
|
||||
if healthPort == 0 {
|
||||
healthPort = 18790
|
||||
}
|
||||
healthURL := fmt.Sprintf("http://%s/health", net.JoinHostPort(healthHost, strconv.Itoa(healthPort)))
|
||||
client := http.Client{Timeout: 1 * time.Second}
|
||||
resp, err := client.Get(healthURL)
|
||||
if err == nil {
|
||||
resp.Body.Close()
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
gateway.events.Broadcast(GatewayEvent{Status: "running", PID: pid})
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return pid, nil
|
||||
}
|
||||
|
||||
// handleGatewayStart starts the picoclaw gateway subprocess.
|
||||
//
|
||||
// POST /api/gateway/start
|
||||
func (h *Handler) handleGatewayStart(w http.ResponseWriter, r *http.Request) {
|
||||
gateway.mu.Lock()
|
||||
defer gateway.mu.Unlock()
|
||||
|
||||
// Prevent duplicate starts
|
||||
if isGatewayProcessAliveLocked() {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusConflict)
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"status": "already_running",
|
||||
"pid": gateway.cmd.Process.Pid,
|
||||
})
|
||||
return
|
||||
}
|
||||
if gateway.cmd != nil && gateway.cmd.Process != nil {
|
||||
gateway.cmd = nil
|
||||
}
|
||||
|
||||
ready, reason, err := h.gatewayStartReady()
|
||||
if err != nil {
|
||||
http.Error(
|
||||
w,
|
||||
fmt.Sprintf("Failed to validate gateway start conditions: %v", err),
|
||||
http.StatusInternalServerError,
|
||||
)
|
||||
return
|
||||
}
|
||||
if !ready {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"status": "precondition_failed",
|
||||
"message": reason,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
pid, err := h.startGatewayLocked()
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("Failed to start gateway: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"status": "ok",
|
||||
"pid": pid,
|
||||
})
|
||||
}
|
||||
|
||||
// handleGatewayStop stops the running gateway subprocess gracefully.
|
||||
//
|
||||
// POST /api/gateway/stop
|
||||
func (h *Handler) handleGatewayStop(w http.ResponseWriter, r *http.Request) {
|
||||
gateway.mu.Lock()
|
||||
defer gateway.mu.Unlock()
|
||||
|
||||
if gateway.cmd == nil || gateway.cmd.Process == nil {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"status": "not_running",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
pid := gateway.cmd.Process.Pid
|
||||
|
||||
// Send SIGTERM for graceful shutdown (SIGKILL on Windows)
|
||||
var sigErr error
|
||||
if runtime.GOOS == "windows" {
|
||||
sigErr = gateway.cmd.Process.Kill()
|
||||
} else {
|
||||
sigErr = gateway.cmd.Process.Signal(syscall.SIGTERM)
|
||||
}
|
||||
|
||||
if sigErr != nil {
|
||||
http.Error(w, fmt.Sprintf("Failed to stop gateway (PID %d): %v", pid, sigErr), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("Sent stop signal to gateway (PID: %d)", pid)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"status": "ok",
|
||||
"pid": pid,
|
||||
})
|
||||
}
|
||||
|
||||
// handleGatewayRestart stops the gateway (if running) and starts a new instance.
|
||||
//
|
||||
// POST /api/gateway/restart
|
||||
func (h *Handler) handleGatewayRestart(w http.ResponseWriter, r *http.Request) {
|
||||
gateway.mu.Lock()
|
||||
|
||||
// Stop existing process if running
|
||||
if gateway.cmd != nil && gateway.cmd.Process != nil {
|
||||
if isCmdProcessAliveLocked(gateway.cmd) {
|
||||
// Process is alive, send SIGTERM
|
||||
if runtime.GOOS == "windows" {
|
||||
gateway.cmd.Process.Kill()
|
||||
} else {
|
||||
gateway.cmd.Process.Signal(syscall.SIGTERM)
|
||||
}
|
||||
|
||||
// Wait briefly for it to exit
|
||||
gateway.mu.Unlock()
|
||||
time.Sleep(2 * time.Second)
|
||||
gateway.mu.Lock()
|
||||
}
|
||||
gateway.cmd = nil
|
||||
}
|
||||
|
||||
gateway.mu.Unlock()
|
||||
|
||||
// Start fresh via the existing handler
|
||||
h.handleGatewayStart(w, r)
|
||||
}
|
||||
|
||||
// handleGatewayStatus returns the gateway run status, health info, and logs.
|
||||
//
|
||||
// GET /api/gateway/status
|
||||
func (h *Handler) handleGatewayStatus(w http.ResponseWriter, r *http.Request) {
|
||||
data := map[string]any{}
|
||||
|
||||
// Check process state
|
||||
gateway.mu.Lock()
|
||||
processAlive := isGatewayProcessAliveLocked()
|
||||
if processAlive {
|
||||
data["pid"] = gateway.cmd.Process.Pid
|
||||
}
|
||||
gateway.mu.Unlock()
|
||||
|
||||
if !processAlive {
|
||||
data["gateway_status"] = "stopped"
|
||||
} else {
|
||||
// Process is alive — probe its health endpoint
|
||||
cfg, err := config.LoadConfig(h.configPath)
|
||||
host := "127.0.0.1"
|
||||
port := 18790
|
||||
if err == nil && cfg != nil {
|
||||
if cfg.Gateway.Host != "" && cfg.Gateway.Host != "0.0.0.0" {
|
||||
host = cfg.Gateway.Host
|
||||
}
|
||||
if cfg.Gateway.Port != 0 {
|
||||
port = cfg.Gateway.Port
|
||||
}
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("http://%s/health", net.JoinHostPort(host, strconv.Itoa(port)))
|
||||
client := http.Client{Timeout: 2 * time.Second}
|
||||
resp, err := client.Get(url)
|
||||
|
||||
if err != nil {
|
||||
data["gateway_status"] = "starting"
|
||||
} else {
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
data["gateway_status"] = "error"
|
||||
data["status_code"] = resp.StatusCode
|
||||
} else {
|
||||
var healthData map[string]any
|
||||
if decErr := json.NewDecoder(resp.Body).Decode(&healthData); decErr != nil {
|
||||
data["gateway_status"] = "error"
|
||||
} else {
|
||||
for k, v := range healthData {
|
||||
data[k] = v
|
||||
}
|
||||
data["gateway_status"] = "running"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ready, reason, readyErr := h.gatewayStartReady()
|
||||
if readyErr != nil {
|
||||
data["gateway_start_allowed"] = false
|
||||
data["gateway_start_reason"] = readyErr.Error()
|
||||
} else {
|
||||
data["gateway_start_allowed"] = ready
|
||||
if !ready {
|
||||
data["gateway_start_reason"] = reason
|
||||
}
|
||||
}
|
||||
|
||||
// Append incremental log data
|
||||
appendGatewayLogs(r, data)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(data)
|
||||
}
|
||||
|
||||
// appendGatewayLogs reads log_offset and log_run_id query params from the request
|
||||
// and populates the response data map with incremental log lines.
|
||||
func appendGatewayLogs(r *http.Request, data map[string]any) {
|
||||
clientOffset := 0
|
||||
clientRunID := -1
|
||||
|
||||
if v := r.URL.Query().Get("log_offset"); v != "" {
|
||||
if n, err := strconv.Atoi(v); err == nil {
|
||||
clientOffset = n
|
||||
}
|
||||
}
|
||||
|
||||
if v := r.URL.Query().Get("log_run_id"); v != "" {
|
||||
if n, err := strconv.Atoi(v); err == nil {
|
||||
clientRunID = n
|
||||
}
|
||||
}
|
||||
|
||||
runID := gateway.logs.RunID()
|
||||
|
||||
if runID == 0 {
|
||||
data["logs"] = []string{}
|
||||
data["log_total"] = 0
|
||||
data["log_run_id"] = 0
|
||||
return
|
||||
}
|
||||
|
||||
// If runID changed, reset offset to get all logs from new run
|
||||
offset := clientOffset
|
||||
if clientRunID != runID {
|
||||
offset = 0
|
||||
}
|
||||
|
||||
lines, total, runID := gateway.logs.LinesSince(offset)
|
||||
if lines == nil {
|
||||
lines = []string{}
|
||||
}
|
||||
|
||||
data["logs"] = lines
|
||||
data["log_total"] = total
|
||||
data["log_run_id"] = runID
|
||||
}
|
||||
|
||||
// handleGatewayEvents serves an SSE stream of gateway state change events.
|
||||
//
|
||||
// GET /api/gateway/events
|
||||
func (h *Handler) handleGatewayEvents(w http.ResponseWriter, r *http.Request) {
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
http.Error(w, "SSE not supported", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
w.Header().Set("Connection", "keep-alive")
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
|
||||
// Subscribe to gateway events
|
||||
ch := gateway.events.Subscribe()
|
||||
defer gateway.events.Unsubscribe(ch)
|
||||
|
||||
// Send initial status so the client doesn't start blank
|
||||
initial := h.currentGatewayStatus()
|
||||
fmt.Fprintf(w, "data: %s\n\n", initial)
|
||||
flusher.Flush()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-r.Context().Done():
|
||||
return
|
||||
case data, ok := <-ch:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
fmt.Fprintf(w, "data: %s\n\n", data)
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// currentGatewayStatus returns the current gateway status as a JSON string.
|
||||
func (h *Handler) currentGatewayStatus() string {
|
||||
gateway.mu.Lock()
|
||||
defer gateway.mu.Unlock()
|
||||
|
||||
data := map[string]any{
|
||||
"gateway_status": "stopped",
|
||||
}
|
||||
if isGatewayProcessAliveLocked() {
|
||||
data["gateway_status"] = "running"
|
||||
data["pid"] = gateway.cmd.Process.Pid
|
||||
}
|
||||
|
||||
ready, reason, readyErr := h.gatewayStartReady()
|
||||
if readyErr != nil {
|
||||
data["gateway_start_allowed"] = false
|
||||
data["gateway_start_reason"] = readyErr.Error()
|
||||
} else {
|
||||
data["gateway_start_allowed"] = ready
|
||||
if !ready {
|
||||
data["gateway_start_reason"] = reason
|
||||
}
|
||||
}
|
||||
|
||||
encoded, _ := json.Marshal(data)
|
||||
return string(encoded)
|
||||
}
|
||||
|
||||
// findPicoclawBinary locates the picoclaw executable.
|
||||
// Tries the same directory as the current executable first, then falls back to $PATH.
|
||||
func findPicoclawBinary() string {
|
||||
if exe, err := os.Executable(); err == nil {
|
||||
dir := filepath.Dir(exe)
|
||||
candidate := filepath.Join(dir, "picoclaw")
|
||||
if runtime.GOOS == "windows" {
|
||||
candidate += ".exe"
|
||||
}
|
||||
if info, err := os.Stat(candidate); err == nil && !info.IsDir() {
|
||||
return candidate
|
||||
}
|
||||
}
|
||||
return "picoclaw"
|
||||
}
|
||||
|
||||
// scanPipe reads lines from r and appends them to buf. Returns when r reaches EOF.
|
||||
func scanPipe(r io.Reader, buf *LogBuffer) {
|
||||
scanner := bufio.NewScanner(r)
|
||||
scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024)
|
||||
for scanner.Scan() {
|
||||
buf.Append(scanner.Text())
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,122 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
func TestGatewayStartReady_NoDefaultModel(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
|
||||
ready, reason, err := h.gatewayStartReady()
|
||||
if err != nil {
|
||||
t.Fatalf("gatewayStartReady() error = %v", err)
|
||||
}
|
||||
if ready {
|
||||
t.Fatalf("gatewayStartReady() ready = true, want false")
|
||||
}
|
||||
if reason != "no default model configured" {
|
||||
t.Fatalf("gatewayStartReady() reason = %q, want %q", reason, "no default model configured")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayStartReady_InvalidDefaultModel(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Agents.Defaults.Model = "missing-model"
|
||||
if err := config.SaveConfig(configPath, cfg); err != nil {
|
||||
t.Fatalf("SaveConfig() error = %v", err)
|
||||
}
|
||||
|
||||
h := NewHandler(configPath)
|
||||
ready, reason, err := h.gatewayStartReady()
|
||||
if err != nil {
|
||||
t.Fatalf("gatewayStartReady() error = %v", err)
|
||||
}
|
||||
if ready {
|
||||
t.Fatalf("gatewayStartReady() ready = true, want false")
|
||||
}
|
||||
if reason == "" {
|
||||
t.Fatalf("gatewayStartReady() reason is empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayStartReady_ValidDefaultModel(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Agents.Defaults.ModelName = cfg.ModelList[0].ModelName
|
||||
cfg.ModelList[0].APIKey = "test-key"
|
||||
if err := config.SaveConfig(configPath, cfg); err != nil {
|
||||
t.Fatalf("SaveConfig() error = %v", err)
|
||||
}
|
||||
|
||||
h := NewHandler(configPath)
|
||||
ready, reason, err := h.gatewayStartReady()
|
||||
if err != nil {
|
||||
t.Fatalf("gatewayStartReady() error = %v", err)
|
||||
}
|
||||
if !ready {
|
||||
t.Fatalf("gatewayStartReady() ready = false, want true (reason=%q)", reason)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayStartReady_DefaultModelWithoutCredential(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Agents.Defaults.ModelName = cfg.ModelList[0].ModelName
|
||||
cfg.ModelList[0].APIKey = ""
|
||||
cfg.ModelList[0].AuthMethod = ""
|
||||
if err := config.SaveConfig(configPath, cfg); err != nil {
|
||||
t.Fatalf("SaveConfig() error = %v", err)
|
||||
}
|
||||
|
||||
h := NewHandler(configPath)
|
||||
ready, reason, err := h.gatewayStartReady()
|
||||
if err != nil {
|
||||
t.Fatalf("gatewayStartReady() error = %v", err)
|
||||
}
|
||||
if ready {
|
||||
t.Fatalf("gatewayStartReady() ready = true, want false")
|
||||
}
|
||||
if !strings.Contains(reason, "no credentials configured") {
|
||||
t.Fatalf("gatewayStartReady() reason = %q, want contains %q", reason, "no credentials configured")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayStatusIncludesStartConditionWhenNotReady(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
mux := http.NewServeMux()
|
||||
h.RegisterRoutes(mux)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/gateway/status", nil)
|
||||
mux.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
var body map[string]any
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil {
|
||||
t.Fatalf("unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
allowed, ok := body["gateway_start_allowed"].(bool)
|
||||
if !ok {
|
||||
t.Fatalf("gateway_start_allowed missing or not bool: %#v", body["gateway_start_allowed"])
|
||||
}
|
||||
if allowed {
|
||||
t.Fatalf("gateway_start_allowed = true, want false")
|
||||
}
|
||||
if _, ok := body["gateway_start_reason"].(string); !ok {
|
||||
t.Fatalf("gateway_start_reason missing or not string: %#v", body["gateway_start_reason"])
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,85 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/sipeed/picoclaw/web/backend/launcherconfig"
|
||||
)
|
||||
|
||||
type launcherConfigPayload struct {
|
||||
Port int `json:"port"`
|
||||
Public bool `json:"public"`
|
||||
AllowedCIDRs []string `json:"allowed_cidrs"`
|
||||
}
|
||||
|
||||
func (h *Handler) registerLauncherConfigRoutes(mux *http.ServeMux) {
|
||||
mux.HandleFunc("GET /api/system/launcher-config", h.handleGetLauncherConfig)
|
||||
mux.HandleFunc("PUT /api/system/launcher-config", h.handleUpdateLauncherConfig)
|
||||
}
|
||||
|
||||
func (h *Handler) launcherConfigPath() string {
|
||||
return launcherconfig.PathForAppConfig(h.configPath)
|
||||
}
|
||||
|
||||
func (h *Handler) launcherFallbackConfig() launcherconfig.Config {
|
||||
port := h.serverPort
|
||||
if port <= 0 {
|
||||
port = launcherconfig.DefaultPort
|
||||
}
|
||||
return launcherconfig.Config{
|
||||
Port: port,
|
||||
Public: h.serverPublic,
|
||||
AllowedCIDRs: append([]string(nil), h.serverCIDRs...),
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) loadLauncherConfig() (launcherconfig.Config, error) {
|
||||
return launcherconfig.Load(h.launcherConfigPath(), h.launcherFallbackConfig())
|
||||
}
|
||||
|
||||
func (h *Handler) handleGetLauncherConfig(w http.ResponseWriter, r *http.Request) {
|
||||
cfg, err := h.loadLauncherConfig()
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("Failed to load launcher config: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(launcherConfigPayload{
|
||||
Port: cfg.Port,
|
||||
Public: cfg.Public,
|
||||
AllowedCIDRs: append([]string(nil), cfg.AllowedCIDRs...),
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Handler) handleUpdateLauncherConfig(w http.ResponseWriter, r *http.Request) {
|
||||
var payload launcherConfigPayload
|
||||
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
|
||||
http.Error(w, fmt.Sprintf("Invalid JSON: %v", err), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
cfg := launcherconfig.Config{
|
||||
Port: payload.Port,
|
||||
Public: payload.Public,
|
||||
AllowedCIDRs: append([]string(nil), payload.AllowedCIDRs...),
|
||||
}
|
||||
if err := launcherconfig.Validate(cfg); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if err := launcherconfig.Save(h.launcherConfigPath(), cfg); err != nil {
|
||||
http.Error(w, fmt.Sprintf("Failed to save launcher config: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(launcherConfigPayload{
|
||||
Port: cfg.Port,
|
||||
Public: cfg.Public,
|
||||
AllowedCIDRs: append([]string(nil), cfg.AllowedCIDRs...),
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,115 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/web/backend/launcherconfig"
|
||||
)
|
||||
|
||||
func TestGetLauncherConfigUsesRuntimeFallback(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
h.SetServerOptions(19999, true, []string{"192.168.1.0/24"})
|
||||
|
||||
mux := http.NewServeMux()
|
||||
h.RegisterRoutes(mux)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/system/launcher-config", nil)
|
||||
mux.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String())
|
||||
}
|
||||
|
||||
var got launcherConfigPayload
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &got); err != nil {
|
||||
t.Fatalf("unmarshal response: %v", err)
|
||||
}
|
||||
if got.Port != 19999 || !got.Public {
|
||||
t.Fatalf("response = %+v, want port=19999 public=true", got)
|
||||
}
|
||||
if len(got.AllowedCIDRs) != 1 || got.AllowedCIDRs[0] != "192.168.1.0/24" {
|
||||
t.Fatalf("response allowed_cidrs = %v, want [192.168.1.0/24]", got.AllowedCIDRs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPutLauncherConfigPersists(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
|
||||
mux := http.NewServeMux()
|
||||
h.RegisterRoutes(mux)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(
|
||||
http.MethodPut,
|
||||
"/api/system/launcher-config",
|
||||
strings.NewReader(`{"port":18080,"public":true,"allowed_cidrs":["192.168.1.0/24"]}`),
|
||||
)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
mux.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String())
|
||||
}
|
||||
|
||||
path := launcherconfig.PathForAppConfig(configPath)
|
||||
cfg, err := launcherconfig.Load(path, launcherconfig.Default())
|
||||
if err != nil {
|
||||
t.Fatalf("launcherconfig.Load() error = %v", err)
|
||||
}
|
||||
if cfg.Port != 18080 || !cfg.Public {
|
||||
t.Fatalf("saved config = %+v, want port=18080 public=true", cfg)
|
||||
}
|
||||
if len(cfg.AllowedCIDRs) != 1 || cfg.AllowedCIDRs[0] != "192.168.1.0/24" {
|
||||
t.Fatalf("saved config allowed_cidrs = %v, want [192.168.1.0/24]", cfg.AllowedCIDRs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPutLauncherConfigRejectsInvalidPort(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
|
||||
mux := http.NewServeMux()
|
||||
h.RegisterRoutes(mux)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(
|
||||
http.MethodPut,
|
||||
"/api/system/launcher-config",
|
||||
strings.NewReader(`{"port":70000,"public":false}`),
|
||||
)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
mux.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusBadRequest, rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestPutLauncherConfigRejectsInvalidCIDR(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
|
||||
mux := http.NewServeMux()
|
||||
h.RegisterRoutes(mux)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(
|
||||
http.MethodPut,
|
||||
"/api/system/launcher-config",
|
||||
strings.NewReader(`{"port":18080,"public":false,"allowed_cidrs":["bad-cidr"]}`),
|
||||
)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
mux.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusBadRequest, rec.Body.String())
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,91 @@
|
||||
package api
|
||||
|
||||
import "sync"
|
||||
|
||||
// LogBuffer is a thread-safe ring buffer that stores the most recent N log lines.
|
||||
// It supports incremental reads via LinesSince and tracks a runID that increments
|
||||
// on each Reset (used to detect gateway restarts).
|
||||
type LogBuffer struct {
|
||||
mu sync.RWMutex
|
||||
lines []string
|
||||
cap int
|
||||
total int // total lines ever appended in current run
|
||||
runID int
|
||||
}
|
||||
|
||||
// NewLogBuffer creates a LogBuffer with the given capacity.
|
||||
func NewLogBuffer(capacity int) *LogBuffer {
|
||||
return &LogBuffer{
|
||||
lines: make([]string, 0, capacity),
|
||||
cap: capacity,
|
||||
}
|
||||
}
|
||||
|
||||
// Append adds a line to the buffer. If the buffer is full, the oldest line is evicted.
|
||||
func (b *LogBuffer) Append(line string) {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
if len(b.lines) < b.cap {
|
||||
b.lines = append(b.lines, line)
|
||||
} else {
|
||||
b.lines[b.total%b.cap] = line
|
||||
}
|
||||
|
||||
b.total++
|
||||
}
|
||||
|
||||
// Reset clears the buffer and increments the runID. Call this when starting a new gateway process.
|
||||
func (b *LogBuffer) Reset() {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
b.lines = b.lines[:0]
|
||||
b.total = 0
|
||||
b.runID++
|
||||
}
|
||||
|
||||
// LinesSince returns lines appended after the given offset, the current total count, and the runID.
|
||||
// If offset >= total, no lines are returned. If offset is too old (evicted), all buffered lines are returned.
|
||||
func (b *LogBuffer) LinesSince(offset int) (lines []string, total int, runID int) {
|
||||
b.mu.RLock()
|
||||
defer b.mu.RUnlock()
|
||||
|
||||
total = b.total
|
||||
runID = b.runID
|
||||
|
||||
if offset >= b.total {
|
||||
return nil, total, runID
|
||||
}
|
||||
|
||||
buffered := len(b.lines)
|
||||
|
||||
// How many new lines since offset
|
||||
newCount := b.total - offset
|
||||
if newCount > buffered {
|
||||
newCount = buffered
|
||||
}
|
||||
|
||||
result := make([]string, newCount)
|
||||
|
||||
if b.total <= b.cap {
|
||||
// Buffer hasn't wrapped yet — simple slice
|
||||
copy(result, b.lines[buffered-newCount:])
|
||||
} else {
|
||||
// Buffer has wrapped — read from ring
|
||||
start := (b.total - newCount) % b.cap
|
||||
for i := range newCount {
|
||||
result[i] = b.lines[(start+i)%b.cap]
|
||||
}
|
||||
}
|
||||
|
||||
return result, total, runID
|
||||
}
|
||||
|
||||
// RunID returns the current run identifier.
|
||||
func (b *LogBuffer) RunID() int {
|
||||
b.mu.RLock()
|
||||
defer b.mu.RUnlock()
|
||||
|
||||
return b.runID
|
||||
}
|
||||
@@ -0,0 +1,298 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
// registerModelRoutes binds model list management endpoints to the ServeMux.
|
||||
func (h *Handler) registerModelRoutes(mux *http.ServeMux) {
|
||||
mux.HandleFunc("GET /api/models", h.handleListModels)
|
||||
mux.HandleFunc("POST /api/models", h.handleAddModel)
|
||||
mux.HandleFunc("POST /api/models/default", h.handleSetDefaultModel)
|
||||
mux.HandleFunc("PUT /api/models/{index}", h.handleUpdateModel)
|
||||
mux.HandleFunc("DELETE /api/models/{index}", h.handleDeleteModel)
|
||||
}
|
||||
|
||||
// modelResponse is the JSON structure returned for each model in the list.
|
||||
// All ModelConfig fields are included so the frontend can display and edit them.
|
||||
type modelResponse struct {
|
||||
Index int `json:"index"`
|
||||
ModelName string `json:"model_name"`
|
||||
Model string `json:"model"`
|
||||
APIBase string `json:"api_base,omitempty"`
|
||||
APIKey string `json:"api_key"`
|
||||
Proxy string `json:"proxy,omitempty"`
|
||||
AuthMethod string `json:"auth_method,omitempty"`
|
||||
// Advanced fields
|
||||
ConnectMode string `json:"connect_mode,omitempty"`
|
||||
Workspace string `json:"workspace,omitempty"`
|
||||
RPM int `json:"rpm,omitempty"`
|
||||
MaxTokensField string `json:"max_tokens_field,omitempty"`
|
||||
RequestTimeout int `json:"request_timeout,omitempty"`
|
||||
ThinkingLevel string `json:"thinking_level,omitempty"`
|
||||
// Meta
|
||||
Configured bool `json:"configured"`
|
||||
IsDefault bool `json:"is_default"`
|
||||
}
|
||||
|
||||
// handleListModels returns all model_list entries with masked API keys.
|
||||
//
|
||||
// GET /api/models
|
||||
func (h *Handler) handleListModels(w http.ResponseWriter, r *http.Request) {
|
||||
cfg, err := h.loadFilteredConfig()
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("Failed to load config: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
defaultModel := cfg.Agents.Defaults.GetModelName()
|
||||
|
||||
models := make([]modelResponse, 0, len(cfg.ModelList))
|
||||
for i, m := range cfg.ModelList {
|
||||
models = append(models, modelResponse{
|
||||
Index: i,
|
||||
ModelName: m.ModelName,
|
||||
Model: m.Model,
|
||||
APIBase: m.APIBase,
|
||||
APIKey: maskAPIKey(m.APIKey),
|
||||
Proxy: m.Proxy,
|
||||
AuthMethod: m.AuthMethod,
|
||||
ConnectMode: m.ConnectMode,
|
||||
Workspace: m.Workspace,
|
||||
RPM: m.RPM,
|
||||
MaxTokensField: m.MaxTokensField,
|
||||
RequestTimeout: m.RequestTimeout,
|
||||
ThinkingLevel: m.ThinkingLevel,
|
||||
Configured: m.APIKey != "" || m.AuthMethod != "",
|
||||
IsDefault: m.ModelName == defaultModel,
|
||||
})
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"models": models,
|
||||
"total": len(models),
|
||||
"default_model": defaultModel,
|
||||
})
|
||||
}
|
||||
|
||||
// handleAddModel appends a new model configuration entry.
|
||||
//
|
||||
// POST /api/models
|
||||
func (h *Handler) handleAddModel(w http.ResponseWriter, r *http.Request) {
|
||||
body, err := io.ReadAll(io.LimitReader(r.Body, 1<<20))
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to read request body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
defer r.Body.Close()
|
||||
|
||||
var mc config.ModelConfig
|
||||
if err = json.Unmarshal(body, &mc); err != nil {
|
||||
http.Error(w, fmt.Sprintf("Invalid JSON: %v", err), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if err = mc.Validate(); err != nil {
|
||||
http.Error(w, fmt.Sprintf("Validation error: %v", err), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
cfg, err := config.LoadConfig(h.configPath)
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("Failed to load config: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
cfg.ModelList = append(cfg.ModelList, mc)
|
||||
|
||||
if err := config.SaveConfig(h.configPath, cfg); err != nil {
|
||||
http.Error(w, fmt.Sprintf("Failed to save config: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"status": "ok",
|
||||
"index": len(cfg.ModelList) - 1,
|
||||
})
|
||||
}
|
||||
|
||||
// handleUpdateModel replaces a model configuration entry at the given index.
|
||||
// If the request body omits api_key (or sends an empty string), the existing
|
||||
// stored key is preserved so callers can update only api_base / proxy without
|
||||
// exposing or clearing the secret.
|
||||
//
|
||||
// PUT /api/models/{index}
|
||||
func (h *Handler) handleUpdateModel(w http.ResponseWriter, r *http.Request) {
|
||||
idx, err := strconv.Atoi(r.PathValue("index"))
|
||||
if err != nil {
|
||||
http.Error(w, "Invalid index", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(io.LimitReader(r.Body, 1<<20))
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to read request body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
defer r.Body.Close()
|
||||
|
||||
var mc config.ModelConfig
|
||||
if err = json.Unmarshal(body, &mc); err != nil {
|
||||
http.Error(w, fmt.Sprintf("Invalid JSON: %v", err), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if err = mc.Validate(); err != nil {
|
||||
http.Error(w, fmt.Sprintf("Validation error: %v", err), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
cfg, err := config.LoadConfig(h.configPath)
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("Failed to load config: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if idx < 0 || idx >= len(cfg.ModelList) {
|
||||
http.Error(w, fmt.Sprintf("Index %d out of range (0-%d)", idx, len(cfg.ModelList)-1), http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
// Preserve the existing API key when the caller omits it (empty string).
|
||||
// This lets the UI update api_base / proxy without clearing the stored secret.
|
||||
if mc.APIKey == "" {
|
||||
mc.APIKey = cfg.ModelList[idx].APIKey
|
||||
}
|
||||
|
||||
cfg.ModelList[idx] = mc
|
||||
|
||||
if err := config.SaveConfig(h.configPath, cfg); err != nil {
|
||||
http.Error(w, fmt.Sprintf("Failed to save config: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]string{"status": "ok"})
|
||||
}
|
||||
|
||||
// handleDeleteModel removes a model configuration entry at the given index.
|
||||
//
|
||||
// DELETE /api/models/{index}
|
||||
func (h *Handler) handleDeleteModel(w http.ResponseWriter, r *http.Request) {
|
||||
idx, err := strconv.Atoi(r.PathValue("index"))
|
||||
if err != nil {
|
||||
http.Error(w, "Invalid index", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
cfg, err := config.LoadConfig(h.configPath)
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("Failed to load config: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if idx < 0 || idx >= len(cfg.ModelList) {
|
||||
http.Error(w, fmt.Sprintf("Index %d out of range (0-%d)", idx, len(cfg.ModelList)-1), http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
deletedModelName := cfg.ModelList[idx].ModelName
|
||||
|
||||
cfg.ModelList = append(cfg.ModelList[:idx], cfg.ModelList[idx+1:]...)
|
||||
|
||||
// If the deleted model was the default, clear it.
|
||||
if cfg.Agents.Defaults.ModelName == deletedModelName {
|
||||
cfg.Agents.Defaults.ModelName = ""
|
||||
}
|
||||
if cfg.Agents.Defaults.Model == deletedModelName {
|
||||
cfg.Agents.Defaults.Model = ""
|
||||
}
|
||||
|
||||
if err := config.SaveConfig(h.configPath, cfg); err != nil {
|
||||
http.Error(w, fmt.Sprintf("Failed to save config: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]string{"status": "ok"})
|
||||
}
|
||||
|
||||
// handleSetDefaultModel sets the default model for all agents.
|
||||
//
|
||||
// POST /api/models/default
|
||||
func (h *Handler) handleSetDefaultModel(w http.ResponseWriter, r *http.Request) {
|
||||
body, err := io.ReadAll(io.LimitReader(r.Body, 1<<20))
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to read request body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
defer r.Body.Close()
|
||||
|
||||
var req struct {
|
||||
ModelName string `json:"model_name"`
|
||||
}
|
||||
if err = json.Unmarshal(body, &req); err != nil {
|
||||
http.Error(w, fmt.Sprintf("Invalid JSON: %v", err), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if req.ModelName == "" {
|
||||
http.Error(w, "model_name is required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
cfg, err := config.LoadConfig(h.configPath)
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("Failed to load config: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Verify the model_name exists in model_list
|
||||
found := false
|
||||
for _, m := range cfg.ModelList {
|
||||
if m.ModelName == req.ModelName {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
http.Error(w, fmt.Sprintf("Model %q not found in model_list", req.ModelName), http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
cfg.Agents.Defaults.ModelName = req.ModelName
|
||||
|
||||
if err := config.SaveConfig(h.configPath, cfg); err != nil {
|
||||
http.Error(w, fmt.Sprintf("Failed to save config: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
"status": "ok",
|
||||
"default_model": req.ModelName,
|
||||
})
|
||||
}
|
||||
|
||||
// maskAPIKey returns a masked version of an API key for safe display.
|
||||
// Keys longer than 8 chars show prefix + last 4 chars: "sk-****abcd"
|
||||
// Shorter keys are fully masked as "****".
|
||||
// Empty keys return empty string.
|
||||
func maskAPIKey(key string) string {
|
||||
if key == "" {
|
||||
return ""
|
||||
}
|
||||
if len(key) <= 8 {
|
||||
return "****"
|
||||
}
|
||||
// Show first 3 chars and last 4 chars
|
||||
return key[:3] + "****" + key[len(key)-4:]
|
||||
}
|
||||
@@ -0,0 +1,844 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"html"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/auth"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
const (
|
||||
oauthProviderOpenAI = "openai"
|
||||
oauthProviderAnthropic = "anthropic"
|
||||
oauthProviderGoogleAntigravity = "google-antigravity"
|
||||
|
||||
oauthMethodBrowser = "browser"
|
||||
oauthMethodDeviceCode = "device_code"
|
||||
oauthMethodToken = "token"
|
||||
|
||||
oauthFlowPending = "pending"
|
||||
oauthFlowSuccess = "success"
|
||||
oauthFlowError = "error"
|
||||
oauthFlowExpired = "expired"
|
||||
)
|
||||
|
||||
const (
|
||||
oauthBrowserFlowTTL = 10 * time.Minute
|
||||
oauthDeviceCodeFlowTTL = 15 * time.Minute
|
||||
oauthTerminalFlowGC = 30 * time.Minute
|
||||
)
|
||||
|
||||
var oauthProviderOrder = []string{
|
||||
oauthProviderOpenAI,
|
||||
oauthProviderAnthropic,
|
||||
oauthProviderGoogleAntigravity,
|
||||
}
|
||||
|
||||
var oauthProviderMethods = map[string][]string{
|
||||
oauthProviderOpenAI: {oauthMethodBrowser, oauthMethodDeviceCode, oauthMethodToken},
|
||||
oauthProviderAnthropic: {oauthMethodToken},
|
||||
oauthProviderGoogleAntigravity: {oauthMethodBrowser},
|
||||
}
|
||||
|
||||
var oauthProviderLabels = map[string]string{
|
||||
oauthProviderOpenAI: "OpenAI",
|
||||
oauthProviderAnthropic: "Anthropic",
|
||||
oauthProviderGoogleAntigravity: "Google Antigravity",
|
||||
}
|
||||
|
||||
var (
|
||||
oauthNow = time.Now
|
||||
oauthGeneratePKCE = auth.GeneratePKCE
|
||||
oauthGenerateState = auth.GenerateState
|
||||
oauthBuildAuthorizeURL = auth.BuildAuthorizeURL
|
||||
oauthRequestDeviceCode = auth.RequestDeviceCode
|
||||
oauthPollDeviceCodeOnce = auth.PollDeviceCodeOnce
|
||||
oauthExchangeCodeForTokens = auth.ExchangeCodeForTokens
|
||||
oauthGetCredential = auth.GetCredential
|
||||
oauthSetCredential = auth.SetCredential
|
||||
oauthDeleteCredential = auth.DeleteCredential
|
||||
oauthLoadConfig = config.LoadConfig
|
||||
oauthSaveConfig = config.SaveConfig
|
||||
oauthFetchAntigravityProject = providers.FetchAntigravityProjectID
|
||||
oauthFetchGoogleUserEmailFunc = fetchGoogleUserEmail
|
||||
)
|
||||
|
||||
type oauthFlow struct {
|
||||
ID string
|
||||
Provider string
|
||||
Method string
|
||||
Status string
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
ExpiresAt time.Time
|
||||
Error string
|
||||
CodeVerifier string
|
||||
OAuthState string
|
||||
RedirectURI string
|
||||
DeviceAuthID string
|
||||
UserCode string
|
||||
VerifyURL string
|
||||
Interval int
|
||||
}
|
||||
|
||||
type oauthProviderStatus struct {
|
||||
Provider string `json:"provider"`
|
||||
DisplayName string `json:"display_name"`
|
||||
Methods []string `json:"methods"`
|
||||
LoggedIn bool `json:"logged_in"`
|
||||
Status string `json:"status"`
|
||||
AuthMethod string `json:"auth_method,omitempty"`
|
||||
ExpiresAt string `json:"expires_at,omitempty"`
|
||||
AccountID string `json:"account_id,omitempty"`
|
||||
Email string `json:"email,omitempty"`
|
||||
ProjectID string `json:"project_id,omitempty"`
|
||||
}
|
||||
|
||||
type oauthFlowResponse struct {
|
||||
FlowID string `json:"flow_id"`
|
||||
Provider string `json:"provider"`
|
||||
Method string `json:"method"`
|
||||
Status string `json:"status"`
|
||||
ExpiresAt string `json:"expires_at,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
UserCode string `json:"user_code,omitempty"`
|
||||
VerifyURL string `json:"verify_url,omitempty"`
|
||||
Interval int `json:"interval,omitempty"`
|
||||
}
|
||||
|
||||
// registerOAuthRoutes binds OAuth login/logout endpoints to the ServeMux.
|
||||
func (h *Handler) registerOAuthRoutes(mux *http.ServeMux) {
|
||||
mux.HandleFunc("GET /api/oauth/providers", h.handleListOAuthProviders)
|
||||
mux.HandleFunc("POST /api/oauth/login", h.handleOAuthLogin)
|
||||
mux.HandleFunc("GET /api/oauth/flows/{id}", h.handleGetOAuthFlow)
|
||||
mux.HandleFunc("POST /api/oauth/flows/{id}/poll", h.handlePollOAuthFlow)
|
||||
mux.HandleFunc("POST /api/oauth/logout", h.handleOAuthLogout)
|
||||
mux.HandleFunc("GET /oauth/callback", h.handleOAuthCallback)
|
||||
}
|
||||
|
||||
func (h *Handler) handleListOAuthProviders(w http.ResponseWriter, r *http.Request) {
|
||||
providersResp := make([]oauthProviderStatus, 0, len(oauthProviderOrder))
|
||||
|
||||
for _, provider := range oauthProviderOrder {
|
||||
cred, err := oauthGetCredential(provider)
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to load credentials: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
item := oauthProviderStatus{
|
||||
Provider: provider,
|
||||
DisplayName: oauthProviderLabels[provider],
|
||||
Methods: oauthProviderMethods[provider],
|
||||
Status: "not_logged_in",
|
||||
}
|
||||
if cred != nil {
|
||||
item.LoggedIn = true
|
||||
item.AuthMethod = cred.AuthMethod
|
||||
item.AccountID = cred.AccountID
|
||||
item.Email = cred.Email
|
||||
item.ProjectID = cred.ProjectID
|
||||
if !cred.ExpiresAt.IsZero() {
|
||||
item.ExpiresAt = cred.ExpiresAt.Format(time.RFC3339)
|
||||
}
|
||||
switch {
|
||||
case cred.IsExpired():
|
||||
item.Status = "expired"
|
||||
case cred.NeedsRefresh():
|
||||
item.Status = "needs_refresh"
|
||||
default:
|
||||
item.Status = "connected"
|
||||
}
|
||||
}
|
||||
|
||||
providersResp = append(providersResp, item)
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"providers": providersResp,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Handler) handleOAuthLogin(w http.ResponseWriter, r *http.Request) {
|
||||
body, err := io.ReadAll(io.LimitReader(r.Body, 1<<20))
|
||||
if err != nil {
|
||||
http.Error(w, "failed to read request body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
defer r.Body.Close()
|
||||
|
||||
var req struct {
|
||||
Provider string `json:"provider"`
|
||||
Method string `json:"method"`
|
||||
Token string `json:"token"`
|
||||
}
|
||||
if err = json.Unmarshal(body, &req); err != nil {
|
||||
http.Error(w, fmt.Sprintf("invalid JSON: %v", err), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
provider, err := normalizeOAuthProvider(req.Provider)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
method := strings.ToLower(strings.TrimSpace(req.Method))
|
||||
if !isOAuthMethodSupported(provider, method) {
|
||||
http.Error(
|
||||
w,
|
||||
fmt.Sprintf("unsupported login method %q for provider %q", method, provider),
|
||||
http.StatusBadRequest,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
switch method {
|
||||
case oauthMethodToken:
|
||||
token := strings.TrimSpace(req.Token)
|
||||
if token == "" {
|
||||
http.Error(w, "token is required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
cred := &auth.AuthCredential{
|
||||
AccessToken: token,
|
||||
Provider: provider,
|
||||
AuthMethod: oauthMethodToken,
|
||||
}
|
||||
if err := h.persistCredentialAndConfig(provider, oauthMethodToken, cred); err != nil {
|
||||
http.Error(w, fmt.Sprintf("token login failed: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"status": "ok",
|
||||
"provider": provider,
|
||||
"method": method,
|
||||
})
|
||||
return
|
||||
|
||||
case oauthMethodDeviceCode:
|
||||
cfg := auth.OpenAIOAuthConfig()
|
||||
info, err := oauthRequestDeviceCode(cfg)
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to request device code: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
now := oauthNow()
|
||||
flow := &oauthFlow{
|
||||
ID: newOAuthFlowID(),
|
||||
Provider: provider,
|
||||
Method: method,
|
||||
Status: oauthFlowPending,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
ExpiresAt: now.Add(oauthDeviceCodeFlowTTL),
|
||||
DeviceAuthID: info.DeviceAuthID,
|
||||
UserCode: info.UserCode,
|
||||
VerifyURL: info.VerifyURL,
|
||||
Interval: info.Interval,
|
||||
}
|
||||
h.storeOAuthFlow(flow)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"status": "ok",
|
||||
"provider": provider,
|
||||
"method": method,
|
||||
"flow_id": flow.ID,
|
||||
"user_code": flow.UserCode,
|
||||
"verify_url": flow.VerifyURL,
|
||||
"interval": flow.Interval,
|
||||
"expires_at": flow.ExpiresAt.Format(time.RFC3339),
|
||||
})
|
||||
return
|
||||
|
||||
case oauthMethodBrowser:
|
||||
cfg, err := oauthConfigForProvider(provider)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
pkce, err := oauthGeneratePKCE()
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to generate PKCE: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
state, err := oauthGenerateState()
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to generate state: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
redirectURI := buildOAuthRedirectURI(r)
|
||||
authURL := oauthBuildAuthorizeURL(cfg, pkce, state, redirectURI)
|
||||
|
||||
now := oauthNow()
|
||||
flow := &oauthFlow{
|
||||
ID: newOAuthFlowID(),
|
||||
Provider: provider,
|
||||
Method: method,
|
||||
Status: oauthFlowPending,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
ExpiresAt: now.Add(oauthBrowserFlowTTL),
|
||||
CodeVerifier: pkce.CodeVerifier,
|
||||
OAuthState: state,
|
||||
RedirectURI: redirectURI,
|
||||
}
|
||||
h.storeOAuthFlow(flow)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"status": "ok",
|
||||
"provider": provider,
|
||||
"method": method,
|
||||
"flow_id": flow.ID,
|
||||
"auth_url": authURL,
|
||||
"expires_at": flow.ExpiresAt.Format(time.RFC3339),
|
||||
})
|
||||
return
|
||||
default:
|
||||
http.Error(w, "unsupported login method", http.StatusBadRequest)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) handleGetOAuthFlow(w http.ResponseWriter, r *http.Request) {
|
||||
flowID := strings.TrimSpace(r.PathValue("id"))
|
||||
if flowID == "" {
|
||||
http.Error(w, "missing flow id", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
flow, ok := h.getOAuthFlow(flowID)
|
||||
if !ok {
|
||||
http.Error(w, "flow not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(flowToResponse(flow))
|
||||
}
|
||||
|
||||
func (h *Handler) handlePollOAuthFlow(w http.ResponseWriter, r *http.Request) {
|
||||
flowID := strings.TrimSpace(r.PathValue("id"))
|
||||
if flowID == "" {
|
||||
http.Error(w, "missing flow id", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
flow, ok := h.getOAuthFlow(flowID)
|
||||
if !ok {
|
||||
http.Error(w, "flow not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
if flow.Method != oauthMethodDeviceCode {
|
||||
http.Error(w, "flow does not support polling", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if flow.Status != oauthFlowPending {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(flowToResponse(flow))
|
||||
return
|
||||
}
|
||||
|
||||
cfg := auth.OpenAIOAuthConfig()
|
||||
cred, err := oauthPollDeviceCodeOnce(cfg, flow.DeviceAuthID, flow.UserCode)
|
||||
if err != nil {
|
||||
if strings.Contains(strings.ToLower(err.Error()), "pending") {
|
||||
updated, _ := h.getOAuthFlow(flowID)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(flowToResponse(updated))
|
||||
return
|
||||
}
|
||||
h.setOAuthFlowError(flowID, fmt.Sprintf("device code poll failed: %v", err))
|
||||
updated, _ := h.getOAuthFlow(flowID)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(flowToResponse(updated))
|
||||
return
|
||||
}
|
||||
if cred == nil {
|
||||
updated, _ := h.getOAuthFlow(flowID)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(flowToResponse(updated))
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.persistCredentialAndConfig(flow.Provider, oauthMethodTokenOrOAuth(flow.Method), cred); err != nil {
|
||||
h.setOAuthFlowError(flowID, fmt.Sprintf("failed to save credential: %v", err))
|
||||
updated, _ := h.getOAuthFlow(flowID)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(flowToResponse(updated))
|
||||
return
|
||||
}
|
||||
|
||||
h.setOAuthFlowSuccess(flowID)
|
||||
updated, _ := h.getOAuthFlow(flowID)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(flowToResponse(updated))
|
||||
}
|
||||
|
||||
func (h *Handler) handleOAuthCallback(w http.ResponseWriter, r *http.Request) {
|
||||
state := strings.TrimSpace(r.URL.Query().Get("state"))
|
||||
if state == "" {
|
||||
renderOAuthCallbackPage(w, "", oauthFlowError, "Missing state", "missing_state")
|
||||
return
|
||||
}
|
||||
|
||||
flow, ok := h.getOAuthFlowByState(state)
|
||||
if !ok {
|
||||
renderOAuthCallbackPage(w, "", oauthFlowError, "OAuth flow not found", "flow_not_found")
|
||||
return
|
||||
}
|
||||
|
||||
if flow.Status != oauthFlowPending {
|
||||
renderOAuthCallbackPage(w, flow.ID, flow.Status, "Flow already completed", flow.Error)
|
||||
return
|
||||
}
|
||||
|
||||
if errMsg := strings.TrimSpace(r.URL.Query().Get("error")); errMsg != "" {
|
||||
if desc := strings.TrimSpace(r.URL.Query().Get("error_description")); desc != "" {
|
||||
errMsg += ": " + desc
|
||||
}
|
||||
h.setOAuthFlowError(flow.ID, errMsg)
|
||||
renderOAuthCallbackPage(w, flow.ID, oauthFlowError, "Authorization failed", errMsg)
|
||||
return
|
||||
}
|
||||
|
||||
code := strings.TrimSpace(r.URL.Query().Get("code"))
|
||||
if code == "" {
|
||||
h.setOAuthFlowError(flow.ID, "missing authorization code")
|
||||
renderOAuthCallbackPage(w, flow.ID, oauthFlowError, "Missing authorization code", "missing_code")
|
||||
return
|
||||
}
|
||||
|
||||
cfg, err := oauthConfigForProvider(flow.Provider)
|
||||
if err != nil {
|
||||
h.setOAuthFlowError(flow.ID, err.Error())
|
||||
renderOAuthCallbackPage(w, flow.ID, oauthFlowError, "Unsupported provider", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
cred, err := oauthExchangeCodeForTokens(cfg, code, flow.CodeVerifier, flow.RedirectURI)
|
||||
if err != nil {
|
||||
h.setOAuthFlowError(flow.ID, fmt.Sprintf("token exchange failed: %v", err))
|
||||
renderOAuthCallbackPage(w, flow.ID, oauthFlowError, "Token exchange failed", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.persistCredentialAndConfig(flow.Provider, oauthMethodTokenOrOAuth(flow.Method), cred); err != nil {
|
||||
h.setOAuthFlowError(flow.ID, fmt.Sprintf("failed to save credential: %v", err))
|
||||
renderOAuthCallbackPage(w, flow.ID, oauthFlowError, "Failed to save credential", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.setOAuthFlowSuccess(flow.ID)
|
||||
renderOAuthCallbackPage(w, flow.ID, oauthFlowSuccess, "Authentication successful", "")
|
||||
}
|
||||
|
||||
func (h *Handler) handleOAuthLogout(w http.ResponseWriter, r *http.Request) {
|
||||
body, err := io.ReadAll(io.LimitReader(r.Body, 1<<20))
|
||||
if err != nil {
|
||||
http.Error(w, "failed to read request body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
defer r.Body.Close()
|
||||
|
||||
var req struct {
|
||||
Provider string `json:"provider"`
|
||||
}
|
||||
if err = json.Unmarshal(body, &req); err != nil {
|
||||
http.Error(w, fmt.Sprintf("invalid JSON: %v", err), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
provider, err := normalizeOAuthProvider(req.Provider)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if err := oauthDeleteCredential(provider); err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to delete credential: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if err := h.syncProviderAuthMethod(provider, ""); err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to update config: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"status": "ok",
|
||||
"provider": provider,
|
||||
})
|
||||
}
|
||||
|
||||
func renderOAuthCallbackPage(w http.ResponseWriter, flowID, status, title, errMsg string) {
|
||||
payload := map[string]string{
|
||||
"type": "picoclaw-oauth-result",
|
||||
"flowId": flowID,
|
||||
"status": status,
|
||||
}
|
||||
if errMsg != "" {
|
||||
payload["error"] = errMsg
|
||||
}
|
||||
payloadJSON, _ := json.Marshal(payload)
|
||||
|
||||
message := title
|
||||
if errMsg != "" {
|
||||
message = fmt.Sprintf("%s: %s", title, errMsg)
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
if status == oauthFlowSuccess {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
} else {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
}
|
||||
|
||||
_, _ = fmt.Fprintf(
|
||||
w,
|
||||
"<!doctype html><html><head><meta charset=\"utf-8\"><title>PicoClaw OAuth</title></head><body><script>(function(){var payload=%s;var hasOpener=false;try{if(window.opener&&!window.opener.closed){window.opener.postMessage(payload,window.location.origin);hasOpener=true}}catch(e){}var target='/credentials?oauth_flow_id='+encodeURIComponent(payload.flowId||'')+'&oauth_status='+encodeURIComponent(payload.status||'');setTimeout(function(){if(hasOpener){window.close();return}window.location.replace(target)},800)})();</script><div style=\"font-family:Inter,system-ui,sans-serif;padding:24px\"><h2>%s</h2><p>%s</p><p>You can close this window.</p></div></body></html>",
|
||||
string(payloadJSON),
|
||||
html.EscapeString(title),
|
||||
html.EscapeString(message),
|
||||
)
|
||||
}
|
||||
|
||||
func normalizeOAuthProvider(raw string) (string, error) {
|
||||
provider := strings.ToLower(strings.TrimSpace(raw))
|
||||
switch provider {
|
||||
case "antigravity":
|
||||
return oauthProviderGoogleAntigravity, nil
|
||||
case oauthProviderOpenAI, oauthProviderAnthropic, oauthProviderGoogleAntigravity:
|
||||
return provider, nil
|
||||
default:
|
||||
return "", fmt.Errorf("unsupported provider %q", raw)
|
||||
}
|
||||
}
|
||||
|
||||
func isOAuthMethodSupported(provider, method string) bool {
|
||||
methods := oauthProviderMethods[provider]
|
||||
for _, m := range methods {
|
||||
if m == method {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func oauthConfigForProvider(provider string) (auth.OAuthProviderConfig, error) {
|
||||
switch provider {
|
||||
case oauthProviderOpenAI:
|
||||
return auth.OpenAIOAuthConfig(), nil
|
||||
case oauthProviderGoogleAntigravity:
|
||||
return auth.GoogleAntigravityOAuthConfig(), nil
|
||||
default:
|
||||
return auth.OAuthProviderConfig{}, fmt.Errorf("provider %q does not support browser oauth", provider)
|
||||
}
|
||||
}
|
||||
|
||||
func oauthMethodTokenOrOAuth(method string) string {
|
||||
if method == oauthMethodToken {
|
||||
return oauthMethodToken
|
||||
}
|
||||
return "oauth"
|
||||
}
|
||||
|
||||
func buildOAuthRedirectURI(r *http.Request) string {
|
||||
scheme := "http"
|
||||
if r.TLS != nil {
|
||||
scheme = "https"
|
||||
}
|
||||
if forwarded := strings.TrimSpace(r.Header.Get("X-Forwarded-Proto")); forwarded != "" {
|
||||
scheme = strings.Split(forwarded, ",")[0]
|
||||
}
|
||||
return fmt.Sprintf("%s://%s/oauth/callback", scheme, r.Host)
|
||||
}
|
||||
|
||||
func flowToResponse(flow *oauthFlow) oauthFlowResponse {
|
||||
resp := oauthFlowResponse{
|
||||
FlowID: flow.ID,
|
||||
Provider: flow.Provider,
|
||||
Method: flow.Method,
|
||||
Status: flow.Status,
|
||||
Error: flow.Error,
|
||||
}
|
||||
if !flow.ExpiresAt.IsZero() {
|
||||
resp.ExpiresAt = flow.ExpiresAt.Format(time.RFC3339)
|
||||
}
|
||||
if flow.Method == oauthMethodDeviceCode {
|
||||
resp.UserCode = flow.UserCode
|
||||
resp.VerifyURL = flow.VerifyURL
|
||||
resp.Interval = flow.Interval
|
||||
}
|
||||
return resp
|
||||
}
|
||||
|
||||
func newOAuthFlowID() string {
|
||||
buf := make([]byte, 16)
|
||||
if _, err := rand.Read(buf); err != nil {
|
||||
return fmt.Sprintf("oauth_%d", time.Now().UnixNano())
|
||||
}
|
||||
return hex.EncodeToString(buf)
|
||||
}
|
||||
|
||||
func (h *Handler) storeOAuthFlow(flow *oauthFlow) {
|
||||
now := oauthNow()
|
||||
h.oauthMu.Lock()
|
||||
defer h.oauthMu.Unlock()
|
||||
|
||||
h.gcOAuthFlowsLocked(now)
|
||||
h.oauthFlows[flow.ID] = flow
|
||||
if flow.OAuthState != "" {
|
||||
h.oauthState[flow.OAuthState] = flow.ID
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) getOAuthFlow(flowID string) (*oauthFlow, bool) {
|
||||
now := oauthNow()
|
||||
h.oauthMu.Lock()
|
||||
defer h.oauthMu.Unlock()
|
||||
|
||||
h.gcOAuthFlowsLocked(now)
|
||||
flow, ok := h.oauthFlows[flowID]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
cp := *flow
|
||||
return &cp, true
|
||||
}
|
||||
|
||||
func (h *Handler) getOAuthFlowByState(state string) (*oauthFlow, bool) {
|
||||
now := oauthNow()
|
||||
h.oauthMu.Lock()
|
||||
defer h.oauthMu.Unlock()
|
||||
|
||||
h.gcOAuthFlowsLocked(now)
|
||||
flowID, ok := h.oauthState[state]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
flow, ok := h.oauthFlows[flowID]
|
||||
if !ok {
|
||||
delete(h.oauthState, state)
|
||||
return nil, false
|
||||
}
|
||||
cp := *flow
|
||||
return &cp, true
|
||||
}
|
||||
|
||||
func (h *Handler) setOAuthFlowSuccess(flowID string) {
|
||||
now := oauthNow()
|
||||
h.oauthMu.Lock()
|
||||
defer h.oauthMu.Unlock()
|
||||
|
||||
flow, ok := h.oauthFlows[flowID]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
flow.Status = oauthFlowSuccess
|
||||
flow.Error = ""
|
||||
flow.UpdatedAt = now
|
||||
if flow.OAuthState != "" {
|
||||
delete(h.oauthState, flow.OAuthState)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) setOAuthFlowError(flowID, errMsg string) {
|
||||
now := oauthNow()
|
||||
h.oauthMu.Lock()
|
||||
defer h.oauthMu.Unlock()
|
||||
|
||||
flow, ok := h.oauthFlows[flowID]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
flow.Status = oauthFlowError
|
||||
flow.Error = errMsg
|
||||
flow.UpdatedAt = now
|
||||
if flow.OAuthState != "" {
|
||||
delete(h.oauthState, flow.OAuthState)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) gcOAuthFlowsLocked(now time.Time) {
|
||||
for id, flow := range h.oauthFlows {
|
||||
if flow.Status == oauthFlowPending && !flow.ExpiresAt.IsZero() && now.After(flow.ExpiresAt) {
|
||||
flow.Status = oauthFlowExpired
|
||||
flow.Error = "flow expired"
|
||||
flow.UpdatedAt = now
|
||||
if flow.OAuthState != "" {
|
||||
delete(h.oauthState, flow.OAuthState)
|
||||
}
|
||||
}
|
||||
|
||||
if flow.Status != oauthFlowPending && now.Sub(flow.UpdatedAt) > oauthTerminalFlowGC {
|
||||
if flow.OAuthState != "" {
|
||||
delete(h.oauthState, flow.OAuthState)
|
||||
}
|
||||
delete(h.oauthFlows, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) persistCredentialAndConfig(provider, authMethod string, cred *auth.AuthCredential) error {
|
||||
if cred == nil {
|
||||
return fmt.Errorf("empty credential")
|
||||
}
|
||||
|
||||
cp := *cred
|
||||
cp.Provider = provider
|
||||
if cp.AuthMethod == "" {
|
||||
cp.AuthMethod = authMethod
|
||||
}
|
||||
|
||||
if provider == oauthProviderGoogleAntigravity {
|
||||
if cp.Email == "" {
|
||||
email, err := oauthFetchGoogleUserEmailFunc(cp.AccessToken)
|
||||
if err != nil {
|
||||
log.Printf("oauth warning: could not fetch google email: %v", err)
|
||||
} else {
|
||||
cp.Email = email
|
||||
}
|
||||
}
|
||||
if cp.ProjectID == "" {
|
||||
projectID, err := oauthFetchAntigravityProject(cp.AccessToken)
|
||||
if err != nil {
|
||||
log.Printf("oauth warning: could not fetch antigravity project id: %v", err)
|
||||
} else {
|
||||
cp.ProjectID = projectID
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := oauthSetCredential(provider, &cp); err != nil {
|
||||
return fmt.Errorf("saving credential: %w", err)
|
||||
}
|
||||
if err := h.syncProviderAuthMethod(provider, authMethod); err != nil {
|
||||
return fmt.Errorf("syncing provider auth config: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *Handler) syncProviderAuthMethod(provider, authMethod string) error {
|
||||
cfg, err := oauthLoadConfig(h.configPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch provider {
|
||||
case oauthProviderOpenAI:
|
||||
cfg.Providers.OpenAI.AuthMethod = authMethod
|
||||
case oauthProviderAnthropic:
|
||||
cfg.Providers.Anthropic.AuthMethod = authMethod
|
||||
case oauthProviderGoogleAntigravity:
|
||||
cfg.Providers.Antigravity.AuthMethod = authMethod
|
||||
default:
|
||||
return fmt.Errorf("unsupported provider %q", provider)
|
||||
}
|
||||
|
||||
found := false
|
||||
for i := range cfg.ModelList {
|
||||
if modelBelongsToProvider(provider, cfg.ModelList[i].Model) {
|
||||
cfg.ModelList[i].AuthMethod = authMethod
|
||||
found = true
|
||||
}
|
||||
}
|
||||
|
||||
if !found && authMethod != "" {
|
||||
cfg.ModelList = append(cfg.ModelList, defaultModelConfigForProvider(provider, authMethod))
|
||||
}
|
||||
|
||||
return oauthSaveConfig(h.configPath, cfg)
|
||||
}
|
||||
|
||||
func modelBelongsToProvider(provider, model string) bool {
|
||||
lower := strings.ToLower(strings.TrimSpace(model))
|
||||
switch provider {
|
||||
case oauthProviderOpenAI:
|
||||
return lower == "openai" || strings.HasPrefix(lower, "openai/")
|
||||
case oauthProviderAnthropic:
|
||||
return lower == "anthropic" || strings.HasPrefix(lower, "anthropic/")
|
||||
case oauthProviderGoogleAntigravity:
|
||||
return lower == "antigravity" ||
|
||||
lower == "google-antigravity" ||
|
||||
strings.HasPrefix(lower, "antigravity/") ||
|
||||
strings.HasPrefix(lower, "google-antigravity/")
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func defaultModelConfigForProvider(provider, authMethod string) config.ModelConfig {
|
||||
switch provider {
|
||||
case oauthProviderOpenAI:
|
||||
return config.ModelConfig{
|
||||
ModelName: "gpt-5.2",
|
||||
Model: "openai/gpt-5.2",
|
||||
AuthMethod: authMethod,
|
||||
}
|
||||
case oauthProviderAnthropic:
|
||||
return config.ModelConfig{
|
||||
ModelName: "claude-sonnet-4.6",
|
||||
Model: "anthropic/claude-sonnet-4.6",
|
||||
AuthMethod: authMethod,
|
||||
}
|
||||
case oauthProviderGoogleAntigravity:
|
||||
return config.ModelConfig{
|
||||
ModelName: "gemini-flash",
|
||||
Model: "antigravity/gemini-3-flash",
|
||||
AuthMethod: authMethod,
|
||||
}
|
||||
default:
|
||||
return config.ModelConfig{}
|
||||
}
|
||||
}
|
||||
|
||||
func fetchGoogleUserEmail(accessToken string) (string, error) {
|
||||
req, err := http.NewRequest(http.MethodGet, "https://www.googleapis.com/oauth2/v2/userinfo", nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
|
||||
client := &http.Client{Timeout: 10 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("userinfo request failed: %s", string(body))
|
||||
}
|
||||
|
||||
var userInfo struct {
|
||||
Email string `json:"email"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &userInfo); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if userInfo.Email == "" {
|
||||
return "", fmt.Errorf("empty email in userinfo response")
|
||||
}
|
||||
return userInfo.Email, nil
|
||||
}
|
||||
@@ -0,0 +1,293 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/auth"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
func TestOAuthLoginRejectsUnsupportedMethod(t *testing.T) {
|
||||
configPath, cleanup := setupOAuthTestEnv(t)
|
||||
defer cleanup()
|
||||
resetOAuthHooks(t)
|
||||
|
||||
h := NewHandler(configPath)
|
||||
mux := http.NewServeMux()
|
||||
h.RegisterRoutes(mux)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(
|
||||
http.MethodPost,
|
||||
"/api/oauth/login",
|
||||
strings.NewReader(`{"provider":"anthropic","method":"browser"}`),
|
||||
)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
mux.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusBadRequest, rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestOAuthBrowserFlowCreatedAndQueried(t *testing.T) {
|
||||
configPath, cleanup := setupOAuthTestEnv(t)
|
||||
defer cleanup()
|
||||
resetOAuthHooks(t)
|
||||
|
||||
oauthGeneratePKCE = func() (auth.PKCECodes, error) {
|
||||
return auth.PKCECodes{CodeVerifier: "verifier-1", CodeChallenge: "challenge-1"}, nil
|
||||
}
|
||||
oauthGenerateState = func() (string, error) { return "state-1", nil }
|
||||
oauthBuildAuthorizeURL = func(cfg auth.OAuthProviderConfig, pkce auth.PKCECodes, state, redirectURI string) string {
|
||||
return "https://example.com/authorize?state=" + state
|
||||
}
|
||||
|
||||
h := NewHandler(configPath)
|
||||
mux := http.NewServeMux()
|
||||
h.RegisterRoutes(mux)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(
|
||||
http.MethodPost,
|
||||
"/api/oauth/login",
|
||||
strings.NewReader(`{"provider":"openai","method":"browser"}`),
|
||||
)
|
||||
req.Host = "localhost:18800"
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
mux.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String())
|
||||
}
|
||||
|
||||
var loginResp map[string]any
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &loginResp); err != nil {
|
||||
t.Fatalf("unmarshal login response: %v", err)
|
||||
}
|
||||
flowID, _ := loginResp["flow_id"].(string)
|
||||
if flowID == "" {
|
||||
t.Fatalf("flow_id is empty: %v", loginResp)
|
||||
}
|
||||
if loginResp["auth_url"] != "https://example.com/authorize?state=state-1" {
|
||||
t.Fatalf("unexpected auth_url: %v", loginResp["auth_url"])
|
||||
}
|
||||
|
||||
rec2 := httptest.NewRecorder()
|
||||
req2 := httptest.NewRequest(http.MethodGet, "/api/oauth/flows/"+flowID, nil)
|
||||
mux.ServeHTTP(rec2, req2)
|
||||
if rec2.Code != http.StatusOK {
|
||||
t.Fatalf("flow status code = %d, want %d, body=%s", rec2.Code, http.StatusOK, rec2.Body.String())
|
||||
}
|
||||
var flowResp oauthFlowResponse
|
||||
if err := json.Unmarshal(rec2.Body.Bytes(), &flowResp); err != nil {
|
||||
t.Fatalf("unmarshal flow response: %v", err)
|
||||
}
|
||||
if flowResp.Status != oauthFlowPending {
|
||||
t.Fatalf("flow status = %q, want %q", flowResp.Status, oauthFlowPending)
|
||||
}
|
||||
if flowResp.Method != oauthMethodBrowser {
|
||||
t.Fatalf("flow method = %q, want %q", flowResp.Method, oauthMethodBrowser)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOAuthFlowExpiresWhenQueried(t *testing.T) {
|
||||
configPath, cleanup := setupOAuthTestEnv(t)
|
||||
defer cleanup()
|
||||
resetOAuthHooks(t)
|
||||
|
||||
now := time.Date(2026, 3, 6, 12, 0, 0, 0, time.UTC)
|
||||
oauthNow = func() time.Time { return now }
|
||||
|
||||
h := NewHandler(configPath)
|
||||
h.storeOAuthFlow(&oauthFlow{
|
||||
ID: "expired-flow",
|
||||
Provider: oauthProviderOpenAI,
|
||||
Method: oauthMethodBrowser,
|
||||
Status: oauthFlowPending,
|
||||
CreatedAt: now.Add(-20 * time.Minute),
|
||||
UpdatedAt: now.Add(-20 * time.Minute),
|
||||
ExpiresAt: now.Add(-1 * time.Minute),
|
||||
})
|
||||
|
||||
mux := http.NewServeMux()
|
||||
h.RegisterRoutes(mux)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/oauth/flows/expired-flow", nil)
|
||||
mux.ServeHTTP(rec, req)
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String())
|
||||
}
|
||||
var flowResp oauthFlowResponse
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &flowResp); err != nil {
|
||||
t.Fatalf("unmarshal flow response: %v", err)
|
||||
}
|
||||
if flowResp.Status != oauthFlowExpired {
|
||||
t.Fatalf("flow status = %q, want %q", flowResp.Status, oauthFlowExpired)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOAuthCallbackUnknownState(t *testing.T) {
|
||||
configPath, cleanup := setupOAuthTestEnv(t)
|
||||
defer cleanup()
|
||||
resetOAuthHooks(t)
|
||||
|
||||
h := NewHandler(configPath)
|
||||
mux := http.NewServeMux()
|
||||
h.RegisterRoutes(mux)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/oauth/callback?state=unknown&code=abc", nil)
|
||||
mux.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Fatalf("status = %d, want %d", rec.Code, http.StatusBadRequest)
|
||||
}
|
||||
if !strings.Contains(rec.Body.String(), "OAuth flow not found") {
|
||||
t.Fatalf("unexpected body: %s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestOAuthLogoutClearsCredentialAndConfig(t *testing.T) {
|
||||
configPath, cleanup := setupOAuthTestEnv(t)
|
||||
defer cleanup()
|
||||
resetOAuthHooks(t)
|
||||
|
||||
cfg, err := config.LoadConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig error: %v", err)
|
||||
}
|
||||
cfg.Providers.OpenAI.AuthMethod = "oauth"
|
||||
cfg.ModelList = append(cfg.ModelList, config.ModelConfig{
|
||||
ModelName: "gpt-5.2",
|
||||
Model: "openai/gpt-5.2",
|
||||
AuthMethod: "oauth",
|
||||
})
|
||||
if err = config.SaveConfig(configPath, cfg); err != nil {
|
||||
t.Fatalf("SaveConfig error: %v", err)
|
||||
}
|
||||
if err = auth.SetCredential(oauthProviderOpenAI, &auth.AuthCredential{
|
||||
AccessToken: "token-before-logout",
|
||||
Provider: oauthProviderOpenAI,
|
||||
AuthMethod: "oauth",
|
||||
}); err != nil {
|
||||
t.Fatalf("SetCredential error: %v", err)
|
||||
}
|
||||
|
||||
h := NewHandler(configPath)
|
||||
mux := http.NewServeMux()
|
||||
h.RegisterRoutes(mux)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/oauth/logout", bytes.NewBufferString(`{"provider":"openai"}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
mux.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String())
|
||||
}
|
||||
|
||||
cred, err := auth.GetCredential(oauthProviderOpenAI)
|
||||
if err != nil {
|
||||
t.Fatalf("GetCredential error: %v", err)
|
||||
}
|
||||
if cred != nil {
|
||||
t.Fatalf("expected credential deleted, got %#v", cred)
|
||||
}
|
||||
|
||||
updated, err := config.LoadConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig error: %v", err)
|
||||
}
|
||||
if updated.Providers.OpenAI.AuthMethod != "" {
|
||||
t.Fatalf("providers.openai.auth_method = %q, want empty", updated.Providers.OpenAI.AuthMethod)
|
||||
}
|
||||
for _, m := range updated.ModelList {
|
||||
if strings.HasPrefix(m.Model, "openai/") && m.AuthMethod != "" {
|
||||
t.Fatalf("openai model auth_method = %q, want empty", m.AuthMethod)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func setupOAuthTestEnv(t *testing.T) (string, func()) {
|
||||
t.Helper()
|
||||
|
||||
tmp := t.TempDir()
|
||||
oldHome := os.Getenv("HOME")
|
||||
oldPicoHome := os.Getenv("PICOCLAW_HOME")
|
||||
|
||||
if err := os.Setenv("HOME", tmp); err != nil {
|
||||
t.Fatalf("set HOME: %v", err)
|
||||
}
|
||||
if err := os.Setenv("PICOCLAW_HOME", filepath.Join(tmp, ".picoclaw")); err != nil {
|
||||
t.Fatalf("set PICOCLAW_HOME: %v", err)
|
||||
}
|
||||
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.ModelList = []config.ModelConfig{{
|
||||
ModelName: "custom-default",
|
||||
Model: "openai/gpt-4o",
|
||||
APIKey: "sk-default",
|
||||
}}
|
||||
cfg.Agents.Defaults.ModelName = "custom-default"
|
||||
|
||||
configPath := filepath.Join(tmp, "config.json")
|
||||
if err := config.SaveConfig(configPath, cfg); err != nil {
|
||||
t.Fatalf("SaveConfig error: %v", err)
|
||||
}
|
||||
|
||||
cleanup := func() {
|
||||
_ = os.Setenv("HOME", oldHome)
|
||||
if oldPicoHome == "" {
|
||||
_ = os.Unsetenv("PICOCLAW_HOME")
|
||||
} else {
|
||||
_ = os.Setenv("PICOCLAW_HOME", oldPicoHome)
|
||||
}
|
||||
}
|
||||
return configPath, cleanup
|
||||
}
|
||||
|
||||
func resetOAuthHooks(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
origNow := oauthNow
|
||||
origGeneratePKCE := oauthGeneratePKCE
|
||||
origGenerateState := oauthGenerateState
|
||||
origBuildAuthorizeURL := oauthBuildAuthorizeURL
|
||||
origRequestDeviceCode := oauthRequestDeviceCode
|
||||
origPollDeviceCodeOnce := oauthPollDeviceCodeOnce
|
||||
origExchangeCodeForTokens := oauthExchangeCodeForTokens
|
||||
origGetCredential := oauthGetCredential
|
||||
origSetCredential := oauthSetCredential
|
||||
origDeleteCredential := oauthDeleteCredential
|
||||
origLoadConfig := oauthLoadConfig
|
||||
origSaveConfig := oauthSaveConfig
|
||||
origFetchProject := oauthFetchAntigravityProject
|
||||
origFetchGoogleEmail := oauthFetchGoogleUserEmailFunc
|
||||
|
||||
t.Cleanup(func() {
|
||||
oauthNow = origNow
|
||||
oauthGeneratePKCE = origGeneratePKCE
|
||||
oauthGenerateState = origGenerateState
|
||||
oauthBuildAuthorizeURL = origBuildAuthorizeURL
|
||||
oauthRequestDeviceCode = origRequestDeviceCode
|
||||
oauthPollDeviceCodeOnce = origPollDeviceCodeOnce
|
||||
oauthExchangeCodeForTokens = origExchangeCodeForTokens
|
||||
oauthGetCredential = origGetCredential
|
||||
oauthSetCredential = origSetCredential
|
||||
oauthDeleteCredential = origDeleteCredential
|
||||
oauthLoadConfig = origLoadConfig
|
||||
oauthSaveConfig = origSaveConfig
|
||||
oauthFetchAntigravityProject = origFetchProject
|
||||
oauthFetchGoogleUserEmailFunc = origFetchGoogleEmail
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,161 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
// registerPicoRoutes binds Pico Channel management endpoints to the ServeMux.
|
||||
func (h *Handler) registerPicoRoutes(mux *http.ServeMux) {
|
||||
mux.HandleFunc("GET /api/pico/token", h.handleGetPicoToken)
|
||||
mux.HandleFunc("POST /api/pico/token", h.handleRegenPicoToken)
|
||||
mux.HandleFunc("POST /api/pico/setup", h.handlePicoSetup)
|
||||
}
|
||||
|
||||
// handleGetPicoToken returns the current WS token and URL for the frontend.
|
||||
//
|
||||
// GET /api/pico/token
|
||||
func (h *Handler) handleGetPicoToken(w http.ResponseWriter, r *http.Request) {
|
||||
cfg, err := config.LoadConfig(h.configPath)
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("Failed to load config: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
wsURL := buildWsURL(r, cfg)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"token": cfg.Channels.Pico.Token,
|
||||
"ws_url": wsURL,
|
||||
"enabled": cfg.Channels.Pico.Enabled,
|
||||
})
|
||||
}
|
||||
|
||||
// handleRegenPicoToken generates a new Pico WebSocket token and saves it.
|
||||
//
|
||||
// POST /api/pico/token
|
||||
func (h *Handler) handleRegenPicoToken(w http.ResponseWriter, r *http.Request) {
|
||||
cfg, err := config.LoadConfig(h.configPath)
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("Failed to load config: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
token := generateSecureToken()
|
||||
cfg.Channels.Pico.Token = token
|
||||
|
||||
if err := config.SaveConfig(h.configPath, cfg); err != nil {
|
||||
http.Error(w, fmt.Sprintf("Failed to save config: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
wsURL := fmt.Sprintf("ws://%s/pico/ws", net.JoinHostPort(cfg.Gateway.Host, strconv.Itoa(cfg.Gateway.Port)))
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"token": token,
|
||||
"ws_url": wsURL,
|
||||
})
|
||||
}
|
||||
|
||||
// ensurePicoChannel checks if the Pico Channel is properly configured and
|
||||
// enables it with sensible defaults if not. Returns true if config was changed.
|
||||
func (h *Handler) ensurePicoChannel() (bool, error) {
|
||||
cfg, err := config.LoadConfig(h.configPath)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to load config: %w", err)
|
||||
}
|
||||
|
||||
changed := false
|
||||
|
||||
if !cfg.Channels.Pico.Enabled {
|
||||
cfg.Channels.Pico.Enabled = true
|
||||
changed = true
|
||||
}
|
||||
|
||||
if cfg.Channels.Pico.Token == "" {
|
||||
cfg.Channels.Pico.Token = generateSecureToken()
|
||||
changed = true
|
||||
}
|
||||
|
||||
if !cfg.Channels.Pico.AllowTokenQuery {
|
||||
cfg.Channels.Pico.AllowTokenQuery = true
|
||||
changed = true
|
||||
}
|
||||
|
||||
// Make sure origins are allowed (frontend might be running on a different port like 5173 during dev)
|
||||
if len(cfg.Channels.Pico.AllowOrigins) == 0 {
|
||||
cfg.Channels.Pico.AllowOrigins = []string{"*"}
|
||||
changed = true
|
||||
}
|
||||
|
||||
if changed {
|
||||
if err := config.SaveConfig(h.configPath, cfg); err != nil {
|
||||
return false, fmt.Errorf("failed to save config: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return changed, nil
|
||||
}
|
||||
|
||||
// handlePicoSetup automatically configures everything needed for the Pico Channel to work.
|
||||
//
|
||||
// POST /api/pico/setup
|
||||
func (h *Handler) handlePicoSetup(w http.ResponseWriter, r *http.Request) {
|
||||
changed, err := h.ensurePicoChannel()
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
cfg, err := config.LoadConfig(h.configPath)
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("Failed to load config: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
wsURL := buildWsURL(r, cfg)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"token": cfg.Channels.Pico.Token,
|
||||
"ws_url": wsURL,
|
||||
"enabled": true,
|
||||
"changed": changed,
|
||||
})
|
||||
}
|
||||
|
||||
// buildWsURL creates a WebSocket URL for the Pico Channel.
|
||||
// When the gateway host is "0.0.0.0" or empty, it uses the hostname from the
|
||||
// incoming HTTP request so the browser gets a connectable address.
|
||||
func buildWsURL(r *http.Request, cfg *config.Config) string {
|
||||
host := cfg.Gateway.Host
|
||||
if host == "" || host == "0.0.0.0" {
|
||||
// Use the hostname the browser used to reach this backend
|
||||
reqHost, _, err := net.SplitHostPort(r.Host)
|
||||
if err != nil {
|
||||
reqHost = r.Host // r.Host might not have a port
|
||||
}
|
||||
host = reqHost
|
||||
}
|
||||
return "ws://" + net.JoinHostPort(host, strconv.Itoa(cfg.Gateway.Port)) + "/pico/ws"
|
||||
}
|
||||
|
||||
// generateSecureToken creates a random 32-character hex string.
|
||||
func generateSecureToken() string {
|
||||
b := make([]byte, 16)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
// Fallback to something pseudo-random if crypto/rand fails
|
||||
return fmt.Sprintf("pico_%x", time.Now().UnixNano())
|
||||
}
|
||||
return hex.EncodeToString(b)
|
||||
}
|
||||
@@ -0,0 +1,66 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"sync"
|
||||
|
||||
"github.com/sipeed/picoclaw/web/backend/launcherconfig"
|
||||
)
|
||||
|
||||
// Handler serves HTTP API requests.
|
||||
type Handler struct {
|
||||
configPath string
|
||||
serverPort int
|
||||
serverPublic bool
|
||||
serverCIDRs []string
|
||||
oauthMu sync.Mutex
|
||||
oauthFlows map[string]*oauthFlow
|
||||
oauthState map[string]string
|
||||
}
|
||||
|
||||
// NewHandler creates an instance of the API handler.
|
||||
func NewHandler(configPath string) *Handler {
|
||||
return &Handler{
|
||||
configPath: configPath,
|
||||
serverPort: launcherconfig.DefaultPort,
|
||||
oauthFlows: make(map[string]*oauthFlow),
|
||||
oauthState: make(map[string]string),
|
||||
}
|
||||
}
|
||||
|
||||
// SetServerOptions stores current backend listen options for fallback behavior.
|
||||
func (h *Handler) SetServerOptions(port int, public bool, allowedCIDRs []string) {
|
||||
h.serverPort = port
|
||||
h.serverPublic = public
|
||||
h.serverCIDRs = append([]string(nil), allowedCIDRs...)
|
||||
}
|
||||
|
||||
// RegisterRoutes binds all API endpoint handlers to the ServeMux.
|
||||
func (h *Handler) RegisterRoutes(mux *http.ServeMux) {
|
||||
// Config CRUD
|
||||
h.registerConfigRoutes(mux)
|
||||
|
||||
// Pico Channel (WebSocket chat)
|
||||
h.registerPicoRoutes(mux)
|
||||
|
||||
// Gateway process lifecycle
|
||||
h.registerGatewayRoutes(mux)
|
||||
|
||||
// Session history
|
||||
h.registerSessionRoutes(mux)
|
||||
|
||||
// OAuth login and credential management
|
||||
h.registerOAuthRoutes(mux)
|
||||
|
||||
// Model list management
|
||||
h.registerModelRoutes(mux)
|
||||
|
||||
// Channel catalog (for frontend navigation/config pages)
|
||||
h.registerChannelRoutes(mux)
|
||||
|
||||
// OS startup / launch-at-login
|
||||
h.registerStartupRoutes(mux)
|
||||
|
||||
// Launcher service parameters (port/public)
|
||||
h.registerLauncherConfigRoutes(mux)
|
||||
}
|
||||
@@ -0,0 +1,286 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
// registerSessionRoutes binds session list and detail endpoints to the ServeMux.
|
||||
func (h *Handler) registerSessionRoutes(mux *http.ServeMux) {
|
||||
mux.HandleFunc("GET /api/sessions", h.handleListSessions)
|
||||
mux.HandleFunc("GET /api/sessions/{id}", h.handleGetSession)
|
||||
mux.HandleFunc("DELETE /api/sessions/{id}", h.handleDeleteSession)
|
||||
}
|
||||
|
||||
// sessionFile mirrors the on-disk session JSON structure from pkg/session.
|
||||
type sessionFile struct {
|
||||
Key string `json:"key"`
|
||||
Messages []providers.Message `json:"messages"`
|
||||
Summary string `json:"summary,omitempty"`
|
||||
Created time.Time `json:"created"`
|
||||
Updated time.Time `json:"updated"`
|
||||
}
|
||||
|
||||
// sessionListItem is a lightweight summary returned by GET /api/sessions.
|
||||
type sessionListItem struct {
|
||||
ID string `json:"id"`
|
||||
Preview string `json:"preview"`
|
||||
MessageCount int `json:"message_count"`
|
||||
Created string `json:"created"`
|
||||
Updated string `json:"updated"`
|
||||
}
|
||||
|
||||
// picoSessionPrefix is the key prefix used by the gateway's routing for Pico
|
||||
// channel sessions. The full key format is:
|
||||
//
|
||||
// agent:main:pico:direct:pico:<session-uuid>
|
||||
//
|
||||
// The sanitized filename replaces ':' with '_', so on disk it becomes:
|
||||
//
|
||||
// agent_main_pico_direct_pico_<session-uuid>.json
|
||||
const picoSessionPrefix = "agent:main:pico:direct:pico:"
|
||||
|
||||
// extractPicoSessionID extracts the session UUID from a full session key.
|
||||
// Returns the UUID and true if the key matches the Pico session pattern.
|
||||
func extractPicoSessionID(key string) (string, bool) {
|
||||
if strings.HasPrefix(key, picoSessionPrefix) {
|
||||
return strings.TrimPrefix(key, picoSessionPrefix), true
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
// sessionsDir resolves the path to the gateway's session storage directory.
|
||||
// It reads the workspace from config, falling back to ~/.picoclaw/workspace.
|
||||
func (h *Handler) sessionsDir() (string, error) {
|
||||
cfg, err := config.LoadConfig(h.configPath)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
workspace := cfg.Agents.Defaults.Workspace
|
||||
if workspace == "" {
|
||||
home, _ := os.UserHomeDir()
|
||||
workspace = filepath.Join(home, ".picoclaw", "workspace")
|
||||
}
|
||||
|
||||
// Expand ~ prefix
|
||||
if len(workspace) > 0 && workspace[0] == '~' {
|
||||
home, _ := os.UserHomeDir()
|
||||
if len(workspace) > 1 && workspace[1] == '/' {
|
||||
workspace = home + workspace[1:]
|
||||
} else {
|
||||
workspace = home
|
||||
}
|
||||
}
|
||||
|
||||
return filepath.Join(workspace, "sessions"), nil
|
||||
}
|
||||
|
||||
// handleListSessions returns a list of Pico session summaries.
|
||||
//
|
||||
// GET /api/sessions
|
||||
func (h *Handler) handleListSessions(w http.ResponseWriter, r *http.Request) {
|
||||
dir, err := h.sessionsDir()
|
||||
if err != nil {
|
||||
http.Error(w, "failed to resolve sessions directory", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
// Directory doesn't exist yet = no sessions
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode([]sessionListItem{})
|
||||
return
|
||||
}
|
||||
|
||||
items := []sessionListItem{}
|
||||
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() || filepath.Ext(entry.Name()) != ".json" {
|
||||
continue
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(filepath.Join(dir, entry.Name()))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
var sess sessionFile
|
||||
if err := json.Unmarshal(data, &sess); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Only include Pico channel sessions
|
||||
sessionID, ok := extractPicoSessionID(sess.Key)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
// Build a preview from the first user message
|
||||
preview := ""
|
||||
for _, msg := range sess.Messages {
|
||||
if msg.Role == "user" && strings.TrimSpace(msg.Content) != "" {
|
||||
preview = msg.Content
|
||||
break
|
||||
}
|
||||
}
|
||||
if len([]rune(preview)) > 60 {
|
||||
preview = string([]rune(preview)[:60]) + "..."
|
||||
}
|
||||
if preview == "" {
|
||||
preview = "(empty)"
|
||||
}
|
||||
|
||||
// Only count non-empty user and assistant messages
|
||||
validMessageCount := 0
|
||||
for _, msg := range sess.Messages {
|
||||
if (msg.Role == "user" || msg.Role == "assistant") && strings.TrimSpace(msg.Content) != "" {
|
||||
validMessageCount++
|
||||
}
|
||||
}
|
||||
|
||||
items = append(items, sessionListItem{
|
||||
ID: sessionID,
|
||||
Preview: preview,
|
||||
MessageCount: validMessageCount,
|
||||
Created: sess.Created.Format(time.RFC3339),
|
||||
Updated: sess.Updated.Format(time.RFC3339),
|
||||
})
|
||||
}
|
||||
|
||||
// Sort by updated descending (most recent first)
|
||||
sort.Slice(items, func(i, j int) bool {
|
||||
return items[i].Updated > items[j].Updated
|
||||
})
|
||||
|
||||
// Pagination parameters
|
||||
offsetStr := r.URL.Query().Get("offset")
|
||||
limitStr := r.URL.Query().Get("limit")
|
||||
|
||||
offset := 0
|
||||
limit := 20 // Default limit
|
||||
|
||||
if val, err := strconv.Atoi(offsetStr); err == nil && val >= 0 {
|
||||
offset = val
|
||||
}
|
||||
if val, err := strconv.Atoi(limitStr); err == nil && val > 0 {
|
||||
limit = val
|
||||
}
|
||||
|
||||
totalItems := len(items)
|
||||
|
||||
end := offset + limit
|
||||
if offset >= totalItems {
|
||||
items = []sessionListItem{} // Out of bounds, return empty
|
||||
} else {
|
||||
if end > totalItems {
|
||||
end = totalItems
|
||||
}
|
||||
items = items[offset:end]
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(items)
|
||||
}
|
||||
|
||||
// handleGetSession returns the full message history for a specific session.
|
||||
//
|
||||
// GET /api/sessions/{id}
|
||||
func (h *Handler) handleGetSession(w http.ResponseWriter, r *http.Request) {
|
||||
sessionID := r.PathValue("id")
|
||||
if sessionID == "" {
|
||||
http.Error(w, "missing session id", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
dir, err := h.sessionsDir()
|
||||
if err != nil {
|
||||
http.Error(w, "failed to resolve sessions directory", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// The sanitized filename replaces ':' with '_':
|
||||
// agent:main:pico:direct:pico:<uuid> -> agent_main_pico_direct_pico_<uuid>.json
|
||||
filename := strings.ReplaceAll(picoSessionPrefix+sessionID, ":", "_") + ".json"
|
||||
|
||||
data, err := os.ReadFile(filepath.Join(dir, filename))
|
||||
if err != nil {
|
||||
http.Error(w, "session not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
var sess sessionFile
|
||||
if err := json.Unmarshal(data, &sess); err != nil {
|
||||
http.Error(w, "failed to parse session", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Convert to a simpler format for the frontend
|
||||
type chatMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
messages := make([]chatMessage, 0, len(sess.Messages))
|
||||
for _, msg := range sess.Messages {
|
||||
// Only include user and assistant messages that have actual content
|
||||
if (msg.Role == "user" || msg.Role == "assistant") && strings.TrimSpace(msg.Content) != "" {
|
||||
messages = append(messages, chatMessage{
|
||||
Role: msg.Role,
|
||||
Content: msg.Content,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"id": sessionID,
|
||||
"messages": messages,
|
||||
"summary": sess.Summary,
|
||||
"created": sess.Created.Format(time.RFC3339),
|
||||
"updated": sess.Updated.Format(time.RFC3339),
|
||||
})
|
||||
}
|
||||
|
||||
// handleDeleteSession deletes a specific session.
|
||||
//
|
||||
// DELETE /api/sessions/{id}
|
||||
func (h *Handler) handleDeleteSession(w http.ResponseWriter, r *http.Request) {
|
||||
sessionID := r.PathValue("id")
|
||||
if sessionID == "" {
|
||||
http.Error(w, "missing session id", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
dir, err := h.sessionsDir()
|
||||
if err != nil {
|
||||
http.Error(w, "failed to resolve sessions directory", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// The sanitized filename replaces ':' with '_':
|
||||
// agent:main:pico:direct:pico:<uuid> -> agent_main_pico_direct_pico_<uuid>.json
|
||||
filename := strings.ReplaceAll(picoSessionPrefix+sessionID, ":", "_") + ".json"
|
||||
filePath := filepath.Join(dir, filename)
|
||||
|
||||
if err := os.Remove(filePath); err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
http.Error(w, "session not found", http.StatusNotFound)
|
||||
} else {
|
||||
http.Error(w, "failed to delete session", http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
@@ -0,0 +1,305 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
autoStartEntryName = "PicoClawLauncher"
|
||||
launchAgentLabel = "io.picoclaw.launcher"
|
||||
)
|
||||
|
||||
type autoStartRequest struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
}
|
||||
|
||||
type autoStartResponse struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Supported bool `json:"supported"`
|
||||
Platform string `json:"platform"`
|
||||
Message string `json:"message,omitempty"`
|
||||
}
|
||||
|
||||
var errAutoStartUnsupported = errors.New("autostart is not supported on this platform")
|
||||
|
||||
func (h *Handler) registerStartupRoutes(mux *http.ServeMux) {
|
||||
mux.HandleFunc("GET /api/system/autostart", h.handleGetAutoStart)
|
||||
mux.HandleFunc("PUT /api/system/autostart", h.handleSetAutoStart)
|
||||
}
|
||||
|
||||
func (h *Handler) handleGetAutoStart(w http.ResponseWriter, r *http.Request) {
|
||||
enabled, supported, message, err := h.getAutoStartStatus()
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("Failed to read startup setting: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(autoStartResponse{
|
||||
Enabled: enabled,
|
||||
Supported: supported,
|
||||
Platform: runtime.GOOS,
|
||||
Message: message,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Handler) handleSetAutoStart(w http.ResponseWriter, r *http.Request) {
|
||||
var req autoStartRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, fmt.Sprintf("Invalid JSON: %v", err), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.setAutoStart(req.Enabled); err != nil {
|
||||
if errors.Is(err, errAutoStartUnsupported) {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
http.Error(w, fmt.Sprintf("Failed to update startup setting: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
enabled, supported, message, err := h.getAutoStartStatus()
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("Failed to verify startup setting: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(autoStartResponse{
|
||||
Enabled: enabled,
|
||||
Supported: supported,
|
||||
Platform: runtime.GOOS,
|
||||
Message: message,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Handler) resolveLaunchCommand() (string, []string, error) {
|
||||
exePath, err := os.Executable()
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
args := []string{"-no-browser"}
|
||||
if h.configPath != "" {
|
||||
args = append(args, h.configPath)
|
||||
}
|
||||
|
||||
return exePath, args, nil
|
||||
}
|
||||
|
||||
func (h *Handler) getAutoStartStatus() (enabled bool, supported bool, message string, err error) {
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
exists, err := fileExists(macLaunchAgentPath())
|
||||
return exists, true, "Changes apply on next login.", err
|
||||
case "linux":
|
||||
exists, err := fileExists(linuxAutoStartPath())
|
||||
return exists, true, "Changes apply on next login.", err
|
||||
case "windows":
|
||||
exists, err := windowsRunKeyExists()
|
||||
return exists, true, "Changes apply on next login.", err
|
||||
default:
|
||||
return false, false, "Current platform does not support launch at login.", nil
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) setAutoStart(enabled bool) error {
|
||||
exePath, args, err := h.resolveLaunchCommand()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
return setDarwinAutoStart(enabled, exePath, args)
|
||||
case "linux":
|
||||
return setLinuxAutoStart(enabled, exePath, args)
|
||||
case "windows":
|
||||
return setWindowsAutoStart(enabled, exePath, args)
|
||||
default:
|
||||
return errAutoStartUnsupported
|
||||
}
|
||||
}
|
||||
|
||||
func fileExists(path string) (bool, error) {
|
||||
_, err := os.Stat(path)
|
||||
if err == nil {
|
||||
return true, nil
|
||||
}
|
||||
if os.IsNotExist(err) {
|
||||
return false, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
|
||||
func macLaunchAgentPath() string {
|
||||
home, _ := os.UserHomeDir()
|
||||
return filepath.Join(home, "Library", "LaunchAgents", launchAgentLabel+".plist")
|
||||
}
|
||||
|
||||
func setDarwinAutoStart(enabled bool, exePath string, args []string) error {
|
||||
plistPath := macLaunchAgentPath()
|
||||
if enabled {
|
||||
if err := os.MkdirAll(filepath.Dir(plistPath), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
content := buildDarwinPlist(exePath, args)
|
||||
return os.WriteFile(plistPath, []byte(content), 0o644)
|
||||
}
|
||||
|
||||
if err := os.Remove(plistPath); err != nil && !os.IsNotExist(err) {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func xmlEscape(s string) string {
|
||||
var b bytes.Buffer
|
||||
for _, r := range s {
|
||||
switch r {
|
||||
case '&':
|
||||
b.WriteString("&")
|
||||
case '<':
|
||||
b.WriteString("<")
|
||||
case '>':
|
||||
b.WriteString(">")
|
||||
case '"':
|
||||
b.WriteString(""")
|
||||
case '\'':
|
||||
b.WriteString("'")
|
||||
default:
|
||||
b.WriteRune(r)
|
||||
}
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func buildDarwinPlist(exePath string, args []string) string {
|
||||
programArgs := make([]string, 0, len(args)+1)
|
||||
programArgs = append(programArgs, exePath)
|
||||
programArgs = append(programArgs, args...)
|
||||
|
||||
var b strings.Builder
|
||||
b.WriteString(`<?xml version="1.0" encoding="UTF-8"?>` + "\n")
|
||||
b.WriteString(
|
||||
`<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">` + "\n",
|
||||
)
|
||||
b.WriteString(`<plist version="1.0">` + "\n")
|
||||
b.WriteString(`<dict>` + "\n")
|
||||
b.WriteString(` <key>Label</key>` + "\n")
|
||||
b.WriteString(` <string>` + launchAgentLabel + `</string>` + "\n")
|
||||
b.WriteString(` <key>ProgramArguments</key>` + "\n")
|
||||
b.WriteString(` <array>` + "\n")
|
||||
for _, arg := range programArgs {
|
||||
b.WriteString(` <string>` + xmlEscape(arg) + `</string>` + "\n")
|
||||
}
|
||||
b.WriteString(` </array>` + "\n")
|
||||
b.WriteString(` <key>RunAtLoad</key>` + "\n")
|
||||
b.WriteString(` <true/>` + "\n")
|
||||
b.WriteString(` <key>ProcessType</key>` + "\n")
|
||||
b.WriteString(` <string>Background</string>` + "\n")
|
||||
b.WriteString(`</dict>` + "\n")
|
||||
b.WriteString(`</plist>` + "\n")
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func linuxAutoStartPath() string {
|
||||
home, _ := os.UserHomeDir()
|
||||
return filepath.Join(home, ".config", "autostart", "picoclaw-web.desktop")
|
||||
}
|
||||
|
||||
func shellQuote(s string) string {
|
||||
if s == "" {
|
||||
return "''"
|
||||
}
|
||||
if !strings.ContainsAny(s, " \t\n'\"\\$`") {
|
||||
return s
|
||||
}
|
||||
return "'" + strings.ReplaceAll(s, "'", "'\"'\"'") + "'"
|
||||
}
|
||||
|
||||
func buildLinuxExecLine(exePath string, args []string) string {
|
||||
parts := make([]string, 0, len(args)+1)
|
||||
parts = append(parts, shellQuote(exePath))
|
||||
for _, arg := range args {
|
||||
parts = append(parts, shellQuote(arg))
|
||||
}
|
||||
return strings.Join(parts, " ")
|
||||
}
|
||||
|
||||
func setLinuxAutoStart(enabled bool, exePath string, args []string) error {
|
||||
desktopPath := linuxAutoStartPath()
|
||||
if enabled {
|
||||
if err := os.MkdirAll(filepath.Dir(desktopPath), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
content := strings.Join([]string{
|
||||
"[Desktop Entry]",
|
||||
"Type=Application",
|
||||
"Version=1.0",
|
||||
"Name=PicoClaw Web",
|
||||
"Comment=Start PicoClaw Web on login",
|
||||
"Exec=" + buildLinuxExecLine(exePath, args),
|
||||
"Terminal=false",
|
||||
"X-GNOME-Autostart-enabled=true",
|
||||
"NoDisplay=true",
|
||||
"",
|
||||
}, "\n")
|
||||
return os.WriteFile(desktopPath, []byte(content), 0o644)
|
||||
}
|
||||
|
||||
if err := os.Remove(desktopPath); err != nil && !os.IsNotExist(err) {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func windowsCommandLine(exePath string, args []string) string {
|
||||
parts := make([]string, 0, len(args)+1)
|
||||
parts = append(parts, fmt.Sprintf("%q", exePath))
|
||||
for _, arg := range args {
|
||||
parts = append(parts, fmt.Sprintf("%q", arg))
|
||||
}
|
||||
return strings.Join(parts, " ")
|
||||
}
|
||||
|
||||
func windowsRunKeyExists() (bool, error) {
|
||||
cmd := exec.Command("reg", "query", `HKCU\Software\Microsoft\Windows\CurrentVersion\Run`, "/v", autoStartEntryName)
|
||||
if err := cmd.Run(); err != nil {
|
||||
var exitErr *exec.ExitError
|
||||
if errors.As(err, &exitErr) {
|
||||
return false, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func setWindowsAutoStart(enabled bool, exePath string, args []string) error {
|
||||
key := `HKCU\Software\Microsoft\Windows\CurrentVersion\Run`
|
||||
if enabled {
|
||||
commandLine := windowsCommandLine(exePath, args)
|
||||
cmd := exec.Command("reg", "add", key, "/v", autoStartEntryName, "/t", "REG_SZ", "/d", commandLine, "/f")
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
cmd := exec.Command("reg", "delete", key, "/v", autoStartEntryName, "/f")
|
||||
if err := cmd.Run(); err != nil {
|
||||
var exitErr *exec.ExitError
|
||||
if errors.As(err, &exitErr) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,56 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/web/backend/launcherconfig"
|
||||
)
|
||||
|
||||
func TestResolveLaunchCommandUsesConfigFileDefaults(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
|
||||
// Persist non-default launcher options to ensure resolveLaunchCommand does not
|
||||
// pin them into autostart args.
|
||||
launcherPath := launcherconfig.PathForAppConfig(configPath)
|
||||
if err := launcherconfig.Save(launcherPath, launcherconfig.Config{
|
||||
Port: 19999,
|
||||
Public: true,
|
||||
}); err != nil {
|
||||
t.Fatalf("launcherconfig.Save() error = %v", err)
|
||||
}
|
||||
|
||||
exePath, args, err := h.resolveLaunchCommand()
|
||||
if err != nil {
|
||||
t.Fatalf("resolveLaunchCommand() error = %v", err)
|
||||
}
|
||||
if exePath == "" {
|
||||
t.Fatal("resolveLaunchCommand() returned empty executable path")
|
||||
}
|
||||
if len(args) != 2 {
|
||||
t.Fatalf("args len = %d, want 2 (got %v)", len(args), args)
|
||||
}
|
||||
if args[0] != "-no-browser" {
|
||||
t.Fatalf("args[0] = %q, want %q", args[0], "-no-browser")
|
||||
}
|
||||
if args[1] != configPath {
|
||||
t.Fatalf("args[1] = %q, want %q", args[1], configPath)
|
||||
}
|
||||
for _, arg := range args {
|
||||
if arg == "-port" || arg == "-public" {
|
||||
t.Fatalf("autostart args should not pin network flags, got %v", args)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildDarwinPlistIncludesRunAtLoad(t *testing.T) {
|
||||
plist := buildDarwinPlist("/tmp/picoclaw-web", []string{"-no-browser", "/tmp/config.json"})
|
||||
if !strings.Contains(plist, "<key>RunAtLoad</key>") {
|
||||
t.Fatalf("plist missing RunAtLoad key:\n%s", plist)
|
||||
}
|
||||
if !strings.Contains(plist, "<true/>") {
|
||||
t.Fatalf("plist missing RunAtLoad true value:\n%s", plist)
|
||||
}
|
||||
}
|
||||
Vendored
@@ -0,0 +1,69 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"embed"
|
||||
"io/fs"
|
||||
"log"
|
||||
"net/http"
|
||||
"path"
|
||||
"strings"
|
||||
)
|
||||
|
||||
//go:embed all:dist
|
||||
var frontendFS embed.FS
|
||||
|
||||
// registerEmbedRoutes sets up the HTTP handler to serve the embedded frontend files
|
||||
func registerEmbedRoutes(mux *http.ServeMux) {
|
||||
// Attempt to get the subdirectory 'dist' where Vite usually builds
|
||||
subFS, err := fs.Sub(frontendFS, "dist")
|
||||
if err != nil {
|
||||
// Log a warning if dist doesn't exist yet (e.g., during development before a frontend build)
|
||||
log.Printf(
|
||||
"Warning: no 'dist' folder found in embedded frontend. " +
|
||||
"Ensure you run `pnpm build:backend` in the frontend directory " +
|
||||
"before building the Go backend.",
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
fileServer := http.FileServer(http.FS(subFS))
|
||||
|
||||
// Serve static assets and fallback to index.html for SPA routes.
|
||||
mux.Handle(
|
||||
"/",
|
||||
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet && r.Method != http.MethodHead {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// Keep unknown API paths as 404 instead of falling back to SPA entry.
|
||||
if r.URL.Path == "/api" || strings.HasPrefix(r.URL.Path, "/api/") {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
cleanPath := path.Clean(strings.TrimPrefix(r.URL.Path, "/"))
|
||||
if cleanPath == "." {
|
||||
cleanPath = ""
|
||||
}
|
||||
|
||||
// Existing static files/directories should be served directly.
|
||||
if cleanPath != "" {
|
||||
if _, statErr := fs.Stat(subFS, cleanPath); statErr == nil {
|
||||
fileServer.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
// Missing asset-like paths should remain 404.
|
||||
if strings.Contains(path.Base(cleanPath), ".") {
|
||||
fileServer.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
indexReq := r.Clone(r.Context())
|
||||
indexReq.URL.Path = "/"
|
||||
fileServer.ServeHTTP(w, indexReq)
|
||||
}),
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,33 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestUnknownAPIPathStays404(t *testing.T) {
|
||||
mux := http.NewServeMux()
|
||||
registerEmbedRoutes(mux)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/not-found", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
mux.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != http.StatusNotFound {
|
||||
t.Fatalf("status = %d, want %d", rr.Code, http.StatusNotFound)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMissingAssetStays404(t *testing.T) {
|
||||
mux := http.NewServeMux()
|
||||
registerEmbedRoutes(mux)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/assets/not-found.js", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
mux.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != http.StatusNotFound {
|
||||
t.Fatalf("status = %d, want %d", rr.Code, http.StatusNotFound)
|
||||
}
|
||||
}
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 44 KiB |
@@ -0,0 +1,113 @@
|
||||
package launcherconfig
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
// FileName is the launcher-specific settings file name.
|
||||
FileName = "launcher-config.json"
|
||||
// DefaultPort is the default port for the web launcher.
|
||||
DefaultPort = 18800
|
||||
)
|
||||
|
||||
// Config stores launch parameters for the web backend service.
|
||||
type Config struct {
|
||||
Port int `json:"port"`
|
||||
Public bool `json:"public"`
|
||||
AllowedCIDRs []string `json:"allowed_cidrs,omitempty"`
|
||||
}
|
||||
|
||||
// Default returns default launcher settings.
|
||||
func Default() Config {
|
||||
return Config{Port: DefaultPort, Public: false}
|
||||
}
|
||||
|
||||
// Validate checks if launcher settings are valid.
|
||||
func Validate(cfg Config) error {
|
||||
if cfg.Port < 1 || cfg.Port > 65535 {
|
||||
return fmt.Errorf("port %d is out of range (1-65535)", cfg.Port)
|
||||
}
|
||||
for _, cidr := range cfg.AllowedCIDRs {
|
||||
if _, _, err := net.ParseCIDR(cidr); err != nil {
|
||||
return fmt.Errorf("invalid CIDR %q", cidr)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// NormalizeCIDRs trims entries, removes empty values, and deduplicates CIDRs.
|
||||
func NormalizeCIDRs(cidrs []string) []string {
|
||||
if len(cidrs) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]string, 0, len(cidrs))
|
||||
seen := make(map[string]struct{}, len(cidrs))
|
||||
for _, raw := range cidrs {
|
||||
trimmed := strings.TrimSpace(raw)
|
||||
if trimmed == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[trimmed]; ok {
|
||||
continue
|
||||
}
|
||||
seen[trimmed] = struct{}{}
|
||||
out = append(out, trimmed)
|
||||
}
|
||||
if len(out) == 0 {
|
||||
return nil
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// PathForAppConfig returns launcher-config path near the app config file.
|
||||
func PathForAppConfig(appConfigPath string) string {
|
||||
dir := filepath.Dir(appConfigPath)
|
||||
if dir == "" || dir == "." {
|
||||
dir = "."
|
||||
}
|
||||
return filepath.Join(dir, FileName)
|
||||
}
|
||||
|
||||
// Load reads launcher settings; fallback is returned when file does not exist.
|
||||
func Load(path string, fallback Config) (Config, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return fallback, nil
|
||||
}
|
||||
return Config{}, err
|
||||
}
|
||||
|
||||
cfg := fallback
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
cfg.AllowedCIDRs = NormalizeCIDRs(cfg.AllowedCIDRs)
|
||||
if err := Validate(cfg); err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
// Save writes launcher settings to disk.
|
||||
func Save(path string, cfg Config) error {
|
||||
cfg.AllowedCIDRs = NormalizeCIDRs(cfg.AllowedCIDRs)
|
||||
if err := Validate(cfg); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
data, err := json.MarshalIndent(cfg, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
data = append(data, '\n')
|
||||
return os.WriteFile(path, data, 0o600)
|
||||
}
|
||||
@@ -0,0 +1,89 @@
|
||||
package launcherconfig
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestLoadReturnsFallbackWhenMissing(t *testing.T) {
|
||||
path := filepath.Join(t.TempDir(), "launcher-config.json")
|
||||
fallback := Config{Port: 19999, Public: true}
|
||||
|
||||
got, err := Load(path, fallback)
|
||||
if err != nil {
|
||||
t.Fatalf("Load() error = %v", err)
|
||||
}
|
||||
if got.Port != fallback.Port || got.Public != fallback.Public {
|
||||
t.Fatalf("Load() = %+v, want %+v", got, fallback)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveAndLoadRoundTrip(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "launcher-config.json")
|
||||
want := Config{
|
||||
Port: 18080,
|
||||
Public: true,
|
||||
AllowedCIDRs: []string{"192.168.1.0/24", "10.0.0.0/8"},
|
||||
}
|
||||
|
||||
if err := Save(path, want); err != nil {
|
||||
t.Fatalf("Save() error = %v", err)
|
||||
}
|
||||
got, err := Load(path, Default())
|
||||
if err != nil {
|
||||
t.Fatalf("Load() error = %v", err)
|
||||
}
|
||||
if got.Port != want.Port || got.Public != want.Public {
|
||||
t.Fatalf("Load() = %+v, want %+v", got, want)
|
||||
}
|
||||
if len(got.AllowedCIDRs) != len(want.AllowedCIDRs) {
|
||||
t.Fatalf("allowed_cidrs len = %d, want %d", len(got.AllowedCIDRs), len(want.AllowedCIDRs))
|
||||
}
|
||||
for i := range want.AllowedCIDRs {
|
||||
if got.AllowedCIDRs[i] != want.AllowedCIDRs[i] {
|
||||
t.Fatalf("allowed_cidrs[%d] = %q, want %q", i, got.AllowedCIDRs[i], want.AllowedCIDRs[i])
|
||||
}
|
||||
}
|
||||
|
||||
stat, err := os.Stat(path)
|
||||
if err != nil {
|
||||
t.Fatalf("Stat() error = %v", err)
|
||||
}
|
||||
if perm := stat.Mode().Perm(); perm != 0o600 {
|
||||
t.Fatalf("file perm = %o, want 600", perm)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateRejectsInvalidPort(t *testing.T) {
|
||||
if err := Validate(Config{Port: 0, Public: false}); err == nil {
|
||||
t.Fatal("Validate() expected error for port 0")
|
||||
}
|
||||
if err := Validate(Config{Port: 65536, Public: false}); err == nil {
|
||||
t.Fatal("Validate() expected error for port 65536")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateRejectsInvalidCIDR(t *testing.T) {
|
||||
err := Validate(Config{
|
||||
Port: 18800,
|
||||
AllowedCIDRs: []string{"192.168.1.0/24", "not-a-cidr"},
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("Validate() expected error for invalid CIDR")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeCIDRs(t *testing.T) {
|
||||
got := NormalizeCIDRs([]string{" 192.168.1.0/24 ", "", "10.0.0.0/8", "192.168.1.0/24"})
|
||||
want := []string{"192.168.1.0/24", "10.0.0.0/8"}
|
||||
if len(got) != len(want) {
|
||||
t.Fatalf("len(got) = %d, want %d", len(got), len(want))
|
||||
}
|
||||
for i := range want {
|
||||
if got[i] != want[i] {
|
||||
t.Fatalf("got[%d] = %q, want %q", i, got[i], want[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,164 @@
|
||||
// PicoClaw Web Console - Web-based chat and management interface
|
||||
//
|
||||
// Provides a web UI for chatting with PicoClaw via the Pico Channel WebSocket,
|
||||
// with configuration management and gateway process control.
|
||||
//
|
||||
// Usage:
|
||||
//
|
||||
// go build -o picoclaw-web ./web/backend/
|
||||
// ./picoclaw-web [config.json]
|
||||
// ./picoclaw-web -public config.json
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/web/backend/api"
|
||||
"github.com/sipeed/picoclaw/web/backend/launcherconfig"
|
||||
"github.com/sipeed/picoclaw/web/backend/middleware"
|
||||
)
|
||||
|
||||
func main() {
|
||||
port := flag.String("port", "18800", "Port to listen on")
|
||||
public := flag.Bool("public", false, "Listen on all interfaces (0.0.0.0) instead of localhost only")
|
||||
noBrowser := flag.Bool("no-browser", false, "Do not auto-open browser on startup")
|
||||
|
||||
flag.Usage = func() {
|
||||
fmt.Fprintf(os.Stderr, "PicoClaw Launcher - A web-based configuration editor\n\n")
|
||||
fmt.Fprintf(os.Stderr, "Usage: %s [options] [config.json]\n\n", os.Args[0])
|
||||
fmt.Fprintf(os.Stderr, "Arguments:\n")
|
||||
fmt.Fprintf(os.Stderr, " config.json Path to the configuration file (default: ~/.picoclaw/config.json)\n\n")
|
||||
fmt.Fprintf(os.Stderr, "Options:\n")
|
||||
flag.PrintDefaults()
|
||||
fmt.Fprintf(os.Stderr, "\nExamples:\n")
|
||||
fmt.Fprintf(os.Stderr, " %s Use default config path\n", os.Args[0])
|
||||
fmt.Fprintf(os.Stderr, " %s ./config.json Specify a config file\n", os.Args[0])
|
||||
fmt.Fprintf(
|
||||
os.Stderr,
|
||||
" %s -public ./config.json Allow access from other devices on the network\n",
|
||||
os.Args[0],
|
||||
)
|
||||
}
|
||||
flag.Parse()
|
||||
|
||||
// Resolve config path
|
||||
configPath := getDefaultConfigPath()
|
||||
if flag.NArg() > 0 {
|
||||
configPath = flag.Arg(0)
|
||||
}
|
||||
|
||||
absPath, err := filepath.Abs(configPath)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to resolve config path: %v", err)
|
||||
}
|
||||
|
||||
var explicitPort bool
|
||||
var explicitPublic bool
|
||||
flag.Visit(func(f *flag.Flag) {
|
||||
switch f.Name {
|
||||
case "port":
|
||||
explicitPort = true
|
||||
case "public":
|
||||
explicitPublic = true
|
||||
}
|
||||
})
|
||||
|
||||
launcherPath := launcherconfig.PathForAppConfig(absPath)
|
||||
launcherCfg, err := launcherconfig.Load(launcherPath, launcherconfig.Default())
|
||||
if err != nil {
|
||||
log.Printf("Warning: Failed to load %s: %v", launcherPath, err)
|
||||
launcherCfg = launcherconfig.Default()
|
||||
}
|
||||
|
||||
effectivePort := *port
|
||||
effectivePublic := *public
|
||||
if !explicitPort {
|
||||
effectivePort = strconv.Itoa(launcherCfg.Port)
|
||||
}
|
||||
if !explicitPublic {
|
||||
effectivePublic = launcherCfg.Public
|
||||
}
|
||||
|
||||
portNum, err := strconv.Atoi(effectivePort)
|
||||
if err != nil || portNum < 1 || portNum > 65535 {
|
||||
if err == nil {
|
||||
err = errors.New("must be in range 1-65535")
|
||||
}
|
||||
log.Fatalf("Invalid port %q: %v", effectivePort, err)
|
||||
}
|
||||
|
||||
// Determine listen address
|
||||
var addr string
|
||||
if effectivePublic {
|
||||
addr = "0.0.0.0:" + effectivePort
|
||||
} else {
|
||||
addr = "127.0.0.1:" + effectivePort
|
||||
}
|
||||
|
||||
// Initialize Server components
|
||||
mux := http.NewServeMux()
|
||||
|
||||
// API Routes (e.g. /api/status)
|
||||
apiHandler := api.NewHandler(absPath)
|
||||
apiHandler.SetServerOptions(portNum, effectivePublic, launcherCfg.AllowedCIDRs)
|
||||
apiHandler.RegisterRoutes(mux)
|
||||
|
||||
// Frontend Embedded Assets
|
||||
registerEmbedRoutes(mux)
|
||||
|
||||
accessControlledMux, err := middleware.IPAllowlist(launcherCfg.AllowedCIDRs, mux)
|
||||
if err != nil {
|
||||
log.Fatalf("Invalid allowed CIDR configuration: %v", err)
|
||||
}
|
||||
|
||||
// Apply middleware stack
|
||||
handler := middleware.Recoverer(
|
||||
middleware.Logger(
|
||||
middleware.JSONContentType(accessControlledMux),
|
||||
),
|
||||
)
|
||||
|
||||
// Print startup banner
|
||||
fmt.Print(banner)
|
||||
fmt.Println()
|
||||
fmt.Println(" Open the following URL in your browser:")
|
||||
fmt.Println()
|
||||
fmt.Printf(" >> http://localhost:%s <<\n", effectivePort)
|
||||
if effectivePublic {
|
||||
if ip := getLocalIP(); ip != "" {
|
||||
fmt.Printf(" >> http://%s:%s <<\n", ip, effectivePort)
|
||||
}
|
||||
}
|
||||
fmt.Println()
|
||||
|
||||
// Auto-open browser
|
||||
if !*noBrowser {
|
||||
go func() {
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
url := "http://localhost:" + effectivePort
|
||||
if err := openBrowser(url); err != nil {
|
||||
log.Printf("Warning: Failed to auto-open browser: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Auto-start gateway after backend starts listening.
|
||||
go func() {
|
||||
time.Sleep(1 * time.Second)
|
||||
apiHandler.TryAutoStartGateway()
|
||||
}()
|
||||
|
||||
// Start the Server
|
||||
if err := http.ListenAndServe(addr, handler); err != nil {
|
||||
log.Fatalf("Server failed to start: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,64 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// IPAllowlist restricts access to requests from configured CIDR ranges.
|
||||
// Loopback addresses are always allowed for local administration.
|
||||
// Empty CIDR list means no restriction.
|
||||
func IPAllowlist(allowedCIDRs []string, next http.Handler) (http.Handler, error) {
|
||||
if len(allowedCIDRs) == 0 {
|
||||
return next, nil
|
||||
}
|
||||
|
||||
nets := make([]*net.IPNet, 0, len(allowedCIDRs))
|
||||
for _, cidr := range allowedCIDRs {
|
||||
_, ipNet, err := net.ParseCIDR(cidr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid CIDR %q: %w", cidr, err)
|
||||
}
|
||||
nets = append(nets, ipNet)
|
||||
}
|
||||
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ip := clientIPFromRemoteAddr(r.RemoteAddr)
|
||||
if ip == nil {
|
||||
rejectByPolicy(w, r)
|
||||
return
|
||||
}
|
||||
if ip.IsLoopback() {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
for _, ipNet := range nets {
|
||||
if ipNet.Contains(ip) {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
rejectByPolicy(w, r)
|
||||
}), nil
|
||||
}
|
||||
|
||||
func clientIPFromRemoteAddr(remoteAddr string) net.IP {
|
||||
host := remoteAddr
|
||||
if h, _, err := net.SplitHostPort(remoteAddr); err == nil {
|
||||
host = h
|
||||
}
|
||||
return net.ParseIP(host)
|
||||
}
|
||||
|
||||
func rejectByPolicy(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.HasPrefix(r.URL.Path, "/api/") {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
_, _ = w.Write([]byte(`{"error":"access denied by network policy"}`))
|
||||
return
|
||||
}
|
||||
http.Error(w, "Forbidden", http.StatusForbidden)
|
||||
}
|
||||
@@ -0,0 +1,86 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestIPAllowlist_EmptyCIDRsAllowsAll(t *testing.T) {
|
||||
h, err := IPAllowlist(nil, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
if err != nil {
|
||||
t.Fatalf("IPAllowlist() error = %v", err)
|
||||
}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.RemoteAddr = "203.0.113.5:1234"
|
||||
h.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIPAllowlist_RejectsOutsideCIDR(t *testing.T) {
|
||||
h, err := IPAllowlist([]string{"192.168.1.0/24"}, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
if err != nil {
|
||||
t.Fatalf("IPAllowlist() error = %v", err)
|
||||
}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/config", nil)
|
||||
req.RemoteAddr = "10.0.0.8:1234"
|
||||
h.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusForbidden {
|
||||
t.Fatalf("status = %d, want %d", rec.Code, http.StatusForbidden)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIPAllowlist_AllowsInsideCIDR(t *testing.T) {
|
||||
h, err := IPAllowlist([]string{"192.168.1.0/24"}, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
if err != nil {
|
||||
t.Fatalf("IPAllowlist() error = %v", err)
|
||||
}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.RemoteAddr = "192.168.1.88:1234"
|
||||
h.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIPAllowlist_AlwaysAllowsLoopback(t *testing.T) {
|
||||
h, err := IPAllowlist([]string{"192.168.1.0/24"}, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
if err != nil {
|
||||
t.Fatalf("IPAllowlist() error = %v", err)
|
||||
}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.RemoteAddr = "127.0.0.1:1234"
|
||||
h.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIPAllowlist_InvalidCIDR(t *testing.T) {
|
||||
_, err := IPAllowlist([]string{"bad-cidr"}, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
||||
if err == nil {
|
||||
t.Fatal("IPAllowlist() expected error for invalid CIDR")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,70 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// JSONContentType sets the Content-Type header to application/json for
|
||||
// API requests handled by the wrapped handler.
|
||||
// SSE endpoints (text/event-stream) are excluded.
|
||||
func JSONContentType(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.HasPrefix(r.URL.Path, "/api/") && !strings.HasSuffix(r.URL.Path, "/events") {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// responseRecorder wraps http.ResponseWriter to capture the status code.
|
||||
type responseRecorder struct {
|
||||
http.ResponseWriter
|
||||
statusCode int
|
||||
}
|
||||
|
||||
func (rr *responseRecorder) WriteHeader(code int) {
|
||||
rr.statusCode = code
|
||||
rr.ResponseWriter.WriteHeader(code)
|
||||
}
|
||||
|
||||
// Flush delegates to the underlying ResponseWriter if it implements http.Flusher.
|
||||
// This is required for SSE (Server-Sent Events) to work through the middleware.
|
||||
func (rr *responseRecorder) Flush() {
|
||||
if f, ok := rr.ResponseWriter.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
// Unwrap returns the underlying ResponseWriter so that http.ResponseController
|
||||
// and interface checks (like http.Flusher) can see through the wrapper.
|
||||
func (rr *responseRecorder) Unwrap() http.ResponseWriter {
|
||||
return rr.ResponseWriter
|
||||
}
|
||||
|
||||
// Logger logs each HTTP request with method, path, status code, and duration.
|
||||
func Logger(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
start := time.Now()
|
||||
rec := &responseRecorder{ResponseWriter: w, statusCode: http.StatusOK}
|
||||
next.ServeHTTP(rec, r)
|
||||
log.Printf("%s %s %d %s", r.Method, r.URL.Path, rec.statusCode, time.Since(start))
|
||||
})
|
||||
}
|
||||
|
||||
// Recoverer recovers from panics in downstream handlers and returns a 500
|
||||
// Internal Server Error response.
|
||||
func Recoverer(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
log.Printf("panic recovered: %v\n%s", err, debug.Stack())
|
||||
http.Error(w, `{"error":"internal server error"}`, http.StatusInternalServerError)
|
||||
}
|
||||
}()
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,8 @@
|
||||
package model
|
||||
|
||||
// StatusResponse represents the response payload for the GET /api/status endpoint.
|
||||
type StatusResponse struct {
|
||||
Status string `json:"status"`
|
||||
Version string `json:"version"`
|
||||
Uptime string `json:"uptime"`
|
||||
}
|
||||
@@ -0,0 +1,61 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
const (
|
||||
colorBlue = "\x1b[38;2;62;93;185m"
|
||||
colorRed = "\x1b[38;2;213;70;70m"
|
||||
colorReset = "\x1b[0m"
|
||||
banner = "\r\n" +
|
||||
colorBlue + "██████╗ ██╗ ██████╗ ██████╗ " + colorRed + " ██████╗██╗ █████╗ ██╗ ██╗\n" +
|
||||
colorBlue + "██╔══██╗██║██╔════╝██╔═══██╗" + colorRed + "██╔════╝██║ ██╔══██╗██║ ██║\n" +
|
||||
colorBlue + "██████╔╝██║██║ ██║ ██║" + colorRed + "██║ ██║ ███████║██║ █╗ ██║\n" +
|
||||
colorBlue + "██╔═══╝ ██║██║ ██║ ██║" + colorRed + "██║ ██║ ██╔══██║██║███╗██║\n" +
|
||||
colorBlue + "██║ ██║╚██████╗╚██████╔╝" + colorRed + "╚██████╗███████╗██║ ██║╚███╔███╔╝\n" +
|
||||
colorBlue + "╚═╝ ╚═╝ ╚═════╝ ╚═════╝ " + colorRed + " ╚═════╝╚══════╝╚═╝ ╚═╝ ╚══╝╚══╝\n" +
|
||||
colorReset
|
||||
)
|
||||
|
||||
// getDefaultConfigPath returns the default path to the picoclaw config file.
|
||||
func getDefaultConfigPath() string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "config.json"
|
||||
}
|
||||
return filepath.Join(home, ".picoclaw", "config.json")
|
||||
}
|
||||
|
||||
// getLocalIP returns the local IP address of the machine.
|
||||
func getLocalIP() string {
|
||||
addrs, err := net.InterfaceAddrs()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
for _, a := range addrs {
|
||||
if ipnet, ok := a.(*net.IPNet); ok && !ipnet.IP.IsLoopback() && ipnet.IP.To4() != nil {
|
||||
return ipnet.IP.String()
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// openBrowser automatically opens the given URL in the default browser.
|
||||
func openBrowser(url string) error {
|
||||
switch runtime.GOOS {
|
||||
case "linux":
|
||||
return exec.Command("xdg-open", url).Start()
|
||||
case "windows":
|
||||
return exec.Command("rundll32", "url.dll,FileProtocolHandler", url).Start()
|
||||
case "darwin":
|
||||
return exec.Command("open", url).Start()
|
||||
default:
|
||||
return fmt.Errorf("unsupported platform")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,22 @@
|
||||
{
|
||||
"RT_GROUP_ICON": {
|
||||
"APP": {
|
||||
"0000": "../icon.ico"
|
||||
}
|
||||
},
|
||||
"RT_MANIFEST": {
|
||||
"#1": {
|
||||
"0409": {
|
||||
"identity": {
|
||||
"name": "PicoClaw Launcher",
|
||||
"version": "0.0.0.0"
|
||||
},
|
||||
"description": "PicoClaw Launcher - Web-based configuration editor",
|
||||
"minimum-os": "win7",
|
||||
"execution-level": "asInvoker",
|
||||
"dpi-awareness": "system",
|
||||
"use-common-controls-v6": true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user