mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
feat: add ContextManager abstraction for pluggable context management (#2203)
- Define ContextManager interface with Assemble/Compact/Ingest methods - Implement legacyContextManager wrapping existing summarization logic - Wire Assemble (before BuildMessages), Compact (post-turn + overflow), and Ingest (after message persistence) into agent loop - Add ContextManager config field and factory registry with config passthrough - Remove old maybeSummarize/summarizeSession/summarizeBatch/etc from loop.go - All existing tests pass with default (legacy) config Co-authored-by: Liu Yuan <namei.unix@gmail.com>
This commit is contained in:
@@ -0,0 +1,379 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
// legacyContextManager wraps the existing summarization/compression logic
|
||||
// as a ContextManager implementation. It is the default when no other
|
||||
// ContextManager is configured.
|
||||
type legacyContextManager struct {
|
||||
al *AgentLoop
|
||||
summarizing sync.Map // dedup for async Compact (post-turn)
|
||||
}
|
||||
|
||||
func (m *legacyContextManager) Assemble(_ context.Context, req *AssembleRequest) (*AssembleResponse, error) {
|
||||
// Legacy: read history from session, return as-is.
|
||||
// Budget enforcement happens in BuildMessages caller via
|
||||
// isOverContextBudget + forceCompression.
|
||||
agent := m.al.registry.GetDefaultAgent()
|
||||
if agent == nil {
|
||||
return &AssembleResponse{}, nil
|
||||
}
|
||||
history := agent.Sessions.GetHistory(req.SessionKey)
|
||||
summary := agent.Sessions.GetSummary(req.SessionKey)
|
||||
return &AssembleResponse{
|
||||
History: history,
|
||||
Summary: summary,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *legacyContextManager) Compact(_ context.Context, req *CompactRequest) error {
|
||||
switch req.Reason {
|
||||
case ContextCompressReasonProactive, ContextCompressReasonRetry:
|
||||
// Sync emergency compression — budget exceeded.
|
||||
if result, ok := m.forceCompression(req.SessionKey); ok {
|
||||
m.al.emitEvent(
|
||||
EventKindContextCompress,
|
||||
m.al.newTurnEventScope("", req.SessionKey).meta(0, "forceCompression", "turn.context.compress"),
|
||||
ContextCompressPayload{
|
||||
Reason: req.Reason,
|
||||
DroppedMessages: result.DroppedMessages,
|
||||
RemainingMessages: result.RemainingMessages,
|
||||
},
|
||||
)
|
||||
}
|
||||
case ContextCompressReasonSummarize:
|
||||
m.maybeSummarize(req.SessionKey)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *legacyContextManager) Ingest(_ context.Context, _ *IngestRequest) error {
|
||||
// Legacy: no-op. Messages are persisted by Sessions JSONL.
|
||||
return nil
|
||||
}
|
||||
|
||||
// maybeSummarize triggers summarization if the session history exceeds thresholds.
|
||||
// It runs asynchronously in a goroutine.
|
||||
func (m *legacyContextManager) maybeSummarize(sessionKey string) {
|
||||
agent := m.al.registry.GetDefaultAgent()
|
||||
if agent == nil {
|
||||
return
|
||||
}
|
||||
|
||||
newHistory := agent.Sessions.GetHistory(sessionKey)
|
||||
tokenEstimate := m.estimateTokens(newHistory)
|
||||
threshold := agent.ContextWindow * agent.SummarizeTokenPercent / 100
|
||||
|
||||
if len(newHistory) > agent.SummarizeMessageThreshold || tokenEstimate > threshold {
|
||||
summarizeKey := agent.ID + ":" + sessionKey
|
||||
if _, loading := m.summarizing.LoadOrStore(summarizeKey, true); !loading {
|
||||
go func() {
|
||||
defer m.summarizing.Delete(summarizeKey)
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
logger.WarnCF("agent", "Summarization panic recovered", map[string]any{
|
||||
"session_key": sessionKey,
|
||||
"panic": r,
|
||||
})
|
||||
}
|
||||
}()
|
||||
logger.Debug("Memory threshold reached. Optimizing conversation history...")
|
||||
m.summarizeSession(agent, sessionKey)
|
||||
}()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type compressionResult struct {
|
||||
DroppedMessages int
|
||||
RemainingMessages int
|
||||
}
|
||||
|
||||
// forceCompression aggressively reduces context when the limit is hit.
|
||||
// It drops the oldest ~50% of Turns (a Turn is a complete user→LLM→response
|
||||
// cycle, as defined in #1316), so tool-call sequences are never split.
|
||||
func (m *legacyContextManager) forceCompression(sessionKey string) (compressionResult, bool) {
|
||||
agent := m.al.registry.GetDefaultAgent()
|
||||
if agent == nil {
|
||||
return compressionResult{}, false
|
||||
}
|
||||
|
||||
history := agent.Sessions.GetHistory(sessionKey)
|
||||
if len(history) <= 2 {
|
||||
return compressionResult{}, false
|
||||
}
|
||||
|
||||
turns := parseTurnBoundaries(history)
|
||||
var mid int
|
||||
if len(turns) >= 2 {
|
||||
mid = turns[len(turns)/2]
|
||||
} else {
|
||||
mid = findSafeBoundary(history, len(history)/2)
|
||||
}
|
||||
var keptHistory []providers.Message
|
||||
if mid <= 0 {
|
||||
for i := len(history) - 1; i >= 0; i-- {
|
||||
if history[i].Role == "user" {
|
||||
keptHistory = []providers.Message{history[i]}
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
keptHistory = history[mid:]
|
||||
}
|
||||
|
||||
droppedCount := len(history) - len(keptHistory)
|
||||
|
||||
existingSummary := agent.Sessions.GetSummary(sessionKey)
|
||||
compressionNote := fmt.Sprintf(
|
||||
"[Emergency compression dropped %d oldest messages due to context limit]",
|
||||
droppedCount,
|
||||
)
|
||||
if existingSummary != "" {
|
||||
compressionNote = existingSummary + "\n\n" + compressionNote
|
||||
}
|
||||
agent.Sessions.SetSummary(sessionKey, compressionNote)
|
||||
|
||||
agent.Sessions.SetHistory(sessionKey, keptHistory)
|
||||
agent.Sessions.Save(sessionKey)
|
||||
|
||||
logger.WarnCF("agent", "Forced compression executed", map[string]any{
|
||||
"session_key": sessionKey,
|
||||
"dropped_msgs": droppedCount,
|
||||
"new_count": len(keptHistory),
|
||||
})
|
||||
|
||||
return compressionResult{
|
||||
DroppedMessages: droppedCount,
|
||||
RemainingMessages: len(keptHistory),
|
||||
}, true
|
||||
}
|
||||
|
||||
func (m *legacyContextManager) summarizeSession(agent *AgentInstance, sessionKey string) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||
defer cancel()
|
||||
|
||||
history := agent.Sessions.GetHistory(sessionKey)
|
||||
summary := agent.Sessions.GetSummary(sessionKey)
|
||||
|
||||
if len(history) <= 4 {
|
||||
return
|
||||
}
|
||||
|
||||
safeCut := findSafeBoundary(history, len(history)-4)
|
||||
if safeCut <= 0 {
|
||||
return
|
||||
}
|
||||
keepCount := len(history) - safeCut
|
||||
toSummarize := history[:safeCut]
|
||||
|
||||
maxMessageTokens := agent.ContextWindow / 2
|
||||
validMessages := make([]providers.Message, 0)
|
||||
omitted := false
|
||||
|
||||
for _, msg := range toSummarize {
|
||||
if msg.Role != "user" && msg.Role != "assistant" {
|
||||
continue
|
||||
}
|
||||
msgTokens := len(msg.Content) / 2
|
||||
if msgTokens > maxMessageTokens {
|
||||
omitted = true
|
||||
continue
|
||||
}
|
||||
validMessages = append(validMessages, msg)
|
||||
}
|
||||
|
||||
if len(validMessages) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
const (
|
||||
maxSummarizationMessages = 10
|
||||
llmMaxRetries = 3
|
||||
)
|
||||
|
||||
var finalSummary string
|
||||
if len(validMessages) > maxSummarizationMessages {
|
||||
mid := len(validMessages) / 2
|
||||
mid = m.findNearestUserMessage(validMessages, mid)
|
||||
|
||||
part1 := validMessages[:mid]
|
||||
part2 := validMessages[mid:]
|
||||
|
||||
s1, _ := m.summarizeBatch(ctx, agent, part1, "")
|
||||
s2, _ := m.summarizeBatch(ctx, agent, part2, "")
|
||||
|
||||
mergePrompt := fmt.Sprintf(
|
||||
"Merge these two conversation summaries into one cohesive summary:\n\n1: %s\n\n2: %s",
|
||||
s1, s2,
|
||||
)
|
||||
|
||||
resp, err := m.retryLLMCall(ctx, agent, mergePrompt, llmMaxRetries)
|
||||
if err == nil && resp.Content != "" {
|
||||
finalSummary = resp.Content
|
||||
} else {
|
||||
finalSummary = s1 + " " + s2
|
||||
}
|
||||
} else {
|
||||
finalSummary, _ = m.summarizeBatch(ctx, agent, validMessages, summary)
|
||||
}
|
||||
|
||||
if omitted && finalSummary != "" {
|
||||
finalSummary += "\n[Note: Some oversized messages were omitted from this summary for efficiency.]"
|
||||
}
|
||||
|
||||
if finalSummary != "" {
|
||||
agent.Sessions.SetSummary(sessionKey, finalSummary)
|
||||
agent.Sessions.TruncateHistory(sessionKey, keepCount)
|
||||
agent.Sessions.Save(sessionKey)
|
||||
m.al.emitEvent(
|
||||
EventKindSessionSummarize,
|
||||
m.al.newTurnEventScope(agent.ID, sessionKey).meta(0, "summarizeSession", "turn.session.summarize"),
|
||||
SessionSummarizePayload{
|
||||
SummarizedMessages: len(validMessages),
|
||||
KeptMessages: keepCount,
|
||||
SummaryLen: len(finalSummary),
|
||||
OmittedOversized: omitted,
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *legacyContextManager) findNearestUserMessage(messages []providers.Message, mid int) int {
|
||||
originalMid := mid
|
||||
|
||||
for mid > 0 && messages[mid].Role != "user" {
|
||||
mid--
|
||||
}
|
||||
|
||||
if messages[mid].Role == "user" {
|
||||
return mid
|
||||
}
|
||||
|
||||
mid = originalMid
|
||||
for mid < len(messages) && messages[mid].Role != "user" {
|
||||
mid++
|
||||
}
|
||||
|
||||
if mid < len(messages) {
|
||||
return mid
|
||||
}
|
||||
|
||||
return originalMid
|
||||
}
|
||||
|
||||
func (m *legacyContextManager) retryLLMCall(
|
||||
ctx context.Context,
|
||||
agent *AgentInstance,
|
||||
prompt string,
|
||||
maxRetries int,
|
||||
) (*providers.LLMResponse, error) {
|
||||
const llmTemperature = 0.3
|
||||
|
||||
var resp *providers.LLMResponse
|
||||
var err error
|
||||
|
||||
for attempt := 0; attempt < maxRetries; attempt++ {
|
||||
m.al.activeRequests.Add(1)
|
||||
resp, err = func() (*providers.LLMResponse, error) {
|
||||
defer m.al.activeRequests.Done()
|
||||
return agent.Provider.Chat(
|
||||
ctx,
|
||||
[]providers.Message{{Role: "user", Content: prompt}},
|
||||
nil,
|
||||
agent.Model,
|
||||
map[string]any{
|
||||
"max_tokens": agent.MaxTokens,
|
||||
"temperature": llmTemperature,
|
||||
"prompt_cache_key": agent.ID,
|
||||
},
|
||||
)
|
||||
}()
|
||||
|
||||
if err == nil && resp != nil && resp.Content != "" {
|
||||
return resp, nil
|
||||
}
|
||||
if attempt < maxRetries-1 {
|
||||
time.Sleep(time.Duration(attempt+1) * 100 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
|
||||
return resp, err
|
||||
}
|
||||
|
||||
func (m *legacyContextManager) summarizeBatch(
|
||||
ctx context.Context,
|
||||
agent *AgentInstance,
|
||||
batch []providers.Message,
|
||||
existingSummary string,
|
||||
) (string, error) {
|
||||
const (
|
||||
llmMaxRetries = 3
|
||||
fallbackMinContentLength = 200
|
||||
fallbackMaxContentPercent = 10
|
||||
)
|
||||
|
||||
var sb strings.Builder
|
||||
sb.WriteString("Provide a concise summary of this conversation segment, preserving core context and key points.\n")
|
||||
if existingSummary != "" {
|
||||
sb.WriteString("Existing context: ")
|
||||
sb.WriteString(existingSummary)
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
sb.WriteString("\nCONVERSATION:\n")
|
||||
for _, msg := range batch {
|
||||
fmt.Fprintf(&sb, "%s: %s\n", msg.Role, msg.Content)
|
||||
}
|
||||
prompt := sb.String()
|
||||
|
||||
response, err := m.retryLLMCall(ctx, agent, prompt, llmMaxRetries)
|
||||
if err == nil && response.Content != "" {
|
||||
return strings.TrimSpace(response.Content), nil
|
||||
}
|
||||
|
||||
var fallback strings.Builder
|
||||
fallback.WriteString("Conversation summary: ")
|
||||
for i, msg := range batch {
|
||||
if i > 0 {
|
||||
fallback.WriteString(" | ")
|
||||
}
|
||||
content := strings.TrimSpace(msg.Content)
|
||||
runes := []rune(content)
|
||||
if len(runes) == 0 {
|
||||
fallback.WriteString(fmt.Sprintf("%s: ", msg.Role))
|
||||
continue
|
||||
}
|
||||
|
||||
keepLength := len(runes) * fallbackMaxContentPercent / 100
|
||||
if keepLength < fallbackMinContentLength {
|
||||
keepLength = fallbackMinContentLength
|
||||
}
|
||||
if keepLength > len(runes) {
|
||||
keepLength = len(runes)
|
||||
}
|
||||
|
||||
content = string(runes[:keepLength])
|
||||
if keepLength < len(runes) {
|
||||
content += "..."
|
||||
}
|
||||
fallback.WriteString(fmt.Sprintf("%s: %s", msg.Role, content))
|
||||
}
|
||||
return fallback.String(), nil
|
||||
}
|
||||
|
||||
func (m *legacyContextManager) estimateTokens(messages []providers.Message) int {
|
||||
total := 0
|
||||
for _, msg := range messages {
|
||||
total += estimateMessageTokens(msg)
|
||||
}
|
||||
return total
|
||||
}
|
||||
@@ -0,0 +1,89 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
// ContextManager manages conversation context via a pluggable strategy.
|
||||
// Exactly ONE ContextManager is active per AgentLoop, selected by config.
|
||||
// The default ("legacy") preserves current summarization behavior.
|
||||
type ContextManager interface {
|
||||
// Assemble builds budget-aware context from the ContextManager's own storage.
|
||||
// Called before BuildMessages. Returns assembled messages ready for LLM.
|
||||
Assemble(ctx context.Context, req *AssembleRequest) (*AssembleResponse, error)
|
||||
|
||||
// Compact compresses conversation history.
|
||||
// Called after turn completes (may be async internally) and on context overflow (sync).
|
||||
Compact(ctx context.Context, req *CompactRequest) error
|
||||
|
||||
// Ingest records a message into the ContextManager's own storage.
|
||||
// Called after each message is persisted to session JSONL.
|
||||
Ingest(ctx context.Context, req *IngestRequest) error
|
||||
}
|
||||
|
||||
// AssembleRequest is the input to Assemble.
|
||||
type AssembleRequest struct {
|
||||
SessionKey string // session identifier
|
||||
Budget int // context window in tokens
|
||||
MaxTokens int // max response tokens
|
||||
}
|
||||
|
||||
// AssembleResponse is the output of Assemble.
|
||||
type AssembleResponse struct {
|
||||
History []providers.Message // assembled conversation history for BuildMessages
|
||||
Summary string // conversation summary embedded into system prompt by BuildMessages
|
||||
}
|
||||
|
||||
// CompactRequest is the input to Compact.
|
||||
type CompactRequest struct {
|
||||
SessionKey string // session identifier
|
||||
Reason ContextCompressReason // proactive_budget | llm_retry | summarize
|
||||
}
|
||||
|
||||
// IngestRequest is the input to Ingest.
|
||||
type IngestRequest struct {
|
||||
SessionKey string // session identifier
|
||||
Message providers.Message // the message just persisted
|
||||
}
|
||||
|
||||
// ContextManagerFactory constructs a ContextManager from config.
|
||||
// al provides access to the AgentLoop's runtime resources (provider, model, workspace, etc.)
|
||||
// cfg is the raw JSON configuration from config.json (may be nil).
|
||||
type ContextManagerFactory func(cfg json.RawMessage, al *AgentLoop) (ContextManager, error)
|
||||
|
||||
var (
|
||||
cmRegistryMu sync.RWMutex
|
||||
cmRegistry = map[string]ContextManagerFactory{}
|
||||
)
|
||||
|
||||
// RegisterContextManager registers a named ContextManager factory.
|
||||
func RegisterContextManager(name string, factory ContextManagerFactory) error {
|
||||
if name == "" {
|
||||
return fmt.Errorf("context manager name is required")
|
||||
}
|
||||
if factory == nil {
|
||||
return fmt.Errorf("context manager %q factory is nil", name)
|
||||
}
|
||||
|
||||
cmRegistryMu.Lock()
|
||||
defer cmRegistryMu.Unlock()
|
||||
|
||||
if _, exists := cmRegistry[name]; exists {
|
||||
return fmt.Errorf("context manager %q is already registered", name)
|
||||
}
|
||||
cmRegistry[name] = factory
|
||||
return nil
|
||||
}
|
||||
|
||||
func lookupContextManager(name string) (ContextManagerFactory, bool) {
|
||||
cmRegistryMu.RLock()
|
||||
defer cmRegistryMu.RUnlock()
|
||||
|
||||
f, ok := cmRegistry[name]
|
||||
return f, ok
|
||||
}
|
||||
@@ -0,0 +1,764 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Factory registry tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestRegisterContextManager_Success(t *testing.T) {
|
||||
cleanup := resetCMRegistry()
|
||||
defer cleanup()
|
||||
|
||||
factory := func(cfg json.RawMessage, al *AgentLoop) (ContextManager, error) {
|
||||
return &noopContextManager{}, nil
|
||||
}
|
||||
if err := RegisterContextManager("test_cm", factory); err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
f, ok := lookupContextManager("test_cm")
|
||||
if !ok {
|
||||
t.Fatal("expected factory to be registered")
|
||||
}
|
||||
if f == nil {
|
||||
t.Fatal("expected non-nil factory")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegisterContextManager_EmptyName(t *testing.T) {
|
||||
cleanup := resetCMRegistry()
|
||||
defer cleanup()
|
||||
|
||||
err := RegisterContextManager("", func(cfg json.RawMessage, al *AgentLoop) (ContextManager, error) {
|
||||
return &noopContextManager{}, nil
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("expected error for empty name")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "name is required") {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegisterContextManager_NilFactory(t *testing.T) {
|
||||
cleanup := resetCMRegistry()
|
||||
defer cleanup()
|
||||
|
||||
err := RegisterContextManager("nil_factory", nil)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for nil factory")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "factory is nil") {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegisterContextManager_Duplicate(t *testing.T) {
|
||||
cleanup := resetCMRegistry()
|
||||
defer cleanup()
|
||||
|
||||
factory := func(cfg json.RawMessage, al *AgentLoop) (ContextManager, error) {
|
||||
return &noopContextManager{}, nil
|
||||
}
|
||||
if err := RegisterContextManager("dup_cm", factory); err != nil {
|
||||
t.Fatalf("first registration failed: %v", err)
|
||||
}
|
||||
err := RegisterContextManager("dup_cm", factory)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for duplicate registration")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "already registered") {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLookupContextManager_Unknown(t *testing.T) {
|
||||
cleanup := resetCMRegistry()
|
||||
defer cleanup()
|
||||
|
||||
_, ok := lookupContextManager("nonexistent")
|
||||
if ok {
|
||||
t.Fatal("expected lookup to fail for unknown name")
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// resolveContextManager tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestResolveContextManager_Default(t *testing.T) {
|
||||
cleanup := resetCMRegistry()
|
||||
defer cleanup()
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: t.TempDir(),
|
||||
ModelName: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
ContextManager: "", // default → legacy
|
||||
},
|
||||
},
|
||||
}
|
||||
al := newCMTestAgentLoop(cfg)
|
||||
|
||||
cm := al.contextManager
|
||||
if cm == nil {
|
||||
t.Fatal("expected non-nil context manager")
|
||||
}
|
||||
if _, ok := cm.(*legacyContextManager); !ok {
|
||||
t.Fatalf("expected *legacyContextManager, got %T", cm)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveContextManager_ExplicitLegacy(t *testing.T) {
|
||||
cleanup := resetCMRegistry()
|
||||
defer cleanup()
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: t.TempDir(),
|
||||
ModelName: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
ContextManager: "legacy",
|
||||
},
|
||||
},
|
||||
}
|
||||
al := newCMTestAgentLoop(cfg)
|
||||
|
||||
if _, ok := al.contextManager.(*legacyContextManager); !ok {
|
||||
t.Fatalf("expected *legacyContextManager, got %T", al.contextManager)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveContextManager_UnknownFallsBackToLegacy(t *testing.T) {
|
||||
cleanup := resetCMRegistry()
|
||||
defer cleanup()
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: t.TempDir(),
|
||||
ModelName: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
ContextManager: "unknown_cm",
|
||||
},
|
||||
},
|
||||
}
|
||||
al := newCMTestAgentLoop(cfg)
|
||||
|
||||
if _, ok := al.contextManager.(*legacyContextManager); !ok {
|
||||
t.Fatalf("expected fallback to *legacyContextManager, got %T", al.contextManager)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveContextManager_RegisteredFactory(t *testing.T) {
|
||||
cleanup := resetCMRegistry()
|
||||
defer cleanup()
|
||||
|
||||
factory := func(cfg json.RawMessage, al *AgentLoop) (ContextManager, error) {
|
||||
return &noopContextManager{}, nil
|
||||
}
|
||||
if err := RegisterContextManager("custom_cm", factory); err != nil {
|
||||
t.Fatalf("register failed: %v", err)
|
||||
}
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: t.TempDir(),
|
||||
ModelName: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
ContextManager: "custom_cm",
|
||||
},
|
||||
},
|
||||
}
|
||||
al := newCMTestAgentLoop(cfg)
|
||||
|
||||
if _, ok := al.contextManager.(*noopContextManager); !ok {
|
||||
t.Fatalf("expected *noopContextManager, got %T", al.contextManager)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveContextManager_FactoryError(t *testing.T) {
|
||||
cleanup := resetCMRegistry()
|
||||
defer cleanup()
|
||||
|
||||
factory := func(cfg json.RawMessage, al *AgentLoop) (ContextManager, error) {
|
||||
return nil, os.ErrPermission
|
||||
}
|
||||
if err := RegisterContextManager("broken_cm", factory); err != nil {
|
||||
t.Fatalf("register failed: %v", err)
|
||||
}
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: t.TempDir(),
|
||||
ModelName: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
ContextManager: "broken_cm",
|
||||
},
|
||||
},
|
||||
}
|
||||
al := newCMTestAgentLoop(cfg)
|
||||
|
||||
// Should fall back to legacy when factory returns error
|
||||
if _, ok := al.contextManager.(*legacyContextManager); !ok {
|
||||
t.Fatalf("expected fallback to *legacyContextManager on factory error, got %T", al.contextManager)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Legacy Assemble tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestLegacyAssemble_Passthrough(t *testing.T) {
|
||||
cfg := testConfig(t)
|
||||
al := newCMTestAgentLoop(cfg)
|
||||
|
||||
agent := al.registry.GetDefaultAgent()
|
||||
if agent == nil {
|
||||
t.Fatal("expected default agent")
|
||||
}
|
||||
|
||||
history := []providers.Message{
|
||||
{Role: "user", Content: "hello"},
|
||||
{Role: "assistant", Content: "hi there"},
|
||||
}
|
||||
agent.Sessions.SetHistory("test-session", history)
|
||||
|
||||
resp, err := al.contextManager.Assemble(context.Background(), &AssembleRequest{
|
||||
SessionKey: "test-session",
|
||||
Budget: 8000,
|
||||
MaxTokens: 4096,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(resp.History) != len(history) {
|
||||
t.Fatalf("expected %d messages, got %d", len(history), len(resp.History))
|
||||
}
|
||||
for i, msg := range resp.History {
|
||||
if msg.Content != history[i].Content || msg.Role != history[i].Role {
|
||||
t.Fatalf("message %d mismatch: want %+v, got %+v", i, history[i], msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestLegacyAssemble_EmptyHistory(t *testing.T) {
|
||||
cfg := testConfig(t)
|
||||
al := newCMTestAgentLoop(cfg)
|
||||
|
||||
resp, err := al.contextManager.Assemble(context.Background(), &AssembleRequest{
|
||||
SessionKey: "test-session",
|
||||
Budget: 8000,
|
||||
MaxTokens: 4096,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(resp.History) != 0 {
|
||||
t.Fatalf("expected empty messages, got %d", len(resp.History))
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Legacy Compact overflow tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestLegacyCompact_Overflow(t *testing.T) {
|
||||
cfg := testConfig(t)
|
||||
al := newCMTestAgentLoop(cfg)
|
||||
|
||||
defaultAgent := al.registry.GetDefaultAgent()
|
||||
if defaultAgent == nil {
|
||||
t.Fatal("expected default agent")
|
||||
}
|
||||
|
||||
history := []providers.Message{
|
||||
{Role: "user", Content: "msg 1"},
|
||||
{Role: "assistant", Content: "resp 1"},
|
||||
{Role: "user", Content: "msg 2"},
|
||||
{Role: "assistant", Content: "resp 2"},
|
||||
{Role: "user", Content: "msg 3"},
|
||||
}
|
||||
defaultAgent.Sessions.SetHistory("session-overflow", history)
|
||||
|
||||
sub := al.SubscribeEvents(16)
|
||||
defer al.UnsubscribeEvents(sub.ID)
|
||||
|
||||
err := al.contextManager.Compact(context.Background(), &CompactRequest{
|
||||
SessionKey: "session-overflow",
|
||||
Reason: ContextCompressReasonRetry,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// After overflow compression, history should be shorter
|
||||
newHistory := defaultAgent.Sessions.GetHistory("session-overflow")
|
||||
if len(newHistory) >= len(history) {
|
||||
t.Fatalf("expected compressed history, got %d messages (was %d)", len(newHistory), len(history))
|
||||
}
|
||||
|
||||
// Summary should contain compression note
|
||||
summary := defaultAgent.Sessions.GetSummary("session-overflow")
|
||||
if !strings.Contains(summary, "Emergency compression") {
|
||||
t.Fatalf("expected compression note in summary, got %q", summary)
|
||||
}
|
||||
|
||||
// Event should carry the proactive reason
|
||||
events := collectEventStream(sub.C)
|
||||
compressEvt, ok := findEvent(events, EventKindContextCompress)
|
||||
if !ok {
|
||||
t.Fatal("expected context compress event")
|
||||
}
|
||||
payload, ok := compressEvt.Payload.(ContextCompressPayload)
|
||||
if !ok {
|
||||
t.Fatalf("expected ContextCompressPayload, got %T", compressEvt.Payload)
|
||||
}
|
||||
if payload.Reason != ContextCompressReasonRetry {
|
||||
t.Fatalf("expected retry reason, got %q", payload.Reason)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLegacyCompact_Overflow_ProactiveReason(t *testing.T) {
|
||||
cfg := testConfig(t)
|
||||
al := newCMTestAgentLoop(cfg)
|
||||
|
||||
defaultAgent := al.registry.GetDefaultAgent()
|
||||
if defaultAgent == nil {
|
||||
t.Fatal("expected default agent")
|
||||
}
|
||||
|
||||
history := []providers.Message{
|
||||
{Role: "user", Content: "msg 1"},
|
||||
{Role: "assistant", Content: "resp 1"},
|
||||
{Role: "user", Content: "msg 2"},
|
||||
{Role: "assistant", Content: "resp 2"},
|
||||
{Role: "user", Content: "msg 3"},
|
||||
}
|
||||
defaultAgent.Sessions.SetHistory("session-proactive", history)
|
||||
|
||||
sub := al.SubscribeEvents(16)
|
||||
defer al.UnsubscribeEvents(sub.ID)
|
||||
|
||||
err := al.contextManager.Compact(context.Background(), &CompactRequest{
|
||||
SessionKey: "session-proactive",
|
||||
Reason: ContextCompressReasonProactive,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
events := collectEventStream(sub.C)
|
||||
compressEvt, ok := findEvent(events, EventKindContextCompress)
|
||||
if !ok {
|
||||
t.Fatal("expected context compress event")
|
||||
}
|
||||
payload, ok := compressEvt.Payload.(ContextCompressPayload)
|
||||
if !ok {
|
||||
t.Fatalf("expected ContextCompressPayload, got %T", compressEvt.Payload)
|
||||
}
|
||||
if payload.Reason != ContextCompressReasonProactive {
|
||||
t.Fatalf("expected proactive reason, got %q", payload.Reason)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLegacyCompact_Overflow_TooShortToCompress(t *testing.T) {
|
||||
cfg := testConfig(t)
|
||||
al := newCMTestAgentLoop(cfg)
|
||||
|
||||
defaultAgent := al.registry.GetDefaultAgent()
|
||||
if defaultAgent == nil {
|
||||
t.Fatal("expected default agent")
|
||||
}
|
||||
|
||||
history := []providers.Message{
|
||||
{Role: "user", Content: "only one"},
|
||||
}
|
||||
defaultAgent.Sessions.SetHistory("session-tiny", history)
|
||||
|
||||
err := al.contextManager.Compact(context.Background(), &CompactRequest{
|
||||
SessionKey: "session-tiny",
|
||||
Reason: ContextCompressReasonRetry,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// History should be unchanged (too short to compress)
|
||||
newHistory := defaultAgent.Sessions.GetHistory("session-tiny")
|
||||
if len(newHistory) != len(history) {
|
||||
t.Fatalf("expected history unchanged, got %d messages (was %d)", len(newHistory), len(history))
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Legacy Compact post-turn tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestLegacyCompact_PostTurn_BelowThreshold(t *testing.T) {
|
||||
cfg := testConfig(t)
|
||||
al := newCMTestAgentLoop(cfg)
|
||||
|
||||
defaultAgent := al.registry.GetDefaultAgent()
|
||||
if defaultAgent == nil {
|
||||
t.Fatal("expected default agent")
|
||||
}
|
||||
|
||||
// Small history, below summarization thresholds
|
||||
history := []providers.Message{
|
||||
{Role: "user", Content: "hi"},
|
||||
{Role: "assistant", Content: "hello"},
|
||||
}
|
||||
defaultAgent.Sessions.SetHistory("session-small", history)
|
||||
|
||||
err := al.contextManager.Compact(context.Background(), &CompactRequest{
|
||||
SessionKey: "session-small",
|
||||
Reason: ContextCompressReasonSummarize,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// History should remain unchanged
|
||||
newHistory := defaultAgent.Sessions.GetHistory("session-small")
|
||||
if len(newHistory) != len(history) {
|
||||
t.Fatalf("expected unchanged history, got %d messages (was %d)", len(newHistory), len(history))
|
||||
}
|
||||
}
|
||||
|
||||
func TestLegacyCompact_PostTurn_ExceedsMessageThreshold(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: t.TempDir(),
|
||||
ModelName: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
ContextWindow: 8000,
|
||||
SummarizeMessageThreshold: 2,
|
||||
SummarizeTokenPercent: 75,
|
||||
},
|
||||
},
|
||||
}
|
||||
msgBus := bus.NewMessageBus()
|
||||
al := NewAgentLoop(cfg, msgBus, &simpleMockProvider{response: "summary"})
|
||||
|
||||
defaultAgent := al.registry.GetDefaultAgent()
|
||||
if defaultAgent == nil {
|
||||
t.Fatal("expected default agent")
|
||||
}
|
||||
|
||||
// 6 messages > threshold of 2
|
||||
history := []providers.Message{
|
||||
{Role: "user", Content: "q1"},
|
||||
{Role: "assistant", Content: "a1"},
|
||||
{Role: "user", Content: "q2"},
|
||||
{Role: "assistant", Content: "a2"},
|
||||
{Role: "user", Content: "q3"},
|
||||
{Role: "assistant", Content: "a3"},
|
||||
}
|
||||
defaultAgent.Sessions.SetHistory("session-threshold", history)
|
||||
|
||||
err := al.contextManager.Compact(context.Background(), &CompactRequest{
|
||||
SessionKey: "session-threshold",
|
||||
Reason: ContextCompressReasonSummarize,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Wait for async summarization to complete via event
|
||||
sub := al.SubscribeEvents(16)
|
||||
defer al.UnsubscribeEvents(sub.ID)
|
||||
|
||||
waitForEvent(t, sub.C, 5*time.Second, func(evt Event) bool {
|
||||
return evt.Kind == EventKindSessionSummarize
|
||||
})
|
||||
|
||||
newHistory := defaultAgent.Sessions.GetHistory("session-threshold")
|
||||
if len(newHistory) >= len(history) {
|
||||
t.Fatalf("expected summarization to reduce history from %d messages, got %d", len(history), len(newHistory))
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Legacy Ingest tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestLegacyIngest_NoOp(t *testing.T) {
|
||||
cfg := testConfig(t)
|
||||
al := newCMTestAgentLoop(cfg)
|
||||
|
||||
err := al.contextManager.Ingest(context.Background(), &IngestRequest{
|
||||
SessionKey: "session-ingest",
|
||||
Message: providers.Message{Role: "user", Content: "test"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Mock ContextManager — verifies dispatch through AgentLoop
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestAgentLoop_UsesCustomContextManager(t *testing.T) {
|
||||
cleanup := resetCMRegistry()
|
||||
defer cleanup()
|
||||
|
||||
mock := &trackingContextManager{}
|
||||
factory := func(cfg json.RawMessage, al *AgentLoop) (ContextManager, error) {
|
||||
return mock, nil
|
||||
}
|
||||
if err := RegisterContextManager("tracking_cm", factory); err != nil {
|
||||
t.Fatalf("register failed: %v", err)
|
||||
}
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: t.TempDir(),
|
||||
ModelName: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
ContextManager: "tracking_cm",
|
||||
},
|
||||
},
|
||||
}
|
||||
al := newCMTestAgentLoop(cfg)
|
||||
|
||||
// Verify the mock was installed
|
||||
if al.contextManager != mock {
|
||||
t.Fatalf("expected mock context manager, got %T", al.contextManager)
|
||||
}
|
||||
|
||||
// Direct method calls
|
||||
_, err := mock.Assemble(context.Background(), &AssembleRequest{
|
||||
SessionKey: "s1",
|
||||
Budget: 8000,
|
||||
MaxTokens: 4096,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Assemble error: %v", err)
|
||||
}
|
||||
if mock.assembleCalls.Load() != 1 {
|
||||
t.Fatalf("expected 1 assemble call, got %d", mock.assembleCalls.Load())
|
||||
}
|
||||
|
||||
err = mock.Compact(context.Background(), &CompactRequest{
|
||||
SessionKey: "s1",
|
||||
Reason: ContextCompressReasonRetry,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Compact error: %v", err)
|
||||
}
|
||||
if mock.compactCalls.Load() != 1 {
|
||||
t.Fatalf("expected 1 compact call, got %d", mock.compactCalls.Load())
|
||||
}
|
||||
|
||||
err = mock.Ingest(context.Background(), &IngestRequest{
|
||||
SessionKey: "s1",
|
||||
Message: providers.Message{Role: "user", Content: "test"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Ingest error: %v", err)
|
||||
}
|
||||
if mock.ingestCalls.Load() != 1 {
|
||||
t.Fatalf("expected 1 ingest call, got %d", mock.ingestCalls.Load())
|
||||
}
|
||||
}
|
||||
|
||||
func TestIngestCalledDuringTurn(t *testing.T) {
|
||||
cleanup := resetCMRegistry()
|
||||
defer cleanup()
|
||||
|
||||
mock := &trackingContextManager{}
|
||||
factory := func(cfg json.RawMessage, al *AgentLoop) (ContextManager, error) {
|
||||
return mock, nil
|
||||
}
|
||||
if err := RegisterContextManager("ingest_track_cm", factory); err != nil {
|
||||
t.Fatalf("register failed: %v", err)
|
||||
}
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: t.TempDir(),
|
||||
ModelName: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
ContextManager: "ingest_track_cm",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
al := NewAgentLoop(cfg, msgBus, &simpleMockProvider{response: "done"})
|
||||
defaultAgent := al.registry.GetDefaultAgent()
|
||||
if defaultAgent == nil {
|
||||
t.Fatal("expected default agent")
|
||||
}
|
||||
|
||||
// Run a turn — ingestMessage is called for user message and final assistant message
|
||||
_, err := al.runAgentLoop(context.Background(), defaultAgent, processOptions{
|
||||
SessionKey: "session-ingest-turn",
|
||||
Channel: "cli",
|
||||
ChatID: "direct",
|
||||
UserMessage: "test ingest",
|
||||
DefaultResponse: defaultResponse,
|
||||
EnableSummary: false,
|
||||
SendResponse: false,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("runAgentLoop failed: %v", err)
|
||||
}
|
||||
|
||||
// Should have at least 2 ingest calls: user message + final assistant message
|
||||
if mock.ingestCalls.Load() < 2 {
|
||||
t.Fatalf("expected >= 2 ingest calls during turn, got %d", mock.ingestCalls.Load())
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// forceCompression edge cases (via legacy Compact)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestLegacyCompact_Overflow_SingleTurnKeepsLastUserMessage(t *testing.T) {
|
||||
cfg := testConfig(t)
|
||||
al := newCMTestAgentLoop(cfg)
|
||||
|
||||
defaultAgent := al.registry.GetDefaultAgent()
|
||||
if defaultAgent == nil {
|
||||
t.Fatal("expected default agent")
|
||||
}
|
||||
|
||||
// History with only 2 messages — forceCompression should still handle it
|
||||
history := []providers.Message{
|
||||
{Role: "user", Content: "first question"},
|
||||
{Role: "assistant", Content: "first answer"},
|
||||
}
|
||||
defaultAgent.Sessions.SetHistory("session-2msg", history)
|
||||
|
||||
err := al.contextManager.Compact(context.Background(), &CompactRequest{
|
||||
SessionKey: "session-2msg",
|
||||
Reason: ContextCompressReasonRetry,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
newHistory := defaultAgent.Sessions.GetHistory("session-2msg")
|
||||
// With 2 messages, forceCompression returns false (len <= 2), so no compression
|
||||
if len(newHistory) != len(history) {
|
||||
t.Fatalf("expected no compression for 2-message history, got %d", len(newHistory))
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Test helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// noopContextManager is a minimal ContextManager that does nothing.
|
||||
type noopContextManager struct{}
|
||||
|
||||
func (m *noopContextManager) Assemble(_ context.Context, req *AssembleRequest) (*AssembleResponse, error) {
|
||||
return &AssembleResponse{}, nil
|
||||
}
|
||||
func (m *noopContextManager) Compact(_ context.Context, _ *CompactRequest) error { return nil }
|
||||
func (m *noopContextManager) Ingest(_ context.Context, _ *IngestRequest) error { return nil }
|
||||
|
||||
// trackingContextManager tracks call counts for each method.
|
||||
type trackingContextManager struct {
|
||||
assembleCalls atomic.Int64
|
||||
compactCalls atomic.Int64
|
||||
ingestCalls atomic.Int64
|
||||
mu sync.Mutex
|
||||
lastAssemble *AssembleRequest
|
||||
lastCompact *CompactRequest
|
||||
lastIngest *IngestRequest
|
||||
}
|
||||
|
||||
func (m *trackingContextManager) Assemble(_ context.Context, req *AssembleRequest) (*AssembleResponse, error) {
|
||||
m.assembleCalls.Add(1)
|
||||
m.mu.Lock()
|
||||
m.lastAssemble = req
|
||||
m.mu.Unlock()
|
||||
return &AssembleResponse{}, nil
|
||||
}
|
||||
|
||||
func (m *trackingContextManager) Compact(_ context.Context, req *CompactRequest) error {
|
||||
m.compactCalls.Add(1)
|
||||
m.mu.Lock()
|
||||
m.lastCompact = req
|
||||
m.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *trackingContextManager) Ingest(_ context.Context, req *IngestRequest) error {
|
||||
m.ingestCalls.Add(1)
|
||||
m.mu.Lock()
|
||||
m.lastIngest = req
|
||||
m.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// resetCMRegistry clears the global factory registry and returns a cleanup
|
||||
// function that restores the original state after the test.
|
||||
func resetCMRegistry() func() {
|
||||
cmRegistryMu.Lock()
|
||||
original := make(map[string]ContextManagerFactory, len(cmRegistry))
|
||||
for k, v := range cmRegistry {
|
||||
original[k] = v
|
||||
}
|
||||
cmRegistry = make(map[string]ContextManagerFactory)
|
||||
cmRegistryMu.Unlock()
|
||||
|
||||
return func() {
|
||||
cmRegistryMu.Lock()
|
||||
cmRegistry = original
|
||||
cmRegistryMu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func testConfig(t *testing.T) *config.Config {
|
||||
t.Helper()
|
||||
return &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: t.TempDir(),
|
||||
ModelName: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func newCMTestAgentLoop(cfg *config.Config) *AgentLoop {
|
||||
msgBus := bus.NewMessageBus()
|
||||
return NewAgentLoop(cfg, msgBus, &simpleMockProvider{response: "test"})
|
||||
}
|
||||
@@ -472,8 +472,9 @@ func TestAgentLoop_EmitsSessionSummarizeEvent(t *testing.T) {
|
||||
sub := al.SubscribeEvents(16)
|
||||
defer al.UnsubscribeEvents(sub.ID)
|
||||
|
||||
turnScope := al.newTurnEventScope(defaultAgent.ID, "session-1")
|
||||
al.summarizeSession(defaultAgent, "session-1", turnScope)
|
||||
// Use legacyContextManager's summarizeSession via contextManager interface
|
||||
lcm := &legacyContextManager{al: al}
|
||||
lcm.summarizeSession(defaultAgent, "session-1")
|
||||
|
||||
events := collectEventStream(sub.C)
|
||||
summaryEvt, ok := findEvent(events, EventKindSessionSummarize)
|
||||
|
||||
@@ -167,6 +167,8 @@ const (
|
||||
ContextCompressReasonProactive ContextCompressReason = "proactive_budget"
|
||||
// ContextCompressReasonRetry indicates compression during context-error retry handling.
|
||||
ContextCompressReasonRetry ContextCompressReason = "llm_retry"
|
||||
// ContextCompressReasonSummarize indicates post-turn async summarization.
|
||||
ContextCompressReasonSummarize ContextCompressReason = "summarize"
|
||||
)
|
||||
|
||||
// ContextCompressPayload describes a forced history compression.
|
||||
|
||||
+81
-363
@@ -48,7 +48,7 @@ type AgentLoop struct {
|
||||
|
||||
// Runtime state
|
||||
running atomic.Bool
|
||||
summarizing sync.Map
|
||||
contextManager ContextManager
|
||||
fallback *providers.FallbackChain
|
||||
channelManager *channels.Manager
|
||||
mediaStore media.MediaStore
|
||||
@@ -137,13 +137,13 @@ func NewAgentLoop(
|
||||
registry: registry,
|
||||
state: stateManager,
|
||||
eventBus: eventBus,
|
||||
summarizing: sync.Map{},
|
||||
fallback: fallbackChain,
|
||||
cmdRegistry: commands.NewRegistry(commands.BuiltinDefinitions()),
|
||||
steering: newSteeringQueue(parseSteeringMode(cfg.Agents.Defaults.SteeringMode)),
|
||||
}
|
||||
al.hooks = NewHookManager(eventBus)
|
||||
configureHookManagerFromConfig(al.hooks, cfg)
|
||||
al.contextManager = al.resolveContextManager()
|
||||
|
||||
// Register shared tools to all agents (now that al is created)
|
||||
registerSharedTools(al, cfg, msgBus, registry, provider)
|
||||
@@ -1690,8 +1690,15 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er
|
||||
var history []providers.Message
|
||||
var summary string
|
||||
if !ts.opts.NoHistory {
|
||||
history = ts.agent.Sessions.GetHistory(ts.sessionKey)
|
||||
summary = ts.agent.Sessions.GetSummary(ts.sessionKey)
|
||||
// ContextManager assembles budget-aware history and summary.
|
||||
if resp, err := al.contextManager.Assemble(turnCtx, &AssembleRequest{
|
||||
SessionKey: ts.sessionKey,
|
||||
Budget: ts.agent.ContextWindow,
|
||||
MaxTokens: ts.agent.MaxTokens,
|
||||
}); err == nil && resp != nil {
|
||||
history = resp.History
|
||||
summary = resp.Summary
|
||||
}
|
||||
}
|
||||
ts.captureRestorePoint(history, summary)
|
||||
|
||||
@@ -1716,22 +1723,27 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er
|
||||
if isOverContextBudget(ts.agent.ContextWindow, messages, toolDefs, ts.agent.MaxTokens) {
|
||||
logger.WarnCF("agent", "Proactive compression: context budget exceeded before LLM call",
|
||||
map[string]any{"session_key": ts.sessionKey})
|
||||
if compression, ok := al.forceCompression(ts.agent, ts.sessionKey); ok {
|
||||
al.emitEvent(
|
||||
EventKindContextCompress,
|
||||
ts.eventMeta("runTurn", "turn.context.compress"),
|
||||
ContextCompressPayload{
|
||||
Reason: ContextCompressReasonProactive,
|
||||
DroppedMessages: compression.DroppedMessages,
|
||||
RemainingMessages: compression.RemainingMessages,
|
||||
},
|
||||
)
|
||||
ts.refreshRestorePointFromSession(ts.agent)
|
||||
if err := al.contextManager.Compact(turnCtx, &CompactRequest{
|
||||
SessionKey: ts.sessionKey,
|
||||
Reason: ContextCompressReasonProactive,
|
||||
}); err != nil {
|
||||
logger.WarnCF("agent", "Proactive compact failed", map[string]any{
|
||||
"session_key": ts.sessionKey,
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
ts.refreshRestorePointFromSession(ts.agent)
|
||||
// Re-assemble from CM after compact.
|
||||
if resp, err := al.contextManager.Assemble(turnCtx, &AssembleRequest{
|
||||
SessionKey: ts.sessionKey,
|
||||
Budget: ts.agent.ContextWindow,
|
||||
MaxTokens: ts.agent.MaxTokens,
|
||||
}); err == nil && resp != nil {
|
||||
history = resp.History
|
||||
summary = resp.Summary
|
||||
}
|
||||
newHistory := ts.agent.Sessions.GetHistory(ts.sessionKey)
|
||||
newSummary := ts.agent.Sessions.GetSummary(ts.sessionKey)
|
||||
messages = ts.agent.ContextBuilder.BuildMessages(
|
||||
newHistory, newSummary, ts.userMessage,
|
||||
history, summary, ts.userMessage,
|
||||
ts.media, ts.channel, ts.chatID,
|
||||
ts.opts.SenderID, ts.opts.SenderDisplayName,
|
||||
activeSkillNames(ts.agent, ts.opts)...,
|
||||
@@ -1753,6 +1765,7 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er
|
||||
ts.agent.Sessions.AddMessage(ts.sessionKey, rootMsg.Role, rootMsg.Content)
|
||||
}
|
||||
ts.recordPersistedMessage(rootMsg)
|
||||
ts.ingestMessage(turnCtx, al, rootMsg)
|
||||
}
|
||||
|
||||
activeCandidates, activeModel, usedLight := al.selectCandidates(ts.agent, ts.userMessage, messages)
|
||||
@@ -2096,23 +2109,27 @@ turnLoop:
|
||||
})
|
||||
}
|
||||
|
||||
if compression, ok := al.forceCompression(ts.agent, ts.sessionKey); ok {
|
||||
al.emitEvent(
|
||||
EventKindContextCompress,
|
||||
ts.eventMeta("runTurn", "turn.context.compress"),
|
||||
ContextCompressPayload{
|
||||
Reason: ContextCompressReasonRetry,
|
||||
DroppedMessages: compression.DroppedMessages,
|
||||
RemainingMessages: compression.RemainingMessages,
|
||||
},
|
||||
)
|
||||
ts.refreshRestorePointFromSession(ts.agent)
|
||||
if compactErr := al.contextManager.Compact(turnCtx, &CompactRequest{
|
||||
SessionKey: ts.sessionKey,
|
||||
Reason: ContextCompressReasonRetry,
|
||||
}); compactErr != nil {
|
||||
logger.WarnCF("agent", "Context overflow compact failed", map[string]any{
|
||||
"session_key": ts.sessionKey,
|
||||
"error": compactErr.Error(),
|
||||
})
|
||||
}
|
||||
ts.refreshRestorePointFromSession(ts.agent)
|
||||
// Re-assemble from CM after compact.
|
||||
if asmResp, asmErr := al.contextManager.Assemble(turnCtx, &AssembleRequest{
|
||||
SessionKey: ts.sessionKey,
|
||||
Budget: ts.agent.ContextWindow,
|
||||
MaxTokens: ts.agent.MaxTokens,
|
||||
}); asmErr == nil && asmResp != nil {
|
||||
history = asmResp.History
|
||||
summary = asmResp.Summary
|
||||
}
|
||||
|
||||
newHistory := ts.agent.Sessions.GetHistory(ts.sessionKey)
|
||||
newSummary := ts.agent.Sessions.GetSummary(ts.sessionKey)
|
||||
messages = ts.agent.ContextBuilder.BuildMessages(
|
||||
newHistory, newSummary, "",
|
||||
history, summary, "",
|
||||
nil, ts.channel, ts.chatID, ts.opts.SenderID, ts.opts.SenderDisplayName,
|
||||
activeSkillNames(ts.agent, ts.opts)...,
|
||||
)
|
||||
@@ -2285,6 +2302,7 @@ turnLoop:
|
||||
if !ts.opts.NoHistory {
|
||||
ts.agent.Sessions.AddFullMessage(ts.sessionKey, assistantMsg)
|
||||
ts.recordPersistedMessage(assistantMsg)
|
||||
ts.ingestMessage(turnCtx, al, assistantMsg)
|
||||
}
|
||||
|
||||
ts.setPhase(TurnPhaseTools)
|
||||
@@ -2624,6 +2642,7 @@ turnLoop:
|
||||
if !ts.opts.NoHistory {
|
||||
ts.agent.Sessions.AddFullMessage(ts.sessionKey, toolResultMsg)
|
||||
ts.recordPersistedMessage(toolResultMsg)
|
||||
ts.ingestMessage(turnCtx, al, toolResultMsg)
|
||||
}
|
||||
|
||||
if steerMsgs := al.dequeueSteeringMessagesForScope(ts.sessionKey); len(steerMsgs) > 0 {
|
||||
@@ -2723,6 +2742,7 @@ turnLoop:
|
||||
if !ts.opts.NoHistory {
|
||||
ts.agent.Sessions.AddMessage(ts.sessionKey, summaryMsg.Role, summaryMsg.Content)
|
||||
ts.recordPersistedMessage(summaryMsg)
|
||||
ts.ingestMessage(turnCtx, al, summaryMsg)
|
||||
if err := ts.agent.Sessions.Save(ts.sessionKey); err != nil {
|
||||
turnStatus = TurnEndStatusError
|
||||
al.emitEvent(
|
||||
@@ -2737,7 +2757,7 @@ turnLoop:
|
||||
}
|
||||
}
|
||||
if ts.opts.EnableSummary {
|
||||
al.maybeSummarize(ts.agent, ts.sessionKey, ts.scope)
|
||||
al.contextManager.Compact(turnCtx, &CompactRequest{SessionKey: ts.sessionKey, Reason: ContextCompressReasonSummarize})
|
||||
}
|
||||
|
||||
ts.setPhase(TurnPhaseCompleted)
|
||||
@@ -2792,6 +2812,7 @@ turnLoop:
|
||||
finalMsg := providers.Message{Role: "assistant", Content: finalContent}
|
||||
ts.agent.Sessions.AddMessage(ts.sessionKey, finalMsg.Role, finalMsg.Content)
|
||||
ts.recordPersistedMessage(finalMsg)
|
||||
ts.ingestMessage(turnCtx, al, finalMsg)
|
||||
if err := ts.agent.Sessions.Save(ts.sessionKey); err != nil {
|
||||
turnStatus = TurnEndStatusError
|
||||
al.emitEvent(
|
||||
@@ -2807,7 +2828,13 @@ turnLoop:
|
||||
}
|
||||
|
||||
if ts.opts.EnableSummary {
|
||||
al.maybeSummarize(ts.agent, ts.sessionKey, ts.scope)
|
||||
al.contextManager.Compact(
|
||||
turnCtx,
|
||||
&CompactRequest{
|
||||
SessionKey: ts.sessionKey,
|
||||
Reason: ContextCompressReasonSummarize,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
ts.setPhase(TurnPhaseCompleted)
|
||||
@@ -2886,103 +2913,28 @@ func (al *AgentLoop) selectCandidates(
|
||||
return agent.LightCandidates, resolvedCandidateModel(agent.LightCandidates, agent.Router.LightModel()), true
|
||||
}
|
||||
|
||||
// maybeSummarize triggers summarization if the session history exceeds thresholds.
|
||||
func (al *AgentLoop) maybeSummarize(agent *AgentInstance, sessionKey string, turnScope turnEventScope) {
|
||||
newHistory := agent.Sessions.GetHistory(sessionKey)
|
||||
tokenEstimate := al.estimateTokens(newHistory)
|
||||
threshold := agent.ContextWindow * agent.SummarizeTokenPercent / 100
|
||||
|
||||
if len(newHistory) > agent.SummarizeMessageThreshold || tokenEstimate > threshold {
|
||||
summarizeKey := agent.ID + ":" + sessionKey
|
||||
if _, loading := al.summarizing.LoadOrStore(summarizeKey, true); !loading {
|
||||
go func() {
|
||||
defer al.summarizing.Delete(summarizeKey)
|
||||
logger.Debug("Memory threshold reached. Optimizing conversation history...")
|
||||
al.summarizeSession(agent, sessionKey, turnScope)
|
||||
}()
|
||||
}
|
||||
// resolveContextManager selects the ContextManager implementation based on config.
|
||||
func (al *AgentLoop) resolveContextManager() ContextManager {
|
||||
name := al.cfg.Agents.Defaults.ContextManager
|
||||
if name == "" || name == "legacy" {
|
||||
return &legacyContextManager{al: al}
|
||||
}
|
||||
}
|
||||
|
||||
type compressionResult struct {
|
||||
DroppedMessages int
|
||||
RemainingMessages int
|
||||
}
|
||||
|
||||
// forceCompression aggressively reduces context when the limit is hit.
|
||||
// It drops the oldest ~50% of Turns (a Turn is a complete user→LLM→response
|
||||
// cycle, as defined in #1316), so tool-call sequences are never split.
|
||||
//
|
||||
// If the history is a single Turn with no safe split point, the function
|
||||
// falls back to keeping only the most recent user message. This breaks
|
||||
// Turn atomicity as a last resort to avoid a context-exceeded loop.
|
||||
//
|
||||
// Session history contains only user/assistant/tool messages — the system
|
||||
// prompt is built dynamically by BuildMessages and is NOT stored here.
|
||||
// The compression note is recorded in the session summary so that
|
||||
// BuildMessages can include it in the next system prompt.
|
||||
func (al *AgentLoop) forceCompression(agent *AgentInstance, sessionKey string) (compressionResult, bool) {
|
||||
history := agent.Sessions.GetHistory(sessionKey)
|
||||
if len(history) <= 2 {
|
||||
return compressionResult{}, false
|
||||
factory, ok := lookupContextManager(name)
|
||||
if !ok {
|
||||
logger.WarnCF("agent", "Unknown context manager, falling back to legacy", map[string]any{
|
||||
"name": name,
|
||||
})
|
||||
return &legacyContextManager{al: al}
|
||||
}
|
||||
|
||||
// Split at a Turn boundary so no tool-call sequence is torn apart.
|
||||
// parseTurnBoundaries gives us the start of each Turn; we drop the
|
||||
// oldest half of Turns and keep the most recent ones.
|
||||
turns := parseTurnBoundaries(history)
|
||||
var mid int
|
||||
if len(turns) >= 2 {
|
||||
mid = turns[len(turns)/2]
|
||||
} else {
|
||||
// Fewer than 2 Turns — fall back to message-level midpoint
|
||||
// aligned to the nearest Turn boundary.
|
||||
mid = findSafeBoundary(history, len(history)/2)
|
||||
cm, err := factory(al.cfg.Agents.Defaults.ContextManagerConfig, al)
|
||||
if err != nil {
|
||||
logger.WarnCF("agent", "Failed to create context manager, falling back to legacy", map[string]any{
|
||||
"name": name,
|
||||
"error": err.Error(),
|
||||
})
|
||||
return &legacyContextManager{al: al}
|
||||
}
|
||||
var keptHistory []providers.Message
|
||||
if mid <= 0 {
|
||||
// No safe Turn boundary — the entire history is a single Turn
|
||||
// (e.g. one user message followed by a massive tool response).
|
||||
// Keeping everything would leave the agent stuck in a context-
|
||||
// exceeded loop, so fall back to keeping only the most recent
|
||||
// user message. This breaks Turn atomicity as a last resort.
|
||||
for i := len(history) - 1; i >= 0; i-- {
|
||||
if history[i].Role == "user" {
|
||||
keptHistory = []providers.Message{history[i]}
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
keptHistory = history[mid:]
|
||||
}
|
||||
|
||||
droppedCount := len(history) - len(keptHistory)
|
||||
|
||||
// Record compression in the session summary so BuildMessages includes it
|
||||
// in the system prompt. We do not modify history messages themselves.
|
||||
existingSummary := agent.Sessions.GetSummary(sessionKey)
|
||||
compressionNote := fmt.Sprintf(
|
||||
"[Emergency compression dropped %d oldest messages due to context limit]",
|
||||
droppedCount,
|
||||
)
|
||||
if existingSummary != "" {
|
||||
compressionNote = existingSummary + "\n\n" + compressionNote
|
||||
}
|
||||
agent.Sessions.SetSummary(sessionKey, compressionNote)
|
||||
|
||||
agent.Sessions.SetHistory(sessionKey, keptHistory)
|
||||
agent.Sessions.Save(sessionKey)
|
||||
|
||||
logger.WarnCF("agent", "Forced compression executed", map[string]any{
|
||||
"session_key": sessionKey,
|
||||
"dropped_msgs": droppedCount,
|
||||
"new_count": len(keptHistory),
|
||||
})
|
||||
|
||||
return compressionResult{
|
||||
DroppedMessages: droppedCount,
|
||||
RemainingMessages: len(keptHistory),
|
||||
}, true
|
||||
return cm
|
||||
}
|
||||
|
||||
// GetStartupInfo returns information about loaded tools and skills for logging.
|
||||
@@ -3074,247 +3026,13 @@ func formatToolsForLog(toolDefs []providers.ToolDefinition) string {
|
||||
}
|
||||
|
||||
// summarizeSession summarizes the conversation history for a session.
|
||||
func (al *AgentLoop) summarizeSession(agent *AgentInstance, sessionKey string, turnScope turnEventScope) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||
defer cancel()
|
||||
|
||||
history := agent.Sessions.GetHistory(sessionKey)
|
||||
summary := agent.Sessions.GetSummary(sessionKey)
|
||||
|
||||
// Keep the most recent Turns for continuity, aligned to a Turn boundary
|
||||
// so that no tool-call sequence is split.
|
||||
if len(history) <= 4 {
|
||||
return
|
||||
}
|
||||
|
||||
safeCut := findSafeBoundary(history, len(history)-4)
|
||||
if safeCut <= 0 {
|
||||
return
|
||||
}
|
||||
keepCount := len(history) - safeCut
|
||||
toSummarize := history[:safeCut]
|
||||
|
||||
// Oversized Message Guard
|
||||
maxMessageTokens := agent.ContextWindow / 2
|
||||
validMessages := make([]providers.Message, 0)
|
||||
omitted := false
|
||||
|
||||
for _, m := range toSummarize {
|
||||
if m.Role != "user" && m.Role != "assistant" {
|
||||
continue
|
||||
}
|
||||
msgTokens := len(m.Content) / 2
|
||||
if msgTokens > maxMessageTokens {
|
||||
omitted = true
|
||||
continue
|
||||
}
|
||||
validMessages = append(validMessages, m)
|
||||
}
|
||||
|
||||
if len(validMessages) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
const (
|
||||
maxSummarizationMessages = 10
|
||||
llmMaxRetries = 3
|
||||
llmTemperature = 0.3
|
||||
fallbackMaxContentLength = 200
|
||||
)
|
||||
|
||||
// Multi-Part Summarization
|
||||
var finalSummary string
|
||||
if len(validMessages) > maxSummarizationMessages {
|
||||
mid := len(validMessages) / 2
|
||||
|
||||
mid = al.findNearestUserMessage(validMessages, mid)
|
||||
|
||||
part1 := validMessages[:mid]
|
||||
part2 := validMessages[mid:]
|
||||
|
||||
s1, _ := al.summarizeBatch(ctx, agent, part1, "")
|
||||
s2, _ := al.summarizeBatch(ctx, agent, part2, "")
|
||||
|
||||
mergePrompt := fmt.Sprintf(
|
||||
"Merge these two conversation summaries into one cohesive summary:\n\n1: %s\n\n2: %s",
|
||||
s1,
|
||||
s2,
|
||||
)
|
||||
|
||||
resp, err := al.retryLLMCall(ctx, agent, mergePrompt, llmMaxRetries)
|
||||
if err == nil && resp.Content != "" {
|
||||
finalSummary = resp.Content
|
||||
} else {
|
||||
finalSummary = s1 + " " + s2
|
||||
}
|
||||
} else {
|
||||
finalSummary, _ = al.summarizeBatch(ctx, agent, validMessages, summary)
|
||||
}
|
||||
|
||||
if omitted && finalSummary != "" {
|
||||
finalSummary += "\n[Note: Some oversized messages were omitted from this summary for efficiency.]"
|
||||
}
|
||||
|
||||
if finalSummary != "" {
|
||||
agent.Sessions.SetSummary(sessionKey, finalSummary)
|
||||
agent.Sessions.TruncateHistory(sessionKey, keepCount)
|
||||
agent.Sessions.Save(sessionKey)
|
||||
al.emitEvent(
|
||||
EventKindSessionSummarize,
|
||||
turnScope.meta(0, "summarizeSession", "turn.session.summarize"),
|
||||
SessionSummarizePayload{
|
||||
SummarizedMessages: len(validMessages),
|
||||
KeptMessages: keepCount,
|
||||
SummaryLen: len(finalSummary),
|
||||
OmittedOversized: omitted,
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// findNearestUserMessage finds the nearest user message to the given index.
|
||||
// It searches backward first, then forward if no user message is found.
|
||||
func (al *AgentLoop) findNearestUserMessage(messages []providers.Message, mid int) int {
|
||||
originalMid := mid
|
||||
|
||||
for mid > 0 && messages[mid].Role != "user" {
|
||||
mid--
|
||||
}
|
||||
|
||||
if messages[mid].Role == "user" {
|
||||
return mid
|
||||
}
|
||||
|
||||
mid = originalMid
|
||||
for mid < len(messages) && messages[mid].Role != "user" {
|
||||
mid++
|
||||
}
|
||||
|
||||
if mid < len(messages) {
|
||||
return mid
|
||||
}
|
||||
|
||||
return originalMid
|
||||
}
|
||||
|
||||
// retryLLMCall calls the LLM with retry logic.
|
||||
func (al *AgentLoop) retryLLMCall(
|
||||
ctx context.Context,
|
||||
agent *AgentInstance,
|
||||
prompt string,
|
||||
maxRetries int,
|
||||
) (*providers.LLMResponse, error) {
|
||||
const (
|
||||
llmTemperature = 0.3
|
||||
)
|
||||
|
||||
var resp *providers.LLMResponse
|
||||
var err error
|
||||
|
||||
for attempt := 0; attempt < maxRetries; attempt++ {
|
||||
al.activeRequests.Add(1)
|
||||
resp, err = func() (*providers.LLMResponse, error) {
|
||||
defer al.activeRequests.Done()
|
||||
return agent.Provider.Chat(
|
||||
ctx,
|
||||
[]providers.Message{{Role: "user", Content: prompt}},
|
||||
nil,
|
||||
agent.Model,
|
||||
map[string]any{
|
||||
"max_tokens": agent.MaxTokens,
|
||||
"temperature": llmTemperature,
|
||||
"prompt_cache_key": agent.ID,
|
||||
},
|
||||
)
|
||||
}()
|
||||
|
||||
if err == nil && resp != nil && resp.Content != "" {
|
||||
return resp, nil
|
||||
}
|
||||
if attempt < maxRetries-1 {
|
||||
time.Sleep(time.Duration(attempt+1) * 100 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
|
||||
return resp, err
|
||||
}
|
||||
|
||||
// summarizeBatch summarizes a batch of messages.
|
||||
func (al *AgentLoop) summarizeBatch(
|
||||
ctx context.Context,
|
||||
agent *AgentInstance,
|
||||
batch []providers.Message,
|
||||
existingSummary string,
|
||||
) (string, error) {
|
||||
const (
|
||||
llmMaxRetries = 3
|
||||
llmTemperature = 0.3
|
||||
fallbackMinContentLength = 200
|
||||
fallbackMaxContentPercent = 10
|
||||
)
|
||||
|
||||
var sb strings.Builder
|
||||
sb.WriteString(
|
||||
"Provide a concise summary of this conversation segment, preserving core context and key points.\n",
|
||||
)
|
||||
if existingSummary != "" {
|
||||
sb.WriteString("Existing context: ")
|
||||
sb.WriteString(existingSummary)
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
sb.WriteString("\nCONVERSATION:\n")
|
||||
for _, m := range batch {
|
||||
fmt.Fprintf(&sb, "%s: %s\n", m.Role, m.Content)
|
||||
}
|
||||
prompt := sb.String()
|
||||
|
||||
response, err := al.retryLLMCall(ctx, agent, prompt, llmMaxRetries)
|
||||
if err == nil && response.Content != "" {
|
||||
return strings.TrimSpace(response.Content), nil
|
||||
}
|
||||
|
||||
var fallback strings.Builder
|
||||
fallback.WriteString("Conversation summary: ")
|
||||
for i, m := range batch {
|
||||
if i > 0 {
|
||||
fallback.WriteString(" | ")
|
||||
}
|
||||
content := strings.TrimSpace(m.Content)
|
||||
runes := []rune(content)
|
||||
if len(runes) == 0 {
|
||||
fallback.WriteString(fmt.Sprintf("%s: ", m.Role))
|
||||
continue
|
||||
}
|
||||
|
||||
keepLength := len(runes) * fallbackMaxContentPercent / 100
|
||||
if keepLength < fallbackMinContentLength {
|
||||
keepLength = fallbackMinContentLength
|
||||
}
|
||||
|
||||
if keepLength > len(runes) {
|
||||
keepLength = len(runes)
|
||||
}
|
||||
|
||||
content = string(runes[:keepLength])
|
||||
if keepLength < len(runes) {
|
||||
content += "..."
|
||||
}
|
||||
fallback.WriteString(fmt.Sprintf("%s: %s", m.Role, content))
|
||||
}
|
||||
return fallback.String(), nil
|
||||
}
|
||||
|
||||
// estimateTokens estimates the number of tokens in a message list.
|
||||
// Counts Content, ToolCalls arguments, and ToolCallID metadata so that
|
||||
// tool-heavy conversations are not systematically undercounted.
|
||||
func (al *AgentLoop) estimateTokens(messages []providers.Message) int {
|
||||
total := 0
|
||||
for _, m := range messages {
|
||||
total += estimateMessageTokens(m)
|
||||
}
|
||||
return total
|
||||
}
|
||||
|
||||
func (al *AgentLoop) handleCommand(
|
||||
ctx context.Context,
|
||||
msg bus.InboundMessage,
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/session"
|
||||
"github.com/sipeed/picoclaw/pkg/tools"
|
||||
@@ -338,6 +339,23 @@ func (ts *turnState) refreshRestorePointFromSession(agent *AgentInstance) {
|
||||
ts.captureRestorePoint(history, summary)
|
||||
}
|
||||
|
||||
// ingestMessage calls the ContextManager's Ingest method for a persisted message.
|
||||
// Errors are logged but never block the turn.
|
||||
func (ts *turnState) ingestMessage(ctx context.Context, al *AgentLoop, msg providers.Message) {
|
||||
if al.contextManager == nil {
|
||||
return
|
||||
}
|
||||
if err := al.contextManager.Ingest(ctx, &IngestRequest{
|
||||
SessionKey: ts.sessionKey,
|
||||
Message: msg,
|
||||
}); err != nil {
|
||||
logger.WarnCF("agent", "Context manager ingest failed", map[string]any{
|
||||
"session_key": ts.sessionKey,
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (ts *turnState) restoreSession(agent *AgentInstance) error {
|
||||
ts.mu.RLock()
|
||||
history := append([]providers.Message(nil), ts.restorePointHistory...)
|
||||
|
||||
+18
-16
@@ -226,26 +226,28 @@ type ToolFeedbackConfig struct {
|
||||
}
|
||||
|
||||
type AgentDefaults struct {
|
||||
Workspace string `json:"workspace" env:"PICOCLAW_AGENTS_DEFAULTS_WORKSPACE"`
|
||||
RestrictToWorkspace bool `json:"restrict_to_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE"`
|
||||
AllowReadOutsideWorkspace bool `json:"allow_read_outside_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_ALLOW_READ_OUTSIDE_WORKSPACE"`
|
||||
Provider string `json:"provider" env:"PICOCLAW_AGENTS_DEFAULTS_PROVIDER"`
|
||||
ModelName string `json:"model_name" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL_NAME"`
|
||||
Workspace string `json:"workspace" env:"PICOCLAW_AGENTS_DEFAULTS_WORKSPACE"`
|
||||
RestrictToWorkspace bool `json:"restrict_to_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE"`
|
||||
AllowReadOutsideWorkspace bool `json:"allow_read_outside_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_ALLOW_READ_OUTSIDE_WORKSPACE"`
|
||||
Provider string `json:"provider" env:"PICOCLAW_AGENTS_DEFAULTS_PROVIDER"`
|
||||
ModelName string `json:"model_name" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL_NAME"`
|
||||
ModelFallbacks []string `json:"model_fallbacks,omitempty"`
|
||||
ImageModel string `json:"image_model,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_IMAGE_MODEL"`
|
||||
ImageModel string `json:"image_model,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_IMAGE_MODEL"`
|
||||
ImageModelFallbacks []string `json:"image_model_fallbacks,omitempty"`
|
||||
MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"`
|
||||
ContextWindow int `json:"context_window,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_CONTEXT_WINDOW"`
|
||||
Temperature *float64 `json:"temperature,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"`
|
||||
MaxToolIterations int `json:"max_tool_iterations" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"`
|
||||
SummarizeMessageThreshold int `json:"summarize_message_threshold" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_MESSAGE_THRESHOLD"`
|
||||
SummarizeTokenPercent int `json:"summarize_token_percent" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_TOKEN_PERCENT"`
|
||||
MaxMediaSize int `json:"max_media_size,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_MEDIA_SIZE"`
|
||||
MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"`
|
||||
ContextWindow int `json:"context_window,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_CONTEXT_WINDOW"`
|
||||
Temperature *float64 `json:"temperature,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"`
|
||||
MaxToolIterations int `json:"max_tool_iterations" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"`
|
||||
SummarizeMessageThreshold int `json:"summarize_message_threshold" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_MESSAGE_THRESHOLD"`
|
||||
SummarizeTokenPercent int `json:"summarize_token_percent" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_TOKEN_PERCENT"`
|
||||
MaxMediaSize int `json:"max_media_size,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_MEDIA_SIZE"`
|
||||
Routing *RoutingConfig `json:"routing,omitempty"`
|
||||
SteeringMode string `json:"steering_mode,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_STEERING_MODE"` // "one-at-a-time" (default) or "all"
|
||||
SubTurn SubTurnConfig `json:"subturn" envPrefix:"PICOCLAW_AGENTS_DEFAULTS_SUBTURN_"`
|
||||
SteeringMode string `json:"steering_mode,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_STEERING_MODE"` // "one-at-a-time" (default) or "all"
|
||||
SubTurn SubTurnConfig `json:"subturn" envPrefix:"PICOCLAW_AGENTS_DEFAULTS_SUBTURN_"`
|
||||
ToolFeedback ToolFeedbackConfig `json:"tool_feedback,omitempty"`
|
||||
SplitOnMarker bool `json:"split_on_marker" env:"PICOCLAW_AGENTS_DEFAULTS_SPLIT_ON_MARKER"` // split messages on <|[SPLIT]|> marker
|
||||
SplitOnMarker bool `json:"split_on_marker" env:"PICOCLAW_AGENTS_DEFAULTS_SPLIT_ON_MARKER"` // split messages on <|[SPLIT]|> marker
|
||||
ContextManager string `json:"context_manager,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_CONTEXT_MANAGER"`
|
||||
ContextManagerConfig json.RawMessage `json:"context_manager_config,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_CONTEXT_MANAGER_CONFIG"`
|
||||
}
|
||||
|
||||
const DefaultMaxMediaSize = 20 * 1024 * 1024 // 20 MB
|
||||
|
||||
Reference in New Issue
Block a user