From 30938df40b89e147a697c941fad38d731557c306 Mon Sep 17 00:00:00 2001 From: Guoguo Date: Thu, 21 May 2026 15:51:45 +0800 Subject: [PATCH] fix(web): use stored API key when fetching models for saved providers (#2910) When editing an existing model, the edit form initializes apiKey as empty for security. This caused "Fetch Available Models" to reject with "please enter API Key first" even though the key is saved server-side. Add model_index support: the frontend passes the model's index to the backend, which looks up the stored key from config. The key never leaves the backend. Provider and API base are validated to prevent a stored key from being sent to an unrelated endpoint. Co-authored-by: Claude Opus 4.7 (1M context) --- web/backend/api/models.go | 43 ++++++- web/backend/api/models_test.go | 105 ++++++++++++++++++ web/frontend/src/api/models.ts | 1 + .../components/models/edit-model-sheet.tsx | 1 + .../components/models/fetch-models-dialog.tsx | 12 +- 5 files changed, 153 insertions(+), 9 deletions(-) diff --git a/web/backend/api/models.go b/web/backend/api/models.go index 5c1eb23ad..18bd457bf 100644 --- a/web/backend/api/models.go +++ b/web/backend/api/models.go @@ -640,9 +640,10 @@ func (h *Handler) handleFetchModels(w http.ResponseWriter, r *http.Request) { defer r.Body.Close() var req struct { - Provider string `json:"provider"` - APIKey string `json:"api_key"` - APIBase string `json:"api_base"` + Provider string `json:"provider"` + APIKey string `json:"api_key"` + APIBase string `json:"api_base"` + ModelIndex *int `json:"model_index,omitempty"` } if err = json.Unmarshal(body, &req); err != nil { http.Error(w, fmt.Sprintf("Invalid JSON: %v", err), http.StatusBadRequest) @@ -659,7 +660,15 @@ func (h *Handler) handleFetchModels(w http.ResponseWriter, r *http.Request) { return } + apiKey := strings.TrimSpace(req.APIKey) apiBase := strings.TrimSpace(req.APIBase) + + if apiKey == "" && req.ModelIndex != nil { + if stored := h.lookupStoredAPIKey(*req.ModelIndex, req.Provider, apiBase); stored != "" { + apiKey = stored + } + } + if apiBase == "" { apiBase = providers.DefaultAPIBaseForProtocol(req.Provider) } @@ -671,7 +680,7 @@ func (h *Handler) handleFetchModels(w http.ResponseWriter, r *http.Request) { ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second) defer cancel() - models, err := fetchUpstreamModels(ctx, req.Provider, apiBase, req.APIKey) + models, err := fetchUpstreamModels(ctx, req.Provider, apiBase, apiKey) if err != nil { http.Error(w, fmt.Sprintf("Failed to fetch models: %v", err), http.StatusBadGateway) return @@ -682,7 +691,7 @@ func (h *Handler) handleFetchModels(w http.ResponseWriter, r *http.Request) { for i, m := range models { catalogModels[i] = CatalogModel{ID: m.ID, OwnedBy: m.OwnedBy} } - if saveErr := SaveCatalog(req.Provider, apiBase, req.APIKey, catalogModels); saveErr != nil { + if saveErr := SaveCatalog(req.Provider, apiBase, apiKey, catalogModels); saveErr != nil { // Log but don't fail the request — saving catalog is non-critical logger.Warnf("Failed to save model catalog: %v", saveErr) } @@ -694,6 +703,30 @@ func (h *Handler) handleFetchModels(w http.ResponseWriter, r *http.Request) { }) } +func (h *Handler) lookupStoredAPIKey(index int, reqProvider, reqAPIBase string) string { + cfg, err := config.LoadConfig(h.configPath) + if err != nil || index < 0 || index >= len(cfg.ModelList) { + return "" + } + stored := cfg.ModelList[index] + storedProvider, _ := providers.ExtractProtocol(stored) + if providers.NormalizeProvider(reqProvider) != providers.NormalizeProvider(storedProvider) { + return "" + } + effectiveReqBase := strings.TrimSpace(reqAPIBase) + if effectiveReqBase == "" { + effectiveReqBase = providers.DefaultAPIBaseForProtocol(reqProvider) + } + effectiveStoredBase := strings.TrimSpace(stored.APIBase) + if effectiveStoredBase == "" { + effectiveStoredBase = providers.DefaultAPIBaseForProtocol(storedProvider) + } + if normalizeAPIBaseForCompare(effectiveReqBase) != normalizeAPIBaseForCompare(effectiveStoredBase) { + return "" + } + return stored.APIKey() +} + type upstreamModel struct { ID string `json:"id"` OwnedBy string `json:"owned_by,omitempty"` diff --git a/web/backend/api/models_test.go b/web/backend/api/models_test.go index 94a3da3b3..22b02c22a 100644 --- a/web/backend/api/models_test.go +++ b/web/backend/api/models_test.go @@ -6,6 +6,8 @@ import ( "fmt" "net/http" "net/http/httptest" + "os" + "path/filepath" "strings" "sync" "testing" @@ -2576,3 +2578,106 @@ func TestHandleFetchModels_SiliconFlowUsesOpenAICompatibleEndpoint(t *testing.T) t.Fatalf("model id = %q, want %q", resp.Models[0].ID, "deepseek-ai/DeepSeek-V3") } } + +func TestHandleFetchModels_ModelIndexUsesStoredKey(t *testing.T) { + var gotAuth string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotAuth = r.Header.Get("Authorization") + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, `{"data":[{"id":"gpt-4o","owned_by":"openai"}]}`) + })) + defer srv.Close() + + tmp := t.TempDir() + oldHome := os.Getenv("PICOCLAW_HOME") + t.Setenv("PICOCLAW_HOME", filepath.Join(tmp, ".picoclaw")) + defer func() { + if oldHome != "" { + os.Setenv("PICOCLAW_HOME", oldHome) + } else { + os.Unsetenv("PICOCLAW_HOME") + } + }() + + cfg := config.DefaultConfig() + cfg.ModelList = []*config.ModelConfig{ + { + ModelName: "my-openai", + Provider: "openai", + Model: "gpt-4o", + APIKeys: config.SimpleSecureStrings("sk-stored-secret"), + APIBase: srv.URL, + }, + } + configPath := filepath.Join(tmp, "config.json") + if err := config.SaveConfig(configPath, cfg); err != nil { + t.Fatalf("SaveConfig error: %v", err) + } + + h := NewHandler(configPath) + mux := http.NewServeMux() + h.RegisterRoutes(mux) + + idx := 0 + body := fmt.Sprintf(`{"provider":"openai","api_base":"%s","model_index":%d}`, srv.URL, idx) + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/api/models/fetch", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + mux.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String()) + } + if gotAuth != "Bearer sk-stored-secret" { + t.Fatalf("Authorization = %q, want stored key to be used", gotAuth) + } + + var resp struct { + Models []upstreamModel `json:"models"` + } + if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil { + t.Fatalf("Unmarshal error: %v", err) + } + if len(resp.Models) != 1 || resp.Models[0].ID != "gpt-4o" { + t.Fatalf("unexpected response: %+v", resp) + } +} + +func TestHandleFetchModels_ModelIndexProviderMismatchRejectsKey(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Authorization") != "" { + t.Error("stored key should NOT be sent to mismatched provider") + } + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, `{"data":[]}`) + })) + defer srv.Close() + + tmp := t.TempDir() + t.Setenv("PICOCLAW_HOME", filepath.Join(tmp, ".picoclaw")) + + cfg := config.DefaultConfig() + cfg.ModelList = []*config.ModelConfig{ + { + ModelName: "my-openai", + Provider: "openai", + Model: "gpt-4o", + APIKeys: config.SimpleSecureStrings("sk-openai-secret"), + APIBase: "https://api.openai.com/v1", + }, + } + configPath := filepath.Join(tmp, "config.json") + if err := config.SaveConfig(configPath, cfg); err != nil { + t.Fatalf("SaveConfig error: %v", err) + } + + h := NewHandler(configPath) + mux := http.NewServeMux() + h.RegisterRoutes(mux) + + body := fmt.Sprintf(`{"provider":"siliconflow","api_base":"%s","model_index":0}`, srv.URL) + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/api/models/fetch", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + mux.ServeHTTP(rec, req) +} diff --git a/web/frontend/src/api/models.ts b/web/frontend/src/api/models.ts index 5f89cbfcf..a319eda55 100644 --- a/web/frontend/src/api/models.ts +++ b/web/frontend/src/api/models.ts @@ -167,6 +167,7 @@ export interface FetchModelsRequest { provider: string api_key?: string api_base?: string + model_index?: number } export interface FetchModelsResponse { diff --git a/web/frontend/src/components/models/edit-model-sheet.tsx b/web/frontend/src/components/models/edit-model-sheet.tsx index 44cd6b668..d54747481 100644 --- a/web/frontend/src/components/models/edit-model-sheet.tsx +++ b/web/frontend/src/components/models/edit-model-sheet.tsx @@ -804,6 +804,7 @@ export function EditModelSheet({ provider={canonicalProvider} apiKey={form.apiKey} apiBase={effectiveApiBase} + modelIndex={model?.index} backendOptions={providerOptions} /> diff --git a/web/frontend/src/components/models/fetch-models-dialog.tsx b/web/frontend/src/components/models/fetch-models-dialog.tsx index 8ea9e52a7..164b2d30e 100644 --- a/web/frontend/src/components/models/fetch-models-dialog.tsx +++ b/web/frontend/src/components/models/fetch-models-dialog.tsx @@ -30,6 +30,7 @@ interface FetchModelsDialogProps { provider: string apiKey: string apiBase: string + modelIndex?: number backendOptions?: ModelProviderOption[] } @@ -40,6 +41,7 @@ export function FetchModelsDialog({ provider, apiKey, apiBase, + modelIndex, backendOptions, }: FetchModelsDialogProps) { const { t } = useTranslation() @@ -52,6 +54,7 @@ export function FetchModelsDialog({ const canonicalProvider = getCanonicalProviderKey(provider, backendOptions) const providerDef = getProviderCatalogMap(backendOptions).get(canonicalProvider) const needsKey = providerDef?.requiresApiKey !== false + const hasKey = !!apiKey || modelIndex !== undefined const handleFetch = useCallback(async () => { setFetching(true) @@ -63,6 +66,7 @@ export function FetchModelsDialog({ provider: canonicalProvider, api_key: apiKey, api_base: apiBase, + model_index: modelIndex, }) setModels(res.models) // Auto-select all by default @@ -72,14 +76,14 @@ export function FetchModelsDialog({ } finally { setFetching(false) } - }, [canonicalProvider, apiKey, apiBase, t]) + }, [canonicalProvider, apiKey, apiBase, modelIndex, t]) // Auto-fetch when dialog opens (skip if provider requires API key but none is set) useEffect(() => { - if (open && provider && !(needsKey && !apiKey)) { + if (open && provider && !(needsKey && !hasKey)) { handleFetch() } - }, [open, provider, apiKey, needsKey, handleFetch]) + }, [open, provider, hasKey, needsKey, handleFetch]) const handleFill = () => { onFill(Array.from(selected)) @@ -140,7 +144,7 @@ export function FetchModelsDialog({
- {needsKey && !apiKey && ( + {needsKey && !hasKey && (
{t("models.fetch.needApiKey")}