mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
Merge branch 'main' into refactor-inbound-context-routing-session
# Conflicts: # pkg/agent/eventbus_test.go # pkg/agent/loop.go # pkg/bus/bus.go # pkg/bus/types.go # pkg/channels/pico/pico.go # pkg/channels/telegram/telegram.go # pkg/config/config.go # web/backend/api/session.go # web/backend/api/session_test.go
This commit is contained in:
@@ -602,14 +602,16 @@ func (cb *ContextBuilder) BuildMessages(
|
||||
// Add conversation history
|
||||
messages = append(messages, history...)
|
||||
|
||||
// Add current user message
|
||||
if strings.TrimSpace(currentMessage) != "" {
|
||||
// Add current user message. Media-only turns must still be preserved so
|
||||
// multimodal providers receive the uploaded image even when the user sends
|
||||
// no accompanying text.
|
||||
if strings.TrimSpace(currentMessage) != "" || len(media) > 0 {
|
||||
msg := providers.Message{
|
||||
Role: "user",
|
||||
Content: currentMessage,
|
||||
}
|
||||
if len(media) > 0 {
|
||||
msg.Media = media
|
||||
msg.Media = append([]string(nil), media...)
|
||||
}
|
||||
messages = append(messages, msg)
|
||||
}
|
||||
|
||||
+11
-85
@@ -6,10 +6,8 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/tokenizer"
|
||||
)
|
||||
|
||||
// parseTurnBoundaries returns the starting index of each Turn in the history.
|
||||
@@ -86,88 +84,16 @@ func findSafeBoundary(history []providers.Message, targetIndex int) int {
|
||||
return 0
|
||||
}
|
||||
|
||||
// estimateMessageTokens estimates the token count for a single message,
|
||||
// including Content, ReasoningContent, ToolCalls arguments, ToolCallID
|
||||
// metadata, and Media items. Uses a heuristic of 2.5 characters per token.
|
||||
func estimateMessageTokens(msg providers.Message) int {
|
||||
contentChars := utf8.RuneCountInString(msg.Content)
|
||||
|
||||
// SystemParts are structured system blocks used for cache-aware adapters.
|
||||
// They carry the same content as Content, but in multiple blocks.
|
||||
// We estimate them as an alternative representation, not additive.
|
||||
systemPartsChars := 0
|
||||
if len(msg.SystemParts) > 0 {
|
||||
for _, part := range msg.SystemParts {
|
||||
systemPartsChars += utf8.RuneCountInString(part.Text)
|
||||
}
|
||||
// Per-part overhead for JSON structure (type, text, cache_control).
|
||||
const perPartOverhead = 20
|
||||
systemPartsChars += len(msg.SystemParts) * perPartOverhead
|
||||
}
|
||||
|
||||
// Use the larger of the two representations to stay conservative.
|
||||
chars := contentChars
|
||||
if systemPartsChars > chars {
|
||||
chars = systemPartsChars
|
||||
}
|
||||
|
||||
chars += utf8.RuneCountInString(msg.ReasoningContent)
|
||||
|
||||
for _, tc := range msg.ToolCalls {
|
||||
chars += len(tc.ID) + len(tc.Type)
|
||||
if tc.Function != nil {
|
||||
// Count function name + arguments (the wire format for most providers).
|
||||
// tc.Name mirrors tc.Function.Name — count only once to avoid double-counting.
|
||||
chars += len(tc.Function.Name) + len(tc.Function.Arguments)
|
||||
} else {
|
||||
// Fallback: some provider formats use top-level Name without Function.
|
||||
chars += len(tc.Name)
|
||||
}
|
||||
}
|
||||
|
||||
if msg.ToolCallID != "" {
|
||||
chars += len(msg.ToolCallID)
|
||||
}
|
||||
|
||||
// Per-message overhead for role label, JSON structure, separators.
|
||||
const messageOverhead = 12
|
||||
chars += messageOverhead
|
||||
|
||||
tokens := chars * 2 / 5
|
||||
|
||||
// Media items (images, files) are serialized by provider adapters into
|
||||
// multipart or image_url payloads. Add a fixed per-item token estimate
|
||||
// directly (not through the chars heuristic) since actual cost depends
|
||||
// on resolution and provider-specific image tokenization.
|
||||
const mediaTokensPerItem = 256
|
||||
tokens += len(msg.Media) * mediaTokensPerItem
|
||||
|
||||
return tokens
|
||||
// EstimateMessageTokens estimates the token count for a single message.
|
||||
// Delegates to the shared tokenizer package for consistency across agent and seahorse.
|
||||
func EstimateMessageTokens(msg providers.Message) int {
|
||||
return tokenizer.EstimateMessageTokens(msg)
|
||||
}
|
||||
|
||||
// estimateToolDefsTokens estimates the total token cost of tool definitions
|
||||
// as they appear in the LLM request. Each tool's name, description, and
|
||||
// JSON schema parameters contribute to the context window budget.
|
||||
func estimateToolDefsTokens(defs []providers.ToolDefinition) int {
|
||||
if len(defs) == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
totalChars := 0
|
||||
for _, d := range defs {
|
||||
totalChars += len(d.Function.Name) + len(d.Function.Description)
|
||||
|
||||
if d.Function.Parameters != nil {
|
||||
if paramJSON, err := json.Marshal(d.Function.Parameters); err == nil {
|
||||
totalChars += len(paramJSON)
|
||||
}
|
||||
}
|
||||
|
||||
// Per-tool overhead: type field, JSON structure, separators.
|
||||
totalChars += 20
|
||||
}
|
||||
|
||||
return totalChars * 2 / 5
|
||||
// EstimateToolDefsTokens estimates the total token cost of tool definitions
|
||||
// as they appear in the LLM request. Delegates to the shared tokenizer package.
|
||||
func EstimateToolDefsTokens(defs []providers.ToolDefinition) int {
|
||||
return tokenizer.EstimateToolDefsTokens(defs)
|
||||
}
|
||||
|
||||
// isOverContextBudget checks whether the assembled messages plus tool definitions
|
||||
@@ -181,10 +107,10 @@ func isOverContextBudget(
|
||||
) bool {
|
||||
msgTokens := 0
|
||||
for _, m := range messages {
|
||||
msgTokens += estimateMessageTokens(m)
|
||||
msgTokens += EstimateMessageTokens(m)
|
||||
}
|
||||
|
||||
toolTokens := estimateToolDefsTokens(toolDefs)
|
||||
toolTokens := EstimateToolDefsTokens(toolDefs)
|
||||
total := msgTokens + toolTokens + maxTokens
|
||||
|
||||
return total > contextWindow
|
||||
|
||||
@@ -417,9 +417,9 @@ func TestEstimateMessageTokens(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := estimateMessageTokens(tt.msg)
|
||||
got := EstimateMessageTokens(tt.msg)
|
||||
if got < tt.want {
|
||||
t.Errorf("estimateMessageTokens() = %d, want >= %d", got, tt.want)
|
||||
t.Errorf("EstimateMessageTokens() = %d, want >= %d", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -443,8 +443,8 @@ func TestEstimateMessageTokens_ToolCallsContribute(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
plainTokens := estimateMessageTokens(plain)
|
||||
withTCTokens := estimateMessageTokens(withTC)
|
||||
plainTokens := EstimateMessageTokens(plain)
|
||||
withTCTokens := EstimateMessageTokens(withTC)
|
||||
|
||||
if withTCTokens <= plainTokens {
|
||||
t.Errorf("message with ToolCalls (%d tokens) should exceed plain message (%d tokens)",
|
||||
@@ -457,7 +457,7 @@ func TestEstimateMessageTokens_MultibyteContent(t *testing.T) {
|
||||
// but may map to different token counts. The heuristic should still produce
|
||||
// reasonable estimates via RuneCountInString.
|
||||
msg := msgUser("caf\u00e9 na\u00efve r\u00e9sum\u00e9 \u00fcber stra\u00dfe")
|
||||
tokens := estimateMessageTokens(msg)
|
||||
tokens := EstimateMessageTokens(msg)
|
||||
if tokens <= 0 {
|
||||
t.Errorf("multibyte message should produce positive token count, got %d", tokens)
|
||||
}
|
||||
@@ -481,7 +481,7 @@ func TestEstimateMessageTokens_LargeArguments(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
tokens := estimateMessageTokens(msg)
|
||||
tokens := EstimateMessageTokens(msg)
|
||||
// 5000+ chars → at least 2000 tokens with the 2.5 char/token heuristic
|
||||
if tokens < 2000 {
|
||||
t.Errorf("large tool call arguments should produce significant token count, got %d", tokens)
|
||||
@@ -496,8 +496,8 @@ func TestEstimateMessageTokens_ReasoningContent(t *testing.T) {
|
||||
ReasoningContent: strings.Repeat("thinking step ", 200),
|
||||
}
|
||||
|
||||
plainTokens := estimateMessageTokens(plain)
|
||||
reasoningTokens := estimateMessageTokens(withReasoning)
|
||||
plainTokens := EstimateMessageTokens(plain)
|
||||
reasoningTokens := EstimateMessageTokens(withReasoning)
|
||||
|
||||
if reasoningTokens <= plainTokens {
|
||||
t.Errorf("message with ReasoningContent (%d tokens) should exceed plain message (%d tokens)",
|
||||
@@ -513,8 +513,8 @@ func TestEstimateMessageTokens_MediaItems(t *testing.T) {
|
||||
Media: []string{"media://img1.png", "media://img2.png"},
|
||||
}
|
||||
|
||||
plainTokens := estimateMessageTokens(plain)
|
||||
mediaTokens := estimateMessageTokens(withMedia)
|
||||
plainTokens := EstimateMessageTokens(plain)
|
||||
mediaTokens := EstimateMessageTokens(withMedia)
|
||||
|
||||
if mediaTokens <= plainTokens {
|
||||
t.Errorf("message with Media (%d tokens) should exceed plain message (%d tokens)",
|
||||
@@ -540,8 +540,8 @@ func TestEstimateMessageTokens_SystemParts(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
plainTokens := estimateMessageTokens(plain)
|
||||
partsTokens := estimateMessageTokens(withParts)
|
||||
plainTokens := EstimateMessageTokens(plain)
|
||||
partsTokens := EstimateMessageTokens(withParts)
|
||||
|
||||
if partsTokens <= plainTokens {
|
||||
t.Errorf("system message with SystemParts (%d) should exceed plain message (%d)",
|
||||
@@ -549,7 +549,7 @@ func TestEstimateMessageTokens_SystemParts(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// --- estimateToolDefsTokens tests ---
|
||||
// --- EstimateToolDefsTokens tests ---
|
||||
|
||||
func TestEstimateToolDefsTokens(t *testing.T) {
|
||||
tests := []struct {
|
||||
@@ -599,9 +599,9 @@ func TestEstimateToolDefsTokens(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := estimateToolDefsTokens(tt.defs)
|
||||
got := EstimateToolDefsTokens(tt.defs)
|
||||
if got < tt.want {
|
||||
t.Errorf("estimateToolDefsTokens() = %d, want >= %d", got, tt.want)
|
||||
t.Errorf("EstimateToolDefsTokens() = %d, want >= %d", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -624,8 +624,8 @@ func TestEstimateToolDefsTokens_ScalesWithCount(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
one := estimateToolDefsTokens([]providers.ToolDefinition{makeTool("tool_a")})
|
||||
three := estimateToolDefsTokens([]providers.ToolDefinition{
|
||||
one := EstimateToolDefsTokens([]providers.ToolDefinition{makeTool("tool_a")})
|
||||
three := EstimateToolDefsTokens([]providers.ToolDefinition{
|
||||
makeTool("tool_a"), makeTool("tool_b"), makeTool("tool_c"),
|
||||
})
|
||||
|
||||
@@ -770,7 +770,7 @@ func TestEstimateMessageTokens_WithReasoningAndMedia(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
tokens := estimateMessageTokens(msg)
|
||||
tokens := EstimateMessageTokens(msg)
|
||||
|
||||
// ReasoningContent alone is ~1700 chars → ~680 tokens.
|
||||
// Content + TC + overhead adds more. Should be well above 500.
|
||||
@@ -781,7 +781,7 @@ func TestEstimateMessageTokens_WithReasoningAndMedia(t *testing.T) {
|
||||
// Compare without reasoning to ensure it's counted.
|
||||
msgNoReasoning := msg
|
||||
msgNoReasoning.ReasoningContent = ""
|
||||
tokensNoReasoning := estimateMessageTokens(msgNoReasoning)
|
||||
tokensNoReasoning := EstimateMessageTokens(msgNoReasoning)
|
||||
|
||||
if tokens <= tokensNoReasoning {
|
||||
t.Errorf("reasoning content should add tokens: with=%d, without=%d", tokens, tokensNoReasoning)
|
||||
|
||||
@@ -707,6 +707,38 @@ func TestEmptyWorkspaceBaselineDetectsNewFiles(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildMessages_IncludesMediaOnlyCurrentMessage(t *testing.T) {
|
||||
tmpDir := setupWorkspace(t, nil)
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cb := NewContextBuilder(tmpDir)
|
||||
msgs := cb.BuildMessages(
|
||||
nil,
|
||||
"",
|
||||
"",
|
||||
[]string{"data:image/png;base64,abc123"},
|
||||
"pico",
|
||||
"chat-1",
|
||||
"",
|
||||
"",
|
||||
)
|
||||
|
||||
if len(msgs) != 2 {
|
||||
t.Fatalf("len(msgs) = %d, want 2", len(msgs))
|
||||
}
|
||||
|
||||
userMsg := msgs[1]
|
||||
if userMsg.Role != "user" {
|
||||
t.Fatalf("userMsg.Role = %q, want %q", userMsg.Role, "user")
|
||||
}
|
||||
if userMsg.Content != "" {
|
||||
t.Fatalf("userMsg.Content = %q, want empty string", userMsg.Content)
|
||||
}
|
||||
if len(userMsg.Media) != 1 || userMsg.Media[0] != "data:image/png;base64,abc123" {
|
||||
t.Fatalf("userMsg.Media = %#v, want image payload", userMsg.Media)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkBuildMessagesWithCache measures caching performance.
|
||||
func BenchmarkBuildMessagesWithCache(b *testing.B) {
|
||||
tmpDir, _ := os.MkdirTemp("", "picoclaw-bench-*")
|
||||
|
||||
@@ -0,0 +1,379 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
// legacyContextManager wraps the existing summarization/compression logic
|
||||
// as a ContextManager implementation. It is the default when no other
|
||||
// ContextManager is configured.
|
||||
type legacyContextManager struct {
|
||||
al *AgentLoop
|
||||
summarizing sync.Map // dedup for async Compact (post-turn)
|
||||
}
|
||||
|
||||
func (m *legacyContextManager) Assemble(_ context.Context, req *AssembleRequest) (*AssembleResponse, error) {
|
||||
// Legacy: read history from session, return as-is.
|
||||
// Budget enforcement happens in BuildMessages caller via
|
||||
// isOverContextBudget + forceCompression.
|
||||
agent := m.al.registry.GetDefaultAgent()
|
||||
if agent == nil {
|
||||
return &AssembleResponse{}, nil
|
||||
}
|
||||
history := agent.Sessions.GetHistory(req.SessionKey)
|
||||
summary := agent.Sessions.GetSummary(req.SessionKey)
|
||||
return &AssembleResponse{
|
||||
History: history,
|
||||
Summary: summary,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *legacyContextManager) Compact(_ context.Context, req *CompactRequest) error {
|
||||
switch req.Reason {
|
||||
case ContextCompressReasonProactive, ContextCompressReasonRetry:
|
||||
// Sync emergency compression — budget exceeded.
|
||||
if result, ok := m.forceCompression(req.SessionKey); ok {
|
||||
m.al.emitEvent(
|
||||
EventKindContextCompress,
|
||||
m.al.newTurnEventScope("", req.SessionKey, nil).meta(0, "forceCompression", "turn.context.compress"),
|
||||
ContextCompressPayload{
|
||||
Reason: req.Reason,
|
||||
DroppedMessages: result.DroppedMessages,
|
||||
RemainingMessages: result.RemainingMessages,
|
||||
},
|
||||
)
|
||||
}
|
||||
case ContextCompressReasonSummarize:
|
||||
m.maybeSummarize(req.SessionKey)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *legacyContextManager) Ingest(_ context.Context, _ *IngestRequest) error {
|
||||
// Legacy: no-op. Messages are persisted by Sessions JSONL.
|
||||
return nil
|
||||
}
|
||||
|
||||
// maybeSummarize triggers summarization if the session history exceeds thresholds.
|
||||
// It runs asynchronously in a goroutine.
|
||||
func (m *legacyContextManager) maybeSummarize(sessionKey string) {
|
||||
agent := m.al.registry.GetDefaultAgent()
|
||||
if agent == nil {
|
||||
return
|
||||
}
|
||||
|
||||
newHistory := agent.Sessions.GetHistory(sessionKey)
|
||||
tokenEstimate := m.estimateTokens(newHistory)
|
||||
threshold := agent.ContextWindow * agent.SummarizeTokenPercent / 100
|
||||
|
||||
if len(newHistory) > agent.SummarizeMessageThreshold || tokenEstimate > threshold {
|
||||
summarizeKey := agent.ID + ":" + sessionKey
|
||||
if _, loading := m.summarizing.LoadOrStore(summarizeKey, true); !loading {
|
||||
go func() {
|
||||
defer m.summarizing.Delete(summarizeKey)
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
logger.WarnCF("agent", "Summarization panic recovered", map[string]any{
|
||||
"session_key": sessionKey,
|
||||
"panic": r,
|
||||
})
|
||||
}
|
||||
}()
|
||||
logger.Debug("Memory threshold reached. Optimizing conversation history...")
|
||||
m.summarizeSession(agent, sessionKey)
|
||||
}()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type compressionResult struct {
|
||||
DroppedMessages int
|
||||
RemainingMessages int
|
||||
}
|
||||
|
||||
// forceCompression aggressively reduces context when the limit is hit.
|
||||
// It drops the oldest ~50% of Turns (a Turn is a complete user→LLM→response
|
||||
// cycle, as defined in #1316), so tool-call sequences are never split.
|
||||
func (m *legacyContextManager) forceCompression(sessionKey string) (compressionResult, bool) {
|
||||
agent := m.al.registry.GetDefaultAgent()
|
||||
if agent == nil {
|
||||
return compressionResult{}, false
|
||||
}
|
||||
|
||||
history := agent.Sessions.GetHistory(sessionKey)
|
||||
if len(history) <= 2 {
|
||||
return compressionResult{}, false
|
||||
}
|
||||
|
||||
turns := parseTurnBoundaries(history)
|
||||
var mid int
|
||||
if len(turns) >= 2 {
|
||||
mid = turns[len(turns)/2]
|
||||
} else {
|
||||
mid = findSafeBoundary(history, len(history)/2)
|
||||
}
|
||||
var keptHistory []providers.Message
|
||||
if mid <= 0 {
|
||||
for i := len(history) - 1; i >= 0; i-- {
|
||||
if history[i].Role == "user" {
|
||||
keptHistory = []providers.Message{history[i]}
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
keptHistory = history[mid:]
|
||||
}
|
||||
|
||||
droppedCount := len(history) - len(keptHistory)
|
||||
|
||||
existingSummary := agent.Sessions.GetSummary(sessionKey)
|
||||
compressionNote := fmt.Sprintf(
|
||||
"[Emergency compression dropped %d oldest messages due to context limit]",
|
||||
droppedCount,
|
||||
)
|
||||
if existingSummary != "" {
|
||||
compressionNote = existingSummary + "\n\n" + compressionNote
|
||||
}
|
||||
agent.Sessions.SetSummary(sessionKey, compressionNote)
|
||||
|
||||
agent.Sessions.SetHistory(sessionKey, keptHistory)
|
||||
agent.Sessions.Save(sessionKey)
|
||||
|
||||
logger.WarnCF("agent", "Forced compression executed", map[string]any{
|
||||
"session_key": sessionKey,
|
||||
"dropped_msgs": droppedCount,
|
||||
"new_count": len(keptHistory),
|
||||
})
|
||||
|
||||
return compressionResult{
|
||||
DroppedMessages: droppedCount,
|
||||
RemainingMessages: len(keptHistory),
|
||||
}, true
|
||||
}
|
||||
|
||||
func (m *legacyContextManager) summarizeSession(agent *AgentInstance, sessionKey string) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||
defer cancel()
|
||||
|
||||
history := agent.Sessions.GetHistory(sessionKey)
|
||||
summary := agent.Sessions.GetSummary(sessionKey)
|
||||
|
||||
if len(history) <= 4 {
|
||||
return
|
||||
}
|
||||
|
||||
safeCut := findSafeBoundary(history, len(history)-4)
|
||||
if safeCut <= 0 {
|
||||
return
|
||||
}
|
||||
keepCount := len(history) - safeCut
|
||||
toSummarize := history[:safeCut]
|
||||
|
||||
maxMessageTokens := agent.ContextWindow / 2
|
||||
validMessages := make([]providers.Message, 0)
|
||||
omitted := false
|
||||
|
||||
for _, msg := range toSummarize {
|
||||
if msg.Role != "user" && msg.Role != "assistant" {
|
||||
continue
|
||||
}
|
||||
msgTokens := len(msg.Content) / 2
|
||||
if msgTokens > maxMessageTokens {
|
||||
omitted = true
|
||||
continue
|
||||
}
|
||||
validMessages = append(validMessages, msg)
|
||||
}
|
||||
|
||||
if len(validMessages) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
const (
|
||||
maxSummarizationMessages = 10
|
||||
llmMaxRetries = 3
|
||||
)
|
||||
|
||||
var finalSummary string
|
||||
if len(validMessages) > maxSummarizationMessages {
|
||||
mid := len(validMessages) / 2
|
||||
mid = m.findNearestUserMessage(validMessages, mid)
|
||||
|
||||
part1 := validMessages[:mid]
|
||||
part2 := validMessages[mid:]
|
||||
|
||||
s1, _ := m.summarizeBatch(ctx, agent, part1, "")
|
||||
s2, _ := m.summarizeBatch(ctx, agent, part2, "")
|
||||
|
||||
mergePrompt := fmt.Sprintf(
|
||||
"Merge these two conversation summaries into one cohesive summary:\n\n1: %s\n\n2: %s",
|
||||
s1, s2,
|
||||
)
|
||||
|
||||
resp, err := m.retryLLMCall(ctx, agent, mergePrompt, llmMaxRetries)
|
||||
if err == nil && resp.Content != "" {
|
||||
finalSummary = resp.Content
|
||||
} else {
|
||||
finalSummary = s1 + " " + s2
|
||||
}
|
||||
} else {
|
||||
finalSummary, _ = m.summarizeBatch(ctx, agent, validMessages, summary)
|
||||
}
|
||||
|
||||
if omitted && finalSummary != "" {
|
||||
finalSummary += "\n[Note: Some oversized messages were omitted from this summary for efficiency.]"
|
||||
}
|
||||
|
||||
if finalSummary != "" {
|
||||
agent.Sessions.SetSummary(sessionKey, finalSummary)
|
||||
agent.Sessions.TruncateHistory(sessionKey, keepCount)
|
||||
agent.Sessions.Save(sessionKey)
|
||||
m.al.emitEvent(
|
||||
EventKindSessionSummarize,
|
||||
m.al.newTurnEventScope(agent.ID, sessionKey, nil).meta(0, "summarizeSession", "turn.session.summarize"),
|
||||
SessionSummarizePayload{
|
||||
SummarizedMessages: len(validMessages),
|
||||
KeptMessages: keepCount,
|
||||
SummaryLen: len(finalSummary),
|
||||
OmittedOversized: omitted,
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *legacyContextManager) findNearestUserMessage(messages []providers.Message, mid int) int {
|
||||
originalMid := mid
|
||||
|
||||
for mid > 0 && messages[mid].Role != "user" {
|
||||
mid--
|
||||
}
|
||||
|
||||
if messages[mid].Role == "user" {
|
||||
return mid
|
||||
}
|
||||
|
||||
mid = originalMid
|
||||
for mid < len(messages) && messages[mid].Role != "user" {
|
||||
mid++
|
||||
}
|
||||
|
||||
if mid < len(messages) {
|
||||
return mid
|
||||
}
|
||||
|
||||
return originalMid
|
||||
}
|
||||
|
||||
func (m *legacyContextManager) retryLLMCall(
|
||||
ctx context.Context,
|
||||
agent *AgentInstance,
|
||||
prompt string,
|
||||
maxRetries int,
|
||||
) (*providers.LLMResponse, error) {
|
||||
const llmTemperature = 0.3
|
||||
|
||||
var resp *providers.LLMResponse
|
||||
var err error
|
||||
|
||||
for attempt := 0; attempt < maxRetries; attempt++ {
|
||||
m.al.activeRequests.Add(1)
|
||||
resp, err = func() (*providers.LLMResponse, error) {
|
||||
defer m.al.activeRequests.Done()
|
||||
return agent.Provider.Chat(
|
||||
ctx,
|
||||
[]providers.Message{{Role: "user", Content: prompt}},
|
||||
nil,
|
||||
agent.Model,
|
||||
map[string]any{
|
||||
"max_tokens": agent.MaxTokens,
|
||||
"temperature": llmTemperature,
|
||||
"prompt_cache_key": agent.ID,
|
||||
},
|
||||
)
|
||||
}()
|
||||
|
||||
if err == nil && resp != nil && resp.Content != "" {
|
||||
return resp, nil
|
||||
}
|
||||
if attempt < maxRetries-1 {
|
||||
time.Sleep(time.Duration(attempt+1) * 100 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
|
||||
return resp, err
|
||||
}
|
||||
|
||||
func (m *legacyContextManager) summarizeBatch(
|
||||
ctx context.Context,
|
||||
agent *AgentInstance,
|
||||
batch []providers.Message,
|
||||
existingSummary string,
|
||||
) (string, error) {
|
||||
const (
|
||||
llmMaxRetries = 3
|
||||
fallbackMinContentLength = 200
|
||||
fallbackMaxContentPercent = 10
|
||||
)
|
||||
|
||||
var sb strings.Builder
|
||||
sb.WriteString("Provide a concise summary of this conversation segment, preserving core context and key points.\n")
|
||||
if existingSummary != "" {
|
||||
sb.WriteString("Existing context: ")
|
||||
sb.WriteString(existingSummary)
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
sb.WriteString("\nCONVERSATION:\n")
|
||||
for _, msg := range batch {
|
||||
fmt.Fprintf(&sb, "%s: %s\n", msg.Role, msg.Content)
|
||||
}
|
||||
prompt := sb.String()
|
||||
|
||||
response, err := m.retryLLMCall(ctx, agent, prompt, llmMaxRetries)
|
||||
if err == nil && response.Content != "" {
|
||||
return strings.TrimSpace(response.Content), nil
|
||||
}
|
||||
|
||||
var fallback strings.Builder
|
||||
fallback.WriteString("Conversation summary: ")
|
||||
for i, msg := range batch {
|
||||
if i > 0 {
|
||||
fallback.WriteString(" | ")
|
||||
}
|
||||
content := strings.TrimSpace(msg.Content)
|
||||
runes := []rune(content)
|
||||
if len(runes) == 0 {
|
||||
fallback.WriteString(fmt.Sprintf("%s: ", msg.Role))
|
||||
continue
|
||||
}
|
||||
|
||||
keepLength := len(runes) * fallbackMaxContentPercent / 100
|
||||
if keepLength < fallbackMinContentLength {
|
||||
keepLength = fallbackMinContentLength
|
||||
}
|
||||
if keepLength > len(runes) {
|
||||
keepLength = len(runes)
|
||||
}
|
||||
|
||||
content = string(runes[:keepLength])
|
||||
if keepLength < len(runes) {
|
||||
content += "..."
|
||||
}
|
||||
fallback.WriteString(fmt.Sprintf("%s: %s", msg.Role, content))
|
||||
}
|
||||
return fallback.String(), nil
|
||||
}
|
||||
|
||||
func (m *legacyContextManager) estimateTokens(messages []providers.Message) int {
|
||||
total := 0
|
||||
for _, msg := range messages {
|
||||
total += EstimateMessageTokens(msg)
|
||||
}
|
||||
return total
|
||||
}
|
||||
@@ -0,0 +1,90 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
// ContextManager manages conversation context via a pluggable strategy.
|
||||
// Exactly ONE ContextManager is active per AgentLoop, selected by config.
|
||||
// The default ("legacy") preserves current summarization behavior.
|
||||
type ContextManager interface {
|
||||
// Assemble builds budget-aware context from the ContextManager's own storage.
|
||||
// Called before BuildMessages. Returns assembled messages ready for LLM.
|
||||
Assemble(ctx context.Context, req *AssembleRequest) (*AssembleResponse, error)
|
||||
|
||||
// Compact compresses conversation history.
|
||||
// Called after turn completes (may be async internally) and on context overflow (sync).
|
||||
Compact(ctx context.Context, req *CompactRequest) error
|
||||
|
||||
// Ingest records a message into the ContextManager's own storage.
|
||||
// Called after each message is persisted to session JSONL.
|
||||
Ingest(ctx context.Context, req *IngestRequest) error
|
||||
}
|
||||
|
||||
// AssembleRequest is the input to Assemble.
|
||||
type AssembleRequest struct {
|
||||
SessionKey string // session identifier
|
||||
Budget int // context window in tokens
|
||||
MaxTokens int // max response tokens
|
||||
}
|
||||
|
||||
// AssembleResponse is the output of Assemble.
|
||||
type AssembleResponse struct {
|
||||
History []providers.Message // assembled conversation history for BuildMessages
|
||||
Summary string // conversation summary embedded into system prompt by BuildMessages
|
||||
}
|
||||
|
||||
// CompactRequest is the input to Compact.
|
||||
type CompactRequest struct {
|
||||
SessionKey string // session identifier
|
||||
Reason ContextCompressReason // proactive_budget | llm_retry | summarize
|
||||
Budget int // context window budget (used for retry aggressive compaction)
|
||||
}
|
||||
|
||||
// IngestRequest is the input to Ingest.
|
||||
type IngestRequest struct {
|
||||
SessionKey string // session identifier
|
||||
Message providers.Message // the message just persisted
|
||||
}
|
||||
|
||||
// ContextManagerFactory constructs a ContextManager from config.
|
||||
// al provides access to the AgentLoop's runtime resources (provider, model, workspace, etc.)
|
||||
// cfg is the raw JSON configuration from config.json (may be nil).
|
||||
type ContextManagerFactory func(cfg json.RawMessage, al *AgentLoop) (ContextManager, error)
|
||||
|
||||
var (
|
||||
cmRegistryMu sync.RWMutex
|
||||
cmRegistry = map[string]ContextManagerFactory{}
|
||||
)
|
||||
|
||||
// RegisterContextManager registers a named ContextManager factory.
|
||||
func RegisterContextManager(name string, factory ContextManagerFactory) error {
|
||||
if name == "" {
|
||||
return fmt.Errorf("context manager name is required")
|
||||
}
|
||||
if factory == nil {
|
||||
return fmt.Errorf("context manager %q factory is nil", name)
|
||||
}
|
||||
|
||||
cmRegistryMu.Lock()
|
||||
defer cmRegistryMu.Unlock()
|
||||
|
||||
if _, exists := cmRegistry[name]; exists {
|
||||
return fmt.Errorf("context manager %q is already registered", name)
|
||||
}
|
||||
cmRegistry[name] = factory
|
||||
return nil
|
||||
}
|
||||
|
||||
func lookupContextManager(name string) (ContextManagerFactory, bool) {
|
||||
cmRegistryMu.RLock()
|
||||
defer cmRegistryMu.RUnlock()
|
||||
|
||||
f, ok := cmRegistry[name]
|
||||
return f, ok
|
||||
}
|
||||
@@ -0,0 +1,764 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Factory registry tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestRegisterContextManager_Success(t *testing.T) {
|
||||
cleanup := resetCMRegistry()
|
||||
defer cleanup()
|
||||
|
||||
factory := func(cfg json.RawMessage, al *AgentLoop) (ContextManager, error) {
|
||||
return &noopContextManager{}, nil
|
||||
}
|
||||
if err := RegisterContextManager("test_cm", factory); err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
f, ok := lookupContextManager("test_cm")
|
||||
if !ok {
|
||||
t.Fatal("expected factory to be registered")
|
||||
}
|
||||
if f == nil {
|
||||
t.Fatal("expected non-nil factory")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegisterContextManager_EmptyName(t *testing.T) {
|
||||
cleanup := resetCMRegistry()
|
||||
defer cleanup()
|
||||
|
||||
err := RegisterContextManager("", func(cfg json.RawMessage, al *AgentLoop) (ContextManager, error) {
|
||||
return &noopContextManager{}, nil
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("expected error for empty name")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "name is required") {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegisterContextManager_NilFactory(t *testing.T) {
|
||||
cleanup := resetCMRegistry()
|
||||
defer cleanup()
|
||||
|
||||
err := RegisterContextManager("nil_factory", nil)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for nil factory")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "factory is nil") {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegisterContextManager_Duplicate(t *testing.T) {
|
||||
cleanup := resetCMRegistry()
|
||||
defer cleanup()
|
||||
|
||||
factory := func(cfg json.RawMessage, al *AgentLoop) (ContextManager, error) {
|
||||
return &noopContextManager{}, nil
|
||||
}
|
||||
if err := RegisterContextManager("dup_cm", factory); err != nil {
|
||||
t.Fatalf("first registration failed: %v", err)
|
||||
}
|
||||
err := RegisterContextManager("dup_cm", factory)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for duplicate registration")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "already registered") {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLookupContextManager_Unknown(t *testing.T) {
|
||||
cleanup := resetCMRegistry()
|
||||
defer cleanup()
|
||||
|
||||
_, ok := lookupContextManager("nonexistent")
|
||||
if ok {
|
||||
t.Fatal("expected lookup to fail for unknown name")
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// resolveContextManager tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestResolveContextManager_Default(t *testing.T) {
|
||||
cleanup := resetCMRegistry()
|
||||
defer cleanup()
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: t.TempDir(),
|
||||
ModelName: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
ContextManager: "", // default → legacy
|
||||
},
|
||||
},
|
||||
}
|
||||
al := newCMTestAgentLoop(cfg)
|
||||
|
||||
cm := al.contextManager
|
||||
if cm == nil {
|
||||
t.Fatal("expected non-nil context manager")
|
||||
}
|
||||
if _, ok := cm.(*legacyContextManager); !ok {
|
||||
t.Fatalf("expected *legacyContextManager, got %T", cm)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveContextManager_ExplicitLegacy(t *testing.T) {
|
||||
cleanup := resetCMRegistry()
|
||||
defer cleanup()
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: t.TempDir(),
|
||||
ModelName: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
ContextManager: "legacy",
|
||||
},
|
||||
},
|
||||
}
|
||||
al := newCMTestAgentLoop(cfg)
|
||||
|
||||
if _, ok := al.contextManager.(*legacyContextManager); !ok {
|
||||
t.Fatalf("expected *legacyContextManager, got %T", al.contextManager)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveContextManager_UnknownFallsBackToLegacy(t *testing.T) {
|
||||
cleanup := resetCMRegistry()
|
||||
defer cleanup()
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: t.TempDir(),
|
||||
ModelName: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
ContextManager: "unknown_cm",
|
||||
},
|
||||
},
|
||||
}
|
||||
al := newCMTestAgentLoop(cfg)
|
||||
|
||||
if _, ok := al.contextManager.(*legacyContextManager); !ok {
|
||||
t.Fatalf("expected fallback to *legacyContextManager, got %T", al.contextManager)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveContextManager_RegisteredFactory(t *testing.T) {
|
||||
cleanup := resetCMRegistry()
|
||||
defer cleanup()
|
||||
|
||||
factory := func(cfg json.RawMessage, al *AgentLoop) (ContextManager, error) {
|
||||
return &noopContextManager{}, nil
|
||||
}
|
||||
if err := RegisterContextManager("custom_cm", factory); err != nil {
|
||||
t.Fatalf("register failed: %v", err)
|
||||
}
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: t.TempDir(),
|
||||
ModelName: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
ContextManager: "custom_cm",
|
||||
},
|
||||
},
|
||||
}
|
||||
al := newCMTestAgentLoop(cfg)
|
||||
|
||||
if _, ok := al.contextManager.(*noopContextManager); !ok {
|
||||
t.Fatalf("expected *noopContextManager, got %T", al.contextManager)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveContextManager_FactoryError(t *testing.T) {
|
||||
cleanup := resetCMRegistry()
|
||||
defer cleanup()
|
||||
|
||||
factory := func(cfg json.RawMessage, al *AgentLoop) (ContextManager, error) {
|
||||
return nil, os.ErrPermission
|
||||
}
|
||||
if err := RegisterContextManager("broken_cm", factory); err != nil {
|
||||
t.Fatalf("register failed: %v", err)
|
||||
}
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: t.TempDir(),
|
||||
ModelName: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
ContextManager: "broken_cm",
|
||||
},
|
||||
},
|
||||
}
|
||||
al := newCMTestAgentLoop(cfg)
|
||||
|
||||
// Should fall back to legacy when factory returns error
|
||||
if _, ok := al.contextManager.(*legacyContextManager); !ok {
|
||||
t.Fatalf("expected fallback to *legacyContextManager on factory error, got %T", al.contextManager)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Legacy Assemble tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestLegacyAssemble_Passthrough(t *testing.T) {
|
||||
cfg := testConfig(t)
|
||||
al := newCMTestAgentLoop(cfg)
|
||||
|
||||
agent := al.registry.GetDefaultAgent()
|
||||
if agent == nil {
|
||||
t.Fatal("expected default agent")
|
||||
}
|
||||
|
||||
history := []providers.Message{
|
||||
{Role: "user", Content: "hello"},
|
||||
{Role: "assistant", Content: "hi there"},
|
||||
}
|
||||
agent.Sessions.SetHistory("test-session", history)
|
||||
|
||||
resp, err := al.contextManager.Assemble(context.Background(), &AssembleRequest{
|
||||
SessionKey: "test-session",
|
||||
Budget: 8000,
|
||||
MaxTokens: 4096,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(resp.History) != len(history) {
|
||||
t.Fatalf("expected %d messages, got %d", len(history), len(resp.History))
|
||||
}
|
||||
for i, msg := range resp.History {
|
||||
if msg.Content != history[i].Content || msg.Role != history[i].Role {
|
||||
t.Fatalf("message %d mismatch: want %+v, got %+v", i, history[i], msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestLegacyAssemble_EmptyHistory(t *testing.T) {
|
||||
cfg := testConfig(t)
|
||||
al := newCMTestAgentLoop(cfg)
|
||||
|
||||
resp, err := al.contextManager.Assemble(context.Background(), &AssembleRequest{
|
||||
SessionKey: "test-session",
|
||||
Budget: 8000,
|
||||
MaxTokens: 4096,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(resp.History) != 0 {
|
||||
t.Fatalf("expected empty messages, got %d", len(resp.History))
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Legacy Compact overflow tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestLegacyCompact_Overflow(t *testing.T) {
|
||||
cfg := testConfig(t)
|
||||
al := newCMTestAgentLoop(cfg)
|
||||
|
||||
defaultAgent := al.registry.GetDefaultAgent()
|
||||
if defaultAgent == nil {
|
||||
t.Fatal("expected default agent")
|
||||
}
|
||||
|
||||
history := []providers.Message{
|
||||
{Role: "user", Content: "msg 1"},
|
||||
{Role: "assistant", Content: "resp 1"},
|
||||
{Role: "user", Content: "msg 2"},
|
||||
{Role: "assistant", Content: "resp 2"},
|
||||
{Role: "user", Content: "msg 3"},
|
||||
}
|
||||
defaultAgent.Sessions.SetHistory("session-overflow", history)
|
||||
|
||||
sub := al.SubscribeEvents(16)
|
||||
defer al.UnsubscribeEvents(sub.ID)
|
||||
|
||||
err := al.contextManager.Compact(context.Background(), &CompactRequest{
|
||||
SessionKey: "session-overflow",
|
||||
Reason: ContextCompressReasonRetry,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// After overflow compression, history should be shorter
|
||||
newHistory := defaultAgent.Sessions.GetHistory("session-overflow")
|
||||
if len(newHistory) >= len(history) {
|
||||
t.Fatalf("expected compressed history, got %d messages (was %d)", len(newHistory), len(history))
|
||||
}
|
||||
|
||||
// Summary should contain compression note
|
||||
summary := defaultAgent.Sessions.GetSummary("session-overflow")
|
||||
if !strings.Contains(summary, "Emergency compression") {
|
||||
t.Fatalf("expected compression note in summary, got %q", summary)
|
||||
}
|
||||
|
||||
// Event should carry the proactive reason
|
||||
events := collectEventStream(sub.C)
|
||||
compressEvt, ok := findEvent(events, EventKindContextCompress)
|
||||
if !ok {
|
||||
t.Fatal("expected context compress event")
|
||||
}
|
||||
payload, ok := compressEvt.Payload.(ContextCompressPayload)
|
||||
if !ok {
|
||||
t.Fatalf("expected ContextCompressPayload, got %T", compressEvt.Payload)
|
||||
}
|
||||
if payload.Reason != ContextCompressReasonRetry {
|
||||
t.Fatalf("expected retry reason, got %q", payload.Reason)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLegacyCompact_Overflow_ProactiveReason(t *testing.T) {
|
||||
cfg := testConfig(t)
|
||||
al := newCMTestAgentLoop(cfg)
|
||||
|
||||
defaultAgent := al.registry.GetDefaultAgent()
|
||||
if defaultAgent == nil {
|
||||
t.Fatal("expected default agent")
|
||||
}
|
||||
|
||||
history := []providers.Message{
|
||||
{Role: "user", Content: "msg 1"},
|
||||
{Role: "assistant", Content: "resp 1"},
|
||||
{Role: "user", Content: "msg 2"},
|
||||
{Role: "assistant", Content: "resp 2"},
|
||||
{Role: "user", Content: "msg 3"},
|
||||
}
|
||||
defaultAgent.Sessions.SetHistory("session-proactive", history)
|
||||
|
||||
sub := al.SubscribeEvents(16)
|
||||
defer al.UnsubscribeEvents(sub.ID)
|
||||
|
||||
err := al.contextManager.Compact(context.Background(), &CompactRequest{
|
||||
SessionKey: "session-proactive",
|
||||
Reason: ContextCompressReasonProactive,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
events := collectEventStream(sub.C)
|
||||
compressEvt, ok := findEvent(events, EventKindContextCompress)
|
||||
if !ok {
|
||||
t.Fatal("expected context compress event")
|
||||
}
|
||||
payload, ok := compressEvt.Payload.(ContextCompressPayload)
|
||||
if !ok {
|
||||
t.Fatalf("expected ContextCompressPayload, got %T", compressEvt.Payload)
|
||||
}
|
||||
if payload.Reason != ContextCompressReasonProactive {
|
||||
t.Fatalf("expected proactive reason, got %q", payload.Reason)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLegacyCompact_Overflow_TooShortToCompress(t *testing.T) {
|
||||
cfg := testConfig(t)
|
||||
al := newCMTestAgentLoop(cfg)
|
||||
|
||||
defaultAgent := al.registry.GetDefaultAgent()
|
||||
if defaultAgent == nil {
|
||||
t.Fatal("expected default agent")
|
||||
}
|
||||
|
||||
history := []providers.Message{
|
||||
{Role: "user", Content: "only one"},
|
||||
}
|
||||
defaultAgent.Sessions.SetHistory("session-tiny", history)
|
||||
|
||||
err := al.contextManager.Compact(context.Background(), &CompactRequest{
|
||||
SessionKey: "session-tiny",
|
||||
Reason: ContextCompressReasonRetry,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// History should be unchanged (too short to compress)
|
||||
newHistory := defaultAgent.Sessions.GetHistory("session-tiny")
|
||||
if len(newHistory) != len(history) {
|
||||
t.Fatalf("expected history unchanged, got %d messages (was %d)", len(newHistory), len(history))
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Legacy Compact post-turn tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestLegacyCompact_PostTurn_BelowThreshold(t *testing.T) {
|
||||
cfg := testConfig(t)
|
||||
al := newCMTestAgentLoop(cfg)
|
||||
|
||||
defaultAgent := al.registry.GetDefaultAgent()
|
||||
if defaultAgent == nil {
|
||||
t.Fatal("expected default agent")
|
||||
}
|
||||
|
||||
// Small history, below summarization thresholds
|
||||
history := []providers.Message{
|
||||
{Role: "user", Content: "hi"},
|
||||
{Role: "assistant", Content: "hello"},
|
||||
}
|
||||
defaultAgent.Sessions.SetHistory("session-small", history)
|
||||
|
||||
err := al.contextManager.Compact(context.Background(), &CompactRequest{
|
||||
SessionKey: "session-small",
|
||||
Reason: ContextCompressReasonSummarize,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// History should remain unchanged
|
||||
newHistory := defaultAgent.Sessions.GetHistory("session-small")
|
||||
if len(newHistory) != len(history) {
|
||||
t.Fatalf("expected unchanged history, got %d messages (was %d)", len(newHistory), len(history))
|
||||
}
|
||||
}
|
||||
|
||||
func TestLegacyCompact_PostTurn_ExceedsMessageThreshold(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: t.TempDir(),
|
||||
ModelName: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
ContextWindow: 8000,
|
||||
SummarizeMessageThreshold: 2,
|
||||
SummarizeTokenPercent: 75,
|
||||
},
|
||||
},
|
||||
}
|
||||
msgBus := bus.NewMessageBus()
|
||||
al := NewAgentLoop(cfg, msgBus, &simpleMockProvider{response: "summary"})
|
||||
|
||||
defaultAgent := al.registry.GetDefaultAgent()
|
||||
if defaultAgent == nil {
|
||||
t.Fatal("expected default agent")
|
||||
}
|
||||
|
||||
// 6 messages > threshold of 2
|
||||
history := []providers.Message{
|
||||
{Role: "user", Content: "q1"},
|
||||
{Role: "assistant", Content: "a1"},
|
||||
{Role: "user", Content: "q2"},
|
||||
{Role: "assistant", Content: "a2"},
|
||||
{Role: "user", Content: "q3"},
|
||||
{Role: "assistant", Content: "a3"},
|
||||
}
|
||||
defaultAgent.Sessions.SetHistory("session-threshold", history)
|
||||
|
||||
err := al.contextManager.Compact(context.Background(), &CompactRequest{
|
||||
SessionKey: "session-threshold",
|
||||
Reason: ContextCompressReasonSummarize,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Wait for async summarization to complete via event
|
||||
sub := al.SubscribeEvents(16)
|
||||
defer al.UnsubscribeEvents(sub.ID)
|
||||
|
||||
waitForEvent(t, sub.C, 5*time.Second, func(evt Event) bool {
|
||||
return evt.Kind == EventKindSessionSummarize
|
||||
})
|
||||
|
||||
newHistory := defaultAgent.Sessions.GetHistory("session-threshold")
|
||||
if len(newHistory) >= len(history) {
|
||||
t.Fatalf("expected summarization to reduce history from %d messages, got %d", len(history), len(newHistory))
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Legacy Ingest tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestLegacyIngest_NoOp(t *testing.T) {
|
||||
cfg := testConfig(t)
|
||||
al := newCMTestAgentLoop(cfg)
|
||||
|
||||
err := al.contextManager.Ingest(context.Background(), &IngestRequest{
|
||||
SessionKey: "session-ingest",
|
||||
Message: providers.Message{Role: "user", Content: "test"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Mock ContextManager — verifies dispatch through AgentLoop
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestAgentLoop_UsesCustomContextManager(t *testing.T) {
|
||||
cleanup := resetCMRegistry()
|
||||
defer cleanup()
|
||||
|
||||
mock := &trackingContextManager{}
|
||||
factory := func(cfg json.RawMessage, al *AgentLoop) (ContextManager, error) {
|
||||
return mock, nil
|
||||
}
|
||||
if err := RegisterContextManager("tracking_cm", factory); err != nil {
|
||||
t.Fatalf("register failed: %v", err)
|
||||
}
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: t.TempDir(),
|
||||
ModelName: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
ContextManager: "tracking_cm",
|
||||
},
|
||||
},
|
||||
}
|
||||
al := newCMTestAgentLoop(cfg)
|
||||
|
||||
// Verify the mock was installed
|
||||
if al.contextManager != mock {
|
||||
t.Fatalf("expected mock context manager, got %T", al.contextManager)
|
||||
}
|
||||
|
||||
// Direct method calls
|
||||
_, err := mock.Assemble(context.Background(), &AssembleRequest{
|
||||
SessionKey: "s1",
|
||||
Budget: 8000,
|
||||
MaxTokens: 4096,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Assemble error: %v", err)
|
||||
}
|
||||
if mock.assembleCalls.Load() != 1 {
|
||||
t.Fatalf("expected 1 assemble call, got %d", mock.assembleCalls.Load())
|
||||
}
|
||||
|
||||
err = mock.Compact(context.Background(), &CompactRequest{
|
||||
SessionKey: "s1",
|
||||
Reason: ContextCompressReasonRetry,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Compact error: %v", err)
|
||||
}
|
||||
if mock.compactCalls.Load() != 1 {
|
||||
t.Fatalf("expected 1 compact call, got %d", mock.compactCalls.Load())
|
||||
}
|
||||
|
||||
err = mock.Ingest(context.Background(), &IngestRequest{
|
||||
SessionKey: "s1",
|
||||
Message: providers.Message{Role: "user", Content: "test"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Ingest error: %v", err)
|
||||
}
|
||||
if mock.ingestCalls.Load() != 1 {
|
||||
t.Fatalf("expected 1 ingest call, got %d", mock.ingestCalls.Load())
|
||||
}
|
||||
}
|
||||
|
||||
func TestIngestCalledDuringTurn(t *testing.T) {
|
||||
cleanup := resetCMRegistry()
|
||||
defer cleanup()
|
||||
|
||||
mock := &trackingContextManager{}
|
||||
factory := func(cfg json.RawMessage, al *AgentLoop) (ContextManager, error) {
|
||||
return mock, nil
|
||||
}
|
||||
if err := RegisterContextManager("ingest_track_cm", factory); err != nil {
|
||||
t.Fatalf("register failed: %v", err)
|
||||
}
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: t.TempDir(),
|
||||
ModelName: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
ContextManager: "ingest_track_cm",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
al := NewAgentLoop(cfg, msgBus, &simpleMockProvider{response: "done"})
|
||||
defaultAgent := al.registry.GetDefaultAgent()
|
||||
if defaultAgent == nil {
|
||||
t.Fatal("expected default agent")
|
||||
}
|
||||
|
||||
// Run a turn — ingestMessage is called for user message and final assistant message
|
||||
_, err := al.runAgentLoop(context.Background(), defaultAgent, processOptions{
|
||||
SessionKey: "session-ingest-turn",
|
||||
Channel: "cli",
|
||||
ChatID: "direct",
|
||||
UserMessage: "test ingest",
|
||||
DefaultResponse: defaultResponse,
|
||||
EnableSummary: false,
|
||||
SendResponse: false,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("runAgentLoop failed: %v", err)
|
||||
}
|
||||
|
||||
// Should have at least 2 ingest calls: user message + final assistant message
|
||||
if mock.ingestCalls.Load() < 2 {
|
||||
t.Fatalf("expected >= 2 ingest calls during turn, got %d", mock.ingestCalls.Load())
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// forceCompression edge cases (via legacy Compact)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestLegacyCompact_Overflow_SingleTurnKeepsLastUserMessage(t *testing.T) {
|
||||
cfg := testConfig(t)
|
||||
al := newCMTestAgentLoop(cfg)
|
||||
|
||||
defaultAgent := al.registry.GetDefaultAgent()
|
||||
if defaultAgent == nil {
|
||||
t.Fatal("expected default agent")
|
||||
}
|
||||
|
||||
// History with only 2 messages — forceCompression should still handle it
|
||||
history := []providers.Message{
|
||||
{Role: "user", Content: "first question"},
|
||||
{Role: "assistant", Content: "first answer"},
|
||||
}
|
||||
defaultAgent.Sessions.SetHistory("session-2msg", history)
|
||||
|
||||
err := al.contextManager.Compact(context.Background(), &CompactRequest{
|
||||
SessionKey: "session-2msg",
|
||||
Reason: ContextCompressReasonRetry,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
newHistory := defaultAgent.Sessions.GetHistory("session-2msg")
|
||||
// With 2 messages, forceCompression returns false (len <= 2), so no compression
|
||||
if len(newHistory) != len(history) {
|
||||
t.Fatalf("expected no compression for 2-message history, got %d", len(newHistory))
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Test helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// noopContextManager is a minimal ContextManager that does nothing.
|
||||
type noopContextManager struct{}
|
||||
|
||||
func (m *noopContextManager) Assemble(_ context.Context, req *AssembleRequest) (*AssembleResponse, error) {
|
||||
return &AssembleResponse{}, nil
|
||||
}
|
||||
func (m *noopContextManager) Compact(_ context.Context, _ *CompactRequest) error { return nil }
|
||||
func (m *noopContextManager) Ingest(_ context.Context, _ *IngestRequest) error { return nil }
|
||||
|
||||
// trackingContextManager tracks call counts for each method.
|
||||
type trackingContextManager struct {
|
||||
assembleCalls atomic.Int64
|
||||
compactCalls atomic.Int64
|
||||
ingestCalls atomic.Int64
|
||||
mu sync.Mutex
|
||||
lastAssemble *AssembleRequest
|
||||
lastCompact *CompactRequest
|
||||
lastIngest *IngestRequest
|
||||
}
|
||||
|
||||
func (m *trackingContextManager) Assemble(_ context.Context, req *AssembleRequest) (*AssembleResponse, error) {
|
||||
m.assembleCalls.Add(1)
|
||||
m.mu.Lock()
|
||||
m.lastAssemble = req
|
||||
m.mu.Unlock()
|
||||
return &AssembleResponse{}, nil
|
||||
}
|
||||
|
||||
func (m *trackingContextManager) Compact(_ context.Context, req *CompactRequest) error {
|
||||
m.compactCalls.Add(1)
|
||||
m.mu.Lock()
|
||||
m.lastCompact = req
|
||||
m.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *trackingContextManager) Ingest(_ context.Context, req *IngestRequest) error {
|
||||
m.ingestCalls.Add(1)
|
||||
m.mu.Lock()
|
||||
m.lastIngest = req
|
||||
m.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// resetCMRegistry clears the global factory registry and returns a cleanup
|
||||
// function that restores the original state after the test.
|
||||
func resetCMRegistry() func() {
|
||||
cmRegistryMu.Lock()
|
||||
original := make(map[string]ContextManagerFactory, len(cmRegistry))
|
||||
for k, v := range cmRegistry {
|
||||
original[k] = v
|
||||
}
|
||||
cmRegistry = make(map[string]ContextManagerFactory)
|
||||
cmRegistryMu.Unlock()
|
||||
|
||||
return func() {
|
||||
cmRegistryMu.Lock()
|
||||
cmRegistry = original
|
||||
cmRegistryMu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func testConfig(t *testing.T) *config.Config {
|
||||
t.Helper()
|
||||
return &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: t.TempDir(),
|
||||
ModelName: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func newCMTestAgentLoop(cfg *config.Config) *AgentLoop {
|
||||
msgBus := bus.NewMessageBus()
|
||||
return NewAgentLoop(cfg, msgBus, &simpleMockProvider{response: "test"})
|
||||
}
|
||||
@@ -0,0 +1,269 @@
|
||||
//go:build !mipsle && !netbsd
|
||||
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
|
||||
"github.com/sipeed/picoclaw/pkg/seahorse"
|
||||
"github.com/sipeed/picoclaw/pkg/session"
|
||||
"github.com/sipeed/picoclaw/pkg/tokenizer"
|
||||
)
|
||||
|
||||
// seahorseContextManager adapts seahorse.Engine to agent.ContextManager.
|
||||
type seahorseContextManager struct {
|
||||
engine *seahorse.Engine
|
||||
sessions session.SessionStore // for startup bootstrap
|
||||
}
|
||||
|
||||
// newSeahorseContextManager creates a seahorse-backed ContextManager.
|
||||
func newSeahorseContextManager(_ json.RawMessage, al *AgentLoop) (ContextManager, error) {
|
||||
if al == nil {
|
||||
return nil, fmt.Errorf("seahorse: AgentLoop is required")
|
||||
}
|
||||
|
||||
// Resolve workspace for DB path
|
||||
// DB stores session data, so it goes in sessions/ directory
|
||||
agent := al.registry.GetDefaultAgent()
|
||||
dbPath := agent.Workspace + "/sessions/seahorse.db"
|
||||
|
||||
// Create CompleteFn from provider
|
||||
completeFn := providerToCompleteFn(agent.Provider, agent.Model)
|
||||
|
||||
// Create engine
|
||||
engine, err := seahorse.NewEngine(seahorse.Config{
|
||||
DBPath: dbPath,
|
||||
}, completeFn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("seahorse: create engine: %w", err)
|
||||
}
|
||||
|
||||
mgr := &seahorseContextManager{
|
||||
engine: engine,
|
||||
sessions: agent.Sessions,
|
||||
}
|
||||
|
||||
// Register seahorse tools with the agent's tool registry
|
||||
retrieval := mgr.engine.GetRetrieval()
|
||||
al.RegisterTool(seahorse.NewGrepTool(retrieval))
|
||||
al.RegisterTool(seahorse.NewExpandTool(retrieval))
|
||||
|
||||
// Bootstrap all existing sessions at startup
|
||||
if agent.Sessions != nil {
|
||||
ctx := context.Background()
|
||||
for _, sessionKey := range agent.Sessions.ListSessions() {
|
||||
mgr.bootstrapSession(ctx, sessionKey)
|
||||
}
|
||||
}
|
||||
|
||||
return mgr, nil
|
||||
}
|
||||
|
||||
// providerToCompleteFn wraps providers.LLMProvider as a seahorse.CompleteFn.
|
||||
func providerToCompleteFn(provider providers.LLMProvider, model string) seahorse.CompleteFn {
|
||||
return func(ctx context.Context, prompt string, opts seahorse.CompleteOptions) (string, error) {
|
||||
resp, err := provider.Chat(
|
||||
ctx,
|
||||
[]providers.Message{{Role: "user", Content: prompt}},
|
||||
nil, // no tools for summarization
|
||||
model,
|
||||
map[string]any{
|
||||
"max_tokens": opts.MaxTokens,
|
||||
"temperature": opts.Temperature,
|
||||
"prompt_cache_key": "seahorse",
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return resp.Content, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Assemble builds budget-aware context from seahorse SQLite.
|
||||
func (m *seahorseContextManager) Assemble(ctx context.Context, req *AssembleRequest) (*AssembleResponse, error) {
|
||||
if req == nil {
|
||||
return nil, fmt.Errorf("seahorse assemble: nil request")
|
||||
}
|
||||
|
||||
budget := req.Budget
|
||||
if budget <= 0 {
|
||||
budget = 100000
|
||||
}
|
||||
|
||||
// Reserve space for model response (spec lines 1400-1410)
|
||||
effectiveBudget := budget - req.MaxTokens
|
||||
if effectiveBudget <= 0 {
|
||||
// MaxTokens >= budget is a configuration problem
|
||||
// Use 50% as minimum to avoid guaranteed overflow
|
||||
logger.WarnCF("agent", "MaxTokens >= budget, using 50% fallback",
|
||||
map[string]any{"budget": budget, "max_tokens": req.MaxTokens})
|
||||
effectiveBudget = budget / 2
|
||||
}
|
||||
|
||||
result, err := m.engine.Assemble(ctx, req.SessionKey, seahorse.AssembleInput{
|
||||
Budget: effectiveBudget,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("seahorse assemble: %w", err)
|
||||
}
|
||||
|
||||
history := seahorseToProviderMessages(result)
|
||||
|
||||
// Summary is already formatted as XML with system prompt addition by assembler
|
||||
return &AssembleResponse{
|
||||
History: history,
|
||||
Summary: result.Summary,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Compact compresses conversation history via seahorse summarization.
|
||||
func (m *seahorseContextManager) Compact(ctx context.Context, req *CompactRequest) error {
|
||||
if req == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// For retry (LLM overflow), use aggressive CompactUntilUnder to guarantee
|
||||
// context shrinks below budget (spec lines ~1410).
|
||||
if req.Reason == ContextCompressReasonRetry && req.Budget > 0 {
|
||||
_, err := m.engine.CompactUntilUnder(ctx, req.SessionKey, req.Budget)
|
||||
return err
|
||||
}
|
||||
|
||||
_, err := m.engine.Compact(ctx, req.SessionKey, seahorse.CompactInput{
|
||||
Force: req.Reason == ContextCompressReasonRetry,
|
||||
Budget: &req.Budget,
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
// Ingest records a message into seahorse SQLite.
|
||||
// All existing sessions are bootstrapped at startup, so this only ingests new messages.
|
||||
func (m *seahorseContextManager) Ingest(ctx context.Context, req *IngestRequest) error {
|
||||
if req == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
msg := providerToSeahorseMessage(req.Message)
|
||||
_, err := m.engine.Ingest(ctx, req.SessionKey, []seahorse.Message{msg})
|
||||
return err
|
||||
}
|
||||
|
||||
// bootstrapSession reconciles JSONL session history into seahorse SQLite.
|
||||
func (m *seahorseContextManager) bootstrapSession(ctx context.Context, sessionKey string) {
|
||||
if m.sessions == nil {
|
||||
return
|
||||
}
|
||||
|
||||
history := m.sessions.GetHistory(sessionKey)
|
||||
if len(history) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Convert provider messages to seahorse messages
|
||||
msgs := make([]seahorse.Message, len(history))
|
||||
for i, h := range history {
|
||||
msgs[i] = providerToSeahorseMessage(h)
|
||||
}
|
||||
|
||||
if err := m.engine.Bootstrap(ctx, sessionKey, msgs); err != nil {
|
||||
logger.WarnCF("seahorse", "bootstrap", map[string]any{
|
||||
"session": sessionKey,
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// providerToSeahorseMessage converts a providers.Message to a seahorse.Message.
|
||||
func providerToSeahorseMessage(msg protocoltypes.Message) seahorse.Message {
|
||||
result := seahorse.Message{
|
||||
Role: msg.Role,
|
||||
Content: msg.Content,
|
||||
ReasoningContent: msg.ReasoningContent,
|
||||
TokenCount: tokenizer.EstimateMessageTokens(msg),
|
||||
}
|
||||
|
||||
// Convert ToolCalls → MessageParts
|
||||
for _, tc := range msg.ToolCalls {
|
||||
part := seahorse.MessagePart{
|
||||
Type: "tool_use",
|
||||
Name: tc.Function.Name,
|
||||
Arguments: tc.Function.Arguments,
|
||||
ToolCallID: tc.ID,
|
||||
}
|
||||
result.Parts = append(result.Parts, part)
|
||||
}
|
||||
|
||||
// Convert tool result
|
||||
if msg.ToolCallID != "" {
|
||||
part := seahorse.MessagePart{
|
||||
Type: "tool_result",
|
||||
ToolCallID: msg.ToolCallID,
|
||||
Text: msg.Content,
|
||||
}
|
||||
result.Parts = append(result.Parts, part)
|
||||
}
|
||||
|
||||
// Convert media attachments
|
||||
for _, mediaURI := range msg.Media {
|
||||
part := seahorse.MessagePart{
|
||||
Type: "media",
|
||||
MediaURI: mediaURI,
|
||||
}
|
||||
result.Parts = append(result.Parts, part)
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// seahorseToProviderMessages converts a seahorse.AssembleResult to []providers.Message.
|
||||
func seahorseToProviderMessages(result *seahorse.AssembleResult) []protocoltypes.Message {
|
||||
messages := make([]protocoltypes.Message, 0, len(result.Messages))
|
||||
|
||||
// Convert assembled messages (which already include summary XML messages)
|
||||
for _, msg := range result.Messages {
|
||||
pm := protocoltypes.Message{
|
||||
Role: msg.Role,
|
||||
Content: msg.Content,
|
||||
ReasoningContent: msg.ReasoningContent,
|
||||
}
|
||||
|
||||
// Reconstruct ToolCalls from parts
|
||||
for _, part := range msg.Parts {
|
||||
if part.Type == "tool_use" {
|
||||
pm.ToolCalls = append(pm.ToolCalls, protocoltypes.ToolCall{
|
||||
ID: part.ToolCallID,
|
||||
Type: "function", // Required by OpenAI-compatible APIs (GLM, etc.)
|
||||
Function: &protocoltypes.FunctionCall{
|
||||
Name: part.Name,
|
||||
Arguments: part.Arguments,
|
||||
},
|
||||
})
|
||||
}
|
||||
if part.Type == "tool_result" {
|
||||
pm.ToolCallID = part.ToolCallID
|
||||
if pm.Content == "" && part.Text != "" {
|
||||
pm.Content = part.Text
|
||||
}
|
||||
}
|
||||
if part.Type == "media" && part.MediaURI != "" {
|
||||
pm.Media = append(pm.Media, part.MediaURI)
|
||||
}
|
||||
}
|
||||
|
||||
messages = append(messages, pm)
|
||||
}
|
||||
|
||||
return messages
|
||||
}
|
||||
|
||||
func init() {
|
||||
if err := RegisterContextManager("seahorse", newSeahorseContextManager); err != nil {
|
||||
panic(fmt.Sprintf("register seahorse context manager: %v", err))
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,20 @@
|
||||
//go:build mipsle || netbsd
|
||||
|
||||
package agent
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// newSeahorseContextManager is unavailable on platforms where modernc sqlite/libc
|
||||
// currently has no stable build path for this project.
|
||||
func newSeahorseContextManager(_ json.RawMessage, _ *AgentLoop) (ContextManager, error) {
|
||||
return nil, fmt.Errorf("seahorse context manager is unavailable on this platform")
|
||||
}
|
||||
|
||||
func init() {
|
||||
if err := RegisterContextManager("seahorse", newSeahorseContextManager); err != nil {
|
||||
panic(fmt.Sprintf("register seahorse context manager: %v", err))
|
||||
}
|
||||
}
|
||||
@@ -511,8 +511,8 @@ func TestAgentLoop_EmitsSessionSummarizeEvent(t *testing.T) {
|
||||
sub := al.SubscribeEvents(16)
|
||||
defer al.UnsubscribeEvents(sub.ID)
|
||||
|
||||
turnScope := al.newTurnEventScope(defaultAgent.ID, "session-1", nil)
|
||||
al.summarizeSession(defaultAgent, "session-1", turnScope)
|
||||
lcm := &legacyContextManager{al: al}
|
||||
lcm.summarizeSession(defaultAgent, "session-1")
|
||||
|
||||
events := collectEventStream(sub.C)
|
||||
summaryEvt, ok := findEvent(events, EventKindSessionSummarize)
|
||||
|
||||
@@ -167,6 +167,8 @@ const (
|
||||
ContextCompressReasonProactive ContextCompressReason = "proactive_budget"
|
||||
// ContextCompressReasonRetry indicates compression during context-error retry handling.
|
||||
ContextCompressReasonRetry ContextCompressReason = "llm_retry"
|
||||
// ContextCompressReasonSummarize indicates post-turn async summarization.
|
||||
ContextCompressReasonSummarize ContextCompressReason = "summarize"
|
||||
)
|
||||
|
||||
// ContextCompressPayload describes a forced history compression.
|
||||
|
||||
+51
-1
@@ -51,6 +51,10 @@ type AgentInstance struct {
|
||||
// LightProvider is the concrete provider instance for the configured light model.
|
||||
// It is only used when routing selects the light tier for a turn.
|
||||
LightProvider providers.LLMProvider
|
||||
// CandidateProviders maps "provider/model" keys to per-candidate LLMProvider
|
||||
// instances. This allows each fallback model to use its own api_base and api_key
|
||||
// from model_list, instead of inheriting the primary model's provider config.
|
||||
CandidateProviders map[string]providers.LLMProvider
|
||||
}
|
||||
|
||||
// NewAgentInstance creates an agent instance from config.
|
||||
@@ -77,7 +81,12 @@ func NewAgentInstance(
|
||||
|
||||
if cfg.Tools.IsToolEnabled("read_file") {
|
||||
maxReadFileSize := cfg.Tools.ReadFile.MaxReadFileSize
|
||||
toolsRegistry.Register(tools.NewReadFileTool(workspace, readRestrict, maxReadFileSize, allowReadPaths))
|
||||
switch cfg.Tools.ReadFile.EffectiveMode() {
|
||||
case config.ReadFileModeLines:
|
||||
toolsRegistry.Register(tools.NewReadFileLinesTool(workspace, readRestrict, maxReadFileSize, allowReadPaths))
|
||||
default:
|
||||
toolsRegistry.Register(tools.NewReadFileBytesTool(workspace, readRestrict, maxReadFileSize, allowReadPaths))
|
||||
}
|
||||
}
|
||||
if cfg.Tools.IsToolEnabled("write_file") {
|
||||
toolsRegistry.Register(tools.NewWriteFileTool(workspace, restrict, allowWritePaths))
|
||||
@@ -170,6 +179,9 @@ func NewAgentInstance(
|
||||
// Resolve fallback candidates
|
||||
candidates := resolveModelCandidates(cfg, defaults.Provider, model, fallbacks)
|
||||
|
||||
candidateProviders := make(map[string]providers.LLMProvider)
|
||||
populateCandidateProvidersFromNames(cfg, workspace, fallbacks, candidateProviders)
|
||||
|
||||
// Model routing setup: pre-resolve light model candidates at creation time
|
||||
// to avoid repeated model_list lookups on every incoming message.
|
||||
var router *routing.Router
|
||||
@@ -194,6 +206,7 @@ func NewAgentInstance(
|
||||
})
|
||||
lightCandidates = resolved
|
||||
lightProvider = lp
|
||||
populateCandidateProvidersFromNames(cfg, workspace, []string{rc.LightModel}, candidateProviders)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
@@ -225,6 +238,43 @@ func NewAgentInstance(
|
||||
Router: router,
|
||||
LightCandidates: lightCandidates,
|
||||
LightProvider: lightProvider,
|
||||
CandidateProviders: candidateProviders,
|
||||
}
|
||||
}
|
||||
|
||||
// populateCandidateProvidersFromNames resolves each model name (alias or
|
||||
// "provider/model") via resolvedModelConfig and creates a dedicated LLMProvider
|
||||
// for it. This reuses the canonical config resolution path (GetModelConfig) so
|
||||
// alias handling and load-balancing stay consistent with the rest of the codebase.
|
||||
func populateCandidateProvidersFromNames(
|
||||
cfg *config.Config,
|
||||
workspace string,
|
||||
names []string,
|
||||
out map[string]providers.LLMProvider,
|
||||
) {
|
||||
if cfg == nil || len(names) == 0 {
|
||||
return
|
||||
}
|
||||
for _, name := range names {
|
||||
mc, err := resolvedModelConfig(cfg, strings.TrimSpace(name), workspace)
|
||||
if err != nil {
|
||||
logger.WarnCF("agent",
|
||||
"fallback provider: no model_list entry found; will inherit primary provider credentials",
|
||||
map[string]any{"name": name, "error": err.Error()})
|
||||
continue
|
||||
}
|
||||
protocol, modelID := providers.ExtractProtocol(strings.TrimSpace(mc.Model))
|
||||
key := providers.ModelKey(providers.NormalizeProvider(protocol), modelID)
|
||||
if _, exists := out[key]; exists {
|
||||
continue
|
||||
}
|
||||
p, _, err := providers.CreateProviderFromConfig(mc)
|
||||
if err != nil {
|
||||
logger.WarnCF("agent", "fallback provider: failed to create provider",
|
||||
map[string]any{"model": mc.Model, "error": err.Error()})
|
||||
continue
|
||||
}
|
||||
out[key] = p
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/media"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
func TestNewAgentInstance_UsesDefaultsTemperatureAndMaxTokens(t *testing.T) {
|
||||
@@ -165,6 +166,58 @@ func TestNewAgentInstance_ResolveCandidatesFromModelListAlias(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewAgentInstance_PreservesDistinctLimiterIdentityForSharedResolvedModel(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
ModelName: "glm-4.7",
|
||||
ModelFallbacks: []string{"glm-4.7__key_1"},
|
||||
},
|
||||
},
|
||||
ModelList: []*config.ModelConfig{
|
||||
{
|
||||
ModelName: "glm-4.7",
|
||||
Model: "zhipu/glm-4.7",
|
||||
RPM: 1,
|
||||
},
|
||||
{
|
||||
ModelName: "glm-4.7__key_1",
|
||||
Model: "zhipu/glm-4.7",
|
||||
RPM: 3,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
agent := NewAgentInstance(nil, &cfg.Agents.Defaults, cfg, &mockProvider{})
|
||||
if len(agent.Candidates) != 2 {
|
||||
t.Fatalf("len(Candidates) = %d, want 2", len(agent.Candidates))
|
||||
}
|
||||
|
||||
first := agent.Candidates[0]
|
||||
second := agent.Candidates[1]
|
||||
if first.Provider != "zhipu" || first.Model != "glm-4.7" {
|
||||
t.Fatalf("first candidate = %s/%s, want zhipu/glm-4.7", first.Provider, first.Model)
|
||||
}
|
||||
if second.Provider != "zhipu" || second.Model != "glm-4.7" {
|
||||
t.Fatalf("second candidate = %s/%s, want zhipu/glm-4.7", second.Provider, second.Model)
|
||||
}
|
||||
if first.IdentityKey != "model_name:glm-4.7" {
|
||||
t.Fatalf("first identity key = %q, want %q", first.IdentityKey, "model_name:glm-4.7")
|
||||
}
|
||||
if second.IdentityKey != "model_name:glm-4.7__key_1" {
|
||||
t.Fatalf("second identity key = %q, want %q", second.IdentityKey, "model_name:glm-4.7__key_1")
|
||||
}
|
||||
if first.RPM != 1 {
|
||||
t.Fatalf("first RPM = %d, want 1", first.RPM)
|
||||
}
|
||||
if second.RPM != 3 {
|
||||
t.Fatalf("second RPM = %d, want 3", second.RPM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewAgentInstance_AllowsMediaTempDirForReadListAndExec(t *testing.T) {
|
||||
workspace := t.TempDir()
|
||||
mediaDir := media.TempDir()
|
||||
@@ -248,6 +301,240 @@ func TestNewAgentInstance_AllowsMediaTempDirForReadListAndExec(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestPopulateCandidateProviders_NilCfgIsNoop verifies that passing a nil
|
||||
// config does not panic and leaves the output map empty.
|
||||
func TestPopulateCandidateProviders_NilCfgIsNoop(t *testing.T) {
|
||||
out := map[string]providers.LLMProvider{}
|
||||
populateCandidateProvidersFromNames(nil, t.TempDir(), []string{"gpt-4o"}, out)
|
||||
if len(out) != 0 {
|
||||
t.Fatalf("expected empty map, got %d entries", len(out))
|
||||
}
|
||||
}
|
||||
|
||||
// TestPopulateCandidateProviders_SkipsExistingKeys verifies that a key already
|
||||
// present in the output map is not overwritten.
|
||||
func TestPopulateCandidateProviders_SkipsExistingKeys(t *testing.T) {
|
||||
existing := &mockProvider{}
|
||||
key := providers.ModelKey("openai", "gpt-4o")
|
||||
out := map[string]providers.LLMProvider{key: existing}
|
||||
|
||||
cfg := &config.Config{
|
||||
ModelList: []*config.ModelConfig{
|
||||
{ModelName: "my-gpt", Model: "openai/gpt-4o", APIKeys: config.SimpleSecureStrings("test-key")},
|
||||
},
|
||||
}
|
||||
populateCandidateProvidersFromNames(cfg, t.TempDir(), []string{"my-gpt"}, out)
|
||||
|
||||
if out[key] != existing {
|
||||
t.Fatal("existing provider entry was overwritten; expected it to be preserved")
|
||||
}
|
||||
}
|
||||
|
||||
// TestPopulateCandidateProviders_ResolvesAlias verifies that a model_name
|
||||
// alias (e.g. "my-gpt") is resolved via GetModelConfig and the provider
|
||||
// is created using the underlying model's config.
|
||||
func TestPopulateCandidateProviders_ResolvesAlias(t *testing.T) {
|
||||
workspace := t.TempDir()
|
||||
out := map[string]providers.LLMProvider{}
|
||||
|
||||
cfg := &config.Config{
|
||||
ModelList: []*config.ModelConfig{
|
||||
{ModelName: "my-gpt", Model: "openai/gpt-4o", APIBase: "https://api.openai.com/v1", Workspace: workspace},
|
||||
},
|
||||
}
|
||||
populateCandidateProvidersFromNames(cfg, workspace, []string{"my-gpt"}, out)
|
||||
|
||||
key := providers.ModelKey("openai", "gpt-4o")
|
||||
if out[key] == nil {
|
||||
t.Fatalf("expected CandidateProviders[%q] to be populated for alias", key)
|
||||
}
|
||||
}
|
||||
|
||||
// TestPopulateCandidateProviders_ResolvesProtocolPrefix verifies that a
|
||||
// model_list entry using full "provider/model" notation (e.g.
|
||||
// "gemini/gemma-3-27b-it") is matched correctly when referenced by model_name.
|
||||
func TestPopulateCandidateProviders_ResolvesProtocolPrefix(t *testing.T) {
|
||||
workspace := t.TempDir()
|
||||
out := map[string]providers.LLMProvider{}
|
||||
|
||||
cfg := &config.Config{
|
||||
ModelList: []*config.ModelConfig{
|
||||
{
|
||||
ModelName: "gemma",
|
||||
Model: "gemini/gemma-3-27b-it",
|
||||
APIKeys: config.SimpleSecureStrings("gemini-test-key"),
|
||||
Workspace: workspace,
|
||||
},
|
||||
},
|
||||
}
|
||||
populateCandidateProvidersFromNames(cfg, workspace, []string{"gemma"}, out)
|
||||
|
||||
key := providers.ModelKey("gemini", "gemma-3-27b-it")
|
||||
if out[key] == nil {
|
||||
t.Fatalf("expected CandidateProviders[%q] to be populated for protocol-prefixed model", key)
|
||||
}
|
||||
}
|
||||
|
||||
// TestPopulateCandidateProviders_EmptyNamesIsNoop verifies the early-exit
|
||||
// path when the names slice is empty.
|
||||
func TestPopulateCandidateProviders_EmptyNamesIsNoop(t *testing.T) {
|
||||
out := map[string]providers.LLMProvider{}
|
||||
cfg := &config.Config{
|
||||
ModelList: []*config.ModelConfig{
|
||||
{ModelName: "my-gpt", Model: "openai/gpt-4o", APIKeys: config.SimpleSecureStrings("key")},
|
||||
},
|
||||
}
|
||||
populateCandidateProvidersFromNames(cfg, t.TempDir(), nil, out)
|
||||
if len(out) != 0 {
|
||||
t.Fatalf("expected empty map, got %d entries", len(out))
|
||||
}
|
||||
}
|
||||
|
||||
// TestPopulateCandidateProviders_EmptyModelListIsNoop verifies the early-exit
|
||||
// path when model_list is empty — no provider can be created.
|
||||
func TestPopulateCandidateProviders_EmptyModelListIsNoop(t *testing.T) {
|
||||
out := map[string]providers.LLMProvider{}
|
||||
cfg := &config.Config{}
|
||||
populateCandidateProvidersFromNames(cfg, t.TempDir(), []string{"gpt-4o"}, out)
|
||||
if len(out) != 0 {
|
||||
t.Fatalf("expected empty map, got %d entries", len(out))
|
||||
}
|
||||
}
|
||||
|
||||
// TestPopulateCandidateProviders_UnmatchedNameIsSkipped verifies that a
|
||||
// name with no matching model_list entry is skipped and does not
|
||||
// cause a panic or leave a nil entry in the map.
|
||||
func TestPopulateCandidateProviders_UnmatchedNameIsSkipped(t *testing.T) {
|
||||
out := map[string]providers.LLMProvider{}
|
||||
cfg := &config.Config{
|
||||
ModelList: []*config.ModelConfig{
|
||||
{ModelName: "my-gpt", Model: "openai/gpt-4o", APIKeys: config.SimpleSecureStrings("key")},
|
||||
},
|
||||
}
|
||||
populateCandidateProvidersFromNames(cfg, t.TempDir(), []string{"nonexistent-model"}, out)
|
||||
|
||||
if len(out) != 0 {
|
||||
t.Fatalf("expected empty map for unmatched name, got %d entries", len(out))
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewAgentInstance_CandidateProvidersPopulatedForCrossProviderFallbacks
|
||||
// mirrors the exact scenario from bug #2140: primary model on OpenRouter with
|
||||
// Gemini fallbacks. Each entry must get its own provider instance so that
|
||||
// fallback requests go to the correct API endpoint, not the primary's.
|
||||
func TestNewAgentInstance_CandidateProvidersPopulatedForCrossProviderFallbacks(t *testing.T) {
|
||||
workspace := t.TempDir()
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: workspace,
|
||||
ModelName: "mistral-small-3.1",
|
||||
ModelFallbacks: []string{"gemma-3-27b", "gemini-images"},
|
||||
},
|
||||
},
|
||||
ModelList: []*config.ModelConfig{
|
||||
{
|
||||
ModelName: "mistral-small-3.1",
|
||||
Model: "openrouter/mistralai/mistral-small-3.1-24b-instruct:free",
|
||||
APIBase: "https://openrouter.ai/api/v1",
|
||||
APIKeys: config.SimpleSecureStrings("sk-or-test"),
|
||||
Workspace: workspace,
|
||||
},
|
||||
{
|
||||
ModelName: "gemma-3-27b",
|
||||
Model: "gemini/gemma-3-27b-it",
|
||||
APIKeys: config.SimpleSecureStrings("AIzaSy-test"),
|
||||
Workspace: workspace,
|
||||
},
|
||||
{
|
||||
ModelName: "gemini-images",
|
||||
Model: "gemini/gemini-2.5-flash-lite",
|
||||
APIKeys: config.SimpleSecureStrings("AIzaSy-test"),
|
||||
Workspace: workspace,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
primaryProvider := &mockProvider{}
|
||||
agent := NewAgentInstance(nil, &cfg.Agents.Defaults, cfg, primaryProvider)
|
||||
|
||||
// Only fallback models need entries — the primary uses the injected provider directly.
|
||||
wantKeys := []string{
|
||||
providers.ModelKey("gemini", "gemma-3-27b-it"),
|
||||
providers.ModelKey("gemini", "gemini-2.5-flash-lite"),
|
||||
}
|
||||
|
||||
for _, key := range wantKeys {
|
||||
p, ok := agent.CandidateProviders[key]
|
||||
if !ok {
|
||||
t.Errorf("CandidateProviders missing key %q", key)
|
||||
continue
|
||||
}
|
||||
if p == nil {
|
||||
t.Errorf("CandidateProviders[%q] is nil", key)
|
||||
}
|
||||
// Each fallback must use its own provider, not the injected primary.
|
||||
if p == primaryProvider {
|
||||
t.Errorf(
|
||||
"CandidateProviders[%q] is the same instance as the primary provider; fallback would inherit primary credentials",
|
||||
key,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
if t.Failed() {
|
||||
t.Logf("CandidateProviders keys present: %v", func() []string {
|
||||
keys := make([]string, 0, len(agent.CandidateProviders))
|
||||
for k := range agent.CandidateProviders {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
return keys
|
||||
}())
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewAgentInstance_ReadFileModeSelectsSchema(t *testing.T) {
|
||||
workspace := t.TempDir()
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: workspace,
|
||||
ModelName: "test-model",
|
||||
},
|
||||
},
|
||||
Tools: config.ToolsConfig{
|
||||
ReadFile: config.ReadFileToolConfig{
|
||||
Enabled: true,
|
||||
Mode: config.ReadFileModeLines,
|
||||
MaxReadFileSize: 4096,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
agent := NewAgentInstance(nil, &cfg.Agents.Defaults, cfg, &mockProvider{})
|
||||
readTool, ok := agent.Tools.Get("read_file")
|
||||
if !ok {
|
||||
t.Fatal("read_file tool not registered")
|
||||
}
|
||||
|
||||
params := readTool.Parameters()
|
||||
props, _ := params["properties"].(map[string]any)
|
||||
if _, ok := props["start_line"]; !ok {
|
||||
t.Fatalf("expected line-mode schema to expose start_line, got %#v", props)
|
||||
}
|
||||
if _, ok := props["max_lines"]; !ok {
|
||||
t.Fatalf("expected line-mode schema to expose max_lines, got %#v", props)
|
||||
}
|
||||
if _, ok := props["offset"]; ok {
|
||||
t.Fatalf("did not expect line-mode schema to expose offset, got %#v", props)
|
||||
}
|
||||
if _, ok := props["length"]; ok {
|
||||
t.Fatalf("did not expect line-mode schema to expose length, got %#v", props)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewAgentInstance_InvalidExecConfigDoesNotExit(t *testing.T) {
|
||||
workspace := t.TempDir()
|
||||
|
||||
|
||||
+181
-374
@@ -18,6 +18,8 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/audio/asr"
|
||||
"github.com/sipeed/picoclaw/pkg/audio/tts"
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/channels"
|
||||
"github.com/sipeed/picoclaw/pkg/commands"
|
||||
@@ -32,7 +34,6 @@ import (
|
||||
"github.com/sipeed/picoclaw/pkg/state"
|
||||
"github.com/sipeed/picoclaw/pkg/tools"
|
||||
"github.com/sipeed/picoclaw/pkg/utils"
|
||||
"github.com/sipeed/picoclaw/pkg/voice"
|
||||
)
|
||||
|
||||
type AgentLoop struct {
|
||||
@@ -48,11 +49,11 @@ type AgentLoop struct {
|
||||
|
||||
// Runtime state
|
||||
running atomic.Bool
|
||||
summarizing sync.Map
|
||||
contextManager ContextManager
|
||||
fallback *providers.FallbackChain
|
||||
channelManager *channels.Manager
|
||||
mediaStore media.MediaStore
|
||||
transcriber voice.Transcriber
|
||||
transcriber asr.Transcriber
|
||||
cmdRegistry *commands.Registry
|
||||
mcp mcpRuntime
|
||||
hookRuntime hookRuntime
|
||||
@@ -116,9 +117,18 @@ func NewAgentLoop(
|
||||
) *AgentLoop {
|
||||
registry := NewAgentRegistry(cfg, provider)
|
||||
|
||||
// Set up shared fallback chain
|
||||
// Set up shared fallback chain with rate limiting.
|
||||
cooldown := providers.NewCooldownTracker()
|
||||
fallbackChain := providers.NewFallbackChain(cooldown)
|
||||
rl := providers.NewRateLimiterRegistry()
|
||||
// Register rate limiters for all agents' candidates so that RPM limits
|
||||
// configured in ModelConfig are enforced before each LLM call.
|
||||
for _, agentID := range registry.ListAgentIDs() {
|
||||
if agent, ok := registry.GetAgent(agentID); ok {
|
||||
rl.RegisterCandidates(agent.Candidates)
|
||||
rl.RegisterCandidates(agent.LightCandidates)
|
||||
}
|
||||
}
|
||||
fallbackChain := providers.NewFallbackChain(cooldown, rl)
|
||||
|
||||
// Create state manager using default agent's workspace for channel recording
|
||||
defaultAgent := registry.GetDefaultAgent()
|
||||
@@ -134,13 +144,13 @@ func NewAgentLoop(
|
||||
registry: registry,
|
||||
state: stateManager,
|
||||
eventBus: eventBus,
|
||||
summarizing: sync.Map{},
|
||||
fallback: fallbackChain,
|
||||
cmdRegistry: commands.NewRegistry(commands.BuiltinDefinitions()),
|
||||
steering: newSteeringQueue(parseSteeringMode(cfg.Agents.Defaults.SteeringMode)),
|
||||
}
|
||||
al.hooks = NewHookManager(eventBus)
|
||||
configureHookManagerFromConfig(al.hooks, cfg)
|
||||
al.contextManager = al.resolveContextManager()
|
||||
|
||||
// Register shared tools to all agents (now that al is created)
|
||||
registerSharedTools(al, cfg, msgBus, registry, provider)
|
||||
@@ -157,6 +167,13 @@ func registerSharedTools(
|
||||
provider providers.LLMProvider,
|
||||
) {
|
||||
allowReadPaths := buildAllowReadPatterns(cfg)
|
||||
var ttsProvider tts.TTSProvider
|
||||
if cfg.Tools.IsToolEnabled("send_tts") {
|
||||
ttsProvider = tts.DetectTTS(cfg)
|
||||
if ttsProvider == nil {
|
||||
logger.WarnCF("voice-tts", "send_tts enabled but no TTS provider configured", nil)
|
||||
}
|
||||
}
|
||||
|
||||
for _, agentID := range registry.ListAgentIDs() {
|
||||
agent, ok := registry.GetAgent(agentID)
|
||||
@@ -267,6 +284,21 @@ func registerSharedTools(
|
||||
agent.Tools.Register(sendFileTool)
|
||||
}
|
||||
|
||||
if ttsProvider != nil {
|
||||
agent.Tools.Register(tools.NewSendTTSTool(ttsProvider, nil))
|
||||
}
|
||||
|
||||
if cfg.Tools.IsToolEnabled("load_image") {
|
||||
loadImageTool := tools.NewLoadImageTool(
|
||||
agent.Workspace,
|
||||
cfg.Agents.Defaults.RestrictToWorkspace,
|
||||
cfg.Agents.Defaults.GetMaxMediaSize(),
|
||||
nil,
|
||||
allowReadPaths,
|
||||
)
|
||||
agent.Tools.Register(loadImageTool)
|
||||
}
|
||||
|
||||
// Skill discovery and installation tools
|
||||
skills_enabled := cfg.Tools.IsToolEnabled("skills")
|
||||
find_skills_enable := cfg.Tools.IsToolEnabled("find_skills")
|
||||
@@ -309,6 +341,14 @@ func registerSharedTools(
|
||||
subagentManager := tools.NewSubagentManager(provider, agent.Model, agent.Workspace)
|
||||
subagentManager.SetLLMOptions(agent.MaxTokens, agent.Temperature)
|
||||
|
||||
// Inject a media resolver so the legacy RunToolLoop fallback path can
|
||||
// resolve media:// refs in the same way the main AgentLoop does.
|
||||
// This keeps subagent vision support working even when the optimized
|
||||
// sub-turn spawner path is unavailable.
|
||||
subagentManager.SetMediaResolver(func(msgs []providers.Message) []providers.Message {
|
||||
return resolveMediaRefs(msgs, al.mediaStore, cfg.Agents.Defaults.GetMaxMediaSize())
|
||||
})
|
||||
|
||||
// Set the spawner that links into AgentLoop's turnState
|
||||
subagentManager.SetSpawner(func(
|
||||
ctx context.Context,
|
||||
@@ -1075,6 +1115,7 @@ func (al *AgentLoop) ReloadProviderAndConfig(
|
||||
go func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
logger.RecoverPanicNoExit(r)
|
||||
panicErr = fmt.Errorf("panic during registry creation: %v", r)
|
||||
logger.ErrorCF("agent", "Panic during registry creation",
|
||||
map[string]any{"panic": r})
|
||||
@@ -1115,8 +1156,15 @@ func (al *AgentLoop) ReloadProviderAndConfig(
|
||||
al.cfg = cfg
|
||||
al.registry = registry
|
||||
|
||||
// Also update fallback chain with new config
|
||||
al.fallback = providers.NewFallbackChain(providers.NewCooldownTracker())
|
||||
// Also update fallback chain with new config; rebuild rate limiter registry.
|
||||
newRL := providers.NewRateLimiterRegistry()
|
||||
for _, agentID := range registry.ListAgentIDs() {
|
||||
if agent, ok := registry.GetAgent(agentID); ok {
|
||||
newRL.RegisterCandidates(agent.Candidates)
|
||||
newRL.RegisterCandidates(agent.LightCandidates)
|
||||
}
|
||||
}
|
||||
al.fallback = providers.NewFallbackChain(providers.NewCooldownTracker(), newRL)
|
||||
|
||||
al.mu.Unlock()
|
||||
|
||||
@@ -1174,10 +1222,15 @@ func (al *AgentLoop) SetMediaStore(s media.MediaStore) {
|
||||
agent.Tools.SetMediaStore(s)
|
||||
}
|
||||
}
|
||||
registry.ForEachTool("send_tts", func(t tools.Tool) {
|
||||
if st, ok := t.(*tools.SendTTSTool); ok {
|
||||
st.SetMediaStore(s)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// SetTranscriber injects a voice transcriber for agent-level audio transcription.
|
||||
func (al *AgentLoop) SetTranscriber(t voice.Transcriber) {
|
||||
func (al *AgentLoop) SetTranscriber(t asr.Transcriber) {
|
||||
al.transcriber = t
|
||||
}
|
||||
|
||||
@@ -1198,19 +1251,23 @@ func (al *AgentLoop) transcribeAudioInMessage(ctx context.Context, msg bus.Inbou
|
||||
|
||||
// Transcribe each audio media ref in order.
|
||||
var transcriptions []string
|
||||
var keptMedia []string
|
||||
for _, ref := range msg.Media {
|
||||
path, meta, err := al.mediaStore.ResolveWithMeta(ref)
|
||||
if err != nil {
|
||||
logger.WarnCF("voice", "Failed to resolve media ref", map[string]any{"ref": ref, "error": err})
|
||||
keptMedia = append(keptMedia, ref)
|
||||
continue
|
||||
}
|
||||
if !utils.IsAudioFile(meta.Filename, meta.ContentType) {
|
||||
keptMedia = append(keptMedia, ref)
|
||||
continue
|
||||
}
|
||||
result, err := al.transcriber.Transcribe(ctx, path)
|
||||
if err != nil {
|
||||
logger.WarnCF("voice", "Transcription failed", map[string]any{"ref": ref, "error": err})
|
||||
transcriptions = append(transcriptions, "")
|
||||
keptMedia = append(keptMedia, ref)
|
||||
continue
|
||||
}
|
||||
transcriptions = append(transcriptions, result.Text)
|
||||
@@ -1230,15 +1287,21 @@ func (al *AgentLoop) transcribeAudioInMessage(ctx context.Context, msg bus.Inbou
|
||||
}
|
||||
text := transcriptions[idx]
|
||||
idx++
|
||||
if text == "" {
|
||||
return match
|
||||
}
|
||||
return "[voice: " + text + "]"
|
||||
})
|
||||
|
||||
// Append any remaining transcriptions not matched by an annotation.
|
||||
for ; idx < len(transcriptions); idx++ {
|
||||
newContent += "\n[voice: " + transcriptions[idx] + "]"
|
||||
if transcriptions[idx] != "" {
|
||||
newContent += "\n[voice: " + transcriptions[idx] + "]"
|
||||
}
|
||||
}
|
||||
|
||||
msg.Content = newContent
|
||||
msg.Media = keptMedia
|
||||
return msg, true
|
||||
}
|
||||
|
||||
@@ -1825,8 +1888,15 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er
|
||||
var history []providers.Message
|
||||
var summary string
|
||||
if !ts.opts.NoHistory {
|
||||
history = ts.agent.Sessions.GetHistory(ts.sessionKey)
|
||||
summary = ts.agent.Sessions.GetSummary(ts.sessionKey)
|
||||
// ContextManager assembles budget-aware history and summary.
|
||||
if resp, err := al.contextManager.Assemble(turnCtx, &AssembleRequest{
|
||||
SessionKey: ts.sessionKey,
|
||||
Budget: ts.agent.ContextWindow,
|
||||
MaxTokens: ts.agent.MaxTokens,
|
||||
}); err == nil && resp != nil {
|
||||
history = resp.History
|
||||
summary = resp.Summary
|
||||
}
|
||||
}
|
||||
ts.captureRestorePoint(history, summary)
|
||||
|
||||
@@ -1851,22 +1921,28 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er
|
||||
if isOverContextBudget(ts.agent.ContextWindow, messages, toolDefs, ts.agent.MaxTokens) {
|
||||
logger.WarnCF("agent", "Proactive compression: context budget exceeded before LLM call",
|
||||
map[string]any{"session_key": ts.sessionKey})
|
||||
if compression, ok := al.forceCompression(ts.agent, ts.sessionKey); ok {
|
||||
al.emitEvent(
|
||||
EventKindContextCompress,
|
||||
ts.eventMeta("runTurn", "turn.context.compress"),
|
||||
ContextCompressPayload{
|
||||
Reason: ContextCompressReasonProactive,
|
||||
DroppedMessages: compression.DroppedMessages,
|
||||
RemainingMessages: compression.RemainingMessages,
|
||||
},
|
||||
)
|
||||
ts.refreshRestorePointFromSession(ts.agent)
|
||||
if err := al.contextManager.Compact(turnCtx, &CompactRequest{
|
||||
SessionKey: ts.sessionKey,
|
||||
Reason: ContextCompressReasonProactive,
|
||||
Budget: ts.agent.ContextWindow,
|
||||
}); err != nil {
|
||||
logger.WarnCF("agent", "Proactive compact failed", map[string]any{
|
||||
"session_key": ts.sessionKey,
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
ts.refreshRestorePointFromSession(ts.agent)
|
||||
// Re-assemble from CM after compact.
|
||||
if resp, err := al.contextManager.Assemble(turnCtx, &AssembleRequest{
|
||||
SessionKey: ts.sessionKey,
|
||||
Budget: ts.agent.ContextWindow,
|
||||
MaxTokens: ts.agent.MaxTokens,
|
||||
}); err == nil && resp != nil {
|
||||
history = resp.History
|
||||
summary = resp.Summary
|
||||
}
|
||||
newHistory := ts.agent.Sessions.GetHistory(ts.sessionKey)
|
||||
newSummary := ts.agent.Sessions.GetSummary(ts.sessionKey)
|
||||
messages = ts.agent.ContextBuilder.BuildMessages(
|
||||
newHistory, newSummary, ts.userMessage,
|
||||
history, summary, ts.userMessage,
|
||||
ts.media, ts.channel, ts.chatID,
|
||||
ts.opts.SenderID, ts.opts.SenderDisplayName,
|
||||
activeSkillNames(ts.agent, ts.opts)...,
|
||||
@@ -1888,6 +1964,7 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er
|
||||
ts.agent.Sessions.AddMessage(ts.sessionKey, rootMsg.Role, rootMsg.Content)
|
||||
}
|
||||
ts.recordPersistedMessage(rootMsg)
|
||||
ts.ingestMessage(turnCtx, al, rootMsg)
|
||||
}
|
||||
|
||||
activeCandidates, activeModel, usedLight := al.selectCandidates(ts.agent, ts.userMessage, messages)
|
||||
@@ -1963,6 +2040,7 @@ turnLoop:
|
||||
if !ts.opts.NoHistory {
|
||||
ts.agent.Sessions.AddFullMessage(ts.sessionKey, pm)
|
||||
ts.recordPersistedMessage(pm)
|
||||
ts.ingestMessage(turnCtx, al, pm)
|
||||
}
|
||||
logger.InfoCF("agent", "Injected steering message into context",
|
||||
map[string]any{
|
||||
@@ -2016,6 +2094,14 @@ turnLoop:
|
||||
providerToolDefs = filtered
|
||||
}
|
||||
|
||||
// Resolve media:// refs produced by tool results (e.g. load_image).
|
||||
// Skipped on iteration 1 because inbound user media is already resolved
|
||||
// before entering the loop; only subsequent iterations can contain new
|
||||
// tool-generated media refs that need base64 encoding.
|
||||
if iteration > 1 {
|
||||
messages = resolveMediaRefs(messages, al.mediaStore, maxMediaSize)
|
||||
}
|
||||
|
||||
callMessages := messages
|
||||
if gracefulTerminal {
|
||||
callMessages = append(append([]providers.Message(nil), messages...), ts.interruptHintMessage())
|
||||
@@ -2115,7 +2201,11 @@ turnLoop:
|
||||
providerCtx,
|
||||
activeCandidates,
|
||||
func(ctx context.Context, provider, model string) (*providers.LLMResponse, error) {
|
||||
return activeProvider.Chat(ctx, messagesForCall, toolDefsForCall, model, llmOpts)
|
||||
candidateProvider := activeProvider
|
||||
if cp, ok := ts.agent.CandidateProviders[providers.ModelKey(provider, model)]; ok {
|
||||
candidateProvider = cp
|
||||
}
|
||||
return candidateProvider.Chat(ctx, messagesForCall, toolDefsForCall, model, llmOpts)
|
||||
},
|
||||
)
|
||||
if fbErr != nil {
|
||||
@@ -2221,23 +2311,28 @@ turnLoop:
|
||||
))
|
||||
}
|
||||
|
||||
if compression, ok := al.forceCompression(ts.agent, ts.sessionKey); ok {
|
||||
al.emitEvent(
|
||||
EventKindContextCompress,
|
||||
ts.eventMeta("runTurn", "turn.context.compress"),
|
||||
ContextCompressPayload{
|
||||
Reason: ContextCompressReasonRetry,
|
||||
DroppedMessages: compression.DroppedMessages,
|
||||
RemainingMessages: compression.RemainingMessages,
|
||||
},
|
||||
)
|
||||
ts.refreshRestorePointFromSession(ts.agent)
|
||||
if compactErr := al.contextManager.Compact(turnCtx, &CompactRequest{
|
||||
SessionKey: ts.sessionKey,
|
||||
Reason: ContextCompressReasonRetry,
|
||||
Budget: ts.agent.ContextWindow,
|
||||
}); compactErr != nil {
|
||||
logger.WarnCF("agent", "Context overflow compact failed", map[string]any{
|
||||
"session_key": ts.sessionKey,
|
||||
"error": compactErr.Error(),
|
||||
})
|
||||
}
|
||||
ts.refreshRestorePointFromSession(ts.agent)
|
||||
// Re-assemble from CM after compact.
|
||||
if asmResp, asmErr := al.contextManager.Assemble(turnCtx, &AssembleRequest{
|
||||
SessionKey: ts.sessionKey,
|
||||
Budget: ts.agent.ContextWindow,
|
||||
MaxTokens: ts.agent.MaxTokens,
|
||||
}); asmErr == nil && asmResp != nil {
|
||||
history = asmResp.History
|
||||
summary = asmResp.Summary
|
||||
}
|
||||
|
||||
newHistory := ts.agent.Sessions.GetHistory(ts.sessionKey)
|
||||
newSummary := ts.agent.Sessions.GetSummary(ts.sessionKey)
|
||||
messages = ts.agent.ContextBuilder.BuildMessages(
|
||||
newHistory, newSummary, "",
|
||||
history, summary, "",
|
||||
nil, ts.channel, ts.chatID, ts.opts.SenderID, ts.opts.SenderDisplayName,
|
||||
activeSkillNames(ts.agent, ts.opts)...,
|
||||
)
|
||||
@@ -2409,6 +2504,7 @@ turnLoop:
|
||||
if !ts.opts.NoHistory {
|
||||
ts.agent.Sessions.AddFullMessage(ts.sessionKey, assistantMsg)
|
||||
ts.recordPersistedMessage(assistantMsg)
|
||||
ts.ingestMessage(turnCtx, al, assistantMsg)
|
||||
}
|
||||
|
||||
ts.setPhase(TurnPhaseTools)
|
||||
@@ -2633,6 +2729,7 @@ turnLoop:
|
||||
if toolResult == nil {
|
||||
toolResult = tools.ErrorResult("hook returned nil tool result")
|
||||
}
|
||||
|
||||
if len(toolResult.Media) > 0 && toolResult.ResponseHandled {
|
||||
parts := make([]bus.MediaPart, 0, len(toolResult.Media))
|
||||
for _, ref := range toolResult.Media {
|
||||
@@ -2675,6 +2772,13 @@ turnLoop:
|
||||
}
|
||||
|
||||
if len(toolResult.Media) > 0 && !toolResult.ResponseHandled {
|
||||
// For tools like load_image that produce media refs without sending them
|
||||
// to the user channel (ResponseHandled == false), both Media and ArtifactTags
|
||||
// coexist on the result:
|
||||
// - Media: carries media:// refs that resolveMediaRefs will base64-encode
|
||||
// into image_url parts in the next LLM iteration (enabling vision).
|
||||
// - ArtifactTags: exposes the local file path as a structured [file:…] tag
|
||||
// in the tool result text, so the LLM knows an artifact was produced.
|
||||
toolResult.ArtifactTags = buildArtifactTags(al.mediaStore, toolResult.Media)
|
||||
}
|
||||
|
||||
@@ -2693,7 +2797,6 @@ turnLoop:
|
||||
"content_len": len(toolResult.ForUser),
|
||||
})
|
||||
}
|
||||
|
||||
contentForLLM := toolResult.ContentForLLM()
|
||||
|
||||
// Filter sensitive data (API keys, tokens, secrets) before sending to LLM
|
||||
@@ -2706,6 +2809,9 @@ turnLoop:
|
||||
Content: contentForLLM,
|
||||
ToolCallID: toolCallID,
|
||||
}
|
||||
if len(toolResult.Media) > 0 && !toolResult.ResponseHandled {
|
||||
toolResultMsg.Media = append(toolResultMsg.Media, toolResult.Media...)
|
||||
}
|
||||
al.emitEvent(
|
||||
EventKindToolExecEnd,
|
||||
ts.eventMeta("runTurn", "turn.tool.end"),
|
||||
@@ -2722,6 +2828,7 @@ turnLoop:
|
||||
if !ts.opts.NoHistory {
|
||||
ts.agent.Sessions.AddFullMessage(ts.sessionKey, toolResultMsg)
|
||||
ts.recordPersistedMessage(toolResultMsg)
|
||||
ts.ingestMessage(turnCtx, al, toolResultMsg)
|
||||
}
|
||||
|
||||
if steerMsgs := al.dequeueSteeringMessagesForScope(ts.sessionKey); len(steerMsgs) > 0 {
|
||||
@@ -2821,6 +2928,7 @@ turnLoop:
|
||||
if !ts.opts.NoHistory {
|
||||
ts.agent.Sessions.AddMessage(ts.sessionKey, summaryMsg.Role, summaryMsg.Content)
|
||||
ts.recordPersistedMessage(summaryMsg)
|
||||
ts.ingestMessage(turnCtx, al, summaryMsg)
|
||||
if err := ts.agent.Sessions.Save(ts.sessionKey); err != nil {
|
||||
turnStatus = TurnEndStatusError
|
||||
al.emitEvent(
|
||||
@@ -2835,7 +2943,7 @@ turnLoop:
|
||||
}
|
||||
}
|
||||
if ts.opts.EnableSummary {
|
||||
al.maybeSummarize(ts.agent, ts.sessionKey, ts.scope)
|
||||
al.contextManager.Compact(turnCtx, &CompactRequest{SessionKey: ts.sessionKey, Reason: ContextCompressReasonSummarize, Budget: ts.agent.ContextWindow})
|
||||
}
|
||||
|
||||
ts.setPhase(TurnPhaseCompleted)
|
||||
@@ -2890,6 +2998,7 @@ turnLoop:
|
||||
finalMsg := providers.Message{Role: "assistant", Content: finalContent}
|
||||
ts.agent.Sessions.AddMessage(ts.sessionKey, finalMsg.Role, finalMsg.Content)
|
||||
ts.recordPersistedMessage(finalMsg)
|
||||
ts.ingestMessage(turnCtx, al, finalMsg)
|
||||
if err := ts.agent.Sessions.Save(ts.sessionKey); err != nil {
|
||||
turnStatus = TurnEndStatusError
|
||||
al.emitEvent(
|
||||
@@ -2905,7 +3014,14 @@ turnLoop:
|
||||
}
|
||||
|
||||
if ts.opts.EnableSummary {
|
||||
al.maybeSummarize(ts.agent, ts.sessionKey, ts.scope)
|
||||
al.contextManager.Compact(
|
||||
turnCtx,
|
||||
&CompactRequest{
|
||||
SessionKey: ts.sessionKey,
|
||||
Reason: ContextCompressReasonSummarize,
|
||||
Budget: ts.agent.ContextWindow,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
ts.setPhase(TurnPhaseCompleted)
|
||||
@@ -2984,103 +3100,28 @@ func (al *AgentLoop) selectCandidates(
|
||||
return agent.LightCandidates, resolvedCandidateModel(agent.LightCandidates, agent.Router.LightModel()), true
|
||||
}
|
||||
|
||||
// maybeSummarize triggers summarization if the session history exceeds thresholds.
|
||||
func (al *AgentLoop) maybeSummarize(agent *AgentInstance, sessionKey string, turnScope turnEventScope) {
|
||||
newHistory := agent.Sessions.GetHistory(sessionKey)
|
||||
tokenEstimate := al.estimateTokens(newHistory)
|
||||
threshold := agent.ContextWindow * agent.SummarizeTokenPercent / 100
|
||||
|
||||
if len(newHistory) > agent.SummarizeMessageThreshold || tokenEstimate > threshold {
|
||||
summarizeKey := agent.ID + ":" + sessionKey
|
||||
if _, loading := al.summarizing.LoadOrStore(summarizeKey, true); !loading {
|
||||
go func() {
|
||||
defer al.summarizing.Delete(summarizeKey)
|
||||
logger.Debug("Memory threshold reached. Optimizing conversation history...")
|
||||
al.summarizeSession(agent, sessionKey, turnScope)
|
||||
}()
|
||||
}
|
||||
// resolveContextManager selects the ContextManager implementation based on config.
|
||||
func (al *AgentLoop) resolveContextManager() ContextManager {
|
||||
name := al.cfg.Agents.Defaults.ContextManager
|
||||
if name == "" || name == "legacy" {
|
||||
return &legacyContextManager{al: al}
|
||||
}
|
||||
}
|
||||
|
||||
type compressionResult struct {
|
||||
DroppedMessages int
|
||||
RemainingMessages int
|
||||
}
|
||||
|
||||
// forceCompression aggressively reduces context when the limit is hit.
|
||||
// It drops the oldest ~50% of Turns (a Turn is a complete user→LLM→response
|
||||
// cycle, as defined in #1316), so tool-call sequences are never split.
|
||||
//
|
||||
// If the history is a single Turn with no safe split point, the function
|
||||
// falls back to keeping only the most recent user message. This breaks
|
||||
// Turn atomicity as a last resort to avoid a context-exceeded loop.
|
||||
//
|
||||
// Session history contains only user/assistant/tool messages — the system
|
||||
// prompt is built dynamically by BuildMessages and is NOT stored here.
|
||||
// The compression note is recorded in the session summary so that
|
||||
// BuildMessages can include it in the next system prompt.
|
||||
func (al *AgentLoop) forceCompression(agent *AgentInstance, sessionKey string) (compressionResult, bool) {
|
||||
history := agent.Sessions.GetHistory(sessionKey)
|
||||
if len(history) <= 2 {
|
||||
return compressionResult{}, false
|
||||
factory, ok := lookupContextManager(name)
|
||||
if !ok {
|
||||
logger.WarnCF("agent", "Unknown context manager, falling back to legacy", map[string]any{
|
||||
"name": name,
|
||||
})
|
||||
return &legacyContextManager{al: al}
|
||||
}
|
||||
|
||||
// Split at a Turn boundary so no tool-call sequence is torn apart.
|
||||
// parseTurnBoundaries gives us the start of each Turn; we drop the
|
||||
// oldest half of Turns and keep the most recent ones.
|
||||
turns := parseTurnBoundaries(history)
|
||||
var mid int
|
||||
if len(turns) >= 2 {
|
||||
mid = turns[len(turns)/2]
|
||||
} else {
|
||||
// Fewer than 2 Turns — fall back to message-level midpoint
|
||||
// aligned to the nearest Turn boundary.
|
||||
mid = findSafeBoundary(history, len(history)/2)
|
||||
cm, err := factory(al.cfg.Agents.Defaults.ContextManagerConfig, al)
|
||||
if err != nil {
|
||||
logger.WarnCF("agent", "Failed to create context manager, falling back to legacy", map[string]any{
|
||||
"name": name,
|
||||
"error": err.Error(),
|
||||
})
|
||||
return &legacyContextManager{al: al}
|
||||
}
|
||||
var keptHistory []providers.Message
|
||||
if mid <= 0 {
|
||||
// No safe Turn boundary — the entire history is a single Turn
|
||||
// (e.g. one user message followed by a massive tool response).
|
||||
// Keeping everything would leave the agent stuck in a context-
|
||||
// exceeded loop, so fall back to keeping only the most recent
|
||||
// user message. This breaks Turn atomicity as a last resort.
|
||||
for i := len(history) - 1; i >= 0; i-- {
|
||||
if history[i].Role == "user" {
|
||||
keptHistory = []providers.Message{history[i]}
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
keptHistory = history[mid:]
|
||||
}
|
||||
|
||||
droppedCount := len(history) - len(keptHistory)
|
||||
|
||||
// Record compression in the session summary so BuildMessages includes it
|
||||
// in the system prompt. We do not modify history messages themselves.
|
||||
existingSummary := agent.Sessions.GetSummary(sessionKey)
|
||||
compressionNote := fmt.Sprintf(
|
||||
"[Emergency compression dropped %d oldest messages due to context limit]",
|
||||
droppedCount,
|
||||
)
|
||||
if existingSummary != "" {
|
||||
compressionNote = existingSummary + "\n\n" + compressionNote
|
||||
}
|
||||
agent.Sessions.SetSummary(sessionKey, compressionNote)
|
||||
|
||||
agent.Sessions.SetHistory(sessionKey, keptHistory)
|
||||
agent.Sessions.Save(sessionKey)
|
||||
|
||||
logger.WarnCF("agent", "Forced compression executed", map[string]any{
|
||||
"session_key": sessionKey,
|
||||
"dropped_msgs": droppedCount,
|
||||
"new_count": len(keptHistory),
|
||||
})
|
||||
|
||||
return compressionResult{
|
||||
DroppedMessages: droppedCount,
|
||||
RemainingMessages: len(keptHistory),
|
||||
}, true
|
||||
return cm
|
||||
}
|
||||
|
||||
// GetStartupInfo returns information about loaded tools and skills for logging.
|
||||
@@ -3172,247 +3213,13 @@ func formatToolsForLog(toolDefs []providers.ToolDefinition) string {
|
||||
}
|
||||
|
||||
// summarizeSession summarizes the conversation history for a session.
|
||||
func (al *AgentLoop) summarizeSession(agent *AgentInstance, sessionKey string, turnScope turnEventScope) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||
defer cancel()
|
||||
|
||||
history := agent.Sessions.GetHistory(sessionKey)
|
||||
summary := agent.Sessions.GetSummary(sessionKey)
|
||||
|
||||
// Keep the most recent Turns for continuity, aligned to a Turn boundary
|
||||
// so that no tool-call sequence is split.
|
||||
if len(history) <= 4 {
|
||||
return
|
||||
}
|
||||
|
||||
safeCut := findSafeBoundary(history, len(history)-4)
|
||||
if safeCut <= 0 {
|
||||
return
|
||||
}
|
||||
keepCount := len(history) - safeCut
|
||||
toSummarize := history[:safeCut]
|
||||
|
||||
// Oversized Message Guard
|
||||
maxMessageTokens := agent.ContextWindow / 2
|
||||
validMessages := make([]providers.Message, 0)
|
||||
omitted := false
|
||||
|
||||
for _, m := range toSummarize {
|
||||
if m.Role != "user" && m.Role != "assistant" {
|
||||
continue
|
||||
}
|
||||
msgTokens := len(m.Content) / 2
|
||||
if msgTokens > maxMessageTokens {
|
||||
omitted = true
|
||||
continue
|
||||
}
|
||||
validMessages = append(validMessages, m)
|
||||
}
|
||||
|
||||
if len(validMessages) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
const (
|
||||
maxSummarizationMessages = 10
|
||||
llmMaxRetries = 3
|
||||
llmTemperature = 0.3
|
||||
fallbackMaxContentLength = 200
|
||||
)
|
||||
|
||||
// Multi-Part Summarization
|
||||
var finalSummary string
|
||||
if len(validMessages) > maxSummarizationMessages {
|
||||
mid := len(validMessages) / 2
|
||||
|
||||
mid = al.findNearestUserMessage(validMessages, mid)
|
||||
|
||||
part1 := validMessages[:mid]
|
||||
part2 := validMessages[mid:]
|
||||
|
||||
s1, _ := al.summarizeBatch(ctx, agent, part1, "")
|
||||
s2, _ := al.summarizeBatch(ctx, agent, part2, "")
|
||||
|
||||
mergePrompt := fmt.Sprintf(
|
||||
"Merge these two conversation summaries into one cohesive summary:\n\n1: %s\n\n2: %s",
|
||||
s1,
|
||||
s2,
|
||||
)
|
||||
|
||||
resp, err := al.retryLLMCall(ctx, agent, mergePrompt, llmMaxRetries)
|
||||
if err == nil && resp.Content != "" {
|
||||
finalSummary = resp.Content
|
||||
} else {
|
||||
finalSummary = s1 + " " + s2
|
||||
}
|
||||
} else {
|
||||
finalSummary, _ = al.summarizeBatch(ctx, agent, validMessages, summary)
|
||||
}
|
||||
|
||||
if omitted && finalSummary != "" {
|
||||
finalSummary += "\n[Note: Some oversized messages were omitted from this summary for efficiency.]"
|
||||
}
|
||||
|
||||
if finalSummary != "" {
|
||||
agent.Sessions.SetSummary(sessionKey, finalSummary)
|
||||
agent.Sessions.TruncateHistory(sessionKey, keepCount)
|
||||
agent.Sessions.Save(sessionKey)
|
||||
al.emitEvent(
|
||||
EventKindSessionSummarize,
|
||||
turnScope.meta(0, "summarizeSession", "turn.session.summarize"),
|
||||
SessionSummarizePayload{
|
||||
SummarizedMessages: len(validMessages),
|
||||
KeptMessages: keepCount,
|
||||
SummaryLen: len(finalSummary),
|
||||
OmittedOversized: omitted,
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// findNearestUserMessage finds the nearest user message to the given index.
|
||||
// It searches backward first, then forward if no user message is found.
|
||||
func (al *AgentLoop) findNearestUserMessage(messages []providers.Message, mid int) int {
|
||||
originalMid := mid
|
||||
|
||||
for mid > 0 && messages[mid].Role != "user" {
|
||||
mid--
|
||||
}
|
||||
|
||||
if messages[mid].Role == "user" {
|
||||
return mid
|
||||
}
|
||||
|
||||
mid = originalMid
|
||||
for mid < len(messages) && messages[mid].Role != "user" {
|
||||
mid++
|
||||
}
|
||||
|
||||
if mid < len(messages) {
|
||||
return mid
|
||||
}
|
||||
|
||||
return originalMid
|
||||
}
|
||||
|
||||
// retryLLMCall calls the LLM with retry logic.
|
||||
func (al *AgentLoop) retryLLMCall(
|
||||
ctx context.Context,
|
||||
agent *AgentInstance,
|
||||
prompt string,
|
||||
maxRetries int,
|
||||
) (*providers.LLMResponse, error) {
|
||||
const (
|
||||
llmTemperature = 0.3
|
||||
)
|
||||
|
||||
var resp *providers.LLMResponse
|
||||
var err error
|
||||
|
||||
for attempt := 0; attempt < maxRetries; attempt++ {
|
||||
al.activeRequests.Add(1)
|
||||
resp, err = func() (*providers.LLMResponse, error) {
|
||||
defer al.activeRequests.Done()
|
||||
return agent.Provider.Chat(
|
||||
ctx,
|
||||
[]providers.Message{{Role: "user", Content: prompt}},
|
||||
nil,
|
||||
agent.Model,
|
||||
map[string]any{
|
||||
"max_tokens": agent.MaxTokens,
|
||||
"temperature": llmTemperature,
|
||||
"prompt_cache_key": agent.ID,
|
||||
},
|
||||
)
|
||||
}()
|
||||
|
||||
if err == nil && resp != nil && resp.Content != "" {
|
||||
return resp, nil
|
||||
}
|
||||
if attempt < maxRetries-1 {
|
||||
time.Sleep(time.Duration(attempt+1) * 100 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
|
||||
return resp, err
|
||||
}
|
||||
|
||||
// summarizeBatch summarizes a batch of messages.
|
||||
func (al *AgentLoop) summarizeBatch(
|
||||
ctx context.Context,
|
||||
agent *AgentInstance,
|
||||
batch []providers.Message,
|
||||
existingSummary string,
|
||||
) (string, error) {
|
||||
const (
|
||||
llmMaxRetries = 3
|
||||
llmTemperature = 0.3
|
||||
fallbackMinContentLength = 200
|
||||
fallbackMaxContentPercent = 10
|
||||
)
|
||||
|
||||
var sb strings.Builder
|
||||
sb.WriteString(
|
||||
"Provide a concise summary of this conversation segment, preserving core context and key points.\n",
|
||||
)
|
||||
if existingSummary != "" {
|
||||
sb.WriteString("Existing context: ")
|
||||
sb.WriteString(existingSummary)
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
sb.WriteString("\nCONVERSATION:\n")
|
||||
for _, m := range batch {
|
||||
fmt.Fprintf(&sb, "%s: %s\n", m.Role, m.Content)
|
||||
}
|
||||
prompt := sb.String()
|
||||
|
||||
response, err := al.retryLLMCall(ctx, agent, prompt, llmMaxRetries)
|
||||
if err == nil && response.Content != "" {
|
||||
return strings.TrimSpace(response.Content), nil
|
||||
}
|
||||
|
||||
var fallback strings.Builder
|
||||
fallback.WriteString("Conversation summary: ")
|
||||
for i, m := range batch {
|
||||
if i > 0 {
|
||||
fallback.WriteString(" | ")
|
||||
}
|
||||
content := strings.TrimSpace(m.Content)
|
||||
runes := []rune(content)
|
||||
if len(runes) == 0 {
|
||||
fallback.WriteString(fmt.Sprintf("%s: ", m.Role))
|
||||
continue
|
||||
}
|
||||
|
||||
keepLength := len(runes) * fallbackMaxContentPercent / 100
|
||||
if keepLength < fallbackMinContentLength {
|
||||
keepLength = fallbackMinContentLength
|
||||
}
|
||||
|
||||
if keepLength > len(runes) {
|
||||
keepLength = len(runes)
|
||||
}
|
||||
|
||||
content = string(runes[:keepLength])
|
||||
if keepLength < len(runes) {
|
||||
content += "..."
|
||||
}
|
||||
fallback.WriteString(fmt.Sprintf("%s: %s", m.Role, content))
|
||||
}
|
||||
return fallback.String(), nil
|
||||
}
|
||||
|
||||
// estimateTokens estimates the number of tokens in a message list.
|
||||
// Counts Content, ToolCalls arguments, and ToolCallID metadata so that
|
||||
// tool-heavy conversations are not systematically undercounted.
|
||||
func (al *AgentLoop) estimateTokens(messages []providers.Message) int {
|
||||
total := 0
|
||||
for _, m := range messages {
|
||||
total += estimateMessageTokens(m)
|
||||
}
|
||||
return total
|
||||
}
|
||||
|
||||
func (al *AgentLoop) handleCommand(
|
||||
ctx context.Context,
|
||||
msg bus.InboundMessage,
|
||||
@@ -3609,7 +3416,7 @@ func (al *AgentLoop) buildCommandsRuntime(agent *AgentInstance, opts *processOpt
|
||||
return "", fmt.Errorf("failed to initialize model %q: %w", value, err)
|
||||
}
|
||||
|
||||
nextCandidates := resolveModelCandidates(cfg, cfg.Agents.Defaults.Provider, modelCfg.Model, agent.Fallbacks)
|
||||
nextCandidates := resolveModelCandidates(cfg, cfg.Agents.Defaults.Provider, value, agent.Fallbacks)
|
||||
if len(nextCandidates) == 0 {
|
||||
return "", fmt.Errorf("model %q did not resolve to any provider candidates", value)
|
||||
}
|
||||
|
||||
@@ -126,6 +126,8 @@ func (al *AgentLoop) ensureMCPInitialized(ctx context.Context) error {
|
||||
}
|
||||
|
||||
mcpTool := tools.NewMCPTool(mcpManager, serverName, tool)
|
||||
mcpTool.SetWorkspace(agent.Workspace)
|
||||
mcpTool.SetMaxInlineTextRunes(al.cfg.Tools.MCP.GetMaxInlineTextChars())
|
||||
|
||||
if registerAsHidden {
|
||||
agent.Tools.RegisterHidden(mcpTool)
|
||||
|
||||
@@ -2132,6 +2132,162 @@ func TestProcessMessage_ModelRoutingUsesLightProvider(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestProcessMessage_FallbackUsesPerCandidateProvider is the loop-level test for
|
||||
// bug #2140. It verifies that when the primary model returns a rate-limit error
|
||||
// the fallback closure routes the retry to the fallback model's own provider
|
||||
// (its own api_base), not back to the primary provider's endpoint.
|
||||
func TestProcessMessage_FallbackUsesPerCandidateProvider(t *testing.T) {
|
||||
workspace := t.TempDir()
|
||||
|
||||
primaryCalls := 0
|
||||
primaryServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
primaryCalls++
|
||||
// Return 429 so FallbackChain classifies this as retriable and moves on.
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusTooManyRequests)
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"error": map[string]any{
|
||||
"message": "rate limit exceeded",
|
||||
"type": "rate_limit_error",
|
||||
},
|
||||
})
|
||||
}))
|
||||
defer primaryServer.Close()
|
||||
|
||||
fallbackCalls := 0
|
||||
fallbackServer := newStrictChatCompletionTestServer(
|
||||
t, "fallback", "gemma-3-27b-it", "fallback reply", &fallbackCalls,
|
||||
)
|
||||
defer fallbackServer.Close()
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: workspace,
|
||||
ModelName: "mistral-primary",
|
||||
ModelFallbacks: []string{"gemma-fallback"},
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 3,
|
||||
},
|
||||
},
|
||||
ModelList: []*config.ModelConfig{
|
||||
{
|
||||
ModelName: "mistral-primary",
|
||||
Model: "openrouter/mistralai/mistral-small-3.1",
|
||||
APIBase: primaryServer.URL,
|
||||
APIKeys: config.SimpleSecureStrings("primary-key"),
|
||||
Workspace: workspace,
|
||||
},
|
||||
{
|
||||
ModelName: "gemma-fallback",
|
||||
Model: "gemini/gemma-3-27b-it",
|
||||
APIBase: fallbackServer.URL,
|
||||
APIKeys: config.SimpleSecureStrings("fallback-key"),
|
||||
Workspace: workspace,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
provider, _, err := providers.CreateProvider(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateProvider() error = %v", err)
|
||||
}
|
||||
msgBus := bus.NewMessageBus()
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
helper := testHelper{al: al}
|
||||
|
||||
resp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
|
||||
Channel: "telegram",
|
||||
SenderID: "user1",
|
||||
ChatID: "chat1",
|
||||
Content: "hi",
|
||||
})
|
||||
|
||||
if resp != "fallback reply" {
|
||||
t.Fatalf("response = %q, want %q (fallback provider)", resp, "fallback reply")
|
||||
}
|
||||
if primaryCalls == 0 {
|
||||
t.Fatal("primary server was never called; expected at least one attempt")
|
||||
}
|
||||
if fallbackCalls != 1 {
|
||||
t.Fatalf("fallback server calls = %d, want 1", fallbackCalls)
|
||||
}
|
||||
}
|
||||
|
||||
// TestProcessMessage_FallbackUsesActiveProviderWhenCandidateNotRegistered verifies
|
||||
// that when a candidate has no model_list entry it is absent from CandidateProviders
|
||||
// and the fallback closure falls back to activeProvider instead of panicking.
|
||||
func TestProcessMessage_FallbackUsesActiveProviderWhenCandidateNotRegistered(t *testing.T) {
|
||||
workspace := t.TempDir()
|
||||
|
||||
// Primary server: returns 429 on first call, succeeds on second.
|
||||
// Both the primary and the unregistered fallback share this server
|
||||
// (same api_base) so activeProvider routes both calls here.
|
||||
callCount := 0
|
||||
primaryServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
callCount++
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if callCount == 1 {
|
||||
w.WriteHeader(http.StatusTooManyRequests)
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"error": map[string]any{"message": "rate limit", "type": "rate_limit_error"},
|
||||
})
|
||||
return
|
||||
}
|
||||
// Second call (fallback via activeProvider) succeeds.
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"choices": []map[string]any{
|
||||
{"message": map[string]any{"content": "active provider reply"}, "finish_reason": "stop"},
|
||||
},
|
||||
})
|
||||
}))
|
||||
defer primaryServer.Close()
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: workspace,
|
||||
ModelName: "primary-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 3,
|
||||
// No model_list entry for this alias — absent from CandidateProviders.
|
||||
ModelFallbacks: []string{"openrouter/fallback-model"},
|
||||
},
|
||||
},
|
||||
ModelList: []*config.ModelConfig{
|
||||
{
|
||||
ModelName: "primary-model",
|
||||
Model: "openrouter/primary-model",
|
||||
APIBase: primaryServer.URL,
|
||||
APIKeys: config.SimpleSecureStrings("primary-key"),
|
||||
Workspace: workspace,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
provider, _, err := providers.CreateProvider(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateProvider() error = %v", err)
|
||||
}
|
||||
msgBus := bus.NewMessageBus()
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
|
||||
helper := testHelper{al: al}
|
||||
resp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
|
||||
Channel: "telegram",
|
||||
SenderID: "user1",
|
||||
ChatID: "chat1",
|
||||
Content: "hi",
|
||||
})
|
||||
|
||||
if resp != "active provider reply" {
|
||||
t.Fatalf("response = %q, want %q", resp, "active provider reply")
|
||||
}
|
||||
if callCount < 2 {
|
||||
t.Fatalf("primary server calls = %d, want >= 2 (one 429 + one success via activeProvider)", callCount)
|
||||
}
|
||||
}
|
||||
|
||||
// TestToolResult_SilentToolDoesNotSendUserMessage verifies silent tools don't trigger outbound
|
||||
func TestToolResult_SilentToolDoesNotSendUserMessage(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
|
||||
+116
-43
@@ -8,44 +8,102 @@ import (
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
func buildModelListResolver(cfg *config.Config) func(raw string) (string, bool) {
|
||||
ensureProtocol := func(model string) string {
|
||||
model = strings.TrimSpace(model)
|
||||
if model == "" {
|
||||
return ""
|
||||
}
|
||||
if strings.Contains(model, "/") {
|
||||
return model
|
||||
}
|
||||
return "openai/" + model
|
||||
func ensureProtocolModel(model string) string {
|
||||
model = strings.TrimSpace(model)
|
||||
if model == "" {
|
||||
return ""
|
||||
}
|
||||
if strings.Contains(model, "/") {
|
||||
return model
|
||||
}
|
||||
return "openai/" + model
|
||||
}
|
||||
|
||||
func modelConfigIdentityKey(mc *config.ModelConfig) string {
|
||||
if mc == nil {
|
||||
return ""
|
||||
}
|
||||
if name := strings.TrimSpace(mc.ModelName); name != "" {
|
||||
return "model_name:" + name
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func candidateFromModelConfig(
|
||||
defaultProvider string,
|
||||
mc *config.ModelConfig,
|
||||
) (providers.FallbackCandidate, bool) {
|
||||
if mc == nil {
|
||||
return providers.FallbackCandidate{}, false
|
||||
}
|
||||
|
||||
return func(raw string) (string, bool) {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" || cfg == nil {
|
||||
return "", false
|
||||
}
|
||||
|
||||
if mc, err := cfg.GetModelConfig(raw); err == nil && mc != nil && strings.TrimSpace(mc.Model) != "" {
|
||||
return ensureProtocol(mc.Model), true
|
||||
}
|
||||
|
||||
for i := range cfg.ModelList {
|
||||
fullModel := strings.TrimSpace(cfg.ModelList[i].Model)
|
||||
if fullModel == "" {
|
||||
continue
|
||||
}
|
||||
if fullModel == raw {
|
||||
return ensureProtocol(fullModel), true
|
||||
}
|
||||
_, modelID := providers.ExtractProtocol(fullModel)
|
||||
if modelID == raw {
|
||||
return ensureProtocol(fullModel), true
|
||||
}
|
||||
}
|
||||
|
||||
return "", false
|
||||
ref := providers.ParseModelRef(ensureProtocolModel(mc.Model), defaultProvider)
|
||||
if ref == nil {
|
||||
return providers.FallbackCandidate{}, false
|
||||
}
|
||||
|
||||
return providers.FallbackCandidate{
|
||||
Provider: ref.Provider,
|
||||
Model: ref.Model,
|
||||
RPM: mc.RPM,
|
||||
IdentityKey: modelConfigIdentityKey(mc),
|
||||
}, true
|
||||
}
|
||||
|
||||
func lookupModelConfigByRef(cfg *config.Config, raw string) *config.ModelConfig {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" || cfg == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if mc, err := cfg.GetModelConfig(raw); err == nil && mc != nil && strings.TrimSpace(mc.Model) != "" {
|
||||
return mc
|
||||
}
|
||||
|
||||
for i := range cfg.ModelList {
|
||||
mc := cfg.ModelList[i]
|
||||
if mc == nil {
|
||||
continue
|
||||
}
|
||||
fullModel := strings.TrimSpace(mc.Model)
|
||||
if fullModel == "" {
|
||||
continue
|
||||
}
|
||||
if fullModel == raw {
|
||||
return mc
|
||||
}
|
||||
_, modelID := providers.ExtractProtocol(fullModel)
|
||||
if modelID == raw {
|
||||
return mc
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func resolveModelCandidate(
|
||||
cfg *config.Config,
|
||||
defaultProvider string,
|
||||
raw string,
|
||||
) (providers.FallbackCandidate, bool) {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return providers.FallbackCandidate{}, false
|
||||
}
|
||||
|
||||
if mc := lookupModelConfigByRef(cfg, raw); mc != nil {
|
||||
return candidateFromModelConfig(defaultProvider, mc)
|
||||
}
|
||||
|
||||
ref := providers.ParseModelRef(raw, defaultProvider)
|
||||
if ref == nil {
|
||||
return providers.FallbackCandidate{}, false
|
||||
}
|
||||
|
||||
return providers.FallbackCandidate{
|
||||
Provider: ref.Provider,
|
||||
Model: ref.Model,
|
||||
}, true
|
||||
}
|
||||
|
||||
func resolveModelCandidates(
|
||||
@@ -54,14 +112,29 @@ func resolveModelCandidates(
|
||||
primary string,
|
||||
fallbacks []string,
|
||||
) []providers.FallbackCandidate {
|
||||
return providers.ResolveCandidatesWithLookup(
|
||||
providers.ModelConfig{
|
||||
Primary: primary,
|
||||
Fallbacks: fallbacks,
|
||||
},
|
||||
defaultProvider,
|
||||
buildModelListResolver(cfg),
|
||||
)
|
||||
seen := make(map[string]bool)
|
||||
candidates := make([]providers.FallbackCandidate, 0, 1+len(fallbacks))
|
||||
|
||||
addCandidate := func(raw string) {
|
||||
candidate, ok := resolveModelCandidate(cfg, defaultProvider, raw)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
key := candidate.StableKey()
|
||||
if seen[key] {
|
||||
return
|
||||
}
|
||||
seen[key] = true
|
||||
candidates = append(candidates, candidate)
|
||||
}
|
||||
|
||||
addCandidate(primary)
|
||||
for _, fallback := range fallbacks {
|
||||
addCandidate(fallback)
|
||||
}
|
||||
|
||||
return candidates
|
||||
}
|
||||
|
||||
func resolvedCandidateModel(candidates []providers.FallbackCandidate, fallback string) string {
|
||||
|
||||
@@ -432,6 +432,7 @@ func spawnSubTurn(
|
||||
// 7. Defer cleanup: deliver result (for async), emit End event, and recover from panics
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
logger.RecoverPanicNoExit(r)
|
||||
err = fmt.Errorf("subturn panicked: %v", r)
|
||||
result = nil
|
||||
logger.ErrorCF("subturn", "SubTurn panicked", map[string]any{
|
||||
@@ -515,6 +516,7 @@ func deliverSubTurnResult(al *AgentLoop, parentTS *turnState, childID string, re
|
||||
// We use defer/recover to catch any unlikely channel panics if it were ever closed.
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
logger.RecoverPanicNoExit(r)
|
||||
logger.WarnCF("subturn", "recovered panic sending to pendingResults", map[string]any{
|
||||
"parent_id": parentTS.turnID,
|
||||
"child_id": childID,
|
||||
@@ -607,6 +609,7 @@ type ephemeralSessionStoreIface interface {
|
||||
SetHistory(key string, history []providers.Message)
|
||||
TruncateHistory(key string, keepLast int)
|
||||
Save(key string) error
|
||||
ListSessions() []string
|
||||
Close() error
|
||||
}
|
||||
|
||||
@@ -666,8 +669,9 @@ func (e *ephemeralSessionStore) TruncateHistory(_ string, keepLast int) {
|
||||
e.history = e.history[len(e.history)-keepLast:]
|
||||
}
|
||||
|
||||
func (e *ephemeralSessionStore) Save(_ string) error { return nil }
|
||||
func (e *ephemeralSessionStore) Close() error { return nil }
|
||||
func (e *ephemeralSessionStore) Save(_ string) error { return nil }
|
||||
func (e *ephemeralSessionStore) Close() error { return nil }
|
||||
func (e *ephemeralSessionStore) ListSessions() []string { return nil }
|
||||
|
||||
func (e *ephemeralSessionStore) truncateLocked() {
|
||||
if len(e.history) > maxEphemeralHistorySize {
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/session"
|
||||
"github.com/sipeed/picoclaw/pkg/tools"
|
||||
@@ -341,6 +342,23 @@ func (ts *turnState) refreshRestorePointFromSession(agent *AgentInstance) {
|
||||
ts.captureRestorePoint(history, summary)
|
||||
}
|
||||
|
||||
// ingestMessage calls the ContextManager's Ingest method for a persisted message.
|
||||
// Errors are logged but never block the turn.
|
||||
func (ts *turnState) ingestMessage(ctx context.Context, al *AgentLoop, msg providers.Message) {
|
||||
if al.contextManager == nil {
|
||||
return
|
||||
}
|
||||
if err := al.contextManager.Ingest(ctx, &IngestRequest{
|
||||
SessionKey: ts.sessionKey,
|
||||
Message: msg,
|
||||
}); err != nil {
|
||||
logger.WarnCF("agent", "Context manager ingest failed", map[string]any{
|
||||
"session_key": ts.sessionKey,
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (ts *turnState) restoreSession(agent *AgentInstance) error {
|
||||
ts.mu.RLock()
|
||||
history := append([]providers.Message(nil), ts.restorePointHistory...)
|
||||
|
||||
Reference in New Issue
Block a user