diff --git a/.env.example b/.env.example index 06d43070c..bc68456d6 100644 --- a/.env.example +++ b/.env.example @@ -17,4 +17,4 @@ # BRAVE_SEARCH_API_KEY=BSA... # ── Timezone ────────────────────────────── -TZ=Asia/Tokyo +TZ=Asia/Shanghai diff --git a/README.md b/README.md index c5b38e222..759ebbb82 100644 --- a/README.md +++ b/README.md @@ -54,7 +54,7 @@ ## 📢 News -2026-02-16 🎉 PicoClaw hit 12K stars in one week! Thank you all for your support! PicoClaw is growing faster than we ever imagined. Given the high volume of PRs, we urgently need community maintainers. Our volunteer roles and roadmap are officially posted [here](docs/ROADMAP.md) —we can’t wait to have you on board! +2026-02-16 🎉 PicoClaw hit 12K stars in one week! Thank you all for your support! PicoClaw is growing faster than we ever imagined. Given the high volume of PRs, we urgently need community maintainers. Our volunteer roles and roadmap are officially posted [here](ROADMAP.md) —we can’t wait to have you on board! 2026-02-13 🎉 PicoClaw hit 5000 stars in 4days! Thank you for the community! There are so many PRs & issues coming in (during Chinese New Year holidays), we are finalizing the Project Roadmap and setting up the Developer Group to accelerate PicoClaw's development. 🚀 Call to Action: Please submit your feature requests in GitHub Discussions. We will review and prioritize them during our upcoming weekly meeting. diff --git a/assets/wechat.png b/assets/wechat.png index 1c0b88295..32998c122 100644 Binary files a/assets/wechat.png and b/assets/wechat.png differ diff --git a/config/config.example.json b/config/config.example.json index 3c84cfa9f..f46f6a670 100644 --- a/config/config.example.json +++ b/config/config.example.json @@ -6,7 +6,9 @@ "model_name": "gpt4", "max_tokens": 8192, "temperature": 0.7, - "max_tool_iterations": 20 + "max_tool_iterations": 20, + "summarize_message_threshold": 20, + "summarize_token_percent": 75 } }, "model_list": [ @@ -59,6 +61,7 @@ "discord": { "enabled": false, "token": "YOUR_DISCORD_BOT_TOKEN", + "proxy": "", "allow_from": [], "group_trigger": { "mention_only": false @@ -337,4 +340,4 @@ "host": "127.0.0.1", "port": 18790 } -} \ No newline at end of file +} diff --git a/go.mod b/go.mod index a024c2023..c1172937c 100644 --- a/go.mod +++ b/go.mod @@ -11,7 +11,6 @@ require ( github.com/gdamore/tcell/v2 v2.13.8 github.com/google/uuid v1.6.0 github.com/gorilla/websocket v1.5.3 - github.com/joho/godotenv v1.5.1 github.com/larksuite/oapi-sdk-go/v3 v3.5.3 github.com/mdp/qrterminal/v3 v3.2.1 github.com/modelcontextprotocol/go-sdk v1.3.0 @@ -38,6 +37,7 @@ require ( github.com/dustin/go-humanize v1.0.1 // indirect github.com/elliotchance/orderedmap/v3 v3.1.0 // indirect github.com/gdamore/encoding v1.0.1 // indirect + github.com/gdamore/tcell/v2 v2.13.8 // indirect github.com/h2non/filetype v1.1.3 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/lucasb-eyer/go-colorful v1.3.0 // indirect diff --git a/go.sum b/go.sum index fc4892027..060594d06 100644 --- a/go.sum +++ b/go.sum @@ -105,8 +105,6 @@ github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyf github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= -github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= -github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/klauspost/compress v1.18.4 h1:RPhnKRAQ4Fh8zU2FY/6ZFDwTVTxgJ/EMydqSTzE9a2c= diff --git a/pkg/agent/instance.go b/pkg/agent/instance.go index ed438059f..ed25f537f 100644 --- a/pkg/agent/instance.go +++ b/pkg/agent/instance.go @@ -18,22 +18,24 @@ import ( // AgentInstance represents a fully configured agent with its own workspace, // session manager, context builder, and tool registry. type AgentInstance struct { - ID string - Name string - Model string - Fallbacks []string - Workspace string - MaxIterations int - MaxTokens int - Temperature float64 - ContextWindow int - Provider providers.LLMProvider - Sessions *session.SessionManager - ContextBuilder *ContextBuilder - Tools *tools.ToolRegistry - Subagents *config.SubagentsConfig - SkillsFilter []string - Candidates []providers.FallbackCandidate + ID string + Name string + Model string + Fallbacks []string + Workspace string + MaxIterations int + MaxTokens int + Temperature float64 + ContextWindow int + SummarizeMessageThreshold int + SummarizeTokenPercent int + Provider providers.LLMProvider + Sessions *session.SessionManager + ContextBuilder *ContextBuilder + Tools *tools.ToolRegistry + Subagents *config.SubagentsConfig + SkillsFilter []string + Candidates []providers.FallbackCandidate } // NewAgentInstance creates an agent instance from config. @@ -101,6 +103,16 @@ func NewAgentInstance( temperature = *defaults.Temperature } + summarizeMessageThreshold := defaults.SummarizeMessageThreshold + if summarizeMessageThreshold == 0 { + summarizeMessageThreshold = 20 + } + + summarizeTokenPercent := defaults.SummarizeTokenPercent + if summarizeTokenPercent == 0 { + summarizeTokenPercent = 75 + } + // Resolve fallback candidates modelCfg := providers.ModelConfig{ Primary: model, @@ -149,22 +161,24 @@ func NewAgentInstance( candidates := providers.ResolveCandidatesWithLookup(modelCfg, defaults.Provider, resolveFromModelList) return &AgentInstance{ - ID: agentID, - Name: agentName, - Model: model, - Fallbacks: fallbacks, - Workspace: workspace, - MaxIterations: maxIter, - MaxTokens: maxTokens, - Temperature: temperature, - ContextWindow: maxTokens, - Provider: provider, - Sessions: sessionsManager, - ContextBuilder: contextBuilder, - Tools: toolsRegistry, - Subagents: subagents, - SkillsFilter: skillsFilter, - Candidates: candidates, + ID: agentID, + Name: agentName, + Model: model, + Fallbacks: fallbacks, + Workspace: workspace, + MaxIterations: maxIter, + MaxTokens: maxTokens, + Temperature: temperature, + ContextWindow: maxTokens, + SummarizeMessageThreshold: summarizeMessageThreshold, + SummarizeTokenPercent: summarizeTokenPercent, + Provider: provider, + Sessions: sessionsManager, + ContextBuilder: contextBuilder, + Tools: toolsRegistry, + Subagents: subagents, + SkillsFilter: skillsFilter, + Candidates: candidates, } } diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index 89d42069d..db9efa2cf 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -118,9 +118,11 @@ func registerSharedTools( PerplexityAPIKey: cfg.Tools.Web.Perplexity.APIKey, PerplexityMaxResults: cfg.Tools.Web.Perplexity.MaxResults, PerplexityEnabled: cfg.Tools.Web.Perplexity.Enabled, - ExaAPIKey: cfg.Tools.Web.Exa.APIKey, - ExaMaxResults: cfg.Tools.Web.Exa.MaxResults, - ExaEnabled: cfg.Tools.Web.Exa.Enabled, + GLMSearchAPIKey: cfg.Tools.Web.GLMSearch.APIKey, + GLMSearchBaseURL: cfg.Tools.Web.GLMSearch.BaseURL, + GLMSearchEngine: cfg.Tools.Web.GLMSearch.SearchEngine, + GLMSearchMaxResults: cfg.Tools.Web.GLMSearch.MaxResults, + GLMSearchEnabled: cfg.Tools.Web.GLMSearch.Enabled, Proxy: cfg.Tools.Web.Proxy, }) if err != nil { @@ -967,62 +969,76 @@ func (al *AgentLoop) runLLMIteration( // Save assistant message with tool calls to session agent.Sessions.AddFullMessage(opts.SessionKey, assistantMsg) - // Execute tool calls - for _, tc := range normalizedToolCalls { - argsJSON, _ := json.Marshal(tc.Arguments) - argsPreview := utils.Truncate(string(argsJSON), 200) - logger.InfoCF("agent", fmt.Sprintf("Tool call: %s(%s)", tc.Name, argsPreview), - map[string]any{ - "agent_id": agent.ID, - "tool": tc.Name, - "iteration": iteration, - }) + // Execute tool calls in parallel + type indexedAgentResult struct { + result *tools.ToolResult + tc providers.ToolCall + } - // Create async callback for tools that implement AsyncTool - // NOTE: Following openclaw's design, async tools do NOT send results directly to users. - // Instead, they notify the agent via PublishInbound, and the agent decides - // whether to forward the result to the user (in processSystemMessage). - asyncCallback := func(callbackCtx context.Context, result *tools.ToolResult) { - // Log the async completion but don't send directly to user - // The agent will handle user notification via processSystemMessage - if !result.Silent && result.ForUser != "" { - logger.InfoCF("agent", "Async tool completed, agent will handle notification", - map[string]any{ - "tool": tc.Name, - "content_len": len(result.ForUser), - }) + agentResults := make([]indexedAgentResult, len(normalizedToolCalls)) + var wg sync.WaitGroup + + for i, tc := range normalizedToolCalls { + agentResults[i].tc = tc + + wg.Add(1) + go func(idx int, tc providers.ToolCall) { + defer wg.Done() + + argsJSON, _ := json.Marshal(tc.Arguments) + argsPreview := utils.Truncate(string(argsJSON), 200) + logger.InfoCF("agent", fmt.Sprintf("Tool call: %s(%s)", tc.Name, argsPreview), + map[string]any{ + "agent_id": agent.ID, + "tool": tc.Name, + "iteration": iteration, + }) + + // Create async callback for tools that implement AsyncTool + asyncCallback := func(callbackCtx context.Context, result *tools.ToolResult) { + if !result.Silent && result.ForUser != "" { + logger.InfoCF("agent", "Async tool completed, agent will handle notification", + map[string]any{ + "tool": tc.Name, + "content_len": len(result.ForUser), + }) + } } - } - toolResult := agent.Tools.ExecuteWithContext( - ctx, - tc.Name, - tc.Arguments, - opts.Channel, - opts.ChatID, - asyncCallback, - ) + toolResult := agent.Tools.ExecuteWithContext( + ctx, + tc.Name, + tc.Arguments, + opts.Channel, + opts.ChatID, + asyncCallback, + ) + agentResults[idx].result = toolResult + }(i, tc) + } + wg.Wait() + // Process results in original order (send to user, save to session) + for _, r := range agentResults { // Send ForUser content to user immediately if not Silent - if !toolResult.Silent && toolResult.ForUser != "" && opts.SendResponse { + if !r.result.Silent && r.result.ForUser != "" && opts.SendResponse { al.bus.PublishOutbound(ctx, bus.OutboundMessage{ Channel: opts.Channel, ChatID: opts.ChatID, - Content: toolResult.ForUser, + Content: r.result.ForUser, }) logger.DebugCF("agent", "Sent tool result to user", map[string]any{ - "tool": tc.Name, - "content_len": len(toolResult.ForUser), + "tool": r.tc.Name, + "content_len": len(r.result.ForUser), }) } // If tool returned media refs, publish them as outbound media - if len(toolResult.Media) > 0 && opts.SendResponse { - parts := make([]bus.MediaPart, 0, len(toolResult.Media)) - for _, ref := range toolResult.Media { + if len(r.result.Media) > 0 && opts.SendResponse { + parts := make([]bus.MediaPart, 0, len(r.result.Media)) + for _, ref := range r.result.Media { part := bus.MediaPart{Ref: ref} - // Populate metadata from MediaStore when available if al.mediaStore != nil { if _, meta, err := al.mediaStore.ResolveWithMeta(ref); err == nil { part.Filename = meta.Filename @@ -1040,15 +1056,15 @@ func (al *AgentLoop) runLLMIteration( } // Determine content for LLM based on tool result - contentForLLM := toolResult.ForLLM - if contentForLLM == "" && toolResult.Err != nil { - contentForLLM = toolResult.Err.Error() + contentForLLM := r.result.ForLLM + if contentForLLM == "" && r.result.Err != nil { + contentForLLM = r.result.Err.Error() } toolResultMsg := providers.Message{ Role: "tool", Content: contentForLLM, - ToolCallID: tc.ID, + ToolCallID: r.tc.ID, } messages = append(messages, toolResultMsg) @@ -1084,9 +1100,9 @@ func (al *AgentLoop) updateToolContexts(agent *AgentInstance, channel, chatID st func (al *AgentLoop) maybeSummarize(agent *AgentInstance, sessionKey, channel, chatID string) { newHistory := agent.Sessions.GetHistory(sessionKey) tokenEstimate := al.estimateTokens(newHistory) - threshold := agent.ContextWindow * 75 / 100 + threshold := agent.ContextWindow * agent.SummarizeTokenPercent / 100 - if len(newHistory) > 20 || tokenEstimate > threshold { + if len(newHistory) > agent.SummarizeMessageThreshold || tokenEstimate > threshold { summarizeKey := agent.ID + ":" + sessionKey if _, loading := al.summarizing.LoadOrStore(summarizeKey, true); !loading { go func() { @@ -1114,15 +1130,8 @@ func (al *AgentLoop) forceCompression(agent *AgentInstance, sessionKey string) { return } - // Find the mid-point of the conversation, avoiding splitting tool call/result pairs. - // A tool-call message (role=assistant with ToolCalls) must be followed by its - // tool-result message (role=tool). Splitting between them causes API errors. + // Helper to find the mid-point of the conversation mid := len(conversation) / 2 - if mid < len(conversation) && mid > 0 { - if conversation[mid].Role == "tool" { - mid++ // move past the tool result to keep the pair together - } - } // New history structure: // 1. System Prompt (with compression note appended) diff --git a/pkg/agent/loop_test.go b/pkg/agent/loop_test.go index 55098fa61..023286f02 100644 --- a/pkg/agent/loop_test.go +++ b/pkg/agent/loop_test.go @@ -603,85 +603,6 @@ func TestAgentLoop_ContextExhaustionRetry(t *testing.T) { } } -// TestForceCompression_ToolMessageBoundary verifies that forceCompression does not -// split a tool call/result pair when the midpoint falls on a "tool" role message. -// Regression test for: API errors when orphaned tool result messages appear -// without their preceding assistant tool-call message. -func TestForceCompression_ToolMessageBoundary(t *testing.T) { - tmpDir, err := os.MkdirTemp("", "agent-test-*") - if err != nil { - t.Fatalf("Failed to create temp dir: %v", err) - } - defer os.RemoveAll(tmpDir) - - cfg := &config.Config{ - Agents: config.AgentsConfig{ - Defaults: config.AgentDefaults{ - Workspace: tmpDir, - Model: "test-model", - MaxTokens: 4096, - MaxToolIterations: 10, - }, - }, - } - - msgBus := bus.NewMessageBus() - provider := &mockProvider{} - al := NewAgentLoop(cfg, msgBus, provider) - - sessionKey := "test-session-tool-boundary" - defaultAgent := al.registry.GetDefaultAgent() - if defaultAgent == nil { - t.Fatal("No default agent found") - } - - // Construct a history where len(conversation)/2 falls exactly on a "tool" message. - // history = [system, user, assistant(tool_call), tool, user, assistant, user_trigger] - // conversation = history[1:6] = [user, assistant(tool_call), tool, user, assistant] - // len(conversation) = 5, mid = 5/2 = 2 => conversation[2].Role == "tool" - // Without the fix, this would split between assistant(tool_call) and tool result. - history := []providers.Message{ - {Role: "system", Content: "You are a helpful assistant."}, - {Role: "user", Content: "What files are in the current directory?"}, - {Role: "assistant", Content: "", ToolCalls: []providers.ToolCall{ - {ID: "call_1", Name: "exec", Arguments: map[string]any{"command": "ls"}}, - }}, - {Role: "tool", Content: "file1.txt\nfile2.txt", ToolCallID: "call_1"}, - {Role: "user", Content: "Tell me about file1.txt"}, - {Role: "assistant", Content: "file1.txt is a text file."}, - {Role: "user", Content: "Thanks"}, // trigger message - } - - // Create the session first (AddMessage creates the session entry), - // then overwrite with our full history via SetHistory. - defaultAgent.Sessions.AddMessage(sessionKey, "system", "init") - defaultAgent.Sessions.SetHistory(sessionKey, history) - - // Call forceCompression - al.forceCompression(defaultAgent, sessionKey) - - // Verify the result - compressed := defaultAgent.Sessions.GetHistory(sessionKey) - - // Check that no message with role="tool" is the first conversation message - // (after the system prompt). If it is, it means the tool result was orphaned. - for i := 1; i < len(compressed); i++ { - if compressed[i].Role == "tool" { - // There must be an assistant message with tool calls before it - if i == 1 { - t.Errorf("Tool result message at position %d is orphaned (no preceding assistant with tool call)", i) - } else if compressed[i-1].Role != "assistant" || len(compressed[i-1].ToolCalls) == 0 { - t.Errorf("Tool result at position %d is not preceded by assistant with tool calls (preceded by role=%q)", i, compressed[i-1].Role) - } - } - } - - // Verify the system prompt has the compression note - if !strings.Contains(compressed[0].Content, "Emergency compression") { - t.Errorf("Expected compression note in system prompt, got: %s", compressed[0].Content) - } -} - func TestTargetReasoningChannelID_AllChannels(t *testing.T) { tmpDir, err := os.MkdirTemp("", "agent-test-*") if err != nil { diff --git a/pkg/channels/discord/discord.go b/pkg/channels/discord/discord.go index cd6a2560f..1de910c83 100644 --- a/pkg/channels/discord/discord.go +++ b/pkg/channels/discord/discord.go @@ -3,12 +3,15 @@ package discord import ( "context" "fmt" + "net/http" + "net/url" "os" "strings" "sync" "time" "github.com/bwmarrin/discordgo" + "github.com/gorilla/websocket" "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/channels" @@ -40,6 +43,9 @@ func NewDiscordChannel(cfg config.DiscordConfig, bus *bus.MessageBus) (*DiscordC return nil, fmt.Errorf("failed to create discord session: %w", err) } + if err := applyDiscordProxy(session, cfg.Proxy); err != nil { + return nil, err + } base := channels.NewBaseChannel("discord", cfg, bus, cfg.AllowFrom, channels.WithMaxMessageLength(2000), channels.WithGroupTrigger(cfg.GroupTrigger), @@ -465,9 +471,43 @@ func (c *DiscordChannel) StartTyping(ctx context.Context, chatID string) (func() func (c *DiscordChannel) downloadAttachment(url, filename string) string { return utils.DownloadFile(url, filename, utils.DownloadOptions{ LoggerPrefix: "discord", + ProxyURL: c.config.Proxy, }) } +func applyDiscordProxy(session *discordgo.Session, proxyAddr string) error { + var proxyFunc func(*http.Request) (*url.URL, error) + if proxyAddr != "" { + proxyURL, err := url.Parse(proxyAddr) + if err != nil { + return fmt.Errorf("invalid discord proxy URL %q: %w", proxyAddr, err) + } + proxyFunc = http.ProxyURL(proxyURL) + } else if os.Getenv("HTTP_PROXY") != "" || os.Getenv("HTTPS_PROXY") != "" { + proxyFunc = http.ProxyFromEnvironment + } + + if proxyFunc == nil { + return nil + } + + transport := &http.Transport{Proxy: proxyFunc} + session.Client = &http.Client{ + Timeout: sendTimeout, + Transport: transport, + } + + if session.Dialer != nil { + dialerCopy := *session.Dialer + dialerCopy.Proxy = proxyFunc + session.Dialer = &dialerCopy + } else { + session.Dialer = &websocket.Dialer{Proxy: proxyFunc} + } + + return nil +} + // stripBotMention removes the bot mention from the message content. // Discord mentions have the format <@USER_ID> or <@!USER_ID> (with nickname). func (c *DiscordChannel) stripBotMention(text string) string { diff --git a/pkg/channels/discord/discord_test.go b/pkg/channels/discord/discord_test.go new file mode 100644 index 000000000..0cd5328f4 --- /dev/null +++ b/pkg/channels/discord/discord_test.go @@ -0,0 +1,91 @@ +package discord + +import ( + "net/http" + "net/url" + "testing" + + "github.com/bwmarrin/discordgo" +) + +func TestApplyDiscordProxy_CustomProxy(t *testing.T) { + session, err := discordgo.New("Bot test-token") + if err != nil { + t.Fatalf("discordgo.New() error: %v", err) + } + + if err = applyDiscordProxy(session, "http://127.0.0.1:7890"); err != nil { + t.Fatalf("applyDiscordProxy() error: %v", err) + } + + req, err := http.NewRequest("GET", "https://discord.com/api/v10/gateway", nil) + if err != nil { + t.Fatalf("http.NewRequest() error: %v", err) + } + + restProxy := session.Client.Transport.(*http.Transport).Proxy + restProxyURL, err := restProxy(req) + if err != nil { + t.Fatalf("rest proxy func error: %v", err) + } + if got, want := restProxyURL.String(), "http://127.0.0.1:7890"; got != want { + t.Fatalf("REST proxy = %q, want %q", got, want) + } + + wsProxyURL, err := session.Dialer.Proxy(req) + if err != nil { + t.Fatalf("ws proxy func error: %v", err) + } + if got, want := wsProxyURL.String(), "http://127.0.0.1:7890"; got != want { + t.Fatalf("WS proxy = %q, want %q", got, want) + } +} + +func TestApplyDiscordProxy_FromEnvironment(t *testing.T) { + t.Setenv("HTTP_PROXY", "http://127.0.0.1:8888") + t.Setenv("http_proxy", "http://127.0.0.1:8888") + t.Setenv("HTTPS_PROXY", "http://127.0.0.1:8888") + t.Setenv("https_proxy", "http://127.0.0.1:8888") + t.Setenv("ALL_PROXY", "") + t.Setenv("all_proxy", "") + t.Setenv("NO_PROXY", "") + t.Setenv("no_proxy", "") + + session, err := discordgo.New("Bot test-token") + if err != nil { + t.Fatalf("discordgo.New() error: %v", err) + } + + if err = applyDiscordProxy(session, ""); err != nil { + t.Fatalf("applyDiscordProxy() error: %v", err) + } + + req, err := http.NewRequest("GET", "https://discord.com/api/v10/gateway", nil) + if err != nil { + t.Fatalf("http.NewRequest() error: %v", err) + } + + gotURL, err := session.Dialer.Proxy(req) + if err != nil { + t.Fatalf("ws proxy func error: %v", err) + } + + wantURL, err := url.Parse("http://127.0.0.1:8888") + if err != nil { + t.Fatalf("url.Parse() error: %v", err) + } + if gotURL.String() != wantURL.String() { + t.Fatalf("WS proxy = %q, want %q", gotURL.String(), wantURL.String()) + } +} + +func TestApplyDiscordProxy_InvalidProxyURL(t *testing.T) { + session, err := discordgo.New("Bot test-token") + if err != nil { + t.Fatalf("discordgo.New() error: %v", err) + } + + if err = applyDiscordProxy(session, "://bad-proxy"); err == nil { + t.Fatal("applyDiscordProxy() expected error for invalid proxy URL, got nil") + } +} diff --git a/pkg/config/config.go b/pkg/config/config.go index d892b64a5..f40e05e1c 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -3,23 +3,14 @@ package config import ( "encoding/json" "fmt" - "log" "os" - "path/filepath" - "sync" "sync/atomic" "github.com/caarlos0/env/v11" - "github.com/joho/godotenv" "github.com/sipeed/picoclaw/pkg/fileutil" ) -// dotenvOnce ensures .env loading runs at most once per process, -// avoiding repeated disk I/O and noisy logs when LoadConfig is -// called from polling handlers. -var dotenvOnce sync.Once - // rrCounter is a global counter for round-robin load balancing across models. var rrCounter atomic.Uint64 @@ -189,6 +180,8 @@ type AgentDefaults struct { MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"` Temperature *float64 `json:"temperature,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"` MaxToolIterations int `json:"max_tool_iterations" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"` + SummarizeMessageThreshold int `json:"summarize_message_threshold" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_MESSAGE_THRESHOLD"` + SummarizeTokenPercent int `json:"summarize_token_percent" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_TOKEN_PERCENT"` MaxMediaSize int `json:"max_media_size,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_MEDIA_SIZE"` } @@ -280,6 +273,7 @@ type FeishuConfig struct { type DiscordConfig struct { Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_DISCORD_ENABLED"` Token string `json:"token" env:"PICOCLAW_CHANNELS_DISCORD_TOKEN"` + Proxy string `json:"proxy" env:"PICOCLAW_CHANNELS_DISCORD_PROXY"` AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_DISCORD_ALLOW_FROM"` MentionOnly bool `json:"mention_only" env:"PICOCLAW_CHANNELS_DISCORD_MENTION_ONLY"` GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"` @@ -437,7 +431,6 @@ type ProvidersConfig struct { Antigravity ProviderConfig `json:"antigravity"` Qwen ProviderConfig `json:"qwen"` Mistral ProviderConfig `json:"mistral"` - Opencode ProviderConfig `json:"opencode"` } // IsEmpty checks if all provider configs are empty (no API keys or API bases set) @@ -461,8 +454,7 @@ func (p ProvidersConfig) IsEmpty() bool { p.GitHubCopilot.APIKey == "" && p.GitHubCopilot.APIBase == "" && p.Antigravity.APIKey == "" && p.Antigravity.APIBase == "" && p.Qwen.APIKey == "" && p.Qwen.APIBase == "" && - p.Mistral.APIKey == "" && p.Mistral.APIBase == "" && - p.Opencode.APIKey == "" && p.Opencode.APIBase == "" + p.Mistral.APIKey == "" && p.Mistral.APIBase == "" } // MarshalJSON implements custom JSON marshaling for ProvidersConfig @@ -555,10 +547,14 @@ type PerplexityConfig struct { MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_PERPLEXITY_MAX_RESULTS"` } -type ExaConfig struct { - Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_WEB_EXA_ENABLED"` - APIKey string `json:"api_key" env:"PICOCLAW_TOOLS_WEB_EXA_API_KEY"` - MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_EXA_MAX_RESULTS"` +type GLMSearchConfig struct { + Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_WEB_GLM_ENABLED"` + APIKey string `json:"api_key" env:"PICOCLAW_TOOLS_WEB_GLM_API_KEY"` + BaseURL string `json:"base_url" env:"PICOCLAW_TOOLS_WEB_GLM_BASE_URL"` + // SearchEngine specifies the search backend: "search_std" (default), + // "search_pro", "search_pro_sogou", or "search_pro_quark". + SearchEngine string `json:"search_engine" env:"PICOCLAW_TOOLS_WEB_GLM_SEARCH_ENGINE"` + MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_GLM_MAX_RESULTS"` } type WebToolsConfig struct { @@ -566,7 +562,7 @@ type WebToolsConfig struct { Tavily TavilyConfig `json:"tavily"` DuckDuckGo DuckDuckGoConfig `json:"duckduckgo"` Perplexity PerplexityConfig `json:"perplexity"` - Exa ExaConfig `json:"exa"` + GLMSearch GLMSearchConfig `json:"glm_search"` // Proxy is an optional proxy URL for web tools (http/https/socks5/socks5h). // For authenticated proxies, prefer HTTP_PROXY/HTTPS_PROXY env vars instead of embedding credentials in config. Proxy string `json:"proxy,omitempty" env:"PICOCLAW_TOOLS_WEB_PROXY"` @@ -658,35 +654,9 @@ type MCPConfig struct { func LoadConfig(path string) (*Config, error) { cfg := DefaultConfig() - // Load .env file from config directory (secrets, API keys, etc.) - // Guarded by sync.Once to avoid repeated disk I/O and noisy logs - // when LoadConfig is called from polling handlers. - dotenvOnce.Do(func() { - envFile := filepath.Join(filepath.Dir(path), ".env") - if err := godotenv.Load(envFile); err != nil { - if os.IsNotExist(err) { - log.Printf("[INFO] No .env file found at %s; skipping .env loading", envFile) - } else { - log.Printf("[WARN] Failed to load .env file from %s: %v", envFile, err) - } - } - }) - data, err := os.ReadFile(path) if err != nil { if os.IsNotExist(err) { - // No config file — still apply env vars + overrides to default config - if err := env.Parse(cfg); err != nil { - return nil, err - } - loadProviderEnvOverrides(cfg) - cfg.migrateChannelConfigs() - if cfg.HasProvidersConfig() { - cfg.ModelList = ConvertProvidersToModelList(cfg) - } - if err := cfg.ValidateModelList(); err != nil { - return nil, err - } return cfg, nil } return nil, err @@ -714,9 +684,6 @@ func LoadConfig(path string) (*Config, error) { return nil, err } - // Load provider-specific env overrides (PICOCLAW_PROVIDERS__API_KEY, etc.) - loadProviderEnvOverrides(cfg) - // Migrate legacy channel config fields to new unified structures cfg.migrateChannelConfigs() @@ -865,42 +832,3 @@ func (c *Config) ValidateModelList() error { } return nil } - -// loadProviderEnvOverrides reads PICOCLAW_PROVIDERS__API_KEY and _API_BASE -// environment variables and sets them on the corresponding provider config fields. -// This enables storing provider secrets in .env files without using struct tags. -func loadProviderEnvOverrides(cfg *Config) { - providers := []struct { - name string - apiKey *string - base *string - }{ - {"ANTHROPIC", &cfg.Providers.Anthropic.APIKey, &cfg.Providers.Anthropic.APIBase}, - {"OPENAI", &cfg.Providers.OpenAI.APIKey, &cfg.Providers.OpenAI.APIBase}, - {"LITELLM", &cfg.Providers.LiteLLM.APIKey, &cfg.Providers.LiteLLM.APIBase}, - {"OPENROUTER", &cfg.Providers.OpenRouter.APIKey, &cfg.Providers.OpenRouter.APIBase}, - {"GROQ", &cfg.Providers.Groq.APIKey, &cfg.Providers.Groq.APIBase}, - {"ZHIPU", &cfg.Providers.Zhipu.APIKey, &cfg.Providers.Zhipu.APIBase}, - {"GEMINI", &cfg.Providers.Gemini.APIKey, &cfg.Providers.Gemini.APIBase}, - {"NVIDIA", &cfg.Providers.Nvidia.APIKey, &cfg.Providers.Nvidia.APIBase}, - {"OLLAMA", &cfg.Providers.Ollama.APIKey, &cfg.Providers.Ollama.APIBase}, - {"MOONSHOT", &cfg.Providers.Moonshot.APIKey, &cfg.Providers.Moonshot.APIBase}, - {"SHENGSUANYUN", &cfg.Providers.ShengSuanYun.APIKey, &cfg.Providers.ShengSuanYun.APIBase}, - {"DEEPSEEK", &cfg.Providers.DeepSeek.APIKey, &cfg.Providers.DeepSeek.APIBase}, - {"MISTRAL", &cfg.Providers.Mistral.APIKey, &cfg.Providers.Mistral.APIBase}, - {"VLLM", &cfg.Providers.VLLM.APIKey, &cfg.Providers.VLLM.APIBase}, - {"CEREBRAS", &cfg.Providers.Cerebras.APIKey, &cfg.Providers.Cerebras.APIBase}, - {"VOLCENGINE", &cfg.Providers.VolcEngine.APIKey, &cfg.Providers.VolcEngine.APIBase}, - {"QWEN", &cfg.Providers.Qwen.APIKey, &cfg.Providers.Qwen.APIBase}, - // Note: GitHubCopilot and Antigravity use different auth patterns (ConnectMode/AuthMethod), - // not standard APIKey/APIBase, so they are not included here. - } - for _, p := range providers { - if v, ok := os.LookupEnv("PICOCLAW_PROVIDERS_" + p.name + "_API_KEY"); ok { - *p.apiKey = v - } - if v, ok := os.LookupEnv("PICOCLAW_PROVIDERS_" + p.name + "_API_BASE"); ok { - *p.base = v - } - } -} diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go index fb11799d4..10ebc7c90 100644 --- a/pkg/config/config_test.go +++ b/pkg/config/config_test.go @@ -6,7 +6,6 @@ import ( "path/filepath" "runtime" "strings" - "sync" "testing" ) @@ -436,6 +435,18 @@ func TestLoadConfig_WebToolsProxy(t *testing.T) { } // TestDefaultConfig_DMScope verifies the default dm_scope value +// TestDefaultConfig_SummarizationThresholds verifies summarization defaults +func TestDefaultConfig_SummarizationThresholds(t *testing.T) { + cfg := DefaultConfig() + + if cfg.Agents.Defaults.SummarizeMessageThreshold != 20 { + t.Errorf("SummarizeMessageThreshold = %d, want 20", cfg.Agents.Defaults.SummarizeMessageThreshold) + } + if cfg.Agents.Defaults.SummarizeTokenPercent != 75 { + t.Errorf("SummarizeTokenPercent = %d, want 75", cfg.Agents.Defaults.SummarizeTokenPercent) + } +} + func TestDefaultConfig_DMScope(t *testing.T) { cfg := DefaultConfig() @@ -468,98 +479,3 @@ func TestDefaultConfig_WorkspacePath_WithPicoclawHome(t *testing.T) { t.Errorf("Workspace path with PICOCLAW_HOME = %q, want %q", cfg.Agents.Defaults.Workspace, want) } } - -func TestLoadConfig_DotenvFileLoaded(t *testing.T) { - // Reset sync.Once so .env loading runs for this test - dotenvOnce = sync.Once{} - - dir := t.TempDir() - configPath := filepath.Join(dir, "config.json") - - // Write a minimal config.json - if err := os.WriteFile(configPath, []byte(`{}`), 0o600); err != nil { - t.Fatalf("WriteFile config: %v", err) - } - - // Write a .env file with a provider API key - envFile := filepath.Join(dir, ".env") - if err := os.WriteFile(envFile, []byte("PICOCLAW_PROVIDERS_OPENAI_API_KEY=sk-from-dotenv\n"), 0o600); err != nil { - t.Fatalf("WriteFile .env: %v", err) - } - - // Clear the env var first to ensure it comes from .env - t.Setenv("PICOCLAW_PROVIDERS_OPENAI_API_KEY", "") - os.Unsetenv("PICOCLAW_PROVIDERS_OPENAI_API_KEY") - - cfg, err := LoadConfig(configPath) - if err != nil { - t.Fatalf("LoadConfig() error: %v", err) - } - - if cfg.Providers.OpenAI.APIKey != "sk-from-dotenv" { - t.Errorf("OpenAI.APIKey = %q, want %q", cfg.Providers.OpenAI.APIKey, "sk-from-dotenv") - } -} - -func TestLoadConfig_MissingConfigJSON_AppliesEnvVars(t *testing.T) { - // Reset sync.Once so .env loading runs for this test - dotenvOnce = sync.Once{} - - dir := t.TempDir() - configPath := filepath.Join(dir, "config.json") // does NOT exist - - t.Setenv("PICOCLAW_PROVIDERS_ANTHROPIC_API_KEY", "sk-anthropic-test") - - cfg, err := LoadConfig(configPath) - if err != nil { - t.Fatalf("LoadConfig() error: %v", err) - } - - if cfg.Providers.Anthropic.APIKey != "sk-anthropic-test" { - t.Errorf("Anthropic.APIKey = %q, want %q", cfg.Providers.Anthropic.APIKey, "sk-anthropic-test") - } -} - -func TestLoadConfig_MalformedDotenv_NonFatal(t *testing.T) { - // Reset sync.Once so .env loading runs for this test - dotenvOnce = sync.Once{} - - dir := t.TempDir() - configPath := filepath.Join(dir, "config.json") - - // Write a minimal config.json - if err := os.WriteFile(configPath, []byte(`{}`), 0o600); err != nil { - t.Fatalf("WriteFile config: %v", err) - } - - // Write a .env file with genuinely malformed content (bare key without '=', - // mixed with a valid line) to verify godotenv.Load errors are non-fatal. - envFile := filepath.Join(dir, ".env") - if err := os.WriteFile(envFile, []byte("THIS_LINE_HAS_NO_EQUALS\nVALID_KEY=valid_value\n"), 0o600); err != nil { - t.Fatalf("WriteFile .env: %v", err) - } - - // LoadConfig should not fail even with malformed .env content - cfg, err := LoadConfig(configPath) - if err != nil { - t.Fatalf("LoadConfig() should not fail with .env issues, got error: %v", err) - } - if cfg == nil { - t.Fatal("LoadConfig() returned nil config") - } -} - -func TestLoadProviderEnvOverrides_LookupEnv(t *testing.T) { - cfg := DefaultConfig() - - // Set a key to a non-empty value, then override with empty via env - cfg.Providers.OpenRouter.APIBase = "https://original.com" - t.Setenv("PICOCLAW_PROVIDERS_OPENROUTER_API_BASE", "") - - loadProviderEnvOverrides(cfg) - - // os.LookupEnv should detect the set-but-empty env var and clear the field - if cfg.Providers.OpenRouter.APIBase != "" { - t.Errorf("OpenRouter.APIBase = %q, want empty (overridden by empty env var)", cfg.Providers.OpenRouter.APIBase) - } -} diff --git a/pkg/config/defaults.go b/pkg/config/defaults.go index 7f6dd3ca5..6f65dd469 100644 --- a/pkg/config/defaults.go +++ b/pkg/config/defaults.go @@ -26,13 +26,15 @@ func DefaultConfig() *Config { return &Config{ Agents: AgentsConfig{ Defaults: AgentDefaults{ - Workspace: workspacePath, - RestrictToWorkspace: true, - Provider: "", - Model: "", - MaxTokens: 32768, - Temperature: nil, // nil means use provider default - MaxToolIterations: 50, + Workspace: workspacePath, + RestrictToWorkspace: true, + Provider: "", + Model: "", + MaxTokens: 32768, + Temperature: nil, // nil means use provider default + MaxToolIterations: 50, + SummarizeMessageThreshold: 20, + SummarizeTokenPercent: 75, }, }, Bindings: []AgentBinding{}, @@ -341,10 +343,12 @@ func DefaultConfig() *Config { APIKey: "", MaxResults: 5, }, - Exa: ExaConfig{ - Enabled: false, - APIKey: "", - MaxResults: 5, + GLMSearch: GLMSearchConfig{ + Enabled: false, + APIKey: "", + BaseURL: "https://open.bigmodel.cn/api/paas/v4/web_search", + SearchEngine: "search_std", + MaxResults: 5, }, }, Cron: CronToolsConfig{ diff --git a/pkg/config/migration.go b/pkg/config/migration.go index b7ca6dd85..772f714fd 100644 --- a/pkg/config/migration.go +++ b/pkg/config/migration.go @@ -225,7 +225,7 @@ func ConvertProvidersToModelList(cfg *Config) []ModelConfig { }, }, { - providerNames: []string{"moonshot", "kimi", "kimi-code"}, + providerNames: []string{"moonshot", "kimi"}, protocol: "moonshot", buildConfig: func(p ProvidersConfig) (ModelConfig, bool) { if p.Moonshot.APIKey == "" && p.Moonshot.APIBase == "" { @@ -373,23 +373,6 @@ func ConvertProvidersToModelList(cfg *Config) []ModelConfig { }, true }, }, - { - providerNames: []string{"opencode"}, - protocol: "opencode", - buildConfig: func(p ProvidersConfig) (ModelConfig, bool) { - if p.Opencode.APIKey == "" && p.Opencode.APIBase == "" { - return ModelConfig{}, false - } - return ModelConfig{ - ModelName: "opencode", - Model: "opencode/auto", - APIKey: p.Opencode.APIKey, - APIBase: p.Opencode.APIBase, - Proxy: p.Opencode.Proxy, - RequestTimeout: p.Opencode.RequestTimeout, - }, true - }, - }, } // Process each provider migration @@ -401,9 +384,7 @@ func ConvertProvidersToModelList(cfg *Config) []ModelConfig { // Check if this is the user's configured provider if slices.Contains(m.providerNames, userProvider) && userModel != "" { - // Use the user's configured model instead of default. - // Also set ModelName so GetModelConfig(userModel) can find this entry. - mc.ModelName = userModel + // Use the user's configured model instead of default mc.Model = buildModelWithProtocol(m.protocol, userModel) } else if userProvider == "" && userModel != "" && !legacyModelNameApplied { // Legacy config: no explicit provider field but model is specified diff --git a/pkg/config/migration_test.go b/pkg/config/migration_test.go index 841ba8a9c..e24e9fa1d 100644 --- a/pkg/config/migration_test.go +++ b/pkg/config/migration_test.go @@ -160,7 +160,6 @@ func TestConvertProvidersToModelList_AllProviders(t *testing.T) { Antigravity: ProviderConfig{AuthMethod: "oauth"}, Qwen: ProviderConfig{APIKey: "key17"}, Mistral: ProviderConfig{APIKey: "key18"}, - Opencode: ProviderConfig{APIKey: "key19"}, }, } @@ -580,65 +579,6 @@ func TestBuildModelWithProtocol_DifferentPrefix(t *testing.T) { } } -func TestConvertProvidersToModelList_Opencode(t *testing.T) { - cfg := &Config{ - Providers: ProvidersConfig{ - Opencode: ProviderConfig{ - APIKey: "oc-test-key", - APIBase: "https://custom.opencode.ai/v1", - Proxy: "http://proxy:9090", - RequestTimeout: 60, - }, - }, - } - - result := ConvertProvidersToModelList(cfg) - - if len(result) != 1 { - t.Fatalf("len(result) = %d, want 1", len(result)) - } - - mc := result[0] - if mc.ModelName != "opencode" { - t.Errorf("ModelName = %q, want %q", mc.ModelName, "opencode") - } - if mc.Model != "opencode/auto" { - t.Errorf("Model = %q, want %q", mc.Model, "opencode/auto") - } - if mc.APIKey != "oc-test-key" { - t.Errorf("APIKey = %q, want %q", mc.APIKey, "oc-test-key") - } - if mc.APIBase != "https://custom.opencode.ai/v1" { - t.Errorf("APIBase = %q, want %q", mc.APIBase, "https://custom.opencode.ai/v1") - } - if mc.Proxy != "http://proxy:9090" { - t.Errorf("Proxy = %q, want %q", mc.Proxy, "http://proxy:9090") - } - if mc.RequestTimeout != 60 { - t.Errorf("RequestTimeout = %d, want %d", mc.RequestTimeout, 60) - } -} - -func TestConvertProvidersToModelList_Opencode_APIBaseOnly(t *testing.T) { - cfg := &Config{ - Providers: ProvidersConfig{ - Opencode: ProviderConfig{ - APIBase: "https://custom.opencode.ai/v1", - }, - }, - } - - result := ConvertProvidersToModelList(cfg) - - if len(result) != 1 { - t.Fatalf("len(result) = %d, want 1 (APIBase-only should create entry)", len(result)) - } - - if result[0].ModelName != "opencode" { - t.Errorf("ModelName = %q, want %q", result[0].ModelName, "opencode") - } -} - // Test for legacy config with protocol prefix in model name func TestConvertProvidersToModelList_LegacyModelWithProtocolPrefix(t *testing.T) { cfg := &Config{ @@ -669,72 +609,3 @@ func TestConvertProvidersToModelList_LegacyModelWithProtocolPrefix(t *testing.T) t.Errorf("Model = %q, want %q (should not duplicate prefix)", result[0].Model, "openrouter/auto") } } - -// Test that ModelName is set to the user's configured model when provider matches. -// This ensures GetModelConfig(userModel) can find the migrated entry. -// Regression test for: gateway startup failure when user model differs from provider name. -func TestConvertProvidersToModelList_ModelNameMatchesUserModel(t *testing.T) { - cfg := &Config{ - Agents: AgentsConfig{ - Defaults: AgentDefaults{ - Provider: "moonshot", - Model: "k2p5", - }, - }, - Providers: ProvidersConfig{ - Moonshot: ProviderConfig{APIKey: "sk-kimi-test"}, - }, - } - - result := ConvertProvidersToModelList(cfg) - - if len(result) != 1 { - t.Fatalf("len(result) = %d, want 1", len(result)) - } - - // ModelName must match the user's configured model, not the provider name. - // Without this, GetModelConfig("k2p5") would fail because it would look - // for ModelName == "k2p5" but find ModelName == "moonshot". - if result[0].ModelName != "k2p5" { - t.Errorf("ModelName = %q, want %q (must match user's model for GetModelConfig lookup)", result[0].ModelName, "k2p5") - } - - if result[0].Model != "moonshot/k2p5" { - t.Errorf("Model = %q, want %q", result[0].Model, "moonshot/k2p5") - } - - // Other providers (not matching the user's configured provider) should keep their provider name - cfg2 := &Config{ - Agents: AgentsConfig{ - Defaults: AgentDefaults{ - Provider: "moonshot", - Model: "k2p5", - }, - }, - Providers: ProvidersConfig{ - OpenAI: OpenAIProviderConfig{ProviderConfig: ProviderConfig{APIKey: "sk-openai"}}, - Moonshot: ProviderConfig{APIKey: "sk-kimi-test"}, - }, - } - - result2 := ConvertProvidersToModelList(cfg2) - - if len(result2) != 2 { - t.Fatalf("len(result2) = %d, want 2", len(result2)) - } - - for _, mc := range result2 { - switch { - case mc.APIKey == "sk-openai": - // OpenAI is not the user's provider, should keep default ModelName - if mc.ModelName != "openai" { - t.Errorf("OpenAI ModelName = %q, want %q (non-matching provider keeps default)", mc.ModelName, "openai") - } - case mc.APIKey == "sk-kimi-test": - // Moonshot is the user's provider, ModelName must be the user's model - if mc.ModelName != "k2p5" { - t.Errorf("Moonshot ModelName = %q, want %q (matching provider uses user model)", mc.ModelName, "k2p5") - } - } - } -} diff --git a/pkg/memory/jsonl.go b/pkg/memory/jsonl.go new file mode 100644 index 000000000..e12e2c5ab --- /dev/null +++ b/pkg/memory/jsonl.go @@ -0,0 +1,460 @@ +package memory + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "fmt" + "hash/fnv" + "log" + "os" + "path/filepath" + "strings" + "sync" + "time" + + "github.com/sipeed/picoclaw/pkg/fileutil" + "github.com/sipeed/picoclaw/pkg/providers" +) + +const ( + // numLockShards is the fixed number of mutexes used to serialize + // per-session access. Using a sharded array instead of a map keeps + // memory bounded regardless of how many sessions are created over + // the lifetime of the process — important for a long-running daemon. + numLockShards = 64 + + // maxLineSize is the maximum size of a single JSON line in a .jsonl + // file. Tool results (read_file, web search, etc.) can be large, so + // we set a generous limit. The scanner starts at 64 KB and grows + // only as needed up to this cap. + maxLineSize = 10 * 1024 * 1024 // 10 MB +) + +// sessionMeta holds per-session metadata stored in a .meta.json file. +type sessionMeta struct { + Key string `json:"key"` + Summary string `json:"summary"` + Skip int `json:"skip"` + Count int `json:"count"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +// JSONLStore implements Store using append-only JSONL files. +// +// Each session is stored as two files: +// +// {sanitized_key}.jsonl — one JSON-encoded message per line, append-only +// {sanitized_key}.meta.json — session metadata (summary, logical truncation offset) +// +// Messages are never physically deleted from the JSONL file. Instead, +// TruncateHistory records a "skip" offset in the metadata file and +// GetHistory ignores lines before that offset. This keeps all writes +// append-only, which is both fast and crash-safe. +type JSONLStore struct { + dir string + locks [numLockShards]sync.Mutex +} + +// NewJSONLStore creates a new JSONL-backed store rooted at dir. +func NewJSONLStore(dir string) (*JSONLStore, error) { + err := os.MkdirAll(dir, 0o755) + if err != nil { + return nil, fmt.Errorf("memory: create directory: %w", err) + } + return &JSONLStore{dir: dir}, nil +} + +// sessionLock returns a mutex for the given session key. +// Keys are mapped to a fixed pool of shards via FNV hash, so +// memory usage is O(1) regardless of total session count. +func (s *JSONLStore) sessionLock(key string) *sync.Mutex { + h := fnv.New32a() + h.Write([]byte(key)) + return &s.locks[h.Sum32()%numLockShards] +} + +func (s *JSONLStore) jsonlPath(key string) string { + return filepath.Join(s.dir, sanitizeKey(key)+".jsonl") +} + +func (s *JSONLStore) metaPath(key string) string { + return filepath.Join(s.dir, sanitizeKey(key)+".meta.json") +} + +// sanitizeKey converts a session key to a safe filename component. +// Mirrors pkg/session.sanitizeFilename so that migration paths match. +// +// Note: this is a lossy mapping — "telegram:123" and "telegram_123" +// both produce the same filename. This is an intentional tradeoff: +// keys with colons (e.g. from channels) are by far the common case, +// and a bidirectional encoding (like URL-encoding) would complicate +// file listings and debugging. +func sanitizeKey(key string) string { + return strings.ReplaceAll(key, ":", "_") +} + +// readMeta loads the metadata file for a session. +// Returns a zero-value sessionMeta if the file does not exist. +func (s *JSONLStore) readMeta(key string) (sessionMeta, error) { + data, err := os.ReadFile(s.metaPath(key)) + if os.IsNotExist(err) { + return sessionMeta{Key: key}, nil + } + if err != nil { + return sessionMeta{}, fmt.Errorf("memory: read meta: %w", err) + } + var meta sessionMeta + err = json.Unmarshal(data, &meta) + if err != nil { + return sessionMeta{}, fmt.Errorf("memory: decode meta: %w", err) + } + return meta, nil +} + +// writeMeta atomically writes the metadata file using the project's +// standard WriteFileAtomic (temp + fsync + rename). +func (s *JSONLStore) writeMeta(key string, meta sessionMeta) error { + data, err := json.MarshalIndent(meta, "", " ") + if err != nil { + return fmt.Errorf("memory: encode meta: %w", err) + } + return fileutil.WriteFileAtomic(s.metaPath(key), data, 0o644) +} + +// readMessages reads valid JSON lines from a .jsonl file, skipping +// the first `skip` lines without unmarshaling them. This avoids the +// cost of json.Unmarshal on logically truncated messages. +// Malformed trailing lines (e.g. from a crash) are silently skipped. +func readMessages(path string, skip int) ([]providers.Message, error) { + f, err := os.Open(path) + if os.IsNotExist(err) { + return []providers.Message{}, nil + } + if err != nil { + return nil, fmt.Errorf("memory: open jsonl: %w", err) + } + defer f.Close() + + var msgs []providers.Message + scanner := bufio.NewScanner(f) + // Allow large lines for tool results (read_file, web search, etc.). + scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize) + + lineNum := 0 + for scanner.Scan() { + line := scanner.Bytes() + if len(line) == 0 { + continue + } + lineNum++ + if lineNum <= skip { + continue + } + var msg providers.Message + if err := json.Unmarshal(line, &msg); err != nil { + // Corrupt line — likely a partial write from a crash. + // Log so operators know data was skipped, but don't + // fail the entire read; this is the standard JSONL + // recovery pattern. + log.Printf("memory: skipping corrupt line %d in %s: %v", + lineNum, filepath.Base(path), err) + continue + } + msgs = append(msgs, msg) + } + if scanner.Err() != nil { + return nil, fmt.Errorf("memory: scan jsonl: %w", scanner.Err()) + } + + if msgs == nil { + msgs = []providers.Message{} + } + return msgs, nil +} + +// countLines counts the total number of non-empty lines in a .jsonl file. +// Used by TruncateHistory to reconcile a stale meta.Count without +// the overhead of unmarshaling every message. +func countLines(path string) (int, error) { + f, err := os.Open(path) + if os.IsNotExist(err) { + return 0, nil + } + if err != nil { + return 0, fmt.Errorf("memory: open jsonl: %w", err) + } + defer f.Close() + + n := 0 + scanner := bufio.NewScanner(f) + scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize) + for scanner.Scan() { + if len(scanner.Bytes()) > 0 { + n++ + } + } + return n, scanner.Err() +} + +func (s *JSONLStore) AddMessage( + _ context.Context, sessionKey, role, content string, +) error { + return s.addMsg(sessionKey, providers.Message{ + Role: role, + Content: content, + }) +} + +func (s *JSONLStore) AddFullMessage( + _ context.Context, sessionKey string, msg providers.Message, +) error { + return s.addMsg(sessionKey, msg) +} + +// addMsg is the shared implementation for AddMessage and AddFullMessage. +func (s *JSONLStore) addMsg(sessionKey string, msg providers.Message) error { + l := s.sessionLock(sessionKey) + l.Lock() + defer l.Unlock() + + // Append the message as a single JSON line. + line, err := json.Marshal(msg) + if err != nil { + return fmt.Errorf("memory: marshal message: %w", err) + } + line = append(line, '\n') + + f, err := os.OpenFile( + s.jsonlPath(sessionKey), + os.O_CREATE|os.O_WRONLY|os.O_APPEND, + 0o644, + ) + if err != nil { + return fmt.Errorf("memory: open jsonl for append: %w", err) + } + _, writeErr := f.Write(line) + if writeErr != nil { + f.Close() + return fmt.Errorf("memory: append message: %w", writeErr) + } + // Flush to physical storage before closing. This matches the + // durability guarantee of writeMeta and rewriteJSONL (which use + // WriteFileAtomic with fsync). Without Sync, a power loss could + // leave the append in the kernel page cache only — lost on reboot. + if syncErr := f.Sync(); syncErr != nil { + f.Close() + return fmt.Errorf("memory: sync jsonl: %w", syncErr) + } + if closeErr := f.Close(); closeErr != nil { + return fmt.Errorf("memory: close jsonl: %w", closeErr) + } + + // Update metadata. + meta, err := s.readMeta(sessionKey) + if err != nil { + return err + } + now := time.Now() + if meta.Count == 0 && meta.CreatedAt.IsZero() { + meta.CreatedAt = now + } + meta.Count++ + meta.UpdatedAt = now + + return s.writeMeta(sessionKey, meta) +} + +func (s *JSONLStore) GetHistory( + _ context.Context, sessionKey string, +) ([]providers.Message, error) { + l := s.sessionLock(sessionKey) + l.Lock() + defer l.Unlock() + + meta, err := s.readMeta(sessionKey) + if err != nil { + return nil, err + } + + // Pass meta.Skip so readMessages skips those lines without + // unmarshaling them — avoids wasted CPU on truncated messages. + msgs, err := readMessages(s.jsonlPath(sessionKey), meta.Skip) + if err != nil { + return nil, err + } + + return msgs, nil +} + +func (s *JSONLStore) GetSummary( + _ context.Context, sessionKey string, +) (string, error) { + l := s.sessionLock(sessionKey) + l.Lock() + defer l.Unlock() + + meta, err := s.readMeta(sessionKey) + if err != nil { + return "", err + } + return meta.Summary, nil +} + +func (s *JSONLStore) SetSummary( + _ context.Context, sessionKey, summary string, +) error { + l := s.sessionLock(sessionKey) + l.Lock() + defer l.Unlock() + + meta, err := s.readMeta(sessionKey) + if err != nil { + return err + } + now := time.Now() + if meta.CreatedAt.IsZero() { + meta.CreatedAt = now + } + meta.Summary = summary + meta.UpdatedAt = now + + return s.writeMeta(sessionKey, meta) +} + +func (s *JSONLStore) TruncateHistory( + _ context.Context, sessionKey string, keepLast int, +) error { + l := s.sessionLock(sessionKey) + l.Lock() + defer l.Unlock() + + meta, err := s.readMeta(sessionKey) + if err != nil { + return err + } + + // Always reconcile meta.Count with the actual line count on disk. + // A crash between the JSONL append and the meta update in addMsg + // leaves meta.Count stale (e.g. file has 101 lines but meta says + // 100). Counting lines is cheap — no unmarshal, just a scan — and + // TruncateHistory is not a hot path, so always re-count. + n, countErr := countLines(s.jsonlPath(sessionKey)) + if countErr != nil { + return countErr + } + meta.Count = n + + if keepLast <= 0 { + meta.Skip = meta.Count + } else { + effective := meta.Count - meta.Skip + if keepLast < effective { + meta.Skip = meta.Count - keepLast + } + } + meta.UpdatedAt = time.Now() + + return s.writeMeta(sessionKey, meta) +} + +func (s *JSONLStore) SetHistory( + _ context.Context, + sessionKey string, + history []providers.Message, +) error { + l := s.sessionLock(sessionKey) + l.Lock() + defer l.Unlock() + + meta, err := s.readMeta(sessionKey) + if err != nil { + return err + } + now := time.Now() + if meta.CreatedAt.IsZero() { + meta.CreatedAt = now + } + meta.Skip = 0 + meta.Count = len(history) + meta.UpdatedAt = now + + // Write meta BEFORE rewriting the JSONL file. If we crash between + // the two writes, meta has Skip=0 and the old file is still intact, + // so GetHistory reads from line 1 — returning "too many" messages + // rather than losing data. The next SetHistory call corrects this. + err = s.writeMeta(sessionKey, meta) + if err != nil { + return err + } + + return s.rewriteJSONL(sessionKey, history) +} + +// Compact physically rewrites the JSONL file, dropping all logically +// skipped lines. This reclaims disk space that accumulates after +// repeated TruncateHistory calls. +// +// It is safe to call at any time; if there is nothing to compact +// (skip == 0) the method returns immediately. +func (s *JSONLStore) Compact( + _ context.Context, sessionKey string, +) error { + l := s.sessionLock(sessionKey) + l.Lock() + defer l.Unlock() + + meta, err := s.readMeta(sessionKey) + if err != nil { + return err + } + if meta.Skip == 0 { + return nil + } + + // Read only the active messages, skipping truncated lines + // without unmarshaling them. + active, err := readMessages(s.jsonlPath(sessionKey), meta.Skip) + if err != nil { + return err + } + + // Write meta BEFORE rewriting the JSONL file. If the process + // crashes between the two writes, meta has Skip=0 and the old + // (uncompacted) file is still intact, so GetHistory reads from + // line 1 — returning previously-truncated messages rather than + // losing data. The next Compact or TruncateHistory corrects this. + meta.Skip = 0 + meta.Count = len(active) + meta.UpdatedAt = time.Now() + + err = s.writeMeta(sessionKey, meta) + if err != nil { + return err + } + + return s.rewriteJSONL(sessionKey, active) +} + +// rewriteJSONL atomically replaces the JSONL file with the given messages +// using the project's standard WriteFileAtomic (temp + fsync + rename). +func (s *JSONLStore) rewriteJSONL( + sessionKey string, msgs []providers.Message, +) error { + var buf bytes.Buffer + for i, msg := range msgs { + line, err := json.Marshal(msg) + if err != nil { + return fmt.Errorf("memory: marshal message %d: %w", i, err) + } + buf.Write(line) + buf.WriteByte('\n') + } + return fileutil.WriteFileAtomic(s.jsonlPath(sessionKey), buf.Bytes(), 0o644) +} + +func (s *JSONLStore) Close() error { + return nil +} diff --git a/pkg/memory/jsonl_test.go b/pkg/memory/jsonl_test.go new file mode 100644 index 000000000..356ff14ff --- /dev/null +++ b/pkg/memory/jsonl_test.go @@ -0,0 +1,835 @@ +package memory + +import ( + "context" + "os" + "path/filepath" + "sync" + "testing" + + "github.com/sipeed/picoclaw/pkg/providers" +) + +func newTestStore(t *testing.T) *JSONLStore { + t.Helper() + store, err := NewJSONLStore(t.TempDir()) + if err != nil { + t.Fatalf("NewJSONLStore: %v", err) + } + return store +} + +func TestNewJSONLStore_CreatesDirectory(t *testing.T) { + dir := filepath.Join(t.TempDir(), "nested", "sessions") + store, err := NewJSONLStore(dir) + if err != nil { + t.Fatalf("NewJSONLStore: %v", err) + } + defer store.Close() + + info, err := os.Stat(dir) + if err != nil { + t.Fatalf("Stat: %v", err) + } + if !info.IsDir() { + t.Errorf("expected directory, got file") + } +} + +func TestAddMessage_BasicRoundtrip(t *testing.T) { + store := newTestStore(t) + ctx := context.Background() + + err := store.AddMessage(ctx, "s1", "user", "hello") + if err != nil { + t.Fatalf("AddMessage: %v", err) + } + err = store.AddMessage(ctx, "s1", "assistant", "hi there") + if err != nil { + t.Fatalf("AddMessage: %v", err) + } + + history, err := store.GetHistory(ctx, "s1") + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + if len(history) != 2 { + t.Fatalf("expected 2 messages, got %d", len(history)) + } + if history[0].Role != "user" || history[0].Content != "hello" { + t.Errorf("msg[0] = %+v", history[0]) + } + if history[1].Role != "assistant" || history[1].Content != "hi there" { + t.Errorf("msg[1] = %+v", history[1]) + } +} + +func TestAddMessage_AutoCreatesSession(t *testing.T) { + store := newTestStore(t) + ctx := context.Background() + + // Adding a message to a non-existent session should work. + err := store.AddMessage(ctx, "new-session", "user", "first message") + if err != nil { + t.Fatalf("AddMessage: %v", err) + } + + history, err := store.GetHistory(ctx, "new-session") + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + if len(history) != 1 { + t.Fatalf("expected 1 message, got %d", len(history)) + } +} + +func TestAddFullMessage_WithToolCalls(t *testing.T) { + store := newTestStore(t) + ctx := context.Background() + + msg := providers.Message{ + Role: "assistant", + Content: "Let me search that.", + ToolCalls: []providers.ToolCall{ + { + ID: "call_abc", + Type: "function", + Function: &providers.FunctionCall{ + Name: "web_search", + Arguments: `{"q":"golang jsonl"}`, + }, + }, + }, + } + + err := store.AddFullMessage(ctx, "tc", msg) + if err != nil { + t.Fatalf("AddFullMessage: %v", err) + } + + history, err := store.GetHistory(ctx, "tc") + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + if len(history) != 1 { + t.Fatalf("expected 1, got %d", len(history)) + } + if len(history[0].ToolCalls) != 1 { + t.Fatalf("expected 1 tool call, got %d", len(history[0].ToolCalls)) + } + tc := history[0].ToolCalls[0] + if tc.ID != "call_abc" { + t.Errorf("tool call ID = %q", tc.ID) + } + if tc.Function == nil || tc.Function.Name != "web_search" { + t.Errorf("tool call function = %+v", tc.Function) + } +} + +func TestAddFullMessage_ToolCallID(t *testing.T) { + store := newTestStore(t) + ctx := context.Background() + + msg := providers.Message{ + Role: "tool", + Content: "search results here", + ToolCallID: "call_abc", + } + + err := store.AddFullMessage(ctx, "tr", msg) + if err != nil { + t.Fatalf("AddFullMessage: %v", err) + } + + history, err := store.GetHistory(ctx, "tr") + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + if len(history) != 1 { + t.Fatalf("expected 1, got %d", len(history)) + } + if history[0].ToolCallID != "call_abc" { + t.Errorf("ToolCallID = %q", history[0].ToolCallID) + } +} + +func TestGetHistory_EmptySession(t *testing.T) { + store := newTestStore(t) + ctx := context.Background() + + history, err := store.GetHistory(ctx, "nonexistent") + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + if history == nil { + t.Fatal("expected non-nil empty slice") + } + if len(history) != 0 { + t.Errorf("expected 0 messages, got %d", len(history)) + } +} + +func TestGetHistory_Ordering(t *testing.T) { + store := newTestStore(t) + ctx := context.Background() + + for i := 0; i < 5; i++ { + err := store.AddMessage( + ctx, "order", + "user", + string(rune('a'+i)), + ) + if err != nil { + t.Fatalf("AddMessage(%d): %v", i, err) + } + } + + history, err := store.GetHistory(ctx, "order") + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + if len(history) != 5 { + t.Fatalf("expected 5, got %d", len(history)) + } + for i := 0; i < 5; i++ { + expected := string(rune('a' + i)) + if history[i].Content != expected { + t.Errorf("msg[%d].Content = %q, want %q", i, history[i].Content, expected) + } + } +} + +func TestSetSummary_GetSummary(t *testing.T) { + store := newTestStore(t) + ctx := context.Background() + + // No summary yet. + summary, err := store.GetSummary(ctx, "s1") + if err != nil { + t.Fatalf("GetSummary: %v", err) + } + if summary != "" { + t.Errorf("expected empty, got %q", summary) + } + + // Set a summary. + err = store.SetSummary(ctx, "s1", "talked about Go") + if err != nil { + t.Fatalf("SetSummary: %v", err) + } + + summary, err = store.GetSummary(ctx, "s1") + if err != nil { + t.Fatalf("GetSummary: %v", err) + } + if summary != "talked about Go" { + t.Errorf("summary = %q", summary) + } + + // Update summary. + err = store.SetSummary(ctx, "s1", "updated summary") + if err != nil { + t.Fatalf("SetSummary: %v", err) + } + + summary, err = store.GetSummary(ctx, "s1") + if err != nil { + t.Fatalf("GetSummary: %v", err) + } + if summary != "updated summary" { + t.Errorf("summary = %q", summary) + } +} + +func TestTruncateHistory_KeepLast(t *testing.T) { + store := newTestStore(t) + ctx := context.Background() + + for i := 0; i < 10; i++ { + err := store.AddMessage( + ctx, "trunc", + "user", + string(rune('a'+i)), + ) + if err != nil { + t.Fatalf("AddMessage: %v", err) + } + } + + err := store.TruncateHistory(ctx, "trunc", 4) + if err != nil { + t.Fatalf("TruncateHistory: %v", err) + } + + history, err := store.GetHistory(ctx, "trunc") + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + if len(history) != 4 { + t.Fatalf("expected 4, got %d", len(history)) + } + // Should be the last 4: g, h, i, j + if history[0].Content != "g" { + t.Errorf("first kept = %q, want 'g'", history[0].Content) + } + if history[3].Content != "j" { + t.Errorf("last kept = %q, want 'j'", history[3].Content) + } +} + +func TestTruncateHistory_KeepZero(t *testing.T) { + store := newTestStore(t) + ctx := context.Background() + + for i := 0; i < 5; i++ { + err := store.AddMessage(ctx, "empty", "user", "msg") + if err != nil { + t.Fatalf("AddMessage: %v", err) + } + } + + err := store.TruncateHistory(ctx, "empty", 0) + if err != nil { + t.Fatalf("TruncateHistory: %v", err) + } + + history, err := store.GetHistory(ctx, "empty") + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + if len(history) != 0 { + t.Errorf("expected 0, got %d", len(history)) + } +} + +func TestTruncateHistory_KeepMoreThanExists(t *testing.T) { + store := newTestStore(t) + ctx := context.Background() + + for i := 0; i < 3; i++ { + err := store.AddMessage(ctx, "few", "user", "msg") + if err != nil { + t.Fatalf("AddMessage: %v", err) + } + } + + // Keep 100, but only 3 exist — should keep all. + err := store.TruncateHistory(ctx, "few", 100) + if err != nil { + t.Fatalf("TruncateHistory: %v", err) + } + + history, err := store.GetHistory(ctx, "few") + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + if len(history) != 3 { + t.Errorf("expected 3, got %d", len(history)) + } +} + +func TestSetHistory_ReplacesAll(t *testing.T) { + store := newTestStore(t) + ctx := context.Background() + + // Add some initial messages. + for i := 0; i < 5; i++ { + err := store.AddMessage(ctx, "replace", "user", "old") + if err != nil { + t.Fatalf("AddMessage: %v", err) + } + } + + // Replace with new history. + newHistory := []providers.Message{ + {Role: "user", Content: "new1"}, + {Role: "assistant", Content: "new2"}, + } + err := store.SetHistory(ctx, "replace", newHistory) + if err != nil { + t.Fatalf("SetHistory: %v", err) + } + + history, err := store.GetHistory(ctx, "replace") + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + if len(history) != 2 { + t.Fatalf("expected 2, got %d", len(history)) + } + if history[0].Content != "new1" || history[1].Content != "new2" { + t.Errorf("history = %+v", history) + } +} + +func TestSetHistory_ResetsSkip(t *testing.T) { + store := newTestStore(t) + ctx := context.Background() + + // Add messages and truncate. + for i := 0; i < 10; i++ { + err := store.AddMessage(ctx, "skip-reset", "user", "old") + if err != nil { + t.Fatalf("AddMessage: %v", err) + } + } + err := store.TruncateHistory(ctx, "skip-reset", 3) + if err != nil { + t.Fatalf("TruncateHistory: %v", err) + } + + // SetHistory should reset skip to 0. + newHistory := []providers.Message{ + {Role: "user", Content: "fresh"}, + } + err = store.SetHistory(ctx, "skip-reset", newHistory) + if err != nil { + t.Fatalf("SetHistory: %v", err) + } + + history, err := store.GetHistory(ctx, "skip-reset") + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + if len(history) != 1 { + t.Fatalf("expected 1, got %d", len(history)) + } + if history[0].Content != "fresh" { + t.Errorf("content = %q", history[0].Content) + } +} + +func TestColonInKey(t *testing.T) { + store := newTestStore(t) + ctx := context.Background() + + err := store.AddMessage(ctx, "telegram:123", "user", "hi") + if err != nil { + t.Fatalf("AddMessage: %v", err) + } + + history, err := store.GetHistory(ctx, "telegram:123") + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + if len(history) != 1 { + t.Fatalf("expected 1, got %d", len(history)) + } + + // Verify the file is named with underscore. + jsonlFile := filepath.Join(store.dir, "telegram_123.jsonl") + if _, statErr := os.Stat(jsonlFile); statErr != nil { + t.Errorf("expected file %s to exist: %v", jsonlFile, statErr) + } +} + +func TestCompact_RemovesSkippedMessages(t *testing.T) { + store := newTestStore(t) + ctx := context.Background() + + // Write 10 messages, then truncate to keep last 3. + for i := 0; i < 10; i++ { + err := store.AddMessage(ctx, "compact", "user", string(rune('a'+i))) + if err != nil { + t.Fatalf("AddMessage: %v", err) + } + } + err := store.TruncateHistory(ctx, "compact", 3) + if err != nil { + t.Fatalf("TruncateHistory: %v", err) + } + + // Before compact: file still has 10 lines. + allOnDisk, err := readMessages(store.jsonlPath("compact"), 0) + if err != nil { + t.Fatalf("readMessages: %v", err) + } + if len(allOnDisk) != 10 { + t.Fatalf("before compact: expected 10 on disk, got %d", len(allOnDisk)) + } + + // Compact. + err = store.Compact(ctx, "compact") + if err != nil { + t.Fatalf("Compact: %v", err) + } + + // After compact: file should have only 3 lines. + allOnDisk, err = readMessages(store.jsonlPath("compact"), 0) + if err != nil { + t.Fatalf("readMessages: %v", err) + } + if len(allOnDisk) != 3 { + t.Fatalf("after compact: expected 3 on disk, got %d", len(allOnDisk)) + } + + // GetHistory should still return the same 3 messages. + history, err := store.GetHistory(ctx, "compact") + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + if len(history) != 3 { + t.Fatalf("expected 3, got %d", len(history)) + } + if history[0].Content != "h" || history[2].Content != "j" { + t.Errorf("wrong content: %+v", history) + } +} + +func TestCompact_NoOpWhenNoSkip(t *testing.T) { + store := newTestStore(t) + ctx := context.Background() + + for i := 0; i < 5; i++ { + err := store.AddMessage(ctx, "noop", "user", "msg") + if err != nil { + t.Fatalf("AddMessage: %v", err) + } + } + + // Compact without prior truncation — should be a no-op. + err := store.Compact(ctx, "noop") + if err != nil { + t.Fatalf("Compact: %v", err) + } + + history, err := store.GetHistory(ctx, "noop") + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + if len(history) != 5 { + t.Errorf("expected 5, got %d", len(history)) + } +} + +func TestCompact_ThenAppend(t *testing.T) { + store := newTestStore(t) + ctx := context.Background() + + for i := 0; i < 8; i++ { + err := store.AddMessage(ctx, "cap", "user", string(rune('a'+i))) + if err != nil { + t.Fatalf("AddMessage: %v", err) + } + } + + err := store.TruncateHistory(ctx, "cap", 2) + if err != nil { + t.Fatalf("TruncateHistory: %v", err) + } + err = store.Compact(ctx, "cap") + if err != nil { + t.Fatalf("Compact: %v", err) + } + + // Append after compaction should work correctly. + err = store.AddMessage(ctx, "cap", "user", "new") + if err != nil { + t.Fatalf("AddMessage after compact: %v", err) + } + + history, err := store.GetHistory(ctx, "cap") + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + if len(history) != 3 { + t.Fatalf("expected 3, got %d", len(history)) + } + // g, h (kept from truncation), new (appended after compaction). + if history[0].Content != "g" { + t.Errorf("first = %q, want 'g'", history[0].Content) + } + if history[2].Content != "new" { + t.Errorf("last = %q, want 'new'", history[2].Content) + } +} + +func TestTruncateHistory_StaleMetaCount(t *testing.T) { + // Simulates a crash between JSONL append and meta update in addMsg: + // file has N+1 lines but meta.Count is still N. TruncateHistory must + // reconcile with the real line count so that keepLast is accurate. + store := newTestStore(t) + ctx := context.Background() + + // Write 10 messages normally (meta.Count = 10). + for i := 0; i < 10; i++ { + err := store.AddMessage(ctx, "stale", "user", string(rune('a'+i))) + if err != nil { + t.Fatalf("AddMessage: %v", err) + } + } + + // Simulate crash: append a line to JSONL but do NOT update meta. + // This leaves meta.Count = 10 while the file has 11 lines. + jsonlPath := store.jsonlPath("stale") + f, err := os.OpenFile(jsonlPath, os.O_WRONLY|os.O_APPEND, 0o644) + if err != nil { + t.Fatalf("open for append: %v", err) + } + _, err = f.WriteString(`{"role":"user","content":"orphan"}` + "\n") + if err != nil { + t.Fatalf("write orphan: %v", err) + } + f.Close() + + // TruncateHistory(keepLast=4) should keep the last 4 of 11 lines, + // not the last 4 of 10. + err = store.TruncateHistory(ctx, "stale", 4) + if err != nil { + t.Fatalf("TruncateHistory: %v", err) + } + + history, err := store.GetHistory(ctx, "stale") + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + if len(history) != 4 { + t.Fatalf("expected 4, got %d", len(history)) + } + // Last 4 of [a,b,c,d,e,f,g,h,i,j,orphan] = [h,i,j,orphan] + if history[0].Content != "h" { + t.Errorf("first kept = %q, want 'h'", history[0].Content) + } + if history[3].Content != "orphan" { + t.Errorf("last kept = %q, want 'orphan'", history[3].Content) + } +} + +func TestCrashRecovery_PartialLine(t *testing.T) { + store := newTestStore(t) + ctx := context.Background() + + // Write a valid message first. + err := store.AddMessage(ctx, "crash", "user", "valid") + if err != nil { + t.Fatalf("AddMessage: %v", err) + } + + // Simulate a crash by appending a partial JSON line directly. + jsonlPath := store.jsonlPath("crash") + f, err := os.OpenFile(jsonlPath, os.O_WRONLY|os.O_APPEND, 0o644) + if err != nil { + t.Fatalf("open for append: %v", err) + } + _, err = f.WriteString(`{"role":"user","content":"incomple`) + if err != nil { + t.Fatalf("write partial: %v", err) + } + f.Close() + + // GetHistory should return only the valid message. + history, err := store.GetHistory(ctx, "crash") + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + if len(history) != 1 { + t.Fatalf("expected 1 valid message, got %d", len(history)) + } + if history[0].Content != "valid" { + t.Errorf("content = %q", history[0].Content) + } +} + +func TestPersistence_AcrossInstances(t *testing.T) { + dir := t.TempDir() + ctx := context.Background() + + // Write with first instance. + store1, err := NewJSONLStore(dir) + if err != nil { + t.Fatalf("NewJSONLStore: %v", err) + } + err = store1.AddMessage(ctx, "persist", "user", "remember me") + if err != nil { + t.Fatalf("AddMessage: %v", err) + } + err = store1.SetSummary(ctx, "persist", "a test session") + if err != nil { + t.Fatalf("SetSummary: %v", err) + } + store1.Close() + + // Read with second instance. + store2, err := NewJSONLStore(dir) + if err != nil { + t.Fatalf("NewJSONLStore: %v", err) + } + defer store2.Close() + + history, err := store2.GetHistory(ctx, "persist") + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + if len(history) != 1 || history[0].Content != "remember me" { + t.Errorf("history = %+v", history) + } + + summary, err := store2.GetSummary(ctx, "persist") + if err != nil { + t.Fatalf("GetSummary: %v", err) + } + if summary != "a test session" { + t.Errorf("summary = %q", summary) + } +} + +func TestConcurrent_AddAndRead(t *testing.T) { + store := newTestStore(t) + ctx := context.Background() + + var wg sync.WaitGroup + const goroutines = 10 + const msgsPerGoroutine = 20 + + // Concurrent writes. + for g := 0; g < goroutines; g++ { + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < msgsPerGoroutine; i++ { + _ = store.AddMessage(ctx, "concurrent", "user", "msg") + } + }() + } + wg.Wait() + + history, err := store.GetHistory(ctx, "concurrent") + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + expected := goroutines * msgsPerGoroutine + if len(history) != expected { + t.Errorf("expected %d messages, got %d", expected, len(history)) + } +} + +func TestConcurrent_SummarizeRace(t *testing.T) { + // Simulates the #704 race: one goroutine adds messages while + // another truncates + sets summary — like summarizeSession(). + store := newTestStore(t) + ctx := context.Background() + + // Seed with some messages. + for i := 0; i < 20; i++ { + err := store.AddMessage(ctx, "race", "user", "seed") + if err != nil { + t.Fatalf("AddMessage: %v", err) + } + } + + var wg sync.WaitGroup + + // Writer goroutine (main agent loop). + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < 50; i++ { + _ = store.AddMessage(ctx, "race", "user", "new") + } + }() + + // Summarizer goroutine (background task). + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < 10; i++ { + _ = store.SetSummary(ctx, "race", "summary") + _ = store.TruncateHistory(ctx, "race", 5) + } + }() + + wg.Wait() + + // Verify the store is still in a consistent state. + _, err := store.GetHistory(ctx, "race") + if err != nil { + t.Fatalf("GetHistory after race: %v", err) + } + _, err = store.GetSummary(ctx, "race") + if err != nil { + t.Fatalf("GetSummary after race: %v", err) + } +} + +func TestMultipleSessions_Isolation(t *testing.T) { + store := newTestStore(t) + ctx := context.Background() + + err := store.AddMessage(ctx, "s1", "user", "msg for s1") + if err != nil { + t.Fatalf("AddMessage: %v", err) + } + err = store.AddMessage(ctx, "s2", "user", "msg for s2") + if err != nil { + t.Fatalf("AddMessage: %v", err) + } + + h1, err := store.GetHistory(ctx, "s1") + if err != nil { + t.Fatalf("GetHistory s1: %v", err) + } + h2, err := store.GetHistory(ctx, "s2") + if err != nil { + t.Fatalf("GetHistory s2: %v", err) + } + + if len(h1) != 1 || h1[0].Content != "msg for s1" { + t.Errorf("s1 history = %+v", h1) + } + if len(h2) != 1 || h2[0].Content != "msg for s2" { + t.Errorf("s2 history = %+v", h2) + } +} + +func BenchmarkAddMessage(b *testing.B) { + dir := b.TempDir() + store, err := NewJSONLStore(dir) + if err != nil { + b.Fatalf("NewJSONLStore: %v", err) + } + defer store.Close() + ctx := context.Background() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = store.AddMessage(ctx, "bench", "user", "benchmark message content") + } +} + +func BenchmarkGetHistory_100(b *testing.B) { + dir := b.TempDir() + store, err := NewJSONLStore(dir) + if err != nil { + b.Fatalf("NewJSONLStore: %v", err) + } + defer store.Close() + ctx := context.Background() + + for i := 0; i < 100; i++ { + _ = store.AddMessage(ctx, "bench", "user", "message content") + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = store.GetHistory(ctx, "bench") + } +} + +func BenchmarkGetHistory_1000(b *testing.B) { + dir := b.TempDir() + store, err := NewJSONLStore(dir) + if err != nil { + b.Fatalf("NewJSONLStore: %v", err) + } + defer store.Close() + ctx := context.Background() + + for i := 0; i < 1000; i++ { + _ = store.AddMessage(ctx, "bench", "user", "message content") + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = store.GetHistory(ctx, "bench") + } +} diff --git a/pkg/memory/migration.go b/pkg/memory/migration.go new file mode 100644 index 000000000..c9d5176ab --- /dev/null +++ b/pkg/memory/migration.go @@ -0,0 +1,108 @@ +package memory + +import ( + "context" + "encoding/json" + "fmt" + "log" + "os" + "path/filepath" + "strings" + "time" + + "github.com/sipeed/picoclaw/pkg/providers" +) + +// jsonSession mirrors pkg/session.Session for migration purposes. +type jsonSession struct { + Key string `json:"key"` + Messages []providers.Message `json:"messages"` + Summary string `json:"summary,omitempty"` + Created time.Time `json:"created"` + Updated time.Time `json:"updated"` +} + +// MigrateFromJSON reads legacy sessions/*.json files from sessionsDir, +// writes them into the Store, and renames each migrated file to +// .json.migrated as a backup. Returns the number of sessions migrated. +// +// Files that fail to parse are logged and skipped. Already-migrated +// files (.json.migrated) are ignored, making the function idempotent. +func MigrateFromJSON( + ctx context.Context, sessionsDir string, store Store, +) (int, error) { + entries, err := os.ReadDir(sessionsDir) + if os.IsNotExist(err) { + return 0, nil + } + if err != nil { + return 0, fmt.Errorf("memory: read sessions dir: %w", err) + } + + migrated := 0 + for _, entry := range entries { + if entry.IsDir() { + continue + } + name := entry.Name() + if !strings.HasSuffix(name, ".json") { + continue + } + // Skip already-migrated files. + if strings.HasSuffix(name, ".migrated") { + continue + } + + srcPath := filepath.Join(sessionsDir, name) + + data, readErr := os.ReadFile(srcPath) + if readErr != nil { + log.Printf("memory: migrate: skip %s: %v", name, readErr) + continue + } + + var sess jsonSession + if parseErr := json.Unmarshal(data, &sess); parseErr != nil { + log.Printf("memory: migrate: skip %s: %v", name, parseErr) + continue + } + + // Use the key from the JSON content, not the filename. + // Filenames are sanitized (":" → "_") but keys are not. + key := sess.Key + if key == "" { + key = strings.TrimSuffix(name, ".json") + } + + // Use SetHistory (atomic replace) instead of per-message + // AddFullMessage. This makes migration idempotent: if the + // process crashes after writing messages but before the + // rename below, a retry replaces the partial data cleanly + // instead of duplicating messages. + if setErr := store.SetHistory(ctx, key, sess.Messages); setErr != nil { + return migrated, fmt.Errorf( + "memory: migrate %s: set history: %w", + name, setErr, + ) + } + + if sess.Summary != "" { + if sumErr := store.SetSummary(ctx, key, sess.Summary); sumErr != nil { + return migrated, fmt.Errorf( + "memory: migrate %s: set summary: %w", + name, sumErr, + ) + } + } + + // Rename to .migrated as backup (not delete). + renameErr := os.Rename(srcPath, srcPath+".migrated") + if renameErr != nil { + log.Printf("memory: migrate: rename %s: %v", name, renameErr) + } + + migrated++ + } + + return migrated, nil +} diff --git a/pkg/memory/migration_test.go b/pkg/memory/migration_test.go new file mode 100644 index 000000000..3170758b7 --- /dev/null +++ b/pkg/memory/migration_test.go @@ -0,0 +1,384 @@ +package memory + +import ( + "context" + "encoding/json" + "os" + "path/filepath" + "testing" + "time" + + "github.com/sipeed/picoclaw/pkg/providers" +) + +func writeJSONSession( + t *testing.T, dir string, filename string, sess jsonSession, +) { + t.Helper() + data, err := json.MarshalIndent(sess, "", " ") + if err != nil { + t.Fatalf("marshal session: %v", err) + } + err = os.WriteFile(filepath.Join(dir, filename), data, 0o644) + if err != nil { + t.Fatalf("write session file: %v", err) + } +} + +func TestMigrateFromJSON_Basic(t *testing.T) { + sessionsDir := t.TempDir() + store := newTestStore(t) + ctx := context.Background() + + writeJSONSession(t, sessionsDir, "test.json", jsonSession{ + Key: "test", + Messages: []providers.Message{ + {Role: "user", Content: "hello"}, + {Role: "assistant", Content: "hi"}, + }, + Summary: "A greeting.", + Created: time.Now(), + Updated: time.Now(), + }) + + count, err := MigrateFromJSON(ctx, sessionsDir, store) + if err != nil { + t.Fatalf("MigrateFromJSON: %v", err) + } + if count != 1 { + t.Errorf("expected 1 migrated, got %d", count) + } + + history, err := store.GetHistory(ctx, "test") + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + if len(history) != 2 { + t.Fatalf("expected 2 messages, got %d", len(history)) + } + if history[0].Content != "hello" || history[1].Content != "hi" { + t.Errorf("unexpected messages: %+v", history) + } + + summary, err := store.GetSummary(ctx, "test") + if err != nil { + t.Fatalf("GetSummary: %v", err) + } + if summary != "A greeting." { + t.Errorf("summary = %q", summary) + } +} + +func TestMigrateFromJSON_WithToolCalls(t *testing.T) { + sessionsDir := t.TempDir() + store := newTestStore(t) + ctx := context.Background() + + writeJSONSession(t, sessionsDir, "tools.json", jsonSession{ + Key: "tools", + Messages: []providers.Message{ + { + Role: "assistant", + Content: "Searching...", + ToolCalls: []providers.ToolCall{ + { + ID: "call_1", + Type: "function", + Function: &providers.FunctionCall{ + Name: "web_search", + Arguments: `{"q":"test"}`, + }, + }, + }, + }, + { + Role: "tool", + Content: "result", + ToolCallID: "call_1", + }, + }, + Created: time.Now(), + Updated: time.Now(), + }) + + count, err := MigrateFromJSON(ctx, sessionsDir, store) + if err != nil { + t.Fatalf("MigrateFromJSON: %v", err) + } + if count != 1 { + t.Errorf("expected 1, got %d", count) + } + + history, err := store.GetHistory(ctx, "tools") + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + if len(history) != 2 { + t.Fatalf("expected 2 messages, got %d", len(history)) + } + if len(history[0].ToolCalls) != 1 { + t.Fatalf("expected 1 tool call, got %d", len(history[0].ToolCalls)) + } + if history[0].ToolCalls[0].Function.Name != "web_search" { + t.Errorf("function = %q", history[0].ToolCalls[0].Function.Name) + } + if history[1].ToolCallID != "call_1" { + t.Errorf("ToolCallID = %q", history[1].ToolCallID) + } +} + +func TestMigrateFromJSON_MultipleFiles(t *testing.T) { + sessionsDir := t.TempDir() + store := newTestStore(t) + ctx := context.Background() + + for i := 0; i < 3; i++ { + key := string(rune('a' + i)) + writeJSONSession(t, sessionsDir, key+".json", jsonSession{ + Key: key, + Messages: []providers.Message{{Role: "user", Content: "msg " + key}}, + Created: time.Now(), + Updated: time.Now(), + }) + } + + count, err := MigrateFromJSON(ctx, sessionsDir, store) + if err != nil { + t.Fatalf("MigrateFromJSON: %v", err) + } + if count != 3 { + t.Errorf("expected 3, got %d", count) + } + + for i := 0; i < 3; i++ { + key := string(rune('a' + i)) + history, histErr := store.GetHistory(ctx, key) + if histErr != nil { + t.Fatalf("GetHistory(%q): %v", key, histErr) + } + if len(history) != 1 { + t.Errorf("session %q: expected 1 msg, got %d", key, len(history)) + } + } +} + +func TestMigrateFromJSON_InvalidJSON(t *testing.T) { + sessionsDir := t.TempDir() + store := newTestStore(t) + ctx := context.Background() + + // One valid, one invalid. + writeJSONSession(t, sessionsDir, "good.json", jsonSession{ + Key: "good", + Messages: []providers.Message{{Role: "user", Content: "ok"}}, + Created: time.Now(), + Updated: time.Now(), + }) + err := os.WriteFile( + filepath.Join(sessionsDir, "bad.json"), + []byte("{invalid json"), + 0o644, + ) + if err != nil { + t.Fatalf("write bad file: %v", err) + } + + count, err := MigrateFromJSON(ctx, sessionsDir, store) + if err != nil { + t.Fatalf("MigrateFromJSON: %v", err) + } + if count != 1 { + t.Errorf("expected 1 (bad file skipped), got %d", count) + } + + history, err := store.GetHistory(ctx, "good") + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + if len(history) != 1 { + t.Errorf("expected 1 message, got %d", len(history)) + } +} + +func TestMigrateFromJSON_RenamesFiles(t *testing.T) { + sessionsDir := t.TempDir() + store := newTestStore(t) + ctx := context.Background() + + writeJSONSession(t, sessionsDir, "rename.json", jsonSession{ + Key: "rename", + Messages: []providers.Message{{Role: "user", Content: "hi"}}, + Created: time.Now(), + Updated: time.Now(), + }) + + _, err := MigrateFromJSON(ctx, sessionsDir, store) + if err != nil { + t.Fatalf("MigrateFromJSON: %v", err) + } + + // Original .json should not exist. + _, statErr := os.Stat(filepath.Join(sessionsDir, "rename.json")) + if !os.IsNotExist(statErr) { + t.Error("rename.json should have been renamed") + } + // .json.migrated should exist. + _, statErr = os.Stat( + filepath.Join(sessionsDir, "rename.json.migrated"), + ) + if statErr != nil { + t.Errorf("rename.json.migrated should exist: %v", statErr) + } +} + +func TestMigrateFromJSON_Idempotent(t *testing.T) { + sessionsDir := t.TempDir() + store := newTestStore(t) + ctx := context.Background() + + writeJSONSession(t, sessionsDir, "idem.json", jsonSession{ + Key: "idem", + Messages: []providers.Message{{Role: "user", Content: "once"}}, + Created: time.Now(), + Updated: time.Now(), + }) + + count1, err := MigrateFromJSON(ctx, sessionsDir, store) + if err != nil { + t.Fatalf("first migration: %v", err) + } + if count1 != 1 { + t.Errorf("first run: expected 1, got %d", count1) + } + + // Second run should find only .migrated files, skip them. + count2, err := MigrateFromJSON(ctx, sessionsDir, store) + if err != nil { + t.Fatalf("second migration: %v", err) + } + if count2 != 0 { + t.Errorf("second run: expected 0, got %d", count2) + } + + history, err := store.GetHistory(ctx, "idem") + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + if len(history) != 1 { + t.Errorf("expected 1 message, got %d", len(history)) + } +} + +func TestMigrateFromJSON_ColonInKey(t *testing.T) { + sessionsDir := t.TempDir() + store := newTestStore(t) + ctx := context.Background() + + // File is named telegram_123 (sanitized), but the key inside is telegram:123. + writeJSONSession(t, sessionsDir, "telegram_123.json", jsonSession{ + Key: "telegram:123", + Messages: []providers.Message{{Role: "user", Content: "from telegram"}}, + Created: time.Now(), + Updated: time.Now(), + }) + + count, err := MigrateFromJSON(ctx, sessionsDir, store) + if err != nil { + t.Fatalf("MigrateFromJSON: %v", err) + } + if count != 1 { + t.Errorf("expected 1, got %d", count) + } + + // Accessible via the original key "telegram:123". + history, err := store.GetHistory(ctx, "telegram:123") + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + if len(history) != 1 { + t.Fatalf("expected 1 message, got %d", len(history)) + } + if history[0].Content != "from telegram" { + t.Errorf("content = %q", history[0].Content) + } + + // In the file-based store, "telegram:123" and "telegram_123" both + // sanitize to the same filename, so they share storage. This is + // expected — the colon-to-underscore mapping is a one-way function. + history2, err := store.GetHistory(ctx, "telegram_123") + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + if len(history2) != 1 { + t.Errorf("expected 1 (same file), got %d", len(history2)) + } +} + +func TestMigrateFromJSON_RetryAfterCrash(t *testing.T) { + // Simulates a crash during migration: first run writes messages + // but doesn't rename the .json file. Second run must replace + // (not duplicate) the messages thanks to SetHistory semantics. + sessionsDir := t.TempDir() + store := newTestStore(t) + ctx := context.Background() + + writeJSONSession(t, sessionsDir, "retry.json", jsonSession{ + Key: "retry", + Messages: []providers.Message{ + {Role: "user", Content: "one"}, + {Role: "assistant", Content: "two"}, + }, + Created: time.Now(), + Updated: time.Now(), + }) + + // First migration succeeds — writes messages and renames file. + count, err := MigrateFromJSON(ctx, sessionsDir, store) + if err != nil { + t.Fatalf("first migration: %v", err) + } + if count != 1 { + t.Fatalf("expected 1, got %d", count) + } + + // Simulate "crash before rename": restore the .json file. + src := filepath.Join(sessionsDir, "retry.json.migrated") + dst := filepath.Join(sessionsDir, "retry.json") + if renameErr := os.Rename(src, dst); renameErr != nil { + t.Fatalf("restore .json: %v", renameErr) + } + + // Second migration should re-import without duplicating messages. + count, err = MigrateFromJSON(ctx, sessionsDir, store) + if err != nil { + t.Fatalf("second migration: %v", err) + } + if count != 1 { + t.Fatalf("expected 1, got %d", count) + } + + history, err := store.GetHistory(ctx, "retry") + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + // Must be exactly 2 messages (not 4 from duplication). + if len(history) != 2 { + t.Fatalf("expected 2 messages (no duplicates), got %d", len(history)) + } + if history[0].Content != "one" || history[1].Content != "two" { + t.Errorf("unexpected messages: %+v", history) + } +} + +func TestMigrateFromJSON_NonexistentDir(t *testing.T) { + store := newTestStore(t) + ctx := context.Background() + + count, err := MigrateFromJSON(ctx, "/nonexistent/path", store) + if err != nil { + t.Fatalf("MigrateFromJSON: %v", err) + } + if count != 0 { + t.Errorf("expected 0, got %d", count) + } +} diff --git a/pkg/memory/store.go b/pkg/memory/store.go new file mode 100644 index 000000000..b6e11707d --- /dev/null +++ b/pkg/memory/store.go @@ -0,0 +1,42 @@ +package memory + +import ( + "context" + + "github.com/sipeed/picoclaw/pkg/providers" +) + +// Store defines an interface for persistent session storage. +// Each method is an atomic operation — there is no separate Save() call. +type Store interface { + // AddMessage appends a simple text message to a session. + AddMessage(ctx context.Context, sessionKey, role, content string) error + + // AddFullMessage appends a complete message (with tool calls, etc.) to a session. + AddFullMessage(ctx context.Context, sessionKey string, msg providers.Message) error + + // GetHistory returns all messages for a session in insertion order. + // Returns an empty slice (not nil) if the session does not exist. + GetHistory(ctx context.Context, sessionKey string) ([]providers.Message, error) + + // GetSummary returns the conversation summary for a session. + // Returns an empty string if no summary exists. + GetSummary(ctx context.Context, sessionKey string) (string, error) + + // SetSummary updates the conversation summary for a session. + SetSummary(ctx context.Context, sessionKey, summary string) error + + // TruncateHistory removes all but the last keepLast messages from a session. + // If keepLast <= 0, all messages are removed. + TruncateHistory(ctx context.Context, sessionKey string, keepLast int) error + + // SetHistory replaces all messages in a session with the provided history. + SetHistory(ctx context.Context, sessionKey string, history []providers.Message) error + + // Compact reclaims storage by physically removing logically truncated + // data. Backends that do not accumulate dead data may return nil. + Compact(ctx context.Context, sessionKey string) error + + // Close releases any resources held by the store. + Close() error +} diff --git a/pkg/providers/factory.go b/pkg/providers/factory.go index 9d53dca1c..5b3e42b9e 100644 --- a/pkg/providers/factory.go +++ b/pkg/providers/factory.go @@ -190,28 +190,6 @@ func resolveProviderSelection(cfg *config.Config) (providerSelection, error) { sel.apiBase = "https://api.mistral.ai/v1" } } - case "opencode": - if cfg.Providers.Opencode.APIKey != "" || cfg.Providers.Opencode.APIBase != "" { - sel.apiKey = cfg.Providers.Opencode.APIKey - sel.apiBase = cfg.Providers.Opencode.APIBase - sel.proxy = cfg.Providers.Opencode.Proxy - if sel.apiBase == "" { - sel.apiBase = "https://opencode.ai/zen/v1" - } - } - case "kimi", "kimi-code", "moonshot": - if cfg.Providers.Moonshot.APIKey != "" { - sel.apiKey = cfg.Providers.Moonshot.APIKey - sel.apiBase = cfg.Providers.Moonshot.APIBase - sel.proxy = cfg.Providers.Moonshot.Proxy - if sel.apiBase == "" { - if providerName == "moonshot" { - sel.apiBase = "https://api.moonshot.cn/v1" - } else { - sel.apiBase = "https://api.kimi.com/coding/v1" - } - } - } case "github_copilot", "copilot": sel.providerType = providerTypeGitHubCopilot if cfg.Providers.GitHubCopilot.APIBase != "" { @@ -232,11 +210,7 @@ func resolveProviderSelection(cfg *config.Config) (providerSelection, error) { sel.apiBase = cfg.Providers.Moonshot.APIBase sel.proxy = cfg.Providers.Moonshot.Proxy if sel.apiBase == "" { - if strings.Contains(lowerModel, "moonshot") || strings.HasPrefix(model, "moonshot/") { - sel.apiBase = "https://api.moonshot.cn/v1" - } else { - sel.apiBase = "https://api.kimi.com/coding/v1" - } + sel.apiBase = "https://api.moonshot.cn/v1" } case strings.HasPrefix(model, "openrouter/") || strings.HasPrefix(model, "anthropic/") || diff --git a/pkg/providers/factory_provider.go b/pkg/providers/factory_provider.go index 4d2949c91..155317a3b 100644 --- a/pkg/providers/factory_provider.go +++ b/pkg/providers/factory_provider.go @@ -94,7 +94,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err case "litellm", "openrouter", "groq", "zhipu", "gemini", "nvidia", "ollama", "moonshot", "shengsuanyun", "deepseek", "cerebras", - "volcengine", "vllm", "qwen", "mistral", "opencode": + "volcengine", "vllm", "qwen", "mistral": // All other OpenAI-compatible HTTP providers if cfg.APIKey == "" && cfg.APIBase == "" { return nil, "", fmt.Errorf("api_key or api_base is required for HTTP-based protocol %q", protocol) @@ -208,8 +208,6 @@ func getDefaultAPIBase(protocol string) string { return "http://localhost:8000/v1" case "mistral": return "https://api.mistral.ai/v1" - case "opencode": - return "https://opencode.ai/zen/v1" default: return "" } diff --git a/pkg/providers/factory_provider_test.go b/pkg/providers/factory_provider_test.go index 7d0ea1e32..78389f331 100644 --- a/pkg/providers/factory_provider_test.go +++ b/pkg/providers/factory_provider_test.go @@ -112,7 +112,6 @@ func TestCreateProviderFromConfig_DefaultAPIBase(t *testing.T) { {"vllm", "vllm"}, {"deepseek", "deepseek"}, {"ollama", "ollama"}, - {"opencode", "opencode"}, } for _, tt := range tests { diff --git a/pkg/providers/openai_compat/provider.go b/pkg/providers/openai_compat/provider.go index 6bed72456..ff9109e96 100644 --- a/pkg/providers/openai_compat/provider.go +++ b/pkg/providers/openai_compat/provider.go @@ -33,7 +33,6 @@ type Provider struct { apiBase string maxTokensField string // Field name for max tokens (e.g., "max_completion_tokens" for o1/glm models) httpClient *http.Client - isKimiAPI bool // true when apiBase points to api.kimi.com } type Option func(*Provider) @@ -70,17 +69,10 @@ func NewProvider(apiKey, apiBase, proxy string, opts ...Option) *Provider { } } - trimmedBase := strings.TrimRight(apiBase, "/") - var isKimi bool - if parsed, err := url.Parse(trimmedBase); err == nil { - isKimi = parsed.Hostname() == "api.kimi.com" - } - p := &Provider{ apiKey: apiKey, - apiBase: trimmedBase, + apiBase: strings.TrimRight(apiBase, "/"), httpClient: client, - isKimiAPI: isKimi, } for _, opt := range opts { @@ -184,12 +176,6 @@ func (p *Provider) Chat( if p.apiKey != "" { req.Header.Set("Authorization", "Bearer "+p.apiKey) } - // Kimi Code API rejects requests without a recognized coding-agent - // User-Agent. "KimiCLI/0.77" is the minimum version string accepted - // by the api.kimi.com/coding/v1 endpoint (per Kimi's API docs). - if p.isKimiAPI { - req.Header.Set("User-Agent", "KimiCLI/0.77") - } resp, err := p.httpClient.Do(req) if err != nil { diff --git a/pkg/providers/openai_compat/provider_test.go b/pkg/providers/openai_compat/provider_test.go index f08b24f17..174bcf00d 100644 --- a/pkg/providers/openai_compat/provider_test.go +++ b/pkg/providers/openai_compat/provider_test.go @@ -2,7 +2,6 @@ package openai_compat import ( "encoding/json" - "io" "net/http" "net/http/httptest" "net/url" @@ -421,82 +420,6 @@ func TestProvider_FunctionalOptionRequestTimeoutNonPositive(t *testing.T) { } } -// roundTripFunc adapts a function to http.RoundTripper for test injection. -type roundTripFunc func(*http.Request) (*http.Response, error) - -func (f roundTripFunc) RoundTrip(r *http.Request) (*http.Response, error) { - return f(r) -} - -func TestProviderChat_KimiCodeUserAgent(t *testing.T) { - okBody := `{"choices":[{"message":{"content":"ok"},"finish_reason":"stop"}]}` - - tests := []struct { - name string - apiBase string - wantAgent string - }{ - { - name: "sets KimiCLI User-Agent for api.kimi.com", - apiBase: "https://api.kimi.com/coding/v1", - wantAgent: "KimiCLI/0.77", - }, - { - name: "does not set KimiCLI User-Agent for other hosts", - apiBase: "https://api.example.com/v1", - wantAgent: "", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - var gotUserAgent string - - p := NewProvider("key", tt.apiBase, "") - p.httpClient.Transport = roundTripFunc( - func(r *http.Request) (*http.Response, error) { - gotUserAgent = r.Header.Get("User-Agent") - return &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser( - strings.NewReader(okBody), - ), - Header: http.Header{ - "Content-Type": {"application/json"}, - }, - }, nil - }, - ) - - _, err := p.Chat( - t.Context(), - []Message{{Role: "user", Content: "hi"}}, - nil, - "kimi-k2.5", - nil, - ) - if err != nil { - t.Fatalf("Chat() error = %v", err) - } - - if tt.wantAgent != "" { - if gotUserAgent != tt.wantAgent { - t.Fatalf( - "User-Agent = %q, want %q", - gotUserAgent, tt.wantAgent, - ) - } - } else { - if gotUserAgent == "KimiCLI/0.77" { - t.Fatalf( - "User-Agent should not be KimiCLI/0.77 for non-kimi host", - ) - } - } - }) - } -} - func TestSerializeMessages_PlainText(t *testing.T) { messages := []protocoltypes.Message{ {Role: "user", Content: "hello"}, diff --git a/pkg/tools/message.go b/pkg/tools/message.go index 15ef4ff73..d1e4a373e 100644 --- a/pkg/tools/message.go +++ b/pkg/tools/message.go @@ -3,6 +3,7 @@ package tools import ( "context" "fmt" + "sync/atomic" ) type SendCallback func(channel, chatID, content string) error @@ -11,7 +12,7 @@ type MessageTool struct { sendCallback SendCallback defaultChannel string defaultChatID string - sentInRound bool // Tracks whether a message was sent in the current processing round + sentInRound atomic.Bool // Tracks whether a message was sent in the current processing round } func NewMessageTool() *MessageTool { @@ -50,12 +51,12 @@ func (t *MessageTool) Parameters() map[string]any { func (t *MessageTool) SetContext(channel, chatID string) { t.defaultChannel = channel t.defaultChatID = chatID - t.sentInRound = false // Reset send tracking for new processing round + t.sentInRound.Store(false) // Reset send tracking for new processing round } // HasSentInRound returns true if the message tool sent a message during the current round. func (t *MessageTool) HasSentInRound() bool { - return t.sentInRound + return t.sentInRound.Load() } func (t *MessageTool) SetSendCallback(callback SendCallback) { @@ -94,7 +95,7 @@ func (t *MessageTool) Execute(ctx context.Context, args map[string]any) *ToolRes } } - t.sentInRound = true + t.sentInRound.Store(true) // Silent: user already received the message directly return &ToolResult{ ForLLM: fmt.Sprintf("Message sent to %s:%s", channel, chatID), diff --git a/pkg/tools/shell.go b/pkg/tools/shell.go index 08711ae14..a0c83eb1e 100644 --- a/pkg/tools/shell.go +++ b/pkg/tools/shell.go @@ -30,9 +30,9 @@ var ( regexp.MustCompile(`\brm\s+-[rf]{1,2}\b`), regexp.MustCompile(`\bdel\s+/[fq]\b`), regexp.MustCompile(`\brmdir\s+/s\b`), - // Match disk wiping commands, avoid matching --format flags + // Match disk wiping commands (must be followed by space/args) regexp.MustCompile( - `(?:^|[;&|]\s*|\s+)(format|mkfs|diskpart)\s`, + `\b(format|mkfs|diskpart)\b\s`, ), regexp.MustCompile(`\bdd\s+if=`), // Block writes to block devices (all common naming schemes). diff --git a/pkg/tools/shell_test.go b/pkg/tools/shell_test.go index 955acb36a..a6abca8ea 100644 --- a/pkg/tools/shell_test.go +++ b/pkg/tools/shell_test.go @@ -366,56 +366,6 @@ func TestShellTool_BlockDevices(t *testing.T) { } } -// TestShellTool_DenyPattern_DiskWiping verifies the deny pattern for disk wiping -// commands (format, mkfs, diskpart) blocks them when preceded by shell separators -// but does NOT block legitimate uses like --format flags. -func TestShellTool_DenyPattern_DiskWiping(t *testing.T) { - tool, err := NewExecTool("", false) - if err != nil { - t.Fatalf("unable to configure exec tool: %s", err) - } - - // These should be BLOCKED (disk wiping commands) - blockedCmds := []struct { - name string - cmd string - }{ - {"format with space", "format c:"}, - {"mkfs standalone", "mkfs /dev/sda"}, - {"semicolon format", "echo hello; format c:"}, - {"pipe format", "echo hello | format c:"}, - {"and format", "echo hello && format c:"}, - {"diskpart standalone", "diskpart /s script.txt"}, - } - - for _, tt := range blockedCmds { - t.Run("blocked_"+tt.name, func(t *testing.T) { - msg := tool.guardCommand(tt.cmd, "") - if !strings.Contains(msg, "blocked") { - t.Errorf("Expected %q to be blocked by safety guard, got: %q", tt.cmd, msg) - } - }) - } - - // These should be ALLOWED (not disk wiping) - allowed := []struct { - name string - cmd string - }{ - {"--format flag", "echo test --format json"}, - {"go fmt", "echo go fmt ./..."}, - } - - for _, tt := range allowed { - t.Run("allowed_"+tt.name, func(t *testing.T) { - msg := tool.guardCommand(tt.cmd, "") - if msg != "" { - t.Errorf("Expected %q to be allowed, but it was blocked: %s", tt.cmd, msg) - } - }) - } -} - // TestShellTool_SafePathsInWorkspaceRestriction verifies that safe kernel pseudo-devices // are allowed even when workspace restriction is active. func TestShellTool_SafePathsInWorkspaceRestriction(t *testing.T) { diff --git a/pkg/tools/toolloop.go b/pkg/tools/toolloop.go index cdfe0d6ce..244f0d4a2 100644 --- a/pkg/tools/toolloop.go +++ b/pkg/tools/toolloop.go @@ -10,6 +10,7 @@ import ( "context" "encoding/json" "fmt" + "sync" "github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/providers" @@ -121,37 +122,53 @@ func RunToolLoop( } messages = append(messages, assistantMsg) - // 7. Execute tool calls - for _, tc := range normalizedToolCalls { - argsJSON, _ := json.Marshal(tc.Arguments) - argsPreview := utils.Truncate(string(argsJSON), 200) - logger.InfoCF("toolloop", fmt.Sprintf("Tool call: %s(%s)", tc.Name, argsPreview), - map[string]any{ - "tool": tc.Name, - "iteration": iteration, - }) + // 7. Execute tool calls in parallel + type indexedResult struct { + result *ToolResult + tc providers.ToolCall + } - // Execute tool (no async callback for subagents - they run independently) - var toolResult *ToolResult - if config.Tools != nil { - toolResult = config.Tools.ExecuteWithContext(ctx, tc.Name, tc.Arguments, channel, chatID, nil) - } else { - toolResult = ErrorResult("No tools available") + results := make([]indexedResult, len(normalizedToolCalls)) + var wg sync.WaitGroup + + for i, tc := range normalizedToolCalls { + results[i].tc = tc + + wg.Add(1) + go func(idx int, tc providers.ToolCall) { + defer wg.Done() + + argsJSON, _ := json.Marshal(tc.Arguments) + argsPreview := utils.Truncate(string(argsJSON), 200) + logger.InfoCF("toolloop", fmt.Sprintf("Tool call: %s(%s)", tc.Name, argsPreview), + map[string]any{ + "tool": tc.Name, + "iteration": iteration, + }) + + var toolResult *ToolResult + if config.Tools != nil { + toolResult = config.Tools.ExecuteWithContext(ctx, tc.Name, tc.Arguments, channel, chatID, nil) + } else { + toolResult = ErrorResult("No tools available") + } + results[idx].result = toolResult + }(i, tc) + } + wg.Wait() + + // Append results in original order + for _, r := range results { + contentForLLM := r.result.ForLLM + if contentForLLM == "" && r.result.Err != nil { + contentForLLM = r.result.Err.Error() } - // Determine content for LLM - contentForLLM := toolResult.ForLLM - if contentForLLM == "" && toolResult.Err != nil { - contentForLLM = toolResult.Err.Error() - } - - // Add tool result message - toolResultMsg := providers.Message{ + messages = append(messages, providers.Message{ Role: "tool", Content: contentForLLM, - ToolCallID: tc.ID, - } - messages = append(messages, toolResultMsg) + ToolCallID: r.tc.ID, + }) } } diff --git a/pkg/tools/web.go b/pkg/tools/web.go index b3b127bd0..7b14686c9 100644 --- a/pkg/tools/web.go +++ b/pkg/tools/web.go @@ -395,6 +395,88 @@ func (p *PerplexitySearchProvider) Search(ctx context.Context, query string, cou return fmt.Sprintf("Results for: %s (via Perplexity)\n%s", query, searchResp.Choices[0].Message.Content), nil } +type GLMSearchProvider struct { + apiKey string + baseURL string + searchEngine string + proxy string + client *http.Client +} + +func (p *GLMSearchProvider) Search(ctx context.Context, query string, count int) (string, error) { + searchURL := p.baseURL + if searchURL == "" { + searchURL = "https://open.bigmodel.cn/api/paas/v4/web_search" + } + + payload := map[string]any{ + "search_query": query, + "search_engine": p.searchEngine, + "search_intent": false, + "count": count, + "content_size": "medium", + } + + bodyBytes, err := json.Marshal(payload) + if err != nil { + return "", fmt.Errorf("failed to marshal payload: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, "POST", searchURL, bytes.NewReader(bodyBytes)) + if err != nil { + return "", fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+p.apiKey) + + resp, err := p.client.Do(req) + if err != nil { + return "", fmt.Errorf("request failed: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) + if err != nil { + return "", fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("GLM Search API error (status %d): %s", resp.StatusCode, string(body)) + } + + var searchResp struct { + SearchResult []struct { + Title string `json:"title"` + Content string `json:"content"` + Link string `json:"link"` + } `json:"search_result"` + } + + if err := json.Unmarshal(body, &searchResp); err != nil { + return "", fmt.Errorf("failed to parse response: %w", err) + } + + results := searchResp.SearchResult + if len(results) == 0 { + return fmt.Sprintf("No results for: %s", query), nil + } + + var lines []string + lines = append(lines, fmt.Sprintf("Results for: %s (via GLM Search)", query)) + for i, item := range results { + if i >= count { + break + } + lines = append(lines, fmt.Sprintf("%d. %s\n %s", i+1, item.Title, item.Link)) + if item.Content != "" { + lines = append(lines, fmt.Sprintf(" %s", item.Content)) + } + } + + return strings.Join(lines, "\n"), nil +} + type WebSearchTool struct { provider SearchProvider maxResults int @@ -413,9 +495,11 @@ type WebSearchToolOptions struct { PerplexityAPIKey string PerplexityMaxResults int PerplexityEnabled bool - ExaAPIKey string - ExaMaxResults int - ExaEnabled bool + GLMSearchAPIKey string + GLMSearchBaseURL string + GLMSearchEngine string + GLMSearchMaxResults int + GLMSearchEnabled bool Proxy string } @@ -423,7 +507,7 @@ func NewWebSearchTool(opts WebSearchToolOptions) (*WebSearchTool, error) { var provider SearchProvider maxResults := 5 - // Priority: Perplexity > Exa > Brave > Tavily > DuckDuckGo + // Priority: Perplexity > Brave > Tavily > DuckDuckGo > GLM Search if opts.PerplexityEnabled && opts.PerplexityAPIKey != "" { client, err := createHTTPClient(opts.Proxy, perplexityTimeout) if err != nil { @@ -433,15 +517,6 @@ func NewWebSearchTool(opts WebSearchToolOptions) (*WebSearchTool, error) { if opts.PerplexityMaxResults > 0 { maxResults = opts.PerplexityMaxResults } - } else if opts.ExaEnabled && opts.ExaAPIKey != "" { - client, err := createHTTPClient(opts.Proxy, searchTimeout) - if err != nil { - return nil, fmt.Errorf("failed to create HTTP client for Exa: %w", err) - } - provider = &ExaSearchProvider{apiKey: opts.ExaAPIKey, proxy: opts.Proxy, client: client} - if opts.ExaMaxResults > 0 { - maxResults = opts.ExaMaxResults - } } else if opts.BraveEnabled && opts.BraveAPIKey != "" { client, err := createHTTPClient(opts.Proxy, searchTimeout) if err != nil { @@ -474,6 +549,25 @@ func NewWebSearchTool(opts WebSearchToolOptions) (*WebSearchTool, error) { if opts.DuckDuckGoMaxResults > 0 { maxResults = opts.DuckDuckGoMaxResults } + } else if opts.GLMSearchEnabled && opts.GLMSearchAPIKey != "" { + client, err := createHTTPClient(opts.Proxy, searchTimeout) + if err != nil { + return nil, fmt.Errorf("failed to create HTTP client for GLM Search: %w", err) + } + searchEngine := opts.GLMSearchEngine + if searchEngine == "" { + searchEngine = "search_std" + } + provider = &GLMSearchProvider{ + apiKey: opts.GLMSearchAPIKey, + baseURL: opts.GLMSearchBaseURL, + searchEngine: searchEngine, + proxy: opts.Proxy, + client: client, + } + if opts.GLMSearchMaxResults > 0 { + maxResults = opts.GLMSearchMaxResults + } } else { return nil, nil } @@ -721,77 +815,3 @@ func (t *WebFetchTool) extractText(htmlContent string) string { return strings.Join(cleanLines, "\n") } - -// ExaSearchProvider uses the Exa AI search API (https://exa.ai). -type ExaSearchProvider struct { - apiKey string - proxy string - client *http.Client -} - -func (p *ExaSearchProvider) Search(ctx context.Context, query string, count int) (string, error) { - reqBody := map[string]any{ - "query": query, - "num_results": count, - "type": "neural", - } - jsonData, err := json.Marshal(reqBody) - if err != nil { - return "", fmt.Errorf("exa: marshal error: %w", err) - } - - req, err := http.NewRequestWithContext(ctx, "POST", "https://api.exa.ai/search", bytes.NewReader(jsonData)) - if err != nil { - return "", fmt.Errorf("exa: request error: %w", err) - } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("x-api-key", p.apiKey) - - resp, err := p.client.Do(req) - if err != nil { - return "", fmt.Errorf("exa: search failed: %w", err) - } - defer resp.Body.Close() - - body, err := io.ReadAll(resp.Body) - if err != nil { - return "", fmt.Errorf("exa: read error: %w", err) - } - if resp.StatusCode != http.StatusOK { - return "", fmt.Errorf("exa: API error %d: %s", resp.StatusCode, string(body)) - } - - var result struct { - Results []struct { - Title string `json:"title"` - URL string `json:"url"` - Text string `json:"text"` - } `json:"results"` - } - if err := json.Unmarshal(body, &result); err != nil { - return "", fmt.Errorf("exa: parse error: %w", err) - } - - if len(result.Results) == 0 { - return fmt.Sprintf("No results for: %s", query), nil - } - - var lines []string - lines = append(lines, fmt.Sprintf("Results for: %s (via Exa)", query)) - maxResults := count - if maxResults > len(result.Results) { - maxResults = len(result.Results) - } - for i, r := range result.Results[:maxResults] { - lines = append(lines, fmt.Sprintf("%d. %s\n %s", i+1, r.Title, r.URL)) - if r.Text != "" { - snippet := r.Text - if len(snippet) > 200 { - snippet = snippet[:200] + "..." - } - lines = append(lines, fmt.Sprintf(" %s", snippet)) - } - } - - return strings.Join(lines, "\n"), nil -} diff --git a/pkg/tools/web_test.go b/pkg/tools/web_test.go index 896b39a33..bdd30d385 100644 --- a/pkg/tools/web_test.go +++ b/pkg/tools/web_test.go @@ -5,7 +5,6 @@ import ( "context" "encoding/json" "fmt" - "io" "net/http" "net/http/httptest" "strings" @@ -683,86 +682,7 @@ func TestWebTool_TavilySearch_Success(t *testing.T) { } } -func TestNewWebSearchTool_ExaPriority(t *testing.T) { - // Exa should be selected when enabled with API key - tool, err := NewWebSearchTool(WebSearchToolOptions{ - ExaEnabled: true, - ExaAPIKey: "exa-key", - ExaMaxResults: 3, - }) - if err != nil { - t.Fatalf("NewWebSearchTool() error: %v", err) - } - if tool == nil { - t.Fatal("Expected non-nil tool when Exa is enabled with API key") - } - if _, ok := tool.provider.(*ExaSearchProvider); !ok { - t.Fatalf("provider type = %T, want *ExaSearchProvider", tool.provider) - } - if tool.maxResults != 3 { - t.Fatalf("maxResults = %d, want 3", tool.maxResults) - } - - // Exa enabled but no API key should fall through - tool, err = NewWebSearchTool(WebSearchToolOptions{ - ExaEnabled: true, - ExaAPIKey: "", - }) - if err != nil { - t.Fatalf("NewWebSearchTool() error: %v", err) - } - if tool != nil { - t.Errorf("Expected nil tool when Exa API key is empty and no other provider enabled") - } - - // Perplexity should take priority over Exa - tool, err = NewWebSearchTool(WebSearchToolOptions{ - PerplexityEnabled: true, - PerplexityAPIKey: "perp-key", - ExaEnabled: true, - ExaAPIKey: "exa-key", - }) - if err != nil { - t.Fatalf("NewWebSearchTool() error: %v", err) - } - if _, ok := tool.provider.(*PerplexitySearchProvider); !ok { - t.Fatalf("provider type = %T, want *PerplexitySearchProvider (Perplexity should outrank Exa)", tool.provider) - } - - // Exa should take priority over Brave - tool, err = NewWebSearchTool(WebSearchToolOptions{ - ExaEnabled: true, - ExaAPIKey: "exa-key", - BraveEnabled: true, - BraveAPIKey: "brave-key", - }) - if err != nil { - t.Fatalf("NewWebSearchTool() error: %v", err) - } - if _, ok := tool.provider.(*ExaSearchProvider); !ok { - t.Fatalf("provider type = %T, want *ExaSearchProvider (Exa should outrank Brave)", tool.provider) - } -} - -func TestNewWebSearchTool_ExaProxyPropagation(t *testing.T) { - tool, err := NewWebSearchTool(WebSearchToolOptions{ - ExaEnabled: true, - ExaAPIKey: "k", - Proxy: "http://127.0.0.1:7890", - }) - if err != nil { - t.Fatalf("NewWebSearchTool() error: %v", err) - } - p, ok := tool.provider.(*ExaSearchProvider) - if !ok { - t.Fatalf("provider type = %T, want *ExaSearchProvider", tool.provider) - } - if p.proxy != "http://127.0.0.1:7890" { - t.Fatalf("provider proxy = %q, want %q", p.proxy, "http://127.0.0.1:7890") - } -} - -func TestExaSearchProvider_Success(t *testing.T) { +func TestWebTool_GLMSearch_Success(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method != "POST" { t.Errorf("Expected POST request, got %s", r.Method) @@ -770,130 +690,126 @@ func TestExaSearchProvider_Success(t *testing.T) { if r.Header.Get("Content-Type") != "application/json" { t.Errorf("Expected Content-Type application/json, got %s", r.Header.Get("Content-Type")) } - if r.Header.Get("x-api-key") != "test-exa-key" { - t.Errorf("Expected x-api-key test-exa-key, got %s", r.Header.Get("x-api-key")) + if r.Header.Get("Authorization") != "Bearer test-glm-key" { + t.Errorf("Expected Authorization Bearer test-glm-key, got %s", r.Header.Get("Authorization")) } - // Verify payload - body, _ := io.ReadAll(r.Body) var payload map[string]any - json.Unmarshal(body, &payload) - if payload["query"] != "test query" { - t.Errorf("Expected query 'test query', got %v", payload["query"]) + json.NewDecoder(r.Body).Decode(&payload) + if payload["search_query"] != "test query" { + t.Errorf("Expected search_query 'test query', got %v", payload["search_query"]) } - if payload["type"] != "neural" { - t.Errorf("Expected type 'neural', got %v", payload["type"]) + if payload["search_engine"] != "search_std" { + t.Errorf("Expected search_engine 'search_std', got %v", payload["search_engine"]) } response := map[string]any{ - "results": []map[string]any{ - {"title": "Exa Result 1", "url": "https://exa.ai/1", "text": "First result text"}, - {"title": "Exa Result 2", "url": "https://exa.ai/2", "text": "Second result text"}, - {"title": "Exa Result 3", "url": "https://exa.ai/3", "text": "Third result text"}, + "id": "web-search-test", + "created": 1709568000, + "search_result": []map[string]any{ + { + "title": "Test GLM Result", + "content": "GLM search snippet", + "link": "https://example.com/glm", + "media": "Example", + "publish_date": "2026-03-04", + }, }, } w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) json.NewEncoder(w).Encode(response) })) defer server.Close() - provider := &ExaSearchProvider{ - apiKey: "test-exa-key", - client: &http.Client{}, - } - - // Temporarily override the API URL by using a custom transport - provider.client.Transport = rewriteHostTransport(server.URL) - - result, err := provider.Search(context.Background(), "test query", 5) - if err != nil { - t.Fatalf("Search() error: %v", err) - } - - if !strings.Contains(result, "via Exa") { - t.Errorf("Expected '(via Exa)' attribution, got: %s", result) - } - if !strings.Contains(result, "Exa Result 1") || !strings.Contains(result, "https://exa.ai/1") { - t.Errorf("Expected results in output, got: %s", result) - } - if !strings.Contains(result, "First result text") { - t.Errorf("Expected snippet text in output, got: %s", result) - } -} - -func TestExaSearchProvider_EmptyResults(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - response := map[string]any{"results": []map[string]any{}} - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(response) - })) - defer server.Close() - - provider := &ExaSearchProvider{ - apiKey: "test-key", - client: &http.Client{Transport: rewriteHostTransport(server.URL)}, - } - - result, err := provider.Search(context.Background(), "no results query", 5) - if err != nil { - t.Fatalf("Search() error: %v", err) - } - if !strings.Contains(result, "No results for: no results query") { - t.Errorf("Expected 'No results' message, got: %s", result) - } -} - -func TestExaSearchProvider_MaxResultsCapping(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Return 5 results - results := make([]map[string]any, 5) - for i := range results { - results[i] = map[string]any{ - "title": fmt.Sprintf("Result %d", i+1), - "url": fmt.Sprintf("https://exa.ai/%d", i+1), - "text": fmt.Sprintf("Text %d", i+1), - } - } - response := map[string]any{"results": results} - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(response) - })) - defer server.Close() - - provider := &ExaSearchProvider{ - apiKey: "test-key", - client: &http.Client{Transport: rewriteHostTransport(server.URL)}, - } - - // Request only 2 results even though API returns 5 - result, err := provider.Search(context.Background(), "test", 2) - if err != nil { - t.Fatalf("Search() error: %v", err) - } - - if !strings.Contains(result, "Result 1") || !strings.Contains(result, "Result 2") { - t.Errorf("Expected first 2 results, got: %s", result) - } - if strings.Contains(result, "Result 3") { - t.Errorf("Expected results capped at 2, but got Result 3 in output: %s", result) - } -} - -// rewriteHostTransport returns an http.RoundTripper that redirects all requests to the given target URL. -func rewriteHostTransport(target string) http.RoundTripper { - return roundTripFunc(func(req *http.Request) (*http.Response, error) { - newURL := target + req.URL.Path - newReq, err := http.NewRequestWithContext(req.Context(), req.Method, newURL, req.Body) - if err != nil { - return nil, err - } - newReq.Header = req.Header - return http.DefaultClient.Do(newReq) + tool, err := NewWebSearchTool(WebSearchToolOptions{ + GLMSearchEnabled: true, + GLMSearchAPIKey: "test-glm-key", + GLMSearchBaseURL: server.URL, + GLMSearchEngine: "search_std", }) + if err != nil { + t.Fatalf("NewWebSearchTool() error: %v", err) + } + + result := tool.Execute(context.Background(), map[string]any{ + "query": "test query", + }) + + if result.IsError { + t.Errorf("Expected success, got IsError=true: %s", result.ForLLM) + } + if !strings.Contains(result.ForUser, "Test GLM Result") { + t.Errorf("Expected 'Test GLM Result' in output, got: %s", result.ForUser) + } + if !strings.Contains(result.ForUser, "https://example.com/glm") { + t.Errorf("Expected URL in output, got: %s", result.ForUser) + } + if !strings.Contains(result.ForUser, "via GLM Search") { + t.Errorf("Expected 'via GLM Search' in output, got: %s", result.ForUser) + } } -type roundTripFunc func(*http.Request) (*http.Response, error) +func TestWebTool_GLMSearch_APIError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + w.Write([]byte(`{"error":"invalid api key"}`)) + })) + defer server.Close() -func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { - return f(req) + tool, err := NewWebSearchTool(WebSearchToolOptions{ + GLMSearchEnabled: true, + GLMSearchAPIKey: "bad-key", + GLMSearchBaseURL: server.URL, + GLMSearchEngine: "search_std", + }) + if err != nil { + t.Fatalf("NewWebSearchTool() error: %v", err) + } + + result := tool.Execute(context.Background(), map[string]any{ + "query": "test query", + }) + + if !result.IsError { + t.Errorf("Expected IsError=true for 401 response") + } + if !strings.Contains(result.ForLLM, "status 401") { + t.Errorf("Expected status 401 in error, got: %s", result.ForLLM) + } +} + +func TestWebTool_GLMSearch_Priority(t *testing.T) { + // GLM Search should only be selected when all other providers are disabled + tool, err := NewWebSearchTool(WebSearchToolOptions{ + DuckDuckGoEnabled: true, + DuckDuckGoMaxResults: 5, + GLMSearchEnabled: true, + GLMSearchAPIKey: "test-key", + GLMSearchBaseURL: "https://example.com", + GLMSearchEngine: "search_std", + }) + if err != nil { + t.Fatalf("NewWebSearchTool() error: %v", err) + } + + // DuckDuckGo should win over GLM Search + if _, ok := tool.provider.(*DuckDuckGoSearchProvider); !ok { + t.Errorf("Expected DuckDuckGoSearchProvider when both enabled, got %T", tool.provider) + } + + // With DuckDuckGo disabled, GLM Search should be selected + tool2, err := NewWebSearchTool(WebSearchToolOptions{ + DuckDuckGoEnabled: false, + GLMSearchEnabled: true, + GLMSearchAPIKey: "test-key", + GLMSearchBaseURL: "https://example.com", + GLMSearchEngine: "search_std", + }) + if err != nil { + t.Fatalf("NewWebSearchTool() error: %v", err) + } + if _, ok := tool2.provider.(*GLMSearchProvider); !ok { + t.Errorf("Expected GLMSearchProvider when only GLM enabled, got %T", tool2.provider) + } } diff --git a/pkg/utils/media.go b/pkg/utils/media.go index a34889fb8..3e1c5d88e 100644 --- a/pkg/utils/media.go +++ b/pkg/utils/media.go @@ -3,6 +3,7 @@ package utils import ( "io" "net/http" + "net/url" "os" "path/filepath" "strings" @@ -52,11 +53,12 @@ type DownloadOptions struct { Timeout time.Duration ExtraHeaders map[string]string LoggerPrefix string + ProxyURL string } // DownloadFile downloads a file from URL to a local temp directory. // Returns the local file path or empty string on error. -func DownloadFile(url, filename string, opts DownloadOptions) string { +func DownloadFile(urlStr, filename string, opts DownloadOptions) string { // Set defaults if opts.Timeout == 0 { opts.Timeout = 60 * time.Second @@ -78,7 +80,7 @@ func DownloadFile(url, filename string, opts DownloadOptions) string { localPath := filepath.Join(mediaDir, uuid.New().String()[:8]+"_"+safeName) // Create HTTP request - req, err := http.NewRequest("GET", url, nil) + req, err := http.NewRequest("GET", urlStr, nil) if err != nil { logger.ErrorCF(opts.LoggerPrefix, "Failed to create download request", map[string]any{ "error": err.Error(), @@ -92,11 +94,24 @@ func DownloadFile(url, filename string, opts DownloadOptions) string { } client := &http.Client{Timeout: opts.Timeout} + if opts.ProxyURL != "" { + proxyURL, parseErr := url.Parse(opts.ProxyURL) + if parseErr != nil { + logger.ErrorCF(opts.LoggerPrefix, "Invalid proxy URL for download", map[string]any{ + "error": parseErr.Error(), + "proxy": opts.ProxyURL, + }) + return "" + } + client.Transport = &http.Transport{ + Proxy: http.ProxyURL(proxyURL), + } + } resp, err := client.Do(req) if err != nil { logger.ErrorCF(opts.LoggerPrefix, "Failed to download file", map[string]any{ "error": err.Error(), - "url": url, + "url": urlStr, }) return "" } @@ -105,7 +120,7 @@ func DownloadFile(url, filename string, opts DownloadOptions) string { if resp.StatusCode != http.StatusOK { logger.ErrorCF(opts.LoggerPrefix, "File download returned non-200 status", map[string]any{ "status": resp.StatusCode, - "url": url, + "url": urlStr, }) return "" }