mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
merge: resolve conflicts with upstream/main
Accept upstream versions for all non-Telegram files to keep PR scope focused on Telegram message chunking only. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
+1
-1
@@ -17,4 +17,4 @@
|
||||
# BRAVE_SEARCH_API_KEY=BSA...
|
||||
|
||||
# ── Timezone ──────────────────────────────
|
||||
TZ=Asia/Tokyo
|
||||
TZ=Asia/Shanghai
|
||||
|
||||
@@ -54,7 +54,7 @@
|
||||
|
||||
## 📢 News
|
||||
|
||||
2026-02-16 🎉 PicoClaw hit 12K stars in one week! Thank you all for your support! PicoClaw is growing faster than we ever imagined. Given the high volume of PRs, we urgently need community maintainers. Our volunteer roles and roadmap are officially posted [here](docs/ROADMAP.md) —we can’t wait to have you on board!
|
||||
2026-02-16 🎉 PicoClaw hit 12K stars in one week! Thank you all for your support! PicoClaw is growing faster than we ever imagined. Given the high volume of PRs, we urgently need community maintainers. Our volunteer roles and roadmap are officially posted [here](ROADMAP.md) —we can’t wait to have you on board!
|
||||
|
||||
2026-02-13 🎉 PicoClaw hit 5000 stars in 4days! Thank you for the community! There are so many PRs & issues coming in (during Chinese New Year holidays), we are finalizing the Project Roadmap and setting up the Developer Group to accelerate PicoClaw's development.
|
||||
🚀 Call to Action: Please submit your feature requests in GitHub Discussions. We will review and prioritize them during our upcoming weekly meeting.
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 140 KiB After Width: | Height: | Size: 96 KiB |
@@ -6,7 +6,9 @@
|
||||
"model_name": "gpt4",
|
||||
"max_tokens": 8192,
|
||||
"temperature": 0.7,
|
||||
"max_tool_iterations": 20
|
||||
"max_tool_iterations": 20,
|
||||
"summarize_message_threshold": 20,
|
||||
"summarize_token_percent": 75
|
||||
}
|
||||
},
|
||||
"model_list": [
|
||||
@@ -59,6 +61,7 @@
|
||||
"discord": {
|
||||
"enabled": false,
|
||||
"token": "YOUR_DISCORD_BOT_TOKEN",
|
||||
"proxy": "",
|
||||
"allow_from": [],
|
||||
"group_trigger": {
|
||||
"mention_only": false
|
||||
@@ -337,4 +340,4 @@
|
||||
"host": "127.0.0.1",
|
||||
"port": 18790
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,7 +11,6 @@ require (
|
||||
github.com/gdamore/tcell/v2 v2.13.8
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/gorilla/websocket v1.5.3
|
||||
github.com/joho/godotenv v1.5.1
|
||||
github.com/larksuite/oapi-sdk-go/v3 v3.5.3
|
||||
github.com/mdp/qrterminal/v3 v3.2.1
|
||||
github.com/modelcontextprotocol/go-sdk v1.3.0
|
||||
@@ -38,6 +37,7 @@ require (
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
github.com/elliotchance/orderedmap/v3 v3.1.0 // indirect
|
||||
github.com/gdamore/encoding v1.0.1 // indirect
|
||||
github.com/gdamore/tcell/v2 v2.13.8 // indirect
|
||||
github.com/h2non/filetype v1.1.3 // indirect
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
github.com/lucasb-eyer/go-colorful v1.3.0 // indirect
|
||||
|
||||
@@ -105,8 +105,6 @@ github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyf
|
||||
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
|
||||
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
||||
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
|
||||
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
|
||||
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
|
||||
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
|
||||
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
|
||||
github.com/klauspost/compress v1.18.4 h1:RPhnKRAQ4Fh8zU2FY/6ZFDwTVTxgJ/EMydqSTzE9a2c=
|
||||
|
||||
+46
-32
@@ -18,22 +18,24 @@ import (
|
||||
// AgentInstance represents a fully configured agent with its own workspace,
|
||||
// session manager, context builder, and tool registry.
|
||||
type AgentInstance struct {
|
||||
ID string
|
||||
Name string
|
||||
Model string
|
||||
Fallbacks []string
|
||||
Workspace string
|
||||
MaxIterations int
|
||||
MaxTokens int
|
||||
Temperature float64
|
||||
ContextWindow int
|
||||
Provider providers.LLMProvider
|
||||
Sessions *session.SessionManager
|
||||
ContextBuilder *ContextBuilder
|
||||
Tools *tools.ToolRegistry
|
||||
Subagents *config.SubagentsConfig
|
||||
SkillsFilter []string
|
||||
Candidates []providers.FallbackCandidate
|
||||
ID string
|
||||
Name string
|
||||
Model string
|
||||
Fallbacks []string
|
||||
Workspace string
|
||||
MaxIterations int
|
||||
MaxTokens int
|
||||
Temperature float64
|
||||
ContextWindow int
|
||||
SummarizeMessageThreshold int
|
||||
SummarizeTokenPercent int
|
||||
Provider providers.LLMProvider
|
||||
Sessions *session.SessionManager
|
||||
ContextBuilder *ContextBuilder
|
||||
Tools *tools.ToolRegistry
|
||||
Subagents *config.SubagentsConfig
|
||||
SkillsFilter []string
|
||||
Candidates []providers.FallbackCandidate
|
||||
}
|
||||
|
||||
// NewAgentInstance creates an agent instance from config.
|
||||
@@ -101,6 +103,16 @@ func NewAgentInstance(
|
||||
temperature = *defaults.Temperature
|
||||
}
|
||||
|
||||
summarizeMessageThreshold := defaults.SummarizeMessageThreshold
|
||||
if summarizeMessageThreshold == 0 {
|
||||
summarizeMessageThreshold = 20
|
||||
}
|
||||
|
||||
summarizeTokenPercent := defaults.SummarizeTokenPercent
|
||||
if summarizeTokenPercent == 0 {
|
||||
summarizeTokenPercent = 75
|
||||
}
|
||||
|
||||
// Resolve fallback candidates
|
||||
modelCfg := providers.ModelConfig{
|
||||
Primary: model,
|
||||
@@ -149,22 +161,24 @@ func NewAgentInstance(
|
||||
candidates := providers.ResolveCandidatesWithLookup(modelCfg, defaults.Provider, resolveFromModelList)
|
||||
|
||||
return &AgentInstance{
|
||||
ID: agentID,
|
||||
Name: agentName,
|
||||
Model: model,
|
||||
Fallbacks: fallbacks,
|
||||
Workspace: workspace,
|
||||
MaxIterations: maxIter,
|
||||
MaxTokens: maxTokens,
|
||||
Temperature: temperature,
|
||||
ContextWindow: maxTokens,
|
||||
Provider: provider,
|
||||
Sessions: sessionsManager,
|
||||
ContextBuilder: contextBuilder,
|
||||
Tools: toolsRegistry,
|
||||
Subagents: subagents,
|
||||
SkillsFilter: skillsFilter,
|
||||
Candidates: candidates,
|
||||
ID: agentID,
|
||||
Name: agentName,
|
||||
Model: model,
|
||||
Fallbacks: fallbacks,
|
||||
Workspace: workspace,
|
||||
MaxIterations: maxIter,
|
||||
MaxTokens: maxTokens,
|
||||
Temperature: temperature,
|
||||
ContextWindow: maxTokens,
|
||||
SummarizeMessageThreshold: summarizeMessageThreshold,
|
||||
SummarizeTokenPercent: summarizeTokenPercent,
|
||||
Provider: provider,
|
||||
Sessions: sessionsManager,
|
||||
ContextBuilder: contextBuilder,
|
||||
Tools: toolsRegistry,
|
||||
Subagents: subagents,
|
||||
SkillsFilter: skillsFilter,
|
||||
Candidates: candidates,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+66
-57
@@ -118,9 +118,11 @@ func registerSharedTools(
|
||||
PerplexityAPIKey: cfg.Tools.Web.Perplexity.APIKey,
|
||||
PerplexityMaxResults: cfg.Tools.Web.Perplexity.MaxResults,
|
||||
PerplexityEnabled: cfg.Tools.Web.Perplexity.Enabled,
|
||||
ExaAPIKey: cfg.Tools.Web.Exa.APIKey,
|
||||
ExaMaxResults: cfg.Tools.Web.Exa.MaxResults,
|
||||
ExaEnabled: cfg.Tools.Web.Exa.Enabled,
|
||||
GLMSearchAPIKey: cfg.Tools.Web.GLMSearch.APIKey,
|
||||
GLMSearchBaseURL: cfg.Tools.Web.GLMSearch.BaseURL,
|
||||
GLMSearchEngine: cfg.Tools.Web.GLMSearch.SearchEngine,
|
||||
GLMSearchMaxResults: cfg.Tools.Web.GLMSearch.MaxResults,
|
||||
GLMSearchEnabled: cfg.Tools.Web.GLMSearch.Enabled,
|
||||
Proxy: cfg.Tools.Web.Proxy,
|
||||
})
|
||||
if err != nil {
|
||||
@@ -967,62 +969,76 @@ func (al *AgentLoop) runLLMIteration(
|
||||
// Save assistant message with tool calls to session
|
||||
agent.Sessions.AddFullMessage(opts.SessionKey, assistantMsg)
|
||||
|
||||
// Execute tool calls
|
||||
for _, tc := range normalizedToolCalls {
|
||||
argsJSON, _ := json.Marshal(tc.Arguments)
|
||||
argsPreview := utils.Truncate(string(argsJSON), 200)
|
||||
logger.InfoCF("agent", fmt.Sprintf("Tool call: %s(%s)", tc.Name, argsPreview),
|
||||
map[string]any{
|
||||
"agent_id": agent.ID,
|
||||
"tool": tc.Name,
|
||||
"iteration": iteration,
|
||||
})
|
||||
// Execute tool calls in parallel
|
||||
type indexedAgentResult struct {
|
||||
result *tools.ToolResult
|
||||
tc providers.ToolCall
|
||||
}
|
||||
|
||||
// Create async callback for tools that implement AsyncTool
|
||||
// NOTE: Following openclaw's design, async tools do NOT send results directly to users.
|
||||
// Instead, they notify the agent via PublishInbound, and the agent decides
|
||||
// whether to forward the result to the user (in processSystemMessage).
|
||||
asyncCallback := func(callbackCtx context.Context, result *tools.ToolResult) {
|
||||
// Log the async completion but don't send directly to user
|
||||
// The agent will handle user notification via processSystemMessage
|
||||
if !result.Silent && result.ForUser != "" {
|
||||
logger.InfoCF("agent", "Async tool completed, agent will handle notification",
|
||||
map[string]any{
|
||||
"tool": tc.Name,
|
||||
"content_len": len(result.ForUser),
|
||||
})
|
||||
agentResults := make([]indexedAgentResult, len(normalizedToolCalls))
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for i, tc := range normalizedToolCalls {
|
||||
agentResults[i].tc = tc
|
||||
|
||||
wg.Add(1)
|
||||
go func(idx int, tc providers.ToolCall) {
|
||||
defer wg.Done()
|
||||
|
||||
argsJSON, _ := json.Marshal(tc.Arguments)
|
||||
argsPreview := utils.Truncate(string(argsJSON), 200)
|
||||
logger.InfoCF("agent", fmt.Sprintf("Tool call: %s(%s)", tc.Name, argsPreview),
|
||||
map[string]any{
|
||||
"agent_id": agent.ID,
|
||||
"tool": tc.Name,
|
||||
"iteration": iteration,
|
||||
})
|
||||
|
||||
// Create async callback for tools that implement AsyncTool
|
||||
asyncCallback := func(callbackCtx context.Context, result *tools.ToolResult) {
|
||||
if !result.Silent && result.ForUser != "" {
|
||||
logger.InfoCF("agent", "Async tool completed, agent will handle notification",
|
||||
map[string]any{
|
||||
"tool": tc.Name,
|
||||
"content_len": len(result.ForUser),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
toolResult := agent.Tools.ExecuteWithContext(
|
||||
ctx,
|
||||
tc.Name,
|
||||
tc.Arguments,
|
||||
opts.Channel,
|
||||
opts.ChatID,
|
||||
asyncCallback,
|
||||
)
|
||||
toolResult := agent.Tools.ExecuteWithContext(
|
||||
ctx,
|
||||
tc.Name,
|
||||
tc.Arguments,
|
||||
opts.Channel,
|
||||
opts.ChatID,
|
||||
asyncCallback,
|
||||
)
|
||||
agentResults[idx].result = toolResult
|
||||
}(i, tc)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
// Process results in original order (send to user, save to session)
|
||||
for _, r := range agentResults {
|
||||
// Send ForUser content to user immediately if not Silent
|
||||
if !toolResult.Silent && toolResult.ForUser != "" && opts.SendResponse {
|
||||
if !r.result.Silent && r.result.ForUser != "" && opts.SendResponse {
|
||||
al.bus.PublishOutbound(ctx, bus.OutboundMessage{
|
||||
Channel: opts.Channel,
|
||||
ChatID: opts.ChatID,
|
||||
Content: toolResult.ForUser,
|
||||
Content: r.result.ForUser,
|
||||
})
|
||||
logger.DebugCF("agent", "Sent tool result to user",
|
||||
map[string]any{
|
||||
"tool": tc.Name,
|
||||
"content_len": len(toolResult.ForUser),
|
||||
"tool": r.tc.Name,
|
||||
"content_len": len(r.result.ForUser),
|
||||
})
|
||||
}
|
||||
|
||||
// If tool returned media refs, publish them as outbound media
|
||||
if len(toolResult.Media) > 0 && opts.SendResponse {
|
||||
parts := make([]bus.MediaPart, 0, len(toolResult.Media))
|
||||
for _, ref := range toolResult.Media {
|
||||
if len(r.result.Media) > 0 && opts.SendResponse {
|
||||
parts := make([]bus.MediaPart, 0, len(r.result.Media))
|
||||
for _, ref := range r.result.Media {
|
||||
part := bus.MediaPart{Ref: ref}
|
||||
// Populate metadata from MediaStore when available
|
||||
if al.mediaStore != nil {
|
||||
if _, meta, err := al.mediaStore.ResolveWithMeta(ref); err == nil {
|
||||
part.Filename = meta.Filename
|
||||
@@ -1040,15 +1056,15 @@ func (al *AgentLoop) runLLMIteration(
|
||||
}
|
||||
|
||||
// Determine content for LLM based on tool result
|
||||
contentForLLM := toolResult.ForLLM
|
||||
if contentForLLM == "" && toolResult.Err != nil {
|
||||
contentForLLM = toolResult.Err.Error()
|
||||
contentForLLM := r.result.ForLLM
|
||||
if contentForLLM == "" && r.result.Err != nil {
|
||||
contentForLLM = r.result.Err.Error()
|
||||
}
|
||||
|
||||
toolResultMsg := providers.Message{
|
||||
Role: "tool",
|
||||
Content: contentForLLM,
|
||||
ToolCallID: tc.ID,
|
||||
ToolCallID: r.tc.ID,
|
||||
}
|
||||
messages = append(messages, toolResultMsg)
|
||||
|
||||
@@ -1084,9 +1100,9 @@ func (al *AgentLoop) updateToolContexts(agent *AgentInstance, channel, chatID st
|
||||
func (al *AgentLoop) maybeSummarize(agent *AgentInstance, sessionKey, channel, chatID string) {
|
||||
newHistory := agent.Sessions.GetHistory(sessionKey)
|
||||
tokenEstimate := al.estimateTokens(newHistory)
|
||||
threshold := agent.ContextWindow * 75 / 100
|
||||
threshold := agent.ContextWindow * agent.SummarizeTokenPercent / 100
|
||||
|
||||
if len(newHistory) > 20 || tokenEstimate > threshold {
|
||||
if len(newHistory) > agent.SummarizeMessageThreshold || tokenEstimate > threshold {
|
||||
summarizeKey := agent.ID + ":" + sessionKey
|
||||
if _, loading := al.summarizing.LoadOrStore(summarizeKey, true); !loading {
|
||||
go func() {
|
||||
@@ -1114,15 +1130,8 @@ func (al *AgentLoop) forceCompression(agent *AgentInstance, sessionKey string) {
|
||||
return
|
||||
}
|
||||
|
||||
// Find the mid-point of the conversation, avoiding splitting tool call/result pairs.
|
||||
// A tool-call message (role=assistant with ToolCalls) must be followed by its
|
||||
// tool-result message (role=tool). Splitting between them causes API errors.
|
||||
// Helper to find the mid-point of the conversation
|
||||
mid := len(conversation) / 2
|
||||
if mid < len(conversation) && mid > 0 {
|
||||
if conversation[mid].Role == "tool" {
|
||||
mid++ // move past the tool result to keep the pair together
|
||||
}
|
||||
}
|
||||
|
||||
// New history structure:
|
||||
// 1. System Prompt (with compression note appended)
|
||||
|
||||
@@ -603,85 +603,6 @@ func TestAgentLoop_ContextExhaustionRetry(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestForceCompression_ToolMessageBoundary verifies that forceCompression does not
|
||||
// split a tool call/result pair when the midpoint falls on a "tool" role message.
|
||||
// Regression test for: API errors when orphaned tool result messages appear
|
||||
// without their preceding assistant tool-call message.
|
||||
func TestForceCompression_ToolMessageBoundary(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &mockProvider{}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
|
||||
sessionKey := "test-session-tool-boundary"
|
||||
defaultAgent := al.registry.GetDefaultAgent()
|
||||
if defaultAgent == nil {
|
||||
t.Fatal("No default agent found")
|
||||
}
|
||||
|
||||
// Construct a history where len(conversation)/2 falls exactly on a "tool" message.
|
||||
// history = [system, user, assistant(tool_call), tool, user, assistant, user_trigger]
|
||||
// conversation = history[1:6] = [user, assistant(tool_call), tool, user, assistant]
|
||||
// len(conversation) = 5, mid = 5/2 = 2 => conversation[2].Role == "tool"
|
||||
// Without the fix, this would split between assistant(tool_call) and tool result.
|
||||
history := []providers.Message{
|
||||
{Role: "system", Content: "You are a helpful assistant."},
|
||||
{Role: "user", Content: "What files are in the current directory?"},
|
||||
{Role: "assistant", Content: "", ToolCalls: []providers.ToolCall{
|
||||
{ID: "call_1", Name: "exec", Arguments: map[string]any{"command": "ls"}},
|
||||
}},
|
||||
{Role: "tool", Content: "file1.txt\nfile2.txt", ToolCallID: "call_1"},
|
||||
{Role: "user", Content: "Tell me about file1.txt"},
|
||||
{Role: "assistant", Content: "file1.txt is a text file."},
|
||||
{Role: "user", Content: "Thanks"}, // trigger message
|
||||
}
|
||||
|
||||
// Create the session first (AddMessage creates the session entry),
|
||||
// then overwrite with our full history via SetHistory.
|
||||
defaultAgent.Sessions.AddMessage(sessionKey, "system", "init")
|
||||
defaultAgent.Sessions.SetHistory(sessionKey, history)
|
||||
|
||||
// Call forceCompression
|
||||
al.forceCompression(defaultAgent, sessionKey)
|
||||
|
||||
// Verify the result
|
||||
compressed := defaultAgent.Sessions.GetHistory(sessionKey)
|
||||
|
||||
// Check that no message with role="tool" is the first conversation message
|
||||
// (after the system prompt). If it is, it means the tool result was orphaned.
|
||||
for i := 1; i < len(compressed); i++ {
|
||||
if compressed[i].Role == "tool" {
|
||||
// There must be an assistant message with tool calls before it
|
||||
if i == 1 {
|
||||
t.Errorf("Tool result message at position %d is orphaned (no preceding assistant with tool call)", i)
|
||||
} else if compressed[i-1].Role != "assistant" || len(compressed[i-1].ToolCalls) == 0 {
|
||||
t.Errorf("Tool result at position %d is not preceded by assistant with tool calls (preceded by role=%q)", i, compressed[i-1].Role)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Verify the system prompt has the compression note
|
||||
if !strings.Contains(compressed[0].Content, "Emergency compression") {
|
||||
t.Errorf("Expected compression note in system prompt, got: %s", compressed[0].Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTargetReasoningChannelID_AllChannels(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
|
||||
@@ -3,12 +3,15 @@ package discord
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/bwmarrin/discordgo"
|
||||
"github.com/gorilla/websocket"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/channels"
|
||||
@@ -40,6 +43,9 @@ func NewDiscordChannel(cfg config.DiscordConfig, bus *bus.MessageBus) (*DiscordC
|
||||
return nil, fmt.Errorf("failed to create discord session: %w", err)
|
||||
}
|
||||
|
||||
if err := applyDiscordProxy(session, cfg.Proxy); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
base := channels.NewBaseChannel("discord", cfg, bus, cfg.AllowFrom,
|
||||
channels.WithMaxMessageLength(2000),
|
||||
channels.WithGroupTrigger(cfg.GroupTrigger),
|
||||
@@ -465,9 +471,43 @@ func (c *DiscordChannel) StartTyping(ctx context.Context, chatID string) (func()
|
||||
func (c *DiscordChannel) downloadAttachment(url, filename string) string {
|
||||
return utils.DownloadFile(url, filename, utils.DownloadOptions{
|
||||
LoggerPrefix: "discord",
|
||||
ProxyURL: c.config.Proxy,
|
||||
})
|
||||
}
|
||||
|
||||
func applyDiscordProxy(session *discordgo.Session, proxyAddr string) error {
|
||||
var proxyFunc func(*http.Request) (*url.URL, error)
|
||||
if proxyAddr != "" {
|
||||
proxyURL, err := url.Parse(proxyAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid discord proxy URL %q: %w", proxyAddr, err)
|
||||
}
|
||||
proxyFunc = http.ProxyURL(proxyURL)
|
||||
} else if os.Getenv("HTTP_PROXY") != "" || os.Getenv("HTTPS_PROXY") != "" {
|
||||
proxyFunc = http.ProxyFromEnvironment
|
||||
}
|
||||
|
||||
if proxyFunc == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
transport := &http.Transport{Proxy: proxyFunc}
|
||||
session.Client = &http.Client{
|
||||
Timeout: sendTimeout,
|
||||
Transport: transport,
|
||||
}
|
||||
|
||||
if session.Dialer != nil {
|
||||
dialerCopy := *session.Dialer
|
||||
dialerCopy.Proxy = proxyFunc
|
||||
session.Dialer = &dialerCopy
|
||||
} else {
|
||||
session.Dialer = &websocket.Dialer{Proxy: proxyFunc}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// stripBotMention removes the bot mention from the message content.
|
||||
// Discord mentions have the format <@USER_ID> or <@!USER_ID> (with nickname).
|
||||
func (c *DiscordChannel) stripBotMention(text string) string {
|
||||
|
||||
@@ -0,0 +1,91 @@
|
||||
package discord
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/bwmarrin/discordgo"
|
||||
)
|
||||
|
||||
func TestApplyDiscordProxy_CustomProxy(t *testing.T) {
|
||||
session, err := discordgo.New("Bot test-token")
|
||||
if err != nil {
|
||||
t.Fatalf("discordgo.New() error: %v", err)
|
||||
}
|
||||
|
||||
if err = applyDiscordProxy(session, "http://127.0.0.1:7890"); err != nil {
|
||||
t.Fatalf("applyDiscordProxy() error: %v", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("GET", "https://discord.com/api/v10/gateway", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("http.NewRequest() error: %v", err)
|
||||
}
|
||||
|
||||
restProxy := session.Client.Transport.(*http.Transport).Proxy
|
||||
restProxyURL, err := restProxy(req)
|
||||
if err != nil {
|
||||
t.Fatalf("rest proxy func error: %v", err)
|
||||
}
|
||||
if got, want := restProxyURL.String(), "http://127.0.0.1:7890"; got != want {
|
||||
t.Fatalf("REST proxy = %q, want %q", got, want)
|
||||
}
|
||||
|
||||
wsProxyURL, err := session.Dialer.Proxy(req)
|
||||
if err != nil {
|
||||
t.Fatalf("ws proxy func error: %v", err)
|
||||
}
|
||||
if got, want := wsProxyURL.String(), "http://127.0.0.1:7890"; got != want {
|
||||
t.Fatalf("WS proxy = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyDiscordProxy_FromEnvironment(t *testing.T) {
|
||||
t.Setenv("HTTP_PROXY", "http://127.0.0.1:8888")
|
||||
t.Setenv("http_proxy", "http://127.0.0.1:8888")
|
||||
t.Setenv("HTTPS_PROXY", "http://127.0.0.1:8888")
|
||||
t.Setenv("https_proxy", "http://127.0.0.1:8888")
|
||||
t.Setenv("ALL_PROXY", "")
|
||||
t.Setenv("all_proxy", "")
|
||||
t.Setenv("NO_PROXY", "")
|
||||
t.Setenv("no_proxy", "")
|
||||
|
||||
session, err := discordgo.New("Bot test-token")
|
||||
if err != nil {
|
||||
t.Fatalf("discordgo.New() error: %v", err)
|
||||
}
|
||||
|
||||
if err = applyDiscordProxy(session, ""); err != nil {
|
||||
t.Fatalf("applyDiscordProxy() error: %v", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("GET", "https://discord.com/api/v10/gateway", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("http.NewRequest() error: %v", err)
|
||||
}
|
||||
|
||||
gotURL, err := session.Dialer.Proxy(req)
|
||||
if err != nil {
|
||||
t.Fatalf("ws proxy func error: %v", err)
|
||||
}
|
||||
|
||||
wantURL, err := url.Parse("http://127.0.0.1:8888")
|
||||
if err != nil {
|
||||
t.Fatalf("url.Parse() error: %v", err)
|
||||
}
|
||||
if gotURL.String() != wantURL.String() {
|
||||
t.Fatalf("WS proxy = %q, want %q", gotURL.String(), wantURL.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyDiscordProxy_InvalidProxyURL(t *testing.T) {
|
||||
session, err := discordgo.New("Bot test-token")
|
||||
if err != nil {
|
||||
t.Fatalf("discordgo.New() error: %v", err)
|
||||
}
|
||||
|
||||
if err = applyDiscordProxy(session, "://bad-proxy"); err == nil {
|
||||
t.Fatal("applyDiscordProxy() expected error for invalid proxy URL, got nil")
|
||||
}
|
||||
}
|
||||
+13
-85
@@ -3,23 +3,14 @@ package config
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/caarlos0/env/v11"
|
||||
"github.com/joho/godotenv"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/fileutil"
|
||||
)
|
||||
|
||||
// dotenvOnce ensures .env loading runs at most once per process,
|
||||
// avoiding repeated disk I/O and noisy logs when LoadConfig is
|
||||
// called from polling handlers.
|
||||
var dotenvOnce sync.Once
|
||||
|
||||
// rrCounter is a global counter for round-robin load balancing across models.
|
||||
var rrCounter atomic.Uint64
|
||||
|
||||
@@ -189,6 +180,8 @@ type AgentDefaults struct {
|
||||
MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"`
|
||||
Temperature *float64 `json:"temperature,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"`
|
||||
MaxToolIterations int `json:"max_tool_iterations" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"`
|
||||
SummarizeMessageThreshold int `json:"summarize_message_threshold" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_MESSAGE_THRESHOLD"`
|
||||
SummarizeTokenPercent int `json:"summarize_token_percent" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_TOKEN_PERCENT"`
|
||||
MaxMediaSize int `json:"max_media_size,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_MEDIA_SIZE"`
|
||||
}
|
||||
|
||||
@@ -280,6 +273,7 @@ type FeishuConfig struct {
|
||||
type DiscordConfig struct {
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_DISCORD_ENABLED"`
|
||||
Token string `json:"token" env:"PICOCLAW_CHANNELS_DISCORD_TOKEN"`
|
||||
Proxy string `json:"proxy" env:"PICOCLAW_CHANNELS_DISCORD_PROXY"`
|
||||
AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_DISCORD_ALLOW_FROM"`
|
||||
MentionOnly bool `json:"mention_only" env:"PICOCLAW_CHANNELS_DISCORD_MENTION_ONLY"`
|
||||
GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"`
|
||||
@@ -437,7 +431,6 @@ type ProvidersConfig struct {
|
||||
Antigravity ProviderConfig `json:"antigravity"`
|
||||
Qwen ProviderConfig `json:"qwen"`
|
||||
Mistral ProviderConfig `json:"mistral"`
|
||||
Opencode ProviderConfig `json:"opencode"`
|
||||
}
|
||||
|
||||
// IsEmpty checks if all provider configs are empty (no API keys or API bases set)
|
||||
@@ -461,8 +454,7 @@ func (p ProvidersConfig) IsEmpty() bool {
|
||||
p.GitHubCopilot.APIKey == "" && p.GitHubCopilot.APIBase == "" &&
|
||||
p.Antigravity.APIKey == "" && p.Antigravity.APIBase == "" &&
|
||||
p.Qwen.APIKey == "" && p.Qwen.APIBase == "" &&
|
||||
p.Mistral.APIKey == "" && p.Mistral.APIBase == "" &&
|
||||
p.Opencode.APIKey == "" && p.Opencode.APIBase == ""
|
||||
p.Mistral.APIKey == "" && p.Mistral.APIBase == ""
|
||||
}
|
||||
|
||||
// MarshalJSON implements custom JSON marshaling for ProvidersConfig
|
||||
@@ -555,10 +547,14 @@ type PerplexityConfig struct {
|
||||
MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_PERPLEXITY_MAX_RESULTS"`
|
||||
}
|
||||
|
||||
type ExaConfig struct {
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_WEB_EXA_ENABLED"`
|
||||
APIKey string `json:"api_key" env:"PICOCLAW_TOOLS_WEB_EXA_API_KEY"`
|
||||
MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_EXA_MAX_RESULTS"`
|
||||
type GLMSearchConfig struct {
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_WEB_GLM_ENABLED"`
|
||||
APIKey string `json:"api_key" env:"PICOCLAW_TOOLS_WEB_GLM_API_KEY"`
|
||||
BaseURL string `json:"base_url" env:"PICOCLAW_TOOLS_WEB_GLM_BASE_URL"`
|
||||
// SearchEngine specifies the search backend: "search_std" (default),
|
||||
// "search_pro", "search_pro_sogou", or "search_pro_quark".
|
||||
SearchEngine string `json:"search_engine" env:"PICOCLAW_TOOLS_WEB_GLM_SEARCH_ENGINE"`
|
||||
MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_GLM_MAX_RESULTS"`
|
||||
}
|
||||
|
||||
type WebToolsConfig struct {
|
||||
@@ -566,7 +562,7 @@ type WebToolsConfig struct {
|
||||
Tavily TavilyConfig `json:"tavily"`
|
||||
DuckDuckGo DuckDuckGoConfig `json:"duckduckgo"`
|
||||
Perplexity PerplexityConfig `json:"perplexity"`
|
||||
Exa ExaConfig `json:"exa"`
|
||||
GLMSearch GLMSearchConfig `json:"glm_search"`
|
||||
// Proxy is an optional proxy URL for web tools (http/https/socks5/socks5h).
|
||||
// For authenticated proxies, prefer HTTP_PROXY/HTTPS_PROXY env vars instead of embedding credentials in config.
|
||||
Proxy string `json:"proxy,omitempty" env:"PICOCLAW_TOOLS_WEB_PROXY"`
|
||||
@@ -658,35 +654,9 @@ type MCPConfig struct {
|
||||
func LoadConfig(path string) (*Config, error) {
|
||||
cfg := DefaultConfig()
|
||||
|
||||
// Load .env file from config directory (secrets, API keys, etc.)
|
||||
// Guarded by sync.Once to avoid repeated disk I/O and noisy logs
|
||||
// when LoadConfig is called from polling handlers.
|
||||
dotenvOnce.Do(func() {
|
||||
envFile := filepath.Join(filepath.Dir(path), ".env")
|
||||
if err := godotenv.Load(envFile); err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
log.Printf("[INFO] No .env file found at %s; skipping .env loading", envFile)
|
||||
} else {
|
||||
log.Printf("[WARN] Failed to load .env file from %s: %v", envFile, err)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
// No config file — still apply env vars + overrides to default config
|
||||
if err := env.Parse(cfg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
loadProviderEnvOverrides(cfg)
|
||||
cfg.migrateChannelConfigs()
|
||||
if cfg.HasProvidersConfig() {
|
||||
cfg.ModelList = ConvertProvidersToModelList(cfg)
|
||||
}
|
||||
if err := cfg.ValidateModelList(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return cfg, nil
|
||||
}
|
||||
return nil, err
|
||||
@@ -714,9 +684,6 @@ func LoadConfig(path string) (*Config, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Load provider-specific env overrides (PICOCLAW_PROVIDERS_<NAME>_API_KEY, etc.)
|
||||
loadProviderEnvOverrides(cfg)
|
||||
|
||||
// Migrate legacy channel config fields to new unified structures
|
||||
cfg.migrateChannelConfigs()
|
||||
|
||||
@@ -865,42 +832,3 @@ func (c *Config) ValidateModelList() error {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadProviderEnvOverrides reads PICOCLAW_PROVIDERS_<NAME>_API_KEY and _API_BASE
|
||||
// environment variables and sets them on the corresponding provider config fields.
|
||||
// This enables storing provider secrets in .env files without using struct tags.
|
||||
func loadProviderEnvOverrides(cfg *Config) {
|
||||
providers := []struct {
|
||||
name string
|
||||
apiKey *string
|
||||
base *string
|
||||
}{
|
||||
{"ANTHROPIC", &cfg.Providers.Anthropic.APIKey, &cfg.Providers.Anthropic.APIBase},
|
||||
{"OPENAI", &cfg.Providers.OpenAI.APIKey, &cfg.Providers.OpenAI.APIBase},
|
||||
{"LITELLM", &cfg.Providers.LiteLLM.APIKey, &cfg.Providers.LiteLLM.APIBase},
|
||||
{"OPENROUTER", &cfg.Providers.OpenRouter.APIKey, &cfg.Providers.OpenRouter.APIBase},
|
||||
{"GROQ", &cfg.Providers.Groq.APIKey, &cfg.Providers.Groq.APIBase},
|
||||
{"ZHIPU", &cfg.Providers.Zhipu.APIKey, &cfg.Providers.Zhipu.APIBase},
|
||||
{"GEMINI", &cfg.Providers.Gemini.APIKey, &cfg.Providers.Gemini.APIBase},
|
||||
{"NVIDIA", &cfg.Providers.Nvidia.APIKey, &cfg.Providers.Nvidia.APIBase},
|
||||
{"OLLAMA", &cfg.Providers.Ollama.APIKey, &cfg.Providers.Ollama.APIBase},
|
||||
{"MOONSHOT", &cfg.Providers.Moonshot.APIKey, &cfg.Providers.Moonshot.APIBase},
|
||||
{"SHENGSUANYUN", &cfg.Providers.ShengSuanYun.APIKey, &cfg.Providers.ShengSuanYun.APIBase},
|
||||
{"DEEPSEEK", &cfg.Providers.DeepSeek.APIKey, &cfg.Providers.DeepSeek.APIBase},
|
||||
{"MISTRAL", &cfg.Providers.Mistral.APIKey, &cfg.Providers.Mistral.APIBase},
|
||||
{"VLLM", &cfg.Providers.VLLM.APIKey, &cfg.Providers.VLLM.APIBase},
|
||||
{"CEREBRAS", &cfg.Providers.Cerebras.APIKey, &cfg.Providers.Cerebras.APIBase},
|
||||
{"VOLCENGINE", &cfg.Providers.VolcEngine.APIKey, &cfg.Providers.VolcEngine.APIBase},
|
||||
{"QWEN", &cfg.Providers.Qwen.APIKey, &cfg.Providers.Qwen.APIBase},
|
||||
// Note: GitHubCopilot and Antigravity use different auth patterns (ConnectMode/AuthMethod),
|
||||
// not standard APIKey/APIBase, so they are not included here.
|
||||
}
|
||||
for _, p := range providers {
|
||||
if v, ok := os.LookupEnv("PICOCLAW_PROVIDERS_" + p.name + "_API_KEY"); ok {
|
||||
*p.apiKey = v
|
||||
}
|
||||
if v, ok := os.LookupEnv("PICOCLAW_PROVIDERS_" + p.name + "_API_BASE"); ok {
|
||||
*p.base = v
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
+12
-96
@@ -6,7 +6,6 @@ import (
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@@ -436,6 +435,18 @@ func TestLoadConfig_WebToolsProxy(t *testing.T) {
|
||||
}
|
||||
|
||||
// TestDefaultConfig_DMScope verifies the default dm_scope value
|
||||
// TestDefaultConfig_SummarizationThresholds verifies summarization defaults
|
||||
func TestDefaultConfig_SummarizationThresholds(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
|
||||
if cfg.Agents.Defaults.SummarizeMessageThreshold != 20 {
|
||||
t.Errorf("SummarizeMessageThreshold = %d, want 20", cfg.Agents.Defaults.SummarizeMessageThreshold)
|
||||
}
|
||||
if cfg.Agents.Defaults.SummarizeTokenPercent != 75 {
|
||||
t.Errorf("SummarizeTokenPercent = %d, want 75", cfg.Agents.Defaults.SummarizeTokenPercent)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultConfig_DMScope(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
|
||||
@@ -468,98 +479,3 @@ func TestDefaultConfig_WorkspacePath_WithPicoclawHome(t *testing.T) {
|
||||
t.Errorf("Workspace path with PICOCLAW_HOME = %q, want %q", cfg.Agents.Defaults.Workspace, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfig_DotenvFileLoaded(t *testing.T) {
|
||||
// Reset sync.Once so .env loading runs for this test
|
||||
dotenvOnce = sync.Once{}
|
||||
|
||||
dir := t.TempDir()
|
||||
configPath := filepath.Join(dir, "config.json")
|
||||
|
||||
// Write a minimal config.json
|
||||
if err := os.WriteFile(configPath, []byte(`{}`), 0o600); err != nil {
|
||||
t.Fatalf("WriteFile config: %v", err)
|
||||
}
|
||||
|
||||
// Write a .env file with a provider API key
|
||||
envFile := filepath.Join(dir, ".env")
|
||||
if err := os.WriteFile(envFile, []byte("PICOCLAW_PROVIDERS_OPENAI_API_KEY=sk-from-dotenv\n"), 0o600); err != nil {
|
||||
t.Fatalf("WriteFile .env: %v", err)
|
||||
}
|
||||
|
||||
// Clear the env var first to ensure it comes from .env
|
||||
t.Setenv("PICOCLAW_PROVIDERS_OPENAI_API_KEY", "")
|
||||
os.Unsetenv("PICOCLAW_PROVIDERS_OPENAI_API_KEY")
|
||||
|
||||
cfg, err := LoadConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig() error: %v", err)
|
||||
}
|
||||
|
||||
if cfg.Providers.OpenAI.APIKey != "sk-from-dotenv" {
|
||||
t.Errorf("OpenAI.APIKey = %q, want %q", cfg.Providers.OpenAI.APIKey, "sk-from-dotenv")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfig_MissingConfigJSON_AppliesEnvVars(t *testing.T) {
|
||||
// Reset sync.Once so .env loading runs for this test
|
||||
dotenvOnce = sync.Once{}
|
||||
|
||||
dir := t.TempDir()
|
||||
configPath := filepath.Join(dir, "config.json") // does NOT exist
|
||||
|
||||
t.Setenv("PICOCLAW_PROVIDERS_ANTHROPIC_API_KEY", "sk-anthropic-test")
|
||||
|
||||
cfg, err := LoadConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig() error: %v", err)
|
||||
}
|
||||
|
||||
if cfg.Providers.Anthropic.APIKey != "sk-anthropic-test" {
|
||||
t.Errorf("Anthropic.APIKey = %q, want %q", cfg.Providers.Anthropic.APIKey, "sk-anthropic-test")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfig_MalformedDotenv_NonFatal(t *testing.T) {
|
||||
// Reset sync.Once so .env loading runs for this test
|
||||
dotenvOnce = sync.Once{}
|
||||
|
||||
dir := t.TempDir()
|
||||
configPath := filepath.Join(dir, "config.json")
|
||||
|
||||
// Write a minimal config.json
|
||||
if err := os.WriteFile(configPath, []byte(`{}`), 0o600); err != nil {
|
||||
t.Fatalf("WriteFile config: %v", err)
|
||||
}
|
||||
|
||||
// Write a .env file with genuinely malformed content (bare key without '=',
|
||||
// mixed with a valid line) to verify godotenv.Load errors are non-fatal.
|
||||
envFile := filepath.Join(dir, ".env")
|
||||
if err := os.WriteFile(envFile, []byte("THIS_LINE_HAS_NO_EQUALS\nVALID_KEY=valid_value\n"), 0o600); err != nil {
|
||||
t.Fatalf("WriteFile .env: %v", err)
|
||||
}
|
||||
|
||||
// LoadConfig should not fail even with malformed .env content
|
||||
cfg, err := LoadConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig() should not fail with .env issues, got error: %v", err)
|
||||
}
|
||||
if cfg == nil {
|
||||
t.Fatal("LoadConfig() returned nil config")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadProviderEnvOverrides_LookupEnv(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
|
||||
// Set a key to a non-empty value, then override with empty via env
|
||||
cfg.Providers.OpenRouter.APIBase = "https://original.com"
|
||||
t.Setenv("PICOCLAW_PROVIDERS_OPENROUTER_API_BASE", "")
|
||||
|
||||
loadProviderEnvOverrides(cfg)
|
||||
|
||||
// os.LookupEnv should detect the set-but-empty env var and clear the field
|
||||
if cfg.Providers.OpenRouter.APIBase != "" {
|
||||
t.Errorf("OpenRouter.APIBase = %q, want empty (overridden by empty env var)", cfg.Providers.OpenRouter.APIBase)
|
||||
}
|
||||
}
|
||||
|
||||
+15
-11
@@ -26,13 +26,15 @@ func DefaultConfig() *Config {
|
||||
return &Config{
|
||||
Agents: AgentsConfig{
|
||||
Defaults: AgentDefaults{
|
||||
Workspace: workspacePath,
|
||||
RestrictToWorkspace: true,
|
||||
Provider: "",
|
||||
Model: "",
|
||||
MaxTokens: 32768,
|
||||
Temperature: nil, // nil means use provider default
|
||||
MaxToolIterations: 50,
|
||||
Workspace: workspacePath,
|
||||
RestrictToWorkspace: true,
|
||||
Provider: "",
|
||||
Model: "",
|
||||
MaxTokens: 32768,
|
||||
Temperature: nil, // nil means use provider default
|
||||
MaxToolIterations: 50,
|
||||
SummarizeMessageThreshold: 20,
|
||||
SummarizeTokenPercent: 75,
|
||||
},
|
||||
},
|
||||
Bindings: []AgentBinding{},
|
||||
@@ -341,10 +343,12 @@ func DefaultConfig() *Config {
|
||||
APIKey: "",
|
||||
MaxResults: 5,
|
||||
},
|
||||
Exa: ExaConfig{
|
||||
Enabled: false,
|
||||
APIKey: "",
|
||||
MaxResults: 5,
|
||||
GLMSearch: GLMSearchConfig{
|
||||
Enabled: false,
|
||||
APIKey: "",
|
||||
BaseURL: "https://open.bigmodel.cn/api/paas/v4/web_search",
|
||||
SearchEngine: "search_std",
|
||||
MaxResults: 5,
|
||||
},
|
||||
},
|
||||
Cron: CronToolsConfig{
|
||||
|
||||
+2
-21
@@ -225,7 +225,7 @@ func ConvertProvidersToModelList(cfg *Config) []ModelConfig {
|
||||
},
|
||||
},
|
||||
{
|
||||
providerNames: []string{"moonshot", "kimi", "kimi-code"},
|
||||
providerNames: []string{"moonshot", "kimi"},
|
||||
protocol: "moonshot",
|
||||
buildConfig: func(p ProvidersConfig) (ModelConfig, bool) {
|
||||
if p.Moonshot.APIKey == "" && p.Moonshot.APIBase == "" {
|
||||
@@ -373,23 +373,6 @@ func ConvertProvidersToModelList(cfg *Config) []ModelConfig {
|
||||
}, true
|
||||
},
|
||||
},
|
||||
{
|
||||
providerNames: []string{"opencode"},
|
||||
protocol: "opencode",
|
||||
buildConfig: func(p ProvidersConfig) (ModelConfig, bool) {
|
||||
if p.Opencode.APIKey == "" && p.Opencode.APIBase == "" {
|
||||
return ModelConfig{}, false
|
||||
}
|
||||
return ModelConfig{
|
||||
ModelName: "opencode",
|
||||
Model: "opencode/auto",
|
||||
APIKey: p.Opencode.APIKey,
|
||||
APIBase: p.Opencode.APIBase,
|
||||
Proxy: p.Opencode.Proxy,
|
||||
RequestTimeout: p.Opencode.RequestTimeout,
|
||||
}, true
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Process each provider migration
|
||||
@@ -401,9 +384,7 @@ func ConvertProvidersToModelList(cfg *Config) []ModelConfig {
|
||||
|
||||
// Check if this is the user's configured provider
|
||||
if slices.Contains(m.providerNames, userProvider) && userModel != "" {
|
||||
// Use the user's configured model instead of default.
|
||||
// Also set ModelName so GetModelConfig(userModel) can find this entry.
|
||||
mc.ModelName = userModel
|
||||
// Use the user's configured model instead of default
|
||||
mc.Model = buildModelWithProtocol(m.protocol, userModel)
|
||||
} else if userProvider == "" && userModel != "" && !legacyModelNameApplied {
|
||||
// Legacy config: no explicit provider field but model is specified
|
||||
|
||||
@@ -160,7 +160,6 @@ func TestConvertProvidersToModelList_AllProviders(t *testing.T) {
|
||||
Antigravity: ProviderConfig{AuthMethod: "oauth"},
|
||||
Qwen: ProviderConfig{APIKey: "key17"},
|
||||
Mistral: ProviderConfig{APIKey: "key18"},
|
||||
Opencode: ProviderConfig{APIKey: "key19"},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -580,65 +579,6 @@ func TestBuildModelWithProtocol_DifferentPrefix(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertProvidersToModelList_Opencode(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Providers: ProvidersConfig{
|
||||
Opencode: ProviderConfig{
|
||||
APIKey: "oc-test-key",
|
||||
APIBase: "https://custom.opencode.ai/v1",
|
||||
Proxy: "http://proxy:9090",
|
||||
RequestTimeout: 60,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result := ConvertProvidersToModelList(cfg)
|
||||
|
||||
if len(result) != 1 {
|
||||
t.Fatalf("len(result) = %d, want 1", len(result))
|
||||
}
|
||||
|
||||
mc := result[0]
|
||||
if mc.ModelName != "opencode" {
|
||||
t.Errorf("ModelName = %q, want %q", mc.ModelName, "opencode")
|
||||
}
|
||||
if mc.Model != "opencode/auto" {
|
||||
t.Errorf("Model = %q, want %q", mc.Model, "opencode/auto")
|
||||
}
|
||||
if mc.APIKey != "oc-test-key" {
|
||||
t.Errorf("APIKey = %q, want %q", mc.APIKey, "oc-test-key")
|
||||
}
|
||||
if mc.APIBase != "https://custom.opencode.ai/v1" {
|
||||
t.Errorf("APIBase = %q, want %q", mc.APIBase, "https://custom.opencode.ai/v1")
|
||||
}
|
||||
if mc.Proxy != "http://proxy:9090" {
|
||||
t.Errorf("Proxy = %q, want %q", mc.Proxy, "http://proxy:9090")
|
||||
}
|
||||
if mc.RequestTimeout != 60 {
|
||||
t.Errorf("RequestTimeout = %d, want %d", mc.RequestTimeout, 60)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertProvidersToModelList_Opencode_APIBaseOnly(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Providers: ProvidersConfig{
|
||||
Opencode: ProviderConfig{
|
||||
APIBase: "https://custom.opencode.ai/v1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result := ConvertProvidersToModelList(cfg)
|
||||
|
||||
if len(result) != 1 {
|
||||
t.Fatalf("len(result) = %d, want 1 (APIBase-only should create entry)", len(result))
|
||||
}
|
||||
|
||||
if result[0].ModelName != "opencode" {
|
||||
t.Errorf("ModelName = %q, want %q", result[0].ModelName, "opencode")
|
||||
}
|
||||
}
|
||||
|
||||
// Test for legacy config with protocol prefix in model name
|
||||
func TestConvertProvidersToModelList_LegacyModelWithProtocolPrefix(t *testing.T) {
|
||||
cfg := &Config{
|
||||
@@ -669,72 +609,3 @@ func TestConvertProvidersToModelList_LegacyModelWithProtocolPrefix(t *testing.T)
|
||||
t.Errorf("Model = %q, want %q (should not duplicate prefix)", result[0].Model, "openrouter/auto")
|
||||
}
|
||||
}
|
||||
|
||||
// Test that ModelName is set to the user's configured model when provider matches.
|
||||
// This ensures GetModelConfig(userModel) can find the migrated entry.
|
||||
// Regression test for: gateway startup failure when user model differs from provider name.
|
||||
func TestConvertProvidersToModelList_ModelNameMatchesUserModel(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Agents: AgentsConfig{
|
||||
Defaults: AgentDefaults{
|
||||
Provider: "moonshot",
|
||||
Model: "k2p5",
|
||||
},
|
||||
},
|
||||
Providers: ProvidersConfig{
|
||||
Moonshot: ProviderConfig{APIKey: "sk-kimi-test"},
|
||||
},
|
||||
}
|
||||
|
||||
result := ConvertProvidersToModelList(cfg)
|
||||
|
||||
if len(result) != 1 {
|
||||
t.Fatalf("len(result) = %d, want 1", len(result))
|
||||
}
|
||||
|
||||
// ModelName must match the user's configured model, not the provider name.
|
||||
// Without this, GetModelConfig("k2p5") would fail because it would look
|
||||
// for ModelName == "k2p5" but find ModelName == "moonshot".
|
||||
if result[0].ModelName != "k2p5" {
|
||||
t.Errorf("ModelName = %q, want %q (must match user's model for GetModelConfig lookup)", result[0].ModelName, "k2p5")
|
||||
}
|
||||
|
||||
if result[0].Model != "moonshot/k2p5" {
|
||||
t.Errorf("Model = %q, want %q", result[0].Model, "moonshot/k2p5")
|
||||
}
|
||||
|
||||
// Other providers (not matching the user's configured provider) should keep their provider name
|
||||
cfg2 := &Config{
|
||||
Agents: AgentsConfig{
|
||||
Defaults: AgentDefaults{
|
||||
Provider: "moonshot",
|
||||
Model: "k2p5",
|
||||
},
|
||||
},
|
||||
Providers: ProvidersConfig{
|
||||
OpenAI: OpenAIProviderConfig{ProviderConfig: ProviderConfig{APIKey: "sk-openai"}},
|
||||
Moonshot: ProviderConfig{APIKey: "sk-kimi-test"},
|
||||
},
|
||||
}
|
||||
|
||||
result2 := ConvertProvidersToModelList(cfg2)
|
||||
|
||||
if len(result2) != 2 {
|
||||
t.Fatalf("len(result2) = %d, want 2", len(result2))
|
||||
}
|
||||
|
||||
for _, mc := range result2 {
|
||||
switch {
|
||||
case mc.APIKey == "sk-openai":
|
||||
// OpenAI is not the user's provider, should keep default ModelName
|
||||
if mc.ModelName != "openai" {
|
||||
t.Errorf("OpenAI ModelName = %q, want %q (non-matching provider keeps default)", mc.ModelName, "openai")
|
||||
}
|
||||
case mc.APIKey == "sk-kimi-test":
|
||||
// Moonshot is the user's provider, ModelName must be the user's model
|
||||
if mc.ModelName != "k2p5" {
|
||||
t.Errorf("Moonshot ModelName = %q, want %q (matching provider uses user model)", mc.ModelName, "k2p5")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,460 @@
|
||||
package memory
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"hash/fnv"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/fileutil"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
const (
|
||||
// numLockShards is the fixed number of mutexes used to serialize
|
||||
// per-session access. Using a sharded array instead of a map keeps
|
||||
// memory bounded regardless of how many sessions are created over
|
||||
// the lifetime of the process — important for a long-running daemon.
|
||||
numLockShards = 64
|
||||
|
||||
// maxLineSize is the maximum size of a single JSON line in a .jsonl
|
||||
// file. Tool results (read_file, web search, etc.) can be large, so
|
||||
// we set a generous limit. The scanner starts at 64 KB and grows
|
||||
// only as needed up to this cap.
|
||||
maxLineSize = 10 * 1024 * 1024 // 10 MB
|
||||
)
|
||||
|
||||
// sessionMeta holds per-session metadata stored in a .meta.json file.
|
||||
type sessionMeta struct {
|
||||
Key string `json:"key"`
|
||||
Summary string `json:"summary"`
|
||||
Skip int `json:"skip"`
|
||||
Count int `json:"count"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
// JSONLStore implements Store using append-only JSONL files.
|
||||
//
|
||||
// Each session is stored as two files:
|
||||
//
|
||||
// {sanitized_key}.jsonl — one JSON-encoded message per line, append-only
|
||||
// {sanitized_key}.meta.json — session metadata (summary, logical truncation offset)
|
||||
//
|
||||
// Messages are never physically deleted from the JSONL file. Instead,
|
||||
// TruncateHistory records a "skip" offset in the metadata file and
|
||||
// GetHistory ignores lines before that offset. This keeps all writes
|
||||
// append-only, which is both fast and crash-safe.
|
||||
type JSONLStore struct {
|
||||
dir string
|
||||
locks [numLockShards]sync.Mutex
|
||||
}
|
||||
|
||||
// NewJSONLStore creates a new JSONL-backed store rooted at dir.
|
||||
func NewJSONLStore(dir string) (*JSONLStore, error) {
|
||||
err := os.MkdirAll(dir, 0o755)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("memory: create directory: %w", err)
|
||||
}
|
||||
return &JSONLStore{dir: dir}, nil
|
||||
}
|
||||
|
||||
// sessionLock returns a mutex for the given session key.
|
||||
// Keys are mapped to a fixed pool of shards via FNV hash, so
|
||||
// memory usage is O(1) regardless of total session count.
|
||||
func (s *JSONLStore) sessionLock(key string) *sync.Mutex {
|
||||
h := fnv.New32a()
|
||||
h.Write([]byte(key))
|
||||
return &s.locks[h.Sum32()%numLockShards]
|
||||
}
|
||||
|
||||
func (s *JSONLStore) jsonlPath(key string) string {
|
||||
return filepath.Join(s.dir, sanitizeKey(key)+".jsonl")
|
||||
}
|
||||
|
||||
func (s *JSONLStore) metaPath(key string) string {
|
||||
return filepath.Join(s.dir, sanitizeKey(key)+".meta.json")
|
||||
}
|
||||
|
||||
// sanitizeKey converts a session key to a safe filename component.
|
||||
// Mirrors pkg/session.sanitizeFilename so that migration paths match.
|
||||
//
|
||||
// Note: this is a lossy mapping — "telegram:123" and "telegram_123"
|
||||
// both produce the same filename. This is an intentional tradeoff:
|
||||
// keys with colons (e.g. from channels) are by far the common case,
|
||||
// and a bidirectional encoding (like URL-encoding) would complicate
|
||||
// file listings and debugging.
|
||||
func sanitizeKey(key string) string {
|
||||
return strings.ReplaceAll(key, ":", "_")
|
||||
}
|
||||
|
||||
// readMeta loads the metadata file for a session.
|
||||
// Returns a zero-value sessionMeta if the file does not exist.
|
||||
func (s *JSONLStore) readMeta(key string) (sessionMeta, error) {
|
||||
data, err := os.ReadFile(s.metaPath(key))
|
||||
if os.IsNotExist(err) {
|
||||
return sessionMeta{Key: key}, nil
|
||||
}
|
||||
if err != nil {
|
||||
return sessionMeta{}, fmt.Errorf("memory: read meta: %w", err)
|
||||
}
|
||||
var meta sessionMeta
|
||||
err = json.Unmarshal(data, &meta)
|
||||
if err != nil {
|
||||
return sessionMeta{}, fmt.Errorf("memory: decode meta: %w", err)
|
||||
}
|
||||
return meta, nil
|
||||
}
|
||||
|
||||
// writeMeta atomically writes the metadata file using the project's
|
||||
// standard WriteFileAtomic (temp + fsync + rename).
|
||||
func (s *JSONLStore) writeMeta(key string, meta sessionMeta) error {
|
||||
data, err := json.MarshalIndent(meta, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("memory: encode meta: %w", err)
|
||||
}
|
||||
return fileutil.WriteFileAtomic(s.metaPath(key), data, 0o644)
|
||||
}
|
||||
|
||||
// readMessages reads valid JSON lines from a .jsonl file, skipping
|
||||
// the first `skip` lines without unmarshaling them. This avoids the
|
||||
// cost of json.Unmarshal on logically truncated messages.
|
||||
// Malformed trailing lines (e.g. from a crash) are silently skipped.
|
||||
func readMessages(path string, skip int) ([]providers.Message, error) {
|
||||
f, err := os.Open(path)
|
||||
if os.IsNotExist(err) {
|
||||
return []providers.Message{}, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("memory: open jsonl: %w", err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
var msgs []providers.Message
|
||||
scanner := bufio.NewScanner(f)
|
||||
// Allow large lines for tool results (read_file, web search, etc.).
|
||||
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
|
||||
|
||||
lineNum := 0
|
||||
for scanner.Scan() {
|
||||
line := scanner.Bytes()
|
||||
if len(line) == 0 {
|
||||
continue
|
||||
}
|
||||
lineNum++
|
||||
if lineNum <= skip {
|
||||
continue
|
||||
}
|
||||
var msg providers.Message
|
||||
if err := json.Unmarshal(line, &msg); err != nil {
|
||||
// Corrupt line — likely a partial write from a crash.
|
||||
// Log so operators know data was skipped, but don't
|
||||
// fail the entire read; this is the standard JSONL
|
||||
// recovery pattern.
|
||||
log.Printf("memory: skipping corrupt line %d in %s: %v",
|
||||
lineNum, filepath.Base(path), err)
|
||||
continue
|
||||
}
|
||||
msgs = append(msgs, msg)
|
||||
}
|
||||
if scanner.Err() != nil {
|
||||
return nil, fmt.Errorf("memory: scan jsonl: %w", scanner.Err())
|
||||
}
|
||||
|
||||
if msgs == nil {
|
||||
msgs = []providers.Message{}
|
||||
}
|
||||
return msgs, nil
|
||||
}
|
||||
|
||||
// countLines counts the total number of non-empty lines in a .jsonl file.
|
||||
// Used by TruncateHistory to reconcile a stale meta.Count without
|
||||
// the overhead of unmarshaling every message.
|
||||
func countLines(path string) (int, error) {
|
||||
f, err := os.Open(path)
|
||||
if os.IsNotExist(err) {
|
||||
return 0, nil
|
||||
}
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("memory: open jsonl: %w", err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
n := 0
|
||||
scanner := bufio.NewScanner(f)
|
||||
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
|
||||
for scanner.Scan() {
|
||||
if len(scanner.Bytes()) > 0 {
|
||||
n++
|
||||
}
|
||||
}
|
||||
return n, scanner.Err()
|
||||
}
|
||||
|
||||
func (s *JSONLStore) AddMessage(
|
||||
_ context.Context, sessionKey, role, content string,
|
||||
) error {
|
||||
return s.addMsg(sessionKey, providers.Message{
|
||||
Role: role,
|
||||
Content: content,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *JSONLStore) AddFullMessage(
|
||||
_ context.Context, sessionKey string, msg providers.Message,
|
||||
) error {
|
||||
return s.addMsg(sessionKey, msg)
|
||||
}
|
||||
|
||||
// addMsg is the shared implementation for AddMessage and AddFullMessage.
|
||||
func (s *JSONLStore) addMsg(sessionKey string, msg providers.Message) error {
|
||||
l := s.sessionLock(sessionKey)
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
|
||||
// Append the message as a single JSON line.
|
||||
line, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("memory: marshal message: %w", err)
|
||||
}
|
||||
line = append(line, '\n')
|
||||
|
||||
f, err := os.OpenFile(
|
||||
s.jsonlPath(sessionKey),
|
||||
os.O_CREATE|os.O_WRONLY|os.O_APPEND,
|
||||
0o644,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("memory: open jsonl for append: %w", err)
|
||||
}
|
||||
_, writeErr := f.Write(line)
|
||||
if writeErr != nil {
|
||||
f.Close()
|
||||
return fmt.Errorf("memory: append message: %w", writeErr)
|
||||
}
|
||||
// Flush to physical storage before closing. This matches the
|
||||
// durability guarantee of writeMeta and rewriteJSONL (which use
|
||||
// WriteFileAtomic with fsync). Without Sync, a power loss could
|
||||
// leave the append in the kernel page cache only — lost on reboot.
|
||||
if syncErr := f.Sync(); syncErr != nil {
|
||||
f.Close()
|
||||
return fmt.Errorf("memory: sync jsonl: %w", syncErr)
|
||||
}
|
||||
if closeErr := f.Close(); closeErr != nil {
|
||||
return fmt.Errorf("memory: close jsonl: %w", closeErr)
|
||||
}
|
||||
|
||||
// Update metadata.
|
||||
meta, err := s.readMeta(sessionKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
now := time.Now()
|
||||
if meta.Count == 0 && meta.CreatedAt.IsZero() {
|
||||
meta.CreatedAt = now
|
||||
}
|
||||
meta.Count++
|
||||
meta.UpdatedAt = now
|
||||
|
||||
return s.writeMeta(sessionKey, meta)
|
||||
}
|
||||
|
||||
func (s *JSONLStore) GetHistory(
|
||||
_ context.Context, sessionKey string,
|
||||
) ([]providers.Message, error) {
|
||||
l := s.sessionLock(sessionKey)
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
|
||||
meta, err := s.readMeta(sessionKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Pass meta.Skip so readMessages skips those lines without
|
||||
// unmarshaling them — avoids wasted CPU on truncated messages.
|
||||
msgs, err := readMessages(s.jsonlPath(sessionKey), meta.Skip)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return msgs, nil
|
||||
}
|
||||
|
||||
func (s *JSONLStore) GetSummary(
|
||||
_ context.Context, sessionKey string,
|
||||
) (string, error) {
|
||||
l := s.sessionLock(sessionKey)
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
|
||||
meta, err := s.readMeta(sessionKey)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return meta.Summary, nil
|
||||
}
|
||||
|
||||
func (s *JSONLStore) SetSummary(
|
||||
_ context.Context, sessionKey, summary string,
|
||||
) error {
|
||||
l := s.sessionLock(sessionKey)
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
|
||||
meta, err := s.readMeta(sessionKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
now := time.Now()
|
||||
if meta.CreatedAt.IsZero() {
|
||||
meta.CreatedAt = now
|
||||
}
|
||||
meta.Summary = summary
|
||||
meta.UpdatedAt = now
|
||||
|
||||
return s.writeMeta(sessionKey, meta)
|
||||
}
|
||||
|
||||
func (s *JSONLStore) TruncateHistory(
|
||||
_ context.Context, sessionKey string, keepLast int,
|
||||
) error {
|
||||
l := s.sessionLock(sessionKey)
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
|
||||
meta, err := s.readMeta(sessionKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Always reconcile meta.Count with the actual line count on disk.
|
||||
// A crash between the JSONL append and the meta update in addMsg
|
||||
// leaves meta.Count stale (e.g. file has 101 lines but meta says
|
||||
// 100). Counting lines is cheap — no unmarshal, just a scan — and
|
||||
// TruncateHistory is not a hot path, so always re-count.
|
||||
n, countErr := countLines(s.jsonlPath(sessionKey))
|
||||
if countErr != nil {
|
||||
return countErr
|
||||
}
|
||||
meta.Count = n
|
||||
|
||||
if keepLast <= 0 {
|
||||
meta.Skip = meta.Count
|
||||
} else {
|
||||
effective := meta.Count - meta.Skip
|
||||
if keepLast < effective {
|
||||
meta.Skip = meta.Count - keepLast
|
||||
}
|
||||
}
|
||||
meta.UpdatedAt = time.Now()
|
||||
|
||||
return s.writeMeta(sessionKey, meta)
|
||||
}
|
||||
|
||||
func (s *JSONLStore) SetHistory(
|
||||
_ context.Context,
|
||||
sessionKey string,
|
||||
history []providers.Message,
|
||||
) error {
|
||||
l := s.sessionLock(sessionKey)
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
|
||||
meta, err := s.readMeta(sessionKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
now := time.Now()
|
||||
if meta.CreatedAt.IsZero() {
|
||||
meta.CreatedAt = now
|
||||
}
|
||||
meta.Skip = 0
|
||||
meta.Count = len(history)
|
||||
meta.UpdatedAt = now
|
||||
|
||||
// Write meta BEFORE rewriting the JSONL file. If we crash between
|
||||
// the two writes, meta has Skip=0 and the old file is still intact,
|
||||
// so GetHistory reads from line 1 — returning "too many" messages
|
||||
// rather than losing data. The next SetHistory call corrects this.
|
||||
err = s.writeMeta(sessionKey, meta)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return s.rewriteJSONL(sessionKey, history)
|
||||
}
|
||||
|
||||
// Compact physically rewrites the JSONL file, dropping all logically
|
||||
// skipped lines. This reclaims disk space that accumulates after
|
||||
// repeated TruncateHistory calls.
|
||||
//
|
||||
// It is safe to call at any time; if there is nothing to compact
|
||||
// (skip == 0) the method returns immediately.
|
||||
func (s *JSONLStore) Compact(
|
||||
_ context.Context, sessionKey string,
|
||||
) error {
|
||||
l := s.sessionLock(sessionKey)
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
|
||||
meta, err := s.readMeta(sessionKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if meta.Skip == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Read only the active messages, skipping truncated lines
|
||||
// without unmarshaling them.
|
||||
active, err := readMessages(s.jsonlPath(sessionKey), meta.Skip)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Write meta BEFORE rewriting the JSONL file. If the process
|
||||
// crashes between the two writes, meta has Skip=0 and the old
|
||||
// (uncompacted) file is still intact, so GetHistory reads from
|
||||
// line 1 — returning previously-truncated messages rather than
|
||||
// losing data. The next Compact or TruncateHistory corrects this.
|
||||
meta.Skip = 0
|
||||
meta.Count = len(active)
|
||||
meta.UpdatedAt = time.Now()
|
||||
|
||||
err = s.writeMeta(sessionKey, meta)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return s.rewriteJSONL(sessionKey, active)
|
||||
}
|
||||
|
||||
// rewriteJSONL atomically replaces the JSONL file with the given messages
|
||||
// using the project's standard WriteFileAtomic (temp + fsync + rename).
|
||||
func (s *JSONLStore) rewriteJSONL(
|
||||
sessionKey string, msgs []providers.Message,
|
||||
) error {
|
||||
var buf bytes.Buffer
|
||||
for i, msg := range msgs {
|
||||
line, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("memory: marshal message %d: %w", i, err)
|
||||
}
|
||||
buf.Write(line)
|
||||
buf.WriteByte('\n')
|
||||
}
|
||||
return fileutil.WriteFileAtomic(s.jsonlPath(sessionKey), buf.Bytes(), 0o644)
|
||||
}
|
||||
|
||||
func (s *JSONLStore) Close() error {
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,835 @@
|
||||
package memory
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
func newTestStore(t *testing.T) *JSONLStore {
|
||||
t.Helper()
|
||||
store, err := NewJSONLStore(t.TempDir())
|
||||
if err != nil {
|
||||
t.Fatalf("NewJSONLStore: %v", err)
|
||||
}
|
||||
return store
|
||||
}
|
||||
|
||||
func TestNewJSONLStore_CreatesDirectory(t *testing.T) {
|
||||
dir := filepath.Join(t.TempDir(), "nested", "sessions")
|
||||
store, err := NewJSONLStore(dir)
|
||||
if err != nil {
|
||||
t.Fatalf("NewJSONLStore: %v", err)
|
||||
}
|
||||
defer store.Close()
|
||||
|
||||
info, err := os.Stat(dir)
|
||||
if err != nil {
|
||||
t.Fatalf("Stat: %v", err)
|
||||
}
|
||||
if !info.IsDir() {
|
||||
t.Errorf("expected directory, got file")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddMessage_BasicRoundtrip(t *testing.T) {
|
||||
store := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
err := store.AddMessage(ctx, "s1", "user", "hello")
|
||||
if err != nil {
|
||||
t.Fatalf("AddMessage: %v", err)
|
||||
}
|
||||
err = store.AddMessage(ctx, "s1", "assistant", "hi there")
|
||||
if err != nil {
|
||||
t.Fatalf("AddMessage: %v", err)
|
||||
}
|
||||
|
||||
history, err := store.GetHistory(ctx, "s1")
|
||||
if err != nil {
|
||||
t.Fatalf("GetHistory: %v", err)
|
||||
}
|
||||
if len(history) != 2 {
|
||||
t.Fatalf("expected 2 messages, got %d", len(history))
|
||||
}
|
||||
if history[0].Role != "user" || history[0].Content != "hello" {
|
||||
t.Errorf("msg[0] = %+v", history[0])
|
||||
}
|
||||
if history[1].Role != "assistant" || history[1].Content != "hi there" {
|
||||
t.Errorf("msg[1] = %+v", history[1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddMessage_AutoCreatesSession(t *testing.T) {
|
||||
store := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Adding a message to a non-existent session should work.
|
||||
err := store.AddMessage(ctx, "new-session", "user", "first message")
|
||||
if err != nil {
|
||||
t.Fatalf("AddMessage: %v", err)
|
||||
}
|
||||
|
||||
history, err := store.GetHistory(ctx, "new-session")
|
||||
if err != nil {
|
||||
t.Fatalf("GetHistory: %v", err)
|
||||
}
|
||||
if len(history) != 1 {
|
||||
t.Fatalf("expected 1 message, got %d", len(history))
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddFullMessage_WithToolCalls(t *testing.T) {
|
||||
store := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
msg := providers.Message{
|
||||
Role: "assistant",
|
||||
Content: "Let me search that.",
|
||||
ToolCalls: []providers.ToolCall{
|
||||
{
|
||||
ID: "call_abc",
|
||||
Type: "function",
|
||||
Function: &providers.FunctionCall{
|
||||
Name: "web_search",
|
||||
Arguments: `{"q":"golang jsonl"}`,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err := store.AddFullMessage(ctx, "tc", msg)
|
||||
if err != nil {
|
||||
t.Fatalf("AddFullMessage: %v", err)
|
||||
}
|
||||
|
||||
history, err := store.GetHistory(ctx, "tc")
|
||||
if err != nil {
|
||||
t.Fatalf("GetHistory: %v", err)
|
||||
}
|
||||
if len(history) != 1 {
|
||||
t.Fatalf("expected 1, got %d", len(history))
|
||||
}
|
||||
if len(history[0].ToolCalls) != 1 {
|
||||
t.Fatalf("expected 1 tool call, got %d", len(history[0].ToolCalls))
|
||||
}
|
||||
tc := history[0].ToolCalls[0]
|
||||
if tc.ID != "call_abc" {
|
||||
t.Errorf("tool call ID = %q", tc.ID)
|
||||
}
|
||||
if tc.Function == nil || tc.Function.Name != "web_search" {
|
||||
t.Errorf("tool call function = %+v", tc.Function)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddFullMessage_ToolCallID(t *testing.T) {
|
||||
store := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
msg := providers.Message{
|
||||
Role: "tool",
|
||||
Content: "search results here",
|
||||
ToolCallID: "call_abc",
|
||||
}
|
||||
|
||||
err := store.AddFullMessage(ctx, "tr", msg)
|
||||
if err != nil {
|
||||
t.Fatalf("AddFullMessage: %v", err)
|
||||
}
|
||||
|
||||
history, err := store.GetHistory(ctx, "tr")
|
||||
if err != nil {
|
||||
t.Fatalf("GetHistory: %v", err)
|
||||
}
|
||||
if len(history) != 1 {
|
||||
t.Fatalf("expected 1, got %d", len(history))
|
||||
}
|
||||
if history[0].ToolCallID != "call_abc" {
|
||||
t.Errorf("ToolCallID = %q", history[0].ToolCallID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetHistory_EmptySession(t *testing.T) {
|
||||
store := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
history, err := store.GetHistory(ctx, "nonexistent")
|
||||
if err != nil {
|
||||
t.Fatalf("GetHistory: %v", err)
|
||||
}
|
||||
if history == nil {
|
||||
t.Fatal("expected non-nil empty slice")
|
||||
}
|
||||
if len(history) != 0 {
|
||||
t.Errorf("expected 0 messages, got %d", len(history))
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetHistory_Ordering(t *testing.T) {
|
||||
store := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
err := store.AddMessage(
|
||||
ctx, "order",
|
||||
"user",
|
||||
string(rune('a'+i)),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("AddMessage(%d): %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
history, err := store.GetHistory(ctx, "order")
|
||||
if err != nil {
|
||||
t.Fatalf("GetHistory: %v", err)
|
||||
}
|
||||
if len(history) != 5 {
|
||||
t.Fatalf("expected 5, got %d", len(history))
|
||||
}
|
||||
for i := 0; i < 5; i++ {
|
||||
expected := string(rune('a' + i))
|
||||
if history[i].Content != expected {
|
||||
t.Errorf("msg[%d].Content = %q, want %q", i, history[i].Content, expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetSummary_GetSummary(t *testing.T) {
|
||||
store := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// No summary yet.
|
||||
summary, err := store.GetSummary(ctx, "s1")
|
||||
if err != nil {
|
||||
t.Fatalf("GetSummary: %v", err)
|
||||
}
|
||||
if summary != "" {
|
||||
t.Errorf("expected empty, got %q", summary)
|
||||
}
|
||||
|
||||
// Set a summary.
|
||||
err = store.SetSummary(ctx, "s1", "talked about Go")
|
||||
if err != nil {
|
||||
t.Fatalf("SetSummary: %v", err)
|
||||
}
|
||||
|
||||
summary, err = store.GetSummary(ctx, "s1")
|
||||
if err != nil {
|
||||
t.Fatalf("GetSummary: %v", err)
|
||||
}
|
||||
if summary != "talked about Go" {
|
||||
t.Errorf("summary = %q", summary)
|
||||
}
|
||||
|
||||
// Update summary.
|
||||
err = store.SetSummary(ctx, "s1", "updated summary")
|
||||
if err != nil {
|
||||
t.Fatalf("SetSummary: %v", err)
|
||||
}
|
||||
|
||||
summary, err = store.GetSummary(ctx, "s1")
|
||||
if err != nil {
|
||||
t.Fatalf("GetSummary: %v", err)
|
||||
}
|
||||
if summary != "updated summary" {
|
||||
t.Errorf("summary = %q", summary)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTruncateHistory_KeepLast(t *testing.T) {
|
||||
store := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
err := store.AddMessage(
|
||||
ctx, "trunc",
|
||||
"user",
|
||||
string(rune('a'+i)),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("AddMessage: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
err := store.TruncateHistory(ctx, "trunc", 4)
|
||||
if err != nil {
|
||||
t.Fatalf("TruncateHistory: %v", err)
|
||||
}
|
||||
|
||||
history, err := store.GetHistory(ctx, "trunc")
|
||||
if err != nil {
|
||||
t.Fatalf("GetHistory: %v", err)
|
||||
}
|
||||
if len(history) != 4 {
|
||||
t.Fatalf("expected 4, got %d", len(history))
|
||||
}
|
||||
// Should be the last 4: g, h, i, j
|
||||
if history[0].Content != "g" {
|
||||
t.Errorf("first kept = %q, want 'g'", history[0].Content)
|
||||
}
|
||||
if history[3].Content != "j" {
|
||||
t.Errorf("last kept = %q, want 'j'", history[3].Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTruncateHistory_KeepZero(t *testing.T) {
|
||||
store := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
err := store.AddMessage(ctx, "empty", "user", "msg")
|
||||
if err != nil {
|
||||
t.Fatalf("AddMessage: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
err := store.TruncateHistory(ctx, "empty", 0)
|
||||
if err != nil {
|
||||
t.Fatalf("TruncateHistory: %v", err)
|
||||
}
|
||||
|
||||
history, err := store.GetHistory(ctx, "empty")
|
||||
if err != nil {
|
||||
t.Fatalf("GetHistory: %v", err)
|
||||
}
|
||||
if len(history) != 0 {
|
||||
t.Errorf("expected 0, got %d", len(history))
|
||||
}
|
||||
}
|
||||
|
||||
func TestTruncateHistory_KeepMoreThanExists(t *testing.T) {
|
||||
store := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
err := store.AddMessage(ctx, "few", "user", "msg")
|
||||
if err != nil {
|
||||
t.Fatalf("AddMessage: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Keep 100, but only 3 exist — should keep all.
|
||||
err := store.TruncateHistory(ctx, "few", 100)
|
||||
if err != nil {
|
||||
t.Fatalf("TruncateHistory: %v", err)
|
||||
}
|
||||
|
||||
history, err := store.GetHistory(ctx, "few")
|
||||
if err != nil {
|
||||
t.Fatalf("GetHistory: %v", err)
|
||||
}
|
||||
if len(history) != 3 {
|
||||
t.Errorf("expected 3, got %d", len(history))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetHistory_ReplacesAll(t *testing.T) {
|
||||
store := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Add some initial messages.
|
||||
for i := 0; i < 5; i++ {
|
||||
err := store.AddMessage(ctx, "replace", "user", "old")
|
||||
if err != nil {
|
||||
t.Fatalf("AddMessage: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Replace with new history.
|
||||
newHistory := []providers.Message{
|
||||
{Role: "user", Content: "new1"},
|
||||
{Role: "assistant", Content: "new2"},
|
||||
}
|
||||
err := store.SetHistory(ctx, "replace", newHistory)
|
||||
if err != nil {
|
||||
t.Fatalf("SetHistory: %v", err)
|
||||
}
|
||||
|
||||
history, err := store.GetHistory(ctx, "replace")
|
||||
if err != nil {
|
||||
t.Fatalf("GetHistory: %v", err)
|
||||
}
|
||||
if len(history) != 2 {
|
||||
t.Fatalf("expected 2, got %d", len(history))
|
||||
}
|
||||
if history[0].Content != "new1" || history[1].Content != "new2" {
|
||||
t.Errorf("history = %+v", history)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetHistory_ResetsSkip(t *testing.T) {
|
||||
store := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Add messages and truncate.
|
||||
for i := 0; i < 10; i++ {
|
||||
err := store.AddMessage(ctx, "skip-reset", "user", "old")
|
||||
if err != nil {
|
||||
t.Fatalf("AddMessage: %v", err)
|
||||
}
|
||||
}
|
||||
err := store.TruncateHistory(ctx, "skip-reset", 3)
|
||||
if err != nil {
|
||||
t.Fatalf("TruncateHistory: %v", err)
|
||||
}
|
||||
|
||||
// SetHistory should reset skip to 0.
|
||||
newHistory := []providers.Message{
|
||||
{Role: "user", Content: "fresh"},
|
||||
}
|
||||
err = store.SetHistory(ctx, "skip-reset", newHistory)
|
||||
if err != nil {
|
||||
t.Fatalf("SetHistory: %v", err)
|
||||
}
|
||||
|
||||
history, err := store.GetHistory(ctx, "skip-reset")
|
||||
if err != nil {
|
||||
t.Fatalf("GetHistory: %v", err)
|
||||
}
|
||||
if len(history) != 1 {
|
||||
t.Fatalf("expected 1, got %d", len(history))
|
||||
}
|
||||
if history[0].Content != "fresh" {
|
||||
t.Errorf("content = %q", history[0].Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestColonInKey(t *testing.T) {
|
||||
store := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
err := store.AddMessage(ctx, "telegram:123", "user", "hi")
|
||||
if err != nil {
|
||||
t.Fatalf("AddMessage: %v", err)
|
||||
}
|
||||
|
||||
history, err := store.GetHistory(ctx, "telegram:123")
|
||||
if err != nil {
|
||||
t.Fatalf("GetHistory: %v", err)
|
||||
}
|
||||
if len(history) != 1 {
|
||||
t.Fatalf("expected 1, got %d", len(history))
|
||||
}
|
||||
|
||||
// Verify the file is named with underscore.
|
||||
jsonlFile := filepath.Join(store.dir, "telegram_123.jsonl")
|
||||
if _, statErr := os.Stat(jsonlFile); statErr != nil {
|
||||
t.Errorf("expected file %s to exist: %v", jsonlFile, statErr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompact_RemovesSkippedMessages(t *testing.T) {
|
||||
store := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Write 10 messages, then truncate to keep last 3.
|
||||
for i := 0; i < 10; i++ {
|
||||
err := store.AddMessage(ctx, "compact", "user", string(rune('a'+i)))
|
||||
if err != nil {
|
||||
t.Fatalf("AddMessage: %v", err)
|
||||
}
|
||||
}
|
||||
err := store.TruncateHistory(ctx, "compact", 3)
|
||||
if err != nil {
|
||||
t.Fatalf("TruncateHistory: %v", err)
|
||||
}
|
||||
|
||||
// Before compact: file still has 10 lines.
|
||||
allOnDisk, err := readMessages(store.jsonlPath("compact"), 0)
|
||||
if err != nil {
|
||||
t.Fatalf("readMessages: %v", err)
|
||||
}
|
||||
if len(allOnDisk) != 10 {
|
||||
t.Fatalf("before compact: expected 10 on disk, got %d", len(allOnDisk))
|
||||
}
|
||||
|
||||
// Compact.
|
||||
err = store.Compact(ctx, "compact")
|
||||
if err != nil {
|
||||
t.Fatalf("Compact: %v", err)
|
||||
}
|
||||
|
||||
// After compact: file should have only 3 lines.
|
||||
allOnDisk, err = readMessages(store.jsonlPath("compact"), 0)
|
||||
if err != nil {
|
||||
t.Fatalf("readMessages: %v", err)
|
||||
}
|
||||
if len(allOnDisk) != 3 {
|
||||
t.Fatalf("after compact: expected 3 on disk, got %d", len(allOnDisk))
|
||||
}
|
||||
|
||||
// GetHistory should still return the same 3 messages.
|
||||
history, err := store.GetHistory(ctx, "compact")
|
||||
if err != nil {
|
||||
t.Fatalf("GetHistory: %v", err)
|
||||
}
|
||||
if len(history) != 3 {
|
||||
t.Fatalf("expected 3, got %d", len(history))
|
||||
}
|
||||
if history[0].Content != "h" || history[2].Content != "j" {
|
||||
t.Errorf("wrong content: %+v", history)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompact_NoOpWhenNoSkip(t *testing.T) {
|
||||
store := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
err := store.AddMessage(ctx, "noop", "user", "msg")
|
||||
if err != nil {
|
||||
t.Fatalf("AddMessage: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Compact without prior truncation — should be a no-op.
|
||||
err := store.Compact(ctx, "noop")
|
||||
if err != nil {
|
||||
t.Fatalf("Compact: %v", err)
|
||||
}
|
||||
|
||||
history, err := store.GetHistory(ctx, "noop")
|
||||
if err != nil {
|
||||
t.Fatalf("GetHistory: %v", err)
|
||||
}
|
||||
if len(history) != 5 {
|
||||
t.Errorf("expected 5, got %d", len(history))
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompact_ThenAppend(t *testing.T) {
|
||||
store := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
for i := 0; i < 8; i++ {
|
||||
err := store.AddMessage(ctx, "cap", "user", string(rune('a'+i)))
|
||||
if err != nil {
|
||||
t.Fatalf("AddMessage: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
err := store.TruncateHistory(ctx, "cap", 2)
|
||||
if err != nil {
|
||||
t.Fatalf("TruncateHistory: %v", err)
|
||||
}
|
||||
err = store.Compact(ctx, "cap")
|
||||
if err != nil {
|
||||
t.Fatalf("Compact: %v", err)
|
||||
}
|
||||
|
||||
// Append after compaction should work correctly.
|
||||
err = store.AddMessage(ctx, "cap", "user", "new")
|
||||
if err != nil {
|
||||
t.Fatalf("AddMessage after compact: %v", err)
|
||||
}
|
||||
|
||||
history, err := store.GetHistory(ctx, "cap")
|
||||
if err != nil {
|
||||
t.Fatalf("GetHistory: %v", err)
|
||||
}
|
||||
if len(history) != 3 {
|
||||
t.Fatalf("expected 3, got %d", len(history))
|
||||
}
|
||||
// g, h (kept from truncation), new (appended after compaction).
|
||||
if history[0].Content != "g" {
|
||||
t.Errorf("first = %q, want 'g'", history[0].Content)
|
||||
}
|
||||
if history[2].Content != "new" {
|
||||
t.Errorf("last = %q, want 'new'", history[2].Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTruncateHistory_StaleMetaCount(t *testing.T) {
|
||||
// Simulates a crash between JSONL append and meta update in addMsg:
|
||||
// file has N+1 lines but meta.Count is still N. TruncateHistory must
|
||||
// reconcile with the real line count so that keepLast is accurate.
|
||||
store := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Write 10 messages normally (meta.Count = 10).
|
||||
for i := 0; i < 10; i++ {
|
||||
err := store.AddMessage(ctx, "stale", "user", string(rune('a'+i)))
|
||||
if err != nil {
|
||||
t.Fatalf("AddMessage: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Simulate crash: append a line to JSONL but do NOT update meta.
|
||||
// This leaves meta.Count = 10 while the file has 11 lines.
|
||||
jsonlPath := store.jsonlPath("stale")
|
||||
f, err := os.OpenFile(jsonlPath, os.O_WRONLY|os.O_APPEND, 0o644)
|
||||
if err != nil {
|
||||
t.Fatalf("open for append: %v", err)
|
||||
}
|
||||
_, err = f.WriteString(`{"role":"user","content":"orphan"}` + "\n")
|
||||
if err != nil {
|
||||
t.Fatalf("write orphan: %v", err)
|
||||
}
|
||||
f.Close()
|
||||
|
||||
// TruncateHistory(keepLast=4) should keep the last 4 of 11 lines,
|
||||
// not the last 4 of 10.
|
||||
err = store.TruncateHistory(ctx, "stale", 4)
|
||||
if err != nil {
|
||||
t.Fatalf("TruncateHistory: %v", err)
|
||||
}
|
||||
|
||||
history, err := store.GetHistory(ctx, "stale")
|
||||
if err != nil {
|
||||
t.Fatalf("GetHistory: %v", err)
|
||||
}
|
||||
if len(history) != 4 {
|
||||
t.Fatalf("expected 4, got %d", len(history))
|
||||
}
|
||||
// Last 4 of [a,b,c,d,e,f,g,h,i,j,orphan] = [h,i,j,orphan]
|
||||
if history[0].Content != "h" {
|
||||
t.Errorf("first kept = %q, want 'h'", history[0].Content)
|
||||
}
|
||||
if history[3].Content != "orphan" {
|
||||
t.Errorf("last kept = %q, want 'orphan'", history[3].Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCrashRecovery_PartialLine(t *testing.T) {
|
||||
store := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Write a valid message first.
|
||||
err := store.AddMessage(ctx, "crash", "user", "valid")
|
||||
if err != nil {
|
||||
t.Fatalf("AddMessage: %v", err)
|
||||
}
|
||||
|
||||
// Simulate a crash by appending a partial JSON line directly.
|
||||
jsonlPath := store.jsonlPath("crash")
|
||||
f, err := os.OpenFile(jsonlPath, os.O_WRONLY|os.O_APPEND, 0o644)
|
||||
if err != nil {
|
||||
t.Fatalf("open for append: %v", err)
|
||||
}
|
||||
_, err = f.WriteString(`{"role":"user","content":"incomple`)
|
||||
if err != nil {
|
||||
t.Fatalf("write partial: %v", err)
|
||||
}
|
||||
f.Close()
|
||||
|
||||
// GetHistory should return only the valid message.
|
||||
history, err := store.GetHistory(ctx, "crash")
|
||||
if err != nil {
|
||||
t.Fatalf("GetHistory: %v", err)
|
||||
}
|
||||
if len(history) != 1 {
|
||||
t.Fatalf("expected 1 valid message, got %d", len(history))
|
||||
}
|
||||
if history[0].Content != "valid" {
|
||||
t.Errorf("content = %q", history[0].Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPersistence_AcrossInstances(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
ctx := context.Background()
|
||||
|
||||
// Write with first instance.
|
||||
store1, err := NewJSONLStore(dir)
|
||||
if err != nil {
|
||||
t.Fatalf("NewJSONLStore: %v", err)
|
||||
}
|
||||
err = store1.AddMessage(ctx, "persist", "user", "remember me")
|
||||
if err != nil {
|
||||
t.Fatalf("AddMessage: %v", err)
|
||||
}
|
||||
err = store1.SetSummary(ctx, "persist", "a test session")
|
||||
if err != nil {
|
||||
t.Fatalf("SetSummary: %v", err)
|
||||
}
|
||||
store1.Close()
|
||||
|
||||
// Read with second instance.
|
||||
store2, err := NewJSONLStore(dir)
|
||||
if err != nil {
|
||||
t.Fatalf("NewJSONLStore: %v", err)
|
||||
}
|
||||
defer store2.Close()
|
||||
|
||||
history, err := store2.GetHistory(ctx, "persist")
|
||||
if err != nil {
|
||||
t.Fatalf("GetHistory: %v", err)
|
||||
}
|
||||
if len(history) != 1 || history[0].Content != "remember me" {
|
||||
t.Errorf("history = %+v", history)
|
||||
}
|
||||
|
||||
summary, err := store2.GetSummary(ctx, "persist")
|
||||
if err != nil {
|
||||
t.Fatalf("GetSummary: %v", err)
|
||||
}
|
||||
if summary != "a test session" {
|
||||
t.Errorf("summary = %q", summary)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConcurrent_AddAndRead(t *testing.T) {
|
||||
store := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
const goroutines = 10
|
||||
const msgsPerGoroutine = 20
|
||||
|
||||
// Concurrent writes.
|
||||
for g := 0; g < goroutines; g++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for i := 0; i < msgsPerGoroutine; i++ {
|
||||
_ = store.AddMessage(ctx, "concurrent", "user", "msg")
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
history, err := store.GetHistory(ctx, "concurrent")
|
||||
if err != nil {
|
||||
t.Fatalf("GetHistory: %v", err)
|
||||
}
|
||||
expected := goroutines * msgsPerGoroutine
|
||||
if len(history) != expected {
|
||||
t.Errorf("expected %d messages, got %d", expected, len(history))
|
||||
}
|
||||
}
|
||||
|
||||
func TestConcurrent_SummarizeRace(t *testing.T) {
|
||||
// Simulates the #704 race: one goroutine adds messages while
|
||||
// another truncates + sets summary — like summarizeSession().
|
||||
store := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Seed with some messages.
|
||||
for i := 0; i < 20; i++ {
|
||||
err := store.AddMessage(ctx, "race", "user", "seed")
|
||||
if err != nil {
|
||||
t.Fatalf("AddMessage: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// Writer goroutine (main agent loop).
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for i := 0; i < 50; i++ {
|
||||
_ = store.AddMessage(ctx, "race", "user", "new")
|
||||
}
|
||||
}()
|
||||
|
||||
// Summarizer goroutine (background task).
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for i := 0; i < 10; i++ {
|
||||
_ = store.SetSummary(ctx, "race", "summary")
|
||||
_ = store.TruncateHistory(ctx, "race", 5)
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Verify the store is still in a consistent state.
|
||||
_, err := store.GetHistory(ctx, "race")
|
||||
if err != nil {
|
||||
t.Fatalf("GetHistory after race: %v", err)
|
||||
}
|
||||
_, err = store.GetSummary(ctx, "race")
|
||||
if err != nil {
|
||||
t.Fatalf("GetSummary after race: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMultipleSessions_Isolation(t *testing.T) {
|
||||
store := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
err := store.AddMessage(ctx, "s1", "user", "msg for s1")
|
||||
if err != nil {
|
||||
t.Fatalf("AddMessage: %v", err)
|
||||
}
|
||||
err = store.AddMessage(ctx, "s2", "user", "msg for s2")
|
||||
if err != nil {
|
||||
t.Fatalf("AddMessage: %v", err)
|
||||
}
|
||||
|
||||
h1, err := store.GetHistory(ctx, "s1")
|
||||
if err != nil {
|
||||
t.Fatalf("GetHistory s1: %v", err)
|
||||
}
|
||||
h2, err := store.GetHistory(ctx, "s2")
|
||||
if err != nil {
|
||||
t.Fatalf("GetHistory s2: %v", err)
|
||||
}
|
||||
|
||||
if len(h1) != 1 || h1[0].Content != "msg for s1" {
|
||||
t.Errorf("s1 history = %+v", h1)
|
||||
}
|
||||
if len(h2) != 1 || h2[0].Content != "msg for s2" {
|
||||
t.Errorf("s2 history = %+v", h2)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkAddMessage(b *testing.B) {
|
||||
dir := b.TempDir()
|
||||
store, err := NewJSONLStore(dir)
|
||||
if err != nil {
|
||||
b.Fatalf("NewJSONLStore: %v", err)
|
||||
}
|
||||
defer store.Close()
|
||||
ctx := context.Background()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = store.AddMessage(ctx, "bench", "user", "benchmark message content")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkGetHistory_100(b *testing.B) {
|
||||
dir := b.TempDir()
|
||||
store, err := NewJSONLStore(dir)
|
||||
if err != nil {
|
||||
b.Fatalf("NewJSONLStore: %v", err)
|
||||
}
|
||||
defer store.Close()
|
||||
ctx := context.Background()
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
_ = store.AddMessage(ctx, "bench", "user", "message content")
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = store.GetHistory(ctx, "bench")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkGetHistory_1000(b *testing.B) {
|
||||
dir := b.TempDir()
|
||||
store, err := NewJSONLStore(dir)
|
||||
if err != nil {
|
||||
b.Fatalf("NewJSONLStore: %v", err)
|
||||
}
|
||||
defer store.Close()
|
||||
ctx := context.Background()
|
||||
|
||||
for i := 0; i < 1000; i++ {
|
||||
_ = store.AddMessage(ctx, "bench", "user", "message content")
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = store.GetHistory(ctx, "bench")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,108 @@
|
||||
package memory
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
// jsonSession mirrors pkg/session.Session for migration purposes.
|
||||
type jsonSession struct {
|
||||
Key string `json:"key"`
|
||||
Messages []providers.Message `json:"messages"`
|
||||
Summary string `json:"summary,omitempty"`
|
||||
Created time.Time `json:"created"`
|
||||
Updated time.Time `json:"updated"`
|
||||
}
|
||||
|
||||
// MigrateFromJSON reads legacy sessions/*.json files from sessionsDir,
|
||||
// writes them into the Store, and renames each migrated file to
|
||||
// .json.migrated as a backup. Returns the number of sessions migrated.
|
||||
//
|
||||
// Files that fail to parse are logged and skipped. Already-migrated
|
||||
// files (.json.migrated) are ignored, making the function idempotent.
|
||||
func MigrateFromJSON(
|
||||
ctx context.Context, sessionsDir string, store Store,
|
||||
) (int, error) {
|
||||
entries, err := os.ReadDir(sessionsDir)
|
||||
if os.IsNotExist(err) {
|
||||
return 0, nil
|
||||
}
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("memory: read sessions dir: %w", err)
|
||||
}
|
||||
|
||||
migrated := 0
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() {
|
||||
continue
|
||||
}
|
||||
name := entry.Name()
|
||||
if !strings.HasSuffix(name, ".json") {
|
||||
continue
|
||||
}
|
||||
// Skip already-migrated files.
|
||||
if strings.HasSuffix(name, ".migrated") {
|
||||
continue
|
||||
}
|
||||
|
||||
srcPath := filepath.Join(sessionsDir, name)
|
||||
|
||||
data, readErr := os.ReadFile(srcPath)
|
||||
if readErr != nil {
|
||||
log.Printf("memory: migrate: skip %s: %v", name, readErr)
|
||||
continue
|
||||
}
|
||||
|
||||
var sess jsonSession
|
||||
if parseErr := json.Unmarshal(data, &sess); parseErr != nil {
|
||||
log.Printf("memory: migrate: skip %s: %v", name, parseErr)
|
||||
continue
|
||||
}
|
||||
|
||||
// Use the key from the JSON content, not the filename.
|
||||
// Filenames are sanitized (":" → "_") but keys are not.
|
||||
key := sess.Key
|
||||
if key == "" {
|
||||
key = strings.TrimSuffix(name, ".json")
|
||||
}
|
||||
|
||||
// Use SetHistory (atomic replace) instead of per-message
|
||||
// AddFullMessage. This makes migration idempotent: if the
|
||||
// process crashes after writing messages but before the
|
||||
// rename below, a retry replaces the partial data cleanly
|
||||
// instead of duplicating messages.
|
||||
if setErr := store.SetHistory(ctx, key, sess.Messages); setErr != nil {
|
||||
return migrated, fmt.Errorf(
|
||||
"memory: migrate %s: set history: %w",
|
||||
name, setErr,
|
||||
)
|
||||
}
|
||||
|
||||
if sess.Summary != "" {
|
||||
if sumErr := store.SetSummary(ctx, key, sess.Summary); sumErr != nil {
|
||||
return migrated, fmt.Errorf(
|
||||
"memory: migrate %s: set summary: %w",
|
||||
name, sumErr,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Rename to .migrated as backup (not delete).
|
||||
renameErr := os.Rename(srcPath, srcPath+".migrated")
|
||||
if renameErr != nil {
|
||||
log.Printf("memory: migrate: rename %s: %v", name, renameErr)
|
||||
}
|
||||
|
||||
migrated++
|
||||
}
|
||||
|
||||
return migrated, nil
|
||||
}
|
||||
@@ -0,0 +1,384 @@
|
||||
package memory
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
func writeJSONSession(
|
||||
t *testing.T, dir string, filename string, sess jsonSession,
|
||||
) {
|
||||
t.Helper()
|
||||
data, err := json.MarshalIndent(sess, "", " ")
|
||||
if err != nil {
|
||||
t.Fatalf("marshal session: %v", err)
|
||||
}
|
||||
err = os.WriteFile(filepath.Join(dir, filename), data, 0o644)
|
||||
if err != nil {
|
||||
t.Fatalf("write session file: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMigrateFromJSON_Basic(t *testing.T) {
|
||||
sessionsDir := t.TempDir()
|
||||
store := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
writeJSONSession(t, sessionsDir, "test.json", jsonSession{
|
||||
Key: "test",
|
||||
Messages: []providers.Message{
|
||||
{Role: "user", Content: "hello"},
|
||||
{Role: "assistant", Content: "hi"},
|
||||
},
|
||||
Summary: "A greeting.",
|
||||
Created: time.Now(),
|
||||
Updated: time.Now(),
|
||||
})
|
||||
|
||||
count, err := MigrateFromJSON(ctx, sessionsDir, store)
|
||||
if err != nil {
|
||||
t.Fatalf("MigrateFromJSON: %v", err)
|
||||
}
|
||||
if count != 1 {
|
||||
t.Errorf("expected 1 migrated, got %d", count)
|
||||
}
|
||||
|
||||
history, err := store.GetHistory(ctx, "test")
|
||||
if err != nil {
|
||||
t.Fatalf("GetHistory: %v", err)
|
||||
}
|
||||
if len(history) != 2 {
|
||||
t.Fatalf("expected 2 messages, got %d", len(history))
|
||||
}
|
||||
if history[0].Content != "hello" || history[1].Content != "hi" {
|
||||
t.Errorf("unexpected messages: %+v", history)
|
||||
}
|
||||
|
||||
summary, err := store.GetSummary(ctx, "test")
|
||||
if err != nil {
|
||||
t.Fatalf("GetSummary: %v", err)
|
||||
}
|
||||
if summary != "A greeting." {
|
||||
t.Errorf("summary = %q", summary)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMigrateFromJSON_WithToolCalls(t *testing.T) {
|
||||
sessionsDir := t.TempDir()
|
||||
store := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
writeJSONSession(t, sessionsDir, "tools.json", jsonSession{
|
||||
Key: "tools",
|
||||
Messages: []providers.Message{
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "Searching...",
|
||||
ToolCalls: []providers.ToolCall{
|
||||
{
|
||||
ID: "call_1",
|
||||
Type: "function",
|
||||
Function: &providers.FunctionCall{
|
||||
Name: "web_search",
|
||||
Arguments: `{"q":"test"}`,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: "tool",
|
||||
Content: "result",
|
||||
ToolCallID: "call_1",
|
||||
},
|
||||
},
|
||||
Created: time.Now(),
|
||||
Updated: time.Now(),
|
||||
})
|
||||
|
||||
count, err := MigrateFromJSON(ctx, sessionsDir, store)
|
||||
if err != nil {
|
||||
t.Fatalf("MigrateFromJSON: %v", err)
|
||||
}
|
||||
if count != 1 {
|
||||
t.Errorf("expected 1, got %d", count)
|
||||
}
|
||||
|
||||
history, err := store.GetHistory(ctx, "tools")
|
||||
if err != nil {
|
||||
t.Fatalf("GetHistory: %v", err)
|
||||
}
|
||||
if len(history) != 2 {
|
||||
t.Fatalf("expected 2 messages, got %d", len(history))
|
||||
}
|
||||
if len(history[0].ToolCalls) != 1 {
|
||||
t.Fatalf("expected 1 tool call, got %d", len(history[0].ToolCalls))
|
||||
}
|
||||
if history[0].ToolCalls[0].Function.Name != "web_search" {
|
||||
t.Errorf("function = %q", history[0].ToolCalls[0].Function.Name)
|
||||
}
|
||||
if history[1].ToolCallID != "call_1" {
|
||||
t.Errorf("ToolCallID = %q", history[1].ToolCallID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMigrateFromJSON_MultipleFiles(t *testing.T) {
|
||||
sessionsDir := t.TempDir()
|
||||
store := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
key := string(rune('a' + i))
|
||||
writeJSONSession(t, sessionsDir, key+".json", jsonSession{
|
||||
Key: key,
|
||||
Messages: []providers.Message{{Role: "user", Content: "msg " + key}},
|
||||
Created: time.Now(),
|
||||
Updated: time.Now(),
|
||||
})
|
||||
}
|
||||
|
||||
count, err := MigrateFromJSON(ctx, sessionsDir, store)
|
||||
if err != nil {
|
||||
t.Fatalf("MigrateFromJSON: %v", err)
|
||||
}
|
||||
if count != 3 {
|
||||
t.Errorf("expected 3, got %d", count)
|
||||
}
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
key := string(rune('a' + i))
|
||||
history, histErr := store.GetHistory(ctx, key)
|
||||
if histErr != nil {
|
||||
t.Fatalf("GetHistory(%q): %v", key, histErr)
|
||||
}
|
||||
if len(history) != 1 {
|
||||
t.Errorf("session %q: expected 1 msg, got %d", key, len(history))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestMigrateFromJSON_InvalidJSON(t *testing.T) {
|
||||
sessionsDir := t.TempDir()
|
||||
store := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// One valid, one invalid.
|
||||
writeJSONSession(t, sessionsDir, "good.json", jsonSession{
|
||||
Key: "good",
|
||||
Messages: []providers.Message{{Role: "user", Content: "ok"}},
|
||||
Created: time.Now(),
|
||||
Updated: time.Now(),
|
||||
})
|
||||
err := os.WriteFile(
|
||||
filepath.Join(sessionsDir, "bad.json"),
|
||||
[]byte("{invalid json"),
|
||||
0o644,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("write bad file: %v", err)
|
||||
}
|
||||
|
||||
count, err := MigrateFromJSON(ctx, sessionsDir, store)
|
||||
if err != nil {
|
||||
t.Fatalf("MigrateFromJSON: %v", err)
|
||||
}
|
||||
if count != 1 {
|
||||
t.Errorf("expected 1 (bad file skipped), got %d", count)
|
||||
}
|
||||
|
||||
history, err := store.GetHistory(ctx, "good")
|
||||
if err != nil {
|
||||
t.Fatalf("GetHistory: %v", err)
|
||||
}
|
||||
if len(history) != 1 {
|
||||
t.Errorf("expected 1 message, got %d", len(history))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMigrateFromJSON_RenamesFiles(t *testing.T) {
|
||||
sessionsDir := t.TempDir()
|
||||
store := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
writeJSONSession(t, sessionsDir, "rename.json", jsonSession{
|
||||
Key: "rename",
|
||||
Messages: []providers.Message{{Role: "user", Content: "hi"}},
|
||||
Created: time.Now(),
|
||||
Updated: time.Now(),
|
||||
})
|
||||
|
||||
_, err := MigrateFromJSON(ctx, sessionsDir, store)
|
||||
if err != nil {
|
||||
t.Fatalf("MigrateFromJSON: %v", err)
|
||||
}
|
||||
|
||||
// Original .json should not exist.
|
||||
_, statErr := os.Stat(filepath.Join(sessionsDir, "rename.json"))
|
||||
if !os.IsNotExist(statErr) {
|
||||
t.Error("rename.json should have been renamed")
|
||||
}
|
||||
// .json.migrated should exist.
|
||||
_, statErr = os.Stat(
|
||||
filepath.Join(sessionsDir, "rename.json.migrated"),
|
||||
)
|
||||
if statErr != nil {
|
||||
t.Errorf("rename.json.migrated should exist: %v", statErr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMigrateFromJSON_Idempotent(t *testing.T) {
|
||||
sessionsDir := t.TempDir()
|
||||
store := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
writeJSONSession(t, sessionsDir, "idem.json", jsonSession{
|
||||
Key: "idem",
|
||||
Messages: []providers.Message{{Role: "user", Content: "once"}},
|
||||
Created: time.Now(),
|
||||
Updated: time.Now(),
|
||||
})
|
||||
|
||||
count1, err := MigrateFromJSON(ctx, sessionsDir, store)
|
||||
if err != nil {
|
||||
t.Fatalf("first migration: %v", err)
|
||||
}
|
||||
if count1 != 1 {
|
||||
t.Errorf("first run: expected 1, got %d", count1)
|
||||
}
|
||||
|
||||
// Second run should find only .migrated files, skip them.
|
||||
count2, err := MigrateFromJSON(ctx, sessionsDir, store)
|
||||
if err != nil {
|
||||
t.Fatalf("second migration: %v", err)
|
||||
}
|
||||
if count2 != 0 {
|
||||
t.Errorf("second run: expected 0, got %d", count2)
|
||||
}
|
||||
|
||||
history, err := store.GetHistory(ctx, "idem")
|
||||
if err != nil {
|
||||
t.Fatalf("GetHistory: %v", err)
|
||||
}
|
||||
if len(history) != 1 {
|
||||
t.Errorf("expected 1 message, got %d", len(history))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMigrateFromJSON_ColonInKey(t *testing.T) {
|
||||
sessionsDir := t.TempDir()
|
||||
store := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// File is named telegram_123 (sanitized), but the key inside is telegram:123.
|
||||
writeJSONSession(t, sessionsDir, "telegram_123.json", jsonSession{
|
||||
Key: "telegram:123",
|
||||
Messages: []providers.Message{{Role: "user", Content: "from telegram"}},
|
||||
Created: time.Now(),
|
||||
Updated: time.Now(),
|
||||
})
|
||||
|
||||
count, err := MigrateFromJSON(ctx, sessionsDir, store)
|
||||
if err != nil {
|
||||
t.Fatalf("MigrateFromJSON: %v", err)
|
||||
}
|
||||
if count != 1 {
|
||||
t.Errorf("expected 1, got %d", count)
|
||||
}
|
||||
|
||||
// Accessible via the original key "telegram:123".
|
||||
history, err := store.GetHistory(ctx, "telegram:123")
|
||||
if err != nil {
|
||||
t.Fatalf("GetHistory: %v", err)
|
||||
}
|
||||
if len(history) != 1 {
|
||||
t.Fatalf("expected 1 message, got %d", len(history))
|
||||
}
|
||||
if history[0].Content != "from telegram" {
|
||||
t.Errorf("content = %q", history[0].Content)
|
||||
}
|
||||
|
||||
// In the file-based store, "telegram:123" and "telegram_123" both
|
||||
// sanitize to the same filename, so they share storage. This is
|
||||
// expected — the colon-to-underscore mapping is a one-way function.
|
||||
history2, err := store.GetHistory(ctx, "telegram_123")
|
||||
if err != nil {
|
||||
t.Fatalf("GetHistory: %v", err)
|
||||
}
|
||||
if len(history2) != 1 {
|
||||
t.Errorf("expected 1 (same file), got %d", len(history2))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMigrateFromJSON_RetryAfterCrash(t *testing.T) {
|
||||
// Simulates a crash during migration: first run writes messages
|
||||
// but doesn't rename the .json file. Second run must replace
|
||||
// (not duplicate) the messages thanks to SetHistory semantics.
|
||||
sessionsDir := t.TempDir()
|
||||
store := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
writeJSONSession(t, sessionsDir, "retry.json", jsonSession{
|
||||
Key: "retry",
|
||||
Messages: []providers.Message{
|
||||
{Role: "user", Content: "one"},
|
||||
{Role: "assistant", Content: "two"},
|
||||
},
|
||||
Created: time.Now(),
|
||||
Updated: time.Now(),
|
||||
})
|
||||
|
||||
// First migration succeeds — writes messages and renames file.
|
||||
count, err := MigrateFromJSON(ctx, sessionsDir, store)
|
||||
if err != nil {
|
||||
t.Fatalf("first migration: %v", err)
|
||||
}
|
||||
if count != 1 {
|
||||
t.Fatalf("expected 1, got %d", count)
|
||||
}
|
||||
|
||||
// Simulate "crash before rename": restore the .json file.
|
||||
src := filepath.Join(sessionsDir, "retry.json.migrated")
|
||||
dst := filepath.Join(sessionsDir, "retry.json")
|
||||
if renameErr := os.Rename(src, dst); renameErr != nil {
|
||||
t.Fatalf("restore .json: %v", renameErr)
|
||||
}
|
||||
|
||||
// Second migration should re-import without duplicating messages.
|
||||
count, err = MigrateFromJSON(ctx, sessionsDir, store)
|
||||
if err != nil {
|
||||
t.Fatalf("second migration: %v", err)
|
||||
}
|
||||
if count != 1 {
|
||||
t.Fatalf("expected 1, got %d", count)
|
||||
}
|
||||
|
||||
history, err := store.GetHistory(ctx, "retry")
|
||||
if err != nil {
|
||||
t.Fatalf("GetHistory: %v", err)
|
||||
}
|
||||
// Must be exactly 2 messages (not 4 from duplication).
|
||||
if len(history) != 2 {
|
||||
t.Fatalf("expected 2 messages (no duplicates), got %d", len(history))
|
||||
}
|
||||
if history[0].Content != "one" || history[1].Content != "two" {
|
||||
t.Errorf("unexpected messages: %+v", history)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMigrateFromJSON_NonexistentDir(t *testing.T) {
|
||||
store := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
count, err := MigrateFromJSON(ctx, "/nonexistent/path", store)
|
||||
if err != nil {
|
||||
t.Fatalf("MigrateFromJSON: %v", err)
|
||||
}
|
||||
if count != 0 {
|
||||
t.Errorf("expected 0, got %d", count)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,42 @@
|
||||
package memory
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
// Store defines an interface for persistent session storage.
|
||||
// Each method is an atomic operation — there is no separate Save() call.
|
||||
type Store interface {
|
||||
// AddMessage appends a simple text message to a session.
|
||||
AddMessage(ctx context.Context, sessionKey, role, content string) error
|
||||
|
||||
// AddFullMessage appends a complete message (with tool calls, etc.) to a session.
|
||||
AddFullMessage(ctx context.Context, sessionKey string, msg providers.Message) error
|
||||
|
||||
// GetHistory returns all messages for a session in insertion order.
|
||||
// Returns an empty slice (not nil) if the session does not exist.
|
||||
GetHistory(ctx context.Context, sessionKey string) ([]providers.Message, error)
|
||||
|
||||
// GetSummary returns the conversation summary for a session.
|
||||
// Returns an empty string if no summary exists.
|
||||
GetSummary(ctx context.Context, sessionKey string) (string, error)
|
||||
|
||||
// SetSummary updates the conversation summary for a session.
|
||||
SetSummary(ctx context.Context, sessionKey, summary string) error
|
||||
|
||||
// TruncateHistory removes all but the last keepLast messages from a session.
|
||||
// If keepLast <= 0, all messages are removed.
|
||||
TruncateHistory(ctx context.Context, sessionKey string, keepLast int) error
|
||||
|
||||
// SetHistory replaces all messages in a session with the provided history.
|
||||
SetHistory(ctx context.Context, sessionKey string, history []providers.Message) error
|
||||
|
||||
// Compact reclaims storage by physically removing logically truncated
|
||||
// data. Backends that do not accumulate dead data may return nil.
|
||||
Compact(ctx context.Context, sessionKey string) error
|
||||
|
||||
// Close releases any resources held by the store.
|
||||
Close() error
|
||||
}
|
||||
@@ -190,28 +190,6 @@ func resolveProviderSelection(cfg *config.Config) (providerSelection, error) {
|
||||
sel.apiBase = "https://api.mistral.ai/v1"
|
||||
}
|
||||
}
|
||||
case "opencode":
|
||||
if cfg.Providers.Opencode.APIKey != "" || cfg.Providers.Opencode.APIBase != "" {
|
||||
sel.apiKey = cfg.Providers.Opencode.APIKey
|
||||
sel.apiBase = cfg.Providers.Opencode.APIBase
|
||||
sel.proxy = cfg.Providers.Opencode.Proxy
|
||||
if sel.apiBase == "" {
|
||||
sel.apiBase = "https://opencode.ai/zen/v1"
|
||||
}
|
||||
}
|
||||
case "kimi", "kimi-code", "moonshot":
|
||||
if cfg.Providers.Moonshot.APIKey != "" {
|
||||
sel.apiKey = cfg.Providers.Moonshot.APIKey
|
||||
sel.apiBase = cfg.Providers.Moonshot.APIBase
|
||||
sel.proxy = cfg.Providers.Moonshot.Proxy
|
||||
if sel.apiBase == "" {
|
||||
if providerName == "moonshot" {
|
||||
sel.apiBase = "https://api.moonshot.cn/v1"
|
||||
} else {
|
||||
sel.apiBase = "https://api.kimi.com/coding/v1"
|
||||
}
|
||||
}
|
||||
}
|
||||
case "github_copilot", "copilot":
|
||||
sel.providerType = providerTypeGitHubCopilot
|
||||
if cfg.Providers.GitHubCopilot.APIBase != "" {
|
||||
@@ -232,11 +210,7 @@ func resolveProviderSelection(cfg *config.Config) (providerSelection, error) {
|
||||
sel.apiBase = cfg.Providers.Moonshot.APIBase
|
||||
sel.proxy = cfg.Providers.Moonshot.Proxy
|
||||
if sel.apiBase == "" {
|
||||
if strings.Contains(lowerModel, "moonshot") || strings.HasPrefix(model, "moonshot/") {
|
||||
sel.apiBase = "https://api.moonshot.cn/v1"
|
||||
} else {
|
||||
sel.apiBase = "https://api.kimi.com/coding/v1"
|
||||
}
|
||||
sel.apiBase = "https://api.moonshot.cn/v1"
|
||||
}
|
||||
case strings.HasPrefix(model, "openrouter/") ||
|
||||
strings.HasPrefix(model, "anthropic/") ||
|
||||
|
||||
@@ -94,7 +94,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
|
||||
|
||||
case "litellm", "openrouter", "groq", "zhipu", "gemini", "nvidia",
|
||||
"ollama", "moonshot", "shengsuanyun", "deepseek", "cerebras",
|
||||
"volcengine", "vllm", "qwen", "mistral", "opencode":
|
||||
"volcengine", "vllm", "qwen", "mistral":
|
||||
// All other OpenAI-compatible HTTP providers
|
||||
if cfg.APIKey == "" && cfg.APIBase == "" {
|
||||
return nil, "", fmt.Errorf("api_key or api_base is required for HTTP-based protocol %q", protocol)
|
||||
@@ -208,8 +208,6 @@ func getDefaultAPIBase(protocol string) string {
|
||||
return "http://localhost:8000/v1"
|
||||
case "mistral":
|
||||
return "https://api.mistral.ai/v1"
|
||||
case "opencode":
|
||||
return "https://opencode.ai/zen/v1"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -112,7 +112,6 @@ func TestCreateProviderFromConfig_DefaultAPIBase(t *testing.T) {
|
||||
{"vllm", "vllm"},
|
||||
{"deepseek", "deepseek"},
|
||||
{"ollama", "ollama"},
|
||||
{"opencode", "opencode"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
||||
@@ -33,7 +33,6 @@ type Provider struct {
|
||||
apiBase string
|
||||
maxTokensField string // Field name for max tokens (e.g., "max_completion_tokens" for o1/glm models)
|
||||
httpClient *http.Client
|
||||
isKimiAPI bool // true when apiBase points to api.kimi.com
|
||||
}
|
||||
|
||||
type Option func(*Provider)
|
||||
@@ -70,17 +69,10 @@ func NewProvider(apiKey, apiBase, proxy string, opts ...Option) *Provider {
|
||||
}
|
||||
}
|
||||
|
||||
trimmedBase := strings.TrimRight(apiBase, "/")
|
||||
var isKimi bool
|
||||
if parsed, err := url.Parse(trimmedBase); err == nil {
|
||||
isKimi = parsed.Hostname() == "api.kimi.com"
|
||||
}
|
||||
|
||||
p := &Provider{
|
||||
apiKey: apiKey,
|
||||
apiBase: trimmedBase,
|
||||
apiBase: strings.TrimRight(apiBase, "/"),
|
||||
httpClient: client,
|
||||
isKimiAPI: isKimi,
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
@@ -184,12 +176,6 @@ func (p *Provider) Chat(
|
||||
if p.apiKey != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+p.apiKey)
|
||||
}
|
||||
// Kimi Code API rejects requests without a recognized coding-agent
|
||||
// User-Agent. "KimiCLI/0.77" is the minimum version string accepted
|
||||
// by the api.kimi.com/coding/v1 endpoint (per Kimi's API docs).
|
||||
if p.isKimiAPI {
|
||||
req.Header.Set("User-Agent", "KimiCLI/0.77")
|
||||
}
|
||||
|
||||
resp, err := p.httpClient.Do(req)
|
||||
if err != nil {
|
||||
|
||||
@@ -2,7 +2,6 @@ package openai_compat
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
@@ -421,82 +420,6 @@ func TestProvider_FunctionalOptionRequestTimeoutNonPositive(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// roundTripFunc adapts a function to http.RoundTripper for test injection.
|
||||
type roundTripFunc func(*http.Request) (*http.Response, error)
|
||||
|
||||
func (f roundTripFunc) RoundTrip(r *http.Request) (*http.Response, error) {
|
||||
return f(r)
|
||||
}
|
||||
|
||||
func TestProviderChat_KimiCodeUserAgent(t *testing.T) {
|
||||
okBody := `{"choices":[{"message":{"content":"ok"},"finish_reason":"stop"}]}`
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
apiBase string
|
||||
wantAgent string
|
||||
}{
|
||||
{
|
||||
name: "sets KimiCLI User-Agent for api.kimi.com",
|
||||
apiBase: "https://api.kimi.com/coding/v1",
|
||||
wantAgent: "KimiCLI/0.77",
|
||||
},
|
||||
{
|
||||
name: "does not set KimiCLI User-Agent for other hosts",
|
||||
apiBase: "https://api.example.com/v1",
|
||||
wantAgent: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var gotUserAgent string
|
||||
|
||||
p := NewProvider("key", tt.apiBase, "")
|
||||
p.httpClient.Transport = roundTripFunc(
|
||||
func(r *http.Request) (*http.Response, error) {
|
||||
gotUserAgent = r.Header.Get("User-Agent")
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(
|
||||
strings.NewReader(okBody),
|
||||
),
|
||||
Header: http.Header{
|
||||
"Content-Type": {"application/json"},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
)
|
||||
|
||||
_, err := p.Chat(
|
||||
t.Context(),
|
||||
[]Message{{Role: "user", Content: "hi"}},
|
||||
nil,
|
||||
"kimi-k2.5",
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
}
|
||||
|
||||
if tt.wantAgent != "" {
|
||||
if gotUserAgent != tt.wantAgent {
|
||||
t.Fatalf(
|
||||
"User-Agent = %q, want %q",
|
||||
gotUserAgent, tt.wantAgent,
|
||||
)
|
||||
}
|
||||
} else {
|
||||
if gotUserAgent == "KimiCLI/0.77" {
|
||||
t.Fatalf(
|
||||
"User-Agent should not be KimiCLI/0.77 for non-kimi host",
|
||||
)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSerializeMessages_PlainText(t *testing.T) {
|
||||
messages := []protocoltypes.Message{
|
||||
{Role: "user", Content: "hello"},
|
||||
|
||||
@@ -3,6 +3,7 @@ package tools
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
type SendCallback func(channel, chatID, content string) error
|
||||
@@ -11,7 +12,7 @@ type MessageTool struct {
|
||||
sendCallback SendCallback
|
||||
defaultChannel string
|
||||
defaultChatID string
|
||||
sentInRound bool // Tracks whether a message was sent in the current processing round
|
||||
sentInRound atomic.Bool // Tracks whether a message was sent in the current processing round
|
||||
}
|
||||
|
||||
func NewMessageTool() *MessageTool {
|
||||
@@ -50,12 +51,12 @@ func (t *MessageTool) Parameters() map[string]any {
|
||||
func (t *MessageTool) SetContext(channel, chatID string) {
|
||||
t.defaultChannel = channel
|
||||
t.defaultChatID = chatID
|
||||
t.sentInRound = false // Reset send tracking for new processing round
|
||||
t.sentInRound.Store(false) // Reset send tracking for new processing round
|
||||
}
|
||||
|
||||
// HasSentInRound returns true if the message tool sent a message during the current round.
|
||||
func (t *MessageTool) HasSentInRound() bool {
|
||||
return t.sentInRound
|
||||
return t.sentInRound.Load()
|
||||
}
|
||||
|
||||
func (t *MessageTool) SetSendCallback(callback SendCallback) {
|
||||
@@ -94,7 +95,7 @@ func (t *MessageTool) Execute(ctx context.Context, args map[string]any) *ToolRes
|
||||
}
|
||||
}
|
||||
|
||||
t.sentInRound = true
|
||||
t.sentInRound.Store(true)
|
||||
// Silent: user already received the message directly
|
||||
return &ToolResult{
|
||||
ForLLM: fmt.Sprintf("Message sent to %s:%s", channel, chatID),
|
||||
|
||||
+2
-2
@@ -30,9 +30,9 @@ var (
|
||||
regexp.MustCompile(`\brm\s+-[rf]{1,2}\b`),
|
||||
regexp.MustCompile(`\bdel\s+/[fq]\b`),
|
||||
regexp.MustCompile(`\brmdir\s+/s\b`),
|
||||
// Match disk wiping commands, avoid matching --format flags
|
||||
// Match disk wiping commands (must be followed by space/args)
|
||||
regexp.MustCompile(
|
||||
`(?:^|[;&|]\s*|\s+)(format|mkfs|diskpart)\s`,
|
||||
`\b(format|mkfs|diskpart)\b\s`,
|
||||
),
|
||||
regexp.MustCompile(`\bdd\s+if=`),
|
||||
// Block writes to block devices (all common naming schemes).
|
||||
|
||||
@@ -366,56 +366,6 @@ func TestShellTool_BlockDevices(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestShellTool_DenyPattern_DiskWiping verifies the deny pattern for disk wiping
|
||||
// commands (format, mkfs, diskpart) blocks them when preceded by shell separators
|
||||
// but does NOT block legitimate uses like --format flags.
|
||||
func TestShellTool_DenyPattern_DiskWiping(t *testing.T) {
|
||||
tool, err := NewExecTool("", false)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to configure exec tool: %s", err)
|
||||
}
|
||||
|
||||
// These should be BLOCKED (disk wiping commands)
|
||||
blockedCmds := []struct {
|
||||
name string
|
||||
cmd string
|
||||
}{
|
||||
{"format with space", "format c:"},
|
||||
{"mkfs standalone", "mkfs /dev/sda"},
|
||||
{"semicolon format", "echo hello; format c:"},
|
||||
{"pipe format", "echo hello | format c:"},
|
||||
{"and format", "echo hello && format c:"},
|
||||
{"diskpart standalone", "diskpart /s script.txt"},
|
||||
}
|
||||
|
||||
for _, tt := range blockedCmds {
|
||||
t.Run("blocked_"+tt.name, func(t *testing.T) {
|
||||
msg := tool.guardCommand(tt.cmd, "")
|
||||
if !strings.Contains(msg, "blocked") {
|
||||
t.Errorf("Expected %q to be blocked by safety guard, got: %q", tt.cmd, msg)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// These should be ALLOWED (not disk wiping)
|
||||
allowed := []struct {
|
||||
name string
|
||||
cmd string
|
||||
}{
|
||||
{"--format flag", "echo test --format json"},
|
||||
{"go fmt", "echo go fmt ./..."},
|
||||
}
|
||||
|
||||
for _, tt := range allowed {
|
||||
t.Run("allowed_"+tt.name, func(t *testing.T) {
|
||||
msg := tool.guardCommand(tt.cmd, "")
|
||||
if msg != "" {
|
||||
t.Errorf("Expected %q to be allowed, but it was blocked: %s", tt.cmd, msg)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestShellTool_SafePathsInWorkspaceRestriction verifies that safe kernel pseudo-devices
|
||||
// are allowed even when workspace restriction is active.
|
||||
func TestShellTool_SafePathsInWorkspaceRestriction(t *testing.T) {
|
||||
|
||||
+43
-26
@@ -10,6 +10,7 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
@@ -121,37 +122,53 @@ func RunToolLoop(
|
||||
}
|
||||
messages = append(messages, assistantMsg)
|
||||
|
||||
// 7. Execute tool calls
|
||||
for _, tc := range normalizedToolCalls {
|
||||
argsJSON, _ := json.Marshal(tc.Arguments)
|
||||
argsPreview := utils.Truncate(string(argsJSON), 200)
|
||||
logger.InfoCF("toolloop", fmt.Sprintf("Tool call: %s(%s)", tc.Name, argsPreview),
|
||||
map[string]any{
|
||||
"tool": tc.Name,
|
||||
"iteration": iteration,
|
||||
})
|
||||
// 7. Execute tool calls in parallel
|
||||
type indexedResult struct {
|
||||
result *ToolResult
|
||||
tc providers.ToolCall
|
||||
}
|
||||
|
||||
// Execute tool (no async callback for subagents - they run independently)
|
||||
var toolResult *ToolResult
|
||||
if config.Tools != nil {
|
||||
toolResult = config.Tools.ExecuteWithContext(ctx, tc.Name, tc.Arguments, channel, chatID, nil)
|
||||
} else {
|
||||
toolResult = ErrorResult("No tools available")
|
||||
results := make([]indexedResult, len(normalizedToolCalls))
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for i, tc := range normalizedToolCalls {
|
||||
results[i].tc = tc
|
||||
|
||||
wg.Add(1)
|
||||
go func(idx int, tc providers.ToolCall) {
|
||||
defer wg.Done()
|
||||
|
||||
argsJSON, _ := json.Marshal(tc.Arguments)
|
||||
argsPreview := utils.Truncate(string(argsJSON), 200)
|
||||
logger.InfoCF("toolloop", fmt.Sprintf("Tool call: %s(%s)", tc.Name, argsPreview),
|
||||
map[string]any{
|
||||
"tool": tc.Name,
|
||||
"iteration": iteration,
|
||||
})
|
||||
|
||||
var toolResult *ToolResult
|
||||
if config.Tools != nil {
|
||||
toolResult = config.Tools.ExecuteWithContext(ctx, tc.Name, tc.Arguments, channel, chatID, nil)
|
||||
} else {
|
||||
toolResult = ErrorResult("No tools available")
|
||||
}
|
||||
results[idx].result = toolResult
|
||||
}(i, tc)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
// Append results in original order
|
||||
for _, r := range results {
|
||||
contentForLLM := r.result.ForLLM
|
||||
if contentForLLM == "" && r.result.Err != nil {
|
||||
contentForLLM = r.result.Err.Error()
|
||||
}
|
||||
|
||||
// Determine content for LLM
|
||||
contentForLLM := toolResult.ForLLM
|
||||
if contentForLLM == "" && toolResult.Err != nil {
|
||||
contentForLLM = toolResult.Err.Error()
|
||||
}
|
||||
|
||||
// Add tool result message
|
||||
toolResultMsg := providers.Message{
|
||||
messages = append(messages, providers.Message{
|
||||
Role: "tool",
|
||||
Content: contentForLLM,
|
||||
ToolCallID: tc.ID,
|
||||
}
|
||||
messages = append(messages, toolResultMsg)
|
||||
ToolCallID: r.tc.ID,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+107
-87
@@ -395,6 +395,88 @@ func (p *PerplexitySearchProvider) Search(ctx context.Context, query string, cou
|
||||
return fmt.Sprintf("Results for: %s (via Perplexity)\n%s", query, searchResp.Choices[0].Message.Content), nil
|
||||
}
|
||||
|
||||
type GLMSearchProvider struct {
|
||||
apiKey string
|
||||
baseURL string
|
||||
searchEngine string
|
||||
proxy string
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
func (p *GLMSearchProvider) Search(ctx context.Context, query string, count int) (string, error) {
|
||||
searchURL := p.baseURL
|
||||
if searchURL == "" {
|
||||
searchURL = "https://open.bigmodel.cn/api/paas/v4/web_search"
|
||||
}
|
||||
|
||||
payload := map[string]any{
|
||||
"search_query": query,
|
||||
"search_engine": p.searchEngine,
|
||||
"search_intent": false,
|
||||
"count": count,
|
||||
"content_size": "medium",
|
||||
}
|
||||
|
||||
bodyBytes, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to marshal payload: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", searchURL, bytes.NewReader(bodyBytes))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+p.apiKey)
|
||||
|
||||
resp, err := p.client.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("GLM Search API error (status %d): %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var searchResp struct {
|
||||
SearchResult []struct {
|
||||
Title string `json:"title"`
|
||||
Content string `json:"content"`
|
||||
Link string `json:"link"`
|
||||
} `json:"search_result"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(body, &searchResp); err != nil {
|
||||
return "", fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
results := searchResp.SearchResult
|
||||
if len(results) == 0 {
|
||||
return fmt.Sprintf("No results for: %s", query), nil
|
||||
}
|
||||
|
||||
var lines []string
|
||||
lines = append(lines, fmt.Sprintf("Results for: %s (via GLM Search)", query))
|
||||
for i, item := range results {
|
||||
if i >= count {
|
||||
break
|
||||
}
|
||||
lines = append(lines, fmt.Sprintf("%d. %s\n %s", i+1, item.Title, item.Link))
|
||||
if item.Content != "" {
|
||||
lines = append(lines, fmt.Sprintf(" %s", item.Content))
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(lines, "\n"), nil
|
||||
}
|
||||
|
||||
type WebSearchTool struct {
|
||||
provider SearchProvider
|
||||
maxResults int
|
||||
@@ -413,9 +495,11 @@ type WebSearchToolOptions struct {
|
||||
PerplexityAPIKey string
|
||||
PerplexityMaxResults int
|
||||
PerplexityEnabled bool
|
||||
ExaAPIKey string
|
||||
ExaMaxResults int
|
||||
ExaEnabled bool
|
||||
GLMSearchAPIKey string
|
||||
GLMSearchBaseURL string
|
||||
GLMSearchEngine string
|
||||
GLMSearchMaxResults int
|
||||
GLMSearchEnabled bool
|
||||
Proxy string
|
||||
}
|
||||
|
||||
@@ -423,7 +507,7 @@ func NewWebSearchTool(opts WebSearchToolOptions) (*WebSearchTool, error) {
|
||||
var provider SearchProvider
|
||||
maxResults := 5
|
||||
|
||||
// Priority: Perplexity > Exa > Brave > Tavily > DuckDuckGo
|
||||
// Priority: Perplexity > Brave > Tavily > DuckDuckGo > GLM Search
|
||||
if opts.PerplexityEnabled && opts.PerplexityAPIKey != "" {
|
||||
client, err := createHTTPClient(opts.Proxy, perplexityTimeout)
|
||||
if err != nil {
|
||||
@@ -433,15 +517,6 @@ func NewWebSearchTool(opts WebSearchToolOptions) (*WebSearchTool, error) {
|
||||
if opts.PerplexityMaxResults > 0 {
|
||||
maxResults = opts.PerplexityMaxResults
|
||||
}
|
||||
} else if opts.ExaEnabled && opts.ExaAPIKey != "" {
|
||||
client, err := createHTTPClient(opts.Proxy, searchTimeout)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create HTTP client for Exa: %w", err)
|
||||
}
|
||||
provider = &ExaSearchProvider{apiKey: opts.ExaAPIKey, proxy: opts.Proxy, client: client}
|
||||
if opts.ExaMaxResults > 0 {
|
||||
maxResults = opts.ExaMaxResults
|
||||
}
|
||||
} else if opts.BraveEnabled && opts.BraveAPIKey != "" {
|
||||
client, err := createHTTPClient(opts.Proxy, searchTimeout)
|
||||
if err != nil {
|
||||
@@ -474,6 +549,25 @@ func NewWebSearchTool(opts WebSearchToolOptions) (*WebSearchTool, error) {
|
||||
if opts.DuckDuckGoMaxResults > 0 {
|
||||
maxResults = opts.DuckDuckGoMaxResults
|
||||
}
|
||||
} else if opts.GLMSearchEnabled && opts.GLMSearchAPIKey != "" {
|
||||
client, err := createHTTPClient(opts.Proxy, searchTimeout)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create HTTP client for GLM Search: %w", err)
|
||||
}
|
||||
searchEngine := opts.GLMSearchEngine
|
||||
if searchEngine == "" {
|
||||
searchEngine = "search_std"
|
||||
}
|
||||
provider = &GLMSearchProvider{
|
||||
apiKey: opts.GLMSearchAPIKey,
|
||||
baseURL: opts.GLMSearchBaseURL,
|
||||
searchEngine: searchEngine,
|
||||
proxy: opts.Proxy,
|
||||
client: client,
|
||||
}
|
||||
if opts.GLMSearchMaxResults > 0 {
|
||||
maxResults = opts.GLMSearchMaxResults
|
||||
}
|
||||
} else {
|
||||
return nil, nil
|
||||
}
|
||||
@@ -721,77 +815,3 @@ func (t *WebFetchTool) extractText(htmlContent string) string {
|
||||
|
||||
return strings.Join(cleanLines, "\n")
|
||||
}
|
||||
|
||||
// ExaSearchProvider uses the Exa AI search API (https://exa.ai).
|
||||
type ExaSearchProvider struct {
|
||||
apiKey string
|
||||
proxy string
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
func (p *ExaSearchProvider) Search(ctx context.Context, query string, count int) (string, error) {
|
||||
reqBody := map[string]any{
|
||||
"query": query,
|
||||
"num_results": count,
|
||||
"type": "neural",
|
||||
}
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("exa: marshal error: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", "https://api.exa.ai/search", bytes.NewReader(jsonData))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("exa: request error: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("x-api-key", p.apiKey)
|
||||
|
||||
resp, err := p.client.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("exa: search failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("exa: read error: %w", err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("exa: API error %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var result struct {
|
||||
Results []struct {
|
||||
Title string `json:"title"`
|
||||
URL string `json:"url"`
|
||||
Text string `json:"text"`
|
||||
} `json:"results"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
return "", fmt.Errorf("exa: parse error: %w", err)
|
||||
}
|
||||
|
||||
if len(result.Results) == 0 {
|
||||
return fmt.Sprintf("No results for: %s", query), nil
|
||||
}
|
||||
|
||||
var lines []string
|
||||
lines = append(lines, fmt.Sprintf("Results for: %s (via Exa)", query))
|
||||
maxResults := count
|
||||
if maxResults > len(result.Results) {
|
||||
maxResults = len(result.Results)
|
||||
}
|
||||
for i, r := range result.Results[:maxResults] {
|
||||
lines = append(lines, fmt.Sprintf("%d. %s\n %s", i+1, r.Title, r.URL))
|
||||
if r.Text != "" {
|
||||
snippet := r.Text
|
||||
if len(snippet) > 200 {
|
||||
snippet = snippet[:200] + "..."
|
||||
}
|
||||
lines = append(lines, fmt.Sprintf(" %s", snippet))
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(lines, "\n"), nil
|
||||
}
|
||||
|
||||
+105
-189
@@ -5,7 +5,6 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
@@ -683,86 +682,7 @@ func TestWebTool_TavilySearch_Success(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewWebSearchTool_ExaPriority(t *testing.T) {
|
||||
// Exa should be selected when enabled with API key
|
||||
tool, err := NewWebSearchTool(WebSearchToolOptions{
|
||||
ExaEnabled: true,
|
||||
ExaAPIKey: "exa-key",
|
||||
ExaMaxResults: 3,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewWebSearchTool() error: %v", err)
|
||||
}
|
||||
if tool == nil {
|
||||
t.Fatal("Expected non-nil tool when Exa is enabled with API key")
|
||||
}
|
||||
if _, ok := tool.provider.(*ExaSearchProvider); !ok {
|
||||
t.Fatalf("provider type = %T, want *ExaSearchProvider", tool.provider)
|
||||
}
|
||||
if tool.maxResults != 3 {
|
||||
t.Fatalf("maxResults = %d, want 3", tool.maxResults)
|
||||
}
|
||||
|
||||
// Exa enabled but no API key should fall through
|
||||
tool, err = NewWebSearchTool(WebSearchToolOptions{
|
||||
ExaEnabled: true,
|
||||
ExaAPIKey: "",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewWebSearchTool() error: %v", err)
|
||||
}
|
||||
if tool != nil {
|
||||
t.Errorf("Expected nil tool when Exa API key is empty and no other provider enabled")
|
||||
}
|
||||
|
||||
// Perplexity should take priority over Exa
|
||||
tool, err = NewWebSearchTool(WebSearchToolOptions{
|
||||
PerplexityEnabled: true,
|
||||
PerplexityAPIKey: "perp-key",
|
||||
ExaEnabled: true,
|
||||
ExaAPIKey: "exa-key",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewWebSearchTool() error: %v", err)
|
||||
}
|
||||
if _, ok := tool.provider.(*PerplexitySearchProvider); !ok {
|
||||
t.Fatalf("provider type = %T, want *PerplexitySearchProvider (Perplexity should outrank Exa)", tool.provider)
|
||||
}
|
||||
|
||||
// Exa should take priority over Brave
|
||||
tool, err = NewWebSearchTool(WebSearchToolOptions{
|
||||
ExaEnabled: true,
|
||||
ExaAPIKey: "exa-key",
|
||||
BraveEnabled: true,
|
||||
BraveAPIKey: "brave-key",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewWebSearchTool() error: %v", err)
|
||||
}
|
||||
if _, ok := tool.provider.(*ExaSearchProvider); !ok {
|
||||
t.Fatalf("provider type = %T, want *ExaSearchProvider (Exa should outrank Brave)", tool.provider)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewWebSearchTool_ExaProxyPropagation(t *testing.T) {
|
||||
tool, err := NewWebSearchTool(WebSearchToolOptions{
|
||||
ExaEnabled: true,
|
||||
ExaAPIKey: "k",
|
||||
Proxy: "http://127.0.0.1:7890",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewWebSearchTool() error: %v", err)
|
||||
}
|
||||
p, ok := tool.provider.(*ExaSearchProvider)
|
||||
if !ok {
|
||||
t.Fatalf("provider type = %T, want *ExaSearchProvider", tool.provider)
|
||||
}
|
||||
if p.proxy != "http://127.0.0.1:7890" {
|
||||
t.Fatalf("provider proxy = %q, want %q", p.proxy, "http://127.0.0.1:7890")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExaSearchProvider_Success(t *testing.T) {
|
||||
func TestWebTool_GLMSearch_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != "POST" {
|
||||
t.Errorf("Expected POST request, got %s", r.Method)
|
||||
@@ -770,130 +690,126 @@ func TestExaSearchProvider_Success(t *testing.T) {
|
||||
if r.Header.Get("Content-Type") != "application/json" {
|
||||
t.Errorf("Expected Content-Type application/json, got %s", r.Header.Get("Content-Type"))
|
||||
}
|
||||
if r.Header.Get("x-api-key") != "test-exa-key" {
|
||||
t.Errorf("Expected x-api-key test-exa-key, got %s", r.Header.Get("x-api-key"))
|
||||
if r.Header.Get("Authorization") != "Bearer test-glm-key" {
|
||||
t.Errorf("Expected Authorization Bearer test-glm-key, got %s", r.Header.Get("Authorization"))
|
||||
}
|
||||
|
||||
// Verify payload
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
var payload map[string]any
|
||||
json.Unmarshal(body, &payload)
|
||||
if payload["query"] != "test query" {
|
||||
t.Errorf("Expected query 'test query', got %v", payload["query"])
|
||||
json.NewDecoder(r.Body).Decode(&payload)
|
||||
if payload["search_query"] != "test query" {
|
||||
t.Errorf("Expected search_query 'test query', got %v", payload["search_query"])
|
||||
}
|
||||
if payload["type"] != "neural" {
|
||||
t.Errorf("Expected type 'neural', got %v", payload["type"])
|
||||
if payload["search_engine"] != "search_std" {
|
||||
t.Errorf("Expected search_engine 'search_std', got %v", payload["search_engine"])
|
||||
}
|
||||
|
||||
response := map[string]any{
|
||||
"results": []map[string]any{
|
||||
{"title": "Exa Result 1", "url": "https://exa.ai/1", "text": "First result text"},
|
||||
{"title": "Exa Result 2", "url": "https://exa.ai/2", "text": "Second result text"},
|
||||
{"title": "Exa Result 3", "url": "https://exa.ai/3", "text": "Third result text"},
|
||||
"id": "web-search-test",
|
||||
"created": 1709568000,
|
||||
"search_result": []map[string]any{
|
||||
{
|
||||
"title": "Test GLM Result",
|
||||
"content": "GLM search snippet",
|
||||
"link": "https://example.com/glm",
|
||||
"media": "Example",
|
||||
"publish_date": "2026-03-04",
|
||||
},
|
||||
},
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(response)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
provider := &ExaSearchProvider{
|
||||
apiKey: "test-exa-key",
|
||||
client: &http.Client{},
|
||||
}
|
||||
|
||||
// Temporarily override the API URL by using a custom transport
|
||||
provider.client.Transport = rewriteHostTransport(server.URL)
|
||||
|
||||
result, err := provider.Search(context.Background(), "test query", 5)
|
||||
if err != nil {
|
||||
t.Fatalf("Search() error: %v", err)
|
||||
}
|
||||
|
||||
if !strings.Contains(result, "via Exa") {
|
||||
t.Errorf("Expected '(via Exa)' attribution, got: %s", result)
|
||||
}
|
||||
if !strings.Contains(result, "Exa Result 1") || !strings.Contains(result, "https://exa.ai/1") {
|
||||
t.Errorf("Expected results in output, got: %s", result)
|
||||
}
|
||||
if !strings.Contains(result, "First result text") {
|
||||
t.Errorf("Expected snippet text in output, got: %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExaSearchProvider_EmptyResults(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
response := map[string]any{"results": []map[string]any{}}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(response)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
provider := &ExaSearchProvider{
|
||||
apiKey: "test-key",
|
||||
client: &http.Client{Transport: rewriteHostTransport(server.URL)},
|
||||
}
|
||||
|
||||
result, err := provider.Search(context.Background(), "no results query", 5)
|
||||
if err != nil {
|
||||
t.Fatalf("Search() error: %v", err)
|
||||
}
|
||||
if !strings.Contains(result, "No results for: no results query") {
|
||||
t.Errorf("Expected 'No results' message, got: %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExaSearchProvider_MaxResultsCapping(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Return 5 results
|
||||
results := make([]map[string]any, 5)
|
||||
for i := range results {
|
||||
results[i] = map[string]any{
|
||||
"title": fmt.Sprintf("Result %d", i+1),
|
||||
"url": fmt.Sprintf("https://exa.ai/%d", i+1),
|
||||
"text": fmt.Sprintf("Text %d", i+1),
|
||||
}
|
||||
}
|
||||
response := map[string]any{"results": results}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(response)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
provider := &ExaSearchProvider{
|
||||
apiKey: "test-key",
|
||||
client: &http.Client{Transport: rewriteHostTransport(server.URL)},
|
||||
}
|
||||
|
||||
// Request only 2 results even though API returns 5
|
||||
result, err := provider.Search(context.Background(), "test", 2)
|
||||
if err != nil {
|
||||
t.Fatalf("Search() error: %v", err)
|
||||
}
|
||||
|
||||
if !strings.Contains(result, "Result 1") || !strings.Contains(result, "Result 2") {
|
||||
t.Errorf("Expected first 2 results, got: %s", result)
|
||||
}
|
||||
if strings.Contains(result, "Result 3") {
|
||||
t.Errorf("Expected results capped at 2, but got Result 3 in output: %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
// rewriteHostTransport returns an http.RoundTripper that redirects all requests to the given target URL.
|
||||
func rewriteHostTransport(target string) http.RoundTripper {
|
||||
return roundTripFunc(func(req *http.Request) (*http.Response, error) {
|
||||
newURL := target + req.URL.Path
|
||||
newReq, err := http.NewRequestWithContext(req.Context(), req.Method, newURL, req.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
newReq.Header = req.Header
|
||||
return http.DefaultClient.Do(newReq)
|
||||
tool, err := NewWebSearchTool(WebSearchToolOptions{
|
||||
GLMSearchEnabled: true,
|
||||
GLMSearchAPIKey: "test-glm-key",
|
||||
GLMSearchBaseURL: server.URL,
|
||||
GLMSearchEngine: "search_std",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewWebSearchTool() error: %v", err)
|
||||
}
|
||||
|
||||
result := tool.Execute(context.Background(), map[string]any{
|
||||
"query": "test query",
|
||||
})
|
||||
|
||||
if result.IsError {
|
||||
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForUser, "Test GLM Result") {
|
||||
t.Errorf("Expected 'Test GLM Result' in output, got: %s", result.ForUser)
|
||||
}
|
||||
if !strings.Contains(result.ForUser, "https://example.com/glm") {
|
||||
t.Errorf("Expected URL in output, got: %s", result.ForUser)
|
||||
}
|
||||
if !strings.Contains(result.ForUser, "via GLM Search") {
|
||||
t.Errorf("Expected 'via GLM Search' in output, got: %s", result.ForUser)
|
||||
}
|
||||
}
|
||||
|
||||
type roundTripFunc func(*http.Request) (*http.Response, error)
|
||||
func TestWebTool_GLMSearch_APIError(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
w.Write([]byte(`{"error":"invalid api key"}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return f(req)
|
||||
tool, err := NewWebSearchTool(WebSearchToolOptions{
|
||||
GLMSearchEnabled: true,
|
||||
GLMSearchAPIKey: "bad-key",
|
||||
GLMSearchBaseURL: server.URL,
|
||||
GLMSearchEngine: "search_std",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewWebSearchTool() error: %v", err)
|
||||
}
|
||||
|
||||
result := tool.Execute(context.Background(), map[string]any{
|
||||
"query": "test query",
|
||||
})
|
||||
|
||||
if !result.IsError {
|
||||
t.Errorf("Expected IsError=true for 401 response")
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "status 401") {
|
||||
t.Errorf("Expected status 401 in error, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebTool_GLMSearch_Priority(t *testing.T) {
|
||||
// GLM Search should only be selected when all other providers are disabled
|
||||
tool, err := NewWebSearchTool(WebSearchToolOptions{
|
||||
DuckDuckGoEnabled: true,
|
||||
DuckDuckGoMaxResults: 5,
|
||||
GLMSearchEnabled: true,
|
||||
GLMSearchAPIKey: "test-key",
|
||||
GLMSearchBaseURL: "https://example.com",
|
||||
GLMSearchEngine: "search_std",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewWebSearchTool() error: %v", err)
|
||||
}
|
||||
|
||||
// DuckDuckGo should win over GLM Search
|
||||
if _, ok := tool.provider.(*DuckDuckGoSearchProvider); !ok {
|
||||
t.Errorf("Expected DuckDuckGoSearchProvider when both enabled, got %T", tool.provider)
|
||||
}
|
||||
|
||||
// With DuckDuckGo disabled, GLM Search should be selected
|
||||
tool2, err := NewWebSearchTool(WebSearchToolOptions{
|
||||
DuckDuckGoEnabled: false,
|
||||
GLMSearchEnabled: true,
|
||||
GLMSearchAPIKey: "test-key",
|
||||
GLMSearchBaseURL: "https://example.com",
|
||||
GLMSearchEngine: "search_std",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewWebSearchTool() error: %v", err)
|
||||
}
|
||||
if _, ok := tool2.provider.(*GLMSearchProvider); !ok {
|
||||
t.Errorf("Expected GLMSearchProvider when only GLM enabled, got %T", tool2.provider)
|
||||
}
|
||||
}
|
||||
|
||||
+19
-4
@@ -3,6 +3,7 @@ package utils
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
@@ -52,11 +53,12 @@ type DownloadOptions struct {
|
||||
Timeout time.Duration
|
||||
ExtraHeaders map[string]string
|
||||
LoggerPrefix string
|
||||
ProxyURL string
|
||||
}
|
||||
|
||||
// DownloadFile downloads a file from URL to a local temp directory.
|
||||
// Returns the local file path or empty string on error.
|
||||
func DownloadFile(url, filename string, opts DownloadOptions) string {
|
||||
func DownloadFile(urlStr, filename string, opts DownloadOptions) string {
|
||||
// Set defaults
|
||||
if opts.Timeout == 0 {
|
||||
opts.Timeout = 60 * time.Second
|
||||
@@ -78,7 +80,7 @@ func DownloadFile(url, filename string, opts DownloadOptions) string {
|
||||
localPath := filepath.Join(mediaDir, uuid.New().String()[:8]+"_"+safeName)
|
||||
|
||||
// Create HTTP request
|
||||
req, err := http.NewRequest("GET", url, nil)
|
||||
req, err := http.NewRequest("GET", urlStr, nil)
|
||||
if err != nil {
|
||||
logger.ErrorCF(opts.LoggerPrefix, "Failed to create download request", map[string]any{
|
||||
"error": err.Error(),
|
||||
@@ -92,11 +94,24 @@ func DownloadFile(url, filename string, opts DownloadOptions) string {
|
||||
}
|
||||
|
||||
client := &http.Client{Timeout: opts.Timeout}
|
||||
if opts.ProxyURL != "" {
|
||||
proxyURL, parseErr := url.Parse(opts.ProxyURL)
|
||||
if parseErr != nil {
|
||||
logger.ErrorCF(opts.LoggerPrefix, "Invalid proxy URL for download", map[string]any{
|
||||
"error": parseErr.Error(),
|
||||
"proxy": opts.ProxyURL,
|
||||
})
|
||||
return ""
|
||||
}
|
||||
client.Transport = &http.Transport{
|
||||
Proxy: http.ProxyURL(proxyURL),
|
||||
}
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
logger.ErrorCF(opts.LoggerPrefix, "Failed to download file", map[string]any{
|
||||
"error": err.Error(),
|
||||
"url": url,
|
||||
"url": urlStr,
|
||||
})
|
||||
return ""
|
||||
}
|
||||
@@ -105,7 +120,7 @@ func DownloadFile(url, filename string, opts DownloadOptions) string {
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
logger.ErrorCF(opts.LoggerPrefix, "File download returned non-200 status", map[string]any{
|
||||
"status": resp.StatusCode,
|
||||
"url": url,
|
||||
"url": urlStr,
|
||||
})
|
||||
return ""
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user