merge: resolve conflicts with upstream/main

Accept upstream versions for all non-Telegram files to keep PR
scope focused on Telegram message chunking only.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
I Putu Eddy Irawan
2026-03-04 22:26:51 +07:00
33 changed files with 2406 additions and 1002 deletions
+46 -32
View File
@@ -18,22 +18,24 @@ import (
// AgentInstance represents a fully configured agent with its own workspace,
// session manager, context builder, and tool registry.
type AgentInstance struct {
ID string
Name string
Model string
Fallbacks []string
Workspace string
MaxIterations int
MaxTokens int
Temperature float64
ContextWindow int
Provider providers.LLMProvider
Sessions *session.SessionManager
ContextBuilder *ContextBuilder
Tools *tools.ToolRegistry
Subagents *config.SubagentsConfig
SkillsFilter []string
Candidates []providers.FallbackCandidate
ID string
Name string
Model string
Fallbacks []string
Workspace string
MaxIterations int
MaxTokens int
Temperature float64
ContextWindow int
SummarizeMessageThreshold int
SummarizeTokenPercent int
Provider providers.LLMProvider
Sessions *session.SessionManager
ContextBuilder *ContextBuilder
Tools *tools.ToolRegistry
Subagents *config.SubagentsConfig
SkillsFilter []string
Candidates []providers.FallbackCandidate
}
// NewAgentInstance creates an agent instance from config.
@@ -101,6 +103,16 @@ func NewAgentInstance(
temperature = *defaults.Temperature
}
summarizeMessageThreshold := defaults.SummarizeMessageThreshold
if summarizeMessageThreshold == 0 {
summarizeMessageThreshold = 20
}
summarizeTokenPercent := defaults.SummarizeTokenPercent
if summarizeTokenPercent == 0 {
summarizeTokenPercent = 75
}
// Resolve fallback candidates
modelCfg := providers.ModelConfig{
Primary: model,
@@ -149,22 +161,24 @@ func NewAgentInstance(
candidates := providers.ResolveCandidatesWithLookup(modelCfg, defaults.Provider, resolveFromModelList)
return &AgentInstance{
ID: agentID,
Name: agentName,
Model: model,
Fallbacks: fallbacks,
Workspace: workspace,
MaxIterations: maxIter,
MaxTokens: maxTokens,
Temperature: temperature,
ContextWindow: maxTokens,
Provider: provider,
Sessions: sessionsManager,
ContextBuilder: contextBuilder,
Tools: toolsRegistry,
Subagents: subagents,
SkillsFilter: skillsFilter,
Candidates: candidates,
ID: agentID,
Name: agentName,
Model: model,
Fallbacks: fallbacks,
Workspace: workspace,
MaxIterations: maxIter,
MaxTokens: maxTokens,
Temperature: temperature,
ContextWindow: maxTokens,
SummarizeMessageThreshold: summarizeMessageThreshold,
SummarizeTokenPercent: summarizeTokenPercent,
Provider: provider,
Sessions: sessionsManager,
ContextBuilder: contextBuilder,
Tools: toolsRegistry,
Subagents: subagents,
SkillsFilter: skillsFilter,
Candidates: candidates,
}
}
+66 -57
View File
@@ -118,9 +118,11 @@ func registerSharedTools(
PerplexityAPIKey: cfg.Tools.Web.Perplexity.APIKey,
PerplexityMaxResults: cfg.Tools.Web.Perplexity.MaxResults,
PerplexityEnabled: cfg.Tools.Web.Perplexity.Enabled,
ExaAPIKey: cfg.Tools.Web.Exa.APIKey,
ExaMaxResults: cfg.Tools.Web.Exa.MaxResults,
ExaEnabled: cfg.Tools.Web.Exa.Enabled,
GLMSearchAPIKey: cfg.Tools.Web.GLMSearch.APIKey,
GLMSearchBaseURL: cfg.Tools.Web.GLMSearch.BaseURL,
GLMSearchEngine: cfg.Tools.Web.GLMSearch.SearchEngine,
GLMSearchMaxResults: cfg.Tools.Web.GLMSearch.MaxResults,
GLMSearchEnabled: cfg.Tools.Web.GLMSearch.Enabled,
Proxy: cfg.Tools.Web.Proxy,
})
if err != nil {
@@ -967,62 +969,76 @@ func (al *AgentLoop) runLLMIteration(
// Save assistant message with tool calls to session
agent.Sessions.AddFullMessage(opts.SessionKey, assistantMsg)
// Execute tool calls
for _, tc := range normalizedToolCalls {
argsJSON, _ := json.Marshal(tc.Arguments)
argsPreview := utils.Truncate(string(argsJSON), 200)
logger.InfoCF("agent", fmt.Sprintf("Tool call: %s(%s)", tc.Name, argsPreview),
map[string]any{
"agent_id": agent.ID,
"tool": tc.Name,
"iteration": iteration,
})
// Execute tool calls in parallel
type indexedAgentResult struct {
result *tools.ToolResult
tc providers.ToolCall
}
// Create async callback for tools that implement AsyncTool
// NOTE: Following openclaw's design, async tools do NOT send results directly to users.
// Instead, they notify the agent via PublishInbound, and the agent decides
// whether to forward the result to the user (in processSystemMessage).
asyncCallback := func(callbackCtx context.Context, result *tools.ToolResult) {
// Log the async completion but don't send directly to user
// The agent will handle user notification via processSystemMessage
if !result.Silent && result.ForUser != "" {
logger.InfoCF("agent", "Async tool completed, agent will handle notification",
map[string]any{
"tool": tc.Name,
"content_len": len(result.ForUser),
})
agentResults := make([]indexedAgentResult, len(normalizedToolCalls))
var wg sync.WaitGroup
for i, tc := range normalizedToolCalls {
agentResults[i].tc = tc
wg.Add(1)
go func(idx int, tc providers.ToolCall) {
defer wg.Done()
argsJSON, _ := json.Marshal(tc.Arguments)
argsPreview := utils.Truncate(string(argsJSON), 200)
logger.InfoCF("agent", fmt.Sprintf("Tool call: %s(%s)", tc.Name, argsPreview),
map[string]any{
"agent_id": agent.ID,
"tool": tc.Name,
"iteration": iteration,
})
// Create async callback for tools that implement AsyncTool
asyncCallback := func(callbackCtx context.Context, result *tools.ToolResult) {
if !result.Silent && result.ForUser != "" {
logger.InfoCF("agent", "Async tool completed, agent will handle notification",
map[string]any{
"tool": tc.Name,
"content_len": len(result.ForUser),
})
}
}
}
toolResult := agent.Tools.ExecuteWithContext(
ctx,
tc.Name,
tc.Arguments,
opts.Channel,
opts.ChatID,
asyncCallback,
)
toolResult := agent.Tools.ExecuteWithContext(
ctx,
tc.Name,
tc.Arguments,
opts.Channel,
opts.ChatID,
asyncCallback,
)
agentResults[idx].result = toolResult
}(i, tc)
}
wg.Wait()
// Process results in original order (send to user, save to session)
for _, r := range agentResults {
// Send ForUser content to user immediately if not Silent
if !toolResult.Silent && toolResult.ForUser != "" && opts.SendResponse {
if !r.result.Silent && r.result.ForUser != "" && opts.SendResponse {
al.bus.PublishOutbound(ctx, bus.OutboundMessage{
Channel: opts.Channel,
ChatID: opts.ChatID,
Content: toolResult.ForUser,
Content: r.result.ForUser,
})
logger.DebugCF("agent", "Sent tool result to user",
map[string]any{
"tool": tc.Name,
"content_len": len(toolResult.ForUser),
"tool": r.tc.Name,
"content_len": len(r.result.ForUser),
})
}
// If tool returned media refs, publish them as outbound media
if len(toolResult.Media) > 0 && opts.SendResponse {
parts := make([]bus.MediaPart, 0, len(toolResult.Media))
for _, ref := range toolResult.Media {
if len(r.result.Media) > 0 && opts.SendResponse {
parts := make([]bus.MediaPart, 0, len(r.result.Media))
for _, ref := range r.result.Media {
part := bus.MediaPart{Ref: ref}
// Populate metadata from MediaStore when available
if al.mediaStore != nil {
if _, meta, err := al.mediaStore.ResolveWithMeta(ref); err == nil {
part.Filename = meta.Filename
@@ -1040,15 +1056,15 @@ func (al *AgentLoop) runLLMIteration(
}
// Determine content for LLM based on tool result
contentForLLM := toolResult.ForLLM
if contentForLLM == "" && toolResult.Err != nil {
contentForLLM = toolResult.Err.Error()
contentForLLM := r.result.ForLLM
if contentForLLM == "" && r.result.Err != nil {
contentForLLM = r.result.Err.Error()
}
toolResultMsg := providers.Message{
Role: "tool",
Content: contentForLLM,
ToolCallID: tc.ID,
ToolCallID: r.tc.ID,
}
messages = append(messages, toolResultMsg)
@@ -1084,9 +1100,9 @@ func (al *AgentLoop) updateToolContexts(agent *AgentInstance, channel, chatID st
func (al *AgentLoop) maybeSummarize(agent *AgentInstance, sessionKey, channel, chatID string) {
newHistory := agent.Sessions.GetHistory(sessionKey)
tokenEstimate := al.estimateTokens(newHistory)
threshold := agent.ContextWindow * 75 / 100
threshold := agent.ContextWindow * agent.SummarizeTokenPercent / 100
if len(newHistory) > 20 || tokenEstimate > threshold {
if len(newHistory) > agent.SummarizeMessageThreshold || tokenEstimate > threshold {
summarizeKey := agent.ID + ":" + sessionKey
if _, loading := al.summarizing.LoadOrStore(summarizeKey, true); !loading {
go func() {
@@ -1114,15 +1130,8 @@ func (al *AgentLoop) forceCompression(agent *AgentInstance, sessionKey string) {
return
}
// Find the mid-point of the conversation, avoiding splitting tool call/result pairs.
// A tool-call message (role=assistant with ToolCalls) must be followed by its
// tool-result message (role=tool). Splitting between them causes API errors.
// Helper to find the mid-point of the conversation
mid := len(conversation) / 2
if mid < len(conversation) && mid > 0 {
if conversation[mid].Role == "tool" {
mid++ // move past the tool result to keep the pair together
}
}
// New history structure:
// 1. System Prompt (with compression note appended)
-79
View File
@@ -603,85 +603,6 @@ func TestAgentLoop_ContextExhaustionRetry(t *testing.T) {
}
}
// TestForceCompression_ToolMessageBoundary verifies that forceCompression does not
// split a tool call/result pair when the midpoint falls on a "tool" role message.
// Regression test for: API errors when orphaned tool result messages appear
// without their preceding assistant tool-call message.
func TestForceCompression_ToolMessageBoundary(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
msgBus := bus.NewMessageBus()
provider := &mockProvider{}
al := NewAgentLoop(cfg, msgBus, provider)
sessionKey := "test-session-tool-boundary"
defaultAgent := al.registry.GetDefaultAgent()
if defaultAgent == nil {
t.Fatal("No default agent found")
}
// Construct a history where len(conversation)/2 falls exactly on a "tool" message.
// history = [system, user, assistant(tool_call), tool, user, assistant, user_trigger]
// conversation = history[1:6] = [user, assistant(tool_call), tool, user, assistant]
// len(conversation) = 5, mid = 5/2 = 2 => conversation[2].Role == "tool"
// Without the fix, this would split between assistant(tool_call) and tool result.
history := []providers.Message{
{Role: "system", Content: "You are a helpful assistant."},
{Role: "user", Content: "What files are in the current directory?"},
{Role: "assistant", Content: "", ToolCalls: []providers.ToolCall{
{ID: "call_1", Name: "exec", Arguments: map[string]any{"command": "ls"}},
}},
{Role: "tool", Content: "file1.txt\nfile2.txt", ToolCallID: "call_1"},
{Role: "user", Content: "Tell me about file1.txt"},
{Role: "assistant", Content: "file1.txt is a text file."},
{Role: "user", Content: "Thanks"}, // trigger message
}
// Create the session first (AddMessage creates the session entry),
// then overwrite with our full history via SetHistory.
defaultAgent.Sessions.AddMessage(sessionKey, "system", "init")
defaultAgent.Sessions.SetHistory(sessionKey, history)
// Call forceCompression
al.forceCompression(defaultAgent, sessionKey)
// Verify the result
compressed := defaultAgent.Sessions.GetHistory(sessionKey)
// Check that no message with role="tool" is the first conversation message
// (after the system prompt). If it is, it means the tool result was orphaned.
for i := 1; i < len(compressed); i++ {
if compressed[i].Role == "tool" {
// There must be an assistant message with tool calls before it
if i == 1 {
t.Errorf("Tool result message at position %d is orphaned (no preceding assistant with tool call)", i)
} else if compressed[i-1].Role != "assistant" || len(compressed[i-1].ToolCalls) == 0 {
t.Errorf("Tool result at position %d is not preceded by assistant with tool calls (preceded by role=%q)", i, compressed[i-1].Role)
}
}
}
// Verify the system prompt has the compression note
if !strings.Contains(compressed[0].Content, "Emergency compression") {
t.Errorf("Expected compression note in system prompt, got: %s", compressed[0].Content)
}
}
func TestTargetReasoningChannelID_AllChannels(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
+40
View File
@@ -3,12 +3,15 @@ package discord
import (
"context"
"fmt"
"net/http"
"net/url"
"os"
"strings"
"sync"
"time"
"github.com/bwmarrin/discordgo"
"github.com/gorilla/websocket"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/channels"
@@ -40,6 +43,9 @@ func NewDiscordChannel(cfg config.DiscordConfig, bus *bus.MessageBus) (*DiscordC
return nil, fmt.Errorf("failed to create discord session: %w", err)
}
if err := applyDiscordProxy(session, cfg.Proxy); err != nil {
return nil, err
}
base := channels.NewBaseChannel("discord", cfg, bus, cfg.AllowFrom,
channels.WithMaxMessageLength(2000),
channels.WithGroupTrigger(cfg.GroupTrigger),
@@ -465,9 +471,43 @@ func (c *DiscordChannel) StartTyping(ctx context.Context, chatID string) (func()
func (c *DiscordChannel) downloadAttachment(url, filename string) string {
return utils.DownloadFile(url, filename, utils.DownloadOptions{
LoggerPrefix: "discord",
ProxyURL: c.config.Proxy,
})
}
func applyDiscordProxy(session *discordgo.Session, proxyAddr string) error {
var proxyFunc func(*http.Request) (*url.URL, error)
if proxyAddr != "" {
proxyURL, err := url.Parse(proxyAddr)
if err != nil {
return fmt.Errorf("invalid discord proxy URL %q: %w", proxyAddr, err)
}
proxyFunc = http.ProxyURL(proxyURL)
} else if os.Getenv("HTTP_PROXY") != "" || os.Getenv("HTTPS_PROXY") != "" {
proxyFunc = http.ProxyFromEnvironment
}
if proxyFunc == nil {
return nil
}
transport := &http.Transport{Proxy: proxyFunc}
session.Client = &http.Client{
Timeout: sendTimeout,
Transport: transport,
}
if session.Dialer != nil {
dialerCopy := *session.Dialer
dialerCopy.Proxy = proxyFunc
session.Dialer = &dialerCopy
} else {
session.Dialer = &websocket.Dialer{Proxy: proxyFunc}
}
return nil
}
// stripBotMention removes the bot mention from the message content.
// Discord mentions have the format <@USER_ID> or <@!USER_ID> (with nickname).
func (c *DiscordChannel) stripBotMention(text string) string {
+91
View File
@@ -0,0 +1,91 @@
package discord
import (
"net/http"
"net/url"
"testing"
"github.com/bwmarrin/discordgo"
)
func TestApplyDiscordProxy_CustomProxy(t *testing.T) {
session, err := discordgo.New("Bot test-token")
if err != nil {
t.Fatalf("discordgo.New() error: %v", err)
}
if err = applyDiscordProxy(session, "http://127.0.0.1:7890"); err != nil {
t.Fatalf("applyDiscordProxy() error: %v", err)
}
req, err := http.NewRequest("GET", "https://discord.com/api/v10/gateway", nil)
if err != nil {
t.Fatalf("http.NewRequest() error: %v", err)
}
restProxy := session.Client.Transport.(*http.Transport).Proxy
restProxyURL, err := restProxy(req)
if err != nil {
t.Fatalf("rest proxy func error: %v", err)
}
if got, want := restProxyURL.String(), "http://127.0.0.1:7890"; got != want {
t.Fatalf("REST proxy = %q, want %q", got, want)
}
wsProxyURL, err := session.Dialer.Proxy(req)
if err != nil {
t.Fatalf("ws proxy func error: %v", err)
}
if got, want := wsProxyURL.String(), "http://127.0.0.1:7890"; got != want {
t.Fatalf("WS proxy = %q, want %q", got, want)
}
}
func TestApplyDiscordProxy_FromEnvironment(t *testing.T) {
t.Setenv("HTTP_PROXY", "http://127.0.0.1:8888")
t.Setenv("http_proxy", "http://127.0.0.1:8888")
t.Setenv("HTTPS_PROXY", "http://127.0.0.1:8888")
t.Setenv("https_proxy", "http://127.0.0.1:8888")
t.Setenv("ALL_PROXY", "")
t.Setenv("all_proxy", "")
t.Setenv("NO_PROXY", "")
t.Setenv("no_proxy", "")
session, err := discordgo.New("Bot test-token")
if err != nil {
t.Fatalf("discordgo.New() error: %v", err)
}
if err = applyDiscordProxy(session, ""); err != nil {
t.Fatalf("applyDiscordProxy() error: %v", err)
}
req, err := http.NewRequest("GET", "https://discord.com/api/v10/gateway", nil)
if err != nil {
t.Fatalf("http.NewRequest() error: %v", err)
}
gotURL, err := session.Dialer.Proxy(req)
if err != nil {
t.Fatalf("ws proxy func error: %v", err)
}
wantURL, err := url.Parse("http://127.0.0.1:8888")
if err != nil {
t.Fatalf("url.Parse() error: %v", err)
}
if gotURL.String() != wantURL.String() {
t.Fatalf("WS proxy = %q, want %q", gotURL.String(), wantURL.String())
}
}
func TestApplyDiscordProxy_InvalidProxyURL(t *testing.T) {
session, err := discordgo.New("Bot test-token")
if err != nil {
t.Fatalf("discordgo.New() error: %v", err)
}
if err = applyDiscordProxy(session, "://bad-proxy"); err == nil {
t.Fatal("applyDiscordProxy() expected error for invalid proxy URL, got nil")
}
}
+13 -85
View File
@@ -3,23 +3,14 @@ package config
import (
"encoding/json"
"fmt"
"log"
"os"
"path/filepath"
"sync"
"sync/atomic"
"github.com/caarlos0/env/v11"
"github.com/joho/godotenv"
"github.com/sipeed/picoclaw/pkg/fileutil"
)
// dotenvOnce ensures .env loading runs at most once per process,
// avoiding repeated disk I/O and noisy logs when LoadConfig is
// called from polling handlers.
var dotenvOnce sync.Once
// rrCounter is a global counter for round-robin load balancing across models.
var rrCounter atomic.Uint64
@@ -189,6 +180,8 @@ type AgentDefaults struct {
MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"`
Temperature *float64 `json:"temperature,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"`
MaxToolIterations int `json:"max_tool_iterations" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"`
SummarizeMessageThreshold int `json:"summarize_message_threshold" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_MESSAGE_THRESHOLD"`
SummarizeTokenPercent int `json:"summarize_token_percent" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_TOKEN_PERCENT"`
MaxMediaSize int `json:"max_media_size,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_MEDIA_SIZE"`
}
@@ -280,6 +273,7 @@ type FeishuConfig struct {
type DiscordConfig struct {
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_DISCORD_ENABLED"`
Token string `json:"token" env:"PICOCLAW_CHANNELS_DISCORD_TOKEN"`
Proxy string `json:"proxy" env:"PICOCLAW_CHANNELS_DISCORD_PROXY"`
AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_DISCORD_ALLOW_FROM"`
MentionOnly bool `json:"mention_only" env:"PICOCLAW_CHANNELS_DISCORD_MENTION_ONLY"`
GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"`
@@ -437,7 +431,6 @@ type ProvidersConfig struct {
Antigravity ProviderConfig `json:"antigravity"`
Qwen ProviderConfig `json:"qwen"`
Mistral ProviderConfig `json:"mistral"`
Opencode ProviderConfig `json:"opencode"`
}
// IsEmpty checks if all provider configs are empty (no API keys or API bases set)
@@ -461,8 +454,7 @@ func (p ProvidersConfig) IsEmpty() bool {
p.GitHubCopilot.APIKey == "" && p.GitHubCopilot.APIBase == "" &&
p.Antigravity.APIKey == "" && p.Antigravity.APIBase == "" &&
p.Qwen.APIKey == "" && p.Qwen.APIBase == "" &&
p.Mistral.APIKey == "" && p.Mistral.APIBase == "" &&
p.Opencode.APIKey == "" && p.Opencode.APIBase == ""
p.Mistral.APIKey == "" && p.Mistral.APIBase == ""
}
// MarshalJSON implements custom JSON marshaling for ProvidersConfig
@@ -555,10 +547,14 @@ type PerplexityConfig struct {
MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_PERPLEXITY_MAX_RESULTS"`
}
type ExaConfig struct {
Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_WEB_EXA_ENABLED"`
APIKey string `json:"api_key" env:"PICOCLAW_TOOLS_WEB_EXA_API_KEY"`
MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_EXA_MAX_RESULTS"`
type GLMSearchConfig struct {
Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_WEB_GLM_ENABLED"`
APIKey string `json:"api_key" env:"PICOCLAW_TOOLS_WEB_GLM_API_KEY"`
BaseURL string `json:"base_url" env:"PICOCLAW_TOOLS_WEB_GLM_BASE_URL"`
// SearchEngine specifies the search backend: "search_std" (default),
// "search_pro", "search_pro_sogou", or "search_pro_quark".
SearchEngine string `json:"search_engine" env:"PICOCLAW_TOOLS_WEB_GLM_SEARCH_ENGINE"`
MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_GLM_MAX_RESULTS"`
}
type WebToolsConfig struct {
@@ -566,7 +562,7 @@ type WebToolsConfig struct {
Tavily TavilyConfig `json:"tavily"`
DuckDuckGo DuckDuckGoConfig `json:"duckduckgo"`
Perplexity PerplexityConfig `json:"perplexity"`
Exa ExaConfig `json:"exa"`
GLMSearch GLMSearchConfig `json:"glm_search"`
// Proxy is an optional proxy URL for web tools (http/https/socks5/socks5h).
// For authenticated proxies, prefer HTTP_PROXY/HTTPS_PROXY env vars instead of embedding credentials in config.
Proxy string `json:"proxy,omitempty" env:"PICOCLAW_TOOLS_WEB_PROXY"`
@@ -658,35 +654,9 @@ type MCPConfig struct {
func LoadConfig(path string) (*Config, error) {
cfg := DefaultConfig()
// Load .env file from config directory (secrets, API keys, etc.)
// Guarded by sync.Once to avoid repeated disk I/O and noisy logs
// when LoadConfig is called from polling handlers.
dotenvOnce.Do(func() {
envFile := filepath.Join(filepath.Dir(path), ".env")
if err := godotenv.Load(envFile); err != nil {
if os.IsNotExist(err) {
log.Printf("[INFO] No .env file found at %s; skipping .env loading", envFile)
} else {
log.Printf("[WARN] Failed to load .env file from %s: %v", envFile, err)
}
}
})
data, err := os.ReadFile(path)
if err != nil {
if os.IsNotExist(err) {
// No config file — still apply env vars + overrides to default config
if err := env.Parse(cfg); err != nil {
return nil, err
}
loadProviderEnvOverrides(cfg)
cfg.migrateChannelConfigs()
if cfg.HasProvidersConfig() {
cfg.ModelList = ConvertProvidersToModelList(cfg)
}
if err := cfg.ValidateModelList(); err != nil {
return nil, err
}
return cfg, nil
}
return nil, err
@@ -714,9 +684,6 @@ func LoadConfig(path string) (*Config, error) {
return nil, err
}
// Load provider-specific env overrides (PICOCLAW_PROVIDERS_<NAME>_API_KEY, etc.)
loadProviderEnvOverrides(cfg)
// Migrate legacy channel config fields to new unified structures
cfg.migrateChannelConfigs()
@@ -865,42 +832,3 @@ func (c *Config) ValidateModelList() error {
}
return nil
}
// loadProviderEnvOverrides reads PICOCLAW_PROVIDERS_<NAME>_API_KEY and _API_BASE
// environment variables and sets them on the corresponding provider config fields.
// This enables storing provider secrets in .env files without using struct tags.
func loadProviderEnvOverrides(cfg *Config) {
providers := []struct {
name string
apiKey *string
base *string
}{
{"ANTHROPIC", &cfg.Providers.Anthropic.APIKey, &cfg.Providers.Anthropic.APIBase},
{"OPENAI", &cfg.Providers.OpenAI.APIKey, &cfg.Providers.OpenAI.APIBase},
{"LITELLM", &cfg.Providers.LiteLLM.APIKey, &cfg.Providers.LiteLLM.APIBase},
{"OPENROUTER", &cfg.Providers.OpenRouter.APIKey, &cfg.Providers.OpenRouter.APIBase},
{"GROQ", &cfg.Providers.Groq.APIKey, &cfg.Providers.Groq.APIBase},
{"ZHIPU", &cfg.Providers.Zhipu.APIKey, &cfg.Providers.Zhipu.APIBase},
{"GEMINI", &cfg.Providers.Gemini.APIKey, &cfg.Providers.Gemini.APIBase},
{"NVIDIA", &cfg.Providers.Nvidia.APIKey, &cfg.Providers.Nvidia.APIBase},
{"OLLAMA", &cfg.Providers.Ollama.APIKey, &cfg.Providers.Ollama.APIBase},
{"MOONSHOT", &cfg.Providers.Moonshot.APIKey, &cfg.Providers.Moonshot.APIBase},
{"SHENGSUANYUN", &cfg.Providers.ShengSuanYun.APIKey, &cfg.Providers.ShengSuanYun.APIBase},
{"DEEPSEEK", &cfg.Providers.DeepSeek.APIKey, &cfg.Providers.DeepSeek.APIBase},
{"MISTRAL", &cfg.Providers.Mistral.APIKey, &cfg.Providers.Mistral.APIBase},
{"VLLM", &cfg.Providers.VLLM.APIKey, &cfg.Providers.VLLM.APIBase},
{"CEREBRAS", &cfg.Providers.Cerebras.APIKey, &cfg.Providers.Cerebras.APIBase},
{"VOLCENGINE", &cfg.Providers.VolcEngine.APIKey, &cfg.Providers.VolcEngine.APIBase},
{"QWEN", &cfg.Providers.Qwen.APIKey, &cfg.Providers.Qwen.APIBase},
// Note: GitHubCopilot and Antigravity use different auth patterns (ConnectMode/AuthMethod),
// not standard APIKey/APIBase, so they are not included here.
}
for _, p := range providers {
if v, ok := os.LookupEnv("PICOCLAW_PROVIDERS_" + p.name + "_API_KEY"); ok {
*p.apiKey = v
}
if v, ok := os.LookupEnv("PICOCLAW_PROVIDERS_" + p.name + "_API_BASE"); ok {
*p.base = v
}
}
}
+12 -96
View File
@@ -6,7 +6,6 @@ import (
"path/filepath"
"runtime"
"strings"
"sync"
"testing"
)
@@ -436,6 +435,18 @@ func TestLoadConfig_WebToolsProxy(t *testing.T) {
}
// TestDefaultConfig_DMScope verifies the default dm_scope value
// TestDefaultConfig_SummarizationThresholds verifies summarization defaults
func TestDefaultConfig_SummarizationThresholds(t *testing.T) {
cfg := DefaultConfig()
if cfg.Agents.Defaults.SummarizeMessageThreshold != 20 {
t.Errorf("SummarizeMessageThreshold = %d, want 20", cfg.Agents.Defaults.SummarizeMessageThreshold)
}
if cfg.Agents.Defaults.SummarizeTokenPercent != 75 {
t.Errorf("SummarizeTokenPercent = %d, want 75", cfg.Agents.Defaults.SummarizeTokenPercent)
}
}
func TestDefaultConfig_DMScope(t *testing.T) {
cfg := DefaultConfig()
@@ -468,98 +479,3 @@ func TestDefaultConfig_WorkspacePath_WithPicoclawHome(t *testing.T) {
t.Errorf("Workspace path with PICOCLAW_HOME = %q, want %q", cfg.Agents.Defaults.Workspace, want)
}
}
func TestLoadConfig_DotenvFileLoaded(t *testing.T) {
// Reset sync.Once so .env loading runs for this test
dotenvOnce = sync.Once{}
dir := t.TempDir()
configPath := filepath.Join(dir, "config.json")
// Write a minimal config.json
if err := os.WriteFile(configPath, []byte(`{}`), 0o600); err != nil {
t.Fatalf("WriteFile config: %v", err)
}
// Write a .env file with a provider API key
envFile := filepath.Join(dir, ".env")
if err := os.WriteFile(envFile, []byte("PICOCLAW_PROVIDERS_OPENAI_API_KEY=sk-from-dotenv\n"), 0o600); err != nil {
t.Fatalf("WriteFile .env: %v", err)
}
// Clear the env var first to ensure it comes from .env
t.Setenv("PICOCLAW_PROVIDERS_OPENAI_API_KEY", "")
os.Unsetenv("PICOCLAW_PROVIDERS_OPENAI_API_KEY")
cfg, err := LoadConfig(configPath)
if err != nil {
t.Fatalf("LoadConfig() error: %v", err)
}
if cfg.Providers.OpenAI.APIKey != "sk-from-dotenv" {
t.Errorf("OpenAI.APIKey = %q, want %q", cfg.Providers.OpenAI.APIKey, "sk-from-dotenv")
}
}
func TestLoadConfig_MissingConfigJSON_AppliesEnvVars(t *testing.T) {
// Reset sync.Once so .env loading runs for this test
dotenvOnce = sync.Once{}
dir := t.TempDir()
configPath := filepath.Join(dir, "config.json") // does NOT exist
t.Setenv("PICOCLAW_PROVIDERS_ANTHROPIC_API_KEY", "sk-anthropic-test")
cfg, err := LoadConfig(configPath)
if err != nil {
t.Fatalf("LoadConfig() error: %v", err)
}
if cfg.Providers.Anthropic.APIKey != "sk-anthropic-test" {
t.Errorf("Anthropic.APIKey = %q, want %q", cfg.Providers.Anthropic.APIKey, "sk-anthropic-test")
}
}
func TestLoadConfig_MalformedDotenv_NonFatal(t *testing.T) {
// Reset sync.Once so .env loading runs for this test
dotenvOnce = sync.Once{}
dir := t.TempDir()
configPath := filepath.Join(dir, "config.json")
// Write a minimal config.json
if err := os.WriteFile(configPath, []byte(`{}`), 0o600); err != nil {
t.Fatalf("WriteFile config: %v", err)
}
// Write a .env file with genuinely malformed content (bare key without '=',
// mixed with a valid line) to verify godotenv.Load errors are non-fatal.
envFile := filepath.Join(dir, ".env")
if err := os.WriteFile(envFile, []byte("THIS_LINE_HAS_NO_EQUALS\nVALID_KEY=valid_value\n"), 0o600); err != nil {
t.Fatalf("WriteFile .env: %v", err)
}
// LoadConfig should not fail even with malformed .env content
cfg, err := LoadConfig(configPath)
if err != nil {
t.Fatalf("LoadConfig() should not fail with .env issues, got error: %v", err)
}
if cfg == nil {
t.Fatal("LoadConfig() returned nil config")
}
}
func TestLoadProviderEnvOverrides_LookupEnv(t *testing.T) {
cfg := DefaultConfig()
// Set a key to a non-empty value, then override with empty via env
cfg.Providers.OpenRouter.APIBase = "https://original.com"
t.Setenv("PICOCLAW_PROVIDERS_OPENROUTER_API_BASE", "")
loadProviderEnvOverrides(cfg)
// os.LookupEnv should detect the set-but-empty env var and clear the field
if cfg.Providers.OpenRouter.APIBase != "" {
t.Errorf("OpenRouter.APIBase = %q, want empty (overridden by empty env var)", cfg.Providers.OpenRouter.APIBase)
}
}
+15 -11
View File
@@ -26,13 +26,15 @@ func DefaultConfig() *Config {
return &Config{
Agents: AgentsConfig{
Defaults: AgentDefaults{
Workspace: workspacePath,
RestrictToWorkspace: true,
Provider: "",
Model: "",
MaxTokens: 32768,
Temperature: nil, // nil means use provider default
MaxToolIterations: 50,
Workspace: workspacePath,
RestrictToWorkspace: true,
Provider: "",
Model: "",
MaxTokens: 32768,
Temperature: nil, // nil means use provider default
MaxToolIterations: 50,
SummarizeMessageThreshold: 20,
SummarizeTokenPercent: 75,
},
},
Bindings: []AgentBinding{},
@@ -341,10 +343,12 @@ func DefaultConfig() *Config {
APIKey: "",
MaxResults: 5,
},
Exa: ExaConfig{
Enabled: false,
APIKey: "",
MaxResults: 5,
GLMSearch: GLMSearchConfig{
Enabled: false,
APIKey: "",
BaseURL: "https://open.bigmodel.cn/api/paas/v4/web_search",
SearchEngine: "search_std",
MaxResults: 5,
},
},
Cron: CronToolsConfig{
+2 -21
View File
@@ -225,7 +225,7 @@ func ConvertProvidersToModelList(cfg *Config) []ModelConfig {
},
},
{
providerNames: []string{"moonshot", "kimi", "kimi-code"},
providerNames: []string{"moonshot", "kimi"},
protocol: "moonshot",
buildConfig: func(p ProvidersConfig) (ModelConfig, bool) {
if p.Moonshot.APIKey == "" && p.Moonshot.APIBase == "" {
@@ -373,23 +373,6 @@ func ConvertProvidersToModelList(cfg *Config) []ModelConfig {
}, true
},
},
{
providerNames: []string{"opencode"},
protocol: "opencode",
buildConfig: func(p ProvidersConfig) (ModelConfig, bool) {
if p.Opencode.APIKey == "" && p.Opencode.APIBase == "" {
return ModelConfig{}, false
}
return ModelConfig{
ModelName: "opencode",
Model: "opencode/auto",
APIKey: p.Opencode.APIKey,
APIBase: p.Opencode.APIBase,
Proxy: p.Opencode.Proxy,
RequestTimeout: p.Opencode.RequestTimeout,
}, true
},
},
}
// Process each provider migration
@@ -401,9 +384,7 @@ func ConvertProvidersToModelList(cfg *Config) []ModelConfig {
// Check if this is the user's configured provider
if slices.Contains(m.providerNames, userProvider) && userModel != "" {
// Use the user's configured model instead of default.
// Also set ModelName so GetModelConfig(userModel) can find this entry.
mc.ModelName = userModel
// Use the user's configured model instead of default
mc.Model = buildModelWithProtocol(m.protocol, userModel)
} else if userProvider == "" && userModel != "" && !legacyModelNameApplied {
// Legacy config: no explicit provider field but model is specified
-129
View File
@@ -160,7 +160,6 @@ func TestConvertProvidersToModelList_AllProviders(t *testing.T) {
Antigravity: ProviderConfig{AuthMethod: "oauth"},
Qwen: ProviderConfig{APIKey: "key17"},
Mistral: ProviderConfig{APIKey: "key18"},
Opencode: ProviderConfig{APIKey: "key19"},
},
}
@@ -580,65 +579,6 @@ func TestBuildModelWithProtocol_DifferentPrefix(t *testing.T) {
}
}
func TestConvertProvidersToModelList_Opencode(t *testing.T) {
cfg := &Config{
Providers: ProvidersConfig{
Opencode: ProviderConfig{
APIKey: "oc-test-key",
APIBase: "https://custom.opencode.ai/v1",
Proxy: "http://proxy:9090",
RequestTimeout: 60,
},
},
}
result := ConvertProvidersToModelList(cfg)
if len(result) != 1 {
t.Fatalf("len(result) = %d, want 1", len(result))
}
mc := result[0]
if mc.ModelName != "opencode" {
t.Errorf("ModelName = %q, want %q", mc.ModelName, "opencode")
}
if mc.Model != "opencode/auto" {
t.Errorf("Model = %q, want %q", mc.Model, "opencode/auto")
}
if mc.APIKey != "oc-test-key" {
t.Errorf("APIKey = %q, want %q", mc.APIKey, "oc-test-key")
}
if mc.APIBase != "https://custom.opencode.ai/v1" {
t.Errorf("APIBase = %q, want %q", mc.APIBase, "https://custom.opencode.ai/v1")
}
if mc.Proxy != "http://proxy:9090" {
t.Errorf("Proxy = %q, want %q", mc.Proxy, "http://proxy:9090")
}
if mc.RequestTimeout != 60 {
t.Errorf("RequestTimeout = %d, want %d", mc.RequestTimeout, 60)
}
}
func TestConvertProvidersToModelList_Opencode_APIBaseOnly(t *testing.T) {
cfg := &Config{
Providers: ProvidersConfig{
Opencode: ProviderConfig{
APIBase: "https://custom.opencode.ai/v1",
},
},
}
result := ConvertProvidersToModelList(cfg)
if len(result) != 1 {
t.Fatalf("len(result) = %d, want 1 (APIBase-only should create entry)", len(result))
}
if result[0].ModelName != "opencode" {
t.Errorf("ModelName = %q, want %q", result[0].ModelName, "opencode")
}
}
// Test for legacy config with protocol prefix in model name
func TestConvertProvidersToModelList_LegacyModelWithProtocolPrefix(t *testing.T) {
cfg := &Config{
@@ -669,72 +609,3 @@ func TestConvertProvidersToModelList_LegacyModelWithProtocolPrefix(t *testing.T)
t.Errorf("Model = %q, want %q (should not duplicate prefix)", result[0].Model, "openrouter/auto")
}
}
// Test that ModelName is set to the user's configured model when provider matches.
// This ensures GetModelConfig(userModel) can find the migrated entry.
// Regression test for: gateway startup failure when user model differs from provider name.
func TestConvertProvidersToModelList_ModelNameMatchesUserModel(t *testing.T) {
cfg := &Config{
Agents: AgentsConfig{
Defaults: AgentDefaults{
Provider: "moonshot",
Model: "k2p5",
},
},
Providers: ProvidersConfig{
Moonshot: ProviderConfig{APIKey: "sk-kimi-test"},
},
}
result := ConvertProvidersToModelList(cfg)
if len(result) != 1 {
t.Fatalf("len(result) = %d, want 1", len(result))
}
// ModelName must match the user's configured model, not the provider name.
// Without this, GetModelConfig("k2p5") would fail because it would look
// for ModelName == "k2p5" but find ModelName == "moonshot".
if result[0].ModelName != "k2p5" {
t.Errorf("ModelName = %q, want %q (must match user's model for GetModelConfig lookup)", result[0].ModelName, "k2p5")
}
if result[0].Model != "moonshot/k2p5" {
t.Errorf("Model = %q, want %q", result[0].Model, "moonshot/k2p5")
}
// Other providers (not matching the user's configured provider) should keep their provider name
cfg2 := &Config{
Agents: AgentsConfig{
Defaults: AgentDefaults{
Provider: "moonshot",
Model: "k2p5",
},
},
Providers: ProvidersConfig{
OpenAI: OpenAIProviderConfig{ProviderConfig: ProviderConfig{APIKey: "sk-openai"}},
Moonshot: ProviderConfig{APIKey: "sk-kimi-test"},
},
}
result2 := ConvertProvidersToModelList(cfg2)
if len(result2) != 2 {
t.Fatalf("len(result2) = %d, want 2", len(result2))
}
for _, mc := range result2 {
switch {
case mc.APIKey == "sk-openai":
// OpenAI is not the user's provider, should keep default ModelName
if mc.ModelName != "openai" {
t.Errorf("OpenAI ModelName = %q, want %q (non-matching provider keeps default)", mc.ModelName, "openai")
}
case mc.APIKey == "sk-kimi-test":
// Moonshot is the user's provider, ModelName must be the user's model
if mc.ModelName != "k2p5" {
t.Errorf("Moonshot ModelName = %q, want %q (matching provider uses user model)", mc.ModelName, "k2p5")
}
}
}
}
+460
View File
@@ -0,0 +1,460 @@
package memory
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"hash/fnv"
"log"
"os"
"path/filepath"
"strings"
"sync"
"time"
"github.com/sipeed/picoclaw/pkg/fileutil"
"github.com/sipeed/picoclaw/pkg/providers"
)
const (
// numLockShards is the fixed number of mutexes used to serialize
// per-session access. Using a sharded array instead of a map keeps
// memory bounded regardless of how many sessions are created over
// the lifetime of the process — important for a long-running daemon.
numLockShards = 64
// maxLineSize is the maximum size of a single JSON line in a .jsonl
// file. Tool results (read_file, web search, etc.) can be large, so
// we set a generous limit. The scanner starts at 64 KB and grows
// only as needed up to this cap.
maxLineSize = 10 * 1024 * 1024 // 10 MB
)
// sessionMeta holds per-session metadata stored in a .meta.json file.
type sessionMeta struct {
Key string `json:"key"`
Summary string `json:"summary"`
Skip int `json:"skip"`
Count int `json:"count"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// JSONLStore implements Store using append-only JSONL files.
//
// Each session is stored as two files:
//
// {sanitized_key}.jsonl — one JSON-encoded message per line, append-only
// {sanitized_key}.meta.json — session metadata (summary, logical truncation offset)
//
// Messages are never physically deleted from the JSONL file. Instead,
// TruncateHistory records a "skip" offset in the metadata file and
// GetHistory ignores lines before that offset. This keeps all writes
// append-only, which is both fast and crash-safe.
type JSONLStore struct {
dir string
locks [numLockShards]sync.Mutex
}
// NewJSONLStore creates a new JSONL-backed store rooted at dir.
func NewJSONLStore(dir string) (*JSONLStore, error) {
err := os.MkdirAll(dir, 0o755)
if err != nil {
return nil, fmt.Errorf("memory: create directory: %w", err)
}
return &JSONLStore{dir: dir}, nil
}
// sessionLock returns a mutex for the given session key.
// Keys are mapped to a fixed pool of shards via FNV hash, so
// memory usage is O(1) regardless of total session count.
func (s *JSONLStore) sessionLock(key string) *sync.Mutex {
h := fnv.New32a()
h.Write([]byte(key))
return &s.locks[h.Sum32()%numLockShards]
}
func (s *JSONLStore) jsonlPath(key string) string {
return filepath.Join(s.dir, sanitizeKey(key)+".jsonl")
}
func (s *JSONLStore) metaPath(key string) string {
return filepath.Join(s.dir, sanitizeKey(key)+".meta.json")
}
// sanitizeKey converts a session key to a safe filename component.
// Mirrors pkg/session.sanitizeFilename so that migration paths match.
//
// Note: this is a lossy mapping — "telegram:123" and "telegram_123"
// both produce the same filename. This is an intentional tradeoff:
// keys with colons (e.g. from channels) are by far the common case,
// and a bidirectional encoding (like URL-encoding) would complicate
// file listings and debugging.
func sanitizeKey(key string) string {
return strings.ReplaceAll(key, ":", "_")
}
// readMeta loads the metadata file for a session.
// Returns a zero-value sessionMeta if the file does not exist.
func (s *JSONLStore) readMeta(key string) (sessionMeta, error) {
data, err := os.ReadFile(s.metaPath(key))
if os.IsNotExist(err) {
return sessionMeta{Key: key}, nil
}
if err != nil {
return sessionMeta{}, fmt.Errorf("memory: read meta: %w", err)
}
var meta sessionMeta
err = json.Unmarshal(data, &meta)
if err != nil {
return sessionMeta{}, fmt.Errorf("memory: decode meta: %w", err)
}
return meta, nil
}
// writeMeta atomically writes the metadata file using the project's
// standard WriteFileAtomic (temp + fsync + rename).
func (s *JSONLStore) writeMeta(key string, meta sessionMeta) error {
data, err := json.MarshalIndent(meta, "", " ")
if err != nil {
return fmt.Errorf("memory: encode meta: %w", err)
}
return fileutil.WriteFileAtomic(s.metaPath(key), data, 0o644)
}
// readMessages reads valid JSON lines from a .jsonl file, skipping
// the first `skip` lines without unmarshaling them. This avoids the
// cost of json.Unmarshal on logically truncated messages.
// Malformed trailing lines (e.g. from a crash) are silently skipped.
func readMessages(path string, skip int) ([]providers.Message, error) {
f, err := os.Open(path)
if os.IsNotExist(err) {
return []providers.Message{}, nil
}
if err != nil {
return nil, fmt.Errorf("memory: open jsonl: %w", err)
}
defer f.Close()
var msgs []providers.Message
scanner := bufio.NewScanner(f)
// Allow large lines for tool results (read_file, web search, etc.).
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
lineNum := 0
for scanner.Scan() {
line := scanner.Bytes()
if len(line) == 0 {
continue
}
lineNum++
if lineNum <= skip {
continue
}
var msg providers.Message
if err := json.Unmarshal(line, &msg); err != nil {
// Corrupt line — likely a partial write from a crash.
// Log so operators know data was skipped, but don't
// fail the entire read; this is the standard JSONL
// recovery pattern.
log.Printf("memory: skipping corrupt line %d in %s: %v",
lineNum, filepath.Base(path), err)
continue
}
msgs = append(msgs, msg)
}
if scanner.Err() != nil {
return nil, fmt.Errorf("memory: scan jsonl: %w", scanner.Err())
}
if msgs == nil {
msgs = []providers.Message{}
}
return msgs, nil
}
// countLines counts the total number of non-empty lines in a .jsonl file.
// Used by TruncateHistory to reconcile a stale meta.Count without
// the overhead of unmarshaling every message.
func countLines(path string) (int, error) {
f, err := os.Open(path)
if os.IsNotExist(err) {
return 0, nil
}
if err != nil {
return 0, fmt.Errorf("memory: open jsonl: %w", err)
}
defer f.Close()
n := 0
scanner := bufio.NewScanner(f)
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
for scanner.Scan() {
if len(scanner.Bytes()) > 0 {
n++
}
}
return n, scanner.Err()
}
func (s *JSONLStore) AddMessage(
_ context.Context, sessionKey, role, content string,
) error {
return s.addMsg(sessionKey, providers.Message{
Role: role,
Content: content,
})
}
func (s *JSONLStore) AddFullMessage(
_ context.Context, sessionKey string, msg providers.Message,
) error {
return s.addMsg(sessionKey, msg)
}
// addMsg is the shared implementation for AddMessage and AddFullMessage.
func (s *JSONLStore) addMsg(sessionKey string, msg providers.Message) error {
l := s.sessionLock(sessionKey)
l.Lock()
defer l.Unlock()
// Append the message as a single JSON line.
line, err := json.Marshal(msg)
if err != nil {
return fmt.Errorf("memory: marshal message: %w", err)
}
line = append(line, '\n')
f, err := os.OpenFile(
s.jsonlPath(sessionKey),
os.O_CREATE|os.O_WRONLY|os.O_APPEND,
0o644,
)
if err != nil {
return fmt.Errorf("memory: open jsonl for append: %w", err)
}
_, writeErr := f.Write(line)
if writeErr != nil {
f.Close()
return fmt.Errorf("memory: append message: %w", writeErr)
}
// Flush to physical storage before closing. This matches the
// durability guarantee of writeMeta and rewriteJSONL (which use
// WriteFileAtomic with fsync). Without Sync, a power loss could
// leave the append in the kernel page cache only — lost on reboot.
if syncErr := f.Sync(); syncErr != nil {
f.Close()
return fmt.Errorf("memory: sync jsonl: %w", syncErr)
}
if closeErr := f.Close(); closeErr != nil {
return fmt.Errorf("memory: close jsonl: %w", closeErr)
}
// Update metadata.
meta, err := s.readMeta(sessionKey)
if err != nil {
return err
}
now := time.Now()
if meta.Count == 0 && meta.CreatedAt.IsZero() {
meta.CreatedAt = now
}
meta.Count++
meta.UpdatedAt = now
return s.writeMeta(sessionKey, meta)
}
func (s *JSONLStore) GetHistory(
_ context.Context, sessionKey string,
) ([]providers.Message, error) {
l := s.sessionLock(sessionKey)
l.Lock()
defer l.Unlock()
meta, err := s.readMeta(sessionKey)
if err != nil {
return nil, err
}
// Pass meta.Skip so readMessages skips those lines without
// unmarshaling them — avoids wasted CPU on truncated messages.
msgs, err := readMessages(s.jsonlPath(sessionKey), meta.Skip)
if err != nil {
return nil, err
}
return msgs, nil
}
func (s *JSONLStore) GetSummary(
_ context.Context, sessionKey string,
) (string, error) {
l := s.sessionLock(sessionKey)
l.Lock()
defer l.Unlock()
meta, err := s.readMeta(sessionKey)
if err != nil {
return "", err
}
return meta.Summary, nil
}
func (s *JSONLStore) SetSummary(
_ context.Context, sessionKey, summary string,
) error {
l := s.sessionLock(sessionKey)
l.Lock()
defer l.Unlock()
meta, err := s.readMeta(sessionKey)
if err != nil {
return err
}
now := time.Now()
if meta.CreatedAt.IsZero() {
meta.CreatedAt = now
}
meta.Summary = summary
meta.UpdatedAt = now
return s.writeMeta(sessionKey, meta)
}
func (s *JSONLStore) TruncateHistory(
_ context.Context, sessionKey string, keepLast int,
) error {
l := s.sessionLock(sessionKey)
l.Lock()
defer l.Unlock()
meta, err := s.readMeta(sessionKey)
if err != nil {
return err
}
// Always reconcile meta.Count with the actual line count on disk.
// A crash between the JSONL append and the meta update in addMsg
// leaves meta.Count stale (e.g. file has 101 lines but meta says
// 100). Counting lines is cheap — no unmarshal, just a scan — and
// TruncateHistory is not a hot path, so always re-count.
n, countErr := countLines(s.jsonlPath(sessionKey))
if countErr != nil {
return countErr
}
meta.Count = n
if keepLast <= 0 {
meta.Skip = meta.Count
} else {
effective := meta.Count - meta.Skip
if keepLast < effective {
meta.Skip = meta.Count - keepLast
}
}
meta.UpdatedAt = time.Now()
return s.writeMeta(sessionKey, meta)
}
func (s *JSONLStore) SetHistory(
_ context.Context,
sessionKey string,
history []providers.Message,
) error {
l := s.sessionLock(sessionKey)
l.Lock()
defer l.Unlock()
meta, err := s.readMeta(sessionKey)
if err != nil {
return err
}
now := time.Now()
if meta.CreatedAt.IsZero() {
meta.CreatedAt = now
}
meta.Skip = 0
meta.Count = len(history)
meta.UpdatedAt = now
// Write meta BEFORE rewriting the JSONL file. If we crash between
// the two writes, meta has Skip=0 and the old file is still intact,
// so GetHistory reads from line 1 — returning "too many" messages
// rather than losing data. The next SetHistory call corrects this.
err = s.writeMeta(sessionKey, meta)
if err != nil {
return err
}
return s.rewriteJSONL(sessionKey, history)
}
// Compact physically rewrites the JSONL file, dropping all logically
// skipped lines. This reclaims disk space that accumulates after
// repeated TruncateHistory calls.
//
// It is safe to call at any time; if there is nothing to compact
// (skip == 0) the method returns immediately.
func (s *JSONLStore) Compact(
_ context.Context, sessionKey string,
) error {
l := s.sessionLock(sessionKey)
l.Lock()
defer l.Unlock()
meta, err := s.readMeta(sessionKey)
if err != nil {
return err
}
if meta.Skip == 0 {
return nil
}
// Read only the active messages, skipping truncated lines
// without unmarshaling them.
active, err := readMessages(s.jsonlPath(sessionKey), meta.Skip)
if err != nil {
return err
}
// Write meta BEFORE rewriting the JSONL file. If the process
// crashes between the two writes, meta has Skip=0 and the old
// (uncompacted) file is still intact, so GetHistory reads from
// line 1 — returning previously-truncated messages rather than
// losing data. The next Compact or TruncateHistory corrects this.
meta.Skip = 0
meta.Count = len(active)
meta.UpdatedAt = time.Now()
err = s.writeMeta(sessionKey, meta)
if err != nil {
return err
}
return s.rewriteJSONL(sessionKey, active)
}
// rewriteJSONL atomically replaces the JSONL file with the given messages
// using the project's standard WriteFileAtomic (temp + fsync + rename).
func (s *JSONLStore) rewriteJSONL(
sessionKey string, msgs []providers.Message,
) error {
var buf bytes.Buffer
for i, msg := range msgs {
line, err := json.Marshal(msg)
if err != nil {
return fmt.Errorf("memory: marshal message %d: %w", i, err)
}
buf.Write(line)
buf.WriteByte('\n')
}
return fileutil.WriteFileAtomic(s.jsonlPath(sessionKey), buf.Bytes(), 0o644)
}
func (s *JSONLStore) Close() error {
return nil
}
+835
View File
@@ -0,0 +1,835 @@
package memory
import (
"context"
"os"
"path/filepath"
"sync"
"testing"
"github.com/sipeed/picoclaw/pkg/providers"
)
func newTestStore(t *testing.T) *JSONLStore {
t.Helper()
store, err := NewJSONLStore(t.TempDir())
if err != nil {
t.Fatalf("NewJSONLStore: %v", err)
}
return store
}
func TestNewJSONLStore_CreatesDirectory(t *testing.T) {
dir := filepath.Join(t.TempDir(), "nested", "sessions")
store, err := NewJSONLStore(dir)
if err != nil {
t.Fatalf("NewJSONLStore: %v", err)
}
defer store.Close()
info, err := os.Stat(dir)
if err != nil {
t.Fatalf("Stat: %v", err)
}
if !info.IsDir() {
t.Errorf("expected directory, got file")
}
}
func TestAddMessage_BasicRoundtrip(t *testing.T) {
store := newTestStore(t)
ctx := context.Background()
err := store.AddMessage(ctx, "s1", "user", "hello")
if err != nil {
t.Fatalf("AddMessage: %v", err)
}
err = store.AddMessage(ctx, "s1", "assistant", "hi there")
if err != nil {
t.Fatalf("AddMessage: %v", err)
}
history, err := store.GetHistory(ctx, "s1")
if err != nil {
t.Fatalf("GetHistory: %v", err)
}
if len(history) != 2 {
t.Fatalf("expected 2 messages, got %d", len(history))
}
if history[0].Role != "user" || history[0].Content != "hello" {
t.Errorf("msg[0] = %+v", history[0])
}
if history[1].Role != "assistant" || history[1].Content != "hi there" {
t.Errorf("msg[1] = %+v", history[1])
}
}
func TestAddMessage_AutoCreatesSession(t *testing.T) {
store := newTestStore(t)
ctx := context.Background()
// Adding a message to a non-existent session should work.
err := store.AddMessage(ctx, "new-session", "user", "first message")
if err != nil {
t.Fatalf("AddMessage: %v", err)
}
history, err := store.GetHistory(ctx, "new-session")
if err != nil {
t.Fatalf("GetHistory: %v", err)
}
if len(history) != 1 {
t.Fatalf("expected 1 message, got %d", len(history))
}
}
func TestAddFullMessage_WithToolCalls(t *testing.T) {
store := newTestStore(t)
ctx := context.Background()
msg := providers.Message{
Role: "assistant",
Content: "Let me search that.",
ToolCalls: []providers.ToolCall{
{
ID: "call_abc",
Type: "function",
Function: &providers.FunctionCall{
Name: "web_search",
Arguments: `{"q":"golang jsonl"}`,
},
},
},
}
err := store.AddFullMessage(ctx, "tc", msg)
if err != nil {
t.Fatalf("AddFullMessage: %v", err)
}
history, err := store.GetHistory(ctx, "tc")
if err != nil {
t.Fatalf("GetHistory: %v", err)
}
if len(history) != 1 {
t.Fatalf("expected 1, got %d", len(history))
}
if len(history[0].ToolCalls) != 1 {
t.Fatalf("expected 1 tool call, got %d", len(history[0].ToolCalls))
}
tc := history[0].ToolCalls[0]
if tc.ID != "call_abc" {
t.Errorf("tool call ID = %q", tc.ID)
}
if tc.Function == nil || tc.Function.Name != "web_search" {
t.Errorf("tool call function = %+v", tc.Function)
}
}
func TestAddFullMessage_ToolCallID(t *testing.T) {
store := newTestStore(t)
ctx := context.Background()
msg := providers.Message{
Role: "tool",
Content: "search results here",
ToolCallID: "call_abc",
}
err := store.AddFullMessage(ctx, "tr", msg)
if err != nil {
t.Fatalf("AddFullMessage: %v", err)
}
history, err := store.GetHistory(ctx, "tr")
if err != nil {
t.Fatalf("GetHistory: %v", err)
}
if len(history) != 1 {
t.Fatalf("expected 1, got %d", len(history))
}
if history[0].ToolCallID != "call_abc" {
t.Errorf("ToolCallID = %q", history[0].ToolCallID)
}
}
func TestGetHistory_EmptySession(t *testing.T) {
store := newTestStore(t)
ctx := context.Background()
history, err := store.GetHistory(ctx, "nonexistent")
if err != nil {
t.Fatalf("GetHistory: %v", err)
}
if history == nil {
t.Fatal("expected non-nil empty slice")
}
if len(history) != 0 {
t.Errorf("expected 0 messages, got %d", len(history))
}
}
func TestGetHistory_Ordering(t *testing.T) {
store := newTestStore(t)
ctx := context.Background()
for i := 0; i < 5; i++ {
err := store.AddMessage(
ctx, "order",
"user",
string(rune('a'+i)),
)
if err != nil {
t.Fatalf("AddMessage(%d): %v", i, err)
}
}
history, err := store.GetHistory(ctx, "order")
if err != nil {
t.Fatalf("GetHistory: %v", err)
}
if len(history) != 5 {
t.Fatalf("expected 5, got %d", len(history))
}
for i := 0; i < 5; i++ {
expected := string(rune('a' + i))
if history[i].Content != expected {
t.Errorf("msg[%d].Content = %q, want %q", i, history[i].Content, expected)
}
}
}
func TestSetSummary_GetSummary(t *testing.T) {
store := newTestStore(t)
ctx := context.Background()
// No summary yet.
summary, err := store.GetSummary(ctx, "s1")
if err != nil {
t.Fatalf("GetSummary: %v", err)
}
if summary != "" {
t.Errorf("expected empty, got %q", summary)
}
// Set a summary.
err = store.SetSummary(ctx, "s1", "talked about Go")
if err != nil {
t.Fatalf("SetSummary: %v", err)
}
summary, err = store.GetSummary(ctx, "s1")
if err != nil {
t.Fatalf("GetSummary: %v", err)
}
if summary != "talked about Go" {
t.Errorf("summary = %q", summary)
}
// Update summary.
err = store.SetSummary(ctx, "s1", "updated summary")
if err != nil {
t.Fatalf("SetSummary: %v", err)
}
summary, err = store.GetSummary(ctx, "s1")
if err != nil {
t.Fatalf("GetSummary: %v", err)
}
if summary != "updated summary" {
t.Errorf("summary = %q", summary)
}
}
func TestTruncateHistory_KeepLast(t *testing.T) {
store := newTestStore(t)
ctx := context.Background()
for i := 0; i < 10; i++ {
err := store.AddMessage(
ctx, "trunc",
"user",
string(rune('a'+i)),
)
if err != nil {
t.Fatalf("AddMessage: %v", err)
}
}
err := store.TruncateHistory(ctx, "trunc", 4)
if err != nil {
t.Fatalf("TruncateHistory: %v", err)
}
history, err := store.GetHistory(ctx, "trunc")
if err != nil {
t.Fatalf("GetHistory: %v", err)
}
if len(history) != 4 {
t.Fatalf("expected 4, got %d", len(history))
}
// Should be the last 4: g, h, i, j
if history[0].Content != "g" {
t.Errorf("first kept = %q, want 'g'", history[0].Content)
}
if history[3].Content != "j" {
t.Errorf("last kept = %q, want 'j'", history[3].Content)
}
}
func TestTruncateHistory_KeepZero(t *testing.T) {
store := newTestStore(t)
ctx := context.Background()
for i := 0; i < 5; i++ {
err := store.AddMessage(ctx, "empty", "user", "msg")
if err != nil {
t.Fatalf("AddMessage: %v", err)
}
}
err := store.TruncateHistory(ctx, "empty", 0)
if err != nil {
t.Fatalf("TruncateHistory: %v", err)
}
history, err := store.GetHistory(ctx, "empty")
if err != nil {
t.Fatalf("GetHistory: %v", err)
}
if len(history) != 0 {
t.Errorf("expected 0, got %d", len(history))
}
}
func TestTruncateHistory_KeepMoreThanExists(t *testing.T) {
store := newTestStore(t)
ctx := context.Background()
for i := 0; i < 3; i++ {
err := store.AddMessage(ctx, "few", "user", "msg")
if err != nil {
t.Fatalf("AddMessage: %v", err)
}
}
// Keep 100, but only 3 exist — should keep all.
err := store.TruncateHistory(ctx, "few", 100)
if err != nil {
t.Fatalf("TruncateHistory: %v", err)
}
history, err := store.GetHistory(ctx, "few")
if err != nil {
t.Fatalf("GetHistory: %v", err)
}
if len(history) != 3 {
t.Errorf("expected 3, got %d", len(history))
}
}
func TestSetHistory_ReplacesAll(t *testing.T) {
store := newTestStore(t)
ctx := context.Background()
// Add some initial messages.
for i := 0; i < 5; i++ {
err := store.AddMessage(ctx, "replace", "user", "old")
if err != nil {
t.Fatalf("AddMessage: %v", err)
}
}
// Replace with new history.
newHistory := []providers.Message{
{Role: "user", Content: "new1"},
{Role: "assistant", Content: "new2"},
}
err := store.SetHistory(ctx, "replace", newHistory)
if err != nil {
t.Fatalf("SetHistory: %v", err)
}
history, err := store.GetHistory(ctx, "replace")
if err != nil {
t.Fatalf("GetHistory: %v", err)
}
if len(history) != 2 {
t.Fatalf("expected 2, got %d", len(history))
}
if history[0].Content != "new1" || history[1].Content != "new2" {
t.Errorf("history = %+v", history)
}
}
func TestSetHistory_ResetsSkip(t *testing.T) {
store := newTestStore(t)
ctx := context.Background()
// Add messages and truncate.
for i := 0; i < 10; i++ {
err := store.AddMessage(ctx, "skip-reset", "user", "old")
if err != nil {
t.Fatalf("AddMessage: %v", err)
}
}
err := store.TruncateHistory(ctx, "skip-reset", 3)
if err != nil {
t.Fatalf("TruncateHistory: %v", err)
}
// SetHistory should reset skip to 0.
newHistory := []providers.Message{
{Role: "user", Content: "fresh"},
}
err = store.SetHistory(ctx, "skip-reset", newHistory)
if err != nil {
t.Fatalf("SetHistory: %v", err)
}
history, err := store.GetHistory(ctx, "skip-reset")
if err != nil {
t.Fatalf("GetHistory: %v", err)
}
if len(history) != 1 {
t.Fatalf("expected 1, got %d", len(history))
}
if history[0].Content != "fresh" {
t.Errorf("content = %q", history[0].Content)
}
}
func TestColonInKey(t *testing.T) {
store := newTestStore(t)
ctx := context.Background()
err := store.AddMessage(ctx, "telegram:123", "user", "hi")
if err != nil {
t.Fatalf("AddMessage: %v", err)
}
history, err := store.GetHistory(ctx, "telegram:123")
if err != nil {
t.Fatalf("GetHistory: %v", err)
}
if len(history) != 1 {
t.Fatalf("expected 1, got %d", len(history))
}
// Verify the file is named with underscore.
jsonlFile := filepath.Join(store.dir, "telegram_123.jsonl")
if _, statErr := os.Stat(jsonlFile); statErr != nil {
t.Errorf("expected file %s to exist: %v", jsonlFile, statErr)
}
}
func TestCompact_RemovesSkippedMessages(t *testing.T) {
store := newTestStore(t)
ctx := context.Background()
// Write 10 messages, then truncate to keep last 3.
for i := 0; i < 10; i++ {
err := store.AddMessage(ctx, "compact", "user", string(rune('a'+i)))
if err != nil {
t.Fatalf("AddMessage: %v", err)
}
}
err := store.TruncateHistory(ctx, "compact", 3)
if err != nil {
t.Fatalf("TruncateHistory: %v", err)
}
// Before compact: file still has 10 lines.
allOnDisk, err := readMessages(store.jsonlPath("compact"), 0)
if err != nil {
t.Fatalf("readMessages: %v", err)
}
if len(allOnDisk) != 10 {
t.Fatalf("before compact: expected 10 on disk, got %d", len(allOnDisk))
}
// Compact.
err = store.Compact(ctx, "compact")
if err != nil {
t.Fatalf("Compact: %v", err)
}
// After compact: file should have only 3 lines.
allOnDisk, err = readMessages(store.jsonlPath("compact"), 0)
if err != nil {
t.Fatalf("readMessages: %v", err)
}
if len(allOnDisk) != 3 {
t.Fatalf("after compact: expected 3 on disk, got %d", len(allOnDisk))
}
// GetHistory should still return the same 3 messages.
history, err := store.GetHistory(ctx, "compact")
if err != nil {
t.Fatalf("GetHistory: %v", err)
}
if len(history) != 3 {
t.Fatalf("expected 3, got %d", len(history))
}
if history[0].Content != "h" || history[2].Content != "j" {
t.Errorf("wrong content: %+v", history)
}
}
func TestCompact_NoOpWhenNoSkip(t *testing.T) {
store := newTestStore(t)
ctx := context.Background()
for i := 0; i < 5; i++ {
err := store.AddMessage(ctx, "noop", "user", "msg")
if err != nil {
t.Fatalf("AddMessage: %v", err)
}
}
// Compact without prior truncation — should be a no-op.
err := store.Compact(ctx, "noop")
if err != nil {
t.Fatalf("Compact: %v", err)
}
history, err := store.GetHistory(ctx, "noop")
if err != nil {
t.Fatalf("GetHistory: %v", err)
}
if len(history) != 5 {
t.Errorf("expected 5, got %d", len(history))
}
}
func TestCompact_ThenAppend(t *testing.T) {
store := newTestStore(t)
ctx := context.Background()
for i := 0; i < 8; i++ {
err := store.AddMessage(ctx, "cap", "user", string(rune('a'+i)))
if err != nil {
t.Fatalf("AddMessage: %v", err)
}
}
err := store.TruncateHistory(ctx, "cap", 2)
if err != nil {
t.Fatalf("TruncateHistory: %v", err)
}
err = store.Compact(ctx, "cap")
if err != nil {
t.Fatalf("Compact: %v", err)
}
// Append after compaction should work correctly.
err = store.AddMessage(ctx, "cap", "user", "new")
if err != nil {
t.Fatalf("AddMessage after compact: %v", err)
}
history, err := store.GetHistory(ctx, "cap")
if err != nil {
t.Fatalf("GetHistory: %v", err)
}
if len(history) != 3 {
t.Fatalf("expected 3, got %d", len(history))
}
// g, h (kept from truncation), new (appended after compaction).
if history[0].Content != "g" {
t.Errorf("first = %q, want 'g'", history[0].Content)
}
if history[2].Content != "new" {
t.Errorf("last = %q, want 'new'", history[2].Content)
}
}
func TestTruncateHistory_StaleMetaCount(t *testing.T) {
// Simulates a crash between JSONL append and meta update in addMsg:
// file has N+1 lines but meta.Count is still N. TruncateHistory must
// reconcile with the real line count so that keepLast is accurate.
store := newTestStore(t)
ctx := context.Background()
// Write 10 messages normally (meta.Count = 10).
for i := 0; i < 10; i++ {
err := store.AddMessage(ctx, "stale", "user", string(rune('a'+i)))
if err != nil {
t.Fatalf("AddMessage: %v", err)
}
}
// Simulate crash: append a line to JSONL but do NOT update meta.
// This leaves meta.Count = 10 while the file has 11 lines.
jsonlPath := store.jsonlPath("stale")
f, err := os.OpenFile(jsonlPath, os.O_WRONLY|os.O_APPEND, 0o644)
if err != nil {
t.Fatalf("open for append: %v", err)
}
_, err = f.WriteString(`{"role":"user","content":"orphan"}` + "\n")
if err != nil {
t.Fatalf("write orphan: %v", err)
}
f.Close()
// TruncateHistory(keepLast=4) should keep the last 4 of 11 lines,
// not the last 4 of 10.
err = store.TruncateHistory(ctx, "stale", 4)
if err != nil {
t.Fatalf("TruncateHistory: %v", err)
}
history, err := store.GetHistory(ctx, "stale")
if err != nil {
t.Fatalf("GetHistory: %v", err)
}
if len(history) != 4 {
t.Fatalf("expected 4, got %d", len(history))
}
// Last 4 of [a,b,c,d,e,f,g,h,i,j,orphan] = [h,i,j,orphan]
if history[0].Content != "h" {
t.Errorf("first kept = %q, want 'h'", history[0].Content)
}
if history[3].Content != "orphan" {
t.Errorf("last kept = %q, want 'orphan'", history[3].Content)
}
}
func TestCrashRecovery_PartialLine(t *testing.T) {
store := newTestStore(t)
ctx := context.Background()
// Write a valid message first.
err := store.AddMessage(ctx, "crash", "user", "valid")
if err != nil {
t.Fatalf("AddMessage: %v", err)
}
// Simulate a crash by appending a partial JSON line directly.
jsonlPath := store.jsonlPath("crash")
f, err := os.OpenFile(jsonlPath, os.O_WRONLY|os.O_APPEND, 0o644)
if err != nil {
t.Fatalf("open for append: %v", err)
}
_, err = f.WriteString(`{"role":"user","content":"incomple`)
if err != nil {
t.Fatalf("write partial: %v", err)
}
f.Close()
// GetHistory should return only the valid message.
history, err := store.GetHistory(ctx, "crash")
if err != nil {
t.Fatalf("GetHistory: %v", err)
}
if len(history) != 1 {
t.Fatalf("expected 1 valid message, got %d", len(history))
}
if history[0].Content != "valid" {
t.Errorf("content = %q", history[0].Content)
}
}
func TestPersistence_AcrossInstances(t *testing.T) {
dir := t.TempDir()
ctx := context.Background()
// Write with first instance.
store1, err := NewJSONLStore(dir)
if err != nil {
t.Fatalf("NewJSONLStore: %v", err)
}
err = store1.AddMessage(ctx, "persist", "user", "remember me")
if err != nil {
t.Fatalf("AddMessage: %v", err)
}
err = store1.SetSummary(ctx, "persist", "a test session")
if err != nil {
t.Fatalf("SetSummary: %v", err)
}
store1.Close()
// Read with second instance.
store2, err := NewJSONLStore(dir)
if err != nil {
t.Fatalf("NewJSONLStore: %v", err)
}
defer store2.Close()
history, err := store2.GetHistory(ctx, "persist")
if err != nil {
t.Fatalf("GetHistory: %v", err)
}
if len(history) != 1 || history[0].Content != "remember me" {
t.Errorf("history = %+v", history)
}
summary, err := store2.GetSummary(ctx, "persist")
if err != nil {
t.Fatalf("GetSummary: %v", err)
}
if summary != "a test session" {
t.Errorf("summary = %q", summary)
}
}
func TestConcurrent_AddAndRead(t *testing.T) {
store := newTestStore(t)
ctx := context.Background()
var wg sync.WaitGroup
const goroutines = 10
const msgsPerGoroutine = 20
// Concurrent writes.
for g := 0; g < goroutines; g++ {
wg.Add(1)
go func() {
defer wg.Done()
for i := 0; i < msgsPerGoroutine; i++ {
_ = store.AddMessage(ctx, "concurrent", "user", "msg")
}
}()
}
wg.Wait()
history, err := store.GetHistory(ctx, "concurrent")
if err != nil {
t.Fatalf("GetHistory: %v", err)
}
expected := goroutines * msgsPerGoroutine
if len(history) != expected {
t.Errorf("expected %d messages, got %d", expected, len(history))
}
}
func TestConcurrent_SummarizeRace(t *testing.T) {
// Simulates the #704 race: one goroutine adds messages while
// another truncates + sets summary — like summarizeSession().
store := newTestStore(t)
ctx := context.Background()
// Seed with some messages.
for i := 0; i < 20; i++ {
err := store.AddMessage(ctx, "race", "user", "seed")
if err != nil {
t.Fatalf("AddMessage: %v", err)
}
}
var wg sync.WaitGroup
// Writer goroutine (main agent loop).
wg.Add(1)
go func() {
defer wg.Done()
for i := 0; i < 50; i++ {
_ = store.AddMessage(ctx, "race", "user", "new")
}
}()
// Summarizer goroutine (background task).
wg.Add(1)
go func() {
defer wg.Done()
for i := 0; i < 10; i++ {
_ = store.SetSummary(ctx, "race", "summary")
_ = store.TruncateHistory(ctx, "race", 5)
}
}()
wg.Wait()
// Verify the store is still in a consistent state.
_, err := store.GetHistory(ctx, "race")
if err != nil {
t.Fatalf("GetHistory after race: %v", err)
}
_, err = store.GetSummary(ctx, "race")
if err != nil {
t.Fatalf("GetSummary after race: %v", err)
}
}
func TestMultipleSessions_Isolation(t *testing.T) {
store := newTestStore(t)
ctx := context.Background()
err := store.AddMessage(ctx, "s1", "user", "msg for s1")
if err != nil {
t.Fatalf("AddMessage: %v", err)
}
err = store.AddMessage(ctx, "s2", "user", "msg for s2")
if err != nil {
t.Fatalf("AddMessage: %v", err)
}
h1, err := store.GetHistory(ctx, "s1")
if err != nil {
t.Fatalf("GetHistory s1: %v", err)
}
h2, err := store.GetHistory(ctx, "s2")
if err != nil {
t.Fatalf("GetHistory s2: %v", err)
}
if len(h1) != 1 || h1[0].Content != "msg for s1" {
t.Errorf("s1 history = %+v", h1)
}
if len(h2) != 1 || h2[0].Content != "msg for s2" {
t.Errorf("s2 history = %+v", h2)
}
}
func BenchmarkAddMessage(b *testing.B) {
dir := b.TempDir()
store, err := NewJSONLStore(dir)
if err != nil {
b.Fatalf("NewJSONLStore: %v", err)
}
defer store.Close()
ctx := context.Background()
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = store.AddMessage(ctx, "bench", "user", "benchmark message content")
}
}
func BenchmarkGetHistory_100(b *testing.B) {
dir := b.TempDir()
store, err := NewJSONLStore(dir)
if err != nil {
b.Fatalf("NewJSONLStore: %v", err)
}
defer store.Close()
ctx := context.Background()
for i := 0; i < 100; i++ {
_ = store.AddMessage(ctx, "bench", "user", "message content")
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = store.GetHistory(ctx, "bench")
}
}
func BenchmarkGetHistory_1000(b *testing.B) {
dir := b.TempDir()
store, err := NewJSONLStore(dir)
if err != nil {
b.Fatalf("NewJSONLStore: %v", err)
}
defer store.Close()
ctx := context.Background()
for i := 0; i < 1000; i++ {
_ = store.AddMessage(ctx, "bench", "user", "message content")
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = store.GetHistory(ctx, "bench")
}
}
+108
View File
@@ -0,0 +1,108 @@
package memory
import (
"context"
"encoding/json"
"fmt"
"log"
"os"
"path/filepath"
"strings"
"time"
"github.com/sipeed/picoclaw/pkg/providers"
)
// jsonSession mirrors pkg/session.Session for migration purposes.
type jsonSession struct {
Key string `json:"key"`
Messages []providers.Message `json:"messages"`
Summary string `json:"summary,omitempty"`
Created time.Time `json:"created"`
Updated time.Time `json:"updated"`
}
// MigrateFromJSON reads legacy sessions/*.json files from sessionsDir,
// writes them into the Store, and renames each migrated file to
// .json.migrated as a backup. Returns the number of sessions migrated.
//
// Files that fail to parse are logged and skipped. Already-migrated
// files (.json.migrated) are ignored, making the function idempotent.
func MigrateFromJSON(
ctx context.Context, sessionsDir string, store Store,
) (int, error) {
entries, err := os.ReadDir(sessionsDir)
if os.IsNotExist(err) {
return 0, nil
}
if err != nil {
return 0, fmt.Errorf("memory: read sessions dir: %w", err)
}
migrated := 0
for _, entry := range entries {
if entry.IsDir() {
continue
}
name := entry.Name()
if !strings.HasSuffix(name, ".json") {
continue
}
// Skip already-migrated files.
if strings.HasSuffix(name, ".migrated") {
continue
}
srcPath := filepath.Join(sessionsDir, name)
data, readErr := os.ReadFile(srcPath)
if readErr != nil {
log.Printf("memory: migrate: skip %s: %v", name, readErr)
continue
}
var sess jsonSession
if parseErr := json.Unmarshal(data, &sess); parseErr != nil {
log.Printf("memory: migrate: skip %s: %v", name, parseErr)
continue
}
// Use the key from the JSON content, not the filename.
// Filenames are sanitized (":" → "_") but keys are not.
key := sess.Key
if key == "" {
key = strings.TrimSuffix(name, ".json")
}
// Use SetHistory (atomic replace) instead of per-message
// AddFullMessage. This makes migration idempotent: if the
// process crashes after writing messages but before the
// rename below, a retry replaces the partial data cleanly
// instead of duplicating messages.
if setErr := store.SetHistory(ctx, key, sess.Messages); setErr != nil {
return migrated, fmt.Errorf(
"memory: migrate %s: set history: %w",
name, setErr,
)
}
if sess.Summary != "" {
if sumErr := store.SetSummary(ctx, key, sess.Summary); sumErr != nil {
return migrated, fmt.Errorf(
"memory: migrate %s: set summary: %w",
name, sumErr,
)
}
}
// Rename to .migrated as backup (not delete).
renameErr := os.Rename(srcPath, srcPath+".migrated")
if renameErr != nil {
log.Printf("memory: migrate: rename %s: %v", name, renameErr)
}
migrated++
}
return migrated, nil
}
+384
View File
@@ -0,0 +1,384 @@
package memory
import (
"context"
"encoding/json"
"os"
"path/filepath"
"testing"
"time"
"github.com/sipeed/picoclaw/pkg/providers"
)
func writeJSONSession(
t *testing.T, dir string, filename string, sess jsonSession,
) {
t.Helper()
data, err := json.MarshalIndent(sess, "", " ")
if err != nil {
t.Fatalf("marshal session: %v", err)
}
err = os.WriteFile(filepath.Join(dir, filename), data, 0o644)
if err != nil {
t.Fatalf("write session file: %v", err)
}
}
func TestMigrateFromJSON_Basic(t *testing.T) {
sessionsDir := t.TempDir()
store := newTestStore(t)
ctx := context.Background()
writeJSONSession(t, sessionsDir, "test.json", jsonSession{
Key: "test",
Messages: []providers.Message{
{Role: "user", Content: "hello"},
{Role: "assistant", Content: "hi"},
},
Summary: "A greeting.",
Created: time.Now(),
Updated: time.Now(),
})
count, err := MigrateFromJSON(ctx, sessionsDir, store)
if err != nil {
t.Fatalf("MigrateFromJSON: %v", err)
}
if count != 1 {
t.Errorf("expected 1 migrated, got %d", count)
}
history, err := store.GetHistory(ctx, "test")
if err != nil {
t.Fatalf("GetHistory: %v", err)
}
if len(history) != 2 {
t.Fatalf("expected 2 messages, got %d", len(history))
}
if history[0].Content != "hello" || history[1].Content != "hi" {
t.Errorf("unexpected messages: %+v", history)
}
summary, err := store.GetSummary(ctx, "test")
if err != nil {
t.Fatalf("GetSummary: %v", err)
}
if summary != "A greeting." {
t.Errorf("summary = %q", summary)
}
}
func TestMigrateFromJSON_WithToolCalls(t *testing.T) {
sessionsDir := t.TempDir()
store := newTestStore(t)
ctx := context.Background()
writeJSONSession(t, sessionsDir, "tools.json", jsonSession{
Key: "tools",
Messages: []providers.Message{
{
Role: "assistant",
Content: "Searching...",
ToolCalls: []providers.ToolCall{
{
ID: "call_1",
Type: "function",
Function: &providers.FunctionCall{
Name: "web_search",
Arguments: `{"q":"test"}`,
},
},
},
},
{
Role: "tool",
Content: "result",
ToolCallID: "call_1",
},
},
Created: time.Now(),
Updated: time.Now(),
})
count, err := MigrateFromJSON(ctx, sessionsDir, store)
if err != nil {
t.Fatalf("MigrateFromJSON: %v", err)
}
if count != 1 {
t.Errorf("expected 1, got %d", count)
}
history, err := store.GetHistory(ctx, "tools")
if err != nil {
t.Fatalf("GetHistory: %v", err)
}
if len(history) != 2 {
t.Fatalf("expected 2 messages, got %d", len(history))
}
if len(history[0].ToolCalls) != 1 {
t.Fatalf("expected 1 tool call, got %d", len(history[0].ToolCalls))
}
if history[0].ToolCalls[0].Function.Name != "web_search" {
t.Errorf("function = %q", history[0].ToolCalls[0].Function.Name)
}
if history[1].ToolCallID != "call_1" {
t.Errorf("ToolCallID = %q", history[1].ToolCallID)
}
}
func TestMigrateFromJSON_MultipleFiles(t *testing.T) {
sessionsDir := t.TempDir()
store := newTestStore(t)
ctx := context.Background()
for i := 0; i < 3; i++ {
key := string(rune('a' + i))
writeJSONSession(t, sessionsDir, key+".json", jsonSession{
Key: key,
Messages: []providers.Message{{Role: "user", Content: "msg " + key}},
Created: time.Now(),
Updated: time.Now(),
})
}
count, err := MigrateFromJSON(ctx, sessionsDir, store)
if err != nil {
t.Fatalf("MigrateFromJSON: %v", err)
}
if count != 3 {
t.Errorf("expected 3, got %d", count)
}
for i := 0; i < 3; i++ {
key := string(rune('a' + i))
history, histErr := store.GetHistory(ctx, key)
if histErr != nil {
t.Fatalf("GetHistory(%q): %v", key, histErr)
}
if len(history) != 1 {
t.Errorf("session %q: expected 1 msg, got %d", key, len(history))
}
}
}
func TestMigrateFromJSON_InvalidJSON(t *testing.T) {
sessionsDir := t.TempDir()
store := newTestStore(t)
ctx := context.Background()
// One valid, one invalid.
writeJSONSession(t, sessionsDir, "good.json", jsonSession{
Key: "good",
Messages: []providers.Message{{Role: "user", Content: "ok"}},
Created: time.Now(),
Updated: time.Now(),
})
err := os.WriteFile(
filepath.Join(sessionsDir, "bad.json"),
[]byte("{invalid json"),
0o644,
)
if err != nil {
t.Fatalf("write bad file: %v", err)
}
count, err := MigrateFromJSON(ctx, sessionsDir, store)
if err != nil {
t.Fatalf("MigrateFromJSON: %v", err)
}
if count != 1 {
t.Errorf("expected 1 (bad file skipped), got %d", count)
}
history, err := store.GetHistory(ctx, "good")
if err != nil {
t.Fatalf("GetHistory: %v", err)
}
if len(history) != 1 {
t.Errorf("expected 1 message, got %d", len(history))
}
}
func TestMigrateFromJSON_RenamesFiles(t *testing.T) {
sessionsDir := t.TempDir()
store := newTestStore(t)
ctx := context.Background()
writeJSONSession(t, sessionsDir, "rename.json", jsonSession{
Key: "rename",
Messages: []providers.Message{{Role: "user", Content: "hi"}},
Created: time.Now(),
Updated: time.Now(),
})
_, err := MigrateFromJSON(ctx, sessionsDir, store)
if err != nil {
t.Fatalf("MigrateFromJSON: %v", err)
}
// Original .json should not exist.
_, statErr := os.Stat(filepath.Join(sessionsDir, "rename.json"))
if !os.IsNotExist(statErr) {
t.Error("rename.json should have been renamed")
}
// .json.migrated should exist.
_, statErr = os.Stat(
filepath.Join(sessionsDir, "rename.json.migrated"),
)
if statErr != nil {
t.Errorf("rename.json.migrated should exist: %v", statErr)
}
}
func TestMigrateFromJSON_Idempotent(t *testing.T) {
sessionsDir := t.TempDir()
store := newTestStore(t)
ctx := context.Background()
writeJSONSession(t, sessionsDir, "idem.json", jsonSession{
Key: "idem",
Messages: []providers.Message{{Role: "user", Content: "once"}},
Created: time.Now(),
Updated: time.Now(),
})
count1, err := MigrateFromJSON(ctx, sessionsDir, store)
if err != nil {
t.Fatalf("first migration: %v", err)
}
if count1 != 1 {
t.Errorf("first run: expected 1, got %d", count1)
}
// Second run should find only .migrated files, skip them.
count2, err := MigrateFromJSON(ctx, sessionsDir, store)
if err != nil {
t.Fatalf("second migration: %v", err)
}
if count2 != 0 {
t.Errorf("second run: expected 0, got %d", count2)
}
history, err := store.GetHistory(ctx, "idem")
if err != nil {
t.Fatalf("GetHistory: %v", err)
}
if len(history) != 1 {
t.Errorf("expected 1 message, got %d", len(history))
}
}
func TestMigrateFromJSON_ColonInKey(t *testing.T) {
sessionsDir := t.TempDir()
store := newTestStore(t)
ctx := context.Background()
// File is named telegram_123 (sanitized), but the key inside is telegram:123.
writeJSONSession(t, sessionsDir, "telegram_123.json", jsonSession{
Key: "telegram:123",
Messages: []providers.Message{{Role: "user", Content: "from telegram"}},
Created: time.Now(),
Updated: time.Now(),
})
count, err := MigrateFromJSON(ctx, sessionsDir, store)
if err != nil {
t.Fatalf("MigrateFromJSON: %v", err)
}
if count != 1 {
t.Errorf("expected 1, got %d", count)
}
// Accessible via the original key "telegram:123".
history, err := store.GetHistory(ctx, "telegram:123")
if err != nil {
t.Fatalf("GetHistory: %v", err)
}
if len(history) != 1 {
t.Fatalf("expected 1 message, got %d", len(history))
}
if history[0].Content != "from telegram" {
t.Errorf("content = %q", history[0].Content)
}
// In the file-based store, "telegram:123" and "telegram_123" both
// sanitize to the same filename, so they share storage. This is
// expected — the colon-to-underscore mapping is a one-way function.
history2, err := store.GetHistory(ctx, "telegram_123")
if err != nil {
t.Fatalf("GetHistory: %v", err)
}
if len(history2) != 1 {
t.Errorf("expected 1 (same file), got %d", len(history2))
}
}
func TestMigrateFromJSON_RetryAfterCrash(t *testing.T) {
// Simulates a crash during migration: first run writes messages
// but doesn't rename the .json file. Second run must replace
// (not duplicate) the messages thanks to SetHistory semantics.
sessionsDir := t.TempDir()
store := newTestStore(t)
ctx := context.Background()
writeJSONSession(t, sessionsDir, "retry.json", jsonSession{
Key: "retry",
Messages: []providers.Message{
{Role: "user", Content: "one"},
{Role: "assistant", Content: "two"},
},
Created: time.Now(),
Updated: time.Now(),
})
// First migration succeeds — writes messages and renames file.
count, err := MigrateFromJSON(ctx, sessionsDir, store)
if err != nil {
t.Fatalf("first migration: %v", err)
}
if count != 1 {
t.Fatalf("expected 1, got %d", count)
}
// Simulate "crash before rename": restore the .json file.
src := filepath.Join(sessionsDir, "retry.json.migrated")
dst := filepath.Join(sessionsDir, "retry.json")
if renameErr := os.Rename(src, dst); renameErr != nil {
t.Fatalf("restore .json: %v", renameErr)
}
// Second migration should re-import without duplicating messages.
count, err = MigrateFromJSON(ctx, sessionsDir, store)
if err != nil {
t.Fatalf("second migration: %v", err)
}
if count != 1 {
t.Fatalf("expected 1, got %d", count)
}
history, err := store.GetHistory(ctx, "retry")
if err != nil {
t.Fatalf("GetHistory: %v", err)
}
// Must be exactly 2 messages (not 4 from duplication).
if len(history) != 2 {
t.Fatalf("expected 2 messages (no duplicates), got %d", len(history))
}
if history[0].Content != "one" || history[1].Content != "two" {
t.Errorf("unexpected messages: %+v", history)
}
}
func TestMigrateFromJSON_NonexistentDir(t *testing.T) {
store := newTestStore(t)
ctx := context.Background()
count, err := MigrateFromJSON(ctx, "/nonexistent/path", store)
if err != nil {
t.Fatalf("MigrateFromJSON: %v", err)
}
if count != 0 {
t.Errorf("expected 0, got %d", count)
}
}
+42
View File
@@ -0,0 +1,42 @@
package memory
import (
"context"
"github.com/sipeed/picoclaw/pkg/providers"
)
// Store defines an interface for persistent session storage.
// Each method is an atomic operation — there is no separate Save() call.
type Store interface {
// AddMessage appends a simple text message to a session.
AddMessage(ctx context.Context, sessionKey, role, content string) error
// AddFullMessage appends a complete message (with tool calls, etc.) to a session.
AddFullMessage(ctx context.Context, sessionKey string, msg providers.Message) error
// GetHistory returns all messages for a session in insertion order.
// Returns an empty slice (not nil) if the session does not exist.
GetHistory(ctx context.Context, sessionKey string) ([]providers.Message, error)
// GetSummary returns the conversation summary for a session.
// Returns an empty string if no summary exists.
GetSummary(ctx context.Context, sessionKey string) (string, error)
// SetSummary updates the conversation summary for a session.
SetSummary(ctx context.Context, sessionKey, summary string) error
// TruncateHistory removes all but the last keepLast messages from a session.
// If keepLast <= 0, all messages are removed.
TruncateHistory(ctx context.Context, sessionKey string, keepLast int) error
// SetHistory replaces all messages in a session with the provided history.
SetHistory(ctx context.Context, sessionKey string, history []providers.Message) error
// Compact reclaims storage by physically removing logically truncated
// data. Backends that do not accumulate dead data may return nil.
Compact(ctx context.Context, sessionKey string) error
// Close releases any resources held by the store.
Close() error
}
+1 -27
View File
@@ -190,28 +190,6 @@ func resolveProviderSelection(cfg *config.Config) (providerSelection, error) {
sel.apiBase = "https://api.mistral.ai/v1"
}
}
case "opencode":
if cfg.Providers.Opencode.APIKey != "" || cfg.Providers.Opencode.APIBase != "" {
sel.apiKey = cfg.Providers.Opencode.APIKey
sel.apiBase = cfg.Providers.Opencode.APIBase
sel.proxy = cfg.Providers.Opencode.Proxy
if sel.apiBase == "" {
sel.apiBase = "https://opencode.ai/zen/v1"
}
}
case "kimi", "kimi-code", "moonshot":
if cfg.Providers.Moonshot.APIKey != "" {
sel.apiKey = cfg.Providers.Moonshot.APIKey
sel.apiBase = cfg.Providers.Moonshot.APIBase
sel.proxy = cfg.Providers.Moonshot.Proxy
if sel.apiBase == "" {
if providerName == "moonshot" {
sel.apiBase = "https://api.moonshot.cn/v1"
} else {
sel.apiBase = "https://api.kimi.com/coding/v1"
}
}
}
case "github_copilot", "copilot":
sel.providerType = providerTypeGitHubCopilot
if cfg.Providers.GitHubCopilot.APIBase != "" {
@@ -232,11 +210,7 @@ func resolveProviderSelection(cfg *config.Config) (providerSelection, error) {
sel.apiBase = cfg.Providers.Moonshot.APIBase
sel.proxy = cfg.Providers.Moonshot.Proxy
if sel.apiBase == "" {
if strings.Contains(lowerModel, "moonshot") || strings.HasPrefix(model, "moonshot/") {
sel.apiBase = "https://api.moonshot.cn/v1"
} else {
sel.apiBase = "https://api.kimi.com/coding/v1"
}
sel.apiBase = "https://api.moonshot.cn/v1"
}
case strings.HasPrefix(model, "openrouter/") ||
strings.HasPrefix(model, "anthropic/") ||
+1 -3
View File
@@ -94,7 +94,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
case "litellm", "openrouter", "groq", "zhipu", "gemini", "nvidia",
"ollama", "moonshot", "shengsuanyun", "deepseek", "cerebras",
"volcengine", "vllm", "qwen", "mistral", "opencode":
"volcengine", "vllm", "qwen", "mistral":
// All other OpenAI-compatible HTTP providers
if cfg.APIKey == "" && cfg.APIBase == "" {
return nil, "", fmt.Errorf("api_key or api_base is required for HTTP-based protocol %q", protocol)
@@ -208,8 +208,6 @@ func getDefaultAPIBase(protocol string) string {
return "http://localhost:8000/v1"
case "mistral":
return "https://api.mistral.ai/v1"
case "opencode":
return "https://opencode.ai/zen/v1"
default:
return ""
}
-1
View File
@@ -112,7 +112,6 @@ func TestCreateProviderFromConfig_DefaultAPIBase(t *testing.T) {
{"vllm", "vllm"},
{"deepseek", "deepseek"},
{"ollama", "ollama"},
{"opencode", "opencode"},
}
for _, tt := range tests {
+1 -15
View File
@@ -33,7 +33,6 @@ type Provider struct {
apiBase string
maxTokensField string // Field name for max tokens (e.g., "max_completion_tokens" for o1/glm models)
httpClient *http.Client
isKimiAPI bool // true when apiBase points to api.kimi.com
}
type Option func(*Provider)
@@ -70,17 +69,10 @@ func NewProvider(apiKey, apiBase, proxy string, opts ...Option) *Provider {
}
}
trimmedBase := strings.TrimRight(apiBase, "/")
var isKimi bool
if parsed, err := url.Parse(trimmedBase); err == nil {
isKimi = parsed.Hostname() == "api.kimi.com"
}
p := &Provider{
apiKey: apiKey,
apiBase: trimmedBase,
apiBase: strings.TrimRight(apiBase, "/"),
httpClient: client,
isKimiAPI: isKimi,
}
for _, opt := range opts {
@@ -184,12 +176,6 @@ func (p *Provider) Chat(
if p.apiKey != "" {
req.Header.Set("Authorization", "Bearer "+p.apiKey)
}
// Kimi Code API rejects requests without a recognized coding-agent
// User-Agent. "KimiCLI/0.77" is the minimum version string accepted
// by the api.kimi.com/coding/v1 endpoint (per Kimi's API docs).
if p.isKimiAPI {
req.Header.Set("User-Agent", "KimiCLI/0.77")
}
resp, err := p.httpClient.Do(req)
if err != nil {
@@ -2,7 +2,6 @@ package openai_compat
import (
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"net/url"
@@ -421,82 +420,6 @@ func TestProvider_FunctionalOptionRequestTimeoutNonPositive(t *testing.T) {
}
}
// roundTripFunc adapts a function to http.RoundTripper for test injection.
type roundTripFunc func(*http.Request) (*http.Response, error)
func (f roundTripFunc) RoundTrip(r *http.Request) (*http.Response, error) {
return f(r)
}
func TestProviderChat_KimiCodeUserAgent(t *testing.T) {
okBody := `{"choices":[{"message":{"content":"ok"},"finish_reason":"stop"}]}`
tests := []struct {
name string
apiBase string
wantAgent string
}{
{
name: "sets KimiCLI User-Agent for api.kimi.com",
apiBase: "https://api.kimi.com/coding/v1",
wantAgent: "KimiCLI/0.77",
},
{
name: "does not set KimiCLI User-Agent for other hosts",
apiBase: "https://api.example.com/v1",
wantAgent: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var gotUserAgent string
p := NewProvider("key", tt.apiBase, "")
p.httpClient.Transport = roundTripFunc(
func(r *http.Request) (*http.Response, error) {
gotUserAgent = r.Header.Get("User-Agent")
return &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(
strings.NewReader(okBody),
),
Header: http.Header{
"Content-Type": {"application/json"},
},
}, nil
},
)
_, err := p.Chat(
t.Context(),
[]Message{{Role: "user", Content: "hi"}},
nil,
"kimi-k2.5",
nil,
)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
if tt.wantAgent != "" {
if gotUserAgent != tt.wantAgent {
t.Fatalf(
"User-Agent = %q, want %q",
gotUserAgent, tt.wantAgent,
)
}
} else {
if gotUserAgent == "KimiCLI/0.77" {
t.Fatalf(
"User-Agent should not be KimiCLI/0.77 for non-kimi host",
)
}
}
})
}
}
func TestSerializeMessages_PlainText(t *testing.T) {
messages := []protocoltypes.Message{
{Role: "user", Content: "hello"},
+5 -4
View File
@@ -3,6 +3,7 @@ package tools
import (
"context"
"fmt"
"sync/atomic"
)
type SendCallback func(channel, chatID, content string) error
@@ -11,7 +12,7 @@ type MessageTool struct {
sendCallback SendCallback
defaultChannel string
defaultChatID string
sentInRound bool // Tracks whether a message was sent in the current processing round
sentInRound atomic.Bool // Tracks whether a message was sent in the current processing round
}
func NewMessageTool() *MessageTool {
@@ -50,12 +51,12 @@ func (t *MessageTool) Parameters() map[string]any {
func (t *MessageTool) SetContext(channel, chatID string) {
t.defaultChannel = channel
t.defaultChatID = chatID
t.sentInRound = false // Reset send tracking for new processing round
t.sentInRound.Store(false) // Reset send tracking for new processing round
}
// HasSentInRound returns true if the message tool sent a message during the current round.
func (t *MessageTool) HasSentInRound() bool {
return t.sentInRound
return t.sentInRound.Load()
}
func (t *MessageTool) SetSendCallback(callback SendCallback) {
@@ -94,7 +95,7 @@ func (t *MessageTool) Execute(ctx context.Context, args map[string]any) *ToolRes
}
}
t.sentInRound = true
t.sentInRound.Store(true)
// Silent: user already received the message directly
return &ToolResult{
ForLLM: fmt.Sprintf("Message sent to %s:%s", channel, chatID),
+2 -2
View File
@@ -30,9 +30,9 @@ var (
regexp.MustCompile(`\brm\s+-[rf]{1,2}\b`),
regexp.MustCompile(`\bdel\s+/[fq]\b`),
regexp.MustCompile(`\brmdir\s+/s\b`),
// Match disk wiping commands, avoid matching --format flags
// Match disk wiping commands (must be followed by space/args)
regexp.MustCompile(
`(?:^|[;&|]\s*|\s+)(format|mkfs|diskpart)\s`,
`\b(format|mkfs|diskpart)\b\s`,
),
regexp.MustCompile(`\bdd\s+if=`),
// Block writes to block devices (all common naming schemes).
-50
View File
@@ -366,56 +366,6 @@ func TestShellTool_BlockDevices(t *testing.T) {
}
}
// TestShellTool_DenyPattern_DiskWiping verifies the deny pattern for disk wiping
// commands (format, mkfs, diskpart) blocks them when preceded by shell separators
// but does NOT block legitimate uses like --format flags.
func TestShellTool_DenyPattern_DiskWiping(t *testing.T) {
tool, err := NewExecTool("", false)
if err != nil {
t.Fatalf("unable to configure exec tool: %s", err)
}
// These should be BLOCKED (disk wiping commands)
blockedCmds := []struct {
name string
cmd string
}{
{"format with space", "format c:"},
{"mkfs standalone", "mkfs /dev/sda"},
{"semicolon format", "echo hello; format c:"},
{"pipe format", "echo hello | format c:"},
{"and format", "echo hello && format c:"},
{"diskpart standalone", "diskpart /s script.txt"},
}
for _, tt := range blockedCmds {
t.Run("blocked_"+tt.name, func(t *testing.T) {
msg := tool.guardCommand(tt.cmd, "")
if !strings.Contains(msg, "blocked") {
t.Errorf("Expected %q to be blocked by safety guard, got: %q", tt.cmd, msg)
}
})
}
// These should be ALLOWED (not disk wiping)
allowed := []struct {
name string
cmd string
}{
{"--format flag", "echo test --format json"},
{"go fmt", "echo go fmt ./..."},
}
for _, tt := range allowed {
t.Run("allowed_"+tt.name, func(t *testing.T) {
msg := tool.guardCommand(tt.cmd, "")
if msg != "" {
t.Errorf("Expected %q to be allowed, but it was blocked: %s", tt.cmd, msg)
}
})
}
}
// TestShellTool_SafePathsInWorkspaceRestriction verifies that safe kernel pseudo-devices
// are allowed even when workspace restriction is active.
func TestShellTool_SafePathsInWorkspaceRestriction(t *testing.T) {
+43 -26
View File
@@ -10,6 +10,7 @@ import (
"context"
"encoding/json"
"fmt"
"sync"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/providers"
@@ -121,37 +122,53 @@ func RunToolLoop(
}
messages = append(messages, assistantMsg)
// 7. Execute tool calls
for _, tc := range normalizedToolCalls {
argsJSON, _ := json.Marshal(tc.Arguments)
argsPreview := utils.Truncate(string(argsJSON), 200)
logger.InfoCF("toolloop", fmt.Sprintf("Tool call: %s(%s)", tc.Name, argsPreview),
map[string]any{
"tool": tc.Name,
"iteration": iteration,
})
// 7. Execute tool calls in parallel
type indexedResult struct {
result *ToolResult
tc providers.ToolCall
}
// Execute tool (no async callback for subagents - they run independently)
var toolResult *ToolResult
if config.Tools != nil {
toolResult = config.Tools.ExecuteWithContext(ctx, tc.Name, tc.Arguments, channel, chatID, nil)
} else {
toolResult = ErrorResult("No tools available")
results := make([]indexedResult, len(normalizedToolCalls))
var wg sync.WaitGroup
for i, tc := range normalizedToolCalls {
results[i].tc = tc
wg.Add(1)
go func(idx int, tc providers.ToolCall) {
defer wg.Done()
argsJSON, _ := json.Marshal(tc.Arguments)
argsPreview := utils.Truncate(string(argsJSON), 200)
logger.InfoCF("toolloop", fmt.Sprintf("Tool call: %s(%s)", tc.Name, argsPreview),
map[string]any{
"tool": tc.Name,
"iteration": iteration,
})
var toolResult *ToolResult
if config.Tools != nil {
toolResult = config.Tools.ExecuteWithContext(ctx, tc.Name, tc.Arguments, channel, chatID, nil)
} else {
toolResult = ErrorResult("No tools available")
}
results[idx].result = toolResult
}(i, tc)
}
wg.Wait()
// Append results in original order
for _, r := range results {
contentForLLM := r.result.ForLLM
if contentForLLM == "" && r.result.Err != nil {
contentForLLM = r.result.Err.Error()
}
// Determine content for LLM
contentForLLM := toolResult.ForLLM
if contentForLLM == "" && toolResult.Err != nil {
contentForLLM = toolResult.Err.Error()
}
// Add tool result message
toolResultMsg := providers.Message{
messages = append(messages, providers.Message{
Role: "tool",
Content: contentForLLM,
ToolCallID: tc.ID,
}
messages = append(messages, toolResultMsg)
ToolCallID: r.tc.ID,
})
}
}
+107 -87
View File
@@ -395,6 +395,88 @@ func (p *PerplexitySearchProvider) Search(ctx context.Context, query string, cou
return fmt.Sprintf("Results for: %s (via Perplexity)\n%s", query, searchResp.Choices[0].Message.Content), nil
}
type GLMSearchProvider struct {
apiKey string
baseURL string
searchEngine string
proxy string
client *http.Client
}
func (p *GLMSearchProvider) Search(ctx context.Context, query string, count int) (string, error) {
searchURL := p.baseURL
if searchURL == "" {
searchURL = "https://open.bigmodel.cn/api/paas/v4/web_search"
}
payload := map[string]any{
"search_query": query,
"search_engine": p.searchEngine,
"search_intent": false,
"count": count,
"content_size": "medium",
}
bodyBytes, err := json.Marshal(payload)
if err != nil {
return "", fmt.Errorf("failed to marshal payload: %w", err)
}
req, err := http.NewRequestWithContext(ctx, "POST", searchURL, bytes.NewReader(bodyBytes))
if err != nil {
return "", fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+p.apiKey)
resp, err := p.client.Do(req)
if err != nil {
return "", fmt.Errorf("request failed: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
if err != nil {
return "", fmt.Errorf("failed to read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("GLM Search API error (status %d): %s", resp.StatusCode, string(body))
}
var searchResp struct {
SearchResult []struct {
Title string `json:"title"`
Content string `json:"content"`
Link string `json:"link"`
} `json:"search_result"`
}
if err := json.Unmarshal(body, &searchResp); err != nil {
return "", fmt.Errorf("failed to parse response: %w", err)
}
results := searchResp.SearchResult
if len(results) == 0 {
return fmt.Sprintf("No results for: %s", query), nil
}
var lines []string
lines = append(lines, fmt.Sprintf("Results for: %s (via GLM Search)", query))
for i, item := range results {
if i >= count {
break
}
lines = append(lines, fmt.Sprintf("%d. %s\n %s", i+1, item.Title, item.Link))
if item.Content != "" {
lines = append(lines, fmt.Sprintf(" %s", item.Content))
}
}
return strings.Join(lines, "\n"), nil
}
type WebSearchTool struct {
provider SearchProvider
maxResults int
@@ -413,9 +495,11 @@ type WebSearchToolOptions struct {
PerplexityAPIKey string
PerplexityMaxResults int
PerplexityEnabled bool
ExaAPIKey string
ExaMaxResults int
ExaEnabled bool
GLMSearchAPIKey string
GLMSearchBaseURL string
GLMSearchEngine string
GLMSearchMaxResults int
GLMSearchEnabled bool
Proxy string
}
@@ -423,7 +507,7 @@ func NewWebSearchTool(opts WebSearchToolOptions) (*WebSearchTool, error) {
var provider SearchProvider
maxResults := 5
// Priority: Perplexity > Exa > Brave > Tavily > DuckDuckGo
// Priority: Perplexity > Brave > Tavily > DuckDuckGo > GLM Search
if opts.PerplexityEnabled && opts.PerplexityAPIKey != "" {
client, err := createHTTPClient(opts.Proxy, perplexityTimeout)
if err != nil {
@@ -433,15 +517,6 @@ func NewWebSearchTool(opts WebSearchToolOptions) (*WebSearchTool, error) {
if opts.PerplexityMaxResults > 0 {
maxResults = opts.PerplexityMaxResults
}
} else if opts.ExaEnabled && opts.ExaAPIKey != "" {
client, err := createHTTPClient(opts.Proxy, searchTimeout)
if err != nil {
return nil, fmt.Errorf("failed to create HTTP client for Exa: %w", err)
}
provider = &ExaSearchProvider{apiKey: opts.ExaAPIKey, proxy: opts.Proxy, client: client}
if opts.ExaMaxResults > 0 {
maxResults = opts.ExaMaxResults
}
} else if opts.BraveEnabled && opts.BraveAPIKey != "" {
client, err := createHTTPClient(opts.Proxy, searchTimeout)
if err != nil {
@@ -474,6 +549,25 @@ func NewWebSearchTool(opts WebSearchToolOptions) (*WebSearchTool, error) {
if opts.DuckDuckGoMaxResults > 0 {
maxResults = opts.DuckDuckGoMaxResults
}
} else if opts.GLMSearchEnabled && opts.GLMSearchAPIKey != "" {
client, err := createHTTPClient(opts.Proxy, searchTimeout)
if err != nil {
return nil, fmt.Errorf("failed to create HTTP client for GLM Search: %w", err)
}
searchEngine := opts.GLMSearchEngine
if searchEngine == "" {
searchEngine = "search_std"
}
provider = &GLMSearchProvider{
apiKey: opts.GLMSearchAPIKey,
baseURL: opts.GLMSearchBaseURL,
searchEngine: searchEngine,
proxy: opts.Proxy,
client: client,
}
if opts.GLMSearchMaxResults > 0 {
maxResults = opts.GLMSearchMaxResults
}
} else {
return nil, nil
}
@@ -721,77 +815,3 @@ func (t *WebFetchTool) extractText(htmlContent string) string {
return strings.Join(cleanLines, "\n")
}
// ExaSearchProvider uses the Exa AI search API (https://exa.ai).
type ExaSearchProvider struct {
apiKey string
proxy string
client *http.Client
}
func (p *ExaSearchProvider) Search(ctx context.Context, query string, count int) (string, error) {
reqBody := map[string]any{
"query": query,
"num_results": count,
"type": "neural",
}
jsonData, err := json.Marshal(reqBody)
if err != nil {
return "", fmt.Errorf("exa: marshal error: %w", err)
}
req, err := http.NewRequestWithContext(ctx, "POST", "https://api.exa.ai/search", bytes.NewReader(jsonData))
if err != nil {
return "", fmt.Errorf("exa: request error: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("x-api-key", p.apiKey)
resp, err := p.client.Do(req)
if err != nil {
return "", fmt.Errorf("exa: search failed: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("exa: read error: %w", err)
}
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("exa: API error %d: %s", resp.StatusCode, string(body))
}
var result struct {
Results []struct {
Title string `json:"title"`
URL string `json:"url"`
Text string `json:"text"`
} `json:"results"`
}
if err := json.Unmarshal(body, &result); err != nil {
return "", fmt.Errorf("exa: parse error: %w", err)
}
if len(result.Results) == 0 {
return fmt.Sprintf("No results for: %s", query), nil
}
var lines []string
lines = append(lines, fmt.Sprintf("Results for: %s (via Exa)", query))
maxResults := count
if maxResults > len(result.Results) {
maxResults = len(result.Results)
}
for i, r := range result.Results[:maxResults] {
lines = append(lines, fmt.Sprintf("%d. %s\n %s", i+1, r.Title, r.URL))
if r.Text != "" {
snippet := r.Text
if len(snippet) > 200 {
snippet = snippet[:200] + "..."
}
lines = append(lines, fmt.Sprintf(" %s", snippet))
}
}
return strings.Join(lines, "\n"), nil
}
+105 -189
View File
@@ -5,7 +5,6 @@ import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"strings"
@@ -683,86 +682,7 @@ func TestWebTool_TavilySearch_Success(t *testing.T) {
}
}
func TestNewWebSearchTool_ExaPriority(t *testing.T) {
// Exa should be selected when enabled with API key
tool, err := NewWebSearchTool(WebSearchToolOptions{
ExaEnabled: true,
ExaAPIKey: "exa-key",
ExaMaxResults: 3,
})
if err != nil {
t.Fatalf("NewWebSearchTool() error: %v", err)
}
if tool == nil {
t.Fatal("Expected non-nil tool when Exa is enabled with API key")
}
if _, ok := tool.provider.(*ExaSearchProvider); !ok {
t.Fatalf("provider type = %T, want *ExaSearchProvider", tool.provider)
}
if tool.maxResults != 3 {
t.Fatalf("maxResults = %d, want 3", tool.maxResults)
}
// Exa enabled but no API key should fall through
tool, err = NewWebSearchTool(WebSearchToolOptions{
ExaEnabled: true,
ExaAPIKey: "",
})
if err != nil {
t.Fatalf("NewWebSearchTool() error: %v", err)
}
if tool != nil {
t.Errorf("Expected nil tool when Exa API key is empty and no other provider enabled")
}
// Perplexity should take priority over Exa
tool, err = NewWebSearchTool(WebSearchToolOptions{
PerplexityEnabled: true,
PerplexityAPIKey: "perp-key",
ExaEnabled: true,
ExaAPIKey: "exa-key",
})
if err != nil {
t.Fatalf("NewWebSearchTool() error: %v", err)
}
if _, ok := tool.provider.(*PerplexitySearchProvider); !ok {
t.Fatalf("provider type = %T, want *PerplexitySearchProvider (Perplexity should outrank Exa)", tool.provider)
}
// Exa should take priority over Brave
tool, err = NewWebSearchTool(WebSearchToolOptions{
ExaEnabled: true,
ExaAPIKey: "exa-key",
BraveEnabled: true,
BraveAPIKey: "brave-key",
})
if err != nil {
t.Fatalf("NewWebSearchTool() error: %v", err)
}
if _, ok := tool.provider.(*ExaSearchProvider); !ok {
t.Fatalf("provider type = %T, want *ExaSearchProvider (Exa should outrank Brave)", tool.provider)
}
}
func TestNewWebSearchTool_ExaProxyPropagation(t *testing.T) {
tool, err := NewWebSearchTool(WebSearchToolOptions{
ExaEnabled: true,
ExaAPIKey: "k",
Proxy: "http://127.0.0.1:7890",
})
if err != nil {
t.Fatalf("NewWebSearchTool() error: %v", err)
}
p, ok := tool.provider.(*ExaSearchProvider)
if !ok {
t.Fatalf("provider type = %T, want *ExaSearchProvider", tool.provider)
}
if p.proxy != "http://127.0.0.1:7890" {
t.Fatalf("provider proxy = %q, want %q", p.proxy, "http://127.0.0.1:7890")
}
}
func TestExaSearchProvider_Success(t *testing.T) {
func TestWebTool_GLMSearch_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
t.Errorf("Expected POST request, got %s", r.Method)
@@ -770,130 +690,126 @@ func TestExaSearchProvider_Success(t *testing.T) {
if r.Header.Get("Content-Type") != "application/json" {
t.Errorf("Expected Content-Type application/json, got %s", r.Header.Get("Content-Type"))
}
if r.Header.Get("x-api-key") != "test-exa-key" {
t.Errorf("Expected x-api-key test-exa-key, got %s", r.Header.Get("x-api-key"))
if r.Header.Get("Authorization") != "Bearer test-glm-key" {
t.Errorf("Expected Authorization Bearer test-glm-key, got %s", r.Header.Get("Authorization"))
}
// Verify payload
body, _ := io.ReadAll(r.Body)
var payload map[string]any
json.Unmarshal(body, &payload)
if payload["query"] != "test query" {
t.Errorf("Expected query 'test query', got %v", payload["query"])
json.NewDecoder(r.Body).Decode(&payload)
if payload["search_query"] != "test query" {
t.Errorf("Expected search_query 'test query', got %v", payload["search_query"])
}
if payload["type"] != "neural" {
t.Errorf("Expected type 'neural', got %v", payload["type"])
if payload["search_engine"] != "search_std" {
t.Errorf("Expected search_engine 'search_std', got %v", payload["search_engine"])
}
response := map[string]any{
"results": []map[string]any{
{"title": "Exa Result 1", "url": "https://exa.ai/1", "text": "First result text"},
{"title": "Exa Result 2", "url": "https://exa.ai/2", "text": "Second result text"},
{"title": "Exa Result 3", "url": "https://exa.ai/3", "text": "Third result text"},
"id": "web-search-test",
"created": 1709568000,
"search_result": []map[string]any{
{
"title": "Test GLM Result",
"content": "GLM search snippet",
"link": "https://example.com/glm",
"media": "Example",
"publish_date": "2026-03-04",
},
},
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(response)
}))
defer server.Close()
provider := &ExaSearchProvider{
apiKey: "test-exa-key",
client: &http.Client{},
}
// Temporarily override the API URL by using a custom transport
provider.client.Transport = rewriteHostTransport(server.URL)
result, err := provider.Search(context.Background(), "test query", 5)
if err != nil {
t.Fatalf("Search() error: %v", err)
}
if !strings.Contains(result, "via Exa") {
t.Errorf("Expected '(via Exa)' attribution, got: %s", result)
}
if !strings.Contains(result, "Exa Result 1") || !strings.Contains(result, "https://exa.ai/1") {
t.Errorf("Expected results in output, got: %s", result)
}
if !strings.Contains(result, "First result text") {
t.Errorf("Expected snippet text in output, got: %s", result)
}
}
func TestExaSearchProvider_EmptyResults(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
response := map[string]any{"results": []map[string]any{}}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(response)
}))
defer server.Close()
provider := &ExaSearchProvider{
apiKey: "test-key",
client: &http.Client{Transport: rewriteHostTransport(server.URL)},
}
result, err := provider.Search(context.Background(), "no results query", 5)
if err != nil {
t.Fatalf("Search() error: %v", err)
}
if !strings.Contains(result, "No results for: no results query") {
t.Errorf("Expected 'No results' message, got: %s", result)
}
}
func TestExaSearchProvider_MaxResultsCapping(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Return 5 results
results := make([]map[string]any, 5)
for i := range results {
results[i] = map[string]any{
"title": fmt.Sprintf("Result %d", i+1),
"url": fmt.Sprintf("https://exa.ai/%d", i+1),
"text": fmt.Sprintf("Text %d", i+1),
}
}
response := map[string]any{"results": results}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(response)
}))
defer server.Close()
provider := &ExaSearchProvider{
apiKey: "test-key",
client: &http.Client{Transport: rewriteHostTransport(server.URL)},
}
// Request only 2 results even though API returns 5
result, err := provider.Search(context.Background(), "test", 2)
if err != nil {
t.Fatalf("Search() error: %v", err)
}
if !strings.Contains(result, "Result 1") || !strings.Contains(result, "Result 2") {
t.Errorf("Expected first 2 results, got: %s", result)
}
if strings.Contains(result, "Result 3") {
t.Errorf("Expected results capped at 2, but got Result 3 in output: %s", result)
}
}
// rewriteHostTransport returns an http.RoundTripper that redirects all requests to the given target URL.
func rewriteHostTransport(target string) http.RoundTripper {
return roundTripFunc(func(req *http.Request) (*http.Response, error) {
newURL := target + req.URL.Path
newReq, err := http.NewRequestWithContext(req.Context(), req.Method, newURL, req.Body)
if err != nil {
return nil, err
}
newReq.Header = req.Header
return http.DefaultClient.Do(newReq)
tool, err := NewWebSearchTool(WebSearchToolOptions{
GLMSearchEnabled: true,
GLMSearchAPIKey: "test-glm-key",
GLMSearchBaseURL: server.URL,
GLMSearchEngine: "search_std",
})
if err != nil {
t.Fatalf("NewWebSearchTool() error: %v", err)
}
result := tool.Execute(context.Background(), map[string]any{
"query": "test query",
})
if result.IsError {
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
}
if !strings.Contains(result.ForUser, "Test GLM Result") {
t.Errorf("Expected 'Test GLM Result' in output, got: %s", result.ForUser)
}
if !strings.Contains(result.ForUser, "https://example.com/glm") {
t.Errorf("Expected URL in output, got: %s", result.ForUser)
}
if !strings.Contains(result.ForUser, "via GLM Search") {
t.Errorf("Expected 'via GLM Search' in output, got: %s", result.ForUser)
}
}
type roundTripFunc func(*http.Request) (*http.Response, error)
func TestWebTool_GLMSearch_APIError(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
w.Write([]byte(`{"error":"invalid api key"}`))
}))
defer server.Close()
func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return f(req)
tool, err := NewWebSearchTool(WebSearchToolOptions{
GLMSearchEnabled: true,
GLMSearchAPIKey: "bad-key",
GLMSearchBaseURL: server.URL,
GLMSearchEngine: "search_std",
})
if err != nil {
t.Fatalf("NewWebSearchTool() error: %v", err)
}
result := tool.Execute(context.Background(), map[string]any{
"query": "test query",
})
if !result.IsError {
t.Errorf("Expected IsError=true for 401 response")
}
if !strings.Contains(result.ForLLM, "status 401") {
t.Errorf("Expected status 401 in error, got: %s", result.ForLLM)
}
}
func TestWebTool_GLMSearch_Priority(t *testing.T) {
// GLM Search should only be selected when all other providers are disabled
tool, err := NewWebSearchTool(WebSearchToolOptions{
DuckDuckGoEnabled: true,
DuckDuckGoMaxResults: 5,
GLMSearchEnabled: true,
GLMSearchAPIKey: "test-key",
GLMSearchBaseURL: "https://example.com",
GLMSearchEngine: "search_std",
})
if err != nil {
t.Fatalf("NewWebSearchTool() error: %v", err)
}
// DuckDuckGo should win over GLM Search
if _, ok := tool.provider.(*DuckDuckGoSearchProvider); !ok {
t.Errorf("Expected DuckDuckGoSearchProvider when both enabled, got %T", tool.provider)
}
// With DuckDuckGo disabled, GLM Search should be selected
tool2, err := NewWebSearchTool(WebSearchToolOptions{
DuckDuckGoEnabled: false,
GLMSearchEnabled: true,
GLMSearchAPIKey: "test-key",
GLMSearchBaseURL: "https://example.com",
GLMSearchEngine: "search_std",
})
if err != nil {
t.Fatalf("NewWebSearchTool() error: %v", err)
}
if _, ok := tool2.provider.(*GLMSearchProvider); !ok {
t.Errorf("Expected GLMSearchProvider when only GLM enabled, got %T", tool2.provider)
}
}
+19 -4
View File
@@ -3,6 +3,7 @@ package utils
import (
"io"
"net/http"
"net/url"
"os"
"path/filepath"
"strings"
@@ -52,11 +53,12 @@ type DownloadOptions struct {
Timeout time.Duration
ExtraHeaders map[string]string
LoggerPrefix string
ProxyURL string
}
// DownloadFile downloads a file from URL to a local temp directory.
// Returns the local file path or empty string on error.
func DownloadFile(url, filename string, opts DownloadOptions) string {
func DownloadFile(urlStr, filename string, opts DownloadOptions) string {
// Set defaults
if opts.Timeout == 0 {
opts.Timeout = 60 * time.Second
@@ -78,7 +80,7 @@ func DownloadFile(url, filename string, opts DownloadOptions) string {
localPath := filepath.Join(mediaDir, uuid.New().String()[:8]+"_"+safeName)
// Create HTTP request
req, err := http.NewRequest("GET", url, nil)
req, err := http.NewRequest("GET", urlStr, nil)
if err != nil {
logger.ErrorCF(opts.LoggerPrefix, "Failed to create download request", map[string]any{
"error": err.Error(),
@@ -92,11 +94,24 @@ func DownloadFile(url, filename string, opts DownloadOptions) string {
}
client := &http.Client{Timeout: opts.Timeout}
if opts.ProxyURL != "" {
proxyURL, parseErr := url.Parse(opts.ProxyURL)
if parseErr != nil {
logger.ErrorCF(opts.LoggerPrefix, "Invalid proxy URL for download", map[string]any{
"error": parseErr.Error(),
"proxy": opts.ProxyURL,
})
return ""
}
client.Transport = &http.Transport{
Proxy: http.ProxyURL(proxyURL),
}
}
resp, err := client.Do(req)
if err != nil {
logger.ErrorCF(opts.LoggerPrefix, "Failed to download file", map[string]any{
"error": err.Error(),
"url": url,
"url": urlStr,
})
return ""
}
@@ -105,7 +120,7 @@ func DownloadFile(url, filename string, opts DownloadOptions) string {
if resp.StatusCode != http.StatusOK {
logger.ErrorCF(opts.LoggerPrefix, "File download returned non-200 status", map[string]any{
"status": resp.StatusCode,
"url": url,
"url": urlStr,
})
return ""
}