mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
Merge upstream/main into fix/bugfixes
Resolve conflicts: - provider.go: keep upstream's serializeMessages (supersedes stripSystemParts) - provider_test.go: keep upstream's serializeMessages tests - loop_test.go: add slices import needed by upstream tests - shell.go: merge PR's --format deny fix with upstream's block device pattern, safePaths, and absolutePathPattern - shell_test.go: include tests from both branches Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
+180
-39
@@ -23,6 +23,7 @@ import (
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/constants"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/mcp"
|
||||
"github.com/sipeed/picoclaw/pkg/media"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/routing"
|
||||
@@ -46,19 +47,24 @@ type AgentLoop struct {
|
||||
|
||||
// processOptions configures how a message is processed
|
||||
type processOptions struct {
|
||||
SessionKey string // Session identifier for history/context
|
||||
Channel string // Target channel for tool execution
|
||||
ChatID string // Target chat ID for tool execution
|
||||
UserMessage string // User message content (may include prefix)
|
||||
DefaultResponse string // Response when LLM returns empty
|
||||
EnableSummary bool // Whether to trigger summarization
|
||||
SendResponse bool // Whether to send response via bus
|
||||
NoHistory bool // If true, don't load session history (for heartbeat)
|
||||
SessionKey string // Session identifier for history/context
|
||||
Channel string // Target channel for tool execution
|
||||
ChatID string // Target chat ID for tool execution
|
||||
UserMessage string // User message content (may include prefix)
|
||||
Media []string // media:// refs from inbound message
|
||||
DefaultResponse string // Response when LLM returns empty
|
||||
EnableSummary bool // Whether to trigger summarization
|
||||
SendResponse bool // Whether to send response via bus
|
||||
NoHistory bool // If true, don't load session history (for heartbeat)
|
||||
}
|
||||
|
||||
const defaultResponse = "I've completed processing but have no response to give. Increase `max_tool_iterations` in config.json."
|
||||
|
||||
func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers.LLMProvider) *AgentLoop {
|
||||
func NewAgentLoop(
|
||||
cfg *config.Config,
|
||||
msgBus *bus.MessageBus,
|
||||
provider providers.LLMProvider,
|
||||
) *AgentLoop {
|
||||
registry := NewAgentRegistry(cfg, provider)
|
||||
|
||||
// Register shared tools to all agents
|
||||
@@ -99,7 +105,7 @@ func registerSharedTools(
|
||||
}
|
||||
|
||||
// Web tools
|
||||
if searchTool := tools.NewWebSearchTool(tools.WebSearchToolOptions{
|
||||
searchTool, err := tools.NewWebSearchTool(tools.WebSearchToolOptions{
|
||||
BraveAPIKey: cfg.Tools.Web.Brave.APIKey,
|
||||
BraveMaxResults: cfg.Tools.Web.Brave.MaxResults,
|
||||
BraveEnabled: cfg.Tools.Web.Brave.Enabled,
|
||||
@@ -113,10 +119,18 @@ func registerSharedTools(
|
||||
PerplexityMaxResults: cfg.Tools.Web.Perplexity.MaxResults,
|
||||
PerplexityEnabled: cfg.Tools.Web.Perplexity.Enabled,
|
||||
Proxy: cfg.Tools.Web.Proxy,
|
||||
}); searchTool != nil {
|
||||
})
|
||||
if err != nil {
|
||||
logger.ErrorCF("agent", "Failed to create web search tool", map[string]any{"error": err.Error()})
|
||||
} else if searchTool != nil {
|
||||
agent.Tools.Register(searchTool)
|
||||
}
|
||||
agent.Tools.Register(tools.NewWebFetchToolWithProxy(50000, cfg.Tools.Web.Proxy))
|
||||
fetchTool, err := tools.NewWebFetchToolWithProxy(50000, cfg.Tools.Web.Proxy, cfg.Tools.Web.FetchLimitBytes)
|
||||
if err != nil {
|
||||
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
|
||||
} else {
|
||||
agent.Tools.Register(fetchTool)
|
||||
}
|
||||
|
||||
// Hardware tools (I2C, SPI) - Linux only, returns error on other platforms
|
||||
agent.Tools.Register(tools.NewI2CTool())
|
||||
@@ -162,6 +176,72 @@ func registerSharedTools(
|
||||
func (al *AgentLoop) Run(ctx context.Context) error {
|
||||
al.running.Store(true)
|
||||
|
||||
// Initialize MCP servers for all agents
|
||||
if al.cfg.Tools.MCP.Enabled {
|
||||
mcpManager := mcp.NewManager()
|
||||
// Ensure MCP connections are cleaned up on exit, regardless of initialization success
|
||||
// This fixes resource leak when LoadFromMCPConfig partially succeeds then fails
|
||||
defer func() {
|
||||
if err := mcpManager.Close(); err != nil {
|
||||
logger.ErrorCF("agent", "Failed to close MCP manager",
|
||||
map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
}()
|
||||
|
||||
defaultAgent := al.registry.GetDefaultAgent()
|
||||
var workspacePath string
|
||||
if defaultAgent != nil && defaultAgent.Workspace != "" {
|
||||
workspacePath = defaultAgent.Workspace
|
||||
} else {
|
||||
workspacePath = al.cfg.WorkspacePath()
|
||||
}
|
||||
|
||||
if err := mcpManager.LoadFromMCPConfig(ctx, al.cfg.Tools.MCP, workspacePath); err != nil {
|
||||
logger.WarnCF("agent", "Failed to load MCP servers, MCP tools will not be available",
|
||||
map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
} else {
|
||||
// Register MCP tools for all agents
|
||||
servers := mcpManager.GetServers()
|
||||
uniqueTools := 0
|
||||
totalRegistrations := 0
|
||||
agentIDs := al.registry.ListAgentIDs()
|
||||
agentCount := len(agentIDs)
|
||||
|
||||
for serverName, conn := range servers {
|
||||
uniqueTools += len(conn.Tools)
|
||||
for _, tool := range conn.Tools {
|
||||
for _, agentID := range agentIDs {
|
||||
agent, ok := al.registry.GetAgent(agentID)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
mcpTool := tools.NewMCPTool(mcpManager, serverName, tool)
|
||||
agent.Tools.Register(mcpTool)
|
||||
totalRegistrations++
|
||||
logger.DebugCF("agent", "Registered MCP tool",
|
||||
map[string]any{
|
||||
"agent_id": agentID,
|
||||
"server": serverName,
|
||||
"tool": tool.Name,
|
||||
"name": mcpTool.Name(),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
logger.InfoCF("agent", "MCP tools registered successfully",
|
||||
map[string]any{
|
||||
"server_count": len(servers),
|
||||
"unique_tools": uniqueTools,
|
||||
"total_registrations": totalRegistrations,
|
||||
"agent_count": agentCount,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
for al.running.Load() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
@@ -302,7 +382,10 @@ func (al *AgentLoop) RecordLastChatID(chatID string) error {
|
||||
return al.state.SetLastChatID(chatID)
|
||||
}
|
||||
|
||||
func (al *AgentLoop) ProcessDirect(ctx context.Context, content, sessionKey string) (string, error) {
|
||||
func (al *AgentLoop) ProcessDirect(
|
||||
ctx context.Context,
|
||||
content, sessionKey string,
|
||||
) (string, error) {
|
||||
return al.ProcessDirectWithChannel(ctx, content, sessionKey, "cli", "direct")
|
||||
}
|
||||
|
||||
@@ -323,7 +406,10 @@ func (al *AgentLoop) ProcessDirectWithChannel(
|
||||
|
||||
// ProcessHeartbeat processes a heartbeat request without session history.
|
||||
// Each heartbeat is independent and doesn't accumulate context.
|
||||
func (al *AgentLoop) ProcessHeartbeat(ctx context.Context, content, channel, chatID string) (string, error) {
|
||||
func (al *AgentLoop) ProcessHeartbeat(
|
||||
ctx context.Context,
|
||||
content, channel, chatID string,
|
||||
) (string, error) {
|
||||
agent := al.registry.GetDefaultAgent()
|
||||
if agent == nil {
|
||||
return "", fmt.Errorf("no default agent for heartbeat")
|
||||
@@ -348,13 +434,16 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage)
|
||||
} else {
|
||||
logContent = utils.Truncate(msg.Content, 80)
|
||||
}
|
||||
logger.InfoCF("agent", fmt.Sprintf("Processing message from %s:%s: %s", msg.Channel, msg.SenderID, logContent),
|
||||
logger.InfoCF(
|
||||
"agent",
|
||||
fmt.Sprintf("Processing message from %s:%s: %s", msg.Channel, msg.SenderID, logContent),
|
||||
map[string]any{
|
||||
"channel": msg.Channel,
|
||||
"chat_id": msg.ChatID,
|
||||
"sender_id": msg.SenderID,
|
||||
"session_key": msg.SessionKey,
|
||||
})
|
||||
},
|
||||
)
|
||||
|
||||
// Route system messages to processSystemMessage
|
||||
if msg.Channel == "system" {
|
||||
@@ -409,15 +498,22 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage)
|
||||
Channel: msg.Channel,
|
||||
ChatID: msg.ChatID,
|
||||
UserMessage: msg.Content,
|
||||
Media: msg.Media,
|
||||
DefaultResponse: defaultResponse,
|
||||
EnableSummary: true,
|
||||
SendResponse: false,
|
||||
})
|
||||
}
|
||||
|
||||
func (al *AgentLoop) processSystemMessage(ctx context.Context, msg bus.InboundMessage) (string, error) {
|
||||
func (al *AgentLoop) processSystemMessage(
|
||||
ctx context.Context,
|
||||
msg bus.InboundMessage,
|
||||
) (string, error) {
|
||||
if msg.Channel != "system" {
|
||||
return "", fmt.Errorf("processSystemMessage called with non-system message channel: %s", msg.Channel)
|
||||
return "", fmt.Errorf(
|
||||
"processSystemMessage called with non-system message channel: %s",
|
||||
msg.Channel,
|
||||
)
|
||||
}
|
||||
|
||||
logger.InfoCF("agent", "Processing system message",
|
||||
@@ -475,14 +571,22 @@ func (al *AgentLoop) processSystemMessage(ctx context.Context, msg bus.InboundMe
|
||||
}
|
||||
|
||||
// runAgentLoop is the core message processing logic.
|
||||
func (al *AgentLoop) runAgentLoop(ctx context.Context, agent *AgentInstance, opts processOptions) (string, error) {
|
||||
func (al *AgentLoop) runAgentLoop(
|
||||
ctx context.Context,
|
||||
agent *AgentInstance,
|
||||
opts processOptions,
|
||||
) (string, error) {
|
||||
// 0. Record last channel for heartbeat notifications (skip internal channels)
|
||||
if opts.Channel != "" && opts.ChatID != "" {
|
||||
// Don't record internal channels (cli, system, subagent)
|
||||
if !constants.IsInternalChannel(opts.Channel) {
|
||||
channelKey := fmt.Sprintf("%s:%s", opts.Channel, opts.ChatID)
|
||||
if err := al.RecordLastChannel(channelKey); err != nil {
|
||||
logger.WarnCF("agent", "Failed to record last channel", map[string]any{"error": err.Error()})
|
||||
logger.WarnCF(
|
||||
"agent",
|
||||
"Failed to record last channel",
|
||||
map[string]any{"error": err.Error()},
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -501,11 +605,15 @@ func (al *AgentLoop) runAgentLoop(ctx context.Context, agent *AgentInstance, opt
|
||||
history,
|
||||
summary,
|
||||
opts.UserMessage,
|
||||
nil,
|
||||
opts.Media,
|
||||
opts.Channel,
|
||||
opts.ChatID,
|
||||
)
|
||||
|
||||
// Resolve media:// refs to base64 data URLs (streaming)
|
||||
maxMediaSize := al.cfg.Agents.Defaults.GetMaxMediaSize()
|
||||
messages = resolveMediaRefs(messages, al.mediaStore, maxMediaSize)
|
||||
|
||||
// 3. Save user message to session
|
||||
agent.Sessions.AddMessage(opts.SessionKey, "user", opts.UserMessage)
|
||||
|
||||
@@ -564,7 +672,10 @@ func (al *AgentLoop) targetReasoningChannelID(channelName string) (chatID string
|
||||
return ""
|
||||
}
|
||||
|
||||
func (al *AgentLoop) handleReasoning(ctx context.Context, reasoningContent, channelName, channelID string) {
|
||||
func (al *AgentLoop) handleReasoning(
|
||||
ctx context.Context,
|
||||
reasoningContent, channelName, channelID string,
|
||||
) {
|
||||
if reasoningContent == "" || channelName == "" || channelID == "" {
|
||||
return
|
||||
}
|
||||
@@ -657,22 +768,33 @@ func (al *AgentLoop) runLLMIteration(
|
||||
|
||||
callLLM := func() (*providers.LLMResponse, error) {
|
||||
if len(agent.Candidates) > 1 && al.fallback != nil {
|
||||
fbResult, fbErr := al.fallback.Execute(ctx, agent.Candidates,
|
||||
fbResult, fbErr := al.fallback.Execute(
|
||||
ctx,
|
||||
agent.Candidates,
|
||||
func(ctx context.Context, provider, model string) (*providers.LLMResponse, error) {
|
||||
return agent.Provider.Chat(ctx, messages, providerToolDefs, model, map[string]any{
|
||||
"max_tokens": agent.MaxTokens,
|
||||
"temperature": agent.Temperature,
|
||||
"prompt_cache_key": agent.ID,
|
||||
})
|
||||
return agent.Provider.Chat(
|
||||
ctx,
|
||||
messages,
|
||||
providerToolDefs,
|
||||
model,
|
||||
map[string]any{
|
||||
"max_tokens": agent.MaxTokens,
|
||||
"temperature": agent.Temperature,
|
||||
"prompt_cache_key": agent.ID,
|
||||
},
|
||||
)
|
||||
},
|
||||
)
|
||||
if fbErr != nil {
|
||||
return nil, fbErr
|
||||
}
|
||||
if fbResult.Provider != "" && len(fbResult.Attempts) > 0 {
|
||||
logger.InfoCF("agent", fmt.Sprintf("Fallback: succeeded with %s/%s after %d attempts",
|
||||
fbResult.Provider, fbResult.Model, len(fbResult.Attempts)+1),
|
||||
map[string]any{"agent_id": agent.ID, "iteration": iteration})
|
||||
logger.InfoCF(
|
||||
"agent",
|
||||
fmt.Sprintf("Fallback: succeeded with %s/%s after %d attempts",
|
||||
fbResult.Provider, fbResult.Model, len(fbResult.Attempts)+1),
|
||||
map[string]any{"agent_id": agent.ID, "iteration": iteration},
|
||||
)
|
||||
}
|
||||
return fbResult.Response, nil
|
||||
}
|
||||
@@ -723,10 +845,14 @@ func (al *AgentLoop) runLLMIteration(
|
||||
}
|
||||
|
||||
if isContextError && retry < maxRetries {
|
||||
logger.WarnCF("agent", "Context window error detected, attempting compression", map[string]any{
|
||||
"error": err.Error(),
|
||||
"retry": retry,
|
||||
})
|
||||
logger.WarnCF(
|
||||
"agent",
|
||||
"Context window error detected, attempting compression",
|
||||
map[string]any{
|
||||
"error": err.Error(),
|
||||
"retry": retry,
|
||||
},
|
||||
)
|
||||
|
||||
if retry == 0 && !constants.IsInternalChannel(opts.Channel) {
|
||||
al.bus.PublishOutbound(ctx, bus.OutboundMessage{
|
||||
@@ -758,7 +884,12 @@ func (al *AgentLoop) runLLMIteration(
|
||||
return "", iteration, fmt.Errorf("LLM call failed after retries: %w", err)
|
||||
}
|
||||
|
||||
go al.handleReasoning(ctx, response.Reasoning, opts.Channel, al.targetReasoningChannelID(opts.Channel))
|
||||
go al.handleReasoning(
|
||||
ctx,
|
||||
response.Reasoning,
|
||||
opts.Channel,
|
||||
al.targetReasoningChannelID(opts.Channel),
|
||||
)
|
||||
|
||||
logger.DebugCF("agent", "LLM response",
|
||||
map[string]any{
|
||||
@@ -1067,7 +1198,11 @@ func formatMessagesForLog(messages []providers.Message) string {
|
||||
for _, tc := range msg.ToolCalls {
|
||||
fmt.Fprintf(&sb, " - ID: %s, Type: %s, Name: %s\n", tc.ID, tc.Type, tc.Name)
|
||||
if tc.Function != nil {
|
||||
fmt.Fprintf(&sb, " Arguments: %s\n", utils.Truncate(tc.Function.Arguments, 200))
|
||||
fmt.Fprintf(
|
||||
&sb,
|
||||
" Arguments: %s\n",
|
||||
utils.Truncate(tc.Function.Arguments, 200),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1096,7 +1231,11 @@ func formatToolsForLog(toolDefs []providers.ToolDefinition) string {
|
||||
fmt.Fprintf(&sb, " [%d] Type: %s, Name: %s\n", i, tool.Type, tool.Function.Name)
|
||||
fmt.Fprintf(&sb, " Description: %s\n", tool.Function.Description)
|
||||
if len(tool.Function.Parameters) > 0 {
|
||||
fmt.Fprintf(&sb, " Parameters: %s\n", utils.Truncate(fmt.Sprintf("%v", tool.Function.Parameters), 200))
|
||||
fmt.Fprintf(
|
||||
&sb,
|
||||
" Parameters: %s\n",
|
||||
utils.Truncate(fmt.Sprintf("%v", tool.Function.Parameters), 200),
|
||||
)
|
||||
}
|
||||
}
|
||||
sb.WriteString("]")
|
||||
@@ -1193,7 +1332,9 @@ func (al *AgentLoop) summarizeBatch(
|
||||
existingSummary string,
|
||||
) (string, error) {
|
||||
var sb strings.Builder
|
||||
sb.WriteString("Provide a concise summary of this conversation segment, preserving core context and key points.\n")
|
||||
sb.WriteString(
|
||||
"Provide a concise summary of this conversation segment, preserving core context and key points.\n",
|
||||
)
|
||||
if existingSummary != "" {
|
||||
sb.WriteString("Existing context: ")
|
||||
sb.WriteString(existingSummary)
|
||||
|
||||
Reference in New Issue
Block a user