mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
merge: sync upstream/main into feat/multi-agent-routing
Resolve conflicts: - pkg/agent/loop.go: integrate context compression, command handling, utf8 token estimation, and summarization notification into multi-agent routing architecture - pkg/config/config_test.go: merge imports from both branches - pkg/agent/loop_test.go: update test to use registry-based sessions
This commit is contained in:
+242
-30
@@ -14,8 +14,10 @@ import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/channels"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/constants"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
@@ -27,13 +29,14 @@ import (
|
||||
)
|
||||
|
||||
type AgentLoop struct {
|
||||
bus *bus.MessageBus
|
||||
cfg *config.Config
|
||||
registry *AgentRegistry
|
||||
state *state.Manager
|
||||
running atomic.Bool
|
||||
summarizing sync.Map
|
||||
fallback *providers.FallbackChain
|
||||
bus *bus.MessageBus
|
||||
cfg *config.Config
|
||||
registry *AgentRegistry
|
||||
state *state.Manager
|
||||
running atomic.Bool
|
||||
summarizing sync.Map
|
||||
fallback *providers.FallbackChain
|
||||
channelManager *channels.Manager
|
||||
}
|
||||
|
||||
// processOptions configures how a message is processed
|
||||
@@ -183,6 +186,10 @@ func (al *AgentLoop) RegisterTool(tool tools.Tool) {
|
||||
}
|
||||
}
|
||||
|
||||
func (al *AgentLoop) SetChannelManager(cm *channels.Manager) {
|
||||
al.channelManager = cm
|
||||
}
|
||||
|
||||
// RecordLastChannel records the last active channel for this workspace.
|
||||
// This uses the atomic state save mechanism to prevent data loss on crash.
|
||||
func (al *AgentLoop) RecordLastChannel(channel string) error {
|
||||
@@ -254,6 +261,11 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage)
|
||||
return al.processSystemMessage(ctx, msg)
|
||||
}
|
||||
|
||||
// Check for commands
|
||||
if response, handled := al.handleCommand(ctx, msg); handled {
|
||||
return response, nil
|
||||
}
|
||||
|
||||
// Route to determine agent and session key
|
||||
route := al.registry.ResolveRoute(routing.RouteInput{
|
||||
Channel: msg.Channel,
|
||||
@@ -404,7 +416,7 @@ func (al *AgentLoop) runAgentLoop(ctx context.Context, agent *AgentInstance, opt
|
||||
|
||||
// 7. Optional: summarization
|
||||
if opts.EnableSummary {
|
||||
al.maybeSummarize(agent, opts.SessionKey)
|
||||
al.maybeSummarize(agent, opts.SessionKey, opts.Channel, opts.ChatID)
|
||||
}
|
||||
|
||||
// 8. Optional: send response via bus
|
||||
@@ -472,32 +484,72 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, agent *AgentInstance,
|
||||
var response *providers.LLMResponse
|
||||
var err error
|
||||
|
||||
if len(agent.Candidates) > 1 && al.fallback != nil {
|
||||
fbResult, fbErr := al.fallback.Execute(ctx, agent.Candidates,
|
||||
func(ctx context.Context, provider, model string) (*providers.LLMResponse, error) {
|
||||
return agent.Provider.Chat(ctx, messages, providerToolDefs, model, map[string]interface{}{
|
||||
"max_tokens": 8192,
|
||||
"temperature": 0.7,
|
||||
})
|
||||
},
|
||||
)
|
||||
if fbErr != nil {
|
||||
err = fbErr
|
||||
} else {
|
||||
response = fbResult.Response
|
||||
callLLM := func() (*providers.LLMResponse, error) {
|
||||
if len(agent.Candidates) > 1 && al.fallback != nil {
|
||||
fbResult, fbErr := al.fallback.Execute(ctx, agent.Candidates,
|
||||
func(ctx context.Context, provider, model string) (*providers.LLMResponse, error) {
|
||||
return agent.Provider.Chat(ctx, messages, providerToolDefs, model, map[string]interface{}{
|
||||
"max_tokens": 8192,
|
||||
"temperature": 0.7,
|
||||
})
|
||||
},
|
||||
)
|
||||
if fbErr != nil {
|
||||
return nil, fbErr
|
||||
}
|
||||
if fbResult.Provider != "" && len(fbResult.Attempts) > 0 {
|
||||
logger.InfoCF("agent", fmt.Sprintf("Fallback: succeeded with %s/%s after %d attempts",
|
||||
fbResult.Provider, fbResult.Model, len(fbResult.Attempts)+1),
|
||||
map[string]interface{}{"agent_id": agent.ID, "iteration": iteration})
|
||||
}
|
||||
return fbResult.Response, nil
|
||||
}
|
||||
} else {
|
||||
response, err = agent.Provider.Chat(ctx, messages, providerToolDefs, agent.Model, map[string]interface{}{
|
||||
return agent.Provider.Chat(ctx, messages, providerToolDefs, agent.Model, map[string]interface{}{
|
||||
"max_tokens": 8192,
|
||||
"temperature": 0.7,
|
||||
})
|
||||
}
|
||||
|
||||
// Retry loop for context/token errors
|
||||
maxRetries := 2
|
||||
for retry := 0; retry <= maxRetries; retry++ {
|
||||
response, err = callLLM()
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
|
||||
errMsg := strings.ToLower(err.Error())
|
||||
isContextError := strings.Contains(errMsg, "token") ||
|
||||
strings.Contains(errMsg, "context") ||
|
||||
strings.Contains(errMsg, "invalidparameter") ||
|
||||
strings.Contains(errMsg, "length")
|
||||
|
||||
if isContextError && retry < maxRetries {
|
||||
logger.WarnCF("agent", "Context window error detected, attempting compression", map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
"retry": retry,
|
||||
})
|
||||
|
||||
if retry == 0 && !constants.IsInternalChannel(opts.Channel) {
|
||||
al.bus.PublishOutbound(bus.OutboundMessage{
|
||||
Channel: opts.Channel,
|
||||
ChatID: opts.ChatID,
|
||||
Content: "Context window exceeded. Compressing history and retrying...",
|
||||
})
|
||||
}
|
||||
|
||||
al.forceCompression(agent, opts.SessionKey)
|
||||
newHistory := agent.Sessions.GetHistory(opts.SessionKey)
|
||||
newSummary := agent.Sessions.GetSummary(opts.SessionKey)
|
||||
messages = agent.ContextBuilder.BuildMessages(
|
||||
newHistory, newSummary, "",
|
||||
nil, opts.Channel, opts.ChatID,
|
||||
)
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
logger.ErrorCF("agent", "LLM call failed",
|
||||
map[string]interface{}{
|
||||
@@ -505,7 +557,7 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, agent *AgentInstance,
|
||||
"iteration": iteration,
|
||||
"error": err.Error(),
|
||||
})
|
||||
return "", iteration, fmt.Errorf("LLM call failed: %w", err)
|
||||
return "", iteration, fmt.Errorf("LLM call failed after retries: %w", err)
|
||||
}
|
||||
|
||||
// Check if no tool calls - we're done
|
||||
@@ -639,7 +691,7 @@ func (al *AgentLoop) updateToolContexts(agent *AgentInstance, channel, chatID st
|
||||
}
|
||||
|
||||
// maybeSummarize triggers summarization if the session history exceeds thresholds.
|
||||
func (al *AgentLoop) maybeSummarize(agent *AgentInstance, sessionKey string) {
|
||||
func (al *AgentLoop) maybeSummarize(agent *AgentInstance, sessionKey, channel, chatID string) {
|
||||
newHistory := agent.Sessions.GetHistory(sessionKey)
|
||||
tokenEstimate := al.estimateTokens(newHistory)
|
||||
threshold := agent.ContextWindow * 75 / 100
|
||||
@@ -649,12 +701,79 @@ func (al *AgentLoop) maybeSummarize(agent *AgentInstance, sessionKey string) {
|
||||
if _, loading := al.summarizing.LoadOrStore(summarizeKey, true); !loading {
|
||||
go func() {
|
||||
defer al.summarizing.Delete(summarizeKey)
|
||||
if !constants.IsInternalChannel(channel) {
|
||||
al.bus.PublishOutbound(bus.OutboundMessage{
|
||||
Channel: channel,
|
||||
ChatID: chatID,
|
||||
Content: "Memory threshold reached. Optimizing conversation history...",
|
||||
})
|
||||
}
|
||||
al.summarizeSession(agent, sessionKey)
|
||||
}()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// forceCompression aggressively reduces context when the limit is hit.
|
||||
// It drops the oldest 50% of messages (keeping system prompt and last user message).
|
||||
func (al *AgentLoop) forceCompression(agent *AgentInstance, sessionKey string) {
|
||||
history := agent.Sessions.GetHistory(sessionKey)
|
||||
if len(history) <= 4 {
|
||||
return
|
||||
}
|
||||
|
||||
// Keep system prompt (usually [0]) and the very last message (user's trigger)
|
||||
// We want to drop the oldest half of the *conversation*
|
||||
// Assuming [0] is system, [1:] is conversation
|
||||
conversation := history[1 : len(history)-1]
|
||||
if len(conversation) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Helper to find the mid-point of the conversation
|
||||
mid := len(conversation) / 2
|
||||
|
||||
// New history structure:
|
||||
// 1. System Prompt
|
||||
// 2. [Summary of dropped part] - synthesized
|
||||
// 3. Second half of conversation
|
||||
// 4. Last message
|
||||
|
||||
// Simplified approach for emergency: Drop first half of conversation
|
||||
// and rely on existing summary if present, or create a placeholder.
|
||||
|
||||
droppedCount := mid
|
||||
keptConversation := conversation[mid:]
|
||||
|
||||
newHistory := make([]providers.Message, 0)
|
||||
newHistory = append(newHistory, history[0]) // System prompt
|
||||
|
||||
// Add a note about compression
|
||||
compressionNote := fmt.Sprintf("[System: Emergency compression dropped %d oldest messages due to context limit]", droppedCount)
|
||||
// If there was an existing summary, we might lose it if it was in the dropped part (which is just messages).
|
||||
// The summary is stored separately in session.Summary, so it persists!
|
||||
// We just need to ensure the user knows there's a gap.
|
||||
|
||||
// We only modify the messages list here
|
||||
newHistory = append(newHistory, providers.Message{
|
||||
Role: "system",
|
||||
Content: compressionNote,
|
||||
})
|
||||
|
||||
newHistory = append(newHistory, keptConversation...)
|
||||
newHistory = append(newHistory, history[len(history)-1]) // Last message
|
||||
|
||||
// Update session
|
||||
agent.Sessions.SetHistory(sessionKey, newHistory)
|
||||
agent.Sessions.Save(sessionKey)
|
||||
|
||||
logger.WarnCF("agent", "Forced compression executed", map[string]interface{}{
|
||||
"session_key": sessionKey,
|
||||
"dropped_msgs": droppedCount,
|
||||
"new_count": len(newHistory),
|
||||
})
|
||||
}
|
||||
|
||||
// GetStartupInfo returns information about loaded tools and skills for logging.
|
||||
func (al *AgentLoop) GetStartupInfo() map[string]interface{} {
|
||||
info := make(map[string]interface{})
|
||||
@@ -693,7 +812,7 @@ func formatMessagesForLog(messages []providers.Message) string {
|
||||
result += "[\n"
|
||||
for i, msg := range messages {
|
||||
result += fmt.Sprintf(" [%d] Role: %s\n", i, msg.Role)
|
||||
if msg.ToolCalls != nil && len(msg.ToolCalls) > 0 {
|
||||
if len(msg.ToolCalls) > 0 {
|
||||
result += " ToolCalls:\n"
|
||||
for _, tc := range msg.ToolCalls {
|
||||
result += fmt.Sprintf(" - ID: %s, Type: %s, Name: %s\n", tc.ID, tc.Type, tc.Name)
|
||||
@@ -758,7 +877,7 @@ func (al *AgentLoop) summarizeSession(agent *AgentInstance, sessionKey string) {
|
||||
if m.Role != "user" && m.Role != "assistant" {
|
||||
continue
|
||||
}
|
||||
msgTokens := len(m.Content) / 4
|
||||
msgTokens := len(m.Content) / 2
|
||||
if msgTokens > maxMessageTokens {
|
||||
omitted = true
|
||||
continue
|
||||
@@ -827,12 +946,105 @@ func (al *AgentLoop) summarizeBatch(ctx context.Context, agent *AgentInstance, b
|
||||
}
|
||||
|
||||
// estimateTokens estimates the number of tokens in a message list.
|
||||
// Uses a safe heuristic of 2.5 characters per token to account for CJK and other
|
||||
// overheads better than the previous 3 chars/token.
|
||||
func (al *AgentLoop) estimateTokens(messages []providers.Message) int {
|
||||
total := 0
|
||||
totalChars := 0
|
||||
for _, m := range messages {
|
||||
total += len(m.Content) / 4
|
||||
totalChars += utf8.RuneCountInString(m.Content)
|
||||
}
|
||||
return total
|
||||
// 2.5 chars per token = totalChars * 2 / 5
|
||||
return totalChars * 2 / 5
|
||||
}
|
||||
|
||||
func (al *AgentLoop) handleCommand(ctx context.Context, msg bus.InboundMessage) (string, bool) {
|
||||
content := strings.TrimSpace(msg.Content)
|
||||
if !strings.HasPrefix(content, "/") {
|
||||
return "", false
|
||||
}
|
||||
|
||||
parts := strings.Fields(content)
|
||||
if len(parts) == 0 {
|
||||
return "", false
|
||||
}
|
||||
|
||||
cmd := parts[0]
|
||||
args := parts[1:]
|
||||
|
||||
switch cmd {
|
||||
case "/show":
|
||||
if len(args) < 1 {
|
||||
return "Usage: /show [model|channel|agents]", true
|
||||
}
|
||||
switch args[0] {
|
||||
case "model":
|
||||
defaultAgent := al.registry.GetDefaultAgent()
|
||||
if defaultAgent == nil {
|
||||
return "No default agent configured", true
|
||||
}
|
||||
return fmt.Sprintf("Current model: %s", defaultAgent.Model), true
|
||||
case "channel":
|
||||
return fmt.Sprintf("Current channel: %s", msg.Channel), true
|
||||
case "agents":
|
||||
agentIDs := al.registry.ListAgentIDs()
|
||||
return fmt.Sprintf("Registered agents: %s", strings.Join(agentIDs, ", ")), true
|
||||
default:
|
||||
return fmt.Sprintf("Unknown show target: %s", args[0]), true
|
||||
}
|
||||
|
||||
case "/list":
|
||||
if len(args) < 1 {
|
||||
return "Usage: /list [models|channels|agents]", true
|
||||
}
|
||||
switch args[0] {
|
||||
case "models":
|
||||
return "Available models: configured in config.json per agent", true
|
||||
case "channels":
|
||||
if al.channelManager == nil {
|
||||
return "Channel manager not initialized", true
|
||||
}
|
||||
channels := al.channelManager.GetEnabledChannels()
|
||||
if len(channels) == 0 {
|
||||
return "No channels enabled", true
|
||||
}
|
||||
return fmt.Sprintf("Enabled channels: %s", strings.Join(channels, ", ")), true
|
||||
case "agents":
|
||||
agentIDs := al.registry.ListAgentIDs()
|
||||
return fmt.Sprintf("Registered agents: %s", strings.Join(agentIDs, ", ")), true
|
||||
default:
|
||||
return fmt.Sprintf("Unknown list target: %s", args[0]), true
|
||||
}
|
||||
|
||||
case "/switch":
|
||||
if len(args) < 3 || args[1] != "to" {
|
||||
return "Usage: /switch [model|channel] to <name>", true
|
||||
}
|
||||
target := args[0]
|
||||
value := args[2]
|
||||
|
||||
switch target {
|
||||
case "model":
|
||||
defaultAgent := al.registry.GetDefaultAgent()
|
||||
if defaultAgent == nil {
|
||||
return "No default agent configured", true
|
||||
}
|
||||
oldModel := defaultAgent.Model
|
||||
defaultAgent.Model = value
|
||||
return fmt.Sprintf("Switched model from %s to %s", oldModel, value), true
|
||||
case "channel":
|
||||
if al.channelManager == nil {
|
||||
return "Channel manager not initialized", true
|
||||
}
|
||||
if _, exists := al.channelManager.GetChannel(value); !exists && value != "cli" {
|
||||
return fmt.Sprintf("Channel '%s' not found or not enabled", value), true
|
||||
}
|
||||
return fmt.Sprintf("Switched target channel to %s", value), true
|
||||
default:
|
||||
return fmt.Sprintf("Unknown switch target: %s", target), true
|
||||
}
|
||||
}
|
||||
|
||||
return "", false
|
||||
}
|
||||
|
||||
// extractPeer extracts the routing peer from inbound message metadata.
|
||||
|
||||
Reference in New Issue
Block a user