diff --git a/web/backend/api/models.go b/web/backend/api/models.go index 64a7b5f1f..48babd8cd 100644 --- a/web/backend/api/models.go +++ b/web/backend/api/models.go @@ -108,7 +108,12 @@ func (h *Handler) handleAddModel(w http.ResponseWriter, r *http.Request) { } defer r.Body.Close() - var mc config.ModelConfig + type custom struct { + config.ModelConfig + APIKey string `json:"api_key"` + } + + var mc custom if err = json.Unmarshal(body, &mc); err != nil { http.Error(w, fmt.Sprintf("Invalid JSON: %v", err), http.StatusBadRequest) return @@ -119,13 +124,17 @@ func (h *Handler) handleAddModel(w http.ResponseWriter, r *http.Request) { return } + if mc.APIKey != "" { + mc.ModelConfig.SetAPIKey(mc.APIKey) + } + cfg, err := config.LoadConfig(h.configPath) if err != nil { http.Error(w, fmt.Sprintf("Failed to load config: %v", err), http.StatusInternalServerError) return } - cfg.ModelList = append(cfg.ModelList, &mc) + cfg.ModelList = append(cfg.ModelList, &mc.ModelConfig) if err := config.SaveConfig(h.configPath, cfg); err != nil { http.Error(w, fmt.Sprintf("Failed to save config: %v", err), http.StatusInternalServerError) diff --git a/web/backend/api/models_test.go b/web/backend/api/models_test.go index 0127ce675..9d3e72bd3 100644 --- a/web/backend/api/models_test.go +++ b/web/backend/api/models_test.go @@ -1,6 +1,7 @@ package api import ( + "bytes" "encoding/json" "net/http" "net/http/httptest" @@ -317,6 +318,44 @@ func TestHandleListModels_NormalizesWildcardLocalAPIBaseForProbe(t *testing.T) { } } +func TestHandleAddModel_PersistsAPIKey(t *testing.T) { + configPath, cleanup := setupOAuthTestEnv(t) + defer cleanup() + + h := NewHandler(configPath) + mux := http.NewServeMux() + h.RegisterRoutes(mux) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/api/models", bytes.NewBufferString(`{ + "model_name":"new-model", + "model":"openai/gpt-4o-mini", + "api_key":"sk-new-model-key" + }`)) + req.Header.Set("Content-Type", "application/json") + mux.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String()) + } + + cfg, err := config.LoadConfig(configPath) + if err != nil { + t.Fatalf("LoadConfig() error = %v", err) + } + if len(cfg.ModelList) != 2 { + t.Fatalf("len(model_list) = %d, want 2", len(cfg.ModelList)) + } + + added := cfg.ModelList[1] + if added.ModelName != "new-model" { + t.Fatalf("model_name = %q, want %q", added.ModelName, "new-model") + } + if added.APIKey() != "sk-new-model-key" { + t.Fatalf("api_key = %q, want %q", added.APIKey(), "sk-new-model-key") + } +} + func TestMaskAPIKey(t *testing.T) { tests := []struct { name string