diff --git a/config/config.example.json b/config/config.example.json index 094aa46df..20c10e60d 100644 --- a/config/config.example.json +++ b/config/config.example.json @@ -5,6 +5,7 @@ "restrict_to_workspace": true, "model_name": "gpt-5.4", "max_tokens": 8192, + "context_window": 131072, "temperature": 0.7, "max_tool_iterations": 20, "summarize_message_threshold": 20, diff --git a/docs/agent-refactor/context.md b/docs/agent-refactor/context.md new file mode 100644 index 000000000..2269d9258 --- /dev/null +++ b/docs/agent-refactor/context.md @@ -0,0 +1,164 @@ +# Context + +## What this document covers + +This document makes explicit the boundaries of context management in the agent loop: + +- what fills the context window and how space is divided +- what is stored in session history vs. built at request time +- when and how context compression happens +- how token budgets are estimated + +These are existing concepts. This document clarifies their boundaries rather than introducing new ones. + +--- + +## Context window regions + +The context window is the model's total input capacity. Four regions fill it: + +| Region | Assembled by | Stored in session? | +|---|---|---| +| System prompt | `BuildMessages()` — static + dynamic parts | No | +| Summary | `SetSummary()` stores it; `BuildMessages()` injects it | Separate from history | +| Session history | User / assistant / tool messages | Yes | +| Tool definitions | Provider adapter injects at call time | No | + +`MaxTokens` (the output generation limit) must also be reserved from the total budget. + +The available space for history is therefore: + +``` +history_budget = ContextWindow - system_prompt - summary - tool_definitions - MaxTokens +``` + +--- + +## ContextWindow vs MaxTokens + +These serve different purposes: + +- **MaxTokens** — maximum tokens the LLM may generate in one response. Sent as the `max_tokens` request parameter. +- **ContextWindow** — the model's total input context capacity. + +These were previously set to the same value, which caused the summarization threshold to fire either far too early (at the default 32K) or not at all (when a user raised `max_tokens`). + +Current default when not explicitly configured: `ContextWindow = MaxTokens * 4`. + +--- + +## Session history + +Session history stores only conversation messages: + +- `user` — user input +- `assistant` — LLM response (may include `ToolCalls`) +- `tool` — tool execution results + +Session history does **not** contain: + +- System prompts — assembled at request time by `BuildMessages` +- Summary content — stored separately via `SetSummary`, injected by `BuildMessages` + +This distinction matters: any code that operates on session history — compression, boundary detection, token estimation — must not assume a system message is present. + +--- + +## Turn + +A **Turn** is one complete cycle: + +> user message -> LLM iterations (possibly including tool calls) -> final assistant response + +This definition comes from the agent loop design (#1316). In session history, Turn boundaries are identified by `user`-role messages. + +Turn is the atomic unit for compression. Cutting inside a Turn can orphan tool-call sequences — an assistant message with `ToolCalls` separated from its corresponding `tool` results. Compressing at Turn boundaries avoids this by construction. + +`parseTurnBoundaries(history)` returns the starting index of each Turn. +`findSafeBoundary(history, targetIndex)` snaps a target cut point to the nearest Turn boundary. + +--- + +## Compression paths + +Three compression paths exist, in order of preference: + +### 1. Async summarization + +`maybeSummarize` runs after each Turn completes. + +Triggers when message count exceeds a threshold, or when estimated history tokens exceed a percentage of `ContextWindow`. If triggered, a background goroutine calls the LLM to produce a summary of the oldest messages. The summary is stored via `SetSummary`; `BuildMessages` injects it into the system prompt on the next call. + +Cut point uses `findSafeBoundary` so no Turn is split. + +### 2. Proactive budget check + +`isOverContextBudget` runs before each LLM call. + +Uses the full budget formula: `message_tokens + tool_def_tokens + MaxTokens > ContextWindow`. If over budget, triggers `forceCompression` and rebuilds messages before calling the LLM. + +This prevents wasted (and billed) LLM calls that would otherwise fail with a context-window error. + +### 3. Emergency compression (reactive) + +`forceCompression` runs when the LLM returns a context-window error despite the proactive check. + +Drops the oldest ~50% of Turns. If the history is a single Turn with no safe split point (e.g. one user message followed by a massive tool response), falls back to keeping only the most recent user message — breaking Turn atomicity as a last resort to avoid a context-exceeded loop. + +Stores a compression note in the session summary (not in history messages) so `BuildMessages` can include it in the next system prompt. + +This is the fallback for when the token estimate undershoots reality. + +--- + +## Token estimation + +Estimation uses a heuristic of ~2.5 characters per token (`chars * 2 / 5`). + +`estimateMessageTokens` counts: + +- `Content` (rune count, for multibyte correctness) +- `ReasoningContent` (extended thinking / chain-of-thought) +- `ToolCalls` — ID, type, function name, arguments +- `ToolCallID` (tool result metadata) +- Per-message overhead (role label, JSON structure) +- `Media` items — flat per-item token estimate, added directly to the final count (not through the character heuristic, since actual cost depends on resolution and provider-specific image tokenization) + +`estimateToolDefsTokens` counts tool definition overhead: name, description, JSON schema of parameters. + +These are deliberately heuristic. The proactive check handles the common case; the reactive path catches estimation errors. + +--- + +## Interface boundaries + +Context budget functions (`parseTurnBoundaries`, `findSafeBoundary`, `estimateMessageTokens`, `isOverContextBudget`) are **pure functions**. They take `[]providers.Message` and integer parameters. They have no dependency on `AgentLoop` or any other runtime struct. + +`BuildMessages` is the sole assembler of the final message array sent to the LLM. Budget functions inform compression decisions but do not construct messages. + +`forceCompression` and `summarizeSession` mutate session state (history and summary). `BuildMessages` reads that state to construct context. The flow is: + +``` +budget check --> compression decision --> mutate session --> BuildMessages reads session --> LLM call +``` + +--- + +## Known gaps + +These are recognized limitations in the current implementation, documented here for visibility: + +- **Summarization trigger does not use the full budget formula.** `maybeSummarize` compares estimated history tokens against a percentage of `ContextWindow`. It does not account for system prompt size, tool definition overhead, or `MaxTokens` reserve. The proactive check covers the critical path (preventing 400 errors), but the summarization trigger could be aligned with the same budget model for more accurate early compression. + +- **Token estimation is heuristic.** It does not account for provider-specific tokenization, exact system prompt size (assembled separately), or variable image token costs. The two-path design (proactive + reactive) is intended to tolerate this imprecision. + +- **Reactive retry does not preserve media.** When the reactive path rebuilds context after compression, it currently passes empty values for media references. This is a pre-existing issue in the main loop, not introduced by the budget system. + +--- + +## What this document does not cover + +- How `AGENT.md` frontmatter configures context parameters — that is part of the Agent definition work +- How the context builder assembles context in the new architecture — that is upcoming work +- How compression events surface through the event system — that is part of the event model (#1316) +- Subagent context isolation — that is a separate track diff --git a/pkg/agent/context_budget.go b/pkg/agent/context_budget.go new file mode 100644 index 000000000..c87695c7a --- /dev/null +++ b/pkg/agent/context_budget.go @@ -0,0 +1,176 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package agent + +import ( + "encoding/json" + "unicode/utf8" + + "github.com/sipeed/picoclaw/pkg/providers" +) + +// parseTurnBoundaries returns the starting index of each Turn in the history. +// A Turn is a complete "user input → LLM iterations → final response" cycle +// (as defined in #1316). Each Turn begins at a user message and extends +// through all subsequent assistant/tool messages until the next user message. +// +// Cutting at a Turn boundary guarantees that no tool-call sequence +// (assistant+ToolCalls → tool results) is split across the cut. +func parseTurnBoundaries(history []providers.Message) []int { + var starts []int + for i, msg := range history { + if msg.Role == "user" { + starts = append(starts, i) + } + } + return starts +} + +// isSafeBoundary reports whether index is a valid Turn boundary — i.e., +// a position where the kept portion (history[index:]) begins at a user +// message, so no tool-call sequence is torn apart. +func isSafeBoundary(history []providers.Message, index int) bool { + if index <= 0 || index >= len(history) { + return true + } + return history[index].Role == "user" +} + +// findSafeBoundary locates the nearest Turn boundary to targetIndex. +// It prefers the boundary at or before targetIndex (preserving more recent +// context). Falls back to the nearest boundary after targetIndex, and +// returns targetIndex unchanged only when no Turn boundary exists at all. +func findSafeBoundary(history []providers.Message, targetIndex int) int { + if len(history) == 0 { + return 0 + } + if targetIndex <= 0 { + return 0 + } + if targetIndex >= len(history) { + return len(history) + } + + turns := parseTurnBoundaries(history) + if len(turns) == 0 { + return targetIndex + } + + // Find the last Turn boundary at or before targetIndex. + // Prefer backward: keeps more recent messages. + backward := -1 + for _, t := range turns { + if t <= targetIndex { + backward = t + } + } + if backward > 0 { + return backward + } + + // No valid Turn boundary before target (or only at index 0 which + // would keep everything). Use the first Turn after targetIndex. + for _, t := range turns { + if t > targetIndex { + return t + } + } + + // No Turn boundary after targetIndex either. The only boundary is at + // index 0, meaning the entire history is a single Turn. Return 0 to + // signal that safe compression is not possible — callers check for + // mid <= 0 and skip compression in that case. + return 0 +} + +// estimateMessageTokens estimates the token count for a single message, +// including Content, ReasoningContent, ToolCalls arguments, ToolCallID +// metadata, and Media items. Uses a heuristic of 2.5 characters per token. +func estimateMessageTokens(msg providers.Message) int { + chars := utf8.RuneCountInString(msg.Content) + + // ReasoningContent (extended thinking / chain-of-thought) can be + // substantial and is stored in session history via AddFullMessage. + if msg.ReasoningContent != "" { + chars += utf8.RuneCountInString(msg.ReasoningContent) + } + + for _, tc := range msg.ToolCalls { + chars += len(tc.ID) + len(tc.Type) + if tc.Function != nil { + // Count function name + arguments (the wire format for most providers). + // tc.Name mirrors tc.Function.Name — count only once to avoid double-counting. + chars += len(tc.Function.Name) + len(tc.Function.Arguments) + } else { + // Fallback: some provider formats use top-level Name without Function. + chars += len(tc.Name) + } + } + + if msg.ToolCallID != "" { + chars += len(msg.ToolCallID) + } + + // Per-message overhead for role label, JSON structure, separators. + const messageOverhead = 12 + chars += messageOverhead + + tokens := chars * 2 / 5 + + // Media items (images, files) are serialized by provider adapters into + // multipart or image_url payloads. Add a fixed per-item token estimate + // directly (not through the chars heuristic) since actual cost depends + // on resolution and provider-specific image tokenization. + const mediaTokensPerItem = 256 + tokens += len(msg.Media) * mediaTokensPerItem + + return tokens +} + +// estimateToolDefsTokens estimates the total token cost of tool definitions +// as they appear in the LLM request. Each tool's name, description, and +// JSON schema parameters contribute to the context window budget. +func estimateToolDefsTokens(defs []providers.ToolDefinition) int { + if len(defs) == 0 { + return 0 + } + + totalChars := 0 + for _, d := range defs { + totalChars += len(d.Function.Name) + len(d.Function.Description) + + if d.Function.Parameters != nil { + if paramJSON, err := json.Marshal(d.Function.Parameters); err == nil { + totalChars += len(paramJSON) + } + } + + // Per-tool overhead: type field, JSON structure, separators. + totalChars += 20 + } + + return totalChars * 2 / 5 +} + +// isOverContextBudget checks whether the assembled messages plus tool definitions +// and output reserve would exceed the model's context window. This enables +// proactive compression before calling the LLM, rather than reacting to 400 errors. +func isOverContextBudget( + contextWindow int, + messages []providers.Message, + toolDefs []providers.ToolDefinition, + maxTokens int, +) bool { + msgTokens := 0 + for _, m := range messages { + msgTokens += estimateMessageTokens(m) + } + + toolTokens := estimateToolDefsTokens(toolDefs) + total := msgTokens + toolTokens + maxTokens + + return total > contextWindow +} diff --git a/pkg/agent/context_budget_test.go b/pkg/agent/context_budget_test.go new file mode 100644 index 000000000..870f0fbe6 --- /dev/null +++ b/pkg/agent/context_budget_test.go @@ -0,0 +1,826 @@ +package agent + +import ( + "fmt" + "strings" + "testing" + + "github.com/sipeed/picoclaw/pkg/providers" +) + +// msgUser creates a user message. +func msgUser(content string) providers.Message { + return providers.Message{Role: "user", Content: content} +} + +// msgAssistant creates a plain assistant message (no tool calls). +func msgAssistant(content string) providers.Message { + return providers.Message{Role: "assistant", Content: content} +} + +// msgAssistantTC creates an assistant message with tool calls. +func msgAssistantTC(toolIDs ...string) providers.Message { + tcs := make([]providers.ToolCall, len(toolIDs)) + for i, id := range toolIDs { + tcs[i] = providers.ToolCall{ + ID: id, + Type: "function", + Name: "tool_" + id, + Function: &providers.FunctionCall{ + Name: "tool_" + id, + Arguments: `{"key":"value"}`, + }, + } + } + return providers.Message{Role: "assistant", ToolCalls: tcs} +} + +// msgTool creates a tool result message. +func msgTool(callID, content string) providers.Message { + return providers.Message{Role: "tool", ToolCallID: callID, Content: content} +} + +func TestParseTurnBoundaries(t *testing.T) { + tests := []struct { + name string + history []providers.Message + want []int + }{ + { + name: "empty history", + history: nil, + want: nil, + }, + { + name: "simple exchange", + history: []providers.Message{ + msgUser("q1"), + msgAssistant("a1"), + msgUser("q2"), + msgAssistant("a2"), + }, + want: []int{0, 2}, + }, + { + name: "tool-call Turn", + history: []providers.Message{ + msgUser("search"), + msgAssistantTC("tc1"), + msgTool("tc1", "result"), + msgAssistant("found it"), + msgUser("thanks"), + msgAssistant("welcome"), + }, + want: []int{0, 4}, + }, + { + name: "chained tool calls in single Turn", + history: []providers.Message{ + msgUser("save and notify"), + msgAssistantTC("tc_save"), + msgTool("tc_save", "saved"), + msgAssistantTC("tc_notify"), + msgTool("tc_notify", "notified"), + msgAssistant("done"), + }, + want: []int{0}, + }, + { + name: "no user messages", + history: []providers.Message{ + msgAssistant("a1"), + msgAssistant("a2"), + }, + want: nil, + }, + { + name: "leading non-user messages", + history: []providers.Message{ + msgAssistantTC("tc1"), + msgTool("tc1", "r1"), + msgAssistant("greeting"), + msgUser("hello"), + msgAssistant("hi"), + }, + want: []int{3}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := parseTurnBoundaries(tt.history) + if len(got) != len(tt.want) { + t.Errorf("parseTurnBoundaries() = %v, want %v", got, tt.want) + return + } + for i := range got { + if got[i] != tt.want[i] { + t.Errorf("parseTurnBoundaries()[%d] = %d, want %d", i, got[i], tt.want[i]) + } + } + }) + } +} + +func TestIsSafeBoundary(t *testing.T) { + tests := []struct { + name string + history []providers.Message + index int + want bool + }{ + { + name: "empty history, index 0", + history: nil, + index: 0, + want: true, + }, + { + name: "single user message, index 0", + history: []providers.Message{msgUser("hi")}, + index: 0, + want: true, + }, + { + name: "single user message, index 1 (end)", + history: []providers.Message{msgUser("hi")}, + index: 1, + want: true, + }, + { + name: "at user message", + history: []providers.Message{ + msgAssistant("hello"), + msgUser("how are you"), + msgAssistant("fine"), + }, + index: 1, + want: true, + }, + { + name: "at assistant without tool calls", + history: []providers.Message{ + msgUser("hello"), + msgAssistant("response"), + msgUser("follow up"), + }, + index: 1, + want: false, + }, + { + name: "at assistant with tool calls", + history: []providers.Message{ + msgUser("search something"), + msgAssistantTC("tc1"), + msgTool("tc1", "result"), + msgAssistant("here is what I found"), + }, + index: 1, + want: false, + }, + { + name: "at tool result", + history: []providers.Message{ + msgUser("do something"), + msgAssistantTC("tc1"), + msgTool("tc1", "done"), + msgAssistant("completed"), + }, + index: 2, + want: false, + }, + { + name: "negative index", + history: []providers.Message{ + msgUser("hello"), + }, + index: -1, + want: true, + }, + { + name: "index beyond length", + history: []providers.Message{ + msgUser("hello"), + }, + index: 5, + want: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := isSafeBoundary(tt.history, tt.index) + if got != tt.want { + t.Errorf("isSafeBoundary(history, %d) = %v, want %v", tt.index, got, tt.want) + } + }) + } +} + +func TestFindSafeBoundary(t *testing.T) { + tests := []struct { + name string + history []providers.Message + targetIndex int + want int + }{ + { + name: "empty history", + history: nil, + targetIndex: 0, + want: 0, + }, + { + name: "target at 0", + history: []providers.Message{msgUser("hi")}, + targetIndex: 0, + want: 0, + }, + { + name: "target beyond length", + history: []providers.Message{msgUser("hi")}, + targetIndex: 5, + want: 1, + }, + { + name: "target already at user message", + history: []providers.Message{ + msgUser("q1"), + msgAssistant("a1"), + msgUser("q2"), + msgAssistant("a2"), + }, + targetIndex: 2, + want: 2, + }, + { + name: "target at assistant, scan backward finds user", + history: []providers.Message{ + msgUser("q1"), + msgAssistant("a1"), + msgUser("q2"), + msgAssistant("a2"), + msgUser("q3"), + }, + targetIndex: 3, // assistant "a2" + want: 2, // backward to user "q2" + }, + { + name: "target inside tool sequence, scan backward finds user", + history: []providers.Message{ + msgUser("q1"), + msgAssistant("a1"), + msgUser("q2"), + msgAssistantTC("tc1", "tc2"), + msgTool("tc1", "r1"), + msgTool("tc2", "r2"), + msgAssistant("summary"), + msgUser("q3"), + }, + targetIndex: 4, // tool result "r1" + want: 2, // backward: 3=assistant+TC (not safe), 2=user → safe + }, + { + name: "target inside tool sequence, backward finds user before chain", + history: []providers.Message{ + msgUser("q1"), + msgAssistant("a1"), + msgUser("q2"), + msgAssistantTC("tc1", "tc2"), + msgTool("tc1", "r1"), + msgTool("tc2", "r2"), + msgAssistant("summary"), + msgUser("q3"), + }, + targetIndex: 5, // tool result "r2" + want: 2, // backward: 4=tool, 3=assistant+TC, 2=user → safe + }, + { + name: "no backward user, scan forward finds one", + history: []providers.Message{ + msgAssistantTC("tc1"), + msgTool("tc1", "r1"), + msgAssistant("a1"), + msgUser("q1"), + }, + targetIndex: 1, // tool result + want: 3, // forward to user "q1" + }, + { + name: "multi-step tool chain preserves atomicity", + history: []providers.Message{ + msgUser("q1"), + msgAssistant("a1"), + msgUser("q2"), + msgAssistantTC("tc1"), + msgTool("tc1", "r1"), + msgAssistantTC("tc2"), + msgTool("tc2", "r2"), + msgAssistant("final"), + msgUser("q3"), + msgAssistant("a3"), + }, + targetIndex: 5, // second assistant+TC + want: 2, // backward: 4=tool, 3=assistant+TC, 2=user → safe + }, + { + name: "all non-user messages returns target unchanged", + history: []providers.Message{ + msgAssistant("a1"), + msgAssistant("a2"), + msgAssistant("a3"), + }, + targetIndex: 1, + want: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := findSafeBoundary(tt.history, tt.targetIndex) + if got != tt.want { + t.Errorf("findSafeBoundary(history, %d) = %d, want %d", + tt.targetIndex, got, tt.want) + } + }) + } +} + +func TestFindSafeBoundary_SingleTurnReturnsZero(t *testing.T) { + // A single Turn with no subsequent user message. The only Turn boundary + // is at index 0; cutting anywhere else would split the Turn's tool + // sequence. findSafeBoundary must return 0 so callers skip compression. + history := []providers.Message{ + msgUser("do everything"), // 0 ← only Turn boundary + msgAssistantTC("tc1"), // 1 + msgTool("tc1", "result"), // 2 + msgAssistant("all done"), // 3 + } + + got := findSafeBoundary(history, 2) + if got != 0 { + t.Errorf("findSafeBoundary(single_turn, 2) = %d, want 0 (cannot split single Turn)", got) + } +} + +func TestFindSafeBoundary_BackwardScanSkipsToolSequence(t *testing.T) { + // A long tool-call chain: user → assistant+TC → tool → tool → ... → assistant → user + // Target is inside the chain; boundary should skip the entire chain backward. + history := []providers.Message{ + msgUser("start"), // 0 + msgAssistant("before chain"), // 1 + msgUser("trigger"), // 2 ← expected safe boundary + msgAssistantTC("t1", "t2", "t3"), // 3 + msgTool("t1", "r1"), // 4 + msgTool("t2", "r2"), // 5 + msgTool("t3", "r3"), // 6 + msgAssistantTC("t4"), // 7 + msgTool("t4", "r4"), // 8 + msgAssistant("chain done"), // 9 + msgUser("next"), // 10 + } + + // Target at index 6 (middle of tool results) + got := findSafeBoundary(history, 6) + if got != 2 { + t.Errorf("findSafeBoundary(history, 6) = %d, want 2 (user before chain)", got) + } +} + +func TestEstimateMessageTokens(t *testing.T) { + tests := []struct { + name string + msg providers.Message + want int // minimum expected tokens (exact value depends on overhead) + }{ + { + name: "plain user message", + msg: msgUser("Hello, world!"), + want: 1, // at least some tokens + }, + { + name: "empty message still has overhead", + msg: providers.Message{Role: "user"}, + want: 1, // message overhead alone + }, + { + name: "assistant with tool calls", + msg: msgAssistantTC("tc_123"), + want: 1, + }, + { + name: "tool result with ID", + msg: msgTool("call_abc", "Here is the search result with lots of content"), + want: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := estimateMessageTokens(tt.msg) + if got < tt.want { + t.Errorf("estimateMessageTokens() = %d, want >= %d", got, tt.want) + } + }) + } +} + +func TestEstimateMessageTokens_ToolCallsContribute(t *testing.T) { + plain := msgAssistant("thinking") + withTC := providers.Message{ + Role: "assistant", + Content: "thinking", + ToolCalls: []providers.ToolCall{ + { + ID: "call_1", + Type: "function", + Name: "web_search", + Function: &providers.FunctionCall{ + Name: "web_search", + Arguments: `{"query":"picoclaw agent framework","max_results":5}`, + }, + }, + }, + } + + plainTokens := estimateMessageTokens(plain) + withTCTokens := estimateMessageTokens(withTC) + + if withTCTokens <= plainTokens { + t.Errorf("message with ToolCalls (%d tokens) should exceed plain message (%d tokens)", + withTCTokens, plainTokens) + } +} + +func TestEstimateMessageTokens_MultibyteContent(t *testing.T) { + // Multi-byte characters (e.g. emoji, accented letters) are single runes + // but may map to different token counts. The heuristic should still produce + // reasonable estimates via RuneCountInString. + msg := msgUser("caf\u00e9 na\u00efve r\u00e9sum\u00e9 \u00fcber stra\u00dfe") + tokens := estimateMessageTokens(msg) + if tokens <= 0 { + t.Errorf("multibyte message should produce positive token count, got %d", tokens) + } +} + +func TestEstimateMessageTokens_LargeArguments(t *testing.T) { + // Simulate a tool call with large JSON arguments. + largeArgs := fmt.Sprintf(`{"content":"%s"}`, strings.Repeat("x", 5000)) + msg := providers.Message{ + Role: "assistant", + ToolCalls: []providers.ToolCall{ + { + ID: "call_large", + Type: "function", + Name: "write_file", + Function: &providers.FunctionCall{ + Name: "write_file", + Arguments: largeArgs, + }, + }, + }, + } + + tokens := estimateMessageTokens(msg) + // 5000+ chars → at least 2000 tokens with the 2.5 char/token heuristic + if tokens < 2000 { + t.Errorf("large tool call arguments should produce significant token count, got %d", tokens) + } +} + +func TestEstimateMessageTokens_ReasoningContent(t *testing.T) { + plain := msgAssistant("result") + withReasoning := providers.Message{ + Role: "assistant", + Content: "result", + ReasoningContent: strings.Repeat("thinking step ", 200), + } + + plainTokens := estimateMessageTokens(plain) + reasoningTokens := estimateMessageTokens(withReasoning) + + if reasoningTokens <= plainTokens { + t.Errorf("message with ReasoningContent (%d tokens) should exceed plain message (%d tokens)", + reasoningTokens, plainTokens) + } +} + +func TestEstimateMessageTokens_MediaItems(t *testing.T) { + plain := msgUser("describe this") + withMedia := providers.Message{ + Role: "user", + Content: "describe this", + Media: []string{"media://img1.png", "media://img2.png"}, + } + + plainTokens := estimateMessageTokens(plain) + mediaTokens := estimateMessageTokens(withMedia) + + if mediaTokens <= plainTokens { + t.Errorf("message with Media (%d tokens) should exceed plain message (%d tokens)", + mediaTokens, plainTokens) + } + + // Each media item should add exactly 256 tokens (not run through chars*2/5). + expectedDelta := 256 * 2 + actualDelta := mediaTokens - plainTokens + if actualDelta != expectedDelta { + t.Errorf("2 media items should add %d tokens, got delta %d", expectedDelta, actualDelta) + } +} + +// --- estimateToolDefsTokens tests --- + +func TestEstimateToolDefsTokens(t *testing.T) { + tests := []struct { + name string + defs []providers.ToolDefinition + want int // minimum expected tokens + }{ + { + name: "empty tool list", + defs: nil, + want: 0, + }, + { + name: "single tool with params", + defs: []providers.ToolDefinition{ + { + Type: "function", + Function: providers.ToolFunctionDefinition{ + Name: "web_search", + Description: "Search the web for information", + Parameters: map[string]any{ + "type": "object", + "properties": map[string]any{ + "query": map[string]any{"type": "string"}, + }, + "required": []any{"query"}, + }, + }, + }, + }, + want: 1, + }, + { + name: "tool without params", + defs: []providers.ToolDefinition{ + { + Type: "function", + Function: providers.ToolFunctionDefinition{ + Name: "list_dir", + Description: "List directory contents", + }, + }, + }, + want: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := estimateToolDefsTokens(tt.defs) + if got < tt.want { + t.Errorf("estimateToolDefsTokens() = %d, want >= %d", got, tt.want) + } + }) + } +} + +func TestEstimateToolDefsTokens_ScalesWithCount(t *testing.T) { + makeTool := func(name string) providers.ToolDefinition { + return providers.ToolDefinition{ + Type: "function", + Function: providers.ToolFunctionDefinition{ + Name: name, + Description: "A test tool that does something useful", + Parameters: map[string]any{ + "type": "object", + "properties": map[string]any{ + "input": map[string]any{"type": "string", "description": "Input value"}, + }, + }, + }, + } + } + + one := estimateToolDefsTokens([]providers.ToolDefinition{makeTool("tool_a")}) + three := estimateToolDefsTokens([]providers.ToolDefinition{ + makeTool("tool_a"), makeTool("tool_b"), makeTool("tool_c"), + }) + + if three <= one { + t.Errorf("3 tools (%d tokens) should exceed 1 tool (%d tokens)", three, one) + } +} + +// --- isOverContextBudget tests --- + +func TestIsOverContextBudget(t *testing.T) { + systemMsg := providers.Message{Role: "system", Content: strings.Repeat("x", 1000)} + userMsg := msgUser("hello") + smallHistory := []providers.Message{systemMsg, msgUser("q1"), msgAssistant("a1"), userMsg} + + tools := []providers.ToolDefinition{ + { + Type: "function", + Function: providers.ToolFunctionDefinition{ + Name: "test_tool", + Description: "A test tool", + Parameters: map[string]any{"type": "object"}, + }, + }, + } + + tests := []struct { + name string + contextWindow int + messages []providers.Message + toolDefs []providers.ToolDefinition + maxTokens int + want bool + }{ + { + name: "within budget", + contextWindow: 100000, + messages: smallHistory, + toolDefs: tools, + maxTokens: 4096, + want: false, + }, + { + name: "over budget with small window", + contextWindow: 100, // very small window + messages: smallHistory, + toolDefs: tools, + maxTokens: 4096, + want: true, + }, + { + name: "large max_tokens eats budget", + contextWindow: 2000, + messages: smallHistory, + toolDefs: tools, + maxTokens: 1800, // leaves almost no room + want: true, + }, + { + name: "empty messages within budget", + contextWindow: 10000, + messages: nil, + toolDefs: nil, + maxTokens: 4096, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := isOverContextBudget(tt.contextWindow, tt.messages, tt.toolDefs, tt.maxTokens) + if got != tt.want { + t.Errorf("isOverContextBudget() = %v, want %v", got, tt.want) + } + }) + } +} + +// --- Tests reflecting actual session data shape --- +// Session history never contains system messages. The system prompt is +// built dynamically by BuildMessages. These tests use realistic history +// shapes: user/assistant/tool only, with tool chains and reasoning content. + +func TestFindSafeBoundary_SessionHistoryNoSystem(t *testing.T) { + // Real session history starts with a user message, not a system message. + history := []providers.Message{ + msgUser("hello"), // 0 + msgAssistant("hi there"), // 1 + msgUser("search for X"), // 2 + msgAssistantTC("tc1"), // 3 + msgTool("tc1", "found X"), // 4 + msgAssistant("here is X"), // 5 + msgUser("thanks"), // 6 + msgAssistant("you're welcome"), // 7 + } + + // Mid-point is 4 (tool result). Should snap backward to 2 (user). + got := findSafeBoundary(history, 4) + if got != 2 { + t.Errorf("findSafeBoundary(session_history, 4) = %d, want 2", got) + } +} + +func TestFindSafeBoundary_SessionWithChainedTools(t *testing.T) { + // Session with chained tool calls (save then notify). + history := []providers.Message{ + msgUser("save and notify"), // 0 + msgAssistantTC("tc_save"), // 1 + msgTool("tc_save", "saved"), // 2 + msgAssistantTC("tc_notify"), // 3 + msgTool("tc_notify", "notified"), // 4 + msgAssistant("done"), // 5 + msgUser("check status"), // 6 + msgAssistant("all good"), // 7 + } + + // Target at 3 (inside chain). Should find user at 0, but backward + // scan stops at i>0, so forward scan finds user at 6. + // Actually: backward from 3: 2=tool (no), 1=assistantTC (no). Forward: 4=tool, 5=asst, 6=user ✓ + got := findSafeBoundary(history, 3) + if got != 6 { + t.Errorf("findSafeBoundary(chained_tools, 3) = %d, want 6", got) + } +} + +func TestEstimateMessageTokens_WithReasoningAndMedia(t *testing.T) { + // Message with all fields populated — mirrors what AddFullMessage stores. + msg := providers.Message{ + Role: "assistant", + Content: "Here is the analysis.", + ReasoningContent: strings.Repeat("Let me think about this carefully. ", 50), + ToolCalls: []providers.ToolCall{ + { + ID: "call_1", + Type: "function", + Name: "analyze", + Function: &providers.FunctionCall{ + Name: "analyze", + Arguments: `{"data":"sample","depth":3}`, + }, + }, + }, + } + + tokens := estimateMessageTokens(msg) + + // ReasoningContent alone is ~1700 chars → ~680 tokens. + // Content + TC + overhead adds more. Should be well above 500. + if tokens < 500 { + t.Errorf("message with reasoning+toolcalls should have significant tokens, got %d", tokens) + } + + // Compare without reasoning to ensure it's counted. + msgNoReasoning := msg + msgNoReasoning.ReasoningContent = "" + tokensNoReasoning := estimateMessageTokens(msgNoReasoning) + + if tokens <= tokensNoReasoning { + t.Errorf("reasoning content should add tokens: with=%d, without=%d", tokens, tokensNoReasoning) + } +} + +func TestIsOverContextBudget_RealisticSession(t *testing.T) { + // Simulate what BuildMessages produces: system + session history + current user. + // System message is built by BuildMessages, not stored in session. + systemMsg := providers.Message{ + Role: "system", + Content: strings.Repeat("system prompt content ", 100), + } + sessionHistory := []providers.Message{ + msgUser("first question"), + msgAssistant("first answer"), + msgUser("use tool X"), + { + Role: "assistant", + Content: "I'll use tool X", + ToolCalls: []providers.ToolCall{ + { + ID: "tc1", Type: "function", Name: "tool_x", + Function: &providers.FunctionCall{ + Name: "tool_x", + Arguments: `{"query":"test","verbose":true}`, + }, + }, + }, + }, + {Role: "tool", Content: strings.Repeat("result data ", 200), ToolCallID: "tc1"}, + msgAssistant("Here are the results from tool X."), + } + currentUser := msgUser("follow up question") + + // Assemble as BuildMessages would. + messages := make([]providers.Message, 0, 1+len(sessionHistory)+1) + messages = append(messages, systemMsg) + messages = append(messages, sessionHistory...) + messages = append(messages, currentUser) + + tools := []providers.ToolDefinition{ + { + Type: "function", + Function: providers.ToolFunctionDefinition{ + Name: "tool_x", + Description: "A useful tool", + Parameters: map[string]any{"type": "object"}, + }, + }, + } + + // With a large context window, should be within budget. + if isOverContextBudget(131072, messages, tools, 32768) { + t.Error("realistic session should be within 131072 context window") + } + + // With a tiny context window, should exceed budget. + if !isOverContextBudget(500, messages, tools, 32768) { + t.Error("realistic session should exceed 500 context window") + } +} diff --git a/pkg/agent/instance.go b/pkg/agent/instance.go index 0c7baa1ee..c34f9b4a4 100644 --- a/pkg/agent/instance.go +++ b/pkg/agent/instance.go @@ -127,6 +127,17 @@ func NewAgentInstance( maxTokens = 8192 } + contextWindow := defaults.ContextWindow + if contextWindow == 0 { + // Default heuristic: 4x the output token limit. + // Most models have context windows well above their output limits + // (e.g., GPT-4o 128k ctx / 16k out, Claude 200k ctx / 8k out). + // 4x is a conservative lower bound that avoids premature + // summarization while remaining safe — the reactive + // forceCompression handles any overshoot. + contextWindow = maxTokens * 4 + } + temperature := 0.7 if defaults.Temperature != nil { temperature = *defaults.Temperature @@ -224,7 +235,7 @@ func NewAgentInstance( MaxTokens: maxTokens, Temperature: temperature, ThinkingLevel: thinkingLevel, - ContextWindow: maxTokens, + ContextWindow: contextWindow, SummarizeMessageThreshold: summarizeMessageThreshold, SummarizeTokenPercent: summarizeTokenPercent, Provider: provider, diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index 21516e7de..c583f5ca5 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -17,7 +17,6 @@ import ( "sync" "sync/atomic" "time" - "unicode/utf8" "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/channels" @@ -931,6 +930,24 @@ func (al *AgentLoop) runAgentLoop( maxMediaSize := cfg.Agents.Defaults.GetMaxMediaSize() messages = resolveMediaRefs(messages, al.mediaStore, maxMediaSize) + // 1.5. Proactive context budget check: compress before LLM call + // rather than waiting for a 400 context-length error. + if !opts.NoHistory { + toolDefs := agent.Tools.ToProviderDefs() + if isOverContextBudget(agent.ContextWindow, messages, toolDefs, agent.MaxTokens) { + logger.WarnCF("agent", "Proactive compression: context budget exceeded before LLM call", + map[string]any{"session_key": opts.SessionKey}) + al.forceCompression(agent, opts.SessionKey) + newHistory := agent.Sessions.GetHistory(opts.SessionKey) + newSummary := agent.Sessions.GetSummary(opts.SessionKey) + messages = agent.ContextBuilder.BuildMessages( + newHistory, newSummary, opts.UserMessage, + opts.Media, opts.Channel, opts.ChatID, + ) + messages = resolveMediaRefs(messages, al.mediaStore, maxMediaSize) + } + } + // 2. Save user message to session agent.Sessions.AddMessage(opts.SessionKey, "user", opts.UserMessage) @@ -1539,55 +1556,73 @@ func (al *AgentLoop) maybeSummarize(agent *AgentInstance, sessionKey, channel, c } // forceCompression aggressively reduces context when the limit is hit. -// It drops the oldest 50% of messages (keeping system prompt and last user message). +// It drops the oldest ~50% of Turns (a Turn is a complete user→LLM→response +// cycle, as defined in #1316), so tool-call sequences are never split. +// +// If the history is a single Turn with no safe split point, the function +// falls back to keeping only the most recent user message. This breaks +// Turn atomicity as a last resort to avoid a context-exceeded loop. +// +// Session history contains only user/assistant/tool messages — the system +// prompt is built dynamically by BuildMessages and is NOT stored here. +// The compression note is recorded in the session summary so that +// BuildMessages can include it in the next system prompt. func (al *AgentLoop) forceCompression(agent *AgentInstance, sessionKey string) { history := agent.Sessions.GetHistory(sessionKey) - if len(history) <= 4 { + if len(history) <= 2 { return } - // Keep system prompt (usually [0]) and the very last message (user's trigger) - // We want to drop the oldest half of the *conversation* - // Assuming [0] is system, [1:] is conversation - conversation := history[1 : len(history)-1] - if len(conversation) == 0 { - return + // Split at a Turn boundary so no tool-call sequence is torn apart. + // parseTurnBoundaries gives us the start of each Turn; we drop the + // oldest half of Turns and keep the most recent ones. + turns := parseTurnBoundaries(history) + var mid int + if len(turns) >= 2 { + mid = turns[len(turns)/2] + } else { + // Fewer than 2 Turns — fall back to message-level midpoint + // aligned to the nearest Turn boundary. + mid = findSafeBoundary(history, len(history)/2) + } + var keptHistory []providers.Message + if mid <= 0 { + // No safe Turn boundary — the entire history is a single Turn + // (e.g. one user message followed by a massive tool response). + // Keeping everything would leave the agent stuck in a context- + // exceeded loop, so fall back to keeping only the most recent + // user message. This breaks Turn atomicity as a last resort. + for i := len(history) - 1; i >= 0; i-- { + if history[i].Role == "user" { + keptHistory = []providers.Message{history[i]} + break + } + } + } else { + keptHistory = history[mid:] } - // Helper to find the mid-point of the conversation - mid := len(conversation) / 2 + droppedCount := len(history) - len(keptHistory) - // New history structure: - // 1. System Prompt (with compression note appended) - // 2. Second half of conversation - // 3. Last message - - droppedCount := mid - keptConversation := conversation[mid:] - - newHistory := make([]providers.Message, 0, 1+len(keptConversation)+1) - - // Append compression note to the original system prompt instead of adding a new system message - // This avoids having two consecutive system messages which some APIs (like Zhipu) reject + // Record compression in the session summary so BuildMessages includes it + // in the system prompt. We do not modify history messages themselves. + existingSummary := agent.Sessions.GetSummary(sessionKey) compressionNote := fmt.Sprintf( - "\n\n[System Note: Emergency compression dropped %d oldest messages due to context limit]", + "[Emergency compression dropped %d oldest messages due to context limit]", droppedCount, ) - enhancedSystemPrompt := history[0] - enhancedSystemPrompt.Content = enhancedSystemPrompt.Content + compressionNote - newHistory = append(newHistory, enhancedSystemPrompt) + if existingSummary != "" { + compressionNote = existingSummary + "\n\n" + compressionNote + } + agent.Sessions.SetSummary(sessionKey, compressionNote) - newHistory = append(newHistory, keptConversation...) - newHistory = append(newHistory, history[len(history)-1]) // Last message - - // Update session - agent.Sessions.SetHistory(sessionKey, newHistory) + agent.Sessions.SetHistory(sessionKey, keptHistory) agent.Sessions.Save(sessionKey) logger.WarnCF("agent", "Forced compression executed", map[string]any{ "session_key": sessionKey, "dropped_msgs": droppedCount, - "new_count": len(newHistory), + "new_count": len(keptHistory), }) } @@ -1687,12 +1722,18 @@ func (al *AgentLoop) summarizeSession(agent *AgentInstance, sessionKey string) { history := agent.Sessions.GetHistory(sessionKey) summary := agent.Sessions.GetSummary(sessionKey) - // Keep last 4 messages for continuity + // Keep the most recent Turns for continuity, aligned to a Turn boundary + // so that no tool-call sequence is split. if len(history) <= 4 { return } - toSummarize := history[:len(history)-4] + safeCut := findSafeBoundary(history, len(history)-4) + if safeCut <= 0 { + return + } + keepCount := len(history) - safeCut + toSummarize := history[:safeCut] // Oversized Message Guard maxMessageTokens := agent.ContextWindow / 2 @@ -1757,7 +1798,7 @@ func (al *AgentLoop) summarizeSession(agent *AgentInstance, sessionKey string) { if finalSummary != "" { agent.Sessions.SetSummary(sessionKey, finalSummary) - agent.Sessions.TruncateHistory(sessionKey, 4) + agent.Sessions.TruncateHistory(sessionKey, keepCount) agent.Sessions.Save(sessionKey) } } @@ -1895,15 +1936,14 @@ func (al *AgentLoop) summarizeBatch( } // estimateTokens estimates the number of tokens in a message list. -// Uses a safe heuristic of 2.5 characters per token to account for CJK and other -// overheads better than the previous 3 chars/token. +// Counts Content, ToolCalls arguments, and ToolCallID metadata so that +// tool-heavy conversations are not systematically undercounted. func (al *AgentLoop) estimateTokens(messages []providers.Message) int { - totalChars := 0 + total := 0 for _, m := range messages { - totalChars += utf8.RuneCountInString(m.Content) + total += estimateMessageTokens(m) } - // 2.5 chars per token = totalChars * 2 / 5 - return totalChars * 2 / 5 + return total } func (al *AgentLoop) handleCommand( diff --git a/pkg/agent/loop_test.go b/pkg/agent/loop_test.go index a6604e87f..b65c0e21c 100644 --- a/pkg/agent/loop_test.go +++ b/pkg/agent/loop_test.go @@ -719,11 +719,11 @@ func TestAgentLoop_ContextExhaustionRetry(t *testing.T) { al := NewAgentLoop(cfg, msgBus, provider) - // Inject some history to simulate a full context + // Inject some history to simulate a full context. + // Session history only stores user/assistant/tool messages — the system + // prompt is built dynamically by BuildMessages and is NOT stored here. sessionKey := "test-session-context" - // Create dummy history history := []providers.Message{ - {Role: "system", Content: "System prompt"}, {Role: "user", Content: "Old message 1"}, {Role: "assistant", Content: "Old response 1"}, {Role: "user", Content: "Old message 2"}, @@ -761,12 +761,11 @@ func TestAgentLoop_ContextExhaustionRetry(t *testing.T) { // Check final history length finalHistory := defaultAgent.Sessions.GetHistory(sessionKey) // We verify that the history has been modified (compressed) - // Original length: 6 - // Expected behavior: compression drops ~50% of history (mid slice) - // We can assert that the length is NOT what it would be without compression. - // Without compression: 6 + 1 (new user msg) + 1 (assistant msg) = 8 - if len(finalHistory) >= 8 { - t.Errorf("Expected history to be compressed (len < 8), got %d", len(finalHistory)) + // Original length: 5 + // Expected behavior: compression drops ~50% of Turns + // Without compression: 5 + 1 (new user msg) + 1 (assistant msg) = 7 + if len(finalHistory) >= 7 { + t.Errorf("Expected history to be compressed (len < 7), got %d", len(finalHistory)) } } diff --git a/pkg/config/config.go b/pkg/config/config.go index a8b8f337f..a3720b656 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -228,6 +228,7 @@ type AgentDefaults struct { ImageModel string `json:"image_model,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_IMAGE_MODEL"` ImageModelFallbacks []string `json:"image_model_fallbacks,omitempty"` MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"` + ContextWindow int `json:"context_window,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_CONTEXT_WINDOW"` Temperature *float64 `json:"temperature,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"` MaxToolIterations int `json:"max_tool_iterations" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"` SummarizeMessageThreshold int `json:"summarize_message_threshold" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_MESSAGE_THRESHOLD"` diff --git a/web/frontend/src/components/config/config-page.tsx b/web/frontend/src/components/config/config-page.tsx index cbce7d27e..dc6797749 100644 --- a/web/frontend/src/components/config/config-page.tsx +++ b/web/frontend/src/components/config/config-page.tsx @@ -144,6 +144,9 @@ export function ConfigPage() { const maxTokens = parseIntField(form.maxTokens, "Max tokens", { min: 1, }) + const contextWindow = form.contextWindow.trim() + ? parseIntField(form.contextWindow, "Context window", { min: 1 }) + : undefined const maxToolIterations = parseIntField( form.maxToolIterations, "Max tool iterations", @@ -171,6 +174,7 @@ export function ConfigPage() { workspace, restrict_to_workspace: form.restrictToWorkspace, max_tokens: maxTokens, + context_window: contextWindow, max_tool_iterations: maxToolIterations, summarize_message_threshold: summarizeMessageThreshold, summarize_token_percent: summarizeTokenPercent, diff --git a/web/frontend/src/components/config/config-sections.tsx b/web/frontend/src/components/config/config-sections.tsx index dfbe22fc3..825d882b7 100644 --- a/web/frontend/src/components/config/config-sections.tsx +++ b/web/frontend/src/components/config/config-sections.tsx @@ -114,6 +114,20 @@ export function AgentDefaultsSection({ /> + + onFieldChange("contextWindow", e.target.value)} + placeholder="131072" + /> + +