diff --git a/config/config.example.json b/config/config.example.json index 167ba7d59..c214f26fa 100644 --- a/config/config.example.json +++ b/config/config.example.json @@ -122,7 +122,8 @@ "verification_token": "", "allow_from": [], "reasoning_channel_id": "", - "random_reaction_emoji": [] + "random_reaction_emoji": [], + "is_lark": false }, "dingtalk": { "enabled": false, diff --git a/docs/channels/feishu/README.zh.md b/docs/channels/feishu/README.zh.md index 3fafffb7d..db7eb56eb 100644 --- a/docs/channels/feishu/README.zh.md +++ b/docs/channels/feishu/README.zh.md @@ -13,25 +13,27 @@ "app_secret": "xxx", "encrypt_key": "", "verification_token": "", - "allow_from": [] + "allow_from": [], + "is_lark": false } } } ``` -| 字段 | 类型 | 必填 | 描述 | -| ------------------ | ------ | ---- | -------------------------------- | -| enabled | bool | 是 | 是否启用飞书频道 | -| app_id | string | 是 | 飞书应用的 App ID(以cli\_开头) | -| app_secret | string | 是 | 飞书应用的 App Secret | -| encrypt_key | string | 否 | 事件回调加密密钥 | -| verification_token | string | 否 | 用于Webhook事件验证的Token | -| allow_from | array | 否 | 用户ID白名单,空表示所有用户 | -| random_reaction_emoji | array | 否 | 随机添加的表情列表,空则使用默认 "Pin" | +| 字段 | 类型 | 必填 | 描述 | +| --------------------- | ------ | ---- | ------------------------------------------------------------------------------------------------ | +| enabled | bool | 是 | 是否启用飞书频道 | +| app_id | string | 是 | 飞书应用的 App ID(以cli\_开头) | +| app_secret | string | 是 | 飞书应用的 App Secret | +| encrypt_key | string | 否 | 事件回调加密密钥 | +| verification_token | string | 否 | 用于Webhook事件验证的Token | +| allow_from | array | 否 | 用户ID白名单,空表示所有用户 | +| random_reaction_emoji | array | 否 | 随机添加的表情列表,空则使用默认 "Pin" | +| is_lark | bool | 否 | 是否使用 Lark 国际版域名(`open.larksuite.com`),默认为 `false`(使用飞书域名 `open.feishu.cn`) | ## 设置流程 -1. 前往 [飞书开放平台](https://open.feishu.cn/)创建应用程序 +1. 前往 [飞书开放平台](https://open.feishu.cn/)(国际版用户请前往 [Lark 开放平台](https://open.larksuite.com/))创建应用程序 2. 获取 App ID 和 App Secret 3. 配置事件订阅和Webhook URL 4. 设置加密(可选,生产环境建议启用) diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index aa9dbc3e8..8e9a70f2e 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -307,6 +307,11 @@ func registerSharedTools( return spawnSubTurn(ctx, al, parentTS, cfg) }) + // Clone the parent's tool registry so subagents can use all + // tools registered so far (file, web, etc.) but NOT spawn/ + // spawn_status which are added below — preventing recursive + // subagent spawning. + subagentManager.SetTools(agent.Tools.Clone()) if spawnEnabled { spawnTool := tools.NewSpawnTool(subagentManager) currentAgentID := agentID diff --git a/pkg/channels/feishu/feishu_64.go b/pkg/channels/feishu/feishu_64.go index c503e2993..3aea67b12 100644 --- a/pkg/channels/feishu/feishu_64.go +++ b/pkg/channels/feishu/feishu_64.go @@ -54,11 +54,15 @@ func NewFeishuChannel(cfg config.FeishuConfig, bus *bus.MessageBus) (*FeishuChan ) tc := newTokenCache() + opts := []lark.ClientOptionFunc{lark.WithTokenCache(tc)} + if cfg.IsLark { + opts = append(opts, lark.WithOpenBaseUrl(lark.LarkBaseUrl)) + } ch := &FeishuChannel{ BaseChannel: base, config: cfg, tokenCache: tc, - client: lark.NewClient(cfg.AppID, cfg.AppSecret, lark.WithTokenCache(tc)), + client: lark.NewClient(cfg.AppID, cfg.AppSecret, opts...), } ch.SetOwner(ch) return ch, nil @@ -83,10 +87,15 @@ func (c *FeishuChannel) Start(ctx context.Context) error { c.mu.Lock() c.cancel = cancel + domain := lark.FeishuBaseUrl + if c.config.IsLark { + domain = lark.LarkBaseUrl + } c.wsClient = larkws.NewClient( c.config.AppID, c.config.AppSecret, larkws.WithEventHandler(dispatcher), + larkws.WithDomain(domain), ) wsClient := c.wsClient c.mu.Unlock() diff --git a/pkg/config/config.go b/pkg/config/config.go index 64791e6e5..fe0fd711d 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -326,6 +326,7 @@ type FeishuConfig struct { Placeholder PlaceholderConfig `json:"placeholder,omitempty"` ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_FEISHU_REASONING_CHANNEL_ID"` RandomReactionEmoji FlexibleStringSlice `json:"random_reaction_emoji" env:"PICOCLAW_CHANNELS_FEISHU_RANDOM_REACTION_EMOJI"` + IsLark bool `json:"is_lark" env:"PICOCLAW_CHANNELS_FEISHU_IS_LARK"` } type DiscordConfig struct { @@ -603,9 +604,11 @@ type ModelConfig struct { Model string `json:"model"` // Protocol/model-identifier (e.g., "openai/gpt-4o", "anthropic/claude-sonnet-4.6") // HTTP-based providers - APIBase string `json:"api_base,omitempty"` // API endpoint URL - APIKey string `json:"api_key"` // API authentication key - Proxy string `json:"proxy,omitempty"` // HTTP proxy URL + APIBase string `json:"api_base,omitempty"` // API endpoint URL + APIKey string `json:"api_key"` // API authentication key (single key) + APIKeys []string `json:"api_keys,omitempty"` // API authentication keys (multiple keys for failover) + Proxy string `json:"proxy,omitempty"` // HTTP proxy URL + Fallbacks []string `json:"fallbacks,omitempty"` // Fallback model names for failover // Special providers (CLI-based, OAuth, etc.) AuthMethod string `json:"auth_method,omitempty"` // Authentication method: oauth, token @@ -874,6 +877,9 @@ func LoadConfig(path string) (*Config, error) { return nil, err } + // Expand multi-key configs into separate entries for key-level failover + cfg.ModelList = ExpandMultiKeyModels(cfg.ModelList) + // Migrate legacy channel config fields to new unified structures cfg.migrateChannelConfigs() @@ -920,14 +926,25 @@ func encryptPlaintextAPIKeys(models []ModelConfig, passphrase string) ([]ModelCo // resolveAPIKeys decrypts or dereferences each api_key in models in-place. // Supports plaintext (no-op), file:// (read from configDir), and enc:// (AES-GCM decrypt). +// Also resolves api_keys array if present. func resolveAPIKeys(models []ModelConfig, configDir string) error { cr := credential.NewResolver(configDir) for i := range models { + // Resolve single APIKey resolved, err := cr.Resolve(models[i].APIKey) if err != nil { return fmt.Errorf("model_list[%d] (%s): %w", i, models[i].ModelName, err) } models[i].APIKey = resolved + + // Resolve APIKeys array + for j, key := range models[i].APIKeys { + resolved, err := cr.Resolve(key) + if err != nil { + return fmt.Errorf("model_list[%d] (%s): api_keys[%d]: %w", i, models[i].ModelName, j, err) + } + models[i].APIKeys[j] = resolved + } } return nil } @@ -1098,6 +1115,89 @@ func MergeAPIKeys(apiKey string, apiKeys []string) []string { return all } +// ExpandMultiKeyModels expands ModelConfig entries with multiple API keys into +// separate entries for key-level failover. Each key gets its own ModelConfig entry, +// and the original entry's fallbacks are set up to chain through the expanded entries. +// +// Example: {"model_name": "gpt-4", "api_keys": ["k1", "k2", "k3"]} +// Becomes: +// - {"model_name": "gpt-4", "api_key": "k1", "fallbacks": ["gpt-4__key_1", "gpt-4__key_2"]} +// - {"model_name": "gpt-4__key_1", "api_key": "k2"} +// - {"model_name": "gpt-4__key_2", "api_key": "k3"} +func ExpandMultiKeyModels(models []ModelConfig) []ModelConfig { + var expanded []ModelConfig + + for _, m := range models { + keys := MergeAPIKeys(m.APIKey, m.APIKeys) + + // Single key or no keys: keep as-is + if len(keys) <= 1 { + // Ensure APIKey is set from APIKeys if needed + if m.APIKey == "" && len(keys) == 1 { + m.APIKey = keys[0] + } + m.APIKeys = nil // Clear APIKeys to avoid confusion + expanded = append(expanded, m) + continue + } + + // Multiple keys: expand + originalName := m.ModelName + + // Create entries for additional keys (key_1, key_2, ...) + var fallbackNames []string + for i := 1; i < len(keys); i++ { + suffix := fmt.Sprintf("__key_%d", i) + expandedName := originalName + suffix + + // Create a copy for the additional key + additionalEntry := ModelConfig{ + ModelName: expandedName, + Model: m.Model, + APIBase: m.APIBase, + APIKey: keys[i], + Proxy: m.Proxy, + AuthMethod: m.AuthMethod, + ConnectMode: m.ConnectMode, + Workspace: m.Workspace, + RPM: m.RPM, + MaxTokensField: m.MaxTokensField, + RequestTimeout: m.RequestTimeout, + ThinkingLevel: m.ThinkingLevel, + } + expanded = append(expanded, additionalEntry) + fallbackNames = append(fallbackNames, expandedName) + } + + // Create the primary entry with first key and fallbacks + primaryEntry := ModelConfig{ + ModelName: originalName, + Model: m.Model, + APIBase: m.APIBase, + APIKey: keys[0], + Proxy: m.Proxy, + AuthMethod: m.AuthMethod, + ConnectMode: m.ConnectMode, + Workspace: m.Workspace, + RPM: m.RPM, + MaxTokensField: m.MaxTokensField, + RequestTimeout: m.RequestTimeout, + ThinkingLevel: m.ThinkingLevel, + } + + // Prepend new fallbacks to existing ones + if len(fallbackNames) > 0 { + primaryEntry.Fallbacks = append(fallbackNames, m.Fallbacks...) + } else if len(m.Fallbacks) > 0 { + primaryEntry.Fallbacks = m.Fallbacks + } + + expanded = append(expanded, primaryEntry) + } + + return expanded +} + func (t *ToolsConfig) IsToolEnabled(name string) bool { switch name { case "web": diff --git a/pkg/config/multikey_test.go b/pkg/config/multikey_test.go new file mode 100644 index 000000000..b899b991c --- /dev/null +++ b/pkg/config/multikey_test.go @@ -0,0 +1,291 @@ +package config + +import ( + "testing" +) + +func TestExpandMultiKeyModels_SingleKey(t *testing.T) { + models := []ModelConfig{ + { + ModelName: "gpt-4", + Model: "openai/gpt-4o", + APIKey: "single-key", + }, + } + + result := ExpandMultiKeyModels(models) + + if len(result) != 1 { + t.Fatalf("expected 1 model, got %d", len(result)) + } + + if result[0].ModelName != "gpt-4" { + t.Errorf("expected model_name 'gpt-4', got %q", result[0].ModelName) + } + + if result[0].APIKey != "single-key" { + t.Errorf("expected api_key 'single-key', got %q", result[0].APIKey) + } + + if len(result[0].Fallbacks) != 0 { + t.Errorf("expected no fallbacks, got %v", result[0].Fallbacks) + } +} + +func TestExpandMultiKeyModels_APIKeysOnly(t *testing.T) { + models := []ModelConfig{ + { + ModelName: "glm-4.7", + Model: "zhipu/glm-4.7", + APIBase: "https://api.example.com", + APIKeys: []string{"key1", "key2", "key3"}, + }, + } + + result := ExpandMultiKeyModels(models) + + // Should expand to 3 models + if len(result) != 3 { + t.Fatalf("expected 3 models, got %d", len(result)) + } + + // First entry should be the primary with key1 and fallbacks + primary := result[2] // Primary is added last + if primary.ModelName != "glm-4.7" { + t.Errorf("expected primary model_name 'glm-4.7', got %q", primary.ModelName) + } + if primary.APIKey != "key1" { + t.Errorf("expected primary api_key 'key1', got %q", primary.APIKey) + } + if len(primary.Fallbacks) != 2 { + t.Errorf("expected 2 fallbacks, got %d", len(primary.Fallbacks)) + } + if primary.Fallbacks[0] != "glm-4.7__key_1" { + t.Errorf("expected first fallback 'glm-4.7__key_1', got %q", primary.Fallbacks[0]) + } + if primary.Fallbacks[1] != "glm-4.7__key_2" { + t.Errorf("expected second fallback 'glm-4.7__key_2', got %q", primary.Fallbacks[1]) + } + + // Second entry should be key2 + second := result[0] + if second.ModelName != "glm-4.7__key_1" { + t.Errorf("expected second model_name 'glm-4.7__key_1', got %q", second.ModelName) + } + if second.APIKey != "key2" { + t.Errorf("expected second api_key 'key2', got %q", second.APIKey) + } + + // Third entry should be key3 + third := result[1] + if third.ModelName != "glm-4.7__key_2" { + t.Errorf("expected third model_name 'glm-4.7__key_2', got %q", third.ModelName) + } + if third.APIKey != "key3" { + t.Errorf("expected third api_key 'key3', got %q", third.APIKey) + } +} + +func TestExpandMultiKeyModels_APIKeyAndAPIKeys(t *testing.T) { + models := []ModelConfig{ + { + ModelName: "gpt-4", + Model: "openai/gpt-4o", + APIKey: "key0", + APIKeys: []string{"key1", "key2"}, + }, + } + + result := ExpandMultiKeyModels(models) + + // Should expand to 3 models (key0 from APIKey + key1, key2 from APIKeys) + if len(result) != 3 { + t.Fatalf("expected 3 models, got %d", len(result)) + } + + // Primary should use key0 + primary := result[2] + if primary.APIKey != "key0" { + t.Errorf("expected primary api_key 'key0', got %q", primary.APIKey) + } + if len(primary.Fallbacks) != 2 { + t.Errorf("expected 2 fallbacks, got %d", len(primary.Fallbacks)) + } +} + +func TestExpandMultiKeyModels_WithExistingFallbacks(t *testing.T) { + models := []ModelConfig{ + { + ModelName: "gpt-4", + Model: "openai/gpt-4o", + APIKeys: []string{"key1", "key2"}, + Fallbacks: []string{"claude-3"}, + }, + } + + result := ExpandMultiKeyModels(models) + + primary := result[1] + // With 2 keys, we get 1 key fallback + 1 existing fallback = 2 total + if len(primary.Fallbacks) != 2 { + t.Fatalf("expected 2 fallbacks, got %d: %v", len(primary.Fallbacks), primary.Fallbacks) + } + + // Key fallbacks should come first, then existing fallbacks + if primary.Fallbacks[0] != "gpt-4__key_1" { + t.Errorf("expected first fallback 'gpt-4__key_1', got %q", primary.Fallbacks[0]) + } + if primary.Fallbacks[1] != "claude-3" { + t.Errorf("expected second fallback 'claude-3', got %q", primary.Fallbacks[1]) + } +} + +func TestExpandMultiKeyModels_EmptyAPIKeys(t *testing.T) { + models := []ModelConfig{ + { + ModelName: "gpt-4", + Model: "openai/gpt-4o", + APIKey: "", + APIKeys: []string{}, + }, + } + + result := ExpandMultiKeyModels(models) + + // Should keep as-is with no changes + if len(result) != 1 { + t.Fatalf("expected 1 model, got %d", len(result)) + } + + if result[0].ModelName != "gpt-4" { + t.Errorf("expected model_name 'gpt-4', got %q", result[0].ModelName) + } +} + +func TestExpandMultiKeyModels_Deduplication(t *testing.T) { + models := []ModelConfig{ + { + ModelName: "gpt-4", + Model: "openai/gpt-4o", + APIKey: "key1", + APIKeys: []string{"key1", "key2", "key1"}, // Duplicate key1 + }, + } + + result := ExpandMultiKeyModels(models) + + // Should only create 2 models (deduplicated keys) + if len(result) != 2 { + t.Fatalf("expected 2 models (deduplicated), got %d", len(result)) + } + + primary := result[1] + if primary.APIKey != "key1" { + t.Errorf("expected primary api_key 'key1', got %q", primary.APIKey) + } + if len(primary.Fallbacks) != 1 { + t.Errorf("expected 1 fallback, got %d", len(primary.Fallbacks)) + } +} + +func TestExpandMultiKeyModels_PreservesOtherFields(t *testing.T) { + models := []ModelConfig{ + { + ModelName: "gpt-4", + Model: "openai/gpt-4o", + APIBase: "https://api.example.com", + APIKeys: []string{"key1", "key2"}, + Proxy: "http://proxy:8080", + RPM: 60, + MaxTokensField: "max_completion_tokens", + RequestTimeout: 30, + ThinkingLevel: "high", + }, + } + + result := ExpandMultiKeyModels(models) + + // Check primary entry preserves all fields + primary := result[1] + if primary.APIBase != "https://api.example.com" { + t.Errorf("expected api_base preserved, got %q", primary.APIBase) + } + if primary.Proxy != "http://proxy:8080" { + t.Errorf("expected proxy preserved, got %q", primary.Proxy) + } + if primary.RPM != 60 { + t.Errorf("expected rpm preserved, got %d", primary.RPM) + } + if primary.MaxTokensField != "max_completion_tokens" { + t.Errorf("expected max_tokens_field preserved, got %q", primary.MaxTokensField) + } + if primary.RequestTimeout != 30 { + t.Errorf("expected request_timeout preserved, got %d", primary.RequestTimeout) + } + if primary.ThinkingLevel != "high" { + t.Errorf("expected thinking_level preserved, got %q", primary.ThinkingLevel) + } + + // Check additional entry also preserves fields + additional := result[0] + if additional.APIBase != "https://api.example.com" { + t.Errorf("expected additional api_base preserved, got %q", additional.APIBase) + } + if additional.RPM != 60 { + t.Errorf("expected additional rpm preserved, got %d", additional.RPM) + } +} + +func TestMergeAPIKeys(t *testing.T) { + tests := []struct { + name string + apiKey string + apiKeys []string + expected []string + }{ + { + name: "both empty", + apiKey: "", + apiKeys: nil, + expected: nil, + }, + { + name: "only apiKey", + apiKey: "key1", + apiKeys: nil, + expected: []string{"key1"}, + }, + { + name: "only apiKeys", + apiKey: "", + apiKeys: []string{"key1", "key2"}, + expected: []string{"key1", "key2"}, + }, + { + name: "both with overlap", + apiKey: "key1", + apiKeys: []string{"key1", "key2", "key3"}, + expected: []string{"key1", "key2", "key3"}, + }, + { + name: "with whitespace", + apiKey: " key1 ", + apiKeys: []string{" key2 ", " key1 "}, + expected: []string{"key1", "key2"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := MergeAPIKeys(tt.apiKey, tt.apiKeys) + if len(result) != len(tt.expected) { + t.Fatalf("expected %d keys, got %d", len(tt.expected), len(result)) + } + for i, k := range result { + if k != tt.expected[i] { + t.Errorf("expected key[%d] = %q, got %q", i, tt.expected[i], k) + } + } + }) + } +} diff --git a/pkg/providers/fallback.go b/pkg/providers/fallback.go index 7ba563b66..549ec7837 100644 --- a/pkg/providers/fallback.go +++ b/pkg/providers/fallback.go @@ -117,17 +117,19 @@ func (fc *FallbackChain) Execute( return nil, context.Canceled } - // Check cooldown. - if !fc.cooldown.IsAvailable(candidate.Provider) { - remaining := fc.cooldown.CooldownRemaining(candidate.Provider) + // Check cooldown (per provider/model, not just provider). + // This allows multi-key failover where different keys use different model names. + cooldownKey := ModelKey(candidate.Provider, candidate.Model) + if !fc.cooldown.IsAvailable(cooldownKey) { + remaining := fc.cooldown.CooldownRemaining(cooldownKey) result.Attempts = append(result.Attempts, FallbackAttempt{ Provider: candidate.Provider, Model: candidate.Model, Skipped: true, Reason: FailoverRateLimit, Error: fmt.Errorf( - "provider %s in cooldown (%s remaining)", - candidate.Provider, + "%s in cooldown (%s remaining)", + cooldownKey, remaining.Round(time.Second), ), }) @@ -141,7 +143,7 @@ func (fc *FallbackChain) Execute( if err == nil { // Success. - fc.cooldown.MarkSuccess(candidate.Provider) + fc.cooldown.MarkSuccess(cooldownKey) result.Response = resp result.Provider = candidate.Provider result.Model = candidate.Model @@ -187,7 +189,7 @@ func (fc *FallbackChain) Execute( } // Retriable error: mark failure and continue to next candidate. - fc.cooldown.MarkFailure(candidate.Provider, failErr.Reason) + fc.cooldown.MarkFailure(cooldownKey, failErr.Reason) result.Attempts = append(result.Attempts, FallbackAttempt{ Provider: candidate.Provider, Model: candidate.Model, diff --git a/pkg/providers/fallback_multikey_test.go b/pkg/providers/fallback_multikey_test.go new file mode 100644 index 000000000..9ed8fa73c --- /dev/null +++ b/pkg/providers/fallback_multikey_test.go @@ -0,0 +1,384 @@ +package providers + +import ( + "context" + "errors" + "testing" +) + +// TestMultiKeyFailover tests the complete failover flow with multiple API keys. +// This simulates the config expansion scenario where api_keys: ["key1", "key2", "key3"] +// is expanded into primary + fallbacks. +func TestMultiKeyFailover(t *testing.T) { + // Simulate expanded config: primary with 2 fallbacks + // This is what ExpandMultiKeyModels would produce for api_keys: ["key1", "key2", "key3"] + cfg := ModelConfig{ + Primary: "glm-4.7", + Fallbacks: []string{"glm-4.7__key_1", "glm-4.7__key_2"}, + } + + candidates := ResolveCandidates(cfg, "zhipu") + + if len(candidates) != 3 { + t.Fatalf("expected 3 candidates, got %d: %v", len(candidates), candidates) + } + + // Create fallback chain + cooldown := NewCooldownTracker() + chain := NewFallbackChain(cooldown) + + // Mock run function: first call fails with 429, second succeeds + callCount := 0 + mockRun := func(ctx context.Context, provider, model string) (*LLMResponse, error) { + callCount++ + if callCount == 1 { + // First call: simulate rate limit + return nil, errors.New("http error: status 429 - rate limit exceeded") + } + // Second call: success + return &LLMResponse{ + Content: "Hello from key2!", + }, nil + } + + // Execute fallback chain + result, err := chain.Execute(context.Background(), candidates, mockRun) + if err != nil { + t.Fatalf("expected success after failover, got error: %v", err) + } + + if result == nil { + t.Fatal("expected result, got nil") + } + + if result.Response.Content != "Hello from key2!" { + t.Errorf("expected response from key2, got: %s", result.Response.Content) + } + + if callCount != 2 { + t.Errorf("expected 2 calls (1 fail + 1 success), got %d", callCount) + } + + // Verify first attempt was recorded + if len(result.Attempts) != 1 { + t.Errorf("expected 1 failed attempt recorded, got %d", len(result.Attempts)) + } + + if result.Attempts[0].Reason != FailoverRateLimit { + t.Errorf( + "expected first attempt reason to be rate_limit, got: %s", + result.Attempts[0].Reason, + ) + } +} + +// TestMultiKeyFailoverAllFail tests when all keys hit rate limit +func TestMultiKeyFailoverAllFail(t *testing.T) { + cfg := ModelConfig{ + Primary: "glm-4.7", + Fallbacks: []string{"glm-4.7__key_1", "glm-4.7__key_2"}, + } + + candidates := ResolveCandidates(cfg, "zhipu") + + cooldown := NewCooldownTracker() + chain := NewFallbackChain(cooldown) + + // Mock run function: all calls fail with rate limit + callCount := 0 + mockRun := func(ctx context.Context, provider, model string) (*LLMResponse, error) { + callCount++ + return nil, errors.New("status: 429 - too many requests") + } + + // Execute fallback chain + result, err := chain.Execute(context.Background(), candidates, mockRun) + + if err == nil { + t.Fatal("expected error when all keys fail, got nil") + } + + if result != nil { + t.Errorf("expected nil result on failure, got: %v", result) + } + + if callCount != 3 { + t.Errorf("expected 3 calls (all fail), got %d", callCount) + } + + // Verify error type + var exhausted *FallbackExhaustedError + if !errors.As(err, &exhausted) { + t.Errorf("expected FallbackExhaustedError, got: %T - %v", err, err) + } + + if len(exhausted.Attempts) != 3 { + t.Errorf("expected 3 attempts in exhausted error, got %d", len(exhausted.Attempts)) + } +} + +// TestMultiKeyFailoverCooldown tests that a key in cooldown is skipped +func TestMultiKeyFailoverCooldown(t *testing.T) { + cfg := ModelConfig{ + Primary: "glm-4.7", + Fallbacks: []string{"glm-4.7__key_1"}, + } + + candidates := ResolveCandidates(cfg, "zhipu") + + cooldown := NewCooldownTracker() + chain := NewFallbackChain(cooldown) + + // Put the first model in cooldown (using ModelKey now, not just provider) + cooldownKey := ModelKey(candidates[0].Provider, candidates[0].Model) + cooldown.MarkFailure(cooldownKey, FailoverRateLimit) + + // Verify it's not available + if cooldown.IsAvailable(cooldownKey) { + t.Fatal("expected first model to be in cooldown") + } + + // Mock run function: only second should be called + callCount := 0 + calledProviders := []string{} + mockRun := func(ctx context.Context, provider, model string) (*LLMResponse, error) { + callCount++ + calledProviders = append(calledProviders, provider+"/"+model) + return &LLMResponse{Content: "success"}, nil + } + + result, err := chain.Execute(context.Background(), candidates, mockRun) + if err != nil { + t.Fatalf("expected success, got error: %v", err) + } + + // First provider should have been skipped + if callCount != 1 { + t.Errorf("expected 1 call (first skipped due to cooldown), got %d", callCount) + } + + // Should have called the second provider/model + if len(calledProviders) != 1 || + calledProviders[0] != candidates[1].Provider+"/"+candidates[1].Model { + t.Errorf("expected second model to be called, got: %v", calledProviders) + } + + // Verify first attempt was recorded as skipped + if len(result.Attempts) != 1 { + t.Fatalf("expected 1 attempt (skipped), got %d", len(result.Attempts)) + } + + if !result.Attempts[0].Skipped { + t.Error("expected first attempt to be marked as skipped") + } +} + +// TestMultiKeyFailoverWithFormatError tests that format errors are non-retriable +func TestMultiKeyFailoverWithFormatError(t *testing.T) { + cfg := ModelConfig{ + Primary: "glm-4.7", + Fallbacks: []string{"glm-4.7__key_1"}, + } + + candidates := ResolveCandidates(cfg, "zhipu") + + cooldown := NewCooldownTracker() + chain := NewFallbackChain(cooldown) + + // Mock run function: first call fails with format error (bad request) + callCount := 0 + mockRun := func(ctx context.Context, provider, model string) (*LLMResponse, error) { + callCount++ + return nil, errors.New("invalid request format: tool_use.id missing") + } + + // Execute fallback chain + result, err := chain.Execute(context.Background(), candidates, mockRun) + + if err == nil { + t.Fatal("expected error for format failure, got nil") + } + + // Format errors should NOT trigger failover (non-retriable) + // So we should only have 1 call + if callCount != 1 { + t.Errorf("expected 1 call (format error is non-retriable), got %d", callCount) + } + + // Verify the error is a FailoverError with format reason + var failoverErr *FailoverError + if !errors.As(err, &failoverErr) { + t.Errorf("expected FailoverError, got: %T - %v", err, err) + } + + if failoverErr.Reason != FailoverFormat { + t.Errorf("expected FailoverFormat reason, got: %s", failoverErr.Reason) + } + + _ = result // result should be nil +} + +// TestMultiKeyWithModelFallback tests multi-key failover combined with model fallback. +// This simulates the scenario: api_keys: ["k1", "k2"] + fallbacks: ["minimax"] +// Expected failover order: glm-4.7 (k1) → glm-4.7__key_1 (k2) → minimax +func TestMultiKeyWithModelFallback(t *testing.T) { + // Simulate expanded config from: + // { "model_name": "glm-4.7", "api_keys": ["k1", "k2"], "fallbacks": ["minimax"] } + // After ExpandMultiKeyModels, primaryEntry.Fallbacks = ["glm-4.7__key_1", "minimax"] + // Note: In production, "minimax" would be resolved via model lookup to "minimax/minimax" + // In this test, we use the full format to avoid needing a lookup function. + cfg := ModelConfig{ + Primary: "glm-4.7", + Fallbacks: []string{"glm-4.7__key_1", "minimax/minimax"}, + } + + candidates := ResolveCandidates(cfg, "zhipu") + + // Should have 3 candidates: glm-4.7 (zhipu), glm-4.7__key_1 (zhipu), minimax (minimax) + if len(candidates) != 3 { + t.Fatalf("expected 3 candidates, got %d: %v", len(candidates), candidates) + } + + // Verify candidate order + if candidates[0].Model != "glm-4.7" || candidates[0].Provider != "zhipu" { + t.Errorf( + "expected first candidate to be zhipu/glm-4.7, got: %s/%s", + candidates[0].Provider, + candidates[0].Model, + ) + } + if candidates[1].Model != "glm-4.7__key_1" || candidates[1].Provider != "zhipu" { + t.Errorf( + "expected second candidate to be zhipu/glm-4.7__key_1, got: %s/%s", + candidates[1].Provider, + candidates[1].Model, + ) + } + if candidates[2].Model != "minimax" || candidates[2].Provider != "minimax" { + t.Errorf( + "expected third candidate to be minimax/minimax, got: %s/%s", + candidates[2].Provider, + candidates[2].Model, + ) + } + + cooldown := NewCooldownTracker() + chain := NewFallbackChain(cooldown) + + // Mock run function: first two fail, third succeeds (model fallback) + callCount := 0 + calledModels := []string{} + mockRun := func(ctx context.Context, provider, model string) (*LLMResponse, error) { + callCount++ + calledModels = append(calledModels, provider+"/"+model) + + switch callCount { + case 1: + // k1: rate limit + return nil, errors.New("status: 429 - rate limit") + case 2: + // k2: also rate limit (all zhipu keys exhausted) + return nil, errors.New("status: 429 - rate limit") + case 3: + // minimax: success + return &LLMResponse{Content: "success from minimax"}, nil + default: + return nil, errors.New("unexpected call") + } + } + + result, err := chain.Execute(context.Background(), candidates, mockRun) + if err != nil { + t.Fatalf("expected success after failover to model fallback, got error: %v", err) + } + + if callCount != 3 { + t.Errorf("expected 3 calls (k1 fail + k2 fail + minimax success), got %d", callCount) + } + + if result.Response.Content != "success from minimax" { + t.Errorf("expected response from minimax, got: %s", result.Response.Content) + } + + // Verify call order + if len(calledModels) != 3 { + t.Fatalf("expected 3 called models, got %d", len(calledModels)) + } + if calledModels[0] != "zhipu/glm-4.7" { + t.Errorf("expected first call to zhipu/glm-4.7, got: %s", calledModels[0]) + } + if calledModels[1] != "zhipu/glm-4.7__key_1" { + t.Errorf("expected second call to zhipu/glm-4.7__key_1, got: %s", calledModels[1]) + } + if calledModels[2] != "minimax/minimax" { + t.Errorf("expected third call to minimax/minimax, got: %s", calledModels[2]) + } + + // Verify 2 failed attempts recorded + if len(result.Attempts) != 2 { + t.Errorf("expected 2 failed attempts, got %d", len(result.Attempts)) + } + + // Both should be rate limit + for i, attempt := range result.Attempts { + if attempt.Reason != FailoverRateLimit { + t.Errorf("expected attempt %d to be rate_limit, got: %s", i, attempt.Reason) + } + } +} + +// TestMultiKeyFailoverMixedErrors tests failover with different error types +func TestMultiKeyFailoverMixedErrors(t *testing.T) { + cfg := ModelConfig{ + Primary: "glm-4.7", + Fallbacks: []string{"glm-4.7__key_1", "glm-4.7__key_2"}, + } + + candidates := ResolveCandidates(cfg, "zhipu") + + cooldown := NewCooldownTracker() + chain := NewFallbackChain(cooldown) + + // Mock run function: different errors for each key + callCount := 0 + mockRun := func(ctx context.Context, provider, model string) (*LLMResponse, error) { + callCount++ + switch callCount { + case 1: + // First: rate limit (retriable) + return nil, errors.New("status: 429 - rate limit") + case 2: + // Second: timeout (retriable) + return nil, errors.New("context deadline exceeded") + case 3: + // Third: success + return &LLMResponse{Content: "success from key3"}, nil + default: + return nil, errors.New("unexpected call") + } + } + + result, err := chain.Execute(context.Background(), candidates, mockRun) + if err != nil { + t.Fatalf("expected success after 2 failovers, got error: %v", err) + } + + if callCount != 3 { + t.Errorf("expected 3 calls, got %d", callCount) + } + + // Verify both failed attempts were recorded + if len(result.Attempts) != 2 { + t.Errorf("expected 2 failed attempts, got %d", len(result.Attempts)) + } + + // First should be rate limit + if result.Attempts[0].Reason != FailoverRateLimit { + t.Errorf("expected first attempt to be rate_limit, got: %s", result.Attempts[0].Reason) + } + + // Second should be timeout + if result.Attempts[1].Reason != FailoverTimeout { + t.Errorf("expected second attempt to be timeout, got: %s", result.Attempts[1].Reason) + } +} diff --git a/pkg/providers/fallback_test.go b/pkg/providers/fallback_test.go index 1783ebcb5..1a1118e33 100644 --- a/pkg/providers/fallback_test.go +++ b/pkg/providers/fallback_test.go @@ -157,8 +157,8 @@ func TestFallback_CooldownSkip(t *testing.T) { ct, _ := newTestTracker(now) fc := NewFallbackChain(ct) - // Put openai in cooldown - ct.MarkFailure("openai", FailoverRateLimit) + // Put openai/gpt-4 in cooldown (using ModelKey now) + ct.MarkFailure(ModelKey("openai", "gpt-4"), FailoverRateLimit) candidates := []FallbackCandidate{ makeCandidate("openai", "gpt-4"), @@ -195,9 +195,9 @@ func TestFallback_AllInCooldown(t *testing.T) { ct := NewCooldownTracker() fc := NewFallbackChain(ct) - // Put all providers in cooldown - ct.MarkFailure("openai", FailoverRateLimit) - ct.MarkFailure("anthropic", FailoverBilling) + // Put all models in cooldown (using ModelKey now) + ct.MarkFailure(ModelKey("openai", "gpt-4"), FailoverRateLimit) + ct.MarkFailure(ModelKey("anthropic", "claude"), FailoverBilling) candidates := []FallbackCandidate{ makeCandidate("openai", "gpt-4"), @@ -273,12 +273,13 @@ func TestFallback_SuccessResetsCooldown(t *testing.T) { fc := NewFallbackChain(ct) candidates := []FallbackCandidate{makeCandidate("openai", "gpt-4")} + modelKey := ModelKey("openai", "gpt-4") attempt := 0 run := func(ctx context.Context, provider, model string) (*LLMResponse, error) { attempt++ if attempt == 1 { - ct.MarkFailure("openai", FailoverRateLimit) // simulate failure tracked elsewhere + ct.MarkFailure(modelKey, FailoverRateLimit) // simulate failure tracked elsewhere } return &LLMResponse{Content: "ok", FinishReason: "stop"}, nil } @@ -287,7 +288,7 @@ func TestFallback_SuccessResetsCooldown(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %v", err) } - if !ct.IsAvailable("openai") { + if !ct.IsAvailable(modelKey) { t.Error("success should reset cooldown") } } diff --git a/pkg/tools/registry.go b/pkg/tools/registry.go index c879e802b..e05fcc2e6 100644 --- a/pkg/tools/registry.go +++ b/pkg/tools/registry.go @@ -188,15 +188,48 @@ func (r *ToolRegistry) ExecuteWithContext( // The callback is a call parameter, not mutable state on the tool instance. var result *ToolResult start := time.Now() - if asyncExec, ok := tool.(AsyncExecutor); ok && asyncCallback != nil { - logger.DebugCF("tool", "Executing async tool via ExecuteAsync", - map[string]any{ - "tool": name, - }) - result = asyncExec.ExecuteAsync(ctx, args, asyncCallback) - } else { - result = tool.Execute(ctx, args) + + // Use recover to catch any panics during tool execution + // This prevents tool crashes from killing the entire agent + func() { + defer func() { + if re := recover(); re != nil { + errMsg := fmt.Sprintf("Tool '%s' crashed with panic: %v", name, re) + logger.ErrorCF("tool", "Tool execution panic recovered", + map[string]any{ + "tool": name, + "panic": fmt.Sprintf("%v", re), + }) + result = &ToolResult{ + ForLLM: errMsg, + ForUser: errMsg, + IsError: true, + Err: fmt.Errorf("panic: %v", re), + } + } + }() + + if asyncExec, ok := tool.(AsyncExecutor); ok && asyncCallback != nil { + logger.DebugCF("tool", "Executing async tool via ExecuteAsync", + map[string]any{ + "tool": name, + }) + result = asyncExec.ExecuteAsync(ctx, args, asyncCallback) + } else { + result = tool.Execute(ctx, args) + } + }() + + // Handle nil result (should not happen, but defensive) + if result == nil { + result = &ToolResult{ + ForLLM: fmt.Sprintf("Tool '%s' returned nil result unexpectedly", name), + ForUser: fmt.Sprintf("Tool '%s' returned nil result unexpectedly", name), + IsError: true, + Err: fmt.Errorf("nil result from tool"), + } } + duration := time.Since(start) // Log based on result type @@ -303,6 +336,28 @@ func (r *ToolRegistry) List() []string { return r.sortedToolNames() } +// Clone creates an independent copy of the registry containing the same tool +// entries (shallow copy of each ToolEntry). This is used to give subagents a +// snapshot of the parent agent's tools without sharing the same registry — +// tools registered on the parent after cloning (e.g. spawn, spawn_status) +// will NOT be visible to the clone, preventing recursive subagent spawning. +// The version counter is reset to 0 in the clone as it's a new independent registry. +func (r *ToolRegistry) Clone() *ToolRegistry { + r.mu.RLock() + defer r.mu.RUnlock() + clone := &ToolRegistry{ + tools: make(map[string]*ToolEntry, len(r.tools)), + } + for name, entry := range r.tools { + clone.tools[name] = &ToolEntry{ + Tool: entry.Tool, + IsCore: entry.IsCore, + TTL: entry.TTL, + } + } + return clone +} + // Count returns the number of registered tools. func (r *ToolRegistry) Count() int { r.mu.RLock() diff --git a/pkg/tools/registry_test.go b/pkg/tools/registry_test.go index 92d7d5abd..967758dfa 100644 --- a/pkg/tools/registry_test.go +++ b/pkg/tools/registry_test.go @@ -2,6 +2,7 @@ package tools import ( "context" + "errors" "strings" "sync" "testing" @@ -335,6 +336,96 @@ func TestToolToSchema(t *testing.T) { } } +func TestToolRegistry_Clone(t *testing.T) { + r := NewToolRegistry() + r.Register(newMockTool("read_file", "reads files")) + r.Register(newMockTool("exec", "runs commands")) + r.Register(newMockTool("web_search", "searches the web")) + + clone := r.Clone() + + // Clone should have the same tools + if clone.Count() != 3 { + t.Errorf("expected clone to have 3 tools, got %d", clone.Count()) + } + for _, name := range []string{"read_file", "exec", "web_search"} { + if _, ok := clone.Get(name); !ok { + t.Errorf("expected clone to have tool %q", name) + } + } + + // Registering on parent should NOT affect clone + r.Register(newMockTool("spawn", "spawns subagent")) + if r.Count() != 4 { + t.Errorf("expected parent to have 4 tools, got %d", r.Count()) + } + if clone.Count() != 3 { + t.Errorf("expected clone to still have 3 tools after parent mutation, got %d", clone.Count()) + } + if _, ok := clone.Get("spawn"); ok { + t.Error("expected clone NOT to have 'spawn' tool registered on parent after cloning") + } + + // Registering on clone should NOT affect parent + clone.Register(newMockTool("custom", "custom tool")) + if clone.Count() != 4 { + t.Errorf("expected clone to have 4 tools, got %d", clone.Count()) + } + if _, ok := r.Get("custom"); ok { + t.Error("expected parent NOT to have 'custom' tool registered on clone") + } +} + +func TestToolRegistry_Clone_Empty(t *testing.T) { + r := NewToolRegistry() + clone := r.Clone() + if clone.Count() != 0 { + t.Errorf("expected empty clone, got count %d", clone.Count()) + } +} + +func TestToolRegistry_Clone_PreservesHiddenToolState(t *testing.T) { + r := NewToolRegistry() + r.RegisterHidden(newMockTool("mcp_tool", "dynamic MCP tool")) + + clone := r.Clone() + + // Hidden tools with TTL=0 should not be gettable (same behavior as parent) + if _, ok := clone.Get("mcp_tool"); ok { + t.Error("expected hidden tool with TTL=0 to be invisible in clone") + } + + // But the entry should exist (count includes hidden tools) + if clone.Count() != 1 { + t.Errorf("expected clone count 1 (hidden entry exists), got %d", clone.Count()) + } +} + +func TestToolRegistry_Clone_PreservesTTLValue(t *testing.T) { + r := NewToolRegistry() + r.RegisterHidden(newMockTool("ttl_tool", "tool with TTL")) + + // Manually set a non-zero TTL on the entry + r.mu.RLock() + if entry, ok := r.tools["ttl_tool"]; ok { + entry.TTL = 5 + } + r.mu.RUnlock() + + clone := r.Clone() + + // Verify TTL value is preserved in the clone + clone.mu.RLock() + defer clone.mu.RUnlock() + entry, ok := clone.tools["ttl_tool"] + if !ok { + t.Fatal("expected ttl_tool to exist in clone") + } + if entry.TTL != 5 { + t.Errorf("expected TTL=5 in clone, got %d", entry.TTL) + } +} + func TestToolRegistry_ConcurrentAccess(t *testing.T) { r := NewToolRegistry() var wg sync.WaitGroup @@ -358,3 +449,175 @@ func TestToolRegistry_ConcurrentAccess(t *testing.T) { t.Error("expected tools to be registered after concurrent access") } } + +// --- Panic and abnormal exit tests --- + +// mockPanicTool is a tool that panics during execution +type mockPanicTool struct { + name string + panicValue any +} + +func (m *mockPanicTool) Name() string { return m.name } +func (m *mockPanicTool) Description() string { return "a tool that panics" } +func (m *mockPanicTool) Parameters() map[string]any { return map[string]any{"type": "object"} } +func (m *mockPanicTool) Execute(_ context.Context, _ map[string]any) *ToolResult { + panic(m.panicValue) +} + +// mockNilResultTool is a tool that returns nil +type mockNilResultTool struct { + name string +} + +func (m *mockNilResultTool) Name() string { return m.name } +func (m *mockNilResultTool) Description() string { return "a tool that returns nil" } +func (m *mockNilResultTool) Parameters() map[string]any { return map[string]any{"type": "object"} } +func (m *mockNilResultTool) Execute(_ context.Context, _ map[string]any) *ToolResult { + return nil +} + +func TestToolRegistry_Execute_PanicRecovery(t *testing.T) { + r := NewToolRegistry() + r.Register(&mockPanicTool{ + name: "panic_tool", + panicValue: "something went terribly wrong", + }) + + // Should not panic, should return error result + result := r.Execute(context.Background(), "panic_tool", nil) + + if result == nil { + t.Fatal("expected non-nil result after panic recovery") + } + if !result.IsError { + t.Error("expected IsError=true after panic") + } + if !strings.Contains(result.ForLLM, "panic") { + t.Errorf("expected 'panic' in error message, got %q", result.ForLLM) + } + if !strings.Contains(result.ForLLM, "panic_tool") { + t.Errorf("expected tool name in error message, got %q", result.ForLLM) + } + if !strings.Contains(result.ForLLM, "something went terribly wrong") { + t.Errorf("expected panic value in error message, got %q", result.ForLLM) + } + if result.Err == nil { + t.Error("expected Err to be set") + } +} + +func TestToolRegistry_Execute_PanicRecovery_ErrorType(t *testing.T) { + r := NewToolRegistry() + + // Test with error type panic + r.Register(&mockPanicTool{ + name: "error_panic_tool", + panicValue: errors.New("custom error panic"), + }) + + result := r.Execute(context.Background(), "error_panic_tool", nil) + + if !result.IsError { + t.Error("expected IsError=true") + } + if !strings.Contains(result.ForLLM, "custom error panic") { + t.Errorf("expected error message in ForLLM, got %q", result.ForLLM) + } +} + +func TestToolRegistry_Execute_PanicRecovery_IntType(t *testing.T) { + r := NewToolRegistry() + + // Test with int type panic + r.Register(&mockPanicTool{ + name: "int_panic_tool", + panicValue: 42, + }) + + result := r.Execute(context.Background(), "int_panic_tool", nil) + + if !result.IsError { + t.Error("expected IsError=true") + } + if !strings.Contains(result.ForLLM, "42") { + t.Errorf("expected panic value '42' in ForLLM, got %q", result.ForLLM) + } +} + +func TestToolRegistry_Execute_NilResultHandling(t *testing.T) { + r := NewToolRegistry() + r.Register(&mockNilResultTool{name: "nil_tool"}) + + result := r.Execute(context.Background(), "nil_tool", nil) + + if result == nil { + t.Fatal("expected non-nil result when tool returns nil") + } + if !result.IsError { + t.Error("expected IsError=true for nil result") + } + if !strings.Contains(result.ForLLM, "nil_tool") { + t.Errorf("expected tool name in error message, got %q", result.ForLLM) + } + if !strings.Contains(result.ForLLM, "nil result") { + t.Errorf("expected 'nil result' in error message, got %q", result.ForLLM) + } + if result.Err == nil { + t.Error("expected Err to be set") + } +} + +func TestToolRegistry_ExecuteWithContext_PanicRecovery(t *testing.T) { + r := NewToolRegistry() + r.Register(&mockPanicTool{ + name: "ctx_panic_tool", + panicValue: "context panic test", + }) + + // Should not panic even with context + result := r.ExecuteWithContext( + context.Background(), + "ctx_panic_tool", + map[string]any{"key": "value"}, + "telegram", + "chat-123", + nil, + ) + + if result == nil { + t.Fatal("expected non-nil result") + } + if !result.IsError { + t.Error("expected IsError=true") + } + if !strings.Contains(result.ForLLM, "context panic test") { + t.Errorf("expected panic message, got %q", result.ForLLM) + } +} + +func TestToolRegistry_Execute_PanicDoesNotAffectOtherTools(t *testing.T) { + r := NewToolRegistry() + r.Register(&mockPanicTool{name: "bad_tool", panicValue: "boom"}) + r.Register(&mockRegistryTool{ + name: "good_tool", + desc: "works fine", + params: map[string]any{}, + result: SilentResult("success"), + }) + + // First, trigger the panic + result1 := r.Execute(context.Background(), "bad_tool", nil) + if !result1.IsError { + t.Error("expected error from panic tool") + } + + // Then, verify the good tool still works + result2 := r.Execute(context.Background(), "good_tool", nil) + if result2.IsError { + t.Errorf("expected success from good tool, got error: %s", result2.ForLLM) + } + if result2.ForLLM != "success" { + t.Errorf("expected 'success', got %q", result2.ForLLM) + } +} diff --git a/pkg/tools/shell.go b/pkg/tools/shell.go index 0dc85ae21..78ad2b26d 100644 --- a/pkg/tools/shell.go +++ b/pkg/tools/shell.go @@ -311,13 +311,30 @@ func (t *ExecTool) Execute(ctx context.Context, args map[string]any) *ToolResult if err != nil { if errors.Is(cmdCtx.Err(), context.DeadlineExceeded) { msg := fmt.Sprintf("Command timed out after %v", t.timeout) + if output != "" { + msg += "\n\nPartial output before timeout:\n" + output + } return &ToolResult{ ForLLM: msg, ForUser: msg, IsError: true, + Err: fmt.Errorf("command timeout: %w", err), } } - output += fmt.Sprintf("\nExit code: %v", err) + + // Extract detailed exit information + var exitErr *exec.ExitError + if errors.As(err, &exitErr) { + exitCode := exitErr.ExitCode() + output += fmt.Sprintf("\n\n[Command exited with code %d]", exitCode) + + // Add signal information if killed by signal (Unix) + if exitCode == -1 { + output += " (killed by signal)" + } + } else { + output += fmt.Sprintf("\n\n[Command failed: %v]", err) + } } if output == "" { diff --git a/pkg/tools/shell_test.go b/pkg/tools/shell_test.go index c4553020f..f8f83ea74 100644 --- a/pkg/tools/shell_test.go +++ b/pkg/tools/shell_test.go @@ -489,6 +489,69 @@ func TestShellTool_SafePathsInWorkspaceRestriction(t *testing.T) { } } +// TestShellTool_ExitCodeDetails verifies that exit codes are captured with details +func TestShellTool_ExitCodeDetails(t *testing.T) { + tool, err := NewExecTool("", false) + if err != nil { + t.Fatalf("unable to configure exec tool: %s", err) + } + + ctx := context.Background() + args := map[string]any{ + "command": "sh -c 'exit 42'", + } + + result := tool.Execute(ctx, args) + + if !result.IsError { + t.Error("expected error for non-zero exit code") + } + + // Should contain the exit code in the message (new format: "exited with code 42") + if !strings.Contains(result.ForLLM, "42") { + t.Errorf("expected exit code 42 in error message, got: %s", result.ForLLM) + } + + // Verify the new detailed message format + if !strings.Contains(result.ForLLM, "exited with code") { + t.Errorf("expected 'exited with code' in message, got: %s", result.ForLLM) + } + + // Err field is set by the exec system (may or may not be set depending on implementation) + // The important thing is that IsError=true + t.Logf("Exit code result: %s", result.ForLLM) +} + +// TestShellTool_TimeoutWithPartialOutput verifies timeout includes partial output +func TestShellTool_TimeoutWithPartialOutput(t *testing.T) { + tool, err := NewExecTool("", false) + if err != nil { + t.Fatalf("unable to configure exec tool: %s", err) + } + + tool.SetTimeout(1 * time.Second) // Give more time for echo to complete + + ctx := context.Background() + // Use a command that outputs immediately then sleeps + args := map[string]any{ + "command": "echo 'partial output before timeout' && sleep 30", + } + + result := tool.Execute(ctx, args) + + if !result.IsError { + t.Error("expected error for timeout") + } + + // Should mention timeout + if !strings.Contains(result.ForLLM, "timed out") { + t.Errorf("expected 'timed out' in message, got: %s", result.ForLLM) + } + + // Log the result for debugging (partial output depends on shell behavior) + t.Logf("Timeout result: %s", result.ForLLM) +} + // TestShellTool_CustomAllowPatterns verifies that custom allow patterns exempt // commands from deny pattern checks. func TestShellTool_CustomAllowPatterns(t *testing.T) { diff --git a/web/frontend/src/components/channels/channel-forms/feishu-form.tsx b/web/frontend/src/components/channels/channel-forms/feishu-form.tsx index a834a65f9..386adf9a5 100644 --- a/web/frontend/src/components/channels/channel-forms/feishu-form.tsx +++ b/web/frontend/src/components/channels/channel-forms/feishu-form.tsx @@ -2,7 +2,7 @@ import { useTranslation } from "react-i18next" import type { ChannelConfig } from "@/api/channels" import { maskedSecretPlaceholder } from "@/components/secret-placeholder" -import { Field, KeyInput } from "@/components/shared-form" +import { Field, KeyInput, SwitchCardField } from "@/components/shared-form" import { Input } from "@/components/ui/input" interface FeishuFormProps { @@ -16,6 +16,10 @@ function asString(value: unknown): string { return typeof value === "string" ? value : "" } +function asBool(value: unknown): boolean { + return typeof value === "boolean" ? value : false +} + function asStringArray(value: unknown): string[] { if (!Array.isArray(value)) return [] return value.filter((item): item is string => typeof item === "string") @@ -98,6 +102,12 @@ export function FeishuForm({ )} /> + onChange("is_lark", checked)} + />