fix: use per candidate provider for model_fallbacks (#2143)

* fix: use per-candidate provider for model_fallbacks

Each fallback model now uses its own api_base and api_key from
model_list instead of inheriting the primary model's provider config.

Previously, a single LLMProvider was created from the primary model's
ModelConfig and reused for all fallback candidates — only the model ID
string was swapped. This caused all fallback requests to be routed to
the primary provider's endpoint, making cross-provider fallback chains
non-functional (e.g., OpenRouter primary with Gemini fallback would
send the Gemini request to OpenRouter's API).

Fix: pre-create a per-candidate LLMProvider at agent initialization
time by looking up each candidate's ModelConfig from model_list. The
fallback run closure now selects the correct provider per candidate
via CandidateProviders map, falling back to agent.Provider when no
override is found.

Fixes #2140

Made-with: Cursor

test: add test for instance.go

fix: fix test

refactor: optimize

fix: fix Golang lint issues

chore: comment cleanup

* refactor: use resolvedModelConfig() instead of buildModelIndex()

* fix
This commit is contained in:
corevibe555
2026-04-07 15:07:56 +03:00
committed by GitHub
parent 1fc2710999
commit 6ce0306c66
4 changed files with 402 additions and 1 deletions
+45
View File
@@ -51,6 +51,10 @@ type AgentInstance struct {
// LightProvider is the concrete provider instance for the configured light model.
// It is only used when routing selects the light tier for a turn.
LightProvider providers.LLMProvider
// CandidateProviders maps "provider/model" keys to per-candidate LLMProvider
// instances. This allows each fallback model to use its own api_base and api_key
// from model_list, instead of inheriting the primary model's provider config.
CandidateProviders map[string]providers.LLMProvider
}
// NewAgentInstance creates an agent instance from config.
@@ -175,6 +179,9 @@ func NewAgentInstance(
// Resolve fallback candidates
candidates := resolveModelCandidates(cfg, defaults.Provider, model, fallbacks)
candidateProviders := make(map[string]providers.LLMProvider)
populateCandidateProvidersFromNames(cfg, workspace, fallbacks, candidateProviders)
// Model routing setup: pre-resolve light model candidates at creation time
// to avoid repeated model_list lookups on every incoming message.
var router *routing.Router
@@ -199,6 +206,7 @@ func NewAgentInstance(
})
lightCandidates = resolved
lightProvider = lp
populateCandidateProvidersFromNames(cfg, workspace, []string{rc.LightModel}, candidateProviders)
}
}
} else {
@@ -230,6 +238,43 @@ func NewAgentInstance(
Router: router,
LightCandidates: lightCandidates,
LightProvider: lightProvider,
CandidateProviders: candidateProviders,
}
}
// populateCandidateProvidersFromNames resolves each model name (alias or
// "provider/model") via resolvedModelConfig and creates a dedicated LLMProvider
// for it. This reuses the canonical config resolution path (GetModelConfig) so
// alias handling and load-balancing stay consistent with the rest of the codebase.
func populateCandidateProvidersFromNames(
cfg *config.Config,
workspace string,
names []string,
out map[string]providers.LLMProvider,
) {
if cfg == nil || len(names) == 0 {
return
}
for _, name := range names {
mc, err := resolvedModelConfig(cfg, strings.TrimSpace(name), workspace)
if err != nil {
logger.WarnCF("agent",
"fallback provider: no model_list entry found; will inherit primary provider credentials",
map[string]any{"name": name, "error": err.Error()})
continue
}
protocol, modelID := providers.ExtractProtocol(strings.TrimSpace(mc.Model))
key := providers.ModelKey(providers.NormalizeProvider(protocol), modelID)
if _, exists := out[key]; exists {
continue
}
p, _, err := providers.CreateProviderFromConfig(mc)
if err != nil {
logger.WarnCF("agent", "fallback provider: failed to create provider",
map[string]any{"model": mc.Model, "error": err.Error()})
continue
}
out[key] = p
}
}
+194
View File
@@ -9,6 +9,7 @@ import (
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/media"
"github.com/sipeed/picoclaw/pkg/providers"
)
func TestNewAgentInstance_UsesDefaultsTemperatureAndMaxTokens(t *testing.T) {
@@ -300,6 +301,199 @@ func TestNewAgentInstance_AllowsMediaTempDirForReadListAndExec(t *testing.T) {
}
}
// TestPopulateCandidateProviders_NilCfgIsNoop verifies that passing a nil
// config does not panic and leaves the output map empty.
func TestPopulateCandidateProviders_NilCfgIsNoop(t *testing.T) {
out := map[string]providers.LLMProvider{}
populateCandidateProvidersFromNames(nil, t.TempDir(), []string{"gpt-4o"}, out)
if len(out) != 0 {
t.Fatalf("expected empty map, got %d entries", len(out))
}
}
// TestPopulateCandidateProviders_SkipsExistingKeys verifies that a key already
// present in the output map is not overwritten.
func TestPopulateCandidateProviders_SkipsExistingKeys(t *testing.T) {
existing := &mockProvider{}
key := providers.ModelKey("openai", "gpt-4o")
out := map[string]providers.LLMProvider{key: existing}
cfg := &config.Config{
ModelList: []*config.ModelConfig{
{ModelName: "my-gpt", Model: "openai/gpt-4o", APIKeys: config.SimpleSecureStrings("test-key")},
},
}
populateCandidateProvidersFromNames(cfg, t.TempDir(), []string{"my-gpt"}, out)
if out[key] != existing {
t.Fatal("existing provider entry was overwritten; expected it to be preserved")
}
}
// TestPopulateCandidateProviders_ResolvesAlias verifies that a model_name
// alias (e.g. "my-gpt") is resolved via GetModelConfig and the provider
// is created using the underlying model's config.
func TestPopulateCandidateProviders_ResolvesAlias(t *testing.T) {
workspace := t.TempDir()
out := map[string]providers.LLMProvider{}
cfg := &config.Config{
ModelList: []*config.ModelConfig{
{ModelName: "my-gpt", Model: "openai/gpt-4o", APIBase: "https://api.openai.com/v1", Workspace: workspace},
},
}
populateCandidateProvidersFromNames(cfg, workspace, []string{"my-gpt"}, out)
key := providers.ModelKey("openai", "gpt-4o")
if out[key] == nil {
t.Fatalf("expected CandidateProviders[%q] to be populated for alias", key)
}
}
// TestPopulateCandidateProviders_ResolvesProtocolPrefix verifies that a
// model_list entry using full "provider/model" notation (e.g.
// "gemini/gemma-3-27b-it") is matched correctly when referenced by model_name.
func TestPopulateCandidateProviders_ResolvesProtocolPrefix(t *testing.T) {
workspace := t.TempDir()
out := map[string]providers.LLMProvider{}
cfg := &config.Config{
ModelList: []*config.ModelConfig{
{
ModelName: "gemma",
Model: "gemini/gemma-3-27b-it",
APIKeys: config.SimpleSecureStrings("gemini-test-key"),
Workspace: workspace,
},
},
}
populateCandidateProvidersFromNames(cfg, workspace, []string{"gemma"}, out)
key := providers.ModelKey("gemini", "gemma-3-27b-it")
if out[key] == nil {
t.Fatalf("expected CandidateProviders[%q] to be populated for protocol-prefixed model", key)
}
}
// TestPopulateCandidateProviders_EmptyNamesIsNoop verifies the early-exit
// path when the names slice is empty.
func TestPopulateCandidateProviders_EmptyNamesIsNoop(t *testing.T) {
out := map[string]providers.LLMProvider{}
cfg := &config.Config{
ModelList: []*config.ModelConfig{
{ModelName: "my-gpt", Model: "openai/gpt-4o", APIKeys: config.SimpleSecureStrings("key")},
},
}
populateCandidateProvidersFromNames(cfg, t.TempDir(), nil, out)
if len(out) != 0 {
t.Fatalf("expected empty map, got %d entries", len(out))
}
}
// TestPopulateCandidateProviders_EmptyModelListIsNoop verifies the early-exit
// path when model_list is empty — no provider can be created.
func TestPopulateCandidateProviders_EmptyModelListIsNoop(t *testing.T) {
out := map[string]providers.LLMProvider{}
cfg := &config.Config{}
populateCandidateProvidersFromNames(cfg, t.TempDir(), []string{"gpt-4o"}, out)
if len(out) != 0 {
t.Fatalf("expected empty map, got %d entries", len(out))
}
}
// TestPopulateCandidateProviders_UnmatchedNameIsSkipped verifies that a
// name with no matching model_list entry is skipped and does not
// cause a panic or leave a nil entry in the map.
func TestPopulateCandidateProviders_UnmatchedNameIsSkipped(t *testing.T) {
out := map[string]providers.LLMProvider{}
cfg := &config.Config{
ModelList: []*config.ModelConfig{
{ModelName: "my-gpt", Model: "openai/gpt-4o", APIKeys: config.SimpleSecureStrings("key")},
},
}
populateCandidateProvidersFromNames(cfg, t.TempDir(), []string{"nonexistent-model"}, out)
if len(out) != 0 {
t.Fatalf("expected empty map for unmatched name, got %d entries", len(out))
}
}
// TestNewAgentInstance_CandidateProvidersPopulatedForCrossProviderFallbacks
// mirrors the exact scenario from bug #2140: primary model on OpenRouter with
// Gemini fallbacks. Each entry must get its own provider instance so that
// fallback requests go to the correct API endpoint, not the primary's.
func TestNewAgentInstance_CandidateProvidersPopulatedForCrossProviderFallbacks(t *testing.T) {
workspace := t.TempDir()
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: workspace,
ModelName: "mistral-small-3.1",
ModelFallbacks: []string{"gemma-3-27b", "gemini-images"},
},
},
ModelList: []*config.ModelConfig{
{
ModelName: "mistral-small-3.1",
Model: "openrouter/mistralai/mistral-small-3.1-24b-instruct:free",
APIBase: "https://openrouter.ai/api/v1",
APIKeys: config.SimpleSecureStrings("sk-or-test"),
Workspace: workspace,
},
{
ModelName: "gemma-3-27b",
Model: "gemini/gemma-3-27b-it",
APIKeys: config.SimpleSecureStrings("AIzaSy-test"),
Workspace: workspace,
},
{
ModelName: "gemini-images",
Model: "gemini/gemini-2.5-flash-lite",
APIKeys: config.SimpleSecureStrings("AIzaSy-test"),
Workspace: workspace,
},
},
}
primaryProvider := &mockProvider{}
agent := NewAgentInstance(nil, &cfg.Agents.Defaults, cfg, primaryProvider)
// Only fallback models need entries — the primary uses the injected provider directly.
wantKeys := []string{
providers.ModelKey("gemini", "gemma-3-27b-it"),
providers.ModelKey("gemini", "gemini-2.5-flash-lite"),
}
for _, key := range wantKeys {
p, ok := agent.CandidateProviders[key]
if !ok {
t.Errorf("CandidateProviders missing key %q", key)
continue
}
if p == nil {
t.Errorf("CandidateProviders[%q] is nil", key)
}
// Each fallback must use its own provider, not the injected primary.
if p == primaryProvider {
t.Errorf(
"CandidateProviders[%q] is the same instance as the primary provider; fallback would inherit primary credentials",
key,
)
}
}
if t.Failed() {
t.Logf("CandidateProviders keys present: %v", func() []string {
keys := make([]string, 0, len(agent.CandidateProviders))
for k := range agent.CandidateProviders {
keys = append(keys, k)
}
return keys
}())
}
}
func TestNewAgentInstance_ReadFileModeSelectsSchema(t *testing.T) {
workspace := t.TempDir()
+5 -1
View File
@@ -2020,7 +2020,11 @@ turnLoop:
providerCtx,
activeCandidates,
func(ctx context.Context, provider, model string) (*providers.LLMResponse, error) {
return activeProvider.Chat(ctx, messagesForCall, toolDefsForCall, model, llmOpts)
candidateProvider := activeProvider
if cp, ok := ts.agent.CandidateProviders[providers.ModelKey(provider, model)]; ok {
candidateProvider = cp
}
return candidateProvider.Chat(ctx, messagesForCall, toolDefsForCall, model, llmOpts)
},
)
if fbErr != nil {
+158
View File
@@ -1839,6 +1839,164 @@ func TestProcessMessage_ModelRoutingUsesLightProvider(t *testing.T) {
}
}
// TestProcessMessage_FallbackUsesPerCandidateProvider is the loop-level test for
// bug #2140. It verifies that when the primary model returns a rate-limit error
// the fallback closure routes the retry to the fallback model's own provider
// (its own api_base), not back to the primary provider's endpoint.
func TestProcessMessage_FallbackUsesPerCandidateProvider(t *testing.T) {
workspace := t.TempDir()
primaryCalls := 0
primaryServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
primaryCalls++
// Return 429 so FallbackChain classifies this as retriable and moves on.
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusTooManyRequests)
_ = json.NewEncoder(w).Encode(map[string]any{
"error": map[string]any{
"message": "rate limit exceeded",
"type": "rate_limit_error",
},
})
}))
defer primaryServer.Close()
fallbackCalls := 0
fallbackServer := newStrictChatCompletionTestServer(
t, "fallback", "gemma-3-27b-it", "fallback reply", &fallbackCalls,
)
defer fallbackServer.Close()
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: workspace,
ModelName: "mistral-primary",
ModelFallbacks: []string{"gemma-fallback"},
MaxTokens: 4096,
MaxToolIterations: 3,
},
},
ModelList: []*config.ModelConfig{
{
ModelName: "mistral-primary",
Model: "openrouter/mistralai/mistral-small-3.1",
APIBase: primaryServer.URL,
APIKeys: config.SimpleSecureStrings("primary-key"),
Workspace: workspace,
},
{
ModelName: "gemma-fallback",
Model: "gemini/gemma-3-27b-it",
APIBase: fallbackServer.URL,
APIKeys: config.SimpleSecureStrings("fallback-key"),
Workspace: workspace,
},
},
}
provider, _, err := providers.CreateProvider(cfg)
if err != nil {
t.Fatalf("CreateProvider() error = %v", err)
}
msgBus := bus.NewMessageBus()
al := NewAgentLoop(cfg, msgBus, provider)
helper := testHelper{al: al}
resp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
Channel: "telegram",
SenderID: "user1",
ChatID: "chat1",
Content: "hi",
Peer: bus.Peer{Kind: "direct", ID: "user1"},
})
if resp != "fallback reply" {
t.Fatalf("response = %q, want %q (fallback provider)", resp, "fallback reply")
}
if primaryCalls == 0 {
t.Fatal("primary server was never called; expected at least one attempt")
}
if fallbackCalls != 1 {
t.Fatalf("fallback server calls = %d, want 1", fallbackCalls)
}
}
// TestProcessMessage_FallbackUsesActiveProviderWhenCandidateNotRegistered verifies
// that when a candidate has no model_list entry it is absent from CandidateProviders
// and the fallback closure falls back to activeProvider instead of panicking.
func TestProcessMessage_FallbackUsesActiveProviderWhenCandidateNotRegistered(t *testing.T) {
workspace := t.TempDir()
// Primary server: returns 429 on first call, succeeds on second.
// Both the primary and the unregistered fallback share this server
// (same api_base) so activeProvider routes both calls here.
callCount := 0
primaryServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
callCount++
w.Header().Set("Content-Type", "application/json")
if callCount == 1 {
w.WriteHeader(http.StatusTooManyRequests)
_ = json.NewEncoder(w).Encode(map[string]any{
"error": map[string]any{"message": "rate limit", "type": "rate_limit_error"},
})
return
}
// Second call (fallback via activeProvider) succeeds.
_ = json.NewEncoder(w).Encode(map[string]any{
"choices": []map[string]any{
{"message": map[string]any{"content": "active provider reply"}, "finish_reason": "stop"},
},
})
}))
defer primaryServer.Close()
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: workspace,
ModelName: "primary-model",
MaxTokens: 4096,
MaxToolIterations: 3,
// No model_list entry for this alias — absent from CandidateProviders.
ModelFallbacks: []string{"openrouter/fallback-model"},
},
},
ModelList: []*config.ModelConfig{
{
ModelName: "primary-model",
Model: "openrouter/primary-model",
APIBase: primaryServer.URL,
APIKeys: config.SimpleSecureStrings("primary-key"),
Workspace: workspace,
},
},
}
provider, _, err := providers.CreateProvider(cfg)
if err != nil {
t.Fatalf("CreateProvider() error = %v", err)
}
msgBus := bus.NewMessageBus()
al := NewAgentLoop(cfg, msgBus, provider)
helper := testHelper{al: al}
resp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
Channel: "telegram",
SenderID: "user1",
ChatID: "chat1",
Content: "hi",
Peer: bus.Peer{Kind: "direct", ID: "user1"},
})
if resp != "active provider reply" {
t.Fatalf("response = %q, want %q", resp, "active provider reply")
}
if callCount < 2 {
t.Fatalf("primary server calls = %d, want >= 2 (one 429 + one success via activeProvider)", callCount)
}
}
// TestToolResult_SilentToolDoesNotSendUserMessage verifies silent tools don't trigger outbound
func TestToolResult_SilentToolDoesNotSendUserMessage(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")