From 6ce0306c66a11d5c8ce994f0904cdcb72968b91d Mon Sep 17 00:00:00 2001 From: corevibe555 <45244658+corevibe555@users.noreply.github.com> Date: Tue, 7 Apr 2026 15:07:56 +0300 Subject: [PATCH] fix: use per candidate provider for model_fallbacks (#2143) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix: use per-candidate provider for model_fallbacks Each fallback model now uses its own api_base and api_key from model_list instead of inheriting the primary model's provider config. Previously, a single LLMProvider was created from the primary model's ModelConfig and reused for all fallback candidates — only the model ID string was swapped. This caused all fallback requests to be routed to the primary provider's endpoint, making cross-provider fallback chains non-functional (e.g., OpenRouter primary with Gemini fallback would send the Gemini request to OpenRouter's API). Fix: pre-create a per-candidate LLMProvider at agent initialization time by looking up each candidate's ModelConfig from model_list. The fallback run closure now selects the correct provider per candidate via CandidateProviders map, falling back to agent.Provider when no override is found. Fixes #2140 Made-with: Cursor test: add test for instance.go fix: fix test refactor: optimize fix: fix Golang lint issues chore: comment cleanup * refactor: use resolvedModelConfig() instead of buildModelIndex() * fix --- pkg/agent/instance.go | 45 +++++++++ pkg/agent/instance_test.go | 194 +++++++++++++++++++++++++++++++++++++ pkg/agent/loop.go | 6 +- pkg/agent/loop_test.go | 158 ++++++++++++++++++++++++++++++ 4 files changed, 402 insertions(+), 1 deletion(-) diff --git a/pkg/agent/instance.go b/pkg/agent/instance.go index bacfa49c5..48e5aa625 100644 --- a/pkg/agent/instance.go +++ b/pkg/agent/instance.go @@ -51,6 +51,10 @@ type AgentInstance struct { // LightProvider is the concrete provider instance for the configured light model. // It is only used when routing selects the light tier for a turn. LightProvider providers.LLMProvider + // CandidateProviders maps "provider/model" keys to per-candidate LLMProvider + // instances. This allows each fallback model to use its own api_base and api_key + // from model_list, instead of inheriting the primary model's provider config. + CandidateProviders map[string]providers.LLMProvider } // NewAgentInstance creates an agent instance from config. @@ -175,6 +179,9 @@ func NewAgentInstance( // Resolve fallback candidates candidates := resolveModelCandidates(cfg, defaults.Provider, model, fallbacks) + candidateProviders := make(map[string]providers.LLMProvider) + populateCandidateProvidersFromNames(cfg, workspace, fallbacks, candidateProviders) + // Model routing setup: pre-resolve light model candidates at creation time // to avoid repeated model_list lookups on every incoming message. var router *routing.Router @@ -199,6 +206,7 @@ func NewAgentInstance( }) lightCandidates = resolved lightProvider = lp + populateCandidateProvidersFromNames(cfg, workspace, []string{rc.LightModel}, candidateProviders) } } } else { @@ -230,6 +238,43 @@ func NewAgentInstance( Router: router, LightCandidates: lightCandidates, LightProvider: lightProvider, + CandidateProviders: candidateProviders, + } +} + +// populateCandidateProvidersFromNames resolves each model name (alias or +// "provider/model") via resolvedModelConfig and creates a dedicated LLMProvider +// for it. This reuses the canonical config resolution path (GetModelConfig) so +// alias handling and load-balancing stay consistent with the rest of the codebase. +func populateCandidateProvidersFromNames( + cfg *config.Config, + workspace string, + names []string, + out map[string]providers.LLMProvider, +) { + if cfg == nil || len(names) == 0 { + return + } + for _, name := range names { + mc, err := resolvedModelConfig(cfg, strings.TrimSpace(name), workspace) + if err != nil { + logger.WarnCF("agent", + "fallback provider: no model_list entry found; will inherit primary provider credentials", + map[string]any{"name": name, "error": err.Error()}) + continue + } + protocol, modelID := providers.ExtractProtocol(strings.TrimSpace(mc.Model)) + key := providers.ModelKey(providers.NormalizeProvider(protocol), modelID) + if _, exists := out[key]; exists { + continue + } + p, _, err := providers.CreateProviderFromConfig(mc) + if err != nil { + logger.WarnCF("agent", "fallback provider: failed to create provider", + map[string]any{"model": mc.Model, "error": err.Error()}) + continue + } + out[key] = p } } diff --git a/pkg/agent/instance_test.go b/pkg/agent/instance_test.go index ba907e88b..8c71296ed 100644 --- a/pkg/agent/instance_test.go +++ b/pkg/agent/instance_test.go @@ -9,6 +9,7 @@ import ( "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/media" + "github.com/sipeed/picoclaw/pkg/providers" ) func TestNewAgentInstance_UsesDefaultsTemperatureAndMaxTokens(t *testing.T) { @@ -300,6 +301,199 @@ func TestNewAgentInstance_AllowsMediaTempDirForReadListAndExec(t *testing.T) { } } +// TestPopulateCandidateProviders_NilCfgIsNoop verifies that passing a nil +// config does not panic and leaves the output map empty. +func TestPopulateCandidateProviders_NilCfgIsNoop(t *testing.T) { + out := map[string]providers.LLMProvider{} + populateCandidateProvidersFromNames(nil, t.TempDir(), []string{"gpt-4o"}, out) + if len(out) != 0 { + t.Fatalf("expected empty map, got %d entries", len(out)) + } +} + +// TestPopulateCandidateProviders_SkipsExistingKeys verifies that a key already +// present in the output map is not overwritten. +func TestPopulateCandidateProviders_SkipsExistingKeys(t *testing.T) { + existing := &mockProvider{} + key := providers.ModelKey("openai", "gpt-4o") + out := map[string]providers.LLMProvider{key: existing} + + cfg := &config.Config{ + ModelList: []*config.ModelConfig{ + {ModelName: "my-gpt", Model: "openai/gpt-4o", APIKeys: config.SimpleSecureStrings("test-key")}, + }, + } + populateCandidateProvidersFromNames(cfg, t.TempDir(), []string{"my-gpt"}, out) + + if out[key] != existing { + t.Fatal("existing provider entry was overwritten; expected it to be preserved") + } +} + +// TestPopulateCandidateProviders_ResolvesAlias verifies that a model_name +// alias (e.g. "my-gpt") is resolved via GetModelConfig and the provider +// is created using the underlying model's config. +func TestPopulateCandidateProviders_ResolvesAlias(t *testing.T) { + workspace := t.TempDir() + out := map[string]providers.LLMProvider{} + + cfg := &config.Config{ + ModelList: []*config.ModelConfig{ + {ModelName: "my-gpt", Model: "openai/gpt-4o", APIBase: "https://api.openai.com/v1", Workspace: workspace}, + }, + } + populateCandidateProvidersFromNames(cfg, workspace, []string{"my-gpt"}, out) + + key := providers.ModelKey("openai", "gpt-4o") + if out[key] == nil { + t.Fatalf("expected CandidateProviders[%q] to be populated for alias", key) + } +} + +// TestPopulateCandidateProviders_ResolvesProtocolPrefix verifies that a +// model_list entry using full "provider/model" notation (e.g. +// "gemini/gemma-3-27b-it") is matched correctly when referenced by model_name. +func TestPopulateCandidateProviders_ResolvesProtocolPrefix(t *testing.T) { + workspace := t.TempDir() + out := map[string]providers.LLMProvider{} + + cfg := &config.Config{ + ModelList: []*config.ModelConfig{ + { + ModelName: "gemma", + Model: "gemini/gemma-3-27b-it", + APIKeys: config.SimpleSecureStrings("gemini-test-key"), + Workspace: workspace, + }, + }, + } + populateCandidateProvidersFromNames(cfg, workspace, []string{"gemma"}, out) + + key := providers.ModelKey("gemini", "gemma-3-27b-it") + if out[key] == nil { + t.Fatalf("expected CandidateProviders[%q] to be populated for protocol-prefixed model", key) + } +} + +// TestPopulateCandidateProviders_EmptyNamesIsNoop verifies the early-exit +// path when the names slice is empty. +func TestPopulateCandidateProviders_EmptyNamesIsNoop(t *testing.T) { + out := map[string]providers.LLMProvider{} + cfg := &config.Config{ + ModelList: []*config.ModelConfig{ + {ModelName: "my-gpt", Model: "openai/gpt-4o", APIKeys: config.SimpleSecureStrings("key")}, + }, + } + populateCandidateProvidersFromNames(cfg, t.TempDir(), nil, out) + if len(out) != 0 { + t.Fatalf("expected empty map, got %d entries", len(out)) + } +} + +// TestPopulateCandidateProviders_EmptyModelListIsNoop verifies the early-exit +// path when model_list is empty — no provider can be created. +func TestPopulateCandidateProviders_EmptyModelListIsNoop(t *testing.T) { + out := map[string]providers.LLMProvider{} + cfg := &config.Config{} + populateCandidateProvidersFromNames(cfg, t.TempDir(), []string{"gpt-4o"}, out) + if len(out) != 0 { + t.Fatalf("expected empty map, got %d entries", len(out)) + } +} + +// TestPopulateCandidateProviders_UnmatchedNameIsSkipped verifies that a +// name with no matching model_list entry is skipped and does not +// cause a panic or leave a nil entry in the map. +func TestPopulateCandidateProviders_UnmatchedNameIsSkipped(t *testing.T) { + out := map[string]providers.LLMProvider{} + cfg := &config.Config{ + ModelList: []*config.ModelConfig{ + {ModelName: "my-gpt", Model: "openai/gpt-4o", APIKeys: config.SimpleSecureStrings("key")}, + }, + } + populateCandidateProvidersFromNames(cfg, t.TempDir(), []string{"nonexistent-model"}, out) + + if len(out) != 0 { + t.Fatalf("expected empty map for unmatched name, got %d entries", len(out)) + } +} + +// TestNewAgentInstance_CandidateProvidersPopulatedForCrossProviderFallbacks +// mirrors the exact scenario from bug #2140: primary model on OpenRouter with +// Gemini fallbacks. Each entry must get its own provider instance so that +// fallback requests go to the correct API endpoint, not the primary's. +func TestNewAgentInstance_CandidateProvidersPopulatedForCrossProviderFallbacks(t *testing.T) { + workspace := t.TempDir() + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: workspace, + ModelName: "mistral-small-3.1", + ModelFallbacks: []string{"gemma-3-27b", "gemini-images"}, + }, + }, + ModelList: []*config.ModelConfig{ + { + ModelName: "mistral-small-3.1", + Model: "openrouter/mistralai/mistral-small-3.1-24b-instruct:free", + APIBase: "https://openrouter.ai/api/v1", + APIKeys: config.SimpleSecureStrings("sk-or-test"), + Workspace: workspace, + }, + { + ModelName: "gemma-3-27b", + Model: "gemini/gemma-3-27b-it", + APIKeys: config.SimpleSecureStrings("AIzaSy-test"), + Workspace: workspace, + }, + { + ModelName: "gemini-images", + Model: "gemini/gemini-2.5-flash-lite", + APIKeys: config.SimpleSecureStrings("AIzaSy-test"), + Workspace: workspace, + }, + }, + } + + primaryProvider := &mockProvider{} + agent := NewAgentInstance(nil, &cfg.Agents.Defaults, cfg, primaryProvider) + + // Only fallback models need entries — the primary uses the injected provider directly. + wantKeys := []string{ + providers.ModelKey("gemini", "gemma-3-27b-it"), + providers.ModelKey("gemini", "gemini-2.5-flash-lite"), + } + + for _, key := range wantKeys { + p, ok := agent.CandidateProviders[key] + if !ok { + t.Errorf("CandidateProviders missing key %q", key) + continue + } + if p == nil { + t.Errorf("CandidateProviders[%q] is nil", key) + } + // Each fallback must use its own provider, not the injected primary. + if p == primaryProvider { + t.Errorf( + "CandidateProviders[%q] is the same instance as the primary provider; fallback would inherit primary credentials", + key, + ) + } + } + + if t.Failed() { + t.Logf("CandidateProviders keys present: %v", func() []string { + keys := make([]string, 0, len(agent.CandidateProviders)) + for k := range agent.CandidateProviders { + keys = append(keys, k) + } + return keys + }()) + } +} + func TestNewAgentInstance_ReadFileModeSelectsSchema(t *testing.T) { workspace := t.TempDir() diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index fc37ff8a0..369928d78 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -2020,7 +2020,11 @@ turnLoop: providerCtx, activeCandidates, func(ctx context.Context, provider, model string) (*providers.LLMResponse, error) { - return activeProvider.Chat(ctx, messagesForCall, toolDefsForCall, model, llmOpts) + candidateProvider := activeProvider + if cp, ok := ts.agent.CandidateProviders[providers.ModelKey(provider, model)]; ok { + candidateProvider = cp + } + return candidateProvider.Chat(ctx, messagesForCall, toolDefsForCall, model, llmOpts) }, ) if fbErr != nil { diff --git a/pkg/agent/loop_test.go b/pkg/agent/loop_test.go index 9513d8aca..3d04b81cc 100644 --- a/pkg/agent/loop_test.go +++ b/pkg/agent/loop_test.go @@ -1839,6 +1839,164 @@ func TestProcessMessage_ModelRoutingUsesLightProvider(t *testing.T) { } } +// TestProcessMessage_FallbackUsesPerCandidateProvider is the loop-level test for +// bug #2140. It verifies that when the primary model returns a rate-limit error +// the fallback closure routes the retry to the fallback model's own provider +// (its own api_base), not back to the primary provider's endpoint. +func TestProcessMessage_FallbackUsesPerCandidateProvider(t *testing.T) { + workspace := t.TempDir() + + primaryCalls := 0 + primaryServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + primaryCalls++ + // Return 429 so FallbackChain classifies this as retriable and moves on. + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusTooManyRequests) + _ = json.NewEncoder(w).Encode(map[string]any{ + "error": map[string]any{ + "message": "rate limit exceeded", + "type": "rate_limit_error", + }, + }) + })) + defer primaryServer.Close() + + fallbackCalls := 0 + fallbackServer := newStrictChatCompletionTestServer( + t, "fallback", "gemma-3-27b-it", "fallback reply", &fallbackCalls, + ) + defer fallbackServer.Close() + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: workspace, + ModelName: "mistral-primary", + ModelFallbacks: []string{"gemma-fallback"}, + MaxTokens: 4096, + MaxToolIterations: 3, + }, + }, + ModelList: []*config.ModelConfig{ + { + ModelName: "mistral-primary", + Model: "openrouter/mistralai/mistral-small-3.1", + APIBase: primaryServer.URL, + APIKeys: config.SimpleSecureStrings("primary-key"), + Workspace: workspace, + }, + { + ModelName: "gemma-fallback", + Model: "gemini/gemma-3-27b-it", + APIBase: fallbackServer.URL, + APIKeys: config.SimpleSecureStrings("fallback-key"), + Workspace: workspace, + }, + }, + } + + provider, _, err := providers.CreateProvider(cfg) + if err != nil { + t.Fatalf("CreateProvider() error = %v", err) + } + msgBus := bus.NewMessageBus() + al := NewAgentLoop(cfg, msgBus, provider) + helper := testHelper{al: al} + + resp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{ + Channel: "telegram", + SenderID: "user1", + ChatID: "chat1", + Content: "hi", + Peer: bus.Peer{Kind: "direct", ID: "user1"}, + }) + + if resp != "fallback reply" { + t.Fatalf("response = %q, want %q (fallback provider)", resp, "fallback reply") + } + if primaryCalls == 0 { + t.Fatal("primary server was never called; expected at least one attempt") + } + if fallbackCalls != 1 { + t.Fatalf("fallback server calls = %d, want 1", fallbackCalls) + } +} + +// TestProcessMessage_FallbackUsesActiveProviderWhenCandidateNotRegistered verifies +// that when a candidate has no model_list entry it is absent from CandidateProviders +// and the fallback closure falls back to activeProvider instead of panicking. +func TestProcessMessage_FallbackUsesActiveProviderWhenCandidateNotRegistered(t *testing.T) { + workspace := t.TempDir() + + // Primary server: returns 429 on first call, succeeds on second. + // Both the primary and the unregistered fallback share this server + // (same api_base) so activeProvider routes both calls here. + callCount := 0 + primaryServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + callCount++ + w.Header().Set("Content-Type", "application/json") + if callCount == 1 { + w.WriteHeader(http.StatusTooManyRequests) + _ = json.NewEncoder(w).Encode(map[string]any{ + "error": map[string]any{"message": "rate limit", "type": "rate_limit_error"}, + }) + return + } + // Second call (fallback via activeProvider) succeeds. + _ = json.NewEncoder(w).Encode(map[string]any{ + "choices": []map[string]any{ + {"message": map[string]any{"content": "active provider reply"}, "finish_reason": "stop"}, + }, + }) + })) + defer primaryServer.Close() + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: workspace, + ModelName: "primary-model", + MaxTokens: 4096, + MaxToolIterations: 3, + // No model_list entry for this alias — absent from CandidateProviders. + ModelFallbacks: []string{"openrouter/fallback-model"}, + }, + }, + ModelList: []*config.ModelConfig{ + { + ModelName: "primary-model", + Model: "openrouter/primary-model", + APIBase: primaryServer.URL, + APIKeys: config.SimpleSecureStrings("primary-key"), + Workspace: workspace, + }, + }, + } + + provider, _, err := providers.CreateProvider(cfg) + if err != nil { + t.Fatalf("CreateProvider() error = %v", err) + } + msgBus := bus.NewMessageBus() + al := NewAgentLoop(cfg, msgBus, provider) + + helper := testHelper{al: al} + resp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{ + Channel: "telegram", + SenderID: "user1", + ChatID: "chat1", + Content: "hi", + Peer: bus.Peer{Kind: "direct", ID: "user1"}, + }) + + if resp != "active provider reply" { + t.Fatalf("response = %q, want %q", resp, "active provider reply") + } + if callCount < 2 { + t.Fatalf("primary server calls = %d, want >= 2 (one 429 + one success via activeProvider)", callCount) + } +} + // TestToolResult_SilentToolDoesNotSendUserMessage verifies silent tools don't trigger outbound func TestToolResult_SilentToolDoesNotSendUserMessage(t *testing.T) { tmpDir, err := os.MkdirTemp("", "agent-test-*")