merge: resolve conflicts with upstream/main

Accept upstream versions for all non-Telegram files to keep PR
scope focused on Telegram message chunking only.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
I Putu Eddy Irawan
2026-03-04 22:26:51 +07:00
33 changed files with 2406 additions and 1002 deletions
+1 -1
View File
@@ -17,4 +17,4 @@
# BRAVE_SEARCH_API_KEY=BSA...
# ── Timezone ──────────────────────────────
TZ=Asia/Tokyo
TZ=Asia/Shanghai
+1 -1
View File
@@ -54,7 +54,7 @@
## 📢 News
2026-02-16 🎉 PicoClaw hit 12K stars in one week! Thank you all for your support! PicoClaw is growing faster than we ever imagined. Given the high volume of PRs, we urgently need community maintainers. Our volunteer roles and roadmap are officially posted [here](docs/ROADMAP.md) —we cant wait to have you on board!
2026-02-16 🎉 PicoClaw hit 12K stars in one week! Thank you all for your support! PicoClaw is growing faster than we ever imagined. Given the high volume of PRs, we urgently need community maintainers. Our volunteer roles and roadmap are officially posted [here](ROADMAP.md) —we cant wait to have you on board!
2026-02-13 🎉 PicoClaw hit 5000 stars in 4days! Thank you for the community! There are so many PRs & issues coming in (during Chinese New Year holidays), we are finalizing the Project Roadmap and setting up the Developer Group to accelerate PicoClaw's development.
🚀 Call to Action: Please submit your feature requests in GitHub Discussions. We will review and prioritize them during our upcoming weekly meeting.
BIN
View File
Binary file not shown.

Before

Width:  |  Height:  |  Size: 140 KiB

After

Width:  |  Height:  |  Size: 96 KiB

+5 -2
View File
@@ -6,7 +6,9 @@
"model_name": "gpt4",
"max_tokens": 8192,
"temperature": 0.7,
"max_tool_iterations": 20
"max_tool_iterations": 20,
"summarize_message_threshold": 20,
"summarize_token_percent": 75
}
},
"model_list": [
@@ -59,6 +61,7 @@
"discord": {
"enabled": false,
"token": "YOUR_DISCORD_BOT_TOKEN",
"proxy": "",
"allow_from": [],
"group_trigger": {
"mention_only": false
@@ -337,4 +340,4 @@
"host": "127.0.0.1",
"port": 18790
}
}
}
+1 -1
View File
@@ -11,7 +11,6 @@ require (
github.com/gdamore/tcell/v2 v2.13.8
github.com/google/uuid v1.6.0
github.com/gorilla/websocket v1.5.3
github.com/joho/godotenv v1.5.1
github.com/larksuite/oapi-sdk-go/v3 v3.5.3
github.com/mdp/qrterminal/v3 v3.2.1
github.com/modelcontextprotocol/go-sdk v1.3.0
@@ -38,6 +37,7 @@ require (
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/elliotchance/orderedmap/v3 v3.1.0 // indirect
github.com/gdamore/encoding v1.0.1 // indirect
github.com/gdamore/tcell/v2 v2.13.8 // indirect
github.com/h2non/filetype v1.1.3 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/lucasb-eyer/go-colorful v1.3.0 // indirect
-2
View File
@@ -105,8 +105,6 @@ github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyf
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
github.com/klauspost/compress v1.18.4 h1:RPhnKRAQ4Fh8zU2FY/6ZFDwTVTxgJ/EMydqSTzE9a2c=
+46 -32
View File
@@ -18,22 +18,24 @@ import (
// AgentInstance represents a fully configured agent with its own workspace,
// session manager, context builder, and tool registry.
type AgentInstance struct {
ID string
Name string
Model string
Fallbacks []string
Workspace string
MaxIterations int
MaxTokens int
Temperature float64
ContextWindow int
Provider providers.LLMProvider
Sessions *session.SessionManager
ContextBuilder *ContextBuilder
Tools *tools.ToolRegistry
Subagents *config.SubagentsConfig
SkillsFilter []string
Candidates []providers.FallbackCandidate
ID string
Name string
Model string
Fallbacks []string
Workspace string
MaxIterations int
MaxTokens int
Temperature float64
ContextWindow int
SummarizeMessageThreshold int
SummarizeTokenPercent int
Provider providers.LLMProvider
Sessions *session.SessionManager
ContextBuilder *ContextBuilder
Tools *tools.ToolRegistry
Subagents *config.SubagentsConfig
SkillsFilter []string
Candidates []providers.FallbackCandidate
}
// NewAgentInstance creates an agent instance from config.
@@ -101,6 +103,16 @@ func NewAgentInstance(
temperature = *defaults.Temperature
}
summarizeMessageThreshold := defaults.SummarizeMessageThreshold
if summarizeMessageThreshold == 0 {
summarizeMessageThreshold = 20
}
summarizeTokenPercent := defaults.SummarizeTokenPercent
if summarizeTokenPercent == 0 {
summarizeTokenPercent = 75
}
// Resolve fallback candidates
modelCfg := providers.ModelConfig{
Primary: model,
@@ -149,22 +161,24 @@ func NewAgentInstance(
candidates := providers.ResolveCandidatesWithLookup(modelCfg, defaults.Provider, resolveFromModelList)
return &AgentInstance{
ID: agentID,
Name: agentName,
Model: model,
Fallbacks: fallbacks,
Workspace: workspace,
MaxIterations: maxIter,
MaxTokens: maxTokens,
Temperature: temperature,
ContextWindow: maxTokens,
Provider: provider,
Sessions: sessionsManager,
ContextBuilder: contextBuilder,
Tools: toolsRegistry,
Subagents: subagents,
SkillsFilter: skillsFilter,
Candidates: candidates,
ID: agentID,
Name: agentName,
Model: model,
Fallbacks: fallbacks,
Workspace: workspace,
MaxIterations: maxIter,
MaxTokens: maxTokens,
Temperature: temperature,
ContextWindow: maxTokens,
SummarizeMessageThreshold: summarizeMessageThreshold,
SummarizeTokenPercent: summarizeTokenPercent,
Provider: provider,
Sessions: sessionsManager,
ContextBuilder: contextBuilder,
Tools: toolsRegistry,
Subagents: subagents,
SkillsFilter: skillsFilter,
Candidates: candidates,
}
}
+66 -57
View File
@@ -118,9 +118,11 @@ func registerSharedTools(
PerplexityAPIKey: cfg.Tools.Web.Perplexity.APIKey,
PerplexityMaxResults: cfg.Tools.Web.Perplexity.MaxResults,
PerplexityEnabled: cfg.Tools.Web.Perplexity.Enabled,
ExaAPIKey: cfg.Tools.Web.Exa.APIKey,
ExaMaxResults: cfg.Tools.Web.Exa.MaxResults,
ExaEnabled: cfg.Tools.Web.Exa.Enabled,
GLMSearchAPIKey: cfg.Tools.Web.GLMSearch.APIKey,
GLMSearchBaseURL: cfg.Tools.Web.GLMSearch.BaseURL,
GLMSearchEngine: cfg.Tools.Web.GLMSearch.SearchEngine,
GLMSearchMaxResults: cfg.Tools.Web.GLMSearch.MaxResults,
GLMSearchEnabled: cfg.Tools.Web.GLMSearch.Enabled,
Proxy: cfg.Tools.Web.Proxy,
})
if err != nil {
@@ -967,62 +969,76 @@ func (al *AgentLoop) runLLMIteration(
// Save assistant message with tool calls to session
agent.Sessions.AddFullMessage(opts.SessionKey, assistantMsg)
// Execute tool calls
for _, tc := range normalizedToolCalls {
argsJSON, _ := json.Marshal(tc.Arguments)
argsPreview := utils.Truncate(string(argsJSON), 200)
logger.InfoCF("agent", fmt.Sprintf("Tool call: %s(%s)", tc.Name, argsPreview),
map[string]any{
"agent_id": agent.ID,
"tool": tc.Name,
"iteration": iteration,
})
// Execute tool calls in parallel
type indexedAgentResult struct {
result *tools.ToolResult
tc providers.ToolCall
}
// Create async callback for tools that implement AsyncTool
// NOTE: Following openclaw's design, async tools do NOT send results directly to users.
// Instead, they notify the agent via PublishInbound, and the agent decides
// whether to forward the result to the user (in processSystemMessage).
asyncCallback := func(callbackCtx context.Context, result *tools.ToolResult) {
// Log the async completion but don't send directly to user
// The agent will handle user notification via processSystemMessage
if !result.Silent && result.ForUser != "" {
logger.InfoCF("agent", "Async tool completed, agent will handle notification",
map[string]any{
"tool": tc.Name,
"content_len": len(result.ForUser),
})
agentResults := make([]indexedAgentResult, len(normalizedToolCalls))
var wg sync.WaitGroup
for i, tc := range normalizedToolCalls {
agentResults[i].tc = tc
wg.Add(1)
go func(idx int, tc providers.ToolCall) {
defer wg.Done()
argsJSON, _ := json.Marshal(tc.Arguments)
argsPreview := utils.Truncate(string(argsJSON), 200)
logger.InfoCF("agent", fmt.Sprintf("Tool call: %s(%s)", tc.Name, argsPreview),
map[string]any{
"agent_id": agent.ID,
"tool": tc.Name,
"iteration": iteration,
})
// Create async callback for tools that implement AsyncTool
asyncCallback := func(callbackCtx context.Context, result *tools.ToolResult) {
if !result.Silent && result.ForUser != "" {
logger.InfoCF("agent", "Async tool completed, agent will handle notification",
map[string]any{
"tool": tc.Name,
"content_len": len(result.ForUser),
})
}
}
}
toolResult := agent.Tools.ExecuteWithContext(
ctx,
tc.Name,
tc.Arguments,
opts.Channel,
opts.ChatID,
asyncCallback,
)
toolResult := agent.Tools.ExecuteWithContext(
ctx,
tc.Name,
tc.Arguments,
opts.Channel,
opts.ChatID,
asyncCallback,
)
agentResults[idx].result = toolResult
}(i, tc)
}
wg.Wait()
// Process results in original order (send to user, save to session)
for _, r := range agentResults {
// Send ForUser content to user immediately if not Silent
if !toolResult.Silent && toolResult.ForUser != "" && opts.SendResponse {
if !r.result.Silent && r.result.ForUser != "" && opts.SendResponse {
al.bus.PublishOutbound(ctx, bus.OutboundMessage{
Channel: opts.Channel,
ChatID: opts.ChatID,
Content: toolResult.ForUser,
Content: r.result.ForUser,
})
logger.DebugCF("agent", "Sent tool result to user",
map[string]any{
"tool": tc.Name,
"content_len": len(toolResult.ForUser),
"tool": r.tc.Name,
"content_len": len(r.result.ForUser),
})
}
// If tool returned media refs, publish them as outbound media
if len(toolResult.Media) > 0 && opts.SendResponse {
parts := make([]bus.MediaPart, 0, len(toolResult.Media))
for _, ref := range toolResult.Media {
if len(r.result.Media) > 0 && opts.SendResponse {
parts := make([]bus.MediaPart, 0, len(r.result.Media))
for _, ref := range r.result.Media {
part := bus.MediaPart{Ref: ref}
// Populate metadata from MediaStore when available
if al.mediaStore != nil {
if _, meta, err := al.mediaStore.ResolveWithMeta(ref); err == nil {
part.Filename = meta.Filename
@@ -1040,15 +1056,15 @@ func (al *AgentLoop) runLLMIteration(
}
// Determine content for LLM based on tool result
contentForLLM := toolResult.ForLLM
if contentForLLM == "" && toolResult.Err != nil {
contentForLLM = toolResult.Err.Error()
contentForLLM := r.result.ForLLM
if contentForLLM == "" && r.result.Err != nil {
contentForLLM = r.result.Err.Error()
}
toolResultMsg := providers.Message{
Role: "tool",
Content: contentForLLM,
ToolCallID: tc.ID,
ToolCallID: r.tc.ID,
}
messages = append(messages, toolResultMsg)
@@ -1084,9 +1100,9 @@ func (al *AgentLoop) updateToolContexts(agent *AgentInstance, channel, chatID st
func (al *AgentLoop) maybeSummarize(agent *AgentInstance, sessionKey, channel, chatID string) {
newHistory := agent.Sessions.GetHistory(sessionKey)
tokenEstimate := al.estimateTokens(newHistory)
threshold := agent.ContextWindow * 75 / 100
threshold := agent.ContextWindow * agent.SummarizeTokenPercent / 100
if len(newHistory) > 20 || tokenEstimate > threshold {
if len(newHistory) > agent.SummarizeMessageThreshold || tokenEstimate > threshold {
summarizeKey := agent.ID + ":" + sessionKey
if _, loading := al.summarizing.LoadOrStore(summarizeKey, true); !loading {
go func() {
@@ -1114,15 +1130,8 @@ func (al *AgentLoop) forceCompression(agent *AgentInstance, sessionKey string) {
return
}
// Find the mid-point of the conversation, avoiding splitting tool call/result pairs.
// A tool-call message (role=assistant with ToolCalls) must be followed by its
// tool-result message (role=tool). Splitting between them causes API errors.
// Helper to find the mid-point of the conversation
mid := len(conversation) / 2
if mid < len(conversation) && mid > 0 {
if conversation[mid].Role == "tool" {
mid++ // move past the tool result to keep the pair together
}
}
// New history structure:
// 1. System Prompt (with compression note appended)
-79
View File
@@ -603,85 +603,6 @@ func TestAgentLoop_ContextExhaustionRetry(t *testing.T) {
}
}
// TestForceCompression_ToolMessageBoundary verifies that forceCompression does not
// split a tool call/result pair when the midpoint falls on a "tool" role message.
// Regression test for: API errors when orphaned tool result messages appear
// without their preceding assistant tool-call message.
func TestForceCompression_ToolMessageBoundary(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
msgBus := bus.NewMessageBus()
provider := &mockProvider{}
al := NewAgentLoop(cfg, msgBus, provider)
sessionKey := "test-session-tool-boundary"
defaultAgent := al.registry.GetDefaultAgent()
if defaultAgent == nil {
t.Fatal("No default agent found")
}
// Construct a history where len(conversation)/2 falls exactly on a "tool" message.
// history = [system, user, assistant(tool_call), tool, user, assistant, user_trigger]
// conversation = history[1:6] = [user, assistant(tool_call), tool, user, assistant]
// len(conversation) = 5, mid = 5/2 = 2 => conversation[2].Role == "tool"
// Without the fix, this would split between assistant(tool_call) and tool result.
history := []providers.Message{
{Role: "system", Content: "You are a helpful assistant."},
{Role: "user", Content: "What files are in the current directory?"},
{Role: "assistant", Content: "", ToolCalls: []providers.ToolCall{
{ID: "call_1", Name: "exec", Arguments: map[string]any{"command": "ls"}},
}},
{Role: "tool", Content: "file1.txt\nfile2.txt", ToolCallID: "call_1"},
{Role: "user", Content: "Tell me about file1.txt"},
{Role: "assistant", Content: "file1.txt is a text file."},
{Role: "user", Content: "Thanks"}, // trigger message
}
// Create the session first (AddMessage creates the session entry),
// then overwrite with our full history via SetHistory.
defaultAgent.Sessions.AddMessage(sessionKey, "system", "init")
defaultAgent.Sessions.SetHistory(sessionKey, history)
// Call forceCompression
al.forceCompression(defaultAgent, sessionKey)
// Verify the result
compressed := defaultAgent.Sessions.GetHistory(sessionKey)
// Check that no message with role="tool" is the first conversation message
// (after the system prompt). If it is, it means the tool result was orphaned.
for i := 1; i < len(compressed); i++ {
if compressed[i].Role == "tool" {
// There must be an assistant message with tool calls before it
if i == 1 {
t.Errorf("Tool result message at position %d is orphaned (no preceding assistant with tool call)", i)
} else if compressed[i-1].Role != "assistant" || len(compressed[i-1].ToolCalls) == 0 {
t.Errorf("Tool result at position %d is not preceded by assistant with tool calls (preceded by role=%q)", i, compressed[i-1].Role)
}
}
}
// Verify the system prompt has the compression note
if !strings.Contains(compressed[0].Content, "Emergency compression") {
t.Errorf("Expected compression note in system prompt, got: %s", compressed[0].Content)
}
}
func TestTargetReasoningChannelID_AllChannels(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
+40
View File
@@ -3,12 +3,15 @@ package discord
import (
"context"
"fmt"
"net/http"
"net/url"
"os"
"strings"
"sync"
"time"
"github.com/bwmarrin/discordgo"
"github.com/gorilla/websocket"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/channels"
@@ -40,6 +43,9 @@ func NewDiscordChannel(cfg config.DiscordConfig, bus *bus.MessageBus) (*DiscordC
return nil, fmt.Errorf("failed to create discord session: %w", err)
}
if err := applyDiscordProxy(session, cfg.Proxy); err != nil {
return nil, err
}
base := channels.NewBaseChannel("discord", cfg, bus, cfg.AllowFrom,
channels.WithMaxMessageLength(2000),
channels.WithGroupTrigger(cfg.GroupTrigger),
@@ -465,9 +471,43 @@ func (c *DiscordChannel) StartTyping(ctx context.Context, chatID string) (func()
func (c *DiscordChannel) downloadAttachment(url, filename string) string {
return utils.DownloadFile(url, filename, utils.DownloadOptions{
LoggerPrefix: "discord",
ProxyURL: c.config.Proxy,
})
}
func applyDiscordProxy(session *discordgo.Session, proxyAddr string) error {
var proxyFunc func(*http.Request) (*url.URL, error)
if proxyAddr != "" {
proxyURL, err := url.Parse(proxyAddr)
if err != nil {
return fmt.Errorf("invalid discord proxy URL %q: %w", proxyAddr, err)
}
proxyFunc = http.ProxyURL(proxyURL)
} else if os.Getenv("HTTP_PROXY") != "" || os.Getenv("HTTPS_PROXY") != "" {
proxyFunc = http.ProxyFromEnvironment
}
if proxyFunc == nil {
return nil
}
transport := &http.Transport{Proxy: proxyFunc}
session.Client = &http.Client{
Timeout: sendTimeout,
Transport: transport,
}
if session.Dialer != nil {
dialerCopy := *session.Dialer
dialerCopy.Proxy = proxyFunc
session.Dialer = &dialerCopy
} else {
session.Dialer = &websocket.Dialer{Proxy: proxyFunc}
}
return nil
}
// stripBotMention removes the bot mention from the message content.
// Discord mentions have the format <@USER_ID> or <@!USER_ID> (with nickname).
func (c *DiscordChannel) stripBotMention(text string) string {
+91
View File
@@ -0,0 +1,91 @@
package discord
import (
"net/http"
"net/url"
"testing"
"github.com/bwmarrin/discordgo"
)
func TestApplyDiscordProxy_CustomProxy(t *testing.T) {
session, err := discordgo.New("Bot test-token")
if err != nil {
t.Fatalf("discordgo.New() error: %v", err)
}
if err = applyDiscordProxy(session, "http://127.0.0.1:7890"); err != nil {
t.Fatalf("applyDiscordProxy() error: %v", err)
}
req, err := http.NewRequest("GET", "https://discord.com/api/v10/gateway", nil)
if err != nil {
t.Fatalf("http.NewRequest() error: %v", err)
}
restProxy := session.Client.Transport.(*http.Transport).Proxy
restProxyURL, err := restProxy(req)
if err != nil {
t.Fatalf("rest proxy func error: %v", err)
}
if got, want := restProxyURL.String(), "http://127.0.0.1:7890"; got != want {
t.Fatalf("REST proxy = %q, want %q", got, want)
}
wsProxyURL, err := session.Dialer.Proxy(req)
if err != nil {
t.Fatalf("ws proxy func error: %v", err)
}
if got, want := wsProxyURL.String(), "http://127.0.0.1:7890"; got != want {
t.Fatalf("WS proxy = %q, want %q", got, want)
}
}
func TestApplyDiscordProxy_FromEnvironment(t *testing.T) {
t.Setenv("HTTP_PROXY", "http://127.0.0.1:8888")
t.Setenv("http_proxy", "http://127.0.0.1:8888")
t.Setenv("HTTPS_PROXY", "http://127.0.0.1:8888")
t.Setenv("https_proxy", "http://127.0.0.1:8888")
t.Setenv("ALL_PROXY", "")
t.Setenv("all_proxy", "")
t.Setenv("NO_PROXY", "")
t.Setenv("no_proxy", "")
session, err := discordgo.New("Bot test-token")
if err != nil {
t.Fatalf("discordgo.New() error: %v", err)
}
if err = applyDiscordProxy(session, ""); err != nil {
t.Fatalf("applyDiscordProxy() error: %v", err)
}
req, err := http.NewRequest("GET", "https://discord.com/api/v10/gateway", nil)
if err != nil {
t.Fatalf("http.NewRequest() error: %v", err)
}
gotURL, err := session.Dialer.Proxy(req)
if err != nil {
t.Fatalf("ws proxy func error: %v", err)
}
wantURL, err := url.Parse("http://127.0.0.1:8888")
if err != nil {
t.Fatalf("url.Parse() error: %v", err)
}
if gotURL.String() != wantURL.String() {
t.Fatalf("WS proxy = %q, want %q", gotURL.String(), wantURL.String())
}
}
func TestApplyDiscordProxy_InvalidProxyURL(t *testing.T) {
session, err := discordgo.New("Bot test-token")
if err != nil {
t.Fatalf("discordgo.New() error: %v", err)
}
if err = applyDiscordProxy(session, "://bad-proxy"); err == nil {
t.Fatal("applyDiscordProxy() expected error for invalid proxy URL, got nil")
}
}
+13 -85
View File
@@ -3,23 +3,14 @@ package config
import (
"encoding/json"
"fmt"
"log"
"os"
"path/filepath"
"sync"
"sync/atomic"
"github.com/caarlos0/env/v11"
"github.com/joho/godotenv"
"github.com/sipeed/picoclaw/pkg/fileutil"
)
// dotenvOnce ensures .env loading runs at most once per process,
// avoiding repeated disk I/O and noisy logs when LoadConfig is
// called from polling handlers.
var dotenvOnce sync.Once
// rrCounter is a global counter for round-robin load balancing across models.
var rrCounter atomic.Uint64
@@ -189,6 +180,8 @@ type AgentDefaults struct {
MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"`
Temperature *float64 `json:"temperature,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"`
MaxToolIterations int `json:"max_tool_iterations" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"`
SummarizeMessageThreshold int `json:"summarize_message_threshold" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_MESSAGE_THRESHOLD"`
SummarizeTokenPercent int `json:"summarize_token_percent" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_TOKEN_PERCENT"`
MaxMediaSize int `json:"max_media_size,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_MEDIA_SIZE"`
}
@@ -280,6 +273,7 @@ type FeishuConfig struct {
type DiscordConfig struct {
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_DISCORD_ENABLED"`
Token string `json:"token" env:"PICOCLAW_CHANNELS_DISCORD_TOKEN"`
Proxy string `json:"proxy" env:"PICOCLAW_CHANNELS_DISCORD_PROXY"`
AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_DISCORD_ALLOW_FROM"`
MentionOnly bool `json:"mention_only" env:"PICOCLAW_CHANNELS_DISCORD_MENTION_ONLY"`
GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"`
@@ -437,7 +431,6 @@ type ProvidersConfig struct {
Antigravity ProviderConfig `json:"antigravity"`
Qwen ProviderConfig `json:"qwen"`
Mistral ProviderConfig `json:"mistral"`
Opencode ProviderConfig `json:"opencode"`
}
// IsEmpty checks if all provider configs are empty (no API keys or API bases set)
@@ -461,8 +454,7 @@ func (p ProvidersConfig) IsEmpty() bool {
p.GitHubCopilot.APIKey == "" && p.GitHubCopilot.APIBase == "" &&
p.Antigravity.APIKey == "" && p.Antigravity.APIBase == "" &&
p.Qwen.APIKey == "" && p.Qwen.APIBase == "" &&
p.Mistral.APIKey == "" && p.Mistral.APIBase == "" &&
p.Opencode.APIKey == "" && p.Opencode.APIBase == ""
p.Mistral.APIKey == "" && p.Mistral.APIBase == ""
}
// MarshalJSON implements custom JSON marshaling for ProvidersConfig
@@ -555,10 +547,14 @@ type PerplexityConfig struct {
MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_PERPLEXITY_MAX_RESULTS"`
}
type ExaConfig struct {
Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_WEB_EXA_ENABLED"`
APIKey string `json:"api_key" env:"PICOCLAW_TOOLS_WEB_EXA_API_KEY"`
MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_EXA_MAX_RESULTS"`
type GLMSearchConfig struct {
Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_WEB_GLM_ENABLED"`
APIKey string `json:"api_key" env:"PICOCLAW_TOOLS_WEB_GLM_API_KEY"`
BaseURL string `json:"base_url" env:"PICOCLAW_TOOLS_WEB_GLM_BASE_URL"`
// SearchEngine specifies the search backend: "search_std" (default),
// "search_pro", "search_pro_sogou", or "search_pro_quark".
SearchEngine string `json:"search_engine" env:"PICOCLAW_TOOLS_WEB_GLM_SEARCH_ENGINE"`
MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_GLM_MAX_RESULTS"`
}
type WebToolsConfig struct {
@@ -566,7 +562,7 @@ type WebToolsConfig struct {
Tavily TavilyConfig `json:"tavily"`
DuckDuckGo DuckDuckGoConfig `json:"duckduckgo"`
Perplexity PerplexityConfig `json:"perplexity"`
Exa ExaConfig `json:"exa"`
GLMSearch GLMSearchConfig `json:"glm_search"`
// Proxy is an optional proxy URL for web tools (http/https/socks5/socks5h).
// For authenticated proxies, prefer HTTP_PROXY/HTTPS_PROXY env vars instead of embedding credentials in config.
Proxy string `json:"proxy,omitempty" env:"PICOCLAW_TOOLS_WEB_PROXY"`
@@ -658,35 +654,9 @@ type MCPConfig struct {
func LoadConfig(path string) (*Config, error) {
cfg := DefaultConfig()
// Load .env file from config directory (secrets, API keys, etc.)
// Guarded by sync.Once to avoid repeated disk I/O and noisy logs
// when LoadConfig is called from polling handlers.
dotenvOnce.Do(func() {
envFile := filepath.Join(filepath.Dir(path), ".env")
if err := godotenv.Load(envFile); err != nil {
if os.IsNotExist(err) {
log.Printf("[INFO] No .env file found at %s; skipping .env loading", envFile)
} else {
log.Printf("[WARN] Failed to load .env file from %s: %v", envFile, err)
}
}
})
data, err := os.ReadFile(path)
if err != nil {
if os.IsNotExist(err) {
// No config file — still apply env vars + overrides to default config
if err := env.Parse(cfg); err != nil {
return nil, err
}
loadProviderEnvOverrides(cfg)
cfg.migrateChannelConfigs()
if cfg.HasProvidersConfig() {
cfg.ModelList = ConvertProvidersToModelList(cfg)
}
if err := cfg.ValidateModelList(); err != nil {
return nil, err
}
return cfg, nil
}
return nil, err
@@ -714,9 +684,6 @@ func LoadConfig(path string) (*Config, error) {
return nil, err
}
// Load provider-specific env overrides (PICOCLAW_PROVIDERS_<NAME>_API_KEY, etc.)
loadProviderEnvOverrides(cfg)
// Migrate legacy channel config fields to new unified structures
cfg.migrateChannelConfigs()
@@ -865,42 +832,3 @@ func (c *Config) ValidateModelList() error {
}
return nil
}
// loadProviderEnvOverrides reads PICOCLAW_PROVIDERS_<NAME>_API_KEY and _API_BASE
// environment variables and sets them on the corresponding provider config fields.
// This enables storing provider secrets in .env files without using struct tags.
func loadProviderEnvOverrides(cfg *Config) {
providers := []struct {
name string
apiKey *string
base *string
}{
{"ANTHROPIC", &cfg.Providers.Anthropic.APIKey, &cfg.Providers.Anthropic.APIBase},
{"OPENAI", &cfg.Providers.OpenAI.APIKey, &cfg.Providers.OpenAI.APIBase},
{"LITELLM", &cfg.Providers.LiteLLM.APIKey, &cfg.Providers.LiteLLM.APIBase},
{"OPENROUTER", &cfg.Providers.OpenRouter.APIKey, &cfg.Providers.OpenRouter.APIBase},
{"GROQ", &cfg.Providers.Groq.APIKey, &cfg.Providers.Groq.APIBase},
{"ZHIPU", &cfg.Providers.Zhipu.APIKey, &cfg.Providers.Zhipu.APIBase},
{"GEMINI", &cfg.Providers.Gemini.APIKey, &cfg.Providers.Gemini.APIBase},
{"NVIDIA", &cfg.Providers.Nvidia.APIKey, &cfg.Providers.Nvidia.APIBase},
{"OLLAMA", &cfg.Providers.Ollama.APIKey, &cfg.Providers.Ollama.APIBase},
{"MOONSHOT", &cfg.Providers.Moonshot.APIKey, &cfg.Providers.Moonshot.APIBase},
{"SHENGSUANYUN", &cfg.Providers.ShengSuanYun.APIKey, &cfg.Providers.ShengSuanYun.APIBase},
{"DEEPSEEK", &cfg.Providers.DeepSeek.APIKey, &cfg.Providers.DeepSeek.APIBase},
{"MISTRAL", &cfg.Providers.Mistral.APIKey, &cfg.Providers.Mistral.APIBase},
{"VLLM", &cfg.Providers.VLLM.APIKey, &cfg.Providers.VLLM.APIBase},
{"CEREBRAS", &cfg.Providers.Cerebras.APIKey, &cfg.Providers.Cerebras.APIBase},
{"VOLCENGINE", &cfg.Providers.VolcEngine.APIKey, &cfg.Providers.VolcEngine.APIBase},
{"QWEN", &cfg.Providers.Qwen.APIKey, &cfg.Providers.Qwen.APIBase},
// Note: GitHubCopilot and Antigravity use different auth patterns (ConnectMode/AuthMethod),
// not standard APIKey/APIBase, so they are not included here.
}
for _, p := range providers {
if v, ok := os.LookupEnv("PICOCLAW_PROVIDERS_" + p.name + "_API_KEY"); ok {
*p.apiKey = v
}
if v, ok := os.LookupEnv("PICOCLAW_PROVIDERS_" + p.name + "_API_BASE"); ok {
*p.base = v
}
}
}
+12 -96
View File
@@ -6,7 +6,6 @@ import (
"path/filepath"
"runtime"
"strings"
"sync"
"testing"
)
@@ -436,6 +435,18 @@ func TestLoadConfig_WebToolsProxy(t *testing.T) {
}
// TestDefaultConfig_DMScope verifies the default dm_scope value
// TestDefaultConfig_SummarizationThresholds verifies summarization defaults
func TestDefaultConfig_SummarizationThresholds(t *testing.T) {
cfg := DefaultConfig()
if cfg.Agents.Defaults.SummarizeMessageThreshold != 20 {
t.Errorf("SummarizeMessageThreshold = %d, want 20", cfg.Agents.Defaults.SummarizeMessageThreshold)
}
if cfg.Agents.Defaults.SummarizeTokenPercent != 75 {
t.Errorf("SummarizeTokenPercent = %d, want 75", cfg.Agents.Defaults.SummarizeTokenPercent)
}
}
func TestDefaultConfig_DMScope(t *testing.T) {
cfg := DefaultConfig()
@@ -468,98 +479,3 @@ func TestDefaultConfig_WorkspacePath_WithPicoclawHome(t *testing.T) {
t.Errorf("Workspace path with PICOCLAW_HOME = %q, want %q", cfg.Agents.Defaults.Workspace, want)
}
}
func TestLoadConfig_DotenvFileLoaded(t *testing.T) {
// Reset sync.Once so .env loading runs for this test
dotenvOnce = sync.Once{}
dir := t.TempDir()
configPath := filepath.Join(dir, "config.json")
// Write a minimal config.json
if err := os.WriteFile(configPath, []byte(`{}`), 0o600); err != nil {
t.Fatalf("WriteFile config: %v", err)
}
// Write a .env file with a provider API key
envFile := filepath.Join(dir, ".env")
if err := os.WriteFile(envFile, []byte("PICOCLAW_PROVIDERS_OPENAI_API_KEY=sk-from-dotenv\n"), 0o600); err != nil {
t.Fatalf("WriteFile .env: %v", err)
}
// Clear the env var first to ensure it comes from .env
t.Setenv("PICOCLAW_PROVIDERS_OPENAI_API_KEY", "")
os.Unsetenv("PICOCLAW_PROVIDERS_OPENAI_API_KEY")
cfg, err := LoadConfig(configPath)
if err != nil {
t.Fatalf("LoadConfig() error: %v", err)
}
if cfg.Providers.OpenAI.APIKey != "sk-from-dotenv" {
t.Errorf("OpenAI.APIKey = %q, want %q", cfg.Providers.OpenAI.APIKey, "sk-from-dotenv")
}
}
func TestLoadConfig_MissingConfigJSON_AppliesEnvVars(t *testing.T) {
// Reset sync.Once so .env loading runs for this test
dotenvOnce = sync.Once{}
dir := t.TempDir()
configPath := filepath.Join(dir, "config.json") // does NOT exist
t.Setenv("PICOCLAW_PROVIDERS_ANTHROPIC_API_KEY", "sk-anthropic-test")
cfg, err := LoadConfig(configPath)
if err != nil {
t.Fatalf("LoadConfig() error: %v", err)
}
if cfg.Providers.Anthropic.APIKey != "sk-anthropic-test" {
t.Errorf("Anthropic.APIKey = %q, want %q", cfg.Providers.Anthropic.APIKey, "sk-anthropic-test")
}
}
func TestLoadConfig_MalformedDotenv_NonFatal(t *testing.T) {
// Reset sync.Once so .env loading runs for this test
dotenvOnce = sync.Once{}
dir := t.TempDir()
configPath := filepath.Join(dir, "config.json")
// Write a minimal config.json
if err := os.WriteFile(configPath, []byte(`{}`), 0o600); err != nil {
t.Fatalf("WriteFile config: %v", err)
}
// Write a .env file with genuinely malformed content (bare key without '=',
// mixed with a valid line) to verify godotenv.Load errors are non-fatal.
envFile := filepath.Join(dir, ".env")
if err := os.WriteFile(envFile, []byte("THIS_LINE_HAS_NO_EQUALS\nVALID_KEY=valid_value\n"), 0o600); err != nil {
t.Fatalf("WriteFile .env: %v", err)
}
// LoadConfig should not fail even with malformed .env content
cfg, err := LoadConfig(configPath)
if err != nil {
t.Fatalf("LoadConfig() should not fail with .env issues, got error: %v", err)
}
if cfg == nil {
t.Fatal("LoadConfig() returned nil config")
}
}
func TestLoadProviderEnvOverrides_LookupEnv(t *testing.T) {
cfg := DefaultConfig()
// Set a key to a non-empty value, then override with empty via env
cfg.Providers.OpenRouter.APIBase = "https://original.com"
t.Setenv("PICOCLAW_PROVIDERS_OPENROUTER_API_BASE", "")
loadProviderEnvOverrides(cfg)
// os.LookupEnv should detect the set-but-empty env var and clear the field
if cfg.Providers.OpenRouter.APIBase != "" {
t.Errorf("OpenRouter.APIBase = %q, want empty (overridden by empty env var)", cfg.Providers.OpenRouter.APIBase)
}
}
+15 -11
View File
@@ -26,13 +26,15 @@ func DefaultConfig() *Config {
return &Config{
Agents: AgentsConfig{
Defaults: AgentDefaults{
Workspace: workspacePath,
RestrictToWorkspace: true,
Provider: "",
Model: "",
MaxTokens: 32768,
Temperature: nil, // nil means use provider default
MaxToolIterations: 50,
Workspace: workspacePath,
RestrictToWorkspace: true,
Provider: "",
Model: "",
MaxTokens: 32768,
Temperature: nil, // nil means use provider default
MaxToolIterations: 50,
SummarizeMessageThreshold: 20,
SummarizeTokenPercent: 75,
},
},
Bindings: []AgentBinding{},
@@ -341,10 +343,12 @@ func DefaultConfig() *Config {
APIKey: "",
MaxResults: 5,
},
Exa: ExaConfig{
Enabled: false,
APIKey: "",
MaxResults: 5,
GLMSearch: GLMSearchConfig{
Enabled: false,
APIKey: "",
BaseURL: "https://open.bigmodel.cn/api/paas/v4/web_search",
SearchEngine: "search_std",
MaxResults: 5,
},
},
Cron: CronToolsConfig{
+2 -21
View File
@@ -225,7 +225,7 @@ func ConvertProvidersToModelList(cfg *Config) []ModelConfig {
},
},
{
providerNames: []string{"moonshot", "kimi", "kimi-code"},
providerNames: []string{"moonshot", "kimi"},
protocol: "moonshot",
buildConfig: func(p ProvidersConfig) (ModelConfig, bool) {
if p.Moonshot.APIKey == "" && p.Moonshot.APIBase == "" {
@@ -373,23 +373,6 @@ func ConvertProvidersToModelList(cfg *Config) []ModelConfig {
}, true
},
},
{
providerNames: []string{"opencode"},
protocol: "opencode",
buildConfig: func(p ProvidersConfig) (ModelConfig, bool) {
if p.Opencode.APIKey == "" && p.Opencode.APIBase == "" {
return ModelConfig{}, false
}
return ModelConfig{
ModelName: "opencode",
Model: "opencode/auto",
APIKey: p.Opencode.APIKey,
APIBase: p.Opencode.APIBase,
Proxy: p.Opencode.Proxy,
RequestTimeout: p.Opencode.RequestTimeout,
}, true
},
},
}
// Process each provider migration
@@ -401,9 +384,7 @@ func ConvertProvidersToModelList(cfg *Config) []ModelConfig {
// Check if this is the user's configured provider
if slices.Contains(m.providerNames, userProvider) && userModel != "" {
// Use the user's configured model instead of default.
// Also set ModelName so GetModelConfig(userModel) can find this entry.
mc.ModelName = userModel
// Use the user's configured model instead of default
mc.Model = buildModelWithProtocol(m.protocol, userModel)
} else if userProvider == "" && userModel != "" && !legacyModelNameApplied {
// Legacy config: no explicit provider field but model is specified
-129
View File
@@ -160,7 +160,6 @@ func TestConvertProvidersToModelList_AllProviders(t *testing.T) {
Antigravity: ProviderConfig{AuthMethod: "oauth"},
Qwen: ProviderConfig{APIKey: "key17"},
Mistral: ProviderConfig{APIKey: "key18"},
Opencode: ProviderConfig{APIKey: "key19"},
},
}
@@ -580,65 +579,6 @@ func TestBuildModelWithProtocol_DifferentPrefix(t *testing.T) {
}
}
func TestConvertProvidersToModelList_Opencode(t *testing.T) {
cfg := &Config{
Providers: ProvidersConfig{
Opencode: ProviderConfig{
APIKey: "oc-test-key",
APIBase: "https://custom.opencode.ai/v1",
Proxy: "http://proxy:9090",
RequestTimeout: 60,
},
},
}
result := ConvertProvidersToModelList(cfg)
if len(result) != 1 {
t.Fatalf("len(result) = %d, want 1", len(result))
}
mc := result[0]
if mc.ModelName != "opencode" {
t.Errorf("ModelName = %q, want %q", mc.ModelName, "opencode")
}
if mc.Model != "opencode/auto" {
t.Errorf("Model = %q, want %q", mc.Model, "opencode/auto")
}
if mc.APIKey != "oc-test-key" {
t.Errorf("APIKey = %q, want %q", mc.APIKey, "oc-test-key")
}
if mc.APIBase != "https://custom.opencode.ai/v1" {
t.Errorf("APIBase = %q, want %q", mc.APIBase, "https://custom.opencode.ai/v1")
}
if mc.Proxy != "http://proxy:9090" {
t.Errorf("Proxy = %q, want %q", mc.Proxy, "http://proxy:9090")
}
if mc.RequestTimeout != 60 {
t.Errorf("RequestTimeout = %d, want %d", mc.RequestTimeout, 60)
}
}
func TestConvertProvidersToModelList_Opencode_APIBaseOnly(t *testing.T) {
cfg := &Config{
Providers: ProvidersConfig{
Opencode: ProviderConfig{
APIBase: "https://custom.opencode.ai/v1",
},
},
}
result := ConvertProvidersToModelList(cfg)
if len(result) != 1 {
t.Fatalf("len(result) = %d, want 1 (APIBase-only should create entry)", len(result))
}
if result[0].ModelName != "opencode" {
t.Errorf("ModelName = %q, want %q", result[0].ModelName, "opencode")
}
}
// Test for legacy config with protocol prefix in model name
func TestConvertProvidersToModelList_LegacyModelWithProtocolPrefix(t *testing.T) {
cfg := &Config{
@@ -669,72 +609,3 @@ func TestConvertProvidersToModelList_LegacyModelWithProtocolPrefix(t *testing.T)
t.Errorf("Model = %q, want %q (should not duplicate prefix)", result[0].Model, "openrouter/auto")
}
}
// Test that ModelName is set to the user's configured model when provider matches.
// This ensures GetModelConfig(userModel) can find the migrated entry.
// Regression test for: gateway startup failure when user model differs from provider name.
func TestConvertProvidersToModelList_ModelNameMatchesUserModel(t *testing.T) {
cfg := &Config{
Agents: AgentsConfig{
Defaults: AgentDefaults{
Provider: "moonshot",
Model: "k2p5",
},
},
Providers: ProvidersConfig{
Moonshot: ProviderConfig{APIKey: "sk-kimi-test"},
},
}
result := ConvertProvidersToModelList(cfg)
if len(result) != 1 {
t.Fatalf("len(result) = %d, want 1", len(result))
}
// ModelName must match the user's configured model, not the provider name.
// Without this, GetModelConfig("k2p5") would fail because it would look
// for ModelName == "k2p5" but find ModelName == "moonshot".
if result[0].ModelName != "k2p5" {
t.Errorf("ModelName = %q, want %q (must match user's model for GetModelConfig lookup)", result[0].ModelName, "k2p5")
}
if result[0].Model != "moonshot/k2p5" {
t.Errorf("Model = %q, want %q", result[0].Model, "moonshot/k2p5")
}
// Other providers (not matching the user's configured provider) should keep their provider name
cfg2 := &Config{
Agents: AgentsConfig{
Defaults: AgentDefaults{
Provider: "moonshot",
Model: "k2p5",
},
},
Providers: ProvidersConfig{
OpenAI: OpenAIProviderConfig{ProviderConfig: ProviderConfig{APIKey: "sk-openai"}},
Moonshot: ProviderConfig{APIKey: "sk-kimi-test"},
},
}
result2 := ConvertProvidersToModelList(cfg2)
if len(result2) != 2 {
t.Fatalf("len(result2) = %d, want 2", len(result2))
}
for _, mc := range result2 {
switch {
case mc.APIKey == "sk-openai":
// OpenAI is not the user's provider, should keep default ModelName
if mc.ModelName != "openai" {
t.Errorf("OpenAI ModelName = %q, want %q (non-matching provider keeps default)", mc.ModelName, "openai")
}
case mc.APIKey == "sk-kimi-test":
// Moonshot is the user's provider, ModelName must be the user's model
if mc.ModelName != "k2p5" {
t.Errorf("Moonshot ModelName = %q, want %q (matching provider uses user model)", mc.ModelName, "k2p5")
}
}
}
}
+460
View File
@@ -0,0 +1,460 @@
package memory
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"hash/fnv"
"log"
"os"
"path/filepath"
"strings"
"sync"
"time"
"github.com/sipeed/picoclaw/pkg/fileutil"
"github.com/sipeed/picoclaw/pkg/providers"
)
const (
// numLockShards is the fixed number of mutexes used to serialize
// per-session access. Using a sharded array instead of a map keeps
// memory bounded regardless of how many sessions are created over
// the lifetime of the process — important for a long-running daemon.
numLockShards = 64
// maxLineSize is the maximum size of a single JSON line in a .jsonl
// file. Tool results (read_file, web search, etc.) can be large, so
// we set a generous limit. The scanner starts at 64 KB and grows
// only as needed up to this cap.
maxLineSize = 10 * 1024 * 1024 // 10 MB
)
// sessionMeta holds per-session metadata stored in a .meta.json file.
type sessionMeta struct {
Key string `json:"key"`
Summary string `json:"summary"`
Skip int `json:"skip"`
Count int `json:"count"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// JSONLStore implements Store using append-only JSONL files.
//
// Each session is stored as two files:
//
// {sanitized_key}.jsonl — one JSON-encoded message per line, append-only
// {sanitized_key}.meta.json — session metadata (summary, logical truncation offset)
//
// Messages are never physically deleted from the JSONL file. Instead,
// TruncateHistory records a "skip" offset in the metadata file and
// GetHistory ignores lines before that offset. This keeps all writes
// append-only, which is both fast and crash-safe.
type JSONLStore struct {
dir string
locks [numLockShards]sync.Mutex
}
// NewJSONLStore creates a new JSONL-backed store rooted at dir.
func NewJSONLStore(dir string) (*JSONLStore, error) {
err := os.MkdirAll(dir, 0o755)
if err != nil {
return nil, fmt.Errorf("memory: create directory: %w", err)
}
return &JSONLStore{dir: dir}, nil
}
// sessionLock returns a mutex for the given session key.
// Keys are mapped to a fixed pool of shards via FNV hash, so
// memory usage is O(1) regardless of total session count.
func (s *JSONLStore) sessionLock(key string) *sync.Mutex {
h := fnv.New32a()
h.Write([]byte(key))
return &s.locks[h.Sum32()%numLockShards]
}
func (s *JSONLStore) jsonlPath(key string) string {
return filepath.Join(s.dir, sanitizeKey(key)+".jsonl")
}
func (s *JSONLStore) metaPath(key string) string {
return filepath.Join(s.dir, sanitizeKey(key)+".meta.json")
}
// sanitizeKey converts a session key to a safe filename component.
// Mirrors pkg/session.sanitizeFilename so that migration paths match.
//
// Note: this is a lossy mapping — "telegram:123" and "telegram_123"
// both produce the same filename. This is an intentional tradeoff:
// keys with colons (e.g. from channels) are by far the common case,
// and a bidirectional encoding (like URL-encoding) would complicate
// file listings and debugging.
func sanitizeKey(key string) string {
return strings.ReplaceAll(key, ":", "_")
}
// readMeta loads the metadata file for a session.
// Returns a zero-value sessionMeta if the file does not exist.
func (s *JSONLStore) readMeta(key string) (sessionMeta, error) {
data, err := os.ReadFile(s.metaPath(key))
if os.IsNotExist(err) {
return sessionMeta{Key: key}, nil
}
if err != nil {
return sessionMeta{}, fmt.Errorf("memory: read meta: %w", err)
}
var meta sessionMeta
err = json.Unmarshal(data, &meta)
if err != nil {
return sessionMeta{}, fmt.Errorf("memory: decode meta: %w", err)
}
return meta, nil
}
// writeMeta atomically writes the metadata file using the project's
// standard WriteFileAtomic (temp + fsync + rename).
func (s *JSONLStore) writeMeta(key string, meta sessionMeta) error {
data, err := json.MarshalIndent(meta, "", " ")
if err != nil {
return fmt.Errorf("memory: encode meta: %w", err)
}
return fileutil.WriteFileAtomic(s.metaPath(key), data, 0o644)
}
// readMessages reads valid JSON lines from a .jsonl file, skipping
// the first `skip` lines without unmarshaling them. This avoids the
// cost of json.Unmarshal on logically truncated messages.
// Malformed trailing lines (e.g. from a crash) are silently skipped.
func readMessages(path string, skip int) ([]providers.Message, error) {
f, err := os.Open(path)
if os.IsNotExist(err) {
return []providers.Message{}, nil
}
if err != nil {
return nil, fmt.Errorf("memory: open jsonl: %w", err)
}
defer f.Close()
var msgs []providers.Message
scanner := bufio.NewScanner(f)
// Allow large lines for tool results (read_file, web search, etc.).
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
lineNum := 0
for scanner.Scan() {
line := scanner.Bytes()
if len(line) == 0 {
continue
}
lineNum++
if lineNum <= skip {
continue
}
var msg providers.Message
if err := json.Unmarshal(line, &msg); err != nil {
// Corrupt line — likely a partial write from a crash.
// Log so operators know data was skipped, but don't
// fail the entire read; this is the standard JSONL
// recovery pattern.
log.Printf("memory: skipping corrupt line %d in %s: %v",
lineNum, filepath.Base(path), err)
continue
}
msgs = append(msgs, msg)
}
if scanner.Err() != nil {
return nil, fmt.Errorf("memory: scan jsonl: %w", scanner.Err())
}
if msgs == nil {
msgs = []providers.Message{}
}
return msgs, nil
}
// countLines counts the total number of non-empty lines in a .jsonl file.
// Used by TruncateHistory to reconcile a stale meta.Count without
// the overhead of unmarshaling every message.
func countLines(path string) (int, error) {
f, err := os.Open(path)
if os.IsNotExist(err) {
return 0, nil
}
if err != nil {
return 0, fmt.Errorf("memory: open jsonl: %w", err)
}
defer f.Close()
n := 0
scanner := bufio.NewScanner(f)
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
for scanner.Scan() {
if len(scanner.Bytes()) > 0 {
n++
}
}
return n, scanner.Err()
}
func (s *JSONLStore) AddMessage(
_ context.Context, sessionKey, role, content string,
) error {
return s.addMsg(sessionKey, providers.Message{
Role: role,
Content: content,
})
}
func (s *JSONLStore) AddFullMessage(
_ context.Context, sessionKey string, msg providers.Message,
) error {
return s.addMsg(sessionKey, msg)
}
// addMsg is the shared implementation for AddMessage and AddFullMessage.
func (s *JSONLStore) addMsg(sessionKey string, msg providers.Message) error {
l := s.sessionLock(sessionKey)
l.Lock()
defer l.Unlock()
// Append the message as a single JSON line.
line, err := json.Marshal(msg)
if err != nil {
return fmt.Errorf("memory: marshal message: %w", err)
}
line = append(line, '\n')
f, err := os.OpenFile(
s.jsonlPath(sessionKey),
os.O_CREATE|os.O_WRONLY|os.O_APPEND,
0o644,
)
if err != nil {
return fmt.Errorf("memory: open jsonl for append: %w", err)
}
_, writeErr := f.Write(line)
if writeErr != nil {
f.Close()
return fmt.Errorf("memory: append message: %w", writeErr)
}
// Flush to physical storage before closing. This matches the
// durability guarantee of writeMeta and rewriteJSONL (which use
// WriteFileAtomic with fsync). Without Sync, a power loss could
// leave the append in the kernel page cache only — lost on reboot.
if syncErr := f.Sync(); syncErr != nil {
f.Close()
return fmt.Errorf("memory: sync jsonl: %w", syncErr)
}
if closeErr := f.Close(); closeErr != nil {
return fmt.Errorf("memory: close jsonl: %w", closeErr)
}
// Update metadata.
meta, err := s.readMeta(sessionKey)
if err != nil {
return err
}
now := time.Now()
if meta.Count == 0 && meta.CreatedAt.IsZero() {
meta.CreatedAt = now
}
meta.Count++
meta.UpdatedAt = now
return s.writeMeta(sessionKey, meta)
}
func (s *JSONLStore) GetHistory(
_ context.Context, sessionKey string,
) ([]providers.Message, error) {
l := s.sessionLock(sessionKey)
l.Lock()
defer l.Unlock()
meta, err := s.readMeta(sessionKey)
if err != nil {
return nil, err
}
// Pass meta.Skip so readMessages skips those lines without
// unmarshaling them — avoids wasted CPU on truncated messages.
msgs, err := readMessages(s.jsonlPath(sessionKey), meta.Skip)
if err != nil {
return nil, err
}
return msgs, nil
}
func (s *JSONLStore) GetSummary(
_ context.Context, sessionKey string,
) (string, error) {
l := s.sessionLock(sessionKey)
l.Lock()
defer l.Unlock()
meta, err := s.readMeta(sessionKey)
if err != nil {
return "", err
}
return meta.Summary, nil
}
func (s *JSONLStore) SetSummary(
_ context.Context, sessionKey, summary string,
) error {
l := s.sessionLock(sessionKey)
l.Lock()
defer l.Unlock()
meta, err := s.readMeta(sessionKey)
if err != nil {
return err
}
now := time.Now()
if meta.CreatedAt.IsZero() {
meta.CreatedAt = now
}
meta.Summary = summary
meta.UpdatedAt = now
return s.writeMeta(sessionKey, meta)
}
func (s *JSONLStore) TruncateHistory(
_ context.Context, sessionKey string, keepLast int,
) error {
l := s.sessionLock(sessionKey)
l.Lock()
defer l.Unlock()
meta, err := s.readMeta(sessionKey)
if err != nil {
return err
}
// Always reconcile meta.Count with the actual line count on disk.
// A crash between the JSONL append and the meta update in addMsg
// leaves meta.Count stale (e.g. file has 101 lines but meta says
// 100). Counting lines is cheap — no unmarshal, just a scan — and
// TruncateHistory is not a hot path, so always re-count.
n, countErr := countLines(s.jsonlPath(sessionKey))
if countErr != nil {
return countErr
}
meta.Count = n
if keepLast <= 0 {
meta.Skip = meta.Count
} else {
effective := meta.Count - meta.Skip
if keepLast < effective {
meta.Skip = meta.Count - keepLast
}
}
meta.UpdatedAt = time.Now()
return s.writeMeta(sessionKey, meta)
}
func (s *JSONLStore) SetHistory(
_ context.Context,
sessionKey string,
history []providers.Message,
) error {
l := s.sessionLock(sessionKey)
l.Lock()
defer l.Unlock()
meta, err := s.readMeta(sessionKey)
if err != nil {
return err
}
now := time.Now()
if meta.CreatedAt.IsZero() {
meta.CreatedAt = now
}
meta.Skip = 0
meta.Count = len(history)
meta.UpdatedAt = now
// Write meta BEFORE rewriting the JSONL file. If we crash between
// the two writes, meta has Skip=0 and the old file is still intact,
// so GetHistory reads from line 1 — returning "too many" messages
// rather than losing data. The next SetHistory call corrects this.
err = s.writeMeta(sessionKey, meta)
if err != nil {
return err
}
return s.rewriteJSONL(sessionKey, history)
}
// Compact physically rewrites the JSONL file, dropping all logically
// skipped lines. This reclaims disk space that accumulates after
// repeated TruncateHistory calls.
//
// It is safe to call at any time; if there is nothing to compact
// (skip == 0) the method returns immediately.
func (s *JSONLStore) Compact(
_ context.Context, sessionKey string,
) error {
l := s.sessionLock(sessionKey)
l.Lock()
defer l.Unlock()
meta, err := s.readMeta(sessionKey)
if err != nil {
return err
}
if meta.Skip == 0 {
return nil
}
// Read only the active messages, skipping truncated lines
// without unmarshaling them.
active, err := readMessages(s.jsonlPath(sessionKey), meta.Skip)
if err != nil {
return err
}
// Write meta BEFORE rewriting the JSONL file. If the process
// crashes between the two writes, meta has Skip=0 and the old
// (uncompacted) file is still intact, so GetHistory reads from
// line 1 — returning previously-truncated messages rather than
// losing data. The next Compact or TruncateHistory corrects this.
meta.Skip = 0
meta.Count = len(active)
meta.UpdatedAt = time.Now()
err = s.writeMeta(sessionKey, meta)
if err != nil {
return err
}
return s.rewriteJSONL(sessionKey, active)
}
// rewriteJSONL atomically replaces the JSONL file with the given messages
// using the project's standard WriteFileAtomic (temp + fsync + rename).
func (s *JSONLStore) rewriteJSONL(
sessionKey string, msgs []providers.Message,
) error {
var buf bytes.Buffer
for i, msg := range msgs {
line, err := json.Marshal(msg)
if err != nil {
return fmt.Errorf("memory: marshal message %d: %w", i, err)
}
buf.Write(line)
buf.WriteByte('\n')
}
return fileutil.WriteFileAtomic(s.jsonlPath(sessionKey), buf.Bytes(), 0o644)
}
func (s *JSONLStore) Close() error {
return nil
}
+835
View File
@@ -0,0 +1,835 @@
package memory
import (
"context"
"os"
"path/filepath"
"sync"
"testing"
"github.com/sipeed/picoclaw/pkg/providers"
)
func newTestStore(t *testing.T) *JSONLStore {
t.Helper()
store, err := NewJSONLStore(t.TempDir())
if err != nil {
t.Fatalf("NewJSONLStore: %v", err)
}
return store
}
func TestNewJSONLStore_CreatesDirectory(t *testing.T) {
dir := filepath.Join(t.TempDir(), "nested", "sessions")
store, err := NewJSONLStore(dir)
if err != nil {
t.Fatalf("NewJSONLStore: %v", err)
}
defer store.Close()
info, err := os.Stat(dir)
if err != nil {
t.Fatalf("Stat: %v", err)
}
if !info.IsDir() {
t.Errorf("expected directory, got file")
}
}
func TestAddMessage_BasicRoundtrip(t *testing.T) {
store := newTestStore(t)
ctx := context.Background()
err := store.AddMessage(ctx, "s1", "user", "hello")
if err != nil {
t.Fatalf("AddMessage: %v", err)
}
err = store.AddMessage(ctx, "s1", "assistant", "hi there")
if err != nil {
t.Fatalf("AddMessage: %v", err)
}
history, err := store.GetHistory(ctx, "s1")
if err != nil {
t.Fatalf("GetHistory: %v", err)
}
if len(history) != 2 {
t.Fatalf("expected 2 messages, got %d", len(history))
}
if history[0].Role != "user" || history[0].Content != "hello" {
t.Errorf("msg[0] = %+v", history[0])
}
if history[1].Role != "assistant" || history[1].Content != "hi there" {
t.Errorf("msg[1] = %+v", history[1])
}
}
func TestAddMessage_AutoCreatesSession(t *testing.T) {
store := newTestStore(t)
ctx := context.Background()
// Adding a message to a non-existent session should work.
err := store.AddMessage(ctx, "new-session", "user", "first message")
if err != nil {
t.Fatalf("AddMessage: %v", err)
}
history, err := store.GetHistory(ctx, "new-session")
if err != nil {
t.Fatalf("GetHistory: %v", err)
}
if len(history) != 1 {
t.Fatalf("expected 1 message, got %d", len(history))
}
}
func TestAddFullMessage_WithToolCalls(t *testing.T) {
store := newTestStore(t)
ctx := context.Background()
msg := providers.Message{
Role: "assistant",
Content: "Let me search that.",
ToolCalls: []providers.ToolCall{
{
ID: "call_abc",
Type: "function",
Function: &providers.FunctionCall{
Name: "web_search",
Arguments: `{"q":"golang jsonl"}`,
},
},
},
}
err := store.AddFullMessage(ctx, "tc", msg)
if err != nil {
t.Fatalf("AddFullMessage: %v", err)
}
history, err := store.GetHistory(ctx, "tc")
if err != nil {
t.Fatalf("GetHistory: %v", err)
}
if len(history) != 1 {
t.Fatalf("expected 1, got %d", len(history))
}
if len(history[0].ToolCalls) != 1 {
t.Fatalf("expected 1 tool call, got %d", len(history[0].ToolCalls))
}
tc := history[0].ToolCalls[0]
if tc.ID != "call_abc" {
t.Errorf("tool call ID = %q", tc.ID)
}
if tc.Function == nil || tc.Function.Name != "web_search" {
t.Errorf("tool call function = %+v", tc.Function)
}
}
func TestAddFullMessage_ToolCallID(t *testing.T) {
store := newTestStore(t)
ctx := context.Background()
msg := providers.Message{
Role: "tool",
Content: "search results here",
ToolCallID: "call_abc",
}
err := store.AddFullMessage(ctx, "tr", msg)
if err != nil {
t.Fatalf("AddFullMessage: %v", err)
}
history, err := store.GetHistory(ctx, "tr")
if err != nil {
t.Fatalf("GetHistory: %v", err)
}
if len(history) != 1 {
t.Fatalf("expected 1, got %d", len(history))
}
if history[0].ToolCallID != "call_abc" {
t.Errorf("ToolCallID = %q", history[0].ToolCallID)
}
}
func TestGetHistory_EmptySession(t *testing.T) {
store := newTestStore(t)
ctx := context.Background()
history, err := store.GetHistory(ctx, "nonexistent")
if err != nil {
t.Fatalf("GetHistory: %v", err)
}
if history == nil {
t.Fatal("expected non-nil empty slice")
}
if len(history) != 0 {
t.Errorf("expected 0 messages, got %d", len(history))
}
}
func TestGetHistory_Ordering(t *testing.T) {
store := newTestStore(t)
ctx := context.Background()
for i := 0; i < 5; i++ {
err := store.AddMessage(
ctx, "order",
"user",
string(rune('a'+i)),
)
if err != nil {
t.Fatalf("AddMessage(%d): %v", i, err)
}
}
history, err := store.GetHistory(ctx, "order")
if err != nil {
t.Fatalf("GetHistory: %v", err)
}
if len(history) != 5 {
t.Fatalf("expected 5, got %d", len(history))
}
for i := 0; i < 5; i++ {
expected := string(rune('a' + i))
if history[i].Content != expected {
t.Errorf("msg[%d].Content = %q, want %q", i, history[i].Content, expected)
}
}
}
func TestSetSummary_GetSummary(t *testing.T) {
store := newTestStore(t)
ctx := context.Background()
// No summary yet.
summary, err := store.GetSummary(ctx, "s1")
if err != nil {
t.Fatalf("GetSummary: %v", err)
}
if summary != "" {
t.Errorf("expected empty, got %q", summary)
}
// Set a summary.
err = store.SetSummary(ctx, "s1", "talked about Go")
if err != nil {
t.Fatalf("SetSummary: %v", err)
}
summary, err = store.GetSummary(ctx, "s1")
if err != nil {
t.Fatalf("GetSummary: %v", err)
}
if summary != "talked about Go" {
t.Errorf("summary = %q", summary)
}
// Update summary.
err = store.SetSummary(ctx, "s1", "updated summary")
if err != nil {
t.Fatalf("SetSummary: %v", err)
}
summary, err = store.GetSummary(ctx, "s1")
if err != nil {
t.Fatalf("GetSummary: %v", err)
}
if summary != "updated summary" {
t.Errorf("summary = %q", summary)
}
}
func TestTruncateHistory_KeepLast(t *testing.T) {
store := newTestStore(t)
ctx := context.Background()
for i := 0; i < 10; i++ {
err := store.AddMessage(
ctx, "trunc",
"user",
string(rune('a'+i)),
)
if err != nil {
t.Fatalf("AddMessage: %v", err)
}
}
err := store.TruncateHistory(ctx, "trunc", 4)
if err != nil {
t.Fatalf("TruncateHistory: %v", err)
}
history, err := store.GetHistory(ctx, "trunc")
if err != nil {
t.Fatalf("GetHistory: %v", err)
}
if len(history) != 4 {
t.Fatalf("expected 4, got %d", len(history))
}
// Should be the last 4: g, h, i, j
if history[0].Content != "g" {
t.Errorf("first kept = %q, want 'g'", history[0].Content)
}
if history[3].Content != "j" {
t.Errorf("last kept = %q, want 'j'", history[3].Content)
}
}
func TestTruncateHistory_KeepZero(t *testing.T) {
store := newTestStore(t)
ctx := context.Background()
for i := 0; i < 5; i++ {
err := store.AddMessage(ctx, "empty", "user", "msg")
if err != nil {
t.Fatalf("AddMessage: %v", err)
}
}
err := store.TruncateHistory(ctx, "empty", 0)
if err != nil {
t.Fatalf("TruncateHistory: %v", err)
}
history, err := store.GetHistory(ctx, "empty")
if err != nil {
t.Fatalf("GetHistory: %v", err)
}
if len(history) != 0 {
t.Errorf("expected 0, got %d", len(history))
}
}
func TestTruncateHistory_KeepMoreThanExists(t *testing.T) {
store := newTestStore(t)
ctx := context.Background()
for i := 0; i < 3; i++ {
err := store.AddMessage(ctx, "few", "user", "msg")
if err != nil {
t.Fatalf("AddMessage: %v", err)
}
}
// Keep 100, but only 3 exist — should keep all.
err := store.TruncateHistory(ctx, "few", 100)
if err != nil {
t.Fatalf("TruncateHistory: %v", err)
}
history, err := store.GetHistory(ctx, "few")
if err != nil {
t.Fatalf("GetHistory: %v", err)
}
if len(history) != 3 {
t.Errorf("expected 3, got %d", len(history))
}
}
func TestSetHistory_ReplacesAll(t *testing.T) {
store := newTestStore(t)
ctx := context.Background()
// Add some initial messages.
for i := 0; i < 5; i++ {
err := store.AddMessage(ctx, "replace", "user", "old")
if err != nil {
t.Fatalf("AddMessage: %v", err)
}
}
// Replace with new history.
newHistory := []providers.Message{
{Role: "user", Content: "new1"},
{Role: "assistant", Content: "new2"},
}
err := store.SetHistory(ctx, "replace", newHistory)
if err != nil {
t.Fatalf("SetHistory: %v", err)
}
history, err := store.GetHistory(ctx, "replace")
if err != nil {
t.Fatalf("GetHistory: %v", err)
}
if len(history) != 2 {
t.Fatalf("expected 2, got %d", len(history))
}
if history[0].Content != "new1" || history[1].Content != "new2" {
t.Errorf("history = %+v", history)
}
}
func TestSetHistory_ResetsSkip(t *testing.T) {
store := newTestStore(t)
ctx := context.Background()
// Add messages and truncate.
for i := 0; i < 10; i++ {
err := store.AddMessage(ctx, "skip-reset", "user", "old")
if err != nil {
t.Fatalf("AddMessage: %v", err)
}
}
err := store.TruncateHistory(ctx, "skip-reset", 3)
if err != nil {
t.Fatalf("TruncateHistory: %v", err)
}
// SetHistory should reset skip to 0.
newHistory := []providers.Message{
{Role: "user", Content: "fresh"},
}
err = store.SetHistory(ctx, "skip-reset", newHistory)
if err != nil {
t.Fatalf("SetHistory: %v", err)
}
history, err := store.GetHistory(ctx, "skip-reset")
if err != nil {
t.Fatalf("GetHistory: %v", err)
}
if len(history) != 1 {
t.Fatalf("expected 1, got %d", len(history))
}
if history[0].Content != "fresh" {
t.Errorf("content = %q", history[0].Content)
}
}
func TestColonInKey(t *testing.T) {
store := newTestStore(t)
ctx := context.Background()
err := store.AddMessage(ctx, "telegram:123", "user", "hi")
if err != nil {
t.Fatalf("AddMessage: %v", err)
}
history, err := store.GetHistory(ctx, "telegram:123")
if err != nil {
t.Fatalf("GetHistory: %v", err)
}
if len(history) != 1 {
t.Fatalf("expected 1, got %d", len(history))
}
// Verify the file is named with underscore.
jsonlFile := filepath.Join(store.dir, "telegram_123.jsonl")
if _, statErr := os.Stat(jsonlFile); statErr != nil {
t.Errorf("expected file %s to exist: %v", jsonlFile, statErr)
}
}
func TestCompact_RemovesSkippedMessages(t *testing.T) {
store := newTestStore(t)
ctx := context.Background()
// Write 10 messages, then truncate to keep last 3.
for i := 0; i < 10; i++ {
err := store.AddMessage(ctx, "compact", "user", string(rune('a'+i)))
if err != nil {
t.Fatalf("AddMessage: %v", err)
}
}
err := store.TruncateHistory(ctx, "compact", 3)
if err != nil {
t.Fatalf("TruncateHistory: %v", err)
}
// Before compact: file still has 10 lines.
allOnDisk, err := readMessages(store.jsonlPath("compact"), 0)
if err != nil {
t.Fatalf("readMessages: %v", err)
}
if len(allOnDisk) != 10 {
t.Fatalf("before compact: expected 10 on disk, got %d", len(allOnDisk))
}
// Compact.
err = store.Compact(ctx, "compact")
if err != nil {
t.Fatalf("Compact: %v", err)
}
// After compact: file should have only 3 lines.
allOnDisk, err = readMessages(store.jsonlPath("compact"), 0)
if err != nil {
t.Fatalf("readMessages: %v", err)
}
if len(allOnDisk) != 3 {
t.Fatalf("after compact: expected 3 on disk, got %d", len(allOnDisk))
}
// GetHistory should still return the same 3 messages.
history, err := store.GetHistory(ctx, "compact")
if err != nil {
t.Fatalf("GetHistory: %v", err)
}
if len(history) != 3 {
t.Fatalf("expected 3, got %d", len(history))
}
if history[0].Content != "h" || history[2].Content != "j" {
t.Errorf("wrong content: %+v", history)
}
}
func TestCompact_NoOpWhenNoSkip(t *testing.T) {
store := newTestStore(t)
ctx := context.Background()
for i := 0; i < 5; i++ {
err := store.AddMessage(ctx, "noop", "user", "msg")
if err != nil {
t.Fatalf("AddMessage: %v", err)
}
}
// Compact without prior truncation — should be a no-op.
err := store.Compact(ctx, "noop")
if err != nil {
t.Fatalf("Compact: %v", err)
}
history, err := store.GetHistory(ctx, "noop")
if err != nil {
t.Fatalf("GetHistory: %v", err)
}
if len(history) != 5 {
t.Errorf("expected 5, got %d", len(history))
}
}
func TestCompact_ThenAppend(t *testing.T) {
store := newTestStore(t)
ctx := context.Background()
for i := 0; i < 8; i++ {
err := store.AddMessage(ctx, "cap", "user", string(rune('a'+i)))
if err != nil {
t.Fatalf("AddMessage: %v", err)
}
}
err := store.TruncateHistory(ctx, "cap", 2)
if err != nil {
t.Fatalf("TruncateHistory: %v", err)
}
err = store.Compact(ctx, "cap")
if err != nil {
t.Fatalf("Compact: %v", err)
}
// Append after compaction should work correctly.
err = store.AddMessage(ctx, "cap", "user", "new")
if err != nil {
t.Fatalf("AddMessage after compact: %v", err)
}
history, err := store.GetHistory(ctx, "cap")
if err != nil {
t.Fatalf("GetHistory: %v", err)
}
if len(history) != 3 {
t.Fatalf("expected 3, got %d", len(history))
}
// g, h (kept from truncation), new (appended after compaction).
if history[0].Content != "g" {
t.Errorf("first = %q, want 'g'", history[0].Content)
}
if history[2].Content != "new" {
t.Errorf("last = %q, want 'new'", history[2].Content)
}
}
func TestTruncateHistory_StaleMetaCount(t *testing.T) {
// Simulates a crash between JSONL append and meta update in addMsg:
// file has N+1 lines but meta.Count is still N. TruncateHistory must
// reconcile with the real line count so that keepLast is accurate.
store := newTestStore(t)
ctx := context.Background()
// Write 10 messages normally (meta.Count = 10).
for i := 0; i < 10; i++ {
err := store.AddMessage(ctx, "stale", "user", string(rune('a'+i)))
if err != nil {
t.Fatalf("AddMessage: %v", err)
}
}
// Simulate crash: append a line to JSONL but do NOT update meta.
// This leaves meta.Count = 10 while the file has 11 lines.
jsonlPath := store.jsonlPath("stale")
f, err := os.OpenFile(jsonlPath, os.O_WRONLY|os.O_APPEND, 0o644)
if err != nil {
t.Fatalf("open for append: %v", err)
}
_, err = f.WriteString(`{"role":"user","content":"orphan"}` + "\n")
if err != nil {
t.Fatalf("write orphan: %v", err)
}
f.Close()
// TruncateHistory(keepLast=4) should keep the last 4 of 11 lines,
// not the last 4 of 10.
err = store.TruncateHistory(ctx, "stale", 4)
if err != nil {
t.Fatalf("TruncateHistory: %v", err)
}
history, err := store.GetHistory(ctx, "stale")
if err != nil {
t.Fatalf("GetHistory: %v", err)
}
if len(history) != 4 {
t.Fatalf("expected 4, got %d", len(history))
}
// Last 4 of [a,b,c,d,e,f,g,h,i,j,orphan] = [h,i,j,orphan]
if history[0].Content != "h" {
t.Errorf("first kept = %q, want 'h'", history[0].Content)
}
if history[3].Content != "orphan" {
t.Errorf("last kept = %q, want 'orphan'", history[3].Content)
}
}
func TestCrashRecovery_PartialLine(t *testing.T) {
store := newTestStore(t)
ctx := context.Background()
// Write a valid message first.
err := store.AddMessage(ctx, "crash", "user", "valid")
if err != nil {
t.Fatalf("AddMessage: %v", err)
}
// Simulate a crash by appending a partial JSON line directly.
jsonlPath := store.jsonlPath("crash")
f, err := os.OpenFile(jsonlPath, os.O_WRONLY|os.O_APPEND, 0o644)
if err != nil {
t.Fatalf("open for append: %v", err)
}
_, err = f.WriteString(`{"role":"user","content":"incomple`)
if err != nil {
t.Fatalf("write partial: %v", err)
}
f.Close()
// GetHistory should return only the valid message.
history, err := store.GetHistory(ctx, "crash")
if err != nil {
t.Fatalf("GetHistory: %v", err)
}
if len(history) != 1 {
t.Fatalf("expected 1 valid message, got %d", len(history))
}
if history[0].Content != "valid" {
t.Errorf("content = %q", history[0].Content)
}
}
func TestPersistence_AcrossInstances(t *testing.T) {
dir := t.TempDir()
ctx := context.Background()
// Write with first instance.
store1, err := NewJSONLStore(dir)
if err != nil {
t.Fatalf("NewJSONLStore: %v", err)
}
err = store1.AddMessage(ctx, "persist", "user", "remember me")
if err != nil {
t.Fatalf("AddMessage: %v", err)
}
err = store1.SetSummary(ctx, "persist", "a test session")
if err != nil {
t.Fatalf("SetSummary: %v", err)
}
store1.Close()
// Read with second instance.
store2, err := NewJSONLStore(dir)
if err != nil {
t.Fatalf("NewJSONLStore: %v", err)
}
defer store2.Close()
history, err := store2.GetHistory(ctx, "persist")
if err != nil {
t.Fatalf("GetHistory: %v", err)
}
if len(history) != 1 || history[0].Content != "remember me" {
t.Errorf("history = %+v", history)
}
summary, err := store2.GetSummary(ctx, "persist")
if err != nil {
t.Fatalf("GetSummary: %v", err)
}
if summary != "a test session" {
t.Errorf("summary = %q", summary)
}
}
func TestConcurrent_AddAndRead(t *testing.T) {
store := newTestStore(t)
ctx := context.Background()
var wg sync.WaitGroup
const goroutines = 10
const msgsPerGoroutine = 20
// Concurrent writes.
for g := 0; g < goroutines; g++ {
wg.Add(1)
go func() {
defer wg.Done()
for i := 0; i < msgsPerGoroutine; i++ {
_ = store.AddMessage(ctx, "concurrent", "user", "msg")
}
}()
}
wg.Wait()
history, err := store.GetHistory(ctx, "concurrent")
if err != nil {
t.Fatalf("GetHistory: %v", err)
}
expected := goroutines * msgsPerGoroutine
if len(history) != expected {
t.Errorf("expected %d messages, got %d", expected, len(history))
}
}
func TestConcurrent_SummarizeRace(t *testing.T) {
// Simulates the #704 race: one goroutine adds messages while
// another truncates + sets summary — like summarizeSession().
store := newTestStore(t)
ctx := context.Background()
// Seed with some messages.
for i := 0; i < 20; i++ {
err := store.AddMessage(ctx, "race", "user", "seed")
if err != nil {
t.Fatalf("AddMessage: %v", err)
}
}
var wg sync.WaitGroup
// Writer goroutine (main agent loop).
wg.Add(1)
go func() {
defer wg.Done()
for i := 0; i < 50; i++ {
_ = store.AddMessage(ctx, "race", "user", "new")
}
}()
// Summarizer goroutine (background task).
wg.Add(1)
go func() {
defer wg.Done()
for i := 0; i < 10; i++ {
_ = store.SetSummary(ctx, "race", "summary")
_ = store.TruncateHistory(ctx, "race", 5)
}
}()
wg.Wait()
// Verify the store is still in a consistent state.
_, err := store.GetHistory(ctx, "race")
if err != nil {
t.Fatalf("GetHistory after race: %v", err)
}
_, err = store.GetSummary(ctx, "race")
if err != nil {
t.Fatalf("GetSummary after race: %v", err)
}
}
func TestMultipleSessions_Isolation(t *testing.T) {
store := newTestStore(t)
ctx := context.Background()
err := store.AddMessage(ctx, "s1", "user", "msg for s1")
if err != nil {
t.Fatalf("AddMessage: %v", err)
}
err = store.AddMessage(ctx, "s2", "user", "msg for s2")
if err != nil {
t.Fatalf("AddMessage: %v", err)
}
h1, err := store.GetHistory(ctx, "s1")
if err != nil {
t.Fatalf("GetHistory s1: %v", err)
}
h2, err := store.GetHistory(ctx, "s2")
if err != nil {
t.Fatalf("GetHistory s2: %v", err)
}
if len(h1) != 1 || h1[0].Content != "msg for s1" {
t.Errorf("s1 history = %+v", h1)
}
if len(h2) != 1 || h2[0].Content != "msg for s2" {
t.Errorf("s2 history = %+v", h2)
}
}
func BenchmarkAddMessage(b *testing.B) {
dir := b.TempDir()
store, err := NewJSONLStore(dir)
if err != nil {
b.Fatalf("NewJSONLStore: %v", err)
}
defer store.Close()
ctx := context.Background()
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = store.AddMessage(ctx, "bench", "user", "benchmark message content")
}
}
func BenchmarkGetHistory_100(b *testing.B) {
dir := b.TempDir()
store, err := NewJSONLStore(dir)
if err != nil {
b.Fatalf("NewJSONLStore: %v", err)
}
defer store.Close()
ctx := context.Background()
for i := 0; i < 100; i++ {
_ = store.AddMessage(ctx, "bench", "user", "message content")
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = store.GetHistory(ctx, "bench")
}
}
func BenchmarkGetHistory_1000(b *testing.B) {
dir := b.TempDir()
store, err := NewJSONLStore(dir)
if err != nil {
b.Fatalf("NewJSONLStore: %v", err)
}
defer store.Close()
ctx := context.Background()
for i := 0; i < 1000; i++ {
_ = store.AddMessage(ctx, "bench", "user", "message content")
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = store.GetHistory(ctx, "bench")
}
}
+108
View File
@@ -0,0 +1,108 @@
package memory
import (
"context"
"encoding/json"
"fmt"
"log"
"os"
"path/filepath"
"strings"
"time"
"github.com/sipeed/picoclaw/pkg/providers"
)
// jsonSession mirrors pkg/session.Session for migration purposes.
type jsonSession struct {
Key string `json:"key"`
Messages []providers.Message `json:"messages"`
Summary string `json:"summary,omitempty"`
Created time.Time `json:"created"`
Updated time.Time `json:"updated"`
}
// MigrateFromJSON reads legacy sessions/*.json files from sessionsDir,
// writes them into the Store, and renames each migrated file to
// .json.migrated as a backup. Returns the number of sessions migrated.
//
// Files that fail to parse are logged and skipped. Already-migrated
// files (.json.migrated) are ignored, making the function idempotent.
func MigrateFromJSON(
ctx context.Context, sessionsDir string, store Store,
) (int, error) {
entries, err := os.ReadDir(sessionsDir)
if os.IsNotExist(err) {
return 0, nil
}
if err != nil {
return 0, fmt.Errorf("memory: read sessions dir: %w", err)
}
migrated := 0
for _, entry := range entries {
if entry.IsDir() {
continue
}
name := entry.Name()
if !strings.HasSuffix(name, ".json") {
continue
}
// Skip already-migrated files.
if strings.HasSuffix(name, ".migrated") {
continue
}
srcPath := filepath.Join(sessionsDir, name)
data, readErr := os.ReadFile(srcPath)
if readErr != nil {
log.Printf("memory: migrate: skip %s: %v", name, readErr)
continue
}
var sess jsonSession
if parseErr := json.Unmarshal(data, &sess); parseErr != nil {
log.Printf("memory: migrate: skip %s: %v", name, parseErr)
continue
}
// Use the key from the JSON content, not the filename.
// Filenames are sanitized (":" → "_") but keys are not.
key := sess.Key
if key == "" {
key = strings.TrimSuffix(name, ".json")
}
// Use SetHistory (atomic replace) instead of per-message
// AddFullMessage. This makes migration idempotent: if the
// process crashes after writing messages but before the
// rename below, a retry replaces the partial data cleanly
// instead of duplicating messages.
if setErr := store.SetHistory(ctx, key, sess.Messages); setErr != nil {
return migrated, fmt.Errorf(
"memory: migrate %s: set history: %w",
name, setErr,
)
}
if sess.Summary != "" {
if sumErr := store.SetSummary(ctx, key, sess.Summary); sumErr != nil {
return migrated, fmt.Errorf(
"memory: migrate %s: set summary: %w",
name, sumErr,
)
}
}
// Rename to .migrated as backup (not delete).
renameErr := os.Rename(srcPath, srcPath+".migrated")
if renameErr != nil {
log.Printf("memory: migrate: rename %s: %v", name, renameErr)
}
migrated++
}
return migrated, nil
}
+384
View File
@@ -0,0 +1,384 @@
package memory
import (
"context"
"encoding/json"
"os"
"path/filepath"
"testing"
"time"
"github.com/sipeed/picoclaw/pkg/providers"
)
func writeJSONSession(
t *testing.T, dir string, filename string, sess jsonSession,
) {
t.Helper()
data, err := json.MarshalIndent(sess, "", " ")
if err != nil {
t.Fatalf("marshal session: %v", err)
}
err = os.WriteFile(filepath.Join(dir, filename), data, 0o644)
if err != nil {
t.Fatalf("write session file: %v", err)
}
}
func TestMigrateFromJSON_Basic(t *testing.T) {
sessionsDir := t.TempDir()
store := newTestStore(t)
ctx := context.Background()
writeJSONSession(t, sessionsDir, "test.json", jsonSession{
Key: "test",
Messages: []providers.Message{
{Role: "user", Content: "hello"},
{Role: "assistant", Content: "hi"},
},
Summary: "A greeting.",
Created: time.Now(),
Updated: time.Now(),
})
count, err := MigrateFromJSON(ctx, sessionsDir, store)
if err != nil {
t.Fatalf("MigrateFromJSON: %v", err)
}
if count != 1 {
t.Errorf("expected 1 migrated, got %d", count)
}
history, err := store.GetHistory(ctx, "test")
if err != nil {
t.Fatalf("GetHistory: %v", err)
}
if len(history) != 2 {
t.Fatalf("expected 2 messages, got %d", len(history))
}
if history[0].Content != "hello" || history[1].Content != "hi" {
t.Errorf("unexpected messages: %+v", history)
}
summary, err := store.GetSummary(ctx, "test")
if err != nil {
t.Fatalf("GetSummary: %v", err)
}
if summary != "A greeting." {
t.Errorf("summary = %q", summary)
}
}
func TestMigrateFromJSON_WithToolCalls(t *testing.T) {
sessionsDir := t.TempDir()
store := newTestStore(t)
ctx := context.Background()
writeJSONSession(t, sessionsDir, "tools.json", jsonSession{
Key: "tools",
Messages: []providers.Message{
{
Role: "assistant",
Content: "Searching...",
ToolCalls: []providers.ToolCall{
{
ID: "call_1",
Type: "function",
Function: &providers.FunctionCall{
Name: "web_search",
Arguments: `{"q":"test"}`,
},
},
},
},
{
Role: "tool",
Content: "result",
ToolCallID: "call_1",
},
},
Created: time.Now(),
Updated: time.Now(),
})
count, err := MigrateFromJSON(ctx, sessionsDir, store)
if err != nil {
t.Fatalf("MigrateFromJSON: %v", err)
}
if count != 1 {
t.Errorf("expected 1, got %d", count)
}
history, err := store.GetHistory(ctx, "tools")
if err != nil {
t.Fatalf("GetHistory: %v", err)
}
if len(history) != 2 {
t.Fatalf("expected 2 messages, got %d", len(history))
}
if len(history[0].ToolCalls) != 1 {
t.Fatalf("expected 1 tool call, got %d", len(history[0].ToolCalls))
}
if history[0].ToolCalls[0].Function.Name != "web_search" {
t.Errorf("function = %q", history[0].ToolCalls[0].Function.Name)
}
if history[1].ToolCallID != "call_1" {
t.Errorf("ToolCallID = %q", history[1].ToolCallID)
}
}
func TestMigrateFromJSON_MultipleFiles(t *testing.T) {
sessionsDir := t.TempDir()
store := newTestStore(t)
ctx := context.Background()
for i := 0; i < 3; i++ {
key := string(rune('a' + i))
writeJSONSession(t, sessionsDir, key+".json", jsonSession{
Key: key,
Messages: []providers.Message{{Role: "user", Content: "msg " + key}},
Created: time.Now(),
Updated: time.Now(),
})
}
count, err := MigrateFromJSON(ctx, sessionsDir, store)
if err != nil {
t.Fatalf("MigrateFromJSON: %v", err)
}
if count != 3 {
t.Errorf("expected 3, got %d", count)
}
for i := 0; i < 3; i++ {
key := string(rune('a' + i))
history, histErr := store.GetHistory(ctx, key)
if histErr != nil {
t.Fatalf("GetHistory(%q): %v", key, histErr)
}
if len(history) != 1 {
t.Errorf("session %q: expected 1 msg, got %d", key, len(history))
}
}
}
func TestMigrateFromJSON_InvalidJSON(t *testing.T) {
sessionsDir := t.TempDir()
store := newTestStore(t)
ctx := context.Background()
// One valid, one invalid.
writeJSONSession(t, sessionsDir, "good.json", jsonSession{
Key: "good",
Messages: []providers.Message{{Role: "user", Content: "ok"}},
Created: time.Now(),
Updated: time.Now(),
})
err := os.WriteFile(
filepath.Join(sessionsDir, "bad.json"),
[]byte("{invalid json"),
0o644,
)
if err != nil {
t.Fatalf("write bad file: %v", err)
}
count, err := MigrateFromJSON(ctx, sessionsDir, store)
if err != nil {
t.Fatalf("MigrateFromJSON: %v", err)
}
if count != 1 {
t.Errorf("expected 1 (bad file skipped), got %d", count)
}
history, err := store.GetHistory(ctx, "good")
if err != nil {
t.Fatalf("GetHistory: %v", err)
}
if len(history) != 1 {
t.Errorf("expected 1 message, got %d", len(history))
}
}
func TestMigrateFromJSON_RenamesFiles(t *testing.T) {
sessionsDir := t.TempDir()
store := newTestStore(t)
ctx := context.Background()
writeJSONSession(t, sessionsDir, "rename.json", jsonSession{
Key: "rename",
Messages: []providers.Message{{Role: "user", Content: "hi"}},
Created: time.Now(),
Updated: time.Now(),
})
_, err := MigrateFromJSON(ctx, sessionsDir, store)
if err != nil {
t.Fatalf("MigrateFromJSON: %v", err)
}
// Original .json should not exist.
_, statErr := os.Stat(filepath.Join(sessionsDir, "rename.json"))
if !os.IsNotExist(statErr) {
t.Error("rename.json should have been renamed")
}
// .json.migrated should exist.
_, statErr = os.Stat(
filepath.Join(sessionsDir, "rename.json.migrated"),
)
if statErr != nil {
t.Errorf("rename.json.migrated should exist: %v", statErr)
}
}
func TestMigrateFromJSON_Idempotent(t *testing.T) {
sessionsDir := t.TempDir()
store := newTestStore(t)
ctx := context.Background()
writeJSONSession(t, sessionsDir, "idem.json", jsonSession{
Key: "idem",
Messages: []providers.Message{{Role: "user", Content: "once"}},
Created: time.Now(),
Updated: time.Now(),
})
count1, err := MigrateFromJSON(ctx, sessionsDir, store)
if err != nil {
t.Fatalf("first migration: %v", err)
}
if count1 != 1 {
t.Errorf("first run: expected 1, got %d", count1)
}
// Second run should find only .migrated files, skip them.
count2, err := MigrateFromJSON(ctx, sessionsDir, store)
if err != nil {
t.Fatalf("second migration: %v", err)
}
if count2 != 0 {
t.Errorf("second run: expected 0, got %d", count2)
}
history, err := store.GetHistory(ctx, "idem")
if err != nil {
t.Fatalf("GetHistory: %v", err)
}
if len(history) != 1 {
t.Errorf("expected 1 message, got %d", len(history))
}
}
func TestMigrateFromJSON_ColonInKey(t *testing.T) {
sessionsDir := t.TempDir()
store := newTestStore(t)
ctx := context.Background()
// File is named telegram_123 (sanitized), but the key inside is telegram:123.
writeJSONSession(t, sessionsDir, "telegram_123.json", jsonSession{
Key: "telegram:123",
Messages: []providers.Message{{Role: "user", Content: "from telegram"}},
Created: time.Now(),
Updated: time.Now(),
})
count, err := MigrateFromJSON(ctx, sessionsDir, store)
if err != nil {
t.Fatalf("MigrateFromJSON: %v", err)
}
if count != 1 {
t.Errorf("expected 1, got %d", count)
}
// Accessible via the original key "telegram:123".
history, err := store.GetHistory(ctx, "telegram:123")
if err != nil {
t.Fatalf("GetHistory: %v", err)
}
if len(history) != 1 {
t.Fatalf("expected 1 message, got %d", len(history))
}
if history[0].Content != "from telegram" {
t.Errorf("content = %q", history[0].Content)
}
// In the file-based store, "telegram:123" and "telegram_123" both
// sanitize to the same filename, so they share storage. This is
// expected — the colon-to-underscore mapping is a one-way function.
history2, err := store.GetHistory(ctx, "telegram_123")
if err != nil {
t.Fatalf("GetHistory: %v", err)
}
if len(history2) != 1 {
t.Errorf("expected 1 (same file), got %d", len(history2))
}
}
func TestMigrateFromJSON_RetryAfterCrash(t *testing.T) {
// Simulates a crash during migration: first run writes messages
// but doesn't rename the .json file. Second run must replace
// (not duplicate) the messages thanks to SetHistory semantics.
sessionsDir := t.TempDir()
store := newTestStore(t)
ctx := context.Background()
writeJSONSession(t, sessionsDir, "retry.json", jsonSession{
Key: "retry",
Messages: []providers.Message{
{Role: "user", Content: "one"},
{Role: "assistant", Content: "two"},
},
Created: time.Now(),
Updated: time.Now(),
})
// First migration succeeds — writes messages and renames file.
count, err := MigrateFromJSON(ctx, sessionsDir, store)
if err != nil {
t.Fatalf("first migration: %v", err)
}
if count != 1 {
t.Fatalf("expected 1, got %d", count)
}
// Simulate "crash before rename": restore the .json file.
src := filepath.Join(sessionsDir, "retry.json.migrated")
dst := filepath.Join(sessionsDir, "retry.json")
if renameErr := os.Rename(src, dst); renameErr != nil {
t.Fatalf("restore .json: %v", renameErr)
}
// Second migration should re-import without duplicating messages.
count, err = MigrateFromJSON(ctx, sessionsDir, store)
if err != nil {
t.Fatalf("second migration: %v", err)
}
if count != 1 {
t.Fatalf("expected 1, got %d", count)
}
history, err := store.GetHistory(ctx, "retry")
if err != nil {
t.Fatalf("GetHistory: %v", err)
}
// Must be exactly 2 messages (not 4 from duplication).
if len(history) != 2 {
t.Fatalf("expected 2 messages (no duplicates), got %d", len(history))
}
if history[0].Content != "one" || history[1].Content != "two" {
t.Errorf("unexpected messages: %+v", history)
}
}
func TestMigrateFromJSON_NonexistentDir(t *testing.T) {
store := newTestStore(t)
ctx := context.Background()
count, err := MigrateFromJSON(ctx, "/nonexistent/path", store)
if err != nil {
t.Fatalf("MigrateFromJSON: %v", err)
}
if count != 0 {
t.Errorf("expected 0, got %d", count)
}
}
+42
View File
@@ -0,0 +1,42 @@
package memory
import (
"context"
"github.com/sipeed/picoclaw/pkg/providers"
)
// Store defines an interface for persistent session storage.
// Each method is an atomic operation — there is no separate Save() call.
type Store interface {
// AddMessage appends a simple text message to a session.
AddMessage(ctx context.Context, sessionKey, role, content string) error
// AddFullMessage appends a complete message (with tool calls, etc.) to a session.
AddFullMessage(ctx context.Context, sessionKey string, msg providers.Message) error
// GetHistory returns all messages for a session in insertion order.
// Returns an empty slice (not nil) if the session does not exist.
GetHistory(ctx context.Context, sessionKey string) ([]providers.Message, error)
// GetSummary returns the conversation summary for a session.
// Returns an empty string if no summary exists.
GetSummary(ctx context.Context, sessionKey string) (string, error)
// SetSummary updates the conversation summary for a session.
SetSummary(ctx context.Context, sessionKey, summary string) error
// TruncateHistory removes all but the last keepLast messages from a session.
// If keepLast <= 0, all messages are removed.
TruncateHistory(ctx context.Context, sessionKey string, keepLast int) error
// SetHistory replaces all messages in a session with the provided history.
SetHistory(ctx context.Context, sessionKey string, history []providers.Message) error
// Compact reclaims storage by physically removing logically truncated
// data. Backends that do not accumulate dead data may return nil.
Compact(ctx context.Context, sessionKey string) error
// Close releases any resources held by the store.
Close() error
}
+1 -27
View File
@@ -190,28 +190,6 @@ func resolveProviderSelection(cfg *config.Config) (providerSelection, error) {
sel.apiBase = "https://api.mistral.ai/v1"
}
}
case "opencode":
if cfg.Providers.Opencode.APIKey != "" || cfg.Providers.Opencode.APIBase != "" {
sel.apiKey = cfg.Providers.Opencode.APIKey
sel.apiBase = cfg.Providers.Opencode.APIBase
sel.proxy = cfg.Providers.Opencode.Proxy
if sel.apiBase == "" {
sel.apiBase = "https://opencode.ai/zen/v1"
}
}
case "kimi", "kimi-code", "moonshot":
if cfg.Providers.Moonshot.APIKey != "" {
sel.apiKey = cfg.Providers.Moonshot.APIKey
sel.apiBase = cfg.Providers.Moonshot.APIBase
sel.proxy = cfg.Providers.Moonshot.Proxy
if sel.apiBase == "" {
if providerName == "moonshot" {
sel.apiBase = "https://api.moonshot.cn/v1"
} else {
sel.apiBase = "https://api.kimi.com/coding/v1"
}
}
}
case "github_copilot", "copilot":
sel.providerType = providerTypeGitHubCopilot
if cfg.Providers.GitHubCopilot.APIBase != "" {
@@ -232,11 +210,7 @@ func resolveProviderSelection(cfg *config.Config) (providerSelection, error) {
sel.apiBase = cfg.Providers.Moonshot.APIBase
sel.proxy = cfg.Providers.Moonshot.Proxy
if sel.apiBase == "" {
if strings.Contains(lowerModel, "moonshot") || strings.HasPrefix(model, "moonshot/") {
sel.apiBase = "https://api.moonshot.cn/v1"
} else {
sel.apiBase = "https://api.kimi.com/coding/v1"
}
sel.apiBase = "https://api.moonshot.cn/v1"
}
case strings.HasPrefix(model, "openrouter/") ||
strings.HasPrefix(model, "anthropic/") ||
+1 -3
View File
@@ -94,7 +94,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
case "litellm", "openrouter", "groq", "zhipu", "gemini", "nvidia",
"ollama", "moonshot", "shengsuanyun", "deepseek", "cerebras",
"volcengine", "vllm", "qwen", "mistral", "opencode":
"volcengine", "vllm", "qwen", "mistral":
// All other OpenAI-compatible HTTP providers
if cfg.APIKey == "" && cfg.APIBase == "" {
return nil, "", fmt.Errorf("api_key or api_base is required for HTTP-based protocol %q", protocol)
@@ -208,8 +208,6 @@ func getDefaultAPIBase(protocol string) string {
return "http://localhost:8000/v1"
case "mistral":
return "https://api.mistral.ai/v1"
case "opencode":
return "https://opencode.ai/zen/v1"
default:
return ""
}
-1
View File
@@ -112,7 +112,6 @@ func TestCreateProviderFromConfig_DefaultAPIBase(t *testing.T) {
{"vllm", "vllm"},
{"deepseek", "deepseek"},
{"ollama", "ollama"},
{"opencode", "opencode"},
}
for _, tt := range tests {
+1 -15
View File
@@ -33,7 +33,6 @@ type Provider struct {
apiBase string
maxTokensField string // Field name for max tokens (e.g., "max_completion_tokens" for o1/glm models)
httpClient *http.Client
isKimiAPI bool // true when apiBase points to api.kimi.com
}
type Option func(*Provider)
@@ -70,17 +69,10 @@ func NewProvider(apiKey, apiBase, proxy string, opts ...Option) *Provider {
}
}
trimmedBase := strings.TrimRight(apiBase, "/")
var isKimi bool
if parsed, err := url.Parse(trimmedBase); err == nil {
isKimi = parsed.Hostname() == "api.kimi.com"
}
p := &Provider{
apiKey: apiKey,
apiBase: trimmedBase,
apiBase: strings.TrimRight(apiBase, "/"),
httpClient: client,
isKimiAPI: isKimi,
}
for _, opt := range opts {
@@ -184,12 +176,6 @@ func (p *Provider) Chat(
if p.apiKey != "" {
req.Header.Set("Authorization", "Bearer "+p.apiKey)
}
// Kimi Code API rejects requests without a recognized coding-agent
// User-Agent. "KimiCLI/0.77" is the minimum version string accepted
// by the api.kimi.com/coding/v1 endpoint (per Kimi's API docs).
if p.isKimiAPI {
req.Header.Set("User-Agent", "KimiCLI/0.77")
}
resp, err := p.httpClient.Do(req)
if err != nil {
@@ -2,7 +2,6 @@ package openai_compat
import (
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"net/url"
@@ -421,82 +420,6 @@ func TestProvider_FunctionalOptionRequestTimeoutNonPositive(t *testing.T) {
}
}
// roundTripFunc adapts a function to http.RoundTripper for test injection.
type roundTripFunc func(*http.Request) (*http.Response, error)
func (f roundTripFunc) RoundTrip(r *http.Request) (*http.Response, error) {
return f(r)
}
func TestProviderChat_KimiCodeUserAgent(t *testing.T) {
okBody := `{"choices":[{"message":{"content":"ok"},"finish_reason":"stop"}]}`
tests := []struct {
name string
apiBase string
wantAgent string
}{
{
name: "sets KimiCLI User-Agent for api.kimi.com",
apiBase: "https://api.kimi.com/coding/v1",
wantAgent: "KimiCLI/0.77",
},
{
name: "does not set KimiCLI User-Agent for other hosts",
apiBase: "https://api.example.com/v1",
wantAgent: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var gotUserAgent string
p := NewProvider("key", tt.apiBase, "")
p.httpClient.Transport = roundTripFunc(
func(r *http.Request) (*http.Response, error) {
gotUserAgent = r.Header.Get("User-Agent")
return &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(
strings.NewReader(okBody),
),
Header: http.Header{
"Content-Type": {"application/json"},
},
}, nil
},
)
_, err := p.Chat(
t.Context(),
[]Message{{Role: "user", Content: "hi"}},
nil,
"kimi-k2.5",
nil,
)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
if tt.wantAgent != "" {
if gotUserAgent != tt.wantAgent {
t.Fatalf(
"User-Agent = %q, want %q",
gotUserAgent, tt.wantAgent,
)
}
} else {
if gotUserAgent == "KimiCLI/0.77" {
t.Fatalf(
"User-Agent should not be KimiCLI/0.77 for non-kimi host",
)
}
}
})
}
}
func TestSerializeMessages_PlainText(t *testing.T) {
messages := []protocoltypes.Message{
{Role: "user", Content: "hello"},
+5 -4
View File
@@ -3,6 +3,7 @@ package tools
import (
"context"
"fmt"
"sync/atomic"
)
type SendCallback func(channel, chatID, content string) error
@@ -11,7 +12,7 @@ type MessageTool struct {
sendCallback SendCallback
defaultChannel string
defaultChatID string
sentInRound bool // Tracks whether a message was sent in the current processing round
sentInRound atomic.Bool // Tracks whether a message was sent in the current processing round
}
func NewMessageTool() *MessageTool {
@@ -50,12 +51,12 @@ func (t *MessageTool) Parameters() map[string]any {
func (t *MessageTool) SetContext(channel, chatID string) {
t.defaultChannel = channel
t.defaultChatID = chatID
t.sentInRound = false // Reset send tracking for new processing round
t.sentInRound.Store(false) // Reset send tracking for new processing round
}
// HasSentInRound returns true if the message tool sent a message during the current round.
func (t *MessageTool) HasSentInRound() bool {
return t.sentInRound
return t.sentInRound.Load()
}
func (t *MessageTool) SetSendCallback(callback SendCallback) {
@@ -94,7 +95,7 @@ func (t *MessageTool) Execute(ctx context.Context, args map[string]any) *ToolRes
}
}
t.sentInRound = true
t.sentInRound.Store(true)
// Silent: user already received the message directly
return &ToolResult{
ForLLM: fmt.Sprintf("Message sent to %s:%s", channel, chatID),
+2 -2
View File
@@ -30,9 +30,9 @@ var (
regexp.MustCompile(`\brm\s+-[rf]{1,2}\b`),
regexp.MustCompile(`\bdel\s+/[fq]\b`),
regexp.MustCompile(`\brmdir\s+/s\b`),
// Match disk wiping commands, avoid matching --format flags
// Match disk wiping commands (must be followed by space/args)
regexp.MustCompile(
`(?:^|[;&|]\s*|\s+)(format|mkfs|diskpart)\s`,
`\b(format|mkfs|diskpart)\b\s`,
),
regexp.MustCompile(`\bdd\s+if=`),
// Block writes to block devices (all common naming schemes).
-50
View File
@@ -366,56 +366,6 @@ func TestShellTool_BlockDevices(t *testing.T) {
}
}
// TestShellTool_DenyPattern_DiskWiping verifies the deny pattern for disk wiping
// commands (format, mkfs, diskpart) blocks them when preceded by shell separators
// but does NOT block legitimate uses like --format flags.
func TestShellTool_DenyPattern_DiskWiping(t *testing.T) {
tool, err := NewExecTool("", false)
if err != nil {
t.Fatalf("unable to configure exec tool: %s", err)
}
// These should be BLOCKED (disk wiping commands)
blockedCmds := []struct {
name string
cmd string
}{
{"format with space", "format c:"},
{"mkfs standalone", "mkfs /dev/sda"},
{"semicolon format", "echo hello; format c:"},
{"pipe format", "echo hello | format c:"},
{"and format", "echo hello && format c:"},
{"diskpart standalone", "diskpart /s script.txt"},
}
for _, tt := range blockedCmds {
t.Run("blocked_"+tt.name, func(t *testing.T) {
msg := tool.guardCommand(tt.cmd, "")
if !strings.Contains(msg, "blocked") {
t.Errorf("Expected %q to be blocked by safety guard, got: %q", tt.cmd, msg)
}
})
}
// These should be ALLOWED (not disk wiping)
allowed := []struct {
name string
cmd string
}{
{"--format flag", "echo test --format json"},
{"go fmt", "echo go fmt ./..."},
}
for _, tt := range allowed {
t.Run("allowed_"+tt.name, func(t *testing.T) {
msg := tool.guardCommand(tt.cmd, "")
if msg != "" {
t.Errorf("Expected %q to be allowed, but it was blocked: %s", tt.cmd, msg)
}
})
}
}
// TestShellTool_SafePathsInWorkspaceRestriction verifies that safe kernel pseudo-devices
// are allowed even when workspace restriction is active.
func TestShellTool_SafePathsInWorkspaceRestriction(t *testing.T) {
+43 -26
View File
@@ -10,6 +10,7 @@ import (
"context"
"encoding/json"
"fmt"
"sync"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/providers"
@@ -121,37 +122,53 @@ func RunToolLoop(
}
messages = append(messages, assistantMsg)
// 7. Execute tool calls
for _, tc := range normalizedToolCalls {
argsJSON, _ := json.Marshal(tc.Arguments)
argsPreview := utils.Truncate(string(argsJSON), 200)
logger.InfoCF("toolloop", fmt.Sprintf("Tool call: %s(%s)", tc.Name, argsPreview),
map[string]any{
"tool": tc.Name,
"iteration": iteration,
})
// 7. Execute tool calls in parallel
type indexedResult struct {
result *ToolResult
tc providers.ToolCall
}
// Execute tool (no async callback for subagents - they run independently)
var toolResult *ToolResult
if config.Tools != nil {
toolResult = config.Tools.ExecuteWithContext(ctx, tc.Name, tc.Arguments, channel, chatID, nil)
} else {
toolResult = ErrorResult("No tools available")
results := make([]indexedResult, len(normalizedToolCalls))
var wg sync.WaitGroup
for i, tc := range normalizedToolCalls {
results[i].tc = tc
wg.Add(1)
go func(idx int, tc providers.ToolCall) {
defer wg.Done()
argsJSON, _ := json.Marshal(tc.Arguments)
argsPreview := utils.Truncate(string(argsJSON), 200)
logger.InfoCF("toolloop", fmt.Sprintf("Tool call: %s(%s)", tc.Name, argsPreview),
map[string]any{
"tool": tc.Name,
"iteration": iteration,
})
var toolResult *ToolResult
if config.Tools != nil {
toolResult = config.Tools.ExecuteWithContext(ctx, tc.Name, tc.Arguments, channel, chatID, nil)
} else {
toolResult = ErrorResult("No tools available")
}
results[idx].result = toolResult
}(i, tc)
}
wg.Wait()
// Append results in original order
for _, r := range results {
contentForLLM := r.result.ForLLM
if contentForLLM == "" && r.result.Err != nil {
contentForLLM = r.result.Err.Error()
}
// Determine content for LLM
contentForLLM := toolResult.ForLLM
if contentForLLM == "" && toolResult.Err != nil {
contentForLLM = toolResult.Err.Error()
}
// Add tool result message
toolResultMsg := providers.Message{
messages = append(messages, providers.Message{
Role: "tool",
Content: contentForLLM,
ToolCallID: tc.ID,
}
messages = append(messages, toolResultMsg)
ToolCallID: r.tc.ID,
})
}
}
+107 -87
View File
@@ -395,6 +395,88 @@ func (p *PerplexitySearchProvider) Search(ctx context.Context, query string, cou
return fmt.Sprintf("Results for: %s (via Perplexity)\n%s", query, searchResp.Choices[0].Message.Content), nil
}
type GLMSearchProvider struct {
apiKey string
baseURL string
searchEngine string
proxy string
client *http.Client
}
func (p *GLMSearchProvider) Search(ctx context.Context, query string, count int) (string, error) {
searchURL := p.baseURL
if searchURL == "" {
searchURL = "https://open.bigmodel.cn/api/paas/v4/web_search"
}
payload := map[string]any{
"search_query": query,
"search_engine": p.searchEngine,
"search_intent": false,
"count": count,
"content_size": "medium",
}
bodyBytes, err := json.Marshal(payload)
if err != nil {
return "", fmt.Errorf("failed to marshal payload: %w", err)
}
req, err := http.NewRequestWithContext(ctx, "POST", searchURL, bytes.NewReader(bodyBytes))
if err != nil {
return "", fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+p.apiKey)
resp, err := p.client.Do(req)
if err != nil {
return "", fmt.Errorf("request failed: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
if err != nil {
return "", fmt.Errorf("failed to read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("GLM Search API error (status %d): %s", resp.StatusCode, string(body))
}
var searchResp struct {
SearchResult []struct {
Title string `json:"title"`
Content string `json:"content"`
Link string `json:"link"`
} `json:"search_result"`
}
if err := json.Unmarshal(body, &searchResp); err != nil {
return "", fmt.Errorf("failed to parse response: %w", err)
}
results := searchResp.SearchResult
if len(results) == 0 {
return fmt.Sprintf("No results for: %s", query), nil
}
var lines []string
lines = append(lines, fmt.Sprintf("Results for: %s (via GLM Search)", query))
for i, item := range results {
if i >= count {
break
}
lines = append(lines, fmt.Sprintf("%d. %s\n %s", i+1, item.Title, item.Link))
if item.Content != "" {
lines = append(lines, fmt.Sprintf(" %s", item.Content))
}
}
return strings.Join(lines, "\n"), nil
}
type WebSearchTool struct {
provider SearchProvider
maxResults int
@@ -413,9 +495,11 @@ type WebSearchToolOptions struct {
PerplexityAPIKey string
PerplexityMaxResults int
PerplexityEnabled bool
ExaAPIKey string
ExaMaxResults int
ExaEnabled bool
GLMSearchAPIKey string
GLMSearchBaseURL string
GLMSearchEngine string
GLMSearchMaxResults int
GLMSearchEnabled bool
Proxy string
}
@@ -423,7 +507,7 @@ func NewWebSearchTool(opts WebSearchToolOptions) (*WebSearchTool, error) {
var provider SearchProvider
maxResults := 5
// Priority: Perplexity > Exa > Brave > Tavily > DuckDuckGo
// Priority: Perplexity > Brave > Tavily > DuckDuckGo > GLM Search
if opts.PerplexityEnabled && opts.PerplexityAPIKey != "" {
client, err := createHTTPClient(opts.Proxy, perplexityTimeout)
if err != nil {
@@ -433,15 +517,6 @@ func NewWebSearchTool(opts WebSearchToolOptions) (*WebSearchTool, error) {
if opts.PerplexityMaxResults > 0 {
maxResults = opts.PerplexityMaxResults
}
} else if opts.ExaEnabled && opts.ExaAPIKey != "" {
client, err := createHTTPClient(opts.Proxy, searchTimeout)
if err != nil {
return nil, fmt.Errorf("failed to create HTTP client for Exa: %w", err)
}
provider = &ExaSearchProvider{apiKey: opts.ExaAPIKey, proxy: opts.Proxy, client: client}
if opts.ExaMaxResults > 0 {
maxResults = opts.ExaMaxResults
}
} else if opts.BraveEnabled && opts.BraveAPIKey != "" {
client, err := createHTTPClient(opts.Proxy, searchTimeout)
if err != nil {
@@ -474,6 +549,25 @@ func NewWebSearchTool(opts WebSearchToolOptions) (*WebSearchTool, error) {
if opts.DuckDuckGoMaxResults > 0 {
maxResults = opts.DuckDuckGoMaxResults
}
} else if opts.GLMSearchEnabled && opts.GLMSearchAPIKey != "" {
client, err := createHTTPClient(opts.Proxy, searchTimeout)
if err != nil {
return nil, fmt.Errorf("failed to create HTTP client for GLM Search: %w", err)
}
searchEngine := opts.GLMSearchEngine
if searchEngine == "" {
searchEngine = "search_std"
}
provider = &GLMSearchProvider{
apiKey: opts.GLMSearchAPIKey,
baseURL: opts.GLMSearchBaseURL,
searchEngine: searchEngine,
proxy: opts.Proxy,
client: client,
}
if opts.GLMSearchMaxResults > 0 {
maxResults = opts.GLMSearchMaxResults
}
} else {
return nil, nil
}
@@ -721,77 +815,3 @@ func (t *WebFetchTool) extractText(htmlContent string) string {
return strings.Join(cleanLines, "\n")
}
// ExaSearchProvider uses the Exa AI search API (https://exa.ai).
type ExaSearchProvider struct {
apiKey string
proxy string
client *http.Client
}
func (p *ExaSearchProvider) Search(ctx context.Context, query string, count int) (string, error) {
reqBody := map[string]any{
"query": query,
"num_results": count,
"type": "neural",
}
jsonData, err := json.Marshal(reqBody)
if err != nil {
return "", fmt.Errorf("exa: marshal error: %w", err)
}
req, err := http.NewRequestWithContext(ctx, "POST", "https://api.exa.ai/search", bytes.NewReader(jsonData))
if err != nil {
return "", fmt.Errorf("exa: request error: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("x-api-key", p.apiKey)
resp, err := p.client.Do(req)
if err != nil {
return "", fmt.Errorf("exa: search failed: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("exa: read error: %w", err)
}
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("exa: API error %d: %s", resp.StatusCode, string(body))
}
var result struct {
Results []struct {
Title string `json:"title"`
URL string `json:"url"`
Text string `json:"text"`
} `json:"results"`
}
if err := json.Unmarshal(body, &result); err != nil {
return "", fmt.Errorf("exa: parse error: %w", err)
}
if len(result.Results) == 0 {
return fmt.Sprintf("No results for: %s", query), nil
}
var lines []string
lines = append(lines, fmt.Sprintf("Results for: %s (via Exa)", query))
maxResults := count
if maxResults > len(result.Results) {
maxResults = len(result.Results)
}
for i, r := range result.Results[:maxResults] {
lines = append(lines, fmt.Sprintf("%d. %s\n %s", i+1, r.Title, r.URL))
if r.Text != "" {
snippet := r.Text
if len(snippet) > 200 {
snippet = snippet[:200] + "..."
}
lines = append(lines, fmt.Sprintf(" %s", snippet))
}
}
return strings.Join(lines, "\n"), nil
}
+105 -189
View File
@@ -5,7 +5,6 @@ import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"strings"
@@ -683,86 +682,7 @@ func TestWebTool_TavilySearch_Success(t *testing.T) {
}
}
func TestNewWebSearchTool_ExaPriority(t *testing.T) {
// Exa should be selected when enabled with API key
tool, err := NewWebSearchTool(WebSearchToolOptions{
ExaEnabled: true,
ExaAPIKey: "exa-key",
ExaMaxResults: 3,
})
if err != nil {
t.Fatalf("NewWebSearchTool() error: %v", err)
}
if tool == nil {
t.Fatal("Expected non-nil tool when Exa is enabled with API key")
}
if _, ok := tool.provider.(*ExaSearchProvider); !ok {
t.Fatalf("provider type = %T, want *ExaSearchProvider", tool.provider)
}
if tool.maxResults != 3 {
t.Fatalf("maxResults = %d, want 3", tool.maxResults)
}
// Exa enabled but no API key should fall through
tool, err = NewWebSearchTool(WebSearchToolOptions{
ExaEnabled: true,
ExaAPIKey: "",
})
if err != nil {
t.Fatalf("NewWebSearchTool() error: %v", err)
}
if tool != nil {
t.Errorf("Expected nil tool when Exa API key is empty and no other provider enabled")
}
// Perplexity should take priority over Exa
tool, err = NewWebSearchTool(WebSearchToolOptions{
PerplexityEnabled: true,
PerplexityAPIKey: "perp-key",
ExaEnabled: true,
ExaAPIKey: "exa-key",
})
if err != nil {
t.Fatalf("NewWebSearchTool() error: %v", err)
}
if _, ok := tool.provider.(*PerplexitySearchProvider); !ok {
t.Fatalf("provider type = %T, want *PerplexitySearchProvider (Perplexity should outrank Exa)", tool.provider)
}
// Exa should take priority over Brave
tool, err = NewWebSearchTool(WebSearchToolOptions{
ExaEnabled: true,
ExaAPIKey: "exa-key",
BraveEnabled: true,
BraveAPIKey: "brave-key",
})
if err != nil {
t.Fatalf("NewWebSearchTool() error: %v", err)
}
if _, ok := tool.provider.(*ExaSearchProvider); !ok {
t.Fatalf("provider type = %T, want *ExaSearchProvider (Exa should outrank Brave)", tool.provider)
}
}
func TestNewWebSearchTool_ExaProxyPropagation(t *testing.T) {
tool, err := NewWebSearchTool(WebSearchToolOptions{
ExaEnabled: true,
ExaAPIKey: "k",
Proxy: "http://127.0.0.1:7890",
})
if err != nil {
t.Fatalf("NewWebSearchTool() error: %v", err)
}
p, ok := tool.provider.(*ExaSearchProvider)
if !ok {
t.Fatalf("provider type = %T, want *ExaSearchProvider", tool.provider)
}
if p.proxy != "http://127.0.0.1:7890" {
t.Fatalf("provider proxy = %q, want %q", p.proxy, "http://127.0.0.1:7890")
}
}
func TestExaSearchProvider_Success(t *testing.T) {
func TestWebTool_GLMSearch_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
t.Errorf("Expected POST request, got %s", r.Method)
@@ -770,130 +690,126 @@ func TestExaSearchProvider_Success(t *testing.T) {
if r.Header.Get("Content-Type") != "application/json" {
t.Errorf("Expected Content-Type application/json, got %s", r.Header.Get("Content-Type"))
}
if r.Header.Get("x-api-key") != "test-exa-key" {
t.Errorf("Expected x-api-key test-exa-key, got %s", r.Header.Get("x-api-key"))
if r.Header.Get("Authorization") != "Bearer test-glm-key" {
t.Errorf("Expected Authorization Bearer test-glm-key, got %s", r.Header.Get("Authorization"))
}
// Verify payload
body, _ := io.ReadAll(r.Body)
var payload map[string]any
json.Unmarshal(body, &payload)
if payload["query"] != "test query" {
t.Errorf("Expected query 'test query', got %v", payload["query"])
json.NewDecoder(r.Body).Decode(&payload)
if payload["search_query"] != "test query" {
t.Errorf("Expected search_query 'test query', got %v", payload["search_query"])
}
if payload["type"] != "neural" {
t.Errorf("Expected type 'neural', got %v", payload["type"])
if payload["search_engine"] != "search_std" {
t.Errorf("Expected search_engine 'search_std', got %v", payload["search_engine"])
}
response := map[string]any{
"results": []map[string]any{
{"title": "Exa Result 1", "url": "https://exa.ai/1", "text": "First result text"},
{"title": "Exa Result 2", "url": "https://exa.ai/2", "text": "Second result text"},
{"title": "Exa Result 3", "url": "https://exa.ai/3", "text": "Third result text"},
"id": "web-search-test",
"created": 1709568000,
"search_result": []map[string]any{
{
"title": "Test GLM Result",
"content": "GLM search snippet",
"link": "https://example.com/glm",
"media": "Example",
"publish_date": "2026-03-04",
},
},
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(response)
}))
defer server.Close()
provider := &ExaSearchProvider{
apiKey: "test-exa-key",
client: &http.Client{},
}
// Temporarily override the API URL by using a custom transport
provider.client.Transport = rewriteHostTransport(server.URL)
result, err := provider.Search(context.Background(), "test query", 5)
if err != nil {
t.Fatalf("Search() error: %v", err)
}
if !strings.Contains(result, "via Exa") {
t.Errorf("Expected '(via Exa)' attribution, got: %s", result)
}
if !strings.Contains(result, "Exa Result 1") || !strings.Contains(result, "https://exa.ai/1") {
t.Errorf("Expected results in output, got: %s", result)
}
if !strings.Contains(result, "First result text") {
t.Errorf("Expected snippet text in output, got: %s", result)
}
}
func TestExaSearchProvider_EmptyResults(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
response := map[string]any{"results": []map[string]any{}}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(response)
}))
defer server.Close()
provider := &ExaSearchProvider{
apiKey: "test-key",
client: &http.Client{Transport: rewriteHostTransport(server.URL)},
}
result, err := provider.Search(context.Background(), "no results query", 5)
if err != nil {
t.Fatalf("Search() error: %v", err)
}
if !strings.Contains(result, "No results for: no results query") {
t.Errorf("Expected 'No results' message, got: %s", result)
}
}
func TestExaSearchProvider_MaxResultsCapping(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Return 5 results
results := make([]map[string]any, 5)
for i := range results {
results[i] = map[string]any{
"title": fmt.Sprintf("Result %d", i+1),
"url": fmt.Sprintf("https://exa.ai/%d", i+1),
"text": fmt.Sprintf("Text %d", i+1),
}
}
response := map[string]any{"results": results}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(response)
}))
defer server.Close()
provider := &ExaSearchProvider{
apiKey: "test-key",
client: &http.Client{Transport: rewriteHostTransport(server.URL)},
}
// Request only 2 results even though API returns 5
result, err := provider.Search(context.Background(), "test", 2)
if err != nil {
t.Fatalf("Search() error: %v", err)
}
if !strings.Contains(result, "Result 1") || !strings.Contains(result, "Result 2") {
t.Errorf("Expected first 2 results, got: %s", result)
}
if strings.Contains(result, "Result 3") {
t.Errorf("Expected results capped at 2, but got Result 3 in output: %s", result)
}
}
// rewriteHostTransport returns an http.RoundTripper that redirects all requests to the given target URL.
func rewriteHostTransport(target string) http.RoundTripper {
return roundTripFunc(func(req *http.Request) (*http.Response, error) {
newURL := target + req.URL.Path
newReq, err := http.NewRequestWithContext(req.Context(), req.Method, newURL, req.Body)
if err != nil {
return nil, err
}
newReq.Header = req.Header
return http.DefaultClient.Do(newReq)
tool, err := NewWebSearchTool(WebSearchToolOptions{
GLMSearchEnabled: true,
GLMSearchAPIKey: "test-glm-key",
GLMSearchBaseURL: server.URL,
GLMSearchEngine: "search_std",
})
if err != nil {
t.Fatalf("NewWebSearchTool() error: %v", err)
}
result := tool.Execute(context.Background(), map[string]any{
"query": "test query",
})
if result.IsError {
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
}
if !strings.Contains(result.ForUser, "Test GLM Result") {
t.Errorf("Expected 'Test GLM Result' in output, got: %s", result.ForUser)
}
if !strings.Contains(result.ForUser, "https://example.com/glm") {
t.Errorf("Expected URL in output, got: %s", result.ForUser)
}
if !strings.Contains(result.ForUser, "via GLM Search") {
t.Errorf("Expected 'via GLM Search' in output, got: %s", result.ForUser)
}
}
type roundTripFunc func(*http.Request) (*http.Response, error)
func TestWebTool_GLMSearch_APIError(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
w.Write([]byte(`{"error":"invalid api key"}`))
}))
defer server.Close()
func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return f(req)
tool, err := NewWebSearchTool(WebSearchToolOptions{
GLMSearchEnabled: true,
GLMSearchAPIKey: "bad-key",
GLMSearchBaseURL: server.URL,
GLMSearchEngine: "search_std",
})
if err != nil {
t.Fatalf("NewWebSearchTool() error: %v", err)
}
result := tool.Execute(context.Background(), map[string]any{
"query": "test query",
})
if !result.IsError {
t.Errorf("Expected IsError=true for 401 response")
}
if !strings.Contains(result.ForLLM, "status 401") {
t.Errorf("Expected status 401 in error, got: %s", result.ForLLM)
}
}
func TestWebTool_GLMSearch_Priority(t *testing.T) {
// GLM Search should only be selected when all other providers are disabled
tool, err := NewWebSearchTool(WebSearchToolOptions{
DuckDuckGoEnabled: true,
DuckDuckGoMaxResults: 5,
GLMSearchEnabled: true,
GLMSearchAPIKey: "test-key",
GLMSearchBaseURL: "https://example.com",
GLMSearchEngine: "search_std",
})
if err != nil {
t.Fatalf("NewWebSearchTool() error: %v", err)
}
// DuckDuckGo should win over GLM Search
if _, ok := tool.provider.(*DuckDuckGoSearchProvider); !ok {
t.Errorf("Expected DuckDuckGoSearchProvider when both enabled, got %T", tool.provider)
}
// With DuckDuckGo disabled, GLM Search should be selected
tool2, err := NewWebSearchTool(WebSearchToolOptions{
DuckDuckGoEnabled: false,
GLMSearchEnabled: true,
GLMSearchAPIKey: "test-key",
GLMSearchBaseURL: "https://example.com",
GLMSearchEngine: "search_std",
})
if err != nil {
t.Fatalf("NewWebSearchTool() error: %v", err)
}
if _, ok := tool2.provider.(*GLMSearchProvider); !ok {
t.Errorf("Expected GLMSearchProvider when only GLM enabled, got %T", tool2.provider)
}
}
+19 -4
View File
@@ -3,6 +3,7 @@ package utils
import (
"io"
"net/http"
"net/url"
"os"
"path/filepath"
"strings"
@@ -52,11 +53,12 @@ type DownloadOptions struct {
Timeout time.Duration
ExtraHeaders map[string]string
LoggerPrefix string
ProxyURL string
}
// DownloadFile downloads a file from URL to a local temp directory.
// Returns the local file path or empty string on error.
func DownloadFile(url, filename string, opts DownloadOptions) string {
func DownloadFile(urlStr, filename string, opts DownloadOptions) string {
// Set defaults
if opts.Timeout == 0 {
opts.Timeout = 60 * time.Second
@@ -78,7 +80,7 @@ func DownloadFile(url, filename string, opts DownloadOptions) string {
localPath := filepath.Join(mediaDir, uuid.New().String()[:8]+"_"+safeName)
// Create HTTP request
req, err := http.NewRequest("GET", url, nil)
req, err := http.NewRequest("GET", urlStr, nil)
if err != nil {
logger.ErrorCF(opts.LoggerPrefix, "Failed to create download request", map[string]any{
"error": err.Error(),
@@ -92,11 +94,24 @@ func DownloadFile(url, filename string, opts DownloadOptions) string {
}
client := &http.Client{Timeout: opts.Timeout}
if opts.ProxyURL != "" {
proxyURL, parseErr := url.Parse(opts.ProxyURL)
if parseErr != nil {
logger.ErrorCF(opts.LoggerPrefix, "Invalid proxy URL for download", map[string]any{
"error": parseErr.Error(),
"proxy": opts.ProxyURL,
})
return ""
}
client.Transport = &http.Transport{
Proxy: http.ProxyURL(proxyURL),
}
}
resp, err := client.Do(req)
if err != nil {
logger.ErrorCF(opts.LoggerPrefix, "Failed to download file", map[string]any{
"error": err.Error(),
"url": url,
"url": urlStr,
})
return ""
}
@@ -105,7 +120,7 @@ func DownloadFile(url, filename string, opts DownloadOptions) string {
if resp.StatusCode != http.StatusOK {
logger.ErrorCF(opts.LoggerPrefix, "File download returned non-200 status", map[string]any{
"status": resp.StatusCode,
"url": url,
"url": urlStr,
})
return ""
}