mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
fix(web): use stored API key when fetching models for saved providers (#2910)
When editing an existing model, the edit form initializes apiKey as empty for security. This caused "Fetch Available Models" to reject with "please enter API Key first" even though the key is saved server-side. Add model_index support: the frontend passes the model's index to the backend, which looks up the stored key from config. The key never leaves the backend. Provider and API base are validated to prevent a stored key from being sent to an unrelated endpoint. Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -640,9 +640,10 @@ func (h *Handler) handleFetchModels(w http.ResponseWriter, r *http.Request) {
|
||||
defer r.Body.Close()
|
||||
|
||||
var req struct {
|
||||
Provider string `json:"provider"`
|
||||
APIKey string `json:"api_key"`
|
||||
APIBase string `json:"api_base"`
|
||||
Provider string `json:"provider"`
|
||||
APIKey string `json:"api_key"`
|
||||
APIBase string `json:"api_base"`
|
||||
ModelIndex *int `json:"model_index,omitempty"`
|
||||
}
|
||||
if err = json.Unmarshal(body, &req); err != nil {
|
||||
http.Error(w, fmt.Sprintf("Invalid JSON: %v", err), http.StatusBadRequest)
|
||||
@@ -659,7 +660,15 @@ func (h *Handler) handleFetchModels(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
apiKey := strings.TrimSpace(req.APIKey)
|
||||
apiBase := strings.TrimSpace(req.APIBase)
|
||||
|
||||
if apiKey == "" && req.ModelIndex != nil {
|
||||
if stored := h.lookupStoredAPIKey(*req.ModelIndex, req.Provider, apiBase); stored != "" {
|
||||
apiKey = stored
|
||||
}
|
||||
}
|
||||
|
||||
if apiBase == "" {
|
||||
apiBase = providers.DefaultAPIBaseForProtocol(req.Provider)
|
||||
}
|
||||
@@ -671,7 +680,7 @@ func (h *Handler) handleFetchModels(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
models, err := fetchUpstreamModels(ctx, req.Provider, apiBase, req.APIKey)
|
||||
models, err := fetchUpstreamModels(ctx, req.Provider, apiBase, apiKey)
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("Failed to fetch models: %v", err), http.StatusBadGateway)
|
||||
return
|
||||
@@ -682,7 +691,7 @@ func (h *Handler) handleFetchModels(w http.ResponseWriter, r *http.Request) {
|
||||
for i, m := range models {
|
||||
catalogModels[i] = CatalogModel{ID: m.ID, OwnedBy: m.OwnedBy}
|
||||
}
|
||||
if saveErr := SaveCatalog(req.Provider, apiBase, req.APIKey, catalogModels); saveErr != nil {
|
||||
if saveErr := SaveCatalog(req.Provider, apiBase, apiKey, catalogModels); saveErr != nil {
|
||||
// Log but don't fail the request — saving catalog is non-critical
|
||||
logger.Warnf("Failed to save model catalog: %v", saveErr)
|
||||
}
|
||||
@@ -694,6 +703,30 @@ func (h *Handler) handleFetchModels(w http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Handler) lookupStoredAPIKey(index int, reqProvider, reqAPIBase string) string {
|
||||
cfg, err := config.LoadConfig(h.configPath)
|
||||
if err != nil || index < 0 || index >= len(cfg.ModelList) {
|
||||
return ""
|
||||
}
|
||||
stored := cfg.ModelList[index]
|
||||
storedProvider, _ := providers.ExtractProtocol(stored)
|
||||
if providers.NormalizeProvider(reqProvider) != providers.NormalizeProvider(storedProvider) {
|
||||
return ""
|
||||
}
|
||||
effectiveReqBase := strings.TrimSpace(reqAPIBase)
|
||||
if effectiveReqBase == "" {
|
||||
effectiveReqBase = providers.DefaultAPIBaseForProtocol(reqProvider)
|
||||
}
|
||||
effectiveStoredBase := strings.TrimSpace(stored.APIBase)
|
||||
if effectiveStoredBase == "" {
|
||||
effectiveStoredBase = providers.DefaultAPIBaseForProtocol(storedProvider)
|
||||
}
|
||||
if normalizeAPIBaseForCompare(effectiveReqBase) != normalizeAPIBaseForCompare(effectiveStoredBase) {
|
||||
return ""
|
||||
}
|
||||
return stored.APIKey()
|
||||
}
|
||||
|
||||
type upstreamModel struct {
|
||||
ID string `json:"id"`
|
||||
OwnedBy string `json:"owned_by,omitempty"`
|
||||
|
||||
@@ -6,6 +6,8 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
@@ -2576,3 +2578,106 @@ func TestHandleFetchModels_SiliconFlowUsesOpenAICompatibleEndpoint(t *testing.T)
|
||||
t.Fatalf("model id = %q, want %q", resp.Models[0].ID, "deepseek-ai/DeepSeek-V3")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleFetchModels_ModelIndexUsesStoredKey(t *testing.T) {
|
||||
var gotAuth string
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotAuth = r.Header.Get("Authorization")
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
fmt.Fprint(w, `{"data":[{"id":"gpt-4o","owned_by":"openai"}]}`)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
tmp := t.TempDir()
|
||||
oldHome := os.Getenv("PICOCLAW_HOME")
|
||||
t.Setenv("PICOCLAW_HOME", filepath.Join(tmp, ".picoclaw"))
|
||||
defer func() {
|
||||
if oldHome != "" {
|
||||
os.Setenv("PICOCLAW_HOME", oldHome)
|
||||
} else {
|
||||
os.Unsetenv("PICOCLAW_HOME")
|
||||
}
|
||||
}()
|
||||
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.ModelList = []*config.ModelConfig{
|
||||
{
|
||||
ModelName: "my-openai",
|
||||
Provider: "openai",
|
||||
Model: "gpt-4o",
|
||||
APIKeys: config.SimpleSecureStrings("sk-stored-secret"),
|
||||
APIBase: srv.URL,
|
||||
},
|
||||
}
|
||||
configPath := filepath.Join(tmp, "config.json")
|
||||
if err := config.SaveConfig(configPath, cfg); err != nil {
|
||||
t.Fatalf("SaveConfig error: %v", err)
|
||||
}
|
||||
|
||||
h := NewHandler(configPath)
|
||||
mux := http.NewServeMux()
|
||||
h.RegisterRoutes(mux)
|
||||
|
||||
idx := 0
|
||||
body := fmt.Sprintf(`{"provider":"openai","api_base":"%s","model_index":%d}`, srv.URL, idx)
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/models/fetch", bytes.NewBufferString(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
mux.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String())
|
||||
}
|
||||
if gotAuth != "Bearer sk-stored-secret" {
|
||||
t.Fatalf("Authorization = %q, want stored key to be used", gotAuth)
|
||||
}
|
||||
|
||||
var resp struct {
|
||||
Models []upstreamModel `json:"models"`
|
||||
}
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("Unmarshal error: %v", err)
|
||||
}
|
||||
if len(resp.Models) != 1 || resp.Models[0].ID != "gpt-4o" {
|
||||
t.Fatalf("unexpected response: %+v", resp)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleFetchModels_ModelIndexProviderMismatchRejectsKey(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Header.Get("Authorization") != "" {
|
||||
t.Error("stored key should NOT be sent to mismatched provider")
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
fmt.Fprint(w, `{"data":[]}`)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
tmp := t.TempDir()
|
||||
t.Setenv("PICOCLAW_HOME", filepath.Join(tmp, ".picoclaw"))
|
||||
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.ModelList = []*config.ModelConfig{
|
||||
{
|
||||
ModelName: "my-openai",
|
||||
Provider: "openai",
|
||||
Model: "gpt-4o",
|
||||
APIKeys: config.SimpleSecureStrings("sk-openai-secret"),
|
||||
APIBase: "https://api.openai.com/v1",
|
||||
},
|
||||
}
|
||||
configPath := filepath.Join(tmp, "config.json")
|
||||
if err := config.SaveConfig(configPath, cfg); err != nil {
|
||||
t.Fatalf("SaveConfig error: %v", err)
|
||||
}
|
||||
|
||||
h := NewHandler(configPath)
|
||||
mux := http.NewServeMux()
|
||||
h.RegisterRoutes(mux)
|
||||
|
||||
body := fmt.Sprintf(`{"provider":"siliconflow","api_base":"%s","model_index":0}`, srv.URL)
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/models/fetch", bytes.NewBufferString(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
mux.ServeHTTP(rec, req)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user