mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
merge: resolve conflicts with upstream/main
Accept upstream versions for all non-Telegram files to keep PR scope focused on Telegram message chunking only. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
+46
-32
@@ -18,22 +18,24 @@ import (
|
||||
// AgentInstance represents a fully configured agent with its own workspace,
|
||||
// session manager, context builder, and tool registry.
|
||||
type AgentInstance struct {
|
||||
ID string
|
||||
Name string
|
||||
Model string
|
||||
Fallbacks []string
|
||||
Workspace string
|
||||
MaxIterations int
|
||||
MaxTokens int
|
||||
Temperature float64
|
||||
ContextWindow int
|
||||
Provider providers.LLMProvider
|
||||
Sessions *session.SessionManager
|
||||
ContextBuilder *ContextBuilder
|
||||
Tools *tools.ToolRegistry
|
||||
Subagents *config.SubagentsConfig
|
||||
SkillsFilter []string
|
||||
Candidates []providers.FallbackCandidate
|
||||
ID string
|
||||
Name string
|
||||
Model string
|
||||
Fallbacks []string
|
||||
Workspace string
|
||||
MaxIterations int
|
||||
MaxTokens int
|
||||
Temperature float64
|
||||
ContextWindow int
|
||||
SummarizeMessageThreshold int
|
||||
SummarizeTokenPercent int
|
||||
Provider providers.LLMProvider
|
||||
Sessions *session.SessionManager
|
||||
ContextBuilder *ContextBuilder
|
||||
Tools *tools.ToolRegistry
|
||||
Subagents *config.SubagentsConfig
|
||||
SkillsFilter []string
|
||||
Candidates []providers.FallbackCandidate
|
||||
}
|
||||
|
||||
// NewAgentInstance creates an agent instance from config.
|
||||
@@ -101,6 +103,16 @@ func NewAgentInstance(
|
||||
temperature = *defaults.Temperature
|
||||
}
|
||||
|
||||
summarizeMessageThreshold := defaults.SummarizeMessageThreshold
|
||||
if summarizeMessageThreshold == 0 {
|
||||
summarizeMessageThreshold = 20
|
||||
}
|
||||
|
||||
summarizeTokenPercent := defaults.SummarizeTokenPercent
|
||||
if summarizeTokenPercent == 0 {
|
||||
summarizeTokenPercent = 75
|
||||
}
|
||||
|
||||
// Resolve fallback candidates
|
||||
modelCfg := providers.ModelConfig{
|
||||
Primary: model,
|
||||
@@ -149,22 +161,24 @@ func NewAgentInstance(
|
||||
candidates := providers.ResolveCandidatesWithLookup(modelCfg, defaults.Provider, resolveFromModelList)
|
||||
|
||||
return &AgentInstance{
|
||||
ID: agentID,
|
||||
Name: agentName,
|
||||
Model: model,
|
||||
Fallbacks: fallbacks,
|
||||
Workspace: workspace,
|
||||
MaxIterations: maxIter,
|
||||
MaxTokens: maxTokens,
|
||||
Temperature: temperature,
|
||||
ContextWindow: maxTokens,
|
||||
Provider: provider,
|
||||
Sessions: sessionsManager,
|
||||
ContextBuilder: contextBuilder,
|
||||
Tools: toolsRegistry,
|
||||
Subagents: subagents,
|
||||
SkillsFilter: skillsFilter,
|
||||
Candidates: candidates,
|
||||
ID: agentID,
|
||||
Name: agentName,
|
||||
Model: model,
|
||||
Fallbacks: fallbacks,
|
||||
Workspace: workspace,
|
||||
MaxIterations: maxIter,
|
||||
MaxTokens: maxTokens,
|
||||
Temperature: temperature,
|
||||
ContextWindow: maxTokens,
|
||||
SummarizeMessageThreshold: summarizeMessageThreshold,
|
||||
SummarizeTokenPercent: summarizeTokenPercent,
|
||||
Provider: provider,
|
||||
Sessions: sessionsManager,
|
||||
ContextBuilder: contextBuilder,
|
||||
Tools: toolsRegistry,
|
||||
Subagents: subagents,
|
||||
SkillsFilter: skillsFilter,
|
||||
Candidates: candidates,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+66
-57
@@ -118,9 +118,11 @@ func registerSharedTools(
|
||||
PerplexityAPIKey: cfg.Tools.Web.Perplexity.APIKey,
|
||||
PerplexityMaxResults: cfg.Tools.Web.Perplexity.MaxResults,
|
||||
PerplexityEnabled: cfg.Tools.Web.Perplexity.Enabled,
|
||||
ExaAPIKey: cfg.Tools.Web.Exa.APIKey,
|
||||
ExaMaxResults: cfg.Tools.Web.Exa.MaxResults,
|
||||
ExaEnabled: cfg.Tools.Web.Exa.Enabled,
|
||||
GLMSearchAPIKey: cfg.Tools.Web.GLMSearch.APIKey,
|
||||
GLMSearchBaseURL: cfg.Tools.Web.GLMSearch.BaseURL,
|
||||
GLMSearchEngine: cfg.Tools.Web.GLMSearch.SearchEngine,
|
||||
GLMSearchMaxResults: cfg.Tools.Web.GLMSearch.MaxResults,
|
||||
GLMSearchEnabled: cfg.Tools.Web.GLMSearch.Enabled,
|
||||
Proxy: cfg.Tools.Web.Proxy,
|
||||
})
|
||||
if err != nil {
|
||||
@@ -967,62 +969,76 @@ func (al *AgentLoop) runLLMIteration(
|
||||
// Save assistant message with tool calls to session
|
||||
agent.Sessions.AddFullMessage(opts.SessionKey, assistantMsg)
|
||||
|
||||
// Execute tool calls
|
||||
for _, tc := range normalizedToolCalls {
|
||||
argsJSON, _ := json.Marshal(tc.Arguments)
|
||||
argsPreview := utils.Truncate(string(argsJSON), 200)
|
||||
logger.InfoCF("agent", fmt.Sprintf("Tool call: %s(%s)", tc.Name, argsPreview),
|
||||
map[string]any{
|
||||
"agent_id": agent.ID,
|
||||
"tool": tc.Name,
|
||||
"iteration": iteration,
|
||||
})
|
||||
// Execute tool calls in parallel
|
||||
type indexedAgentResult struct {
|
||||
result *tools.ToolResult
|
||||
tc providers.ToolCall
|
||||
}
|
||||
|
||||
// Create async callback for tools that implement AsyncTool
|
||||
// NOTE: Following openclaw's design, async tools do NOT send results directly to users.
|
||||
// Instead, they notify the agent via PublishInbound, and the agent decides
|
||||
// whether to forward the result to the user (in processSystemMessage).
|
||||
asyncCallback := func(callbackCtx context.Context, result *tools.ToolResult) {
|
||||
// Log the async completion but don't send directly to user
|
||||
// The agent will handle user notification via processSystemMessage
|
||||
if !result.Silent && result.ForUser != "" {
|
||||
logger.InfoCF("agent", "Async tool completed, agent will handle notification",
|
||||
map[string]any{
|
||||
"tool": tc.Name,
|
||||
"content_len": len(result.ForUser),
|
||||
})
|
||||
agentResults := make([]indexedAgentResult, len(normalizedToolCalls))
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for i, tc := range normalizedToolCalls {
|
||||
agentResults[i].tc = tc
|
||||
|
||||
wg.Add(1)
|
||||
go func(idx int, tc providers.ToolCall) {
|
||||
defer wg.Done()
|
||||
|
||||
argsJSON, _ := json.Marshal(tc.Arguments)
|
||||
argsPreview := utils.Truncate(string(argsJSON), 200)
|
||||
logger.InfoCF("agent", fmt.Sprintf("Tool call: %s(%s)", tc.Name, argsPreview),
|
||||
map[string]any{
|
||||
"agent_id": agent.ID,
|
||||
"tool": tc.Name,
|
||||
"iteration": iteration,
|
||||
})
|
||||
|
||||
// Create async callback for tools that implement AsyncTool
|
||||
asyncCallback := func(callbackCtx context.Context, result *tools.ToolResult) {
|
||||
if !result.Silent && result.ForUser != "" {
|
||||
logger.InfoCF("agent", "Async tool completed, agent will handle notification",
|
||||
map[string]any{
|
||||
"tool": tc.Name,
|
||||
"content_len": len(result.ForUser),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
toolResult := agent.Tools.ExecuteWithContext(
|
||||
ctx,
|
||||
tc.Name,
|
||||
tc.Arguments,
|
||||
opts.Channel,
|
||||
opts.ChatID,
|
||||
asyncCallback,
|
||||
)
|
||||
toolResult := agent.Tools.ExecuteWithContext(
|
||||
ctx,
|
||||
tc.Name,
|
||||
tc.Arguments,
|
||||
opts.Channel,
|
||||
opts.ChatID,
|
||||
asyncCallback,
|
||||
)
|
||||
agentResults[idx].result = toolResult
|
||||
}(i, tc)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
// Process results in original order (send to user, save to session)
|
||||
for _, r := range agentResults {
|
||||
// Send ForUser content to user immediately if not Silent
|
||||
if !toolResult.Silent && toolResult.ForUser != "" && opts.SendResponse {
|
||||
if !r.result.Silent && r.result.ForUser != "" && opts.SendResponse {
|
||||
al.bus.PublishOutbound(ctx, bus.OutboundMessage{
|
||||
Channel: opts.Channel,
|
||||
ChatID: opts.ChatID,
|
||||
Content: toolResult.ForUser,
|
||||
Content: r.result.ForUser,
|
||||
})
|
||||
logger.DebugCF("agent", "Sent tool result to user",
|
||||
map[string]any{
|
||||
"tool": tc.Name,
|
||||
"content_len": len(toolResult.ForUser),
|
||||
"tool": r.tc.Name,
|
||||
"content_len": len(r.result.ForUser),
|
||||
})
|
||||
}
|
||||
|
||||
// If tool returned media refs, publish them as outbound media
|
||||
if len(toolResult.Media) > 0 && opts.SendResponse {
|
||||
parts := make([]bus.MediaPart, 0, len(toolResult.Media))
|
||||
for _, ref := range toolResult.Media {
|
||||
if len(r.result.Media) > 0 && opts.SendResponse {
|
||||
parts := make([]bus.MediaPart, 0, len(r.result.Media))
|
||||
for _, ref := range r.result.Media {
|
||||
part := bus.MediaPart{Ref: ref}
|
||||
// Populate metadata from MediaStore when available
|
||||
if al.mediaStore != nil {
|
||||
if _, meta, err := al.mediaStore.ResolveWithMeta(ref); err == nil {
|
||||
part.Filename = meta.Filename
|
||||
@@ -1040,15 +1056,15 @@ func (al *AgentLoop) runLLMIteration(
|
||||
}
|
||||
|
||||
// Determine content for LLM based on tool result
|
||||
contentForLLM := toolResult.ForLLM
|
||||
if contentForLLM == "" && toolResult.Err != nil {
|
||||
contentForLLM = toolResult.Err.Error()
|
||||
contentForLLM := r.result.ForLLM
|
||||
if contentForLLM == "" && r.result.Err != nil {
|
||||
contentForLLM = r.result.Err.Error()
|
||||
}
|
||||
|
||||
toolResultMsg := providers.Message{
|
||||
Role: "tool",
|
||||
Content: contentForLLM,
|
||||
ToolCallID: tc.ID,
|
||||
ToolCallID: r.tc.ID,
|
||||
}
|
||||
messages = append(messages, toolResultMsg)
|
||||
|
||||
@@ -1084,9 +1100,9 @@ func (al *AgentLoop) updateToolContexts(agent *AgentInstance, channel, chatID st
|
||||
func (al *AgentLoop) maybeSummarize(agent *AgentInstance, sessionKey, channel, chatID string) {
|
||||
newHistory := agent.Sessions.GetHistory(sessionKey)
|
||||
tokenEstimate := al.estimateTokens(newHistory)
|
||||
threshold := agent.ContextWindow * 75 / 100
|
||||
threshold := agent.ContextWindow * agent.SummarizeTokenPercent / 100
|
||||
|
||||
if len(newHistory) > 20 || tokenEstimate > threshold {
|
||||
if len(newHistory) > agent.SummarizeMessageThreshold || tokenEstimate > threshold {
|
||||
summarizeKey := agent.ID + ":" + sessionKey
|
||||
if _, loading := al.summarizing.LoadOrStore(summarizeKey, true); !loading {
|
||||
go func() {
|
||||
@@ -1114,15 +1130,8 @@ func (al *AgentLoop) forceCompression(agent *AgentInstance, sessionKey string) {
|
||||
return
|
||||
}
|
||||
|
||||
// Find the mid-point of the conversation, avoiding splitting tool call/result pairs.
|
||||
// A tool-call message (role=assistant with ToolCalls) must be followed by its
|
||||
// tool-result message (role=tool). Splitting between them causes API errors.
|
||||
// Helper to find the mid-point of the conversation
|
||||
mid := len(conversation) / 2
|
||||
if mid < len(conversation) && mid > 0 {
|
||||
if conversation[mid].Role == "tool" {
|
||||
mid++ // move past the tool result to keep the pair together
|
||||
}
|
||||
}
|
||||
|
||||
// New history structure:
|
||||
// 1. System Prompt (with compression note appended)
|
||||
|
||||
@@ -603,85 +603,6 @@ func TestAgentLoop_ContextExhaustionRetry(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestForceCompression_ToolMessageBoundary verifies that forceCompression does not
|
||||
// split a tool call/result pair when the midpoint falls on a "tool" role message.
|
||||
// Regression test for: API errors when orphaned tool result messages appear
|
||||
// without their preceding assistant tool-call message.
|
||||
func TestForceCompression_ToolMessageBoundary(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &mockProvider{}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
|
||||
sessionKey := "test-session-tool-boundary"
|
||||
defaultAgent := al.registry.GetDefaultAgent()
|
||||
if defaultAgent == nil {
|
||||
t.Fatal("No default agent found")
|
||||
}
|
||||
|
||||
// Construct a history where len(conversation)/2 falls exactly on a "tool" message.
|
||||
// history = [system, user, assistant(tool_call), tool, user, assistant, user_trigger]
|
||||
// conversation = history[1:6] = [user, assistant(tool_call), tool, user, assistant]
|
||||
// len(conversation) = 5, mid = 5/2 = 2 => conversation[2].Role == "tool"
|
||||
// Without the fix, this would split between assistant(tool_call) and tool result.
|
||||
history := []providers.Message{
|
||||
{Role: "system", Content: "You are a helpful assistant."},
|
||||
{Role: "user", Content: "What files are in the current directory?"},
|
||||
{Role: "assistant", Content: "", ToolCalls: []providers.ToolCall{
|
||||
{ID: "call_1", Name: "exec", Arguments: map[string]any{"command": "ls"}},
|
||||
}},
|
||||
{Role: "tool", Content: "file1.txt\nfile2.txt", ToolCallID: "call_1"},
|
||||
{Role: "user", Content: "Tell me about file1.txt"},
|
||||
{Role: "assistant", Content: "file1.txt is a text file."},
|
||||
{Role: "user", Content: "Thanks"}, // trigger message
|
||||
}
|
||||
|
||||
// Create the session first (AddMessage creates the session entry),
|
||||
// then overwrite with our full history via SetHistory.
|
||||
defaultAgent.Sessions.AddMessage(sessionKey, "system", "init")
|
||||
defaultAgent.Sessions.SetHistory(sessionKey, history)
|
||||
|
||||
// Call forceCompression
|
||||
al.forceCompression(defaultAgent, sessionKey)
|
||||
|
||||
// Verify the result
|
||||
compressed := defaultAgent.Sessions.GetHistory(sessionKey)
|
||||
|
||||
// Check that no message with role="tool" is the first conversation message
|
||||
// (after the system prompt). If it is, it means the tool result was orphaned.
|
||||
for i := 1; i < len(compressed); i++ {
|
||||
if compressed[i].Role == "tool" {
|
||||
// There must be an assistant message with tool calls before it
|
||||
if i == 1 {
|
||||
t.Errorf("Tool result message at position %d is orphaned (no preceding assistant with tool call)", i)
|
||||
} else if compressed[i-1].Role != "assistant" || len(compressed[i-1].ToolCalls) == 0 {
|
||||
t.Errorf("Tool result at position %d is not preceded by assistant with tool calls (preceded by role=%q)", i, compressed[i-1].Role)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Verify the system prompt has the compression note
|
||||
if !strings.Contains(compressed[0].Content, "Emergency compression") {
|
||||
t.Errorf("Expected compression note in system prompt, got: %s", compressed[0].Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTargetReasoningChannelID_AllChannels(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
|
||||
@@ -3,12 +3,15 @@ package discord
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/bwmarrin/discordgo"
|
||||
"github.com/gorilla/websocket"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/channels"
|
||||
@@ -40,6 +43,9 @@ func NewDiscordChannel(cfg config.DiscordConfig, bus *bus.MessageBus) (*DiscordC
|
||||
return nil, fmt.Errorf("failed to create discord session: %w", err)
|
||||
}
|
||||
|
||||
if err := applyDiscordProxy(session, cfg.Proxy); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
base := channels.NewBaseChannel("discord", cfg, bus, cfg.AllowFrom,
|
||||
channels.WithMaxMessageLength(2000),
|
||||
channels.WithGroupTrigger(cfg.GroupTrigger),
|
||||
@@ -465,9 +471,43 @@ func (c *DiscordChannel) StartTyping(ctx context.Context, chatID string) (func()
|
||||
func (c *DiscordChannel) downloadAttachment(url, filename string) string {
|
||||
return utils.DownloadFile(url, filename, utils.DownloadOptions{
|
||||
LoggerPrefix: "discord",
|
||||
ProxyURL: c.config.Proxy,
|
||||
})
|
||||
}
|
||||
|
||||
func applyDiscordProxy(session *discordgo.Session, proxyAddr string) error {
|
||||
var proxyFunc func(*http.Request) (*url.URL, error)
|
||||
if proxyAddr != "" {
|
||||
proxyURL, err := url.Parse(proxyAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid discord proxy URL %q: %w", proxyAddr, err)
|
||||
}
|
||||
proxyFunc = http.ProxyURL(proxyURL)
|
||||
} else if os.Getenv("HTTP_PROXY") != "" || os.Getenv("HTTPS_PROXY") != "" {
|
||||
proxyFunc = http.ProxyFromEnvironment
|
||||
}
|
||||
|
||||
if proxyFunc == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
transport := &http.Transport{Proxy: proxyFunc}
|
||||
session.Client = &http.Client{
|
||||
Timeout: sendTimeout,
|
||||
Transport: transport,
|
||||
}
|
||||
|
||||
if session.Dialer != nil {
|
||||
dialerCopy := *session.Dialer
|
||||
dialerCopy.Proxy = proxyFunc
|
||||
session.Dialer = &dialerCopy
|
||||
} else {
|
||||
session.Dialer = &websocket.Dialer{Proxy: proxyFunc}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// stripBotMention removes the bot mention from the message content.
|
||||
// Discord mentions have the format <@USER_ID> or <@!USER_ID> (with nickname).
|
||||
func (c *DiscordChannel) stripBotMention(text string) string {
|
||||
|
||||
@@ -0,0 +1,91 @@
|
||||
package discord
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/bwmarrin/discordgo"
|
||||
)
|
||||
|
||||
func TestApplyDiscordProxy_CustomProxy(t *testing.T) {
|
||||
session, err := discordgo.New("Bot test-token")
|
||||
if err != nil {
|
||||
t.Fatalf("discordgo.New() error: %v", err)
|
||||
}
|
||||
|
||||
if err = applyDiscordProxy(session, "http://127.0.0.1:7890"); err != nil {
|
||||
t.Fatalf("applyDiscordProxy() error: %v", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("GET", "https://discord.com/api/v10/gateway", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("http.NewRequest() error: %v", err)
|
||||
}
|
||||
|
||||
restProxy := session.Client.Transport.(*http.Transport).Proxy
|
||||
restProxyURL, err := restProxy(req)
|
||||
if err != nil {
|
||||
t.Fatalf("rest proxy func error: %v", err)
|
||||
}
|
||||
if got, want := restProxyURL.String(), "http://127.0.0.1:7890"; got != want {
|
||||
t.Fatalf("REST proxy = %q, want %q", got, want)
|
||||
}
|
||||
|
||||
wsProxyURL, err := session.Dialer.Proxy(req)
|
||||
if err != nil {
|
||||
t.Fatalf("ws proxy func error: %v", err)
|
||||
}
|
||||
if got, want := wsProxyURL.String(), "http://127.0.0.1:7890"; got != want {
|
||||
t.Fatalf("WS proxy = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyDiscordProxy_FromEnvironment(t *testing.T) {
|
||||
t.Setenv("HTTP_PROXY", "http://127.0.0.1:8888")
|
||||
t.Setenv("http_proxy", "http://127.0.0.1:8888")
|
||||
t.Setenv("HTTPS_PROXY", "http://127.0.0.1:8888")
|
||||
t.Setenv("https_proxy", "http://127.0.0.1:8888")
|
||||
t.Setenv("ALL_PROXY", "")
|
||||
t.Setenv("all_proxy", "")
|
||||
t.Setenv("NO_PROXY", "")
|
||||
t.Setenv("no_proxy", "")
|
||||
|
||||
session, err := discordgo.New("Bot test-token")
|
||||
if err != nil {
|
||||
t.Fatalf("discordgo.New() error: %v", err)
|
||||
}
|
||||
|
||||
if err = applyDiscordProxy(session, ""); err != nil {
|
||||
t.Fatalf("applyDiscordProxy() error: %v", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("GET", "https://discord.com/api/v10/gateway", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("http.NewRequest() error: %v", err)
|
||||
}
|
||||
|
||||
gotURL, err := session.Dialer.Proxy(req)
|
||||
if err != nil {
|
||||
t.Fatalf("ws proxy func error: %v", err)
|
||||
}
|
||||
|
||||
wantURL, err := url.Parse("http://127.0.0.1:8888")
|
||||
if err != nil {
|
||||
t.Fatalf("url.Parse() error: %v", err)
|
||||
}
|
||||
if gotURL.String() != wantURL.String() {
|
||||
t.Fatalf("WS proxy = %q, want %q", gotURL.String(), wantURL.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyDiscordProxy_InvalidProxyURL(t *testing.T) {
|
||||
session, err := discordgo.New("Bot test-token")
|
||||
if err != nil {
|
||||
t.Fatalf("discordgo.New() error: %v", err)
|
||||
}
|
||||
|
||||
if err = applyDiscordProxy(session, "://bad-proxy"); err == nil {
|
||||
t.Fatal("applyDiscordProxy() expected error for invalid proxy URL, got nil")
|
||||
}
|
||||
}
|
||||
+13
-85
@@ -3,23 +3,14 @@ package config
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/caarlos0/env/v11"
|
||||
"github.com/joho/godotenv"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/fileutil"
|
||||
)
|
||||
|
||||
// dotenvOnce ensures .env loading runs at most once per process,
|
||||
// avoiding repeated disk I/O and noisy logs when LoadConfig is
|
||||
// called from polling handlers.
|
||||
var dotenvOnce sync.Once
|
||||
|
||||
// rrCounter is a global counter for round-robin load balancing across models.
|
||||
var rrCounter atomic.Uint64
|
||||
|
||||
@@ -189,6 +180,8 @@ type AgentDefaults struct {
|
||||
MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"`
|
||||
Temperature *float64 `json:"temperature,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"`
|
||||
MaxToolIterations int `json:"max_tool_iterations" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"`
|
||||
SummarizeMessageThreshold int `json:"summarize_message_threshold" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_MESSAGE_THRESHOLD"`
|
||||
SummarizeTokenPercent int `json:"summarize_token_percent" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_TOKEN_PERCENT"`
|
||||
MaxMediaSize int `json:"max_media_size,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_MEDIA_SIZE"`
|
||||
}
|
||||
|
||||
@@ -280,6 +273,7 @@ type FeishuConfig struct {
|
||||
type DiscordConfig struct {
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_DISCORD_ENABLED"`
|
||||
Token string `json:"token" env:"PICOCLAW_CHANNELS_DISCORD_TOKEN"`
|
||||
Proxy string `json:"proxy" env:"PICOCLAW_CHANNELS_DISCORD_PROXY"`
|
||||
AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_DISCORD_ALLOW_FROM"`
|
||||
MentionOnly bool `json:"mention_only" env:"PICOCLAW_CHANNELS_DISCORD_MENTION_ONLY"`
|
||||
GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"`
|
||||
@@ -437,7 +431,6 @@ type ProvidersConfig struct {
|
||||
Antigravity ProviderConfig `json:"antigravity"`
|
||||
Qwen ProviderConfig `json:"qwen"`
|
||||
Mistral ProviderConfig `json:"mistral"`
|
||||
Opencode ProviderConfig `json:"opencode"`
|
||||
}
|
||||
|
||||
// IsEmpty checks if all provider configs are empty (no API keys or API bases set)
|
||||
@@ -461,8 +454,7 @@ func (p ProvidersConfig) IsEmpty() bool {
|
||||
p.GitHubCopilot.APIKey == "" && p.GitHubCopilot.APIBase == "" &&
|
||||
p.Antigravity.APIKey == "" && p.Antigravity.APIBase == "" &&
|
||||
p.Qwen.APIKey == "" && p.Qwen.APIBase == "" &&
|
||||
p.Mistral.APIKey == "" && p.Mistral.APIBase == "" &&
|
||||
p.Opencode.APIKey == "" && p.Opencode.APIBase == ""
|
||||
p.Mistral.APIKey == "" && p.Mistral.APIBase == ""
|
||||
}
|
||||
|
||||
// MarshalJSON implements custom JSON marshaling for ProvidersConfig
|
||||
@@ -555,10 +547,14 @@ type PerplexityConfig struct {
|
||||
MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_PERPLEXITY_MAX_RESULTS"`
|
||||
}
|
||||
|
||||
type ExaConfig struct {
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_WEB_EXA_ENABLED"`
|
||||
APIKey string `json:"api_key" env:"PICOCLAW_TOOLS_WEB_EXA_API_KEY"`
|
||||
MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_EXA_MAX_RESULTS"`
|
||||
type GLMSearchConfig struct {
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_WEB_GLM_ENABLED"`
|
||||
APIKey string `json:"api_key" env:"PICOCLAW_TOOLS_WEB_GLM_API_KEY"`
|
||||
BaseURL string `json:"base_url" env:"PICOCLAW_TOOLS_WEB_GLM_BASE_URL"`
|
||||
// SearchEngine specifies the search backend: "search_std" (default),
|
||||
// "search_pro", "search_pro_sogou", or "search_pro_quark".
|
||||
SearchEngine string `json:"search_engine" env:"PICOCLAW_TOOLS_WEB_GLM_SEARCH_ENGINE"`
|
||||
MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_GLM_MAX_RESULTS"`
|
||||
}
|
||||
|
||||
type WebToolsConfig struct {
|
||||
@@ -566,7 +562,7 @@ type WebToolsConfig struct {
|
||||
Tavily TavilyConfig `json:"tavily"`
|
||||
DuckDuckGo DuckDuckGoConfig `json:"duckduckgo"`
|
||||
Perplexity PerplexityConfig `json:"perplexity"`
|
||||
Exa ExaConfig `json:"exa"`
|
||||
GLMSearch GLMSearchConfig `json:"glm_search"`
|
||||
// Proxy is an optional proxy URL for web tools (http/https/socks5/socks5h).
|
||||
// For authenticated proxies, prefer HTTP_PROXY/HTTPS_PROXY env vars instead of embedding credentials in config.
|
||||
Proxy string `json:"proxy,omitempty" env:"PICOCLAW_TOOLS_WEB_PROXY"`
|
||||
@@ -658,35 +654,9 @@ type MCPConfig struct {
|
||||
func LoadConfig(path string) (*Config, error) {
|
||||
cfg := DefaultConfig()
|
||||
|
||||
// Load .env file from config directory (secrets, API keys, etc.)
|
||||
// Guarded by sync.Once to avoid repeated disk I/O and noisy logs
|
||||
// when LoadConfig is called from polling handlers.
|
||||
dotenvOnce.Do(func() {
|
||||
envFile := filepath.Join(filepath.Dir(path), ".env")
|
||||
if err := godotenv.Load(envFile); err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
log.Printf("[INFO] No .env file found at %s; skipping .env loading", envFile)
|
||||
} else {
|
||||
log.Printf("[WARN] Failed to load .env file from %s: %v", envFile, err)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
// No config file — still apply env vars + overrides to default config
|
||||
if err := env.Parse(cfg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
loadProviderEnvOverrides(cfg)
|
||||
cfg.migrateChannelConfigs()
|
||||
if cfg.HasProvidersConfig() {
|
||||
cfg.ModelList = ConvertProvidersToModelList(cfg)
|
||||
}
|
||||
if err := cfg.ValidateModelList(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return cfg, nil
|
||||
}
|
||||
return nil, err
|
||||
@@ -714,9 +684,6 @@ func LoadConfig(path string) (*Config, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Load provider-specific env overrides (PICOCLAW_PROVIDERS_<NAME>_API_KEY, etc.)
|
||||
loadProviderEnvOverrides(cfg)
|
||||
|
||||
// Migrate legacy channel config fields to new unified structures
|
||||
cfg.migrateChannelConfigs()
|
||||
|
||||
@@ -865,42 +832,3 @@ func (c *Config) ValidateModelList() error {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadProviderEnvOverrides reads PICOCLAW_PROVIDERS_<NAME>_API_KEY and _API_BASE
|
||||
// environment variables and sets them on the corresponding provider config fields.
|
||||
// This enables storing provider secrets in .env files without using struct tags.
|
||||
func loadProviderEnvOverrides(cfg *Config) {
|
||||
providers := []struct {
|
||||
name string
|
||||
apiKey *string
|
||||
base *string
|
||||
}{
|
||||
{"ANTHROPIC", &cfg.Providers.Anthropic.APIKey, &cfg.Providers.Anthropic.APIBase},
|
||||
{"OPENAI", &cfg.Providers.OpenAI.APIKey, &cfg.Providers.OpenAI.APIBase},
|
||||
{"LITELLM", &cfg.Providers.LiteLLM.APIKey, &cfg.Providers.LiteLLM.APIBase},
|
||||
{"OPENROUTER", &cfg.Providers.OpenRouter.APIKey, &cfg.Providers.OpenRouter.APIBase},
|
||||
{"GROQ", &cfg.Providers.Groq.APIKey, &cfg.Providers.Groq.APIBase},
|
||||
{"ZHIPU", &cfg.Providers.Zhipu.APIKey, &cfg.Providers.Zhipu.APIBase},
|
||||
{"GEMINI", &cfg.Providers.Gemini.APIKey, &cfg.Providers.Gemini.APIBase},
|
||||
{"NVIDIA", &cfg.Providers.Nvidia.APIKey, &cfg.Providers.Nvidia.APIBase},
|
||||
{"OLLAMA", &cfg.Providers.Ollama.APIKey, &cfg.Providers.Ollama.APIBase},
|
||||
{"MOONSHOT", &cfg.Providers.Moonshot.APIKey, &cfg.Providers.Moonshot.APIBase},
|
||||
{"SHENGSUANYUN", &cfg.Providers.ShengSuanYun.APIKey, &cfg.Providers.ShengSuanYun.APIBase},
|
||||
{"DEEPSEEK", &cfg.Providers.DeepSeek.APIKey, &cfg.Providers.DeepSeek.APIBase},
|
||||
{"MISTRAL", &cfg.Providers.Mistral.APIKey, &cfg.Providers.Mistral.APIBase},
|
||||
{"VLLM", &cfg.Providers.VLLM.APIKey, &cfg.Providers.VLLM.APIBase},
|
||||
{"CEREBRAS", &cfg.Providers.Cerebras.APIKey, &cfg.Providers.Cerebras.APIBase},
|
||||
{"VOLCENGINE", &cfg.Providers.VolcEngine.APIKey, &cfg.Providers.VolcEngine.APIBase},
|
||||
{"QWEN", &cfg.Providers.Qwen.APIKey, &cfg.Providers.Qwen.APIBase},
|
||||
// Note: GitHubCopilot and Antigravity use different auth patterns (ConnectMode/AuthMethod),
|
||||
// not standard APIKey/APIBase, so they are not included here.
|
||||
}
|
||||
for _, p := range providers {
|
||||
if v, ok := os.LookupEnv("PICOCLAW_PROVIDERS_" + p.name + "_API_KEY"); ok {
|
||||
*p.apiKey = v
|
||||
}
|
||||
if v, ok := os.LookupEnv("PICOCLAW_PROVIDERS_" + p.name + "_API_BASE"); ok {
|
||||
*p.base = v
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
+12
-96
@@ -6,7 +6,6 @@ import (
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@@ -436,6 +435,18 @@ func TestLoadConfig_WebToolsProxy(t *testing.T) {
|
||||
}
|
||||
|
||||
// TestDefaultConfig_DMScope verifies the default dm_scope value
|
||||
// TestDefaultConfig_SummarizationThresholds verifies summarization defaults
|
||||
func TestDefaultConfig_SummarizationThresholds(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
|
||||
if cfg.Agents.Defaults.SummarizeMessageThreshold != 20 {
|
||||
t.Errorf("SummarizeMessageThreshold = %d, want 20", cfg.Agents.Defaults.SummarizeMessageThreshold)
|
||||
}
|
||||
if cfg.Agents.Defaults.SummarizeTokenPercent != 75 {
|
||||
t.Errorf("SummarizeTokenPercent = %d, want 75", cfg.Agents.Defaults.SummarizeTokenPercent)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultConfig_DMScope(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
|
||||
@@ -468,98 +479,3 @@ func TestDefaultConfig_WorkspacePath_WithPicoclawHome(t *testing.T) {
|
||||
t.Errorf("Workspace path with PICOCLAW_HOME = %q, want %q", cfg.Agents.Defaults.Workspace, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfig_DotenvFileLoaded(t *testing.T) {
|
||||
// Reset sync.Once so .env loading runs for this test
|
||||
dotenvOnce = sync.Once{}
|
||||
|
||||
dir := t.TempDir()
|
||||
configPath := filepath.Join(dir, "config.json")
|
||||
|
||||
// Write a minimal config.json
|
||||
if err := os.WriteFile(configPath, []byte(`{}`), 0o600); err != nil {
|
||||
t.Fatalf("WriteFile config: %v", err)
|
||||
}
|
||||
|
||||
// Write a .env file with a provider API key
|
||||
envFile := filepath.Join(dir, ".env")
|
||||
if err := os.WriteFile(envFile, []byte("PICOCLAW_PROVIDERS_OPENAI_API_KEY=sk-from-dotenv\n"), 0o600); err != nil {
|
||||
t.Fatalf("WriteFile .env: %v", err)
|
||||
}
|
||||
|
||||
// Clear the env var first to ensure it comes from .env
|
||||
t.Setenv("PICOCLAW_PROVIDERS_OPENAI_API_KEY", "")
|
||||
os.Unsetenv("PICOCLAW_PROVIDERS_OPENAI_API_KEY")
|
||||
|
||||
cfg, err := LoadConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig() error: %v", err)
|
||||
}
|
||||
|
||||
if cfg.Providers.OpenAI.APIKey != "sk-from-dotenv" {
|
||||
t.Errorf("OpenAI.APIKey = %q, want %q", cfg.Providers.OpenAI.APIKey, "sk-from-dotenv")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfig_MissingConfigJSON_AppliesEnvVars(t *testing.T) {
|
||||
// Reset sync.Once so .env loading runs for this test
|
||||
dotenvOnce = sync.Once{}
|
||||
|
||||
dir := t.TempDir()
|
||||
configPath := filepath.Join(dir, "config.json") // does NOT exist
|
||||
|
||||
t.Setenv("PICOCLAW_PROVIDERS_ANTHROPIC_API_KEY", "sk-anthropic-test")
|
||||
|
||||
cfg, err := LoadConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig() error: %v", err)
|
||||
}
|
||||
|
||||
if cfg.Providers.Anthropic.APIKey != "sk-anthropic-test" {
|
||||
t.Errorf("Anthropic.APIKey = %q, want %q", cfg.Providers.Anthropic.APIKey, "sk-anthropic-test")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfig_MalformedDotenv_NonFatal(t *testing.T) {
|
||||
// Reset sync.Once so .env loading runs for this test
|
||||
dotenvOnce = sync.Once{}
|
||||
|
||||
dir := t.TempDir()
|
||||
configPath := filepath.Join(dir, "config.json")
|
||||
|
||||
// Write a minimal config.json
|
||||
if err := os.WriteFile(configPath, []byte(`{}`), 0o600); err != nil {
|
||||
t.Fatalf("WriteFile config: %v", err)
|
||||
}
|
||||
|
||||
// Write a .env file with genuinely malformed content (bare key without '=',
|
||||
// mixed with a valid line) to verify godotenv.Load errors are non-fatal.
|
||||
envFile := filepath.Join(dir, ".env")
|
||||
if err := os.WriteFile(envFile, []byte("THIS_LINE_HAS_NO_EQUALS\nVALID_KEY=valid_value\n"), 0o600); err != nil {
|
||||
t.Fatalf("WriteFile .env: %v", err)
|
||||
}
|
||||
|
||||
// LoadConfig should not fail even with malformed .env content
|
||||
cfg, err := LoadConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig() should not fail with .env issues, got error: %v", err)
|
||||
}
|
||||
if cfg == nil {
|
||||
t.Fatal("LoadConfig() returned nil config")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadProviderEnvOverrides_LookupEnv(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
|
||||
// Set a key to a non-empty value, then override with empty via env
|
||||
cfg.Providers.OpenRouter.APIBase = "https://original.com"
|
||||
t.Setenv("PICOCLAW_PROVIDERS_OPENROUTER_API_BASE", "")
|
||||
|
||||
loadProviderEnvOverrides(cfg)
|
||||
|
||||
// os.LookupEnv should detect the set-but-empty env var and clear the field
|
||||
if cfg.Providers.OpenRouter.APIBase != "" {
|
||||
t.Errorf("OpenRouter.APIBase = %q, want empty (overridden by empty env var)", cfg.Providers.OpenRouter.APIBase)
|
||||
}
|
||||
}
|
||||
|
||||
+15
-11
@@ -26,13 +26,15 @@ func DefaultConfig() *Config {
|
||||
return &Config{
|
||||
Agents: AgentsConfig{
|
||||
Defaults: AgentDefaults{
|
||||
Workspace: workspacePath,
|
||||
RestrictToWorkspace: true,
|
||||
Provider: "",
|
||||
Model: "",
|
||||
MaxTokens: 32768,
|
||||
Temperature: nil, // nil means use provider default
|
||||
MaxToolIterations: 50,
|
||||
Workspace: workspacePath,
|
||||
RestrictToWorkspace: true,
|
||||
Provider: "",
|
||||
Model: "",
|
||||
MaxTokens: 32768,
|
||||
Temperature: nil, // nil means use provider default
|
||||
MaxToolIterations: 50,
|
||||
SummarizeMessageThreshold: 20,
|
||||
SummarizeTokenPercent: 75,
|
||||
},
|
||||
},
|
||||
Bindings: []AgentBinding{},
|
||||
@@ -341,10 +343,12 @@ func DefaultConfig() *Config {
|
||||
APIKey: "",
|
||||
MaxResults: 5,
|
||||
},
|
||||
Exa: ExaConfig{
|
||||
Enabled: false,
|
||||
APIKey: "",
|
||||
MaxResults: 5,
|
||||
GLMSearch: GLMSearchConfig{
|
||||
Enabled: false,
|
||||
APIKey: "",
|
||||
BaseURL: "https://open.bigmodel.cn/api/paas/v4/web_search",
|
||||
SearchEngine: "search_std",
|
||||
MaxResults: 5,
|
||||
},
|
||||
},
|
||||
Cron: CronToolsConfig{
|
||||
|
||||
+2
-21
@@ -225,7 +225,7 @@ func ConvertProvidersToModelList(cfg *Config) []ModelConfig {
|
||||
},
|
||||
},
|
||||
{
|
||||
providerNames: []string{"moonshot", "kimi", "kimi-code"},
|
||||
providerNames: []string{"moonshot", "kimi"},
|
||||
protocol: "moonshot",
|
||||
buildConfig: func(p ProvidersConfig) (ModelConfig, bool) {
|
||||
if p.Moonshot.APIKey == "" && p.Moonshot.APIBase == "" {
|
||||
@@ -373,23 +373,6 @@ func ConvertProvidersToModelList(cfg *Config) []ModelConfig {
|
||||
}, true
|
||||
},
|
||||
},
|
||||
{
|
||||
providerNames: []string{"opencode"},
|
||||
protocol: "opencode",
|
||||
buildConfig: func(p ProvidersConfig) (ModelConfig, bool) {
|
||||
if p.Opencode.APIKey == "" && p.Opencode.APIBase == "" {
|
||||
return ModelConfig{}, false
|
||||
}
|
||||
return ModelConfig{
|
||||
ModelName: "opencode",
|
||||
Model: "opencode/auto",
|
||||
APIKey: p.Opencode.APIKey,
|
||||
APIBase: p.Opencode.APIBase,
|
||||
Proxy: p.Opencode.Proxy,
|
||||
RequestTimeout: p.Opencode.RequestTimeout,
|
||||
}, true
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Process each provider migration
|
||||
@@ -401,9 +384,7 @@ func ConvertProvidersToModelList(cfg *Config) []ModelConfig {
|
||||
|
||||
// Check if this is the user's configured provider
|
||||
if slices.Contains(m.providerNames, userProvider) && userModel != "" {
|
||||
// Use the user's configured model instead of default.
|
||||
// Also set ModelName so GetModelConfig(userModel) can find this entry.
|
||||
mc.ModelName = userModel
|
||||
// Use the user's configured model instead of default
|
||||
mc.Model = buildModelWithProtocol(m.protocol, userModel)
|
||||
} else if userProvider == "" && userModel != "" && !legacyModelNameApplied {
|
||||
// Legacy config: no explicit provider field but model is specified
|
||||
|
||||
@@ -160,7 +160,6 @@ func TestConvertProvidersToModelList_AllProviders(t *testing.T) {
|
||||
Antigravity: ProviderConfig{AuthMethod: "oauth"},
|
||||
Qwen: ProviderConfig{APIKey: "key17"},
|
||||
Mistral: ProviderConfig{APIKey: "key18"},
|
||||
Opencode: ProviderConfig{APIKey: "key19"},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -580,65 +579,6 @@ func TestBuildModelWithProtocol_DifferentPrefix(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertProvidersToModelList_Opencode(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Providers: ProvidersConfig{
|
||||
Opencode: ProviderConfig{
|
||||
APIKey: "oc-test-key",
|
||||
APIBase: "https://custom.opencode.ai/v1",
|
||||
Proxy: "http://proxy:9090",
|
||||
RequestTimeout: 60,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result := ConvertProvidersToModelList(cfg)
|
||||
|
||||
if len(result) != 1 {
|
||||
t.Fatalf("len(result) = %d, want 1", len(result))
|
||||
}
|
||||
|
||||
mc := result[0]
|
||||
if mc.ModelName != "opencode" {
|
||||
t.Errorf("ModelName = %q, want %q", mc.ModelName, "opencode")
|
||||
}
|
||||
if mc.Model != "opencode/auto" {
|
||||
t.Errorf("Model = %q, want %q", mc.Model, "opencode/auto")
|
||||
}
|
||||
if mc.APIKey != "oc-test-key" {
|
||||
t.Errorf("APIKey = %q, want %q", mc.APIKey, "oc-test-key")
|
||||
}
|
||||
if mc.APIBase != "https://custom.opencode.ai/v1" {
|
||||
t.Errorf("APIBase = %q, want %q", mc.APIBase, "https://custom.opencode.ai/v1")
|
||||
}
|
||||
if mc.Proxy != "http://proxy:9090" {
|
||||
t.Errorf("Proxy = %q, want %q", mc.Proxy, "http://proxy:9090")
|
||||
}
|
||||
if mc.RequestTimeout != 60 {
|
||||
t.Errorf("RequestTimeout = %d, want %d", mc.RequestTimeout, 60)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertProvidersToModelList_Opencode_APIBaseOnly(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Providers: ProvidersConfig{
|
||||
Opencode: ProviderConfig{
|
||||
APIBase: "https://custom.opencode.ai/v1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result := ConvertProvidersToModelList(cfg)
|
||||
|
||||
if len(result) != 1 {
|
||||
t.Fatalf("len(result) = %d, want 1 (APIBase-only should create entry)", len(result))
|
||||
}
|
||||
|
||||
if result[0].ModelName != "opencode" {
|
||||
t.Errorf("ModelName = %q, want %q", result[0].ModelName, "opencode")
|
||||
}
|
||||
}
|
||||
|
||||
// Test for legacy config with protocol prefix in model name
|
||||
func TestConvertProvidersToModelList_LegacyModelWithProtocolPrefix(t *testing.T) {
|
||||
cfg := &Config{
|
||||
@@ -669,72 +609,3 @@ func TestConvertProvidersToModelList_LegacyModelWithProtocolPrefix(t *testing.T)
|
||||
t.Errorf("Model = %q, want %q (should not duplicate prefix)", result[0].Model, "openrouter/auto")
|
||||
}
|
||||
}
|
||||
|
||||
// Test that ModelName is set to the user's configured model when provider matches.
|
||||
// This ensures GetModelConfig(userModel) can find the migrated entry.
|
||||
// Regression test for: gateway startup failure when user model differs from provider name.
|
||||
func TestConvertProvidersToModelList_ModelNameMatchesUserModel(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Agents: AgentsConfig{
|
||||
Defaults: AgentDefaults{
|
||||
Provider: "moonshot",
|
||||
Model: "k2p5",
|
||||
},
|
||||
},
|
||||
Providers: ProvidersConfig{
|
||||
Moonshot: ProviderConfig{APIKey: "sk-kimi-test"},
|
||||
},
|
||||
}
|
||||
|
||||
result := ConvertProvidersToModelList(cfg)
|
||||
|
||||
if len(result) != 1 {
|
||||
t.Fatalf("len(result) = %d, want 1", len(result))
|
||||
}
|
||||
|
||||
// ModelName must match the user's configured model, not the provider name.
|
||||
// Without this, GetModelConfig("k2p5") would fail because it would look
|
||||
// for ModelName == "k2p5" but find ModelName == "moonshot".
|
||||
if result[0].ModelName != "k2p5" {
|
||||
t.Errorf("ModelName = %q, want %q (must match user's model for GetModelConfig lookup)", result[0].ModelName, "k2p5")
|
||||
}
|
||||
|
||||
if result[0].Model != "moonshot/k2p5" {
|
||||
t.Errorf("Model = %q, want %q", result[0].Model, "moonshot/k2p5")
|
||||
}
|
||||
|
||||
// Other providers (not matching the user's configured provider) should keep their provider name
|
||||
cfg2 := &Config{
|
||||
Agents: AgentsConfig{
|
||||
Defaults: AgentDefaults{
|
||||
Provider: "moonshot",
|
||||
Model: "k2p5",
|
||||
},
|
||||
},
|
||||
Providers: ProvidersConfig{
|
||||
OpenAI: OpenAIProviderConfig{ProviderConfig: ProviderConfig{APIKey: "sk-openai"}},
|
||||
Moonshot: ProviderConfig{APIKey: "sk-kimi-test"},
|
||||
},
|
||||
}
|
||||
|
||||
result2 := ConvertProvidersToModelList(cfg2)
|
||||
|
||||
if len(result2) != 2 {
|
||||
t.Fatalf("len(result2) = %d, want 2", len(result2))
|
||||
}
|
||||
|
||||
for _, mc := range result2 {
|
||||
switch {
|
||||
case mc.APIKey == "sk-openai":
|
||||
// OpenAI is not the user's provider, should keep default ModelName
|
||||
if mc.ModelName != "openai" {
|
||||
t.Errorf("OpenAI ModelName = %q, want %q (non-matching provider keeps default)", mc.ModelName, "openai")
|
||||
}
|
||||
case mc.APIKey == "sk-kimi-test":
|
||||
// Moonshot is the user's provider, ModelName must be the user's model
|
||||
if mc.ModelName != "k2p5" {
|
||||
t.Errorf("Moonshot ModelName = %q, want %q (matching provider uses user model)", mc.ModelName, "k2p5")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,460 @@
|
||||
package memory
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"hash/fnv"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/fileutil"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
const (
|
||||
// numLockShards is the fixed number of mutexes used to serialize
|
||||
// per-session access. Using a sharded array instead of a map keeps
|
||||
// memory bounded regardless of how many sessions are created over
|
||||
// the lifetime of the process — important for a long-running daemon.
|
||||
numLockShards = 64
|
||||
|
||||
// maxLineSize is the maximum size of a single JSON line in a .jsonl
|
||||
// file. Tool results (read_file, web search, etc.) can be large, so
|
||||
// we set a generous limit. The scanner starts at 64 KB and grows
|
||||
// only as needed up to this cap.
|
||||
maxLineSize = 10 * 1024 * 1024 // 10 MB
|
||||
)
|
||||
|
||||
// sessionMeta holds per-session metadata stored in a .meta.json file.
|
||||
type sessionMeta struct {
|
||||
Key string `json:"key"`
|
||||
Summary string `json:"summary"`
|
||||
Skip int `json:"skip"`
|
||||
Count int `json:"count"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
// JSONLStore implements Store using append-only JSONL files.
|
||||
//
|
||||
// Each session is stored as two files:
|
||||
//
|
||||
// {sanitized_key}.jsonl — one JSON-encoded message per line, append-only
|
||||
// {sanitized_key}.meta.json — session metadata (summary, logical truncation offset)
|
||||
//
|
||||
// Messages are never physically deleted from the JSONL file. Instead,
|
||||
// TruncateHistory records a "skip" offset in the metadata file and
|
||||
// GetHistory ignores lines before that offset. This keeps all writes
|
||||
// append-only, which is both fast and crash-safe.
|
||||
type JSONLStore struct {
|
||||
dir string
|
||||
locks [numLockShards]sync.Mutex
|
||||
}
|
||||
|
||||
// NewJSONLStore creates a new JSONL-backed store rooted at dir.
|
||||
func NewJSONLStore(dir string) (*JSONLStore, error) {
|
||||
err := os.MkdirAll(dir, 0o755)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("memory: create directory: %w", err)
|
||||
}
|
||||
return &JSONLStore{dir: dir}, nil
|
||||
}
|
||||
|
||||
// sessionLock returns a mutex for the given session key.
|
||||
// Keys are mapped to a fixed pool of shards via FNV hash, so
|
||||
// memory usage is O(1) regardless of total session count.
|
||||
func (s *JSONLStore) sessionLock(key string) *sync.Mutex {
|
||||
h := fnv.New32a()
|
||||
h.Write([]byte(key))
|
||||
return &s.locks[h.Sum32()%numLockShards]
|
||||
}
|
||||
|
||||
func (s *JSONLStore) jsonlPath(key string) string {
|
||||
return filepath.Join(s.dir, sanitizeKey(key)+".jsonl")
|
||||
}
|
||||
|
||||
func (s *JSONLStore) metaPath(key string) string {
|
||||
return filepath.Join(s.dir, sanitizeKey(key)+".meta.json")
|
||||
}
|
||||
|
||||
// sanitizeKey converts a session key to a safe filename component.
|
||||
// Mirrors pkg/session.sanitizeFilename so that migration paths match.
|
||||
//
|
||||
// Note: this is a lossy mapping — "telegram:123" and "telegram_123"
|
||||
// both produce the same filename. This is an intentional tradeoff:
|
||||
// keys with colons (e.g. from channels) are by far the common case,
|
||||
// and a bidirectional encoding (like URL-encoding) would complicate
|
||||
// file listings and debugging.
|
||||
func sanitizeKey(key string) string {
|
||||
return strings.ReplaceAll(key, ":", "_")
|
||||
}
|
||||
|
||||
// readMeta loads the metadata file for a session.
|
||||
// Returns a zero-value sessionMeta if the file does not exist.
|
||||
func (s *JSONLStore) readMeta(key string) (sessionMeta, error) {
|
||||
data, err := os.ReadFile(s.metaPath(key))
|
||||
if os.IsNotExist(err) {
|
||||
return sessionMeta{Key: key}, nil
|
||||
}
|
||||
if err != nil {
|
||||
return sessionMeta{}, fmt.Errorf("memory: read meta: %w", err)
|
||||
}
|
||||
var meta sessionMeta
|
||||
err = json.Unmarshal(data, &meta)
|
||||
if err != nil {
|
||||
return sessionMeta{}, fmt.Errorf("memory: decode meta: %w", err)
|
||||
}
|
||||
return meta, nil
|
||||
}
|
||||
|
||||
// writeMeta atomically writes the metadata file using the project's
|
||||
// standard WriteFileAtomic (temp + fsync + rename).
|
||||
func (s *JSONLStore) writeMeta(key string, meta sessionMeta) error {
|
||||
data, err := json.MarshalIndent(meta, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("memory: encode meta: %w", err)
|
||||
}
|
||||
return fileutil.WriteFileAtomic(s.metaPath(key), data, 0o644)
|
||||
}
|
||||
|
||||
// readMessages reads valid JSON lines from a .jsonl file, skipping
|
||||
// the first `skip` lines without unmarshaling them. This avoids the
|
||||
// cost of json.Unmarshal on logically truncated messages.
|
||||
// Malformed trailing lines (e.g. from a crash) are silently skipped.
|
||||
func readMessages(path string, skip int) ([]providers.Message, error) {
|
||||
f, err := os.Open(path)
|
||||
if os.IsNotExist(err) {
|
||||
return []providers.Message{}, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("memory: open jsonl: %w", err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
var msgs []providers.Message
|
||||
scanner := bufio.NewScanner(f)
|
||||
// Allow large lines for tool results (read_file, web search, etc.).
|
||||
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
|
||||
|
||||
lineNum := 0
|
||||
for scanner.Scan() {
|
||||
line := scanner.Bytes()
|
||||
if len(line) == 0 {
|
||||
continue
|
||||
}
|
||||
lineNum++
|
||||
if lineNum <= skip {
|
||||
continue
|
||||
}
|
||||
var msg providers.Message
|
||||
if err := json.Unmarshal(line, &msg); err != nil {
|
||||
// Corrupt line — likely a partial write from a crash.
|
||||
// Log so operators know data was skipped, but don't
|
||||
// fail the entire read; this is the standard JSONL
|
||||
// recovery pattern.
|
||||
log.Printf("memory: skipping corrupt line %d in %s: %v",
|
||||
lineNum, filepath.Base(path), err)
|
||||
continue
|
||||
}
|
||||
msgs = append(msgs, msg)
|
||||
}
|
||||
if scanner.Err() != nil {
|
||||
return nil, fmt.Errorf("memory: scan jsonl: %w", scanner.Err())
|
||||
}
|
||||
|
||||
if msgs == nil {
|
||||
msgs = []providers.Message{}
|
||||
}
|
||||
return msgs, nil
|
||||
}
|
||||
|
||||
// countLines counts the total number of non-empty lines in a .jsonl file.
|
||||
// Used by TruncateHistory to reconcile a stale meta.Count without
|
||||
// the overhead of unmarshaling every message.
|
||||
func countLines(path string) (int, error) {
|
||||
f, err := os.Open(path)
|
||||
if os.IsNotExist(err) {
|
||||
return 0, nil
|
||||
}
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("memory: open jsonl: %w", err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
n := 0
|
||||
scanner := bufio.NewScanner(f)
|
||||
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
|
||||
for scanner.Scan() {
|
||||
if len(scanner.Bytes()) > 0 {
|
||||
n++
|
||||
}
|
||||
}
|
||||
return n, scanner.Err()
|
||||
}
|
||||
|
||||
func (s *JSONLStore) AddMessage(
|
||||
_ context.Context, sessionKey, role, content string,
|
||||
) error {
|
||||
return s.addMsg(sessionKey, providers.Message{
|
||||
Role: role,
|
||||
Content: content,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *JSONLStore) AddFullMessage(
|
||||
_ context.Context, sessionKey string, msg providers.Message,
|
||||
) error {
|
||||
return s.addMsg(sessionKey, msg)
|
||||
}
|
||||
|
||||
// addMsg is the shared implementation for AddMessage and AddFullMessage.
|
||||
func (s *JSONLStore) addMsg(sessionKey string, msg providers.Message) error {
|
||||
l := s.sessionLock(sessionKey)
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
|
||||
// Append the message as a single JSON line.
|
||||
line, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("memory: marshal message: %w", err)
|
||||
}
|
||||
line = append(line, '\n')
|
||||
|
||||
f, err := os.OpenFile(
|
||||
s.jsonlPath(sessionKey),
|
||||
os.O_CREATE|os.O_WRONLY|os.O_APPEND,
|
||||
0o644,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("memory: open jsonl for append: %w", err)
|
||||
}
|
||||
_, writeErr := f.Write(line)
|
||||
if writeErr != nil {
|
||||
f.Close()
|
||||
return fmt.Errorf("memory: append message: %w", writeErr)
|
||||
}
|
||||
// Flush to physical storage before closing. This matches the
|
||||
// durability guarantee of writeMeta and rewriteJSONL (which use
|
||||
// WriteFileAtomic with fsync). Without Sync, a power loss could
|
||||
// leave the append in the kernel page cache only — lost on reboot.
|
||||
if syncErr := f.Sync(); syncErr != nil {
|
||||
f.Close()
|
||||
return fmt.Errorf("memory: sync jsonl: %w", syncErr)
|
||||
}
|
||||
if closeErr := f.Close(); closeErr != nil {
|
||||
return fmt.Errorf("memory: close jsonl: %w", closeErr)
|
||||
}
|
||||
|
||||
// Update metadata.
|
||||
meta, err := s.readMeta(sessionKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
now := time.Now()
|
||||
if meta.Count == 0 && meta.CreatedAt.IsZero() {
|
||||
meta.CreatedAt = now
|
||||
}
|
||||
meta.Count++
|
||||
meta.UpdatedAt = now
|
||||
|
||||
return s.writeMeta(sessionKey, meta)
|
||||
}
|
||||
|
||||
func (s *JSONLStore) GetHistory(
|
||||
_ context.Context, sessionKey string,
|
||||
) ([]providers.Message, error) {
|
||||
l := s.sessionLock(sessionKey)
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
|
||||
meta, err := s.readMeta(sessionKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Pass meta.Skip so readMessages skips those lines without
|
||||
// unmarshaling them — avoids wasted CPU on truncated messages.
|
||||
msgs, err := readMessages(s.jsonlPath(sessionKey), meta.Skip)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return msgs, nil
|
||||
}
|
||||
|
||||
func (s *JSONLStore) GetSummary(
|
||||
_ context.Context, sessionKey string,
|
||||
) (string, error) {
|
||||
l := s.sessionLock(sessionKey)
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
|
||||
meta, err := s.readMeta(sessionKey)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return meta.Summary, nil
|
||||
}
|
||||
|
||||
func (s *JSONLStore) SetSummary(
|
||||
_ context.Context, sessionKey, summary string,
|
||||
) error {
|
||||
l := s.sessionLock(sessionKey)
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
|
||||
meta, err := s.readMeta(sessionKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
now := time.Now()
|
||||
if meta.CreatedAt.IsZero() {
|
||||
meta.CreatedAt = now
|
||||
}
|
||||
meta.Summary = summary
|
||||
meta.UpdatedAt = now
|
||||
|
||||
return s.writeMeta(sessionKey, meta)
|
||||
}
|
||||
|
||||
func (s *JSONLStore) TruncateHistory(
|
||||
_ context.Context, sessionKey string, keepLast int,
|
||||
) error {
|
||||
l := s.sessionLock(sessionKey)
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
|
||||
meta, err := s.readMeta(sessionKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Always reconcile meta.Count with the actual line count on disk.
|
||||
// A crash between the JSONL append and the meta update in addMsg
|
||||
// leaves meta.Count stale (e.g. file has 101 lines but meta says
|
||||
// 100). Counting lines is cheap — no unmarshal, just a scan — and
|
||||
// TruncateHistory is not a hot path, so always re-count.
|
||||
n, countErr := countLines(s.jsonlPath(sessionKey))
|
||||
if countErr != nil {
|
||||
return countErr
|
||||
}
|
||||
meta.Count = n
|
||||
|
||||
if keepLast <= 0 {
|
||||
meta.Skip = meta.Count
|
||||
} else {
|
||||
effective := meta.Count - meta.Skip
|
||||
if keepLast < effective {
|
||||
meta.Skip = meta.Count - keepLast
|
||||
}
|
||||
}
|
||||
meta.UpdatedAt = time.Now()
|
||||
|
||||
return s.writeMeta(sessionKey, meta)
|
||||
}
|
||||
|
||||
func (s *JSONLStore) SetHistory(
|
||||
_ context.Context,
|
||||
sessionKey string,
|
||||
history []providers.Message,
|
||||
) error {
|
||||
l := s.sessionLock(sessionKey)
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
|
||||
meta, err := s.readMeta(sessionKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
now := time.Now()
|
||||
if meta.CreatedAt.IsZero() {
|
||||
meta.CreatedAt = now
|
||||
}
|
||||
meta.Skip = 0
|
||||
meta.Count = len(history)
|
||||
meta.UpdatedAt = now
|
||||
|
||||
// Write meta BEFORE rewriting the JSONL file. If we crash between
|
||||
// the two writes, meta has Skip=0 and the old file is still intact,
|
||||
// so GetHistory reads from line 1 — returning "too many" messages
|
||||
// rather than losing data. The next SetHistory call corrects this.
|
||||
err = s.writeMeta(sessionKey, meta)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return s.rewriteJSONL(sessionKey, history)
|
||||
}
|
||||
|
||||
// Compact physically rewrites the JSONL file, dropping all logically
|
||||
// skipped lines. This reclaims disk space that accumulates after
|
||||
// repeated TruncateHistory calls.
|
||||
//
|
||||
// It is safe to call at any time; if there is nothing to compact
|
||||
// (skip == 0) the method returns immediately.
|
||||
func (s *JSONLStore) Compact(
|
||||
_ context.Context, sessionKey string,
|
||||
) error {
|
||||
l := s.sessionLock(sessionKey)
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
|
||||
meta, err := s.readMeta(sessionKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if meta.Skip == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Read only the active messages, skipping truncated lines
|
||||
// without unmarshaling them.
|
||||
active, err := readMessages(s.jsonlPath(sessionKey), meta.Skip)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Write meta BEFORE rewriting the JSONL file. If the process
|
||||
// crashes between the two writes, meta has Skip=0 and the old
|
||||
// (uncompacted) file is still intact, so GetHistory reads from
|
||||
// line 1 — returning previously-truncated messages rather than
|
||||
// losing data. The next Compact or TruncateHistory corrects this.
|
||||
meta.Skip = 0
|
||||
meta.Count = len(active)
|
||||
meta.UpdatedAt = time.Now()
|
||||
|
||||
err = s.writeMeta(sessionKey, meta)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return s.rewriteJSONL(sessionKey, active)
|
||||
}
|
||||
|
||||
// rewriteJSONL atomically replaces the JSONL file with the given messages
|
||||
// using the project's standard WriteFileAtomic (temp + fsync + rename).
|
||||
func (s *JSONLStore) rewriteJSONL(
|
||||
sessionKey string, msgs []providers.Message,
|
||||
) error {
|
||||
var buf bytes.Buffer
|
||||
for i, msg := range msgs {
|
||||
line, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("memory: marshal message %d: %w", i, err)
|
||||
}
|
||||
buf.Write(line)
|
||||
buf.WriteByte('\n')
|
||||
}
|
||||
return fileutil.WriteFileAtomic(s.jsonlPath(sessionKey), buf.Bytes(), 0o644)
|
||||
}
|
||||
|
||||
func (s *JSONLStore) Close() error {
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,835 @@
|
||||
package memory
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
func newTestStore(t *testing.T) *JSONLStore {
|
||||
t.Helper()
|
||||
store, err := NewJSONLStore(t.TempDir())
|
||||
if err != nil {
|
||||
t.Fatalf("NewJSONLStore: %v", err)
|
||||
}
|
||||
return store
|
||||
}
|
||||
|
||||
func TestNewJSONLStore_CreatesDirectory(t *testing.T) {
|
||||
dir := filepath.Join(t.TempDir(), "nested", "sessions")
|
||||
store, err := NewJSONLStore(dir)
|
||||
if err != nil {
|
||||
t.Fatalf("NewJSONLStore: %v", err)
|
||||
}
|
||||
defer store.Close()
|
||||
|
||||
info, err := os.Stat(dir)
|
||||
if err != nil {
|
||||
t.Fatalf("Stat: %v", err)
|
||||
}
|
||||
if !info.IsDir() {
|
||||
t.Errorf("expected directory, got file")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddMessage_BasicRoundtrip(t *testing.T) {
|
||||
store := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
err := store.AddMessage(ctx, "s1", "user", "hello")
|
||||
if err != nil {
|
||||
t.Fatalf("AddMessage: %v", err)
|
||||
}
|
||||
err = store.AddMessage(ctx, "s1", "assistant", "hi there")
|
||||
if err != nil {
|
||||
t.Fatalf("AddMessage: %v", err)
|
||||
}
|
||||
|
||||
history, err := store.GetHistory(ctx, "s1")
|
||||
if err != nil {
|
||||
t.Fatalf("GetHistory: %v", err)
|
||||
}
|
||||
if len(history) != 2 {
|
||||
t.Fatalf("expected 2 messages, got %d", len(history))
|
||||
}
|
||||
if history[0].Role != "user" || history[0].Content != "hello" {
|
||||
t.Errorf("msg[0] = %+v", history[0])
|
||||
}
|
||||
if history[1].Role != "assistant" || history[1].Content != "hi there" {
|
||||
t.Errorf("msg[1] = %+v", history[1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddMessage_AutoCreatesSession(t *testing.T) {
|
||||
store := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Adding a message to a non-existent session should work.
|
||||
err := store.AddMessage(ctx, "new-session", "user", "first message")
|
||||
if err != nil {
|
||||
t.Fatalf("AddMessage: %v", err)
|
||||
}
|
||||
|
||||
history, err := store.GetHistory(ctx, "new-session")
|
||||
if err != nil {
|
||||
t.Fatalf("GetHistory: %v", err)
|
||||
}
|
||||
if len(history) != 1 {
|
||||
t.Fatalf("expected 1 message, got %d", len(history))
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddFullMessage_WithToolCalls(t *testing.T) {
|
||||
store := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
msg := providers.Message{
|
||||
Role: "assistant",
|
||||
Content: "Let me search that.",
|
||||
ToolCalls: []providers.ToolCall{
|
||||
{
|
||||
ID: "call_abc",
|
||||
Type: "function",
|
||||
Function: &providers.FunctionCall{
|
||||
Name: "web_search",
|
||||
Arguments: `{"q":"golang jsonl"}`,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err := store.AddFullMessage(ctx, "tc", msg)
|
||||
if err != nil {
|
||||
t.Fatalf("AddFullMessage: %v", err)
|
||||
}
|
||||
|
||||
history, err := store.GetHistory(ctx, "tc")
|
||||
if err != nil {
|
||||
t.Fatalf("GetHistory: %v", err)
|
||||
}
|
||||
if len(history) != 1 {
|
||||
t.Fatalf("expected 1, got %d", len(history))
|
||||
}
|
||||
if len(history[0].ToolCalls) != 1 {
|
||||
t.Fatalf("expected 1 tool call, got %d", len(history[0].ToolCalls))
|
||||
}
|
||||
tc := history[0].ToolCalls[0]
|
||||
if tc.ID != "call_abc" {
|
||||
t.Errorf("tool call ID = %q", tc.ID)
|
||||
}
|
||||
if tc.Function == nil || tc.Function.Name != "web_search" {
|
||||
t.Errorf("tool call function = %+v", tc.Function)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddFullMessage_ToolCallID(t *testing.T) {
|
||||
store := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
msg := providers.Message{
|
||||
Role: "tool",
|
||||
Content: "search results here",
|
||||
ToolCallID: "call_abc",
|
||||
}
|
||||
|
||||
err := store.AddFullMessage(ctx, "tr", msg)
|
||||
if err != nil {
|
||||
t.Fatalf("AddFullMessage: %v", err)
|
||||
}
|
||||
|
||||
history, err := store.GetHistory(ctx, "tr")
|
||||
if err != nil {
|
||||
t.Fatalf("GetHistory: %v", err)
|
||||
}
|
||||
if len(history) != 1 {
|
||||
t.Fatalf("expected 1, got %d", len(history))
|
||||
}
|
||||
if history[0].ToolCallID != "call_abc" {
|
||||
t.Errorf("ToolCallID = %q", history[0].ToolCallID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetHistory_EmptySession(t *testing.T) {
|
||||
store := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
history, err := store.GetHistory(ctx, "nonexistent")
|
||||
if err != nil {
|
||||
t.Fatalf("GetHistory: %v", err)
|
||||
}
|
||||
if history == nil {
|
||||
t.Fatal("expected non-nil empty slice")
|
||||
}
|
||||
if len(history) != 0 {
|
||||
t.Errorf("expected 0 messages, got %d", len(history))
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetHistory_Ordering(t *testing.T) {
|
||||
store := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
err := store.AddMessage(
|
||||
ctx, "order",
|
||||
"user",
|
||||
string(rune('a'+i)),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("AddMessage(%d): %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
history, err := store.GetHistory(ctx, "order")
|
||||
if err != nil {
|
||||
t.Fatalf("GetHistory: %v", err)
|
||||
}
|
||||
if len(history) != 5 {
|
||||
t.Fatalf("expected 5, got %d", len(history))
|
||||
}
|
||||
for i := 0; i < 5; i++ {
|
||||
expected := string(rune('a' + i))
|
||||
if history[i].Content != expected {
|
||||
t.Errorf("msg[%d].Content = %q, want %q", i, history[i].Content, expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetSummary_GetSummary(t *testing.T) {
|
||||
store := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// No summary yet.
|
||||
summary, err := store.GetSummary(ctx, "s1")
|
||||
if err != nil {
|
||||
t.Fatalf("GetSummary: %v", err)
|
||||
}
|
||||
if summary != "" {
|
||||
t.Errorf("expected empty, got %q", summary)
|
||||
}
|
||||
|
||||
// Set a summary.
|
||||
err = store.SetSummary(ctx, "s1", "talked about Go")
|
||||
if err != nil {
|
||||
t.Fatalf("SetSummary: %v", err)
|
||||
}
|
||||
|
||||
summary, err = store.GetSummary(ctx, "s1")
|
||||
if err != nil {
|
||||
t.Fatalf("GetSummary: %v", err)
|
||||
}
|
||||
if summary != "talked about Go" {
|
||||
t.Errorf("summary = %q", summary)
|
||||
}
|
||||
|
||||
// Update summary.
|
||||
err = store.SetSummary(ctx, "s1", "updated summary")
|
||||
if err != nil {
|
||||
t.Fatalf("SetSummary: %v", err)
|
||||
}
|
||||
|
||||
summary, err = store.GetSummary(ctx, "s1")
|
||||
if err != nil {
|
||||
t.Fatalf("GetSummary: %v", err)
|
||||
}
|
||||
if summary != "updated summary" {
|
||||
t.Errorf("summary = %q", summary)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTruncateHistory_KeepLast(t *testing.T) {
|
||||
store := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
err := store.AddMessage(
|
||||
ctx, "trunc",
|
||||
"user",
|
||||
string(rune('a'+i)),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("AddMessage: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
err := store.TruncateHistory(ctx, "trunc", 4)
|
||||
if err != nil {
|
||||
t.Fatalf("TruncateHistory: %v", err)
|
||||
}
|
||||
|
||||
history, err := store.GetHistory(ctx, "trunc")
|
||||
if err != nil {
|
||||
t.Fatalf("GetHistory: %v", err)
|
||||
}
|
||||
if len(history) != 4 {
|
||||
t.Fatalf("expected 4, got %d", len(history))
|
||||
}
|
||||
// Should be the last 4: g, h, i, j
|
||||
if history[0].Content != "g" {
|
||||
t.Errorf("first kept = %q, want 'g'", history[0].Content)
|
||||
}
|
||||
if history[3].Content != "j" {
|
||||
t.Errorf("last kept = %q, want 'j'", history[3].Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTruncateHistory_KeepZero(t *testing.T) {
|
||||
store := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
err := store.AddMessage(ctx, "empty", "user", "msg")
|
||||
if err != nil {
|
||||
t.Fatalf("AddMessage: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
err := store.TruncateHistory(ctx, "empty", 0)
|
||||
if err != nil {
|
||||
t.Fatalf("TruncateHistory: %v", err)
|
||||
}
|
||||
|
||||
history, err := store.GetHistory(ctx, "empty")
|
||||
if err != nil {
|
||||
t.Fatalf("GetHistory: %v", err)
|
||||
}
|
||||
if len(history) != 0 {
|
||||
t.Errorf("expected 0, got %d", len(history))
|
||||
}
|
||||
}
|
||||
|
||||
func TestTruncateHistory_KeepMoreThanExists(t *testing.T) {
|
||||
store := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
err := store.AddMessage(ctx, "few", "user", "msg")
|
||||
if err != nil {
|
||||
t.Fatalf("AddMessage: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Keep 100, but only 3 exist — should keep all.
|
||||
err := store.TruncateHistory(ctx, "few", 100)
|
||||
if err != nil {
|
||||
t.Fatalf("TruncateHistory: %v", err)
|
||||
}
|
||||
|
||||
history, err := store.GetHistory(ctx, "few")
|
||||
if err != nil {
|
||||
t.Fatalf("GetHistory: %v", err)
|
||||
}
|
||||
if len(history) != 3 {
|
||||
t.Errorf("expected 3, got %d", len(history))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetHistory_ReplacesAll(t *testing.T) {
|
||||
store := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Add some initial messages.
|
||||
for i := 0; i < 5; i++ {
|
||||
err := store.AddMessage(ctx, "replace", "user", "old")
|
||||
if err != nil {
|
||||
t.Fatalf("AddMessage: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Replace with new history.
|
||||
newHistory := []providers.Message{
|
||||
{Role: "user", Content: "new1"},
|
||||
{Role: "assistant", Content: "new2"},
|
||||
}
|
||||
err := store.SetHistory(ctx, "replace", newHistory)
|
||||
if err != nil {
|
||||
t.Fatalf("SetHistory: %v", err)
|
||||
}
|
||||
|
||||
history, err := store.GetHistory(ctx, "replace")
|
||||
if err != nil {
|
||||
t.Fatalf("GetHistory: %v", err)
|
||||
}
|
||||
if len(history) != 2 {
|
||||
t.Fatalf("expected 2, got %d", len(history))
|
||||
}
|
||||
if history[0].Content != "new1" || history[1].Content != "new2" {
|
||||
t.Errorf("history = %+v", history)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetHistory_ResetsSkip(t *testing.T) {
|
||||
store := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Add messages and truncate.
|
||||
for i := 0; i < 10; i++ {
|
||||
err := store.AddMessage(ctx, "skip-reset", "user", "old")
|
||||
if err != nil {
|
||||
t.Fatalf("AddMessage: %v", err)
|
||||
}
|
||||
}
|
||||
err := store.TruncateHistory(ctx, "skip-reset", 3)
|
||||
if err != nil {
|
||||
t.Fatalf("TruncateHistory: %v", err)
|
||||
}
|
||||
|
||||
// SetHistory should reset skip to 0.
|
||||
newHistory := []providers.Message{
|
||||
{Role: "user", Content: "fresh"},
|
||||
}
|
||||
err = store.SetHistory(ctx, "skip-reset", newHistory)
|
||||
if err != nil {
|
||||
t.Fatalf("SetHistory: %v", err)
|
||||
}
|
||||
|
||||
history, err := store.GetHistory(ctx, "skip-reset")
|
||||
if err != nil {
|
||||
t.Fatalf("GetHistory: %v", err)
|
||||
}
|
||||
if len(history) != 1 {
|
||||
t.Fatalf("expected 1, got %d", len(history))
|
||||
}
|
||||
if history[0].Content != "fresh" {
|
||||
t.Errorf("content = %q", history[0].Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestColonInKey(t *testing.T) {
|
||||
store := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
err := store.AddMessage(ctx, "telegram:123", "user", "hi")
|
||||
if err != nil {
|
||||
t.Fatalf("AddMessage: %v", err)
|
||||
}
|
||||
|
||||
history, err := store.GetHistory(ctx, "telegram:123")
|
||||
if err != nil {
|
||||
t.Fatalf("GetHistory: %v", err)
|
||||
}
|
||||
if len(history) != 1 {
|
||||
t.Fatalf("expected 1, got %d", len(history))
|
||||
}
|
||||
|
||||
// Verify the file is named with underscore.
|
||||
jsonlFile := filepath.Join(store.dir, "telegram_123.jsonl")
|
||||
if _, statErr := os.Stat(jsonlFile); statErr != nil {
|
||||
t.Errorf("expected file %s to exist: %v", jsonlFile, statErr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompact_RemovesSkippedMessages(t *testing.T) {
|
||||
store := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Write 10 messages, then truncate to keep last 3.
|
||||
for i := 0; i < 10; i++ {
|
||||
err := store.AddMessage(ctx, "compact", "user", string(rune('a'+i)))
|
||||
if err != nil {
|
||||
t.Fatalf("AddMessage: %v", err)
|
||||
}
|
||||
}
|
||||
err := store.TruncateHistory(ctx, "compact", 3)
|
||||
if err != nil {
|
||||
t.Fatalf("TruncateHistory: %v", err)
|
||||
}
|
||||
|
||||
// Before compact: file still has 10 lines.
|
||||
allOnDisk, err := readMessages(store.jsonlPath("compact"), 0)
|
||||
if err != nil {
|
||||
t.Fatalf("readMessages: %v", err)
|
||||
}
|
||||
if len(allOnDisk) != 10 {
|
||||
t.Fatalf("before compact: expected 10 on disk, got %d", len(allOnDisk))
|
||||
}
|
||||
|
||||
// Compact.
|
||||
err = store.Compact(ctx, "compact")
|
||||
if err != nil {
|
||||
t.Fatalf("Compact: %v", err)
|
||||
}
|
||||
|
||||
// After compact: file should have only 3 lines.
|
||||
allOnDisk, err = readMessages(store.jsonlPath("compact"), 0)
|
||||
if err != nil {
|
||||
t.Fatalf("readMessages: %v", err)
|
||||
}
|
||||
if len(allOnDisk) != 3 {
|
||||
t.Fatalf("after compact: expected 3 on disk, got %d", len(allOnDisk))
|
||||
}
|
||||
|
||||
// GetHistory should still return the same 3 messages.
|
||||
history, err := store.GetHistory(ctx, "compact")
|
||||
if err != nil {
|
||||
t.Fatalf("GetHistory: %v", err)
|
||||
}
|
||||
if len(history) != 3 {
|
||||
t.Fatalf("expected 3, got %d", len(history))
|
||||
}
|
||||
if history[0].Content != "h" || history[2].Content != "j" {
|
||||
t.Errorf("wrong content: %+v", history)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompact_NoOpWhenNoSkip(t *testing.T) {
|
||||
store := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
err := store.AddMessage(ctx, "noop", "user", "msg")
|
||||
if err != nil {
|
||||
t.Fatalf("AddMessage: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Compact without prior truncation — should be a no-op.
|
||||
err := store.Compact(ctx, "noop")
|
||||
if err != nil {
|
||||
t.Fatalf("Compact: %v", err)
|
||||
}
|
||||
|
||||
history, err := store.GetHistory(ctx, "noop")
|
||||
if err != nil {
|
||||
t.Fatalf("GetHistory: %v", err)
|
||||
}
|
||||
if len(history) != 5 {
|
||||
t.Errorf("expected 5, got %d", len(history))
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompact_ThenAppend(t *testing.T) {
|
||||
store := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
for i := 0; i < 8; i++ {
|
||||
err := store.AddMessage(ctx, "cap", "user", string(rune('a'+i)))
|
||||
if err != nil {
|
||||
t.Fatalf("AddMessage: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
err := store.TruncateHistory(ctx, "cap", 2)
|
||||
if err != nil {
|
||||
t.Fatalf("TruncateHistory: %v", err)
|
||||
}
|
||||
err = store.Compact(ctx, "cap")
|
||||
if err != nil {
|
||||
t.Fatalf("Compact: %v", err)
|
||||
}
|
||||
|
||||
// Append after compaction should work correctly.
|
||||
err = store.AddMessage(ctx, "cap", "user", "new")
|
||||
if err != nil {
|
||||
t.Fatalf("AddMessage after compact: %v", err)
|
||||
}
|
||||
|
||||
history, err := store.GetHistory(ctx, "cap")
|
||||
if err != nil {
|
||||
t.Fatalf("GetHistory: %v", err)
|
||||
}
|
||||
if len(history) != 3 {
|
||||
t.Fatalf("expected 3, got %d", len(history))
|
||||
}
|
||||
// g, h (kept from truncation), new (appended after compaction).
|
||||
if history[0].Content != "g" {
|
||||
t.Errorf("first = %q, want 'g'", history[0].Content)
|
||||
}
|
||||
if history[2].Content != "new" {
|
||||
t.Errorf("last = %q, want 'new'", history[2].Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTruncateHistory_StaleMetaCount(t *testing.T) {
|
||||
// Simulates a crash between JSONL append and meta update in addMsg:
|
||||
// file has N+1 lines but meta.Count is still N. TruncateHistory must
|
||||
// reconcile with the real line count so that keepLast is accurate.
|
||||
store := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Write 10 messages normally (meta.Count = 10).
|
||||
for i := 0; i < 10; i++ {
|
||||
err := store.AddMessage(ctx, "stale", "user", string(rune('a'+i)))
|
||||
if err != nil {
|
||||
t.Fatalf("AddMessage: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Simulate crash: append a line to JSONL but do NOT update meta.
|
||||
// This leaves meta.Count = 10 while the file has 11 lines.
|
||||
jsonlPath := store.jsonlPath("stale")
|
||||
f, err := os.OpenFile(jsonlPath, os.O_WRONLY|os.O_APPEND, 0o644)
|
||||
if err != nil {
|
||||
t.Fatalf("open for append: %v", err)
|
||||
}
|
||||
_, err = f.WriteString(`{"role":"user","content":"orphan"}` + "\n")
|
||||
if err != nil {
|
||||
t.Fatalf("write orphan: %v", err)
|
||||
}
|
||||
f.Close()
|
||||
|
||||
// TruncateHistory(keepLast=4) should keep the last 4 of 11 lines,
|
||||
// not the last 4 of 10.
|
||||
err = store.TruncateHistory(ctx, "stale", 4)
|
||||
if err != nil {
|
||||
t.Fatalf("TruncateHistory: %v", err)
|
||||
}
|
||||
|
||||
history, err := store.GetHistory(ctx, "stale")
|
||||
if err != nil {
|
||||
t.Fatalf("GetHistory: %v", err)
|
||||
}
|
||||
if len(history) != 4 {
|
||||
t.Fatalf("expected 4, got %d", len(history))
|
||||
}
|
||||
// Last 4 of [a,b,c,d,e,f,g,h,i,j,orphan] = [h,i,j,orphan]
|
||||
if history[0].Content != "h" {
|
||||
t.Errorf("first kept = %q, want 'h'", history[0].Content)
|
||||
}
|
||||
if history[3].Content != "orphan" {
|
||||
t.Errorf("last kept = %q, want 'orphan'", history[3].Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCrashRecovery_PartialLine(t *testing.T) {
|
||||
store := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Write a valid message first.
|
||||
err := store.AddMessage(ctx, "crash", "user", "valid")
|
||||
if err != nil {
|
||||
t.Fatalf("AddMessage: %v", err)
|
||||
}
|
||||
|
||||
// Simulate a crash by appending a partial JSON line directly.
|
||||
jsonlPath := store.jsonlPath("crash")
|
||||
f, err := os.OpenFile(jsonlPath, os.O_WRONLY|os.O_APPEND, 0o644)
|
||||
if err != nil {
|
||||
t.Fatalf("open for append: %v", err)
|
||||
}
|
||||
_, err = f.WriteString(`{"role":"user","content":"incomple`)
|
||||
if err != nil {
|
||||
t.Fatalf("write partial: %v", err)
|
||||
}
|
||||
f.Close()
|
||||
|
||||
// GetHistory should return only the valid message.
|
||||
history, err := store.GetHistory(ctx, "crash")
|
||||
if err != nil {
|
||||
t.Fatalf("GetHistory: %v", err)
|
||||
}
|
||||
if len(history) != 1 {
|
||||
t.Fatalf("expected 1 valid message, got %d", len(history))
|
||||
}
|
||||
if history[0].Content != "valid" {
|
||||
t.Errorf("content = %q", history[0].Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPersistence_AcrossInstances(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
ctx := context.Background()
|
||||
|
||||
// Write with first instance.
|
||||
store1, err := NewJSONLStore(dir)
|
||||
if err != nil {
|
||||
t.Fatalf("NewJSONLStore: %v", err)
|
||||
}
|
||||
err = store1.AddMessage(ctx, "persist", "user", "remember me")
|
||||
if err != nil {
|
||||
t.Fatalf("AddMessage: %v", err)
|
||||
}
|
||||
err = store1.SetSummary(ctx, "persist", "a test session")
|
||||
if err != nil {
|
||||
t.Fatalf("SetSummary: %v", err)
|
||||
}
|
||||
store1.Close()
|
||||
|
||||
// Read with second instance.
|
||||
store2, err := NewJSONLStore(dir)
|
||||
if err != nil {
|
||||
t.Fatalf("NewJSONLStore: %v", err)
|
||||
}
|
||||
defer store2.Close()
|
||||
|
||||
history, err := store2.GetHistory(ctx, "persist")
|
||||
if err != nil {
|
||||
t.Fatalf("GetHistory: %v", err)
|
||||
}
|
||||
if len(history) != 1 || history[0].Content != "remember me" {
|
||||
t.Errorf("history = %+v", history)
|
||||
}
|
||||
|
||||
summary, err := store2.GetSummary(ctx, "persist")
|
||||
if err != nil {
|
||||
t.Fatalf("GetSummary: %v", err)
|
||||
}
|
||||
if summary != "a test session" {
|
||||
t.Errorf("summary = %q", summary)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConcurrent_AddAndRead(t *testing.T) {
|
||||
store := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
const goroutines = 10
|
||||
const msgsPerGoroutine = 20
|
||||
|
||||
// Concurrent writes.
|
||||
for g := 0; g < goroutines; g++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for i := 0; i < msgsPerGoroutine; i++ {
|
||||
_ = store.AddMessage(ctx, "concurrent", "user", "msg")
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
history, err := store.GetHistory(ctx, "concurrent")
|
||||
if err != nil {
|
||||
t.Fatalf("GetHistory: %v", err)
|
||||
}
|
||||
expected := goroutines * msgsPerGoroutine
|
||||
if len(history) != expected {
|
||||
t.Errorf("expected %d messages, got %d", expected, len(history))
|
||||
}
|
||||
}
|
||||
|
||||
func TestConcurrent_SummarizeRace(t *testing.T) {
|
||||
// Simulates the #704 race: one goroutine adds messages while
|
||||
// another truncates + sets summary — like summarizeSession().
|
||||
store := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Seed with some messages.
|
||||
for i := 0; i < 20; i++ {
|
||||
err := store.AddMessage(ctx, "race", "user", "seed")
|
||||
if err != nil {
|
||||
t.Fatalf("AddMessage: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// Writer goroutine (main agent loop).
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for i := 0; i < 50; i++ {
|
||||
_ = store.AddMessage(ctx, "race", "user", "new")
|
||||
}
|
||||
}()
|
||||
|
||||
// Summarizer goroutine (background task).
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for i := 0; i < 10; i++ {
|
||||
_ = store.SetSummary(ctx, "race", "summary")
|
||||
_ = store.TruncateHistory(ctx, "race", 5)
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Verify the store is still in a consistent state.
|
||||
_, err := store.GetHistory(ctx, "race")
|
||||
if err != nil {
|
||||
t.Fatalf("GetHistory after race: %v", err)
|
||||
}
|
||||
_, err = store.GetSummary(ctx, "race")
|
||||
if err != nil {
|
||||
t.Fatalf("GetSummary after race: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMultipleSessions_Isolation(t *testing.T) {
|
||||
store := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
err := store.AddMessage(ctx, "s1", "user", "msg for s1")
|
||||
if err != nil {
|
||||
t.Fatalf("AddMessage: %v", err)
|
||||
}
|
||||
err = store.AddMessage(ctx, "s2", "user", "msg for s2")
|
||||
if err != nil {
|
||||
t.Fatalf("AddMessage: %v", err)
|
||||
}
|
||||
|
||||
h1, err := store.GetHistory(ctx, "s1")
|
||||
if err != nil {
|
||||
t.Fatalf("GetHistory s1: %v", err)
|
||||
}
|
||||
h2, err := store.GetHistory(ctx, "s2")
|
||||
if err != nil {
|
||||
t.Fatalf("GetHistory s2: %v", err)
|
||||
}
|
||||
|
||||
if len(h1) != 1 || h1[0].Content != "msg for s1" {
|
||||
t.Errorf("s1 history = %+v", h1)
|
||||
}
|
||||
if len(h2) != 1 || h2[0].Content != "msg for s2" {
|
||||
t.Errorf("s2 history = %+v", h2)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkAddMessage(b *testing.B) {
|
||||
dir := b.TempDir()
|
||||
store, err := NewJSONLStore(dir)
|
||||
if err != nil {
|
||||
b.Fatalf("NewJSONLStore: %v", err)
|
||||
}
|
||||
defer store.Close()
|
||||
ctx := context.Background()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = store.AddMessage(ctx, "bench", "user", "benchmark message content")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkGetHistory_100(b *testing.B) {
|
||||
dir := b.TempDir()
|
||||
store, err := NewJSONLStore(dir)
|
||||
if err != nil {
|
||||
b.Fatalf("NewJSONLStore: %v", err)
|
||||
}
|
||||
defer store.Close()
|
||||
ctx := context.Background()
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
_ = store.AddMessage(ctx, "bench", "user", "message content")
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = store.GetHistory(ctx, "bench")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkGetHistory_1000(b *testing.B) {
|
||||
dir := b.TempDir()
|
||||
store, err := NewJSONLStore(dir)
|
||||
if err != nil {
|
||||
b.Fatalf("NewJSONLStore: %v", err)
|
||||
}
|
||||
defer store.Close()
|
||||
ctx := context.Background()
|
||||
|
||||
for i := 0; i < 1000; i++ {
|
||||
_ = store.AddMessage(ctx, "bench", "user", "message content")
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = store.GetHistory(ctx, "bench")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,108 @@
|
||||
package memory
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
// jsonSession mirrors pkg/session.Session for migration purposes.
|
||||
type jsonSession struct {
|
||||
Key string `json:"key"`
|
||||
Messages []providers.Message `json:"messages"`
|
||||
Summary string `json:"summary,omitempty"`
|
||||
Created time.Time `json:"created"`
|
||||
Updated time.Time `json:"updated"`
|
||||
}
|
||||
|
||||
// MigrateFromJSON reads legacy sessions/*.json files from sessionsDir,
|
||||
// writes them into the Store, and renames each migrated file to
|
||||
// .json.migrated as a backup. Returns the number of sessions migrated.
|
||||
//
|
||||
// Files that fail to parse are logged and skipped. Already-migrated
|
||||
// files (.json.migrated) are ignored, making the function idempotent.
|
||||
func MigrateFromJSON(
|
||||
ctx context.Context, sessionsDir string, store Store,
|
||||
) (int, error) {
|
||||
entries, err := os.ReadDir(sessionsDir)
|
||||
if os.IsNotExist(err) {
|
||||
return 0, nil
|
||||
}
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("memory: read sessions dir: %w", err)
|
||||
}
|
||||
|
||||
migrated := 0
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() {
|
||||
continue
|
||||
}
|
||||
name := entry.Name()
|
||||
if !strings.HasSuffix(name, ".json") {
|
||||
continue
|
||||
}
|
||||
// Skip already-migrated files.
|
||||
if strings.HasSuffix(name, ".migrated") {
|
||||
continue
|
||||
}
|
||||
|
||||
srcPath := filepath.Join(sessionsDir, name)
|
||||
|
||||
data, readErr := os.ReadFile(srcPath)
|
||||
if readErr != nil {
|
||||
log.Printf("memory: migrate: skip %s: %v", name, readErr)
|
||||
continue
|
||||
}
|
||||
|
||||
var sess jsonSession
|
||||
if parseErr := json.Unmarshal(data, &sess); parseErr != nil {
|
||||
log.Printf("memory: migrate: skip %s: %v", name, parseErr)
|
||||
continue
|
||||
}
|
||||
|
||||
// Use the key from the JSON content, not the filename.
|
||||
// Filenames are sanitized (":" → "_") but keys are not.
|
||||
key := sess.Key
|
||||
if key == "" {
|
||||
key = strings.TrimSuffix(name, ".json")
|
||||
}
|
||||
|
||||
// Use SetHistory (atomic replace) instead of per-message
|
||||
// AddFullMessage. This makes migration idempotent: if the
|
||||
// process crashes after writing messages but before the
|
||||
// rename below, a retry replaces the partial data cleanly
|
||||
// instead of duplicating messages.
|
||||
if setErr := store.SetHistory(ctx, key, sess.Messages); setErr != nil {
|
||||
return migrated, fmt.Errorf(
|
||||
"memory: migrate %s: set history: %w",
|
||||
name, setErr,
|
||||
)
|
||||
}
|
||||
|
||||
if sess.Summary != "" {
|
||||
if sumErr := store.SetSummary(ctx, key, sess.Summary); sumErr != nil {
|
||||
return migrated, fmt.Errorf(
|
||||
"memory: migrate %s: set summary: %w",
|
||||
name, sumErr,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Rename to .migrated as backup (not delete).
|
||||
renameErr := os.Rename(srcPath, srcPath+".migrated")
|
||||
if renameErr != nil {
|
||||
log.Printf("memory: migrate: rename %s: %v", name, renameErr)
|
||||
}
|
||||
|
||||
migrated++
|
||||
}
|
||||
|
||||
return migrated, nil
|
||||
}
|
||||
@@ -0,0 +1,384 @@
|
||||
package memory
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
func writeJSONSession(
|
||||
t *testing.T, dir string, filename string, sess jsonSession,
|
||||
) {
|
||||
t.Helper()
|
||||
data, err := json.MarshalIndent(sess, "", " ")
|
||||
if err != nil {
|
||||
t.Fatalf("marshal session: %v", err)
|
||||
}
|
||||
err = os.WriteFile(filepath.Join(dir, filename), data, 0o644)
|
||||
if err != nil {
|
||||
t.Fatalf("write session file: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMigrateFromJSON_Basic(t *testing.T) {
|
||||
sessionsDir := t.TempDir()
|
||||
store := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
writeJSONSession(t, sessionsDir, "test.json", jsonSession{
|
||||
Key: "test",
|
||||
Messages: []providers.Message{
|
||||
{Role: "user", Content: "hello"},
|
||||
{Role: "assistant", Content: "hi"},
|
||||
},
|
||||
Summary: "A greeting.",
|
||||
Created: time.Now(),
|
||||
Updated: time.Now(),
|
||||
})
|
||||
|
||||
count, err := MigrateFromJSON(ctx, sessionsDir, store)
|
||||
if err != nil {
|
||||
t.Fatalf("MigrateFromJSON: %v", err)
|
||||
}
|
||||
if count != 1 {
|
||||
t.Errorf("expected 1 migrated, got %d", count)
|
||||
}
|
||||
|
||||
history, err := store.GetHistory(ctx, "test")
|
||||
if err != nil {
|
||||
t.Fatalf("GetHistory: %v", err)
|
||||
}
|
||||
if len(history) != 2 {
|
||||
t.Fatalf("expected 2 messages, got %d", len(history))
|
||||
}
|
||||
if history[0].Content != "hello" || history[1].Content != "hi" {
|
||||
t.Errorf("unexpected messages: %+v", history)
|
||||
}
|
||||
|
||||
summary, err := store.GetSummary(ctx, "test")
|
||||
if err != nil {
|
||||
t.Fatalf("GetSummary: %v", err)
|
||||
}
|
||||
if summary != "A greeting." {
|
||||
t.Errorf("summary = %q", summary)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMigrateFromJSON_WithToolCalls(t *testing.T) {
|
||||
sessionsDir := t.TempDir()
|
||||
store := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
writeJSONSession(t, sessionsDir, "tools.json", jsonSession{
|
||||
Key: "tools",
|
||||
Messages: []providers.Message{
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "Searching...",
|
||||
ToolCalls: []providers.ToolCall{
|
||||
{
|
||||
ID: "call_1",
|
||||
Type: "function",
|
||||
Function: &providers.FunctionCall{
|
||||
Name: "web_search",
|
||||
Arguments: `{"q":"test"}`,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: "tool",
|
||||
Content: "result",
|
||||
ToolCallID: "call_1",
|
||||
},
|
||||
},
|
||||
Created: time.Now(),
|
||||
Updated: time.Now(),
|
||||
})
|
||||
|
||||
count, err := MigrateFromJSON(ctx, sessionsDir, store)
|
||||
if err != nil {
|
||||
t.Fatalf("MigrateFromJSON: %v", err)
|
||||
}
|
||||
if count != 1 {
|
||||
t.Errorf("expected 1, got %d", count)
|
||||
}
|
||||
|
||||
history, err := store.GetHistory(ctx, "tools")
|
||||
if err != nil {
|
||||
t.Fatalf("GetHistory: %v", err)
|
||||
}
|
||||
if len(history) != 2 {
|
||||
t.Fatalf("expected 2 messages, got %d", len(history))
|
||||
}
|
||||
if len(history[0].ToolCalls) != 1 {
|
||||
t.Fatalf("expected 1 tool call, got %d", len(history[0].ToolCalls))
|
||||
}
|
||||
if history[0].ToolCalls[0].Function.Name != "web_search" {
|
||||
t.Errorf("function = %q", history[0].ToolCalls[0].Function.Name)
|
||||
}
|
||||
if history[1].ToolCallID != "call_1" {
|
||||
t.Errorf("ToolCallID = %q", history[1].ToolCallID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMigrateFromJSON_MultipleFiles(t *testing.T) {
|
||||
sessionsDir := t.TempDir()
|
||||
store := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
key := string(rune('a' + i))
|
||||
writeJSONSession(t, sessionsDir, key+".json", jsonSession{
|
||||
Key: key,
|
||||
Messages: []providers.Message{{Role: "user", Content: "msg " + key}},
|
||||
Created: time.Now(),
|
||||
Updated: time.Now(),
|
||||
})
|
||||
}
|
||||
|
||||
count, err := MigrateFromJSON(ctx, sessionsDir, store)
|
||||
if err != nil {
|
||||
t.Fatalf("MigrateFromJSON: %v", err)
|
||||
}
|
||||
if count != 3 {
|
||||
t.Errorf("expected 3, got %d", count)
|
||||
}
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
key := string(rune('a' + i))
|
||||
history, histErr := store.GetHistory(ctx, key)
|
||||
if histErr != nil {
|
||||
t.Fatalf("GetHistory(%q): %v", key, histErr)
|
||||
}
|
||||
if len(history) != 1 {
|
||||
t.Errorf("session %q: expected 1 msg, got %d", key, len(history))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestMigrateFromJSON_InvalidJSON(t *testing.T) {
|
||||
sessionsDir := t.TempDir()
|
||||
store := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// One valid, one invalid.
|
||||
writeJSONSession(t, sessionsDir, "good.json", jsonSession{
|
||||
Key: "good",
|
||||
Messages: []providers.Message{{Role: "user", Content: "ok"}},
|
||||
Created: time.Now(),
|
||||
Updated: time.Now(),
|
||||
})
|
||||
err := os.WriteFile(
|
||||
filepath.Join(sessionsDir, "bad.json"),
|
||||
[]byte("{invalid json"),
|
||||
0o644,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("write bad file: %v", err)
|
||||
}
|
||||
|
||||
count, err := MigrateFromJSON(ctx, sessionsDir, store)
|
||||
if err != nil {
|
||||
t.Fatalf("MigrateFromJSON: %v", err)
|
||||
}
|
||||
if count != 1 {
|
||||
t.Errorf("expected 1 (bad file skipped), got %d", count)
|
||||
}
|
||||
|
||||
history, err := store.GetHistory(ctx, "good")
|
||||
if err != nil {
|
||||
t.Fatalf("GetHistory: %v", err)
|
||||
}
|
||||
if len(history) != 1 {
|
||||
t.Errorf("expected 1 message, got %d", len(history))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMigrateFromJSON_RenamesFiles(t *testing.T) {
|
||||
sessionsDir := t.TempDir()
|
||||
store := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
writeJSONSession(t, sessionsDir, "rename.json", jsonSession{
|
||||
Key: "rename",
|
||||
Messages: []providers.Message{{Role: "user", Content: "hi"}},
|
||||
Created: time.Now(),
|
||||
Updated: time.Now(),
|
||||
})
|
||||
|
||||
_, err := MigrateFromJSON(ctx, sessionsDir, store)
|
||||
if err != nil {
|
||||
t.Fatalf("MigrateFromJSON: %v", err)
|
||||
}
|
||||
|
||||
// Original .json should not exist.
|
||||
_, statErr := os.Stat(filepath.Join(sessionsDir, "rename.json"))
|
||||
if !os.IsNotExist(statErr) {
|
||||
t.Error("rename.json should have been renamed")
|
||||
}
|
||||
// .json.migrated should exist.
|
||||
_, statErr = os.Stat(
|
||||
filepath.Join(sessionsDir, "rename.json.migrated"),
|
||||
)
|
||||
if statErr != nil {
|
||||
t.Errorf("rename.json.migrated should exist: %v", statErr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMigrateFromJSON_Idempotent(t *testing.T) {
|
||||
sessionsDir := t.TempDir()
|
||||
store := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
writeJSONSession(t, sessionsDir, "idem.json", jsonSession{
|
||||
Key: "idem",
|
||||
Messages: []providers.Message{{Role: "user", Content: "once"}},
|
||||
Created: time.Now(),
|
||||
Updated: time.Now(),
|
||||
})
|
||||
|
||||
count1, err := MigrateFromJSON(ctx, sessionsDir, store)
|
||||
if err != nil {
|
||||
t.Fatalf("first migration: %v", err)
|
||||
}
|
||||
if count1 != 1 {
|
||||
t.Errorf("first run: expected 1, got %d", count1)
|
||||
}
|
||||
|
||||
// Second run should find only .migrated files, skip them.
|
||||
count2, err := MigrateFromJSON(ctx, sessionsDir, store)
|
||||
if err != nil {
|
||||
t.Fatalf("second migration: %v", err)
|
||||
}
|
||||
if count2 != 0 {
|
||||
t.Errorf("second run: expected 0, got %d", count2)
|
||||
}
|
||||
|
||||
history, err := store.GetHistory(ctx, "idem")
|
||||
if err != nil {
|
||||
t.Fatalf("GetHistory: %v", err)
|
||||
}
|
||||
if len(history) != 1 {
|
||||
t.Errorf("expected 1 message, got %d", len(history))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMigrateFromJSON_ColonInKey(t *testing.T) {
|
||||
sessionsDir := t.TempDir()
|
||||
store := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// File is named telegram_123 (sanitized), but the key inside is telegram:123.
|
||||
writeJSONSession(t, sessionsDir, "telegram_123.json", jsonSession{
|
||||
Key: "telegram:123",
|
||||
Messages: []providers.Message{{Role: "user", Content: "from telegram"}},
|
||||
Created: time.Now(),
|
||||
Updated: time.Now(),
|
||||
})
|
||||
|
||||
count, err := MigrateFromJSON(ctx, sessionsDir, store)
|
||||
if err != nil {
|
||||
t.Fatalf("MigrateFromJSON: %v", err)
|
||||
}
|
||||
if count != 1 {
|
||||
t.Errorf("expected 1, got %d", count)
|
||||
}
|
||||
|
||||
// Accessible via the original key "telegram:123".
|
||||
history, err := store.GetHistory(ctx, "telegram:123")
|
||||
if err != nil {
|
||||
t.Fatalf("GetHistory: %v", err)
|
||||
}
|
||||
if len(history) != 1 {
|
||||
t.Fatalf("expected 1 message, got %d", len(history))
|
||||
}
|
||||
if history[0].Content != "from telegram" {
|
||||
t.Errorf("content = %q", history[0].Content)
|
||||
}
|
||||
|
||||
// In the file-based store, "telegram:123" and "telegram_123" both
|
||||
// sanitize to the same filename, so they share storage. This is
|
||||
// expected — the colon-to-underscore mapping is a one-way function.
|
||||
history2, err := store.GetHistory(ctx, "telegram_123")
|
||||
if err != nil {
|
||||
t.Fatalf("GetHistory: %v", err)
|
||||
}
|
||||
if len(history2) != 1 {
|
||||
t.Errorf("expected 1 (same file), got %d", len(history2))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMigrateFromJSON_RetryAfterCrash(t *testing.T) {
|
||||
// Simulates a crash during migration: first run writes messages
|
||||
// but doesn't rename the .json file. Second run must replace
|
||||
// (not duplicate) the messages thanks to SetHistory semantics.
|
||||
sessionsDir := t.TempDir()
|
||||
store := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
writeJSONSession(t, sessionsDir, "retry.json", jsonSession{
|
||||
Key: "retry",
|
||||
Messages: []providers.Message{
|
||||
{Role: "user", Content: "one"},
|
||||
{Role: "assistant", Content: "two"},
|
||||
},
|
||||
Created: time.Now(),
|
||||
Updated: time.Now(),
|
||||
})
|
||||
|
||||
// First migration succeeds — writes messages and renames file.
|
||||
count, err := MigrateFromJSON(ctx, sessionsDir, store)
|
||||
if err != nil {
|
||||
t.Fatalf("first migration: %v", err)
|
||||
}
|
||||
if count != 1 {
|
||||
t.Fatalf("expected 1, got %d", count)
|
||||
}
|
||||
|
||||
// Simulate "crash before rename": restore the .json file.
|
||||
src := filepath.Join(sessionsDir, "retry.json.migrated")
|
||||
dst := filepath.Join(sessionsDir, "retry.json")
|
||||
if renameErr := os.Rename(src, dst); renameErr != nil {
|
||||
t.Fatalf("restore .json: %v", renameErr)
|
||||
}
|
||||
|
||||
// Second migration should re-import without duplicating messages.
|
||||
count, err = MigrateFromJSON(ctx, sessionsDir, store)
|
||||
if err != nil {
|
||||
t.Fatalf("second migration: %v", err)
|
||||
}
|
||||
if count != 1 {
|
||||
t.Fatalf("expected 1, got %d", count)
|
||||
}
|
||||
|
||||
history, err := store.GetHistory(ctx, "retry")
|
||||
if err != nil {
|
||||
t.Fatalf("GetHistory: %v", err)
|
||||
}
|
||||
// Must be exactly 2 messages (not 4 from duplication).
|
||||
if len(history) != 2 {
|
||||
t.Fatalf("expected 2 messages (no duplicates), got %d", len(history))
|
||||
}
|
||||
if history[0].Content != "one" || history[1].Content != "two" {
|
||||
t.Errorf("unexpected messages: %+v", history)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMigrateFromJSON_NonexistentDir(t *testing.T) {
|
||||
store := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
count, err := MigrateFromJSON(ctx, "/nonexistent/path", store)
|
||||
if err != nil {
|
||||
t.Fatalf("MigrateFromJSON: %v", err)
|
||||
}
|
||||
if count != 0 {
|
||||
t.Errorf("expected 0, got %d", count)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,42 @@
|
||||
package memory
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
// Store defines an interface for persistent session storage.
|
||||
// Each method is an atomic operation — there is no separate Save() call.
|
||||
type Store interface {
|
||||
// AddMessage appends a simple text message to a session.
|
||||
AddMessage(ctx context.Context, sessionKey, role, content string) error
|
||||
|
||||
// AddFullMessage appends a complete message (with tool calls, etc.) to a session.
|
||||
AddFullMessage(ctx context.Context, sessionKey string, msg providers.Message) error
|
||||
|
||||
// GetHistory returns all messages for a session in insertion order.
|
||||
// Returns an empty slice (not nil) if the session does not exist.
|
||||
GetHistory(ctx context.Context, sessionKey string) ([]providers.Message, error)
|
||||
|
||||
// GetSummary returns the conversation summary for a session.
|
||||
// Returns an empty string if no summary exists.
|
||||
GetSummary(ctx context.Context, sessionKey string) (string, error)
|
||||
|
||||
// SetSummary updates the conversation summary for a session.
|
||||
SetSummary(ctx context.Context, sessionKey, summary string) error
|
||||
|
||||
// TruncateHistory removes all but the last keepLast messages from a session.
|
||||
// If keepLast <= 0, all messages are removed.
|
||||
TruncateHistory(ctx context.Context, sessionKey string, keepLast int) error
|
||||
|
||||
// SetHistory replaces all messages in a session with the provided history.
|
||||
SetHistory(ctx context.Context, sessionKey string, history []providers.Message) error
|
||||
|
||||
// Compact reclaims storage by physically removing logically truncated
|
||||
// data. Backends that do not accumulate dead data may return nil.
|
||||
Compact(ctx context.Context, sessionKey string) error
|
||||
|
||||
// Close releases any resources held by the store.
|
||||
Close() error
|
||||
}
|
||||
@@ -190,28 +190,6 @@ func resolveProviderSelection(cfg *config.Config) (providerSelection, error) {
|
||||
sel.apiBase = "https://api.mistral.ai/v1"
|
||||
}
|
||||
}
|
||||
case "opencode":
|
||||
if cfg.Providers.Opencode.APIKey != "" || cfg.Providers.Opencode.APIBase != "" {
|
||||
sel.apiKey = cfg.Providers.Opencode.APIKey
|
||||
sel.apiBase = cfg.Providers.Opencode.APIBase
|
||||
sel.proxy = cfg.Providers.Opencode.Proxy
|
||||
if sel.apiBase == "" {
|
||||
sel.apiBase = "https://opencode.ai/zen/v1"
|
||||
}
|
||||
}
|
||||
case "kimi", "kimi-code", "moonshot":
|
||||
if cfg.Providers.Moonshot.APIKey != "" {
|
||||
sel.apiKey = cfg.Providers.Moonshot.APIKey
|
||||
sel.apiBase = cfg.Providers.Moonshot.APIBase
|
||||
sel.proxy = cfg.Providers.Moonshot.Proxy
|
||||
if sel.apiBase == "" {
|
||||
if providerName == "moonshot" {
|
||||
sel.apiBase = "https://api.moonshot.cn/v1"
|
||||
} else {
|
||||
sel.apiBase = "https://api.kimi.com/coding/v1"
|
||||
}
|
||||
}
|
||||
}
|
||||
case "github_copilot", "copilot":
|
||||
sel.providerType = providerTypeGitHubCopilot
|
||||
if cfg.Providers.GitHubCopilot.APIBase != "" {
|
||||
@@ -232,11 +210,7 @@ func resolveProviderSelection(cfg *config.Config) (providerSelection, error) {
|
||||
sel.apiBase = cfg.Providers.Moonshot.APIBase
|
||||
sel.proxy = cfg.Providers.Moonshot.Proxy
|
||||
if sel.apiBase == "" {
|
||||
if strings.Contains(lowerModel, "moonshot") || strings.HasPrefix(model, "moonshot/") {
|
||||
sel.apiBase = "https://api.moonshot.cn/v1"
|
||||
} else {
|
||||
sel.apiBase = "https://api.kimi.com/coding/v1"
|
||||
}
|
||||
sel.apiBase = "https://api.moonshot.cn/v1"
|
||||
}
|
||||
case strings.HasPrefix(model, "openrouter/") ||
|
||||
strings.HasPrefix(model, "anthropic/") ||
|
||||
|
||||
@@ -94,7 +94,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
|
||||
|
||||
case "litellm", "openrouter", "groq", "zhipu", "gemini", "nvidia",
|
||||
"ollama", "moonshot", "shengsuanyun", "deepseek", "cerebras",
|
||||
"volcengine", "vllm", "qwen", "mistral", "opencode":
|
||||
"volcengine", "vllm", "qwen", "mistral":
|
||||
// All other OpenAI-compatible HTTP providers
|
||||
if cfg.APIKey == "" && cfg.APIBase == "" {
|
||||
return nil, "", fmt.Errorf("api_key or api_base is required for HTTP-based protocol %q", protocol)
|
||||
@@ -208,8 +208,6 @@ func getDefaultAPIBase(protocol string) string {
|
||||
return "http://localhost:8000/v1"
|
||||
case "mistral":
|
||||
return "https://api.mistral.ai/v1"
|
||||
case "opencode":
|
||||
return "https://opencode.ai/zen/v1"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -112,7 +112,6 @@ func TestCreateProviderFromConfig_DefaultAPIBase(t *testing.T) {
|
||||
{"vllm", "vllm"},
|
||||
{"deepseek", "deepseek"},
|
||||
{"ollama", "ollama"},
|
||||
{"opencode", "opencode"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
||||
@@ -33,7 +33,6 @@ type Provider struct {
|
||||
apiBase string
|
||||
maxTokensField string // Field name for max tokens (e.g., "max_completion_tokens" for o1/glm models)
|
||||
httpClient *http.Client
|
||||
isKimiAPI bool // true when apiBase points to api.kimi.com
|
||||
}
|
||||
|
||||
type Option func(*Provider)
|
||||
@@ -70,17 +69,10 @@ func NewProvider(apiKey, apiBase, proxy string, opts ...Option) *Provider {
|
||||
}
|
||||
}
|
||||
|
||||
trimmedBase := strings.TrimRight(apiBase, "/")
|
||||
var isKimi bool
|
||||
if parsed, err := url.Parse(trimmedBase); err == nil {
|
||||
isKimi = parsed.Hostname() == "api.kimi.com"
|
||||
}
|
||||
|
||||
p := &Provider{
|
||||
apiKey: apiKey,
|
||||
apiBase: trimmedBase,
|
||||
apiBase: strings.TrimRight(apiBase, "/"),
|
||||
httpClient: client,
|
||||
isKimiAPI: isKimi,
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
@@ -184,12 +176,6 @@ func (p *Provider) Chat(
|
||||
if p.apiKey != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+p.apiKey)
|
||||
}
|
||||
// Kimi Code API rejects requests without a recognized coding-agent
|
||||
// User-Agent. "KimiCLI/0.77" is the minimum version string accepted
|
||||
// by the api.kimi.com/coding/v1 endpoint (per Kimi's API docs).
|
||||
if p.isKimiAPI {
|
||||
req.Header.Set("User-Agent", "KimiCLI/0.77")
|
||||
}
|
||||
|
||||
resp, err := p.httpClient.Do(req)
|
||||
if err != nil {
|
||||
|
||||
@@ -2,7 +2,6 @@ package openai_compat
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
@@ -421,82 +420,6 @@ func TestProvider_FunctionalOptionRequestTimeoutNonPositive(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// roundTripFunc adapts a function to http.RoundTripper for test injection.
|
||||
type roundTripFunc func(*http.Request) (*http.Response, error)
|
||||
|
||||
func (f roundTripFunc) RoundTrip(r *http.Request) (*http.Response, error) {
|
||||
return f(r)
|
||||
}
|
||||
|
||||
func TestProviderChat_KimiCodeUserAgent(t *testing.T) {
|
||||
okBody := `{"choices":[{"message":{"content":"ok"},"finish_reason":"stop"}]}`
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
apiBase string
|
||||
wantAgent string
|
||||
}{
|
||||
{
|
||||
name: "sets KimiCLI User-Agent for api.kimi.com",
|
||||
apiBase: "https://api.kimi.com/coding/v1",
|
||||
wantAgent: "KimiCLI/0.77",
|
||||
},
|
||||
{
|
||||
name: "does not set KimiCLI User-Agent for other hosts",
|
||||
apiBase: "https://api.example.com/v1",
|
||||
wantAgent: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var gotUserAgent string
|
||||
|
||||
p := NewProvider("key", tt.apiBase, "")
|
||||
p.httpClient.Transport = roundTripFunc(
|
||||
func(r *http.Request) (*http.Response, error) {
|
||||
gotUserAgent = r.Header.Get("User-Agent")
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(
|
||||
strings.NewReader(okBody),
|
||||
),
|
||||
Header: http.Header{
|
||||
"Content-Type": {"application/json"},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
)
|
||||
|
||||
_, err := p.Chat(
|
||||
t.Context(),
|
||||
[]Message{{Role: "user", Content: "hi"}},
|
||||
nil,
|
||||
"kimi-k2.5",
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
}
|
||||
|
||||
if tt.wantAgent != "" {
|
||||
if gotUserAgent != tt.wantAgent {
|
||||
t.Fatalf(
|
||||
"User-Agent = %q, want %q",
|
||||
gotUserAgent, tt.wantAgent,
|
||||
)
|
||||
}
|
||||
} else {
|
||||
if gotUserAgent == "KimiCLI/0.77" {
|
||||
t.Fatalf(
|
||||
"User-Agent should not be KimiCLI/0.77 for non-kimi host",
|
||||
)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSerializeMessages_PlainText(t *testing.T) {
|
||||
messages := []protocoltypes.Message{
|
||||
{Role: "user", Content: "hello"},
|
||||
|
||||
@@ -3,6 +3,7 @@ package tools
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
type SendCallback func(channel, chatID, content string) error
|
||||
@@ -11,7 +12,7 @@ type MessageTool struct {
|
||||
sendCallback SendCallback
|
||||
defaultChannel string
|
||||
defaultChatID string
|
||||
sentInRound bool // Tracks whether a message was sent in the current processing round
|
||||
sentInRound atomic.Bool // Tracks whether a message was sent in the current processing round
|
||||
}
|
||||
|
||||
func NewMessageTool() *MessageTool {
|
||||
@@ -50,12 +51,12 @@ func (t *MessageTool) Parameters() map[string]any {
|
||||
func (t *MessageTool) SetContext(channel, chatID string) {
|
||||
t.defaultChannel = channel
|
||||
t.defaultChatID = chatID
|
||||
t.sentInRound = false // Reset send tracking for new processing round
|
||||
t.sentInRound.Store(false) // Reset send tracking for new processing round
|
||||
}
|
||||
|
||||
// HasSentInRound returns true if the message tool sent a message during the current round.
|
||||
func (t *MessageTool) HasSentInRound() bool {
|
||||
return t.sentInRound
|
||||
return t.sentInRound.Load()
|
||||
}
|
||||
|
||||
func (t *MessageTool) SetSendCallback(callback SendCallback) {
|
||||
@@ -94,7 +95,7 @@ func (t *MessageTool) Execute(ctx context.Context, args map[string]any) *ToolRes
|
||||
}
|
||||
}
|
||||
|
||||
t.sentInRound = true
|
||||
t.sentInRound.Store(true)
|
||||
// Silent: user already received the message directly
|
||||
return &ToolResult{
|
||||
ForLLM: fmt.Sprintf("Message sent to %s:%s", channel, chatID),
|
||||
|
||||
+2
-2
@@ -30,9 +30,9 @@ var (
|
||||
regexp.MustCompile(`\brm\s+-[rf]{1,2}\b`),
|
||||
regexp.MustCompile(`\bdel\s+/[fq]\b`),
|
||||
regexp.MustCompile(`\brmdir\s+/s\b`),
|
||||
// Match disk wiping commands, avoid matching --format flags
|
||||
// Match disk wiping commands (must be followed by space/args)
|
||||
regexp.MustCompile(
|
||||
`(?:^|[;&|]\s*|\s+)(format|mkfs|diskpart)\s`,
|
||||
`\b(format|mkfs|diskpart)\b\s`,
|
||||
),
|
||||
regexp.MustCompile(`\bdd\s+if=`),
|
||||
// Block writes to block devices (all common naming schemes).
|
||||
|
||||
@@ -366,56 +366,6 @@ func TestShellTool_BlockDevices(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestShellTool_DenyPattern_DiskWiping verifies the deny pattern for disk wiping
|
||||
// commands (format, mkfs, diskpart) blocks them when preceded by shell separators
|
||||
// but does NOT block legitimate uses like --format flags.
|
||||
func TestShellTool_DenyPattern_DiskWiping(t *testing.T) {
|
||||
tool, err := NewExecTool("", false)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to configure exec tool: %s", err)
|
||||
}
|
||||
|
||||
// These should be BLOCKED (disk wiping commands)
|
||||
blockedCmds := []struct {
|
||||
name string
|
||||
cmd string
|
||||
}{
|
||||
{"format with space", "format c:"},
|
||||
{"mkfs standalone", "mkfs /dev/sda"},
|
||||
{"semicolon format", "echo hello; format c:"},
|
||||
{"pipe format", "echo hello | format c:"},
|
||||
{"and format", "echo hello && format c:"},
|
||||
{"diskpart standalone", "diskpart /s script.txt"},
|
||||
}
|
||||
|
||||
for _, tt := range blockedCmds {
|
||||
t.Run("blocked_"+tt.name, func(t *testing.T) {
|
||||
msg := tool.guardCommand(tt.cmd, "")
|
||||
if !strings.Contains(msg, "blocked") {
|
||||
t.Errorf("Expected %q to be blocked by safety guard, got: %q", tt.cmd, msg)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// These should be ALLOWED (not disk wiping)
|
||||
allowed := []struct {
|
||||
name string
|
||||
cmd string
|
||||
}{
|
||||
{"--format flag", "echo test --format json"},
|
||||
{"go fmt", "echo go fmt ./..."},
|
||||
}
|
||||
|
||||
for _, tt := range allowed {
|
||||
t.Run("allowed_"+tt.name, func(t *testing.T) {
|
||||
msg := tool.guardCommand(tt.cmd, "")
|
||||
if msg != "" {
|
||||
t.Errorf("Expected %q to be allowed, but it was blocked: %s", tt.cmd, msg)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestShellTool_SafePathsInWorkspaceRestriction verifies that safe kernel pseudo-devices
|
||||
// are allowed even when workspace restriction is active.
|
||||
func TestShellTool_SafePathsInWorkspaceRestriction(t *testing.T) {
|
||||
|
||||
+43
-26
@@ -10,6 +10,7 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
@@ -121,37 +122,53 @@ func RunToolLoop(
|
||||
}
|
||||
messages = append(messages, assistantMsg)
|
||||
|
||||
// 7. Execute tool calls
|
||||
for _, tc := range normalizedToolCalls {
|
||||
argsJSON, _ := json.Marshal(tc.Arguments)
|
||||
argsPreview := utils.Truncate(string(argsJSON), 200)
|
||||
logger.InfoCF("toolloop", fmt.Sprintf("Tool call: %s(%s)", tc.Name, argsPreview),
|
||||
map[string]any{
|
||||
"tool": tc.Name,
|
||||
"iteration": iteration,
|
||||
})
|
||||
// 7. Execute tool calls in parallel
|
||||
type indexedResult struct {
|
||||
result *ToolResult
|
||||
tc providers.ToolCall
|
||||
}
|
||||
|
||||
// Execute tool (no async callback for subagents - they run independently)
|
||||
var toolResult *ToolResult
|
||||
if config.Tools != nil {
|
||||
toolResult = config.Tools.ExecuteWithContext(ctx, tc.Name, tc.Arguments, channel, chatID, nil)
|
||||
} else {
|
||||
toolResult = ErrorResult("No tools available")
|
||||
results := make([]indexedResult, len(normalizedToolCalls))
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for i, tc := range normalizedToolCalls {
|
||||
results[i].tc = tc
|
||||
|
||||
wg.Add(1)
|
||||
go func(idx int, tc providers.ToolCall) {
|
||||
defer wg.Done()
|
||||
|
||||
argsJSON, _ := json.Marshal(tc.Arguments)
|
||||
argsPreview := utils.Truncate(string(argsJSON), 200)
|
||||
logger.InfoCF("toolloop", fmt.Sprintf("Tool call: %s(%s)", tc.Name, argsPreview),
|
||||
map[string]any{
|
||||
"tool": tc.Name,
|
||||
"iteration": iteration,
|
||||
})
|
||||
|
||||
var toolResult *ToolResult
|
||||
if config.Tools != nil {
|
||||
toolResult = config.Tools.ExecuteWithContext(ctx, tc.Name, tc.Arguments, channel, chatID, nil)
|
||||
} else {
|
||||
toolResult = ErrorResult("No tools available")
|
||||
}
|
||||
results[idx].result = toolResult
|
||||
}(i, tc)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
// Append results in original order
|
||||
for _, r := range results {
|
||||
contentForLLM := r.result.ForLLM
|
||||
if contentForLLM == "" && r.result.Err != nil {
|
||||
contentForLLM = r.result.Err.Error()
|
||||
}
|
||||
|
||||
// Determine content for LLM
|
||||
contentForLLM := toolResult.ForLLM
|
||||
if contentForLLM == "" && toolResult.Err != nil {
|
||||
contentForLLM = toolResult.Err.Error()
|
||||
}
|
||||
|
||||
// Add tool result message
|
||||
toolResultMsg := providers.Message{
|
||||
messages = append(messages, providers.Message{
|
||||
Role: "tool",
|
||||
Content: contentForLLM,
|
||||
ToolCallID: tc.ID,
|
||||
}
|
||||
messages = append(messages, toolResultMsg)
|
||||
ToolCallID: r.tc.ID,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+107
-87
@@ -395,6 +395,88 @@ func (p *PerplexitySearchProvider) Search(ctx context.Context, query string, cou
|
||||
return fmt.Sprintf("Results for: %s (via Perplexity)\n%s", query, searchResp.Choices[0].Message.Content), nil
|
||||
}
|
||||
|
||||
type GLMSearchProvider struct {
|
||||
apiKey string
|
||||
baseURL string
|
||||
searchEngine string
|
||||
proxy string
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
func (p *GLMSearchProvider) Search(ctx context.Context, query string, count int) (string, error) {
|
||||
searchURL := p.baseURL
|
||||
if searchURL == "" {
|
||||
searchURL = "https://open.bigmodel.cn/api/paas/v4/web_search"
|
||||
}
|
||||
|
||||
payload := map[string]any{
|
||||
"search_query": query,
|
||||
"search_engine": p.searchEngine,
|
||||
"search_intent": false,
|
||||
"count": count,
|
||||
"content_size": "medium",
|
||||
}
|
||||
|
||||
bodyBytes, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to marshal payload: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", searchURL, bytes.NewReader(bodyBytes))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+p.apiKey)
|
||||
|
||||
resp, err := p.client.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("GLM Search API error (status %d): %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var searchResp struct {
|
||||
SearchResult []struct {
|
||||
Title string `json:"title"`
|
||||
Content string `json:"content"`
|
||||
Link string `json:"link"`
|
||||
} `json:"search_result"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(body, &searchResp); err != nil {
|
||||
return "", fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
results := searchResp.SearchResult
|
||||
if len(results) == 0 {
|
||||
return fmt.Sprintf("No results for: %s", query), nil
|
||||
}
|
||||
|
||||
var lines []string
|
||||
lines = append(lines, fmt.Sprintf("Results for: %s (via GLM Search)", query))
|
||||
for i, item := range results {
|
||||
if i >= count {
|
||||
break
|
||||
}
|
||||
lines = append(lines, fmt.Sprintf("%d. %s\n %s", i+1, item.Title, item.Link))
|
||||
if item.Content != "" {
|
||||
lines = append(lines, fmt.Sprintf(" %s", item.Content))
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(lines, "\n"), nil
|
||||
}
|
||||
|
||||
type WebSearchTool struct {
|
||||
provider SearchProvider
|
||||
maxResults int
|
||||
@@ -413,9 +495,11 @@ type WebSearchToolOptions struct {
|
||||
PerplexityAPIKey string
|
||||
PerplexityMaxResults int
|
||||
PerplexityEnabled bool
|
||||
ExaAPIKey string
|
||||
ExaMaxResults int
|
||||
ExaEnabled bool
|
||||
GLMSearchAPIKey string
|
||||
GLMSearchBaseURL string
|
||||
GLMSearchEngine string
|
||||
GLMSearchMaxResults int
|
||||
GLMSearchEnabled bool
|
||||
Proxy string
|
||||
}
|
||||
|
||||
@@ -423,7 +507,7 @@ func NewWebSearchTool(opts WebSearchToolOptions) (*WebSearchTool, error) {
|
||||
var provider SearchProvider
|
||||
maxResults := 5
|
||||
|
||||
// Priority: Perplexity > Exa > Brave > Tavily > DuckDuckGo
|
||||
// Priority: Perplexity > Brave > Tavily > DuckDuckGo > GLM Search
|
||||
if opts.PerplexityEnabled && opts.PerplexityAPIKey != "" {
|
||||
client, err := createHTTPClient(opts.Proxy, perplexityTimeout)
|
||||
if err != nil {
|
||||
@@ -433,15 +517,6 @@ func NewWebSearchTool(opts WebSearchToolOptions) (*WebSearchTool, error) {
|
||||
if opts.PerplexityMaxResults > 0 {
|
||||
maxResults = opts.PerplexityMaxResults
|
||||
}
|
||||
} else if opts.ExaEnabled && opts.ExaAPIKey != "" {
|
||||
client, err := createHTTPClient(opts.Proxy, searchTimeout)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create HTTP client for Exa: %w", err)
|
||||
}
|
||||
provider = &ExaSearchProvider{apiKey: opts.ExaAPIKey, proxy: opts.Proxy, client: client}
|
||||
if opts.ExaMaxResults > 0 {
|
||||
maxResults = opts.ExaMaxResults
|
||||
}
|
||||
} else if opts.BraveEnabled && opts.BraveAPIKey != "" {
|
||||
client, err := createHTTPClient(opts.Proxy, searchTimeout)
|
||||
if err != nil {
|
||||
@@ -474,6 +549,25 @@ func NewWebSearchTool(opts WebSearchToolOptions) (*WebSearchTool, error) {
|
||||
if opts.DuckDuckGoMaxResults > 0 {
|
||||
maxResults = opts.DuckDuckGoMaxResults
|
||||
}
|
||||
} else if opts.GLMSearchEnabled && opts.GLMSearchAPIKey != "" {
|
||||
client, err := createHTTPClient(opts.Proxy, searchTimeout)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create HTTP client for GLM Search: %w", err)
|
||||
}
|
||||
searchEngine := opts.GLMSearchEngine
|
||||
if searchEngine == "" {
|
||||
searchEngine = "search_std"
|
||||
}
|
||||
provider = &GLMSearchProvider{
|
||||
apiKey: opts.GLMSearchAPIKey,
|
||||
baseURL: opts.GLMSearchBaseURL,
|
||||
searchEngine: searchEngine,
|
||||
proxy: opts.Proxy,
|
||||
client: client,
|
||||
}
|
||||
if opts.GLMSearchMaxResults > 0 {
|
||||
maxResults = opts.GLMSearchMaxResults
|
||||
}
|
||||
} else {
|
||||
return nil, nil
|
||||
}
|
||||
@@ -721,77 +815,3 @@ func (t *WebFetchTool) extractText(htmlContent string) string {
|
||||
|
||||
return strings.Join(cleanLines, "\n")
|
||||
}
|
||||
|
||||
// ExaSearchProvider uses the Exa AI search API (https://exa.ai).
|
||||
type ExaSearchProvider struct {
|
||||
apiKey string
|
||||
proxy string
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
func (p *ExaSearchProvider) Search(ctx context.Context, query string, count int) (string, error) {
|
||||
reqBody := map[string]any{
|
||||
"query": query,
|
||||
"num_results": count,
|
||||
"type": "neural",
|
||||
}
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("exa: marshal error: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", "https://api.exa.ai/search", bytes.NewReader(jsonData))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("exa: request error: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("x-api-key", p.apiKey)
|
||||
|
||||
resp, err := p.client.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("exa: search failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("exa: read error: %w", err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("exa: API error %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var result struct {
|
||||
Results []struct {
|
||||
Title string `json:"title"`
|
||||
URL string `json:"url"`
|
||||
Text string `json:"text"`
|
||||
} `json:"results"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
return "", fmt.Errorf("exa: parse error: %w", err)
|
||||
}
|
||||
|
||||
if len(result.Results) == 0 {
|
||||
return fmt.Sprintf("No results for: %s", query), nil
|
||||
}
|
||||
|
||||
var lines []string
|
||||
lines = append(lines, fmt.Sprintf("Results for: %s (via Exa)", query))
|
||||
maxResults := count
|
||||
if maxResults > len(result.Results) {
|
||||
maxResults = len(result.Results)
|
||||
}
|
||||
for i, r := range result.Results[:maxResults] {
|
||||
lines = append(lines, fmt.Sprintf("%d. %s\n %s", i+1, r.Title, r.URL))
|
||||
if r.Text != "" {
|
||||
snippet := r.Text
|
||||
if len(snippet) > 200 {
|
||||
snippet = snippet[:200] + "..."
|
||||
}
|
||||
lines = append(lines, fmt.Sprintf(" %s", snippet))
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(lines, "\n"), nil
|
||||
}
|
||||
|
||||
+105
-189
@@ -5,7 +5,6 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
@@ -683,86 +682,7 @@ func TestWebTool_TavilySearch_Success(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewWebSearchTool_ExaPriority(t *testing.T) {
|
||||
// Exa should be selected when enabled with API key
|
||||
tool, err := NewWebSearchTool(WebSearchToolOptions{
|
||||
ExaEnabled: true,
|
||||
ExaAPIKey: "exa-key",
|
||||
ExaMaxResults: 3,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewWebSearchTool() error: %v", err)
|
||||
}
|
||||
if tool == nil {
|
||||
t.Fatal("Expected non-nil tool when Exa is enabled with API key")
|
||||
}
|
||||
if _, ok := tool.provider.(*ExaSearchProvider); !ok {
|
||||
t.Fatalf("provider type = %T, want *ExaSearchProvider", tool.provider)
|
||||
}
|
||||
if tool.maxResults != 3 {
|
||||
t.Fatalf("maxResults = %d, want 3", tool.maxResults)
|
||||
}
|
||||
|
||||
// Exa enabled but no API key should fall through
|
||||
tool, err = NewWebSearchTool(WebSearchToolOptions{
|
||||
ExaEnabled: true,
|
||||
ExaAPIKey: "",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewWebSearchTool() error: %v", err)
|
||||
}
|
||||
if tool != nil {
|
||||
t.Errorf("Expected nil tool when Exa API key is empty and no other provider enabled")
|
||||
}
|
||||
|
||||
// Perplexity should take priority over Exa
|
||||
tool, err = NewWebSearchTool(WebSearchToolOptions{
|
||||
PerplexityEnabled: true,
|
||||
PerplexityAPIKey: "perp-key",
|
||||
ExaEnabled: true,
|
||||
ExaAPIKey: "exa-key",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewWebSearchTool() error: %v", err)
|
||||
}
|
||||
if _, ok := tool.provider.(*PerplexitySearchProvider); !ok {
|
||||
t.Fatalf("provider type = %T, want *PerplexitySearchProvider (Perplexity should outrank Exa)", tool.provider)
|
||||
}
|
||||
|
||||
// Exa should take priority over Brave
|
||||
tool, err = NewWebSearchTool(WebSearchToolOptions{
|
||||
ExaEnabled: true,
|
||||
ExaAPIKey: "exa-key",
|
||||
BraveEnabled: true,
|
||||
BraveAPIKey: "brave-key",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewWebSearchTool() error: %v", err)
|
||||
}
|
||||
if _, ok := tool.provider.(*ExaSearchProvider); !ok {
|
||||
t.Fatalf("provider type = %T, want *ExaSearchProvider (Exa should outrank Brave)", tool.provider)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewWebSearchTool_ExaProxyPropagation(t *testing.T) {
|
||||
tool, err := NewWebSearchTool(WebSearchToolOptions{
|
||||
ExaEnabled: true,
|
||||
ExaAPIKey: "k",
|
||||
Proxy: "http://127.0.0.1:7890",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewWebSearchTool() error: %v", err)
|
||||
}
|
||||
p, ok := tool.provider.(*ExaSearchProvider)
|
||||
if !ok {
|
||||
t.Fatalf("provider type = %T, want *ExaSearchProvider", tool.provider)
|
||||
}
|
||||
if p.proxy != "http://127.0.0.1:7890" {
|
||||
t.Fatalf("provider proxy = %q, want %q", p.proxy, "http://127.0.0.1:7890")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExaSearchProvider_Success(t *testing.T) {
|
||||
func TestWebTool_GLMSearch_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != "POST" {
|
||||
t.Errorf("Expected POST request, got %s", r.Method)
|
||||
@@ -770,130 +690,126 @@ func TestExaSearchProvider_Success(t *testing.T) {
|
||||
if r.Header.Get("Content-Type") != "application/json" {
|
||||
t.Errorf("Expected Content-Type application/json, got %s", r.Header.Get("Content-Type"))
|
||||
}
|
||||
if r.Header.Get("x-api-key") != "test-exa-key" {
|
||||
t.Errorf("Expected x-api-key test-exa-key, got %s", r.Header.Get("x-api-key"))
|
||||
if r.Header.Get("Authorization") != "Bearer test-glm-key" {
|
||||
t.Errorf("Expected Authorization Bearer test-glm-key, got %s", r.Header.Get("Authorization"))
|
||||
}
|
||||
|
||||
// Verify payload
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
var payload map[string]any
|
||||
json.Unmarshal(body, &payload)
|
||||
if payload["query"] != "test query" {
|
||||
t.Errorf("Expected query 'test query', got %v", payload["query"])
|
||||
json.NewDecoder(r.Body).Decode(&payload)
|
||||
if payload["search_query"] != "test query" {
|
||||
t.Errorf("Expected search_query 'test query', got %v", payload["search_query"])
|
||||
}
|
||||
if payload["type"] != "neural" {
|
||||
t.Errorf("Expected type 'neural', got %v", payload["type"])
|
||||
if payload["search_engine"] != "search_std" {
|
||||
t.Errorf("Expected search_engine 'search_std', got %v", payload["search_engine"])
|
||||
}
|
||||
|
||||
response := map[string]any{
|
||||
"results": []map[string]any{
|
||||
{"title": "Exa Result 1", "url": "https://exa.ai/1", "text": "First result text"},
|
||||
{"title": "Exa Result 2", "url": "https://exa.ai/2", "text": "Second result text"},
|
||||
{"title": "Exa Result 3", "url": "https://exa.ai/3", "text": "Third result text"},
|
||||
"id": "web-search-test",
|
||||
"created": 1709568000,
|
||||
"search_result": []map[string]any{
|
||||
{
|
||||
"title": "Test GLM Result",
|
||||
"content": "GLM search snippet",
|
||||
"link": "https://example.com/glm",
|
||||
"media": "Example",
|
||||
"publish_date": "2026-03-04",
|
||||
},
|
||||
},
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(response)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
provider := &ExaSearchProvider{
|
||||
apiKey: "test-exa-key",
|
||||
client: &http.Client{},
|
||||
}
|
||||
|
||||
// Temporarily override the API URL by using a custom transport
|
||||
provider.client.Transport = rewriteHostTransport(server.URL)
|
||||
|
||||
result, err := provider.Search(context.Background(), "test query", 5)
|
||||
if err != nil {
|
||||
t.Fatalf("Search() error: %v", err)
|
||||
}
|
||||
|
||||
if !strings.Contains(result, "via Exa") {
|
||||
t.Errorf("Expected '(via Exa)' attribution, got: %s", result)
|
||||
}
|
||||
if !strings.Contains(result, "Exa Result 1") || !strings.Contains(result, "https://exa.ai/1") {
|
||||
t.Errorf("Expected results in output, got: %s", result)
|
||||
}
|
||||
if !strings.Contains(result, "First result text") {
|
||||
t.Errorf("Expected snippet text in output, got: %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExaSearchProvider_EmptyResults(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
response := map[string]any{"results": []map[string]any{}}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(response)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
provider := &ExaSearchProvider{
|
||||
apiKey: "test-key",
|
||||
client: &http.Client{Transport: rewriteHostTransport(server.URL)},
|
||||
}
|
||||
|
||||
result, err := provider.Search(context.Background(), "no results query", 5)
|
||||
if err != nil {
|
||||
t.Fatalf("Search() error: %v", err)
|
||||
}
|
||||
if !strings.Contains(result, "No results for: no results query") {
|
||||
t.Errorf("Expected 'No results' message, got: %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExaSearchProvider_MaxResultsCapping(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Return 5 results
|
||||
results := make([]map[string]any, 5)
|
||||
for i := range results {
|
||||
results[i] = map[string]any{
|
||||
"title": fmt.Sprintf("Result %d", i+1),
|
||||
"url": fmt.Sprintf("https://exa.ai/%d", i+1),
|
||||
"text": fmt.Sprintf("Text %d", i+1),
|
||||
}
|
||||
}
|
||||
response := map[string]any{"results": results}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(response)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
provider := &ExaSearchProvider{
|
||||
apiKey: "test-key",
|
||||
client: &http.Client{Transport: rewriteHostTransport(server.URL)},
|
||||
}
|
||||
|
||||
// Request only 2 results even though API returns 5
|
||||
result, err := provider.Search(context.Background(), "test", 2)
|
||||
if err != nil {
|
||||
t.Fatalf("Search() error: %v", err)
|
||||
}
|
||||
|
||||
if !strings.Contains(result, "Result 1") || !strings.Contains(result, "Result 2") {
|
||||
t.Errorf("Expected first 2 results, got: %s", result)
|
||||
}
|
||||
if strings.Contains(result, "Result 3") {
|
||||
t.Errorf("Expected results capped at 2, but got Result 3 in output: %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
// rewriteHostTransport returns an http.RoundTripper that redirects all requests to the given target URL.
|
||||
func rewriteHostTransport(target string) http.RoundTripper {
|
||||
return roundTripFunc(func(req *http.Request) (*http.Response, error) {
|
||||
newURL := target + req.URL.Path
|
||||
newReq, err := http.NewRequestWithContext(req.Context(), req.Method, newURL, req.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
newReq.Header = req.Header
|
||||
return http.DefaultClient.Do(newReq)
|
||||
tool, err := NewWebSearchTool(WebSearchToolOptions{
|
||||
GLMSearchEnabled: true,
|
||||
GLMSearchAPIKey: "test-glm-key",
|
||||
GLMSearchBaseURL: server.URL,
|
||||
GLMSearchEngine: "search_std",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewWebSearchTool() error: %v", err)
|
||||
}
|
||||
|
||||
result := tool.Execute(context.Background(), map[string]any{
|
||||
"query": "test query",
|
||||
})
|
||||
|
||||
if result.IsError {
|
||||
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForUser, "Test GLM Result") {
|
||||
t.Errorf("Expected 'Test GLM Result' in output, got: %s", result.ForUser)
|
||||
}
|
||||
if !strings.Contains(result.ForUser, "https://example.com/glm") {
|
||||
t.Errorf("Expected URL in output, got: %s", result.ForUser)
|
||||
}
|
||||
if !strings.Contains(result.ForUser, "via GLM Search") {
|
||||
t.Errorf("Expected 'via GLM Search' in output, got: %s", result.ForUser)
|
||||
}
|
||||
}
|
||||
|
||||
type roundTripFunc func(*http.Request) (*http.Response, error)
|
||||
func TestWebTool_GLMSearch_APIError(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
w.Write([]byte(`{"error":"invalid api key"}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return f(req)
|
||||
tool, err := NewWebSearchTool(WebSearchToolOptions{
|
||||
GLMSearchEnabled: true,
|
||||
GLMSearchAPIKey: "bad-key",
|
||||
GLMSearchBaseURL: server.URL,
|
||||
GLMSearchEngine: "search_std",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewWebSearchTool() error: %v", err)
|
||||
}
|
||||
|
||||
result := tool.Execute(context.Background(), map[string]any{
|
||||
"query": "test query",
|
||||
})
|
||||
|
||||
if !result.IsError {
|
||||
t.Errorf("Expected IsError=true for 401 response")
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "status 401") {
|
||||
t.Errorf("Expected status 401 in error, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebTool_GLMSearch_Priority(t *testing.T) {
|
||||
// GLM Search should only be selected when all other providers are disabled
|
||||
tool, err := NewWebSearchTool(WebSearchToolOptions{
|
||||
DuckDuckGoEnabled: true,
|
||||
DuckDuckGoMaxResults: 5,
|
||||
GLMSearchEnabled: true,
|
||||
GLMSearchAPIKey: "test-key",
|
||||
GLMSearchBaseURL: "https://example.com",
|
||||
GLMSearchEngine: "search_std",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewWebSearchTool() error: %v", err)
|
||||
}
|
||||
|
||||
// DuckDuckGo should win over GLM Search
|
||||
if _, ok := tool.provider.(*DuckDuckGoSearchProvider); !ok {
|
||||
t.Errorf("Expected DuckDuckGoSearchProvider when both enabled, got %T", tool.provider)
|
||||
}
|
||||
|
||||
// With DuckDuckGo disabled, GLM Search should be selected
|
||||
tool2, err := NewWebSearchTool(WebSearchToolOptions{
|
||||
DuckDuckGoEnabled: false,
|
||||
GLMSearchEnabled: true,
|
||||
GLMSearchAPIKey: "test-key",
|
||||
GLMSearchBaseURL: "https://example.com",
|
||||
GLMSearchEngine: "search_std",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewWebSearchTool() error: %v", err)
|
||||
}
|
||||
if _, ok := tool2.provider.(*GLMSearchProvider); !ok {
|
||||
t.Errorf("Expected GLMSearchProvider when only GLM enabled, got %T", tool2.provider)
|
||||
}
|
||||
}
|
||||
|
||||
+19
-4
@@ -3,6 +3,7 @@ package utils
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
@@ -52,11 +53,12 @@ type DownloadOptions struct {
|
||||
Timeout time.Duration
|
||||
ExtraHeaders map[string]string
|
||||
LoggerPrefix string
|
||||
ProxyURL string
|
||||
}
|
||||
|
||||
// DownloadFile downloads a file from URL to a local temp directory.
|
||||
// Returns the local file path or empty string on error.
|
||||
func DownloadFile(url, filename string, opts DownloadOptions) string {
|
||||
func DownloadFile(urlStr, filename string, opts DownloadOptions) string {
|
||||
// Set defaults
|
||||
if opts.Timeout == 0 {
|
||||
opts.Timeout = 60 * time.Second
|
||||
@@ -78,7 +80,7 @@ func DownloadFile(url, filename string, opts DownloadOptions) string {
|
||||
localPath := filepath.Join(mediaDir, uuid.New().String()[:8]+"_"+safeName)
|
||||
|
||||
// Create HTTP request
|
||||
req, err := http.NewRequest("GET", url, nil)
|
||||
req, err := http.NewRequest("GET", urlStr, nil)
|
||||
if err != nil {
|
||||
logger.ErrorCF(opts.LoggerPrefix, "Failed to create download request", map[string]any{
|
||||
"error": err.Error(),
|
||||
@@ -92,11 +94,24 @@ func DownloadFile(url, filename string, opts DownloadOptions) string {
|
||||
}
|
||||
|
||||
client := &http.Client{Timeout: opts.Timeout}
|
||||
if opts.ProxyURL != "" {
|
||||
proxyURL, parseErr := url.Parse(opts.ProxyURL)
|
||||
if parseErr != nil {
|
||||
logger.ErrorCF(opts.LoggerPrefix, "Invalid proxy URL for download", map[string]any{
|
||||
"error": parseErr.Error(),
|
||||
"proxy": opts.ProxyURL,
|
||||
})
|
||||
return ""
|
||||
}
|
||||
client.Transport = &http.Transport{
|
||||
Proxy: http.ProxyURL(proxyURL),
|
||||
}
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
logger.ErrorCF(opts.LoggerPrefix, "Failed to download file", map[string]any{
|
||||
"error": err.Error(),
|
||||
"url": url,
|
||||
"url": urlStr,
|
||||
})
|
||||
return ""
|
||||
}
|
||||
@@ -105,7 +120,7 @@ func DownloadFile(url, filename string, opts DownloadOptions) string {
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
logger.ErrorCF(opts.LoggerPrefix, "File download returned non-200 status", map[string]any{
|
||||
"status": resp.StatusCode,
|
||||
"url": url,
|
||||
"url": urlStr,
|
||||
})
|
||||
return ""
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user