Merge remote-tracking branch 'origin/main' into feat/refactor-provider-by-protocol

This commit is contained in:
yinwm
2026-02-20 00:11:46 +08:00
75 changed files with 10647 additions and 1384 deletions
+145
View File
@@ -0,0 +1,145 @@
package agent
import (
"os"
"path/filepath"
"strings"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/routing"
"github.com/sipeed/picoclaw/pkg/session"
"github.com/sipeed/picoclaw/pkg/tools"
)
// AgentInstance represents a fully configured agent with its own workspace,
// session manager, context builder, and tool registry.
type AgentInstance struct {
ID string
Name string
Model string
Fallbacks []string
Workspace string
MaxIterations int
ContextWindow int
Provider providers.LLMProvider
Sessions *session.SessionManager
ContextBuilder *ContextBuilder
Tools *tools.ToolRegistry
Subagents *config.SubagentsConfig
SkillsFilter []string
Candidates []providers.FallbackCandidate
}
// NewAgentInstance creates an agent instance from config.
func NewAgentInstance(
agentCfg *config.AgentConfig,
defaults *config.AgentDefaults,
cfg *config.Config,
provider providers.LLMProvider,
) *AgentInstance {
workspace := resolveAgentWorkspace(agentCfg, defaults)
os.MkdirAll(workspace, 0755)
model := resolveAgentModel(agentCfg, defaults)
fallbacks := resolveAgentFallbacks(agentCfg, defaults)
restrict := defaults.RestrictToWorkspace
toolsRegistry := tools.NewToolRegistry()
toolsRegistry.Register(tools.NewReadFileTool(workspace, restrict))
toolsRegistry.Register(tools.NewWriteFileTool(workspace, restrict))
toolsRegistry.Register(tools.NewListDirTool(workspace, restrict))
toolsRegistry.Register(tools.NewExecToolWithConfig(workspace, restrict, cfg))
toolsRegistry.Register(tools.NewEditFileTool(workspace, restrict))
toolsRegistry.Register(tools.NewAppendFileTool(workspace, restrict))
sessionsDir := filepath.Join(workspace, "sessions")
sessionsManager := session.NewSessionManager(sessionsDir)
contextBuilder := NewContextBuilder(workspace)
contextBuilder.SetToolsRegistry(toolsRegistry)
agentID := routing.DefaultAgentID
agentName := ""
var subagents *config.SubagentsConfig
var skillsFilter []string
if agentCfg != nil {
agentID = routing.NormalizeAgentID(agentCfg.ID)
agentName = agentCfg.Name
subagents = agentCfg.Subagents
skillsFilter = agentCfg.Skills
}
maxIter := defaults.MaxToolIterations
if maxIter == 0 {
maxIter = 20
}
// Resolve fallback candidates
modelCfg := providers.ModelConfig{
Primary: model,
Fallbacks: fallbacks,
}
candidates := providers.ResolveCandidates(modelCfg, defaults.Provider)
return &AgentInstance{
ID: agentID,
Name: agentName,
Model: model,
Fallbacks: fallbacks,
Workspace: workspace,
MaxIterations: maxIter,
ContextWindow: defaults.MaxTokens,
Provider: provider,
Sessions: sessionsManager,
ContextBuilder: contextBuilder,
Tools: toolsRegistry,
Subagents: subagents,
SkillsFilter: skillsFilter,
Candidates: candidates,
}
}
// resolveAgentWorkspace determines the workspace directory for an agent.
func resolveAgentWorkspace(agentCfg *config.AgentConfig, defaults *config.AgentDefaults) string {
if agentCfg != nil && strings.TrimSpace(agentCfg.Workspace) != "" {
return expandHome(strings.TrimSpace(agentCfg.Workspace))
}
if agentCfg == nil || agentCfg.Default || agentCfg.ID == "" || routing.NormalizeAgentID(agentCfg.ID) == "main" {
return expandHome(defaults.Workspace)
}
home, _ := os.UserHomeDir()
id := routing.NormalizeAgentID(agentCfg.ID)
return filepath.Join(home, ".picoclaw", "workspace-"+id)
}
// resolveAgentModel resolves the primary model for an agent.
func resolveAgentModel(agentCfg *config.AgentConfig, defaults *config.AgentDefaults) string {
if agentCfg != nil && agentCfg.Model != nil && strings.TrimSpace(agentCfg.Model.Primary) != "" {
return strings.TrimSpace(agentCfg.Model.Primary)
}
return defaults.Model
}
// resolveAgentFallbacks resolves the fallback models for an agent.
func resolveAgentFallbacks(agentCfg *config.AgentConfig, defaults *config.AgentDefaults) []string {
if agentCfg != nil && agentCfg.Model != nil && agentCfg.Model.Fallbacks != nil {
return agentCfg.Model.Fallbacks
}
return defaults.ModelFallbacks
}
func expandHome(path string) string {
if path == "" {
return path
}
if path[0] == '~' {
home, _ := os.UserHomeDir()
if len(path) > 1 && path[1] == '/' {
return home + path[1:]
}
return home
}
return path
}
+299 -287
View File
@@ -10,8 +10,6 @@ import (
"context"
"encoding/json"
"fmt"
"os"
"path/filepath"
"strings"
"sync"
"sync/atomic"
@@ -24,7 +22,7 @@ import (
"github.com/sipeed/picoclaw/pkg/constants"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/session"
"github.com/sipeed/picoclaw/pkg/routing"
"github.com/sipeed/picoclaw/pkg/state"
"github.com/sipeed/picoclaw/pkg/tools"
"github.com/sipeed/picoclaw/pkg/utils"
@@ -32,17 +30,12 @@ import (
type AgentLoop struct {
bus *bus.MessageBus
provider providers.LLMProvider
workspace string
model string
contextWindow int // Maximum context window size in tokens
maxIterations int
sessions *session.SessionManager
cfg *config.Config
registry *AgentRegistry
state *state.Manager
contextBuilder *ContextBuilder
tools *tools.ToolRegistry
running atomic.Bool
summarizing sync.Map // Tracks which sessions are currently being summarized
summarizing sync.Map
fallback *providers.FallbackChain
channelManager *channels.Manager
}
@@ -58,99 +51,83 @@ type processOptions struct {
NoHistory bool // If true, don't load session history (for heartbeat)
}
// createToolRegistry creates a tool registry with common tools.
// This is shared between main agent and subagents.
func createToolRegistry(workspace string, restrict bool, cfg *config.Config, msgBus *bus.MessageBus) *tools.ToolRegistry {
registry := tools.NewToolRegistry()
// File system tools
registry.Register(tools.NewReadFileTool(workspace, restrict))
registry.Register(tools.NewWriteFileTool(workspace, restrict))
registry.Register(tools.NewListDirTool(workspace, restrict))
registry.Register(tools.NewEditFileTool(workspace, restrict))
registry.Register(tools.NewAppendFileTool(workspace, restrict))
// Shell execution
registry.Register(tools.NewExecTool(workspace, restrict))
if searchTool := tools.NewWebSearchTool(tools.WebSearchToolOptions{
BraveAPIKey: cfg.Tools.Web.Brave.APIKey,
BraveMaxResults: cfg.Tools.Web.Brave.MaxResults,
BraveEnabled: cfg.Tools.Web.Brave.Enabled,
DuckDuckGoMaxResults: cfg.Tools.Web.DuckDuckGo.MaxResults,
DuckDuckGoEnabled: cfg.Tools.Web.DuckDuckGo.Enabled,
PerplexityAPIKey: cfg.Tools.Web.Perplexity.APIKey,
PerplexityMaxResults: cfg.Tools.Web.Perplexity.MaxResults,
PerplexityEnabled: cfg.Tools.Web.Perplexity.Enabled,
}); searchTool != nil {
registry.Register(searchTool)
}
registry.Register(tools.NewWebFetchTool(50000))
// Hardware tools (I2C, SPI) - Linux only, returns error on other platforms
registry.Register(tools.NewI2CTool())
registry.Register(tools.NewSPITool())
// Message tool - available to both agent and subagent
// Subagent uses it to communicate directly with user
messageTool := tools.NewMessageTool()
messageTool.SetSendCallback(func(channel, chatID, content string) error {
msgBus.PublishOutbound(bus.OutboundMessage{
Channel: channel,
ChatID: chatID,
Content: content,
})
return nil
})
registry.Register(messageTool)
return registry
}
func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers.LLMProvider) *AgentLoop {
workspace := cfg.WorkspacePath()
os.MkdirAll(workspace, 0755)
registry := NewAgentRegistry(cfg, provider)
restrict := cfg.Agents.Defaults.RestrictToWorkspace
// Register shared tools to all agents
registerSharedTools(cfg, msgBus, registry, provider)
// Create tool registry for main agent
toolsRegistry := createToolRegistry(workspace, restrict, cfg, msgBus)
// Set up shared fallback chain
cooldown := providers.NewCooldownTracker()
fallbackChain := providers.NewFallbackChain(cooldown)
// Create subagent manager with its own tool registry
subagentManager := tools.NewSubagentManager(provider, cfg.Agents.Defaults.Model, workspace, msgBus)
subagentTools := createToolRegistry(workspace, restrict, cfg, msgBus)
// Subagent doesn't need spawn/subagent tools to avoid recursion
subagentManager.SetTools(subagentTools)
// Register spawn tool (for main agent)
spawnTool := tools.NewSpawnTool(subagentManager)
toolsRegistry.Register(spawnTool)
// Register subagent tool (synchronous execution)
subagentTool := tools.NewSubagentTool(subagentManager)
toolsRegistry.Register(subagentTool)
sessionsManager := session.NewSessionManager(filepath.Join(workspace, "sessions"))
// Create state manager for atomic state persistence
stateManager := state.NewManager(workspace)
// Create context builder and set tools registry
contextBuilder := NewContextBuilder(workspace)
contextBuilder.SetToolsRegistry(toolsRegistry)
// Create state manager using default agent's workspace for channel recording
defaultAgent := registry.GetDefaultAgent()
var stateManager *state.Manager
if defaultAgent != nil {
stateManager = state.NewManager(defaultAgent.Workspace)
}
return &AgentLoop{
bus: msgBus,
provider: provider,
workspace: workspace,
model: cfg.Agents.Defaults.Model,
contextWindow: cfg.Agents.Defaults.MaxTokens, // Restore context window for summarization
maxIterations: cfg.Agents.Defaults.MaxToolIterations,
sessions: sessionsManager,
state: stateManager,
contextBuilder: contextBuilder,
tools: toolsRegistry,
summarizing: sync.Map{},
bus: msgBus,
cfg: cfg,
registry: registry,
state: stateManager,
summarizing: sync.Map{},
fallback: fallbackChain,
}
}
// registerSharedTools registers tools that are shared across all agents (web, message, spawn).
func registerSharedTools(cfg *config.Config, msgBus *bus.MessageBus, registry *AgentRegistry, provider providers.LLMProvider) {
for _, agentID := range registry.ListAgentIDs() {
agent, ok := registry.GetAgent(agentID)
if !ok {
continue
}
// Web tools
if searchTool := tools.NewWebSearchTool(tools.WebSearchToolOptions{
BraveAPIKey: cfg.Tools.Web.Brave.APIKey,
BraveMaxResults: cfg.Tools.Web.Brave.MaxResults,
BraveEnabled: cfg.Tools.Web.Brave.Enabled,
DuckDuckGoMaxResults: cfg.Tools.Web.DuckDuckGo.MaxResults,
DuckDuckGoEnabled: cfg.Tools.Web.DuckDuckGo.Enabled,
PerplexityAPIKey: cfg.Tools.Web.Perplexity.APIKey,
PerplexityMaxResults: cfg.Tools.Web.Perplexity.MaxResults,
PerplexityEnabled: cfg.Tools.Web.Perplexity.Enabled,
}); searchTool != nil {
agent.Tools.Register(searchTool)
}
agent.Tools.Register(tools.NewWebFetchTool(50000))
// Hardware tools (I2C, SPI) - Linux only, returns error on other platforms
agent.Tools.Register(tools.NewI2CTool())
agent.Tools.Register(tools.NewSPITool())
// Message tool
messageTool := tools.NewMessageTool()
messageTool.SetSendCallback(func(channel, chatID, content string) error {
msgBus.PublishOutbound(bus.OutboundMessage{
Channel: channel,
ChatID: chatID,
Content: content,
})
return nil
})
agent.Tools.Register(messageTool)
// Spawn tool with allowlist checker
subagentManager := tools.NewSubagentManager(provider, agent.Model, agent.Workspace, msgBus)
spawnTool := tools.NewSpawnTool(subagentManager)
currentAgentID := agentID
spawnTool.SetAllowlistChecker(func(targetAgentID string) bool {
return registry.CanSpawnSubagent(currentAgentID, targetAgentID)
})
agent.Tools.Register(spawnTool)
// Update context builder with the complete tools registry
agent.ContextBuilder.SetToolsRegistry(agent.Tools)
}
}
@@ -175,10 +152,14 @@ func (al *AgentLoop) Run(ctx context.Context) error {
if response != "" {
// Check if the message tool already sent a response during this round.
// If so, skip publishing to avoid duplicate messages to the user.
// Use default agent's tools to check (message tool is shared).
alreadySent := false
if tool, ok := al.tools.Get("message"); ok {
if mt, ok := tool.(*tools.MessageTool); ok {
alreadySent = mt.HasSentInRound()
defaultAgent := al.registry.GetDefaultAgent()
if defaultAgent != nil {
if tool, ok := defaultAgent.Tools.Get("message"); ok {
if mt, ok := tool.(*tools.MessageTool); ok {
alreadySent = mt.HasSentInRound()
}
}
}
@@ -201,7 +182,11 @@ func (al *AgentLoop) Stop() {
}
func (al *AgentLoop) RegisterTool(tool tools.Tool) {
al.tools.Register(tool)
for _, agentID := range al.registry.ListAgentIDs() {
if agent, ok := al.registry.GetAgent(agentID); ok {
agent.Tools.Register(tool)
}
}
}
func (al *AgentLoop) SetChannelManager(cm *channels.Manager) {
@@ -211,12 +196,18 @@ func (al *AgentLoop) SetChannelManager(cm *channels.Manager) {
// RecordLastChannel records the last active channel for this workspace.
// This uses the atomic state save mechanism to prevent data loss on crash.
func (al *AgentLoop) RecordLastChannel(channel string) error {
if al.state == nil {
return nil
}
return al.state.SetLastChannel(channel)
}
// RecordLastChatID records the last active chat ID for this workspace.
// This uses the atomic state save mechanism to prevent data loss on crash.
func (al *AgentLoop) RecordLastChatID(chatID string) error {
if al.state == nil {
return nil
}
return al.state.SetLastChatID(chatID)
}
@@ -239,7 +230,8 @@ func (al *AgentLoop) ProcessDirectWithChannel(ctx context.Context, content, sess
// ProcessHeartbeat processes a heartbeat request without session history.
// Each heartbeat is independent and doesn't accumulate context.
func (al *AgentLoop) ProcessHeartbeat(ctx context.Context, content, channel, chatID string) (string, error) {
return al.runAgentLoop(ctx, processOptions{
agent := al.registry.GetDefaultAgent()
return al.runAgentLoop(ctx, agent, processOptions{
SessionKey: "heartbeat",
Channel: channel,
ChatID: chatID,
@@ -277,9 +269,36 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage)
return response, nil
}
// Process as user message
return al.runAgentLoop(ctx, processOptions{
SessionKey: msg.SessionKey,
// Route to determine agent and session key
route := al.registry.ResolveRoute(routing.RouteInput{
Channel: msg.Channel,
AccountID: msg.Metadata["account_id"],
Peer: extractPeer(msg),
ParentPeer: extractParentPeer(msg),
GuildID: msg.Metadata["guild_id"],
TeamID: msg.Metadata["team_id"],
})
agent, ok := al.registry.GetAgent(route.AgentID)
if !ok {
agent = al.registry.GetDefaultAgent()
}
// Use routed session key, but honor pre-set agent-scoped keys (for ProcessDirect/cron)
sessionKey := route.SessionKey
if msg.SessionKey != "" && strings.HasPrefix(msg.SessionKey, "agent:") {
sessionKey = msg.SessionKey
}
logger.InfoCF("agent", "Routed message",
map[string]interface{}{
"agent_id": agent.ID,
"session_key": sessionKey,
"matched_by": route.MatchedBy,
})
return al.runAgentLoop(ctx, agent, processOptions{
SessionKey: sessionKey,
Channel: msg.Channel,
ChatID: msg.ChatID,
UserMessage: msg.Content,
@@ -290,7 +309,6 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage)
}
func (al *AgentLoop) processSystemMessage(ctx context.Context, msg bus.InboundMessage) (string, error) {
// Verify this is a system message
if msg.Channel != "system" {
return "", fmt.Errorf("processSystemMessage called with non-system message channel: %s", msg.Channel)
}
@@ -302,12 +320,13 @@ func (al *AgentLoop) processSystemMessage(ctx context.Context, msg bus.InboundMe
})
// Parse origin channel from chat_id (format: "channel:chat_id")
var originChannel string
var originChannel, originChatID string
if idx := strings.Index(msg.ChatID, ":"); idx > 0 {
originChannel = msg.ChatID[:idx]
originChatID = msg.ChatID[idx+1:]
} else {
// Fallback
originChannel = "cli"
originChatID = msg.ChatID
}
// Extract subagent result from message content
@@ -328,44 +347,47 @@ func (al *AgentLoop) processSystemMessage(ctx context.Context, msg bus.InboundMe
return "", nil
}
// Agent acts as dispatcher only - subagent handles user interaction via message tool
// Don't forward result here, subagent should use message tool to communicate with user
logger.InfoCF("agent", "Subagent completed",
map[string]interface{}{
"sender_id": msg.SenderID,
"channel": originChannel,
"content_len": len(content),
})
// Use default agent for system messages
agent := al.registry.GetDefaultAgent()
// Agent only logs, does not respond to user
return "", nil
// Use the origin session for context
sessionKey := routing.BuildAgentMainSessionKey(agent.ID)
return al.runAgentLoop(ctx, agent, processOptions{
SessionKey: sessionKey,
Channel: originChannel,
ChatID: originChatID,
UserMessage: fmt.Sprintf("[System: %s] %s", msg.SenderID, msg.Content),
DefaultResponse: "Background task completed.",
EnableSummary: false,
SendResponse: true,
})
}
// runAgentLoop is the core message processing logic.
// It handles context building, LLM calls, tool execution, and response handling.
func (al *AgentLoop) runAgentLoop(ctx context.Context, opts processOptions) (string, error) {
func (al *AgentLoop) runAgentLoop(ctx context.Context, agent *AgentInstance, opts processOptions) (string, error) {
// 0. Record last channel for heartbeat notifications (skip internal channels)
if opts.Channel != "" && opts.ChatID != "" {
// Don't record internal channels (cli, system, subagent)
if !constants.IsInternalChannel(opts.Channel) {
channelKey := fmt.Sprintf("%s:%s", opts.Channel, opts.ChatID)
if err := al.RecordLastChannel(channelKey); err != nil {
logger.WarnCF("agent", "Failed to record last channel: %v", map[string]interface{}{"error": err.Error()})
logger.WarnCF("agent", "Failed to record last channel", map[string]interface{}{"error": err.Error()})
}
}
}
// 1. Update tool contexts
al.updateToolContexts(opts.Channel, opts.ChatID)
al.updateToolContexts(agent, opts.Channel, opts.ChatID)
// 2. Build messages (skip history for heartbeat)
var history []providers.Message
var summary string
if !opts.NoHistory {
history = al.sessions.GetHistory(opts.SessionKey)
summary = al.sessions.GetSummary(opts.SessionKey)
history = agent.Sessions.GetHistory(opts.SessionKey)
summary = agent.Sessions.GetSummary(opts.SessionKey)
}
messages := al.contextBuilder.BuildMessages(
messages := agent.ContextBuilder.BuildMessages(
history,
summary,
opts.UserMessage,
@@ -375,10 +397,10 @@ func (al *AgentLoop) runAgentLoop(ctx context.Context, opts processOptions) (str
)
// 3. Save user message to session
al.sessions.AddMessage(opts.SessionKey, "user", opts.UserMessage)
agent.Sessions.AddMessage(opts.SessionKey, "user", opts.UserMessage)
// 4. Run LLM iteration loop
finalContent, iteration, err := al.runLLMIteration(ctx, messages, opts)
finalContent, iteration, err := al.runLLMIteration(ctx, agent, messages, opts)
if err != nil {
return "", err
}
@@ -392,12 +414,12 @@ func (al *AgentLoop) runAgentLoop(ctx context.Context, opts processOptions) (str
}
// 6. Save final assistant message to session
al.sessions.AddMessage(opts.SessionKey, "assistant", finalContent)
al.sessions.Save(opts.SessionKey)
agent.Sessions.AddMessage(opts.SessionKey, "assistant", finalContent)
agent.Sessions.Save(opts.SessionKey)
// 7. Optional: summarization
if opts.EnableSummary {
al.maybeSummarize(opts.SessionKey, opts.Channel, opts.ChatID)
al.maybeSummarize(agent, opts.SessionKey, opts.Channel, opts.ChatID)
}
// 8. Optional: send response via bus
@@ -413,6 +435,7 @@ func (al *AgentLoop) runAgentLoop(ctx context.Context, opts processOptions) (str
responsePreview := utils.Truncate(finalContent, 120)
logger.InfoCF("agent", fmt.Sprintf("Response: %s", responsePreview),
map[string]interface{}{
"agent_id": agent.ID,
"session_key": opts.SessionKey,
"iterations": iteration,
"final_length": len(finalContent),
@@ -422,28 +445,29 @@ func (al *AgentLoop) runAgentLoop(ctx context.Context, opts processOptions) (str
}
// runLLMIteration executes the LLM call loop with tool handling.
// Returns the final content, iteration count, and any error.
func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.Message, opts processOptions) (string, int, error) {
func (al *AgentLoop) runLLMIteration(ctx context.Context, agent *AgentInstance, messages []providers.Message, opts processOptions) (string, int, error) {
iteration := 0
var finalContent string
for iteration < al.maxIterations {
for iteration < agent.MaxIterations {
iteration++
logger.DebugCF("agent", "LLM iteration",
map[string]interface{}{
"agent_id": agent.ID,
"iteration": iteration,
"max": al.maxIterations,
"max": agent.MaxIterations,
})
// Build tool definitions
providerToolDefs := al.tools.ToProviderDefs()
providerToolDefs := agent.Tools.ToProviderDefs()
// Log LLM request details
logger.DebugCF("agent", "LLM request",
map[string]interface{}{
"agent_id": agent.ID,
"iteration": iteration,
"model": al.model,
"model": agent.Model,
"messages_count": len(messages),
"tools_count": len(providerToolDefs),
"max_tokens": 8192,
@@ -459,23 +483,45 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M
"tools_json": formatToolsForLog(providerToolDefs),
})
// Call LLM with fallback chain if candidates are configured.
var response *providers.LLMResponse
var err error
callLLM := func() (*providers.LLMResponse, error) {
if len(agent.Candidates) > 1 && al.fallback != nil {
fbResult, fbErr := al.fallback.Execute(ctx, agent.Candidates,
func(ctx context.Context, provider, model string) (*providers.LLMResponse, error) {
return agent.Provider.Chat(ctx, messages, providerToolDefs, model, map[string]interface{}{
"max_tokens": 8192,
"temperature": 0.7,
})
},
)
if fbErr != nil {
return nil, fbErr
}
if fbResult.Provider != "" && len(fbResult.Attempts) > 0 {
logger.InfoCF("agent", fmt.Sprintf("Fallback: succeeded with %s/%s after %d attempts",
fbResult.Provider, fbResult.Model, len(fbResult.Attempts)+1),
map[string]interface{}{"agent_id": agent.ID, "iteration": iteration})
}
return fbResult.Response, nil
}
return agent.Provider.Chat(ctx, messages, providerToolDefs, agent.Model, map[string]interface{}{
"max_tokens": 8192,
"temperature": 0.7,
})
}
// Retry loop for context/token errors
maxRetries := 2
for retry := 0; retry <= maxRetries; retry++ {
response, err = al.provider.Chat(ctx, messages, providerToolDefs, al.model, map[string]interface{}{
"max_tokens": 8192,
"temperature": 0.7,
})
response, err = callLLM()
if err == nil {
break // Success
break
}
errMsg := strings.ToLower(err.Error())
// Check for context window errors (provider specific, but usually contain "token" or "invalid")
isContextError := strings.Contains(errMsg, "token") ||
strings.Contains(errMsg, "context") ||
strings.Contains(errMsg, "invalidparameter") ||
@@ -487,107 +533,30 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M
"retry": retry,
})
// Notify user on first retry only
if retry == 0 && !constants.IsInternalChannel(opts.Channel) && opts.SendResponse {
if retry == 0 && !constants.IsInternalChannel(opts.Channel) {
al.bus.PublishOutbound(bus.OutboundMessage{
Channel: opts.Channel,
ChatID: opts.ChatID,
Content: "⚠️ Context window exceeded. Compressing history and retrying...",
Content: "Context window exceeded. Compressing history and retrying...",
})
}
// Force compression
al.forceCompression(opts.SessionKey)
// Rebuild messages with compressed history
// Note: We need to reload history from session manager because forceCompression changed it
newHistory := al.sessions.GetHistory(opts.SessionKey)
newSummary := al.sessions.GetSummary(opts.SessionKey)
// Re-create messages for the next attempt
// We keep the current user message (opts.UserMessage) effectively
messages = al.contextBuilder.BuildMessages(
newHistory,
newSummary,
opts.UserMessage,
nil,
opts.Channel,
opts.ChatID,
al.forceCompression(agent, opts.SessionKey)
newHistory := agent.Sessions.GetHistory(opts.SessionKey)
newSummary := agent.Sessions.GetSummary(opts.SessionKey)
messages = agent.ContextBuilder.BuildMessages(
newHistory, newSummary, "",
nil, opts.Channel, opts.ChatID,
)
// Important: If we are in the middle of a tool loop (iteration > 1),
// rebuilding messages from session history might duplicate the flow or miss context
// if intermediate steps weren't saved correctly.
// However, al.sessions.AddFullMessage is called after every tool execution,
// so GetHistory should reflect the current state including partial tool execution.
// But we need to ensure we don't duplicate the user message which is appended in BuildMessages.
// BuildMessages(history...) takes the stored history and appends the *current* user message.
// If iteration > 1, the "current user message" was already added to history in step 3 of runAgentLoop.
// So if we pass opts.UserMessage again, we might duplicate it?
// Actually, step 3 is: al.sessions.AddMessage(opts.SessionKey, "user", opts.UserMessage)
// So GetHistory ALREADY contains the user message!
// CORRECTION:
// BuildMessages combines: [System] + [History] + [CurrentMessage]
// But Step 3 added CurrentMessage to History.
// So if we use GetHistory now, it has the user message.
// If we pass opts.UserMessage to BuildMessages, it adds it AGAIN.
// For retry in the middle of a loop, we should rely on what's in the session.
// BUT checking BuildMessages implementation:
// It appends history... then appends currentMessage.
// Logic fix for retry:
// If iteration == 1, opts.UserMessage corresponds to the user input.
// If iteration > 1, we are processing tool results. The "messages" passed to Chat
// already accumulated tool outputs.
// Rebuilding from session history is safest because it persists state.
// Start fresh with rebuilt history.
// Special case: standard BuildMessages appends "currentMessage".
// If we are strictly retrying the *LLM call*, we want the exact same state as before but compressed.
// However, the "messages" argument passed to runLLMIteration is constructed by the caller.
// If we rebuild from Session, we need to know if "currentMessage" should be appended or is already in history.
// In runAgentLoop:
// 3. sessions.AddMessage(userMsg)
// 4. runLLMIteration(..., UserMessage)
// So History contains the user message.
// BuildMessages typically appends the user message as a *new* pending message.
// Wait, standard BuildMessages usage in runAgentLoop:
// messages := BuildMessages(history (has old), UserMessage)
// THEN AddMessage(UserMessage).
// So "history" passed to BuildMessages does NOT contain the current UserMessage yet.
// But here, inside the loop, we have already saved it.
// So GetHistory() includes the current user message.
// If we call BuildMessages(GetHistory(), UserMessage), we get duplicates.
// Hack/Fix:
// If we are retrying, we rebuild from Session History ONLY.
// We pass empty string as "currentMessage" to BuildMessages
// because the "current message" is already saved in history (step 3).
messages = al.contextBuilder.BuildMessages(
newHistory,
newSummary,
"", // Empty because history already contains the relevant messages
nil,
opts.Channel,
opts.ChatID,
)
continue
}
// Real error or success, break loop
break
}
if err != nil {
logger.ErrorCF("agent", "LLM call failed",
map[string]interface{}{
"agent_id": agent.ID,
"iteration": iteration,
"error": err.Error(),
})
@@ -599,6 +568,7 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M
finalContent = response.Content
logger.InfoCF("agent", "LLM response without tool calls (direct answer)",
map[string]interface{}{
"agent_id": agent.ID,
"iteration": iteration,
"content_chars": len(finalContent),
})
@@ -617,6 +587,7 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M
}
logger.InfoCF("agent", "LLM requested tool calls",
map[string]interface{}{
"agent_id": agent.ID,
"tools": toolNames,
"count": len(normalizedToolCalls),
"iteration": iteration,
@@ -649,15 +620,15 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M
messages = append(messages, assistantMsg)
// Save assistant message with tool calls to session
al.sessions.AddFullMessage(opts.SessionKey, assistantMsg)
agent.Sessions.AddFullMessage(opts.SessionKey, assistantMsg)
// Execute tool calls
for _, tc := range normalizedToolCalls {
// Log tool call with arguments preview
argsJSON, _ := json.Marshal(tc.Arguments)
argsPreview := utils.Truncate(string(argsJSON), 200)
logger.InfoCF("agent", fmt.Sprintf("Tool call: %s(%s)", tc.Name, argsPreview),
map[string]interface{}{
"agent_id": agent.ID,
"tool": tc.Name,
"iteration": iteration,
})
@@ -678,7 +649,7 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M
}
}
toolResult := al.tools.ExecuteWithContext(ctx, tc.Name, tc.Arguments, opts.Channel, opts.ChatID, asyncCallback)
toolResult := agent.Tools.ExecuteWithContext(ctx, tc.Name, tc.Arguments, opts.Channel, opts.ChatID, asyncCallback)
// Send ForUser content to user immediately if not Silent
if !toolResult.Silent && toolResult.ForUser != "" && opts.SendResponse {
@@ -708,7 +679,7 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M
messages = append(messages, toolResultMsg)
// Save tool result message to session
al.sessions.AddFullMessage(opts.SessionKey, toolResultMsg)
agent.Sessions.AddFullMessage(opts.SessionKey, toolResultMsg)
}
}
@@ -716,19 +687,19 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M
}
// updateToolContexts updates the context for tools that need channel/chatID info.
func (al *AgentLoop) updateToolContexts(channel, chatID string) {
func (al *AgentLoop) updateToolContexts(agent *AgentInstance, channel, chatID string) {
// Use ContextualTool interface instead of type assertions
if tool, ok := al.tools.Get("message"); ok {
if tool, ok := agent.Tools.Get("message"); ok {
if mt, ok := tool.(tools.ContextualTool); ok {
mt.SetContext(channel, chatID)
}
}
if tool, ok := al.tools.Get("spawn"); ok {
if tool, ok := agent.Tools.Get("spawn"); ok {
if st, ok := tool.(tools.ContextualTool); ok {
st.SetContext(channel, chatID)
}
}
if tool, ok := al.tools.Get("subagent"); ok {
if tool, ok := agent.Tools.Get("subagent"); ok {
if st, ok := tool.(tools.ContextualTool); ok {
st.SetContext(channel, chatID)
}
@@ -736,24 +707,24 @@ func (al *AgentLoop) updateToolContexts(channel, chatID string) {
}
// maybeSummarize triggers summarization if the session history exceeds thresholds.
func (al *AgentLoop) maybeSummarize(sessionKey, channel, chatID string) {
newHistory := al.sessions.GetHistory(sessionKey)
func (al *AgentLoop) maybeSummarize(agent *AgentInstance, sessionKey, channel, chatID string) {
newHistory := agent.Sessions.GetHistory(sessionKey)
tokenEstimate := al.estimateTokens(newHistory)
threshold := al.contextWindow * 75 / 100
threshold := agent.ContextWindow * 75 / 100
if len(newHistory) > 20 || tokenEstimate > threshold {
if _, loading := al.summarizing.LoadOrStore(sessionKey, true); !loading {
summarizeKey := agent.ID + ":" + sessionKey
if _, loading := al.summarizing.LoadOrStore(summarizeKey, true); !loading {
go func() {
defer al.summarizing.Delete(sessionKey)
// Notify user about optimization if not an internal channel
defer al.summarizing.Delete(summarizeKey)
if !constants.IsInternalChannel(channel) {
al.bus.PublishOutbound(bus.OutboundMessage{
Channel: channel,
ChatID: chatID,
Content: "⚠️ Memory threshold reached. Optimizing conversation history...",
Content: "Memory threshold reached. Optimizing conversation history...",
})
}
al.summarizeSession(sessionKey)
al.summarizeSession(agent, sessionKey)
}()
}
}
@@ -761,8 +732,8 @@ func (al *AgentLoop) maybeSummarize(sessionKey, channel, chatID string) {
// forceCompression aggressively reduces context when the limit is hit.
// It drops the oldest 50% of messages (keeping system prompt and last user message).
func (al *AgentLoop) forceCompression(sessionKey string) {
history := al.sessions.GetHistory(sessionKey)
func (al *AgentLoop) forceCompression(agent *AgentInstance, sessionKey string) {
history := agent.Sessions.GetHistory(sessionKey)
if len(history) <= 4 {
return
}
@@ -799,8 +770,8 @@ func (al *AgentLoop) forceCompression(sessionKey string) {
newHistory = append(newHistory, history[len(history)-1]) // Last message
// Update session
al.sessions.SetHistory(sessionKey, newHistory)
al.sessions.Save(sessionKey)
agent.Sessions.SetHistory(sessionKey, newHistory)
agent.Sessions.Save(sessionKey)
logger.WarnCF("agent", "Forced compression executed", map[string]interface{}{
"session_key": sessionKey,
@@ -813,15 +784,26 @@ func (al *AgentLoop) forceCompression(sessionKey string) {
func (al *AgentLoop) GetStartupInfo() map[string]interface{} {
info := make(map[string]interface{})
agent := al.registry.GetDefaultAgent()
if agent == nil {
return info
}
// Tools info
tools := al.tools.List()
toolsList := agent.Tools.List()
info["tools"] = map[string]interface{}{
"count": len(tools),
"names": tools,
"count": len(toolsList),
"names": toolsList,
}
// Skills info
info["skills"] = al.contextBuilder.GetSkillsInfo()
info["skills"] = agent.ContextBuilder.GetSkillsInfo()
// Agents info
info["agents"] = map[string]interface{}{
"count": len(al.registry.ListAgentIDs()),
"ids": al.registry.ListAgentIDs(),
}
return info
}
@@ -878,12 +860,12 @@ func formatToolsForLog(tools []providers.ToolDefinition) string {
}
// summarizeSession summarizes the conversation history for a session.
func (al *AgentLoop) summarizeSession(sessionKey string) {
func (al *AgentLoop) summarizeSession(agent *AgentInstance, sessionKey string) {
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
history := al.sessions.GetHistory(sessionKey)
summary := al.sessions.GetSummary(sessionKey)
history := agent.Sessions.GetHistory(sessionKey)
summary := agent.Sessions.GetSummary(sessionKey)
// Keep last 4 messages for continuity
if len(history) <= 4 {
@@ -893,8 +875,7 @@ func (al *AgentLoop) summarizeSession(sessionKey string) {
toSummarize := history[:len(history)-4]
// Oversized Message Guard
// Skip messages larger than 50% of context window to prevent summarizer overflow
maxMessageTokens := al.contextWindow / 2
maxMessageTokens := agent.ContextWindow / 2
validMessages := make([]providers.Message, 0)
omitted := false
@@ -902,8 +883,7 @@ func (al *AgentLoop) summarizeSession(sessionKey string) {
if m.Role != "user" && m.Role != "assistant" {
continue
}
// Estimate tokens for this message
msgTokens := len(m.Content) / 2 // Use safer estimate here too (2.5 -> 2 for integer division safety)
msgTokens := len(m.Content) / 2
if msgTokens > maxMessageTokens {
omitted = true
continue
@@ -916,19 +896,17 @@ func (al *AgentLoop) summarizeSession(sessionKey string) {
}
// Multi-Part Summarization
// Split into two parts if history is significant
var finalSummary string
if len(validMessages) > 10 {
mid := len(validMessages) / 2
part1 := validMessages[:mid]
part2 := validMessages[mid:]
s1, _ := al.summarizeBatch(ctx, part1, "")
s2, _ := al.summarizeBatch(ctx, part2, "")
s1, _ := al.summarizeBatch(ctx, agent, part1, "")
s2, _ := al.summarizeBatch(ctx, agent, part2, "")
// Merge them
mergePrompt := fmt.Sprintf("Merge these two conversation summaries into one cohesive summary:\n\n1: %s\n\n2: %s", s1, s2)
resp, err := al.provider.Chat(ctx, []providers.Message{{Role: "user", Content: mergePrompt}}, nil, al.model, map[string]interface{}{
resp, err := agent.Provider.Chat(ctx, []providers.Message{{Role: "user", Content: mergePrompt}}, nil, agent.Model, map[string]interface{}{
"max_tokens": 1024,
"temperature": 0.3,
})
@@ -938,7 +916,7 @@ func (al *AgentLoop) summarizeSession(sessionKey string) {
finalSummary = s1 + " " + s2
}
} else {
finalSummary, _ = al.summarizeBatch(ctx, validMessages, summary)
finalSummary, _ = al.summarizeBatch(ctx, agent, validMessages, summary)
}
if omitted && finalSummary != "" {
@@ -946,14 +924,14 @@ func (al *AgentLoop) summarizeSession(sessionKey string) {
}
if finalSummary != "" {
al.sessions.SetSummary(sessionKey, finalSummary)
al.sessions.TruncateHistory(sessionKey, 4)
al.sessions.Save(sessionKey)
agent.Sessions.SetSummary(sessionKey, finalSummary)
agent.Sessions.TruncateHistory(sessionKey, 4)
agent.Sessions.Save(sessionKey)
}
}
// summarizeBatch summarizes a batch of messages.
func (al *AgentLoop) summarizeBatch(ctx context.Context, batch []providers.Message, existingSummary string) (string, error) {
func (al *AgentLoop) summarizeBatch(ctx context.Context, agent *AgentInstance, batch []providers.Message, existingSummary string) (string, error) {
prompt := "Provide a concise summary of this conversation segment, preserving core context and key points.\n"
if existingSummary != "" {
prompt += "Existing context: " + existingSummary + "\n"
@@ -963,7 +941,7 @@ func (al *AgentLoop) summarizeBatch(ctx context.Context, batch []providers.Messa
prompt += fmt.Sprintf("%s: %s\n", m.Role, m.Content)
}
response, err := al.provider.Chat(ctx, []providers.Message{{Role: "user", Content: prompt}}, nil, al.model, map[string]interface{}{
response, err := agent.Provider.Chat(ctx, []providers.Message{{Role: "user", Content: prompt}}, nil, agent.Model, map[string]interface{}{
"max_tokens": 1024,
"temperature": 0.3,
})
@@ -1002,25 +980,31 @@ func (al *AgentLoop) handleCommand(ctx context.Context, msg bus.InboundMessage)
switch cmd {
case "/show":
if len(args) < 1 {
return "Usage: /show [model|channel]", true
return "Usage: /show [model|channel|agents]", true
}
switch args[0] {
case "model":
return fmt.Sprintf("Current model: %s", al.model), true
defaultAgent := al.registry.GetDefaultAgent()
if defaultAgent == nil {
return "No default agent configured", true
}
return fmt.Sprintf("Current model: %s", defaultAgent.Model), true
case "channel":
return fmt.Sprintf("Current channel: %s", msg.Channel), true
case "agents":
agentIDs := al.registry.ListAgentIDs()
return fmt.Sprintf("Registered agents: %s", strings.Join(agentIDs, ", ")), true
default:
return fmt.Sprintf("Unknown show target: %s", args[0]), true
}
case "/list":
if len(args) < 1 {
return "Usage: /list [models|channels]", true
return "Usage: /list [models|channels|agents]", true
}
switch args[0] {
case "models":
// TODO: Fetch available models dynamically if possible
return "Available models: glm-4.7, claude-3-5-sonnet, gpt-4o (configured in config.json/env)", true
return "Available models: configured in config.json per agent", true
case "channels":
if al.channelManager == nil {
return "Channel manager not initialized", true
@@ -1030,6 +1014,9 @@ func (al *AgentLoop) handleCommand(ctx context.Context, msg bus.InboundMessage)
return "No channels enabled", true
}
return fmt.Sprintf("Enabled channels: %s", strings.Join(channels, ", ")), true
case "agents":
agentIDs := al.registry.ListAgentIDs()
return fmt.Sprintf("Registered agents: %s", strings.Join(agentIDs, ", ")), true
default:
return fmt.Sprintf("Unknown list target: %s", args[0]), true
}
@@ -1043,23 +1030,21 @@ func (al *AgentLoop) handleCommand(ctx context.Context, msg bus.InboundMessage)
switch target {
case "model":
oldModel := al.model
al.model = value
defaultAgent := al.registry.GetDefaultAgent()
if defaultAgent == nil {
return "No default agent configured", true
}
oldModel := defaultAgent.Model
defaultAgent.Model = value
return fmt.Sprintf("Switched model from %s to %s", oldModel, value), true
case "channel":
// This changes the 'default' channel for some operations, or effectively redirects output?
// For now, let's just validate if the channel exists
if al.channelManager == nil {
return "Channel manager not initialized", true
}
if _, exists := al.channelManager.GetChannel(value); !exists && value != "cli" {
return fmt.Sprintf("Channel '%s' not found or not enabled", value), true
}
// If message came from CLI, maybe we want to redirect CLI output to this channel?
// That would require state persistence about "redirected channel"
// For now, just acknowledged.
return fmt.Sprintf("Switched target channel to %s (Note: this currently only validates existence)", value), true
return fmt.Sprintf("Switched target channel to %s", value), true
default:
return fmt.Sprintf("Unknown switch target: %s", target), true
}
@@ -1067,3 +1052,30 @@ func (al *AgentLoop) handleCommand(ctx context.Context, msg bus.InboundMessage)
return "", false
}
// extractPeer extracts the routing peer from inbound message metadata.
func extractPeer(msg bus.InboundMessage) *routing.RoutePeer {
peerKind := msg.Metadata["peer_kind"]
if peerKind == "" {
return nil
}
peerID := msg.Metadata["peer_id"]
if peerID == "" {
if peerKind == "direct" {
peerID = msg.SenderID
} else {
peerID = msg.ChatID
}
}
return &routing.RoutePeer{Kind: peerKind, ID: peerID}
}
// extractParentPeer extracts the parent peer (reply-to) from inbound message metadata.
func extractParentPeer(msg bus.InboundMessage) *routing.RoutePeer {
parentKind := msg.Metadata["parent_peer_kind"]
parentID := msg.Metadata["parent_peer_id"]
if parentKind == "" || parentID == "" {
return nil
}
return &routing.RoutePeer{Kind: parentKind, ID: parentID}
}
+6 -2
View File
@@ -594,7 +594,11 @@ func TestAgentLoop_ContextExhaustionRetry(t *testing.T) {
{Role: "assistant", Content: "Old response 2"},
{Role: "user", Content: "Trigger message"},
}
al.sessions.SetHistory(sessionKey, history)
defaultAgent := al.registry.GetDefaultAgent()
if defaultAgent == nil {
t.Fatal("No default agent found")
}
defaultAgent.Sessions.SetHistory(sessionKey, history)
// Call ProcessDirectWithChannel
// Note: ProcessDirectWithChannel calls processMessage which will execute runLLMIteration
@@ -614,7 +618,7 @@ func TestAgentLoop_ContextExhaustionRetry(t *testing.T) {
}
// Check final history length
finalHistory := al.sessions.GetHistory(sessionKey)
finalHistory := defaultAgent.Sessions.GetHistory(sessionKey)
// We verify that the history has been modified (compressed)
// Original length: 6
// Expected behavior: compression drops ~50% of history (mid slice)
+114
View File
@@ -0,0 +1,114 @@
package agent
import (
"sync"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/routing"
)
// AgentRegistry manages multiple agent instances and routes messages to them.
type AgentRegistry struct {
agents map[string]*AgentInstance
resolver *routing.RouteResolver
mu sync.RWMutex
}
// NewAgentRegistry creates a registry from config, instantiating all agents.
func NewAgentRegistry(
cfg *config.Config,
provider providers.LLMProvider,
) *AgentRegistry {
registry := &AgentRegistry{
agents: make(map[string]*AgentInstance),
resolver: routing.NewRouteResolver(cfg),
}
agentConfigs := cfg.Agents.List
if len(agentConfigs) == 0 {
implicitAgent := &config.AgentConfig{
ID: "main",
Default: true,
}
instance := NewAgentInstance(implicitAgent, &cfg.Agents.Defaults, cfg, provider)
registry.agents["main"] = instance
logger.InfoCF("agent", "Created implicit main agent (no agents.list configured)", nil)
} else {
for i := range agentConfigs {
ac := &agentConfigs[i]
id := routing.NormalizeAgentID(ac.ID)
instance := NewAgentInstance(ac, &cfg.Agents.Defaults, cfg, provider)
registry.agents[id] = instance
logger.InfoCF("agent", "Registered agent",
map[string]interface{}{
"agent_id": id,
"name": ac.Name,
"workspace": instance.Workspace,
"model": instance.Model,
})
}
}
return registry
}
// GetAgent returns the agent instance for a given ID.
func (r *AgentRegistry) GetAgent(agentID string) (*AgentInstance, bool) {
r.mu.RLock()
defer r.mu.RUnlock()
id := routing.NormalizeAgentID(agentID)
agent, ok := r.agents[id]
return agent, ok
}
// ResolveRoute determines which agent handles the message.
func (r *AgentRegistry) ResolveRoute(input routing.RouteInput) routing.ResolvedRoute {
return r.resolver.ResolveRoute(input)
}
// ListAgentIDs returns all registered agent IDs.
func (r *AgentRegistry) ListAgentIDs() []string {
r.mu.RLock()
defer r.mu.RUnlock()
ids := make([]string, 0, len(r.agents))
for id := range r.agents {
ids = append(ids, id)
}
return ids
}
// CanSpawnSubagent checks if parentAgentID is allowed to spawn targetAgentID.
func (r *AgentRegistry) CanSpawnSubagent(parentAgentID, targetAgentID string) bool {
parent, ok := r.GetAgent(parentAgentID)
if !ok {
return false
}
if parent.Subagents == nil || parent.Subagents.AllowAgents == nil {
return false
}
targetNorm := routing.NormalizeAgentID(targetAgentID)
for _, allowed := range parent.Subagents.AllowAgents {
if allowed == "*" {
return true
}
if routing.NormalizeAgentID(allowed) == targetNorm {
return true
}
}
return false
}
// GetDefaultAgent returns the default agent instance.
func (r *AgentRegistry) GetDefaultAgent() *AgentInstance {
r.mu.RLock()
defer r.mu.RUnlock()
if agent, ok := r.agents["main"]; ok {
return agent
}
for _, agent := range r.agents {
return agent
}
return nil
}
+199
View File
@@ -0,0 +1,199 @@
package agent
import (
"context"
"testing"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/providers"
)
type mockRegistryProvider struct{}
func (m *mockRegistryProvider) Chat(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, options map[string]interface{}) (*providers.LLMResponse, error) {
return &providers.LLMResponse{Content: "mock", FinishReason: "stop"}, nil
}
func (m *mockRegistryProvider) GetDefaultModel() string {
return "mock-model"
}
func testCfg(agents []config.AgentConfig) *config.Config {
return &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: "/tmp/picoclaw-test-registry",
Model: "gpt-4",
MaxTokens: 8192,
MaxToolIterations: 10,
},
List: agents,
},
}
}
func TestNewAgentRegistry_ImplicitMain(t *testing.T) {
cfg := testCfg(nil)
registry := NewAgentRegistry(cfg, &mockRegistryProvider{})
ids := registry.ListAgentIDs()
if len(ids) != 1 || ids[0] != "main" {
t.Errorf("expected implicit main agent, got %v", ids)
}
agent, ok := registry.GetAgent("main")
if !ok || agent == nil {
t.Fatal("expected to find 'main' agent")
}
if agent.ID != "main" {
t.Errorf("agent.ID = %q, want 'main'", agent.ID)
}
}
func TestNewAgentRegistry_ExplicitAgents(t *testing.T) {
cfg := testCfg([]config.AgentConfig{
{ID: "sales", Default: true, Name: "Sales Bot"},
{ID: "support", Name: "Support Bot"},
})
registry := NewAgentRegistry(cfg, &mockRegistryProvider{})
ids := registry.ListAgentIDs()
if len(ids) != 2 {
t.Fatalf("expected 2 agents, got %d: %v", len(ids), ids)
}
sales, ok := registry.GetAgent("sales")
if !ok || sales == nil {
t.Fatal("expected to find 'sales' agent")
}
if sales.Name != "Sales Bot" {
t.Errorf("sales.Name = %q, want 'Sales Bot'", sales.Name)
}
support, ok := registry.GetAgent("support")
if !ok || support == nil {
t.Fatal("expected to find 'support' agent")
}
}
func TestAgentRegistry_GetAgent_Normalize(t *testing.T) {
cfg := testCfg([]config.AgentConfig{
{ID: "my-agent", Default: true},
})
registry := NewAgentRegistry(cfg, &mockRegistryProvider{})
agent, ok := registry.GetAgent("My-Agent")
if !ok || agent == nil {
t.Fatal("expected to find agent with normalized ID")
}
if agent.ID != "my-agent" {
t.Errorf("agent.ID = %q, want 'my-agent'", agent.ID)
}
}
func TestAgentRegistry_GetDefaultAgent(t *testing.T) {
cfg := testCfg([]config.AgentConfig{
{ID: "alpha"},
{ID: "beta", Default: true},
})
registry := NewAgentRegistry(cfg, &mockRegistryProvider{})
// GetDefaultAgent first checks for "main", then returns any
agent := registry.GetDefaultAgent()
if agent == nil {
t.Fatal("expected a default agent")
}
}
func TestAgentRegistry_CanSpawnSubagent(t *testing.T) {
cfg := testCfg([]config.AgentConfig{
{
ID: "parent",
Default: true,
Subagents: &config.SubagentsConfig{
AllowAgents: []string{"child1", "child2"},
},
},
{ID: "child1"},
{ID: "child2"},
{ID: "restricted"},
})
registry := NewAgentRegistry(cfg, &mockRegistryProvider{})
if !registry.CanSpawnSubagent("parent", "child1") {
t.Error("expected parent to be allowed to spawn child1")
}
if !registry.CanSpawnSubagent("parent", "child2") {
t.Error("expected parent to be allowed to spawn child2")
}
if registry.CanSpawnSubagent("parent", "restricted") {
t.Error("expected parent to NOT be allowed to spawn restricted")
}
if registry.CanSpawnSubagent("child1", "child2") {
t.Error("expected child1 to NOT be allowed to spawn (no subagents config)")
}
}
func TestAgentRegistry_CanSpawnSubagent_Wildcard(t *testing.T) {
cfg := testCfg([]config.AgentConfig{
{
ID: "admin",
Default: true,
Subagents: &config.SubagentsConfig{
AllowAgents: []string{"*"},
},
},
{ID: "any-agent"},
})
registry := NewAgentRegistry(cfg, &mockRegistryProvider{})
if !registry.CanSpawnSubagent("admin", "any-agent") {
t.Error("expected wildcard to allow spawning any agent")
}
if !registry.CanSpawnSubagent("admin", "nonexistent") {
t.Error("expected wildcard to allow spawning even nonexistent agents")
}
}
func TestAgentInstance_Model(t *testing.T) {
model := &config.AgentModelConfig{Primary: "claude-opus"}
cfg := testCfg([]config.AgentConfig{
{ID: "custom", Default: true, Model: model},
})
registry := NewAgentRegistry(cfg, &mockRegistryProvider{})
agent, _ := registry.GetAgent("custom")
if agent.Model != "claude-opus" {
t.Errorf("agent.Model = %q, want 'claude-opus'", agent.Model)
}
}
func TestAgentInstance_FallbackInheritance(t *testing.T) {
cfg := testCfg([]config.AgentConfig{
{ID: "inherit", Default: true},
})
cfg.Agents.Defaults.ModelFallbacks = []string{"openai/gpt-4o-mini", "anthropic/haiku"}
registry := NewAgentRegistry(cfg, &mockRegistryProvider{})
agent, _ := registry.GetAgent("inherit")
if len(agent.Fallbacks) != 2 {
t.Errorf("expected 2 fallbacks inherited from defaults, got %d", len(agent.Fallbacks))
}
}
func TestAgentInstance_FallbackExplicitEmpty(t *testing.T) {
model := &config.AgentModelConfig{
Primary: "gpt-4",
Fallbacks: []string{}, // explicitly empty = disable
}
cfg := testCfg([]config.AgentConfig{
{ID: "no-fallback", Default: true, Model: model},
})
cfg.Agents.Defaults.ModelFallbacks = []string{"should-not-inherit"}
registry := NewAgentRegistry(cfg, &mockRegistryProvider{})
agent, _ := registry.GetAgent("no-fallback")
if len(agent.Fallbacks) != 0 {
t.Errorf("expected 0 fallbacks (explicit empty), got %d: %v", len(agent.Fallbacks), agent.Fallbacks)
}
}
+6 -11
View File
@@ -2,7 +2,6 @@ package channels
import (
"context"
"fmt"
"strings"
"github.com/sipeed/picoclaw/pkg/bus"
@@ -87,17 +86,13 @@ func (c *BaseChannel) HandleMessage(senderID, chatID, content string, media []st
return
}
// Build session key: channel:chatID
sessionKey := fmt.Sprintf("%s:%s", c.name, chatID)
msg := bus.InboundMessage{
Channel: c.name,
SenderID: senderID,
ChatID: chatID,
Content: content,
Media: media,
SessionKey: sessionKey,
Metadata: metadata,
Channel: c.name,
SenderID: senderID,
ChatID: chatID,
Content: content,
Media: media,
Metadata: metadata,
}
c.bus.PublishInbound(msg)
+73 -134
View File
@@ -4,7 +4,7 @@ import (
"context"
"fmt"
"os"
"strings"
"sync"
"time"
"github.com/bwmarrin/discordgo"
@@ -26,6 +26,8 @@ type DiscordChannel struct {
config config.DiscordConfig
transcriber *voice.GroqTranscriber
ctx context.Context
typingMu sync.Mutex
typingStop map[string]chan struct{} // chatID → stop signal
}
func NewDiscordChannel(cfg config.DiscordConfig, bus *bus.MessageBus) (*DiscordChannel, error) {
@@ -42,6 +44,7 @@ func NewDiscordChannel(cfg config.DiscordConfig, bus *bus.MessageBus) (*DiscordC
config: cfg,
transcriber: nil,
ctx: context.Background(),
typingStop: make(map[string]chan struct{}),
}, nil
}
@@ -84,6 +87,14 @@ func (c *DiscordChannel) Stop(ctx context.Context) error {
logger.InfoC("discord", "Stopping Discord bot")
c.setRunning(false)
// Stop all typing goroutines before closing session
c.typingMu.Lock()
for chatID, stop := range c.typingStop {
close(stop)
delete(c.typingStop, chatID)
}
c.typingMu.Unlock()
if err := c.session.Close(); err != nil {
return fmt.Errorf("failed to close discord session: %w", err)
}
@@ -92,6 +103,8 @@ func (c *DiscordChannel) Stop(ctx context.Context) error {
}
func (c *DiscordChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
c.stopTyping(msg.ChatID)
if !c.IsRunning() {
return fmt.Errorf("discord bot not running")
}
@@ -106,7 +119,7 @@ func (c *DiscordChannel) Send(ctx context.Context, msg bus.OutboundMessage) erro
return nil
}
chunks := splitMessage(msg.Content, 1500) // Discord has a limit of 2000 characters per message, leave 500 for natural split e.g. code blocks
chunks := utils.SplitMessage(msg.Content, 2000) // Split messages into chunks, Discord length limit: 2000 chars
for _, chunk := range chunks {
if err := c.sendChunk(ctx, channelID, chunk); err != nil {
@@ -117,132 +130,6 @@ func (c *DiscordChannel) Send(ctx context.Context, msg bus.OutboundMessage) erro
return nil
}
// splitMessage splits long messages into chunks, preserving code block integrity
// Uses natural boundaries (newlines, spaces) and extends messages slightly to avoid breaking code blocks
func splitMessage(content string, limit int) []string {
var messages []string
for len(content) > 0 {
if len(content) <= limit {
messages = append(messages, content)
break
}
msgEnd := limit
// Find natural split point within the limit
msgEnd = findLastNewline(content[:limit], 200)
if msgEnd <= 0 {
msgEnd = findLastSpace(content[:limit], 100)
}
if msgEnd <= 0 {
msgEnd = limit
}
// Check if this would end with an incomplete code block
candidate := content[:msgEnd]
unclosedIdx := findLastUnclosedCodeBlock(candidate)
if unclosedIdx >= 0 {
// Message would end with incomplete code block
// Try to extend to include the closing ``` (with some buffer)
extendedLimit := limit + 500 // Allow 500 char buffer for code blocks
if len(content) > extendedLimit {
closingIdx := findNextClosingCodeBlock(content, msgEnd)
if closingIdx > 0 && closingIdx <= extendedLimit {
// Extend to include the closing ```
msgEnd = closingIdx
} else {
// Can't find closing, split before the code block
msgEnd = findLastNewline(content[:unclosedIdx], 200)
if msgEnd <= 0 {
msgEnd = findLastSpace(content[:unclosedIdx], 100)
}
if msgEnd <= 0 {
msgEnd = unclosedIdx
}
}
} else {
// Remaining content fits within extended limit
msgEnd = len(content)
}
}
if msgEnd <= 0 {
msgEnd = limit
}
messages = append(messages, content[:msgEnd])
content = strings.TrimSpace(content[msgEnd:])
}
return messages
}
// findLastUnclosedCodeBlock finds the last opening ``` that doesn't have a closing ```
// Returns the position of the opening ``` or -1 if all code blocks are complete
func findLastUnclosedCodeBlock(text string) int {
count := 0
lastOpenIdx := -1
for i := 0; i < len(text); i++ {
if i+2 < len(text) && text[i] == '`' && text[i+1] == '`' && text[i+2] == '`' {
if count == 0 {
lastOpenIdx = i
}
count++
i += 2
}
}
// If odd number of ``` markers, last one is unclosed
if count%2 == 1 {
return lastOpenIdx
}
return -1
}
// findNextClosingCodeBlock finds the next closing ``` starting from a position
// Returns the position after the closing ``` or -1 if not found
func findNextClosingCodeBlock(text string, startIdx int) int {
for i := startIdx; i < len(text); i++ {
if i+2 < len(text) && text[i] == '`' && text[i+1] == '`' && text[i+2] == '`' {
return i + 3
}
}
return -1
}
// findLastNewline finds the last newline character within the last N characters
// Returns the position of the newline or -1 if not found
func findLastNewline(s string, searchWindow int) int {
searchStart := len(s) - searchWindow
if searchStart < 0 {
searchStart = 0
}
for i := len(s) - 1; i >= searchStart; i-- {
if s[i] == '\n' {
return i
}
}
return -1
}
// findLastSpace finds the last space character within the last N characters
// Returns the position of the space or -1 if not found
func findLastSpace(s string, searchWindow int) int {
searchStart := len(s) - searchWindow
if searchStart < 0 {
searchStart = 0
}
for i := len(s) - 1; i >= searchStart; i-- {
if s[i] == ' ' || s[i] == '\t' {
return i
}
}
return -1
}
func (c *DiscordChannel) sendChunk(ctx context.Context, channelID, content string) error {
// 使用传入的 ctx 进行超时控制
sendCtx, cancel := context.WithTimeout(ctx, sendTimeout)
@@ -282,12 +169,6 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag
return
}
if err := c.session.ChannelTyping(m.ChannelID); err != nil {
logger.ErrorCF("discord", "Failed to send typing indicator", map[string]any{
"error": err.Error(),
})
}
// 检查白名单,避免为被拒绝的用户下载附件和转录
if !c.IsAllowed(m.Author.ID) {
logger.DebugCF("discord", "Message rejected by allowlist", map[string]any{
@@ -370,12 +251,22 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag
content = "[media only]"
}
// Start typing after all early returns — guaranteed to have a matching Send()
c.startTyping(m.ChannelID)
logger.DebugCF("discord", "Received message", map[string]any{
"sender_name": senderName,
"sender_id": senderID,
"preview": utils.Truncate(content, 50),
})
peerKind := "channel"
peerID := m.ChannelID
if m.GuildID == "" {
peerKind = "direct"
peerID = senderID
}
metadata := map[string]string{
"message_id": m.ID,
"user_id": senderID,
@@ -384,11 +275,59 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag
"guild_id": m.GuildID,
"channel_id": m.ChannelID,
"is_dm": fmt.Sprintf("%t", m.GuildID == ""),
"peer_kind": peerKind,
"peer_id": peerID,
}
c.HandleMessage(senderID, m.ChannelID, content, mediaPaths, metadata)
}
// startTyping starts a continuous typing indicator loop for the given chatID.
// It stops any existing typing loop for that chatID before starting a new one.
func (c *DiscordChannel) startTyping(chatID string) {
c.typingMu.Lock()
// Stop existing loop for this chatID if any
if stop, ok := c.typingStop[chatID]; ok {
close(stop)
}
stop := make(chan struct{})
c.typingStop[chatID] = stop
c.typingMu.Unlock()
go func() {
if err := c.session.ChannelTyping(chatID); err != nil {
logger.DebugCF("discord", "ChannelTyping error", map[string]interface{}{"chatID": chatID, "err": err})
}
ticker := time.NewTicker(8 * time.Second)
defer ticker.Stop()
timeout := time.After(5 * time.Minute)
for {
select {
case <-stop:
return
case <-timeout:
return
case <-c.ctx.Done():
return
case <-ticker.C:
if err := c.session.ChannelTyping(chatID); err != nil {
logger.DebugCF("discord", "ChannelTyping error", map[string]interface{}{"chatID": chatID, "err": err})
}
}
}
}()
}
// stopTyping stops the typing indicator loop for the given chatID.
func (c *DiscordChannel) stopTyping(chatID string) {
c.typingMu.Lock()
defer c.typingMu.Unlock()
if stop, ok := c.typingStop[chatID]; ok {
close(stop)
delete(c.typingStop, chatID)
}
}
func (c *DiscordChannel) downloadAttachment(url, filename string) string {
return utils.DownloadFile(url, filename, utils.DownloadOptions{
LoggerPrefix: "discord",
-2
View File
@@ -18,7 +18,6 @@ type MaixCamChannel struct {
listener net.Listener
clients map[net.Conn]bool
clientsMux sync.RWMutex
running bool
}
type MaixCamMessage struct {
@@ -35,7 +34,6 @@ func NewMaixCamChannel(cfg config.MaixCamConfig, bus *bus.MessageBus) (*MaixCamC
BaseChannel: base,
config: cfg,
clients: make(map[net.Conn]bool),
running: false,
}, nil
}
+498 -209
View File
@@ -4,9 +4,11 @@ import (
"context"
"encoding/json"
"fmt"
"os"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/gorilla/websocket"
@@ -14,20 +16,28 @@ import (
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/utils"
"github.com/sipeed/picoclaw/pkg/voice"
)
type OneBotChannel struct {
*BaseChannel
config config.OneBotConfig
conn *websocket.Conn
ctx context.Context
cancel context.CancelFunc
dedup map[string]struct{}
dedupRing []string
dedupIdx int
mu sync.Mutex
writeMu sync.Mutex
echoCounter int64
config config.OneBotConfig
conn *websocket.Conn
ctx context.Context
cancel context.CancelFunc
dedup map[string]struct{}
dedupRing []string
dedupIdx int
mu sync.Mutex
writeMu sync.Mutex
echoCounter int64
selfID int64
pending map[string]chan json.RawMessage
pendingMu sync.Mutex
transcriber *voice.GroqTranscriber
lastMessageID sync.Map
pendingEmojiMsg sync.Map
}
type oneBotRawEvent struct {
@@ -43,9 +53,11 @@ type oneBotRawEvent struct {
SelfID json.RawMessage `json:"self_id"`
Time json.RawMessage `json:"time"`
MetaEventType string `json:"meta_event_type"`
NoticeType string `json:"notice_type"`
Echo string `json:"echo"`
RetCode json.RawMessage `json:"retcode"`
Status BotStatus `json:"status"`
Status json.RawMessage `json:"status"`
Data json.RawMessage `json:"data"`
}
type BotStatus struct {
@@ -53,42 +65,36 @@ type BotStatus struct {
Good bool `json:"good"`
}
func isAPIResponse(raw json.RawMessage) bool {
if len(raw) == 0 {
return false
}
var s string
if json.Unmarshal(raw, &s) == nil {
return s == "ok" || s == "failed"
}
var bs BotStatus
if json.Unmarshal(raw, &bs) == nil {
return bs.Online || bs.Good
}
return false
}
type oneBotSender struct {
UserID json.RawMessage `json:"user_id"`
Nickname string `json:"nickname"`
Card string `json:"card"`
}
type oneBotEvent struct {
PostType string
MessageType string
SubType string
MessageID string
UserID int64
GroupID int64
Content string
RawContent string
IsBotMentioned bool
Sender oneBotSender
SelfID int64
Time int64
MetaEventType string
}
type oneBotAPIRequest struct {
Action string `json:"action"`
Params interface{} `json:"params"`
Echo string `json:"echo,omitempty"`
}
type oneBotSendPrivateMsgParams struct {
UserID int64 `json:"user_id"`
Message string `json:"message"`
}
type oneBotSendGroupMsgParams struct {
GroupID int64 `json:"group_id"`
Message string `json:"message"`
type oneBotMessageSegment struct {
Type string `json:"type"`
Data map[string]interface{} `json:"data"`
}
func NewOneBotChannel(cfg config.OneBotConfig, messageBus *bus.MessageBus) (*OneBotChannel, error) {
@@ -101,9 +107,30 @@ func NewOneBotChannel(cfg config.OneBotConfig, messageBus *bus.MessageBus) (*One
dedup: make(map[string]struct{}, dedupSize),
dedupRing: make([]string, dedupSize),
dedupIdx: 0,
pending: make(map[string]chan json.RawMessage),
}, nil
}
func (c *OneBotChannel) SetTranscriber(transcriber *voice.GroqTranscriber) {
c.transcriber = transcriber
}
func (c *OneBotChannel) setMsgEmojiLike(messageID string, emojiID int, set bool) {
go func() {
_, err := c.sendAPIRequest("set_msg_emoji_like", map[string]interface{}{
"message_id": messageID,
"emoji_id": emojiID,
"set": set,
}, 5*time.Second)
if err != nil {
logger.DebugCF("onebot", "Failed to set emoji like", map[string]interface{}{
"message_id": messageID,
"error": err.Error(),
})
}
}()
}
func (c *OneBotChannel) Start(ctx context.Context) error {
if c.config.WSUrl == "" {
return fmt.Errorf("OneBot ws_url not configured")
@@ -121,12 +148,12 @@ func (c *OneBotChannel) Start(ctx context.Context) error {
})
} else {
go c.listen()
c.fetchSelfID()
}
if c.config.ReconnectInterval > 0 {
go c.reconnectLoop()
} else {
// If reconnect is disabled but initial connection failed, we cannot recover
if c.conn == nil {
return fmt.Errorf("failed to connect to OneBot and reconnect is disabled")
}
@@ -152,14 +179,141 @@ func (c *OneBotChannel) connect() error {
return err
}
conn.SetPongHandler(func(appData string) error {
_ = conn.SetReadDeadline(time.Now().Add(60 * time.Second))
return nil
})
_ = conn.SetReadDeadline(time.Now().Add(60 * time.Second))
c.mu.Lock()
c.conn = conn
c.mu.Unlock()
go c.pinger(conn)
logger.InfoC("onebot", "WebSocket connected")
return nil
}
func (c *OneBotChannel) pinger(conn *websocket.Conn) {
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
for {
select {
case <-c.ctx.Done():
return
case <-ticker.C:
c.writeMu.Lock()
err := conn.WriteMessage(websocket.PingMessage, nil)
c.writeMu.Unlock()
if err != nil {
logger.DebugCF("onebot", "Ping write failed, stopping pinger", map[string]interface{}{
"error": err.Error(),
})
return
}
}
}
}
func (c *OneBotChannel) fetchSelfID() {
resp, err := c.sendAPIRequest("get_login_info", nil, 5*time.Second)
if err != nil {
logger.WarnCF("onebot", "Failed to get_login_info", map[string]interface{}{
"error": err.Error(),
})
return
}
type loginInfo struct {
UserID json.RawMessage `json:"user_id"`
Nickname string `json:"nickname"`
}
for _, extract := range []func() (*loginInfo, error){
func() (*loginInfo, error) {
var w struct {
Data loginInfo `json:"data"`
}
err := json.Unmarshal(resp, &w)
return &w.Data, err
},
func() (*loginInfo, error) {
var f loginInfo
err := json.Unmarshal(resp, &f)
return &f, err
},
} {
info, err := extract()
if err != nil || len(info.UserID) == 0 {
continue
}
if uid, err := parseJSONInt64(info.UserID); err == nil && uid > 0 {
atomic.StoreInt64(&c.selfID, uid)
logger.InfoCF("onebot", "Bot self ID retrieved", map[string]interface{}{
"self_id": uid,
"nickname": info.Nickname,
})
return
}
}
logger.WarnCF("onebot", "Could not parse self ID from get_login_info response", map[string]interface{}{
"response": string(resp),
})
}
func (c *OneBotChannel) sendAPIRequest(action string, params interface{}, timeout time.Duration) (json.RawMessage, error) {
c.mu.Lock()
conn := c.conn
c.mu.Unlock()
if conn == nil {
return nil, fmt.Errorf("WebSocket not connected")
}
echo := fmt.Sprintf("api_%d_%d", time.Now().UnixNano(), atomic.AddInt64(&c.echoCounter, 1))
ch := make(chan json.RawMessage, 1)
c.pendingMu.Lock()
c.pending[echo] = ch
c.pendingMu.Unlock()
defer func() {
c.pendingMu.Lock()
delete(c.pending, echo)
c.pendingMu.Unlock()
}()
req := oneBotAPIRequest{
Action: action,
Params: params,
Echo: echo,
}
data, err := json.Marshal(req)
if err != nil {
return nil, fmt.Errorf("failed to marshal API request: %w", err)
}
c.writeMu.Lock()
err = conn.WriteMessage(websocket.TextMessage, data)
c.writeMu.Unlock()
if err != nil {
return nil, fmt.Errorf("failed to write API request: %w", err)
}
select {
case resp := <-ch:
return resp, nil
case <-time.After(timeout):
return nil, fmt.Errorf("API request %s timed out after %v", action, timeout)
case <-c.ctx.Done():
return nil, fmt.Errorf("context cancelled")
}
}
func (c *OneBotChannel) reconnectLoop() {
interval := time.Duration(c.config.ReconnectInterval) * time.Second
if interval < 5*time.Second {
@@ -183,6 +337,7 @@ func (c *OneBotChannel) reconnectLoop() {
})
} else {
go c.listen()
c.fetchSelfID()
}
}
}
@@ -197,6 +352,13 @@ func (c *OneBotChannel) Stop(ctx context.Context) error {
c.cancel()
}
c.pendingMu.Lock()
for echo, ch := range c.pending {
close(ch)
delete(c.pending, echo)
}
c.pendingMu.Unlock()
c.mu.Lock()
if c.conn != nil {
c.conn.Close()
@@ -225,10 +387,7 @@ func (c *OneBotChannel) Send(ctx context.Context, msg bus.OutboundMessage) error
return err
}
c.writeMu.Lock()
c.echoCounter++
echo := fmt.Sprintf("send_%d", c.echoCounter)
c.writeMu.Unlock()
echo := fmt.Sprintf("send_%d", atomic.AddInt64(&c.echoCounter, 1))
req := oneBotAPIRequest{
Action: action,
@@ -252,67 +411,78 @@ func (c *OneBotChannel) Send(ctx context.Context, msg bus.OutboundMessage) error
return err
}
if msgID, ok := c.pendingEmojiMsg.LoadAndDelete(msg.ChatID); ok {
if mid, ok := msgID.(string); ok && mid != "" {
c.setMsgEmojiLike(mid, 289, false)
}
}
return nil
}
func (c *OneBotChannel) buildMessageSegments(chatID, content string) []oneBotMessageSegment {
var segments []oneBotMessageSegment
if lastMsgID, ok := c.lastMessageID.Load(chatID); ok {
if msgID, ok := lastMsgID.(string); ok && msgID != "" {
segments = append(segments, oneBotMessageSegment{
Type: "reply",
Data: map[string]interface{}{"id": msgID},
})
}
}
segments = append(segments, oneBotMessageSegment{
Type: "text",
Data: map[string]interface{}{"text": content},
})
return segments
}
func (c *OneBotChannel) buildSendRequest(msg bus.OutboundMessage) (string, interface{}, error) {
chatID := msg.ChatID
segments := c.buildMessageSegments(chatID, msg.Content)
if len(chatID) > 6 && chatID[:6] == "group:" {
groupID, err := strconv.ParseInt(chatID[6:], 10, 64)
if err != nil {
return "", nil, fmt.Errorf("invalid group ID in chatID: %s", chatID)
}
return "send_group_msg", oneBotSendGroupMsgParams{
GroupID: groupID,
Message: msg.Content,
}, nil
var action, idKey string
var rawID string
if rest, ok := strings.CutPrefix(chatID, "group:"); ok {
action, idKey, rawID = "send_group_msg", "group_id", rest
} else if rest, ok := strings.CutPrefix(chatID, "private:"); ok {
action, idKey, rawID = "send_private_msg", "user_id", rest
} else {
action, idKey, rawID = "send_private_msg", "user_id", chatID
}
if len(chatID) > 8 && chatID[:8] == "private:" {
userID, err := strconv.ParseInt(chatID[8:], 10, 64)
if err != nil {
return "", nil, fmt.Errorf("invalid user ID in chatID: %s", chatID)
}
return "send_private_msg", oneBotSendPrivateMsgParams{
UserID: userID,
Message: msg.Content,
}, nil
}
userID, err := strconv.ParseInt(chatID, 10, 64)
id, err := strconv.ParseInt(rawID, 10, 64)
if err != nil {
return "", nil, fmt.Errorf("invalid chatID for OneBot: %s", chatID)
return "", nil, fmt.Errorf("invalid %s in chatID: %s", idKey, chatID)
}
return "send_private_msg", oneBotSendPrivateMsgParams{
UserID: userID,
Message: msg.Content,
}, nil
return action, map[string]interface{}{idKey: id, "message": segments}, nil
}
func (c *OneBotChannel) listen() {
c.mu.Lock()
conn := c.conn
c.mu.Unlock()
if conn == nil {
logger.WarnC("onebot", "WebSocket connection is nil, listener exiting")
return
}
for {
select {
case <-c.ctx.Done():
return
default:
c.mu.Lock()
conn := c.conn
c.mu.Unlock()
if conn == nil {
logger.WarnC("onebot", "WebSocket connection is nil, listener exiting")
return
}
_, message, err := conn.ReadMessage()
if err != nil {
logger.ErrorCF("onebot", "WebSocket read error", map[string]interface{}{
"error": err.Error(),
})
c.mu.Lock()
if c.conn != nil {
if c.conn == conn {
c.conn.Close()
c.conn = nil
}
@@ -320,10 +490,7 @@ func (c *OneBotChannel) listen() {
return
}
logger.DebugCF("onebot", "Raw WebSocket message received", map[string]interface{}{
"length": len(message),
"payload": string(message),
})
_ = conn.SetReadDeadline(time.Now().Add(60 * time.Second))
var raw oneBotRawEvent
if err := json.Unmarshal(message, &raw); err != nil {
@@ -334,20 +501,37 @@ func (c *OneBotChannel) listen() {
continue
}
if raw.Echo != "" || raw.Status.Online || raw.Status.Good {
logger.DebugCF("onebot", "Received API response, skipping", map[string]interface{}{
"echo": raw.Echo,
"status": raw.Status,
})
logger.DebugCF("onebot", "WebSocket event", map[string]interface{}{
"length": len(message),
"post_type": raw.PostType,
"sub_type": raw.SubType,
})
if raw.Echo != "" {
c.pendingMu.Lock()
ch, ok := c.pending[raw.Echo]
c.pendingMu.Unlock()
if ok {
select {
case ch <- message:
default:
}
} else {
logger.DebugCF("onebot", "Received API response (no waiter)", map[string]interface{}{
"echo": raw.Echo,
"status": string(raw.Status),
})
}
continue
}
logger.DebugCF("onebot", "Parsed raw event", map[string]interface{}{
"post_type": raw.PostType,
"message_type": raw.MessageType,
"sub_type": raw.SubType,
"meta_event_type": raw.MetaEventType,
})
if isAPIResponse(raw.Status) {
logger.DebugCF("onebot", "Received API response without echo, skipping", map[string]interface{}{
"status": string(raw.Status),
})
continue
}
c.handleRawEvent(&raw)
}
@@ -386,9 +570,12 @@ func parseJSONString(raw json.RawMessage) string {
type parseMessageResult struct {
Text string
IsBotMentioned bool
Media []string
LocalFiles []string
ReplyTo string
}
func parseMessageContentEx(raw json.RawMessage, selfID int64) parseMessageResult {
func (c *OneBotChannel) parseMessageSegments(raw json.RawMessage, selfID int64) parseMessageResult {
if len(raw) == 0 {
return parseMessageResult{}
}
@@ -408,60 +595,155 @@ func parseMessageContentEx(raw json.RawMessage, selfID int64) parseMessageResult
}
var segments []map[string]interface{}
if err := json.Unmarshal(raw, &segments); err == nil {
var text string
mentioned := false
selfIDStr := strconv.FormatInt(selfID, 10)
for _, seg := range segments {
segType, _ := seg["type"].(string)
data, _ := seg["data"].(map[string]interface{})
switch segType {
case "text":
if data != nil {
if t, ok := data["text"].(string); ok {
text += t
}
if err := json.Unmarshal(raw, &segments); err != nil {
return parseMessageResult{}
}
var textParts []string
mentioned := false
selfIDStr := strconv.FormatInt(selfID, 10)
var media []string
var localFiles []string
var replyTo string
for _, seg := range segments {
segType, _ := seg["type"].(string)
data, _ := seg["data"].(map[string]interface{})
switch segType {
case "text":
if data != nil {
if t, ok := data["text"].(string); ok {
textParts = append(textParts, t)
}
case "at":
if data != nil && selfID > 0 {
qqVal := fmt.Sprintf("%v", data["qq"])
if qqVal == selfIDStr || qqVal == "all" {
mentioned = true
}
case "at":
if data != nil && selfID > 0 {
qqVal := fmt.Sprintf("%v", data["qq"])
if qqVal == selfIDStr || qqVal == "all" {
mentioned = true
}
}
case "image", "video", "file":
if data != nil {
url, _ := data["url"].(string)
if url != "" {
defaults := map[string]string{"image": "image.jpg", "video": "video.mp4", "file": "file"}
filename := defaults[segType]
if f, ok := data["file"].(string); ok && f != "" {
filename = f
} else if n, ok := data["name"].(string); ok && n != "" {
filename = n
}
localPath := utils.DownloadFile(url, filename, utils.DownloadOptions{
LoggerPrefix: "onebot",
})
if localPath != "" {
media = append(media, localPath)
localFiles = append(localFiles, localPath)
textParts = append(textParts, fmt.Sprintf("[%s]", segType))
}
}
}
case "record":
if data != nil {
url, _ := data["url"].(string)
if url != "" {
localPath := utils.DownloadFile(url, "voice.amr", utils.DownloadOptions{
LoggerPrefix: "onebot",
})
if localPath != "" {
localFiles = append(localFiles, localPath)
if c.transcriber != nil && c.transcriber.IsAvailable() {
tctx, tcancel := context.WithTimeout(c.ctx, 30*time.Second)
result, err := c.transcriber.Transcribe(tctx, localPath)
tcancel()
if err != nil {
logger.WarnCF("onebot", "Voice transcription failed", map[string]interface{}{
"error": err.Error(),
})
textParts = append(textParts, "[voice (transcription failed)]")
media = append(media, localPath)
} else {
textParts = append(textParts, fmt.Sprintf("[voice transcription: %s]", result.Text))
}
} else {
textParts = append(textParts, "[voice]")
media = append(media, localPath)
}
}
}
}
case "reply":
if data != nil {
if id, ok := data["id"]; ok {
replyTo = fmt.Sprintf("%v", id)
}
}
case "face":
if data != nil {
faceID, _ := data["id"]
textParts = append(textParts, fmt.Sprintf("[face:%v]", faceID))
}
case "forward":
textParts = append(textParts, "[forward message]")
default:
}
return parseMessageResult{Text: strings.TrimSpace(text), IsBotMentioned: mentioned}
}
return parseMessageResult{}
return parseMessageResult{
Text: strings.TrimSpace(strings.Join(textParts, "")),
IsBotMentioned: mentioned,
Media: media,
LocalFiles: localFiles,
ReplyTo: replyTo,
}
}
func (c *OneBotChannel) handleRawEvent(raw *oneBotRawEvent) {
switch raw.PostType {
case "message":
evt, err := c.normalizeMessageEvent(raw)
if err != nil {
logger.WarnCF("onebot", "Failed to normalize message event", map[string]interface{}{
"error": err.Error(),
})
return
if userID, err := parseJSONInt64(raw.UserID); err == nil && userID > 0 {
if !c.IsAllowed(strconv.FormatInt(userID, 10)) {
logger.DebugCF("onebot", "Message rejected by allowlist", map[string]interface{}{
"user_id": userID,
})
return
}
}
c.handleMessage(evt)
c.handleMessage(raw)
case "message_sent":
logger.DebugCF("onebot", "Bot sent message event", map[string]interface{}{
"message_type": raw.MessageType,
"message_id": parseJSONString(raw.MessageID),
})
case "meta_event":
c.handleMetaEvent(raw)
case "notice":
logger.DebugCF("onebot", "Notice event received", map[string]interface{}{
"sub_type": raw.SubType,
})
c.handleNoticeEvent(raw)
case "request":
logger.DebugCF("onebot", "Request event received", map[string]interface{}{
"sub_type": raw.SubType,
})
case "":
logger.DebugCF("onebot", "Event with empty post_type (possibly API response)", map[string]interface{}{
"echo": raw.Echo,
"status": raw.Status,
})
default:
logger.DebugCF("onebot", "Unknown post_type", map[string]interface{}{
"post_type": raw.PostType,
@@ -469,18 +751,51 @@ func (c *OneBotChannel) handleRawEvent(raw *oneBotRawEvent) {
}
}
func (c *OneBotChannel) normalizeMessageEvent(raw *oneBotRawEvent) (*oneBotEvent, error) {
func (c *OneBotChannel) handleMetaEvent(raw *oneBotRawEvent) {
if raw.MetaEventType == "lifecycle" {
logger.InfoCF("onebot", "Lifecycle event", map[string]interface{}{"sub_type": raw.SubType})
} else if raw.MetaEventType != "heartbeat" {
logger.DebugCF("onebot", "Meta event: "+raw.MetaEventType, nil)
}
}
func (c *OneBotChannel) handleNoticeEvent(raw *oneBotRawEvent) {
fields := map[string]interface{}{
"notice_type": raw.NoticeType,
"sub_type": raw.SubType,
"group_id": parseJSONString(raw.GroupID),
"user_id": parseJSONString(raw.UserID),
"message_id": parseJSONString(raw.MessageID),
}
switch raw.NoticeType {
case "group_recall", "group_increase", "group_decrease",
"friend_add", "group_admin", "group_ban":
logger.InfoCF("onebot", "Notice: "+raw.NoticeType, fields)
default:
logger.DebugCF("onebot", "Notice: "+raw.NoticeType, fields)
}
}
func (c *OneBotChannel) handleMessage(raw *oneBotRawEvent) {
// Parse fields from raw event
userID, err := parseJSONInt64(raw.UserID)
if err != nil {
return nil, fmt.Errorf("parse user_id: %w (raw: %s)", err, string(raw.UserID))
logger.WarnCF("onebot", "Failed to parse user_id", map[string]interface{}{
"error": err.Error(),
"raw": string(raw.UserID),
})
return
}
groupID, _ := parseJSONInt64(raw.GroupID)
selfID, _ := parseJSONInt64(raw.SelfID)
ts, _ := parseJSONInt64(raw.Time)
messageID := parseJSONString(raw.MessageID)
parsed := parseMessageContentEx(raw.Message, selfID)
if selfID == 0 {
selfID = atomic.LoadInt64(&c.selfID)
}
parsed := c.parseMessageSegments(raw.Message, selfID)
isBotMentioned := parsed.IsBotMentioned
content := raw.RawMessage
@@ -495,6 +810,10 @@ func (c *OneBotChannel) normalizeMessageEvent(raw *oneBotRawEvent) (*oneBotEvent
}
}
if parsed.Text != "" && content != parsed.Text && (len(parsed.Media) > 0 || parsed.ReplyTo != "") {
content = parsed.Text
}
var sender oneBotSender
if len(raw.Sender) > 0 {
if err := json.Unmarshal(raw.Sender, &sender); err != nil {
@@ -505,137 +824,107 @@ func (c *OneBotChannel) normalizeMessageEvent(raw *oneBotRawEvent) (*oneBotEvent
}
}
logger.DebugCF("onebot", "Normalized message event", map[string]interface{}{
"message_type": raw.MessageType,
"user_id": userID,
"group_id": groupID,
"message_id": messageID,
"content_len": len(content),
"nickname": sender.Nickname,
})
return &oneBotEvent{
PostType: raw.PostType,
MessageType: raw.MessageType,
SubType: raw.SubType,
MessageID: messageID,
UserID: userID,
GroupID: groupID,
Content: content,
RawContent: raw.RawMessage,
IsBotMentioned: isBotMentioned,
Sender: sender,
SelfID: selfID,
Time: ts,
MetaEventType: raw.MetaEventType,
}, nil
}
func (c *OneBotChannel) handleMetaEvent(raw *oneBotRawEvent) {
switch raw.MetaEventType {
case "lifecycle":
logger.InfoCF("onebot", "Lifecycle event", map[string]interface{}{
"sub_type": raw.SubType,
})
case "heartbeat":
logger.DebugC("onebot", "Heartbeat received")
default:
logger.DebugCF("onebot", "Unknown meta_event_type", map[string]interface{}{
"meta_event_type": raw.MetaEventType,
})
// Clean up temp files when done
if len(parsed.LocalFiles) > 0 {
defer func() {
for _, f := range parsed.LocalFiles {
if err := os.Remove(f); err != nil {
logger.DebugCF("onebot", "Failed to remove temp file", map[string]interface{}{
"path": f,
"error": err.Error(),
})
}
}
}()
}
}
func (c *OneBotChannel) handleMessage(evt *oneBotEvent) {
if c.isDuplicate(evt.MessageID) {
if c.isDuplicate(messageID) {
logger.DebugCF("onebot", "Duplicate message, skipping", map[string]interface{}{
"message_id": evt.MessageID,
"message_id": messageID,
})
return
}
content := evt.Content
if content == "" {
logger.DebugCF("onebot", "Received empty message, ignoring", map[string]interface{}{
"message_id": evt.MessageID,
"message_id": messageID,
})
return
}
senderID := strconv.FormatInt(evt.UserID, 10)
senderID := strconv.FormatInt(userID, 10)
var chatID string
metadata := map[string]string{
"message_id": evt.MessageID,
"message_id": messageID,
}
switch evt.MessageType {
if parsed.ReplyTo != "" {
metadata["reply_to_message_id"] = parsed.ReplyTo
}
switch raw.MessageType {
case "private":
chatID = "private:" + senderID
logger.InfoCF("onebot", "Received private message", map[string]interface{}{
"sender": senderID,
"message_id": evt.MessageID,
"length": len(content),
"content": truncate(content, 100),
})
case "group":
groupIDStr := strconv.FormatInt(evt.GroupID, 10)
groupIDStr := strconv.FormatInt(groupID, 10)
chatID = "group:" + groupIDStr
metadata["group_id"] = groupIDStr
senderUserID, _ := parseJSONInt64(evt.Sender.UserID)
senderUserID, _ := parseJSONInt64(sender.UserID)
if senderUserID > 0 {
metadata["sender_user_id"] = strconv.FormatInt(senderUserID, 10)
}
if evt.Sender.Card != "" {
metadata["sender_name"] = evt.Sender.Card
} else if evt.Sender.Nickname != "" {
metadata["sender_name"] = evt.Sender.Nickname
if sender.Card != "" {
metadata["sender_name"] = sender.Card
} else if sender.Nickname != "" {
metadata["sender_name"] = sender.Nickname
}
triggered, strippedContent := c.checkGroupTrigger(content, evt.IsBotMentioned)
triggered, strippedContent := c.checkGroupTrigger(content, isBotMentioned)
if !triggered {
logger.DebugCF("onebot", "Group message ignored (no trigger)", map[string]interface{}{
"sender": senderID,
"group": groupIDStr,
"is_mentioned": evt.IsBotMentioned,
"is_mentioned": isBotMentioned,
"content": truncate(content, 100),
})
return
}
content = strippedContent
logger.InfoCF("onebot", "Received group message", map[string]interface{}{
"sender": senderID,
"group": groupIDStr,
"message_id": evt.MessageID,
"is_mentioned": evt.IsBotMentioned,
"length": len(content),
"content": truncate(content, 100),
})
default:
logger.WarnCF("onebot", "Unknown message type, cannot route", map[string]interface{}{
"type": evt.MessageType,
"message_id": evt.MessageID,
"user_id": evt.UserID,
"type": raw.MessageType,
"message_id": messageID,
"user_id": userID,
})
return
}
if evt.Sender.Nickname != "" {
metadata["nickname"] = evt.Sender.Nickname
}
logger.DebugCF("onebot", "Forwarding message to bus", map[string]interface{}{
"sender_id": senderID,
"chat_id": chatID,
"content": truncate(content, 100),
logger.InfoCF("onebot", "Received "+raw.MessageType+" message", map[string]interface{}{
"sender": senderID,
"chat_id": chatID,
"message_id": messageID,
"length": len(content),
"content": truncate(content, 100),
"media_count": len(parsed.Media),
})
c.HandleMessage(senderID, chatID, content, []string{}, metadata)
if sender.Nickname != "" {
metadata["nickname"] = sender.Nickname
}
c.lastMessageID.Store(chatID, messageID)
if raw.MessageType == "group" && messageID != "" && messageID != "0" {
c.setMsgEmojiLike(messageID, 289, true)
c.pendingEmojiMsg.Store(chatID, messageID)
}
c.HandleMessage(senderID, chatID, content, parsed.Media, metadata)
}
func (c *OneBotChannel) isDuplicate(messageID string) bool {
+25
View File
@@ -25,6 +25,7 @@ type SlackChannel struct {
api *slack.Client
socketClient *socketmode.Client
botUserID string
teamID string
transcriber *voice.GroqTranscriber
ctx context.Context
cancel context.CancelFunc
@@ -72,6 +73,7 @@ func (c *SlackChannel) Start(ctx context.Context) error {
return fmt.Errorf("slack auth test failed: %w", err)
}
c.botUserID = authResp.UserID
c.teamID = authResp.TeamID
logger.InfoCF("slack", "Slack bot connected", map[string]interface{}{
"bot_user_id": c.botUserID,
@@ -274,11 +276,21 @@ func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) {
return
}
peerKind := "channel"
peerID := channelID
if strings.HasPrefix(channelID, "D") {
peerKind = "direct"
peerID = senderID
}
metadata := map[string]string{
"message_ts": messageTS,
"channel_id": channelID,
"thread_ts": threadTS,
"platform": "slack",
"peer_kind": peerKind,
"peer_id": peerID,
"team_id": c.teamID,
}
logger.DebugCF("slack", "Received message", map[string]interface{}{
@@ -331,12 +343,22 @@ func (c *SlackChannel) handleAppMention(ev *slackevents.AppMentionEvent) {
return
}
mentionPeerKind := "channel"
mentionPeerID := channelID
if strings.HasPrefix(channelID, "D") {
mentionPeerKind = "direct"
mentionPeerID = senderID
}
metadata := map[string]string{
"message_ts": messageTS,
"channel_id": channelID,
"thread_ts": threadTS,
"platform": "slack",
"is_mention": "true",
"peer_kind": mentionPeerKind,
"peer_id": mentionPeerID,
"team_id": c.teamID,
}
c.HandleMessage(senderID, chatID, content, nil, metadata)
@@ -373,6 +395,9 @@ func (c *SlackChannel) handleSlashCommand(event socketmode.Event) {
"platform": "slack",
"is_command": "true",
"trigger_id": cmd.TriggerID,
"peer_kind": "channel",
"peer_id": channelID,
"team_id": c.teamID,
}
logger.DebugCF("slack", "Slash command received", map[string]interface{}{
+9
View File
@@ -354,12 +354,21 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Mes
c.placeholders.Store(chatIDStr, pID)
}
peerKind := "direct"
peerID := fmt.Sprintf("%d", user.ID)
if message.Chat.Type != "private" {
peerKind = "group"
peerID = fmt.Sprintf("%d", chatID)
}
metadata := map[string]string{
"message_id": fmt.Sprintf("%d", message.MessageID),
"user_id": fmt.Sprintf("%d", user.ID),
"username": user.Username,
"first_name": user.FirstName,
"is_group": fmt.Sprintf("%t", message.Chat.Type != "private"),
"peer_kind": peerKind,
"peer_id": peerID,
}
c.HandleMessage(fmt.Sprintf("%d", user.ID), fmt.Sprintf("%d", chatID), content, mediaPaths, metadata)
+118 -24
View File
@@ -46,6 +46,8 @@ func (f *FlexibleStringSlice) UnmarshalJSON(data []byte) error {
type Config struct {
Agents AgentsConfig `json:"agents"`
Bindings []AgentBinding `json:"bindings,omitempty"`
Session SessionConfig `json:"session,omitempty"`
Channels ChannelsConfig `json:"channels"`
Providers ProvidersConfig `json:"providers"`
ModelList []ModelConfig `json:"model_list"` // New model-centric provider configuration
@@ -59,16 +61,97 @@ type Config struct {
type AgentsConfig struct {
Defaults AgentDefaults `json:"defaults"`
List []AgentConfig `json:"list,omitempty"`
}
// AgentModelConfig supports both string and structured model config.
// String format: "gpt-4" (just primary, no fallbacks)
// Object format: {"primary": "gpt-4", "fallbacks": ["claude-haiku"]}
type AgentModelConfig struct {
Primary string `json:"primary,omitempty"`
Fallbacks []string `json:"fallbacks,omitempty"`
}
func (m *AgentModelConfig) UnmarshalJSON(data []byte) error {
var s string
if err := json.Unmarshal(data, &s); err == nil {
m.Primary = s
m.Fallbacks = nil
return nil
}
type raw struct {
Primary string `json:"primary"`
Fallbacks []string `json:"fallbacks"`
}
var r raw
if err := json.Unmarshal(data, &r); err != nil {
return err
}
m.Primary = r.Primary
m.Fallbacks = r.Fallbacks
return nil
}
func (m AgentModelConfig) MarshalJSON() ([]byte, error) {
if len(m.Fallbacks) == 0 && m.Primary != "" {
return json.Marshal(m.Primary)
}
type raw struct {
Primary string `json:"primary,omitempty"`
Fallbacks []string `json:"fallbacks,omitempty"`
}
return json.Marshal(raw{Primary: m.Primary, Fallbacks: m.Fallbacks})
}
type AgentConfig struct {
ID string `json:"id"`
Default bool `json:"default,omitempty"`
Name string `json:"name,omitempty"`
Workspace string `json:"workspace,omitempty"`
Model *AgentModelConfig `json:"model,omitempty"`
Skills []string `json:"skills,omitempty"`
Subagents *SubagentsConfig `json:"subagents,omitempty"`
}
type SubagentsConfig struct {
AllowAgents []string `json:"allow_agents,omitempty"`
Model *AgentModelConfig `json:"model,omitempty"`
}
type PeerMatch struct {
Kind string `json:"kind"`
ID string `json:"id"`
}
type BindingMatch struct {
Channel string `json:"channel"`
AccountID string `json:"account_id,omitempty"`
Peer *PeerMatch `json:"peer,omitempty"`
GuildID string `json:"guild_id,omitempty"`
TeamID string `json:"team_id,omitempty"`
}
type AgentBinding struct {
AgentID string `json:"agent_id"`
Match BindingMatch `json:"match"`
}
type SessionConfig struct {
DMScope string `json:"dm_scope,omitempty"`
IdentityLinks map[string][]string `json:"identity_links,omitempty"`
}
type AgentDefaults struct {
Workspace string `json:"workspace" env:"PICOCLAW_AGENTS_DEFAULTS_WORKSPACE"`
RestrictToWorkspace bool `json:"restrict_to_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE"`
Provider string `json:"provider" env:"PICOCLAW_AGENTS_DEFAULTS_PROVIDER"`
Model string `json:"model" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL"`
MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"`
Temperature float64 `json:"temperature" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"`
MaxToolIterations int `json:"max_tool_iterations" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"`
Workspace string `json:"workspace" env:"PICOCLAW_AGENTS_DEFAULTS_WORKSPACE"`
RestrictToWorkspace bool `json:"restrict_to_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE"`
Provider string `json:"provider" env:"PICOCLAW_AGENTS_DEFAULTS_PROVIDER"`
Model string `json:"model" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL"`
ModelFallbacks []string `json:"model_fallbacks,omitempty"`
ImageModel string `json:"image_model,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_IMAGE_MODEL"`
ImageModelFallbacks []string `json:"image_model_fallbacks,omitempty"`
MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"`
Temperature float64 `json:"temperature" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"`
MaxToolIterations int `json:"max_tool_iterations" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"`
}
type ChannelsConfig struct {
@@ -170,23 +253,23 @@ type DevicesConfig struct {
}
type ProvidersConfig struct {
Anthropic ProviderConfig `json:"anthropic"`
OpenAI ProviderConfig `json:"openai"`
OpenRouter ProviderConfig `json:"openrouter"`
Groq ProviderConfig `json:"groq"`
Zhipu ProviderConfig `json:"zhipu"`
VLLM ProviderConfig `json:"vllm"`
Gemini ProviderConfig `json:"gemini"`
Nvidia ProviderConfig `json:"nvidia"`
Ollama ProviderConfig `json:"ollama"`
Moonshot ProviderConfig `json:"moonshot"`
ShengSuanYun ProviderConfig `json:"shengsuanyun"`
DeepSeek ProviderConfig `json:"deepseek"`
Cerebras ProviderConfig `json:"cerebras"`
VolcEngine ProviderConfig `json:"volcengine"`
GitHubCopilot ProviderConfig `json:"github_copilot"`
Antigravity ProviderConfig `json:"antigravity"`
Qwen ProviderConfig `json:"qwen"`
Anthropic ProviderConfig `json:"anthropic"`
OpenAI OpenAIProviderConfig `json:"openai"`
OpenRouter ProviderConfig `json:"openrouter"`
Groq ProviderConfig `json:"groq"`
Zhipu ProviderConfig `json:"zhipu"`
VLLM ProviderConfig `json:"vllm"`
Gemini ProviderConfig `json:"gemini"`
Nvidia ProviderConfig `json:"nvidia"`
Ollama ProviderConfig `json:"ollama"`
Moonshot ProviderConfig `json:"moonshot"`
ShengSuanYun ProviderConfig `json:"shengsuanyun"`
DeepSeek ProviderConfig `json:"deepseek"`
Cerebras ProviderConfig `json:"cerebras"`
VolcEngine ProviderConfig `json:"volcengine"`
GitHubCopilot ProviderConfig `json:"github_copilot"`
Antigravity ProviderConfig `json:"antigravity"`
Qwen ProviderConfig `json:"qwen"`
}
type ProviderConfig struct {
@@ -197,6 +280,11 @@ type ProviderConfig struct {
ConnectMode string `json:"connect_mode,omitempty" env:"PICOCLAW_PROVIDERS_{{.Name}}_CONNECT_MODE"` //only for Github Copilot, `stdio` or `grpc`
}
type OpenAIProviderConfig struct {
ProviderConfig
WebSearch bool `json:"web_search" env:"PICOCLAW_PROVIDERS_OPENAI_WEB_SEARCH"`
}
// ModelConfig represents a model-centric provider configuration.
// It allows adding new providers (especially OpenAI-compatible ones) via configuration only.
// The model field uses protocol prefix format: [protocol/]model-identifier
@@ -265,9 +353,15 @@ type CronToolsConfig struct {
ExecTimeoutMinutes int `json:"exec_timeout_minutes" env:"PICOCLAW_TOOLS_CRON_EXEC_TIMEOUT_MINUTES"` // 0 means no timeout
}
type ExecConfig struct {
EnableDenyPatterns bool `json:"enable_deny_patterns" env:"PICOCLAW_TOOLS_EXEC_ENABLE_DENY_PATTERNS"`
CustomDenyPatterns []string `json:"custom_deny_patterns" env:"PICOCLAW_TOOLS_EXEC_CUSTOM_DENY_PATTERNS"`
}
type ToolsConfig struct {
Web WebToolsConfig `json:"web"`
Cron CronToolsConfig `json:"cron"`
Exec ExecConfig `json:"exec"`
}
func LoadConfig(path string) (*Config, error) {
+220 -32
View File
@@ -1,12 +1,193 @@
package config
import (
"encoding/json"
"os"
"path/filepath"
"runtime"
"testing"
)
func TestAgentModelConfig_UnmarshalString(t *testing.T) {
var m AgentModelConfig
if err := json.Unmarshal([]byte(`"gpt-4"`), &m); err != nil {
t.Fatalf("unmarshal string: %v", err)
}
if m.Primary != "gpt-4" {
t.Errorf("Primary = %q, want 'gpt-4'", m.Primary)
}
if m.Fallbacks != nil {
t.Errorf("Fallbacks = %v, want nil", m.Fallbacks)
}
}
func TestAgentModelConfig_UnmarshalObject(t *testing.T) {
var m AgentModelConfig
data := `{"primary": "claude-opus", "fallbacks": ["gpt-4o-mini", "haiku"]}`
if err := json.Unmarshal([]byte(data), &m); err != nil {
t.Fatalf("unmarshal object: %v", err)
}
if m.Primary != "claude-opus" {
t.Errorf("Primary = %q, want 'claude-opus'", m.Primary)
}
if len(m.Fallbacks) != 2 {
t.Fatalf("Fallbacks len = %d, want 2", len(m.Fallbacks))
}
if m.Fallbacks[0] != "gpt-4o-mini" || m.Fallbacks[1] != "haiku" {
t.Errorf("Fallbacks = %v", m.Fallbacks)
}
}
func TestAgentModelConfig_MarshalString(t *testing.T) {
m := AgentModelConfig{Primary: "gpt-4"}
data, err := json.Marshal(m)
if err != nil {
t.Fatalf("marshal: %v", err)
}
if string(data) != `"gpt-4"` {
t.Errorf("marshal = %s, want '\"gpt-4\"'", string(data))
}
}
func TestAgentModelConfig_MarshalObject(t *testing.T) {
m := AgentModelConfig{Primary: "claude-opus", Fallbacks: []string{"haiku"}}
data, err := json.Marshal(m)
if err != nil {
t.Fatalf("marshal: %v", err)
}
var result map[string]interface{}
json.Unmarshal(data, &result)
if result["primary"] != "claude-opus" {
t.Errorf("primary = %v", result["primary"])
}
}
func TestAgentConfig_FullParse(t *testing.T) {
jsonData := `{
"agents": {
"defaults": {
"workspace": "~/.picoclaw/workspace",
"model": "glm-4.7",
"max_tokens": 8192,
"max_tool_iterations": 20
},
"list": [
{
"id": "sales",
"default": true,
"name": "Sales Bot",
"model": "gpt-4"
},
{
"id": "support",
"name": "Support Bot",
"model": {
"primary": "claude-opus",
"fallbacks": ["haiku"]
},
"subagents": {
"allow_agents": ["sales"]
}
}
]
},
"bindings": [
{
"agent_id": "support",
"match": {
"channel": "telegram",
"account_id": "*",
"peer": {"kind": "direct", "id": "user123"}
}
}
],
"session": {
"dm_scope": "per-peer",
"identity_links": {
"john": ["telegram:123", "discord:john#1234"]
}
}
}`
cfg := DefaultConfig()
if err := json.Unmarshal([]byte(jsonData), cfg); err != nil {
t.Fatalf("unmarshal: %v", err)
}
if len(cfg.Agents.List) != 2 {
t.Fatalf("agents.list len = %d, want 2", len(cfg.Agents.List))
}
sales := cfg.Agents.List[0]
if sales.ID != "sales" || !sales.Default || sales.Name != "Sales Bot" {
t.Errorf("sales = %+v", sales)
}
if sales.Model == nil || sales.Model.Primary != "gpt-4" {
t.Errorf("sales.Model = %+v", sales.Model)
}
support := cfg.Agents.List[1]
if support.ID != "support" || support.Name != "Support Bot" {
t.Errorf("support = %+v", support)
}
if support.Model == nil || support.Model.Primary != "claude-opus" {
t.Errorf("support.Model = %+v", support.Model)
}
if len(support.Model.Fallbacks) != 1 || support.Model.Fallbacks[0] != "haiku" {
t.Errorf("support.Model.Fallbacks = %v", support.Model.Fallbacks)
}
if support.Subagents == nil || len(support.Subagents.AllowAgents) != 1 {
t.Errorf("support.Subagents = %+v", support.Subagents)
}
if len(cfg.Bindings) != 1 {
t.Fatalf("bindings len = %d, want 1", len(cfg.Bindings))
}
binding := cfg.Bindings[0]
if binding.AgentID != "support" || binding.Match.Channel != "telegram" {
t.Errorf("binding = %+v", binding)
}
if binding.Match.Peer == nil || binding.Match.Peer.Kind != "direct" || binding.Match.Peer.ID != "user123" {
t.Errorf("binding.Match.Peer = %+v", binding.Match.Peer)
}
if cfg.Session.DMScope != "per-peer" {
t.Errorf("Session.DMScope = %q", cfg.Session.DMScope)
}
if len(cfg.Session.IdentityLinks) != 1 {
t.Errorf("Session.IdentityLinks = %v", cfg.Session.IdentityLinks)
}
links := cfg.Session.IdentityLinks["john"]
if len(links) != 2 {
t.Errorf("john links = %v", links)
}
}
func TestConfig_BackwardCompat_NoAgentsList(t *testing.T) {
jsonData := `{
"agents": {
"defaults": {
"workspace": "~/.picoclaw/workspace",
"model": "glm-4.7",
"max_tokens": 8192,
"max_tool_iterations": 20
}
}
}`
cfg := DefaultConfig()
if err := json.Unmarshal([]byte(jsonData), cfg); err != nil {
t.Fatalf("unmarshal: %v", err)
}
if len(cfg.Agents.List) != 0 {
t.Errorf("agents.list should be empty for backward compat, got %d", len(cfg.Agents.List))
}
if len(cfg.Bindings) != 0 {
t.Errorf("bindings should be empty, got %d", len(cfg.Bindings))
}
}
// TestDefaultConfig_HeartbeatEnabled verifies heartbeat is enabled by default
func TestDefaultConfig_HeartbeatEnabled(t *testing.T) {
cfg := DefaultConfig()
@@ -20,8 +201,6 @@ func TestDefaultConfig_HeartbeatEnabled(t *testing.T) {
func TestDefaultConfig_WorkspacePath(t *testing.T) {
cfg := DefaultConfig()
// Just verify the workspace is set, don't compare exact paths
// since expandHome behavior may differ based on environment
if cfg.Agents.Defaults.Workspace == "" {
t.Error("Workspace should not be empty")
}
@@ -79,7 +258,6 @@ func TestDefaultConfig_Gateway(t *testing.T) {
func TestDefaultConfig_Providers(t *testing.T) {
cfg := DefaultConfig()
// Verify all providers are empty by default
if cfg.Providers.Anthropic.APIKey != "" {
t.Error("Anthropic API key should be empty by default")
}
@@ -89,46 +267,18 @@ func TestDefaultConfig_Providers(t *testing.T) {
if cfg.Providers.OpenRouter.APIKey != "" {
t.Error("OpenRouter API key should be empty by default")
}
if cfg.Providers.Groq.APIKey != "" {
t.Error("Groq API key should be empty by default")
}
if cfg.Providers.Zhipu.APIKey != "" {
t.Error("Zhipu API key should be empty by default")
}
if cfg.Providers.VLLM.APIKey != "" {
t.Error("VLLM API key should be empty by default")
}
if cfg.Providers.Gemini.APIKey != "" {
t.Error("Gemini API key should be empty by default")
}
}
// TestDefaultConfig_Channels verifies channels are disabled by default
func TestDefaultConfig_Channels(t *testing.T) {
cfg := DefaultConfig()
// Verify all channels are disabled by default
if cfg.Channels.WhatsApp.Enabled {
t.Error("WhatsApp should be disabled by default")
}
if cfg.Channels.Telegram.Enabled {
t.Error("Telegram should be disabled by default")
}
if cfg.Channels.Feishu.Enabled {
t.Error("Feishu should be disabled by default")
}
if cfg.Channels.Discord.Enabled {
t.Error("Discord should be disabled by default")
}
if cfg.Channels.MaixCam.Enabled {
t.Error("MaixCam should be disabled by default")
}
if cfg.Channels.QQ.Enabled {
t.Error("QQ should be disabled by default")
}
if cfg.Channels.DingTalk.Enabled {
t.Error("DingTalk should be disabled by default")
}
if cfg.Channels.Slack.Enabled {
t.Error("Slack should be disabled by default")
}
@@ -178,7 +328,6 @@ func TestSaveConfig_FilePermissions(t *testing.T) {
func TestConfig_Complete(t *testing.T) {
cfg := DefaultConfig()
// Verify complete config structure
if cfg.Agents.Defaults.Workspace == "" {
t.Error("Workspace should not be empty")
}
@@ -204,3 +353,42 @@ func TestConfig_Complete(t *testing.T) {
t.Error("Heartbeat should be enabled by default")
}
}
func TestDefaultConfig_OpenAIWebSearchEnabled(t *testing.T) {
cfg := DefaultConfig()
if !cfg.Providers.OpenAI.WebSearch {
t.Fatal("DefaultConfig().Providers.OpenAI.WebSearch should be true")
}
}
func TestLoadConfig_OpenAIWebSearchDefaultsTrueWhenUnset(t *testing.T) {
dir := t.TempDir()
configPath := filepath.Join(dir, "config.json")
if err := os.WriteFile(configPath, []byte(`{"providers":{"openai":{"api_base":""}}}`), 0o600); err != nil {
t.Fatalf("WriteFile() error: %v", err)
}
cfg, err := LoadConfig(configPath)
if err != nil {
t.Fatalf("LoadConfig() error: %v", err)
}
if !cfg.Providers.OpenAI.WebSearch {
t.Fatal("OpenAI codex web search should remain true when unset in config file")
}
}
func TestLoadConfig_OpenAIWebSearchCanBeDisabled(t *testing.T) {
dir := t.TempDir()
configPath := filepath.Join(dir, "config.json")
if err := os.WriteFile(configPath, []byte(`{"providers":{"openai":{"web_search":false}}}`), 0o600); err != nil {
t.Fatalf("WriteFile() error: %v", err)
}
cfg, err := LoadConfig(configPath)
if err != nil {
t.Fatalf("LoadConfig() error: %v", err)
}
if cfg.Providers.OpenAI.WebSearch {
t.Fatal("OpenAI codex web search should be false when disabled in config file")
}
}
+7 -6
View File
@@ -1,15 +1,16 @@
// Package constants provides shared constants across the codebase.
package constants
// InternalChannels defines channels that are used for internal communication
// internalChannels defines channels that are used for internal communication
// and should not be exposed to external users or recorded as last active channel.
var InternalChannels = map[string]bool{
"cli": true,
"system": true,
"subagent": true,
var internalChannels = map[string]struct{}{
"cli": {},
"system": {},
"subagent": {},
}
// IsInternalChannel returns true if the channel is an internal channel.
func IsInternalChannel(channel string) bool {
return InternalChannels[channel]
_, found := internalChannels[channel]
return found
}
+11 -1
View File
@@ -110,7 +110,10 @@ func ConvertConfig(data map[string]interface{}) (*config.Config, []string, error
case "anthropic":
cfg.Providers.Anthropic = pc
case "openai":
cfg.Providers.OpenAI = pc
cfg.Providers.OpenAI = config.OpenAIProviderConfig{
ProviderConfig: pc,
WebSearch: getBoolOrDefault(pMap, "web_search", true),
}
case "openrouter":
cfg.Providers.OpenRouter = pc
case "groq":
@@ -374,6 +377,13 @@ func getBool(data map[string]interface{}, key string) (bool, bool) {
return b, ok
}
func getBoolOrDefault(data map[string]interface{}, key string, defaultVal bool) bool {
if v, ok := getBool(data, key); ok {
return v
}
return defaultVal
}
func getStringSlice(data map[string]interface{}, key string) []string {
v, ok := data[key]
if !ok {
+18
View File
@@ -299,6 +299,24 @@ func TestConvertConfig(t *testing.T) {
})
}
func TestSupportedProvidersCompatibility(t *testing.T) {
expected := []string{
"anthropic",
"openai",
"openrouter",
"groq",
"zhipu",
"vllm",
"gemini",
}
for _, provider := range expected {
if !supportedProviders[provider] {
t.Fatalf("supportedProviders missing expected key %q", provider)
}
}
}
func TestMergeConfig(t *testing.T) {
t.Run("fills empty fields", func(t *testing.T) {
existing := config.DefaultConfig()
+248
View File
@@ -0,0 +1,248 @@
package anthropicprovider
import (
"context"
"encoding/json"
"fmt"
"log"
"strings"
"github.com/anthropics/anthropic-sdk-go"
"github.com/anthropics/anthropic-sdk-go/option"
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
)
type ToolCall = protocoltypes.ToolCall
type FunctionCall = protocoltypes.FunctionCall
type LLMResponse = protocoltypes.LLMResponse
type UsageInfo = protocoltypes.UsageInfo
type Message = protocoltypes.Message
type ToolDefinition = protocoltypes.ToolDefinition
type ToolFunctionDefinition = protocoltypes.ToolFunctionDefinition
const defaultBaseURL = "https://api.anthropic.com"
type Provider struct {
client *anthropic.Client
tokenSource func() (string, error)
baseURL string
}
func NewProvider(token string) *Provider {
return NewProviderWithBaseURL(token, "")
}
func NewProviderWithBaseURL(token, apiBase string) *Provider {
baseURL := normalizeBaseURL(apiBase)
client := anthropic.NewClient(
option.WithAuthToken(token),
option.WithBaseURL(baseURL),
)
return &Provider{
client: &client,
baseURL: baseURL,
}
}
func NewProviderWithClient(client *anthropic.Client) *Provider {
return &Provider{
client: client,
baseURL: defaultBaseURL,
}
}
func NewProviderWithTokenSource(token string, tokenSource func() (string, error)) *Provider {
return NewProviderWithTokenSourceAndBaseURL(token, tokenSource, "")
}
func NewProviderWithTokenSourceAndBaseURL(token string, tokenSource func() (string, error), apiBase string) *Provider {
p := NewProviderWithBaseURL(token, apiBase)
p.tokenSource = tokenSource
return p
}
func (p *Provider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) {
var opts []option.RequestOption
if p.tokenSource != nil {
tok, err := p.tokenSource()
if err != nil {
return nil, fmt.Errorf("refreshing token: %w", err)
}
opts = append(opts, option.WithAuthToken(tok))
}
params, err := buildParams(messages, tools, model, options)
if err != nil {
return nil, err
}
resp, err := p.client.Messages.New(ctx, params, opts...)
if err != nil {
return nil, fmt.Errorf("claude API call: %w", err)
}
return parseResponse(resp), nil
}
func (p *Provider) GetDefaultModel() string {
return "claude-sonnet-4-5-20250929"
}
func (p *Provider) BaseURL() string {
return p.baseURL
}
func buildParams(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (anthropic.MessageNewParams, error) {
var system []anthropic.TextBlockParam
var anthropicMessages []anthropic.MessageParam
for _, msg := range messages {
switch msg.Role {
case "system":
system = append(system, anthropic.TextBlockParam{Text: msg.Content})
case "user":
if msg.ToolCallID != "" {
anthropicMessages = append(anthropicMessages,
anthropic.NewUserMessage(anthropic.NewToolResultBlock(msg.ToolCallID, msg.Content, false)),
)
} else {
anthropicMessages = append(anthropicMessages,
anthropic.NewUserMessage(anthropic.NewTextBlock(msg.Content)),
)
}
case "assistant":
if len(msg.ToolCalls) > 0 {
var blocks []anthropic.ContentBlockParamUnion
if msg.Content != "" {
blocks = append(blocks, anthropic.NewTextBlock(msg.Content))
}
for _, tc := range msg.ToolCalls {
blocks = append(blocks, anthropic.NewToolUseBlock(tc.ID, tc.Arguments, tc.Name))
}
anthropicMessages = append(anthropicMessages, anthropic.NewAssistantMessage(blocks...))
} else {
anthropicMessages = append(anthropicMessages,
anthropic.NewAssistantMessage(anthropic.NewTextBlock(msg.Content)),
)
}
case "tool":
anthropicMessages = append(anthropicMessages,
anthropic.NewUserMessage(anthropic.NewToolResultBlock(msg.ToolCallID, msg.Content, false)),
)
}
}
maxTokens := int64(4096)
if mt, ok := options["max_tokens"].(int); ok {
maxTokens = int64(mt)
}
params := anthropic.MessageNewParams{
Model: anthropic.Model(model),
Messages: anthropicMessages,
MaxTokens: maxTokens,
}
if len(system) > 0 {
params.System = system
}
if temp, ok := options["temperature"].(float64); ok {
params.Temperature = anthropic.Float(temp)
}
if len(tools) > 0 {
params.Tools = translateTools(tools)
}
return params, nil
}
func translateTools(tools []ToolDefinition) []anthropic.ToolUnionParam {
result := make([]anthropic.ToolUnionParam, 0, len(tools))
for _, t := range tools {
tool := anthropic.ToolParam{
Name: t.Function.Name,
InputSchema: anthropic.ToolInputSchemaParam{
Properties: t.Function.Parameters["properties"],
},
}
if desc := t.Function.Description; desc != "" {
tool.Description = anthropic.String(desc)
}
if req, ok := t.Function.Parameters["required"].([]interface{}); ok {
required := make([]string, 0, len(req))
for _, r := range req {
if s, ok := r.(string); ok {
required = append(required, s)
}
}
tool.InputSchema.Required = required
}
result = append(result, anthropic.ToolUnionParam{OfTool: &tool})
}
return result
}
func parseResponse(resp *anthropic.Message) *LLMResponse {
var content string
var toolCalls []ToolCall
for _, block := range resp.Content {
switch block.Type {
case "text":
tb := block.AsText()
content += tb.Text
case "tool_use":
tu := block.AsToolUse()
var args map[string]interface{}
if err := json.Unmarshal(tu.Input, &args); err != nil {
log.Printf("anthropic: failed to decode tool call input for %q: %v", tu.Name, err)
args = map[string]interface{}{"raw": string(tu.Input)}
}
toolCalls = append(toolCalls, ToolCall{
ID: tu.ID,
Name: tu.Name,
Arguments: args,
})
}
}
finishReason := "stop"
switch resp.StopReason {
case anthropic.StopReasonToolUse:
finishReason = "tool_calls"
case anthropic.StopReasonMaxTokens:
finishReason = "length"
case anthropic.StopReasonEndTurn:
finishReason = "stop"
}
return &LLMResponse{
Content: content,
ToolCalls: toolCalls,
FinishReason: finishReason,
Usage: &UsageInfo{
PromptTokens: int(resp.Usage.InputTokens),
CompletionTokens: int(resp.Usage.OutputTokens),
TotalTokens: int(resp.Usage.InputTokens + resp.Usage.OutputTokens),
},
}
}
func normalizeBaseURL(apiBase string) string {
base := strings.TrimSpace(apiBase)
if base == "" {
return defaultBaseURL
}
base = strings.TrimRight(base, "/")
if strings.HasSuffix(base, "/v1") {
base = strings.TrimSuffix(base, "/v1")
}
if base == "" {
return defaultBaseURL
}
return base
}
+265
View File
@@ -0,0 +1,265 @@
package anthropicprovider
import (
"encoding/json"
"net/http"
"net/http/httptest"
"sync/atomic"
"testing"
"github.com/anthropics/anthropic-sdk-go"
anthropicoption "github.com/anthropics/anthropic-sdk-go/option"
)
func TestBuildParams_BasicMessage(t *testing.T) {
messages := []Message{
{Role: "user", Content: "Hello"},
}
params, err := buildParams(messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{
"max_tokens": 1024,
})
if err != nil {
t.Fatalf("buildParams() error: %v", err)
}
if string(params.Model) != "claude-sonnet-4-5-20250929" {
t.Errorf("Model = %q, want %q", params.Model, "claude-sonnet-4-5-20250929")
}
if params.MaxTokens != 1024 {
t.Errorf("MaxTokens = %d, want 1024", params.MaxTokens)
}
if len(params.Messages) != 1 {
t.Fatalf("len(Messages) = %d, want 1", len(params.Messages))
}
}
func TestBuildParams_SystemMessage(t *testing.T) {
messages := []Message{
{Role: "system", Content: "You are helpful"},
{Role: "user", Content: "Hi"},
}
params, err := buildParams(messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{})
if err != nil {
t.Fatalf("buildParams() error: %v", err)
}
if len(params.System) != 1 {
t.Fatalf("len(System) = %d, want 1", len(params.System))
}
if params.System[0].Text != "You are helpful" {
t.Errorf("System[0].Text = %q, want %q", params.System[0].Text, "You are helpful")
}
if len(params.Messages) != 1 {
t.Fatalf("len(Messages) = %d, want 1", len(params.Messages))
}
}
func TestBuildParams_ToolCallMessage(t *testing.T) {
messages := []Message{
{Role: "user", Content: "What's the weather?"},
{
Role: "assistant",
Content: "",
ToolCalls: []ToolCall{
{
ID: "call_1",
Name: "get_weather",
Arguments: map[string]interface{}{"city": "SF"},
},
},
},
{Role: "tool", Content: `{"temp": 72}`, ToolCallID: "call_1"},
}
params, err := buildParams(messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{})
if err != nil {
t.Fatalf("buildParams() error: %v", err)
}
if len(params.Messages) != 3 {
t.Fatalf("len(Messages) = %d, want 3", len(params.Messages))
}
}
func TestBuildParams_WithTools(t *testing.T) {
tools := []ToolDefinition{
{
Type: "function",
Function: ToolFunctionDefinition{
Name: "get_weather",
Description: "Get weather for a city",
Parameters: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"city": map[string]interface{}{"type": "string"},
},
"required": []interface{}{"city"},
},
},
},
}
params, err := buildParams([]Message{{Role: "user", Content: "Hi"}}, tools, "claude-sonnet-4-5-20250929", map[string]interface{}{})
if err != nil {
t.Fatalf("buildParams() error: %v", err)
}
if len(params.Tools) != 1 {
t.Fatalf("len(Tools) = %d, want 1", len(params.Tools))
}
}
func TestParseResponse_TextOnly(t *testing.T) {
resp := &anthropic.Message{
Content: []anthropic.ContentBlockUnion{},
Usage: anthropic.Usage{
InputTokens: 10,
OutputTokens: 20,
},
}
result := parseResponse(resp)
if result.Usage.PromptTokens != 10 {
t.Errorf("PromptTokens = %d, want 10", result.Usage.PromptTokens)
}
if result.Usage.CompletionTokens != 20 {
t.Errorf("CompletionTokens = %d, want 20", result.Usage.CompletionTokens)
}
if result.FinishReason != "stop" {
t.Errorf("FinishReason = %q, want %q", result.FinishReason, "stop")
}
}
func TestParseResponse_StopReasons(t *testing.T) {
tests := []struct {
stopReason anthropic.StopReason
want string
}{
{anthropic.StopReasonEndTurn, "stop"},
{anthropic.StopReasonMaxTokens, "length"},
{anthropic.StopReasonToolUse, "tool_calls"},
}
for _, tt := range tests {
resp := &anthropic.Message{
StopReason: tt.stopReason,
}
result := parseResponse(resp)
if result.FinishReason != tt.want {
t.Errorf("StopReason %q: FinishReason = %q, want %q", tt.stopReason, result.FinishReason, tt.want)
}
}
}
func TestProvider_ChatRoundTrip(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/messages" {
http.Error(w, "not found", http.StatusNotFound)
return
}
if r.Header.Get("Authorization") != "Bearer test-token" {
http.Error(w, "unauthorized", http.StatusUnauthorized)
return
}
var reqBody map[string]interface{}
json.NewDecoder(r.Body).Decode(&reqBody)
resp := map[string]interface{}{
"id": "msg_test",
"type": "message",
"role": "assistant",
"model": reqBody["model"],
"stop_reason": "end_turn",
"content": []map[string]interface{}{
{"type": "text", "text": "Hello! How can I help you?"},
},
"usage": map[string]interface{}{
"input_tokens": 15,
"output_tokens": 8,
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
provider := NewProviderWithClient(createAnthropicTestClient(server.URL, "test-token"))
messages := []Message{{Role: "user", Content: "Hello"}}
resp, err := provider.Chat(t.Context(), messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{"max_tokens": 1024})
if err != nil {
t.Fatalf("Chat() error: %v", err)
}
if resp.Content != "Hello! How can I help you?" {
t.Errorf("Content = %q, want %q", resp.Content, "Hello! How can I help you?")
}
if resp.FinishReason != "stop" {
t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "stop")
}
if resp.Usage.PromptTokens != 15 {
t.Errorf("PromptTokens = %d, want 15", resp.Usage.PromptTokens)
}
}
func TestProvider_GetDefaultModel(t *testing.T) {
p := NewProvider("test-token")
if got := p.GetDefaultModel(); got != "claude-sonnet-4-5-20250929" {
t.Errorf("GetDefaultModel() = %q, want %q", got, "claude-sonnet-4-5-20250929")
}
}
func TestProvider_NewProviderWithBaseURL_NormalizesV1Suffix(t *testing.T) {
p := NewProviderWithBaseURL("token", "https://api.anthropic.com/v1/")
if got := p.BaseURL(); got != "https://api.anthropic.com" {
t.Fatalf("BaseURL() = %q, want %q", got, "https://api.anthropic.com")
}
}
func TestProvider_ChatUsesTokenSource(t *testing.T) {
var requests int32
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/messages" {
http.Error(w, "not found", http.StatusNotFound)
return
}
atomic.AddInt32(&requests, 1)
if got := r.Header.Get("Authorization"); got != "Bearer refreshed-token" {
http.Error(w, "unauthorized", http.StatusUnauthorized)
return
}
var reqBody map[string]interface{}
json.NewDecoder(r.Body).Decode(&reqBody)
resp := map[string]interface{}{
"id": "msg_test",
"type": "message",
"role": "assistant",
"model": reqBody["model"],
"stop_reason": "end_turn",
"content": []map[string]interface{}{
{"type": "text", "text": "ok"},
},
"usage": map[string]interface{}{
"input_tokens": 1,
"output_tokens": 1,
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
p := NewProviderWithTokenSourceAndBaseURL("stale-token", func() (string, error) {
return "refreshed-token", nil
}, server.URL)
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hello"}}, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{})
if err != nil {
t.Fatalf("Chat() error: %v", err)
}
if got := atomic.LoadInt32(&requests); got != 1 {
t.Fatalf("requests = %d, want 1", got)
}
}
func createAnthropicTestClient(baseURL, token string) *anthropic.Client {
c := anthropic.NewClient(
anthropicoption.WithAuthToken(token),
anthropicoption.WithBaseURL(baseURL),
)
return &c
}
+28 -170
View File
@@ -2,200 +2,58 @@ package providers
import (
"context"
"encoding/json"
"fmt"
"github.com/anthropics/anthropic-sdk-go"
"github.com/anthropics/anthropic-sdk-go/option"
"github.com/sipeed/picoclaw/pkg/auth"
anthropicprovider "github.com/sipeed/picoclaw/pkg/providers/anthropic"
)
type ClaudeProvider struct {
client *anthropic.Client
tokenSource func() (string, error)
delegate *anthropicprovider.Provider
}
func NewClaudeProvider(token string) *ClaudeProvider {
client := anthropic.NewClient(
option.WithAuthToken(token),
option.WithBaseURL("https://api.anthropic.com"),
)
return &ClaudeProvider{client: &client}
return &ClaudeProvider{
delegate: anthropicprovider.NewProvider(token),
}
}
func NewClaudeProviderWithBaseURL(token, apiBase string) *ClaudeProvider {
return &ClaudeProvider{
delegate: anthropicprovider.NewProviderWithBaseURL(token, apiBase),
}
}
func NewClaudeProviderWithTokenSource(token string, tokenSource func() (string, error)) *ClaudeProvider {
p := NewClaudeProvider(token)
p.tokenSource = tokenSource
return p
return &ClaudeProvider{
delegate: anthropicprovider.NewProviderWithTokenSource(token, tokenSource),
}
}
func NewClaudeProviderWithTokenSourceAndBaseURL(token string, tokenSource func() (string, error), apiBase string) *ClaudeProvider {
return &ClaudeProvider{
delegate: anthropicprovider.NewProviderWithTokenSourceAndBaseURL(token, tokenSource, apiBase),
}
}
func newClaudeProviderWithDelegate(delegate *anthropicprovider.Provider) *ClaudeProvider {
return &ClaudeProvider{delegate: delegate}
}
func (p *ClaudeProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) {
var opts []option.RequestOption
if p.tokenSource != nil {
tok, err := p.tokenSource()
if err != nil {
return nil, fmt.Errorf("refreshing token: %w", err)
}
opts = append(opts, option.WithAuthToken(tok))
}
params, err := buildClaudeParams(messages, tools, model, options)
resp, err := p.delegate.Chat(ctx, messages, tools, model, options)
if err != nil {
return nil, err
}
resp, err := p.client.Messages.New(ctx, params, opts...)
if err != nil {
return nil, fmt.Errorf("claude API call: %w", err)
}
return parseClaudeResponse(resp), nil
return resp, nil
}
func (p *ClaudeProvider) GetDefaultModel() string {
return "claude-sonnet-4-5-20250929"
}
func buildClaudeParams(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (anthropic.MessageNewParams, error) {
var system []anthropic.TextBlockParam
var anthropicMessages []anthropic.MessageParam
for _, msg := range messages {
switch msg.Role {
case "system":
system = append(system, anthropic.TextBlockParam{Text: msg.Content})
case "user":
if msg.ToolCallID != "" {
anthropicMessages = append(anthropicMessages,
anthropic.NewUserMessage(anthropic.NewToolResultBlock(msg.ToolCallID, msg.Content, false)),
)
} else {
anthropicMessages = append(anthropicMessages,
anthropic.NewUserMessage(anthropic.NewTextBlock(msg.Content)),
)
}
case "assistant":
if len(msg.ToolCalls) > 0 {
var blocks []anthropic.ContentBlockParamUnion
if msg.Content != "" {
blocks = append(blocks, anthropic.NewTextBlock(msg.Content))
}
for _, tc := range msg.ToolCalls {
blocks = append(blocks, anthropic.NewToolUseBlock(tc.ID, tc.Arguments, tc.Name))
}
anthropicMessages = append(anthropicMessages, anthropic.NewAssistantMessage(blocks...))
} else {
anthropicMessages = append(anthropicMessages,
anthropic.NewAssistantMessage(anthropic.NewTextBlock(msg.Content)),
)
}
case "tool":
anthropicMessages = append(anthropicMessages,
anthropic.NewUserMessage(anthropic.NewToolResultBlock(msg.ToolCallID, msg.Content, false)),
)
}
}
maxTokens := int64(4096)
if mt, ok := options["max_tokens"].(int); ok {
maxTokens = int64(mt)
}
params := anthropic.MessageNewParams{
Model: anthropic.Model(model),
Messages: anthropicMessages,
MaxTokens: maxTokens,
}
if len(system) > 0 {
params.System = system
}
if temp, ok := options["temperature"].(float64); ok {
params.Temperature = anthropic.Float(temp)
}
if len(tools) > 0 {
params.Tools = translateToolsForClaude(tools)
}
return params, nil
}
func translateToolsForClaude(tools []ToolDefinition) []anthropic.ToolUnionParam {
result := make([]anthropic.ToolUnionParam, 0, len(tools))
for _, t := range tools {
tool := anthropic.ToolParam{
Name: t.Function.Name,
InputSchema: anthropic.ToolInputSchemaParam{
Properties: t.Function.Parameters["properties"],
},
}
if desc := t.Function.Description; desc != "" {
tool.Description = anthropic.String(desc)
}
if req, ok := t.Function.Parameters["required"].([]interface{}); ok {
required := make([]string, 0, len(req))
for _, r := range req {
if s, ok := r.(string); ok {
required = append(required, s)
}
}
tool.InputSchema.Required = required
}
result = append(result, anthropic.ToolUnionParam{OfTool: &tool})
}
return result
}
func parseClaudeResponse(resp *anthropic.Message) *LLMResponse {
var content string
var toolCalls []ToolCall
for _, block := range resp.Content {
switch block.Type {
case "text":
tb := block.AsText()
content += tb.Text
case "tool_use":
tu := block.AsToolUse()
var args map[string]interface{}
if err := json.Unmarshal(tu.Input, &args); err != nil {
args = map[string]interface{}{"raw": string(tu.Input)}
}
toolCalls = append(toolCalls, ToolCall{
ID: tu.ID,
Name: tu.Name,
Arguments: args,
})
}
}
finishReason := "stop"
switch resp.StopReason {
case anthropic.StopReasonToolUse:
finishReason = "tool_calls"
case anthropic.StopReasonMaxTokens:
finishReason = "length"
case anthropic.StopReasonEndTurn:
finishReason = "stop"
}
return &LLMResponse{
Content: content,
ToolCalls: toolCalls,
FinishReason: finishReason,
Usage: &UsageInfo{
PromptTokens: int(resp.Usage.InputTokens),
CompletionTokens: int(resp.Usage.OutputTokens),
TotalTokens: int(resp.Usage.InputTokens + resp.Usage.OutputTokens),
},
}
return p.delegate.GetDefaultModel()
}
func createClaudeTokenSource() func() (string, error) {
return func() (string, error) {
cred, err := auth.GetCredential("anthropic")
cred, err := getCredential("anthropic")
if err != nil {
return "", fmt.Errorf("loading auth credentials: %w", err)
}
+3 -134
View File
@@ -8,140 +8,9 @@ import (
"github.com/anthropics/anthropic-sdk-go"
anthropicoption "github.com/anthropics/anthropic-sdk-go/option"
anthropicprovider "github.com/sipeed/picoclaw/pkg/providers/anthropic"
)
func TestBuildClaudeParams_BasicMessage(t *testing.T) {
messages := []Message{
{Role: "user", Content: "Hello"},
}
params, err := buildClaudeParams(messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{
"max_tokens": 1024,
})
if err != nil {
t.Fatalf("buildClaudeParams() error: %v", err)
}
if string(params.Model) != "claude-sonnet-4-5-20250929" {
t.Errorf("Model = %q, want %q", params.Model, "claude-sonnet-4-5-20250929")
}
if params.MaxTokens != 1024 {
t.Errorf("MaxTokens = %d, want 1024", params.MaxTokens)
}
if len(params.Messages) != 1 {
t.Fatalf("len(Messages) = %d, want 1", len(params.Messages))
}
}
func TestBuildClaudeParams_SystemMessage(t *testing.T) {
messages := []Message{
{Role: "system", Content: "You are helpful"},
{Role: "user", Content: "Hi"},
}
params, err := buildClaudeParams(messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{})
if err != nil {
t.Fatalf("buildClaudeParams() error: %v", err)
}
if len(params.System) != 1 {
t.Fatalf("len(System) = %d, want 1", len(params.System))
}
if params.System[0].Text != "You are helpful" {
t.Errorf("System[0].Text = %q, want %q", params.System[0].Text, "You are helpful")
}
if len(params.Messages) != 1 {
t.Fatalf("len(Messages) = %d, want 1", len(params.Messages))
}
}
func TestBuildClaudeParams_ToolCallMessage(t *testing.T) {
messages := []Message{
{Role: "user", Content: "What's the weather?"},
{
Role: "assistant",
Content: "",
ToolCalls: []ToolCall{
{
ID: "call_1",
Name: "get_weather",
Arguments: map[string]interface{}{"city": "SF"},
},
},
},
{Role: "tool", Content: `{"temp": 72}`, ToolCallID: "call_1"},
}
params, err := buildClaudeParams(messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{})
if err != nil {
t.Fatalf("buildClaudeParams() error: %v", err)
}
if len(params.Messages) != 3 {
t.Fatalf("len(Messages) = %d, want 3", len(params.Messages))
}
}
func TestBuildClaudeParams_WithTools(t *testing.T) {
tools := []ToolDefinition{
{
Type: "function",
Function: ToolFunctionDefinition{
Name: "get_weather",
Description: "Get weather for a city",
Parameters: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"city": map[string]interface{}{"type": "string"},
},
"required": []interface{}{"city"},
},
},
},
}
params, err := buildClaudeParams([]Message{{Role: "user", Content: "Hi"}}, tools, "claude-sonnet-4-5-20250929", map[string]interface{}{})
if err != nil {
t.Fatalf("buildClaudeParams() error: %v", err)
}
if len(params.Tools) != 1 {
t.Fatalf("len(Tools) = %d, want 1", len(params.Tools))
}
}
func TestParseClaudeResponse_TextOnly(t *testing.T) {
resp := &anthropic.Message{
Content: []anthropic.ContentBlockUnion{},
Usage: anthropic.Usage{
InputTokens: 10,
OutputTokens: 20,
},
}
result := parseClaudeResponse(resp)
if result.Usage.PromptTokens != 10 {
t.Errorf("PromptTokens = %d, want 10", result.Usage.PromptTokens)
}
if result.Usage.CompletionTokens != 20 {
t.Errorf("CompletionTokens = %d, want 20", result.Usage.CompletionTokens)
}
if result.FinishReason != "stop" {
t.Errorf("FinishReason = %q, want %q", result.FinishReason, "stop")
}
}
func TestParseClaudeResponse_StopReasons(t *testing.T) {
tests := []struct {
stopReason anthropic.StopReason
want string
}{
{anthropic.StopReasonEndTurn, "stop"},
{anthropic.StopReasonMaxTokens, "length"},
{anthropic.StopReasonToolUse, "tool_calls"},
}
for _, tt := range tests {
resp := &anthropic.Message{
StopReason: tt.stopReason,
}
result := parseClaudeResponse(resp)
if result.FinishReason != tt.want {
t.Errorf("StopReason %q: FinishReason = %q, want %q", tt.stopReason, result.FinishReason, tt.want)
}
}
}
func TestClaudeProvider_ChatRoundTrip(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/messages" {
@@ -175,8 +44,8 @@ func TestClaudeProvider_ChatRoundTrip(t *testing.T) {
}))
defer server.Close()
provider := NewClaudeProvider("test-token")
provider.client = createAnthropicTestClient(server.URL, "test-token")
delegate := anthropicprovider.NewProviderWithClient(createAnthropicTestClient(server.URL, "test-token"))
provider := newClaudeProviderWithDelegate(delegate)
messages := []Message{{Role: "user", Content: "Hello"}}
resp, err := provider.Chat(t.Context(), messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{"max_tokens": 1024})
@@ -0,0 +1,119 @@
//go:build integration
package providers
import (
"context"
exec "os/exec"
"strings"
"testing"
"time"
)
// TestIntegration_RealCodexCLI tests the CodexCliProvider with a real codex CLI.
// Run with: go test -tags=integration ./pkg/providers/...
func TestIntegration_RealCodexCLI(t *testing.T) {
path, err := exec.LookPath("codex")
if err != nil {
t.Skip("codex CLI not found in PATH, skipping integration test")
}
t.Logf("Using codex CLI at: %s", path)
p := NewCodexCliProvider(t.TempDir())
ctx, cancel := context.WithTimeout(context.Background(), 90*time.Second)
defer cancel()
resp, err := p.Chat(ctx, []Message{
{Role: "user", Content: "Respond with only the word 'pong'. Nothing else."},
}, nil, "", nil)
if err != nil {
t.Fatalf("Chat() with real CLI error = %v", err)
}
if resp.Content == "" {
t.Error("Content is empty")
}
if resp.FinishReason != "stop" {
t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "stop")
}
if resp.Usage != nil {
t.Logf("Usage: prompt=%d, completion=%d, total=%d",
resp.Usage.PromptTokens, resp.Usage.CompletionTokens, resp.Usage.TotalTokens)
}
t.Logf("Response content: %q", resp.Content)
if !strings.Contains(strings.ToLower(resp.Content), "pong") {
t.Errorf("Content = %q, expected to contain 'pong'", resp.Content)
}
}
func TestIntegration_RealCodexCLI_WithSystemPrompt(t *testing.T) {
if _, err := exec.LookPath("codex"); err != nil {
t.Skip("codex CLI not found in PATH")
}
p := NewCodexCliProvider(t.TempDir())
ctx, cancel := context.WithTimeout(context.Background(), 90*time.Second)
defer cancel()
resp, err := p.Chat(ctx, []Message{
{Role: "system", Content: "You are a calculator. Only respond with numbers. No text."},
{Role: "user", Content: "What is 2+2?"},
}, nil, "", nil)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
t.Logf("Response: %q", resp.Content)
if !strings.Contains(resp.Content, "4") {
t.Errorf("Content = %q, expected to contain '4'", resp.Content)
}
}
func TestIntegration_RealCodexCLI_ParsesRealJSONL(t *testing.T) {
if _, err := exec.LookPath("codex"); err != nil {
t.Skip("codex CLI not found in PATH")
}
// Run codex directly and verify our parser handles real output
cmd := exec.Command("codex", "exec",
"--json",
"--dangerously-bypass-approvals-and-sandbox",
"--skip-git-repo-check",
"--color", "never",
"-C", t.TempDir(),
"-")
cmd.Stdin = strings.NewReader("Say hi")
output, err := cmd.Output()
if err != nil {
// codex may write diagnostic noise to stderr but still produce valid output
if len(output) == 0 {
t.Fatalf("codex CLI failed: %v", err)
}
}
t.Logf("Raw CLI output (first 500 chars): %s", string(output[:min(len(output), 500)]))
// Verify our parser can handle real output
p := NewCodexCliProvider("")
resp, err := p.parseJSONLEvents(string(output))
if err != nil {
t.Fatalf("parseJSONLEvents() failed on real CLI output: %v", err)
}
if resp.Content == "" {
t.Error("parsed Content is empty")
}
if resp.FinishReason != "stop" {
t.Errorf("FinishReason = %q, want stop", resp.FinishReason)
}
t.Logf("Parsed: content=%q, finish=%s, usage=%+v", resp.Content, resp.FinishReason, resp.Usage)
}
+59 -18
View File
@@ -18,9 +18,10 @@ const codexDefaultModel = "gpt-5.2"
const codexDefaultInstructions = "You are Codex, a coding assistant."
type CodexProvider struct {
client *openai.Client
accountID string
tokenSource func() (string, string, error)
client *openai.Client
accountID string
tokenSource func() (string, string, error)
enableWebSearch bool
}
const defaultCodexInstructions = "You are Codex, a coding assistant."
@@ -37,8 +38,9 @@ func NewCodexProvider(token, accountID string) *CodexProvider {
}
client := openai.NewClient(opts...)
return &CodexProvider{
client: &client,
accountID: accountID,
client: &client,
accountID: accountID,
enableWebSearch: true,
}
}
@@ -78,7 +80,7 @@ func (p *CodexProvider) Chat(ctx context.Context, messages []Message, tools []To
})
}
params := buildCodexParams(messages, tools, resolvedModel, options)
params := buildCodexParams(messages, tools, resolvedModel, options, p.enableWebSearch)
stream := p.client.Responses.NewStreaming(ctx, params, opts...)
defer stream.Close()
@@ -182,7 +184,7 @@ func resolveCodexModel(model string) (string, string) {
return codexDefaultModel, "unsupported model family"
}
func buildCodexParams(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) responses.ResponseNewParams {
func buildCodexParams(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}, enableWebSearch bool) responses.ResponseNewParams {
var inputItems responses.ResponseInputParam
var instructions string
@@ -217,12 +219,18 @@ func buildCodexParams(messages []Message, tools []ToolDefinition, model string,
})
}
for _, tc := range msg.ToolCalls {
argsJSON, _ := json.Marshal(tc.Arguments)
name, args, ok := resolveCodexToolCall(tc)
if !ok {
logger.WarnCF("provider.codex", "Skipping invalid tool call in history", map[string]interface{}{
"call_id": tc.ID,
})
continue
}
inputItems = append(inputItems, responses.ResponseInputItemUnionParam{
OfFunctionCall: &responses.ResponseFunctionToolCallParam{
CallID: tc.ID,
Name: tc.Name,
Arguments: string(argsJSON),
Name: name,
Arguments: args,
},
})
}
@@ -260,20 +268,50 @@ func buildCodexParams(messages []Message, tools []ToolDefinition, model string,
params.Instructions = openai.Opt(defaultCodexInstructions)
}
if maxTokens, ok := options["max_tokens"].(int); ok {
params.MaxOutputTokens = openai.Opt(int64(maxTokens))
}
if len(tools) > 0 {
params.Tools = translateToolsForCodex(tools)
if len(tools) > 0 || enableWebSearch {
params.Tools = translateToolsForCodex(tools, enableWebSearch)
}
return params
}
func translateToolsForCodex(tools []ToolDefinition) []responses.ToolUnionParam {
result := make([]responses.ToolUnionParam, 0, len(tools))
func resolveCodexToolCall(tc ToolCall) (name string, arguments string, ok bool) {
name = tc.Name
if name == "" && tc.Function != nil {
name = tc.Function.Name
}
if name == "" {
return "", "", false
}
if len(tc.Arguments) > 0 {
argsJSON, err := json.Marshal(tc.Arguments)
if err != nil {
return "", "", false
}
return name, string(argsJSON), true
}
if tc.Function != nil && tc.Function.Arguments != "" {
return name, tc.Function.Arguments, true
}
return name, "{}", true
}
func translateToolsForCodex(tools []ToolDefinition, enableWebSearch bool) []responses.ToolUnionParam {
capHint := len(tools)
if enableWebSearch {
capHint++
}
result := make([]responses.ToolUnionParam, 0, capHint)
for _, t := range tools {
if t.Type != "function" {
continue
}
if enableWebSearch && strings.EqualFold(t.Function.Name, "web_search") {
continue
}
ft := responses.FunctionToolParam{
Name: t.Function.Name,
Parameters: t.Function.Parameters,
@@ -284,6 +322,9 @@ func translateToolsForCodex(tools []ToolDefinition) []responses.ToolUnionParam {
}
result = append(result, responses.ToolUnionParam{OfFunction: &ft})
}
if enableWebSearch {
result = append(result, responses.ToolParamOfWebSearch(responses.WebSearchToolTypeWebSearch))
}
return result
}
+172 -5
View File
@@ -19,7 +19,7 @@ func TestBuildCodexParams_BasicMessage(t *testing.T) {
params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{
"max_tokens": 2048,
"temperature": 0.7,
})
}, true)
if params.Model != "gpt-4o" {
t.Errorf("Model = %q, want %q", params.Model, "gpt-4o")
}
@@ -29,6 +29,9 @@ func TestBuildCodexParams_BasicMessage(t *testing.T) {
if params.Instructions.Or("") != defaultCodexInstructions {
t.Errorf("Instructions = %q, want %q", params.Instructions.Or(""), defaultCodexInstructions)
}
if params.MaxOutputTokens.Valid() {
t.Fatalf("MaxOutputTokens should not be set for Codex backend")
}
}
func TestBuildCodexParams_SystemAsInstructions(t *testing.T) {
@@ -36,7 +39,7 @@ func TestBuildCodexParams_SystemAsInstructions(t *testing.T) {
{Role: "system", Content: "You are helpful"},
{Role: "user", Content: "Hi"},
}
params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{})
params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{}, true)
if !params.Instructions.Valid() {
t.Fatal("Instructions should be set")
}
@@ -56,7 +59,7 @@ func TestBuildCodexParams_ToolCallConversation(t *testing.T) {
},
{Role: "tool", Content: `{"temp": 72}`, ToolCallID: "call_1"},
}
params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{})
params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{}, false)
if params.Input.OfInputItemList == nil {
t.Fatal("Input.OfInputItemList should not be nil")
}
@@ -65,6 +68,45 @@ func TestBuildCodexParams_ToolCallConversation(t *testing.T) {
}
}
func TestBuildCodexParams_ToolCallFunctionFallback(t *testing.T) {
messages := []Message{
{Role: "user", Content: "Read a file"},
{
Role: "assistant",
ToolCalls: []ToolCall{
{
ID: "call_1",
Type: "function",
Function: &FunctionCall{
Name: "read_file",
Arguments: `{"path":"README.md"}`,
},
},
},
},
{Role: "tool", Content: "ok", ToolCallID: "call_1"},
}
params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{}, false)
if params.Input.OfInputItemList == nil {
t.Fatal("Input.OfInputItemList should not be nil")
}
if len(params.Input.OfInputItemList) != 3 {
t.Fatalf("len(Input items) = %d, want 3", len(params.Input.OfInputItemList))
}
fc := params.Input.OfInputItemList[1].OfFunctionCall
if fc == nil {
t.Fatal("assistant tool call should be converted to function_call input item")
}
if fc.Name != "read_file" {
t.Errorf("Function call name = %q, want %q", fc.Name, "read_file")
}
if fc.Arguments != `{"path":"README.md"}` {
t.Errorf("Function call arguments = %q, want %q", fc.Arguments, `{"path":"README.md"}`)
}
}
func TestBuildCodexParams_WithTools(t *testing.T) {
tools := []ToolDefinition{
{
@@ -81,7 +123,7 @@ func TestBuildCodexParams_WithTools(t *testing.T) {
},
},
}
params := buildCodexParams([]Message{{Role: "user", Content: "Hi"}}, tools, "gpt-4o", map[string]interface{}{})
params := buildCodexParams([]Message{{Role: "user", Content: "Hi"}}, tools, "gpt-4o", map[string]interface{}{}, false)
if len(params.Tools) != 1 {
t.Fatalf("len(Tools) = %d, want 1", len(params.Tools))
}
@@ -94,12 +136,61 @@ func TestBuildCodexParams_WithTools(t *testing.T) {
}
func TestBuildCodexParams_StoreIsFalse(t *testing.T) {
params := buildCodexParams([]Message{{Role: "user", Content: "Hi"}}, nil, "gpt-4o", map[string]interface{}{})
params := buildCodexParams([]Message{{Role: "user", Content: "Hi"}}, nil, "gpt-4o", map[string]interface{}{}, false)
if !params.Store.Valid() || params.Store.Or(true) != false {
t.Error("Store should be explicitly set to false")
}
}
func TestBuildCodexParams_DefaultWebSearchEnabled(t *testing.T) {
params := buildCodexParams([]Message{{Role: "user", Content: "Hi"}}, nil, "gpt-4o", map[string]interface{}{}, true)
if len(params.Tools) != 1 {
t.Fatalf("len(Tools) = %d, want 1", len(params.Tools))
}
if params.Tools[0].OfWebSearch == nil {
t.Fatal("Tool should include built-in web_search")
}
if params.Tools[0].OfWebSearch.Type != responses.WebSearchToolTypeWebSearch {
t.Errorf("Web search tool type = %q, want %q", params.Tools[0].OfWebSearch.Type, responses.WebSearchToolTypeWebSearch)
}
}
func TestBuildCodexParams_WebSearchFunctionReplacedWithBuiltin(t *testing.T) {
tools := []ToolDefinition{
{
Type: "function",
Function: ToolFunctionDefinition{
Name: "web_search",
Description: "local web search",
Parameters: map[string]interface{}{
"type": "object",
},
},
},
{
Type: "function",
Function: ToolFunctionDefinition{
Name: "read_file",
Description: "read file",
Parameters: map[string]interface{}{
"type": "object",
},
},
},
}
params := buildCodexParams([]Message{{Role: "user", Content: "Hi"}}, tools, "gpt-4o", map[string]interface{}{}, true)
if len(params.Tools) != 2 {
t.Fatalf("len(Tools) = %d, want 2", len(params.Tools))
}
if params.Tools[0].OfFunction == nil || params.Tools[0].OfFunction.Name != "read_file" {
t.Fatalf("first tool should be function read_file, got %#v", params.Tools[0])
}
if params.Tools[1].OfWebSearch == nil {
t.Fatalf("second tool should be built-in web_search, got %#v", params.Tools[1])
}
}
func TestParseCodexResponse_TextOutput(t *testing.T) {
respJSON := `{
"id": "resp_test",
@@ -214,6 +305,20 @@ func TestCodexProvider_ChatRoundTrip(t *testing.T) {
http.Error(w, "stream must be true", http.StatusBadRequest)
return
}
if _, ok := reqBody["max_output_tokens"]; ok {
http.Error(w, "max_output_tokens is not supported", http.StatusBadRequest)
return
}
toolsAny, ok := reqBody["tools"].([]interface{})
if !ok || len(toolsAny) != 1 {
http.Error(w, "missing default web search tool", http.StatusBadRequest)
return
}
toolObj, ok := toolsAny[0].(map[string]interface{})
if !ok || toolObj["type"] != "web_search" {
http.Error(w, "expected web_search tool", http.StatusBadRequest)
return
}
resp := map[string]interface{}{
"id": "resp_test",
@@ -261,6 +366,64 @@ func TestCodexProvider_ChatRoundTrip(t *testing.T) {
}
}
func TestCodexProvider_ChatRoundTrip_WebSearchDisabled(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/responses" {
http.Error(w, "not found: "+r.URL.Path, http.StatusNotFound)
return
}
var reqBody map[string]interface{}
if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil {
http.Error(w, "invalid json", http.StatusBadRequest)
return
}
if _, ok := reqBody["tools"]; ok {
http.Error(w, "tools should be absent when web search disabled", http.StatusBadRequest)
return
}
resp := map[string]interface{}{
"id": "resp_test",
"object": "response",
"status": "completed",
"output": []map[string]interface{}{
{
"id": "msg_1",
"type": "message",
"role": "assistant",
"status": "completed",
"content": []map[string]interface{}{
{"type": "output_text", "text": "Hi from Codex!"},
},
},
},
"usage": map[string]interface{}{
"input_tokens": 4,
"output_tokens": 3,
"total_tokens": 7,
"input_tokens_details": map[string]interface{}{"cached_tokens": 0},
"output_tokens_details": map[string]interface{}{"reasoning_tokens": 0},
},
}
writeCompletedSSE(w, resp)
}))
defer server.Close()
provider := NewCodexProvider("test-token", "acc-123")
provider.enableWebSearch = false
provider.client = createOpenAITestClient(server.URL, "test-token", "acc-123")
messages := []Message{{Role: "user", Content: "Hello"}}
resp, err := provider.Chat(t.Context(), messages, nil, "gpt-4o", map[string]interface{}{})
if err != nil {
t.Fatalf("Chat() error: %v", err)
}
if resp.Content != "Hi from Codex!" {
t.Errorf("Content = %q, want %q", resp.Content, "Hi from Codex!")
}
}
func TestCodexProvider_ChatRoundTrip_TokenSourceFallbackAccountID(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/responses" {
@@ -293,6 +456,10 @@ func TestCodexProvider_ChatRoundTrip_TokenSourceFallbackAccountID(t *testing.T)
http.Error(w, "temperature is not supported", http.StatusBadRequest)
return
}
if _, ok := reqBody["max_output_tokens"]; ok {
http.Error(w, "max_output_tokens is not supported", http.StatusBadRequest)
return
}
if reqBody["stream"] != true {
http.Error(w, "stream must be true", http.StatusBadRequest)
return
+207
View File
@@ -0,0 +1,207 @@
package providers
import (
"math"
"sync"
"time"
)
const (
defaultFailureWindow = 24 * time.Hour
)
// CooldownTracker manages per-provider cooldown state for the fallback chain.
// Thread-safe via sync.RWMutex. In-memory only (resets on restart).
type CooldownTracker struct {
mu sync.RWMutex
entries map[string]*cooldownEntry
failureWindow time.Duration
nowFunc func() time.Time // for testing
}
type cooldownEntry struct {
ErrorCount int
FailureCounts map[FailoverReason]int
CooldownEnd time.Time // standard cooldown expiry
DisabledUntil time.Time // billing-specific disable expiry
DisabledReason FailoverReason // reason for disable (billing)
LastFailure time.Time
}
// NewCooldownTracker creates a tracker with default 24h failure window.
func NewCooldownTracker() *CooldownTracker {
return &CooldownTracker{
entries: make(map[string]*cooldownEntry),
failureWindow: defaultFailureWindow,
nowFunc: time.Now,
}
}
// MarkFailure records a failure for a provider and sets appropriate cooldown.
// Resets error counts if last failure was more than failureWindow ago.
func (ct *CooldownTracker) MarkFailure(provider string, reason FailoverReason) {
ct.mu.Lock()
defer ct.mu.Unlock()
now := ct.nowFunc()
entry := ct.getOrCreate(provider)
// 24h failure window reset: if no failure in failureWindow, reset counters.
if !entry.LastFailure.IsZero() && now.Sub(entry.LastFailure) > ct.failureWindow {
entry.ErrorCount = 0
entry.FailureCounts = make(map[FailoverReason]int)
}
entry.ErrorCount++
entry.FailureCounts[reason]++
entry.LastFailure = now
if reason == FailoverBilling {
billingCount := entry.FailureCounts[FailoverBilling]
entry.DisabledUntil = now.Add(calculateBillingCooldown(billingCount))
entry.DisabledReason = FailoverBilling
} else {
entry.CooldownEnd = now.Add(calculateStandardCooldown(entry.ErrorCount))
}
}
// MarkSuccess resets all counters and cooldowns for a provider.
func (ct *CooldownTracker) MarkSuccess(provider string) {
ct.mu.Lock()
defer ct.mu.Unlock()
entry := ct.entries[provider]
if entry == nil {
return
}
entry.ErrorCount = 0
entry.FailureCounts = make(map[FailoverReason]int)
entry.CooldownEnd = time.Time{}
entry.DisabledUntil = time.Time{}
entry.DisabledReason = ""
}
// IsAvailable returns true if the provider is not in cooldown or disabled.
func (ct *CooldownTracker) IsAvailable(provider string) bool {
ct.mu.RLock()
defer ct.mu.RUnlock()
entry := ct.entries[provider]
if entry == nil {
return true
}
now := ct.nowFunc()
// Billing disable takes precedence (longer cooldown).
if !entry.DisabledUntil.IsZero() && now.Before(entry.DisabledUntil) {
return false
}
// Standard cooldown.
if !entry.CooldownEnd.IsZero() && now.Before(entry.CooldownEnd) {
return false
}
return true
}
// CooldownRemaining returns how long until the provider becomes available.
// Returns 0 if already available.
func (ct *CooldownTracker) CooldownRemaining(provider string) time.Duration {
ct.mu.RLock()
defer ct.mu.RUnlock()
entry := ct.entries[provider]
if entry == nil {
return 0
}
now := ct.nowFunc()
var remaining time.Duration
if !entry.DisabledUntil.IsZero() && now.Before(entry.DisabledUntil) {
d := entry.DisabledUntil.Sub(now)
if d > remaining {
remaining = d
}
}
if !entry.CooldownEnd.IsZero() && now.Before(entry.CooldownEnd) {
d := entry.CooldownEnd.Sub(now)
if d > remaining {
remaining = d
}
}
return remaining
}
// ErrorCount returns the current error count for a provider.
func (ct *CooldownTracker) ErrorCount(provider string) int {
ct.mu.RLock()
defer ct.mu.RUnlock()
entry := ct.entries[provider]
if entry == nil {
return 0
}
return entry.ErrorCount
}
// FailureCount returns the failure count for a specific reason.
func (ct *CooldownTracker) FailureCount(provider string, reason FailoverReason) int {
ct.mu.RLock()
defer ct.mu.RUnlock()
entry := ct.entries[provider]
if entry == nil {
return 0
}
return entry.FailureCounts[reason]
}
func (ct *CooldownTracker) getOrCreate(provider string) *cooldownEntry {
entry := ct.entries[provider]
if entry == nil {
entry = &cooldownEntry{
FailureCounts: make(map[FailoverReason]int),
}
ct.entries[provider] = entry
}
return entry
}
// calculateStandardCooldown computes standard exponential backoff.
// Formula from OpenClaw: min(1h, 1min * 5^min(n-1, 3))
//
// 1 error → 1 min
// 2 errors → 5 min
// 3 errors → 25 min
// 4+ errors → 1 hour (cap)
func calculateStandardCooldown(errorCount int) time.Duration {
n := max(1, errorCount)
exp := min(n-1, 3)
ms := 60_000 * int(math.Pow(5, float64(exp)))
ms = min(3_600_000, ms) // cap at 1 hour
return time.Duration(ms) * time.Millisecond
}
// calculateBillingCooldown computes billing-specific exponential backoff.
// Formula from OpenClaw: min(24h, 5h * 2^min(n-1, 10))
//
// 1 error → 5 hours
// 2 errors → 10 hours
// 3 errors → 20 hours
// 4+ errors → 24 hours (cap)
func calculateBillingCooldown(billingErrorCount int) time.Duration {
const baseMs = 5 * 60 * 60 * 1000 // 5 hours
const maxMs = 24 * 60 * 60 * 1000 // 24 hours
n := max(1, billingErrorCount)
exp := min(n-1, 10)
raw := float64(baseMs) * math.Pow(2, float64(exp))
ms := int(math.Min(float64(maxMs), raw))
return time.Duration(ms) * time.Millisecond
}
+269
View File
@@ -0,0 +1,269 @@
package providers
import (
"sync"
"testing"
"time"
)
func newTestTracker(now time.Time) (*CooldownTracker, *time.Time) {
current := now
ct := NewCooldownTracker()
ct.nowFunc = func() time.Time { return current }
return ct, &current
}
func TestCooldown_InitiallyAvailable(t *testing.T) {
ct := NewCooldownTracker()
if !ct.IsAvailable("openai") {
t.Error("new provider should be available")
}
if ct.ErrorCount("openai") != 0 {
t.Error("new provider should have 0 errors")
}
}
func TestCooldown_StandardEscalation(t *testing.T) {
now := time.Now()
ct, current := newTestTracker(now)
// 1st error → 1 min cooldown
ct.MarkFailure("openai", FailoverRateLimit)
if ct.IsAvailable("openai") {
t.Error("should be in cooldown after 1st error")
}
// Advance 61 seconds → available
*current = now.Add(61 * time.Second)
if !ct.IsAvailable("openai") {
t.Error("should be available after 1 min cooldown")
}
// 2nd error → 5 min cooldown
ct.MarkFailure("openai", FailoverRateLimit)
*current = now.Add(61*time.Second + 4*time.Minute)
if ct.IsAvailable("openai") {
t.Error("should be in cooldown (5 min) after 2nd error")
}
*current = now.Add(61*time.Second + 6*time.Minute)
if !ct.IsAvailable("openai") {
t.Error("should be available after 5 min cooldown")
}
}
func TestCooldown_StandardCap(t *testing.T) {
// Verify formula: 1m, 5m, 25m, 1h, 1h, 1h...
expected := []time.Duration{
1 * time.Minute,
5 * time.Minute,
25 * time.Minute,
1 * time.Hour,
1 * time.Hour,
}
for i, want := range expected {
got := calculateStandardCooldown(i + 1)
if got != want {
t.Errorf("calculateStandardCooldown(%d) = %v, want %v", i+1, got, want)
}
}
}
func TestCooldown_BillingEscalation(t *testing.T) {
now := time.Now()
ct, current := newTestTracker(now)
// 1st billing error → 5h cooldown
ct.MarkFailure("openai", FailoverBilling)
if ct.IsAvailable("openai") {
t.Error("should be disabled after billing error")
}
// Advance 4h → still disabled
*current = now.Add(4 * time.Hour)
if ct.IsAvailable("openai") {
t.Error("should still be disabled (5h cooldown)")
}
// Advance 5h + 1s → available
*current = now.Add(5*time.Hour + 1*time.Second)
if !ct.IsAvailable("openai") {
t.Error("should be available after 5h billing cooldown")
}
}
func TestCooldown_BillingCap(t *testing.T) {
expected := []time.Duration{
5 * time.Hour,
10 * time.Hour,
20 * time.Hour,
24 * time.Hour,
24 * time.Hour,
}
for i, want := range expected {
got := calculateBillingCooldown(i + 1)
if got != want {
t.Errorf("calculateBillingCooldown(%d) = %v, want %v", i+1, got, want)
}
}
}
func TestCooldown_SuccessReset(t *testing.T) {
ct := NewCooldownTracker()
ct.MarkFailure("openai", FailoverRateLimit)
ct.MarkFailure("openai", FailoverBilling)
if ct.ErrorCount("openai") != 2 {
t.Errorf("error count = %d, want 2", ct.ErrorCount("openai"))
}
ct.MarkSuccess("openai")
if ct.ErrorCount("openai") != 0 {
t.Errorf("error count after success = %d, want 0", ct.ErrorCount("openai"))
}
if !ct.IsAvailable("openai") {
t.Error("should be available after success")
}
if ct.FailureCount("openai", FailoverRateLimit) != 0 {
t.Error("failure counts should be reset after success")
}
if ct.FailureCount("openai", FailoverBilling) != 0 {
t.Error("billing failure count should be reset after success")
}
}
func TestCooldown_FailureWindowReset(t *testing.T) {
now := time.Now()
ct, current := newTestTracker(now)
// 4 errors → 1h cooldown
for i := 0; i < 4; i++ {
ct.MarkFailure("openai", FailoverRateLimit)
*current = current.Add(2 * time.Second) // small advance between errors
}
if ct.ErrorCount("openai") != 4 {
t.Errorf("error count = %d, want 4", ct.ErrorCount("openai"))
}
// Advance 25 hours (past 24h failure window)
*current = now.Add(25 * time.Hour)
// Next error should reset counters first, then increment to 1
ct.MarkFailure("openai", FailoverRateLimit)
if ct.ErrorCount("openai") != 1 {
t.Errorf("error count after window reset = %d, want 1 (reset + 1)", ct.ErrorCount("openai"))
}
}
func TestCooldown_PerReasonTracking(t *testing.T) {
ct := NewCooldownTracker()
ct.MarkFailure("openai", FailoverRateLimit)
ct.MarkFailure("openai", FailoverRateLimit)
ct.MarkFailure("openai", FailoverBilling)
ct.MarkFailure("openai", FailoverAuth)
if ct.FailureCount("openai", FailoverRateLimit) != 2 {
t.Errorf("rate_limit count = %d, want 2", ct.FailureCount("openai", FailoverRateLimit))
}
if ct.FailureCount("openai", FailoverBilling) != 1 {
t.Errorf("billing count = %d, want 1", ct.FailureCount("openai", FailoverBilling))
}
if ct.FailureCount("openai", FailoverAuth) != 1 {
t.Errorf("auth count = %d, want 1", ct.FailureCount("openai", FailoverAuth))
}
if ct.ErrorCount("openai") != 4 {
t.Errorf("total error count = %d, want 4", ct.ErrorCount("openai"))
}
}
func TestCooldown_BillingTakesPrecedence(t *testing.T) {
now := time.Now()
ct, current := newTestTracker(now)
// Standard cooldown (1 min) + billing disable (5h)
ct.MarkFailure("openai", FailoverRateLimit) // 1 min cooldown
ct.MarkFailure("openai", FailoverBilling) // 5h disable
// After 2 min: standard cooldown expired but billing still active
*current = now.Add(2 * time.Minute)
if ct.IsAvailable("openai") {
t.Error("billing disable should take precedence over standard cooldown")
}
// After 5h + 1s: both expired
*current = now.Add(5*time.Hour + 1*time.Second)
if !ct.IsAvailable("openai") {
t.Error("should be available after all cooldowns expire")
}
}
func TestCooldown_CooldownRemaining(t *testing.T) {
now := time.Now()
ct, current := newTestTracker(now)
// No failures → 0 remaining
if ct.CooldownRemaining("openai") != 0 {
t.Error("expected 0 remaining for new provider")
}
ct.MarkFailure("openai", FailoverRateLimit)
*current = now.Add(30 * time.Second)
remaining := ct.CooldownRemaining("openai")
if remaining <= 0 || remaining > 1*time.Minute {
t.Errorf("remaining = %v, expected ~30s", remaining)
}
}
func TestCooldown_SuccessOnUnknownProvider(t *testing.T) {
ct := NewCooldownTracker()
// Should not panic
ct.MarkSuccess("nonexistent")
if !ct.IsAvailable("nonexistent") {
t.Error("nonexistent provider should be available")
}
}
func TestCooldown_ConcurrentAccess(t *testing.T) {
ct := NewCooldownTracker()
var wg sync.WaitGroup
for i := 0; i < 100; i++ {
wg.Add(3)
go func() {
defer wg.Done()
ct.MarkFailure("openai", FailoverRateLimit)
}()
go func() {
defer wg.Done()
ct.IsAvailable("openai")
}()
go func() {
defer wg.Done()
ct.MarkSuccess("openai")
}()
}
wg.Wait()
// If we got here without panic, concurrent access is safe
}
func TestCooldown_MultipleProviders(t *testing.T) {
ct := NewCooldownTracker()
ct.MarkFailure("openai", FailoverRateLimit)
ct.MarkFailure("anthropic", FailoverBilling)
if ct.IsAvailable("openai") {
t.Error("openai should be in cooldown")
}
if ct.IsAvailable("anthropic") {
t.Error("anthropic should be in cooldown")
}
// groq was never touched
if !ct.IsAvailable("groq") {
t.Error("groq should be available")
}
}
+253
View File
@@ -0,0 +1,253 @@
package providers
import (
"context"
"regexp"
"strings"
)
// errorPattern defines a single pattern (string or regex) for error classification.
type errorPattern struct {
substring string
regex *regexp.Regexp
}
func substr(s string) errorPattern { return errorPattern{substring: s} }
func rxp(r string) errorPattern { return errorPattern{regex: regexp.MustCompile("(?i)" + r)} }
// Error patterns organized by FailoverReason, matching OpenClaw production (~40 patterns).
var (
rateLimitPatterns = []errorPattern{
rxp(`rate[_ ]limit`),
substr("too many requests"),
substr("429"),
substr("exceeded your current quota"),
rxp(`exceeded.*quota`),
rxp(`resource has been exhausted`),
rxp(`resource.*exhausted`),
substr("resource_exhausted"),
substr("quota exceeded"),
substr("usage limit"),
}
overloadedPatterns = []errorPattern{
rxp(`overloaded_error`),
rxp(`"type"\s*:\s*"overloaded_error"`),
substr("overloaded"),
}
timeoutPatterns = []errorPattern{
substr("timeout"),
substr("timed out"),
substr("deadline exceeded"),
substr("context deadline exceeded"),
}
billingPatterns = []errorPattern{
rxp(`\b402\b`),
substr("payment required"),
substr("insufficient credits"),
substr("credit balance"),
substr("plans & billing"),
substr("insufficient balance"),
}
authPatterns = []errorPattern{
rxp(`invalid[_ ]?api[_ ]?key`),
substr("incorrect api key"),
substr("invalid token"),
substr("authentication"),
substr("re-authenticate"),
substr("oauth token refresh failed"),
substr("unauthorized"),
substr("forbidden"),
substr("access denied"),
substr("expired"),
substr("token has expired"),
rxp(`\b401\b`),
rxp(`\b403\b`),
substr("no credentials found"),
substr("no api key found"),
}
formatPatterns = []errorPattern{
substr("string should match pattern"),
substr("tool_use.id"),
substr("tool_use_id"),
substr("messages.1.content.1.tool_use.id"),
substr("invalid request format"),
}
imageDimensionPatterns = []errorPattern{
rxp(`image dimensions exceed max`),
}
imageSizePatterns = []errorPattern{
rxp(`image exceeds.*mb`),
}
// Transient HTTP status codes that map to timeout (server-side failures).
transientStatusCodes = map[int]bool{
500: true, 502: true, 503: true,
521: true, 522: true, 523: true, 524: true,
529: true,
}
)
// ClassifyError classifies an error into a FailoverError with reason.
// Returns nil if the error is not classifiable (unknown errors should not trigger fallback).
func ClassifyError(err error, provider, model string) *FailoverError {
if err == nil {
return nil
}
// Context cancellation: user abort, never fallback.
if err == context.Canceled {
return nil
}
// Context deadline exceeded: treat as timeout, always fallback.
if err == context.DeadlineExceeded {
return &FailoverError{
Reason: FailoverTimeout,
Provider: provider,
Model: model,
Wrapped: err,
}
}
msg := strings.ToLower(err.Error())
// Image dimension/size errors: non-retriable, non-fallback.
if IsImageDimensionError(msg) || IsImageSizeError(msg) {
return &FailoverError{
Reason: FailoverFormat,
Provider: provider,
Model: model,
Wrapped: err,
}
}
// Try HTTP status code extraction first.
if status := extractHTTPStatus(msg); status > 0 {
if reason := classifyByStatus(status); reason != "" {
return &FailoverError{
Reason: reason,
Provider: provider,
Model: model,
Status: status,
Wrapped: err,
}
}
}
// Message pattern matching (priority order from OpenClaw).
if reason := classifyByMessage(msg); reason != "" {
return &FailoverError{
Reason: reason,
Provider: provider,
Model: model,
Wrapped: err,
}
}
return nil
}
// classifyByStatus maps HTTP status codes to FailoverReason.
func classifyByStatus(status int) FailoverReason {
switch {
case status == 401 || status == 403:
return FailoverAuth
case status == 402:
return FailoverBilling
case status == 408:
return FailoverTimeout
case status == 429:
return FailoverRateLimit
case status == 400:
return FailoverFormat
case transientStatusCodes[status]:
return FailoverTimeout
}
return ""
}
// classifyByMessage matches error messages against patterns.
// Priority order matters (from OpenClaw classifyFailoverReason).
func classifyByMessage(msg string) FailoverReason {
if matchesAny(msg, rateLimitPatterns) {
return FailoverRateLimit
}
if matchesAny(msg, overloadedPatterns) {
return FailoverRateLimit // Overloaded treated as rate_limit
}
if matchesAny(msg, billingPatterns) {
return FailoverBilling
}
if matchesAny(msg, timeoutPatterns) {
return FailoverTimeout
}
if matchesAny(msg, authPatterns) {
return FailoverAuth
}
if matchesAny(msg, formatPatterns) {
return FailoverFormat
}
return ""
}
// extractHTTPStatus extracts an HTTP status code from an error message.
// Looks for patterns like "status: 429", "status 429", "HTTP 429", or standalone "429".
func extractHTTPStatus(msg string) int {
// Common patterns in Go HTTP error messages
patterns := []*regexp.Regexp{
regexp.MustCompile(`status[:\s]+(\d{3})`),
regexp.MustCompile(`HTTP[/\s]+\d*\.?\d*\s+(\d{3})`),
}
for _, p := range patterns {
if m := p.FindStringSubmatch(msg); len(m) > 1 {
return parseDigits(m[1])
}
}
return 0
}
// IsImageDimensionError returns true if the message indicates an image dimension error.
func IsImageDimensionError(msg string) bool {
return matchesAny(msg, imageDimensionPatterns)
}
// IsImageSizeError returns true if the message indicates an image file size error.
func IsImageSizeError(msg string) bool {
return matchesAny(msg, imageSizePatterns)
}
// matchesAny checks if msg matches any of the patterns.
func matchesAny(msg string, patterns []errorPattern) bool {
for _, p := range patterns {
if p.regex != nil {
if p.regex.MatchString(msg) {
return true
}
} else if p.substring != "" {
if strings.Contains(msg, p.substring) {
return true
}
}
}
return false
}
// parseDigits converts a string of digits to an int.
func parseDigits(s string) int {
n := 0
for _, c := range s {
if c >= '0' && c <= '9' {
n = n*10 + int(c-'0')
}
}
return n
}
+337
View File
@@ -0,0 +1,337 @@
package providers
import (
"context"
"errors"
"fmt"
"testing"
)
func TestClassifyError_Nil(t *testing.T) {
result := ClassifyError(nil, "openai", "gpt-4")
if result != nil {
t.Errorf("expected nil for nil error, got %+v", result)
}
}
func TestClassifyError_ContextCanceled(t *testing.T) {
result := ClassifyError(context.Canceled, "openai", "gpt-4")
if result != nil {
t.Errorf("expected nil for context.Canceled (user abort), got %+v", result)
}
}
func TestClassifyError_ContextDeadlineExceeded(t *testing.T) {
result := ClassifyError(context.DeadlineExceeded, "openai", "gpt-4")
if result == nil {
t.Fatal("expected non-nil for deadline exceeded")
}
if result.Reason != FailoverTimeout {
t.Errorf("reason = %q, want timeout", result.Reason)
}
}
func TestClassifyError_StatusCodes(t *testing.T) {
tests := []struct {
status int
reason FailoverReason
}{
{401, FailoverAuth},
{403, FailoverAuth},
{402, FailoverBilling},
{408, FailoverTimeout},
{429, FailoverRateLimit},
{400, FailoverFormat},
{500, FailoverTimeout},
{502, FailoverTimeout},
{503, FailoverTimeout},
{521, FailoverTimeout},
{522, FailoverTimeout},
{523, FailoverTimeout},
{524, FailoverTimeout},
{529, FailoverTimeout},
}
for _, tt := range tests {
err := fmt.Errorf("API error: status: %d something went wrong", tt.status)
result := ClassifyError(err, "test", "model")
if result == nil {
t.Errorf("status %d: expected non-nil", tt.status)
continue
}
if result.Reason != tt.reason {
t.Errorf("status %d: reason = %q, want %q", tt.status, result.Reason, tt.reason)
}
}
}
func TestClassifyError_RateLimitPatterns(t *testing.T) {
patterns := []string{
"rate limit exceeded",
"rate_limit reached",
"too many requests",
"exceeded your current quota",
"resource has been exhausted",
"resource_exhausted",
"quota exceeded",
"usage limit reached",
}
for _, msg := range patterns {
err := errors.New(msg)
result := ClassifyError(err, "openai", "gpt-4")
if result == nil {
t.Errorf("pattern %q: expected non-nil", msg)
continue
}
if result.Reason != FailoverRateLimit {
t.Errorf("pattern %q: reason = %q, want rate_limit", msg, result.Reason)
}
}
}
func TestClassifyError_OverloadedPatterns(t *testing.T) {
patterns := []string{
"overloaded_error",
`{"type": "overloaded_error"}`,
"server is overloaded",
}
for _, msg := range patterns {
err := errors.New(msg)
result := ClassifyError(err, "anthropic", "claude")
if result == nil {
t.Errorf("pattern %q: expected non-nil", msg)
continue
}
// Overloaded is treated as rate_limit
if result.Reason != FailoverRateLimit {
t.Errorf("pattern %q: reason = %q, want rate_limit", msg, result.Reason)
}
}
}
func TestClassifyError_BillingPatterns(t *testing.T) {
patterns := []string{
"payment required",
"insufficient credits",
"credit balance too low",
"plans & billing page",
"insufficient balance",
}
for _, msg := range patterns {
err := errors.New(msg)
result := ClassifyError(err, "openai", "gpt-4")
if result == nil {
t.Errorf("pattern %q: expected non-nil", msg)
continue
}
if result.Reason != FailoverBilling {
t.Errorf("pattern %q: reason = %q, want billing", msg, result.Reason)
}
}
}
func TestClassifyError_TimeoutPatterns(t *testing.T) {
patterns := []string{
"request timeout",
"connection timed out",
"deadline exceeded",
"context deadline exceeded",
}
for _, msg := range patterns {
err := errors.New(msg)
result := ClassifyError(err, "openai", "gpt-4")
if result == nil {
t.Errorf("pattern %q: expected non-nil", msg)
continue
}
if result.Reason != FailoverTimeout {
t.Errorf("pattern %q: reason = %q, want timeout", msg, result.Reason)
}
}
}
func TestClassifyError_AuthPatterns(t *testing.T) {
patterns := []string{
"invalid api key",
"invalid_api_key",
"incorrect api key",
"invalid token",
"authentication failed",
"re-authenticate",
"oauth token refresh failed",
"unauthorized access",
"forbidden",
"access denied",
"expired",
"token has expired",
"no credentials found",
"no api key found",
}
for _, msg := range patterns {
err := errors.New(msg)
result := ClassifyError(err, "openai", "gpt-4")
if result == nil {
t.Errorf("pattern %q: expected non-nil", msg)
continue
}
if result.Reason != FailoverAuth {
t.Errorf("pattern %q: reason = %q, want auth", msg, result.Reason)
}
}
}
func TestClassifyError_FormatPatterns(t *testing.T) {
patterns := []string{
"string should match pattern",
"tool_use.id is required",
"invalid tool_use_id",
"messages.1.content.1.tool_use.id must be valid",
"invalid request format",
}
for _, msg := range patterns {
err := errors.New(msg)
result := ClassifyError(err, "anthropic", "claude")
if result == nil {
t.Errorf("pattern %q: expected non-nil", msg)
continue
}
if result.Reason != FailoverFormat {
t.Errorf("pattern %q: reason = %q, want format", msg, result.Reason)
}
}
}
func TestClassifyError_ImageDimensionError(t *testing.T) {
err := errors.New("image dimensions exceed max allowed 2048x2048")
result := ClassifyError(err, "openai", "gpt-4o")
if result == nil {
t.Fatal("expected non-nil for image dimension error")
}
if result.Reason != FailoverFormat {
t.Errorf("reason = %q, want format", result.Reason)
}
if result.IsRetriable() {
t.Error("image dimension error should not be retriable")
}
}
func TestClassifyError_ImageSizeError(t *testing.T) {
err := errors.New("image exceeds 20 mb limit")
result := ClassifyError(err, "openai", "gpt-4o")
if result == nil {
t.Fatal("expected non-nil for image size error")
}
if result.Reason != FailoverFormat {
t.Errorf("reason = %q, want format", result.Reason)
}
}
func TestClassifyError_UnknownError(t *testing.T) {
err := errors.New("some completely random error")
result := ClassifyError(err, "openai", "gpt-4")
if result != nil {
t.Errorf("expected nil for unknown error, got %+v", result)
}
}
func TestClassifyError_ProviderModelPropagation(t *testing.T) {
err := errors.New("rate limit exceeded")
result := ClassifyError(err, "my-provider", "my-model")
if result == nil {
t.Fatal("expected non-nil")
}
if result.Provider != "my-provider" {
t.Errorf("provider = %q, want my-provider", result.Provider)
}
if result.Model != "my-model" {
t.Errorf("model = %q, want my-model", result.Model)
}
}
func TestFailoverError_IsRetriable(t *testing.T) {
tests := []struct {
reason FailoverReason
retriable bool
}{
{FailoverAuth, true},
{FailoverRateLimit, true},
{FailoverBilling, true},
{FailoverTimeout, true},
{FailoverOverloaded, true},
{FailoverFormat, false},
{FailoverUnknown, true},
}
for _, tt := range tests {
fe := &FailoverError{Reason: tt.reason}
if fe.IsRetriable() != tt.retriable {
t.Errorf("IsRetriable(%q) = %v, want %v", tt.reason, fe.IsRetriable(), tt.retriable)
}
}
}
func TestFailoverError_ErrorString(t *testing.T) {
fe := &FailoverError{
Reason: FailoverRateLimit,
Provider: "openai",
Model: "gpt-4",
Status: 429,
Wrapped: errors.New("too many requests"),
}
s := fe.Error()
if s == "" {
t.Error("expected non-empty error string")
}
}
func TestFailoverError_Unwrap(t *testing.T) {
inner := errors.New("inner error")
fe := &FailoverError{Reason: FailoverTimeout, Wrapped: inner}
if fe.Unwrap() != inner {
t.Error("Unwrap should return wrapped error")
}
}
func TestExtractHTTPStatus(t *testing.T) {
tests := []struct {
msg string
want int
}{
{"status: 429 rate limited", 429},
{"status 401 unauthorized", 401},
{"HTTP/1.1 502 Bad Gateway", 502},
{"no status code here", 0},
{"random number 12345", 0},
}
for _, tt := range tests {
got := extractHTTPStatus(tt.msg)
if got != tt.want {
t.Errorf("extractHTTPStatus(%q) = %d, want %d", tt.msg, got, tt.want)
}
}
}
func TestIsImageDimensionError(t *testing.T) {
if !IsImageDimensionError("image dimensions exceed max 4096x4096") {
t.Error("should match image dimensions exceed max")
}
if IsImageDimensionError("normal error message") {
t.Error("should not match normal error")
}
}
func TestIsImageSizeError(t *testing.T) {
if !IsImageSizeError("image exceeds 20 mb") {
t.Error("should match image exceeds mb")
}
if IsImageSizeError("normal error message") {
t.Error("should not match normal error")
}
}
+360
View File
@@ -0,0 +1,360 @@
package providers
import (
"fmt"
"strings"
"github.com/sipeed/picoclaw/pkg/auth"
"github.com/sipeed/picoclaw/pkg/config"
)
const defaultAnthropicAPIBase = "https://api.anthropic.com/v1"
var getCredential = auth.GetCredential
type providerType int
const (
providerTypeHTTPCompat providerType = iota
providerTypeClaudeAuth
providerTypeCodexAuth
providerTypeCodexCLIToken
providerTypeClaudeCLI
providerTypeCodexCLI
providerTypeGitHubCopilot
)
type providerSelection struct {
providerType providerType
apiKey string
apiBase string
proxy string
model string
workspace string
connectMode string
enableWebSearch bool
}
func createClaudeAuthProvider(apiBase string) (LLMProvider, error) {
if apiBase == "" {
apiBase = defaultAnthropicAPIBase
}
cred, err := getCredential("anthropic")
if err != nil {
return nil, fmt.Errorf("loading auth credentials: %w", err)
}
if cred == nil {
return nil, fmt.Errorf("no credentials for anthropic. Run: picoclaw auth login --provider anthropic")
}
return NewClaudeProviderWithTokenSourceAndBaseURL(cred.AccessToken, createClaudeTokenSource(), apiBase), nil
}
func createCodexAuthProvider(enableWebSearch bool) (LLMProvider, error) {
cred, err := getCredential("openai")
if err != nil {
return nil, fmt.Errorf("loading auth credentials: %w", err)
}
if cred == nil {
return nil, fmt.Errorf("no credentials for openai. Run: picoclaw auth login --provider openai")
}
p := NewCodexProviderWithTokenSource(cred.AccessToken, cred.AccountID, createCodexTokenSource())
p.enableWebSearch = enableWebSearch
return p, nil
}
func resolveProviderSelection(cfg *config.Config) (providerSelection, error) {
model := cfg.Agents.Defaults.Model
providerName := strings.ToLower(cfg.Agents.Defaults.Provider)
lowerModel := strings.ToLower(model)
sel := providerSelection{
providerType: providerTypeHTTPCompat,
model: model,
}
// First, prefer explicit provider configuration.
if providerName != "" {
switch providerName {
case "groq":
if cfg.Providers.Groq.APIKey != "" {
sel.apiKey = cfg.Providers.Groq.APIKey
sel.apiBase = cfg.Providers.Groq.APIBase
sel.proxy = cfg.Providers.Groq.Proxy
if sel.apiBase == "" {
sel.apiBase = "https://api.groq.com/openai/v1"
}
}
case "openai", "gpt":
if cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != "" {
sel.enableWebSearch = cfg.Providers.OpenAI.WebSearch
if cfg.Providers.OpenAI.AuthMethod == "codex-cli" {
sel.providerType = providerTypeCodexCLIToken
return sel, nil
}
if cfg.Providers.OpenAI.AuthMethod == "oauth" || cfg.Providers.OpenAI.AuthMethod == "token" {
sel.providerType = providerTypeCodexAuth
return sel, nil
}
sel.apiKey = cfg.Providers.OpenAI.APIKey
sel.apiBase = cfg.Providers.OpenAI.APIBase
sel.proxy = cfg.Providers.OpenAI.Proxy
if sel.apiBase == "" {
sel.apiBase = "https://api.openai.com/v1"
}
}
case "anthropic", "claude":
if cfg.Providers.Anthropic.APIKey != "" || cfg.Providers.Anthropic.AuthMethod != "" {
if cfg.Providers.Anthropic.AuthMethod == "oauth" || cfg.Providers.Anthropic.AuthMethod == "token" {
sel.apiBase = cfg.Providers.Anthropic.APIBase
if sel.apiBase == "" {
sel.apiBase = defaultAnthropicAPIBase
}
sel.providerType = providerTypeClaudeAuth
return sel, nil
}
sel.apiKey = cfg.Providers.Anthropic.APIKey
sel.apiBase = cfg.Providers.Anthropic.APIBase
sel.proxy = cfg.Providers.Anthropic.Proxy
if sel.apiBase == "" {
sel.apiBase = defaultAnthropicAPIBase
}
}
case "openrouter":
if cfg.Providers.OpenRouter.APIKey != "" {
sel.apiKey = cfg.Providers.OpenRouter.APIKey
sel.proxy = cfg.Providers.OpenRouter.Proxy
if cfg.Providers.OpenRouter.APIBase != "" {
sel.apiBase = cfg.Providers.OpenRouter.APIBase
} else {
sel.apiBase = "https://openrouter.ai/api/v1"
}
}
case "zhipu", "glm":
if cfg.Providers.Zhipu.APIKey != "" {
sel.apiKey = cfg.Providers.Zhipu.APIKey
sel.apiBase = cfg.Providers.Zhipu.APIBase
sel.proxy = cfg.Providers.Zhipu.Proxy
if sel.apiBase == "" {
sel.apiBase = "https://open.bigmodel.cn/api/paas/v4"
}
}
case "gemini", "google":
if cfg.Providers.Gemini.APIKey != "" {
sel.apiKey = cfg.Providers.Gemini.APIKey
sel.apiBase = cfg.Providers.Gemini.APIBase
sel.proxy = cfg.Providers.Gemini.Proxy
if sel.apiBase == "" {
sel.apiBase = "https://generativelanguage.googleapis.com/v1beta"
}
}
case "vllm":
if cfg.Providers.VLLM.APIBase != "" {
sel.apiKey = cfg.Providers.VLLM.APIKey
sel.apiBase = cfg.Providers.VLLM.APIBase
sel.proxy = cfg.Providers.VLLM.Proxy
}
case "shengsuanyun":
if cfg.Providers.ShengSuanYun.APIKey != "" {
sel.apiKey = cfg.Providers.ShengSuanYun.APIKey
sel.apiBase = cfg.Providers.ShengSuanYun.APIBase
sel.proxy = cfg.Providers.ShengSuanYun.Proxy
if sel.apiBase == "" {
sel.apiBase = "https://router.shengsuanyun.com/api/v1"
}
}
case "nvidia":
if cfg.Providers.Nvidia.APIKey != "" {
sel.apiKey = cfg.Providers.Nvidia.APIKey
sel.apiBase = cfg.Providers.Nvidia.APIBase
sel.proxy = cfg.Providers.Nvidia.Proxy
if sel.apiBase == "" {
sel.apiBase = "https://integrate.api.nvidia.com/v1"
}
}
case "claude-cli", "claude-code", "claudecode":
workspace := cfg.WorkspacePath()
if workspace == "" {
workspace = "."
}
sel.providerType = providerTypeClaudeCLI
sel.workspace = workspace
return sel, nil
case "codex-cli", "codex-code":
workspace := cfg.WorkspacePath()
if workspace == "" {
workspace = "."
}
sel.providerType = providerTypeCodexCLI
sel.workspace = workspace
return sel, nil
case "deepseek":
if cfg.Providers.DeepSeek.APIKey != "" {
sel.apiKey = cfg.Providers.DeepSeek.APIKey
sel.apiBase = cfg.Providers.DeepSeek.APIBase
sel.proxy = cfg.Providers.DeepSeek.Proxy
if sel.apiBase == "" {
sel.apiBase = "https://api.deepseek.com/v1"
}
if model != "deepseek-chat" && model != "deepseek-reasoner" {
sel.model = "deepseek-chat"
}
}
case "github_copilot", "copilot":
sel.providerType = providerTypeGitHubCopilot
if cfg.Providers.GitHubCopilot.APIBase != "" {
sel.apiBase = cfg.Providers.GitHubCopilot.APIBase
} else {
sel.apiBase = "localhost:4321"
}
sel.connectMode = cfg.Providers.GitHubCopilot.ConnectMode
return sel, nil
}
}
// Fallback: infer provider from model and configured keys.
if sel.apiKey == "" && sel.apiBase == "" {
switch {
case (strings.Contains(lowerModel, "kimi") || strings.Contains(lowerModel, "moonshot") || strings.HasPrefix(model, "moonshot/")) && cfg.Providers.Moonshot.APIKey != "":
sel.apiKey = cfg.Providers.Moonshot.APIKey
sel.apiBase = cfg.Providers.Moonshot.APIBase
sel.proxy = cfg.Providers.Moonshot.Proxy
if sel.apiBase == "" {
sel.apiBase = "https://api.moonshot.cn/v1"
}
case strings.HasPrefix(model, "openrouter/") ||
strings.HasPrefix(model, "anthropic/") ||
strings.HasPrefix(model, "openai/") ||
strings.HasPrefix(model, "meta-llama/") ||
strings.HasPrefix(model, "deepseek/") ||
strings.HasPrefix(model, "google/"):
sel.apiKey = cfg.Providers.OpenRouter.APIKey
sel.proxy = cfg.Providers.OpenRouter.Proxy
if cfg.Providers.OpenRouter.APIBase != "" {
sel.apiBase = cfg.Providers.OpenRouter.APIBase
} else {
sel.apiBase = "https://openrouter.ai/api/v1"
}
case (strings.Contains(lowerModel, "claude") || strings.HasPrefix(model, "anthropic/")) &&
(cfg.Providers.Anthropic.APIKey != "" || cfg.Providers.Anthropic.AuthMethod != ""):
if cfg.Providers.Anthropic.AuthMethod == "oauth" || cfg.Providers.Anthropic.AuthMethod == "token" {
sel.apiBase = cfg.Providers.Anthropic.APIBase
if sel.apiBase == "" {
sel.apiBase = defaultAnthropicAPIBase
}
sel.providerType = providerTypeClaudeAuth
return sel, nil
}
sel.apiKey = cfg.Providers.Anthropic.APIKey
sel.apiBase = cfg.Providers.Anthropic.APIBase
sel.proxy = cfg.Providers.Anthropic.Proxy
if sel.apiBase == "" {
sel.apiBase = defaultAnthropicAPIBase
}
case (strings.Contains(lowerModel, "gpt") || strings.HasPrefix(model, "openai/")) &&
(cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != ""):
sel.enableWebSearch = cfg.Providers.OpenAI.WebSearch
if cfg.Providers.OpenAI.AuthMethod == "codex-cli" {
sel.providerType = providerTypeCodexCLIToken
return sel, nil
}
if cfg.Providers.OpenAI.AuthMethod == "oauth" || cfg.Providers.OpenAI.AuthMethod == "token" {
sel.providerType = providerTypeCodexAuth
return sel, nil
}
sel.apiKey = cfg.Providers.OpenAI.APIKey
sel.apiBase = cfg.Providers.OpenAI.APIBase
sel.proxy = cfg.Providers.OpenAI.Proxy
if sel.apiBase == "" {
sel.apiBase = "https://api.openai.com/v1"
}
case (strings.Contains(lowerModel, "gemini") || strings.HasPrefix(model, "google/")) && cfg.Providers.Gemini.APIKey != "":
sel.apiKey = cfg.Providers.Gemini.APIKey
sel.apiBase = cfg.Providers.Gemini.APIBase
sel.proxy = cfg.Providers.Gemini.Proxy
if sel.apiBase == "" {
sel.apiBase = "https://generativelanguage.googleapis.com/v1beta"
}
case (strings.Contains(lowerModel, "glm") || strings.Contains(lowerModel, "zhipu") || strings.Contains(lowerModel, "zai")) && cfg.Providers.Zhipu.APIKey != "":
sel.apiKey = cfg.Providers.Zhipu.APIKey
sel.apiBase = cfg.Providers.Zhipu.APIBase
sel.proxy = cfg.Providers.Zhipu.Proxy
if sel.apiBase == "" {
sel.apiBase = "https://open.bigmodel.cn/api/paas/v4"
}
case (strings.Contains(lowerModel, "groq") || strings.HasPrefix(model, "groq/")) && cfg.Providers.Groq.APIKey != "":
sel.apiKey = cfg.Providers.Groq.APIKey
sel.apiBase = cfg.Providers.Groq.APIBase
sel.proxy = cfg.Providers.Groq.Proxy
if sel.apiBase == "" {
sel.apiBase = "https://api.groq.com/openai/v1"
}
case (strings.Contains(lowerModel, "nvidia") || strings.HasPrefix(model, "nvidia/")) && cfg.Providers.Nvidia.APIKey != "":
sel.apiKey = cfg.Providers.Nvidia.APIKey
sel.apiBase = cfg.Providers.Nvidia.APIBase
sel.proxy = cfg.Providers.Nvidia.Proxy
if sel.apiBase == "" {
sel.apiBase = "https://integrate.api.nvidia.com/v1"
}
case (strings.Contains(lowerModel, "ollama") || strings.HasPrefix(model, "ollama/")) && cfg.Providers.Ollama.APIKey != "":
sel.apiKey = cfg.Providers.Ollama.APIKey
sel.apiBase = cfg.Providers.Ollama.APIBase
sel.proxy = cfg.Providers.Ollama.Proxy
if sel.apiBase == "" {
sel.apiBase = "http://localhost:11434/v1"
}
case cfg.Providers.VLLM.APIBase != "":
sel.apiKey = cfg.Providers.VLLM.APIKey
sel.apiBase = cfg.Providers.VLLM.APIBase
sel.proxy = cfg.Providers.VLLM.Proxy
default:
if cfg.Providers.OpenRouter.APIKey != "" {
sel.apiKey = cfg.Providers.OpenRouter.APIKey
sel.proxy = cfg.Providers.OpenRouter.Proxy
if cfg.Providers.OpenRouter.APIBase != "" {
sel.apiBase = cfg.Providers.OpenRouter.APIBase
} else {
sel.apiBase = "https://openrouter.ai/api/v1"
}
} else {
return providerSelection{}, fmt.Errorf("no API key configured for model: %s", model)
}
}
}
if sel.providerType == providerTypeHTTPCompat {
if sel.apiKey == "" && !strings.HasPrefix(model, "bedrock/") {
return providerSelection{}, fmt.Errorf("no API key configured for provider (model: %s)", model)
}
if sel.apiBase == "" {
return providerSelection{}, fmt.Errorf("no API base configured for provider (model: %s)", model)
}
}
return sel, nil
}
func CreateProvider(cfg *config.Config) (LLMProvider, error) {
sel, err := resolveProviderSelection(cfg)
if err != nil {
return nil, err
}
switch sel.providerType {
case providerTypeClaudeAuth:
return createClaudeAuthProvider(sel.apiBase)
case providerTypeCodexAuth:
return createCodexAuthProvider(sel.enableWebSearch)
case providerTypeCodexCLIToken:
c := NewCodexProviderWithTokenSource("", "", CreateCodexCliTokenSource())
c.enableWebSearch = sel.enableWebSearch
return c, nil
case providerTypeClaudeCLI:
return NewClaudeCliProvider(sel.workspace), nil
case providerTypeCodexCLI:
return NewCodexCliProvider(sel.workspace), nil
case providerTypeGitHubCopilot:
return NewGitHubCopilotProvider(sel.apiBase, sel.connectMode, sel.model)
default:
return NewHTTPProvider(sel.apiKey, sel.apiBase, sel.proxy), nil
}
}
+299
View File
@@ -0,0 +1,299 @@
package providers
import (
"strings"
"testing"
"github.com/sipeed/picoclaw/pkg/auth"
"github.com/sipeed/picoclaw/pkg/config"
)
func TestResolveProviderSelection(t *testing.T) {
tests := []struct {
name string
setup func(*config.Config)
wantType providerType
wantAPIBase string
wantProxy string
wantErrSubstr string
}{
{
name: "explicit claude-cli provider routes to cli provider type",
setup: func(cfg *config.Config) {
cfg.Agents.Defaults.Provider = "claude-cli"
cfg.Agents.Defaults.Workspace = "/tmp/ws"
},
wantType: providerTypeClaudeCLI,
},
{
name: "explicit copilot provider routes to github copilot type",
setup: func(cfg *config.Config) {
cfg.Agents.Defaults.Provider = "copilot"
},
wantType: providerTypeGitHubCopilot,
wantAPIBase: "localhost:4321",
},
{
name: "explicit deepseek provider uses deepseek defaults",
setup: func(cfg *config.Config) {
cfg.Agents.Defaults.Provider = "deepseek"
cfg.Agents.Defaults.Model = "deepseek/deepseek-chat"
cfg.Providers.DeepSeek.APIKey = "deepseek-key"
cfg.Providers.DeepSeek.Proxy = "http://127.0.0.1:7890"
},
wantType: providerTypeHTTPCompat,
wantAPIBase: "https://api.deepseek.com/v1",
wantProxy: "http://127.0.0.1:7890",
},
{
name: "explicit shengsuanyun provider uses defaults",
setup: func(cfg *config.Config) {
cfg.Agents.Defaults.Provider = "shengsuanyun"
cfg.Providers.ShengSuanYun.APIKey = "ssy-key"
cfg.Providers.ShengSuanYun.Proxy = "http://127.0.0.1:7890"
},
wantType: providerTypeHTTPCompat,
wantAPIBase: "https://router.shengsuanyun.com/api/v1",
wantProxy: "http://127.0.0.1:7890",
},
{
name: "explicit nvidia provider uses defaults",
setup: func(cfg *config.Config) {
cfg.Agents.Defaults.Provider = "nvidia"
cfg.Providers.Nvidia.APIKey = "nvapi-test"
cfg.Providers.Nvidia.Proxy = "http://127.0.0.1:7890"
},
wantType: providerTypeHTTPCompat,
wantAPIBase: "https://integrate.api.nvidia.com/v1",
wantProxy: "http://127.0.0.1:7890",
},
{
name: "openrouter model uses openrouter defaults",
setup: func(cfg *config.Config) {
cfg.Agents.Defaults.Model = "openrouter/auto"
cfg.Providers.OpenRouter.APIKey = "sk-or-test"
},
wantType: providerTypeHTTPCompat,
wantAPIBase: "https://openrouter.ai/api/v1",
},
{
name: "anthropic oauth routes to claude auth provider",
setup: func(cfg *config.Config) {
cfg.Agents.Defaults.Model = "claude-sonnet-4-5-20250929"
cfg.Providers.Anthropic.AuthMethod = "oauth"
},
wantType: providerTypeClaudeAuth,
},
{
name: "openai oauth routes to codex auth provider",
setup: func(cfg *config.Config) {
cfg.Agents.Defaults.Model = "gpt-4o"
cfg.Providers.OpenAI.AuthMethod = "oauth"
},
wantType: providerTypeCodexAuth,
},
{
name: "openai codex-cli auth routes to codex cli token provider",
setup: func(cfg *config.Config) {
cfg.Agents.Defaults.Model = "gpt-4o"
cfg.Providers.OpenAI.AuthMethod = "codex-cli"
},
wantType: providerTypeCodexCLIToken,
},
{
name: "explicit codex-code provider routes to codex cli provider type",
setup: func(cfg *config.Config) {
cfg.Agents.Defaults.Provider = "codex-code"
cfg.Agents.Defaults.Workspace = "/tmp/ws"
},
wantType: providerTypeCodexCLI,
},
{
name: "zhipu model uses zhipu base default",
setup: func(cfg *config.Config) {
cfg.Agents.Defaults.Model = "glm-4.7"
cfg.Providers.Zhipu.APIKey = "zhipu-key"
},
wantType: providerTypeHTTPCompat,
wantAPIBase: "https://open.bigmodel.cn/api/paas/v4",
},
{
name: "groq model uses groq base default",
setup: func(cfg *config.Config) {
cfg.Agents.Defaults.Model = "groq/llama-3.3-70b"
cfg.Providers.Groq.APIKey = "gsk-key"
},
wantType: providerTypeHTTPCompat,
wantAPIBase: "https://api.groq.com/openai/v1",
},
{
name: "ollama model uses ollama base default",
setup: func(cfg *config.Config) {
cfg.Agents.Defaults.Model = "ollama/qwen2.5:14b"
cfg.Providers.Ollama.APIKey = "ollama-key"
},
wantType: providerTypeHTTPCompat,
wantAPIBase: "http://localhost:11434/v1",
},
{
name: "moonshot model keeps proxy and default base",
setup: func(cfg *config.Config) {
cfg.Agents.Defaults.Model = "moonshot/kimi-k2.5"
cfg.Providers.Moonshot.APIKey = "moonshot-key"
cfg.Providers.Moonshot.Proxy = "http://127.0.0.1:7890"
},
wantType: providerTypeHTTPCompat,
wantAPIBase: "https://api.moonshot.cn/v1",
wantProxy: "http://127.0.0.1:7890",
},
{
name: "missing keys returns model config error",
setup: func(cfg *config.Config) {
cfg.Agents.Defaults.Model = "custom-model"
},
wantErrSubstr: "no API key configured for model",
},
{
name: "openrouter prefix without key returns provider key error",
setup: func(cfg *config.Config) {
cfg.Agents.Defaults.Model = "openrouter/auto"
},
wantErrSubstr: "no API key configured for provider",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := config.DefaultConfig()
tt.setup(cfg)
got, err := resolveProviderSelection(cfg)
if tt.wantErrSubstr != "" {
if err == nil {
t.Fatalf("expected error containing %q, got nil", tt.wantErrSubstr)
}
if !strings.Contains(err.Error(), tt.wantErrSubstr) {
t.Fatalf("error = %q, want substring %q", err.Error(), tt.wantErrSubstr)
}
return
}
if err != nil {
t.Fatalf("resolveProviderSelection() error = %v", err)
}
if got.providerType != tt.wantType {
t.Fatalf("providerType = %v, want %v", got.providerType, tt.wantType)
}
if tt.wantAPIBase != "" && got.apiBase != tt.wantAPIBase {
t.Fatalf("apiBase = %q, want %q", got.apiBase, tt.wantAPIBase)
}
if tt.wantProxy != "" && got.proxy != tt.wantProxy {
t.Fatalf("proxy = %q, want %q", got.proxy, tt.wantProxy)
}
})
}
}
func TestCreateProviderReturnsHTTPProviderForOpenRouter(t *testing.T) {
cfg := config.DefaultConfig()
cfg.Agents.Defaults.Model = "openrouter/auto"
cfg.Providers.OpenRouter.APIKey = "sk-or-test"
provider, err := CreateProvider(cfg)
if err != nil {
t.Fatalf("CreateProvider() error = %v", err)
}
if _, ok := provider.(*HTTPProvider); !ok {
t.Fatalf("provider type = %T, want *HTTPProvider", provider)
}
}
func TestCreateProviderReturnsCodexCliProviderForCodexCode(t *testing.T) {
cfg := config.DefaultConfig()
cfg.Agents.Defaults.Provider = "codex-code"
provider, err := CreateProvider(cfg)
if err != nil {
t.Fatalf("CreateProvider() error = %v", err)
}
if _, ok := provider.(*CodexCliProvider); !ok {
t.Fatalf("provider type = %T, want *CodexCliProvider", provider)
}
}
func TestCreateProviderReturnsCodexProviderForCodexCliAuthMethod(t *testing.T) {
cfg := config.DefaultConfig()
cfg.Agents.Defaults.Provider = "openai"
cfg.Providers.OpenAI.AuthMethod = "codex-cli"
provider, err := CreateProvider(cfg)
if err != nil {
t.Fatalf("CreateProvider() error = %v", err)
}
if _, ok := provider.(*CodexProvider); !ok {
t.Fatalf("provider type = %T, want *CodexProvider", provider)
}
}
func TestCreateProviderReturnsClaudeProviderForAnthropicOAuth(t *testing.T) {
originalGetCredential := getCredential
t.Cleanup(func() { getCredential = originalGetCredential })
getCredential = func(provider string) (*auth.AuthCredential, error) {
if provider != "anthropic" {
t.Fatalf("provider = %q, want anthropic", provider)
}
return &auth.AuthCredential{
AccessToken: "anthropic-token",
}, nil
}
cfg := config.DefaultConfig()
cfg.Agents.Defaults.Provider = "anthropic"
cfg.Providers.Anthropic.AuthMethod = "oauth"
cfg.Providers.Anthropic.APIBase = "https://proxy.example.com/v1"
provider, err := CreateProvider(cfg)
if err != nil {
t.Fatalf("CreateProvider() error = %v", err)
}
claudeProvider, ok := provider.(*ClaudeProvider)
if !ok {
t.Fatalf("provider type = %T, want *ClaudeProvider", provider)
}
if got := claudeProvider.delegate.BaseURL(); got != "https://proxy.example.com" {
t.Fatalf("anthropic baseURL = %q, want %q", got, "https://proxy.example.com")
}
}
func TestCreateProviderReturnsCodexProviderForOpenAIOAuth(t *testing.T) {
originalGetCredential := getCredential
t.Cleanup(func() { getCredential = originalGetCredential })
getCredential = func(provider string) (*auth.AuthCredential, error) {
if provider != "openai" {
t.Fatalf("provider = %q, want openai", provider)
}
return &auth.AuthCredential{
AccessToken: "openai-token",
AccountID: "acct_123",
}, nil
}
cfg := config.DefaultConfig()
cfg.Agents.Defaults.Provider = "openai"
cfg.Providers.OpenAI.AuthMethod = "oauth"
provider, err := CreateProvider(cfg)
if err != nil {
t.Fatalf("CreateProvider() error = %v", err)
}
if _, ok := provider.(*CodexProvider); !ok {
t.Fatalf("provider type = %T, want *CodexProvider", provider)
}
}
+283
View File
@@ -0,0 +1,283 @@
package providers
import (
"context"
"fmt"
"strings"
"time"
)
// FallbackChain orchestrates model fallback across multiple candidates.
type FallbackChain struct {
cooldown *CooldownTracker
}
// FallbackCandidate represents one model/provider to try.
type FallbackCandidate struct {
Provider string
Model string
}
// FallbackResult contains the successful response and metadata about all attempts.
type FallbackResult struct {
Response *LLMResponse
Provider string
Model string
Attempts []FallbackAttempt
}
// FallbackAttempt records one attempt in the fallback chain.
type FallbackAttempt struct {
Provider string
Model string
Error error
Reason FailoverReason
Duration time.Duration
Skipped bool // true if skipped due to cooldown
}
// NewFallbackChain creates a new fallback chain with the given cooldown tracker.
func NewFallbackChain(cooldown *CooldownTracker) *FallbackChain {
return &FallbackChain{cooldown: cooldown}
}
// ResolveCandidates parses model config into a deduplicated candidate list.
func ResolveCandidates(cfg ModelConfig, defaultProvider string) []FallbackCandidate {
seen := make(map[string]bool)
var candidates []FallbackCandidate
addCandidate := func(raw string) {
ref := ParseModelRef(raw, defaultProvider)
if ref == nil {
return
}
key := ModelKey(ref.Provider, ref.Model)
if seen[key] {
return
}
seen[key] = true
candidates = append(candidates, FallbackCandidate{
Provider: ref.Provider,
Model: ref.Model,
})
}
// Primary first.
addCandidate(cfg.Primary)
// Then fallbacks.
for _, fb := range cfg.Fallbacks {
addCandidate(fb)
}
return candidates
}
// Execute runs the fallback chain for text/chat requests.
// It tries each candidate in order, respecting cooldowns and error classification.
//
// Behavior:
// - Candidates in cooldown are skipped (logged as skipped attempt).
// - context.Canceled aborts immediately (user abort, no fallback).
// - Non-retriable errors (format) abort immediately.
// - Retriable errors trigger fallback to next candidate.
// - Success marks provider as good (resets cooldown).
// - If all fail, returns aggregate error with all attempts.
func (fc *FallbackChain) Execute(
ctx context.Context,
candidates []FallbackCandidate,
run func(ctx context.Context, provider, model string) (*LLMResponse, error),
) (*FallbackResult, error) {
if len(candidates) == 0 {
return nil, fmt.Errorf("fallback: no candidates configured")
}
result := &FallbackResult{
Attempts: make([]FallbackAttempt, 0, len(candidates)),
}
for i, candidate := range candidates {
// Check context before each attempt.
if ctx.Err() == context.Canceled {
return nil, context.Canceled
}
// Check cooldown.
if !fc.cooldown.IsAvailable(candidate.Provider) {
remaining := fc.cooldown.CooldownRemaining(candidate.Provider)
result.Attempts = append(result.Attempts, FallbackAttempt{
Provider: candidate.Provider,
Model: candidate.Model,
Skipped: true,
Reason: FailoverRateLimit,
Error: fmt.Errorf("provider %s in cooldown (%s remaining)", candidate.Provider, remaining.Round(time.Second)),
})
continue
}
// Execute the run function.
start := time.Now()
resp, err := run(ctx, candidate.Provider, candidate.Model)
elapsed := time.Since(start)
if err == nil {
// Success.
fc.cooldown.MarkSuccess(candidate.Provider)
result.Response = resp
result.Provider = candidate.Provider
result.Model = candidate.Model
return result, nil
}
// Context cancellation: abort immediately, no fallback.
if ctx.Err() == context.Canceled {
result.Attempts = append(result.Attempts, FallbackAttempt{
Provider: candidate.Provider,
Model: candidate.Model,
Error: err,
Duration: elapsed,
})
return nil, context.Canceled
}
// Classify the error.
failErr := ClassifyError(err, candidate.Provider, candidate.Model)
if failErr == nil {
// Unclassifiable error: do not fallback, return immediately.
result.Attempts = append(result.Attempts, FallbackAttempt{
Provider: candidate.Provider,
Model: candidate.Model,
Error: err,
Duration: elapsed,
})
return nil, fmt.Errorf("fallback: unclassified error from %s/%s: %w",
candidate.Provider, candidate.Model, err)
}
// Non-retriable error: abort immediately.
if !failErr.IsRetriable() {
result.Attempts = append(result.Attempts, FallbackAttempt{
Provider: candidate.Provider,
Model: candidate.Model,
Error: failErr,
Reason: failErr.Reason,
Duration: elapsed,
})
return nil, failErr
}
// Retriable error: mark failure and continue to next candidate.
fc.cooldown.MarkFailure(candidate.Provider, failErr.Reason)
result.Attempts = append(result.Attempts, FallbackAttempt{
Provider: candidate.Provider,
Model: candidate.Model,
Error: failErr,
Reason: failErr.Reason,
Duration: elapsed,
})
// If this was the last candidate, return aggregate error.
if i == len(candidates)-1 {
return nil, &FallbackExhaustedError{Attempts: result.Attempts}
}
}
// All candidates were skipped (all in cooldown).
return nil, &FallbackExhaustedError{Attempts: result.Attempts}
}
// ExecuteImage runs the fallback chain for image/vision requests.
// Simpler than Execute: no cooldown checks (image endpoints have different rate limits).
// Image dimension/size errors abort immediately (non-retriable).
func (fc *FallbackChain) ExecuteImage(
ctx context.Context,
candidates []FallbackCandidate,
run func(ctx context.Context, provider, model string) (*LLMResponse, error),
) (*FallbackResult, error) {
if len(candidates) == 0 {
return nil, fmt.Errorf("image fallback: no candidates configured")
}
result := &FallbackResult{
Attempts: make([]FallbackAttempt, 0, len(candidates)),
}
for i, candidate := range candidates {
if ctx.Err() == context.Canceled {
return nil, context.Canceled
}
start := time.Now()
resp, err := run(ctx, candidate.Provider, candidate.Model)
elapsed := time.Since(start)
if err == nil {
result.Response = resp
result.Provider = candidate.Provider
result.Model = candidate.Model
return result, nil
}
if ctx.Err() == context.Canceled {
result.Attempts = append(result.Attempts, FallbackAttempt{
Provider: candidate.Provider,
Model: candidate.Model,
Error: err,
Duration: elapsed,
})
return nil, context.Canceled
}
// Image dimension/size errors are non-retriable.
errMsg := strings.ToLower(err.Error())
if IsImageDimensionError(errMsg) || IsImageSizeError(errMsg) {
result.Attempts = append(result.Attempts, FallbackAttempt{
Provider: candidate.Provider,
Model: candidate.Model,
Error: err,
Reason: FailoverFormat,
Duration: elapsed,
})
return nil, &FailoverError{
Reason: FailoverFormat,
Provider: candidate.Provider,
Model: candidate.Model,
Wrapped: err,
}
}
// Any other error: record and try next.
result.Attempts = append(result.Attempts, FallbackAttempt{
Provider: candidate.Provider,
Model: candidate.Model,
Error: err,
Duration: elapsed,
})
if i == len(candidates)-1 {
return nil, &FallbackExhaustedError{Attempts: result.Attempts}
}
}
return nil, &FallbackExhaustedError{Attempts: result.Attempts}
}
// FallbackExhaustedError indicates all fallback candidates were tried and failed.
type FallbackExhaustedError struct {
Attempts []FallbackAttempt
}
func (e *FallbackExhaustedError) Error() string {
var sb strings.Builder
sb.WriteString(fmt.Sprintf("fallback: all %d candidates failed:", len(e.Attempts)))
for i, a := range e.Attempts {
if a.Skipped {
sb.WriteString(fmt.Sprintf("\n [%d] %s/%s: skipped (cooldown)", i+1, a.Provider, a.Model))
} else {
sb.WriteString(fmt.Sprintf("\n [%d] %s/%s: %v (reason=%s, %s)",
i+1, a.Provider, a.Model, a.Error, a.Reason, a.Duration.Round(time.Millisecond)))
}
}
return sb.String()
}
+473
View File
@@ -0,0 +1,473 @@
package providers
import (
"context"
"errors"
"testing"
"time"
)
func makeCandidate(provider, model string) FallbackCandidate {
return FallbackCandidate{Provider: provider, Model: model}
}
func successRun(content string) func(ctx context.Context, provider, model string) (*LLMResponse, error) {
return func(ctx context.Context, provider, model string) (*LLMResponse, error) {
return &LLMResponse{Content: content, FinishReason: "stop"}, nil
}
}
func failRun(err error) func(ctx context.Context, provider, model string) (*LLMResponse, error) {
return func(ctx context.Context, provider, model string) (*LLMResponse, error) {
return nil, err
}
}
func TestFallback_SingleCandidate_Success(t *testing.T) {
ct := NewCooldownTracker()
fc := NewFallbackChain(ct)
candidates := []FallbackCandidate{makeCandidate("openai", "gpt-4")}
result, err := fc.Execute(context.Background(), candidates, successRun("hello"))
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Response.Content != "hello" {
t.Errorf("content = %q, want hello", result.Response.Content)
}
if result.Provider != "openai" || result.Model != "gpt-4" {
t.Errorf("provider/model = %s/%s, want openai/gpt-4", result.Provider, result.Model)
}
}
func TestFallback_SecondCandidateSuccess(t *testing.T) {
ct := NewCooldownTracker()
fc := NewFallbackChain(ct)
candidates := []FallbackCandidate{
makeCandidate("openai", "gpt-4"),
makeCandidate("anthropic", "claude-opus"),
}
attempt := 0
run := func(ctx context.Context, provider, model string) (*LLMResponse, error) {
attempt++
if attempt == 1 {
return nil, errors.New("rate limit exceeded")
}
return &LLMResponse{Content: "from claude", FinishReason: "stop"}, nil
}
result, err := fc.Execute(context.Background(), candidates, run)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Provider != "anthropic" {
t.Errorf("provider = %q, want anthropic", result.Provider)
}
if result.Response.Content != "from claude" {
t.Errorf("content = %q, want 'from claude'", result.Response.Content)
}
if len(result.Attempts) != 1 {
t.Errorf("attempts = %d, want 1 (failed attempt recorded)", len(result.Attempts))
}
}
func TestFallback_AllFail(t *testing.T) {
ct := NewCooldownTracker()
fc := NewFallbackChain(ct)
candidates := []FallbackCandidate{
makeCandidate("openai", "gpt-4"),
makeCandidate("anthropic", "claude"),
makeCandidate("groq", "llama"),
}
run := func(ctx context.Context, provider, model string) (*LLMResponse, error) {
return nil, errors.New("rate limit exceeded")
}
_, err := fc.Execute(context.Background(), candidates, run)
if err == nil {
t.Fatal("expected error when all candidates fail")
}
var exhausted *FallbackExhaustedError
if !errors.As(err, &exhausted) {
t.Errorf("expected FallbackExhaustedError, got %T: %v", err, err)
}
if len(exhausted.Attempts) != 3 {
t.Errorf("attempts = %d, want 3", len(exhausted.Attempts))
}
}
func TestFallback_ContextCanceled(t *testing.T) {
ct := NewCooldownTracker()
fc := NewFallbackChain(ct)
ctx, cancel := context.WithCancel(context.Background())
candidates := []FallbackCandidate{
makeCandidate("openai", "gpt-4"),
makeCandidate("anthropic", "claude"),
}
attempt := 0
run := func(ctx context.Context, provider, model string) (*LLMResponse, error) {
attempt++
if attempt == 1 {
cancel() // cancel context
return nil, context.Canceled
}
t.Error("should not reach second candidate after cancel")
return nil, nil
}
_, err := fc.Execute(ctx, candidates, run)
if err != context.Canceled {
t.Errorf("expected context.Canceled, got %v", err)
}
}
func TestFallback_NonRetriableError(t *testing.T) {
ct := NewCooldownTracker()
fc := NewFallbackChain(ct)
candidates := []FallbackCandidate{
makeCandidate("openai", "gpt-4"),
makeCandidate("anthropic", "claude"),
}
attempt := 0
run := func(ctx context.Context, provider, model string) (*LLMResponse, error) {
attempt++
return nil, errors.New("string should match pattern")
}
_, err := fc.Execute(context.Background(), candidates, run)
if err == nil {
t.Fatal("expected error for non-retriable")
}
var fe *FailoverError
if !errors.As(err, &fe) {
t.Fatalf("expected FailoverError, got %T", err)
}
if fe.Reason != FailoverFormat {
t.Errorf("reason = %q, want format", fe.Reason)
}
if attempt != 1 {
t.Errorf("attempt = %d, want 1 (non-retriable should not try next)", attempt)
}
}
func TestFallback_CooldownSkip(t *testing.T) {
now := time.Now()
ct, _ := newTestTracker(now)
fc := NewFallbackChain(ct)
// Put openai in cooldown
ct.MarkFailure("openai", FailoverRateLimit)
candidates := []FallbackCandidate{
makeCandidate("openai", "gpt-4"),
makeCandidate("anthropic", "claude"),
}
run := func(ctx context.Context, provider, model string) (*LLMResponse, error) {
if provider == "openai" {
t.Error("should not call openai (in cooldown)")
}
return &LLMResponse{Content: "claude response", FinishReason: "stop"}, nil
}
result, err := fc.Execute(context.Background(), candidates, run)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Provider != "anthropic" {
t.Errorf("provider = %q, want anthropic", result.Provider)
}
// Should have 1 skipped attempt
skipped := 0
for _, a := range result.Attempts {
if a.Skipped {
skipped++
}
}
if skipped != 1 {
t.Errorf("skipped = %d, want 1", skipped)
}
}
func TestFallback_AllInCooldown(t *testing.T) {
ct := NewCooldownTracker()
fc := NewFallbackChain(ct)
// Put all providers in cooldown
ct.MarkFailure("openai", FailoverRateLimit)
ct.MarkFailure("anthropic", FailoverBilling)
candidates := []FallbackCandidate{
makeCandidate("openai", "gpt-4"),
makeCandidate("anthropic", "claude"),
}
_, err := fc.Execute(context.Background(), candidates,
func(ctx context.Context, provider, model string) (*LLMResponse, error) {
t.Error("should not call any provider (all in cooldown)")
return nil, nil
})
if err == nil {
t.Fatal("expected error when all in cooldown")
}
var exhausted *FallbackExhaustedError
if !errors.As(err, &exhausted) {
t.Fatalf("expected FallbackExhaustedError, got %T", err)
}
}
func TestFallback_NoCandidates(t *testing.T) {
ct := NewCooldownTracker()
fc := NewFallbackChain(ct)
_, err := fc.Execute(context.Background(), nil, successRun("ok"))
if err == nil {
t.Error("expected error for empty candidates")
}
}
func TestFallback_EmptyFallbacks(t *testing.T) {
// Single primary, no fallbacks: should work like direct call
ct := NewCooldownTracker()
fc := NewFallbackChain(ct)
candidates := []FallbackCandidate{makeCandidate("openai", "gpt-4")}
result, err := fc.Execute(context.Background(), candidates, successRun("ok"))
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Response.Content != "ok" {
t.Error("expected success with single candidate")
}
}
func TestFallback_UnclassifiedError(t *testing.T) {
ct := NewCooldownTracker()
fc := NewFallbackChain(ct)
candidates := []FallbackCandidate{
makeCandidate("openai", "gpt-4"),
makeCandidate("anthropic", "claude"),
}
attempt := 0
run := func(ctx context.Context, provider, model string) (*LLMResponse, error) {
attempt++
return nil, errors.New("completely unknown internal error")
}
_, err := fc.Execute(context.Background(), candidates, run)
if err == nil {
t.Fatal("expected error for unclassified error")
}
if attempt != 1 {
t.Errorf("attempt = %d, want 1 (should not fallback on unclassified)", attempt)
}
}
func TestFallback_SuccessResetsCooldown(t *testing.T) {
ct := NewCooldownTracker()
fc := NewFallbackChain(ct)
candidates := []FallbackCandidate{makeCandidate("openai", "gpt-4")}
attempt := 0
run := func(ctx context.Context, provider, model string) (*LLMResponse, error) {
attempt++
if attempt == 1 {
ct.MarkFailure("openai", FailoverRateLimit) // simulate failure tracked elsewhere
}
return &LLMResponse{Content: "ok", FinishReason: "stop"}, nil
}
_, err := fc.Execute(context.Background(), candidates, run)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !ct.IsAvailable("openai") {
t.Error("success should reset cooldown")
}
}
// --- Image Fallback Tests ---
func TestImageFallback_Success(t *testing.T) {
ct := NewCooldownTracker()
fc := NewFallbackChain(ct)
candidates := []FallbackCandidate{makeCandidate("openai", "gpt-4o")}
result, err := fc.ExecuteImage(context.Background(), candidates, successRun("image result"))
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Response.Content != "image result" {
t.Error("expected image result")
}
}
func TestImageFallback_DimensionError(t *testing.T) {
ct := NewCooldownTracker()
fc := NewFallbackChain(ct)
candidates := []FallbackCandidate{
makeCandidate("openai", "gpt-4o"),
makeCandidate("anthropic", "claude"),
}
attempt := 0
run := func(ctx context.Context, provider, model string) (*LLMResponse, error) {
attempt++
return nil, errors.New("image dimensions exceed max 4096x4096")
}
_, err := fc.ExecuteImage(context.Background(), candidates, run)
if err == nil {
t.Fatal("expected error for image dimension error")
}
if attempt != 1 {
t.Errorf("attempt = %d, want 1 (image dimension error should not retry)", attempt)
}
}
func TestImageFallback_SizeError(t *testing.T) {
ct := NewCooldownTracker()
fc := NewFallbackChain(ct)
candidates := []FallbackCandidate{
makeCandidate("openai", "gpt-4o"),
makeCandidate("anthropic", "claude"),
}
attempt := 0
run := func(ctx context.Context, provider, model string) (*LLMResponse, error) {
attempt++
return nil, errors.New("image exceeds 20 mb")
}
_, err := fc.ExecuteImage(context.Background(), candidates, run)
if err == nil {
t.Fatal("expected error for image size error")
}
if attempt != 1 {
t.Errorf("attempt = %d, want 1 (image size error should not retry)", attempt)
}
}
func TestImageFallback_RetryOnOtherErrors(t *testing.T) {
ct := NewCooldownTracker()
fc := NewFallbackChain(ct)
candidates := []FallbackCandidate{
makeCandidate("openai", "gpt-4o"),
makeCandidate("anthropic", "claude-sonnet"),
}
attempt := 0
run := func(ctx context.Context, provider, model string) (*LLMResponse, error) {
attempt++
if attempt == 1 {
return nil, errors.New("rate limit exceeded")
}
return &LLMResponse{Content: "image ok", FinishReason: "stop"}, nil
}
result, err := fc.ExecuteImage(context.Background(), candidates, run)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Provider != "anthropic" {
t.Errorf("provider = %q, want anthropic", result.Provider)
}
}
func TestImageFallback_NoCandidates(t *testing.T) {
ct := NewCooldownTracker()
fc := NewFallbackChain(ct)
_, err := fc.ExecuteImage(context.Background(), nil, successRun("ok"))
if err == nil {
t.Error("expected error for empty candidates")
}
}
// --- ResolveCandidates Tests ---
func TestResolveCandidates_Simple(t *testing.T) {
cfg := ModelConfig{
Primary: "gpt-4",
Fallbacks: []string{"anthropic/claude-opus", "groq/llama-3"},
}
candidates := ResolveCandidates(cfg, "openai")
if len(candidates) != 3 {
t.Fatalf("candidates = %d, want 3", len(candidates))
}
if candidates[0].Provider != "openai" || candidates[0].Model != "gpt-4" {
t.Errorf("candidate[0] = %s/%s, want openai/gpt-4", candidates[0].Provider, candidates[0].Model)
}
if candidates[1].Provider != "anthropic" || candidates[1].Model != "claude-opus" {
t.Errorf("candidate[1] = %s/%s, want anthropic/claude-opus", candidates[1].Provider, candidates[1].Model)
}
if candidates[2].Provider != "groq" || candidates[2].Model != "llama-3" {
t.Errorf("candidate[2] = %s/%s, want groq/llama-3", candidates[2].Provider, candidates[2].Model)
}
}
func TestResolveCandidates_Deduplication(t *testing.T) {
cfg := ModelConfig{
Primary: "openai/gpt-4",
Fallbacks: []string{"openai/gpt-4", "anthropic/claude"},
}
candidates := ResolveCandidates(cfg, "default")
if len(candidates) != 2 {
t.Errorf("candidates = %d, want 2 (duplicate removed)", len(candidates))
}
}
func TestResolveCandidates_EmptyFallbacks(t *testing.T) {
cfg := ModelConfig{
Primary: "gpt-4",
Fallbacks: nil,
}
candidates := ResolveCandidates(cfg, "openai")
if len(candidates) != 1 {
t.Errorf("candidates = %d, want 1", len(candidates))
}
}
func TestResolveCandidates_EmptyPrimary(t *testing.T) {
cfg := ModelConfig{
Primary: "",
Fallbacks: []string{"anthropic/claude"},
}
candidates := ResolveCandidates(cfg, "openai")
if len(candidates) != 1 {
t.Errorf("candidates = %d, want 1", len(candidates))
}
}
func TestFallbackExhaustedError_Message(t *testing.T) {
e := &FallbackExhaustedError{
Attempts: []FallbackAttempt{
{Provider: "openai", Model: "gpt-4", Error: errors.New("rate limited"), Reason: FailoverRateLimit, Duration: 500 * time.Millisecond},
{Provider: "anthropic", Model: "claude", Skipped: true},
},
}
msg := e.Error()
if msg == "" {
t.Error("expected non-empty error message")
}
}
+8 -180
View File
@@ -7,201 +7,29 @@
package providers
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
"github.com/sipeed/picoclaw/pkg/providers/openai_compat"
)
type HTTPProvider struct {
apiKey string
apiBase string
maxTokensField string // Field name for max tokens (e.g., "max_completion_tokens" for o1/glm models)
httpClient *http.Client
delegate *openai_compat.Provider
}
func NewHTTPProvider(apiKey, apiBase, proxy string) *HTTPProvider {
return NewHTTPProviderWithMaxTokensField(apiKey, apiBase, proxy, "")
return &HTTPProvider{
delegate: openai_compat.NewProvider(apiKey, apiBase, proxy),
}
}
func NewHTTPProviderWithMaxTokensField(apiKey, apiBase, proxy, maxTokensField string) *HTTPProvider {
client := &http.Client{
Timeout: 120 * time.Second,
}
if proxy != "" {
proxyURL, err := url.Parse(proxy)
if err == nil {
client.Transport = &http.Transport{
Proxy: http.ProxyURL(proxyURL),
}
}
}
return &HTTPProvider{
apiKey: apiKey,
apiBase: strings.TrimRight(apiBase, "/"),
maxTokensField: maxTokensField,
httpClient: client,
delegate: openai_compat.NewProviderWithMaxTokensField(apiKey, apiBase, proxy, maxTokensField),
}
}
func (p *HTTPProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) {
if p.apiBase == "" {
return nil, fmt.Errorf("API base not configured")
}
// Strip provider prefix from model name (e.g., moonshot/kimi-k2.5 -> kimi-k2.5, groq/openai/gpt-oss-120b -> openai/gpt-oss-120b, ollama/qwen2.5:14b -> qwen2.5:14b)
if idx := strings.Index(model, "/"); idx != -1 {
prefix := model[:idx]
if prefix == "moonshot" || prefix == "nvidia" || prefix == "groq" || prefix == "ollama" || prefix == "qwen" || prefix == "cerebras" {
model = model[idx+1:]
}
}
requestBody := map[string]interface{}{
"model": model,
"messages": messages,
}
if len(tools) > 0 {
requestBody["tools"] = tools
requestBody["tool_choice"] = "auto"
}
if maxTokens, ok := options["max_tokens"].(int); ok {
// Use configured max_tokens_field if specified, otherwise fallback to model-based detection
fieldName := p.maxTokensField
if fieldName == "" {
// Fallback: detect from model name for backward compatibility
lowerModel := strings.ToLower(model)
if strings.Contains(lowerModel, "glm") || strings.Contains(lowerModel, "o1") {
fieldName = "max_completion_tokens"
} else {
fieldName = "max_tokens"
}
}
requestBody[fieldName] = maxTokens
}
if temperature, ok := options["temperature"].(float64); ok {
lowerModel := strings.ToLower(model)
// Kimi k2 models only support temperature=1
if strings.Contains(lowerModel, "kimi") && strings.Contains(lowerModel, "k2") {
requestBody["temperature"] = 1.0
} else {
requestBody["temperature"] = temperature
}
}
jsonData, err := json.Marshal(requestBody)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
req, err := http.NewRequestWithContext(ctx, "POST", p.apiBase+"/chat/completions", bytes.NewReader(jsonData))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
if p.apiKey != "" {
req.Header.Set("Authorization", "Bearer "+p.apiKey)
}
resp, err := p.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("API request failed:\n Status: %d\n Body: %s", resp.StatusCode, string(body))
}
return p.parseResponse(body)
}
func (p *HTTPProvider) parseResponse(body []byte) (*LLMResponse, error) {
var apiResponse struct {
Choices []struct {
Message struct {
Content string `json:"content"`
ToolCalls []struct {
ID string `json:"id"`
Type string `json:"type"`
Function *struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
ThoughtSignature string `json:"thought_signature"`
} `json:"function"`
} `json:"tool_calls"`
} `json:"message"`
FinishReason string `json:"finish_reason"`
} `json:"choices"`
Usage *UsageInfo `json:"usage"`
}
if err := json.Unmarshal(body, &apiResponse); err != nil {
return nil, fmt.Errorf("failed to unmarshal response: %w", err)
}
if len(apiResponse.Choices) == 0 {
return &LLMResponse{
Content: "",
FinishReason: "stop",
}, nil
}
choice := apiResponse.Choices[0]
toolCalls := make([]ToolCall, 0, len(choice.Message.ToolCalls))
for _, tc := range choice.Message.ToolCalls {
arguments := make(map[string]interface{})
name := ""
thoughtSignature := ""
argsStr := ""
if tc.Function != nil {
name = tc.Function.Name
thoughtSignature = tc.Function.ThoughtSignature
argsStr = tc.Function.Arguments
if argsStr != "" {
if err := json.Unmarshal([]byte(argsStr), &arguments); err != nil {
arguments["raw"] = argsStr
}
}
}
toolCalls = append(toolCalls, ToolCall{
ID: tc.ID,
Type: tc.Type,
Function: &FunctionCall{
Name: name,
Arguments: argsStr,
ThoughtSignature: thoughtSignature,
},
Name: name,
Arguments: arguments,
})
}
return &LLMResponse{
Content: choice.Message.Content,
ToolCalls: toolCalls,
FinishReason: choice.FinishReason,
Usage: apiResponse.Usage,
}, nil
return p.delegate.Chat(ctx, messages, tools, model, options)
}
func (p *HTTPProvider) GetDefaultModel() string {
+64
View File
@@ -0,0 +1,64 @@
package providers
import "strings"
// ModelRef represents a parsed model reference with provider and model name.
type ModelRef struct {
Provider string
Model string
}
// ParseModelRef parses "anthropic/claude-opus" into {Provider: "anthropic", Model: "claude-opus"}.
// If no slash present, uses defaultProvider.
// Returns nil for empty input.
func ParseModelRef(raw string, defaultProvider string) *ModelRef {
raw = strings.TrimSpace(raw)
if raw == "" {
return nil
}
if idx := strings.Index(raw, "/"); idx > 0 {
provider := NormalizeProvider(raw[:idx])
model := strings.TrimSpace(raw[idx+1:])
if model == "" {
return nil
}
return &ModelRef{Provider: provider, Model: model}
}
return &ModelRef{
Provider: NormalizeProvider(defaultProvider),
Model: raw,
}
}
// NormalizeProvider normalizes provider identifiers to canonical form.
func NormalizeProvider(provider string) string {
p := strings.ToLower(strings.TrimSpace(provider))
switch p {
case "z.ai", "z-ai":
return "zai"
case "opencode-zen":
return "opencode"
case "qwen":
return "qwen-portal"
case "kimi-code":
return "kimi-coding"
case "gpt":
return "openai"
case "claude":
return "anthropic"
case "glm":
return "zhipu"
case "google":
return "gemini"
}
return p
}
// ModelKey returns a canonical "provider/model" key for deduplication.
func ModelKey(provider, model string) string {
return NormalizeProvider(provider) + "/" + strings.ToLower(strings.TrimSpace(model))
}
+125
View File
@@ -0,0 +1,125 @@
package providers
import "testing"
func TestParseModelRef_WithSlash(t *testing.T) {
ref := ParseModelRef("anthropic/claude-opus", "openai")
if ref == nil {
t.Fatal("expected non-nil ref")
}
if ref.Provider != "anthropic" {
t.Errorf("provider = %q, want anthropic", ref.Provider)
}
if ref.Model != "claude-opus" {
t.Errorf("model = %q, want claude-opus", ref.Model)
}
}
func TestParseModelRef_WithoutSlash(t *testing.T) {
ref := ParseModelRef("gpt-4", "openai")
if ref == nil {
t.Fatal("expected non-nil ref")
}
if ref.Provider != "openai" {
t.Errorf("provider = %q, want openai", ref.Provider)
}
if ref.Model != "gpt-4" {
t.Errorf("model = %q, want gpt-4", ref.Model)
}
}
func TestParseModelRef_Empty(t *testing.T) {
ref := ParseModelRef("", "openai")
if ref != nil {
t.Errorf("expected nil for empty string, got %+v", ref)
}
}
func TestParseModelRef_EmptyModelAfterSlash(t *testing.T) {
ref := ParseModelRef("openai/", "default")
if ref != nil {
t.Errorf("expected nil for empty model, got %+v", ref)
}
}
func TestParseModelRef_WhitespaceHandling(t *testing.T) {
ref := ParseModelRef(" anthropic / claude-opus ", "openai")
if ref == nil {
t.Fatal("expected non-nil ref")
}
if ref.Provider != "anthropic" {
t.Errorf("provider = %q, want anthropic", ref.Provider)
}
if ref.Model != "claude-opus" {
t.Errorf("model = %q, want claude-opus", ref.Model)
}
}
func TestNormalizeProvider(t *testing.T) {
tests := []struct {
input string
want string
}{
{"OpenAI", "openai"},
{"ANTHROPIC", "anthropic"},
{"z.ai", "zai"},
{"z-ai", "zai"},
{"Z.AI", "zai"},
{"opencode-zen", "opencode"},
{"qwen", "qwen-portal"},
{"kimi-code", "kimi-coding"},
{"gpt", "openai"},
{"claude", "anthropic"},
{"glm", "zhipu"},
{"google", "gemini"},
{"groq", "groq"},
{"", ""},
}
for _, tt := range tests {
got := NormalizeProvider(tt.input)
if got != tt.want {
t.Errorf("NormalizeProvider(%q) = %q, want %q", tt.input, got, tt.want)
}
}
}
func TestModelKey(t *testing.T) {
tests := []struct {
provider string
model string
want string
}{
{"openai", "gpt-4", "openai/gpt-4"},
{"Anthropic", "Claude-Opus", "anthropic/claude-opus"},
{"claude", "sonnet", "anthropic/sonnet"},
{"z.ai", "Model-X", "zai/model-x"},
}
for _, tt := range tests {
got := ModelKey(tt.provider, tt.model)
if got != tt.want {
t.Errorf("ModelKey(%q, %q) = %q, want %q", tt.provider, tt.model, got, tt.want)
}
}
}
func TestParseModelRef_ProviderNormalization(t *testing.T) {
ref := ParseModelRef("Z.AI/model-x", "default")
if ref == nil {
t.Fatal("expected non-nil ref")
}
if ref.Provider != "zai" {
t.Errorf("provider = %q, want zai", ref.Provider)
}
}
func TestParseModelRef_DefaultProviderNormalization(t *testing.T) {
ref := ParseModelRef("gpt-4o", "GPT")
if ref == nil {
t.Fatal("expected non-nil ref")
}
if ref.Provider != "openai" {
t.Errorf("provider = %q, want openai (normalized from GPT)", ref.Provider)
}
}
+232
View File
@@ -0,0 +1,232 @@
package openai_compat
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"net/url"
"strings"
"time"
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
)
type ToolCall = protocoltypes.ToolCall
type FunctionCall = protocoltypes.FunctionCall
type LLMResponse = protocoltypes.LLMResponse
type UsageInfo = protocoltypes.UsageInfo
type Message = protocoltypes.Message
type ToolDefinition = protocoltypes.ToolDefinition
type ToolFunctionDefinition = protocoltypes.ToolFunctionDefinition
type Provider struct {
apiKey string
apiBase string
httpClient *http.Client
}
func NewProvider(apiKey, apiBase, proxy string) *Provider {
client := &http.Client{
Timeout: 120 * time.Second,
}
if proxy != "" {
parsed, err := url.Parse(proxy)
if err == nil {
client.Transport = &http.Transport{
Proxy: http.ProxyURL(parsed),
}
} else {
log.Printf("openai_compat: invalid proxy URL %q: %v", proxy, err)
}
}
return &Provider{
apiKey: apiKey,
apiBase: strings.TrimRight(apiBase, "/"),
httpClient: client,
}
}
func (p *Provider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) {
if p.apiBase == "" {
return nil, fmt.Errorf("API base not configured")
}
model = normalizeModel(model, p.apiBase)
requestBody := map[string]interface{}{
"model": model,
"messages": messages,
}
if len(tools) > 0 {
requestBody["tools"] = tools
requestBody["tool_choice"] = "auto"
}
if maxTokens, ok := asInt(options["max_tokens"]); ok {
lowerModel := strings.ToLower(model)
if strings.Contains(lowerModel, "glm") || strings.Contains(lowerModel, "o1") || strings.Contains(lowerModel, "gpt-5") {
requestBody["max_completion_tokens"] = maxTokens
} else {
requestBody["max_tokens"] = maxTokens
}
}
if temperature, ok := asFloat(options["temperature"]); ok {
lowerModel := strings.ToLower(model)
// Kimi k2 models only support temperature=1.
if strings.Contains(lowerModel, "kimi") && strings.Contains(lowerModel, "k2") {
requestBody["temperature"] = 1.0
} else {
requestBody["temperature"] = temperature
}
}
jsonData, err := json.Marshal(requestBody)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
req, err := http.NewRequestWithContext(ctx, "POST", p.apiBase+"/chat/completions", bytes.NewReader(jsonData))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
if p.apiKey != "" {
req.Header.Set("Authorization", "Bearer "+p.apiKey)
}
resp, err := p.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("API request failed:\n Status: %d\n Body: %s", resp.StatusCode, string(body))
}
return parseResponse(body)
}
func parseResponse(body []byte) (*LLMResponse, error) {
var apiResponse struct {
Choices []struct {
Message struct {
Content string `json:"content"`
ToolCalls []struct {
ID string `json:"id"`
Type string `json:"type"`
Function *struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
} `json:"function"`
} `json:"tool_calls"`
} `json:"message"`
FinishReason string `json:"finish_reason"`
} `json:"choices"`
Usage *UsageInfo `json:"usage"`
}
if err := json.Unmarshal(body, &apiResponse); err != nil {
return nil, fmt.Errorf("failed to unmarshal response: %w", err)
}
if len(apiResponse.Choices) == 0 {
return &LLMResponse{
Content: "",
FinishReason: "stop",
}, nil
}
choice := apiResponse.Choices[0]
toolCalls := make([]ToolCall, 0, len(choice.Message.ToolCalls))
for _, tc := range choice.Message.ToolCalls {
arguments := make(map[string]interface{})
name := ""
if tc.Function != nil {
name = tc.Function.Name
if tc.Function.Arguments != "" {
if err := json.Unmarshal([]byte(tc.Function.Arguments), &arguments); err != nil {
log.Printf("openai_compat: failed to decode tool call arguments for %q: %v", name, err)
arguments["raw"] = tc.Function.Arguments
}
}
}
toolCalls = append(toolCalls, ToolCall{
ID: tc.ID,
Name: name,
Arguments: arguments,
})
}
return &LLMResponse{
Content: choice.Message.Content,
ToolCalls: toolCalls,
FinishReason: choice.FinishReason,
Usage: apiResponse.Usage,
}, nil
}
func normalizeModel(model, apiBase string) string {
idx := strings.Index(model, "/")
if idx == -1 {
return model
}
if strings.Contains(strings.ToLower(apiBase), "openrouter.ai") {
return model
}
prefix := strings.ToLower(model[:idx])
switch prefix {
case "moonshot", "nvidia", "groq", "ollama", "deepseek", "google", "openrouter", "zhipu":
return model[idx+1:]
default:
return model
}
}
func asInt(v interface{}) (int, bool) {
switch val := v.(type) {
case int:
return val, true
case int64:
return int(val), true
case float64:
return int(val), true
case float32:
return int(val), true
default:
return 0, false
}
}
func asFloat(v interface{}) (float64, bool) {
switch val := v.(type) {
case float64:
return val, true
case float32:
return float64(val), true
case int:
return float64(val), true
case int64:
return float64(val), true
default:
return 0, false
}
}
@@ -0,0 +1,277 @@
package openai_compat
import (
"encoding/json"
"net/http"
"net/http/httptest"
"net/url"
"testing"
)
func TestProviderChat_UsesMaxCompletionTokensForGLM(t *testing.T) {
var requestBody map[string]interface{}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/chat/completions" {
http.Error(w, "not found", http.StatusNotFound)
return
}
if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
resp := map[string]interface{}{
"choices": []map[string]interface{}{
{
"message": map[string]interface{}{"content": "ok"},
"finish_reason": "stop",
},
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
p := NewProvider("key", server.URL, "")
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "glm-4.7", map[string]interface{}{"max_tokens": 1234})
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
if _, ok := requestBody["max_completion_tokens"]; !ok {
t.Fatalf("expected max_completion_tokens in request body")
}
if _, ok := requestBody["max_tokens"]; ok {
t.Fatalf("did not expect max_tokens key for glm model")
}
}
func TestProviderChat_ParsesToolCalls(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
resp := map[string]interface{}{
"choices": []map[string]interface{}{
{
"message": map[string]interface{}{
"content": "",
"tool_calls": []map[string]interface{}{
{
"id": "call_1",
"type": "function",
"function": map[string]interface{}{
"name": "get_weather",
"arguments": "{\"city\":\"SF\"}",
},
},
},
},
"finish_reason": "tool_calls",
},
},
"usage": map[string]interface{}{
"prompt_tokens": 10,
"completion_tokens": 5,
"total_tokens": 15,
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
p := NewProvider("key", server.URL, "")
out, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "gpt-4o", nil)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
if len(out.ToolCalls) != 1 {
t.Fatalf("len(ToolCalls) = %d, want 1", len(out.ToolCalls))
}
if out.ToolCalls[0].Name != "get_weather" {
t.Fatalf("ToolCalls[0].Name = %q, want %q", out.ToolCalls[0].Name, "get_weather")
}
if out.ToolCalls[0].Arguments["city"] != "SF" {
t.Fatalf("ToolCalls[0].Arguments[city] = %v, want SF", out.ToolCalls[0].Arguments["city"])
}
}
func TestProviderChat_HTTPError(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "bad request", http.StatusBadRequest)
}))
defer server.Close()
p := NewProvider("key", server.URL, "")
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "gpt-4o", nil)
if err == nil {
t.Fatal("expected error, got nil")
}
}
func TestProviderChat_StripsMoonshotPrefixAndNormalizesKimiTemperature(t *testing.T) {
var requestBody map[string]interface{}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
resp := map[string]interface{}{
"choices": []map[string]interface{}{
{
"message": map[string]interface{}{"content": "ok"},
"finish_reason": "stop",
},
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
p := NewProvider("key", server.URL, "")
_, err := p.Chat(
t.Context(),
[]Message{{Role: "user", Content: "hi"}},
nil,
"moonshot/kimi-k2.5",
map[string]interface{}{"temperature": 0.3},
)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
if requestBody["model"] != "kimi-k2.5" {
t.Fatalf("model = %v, want kimi-k2.5", requestBody["model"])
}
if requestBody["temperature"] != 1.0 {
t.Fatalf("temperature = %v, want 1.0", requestBody["temperature"])
}
}
func TestProviderChat_StripsGroqAndOllamaPrefixes(t *testing.T) {
tests := []struct {
name string
input string
wantModel string
}{
{
name: "strips groq prefix and keeps nested model",
input: "groq/openai/gpt-oss-120b",
wantModel: "openai/gpt-oss-120b",
},
{
name: "strips ollama prefix",
input: "ollama/qwen2.5:14b",
wantModel: "qwen2.5:14b",
},
{
name: "strips deepseek prefix",
input: "deepseek/deepseek-chat",
wantModel: "deepseek-chat",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var requestBody map[string]interface{}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
resp := map[string]interface{}{
"choices": []map[string]interface{}{
{
"message": map[string]interface{}{"content": "ok"},
"finish_reason": "stop",
},
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
p := NewProvider("key", server.URL, "")
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, tt.input, nil)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
if requestBody["model"] != tt.wantModel {
t.Fatalf("model = %v, want %s", requestBody["model"], tt.wantModel)
}
})
}
}
func TestProvider_ProxyConfigured(t *testing.T) {
proxyURL := "http://127.0.0.1:8080"
p := NewProvider("key", "https://example.com", proxyURL)
transport, ok := p.httpClient.Transport.(*http.Transport)
if !ok || transport == nil {
t.Fatalf("expected http transport with proxy, got %T", p.httpClient.Transport)
}
req := &http.Request{URL: &url.URL{Scheme: "https", Host: "api.example.com"}}
gotProxy, err := transport.Proxy(req)
if err != nil {
t.Fatalf("proxy function returned error: %v", err)
}
if gotProxy == nil || gotProxy.String() != proxyURL {
t.Fatalf("proxy = %v, want %s", gotProxy, proxyURL)
}
}
func TestProviderChat_AcceptsNumericOptionTypes(t *testing.T) {
var requestBody map[string]interface{}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
resp := map[string]interface{}{
"choices": []map[string]interface{}{
{
"message": map[string]interface{}{"content": "ok"},
"finish_reason": "stop",
},
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
p := NewProvider("key", server.URL, "")
_, err := p.Chat(
t.Context(),
[]Message{{Role: "user", Content: "hi"}},
nil,
"gpt-4o",
map[string]interface{}{"max_tokens": float64(512), "temperature": 1},
)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
if requestBody["max_tokens"] != float64(512) {
t.Fatalf("max_tokens = %v, want 512", requestBody["max_tokens"])
}
if requestBody["temperature"] != float64(1) {
t.Fatalf("temperature = %v, want 1", requestBody["temperature"])
}
}
func TestNormalizeModel_UsesAPIBase(t *testing.T) {
if got := normalizeModel("deepseek/deepseek-chat", "https://api.deepseek.com/v1"); got != "deepseek-chat" {
t.Fatalf("normalizeModel(deepseek) = %q, want %q", got, "deepseek-chat")
}
if got := normalizeModel("openrouter/auto", "https://openrouter.ai/api/v1"); got != "openrouter/auto" {
t.Fatalf("normalizeModel(openrouter) = %q, want %q", got, "openrouter/auto")
}
}
+45
View File
@@ -0,0 +1,45 @@
package protocoltypes
type ToolCall struct {
ID string `json:"id"`
Type string `json:"type,omitempty"`
Function *FunctionCall `json:"function,omitempty"`
Name string `json:"name,omitempty"`
Arguments map[string]interface{} `json:"arguments,omitempty"`
}
type FunctionCall struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
}
type LLMResponse struct {
Content string `json:"content"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
FinishReason string `json:"finish_reason"`
Usage *UsageInfo `json:"usage,omitempty"`
}
type UsageInfo struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
}
type ToolDefinition struct {
Type string `json:"type"`
Function ToolFunctionDefinition `json:"function"`
}
type ToolFunctionDefinition struct {
Name string `json:"name"`
Description string `json:"description"`
Parameters map[string]interface{} `json:"parameters"`
}
+51 -40
View File
@@ -1,53 +1,64 @@
package providers
import "context"
import (
"context"
"fmt"
type ToolCall struct {
ID string `json:"id"`
Type string `json:"type,omitempty"`
Function *FunctionCall `json:"function,omitempty"`
Name string `json:"name,omitempty"`
Arguments map[string]interface{} `json:"arguments,omitempty"`
}
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
)
type FunctionCall struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
ThoughtSignature string `json:"thought_signature,omitempty"`
}
type LLMResponse struct {
Content string `json:"content"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
FinishReason string `json:"finish_reason"`
Usage *UsageInfo `json:"usage,omitempty"`
}
type UsageInfo struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
}
type ToolCall = protocoltypes.ToolCall
type FunctionCall = protocoltypes.FunctionCall
type LLMResponse = protocoltypes.LLMResponse
type UsageInfo = protocoltypes.UsageInfo
type Message = protocoltypes.Message
type ToolDefinition = protocoltypes.ToolDefinition
type ToolFunctionDefinition = protocoltypes.ToolFunctionDefinition
type LLMProvider interface {
Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error)
GetDefaultModel() string
}
type ToolDefinition struct {
Type string `json:"type"`
Function ToolFunctionDefinition `json:"function"`
// FailoverReason classifies why an LLM request failed for fallback decisions.
type FailoverReason string
const (
FailoverAuth FailoverReason = "auth"
FailoverRateLimit FailoverReason = "rate_limit"
FailoverBilling FailoverReason = "billing"
FailoverTimeout FailoverReason = "timeout"
FailoverFormat FailoverReason = "format"
FailoverOverloaded FailoverReason = "overloaded"
FailoverUnknown FailoverReason = "unknown"
)
// FailoverError wraps an LLM provider error with classification metadata.
type FailoverError struct {
Reason FailoverReason
Provider string
Model string
Status int
Wrapped error
}
type ToolFunctionDefinition struct {
Name string `json:"name"`
Description string `json:"description"`
Parameters map[string]interface{} `json:"parameters"`
func (e *FailoverError) Error() string {
return fmt.Sprintf("failover(%s): provider=%s model=%s status=%d: %v",
e.Reason, e.Provider, e.Model, e.Status, e.Wrapped)
}
func (e *FailoverError) Unwrap() error {
return e.Wrapped
}
// IsRetriable returns true if this error should trigger fallback to next candidate.
// Non-retriable: Format errors (bad request structure, image dimension/size).
func (e *FailoverError) IsRetriable() bool {
return e.Reason != FailoverFormat
}
// ModelConfig holds primary model and fallback list.
type ModelConfig struct {
Primary string
Fallbacks []string
}
+66
View File
@@ -0,0 +1,66 @@
package routing
import (
"regexp"
"strings"
)
const (
DefaultAgentID = "main"
DefaultMainKey = "main"
DefaultAccountID = "default"
MaxAgentIDLength = 64
)
var (
validIDRe = regexp.MustCompile(`^[a-z0-9][a-z0-9_-]{0,63}$`)
invalidCharsRe = regexp.MustCompile(`[^a-z0-9_-]+`)
leadingDashRe = regexp.MustCompile(`^-+`)
trailingDashRe = regexp.MustCompile(`-+$`)
)
// NormalizeAgentID sanitizes an agent ID to [a-z0-9][a-z0-9_-]{0,63}.
// Invalid characters are collapsed to "-". Leading/trailing dashes stripped.
// Empty input returns DefaultAgentID ("main").
func NormalizeAgentID(id string) string {
trimmed := strings.TrimSpace(id)
if trimmed == "" {
return DefaultAgentID
}
lower := strings.ToLower(trimmed)
if validIDRe.MatchString(lower) {
return lower
}
result := invalidCharsRe.ReplaceAllString(lower, "-")
result = leadingDashRe.ReplaceAllString(result, "")
result = trailingDashRe.ReplaceAllString(result, "")
if len(result) > MaxAgentIDLength {
result = result[:MaxAgentIDLength]
}
if result == "" {
return DefaultAgentID
}
return result
}
// NormalizeAccountID sanitizes an account ID. Empty returns DefaultAccountID.
func NormalizeAccountID(id string) string {
trimmed := strings.TrimSpace(id)
if trimmed == "" {
return DefaultAccountID
}
lower := strings.ToLower(trimmed)
if validIDRe.MatchString(lower) {
return lower
}
result := invalidCharsRe.ReplaceAllString(lower, "-")
result = leadingDashRe.ReplaceAllString(result, "")
result = trailingDashRe.ReplaceAllString(result, "")
if len(result) > MaxAgentIDLength {
result = result[:MaxAgentIDLength]
}
if result == "" {
return DefaultAccountID
}
return result
}
+86
View File
@@ -0,0 +1,86 @@
package routing
import "testing"
func TestNormalizeAgentID_Empty(t *testing.T) {
if got := NormalizeAgentID(""); got != DefaultAgentID {
t.Errorf("NormalizeAgentID('') = %q, want %q", got, DefaultAgentID)
}
}
func TestNormalizeAgentID_Whitespace(t *testing.T) {
if got := NormalizeAgentID(" "); got != DefaultAgentID {
t.Errorf("NormalizeAgentID(' ') = %q, want %q", got, DefaultAgentID)
}
}
func TestNormalizeAgentID_Valid(t *testing.T) {
tests := []struct {
input, want string
}{
{"main", "main"},
{"Main", "main"},
{"SALES", "sales"},
{"support-bot", "support-bot"},
{"agent_1", "agent_1"},
{"a", "a"},
{"0test", "0test"},
}
for _, tt := range tests {
if got := NormalizeAgentID(tt.input); got != tt.want {
t.Errorf("NormalizeAgentID(%q) = %q, want %q", tt.input, got, tt.want)
}
}
}
func TestNormalizeAgentID_InvalidChars(t *testing.T) {
tests := []struct {
input, want string
}{
{"Hello World", "hello-world"},
{"agent@123", "agent-123"},
{"foo.bar.baz", "foo-bar-baz"},
{"--leading", "leading"},
{"--both--", "both"},
}
for _, tt := range tests {
if got := NormalizeAgentID(tt.input); got != tt.want {
t.Errorf("NormalizeAgentID(%q) = %q, want %q", tt.input, got, tt.want)
}
}
}
func TestNormalizeAgentID_AllInvalid(t *testing.T) {
if got := NormalizeAgentID("@@@"); got != DefaultAgentID {
t.Errorf("NormalizeAgentID('@@@') = %q, want %q", got, DefaultAgentID)
}
}
func TestNormalizeAgentID_TruncatesAt64(t *testing.T) {
long := ""
for i := 0; i < 100; i++ {
long += "a"
}
got := NormalizeAgentID(long)
if len(got) > MaxAgentIDLength {
t.Errorf("length = %d, want <= %d", len(got), MaxAgentIDLength)
}
}
func TestNormalizeAccountID_Empty(t *testing.T) {
if got := NormalizeAccountID(""); got != DefaultAccountID {
t.Errorf("NormalizeAccountID('') = %q, want %q", got, DefaultAccountID)
}
}
func TestNormalizeAccountID_Valid(t *testing.T) {
if got := NormalizeAccountID("MyBot"); got != "mybot" {
t.Errorf("NormalizeAccountID('MyBot') = %q, want 'mybot'", got)
}
}
func TestNormalizeAccountID_InvalidChars(t *testing.T) {
if got := NormalizeAccountID("bot@home"); got != "bot-home" {
t.Errorf("NormalizeAccountID('bot@home') = %q, want 'bot-home'", got)
}
}
+252
View File
@@ -0,0 +1,252 @@
package routing
import (
"strings"
"github.com/sipeed/picoclaw/pkg/config"
)
// RouteInput contains the routing context from an inbound message.
type RouteInput struct {
Channel string
AccountID string
Peer *RoutePeer
ParentPeer *RoutePeer
GuildID string
TeamID string
}
// ResolvedRoute is the result of agent routing.
type ResolvedRoute struct {
AgentID string
Channel string
AccountID string
SessionKey string
MainSessionKey string
MatchedBy string // "binding.peer", "binding.peer.parent", "binding.guild", "binding.team", "binding.account", "binding.channel", "default"
}
// RouteResolver determines which agent handles a message based on config bindings.
type RouteResolver struct {
cfg *config.Config
}
// NewRouteResolver creates a new route resolver.
func NewRouteResolver(cfg *config.Config) *RouteResolver {
return &RouteResolver{cfg: cfg}
}
// ResolveRoute determines which agent handles the message and constructs session keys.
// Implements the 7-level priority cascade:
// peer > parent_peer > guild > team > account > channel_wildcard > default
func (r *RouteResolver) ResolveRoute(input RouteInput) ResolvedRoute {
channel := strings.ToLower(strings.TrimSpace(input.Channel))
accountID := NormalizeAccountID(input.AccountID)
peer := input.Peer
dmScope := DMScope(r.cfg.Session.DMScope)
if dmScope == "" {
dmScope = DMScopeMain
}
identityLinks := r.cfg.Session.IdentityLinks
bindings := r.filterBindings(channel, accountID)
choose := func(agentID string, matchedBy string) ResolvedRoute {
resolvedAgentID := r.pickAgentID(agentID)
sessionKey := strings.ToLower(BuildAgentPeerSessionKey(SessionKeyParams{
AgentID: resolvedAgentID,
Channel: channel,
AccountID: accountID,
Peer: peer,
DMScope: dmScope,
IdentityLinks: identityLinks,
}))
mainSessionKey := strings.ToLower(BuildAgentMainSessionKey(resolvedAgentID))
return ResolvedRoute{
AgentID: resolvedAgentID,
Channel: channel,
AccountID: accountID,
SessionKey: sessionKey,
MainSessionKey: mainSessionKey,
MatchedBy: matchedBy,
}
}
// Priority 1: Peer binding
if peer != nil && strings.TrimSpace(peer.ID) != "" {
if match := r.findPeerMatch(bindings, peer); match != nil {
return choose(match.AgentID, "binding.peer")
}
}
// Priority 2: Parent peer binding
parentPeer := input.ParentPeer
if parentPeer != nil && strings.TrimSpace(parentPeer.ID) != "" {
if match := r.findPeerMatch(bindings, parentPeer); match != nil {
return choose(match.AgentID, "binding.peer.parent")
}
}
// Priority 3: Guild binding
guildID := strings.TrimSpace(input.GuildID)
if guildID != "" {
if match := r.findGuildMatch(bindings, guildID); match != nil {
return choose(match.AgentID, "binding.guild")
}
}
// Priority 4: Team binding
teamID := strings.TrimSpace(input.TeamID)
if teamID != "" {
if match := r.findTeamMatch(bindings, teamID); match != nil {
return choose(match.AgentID, "binding.team")
}
}
// Priority 5: Account binding
if match := r.findAccountMatch(bindings); match != nil {
return choose(match.AgentID, "binding.account")
}
// Priority 6: Channel wildcard binding
if match := r.findChannelWildcardMatch(bindings); match != nil {
return choose(match.AgentID, "binding.channel")
}
// Priority 7: Default agent
return choose(r.resolveDefaultAgentID(), "default")
}
func (r *RouteResolver) filterBindings(channel, accountID string) []config.AgentBinding {
var filtered []config.AgentBinding
for _, b := range r.cfg.Bindings {
matchChannel := strings.ToLower(strings.TrimSpace(b.Match.Channel))
if matchChannel == "" || matchChannel != channel {
continue
}
if !matchesAccountID(b.Match.AccountID, accountID) {
continue
}
filtered = append(filtered, b)
}
return filtered
}
func matchesAccountID(matchAccountID, actual string) bool {
trimmed := strings.TrimSpace(matchAccountID)
if trimmed == "" {
return actual == DefaultAccountID
}
if trimmed == "*" {
return true
}
return strings.ToLower(trimmed) == strings.ToLower(actual)
}
func (r *RouteResolver) findPeerMatch(bindings []config.AgentBinding, peer *RoutePeer) *config.AgentBinding {
for i := range bindings {
b := &bindings[i]
if b.Match.Peer == nil {
continue
}
peerKind := strings.ToLower(strings.TrimSpace(b.Match.Peer.Kind))
peerID := strings.TrimSpace(b.Match.Peer.ID)
if peerKind == "" || peerID == "" {
continue
}
if peerKind == strings.ToLower(peer.Kind) && peerID == peer.ID {
return b
}
}
return nil
}
func (r *RouteResolver) findGuildMatch(bindings []config.AgentBinding, guildID string) *config.AgentBinding {
for i := range bindings {
b := &bindings[i]
matchGuild := strings.TrimSpace(b.Match.GuildID)
if matchGuild != "" && matchGuild == guildID {
return &bindings[i]
}
}
return nil
}
func (r *RouteResolver) findTeamMatch(bindings []config.AgentBinding, teamID string) *config.AgentBinding {
for i := range bindings {
b := &bindings[i]
matchTeam := strings.TrimSpace(b.Match.TeamID)
if matchTeam != "" && matchTeam == teamID {
return &bindings[i]
}
}
return nil
}
func (r *RouteResolver) findAccountMatch(bindings []config.AgentBinding) *config.AgentBinding {
for i := range bindings {
b := &bindings[i]
accountID := strings.TrimSpace(b.Match.AccountID)
if accountID == "*" {
continue
}
if b.Match.Peer != nil || b.Match.GuildID != "" || b.Match.TeamID != "" {
continue
}
return &bindings[i]
}
return nil
}
func (r *RouteResolver) findChannelWildcardMatch(bindings []config.AgentBinding) *config.AgentBinding {
for i := range bindings {
b := &bindings[i]
accountID := strings.TrimSpace(b.Match.AccountID)
if accountID != "*" {
continue
}
if b.Match.Peer != nil || b.Match.GuildID != "" || b.Match.TeamID != "" {
continue
}
return &bindings[i]
}
return nil
}
func (r *RouteResolver) pickAgentID(agentID string) string {
trimmed := strings.TrimSpace(agentID)
if trimmed == "" {
return NormalizeAgentID(r.resolveDefaultAgentID())
}
normalized := NormalizeAgentID(trimmed)
agents := r.cfg.Agents.List
if len(agents) == 0 {
return normalized
}
for _, a := range agents {
if NormalizeAgentID(a.ID) == normalized {
return normalized
}
}
return NormalizeAgentID(r.resolveDefaultAgentID())
}
func (r *RouteResolver) resolveDefaultAgentID() string {
agents := r.cfg.Agents.List
if len(agents) == 0 {
return DefaultAgentID
}
for _, a := range agents {
if a.Default {
id := strings.TrimSpace(a.ID)
if id != "" {
return NormalizeAgentID(id)
}
}
}
if id := strings.TrimSpace(agents[0].ID); id != "" {
return NormalizeAgentID(id)
}
return DefaultAgentID
}
+297
View File
@@ -0,0 +1,297 @@
package routing
import (
"testing"
"github.com/sipeed/picoclaw/pkg/config"
)
func testConfig(agents []config.AgentConfig, bindings []config.AgentBinding) *config.Config {
return &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: "/tmp/picoclaw-test",
Model: "gpt-4",
},
List: agents,
},
Bindings: bindings,
Session: config.SessionConfig{
DMScope: "per-peer",
},
}
}
func TestResolveRoute_DefaultAgent_NoBindings(t *testing.T) {
cfg := testConfig(nil, nil)
r := NewRouteResolver(cfg)
route := r.ResolveRoute(RouteInput{
Channel: "telegram",
Peer: &RoutePeer{Kind: "direct", ID: "user1"},
})
if route.AgentID != DefaultAgentID {
t.Errorf("AgentID = %q, want %q", route.AgentID, DefaultAgentID)
}
if route.MatchedBy != "default" {
t.Errorf("MatchedBy = %q, want 'default'", route.MatchedBy)
}
}
func TestResolveRoute_PeerBinding(t *testing.T) {
agents := []config.AgentConfig{
{ID: "sales", Default: true},
{ID: "support"},
}
bindings := []config.AgentBinding{
{
AgentID: "support",
Match: config.BindingMatch{
Channel: "telegram",
AccountID: "*",
Peer: &config.PeerMatch{Kind: "direct", ID: "user123"},
},
},
}
cfg := testConfig(agents, bindings)
r := NewRouteResolver(cfg)
route := r.ResolveRoute(RouteInput{
Channel: "telegram",
Peer: &RoutePeer{Kind: "direct", ID: "user123"},
})
if route.AgentID != "support" {
t.Errorf("AgentID = %q, want 'support'", route.AgentID)
}
if route.MatchedBy != "binding.peer" {
t.Errorf("MatchedBy = %q, want 'binding.peer'", route.MatchedBy)
}
}
func TestResolveRoute_GuildBinding(t *testing.T) {
agents := []config.AgentConfig{
{ID: "general", Default: true},
{ID: "gaming"},
}
bindings := []config.AgentBinding{
{
AgentID: "gaming",
Match: config.BindingMatch{
Channel: "discord",
AccountID: "*",
GuildID: "guild-abc",
},
},
}
cfg := testConfig(agents, bindings)
r := NewRouteResolver(cfg)
route := r.ResolveRoute(RouteInput{
Channel: "discord",
GuildID: "guild-abc",
Peer: &RoutePeer{Kind: "channel", ID: "ch1"},
})
if route.AgentID != "gaming" {
t.Errorf("AgentID = %q, want 'gaming'", route.AgentID)
}
if route.MatchedBy != "binding.guild" {
t.Errorf("MatchedBy = %q, want 'binding.guild'", route.MatchedBy)
}
}
func TestResolveRoute_TeamBinding(t *testing.T) {
agents := []config.AgentConfig{
{ID: "general", Default: true},
{ID: "work"},
}
bindings := []config.AgentBinding{
{
AgentID: "work",
Match: config.BindingMatch{
Channel: "slack",
AccountID: "*",
TeamID: "T12345",
},
},
}
cfg := testConfig(agents, bindings)
r := NewRouteResolver(cfg)
route := r.ResolveRoute(RouteInput{
Channel: "slack",
TeamID: "T12345",
Peer: &RoutePeer{Kind: "channel", ID: "C001"},
})
if route.AgentID != "work" {
t.Errorf("AgentID = %q, want 'work'", route.AgentID)
}
if route.MatchedBy != "binding.team" {
t.Errorf("MatchedBy = %q, want 'binding.team'", route.MatchedBy)
}
}
func TestResolveRoute_AccountBinding(t *testing.T) {
agents := []config.AgentConfig{
{ID: "default-agent", Default: true},
{ID: "premium"},
}
bindings := []config.AgentBinding{
{
AgentID: "premium",
Match: config.BindingMatch{
Channel: "telegram",
AccountID: "bot2",
},
},
}
cfg := testConfig(agents, bindings)
r := NewRouteResolver(cfg)
route := r.ResolveRoute(RouteInput{
Channel: "telegram",
AccountID: "bot2",
Peer: &RoutePeer{Kind: "direct", ID: "user1"},
})
if route.AgentID != "premium" {
t.Errorf("AgentID = %q, want 'premium'", route.AgentID)
}
if route.MatchedBy != "binding.account" {
t.Errorf("MatchedBy = %q, want 'binding.account'", route.MatchedBy)
}
}
func TestResolveRoute_ChannelWildcard(t *testing.T) {
agents := []config.AgentConfig{
{ID: "main", Default: true},
{ID: "telegram-bot"},
}
bindings := []config.AgentBinding{
{
AgentID: "telegram-bot",
Match: config.BindingMatch{
Channel: "telegram",
AccountID: "*",
},
},
}
cfg := testConfig(agents, bindings)
r := NewRouteResolver(cfg)
route := r.ResolveRoute(RouteInput{
Channel: "telegram",
Peer: &RoutePeer{Kind: "direct", ID: "user1"},
})
if route.AgentID != "telegram-bot" {
t.Errorf("AgentID = %q, want 'telegram-bot'", route.AgentID)
}
if route.MatchedBy != "binding.channel" {
t.Errorf("MatchedBy = %q, want 'binding.channel'", route.MatchedBy)
}
}
func TestResolveRoute_PriorityOrder_PeerBeatsGuild(t *testing.T) {
agents := []config.AgentConfig{
{ID: "general", Default: true},
{ID: "vip"},
{ID: "gaming"},
}
bindings := []config.AgentBinding{
{
AgentID: "vip",
Match: config.BindingMatch{
Channel: "discord",
AccountID: "*",
Peer: &config.PeerMatch{Kind: "direct", ID: "user-vip"},
},
},
{
AgentID: "gaming",
Match: config.BindingMatch{
Channel: "discord",
AccountID: "*",
GuildID: "guild-1",
},
},
}
cfg := testConfig(agents, bindings)
r := NewRouteResolver(cfg)
route := r.ResolveRoute(RouteInput{
Channel: "discord",
GuildID: "guild-1",
Peer: &RoutePeer{Kind: "direct", ID: "user-vip"},
})
if route.AgentID != "vip" {
t.Errorf("AgentID = %q, want 'vip' (peer should beat guild)", route.AgentID)
}
if route.MatchedBy != "binding.peer" {
t.Errorf("MatchedBy = %q, want 'binding.peer'", route.MatchedBy)
}
}
func TestResolveRoute_InvalidAgentFallsToDefault(t *testing.T) {
agents := []config.AgentConfig{
{ID: "main", Default: true},
}
bindings := []config.AgentBinding{
{
AgentID: "nonexistent",
Match: config.BindingMatch{
Channel: "telegram",
AccountID: "*",
},
},
}
cfg := testConfig(agents, bindings)
r := NewRouteResolver(cfg)
route := r.ResolveRoute(RouteInput{
Channel: "telegram",
})
if route.AgentID != "main" {
t.Errorf("AgentID = %q, want 'main' (invalid agent should fall to default)", route.AgentID)
}
}
func TestResolveRoute_DefaultAgentSelection(t *testing.T) {
agents := []config.AgentConfig{
{ID: "alpha"},
{ID: "beta", Default: true},
{ID: "gamma"},
}
cfg := testConfig(agents, nil)
r := NewRouteResolver(cfg)
route := r.ResolveRoute(RouteInput{
Channel: "cli",
})
if route.AgentID != "beta" {
t.Errorf("AgentID = %q, want 'beta' (marked as default)", route.AgentID)
}
}
func TestResolveRoute_NoDefaultUsesFirst(t *testing.T) {
agents := []config.AgentConfig{
{ID: "alpha"},
{ID: "beta"},
}
cfg := testConfig(agents, nil)
r := NewRouteResolver(cfg)
route := r.ResolveRoute(RouteInput{
Channel: "cli",
})
if route.AgentID != "alpha" {
t.Errorf("AgentID = %q, want 'alpha' (first in list)", route.AgentID)
}
}
+183
View File
@@ -0,0 +1,183 @@
package routing
import (
"fmt"
"strings"
)
// DMScope controls DM session isolation granularity.
type DMScope string
const (
DMScopeMain DMScope = "main"
DMScopePerPeer DMScope = "per-peer"
DMScopePerChannelPeer DMScope = "per-channel-peer"
DMScopePerAccountChannelPeer DMScope = "per-account-channel-peer"
)
// RoutePeer represents a chat peer with kind and ID.
type RoutePeer struct {
Kind string // "direct", "group", "channel"
ID string
}
// SessionKeyParams holds all inputs for session key construction.
type SessionKeyParams struct {
AgentID string
Channel string
AccountID string
Peer *RoutePeer
DMScope DMScope
IdentityLinks map[string][]string
}
// ParsedSessionKey is the result of parsing an agent-scoped session key.
type ParsedSessionKey struct {
AgentID string
Rest string
}
// BuildAgentMainSessionKey returns "agent:<agentId>:main".
func BuildAgentMainSessionKey(agentID string) string {
return fmt.Sprintf("agent:%s:%s", NormalizeAgentID(agentID), DefaultMainKey)
}
// BuildAgentPeerSessionKey constructs a session key based on agent, channel, peer, and DM scope.
func BuildAgentPeerSessionKey(params SessionKeyParams) string {
agentID := NormalizeAgentID(params.AgentID)
peer := params.Peer
if peer == nil {
peer = &RoutePeer{Kind: "direct"}
}
peerKind := strings.TrimSpace(peer.Kind)
if peerKind == "" {
peerKind = "direct"
}
if peerKind == "direct" {
dmScope := params.DMScope
if dmScope == "" {
dmScope = DMScopeMain
}
peerID := strings.TrimSpace(peer.ID)
// Resolve identity links (cross-platform collapse)
if dmScope != DMScopeMain && peerID != "" {
if linked := resolveLinkedPeerID(params.IdentityLinks, params.Channel, peerID); linked != "" {
peerID = linked
}
}
peerID = strings.ToLower(peerID)
switch dmScope {
case DMScopePerAccountChannelPeer:
if peerID != "" {
channel := normalizeChannel(params.Channel)
accountID := NormalizeAccountID(params.AccountID)
return fmt.Sprintf("agent:%s:%s:%s:direct:%s", agentID, channel, accountID, peerID)
}
case DMScopePerChannelPeer:
if peerID != "" {
channel := normalizeChannel(params.Channel)
return fmt.Sprintf("agent:%s:%s:direct:%s", agentID, channel, peerID)
}
case DMScopePerPeer:
if peerID != "" {
return fmt.Sprintf("agent:%s:direct:%s", agentID, peerID)
}
}
return BuildAgentMainSessionKey(agentID)
}
// Group/channel peers always get per-peer sessions
channel := normalizeChannel(params.Channel)
peerID := strings.ToLower(strings.TrimSpace(peer.ID))
if peerID == "" {
peerID = "unknown"
}
return fmt.Sprintf("agent:%s:%s:%s:%s", agentID, channel, peerKind, peerID)
}
// ParseAgentSessionKey extracts agentId and rest from "agent:<agentId>:<rest>".
func ParseAgentSessionKey(sessionKey string) *ParsedSessionKey {
raw := strings.TrimSpace(sessionKey)
if raw == "" {
return nil
}
parts := strings.SplitN(raw, ":", 3)
if len(parts) < 3 {
return nil
}
if parts[0] != "agent" {
return nil
}
agentID := strings.TrimSpace(parts[1])
rest := parts[2]
if agentID == "" || rest == "" {
return nil
}
return &ParsedSessionKey{AgentID: agentID, Rest: rest}
}
// IsSubagentSessionKey returns true if the session key represents a subagent.
func IsSubagentSessionKey(sessionKey string) bool {
raw := strings.TrimSpace(sessionKey)
if raw == "" {
return false
}
if strings.HasPrefix(strings.ToLower(raw), "subagent:") {
return true
}
parsed := ParseAgentSessionKey(raw)
if parsed == nil {
return false
}
return strings.HasPrefix(strings.ToLower(parsed.Rest), "subagent:")
}
func normalizeChannel(channel string) string {
c := strings.TrimSpace(strings.ToLower(channel))
if c == "" {
return "unknown"
}
return c
}
func resolveLinkedPeerID(identityLinks map[string][]string, channel, peerID string) string {
if len(identityLinks) == 0 {
return ""
}
peerID = strings.TrimSpace(peerID)
if peerID == "" {
return ""
}
candidates := make(map[string]bool)
rawCandidate := strings.ToLower(peerID)
if rawCandidate != "" {
candidates[rawCandidate] = true
}
channel = strings.ToLower(strings.TrimSpace(channel))
if channel != "" {
scopedCandidate := fmt.Sprintf("%s:%s", channel, strings.ToLower(peerID))
candidates[scopedCandidate] = true
}
if len(candidates) == 0 {
return ""
}
for canonical, ids := range identityLinks {
canonicalName := strings.TrimSpace(canonical)
if canonicalName == "" {
continue
}
for _, id := range ids {
normalized := strings.ToLower(strings.TrimSpace(id))
if normalized != "" && candidates[normalized] {
return canonicalName
}
}
}
return ""
}
+162
View File
@@ -0,0 +1,162 @@
package routing
import "testing"
func TestBuildAgentMainSessionKey(t *testing.T) {
got := BuildAgentMainSessionKey("sales")
want := "agent:sales:main"
if got != want {
t.Errorf("BuildAgentMainSessionKey('sales') = %q, want %q", got, want)
}
}
func TestBuildAgentMainSessionKey_Normalizes(t *testing.T) {
got := BuildAgentMainSessionKey("Sales Bot")
want := "agent:sales-bot:main"
if got != want {
t.Errorf("BuildAgentMainSessionKey('Sales Bot') = %q, want %q", got, want)
}
}
func TestBuildAgentPeerSessionKey_DMScopeMain(t *testing.T) {
got := BuildAgentPeerSessionKey(SessionKeyParams{
AgentID: "main",
Channel: "telegram",
Peer: &RoutePeer{Kind: "direct", ID: "user123"},
DMScope: DMScopeMain,
})
want := "agent:main:main"
if got != want {
t.Errorf("DMScopeMain = %q, want %q", got, want)
}
}
func TestBuildAgentPeerSessionKey_DMScopePerPeer(t *testing.T) {
got := BuildAgentPeerSessionKey(SessionKeyParams{
AgentID: "main",
Channel: "telegram",
Peer: &RoutePeer{Kind: "direct", ID: "user123"},
DMScope: DMScopePerPeer,
})
want := "agent:main:direct:user123"
if got != want {
t.Errorf("DMScopePerPeer = %q, want %q", got, want)
}
}
func TestBuildAgentPeerSessionKey_DMScopePerChannelPeer(t *testing.T) {
got := BuildAgentPeerSessionKey(SessionKeyParams{
AgentID: "main",
Channel: "telegram",
Peer: &RoutePeer{Kind: "direct", ID: "user123"},
DMScope: DMScopePerChannelPeer,
})
want := "agent:main:telegram:direct:user123"
if got != want {
t.Errorf("DMScopePerChannelPeer = %q, want %q", got, want)
}
}
func TestBuildAgentPeerSessionKey_DMScopePerAccountChannelPeer(t *testing.T) {
got := BuildAgentPeerSessionKey(SessionKeyParams{
AgentID: "main",
Channel: "telegram",
AccountID: "bot1",
Peer: &RoutePeer{Kind: "direct", ID: "User123"},
DMScope: DMScopePerAccountChannelPeer,
})
want := "agent:main:telegram:bot1:direct:user123"
if got != want {
t.Errorf("DMScopePerAccountChannelPeer = %q, want %q", got, want)
}
}
func TestBuildAgentPeerSessionKey_GroupPeer(t *testing.T) {
got := BuildAgentPeerSessionKey(SessionKeyParams{
AgentID: "main",
Channel: "telegram",
Peer: &RoutePeer{Kind: "group", ID: "chat456"},
DMScope: DMScopePerPeer,
})
want := "agent:main:telegram:group:chat456"
if got != want {
t.Errorf("GroupPeer = %q, want %q", got, want)
}
}
func TestBuildAgentPeerSessionKey_NilPeer(t *testing.T) {
got := BuildAgentPeerSessionKey(SessionKeyParams{
AgentID: "main",
Channel: "telegram",
Peer: nil,
DMScope: DMScopePerPeer,
})
// nil peer defaults to direct with empty ID, falls to main
want := "agent:main:main"
if got != want {
t.Errorf("NilPeer = %q, want %q", got, want)
}
}
func TestBuildAgentPeerSessionKey_IdentityLink(t *testing.T) {
links := map[string][]string{
"john": {"telegram:user123", "discord:john#1234"},
}
got := BuildAgentPeerSessionKey(SessionKeyParams{
AgentID: "main",
Channel: "telegram",
Peer: &RoutePeer{Kind: "direct", ID: "user123"},
DMScope: DMScopePerPeer,
IdentityLinks: links,
})
want := "agent:main:direct:john"
if got != want {
t.Errorf("IdentityLink = %q, want %q", got, want)
}
}
func TestParseAgentSessionKey_Valid(t *testing.T) {
parsed := ParseAgentSessionKey("agent:sales:telegram:direct:user123")
if parsed == nil {
t.Fatal("expected non-nil result")
}
if parsed.AgentID != "sales" {
t.Errorf("AgentID = %q, want 'sales'", parsed.AgentID)
}
if parsed.Rest != "telegram:direct:user123" {
t.Errorf("Rest = %q, want 'telegram:direct:user123'", parsed.Rest)
}
}
func TestParseAgentSessionKey_Invalid(t *testing.T) {
tests := []string{
"",
"foo:bar",
"notprefix:sales:main",
"agent::main",
"agent:sales:",
}
for _, input := range tests {
if got := ParseAgentSessionKey(input); got != nil {
t.Errorf("ParseAgentSessionKey(%q) = %+v, want nil", input, got)
}
}
}
func TestIsSubagentSessionKey(t *testing.T) {
tests := []struct {
input string
want bool
}{
{"subagent:task-1", true},
{"agent:main:subagent:task-1", true},
{"agent:main:main", false},
{"agent:main:telegram:direct:user123", false},
{"", false},
}
for _, tt := range tests {
if got := IsSubagentSessionKey(tt.input); got != tt.want {
t.Errorf("IsSubagentSessionKey(%q) = %v, want %v", tt.input, got, tt.want)
}
}
}
-53
View File
@@ -8,7 +8,6 @@ import (
"net/http"
"os"
"path/filepath"
"strings"
"time"
)
@@ -24,12 +23,6 @@ type AvailableSkill struct {
Tags []string `json:"tags"`
}
type BuiltinSkill struct {
Name string `json:"name"`
Path string `json:"path"`
Enabled bool `json:"enabled"`
}
func NewSkillInstaller(workspace string) *SkillInstaller {
return &SkillInstaller{
workspace: workspace,
@@ -123,49 +116,3 @@ func (si *SkillInstaller) ListAvailableSkills(ctx context.Context) ([]AvailableS
return skills, nil
}
func (si *SkillInstaller) ListBuiltinSkills() []BuiltinSkill {
builtinSkillsDir := filepath.Join(filepath.Dir(si.workspace), "picoclaw", "skills")
entries, err := os.ReadDir(builtinSkillsDir)
if err != nil {
return nil
}
var skills []BuiltinSkill
for _, entry := range entries {
if entry.IsDir() {
_ = entry
skillName := entry.Name()
skillFile := filepath.Join(builtinSkillsDir, skillName, "SKILL.md")
data, err := os.ReadFile(skillFile)
description := ""
if err == nil {
content := string(data)
if idx := strings.Index(content, "\n"); idx > 0 {
firstLine := content[:idx]
if strings.Contains(firstLine, "description:") {
descLine := strings.Index(content[idx:], "\n")
if descLine > 0 {
description = strings.TrimSpace(content[idx+descLine : idx+descLine])
}
}
}
}
// skill := BuiltinSkill{
// Name: skillName,
// Path: description,
// Enabled: true,
// }
status := "✓"
fmt.Printf(" %s %s\n", status, entry.Name())
if description != "" {
fmt.Printf(" %s\n", description)
}
}
}
return skills
}
+22 -5
View File
@@ -9,6 +9,8 @@ import (
"path/filepath"
"regexp"
"strings"
"github.com/sipeed/picoclaw/pkg/logger"
)
var namePattern = regexp.MustCompile(`^[a-zA-Z0-9]+(-[a-zA-Z0-9]+)*$`)
@@ -251,6 +253,11 @@ func (sl *SkillsLoader) BuildSkillsSummary() string {
func (sl *SkillsLoader) getSkillMetadata(skillPath string) *SkillMetadata {
content, err := os.ReadFile(skillPath)
if err != nil {
logger.WarnCF("skills", "Failed to read skill metadata",
map[string]interface{}{
"skill_path": skillPath,
"error": err.Error(),
})
return nil
}
@@ -283,10 +290,15 @@ func (sl *SkillsLoader) getSkillMetadata(skillPath string) *SkillMetadata {
// parseSimpleYAML parses simple key: value YAML format
// Example: name: github\n description: "..."
// Normalizes line endings to handle \n (Unix), \r\n (Windows), and \r (classic Mac)
func (sl *SkillsLoader) parseSimpleYAML(content string) map[string]string {
result := make(map[string]string)
for _, line := range strings.Split(content, "\n") {
// Normalize line endings: convert \r\n and \r to \n
normalized := strings.ReplaceAll(content, "\r\n", "\n")
normalized = strings.ReplaceAll(normalized, "\r", "\n")
for _, line := range strings.Split(normalized, "\n") {
line = strings.TrimSpace(line)
if line == "" || strings.HasPrefix(line, "#") {
continue
@@ -306,9 +318,10 @@ func (sl *SkillsLoader) parseSimpleYAML(content string) map[string]string {
}
func (sl *SkillsLoader) extractFrontmatter(content string) string {
// (?s) enables DOTALL mode so . matches newlines
// Match first ---, capture everything until next --- on its own line
re := regexp.MustCompile(`(?s)^---\n(.*)\n---`)
// Support \n (Unix), \r\n (Windows), and \r (classic Mac) line endings for frontmatter blocks
// (?s) enables DOTALL so . matches newlines;
// ^--- at start, then ... --- at start of line, honoring all three line ending types
re := regexp.MustCompile(`(?s)^---(?:\r\n|\n|\r)(.*?)(?:\r\n|\n|\r)---`)
match := re.FindStringSubmatch(content)
if len(match) > 1 {
return match[1]
@@ -317,7 +330,11 @@ func (sl *SkillsLoader) extractFrontmatter(content string) string {
}
func (sl *SkillsLoader) stripFrontmatter(content string) string {
re := regexp.MustCompile(`^---\n.*?\n---\n`)
// Support \n (Unix), \r\n (Windows), and \r (classic Mac) line endings for frontmatter blocks
// (?s) enables DOTALL so . matches newlines;
// ^--- at start, then ... --- at start of line, honoring all three line ending types
// Match zero or more trailing line endings after closing --- (handles both with and without blank lines)
re := regexp.MustCompile(`(?s)^---(?:\r\n|\n|\r)(.*?)(?:\r\n|\n|\r)---(?:\r\n|\n|\r)*`)
return re.ReplaceAllString(content, "")
}
+102
View File
@@ -75,3 +75,105 @@ func TestSkillsInfoValidate(t *testing.T) {
})
}
}
func TestExtractFrontmatter(t *testing.T) {
sl := &SkillsLoader{}
testcases := []struct {
name string
content string
expectedName string
expectedDesc string
lineEndingType string
}{
{
name: "unix-line-endings",
lineEndingType: "Unix (\\n)",
content: "---\nname: test-skill\ndescription: A test skill\n---\n\n# Skill Content",
expectedName: "test-skill",
expectedDesc: "A test skill",
},
{
name: "windows-line-endings",
lineEndingType: "Windows (\\r\\n)",
content: "---\r\nname: test-skill\r\ndescription: A test skill\r\n---\r\n\r\n# Skill Content",
expectedName: "test-skill",
expectedDesc: "A test skill",
},
{
name: "classic-mac-line-endings",
lineEndingType: "Classic Mac (\\r)",
content: "---\rname: test-skill\rdescription: A test skill\r---\r\r# Skill Content",
expectedName: "test-skill",
expectedDesc: "A test skill",
},
}
for _, tc := range testcases {
t.Run(tc.name, func(t *testing.T) {
// Extract frontmatter
frontmatter := sl.extractFrontmatter(tc.content)
assert.NotEmpty(t, frontmatter, "Frontmatter should be extracted for %s line endings", tc.lineEndingType)
// Parse YAML to get name and description (parseSimpleYAML now handles all line ending types)
yamlMeta := sl.parseSimpleYAML(frontmatter)
assert.Equal(t, tc.expectedName, yamlMeta["name"], "Name should be correctly parsed from frontmatter with %s line endings", tc.lineEndingType)
assert.Equal(t, tc.expectedDesc, yamlMeta["description"], "Description should be correctly parsed from frontmatter with %s line endings", tc.lineEndingType)
})
}
}
func TestStripFrontmatter(t *testing.T) {
sl := &SkillsLoader{}
testcases := []struct {
name string
content string
expectedContent string
lineEndingType string
}{
{
name: "unix-line-endings",
lineEndingType: "Unix (\\n)",
content: "---\nname: test-skill\ndescription: A test skill\n---\n\n# Skill Content",
expectedContent: "# Skill Content",
},
{
name: "windows-line-endings",
lineEndingType: "Windows (\\r\\n)",
content: "---\r\nname: test-skill\r\ndescription: A test skill\r\n---\r\n\r\n# Skill Content",
expectedContent: "# Skill Content",
},
{
name: "classic-mac-line-endings",
lineEndingType: "Classic Mac (\\r)",
content: "---\rname: test-skill\rdescription: A test skill\r---\r\r# Skill Content",
expectedContent: "# Skill Content",
},
{
name: "unix-line-endings-without-trailing-newline",
lineEndingType: "Unix (\\n) without trailing newline",
content: "---\nname: test-skill\ndescription: A test skill\n---\n# Skill Content",
expectedContent: "# Skill Content",
},
{
name: "windows-line-endings-without-trailing-newline",
lineEndingType: "Windows (\\r\\n) without trailing newline",
content: "---\r\nname: test-skill\r\ndescription: A test skill\r\n---\r\n# Skill Content",
expectedContent: "# Skill Content",
},
{
name: "no-frontmatter",
lineEndingType: "No frontmatter",
content: "# Skill Content\n\nSome content here.",
expectedContent: "# Skill Content\n\nSome content here.",
},
}
for _, tc := range testcases {
t.Run(tc.name, func(t *testing.T) {
result := sl.stripFrontmatter(tc.content)
assert.Equal(t, tc.expectedContent, result, "Frontmatter should be stripped correctly for %s", tc.lineEndingType)
})
}
}
+4 -3
View File
@@ -7,6 +7,7 @@ import (
"time"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/cron"
"github.com/sipeed/picoclaw/pkg/utils"
)
@@ -29,9 +30,9 @@ type CronTool struct {
// NewCronTool creates a new CronTool
// execTimeout: 0 means no timeout, >0 sets the timeout duration
func NewCronTool(cronService *cron.CronService, executor JobExecutor, msgBus *bus.MessageBus, workspace string, restrict bool, execTimeout time.Duration) *CronTool {
execTool := NewExecTool(workspace, restrict)
execTool.SetTimeout(execTimeout) // 0 means no timeout
func NewCronTool(cronService *cron.CronService, executor JobExecutor, msgBus *bus.MessageBus, workspace string, restrict bool, execTimeout time.Duration, config *config.Config) *CronTool {
execTool := NewExecToolWithConfig(workspace, restrict, config)
execTool.SetTimeout(execTimeout)
return &CronTool{
cronService: cronService,
executor: executor,
+77 -9
View File
@@ -11,6 +11,8 @@ import (
"runtime"
"strings"
"time"
"github.com/sipeed/picoclaw/pkg/config"
)
type ExecTool struct {
@@ -21,16 +23,82 @@ type ExecTool struct {
restrictToWorkspace bool
}
var defaultDenyPatterns = []*regexp.Regexp{
regexp.MustCompile(`\brm\s+-[rf]{1,2}\b`),
regexp.MustCompile(`\bdel\s+/[fq]\b`),
regexp.MustCompile(`\brmdir\s+/s\b`),
regexp.MustCompile(`\b(format|mkfs|diskpart)\b\s`), // Match disk wiping commands (must be followed by space/args)
regexp.MustCompile(`\bdd\s+if=`),
regexp.MustCompile(`>\s*/dev/sd[a-z]\b`), // Block writes to disk devices (but allow /dev/null)
regexp.MustCompile(`\b(shutdown|reboot|poweroff)\b`),
regexp.MustCompile(`:\(\)\s*\{.*\};\s*:`),
regexp.MustCompile(`\$\([^)]+\)`),
regexp.MustCompile(`\$\{[^}]+\}`),
regexp.MustCompile("`[^`]+`"),
regexp.MustCompile(`\|\s*sh\b`),
regexp.MustCompile(`\|\s*bash\b`),
regexp.MustCompile(`;\s*rm\s+-[rf]`),
regexp.MustCompile(`&&\s*rm\s+-[rf]`),
regexp.MustCompile(`\|\|\s*rm\s+-[rf]`),
regexp.MustCompile(`>\s*/dev/null\s*>&?\s*\d?`),
regexp.MustCompile(`<<\s*EOF`),
regexp.MustCompile(`\$\(\s*cat\s+`),
regexp.MustCompile(`\$\(\s*curl\s+`),
regexp.MustCompile(`\$\(\s*wget\s+`),
regexp.MustCompile(`\$\(\s*which\s+`),
regexp.MustCompile(`\bsudo\b`),
regexp.MustCompile(`\bchmod\s+[0-7]{3,4}\b`),
regexp.MustCompile(`\bchown\b`),
regexp.MustCompile(`\bpkill\b`),
regexp.MustCompile(`\bkillall\b`),
regexp.MustCompile(`\bkill\s+-[9]\b`),
regexp.MustCompile(`\bcurl\b.*\|\s*(sh|bash)`),
regexp.MustCompile(`\bwget\b.*\|\s*(sh|bash)`),
regexp.MustCompile(`\bnpm\s+install\s+-g\b`),
regexp.MustCompile(`\bpip\s+install\s+--user\b`),
regexp.MustCompile(`\bapt\s+(install|remove|purge)\b`),
regexp.MustCompile(`\byum\s+(install|remove)\b`),
regexp.MustCompile(`\bdnf\s+(install|remove)\b`),
regexp.MustCompile(`\bdocker\s+run\b`),
regexp.MustCompile(`\bdocker\s+exec\b`),
regexp.MustCompile(`\bgit\s+push\b`),
regexp.MustCompile(`\bgit\s+force\b`),
regexp.MustCompile(`\bssh\b.*@`),
regexp.MustCompile(`\beval\b`),
regexp.MustCompile(`\bsource\s+.*\.sh\b`),
}
func NewExecTool(workingDir string, restrict bool) *ExecTool {
denyPatterns := []*regexp.Regexp{
regexp.MustCompile(`\brm\s+-[rf]{1,2}\b`),
regexp.MustCompile(`\bdel\s+/[fq]\b`),
regexp.MustCompile(`\brmdir\s+/s\b`),
regexp.MustCompile(`\b(format|mkfs|diskpart)\b\s`), // Match disk wiping commands (must be followed by space/args)
regexp.MustCompile(`\bdd\s+if=`),
regexp.MustCompile(`>\s*/dev/sd[a-z]\b`), // Block writes to disk devices (but allow /dev/null)
regexp.MustCompile(`\b(shutdown|reboot|poweroff)\b`),
regexp.MustCompile(`:\(\)\s*\{.*\};\s*:`),
return NewExecToolWithConfig(workingDir, restrict, nil)
}
func NewExecToolWithConfig(workingDir string, restrict bool, config *config.Config) *ExecTool {
denyPatterns := make([]*regexp.Regexp, 0)
enableDenyPatterns := true
if config != nil {
execConfig := config.Tools.Exec
enableDenyPatterns = execConfig.EnableDenyPatterns
if enableDenyPatterns {
if len(execConfig.CustomDenyPatterns) > 0 {
fmt.Printf("Using custom deny patterns: %v\n", execConfig.CustomDenyPatterns)
for _, pattern := range execConfig.CustomDenyPatterns {
re, err := regexp.Compile(pattern)
if err != nil {
fmt.Printf("Invalid custom deny pattern %q: %v\n", pattern, err)
continue
}
denyPatterns = append(denyPatterns, re)
}
} else {
denyPatterns = append(denyPatterns, defaultDenyPatterns...)
}
} else {
// If deny patterns are disabled, we won't add any patterns, allowing all commands.
fmt.Println("Warning: deny patterns are disabled. All commands will be allowed.")
}
} else {
denyPatterns = append(denyPatterns, defaultDenyPatterns...)
}
return &ExecTool{
+22 -5
View File
@@ -6,10 +6,11 @@ import (
)
type SpawnTool struct {
manager *SubagentManager
originChannel string
originChatID string
callback AsyncCallback // For async completion notification
manager *SubagentManager
originChannel string
originChatID string
allowlistCheck func(targetAgentID string) bool
callback AsyncCallback // For async completion notification
}
func NewSpawnTool(manager *SubagentManager) *SpawnTool {
@@ -45,6 +46,10 @@ func (t *SpawnTool) Parameters() map[string]interface{} {
"type": "string",
"description": "Optional short label for the task (for display)",
},
"agent_id": map[string]interface{}{
"type": "string",
"description": "Optional target agent ID to delegate the task to",
},
},
"required": []string{"task"},
}
@@ -55,6 +60,10 @@ func (t *SpawnTool) SetContext(channel, chatID string) {
t.originChatID = chatID
}
func (t *SpawnTool) SetAllowlistChecker(check func(targetAgentID string) bool) {
t.allowlistCheck = check
}
func (t *SpawnTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
task, ok := args["task"].(string)
if !ok {
@@ -62,13 +71,21 @@ func (t *SpawnTool) Execute(ctx context.Context, args map[string]interface{}) *T
}
label, _ := args["label"].(string)
agentID, _ := args["agent_id"].(string)
// Check allowlist if targeting a specific agent
if agentID != "" && t.allowlistCheck != nil {
if !t.allowlistCheck(agentID) {
return ErrorResult(fmt.Sprintf("not allowed to spawn agent '%s'", agentID))
}
}
if t.manager == nil {
return ErrorResult("Subagent manager not configured")
}
// Pass callback to manager for async completion notification
result, err := t.manager.Spawn(ctx, task, label, t.originChannel, t.originChatID, t.callback)
result, err := t.manager.Spawn(ctx, task, label, agentID, t.originChannel, t.originChatID, t.callback)
if err != nil {
return ErrorResult(fmt.Sprintf("failed to spawn subagent: %v", err))
}
+3 -1
View File
@@ -14,6 +14,7 @@ type SubagentTask struct {
ID string
Task string
Label string
AgentID string
OriginChannel string
OriginChatID string
Status string
@@ -61,7 +62,7 @@ func (sm *SubagentManager) RegisterTool(tool Tool) {
sm.tools.Register(tool)
}
func (sm *SubagentManager) Spawn(ctx context.Context, task, label, originChannel, originChatID string, callback AsyncCallback) (string, error) {
func (sm *SubagentManager) Spawn(ctx context.Context, task, label, agentID, originChannel, originChatID string, callback AsyncCallback) (string, error) {
sm.mu.Lock()
defer sm.mu.Unlock()
@@ -72,6 +73,7 @@ func (sm *SubagentManager) Spawn(ctx context.Context, task, label, originChannel
ID: taskID,
Task: task,
Label: label,
AgentID: agentID,
OriginChannel: originChannel,
OriginChatID: originChatID,
Status: "running",
+1
View File
@@ -116,6 +116,7 @@ func RunToolLoop(ctx context.Context, config ToolLoopConfig, messages []provider
Name: tc.Name,
Arguments: string(argumentsJSON),
},
Name: tc.Name,
})
}
messages = append(messages, assistantMsg)
+4 -2
View File
@@ -492,8 +492,10 @@ func (t *WebFetchTool) extractText(htmlContent string) string {
result = strings.TrimSpace(result)
re = regexp.MustCompile(`\s+`)
result = re.ReplaceAllLiteralString(result, " ")
re = regexp.MustCompile(`[^\S\n]+`)
result = re.ReplaceAllString(result, " ")
re = regexp.MustCompile(`\n{3,}`)
result = re.ReplaceAllString(result, "\n\n")
lines := strings.Split(result, "\n")
var cleanLines []string
+74
View File
@@ -234,6 +234,80 @@ func TestWebTool_WebFetch_HTMLExtraction(t *testing.T) {
}
}
// TestWebFetchTool_extractText verifies text extraction preserves newlines
func TestWebFetchTool_extractText(t *testing.T) {
tool := &WebFetchTool{}
tests := []struct {
name string
input string
wantFunc func(t *testing.T, got string)
}{
{
name: "preserves newlines between block elements",
input: "<html><body><h1>Title</h1>\n<p>Paragraph 1</p>\n<p>Paragraph 2</p></body></html>",
wantFunc: func(t *testing.T, got string) {
lines := strings.Split(got, "\n")
if len(lines) < 2 {
t.Errorf("Expected multiple lines, got %d: %q", len(lines), got)
}
if !strings.Contains(got, "Title") || !strings.Contains(got, "Paragraph 1") || !strings.Contains(got, "Paragraph 2") {
t.Errorf("Missing expected text: %q", got)
}
},
},
{
name: "removes script and style tags",
input: "<script>alert('x');</script><style>body{}</style><p>Keep this</p>",
wantFunc: func(t *testing.T, got string) {
if strings.Contains(got, "alert") || strings.Contains(got, "body{}") {
t.Errorf("Expected script/style content removed, got: %q", got)
}
if !strings.Contains(got, "Keep this") {
t.Errorf("Expected 'Keep this' to remain, got: %q", got)
}
},
},
{
name: "collapses excessive blank lines",
input: "<p>A</p>\n\n\n\n\n<p>B</p>",
wantFunc: func(t *testing.T, got string) {
if strings.Contains(got, "\n\n\n") {
t.Errorf("Expected excessive blank lines collapsed, got: %q", got)
}
},
},
{
name: "collapses horizontal whitespace",
input: "<p>hello world</p>",
wantFunc: func(t *testing.T, got string) {
if strings.Contains(got, " ") {
t.Errorf("Expected spaces collapsed, got: %q", got)
}
if !strings.Contains(got, "hello world") {
t.Errorf("Expected 'hello world', got: %q", got)
}
},
},
{
name: "empty input",
input: "",
wantFunc: func(t *testing.T, got string) {
if got != "" {
t.Errorf("Expected empty string, got: %q", got)
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := tool.extractText(tt.input)
tt.wantFunc(t, got)
})
}
}
// TestWebTool_WebFetch_MissingDomain verifies error handling for URL without domain
func TestWebTool_WebFetch_MissingDomain(t *testing.T) {
tool := NewWebFetchTool(50000)
+179
View File
@@ -0,0 +1,179 @@
package utils
import (
"strings"
)
// SplitMessage splits long messages into chunks, preserving code block integrity.
// The function reserves a buffer (10% of maxLen, min 50) to leave room for closing code blocks,
// but may extend to maxLen when needed.
// Call SplitMessage with the full text content and the maximum allowed length of a single message;
// it returns a slice of message chunks that each respect maxLen and avoid splitting fenced code blocks.
func SplitMessage(content string, maxLen int) []string {
var messages []string
// Dynamic buffer: 10% of maxLen, but at least 50 chars if possible
codeBlockBuffer := maxLen / 10
if codeBlockBuffer < 50 {
codeBlockBuffer = 50
}
if codeBlockBuffer > maxLen/2 {
codeBlockBuffer = maxLen / 2
}
for len(content) > 0 {
if len(content) <= maxLen {
messages = append(messages, content)
break
}
// Effective split point: maxLen minus buffer, to leave room for code blocks
effectiveLimit := maxLen - codeBlockBuffer
if effectiveLimit < maxLen/2 {
effectiveLimit = maxLen / 2
}
// Find natural split point within the effective limit
msgEnd := findLastNewline(content[:effectiveLimit], 200)
if msgEnd <= 0 {
msgEnd = findLastSpace(content[:effectiveLimit], 100)
}
if msgEnd <= 0 {
msgEnd = effectiveLimit
}
// Check if this would end with an incomplete code block
candidate := content[:msgEnd]
unclosedIdx := findLastUnclosedCodeBlock(candidate)
if unclosedIdx >= 0 {
// Message would end with incomplete code block
// Try to extend up to maxLen to include the closing ```
if len(content) > msgEnd {
closingIdx := findNextClosingCodeBlock(content, msgEnd)
if closingIdx > 0 && closingIdx <= maxLen {
// Extend to include the closing ```
msgEnd = closingIdx
} else {
// Code block is too long to fit in one chunk or missing closing fence.
// Try to split inside by injecting closing and reopening fences.
headerEnd := strings.Index(content[unclosedIdx:], "\n")
if headerEnd == -1 {
headerEnd = unclosedIdx + 3
} else {
headerEnd += unclosedIdx
}
header := strings.TrimSpace(content[unclosedIdx:headerEnd])
// If we have a reasonable amount of content after the header, split inside
if msgEnd > headerEnd+20 {
// Find a better split point closer to maxLen
innerLimit := maxLen - 5 // Leave room for "\n```"
betterEnd := findLastNewline(content[:innerLimit], 200)
if betterEnd > headerEnd {
msgEnd = betterEnd
} else {
msgEnd = innerLimit
}
messages = append(messages, strings.TrimRight(content[:msgEnd], " \t\n\r")+"\n```")
content = strings.TrimSpace(header + "\n" + content[msgEnd:])
continue
}
// Otherwise, try to split before the code block starts
newEnd := findLastNewline(content[:unclosedIdx], 200)
if newEnd <= 0 {
newEnd = findLastSpace(content[:unclosedIdx], 100)
}
if newEnd > 0 {
msgEnd = newEnd
} else {
// If we can't split before, we MUST split inside (last resort)
if unclosedIdx > 20 {
msgEnd = unclosedIdx
} else {
msgEnd = maxLen - 5
messages = append(messages, strings.TrimRight(content[:msgEnd], " \t\n\r")+"\n```")
content = strings.TrimSpace(header + "\n" + content[msgEnd:])
continue
}
}
}
}
}
if msgEnd <= 0 {
msgEnd = effectiveLimit
}
messages = append(messages, content[:msgEnd])
content = strings.TrimSpace(content[msgEnd:])
}
return messages
}
// findLastUnclosedCodeBlock finds the last opening ``` that doesn't have a closing ```
// Returns the position of the opening ``` or -1 if all code blocks are complete
func findLastUnclosedCodeBlock(text string) int {
inCodeBlock := false
lastOpenIdx := -1
for i := 0; i < len(text); i++ {
if i+2 < len(text) && text[i] == '`' && text[i+1] == '`' && text[i+2] == '`' {
// Toggle code block state on each fence
if !inCodeBlock {
// Entering a code block: record this opening fence
lastOpenIdx = i
}
inCodeBlock = !inCodeBlock
i += 2
}
}
if inCodeBlock {
return lastOpenIdx
}
return -1
}
// findNextClosingCodeBlock finds the next closing ``` starting from a position
// Returns the position after the closing ``` or -1 if not found
func findNextClosingCodeBlock(text string, startIdx int) int {
for i := startIdx; i < len(text); i++ {
if i+2 < len(text) && text[i] == '`' && text[i+1] == '`' && text[i+2] == '`' {
return i + 3
}
}
return -1
}
// findLastNewline finds the last newline character within the last N characters
// Returns the position of the newline or -1 if not found
func findLastNewline(s string, searchWindow int) int {
searchStart := len(s) - searchWindow
if searchStart < 0 {
searchStart = 0
}
for i := len(s) - 1; i >= searchStart; i-- {
if s[i] == '\n' {
return i
}
}
return -1
}
// findLastSpace finds the last space character within the last N characters
// Returns the position of the space or -1 if not found
func findLastSpace(s string, searchWindow int) int {
searchStart := len(s) - searchWindow
if searchStart < 0 {
searchStart = 0
}
for i := len(s) - 1; i >= searchStart; i-- {
if s[i] == ' ' || s[i] == '\t' {
return i
}
}
return -1
}
+151
View File
@@ -0,0 +1,151 @@
package utils
import (
"strings"
"testing"
)
func TestSplitMessage(t *testing.T) {
longText := strings.Repeat("a", 2500)
longCode := "```go\n" + strings.Repeat("fmt.Println(\"hello\")\n", 100) + "```" // ~2100 chars
tests := []struct {
name string
content string
maxLen int
expectChunks int // Check number of chunks
checkContent func(t *testing.T, chunks []string) // Custom validation
}{
{
name: "Empty message",
content: "",
maxLen: 2000,
expectChunks: 0,
},
{
name: "Short message fits in one chunk",
content: "Hello world",
maxLen: 2000,
expectChunks: 1,
},
{
name: "Simple split regular text",
content: longText,
maxLen: 2000,
expectChunks: 2,
checkContent: func(t *testing.T, chunks []string) {
if len(chunks[0]) > 2000 {
t.Errorf("Chunk 0 too large: %d", len(chunks[0]))
}
if len(chunks[0])+len(chunks[1]) != len(longText) {
t.Errorf("Total length mismatch. Got %d, want %d", len(chunks[0])+len(chunks[1]), len(longText))
}
},
},
{
name: "Split at newline",
// 1750 chars then newline, then more chars.
// Dynamic buffer: 2000 / 10 = 200.
// Effective limit: 2000 - 200 = 1800.
// Split should happen at newline because it's at 1750 (< 1800).
// Total length must > 2000 to trigger split. 1750 + 1 + 300 = 2051.
content: strings.Repeat("a", 1750) + "\n" + strings.Repeat("b", 300),
maxLen: 2000,
expectChunks: 2,
checkContent: func(t *testing.T, chunks []string) {
if len(chunks[0]) != 1750 {
t.Errorf("Expected chunk 0 to be 1750 length (split at newline), got %d", len(chunks[0]))
}
if chunks[1] != strings.Repeat("b", 300) {
t.Errorf("Chunk 1 content mismatch. Len: %d", len(chunks[1]))
}
},
},
{
name: "Long code block split",
content: "Prefix\n" + longCode,
maxLen: 2000,
expectChunks: 2,
checkContent: func(t *testing.T, chunks []string) {
// Check that first chunk ends with closing fence
if !strings.HasSuffix(chunks[0], "\n```") {
t.Error("First chunk should end with injected closing fence")
}
// Check that second chunk starts with execution header
if !strings.HasPrefix(chunks[1], "```go") {
t.Error("Second chunk should start with injected code block header")
}
},
},
{
name: "Preserve Unicode characters",
content: strings.Repeat("\u4e16", 1000), // 3000 bytes
maxLen: 2000,
expectChunks: 2,
checkContent: func(t *testing.T, chunks []string) {
// Just verify we didn't panic and got valid strings.
// Go strings are UTF-8, if we split mid-rune it would be bad,
// but standard slicing might do that.
// Let's assume standard behavior is acceptable or check if it produces invalid rune?
if !strings.Contains(chunks[0], "\u4e16") {
t.Error("Chunk should contain unicode characters")
}
},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := SplitMessage(tc.content, tc.maxLen)
if tc.expectChunks == 0 {
if len(got) != 0 {
t.Errorf("Expected 0 chunks, got %d", len(got))
}
return
}
if len(got) != tc.expectChunks {
t.Errorf("Expected %d chunks, got %d", tc.expectChunks, len(got))
// Log sizes for debugging
for i, c := range got {
t.Logf("Chunk %d length: %d", i, len(c))
}
return // Stop further checks if count assumes specific split
}
if tc.checkContent != nil {
tc.checkContent(t, got)
}
})
}
}
func TestSplitMessage_CodeBlockIntegrity(t *testing.T) {
// Focused test for the core requirement: splitting inside a code block preserves syntax highlighting
// 60 chars total approximately
content := "```go\npackage main\n\nfunc main() {\n\tprintln(\"Hello\")\n}\n```"
maxLen := 40
chunks := SplitMessage(content, maxLen)
if len(chunks) != 2 {
t.Fatalf("Expected 2 chunks, got %d: %q", len(chunks), chunks)
}
// First chunk must end with "\n```"
if !strings.HasSuffix(chunks[0], "\n```") {
t.Errorf("First chunk should end with closing fence. Got: %q", chunks[0])
}
// Second chunk must start with the header "```go"
if !strings.HasPrefix(chunks[1], "```go") {
t.Errorf("Second chunk should start with code block header. Got: %q", chunks[1])
}
// First chunk should contain meaningful content
if len(chunks[0]) > 40 {
t.Errorf("First chunk exceeded maxLen: length %d", len(chunks[0]))
}
}