Merge branch 'main' into refactor-inbound-context-routing-session

# Conflicts:
#	pkg/agent/eventbus_test.go
#	pkg/agent/loop.go
#	pkg/bus/bus.go
#	pkg/bus/types.go
#	pkg/channels/pico/pico.go
#	pkg/channels/telegram/telegram.go
#	pkg/config/config.go
#	web/backend/api/session.go
#	web/backend/api/session_test.go
This commit is contained in:
Hoshina
2026-04-07 21:41:02 +08:00
282 changed files with 33064 additions and 3251 deletions
+5 -3
View File
@@ -602,14 +602,16 @@ func (cb *ContextBuilder) BuildMessages(
// Add conversation history
messages = append(messages, history...)
// Add current user message
if strings.TrimSpace(currentMessage) != "" {
// Add current user message. Media-only turns must still be preserved so
// multimodal providers receive the uploaded image even when the user sends
// no accompanying text.
if strings.TrimSpace(currentMessage) != "" || len(media) > 0 {
msg := providers.Message{
Role: "user",
Content: currentMessage,
}
if len(media) > 0 {
msg.Media = media
msg.Media = append([]string(nil), media...)
}
messages = append(messages, msg)
}
+11 -85
View File
@@ -6,10 +6,8 @@
package agent
import (
"encoding/json"
"unicode/utf8"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/tokenizer"
)
// parseTurnBoundaries returns the starting index of each Turn in the history.
@@ -86,88 +84,16 @@ func findSafeBoundary(history []providers.Message, targetIndex int) int {
return 0
}
// estimateMessageTokens estimates the token count for a single message,
// including Content, ReasoningContent, ToolCalls arguments, ToolCallID
// metadata, and Media items. Uses a heuristic of 2.5 characters per token.
func estimateMessageTokens(msg providers.Message) int {
contentChars := utf8.RuneCountInString(msg.Content)
// SystemParts are structured system blocks used for cache-aware adapters.
// They carry the same content as Content, but in multiple blocks.
// We estimate them as an alternative representation, not additive.
systemPartsChars := 0
if len(msg.SystemParts) > 0 {
for _, part := range msg.SystemParts {
systemPartsChars += utf8.RuneCountInString(part.Text)
}
// Per-part overhead for JSON structure (type, text, cache_control).
const perPartOverhead = 20
systemPartsChars += len(msg.SystemParts) * perPartOverhead
}
// Use the larger of the two representations to stay conservative.
chars := contentChars
if systemPartsChars > chars {
chars = systemPartsChars
}
chars += utf8.RuneCountInString(msg.ReasoningContent)
for _, tc := range msg.ToolCalls {
chars += len(tc.ID) + len(tc.Type)
if tc.Function != nil {
// Count function name + arguments (the wire format for most providers).
// tc.Name mirrors tc.Function.Name — count only once to avoid double-counting.
chars += len(tc.Function.Name) + len(tc.Function.Arguments)
} else {
// Fallback: some provider formats use top-level Name without Function.
chars += len(tc.Name)
}
}
if msg.ToolCallID != "" {
chars += len(msg.ToolCallID)
}
// Per-message overhead for role label, JSON structure, separators.
const messageOverhead = 12
chars += messageOverhead
tokens := chars * 2 / 5
// Media items (images, files) are serialized by provider adapters into
// multipart or image_url payloads. Add a fixed per-item token estimate
// directly (not through the chars heuristic) since actual cost depends
// on resolution and provider-specific image tokenization.
const mediaTokensPerItem = 256
tokens += len(msg.Media) * mediaTokensPerItem
return tokens
// EstimateMessageTokens estimates the token count for a single message.
// Delegates to the shared tokenizer package for consistency across agent and seahorse.
func EstimateMessageTokens(msg providers.Message) int {
return tokenizer.EstimateMessageTokens(msg)
}
// estimateToolDefsTokens estimates the total token cost of tool definitions
// as they appear in the LLM request. Each tool's name, description, and
// JSON schema parameters contribute to the context window budget.
func estimateToolDefsTokens(defs []providers.ToolDefinition) int {
if len(defs) == 0 {
return 0
}
totalChars := 0
for _, d := range defs {
totalChars += len(d.Function.Name) + len(d.Function.Description)
if d.Function.Parameters != nil {
if paramJSON, err := json.Marshal(d.Function.Parameters); err == nil {
totalChars += len(paramJSON)
}
}
// Per-tool overhead: type field, JSON structure, separators.
totalChars += 20
}
return totalChars * 2 / 5
// EstimateToolDefsTokens estimates the total token cost of tool definitions
// as they appear in the LLM request. Delegates to the shared tokenizer package.
func EstimateToolDefsTokens(defs []providers.ToolDefinition) int {
return tokenizer.EstimateToolDefsTokens(defs)
}
// isOverContextBudget checks whether the assembled messages plus tool definitions
@@ -181,10 +107,10 @@ func isOverContextBudget(
) bool {
msgTokens := 0
for _, m := range messages {
msgTokens += estimateMessageTokens(m)
msgTokens += EstimateMessageTokens(m)
}
toolTokens := estimateToolDefsTokens(toolDefs)
toolTokens := EstimateToolDefsTokens(toolDefs)
total := msgTokens + toolTokens + maxTokens
return total > contextWindow
+19 -19
View File
@@ -417,9 +417,9 @@ func TestEstimateMessageTokens(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := estimateMessageTokens(tt.msg)
got := EstimateMessageTokens(tt.msg)
if got < tt.want {
t.Errorf("estimateMessageTokens() = %d, want >= %d", got, tt.want)
t.Errorf("EstimateMessageTokens() = %d, want >= %d", got, tt.want)
}
})
}
@@ -443,8 +443,8 @@ func TestEstimateMessageTokens_ToolCallsContribute(t *testing.T) {
},
}
plainTokens := estimateMessageTokens(plain)
withTCTokens := estimateMessageTokens(withTC)
plainTokens := EstimateMessageTokens(plain)
withTCTokens := EstimateMessageTokens(withTC)
if withTCTokens <= plainTokens {
t.Errorf("message with ToolCalls (%d tokens) should exceed plain message (%d tokens)",
@@ -457,7 +457,7 @@ func TestEstimateMessageTokens_MultibyteContent(t *testing.T) {
// but may map to different token counts. The heuristic should still produce
// reasonable estimates via RuneCountInString.
msg := msgUser("caf\u00e9 na\u00efve r\u00e9sum\u00e9 \u00fcber stra\u00dfe")
tokens := estimateMessageTokens(msg)
tokens := EstimateMessageTokens(msg)
if tokens <= 0 {
t.Errorf("multibyte message should produce positive token count, got %d", tokens)
}
@@ -481,7 +481,7 @@ func TestEstimateMessageTokens_LargeArguments(t *testing.T) {
},
}
tokens := estimateMessageTokens(msg)
tokens := EstimateMessageTokens(msg)
// 5000+ chars → at least 2000 tokens with the 2.5 char/token heuristic
if tokens < 2000 {
t.Errorf("large tool call arguments should produce significant token count, got %d", tokens)
@@ -496,8 +496,8 @@ func TestEstimateMessageTokens_ReasoningContent(t *testing.T) {
ReasoningContent: strings.Repeat("thinking step ", 200),
}
plainTokens := estimateMessageTokens(plain)
reasoningTokens := estimateMessageTokens(withReasoning)
plainTokens := EstimateMessageTokens(plain)
reasoningTokens := EstimateMessageTokens(withReasoning)
if reasoningTokens <= plainTokens {
t.Errorf("message with ReasoningContent (%d tokens) should exceed plain message (%d tokens)",
@@ -513,8 +513,8 @@ func TestEstimateMessageTokens_MediaItems(t *testing.T) {
Media: []string{"media://img1.png", "media://img2.png"},
}
plainTokens := estimateMessageTokens(plain)
mediaTokens := estimateMessageTokens(withMedia)
plainTokens := EstimateMessageTokens(plain)
mediaTokens := EstimateMessageTokens(withMedia)
if mediaTokens <= plainTokens {
t.Errorf("message with Media (%d tokens) should exceed plain message (%d tokens)",
@@ -540,8 +540,8 @@ func TestEstimateMessageTokens_SystemParts(t *testing.T) {
},
}
plainTokens := estimateMessageTokens(plain)
partsTokens := estimateMessageTokens(withParts)
plainTokens := EstimateMessageTokens(plain)
partsTokens := EstimateMessageTokens(withParts)
if partsTokens <= plainTokens {
t.Errorf("system message with SystemParts (%d) should exceed plain message (%d)",
@@ -549,7 +549,7 @@ func TestEstimateMessageTokens_SystemParts(t *testing.T) {
}
}
// --- estimateToolDefsTokens tests ---
// --- EstimateToolDefsTokens tests ---
func TestEstimateToolDefsTokens(t *testing.T) {
tests := []struct {
@@ -599,9 +599,9 @@ func TestEstimateToolDefsTokens(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := estimateToolDefsTokens(tt.defs)
got := EstimateToolDefsTokens(tt.defs)
if got < tt.want {
t.Errorf("estimateToolDefsTokens() = %d, want >= %d", got, tt.want)
t.Errorf("EstimateToolDefsTokens() = %d, want >= %d", got, tt.want)
}
})
}
@@ -624,8 +624,8 @@ func TestEstimateToolDefsTokens_ScalesWithCount(t *testing.T) {
}
}
one := estimateToolDefsTokens([]providers.ToolDefinition{makeTool("tool_a")})
three := estimateToolDefsTokens([]providers.ToolDefinition{
one := EstimateToolDefsTokens([]providers.ToolDefinition{makeTool("tool_a")})
three := EstimateToolDefsTokens([]providers.ToolDefinition{
makeTool("tool_a"), makeTool("tool_b"), makeTool("tool_c"),
})
@@ -770,7 +770,7 @@ func TestEstimateMessageTokens_WithReasoningAndMedia(t *testing.T) {
},
}
tokens := estimateMessageTokens(msg)
tokens := EstimateMessageTokens(msg)
// ReasoningContent alone is ~1700 chars → ~680 tokens.
// Content + TC + overhead adds more. Should be well above 500.
@@ -781,7 +781,7 @@ func TestEstimateMessageTokens_WithReasoningAndMedia(t *testing.T) {
// Compare without reasoning to ensure it's counted.
msgNoReasoning := msg
msgNoReasoning.ReasoningContent = ""
tokensNoReasoning := estimateMessageTokens(msgNoReasoning)
tokensNoReasoning := EstimateMessageTokens(msgNoReasoning)
if tokens <= tokensNoReasoning {
t.Errorf("reasoning content should add tokens: with=%d, without=%d", tokens, tokensNoReasoning)
+32
View File
@@ -707,6 +707,38 @@ func TestEmptyWorkspaceBaselineDetectsNewFiles(t *testing.T) {
}
}
func TestBuildMessages_IncludesMediaOnlyCurrentMessage(t *testing.T) {
tmpDir := setupWorkspace(t, nil)
defer os.RemoveAll(tmpDir)
cb := NewContextBuilder(tmpDir)
msgs := cb.BuildMessages(
nil,
"",
"",
[]string{"data:image/png;base64,abc123"},
"pico",
"chat-1",
"",
"",
)
if len(msgs) != 2 {
t.Fatalf("len(msgs) = %d, want 2", len(msgs))
}
userMsg := msgs[1]
if userMsg.Role != "user" {
t.Fatalf("userMsg.Role = %q, want %q", userMsg.Role, "user")
}
if userMsg.Content != "" {
t.Fatalf("userMsg.Content = %q, want empty string", userMsg.Content)
}
if len(userMsg.Media) != 1 || userMsg.Media[0] != "data:image/png;base64,abc123" {
t.Fatalf("userMsg.Media = %#v, want image payload", userMsg.Media)
}
}
// BenchmarkBuildMessagesWithCache measures caching performance.
func BenchmarkBuildMessagesWithCache(b *testing.B) {
tmpDir, _ := os.MkdirTemp("", "picoclaw-bench-*")
+379
View File
@@ -0,0 +1,379 @@
package agent
import (
"context"
"fmt"
"strings"
"sync"
"time"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/providers"
)
// legacyContextManager wraps the existing summarization/compression logic
// as a ContextManager implementation. It is the default when no other
// ContextManager is configured.
type legacyContextManager struct {
al *AgentLoop
summarizing sync.Map // dedup for async Compact (post-turn)
}
func (m *legacyContextManager) Assemble(_ context.Context, req *AssembleRequest) (*AssembleResponse, error) {
// Legacy: read history from session, return as-is.
// Budget enforcement happens in BuildMessages caller via
// isOverContextBudget + forceCompression.
agent := m.al.registry.GetDefaultAgent()
if agent == nil {
return &AssembleResponse{}, nil
}
history := agent.Sessions.GetHistory(req.SessionKey)
summary := agent.Sessions.GetSummary(req.SessionKey)
return &AssembleResponse{
History: history,
Summary: summary,
}, nil
}
func (m *legacyContextManager) Compact(_ context.Context, req *CompactRequest) error {
switch req.Reason {
case ContextCompressReasonProactive, ContextCompressReasonRetry:
// Sync emergency compression — budget exceeded.
if result, ok := m.forceCompression(req.SessionKey); ok {
m.al.emitEvent(
EventKindContextCompress,
m.al.newTurnEventScope("", req.SessionKey, nil).meta(0, "forceCompression", "turn.context.compress"),
ContextCompressPayload{
Reason: req.Reason,
DroppedMessages: result.DroppedMessages,
RemainingMessages: result.RemainingMessages,
},
)
}
case ContextCompressReasonSummarize:
m.maybeSummarize(req.SessionKey)
}
return nil
}
func (m *legacyContextManager) Ingest(_ context.Context, _ *IngestRequest) error {
// Legacy: no-op. Messages are persisted by Sessions JSONL.
return nil
}
// maybeSummarize triggers summarization if the session history exceeds thresholds.
// It runs asynchronously in a goroutine.
func (m *legacyContextManager) maybeSummarize(sessionKey string) {
agent := m.al.registry.GetDefaultAgent()
if agent == nil {
return
}
newHistory := agent.Sessions.GetHistory(sessionKey)
tokenEstimate := m.estimateTokens(newHistory)
threshold := agent.ContextWindow * agent.SummarizeTokenPercent / 100
if len(newHistory) > agent.SummarizeMessageThreshold || tokenEstimate > threshold {
summarizeKey := agent.ID + ":" + sessionKey
if _, loading := m.summarizing.LoadOrStore(summarizeKey, true); !loading {
go func() {
defer m.summarizing.Delete(summarizeKey)
defer func() {
if r := recover(); r != nil {
logger.WarnCF("agent", "Summarization panic recovered", map[string]any{
"session_key": sessionKey,
"panic": r,
})
}
}()
logger.Debug("Memory threshold reached. Optimizing conversation history...")
m.summarizeSession(agent, sessionKey)
}()
}
}
}
type compressionResult struct {
DroppedMessages int
RemainingMessages int
}
// forceCompression aggressively reduces context when the limit is hit.
// It drops the oldest ~50% of Turns (a Turn is a complete user→LLM→response
// cycle, as defined in #1316), so tool-call sequences are never split.
func (m *legacyContextManager) forceCompression(sessionKey string) (compressionResult, bool) {
agent := m.al.registry.GetDefaultAgent()
if agent == nil {
return compressionResult{}, false
}
history := agent.Sessions.GetHistory(sessionKey)
if len(history) <= 2 {
return compressionResult{}, false
}
turns := parseTurnBoundaries(history)
var mid int
if len(turns) >= 2 {
mid = turns[len(turns)/2]
} else {
mid = findSafeBoundary(history, len(history)/2)
}
var keptHistory []providers.Message
if mid <= 0 {
for i := len(history) - 1; i >= 0; i-- {
if history[i].Role == "user" {
keptHistory = []providers.Message{history[i]}
break
}
}
} else {
keptHistory = history[mid:]
}
droppedCount := len(history) - len(keptHistory)
existingSummary := agent.Sessions.GetSummary(sessionKey)
compressionNote := fmt.Sprintf(
"[Emergency compression dropped %d oldest messages due to context limit]",
droppedCount,
)
if existingSummary != "" {
compressionNote = existingSummary + "\n\n" + compressionNote
}
agent.Sessions.SetSummary(sessionKey, compressionNote)
agent.Sessions.SetHistory(sessionKey, keptHistory)
agent.Sessions.Save(sessionKey)
logger.WarnCF("agent", "Forced compression executed", map[string]any{
"session_key": sessionKey,
"dropped_msgs": droppedCount,
"new_count": len(keptHistory),
})
return compressionResult{
DroppedMessages: droppedCount,
RemainingMessages: len(keptHistory),
}, true
}
func (m *legacyContextManager) summarizeSession(agent *AgentInstance, sessionKey string) {
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
history := agent.Sessions.GetHistory(sessionKey)
summary := agent.Sessions.GetSummary(sessionKey)
if len(history) <= 4 {
return
}
safeCut := findSafeBoundary(history, len(history)-4)
if safeCut <= 0 {
return
}
keepCount := len(history) - safeCut
toSummarize := history[:safeCut]
maxMessageTokens := agent.ContextWindow / 2
validMessages := make([]providers.Message, 0)
omitted := false
for _, msg := range toSummarize {
if msg.Role != "user" && msg.Role != "assistant" {
continue
}
msgTokens := len(msg.Content) / 2
if msgTokens > maxMessageTokens {
omitted = true
continue
}
validMessages = append(validMessages, msg)
}
if len(validMessages) == 0 {
return
}
const (
maxSummarizationMessages = 10
llmMaxRetries = 3
)
var finalSummary string
if len(validMessages) > maxSummarizationMessages {
mid := len(validMessages) / 2
mid = m.findNearestUserMessage(validMessages, mid)
part1 := validMessages[:mid]
part2 := validMessages[mid:]
s1, _ := m.summarizeBatch(ctx, agent, part1, "")
s2, _ := m.summarizeBatch(ctx, agent, part2, "")
mergePrompt := fmt.Sprintf(
"Merge these two conversation summaries into one cohesive summary:\n\n1: %s\n\n2: %s",
s1, s2,
)
resp, err := m.retryLLMCall(ctx, agent, mergePrompt, llmMaxRetries)
if err == nil && resp.Content != "" {
finalSummary = resp.Content
} else {
finalSummary = s1 + " " + s2
}
} else {
finalSummary, _ = m.summarizeBatch(ctx, agent, validMessages, summary)
}
if omitted && finalSummary != "" {
finalSummary += "\n[Note: Some oversized messages were omitted from this summary for efficiency.]"
}
if finalSummary != "" {
agent.Sessions.SetSummary(sessionKey, finalSummary)
agent.Sessions.TruncateHistory(sessionKey, keepCount)
agent.Sessions.Save(sessionKey)
m.al.emitEvent(
EventKindSessionSummarize,
m.al.newTurnEventScope(agent.ID, sessionKey, nil).meta(0, "summarizeSession", "turn.session.summarize"),
SessionSummarizePayload{
SummarizedMessages: len(validMessages),
KeptMessages: keepCount,
SummaryLen: len(finalSummary),
OmittedOversized: omitted,
},
)
}
}
func (m *legacyContextManager) findNearestUserMessage(messages []providers.Message, mid int) int {
originalMid := mid
for mid > 0 && messages[mid].Role != "user" {
mid--
}
if messages[mid].Role == "user" {
return mid
}
mid = originalMid
for mid < len(messages) && messages[mid].Role != "user" {
mid++
}
if mid < len(messages) {
return mid
}
return originalMid
}
func (m *legacyContextManager) retryLLMCall(
ctx context.Context,
agent *AgentInstance,
prompt string,
maxRetries int,
) (*providers.LLMResponse, error) {
const llmTemperature = 0.3
var resp *providers.LLMResponse
var err error
for attempt := 0; attempt < maxRetries; attempt++ {
m.al.activeRequests.Add(1)
resp, err = func() (*providers.LLMResponse, error) {
defer m.al.activeRequests.Done()
return agent.Provider.Chat(
ctx,
[]providers.Message{{Role: "user", Content: prompt}},
nil,
agent.Model,
map[string]any{
"max_tokens": agent.MaxTokens,
"temperature": llmTemperature,
"prompt_cache_key": agent.ID,
},
)
}()
if err == nil && resp != nil && resp.Content != "" {
return resp, nil
}
if attempt < maxRetries-1 {
time.Sleep(time.Duration(attempt+1) * 100 * time.Millisecond)
}
}
return resp, err
}
func (m *legacyContextManager) summarizeBatch(
ctx context.Context,
agent *AgentInstance,
batch []providers.Message,
existingSummary string,
) (string, error) {
const (
llmMaxRetries = 3
fallbackMinContentLength = 200
fallbackMaxContentPercent = 10
)
var sb strings.Builder
sb.WriteString("Provide a concise summary of this conversation segment, preserving core context and key points.\n")
if existingSummary != "" {
sb.WriteString("Existing context: ")
sb.WriteString(existingSummary)
sb.WriteString("\n")
}
sb.WriteString("\nCONVERSATION:\n")
for _, msg := range batch {
fmt.Fprintf(&sb, "%s: %s\n", msg.Role, msg.Content)
}
prompt := sb.String()
response, err := m.retryLLMCall(ctx, agent, prompt, llmMaxRetries)
if err == nil && response.Content != "" {
return strings.TrimSpace(response.Content), nil
}
var fallback strings.Builder
fallback.WriteString("Conversation summary: ")
for i, msg := range batch {
if i > 0 {
fallback.WriteString(" | ")
}
content := strings.TrimSpace(msg.Content)
runes := []rune(content)
if len(runes) == 0 {
fallback.WriteString(fmt.Sprintf("%s: ", msg.Role))
continue
}
keepLength := len(runes) * fallbackMaxContentPercent / 100
if keepLength < fallbackMinContentLength {
keepLength = fallbackMinContentLength
}
if keepLength > len(runes) {
keepLength = len(runes)
}
content = string(runes[:keepLength])
if keepLength < len(runes) {
content += "..."
}
fallback.WriteString(fmt.Sprintf("%s: %s", msg.Role, content))
}
return fallback.String(), nil
}
func (m *legacyContextManager) estimateTokens(messages []providers.Message) int {
total := 0
for _, msg := range messages {
total += EstimateMessageTokens(msg)
}
return total
}
+90
View File
@@ -0,0 +1,90 @@
package agent
import (
"context"
"encoding/json"
"fmt"
"sync"
"github.com/sipeed/picoclaw/pkg/providers"
)
// ContextManager manages conversation context via a pluggable strategy.
// Exactly ONE ContextManager is active per AgentLoop, selected by config.
// The default ("legacy") preserves current summarization behavior.
type ContextManager interface {
// Assemble builds budget-aware context from the ContextManager's own storage.
// Called before BuildMessages. Returns assembled messages ready for LLM.
Assemble(ctx context.Context, req *AssembleRequest) (*AssembleResponse, error)
// Compact compresses conversation history.
// Called after turn completes (may be async internally) and on context overflow (sync).
Compact(ctx context.Context, req *CompactRequest) error
// Ingest records a message into the ContextManager's own storage.
// Called after each message is persisted to session JSONL.
Ingest(ctx context.Context, req *IngestRequest) error
}
// AssembleRequest is the input to Assemble.
type AssembleRequest struct {
SessionKey string // session identifier
Budget int // context window in tokens
MaxTokens int // max response tokens
}
// AssembleResponse is the output of Assemble.
type AssembleResponse struct {
History []providers.Message // assembled conversation history for BuildMessages
Summary string // conversation summary embedded into system prompt by BuildMessages
}
// CompactRequest is the input to Compact.
type CompactRequest struct {
SessionKey string // session identifier
Reason ContextCompressReason // proactive_budget | llm_retry | summarize
Budget int // context window budget (used for retry aggressive compaction)
}
// IngestRequest is the input to Ingest.
type IngestRequest struct {
SessionKey string // session identifier
Message providers.Message // the message just persisted
}
// ContextManagerFactory constructs a ContextManager from config.
// al provides access to the AgentLoop's runtime resources (provider, model, workspace, etc.)
// cfg is the raw JSON configuration from config.json (may be nil).
type ContextManagerFactory func(cfg json.RawMessage, al *AgentLoop) (ContextManager, error)
var (
cmRegistryMu sync.RWMutex
cmRegistry = map[string]ContextManagerFactory{}
)
// RegisterContextManager registers a named ContextManager factory.
func RegisterContextManager(name string, factory ContextManagerFactory) error {
if name == "" {
return fmt.Errorf("context manager name is required")
}
if factory == nil {
return fmt.Errorf("context manager %q factory is nil", name)
}
cmRegistryMu.Lock()
defer cmRegistryMu.Unlock()
if _, exists := cmRegistry[name]; exists {
return fmt.Errorf("context manager %q is already registered", name)
}
cmRegistry[name] = factory
return nil
}
func lookupContextManager(name string) (ContextManagerFactory, bool) {
cmRegistryMu.RLock()
defer cmRegistryMu.RUnlock()
f, ok := cmRegistry[name]
return f, ok
}
+764
View File
@@ -0,0 +1,764 @@
package agent
import (
"context"
"encoding/json"
"os"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/providers"
)
// ---------------------------------------------------------------------------
// Factory registry tests
// ---------------------------------------------------------------------------
func TestRegisterContextManager_Success(t *testing.T) {
cleanup := resetCMRegistry()
defer cleanup()
factory := func(cfg json.RawMessage, al *AgentLoop) (ContextManager, error) {
return &noopContextManager{}, nil
}
if err := RegisterContextManager("test_cm", factory); err != nil {
t.Fatalf("unexpected error: %v", err)
}
f, ok := lookupContextManager("test_cm")
if !ok {
t.Fatal("expected factory to be registered")
}
if f == nil {
t.Fatal("expected non-nil factory")
}
}
func TestRegisterContextManager_EmptyName(t *testing.T) {
cleanup := resetCMRegistry()
defer cleanup()
err := RegisterContextManager("", func(cfg json.RawMessage, al *AgentLoop) (ContextManager, error) {
return &noopContextManager{}, nil
})
if err == nil {
t.Fatal("expected error for empty name")
}
if !strings.Contains(err.Error(), "name is required") {
t.Fatalf("unexpected error: %v", err)
}
}
func TestRegisterContextManager_NilFactory(t *testing.T) {
cleanup := resetCMRegistry()
defer cleanup()
err := RegisterContextManager("nil_factory", nil)
if err == nil {
t.Fatal("expected error for nil factory")
}
if !strings.Contains(err.Error(), "factory is nil") {
t.Fatalf("unexpected error: %v", err)
}
}
func TestRegisterContextManager_Duplicate(t *testing.T) {
cleanup := resetCMRegistry()
defer cleanup()
factory := func(cfg json.RawMessage, al *AgentLoop) (ContextManager, error) {
return &noopContextManager{}, nil
}
if err := RegisterContextManager("dup_cm", factory); err != nil {
t.Fatalf("first registration failed: %v", err)
}
err := RegisterContextManager("dup_cm", factory)
if err == nil {
t.Fatal("expected error for duplicate registration")
}
if !strings.Contains(err.Error(), "already registered") {
t.Fatalf("unexpected error: %v", err)
}
}
func TestLookupContextManager_Unknown(t *testing.T) {
cleanup := resetCMRegistry()
defer cleanup()
_, ok := lookupContextManager("nonexistent")
if ok {
t.Fatal("expected lookup to fail for unknown name")
}
}
// ---------------------------------------------------------------------------
// resolveContextManager tests
// ---------------------------------------------------------------------------
func TestResolveContextManager_Default(t *testing.T) {
cleanup := resetCMRegistry()
defer cleanup()
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: t.TempDir(),
ModelName: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
ContextManager: "", // default → legacy
},
},
}
al := newCMTestAgentLoop(cfg)
cm := al.contextManager
if cm == nil {
t.Fatal("expected non-nil context manager")
}
if _, ok := cm.(*legacyContextManager); !ok {
t.Fatalf("expected *legacyContextManager, got %T", cm)
}
}
func TestResolveContextManager_ExplicitLegacy(t *testing.T) {
cleanup := resetCMRegistry()
defer cleanup()
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: t.TempDir(),
ModelName: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
ContextManager: "legacy",
},
},
}
al := newCMTestAgentLoop(cfg)
if _, ok := al.contextManager.(*legacyContextManager); !ok {
t.Fatalf("expected *legacyContextManager, got %T", al.contextManager)
}
}
func TestResolveContextManager_UnknownFallsBackToLegacy(t *testing.T) {
cleanup := resetCMRegistry()
defer cleanup()
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: t.TempDir(),
ModelName: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
ContextManager: "unknown_cm",
},
},
}
al := newCMTestAgentLoop(cfg)
if _, ok := al.contextManager.(*legacyContextManager); !ok {
t.Fatalf("expected fallback to *legacyContextManager, got %T", al.contextManager)
}
}
func TestResolveContextManager_RegisteredFactory(t *testing.T) {
cleanup := resetCMRegistry()
defer cleanup()
factory := func(cfg json.RawMessage, al *AgentLoop) (ContextManager, error) {
return &noopContextManager{}, nil
}
if err := RegisterContextManager("custom_cm", factory); err != nil {
t.Fatalf("register failed: %v", err)
}
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: t.TempDir(),
ModelName: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
ContextManager: "custom_cm",
},
},
}
al := newCMTestAgentLoop(cfg)
if _, ok := al.contextManager.(*noopContextManager); !ok {
t.Fatalf("expected *noopContextManager, got %T", al.contextManager)
}
}
func TestResolveContextManager_FactoryError(t *testing.T) {
cleanup := resetCMRegistry()
defer cleanup()
factory := func(cfg json.RawMessage, al *AgentLoop) (ContextManager, error) {
return nil, os.ErrPermission
}
if err := RegisterContextManager("broken_cm", factory); err != nil {
t.Fatalf("register failed: %v", err)
}
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: t.TempDir(),
ModelName: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
ContextManager: "broken_cm",
},
},
}
al := newCMTestAgentLoop(cfg)
// Should fall back to legacy when factory returns error
if _, ok := al.contextManager.(*legacyContextManager); !ok {
t.Fatalf("expected fallback to *legacyContextManager on factory error, got %T", al.contextManager)
}
}
// ---------------------------------------------------------------------------
// Legacy Assemble tests
// ---------------------------------------------------------------------------
func TestLegacyAssemble_Passthrough(t *testing.T) {
cfg := testConfig(t)
al := newCMTestAgentLoop(cfg)
agent := al.registry.GetDefaultAgent()
if agent == nil {
t.Fatal("expected default agent")
}
history := []providers.Message{
{Role: "user", Content: "hello"},
{Role: "assistant", Content: "hi there"},
}
agent.Sessions.SetHistory("test-session", history)
resp, err := al.contextManager.Assemble(context.Background(), &AssembleRequest{
SessionKey: "test-session",
Budget: 8000,
MaxTokens: 4096,
})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(resp.History) != len(history) {
t.Fatalf("expected %d messages, got %d", len(history), len(resp.History))
}
for i, msg := range resp.History {
if msg.Content != history[i].Content || msg.Role != history[i].Role {
t.Fatalf("message %d mismatch: want %+v, got %+v", i, history[i], msg)
}
}
}
func TestLegacyAssemble_EmptyHistory(t *testing.T) {
cfg := testConfig(t)
al := newCMTestAgentLoop(cfg)
resp, err := al.contextManager.Assemble(context.Background(), &AssembleRequest{
SessionKey: "test-session",
Budget: 8000,
MaxTokens: 4096,
})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(resp.History) != 0 {
t.Fatalf("expected empty messages, got %d", len(resp.History))
}
}
// ---------------------------------------------------------------------------
// Legacy Compact overflow tests
// ---------------------------------------------------------------------------
func TestLegacyCompact_Overflow(t *testing.T) {
cfg := testConfig(t)
al := newCMTestAgentLoop(cfg)
defaultAgent := al.registry.GetDefaultAgent()
if defaultAgent == nil {
t.Fatal("expected default agent")
}
history := []providers.Message{
{Role: "user", Content: "msg 1"},
{Role: "assistant", Content: "resp 1"},
{Role: "user", Content: "msg 2"},
{Role: "assistant", Content: "resp 2"},
{Role: "user", Content: "msg 3"},
}
defaultAgent.Sessions.SetHistory("session-overflow", history)
sub := al.SubscribeEvents(16)
defer al.UnsubscribeEvents(sub.ID)
err := al.contextManager.Compact(context.Background(), &CompactRequest{
SessionKey: "session-overflow",
Reason: ContextCompressReasonRetry,
})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// After overflow compression, history should be shorter
newHistory := defaultAgent.Sessions.GetHistory("session-overflow")
if len(newHistory) >= len(history) {
t.Fatalf("expected compressed history, got %d messages (was %d)", len(newHistory), len(history))
}
// Summary should contain compression note
summary := defaultAgent.Sessions.GetSummary("session-overflow")
if !strings.Contains(summary, "Emergency compression") {
t.Fatalf("expected compression note in summary, got %q", summary)
}
// Event should carry the proactive reason
events := collectEventStream(sub.C)
compressEvt, ok := findEvent(events, EventKindContextCompress)
if !ok {
t.Fatal("expected context compress event")
}
payload, ok := compressEvt.Payload.(ContextCompressPayload)
if !ok {
t.Fatalf("expected ContextCompressPayload, got %T", compressEvt.Payload)
}
if payload.Reason != ContextCompressReasonRetry {
t.Fatalf("expected retry reason, got %q", payload.Reason)
}
}
func TestLegacyCompact_Overflow_ProactiveReason(t *testing.T) {
cfg := testConfig(t)
al := newCMTestAgentLoop(cfg)
defaultAgent := al.registry.GetDefaultAgent()
if defaultAgent == nil {
t.Fatal("expected default agent")
}
history := []providers.Message{
{Role: "user", Content: "msg 1"},
{Role: "assistant", Content: "resp 1"},
{Role: "user", Content: "msg 2"},
{Role: "assistant", Content: "resp 2"},
{Role: "user", Content: "msg 3"},
}
defaultAgent.Sessions.SetHistory("session-proactive", history)
sub := al.SubscribeEvents(16)
defer al.UnsubscribeEvents(sub.ID)
err := al.contextManager.Compact(context.Background(), &CompactRequest{
SessionKey: "session-proactive",
Reason: ContextCompressReasonProactive,
})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
events := collectEventStream(sub.C)
compressEvt, ok := findEvent(events, EventKindContextCompress)
if !ok {
t.Fatal("expected context compress event")
}
payload, ok := compressEvt.Payload.(ContextCompressPayload)
if !ok {
t.Fatalf("expected ContextCompressPayload, got %T", compressEvt.Payload)
}
if payload.Reason != ContextCompressReasonProactive {
t.Fatalf("expected proactive reason, got %q", payload.Reason)
}
}
func TestLegacyCompact_Overflow_TooShortToCompress(t *testing.T) {
cfg := testConfig(t)
al := newCMTestAgentLoop(cfg)
defaultAgent := al.registry.GetDefaultAgent()
if defaultAgent == nil {
t.Fatal("expected default agent")
}
history := []providers.Message{
{Role: "user", Content: "only one"},
}
defaultAgent.Sessions.SetHistory("session-tiny", history)
err := al.contextManager.Compact(context.Background(), &CompactRequest{
SessionKey: "session-tiny",
Reason: ContextCompressReasonRetry,
})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// History should be unchanged (too short to compress)
newHistory := defaultAgent.Sessions.GetHistory("session-tiny")
if len(newHistory) != len(history) {
t.Fatalf("expected history unchanged, got %d messages (was %d)", len(newHistory), len(history))
}
}
// ---------------------------------------------------------------------------
// Legacy Compact post-turn tests
// ---------------------------------------------------------------------------
func TestLegacyCompact_PostTurn_BelowThreshold(t *testing.T) {
cfg := testConfig(t)
al := newCMTestAgentLoop(cfg)
defaultAgent := al.registry.GetDefaultAgent()
if defaultAgent == nil {
t.Fatal("expected default agent")
}
// Small history, below summarization thresholds
history := []providers.Message{
{Role: "user", Content: "hi"},
{Role: "assistant", Content: "hello"},
}
defaultAgent.Sessions.SetHistory("session-small", history)
err := al.contextManager.Compact(context.Background(), &CompactRequest{
SessionKey: "session-small",
Reason: ContextCompressReasonSummarize,
})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// History should remain unchanged
newHistory := defaultAgent.Sessions.GetHistory("session-small")
if len(newHistory) != len(history) {
t.Fatalf("expected unchanged history, got %d messages (was %d)", len(newHistory), len(history))
}
}
func TestLegacyCompact_PostTurn_ExceedsMessageThreshold(t *testing.T) {
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: t.TempDir(),
ModelName: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
ContextWindow: 8000,
SummarizeMessageThreshold: 2,
SummarizeTokenPercent: 75,
},
},
}
msgBus := bus.NewMessageBus()
al := NewAgentLoop(cfg, msgBus, &simpleMockProvider{response: "summary"})
defaultAgent := al.registry.GetDefaultAgent()
if defaultAgent == nil {
t.Fatal("expected default agent")
}
// 6 messages > threshold of 2
history := []providers.Message{
{Role: "user", Content: "q1"},
{Role: "assistant", Content: "a1"},
{Role: "user", Content: "q2"},
{Role: "assistant", Content: "a2"},
{Role: "user", Content: "q3"},
{Role: "assistant", Content: "a3"},
}
defaultAgent.Sessions.SetHistory("session-threshold", history)
err := al.contextManager.Compact(context.Background(), &CompactRequest{
SessionKey: "session-threshold",
Reason: ContextCompressReasonSummarize,
})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Wait for async summarization to complete via event
sub := al.SubscribeEvents(16)
defer al.UnsubscribeEvents(sub.ID)
waitForEvent(t, sub.C, 5*time.Second, func(evt Event) bool {
return evt.Kind == EventKindSessionSummarize
})
newHistory := defaultAgent.Sessions.GetHistory("session-threshold")
if len(newHistory) >= len(history) {
t.Fatalf("expected summarization to reduce history from %d messages, got %d", len(history), len(newHistory))
}
}
// ---------------------------------------------------------------------------
// Legacy Ingest tests
// ---------------------------------------------------------------------------
func TestLegacyIngest_NoOp(t *testing.T) {
cfg := testConfig(t)
al := newCMTestAgentLoop(cfg)
err := al.contextManager.Ingest(context.Background(), &IngestRequest{
SessionKey: "session-ingest",
Message: providers.Message{Role: "user", Content: "test"},
})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
}
// ---------------------------------------------------------------------------
// Mock ContextManager — verifies dispatch through AgentLoop
// ---------------------------------------------------------------------------
func TestAgentLoop_UsesCustomContextManager(t *testing.T) {
cleanup := resetCMRegistry()
defer cleanup()
mock := &trackingContextManager{}
factory := func(cfg json.RawMessage, al *AgentLoop) (ContextManager, error) {
return mock, nil
}
if err := RegisterContextManager("tracking_cm", factory); err != nil {
t.Fatalf("register failed: %v", err)
}
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: t.TempDir(),
ModelName: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
ContextManager: "tracking_cm",
},
},
}
al := newCMTestAgentLoop(cfg)
// Verify the mock was installed
if al.contextManager != mock {
t.Fatalf("expected mock context manager, got %T", al.contextManager)
}
// Direct method calls
_, err := mock.Assemble(context.Background(), &AssembleRequest{
SessionKey: "s1",
Budget: 8000,
MaxTokens: 4096,
})
if err != nil {
t.Fatalf("Assemble error: %v", err)
}
if mock.assembleCalls.Load() != 1 {
t.Fatalf("expected 1 assemble call, got %d", mock.assembleCalls.Load())
}
err = mock.Compact(context.Background(), &CompactRequest{
SessionKey: "s1",
Reason: ContextCompressReasonRetry,
})
if err != nil {
t.Fatalf("Compact error: %v", err)
}
if mock.compactCalls.Load() != 1 {
t.Fatalf("expected 1 compact call, got %d", mock.compactCalls.Load())
}
err = mock.Ingest(context.Background(), &IngestRequest{
SessionKey: "s1",
Message: providers.Message{Role: "user", Content: "test"},
})
if err != nil {
t.Fatalf("Ingest error: %v", err)
}
if mock.ingestCalls.Load() != 1 {
t.Fatalf("expected 1 ingest call, got %d", mock.ingestCalls.Load())
}
}
func TestIngestCalledDuringTurn(t *testing.T) {
cleanup := resetCMRegistry()
defer cleanup()
mock := &trackingContextManager{}
factory := func(cfg json.RawMessage, al *AgentLoop) (ContextManager, error) {
return mock, nil
}
if err := RegisterContextManager("ingest_track_cm", factory); err != nil {
t.Fatalf("register failed: %v", err)
}
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: t.TempDir(),
ModelName: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
ContextManager: "ingest_track_cm",
},
},
}
msgBus := bus.NewMessageBus()
al := NewAgentLoop(cfg, msgBus, &simpleMockProvider{response: "done"})
defaultAgent := al.registry.GetDefaultAgent()
if defaultAgent == nil {
t.Fatal("expected default agent")
}
// Run a turn — ingestMessage is called for user message and final assistant message
_, err := al.runAgentLoop(context.Background(), defaultAgent, processOptions{
SessionKey: "session-ingest-turn",
Channel: "cli",
ChatID: "direct",
UserMessage: "test ingest",
DefaultResponse: defaultResponse,
EnableSummary: false,
SendResponse: false,
})
if err != nil {
t.Fatalf("runAgentLoop failed: %v", err)
}
// Should have at least 2 ingest calls: user message + final assistant message
if mock.ingestCalls.Load() < 2 {
t.Fatalf("expected >= 2 ingest calls during turn, got %d", mock.ingestCalls.Load())
}
}
// ---------------------------------------------------------------------------
// forceCompression edge cases (via legacy Compact)
// ---------------------------------------------------------------------------
func TestLegacyCompact_Overflow_SingleTurnKeepsLastUserMessage(t *testing.T) {
cfg := testConfig(t)
al := newCMTestAgentLoop(cfg)
defaultAgent := al.registry.GetDefaultAgent()
if defaultAgent == nil {
t.Fatal("expected default agent")
}
// History with only 2 messages — forceCompression should still handle it
history := []providers.Message{
{Role: "user", Content: "first question"},
{Role: "assistant", Content: "first answer"},
}
defaultAgent.Sessions.SetHistory("session-2msg", history)
err := al.contextManager.Compact(context.Background(), &CompactRequest{
SessionKey: "session-2msg",
Reason: ContextCompressReasonRetry,
})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
newHistory := defaultAgent.Sessions.GetHistory("session-2msg")
// With 2 messages, forceCompression returns false (len <= 2), so no compression
if len(newHistory) != len(history) {
t.Fatalf("expected no compression for 2-message history, got %d", len(newHistory))
}
}
// ---------------------------------------------------------------------------
// Test helpers
// ---------------------------------------------------------------------------
// noopContextManager is a minimal ContextManager that does nothing.
type noopContextManager struct{}
func (m *noopContextManager) Assemble(_ context.Context, req *AssembleRequest) (*AssembleResponse, error) {
return &AssembleResponse{}, nil
}
func (m *noopContextManager) Compact(_ context.Context, _ *CompactRequest) error { return nil }
func (m *noopContextManager) Ingest(_ context.Context, _ *IngestRequest) error { return nil }
// trackingContextManager tracks call counts for each method.
type trackingContextManager struct {
assembleCalls atomic.Int64
compactCalls atomic.Int64
ingestCalls atomic.Int64
mu sync.Mutex
lastAssemble *AssembleRequest
lastCompact *CompactRequest
lastIngest *IngestRequest
}
func (m *trackingContextManager) Assemble(_ context.Context, req *AssembleRequest) (*AssembleResponse, error) {
m.assembleCalls.Add(1)
m.mu.Lock()
m.lastAssemble = req
m.mu.Unlock()
return &AssembleResponse{}, nil
}
func (m *trackingContextManager) Compact(_ context.Context, req *CompactRequest) error {
m.compactCalls.Add(1)
m.mu.Lock()
m.lastCompact = req
m.mu.Unlock()
return nil
}
func (m *trackingContextManager) Ingest(_ context.Context, req *IngestRequest) error {
m.ingestCalls.Add(1)
m.mu.Lock()
m.lastIngest = req
m.mu.Unlock()
return nil
}
// resetCMRegistry clears the global factory registry and returns a cleanup
// function that restores the original state after the test.
func resetCMRegistry() func() {
cmRegistryMu.Lock()
original := make(map[string]ContextManagerFactory, len(cmRegistry))
for k, v := range cmRegistry {
original[k] = v
}
cmRegistry = make(map[string]ContextManagerFactory)
cmRegistryMu.Unlock()
return func() {
cmRegistryMu.Lock()
cmRegistry = original
cmRegistryMu.Unlock()
}
}
func testConfig(t *testing.T) *config.Config {
t.Helper()
return &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: t.TempDir(),
ModelName: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
}
func newCMTestAgentLoop(cfg *config.Config) *AgentLoop {
msgBus := bus.NewMessageBus()
return NewAgentLoop(cfg, msgBus, &simpleMockProvider{response: "test"})
}
+269
View File
@@ -0,0 +1,269 @@
//go:build !mipsle && !netbsd
package agent
import (
"context"
"encoding/json"
"fmt"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
"github.com/sipeed/picoclaw/pkg/seahorse"
"github.com/sipeed/picoclaw/pkg/session"
"github.com/sipeed/picoclaw/pkg/tokenizer"
)
// seahorseContextManager adapts seahorse.Engine to agent.ContextManager.
type seahorseContextManager struct {
engine *seahorse.Engine
sessions session.SessionStore // for startup bootstrap
}
// newSeahorseContextManager creates a seahorse-backed ContextManager.
func newSeahorseContextManager(_ json.RawMessage, al *AgentLoop) (ContextManager, error) {
if al == nil {
return nil, fmt.Errorf("seahorse: AgentLoop is required")
}
// Resolve workspace for DB path
// DB stores session data, so it goes in sessions/ directory
agent := al.registry.GetDefaultAgent()
dbPath := agent.Workspace + "/sessions/seahorse.db"
// Create CompleteFn from provider
completeFn := providerToCompleteFn(agent.Provider, agent.Model)
// Create engine
engine, err := seahorse.NewEngine(seahorse.Config{
DBPath: dbPath,
}, completeFn)
if err != nil {
return nil, fmt.Errorf("seahorse: create engine: %w", err)
}
mgr := &seahorseContextManager{
engine: engine,
sessions: agent.Sessions,
}
// Register seahorse tools with the agent's tool registry
retrieval := mgr.engine.GetRetrieval()
al.RegisterTool(seahorse.NewGrepTool(retrieval))
al.RegisterTool(seahorse.NewExpandTool(retrieval))
// Bootstrap all existing sessions at startup
if agent.Sessions != nil {
ctx := context.Background()
for _, sessionKey := range agent.Sessions.ListSessions() {
mgr.bootstrapSession(ctx, sessionKey)
}
}
return mgr, nil
}
// providerToCompleteFn wraps providers.LLMProvider as a seahorse.CompleteFn.
func providerToCompleteFn(provider providers.LLMProvider, model string) seahorse.CompleteFn {
return func(ctx context.Context, prompt string, opts seahorse.CompleteOptions) (string, error) {
resp, err := provider.Chat(
ctx,
[]providers.Message{{Role: "user", Content: prompt}},
nil, // no tools for summarization
model,
map[string]any{
"max_tokens": opts.MaxTokens,
"temperature": opts.Temperature,
"prompt_cache_key": "seahorse",
},
)
if err != nil {
return "", err
}
return resp.Content, nil
}
}
// Assemble builds budget-aware context from seahorse SQLite.
func (m *seahorseContextManager) Assemble(ctx context.Context, req *AssembleRequest) (*AssembleResponse, error) {
if req == nil {
return nil, fmt.Errorf("seahorse assemble: nil request")
}
budget := req.Budget
if budget <= 0 {
budget = 100000
}
// Reserve space for model response (spec lines 1400-1410)
effectiveBudget := budget - req.MaxTokens
if effectiveBudget <= 0 {
// MaxTokens >= budget is a configuration problem
// Use 50% as minimum to avoid guaranteed overflow
logger.WarnCF("agent", "MaxTokens >= budget, using 50% fallback",
map[string]any{"budget": budget, "max_tokens": req.MaxTokens})
effectiveBudget = budget / 2
}
result, err := m.engine.Assemble(ctx, req.SessionKey, seahorse.AssembleInput{
Budget: effectiveBudget,
})
if err != nil {
return nil, fmt.Errorf("seahorse assemble: %w", err)
}
history := seahorseToProviderMessages(result)
// Summary is already formatted as XML with system prompt addition by assembler
return &AssembleResponse{
History: history,
Summary: result.Summary,
}, nil
}
// Compact compresses conversation history via seahorse summarization.
func (m *seahorseContextManager) Compact(ctx context.Context, req *CompactRequest) error {
if req == nil {
return nil
}
// For retry (LLM overflow), use aggressive CompactUntilUnder to guarantee
// context shrinks below budget (spec lines ~1410).
if req.Reason == ContextCompressReasonRetry && req.Budget > 0 {
_, err := m.engine.CompactUntilUnder(ctx, req.SessionKey, req.Budget)
return err
}
_, err := m.engine.Compact(ctx, req.SessionKey, seahorse.CompactInput{
Force: req.Reason == ContextCompressReasonRetry,
Budget: &req.Budget,
})
return err
}
// Ingest records a message into seahorse SQLite.
// All existing sessions are bootstrapped at startup, so this only ingests new messages.
func (m *seahorseContextManager) Ingest(ctx context.Context, req *IngestRequest) error {
if req == nil {
return nil
}
msg := providerToSeahorseMessage(req.Message)
_, err := m.engine.Ingest(ctx, req.SessionKey, []seahorse.Message{msg})
return err
}
// bootstrapSession reconciles JSONL session history into seahorse SQLite.
func (m *seahorseContextManager) bootstrapSession(ctx context.Context, sessionKey string) {
if m.sessions == nil {
return
}
history := m.sessions.GetHistory(sessionKey)
if len(history) == 0 {
return
}
// Convert provider messages to seahorse messages
msgs := make([]seahorse.Message, len(history))
for i, h := range history {
msgs[i] = providerToSeahorseMessage(h)
}
if err := m.engine.Bootstrap(ctx, sessionKey, msgs); err != nil {
logger.WarnCF("seahorse", "bootstrap", map[string]any{
"session": sessionKey,
"error": err.Error(),
})
}
}
// providerToSeahorseMessage converts a providers.Message to a seahorse.Message.
func providerToSeahorseMessage(msg protocoltypes.Message) seahorse.Message {
result := seahorse.Message{
Role: msg.Role,
Content: msg.Content,
ReasoningContent: msg.ReasoningContent,
TokenCount: tokenizer.EstimateMessageTokens(msg),
}
// Convert ToolCalls → MessageParts
for _, tc := range msg.ToolCalls {
part := seahorse.MessagePart{
Type: "tool_use",
Name: tc.Function.Name,
Arguments: tc.Function.Arguments,
ToolCallID: tc.ID,
}
result.Parts = append(result.Parts, part)
}
// Convert tool result
if msg.ToolCallID != "" {
part := seahorse.MessagePart{
Type: "tool_result",
ToolCallID: msg.ToolCallID,
Text: msg.Content,
}
result.Parts = append(result.Parts, part)
}
// Convert media attachments
for _, mediaURI := range msg.Media {
part := seahorse.MessagePart{
Type: "media",
MediaURI: mediaURI,
}
result.Parts = append(result.Parts, part)
}
return result
}
// seahorseToProviderMessages converts a seahorse.AssembleResult to []providers.Message.
func seahorseToProviderMessages(result *seahorse.AssembleResult) []protocoltypes.Message {
messages := make([]protocoltypes.Message, 0, len(result.Messages))
// Convert assembled messages (which already include summary XML messages)
for _, msg := range result.Messages {
pm := protocoltypes.Message{
Role: msg.Role,
Content: msg.Content,
ReasoningContent: msg.ReasoningContent,
}
// Reconstruct ToolCalls from parts
for _, part := range msg.Parts {
if part.Type == "tool_use" {
pm.ToolCalls = append(pm.ToolCalls, protocoltypes.ToolCall{
ID: part.ToolCallID,
Type: "function", // Required by OpenAI-compatible APIs (GLM, etc.)
Function: &protocoltypes.FunctionCall{
Name: part.Name,
Arguments: part.Arguments,
},
})
}
if part.Type == "tool_result" {
pm.ToolCallID = part.ToolCallID
if pm.Content == "" && part.Text != "" {
pm.Content = part.Text
}
}
if part.Type == "media" && part.MediaURI != "" {
pm.Media = append(pm.Media, part.MediaURI)
}
}
messages = append(messages, pm)
}
return messages
}
func init() {
if err := RegisterContextManager("seahorse", newSeahorseContextManager); err != nil {
panic(fmt.Sprintf("register seahorse context manager: %v", err))
}
}
File diff suppressed because it is too large Load Diff
+20
View File
@@ -0,0 +1,20 @@
//go:build mipsle || netbsd
package agent
import (
"encoding/json"
"fmt"
)
// newSeahorseContextManager is unavailable on platforms where modernc sqlite/libc
// currently has no stable build path for this project.
func newSeahorseContextManager(_ json.RawMessage, _ *AgentLoop) (ContextManager, error) {
return nil, fmt.Errorf("seahorse context manager is unavailable on this platform")
}
func init() {
if err := RegisterContextManager("seahorse", newSeahorseContextManager); err != nil {
panic(fmt.Sprintf("register seahorse context manager: %v", err))
}
}
+2 -2
View File
@@ -511,8 +511,8 @@ func TestAgentLoop_EmitsSessionSummarizeEvent(t *testing.T) {
sub := al.SubscribeEvents(16)
defer al.UnsubscribeEvents(sub.ID)
turnScope := al.newTurnEventScope(defaultAgent.ID, "session-1", nil)
al.summarizeSession(defaultAgent, "session-1", turnScope)
lcm := &legacyContextManager{al: al}
lcm.summarizeSession(defaultAgent, "session-1")
events := collectEventStream(sub.C)
summaryEvt, ok := findEvent(events, EventKindSessionSummarize)
+2
View File
@@ -167,6 +167,8 @@ const (
ContextCompressReasonProactive ContextCompressReason = "proactive_budget"
// ContextCompressReasonRetry indicates compression during context-error retry handling.
ContextCompressReasonRetry ContextCompressReason = "llm_retry"
// ContextCompressReasonSummarize indicates post-turn async summarization.
ContextCompressReasonSummarize ContextCompressReason = "summarize"
)
// ContextCompressPayload describes a forced history compression.
+51 -1
View File
@@ -51,6 +51,10 @@ type AgentInstance struct {
// LightProvider is the concrete provider instance for the configured light model.
// It is only used when routing selects the light tier for a turn.
LightProvider providers.LLMProvider
// CandidateProviders maps "provider/model" keys to per-candidate LLMProvider
// instances. This allows each fallback model to use its own api_base and api_key
// from model_list, instead of inheriting the primary model's provider config.
CandidateProviders map[string]providers.LLMProvider
}
// NewAgentInstance creates an agent instance from config.
@@ -77,7 +81,12 @@ func NewAgentInstance(
if cfg.Tools.IsToolEnabled("read_file") {
maxReadFileSize := cfg.Tools.ReadFile.MaxReadFileSize
toolsRegistry.Register(tools.NewReadFileTool(workspace, readRestrict, maxReadFileSize, allowReadPaths))
switch cfg.Tools.ReadFile.EffectiveMode() {
case config.ReadFileModeLines:
toolsRegistry.Register(tools.NewReadFileLinesTool(workspace, readRestrict, maxReadFileSize, allowReadPaths))
default:
toolsRegistry.Register(tools.NewReadFileBytesTool(workspace, readRestrict, maxReadFileSize, allowReadPaths))
}
}
if cfg.Tools.IsToolEnabled("write_file") {
toolsRegistry.Register(tools.NewWriteFileTool(workspace, restrict, allowWritePaths))
@@ -170,6 +179,9 @@ func NewAgentInstance(
// Resolve fallback candidates
candidates := resolveModelCandidates(cfg, defaults.Provider, model, fallbacks)
candidateProviders := make(map[string]providers.LLMProvider)
populateCandidateProvidersFromNames(cfg, workspace, fallbacks, candidateProviders)
// Model routing setup: pre-resolve light model candidates at creation time
// to avoid repeated model_list lookups on every incoming message.
var router *routing.Router
@@ -194,6 +206,7 @@ func NewAgentInstance(
})
lightCandidates = resolved
lightProvider = lp
populateCandidateProvidersFromNames(cfg, workspace, []string{rc.LightModel}, candidateProviders)
}
}
} else {
@@ -225,6 +238,43 @@ func NewAgentInstance(
Router: router,
LightCandidates: lightCandidates,
LightProvider: lightProvider,
CandidateProviders: candidateProviders,
}
}
// populateCandidateProvidersFromNames resolves each model name (alias or
// "provider/model") via resolvedModelConfig and creates a dedicated LLMProvider
// for it. This reuses the canonical config resolution path (GetModelConfig) so
// alias handling and load-balancing stay consistent with the rest of the codebase.
func populateCandidateProvidersFromNames(
cfg *config.Config,
workspace string,
names []string,
out map[string]providers.LLMProvider,
) {
if cfg == nil || len(names) == 0 {
return
}
for _, name := range names {
mc, err := resolvedModelConfig(cfg, strings.TrimSpace(name), workspace)
if err != nil {
logger.WarnCF("agent",
"fallback provider: no model_list entry found; will inherit primary provider credentials",
map[string]any{"name": name, "error": err.Error()})
continue
}
protocol, modelID := providers.ExtractProtocol(strings.TrimSpace(mc.Model))
key := providers.ModelKey(providers.NormalizeProvider(protocol), modelID)
if _, exists := out[key]; exists {
continue
}
p, _, err := providers.CreateProviderFromConfig(mc)
if err != nil {
logger.WarnCF("agent", "fallback provider: failed to create provider",
map[string]any{"model": mc.Model, "error": err.Error()})
continue
}
out[key] = p
}
}
+287
View File
@@ -9,6 +9,7 @@ import (
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/media"
"github.com/sipeed/picoclaw/pkg/providers"
)
func TestNewAgentInstance_UsesDefaultsTemperatureAndMaxTokens(t *testing.T) {
@@ -165,6 +166,58 @@ func TestNewAgentInstance_ResolveCandidatesFromModelListAlias(t *testing.T) {
}
}
func TestNewAgentInstance_PreservesDistinctLimiterIdentityForSharedResolvedModel(t *testing.T) {
tmpDir := t.TempDir()
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
ModelName: "glm-4.7",
ModelFallbacks: []string{"glm-4.7__key_1"},
},
},
ModelList: []*config.ModelConfig{
{
ModelName: "glm-4.7",
Model: "zhipu/glm-4.7",
RPM: 1,
},
{
ModelName: "glm-4.7__key_1",
Model: "zhipu/glm-4.7",
RPM: 3,
},
},
}
agent := NewAgentInstance(nil, &cfg.Agents.Defaults, cfg, &mockProvider{})
if len(agent.Candidates) != 2 {
t.Fatalf("len(Candidates) = %d, want 2", len(agent.Candidates))
}
first := agent.Candidates[0]
second := agent.Candidates[1]
if first.Provider != "zhipu" || first.Model != "glm-4.7" {
t.Fatalf("first candidate = %s/%s, want zhipu/glm-4.7", first.Provider, first.Model)
}
if second.Provider != "zhipu" || second.Model != "glm-4.7" {
t.Fatalf("second candidate = %s/%s, want zhipu/glm-4.7", second.Provider, second.Model)
}
if first.IdentityKey != "model_name:glm-4.7" {
t.Fatalf("first identity key = %q, want %q", first.IdentityKey, "model_name:glm-4.7")
}
if second.IdentityKey != "model_name:glm-4.7__key_1" {
t.Fatalf("second identity key = %q, want %q", second.IdentityKey, "model_name:glm-4.7__key_1")
}
if first.RPM != 1 {
t.Fatalf("first RPM = %d, want 1", first.RPM)
}
if second.RPM != 3 {
t.Fatalf("second RPM = %d, want 3", second.RPM)
}
}
func TestNewAgentInstance_AllowsMediaTempDirForReadListAndExec(t *testing.T) {
workspace := t.TempDir()
mediaDir := media.TempDir()
@@ -248,6 +301,240 @@ func TestNewAgentInstance_AllowsMediaTempDirForReadListAndExec(t *testing.T) {
}
}
// TestPopulateCandidateProviders_NilCfgIsNoop verifies that passing a nil
// config does not panic and leaves the output map empty.
func TestPopulateCandidateProviders_NilCfgIsNoop(t *testing.T) {
out := map[string]providers.LLMProvider{}
populateCandidateProvidersFromNames(nil, t.TempDir(), []string{"gpt-4o"}, out)
if len(out) != 0 {
t.Fatalf("expected empty map, got %d entries", len(out))
}
}
// TestPopulateCandidateProviders_SkipsExistingKeys verifies that a key already
// present in the output map is not overwritten.
func TestPopulateCandidateProviders_SkipsExistingKeys(t *testing.T) {
existing := &mockProvider{}
key := providers.ModelKey("openai", "gpt-4o")
out := map[string]providers.LLMProvider{key: existing}
cfg := &config.Config{
ModelList: []*config.ModelConfig{
{ModelName: "my-gpt", Model: "openai/gpt-4o", APIKeys: config.SimpleSecureStrings("test-key")},
},
}
populateCandidateProvidersFromNames(cfg, t.TempDir(), []string{"my-gpt"}, out)
if out[key] != existing {
t.Fatal("existing provider entry was overwritten; expected it to be preserved")
}
}
// TestPopulateCandidateProviders_ResolvesAlias verifies that a model_name
// alias (e.g. "my-gpt") is resolved via GetModelConfig and the provider
// is created using the underlying model's config.
func TestPopulateCandidateProviders_ResolvesAlias(t *testing.T) {
workspace := t.TempDir()
out := map[string]providers.LLMProvider{}
cfg := &config.Config{
ModelList: []*config.ModelConfig{
{ModelName: "my-gpt", Model: "openai/gpt-4o", APIBase: "https://api.openai.com/v1", Workspace: workspace},
},
}
populateCandidateProvidersFromNames(cfg, workspace, []string{"my-gpt"}, out)
key := providers.ModelKey("openai", "gpt-4o")
if out[key] == nil {
t.Fatalf("expected CandidateProviders[%q] to be populated for alias", key)
}
}
// TestPopulateCandidateProviders_ResolvesProtocolPrefix verifies that a
// model_list entry using full "provider/model" notation (e.g.
// "gemini/gemma-3-27b-it") is matched correctly when referenced by model_name.
func TestPopulateCandidateProviders_ResolvesProtocolPrefix(t *testing.T) {
workspace := t.TempDir()
out := map[string]providers.LLMProvider{}
cfg := &config.Config{
ModelList: []*config.ModelConfig{
{
ModelName: "gemma",
Model: "gemini/gemma-3-27b-it",
APIKeys: config.SimpleSecureStrings("gemini-test-key"),
Workspace: workspace,
},
},
}
populateCandidateProvidersFromNames(cfg, workspace, []string{"gemma"}, out)
key := providers.ModelKey("gemini", "gemma-3-27b-it")
if out[key] == nil {
t.Fatalf("expected CandidateProviders[%q] to be populated for protocol-prefixed model", key)
}
}
// TestPopulateCandidateProviders_EmptyNamesIsNoop verifies the early-exit
// path when the names slice is empty.
func TestPopulateCandidateProviders_EmptyNamesIsNoop(t *testing.T) {
out := map[string]providers.LLMProvider{}
cfg := &config.Config{
ModelList: []*config.ModelConfig{
{ModelName: "my-gpt", Model: "openai/gpt-4o", APIKeys: config.SimpleSecureStrings("key")},
},
}
populateCandidateProvidersFromNames(cfg, t.TempDir(), nil, out)
if len(out) != 0 {
t.Fatalf("expected empty map, got %d entries", len(out))
}
}
// TestPopulateCandidateProviders_EmptyModelListIsNoop verifies the early-exit
// path when model_list is empty — no provider can be created.
func TestPopulateCandidateProviders_EmptyModelListIsNoop(t *testing.T) {
out := map[string]providers.LLMProvider{}
cfg := &config.Config{}
populateCandidateProvidersFromNames(cfg, t.TempDir(), []string{"gpt-4o"}, out)
if len(out) != 0 {
t.Fatalf("expected empty map, got %d entries", len(out))
}
}
// TestPopulateCandidateProviders_UnmatchedNameIsSkipped verifies that a
// name with no matching model_list entry is skipped and does not
// cause a panic or leave a nil entry in the map.
func TestPopulateCandidateProviders_UnmatchedNameIsSkipped(t *testing.T) {
out := map[string]providers.LLMProvider{}
cfg := &config.Config{
ModelList: []*config.ModelConfig{
{ModelName: "my-gpt", Model: "openai/gpt-4o", APIKeys: config.SimpleSecureStrings("key")},
},
}
populateCandidateProvidersFromNames(cfg, t.TempDir(), []string{"nonexistent-model"}, out)
if len(out) != 0 {
t.Fatalf("expected empty map for unmatched name, got %d entries", len(out))
}
}
// TestNewAgentInstance_CandidateProvidersPopulatedForCrossProviderFallbacks
// mirrors the exact scenario from bug #2140: primary model on OpenRouter with
// Gemini fallbacks. Each entry must get its own provider instance so that
// fallback requests go to the correct API endpoint, not the primary's.
func TestNewAgentInstance_CandidateProvidersPopulatedForCrossProviderFallbacks(t *testing.T) {
workspace := t.TempDir()
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: workspace,
ModelName: "mistral-small-3.1",
ModelFallbacks: []string{"gemma-3-27b", "gemini-images"},
},
},
ModelList: []*config.ModelConfig{
{
ModelName: "mistral-small-3.1",
Model: "openrouter/mistralai/mistral-small-3.1-24b-instruct:free",
APIBase: "https://openrouter.ai/api/v1",
APIKeys: config.SimpleSecureStrings("sk-or-test"),
Workspace: workspace,
},
{
ModelName: "gemma-3-27b",
Model: "gemini/gemma-3-27b-it",
APIKeys: config.SimpleSecureStrings("AIzaSy-test"),
Workspace: workspace,
},
{
ModelName: "gemini-images",
Model: "gemini/gemini-2.5-flash-lite",
APIKeys: config.SimpleSecureStrings("AIzaSy-test"),
Workspace: workspace,
},
},
}
primaryProvider := &mockProvider{}
agent := NewAgentInstance(nil, &cfg.Agents.Defaults, cfg, primaryProvider)
// Only fallback models need entries — the primary uses the injected provider directly.
wantKeys := []string{
providers.ModelKey("gemini", "gemma-3-27b-it"),
providers.ModelKey("gemini", "gemini-2.5-flash-lite"),
}
for _, key := range wantKeys {
p, ok := agent.CandidateProviders[key]
if !ok {
t.Errorf("CandidateProviders missing key %q", key)
continue
}
if p == nil {
t.Errorf("CandidateProviders[%q] is nil", key)
}
// Each fallback must use its own provider, not the injected primary.
if p == primaryProvider {
t.Errorf(
"CandidateProviders[%q] is the same instance as the primary provider; fallback would inherit primary credentials",
key,
)
}
}
if t.Failed() {
t.Logf("CandidateProviders keys present: %v", func() []string {
keys := make([]string, 0, len(agent.CandidateProviders))
for k := range agent.CandidateProviders {
keys = append(keys, k)
}
return keys
}())
}
}
func TestNewAgentInstance_ReadFileModeSelectsSchema(t *testing.T) {
workspace := t.TempDir()
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: workspace,
ModelName: "test-model",
},
},
Tools: config.ToolsConfig{
ReadFile: config.ReadFileToolConfig{
Enabled: true,
Mode: config.ReadFileModeLines,
MaxReadFileSize: 4096,
},
},
}
agent := NewAgentInstance(nil, &cfg.Agents.Defaults, cfg, &mockProvider{})
readTool, ok := agent.Tools.Get("read_file")
if !ok {
t.Fatal("read_file tool not registered")
}
params := readTool.Parameters()
props, _ := params["properties"].(map[string]any)
if _, ok := props["start_line"]; !ok {
t.Fatalf("expected line-mode schema to expose start_line, got %#v", props)
}
if _, ok := props["max_lines"]; !ok {
t.Fatalf("expected line-mode schema to expose max_lines, got %#v", props)
}
if _, ok := props["offset"]; ok {
t.Fatalf("did not expect line-mode schema to expose offset, got %#v", props)
}
if _, ok := props["length"]; ok {
t.Fatalf("did not expect line-mode schema to expose length, got %#v", props)
}
}
func TestNewAgentInstance_InvalidExecConfigDoesNotExit(t *testing.T) {
workspace := t.TempDir()
+181 -374
View File
@@ -18,6 +18,8 @@ import (
"sync/atomic"
"time"
"github.com/sipeed/picoclaw/pkg/audio/asr"
"github.com/sipeed/picoclaw/pkg/audio/tts"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/channels"
"github.com/sipeed/picoclaw/pkg/commands"
@@ -32,7 +34,6 @@ import (
"github.com/sipeed/picoclaw/pkg/state"
"github.com/sipeed/picoclaw/pkg/tools"
"github.com/sipeed/picoclaw/pkg/utils"
"github.com/sipeed/picoclaw/pkg/voice"
)
type AgentLoop struct {
@@ -48,11 +49,11 @@ type AgentLoop struct {
// Runtime state
running atomic.Bool
summarizing sync.Map
contextManager ContextManager
fallback *providers.FallbackChain
channelManager *channels.Manager
mediaStore media.MediaStore
transcriber voice.Transcriber
transcriber asr.Transcriber
cmdRegistry *commands.Registry
mcp mcpRuntime
hookRuntime hookRuntime
@@ -116,9 +117,18 @@ func NewAgentLoop(
) *AgentLoop {
registry := NewAgentRegistry(cfg, provider)
// Set up shared fallback chain
// Set up shared fallback chain with rate limiting.
cooldown := providers.NewCooldownTracker()
fallbackChain := providers.NewFallbackChain(cooldown)
rl := providers.NewRateLimiterRegistry()
// Register rate limiters for all agents' candidates so that RPM limits
// configured in ModelConfig are enforced before each LLM call.
for _, agentID := range registry.ListAgentIDs() {
if agent, ok := registry.GetAgent(agentID); ok {
rl.RegisterCandidates(agent.Candidates)
rl.RegisterCandidates(agent.LightCandidates)
}
}
fallbackChain := providers.NewFallbackChain(cooldown, rl)
// Create state manager using default agent's workspace for channel recording
defaultAgent := registry.GetDefaultAgent()
@@ -134,13 +144,13 @@ func NewAgentLoop(
registry: registry,
state: stateManager,
eventBus: eventBus,
summarizing: sync.Map{},
fallback: fallbackChain,
cmdRegistry: commands.NewRegistry(commands.BuiltinDefinitions()),
steering: newSteeringQueue(parseSteeringMode(cfg.Agents.Defaults.SteeringMode)),
}
al.hooks = NewHookManager(eventBus)
configureHookManagerFromConfig(al.hooks, cfg)
al.contextManager = al.resolveContextManager()
// Register shared tools to all agents (now that al is created)
registerSharedTools(al, cfg, msgBus, registry, provider)
@@ -157,6 +167,13 @@ func registerSharedTools(
provider providers.LLMProvider,
) {
allowReadPaths := buildAllowReadPatterns(cfg)
var ttsProvider tts.TTSProvider
if cfg.Tools.IsToolEnabled("send_tts") {
ttsProvider = tts.DetectTTS(cfg)
if ttsProvider == nil {
logger.WarnCF("voice-tts", "send_tts enabled but no TTS provider configured", nil)
}
}
for _, agentID := range registry.ListAgentIDs() {
agent, ok := registry.GetAgent(agentID)
@@ -267,6 +284,21 @@ func registerSharedTools(
agent.Tools.Register(sendFileTool)
}
if ttsProvider != nil {
agent.Tools.Register(tools.NewSendTTSTool(ttsProvider, nil))
}
if cfg.Tools.IsToolEnabled("load_image") {
loadImageTool := tools.NewLoadImageTool(
agent.Workspace,
cfg.Agents.Defaults.RestrictToWorkspace,
cfg.Agents.Defaults.GetMaxMediaSize(),
nil,
allowReadPaths,
)
agent.Tools.Register(loadImageTool)
}
// Skill discovery and installation tools
skills_enabled := cfg.Tools.IsToolEnabled("skills")
find_skills_enable := cfg.Tools.IsToolEnabled("find_skills")
@@ -309,6 +341,14 @@ func registerSharedTools(
subagentManager := tools.NewSubagentManager(provider, agent.Model, agent.Workspace)
subagentManager.SetLLMOptions(agent.MaxTokens, agent.Temperature)
// Inject a media resolver so the legacy RunToolLoop fallback path can
// resolve media:// refs in the same way the main AgentLoop does.
// This keeps subagent vision support working even when the optimized
// sub-turn spawner path is unavailable.
subagentManager.SetMediaResolver(func(msgs []providers.Message) []providers.Message {
return resolveMediaRefs(msgs, al.mediaStore, cfg.Agents.Defaults.GetMaxMediaSize())
})
// Set the spawner that links into AgentLoop's turnState
subagentManager.SetSpawner(func(
ctx context.Context,
@@ -1075,6 +1115,7 @@ func (al *AgentLoop) ReloadProviderAndConfig(
go func() {
defer func() {
if r := recover(); r != nil {
logger.RecoverPanicNoExit(r)
panicErr = fmt.Errorf("panic during registry creation: %v", r)
logger.ErrorCF("agent", "Panic during registry creation",
map[string]any{"panic": r})
@@ -1115,8 +1156,15 @@ func (al *AgentLoop) ReloadProviderAndConfig(
al.cfg = cfg
al.registry = registry
// Also update fallback chain with new config
al.fallback = providers.NewFallbackChain(providers.NewCooldownTracker())
// Also update fallback chain with new config; rebuild rate limiter registry.
newRL := providers.NewRateLimiterRegistry()
for _, agentID := range registry.ListAgentIDs() {
if agent, ok := registry.GetAgent(agentID); ok {
newRL.RegisterCandidates(agent.Candidates)
newRL.RegisterCandidates(agent.LightCandidates)
}
}
al.fallback = providers.NewFallbackChain(providers.NewCooldownTracker(), newRL)
al.mu.Unlock()
@@ -1174,10 +1222,15 @@ func (al *AgentLoop) SetMediaStore(s media.MediaStore) {
agent.Tools.SetMediaStore(s)
}
}
registry.ForEachTool("send_tts", func(t tools.Tool) {
if st, ok := t.(*tools.SendTTSTool); ok {
st.SetMediaStore(s)
}
})
}
// SetTranscriber injects a voice transcriber for agent-level audio transcription.
func (al *AgentLoop) SetTranscriber(t voice.Transcriber) {
func (al *AgentLoop) SetTranscriber(t asr.Transcriber) {
al.transcriber = t
}
@@ -1198,19 +1251,23 @@ func (al *AgentLoop) transcribeAudioInMessage(ctx context.Context, msg bus.Inbou
// Transcribe each audio media ref in order.
var transcriptions []string
var keptMedia []string
for _, ref := range msg.Media {
path, meta, err := al.mediaStore.ResolveWithMeta(ref)
if err != nil {
logger.WarnCF("voice", "Failed to resolve media ref", map[string]any{"ref": ref, "error": err})
keptMedia = append(keptMedia, ref)
continue
}
if !utils.IsAudioFile(meta.Filename, meta.ContentType) {
keptMedia = append(keptMedia, ref)
continue
}
result, err := al.transcriber.Transcribe(ctx, path)
if err != nil {
logger.WarnCF("voice", "Transcription failed", map[string]any{"ref": ref, "error": err})
transcriptions = append(transcriptions, "")
keptMedia = append(keptMedia, ref)
continue
}
transcriptions = append(transcriptions, result.Text)
@@ -1230,15 +1287,21 @@ func (al *AgentLoop) transcribeAudioInMessage(ctx context.Context, msg bus.Inbou
}
text := transcriptions[idx]
idx++
if text == "" {
return match
}
return "[voice: " + text + "]"
})
// Append any remaining transcriptions not matched by an annotation.
for ; idx < len(transcriptions); idx++ {
newContent += "\n[voice: " + transcriptions[idx] + "]"
if transcriptions[idx] != "" {
newContent += "\n[voice: " + transcriptions[idx] + "]"
}
}
msg.Content = newContent
msg.Media = keptMedia
return msg, true
}
@@ -1825,8 +1888,15 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er
var history []providers.Message
var summary string
if !ts.opts.NoHistory {
history = ts.agent.Sessions.GetHistory(ts.sessionKey)
summary = ts.agent.Sessions.GetSummary(ts.sessionKey)
// ContextManager assembles budget-aware history and summary.
if resp, err := al.contextManager.Assemble(turnCtx, &AssembleRequest{
SessionKey: ts.sessionKey,
Budget: ts.agent.ContextWindow,
MaxTokens: ts.agent.MaxTokens,
}); err == nil && resp != nil {
history = resp.History
summary = resp.Summary
}
}
ts.captureRestorePoint(history, summary)
@@ -1851,22 +1921,28 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er
if isOverContextBudget(ts.agent.ContextWindow, messages, toolDefs, ts.agent.MaxTokens) {
logger.WarnCF("agent", "Proactive compression: context budget exceeded before LLM call",
map[string]any{"session_key": ts.sessionKey})
if compression, ok := al.forceCompression(ts.agent, ts.sessionKey); ok {
al.emitEvent(
EventKindContextCompress,
ts.eventMeta("runTurn", "turn.context.compress"),
ContextCompressPayload{
Reason: ContextCompressReasonProactive,
DroppedMessages: compression.DroppedMessages,
RemainingMessages: compression.RemainingMessages,
},
)
ts.refreshRestorePointFromSession(ts.agent)
if err := al.contextManager.Compact(turnCtx, &CompactRequest{
SessionKey: ts.sessionKey,
Reason: ContextCompressReasonProactive,
Budget: ts.agent.ContextWindow,
}); err != nil {
logger.WarnCF("agent", "Proactive compact failed", map[string]any{
"session_key": ts.sessionKey,
"error": err.Error(),
})
}
ts.refreshRestorePointFromSession(ts.agent)
// Re-assemble from CM after compact.
if resp, err := al.contextManager.Assemble(turnCtx, &AssembleRequest{
SessionKey: ts.sessionKey,
Budget: ts.agent.ContextWindow,
MaxTokens: ts.agent.MaxTokens,
}); err == nil && resp != nil {
history = resp.History
summary = resp.Summary
}
newHistory := ts.agent.Sessions.GetHistory(ts.sessionKey)
newSummary := ts.agent.Sessions.GetSummary(ts.sessionKey)
messages = ts.agent.ContextBuilder.BuildMessages(
newHistory, newSummary, ts.userMessage,
history, summary, ts.userMessage,
ts.media, ts.channel, ts.chatID,
ts.opts.SenderID, ts.opts.SenderDisplayName,
activeSkillNames(ts.agent, ts.opts)...,
@@ -1888,6 +1964,7 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er
ts.agent.Sessions.AddMessage(ts.sessionKey, rootMsg.Role, rootMsg.Content)
}
ts.recordPersistedMessage(rootMsg)
ts.ingestMessage(turnCtx, al, rootMsg)
}
activeCandidates, activeModel, usedLight := al.selectCandidates(ts.agent, ts.userMessage, messages)
@@ -1963,6 +2040,7 @@ turnLoop:
if !ts.opts.NoHistory {
ts.agent.Sessions.AddFullMessage(ts.sessionKey, pm)
ts.recordPersistedMessage(pm)
ts.ingestMessage(turnCtx, al, pm)
}
logger.InfoCF("agent", "Injected steering message into context",
map[string]any{
@@ -2016,6 +2094,14 @@ turnLoop:
providerToolDefs = filtered
}
// Resolve media:// refs produced by tool results (e.g. load_image).
// Skipped on iteration 1 because inbound user media is already resolved
// before entering the loop; only subsequent iterations can contain new
// tool-generated media refs that need base64 encoding.
if iteration > 1 {
messages = resolveMediaRefs(messages, al.mediaStore, maxMediaSize)
}
callMessages := messages
if gracefulTerminal {
callMessages = append(append([]providers.Message(nil), messages...), ts.interruptHintMessage())
@@ -2115,7 +2201,11 @@ turnLoop:
providerCtx,
activeCandidates,
func(ctx context.Context, provider, model string) (*providers.LLMResponse, error) {
return activeProvider.Chat(ctx, messagesForCall, toolDefsForCall, model, llmOpts)
candidateProvider := activeProvider
if cp, ok := ts.agent.CandidateProviders[providers.ModelKey(provider, model)]; ok {
candidateProvider = cp
}
return candidateProvider.Chat(ctx, messagesForCall, toolDefsForCall, model, llmOpts)
},
)
if fbErr != nil {
@@ -2221,23 +2311,28 @@ turnLoop:
))
}
if compression, ok := al.forceCompression(ts.agent, ts.sessionKey); ok {
al.emitEvent(
EventKindContextCompress,
ts.eventMeta("runTurn", "turn.context.compress"),
ContextCompressPayload{
Reason: ContextCompressReasonRetry,
DroppedMessages: compression.DroppedMessages,
RemainingMessages: compression.RemainingMessages,
},
)
ts.refreshRestorePointFromSession(ts.agent)
if compactErr := al.contextManager.Compact(turnCtx, &CompactRequest{
SessionKey: ts.sessionKey,
Reason: ContextCompressReasonRetry,
Budget: ts.agent.ContextWindow,
}); compactErr != nil {
logger.WarnCF("agent", "Context overflow compact failed", map[string]any{
"session_key": ts.sessionKey,
"error": compactErr.Error(),
})
}
ts.refreshRestorePointFromSession(ts.agent)
// Re-assemble from CM after compact.
if asmResp, asmErr := al.contextManager.Assemble(turnCtx, &AssembleRequest{
SessionKey: ts.sessionKey,
Budget: ts.agent.ContextWindow,
MaxTokens: ts.agent.MaxTokens,
}); asmErr == nil && asmResp != nil {
history = asmResp.History
summary = asmResp.Summary
}
newHistory := ts.agent.Sessions.GetHistory(ts.sessionKey)
newSummary := ts.agent.Sessions.GetSummary(ts.sessionKey)
messages = ts.agent.ContextBuilder.BuildMessages(
newHistory, newSummary, "",
history, summary, "",
nil, ts.channel, ts.chatID, ts.opts.SenderID, ts.opts.SenderDisplayName,
activeSkillNames(ts.agent, ts.opts)...,
)
@@ -2409,6 +2504,7 @@ turnLoop:
if !ts.opts.NoHistory {
ts.agent.Sessions.AddFullMessage(ts.sessionKey, assistantMsg)
ts.recordPersistedMessage(assistantMsg)
ts.ingestMessage(turnCtx, al, assistantMsg)
}
ts.setPhase(TurnPhaseTools)
@@ -2633,6 +2729,7 @@ turnLoop:
if toolResult == nil {
toolResult = tools.ErrorResult("hook returned nil tool result")
}
if len(toolResult.Media) > 0 && toolResult.ResponseHandled {
parts := make([]bus.MediaPart, 0, len(toolResult.Media))
for _, ref := range toolResult.Media {
@@ -2675,6 +2772,13 @@ turnLoop:
}
if len(toolResult.Media) > 0 && !toolResult.ResponseHandled {
// For tools like load_image that produce media refs without sending them
// to the user channel (ResponseHandled == false), both Media and ArtifactTags
// coexist on the result:
// - Media: carries media:// refs that resolveMediaRefs will base64-encode
// into image_url parts in the next LLM iteration (enabling vision).
// - ArtifactTags: exposes the local file path as a structured [file:…] tag
// in the tool result text, so the LLM knows an artifact was produced.
toolResult.ArtifactTags = buildArtifactTags(al.mediaStore, toolResult.Media)
}
@@ -2693,7 +2797,6 @@ turnLoop:
"content_len": len(toolResult.ForUser),
})
}
contentForLLM := toolResult.ContentForLLM()
// Filter sensitive data (API keys, tokens, secrets) before sending to LLM
@@ -2706,6 +2809,9 @@ turnLoop:
Content: contentForLLM,
ToolCallID: toolCallID,
}
if len(toolResult.Media) > 0 && !toolResult.ResponseHandled {
toolResultMsg.Media = append(toolResultMsg.Media, toolResult.Media...)
}
al.emitEvent(
EventKindToolExecEnd,
ts.eventMeta("runTurn", "turn.tool.end"),
@@ -2722,6 +2828,7 @@ turnLoop:
if !ts.opts.NoHistory {
ts.agent.Sessions.AddFullMessage(ts.sessionKey, toolResultMsg)
ts.recordPersistedMessage(toolResultMsg)
ts.ingestMessage(turnCtx, al, toolResultMsg)
}
if steerMsgs := al.dequeueSteeringMessagesForScope(ts.sessionKey); len(steerMsgs) > 0 {
@@ -2821,6 +2928,7 @@ turnLoop:
if !ts.opts.NoHistory {
ts.agent.Sessions.AddMessage(ts.sessionKey, summaryMsg.Role, summaryMsg.Content)
ts.recordPersistedMessage(summaryMsg)
ts.ingestMessage(turnCtx, al, summaryMsg)
if err := ts.agent.Sessions.Save(ts.sessionKey); err != nil {
turnStatus = TurnEndStatusError
al.emitEvent(
@@ -2835,7 +2943,7 @@ turnLoop:
}
}
if ts.opts.EnableSummary {
al.maybeSummarize(ts.agent, ts.sessionKey, ts.scope)
al.contextManager.Compact(turnCtx, &CompactRequest{SessionKey: ts.sessionKey, Reason: ContextCompressReasonSummarize, Budget: ts.agent.ContextWindow})
}
ts.setPhase(TurnPhaseCompleted)
@@ -2890,6 +2998,7 @@ turnLoop:
finalMsg := providers.Message{Role: "assistant", Content: finalContent}
ts.agent.Sessions.AddMessage(ts.sessionKey, finalMsg.Role, finalMsg.Content)
ts.recordPersistedMessage(finalMsg)
ts.ingestMessage(turnCtx, al, finalMsg)
if err := ts.agent.Sessions.Save(ts.sessionKey); err != nil {
turnStatus = TurnEndStatusError
al.emitEvent(
@@ -2905,7 +3014,14 @@ turnLoop:
}
if ts.opts.EnableSummary {
al.maybeSummarize(ts.agent, ts.sessionKey, ts.scope)
al.contextManager.Compact(
turnCtx,
&CompactRequest{
SessionKey: ts.sessionKey,
Reason: ContextCompressReasonSummarize,
Budget: ts.agent.ContextWindow,
},
)
}
ts.setPhase(TurnPhaseCompleted)
@@ -2984,103 +3100,28 @@ func (al *AgentLoop) selectCandidates(
return agent.LightCandidates, resolvedCandidateModel(agent.LightCandidates, agent.Router.LightModel()), true
}
// maybeSummarize triggers summarization if the session history exceeds thresholds.
func (al *AgentLoop) maybeSummarize(agent *AgentInstance, sessionKey string, turnScope turnEventScope) {
newHistory := agent.Sessions.GetHistory(sessionKey)
tokenEstimate := al.estimateTokens(newHistory)
threshold := agent.ContextWindow * agent.SummarizeTokenPercent / 100
if len(newHistory) > agent.SummarizeMessageThreshold || tokenEstimate > threshold {
summarizeKey := agent.ID + ":" + sessionKey
if _, loading := al.summarizing.LoadOrStore(summarizeKey, true); !loading {
go func() {
defer al.summarizing.Delete(summarizeKey)
logger.Debug("Memory threshold reached. Optimizing conversation history...")
al.summarizeSession(agent, sessionKey, turnScope)
}()
}
// resolveContextManager selects the ContextManager implementation based on config.
func (al *AgentLoop) resolveContextManager() ContextManager {
name := al.cfg.Agents.Defaults.ContextManager
if name == "" || name == "legacy" {
return &legacyContextManager{al: al}
}
}
type compressionResult struct {
DroppedMessages int
RemainingMessages int
}
// forceCompression aggressively reduces context when the limit is hit.
// It drops the oldest ~50% of Turns (a Turn is a complete user→LLM→response
// cycle, as defined in #1316), so tool-call sequences are never split.
//
// If the history is a single Turn with no safe split point, the function
// falls back to keeping only the most recent user message. This breaks
// Turn atomicity as a last resort to avoid a context-exceeded loop.
//
// Session history contains only user/assistant/tool messages — the system
// prompt is built dynamically by BuildMessages and is NOT stored here.
// The compression note is recorded in the session summary so that
// BuildMessages can include it in the next system prompt.
func (al *AgentLoop) forceCompression(agent *AgentInstance, sessionKey string) (compressionResult, bool) {
history := agent.Sessions.GetHistory(sessionKey)
if len(history) <= 2 {
return compressionResult{}, false
factory, ok := lookupContextManager(name)
if !ok {
logger.WarnCF("agent", "Unknown context manager, falling back to legacy", map[string]any{
"name": name,
})
return &legacyContextManager{al: al}
}
// Split at a Turn boundary so no tool-call sequence is torn apart.
// parseTurnBoundaries gives us the start of each Turn; we drop the
// oldest half of Turns and keep the most recent ones.
turns := parseTurnBoundaries(history)
var mid int
if len(turns) >= 2 {
mid = turns[len(turns)/2]
} else {
// Fewer than 2 Turns — fall back to message-level midpoint
// aligned to the nearest Turn boundary.
mid = findSafeBoundary(history, len(history)/2)
cm, err := factory(al.cfg.Agents.Defaults.ContextManagerConfig, al)
if err != nil {
logger.WarnCF("agent", "Failed to create context manager, falling back to legacy", map[string]any{
"name": name,
"error": err.Error(),
})
return &legacyContextManager{al: al}
}
var keptHistory []providers.Message
if mid <= 0 {
// No safe Turn boundary — the entire history is a single Turn
// (e.g. one user message followed by a massive tool response).
// Keeping everything would leave the agent stuck in a context-
// exceeded loop, so fall back to keeping only the most recent
// user message. This breaks Turn atomicity as a last resort.
for i := len(history) - 1; i >= 0; i-- {
if history[i].Role == "user" {
keptHistory = []providers.Message{history[i]}
break
}
}
} else {
keptHistory = history[mid:]
}
droppedCount := len(history) - len(keptHistory)
// Record compression in the session summary so BuildMessages includes it
// in the system prompt. We do not modify history messages themselves.
existingSummary := agent.Sessions.GetSummary(sessionKey)
compressionNote := fmt.Sprintf(
"[Emergency compression dropped %d oldest messages due to context limit]",
droppedCount,
)
if existingSummary != "" {
compressionNote = existingSummary + "\n\n" + compressionNote
}
agent.Sessions.SetSummary(sessionKey, compressionNote)
agent.Sessions.SetHistory(sessionKey, keptHistory)
agent.Sessions.Save(sessionKey)
logger.WarnCF("agent", "Forced compression executed", map[string]any{
"session_key": sessionKey,
"dropped_msgs": droppedCount,
"new_count": len(keptHistory),
})
return compressionResult{
DroppedMessages: droppedCount,
RemainingMessages: len(keptHistory),
}, true
return cm
}
// GetStartupInfo returns information about loaded tools and skills for logging.
@@ -3172,247 +3213,13 @@ func formatToolsForLog(toolDefs []providers.ToolDefinition) string {
}
// summarizeSession summarizes the conversation history for a session.
func (al *AgentLoop) summarizeSession(agent *AgentInstance, sessionKey string, turnScope turnEventScope) {
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
history := agent.Sessions.GetHistory(sessionKey)
summary := agent.Sessions.GetSummary(sessionKey)
// Keep the most recent Turns for continuity, aligned to a Turn boundary
// so that no tool-call sequence is split.
if len(history) <= 4 {
return
}
safeCut := findSafeBoundary(history, len(history)-4)
if safeCut <= 0 {
return
}
keepCount := len(history) - safeCut
toSummarize := history[:safeCut]
// Oversized Message Guard
maxMessageTokens := agent.ContextWindow / 2
validMessages := make([]providers.Message, 0)
omitted := false
for _, m := range toSummarize {
if m.Role != "user" && m.Role != "assistant" {
continue
}
msgTokens := len(m.Content) / 2
if msgTokens > maxMessageTokens {
omitted = true
continue
}
validMessages = append(validMessages, m)
}
if len(validMessages) == 0 {
return
}
const (
maxSummarizationMessages = 10
llmMaxRetries = 3
llmTemperature = 0.3
fallbackMaxContentLength = 200
)
// Multi-Part Summarization
var finalSummary string
if len(validMessages) > maxSummarizationMessages {
mid := len(validMessages) / 2
mid = al.findNearestUserMessage(validMessages, mid)
part1 := validMessages[:mid]
part2 := validMessages[mid:]
s1, _ := al.summarizeBatch(ctx, agent, part1, "")
s2, _ := al.summarizeBatch(ctx, agent, part2, "")
mergePrompt := fmt.Sprintf(
"Merge these two conversation summaries into one cohesive summary:\n\n1: %s\n\n2: %s",
s1,
s2,
)
resp, err := al.retryLLMCall(ctx, agent, mergePrompt, llmMaxRetries)
if err == nil && resp.Content != "" {
finalSummary = resp.Content
} else {
finalSummary = s1 + " " + s2
}
} else {
finalSummary, _ = al.summarizeBatch(ctx, agent, validMessages, summary)
}
if omitted && finalSummary != "" {
finalSummary += "\n[Note: Some oversized messages were omitted from this summary for efficiency.]"
}
if finalSummary != "" {
agent.Sessions.SetSummary(sessionKey, finalSummary)
agent.Sessions.TruncateHistory(sessionKey, keepCount)
agent.Sessions.Save(sessionKey)
al.emitEvent(
EventKindSessionSummarize,
turnScope.meta(0, "summarizeSession", "turn.session.summarize"),
SessionSummarizePayload{
SummarizedMessages: len(validMessages),
KeptMessages: keepCount,
SummaryLen: len(finalSummary),
OmittedOversized: omitted,
},
)
}
}
// findNearestUserMessage finds the nearest user message to the given index.
// It searches backward first, then forward if no user message is found.
func (al *AgentLoop) findNearestUserMessage(messages []providers.Message, mid int) int {
originalMid := mid
for mid > 0 && messages[mid].Role != "user" {
mid--
}
if messages[mid].Role == "user" {
return mid
}
mid = originalMid
for mid < len(messages) && messages[mid].Role != "user" {
mid++
}
if mid < len(messages) {
return mid
}
return originalMid
}
// retryLLMCall calls the LLM with retry logic.
func (al *AgentLoop) retryLLMCall(
ctx context.Context,
agent *AgentInstance,
prompt string,
maxRetries int,
) (*providers.LLMResponse, error) {
const (
llmTemperature = 0.3
)
var resp *providers.LLMResponse
var err error
for attempt := 0; attempt < maxRetries; attempt++ {
al.activeRequests.Add(1)
resp, err = func() (*providers.LLMResponse, error) {
defer al.activeRequests.Done()
return agent.Provider.Chat(
ctx,
[]providers.Message{{Role: "user", Content: prompt}},
nil,
agent.Model,
map[string]any{
"max_tokens": agent.MaxTokens,
"temperature": llmTemperature,
"prompt_cache_key": agent.ID,
},
)
}()
if err == nil && resp != nil && resp.Content != "" {
return resp, nil
}
if attempt < maxRetries-1 {
time.Sleep(time.Duration(attempt+1) * 100 * time.Millisecond)
}
}
return resp, err
}
// summarizeBatch summarizes a batch of messages.
func (al *AgentLoop) summarizeBatch(
ctx context.Context,
agent *AgentInstance,
batch []providers.Message,
existingSummary string,
) (string, error) {
const (
llmMaxRetries = 3
llmTemperature = 0.3
fallbackMinContentLength = 200
fallbackMaxContentPercent = 10
)
var sb strings.Builder
sb.WriteString(
"Provide a concise summary of this conversation segment, preserving core context and key points.\n",
)
if existingSummary != "" {
sb.WriteString("Existing context: ")
sb.WriteString(existingSummary)
sb.WriteString("\n")
}
sb.WriteString("\nCONVERSATION:\n")
for _, m := range batch {
fmt.Fprintf(&sb, "%s: %s\n", m.Role, m.Content)
}
prompt := sb.String()
response, err := al.retryLLMCall(ctx, agent, prompt, llmMaxRetries)
if err == nil && response.Content != "" {
return strings.TrimSpace(response.Content), nil
}
var fallback strings.Builder
fallback.WriteString("Conversation summary: ")
for i, m := range batch {
if i > 0 {
fallback.WriteString(" | ")
}
content := strings.TrimSpace(m.Content)
runes := []rune(content)
if len(runes) == 0 {
fallback.WriteString(fmt.Sprintf("%s: ", m.Role))
continue
}
keepLength := len(runes) * fallbackMaxContentPercent / 100
if keepLength < fallbackMinContentLength {
keepLength = fallbackMinContentLength
}
if keepLength > len(runes) {
keepLength = len(runes)
}
content = string(runes[:keepLength])
if keepLength < len(runes) {
content += "..."
}
fallback.WriteString(fmt.Sprintf("%s: %s", m.Role, content))
}
return fallback.String(), nil
}
// estimateTokens estimates the number of tokens in a message list.
// Counts Content, ToolCalls arguments, and ToolCallID metadata so that
// tool-heavy conversations are not systematically undercounted.
func (al *AgentLoop) estimateTokens(messages []providers.Message) int {
total := 0
for _, m := range messages {
total += estimateMessageTokens(m)
}
return total
}
func (al *AgentLoop) handleCommand(
ctx context.Context,
msg bus.InboundMessage,
@@ -3609,7 +3416,7 @@ func (al *AgentLoop) buildCommandsRuntime(agent *AgentInstance, opts *processOpt
return "", fmt.Errorf("failed to initialize model %q: %w", value, err)
}
nextCandidates := resolveModelCandidates(cfg, cfg.Agents.Defaults.Provider, modelCfg.Model, agent.Fallbacks)
nextCandidates := resolveModelCandidates(cfg, cfg.Agents.Defaults.Provider, value, agent.Fallbacks)
if len(nextCandidates) == 0 {
return "", fmt.Errorf("model %q did not resolve to any provider candidates", value)
}
+2
View File
@@ -126,6 +126,8 @@ func (al *AgentLoop) ensureMCPInitialized(ctx context.Context) error {
}
mcpTool := tools.NewMCPTool(mcpManager, serverName, tool)
mcpTool.SetWorkspace(agent.Workspace)
mcpTool.SetMaxInlineTextRunes(al.cfg.Tools.MCP.GetMaxInlineTextChars())
if registerAsHidden {
agent.Tools.RegisterHidden(mcpTool)
+156
View File
@@ -2132,6 +2132,162 @@ func TestProcessMessage_ModelRoutingUsesLightProvider(t *testing.T) {
}
}
// TestProcessMessage_FallbackUsesPerCandidateProvider is the loop-level test for
// bug #2140. It verifies that when the primary model returns a rate-limit error
// the fallback closure routes the retry to the fallback model's own provider
// (its own api_base), not back to the primary provider's endpoint.
func TestProcessMessage_FallbackUsesPerCandidateProvider(t *testing.T) {
workspace := t.TempDir()
primaryCalls := 0
primaryServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
primaryCalls++
// Return 429 so FallbackChain classifies this as retriable and moves on.
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusTooManyRequests)
_ = json.NewEncoder(w).Encode(map[string]any{
"error": map[string]any{
"message": "rate limit exceeded",
"type": "rate_limit_error",
},
})
}))
defer primaryServer.Close()
fallbackCalls := 0
fallbackServer := newStrictChatCompletionTestServer(
t, "fallback", "gemma-3-27b-it", "fallback reply", &fallbackCalls,
)
defer fallbackServer.Close()
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: workspace,
ModelName: "mistral-primary",
ModelFallbacks: []string{"gemma-fallback"},
MaxTokens: 4096,
MaxToolIterations: 3,
},
},
ModelList: []*config.ModelConfig{
{
ModelName: "mistral-primary",
Model: "openrouter/mistralai/mistral-small-3.1",
APIBase: primaryServer.URL,
APIKeys: config.SimpleSecureStrings("primary-key"),
Workspace: workspace,
},
{
ModelName: "gemma-fallback",
Model: "gemini/gemma-3-27b-it",
APIBase: fallbackServer.URL,
APIKeys: config.SimpleSecureStrings("fallback-key"),
Workspace: workspace,
},
},
}
provider, _, err := providers.CreateProvider(cfg)
if err != nil {
t.Fatalf("CreateProvider() error = %v", err)
}
msgBus := bus.NewMessageBus()
al := NewAgentLoop(cfg, msgBus, provider)
helper := testHelper{al: al}
resp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
Channel: "telegram",
SenderID: "user1",
ChatID: "chat1",
Content: "hi",
})
if resp != "fallback reply" {
t.Fatalf("response = %q, want %q (fallback provider)", resp, "fallback reply")
}
if primaryCalls == 0 {
t.Fatal("primary server was never called; expected at least one attempt")
}
if fallbackCalls != 1 {
t.Fatalf("fallback server calls = %d, want 1", fallbackCalls)
}
}
// TestProcessMessage_FallbackUsesActiveProviderWhenCandidateNotRegistered verifies
// that when a candidate has no model_list entry it is absent from CandidateProviders
// and the fallback closure falls back to activeProvider instead of panicking.
func TestProcessMessage_FallbackUsesActiveProviderWhenCandidateNotRegistered(t *testing.T) {
workspace := t.TempDir()
// Primary server: returns 429 on first call, succeeds on second.
// Both the primary and the unregistered fallback share this server
// (same api_base) so activeProvider routes both calls here.
callCount := 0
primaryServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
callCount++
w.Header().Set("Content-Type", "application/json")
if callCount == 1 {
w.WriteHeader(http.StatusTooManyRequests)
_ = json.NewEncoder(w).Encode(map[string]any{
"error": map[string]any{"message": "rate limit", "type": "rate_limit_error"},
})
return
}
// Second call (fallback via activeProvider) succeeds.
_ = json.NewEncoder(w).Encode(map[string]any{
"choices": []map[string]any{
{"message": map[string]any{"content": "active provider reply"}, "finish_reason": "stop"},
},
})
}))
defer primaryServer.Close()
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: workspace,
ModelName: "primary-model",
MaxTokens: 4096,
MaxToolIterations: 3,
// No model_list entry for this alias — absent from CandidateProviders.
ModelFallbacks: []string{"openrouter/fallback-model"},
},
},
ModelList: []*config.ModelConfig{
{
ModelName: "primary-model",
Model: "openrouter/primary-model",
APIBase: primaryServer.URL,
APIKeys: config.SimpleSecureStrings("primary-key"),
Workspace: workspace,
},
},
}
provider, _, err := providers.CreateProvider(cfg)
if err != nil {
t.Fatalf("CreateProvider() error = %v", err)
}
msgBus := bus.NewMessageBus()
al := NewAgentLoop(cfg, msgBus, provider)
helper := testHelper{al: al}
resp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
Channel: "telegram",
SenderID: "user1",
ChatID: "chat1",
Content: "hi",
})
if resp != "active provider reply" {
t.Fatalf("response = %q, want %q", resp, "active provider reply")
}
if callCount < 2 {
t.Fatalf("primary server calls = %d, want >= 2 (one 429 + one success via activeProvider)", callCount)
}
}
// TestToolResult_SilentToolDoesNotSendUserMessage verifies silent tools don't trigger outbound
func TestToolResult_SilentToolDoesNotSendUserMessage(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
+116 -43
View File
@@ -8,44 +8,102 @@ import (
"github.com/sipeed/picoclaw/pkg/providers"
)
func buildModelListResolver(cfg *config.Config) func(raw string) (string, bool) {
ensureProtocol := func(model string) string {
model = strings.TrimSpace(model)
if model == "" {
return ""
}
if strings.Contains(model, "/") {
return model
}
return "openai/" + model
func ensureProtocolModel(model string) string {
model = strings.TrimSpace(model)
if model == "" {
return ""
}
if strings.Contains(model, "/") {
return model
}
return "openai/" + model
}
func modelConfigIdentityKey(mc *config.ModelConfig) string {
if mc == nil {
return ""
}
if name := strings.TrimSpace(mc.ModelName); name != "" {
return "model_name:" + name
}
return ""
}
func candidateFromModelConfig(
defaultProvider string,
mc *config.ModelConfig,
) (providers.FallbackCandidate, bool) {
if mc == nil {
return providers.FallbackCandidate{}, false
}
return func(raw string) (string, bool) {
raw = strings.TrimSpace(raw)
if raw == "" || cfg == nil {
return "", false
}
if mc, err := cfg.GetModelConfig(raw); err == nil && mc != nil && strings.TrimSpace(mc.Model) != "" {
return ensureProtocol(mc.Model), true
}
for i := range cfg.ModelList {
fullModel := strings.TrimSpace(cfg.ModelList[i].Model)
if fullModel == "" {
continue
}
if fullModel == raw {
return ensureProtocol(fullModel), true
}
_, modelID := providers.ExtractProtocol(fullModel)
if modelID == raw {
return ensureProtocol(fullModel), true
}
}
return "", false
ref := providers.ParseModelRef(ensureProtocolModel(mc.Model), defaultProvider)
if ref == nil {
return providers.FallbackCandidate{}, false
}
return providers.FallbackCandidate{
Provider: ref.Provider,
Model: ref.Model,
RPM: mc.RPM,
IdentityKey: modelConfigIdentityKey(mc),
}, true
}
func lookupModelConfigByRef(cfg *config.Config, raw string) *config.ModelConfig {
raw = strings.TrimSpace(raw)
if raw == "" || cfg == nil {
return nil
}
if mc, err := cfg.GetModelConfig(raw); err == nil && mc != nil && strings.TrimSpace(mc.Model) != "" {
return mc
}
for i := range cfg.ModelList {
mc := cfg.ModelList[i]
if mc == nil {
continue
}
fullModel := strings.TrimSpace(mc.Model)
if fullModel == "" {
continue
}
if fullModel == raw {
return mc
}
_, modelID := providers.ExtractProtocol(fullModel)
if modelID == raw {
return mc
}
}
return nil
}
func resolveModelCandidate(
cfg *config.Config,
defaultProvider string,
raw string,
) (providers.FallbackCandidate, bool) {
raw = strings.TrimSpace(raw)
if raw == "" {
return providers.FallbackCandidate{}, false
}
if mc := lookupModelConfigByRef(cfg, raw); mc != nil {
return candidateFromModelConfig(defaultProvider, mc)
}
ref := providers.ParseModelRef(raw, defaultProvider)
if ref == nil {
return providers.FallbackCandidate{}, false
}
return providers.FallbackCandidate{
Provider: ref.Provider,
Model: ref.Model,
}, true
}
func resolveModelCandidates(
@@ -54,14 +112,29 @@ func resolveModelCandidates(
primary string,
fallbacks []string,
) []providers.FallbackCandidate {
return providers.ResolveCandidatesWithLookup(
providers.ModelConfig{
Primary: primary,
Fallbacks: fallbacks,
},
defaultProvider,
buildModelListResolver(cfg),
)
seen := make(map[string]bool)
candidates := make([]providers.FallbackCandidate, 0, 1+len(fallbacks))
addCandidate := func(raw string) {
candidate, ok := resolveModelCandidate(cfg, defaultProvider, raw)
if !ok {
return
}
key := candidate.StableKey()
if seen[key] {
return
}
seen[key] = true
candidates = append(candidates, candidate)
}
addCandidate(primary)
for _, fallback := range fallbacks {
addCandidate(fallback)
}
return candidates
}
func resolvedCandidateModel(candidates []providers.FallbackCandidate, fallback string) string {
+6 -2
View File
@@ -432,6 +432,7 @@ func spawnSubTurn(
// 7. Defer cleanup: deliver result (for async), emit End event, and recover from panics
defer func() {
if r := recover(); r != nil {
logger.RecoverPanicNoExit(r)
err = fmt.Errorf("subturn panicked: %v", r)
result = nil
logger.ErrorCF("subturn", "SubTurn panicked", map[string]any{
@@ -515,6 +516,7 @@ func deliverSubTurnResult(al *AgentLoop, parentTS *turnState, childID string, re
// We use defer/recover to catch any unlikely channel panics if it were ever closed.
defer func() {
if r := recover(); r != nil {
logger.RecoverPanicNoExit(r)
logger.WarnCF("subturn", "recovered panic sending to pendingResults", map[string]any{
"parent_id": parentTS.turnID,
"child_id": childID,
@@ -607,6 +609,7 @@ type ephemeralSessionStoreIface interface {
SetHistory(key string, history []providers.Message)
TruncateHistory(key string, keepLast int)
Save(key string) error
ListSessions() []string
Close() error
}
@@ -666,8 +669,9 @@ func (e *ephemeralSessionStore) TruncateHistory(_ string, keepLast int) {
e.history = e.history[len(e.history)-keepLast:]
}
func (e *ephemeralSessionStore) Save(_ string) error { return nil }
func (e *ephemeralSessionStore) Close() error { return nil }
func (e *ephemeralSessionStore) Save(_ string) error { return nil }
func (e *ephemeralSessionStore) Close() error { return nil }
func (e *ephemeralSessionStore) ListSessions() []string { return nil }
func (e *ephemeralSessionStore) truncateLocked() {
if len(e.history) > maxEphemeralHistorySize {
+18
View File
@@ -8,6 +8,7 @@ import (
"time"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/session"
"github.com/sipeed/picoclaw/pkg/tools"
@@ -341,6 +342,23 @@ func (ts *turnState) refreshRestorePointFromSession(agent *AgentInstance) {
ts.captureRestorePoint(history, summary)
}
// ingestMessage calls the ContextManager's Ingest method for a persisted message.
// Errors are logged but never block the turn.
func (ts *turnState) ingestMessage(ctx context.Context, al *AgentLoop, msg providers.Message) {
if al.contextManager == nil {
return
}
if err := al.contextManager.Ingest(ctx, &IngestRequest{
SessionKey: ts.sessionKey,
Message: msg,
}); err != nil {
logger.WarnCF("agent", "Context manager ingest failed", map[string]any{
"session_key": ts.sessionKey,
"error": err.Error(),
})
}
}
func (ts *turnState) restoreSession(agent *AgentInstance) error {
ts.mu.RLock()
history := append([]providers.Message(nil), ts.restorePointHistory...)