mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
refactor(agent): context boundary detection, proactive budget check, and safe compression
Separate context_window from max_tokens — they serve different purposes (input capacity vs output generation limit). The previous conflation caused premature summarization or missed compression triggers. Changes: - Add context_window field to AgentDefaults config (default: 4x max_tokens) - Extract boundary-safe truncation helpers (isSafeBoundary, findSafeBoundary) into context_budget.go — pure functions with no AgentLoop dependency - forceCompression: align split to safe boundary so tool-call sequences (assistant+ToolCalls → tool results) are never torn apart - summarizeSession: use findSafeBoundary instead of hardcoded keep-last-4 - estimateTokens: count ToolCalls arguments and ToolCallID metadata, not just Content — fixes systematic undercounting in tool-heavy sessions - Add proactive context budget check before LLM call in runAgentLoop, preventing 400 context-length errors instead of reacting to them - Add estimateToolDefsTokens for tool definition token cost Closes #556, closes #665 Ref #1439
This commit is contained in:
@@ -0,0 +1,133 @@
|
||||
// PicoClaw - Ultra-lightweight personal AI agent
|
||||
// License: MIT
|
||||
//
|
||||
// Copyright (c) 2026 PicoClaw contributors
|
||||
|
||||
package agent
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
// isSafeBoundary reports whether index is a valid position to split a message
|
||||
// history for truncation or compression. Splitting at index means:
|
||||
// - history[:index] is dropped or summarized
|
||||
// - history[index:] is kept
|
||||
//
|
||||
// A boundary is safe when the kept portion begins at a "user" message,
|
||||
// ensuring no tool-call sequence (assistant+ToolCalls → tool results)
|
||||
// is torn apart across the split.
|
||||
func isSafeBoundary(history []providers.Message, index int) bool {
|
||||
if index <= 0 || index >= len(history) {
|
||||
return true
|
||||
}
|
||||
return history[index].Role == "user"
|
||||
}
|
||||
|
||||
// findSafeBoundary locates the nearest safe split point to targetIndex.
|
||||
// It scans backward first (preserving more context), then forward.
|
||||
// Returns targetIndex unchanged only when no safe boundary exists.
|
||||
func findSafeBoundary(history []providers.Message, targetIndex int) int {
|
||||
if len(history) == 0 {
|
||||
return 0
|
||||
}
|
||||
if targetIndex <= 0 {
|
||||
return 0
|
||||
}
|
||||
if targetIndex >= len(history) {
|
||||
return len(history)
|
||||
}
|
||||
|
||||
if isSafeBoundary(history, targetIndex) {
|
||||
return targetIndex
|
||||
}
|
||||
|
||||
// Backward scan: prefer keeping more messages.
|
||||
for i := targetIndex - 1; i > 0; i-- {
|
||||
if isSafeBoundary(history, i) {
|
||||
return i
|
||||
}
|
||||
}
|
||||
|
||||
// Forward scan: fall back to keeping fewer messages.
|
||||
for i := targetIndex + 1; i < len(history); i++ {
|
||||
if isSafeBoundary(history, i) {
|
||||
return i
|
||||
}
|
||||
}
|
||||
|
||||
return targetIndex
|
||||
}
|
||||
|
||||
// estimateMessageTokens estimates the token count for a single message,
|
||||
// including Content, ToolCalls arguments, and ToolCallID metadata.
|
||||
// Uses a heuristic of 2.5 characters per token.
|
||||
func estimateMessageTokens(msg providers.Message) int {
|
||||
chars := utf8.RuneCountInString(msg.Content)
|
||||
|
||||
for _, tc := range msg.ToolCalls {
|
||||
// Count tool call metadata: ID, type, function name
|
||||
chars += len(tc.ID) + len(tc.Type) + len(tc.Name)
|
||||
if tc.Function != nil {
|
||||
chars += len(tc.Function.Name) + len(tc.Function.Arguments)
|
||||
}
|
||||
}
|
||||
|
||||
if msg.ToolCallID != "" {
|
||||
chars += len(msg.ToolCallID)
|
||||
}
|
||||
|
||||
// Per-message overhead for role label, JSON structure, separators.
|
||||
const messageOverhead = 12
|
||||
chars += messageOverhead
|
||||
|
||||
return chars * 2 / 5
|
||||
}
|
||||
|
||||
// estimateToolDefsTokens estimates the total token cost of tool definitions
|
||||
// as they appear in the LLM request. Each tool's name, description, and
|
||||
// JSON schema parameters contribute to the context window budget.
|
||||
func estimateToolDefsTokens(defs []providers.ToolDefinition) int {
|
||||
if len(defs) == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
totalChars := 0
|
||||
for _, d := range defs {
|
||||
totalChars += len(d.Function.Name) + len(d.Function.Description)
|
||||
|
||||
if d.Function.Parameters != nil {
|
||||
if paramJSON, err := json.Marshal(d.Function.Parameters); err == nil {
|
||||
totalChars += len(paramJSON)
|
||||
}
|
||||
}
|
||||
|
||||
// Per-tool overhead: type field, JSON structure, separators.
|
||||
totalChars += 20
|
||||
}
|
||||
|
||||
return totalChars * 2 / 5
|
||||
}
|
||||
|
||||
// isOverContextBudget checks whether the assembled messages plus tool definitions
|
||||
// and output reserve would exceed the model's context window. This enables
|
||||
// proactive compression before calling the LLM, rather than reacting to 400 errors.
|
||||
func isOverContextBudget(
|
||||
contextWindow int,
|
||||
messages []providers.Message,
|
||||
toolDefs []providers.ToolDefinition,
|
||||
maxTokens int,
|
||||
) bool {
|
||||
msgTokens := 0
|
||||
for _, m := range messages {
|
||||
msgTokens += estimateMessageTokens(m)
|
||||
}
|
||||
|
||||
toolTokens := estimateToolDefsTokens(toolDefs)
|
||||
total := msgTokens + toolTokens + maxTokens
|
||||
|
||||
return total > contextWindow
|
||||
}
|
||||
@@ -0,0 +1,545 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
// msgUser creates a user message.
|
||||
func msgUser(content string) providers.Message {
|
||||
return providers.Message{Role: "user", Content: content}
|
||||
}
|
||||
|
||||
// msgAssistant creates a plain assistant message (no tool calls).
|
||||
func msgAssistant(content string) providers.Message {
|
||||
return providers.Message{Role: "assistant", Content: content}
|
||||
}
|
||||
|
||||
// msgAssistantTC creates an assistant message with tool calls.
|
||||
func msgAssistantTC(toolIDs ...string) providers.Message {
|
||||
tcs := make([]providers.ToolCall, len(toolIDs))
|
||||
for i, id := range toolIDs {
|
||||
tcs[i] = providers.ToolCall{
|
||||
ID: id,
|
||||
Type: "function",
|
||||
Name: "tool_" + id,
|
||||
Function: &providers.FunctionCall{
|
||||
Name: "tool_" + id,
|
||||
Arguments: `{"key":"value"}`,
|
||||
},
|
||||
}
|
||||
}
|
||||
return providers.Message{Role: "assistant", ToolCalls: tcs}
|
||||
}
|
||||
|
||||
// msgTool creates a tool result message.
|
||||
func msgTool(callID, content string) providers.Message {
|
||||
return providers.Message{Role: "tool", ToolCallID: callID, Content: content}
|
||||
}
|
||||
|
||||
func TestIsSafeBoundary(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
history []providers.Message
|
||||
index int
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "empty history, index 0",
|
||||
history: nil,
|
||||
index: 0,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "single user message, index 0",
|
||||
history: []providers.Message{msgUser("hi")},
|
||||
index: 0,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "single user message, index 1 (end)",
|
||||
history: []providers.Message{msgUser("hi")},
|
||||
index: 1,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "at user message",
|
||||
history: []providers.Message{
|
||||
msgAssistant("hello"),
|
||||
msgUser("how are you"),
|
||||
msgAssistant("fine"),
|
||||
},
|
||||
index: 1,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "at assistant without tool calls",
|
||||
history: []providers.Message{
|
||||
msgUser("hello"),
|
||||
msgAssistant("response"),
|
||||
msgUser("follow up"),
|
||||
},
|
||||
index: 1,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "at assistant with tool calls",
|
||||
history: []providers.Message{
|
||||
msgUser("search something"),
|
||||
msgAssistantTC("tc1"),
|
||||
msgTool("tc1", "result"),
|
||||
msgAssistant("here is what I found"),
|
||||
},
|
||||
index: 1,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "at tool result",
|
||||
history: []providers.Message{
|
||||
msgUser("do something"),
|
||||
msgAssistantTC("tc1"),
|
||||
msgTool("tc1", "done"),
|
||||
msgAssistant("completed"),
|
||||
},
|
||||
index: 2,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "negative index",
|
||||
history: []providers.Message{
|
||||
msgUser("hello"),
|
||||
},
|
||||
index: -1,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "index beyond length",
|
||||
history: []providers.Message{
|
||||
msgUser("hello"),
|
||||
},
|
||||
index: 5,
|
||||
want: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := isSafeBoundary(tt.history, tt.index)
|
||||
if got != tt.want {
|
||||
t.Errorf("isSafeBoundary(history, %d) = %v, want %v", tt.index, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindSafeBoundary(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
history []providers.Message
|
||||
targetIndex int
|
||||
want int
|
||||
}{
|
||||
{
|
||||
name: "empty history",
|
||||
history: nil,
|
||||
targetIndex: 0,
|
||||
want: 0,
|
||||
},
|
||||
{
|
||||
name: "target at 0",
|
||||
history: []providers.Message{msgUser("hi")},
|
||||
targetIndex: 0,
|
||||
want: 0,
|
||||
},
|
||||
{
|
||||
name: "target beyond length",
|
||||
history: []providers.Message{msgUser("hi")},
|
||||
targetIndex: 5,
|
||||
want: 1,
|
||||
},
|
||||
{
|
||||
name: "target already at user message",
|
||||
history: []providers.Message{
|
||||
msgUser("q1"),
|
||||
msgAssistant("a1"),
|
||||
msgUser("q2"),
|
||||
msgAssistant("a2"),
|
||||
},
|
||||
targetIndex: 2,
|
||||
want: 2,
|
||||
},
|
||||
{
|
||||
name: "target at assistant, scan backward finds user",
|
||||
history: []providers.Message{
|
||||
msgUser("q1"),
|
||||
msgAssistant("a1"),
|
||||
msgUser("q2"),
|
||||
msgAssistant("a2"),
|
||||
msgUser("q3"),
|
||||
},
|
||||
targetIndex: 3, // assistant "a2"
|
||||
want: 2, // backward to user "q2"
|
||||
},
|
||||
{
|
||||
name: "target inside tool sequence, scan backward finds user",
|
||||
history: []providers.Message{
|
||||
msgUser("q1"),
|
||||
msgAssistant("a1"),
|
||||
msgUser("q2"),
|
||||
msgAssistantTC("tc1", "tc2"),
|
||||
msgTool("tc1", "r1"),
|
||||
msgTool("tc2", "r2"),
|
||||
msgAssistant("summary"),
|
||||
msgUser("q3"),
|
||||
},
|
||||
targetIndex: 4, // tool result "r1"
|
||||
want: 2, // backward: 3=assistant+TC (not safe), 2=user → safe
|
||||
},
|
||||
{
|
||||
name: "target inside tool sequence, backward finds user before chain",
|
||||
history: []providers.Message{
|
||||
msgUser("q1"),
|
||||
msgAssistant("a1"),
|
||||
msgUser("q2"),
|
||||
msgAssistantTC("tc1", "tc2"),
|
||||
msgTool("tc1", "r1"),
|
||||
msgTool("tc2", "r2"),
|
||||
msgAssistant("summary"),
|
||||
msgUser("q3"),
|
||||
},
|
||||
targetIndex: 5, // tool result "r2"
|
||||
want: 2, // backward: 4=tool, 3=assistant+TC, 2=user → safe
|
||||
},
|
||||
{
|
||||
name: "no backward user, scan forward finds one",
|
||||
history: []providers.Message{
|
||||
msgAssistantTC("tc1"),
|
||||
msgTool("tc1", "r1"),
|
||||
msgAssistant("a1"),
|
||||
msgUser("q1"),
|
||||
},
|
||||
targetIndex: 1, // tool result
|
||||
want: 3, // forward to user "q1"
|
||||
},
|
||||
{
|
||||
name: "multi-step tool chain preserves atomicity",
|
||||
history: []providers.Message{
|
||||
msgUser("q1"),
|
||||
msgAssistant("a1"),
|
||||
msgUser("q2"),
|
||||
msgAssistantTC("tc1"),
|
||||
msgTool("tc1", "r1"),
|
||||
msgAssistantTC("tc2"),
|
||||
msgTool("tc2", "r2"),
|
||||
msgAssistant("final"),
|
||||
msgUser("q3"),
|
||||
msgAssistant("a3"),
|
||||
},
|
||||
targetIndex: 5, // second assistant+TC
|
||||
want: 2, // backward: 4=tool, 3=assistant+TC, 2=user → safe
|
||||
},
|
||||
{
|
||||
name: "all non-user messages returns target unchanged",
|
||||
history: []providers.Message{
|
||||
msgAssistant("a1"),
|
||||
msgAssistant("a2"),
|
||||
msgAssistant("a3"),
|
||||
},
|
||||
targetIndex: 1,
|
||||
want: 1,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := findSafeBoundary(tt.history, tt.targetIndex)
|
||||
if got != tt.want {
|
||||
t.Errorf("findSafeBoundary(history, %d) = %d, want %d",
|
||||
tt.targetIndex, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindSafeBoundary_BackwardScanSkipsToolSequence(t *testing.T) {
|
||||
// A long tool-call chain: user → assistant+TC → tool → tool → ... → assistant → user
|
||||
// Target is inside the chain; boundary should skip the entire chain backward.
|
||||
history := []providers.Message{
|
||||
msgUser("start"), // 0
|
||||
msgAssistant("before chain"), // 1
|
||||
msgUser("trigger"), // 2 ← expected safe boundary
|
||||
msgAssistantTC("t1", "t2", "t3"), // 3
|
||||
msgTool("t1", "r1"), // 4
|
||||
msgTool("t2", "r2"), // 5
|
||||
msgTool("t3", "r3"), // 6
|
||||
msgAssistantTC("t4"), // 7
|
||||
msgTool("t4", "r4"), // 8
|
||||
msgAssistant("chain done"), // 9
|
||||
msgUser("next"), // 10
|
||||
}
|
||||
|
||||
// Target at index 6 (middle of tool results)
|
||||
got := findSafeBoundary(history, 6)
|
||||
if got != 2 {
|
||||
t.Errorf("findSafeBoundary(history, 6) = %d, want 2 (user before chain)", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEstimateMessageTokens(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
msg providers.Message
|
||||
want int // minimum expected tokens (exact value depends on overhead)
|
||||
}{
|
||||
{
|
||||
name: "plain user message",
|
||||
msg: msgUser("Hello, world!"),
|
||||
want: 1, // at least some tokens
|
||||
},
|
||||
{
|
||||
name: "empty message still has overhead",
|
||||
msg: providers.Message{Role: "user"},
|
||||
want: 1, // message overhead alone
|
||||
},
|
||||
{
|
||||
name: "assistant with tool calls",
|
||||
msg: msgAssistantTC("tc_123"),
|
||||
want: 1,
|
||||
},
|
||||
{
|
||||
name: "tool result with ID",
|
||||
msg: msgTool("call_abc", "Here is the search result with lots of content"),
|
||||
want: 1,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := estimateMessageTokens(tt.msg)
|
||||
if got < tt.want {
|
||||
t.Errorf("estimateMessageTokens() = %d, want >= %d", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEstimateMessageTokens_ToolCallsContribute(t *testing.T) {
|
||||
plain := msgAssistant("thinking")
|
||||
withTC := providers.Message{
|
||||
Role: "assistant",
|
||||
Content: "thinking",
|
||||
ToolCalls: []providers.ToolCall{
|
||||
{
|
||||
ID: "call_1",
|
||||
Type: "function",
|
||||
Name: "web_search",
|
||||
Function: &providers.FunctionCall{
|
||||
Name: "web_search",
|
||||
Arguments: `{"query":"picoclaw agent framework","max_results":5}`,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
plainTokens := estimateMessageTokens(plain)
|
||||
withTCTokens := estimateMessageTokens(withTC)
|
||||
|
||||
if withTCTokens <= plainTokens {
|
||||
t.Errorf("message with ToolCalls (%d tokens) should exceed plain message (%d tokens)",
|
||||
withTCTokens, plainTokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEstimateMessageTokens_MultibyteContent(t *testing.T) {
|
||||
// Multi-byte characters (e.g. emoji, accented letters) are single runes
|
||||
// but may map to different token counts. The heuristic should still produce
|
||||
// reasonable estimates via RuneCountInString.
|
||||
msg := msgUser("caf\u00e9 na\u00efve r\u00e9sum\u00e9 \u00fcber stra\u00dfe")
|
||||
tokens := estimateMessageTokens(msg)
|
||||
if tokens <= 0 {
|
||||
t.Errorf("multibyte message should produce positive token count, got %d", tokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEstimateMessageTokens_LargeArguments(t *testing.T) {
|
||||
// Simulate a tool call with large JSON arguments.
|
||||
largeArgs := fmt.Sprintf(`{"content":"%s"}`, strings.Repeat("x", 5000))
|
||||
msg := providers.Message{
|
||||
Role: "assistant",
|
||||
ToolCalls: []providers.ToolCall{
|
||||
{
|
||||
ID: "call_large",
|
||||
Type: "function",
|
||||
Name: "write_file",
|
||||
Function: &providers.FunctionCall{
|
||||
Name: "write_file",
|
||||
Arguments: largeArgs,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
tokens := estimateMessageTokens(msg)
|
||||
// 5000+ chars → at least 2000 tokens with the 2.5 char/token heuristic
|
||||
if tokens < 2000 {
|
||||
t.Errorf("large tool call arguments should produce significant token count, got %d", tokens)
|
||||
}
|
||||
}
|
||||
|
||||
// --- estimateToolDefsTokens tests ---
|
||||
|
||||
func TestEstimateToolDefsTokens(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
defs []providers.ToolDefinition
|
||||
want int // minimum expected tokens
|
||||
}{
|
||||
{
|
||||
name: "empty tool list",
|
||||
defs: nil,
|
||||
want: 0,
|
||||
},
|
||||
{
|
||||
name: "single tool with params",
|
||||
defs: []providers.ToolDefinition{
|
||||
{
|
||||
Type: "function",
|
||||
Function: providers.ToolFunctionDefinition{
|
||||
Name: "web_search",
|
||||
Description: "Search the web for information",
|
||||
Parameters: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"query": map[string]any{"type": "string"},
|
||||
},
|
||||
"required": []any{"query"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
want: 1,
|
||||
},
|
||||
{
|
||||
name: "tool without params",
|
||||
defs: []providers.ToolDefinition{
|
||||
{
|
||||
Type: "function",
|
||||
Function: providers.ToolFunctionDefinition{
|
||||
Name: "list_dir",
|
||||
Description: "List directory contents",
|
||||
},
|
||||
},
|
||||
},
|
||||
want: 1,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := estimateToolDefsTokens(tt.defs)
|
||||
if got < tt.want {
|
||||
t.Errorf("estimateToolDefsTokens() = %d, want >= %d", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEstimateToolDefsTokens_ScalesWithCount(t *testing.T) {
|
||||
makeTool := func(name string) providers.ToolDefinition {
|
||||
return providers.ToolDefinition{
|
||||
Type: "function",
|
||||
Function: providers.ToolFunctionDefinition{
|
||||
Name: name,
|
||||
Description: "A test tool that does something useful",
|
||||
Parameters: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"input": map[string]any{"type": "string", "description": "Input value"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
one := estimateToolDefsTokens([]providers.ToolDefinition{makeTool("tool_a")})
|
||||
three := estimateToolDefsTokens([]providers.ToolDefinition{
|
||||
makeTool("tool_a"), makeTool("tool_b"), makeTool("tool_c"),
|
||||
})
|
||||
|
||||
if three <= one {
|
||||
t.Errorf("3 tools (%d tokens) should exceed 1 tool (%d tokens)", three, one)
|
||||
}
|
||||
}
|
||||
|
||||
// --- isOverContextBudget tests ---
|
||||
|
||||
func TestIsOverContextBudget(t *testing.T) {
|
||||
systemMsg := providers.Message{Role: "system", Content: strings.Repeat("x", 1000)}
|
||||
userMsg := msgUser("hello")
|
||||
smallHistory := []providers.Message{systemMsg, msgUser("q1"), msgAssistant("a1"), userMsg}
|
||||
|
||||
tools := []providers.ToolDefinition{
|
||||
{
|
||||
Type: "function",
|
||||
Function: providers.ToolFunctionDefinition{
|
||||
Name: "test_tool",
|
||||
Description: "A test tool",
|
||||
Parameters: map[string]any{"type": "object"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
contextWindow int
|
||||
messages []providers.Message
|
||||
toolDefs []providers.ToolDefinition
|
||||
maxTokens int
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "within budget",
|
||||
contextWindow: 100000,
|
||||
messages: smallHistory,
|
||||
toolDefs: tools,
|
||||
maxTokens: 4096,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "over budget with small window",
|
||||
contextWindow: 100, // very small window
|
||||
messages: smallHistory,
|
||||
toolDefs: tools,
|
||||
maxTokens: 4096,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "large max_tokens eats budget",
|
||||
contextWindow: 2000,
|
||||
messages: smallHistory,
|
||||
toolDefs: tools,
|
||||
maxTokens: 1800, // leaves almost no room
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "empty messages within budget",
|
||||
contextWindow: 10000,
|
||||
messages: nil,
|
||||
toolDefs: nil,
|
||||
maxTokens: 4096,
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := isOverContextBudget(tt.contextWindow, tt.messages, tt.toolDefs, tt.maxTokens)
|
||||
if got != tt.want {
|
||||
t.Errorf("isOverContextBudget() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
+12
-1
@@ -127,6 +127,17 @@ func NewAgentInstance(
|
||||
maxTokens = 8192
|
||||
}
|
||||
|
||||
contextWindow := defaults.ContextWindow
|
||||
if contextWindow == 0 {
|
||||
// Default heuristic: 4x the output token limit.
|
||||
// Most models have context windows well above their output limits
|
||||
// (e.g., GPT-4o 128k ctx / 16k out, Claude 200k ctx / 8k out).
|
||||
// 4x is a conservative lower bound that avoids premature
|
||||
// summarization while remaining safe — the reactive
|
||||
// forceCompression handles any overshoot.
|
||||
contextWindow = maxTokens * 4
|
||||
}
|
||||
|
||||
temperature := 0.7
|
||||
if defaults.Temperature != nil {
|
||||
temperature = *defaults.Temperature
|
||||
@@ -224,7 +235,7 @@ func NewAgentInstance(
|
||||
MaxTokens: maxTokens,
|
||||
Temperature: temperature,
|
||||
ThinkingLevel: thinkingLevel,
|
||||
ContextWindow: maxTokens,
|
||||
ContextWindow: contextWindow,
|
||||
SummarizeMessageThreshold: summarizeMessageThreshold,
|
||||
SummarizeTokenPercent: summarizeTokenPercent,
|
||||
Provider: provider,
|
||||
|
||||
+36
-13
@@ -17,7 +17,6 @@ import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/channels"
|
||||
@@ -931,6 +930,24 @@ func (al *AgentLoop) runAgentLoop(
|
||||
maxMediaSize := cfg.Agents.Defaults.GetMaxMediaSize()
|
||||
messages = resolveMediaRefs(messages, al.mediaStore, maxMediaSize)
|
||||
|
||||
// 1.5. Proactive context budget check: compress before LLM call
|
||||
// rather than waiting for a 400 context-length error.
|
||||
if !opts.NoHistory {
|
||||
toolDefs := agent.Tools.ToProviderDefs()
|
||||
if isOverContextBudget(agent.ContextWindow, messages, toolDefs, agent.MaxTokens) {
|
||||
logger.WarnCF("agent", "Proactive compression: context budget exceeded before LLM call",
|
||||
map[string]any{"session_key": opts.SessionKey})
|
||||
al.forceCompression(agent, opts.SessionKey)
|
||||
newHistory := agent.Sessions.GetHistory(opts.SessionKey)
|
||||
newSummary := agent.Sessions.GetSummary(opts.SessionKey)
|
||||
messages = agent.ContextBuilder.BuildMessages(
|
||||
newHistory, newSummary, opts.UserMessage,
|
||||
opts.Media, opts.Channel, opts.ChatID,
|
||||
)
|
||||
messages = resolveMediaRefs(messages, al.mediaStore, maxMediaSize)
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Save user message to session
|
||||
agent.Sessions.AddMessage(opts.SessionKey, "user", opts.UserMessage)
|
||||
|
||||
@@ -1539,7 +1556,8 @@ func (al *AgentLoop) maybeSummarize(agent *AgentInstance, sessionKey, channel, c
|
||||
}
|
||||
|
||||
// forceCompression aggressively reduces context when the limit is hit.
|
||||
// It drops the oldest 50% of messages (keeping system prompt and last user message).
|
||||
// It drops the oldest ~50% of messages (keeping system prompt and last user message),
|
||||
// aligning the split to a safe boundary so tool-call sequences stay intact.
|
||||
func (al *AgentLoop) forceCompression(agent *AgentInstance, sessionKey string) {
|
||||
history := agent.Sessions.GetHistory(sessionKey)
|
||||
if len(history) <= 4 {
|
||||
@@ -1554,8 +1572,8 @@ func (al *AgentLoop) forceCompression(agent *AgentInstance, sessionKey string) {
|
||||
return
|
||||
}
|
||||
|
||||
// Helper to find the mid-point of the conversation
|
||||
mid := len(conversation) / 2
|
||||
// Find a safe mid-point that does not split a tool-call sequence.
|
||||
mid := findSafeBoundary(conversation, len(conversation)/2)
|
||||
|
||||
// New history structure:
|
||||
// 1. System Prompt (with compression note appended)
|
||||
@@ -1687,12 +1705,18 @@ func (al *AgentLoop) summarizeSession(agent *AgentInstance, sessionKey string) {
|
||||
history := agent.Sessions.GetHistory(sessionKey)
|
||||
summary := agent.Sessions.GetSummary(sessionKey)
|
||||
|
||||
// Keep last 4 messages for continuity
|
||||
// Keep last few messages for continuity, aligned to a safe boundary
|
||||
// so that no tool-call sequence is split.
|
||||
if len(history) <= 4 {
|
||||
return
|
||||
}
|
||||
|
||||
toSummarize := history[:len(history)-4]
|
||||
safeCut := findSafeBoundary(history, len(history)-4)
|
||||
if safeCut <= 0 {
|
||||
return
|
||||
}
|
||||
keepCount := len(history) - safeCut
|
||||
toSummarize := history[:safeCut]
|
||||
|
||||
// Oversized Message Guard
|
||||
maxMessageTokens := agent.ContextWindow / 2
|
||||
@@ -1757,7 +1781,7 @@ func (al *AgentLoop) summarizeSession(agent *AgentInstance, sessionKey string) {
|
||||
|
||||
if finalSummary != "" {
|
||||
agent.Sessions.SetSummary(sessionKey, finalSummary)
|
||||
agent.Sessions.TruncateHistory(sessionKey, 4)
|
||||
agent.Sessions.TruncateHistory(sessionKey, keepCount)
|
||||
agent.Sessions.Save(sessionKey)
|
||||
}
|
||||
}
|
||||
@@ -1895,15 +1919,14 @@ func (al *AgentLoop) summarizeBatch(
|
||||
}
|
||||
|
||||
// estimateTokens estimates the number of tokens in a message list.
|
||||
// Uses a safe heuristic of 2.5 characters per token to account for CJK and other
|
||||
// overheads better than the previous 3 chars/token.
|
||||
// Counts Content, ToolCalls arguments, and ToolCallID metadata so that
|
||||
// tool-heavy conversations are not systematically undercounted.
|
||||
func (al *AgentLoop) estimateTokens(messages []providers.Message) int {
|
||||
totalChars := 0
|
||||
total := 0
|
||||
for _, m := range messages {
|
||||
totalChars += utf8.RuneCountInString(m.Content)
|
||||
total += estimateMessageTokens(m)
|
||||
}
|
||||
// 2.5 chars per token = totalChars * 2 / 5
|
||||
return totalChars * 2 / 5
|
||||
return total
|
||||
}
|
||||
|
||||
func (al *AgentLoop) handleCommand(
|
||||
|
||||
@@ -228,6 +228,7 @@ type AgentDefaults struct {
|
||||
ImageModel string `json:"image_model,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_IMAGE_MODEL"`
|
||||
ImageModelFallbacks []string `json:"image_model_fallbacks,omitempty"`
|
||||
MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"`
|
||||
ContextWindow int `json:"context_window,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_CONTEXT_WINDOW"`
|
||||
Temperature *float64 `json:"temperature,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"`
|
||||
MaxToolIterations int `json:"max_tool_iterations" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"`
|
||||
SummarizeMessageThreshold int `json:"summarize_message_threshold" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_MESSAGE_THRESHOLD"`
|
||||
|
||||
Reference in New Issue
Block a user