mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
Merge pull request #131 from Leeaandrob/feat/multi-agent-routing
feat: model fallback chain + multi-agent routing
This commit is contained in:
@@ -0,0 +1,145 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/routing"
|
||||
"github.com/sipeed/picoclaw/pkg/session"
|
||||
"github.com/sipeed/picoclaw/pkg/tools"
|
||||
)
|
||||
|
||||
// AgentInstance represents a fully configured agent with its own workspace,
|
||||
// session manager, context builder, and tool registry.
|
||||
type AgentInstance struct {
|
||||
ID string
|
||||
Name string
|
||||
Model string
|
||||
Fallbacks []string
|
||||
Workspace string
|
||||
MaxIterations int
|
||||
ContextWindow int
|
||||
Provider providers.LLMProvider
|
||||
Sessions *session.SessionManager
|
||||
ContextBuilder *ContextBuilder
|
||||
Tools *tools.ToolRegistry
|
||||
Subagents *config.SubagentsConfig
|
||||
SkillsFilter []string
|
||||
Candidates []providers.FallbackCandidate
|
||||
}
|
||||
|
||||
// NewAgentInstance creates an agent instance from config.
|
||||
func NewAgentInstance(
|
||||
agentCfg *config.AgentConfig,
|
||||
defaults *config.AgentDefaults,
|
||||
cfg *config.Config,
|
||||
provider providers.LLMProvider,
|
||||
) *AgentInstance {
|
||||
workspace := resolveAgentWorkspace(agentCfg, defaults)
|
||||
os.MkdirAll(workspace, 0755)
|
||||
|
||||
model := resolveAgentModel(agentCfg, defaults)
|
||||
fallbacks := resolveAgentFallbacks(agentCfg, defaults)
|
||||
|
||||
restrict := defaults.RestrictToWorkspace
|
||||
toolsRegistry := tools.NewToolRegistry()
|
||||
toolsRegistry.Register(tools.NewReadFileTool(workspace, restrict))
|
||||
toolsRegistry.Register(tools.NewWriteFileTool(workspace, restrict))
|
||||
toolsRegistry.Register(tools.NewListDirTool(workspace, restrict))
|
||||
toolsRegistry.Register(tools.NewExecToolWithConfig(workspace, restrict, cfg))
|
||||
toolsRegistry.Register(tools.NewEditFileTool(workspace, restrict))
|
||||
toolsRegistry.Register(tools.NewAppendFileTool(workspace, restrict))
|
||||
|
||||
sessionsDir := filepath.Join(workspace, "sessions")
|
||||
sessionsManager := session.NewSessionManager(sessionsDir)
|
||||
|
||||
contextBuilder := NewContextBuilder(workspace)
|
||||
contextBuilder.SetToolsRegistry(toolsRegistry)
|
||||
|
||||
agentID := routing.DefaultAgentID
|
||||
agentName := ""
|
||||
var subagents *config.SubagentsConfig
|
||||
var skillsFilter []string
|
||||
|
||||
if agentCfg != nil {
|
||||
agentID = routing.NormalizeAgentID(agentCfg.ID)
|
||||
agentName = agentCfg.Name
|
||||
subagents = agentCfg.Subagents
|
||||
skillsFilter = agentCfg.Skills
|
||||
}
|
||||
|
||||
maxIter := defaults.MaxToolIterations
|
||||
if maxIter == 0 {
|
||||
maxIter = 20
|
||||
}
|
||||
|
||||
// Resolve fallback candidates
|
||||
modelCfg := providers.ModelConfig{
|
||||
Primary: model,
|
||||
Fallbacks: fallbacks,
|
||||
}
|
||||
candidates := providers.ResolveCandidates(modelCfg, defaults.Provider)
|
||||
|
||||
return &AgentInstance{
|
||||
ID: agentID,
|
||||
Name: agentName,
|
||||
Model: model,
|
||||
Fallbacks: fallbacks,
|
||||
Workspace: workspace,
|
||||
MaxIterations: maxIter,
|
||||
ContextWindow: defaults.MaxTokens,
|
||||
Provider: provider,
|
||||
Sessions: sessionsManager,
|
||||
ContextBuilder: contextBuilder,
|
||||
Tools: toolsRegistry,
|
||||
Subagents: subagents,
|
||||
SkillsFilter: skillsFilter,
|
||||
Candidates: candidates,
|
||||
}
|
||||
}
|
||||
|
||||
// resolveAgentWorkspace determines the workspace directory for an agent.
|
||||
func resolveAgentWorkspace(agentCfg *config.AgentConfig, defaults *config.AgentDefaults) string {
|
||||
if agentCfg != nil && strings.TrimSpace(agentCfg.Workspace) != "" {
|
||||
return expandHome(strings.TrimSpace(agentCfg.Workspace))
|
||||
}
|
||||
if agentCfg == nil || agentCfg.Default || agentCfg.ID == "" || routing.NormalizeAgentID(agentCfg.ID) == "main" {
|
||||
return expandHome(defaults.Workspace)
|
||||
}
|
||||
home, _ := os.UserHomeDir()
|
||||
id := routing.NormalizeAgentID(agentCfg.ID)
|
||||
return filepath.Join(home, ".picoclaw", "workspace-"+id)
|
||||
}
|
||||
|
||||
// resolveAgentModel resolves the primary model for an agent.
|
||||
func resolveAgentModel(agentCfg *config.AgentConfig, defaults *config.AgentDefaults) string {
|
||||
if agentCfg != nil && agentCfg.Model != nil && strings.TrimSpace(agentCfg.Model.Primary) != "" {
|
||||
return strings.TrimSpace(agentCfg.Model.Primary)
|
||||
}
|
||||
return defaults.Model
|
||||
}
|
||||
|
||||
// resolveAgentFallbacks resolves the fallback models for an agent.
|
||||
func resolveAgentFallbacks(agentCfg *config.AgentConfig, defaults *config.AgentDefaults) []string {
|
||||
if agentCfg != nil && agentCfg.Model != nil && agentCfg.Model.Fallbacks != nil {
|
||||
return agentCfg.Model.Fallbacks
|
||||
}
|
||||
return defaults.ModelFallbacks
|
||||
}
|
||||
|
||||
func expandHome(path string) string {
|
||||
if path == "" {
|
||||
return path
|
||||
}
|
||||
if path[0] == '~' {
|
||||
home, _ := os.UserHomeDir()
|
||||
if len(path) > 1 && path[1] == '/' {
|
||||
return home + path[1:]
|
||||
}
|
||||
return home
|
||||
}
|
||||
return path
|
||||
}
|
||||
+299
-287
@@ -10,8 +10,6 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
@@ -24,7 +22,7 @@ import (
|
||||
"github.com/sipeed/picoclaw/pkg/constants"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/session"
|
||||
"github.com/sipeed/picoclaw/pkg/routing"
|
||||
"github.com/sipeed/picoclaw/pkg/state"
|
||||
"github.com/sipeed/picoclaw/pkg/tools"
|
||||
"github.com/sipeed/picoclaw/pkg/utils"
|
||||
@@ -32,17 +30,12 @@ import (
|
||||
|
||||
type AgentLoop struct {
|
||||
bus *bus.MessageBus
|
||||
provider providers.LLMProvider
|
||||
workspace string
|
||||
model string
|
||||
contextWindow int // Maximum context window size in tokens
|
||||
maxIterations int
|
||||
sessions *session.SessionManager
|
||||
cfg *config.Config
|
||||
registry *AgentRegistry
|
||||
state *state.Manager
|
||||
contextBuilder *ContextBuilder
|
||||
tools *tools.ToolRegistry
|
||||
running atomic.Bool
|
||||
summarizing sync.Map // Tracks which sessions are currently being summarized
|
||||
summarizing sync.Map
|
||||
fallback *providers.FallbackChain
|
||||
channelManager *channels.Manager
|
||||
}
|
||||
|
||||
@@ -58,99 +51,83 @@ type processOptions struct {
|
||||
NoHistory bool // If true, don't load session history (for heartbeat)
|
||||
}
|
||||
|
||||
// createToolRegistry creates a tool registry with common tools.
|
||||
// This is shared between main agent and subagents.
|
||||
func createToolRegistry(workspace string, restrict bool, cfg *config.Config, msgBus *bus.MessageBus) *tools.ToolRegistry {
|
||||
registry := tools.NewToolRegistry()
|
||||
|
||||
// File system tools
|
||||
registry.Register(tools.NewReadFileTool(workspace, restrict))
|
||||
registry.Register(tools.NewWriteFileTool(workspace, restrict))
|
||||
registry.Register(tools.NewListDirTool(workspace, restrict))
|
||||
registry.Register(tools.NewEditFileTool(workspace, restrict))
|
||||
registry.Register(tools.NewAppendFileTool(workspace, restrict))
|
||||
|
||||
// Shell execution
|
||||
registry.Register(tools.NewExecToolWithConfig(workspace, restrict, cfg))
|
||||
|
||||
if searchTool := tools.NewWebSearchTool(tools.WebSearchToolOptions{
|
||||
BraveAPIKey: cfg.Tools.Web.Brave.APIKey,
|
||||
BraveMaxResults: cfg.Tools.Web.Brave.MaxResults,
|
||||
BraveEnabled: cfg.Tools.Web.Brave.Enabled,
|
||||
DuckDuckGoMaxResults: cfg.Tools.Web.DuckDuckGo.MaxResults,
|
||||
DuckDuckGoEnabled: cfg.Tools.Web.DuckDuckGo.Enabled,
|
||||
PerplexityAPIKey: cfg.Tools.Web.Perplexity.APIKey,
|
||||
PerplexityMaxResults: cfg.Tools.Web.Perplexity.MaxResults,
|
||||
PerplexityEnabled: cfg.Tools.Web.Perplexity.Enabled,
|
||||
}); searchTool != nil {
|
||||
registry.Register(searchTool)
|
||||
}
|
||||
registry.Register(tools.NewWebFetchTool(50000))
|
||||
|
||||
// Hardware tools (I2C, SPI) - Linux only, returns error on other platforms
|
||||
registry.Register(tools.NewI2CTool())
|
||||
registry.Register(tools.NewSPITool())
|
||||
|
||||
// Message tool - available to both agent and subagent
|
||||
// Subagent uses it to communicate directly with user
|
||||
messageTool := tools.NewMessageTool()
|
||||
messageTool.SetSendCallback(func(channel, chatID, content string) error {
|
||||
msgBus.PublishOutbound(bus.OutboundMessage{
|
||||
Channel: channel,
|
||||
ChatID: chatID,
|
||||
Content: content,
|
||||
})
|
||||
return nil
|
||||
})
|
||||
registry.Register(messageTool)
|
||||
|
||||
return registry
|
||||
}
|
||||
|
||||
func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers.LLMProvider) *AgentLoop {
|
||||
workspace := cfg.WorkspacePath()
|
||||
os.MkdirAll(workspace, 0755)
|
||||
registry := NewAgentRegistry(cfg, provider)
|
||||
|
||||
restrict := cfg.Agents.Defaults.RestrictToWorkspace
|
||||
// Register shared tools to all agents
|
||||
registerSharedTools(cfg, msgBus, registry, provider)
|
||||
|
||||
// Create tool registry for main agent
|
||||
toolsRegistry := createToolRegistry(workspace, restrict, cfg, msgBus)
|
||||
// Set up shared fallback chain
|
||||
cooldown := providers.NewCooldownTracker()
|
||||
fallbackChain := providers.NewFallbackChain(cooldown)
|
||||
|
||||
// Create subagent manager with its own tool registry
|
||||
subagentManager := tools.NewSubagentManager(provider, cfg.Agents.Defaults.Model, workspace, msgBus)
|
||||
subagentTools := createToolRegistry(workspace, restrict, cfg, msgBus)
|
||||
// Subagent doesn't need spawn/subagent tools to avoid recursion
|
||||
subagentManager.SetTools(subagentTools)
|
||||
|
||||
// Register spawn tool (for main agent)
|
||||
spawnTool := tools.NewSpawnTool(subagentManager)
|
||||
toolsRegistry.Register(spawnTool)
|
||||
|
||||
// Register subagent tool (synchronous execution)
|
||||
subagentTool := tools.NewSubagentTool(subagentManager)
|
||||
toolsRegistry.Register(subagentTool)
|
||||
|
||||
sessionsManager := session.NewSessionManager(filepath.Join(workspace, "sessions"))
|
||||
|
||||
// Create state manager for atomic state persistence
|
||||
stateManager := state.NewManager(workspace)
|
||||
|
||||
// Create context builder and set tools registry
|
||||
contextBuilder := NewContextBuilder(workspace)
|
||||
contextBuilder.SetToolsRegistry(toolsRegistry)
|
||||
// Create state manager using default agent's workspace for channel recording
|
||||
defaultAgent := registry.GetDefaultAgent()
|
||||
var stateManager *state.Manager
|
||||
if defaultAgent != nil {
|
||||
stateManager = state.NewManager(defaultAgent.Workspace)
|
||||
}
|
||||
|
||||
return &AgentLoop{
|
||||
bus: msgBus,
|
||||
provider: provider,
|
||||
workspace: workspace,
|
||||
model: cfg.Agents.Defaults.Model,
|
||||
contextWindow: cfg.Agents.Defaults.MaxTokens, // Restore context window for summarization
|
||||
maxIterations: cfg.Agents.Defaults.MaxToolIterations,
|
||||
sessions: sessionsManager,
|
||||
state: stateManager,
|
||||
contextBuilder: contextBuilder,
|
||||
tools: toolsRegistry,
|
||||
summarizing: sync.Map{},
|
||||
bus: msgBus,
|
||||
cfg: cfg,
|
||||
registry: registry,
|
||||
state: stateManager,
|
||||
summarizing: sync.Map{},
|
||||
fallback: fallbackChain,
|
||||
}
|
||||
}
|
||||
|
||||
// registerSharedTools registers tools that are shared across all agents (web, message, spawn).
|
||||
func registerSharedTools(cfg *config.Config, msgBus *bus.MessageBus, registry *AgentRegistry, provider providers.LLMProvider) {
|
||||
for _, agentID := range registry.ListAgentIDs() {
|
||||
agent, ok := registry.GetAgent(agentID)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
// Web tools
|
||||
if searchTool := tools.NewWebSearchTool(tools.WebSearchToolOptions{
|
||||
BraveAPIKey: cfg.Tools.Web.Brave.APIKey,
|
||||
BraveMaxResults: cfg.Tools.Web.Brave.MaxResults,
|
||||
BraveEnabled: cfg.Tools.Web.Brave.Enabled,
|
||||
DuckDuckGoMaxResults: cfg.Tools.Web.DuckDuckGo.MaxResults,
|
||||
DuckDuckGoEnabled: cfg.Tools.Web.DuckDuckGo.Enabled,
|
||||
PerplexityAPIKey: cfg.Tools.Web.Perplexity.APIKey,
|
||||
PerplexityMaxResults: cfg.Tools.Web.Perplexity.MaxResults,
|
||||
PerplexityEnabled: cfg.Tools.Web.Perplexity.Enabled,
|
||||
}); searchTool != nil {
|
||||
agent.Tools.Register(searchTool)
|
||||
}
|
||||
agent.Tools.Register(tools.NewWebFetchTool(50000))
|
||||
|
||||
// Hardware tools (I2C, SPI) - Linux only, returns error on other platforms
|
||||
agent.Tools.Register(tools.NewI2CTool())
|
||||
agent.Tools.Register(tools.NewSPITool())
|
||||
|
||||
// Message tool
|
||||
messageTool := tools.NewMessageTool()
|
||||
messageTool.SetSendCallback(func(channel, chatID, content string) error {
|
||||
msgBus.PublishOutbound(bus.OutboundMessage{
|
||||
Channel: channel,
|
||||
ChatID: chatID,
|
||||
Content: content,
|
||||
})
|
||||
return nil
|
||||
})
|
||||
agent.Tools.Register(messageTool)
|
||||
|
||||
// Spawn tool with allowlist checker
|
||||
subagentManager := tools.NewSubagentManager(provider, agent.Model, agent.Workspace, msgBus)
|
||||
spawnTool := tools.NewSpawnTool(subagentManager)
|
||||
currentAgentID := agentID
|
||||
spawnTool.SetAllowlistChecker(func(targetAgentID string) bool {
|
||||
return registry.CanSpawnSubagent(currentAgentID, targetAgentID)
|
||||
})
|
||||
agent.Tools.Register(spawnTool)
|
||||
|
||||
// Update context builder with the complete tools registry
|
||||
agent.ContextBuilder.SetToolsRegistry(agent.Tools)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -175,10 +152,14 @@ func (al *AgentLoop) Run(ctx context.Context) error {
|
||||
if response != "" {
|
||||
// Check if the message tool already sent a response during this round.
|
||||
// If so, skip publishing to avoid duplicate messages to the user.
|
||||
// Use default agent's tools to check (message tool is shared).
|
||||
alreadySent := false
|
||||
if tool, ok := al.tools.Get("message"); ok {
|
||||
if mt, ok := tool.(*tools.MessageTool); ok {
|
||||
alreadySent = mt.HasSentInRound()
|
||||
defaultAgent := al.registry.GetDefaultAgent()
|
||||
if defaultAgent != nil {
|
||||
if tool, ok := defaultAgent.Tools.Get("message"); ok {
|
||||
if mt, ok := tool.(*tools.MessageTool); ok {
|
||||
alreadySent = mt.HasSentInRound()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -201,7 +182,11 @@ func (al *AgentLoop) Stop() {
|
||||
}
|
||||
|
||||
func (al *AgentLoop) RegisterTool(tool tools.Tool) {
|
||||
al.tools.Register(tool)
|
||||
for _, agentID := range al.registry.ListAgentIDs() {
|
||||
if agent, ok := al.registry.GetAgent(agentID); ok {
|
||||
agent.Tools.Register(tool)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (al *AgentLoop) SetChannelManager(cm *channels.Manager) {
|
||||
@@ -211,12 +196,18 @@ func (al *AgentLoop) SetChannelManager(cm *channels.Manager) {
|
||||
// RecordLastChannel records the last active channel for this workspace.
|
||||
// This uses the atomic state save mechanism to prevent data loss on crash.
|
||||
func (al *AgentLoop) RecordLastChannel(channel string) error {
|
||||
if al.state == nil {
|
||||
return nil
|
||||
}
|
||||
return al.state.SetLastChannel(channel)
|
||||
}
|
||||
|
||||
// RecordLastChatID records the last active chat ID for this workspace.
|
||||
// This uses the atomic state save mechanism to prevent data loss on crash.
|
||||
func (al *AgentLoop) RecordLastChatID(chatID string) error {
|
||||
if al.state == nil {
|
||||
return nil
|
||||
}
|
||||
return al.state.SetLastChatID(chatID)
|
||||
}
|
||||
|
||||
@@ -239,7 +230,8 @@ func (al *AgentLoop) ProcessDirectWithChannel(ctx context.Context, content, sess
|
||||
// ProcessHeartbeat processes a heartbeat request without session history.
|
||||
// Each heartbeat is independent and doesn't accumulate context.
|
||||
func (al *AgentLoop) ProcessHeartbeat(ctx context.Context, content, channel, chatID string) (string, error) {
|
||||
return al.runAgentLoop(ctx, processOptions{
|
||||
agent := al.registry.GetDefaultAgent()
|
||||
return al.runAgentLoop(ctx, agent, processOptions{
|
||||
SessionKey: "heartbeat",
|
||||
Channel: channel,
|
||||
ChatID: chatID,
|
||||
@@ -277,9 +269,36 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage)
|
||||
return response, nil
|
||||
}
|
||||
|
||||
// Process as user message
|
||||
return al.runAgentLoop(ctx, processOptions{
|
||||
SessionKey: msg.SessionKey,
|
||||
// Route to determine agent and session key
|
||||
route := al.registry.ResolveRoute(routing.RouteInput{
|
||||
Channel: msg.Channel,
|
||||
AccountID: msg.Metadata["account_id"],
|
||||
Peer: extractPeer(msg),
|
||||
ParentPeer: extractParentPeer(msg),
|
||||
GuildID: msg.Metadata["guild_id"],
|
||||
TeamID: msg.Metadata["team_id"],
|
||||
})
|
||||
|
||||
agent, ok := al.registry.GetAgent(route.AgentID)
|
||||
if !ok {
|
||||
agent = al.registry.GetDefaultAgent()
|
||||
}
|
||||
|
||||
// Use routed session key, but honor pre-set agent-scoped keys (for ProcessDirect/cron)
|
||||
sessionKey := route.SessionKey
|
||||
if msg.SessionKey != "" && strings.HasPrefix(msg.SessionKey, "agent:") {
|
||||
sessionKey = msg.SessionKey
|
||||
}
|
||||
|
||||
logger.InfoCF("agent", "Routed message",
|
||||
map[string]interface{}{
|
||||
"agent_id": agent.ID,
|
||||
"session_key": sessionKey,
|
||||
"matched_by": route.MatchedBy,
|
||||
})
|
||||
|
||||
return al.runAgentLoop(ctx, agent, processOptions{
|
||||
SessionKey: sessionKey,
|
||||
Channel: msg.Channel,
|
||||
ChatID: msg.ChatID,
|
||||
UserMessage: msg.Content,
|
||||
@@ -290,7 +309,6 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage)
|
||||
}
|
||||
|
||||
func (al *AgentLoop) processSystemMessage(ctx context.Context, msg bus.InboundMessage) (string, error) {
|
||||
// Verify this is a system message
|
||||
if msg.Channel != "system" {
|
||||
return "", fmt.Errorf("processSystemMessage called with non-system message channel: %s", msg.Channel)
|
||||
}
|
||||
@@ -302,12 +320,13 @@ func (al *AgentLoop) processSystemMessage(ctx context.Context, msg bus.InboundMe
|
||||
})
|
||||
|
||||
// Parse origin channel from chat_id (format: "channel:chat_id")
|
||||
var originChannel string
|
||||
var originChannel, originChatID string
|
||||
if idx := strings.Index(msg.ChatID, ":"); idx > 0 {
|
||||
originChannel = msg.ChatID[:idx]
|
||||
originChatID = msg.ChatID[idx+1:]
|
||||
} else {
|
||||
// Fallback
|
||||
originChannel = "cli"
|
||||
originChatID = msg.ChatID
|
||||
}
|
||||
|
||||
// Extract subagent result from message content
|
||||
@@ -328,44 +347,47 @@ func (al *AgentLoop) processSystemMessage(ctx context.Context, msg bus.InboundMe
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// Agent acts as dispatcher only - subagent handles user interaction via message tool
|
||||
// Don't forward result here, subagent should use message tool to communicate with user
|
||||
logger.InfoCF("agent", "Subagent completed",
|
||||
map[string]interface{}{
|
||||
"sender_id": msg.SenderID,
|
||||
"channel": originChannel,
|
||||
"content_len": len(content),
|
||||
})
|
||||
// Use default agent for system messages
|
||||
agent := al.registry.GetDefaultAgent()
|
||||
|
||||
// Agent only logs, does not respond to user
|
||||
return "", nil
|
||||
// Use the origin session for context
|
||||
sessionKey := routing.BuildAgentMainSessionKey(agent.ID)
|
||||
|
||||
return al.runAgentLoop(ctx, agent, processOptions{
|
||||
SessionKey: sessionKey,
|
||||
Channel: originChannel,
|
||||
ChatID: originChatID,
|
||||
UserMessage: fmt.Sprintf("[System: %s] %s", msg.SenderID, msg.Content),
|
||||
DefaultResponse: "Background task completed.",
|
||||
EnableSummary: false,
|
||||
SendResponse: true,
|
||||
})
|
||||
}
|
||||
|
||||
// runAgentLoop is the core message processing logic.
|
||||
// It handles context building, LLM calls, tool execution, and response handling.
|
||||
func (al *AgentLoop) runAgentLoop(ctx context.Context, opts processOptions) (string, error) {
|
||||
func (al *AgentLoop) runAgentLoop(ctx context.Context, agent *AgentInstance, opts processOptions) (string, error) {
|
||||
// 0. Record last channel for heartbeat notifications (skip internal channels)
|
||||
if opts.Channel != "" && opts.ChatID != "" {
|
||||
// Don't record internal channels (cli, system, subagent)
|
||||
if !constants.IsInternalChannel(opts.Channel) {
|
||||
channelKey := fmt.Sprintf("%s:%s", opts.Channel, opts.ChatID)
|
||||
if err := al.RecordLastChannel(channelKey); err != nil {
|
||||
logger.WarnCF("agent", "Failed to record last channel: %v", map[string]interface{}{"error": err.Error()})
|
||||
logger.WarnCF("agent", "Failed to record last channel", map[string]interface{}{"error": err.Error()})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 1. Update tool contexts
|
||||
al.updateToolContexts(opts.Channel, opts.ChatID)
|
||||
al.updateToolContexts(agent, opts.Channel, opts.ChatID)
|
||||
|
||||
// 2. Build messages (skip history for heartbeat)
|
||||
var history []providers.Message
|
||||
var summary string
|
||||
if !opts.NoHistory {
|
||||
history = al.sessions.GetHistory(opts.SessionKey)
|
||||
summary = al.sessions.GetSummary(opts.SessionKey)
|
||||
history = agent.Sessions.GetHistory(opts.SessionKey)
|
||||
summary = agent.Sessions.GetSummary(opts.SessionKey)
|
||||
}
|
||||
messages := al.contextBuilder.BuildMessages(
|
||||
messages := agent.ContextBuilder.BuildMessages(
|
||||
history,
|
||||
summary,
|
||||
opts.UserMessage,
|
||||
@@ -375,10 +397,10 @@ func (al *AgentLoop) runAgentLoop(ctx context.Context, opts processOptions) (str
|
||||
)
|
||||
|
||||
// 3. Save user message to session
|
||||
al.sessions.AddMessage(opts.SessionKey, "user", opts.UserMessage)
|
||||
agent.Sessions.AddMessage(opts.SessionKey, "user", opts.UserMessage)
|
||||
|
||||
// 4. Run LLM iteration loop
|
||||
finalContent, iteration, err := al.runLLMIteration(ctx, messages, opts)
|
||||
finalContent, iteration, err := al.runLLMIteration(ctx, agent, messages, opts)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -392,12 +414,12 @@ func (al *AgentLoop) runAgentLoop(ctx context.Context, opts processOptions) (str
|
||||
}
|
||||
|
||||
// 6. Save final assistant message to session
|
||||
al.sessions.AddMessage(opts.SessionKey, "assistant", finalContent)
|
||||
al.sessions.Save(opts.SessionKey)
|
||||
agent.Sessions.AddMessage(opts.SessionKey, "assistant", finalContent)
|
||||
agent.Sessions.Save(opts.SessionKey)
|
||||
|
||||
// 7. Optional: summarization
|
||||
if opts.EnableSummary {
|
||||
al.maybeSummarize(opts.SessionKey, opts.Channel, opts.ChatID)
|
||||
al.maybeSummarize(agent, opts.SessionKey, opts.Channel, opts.ChatID)
|
||||
}
|
||||
|
||||
// 8. Optional: send response via bus
|
||||
@@ -413,6 +435,7 @@ func (al *AgentLoop) runAgentLoop(ctx context.Context, opts processOptions) (str
|
||||
responsePreview := utils.Truncate(finalContent, 120)
|
||||
logger.InfoCF("agent", fmt.Sprintf("Response: %s", responsePreview),
|
||||
map[string]interface{}{
|
||||
"agent_id": agent.ID,
|
||||
"session_key": opts.SessionKey,
|
||||
"iterations": iteration,
|
||||
"final_length": len(finalContent),
|
||||
@@ -422,28 +445,29 @@ func (al *AgentLoop) runAgentLoop(ctx context.Context, opts processOptions) (str
|
||||
}
|
||||
|
||||
// runLLMIteration executes the LLM call loop with tool handling.
|
||||
// Returns the final content, iteration count, and any error.
|
||||
func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.Message, opts processOptions) (string, int, error) {
|
||||
func (al *AgentLoop) runLLMIteration(ctx context.Context, agent *AgentInstance, messages []providers.Message, opts processOptions) (string, int, error) {
|
||||
iteration := 0
|
||||
var finalContent string
|
||||
|
||||
for iteration < al.maxIterations {
|
||||
for iteration < agent.MaxIterations {
|
||||
iteration++
|
||||
|
||||
logger.DebugCF("agent", "LLM iteration",
|
||||
map[string]interface{}{
|
||||
"agent_id": agent.ID,
|
||||
"iteration": iteration,
|
||||
"max": al.maxIterations,
|
||||
"max": agent.MaxIterations,
|
||||
})
|
||||
|
||||
// Build tool definitions
|
||||
providerToolDefs := al.tools.ToProviderDefs()
|
||||
providerToolDefs := agent.Tools.ToProviderDefs()
|
||||
|
||||
// Log LLM request details
|
||||
logger.DebugCF("agent", "LLM request",
|
||||
map[string]interface{}{
|
||||
"agent_id": agent.ID,
|
||||
"iteration": iteration,
|
||||
"model": al.model,
|
||||
"model": agent.Model,
|
||||
"messages_count": len(messages),
|
||||
"tools_count": len(providerToolDefs),
|
||||
"max_tokens": 8192,
|
||||
@@ -459,23 +483,45 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M
|
||||
"tools_json": formatToolsForLog(providerToolDefs),
|
||||
})
|
||||
|
||||
// Call LLM with fallback chain if candidates are configured.
|
||||
var response *providers.LLMResponse
|
||||
var err error
|
||||
|
||||
callLLM := func() (*providers.LLMResponse, error) {
|
||||
if len(agent.Candidates) > 1 && al.fallback != nil {
|
||||
fbResult, fbErr := al.fallback.Execute(ctx, agent.Candidates,
|
||||
func(ctx context.Context, provider, model string) (*providers.LLMResponse, error) {
|
||||
return agent.Provider.Chat(ctx, messages, providerToolDefs, model, map[string]interface{}{
|
||||
"max_tokens": 8192,
|
||||
"temperature": 0.7,
|
||||
})
|
||||
},
|
||||
)
|
||||
if fbErr != nil {
|
||||
return nil, fbErr
|
||||
}
|
||||
if fbResult.Provider != "" && len(fbResult.Attempts) > 0 {
|
||||
logger.InfoCF("agent", fmt.Sprintf("Fallback: succeeded with %s/%s after %d attempts",
|
||||
fbResult.Provider, fbResult.Model, len(fbResult.Attempts)+1),
|
||||
map[string]interface{}{"agent_id": agent.ID, "iteration": iteration})
|
||||
}
|
||||
return fbResult.Response, nil
|
||||
}
|
||||
return agent.Provider.Chat(ctx, messages, providerToolDefs, agent.Model, map[string]interface{}{
|
||||
"max_tokens": 8192,
|
||||
"temperature": 0.7,
|
||||
})
|
||||
}
|
||||
|
||||
// Retry loop for context/token errors
|
||||
maxRetries := 2
|
||||
for retry := 0; retry <= maxRetries; retry++ {
|
||||
response, err = al.provider.Chat(ctx, messages, providerToolDefs, al.model, map[string]interface{}{
|
||||
"max_tokens": 8192,
|
||||
"temperature": 0.7,
|
||||
})
|
||||
|
||||
response, err = callLLM()
|
||||
if err == nil {
|
||||
break // Success
|
||||
break
|
||||
}
|
||||
|
||||
errMsg := strings.ToLower(err.Error())
|
||||
// Check for context window errors (provider specific, but usually contain "token" or "invalid")
|
||||
isContextError := strings.Contains(errMsg, "token") ||
|
||||
strings.Contains(errMsg, "context") ||
|
||||
strings.Contains(errMsg, "invalidparameter") ||
|
||||
@@ -487,107 +533,30 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M
|
||||
"retry": retry,
|
||||
})
|
||||
|
||||
// Notify user on first retry only
|
||||
if retry == 0 && !constants.IsInternalChannel(opts.Channel) && opts.SendResponse {
|
||||
if retry == 0 && !constants.IsInternalChannel(opts.Channel) {
|
||||
al.bus.PublishOutbound(bus.OutboundMessage{
|
||||
Channel: opts.Channel,
|
||||
ChatID: opts.ChatID,
|
||||
Content: "⚠️ Context window exceeded. Compressing history and retrying...",
|
||||
Content: "Context window exceeded. Compressing history and retrying...",
|
||||
})
|
||||
}
|
||||
|
||||
// Force compression
|
||||
al.forceCompression(opts.SessionKey)
|
||||
|
||||
// Rebuild messages with compressed history
|
||||
// Note: We need to reload history from session manager because forceCompression changed it
|
||||
newHistory := al.sessions.GetHistory(opts.SessionKey)
|
||||
newSummary := al.sessions.GetSummary(opts.SessionKey)
|
||||
|
||||
// Re-create messages for the next attempt
|
||||
// We keep the current user message (opts.UserMessage) effectively
|
||||
messages = al.contextBuilder.BuildMessages(
|
||||
newHistory,
|
||||
newSummary,
|
||||
opts.UserMessage,
|
||||
nil,
|
||||
opts.Channel,
|
||||
opts.ChatID,
|
||||
al.forceCompression(agent, opts.SessionKey)
|
||||
newHistory := agent.Sessions.GetHistory(opts.SessionKey)
|
||||
newSummary := agent.Sessions.GetSummary(opts.SessionKey)
|
||||
messages = agent.ContextBuilder.BuildMessages(
|
||||
newHistory, newSummary, "",
|
||||
nil, opts.Channel, opts.ChatID,
|
||||
)
|
||||
|
||||
// Important: If we are in the middle of a tool loop (iteration > 1),
|
||||
// rebuilding messages from session history might duplicate the flow or miss context
|
||||
// if intermediate steps weren't saved correctly.
|
||||
// However, al.sessions.AddFullMessage is called after every tool execution,
|
||||
// so GetHistory should reflect the current state including partial tool execution.
|
||||
// But we need to ensure we don't duplicate the user message which is appended in BuildMessages.
|
||||
// BuildMessages(history...) takes the stored history and appends the *current* user message.
|
||||
// If iteration > 1, the "current user message" was already added to history in step 3 of runAgentLoop.
|
||||
// So if we pass opts.UserMessage again, we might duplicate it?
|
||||
// Actually, step 3 is: al.sessions.AddMessage(opts.SessionKey, "user", opts.UserMessage)
|
||||
// So GetHistory ALREADY contains the user message!
|
||||
|
||||
// CORRECTION:
|
||||
// BuildMessages combines: [System] + [History] + [CurrentMessage]
|
||||
// But Step 3 added CurrentMessage to History.
|
||||
// So if we use GetHistory now, it has the user message.
|
||||
// If we pass opts.UserMessage to BuildMessages, it adds it AGAIN.
|
||||
|
||||
// For retry in the middle of a loop, we should rely on what's in the session.
|
||||
// BUT checking BuildMessages implementation:
|
||||
// It appends history... then appends currentMessage.
|
||||
|
||||
// Logic fix for retry:
|
||||
// If iteration == 1, opts.UserMessage corresponds to the user input.
|
||||
// If iteration > 1, we are processing tool results. The "messages" passed to Chat
|
||||
// already accumulated tool outputs.
|
||||
// Rebuilding from session history is safest because it persists state.
|
||||
// Start fresh with rebuilt history.
|
||||
|
||||
// Special case: standard BuildMessages appends "currentMessage".
|
||||
// If we are strictly retrying the *LLM call*, we want the exact same state as before but compressed.
|
||||
// However, the "messages" argument passed to runLLMIteration is constructed by the caller.
|
||||
// If we rebuild from Session, we need to know if "currentMessage" should be appended or is already in history.
|
||||
|
||||
// In runAgentLoop:
|
||||
// 3. sessions.AddMessage(userMsg)
|
||||
// 4. runLLMIteration(..., UserMessage)
|
||||
|
||||
// So History contains the user message.
|
||||
// BuildMessages typically appends the user message as a *new* pending message.
|
||||
// Wait, standard BuildMessages usage in runAgentLoop:
|
||||
// messages := BuildMessages(history (has old), UserMessage)
|
||||
// THEN AddMessage(UserMessage).
|
||||
// So "history" passed to BuildMessages does NOT contain the current UserMessage yet.
|
||||
|
||||
// But here, inside the loop, we have already saved it.
|
||||
// So GetHistory() includes the current user message.
|
||||
// If we call BuildMessages(GetHistory(), UserMessage), we get duplicates.
|
||||
|
||||
// Hack/Fix:
|
||||
// If we are retrying, we rebuild from Session History ONLY.
|
||||
// We pass empty string as "currentMessage" to BuildMessages
|
||||
// because the "current message" is already saved in history (step 3).
|
||||
|
||||
messages = al.contextBuilder.BuildMessages(
|
||||
newHistory,
|
||||
newSummary,
|
||||
"", // Empty because history already contains the relevant messages
|
||||
nil,
|
||||
opts.Channel,
|
||||
opts.ChatID,
|
||||
)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
// Real error or success, break loop
|
||||
break
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
logger.ErrorCF("agent", "LLM call failed",
|
||||
map[string]interface{}{
|
||||
"agent_id": agent.ID,
|
||||
"iteration": iteration,
|
||||
"error": err.Error(),
|
||||
})
|
||||
@@ -599,6 +568,7 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M
|
||||
finalContent = response.Content
|
||||
logger.InfoCF("agent", "LLM response without tool calls (direct answer)",
|
||||
map[string]interface{}{
|
||||
"agent_id": agent.ID,
|
||||
"iteration": iteration,
|
||||
"content_chars": len(finalContent),
|
||||
})
|
||||
@@ -612,6 +582,7 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M
|
||||
}
|
||||
logger.InfoCF("agent", "LLM requested tool calls",
|
||||
map[string]interface{}{
|
||||
"agent_id": agent.ID,
|
||||
"tools": toolNames,
|
||||
"count": len(response.ToolCalls),
|
||||
"iteration": iteration,
|
||||
@@ -636,15 +607,15 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M
|
||||
messages = append(messages, assistantMsg)
|
||||
|
||||
// Save assistant message with tool calls to session
|
||||
al.sessions.AddFullMessage(opts.SessionKey, assistantMsg)
|
||||
agent.Sessions.AddFullMessage(opts.SessionKey, assistantMsg)
|
||||
|
||||
// Execute tool calls
|
||||
for _, tc := range response.ToolCalls {
|
||||
// Log tool call with arguments preview
|
||||
argsJSON, _ := json.Marshal(tc.Arguments)
|
||||
argsPreview := utils.Truncate(string(argsJSON), 200)
|
||||
logger.InfoCF("agent", fmt.Sprintf("Tool call: %s(%s)", tc.Name, argsPreview),
|
||||
map[string]interface{}{
|
||||
"agent_id": agent.ID,
|
||||
"tool": tc.Name,
|
||||
"iteration": iteration,
|
||||
})
|
||||
@@ -665,7 +636,7 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M
|
||||
}
|
||||
}
|
||||
|
||||
toolResult := al.tools.ExecuteWithContext(ctx, tc.Name, tc.Arguments, opts.Channel, opts.ChatID, asyncCallback)
|
||||
toolResult := agent.Tools.ExecuteWithContext(ctx, tc.Name, tc.Arguments, opts.Channel, opts.ChatID, asyncCallback)
|
||||
|
||||
// Send ForUser content to user immediately if not Silent
|
||||
if !toolResult.Silent && toolResult.ForUser != "" && opts.SendResponse {
|
||||
@@ -695,7 +666,7 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M
|
||||
messages = append(messages, toolResultMsg)
|
||||
|
||||
// Save tool result message to session
|
||||
al.sessions.AddFullMessage(opts.SessionKey, toolResultMsg)
|
||||
agent.Sessions.AddFullMessage(opts.SessionKey, toolResultMsg)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -703,19 +674,19 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M
|
||||
}
|
||||
|
||||
// updateToolContexts updates the context for tools that need channel/chatID info.
|
||||
func (al *AgentLoop) updateToolContexts(channel, chatID string) {
|
||||
func (al *AgentLoop) updateToolContexts(agent *AgentInstance, channel, chatID string) {
|
||||
// Use ContextualTool interface instead of type assertions
|
||||
if tool, ok := al.tools.Get("message"); ok {
|
||||
if tool, ok := agent.Tools.Get("message"); ok {
|
||||
if mt, ok := tool.(tools.ContextualTool); ok {
|
||||
mt.SetContext(channel, chatID)
|
||||
}
|
||||
}
|
||||
if tool, ok := al.tools.Get("spawn"); ok {
|
||||
if tool, ok := agent.Tools.Get("spawn"); ok {
|
||||
if st, ok := tool.(tools.ContextualTool); ok {
|
||||
st.SetContext(channel, chatID)
|
||||
}
|
||||
}
|
||||
if tool, ok := al.tools.Get("subagent"); ok {
|
||||
if tool, ok := agent.Tools.Get("subagent"); ok {
|
||||
if st, ok := tool.(tools.ContextualTool); ok {
|
||||
st.SetContext(channel, chatID)
|
||||
}
|
||||
@@ -723,24 +694,24 @@ func (al *AgentLoop) updateToolContexts(channel, chatID string) {
|
||||
}
|
||||
|
||||
// maybeSummarize triggers summarization if the session history exceeds thresholds.
|
||||
func (al *AgentLoop) maybeSummarize(sessionKey, channel, chatID string) {
|
||||
newHistory := al.sessions.GetHistory(sessionKey)
|
||||
func (al *AgentLoop) maybeSummarize(agent *AgentInstance, sessionKey, channel, chatID string) {
|
||||
newHistory := agent.Sessions.GetHistory(sessionKey)
|
||||
tokenEstimate := al.estimateTokens(newHistory)
|
||||
threshold := al.contextWindow * 75 / 100
|
||||
threshold := agent.ContextWindow * 75 / 100
|
||||
|
||||
if len(newHistory) > 20 || tokenEstimate > threshold {
|
||||
if _, loading := al.summarizing.LoadOrStore(sessionKey, true); !loading {
|
||||
summarizeKey := agent.ID + ":" + sessionKey
|
||||
if _, loading := al.summarizing.LoadOrStore(summarizeKey, true); !loading {
|
||||
go func() {
|
||||
defer al.summarizing.Delete(sessionKey)
|
||||
// Notify user about optimization if not an internal channel
|
||||
defer al.summarizing.Delete(summarizeKey)
|
||||
if !constants.IsInternalChannel(channel) {
|
||||
al.bus.PublishOutbound(bus.OutboundMessage{
|
||||
Channel: channel,
|
||||
ChatID: chatID,
|
||||
Content: "⚠️ Memory threshold reached. Optimizing conversation history...",
|
||||
Content: "Memory threshold reached. Optimizing conversation history...",
|
||||
})
|
||||
}
|
||||
al.summarizeSession(sessionKey)
|
||||
al.summarizeSession(agent, sessionKey)
|
||||
}()
|
||||
}
|
||||
}
|
||||
@@ -748,8 +719,8 @@ func (al *AgentLoop) maybeSummarize(sessionKey, channel, chatID string) {
|
||||
|
||||
// forceCompression aggressively reduces context when the limit is hit.
|
||||
// It drops the oldest 50% of messages (keeping system prompt and last user message).
|
||||
func (al *AgentLoop) forceCompression(sessionKey string) {
|
||||
history := al.sessions.GetHistory(sessionKey)
|
||||
func (al *AgentLoop) forceCompression(agent *AgentInstance, sessionKey string) {
|
||||
history := agent.Sessions.GetHistory(sessionKey)
|
||||
if len(history) <= 4 {
|
||||
return
|
||||
}
|
||||
@@ -796,8 +767,8 @@ func (al *AgentLoop) forceCompression(sessionKey string) {
|
||||
newHistory = append(newHistory, history[len(history)-1]) // Last message
|
||||
|
||||
// Update session
|
||||
al.sessions.SetHistory(sessionKey, newHistory)
|
||||
al.sessions.Save(sessionKey)
|
||||
agent.Sessions.SetHistory(sessionKey, newHistory)
|
||||
agent.Sessions.Save(sessionKey)
|
||||
|
||||
logger.WarnCF("agent", "Forced compression executed", map[string]interface{}{
|
||||
"session_key": sessionKey,
|
||||
@@ -810,15 +781,26 @@ func (al *AgentLoop) forceCompression(sessionKey string) {
|
||||
func (al *AgentLoop) GetStartupInfo() map[string]interface{} {
|
||||
info := make(map[string]interface{})
|
||||
|
||||
agent := al.registry.GetDefaultAgent()
|
||||
if agent == nil {
|
||||
return info
|
||||
}
|
||||
|
||||
// Tools info
|
||||
tools := al.tools.List()
|
||||
toolsList := agent.Tools.List()
|
||||
info["tools"] = map[string]interface{}{
|
||||
"count": len(tools),
|
||||
"names": tools,
|
||||
"count": len(toolsList),
|
||||
"names": toolsList,
|
||||
}
|
||||
|
||||
// Skills info
|
||||
info["skills"] = al.contextBuilder.GetSkillsInfo()
|
||||
info["skills"] = agent.ContextBuilder.GetSkillsInfo()
|
||||
|
||||
// Agents info
|
||||
info["agents"] = map[string]interface{}{
|
||||
"count": len(al.registry.ListAgentIDs()),
|
||||
"ids": al.registry.ListAgentIDs(),
|
||||
}
|
||||
|
||||
return info
|
||||
}
|
||||
@@ -875,12 +857,12 @@ func formatToolsForLog(tools []providers.ToolDefinition) string {
|
||||
}
|
||||
|
||||
// summarizeSession summarizes the conversation history for a session.
|
||||
func (al *AgentLoop) summarizeSession(sessionKey string) {
|
||||
func (al *AgentLoop) summarizeSession(agent *AgentInstance, sessionKey string) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||
defer cancel()
|
||||
|
||||
history := al.sessions.GetHistory(sessionKey)
|
||||
summary := al.sessions.GetSummary(sessionKey)
|
||||
history := agent.Sessions.GetHistory(sessionKey)
|
||||
summary := agent.Sessions.GetSummary(sessionKey)
|
||||
|
||||
// Keep last 4 messages for continuity
|
||||
if len(history) <= 4 {
|
||||
@@ -890,8 +872,7 @@ func (al *AgentLoop) summarizeSession(sessionKey string) {
|
||||
toSummarize := history[:len(history)-4]
|
||||
|
||||
// Oversized Message Guard
|
||||
// Skip messages larger than 50% of context window to prevent summarizer overflow
|
||||
maxMessageTokens := al.contextWindow / 2
|
||||
maxMessageTokens := agent.ContextWindow / 2
|
||||
validMessages := make([]providers.Message, 0)
|
||||
omitted := false
|
||||
|
||||
@@ -899,8 +880,7 @@ func (al *AgentLoop) summarizeSession(sessionKey string) {
|
||||
if m.Role != "user" && m.Role != "assistant" {
|
||||
continue
|
||||
}
|
||||
// Estimate tokens for this message
|
||||
msgTokens := len(m.Content) / 2 // Use safer estimate here too (2.5 -> 2 for integer division safety)
|
||||
msgTokens := len(m.Content) / 2
|
||||
if msgTokens > maxMessageTokens {
|
||||
omitted = true
|
||||
continue
|
||||
@@ -913,19 +893,17 @@ func (al *AgentLoop) summarizeSession(sessionKey string) {
|
||||
}
|
||||
|
||||
// Multi-Part Summarization
|
||||
// Split into two parts if history is significant
|
||||
var finalSummary string
|
||||
if len(validMessages) > 10 {
|
||||
mid := len(validMessages) / 2
|
||||
part1 := validMessages[:mid]
|
||||
part2 := validMessages[mid:]
|
||||
|
||||
s1, _ := al.summarizeBatch(ctx, part1, "")
|
||||
s2, _ := al.summarizeBatch(ctx, part2, "")
|
||||
s1, _ := al.summarizeBatch(ctx, agent, part1, "")
|
||||
s2, _ := al.summarizeBatch(ctx, agent, part2, "")
|
||||
|
||||
// Merge them
|
||||
mergePrompt := fmt.Sprintf("Merge these two conversation summaries into one cohesive summary:\n\n1: %s\n\n2: %s", s1, s2)
|
||||
resp, err := al.provider.Chat(ctx, []providers.Message{{Role: "user", Content: mergePrompt}}, nil, al.model, map[string]interface{}{
|
||||
resp, err := agent.Provider.Chat(ctx, []providers.Message{{Role: "user", Content: mergePrompt}}, nil, agent.Model, map[string]interface{}{
|
||||
"max_tokens": 1024,
|
||||
"temperature": 0.3,
|
||||
})
|
||||
@@ -935,7 +913,7 @@ func (al *AgentLoop) summarizeSession(sessionKey string) {
|
||||
finalSummary = s1 + " " + s2
|
||||
}
|
||||
} else {
|
||||
finalSummary, _ = al.summarizeBatch(ctx, validMessages, summary)
|
||||
finalSummary, _ = al.summarizeBatch(ctx, agent, validMessages, summary)
|
||||
}
|
||||
|
||||
if omitted && finalSummary != "" {
|
||||
@@ -943,14 +921,14 @@ func (al *AgentLoop) summarizeSession(sessionKey string) {
|
||||
}
|
||||
|
||||
if finalSummary != "" {
|
||||
al.sessions.SetSummary(sessionKey, finalSummary)
|
||||
al.sessions.TruncateHistory(sessionKey, 4)
|
||||
al.sessions.Save(sessionKey)
|
||||
agent.Sessions.SetSummary(sessionKey, finalSummary)
|
||||
agent.Sessions.TruncateHistory(sessionKey, 4)
|
||||
agent.Sessions.Save(sessionKey)
|
||||
}
|
||||
}
|
||||
|
||||
// summarizeBatch summarizes a batch of messages.
|
||||
func (al *AgentLoop) summarizeBatch(ctx context.Context, batch []providers.Message, existingSummary string) (string, error) {
|
||||
func (al *AgentLoop) summarizeBatch(ctx context.Context, agent *AgentInstance, batch []providers.Message, existingSummary string) (string, error) {
|
||||
prompt := "Provide a concise summary of this conversation segment, preserving core context and key points.\n"
|
||||
if existingSummary != "" {
|
||||
prompt += "Existing context: " + existingSummary + "\n"
|
||||
@@ -960,7 +938,7 @@ func (al *AgentLoop) summarizeBatch(ctx context.Context, batch []providers.Messa
|
||||
prompt += fmt.Sprintf("%s: %s\n", m.Role, m.Content)
|
||||
}
|
||||
|
||||
response, err := al.provider.Chat(ctx, []providers.Message{{Role: "user", Content: prompt}}, nil, al.model, map[string]interface{}{
|
||||
response, err := agent.Provider.Chat(ctx, []providers.Message{{Role: "user", Content: prompt}}, nil, agent.Model, map[string]interface{}{
|
||||
"max_tokens": 1024,
|
||||
"temperature": 0.3,
|
||||
})
|
||||
@@ -999,25 +977,31 @@ func (al *AgentLoop) handleCommand(ctx context.Context, msg bus.InboundMessage)
|
||||
switch cmd {
|
||||
case "/show":
|
||||
if len(args) < 1 {
|
||||
return "Usage: /show [model|channel]", true
|
||||
return "Usage: /show [model|channel|agents]", true
|
||||
}
|
||||
switch args[0] {
|
||||
case "model":
|
||||
return fmt.Sprintf("Current model: %s", al.model), true
|
||||
defaultAgent := al.registry.GetDefaultAgent()
|
||||
if defaultAgent == nil {
|
||||
return "No default agent configured", true
|
||||
}
|
||||
return fmt.Sprintf("Current model: %s", defaultAgent.Model), true
|
||||
case "channel":
|
||||
return fmt.Sprintf("Current channel: %s", msg.Channel), true
|
||||
case "agents":
|
||||
agentIDs := al.registry.ListAgentIDs()
|
||||
return fmt.Sprintf("Registered agents: %s", strings.Join(agentIDs, ", ")), true
|
||||
default:
|
||||
return fmt.Sprintf("Unknown show target: %s", args[0]), true
|
||||
}
|
||||
|
||||
case "/list":
|
||||
if len(args) < 1 {
|
||||
return "Usage: /list [models|channels]", true
|
||||
return "Usage: /list [models|channels|agents]", true
|
||||
}
|
||||
switch args[0] {
|
||||
case "models":
|
||||
// TODO: Fetch available models dynamically if possible
|
||||
return "Available models: glm-4.7, claude-3-5-sonnet, gpt-4o (configured in config.json/env)", true
|
||||
return "Available models: configured in config.json per agent", true
|
||||
case "channels":
|
||||
if al.channelManager == nil {
|
||||
return "Channel manager not initialized", true
|
||||
@@ -1027,6 +1011,9 @@ func (al *AgentLoop) handleCommand(ctx context.Context, msg bus.InboundMessage)
|
||||
return "No channels enabled", true
|
||||
}
|
||||
return fmt.Sprintf("Enabled channels: %s", strings.Join(channels, ", ")), true
|
||||
case "agents":
|
||||
agentIDs := al.registry.ListAgentIDs()
|
||||
return fmt.Sprintf("Registered agents: %s", strings.Join(agentIDs, ", ")), true
|
||||
default:
|
||||
return fmt.Sprintf("Unknown list target: %s", args[0]), true
|
||||
}
|
||||
@@ -1040,23 +1027,21 @@ func (al *AgentLoop) handleCommand(ctx context.Context, msg bus.InboundMessage)
|
||||
|
||||
switch target {
|
||||
case "model":
|
||||
oldModel := al.model
|
||||
al.model = value
|
||||
defaultAgent := al.registry.GetDefaultAgent()
|
||||
if defaultAgent == nil {
|
||||
return "No default agent configured", true
|
||||
}
|
||||
oldModel := defaultAgent.Model
|
||||
defaultAgent.Model = value
|
||||
return fmt.Sprintf("Switched model from %s to %s", oldModel, value), true
|
||||
case "channel":
|
||||
// This changes the 'default' channel for some operations, or effectively redirects output?
|
||||
// For now, let's just validate if the channel exists
|
||||
if al.channelManager == nil {
|
||||
return "Channel manager not initialized", true
|
||||
}
|
||||
if _, exists := al.channelManager.GetChannel(value); !exists && value != "cli" {
|
||||
return fmt.Sprintf("Channel '%s' not found or not enabled", value), true
|
||||
}
|
||||
|
||||
// If message came from CLI, maybe we want to redirect CLI output to this channel?
|
||||
// That would require state persistence about "redirected channel"
|
||||
// For now, just acknowledged.
|
||||
return fmt.Sprintf("Switched target channel to %s (Note: this currently only validates existence)", value), true
|
||||
return fmt.Sprintf("Switched target channel to %s", value), true
|
||||
default:
|
||||
return fmt.Sprintf("Unknown switch target: %s", target), true
|
||||
}
|
||||
@@ -1064,3 +1049,30 @@ func (al *AgentLoop) handleCommand(ctx context.Context, msg bus.InboundMessage)
|
||||
|
||||
return "", false
|
||||
}
|
||||
|
||||
// extractPeer extracts the routing peer from inbound message metadata.
|
||||
func extractPeer(msg bus.InboundMessage) *routing.RoutePeer {
|
||||
peerKind := msg.Metadata["peer_kind"]
|
||||
if peerKind == "" {
|
||||
return nil
|
||||
}
|
||||
peerID := msg.Metadata["peer_id"]
|
||||
if peerID == "" {
|
||||
if peerKind == "direct" {
|
||||
peerID = msg.SenderID
|
||||
} else {
|
||||
peerID = msg.ChatID
|
||||
}
|
||||
}
|
||||
return &routing.RoutePeer{Kind: peerKind, ID: peerID}
|
||||
}
|
||||
|
||||
// extractParentPeer extracts the parent peer (reply-to) from inbound message metadata.
|
||||
func extractParentPeer(msg bus.InboundMessage) *routing.RoutePeer {
|
||||
parentKind := msg.Metadata["parent_peer_kind"]
|
||||
parentID := msg.Metadata["parent_peer_id"]
|
||||
if parentKind == "" || parentID == "" {
|
||||
return nil
|
||||
}
|
||||
return &routing.RoutePeer{Kind: parentKind, ID: parentID}
|
||||
}
|
||||
|
||||
@@ -594,7 +594,11 @@ func TestAgentLoop_ContextExhaustionRetry(t *testing.T) {
|
||||
{Role: "assistant", Content: "Old response 2"},
|
||||
{Role: "user", Content: "Trigger message"},
|
||||
}
|
||||
al.sessions.SetHistory(sessionKey, history)
|
||||
defaultAgent := al.registry.GetDefaultAgent()
|
||||
if defaultAgent == nil {
|
||||
t.Fatal("No default agent found")
|
||||
}
|
||||
defaultAgent.Sessions.SetHistory(sessionKey, history)
|
||||
|
||||
// Call ProcessDirectWithChannel
|
||||
// Note: ProcessDirectWithChannel calls processMessage which will execute runLLMIteration
|
||||
@@ -614,7 +618,7 @@ func TestAgentLoop_ContextExhaustionRetry(t *testing.T) {
|
||||
}
|
||||
|
||||
// Check final history length
|
||||
finalHistory := al.sessions.GetHistory(sessionKey)
|
||||
finalHistory := defaultAgent.Sessions.GetHistory(sessionKey)
|
||||
// We verify that the history has been modified (compressed)
|
||||
// Original length: 6
|
||||
// Expected behavior: compression drops ~50% of history (mid slice)
|
||||
|
||||
@@ -0,0 +1,114 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/routing"
|
||||
)
|
||||
|
||||
// AgentRegistry manages multiple agent instances and routes messages to them.
|
||||
type AgentRegistry struct {
|
||||
agents map[string]*AgentInstance
|
||||
resolver *routing.RouteResolver
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewAgentRegistry creates a registry from config, instantiating all agents.
|
||||
func NewAgentRegistry(
|
||||
cfg *config.Config,
|
||||
provider providers.LLMProvider,
|
||||
) *AgentRegistry {
|
||||
registry := &AgentRegistry{
|
||||
agents: make(map[string]*AgentInstance),
|
||||
resolver: routing.NewRouteResolver(cfg),
|
||||
}
|
||||
|
||||
agentConfigs := cfg.Agents.List
|
||||
if len(agentConfigs) == 0 {
|
||||
implicitAgent := &config.AgentConfig{
|
||||
ID: "main",
|
||||
Default: true,
|
||||
}
|
||||
instance := NewAgentInstance(implicitAgent, &cfg.Agents.Defaults, cfg, provider)
|
||||
registry.agents["main"] = instance
|
||||
logger.InfoCF("agent", "Created implicit main agent (no agents.list configured)", nil)
|
||||
} else {
|
||||
for i := range agentConfigs {
|
||||
ac := &agentConfigs[i]
|
||||
id := routing.NormalizeAgentID(ac.ID)
|
||||
instance := NewAgentInstance(ac, &cfg.Agents.Defaults, cfg, provider)
|
||||
registry.agents[id] = instance
|
||||
logger.InfoCF("agent", "Registered agent",
|
||||
map[string]interface{}{
|
||||
"agent_id": id,
|
||||
"name": ac.Name,
|
||||
"workspace": instance.Workspace,
|
||||
"model": instance.Model,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return registry
|
||||
}
|
||||
|
||||
// GetAgent returns the agent instance for a given ID.
|
||||
func (r *AgentRegistry) GetAgent(agentID string) (*AgentInstance, bool) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
id := routing.NormalizeAgentID(agentID)
|
||||
agent, ok := r.agents[id]
|
||||
return agent, ok
|
||||
}
|
||||
|
||||
// ResolveRoute determines which agent handles the message.
|
||||
func (r *AgentRegistry) ResolveRoute(input routing.RouteInput) routing.ResolvedRoute {
|
||||
return r.resolver.ResolveRoute(input)
|
||||
}
|
||||
|
||||
// ListAgentIDs returns all registered agent IDs.
|
||||
func (r *AgentRegistry) ListAgentIDs() []string {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
ids := make([]string, 0, len(r.agents))
|
||||
for id := range r.agents {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
// CanSpawnSubagent checks if parentAgentID is allowed to spawn targetAgentID.
|
||||
func (r *AgentRegistry) CanSpawnSubagent(parentAgentID, targetAgentID string) bool {
|
||||
parent, ok := r.GetAgent(parentAgentID)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
if parent.Subagents == nil || parent.Subagents.AllowAgents == nil {
|
||||
return false
|
||||
}
|
||||
targetNorm := routing.NormalizeAgentID(targetAgentID)
|
||||
for _, allowed := range parent.Subagents.AllowAgents {
|
||||
if allowed == "*" {
|
||||
return true
|
||||
}
|
||||
if routing.NormalizeAgentID(allowed) == targetNorm {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// GetDefaultAgent returns the default agent instance.
|
||||
func (r *AgentRegistry) GetDefaultAgent() *AgentInstance {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
if agent, ok := r.agents["main"]; ok {
|
||||
return agent
|
||||
}
|
||||
for _, agent := range r.agents {
|
||||
return agent
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,199 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
type mockRegistryProvider struct{}
|
||||
|
||||
func (m *mockRegistryProvider) Chat(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, options map[string]interface{}) (*providers.LLMResponse, error) {
|
||||
return &providers.LLMResponse{Content: "mock", FinishReason: "stop"}, nil
|
||||
}
|
||||
|
||||
func (m *mockRegistryProvider) GetDefaultModel() string {
|
||||
return "mock-model"
|
||||
}
|
||||
|
||||
func testCfg(agents []config.AgentConfig) *config.Config {
|
||||
return &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: "/tmp/picoclaw-test-registry",
|
||||
Model: "gpt-4",
|
||||
MaxTokens: 8192,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
List: agents,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewAgentRegistry_ImplicitMain(t *testing.T) {
|
||||
cfg := testCfg(nil)
|
||||
registry := NewAgentRegistry(cfg, &mockRegistryProvider{})
|
||||
|
||||
ids := registry.ListAgentIDs()
|
||||
if len(ids) != 1 || ids[0] != "main" {
|
||||
t.Errorf("expected implicit main agent, got %v", ids)
|
||||
}
|
||||
|
||||
agent, ok := registry.GetAgent("main")
|
||||
if !ok || agent == nil {
|
||||
t.Fatal("expected to find 'main' agent")
|
||||
}
|
||||
if agent.ID != "main" {
|
||||
t.Errorf("agent.ID = %q, want 'main'", agent.ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewAgentRegistry_ExplicitAgents(t *testing.T) {
|
||||
cfg := testCfg([]config.AgentConfig{
|
||||
{ID: "sales", Default: true, Name: "Sales Bot"},
|
||||
{ID: "support", Name: "Support Bot"},
|
||||
})
|
||||
registry := NewAgentRegistry(cfg, &mockRegistryProvider{})
|
||||
|
||||
ids := registry.ListAgentIDs()
|
||||
if len(ids) != 2 {
|
||||
t.Fatalf("expected 2 agents, got %d: %v", len(ids), ids)
|
||||
}
|
||||
|
||||
sales, ok := registry.GetAgent("sales")
|
||||
if !ok || sales == nil {
|
||||
t.Fatal("expected to find 'sales' agent")
|
||||
}
|
||||
if sales.Name != "Sales Bot" {
|
||||
t.Errorf("sales.Name = %q, want 'Sales Bot'", sales.Name)
|
||||
}
|
||||
|
||||
support, ok := registry.GetAgent("support")
|
||||
if !ok || support == nil {
|
||||
t.Fatal("expected to find 'support' agent")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentRegistry_GetAgent_Normalize(t *testing.T) {
|
||||
cfg := testCfg([]config.AgentConfig{
|
||||
{ID: "my-agent", Default: true},
|
||||
})
|
||||
registry := NewAgentRegistry(cfg, &mockRegistryProvider{})
|
||||
|
||||
agent, ok := registry.GetAgent("My-Agent")
|
||||
if !ok || agent == nil {
|
||||
t.Fatal("expected to find agent with normalized ID")
|
||||
}
|
||||
if agent.ID != "my-agent" {
|
||||
t.Errorf("agent.ID = %q, want 'my-agent'", agent.ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentRegistry_GetDefaultAgent(t *testing.T) {
|
||||
cfg := testCfg([]config.AgentConfig{
|
||||
{ID: "alpha"},
|
||||
{ID: "beta", Default: true},
|
||||
})
|
||||
registry := NewAgentRegistry(cfg, &mockRegistryProvider{})
|
||||
|
||||
// GetDefaultAgent first checks for "main", then returns any
|
||||
agent := registry.GetDefaultAgent()
|
||||
if agent == nil {
|
||||
t.Fatal("expected a default agent")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentRegistry_CanSpawnSubagent(t *testing.T) {
|
||||
cfg := testCfg([]config.AgentConfig{
|
||||
{
|
||||
ID: "parent",
|
||||
Default: true,
|
||||
Subagents: &config.SubagentsConfig{
|
||||
AllowAgents: []string{"child1", "child2"},
|
||||
},
|
||||
},
|
||||
{ID: "child1"},
|
||||
{ID: "child2"},
|
||||
{ID: "restricted"},
|
||||
})
|
||||
registry := NewAgentRegistry(cfg, &mockRegistryProvider{})
|
||||
|
||||
if !registry.CanSpawnSubagent("parent", "child1") {
|
||||
t.Error("expected parent to be allowed to spawn child1")
|
||||
}
|
||||
if !registry.CanSpawnSubagent("parent", "child2") {
|
||||
t.Error("expected parent to be allowed to spawn child2")
|
||||
}
|
||||
if registry.CanSpawnSubagent("parent", "restricted") {
|
||||
t.Error("expected parent to NOT be allowed to spawn restricted")
|
||||
}
|
||||
if registry.CanSpawnSubagent("child1", "child2") {
|
||||
t.Error("expected child1 to NOT be allowed to spawn (no subagents config)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentRegistry_CanSpawnSubagent_Wildcard(t *testing.T) {
|
||||
cfg := testCfg([]config.AgentConfig{
|
||||
{
|
||||
ID: "admin",
|
||||
Default: true,
|
||||
Subagents: &config.SubagentsConfig{
|
||||
AllowAgents: []string{"*"},
|
||||
},
|
||||
},
|
||||
{ID: "any-agent"},
|
||||
})
|
||||
registry := NewAgentRegistry(cfg, &mockRegistryProvider{})
|
||||
|
||||
if !registry.CanSpawnSubagent("admin", "any-agent") {
|
||||
t.Error("expected wildcard to allow spawning any agent")
|
||||
}
|
||||
if !registry.CanSpawnSubagent("admin", "nonexistent") {
|
||||
t.Error("expected wildcard to allow spawning even nonexistent agents")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentInstance_Model(t *testing.T) {
|
||||
model := &config.AgentModelConfig{Primary: "claude-opus"}
|
||||
cfg := testCfg([]config.AgentConfig{
|
||||
{ID: "custom", Default: true, Model: model},
|
||||
})
|
||||
registry := NewAgentRegistry(cfg, &mockRegistryProvider{})
|
||||
|
||||
agent, _ := registry.GetAgent("custom")
|
||||
if agent.Model != "claude-opus" {
|
||||
t.Errorf("agent.Model = %q, want 'claude-opus'", agent.Model)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentInstance_FallbackInheritance(t *testing.T) {
|
||||
cfg := testCfg([]config.AgentConfig{
|
||||
{ID: "inherit", Default: true},
|
||||
})
|
||||
cfg.Agents.Defaults.ModelFallbacks = []string{"openai/gpt-4o-mini", "anthropic/haiku"}
|
||||
registry := NewAgentRegistry(cfg, &mockRegistryProvider{})
|
||||
|
||||
agent, _ := registry.GetAgent("inherit")
|
||||
if len(agent.Fallbacks) != 2 {
|
||||
t.Errorf("expected 2 fallbacks inherited from defaults, got %d", len(agent.Fallbacks))
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentInstance_FallbackExplicitEmpty(t *testing.T) {
|
||||
model := &config.AgentModelConfig{
|
||||
Primary: "gpt-4",
|
||||
Fallbacks: []string{}, // explicitly empty = disable
|
||||
}
|
||||
cfg := testCfg([]config.AgentConfig{
|
||||
{ID: "no-fallback", Default: true, Model: model},
|
||||
})
|
||||
cfg.Agents.Defaults.ModelFallbacks = []string{"should-not-inherit"}
|
||||
registry := NewAgentRegistry(cfg, &mockRegistryProvider{})
|
||||
|
||||
agent, _ := registry.GetAgent("no-fallback")
|
||||
if len(agent.Fallbacks) != 0 {
|
||||
t.Errorf("expected 0 fallbacks (explicit empty), got %d: %v", len(agent.Fallbacks), agent.Fallbacks)
|
||||
}
|
||||
}
|
||||
+6
-11
@@ -2,7 +2,6 @@ package channels
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
@@ -87,17 +86,13 @@ func (c *BaseChannel) HandleMessage(senderID, chatID, content string, media []st
|
||||
return
|
||||
}
|
||||
|
||||
// Build session key: channel:chatID
|
||||
sessionKey := fmt.Sprintf("%s:%s", c.name, chatID)
|
||||
|
||||
msg := bus.InboundMessage{
|
||||
Channel: c.name,
|
||||
SenderID: senderID,
|
||||
ChatID: chatID,
|
||||
Content: content,
|
||||
Media: media,
|
||||
SessionKey: sessionKey,
|
||||
Metadata: metadata,
|
||||
Channel: c.name,
|
||||
SenderID: senderID,
|
||||
ChatID: chatID,
|
||||
Content: content,
|
||||
Media: media,
|
||||
Metadata: metadata,
|
||||
}
|
||||
|
||||
c.bus.PublishInbound(msg)
|
||||
|
||||
@@ -376,6 +376,13 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag
|
||||
"preview": utils.Truncate(content, 50),
|
||||
})
|
||||
|
||||
peerKind := "channel"
|
||||
peerID := m.ChannelID
|
||||
if m.GuildID == "" {
|
||||
peerKind = "direct"
|
||||
peerID = senderID
|
||||
}
|
||||
|
||||
metadata := map[string]string{
|
||||
"message_id": m.ID,
|
||||
"user_id": senderID,
|
||||
@@ -384,6 +391,8 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag
|
||||
"guild_id": m.GuildID,
|
||||
"channel_id": m.ChannelID,
|
||||
"is_dm": fmt.Sprintf("%t", m.GuildID == ""),
|
||||
"peer_kind": peerKind,
|
||||
"peer_id": peerID,
|
||||
}
|
||||
|
||||
c.HandleMessage(senderID, m.ChannelID, content, mediaPaths, metadata)
|
||||
|
||||
@@ -25,6 +25,7 @@ type SlackChannel struct {
|
||||
api *slack.Client
|
||||
socketClient *socketmode.Client
|
||||
botUserID string
|
||||
teamID string
|
||||
transcriber *voice.GroqTranscriber
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
@@ -72,6 +73,7 @@ func (c *SlackChannel) Start(ctx context.Context) error {
|
||||
return fmt.Errorf("slack auth test failed: %w", err)
|
||||
}
|
||||
c.botUserID = authResp.UserID
|
||||
c.teamID = authResp.TeamID
|
||||
|
||||
logger.InfoCF("slack", "Slack bot connected", map[string]interface{}{
|
||||
"bot_user_id": c.botUserID,
|
||||
@@ -274,11 +276,21 @@ func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) {
|
||||
return
|
||||
}
|
||||
|
||||
peerKind := "channel"
|
||||
peerID := channelID
|
||||
if strings.HasPrefix(channelID, "D") {
|
||||
peerKind = "direct"
|
||||
peerID = senderID
|
||||
}
|
||||
|
||||
metadata := map[string]string{
|
||||
"message_ts": messageTS,
|
||||
"channel_id": channelID,
|
||||
"thread_ts": threadTS,
|
||||
"platform": "slack",
|
||||
"peer_kind": peerKind,
|
||||
"peer_id": peerID,
|
||||
"team_id": c.teamID,
|
||||
}
|
||||
|
||||
logger.DebugCF("slack", "Received message", map[string]interface{}{
|
||||
@@ -331,12 +343,22 @@ func (c *SlackChannel) handleAppMention(ev *slackevents.AppMentionEvent) {
|
||||
return
|
||||
}
|
||||
|
||||
mentionPeerKind := "channel"
|
||||
mentionPeerID := channelID
|
||||
if strings.HasPrefix(channelID, "D") {
|
||||
mentionPeerKind = "direct"
|
||||
mentionPeerID = senderID
|
||||
}
|
||||
|
||||
metadata := map[string]string{
|
||||
"message_ts": messageTS,
|
||||
"channel_id": channelID,
|
||||
"thread_ts": threadTS,
|
||||
"platform": "slack",
|
||||
"is_mention": "true",
|
||||
"peer_kind": mentionPeerKind,
|
||||
"peer_id": mentionPeerID,
|
||||
"team_id": c.teamID,
|
||||
}
|
||||
|
||||
c.HandleMessage(senderID, chatID, content, nil, metadata)
|
||||
@@ -373,6 +395,9 @@ func (c *SlackChannel) handleSlashCommand(event socketmode.Event) {
|
||||
"platform": "slack",
|
||||
"is_command": "true",
|
||||
"trigger_id": cmd.TriggerID,
|
||||
"peer_kind": "channel",
|
||||
"peer_id": channelID,
|
||||
"team_id": c.teamID,
|
||||
}
|
||||
|
||||
logger.DebugCF("slack", "Slash command received", map[string]interface{}{
|
||||
|
||||
@@ -347,12 +347,21 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Mes
|
||||
c.placeholders.Store(chatIDStr, pID)
|
||||
}
|
||||
|
||||
peerKind := "direct"
|
||||
peerID := fmt.Sprintf("%d", user.ID)
|
||||
if message.Chat.Type != "private" {
|
||||
peerKind = "group"
|
||||
peerID = fmt.Sprintf("%d", chatID)
|
||||
}
|
||||
|
||||
metadata := map[string]string{
|
||||
"message_id": fmt.Sprintf("%d", message.MessageID),
|
||||
"user_id": fmt.Sprintf("%d", user.ID),
|
||||
"username": user.Username,
|
||||
"first_name": user.FirstName,
|
||||
"is_group": fmt.Sprintf("%t", message.Chat.Type != "private"),
|
||||
"peer_kind": peerKind,
|
||||
"peer_id": peerID,
|
||||
}
|
||||
|
||||
c.HandleMessage(fmt.Sprintf("%d", user.ID), fmt.Sprintf("%d", chatID), content, mediaPaths, metadata)
|
||||
|
||||
+116
-7
@@ -45,6 +45,8 @@ func (f *FlexibleStringSlice) UnmarshalJSON(data []byte) error {
|
||||
|
||||
type Config struct {
|
||||
Agents AgentsConfig `json:"agents"`
|
||||
Bindings []AgentBinding `json:"bindings,omitempty"`
|
||||
Session SessionConfig `json:"session,omitempty"`
|
||||
Channels ChannelsConfig `json:"channels"`
|
||||
Providers ProvidersConfig `json:"providers"`
|
||||
Gateway GatewayConfig `json:"gateway"`
|
||||
@@ -56,16 +58,97 @@ type Config struct {
|
||||
|
||||
type AgentsConfig struct {
|
||||
Defaults AgentDefaults `json:"defaults"`
|
||||
List []AgentConfig `json:"list,omitempty"`
|
||||
}
|
||||
|
||||
// AgentModelConfig supports both string and structured model config.
|
||||
// String format: "gpt-4" (just primary, no fallbacks)
|
||||
// Object format: {"primary": "gpt-4", "fallbacks": ["claude-haiku"]}
|
||||
type AgentModelConfig struct {
|
||||
Primary string `json:"primary,omitempty"`
|
||||
Fallbacks []string `json:"fallbacks,omitempty"`
|
||||
}
|
||||
|
||||
func (m *AgentModelConfig) UnmarshalJSON(data []byte) error {
|
||||
var s string
|
||||
if err := json.Unmarshal(data, &s); err == nil {
|
||||
m.Primary = s
|
||||
m.Fallbacks = nil
|
||||
return nil
|
||||
}
|
||||
type raw struct {
|
||||
Primary string `json:"primary"`
|
||||
Fallbacks []string `json:"fallbacks"`
|
||||
}
|
||||
var r raw
|
||||
if err := json.Unmarshal(data, &r); err != nil {
|
||||
return err
|
||||
}
|
||||
m.Primary = r.Primary
|
||||
m.Fallbacks = r.Fallbacks
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m AgentModelConfig) MarshalJSON() ([]byte, error) {
|
||||
if len(m.Fallbacks) == 0 && m.Primary != "" {
|
||||
return json.Marshal(m.Primary)
|
||||
}
|
||||
type raw struct {
|
||||
Primary string `json:"primary,omitempty"`
|
||||
Fallbacks []string `json:"fallbacks,omitempty"`
|
||||
}
|
||||
return json.Marshal(raw{Primary: m.Primary, Fallbacks: m.Fallbacks})
|
||||
}
|
||||
|
||||
type AgentConfig struct {
|
||||
ID string `json:"id"`
|
||||
Default bool `json:"default,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Workspace string `json:"workspace,omitempty"`
|
||||
Model *AgentModelConfig `json:"model,omitempty"`
|
||||
Skills []string `json:"skills,omitempty"`
|
||||
Subagents *SubagentsConfig `json:"subagents,omitempty"`
|
||||
}
|
||||
|
||||
type SubagentsConfig struct {
|
||||
AllowAgents []string `json:"allow_agents,omitempty"`
|
||||
Model *AgentModelConfig `json:"model,omitempty"`
|
||||
}
|
||||
|
||||
type PeerMatch struct {
|
||||
Kind string `json:"kind"`
|
||||
ID string `json:"id"`
|
||||
}
|
||||
|
||||
type BindingMatch struct {
|
||||
Channel string `json:"channel"`
|
||||
AccountID string `json:"account_id,omitempty"`
|
||||
Peer *PeerMatch `json:"peer,omitempty"`
|
||||
GuildID string `json:"guild_id,omitempty"`
|
||||
TeamID string `json:"team_id,omitempty"`
|
||||
}
|
||||
|
||||
type AgentBinding struct {
|
||||
AgentID string `json:"agent_id"`
|
||||
Match BindingMatch `json:"match"`
|
||||
}
|
||||
|
||||
type SessionConfig struct {
|
||||
DMScope string `json:"dm_scope,omitempty"`
|
||||
IdentityLinks map[string][]string `json:"identity_links,omitempty"`
|
||||
}
|
||||
|
||||
type AgentDefaults struct {
|
||||
Workspace string `json:"workspace" env:"PICOCLAW_AGENTS_DEFAULTS_WORKSPACE"`
|
||||
RestrictToWorkspace bool `json:"restrict_to_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE"`
|
||||
Provider string `json:"provider" env:"PICOCLAW_AGENTS_DEFAULTS_PROVIDER"`
|
||||
Model string `json:"model" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL"`
|
||||
MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"`
|
||||
Temperature float64 `json:"temperature" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"`
|
||||
MaxToolIterations int `json:"max_tool_iterations" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"`
|
||||
Workspace string `json:"workspace" env:"PICOCLAW_AGENTS_DEFAULTS_WORKSPACE"`
|
||||
RestrictToWorkspace bool `json:"restrict_to_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE"`
|
||||
Provider string `json:"provider" env:"PICOCLAW_AGENTS_DEFAULTS_PROVIDER"`
|
||||
Model string `json:"model" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL"`
|
||||
ModelFallbacks []string `json:"model_fallbacks,omitempty"`
|
||||
ImageModel string `json:"image_model,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_IMAGE_MODEL"`
|
||||
ImageModelFallbacks []string `json:"image_model_fallbacks,omitempty"`
|
||||
MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"`
|
||||
Temperature float64 `json:"temperature" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"`
|
||||
MaxToolIterations int `json:"max_tool_iterations" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"`
|
||||
}
|
||||
|
||||
type ChannelsConfig struct {
|
||||
@@ -461,6 +544,32 @@ func (c *Config) GetAPIBase() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// ModelConfig holds primary model and fallback list.
|
||||
type ModelConfig struct {
|
||||
Primary string
|
||||
Fallbacks []string
|
||||
}
|
||||
|
||||
// GetModelConfig returns the text model configuration with fallbacks.
|
||||
func (c *Config) GetModelConfig() ModelConfig {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return ModelConfig{
|
||||
Primary: c.Agents.Defaults.Model,
|
||||
Fallbacks: c.Agents.Defaults.ModelFallbacks,
|
||||
}
|
||||
}
|
||||
|
||||
// GetImageModelConfig returns the image model configuration with fallbacks.
|
||||
func (c *Config) GetImageModelConfig() ModelConfig {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return ModelConfig{
|
||||
Primary: c.Agents.Defaults.ImageModel,
|
||||
Fallbacks: c.Agents.Defaults.ImageModelFallbacks,
|
||||
}
|
||||
}
|
||||
|
||||
func expandHome(path string) string {
|
||||
if path == "" {
|
||||
return path
|
||||
|
||||
+181
-32
@@ -1,12 +1,193 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestAgentModelConfig_UnmarshalString(t *testing.T) {
|
||||
var m AgentModelConfig
|
||||
if err := json.Unmarshal([]byte(`"gpt-4"`), &m); err != nil {
|
||||
t.Fatalf("unmarshal string: %v", err)
|
||||
}
|
||||
if m.Primary != "gpt-4" {
|
||||
t.Errorf("Primary = %q, want 'gpt-4'", m.Primary)
|
||||
}
|
||||
if m.Fallbacks != nil {
|
||||
t.Errorf("Fallbacks = %v, want nil", m.Fallbacks)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentModelConfig_UnmarshalObject(t *testing.T) {
|
||||
var m AgentModelConfig
|
||||
data := `{"primary": "claude-opus", "fallbacks": ["gpt-4o-mini", "haiku"]}`
|
||||
if err := json.Unmarshal([]byte(data), &m); err != nil {
|
||||
t.Fatalf("unmarshal object: %v", err)
|
||||
}
|
||||
if m.Primary != "claude-opus" {
|
||||
t.Errorf("Primary = %q, want 'claude-opus'", m.Primary)
|
||||
}
|
||||
if len(m.Fallbacks) != 2 {
|
||||
t.Fatalf("Fallbacks len = %d, want 2", len(m.Fallbacks))
|
||||
}
|
||||
if m.Fallbacks[0] != "gpt-4o-mini" || m.Fallbacks[1] != "haiku" {
|
||||
t.Errorf("Fallbacks = %v", m.Fallbacks)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentModelConfig_MarshalString(t *testing.T) {
|
||||
m := AgentModelConfig{Primary: "gpt-4"}
|
||||
data, err := json.Marshal(m)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal: %v", err)
|
||||
}
|
||||
if string(data) != `"gpt-4"` {
|
||||
t.Errorf("marshal = %s, want '\"gpt-4\"'", string(data))
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentModelConfig_MarshalObject(t *testing.T) {
|
||||
m := AgentModelConfig{Primary: "claude-opus", Fallbacks: []string{"haiku"}}
|
||||
data, err := json.Marshal(m)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal: %v", err)
|
||||
}
|
||||
var result map[string]interface{}
|
||||
json.Unmarshal(data, &result)
|
||||
if result["primary"] != "claude-opus" {
|
||||
t.Errorf("primary = %v", result["primary"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentConfig_FullParse(t *testing.T) {
|
||||
jsonData := `{
|
||||
"agents": {
|
||||
"defaults": {
|
||||
"workspace": "~/.picoclaw/workspace",
|
||||
"model": "glm-4.7",
|
||||
"max_tokens": 8192,
|
||||
"max_tool_iterations": 20
|
||||
},
|
||||
"list": [
|
||||
{
|
||||
"id": "sales",
|
||||
"default": true,
|
||||
"name": "Sales Bot",
|
||||
"model": "gpt-4"
|
||||
},
|
||||
{
|
||||
"id": "support",
|
||||
"name": "Support Bot",
|
||||
"model": {
|
||||
"primary": "claude-opus",
|
||||
"fallbacks": ["haiku"]
|
||||
},
|
||||
"subagents": {
|
||||
"allow_agents": ["sales"]
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
"bindings": [
|
||||
{
|
||||
"agent_id": "support",
|
||||
"match": {
|
||||
"channel": "telegram",
|
||||
"account_id": "*",
|
||||
"peer": {"kind": "direct", "id": "user123"}
|
||||
}
|
||||
}
|
||||
],
|
||||
"session": {
|
||||
"dm_scope": "per-peer",
|
||||
"identity_links": {
|
||||
"john": ["telegram:123", "discord:john#1234"]
|
||||
}
|
||||
}
|
||||
}`
|
||||
|
||||
cfg := DefaultConfig()
|
||||
if err := json.Unmarshal([]byte(jsonData), cfg); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
|
||||
if len(cfg.Agents.List) != 2 {
|
||||
t.Fatalf("agents.list len = %d, want 2", len(cfg.Agents.List))
|
||||
}
|
||||
|
||||
sales := cfg.Agents.List[0]
|
||||
if sales.ID != "sales" || !sales.Default || sales.Name != "Sales Bot" {
|
||||
t.Errorf("sales = %+v", sales)
|
||||
}
|
||||
if sales.Model == nil || sales.Model.Primary != "gpt-4" {
|
||||
t.Errorf("sales.Model = %+v", sales.Model)
|
||||
}
|
||||
|
||||
support := cfg.Agents.List[1]
|
||||
if support.ID != "support" || support.Name != "Support Bot" {
|
||||
t.Errorf("support = %+v", support)
|
||||
}
|
||||
if support.Model == nil || support.Model.Primary != "claude-opus" {
|
||||
t.Errorf("support.Model = %+v", support.Model)
|
||||
}
|
||||
if len(support.Model.Fallbacks) != 1 || support.Model.Fallbacks[0] != "haiku" {
|
||||
t.Errorf("support.Model.Fallbacks = %v", support.Model.Fallbacks)
|
||||
}
|
||||
if support.Subagents == nil || len(support.Subagents.AllowAgents) != 1 {
|
||||
t.Errorf("support.Subagents = %+v", support.Subagents)
|
||||
}
|
||||
|
||||
if len(cfg.Bindings) != 1 {
|
||||
t.Fatalf("bindings len = %d, want 1", len(cfg.Bindings))
|
||||
}
|
||||
binding := cfg.Bindings[0]
|
||||
if binding.AgentID != "support" || binding.Match.Channel != "telegram" {
|
||||
t.Errorf("binding = %+v", binding)
|
||||
}
|
||||
if binding.Match.Peer == nil || binding.Match.Peer.Kind != "direct" || binding.Match.Peer.ID != "user123" {
|
||||
t.Errorf("binding.Match.Peer = %+v", binding.Match.Peer)
|
||||
}
|
||||
|
||||
if cfg.Session.DMScope != "per-peer" {
|
||||
t.Errorf("Session.DMScope = %q", cfg.Session.DMScope)
|
||||
}
|
||||
if len(cfg.Session.IdentityLinks) != 1 {
|
||||
t.Errorf("Session.IdentityLinks = %v", cfg.Session.IdentityLinks)
|
||||
}
|
||||
links := cfg.Session.IdentityLinks["john"]
|
||||
if len(links) != 2 {
|
||||
t.Errorf("john links = %v", links)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig_BackwardCompat_NoAgentsList(t *testing.T) {
|
||||
jsonData := `{
|
||||
"agents": {
|
||||
"defaults": {
|
||||
"workspace": "~/.picoclaw/workspace",
|
||||
"model": "glm-4.7",
|
||||
"max_tokens": 8192,
|
||||
"max_tool_iterations": 20
|
||||
}
|
||||
}
|
||||
}`
|
||||
|
||||
cfg := DefaultConfig()
|
||||
if err := json.Unmarshal([]byte(jsonData), cfg); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
|
||||
if len(cfg.Agents.List) != 0 {
|
||||
t.Errorf("agents.list should be empty for backward compat, got %d", len(cfg.Agents.List))
|
||||
}
|
||||
if len(cfg.Bindings) != 0 {
|
||||
t.Errorf("bindings should be empty, got %d", len(cfg.Bindings))
|
||||
}
|
||||
}
|
||||
|
||||
// TestDefaultConfig_HeartbeatEnabled verifies heartbeat is enabled by default
|
||||
func TestDefaultConfig_HeartbeatEnabled(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
@@ -20,8 +201,6 @@ func TestDefaultConfig_HeartbeatEnabled(t *testing.T) {
|
||||
func TestDefaultConfig_WorkspacePath(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
|
||||
// Just verify the workspace is set, don't compare exact paths
|
||||
// since expandHome behavior may differ based on environment
|
||||
if cfg.Agents.Defaults.Workspace == "" {
|
||||
t.Error("Workspace should not be empty")
|
||||
}
|
||||
@@ -79,7 +258,6 @@ func TestDefaultConfig_Gateway(t *testing.T) {
|
||||
func TestDefaultConfig_Providers(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
|
||||
// Verify all providers are empty by default
|
||||
if cfg.Providers.Anthropic.APIKey != "" {
|
||||
t.Error("Anthropic API key should be empty by default")
|
||||
}
|
||||
@@ -89,46 +267,18 @@ func TestDefaultConfig_Providers(t *testing.T) {
|
||||
if cfg.Providers.OpenRouter.APIKey != "" {
|
||||
t.Error("OpenRouter API key should be empty by default")
|
||||
}
|
||||
if cfg.Providers.Groq.APIKey != "" {
|
||||
t.Error("Groq API key should be empty by default")
|
||||
}
|
||||
if cfg.Providers.Zhipu.APIKey != "" {
|
||||
t.Error("Zhipu API key should be empty by default")
|
||||
}
|
||||
if cfg.Providers.VLLM.APIKey != "" {
|
||||
t.Error("VLLM API key should be empty by default")
|
||||
}
|
||||
if cfg.Providers.Gemini.APIKey != "" {
|
||||
t.Error("Gemini API key should be empty by default")
|
||||
}
|
||||
}
|
||||
|
||||
// TestDefaultConfig_Channels verifies channels are disabled by default
|
||||
func TestDefaultConfig_Channels(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
|
||||
// Verify all channels are disabled by default
|
||||
if cfg.Channels.WhatsApp.Enabled {
|
||||
t.Error("WhatsApp should be disabled by default")
|
||||
}
|
||||
if cfg.Channels.Telegram.Enabled {
|
||||
t.Error("Telegram should be disabled by default")
|
||||
}
|
||||
if cfg.Channels.Feishu.Enabled {
|
||||
t.Error("Feishu should be disabled by default")
|
||||
}
|
||||
if cfg.Channels.Discord.Enabled {
|
||||
t.Error("Discord should be disabled by default")
|
||||
}
|
||||
if cfg.Channels.MaixCam.Enabled {
|
||||
t.Error("MaixCam should be disabled by default")
|
||||
}
|
||||
if cfg.Channels.QQ.Enabled {
|
||||
t.Error("QQ should be disabled by default")
|
||||
}
|
||||
if cfg.Channels.DingTalk.Enabled {
|
||||
t.Error("DingTalk should be disabled by default")
|
||||
}
|
||||
if cfg.Channels.Slack.Enabled {
|
||||
t.Error("Slack should be disabled by default")
|
||||
}
|
||||
@@ -178,7 +328,6 @@ func TestSaveConfig_FilePermissions(t *testing.T) {
|
||||
func TestConfig_Complete(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
|
||||
// Verify complete config structure
|
||||
if cfg.Agents.Defaults.Workspace == "" {
|
||||
t.Error("Workspace should not be empty")
|
||||
}
|
||||
|
||||
@@ -0,0 +1,119 @@
|
||||
//go:build integration
|
||||
|
||||
package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
exec "os/exec"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestIntegration_RealCodexCLI tests the CodexCliProvider with a real codex CLI.
|
||||
// Run with: go test -tags=integration ./pkg/providers/...
|
||||
func TestIntegration_RealCodexCLI(t *testing.T) {
|
||||
path, err := exec.LookPath("codex")
|
||||
if err != nil {
|
||||
t.Skip("codex CLI not found in PATH, skipping integration test")
|
||||
}
|
||||
t.Logf("Using codex CLI at: %s", path)
|
||||
|
||||
p := NewCodexCliProvider(t.TempDir())
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 90*time.Second)
|
||||
defer cancel()
|
||||
|
||||
resp, err := p.Chat(ctx, []Message{
|
||||
{Role: "user", Content: "Respond with only the word 'pong'. Nothing else."},
|
||||
}, nil, "", nil)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() with real CLI error = %v", err)
|
||||
}
|
||||
|
||||
if resp.Content == "" {
|
||||
t.Error("Content is empty")
|
||||
}
|
||||
if resp.FinishReason != "stop" {
|
||||
t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "stop")
|
||||
}
|
||||
if resp.Usage != nil {
|
||||
t.Logf("Usage: prompt=%d, completion=%d, total=%d",
|
||||
resp.Usage.PromptTokens, resp.Usage.CompletionTokens, resp.Usage.TotalTokens)
|
||||
}
|
||||
|
||||
t.Logf("Response content: %q", resp.Content)
|
||||
|
||||
if !strings.Contains(strings.ToLower(resp.Content), "pong") {
|
||||
t.Errorf("Content = %q, expected to contain 'pong'", resp.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegration_RealCodexCLI_WithSystemPrompt(t *testing.T) {
|
||||
if _, err := exec.LookPath("codex"); err != nil {
|
||||
t.Skip("codex CLI not found in PATH")
|
||||
}
|
||||
|
||||
p := NewCodexCliProvider(t.TempDir())
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 90*time.Second)
|
||||
defer cancel()
|
||||
|
||||
resp, err := p.Chat(ctx, []Message{
|
||||
{Role: "system", Content: "You are a calculator. Only respond with numbers. No text."},
|
||||
{Role: "user", Content: "What is 2+2?"},
|
||||
}, nil, "", nil)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
}
|
||||
|
||||
t.Logf("Response: %q", resp.Content)
|
||||
|
||||
if !strings.Contains(resp.Content, "4") {
|
||||
t.Errorf("Content = %q, expected to contain '4'", resp.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegration_RealCodexCLI_ParsesRealJSONL(t *testing.T) {
|
||||
if _, err := exec.LookPath("codex"); err != nil {
|
||||
t.Skip("codex CLI not found in PATH")
|
||||
}
|
||||
|
||||
// Run codex directly and verify our parser handles real output
|
||||
cmd := exec.Command("codex", "exec",
|
||||
"--json",
|
||||
"--dangerously-bypass-approvals-and-sandbox",
|
||||
"--skip-git-repo-check",
|
||||
"--color", "never",
|
||||
"-C", t.TempDir(),
|
||||
"-")
|
||||
cmd.Stdin = strings.NewReader("Say hi")
|
||||
|
||||
output, err := cmd.Output()
|
||||
if err != nil {
|
||||
// codex may write diagnostic noise to stderr but still produce valid output
|
||||
if len(output) == 0 {
|
||||
t.Fatalf("codex CLI failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
t.Logf("Raw CLI output (first 500 chars): %s", string(output[:min(len(output), 500)]))
|
||||
|
||||
// Verify our parser can handle real output
|
||||
p := NewCodexCliProvider("")
|
||||
resp, err := p.parseJSONLEvents(string(output))
|
||||
if err != nil {
|
||||
t.Fatalf("parseJSONLEvents() failed on real CLI output: %v", err)
|
||||
}
|
||||
|
||||
if resp.Content == "" {
|
||||
t.Error("parsed Content is empty")
|
||||
}
|
||||
if resp.FinishReason != "stop" {
|
||||
t.Errorf("FinishReason = %q, want stop", resp.FinishReason)
|
||||
}
|
||||
|
||||
t.Logf("Parsed: content=%q, finish=%s, usage=%+v", resp.Content, resp.FinishReason, resp.Usage)
|
||||
}
|
||||
@@ -0,0 +1,207 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"math"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultFailureWindow = 24 * time.Hour
|
||||
)
|
||||
|
||||
// CooldownTracker manages per-provider cooldown state for the fallback chain.
|
||||
// Thread-safe via sync.RWMutex. In-memory only (resets on restart).
|
||||
type CooldownTracker struct {
|
||||
mu sync.RWMutex
|
||||
entries map[string]*cooldownEntry
|
||||
failureWindow time.Duration
|
||||
nowFunc func() time.Time // for testing
|
||||
}
|
||||
|
||||
type cooldownEntry struct {
|
||||
ErrorCount int
|
||||
FailureCounts map[FailoverReason]int
|
||||
CooldownEnd time.Time // standard cooldown expiry
|
||||
DisabledUntil time.Time // billing-specific disable expiry
|
||||
DisabledReason FailoverReason // reason for disable (billing)
|
||||
LastFailure time.Time
|
||||
}
|
||||
|
||||
// NewCooldownTracker creates a tracker with default 24h failure window.
|
||||
func NewCooldownTracker() *CooldownTracker {
|
||||
return &CooldownTracker{
|
||||
entries: make(map[string]*cooldownEntry),
|
||||
failureWindow: defaultFailureWindow,
|
||||
nowFunc: time.Now,
|
||||
}
|
||||
}
|
||||
|
||||
// MarkFailure records a failure for a provider and sets appropriate cooldown.
|
||||
// Resets error counts if last failure was more than failureWindow ago.
|
||||
func (ct *CooldownTracker) MarkFailure(provider string, reason FailoverReason) {
|
||||
ct.mu.Lock()
|
||||
defer ct.mu.Unlock()
|
||||
|
||||
now := ct.nowFunc()
|
||||
entry := ct.getOrCreate(provider)
|
||||
|
||||
// 24h failure window reset: if no failure in failureWindow, reset counters.
|
||||
if !entry.LastFailure.IsZero() && now.Sub(entry.LastFailure) > ct.failureWindow {
|
||||
entry.ErrorCount = 0
|
||||
entry.FailureCounts = make(map[FailoverReason]int)
|
||||
}
|
||||
|
||||
entry.ErrorCount++
|
||||
entry.FailureCounts[reason]++
|
||||
entry.LastFailure = now
|
||||
|
||||
if reason == FailoverBilling {
|
||||
billingCount := entry.FailureCounts[FailoverBilling]
|
||||
entry.DisabledUntil = now.Add(calculateBillingCooldown(billingCount))
|
||||
entry.DisabledReason = FailoverBilling
|
||||
} else {
|
||||
entry.CooldownEnd = now.Add(calculateStandardCooldown(entry.ErrorCount))
|
||||
}
|
||||
}
|
||||
|
||||
// MarkSuccess resets all counters and cooldowns for a provider.
|
||||
func (ct *CooldownTracker) MarkSuccess(provider string) {
|
||||
ct.mu.Lock()
|
||||
defer ct.mu.Unlock()
|
||||
|
||||
entry := ct.entries[provider]
|
||||
if entry == nil {
|
||||
return
|
||||
}
|
||||
|
||||
entry.ErrorCount = 0
|
||||
entry.FailureCounts = make(map[FailoverReason]int)
|
||||
entry.CooldownEnd = time.Time{}
|
||||
entry.DisabledUntil = time.Time{}
|
||||
entry.DisabledReason = ""
|
||||
}
|
||||
|
||||
// IsAvailable returns true if the provider is not in cooldown or disabled.
|
||||
func (ct *CooldownTracker) IsAvailable(provider string) bool {
|
||||
ct.mu.RLock()
|
||||
defer ct.mu.RUnlock()
|
||||
|
||||
entry := ct.entries[provider]
|
||||
if entry == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
now := ct.nowFunc()
|
||||
|
||||
// Billing disable takes precedence (longer cooldown).
|
||||
if !entry.DisabledUntil.IsZero() && now.Before(entry.DisabledUntil) {
|
||||
return false
|
||||
}
|
||||
|
||||
// Standard cooldown.
|
||||
if !entry.CooldownEnd.IsZero() && now.Before(entry.CooldownEnd) {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// CooldownRemaining returns how long until the provider becomes available.
|
||||
// Returns 0 if already available.
|
||||
func (ct *CooldownTracker) CooldownRemaining(provider string) time.Duration {
|
||||
ct.mu.RLock()
|
||||
defer ct.mu.RUnlock()
|
||||
|
||||
entry := ct.entries[provider]
|
||||
if entry == nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
now := ct.nowFunc()
|
||||
var remaining time.Duration
|
||||
|
||||
if !entry.DisabledUntil.IsZero() && now.Before(entry.DisabledUntil) {
|
||||
d := entry.DisabledUntil.Sub(now)
|
||||
if d > remaining {
|
||||
remaining = d
|
||||
}
|
||||
}
|
||||
|
||||
if !entry.CooldownEnd.IsZero() && now.Before(entry.CooldownEnd) {
|
||||
d := entry.CooldownEnd.Sub(now)
|
||||
if d > remaining {
|
||||
remaining = d
|
||||
}
|
||||
}
|
||||
|
||||
return remaining
|
||||
}
|
||||
|
||||
// ErrorCount returns the current error count for a provider.
|
||||
func (ct *CooldownTracker) ErrorCount(provider string) int {
|
||||
ct.mu.RLock()
|
||||
defer ct.mu.RUnlock()
|
||||
|
||||
entry := ct.entries[provider]
|
||||
if entry == nil {
|
||||
return 0
|
||||
}
|
||||
return entry.ErrorCount
|
||||
}
|
||||
|
||||
// FailureCount returns the failure count for a specific reason.
|
||||
func (ct *CooldownTracker) FailureCount(provider string, reason FailoverReason) int {
|
||||
ct.mu.RLock()
|
||||
defer ct.mu.RUnlock()
|
||||
|
||||
entry := ct.entries[provider]
|
||||
if entry == nil {
|
||||
return 0
|
||||
}
|
||||
return entry.FailureCounts[reason]
|
||||
}
|
||||
|
||||
func (ct *CooldownTracker) getOrCreate(provider string) *cooldownEntry {
|
||||
entry := ct.entries[provider]
|
||||
if entry == nil {
|
||||
entry = &cooldownEntry{
|
||||
FailureCounts: make(map[FailoverReason]int),
|
||||
}
|
||||
ct.entries[provider] = entry
|
||||
}
|
||||
return entry
|
||||
}
|
||||
|
||||
// calculateStandardCooldown computes standard exponential backoff.
|
||||
// Formula from OpenClaw: min(1h, 1min * 5^min(n-1, 3))
|
||||
//
|
||||
// 1 error → 1 min
|
||||
// 2 errors → 5 min
|
||||
// 3 errors → 25 min
|
||||
// 4+ errors → 1 hour (cap)
|
||||
func calculateStandardCooldown(errorCount int) time.Duration {
|
||||
n := max(1, errorCount)
|
||||
exp := min(n-1, 3)
|
||||
ms := 60_000 * int(math.Pow(5, float64(exp)))
|
||||
ms = min(3_600_000, ms) // cap at 1 hour
|
||||
return time.Duration(ms) * time.Millisecond
|
||||
}
|
||||
|
||||
// calculateBillingCooldown computes billing-specific exponential backoff.
|
||||
// Formula from OpenClaw: min(24h, 5h * 2^min(n-1, 10))
|
||||
//
|
||||
// 1 error → 5 hours
|
||||
// 2 errors → 10 hours
|
||||
// 3 errors → 20 hours
|
||||
// 4+ errors → 24 hours (cap)
|
||||
func calculateBillingCooldown(billingErrorCount int) time.Duration {
|
||||
const baseMs = 5 * 60 * 60 * 1000 // 5 hours
|
||||
const maxMs = 24 * 60 * 60 * 1000 // 24 hours
|
||||
|
||||
n := max(1, billingErrorCount)
|
||||
exp := min(n-1, 10)
|
||||
raw := float64(baseMs) * math.Pow(2, float64(exp))
|
||||
ms := int(math.Min(float64(maxMs), raw))
|
||||
return time.Duration(ms) * time.Millisecond
|
||||
}
|
||||
@@ -0,0 +1,269 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func newTestTracker(now time.Time) (*CooldownTracker, *time.Time) {
|
||||
current := now
|
||||
ct := NewCooldownTracker()
|
||||
ct.nowFunc = func() time.Time { return current }
|
||||
return ct, ¤t
|
||||
}
|
||||
|
||||
func TestCooldown_InitiallyAvailable(t *testing.T) {
|
||||
ct := NewCooldownTracker()
|
||||
if !ct.IsAvailable("openai") {
|
||||
t.Error("new provider should be available")
|
||||
}
|
||||
if ct.ErrorCount("openai") != 0 {
|
||||
t.Error("new provider should have 0 errors")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCooldown_StandardEscalation(t *testing.T) {
|
||||
now := time.Now()
|
||||
ct, current := newTestTracker(now)
|
||||
|
||||
// 1st error → 1 min cooldown
|
||||
ct.MarkFailure("openai", FailoverRateLimit)
|
||||
if ct.IsAvailable("openai") {
|
||||
t.Error("should be in cooldown after 1st error")
|
||||
}
|
||||
|
||||
// Advance 61 seconds → available
|
||||
*current = now.Add(61 * time.Second)
|
||||
if !ct.IsAvailable("openai") {
|
||||
t.Error("should be available after 1 min cooldown")
|
||||
}
|
||||
|
||||
// 2nd error → 5 min cooldown
|
||||
ct.MarkFailure("openai", FailoverRateLimit)
|
||||
*current = now.Add(61*time.Second + 4*time.Minute)
|
||||
if ct.IsAvailable("openai") {
|
||||
t.Error("should be in cooldown (5 min) after 2nd error")
|
||||
}
|
||||
*current = now.Add(61*time.Second + 6*time.Minute)
|
||||
if !ct.IsAvailable("openai") {
|
||||
t.Error("should be available after 5 min cooldown")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCooldown_StandardCap(t *testing.T) {
|
||||
// Verify formula: 1m, 5m, 25m, 1h, 1h, 1h...
|
||||
expected := []time.Duration{
|
||||
1 * time.Minute,
|
||||
5 * time.Minute,
|
||||
25 * time.Minute,
|
||||
1 * time.Hour,
|
||||
1 * time.Hour,
|
||||
}
|
||||
|
||||
for i, want := range expected {
|
||||
got := calculateStandardCooldown(i + 1)
|
||||
if got != want {
|
||||
t.Errorf("calculateStandardCooldown(%d) = %v, want %v", i+1, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCooldown_BillingEscalation(t *testing.T) {
|
||||
now := time.Now()
|
||||
ct, current := newTestTracker(now)
|
||||
|
||||
// 1st billing error → 5h cooldown
|
||||
ct.MarkFailure("openai", FailoverBilling)
|
||||
if ct.IsAvailable("openai") {
|
||||
t.Error("should be disabled after billing error")
|
||||
}
|
||||
|
||||
// Advance 4h → still disabled
|
||||
*current = now.Add(4 * time.Hour)
|
||||
if ct.IsAvailable("openai") {
|
||||
t.Error("should still be disabled (5h cooldown)")
|
||||
}
|
||||
|
||||
// Advance 5h + 1s → available
|
||||
*current = now.Add(5*time.Hour + 1*time.Second)
|
||||
if !ct.IsAvailable("openai") {
|
||||
t.Error("should be available after 5h billing cooldown")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCooldown_BillingCap(t *testing.T) {
|
||||
expected := []time.Duration{
|
||||
5 * time.Hour,
|
||||
10 * time.Hour,
|
||||
20 * time.Hour,
|
||||
24 * time.Hour,
|
||||
24 * time.Hour,
|
||||
}
|
||||
|
||||
for i, want := range expected {
|
||||
got := calculateBillingCooldown(i + 1)
|
||||
if got != want {
|
||||
t.Errorf("calculateBillingCooldown(%d) = %v, want %v", i+1, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCooldown_SuccessReset(t *testing.T) {
|
||||
ct := NewCooldownTracker()
|
||||
|
||||
ct.MarkFailure("openai", FailoverRateLimit)
|
||||
ct.MarkFailure("openai", FailoverBilling)
|
||||
if ct.ErrorCount("openai") != 2 {
|
||||
t.Errorf("error count = %d, want 2", ct.ErrorCount("openai"))
|
||||
}
|
||||
|
||||
ct.MarkSuccess("openai")
|
||||
if ct.ErrorCount("openai") != 0 {
|
||||
t.Errorf("error count after success = %d, want 0", ct.ErrorCount("openai"))
|
||||
}
|
||||
if !ct.IsAvailable("openai") {
|
||||
t.Error("should be available after success")
|
||||
}
|
||||
if ct.FailureCount("openai", FailoverRateLimit) != 0 {
|
||||
t.Error("failure counts should be reset after success")
|
||||
}
|
||||
if ct.FailureCount("openai", FailoverBilling) != 0 {
|
||||
t.Error("billing failure count should be reset after success")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCooldown_FailureWindowReset(t *testing.T) {
|
||||
now := time.Now()
|
||||
ct, current := newTestTracker(now)
|
||||
|
||||
// 4 errors → 1h cooldown
|
||||
for i := 0; i < 4; i++ {
|
||||
ct.MarkFailure("openai", FailoverRateLimit)
|
||||
*current = current.Add(2 * time.Second) // small advance between errors
|
||||
}
|
||||
if ct.ErrorCount("openai") != 4 {
|
||||
t.Errorf("error count = %d, want 4", ct.ErrorCount("openai"))
|
||||
}
|
||||
|
||||
// Advance 25 hours (past 24h failure window)
|
||||
*current = now.Add(25 * time.Hour)
|
||||
|
||||
// Next error should reset counters first, then increment to 1
|
||||
ct.MarkFailure("openai", FailoverRateLimit)
|
||||
if ct.ErrorCount("openai") != 1 {
|
||||
t.Errorf("error count after window reset = %d, want 1 (reset + 1)", ct.ErrorCount("openai"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestCooldown_PerReasonTracking(t *testing.T) {
|
||||
ct := NewCooldownTracker()
|
||||
|
||||
ct.MarkFailure("openai", FailoverRateLimit)
|
||||
ct.MarkFailure("openai", FailoverRateLimit)
|
||||
ct.MarkFailure("openai", FailoverBilling)
|
||||
ct.MarkFailure("openai", FailoverAuth)
|
||||
|
||||
if ct.FailureCount("openai", FailoverRateLimit) != 2 {
|
||||
t.Errorf("rate_limit count = %d, want 2", ct.FailureCount("openai", FailoverRateLimit))
|
||||
}
|
||||
if ct.FailureCount("openai", FailoverBilling) != 1 {
|
||||
t.Errorf("billing count = %d, want 1", ct.FailureCount("openai", FailoverBilling))
|
||||
}
|
||||
if ct.FailureCount("openai", FailoverAuth) != 1 {
|
||||
t.Errorf("auth count = %d, want 1", ct.FailureCount("openai", FailoverAuth))
|
||||
}
|
||||
if ct.ErrorCount("openai") != 4 {
|
||||
t.Errorf("total error count = %d, want 4", ct.ErrorCount("openai"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestCooldown_BillingTakesPrecedence(t *testing.T) {
|
||||
now := time.Now()
|
||||
ct, current := newTestTracker(now)
|
||||
|
||||
// Standard cooldown (1 min) + billing disable (5h)
|
||||
ct.MarkFailure("openai", FailoverRateLimit) // 1 min cooldown
|
||||
ct.MarkFailure("openai", FailoverBilling) // 5h disable
|
||||
|
||||
// After 2 min: standard cooldown expired but billing still active
|
||||
*current = now.Add(2 * time.Minute)
|
||||
if ct.IsAvailable("openai") {
|
||||
t.Error("billing disable should take precedence over standard cooldown")
|
||||
}
|
||||
|
||||
// After 5h + 1s: both expired
|
||||
*current = now.Add(5*time.Hour + 1*time.Second)
|
||||
if !ct.IsAvailable("openai") {
|
||||
t.Error("should be available after all cooldowns expire")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCooldown_CooldownRemaining(t *testing.T) {
|
||||
now := time.Now()
|
||||
ct, current := newTestTracker(now)
|
||||
|
||||
// No failures → 0 remaining
|
||||
if ct.CooldownRemaining("openai") != 0 {
|
||||
t.Error("expected 0 remaining for new provider")
|
||||
}
|
||||
|
||||
ct.MarkFailure("openai", FailoverRateLimit)
|
||||
|
||||
*current = now.Add(30 * time.Second)
|
||||
remaining := ct.CooldownRemaining("openai")
|
||||
if remaining <= 0 || remaining > 1*time.Minute {
|
||||
t.Errorf("remaining = %v, expected ~30s", remaining)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCooldown_SuccessOnUnknownProvider(t *testing.T) {
|
||||
ct := NewCooldownTracker()
|
||||
// Should not panic
|
||||
ct.MarkSuccess("nonexistent")
|
||||
if !ct.IsAvailable("nonexistent") {
|
||||
t.Error("nonexistent provider should be available")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCooldown_ConcurrentAccess(t *testing.T) {
|
||||
ct := NewCooldownTracker()
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
wg.Add(3)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
ct.MarkFailure("openai", FailoverRateLimit)
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
ct.IsAvailable("openai")
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
ct.MarkSuccess("openai")
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
// If we got here without panic, concurrent access is safe
|
||||
}
|
||||
|
||||
func TestCooldown_MultipleProviders(t *testing.T) {
|
||||
ct := NewCooldownTracker()
|
||||
|
||||
ct.MarkFailure("openai", FailoverRateLimit)
|
||||
ct.MarkFailure("anthropic", FailoverBilling)
|
||||
|
||||
if ct.IsAvailable("openai") {
|
||||
t.Error("openai should be in cooldown")
|
||||
}
|
||||
if ct.IsAvailable("anthropic") {
|
||||
t.Error("anthropic should be in cooldown")
|
||||
}
|
||||
// groq was never touched
|
||||
if !ct.IsAvailable("groq") {
|
||||
t.Error("groq should be available")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,253 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// errorPattern defines a single pattern (string or regex) for error classification.
|
||||
type errorPattern struct {
|
||||
substring string
|
||||
regex *regexp.Regexp
|
||||
}
|
||||
|
||||
func substr(s string) errorPattern { return errorPattern{substring: s} }
|
||||
func rxp(r string) errorPattern { return errorPattern{regex: regexp.MustCompile("(?i)" + r)} }
|
||||
|
||||
// Error patterns organized by FailoverReason, matching OpenClaw production (~40 patterns).
|
||||
var (
|
||||
rateLimitPatterns = []errorPattern{
|
||||
rxp(`rate[_ ]limit`),
|
||||
substr("too many requests"),
|
||||
substr("429"),
|
||||
substr("exceeded your current quota"),
|
||||
rxp(`exceeded.*quota`),
|
||||
rxp(`resource has been exhausted`),
|
||||
rxp(`resource.*exhausted`),
|
||||
substr("resource_exhausted"),
|
||||
substr("quota exceeded"),
|
||||
substr("usage limit"),
|
||||
}
|
||||
|
||||
overloadedPatterns = []errorPattern{
|
||||
rxp(`overloaded_error`),
|
||||
rxp(`"type"\s*:\s*"overloaded_error"`),
|
||||
substr("overloaded"),
|
||||
}
|
||||
|
||||
timeoutPatterns = []errorPattern{
|
||||
substr("timeout"),
|
||||
substr("timed out"),
|
||||
substr("deadline exceeded"),
|
||||
substr("context deadline exceeded"),
|
||||
}
|
||||
|
||||
billingPatterns = []errorPattern{
|
||||
rxp(`\b402\b`),
|
||||
substr("payment required"),
|
||||
substr("insufficient credits"),
|
||||
substr("credit balance"),
|
||||
substr("plans & billing"),
|
||||
substr("insufficient balance"),
|
||||
}
|
||||
|
||||
authPatterns = []errorPattern{
|
||||
rxp(`invalid[_ ]?api[_ ]?key`),
|
||||
substr("incorrect api key"),
|
||||
substr("invalid token"),
|
||||
substr("authentication"),
|
||||
substr("re-authenticate"),
|
||||
substr("oauth token refresh failed"),
|
||||
substr("unauthorized"),
|
||||
substr("forbidden"),
|
||||
substr("access denied"),
|
||||
substr("expired"),
|
||||
substr("token has expired"),
|
||||
rxp(`\b401\b`),
|
||||
rxp(`\b403\b`),
|
||||
substr("no credentials found"),
|
||||
substr("no api key found"),
|
||||
}
|
||||
|
||||
formatPatterns = []errorPattern{
|
||||
substr("string should match pattern"),
|
||||
substr("tool_use.id"),
|
||||
substr("tool_use_id"),
|
||||
substr("messages.1.content.1.tool_use.id"),
|
||||
substr("invalid request format"),
|
||||
}
|
||||
|
||||
imageDimensionPatterns = []errorPattern{
|
||||
rxp(`image dimensions exceed max`),
|
||||
}
|
||||
|
||||
imageSizePatterns = []errorPattern{
|
||||
rxp(`image exceeds.*mb`),
|
||||
}
|
||||
|
||||
// Transient HTTP status codes that map to timeout (server-side failures).
|
||||
transientStatusCodes = map[int]bool{
|
||||
500: true, 502: true, 503: true,
|
||||
521: true, 522: true, 523: true, 524: true,
|
||||
529: true,
|
||||
}
|
||||
)
|
||||
|
||||
// ClassifyError classifies an error into a FailoverError with reason.
|
||||
// Returns nil if the error is not classifiable (unknown errors should not trigger fallback).
|
||||
func ClassifyError(err error, provider, model string) *FailoverError {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Context cancellation: user abort, never fallback.
|
||||
if err == context.Canceled {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Context deadline exceeded: treat as timeout, always fallback.
|
||||
if err == context.DeadlineExceeded {
|
||||
return &FailoverError{
|
||||
Reason: FailoverTimeout,
|
||||
Provider: provider,
|
||||
Model: model,
|
||||
Wrapped: err,
|
||||
}
|
||||
}
|
||||
|
||||
msg := strings.ToLower(err.Error())
|
||||
|
||||
// Image dimension/size errors: non-retriable, non-fallback.
|
||||
if IsImageDimensionError(msg) || IsImageSizeError(msg) {
|
||||
return &FailoverError{
|
||||
Reason: FailoverFormat,
|
||||
Provider: provider,
|
||||
Model: model,
|
||||
Wrapped: err,
|
||||
}
|
||||
}
|
||||
|
||||
// Try HTTP status code extraction first.
|
||||
if status := extractHTTPStatus(msg); status > 0 {
|
||||
if reason := classifyByStatus(status); reason != "" {
|
||||
return &FailoverError{
|
||||
Reason: reason,
|
||||
Provider: provider,
|
||||
Model: model,
|
||||
Status: status,
|
||||
Wrapped: err,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Message pattern matching (priority order from OpenClaw).
|
||||
if reason := classifyByMessage(msg); reason != "" {
|
||||
return &FailoverError{
|
||||
Reason: reason,
|
||||
Provider: provider,
|
||||
Model: model,
|
||||
Wrapped: err,
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// classifyByStatus maps HTTP status codes to FailoverReason.
|
||||
func classifyByStatus(status int) FailoverReason {
|
||||
switch {
|
||||
case status == 401 || status == 403:
|
||||
return FailoverAuth
|
||||
case status == 402:
|
||||
return FailoverBilling
|
||||
case status == 408:
|
||||
return FailoverTimeout
|
||||
case status == 429:
|
||||
return FailoverRateLimit
|
||||
case status == 400:
|
||||
return FailoverFormat
|
||||
case transientStatusCodes[status]:
|
||||
return FailoverTimeout
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// classifyByMessage matches error messages against patterns.
|
||||
// Priority order matters (from OpenClaw classifyFailoverReason).
|
||||
func classifyByMessage(msg string) FailoverReason {
|
||||
if matchesAny(msg, rateLimitPatterns) {
|
||||
return FailoverRateLimit
|
||||
}
|
||||
if matchesAny(msg, overloadedPatterns) {
|
||||
return FailoverRateLimit // Overloaded treated as rate_limit
|
||||
}
|
||||
if matchesAny(msg, billingPatterns) {
|
||||
return FailoverBilling
|
||||
}
|
||||
if matchesAny(msg, timeoutPatterns) {
|
||||
return FailoverTimeout
|
||||
}
|
||||
if matchesAny(msg, authPatterns) {
|
||||
return FailoverAuth
|
||||
}
|
||||
if matchesAny(msg, formatPatterns) {
|
||||
return FailoverFormat
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// extractHTTPStatus extracts an HTTP status code from an error message.
|
||||
// Looks for patterns like "status: 429", "status 429", "HTTP 429", or standalone "429".
|
||||
func extractHTTPStatus(msg string) int {
|
||||
// Common patterns in Go HTTP error messages
|
||||
patterns := []*regexp.Regexp{
|
||||
regexp.MustCompile(`status[:\s]+(\d{3})`),
|
||||
regexp.MustCompile(`HTTP[/\s]+\d*\.?\d*\s+(\d{3})`),
|
||||
}
|
||||
|
||||
for _, p := range patterns {
|
||||
if m := p.FindStringSubmatch(msg); len(m) > 1 {
|
||||
return parseDigits(m[1])
|
||||
}
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
// IsImageDimensionError returns true if the message indicates an image dimension error.
|
||||
func IsImageDimensionError(msg string) bool {
|
||||
return matchesAny(msg, imageDimensionPatterns)
|
||||
}
|
||||
|
||||
// IsImageSizeError returns true if the message indicates an image file size error.
|
||||
func IsImageSizeError(msg string) bool {
|
||||
return matchesAny(msg, imageSizePatterns)
|
||||
}
|
||||
|
||||
// matchesAny checks if msg matches any of the patterns.
|
||||
func matchesAny(msg string, patterns []errorPattern) bool {
|
||||
for _, p := range patterns {
|
||||
if p.regex != nil {
|
||||
if p.regex.MatchString(msg) {
|
||||
return true
|
||||
}
|
||||
} else if p.substring != "" {
|
||||
if strings.Contains(msg, p.substring) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// parseDigits converts a string of digits to an int.
|
||||
func parseDigits(s string) int {
|
||||
n := 0
|
||||
for _, c := range s {
|
||||
if c >= '0' && c <= '9' {
|
||||
n = n*10 + int(c-'0')
|
||||
}
|
||||
}
|
||||
return n
|
||||
}
|
||||
@@ -0,0 +1,337 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestClassifyError_Nil(t *testing.T) {
|
||||
result := ClassifyError(nil, "openai", "gpt-4")
|
||||
if result != nil {
|
||||
t.Errorf("expected nil for nil error, got %+v", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifyError_ContextCanceled(t *testing.T) {
|
||||
result := ClassifyError(context.Canceled, "openai", "gpt-4")
|
||||
if result != nil {
|
||||
t.Errorf("expected nil for context.Canceled (user abort), got %+v", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifyError_ContextDeadlineExceeded(t *testing.T) {
|
||||
result := ClassifyError(context.DeadlineExceeded, "openai", "gpt-4")
|
||||
if result == nil {
|
||||
t.Fatal("expected non-nil for deadline exceeded")
|
||||
}
|
||||
if result.Reason != FailoverTimeout {
|
||||
t.Errorf("reason = %q, want timeout", result.Reason)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifyError_StatusCodes(t *testing.T) {
|
||||
tests := []struct {
|
||||
status int
|
||||
reason FailoverReason
|
||||
}{
|
||||
{401, FailoverAuth},
|
||||
{403, FailoverAuth},
|
||||
{402, FailoverBilling},
|
||||
{408, FailoverTimeout},
|
||||
{429, FailoverRateLimit},
|
||||
{400, FailoverFormat},
|
||||
{500, FailoverTimeout},
|
||||
{502, FailoverTimeout},
|
||||
{503, FailoverTimeout},
|
||||
{521, FailoverTimeout},
|
||||
{522, FailoverTimeout},
|
||||
{523, FailoverTimeout},
|
||||
{524, FailoverTimeout},
|
||||
{529, FailoverTimeout},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
err := fmt.Errorf("API error: status: %d something went wrong", tt.status)
|
||||
result := ClassifyError(err, "test", "model")
|
||||
if result == nil {
|
||||
t.Errorf("status %d: expected non-nil", tt.status)
|
||||
continue
|
||||
}
|
||||
if result.Reason != tt.reason {
|
||||
t.Errorf("status %d: reason = %q, want %q", tt.status, result.Reason, tt.reason)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifyError_RateLimitPatterns(t *testing.T) {
|
||||
patterns := []string{
|
||||
"rate limit exceeded",
|
||||
"rate_limit reached",
|
||||
"too many requests",
|
||||
"exceeded your current quota",
|
||||
"resource has been exhausted",
|
||||
"resource_exhausted",
|
||||
"quota exceeded",
|
||||
"usage limit reached",
|
||||
}
|
||||
|
||||
for _, msg := range patterns {
|
||||
err := errors.New(msg)
|
||||
result := ClassifyError(err, "openai", "gpt-4")
|
||||
if result == nil {
|
||||
t.Errorf("pattern %q: expected non-nil", msg)
|
||||
continue
|
||||
}
|
||||
if result.Reason != FailoverRateLimit {
|
||||
t.Errorf("pattern %q: reason = %q, want rate_limit", msg, result.Reason)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifyError_OverloadedPatterns(t *testing.T) {
|
||||
patterns := []string{
|
||||
"overloaded_error",
|
||||
`{"type": "overloaded_error"}`,
|
||||
"server is overloaded",
|
||||
}
|
||||
|
||||
for _, msg := range patterns {
|
||||
err := errors.New(msg)
|
||||
result := ClassifyError(err, "anthropic", "claude")
|
||||
if result == nil {
|
||||
t.Errorf("pattern %q: expected non-nil", msg)
|
||||
continue
|
||||
}
|
||||
// Overloaded is treated as rate_limit
|
||||
if result.Reason != FailoverRateLimit {
|
||||
t.Errorf("pattern %q: reason = %q, want rate_limit", msg, result.Reason)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifyError_BillingPatterns(t *testing.T) {
|
||||
patterns := []string{
|
||||
"payment required",
|
||||
"insufficient credits",
|
||||
"credit balance too low",
|
||||
"plans & billing page",
|
||||
"insufficient balance",
|
||||
}
|
||||
|
||||
for _, msg := range patterns {
|
||||
err := errors.New(msg)
|
||||
result := ClassifyError(err, "openai", "gpt-4")
|
||||
if result == nil {
|
||||
t.Errorf("pattern %q: expected non-nil", msg)
|
||||
continue
|
||||
}
|
||||
if result.Reason != FailoverBilling {
|
||||
t.Errorf("pattern %q: reason = %q, want billing", msg, result.Reason)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifyError_TimeoutPatterns(t *testing.T) {
|
||||
patterns := []string{
|
||||
"request timeout",
|
||||
"connection timed out",
|
||||
"deadline exceeded",
|
||||
"context deadline exceeded",
|
||||
}
|
||||
|
||||
for _, msg := range patterns {
|
||||
err := errors.New(msg)
|
||||
result := ClassifyError(err, "openai", "gpt-4")
|
||||
if result == nil {
|
||||
t.Errorf("pattern %q: expected non-nil", msg)
|
||||
continue
|
||||
}
|
||||
if result.Reason != FailoverTimeout {
|
||||
t.Errorf("pattern %q: reason = %q, want timeout", msg, result.Reason)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifyError_AuthPatterns(t *testing.T) {
|
||||
patterns := []string{
|
||||
"invalid api key",
|
||||
"invalid_api_key",
|
||||
"incorrect api key",
|
||||
"invalid token",
|
||||
"authentication failed",
|
||||
"re-authenticate",
|
||||
"oauth token refresh failed",
|
||||
"unauthorized access",
|
||||
"forbidden",
|
||||
"access denied",
|
||||
"expired",
|
||||
"token has expired",
|
||||
"no credentials found",
|
||||
"no api key found",
|
||||
}
|
||||
|
||||
for _, msg := range patterns {
|
||||
err := errors.New(msg)
|
||||
result := ClassifyError(err, "openai", "gpt-4")
|
||||
if result == nil {
|
||||
t.Errorf("pattern %q: expected non-nil", msg)
|
||||
continue
|
||||
}
|
||||
if result.Reason != FailoverAuth {
|
||||
t.Errorf("pattern %q: reason = %q, want auth", msg, result.Reason)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifyError_FormatPatterns(t *testing.T) {
|
||||
patterns := []string{
|
||||
"string should match pattern",
|
||||
"tool_use.id is required",
|
||||
"invalid tool_use_id",
|
||||
"messages.1.content.1.tool_use.id must be valid",
|
||||
"invalid request format",
|
||||
}
|
||||
|
||||
for _, msg := range patterns {
|
||||
err := errors.New(msg)
|
||||
result := ClassifyError(err, "anthropic", "claude")
|
||||
if result == nil {
|
||||
t.Errorf("pattern %q: expected non-nil", msg)
|
||||
continue
|
||||
}
|
||||
if result.Reason != FailoverFormat {
|
||||
t.Errorf("pattern %q: reason = %q, want format", msg, result.Reason)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifyError_ImageDimensionError(t *testing.T) {
|
||||
err := errors.New("image dimensions exceed max allowed 2048x2048")
|
||||
result := ClassifyError(err, "openai", "gpt-4o")
|
||||
if result == nil {
|
||||
t.Fatal("expected non-nil for image dimension error")
|
||||
}
|
||||
if result.Reason != FailoverFormat {
|
||||
t.Errorf("reason = %q, want format", result.Reason)
|
||||
}
|
||||
if result.IsRetriable() {
|
||||
t.Error("image dimension error should not be retriable")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifyError_ImageSizeError(t *testing.T) {
|
||||
err := errors.New("image exceeds 20 mb limit")
|
||||
result := ClassifyError(err, "openai", "gpt-4o")
|
||||
if result == nil {
|
||||
t.Fatal("expected non-nil for image size error")
|
||||
}
|
||||
if result.Reason != FailoverFormat {
|
||||
t.Errorf("reason = %q, want format", result.Reason)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifyError_UnknownError(t *testing.T) {
|
||||
err := errors.New("some completely random error")
|
||||
result := ClassifyError(err, "openai", "gpt-4")
|
||||
if result != nil {
|
||||
t.Errorf("expected nil for unknown error, got %+v", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifyError_ProviderModelPropagation(t *testing.T) {
|
||||
err := errors.New("rate limit exceeded")
|
||||
result := ClassifyError(err, "my-provider", "my-model")
|
||||
if result == nil {
|
||||
t.Fatal("expected non-nil")
|
||||
}
|
||||
if result.Provider != "my-provider" {
|
||||
t.Errorf("provider = %q, want my-provider", result.Provider)
|
||||
}
|
||||
if result.Model != "my-model" {
|
||||
t.Errorf("model = %q, want my-model", result.Model)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFailoverError_IsRetriable(t *testing.T) {
|
||||
tests := []struct {
|
||||
reason FailoverReason
|
||||
retriable bool
|
||||
}{
|
||||
{FailoverAuth, true},
|
||||
{FailoverRateLimit, true},
|
||||
{FailoverBilling, true},
|
||||
{FailoverTimeout, true},
|
||||
{FailoverOverloaded, true},
|
||||
{FailoverFormat, false},
|
||||
{FailoverUnknown, true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
fe := &FailoverError{Reason: tt.reason}
|
||||
if fe.IsRetriable() != tt.retriable {
|
||||
t.Errorf("IsRetriable(%q) = %v, want %v", tt.reason, fe.IsRetriable(), tt.retriable)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFailoverError_ErrorString(t *testing.T) {
|
||||
fe := &FailoverError{
|
||||
Reason: FailoverRateLimit,
|
||||
Provider: "openai",
|
||||
Model: "gpt-4",
|
||||
Status: 429,
|
||||
Wrapped: errors.New("too many requests"),
|
||||
}
|
||||
s := fe.Error()
|
||||
if s == "" {
|
||||
t.Error("expected non-empty error string")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFailoverError_Unwrap(t *testing.T) {
|
||||
inner := errors.New("inner error")
|
||||
fe := &FailoverError{Reason: FailoverTimeout, Wrapped: inner}
|
||||
if fe.Unwrap() != inner {
|
||||
t.Error("Unwrap should return wrapped error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractHTTPStatus(t *testing.T) {
|
||||
tests := []struct {
|
||||
msg string
|
||||
want int
|
||||
}{
|
||||
{"status: 429 rate limited", 429},
|
||||
{"status 401 unauthorized", 401},
|
||||
{"HTTP/1.1 502 Bad Gateway", 502},
|
||||
{"no status code here", 0},
|
||||
{"random number 12345", 0},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
got := extractHTTPStatus(tt.msg)
|
||||
if got != tt.want {
|
||||
t.Errorf("extractHTTPStatus(%q) = %d, want %d", tt.msg, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsImageDimensionError(t *testing.T) {
|
||||
if !IsImageDimensionError("image dimensions exceed max 4096x4096") {
|
||||
t.Error("should match image dimensions exceed max")
|
||||
}
|
||||
if IsImageDimensionError("normal error message") {
|
||||
t.Error("should not match normal error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsImageSizeError(t *testing.T) {
|
||||
if !IsImageSizeError("image exceeds 20 mb") {
|
||||
t.Error("should match image exceeds mb")
|
||||
}
|
||||
if IsImageSizeError("normal error message") {
|
||||
t.Error("should not match normal error")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,283 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// FallbackChain orchestrates model fallback across multiple candidates.
|
||||
type FallbackChain struct {
|
||||
cooldown *CooldownTracker
|
||||
}
|
||||
|
||||
// FallbackCandidate represents one model/provider to try.
|
||||
type FallbackCandidate struct {
|
||||
Provider string
|
||||
Model string
|
||||
}
|
||||
|
||||
// FallbackResult contains the successful response and metadata about all attempts.
|
||||
type FallbackResult struct {
|
||||
Response *LLMResponse
|
||||
Provider string
|
||||
Model string
|
||||
Attempts []FallbackAttempt
|
||||
}
|
||||
|
||||
// FallbackAttempt records one attempt in the fallback chain.
|
||||
type FallbackAttempt struct {
|
||||
Provider string
|
||||
Model string
|
||||
Error error
|
||||
Reason FailoverReason
|
||||
Duration time.Duration
|
||||
Skipped bool // true if skipped due to cooldown
|
||||
}
|
||||
|
||||
// NewFallbackChain creates a new fallback chain with the given cooldown tracker.
|
||||
func NewFallbackChain(cooldown *CooldownTracker) *FallbackChain {
|
||||
return &FallbackChain{cooldown: cooldown}
|
||||
}
|
||||
|
||||
// ResolveCandidates parses model config into a deduplicated candidate list.
|
||||
func ResolveCandidates(cfg ModelConfig, defaultProvider string) []FallbackCandidate {
|
||||
seen := make(map[string]bool)
|
||||
var candidates []FallbackCandidate
|
||||
|
||||
addCandidate := func(raw string) {
|
||||
ref := ParseModelRef(raw, defaultProvider)
|
||||
if ref == nil {
|
||||
return
|
||||
}
|
||||
key := ModelKey(ref.Provider, ref.Model)
|
||||
if seen[key] {
|
||||
return
|
||||
}
|
||||
seen[key] = true
|
||||
candidates = append(candidates, FallbackCandidate{
|
||||
Provider: ref.Provider,
|
||||
Model: ref.Model,
|
||||
})
|
||||
}
|
||||
|
||||
// Primary first.
|
||||
addCandidate(cfg.Primary)
|
||||
|
||||
// Then fallbacks.
|
||||
for _, fb := range cfg.Fallbacks {
|
||||
addCandidate(fb)
|
||||
}
|
||||
|
||||
return candidates
|
||||
}
|
||||
|
||||
// Execute runs the fallback chain for text/chat requests.
|
||||
// It tries each candidate in order, respecting cooldowns and error classification.
|
||||
//
|
||||
// Behavior:
|
||||
// - Candidates in cooldown are skipped (logged as skipped attempt).
|
||||
// - context.Canceled aborts immediately (user abort, no fallback).
|
||||
// - Non-retriable errors (format) abort immediately.
|
||||
// - Retriable errors trigger fallback to next candidate.
|
||||
// - Success marks provider as good (resets cooldown).
|
||||
// - If all fail, returns aggregate error with all attempts.
|
||||
func (fc *FallbackChain) Execute(
|
||||
ctx context.Context,
|
||||
candidates []FallbackCandidate,
|
||||
run func(ctx context.Context, provider, model string) (*LLMResponse, error),
|
||||
) (*FallbackResult, error) {
|
||||
if len(candidates) == 0 {
|
||||
return nil, fmt.Errorf("fallback: no candidates configured")
|
||||
}
|
||||
|
||||
result := &FallbackResult{
|
||||
Attempts: make([]FallbackAttempt, 0, len(candidates)),
|
||||
}
|
||||
|
||||
for i, candidate := range candidates {
|
||||
// Check context before each attempt.
|
||||
if ctx.Err() == context.Canceled {
|
||||
return nil, context.Canceled
|
||||
}
|
||||
|
||||
// Check cooldown.
|
||||
if !fc.cooldown.IsAvailable(candidate.Provider) {
|
||||
remaining := fc.cooldown.CooldownRemaining(candidate.Provider)
|
||||
result.Attempts = append(result.Attempts, FallbackAttempt{
|
||||
Provider: candidate.Provider,
|
||||
Model: candidate.Model,
|
||||
Skipped: true,
|
||||
Reason: FailoverRateLimit,
|
||||
Error: fmt.Errorf("provider %s in cooldown (%s remaining)", candidate.Provider, remaining.Round(time.Second)),
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
// Execute the run function.
|
||||
start := time.Now()
|
||||
resp, err := run(ctx, candidate.Provider, candidate.Model)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
if err == nil {
|
||||
// Success.
|
||||
fc.cooldown.MarkSuccess(candidate.Provider)
|
||||
result.Response = resp
|
||||
result.Provider = candidate.Provider
|
||||
result.Model = candidate.Model
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Context cancellation: abort immediately, no fallback.
|
||||
if ctx.Err() == context.Canceled {
|
||||
result.Attempts = append(result.Attempts, FallbackAttempt{
|
||||
Provider: candidate.Provider,
|
||||
Model: candidate.Model,
|
||||
Error: err,
|
||||
Duration: elapsed,
|
||||
})
|
||||
return nil, context.Canceled
|
||||
}
|
||||
|
||||
// Classify the error.
|
||||
failErr := ClassifyError(err, candidate.Provider, candidate.Model)
|
||||
|
||||
if failErr == nil {
|
||||
// Unclassifiable error: do not fallback, return immediately.
|
||||
result.Attempts = append(result.Attempts, FallbackAttempt{
|
||||
Provider: candidate.Provider,
|
||||
Model: candidate.Model,
|
||||
Error: err,
|
||||
Duration: elapsed,
|
||||
})
|
||||
return nil, fmt.Errorf("fallback: unclassified error from %s/%s: %w",
|
||||
candidate.Provider, candidate.Model, err)
|
||||
}
|
||||
|
||||
// Non-retriable error: abort immediately.
|
||||
if !failErr.IsRetriable() {
|
||||
result.Attempts = append(result.Attempts, FallbackAttempt{
|
||||
Provider: candidate.Provider,
|
||||
Model: candidate.Model,
|
||||
Error: failErr,
|
||||
Reason: failErr.Reason,
|
||||
Duration: elapsed,
|
||||
})
|
||||
return nil, failErr
|
||||
}
|
||||
|
||||
// Retriable error: mark failure and continue to next candidate.
|
||||
fc.cooldown.MarkFailure(candidate.Provider, failErr.Reason)
|
||||
result.Attempts = append(result.Attempts, FallbackAttempt{
|
||||
Provider: candidate.Provider,
|
||||
Model: candidate.Model,
|
||||
Error: failErr,
|
||||
Reason: failErr.Reason,
|
||||
Duration: elapsed,
|
||||
})
|
||||
|
||||
// If this was the last candidate, return aggregate error.
|
||||
if i == len(candidates)-1 {
|
||||
return nil, &FallbackExhaustedError{Attempts: result.Attempts}
|
||||
}
|
||||
}
|
||||
|
||||
// All candidates were skipped (all in cooldown).
|
||||
return nil, &FallbackExhaustedError{Attempts: result.Attempts}
|
||||
}
|
||||
|
||||
// ExecuteImage runs the fallback chain for image/vision requests.
|
||||
// Simpler than Execute: no cooldown checks (image endpoints have different rate limits).
|
||||
// Image dimension/size errors abort immediately (non-retriable).
|
||||
func (fc *FallbackChain) ExecuteImage(
|
||||
ctx context.Context,
|
||||
candidates []FallbackCandidate,
|
||||
run func(ctx context.Context, provider, model string) (*LLMResponse, error),
|
||||
) (*FallbackResult, error) {
|
||||
if len(candidates) == 0 {
|
||||
return nil, fmt.Errorf("image fallback: no candidates configured")
|
||||
}
|
||||
|
||||
result := &FallbackResult{
|
||||
Attempts: make([]FallbackAttempt, 0, len(candidates)),
|
||||
}
|
||||
|
||||
for i, candidate := range candidates {
|
||||
if ctx.Err() == context.Canceled {
|
||||
return nil, context.Canceled
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
resp, err := run(ctx, candidate.Provider, candidate.Model)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
if err == nil {
|
||||
result.Response = resp
|
||||
result.Provider = candidate.Provider
|
||||
result.Model = candidate.Model
|
||||
return result, nil
|
||||
}
|
||||
|
||||
if ctx.Err() == context.Canceled {
|
||||
result.Attempts = append(result.Attempts, FallbackAttempt{
|
||||
Provider: candidate.Provider,
|
||||
Model: candidate.Model,
|
||||
Error: err,
|
||||
Duration: elapsed,
|
||||
})
|
||||
return nil, context.Canceled
|
||||
}
|
||||
|
||||
// Image dimension/size errors are non-retriable.
|
||||
errMsg := strings.ToLower(err.Error())
|
||||
if IsImageDimensionError(errMsg) || IsImageSizeError(errMsg) {
|
||||
result.Attempts = append(result.Attempts, FallbackAttempt{
|
||||
Provider: candidate.Provider,
|
||||
Model: candidate.Model,
|
||||
Error: err,
|
||||
Reason: FailoverFormat,
|
||||
Duration: elapsed,
|
||||
})
|
||||
return nil, &FailoverError{
|
||||
Reason: FailoverFormat,
|
||||
Provider: candidate.Provider,
|
||||
Model: candidate.Model,
|
||||
Wrapped: err,
|
||||
}
|
||||
}
|
||||
|
||||
// Any other error: record and try next.
|
||||
result.Attempts = append(result.Attempts, FallbackAttempt{
|
||||
Provider: candidate.Provider,
|
||||
Model: candidate.Model,
|
||||
Error: err,
|
||||
Duration: elapsed,
|
||||
})
|
||||
|
||||
if i == len(candidates)-1 {
|
||||
return nil, &FallbackExhaustedError{Attempts: result.Attempts}
|
||||
}
|
||||
}
|
||||
|
||||
return nil, &FallbackExhaustedError{Attempts: result.Attempts}
|
||||
}
|
||||
|
||||
// FallbackExhaustedError indicates all fallback candidates were tried and failed.
|
||||
type FallbackExhaustedError struct {
|
||||
Attempts []FallbackAttempt
|
||||
}
|
||||
|
||||
func (e *FallbackExhaustedError) Error() string {
|
||||
var sb strings.Builder
|
||||
sb.WriteString(fmt.Sprintf("fallback: all %d candidates failed:", len(e.Attempts)))
|
||||
for i, a := range e.Attempts {
|
||||
if a.Skipped {
|
||||
sb.WriteString(fmt.Sprintf("\n [%d] %s/%s: skipped (cooldown)", i+1, a.Provider, a.Model))
|
||||
} else {
|
||||
sb.WriteString(fmt.Sprintf("\n [%d] %s/%s: %v (reason=%s, %s)",
|
||||
i+1, a.Provider, a.Model, a.Error, a.Reason, a.Duration.Round(time.Millisecond)))
|
||||
}
|
||||
}
|
||||
return sb.String()
|
||||
}
|
||||
@@ -0,0 +1,473 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func makeCandidate(provider, model string) FallbackCandidate {
|
||||
return FallbackCandidate{Provider: provider, Model: model}
|
||||
}
|
||||
|
||||
func successRun(content string) func(ctx context.Context, provider, model string) (*LLMResponse, error) {
|
||||
return func(ctx context.Context, provider, model string) (*LLMResponse, error) {
|
||||
return &LLMResponse{Content: content, FinishReason: "stop"}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func failRun(err error) func(ctx context.Context, provider, model string) (*LLMResponse, error) {
|
||||
return func(ctx context.Context, provider, model string) (*LLMResponse, error) {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
func TestFallback_SingleCandidate_Success(t *testing.T) {
|
||||
ct := NewCooldownTracker()
|
||||
fc := NewFallbackChain(ct)
|
||||
|
||||
candidates := []FallbackCandidate{makeCandidate("openai", "gpt-4")}
|
||||
result, err := fc.Execute(context.Background(), candidates, successRun("hello"))
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if result.Response.Content != "hello" {
|
||||
t.Errorf("content = %q, want hello", result.Response.Content)
|
||||
}
|
||||
if result.Provider != "openai" || result.Model != "gpt-4" {
|
||||
t.Errorf("provider/model = %s/%s, want openai/gpt-4", result.Provider, result.Model)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFallback_SecondCandidateSuccess(t *testing.T) {
|
||||
ct := NewCooldownTracker()
|
||||
fc := NewFallbackChain(ct)
|
||||
|
||||
candidates := []FallbackCandidate{
|
||||
makeCandidate("openai", "gpt-4"),
|
||||
makeCandidate("anthropic", "claude-opus"),
|
||||
}
|
||||
|
||||
attempt := 0
|
||||
run := func(ctx context.Context, provider, model string) (*LLMResponse, error) {
|
||||
attempt++
|
||||
if attempt == 1 {
|
||||
return nil, errors.New("rate limit exceeded")
|
||||
}
|
||||
return &LLMResponse{Content: "from claude", FinishReason: "stop"}, nil
|
||||
}
|
||||
|
||||
result, err := fc.Execute(context.Background(), candidates, run)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if result.Provider != "anthropic" {
|
||||
t.Errorf("provider = %q, want anthropic", result.Provider)
|
||||
}
|
||||
if result.Response.Content != "from claude" {
|
||||
t.Errorf("content = %q, want 'from claude'", result.Response.Content)
|
||||
}
|
||||
if len(result.Attempts) != 1 {
|
||||
t.Errorf("attempts = %d, want 1 (failed attempt recorded)", len(result.Attempts))
|
||||
}
|
||||
}
|
||||
|
||||
func TestFallback_AllFail(t *testing.T) {
|
||||
ct := NewCooldownTracker()
|
||||
fc := NewFallbackChain(ct)
|
||||
|
||||
candidates := []FallbackCandidate{
|
||||
makeCandidate("openai", "gpt-4"),
|
||||
makeCandidate("anthropic", "claude"),
|
||||
makeCandidate("groq", "llama"),
|
||||
}
|
||||
|
||||
run := func(ctx context.Context, provider, model string) (*LLMResponse, error) {
|
||||
return nil, errors.New("rate limit exceeded")
|
||||
}
|
||||
|
||||
_, err := fc.Execute(context.Background(), candidates, run)
|
||||
if err == nil {
|
||||
t.Fatal("expected error when all candidates fail")
|
||||
}
|
||||
var exhausted *FallbackExhaustedError
|
||||
if !errors.As(err, &exhausted) {
|
||||
t.Errorf("expected FallbackExhaustedError, got %T: %v", err, err)
|
||||
}
|
||||
if len(exhausted.Attempts) != 3 {
|
||||
t.Errorf("attempts = %d, want 3", len(exhausted.Attempts))
|
||||
}
|
||||
}
|
||||
|
||||
func TestFallback_ContextCanceled(t *testing.T) {
|
||||
ct := NewCooldownTracker()
|
||||
fc := NewFallbackChain(ct)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
candidates := []FallbackCandidate{
|
||||
makeCandidate("openai", "gpt-4"),
|
||||
makeCandidate("anthropic", "claude"),
|
||||
}
|
||||
|
||||
attempt := 0
|
||||
run := func(ctx context.Context, provider, model string) (*LLMResponse, error) {
|
||||
attempt++
|
||||
if attempt == 1 {
|
||||
cancel() // cancel context
|
||||
return nil, context.Canceled
|
||||
}
|
||||
t.Error("should not reach second candidate after cancel")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
_, err := fc.Execute(ctx, candidates, run)
|
||||
if err != context.Canceled {
|
||||
t.Errorf("expected context.Canceled, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFallback_NonRetriableError(t *testing.T) {
|
||||
ct := NewCooldownTracker()
|
||||
fc := NewFallbackChain(ct)
|
||||
|
||||
candidates := []FallbackCandidate{
|
||||
makeCandidate("openai", "gpt-4"),
|
||||
makeCandidate("anthropic", "claude"),
|
||||
}
|
||||
|
||||
attempt := 0
|
||||
run := func(ctx context.Context, provider, model string) (*LLMResponse, error) {
|
||||
attempt++
|
||||
return nil, errors.New("string should match pattern")
|
||||
}
|
||||
|
||||
_, err := fc.Execute(context.Background(), candidates, run)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for non-retriable")
|
||||
}
|
||||
var fe *FailoverError
|
||||
if !errors.As(err, &fe) {
|
||||
t.Fatalf("expected FailoverError, got %T", err)
|
||||
}
|
||||
if fe.Reason != FailoverFormat {
|
||||
t.Errorf("reason = %q, want format", fe.Reason)
|
||||
}
|
||||
if attempt != 1 {
|
||||
t.Errorf("attempt = %d, want 1 (non-retriable should not try next)", attempt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFallback_CooldownSkip(t *testing.T) {
|
||||
now := time.Now()
|
||||
ct, _ := newTestTracker(now)
|
||||
fc := NewFallbackChain(ct)
|
||||
|
||||
// Put openai in cooldown
|
||||
ct.MarkFailure("openai", FailoverRateLimit)
|
||||
|
||||
candidates := []FallbackCandidate{
|
||||
makeCandidate("openai", "gpt-4"),
|
||||
makeCandidate("anthropic", "claude"),
|
||||
}
|
||||
|
||||
run := func(ctx context.Context, provider, model string) (*LLMResponse, error) {
|
||||
if provider == "openai" {
|
||||
t.Error("should not call openai (in cooldown)")
|
||||
}
|
||||
return &LLMResponse{Content: "claude response", FinishReason: "stop"}, nil
|
||||
}
|
||||
|
||||
result, err := fc.Execute(context.Background(), candidates, run)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if result.Provider != "anthropic" {
|
||||
t.Errorf("provider = %q, want anthropic", result.Provider)
|
||||
}
|
||||
// Should have 1 skipped attempt
|
||||
skipped := 0
|
||||
for _, a := range result.Attempts {
|
||||
if a.Skipped {
|
||||
skipped++
|
||||
}
|
||||
}
|
||||
if skipped != 1 {
|
||||
t.Errorf("skipped = %d, want 1", skipped)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFallback_AllInCooldown(t *testing.T) {
|
||||
ct := NewCooldownTracker()
|
||||
fc := NewFallbackChain(ct)
|
||||
|
||||
// Put all providers in cooldown
|
||||
ct.MarkFailure("openai", FailoverRateLimit)
|
||||
ct.MarkFailure("anthropic", FailoverBilling)
|
||||
|
||||
candidates := []FallbackCandidate{
|
||||
makeCandidate("openai", "gpt-4"),
|
||||
makeCandidate("anthropic", "claude"),
|
||||
}
|
||||
|
||||
_, err := fc.Execute(context.Background(), candidates,
|
||||
func(ctx context.Context, provider, model string) (*LLMResponse, error) {
|
||||
t.Error("should not call any provider (all in cooldown)")
|
||||
return nil, nil
|
||||
})
|
||||
|
||||
if err == nil {
|
||||
t.Fatal("expected error when all in cooldown")
|
||||
}
|
||||
var exhausted *FallbackExhaustedError
|
||||
if !errors.As(err, &exhausted) {
|
||||
t.Fatalf("expected FallbackExhaustedError, got %T", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFallback_NoCandidates(t *testing.T) {
|
||||
ct := NewCooldownTracker()
|
||||
fc := NewFallbackChain(ct)
|
||||
|
||||
_, err := fc.Execute(context.Background(), nil, successRun("ok"))
|
||||
if err == nil {
|
||||
t.Error("expected error for empty candidates")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFallback_EmptyFallbacks(t *testing.T) {
|
||||
// Single primary, no fallbacks: should work like direct call
|
||||
ct := NewCooldownTracker()
|
||||
fc := NewFallbackChain(ct)
|
||||
|
||||
candidates := []FallbackCandidate{makeCandidate("openai", "gpt-4")}
|
||||
result, err := fc.Execute(context.Background(), candidates, successRun("ok"))
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if result.Response.Content != "ok" {
|
||||
t.Error("expected success with single candidate")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFallback_UnclassifiedError(t *testing.T) {
|
||||
ct := NewCooldownTracker()
|
||||
fc := NewFallbackChain(ct)
|
||||
|
||||
candidates := []FallbackCandidate{
|
||||
makeCandidate("openai", "gpt-4"),
|
||||
makeCandidate("anthropic", "claude"),
|
||||
}
|
||||
|
||||
attempt := 0
|
||||
run := func(ctx context.Context, provider, model string) (*LLMResponse, error) {
|
||||
attempt++
|
||||
return nil, errors.New("completely unknown internal error")
|
||||
}
|
||||
|
||||
_, err := fc.Execute(context.Background(), candidates, run)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for unclassified error")
|
||||
}
|
||||
if attempt != 1 {
|
||||
t.Errorf("attempt = %d, want 1 (should not fallback on unclassified)", attempt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFallback_SuccessResetsCooldown(t *testing.T) {
|
||||
ct := NewCooldownTracker()
|
||||
fc := NewFallbackChain(ct)
|
||||
|
||||
candidates := []FallbackCandidate{makeCandidate("openai", "gpt-4")}
|
||||
|
||||
attempt := 0
|
||||
run := func(ctx context.Context, provider, model string) (*LLMResponse, error) {
|
||||
attempt++
|
||||
if attempt == 1 {
|
||||
ct.MarkFailure("openai", FailoverRateLimit) // simulate failure tracked elsewhere
|
||||
}
|
||||
return &LLMResponse{Content: "ok", FinishReason: "stop"}, nil
|
||||
}
|
||||
|
||||
_, err := fc.Execute(context.Background(), candidates, run)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !ct.IsAvailable("openai") {
|
||||
t.Error("success should reset cooldown")
|
||||
}
|
||||
}
|
||||
|
||||
// --- Image Fallback Tests ---
|
||||
|
||||
func TestImageFallback_Success(t *testing.T) {
|
||||
ct := NewCooldownTracker()
|
||||
fc := NewFallbackChain(ct)
|
||||
|
||||
candidates := []FallbackCandidate{makeCandidate("openai", "gpt-4o")}
|
||||
result, err := fc.ExecuteImage(context.Background(), candidates, successRun("image result"))
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if result.Response.Content != "image result" {
|
||||
t.Error("expected image result")
|
||||
}
|
||||
}
|
||||
|
||||
func TestImageFallback_DimensionError(t *testing.T) {
|
||||
ct := NewCooldownTracker()
|
||||
fc := NewFallbackChain(ct)
|
||||
|
||||
candidates := []FallbackCandidate{
|
||||
makeCandidate("openai", "gpt-4o"),
|
||||
makeCandidate("anthropic", "claude"),
|
||||
}
|
||||
|
||||
attempt := 0
|
||||
run := func(ctx context.Context, provider, model string) (*LLMResponse, error) {
|
||||
attempt++
|
||||
return nil, errors.New("image dimensions exceed max 4096x4096")
|
||||
}
|
||||
|
||||
_, err := fc.ExecuteImage(context.Background(), candidates, run)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for image dimension error")
|
||||
}
|
||||
if attempt != 1 {
|
||||
t.Errorf("attempt = %d, want 1 (image dimension error should not retry)", attempt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestImageFallback_SizeError(t *testing.T) {
|
||||
ct := NewCooldownTracker()
|
||||
fc := NewFallbackChain(ct)
|
||||
|
||||
candidates := []FallbackCandidate{
|
||||
makeCandidate("openai", "gpt-4o"),
|
||||
makeCandidate("anthropic", "claude"),
|
||||
}
|
||||
|
||||
attempt := 0
|
||||
run := func(ctx context.Context, provider, model string) (*LLMResponse, error) {
|
||||
attempt++
|
||||
return nil, errors.New("image exceeds 20 mb")
|
||||
}
|
||||
|
||||
_, err := fc.ExecuteImage(context.Background(), candidates, run)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for image size error")
|
||||
}
|
||||
if attempt != 1 {
|
||||
t.Errorf("attempt = %d, want 1 (image size error should not retry)", attempt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestImageFallback_RetryOnOtherErrors(t *testing.T) {
|
||||
ct := NewCooldownTracker()
|
||||
fc := NewFallbackChain(ct)
|
||||
|
||||
candidates := []FallbackCandidate{
|
||||
makeCandidate("openai", "gpt-4o"),
|
||||
makeCandidate("anthropic", "claude-sonnet"),
|
||||
}
|
||||
|
||||
attempt := 0
|
||||
run := func(ctx context.Context, provider, model string) (*LLMResponse, error) {
|
||||
attempt++
|
||||
if attempt == 1 {
|
||||
return nil, errors.New("rate limit exceeded")
|
||||
}
|
||||
return &LLMResponse{Content: "image ok", FinishReason: "stop"}, nil
|
||||
}
|
||||
|
||||
result, err := fc.ExecuteImage(context.Background(), candidates, run)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if result.Provider != "anthropic" {
|
||||
t.Errorf("provider = %q, want anthropic", result.Provider)
|
||||
}
|
||||
}
|
||||
|
||||
func TestImageFallback_NoCandidates(t *testing.T) {
|
||||
ct := NewCooldownTracker()
|
||||
fc := NewFallbackChain(ct)
|
||||
|
||||
_, err := fc.ExecuteImage(context.Background(), nil, successRun("ok"))
|
||||
if err == nil {
|
||||
t.Error("expected error for empty candidates")
|
||||
}
|
||||
}
|
||||
|
||||
// --- ResolveCandidates Tests ---
|
||||
|
||||
func TestResolveCandidates_Simple(t *testing.T) {
|
||||
cfg := ModelConfig{
|
||||
Primary: "gpt-4",
|
||||
Fallbacks: []string{"anthropic/claude-opus", "groq/llama-3"},
|
||||
}
|
||||
|
||||
candidates := ResolveCandidates(cfg, "openai")
|
||||
if len(candidates) != 3 {
|
||||
t.Fatalf("candidates = %d, want 3", len(candidates))
|
||||
}
|
||||
|
||||
if candidates[0].Provider != "openai" || candidates[0].Model != "gpt-4" {
|
||||
t.Errorf("candidate[0] = %s/%s, want openai/gpt-4", candidates[0].Provider, candidates[0].Model)
|
||||
}
|
||||
if candidates[1].Provider != "anthropic" || candidates[1].Model != "claude-opus" {
|
||||
t.Errorf("candidate[1] = %s/%s, want anthropic/claude-opus", candidates[1].Provider, candidates[1].Model)
|
||||
}
|
||||
if candidates[2].Provider != "groq" || candidates[2].Model != "llama-3" {
|
||||
t.Errorf("candidate[2] = %s/%s, want groq/llama-3", candidates[2].Provider, candidates[2].Model)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveCandidates_Deduplication(t *testing.T) {
|
||||
cfg := ModelConfig{
|
||||
Primary: "openai/gpt-4",
|
||||
Fallbacks: []string{"openai/gpt-4", "anthropic/claude"},
|
||||
}
|
||||
|
||||
candidates := ResolveCandidates(cfg, "default")
|
||||
if len(candidates) != 2 {
|
||||
t.Errorf("candidates = %d, want 2 (duplicate removed)", len(candidates))
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveCandidates_EmptyFallbacks(t *testing.T) {
|
||||
cfg := ModelConfig{
|
||||
Primary: "gpt-4",
|
||||
Fallbacks: nil,
|
||||
}
|
||||
|
||||
candidates := ResolveCandidates(cfg, "openai")
|
||||
if len(candidates) != 1 {
|
||||
t.Errorf("candidates = %d, want 1", len(candidates))
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveCandidates_EmptyPrimary(t *testing.T) {
|
||||
cfg := ModelConfig{
|
||||
Primary: "",
|
||||
Fallbacks: []string{"anthropic/claude"},
|
||||
}
|
||||
|
||||
candidates := ResolveCandidates(cfg, "openai")
|
||||
if len(candidates) != 1 {
|
||||
t.Errorf("candidates = %d, want 1", len(candidates))
|
||||
}
|
||||
}
|
||||
|
||||
func TestFallbackExhaustedError_Message(t *testing.T) {
|
||||
e := &FallbackExhaustedError{
|
||||
Attempts: []FallbackAttempt{
|
||||
{Provider: "openai", Model: "gpt-4", Error: errors.New("rate limited"), Reason: FailoverRateLimit, Duration: 500 * time.Millisecond},
|
||||
{Provider: "anthropic", Model: "claude", Skipped: true},
|
||||
},
|
||||
}
|
||||
msg := e.Error()
|
||||
if msg == "" {
|
||||
t.Error("expected non-empty error message")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,64 @@
|
||||
package providers
|
||||
|
||||
import "strings"
|
||||
|
||||
// ModelRef represents a parsed model reference with provider and model name.
|
||||
type ModelRef struct {
|
||||
Provider string
|
||||
Model string
|
||||
}
|
||||
|
||||
// ParseModelRef parses "anthropic/claude-opus" into {Provider: "anthropic", Model: "claude-opus"}.
|
||||
// If no slash present, uses defaultProvider.
|
||||
// Returns nil for empty input.
|
||||
func ParseModelRef(raw string, defaultProvider string) *ModelRef {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
if idx := strings.Index(raw, "/"); idx > 0 {
|
||||
provider := NormalizeProvider(raw[:idx])
|
||||
model := strings.TrimSpace(raw[idx+1:])
|
||||
if model == "" {
|
||||
return nil
|
||||
}
|
||||
return &ModelRef{Provider: provider, Model: model}
|
||||
}
|
||||
|
||||
return &ModelRef{
|
||||
Provider: NormalizeProvider(defaultProvider),
|
||||
Model: raw,
|
||||
}
|
||||
}
|
||||
|
||||
// NormalizeProvider normalizes provider identifiers to canonical form.
|
||||
func NormalizeProvider(provider string) string {
|
||||
p := strings.ToLower(strings.TrimSpace(provider))
|
||||
|
||||
switch p {
|
||||
case "z.ai", "z-ai":
|
||||
return "zai"
|
||||
case "opencode-zen":
|
||||
return "opencode"
|
||||
case "qwen":
|
||||
return "qwen-portal"
|
||||
case "kimi-code":
|
||||
return "kimi-coding"
|
||||
case "gpt":
|
||||
return "openai"
|
||||
case "claude":
|
||||
return "anthropic"
|
||||
case "glm":
|
||||
return "zhipu"
|
||||
case "google":
|
||||
return "gemini"
|
||||
}
|
||||
|
||||
return p
|
||||
}
|
||||
|
||||
// ModelKey returns a canonical "provider/model" key for deduplication.
|
||||
func ModelKey(provider, model string) string {
|
||||
return NormalizeProvider(provider) + "/" + strings.ToLower(strings.TrimSpace(model))
|
||||
}
|
||||
@@ -0,0 +1,125 @@
|
||||
package providers
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestParseModelRef_WithSlash(t *testing.T) {
|
||||
ref := ParseModelRef("anthropic/claude-opus", "openai")
|
||||
if ref == nil {
|
||||
t.Fatal("expected non-nil ref")
|
||||
}
|
||||
if ref.Provider != "anthropic" {
|
||||
t.Errorf("provider = %q, want anthropic", ref.Provider)
|
||||
}
|
||||
if ref.Model != "claude-opus" {
|
||||
t.Errorf("model = %q, want claude-opus", ref.Model)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseModelRef_WithoutSlash(t *testing.T) {
|
||||
ref := ParseModelRef("gpt-4", "openai")
|
||||
if ref == nil {
|
||||
t.Fatal("expected non-nil ref")
|
||||
}
|
||||
if ref.Provider != "openai" {
|
||||
t.Errorf("provider = %q, want openai", ref.Provider)
|
||||
}
|
||||
if ref.Model != "gpt-4" {
|
||||
t.Errorf("model = %q, want gpt-4", ref.Model)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseModelRef_Empty(t *testing.T) {
|
||||
ref := ParseModelRef("", "openai")
|
||||
if ref != nil {
|
||||
t.Errorf("expected nil for empty string, got %+v", ref)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseModelRef_EmptyModelAfterSlash(t *testing.T) {
|
||||
ref := ParseModelRef("openai/", "default")
|
||||
if ref != nil {
|
||||
t.Errorf("expected nil for empty model, got %+v", ref)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseModelRef_WhitespaceHandling(t *testing.T) {
|
||||
ref := ParseModelRef(" anthropic / claude-opus ", "openai")
|
||||
if ref == nil {
|
||||
t.Fatal("expected non-nil ref")
|
||||
}
|
||||
if ref.Provider != "anthropic" {
|
||||
t.Errorf("provider = %q, want anthropic", ref.Provider)
|
||||
}
|
||||
if ref.Model != "claude-opus" {
|
||||
t.Errorf("model = %q, want claude-opus", ref.Model)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeProvider(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
want string
|
||||
}{
|
||||
{"OpenAI", "openai"},
|
||||
{"ANTHROPIC", "anthropic"},
|
||||
{"z.ai", "zai"},
|
||||
{"z-ai", "zai"},
|
||||
{"Z.AI", "zai"},
|
||||
{"opencode-zen", "opencode"},
|
||||
{"qwen", "qwen-portal"},
|
||||
{"kimi-code", "kimi-coding"},
|
||||
{"gpt", "openai"},
|
||||
{"claude", "anthropic"},
|
||||
{"glm", "zhipu"},
|
||||
{"google", "gemini"},
|
||||
{"groq", "groq"},
|
||||
{"", ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
got := NormalizeProvider(tt.input)
|
||||
if got != tt.want {
|
||||
t.Errorf("NormalizeProvider(%q) = %q, want %q", tt.input, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
provider string
|
||||
model string
|
||||
want string
|
||||
}{
|
||||
{"openai", "gpt-4", "openai/gpt-4"},
|
||||
{"Anthropic", "Claude-Opus", "anthropic/claude-opus"},
|
||||
{"claude", "sonnet", "anthropic/sonnet"},
|
||||
{"z.ai", "Model-X", "zai/model-x"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
got := ModelKey(tt.provider, tt.model)
|
||||
if got != tt.want {
|
||||
t.Errorf("ModelKey(%q, %q) = %q, want %q", tt.provider, tt.model, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseModelRef_ProviderNormalization(t *testing.T) {
|
||||
ref := ParseModelRef("Z.AI/model-x", "default")
|
||||
if ref == nil {
|
||||
t.Fatal("expected non-nil ref")
|
||||
}
|
||||
if ref.Provider != "zai" {
|
||||
t.Errorf("provider = %q, want zai", ref.Provider)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseModelRef_DefaultProviderNormalization(t *testing.T) {
|
||||
ref := ParseModelRef("gpt-4o", "GPT")
|
||||
if ref == nil {
|
||||
t.Fatal("expected non-nil ref")
|
||||
}
|
||||
if ref.Provider != "openai" {
|
||||
t.Errorf("provider = %q, want openai (normalized from GPT)", ref.Provider)
|
||||
}
|
||||
}
|
||||
@@ -2,6 +2,7 @@ package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
|
||||
)
|
||||
@@ -18,3 +19,46 @@ type LLMProvider interface {
|
||||
Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error)
|
||||
GetDefaultModel() string
|
||||
}
|
||||
|
||||
// FailoverReason classifies why an LLM request failed for fallback decisions.
|
||||
type FailoverReason string
|
||||
|
||||
const (
|
||||
FailoverAuth FailoverReason = "auth"
|
||||
FailoverRateLimit FailoverReason = "rate_limit"
|
||||
FailoverBilling FailoverReason = "billing"
|
||||
FailoverTimeout FailoverReason = "timeout"
|
||||
FailoverFormat FailoverReason = "format"
|
||||
FailoverOverloaded FailoverReason = "overloaded"
|
||||
FailoverUnknown FailoverReason = "unknown"
|
||||
)
|
||||
|
||||
// FailoverError wraps an LLM provider error with classification metadata.
|
||||
type FailoverError struct {
|
||||
Reason FailoverReason
|
||||
Provider string
|
||||
Model string
|
||||
Status int
|
||||
Wrapped error
|
||||
}
|
||||
|
||||
func (e *FailoverError) Error() string {
|
||||
return fmt.Sprintf("failover(%s): provider=%s model=%s status=%d: %v",
|
||||
e.Reason, e.Provider, e.Model, e.Status, e.Wrapped)
|
||||
}
|
||||
|
||||
func (e *FailoverError) Unwrap() error {
|
||||
return e.Wrapped
|
||||
}
|
||||
|
||||
// IsRetriable returns true if this error should trigger fallback to next candidate.
|
||||
// Non-retriable: Format errors (bad request structure, image dimension/size).
|
||||
func (e *FailoverError) IsRetriable() bool {
|
||||
return e.Reason != FailoverFormat
|
||||
}
|
||||
|
||||
// ModelConfig holds primary model and fallback list.
|
||||
type ModelConfig struct {
|
||||
Primary string
|
||||
Fallbacks []string
|
||||
}
|
||||
|
||||
@@ -0,0 +1,66 @@
|
||||
package routing
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultAgentID = "main"
|
||||
DefaultMainKey = "main"
|
||||
DefaultAccountID = "default"
|
||||
MaxAgentIDLength = 64
|
||||
)
|
||||
|
||||
var (
|
||||
validIDRe = regexp.MustCompile(`^[a-z0-9][a-z0-9_-]{0,63}$`)
|
||||
invalidCharsRe = regexp.MustCompile(`[^a-z0-9_-]+`)
|
||||
leadingDashRe = regexp.MustCompile(`^-+`)
|
||||
trailingDashRe = regexp.MustCompile(`-+$`)
|
||||
)
|
||||
|
||||
// NormalizeAgentID sanitizes an agent ID to [a-z0-9][a-z0-9_-]{0,63}.
|
||||
// Invalid characters are collapsed to "-". Leading/trailing dashes stripped.
|
||||
// Empty input returns DefaultAgentID ("main").
|
||||
func NormalizeAgentID(id string) string {
|
||||
trimmed := strings.TrimSpace(id)
|
||||
if trimmed == "" {
|
||||
return DefaultAgentID
|
||||
}
|
||||
lower := strings.ToLower(trimmed)
|
||||
if validIDRe.MatchString(lower) {
|
||||
return lower
|
||||
}
|
||||
result := invalidCharsRe.ReplaceAllString(lower, "-")
|
||||
result = leadingDashRe.ReplaceAllString(result, "")
|
||||
result = trailingDashRe.ReplaceAllString(result, "")
|
||||
if len(result) > MaxAgentIDLength {
|
||||
result = result[:MaxAgentIDLength]
|
||||
}
|
||||
if result == "" {
|
||||
return DefaultAgentID
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// NormalizeAccountID sanitizes an account ID. Empty returns DefaultAccountID.
|
||||
func NormalizeAccountID(id string) string {
|
||||
trimmed := strings.TrimSpace(id)
|
||||
if trimmed == "" {
|
||||
return DefaultAccountID
|
||||
}
|
||||
lower := strings.ToLower(trimmed)
|
||||
if validIDRe.MatchString(lower) {
|
||||
return lower
|
||||
}
|
||||
result := invalidCharsRe.ReplaceAllString(lower, "-")
|
||||
result = leadingDashRe.ReplaceAllString(result, "")
|
||||
result = trailingDashRe.ReplaceAllString(result, "")
|
||||
if len(result) > MaxAgentIDLength {
|
||||
result = result[:MaxAgentIDLength]
|
||||
}
|
||||
if result == "" {
|
||||
return DefaultAccountID
|
||||
}
|
||||
return result
|
||||
}
|
||||
@@ -0,0 +1,86 @@
|
||||
package routing
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestNormalizeAgentID_Empty(t *testing.T) {
|
||||
if got := NormalizeAgentID(""); got != DefaultAgentID {
|
||||
t.Errorf("NormalizeAgentID('') = %q, want %q", got, DefaultAgentID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeAgentID_Whitespace(t *testing.T) {
|
||||
if got := NormalizeAgentID(" "); got != DefaultAgentID {
|
||||
t.Errorf("NormalizeAgentID(' ') = %q, want %q", got, DefaultAgentID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeAgentID_Valid(t *testing.T) {
|
||||
tests := []struct {
|
||||
input, want string
|
||||
}{
|
||||
{"main", "main"},
|
||||
{"Main", "main"},
|
||||
{"SALES", "sales"},
|
||||
{"support-bot", "support-bot"},
|
||||
{"agent_1", "agent_1"},
|
||||
{"a", "a"},
|
||||
{"0test", "0test"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
if got := NormalizeAgentID(tt.input); got != tt.want {
|
||||
t.Errorf("NormalizeAgentID(%q) = %q, want %q", tt.input, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeAgentID_InvalidChars(t *testing.T) {
|
||||
tests := []struct {
|
||||
input, want string
|
||||
}{
|
||||
{"Hello World", "hello-world"},
|
||||
{"agent@123", "agent-123"},
|
||||
{"foo.bar.baz", "foo-bar-baz"},
|
||||
{"--leading", "leading"},
|
||||
{"--both--", "both"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
if got := NormalizeAgentID(tt.input); got != tt.want {
|
||||
t.Errorf("NormalizeAgentID(%q) = %q, want %q", tt.input, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeAgentID_AllInvalid(t *testing.T) {
|
||||
if got := NormalizeAgentID("@@@"); got != DefaultAgentID {
|
||||
t.Errorf("NormalizeAgentID('@@@') = %q, want %q", got, DefaultAgentID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeAgentID_TruncatesAt64(t *testing.T) {
|
||||
long := ""
|
||||
for i := 0; i < 100; i++ {
|
||||
long += "a"
|
||||
}
|
||||
got := NormalizeAgentID(long)
|
||||
if len(got) > MaxAgentIDLength {
|
||||
t.Errorf("length = %d, want <= %d", len(got), MaxAgentIDLength)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeAccountID_Empty(t *testing.T) {
|
||||
if got := NormalizeAccountID(""); got != DefaultAccountID {
|
||||
t.Errorf("NormalizeAccountID('') = %q, want %q", got, DefaultAccountID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeAccountID_Valid(t *testing.T) {
|
||||
if got := NormalizeAccountID("MyBot"); got != "mybot" {
|
||||
t.Errorf("NormalizeAccountID('MyBot') = %q, want 'mybot'", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeAccountID_InvalidChars(t *testing.T) {
|
||||
if got := NormalizeAccountID("bot@home"); got != "bot-home" {
|
||||
t.Errorf("NormalizeAccountID('bot@home') = %q, want 'bot-home'", got)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,252 @@
|
||||
package routing
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
// RouteInput contains the routing context from an inbound message.
|
||||
type RouteInput struct {
|
||||
Channel string
|
||||
AccountID string
|
||||
Peer *RoutePeer
|
||||
ParentPeer *RoutePeer
|
||||
GuildID string
|
||||
TeamID string
|
||||
}
|
||||
|
||||
// ResolvedRoute is the result of agent routing.
|
||||
type ResolvedRoute struct {
|
||||
AgentID string
|
||||
Channel string
|
||||
AccountID string
|
||||
SessionKey string
|
||||
MainSessionKey string
|
||||
MatchedBy string // "binding.peer", "binding.peer.parent", "binding.guild", "binding.team", "binding.account", "binding.channel", "default"
|
||||
}
|
||||
|
||||
// RouteResolver determines which agent handles a message based on config bindings.
|
||||
type RouteResolver struct {
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
// NewRouteResolver creates a new route resolver.
|
||||
func NewRouteResolver(cfg *config.Config) *RouteResolver {
|
||||
return &RouteResolver{cfg: cfg}
|
||||
}
|
||||
|
||||
// ResolveRoute determines which agent handles the message and constructs session keys.
|
||||
// Implements the 7-level priority cascade:
|
||||
// peer > parent_peer > guild > team > account > channel_wildcard > default
|
||||
func (r *RouteResolver) ResolveRoute(input RouteInput) ResolvedRoute {
|
||||
channel := strings.ToLower(strings.TrimSpace(input.Channel))
|
||||
accountID := NormalizeAccountID(input.AccountID)
|
||||
peer := input.Peer
|
||||
|
||||
dmScope := DMScope(r.cfg.Session.DMScope)
|
||||
if dmScope == "" {
|
||||
dmScope = DMScopeMain
|
||||
}
|
||||
identityLinks := r.cfg.Session.IdentityLinks
|
||||
|
||||
bindings := r.filterBindings(channel, accountID)
|
||||
|
||||
choose := func(agentID string, matchedBy string) ResolvedRoute {
|
||||
resolvedAgentID := r.pickAgentID(agentID)
|
||||
sessionKey := strings.ToLower(BuildAgentPeerSessionKey(SessionKeyParams{
|
||||
AgentID: resolvedAgentID,
|
||||
Channel: channel,
|
||||
AccountID: accountID,
|
||||
Peer: peer,
|
||||
DMScope: dmScope,
|
||||
IdentityLinks: identityLinks,
|
||||
}))
|
||||
mainSessionKey := strings.ToLower(BuildAgentMainSessionKey(resolvedAgentID))
|
||||
return ResolvedRoute{
|
||||
AgentID: resolvedAgentID,
|
||||
Channel: channel,
|
||||
AccountID: accountID,
|
||||
SessionKey: sessionKey,
|
||||
MainSessionKey: mainSessionKey,
|
||||
MatchedBy: matchedBy,
|
||||
}
|
||||
}
|
||||
|
||||
// Priority 1: Peer binding
|
||||
if peer != nil && strings.TrimSpace(peer.ID) != "" {
|
||||
if match := r.findPeerMatch(bindings, peer); match != nil {
|
||||
return choose(match.AgentID, "binding.peer")
|
||||
}
|
||||
}
|
||||
|
||||
// Priority 2: Parent peer binding
|
||||
parentPeer := input.ParentPeer
|
||||
if parentPeer != nil && strings.TrimSpace(parentPeer.ID) != "" {
|
||||
if match := r.findPeerMatch(bindings, parentPeer); match != nil {
|
||||
return choose(match.AgentID, "binding.peer.parent")
|
||||
}
|
||||
}
|
||||
|
||||
// Priority 3: Guild binding
|
||||
guildID := strings.TrimSpace(input.GuildID)
|
||||
if guildID != "" {
|
||||
if match := r.findGuildMatch(bindings, guildID); match != nil {
|
||||
return choose(match.AgentID, "binding.guild")
|
||||
}
|
||||
}
|
||||
|
||||
// Priority 4: Team binding
|
||||
teamID := strings.TrimSpace(input.TeamID)
|
||||
if teamID != "" {
|
||||
if match := r.findTeamMatch(bindings, teamID); match != nil {
|
||||
return choose(match.AgentID, "binding.team")
|
||||
}
|
||||
}
|
||||
|
||||
// Priority 5: Account binding
|
||||
if match := r.findAccountMatch(bindings); match != nil {
|
||||
return choose(match.AgentID, "binding.account")
|
||||
}
|
||||
|
||||
// Priority 6: Channel wildcard binding
|
||||
if match := r.findChannelWildcardMatch(bindings); match != nil {
|
||||
return choose(match.AgentID, "binding.channel")
|
||||
}
|
||||
|
||||
// Priority 7: Default agent
|
||||
return choose(r.resolveDefaultAgentID(), "default")
|
||||
}
|
||||
|
||||
func (r *RouteResolver) filterBindings(channel, accountID string) []config.AgentBinding {
|
||||
var filtered []config.AgentBinding
|
||||
for _, b := range r.cfg.Bindings {
|
||||
matchChannel := strings.ToLower(strings.TrimSpace(b.Match.Channel))
|
||||
if matchChannel == "" || matchChannel != channel {
|
||||
continue
|
||||
}
|
||||
if !matchesAccountID(b.Match.AccountID, accountID) {
|
||||
continue
|
||||
}
|
||||
filtered = append(filtered, b)
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
func matchesAccountID(matchAccountID, actual string) bool {
|
||||
trimmed := strings.TrimSpace(matchAccountID)
|
||||
if trimmed == "" {
|
||||
return actual == DefaultAccountID
|
||||
}
|
||||
if trimmed == "*" {
|
||||
return true
|
||||
}
|
||||
return strings.ToLower(trimmed) == strings.ToLower(actual)
|
||||
}
|
||||
|
||||
func (r *RouteResolver) findPeerMatch(bindings []config.AgentBinding, peer *RoutePeer) *config.AgentBinding {
|
||||
for i := range bindings {
|
||||
b := &bindings[i]
|
||||
if b.Match.Peer == nil {
|
||||
continue
|
||||
}
|
||||
peerKind := strings.ToLower(strings.TrimSpace(b.Match.Peer.Kind))
|
||||
peerID := strings.TrimSpace(b.Match.Peer.ID)
|
||||
if peerKind == "" || peerID == "" {
|
||||
continue
|
||||
}
|
||||
if peerKind == strings.ToLower(peer.Kind) && peerID == peer.ID {
|
||||
return b
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RouteResolver) findGuildMatch(bindings []config.AgentBinding, guildID string) *config.AgentBinding {
|
||||
for i := range bindings {
|
||||
b := &bindings[i]
|
||||
matchGuild := strings.TrimSpace(b.Match.GuildID)
|
||||
if matchGuild != "" && matchGuild == guildID {
|
||||
return &bindings[i]
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RouteResolver) findTeamMatch(bindings []config.AgentBinding, teamID string) *config.AgentBinding {
|
||||
for i := range bindings {
|
||||
b := &bindings[i]
|
||||
matchTeam := strings.TrimSpace(b.Match.TeamID)
|
||||
if matchTeam != "" && matchTeam == teamID {
|
||||
return &bindings[i]
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RouteResolver) findAccountMatch(bindings []config.AgentBinding) *config.AgentBinding {
|
||||
for i := range bindings {
|
||||
b := &bindings[i]
|
||||
accountID := strings.TrimSpace(b.Match.AccountID)
|
||||
if accountID == "*" {
|
||||
continue
|
||||
}
|
||||
if b.Match.Peer != nil || b.Match.GuildID != "" || b.Match.TeamID != "" {
|
||||
continue
|
||||
}
|
||||
return &bindings[i]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RouteResolver) findChannelWildcardMatch(bindings []config.AgentBinding) *config.AgentBinding {
|
||||
for i := range bindings {
|
||||
b := &bindings[i]
|
||||
accountID := strings.TrimSpace(b.Match.AccountID)
|
||||
if accountID != "*" {
|
||||
continue
|
||||
}
|
||||
if b.Match.Peer != nil || b.Match.GuildID != "" || b.Match.TeamID != "" {
|
||||
continue
|
||||
}
|
||||
return &bindings[i]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RouteResolver) pickAgentID(agentID string) string {
|
||||
trimmed := strings.TrimSpace(agentID)
|
||||
if trimmed == "" {
|
||||
return NormalizeAgentID(r.resolveDefaultAgentID())
|
||||
}
|
||||
normalized := NormalizeAgentID(trimmed)
|
||||
agents := r.cfg.Agents.List
|
||||
if len(agents) == 0 {
|
||||
return normalized
|
||||
}
|
||||
for _, a := range agents {
|
||||
if NormalizeAgentID(a.ID) == normalized {
|
||||
return normalized
|
||||
}
|
||||
}
|
||||
return NormalizeAgentID(r.resolveDefaultAgentID())
|
||||
}
|
||||
|
||||
func (r *RouteResolver) resolveDefaultAgentID() string {
|
||||
agents := r.cfg.Agents.List
|
||||
if len(agents) == 0 {
|
||||
return DefaultAgentID
|
||||
}
|
||||
for _, a := range agents {
|
||||
if a.Default {
|
||||
id := strings.TrimSpace(a.ID)
|
||||
if id != "" {
|
||||
return NormalizeAgentID(id)
|
||||
}
|
||||
}
|
||||
}
|
||||
if id := strings.TrimSpace(agents[0].ID); id != "" {
|
||||
return NormalizeAgentID(id)
|
||||
}
|
||||
return DefaultAgentID
|
||||
}
|
||||
@@ -0,0 +1,297 @@
|
||||
package routing
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
func testConfig(agents []config.AgentConfig, bindings []config.AgentBinding) *config.Config {
|
||||
return &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: "/tmp/picoclaw-test",
|
||||
Model: "gpt-4",
|
||||
},
|
||||
List: agents,
|
||||
},
|
||||
Bindings: bindings,
|
||||
Session: config.SessionConfig{
|
||||
DMScope: "per-peer",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveRoute_DefaultAgent_NoBindings(t *testing.T) {
|
||||
cfg := testConfig(nil, nil)
|
||||
r := NewRouteResolver(cfg)
|
||||
|
||||
route := r.ResolveRoute(RouteInput{
|
||||
Channel: "telegram",
|
||||
Peer: &RoutePeer{Kind: "direct", ID: "user1"},
|
||||
})
|
||||
|
||||
if route.AgentID != DefaultAgentID {
|
||||
t.Errorf("AgentID = %q, want %q", route.AgentID, DefaultAgentID)
|
||||
}
|
||||
if route.MatchedBy != "default" {
|
||||
t.Errorf("MatchedBy = %q, want 'default'", route.MatchedBy)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveRoute_PeerBinding(t *testing.T) {
|
||||
agents := []config.AgentConfig{
|
||||
{ID: "sales", Default: true},
|
||||
{ID: "support"},
|
||||
}
|
||||
bindings := []config.AgentBinding{
|
||||
{
|
||||
AgentID: "support",
|
||||
Match: config.BindingMatch{
|
||||
Channel: "telegram",
|
||||
AccountID: "*",
|
||||
Peer: &config.PeerMatch{Kind: "direct", ID: "user123"},
|
||||
},
|
||||
},
|
||||
}
|
||||
cfg := testConfig(agents, bindings)
|
||||
r := NewRouteResolver(cfg)
|
||||
|
||||
route := r.ResolveRoute(RouteInput{
|
||||
Channel: "telegram",
|
||||
Peer: &RoutePeer{Kind: "direct", ID: "user123"},
|
||||
})
|
||||
|
||||
if route.AgentID != "support" {
|
||||
t.Errorf("AgentID = %q, want 'support'", route.AgentID)
|
||||
}
|
||||
if route.MatchedBy != "binding.peer" {
|
||||
t.Errorf("MatchedBy = %q, want 'binding.peer'", route.MatchedBy)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveRoute_GuildBinding(t *testing.T) {
|
||||
agents := []config.AgentConfig{
|
||||
{ID: "general", Default: true},
|
||||
{ID: "gaming"},
|
||||
}
|
||||
bindings := []config.AgentBinding{
|
||||
{
|
||||
AgentID: "gaming",
|
||||
Match: config.BindingMatch{
|
||||
Channel: "discord",
|
||||
AccountID: "*",
|
||||
GuildID: "guild-abc",
|
||||
},
|
||||
},
|
||||
}
|
||||
cfg := testConfig(agents, bindings)
|
||||
r := NewRouteResolver(cfg)
|
||||
|
||||
route := r.ResolveRoute(RouteInput{
|
||||
Channel: "discord",
|
||||
GuildID: "guild-abc",
|
||||
Peer: &RoutePeer{Kind: "channel", ID: "ch1"},
|
||||
})
|
||||
|
||||
if route.AgentID != "gaming" {
|
||||
t.Errorf("AgentID = %q, want 'gaming'", route.AgentID)
|
||||
}
|
||||
if route.MatchedBy != "binding.guild" {
|
||||
t.Errorf("MatchedBy = %q, want 'binding.guild'", route.MatchedBy)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveRoute_TeamBinding(t *testing.T) {
|
||||
agents := []config.AgentConfig{
|
||||
{ID: "general", Default: true},
|
||||
{ID: "work"},
|
||||
}
|
||||
bindings := []config.AgentBinding{
|
||||
{
|
||||
AgentID: "work",
|
||||
Match: config.BindingMatch{
|
||||
Channel: "slack",
|
||||
AccountID: "*",
|
||||
TeamID: "T12345",
|
||||
},
|
||||
},
|
||||
}
|
||||
cfg := testConfig(agents, bindings)
|
||||
r := NewRouteResolver(cfg)
|
||||
|
||||
route := r.ResolveRoute(RouteInput{
|
||||
Channel: "slack",
|
||||
TeamID: "T12345",
|
||||
Peer: &RoutePeer{Kind: "channel", ID: "C001"},
|
||||
})
|
||||
|
||||
if route.AgentID != "work" {
|
||||
t.Errorf("AgentID = %q, want 'work'", route.AgentID)
|
||||
}
|
||||
if route.MatchedBy != "binding.team" {
|
||||
t.Errorf("MatchedBy = %q, want 'binding.team'", route.MatchedBy)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveRoute_AccountBinding(t *testing.T) {
|
||||
agents := []config.AgentConfig{
|
||||
{ID: "default-agent", Default: true},
|
||||
{ID: "premium"},
|
||||
}
|
||||
bindings := []config.AgentBinding{
|
||||
{
|
||||
AgentID: "premium",
|
||||
Match: config.BindingMatch{
|
||||
Channel: "telegram",
|
||||
AccountID: "bot2",
|
||||
},
|
||||
},
|
||||
}
|
||||
cfg := testConfig(agents, bindings)
|
||||
r := NewRouteResolver(cfg)
|
||||
|
||||
route := r.ResolveRoute(RouteInput{
|
||||
Channel: "telegram",
|
||||
AccountID: "bot2",
|
||||
Peer: &RoutePeer{Kind: "direct", ID: "user1"},
|
||||
})
|
||||
|
||||
if route.AgentID != "premium" {
|
||||
t.Errorf("AgentID = %q, want 'premium'", route.AgentID)
|
||||
}
|
||||
if route.MatchedBy != "binding.account" {
|
||||
t.Errorf("MatchedBy = %q, want 'binding.account'", route.MatchedBy)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveRoute_ChannelWildcard(t *testing.T) {
|
||||
agents := []config.AgentConfig{
|
||||
{ID: "main", Default: true},
|
||||
{ID: "telegram-bot"},
|
||||
}
|
||||
bindings := []config.AgentBinding{
|
||||
{
|
||||
AgentID: "telegram-bot",
|
||||
Match: config.BindingMatch{
|
||||
Channel: "telegram",
|
||||
AccountID: "*",
|
||||
},
|
||||
},
|
||||
}
|
||||
cfg := testConfig(agents, bindings)
|
||||
r := NewRouteResolver(cfg)
|
||||
|
||||
route := r.ResolveRoute(RouteInput{
|
||||
Channel: "telegram",
|
||||
Peer: &RoutePeer{Kind: "direct", ID: "user1"},
|
||||
})
|
||||
|
||||
if route.AgentID != "telegram-bot" {
|
||||
t.Errorf("AgentID = %q, want 'telegram-bot'", route.AgentID)
|
||||
}
|
||||
if route.MatchedBy != "binding.channel" {
|
||||
t.Errorf("MatchedBy = %q, want 'binding.channel'", route.MatchedBy)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveRoute_PriorityOrder_PeerBeatsGuild(t *testing.T) {
|
||||
agents := []config.AgentConfig{
|
||||
{ID: "general", Default: true},
|
||||
{ID: "vip"},
|
||||
{ID: "gaming"},
|
||||
}
|
||||
bindings := []config.AgentBinding{
|
||||
{
|
||||
AgentID: "vip",
|
||||
Match: config.BindingMatch{
|
||||
Channel: "discord",
|
||||
AccountID: "*",
|
||||
Peer: &config.PeerMatch{Kind: "direct", ID: "user-vip"},
|
||||
},
|
||||
},
|
||||
{
|
||||
AgentID: "gaming",
|
||||
Match: config.BindingMatch{
|
||||
Channel: "discord",
|
||||
AccountID: "*",
|
||||
GuildID: "guild-1",
|
||||
},
|
||||
},
|
||||
}
|
||||
cfg := testConfig(agents, bindings)
|
||||
r := NewRouteResolver(cfg)
|
||||
|
||||
route := r.ResolveRoute(RouteInput{
|
||||
Channel: "discord",
|
||||
GuildID: "guild-1",
|
||||
Peer: &RoutePeer{Kind: "direct", ID: "user-vip"},
|
||||
})
|
||||
|
||||
if route.AgentID != "vip" {
|
||||
t.Errorf("AgentID = %q, want 'vip' (peer should beat guild)", route.AgentID)
|
||||
}
|
||||
if route.MatchedBy != "binding.peer" {
|
||||
t.Errorf("MatchedBy = %q, want 'binding.peer'", route.MatchedBy)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveRoute_InvalidAgentFallsToDefault(t *testing.T) {
|
||||
agents := []config.AgentConfig{
|
||||
{ID: "main", Default: true},
|
||||
}
|
||||
bindings := []config.AgentBinding{
|
||||
{
|
||||
AgentID: "nonexistent",
|
||||
Match: config.BindingMatch{
|
||||
Channel: "telegram",
|
||||
AccountID: "*",
|
||||
},
|
||||
},
|
||||
}
|
||||
cfg := testConfig(agents, bindings)
|
||||
r := NewRouteResolver(cfg)
|
||||
|
||||
route := r.ResolveRoute(RouteInput{
|
||||
Channel: "telegram",
|
||||
})
|
||||
|
||||
if route.AgentID != "main" {
|
||||
t.Errorf("AgentID = %q, want 'main' (invalid agent should fall to default)", route.AgentID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveRoute_DefaultAgentSelection(t *testing.T) {
|
||||
agents := []config.AgentConfig{
|
||||
{ID: "alpha"},
|
||||
{ID: "beta", Default: true},
|
||||
{ID: "gamma"},
|
||||
}
|
||||
cfg := testConfig(agents, nil)
|
||||
r := NewRouteResolver(cfg)
|
||||
|
||||
route := r.ResolveRoute(RouteInput{
|
||||
Channel: "cli",
|
||||
})
|
||||
|
||||
if route.AgentID != "beta" {
|
||||
t.Errorf("AgentID = %q, want 'beta' (marked as default)", route.AgentID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveRoute_NoDefaultUsesFirst(t *testing.T) {
|
||||
agents := []config.AgentConfig{
|
||||
{ID: "alpha"},
|
||||
{ID: "beta"},
|
||||
}
|
||||
cfg := testConfig(agents, nil)
|
||||
r := NewRouteResolver(cfg)
|
||||
|
||||
route := r.ResolveRoute(RouteInput{
|
||||
Channel: "cli",
|
||||
})
|
||||
|
||||
if route.AgentID != "alpha" {
|
||||
t.Errorf("AgentID = %q, want 'alpha' (first in list)", route.AgentID)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,183 @@
|
||||
package routing
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// DMScope controls DM session isolation granularity.
|
||||
type DMScope string
|
||||
|
||||
const (
|
||||
DMScopeMain DMScope = "main"
|
||||
DMScopePerPeer DMScope = "per-peer"
|
||||
DMScopePerChannelPeer DMScope = "per-channel-peer"
|
||||
DMScopePerAccountChannelPeer DMScope = "per-account-channel-peer"
|
||||
)
|
||||
|
||||
// RoutePeer represents a chat peer with kind and ID.
|
||||
type RoutePeer struct {
|
||||
Kind string // "direct", "group", "channel"
|
||||
ID string
|
||||
}
|
||||
|
||||
// SessionKeyParams holds all inputs for session key construction.
|
||||
type SessionKeyParams struct {
|
||||
AgentID string
|
||||
Channel string
|
||||
AccountID string
|
||||
Peer *RoutePeer
|
||||
DMScope DMScope
|
||||
IdentityLinks map[string][]string
|
||||
}
|
||||
|
||||
// ParsedSessionKey is the result of parsing an agent-scoped session key.
|
||||
type ParsedSessionKey struct {
|
||||
AgentID string
|
||||
Rest string
|
||||
}
|
||||
|
||||
// BuildAgentMainSessionKey returns "agent:<agentId>:main".
|
||||
func BuildAgentMainSessionKey(agentID string) string {
|
||||
return fmt.Sprintf("agent:%s:%s", NormalizeAgentID(agentID), DefaultMainKey)
|
||||
}
|
||||
|
||||
// BuildAgentPeerSessionKey constructs a session key based on agent, channel, peer, and DM scope.
|
||||
func BuildAgentPeerSessionKey(params SessionKeyParams) string {
|
||||
agentID := NormalizeAgentID(params.AgentID)
|
||||
|
||||
peer := params.Peer
|
||||
if peer == nil {
|
||||
peer = &RoutePeer{Kind: "direct"}
|
||||
}
|
||||
peerKind := strings.TrimSpace(peer.Kind)
|
||||
if peerKind == "" {
|
||||
peerKind = "direct"
|
||||
}
|
||||
|
||||
if peerKind == "direct" {
|
||||
dmScope := params.DMScope
|
||||
if dmScope == "" {
|
||||
dmScope = DMScopeMain
|
||||
}
|
||||
peerID := strings.TrimSpace(peer.ID)
|
||||
|
||||
// Resolve identity links (cross-platform collapse)
|
||||
if dmScope != DMScopeMain && peerID != "" {
|
||||
if linked := resolveLinkedPeerID(params.IdentityLinks, params.Channel, peerID); linked != "" {
|
||||
peerID = linked
|
||||
}
|
||||
}
|
||||
peerID = strings.ToLower(peerID)
|
||||
|
||||
switch dmScope {
|
||||
case DMScopePerAccountChannelPeer:
|
||||
if peerID != "" {
|
||||
channel := normalizeChannel(params.Channel)
|
||||
accountID := NormalizeAccountID(params.AccountID)
|
||||
return fmt.Sprintf("agent:%s:%s:%s:direct:%s", agentID, channel, accountID, peerID)
|
||||
}
|
||||
case DMScopePerChannelPeer:
|
||||
if peerID != "" {
|
||||
channel := normalizeChannel(params.Channel)
|
||||
return fmt.Sprintf("agent:%s:%s:direct:%s", agentID, channel, peerID)
|
||||
}
|
||||
case DMScopePerPeer:
|
||||
if peerID != "" {
|
||||
return fmt.Sprintf("agent:%s:direct:%s", agentID, peerID)
|
||||
}
|
||||
}
|
||||
return BuildAgentMainSessionKey(agentID)
|
||||
}
|
||||
|
||||
// Group/channel peers always get per-peer sessions
|
||||
channel := normalizeChannel(params.Channel)
|
||||
peerID := strings.ToLower(strings.TrimSpace(peer.ID))
|
||||
if peerID == "" {
|
||||
peerID = "unknown"
|
||||
}
|
||||
return fmt.Sprintf("agent:%s:%s:%s:%s", agentID, channel, peerKind, peerID)
|
||||
}
|
||||
|
||||
// ParseAgentSessionKey extracts agentId and rest from "agent:<agentId>:<rest>".
|
||||
func ParseAgentSessionKey(sessionKey string) *ParsedSessionKey {
|
||||
raw := strings.TrimSpace(sessionKey)
|
||||
if raw == "" {
|
||||
return nil
|
||||
}
|
||||
parts := strings.SplitN(raw, ":", 3)
|
||||
if len(parts) < 3 {
|
||||
return nil
|
||||
}
|
||||
if parts[0] != "agent" {
|
||||
return nil
|
||||
}
|
||||
agentID := strings.TrimSpace(parts[1])
|
||||
rest := parts[2]
|
||||
if agentID == "" || rest == "" {
|
||||
return nil
|
||||
}
|
||||
return &ParsedSessionKey{AgentID: agentID, Rest: rest}
|
||||
}
|
||||
|
||||
// IsSubagentSessionKey returns true if the session key represents a subagent.
|
||||
func IsSubagentSessionKey(sessionKey string) bool {
|
||||
raw := strings.TrimSpace(sessionKey)
|
||||
if raw == "" {
|
||||
return false
|
||||
}
|
||||
if strings.HasPrefix(strings.ToLower(raw), "subagent:") {
|
||||
return true
|
||||
}
|
||||
parsed := ParseAgentSessionKey(raw)
|
||||
if parsed == nil {
|
||||
return false
|
||||
}
|
||||
return strings.HasPrefix(strings.ToLower(parsed.Rest), "subagent:")
|
||||
}
|
||||
|
||||
func normalizeChannel(channel string) string {
|
||||
c := strings.TrimSpace(strings.ToLower(channel))
|
||||
if c == "" {
|
||||
return "unknown"
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
func resolveLinkedPeerID(identityLinks map[string][]string, channel, peerID string) string {
|
||||
if len(identityLinks) == 0 {
|
||||
return ""
|
||||
}
|
||||
peerID = strings.TrimSpace(peerID)
|
||||
if peerID == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
candidates := make(map[string]bool)
|
||||
rawCandidate := strings.ToLower(peerID)
|
||||
if rawCandidate != "" {
|
||||
candidates[rawCandidate] = true
|
||||
}
|
||||
channel = strings.ToLower(strings.TrimSpace(channel))
|
||||
if channel != "" {
|
||||
scopedCandidate := fmt.Sprintf("%s:%s", channel, strings.ToLower(peerID))
|
||||
candidates[scopedCandidate] = true
|
||||
}
|
||||
if len(candidates) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
for canonical, ids := range identityLinks {
|
||||
canonicalName := strings.TrimSpace(canonical)
|
||||
if canonicalName == "" {
|
||||
continue
|
||||
}
|
||||
for _, id := range ids {
|
||||
normalized := strings.ToLower(strings.TrimSpace(id))
|
||||
if normalized != "" && candidates[normalized] {
|
||||
return canonicalName
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
@@ -0,0 +1,162 @@
|
||||
package routing
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestBuildAgentMainSessionKey(t *testing.T) {
|
||||
got := BuildAgentMainSessionKey("sales")
|
||||
want := "agent:sales:main"
|
||||
if got != want {
|
||||
t.Errorf("BuildAgentMainSessionKey('sales') = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAgentMainSessionKey_Normalizes(t *testing.T) {
|
||||
got := BuildAgentMainSessionKey("Sales Bot")
|
||||
want := "agent:sales-bot:main"
|
||||
if got != want {
|
||||
t.Errorf("BuildAgentMainSessionKey('Sales Bot') = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAgentPeerSessionKey_DMScopeMain(t *testing.T) {
|
||||
got := BuildAgentPeerSessionKey(SessionKeyParams{
|
||||
AgentID: "main",
|
||||
Channel: "telegram",
|
||||
Peer: &RoutePeer{Kind: "direct", ID: "user123"},
|
||||
DMScope: DMScopeMain,
|
||||
})
|
||||
want := "agent:main:main"
|
||||
if got != want {
|
||||
t.Errorf("DMScopeMain = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAgentPeerSessionKey_DMScopePerPeer(t *testing.T) {
|
||||
got := BuildAgentPeerSessionKey(SessionKeyParams{
|
||||
AgentID: "main",
|
||||
Channel: "telegram",
|
||||
Peer: &RoutePeer{Kind: "direct", ID: "user123"},
|
||||
DMScope: DMScopePerPeer,
|
||||
})
|
||||
want := "agent:main:direct:user123"
|
||||
if got != want {
|
||||
t.Errorf("DMScopePerPeer = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAgentPeerSessionKey_DMScopePerChannelPeer(t *testing.T) {
|
||||
got := BuildAgentPeerSessionKey(SessionKeyParams{
|
||||
AgentID: "main",
|
||||
Channel: "telegram",
|
||||
Peer: &RoutePeer{Kind: "direct", ID: "user123"},
|
||||
DMScope: DMScopePerChannelPeer,
|
||||
})
|
||||
want := "agent:main:telegram:direct:user123"
|
||||
if got != want {
|
||||
t.Errorf("DMScopePerChannelPeer = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAgentPeerSessionKey_DMScopePerAccountChannelPeer(t *testing.T) {
|
||||
got := BuildAgentPeerSessionKey(SessionKeyParams{
|
||||
AgentID: "main",
|
||||
Channel: "telegram",
|
||||
AccountID: "bot1",
|
||||
Peer: &RoutePeer{Kind: "direct", ID: "User123"},
|
||||
DMScope: DMScopePerAccountChannelPeer,
|
||||
})
|
||||
want := "agent:main:telegram:bot1:direct:user123"
|
||||
if got != want {
|
||||
t.Errorf("DMScopePerAccountChannelPeer = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAgentPeerSessionKey_GroupPeer(t *testing.T) {
|
||||
got := BuildAgentPeerSessionKey(SessionKeyParams{
|
||||
AgentID: "main",
|
||||
Channel: "telegram",
|
||||
Peer: &RoutePeer{Kind: "group", ID: "chat456"},
|
||||
DMScope: DMScopePerPeer,
|
||||
})
|
||||
want := "agent:main:telegram:group:chat456"
|
||||
if got != want {
|
||||
t.Errorf("GroupPeer = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAgentPeerSessionKey_NilPeer(t *testing.T) {
|
||||
got := BuildAgentPeerSessionKey(SessionKeyParams{
|
||||
AgentID: "main",
|
||||
Channel: "telegram",
|
||||
Peer: nil,
|
||||
DMScope: DMScopePerPeer,
|
||||
})
|
||||
// nil peer defaults to direct with empty ID, falls to main
|
||||
want := "agent:main:main"
|
||||
if got != want {
|
||||
t.Errorf("NilPeer = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAgentPeerSessionKey_IdentityLink(t *testing.T) {
|
||||
links := map[string][]string{
|
||||
"john": {"telegram:user123", "discord:john#1234"},
|
||||
}
|
||||
got := BuildAgentPeerSessionKey(SessionKeyParams{
|
||||
AgentID: "main",
|
||||
Channel: "telegram",
|
||||
Peer: &RoutePeer{Kind: "direct", ID: "user123"},
|
||||
DMScope: DMScopePerPeer,
|
||||
IdentityLinks: links,
|
||||
})
|
||||
want := "agent:main:direct:john"
|
||||
if got != want {
|
||||
t.Errorf("IdentityLink = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseAgentSessionKey_Valid(t *testing.T) {
|
||||
parsed := ParseAgentSessionKey("agent:sales:telegram:direct:user123")
|
||||
if parsed == nil {
|
||||
t.Fatal("expected non-nil result")
|
||||
}
|
||||
if parsed.AgentID != "sales" {
|
||||
t.Errorf("AgentID = %q, want 'sales'", parsed.AgentID)
|
||||
}
|
||||
if parsed.Rest != "telegram:direct:user123" {
|
||||
t.Errorf("Rest = %q, want 'telegram:direct:user123'", parsed.Rest)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseAgentSessionKey_Invalid(t *testing.T) {
|
||||
tests := []string{
|
||||
"",
|
||||
"foo:bar",
|
||||
"notprefix:sales:main",
|
||||
"agent::main",
|
||||
"agent:sales:",
|
||||
}
|
||||
for _, input := range tests {
|
||||
if got := ParseAgentSessionKey(input); got != nil {
|
||||
t.Errorf("ParseAgentSessionKey(%q) = %+v, want nil", input, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsSubagentSessionKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
want bool
|
||||
}{
|
||||
{"subagent:task-1", true},
|
||||
{"agent:main:subagent:task-1", true},
|
||||
{"agent:main:main", false},
|
||||
{"agent:main:telegram:direct:user123", false},
|
||||
{"", false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
if got := IsSubagentSessionKey(tt.input); got != tt.want {
|
||||
t.Errorf("IsSubagentSessionKey(%q) = %v, want %v", tt.input, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
+22
-5
@@ -6,10 +6,11 @@ import (
|
||||
)
|
||||
|
||||
type SpawnTool struct {
|
||||
manager *SubagentManager
|
||||
originChannel string
|
||||
originChatID string
|
||||
callback AsyncCallback // For async completion notification
|
||||
manager *SubagentManager
|
||||
originChannel string
|
||||
originChatID string
|
||||
allowlistCheck func(targetAgentID string) bool
|
||||
callback AsyncCallback // For async completion notification
|
||||
}
|
||||
|
||||
func NewSpawnTool(manager *SubagentManager) *SpawnTool {
|
||||
@@ -45,6 +46,10 @@ func (t *SpawnTool) Parameters() map[string]interface{} {
|
||||
"type": "string",
|
||||
"description": "Optional short label for the task (for display)",
|
||||
},
|
||||
"agent_id": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "Optional target agent ID to delegate the task to",
|
||||
},
|
||||
},
|
||||
"required": []string{"task"},
|
||||
}
|
||||
@@ -55,6 +60,10 @@ func (t *SpawnTool) SetContext(channel, chatID string) {
|
||||
t.originChatID = chatID
|
||||
}
|
||||
|
||||
func (t *SpawnTool) SetAllowlistChecker(check func(targetAgentID string) bool) {
|
||||
t.allowlistCheck = check
|
||||
}
|
||||
|
||||
func (t *SpawnTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
|
||||
task, ok := args["task"].(string)
|
||||
if !ok {
|
||||
@@ -62,13 +71,21 @@ func (t *SpawnTool) Execute(ctx context.Context, args map[string]interface{}) *T
|
||||
}
|
||||
|
||||
label, _ := args["label"].(string)
|
||||
agentID, _ := args["agent_id"].(string)
|
||||
|
||||
// Check allowlist if targeting a specific agent
|
||||
if agentID != "" && t.allowlistCheck != nil {
|
||||
if !t.allowlistCheck(agentID) {
|
||||
return ErrorResult(fmt.Sprintf("not allowed to spawn agent '%s'", agentID))
|
||||
}
|
||||
}
|
||||
|
||||
if t.manager == nil {
|
||||
return ErrorResult("Subagent manager not configured")
|
||||
}
|
||||
|
||||
// Pass callback to manager for async completion notification
|
||||
result, err := t.manager.Spawn(ctx, task, label, t.originChannel, t.originChatID, t.callback)
|
||||
result, err := t.manager.Spawn(ctx, task, label, agentID, t.originChannel, t.originChatID, t.callback)
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("failed to spawn subagent: %v", err))
|
||||
}
|
||||
|
||||
@@ -14,6 +14,7 @@ type SubagentTask struct {
|
||||
ID string
|
||||
Task string
|
||||
Label string
|
||||
AgentID string
|
||||
OriginChannel string
|
||||
OriginChatID string
|
||||
Status string
|
||||
@@ -61,7 +62,7 @@ func (sm *SubagentManager) RegisterTool(tool Tool) {
|
||||
sm.tools.Register(tool)
|
||||
}
|
||||
|
||||
func (sm *SubagentManager) Spawn(ctx context.Context, task, label, originChannel, originChatID string, callback AsyncCallback) (string, error) {
|
||||
func (sm *SubagentManager) Spawn(ctx context.Context, task, label, agentID, originChannel, originChatID string, callback AsyncCallback) (string, error) {
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
|
||||
@@ -72,6 +73,7 @@ func (sm *SubagentManager) Spawn(ctx context.Context, task, label, originChannel
|
||||
ID: taskID,
|
||||
Task: task,
|
||||
Label: label,
|
||||
AgentID: agentID,
|
||||
OriginChannel: originChannel,
|
||||
OriginChatID: originChatID,
|
||||
Status: "running",
|
||||
|
||||
Reference in New Issue
Block a user