mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
Merge branch 'upstream-main' into feat/subturn-poc
This commit is contained in:
@@ -122,7 +122,8 @@
|
||||
"verification_token": "",
|
||||
"allow_from": [],
|
||||
"reasoning_channel_id": "",
|
||||
"random_reaction_emoji": []
|
||||
"random_reaction_emoji": [],
|
||||
"is_lark": false
|
||||
},
|
||||
"dingtalk": {
|
||||
"enabled": false,
|
||||
|
||||
@@ -13,25 +13,27 @@
|
||||
"app_secret": "xxx",
|
||||
"encrypt_key": "",
|
||||
"verification_token": "",
|
||||
"allow_from": []
|
||||
"allow_from": [],
|
||||
"is_lark": false
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
| 字段 | 类型 | 必填 | 描述 |
|
||||
| ------------------ | ------ | ---- | -------------------------------- |
|
||||
| enabled | bool | 是 | 是否启用飞书频道 |
|
||||
| app_id | string | 是 | 飞书应用的 App ID(以cli\_开头) |
|
||||
| app_secret | string | 是 | 飞书应用的 App Secret |
|
||||
| encrypt_key | string | 否 | 事件回调加密密钥 |
|
||||
| verification_token | string | 否 | 用于Webhook事件验证的Token |
|
||||
| allow_from | array | 否 | 用户ID白名单,空表示所有用户 |
|
||||
| random_reaction_emoji | array | 否 | 随机添加的表情列表,空则使用默认 "Pin" |
|
||||
| 字段 | 类型 | 必填 | 描述 |
|
||||
| --------------------- | ------ | ---- | ------------------------------------------------------------------------------------------------ |
|
||||
| enabled | bool | 是 | 是否启用飞书频道 |
|
||||
| app_id | string | 是 | 飞书应用的 App ID(以cli\_开头) |
|
||||
| app_secret | string | 是 | 飞书应用的 App Secret |
|
||||
| encrypt_key | string | 否 | 事件回调加密密钥 |
|
||||
| verification_token | string | 否 | 用于Webhook事件验证的Token |
|
||||
| allow_from | array | 否 | 用户ID白名单,空表示所有用户 |
|
||||
| random_reaction_emoji | array | 否 | 随机添加的表情列表,空则使用默认 "Pin" |
|
||||
| is_lark | bool | 否 | 是否使用 Lark 国际版域名(`open.larksuite.com`),默认为 `false`(使用飞书域名 `open.feishu.cn`) |
|
||||
|
||||
## 设置流程
|
||||
|
||||
1. 前往 [飞书开放平台](https://open.feishu.cn/)创建应用程序
|
||||
1. 前往 [飞书开放平台](https://open.feishu.cn/)(国际版用户请前往 [Lark 开放平台](https://open.larksuite.com/))创建应用程序
|
||||
2. 获取 App ID 和 App Secret
|
||||
3. 配置事件订阅和Webhook URL
|
||||
4. 设置加密(可选,生产环境建议启用)
|
||||
|
||||
@@ -307,6 +307,11 @@ func registerSharedTools(
|
||||
return spawnSubTurn(ctx, al, parentTS, cfg)
|
||||
})
|
||||
|
||||
// Clone the parent's tool registry so subagents can use all
|
||||
// tools registered so far (file, web, etc.) but NOT spawn/
|
||||
// spawn_status which are added below — preventing recursive
|
||||
// subagent spawning.
|
||||
subagentManager.SetTools(agent.Tools.Clone())
|
||||
if spawnEnabled {
|
||||
spawnTool := tools.NewSpawnTool(subagentManager)
|
||||
currentAgentID := agentID
|
||||
|
||||
@@ -54,11 +54,15 @@ func NewFeishuChannel(cfg config.FeishuConfig, bus *bus.MessageBus) (*FeishuChan
|
||||
)
|
||||
|
||||
tc := newTokenCache()
|
||||
opts := []lark.ClientOptionFunc{lark.WithTokenCache(tc)}
|
||||
if cfg.IsLark {
|
||||
opts = append(opts, lark.WithOpenBaseUrl(lark.LarkBaseUrl))
|
||||
}
|
||||
ch := &FeishuChannel{
|
||||
BaseChannel: base,
|
||||
config: cfg,
|
||||
tokenCache: tc,
|
||||
client: lark.NewClient(cfg.AppID, cfg.AppSecret, lark.WithTokenCache(tc)),
|
||||
client: lark.NewClient(cfg.AppID, cfg.AppSecret, opts...),
|
||||
}
|
||||
ch.SetOwner(ch)
|
||||
return ch, nil
|
||||
@@ -83,10 +87,15 @@ func (c *FeishuChannel) Start(ctx context.Context) error {
|
||||
|
||||
c.mu.Lock()
|
||||
c.cancel = cancel
|
||||
domain := lark.FeishuBaseUrl
|
||||
if c.config.IsLark {
|
||||
domain = lark.LarkBaseUrl
|
||||
}
|
||||
c.wsClient = larkws.NewClient(
|
||||
c.config.AppID,
|
||||
c.config.AppSecret,
|
||||
larkws.WithEventHandler(dispatcher),
|
||||
larkws.WithDomain(domain),
|
||||
)
|
||||
wsClient := c.wsClient
|
||||
c.mu.Unlock()
|
||||
|
||||
+103
-3
@@ -326,6 +326,7 @@ type FeishuConfig struct {
|
||||
Placeholder PlaceholderConfig `json:"placeholder,omitempty"`
|
||||
ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_FEISHU_REASONING_CHANNEL_ID"`
|
||||
RandomReactionEmoji FlexibleStringSlice `json:"random_reaction_emoji" env:"PICOCLAW_CHANNELS_FEISHU_RANDOM_REACTION_EMOJI"`
|
||||
IsLark bool `json:"is_lark" env:"PICOCLAW_CHANNELS_FEISHU_IS_LARK"`
|
||||
}
|
||||
|
||||
type DiscordConfig struct {
|
||||
@@ -603,9 +604,11 @@ type ModelConfig struct {
|
||||
Model string `json:"model"` // Protocol/model-identifier (e.g., "openai/gpt-4o", "anthropic/claude-sonnet-4.6")
|
||||
|
||||
// HTTP-based providers
|
||||
APIBase string `json:"api_base,omitempty"` // API endpoint URL
|
||||
APIKey string `json:"api_key"` // API authentication key
|
||||
Proxy string `json:"proxy,omitempty"` // HTTP proxy URL
|
||||
APIBase string `json:"api_base,omitempty"` // API endpoint URL
|
||||
APIKey string `json:"api_key"` // API authentication key (single key)
|
||||
APIKeys []string `json:"api_keys,omitempty"` // API authentication keys (multiple keys for failover)
|
||||
Proxy string `json:"proxy,omitempty"` // HTTP proxy URL
|
||||
Fallbacks []string `json:"fallbacks,omitempty"` // Fallback model names for failover
|
||||
|
||||
// Special providers (CLI-based, OAuth, etc.)
|
||||
AuthMethod string `json:"auth_method,omitempty"` // Authentication method: oauth, token
|
||||
@@ -874,6 +877,9 @@ func LoadConfig(path string) (*Config, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Expand multi-key configs into separate entries for key-level failover
|
||||
cfg.ModelList = ExpandMultiKeyModels(cfg.ModelList)
|
||||
|
||||
// Migrate legacy channel config fields to new unified structures
|
||||
cfg.migrateChannelConfigs()
|
||||
|
||||
@@ -920,14 +926,25 @@ func encryptPlaintextAPIKeys(models []ModelConfig, passphrase string) ([]ModelCo
|
||||
|
||||
// resolveAPIKeys decrypts or dereferences each api_key in models in-place.
|
||||
// Supports plaintext (no-op), file:// (read from configDir), and enc:// (AES-GCM decrypt).
|
||||
// Also resolves api_keys array if present.
|
||||
func resolveAPIKeys(models []ModelConfig, configDir string) error {
|
||||
cr := credential.NewResolver(configDir)
|
||||
for i := range models {
|
||||
// Resolve single APIKey
|
||||
resolved, err := cr.Resolve(models[i].APIKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("model_list[%d] (%s): %w", i, models[i].ModelName, err)
|
||||
}
|
||||
models[i].APIKey = resolved
|
||||
|
||||
// Resolve APIKeys array
|
||||
for j, key := range models[i].APIKeys {
|
||||
resolved, err := cr.Resolve(key)
|
||||
if err != nil {
|
||||
return fmt.Errorf("model_list[%d] (%s): api_keys[%d]: %w", i, models[i].ModelName, j, err)
|
||||
}
|
||||
models[i].APIKeys[j] = resolved
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1098,6 +1115,89 @@ func MergeAPIKeys(apiKey string, apiKeys []string) []string {
|
||||
return all
|
||||
}
|
||||
|
||||
// ExpandMultiKeyModels expands ModelConfig entries with multiple API keys into
|
||||
// separate entries for key-level failover. Each key gets its own ModelConfig entry,
|
||||
// and the original entry's fallbacks are set up to chain through the expanded entries.
|
||||
//
|
||||
// Example: {"model_name": "gpt-4", "api_keys": ["k1", "k2", "k3"]}
|
||||
// Becomes:
|
||||
// - {"model_name": "gpt-4", "api_key": "k1", "fallbacks": ["gpt-4__key_1", "gpt-4__key_2"]}
|
||||
// - {"model_name": "gpt-4__key_1", "api_key": "k2"}
|
||||
// - {"model_name": "gpt-4__key_2", "api_key": "k3"}
|
||||
func ExpandMultiKeyModels(models []ModelConfig) []ModelConfig {
|
||||
var expanded []ModelConfig
|
||||
|
||||
for _, m := range models {
|
||||
keys := MergeAPIKeys(m.APIKey, m.APIKeys)
|
||||
|
||||
// Single key or no keys: keep as-is
|
||||
if len(keys) <= 1 {
|
||||
// Ensure APIKey is set from APIKeys if needed
|
||||
if m.APIKey == "" && len(keys) == 1 {
|
||||
m.APIKey = keys[0]
|
||||
}
|
||||
m.APIKeys = nil // Clear APIKeys to avoid confusion
|
||||
expanded = append(expanded, m)
|
||||
continue
|
||||
}
|
||||
|
||||
// Multiple keys: expand
|
||||
originalName := m.ModelName
|
||||
|
||||
// Create entries for additional keys (key_1, key_2, ...)
|
||||
var fallbackNames []string
|
||||
for i := 1; i < len(keys); i++ {
|
||||
suffix := fmt.Sprintf("__key_%d", i)
|
||||
expandedName := originalName + suffix
|
||||
|
||||
// Create a copy for the additional key
|
||||
additionalEntry := ModelConfig{
|
||||
ModelName: expandedName,
|
||||
Model: m.Model,
|
||||
APIBase: m.APIBase,
|
||||
APIKey: keys[i],
|
||||
Proxy: m.Proxy,
|
||||
AuthMethod: m.AuthMethod,
|
||||
ConnectMode: m.ConnectMode,
|
||||
Workspace: m.Workspace,
|
||||
RPM: m.RPM,
|
||||
MaxTokensField: m.MaxTokensField,
|
||||
RequestTimeout: m.RequestTimeout,
|
||||
ThinkingLevel: m.ThinkingLevel,
|
||||
}
|
||||
expanded = append(expanded, additionalEntry)
|
||||
fallbackNames = append(fallbackNames, expandedName)
|
||||
}
|
||||
|
||||
// Create the primary entry with first key and fallbacks
|
||||
primaryEntry := ModelConfig{
|
||||
ModelName: originalName,
|
||||
Model: m.Model,
|
||||
APIBase: m.APIBase,
|
||||
APIKey: keys[0],
|
||||
Proxy: m.Proxy,
|
||||
AuthMethod: m.AuthMethod,
|
||||
ConnectMode: m.ConnectMode,
|
||||
Workspace: m.Workspace,
|
||||
RPM: m.RPM,
|
||||
MaxTokensField: m.MaxTokensField,
|
||||
RequestTimeout: m.RequestTimeout,
|
||||
ThinkingLevel: m.ThinkingLevel,
|
||||
}
|
||||
|
||||
// Prepend new fallbacks to existing ones
|
||||
if len(fallbackNames) > 0 {
|
||||
primaryEntry.Fallbacks = append(fallbackNames, m.Fallbacks...)
|
||||
} else if len(m.Fallbacks) > 0 {
|
||||
primaryEntry.Fallbacks = m.Fallbacks
|
||||
}
|
||||
|
||||
expanded = append(expanded, primaryEntry)
|
||||
}
|
||||
|
||||
return expanded
|
||||
}
|
||||
|
||||
func (t *ToolsConfig) IsToolEnabled(name string) bool {
|
||||
switch name {
|
||||
case "web":
|
||||
|
||||
@@ -0,0 +1,291 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestExpandMultiKeyModels_SingleKey(t *testing.T) {
|
||||
models := []ModelConfig{
|
||||
{
|
||||
ModelName: "gpt-4",
|
||||
Model: "openai/gpt-4o",
|
||||
APIKey: "single-key",
|
||||
},
|
||||
}
|
||||
|
||||
result := ExpandMultiKeyModels(models)
|
||||
|
||||
if len(result) != 1 {
|
||||
t.Fatalf("expected 1 model, got %d", len(result))
|
||||
}
|
||||
|
||||
if result[0].ModelName != "gpt-4" {
|
||||
t.Errorf("expected model_name 'gpt-4', got %q", result[0].ModelName)
|
||||
}
|
||||
|
||||
if result[0].APIKey != "single-key" {
|
||||
t.Errorf("expected api_key 'single-key', got %q", result[0].APIKey)
|
||||
}
|
||||
|
||||
if len(result[0].Fallbacks) != 0 {
|
||||
t.Errorf("expected no fallbacks, got %v", result[0].Fallbacks)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpandMultiKeyModels_APIKeysOnly(t *testing.T) {
|
||||
models := []ModelConfig{
|
||||
{
|
||||
ModelName: "glm-4.7",
|
||||
Model: "zhipu/glm-4.7",
|
||||
APIBase: "https://api.example.com",
|
||||
APIKeys: []string{"key1", "key2", "key3"},
|
||||
},
|
||||
}
|
||||
|
||||
result := ExpandMultiKeyModels(models)
|
||||
|
||||
// Should expand to 3 models
|
||||
if len(result) != 3 {
|
||||
t.Fatalf("expected 3 models, got %d", len(result))
|
||||
}
|
||||
|
||||
// First entry should be the primary with key1 and fallbacks
|
||||
primary := result[2] // Primary is added last
|
||||
if primary.ModelName != "glm-4.7" {
|
||||
t.Errorf("expected primary model_name 'glm-4.7', got %q", primary.ModelName)
|
||||
}
|
||||
if primary.APIKey != "key1" {
|
||||
t.Errorf("expected primary api_key 'key1', got %q", primary.APIKey)
|
||||
}
|
||||
if len(primary.Fallbacks) != 2 {
|
||||
t.Errorf("expected 2 fallbacks, got %d", len(primary.Fallbacks))
|
||||
}
|
||||
if primary.Fallbacks[0] != "glm-4.7__key_1" {
|
||||
t.Errorf("expected first fallback 'glm-4.7__key_1', got %q", primary.Fallbacks[0])
|
||||
}
|
||||
if primary.Fallbacks[1] != "glm-4.7__key_2" {
|
||||
t.Errorf("expected second fallback 'glm-4.7__key_2', got %q", primary.Fallbacks[1])
|
||||
}
|
||||
|
||||
// Second entry should be key2
|
||||
second := result[0]
|
||||
if second.ModelName != "glm-4.7__key_1" {
|
||||
t.Errorf("expected second model_name 'glm-4.7__key_1', got %q", second.ModelName)
|
||||
}
|
||||
if second.APIKey != "key2" {
|
||||
t.Errorf("expected second api_key 'key2', got %q", second.APIKey)
|
||||
}
|
||||
|
||||
// Third entry should be key3
|
||||
third := result[1]
|
||||
if third.ModelName != "glm-4.7__key_2" {
|
||||
t.Errorf("expected third model_name 'glm-4.7__key_2', got %q", third.ModelName)
|
||||
}
|
||||
if third.APIKey != "key3" {
|
||||
t.Errorf("expected third api_key 'key3', got %q", third.APIKey)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpandMultiKeyModels_APIKeyAndAPIKeys(t *testing.T) {
|
||||
models := []ModelConfig{
|
||||
{
|
||||
ModelName: "gpt-4",
|
||||
Model: "openai/gpt-4o",
|
||||
APIKey: "key0",
|
||||
APIKeys: []string{"key1", "key2"},
|
||||
},
|
||||
}
|
||||
|
||||
result := ExpandMultiKeyModels(models)
|
||||
|
||||
// Should expand to 3 models (key0 from APIKey + key1, key2 from APIKeys)
|
||||
if len(result) != 3 {
|
||||
t.Fatalf("expected 3 models, got %d", len(result))
|
||||
}
|
||||
|
||||
// Primary should use key0
|
||||
primary := result[2]
|
||||
if primary.APIKey != "key0" {
|
||||
t.Errorf("expected primary api_key 'key0', got %q", primary.APIKey)
|
||||
}
|
||||
if len(primary.Fallbacks) != 2 {
|
||||
t.Errorf("expected 2 fallbacks, got %d", len(primary.Fallbacks))
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpandMultiKeyModels_WithExistingFallbacks(t *testing.T) {
|
||||
models := []ModelConfig{
|
||||
{
|
||||
ModelName: "gpt-4",
|
||||
Model: "openai/gpt-4o",
|
||||
APIKeys: []string{"key1", "key2"},
|
||||
Fallbacks: []string{"claude-3"},
|
||||
},
|
||||
}
|
||||
|
||||
result := ExpandMultiKeyModels(models)
|
||||
|
||||
primary := result[1]
|
||||
// With 2 keys, we get 1 key fallback + 1 existing fallback = 2 total
|
||||
if len(primary.Fallbacks) != 2 {
|
||||
t.Fatalf("expected 2 fallbacks, got %d: %v", len(primary.Fallbacks), primary.Fallbacks)
|
||||
}
|
||||
|
||||
// Key fallbacks should come first, then existing fallbacks
|
||||
if primary.Fallbacks[0] != "gpt-4__key_1" {
|
||||
t.Errorf("expected first fallback 'gpt-4__key_1', got %q", primary.Fallbacks[0])
|
||||
}
|
||||
if primary.Fallbacks[1] != "claude-3" {
|
||||
t.Errorf("expected second fallback 'claude-3', got %q", primary.Fallbacks[1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpandMultiKeyModels_EmptyAPIKeys(t *testing.T) {
|
||||
models := []ModelConfig{
|
||||
{
|
||||
ModelName: "gpt-4",
|
||||
Model: "openai/gpt-4o",
|
||||
APIKey: "",
|
||||
APIKeys: []string{},
|
||||
},
|
||||
}
|
||||
|
||||
result := ExpandMultiKeyModels(models)
|
||||
|
||||
// Should keep as-is with no changes
|
||||
if len(result) != 1 {
|
||||
t.Fatalf("expected 1 model, got %d", len(result))
|
||||
}
|
||||
|
||||
if result[0].ModelName != "gpt-4" {
|
||||
t.Errorf("expected model_name 'gpt-4', got %q", result[0].ModelName)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpandMultiKeyModels_Deduplication(t *testing.T) {
|
||||
models := []ModelConfig{
|
||||
{
|
||||
ModelName: "gpt-4",
|
||||
Model: "openai/gpt-4o",
|
||||
APIKey: "key1",
|
||||
APIKeys: []string{"key1", "key2", "key1"}, // Duplicate key1
|
||||
},
|
||||
}
|
||||
|
||||
result := ExpandMultiKeyModels(models)
|
||||
|
||||
// Should only create 2 models (deduplicated keys)
|
||||
if len(result) != 2 {
|
||||
t.Fatalf("expected 2 models (deduplicated), got %d", len(result))
|
||||
}
|
||||
|
||||
primary := result[1]
|
||||
if primary.APIKey != "key1" {
|
||||
t.Errorf("expected primary api_key 'key1', got %q", primary.APIKey)
|
||||
}
|
||||
if len(primary.Fallbacks) != 1 {
|
||||
t.Errorf("expected 1 fallback, got %d", len(primary.Fallbacks))
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpandMultiKeyModels_PreservesOtherFields(t *testing.T) {
|
||||
models := []ModelConfig{
|
||||
{
|
||||
ModelName: "gpt-4",
|
||||
Model: "openai/gpt-4o",
|
||||
APIBase: "https://api.example.com",
|
||||
APIKeys: []string{"key1", "key2"},
|
||||
Proxy: "http://proxy:8080",
|
||||
RPM: 60,
|
||||
MaxTokensField: "max_completion_tokens",
|
||||
RequestTimeout: 30,
|
||||
ThinkingLevel: "high",
|
||||
},
|
||||
}
|
||||
|
||||
result := ExpandMultiKeyModels(models)
|
||||
|
||||
// Check primary entry preserves all fields
|
||||
primary := result[1]
|
||||
if primary.APIBase != "https://api.example.com" {
|
||||
t.Errorf("expected api_base preserved, got %q", primary.APIBase)
|
||||
}
|
||||
if primary.Proxy != "http://proxy:8080" {
|
||||
t.Errorf("expected proxy preserved, got %q", primary.Proxy)
|
||||
}
|
||||
if primary.RPM != 60 {
|
||||
t.Errorf("expected rpm preserved, got %d", primary.RPM)
|
||||
}
|
||||
if primary.MaxTokensField != "max_completion_tokens" {
|
||||
t.Errorf("expected max_tokens_field preserved, got %q", primary.MaxTokensField)
|
||||
}
|
||||
if primary.RequestTimeout != 30 {
|
||||
t.Errorf("expected request_timeout preserved, got %d", primary.RequestTimeout)
|
||||
}
|
||||
if primary.ThinkingLevel != "high" {
|
||||
t.Errorf("expected thinking_level preserved, got %q", primary.ThinkingLevel)
|
||||
}
|
||||
|
||||
// Check additional entry also preserves fields
|
||||
additional := result[0]
|
||||
if additional.APIBase != "https://api.example.com" {
|
||||
t.Errorf("expected additional api_base preserved, got %q", additional.APIBase)
|
||||
}
|
||||
if additional.RPM != 60 {
|
||||
t.Errorf("expected additional rpm preserved, got %d", additional.RPM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeAPIKeys(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
apiKey string
|
||||
apiKeys []string
|
||||
expected []string
|
||||
}{
|
||||
{
|
||||
name: "both empty",
|
||||
apiKey: "",
|
||||
apiKeys: nil,
|
||||
expected: nil,
|
||||
},
|
||||
{
|
||||
name: "only apiKey",
|
||||
apiKey: "key1",
|
||||
apiKeys: nil,
|
||||
expected: []string{"key1"},
|
||||
},
|
||||
{
|
||||
name: "only apiKeys",
|
||||
apiKey: "",
|
||||
apiKeys: []string{"key1", "key2"},
|
||||
expected: []string{"key1", "key2"},
|
||||
},
|
||||
{
|
||||
name: "both with overlap",
|
||||
apiKey: "key1",
|
||||
apiKeys: []string{"key1", "key2", "key3"},
|
||||
expected: []string{"key1", "key2", "key3"},
|
||||
},
|
||||
{
|
||||
name: "with whitespace",
|
||||
apiKey: " key1 ",
|
||||
apiKeys: []string{" key2 ", " key1 "},
|
||||
expected: []string{"key1", "key2"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := MergeAPIKeys(tt.apiKey, tt.apiKeys)
|
||||
if len(result) != len(tt.expected) {
|
||||
t.Fatalf("expected %d keys, got %d", len(tt.expected), len(result))
|
||||
}
|
||||
for i, k := range result {
|
||||
if k != tt.expected[i] {
|
||||
t.Errorf("expected key[%d] = %q, got %q", i, tt.expected[i], k)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -117,17 +117,19 @@ func (fc *FallbackChain) Execute(
|
||||
return nil, context.Canceled
|
||||
}
|
||||
|
||||
// Check cooldown.
|
||||
if !fc.cooldown.IsAvailable(candidate.Provider) {
|
||||
remaining := fc.cooldown.CooldownRemaining(candidate.Provider)
|
||||
// Check cooldown (per provider/model, not just provider).
|
||||
// This allows multi-key failover where different keys use different model names.
|
||||
cooldownKey := ModelKey(candidate.Provider, candidate.Model)
|
||||
if !fc.cooldown.IsAvailable(cooldownKey) {
|
||||
remaining := fc.cooldown.CooldownRemaining(cooldownKey)
|
||||
result.Attempts = append(result.Attempts, FallbackAttempt{
|
||||
Provider: candidate.Provider,
|
||||
Model: candidate.Model,
|
||||
Skipped: true,
|
||||
Reason: FailoverRateLimit,
|
||||
Error: fmt.Errorf(
|
||||
"provider %s in cooldown (%s remaining)",
|
||||
candidate.Provider,
|
||||
"%s in cooldown (%s remaining)",
|
||||
cooldownKey,
|
||||
remaining.Round(time.Second),
|
||||
),
|
||||
})
|
||||
@@ -141,7 +143,7 @@ func (fc *FallbackChain) Execute(
|
||||
|
||||
if err == nil {
|
||||
// Success.
|
||||
fc.cooldown.MarkSuccess(candidate.Provider)
|
||||
fc.cooldown.MarkSuccess(cooldownKey)
|
||||
result.Response = resp
|
||||
result.Provider = candidate.Provider
|
||||
result.Model = candidate.Model
|
||||
@@ -187,7 +189,7 @@ func (fc *FallbackChain) Execute(
|
||||
}
|
||||
|
||||
// Retriable error: mark failure and continue to next candidate.
|
||||
fc.cooldown.MarkFailure(candidate.Provider, failErr.Reason)
|
||||
fc.cooldown.MarkFailure(cooldownKey, failErr.Reason)
|
||||
result.Attempts = append(result.Attempts, FallbackAttempt{
|
||||
Provider: candidate.Provider,
|
||||
Model: candidate.Model,
|
||||
|
||||
@@ -0,0 +1,384 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestMultiKeyFailover tests the complete failover flow with multiple API keys.
|
||||
// This simulates the config expansion scenario where api_keys: ["key1", "key2", "key3"]
|
||||
// is expanded into primary + fallbacks.
|
||||
func TestMultiKeyFailover(t *testing.T) {
|
||||
// Simulate expanded config: primary with 2 fallbacks
|
||||
// This is what ExpandMultiKeyModels would produce for api_keys: ["key1", "key2", "key3"]
|
||||
cfg := ModelConfig{
|
||||
Primary: "glm-4.7",
|
||||
Fallbacks: []string{"glm-4.7__key_1", "glm-4.7__key_2"},
|
||||
}
|
||||
|
||||
candidates := ResolveCandidates(cfg, "zhipu")
|
||||
|
||||
if len(candidates) != 3 {
|
||||
t.Fatalf("expected 3 candidates, got %d: %v", len(candidates), candidates)
|
||||
}
|
||||
|
||||
// Create fallback chain
|
||||
cooldown := NewCooldownTracker()
|
||||
chain := NewFallbackChain(cooldown)
|
||||
|
||||
// Mock run function: first call fails with 429, second succeeds
|
||||
callCount := 0
|
||||
mockRun := func(ctx context.Context, provider, model string) (*LLMResponse, error) {
|
||||
callCount++
|
||||
if callCount == 1 {
|
||||
// First call: simulate rate limit
|
||||
return nil, errors.New("http error: status 429 - rate limit exceeded")
|
||||
}
|
||||
// Second call: success
|
||||
return &LLMResponse{
|
||||
Content: "Hello from key2!",
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Execute fallback chain
|
||||
result, err := chain.Execute(context.Background(), candidates, mockRun)
|
||||
if err != nil {
|
||||
t.Fatalf("expected success after failover, got error: %v", err)
|
||||
}
|
||||
|
||||
if result == nil {
|
||||
t.Fatal("expected result, got nil")
|
||||
}
|
||||
|
||||
if result.Response.Content != "Hello from key2!" {
|
||||
t.Errorf("expected response from key2, got: %s", result.Response.Content)
|
||||
}
|
||||
|
||||
if callCount != 2 {
|
||||
t.Errorf("expected 2 calls (1 fail + 1 success), got %d", callCount)
|
||||
}
|
||||
|
||||
// Verify first attempt was recorded
|
||||
if len(result.Attempts) != 1 {
|
||||
t.Errorf("expected 1 failed attempt recorded, got %d", len(result.Attempts))
|
||||
}
|
||||
|
||||
if result.Attempts[0].Reason != FailoverRateLimit {
|
||||
t.Errorf(
|
||||
"expected first attempt reason to be rate_limit, got: %s",
|
||||
result.Attempts[0].Reason,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMultiKeyFailoverAllFail tests when all keys hit rate limit
|
||||
func TestMultiKeyFailoverAllFail(t *testing.T) {
|
||||
cfg := ModelConfig{
|
||||
Primary: "glm-4.7",
|
||||
Fallbacks: []string{"glm-4.7__key_1", "glm-4.7__key_2"},
|
||||
}
|
||||
|
||||
candidates := ResolveCandidates(cfg, "zhipu")
|
||||
|
||||
cooldown := NewCooldownTracker()
|
||||
chain := NewFallbackChain(cooldown)
|
||||
|
||||
// Mock run function: all calls fail with rate limit
|
||||
callCount := 0
|
||||
mockRun := func(ctx context.Context, provider, model string) (*LLMResponse, error) {
|
||||
callCount++
|
||||
return nil, errors.New("status: 429 - too many requests")
|
||||
}
|
||||
|
||||
// Execute fallback chain
|
||||
result, err := chain.Execute(context.Background(), candidates, mockRun)
|
||||
|
||||
if err == nil {
|
||||
t.Fatal("expected error when all keys fail, got nil")
|
||||
}
|
||||
|
||||
if result != nil {
|
||||
t.Errorf("expected nil result on failure, got: %v", result)
|
||||
}
|
||||
|
||||
if callCount != 3 {
|
||||
t.Errorf("expected 3 calls (all fail), got %d", callCount)
|
||||
}
|
||||
|
||||
// Verify error type
|
||||
var exhausted *FallbackExhaustedError
|
||||
if !errors.As(err, &exhausted) {
|
||||
t.Errorf("expected FallbackExhaustedError, got: %T - %v", err, err)
|
||||
}
|
||||
|
||||
if len(exhausted.Attempts) != 3 {
|
||||
t.Errorf("expected 3 attempts in exhausted error, got %d", len(exhausted.Attempts))
|
||||
}
|
||||
}
|
||||
|
||||
// TestMultiKeyFailoverCooldown tests that a key in cooldown is skipped
|
||||
func TestMultiKeyFailoverCooldown(t *testing.T) {
|
||||
cfg := ModelConfig{
|
||||
Primary: "glm-4.7",
|
||||
Fallbacks: []string{"glm-4.7__key_1"},
|
||||
}
|
||||
|
||||
candidates := ResolveCandidates(cfg, "zhipu")
|
||||
|
||||
cooldown := NewCooldownTracker()
|
||||
chain := NewFallbackChain(cooldown)
|
||||
|
||||
// Put the first model in cooldown (using ModelKey now, not just provider)
|
||||
cooldownKey := ModelKey(candidates[0].Provider, candidates[0].Model)
|
||||
cooldown.MarkFailure(cooldownKey, FailoverRateLimit)
|
||||
|
||||
// Verify it's not available
|
||||
if cooldown.IsAvailable(cooldownKey) {
|
||||
t.Fatal("expected first model to be in cooldown")
|
||||
}
|
||||
|
||||
// Mock run function: only second should be called
|
||||
callCount := 0
|
||||
calledProviders := []string{}
|
||||
mockRun := func(ctx context.Context, provider, model string) (*LLMResponse, error) {
|
||||
callCount++
|
||||
calledProviders = append(calledProviders, provider+"/"+model)
|
||||
return &LLMResponse{Content: "success"}, nil
|
||||
}
|
||||
|
||||
result, err := chain.Execute(context.Background(), candidates, mockRun)
|
||||
if err != nil {
|
||||
t.Fatalf("expected success, got error: %v", err)
|
||||
}
|
||||
|
||||
// First provider should have been skipped
|
||||
if callCount != 1 {
|
||||
t.Errorf("expected 1 call (first skipped due to cooldown), got %d", callCount)
|
||||
}
|
||||
|
||||
// Should have called the second provider/model
|
||||
if len(calledProviders) != 1 ||
|
||||
calledProviders[0] != candidates[1].Provider+"/"+candidates[1].Model {
|
||||
t.Errorf("expected second model to be called, got: %v", calledProviders)
|
||||
}
|
||||
|
||||
// Verify first attempt was recorded as skipped
|
||||
if len(result.Attempts) != 1 {
|
||||
t.Fatalf("expected 1 attempt (skipped), got %d", len(result.Attempts))
|
||||
}
|
||||
|
||||
if !result.Attempts[0].Skipped {
|
||||
t.Error("expected first attempt to be marked as skipped")
|
||||
}
|
||||
}
|
||||
|
||||
// TestMultiKeyFailoverWithFormatError tests that format errors are non-retriable
|
||||
func TestMultiKeyFailoverWithFormatError(t *testing.T) {
|
||||
cfg := ModelConfig{
|
||||
Primary: "glm-4.7",
|
||||
Fallbacks: []string{"glm-4.7__key_1"},
|
||||
}
|
||||
|
||||
candidates := ResolveCandidates(cfg, "zhipu")
|
||||
|
||||
cooldown := NewCooldownTracker()
|
||||
chain := NewFallbackChain(cooldown)
|
||||
|
||||
// Mock run function: first call fails with format error (bad request)
|
||||
callCount := 0
|
||||
mockRun := func(ctx context.Context, provider, model string) (*LLMResponse, error) {
|
||||
callCount++
|
||||
return nil, errors.New("invalid request format: tool_use.id missing")
|
||||
}
|
||||
|
||||
// Execute fallback chain
|
||||
result, err := chain.Execute(context.Background(), candidates, mockRun)
|
||||
|
||||
if err == nil {
|
||||
t.Fatal("expected error for format failure, got nil")
|
||||
}
|
||||
|
||||
// Format errors should NOT trigger failover (non-retriable)
|
||||
// So we should only have 1 call
|
||||
if callCount != 1 {
|
||||
t.Errorf("expected 1 call (format error is non-retriable), got %d", callCount)
|
||||
}
|
||||
|
||||
// Verify the error is a FailoverError with format reason
|
||||
var failoverErr *FailoverError
|
||||
if !errors.As(err, &failoverErr) {
|
||||
t.Errorf("expected FailoverError, got: %T - %v", err, err)
|
||||
}
|
||||
|
||||
if failoverErr.Reason != FailoverFormat {
|
||||
t.Errorf("expected FailoverFormat reason, got: %s", failoverErr.Reason)
|
||||
}
|
||||
|
||||
_ = result // result should be nil
|
||||
}
|
||||
|
||||
// TestMultiKeyWithModelFallback tests multi-key failover combined with model fallback.
|
||||
// This simulates the scenario: api_keys: ["k1", "k2"] + fallbacks: ["minimax"]
|
||||
// Expected failover order: glm-4.7 (k1) → glm-4.7__key_1 (k2) → minimax
|
||||
func TestMultiKeyWithModelFallback(t *testing.T) {
|
||||
// Simulate expanded config from:
|
||||
// { "model_name": "glm-4.7", "api_keys": ["k1", "k2"], "fallbacks": ["minimax"] }
|
||||
// After ExpandMultiKeyModels, primaryEntry.Fallbacks = ["glm-4.7__key_1", "minimax"]
|
||||
// Note: In production, "minimax" would be resolved via model lookup to "minimax/minimax"
|
||||
// In this test, we use the full format to avoid needing a lookup function.
|
||||
cfg := ModelConfig{
|
||||
Primary: "glm-4.7",
|
||||
Fallbacks: []string{"glm-4.7__key_1", "minimax/minimax"},
|
||||
}
|
||||
|
||||
candidates := ResolveCandidates(cfg, "zhipu")
|
||||
|
||||
// Should have 3 candidates: glm-4.7 (zhipu), glm-4.7__key_1 (zhipu), minimax (minimax)
|
||||
if len(candidates) != 3 {
|
||||
t.Fatalf("expected 3 candidates, got %d: %v", len(candidates), candidates)
|
||||
}
|
||||
|
||||
// Verify candidate order
|
||||
if candidates[0].Model != "glm-4.7" || candidates[0].Provider != "zhipu" {
|
||||
t.Errorf(
|
||||
"expected first candidate to be zhipu/glm-4.7, got: %s/%s",
|
||||
candidates[0].Provider,
|
||||
candidates[0].Model,
|
||||
)
|
||||
}
|
||||
if candidates[1].Model != "glm-4.7__key_1" || candidates[1].Provider != "zhipu" {
|
||||
t.Errorf(
|
||||
"expected second candidate to be zhipu/glm-4.7__key_1, got: %s/%s",
|
||||
candidates[1].Provider,
|
||||
candidates[1].Model,
|
||||
)
|
||||
}
|
||||
if candidates[2].Model != "minimax" || candidates[2].Provider != "minimax" {
|
||||
t.Errorf(
|
||||
"expected third candidate to be minimax/minimax, got: %s/%s",
|
||||
candidates[2].Provider,
|
||||
candidates[2].Model,
|
||||
)
|
||||
}
|
||||
|
||||
cooldown := NewCooldownTracker()
|
||||
chain := NewFallbackChain(cooldown)
|
||||
|
||||
// Mock run function: first two fail, third succeeds (model fallback)
|
||||
callCount := 0
|
||||
calledModels := []string{}
|
||||
mockRun := func(ctx context.Context, provider, model string) (*LLMResponse, error) {
|
||||
callCount++
|
||||
calledModels = append(calledModels, provider+"/"+model)
|
||||
|
||||
switch callCount {
|
||||
case 1:
|
||||
// k1: rate limit
|
||||
return nil, errors.New("status: 429 - rate limit")
|
||||
case 2:
|
||||
// k2: also rate limit (all zhipu keys exhausted)
|
||||
return nil, errors.New("status: 429 - rate limit")
|
||||
case 3:
|
||||
// minimax: success
|
||||
return &LLMResponse{Content: "success from minimax"}, nil
|
||||
default:
|
||||
return nil, errors.New("unexpected call")
|
||||
}
|
||||
}
|
||||
|
||||
result, err := chain.Execute(context.Background(), candidates, mockRun)
|
||||
if err != nil {
|
||||
t.Fatalf("expected success after failover to model fallback, got error: %v", err)
|
||||
}
|
||||
|
||||
if callCount != 3 {
|
||||
t.Errorf("expected 3 calls (k1 fail + k2 fail + minimax success), got %d", callCount)
|
||||
}
|
||||
|
||||
if result.Response.Content != "success from minimax" {
|
||||
t.Errorf("expected response from minimax, got: %s", result.Response.Content)
|
||||
}
|
||||
|
||||
// Verify call order
|
||||
if len(calledModels) != 3 {
|
||||
t.Fatalf("expected 3 called models, got %d", len(calledModels))
|
||||
}
|
||||
if calledModels[0] != "zhipu/glm-4.7" {
|
||||
t.Errorf("expected first call to zhipu/glm-4.7, got: %s", calledModels[0])
|
||||
}
|
||||
if calledModels[1] != "zhipu/glm-4.7__key_1" {
|
||||
t.Errorf("expected second call to zhipu/glm-4.7__key_1, got: %s", calledModels[1])
|
||||
}
|
||||
if calledModels[2] != "minimax/minimax" {
|
||||
t.Errorf("expected third call to minimax/minimax, got: %s", calledModels[2])
|
||||
}
|
||||
|
||||
// Verify 2 failed attempts recorded
|
||||
if len(result.Attempts) != 2 {
|
||||
t.Errorf("expected 2 failed attempts, got %d", len(result.Attempts))
|
||||
}
|
||||
|
||||
// Both should be rate limit
|
||||
for i, attempt := range result.Attempts {
|
||||
if attempt.Reason != FailoverRateLimit {
|
||||
t.Errorf("expected attempt %d to be rate_limit, got: %s", i, attempt.Reason)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestMultiKeyFailoverMixedErrors tests failover with different error types
|
||||
func TestMultiKeyFailoverMixedErrors(t *testing.T) {
|
||||
cfg := ModelConfig{
|
||||
Primary: "glm-4.7",
|
||||
Fallbacks: []string{"glm-4.7__key_1", "glm-4.7__key_2"},
|
||||
}
|
||||
|
||||
candidates := ResolveCandidates(cfg, "zhipu")
|
||||
|
||||
cooldown := NewCooldownTracker()
|
||||
chain := NewFallbackChain(cooldown)
|
||||
|
||||
// Mock run function: different errors for each key
|
||||
callCount := 0
|
||||
mockRun := func(ctx context.Context, provider, model string) (*LLMResponse, error) {
|
||||
callCount++
|
||||
switch callCount {
|
||||
case 1:
|
||||
// First: rate limit (retriable)
|
||||
return nil, errors.New("status: 429 - rate limit")
|
||||
case 2:
|
||||
// Second: timeout (retriable)
|
||||
return nil, errors.New("context deadline exceeded")
|
||||
case 3:
|
||||
// Third: success
|
||||
return &LLMResponse{Content: "success from key3"}, nil
|
||||
default:
|
||||
return nil, errors.New("unexpected call")
|
||||
}
|
||||
}
|
||||
|
||||
result, err := chain.Execute(context.Background(), candidates, mockRun)
|
||||
if err != nil {
|
||||
t.Fatalf("expected success after 2 failovers, got error: %v", err)
|
||||
}
|
||||
|
||||
if callCount != 3 {
|
||||
t.Errorf("expected 3 calls, got %d", callCount)
|
||||
}
|
||||
|
||||
// Verify both failed attempts were recorded
|
||||
if len(result.Attempts) != 2 {
|
||||
t.Errorf("expected 2 failed attempts, got %d", len(result.Attempts))
|
||||
}
|
||||
|
||||
// First should be rate limit
|
||||
if result.Attempts[0].Reason != FailoverRateLimit {
|
||||
t.Errorf("expected first attempt to be rate_limit, got: %s", result.Attempts[0].Reason)
|
||||
}
|
||||
|
||||
// Second should be timeout
|
||||
if result.Attempts[1].Reason != FailoverTimeout {
|
||||
t.Errorf("expected second attempt to be timeout, got: %s", result.Attempts[1].Reason)
|
||||
}
|
||||
}
|
||||
@@ -157,8 +157,8 @@ func TestFallback_CooldownSkip(t *testing.T) {
|
||||
ct, _ := newTestTracker(now)
|
||||
fc := NewFallbackChain(ct)
|
||||
|
||||
// Put openai in cooldown
|
||||
ct.MarkFailure("openai", FailoverRateLimit)
|
||||
// Put openai/gpt-4 in cooldown (using ModelKey now)
|
||||
ct.MarkFailure(ModelKey("openai", "gpt-4"), FailoverRateLimit)
|
||||
|
||||
candidates := []FallbackCandidate{
|
||||
makeCandidate("openai", "gpt-4"),
|
||||
@@ -195,9 +195,9 @@ func TestFallback_AllInCooldown(t *testing.T) {
|
||||
ct := NewCooldownTracker()
|
||||
fc := NewFallbackChain(ct)
|
||||
|
||||
// Put all providers in cooldown
|
||||
ct.MarkFailure("openai", FailoverRateLimit)
|
||||
ct.MarkFailure("anthropic", FailoverBilling)
|
||||
// Put all models in cooldown (using ModelKey now)
|
||||
ct.MarkFailure(ModelKey("openai", "gpt-4"), FailoverRateLimit)
|
||||
ct.MarkFailure(ModelKey("anthropic", "claude"), FailoverBilling)
|
||||
|
||||
candidates := []FallbackCandidate{
|
||||
makeCandidate("openai", "gpt-4"),
|
||||
@@ -273,12 +273,13 @@ func TestFallback_SuccessResetsCooldown(t *testing.T) {
|
||||
fc := NewFallbackChain(ct)
|
||||
|
||||
candidates := []FallbackCandidate{makeCandidate("openai", "gpt-4")}
|
||||
modelKey := ModelKey("openai", "gpt-4")
|
||||
|
||||
attempt := 0
|
||||
run := func(ctx context.Context, provider, model string) (*LLMResponse, error) {
|
||||
attempt++
|
||||
if attempt == 1 {
|
||||
ct.MarkFailure("openai", FailoverRateLimit) // simulate failure tracked elsewhere
|
||||
ct.MarkFailure(modelKey, FailoverRateLimit) // simulate failure tracked elsewhere
|
||||
}
|
||||
return &LLMResponse{Content: "ok", FinishReason: "stop"}, nil
|
||||
}
|
||||
@@ -287,7 +288,7 @@ func TestFallback_SuccessResetsCooldown(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !ct.IsAvailable("openai") {
|
||||
if !ct.IsAvailable(modelKey) {
|
||||
t.Error("success should reset cooldown")
|
||||
}
|
||||
}
|
||||
|
||||
+63
-8
@@ -188,15 +188,48 @@ func (r *ToolRegistry) ExecuteWithContext(
|
||||
// The callback is a call parameter, not mutable state on the tool instance.
|
||||
var result *ToolResult
|
||||
start := time.Now()
|
||||
if asyncExec, ok := tool.(AsyncExecutor); ok && asyncCallback != nil {
|
||||
logger.DebugCF("tool", "Executing async tool via ExecuteAsync",
|
||||
map[string]any{
|
||||
"tool": name,
|
||||
})
|
||||
result = asyncExec.ExecuteAsync(ctx, args, asyncCallback)
|
||||
} else {
|
||||
result = tool.Execute(ctx, args)
|
||||
|
||||
// Use recover to catch any panics during tool execution
|
||||
// This prevents tool crashes from killing the entire agent
|
||||
func() {
|
||||
defer func() {
|
||||
if re := recover(); re != nil {
|
||||
errMsg := fmt.Sprintf("Tool '%s' crashed with panic: %v", name, re)
|
||||
logger.ErrorCF("tool", "Tool execution panic recovered",
|
||||
map[string]any{
|
||||
"tool": name,
|
||||
"panic": fmt.Sprintf("%v", re),
|
||||
})
|
||||
result = &ToolResult{
|
||||
ForLLM: errMsg,
|
||||
ForUser: errMsg,
|
||||
IsError: true,
|
||||
Err: fmt.Errorf("panic: %v", re),
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
if asyncExec, ok := tool.(AsyncExecutor); ok && asyncCallback != nil {
|
||||
logger.DebugCF("tool", "Executing async tool via ExecuteAsync",
|
||||
map[string]any{
|
||||
"tool": name,
|
||||
})
|
||||
result = asyncExec.ExecuteAsync(ctx, args, asyncCallback)
|
||||
} else {
|
||||
result = tool.Execute(ctx, args)
|
||||
}
|
||||
}()
|
||||
|
||||
// Handle nil result (should not happen, but defensive)
|
||||
if result == nil {
|
||||
result = &ToolResult{
|
||||
ForLLM: fmt.Sprintf("Tool '%s' returned nil result unexpectedly", name),
|
||||
ForUser: fmt.Sprintf("Tool '%s' returned nil result unexpectedly", name),
|
||||
IsError: true,
|
||||
Err: fmt.Errorf("nil result from tool"),
|
||||
}
|
||||
}
|
||||
|
||||
duration := time.Since(start)
|
||||
|
||||
// Log based on result type
|
||||
@@ -303,6 +336,28 @@ func (r *ToolRegistry) List() []string {
|
||||
return r.sortedToolNames()
|
||||
}
|
||||
|
||||
// Clone creates an independent copy of the registry containing the same tool
|
||||
// entries (shallow copy of each ToolEntry). This is used to give subagents a
|
||||
// snapshot of the parent agent's tools without sharing the same registry —
|
||||
// tools registered on the parent after cloning (e.g. spawn, spawn_status)
|
||||
// will NOT be visible to the clone, preventing recursive subagent spawning.
|
||||
// The version counter is reset to 0 in the clone as it's a new independent registry.
|
||||
func (r *ToolRegistry) Clone() *ToolRegistry {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
clone := &ToolRegistry{
|
||||
tools: make(map[string]*ToolEntry, len(r.tools)),
|
||||
}
|
||||
for name, entry := range r.tools {
|
||||
clone.tools[name] = &ToolEntry{
|
||||
Tool: entry.Tool,
|
||||
IsCore: entry.IsCore,
|
||||
TTL: entry.TTL,
|
||||
}
|
||||
}
|
||||
return clone
|
||||
}
|
||||
|
||||
// Count returns the number of registered tools.
|
||||
func (r *ToolRegistry) Count() int {
|
||||
r.mu.RLock()
|
||||
|
||||
@@ -2,6 +2,7 @@ package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
@@ -335,6 +336,96 @@ func TestToolToSchema(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolRegistry_Clone(t *testing.T) {
|
||||
r := NewToolRegistry()
|
||||
r.Register(newMockTool("read_file", "reads files"))
|
||||
r.Register(newMockTool("exec", "runs commands"))
|
||||
r.Register(newMockTool("web_search", "searches the web"))
|
||||
|
||||
clone := r.Clone()
|
||||
|
||||
// Clone should have the same tools
|
||||
if clone.Count() != 3 {
|
||||
t.Errorf("expected clone to have 3 tools, got %d", clone.Count())
|
||||
}
|
||||
for _, name := range []string{"read_file", "exec", "web_search"} {
|
||||
if _, ok := clone.Get(name); !ok {
|
||||
t.Errorf("expected clone to have tool %q", name)
|
||||
}
|
||||
}
|
||||
|
||||
// Registering on parent should NOT affect clone
|
||||
r.Register(newMockTool("spawn", "spawns subagent"))
|
||||
if r.Count() != 4 {
|
||||
t.Errorf("expected parent to have 4 tools, got %d", r.Count())
|
||||
}
|
||||
if clone.Count() != 3 {
|
||||
t.Errorf("expected clone to still have 3 tools after parent mutation, got %d", clone.Count())
|
||||
}
|
||||
if _, ok := clone.Get("spawn"); ok {
|
||||
t.Error("expected clone NOT to have 'spawn' tool registered on parent after cloning")
|
||||
}
|
||||
|
||||
// Registering on clone should NOT affect parent
|
||||
clone.Register(newMockTool("custom", "custom tool"))
|
||||
if clone.Count() != 4 {
|
||||
t.Errorf("expected clone to have 4 tools, got %d", clone.Count())
|
||||
}
|
||||
if _, ok := r.Get("custom"); ok {
|
||||
t.Error("expected parent NOT to have 'custom' tool registered on clone")
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolRegistry_Clone_Empty(t *testing.T) {
|
||||
r := NewToolRegistry()
|
||||
clone := r.Clone()
|
||||
if clone.Count() != 0 {
|
||||
t.Errorf("expected empty clone, got count %d", clone.Count())
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolRegistry_Clone_PreservesHiddenToolState(t *testing.T) {
|
||||
r := NewToolRegistry()
|
||||
r.RegisterHidden(newMockTool("mcp_tool", "dynamic MCP tool"))
|
||||
|
||||
clone := r.Clone()
|
||||
|
||||
// Hidden tools with TTL=0 should not be gettable (same behavior as parent)
|
||||
if _, ok := clone.Get("mcp_tool"); ok {
|
||||
t.Error("expected hidden tool with TTL=0 to be invisible in clone")
|
||||
}
|
||||
|
||||
// But the entry should exist (count includes hidden tools)
|
||||
if clone.Count() != 1 {
|
||||
t.Errorf("expected clone count 1 (hidden entry exists), got %d", clone.Count())
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolRegistry_Clone_PreservesTTLValue(t *testing.T) {
|
||||
r := NewToolRegistry()
|
||||
r.RegisterHidden(newMockTool("ttl_tool", "tool with TTL"))
|
||||
|
||||
// Manually set a non-zero TTL on the entry
|
||||
r.mu.RLock()
|
||||
if entry, ok := r.tools["ttl_tool"]; ok {
|
||||
entry.TTL = 5
|
||||
}
|
||||
r.mu.RUnlock()
|
||||
|
||||
clone := r.Clone()
|
||||
|
||||
// Verify TTL value is preserved in the clone
|
||||
clone.mu.RLock()
|
||||
defer clone.mu.RUnlock()
|
||||
entry, ok := clone.tools["ttl_tool"]
|
||||
if !ok {
|
||||
t.Fatal("expected ttl_tool to exist in clone")
|
||||
}
|
||||
if entry.TTL != 5 {
|
||||
t.Errorf("expected TTL=5 in clone, got %d", entry.TTL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolRegistry_ConcurrentAccess(t *testing.T) {
|
||||
r := NewToolRegistry()
|
||||
var wg sync.WaitGroup
|
||||
@@ -358,3 +449,175 @@ func TestToolRegistry_ConcurrentAccess(t *testing.T) {
|
||||
t.Error("expected tools to be registered after concurrent access")
|
||||
}
|
||||
}
|
||||
|
||||
// --- Panic and abnormal exit tests ---
|
||||
|
||||
// mockPanicTool is a tool that panics during execution
|
||||
type mockPanicTool struct {
|
||||
name string
|
||||
panicValue any
|
||||
}
|
||||
|
||||
func (m *mockPanicTool) Name() string { return m.name }
|
||||
func (m *mockPanicTool) Description() string { return "a tool that panics" }
|
||||
func (m *mockPanicTool) Parameters() map[string]any { return map[string]any{"type": "object"} }
|
||||
func (m *mockPanicTool) Execute(_ context.Context, _ map[string]any) *ToolResult {
|
||||
panic(m.panicValue)
|
||||
}
|
||||
|
||||
// mockNilResultTool is a tool that returns nil
|
||||
type mockNilResultTool struct {
|
||||
name string
|
||||
}
|
||||
|
||||
func (m *mockNilResultTool) Name() string { return m.name }
|
||||
func (m *mockNilResultTool) Description() string { return "a tool that returns nil" }
|
||||
func (m *mockNilResultTool) Parameters() map[string]any { return map[string]any{"type": "object"} }
|
||||
func (m *mockNilResultTool) Execute(_ context.Context, _ map[string]any) *ToolResult {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestToolRegistry_Execute_PanicRecovery(t *testing.T) {
|
||||
r := NewToolRegistry()
|
||||
r.Register(&mockPanicTool{
|
||||
name: "panic_tool",
|
||||
panicValue: "something went terribly wrong",
|
||||
})
|
||||
|
||||
// Should not panic, should return error result
|
||||
result := r.Execute(context.Background(), "panic_tool", nil)
|
||||
|
||||
if result == nil {
|
||||
t.Fatal("expected non-nil result after panic recovery")
|
||||
}
|
||||
if !result.IsError {
|
||||
t.Error("expected IsError=true after panic")
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "panic") {
|
||||
t.Errorf("expected 'panic' in error message, got %q", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "panic_tool") {
|
||||
t.Errorf("expected tool name in error message, got %q", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "something went terribly wrong") {
|
||||
t.Errorf("expected panic value in error message, got %q", result.ForLLM)
|
||||
}
|
||||
if result.Err == nil {
|
||||
t.Error("expected Err to be set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolRegistry_Execute_PanicRecovery_ErrorType(t *testing.T) {
|
||||
r := NewToolRegistry()
|
||||
|
||||
// Test with error type panic
|
||||
r.Register(&mockPanicTool{
|
||||
name: "error_panic_tool",
|
||||
panicValue: errors.New("custom error panic"),
|
||||
})
|
||||
|
||||
result := r.Execute(context.Background(), "error_panic_tool", nil)
|
||||
|
||||
if !result.IsError {
|
||||
t.Error("expected IsError=true")
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "custom error panic") {
|
||||
t.Errorf("expected error message in ForLLM, got %q", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolRegistry_Execute_PanicRecovery_IntType(t *testing.T) {
|
||||
r := NewToolRegistry()
|
||||
|
||||
// Test with int type panic
|
||||
r.Register(&mockPanicTool{
|
||||
name: "int_panic_tool",
|
||||
panicValue: 42,
|
||||
})
|
||||
|
||||
result := r.Execute(context.Background(), "int_panic_tool", nil)
|
||||
|
||||
if !result.IsError {
|
||||
t.Error("expected IsError=true")
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "42") {
|
||||
t.Errorf("expected panic value '42' in ForLLM, got %q", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolRegistry_Execute_NilResultHandling(t *testing.T) {
|
||||
r := NewToolRegistry()
|
||||
r.Register(&mockNilResultTool{name: "nil_tool"})
|
||||
|
||||
result := r.Execute(context.Background(), "nil_tool", nil)
|
||||
|
||||
if result == nil {
|
||||
t.Fatal("expected non-nil result when tool returns nil")
|
||||
}
|
||||
if !result.IsError {
|
||||
t.Error("expected IsError=true for nil result")
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "nil_tool") {
|
||||
t.Errorf("expected tool name in error message, got %q", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "nil result") {
|
||||
t.Errorf("expected 'nil result' in error message, got %q", result.ForLLM)
|
||||
}
|
||||
if result.Err == nil {
|
||||
t.Error("expected Err to be set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolRegistry_ExecuteWithContext_PanicRecovery(t *testing.T) {
|
||||
r := NewToolRegistry()
|
||||
r.Register(&mockPanicTool{
|
||||
name: "ctx_panic_tool",
|
||||
panicValue: "context panic test",
|
||||
})
|
||||
|
||||
// Should not panic even with context
|
||||
result := r.ExecuteWithContext(
|
||||
context.Background(),
|
||||
"ctx_panic_tool",
|
||||
map[string]any{"key": "value"},
|
||||
"telegram",
|
||||
"chat-123",
|
||||
nil,
|
||||
)
|
||||
|
||||
if result == nil {
|
||||
t.Fatal("expected non-nil result")
|
||||
}
|
||||
if !result.IsError {
|
||||
t.Error("expected IsError=true")
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "context panic test") {
|
||||
t.Errorf("expected panic message, got %q", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolRegistry_Execute_PanicDoesNotAffectOtherTools(t *testing.T) {
|
||||
r := NewToolRegistry()
|
||||
r.Register(&mockPanicTool{name: "bad_tool", panicValue: "boom"})
|
||||
r.Register(&mockRegistryTool{
|
||||
name: "good_tool",
|
||||
desc: "works fine",
|
||||
params: map[string]any{},
|
||||
result: SilentResult("success"),
|
||||
})
|
||||
|
||||
// First, trigger the panic
|
||||
result1 := r.Execute(context.Background(), "bad_tool", nil)
|
||||
if !result1.IsError {
|
||||
t.Error("expected error from panic tool")
|
||||
}
|
||||
|
||||
// Then, verify the good tool still works
|
||||
result2 := r.Execute(context.Background(), "good_tool", nil)
|
||||
if result2.IsError {
|
||||
t.Errorf("expected success from good tool, got error: %s", result2.ForLLM)
|
||||
}
|
||||
if result2.ForLLM != "success" {
|
||||
t.Errorf("expected 'success', got %q", result2.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
+18
-1
@@ -311,13 +311,30 @@ func (t *ExecTool) Execute(ctx context.Context, args map[string]any) *ToolResult
|
||||
if err != nil {
|
||||
if errors.Is(cmdCtx.Err(), context.DeadlineExceeded) {
|
||||
msg := fmt.Sprintf("Command timed out after %v", t.timeout)
|
||||
if output != "" {
|
||||
msg += "\n\nPartial output before timeout:\n" + output
|
||||
}
|
||||
return &ToolResult{
|
||||
ForLLM: msg,
|
||||
ForUser: msg,
|
||||
IsError: true,
|
||||
Err: fmt.Errorf("command timeout: %w", err),
|
||||
}
|
||||
}
|
||||
output += fmt.Sprintf("\nExit code: %v", err)
|
||||
|
||||
// Extract detailed exit information
|
||||
var exitErr *exec.ExitError
|
||||
if errors.As(err, &exitErr) {
|
||||
exitCode := exitErr.ExitCode()
|
||||
output += fmt.Sprintf("\n\n[Command exited with code %d]", exitCode)
|
||||
|
||||
// Add signal information if killed by signal (Unix)
|
||||
if exitCode == -1 {
|
||||
output += " (killed by signal)"
|
||||
}
|
||||
} else {
|
||||
output += fmt.Sprintf("\n\n[Command failed: %v]", err)
|
||||
}
|
||||
}
|
||||
|
||||
if output == "" {
|
||||
|
||||
@@ -489,6 +489,69 @@ func TestShellTool_SafePathsInWorkspaceRestriction(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestShellTool_ExitCodeDetails verifies that exit codes are captured with details
|
||||
func TestShellTool_ExitCodeDetails(t *testing.T) {
|
||||
tool, err := NewExecTool("", false)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to configure exec tool: %s", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]any{
|
||||
"command": "sh -c 'exit 42'",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
if !result.IsError {
|
||||
t.Error("expected error for non-zero exit code")
|
||||
}
|
||||
|
||||
// Should contain the exit code in the message (new format: "exited with code 42")
|
||||
if !strings.Contains(result.ForLLM, "42") {
|
||||
t.Errorf("expected exit code 42 in error message, got: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
// Verify the new detailed message format
|
||||
if !strings.Contains(result.ForLLM, "exited with code") {
|
||||
t.Errorf("expected 'exited with code' in message, got: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
// Err field is set by the exec system (may or may not be set depending on implementation)
|
||||
// The important thing is that IsError=true
|
||||
t.Logf("Exit code result: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
// TestShellTool_TimeoutWithPartialOutput verifies timeout includes partial output
|
||||
func TestShellTool_TimeoutWithPartialOutput(t *testing.T) {
|
||||
tool, err := NewExecTool("", false)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to configure exec tool: %s", err)
|
||||
}
|
||||
|
||||
tool.SetTimeout(1 * time.Second) // Give more time for echo to complete
|
||||
|
||||
ctx := context.Background()
|
||||
// Use a command that outputs immediately then sleeps
|
||||
args := map[string]any{
|
||||
"command": "echo 'partial output before timeout' && sleep 30",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
if !result.IsError {
|
||||
t.Error("expected error for timeout")
|
||||
}
|
||||
|
||||
// Should mention timeout
|
||||
if !strings.Contains(result.ForLLM, "timed out") {
|
||||
t.Errorf("expected 'timed out' in message, got: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
// Log the result for debugging (partial output depends on shell behavior)
|
||||
t.Logf("Timeout result: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
// TestShellTool_CustomAllowPatterns verifies that custom allow patterns exempt
|
||||
// commands from deny pattern checks.
|
||||
func TestShellTool_CustomAllowPatterns(t *testing.T) {
|
||||
|
||||
@@ -2,7 +2,7 @@ import { useTranslation } from "react-i18next"
|
||||
|
||||
import type { ChannelConfig } from "@/api/channels"
|
||||
import { maskedSecretPlaceholder } from "@/components/secret-placeholder"
|
||||
import { Field, KeyInput } from "@/components/shared-form"
|
||||
import { Field, KeyInput, SwitchCardField } from "@/components/shared-form"
|
||||
import { Input } from "@/components/ui/input"
|
||||
|
||||
interface FeishuFormProps {
|
||||
@@ -16,6 +16,10 @@ function asString(value: unknown): string {
|
||||
return typeof value === "string" ? value : ""
|
||||
}
|
||||
|
||||
function asBool(value: unknown): boolean {
|
||||
return typeof value === "boolean" ? value : false
|
||||
}
|
||||
|
||||
function asStringArray(value: unknown): string[] {
|
||||
if (!Array.isArray(value)) return []
|
||||
return value.filter((item): item is string => typeof item === "string")
|
||||
@@ -98,6 +102,12 @@ export function FeishuForm({
|
||||
)}
|
||||
/>
|
||||
</Field>
|
||||
<SwitchCardField
|
||||
label={t("channels.field.isLark")}
|
||||
hint={t("channels.form.desc.isLark")}
|
||||
checked={asBool(config.is_lark)}
|
||||
onCheckedChange={(checked) => onChange("is_lark", checked)}
|
||||
/>
|
||||
<Field
|
||||
label={t("channels.field.allowFrom")}
|
||||
hint={t("channels.form.desc.allowFrom")}
|
||||
|
||||
@@ -259,6 +259,7 @@
|
||||
"placeholderText": "Placeholder Text",
|
||||
"groupTriggerMentionOnly": "Group Mention Only",
|
||||
"groupTriggerPrefixes": "Group Trigger Prefixes",
|
||||
"isLark": "Lark (International)",
|
||||
"allowFrom": "Allow From",
|
||||
"allowFromPlaceholder": "e.g. 123456, 789012",
|
||||
"allowOrigins": "Allow Origins",
|
||||
@@ -290,6 +291,7 @@
|
||||
"placeholderEnabled": "Enable temporary placeholder messages before the final reply is sent.",
|
||||
"groupTriggerMentionOnly": "In group chats, respond only when the bot is mentioned.",
|
||||
"groupTriggerPrefixes": "Custom group-chat trigger prefixes, separated by commas.",
|
||||
"isLark": "Use Lark international domain (open.larksuite.com) instead of Feishu domain (open.feishu.cn).",
|
||||
"allowFrom": "Allowed user or group IDs, separated by commas.",
|
||||
"allowOrigins": "Allowed origin domains, separated by commas.",
|
||||
"wsUrl": "WebSocket service URL.",
|
||||
|
||||
@@ -259,6 +259,7 @@
|
||||
"placeholderText": "占位文案",
|
||||
"groupTriggerMentionOnly": "群聊仅提及时响应",
|
||||
"groupTriggerPrefixes": "群聊触发前缀",
|
||||
"isLark": "Lark(国际版)",
|
||||
"allowFrom": "允许来源",
|
||||
"allowFromPlaceholder": "例如 123456, 789012",
|
||||
"allowOrigins": "允许来源域名",
|
||||
@@ -290,6 +291,7 @@
|
||||
"placeholderEnabled": "在最终回复发送前,先发送临时占位消息。",
|
||||
"groupTriggerMentionOnly": "在群聊中仅当提及机器人时才响应。",
|
||||
"groupTriggerPrefixes": "群聊触发前缀,多个值用逗号分隔。",
|
||||
"isLark": "使用 Lark 国际版域名(open.larksuite.com)替代飞书域名(open.feishu.cn)。",
|
||||
"allowFrom": "允许访问的用户或群组 ID,多个值用逗号分隔。",
|
||||
"allowOrigins": "允许访问的来源域名,多个值用逗号分隔。",
|
||||
"wsUrl": "WebSocket 服务地址。",
|
||||
|
||||
Reference in New Issue
Block a user