From f93d2b453325f8529805c75f83717dbdd318b4be Mon Sep 17 00:00:00 2001 From: linhaolin1 Date: Thu, 19 Mar 2026 00:10:26 +0800 Subject: [PATCH 1/4] fix: Avoid failure of the main agent process due to tool call failures (#1023) * Avoid failure of the main agent process due to tool call failures or abnormal returns * rename recover --- pkg/tools/registry.go | 49 +++++++++-- pkg/tools/registry_test.go | 173 +++++++++++++++++++++++++++++++++++++ pkg/tools/shell.go | 19 +++- pkg/tools/shell_test.go | 63 ++++++++++++++ 4 files changed, 295 insertions(+), 9 deletions(-) diff --git a/pkg/tools/registry.go b/pkg/tools/registry.go index 0635f47d7..60effc292 100644 --- a/pkg/tools/registry.go +++ b/pkg/tools/registry.go @@ -188,15 +188,48 @@ func (r *ToolRegistry) ExecuteWithContext( // The callback is a call parameter, not mutable state on the tool instance. var result *ToolResult start := time.Now() - if asyncExec, ok := tool.(AsyncExecutor); ok && asyncCallback != nil { - logger.DebugCF("tool", "Executing async tool via ExecuteAsync", - map[string]any{ - "tool": name, - }) - result = asyncExec.ExecuteAsync(ctx, args, asyncCallback) - } else { - result = tool.Execute(ctx, args) + + // Use recover to catch any panics during tool execution + // This prevents tool crashes from killing the entire agent + func() { + defer func() { + if re := recover(); re != nil { + errMsg := fmt.Sprintf("Tool '%s' crashed with panic: %v", name, re) + logger.ErrorCF("tool", "Tool execution panic recovered", + map[string]any{ + "tool": name, + "panic": fmt.Sprintf("%v", re), + }) + result = &ToolResult{ + ForLLM: errMsg, + ForUser: errMsg, + IsError: true, + Err: fmt.Errorf("panic: %v", re), + } + } + }() + + if asyncExec, ok := tool.(AsyncExecutor); ok && asyncCallback != nil { + logger.DebugCF("tool", "Executing async tool via ExecuteAsync", + map[string]any{ + "tool": name, + }) + result = asyncExec.ExecuteAsync(ctx, args, asyncCallback) + } else { + result = tool.Execute(ctx, args) + } + }() + + // Handle nil result (should not happen, but defensive) + if result == nil { + result = &ToolResult{ + ForLLM: fmt.Sprintf("Tool '%s' returned nil result unexpectedly", name), + ForUser: fmt.Sprintf("Tool '%s' returned nil result unexpectedly", name), + IsError: true, + Err: fmt.Errorf("nil result from tool"), + } } + duration := time.Since(start) // Log based on result type diff --git a/pkg/tools/registry_test.go b/pkg/tools/registry_test.go index 92d7d5abd..5fe681389 100644 --- a/pkg/tools/registry_test.go +++ b/pkg/tools/registry_test.go @@ -2,6 +2,7 @@ package tools import ( "context" + "errors" "strings" "sync" "testing" @@ -358,3 +359,175 @@ func TestToolRegistry_ConcurrentAccess(t *testing.T) { t.Error("expected tools to be registered after concurrent access") } } + +// --- Panic and abnormal exit tests --- + +// mockPanicTool is a tool that panics during execution +type mockPanicTool struct { + name string + panicValue any +} + +func (m *mockPanicTool) Name() string { return m.name } +func (m *mockPanicTool) Description() string { return "a tool that panics" } +func (m *mockPanicTool) Parameters() map[string]any { return map[string]any{"type": "object"} } +func (m *mockPanicTool) Execute(_ context.Context, _ map[string]any) *ToolResult { + panic(m.panicValue) +} + +// mockNilResultTool is a tool that returns nil +type mockNilResultTool struct { + name string +} + +func (m *mockNilResultTool) Name() string { return m.name } +func (m *mockNilResultTool) Description() string { return "a tool that returns nil" } +func (m *mockNilResultTool) Parameters() map[string]any { return map[string]any{"type": "object"} } +func (m *mockNilResultTool) Execute(_ context.Context, _ map[string]any) *ToolResult { + return nil +} + +func TestToolRegistry_Execute_PanicRecovery(t *testing.T) { + r := NewToolRegistry() + r.Register(&mockPanicTool{ + name: "panic_tool", + panicValue: "something went terribly wrong", + }) + + // Should not panic, should return error result + result := r.Execute(context.Background(), "panic_tool", nil) + + if result == nil { + t.Fatal("expected non-nil result after panic recovery") + } + if !result.IsError { + t.Error("expected IsError=true after panic") + } + if !strings.Contains(result.ForLLM, "panic") { + t.Errorf("expected 'panic' in error message, got %q", result.ForLLM) + } + if !strings.Contains(result.ForLLM, "panic_tool") { + t.Errorf("expected tool name in error message, got %q", result.ForLLM) + } + if !strings.Contains(result.ForLLM, "something went terribly wrong") { + t.Errorf("expected panic value in error message, got %q", result.ForLLM) + } + if result.Err == nil { + t.Error("expected Err to be set") + } +} + +func TestToolRegistry_Execute_PanicRecovery_ErrorType(t *testing.T) { + r := NewToolRegistry() + + // Test with error type panic + r.Register(&mockPanicTool{ + name: "error_panic_tool", + panicValue: errors.New("custom error panic"), + }) + + result := r.Execute(context.Background(), "error_panic_tool", nil) + + if !result.IsError { + t.Error("expected IsError=true") + } + if !strings.Contains(result.ForLLM, "custom error panic") { + t.Errorf("expected error message in ForLLM, got %q", result.ForLLM) + } +} + +func TestToolRegistry_Execute_PanicRecovery_IntType(t *testing.T) { + r := NewToolRegistry() + + // Test with int type panic + r.Register(&mockPanicTool{ + name: "int_panic_tool", + panicValue: 42, + }) + + result := r.Execute(context.Background(), "int_panic_tool", nil) + + if !result.IsError { + t.Error("expected IsError=true") + } + if !strings.Contains(result.ForLLM, "42") { + t.Errorf("expected panic value '42' in ForLLM, got %q", result.ForLLM) + } +} + +func TestToolRegistry_Execute_NilResultHandling(t *testing.T) { + r := NewToolRegistry() + r.Register(&mockNilResultTool{name: "nil_tool"}) + + result := r.Execute(context.Background(), "nil_tool", nil) + + if result == nil { + t.Fatal("expected non-nil result when tool returns nil") + } + if !result.IsError { + t.Error("expected IsError=true for nil result") + } + if !strings.Contains(result.ForLLM, "nil_tool") { + t.Errorf("expected tool name in error message, got %q", result.ForLLM) + } + if !strings.Contains(result.ForLLM, "nil result") { + t.Errorf("expected 'nil result' in error message, got %q", result.ForLLM) + } + if result.Err == nil { + t.Error("expected Err to be set") + } +} + +func TestToolRegistry_ExecuteWithContext_PanicRecovery(t *testing.T) { + r := NewToolRegistry() + r.Register(&mockPanicTool{ + name: "ctx_panic_tool", + panicValue: "context panic test", + }) + + // Should not panic even with context + result := r.ExecuteWithContext( + context.Background(), + "ctx_panic_tool", + map[string]any{"key": "value"}, + "telegram", + "chat-123", + nil, + ) + + if result == nil { + t.Fatal("expected non-nil result") + } + if !result.IsError { + t.Error("expected IsError=true") + } + if !strings.Contains(result.ForLLM, "context panic test") { + t.Errorf("expected panic message, got %q", result.ForLLM) + } +} + +func TestToolRegistry_Execute_PanicDoesNotAffectOtherTools(t *testing.T) { + r := NewToolRegistry() + r.Register(&mockPanicTool{name: "bad_tool", panicValue: "boom"}) + r.Register(&mockRegistryTool{ + name: "good_tool", + desc: "works fine", + params: map[string]any{}, + result: SilentResult("success"), + }) + + // First, trigger the panic + result1 := r.Execute(context.Background(), "bad_tool", nil) + if !result1.IsError { + t.Error("expected error from panic tool") + } + + // Then, verify the good tool still works + result2 := r.Execute(context.Background(), "good_tool", nil) + if result2.IsError { + t.Errorf("expected success from good tool, got error: %s", result2.ForLLM) + } + if result2.ForLLM != "success" { + t.Errorf("expected 'success', got %q", result2.ForLLM) + } +} diff --git a/pkg/tools/shell.go b/pkg/tools/shell.go index 0dc85ae21..78ad2b26d 100644 --- a/pkg/tools/shell.go +++ b/pkg/tools/shell.go @@ -311,13 +311,30 @@ func (t *ExecTool) Execute(ctx context.Context, args map[string]any) *ToolResult if err != nil { if errors.Is(cmdCtx.Err(), context.DeadlineExceeded) { msg := fmt.Sprintf("Command timed out after %v", t.timeout) + if output != "" { + msg += "\n\nPartial output before timeout:\n" + output + } return &ToolResult{ ForLLM: msg, ForUser: msg, IsError: true, + Err: fmt.Errorf("command timeout: %w", err), } } - output += fmt.Sprintf("\nExit code: %v", err) + + // Extract detailed exit information + var exitErr *exec.ExitError + if errors.As(err, &exitErr) { + exitCode := exitErr.ExitCode() + output += fmt.Sprintf("\n\n[Command exited with code %d]", exitCode) + + // Add signal information if killed by signal (Unix) + if exitCode == -1 { + output += " (killed by signal)" + } + } else { + output += fmt.Sprintf("\n\n[Command failed: %v]", err) + } } if output == "" { diff --git a/pkg/tools/shell_test.go b/pkg/tools/shell_test.go index c4553020f..f8f83ea74 100644 --- a/pkg/tools/shell_test.go +++ b/pkg/tools/shell_test.go @@ -489,6 +489,69 @@ func TestShellTool_SafePathsInWorkspaceRestriction(t *testing.T) { } } +// TestShellTool_ExitCodeDetails verifies that exit codes are captured with details +func TestShellTool_ExitCodeDetails(t *testing.T) { + tool, err := NewExecTool("", false) + if err != nil { + t.Fatalf("unable to configure exec tool: %s", err) + } + + ctx := context.Background() + args := map[string]any{ + "command": "sh -c 'exit 42'", + } + + result := tool.Execute(ctx, args) + + if !result.IsError { + t.Error("expected error for non-zero exit code") + } + + // Should contain the exit code in the message (new format: "exited with code 42") + if !strings.Contains(result.ForLLM, "42") { + t.Errorf("expected exit code 42 in error message, got: %s", result.ForLLM) + } + + // Verify the new detailed message format + if !strings.Contains(result.ForLLM, "exited with code") { + t.Errorf("expected 'exited with code' in message, got: %s", result.ForLLM) + } + + // Err field is set by the exec system (may or may not be set depending on implementation) + // The important thing is that IsError=true + t.Logf("Exit code result: %s", result.ForLLM) +} + +// TestShellTool_TimeoutWithPartialOutput verifies timeout includes partial output +func TestShellTool_TimeoutWithPartialOutput(t *testing.T) { + tool, err := NewExecTool("", false) + if err != nil { + t.Fatalf("unable to configure exec tool: %s", err) + } + + tool.SetTimeout(1 * time.Second) // Give more time for echo to complete + + ctx := context.Background() + // Use a command that outputs immediately then sleeps + args := map[string]any{ + "command": "echo 'partial output before timeout' && sleep 30", + } + + result := tool.Execute(ctx, args) + + if !result.IsError { + t.Error("expected error for timeout") + } + + // Should mention timeout + if !strings.Contains(result.ForLLM, "timed out") { + t.Errorf("expected 'timed out' in message, got: %s", result.ForLLM) + } + + // Log the result for debugging (partial output depends on shell behavior) + t.Logf("Timeout result: %s", result.ForLLM) +} + // TestShellTool_CustomAllowPatterns verifies that custom allow patterns exempt // commands from deny pattern checks. func TestShellTool_CustomAllowPatterns(t *testing.T) { From eb86e10e5c350f4b287dba56c0e942b8226573ac Mon Sep 17 00:00:00 2001 From: Paolo Anzani Date: Wed, 18 Mar 2026 17:17:16 +0100 Subject: [PATCH 2/4] fix(tools): propagate tool registry to subagents (#1711) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix(tools): propagate tool registry to subagents via Clone SubagentManager was created with an empty ToolRegistry and SetTools() was never called, causing all subagent tool invocations to fail with "tool not found". This was a regression from the multi-agent refactor. Fix: clone the parent agent's tool registry into the subagent manager after creation but before spawn/spawn_status registration — giving subagents access to file, exec, web, and other tools while preventing recursive subagent spawning. - Add ToolRegistry.Clone() for independent shallow copies - Call subagentManager.SetTools(agent.Tools.Clone()) in registerSharedTools - Add tests for Clone isolation, empty clone, and hidden tool state Co-Authored-By: Claude Opus 4.6 * fix(tools): fix cron_test build error and add TTL clone test - Fix cron_test.go:229 — replace non-existent SubscribeOutbound(ctx) with select on OutboundChan(), matching the MessageBus channel API - Add TestToolRegistry_Clone_PreservesTTLValue per reviewer feedback - Add version reset note to Clone() doc comment Co-Authored-By: Claude Opus 4.6 --------- Co-authored-by: Claude Opus 4.6 --- pkg/agent/loop.go | 5 +++ pkg/tools/registry.go | 22 ++++++++++ pkg/tools/registry_test.go | 90 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 117 insertions(+) diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index 86994c360..33da33e92 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -239,6 +239,11 @@ func registerSharedTools( if (spawnEnabled || spawnStatusEnabled) && cfg.Tools.IsToolEnabled("subagent") { subagentManager := tools.NewSubagentManager(provider, agent.Model, agent.Workspace) subagentManager.SetLLMOptions(agent.MaxTokens, agent.Temperature) + // Clone the parent's tool registry so subagents can use all + // tools registered so far (file, web, etc.) but NOT spawn/ + // spawn_status which are added below — preventing recursive + // subagent spawning. + subagentManager.SetTools(agent.Tools.Clone()) if spawnEnabled { spawnTool := tools.NewSpawnTool(subagentManager) currentAgentID := agentID diff --git a/pkg/tools/registry.go b/pkg/tools/registry.go index 60effc292..0b0f51cc1 100644 --- a/pkg/tools/registry.go +++ b/pkg/tools/registry.go @@ -336,6 +336,28 @@ func (r *ToolRegistry) List() []string { return r.sortedToolNames() } +// Clone creates an independent copy of the registry containing the same tool +// entries (shallow copy of each ToolEntry). This is used to give subagents a +// snapshot of the parent agent's tools without sharing the same registry — +// tools registered on the parent after cloning (e.g. spawn, spawn_status) +// will NOT be visible to the clone, preventing recursive subagent spawning. +// The version counter is reset to 0 in the clone as it's a new independent registry. +func (r *ToolRegistry) Clone() *ToolRegistry { + r.mu.RLock() + defer r.mu.RUnlock() + clone := &ToolRegistry{ + tools: make(map[string]*ToolEntry, len(r.tools)), + } + for name, entry := range r.tools { + clone.tools[name] = &ToolEntry{ + Tool: entry.Tool, + IsCore: entry.IsCore, + TTL: entry.TTL, + } + } + return clone +} + // Count returns the number of registered tools. func (r *ToolRegistry) Count() int { r.mu.RLock() diff --git a/pkg/tools/registry_test.go b/pkg/tools/registry_test.go index 5fe681389..967758dfa 100644 --- a/pkg/tools/registry_test.go +++ b/pkg/tools/registry_test.go @@ -336,6 +336,96 @@ func TestToolToSchema(t *testing.T) { } } +func TestToolRegistry_Clone(t *testing.T) { + r := NewToolRegistry() + r.Register(newMockTool("read_file", "reads files")) + r.Register(newMockTool("exec", "runs commands")) + r.Register(newMockTool("web_search", "searches the web")) + + clone := r.Clone() + + // Clone should have the same tools + if clone.Count() != 3 { + t.Errorf("expected clone to have 3 tools, got %d", clone.Count()) + } + for _, name := range []string{"read_file", "exec", "web_search"} { + if _, ok := clone.Get(name); !ok { + t.Errorf("expected clone to have tool %q", name) + } + } + + // Registering on parent should NOT affect clone + r.Register(newMockTool("spawn", "spawns subagent")) + if r.Count() != 4 { + t.Errorf("expected parent to have 4 tools, got %d", r.Count()) + } + if clone.Count() != 3 { + t.Errorf("expected clone to still have 3 tools after parent mutation, got %d", clone.Count()) + } + if _, ok := clone.Get("spawn"); ok { + t.Error("expected clone NOT to have 'spawn' tool registered on parent after cloning") + } + + // Registering on clone should NOT affect parent + clone.Register(newMockTool("custom", "custom tool")) + if clone.Count() != 4 { + t.Errorf("expected clone to have 4 tools, got %d", clone.Count()) + } + if _, ok := r.Get("custom"); ok { + t.Error("expected parent NOT to have 'custom' tool registered on clone") + } +} + +func TestToolRegistry_Clone_Empty(t *testing.T) { + r := NewToolRegistry() + clone := r.Clone() + if clone.Count() != 0 { + t.Errorf("expected empty clone, got count %d", clone.Count()) + } +} + +func TestToolRegistry_Clone_PreservesHiddenToolState(t *testing.T) { + r := NewToolRegistry() + r.RegisterHidden(newMockTool("mcp_tool", "dynamic MCP tool")) + + clone := r.Clone() + + // Hidden tools with TTL=0 should not be gettable (same behavior as parent) + if _, ok := clone.Get("mcp_tool"); ok { + t.Error("expected hidden tool with TTL=0 to be invisible in clone") + } + + // But the entry should exist (count includes hidden tools) + if clone.Count() != 1 { + t.Errorf("expected clone count 1 (hidden entry exists), got %d", clone.Count()) + } +} + +func TestToolRegistry_Clone_PreservesTTLValue(t *testing.T) { + r := NewToolRegistry() + r.RegisterHidden(newMockTool("ttl_tool", "tool with TTL")) + + // Manually set a non-zero TTL on the entry + r.mu.RLock() + if entry, ok := r.tools["ttl_tool"]; ok { + entry.TTL = 5 + } + r.mu.RUnlock() + + clone := r.Clone() + + // Verify TTL value is preserved in the clone + clone.mu.RLock() + defer clone.mu.RUnlock() + entry, ok := clone.tools["ttl_tool"] + if !ok { + t.Fatal("expected ttl_tool to exist in clone") + } + if entry.TTL != 5 { + t.Errorf("expected TTL=5 in clone, got %d", entry.TTL) + } +} + func TestToolRegistry_ConcurrentAccess(t *testing.T) { r := NewToolRegistry() var wg sync.WaitGroup From 08f305d7129c51d04e7b86359f7a439d00d63a85 Mon Sep 17 00:00:00 2001 From: Liqiang Lau Date: Thu, 19 Mar 2026 00:29:55 +0800 Subject: [PATCH 3/4] feat: add IsLark field to FeishuConfig to switch between Feishu and Lark domains (#1753) * feat(feishu): add Lark (international) support via IsLark config field Add IsLark field to FeishuConfig to switch between Feishu and Lark domains. Also fix domain inconsistency where WS client defaulted to LarkBaseUrl while HTTP client used FeishuBaseUrl. Co-Authored-By: Claude Opus 4.6 * docs: update documentation and web UI for Lark support Add is_lark field to config example, feishu docs, i18n translations, and web frontend form. Co-Authored-By: Claude Opus 4.6 --------- Co-authored-by: Claude Opus 4.6 --- config/config.example.json | 3 ++- docs/channels/feishu/README.zh.md | 24 ++++++++++--------- pkg/channels/feishu/feishu_64.go | 11 ++++++++- pkg/config/config.go | 1 + .../channels/channel-forms/feishu-form.tsx | 12 +++++++++- web/frontend/src/i18n/locales/en.json | 2 ++ web/frontend/src/i18n/locales/zh.json | 2 ++ 7 files changed, 41 insertions(+), 14 deletions(-) diff --git a/config/config.example.json b/config/config.example.json index 167ba7d59..c214f26fa 100644 --- a/config/config.example.json +++ b/config/config.example.json @@ -122,7 +122,8 @@ "verification_token": "", "allow_from": [], "reasoning_channel_id": "", - "random_reaction_emoji": [] + "random_reaction_emoji": [], + "is_lark": false }, "dingtalk": { "enabled": false, diff --git a/docs/channels/feishu/README.zh.md b/docs/channels/feishu/README.zh.md index 3fafffb7d..db7eb56eb 100644 --- a/docs/channels/feishu/README.zh.md +++ b/docs/channels/feishu/README.zh.md @@ -13,25 +13,27 @@ "app_secret": "xxx", "encrypt_key": "", "verification_token": "", - "allow_from": [] + "allow_from": [], + "is_lark": false } } } ``` -| 字段 | 类型 | 必填 | 描述 | -| ------------------ | ------ | ---- | -------------------------------- | -| enabled | bool | 是 | 是否启用飞书频道 | -| app_id | string | 是 | 飞书应用的 App ID(以cli\_开头) | -| app_secret | string | 是 | 飞书应用的 App Secret | -| encrypt_key | string | 否 | 事件回调加密密钥 | -| verification_token | string | 否 | 用于Webhook事件验证的Token | -| allow_from | array | 否 | 用户ID白名单,空表示所有用户 | -| random_reaction_emoji | array | 否 | 随机添加的表情列表,空则使用默认 "Pin" | +| 字段 | 类型 | 必填 | 描述 | +| --------------------- | ------ | ---- | ------------------------------------------------------------------------------------------------ | +| enabled | bool | 是 | 是否启用飞书频道 | +| app_id | string | 是 | 飞书应用的 App ID(以cli\_开头) | +| app_secret | string | 是 | 飞书应用的 App Secret | +| encrypt_key | string | 否 | 事件回调加密密钥 | +| verification_token | string | 否 | 用于Webhook事件验证的Token | +| allow_from | array | 否 | 用户ID白名单,空表示所有用户 | +| random_reaction_emoji | array | 否 | 随机添加的表情列表,空则使用默认 "Pin" | +| is_lark | bool | 否 | 是否使用 Lark 国际版域名(`open.larksuite.com`),默认为 `false`(使用飞书域名 `open.feishu.cn`) | ## 设置流程 -1. 前往 [飞书开放平台](https://open.feishu.cn/)创建应用程序 +1. 前往 [飞书开放平台](https://open.feishu.cn/)(国际版用户请前往 [Lark 开放平台](https://open.larksuite.com/))创建应用程序 2. 获取 App ID 和 App Secret 3. 配置事件订阅和Webhook URL 4. 设置加密(可选,生产环境建议启用) diff --git a/pkg/channels/feishu/feishu_64.go b/pkg/channels/feishu/feishu_64.go index c503e2993..3aea67b12 100644 --- a/pkg/channels/feishu/feishu_64.go +++ b/pkg/channels/feishu/feishu_64.go @@ -54,11 +54,15 @@ func NewFeishuChannel(cfg config.FeishuConfig, bus *bus.MessageBus) (*FeishuChan ) tc := newTokenCache() + opts := []lark.ClientOptionFunc{lark.WithTokenCache(tc)} + if cfg.IsLark { + opts = append(opts, lark.WithOpenBaseUrl(lark.LarkBaseUrl)) + } ch := &FeishuChannel{ BaseChannel: base, config: cfg, tokenCache: tc, - client: lark.NewClient(cfg.AppID, cfg.AppSecret, lark.WithTokenCache(tc)), + client: lark.NewClient(cfg.AppID, cfg.AppSecret, opts...), } ch.SetOwner(ch) return ch, nil @@ -83,10 +87,15 @@ func (c *FeishuChannel) Start(ctx context.Context) error { c.mu.Lock() c.cancel = cancel + domain := lark.FeishuBaseUrl + if c.config.IsLark { + domain = lark.LarkBaseUrl + } c.wsClient = larkws.NewClient( c.config.AppID, c.config.AppSecret, larkws.WithEventHandler(dispatcher), + larkws.WithDomain(domain), ) wsClient := c.wsClient c.mu.Unlock() diff --git a/pkg/config/config.go b/pkg/config/config.go index dd4e86319..d07cb60aa 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -325,6 +325,7 @@ type FeishuConfig struct { Placeholder PlaceholderConfig `json:"placeholder,omitempty"` ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_FEISHU_REASONING_CHANNEL_ID"` RandomReactionEmoji FlexibleStringSlice `json:"random_reaction_emoji" env:"PICOCLAW_CHANNELS_FEISHU_RANDOM_REACTION_EMOJI"` + IsLark bool `json:"is_lark" env:"PICOCLAW_CHANNELS_FEISHU_IS_LARK"` } type DiscordConfig struct { diff --git a/web/frontend/src/components/channels/channel-forms/feishu-form.tsx b/web/frontend/src/components/channels/channel-forms/feishu-form.tsx index a834a65f9..386adf9a5 100644 --- a/web/frontend/src/components/channels/channel-forms/feishu-form.tsx +++ b/web/frontend/src/components/channels/channel-forms/feishu-form.tsx @@ -2,7 +2,7 @@ import { useTranslation } from "react-i18next" import type { ChannelConfig } from "@/api/channels" import { maskedSecretPlaceholder } from "@/components/secret-placeholder" -import { Field, KeyInput } from "@/components/shared-form" +import { Field, KeyInput, SwitchCardField } from "@/components/shared-form" import { Input } from "@/components/ui/input" interface FeishuFormProps { @@ -16,6 +16,10 @@ function asString(value: unknown): string { return typeof value === "string" ? value : "" } +function asBool(value: unknown): boolean { + return typeof value === "boolean" ? value : false +} + function asStringArray(value: unknown): string[] { if (!Array.isArray(value)) return [] return value.filter((item): item is string => typeof item === "string") @@ -98,6 +102,12 @@ export function FeishuForm({ )} /> + onChange("is_lark", checked)} + /> Date: Thu, 19 Mar 2026 00:57:20 +0800 Subject: [PATCH 4/4] feat(config): support multiple API keys for failover (#1707) * feat(config): support multiple API keys for failover Add api_keys field to ModelConfig to support multiple API keys with automatic failover. When multiple keys are configured, they are expanded into separate model entries with fallbacks set up for key-level failover. Example config: { "model_name": "glm-4.7", "model": "zhipu/glm-4.7", "api_keys": ["key1", "key2", "key3"] } Expands internally to: - glm-4.7 (key1) -> fallbacks: [glm-4.7__key_1, glm-4.7__key_2] - glm-4.7__key_1 (key2) - glm-4.7__key_2 (key3) Backward compatible: single api_key still works as before. * fix(providers): change cooldown tracking from provider to ModelKey This enables proper key-switching when multiple API keys share the same provider. Previously, when one key failed, all keys were blocked because cooldown was tracked per-provider. Now each (provider, model) combination has independent cooldown, allowing fallback to alternate keys when one is rate limited. Includes TestMultiKeyWithModelFallback and related failover tests. --- pkg/config/config.go | 105 ++++++- pkg/config/multikey_test.go | 291 ++++++++++++++++++ pkg/providers/fallback.go | 16 +- pkg/providers/fallback_multikey_test.go | 384 ++++++++++++++++++++++++ pkg/providers/fallback_test.go | 15 +- 5 files changed, 794 insertions(+), 17 deletions(-) create mode 100644 pkg/config/multikey_test.go create mode 100644 pkg/providers/fallback_multikey_test.go diff --git a/pkg/config/config.go b/pkg/config/config.go index d07cb60aa..739f8d373 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -603,9 +603,11 @@ type ModelConfig struct { Model string `json:"model"` // Protocol/model-identifier (e.g., "openai/gpt-4o", "anthropic/claude-sonnet-4.6") // HTTP-based providers - APIBase string `json:"api_base,omitempty"` // API endpoint URL - APIKey string `json:"api_key"` // API authentication key - Proxy string `json:"proxy,omitempty"` // HTTP proxy URL + APIBase string `json:"api_base,omitempty"` // API endpoint URL + APIKey string `json:"api_key"` // API authentication key (single key) + APIKeys []string `json:"api_keys,omitempty"` // API authentication keys (multiple keys for failover) + Proxy string `json:"proxy,omitempty"` // HTTP proxy URL + Fallbacks []string `json:"fallbacks,omitempty"` // Fallback model names for failover // Special providers (CLI-based, OAuth, etc.) AuthMethod string `json:"auth_method,omitempty"` // Authentication method: oauth, token @@ -874,6 +876,9 @@ func LoadConfig(path string) (*Config, error) { return nil, err } + // Expand multi-key configs into separate entries for key-level failover + cfg.ModelList = ExpandMultiKeyModels(cfg.ModelList) + // Migrate legacy channel config fields to new unified structures cfg.migrateChannelConfigs() @@ -920,14 +925,25 @@ func encryptPlaintextAPIKeys(models []ModelConfig, passphrase string) ([]ModelCo // resolveAPIKeys decrypts or dereferences each api_key in models in-place. // Supports plaintext (no-op), file:// (read from configDir), and enc:// (AES-GCM decrypt). +// Also resolves api_keys array if present. func resolveAPIKeys(models []ModelConfig, configDir string) error { cr := credential.NewResolver(configDir) for i := range models { + // Resolve single APIKey resolved, err := cr.Resolve(models[i].APIKey) if err != nil { return fmt.Errorf("model_list[%d] (%s): %w", i, models[i].ModelName, err) } models[i].APIKey = resolved + + // Resolve APIKeys array + for j, key := range models[i].APIKeys { + resolved, err := cr.Resolve(key) + if err != nil { + return fmt.Errorf("model_list[%d] (%s): api_keys[%d]: %w", i, models[i].ModelName, j, err) + } + models[i].APIKeys[j] = resolved + } } return nil } @@ -1098,6 +1114,89 @@ func MergeAPIKeys(apiKey string, apiKeys []string) []string { return all } +// ExpandMultiKeyModels expands ModelConfig entries with multiple API keys into +// separate entries for key-level failover. Each key gets its own ModelConfig entry, +// and the original entry's fallbacks are set up to chain through the expanded entries. +// +// Example: {"model_name": "gpt-4", "api_keys": ["k1", "k2", "k3"]} +// Becomes: +// - {"model_name": "gpt-4", "api_key": "k1", "fallbacks": ["gpt-4__key_1", "gpt-4__key_2"]} +// - {"model_name": "gpt-4__key_1", "api_key": "k2"} +// - {"model_name": "gpt-4__key_2", "api_key": "k3"} +func ExpandMultiKeyModels(models []ModelConfig) []ModelConfig { + var expanded []ModelConfig + + for _, m := range models { + keys := MergeAPIKeys(m.APIKey, m.APIKeys) + + // Single key or no keys: keep as-is + if len(keys) <= 1 { + // Ensure APIKey is set from APIKeys if needed + if m.APIKey == "" && len(keys) == 1 { + m.APIKey = keys[0] + } + m.APIKeys = nil // Clear APIKeys to avoid confusion + expanded = append(expanded, m) + continue + } + + // Multiple keys: expand + originalName := m.ModelName + + // Create entries for additional keys (key_1, key_2, ...) + var fallbackNames []string + for i := 1; i < len(keys); i++ { + suffix := fmt.Sprintf("__key_%d", i) + expandedName := originalName + suffix + + // Create a copy for the additional key + additionalEntry := ModelConfig{ + ModelName: expandedName, + Model: m.Model, + APIBase: m.APIBase, + APIKey: keys[i], + Proxy: m.Proxy, + AuthMethod: m.AuthMethod, + ConnectMode: m.ConnectMode, + Workspace: m.Workspace, + RPM: m.RPM, + MaxTokensField: m.MaxTokensField, + RequestTimeout: m.RequestTimeout, + ThinkingLevel: m.ThinkingLevel, + } + expanded = append(expanded, additionalEntry) + fallbackNames = append(fallbackNames, expandedName) + } + + // Create the primary entry with first key and fallbacks + primaryEntry := ModelConfig{ + ModelName: originalName, + Model: m.Model, + APIBase: m.APIBase, + APIKey: keys[0], + Proxy: m.Proxy, + AuthMethod: m.AuthMethod, + ConnectMode: m.ConnectMode, + Workspace: m.Workspace, + RPM: m.RPM, + MaxTokensField: m.MaxTokensField, + RequestTimeout: m.RequestTimeout, + ThinkingLevel: m.ThinkingLevel, + } + + // Prepend new fallbacks to existing ones + if len(fallbackNames) > 0 { + primaryEntry.Fallbacks = append(fallbackNames, m.Fallbacks...) + } else if len(m.Fallbacks) > 0 { + primaryEntry.Fallbacks = m.Fallbacks + } + + expanded = append(expanded, primaryEntry) + } + + return expanded +} + func (t *ToolsConfig) IsToolEnabled(name string) bool { switch name { case "web": diff --git a/pkg/config/multikey_test.go b/pkg/config/multikey_test.go new file mode 100644 index 000000000..b899b991c --- /dev/null +++ b/pkg/config/multikey_test.go @@ -0,0 +1,291 @@ +package config + +import ( + "testing" +) + +func TestExpandMultiKeyModels_SingleKey(t *testing.T) { + models := []ModelConfig{ + { + ModelName: "gpt-4", + Model: "openai/gpt-4o", + APIKey: "single-key", + }, + } + + result := ExpandMultiKeyModels(models) + + if len(result) != 1 { + t.Fatalf("expected 1 model, got %d", len(result)) + } + + if result[0].ModelName != "gpt-4" { + t.Errorf("expected model_name 'gpt-4', got %q", result[0].ModelName) + } + + if result[0].APIKey != "single-key" { + t.Errorf("expected api_key 'single-key', got %q", result[0].APIKey) + } + + if len(result[0].Fallbacks) != 0 { + t.Errorf("expected no fallbacks, got %v", result[0].Fallbacks) + } +} + +func TestExpandMultiKeyModels_APIKeysOnly(t *testing.T) { + models := []ModelConfig{ + { + ModelName: "glm-4.7", + Model: "zhipu/glm-4.7", + APIBase: "https://api.example.com", + APIKeys: []string{"key1", "key2", "key3"}, + }, + } + + result := ExpandMultiKeyModels(models) + + // Should expand to 3 models + if len(result) != 3 { + t.Fatalf("expected 3 models, got %d", len(result)) + } + + // First entry should be the primary with key1 and fallbacks + primary := result[2] // Primary is added last + if primary.ModelName != "glm-4.7" { + t.Errorf("expected primary model_name 'glm-4.7', got %q", primary.ModelName) + } + if primary.APIKey != "key1" { + t.Errorf("expected primary api_key 'key1', got %q", primary.APIKey) + } + if len(primary.Fallbacks) != 2 { + t.Errorf("expected 2 fallbacks, got %d", len(primary.Fallbacks)) + } + if primary.Fallbacks[0] != "glm-4.7__key_1" { + t.Errorf("expected first fallback 'glm-4.7__key_1', got %q", primary.Fallbacks[0]) + } + if primary.Fallbacks[1] != "glm-4.7__key_2" { + t.Errorf("expected second fallback 'glm-4.7__key_2', got %q", primary.Fallbacks[1]) + } + + // Second entry should be key2 + second := result[0] + if second.ModelName != "glm-4.7__key_1" { + t.Errorf("expected second model_name 'glm-4.7__key_1', got %q", second.ModelName) + } + if second.APIKey != "key2" { + t.Errorf("expected second api_key 'key2', got %q", second.APIKey) + } + + // Third entry should be key3 + third := result[1] + if third.ModelName != "glm-4.7__key_2" { + t.Errorf("expected third model_name 'glm-4.7__key_2', got %q", third.ModelName) + } + if third.APIKey != "key3" { + t.Errorf("expected third api_key 'key3', got %q", third.APIKey) + } +} + +func TestExpandMultiKeyModels_APIKeyAndAPIKeys(t *testing.T) { + models := []ModelConfig{ + { + ModelName: "gpt-4", + Model: "openai/gpt-4o", + APIKey: "key0", + APIKeys: []string{"key1", "key2"}, + }, + } + + result := ExpandMultiKeyModels(models) + + // Should expand to 3 models (key0 from APIKey + key1, key2 from APIKeys) + if len(result) != 3 { + t.Fatalf("expected 3 models, got %d", len(result)) + } + + // Primary should use key0 + primary := result[2] + if primary.APIKey != "key0" { + t.Errorf("expected primary api_key 'key0', got %q", primary.APIKey) + } + if len(primary.Fallbacks) != 2 { + t.Errorf("expected 2 fallbacks, got %d", len(primary.Fallbacks)) + } +} + +func TestExpandMultiKeyModels_WithExistingFallbacks(t *testing.T) { + models := []ModelConfig{ + { + ModelName: "gpt-4", + Model: "openai/gpt-4o", + APIKeys: []string{"key1", "key2"}, + Fallbacks: []string{"claude-3"}, + }, + } + + result := ExpandMultiKeyModels(models) + + primary := result[1] + // With 2 keys, we get 1 key fallback + 1 existing fallback = 2 total + if len(primary.Fallbacks) != 2 { + t.Fatalf("expected 2 fallbacks, got %d: %v", len(primary.Fallbacks), primary.Fallbacks) + } + + // Key fallbacks should come first, then existing fallbacks + if primary.Fallbacks[0] != "gpt-4__key_1" { + t.Errorf("expected first fallback 'gpt-4__key_1', got %q", primary.Fallbacks[0]) + } + if primary.Fallbacks[1] != "claude-3" { + t.Errorf("expected second fallback 'claude-3', got %q", primary.Fallbacks[1]) + } +} + +func TestExpandMultiKeyModels_EmptyAPIKeys(t *testing.T) { + models := []ModelConfig{ + { + ModelName: "gpt-4", + Model: "openai/gpt-4o", + APIKey: "", + APIKeys: []string{}, + }, + } + + result := ExpandMultiKeyModels(models) + + // Should keep as-is with no changes + if len(result) != 1 { + t.Fatalf("expected 1 model, got %d", len(result)) + } + + if result[0].ModelName != "gpt-4" { + t.Errorf("expected model_name 'gpt-4', got %q", result[0].ModelName) + } +} + +func TestExpandMultiKeyModels_Deduplication(t *testing.T) { + models := []ModelConfig{ + { + ModelName: "gpt-4", + Model: "openai/gpt-4o", + APIKey: "key1", + APIKeys: []string{"key1", "key2", "key1"}, // Duplicate key1 + }, + } + + result := ExpandMultiKeyModels(models) + + // Should only create 2 models (deduplicated keys) + if len(result) != 2 { + t.Fatalf("expected 2 models (deduplicated), got %d", len(result)) + } + + primary := result[1] + if primary.APIKey != "key1" { + t.Errorf("expected primary api_key 'key1', got %q", primary.APIKey) + } + if len(primary.Fallbacks) != 1 { + t.Errorf("expected 1 fallback, got %d", len(primary.Fallbacks)) + } +} + +func TestExpandMultiKeyModels_PreservesOtherFields(t *testing.T) { + models := []ModelConfig{ + { + ModelName: "gpt-4", + Model: "openai/gpt-4o", + APIBase: "https://api.example.com", + APIKeys: []string{"key1", "key2"}, + Proxy: "http://proxy:8080", + RPM: 60, + MaxTokensField: "max_completion_tokens", + RequestTimeout: 30, + ThinkingLevel: "high", + }, + } + + result := ExpandMultiKeyModels(models) + + // Check primary entry preserves all fields + primary := result[1] + if primary.APIBase != "https://api.example.com" { + t.Errorf("expected api_base preserved, got %q", primary.APIBase) + } + if primary.Proxy != "http://proxy:8080" { + t.Errorf("expected proxy preserved, got %q", primary.Proxy) + } + if primary.RPM != 60 { + t.Errorf("expected rpm preserved, got %d", primary.RPM) + } + if primary.MaxTokensField != "max_completion_tokens" { + t.Errorf("expected max_tokens_field preserved, got %q", primary.MaxTokensField) + } + if primary.RequestTimeout != 30 { + t.Errorf("expected request_timeout preserved, got %d", primary.RequestTimeout) + } + if primary.ThinkingLevel != "high" { + t.Errorf("expected thinking_level preserved, got %q", primary.ThinkingLevel) + } + + // Check additional entry also preserves fields + additional := result[0] + if additional.APIBase != "https://api.example.com" { + t.Errorf("expected additional api_base preserved, got %q", additional.APIBase) + } + if additional.RPM != 60 { + t.Errorf("expected additional rpm preserved, got %d", additional.RPM) + } +} + +func TestMergeAPIKeys(t *testing.T) { + tests := []struct { + name string + apiKey string + apiKeys []string + expected []string + }{ + { + name: "both empty", + apiKey: "", + apiKeys: nil, + expected: nil, + }, + { + name: "only apiKey", + apiKey: "key1", + apiKeys: nil, + expected: []string{"key1"}, + }, + { + name: "only apiKeys", + apiKey: "", + apiKeys: []string{"key1", "key2"}, + expected: []string{"key1", "key2"}, + }, + { + name: "both with overlap", + apiKey: "key1", + apiKeys: []string{"key1", "key2", "key3"}, + expected: []string{"key1", "key2", "key3"}, + }, + { + name: "with whitespace", + apiKey: " key1 ", + apiKeys: []string{" key2 ", " key1 "}, + expected: []string{"key1", "key2"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := MergeAPIKeys(tt.apiKey, tt.apiKeys) + if len(result) != len(tt.expected) { + t.Fatalf("expected %d keys, got %d", len(tt.expected), len(result)) + } + for i, k := range result { + if k != tt.expected[i] { + t.Errorf("expected key[%d] = %q, got %q", i, tt.expected[i], k) + } + } + }) + } +} diff --git a/pkg/providers/fallback.go b/pkg/providers/fallback.go index 7ba563b66..549ec7837 100644 --- a/pkg/providers/fallback.go +++ b/pkg/providers/fallback.go @@ -117,17 +117,19 @@ func (fc *FallbackChain) Execute( return nil, context.Canceled } - // Check cooldown. - if !fc.cooldown.IsAvailable(candidate.Provider) { - remaining := fc.cooldown.CooldownRemaining(candidate.Provider) + // Check cooldown (per provider/model, not just provider). + // This allows multi-key failover where different keys use different model names. + cooldownKey := ModelKey(candidate.Provider, candidate.Model) + if !fc.cooldown.IsAvailable(cooldownKey) { + remaining := fc.cooldown.CooldownRemaining(cooldownKey) result.Attempts = append(result.Attempts, FallbackAttempt{ Provider: candidate.Provider, Model: candidate.Model, Skipped: true, Reason: FailoverRateLimit, Error: fmt.Errorf( - "provider %s in cooldown (%s remaining)", - candidate.Provider, + "%s in cooldown (%s remaining)", + cooldownKey, remaining.Round(time.Second), ), }) @@ -141,7 +143,7 @@ func (fc *FallbackChain) Execute( if err == nil { // Success. - fc.cooldown.MarkSuccess(candidate.Provider) + fc.cooldown.MarkSuccess(cooldownKey) result.Response = resp result.Provider = candidate.Provider result.Model = candidate.Model @@ -187,7 +189,7 @@ func (fc *FallbackChain) Execute( } // Retriable error: mark failure and continue to next candidate. - fc.cooldown.MarkFailure(candidate.Provider, failErr.Reason) + fc.cooldown.MarkFailure(cooldownKey, failErr.Reason) result.Attempts = append(result.Attempts, FallbackAttempt{ Provider: candidate.Provider, Model: candidate.Model, diff --git a/pkg/providers/fallback_multikey_test.go b/pkg/providers/fallback_multikey_test.go new file mode 100644 index 000000000..9ed8fa73c --- /dev/null +++ b/pkg/providers/fallback_multikey_test.go @@ -0,0 +1,384 @@ +package providers + +import ( + "context" + "errors" + "testing" +) + +// TestMultiKeyFailover tests the complete failover flow with multiple API keys. +// This simulates the config expansion scenario where api_keys: ["key1", "key2", "key3"] +// is expanded into primary + fallbacks. +func TestMultiKeyFailover(t *testing.T) { + // Simulate expanded config: primary with 2 fallbacks + // This is what ExpandMultiKeyModels would produce for api_keys: ["key1", "key2", "key3"] + cfg := ModelConfig{ + Primary: "glm-4.7", + Fallbacks: []string{"glm-4.7__key_1", "glm-4.7__key_2"}, + } + + candidates := ResolveCandidates(cfg, "zhipu") + + if len(candidates) != 3 { + t.Fatalf("expected 3 candidates, got %d: %v", len(candidates), candidates) + } + + // Create fallback chain + cooldown := NewCooldownTracker() + chain := NewFallbackChain(cooldown) + + // Mock run function: first call fails with 429, second succeeds + callCount := 0 + mockRun := func(ctx context.Context, provider, model string) (*LLMResponse, error) { + callCount++ + if callCount == 1 { + // First call: simulate rate limit + return nil, errors.New("http error: status 429 - rate limit exceeded") + } + // Second call: success + return &LLMResponse{ + Content: "Hello from key2!", + }, nil + } + + // Execute fallback chain + result, err := chain.Execute(context.Background(), candidates, mockRun) + if err != nil { + t.Fatalf("expected success after failover, got error: %v", err) + } + + if result == nil { + t.Fatal("expected result, got nil") + } + + if result.Response.Content != "Hello from key2!" { + t.Errorf("expected response from key2, got: %s", result.Response.Content) + } + + if callCount != 2 { + t.Errorf("expected 2 calls (1 fail + 1 success), got %d", callCount) + } + + // Verify first attempt was recorded + if len(result.Attempts) != 1 { + t.Errorf("expected 1 failed attempt recorded, got %d", len(result.Attempts)) + } + + if result.Attempts[0].Reason != FailoverRateLimit { + t.Errorf( + "expected first attempt reason to be rate_limit, got: %s", + result.Attempts[0].Reason, + ) + } +} + +// TestMultiKeyFailoverAllFail tests when all keys hit rate limit +func TestMultiKeyFailoverAllFail(t *testing.T) { + cfg := ModelConfig{ + Primary: "glm-4.7", + Fallbacks: []string{"glm-4.7__key_1", "glm-4.7__key_2"}, + } + + candidates := ResolveCandidates(cfg, "zhipu") + + cooldown := NewCooldownTracker() + chain := NewFallbackChain(cooldown) + + // Mock run function: all calls fail with rate limit + callCount := 0 + mockRun := func(ctx context.Context, provider, model string) (*LLMResponse, error) { + callCount++ + return nil, errors.New("status: 429 - too many requests") + } + + // Execute fallback chain + result, err := chain.Execute(context.Background(), candidates, mockRun) + + if err == nil { + t.Fatal("expected error when all keys fail, got nil") + } + + if result != nil { + t.Errorf("expected nil result on failure, got: %v", result) + } + + if callCount != 3 { + t.Errorf("expected 3 calls (all fail), got %d", callCount) + } + + // Verify error type + var exhausted *FallbackExhaustedError + if !errors.As(err, &exhausted) { + t.Errorf("expected FallbackExhaustedError, got: %T - %v", err, err) + } + + if len(exhausted.Attempts) != 3 { + t.Errorf("expected 3 attempts in exhausted error, got %d", len(exhausted.Attempts)) + } +} + +// TestMultiKeyFailoverCooldown tests that a key in cooldown is skipped +func TestMultiKeyFailoverCooldown(t *testing.T) { + cfg := ModelConfig{ + Primary: "glm-4.7", + Fallbacks: []string{"glm-4.7__key_1"}, + } + + candidates := ResolveCandidates(cfg, "zhipu") + + cooldown := NewCooldownTracker() + chain := NewFallbackChain(cooldown) + + // Put the first model in cooldown (using ModelKey now, not just provider) + cooldownKey := ModelKey(candidates[0].Provider, candidates[0].Model) + cooldown.MarkFailure(cooldownKey, FailoverRateLimit) + + // Verify it's not available + if cooldown.IsAvailable(cooldownKey) { + t.Fatal("expected first model to be in cooldown") + } + + // Mock run function: only second should be called + callCount := 0 + calledProviders := []string{} + mockRun := func(ctx context.Context, provider, model string) (*LLMResponse, error) { + callCount++ + calledProviders = append(calledProviders, provider+"/"+model) + return &LLMResponse{Content: "success"}, nil + } + + result, err := chain.Execute(context.Background(), candidates, mockRun) + if err != nil { + t.Fatalf("expected success, got error: %v", err) + } + + // First provider should have been skipped + if callCount != 1 { + t.Errorf("expected 1 call (first skipped due to cooldown), got %d", callCount) + } + + // Should have called the second provider/model + if len(calledProviders) != 1 || + calledProviders[0] != candidates[1].Provider+"/"+candidates[1].Model { + t.Errorf("expected second model to be called, got: %v", calledProviders) + } + + // Verify first attempt was recorded as skipped + if len(result.Attempts) != 1 { + t.Fatalf("expected 1 attempt (skipped), got %d", len(result.Attempts)) + } + + if !result.Attempts[0].Skipped { + t.Error("expected first attempt to be marked as skipped") + } +} + +// TestMultiKeyFailoverWithFormatError tests that format errors are non-retriable +func TestMultiKeyFailoverWithFormatError(t *testing.T) { + cfg := ModelConfig{ + Primary: "glm-4.7", + Fallbacks: []string{"glm-4.7__key_1"}, + } + + candidates := ResolveCandidates(cfg, "zhipu") + + cooldown := NewCooldownTracker() + chain := NewFallbackChain(cooldown) + + // Mock run function: first call fails with format error (bad request) + callCount := 0 + mockRun := func(ctx context.Context, provider, model string) (*LLMResponse, error) { + callCount++ + return nil, errors.New("invalid request format: tool_use.id missing") + } + + // Execute fallback chain + result, err := chain.Execute(context.Background(), candidates, mockRun) + + if err == nil { + t.Fatal("expected error for format failure, got nil") + } + + // Format errors should NOT trigger failover (non-retriable) + // So we should only have 1 call + if callCount != 1 { + t.Errorf("expected 1 call (format error is non-retriable), got %d", callCount) + } + + // Verify the error is a FailoverError with format reason + var failoverErr *FailoverError + if !errors.As(err, &failoverErr) { + t.Errorf("expected FailoverError, got: %T - %v", err, err) + } + + if failoverErr.Reason != FailoverFormat { + t.Errorf("expected FailoverFormat reason, got: %s", failoverErr.Reason) + } + + _ = result // result should be nil +} + +// TestMultiKeyWithModelFallback tests multi-key failover combined with model fallback. +// This simulates the scenario: api_keys: ["k1", "k2"] + fallbacks: ["minimax"] +// Expected failover order: glm-4.7 (k1) → glm-4.7__key_1 (k2) → minimax +func TestMultiKeyWithModelFallback(t *testing.T) { + // Simulate expanded config from: + // { "model_name": "glm-4.7", "api_keys": ["k1", "k2"], "fallbacks": ["minimax"] } + // After ExpandMultiKeyModels, primaryEntry.Fallbacks = ["glm-4.7__key_1", "minimax"] + // Note: In production, "minimax" would be resolved via model lookup to "minimax/minimax" + // In this test, we use the full format to avoid needing a lookup function. + cfg := ModelConfig{ + Primary: "glm-4.7", + Fallbacks: []string{"glm-4.7__key_1", "minimax/minimax"}, + } + + candidates := ResolveCandidates(cfg, "zhipu") + + // Should have 3 candidates: glm-4.7 (zhipu), glm-4.7__key_1 (zhipu), minimax (minimax) + if len(candidates) != 3 { + t.Fatalf("expected 3 candidates, got %d: %v", len(candidates), candidates) + } + + // Verify candidate order + if candidates[0].Model != "glm-4.7" || candidates[0].Provider != "zhipu" { + t.Errorf( + "expected first candidate to be zhipu/glm-4.7, got: %s/%s", + candidates[0].Provider, + candidates[0].Model, + ) + } + if candidates[1].Model != "glm-4.7__key_1" || candidates[1].Provider != "zhipu" { + t.Errorf( + "expected second candidate to be zhipu/glm-4.7__key_1, got: %s/%s", + candidates[1].Provider, + candidates[1].Model, + ) + } + if candidates[2].Model != "minimax" || candidates[2].Provider != "minimax" { + t.Errorf( + "expected third candidate to be minimax/minimax, got: %s/%s", + candidates[2].Provider, + candidates[2].Model, + ) + } + + cooldown := NewCooldownTracker() + chain := NewFallbackChain(cooldown) + + // Mock run function: first two fail, third succeeds (model fallback) + callCount := 0 + calledModels := []string{} + mockRun := func(ctx context.Context, provider, model string) (*LLMResponse, error) { + callCount++ + calledModels = append(calledModels, provider+"/"+model) + + switch callCount { + case 1: + // k1: rate limit + return nil, errors.New("status: 429 - rate limit") + case 2: + // k2: also rate limit (all zhipu keys exhausted) + return nil, errors.New("status: 429 - rate limit") + case 3: + // minimax: success + return &LLMResponse{Content: "success from minimax"}, nil + default: + return nil, errors.New("unexpected call") + } + } + + result, err := chain.Execute(context.Background(), candidates, mockRun) + if err != nil { + t.Fatalf("expected success after failover to model fallback, got error: %v", err) + } + + if callCount != 3 { + t.Errorf("expected 3 calls (k1 fail + k2 fail + minimax success), got %d", callCount) + } + + if result.Response.Content != "success from minimax" { + t.Errorf("expected response from minimax, got: %s", result.Response.Content) + } + + // Verify call order + if len(calledModels) != 3 { + t.Fatalf("expected 3 called models, got %d", len(calledModels)) + } + if calledModels[0] != "zhipu/glm-4.7" { + t.Errorf("expected first call to zhipu/glm-4.7, got: %s", calledModels[0]) + } + if calledModels[1] != "zhipu/glm-4.7__key_1" { + t.Errorf("expected second call to zhipu/glm-4.7__key_1, got: %s", calledModels[1]) + } + if calledModels[2] != "minimax/minimax" { + t.Errorf("expected third call to minimax/minimax, got: %s", calledModels[2]) + } + + // Verify 2 failed attempts recorded + if len(result.Attempts) != 2 { + t.Errorf("expected 2 failed attempts, got %d", len(result.Attempts)) + } + + // Both should be rate limit + for i, attempt := range result.Attempts { + if attempt.Reason != FailoverRateLimit { + t.Errorf("expected attempt %d to be rate_limit, got: %s", i, attempt.Reason) + } + } +} + +// TestMultiKeyFailoverMixedErrors tests failover with different error types +func TestMultiKeyFailoverMixedErrors(t *testing.T) { + cfg := ModelConfig{ + Primary: "glm-4.7", + Fallbacks: []string{"glm-4.7__key_1", "glm-4.7__key_2"}, + } + + candidates := ResolveCandidates(cfg, "zhipu") + + cooldown := NewCooldownTracker() + chain := NewFallbackChain(cooldown) + + // Mock run function: different errors for each key + callCount := 0 + mockRun := func(ctx context.Context, provider, model string) (*LLMResponse, error) { + callCount++ + switch callCount { + case 1: + // First: rate limit (retriable) + return nil, errors.New("status: 429 - rate limit") + case 2: + // Second: timeout (retriable) + return nil, errors.New("context deadline exceeded") + case 3: + // Third: success + return &LLMResponse{Content: "success from key3"}, nil + default: + return nil, errors.New("unexpected call") + } + } + + result, err := chain.Execute(context.Background(), candidates, mockRun) + if err != nil { + t.Fatalf("expected success after 2 failovers, got error: %v", err) + } + + if callCount != 3 { + t.Errorf("expected 3 calls, got %d", callCount) + } + + // Verify both failed attempts were recorded + if len(result.Attempts) != 2 { + t.Errorf("expected 2 failed attempts, got %d", len(result.Attempts)) + } + + // First should be rate limit + if result.Attempts[0].Reason != FailoverRateLimit { + t.Errorf("expected first attempt to be rate_limit, got: %s", result.Attempts[0].Reason) + } + + // Second should be timeout + if result.Attempts[1].Reason != FailoverTimeout { + t.Errorf("expected second attempt to be timeout, got: %s", result.Attempts[1].Reason) + } +} diff --git a/pkg/providers/fallback_test.go b/pkg/providers/fallback_test.go index 1783ebcb5..1a1118e33 100644 --- a/pkg/providers/fallback_test.go +++ b/pkg/providers/fallback_test.go @@ -157,8 +157,8 @@ func TestFallback_CooldownSkip(t *testing.T) { ct, _ := newTestTracker(now) fc := NewFallbackChain(ct) - // Put openai in cooldown - ct.MarkFailure("openai", FailoverRateLimit) + // Put openai/gpt-4 in cooldown (using ModelKey now) + ct.MarkFailure(ModelKey("openai", "gpt-4"), FailoverRateLimit) candidates := []FallbackCandidate{ makeCandidate("openai", "gpt-4"), @@ -195,9 +195,9 @@ func TestFallback_AllInCooldown(t *testing.T) { ct := NewCooldownTracker() fc := NewFallbackChain(ct) - // Put all providers in cooldown - ct.MarkFailure("openai", FailoverRateLimit) - ct.MarkFailure("anthropic", FailoverBilling) + // Put all models in cooldown (using ModelKey now) + ct.MarkFailure(ModelKey("openai", "gpt-4"), FailoverRateLimit) + ct.MarkFailure(ModelKey("anthropic", "claude"), FailoverBilling) candidates := []FallbackCandidate{ makeCandidate("openai", "gpt-4"), @@ -273,12 +273,13 @@ func TestFallback_SuccessResetsCooldown(t *testing.T) { fc := NewFallbackChain(ct) candidates := []FallbackCandidate{makeCandidate("openai", "gpt-4")} + modelKey := ModelKey("openai", "gpt-4") attempt := 0 run := func(ctx context.Context, provider, model string) (*LLMResponse, error) { attempt++ if attempt == 1 { - ct.MarkFailure("openai", FailoverRateLimit) // simulate failure tracked elsewhere + ct.MarkFailure(modelKey, FailoverRateLimit) // simulate failure tracked elsewhere } return &LLMResponse{Content: "ok", FinishReason: "stop"}, nil } @@ -287,7 +288,7 @@ func TestFallback_SuccessResetsCooldown(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %v", err) } - if !ct.IsAvailable("openai") { + if !ct.IsAvailable(modelKey) { t.Error("success should reset cooldown") } }