mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
fix: use per candidate provider for model_fallbacks (#2143)
* fix: use per-candidate provider for model_fallbacks Each fallback model now uses its own api_base and api_key from model_list instead of inheriting the primary model's provider config. Previously, a single LLMProvider was created from the primary model's ModelConfig and reused for all fallback candidates — only the model ID string was swapped. This caused all fallback requests to be routed to the primary provider's endpoint, making cross-provider fallback chains non-functional (e.g., OpenRouter primary with Gemini fallback would send the Gemini request to OpenRouter's API). Fix: pre-create a per-candidate LLMProvider at agent initialization time by looking up each candidate's ModelConfig from model_list. The fallback run closure now selects the correct provider per candidate via CandidateProviders map, falling back to agent.Provider when no override is found. Fixes #2140 Made-with: Cursor test: add test for instance.go fix: fix test refactor: optimize fix: fix Golang lint issues chore: comment cleanup * refactor: use resolvedModelConfig() instead of buildModelIndex() * fix
This commit is contained in:
@@ -51,6 +51,10 @@ type AgentInstance struct {
|
||||
// LightProvider is the concrete provider instance for the configured light model.
|
||||
// It is only used when routing selects the light tier for a turn.
|
||||
LightProvider providers.LLMProvider
|
||||
// CandidateProviders maps "provider/model" keys to per-candidate LLMProvider
|
||||
// instances. This allows each fallback model to use its own api_base and api_key
|
||||
// from model_list, instead of inheriting the primary model's provider config.
|
||||
CandidateProviders map[string]providers.LLMProvider
|
||||
}
|
||||
|
||||
// NewAgentInstance creates an agent instance from config.
|
||||
@@ -175,6 +179,9 @@ func NewAgentInstance(
|
||||
// Resolve fallback candidates
|
||||
candidates := resolveModelCandidates(cfg, defaults.Provider, model, fallbacks)
|
||||
|
||||
candidateProviders := make(map[string]providers.LLMProvider)
|
||||
populateCandidateProvidersFromNames(cfg, workspace, fallbacks, candidateProviders)
|
||||
|
||||
// Model routing setup: pre-resolve light model candidates at creation time
|
||||
// to avoid repeated model_list lookups on every incoming message.
|
||||
var router *routing.Router
|
||||
@@ -199,6 +206,7 @@ func NewAgentInstance(
|
||||
})
|
||||
lightCandidates = resolved
|
||||
lightProvider = lp
|
||||
populateCandidateProvidersFromNames(cfg, workspace, []string{rc.LightModel}, candidateProviders)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
@@ -230,6 +238,43 @@ func NewAgentInstance(
|
||||
Router: router,
|
||||
LightCandidates: lightCandidates,
|
||||
LightProvider: lightProvider,
|
||||
CandidateProviders: candidateProviders,
|
||||
}
|
||||
}
|
||||
|
||||
// populateCandidateProvidersFromNames resolves each model name (alias or
|
||||
// "provider/model") via resolvedModelConfig and creates a dedicated LLMProvider
|
||||
// for it. This reuses the canonical config resolution path (GetModelConfig) so
|
||||
// alias handling and load-balancing stay consistent with the rest of the codebase.
|
||||
func populateCandidateProvidersFromNames(
|
||||
cfg *config.Config,
|
||||
workspace string,
|
||||
names []string,
|
||||
out map[string]providers.LLMProvider,
|
||||
) {
|
||||
if cfg == nil || len(names) == 0 {
|
||||
return
|
||||
}
|
||||
for _, name := range names {
|
||||
mc, err := resolvedModelConfig(cfg, strings.TrimSpace(name), workspace)
|
||||
if err != nil {
|
||||
logger.WarnCF("agent",
|
||||
"fallback provider: no model_list entry found; will inherit primary provider credentials",
|
||||
map[string]any{"name": name, "error": err.Error()})
|
||||
continue
|
||||
}
|
||||
protocol, modelID := providers.ExtractProtocol(strings.TrimSpace(mc.Model))
|
||||
key := providers.ModelKey(providers.NormalizeProvider(protocol), modelID)
|
||||
if _, exists := out[key]; exists {
|
||||
continue
|
||||
}
|
||||
p, _, err := providers.CreateProviderFromConfig(mc)
|
||||
if err != nil {
|
||||
logger.WarnCF("agent", "fallback provider: failed to create provider",
|
||||
map[string]any{"model": mc.Model, "error": err.Error()})
|
||||
continue
|
||||
}
|
||||
out[key] = p
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/media"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
func TestNewAgentInstance_UsesDefaultsTemperatureAndMaxTokens(t *testing.T) {
|
||||
@@ -300,6 +301,199 @@ func TestNewAgentInstance_AllowsMediaTempDirForReadListAndExec(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestPopulateCandidateProviders_NilCfgIsNoop verifies that passing a nil
|
||||
// config does not panic and leaves the output map empty.
|
||||
func TestPopulateCandidateProviders_NilCfgIsNoop(t *testing.T) {
|
||||
out := map[string]providers.LLMProvider{}
|
||||
populateCandidateProvidersFromNames(nil, t.TempDir(), []string{"gpt-4o"}, out)
|
||||
if len(out) != 0 {
|
||||
t.Fatalf("expected empty map, got %d entries", len(out))
|
||||
}
|
||||
}
|
||||
|
||||
// TestPopulateCandidateProviders_SkipsExistingKeys verifies that a key already
|
||||
// present in the output map is not overwritten.
|
||||
func TestPopulateCandidateProviders_SkipsExistingKeys(t *testing.T) {
|
||||
existing := &mockProvider{}
|
||||
key := providers.ModelKey("openai", "gpt-4o")
|
||||
out := map[string]providers.LLMProvider{key: existing}
|
||||
|
||||
cfg := &config.Config{
|
||||
ModelList: []*config.ModelConfig{
|
||||
{ModelName: "my-gpt", Model: "openai/gpt-4o", APIKeys: config.SimpleSecureStrings("test-key")},
|
||||
},
|
||||
}
|
||||
populateCandidateProvidersFromNames(cfg, t.TempDir(), []string{"my-gpt"}, out)
|
||||
|
||||
if out[key] != existing {
|
||||
t.Fatal("existing provider entry was overwritten; expected it to be preserved")
|
||||
}
|
||||
}
|
||||
|
||||
// TestPopulateCandidateProviders_ResolvesAlias verifies that a model_name
|
||||
// alias (e.g. "my-gpt") is resolved via GetModelConfig and the provider
|
||||
// is created using the underlying model's config.
|
||||
func TestPopulateCandidateProviders_ResolvesAlias(t *testing.T) {
|
||||
workspace := t.TempDir()
|
||||
out := map[string]providers.LLMProvider{}
|
||||
|
||||
cfg := &config.Config{
|
||||
ModelList: []*config.ModelConfig{
|
||||
{ModelName: "my-gpt", Model: "openai/gpt-4o", APIBase: "https://api.openai.com/v1", Workspace: workspace},
|
||||
},
|
||||
}
|
||||
populateCandidateProvidersFromNames(cfg, workspace, []string{"my-gpt"}, out)
|
||||
|
||||
key := providers.ModelKey("openai", "gpt-4o")
|
||||
if out[key] == nil {
|
||||
t.Fatalf("expected CandidateProviders[%q] to be populated for alias", key)
|
||||
}
|
||||
}
|
||||
|
||||
// TestPopulateCandidateProviders_ResolvesProtocolPrefix verifies that a
|
||||
// model_list entry using full "provider/model" notation (e.g.
|
||||
// "gemini/gemma-3-27b-it") is matched correctly when referenced by model_name.
|
||||
func TestPopulateCandidateProviders_ResolvesProtocolPrefix(t *testing.T) {
|
||||
workspace := t.TempDir()
|
||||
out := map[string]providers.LLMProvider{}
|
||||
|
||||
cfg := &config.Config{
|
||||
ModelList: []*config.ModelConfig{
|
||||
{
|
||||
ModelName: "gemma",
|
||||
Model: "gemini/gemma-3-27b-it",
|
||||
APIKeys: config.SimpleSecureStrings("gemini-test-key"),
|
||||
Workspace: workspace,
|
||||
},
|
||||
},
|
||||
}
|
||||
populateCandidateProvidersFromNames(cfg, workspace, []string{"gemma"}, out)
|
||||
|
||||
key := providers.ModelKey("gemini", "gemma-3-27b-it")
|
||||
if out[key] == nil {
|
||||
t.Fatalf("expected CandidateProviders[%q] to be populated for protocol-prefixed model", key)
|
||||
}
|
||||
}
|
||||
|
||||
// TestPopulateCandidateProviders_EmptyNamesIsNoop verifies the early-exit
|
||||
// path when the names slice is empty.
|
||||
func TestPopulateCandidateProviders_EmptyNamesIsNoop(t *testing.T) {
|
||||
out := map[string]providers.LLMProvider{}
|
||||
cfg := &config.Config{
|
||||
ModelList: []*config.ModelConfig{
|
||||
{ModelName: "my-gpt", Model: "openai/gpt-4o", APIKeys: config.SimpleSecureStrings("key")},
|
||||
},
|
||||
}
|
||||
populateCandidateProvidersFromNames(cfg, t.TempDir(), nil, out)
|
||||
if len(out) != 0 {
|
||||
t.Fatalf("expected empty map, got %d entries", len(out))
|
||||
}
|
||||
}
|
||||
|
||||
// TestPopulateCandidateProviders_EmptyModelListIsNoop verifies the early-exit
|
||||
// path when model_list is empty — no provider can be created.
|
||||
func TestPopulateCandidateProviders_EmptyModelListIsNoop(t *testing.T) {
|
||||
out := map[string]providers.LLMProvider{}
|
||||
cfg := &config.Config{}
|
||||
populateCandidateProvidersFromNames(cfg, t.TempDir(), []string{"gpt-4o"}, out)
|
||||
if len(out) != 0 {
|
||||
t.Fatalf("expected empty map, got %d entries", len(out))
|
||||
}
|
||||
}
|
||||
|
||||
// TestPopulateCandidateProviders_UnmatchedNameIsSkipped verifies that a
|
||||
// name with no matching model_list entry is skipped and does not
|
||||
// cause a panic or leave a nil entry in the map.
|
||||
func TestPopulateCandidateProviders_UnmatchedNameIsSkipped(t *testing.T) {
|
||||
out := map[string]providers.LLMProvider{}
|
||||
cfg := &config.Config{
|
||||
ModelList: []*config.ModelConfig{
|
||||
{ModelName: "my-gpt", Model: "openai/gpt-4o", APIKeys: config.SimpleSecureStrings("key")},
|
||||
},
|
||||
}
|
||||
populateCandidateProvidersFromNames(cfg, t.TempDir(), []string{"nonexistent-model"}, out)
|
||||
|
||||
if len(out) != 0 {
|
||||
t.Fatalf("expected empty map for unmatched name, got %d entries", len(out))
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewAgentInstance_CandidateProvidersPopulatedForCrossProviderFallbacks
|
||||
// mirrors the exact scenario from bug #2140: primary model on OpenRouter with
|
||||
// Gemini fallbacks. Each entry must get its own provider instance so that
|
||||
// fallback requests go to the correct API endpoint, not the primary's.
|
||||
func TestNewAgentInstance_CandidateProvidersPopulatedForCrossProviderFallbacks(t *testing.T) {
|
||||
workspace := t.TempDir()
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: workspace,
|
||||
ModelName: "mistral-small-3.1",
|
||||
ModelFallbacks: []string{"gemma-3-27b", "gemini-images"},
|
||||
},
|
||||
},
|
||||
ModelList: []*config.ModelConfig{
|
||||
{
|
||||
ModelName: "mistral-small-3.1",
|
||||
Model: "openrouter/mistralai/mistral-small-3.1-24b-instruct:free",
|
||||
APIBase: "https://openrouter.ai/api/v1",
|
||||
APIKeys: config.SimpleSecureStrings("sk-or-test"),
|
||||
Workspace: workspace,
|
||||
},
|
||||
{
|
||||
ModelName: "gemma-3-27b",
|
||||
Model: "gemini/gemma-3-27b-it",
|
||||
APIKeys: config.SimpleSecureStrings("AIzaSy-test"),
|
||||
Workspace: workspace,
|
||||
},
|
||||
{
|
||||
ModelName: "gemini-images",
|
||||
Model: "gemini/gemini-2.5-flash-lite",
|
||||
APIKeys: config.SimpleSecureStrings("AIzaSy-test"),
|
||||
Workspace: workspace,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
primaryProvider := &mockProvider{}
|
||||
agent := NewAgentInstance(nil, &cfg.Agents.Defaults, cfg, primaryProvider)
|
||||
|
||||
// Only fallback models need entries — the primary uses the injected provider directly.
|
||||
wantKeys := []string{
|
||||
providers.ModelKey("gemini", "gemma-3-27b-it"),
|
||||
providers.ModelKey("gemini", "gemini-2.5-flash-lite"),
|
||||
}
|
||||
|
||||
for _, key := range wantKeys {
|
||||
p, ok := agent.CandidateProviders[key]
|
||||
if !ok {
|
||||
t.Errorf("CandidateProviders missing key %q", key)
|
||||
continue
|
||||
}
|
||||
if p == nil {
|
||||
t.Errorf("CandidateProviders[%q] is nil", key)
|
||||
}
|
||||
// Each fallback must use its own provider, not the injected primary.
|
||||
if p == primaryProvider {
|
||||
t.Errorf(
|
||||
"CandidateProviders[%q] is the same instance as the primary provider; fallback would inherit primary credentials",
|
||||
key,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
if t.Failed() {
|
||||
t.Logf("CandidateProviders keys present: %v", func() []string {
|
||||
keys := make([]string, 0, len(agent.CandidateProviders))
|
||||
for k := range agent.CandidateProviders {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
return keys
|
||||
}())
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewAgentInstance_ReadFileModeSelectsSchema(t *testing.T) {
|
||||
workspace := t.TempDir()
|
||||
|
||||
|
||||
+5
-1
@@ -2020,7 +2020,11 @@ turnLoop:
|
||||
providerCtx,
|
||||
activeCandidates,
|
||||
func(ctx context.Context, provider, model string) (*providers.LLMResponse, error) {
|
||||
return activeProvider.Chat(ctx, messagesForCall, toolDefsForCall, model, llmOpts)
|
||||
candidateProvider := activeProvider
|
||||
if cp, ok := ts.agent.CandidateProviders[providers.ModelKey(provider, model)]; ok {
|
||||
candidateProvider = cp
|
||||
}
|
||||
return candidateProvider.Chat(ctx, messagesForCall, toolDefsForCall, model, llmOpts)
|
||||
},
|
||||
)
|
||||
if fbErr != nil {
|
||||
|
||||
@@ -1839,6 +1839,164 @@ func TestProcessMessage_ModelRoutingUsesLightProvider(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestProcessMessage_FallbackUsesPerCandidateProvider is the loop-level test for
|
||||
// bug #2140. It verifies that when the primary model returns a rate-limit error
|
||||
// the fallback closure routes the retry to the fallback model's own provider
|
||||
// (its own api_base), not back to the primary provider's endpoint.
|
||||
func TestProcessMessage_FallbackUsesPerCandidateProvider(t *testing.T) {
|
||||
workspace := t.TempDir()
|
||||
|
||||
primaryCalls := 0
|
||||
primaryServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
primaryCalls++
|
||||
// Return 429 so FallbackChain classifies this as retriable and moves on.
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusTooManyRequests)
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"error": map[string]any{
|
||||
"message": "rate limit exceeded",
|
||||
"type": "rate_limit_error",
|
||||
},
|
||||
})
|
||||
}))
|
||||
defer primaryServer.Close()
|
||||
|
||||
fallbackCalls := 0
|
||||
fallbackServer := newStrictChatCompletionTestServer(
|
||||
t, "fallback", "gemma-3-27b-it", "fallback reply", &fallbackCalls,
|
||||
)
|
||||
defer fallbackServer.Close()
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: workspace,
|
||||
ModelName: "mistral-primary",
|
||||
ModelFallbacks: []string{"gemma-fallback"},
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 3,
|
||||
},
|
||||
},
|
||||
ModelList: []*config.ModelConfig{
|
||||
{
|
||||
ModelName: "mistral-primary",
|
||||
Model: "openrouter/mistralai/mistral-small-3.1",
|
||||
APIBase: primaryServer.URL,
|
||||
APIKeys: config.SimpleSecureStrings("primary-key"),
|
||||
Workspace: workspace,
|
||||
},
|
||||
{
|
||||
ModelName: "gemma-fallback",
|
||||
Model: "gemini/gemma-3-27b-it",
|
||||
APIBase: fallbackServer.URL,
|
||||
APIKeys: config.SimpleSecureStrings("fallback-key"),
|
||||
Workspace: workspace,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
provider, _, err := providers.CreateProvider(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateProvider() error = %v", err)
|
||||
}
|
||||
msgBus := bus.NewMessageBus()
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
helper := testHelper{al: al}
|
||||
|
||||
resp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
|
||||
Channel: "telegram",
|
||||
SenderID: "user1",
|
||||
ChatID: "chat1",
|
||||
Content: "hi",
|
||||
Peer: bus.Peer{Kind: "direct", ID: "user1"},
|
||||
})
|
||||
|
||||
if resp != "fallback reply" {
|
||||
t.Fatalf("response = %q, want %q (fallback provider)", resp, "fallback reply")
|
||||
}
|
||||
if primaryCalls == 0 {
|
||||
t.Fatal("primary server was never called; expected at least one attempt")
|
||||
}
|
||||
if fallbackCalls != 1 {
|
||||
t.Fatalf("fallback server calls = %d, want 1", fallbackCalls)
|
||||
}
|
||||
}
|
||||
|
||||
// TestProcessMessage_FallbackUsesActiveProviderWhenCandidateNotRegistered verifies
|
||||
// that when a candidate has no model_list entry it is absent from CandidateProviders
|
||||
// and the fallback closure falls back to activeProvider instead of panicking.
|
||||
func TestProcessMessage_FallbackUsesActiveProviderWhenCandidateNotRegistered(t *testing.T) {
|
||||
workspace := t.TempDir()
|
||||
|
||||
// Primary server: returns 429 on first call, succeeds on second.
|
||||
// Both the primary and the unregistered fallback share this server
|
||||
// (same api_base) so activeProvider routes both calls here.
|
||||
callCount := 0
|
||||
primaryServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
callCount++
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if callCount == 1 {
|
||||
w.WriteHeader(http.StatusTooManyRequests)
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"error": map[string]any{"message": "rate limit", "type": "rate_limit_error"},
|
||||
})
|
||||
return
|
||||
}
|
||||
// Second call (fallback via activeProvider) succeeds.
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"choices": []map[string]any{
|
||||
{"message": map[string]any{"content": "active provider reply"}, "finish_reason": "stop"},
|
||||
},
|
||||
})
|
||||
}))
|
||||
defer primaryServer.Close()
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: workspace,
|
||||
ModelName: "primary-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 3,
|
||||
// No model_list entry for this alias — absent from CandidateProviders.
|
||||
ModelFallbacks: []string{"openrouter/fallback-model"},
|
||||
},
|
||||
},
|
||||
ModelList: []*config.ModelConfig{
|
||||
{
|
||||
ModelName: "primary-model",
|
||||
Model: "openrouter/primary-model",
|
||||
APIBase: primaryServer.URL,
|
||||
APIKeys: config.SimpleSecureStrings("primary-key"),
|
||||
Workspace: workspace,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
provider, _, err := providers.CreateProvider(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateProvider() error = %v", err)
|
||||
}
|
||||
msgBus := bus.NewMessageBus()
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
|
||||
helper := testHelper{al: al}
|
||||
resp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
|
||||
Channel: "telegram",
|
||||
SenderID: "user1",
|
||||
ChatID: "chat1",
|
||||
Content: "hi",
|
||||
Peer: bus.Peer{Kind: "direct", ID: "user1"},
|
||||
})
|
||||
|
||||
if resp != "active provider reply" {
|
||||
t.Fatalf("response = %q, want %q", resp, "active provider reply")
|
||||
}
|
||||
if callCount < 2 {
|
||||
t.Fatalf("primary server calls = %d, want >= 2 (one 429 + one success via activeProvider)", callCount)
|
||||
}
|
||||
}
|
||||
|
||||
// TestToolResult_SilentToolDoesNotSendUserMessage verifies silent tools don't trigger outbound
|
||||
func TestToolResult_SilentToolDoesNotSendUserMessage(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
|
||||
Reference in New Issue
Block a user