refactor: support explicit provider field in model list entries (#2609)

* refactor: support explicit model list providers

* fix(web): preserve explicit model providers

* fix(web): preserve legacy provider prefixes on model updates

fix(models): normalize explicit provider-prefixed ids

fix(api): preserve legacy model updates across providers

fix(agent): preserve config identity for explicit provider refs

* fix ci
This commit is contained in:
lxowalle
2026-04-22 11:28:47 +08:00
committed by GitHub
parent 3316ee6923
commit 77b0c43392
42 changed files with 1559 additions and 441 deletions
+12 -16
View File
@@ -87,7 +87,7 @@ func hasModelConfiguration(m *config.ModelConfig) bool {
apiKey := strings.TrimSpace(m.APIKey())
if authMethod == "oauth" || authMethod == "token" {
if provider, ok := oauthProviderForModel(m.Model); ok {
if provider, ok := oauthProviderForModel(m); ok {
cred, err := oauthGetCredential(provider)
if err != nil || cred == nil {
return false
@@ -123,7 +123,7 @@ func requiresRuntimeProbe(m *config.ModelConfig) bool {
return true
}
protocol := modelProtocol(m.Model)
protocol := modelProtocol(m)
switch protocol {
case "claude-cli", "claudecli", "codex-cli", "codexcli", "github-copilot", "copilot":
@@ -172,7 +172,7 @@ func (s *modelProbeCacheState) probe(cacheKey string, probeFunc func() bool) boo
func runLocalModelProbe(m *config.ModelConfig) bool {
apiBase := modelProbeAPIBase(m)
protocol, modelID := splitModel(m.Model)
protocol, modelID := splitModel(m)
switch protocol {
case "ollama":
return probeOllamaModelFunc(apiBase, modelID)
@@ -191,7 +191,7 @@ func runLocalModelProbe(m *config.ModelConfig) bool {
}
func modelProbeCacheKey(m *config.ModelConfig) string {
protocol, modelID := splitModel(m.Model)
protocol, modelID := splitModel(m)
apiBaseRaw := modelProbeAPIBase(m)
apiBase := strings.ToLower(strings.TrimRight(strings.TrimSpace(apiBaseRaw), "/"))
@@ -384,7 +384,7 @@ func modelProbeAPIBase(m *config.ModelConfig) string {
return normalizeModelProbeAPIBase(apiBase)
}
protocol := modelProtocol(m.Model)
protocol := modelProtocol(m)
if providers.IsEmptyAPIKeyAllowedForProtocol(protocol) {
return providers.DefaultAPIBaseForProtocol(protocol)
}
@@ -419,8 +419,8 @@ func normalizeModelProbeAPIBase(raw string) string {
return u.String()
}
func oauthProviderForModel(model string) (string, bool) {
switch modelProtocol(model) {
func oauthProviderForModel(m *config.ModelConfig) (string, bool) {
switch modelProtocol(m) {
case "openai":
return oauthProviderOpenAI, true
case "anthropic":
@@ -432,18 +432,14 @@ func oauthProviderForModel(model string) (string, bool) {
}
}
func modelProtocol(model string) string {
protocol, _ := splitModel(model)
func modelProtocol(m *config.ModelConfig) string {
protocol, _ := splitModel(m)
return protocol
}
func splitModel(model string) (protocol, modelID string) {
model = strings.ToLower(strings.TrimSpace(model))
protocol, _, found := strings.Cut(model, "/")
if !found {
return "openai", model
}
return protocol, strings.TrimSpace(model[strings.Index(model, "/")+1:])
func splitModel(m *config.ModelConfig) (protocol, modelID string) {
protocol, modelID = providers.ExtractProtocol(m)
return strings.ToLower(strings.TrimSpace(protocol)), strings.ToLower(strings.TrimSpace(modelID))
}
func hasLocalAPIBase(raw string) bool {
+41 -1
View File
@@ -6,10 +6,12 @@ import (
"io"
"net/http"
"strconv"
"strings"
"sync"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/providers"
)
// registerModelRoutes binds model list management endpoints to the ServeMux.
@@ -26,6 +28,7 @@ func (h *Handler) registerModelRoutes(mux *http.ServeMux) {
type modelResponse struct {
Index int `json:"index"`
ModelName string `json:"model_name"`
Provider string `json:"provider,omitempty"`
Model string `json:"model"`
APIBase string `json:"api_base,omitempty"`
APIKey string `json:"api_key"`
@@ -73,10 +76,12 @@ func (h *Handler) handleListModels(w http.ResponseWriter, r *http.Request) {
models := make([]modelResponse, 0, len(cfg.ModelList))
for i, m := range cfg.ModelList {
provider, modelID := providers.ExtractProtocol(m)
models = append(models, modelResponse{
Index: i,
ModelName: m.ModelName,
Model: m.Model,
Provider: provider,
Model: modelID,
APIBase: m.APIBase,
APIKey: maskAPIKey(m.APIKey()),
Proxy: m.Proxy,
@@ -176,6 +181,12 @@ func (h *Handler) handleUpdateModel(w http.ResponseWriter, r *http.Request) {
}
defer r.Body.Close()
var rawFields map[string]json.RawMessage
if err = json.Unmarshal(body, &rawFields); err != nil {
http.Error(w, fmt.Sprintf("Invalid JSON: %v", err), http.StatusBadRequest)
return
}
type custom struct {
config.ModelConfig
APIKey string `json:"api_key"`
@@ -226,6 +237,35 @@ func (h *Handler) handleUpdateModel(w http.ResponseWriter, r *http.Request) {
} else if len(mc.CustomHeaders) == 0 {
mc.CustomHeaders = nil
}
// Preserve the existing Provider when the caller omits it. This keeps the
// update API backward-compatible for clients that haven't started sending
// the new field yet, while still allowing explicit clearing via "".
if _, ok := rawFields["provider"]; !ok {
mc.Provider = cfg.ModelList[idx].Provider
// Older clients still round-trip the legacy model field only. When the
// stored config encodes provider/model in Model and has no explicit
// Provider field yet, continue preserving that hidden provider prefix.
// This keeps provider-omitted updates backward-compatible even when an
// older client edits the visible model ID.
if strings.TrimSpace(cfg.ModelList[idx].Provider) == "" {
existingProtocol, existingModelID := providers.ExtractProtocol(cfg.ModelList[idx])
existingRawModel := strings.TrimSpace(cfg.ModelList[idx].Model)
incomingModel := strings.TrimSpace(mc.Model)
if existingRawModel != "" && existingRawModel != existingModelID && incomingModel != "" {
if incomingModel == existingModelID {
mc.Model = existingRawModel
} else if strings.Contains(incomingModel, "/") && !strings.Contains(existingModelID, "/") {
// Older clients never saw the hidden provider prefix for simple
// legacy entries such as "openai/gpt-4o". If they now send an
// explicit provider/model string, treat it as the caller's full
// intent instead of re-applying the old hidden prefix.
mc.Model = incomingModel
} else if !strings.HasPrefix(incomingModel, existingProtocol+"/") {
mc.Model = existingProtocol + "/" + incomingModel
}
}
}
}
cfg.ModelList[idx] = &mc.ModelConfig
+477
View File
@@ -392,6 +392,49 @@ func TestHandleListModels_StatusMarksUnreachableLocalModel(t *testing.T) {
}
}
func TestHandleListModels_RuntimeProbeUsesExplicitProviderField(t *testing.T) {
configPath, cleanup := setupOAuthTestEnv(t)
defer cleanup()
resetOAuthHooks(t)
resetModelProbeHooks(t)
var gotProbe string
probeOpenAICompatibleModelFunc = func(apiBase, modelID, apiKey string) bool {
gotProbe = apiBase + "|" + modelID + "|" + apiKey
return true
}
cfg, err := config.LoadConfig(configPath)
if err != nil {
t.Fatalf("LoadConfig() error = %v", err)
}
cfg.ModelList = []*config.ModelConfig{{
ModelName: "vllm-local",
Provider: "vllm",
Model: "custom-model",
APIBase: "http://127.0.0.1:8000/v1",
}}
if err := config.SaveConfig(configPath, cfg); err != nil {
t.Fatalf("SaveConfig() error = %v", err)
}
h := NewHandler(configPath)
mux := http.NewServeMux()
h.RegisterRoutes(mux)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/models", nil)
mux.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String())
}
if gotProbe != "http://127.0.0.1:8000/v1|custom-model|" {
t.Fatalf("probe = %q, want %q", gotProbe, "http://127.0.0.1:8000/v1|custom-model|")
}
}
func TestHandleAddModel_PersistsAPIKey(t *testing.T) {
configPath, cleanup := setupOAuthTestEnv(t)
defer cleanup()
@@ -430,6 +473,76 @@ func TestHandleAddModel_PersistsAPIKey(t *testing.T) {
}
}
func TestHandleAddModel_PersistsProvider(t *testing.T) {
configPath, cleanup := setupOAuthTestEnv(t)
defer cleanup()
h := NewHandler(configPath)
mux := http.NewServeMux()
h.RegisterRoutes(mux)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/api/models", bytes.NewBufferString(`{
"model_name":"nvidia-glm",
"provider":"nvidia",
"model":"z-ai/glm-5.1",
"api_key":"nv-key"
}`))
req.Header.Set("Content-Type", "application/json")
mux.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String())
}
cfg, err := config.LoadConfig(configPath)
if err != nil {
t.Fatalf("LoadConfig() error = %v", err)
}
added := cfg.ModelList[len(cfg.ModelList)-1]
if added.Provider != "nvidia" {
t.Fatalf("provider = %q, want %q", added.Provider, "nvidia")
}
if added.Model != "z-ai/glm-5.1" {
t.Fatalf("model = %q, want %q", added.Model, "z-ai/glm-5.1")
}
}
func TestHandleAddModel_PreservesExplicitProviderPrefixedModel(t *testing.T) {
configPath, cleanup := setupOAuthTestEnv(t)
defer cleanup()
h := NewHandler(configPath)
mux := http.NewServeMux()
h.RegisterRoutes(mux)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/api/models", bytes.NewBufferString(`{
"model_name":"openai-gpt",
"provider":"openai",
"model":"openai/gpt-4o-mini",
"api_key":"sk-openai"
}`))
req.Header.Set("Content-Type", "application/json")
mux.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String())
}
cfg, err := config.LoadConfig(configPath)
if err != nil {
t.Fatalf("LoadConfig() error = %v", err)
}
added := cfg.ModelList[len(cfg.ModelList)-1]
if got := added.Provider; got != "openai" {
t.Fatalf("provider = %q, want %q", got, "openai")
}
if got := added.Model; got != "openai/gpt-4o-mini" {
t.Fatalf("model = %q, want %q", got, "openai/gpt-4o-mini")
}
}
func TestHandleAddModel_PersistsCustomHeaders(t *testing.T) {
configPath, cleanup := setupOAuthTestEnv(t)
defer cleanup()
@@ -536,6 +649,370 @@ func TestHandleUpdateModel_CustomHeadersPreserveAndClear(t *testing.T) {
}
}
func TestHandleUpdateModel_PersistsProvider(t *testing.T) {
configPath, cleanup := setupOAuthTestEnv(t)
defer cleanup()
cfg, err := config.LoadConfig(configPath)
if err != nil {
t.Fatalf("LoadConfig() error = %v", err)
}
cfg.ModelList = []*config.ModelConfig{{
ModelName: "editable",
Model: "gpt-4o",
Provider: "openai",
}}
if err = config.SaveConfig(configPath, cfg); err != nil {
t.Fatalf("SaveConfig() error = %v", err)
}
h := NewHandler(configPath)
mux := http.NewServeMux()
h.RegisterRoutes(mux)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPut, "/api/models/0", bytes.NewBufferString(`{
"model_name":"editable",
"provider":"openrouter",
"model":"openai/gpt-4o"
}`))
req.Header.Set("Content-Type", "application/json")
mux.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String())
}
updated, err := config.LoadConfig(configPath)
if err != nil {
t.Fatalf("LoadConfig() error = %v", err)
}
if got := updated.ModelList[0].Provider; got != "openrouter" {
t.Fatalf("provider = %q, want %q", got, "openrouter")
}
}
func TestHandleUpdateModel_PreservesExplicitProviderPrefixedModel(t *testing.T) {
configPath, cleanup := setupOAuthTestEnv(t)
defer cleanup()
cfg, err := config.LoadConfig(configPath)
if err != nil {
t.Fatalf("LoadConfig() error = %v", err)
}
cfg.ModelList = []*config.ModelConfig{{
ModelName: "editable",
Model: "gpt-4o",
Provider: "openai",
}}
if err = config.SaveConfig(configPath, cfg); err != nil {
t.Fatalf("SaveConfig() error = %v", err)
}
h := NewHandler(configPath)
mux := http.NewServeMux()
h.RegisterRoutes(mux)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPut, "/api/models/0", bytes.NewBufferString(`{
"model_name":"editable",
"provider":"openai",
"model":"openai/gpt-5.4"
}`))
req.Header.Set("Content-Type", "application/json")
mux.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String())
}
updated, err := config.LoadConfig(configPath)
if err != nil {
t.Fatalf("LoadConfig() error = %v", err)
}
if got := updated.ModelList[0].Provider; got != "openai" {
t.Fatalf("provider = %q, want %q", got, "openai")
}
if got := updated.ModelList[0].Model; got != "openai/gpt-5.4" {
t.Fatalf("model = %q, want %q", got, "openai/gpt-5.4")
}
}
func TestHandleListModels_PreservesExplicitProviderPrefixedModel(t *testing.T) {
configPath, cleanup := setupOAuthTestEnv(t)
defer cleanup()
cfg, err := config.LoadConfig(configPath)
if err != nil {
t.Fatalf("LoadConfig() error = %v", err)
}
cfg.ModelList = []*config.ModelConfig{{
ModelName: "openrouter-auto-explicit",
Provider: "openrouter",
Model: "openrouter/auto",
}}
if err := config.SaveConfig(configPath, cfg); err != nil {
t.Fatalf("SaveConfig() error = %v", err)
}
h := NewHandler(configPath)
mux := http.NewServeMux()
h.RegisterRoutes(mux)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/models", nil)
mux.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String())
}
var resp struct {
Models []modelResponse `json:"models"`
}
if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil {
t.Fatalf("Unmarshal() error = %v", err)
}
if len(resp.Models) != 1 {
t.Fatalf("len(models) = %d, want 1", len(resp.Models))
}
if got := resp.Models[0].Provider; got != "openrouter" {
t.Fatalf("provider = %q, want %q", got, "openrouter")
}
if got := resp.Models[0].Model; got != "openrouter/auto" {
t.Fatalf("model = %q, want %q", got, "openrouter/auto")
}
}
func TestHandleUpdateModel_PreservesLegacyModelPrefixWhenProviderOmitted(t *testing.T) {
configPath, cleanup := setupOAuthTestEnv(t)
defer cleanup()
cfg, err := config.LoadConfig(configPath)
if err != nil {
t.Fatalf("LoadConfig() error = %v", err)
}
cfg.ModelList = []*config.ModelConfig{{
ModelName: "legacy-openrouter",
Model: "openrouter/openai/gpt-5.4",
}}
if err = config.SaveConfig(configPath, cfg); err != nil {
t.Fatalf("SaveConfig() error = %v", err)
}
h := NewHandler(configPath)
mux := http.NewServeMux()
h.RegisterRoutes(mux)
// Simulate an older client: it reads GET /api/models, ignores the new
// provider field, then PUTs the visible model string back unchanged.
recList := httptest.NewRecorder()
reqList := httptest.NewRequest(http.MethodGet, "/api/models", nil)
mux.ServeHTTP(recList, reqList)
if recList.Code != http.StatusOK {
t.Fatalf("list status = %d, want %d, body=%s", recList.Code, http.StatusOK, recList.Body.String())
}
var listResp struct {
Models []modelResponse `json:"models"`
}
if err = json.Unmarshal(recList.Body.Bytes(), &listResp); err != nil {
t.Fatalf("Unmarshal() error = %v", err)
}
if len(listResp.Models) != 1 {
t.Fatalf("len(models) = %d, want 1", len(listResp.Models))
}
if got := listResp.Models[0].Provider; got != "openrouter" {
t.Fatalf("provider = %q, want %q", got, "openrouter")
}
if got := listResp.Models[0].Model; got != "openai/gpt-5.4" {
t.Fatalf("model = %q, want %q", got, "openai/gpt-5.4")
}
recUpdate := httptest.NewRecorder()
reqUpdate := httptest.NewRequest(http.MethodPut, "/api/models/0", bytes.NewBufferString(`{
"model_name":"legacy-openrouter",
"model":"openai/gpt-5.4"
}`))
reqUpdate.Header.Set("Content-Type", "application/json")
mux.ServeHTTP(recUpdate, reqUpdate)
if recUpdate.Code != http.StatusOK {
t.Fatalf("update status = %d, want %d, body=%s", recUpdate.Code, http.StatusOK, recUpdate.Body.String())
}
updated, err := config.LoadConfig(configPath)
if err != nil {
t.Fatalf("LoadConfig() error = %v", err)
}
if got := updated.ModelList[0].Provider; got != "" {
t.Fatalf("provider = %q, want empty", got)
}
if got := updated.ModelList[0].Model; got != "openrouter/openai/gpt-5.4" {
t.Fatalf("model = %q, want %q", got, "openrouter/openai/gpt-5.4")
}
}
func TestHandleUpdateModel_PreservesLegacyModelPrefixWhenProviderOmittedAndModelChanges(t *testing.T) {
configPath, cleanup := setupOAuthTestEnv(t)
defer cleanup()
cfg, err := config.LoadConfig(configPath)
if err != nil {
t.Fatalf("LoadConfig() error = %v", err)
}
cfg.ModelList = []*config.ModelConfig{{
ModelName: "legacy-openrouter",
Model: "openrouter/openai/gpt-5.4",
}}
if err = config.SaveConfig(configPath, cfg); err != nil {
t.Fatalf("SaveConfig() error = %v", err)
}
h := NewHandler(configPath)
mux := http.NewServeMux()
h.RegisterRoutes(mux)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPut, "/api/models/0", bytes.NewBufferString(`{
"model_name":"legacy-openrouter",
"model":"openai/gpt-5.5"
}`))
req.Header.Set("Content-Type", "application/json")
mux.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String())
}
updated, err := config.LoadConfig(configPath)
if err != nil {
t.Fatalf("LoadConfig() error = %v", err)
}
if got := updated.ModelList[0].Provider; got != "" {
t.Fatalf("provider = %q, want empty", got)
}
if got := updated.ModelList[0].Model; got != "openrouter/openai/gpt-5.5" {
t.Fatalf("model = %q, want %q", got, "openrouter/openai/gpt-5.5")
}
}
func TestHandleListModels_ReturnsProviderField(t *testing.T) {
configPath, cleanup := setupOAuthTestEnv(t)
defer cleanup()
cfg, err := config.LoadConfig(configPath)
if err != nil {
t.Fatalf("LoadConfig() error = %v", err)
}
cfg.ModelList = []*config.ModelConfig{{
ModelName: "nvidia-glm",
Provider: "nvidia",
Model: "z-ai/glm-5.1",
APIKeys: config.SimpleSecureStrings("nv-key"),
}}
if err := config.SaveConfig(configPath, cfg); err != nil {
t.Fatalf("SaveConfig() error = %v", err)
}
h := NewHandler(configPath)
mux := http.NewServeMux()
h.RegisterRoutes(mux)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/models", nil)
mux.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String())
}
var resp struct {
Models []modelResponse `json:"models"`
}
if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil {
t.Fatalf("Unmarshal() error = %v", err)
}
if len(resp.Models) != 1 {
t.Fatalf("len(models) = %d, want 1", len(resp.Models))
}
if got := resp.Models[0].Provider; got != "nvidia" {
t.Fatalf("provider = %q, want %q", got, "nvidia")
}
}
func TestHandleListModels_ReturnsEffectiveProviderField(t *testing.T) {
configPath, cleanup := setupOAuthTestEnv(t)
defer cleanup()
cfg, err := config.LoadConfig(configPath)
if err != nil {
t.Fatalf("LoadConfig() error = %v", err)
}
cfg.ModelList = []*config.ModelConfig{
{
ModelName: "plain-openai",
Model: "gpt-4o",
},
{
ModelName: "explicit-google",
Provider: "google",
Model: "gemini-2.5-pro",
},
{
ModelName: "explicit-qwen-intl",
Provider: "qwen-international",
Model: "qwen3-coder-plus",
},
}
if err := config.SaveConfig(configPath, cfg); err != nil {
t.Fatalf("SaveConfig() error = %v", err)
}
h := NewHandler(configPath)
mux := http.NewServeMux()
h.RegisterRoutes(mux)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/models", nil)
mux.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String())
}
var resp struct {
Models []modelResponse `json:"models"`
}
if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil {
t.Fatalf("Unmarshal() error = %v", err)
}
if len(resp.Models) != 3 {
t.Fatalf("len(models) = %d, want 3", len(resp.Models))
}
if got := resp.Models[0].Provider; got != "openai" {
t.Fatalf("provider[0] = %q, want %q", got, "openai")
}
if got := resp.Models[0].Model; got != "gpt-4o" {
t.Fatalf("model[0] = %q, want %q", got, "gpt-4o")
}
if got := resp.Models[1].Provider; got != "gemini" {
t.Fatalf("provider[1] = %q, want %q", got, "gemini")
}
if got := resp.Models[1].Model; got != "gemini-2.5-pro" {
t.Fatalf("model[1] = %q, want %q", got, "gemini-2.5-pro")
}
if got := resp.Models[2].Provider; got != "qwen-intl" {
t.Fatalf("provider[2] = %q, want %q", got, "qwen-intl")
}
if got := resp.Models[2].Model; got != "qwen3-coder-plus" {
t.Fatalf("model[2] = %q, want %q", got, "qwen3-coder-plus")
}
}
// TestHandleSetDefaultModel_RejectsNonexistentModel tests that setting a non-existent
// model as default returns 404. This covers the case where virtual models (which are
// filtered by SaveConfig) cannot be set as default.
+12 -12
View File
@@ -746,7 +746,7 @@ func (h *Handler) syncProviderAuthMethod(provider, authMethod string) error {
found := false
for i := range cfg.ModelList {
if modelBelongsToProvider(provider, cfg.ModelList[i].Model) {
if modelBelongsToProvider(provider, cfg.ModelList[i]) {
cfg.ModelList[i].AuthMethod = authMethod
found = true
}
@@ -759,18 +759,15 @@ func (h *Handler) syncProviderAuthMethod(provider, authMethod string) error {
return oauthSaveConfig(h.configPath, cfg)
}
func modelBelongsToProvider(provider, model string) bool {
lower := strings.ToLower(strings.TrimSpace(model))
func modelBelongsToProvider(provider string, modelCfg *config.ModelConfig) bool {
protocol, _ := providers.ExtractProtocol(modelCfg)
switch provider {
case oauthProviderOpenAI:
return lower == "openai" || strings.HasPrefix(lower, "openai/")
return protocol == "openai"
case oauthProviderAnthropic:
return lower == "anthropic" || strings.HasPrefix(lower, "anthropic/")
return protocol == "anthropic"
case oauthProviderGoogleAntigravity:
return lower == "antigravity" ||
lower == "google-antigravity" ||
strings.HasPrefix(lower, "antigravity/") ||
strings.HasPrefix(lower, "google-antigravity/")
return protocol == "antigravity" || protocol == "google-antigravity"
default:
return false
}
@@ -781,19 +778,22 @@ func defaultModelConfigForProvider(provider, authMethod string) *config.ModelCon
case oauthProviderOpenAI:
return &config.ModelConfig{
ModelName: "gpt-5.4",
Model: "openai/gpt-5.4",
Provider: "openai",
Model: "gpt-5.4",
AuthMethod: authMethod,
}
case oauthProviderAnthropic:
return &config.ModelConfig{
ModelName: "claude-sonnet-4.6",
Model: "anthropic/claude-sonnet-4.6",
Provider: "anthropic",
Model: "claude-sonnet-4.6",
AuthMethod: authMethod,
}
case oauthProviderGoogleAntigravity:
return &config.ModelConfig{
ModelName: "gemini-flash",
Model: "antigravity/gemini-3-flash",
Provider: "antigravity",
Model: "gemini-3-flash",
AuthMethod: authMethod,
}
default:
+48
View File
@@ -214,6 +214,54 @@ func TestOAuthLogoutClearsCredentialAndConfig(t *testing.T) {
}
}
func TestOAuthLogoutClearsAuthMethodForExplicitProviderField(t *testing.T) {
configPath, cleanup := setupOAuthTestEnv(t)
defer cleanup()
resetOAuthHooks(t)
cfg, err := config.LoadConfig(configPath)
if err != nil {
t.Fatalf("LoadConfig error: %v", err)
}
cfg.ModelList = append(cfg.ModelList, &config.ModelConfig{
ModelName: "gpt-5.4",
Provider: "openai",
Model: "gpt-5.4",
AuthMethod: "oauth",
})
if err = config.SaveConfig(configPath, cfg); err != nil {
t.Fatalf("SaveConfig error: %v", err)
}
if err = auth.SetCredential(oauthProviderOpenAI, &auth.AuthCredential{
AccessToken: "token-before-logout",
Provider: oauthProviderOpenAI,
AuthMethod: "oauth",
}); err != nil {
t.Fatalf("SetCredential error: %v", err)
}
h := NewHandler(configPath)
mux := http.NewServeMux()
h.RegisterRoutes(mux)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/api/oauth/logout", bytes.NewBufferString(`{"provider":"openai"}`))
req.Header.Set("Content-Type", "application/json")
mux.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String())
}
updated, err := config.LoadConfig(configPath)
if err != nil {
t.Fatalf("LoadConfig error: %v", err)
}
if got := updated.ModelList[len(updated.ModelList)-1].AuthMethod; got != "" {
t.Fatalf("auth_method = %q, want empty", got)
}
}
func setupOAuthTestEnv(t *testing.T) (string, func()) {
t.Helper()