From f6190b54de244134aabf7040facafcf28330726b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=82=86=E6=9C=88?= <2835601846@qq.com> Date: Fri, 15 May 2026 09:49:03 +0800 Subject: [PATCH] feat(web,api): fetch models and saved catalog support (#2832) * feat(web,api): add fetch models and saved catalog support Split from PR #2752 (part 2 of 3). Backend: - /api/models/catalog endpoint for browsing remote model catalogs - /api/models/fetch endpoint for fetching available models from providers - Credential reuse with provider/API base matching for security - Default API base resolution for providers without explicit base Frontend: - FetchModelsDialog for importing models from remote providers - CatalogDialog for browsing and importing from model catalogs - Static import for FetchModelsDialog (replaces dynamic import from PR1) - Dynamic import retained for TestModelDialog (PR3 territory) * fix(web,api): support bare-array responses in fetchOpenAICompatibleModels * fix(web,api): tighten maskAPIKeyValue to match maskAPIKey policy For 9-12 character keys, maskAPIKeyValue exposed first 4 + last 4 chars (only 1 char masked for a 9-char key). Now uses the same policy as maskAPIKey: first 3 + last 2 for 9-12 chars, first 3 + last 4 for longer keys. Adds tests covering all key length boundaries. --- web/backend/api/model_catalog.go | 169 +++++++++ web/backend/api/model_catalog_test.go | 87 +++++ web/backend/api/models.go | 213 +++++++++++ web/backend/api/models_test.go | 172 +++++++++ .../src/components/models/add-model-sheet.tsx | 31 +- .../src/components/models/catalog-dialog.tsx | 333 +++++++++++++++++- .../components/models/edit-model-sheet.tsx | 31 +- .../components/models/fetch-models-dialog.tsx | 226 +++++++++++- .../src/components/models/models-page.tsx | 24 +- 9 files changed, 1225 insertions(+), 61 deletions(-) create mode 100644 web/backend/api/model_catalog.go create mode 100644 web/backend/api/model_catalog_test.go diff --git a/web/backend/api/model_catalog.go b/web/backend/api/model_catalog.go new file mode 100644 index 000000000..ce50deafe --- /dev/null +++ b/web/backend/api/model_catalog.go @@ -0,0 +1,169 @@ +package api + +import ( + "crypto/sha256" + "encoding/json" + "fmt" + "net/http" + "os" + "path/filepath" + "strings" + "time" + + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/fileutil" +) + +// CatalogModel represents a single model entry in a saved catalog. +type CatalogModel struct { + ID string `json:"id"` + OwnedBy string `json:"owned_by,omitempty"` + Extra map[string]any `json:"extra,omitempty"` +} + +// CatalogEntry is a saved list of upstream models fetched for a specific provider+key combination. +type CatalogEntry struct { + ID string `json:"id"` + Provider string `json:"provider"` + APIBase string `json:"api_base"` + APIKeyMask string `json:"api_key_mask"` + Models []CatalogModel `json:"models"` + FetchedAt string `json:"fetched_at"` +} + +// CatalogStore holds all saved model catalogs. +type CatalogStore struct { + Entries map[string]*CatalogEntry `json:"entries"` +} + +func catalogFilePath() string { + return filepath.Join(config.GetHome(), "model_catalogs.json") +} + +// generateCatalogKey creates a deterministic key for a provider+base+key combination. +func generateCatalogKey(provider, apiBase, apiKey string) string { + provider = strings.ToLower(strings.TrimSpace(provider)) + apiBase = strings.TrimRight(strings.TrimSpace(apiBase), "/") + hash := sha256.Sum256([]byte(apiKey)) + return fmt.Sprintf("%s|%s|%x", provider, apiBase, hash[:6]) +} + +// maskAPIKeyValue masks an API key for display. +// Keys longer than 12 chars show prefix + last 4 chars: "sk-****abcd". +// Keys 9-12 chars show prefix + last 2 chars: "sk-****cd". +// Shorter keys are fully masked as "****". +// Empty keys return empty string. +// Ensure at least 40% of the key will not be displayed. +func maskAPIKeyValue(key string) string { + key = strings.TrimSpace(key) + if key == "" { + return "" + } + if len(key) <= 8 { + return "****" + } + if len(key) <= 12 { + return key[:3] + "****" + key[len(key)-2:] + } + return key[:3] + "****" + key[len(key)-4:] +} + +func loadCatalogs() (*CatalogStore, error) { + path := catalogFilePath() + data, err := os.ReadFile(path) + if err != nil { + if os.IsNotExist(err) { + return &CatalogStore{Entries: make(map[string]*CatalogEntry)}, nil + } + return nil, err + } + var store CatalogStore + if err := json.Unmarshal(data, &store); err != nil { + return nil, err + } + if store.Entries == nil { + store.Entries = make(map[string]*CatalogEntry) + } + return &store, nil +} + +func saveCatalogs(store *CatalogStore) error { + path := catalogFilePath() + data, err := json.MarshalIndent(store, "", " ") + if err != nil { + return err + } + return fileutil.WriteFileAtomic(path, data, 0o600) +} + +// SaveCatalog persists a fetched model list for a given provider+key combination. +// If a catalog with the same key already exists, it is updated. +func SaveCatalog(provider, apiBase, apiKey string, models []CatalogModel) error { + store, err := loadCatalogs() + if err != nil { + return err + } + key := generateCatalogKey(provider, apiBase, apiKey) + store.Entries[key] = &CatalogEntry{ + ID: key, + Provider: strings.ToLower(strings.TrimSpace(provider)), + APIBase: strings.TrimRight(strings.TrimSpace(apiBase), "/"), + APIKeyMask: maskAPIKeyValue(apiKey), + Models: models, + FetchedAt: time.Now().UTC().Format(time.RFC3339), + } + return saveCatalogs(store) +} + +// handleListCatalogs returns all saved model catalogs. +// +// GET /api/models/catalog +func (h *Handler) handleListCatalogs(w http.ResponseWriter, r *http.Request) { + store, err := loadCatalogs() + if err != nil { + http.Error(w, fmt.Sprintf("Failed to load catalogs: %v", err), http.StatusInternalServerError) + return + } + + entries := make([]*CatalogEntry, 0, len(store.Entries)) + for _, e := range store.Entries { + entries = append(entries, e) + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]any{ + "entries": entries, + "total": len(entries), + }) +} + +// handleDeleteCatalog deletes a saved model catalog by ID. +// +// DELETE /api/models/catalog/{id} +func (h *Handler) handleDeleteCatalog(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + if id == "" { + http.Error(w, "id is required", http.StatusBadRequest) + return + } + + store, err := loadCatalogs() + if err != nil { + http.Error(w, fmt.Sprintf("Failed to load catalogs: %v", err), http.StatusInternalServerError) + return + } + + if _, ok := store.Entries[id]; !ok { + http.Error(w, "catalog not found", http.StatusNotFound) + return + } + + delete(store.Entries, id) + if err := saveCatalogs(store); err != nil { + http.Error(w, fmt.Sprintf("Failed to save catalogs: %v", err), http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]string{"status": "ok"}) +} diff --git a/web/backend/api/model_catalog_test.go b/web/backend/api/model_catalog_test.go new file mode 100644 index 000000000..76138fdcd --- /dev/null +++ b/web/backend/api/model_catalog_test.go @@ -0,0 +1,87 @@ +package api + +import ( + "strings" + "testing" +) + +func TestMaskAPIKeyValue(t *testing.T) { + tests := []struct { + name string + key string + want string + }{ + { + name: "empty key", + key: "", + want: "", + }, + { + name: "whitespace only", + key: " ", + want: "", + }, + { + name: "short key fully masked", + key: "abcd", + want: "****", + }, + { + name: "length 8 boundary fully masked", + key: "12345678", + want: "****", + }, + { + name: "length 9 boundary shows last 2", + key: "123456789", + want: "123****89", + }, + { + name: "length 10 shows last 2", + key: "1234567890", + want: "123****90", + }, + { + name: "length 12 boundary shows last 2", + key: "abcdefghijkl", + want: "abc****kl", + }, + { + name: "length 13 boundary shows last 4", + key: "abcdefghijklm", + want: "abc****jklm", + }, + { + name: "typical api key", + key: "sk-1234567890abcd", + want: "sk-****abcd", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := maskAPIKeyValue(tc.key) + if got != tc.want { + t.Fatalf("maskAPIKeyValue(%q) = %q, want %q", tc.key, got, tc.want) + } + + if tc.key != "" { + displayed := strings.Replace(got, "****", "", 1) + if len(strings.TrimSpace(tc.key)) <= 8 { + if displayed != "" { + t.Fatalf("maskAPIKeyValue(%q) displayed part = %q, want empty", tc.key, displayed) + } + } else { + if len(displayed)*10 > len(strings.TrimSpace(tc.key))*6 { + t.Fatalf( + "maskAPIKeyValue(%q) displayed length = %d, want at most 60%% of %d", + tc.key, + len(displayed), + len(strings.TrimSpace(tc.key)), + ) + } + } + } + }) + } +} diff --git a/web/backend/api/models.go b/web/backend/api/models.go index 8a66918f9..f1761cb70 100644 --- a/web/backend/api/models.go +++ b/web/backend/api/models.go @@ -1,6 +1,7 @@ package api import ( + "context" "encoding/json" "fmt" "io" @@ -8,6 +9,7 @@ import ( "strconv" "strings" "sync" + "time" "github.com/sipeed/picoclaw/pkg/audio/asr" "github.com/sipeed/picoclaw/pkg/config" @@ -15,6 +17,18 @@ import ( "github.com/sipeed/picoclaw/pkg/providers" ) +// fetchableProviders lists providers that support OpenAI-compatible /models listing. +var fetchableProviders = map[string]bool{ + "openai": true, "deepseek": true, "openrouter": true, + "qwen-portal": true, "qwen-intl": true, "moonshot": true, + "volcengine": true, "zhipu": true, "groq": true, + "mistral": true, "nvidia": true, "cerebras": true, + "venice": true, "shengsuanyun": true, "vivgrid": true, + "minimax": true, "longcat": true, "modelscope": true, + "mimo": true, "avian": true, "zai": true, "novita": true, + "litellm": true, "vllm": true, "lmstudio": true, "ollama": true, +} + // registerModelRoutes binds model list management endpoints to the ServeMux. func (h *Handler) registerModelRoutes(mux *http.ServeMux) { mux.HandleFunc("GET /api/models", h.handleListModels) @@ -22,6 +36,9 @@ func (h *Handler) registerModelRoutes(mux *http.ServeMux) { mux.HandleFunc("POST /api/models/default", h.handleSetDefaultModel) mux.HandleFunc("PUT /api/models/{index}", h.handleUpdateModel) mux.HandleFunc("DELETE /api/models/{index}", h.handleDeleteModel) + mux.HandleFunc("POST /api/models/fetch", h.handleFetchModels) + mux.HandleFunc("GET /api/models/catalog", h.handleListCatalogs) + mux.HandleFunc("DELETE /api/models/catalog/{id}", h.handleDeleteCatalog) } // modelResponse is the JSON structure returned for each model in the list. @@ -614,3 +631,199 @@ func maskAPIKey(key string) string { // Show first 3 chars and last 4 chars return key[:3] + "****" + key[len(key)-4:] } + +// handleFetchModels fetches available models from an upstream provider. +// +// POST /api/models/fetch +func (h *Handler) handleFetchModels(w http.ResponseWriter, r *http.Request) { + body, err := io.ReadAll(io.LimitReader(r.Body, 1<<20)) + if err != nil { + http.Error(w, "Failed to read request body", http.StatusBadRequest) + return + } + defer r.Body.Close() + + var req struct { + Provider string `json:"provider"` + APIKey string `json:"api_key"` + APIBase string `json:"api_base"` + } + if err = json.Unmarshal(body, &req); err != nil { + http.Error(w, fmt.Sprintf("Invalid JSON: %v", err), http.StatusBadRequest) + return + } + + if req.Provider == "" { + http.Error(w, "provider is required", http.StatusBadRequest) + return + } + + if !fetchableProviders[strings.ToLower(req.Provider)] { + http.Error(w, fmt.Sprintf("provider %q does not support model listing", req.Provider), http.StatusBadRequest) + return + } + + apiBase := strings.TrimSpace(req.APIBase) + if apiBase == "" { + apiBase = providers.DefaultAPIBaseForProtocol(req.Provider) + } + if apiBase == "" { + http.Error(w, fmt.Sprintf("No default API base for provider %q", req.Provider), http.StatusBadRequest) + return + } + + ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second) + defer cancel() + + models, err := fetchUpstreamModels(ctx, req.Provider, apiBase, req.APIKey) + if err != nil { + http.Error(w, fmt.Sprintf("Failed to fetch models: %v", err), http.StatusBadGateway) + return + } + + // Auto-save fetched models to catalog + catalogModels := make([]CatalogModel, len(models)) + for i, m := range models { + catalogModels[i] = CatalogModel{ID: m.ID, OwnedBy: m.OwnedBy} + } + if saveErr := SaveCatalog(req.Provider, apiBase, req.APIKey, catalogModels); saveErr != nil { + // Log but don't fail the request — saving catalog is non-critical + logger.Warnf("Failed to save model catalog: %v", saveErr) + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]any{ + "models": models, + "total": len(models), + }) +} + +type upstreamModel struct { + ID string `json:"id"` + OwnedBy string `json:"owned_by,omitempty"` +} + +func fetchUpstreamModels(ctx context.Context, provider, apiBase, apiKey string) ([]upstreamModel, error) { + apiBase = strings.TrimRight(strings.TrimSpace(apiBase), "/") + + var fetchURL string + switch strings.ToLower(provider) { + case "ollama": + // Strip /v1 suffix if present to get the Ollama root + root := apiBase + if strings.HasSuffix(root, "/v1") { + root = root[:len(root)-3] + } + root = strings.TrimRight(root, "/") + fetchURL = root + "/api/tags" + return fetchOllamaModels(ctx, fetchURL) + default: + // OpenAI-compatible: /v1/models + fetchURL = apiBase + "/models" + return fetchOpenAICompatibleModels(ctx, fetchURL, apiKey) + } +} + +func fetchOpenAICompatibleModels(ctx context.Context, fetchURL, apiKey string) ([]upstreamModel, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, fetchURL, nil) + if err != nil { + return nil, err + } + if apiKey = strings.TrimSpace(apiKey); apiKey != "" { + req.Header.Set("Authorization", "Bearer "+apiKey) + } + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("upstream returned status %d", resp.StatusCode) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("read response: %w", err) + } + + type modelItem struct { + ID string `json:"id"` + OwnedBy string `json:"owned_by"` + } + + // {"data": [...]} envelope. Distinguish "envelope shape with empty list" + // from "object without a data key" via Data being non-nil after unmarshal: + // json.Unmarshal sets Data to []modelItem{} for `{"data":[]}` but leaves + // it as nil when "data" is absent or null. + var envelope struct { + Data []modelItem `json:"data"` + } + if err := json.Unmarshal(body, &envelope); err == nil && envelope.Data != nil { + models := make([]upstreamModel, 0, len(envelope.Data)) + for _, m := range envelope.Data { + if m.ID != "" { + models = append(models, upstreamModel{ID: m.ID, OwnedBy: m.OwnedBy}) + } + } + return models, nil + } + + // Bare-array shape, including `[]`. + var arr []modelItem + if err := json.Unmarshal(body, &arr); err == nil { + models := make([]upstreamModel, 0, len(arr)) + for _, m := range arr { + if m.ID != "" { + models = append(models, upstreamModel{ID: m.ID, OwnedBy: m.OwnedBy}) + } + } + return models, nil + } + + preview := body + if len(preview) > 256 { + preview = preview[:256] + } + return nil, fmt.Errorf("decode response: unrecognized shape: %s", strings.TrimSpace(string(preview))) +} + +func fetchOllamaModels(ctx context.Context, fetchURL string) ([]upstreamModel, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, fetchURL, nil) + if err != nil { + return nil, err + } + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("ollama returned status %d", resp.StatusCode) + } + + var parsed struct { + Models []struct { + Name string `json:"name"` + Model string `json:"model"` + } `json:"models"` + } + if err := json.NewDecoder(resp.Body).Decode(&parsed); err != nil { + return nil, err + } + + models := make([]upstreamModel, 0, len(parsed.Models)) + for _, m := range parsed.Models { + id := m.Name + if id == "" { + id = m.Model + } + if id != "" { + models = append(models, upstreamModel{ID: id}) + } + } + return models, nil +} diff --git a/web/backend/api/models_test.go b/web/backend/api/models_test.go index 0b1f04848..71b6b279c 100644 --- a/web/backend/api/models_test.go +++ b/web/backend/api/models_test.go @@ -3,6 +3,7 @@ package api import ( "bytes" "encoding/json" + "fmt" "net/http" "net/http/httptest" "strings" @@ -2219,3 +2220,174 @@ func TestMaskAPIKey(t *testing.T) { }) } } + +func TestFetchOpenAICompatibleModels_ResponseShapes(t *testing.T) { + tests := []struct { + name string + response string + apiKey string + wantLen int + wantFirst struct { + id, ownedBy string + } + wantSecond struct { + id, ownedBy string + } + }{ + { + name: "envelope shape", + response: `{"data":[{"id":"gpt-4o","owned_by":"openai"},{"id":"gpt-4o-mini","owned_by":"openai"}]}`, + apiKey: "test-key", + wantLen: 2, + wantFirst: struct { + id, ownedBy string + }{id: "gpt-4o", ownedBy: "openai"}, + wantSecond: struct { + id, ownedBy string + }{id: "gpt-4o-mini", ownedBy: "openai"}, + }, + { + name: "bare array shape", + response: `[{"id":"qwen-max","owned_by":"qwen"},{"id":"qwen-plus","owned_by":"qwen"}]`, + apiKey: "", + wantLen: 2, + wantFirst: struct { + id, ownedBy string + }{id: "qwen-max", ownedBy: "qwen"}, + wantSecond: struct { + id, ownedBy string + }{id: "qwen-plus", ownedBy: "qwen"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, tt.response) + })) + defer srv.Close() + + models, err := fetchOpenAICompatibleModels(t.Context(), srv.URL+"/models", tt.apiKey) + if err != nil { + t.Fatalf("error = %v", err) + } + if len(models) != tt.wantLen { + t.Fatalf("len(models) = %d, want %d", len(models), tt.wantLen) + } + if models[0].ID != tt.wantFirst.id || models[0].OwnedBy != tt.wantFirst.ownedBy { + t.Fatalf("models[0] = %+v, want {ID:%s OwnedBy:%s}", models[0], tt.wantFirst.id, tt.wantFirst.ownedBy) + } + if models[1].ID != tt.wantSecond.id || models[1].OwnedBy != tt.wantSecond.ownedBy { + t.Fatalf("models[1] = %+v, want {ID:%s OwnedBy:%s}", models[1], tt.wantSecond.id, tt.wantSecond.ownedBy) + } + }) + } +} + +func TestFetchOpenAICompatibleModels_EmptyEnvelopeReturnsEmptySlice(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, `{"data":[]}`) + })) + defer srv.Close() + + models, err := fetchOpenAICompatibleModels(t.Context(), srv.URL+"/models", "k") + if err != nil { + t.Fatalf("error = %v", err) + } + if len(models) != 0 { + t.Fatalf("len(models) = %d, want 0", len(models)) + } +} + +func TestFetchOpenAICompatibleModels_EmptyBareArrayReturnsEmptySlice(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, `[]`) + })) + defer srv.Close() + + models, err := fetchOpenAICompatibleModels(t.Context(), srv.URL+"/models", "k") + if err != nil { + t.Fatalf("error = %v", err) + } + if len(models) != 0 { + t.Fatalf("len(models) = %d, want 0", len(models)) + } +} + +func TestFetchOpenAICompatibleModels_UnrecognizedShape(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, `{"models":[],"error":"unsupported"}`) + })) + defer srv.Close() + + _, err := fetchOpenAICompatibleModels(t.Context(), srv.URL+"/models", "k") + if err == nil { + t.Fatal("error = nil, want unrecognized shape error") + } + if !strings.Contains(err.Error(), "unrecognized shape") { + t.Fatalf("error = %q, want it to contain 'unrecognized shape'", err.Error()) + } +} + +func TestFetchOpenAICompatibleModels_FiltersEmptyIDs(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, `{"data":[`+ + `{"id":"gpt-4o","owned_by":"openai"},`+ + `{"id":"","owned_by":"openai"},`+ + `{"id":"gpt-4o-mini"}]}`) + })) + defer srv.Close() + + models, err := fetchOpenAICompatibleModels(t.Context(), srv.URL+"/models", "k") + if err != nil { + t.Fatalf("error = %v", err) + } + if len(models) != 2 { + t.Fatalf("len(models) = %d, want 2 (empty IDs should be filtered)", len(models)) + } + if models[0].ID != "gpt-4o" { + t.Fatalf("models[0].ID = %q, want %q", models[0].ID, "gpt-4o") + } + if models[1].ID != "gpt-4o-mini" { + t.Fatalf("models[1].ID = %q, want %q", models[1].ID, "gpt-4o-mini") + } +} + +func TestFetchOpenAICompatibleModels_SetsAuthorizationHeader(t *testing.T) { + var gotAuth string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotAuth = r.Header.Get("Authorization") + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, `{"data":[{"id":"m1"}]}`) + })) + defer srv.Close() + + if _, err := fetchOpenAICompatibleModels(t.Context(), srv.URL+"/models", "my-secret-key"); err != nil { + t.Fatalf("error = %v", err) + } + if gotAuth != "Bearer my-secret-key" { + t.Fatalf("Authorization = %q, want %q", gotAuth, "Bearer my-secret-key") + } +} + +func TestFetchOpenAICompatibleModels_NoAuthHeaderWhenKeyEmpty(t *testing.T) { + var gotAuth string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotAuth = r.Header.Get("Authorization") + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, `[{"id":"m1"}]`) + })) + defer srv.Close() + + if _, err := fetchOpenAICompatibleModels(t.Context(), srv.URL+"/models", ""); err != nil { + t.Fatalf("error = %v", err) + } + if gotAuth != "" { + t.Fatalf("Authorization = %q, want empty", gotAuth) + } +} diff --git a/web/frontend/src/components/models/add-model-sheet.tsx b/web/frontend/src/components/models/add-model-sheet.tsx index 06626441d..e24c0d81e 100644 --- a/web/frontend/src/components/models/add-model-sheet.tsx +++ b/web/frontend/src/components/models/add-model-sheet.tsx @@ -3,7 +3,7 @@ import { IconLoader2, IconPlugConnected, } from "@tabler/icons-react" -import { type ComponentType, useCallback, useEffect, useRef, useState } from "react" +import { useCallback, useEffect, useRef, useState } from "react" import { useTranslation } from "react-i18next" import { @@ -35,6 +35,7 @@ import { Textarea } from "@/components/ui/textarea" import { showSaveSuccessOrRestartToast } from "@/lib/restart-required" import { refreshGatewayState } from "@/store/gateway" +import { FetchModelsDialog } from "./fetch-models-dialog" import { type FieldValidation, validateModelField } from "./model-validation" import { ProviderCombobox } from "./provider-combobox" import { getProviderKey } from "./provider-label" @@ -141,17 +142,12 @@ export function AddModelSheet({ const debounceRef = useRef>(undefined) const scrollContainerRef = useRef(null) - // Dynamic imports for dialogs added in later PRs - const [FetchModelsDialogComp, setFetchModelsDialogComp] = useState void; onFill: (models: string[]) => void; - provider: string; apiKey: string; apiBase: string; - }> | null>(null) - const [TestModelDialogComp, setTestModelDialogComp] = useState void; inlineParams: { provider: string; model: string; apiBase: string; apiKey: string; authMethod: string }; }> | null>(null) useEffect(() => { - import("./fetch-models-dialog").then((m) => setFetchModelsDialogComp(() => m.FetchModelsDialog)).catch(() => {}) import("./test-model-dialog").then((m) => setTestModelDialogComp(() => m.TestModelDialog)).catch(() => {}) }, []) @@ -526,7 +522,6 @@ export function AddModelSheet({ size="sm" className="h-7 text-xs" onClick={() => setFetchOpen(true)} - disabled={!FetchModelsDialogComp} > {t("models.fetch.title")} @@ -740,16 +735,14 @@ export function AddModelSheet({ - {FetchModelsDialogComp && ( - setFetchOpen(false)} - onFill={handleFetchFill} - provider={form.provider} - apiKey={form.apiKey} - apiBase={form.apiBase} - /> - )} + setFetchOpen(false)} + onFill={handleFetchFill} + provider={form.provider} + apiKey={form.apiKey} + apiBase={form.apiBase} + /> {TestModelDialogComp && ( void + onModelAdded: () => void +} + +export function CatalogDialog({ + open, + onClose, + onModelAdded, +}: CatalogDialogProps) { + const { t } = useTranslation() + const [loading, setLoading] = useState(false) + const [entries, setEntries] = useState([]) + const [expandedId, setExpandedId] = useState(null) + const [selected, setSelected] = useState>>(new Map()) + const [adding, setAdding] = useState(false) + const [filter, setFilter] = useState("") + + const loadCatalogs = useCallback(async () => { + setLoading(true) + try { + const res = await getCatalogs() + setEntries(res.entries || []) + } catch (e) { + toast.error(e instanceof Error ? e.message : "Failed to load catalogs") + } finally { + setLoading(false) + } + }, []) + + useEffect(() => { + if (open) { + loadCatalogs() + setExpandedId(null) + setSelected(new Map()) + setFilter("") + } + }, [open, loadCatalogs]) + + const toggleExpand = (id: string) => { + setExpandedId((prev) => (prev === id ? null : id)) + } + + const toggleModel = (catalogId: string, modelId: string) => { + setSelected((prev) => { + const next = new Map(prev) + const set = new Set(next.get(catalogId) || []) + if (set.has(modelId)) set.delete(modelId) + else set.add(modelId) + next.set(catalogId, set) + return next + }) + } + + const toggleAll = (catalogId: string, models: CatalogModel[]) => { + setSelected((prev) => { + const next = new Map(prev) + const current = next.get(catalogId) || new Set() + const filtered = filter + ? models.filter((m) => + m.id.toLowerCase().includes(filter.toLowerCase()), + ) + : models + if (filtered.every((m) => current.has(m.id))) { + next.set(catalogId, new Set()) + } else { + next.set(catalogId, new Set(filtered.map((m) => m.id))) + } + return next + }) + } + + const handleDelete = async (id: string) => { + try { + await deleteCatalog(id) + setEntries((prev) => prev.filter((e) => e.id !== id)) + setSelected((prev) => { + const next = new Map(prev) + next.delete(id) + return next + }) + if (expandedId === id) setExpandedId(null) + } catch (e) { + toast.error(e instanceof Error ? e.message : "Failed to delete catalog") + } + } + + const handleAddSelected = async (entry: CatalogEntry) => { + const catalogSelected = selected.get(entry.id) || new Set() + if (catalogSelected.size === 0) return + + setAdding(true) + try { + const modelsToAdd = entry.models.filter((m) => catalogSelected.has(m.id)) + for (const model of modelsToAdd) { + await addModel({ + model_name: model.id, + provider: entry.provider || undefined, + model: model.id, + api_base: entry.api_base || undefined, + }) + } + await refreshGatewayState({ force: true }) + toast.success( + t("models.catalog.addSuccess", { count: modelsToAdd.length }), + ) + onModelAdded() + } catch (e) { + toast.error(e instanceof Error ? e.message : "Failed to add models") + } finally { + setAdding(false) + } + } + + const getFilteredModels = (models: CatalogModel[]) => + filter + ? models.filter((m) => m.id.toLowerCase().includes(filter.toLowerCase())) + : models + + return ( + !v && onClose()}> + + + {t("models.catalog.title")} + + {t("models.catalog.description")} + + + +
+ {loading && ( +
+ + {t("models.catalog.loading")} +
+ )} + + {!loading && entries.length === 0 && ( +
+ {t("models.catalog.empty")} +
+ )} + + {entries.length > 0 && ( + setFilter(e.target.value)} + className="h-8" + /> + )} + +
+ {entries.map((entry) => { + const isExpanded = expandedId === entry.id + const entrySelected = selected.get(entry.id) || new Set() + const filteredModels = getFilteredModels(entry.models) + + return ( +
+
toggleExpand(entry.id)} + > + {isExpanded ? ( + + ) : ( + + )} +
+
+ + {getProviderLabel(entry.provider)} + + + {entry.api_key_mask} + +
+
+ + {entry.models.length} {t("models.catalog.models")} + + {entry.api_base && ( + <> + | + {entry.api_base} + + )} + {entry.fetched_at && ( + <> + | + + {t("models.catalog.fetchedAt")}{" "} + {new Date(entry.fetched_at).toLocaleDateString()} + + + )} +
+
+
+ +
+
+ + {isExpanded && ( +
+
+ + {t("models.catalog.found", { + count: filteredModels.length, + })} + + +
+
+ {filteredModels.map((m) => ( + + ))} +
+ {entrySelected.size > 0 && ( +
+ {PROVIDER_MAP.get(entry.provider)?.requiresApiKey !== + false && ( +
+ {t("models.catalog.needApiKey")} +
+ )} +
+ +
+
+ )} +
+ )} +
+ ) + })} +
+
+ + + + +
+
+ ) } diff --git a/web/frontend/src/components/models/edit-model-sheet.tsx b/web/frontend/src/components/models/edit-model-sheet.tsx index 0581fe057..f8a645316 100644 --- a/web/frontend/src/components/models/edit-model-sheet.tsx +++ b/web/frontend/src/components/models/edit-model-sheet.tsx @@ -3,7 +3,7 @@ import { IconLoader2, IconPlugConnected, } from "@tabler/icons-react" -import { type ComponentType, useCallback, useEffect, useRef, useState } from "react" +import { useCallback, useEffect, useRef, useState } from "react" import { useTranslation } from "react-i18next" import { @@ -36,6 +36,7 @@ import { Textarea } from "@/components/ui/textarea" import { showSaveSuccessOrRestartToast } from "@/lib/restart-required" import { refreshGatewayState } from "@/store/gateway" +import { FetchModelsDialog } from "./fetch-models-dialog" import { type FieldValidation, validateModelField } from "./model-validation" import { ProviderCombobox } from "./provider-combobox" import { getProviderKey } from "./provider-label" @@ -158,17 +159,12 @@ export function EditModelSheet({ const debounceRef = useRef>(undefined) const scrollContainerRef = useRef(null) - // Dynamic imports for dialogs added in later PRs - const [FetchModelsDialogComp, setFetchModelsDialogComp] = useState void; onFill: (models: string[]) => void; - provider: string; apiKey: string; apiBase: string; - }> | null>(null) - const [TestModelDialogComp, setTestModelDialogComp] = useState void; inlineParams: { provider: string; model: string; apiBase: string; apiKey: string; authMethod: string; modelIndex?: number }; }> | null>(null) useEffect(() => { - import("./fetch-models-dialog").then((m) => setFetchModelsDialogComp(() => m.FetchModelsDialog)).catch(() => {}) import("./test-model-dialog").then((m) => setTestModelDialogComp(() => m.TestModelDialog)).catch(() => {}) }, []) @@ -477,7 +473,6 @@ export function EditModelSheet({ size="sm" className="h-7 text-xs" onClick={() => setFetchOpen(true)} - disabled={!FetchModelsDialogComp} > {t("models.fetch.title")} @@ -714,16 +709,14 @@ export function EditModelSheet({ /> )} - {FetchModelsDialogComp && ( - setFetchOpen(false)} - onFill={handleFetchFill} - provider={form.provider} - apiKey={form.apiKey} - apiBase={form.apiBase} - /> - )} + setFetchOpen(false)} + onFill={handleFetchFill} + provider={form.provider} + apiKey={form.apiKey} + apiBase={form.apiBase} + /> ) } diff --git a/web/frontend/src/components/models/fetch-models-dialog.tsx b/web/frontend/src/components/models/fetch-models-dialog.tsx index 8f7192e4a..09b602e6d 100644 --- a/web/frontend/src/components/models/fetch-models-dialog.tsx +++ b/web/frontend/src/components/models/fetch-models-dialog.tsx @@ -1,4 +1,224 @@ -// Placeholder: full implementation added in PR2 (Fetch Models & Saved Catalogs) -export function FetchModelsDialog() { - return null +import { IconDownload, IconLoader2 } from "@tabler/icons-react" +import { useCallback, useEffect, useState } from "react" +import { useTranslation } from "react-i18next" + +import { type UpstreamModel, fetchUpstreamModels } from "@/api/models" +import { Button } from "@/components/ui/button" +import { + Dialog, + DialogContent, + DialogDescription, + DialogFooter, + DialogHeader, + DialogTitle, +} from "@/components/ui/dialog" +import { Input } from "@/components/ui/input" + +import { PROVIDER_MAP } from "./provider-registry" + +interface FetchModelsDialogProps { + open: boolean + onClose: () => void + onFill: (models: string[]) => void + provider: string + apiKey: string + apiBase: string +} + +export function FetchModelsDialog({ + open, + onClose, + onFill, + provider, + apiKey, + apiBase, +}: FetchModelsDialogProps) { + const { t } = useTranslation() + const [fetching, setFetching] = useState(false) + const [models, setModels] = useState([]) + const [selected, setSelected] = useState>(new Set()) + const [error, setError] = useState("") + const [filter, setFilter] = useState("") + + const providerDef = PROVIDER_MAP.get(provider) + const needsKey = providerDef?.requiresApiKey !== false + + const handleFetch = useCallback(async () => { + setFetching(true) + setError("") + setModels([]) + setSelected(new Set()) + try { + const res = await fetchUpstreamModels({ + provider, + api_key: apiKey, + api_base: apiBase, + }) + setModels(res.models) + // Auto-select all by default + setSelected(new Set(res.models.map((m) => m.id))) + } catch (e) { + setError(e instanceof Error ? e.message : t("models.fetch.failed")) + } finally { + setFetching(false) + } + }, [provider, apiKey, apiBase, t]) + + // Auto-fetch when dialog opens (skip if provider requires API key but none is set) + useEffect(() => { + if (open && provider && !(needsKey && !apiKey)) { + handleFetch() + } + }, [open, provider, apiKey, needsKey, handleFetch]) + + const handleFill = () => { + onFill(Array.from(selected)) + handleClose() + } + + const handleClose = () => { + setModels([]) + setSelected(new Set()) + setError("") + setFilter("") + onClose() + } + + const toggleModel = (id: string) => { + setSelected((prev) => { + const next = new Set(prev) + if (next.has(id)) next.delete(id) + else next.add(id) + return next + }) + } + + const toggleAll = () => { + const filtered = models + .map((m) => m.id) + .filter( + (id) => !filter || id.toLowerCase().includes(filter.toLowerCase()), + ) + if (filtered.every((id) => selected.has(id))) { + setSelected(new Set()) + } else { + setSelected(new Set(filtered)) + } + } + + const filteredModels = filter + ? models.filter((m) => m.id.toLowerCase().includes(filter.toLowerCase())) + : models + + return ( + !v && handleClose()}> + + + + + {t("models.fetch.title")} + + + {t("models.fetch.description")} + {provider && ( + + {t("models.fetch.providerLabel")} {provider} + {apiBase && ` | ${apiBase}`} + + )} + + + +
+ {needsKey && !apiKey && ( +
+ {t("models.fetch.needApiKey")} +
+ )} + + {fetching && ( +
+ + {t("models.fetch.fetching")} +
+ )} + + {error && ( +
+
+ {error} +
+ +
+ )} + + {models.length > 0 && ( + <> + setFilter(e.target.value)} + className="h-8" + /> +
+ + {t("models.fetch.found", { count: models.length })} + {filter && + ` ${t("models.fetch.shown", { count: filteredModels.length })}`} + + +
+
+ {filteredModels.map((m) => ( + + ))} +
+ + )} +
+ + + + {models.length > 0 && ( + + )} + +
+
+ ) } diff --git a/web/frontend/src/components/models/models-page.tsx b/web/frontend/src/components/models/models-page.tsx index 238a51a5d..9c0c400db 100644 --- a/web/frontend/src/components/models/models-page.tsx +++ b/web/frontend/src/components/models/models-page.tsx @@ -4,7 +4,7 @@ import { IconPlus, IconStar, } from "@tabler/icons-react" -import { type ComponentType, useCallback, useEffect, useState } from "react" +import { useCallback, useEffect, useState } from "react" import { useTranslation } from "react-i18next" import { toast } from "sonner" @@ -20,6 +20,7 @@ import { showSaveSuccessOrRestartToast } from "@/lib/restart-required" import { refreshGatewayState } from "@/store/gateway" import { AddModelSheet } from "./add-model-sheet" +import { CatalogDialog } from "./catalog-dialog" import { DeleteModelDialog } from "./delete-model-dialog" import { EditModelSheet } from "./edit-model-sheet" import { getProviderKey, getProviderLabel } from "./provider-label" @@ -51,14 +52,6 @@ export function ModelsPage() { null, ) - // Dynamic import for CatalogDialog (added in PR2) - const [CatalogDialogComp, setCatalogDialogComp] = useState void; onModelAdded: () => void; - }> | null>(null) - useEffect(() => { - import("./catalog-dialog").then((m) => setCatalogDialogComp(() => m.CatalogDialog)).catch(() => {}) - }, []) - const fetchModels = useCallback(async () => { try { const data = await getModels() @@ -156,7 +149,6 @@ export function ModelsPage() { size="sm" variant="outline" onClick={() => setCatalogOpen(true)} - disabled={!CatalogDialogComp} > {t("models.catalog.button")} @@ -234,13 +226,11 @@ export function ModelsPage() { onDeleted={fetchModels} /> - {CatalogDialogComp && ( - setCatalogOpen(false)} - onModelAdded={fetchModels} - /> - )} + setCatalogOpen(false)} + onModelAdded={fetchModels} + /> ) }