mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
Merge remote-tracking branch 'origin/main' into feat/echo-voice-audio-transcription
# Conflicts: # pkg/channels/telegram/telegram.go # pkg/config/config.go # pkg/config/defaults.go
This commit is contained in:
+41
-10
@@ -12,15 +12,19 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/skills"
|
||||
"github.com/sipeed/picoclaw/pkg/utils"
|
||||
)
|
||||
|
||||
type ContextBuilder struct {
|
||||
workspace string
|
||||
skillsLoader *skills.SkillsLoader
|
||||
memory *MemoryStore
|
||||
workspace string
|
||||
skillsLoader *skills.SkillsLoader
|
||||
memory *MemoryStore
|
||||
toolDiscoveryBM25 bool
|
||||
toolDiscoveryRegex bool
|
||||
|
||||
// Cache for system prompt to avoid rebuilding on every call.
|
||||
// This fixes issue #607: repeated reprocessing of the entire context.
|
||||
@@ -41,6 +45,12 @@ type ContextBuilder struct {
|
||||
skillFilesAtCache map[string]time.Time
|
||||
}
|
||||
|
||||
func (cb *ContextBuilder) WithToolDiscovery(useBM25, useRegex bool) *ContextBuilder {
|
||||
cb.toolDiscoveryBM25 = useBM25
|
||||
cb.toolDiscoveryRegex = useRegex
|
||||
return cb
|
||||
}
|
||||
|
||||
func getGlobalConfigDir() string {
|
||||
if home := os.Getenv("PICOCLAW_HOME"); home != "" {
|
||||
return home
|
||||
@@ -71,8 +81,11 @@ func NewContextBuilder(workspace string) *ContextBuilder {
|
||||
|
||||
func (cb *ContextBuilder) getIdentity() string {
|
||||
workspacePath, _ := filepath.Abs(filepath.Join(cb.workspace))
|
||||
toolDiscovery := cb.getDiscoveryRule()
|
||||
version := config.FormatVersion()
|
||||
|
||||
return fmt.Sprintf(`# picoclaw 🦞
|
||||
return fmt.Sprintf(
|
||||
`# picoclaw 🦞 (%s)
|
||||
|
||||
You are picoclaw, a helpful AI assistant.
|
||||
|
||||
@@ -90,8 +103,29 @@ Your workspace is at: %s
|
||||
|
||||
3. **Memory** - When interacting with me if something seems memorable, update %s/memory/MEMORY.md
|
||||
|
||||
4. **Context summaries** - Conversation summaries provided as context are approximate references only. They may be incomplete or outdated. Always defer to explicit user instructions over summary content.`,
|
||||
workspacePath, workspacePath, workspacePath, workspacePath, workspacePath)
|
||||
4. **Context summaries** - Conversation summaries provided as context are approximate references only. They may be incomplete or outdated. Always defer to explicit user instructions over summary content.
|
||||
|
||||
%s`,
|
||||
version, workspacePath, workspacePath, workspacePath, workspacePath, workspacePath, toolDiscovery)
|
||||
}
|
||||
|
||||
func (cb *ContextBuilder) getDiscoveryRule() string {
|
||||
if !cb.toolDiscoveryBM25 && !cb.toolDiscoveryRegex {
|
||||
return ""
|
||||
}
|
||||
|
||||
var toolNames []string
|
||||
if cb.toolDiscoveryBM25 {
|
||||
toolNames = append(toolNames, `"tool_search_tool_bm25"`)
|
||||
}
|
||||
if cb.toolDiscoveryRegex {
|
||||
toolNames = append(toolNames, `"tool_search_tool_regex"`)
|
||||
}
|
||||
|
||||
return fmt.Sprintf(
|
||||
`5. **Tool Discovery** - Your visible tools are limited to save memory, but a vast hidden library exists. If you lack the right tool for a task, BEFORE giving up, you MUST search using the %s tool. Do not refuse a request unless the search returns nothing. Found tools will temporarily unlock for your next turn.`,
|
||||
strings.Join(toolNames, " or "),
|
||||
)
|
||||
}
|
||||
|
||||
func (cb *ContextBuilder) BuildSystemPrompt() string {
|
||||
@@ -505,10 +539,7 @@ func (cb *ContextBuilder) BuildMessages(
|
||||
})
|
||||
|
||||
// Log preview of system prompt (avoid logging huge content)
|
||||
preview := fullSystemPrompt
|
||||
if len(preview) > 500 {
|
||||
preview = preview[:500] + "... (truncated)"
|
||||
}
|
||||
preview := utils.Truncate(fullSystemPrompt, 500)
|
||||
logger.DebugCF("agent", "System prompt preview",
|
||||
map[string]any{
|
||||
"preview": preview,
|
||||
|
||||
+45
-5
@@ -1,6 +1,7 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
@@ -9,6 +10,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/memory"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/routing"
|
||||
"github.com/sipeed/picoclaw/pkg/session"
|
||||
@@ -31,7 +33,7 @@ type AgentInstance struct {
|
||||
SummarizeMessageThreshold int
|
||||
SummarizeTokenPercent int
|
||||
Provider providers.LLMProvider
|
||||
Sessions *session.SessionManager
|
||||
Sessions session.SessionStore
|
||||
ContextBuilder *ContextBuilder
|
||||
Tools *tools.ToolRegistry
|
||||
Subagents *config.SubagentsConfig
|
||||
@@ -70,7 +72,8 @@ func NewAgentInstance(
|
||||
toolsRegistry := tools.NewToolRegistry()
|
||||
|
||||
if cfg.Tools.IsToolEnabled("read_file") {
|
||||
toolsRegistry.Register(tools.NewReadFileTool(workspace, readRestrict, allowReadPaths))
|
||||
maxReadFileSize := cfg.Tools.ReadFile.MaxReadFileSize
|
||||
toolsRegistry.Register(tools.NewReadFileTool(workspace, readRestrict, maxReadFileSize, allowReadPaths))
|
||||
}
|
||||
if cfg.Tools.IsToolEnabled("write_file") {
|
||||
toolsRegistry.Register(tools.NewWriteFileTool(workspace, restrict, allowWritePaths))
|
||||
@@ -94,9 +97,13 @@ func NewAgentInstance(
|
||||
}
|
||||
|
||||
sessionsDir := filepath.Join(workspace, "sessions")
|
||||
sessionsManager := session.NewSessionManager(sessionsDir)
|
||||
sessions := initSessionStore(sessionsDir)
|
||||
|
||||
contextBuilder := NewContextBuilder(workspace)
|
||||
mcpDiscoveryActive := cfg.Tools.MCP.Enabled && cfg.Tools.MCP.Discovery.Enabled
|
||||
contextBuilder := NewContextBuilder(workspace).WithToolDiscovery(
|
||||
mcpDiscoveryActive && cfg.Tools.MCP.Discovery.UseBM25,
|
||||
mcpDiscoveryActive && cfg.Tools.MCP.Discovery.UseRegex,
|
||||
)
|
||||
|
||||
agentID := routing.DefaultAgentID
|
||||
agentName := ""
|
||||
@@ -221,7 +228,7 @@ func NewAgentInstance(
|
||||
SummarizeMessageThreshold: summarizeMessageThreshold,
|
||||
SummarizeTokenPercent: summarizeTokenPercent,
|
||||
Provider: provider,
|
||||
Sessions: sessionsManager,
|
||||
Sessions: sessions,
|
||||
ContextBuilder: contextBuilder,
|
||||
Tools: toolsRegistry,
|
||||
Subagents: subagents,
|
||||
@@ -275,6 +282,39 @@ func compilePatterns(patterns []string) []*regexp.Regexp {
|
||||
return compiled
|
||||
}
|
||||
|
||||
// Close releases resources held by the agent's session store.
|
||||
func (a *AgentInstance) Close() error {
|
||||
if a.Sessions != nil {
|
||||
return a.Sessions.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// initSessionStore creates the session persistence backend.
|
||||
// It uses the JSONL store by default and auto-migrates legacy JSON sessions.
|
||||
// Falls back to SessionManager if the JSONL store cannot be initialized or
|
||||
// if migration fails (which indicates the store cannot write reliably).
|
||||
func initSessionStore(dir string) session.SessionStore {
|
||||
store, err := memory.NewJSONLStore(dir)
|
||||
if err != nil {
|
||||
log.Printf("memory: init store: %v; using json sessions", err)
|
||||
return session.NewSessionManager(dir)
|
||||
}
|
||||
|
||||
if n, merr := memory.MigrateFromJSON(context.Background(), dir, store); merr != nil {
|
||||
// Migration failure means the store could not write data.
|
||||
// Fall back to SessionManager to avoid a split state where
|
||||
// some sessions are in JSONL and others remain in JSON.
|
||||
log.Printf("memory: migration failed: %v; falling back to json sessions", merr)
|
||||
store.Close()
|
||||
return session.NewSessionManager(dir)
|
||||
} else if n > 0 {
|
||||
log.Printf("memory: migrated %d session(s) to jsonl", n)
|
||||
}
|
||||
|
||||
return session.NewJSONLBackend(store)
|
||||
}
|
||||
|
||||
func expandHome(path string) string {
|
||||
if path == "" {
|
||||
return path
|
||||
|
||||
+213
-45
@@ -120,19 +120,21 @@ func registerSharedTools(
|
||||
continue
|
||||
}
|
||||
|
||||
// Web tools
|
||||
if cfg.Tools.IsToolEnabled("web") {
|
||||
searchTool, err := tools.NewWebSearchTool(tools.WebSearchToolOptions{
|
||||
BraveAPIKey: cfg.Tools.Web.Brave.APIKey,
|
||||
BraveAPIKeys: config.MergeAPIKeys(cfg.Tools.Web.Brave.APIKey, cfg.Tools.Web.Brave.APIKeys),
|
||||
BraveMaxResults: cfg.Tools.Web.Brave.MaxResults,
|
||||
BraveEnabled: cfg.Tools.Web.Brave.Enabled,
|
||||
TavilyAPIKey: cfg.Tools.Web.Tavily.APIKey,
|
||||
TavilyAPIKeys: config.MergeAPIKeys(cfg.Tools.Web.Tavily.APIKey, cfg.Tools.Web.Tavily.APIKeys),
|
||||
TavilyBaseURL: cfg.Tools.Web.Tavily.BaseURL,
|
||||
TavilyMaxResults: cfg.Tools.Web.Tavily.MaxResults,
|
||||
TavilyEnabled: cfg.Tools.Web.Tavily.Enabled,
|
||||
DuckDuckGoMaxResults: cfg.Tools.Web.DuckDuckGo.MaxResults,
|
||||
DuckDuckGoEnabled: cfg.Tools.Web.DuckDuckGo.Enabled,
|
||||
PerplexityAPIKey: cfg.Tools.Web.Perplexity.APIKey,
|
||||
PerplexityAPIKeys: config.MergeAPIKeys(
|
||||
cfg.Tools.Web.Perplexity.APIKey,
|
||||
cfg.Tools.Web.Perplexity.APIKeys,
|
||||
),
|
||||
PerplexityMaxResults: cfg.Tools.Web.Perplexity.MaxResults,
|
||||
PerplexityEnabled: cfg.Tools.Web.Perplexity.Enabled,
|
||||
SearXNGBaseURL: cfg.Tools.Web.SearXNG.BaseURL,
|
||||
@@ -283,7 +285,13 @@ func (al *AgentLoop) Run(ctx context.Context) error {
|
||||
}
|
||||
|
||||
mcpTool := tools.NewMCPTool(mcpManager, serverName, tool)
|
||||
agent.Tools.Register(mcpTool)
|
||||
|
||||
if al.cfg.Tools.MCP.Discovery.Enabled {
|
||||
agent.Tools.RegisterHidden(mcpTool)
|
||||
} else {
|
||||
agent.Tools.Register(mcpTool)
|
||||
}
|
||||
|
||||
totalRegistrations++
|
||||
logger.DebugCF("agent", "Registered MCP tool",
|
||||
map[string]any{
|
||||
@@ -302,6 +310,47 @@ func (al *AgentLoop) Run(ctx context.Context) error {
|
||||
"total_registrations": totalRegistrations,
|
||||
"agent_count": agentCount,
|
||||
})
|
||||
|
||||
// Initializes Discovery Tools only if enabled by configuration
|
||||
if al.cfg.Tools.MCP.Enabled && al.cfg.Tools.MCP.Discovery.Enabled {
|
||||
useBM25 := al.cfg.Tools.MCP.Discovery.UseBM25
|
||||
useRegex := al.cfg.Tools.MCP.Discovery.UseRegex
|
||||
|
||||
// Fail fast: If discovery is enabled but no search method is turned on
|
||||
if !useBM25 && !useRegex {
|
||||
return fmt.Errorf(
|
||||
"tool discovery is enabled but neither 'use_bm25' nor 'use_regex' is set to true in the configuration",
|
||||
)
|
||||
}
|
||||
|
||||
ttl := al.cfg.Tools.MCP.Discovery.TTL
|
||||
if ttl <= 0 {
|
||||
ttl = 5 // Default value
|
||||
}
|
||||
|
||||
maxSearchResults := al.cfg.Tools.MCP.Discovery.MaxSearchResults
|
||||
if maxSearchResults <= 0 {
|
||||
maxSearchResults = 5 // Default value
|
||||
}
|
||||
|
||||
logger.InfoCF("agent", "Initializing tool discovery", map[string]any{
|
||||
"bm25": useBM25, "regex": useRegex, "ttl": ttl, "max_results": maxSearchResults,
|
||||
})
|
||||
|
||||
for _, agentID := range agentIDs {
|
||||
agent, ok := al.registry.GetAgent(agentID)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if useRegex {
|
||||
agent.Tools.Register(tools.NewRegexSearchTool(agent.Tools, ttl, maxSearchResults))
|
||||
}
|
||||
if useBM25 {
|
||||
agent.Tools.Register(tools.NewBM25SearchTool(agent.Tools, ttl, maxSearchResults))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -380,6 +429,11 @@ func (al *AgentLoop) Stop() {
|
||||
al.running.Store(false)
|
||||
}
|
||||
|
||||
// Close releases resources held by agent session stores. Call after Stop.
|
||||
func (al *AgentLoop) Close() {
|
||||
al.registry.Close()
|
||||
}
|
||||
|
||||
func (al *AgentLoop) RegisterTool(tool tools.Tool) {
|
||||
for _, agentID := range al.registry.ListAgentIDs() {
|
||||
if agent, ok := al.registry.GetAgent(agentID); ok {
|
||||
@@ -632,15 +686,6 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage)
|
||||
}
|
||||
|
||||
route, agent, routeErr := al.resolveMessageRoute(msg)
|
||||
|
||||
// Commands are checked before requiring a successful route.
|
||||
// Global commands (/help, /show, /switch) work even when routing fails;
|
||||
// context-dependent commands check their own Runtime fields and report
|
||||
// "unavailable" when the required capability is nil.
|
||||
if response, handled := al.handleCommand(ctx, msg, agent); handled {
|
||||
return response, nil
|
||||
}
|
||||
|
||||
if routeErr != nil {
|
||||
return "", routeErr
|
||||
}
|
||||
@@ -666,7 +711,7 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage)
|
||||
"route_channel": route.Channel,
|
||||
})
|
||||
|
||||
return al.runAgentLoop(ctx, agent, processOptions{
|
||||
opts := processOptions{
|
||||
SessionKey: sessionKey,
|
||||
Channel: msg.Channel,
|
||||
ChatID: msg.ChatID,
|
||||
@@ -675,7 +720,15 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage)
|
||||
DefaultResponse: defaultResponse,
|
||||
EnableSummary: true,
|
||||
SendResponse: false,
|
||||
})
|
||||
}
|
||||
|
||||
// context-dependent commands check their own Runtime fields and report
|
||||
// "unavailable" when the required capability is nil.
|
||||
if response, handled := al.handleCommand(ctx, msg, agent, &opts); handled {
|
||||
return response, nil
|
||||
}
|
||||
|
||||
return al.runAgentLoop(ctx, agent, opts)
|
||||
}
|
||||
|
||||
func (al *AgentLoop) resolveMessageRoute(msg bus.InboundMessage) (routing.ResolvedRoute, *AgentInstance, error) {
|
||||
@@ -1306,6 +1359,17 @@ func (al *AgentLoop) runLLMIteration(
|
||||
// Save tool result message to session
|
||||
agent.Sessions.AddFullMessage(opts.SessionKey, toolResultMsg)
|
||||
}
|
||||
|
||||
// Tick down TTL of discovered tools after processing tool results.
|
||||
// Only reached when tool calls were made (the loop continues);
|
||||
// the break on no-tool-call responses skips this.
|
||||
// NOTE: This is safe because processMessage is sequential per agent.
|
||||
// If per-agent concurrency is added, TTL consistency between
|
||||
// ToProviderDefs and Get must be re-evaluated.
|
||||
agent.Tools.TickTTL()
|
||||
logger.DebugCF("agent", "TTL tick after tool execution", map[string]any{
|
||||
"agent_id": agent.ID, "iteration": iteration,
|
||||
})
|
||||
}
|
||||
|
||||
return finalContent, iteration, nil
|
||||
@@ -1543,10 +1607,20 @@ func (al *AgentLoop) summarizeSession(agent *AgentInstance, sessionKey string) {
|
||||
return
|
||||
}
|
||||
|
||||
const (
|
||||
maxSummarizationMessages = 10
|
||||
llmMaxRetries = 3
|
||||
llmTemperature = 0.3
|
||||
fallbackMaxContentLength = 200
|
||||
)
|
||||
|
||||
// Multi-Part Summarization
|
||||
var finalSummary string
|
||||
if len(validMessages) > 10 {
|
||||
if len(validMessages) > maxSummarizationMessages {
|
||||
mid := len(validMessages) / 2
|
||||
|
||||
mid = al.findNearestUserMessage(validMessages, mid)
|
||||
|
||||
part1 := validMessages[:mid]
|
||||
part2 := validMessages[mid:]
|
||||
|
||||
@@ -1558,18 +1632,9 @@ func (al *AgentLoop) summarizeSession(agent *AgentInstance, sessionKey string) {
|
||||
s1,
|
||||
s2,
|
||||
)
|
||||
resp, err := agent.Provider.Chat(
|
||||
ctx,
|
||||
[]providers.Message{{Role: "user", Content: mergePrompt}},
|
||||
nil,
|
||||
agent.Model,
|
||||
map[string]any{
|
||||
"max_tokens": 1024,
|
||||
"temperature": 0.3,
|
||||
"prompt_cache_key": agent.ID,
|
||||
},
|
||||
)
|
||||
if err == nil {
|
||||
|
||||
resp, err := al.retryLLMCall(ctx, agent, mergePrompt, llmMaxRetries)
|
||||
if err == nil && resp.Content != "" {
|
||||
finalSummary = resp.Content
|
||||
} else {
|
||||
finalSummary = s1 + " " + s2
|
||||
@@ -1589,6 +1654,68 @@ func (al *AgentLoop) summarizeSession(agent *AgentInstance, sessionKey string) {
|
||||
}
|
||||
}
|
||||
|
||||
// findNearestUserMessage finds the nearest user message to the given index.
|
||||
// It searches backward first, then forward if no user message is found.
|
||||
func (al *AgentLoop) findNearestUserMessage(messages []providers.Message, mid int) int {
|
||||
originalMid := mid
|
||||
|
||||
for mid > 0 && messages[mid].Role != "user" {
|
||||
mid--
|
||||
}
|
||||
|
||||
if messages[mid].Role == "user" {
|
||||
return mid
|
||||
}
|
||||
|
||||
mid = originalMid
|
||||
for mid < len(messages) && messages[mid].Role != "user" {
|
||||
mid++
|
||||
}
|
||||
|
||||
if mid < len(messages) {
|
||||
return mid
|
||||
}
|
||||
|
||||
return originalMid
|
||||
}
|
||||
|
||||
// retryLLMCall calls the LLM with retry logic.
|
||||
func (al *AgentLoop) retryLLMCall(
|
||||
ctx context.Context,
|
||||
agent *AgentInstance,
|
||||
prompt string,
|
||||
maxRetries int,
|
||||
) (*providers.LLMResponse, error) {
|
||||
const (
|
||||
llmTemperature = 0.3
|
||||
)
|
||||
|
||||
var resp *providers.LLMResponse
|
||||
var err error
|
||||
|
||||
for attempt := 0; attempt < maxRetries; attempt++ {
|
||||
resp, err = agent.Provider.Chat(
|
||||
ctx,
|
||||
[]providers.Message{{Role: "user", Content: prompt}},
|
||||
nil,
|
||||
agent.Model,
|
||||
map[string]any{
|
||||
"max_tokens": agent.MaxTokens,
|
||||
"temperature": llmTemperature,
|
||||
"prompt_cache_key": agent.ID,
|
||||
},
|
||||
)
|
||||
if err == nil && resp != nil && resp.Content != "" {
|
||||
return resp, nil
|
||||
}
|
||||
if attempt < maxRetries-1 {
|
||||
time.Sleep(time.Duration(attempt+1) * 100 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
|
||||
return resp, err
|
||||
}
|
||||
|
||||
// summarizeBatch summarizes a batch of messages.
|
||||
func (al *AgentLoop) summarizeBatch(
|
||||
ctx context.Context,
|
||||
@@ -1596,6 +1723,13 @@ func (al *AgentLoop) summarizeBatch(
|
||||
batch []providers.Message,
|
||||
existingSummary string,
|
||||
) (string, error) {
|
||||
const (
|
||||
llmMaxRetries = 3
|
||||
llmTemperature = 0.3
|
||||
fallbackMinContentLength = 200
|
||||
fallbackMaxContentPercent = 10
|
||||
)
|
||||
|
||||
var sb strings.Builder
|
||||
sb.WriteString(
|
||||
"Provide a concise summary of this conversation segment, preserving core context and key points.\n",
|
||||
@@ -1611,21 +1745,40 @@ func (al *AgentLoop) summarizeBatch(
|
||||
}
|
||||
prompt := sb.String()
|
||||
|
||||
response, err := agent.Provider.Chat(
|
||||
ctx,
|
||||
[]providers.Message{{Role: "user", Content: prompt}},
|
||||
nil,
|
||||
agent.Model,
|
||||
map[string]any{
|
||||
"max_tokens": 1024,
|
||||
"temperature": 0.3,
|
||||
"prompt_cache_key": agent.ID,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return "", err
|
||||
response, err := al.retryLLMCall(ctx, agent, prompt, llmMaxRetries)
|
||||
if err == nil && response.Content != "" {
|
||||
return strings.TrimSpace(response.Content), nil
|
||||
}
|
||||
return response.Content, nil
|
||||
|
||||
var fallback strings.Builder
|
||||
fallback.WriteString("Conversation summary: ")
|
||||
for i, m := range batch {
|
||||
if i > 0 {
|
||||
fallback.WriteString(" | ")
|
||||
}
|
||||
content := strings.TrimSpace(m.Content)
|
||||
runes := []rune(content)
|
||||
if len(runes) == 0 {
|
||||
fallback.WriteString(fmt.Sprintf("%s: ", m.Role))
|
||||
continue
|
||||
}
|
||||
|
||||
keepLength := len(runes) * fallbackMaxContentPercent / 100
|
||||
if keepLength < fallbackMinContentLength {
|
||||
keepLength = fallbackMinContentLength
|
||||
}
|
||||
|
||||
if keepLength > len(runes) {
|
||||
keepLength = len(runes)
|
||||
}
|
||||
|
||||
content = string(runes[:keepLength])
|
||||
if keepLength < len(runes) {
|
||||
content += "..."
|
||||
}
|
||||
fallback.WriteString(fmt.Sprintf("%s: %s", m.Role, content))
|
||||
}
|
||||
return fallback.String(), nil
|
||||
}
|
||||
|
||||
// estimateTokens estimates the number of tokens in a message list.
|
||||
@@ -1644,6 +1797,7 @@ func (al *AgentLoop) handleCommand(
|
||||
ctx context.Context,
|
||||
msg bus.InboundMessage,
|
||||
agent *AgentInstance,
|
||||
opts *processOptions,
|
||||
) (string, bool) {
|
||||
if !commands.HasCommandPrefix(msg.Content) {
|
||||
return "", false
|
||||
@@ -1653,7 +1807,7 @@ func (al *AgentLoop) handleCommand(
|
||||
return "", false
|
||||
}
|
||||
|
||||
rt := al.buildCommandsRuntime(agent)
|
||||
rt := al.buildCommandsRuntime(agent, opts)
|
||||
executor := commands.NewExecutor(al.cmdRegistry, rt)
|
||||
|
||||
var commandReply string
|
||||
@@ -1682,7 +1836,7 @@ func (al *AgentLoop) handleCommand(
|
||||
}
|
||||
}
|
||||
|
||||
func (al *AgentLoop) buildCommandsRuntime(agent *AgentInstance) *commands.Runtime {
|
||||
func (al *AgentLoop) buildCommandsRuntime(agent *AgentInstance, opts *processOptions) *commands.Runtime {
|
||||
rt := &commands.Runtime{
|
||||
Config: al.cfg,
|
||||
ListAgentIDs: al.registry.ListAgentIDs,
|
||||
@@ -1712,6 +1866,20 @@ func (al *AgentLoop) buildCommandsRuntime(agent *AgentInstance) *commands.Runtim
|
||||
agent.Model = value
|
||||
return oldModel, nil
|
||||
}
|
||||
|
||||
rt.ClearHistory = func() error {
|
||||
if opts == nil {
|
||||
return fmt.Errorf("process options not available")
|
||||
}
|
||||
if agent.Sessions == nil {
|
||||
return fmt.Errorf("sessions not initialized for agent")
|
||||
}
|
||||
|
||||
agent.Sessions.SetHistory(opts.SessionKey, make([]providers.Message, 0))
|
||||
agent.Sessions.SetSummary(opts.SessionKey, "")
|
||||
agent.Sessions.Save(opts.SessionKey)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return rt
|
||||
}
|
||||
|
||||
@@ -114,6 +114,18 @@ func (r *AgentRegistry) ForEachTool(name string, fn func(tools.Tool)) {
|
||||
}
|
||||
}
|
||||
|
||||
// Close releases resources held by all registered agents.
|
||||
func (r *AgentRegistry) Close() {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
for _, agent := range r.agents {
|
||||
if err := agent.Close(); err != nil {
|
||||
logger.WarnCF("agent", "Failed to close agent",
|
||||
map[string]any{"agent_id": agent.ID, "error": err.Error()})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetDefaultAgent returns the default agent instance.
|
||||
func (r *AgentRegistry) GetDefaultAgent() *AgentInstance {
|
||||
r.mu.RLock()
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@@ -195,18 +196,30 @@ func (c *FeishuChannel) SendPlaceholder(ctx context.Context, chatID string) (str
|
||||
}
|
||||
|
||||
// ReactToMessage implements channels.ReactionCapable.
|
||||
// Adds an "Pin" reaction and returns an undo function to remove it.
|
||||
// Adds a reaction (randomly chosen from config) and returns an undo function to remove it.
|
||||
func (c *FeishuChannel) ReactToMessage(ctx context.Context, chatID, messageID string) (func(), error) {
|
||||
// Get emoji list from config
|
||||
emojiList := c.config.RandomReactionEmoji
|
||||
var chosenEmoji string
|
||||
if len(emojiList) == 0 {
|
||||
// Default to "Pin" if no config
|
||||
chosenEmoji = "Pin"
|
||||
} else {
|
||||
idx := rand.Intn(len(emojiList))
|
||||
chosenEmoji = emojiList[idx]
|
||||
}
|
||||
|
||||
req := larkim.NewCreateMessageReactionReqBuilder().
|
||||
MessageId(messageID).
|
||||
Body(larkim.NewCreateMessageReactionReqBodyBuilder().
|
||||
ReactionType(larkim.NewEmojiBuilder().EmojiType("Pin").Build()).
|
||||
ReactionType(larkim.NewEmojiBuilder().EmojiType(chosenEmoji).Build()).
|
||||
Build()).
|
||||
Build()
|
||||
|
||||
resp, err := c.client.Im.V1.MessageReaction.Create(ctx, req)
|
||||
if err != nil {
|
||||
logger.ErrorCF("feishu", "Failed to add reaction", map[string]any{
|
||||
"emoji": chosenEmoji,
|
||||
"message_id": messageID,
|
||||
"error": err.Error(),
|
||||
})
|
||||
@@ -214,6 +227,7 @@ func (c *FeishuChannel) ReactToMessage(ctx context.Context, chatID, messageID st
|
||||
}
|
||||
if !resp.Success() {
|
||||
logger.ErrorCF("feishu", "Reaction API error", map[string]any{
|
||||
"emoji": chosenEmoji,
|
||||
"message_id": messageID,
|
||||
"code": resp.Code,
|
||||
"msg": resp.Msg,
|
||||
|
||||
@@ -61,7 +61,9 @@ var channelRateConfig = map[string]float64{
|
||||
"telegram": 20,
|
||||
"discord": 1,
|
||||
"slack": 1,
|
||||
"matrix": 2,
|
||||
"line": 10,
|
||||
"qq": 5,
|
||||
"irc": 2,
|
||||
}
|
||||
|
||||
@@ -265,6 +267,13 @@ func (m *Manager) initChannels() error {
|
||||
m.initChannel("slack", "Slack")
|
||||
}
|
||||
|
||||
if m.config.Channels.Matrix.Enabled &&
|
||||
m.config.Channels.Matrix.Homeserver != "" &&
|
||||
m.config.Channels.Matrix.UserID != "" &&
|
||||
m.config.Channels.Matrix.AccessToken != "" {
|
||||
m.initChannel("matrix", "Matrix")
|
||||
}
|
||||
|
||||
if m.config.Channels.LINE.Enabled && m.config.Channels.LINE.ChannelAccessToken != "" {
|
||||
m.initChannel("line", "LINE")
|
||||
}
|
||||
|
||||
@@ -0,0 +1,13 @@
|
||||
package matrix
|
||||
|
||||
import (
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/channels"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
func init() {
|
||||
channels.RegisterFactory("matrix", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
|
||||
return NewMatrixChannel(cfg.Channels.Matrix, b)
|
||||
})
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,291 @@
|
||||
package matrix
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"maunium.net/go/mautrix"
|
||||
"maunium.net/go/mautrix/event"
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
func TestMatrixLocalpartMentionRegexp(t *testing.T) {
|
||||
re := localpartMentionRegexp("picoclaw")
|
||||
|
||||
cases := []struct {
|
||||
text string
|
||||
want bool
|
||||
}{
|
||||
{text: "@picoclaw hello", want: true},
|
||||
{text: "hi @picoclaw:matrix.org", want: true},
|
||||
{
|
||||
text: "\u6b22\u8fce\u4e00\u4e0bpicoclaw\u5c0f\u9f99\u867e",
|
||||
want: false, // historical false-positive case in PR #356
|
||||
},
|
||||
{text: "mail test@example.com", want: false},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
if got := re.MatchString(tc.text); got != tc.want {
|
||||
t.Fatalf("text=%q match=%v want=%v", tc.text, got, tc.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestStripUserMention(t *testing.T) {
|
||||
userID := id.UserID("@picoclaw:matrix.org")
|
||||
|
||||
cases := []struct {
|
||||
in string
|
||||
want string
|
||||
}{
|
||||
{in: "@picoclaw:matrix.org hello", want: "hello"},
|
||||
{in: "@picoclaw, hello", want: "hello"},
|
||||
{in: "no mention here", want: "no mention here"},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
if got := stripUserMention(tc.in, userID); got != tc.want {
|
||||
t.Fatalf("stripUserMention(%q)=%q want=%q", tc.in, got, tc.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsBotMentioned(t *testing.T) {
|
||||
ch := &MatrixChannel{
|
||||
client: &mautrix.Client{
|
||||
UserID: id.UserID("@picoclaw:matrix.org"),
|
||||
},
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
msg event.MessageEventContent
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "mentions field",
|
||||
msg: event.MessageEventContent{
|
||||
Body: "hello",
|
||||
Mentions: &event.Mentions{
|
||||
UserIDs: []id.UserID{id.UserID("@picoclaw:matrix.org")},
|
||||
},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "full user id in body",
|
||||
msg: event.MessageEventContent{
|
||||
Body: "@picoclaw:matrix.org hello",
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "localpart with at sign",
|
||||
msg: event.MessageEventContent{
|
||||
Body: "@picoclaw hello",
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "localpart without at sign should not match",
|
||||
msg: event.MessageEventContent{
|
||||
Body: "\u6b22\u8fce\u4e00\u4e0bpicoclaw\u5c0f\u9f99\u867e",
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "formatted mention href matrix.to plain",
|
||||
msg: event.MessageEventContent{
|
||||
Body: "hello bot",
|
||||
FormattedBody: `<a href="https://matrix.to/#/@picoclaw:matrix.org">PicoClaw</a> hello`,
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "formatted mention href matrix.to encoded",
|
||||
msg: event.MessageEventContent{
|
||||
Body: "hello bot",
|
||||
FormattedBody: `<a href="https://matrix.to/#/%40picoclaw%3Amatrix.org">PicoClaw</a> hello`,
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
if got := ch.isBotMentioned(&tc.msg); got != tc.want {
|
||||
t.Fatalf("%s: got=%v want=%v", tc.name, got, tc.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoomKindCache_ExpiresEntries(t *testing.T) {
|
||||
cache := newRoomKindCache(4, 5*time.Second)
|
||||
now := time.Unix(100, 0)
|
||||
cache.set("!room:matrix.org", true, now)
|
||||
|
||||
if got, ok := cache.get("!room:matrix.org", now.Add(2*time.Second)); !ok || !got {
|
||||
t.Fatalf("expected cached group room before ttl, got ok=%v group=%v", ok, got)
|
||||
}
|
||||
|
||||
if _, ok := cache.get("!room:matrix.org", now.Add(6*time.Second)); ok {
|
||||
t.Fatal("expected cache miss after ttl expiry")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoomKindCache_EvictsOldestWhenFull(t *testing.T) {
|
||||
cache := newRoomKindCache(2, time.Minute)
|
||||
now := time.Unix(200, 0)
|
||||
|
||||
cache.set("!room1:matrix.org", false, now)
|
||||
cache.set("!room2:matrix.org", false, now.Add(1*time.Second))
|
||||
cache.set("!room3:matrix.org", true, now.Add(2*time.Second))
|
||||
|
||||
if _, ok := cache.get("!room1:matrix.org", now.Add(2*time.Second)); ok {
|
||||
t.Fatal("expected oldest cache entry to be evicted")
|
||||
}
|
||||
if got, ok := cache.get("!room2:matrix.org", now.Add(2*time.Second)); !ok || got {
|
||||
t.Fatalf("expected room2 to remain and be direct, got ok=%v group=%v", ok, got)
|
||||
}
|
||||
if got, ok := cache.get("!room3:matrix.org", now.Add(2*time.Second)); !ok || !got {
|
||||
t.Fatalf("expected room3 to remain and be group, got ok=%v group=%v", ok, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMatrixMediaTempDir(t *testing.T) {
|
||||
dir, err := matrixMediaTempDir()
|
||||
if err != nil {
|
||||
t.Fatalf("matrixMediaTempDir failed: %v", err)
|
||||
}
|
||||
if filepath.Base(dir) != matrixMediaTempDirName {
|
||||
t.Fatalf("unexpected media dir base: %q", filepath.Base(dir))
|
||||
}
|
||||
|
||||
info, err := os.Stat(dir)
|
||||
if err != nil {
|
||||
t.Fatalf("media dir not created: %v", err)
|
||||
}
|
||||
if !info.IsDir() {
|
||||
t.Fatalf("expected directory, got mode=%v", info.Mode())
|
||||
}
|
||||
}
|
||||
|
||||
func TestMatrixMediaExt(t *testing.T) {
|
||||
if got := matrixMediaExt("photo.png", "", "image"); got != ".png" {
|
||||
t.Fatalf("filename extension mismatch: got=%q", got)
|
||||
}
|
||||
if got := matrixMediaExt("", "image/webp", "image"); got != ".webp" {
|
||||
t.Fatalf("content-type extension mismatch: got=%q", got)
|
||||
}
|
||||
if got := matrixMediaExt("", "", "image"); got != ".jpg" {
|
||||
t.Fatalf("default image extension mismatch: got=%q", got)
|
||||
}
|
||||
if got := matrixMediaExt("", "", "audio"); got != ".ogg" {
|
||||
t.Fatalf("default audio extension mismatch: got=%q", got)
|
||||
}
|
||||
if got := matrixMediaExt("", "", "video"); got != ".mp4" {
|
||||
t.Fatalf("default video extension mismatch: got=%q", got)
|
||||
}
|
||||
if got := matrixMediaExt("", "", "file"); got != ".bin" {
|
||||
t.Fatalf("default file extension mismatch: got=%q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractInboundContent_ImageNoURLFallback(t *testing.T) {
|
||||
ch := &MatrixChannel{}
|
||||
msg := &event.MessageEventContent{
|
||||
MsgType: event.MsgImage,
|
||||
Body: "test.png",
|
||||
}
|
||||
|
||||
content, mediaRefs, ok := ch.extractInboundContent(context.Background(), msg, "matrix:room:event")
|
||||
if !ok {
|
||||
t.Fatal("expected ok for image fallback")
|
||||
}
|
||||
if content != "[image: test.png]" {
|
||||
t.Fatalf("unexpected content: %q", content)
|
||||
}
|
||||
if len(mediaRefs) != 0 {
|
||||
t.Fatalf("expected no media refs, got %d", len(mediaRefs))
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractInboundContent_AudioNoURLFallback(t *testing.T) {
|
||||
ch := &MatrixChannel{}
|
||||
msg := &event.MessageEventContent{
|
||||
MsgType: event.MsgAudio,
|
||||
FileName: "voice.ogg",
|
||||
Body: "please transcribe",
|
||||
}
|
||||
|
||||
content, mediaRefs, ok := ch.extractInboundContent(context.Background(), msg, "matrix:room:event")
|
||||
if !ok {
|
||||
t.Fatal("expected ok for audio fallback")
|
||||
}
|
||||
if content != "please transcribe\n[audio: voice.ogg]" {
|
||||
t.Fatalf("unexpected content: %q", content)
|
||||
}
|
||||
if len(mediaRefs) != 0 {
|
||||
t.Fatalf("expected no media refs, got %d", len(mediaRefs))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMatrixOutboundMsgType(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
partType string
|
||||
filename string
|
||||
contentType string
|
||||
want event.MessageType
|
||||
}{
|
||||
{name: "explicit image", partType: "image", want: event.MsgImage},
|
||||
{name: "explicit audio", partType: "audio", want: event.MsgAudio},
|
||||
{name: "mime fallback video", contentType: "video/mp4", want: event.MsgVideo},
|
||||
{name: "extension fallback audio", filename: "voice.ogg", want: event.MsgAudio},
|
||||
{name: "unknown defaults file", filename: "report.txt", want: event.MsgFile},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
if got := matrixOutboundMsgType(tc.partType, tc.filename, tc.contentType); got != tc.want {
|
||||
t.Fatalf("%s: got=%q want=%q", tc.name, got, tc.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestMatrixOutboundContent(t *testing.T) {
|
||||
content := matrixOutboundContent(
|
||||
"please review",
|
||||
"voice.ogg",
|
||||
event.MsgAudio,
|
||||
"audio/ogg",
|
||||
1234,
|
||||
id.ContentURIString("mxc://matrix.org/abc"),
|
||||
)
|
||||
if content.Body != "please review" {
|
||||
t.Fatalf("unexpected body: %q", content.Body)
|
||||
}
|
||||
if content.FileName != "voice.ogg" {
|
||||
t.Fatalf("unexpected filename: %q", content.FileName)
|
||||
}
|
||||
if content.Info == nil || content.Info.MimeType != "audio/ogg" {
|
||||
t.Fatalf("unexpected content type: %+v", content.Info)
|
||||
}
|
||||
if content.Info == nil || content.Info.Size != 1234 {
|
||||
t.Fatalf("unexpected size: %+v", content.Info)
|
||||
}
|
||||
|
||||
noCaption := matrixOutboundContent(
|
||||
"",
|
||||
"image.png",
|
||||
event.MsgImage,
|
||||
"image/png",
|
||||
0,
|
||||
id.ContentURIString("mxc://matrix.org/def"),
|
||||
)
|
||||
if noCaption.Body != "image.png" {
|
||||
t.Fatalf("unexpected fallback body: %q", noCaption.Body)
|
||||
}
|
||||
}
|
||||
+362
-29
@@ -3,7 +3,10 @@ package qq
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/tencent-connect/botgo"
|
||||
@@ -20,6 +23,14 @@ import (
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
)
|
||||
|
||||
const (
|
||||
dedupTTL = 5 * time.Minute
|
||||
dedupInterval = 60 * time.Second
|
||||
dedupMaxSize = 10000 // hard cap on dedup map entries
|
||||
typingResend = 8 * time.Second
|
||||
typingSeconds = 10
|
||||
)
|
||||
|
||||
type QQChannel struct {
|
||||
*channels.BaseChannel
|
||||
config config.QQConfig
|
||||
@@ -28,20 +39,37 @@ type QQChannel struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
sessionManager botgo.SessionManager
|
||||
processedIDs map[string]bool
|
||||
mu sync.RWMutex
|
||||
|
||||
// Chat routing: track whether a chatID is group or direct.
|
||||
chatType sync.Map // chatID → "group" | "direct"
|
||||
|
||||
// Passive reply: store last inbound message ID per chat.
|
||||
lastMsgID sync.Map // chatID → string
|
||||
|
||||
// msg_seq: per-chat atomic counter for multi-part replies.
|
||||
msgSeqCounters sync.Map // chatID → *atomic.Uint64
|
||||
|
||||
// Time-based dedup replacing the unbounded map.
|
||||
dedup map[string]time.Time
|
||||
muDedup sync.Mutex
|
||||
|
||||
// done is closed on Stop to shut down the dedup janitor.
|
||||
done chan struct{}
|
||||
stopOnce sync.Once
|
||||
}
|
||||
|
||||
func NewQQChannel(cfg config.QQConfig, messageBus *bus.MessageBus) (*QQChannel, error) {
|
||||
base := channels.NewBaseChannel("qq", cfg, messageBus, cfg.AllowFrom,
|
||||
channels.WithMaxMessageLength(cfg.MaxMessageLength),
|
||||
channels.WithGroupTrigger(cfg.GroupTrigger),
|
||||
channels.WithReasoningChannelID(cfg.ReasoningChannelID),
|
||||
)
|
||||
|
||||
return &QQChannel{
|
||||
BaseChannel: base,
|
||||
config: cfg,
|
||||
processedIDs: make(map[string]bool),
|
||||
BaseChannel: base,
|
||||
config: cfg,
|
||||
dedup: make(map[string]time.Time),
|
||||
done: make(chan struct{}),
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -52,6 +80,10 @@ func (c *QQChannel) Start(ctx context.Context) error {
|
||||
|
||||
logger.InfoC("qq", "Starting QQ bot (WebSocket mode)")
|
||||
|
||||
// Reinitialize shutdown signal for clean restart.
|
||||
c.done = make(chan struct{})
|
||||
c.stopOnce = sync.Once{}
|
||||
|
||||
// create token source
|
||||
credentials := &token.QQBotCredentials{
|
||||
AppID: c.config.AppID,
|
||||
@@ -99,6 +131,15 @@ func (c *QQChannel) Start(ctx context.Context) error {
|
||||
}
|
||||
}()
|
||||
|
||||
// start dedup janitor goroutine
|
||||
go c.dedupJanitor()
|
||||
|
||||
// Pre-register reasoning_channel_id as group chat if configured,
|
||||
// so outbound-only destinations are routed correctly.
|
||||
if c.config.ReasoningChannelID != "" {
|
||||
c.chatType.Store(c.config.ReasoningChannelID, "group")
|
||||
}
|
||||
|
||||
c.SetRunning(true)
|
||||
logger.InfoC("qq", "QQ bot started successfully")
|
||||
|
||||
@@ -109,6 +150,9 @@ func (c *QQChannel) Stop(ctx context.Context) error {
|
||||
logger.InfoC("qq", "Stopping QQ bot")
|
||||
c.SetRunning(false)
|
||||
|
||||
// Signal the dedup janitor to stop (idempotent).
|
||||
c.stopOnce.Do(func() { close(c.done) })
|
||||
|
||||
if c.cancel != nil {
|
||||
c.cancel()
|
||||
}
|
||||
@@ -116,21 +160,82 @@ func (c *QQChannel) Stop(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// getChatKind returns the chat type for a given chatID ("group" or "direct").
|
||||
// Unknown chatIDs default to "group" and log a warning, since QQ group IDs are
|
||||
// more common as outbound-only destinations (e.g. reasoning_channel_id).
|
||||
func (c *QQChannel) getChatKind(chatID string) string {
|
||||
if v, ok := c.chatType.Load(chatID); ok {
|
||||
if k, ok := v.(string); ok {
|
||||
return k
|
||||
}
|
||||
}
|
||||
logger.DebugCF("qq", "Unknown chat type for chatID, defaulting to group", map[string]any{
|
||||
"chat_id": chatID,
|
||||
})
|
||||
return "group"
|
||||
}
|
||||
|
||||
func (c *QQChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
|
||||
if !c.IsRunning() {
|
||||
return channels.ErrNotRunning
|
||||
}
|
||||
|
||||
// construct message
|
||||
chatKind := c.getChatKind(msg.ChatID)
|
||||
|
||||
// Build message with content.
|
||||
msgToCreate := &dto.MessageToCreate{
|
||||
Content: msg.Content,
|
||||
MsgType: dto.TextMsg,
|
||||
}
|
||||
|
||||
// Use Markdown message type if enabled in config.
|
||||
if c.config.SendMarkdown {
|
||||
msgToCreate.MsgType = dto.MarkdownMsg
|
||||
msgToCreate.Markdown = &dto.Markdown{
|
||||
Content: msg.Content,
|
||||
}
|
||||
// Clear plain content to avoid sending duplicate text.
|
||||
msgToCreate.Content = ""
|
||||
}
|
||||
|
||||
// Attach passive reply msg_id and msg_seq if available.
|
||||
if v, ok := c.lastMsgID.Load(msg.ChatID); ok {
|
||||
if msgID, ok := v.(string); ok && msgID != "" {
|
||||
msgToCreate.MsgID = msgID
|
||||
|
||||
// Increment msg_seq atomically for multi-part replies.
|
||||
if counterVal, ok := c.msgSeqCounters.Load(msg.ChatID); ok {
|
||||
if counter, ok := counterVal.(*atomic.Uint64); ok {
|
||||
seq := counter.Add(1)
|
||||
msgToCreate.MsgSeq = uint32(seq)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Sanitize URLs in group messages to avoid QQ's URL blacklist rejection.
|
||||
if chatKind == "group" {
|
||||
if msgToCreate.Content != "" {
|
||||
msgToCreate.Content = sanitizeURLs(msgToCreate.Content)
|
||||
}
|
||||
if msgToCreate.Markdown != nil && msgToCreate.Markdown.Content != "" {
|
||||
msgToCreate.Markdown.Content = sanitizeURLs(msgToCreate.Markdown.Content)
|
||||
}
|
||||
}
|
||||
|
||||
// Route to group or C2C.
|
||||
var err error
|
||||
if chatKind == "group" {
|
||||
_, err = c.api.PostGroupMessage(ctx, msg.ChatID, msgToCreate)
|
||||
} else {
|
||||
_, err = c.api.PostC2CMessage(ctx, msg.ChatID, msgToCreate)
|
||||
}
|
||||
|
||||
// send C2C message
|
||||
_, err := c.api.PostC2CMessage(ctx, msg.ChatID, msgToCreate)
|
||||
if err != nil {
|
||||
logger.ErrorCF("qq", "Failed to send C2C message", map[string]any{
|
||||
"error": err.Error(),
|
||||
logger.ErrorCF("qq", "Failed to send message", map[string]any{
|
||||
"chat_id": msg.ChatID,
|
||||
"chat_kind": chatKind,
|
||||
"error": err.Error(),
|
||||
})
|
||||
return fmt.Errorf("qq send: %w", channels.ErrTemporary)
|
||||
}
|
||||
@@ -138,7 +243,150 @@ func (c *QQChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleC2CMessage handles QQ private messages
|
||||
// StartTyping implements channels.TypingCapable.
|
||||
// It sends an InputNotify (msg_type=6) immediately and re-sends every 8 seconds.
|
||||
// The returned stop function is idempotent and cancels the goroutine.
|
||||
func (c *QQChannel) StartTyping(ctx context.Context, chatID string) (func(), error) {
|
||||
// We need a stored msg_id for passive InputNotify; skip if none available.
|
||||
v, ok := c.lastMsgID.Load(chatID)
|
||||
if !ok {
|
||||
return func() {}, nil
|
||||
}
|
||||
msgID, ok := v.(string)
|
||||
if !ok || msgID == "" {
|
||||
return func() {}, nil
|
||||
}
|
||||
|
||||
chatKind := c.getChatKind(chatID)
|
||||
|
||||
sendTyping := func(sendCtx context.Context) {
|
||||
typingMsg := &dto.MessageToCreate{
|
||||
MsgType: dto.InputNotifyMsg,
|
||||
MsgID: msgID,
|
||||
InputNotify: &dto.InputNotify{
|
||||
InputType: 1,
|
||||
InputSecond: typingSeconds,
|
||||
},
|
||||
}
|
||||
|
||||
var err error
|
||||
if chatKind == "group" {
|
||||
_, err = c.api.PostGroupMessage(sendCtx, chatID, typingMsg)
|
||||
} else {
|
||||
_, err = c.api.PostC2CMessage(sendCtx, chatID, typingMsg)
|
||||
}
|
||||
if err != nil {
|
||||
logger.DebugCF("qq", "Failed to send typing indicator", map[string]any{
|
||||
"chat_id": chatID,
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Send immediately.
|
||||
sendTyping(c.ctx)
|
||||
|
||||
typingCtx, cancel := context.WithCancel(c.ctx)
|
||||
go func() {
|
||||
ticker := time.NewTicker(typingResend)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-typingCtx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
sendTyping(typingCtx)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return cancel, nil
|
||||
}
|
||||
|
||||
// SendMedia implements the channels.MediaSender interface.
|
||||
// QQ RichMediaMessage requires an HTTP/HTTPS URL — local file paths are not supported.
|
||||
// If part.Ref is already an http(s) URL it is used directly; otherwise we try
|
||||
// the media store, and skip with a warning if the resolved path is not an HTTP URL.
|
||||
func (c *QQChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error {
|
||||
if !c.IsRunning() {
|
||||
return channels.ErrNotRunning
|
||||
}
|
||||
|
||||
chatKind := c.getChatKind(msg.ChatID)
|
||||
|
||||
for _, part := range msg.Parts {
|
||||
// If the ref is already an HTTP(S) URL, use it directly.
|
||||
mediaURL := part.Ref
|
||||
if !isHTTPURL(mediaURL) {
|
||||
// Try resolving through media store.
|
||||
store := c.GetMediaStore()
|
||||
if store == nil {
|
||||
logger.WarnCF("qq", "QQ media requires HTTP/HTTPS URL, no media store available", map[string]any{
|
||||
"ref": part.Ref,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
resolved, err := store.Resolve(part.Ref)
|
||||
if err != nil {
|
||||
logger.ErrorCF("qq", "Failed to resolve media ref", map[string]any{
|
||||
"ref": part.Ref,
|
||||
"error": err.Error(),
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
if !isHTTPURL(resolved) {
|
||||
logger.WarnCF("qq", "QQ media requires HTTP/HTTPS URL, local files not supported", map[string]any{
|
||||
"ref": part.Ref,
|
||||
"resolved": resolved,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
mediaURL = resolved
|
||||
}
|
||||
|
||||
// Map part type to QQ file type: 1=image, 2=video, 3=audio, 4=file.
|
||||
var fileType uint64
|
||||
switch part.Type {
|
||||
case "image":
|
||||
fileType = 1
|
||||
case "video":
|
||||
fileType = 2
|
||||
case "audio":
|
||||
fileType = 3
|
||||
default:
|
||||
fileType = 4 // file
|
||||
}
|
||||
|
||||
richMedia := &dto.RichMediaMessage{
|
||||
FileType: fileType,
|
||||
URL: mediaURL,
|
||||
SrvSendMsg: true,
|
||||
}
|
||||
|
||||
var sendErr error
|
||||
if chatKind == "group" {
|
||||
_, sendErr = c.api.PostGroupMessage(ctx, msg.ChatID, richMedia)
|
||||
} else {
|
||||
_, sendErr = c.api.PostC2CMessage(ctx, msg.ChatID, richMedia)
|
||||
}
|
||||
|
||||
if sendErr != nil {
|
||||
logger.ErrorCF("qq", "Failed to send media", map[string]any{
|
||||
"type": part.Type,
|
||||
"chat_id": msg.ChatID,
|
||||
"error": sendErr.Error(),
|
||||
})
|
||||
return fmt.Errorf("qq send media: %w", channels.ErrTemporary)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleC2CMessage handles QQ private messages.
|
||||
func (c *QQChannel) handleC2CMessage() event.C2CMessageEventHandler {
|
||||
return func(event *dto.WSPayload, data *dto.WSC2CMessageData) error {
|
||||
// deduplication check
|
||||
@@ -167,7 +415,13 @@ func (c *QQChannel) handleC2CMessage() event.C2CMessageEventHandler {
|
||||
"length": len(content),
|
||||
})
|
||||
|
||||
// 转发到消息总线
|
||||
// Store chat routing context.
|
||||
c.chatType.Store(senderID, "direct")
|
||||
c.lastMsgID.Store(senderID, data.ID)
|
||||
|
||||
// Reset msg_seq counter for new inbound message.
|
||||
c.msgSeqCounters.Store(senderID, new(atomic.Uint64))
|
||||
|
||||
metadata := map[string]string{}
|
||||
|
||||
sender := bus.SenderInfo{
|
||||
@@ -195,7 +449,7 @@ func (c *QQChannel) handleC2CMessage() event.C2CMessageEventHandler {
|
||||
}
|
||||
}
|
||||
|
||||
// handleGroupATMessage handles QQ group @ messages
|
||||
// handleGroupATMessage handles QQ group @ messages.
|
||||
func (c *QQChannel) handleGroupATMessage() event.GroupATMessageEventHandler {
|
||||
return func(event *dto.WSPayload, data *dto.WSGroupATMessageData) error {
|
||||
// deduplication check
|
||||
@@ -232,7 +486,13 @@ func (c *QQChannel) handleGroupATMessage() event.GroupATMessageEventHandler {
|
||||
"length": len(content),
|
||||
})
|
||||
|
||||
// 转发到消息总线(使用 GroupID 作为 ChatID)
|
||||
// Store chat routing context using GroupID as chatID.
|
||||
c.chatType.Store(data.GroupID, "group")
|
||||
c.lastMsgID.Store(data.GroupID, data.ID)
|
||||
|
||||
// Reset msg_seq counter for new inbound message.
|
||||
c.msgSeqCounters.Store(data.GroupID, new(atomic.Uint64))
|
||||
|
||||
metadata := map[string]string{
|
||||
"group_id": data.GroupID,
|
||||
}
|
||||
@@ -262,29 +522,102 @@ func (c *QQChannel) handleGroupATMessage() event.GroupATMessageEventHandler {
|
||||
}
|
||||
}
|
||||
|
||||
// isDuplicate 检查消息是否重复
|
||||
// isDuplicate checks whether a message has been seen within the TTL window.
|
||||
// It also enforces a hard cap on map size by evicting oldest entries.
|
||||
func (c *QQChannel) isDuplicate(messageID string) bool {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.muDedup.Lock()
|
||||
defer c.muDedup.Unlock()
|
||||
|
||||
if c.processedIDs[messageID] {
|
||||
if ts, exists := c.dedup[messageID]; exists && time.Since(ts) < dedupTTL {
|
||||
return true
|
||||
}
|
||||
|
||||
c.processedIDs[messageID] = true
|
||||
|
||||
// 简单清理:限制 map 大小
|
||||
if len(c.processedIDs) > 10000 {
|
||||
// 清空一半
|
||||
count := 0
|
||||
for id := range c.processedIDs {
|
||||
if count >= 5000 {
|
||||
break
|
||||
// Enforce hard cap: evict oldest entries when at capacity.
|
||||
if len(c.dedup) >= dedupMaxSize {
|
||||
var oldestID string
|
||||
var oldestTS time.Time
|
||||
for id, ts := range c.dedup {
|
||||
if oldestID == "" || ts.Before(oldestTS) {
|
||||
oldestID = id
|
||||
oldestTS = ts
|
||||
}
|
||||
delete(c.processedIDs, id)
|
||||
count++
|
||||
}
|
||||
if oldestID != "" {
|
||||
delete(c.dedup, oldestID)
|
||||
}
|
||||
}
|
||||
|
||||
c.dedup[messageID] = time.Now()
|
||||
return false
|
||||
}
|
||||
|
||||
// dedupJanitor periodically evicts expired entries from the dedup map.
|
||||
func (c *QQChannel) dedupJanitor() {
|
||||
ticker := time.NewTicker(dedupInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-c.done:
|
||||
return
|
||||
case <-ticker.C:
|
||||
// Collect expired keys under read-like scan.
|
||||
c.muDedup.Lock()
|
||||
now := time.Now()
|
||||
var expired []string
|
||||
for id, ts := range c.dedup {
|
||||
if now.Sub(ts) >= dedupTTL {
|
||||
expired = append(expired, id)
|
||||
}
|
||||
}
|
||||
for _, id := range expired {
|
||||
delete(c.dedup, id)
|
||||
}
|
||||
c.muDedup.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// isHTTPURL returns true if s starts with http:// or https://.
|
||||
func isHTTPURL(s string) bool {
|
||||
return strings.HasPrefix(s, "http://") || strings.HasPrefix(s, "https://")
|
||||
}
|
||||
|
||||
// urlPattern matches URLs with explicit http(s):// scheme.
|
||||
// Only scheme-prefixed URLs are matched to avoid false positives on bare text
|
||||
// like version numbers (e.g., "1.2.3") or domain-like fragments.
|
||||
var urlPattern = regexp.MustCompile(
|
||||
`(?i)` +
|
||||
`https?://` + // required scheme
|
||||
`(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+` + // domain parts
|
||||
`[a-zA-Z]{2,}` + // TLD
|
||||
`(?:[/?#]\S*)?`, // optional path/query/fragment
|
||||
)
|
||||
|
||||
// sanitizeURLs replaces dots in URL domains with "。" (fullwidth period)
|
||||
// to prevent QQ's URL blacklist from rejecting the message.
|
||||
func sanitizeURLs(text string) string {
|
||||
return urlPattern.ReplaceAllStringFunc(text, func(match string) string {
|
||||
// Split into scheme + rest (scheme is always present).
|
||||
idx := strings.Index(match, "://")
|
||||
scheme := match[:idx+3]
|
||||
rest := match[idx+3:]
|
||||
|
||||
// Find where the domain ends (first / ? or #).
|
||||
domainEnd := len(rest)
|
||||
for i, ch := range rest {
|
||||
if ch == '/' || ch == '?' || ch == '#' {
|
||||
domainEnd = i
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
domain := rest[:domainEnd]
|
||||
path := rest[domainEnd:]
|
||||
|
||||
// Replace dots in domain only.
|
||||
domain = strings.ReplaceAll(domain, ".", "。")
|
||||
|
||||
return scheme + domain + path
|
||||
})
|
||||
}
|
||||
|
||||
@@ -168,7 +168,7 @@ func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) err
|
||||
return channels.ErrNotRunning
|
||||
}
|
||||
|
||||
chatID, err := parseChatID(msg.ChatID)
|
||||
chatID, threadID, err := parseTelegramChatID(msg.ChatID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid chat ID %s: %w", msg.ChatID, channels.ErrSendFailed)
|
||||
}
|
||||
@@ -201,7 +201,7 @@ func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) err
|
||||
continue
|
||||
}
|
||||
|
||||
if err := c.sendHTMLChunk(ctx, chatID, htmlContent, chunk, replyToID); err != nil {
|
||||
if err := c.sendHTMLChunk(ctx, chatID, threadID, htmlContent, chunk, replyToID); err != nil {
|
||||
return err
|
||||
}
|
||||
// Only the first chunk should be a reply; subsequent chunks are normal messages.
|
||||
@@ -214,12 +214,11 @@ func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) err
|
||||
// sendHTMLChunk sends a single HTML message, falling back to the original
|
||||
// markdown as plain text on parse failure so users never see raw HTML tags.
|
||||
func (c *TelegramChannel) sendHTMLChunk(
|
||||
ctx context.Context,
|
||||
chatID int64,
|
||||
htmlContent, mdFallback, replyToID string,
|
||||
ctx context.Context, chatID int64, threadID int, htmlContent, mdFallback string, replyToID string
|
||||
) error {
|
||||
tgMsg := tu.Message(tu.ID(chatID), htmlContent)
|
||||
tgMsg.ParseMode = telego.ModeHTML
|
||||
tgMsg.MessageThreadID = threadID
|
||||
|
||||
if replyToID != "" {
|
||||
if mid, parseErr := strconv.Atoi(replyToID); parseErr == nil {
|
||||
@@ -247,13 +246,16 @@ func (c *TelegramChannel) sendHTMLChunk(
|
||||
// (Telegram's typing indicator expires after ~5s) in a background goroutine.
|
||||
// The returned stop function is idempotent and cancels the goroutine.
|
||||
func (c *TelegramChannel) StartTyping(ctx context.Context, chatID string) (func(), error) {
|
||||
cid, err := parseChatID(chatID)
|
||||
cid, threadID, err := parseTelegramChatID(chatID)
|
||||
if err != nil {
|
||||
return func() {}, err
|
||||
}
|
||||
|
||||
action := tu.ChatAction(tu.ID(cid), telego.ChatActionTyping)
|
||||
action.MessageThreadID = threadID
|
||||
|
||||
// Send the first typing action immediately
|
||||
_ = c.bot.SendChatAction(ctx, tu.ChatAction(tu.ID(cid), telego.ChatActionTyping))
|
||||
_ = c.bot.SendChatAction(ctx, action)
|
||||
|
||||
typingCtx, cancel := context.WithCancel(ctx)
|
||||
go func() {
|
||||
@@ -264,7 +266,9 @@ func (c *TelegramChannel) StartTyping(ctx context.Context, chatID string) (func(
|
||||
case <-typingCtx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
_ = c.bot.SendChatAction(typingCtx, tu.ChatAction(tu.ID(cid), telego.ChatActionTyping))
|
||||
a := tu.ChatAction(tu.ID(cid), telego.ChatActionTyping)
|
||||
a.MessageThreadID = threadID
|
||||
_ = c.bot.SendChatAction(typingCtx, a)
|
||||
}
|
||||
}
|
||||
}()
|
||||
@@ -274,7 +278,7 @@ func (c *TelegramChannel) StartTyping(ctx context.Context, chatID string) (func(
|
||||
|
||||
// EditMessage implements channels.MessageEditor.
|
||||
func (c *TelegramChannel) EditMessage(ctx context.Context, chatID string, messageID string, content string) error {
|
||||
cid, err := parseChatID(chatID)
|
||||
cid, _, err := parseTelegramChatID(chatID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -303,12 +307,14 @@ func (c *TelegramChannel) SendPlaceholder(ctx context.Context, chatID string) (s
|
||||
text = "Thinking... 💭"
|
||||
}
|
||||
|
||||
cid, err := parseChatID(chatID)
|
||||
cid, threadID, err := parseTelegramChatID(chatID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
pMsg, err := c.bot.SendMessage(ctx, tu.Message(tu.ID(cid), text))
|
||||
phMsg := tu.Message(tu.ID(cid), text)
|
||||
phMsg.MessageThreadID = threadID
|
||||
pMsg, err := c.bot.SendMessage(ctx, phMsg)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -322,7 +328,7 @@ func (c *TelegramChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMe
|
||||
return channels.ErrNotRunning
|
||||
}
|
||||
|
||||
chatID, err := parseChatID(msg.ChatID)
|
||||
chatID, threadID, err := parseTelegramChatID(msg.ChatID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid chat ID %s: %w", msg.ChatID, channels.ErrSendFailed)
|
||||
}
|
||||
@@ -354,30 +360,34 @@ func (c *TelegramChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMe
|
||||
switch part.Type {
|
||||
case "image":
|
||||
params := &telego.SendPhotoParams{
|
||||
ChatID: tu.ID(chatID),
|
||||
Photo: telego.InputFile{File: file},
|
||||
Caption: part.Caption,
|
||||
ChatID: tu.ID(chatID),
|
||||
MessageThreadID: threadID,
|
||||
Photo: telego.InputFile{File: file},
|
||||
Caption: part.Caption,
|
||||
}
|
||||
_, err = c.bot.SendPhoto(ctx, params)
|
||||
case "audio":
|
||||
params := &telego.SendAudioParams{
|
||||
ChatID: tu.ID(chatID),
|
||||
Audio: telego.InputFile{File: file},
|
||||
Caption: part.Caption,
|
||||
ChatID: tu.ID(chatID),
|
||||
MessageThreadID: threadID,
|
||||
Audio: telego.InputFile{File: file},
|
||||
Caption: part.Caption,
|
||||
}
|
||||
_, err = c.bot.SendAudio(ctx, params)
|
||||
case "video":
|
||||
params := &telego.SendVideoParams{
|
||||
ChatID: tu.ID(chatID),
|
||||
Video: telego.InputFile{File: file},
|
||||
Caption: part.Caption,
|
||||
ChatID: tu.ID(chatID),
|
||||
MessageThreadID: threadID,
|
||||
Video: telego.InputFile{File: file},
|
||||
Caption: part.Caption,
|
||||
}
|
||||
_, err = c.bot.SendVideo(ctx, params)
|
||||
default: // "file" or unknown types
|
||||
params := &telego.SendDocumentParams{
|
||||
ChatID: tu.ID(chatID),
|
||||
Document: telego.InputFile{File: file},
|
||||
Caption: part.Caption,
|
||||
ChatID: tu.ID(chatID),
|
||||
MessageThreadID: threadID,
|
||||
Document: telego.InputFile{File: file},
|
||||
Caption: part.Caption,
|
||||
}
|
||||
_, err = c.bot.SendDocument(ctx, params)
|
||||
}
|
||||
@@ -521,19 +531,28 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Mes
|
||||
content = cleaned
|
||||
}
|
||||
|
||||
// For forum topics, embed the thread ID as "chatID/threadID" so replies
|
||||
// route to the correct topic and each topic gets its own session.
|
||||
// Only forum groups (IsForum) are handled; regular group reply threads
|
||||
// must share one session per group.
|
||||
compositeChatID := fmt.Sprintf("%d", chatID)
|
||||
threadID := message.MessageThreadID
|
||||
if message.Chat.IsForum && threadID != 0 {
|
||||
compositeChatID = fmt.Sprintf("%d/%d", chatID, threadID)
|
||||
}
|
||||
|
||||
logger.DebugCF("telegram", "Received message", map[string]any{
|
||||
"sender_id": sender.CanonicalID,
|
||||
"chat_id": fmt.Sprintf("%d", chatID),
|
||||
"chat_id": compositeChatID,
|
||||
"thread_id": threadID,
|
||||
"preview": utils.Truncate(content, 50),
|
||||
})
|
||||
|
||||
// Placeholder is now auto-triggered by BaseChannel.HandleMessage via PlaceholderCapable
|
||||
|
||||
peerKind := "direct"
|
||||
peerID := fmt.Sprintf("%d", user.ID)
|
||||
if message.Chat.Type != "private" {
|
||||
peerKind = "group"
|
||||
peerID = fmt.Sprintf("%d", chatID)
|
||||
peerID = compositeChatID
|
||||
}
|
||||
|
||||
peer := bus.Peer{Kind: peerKind, ID: peerID}
|
||||
@@ -546,11 +565,17 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Mes
|
||||
"is_group": fmt.Sprintf("%t", message.Chat.Type != "private"),
|
||||
}
|
||||
|
||||
// Set parent_peer metadata for per-topic agent binding.
|
||||
if message.Chat.IsForum && threadID != 0 {
|
||||
metadata["parent_peer_kind"] = "topic"
|
||||
metadata["parent_peer_id"] = fmt.Sprintf("%d", threadID)
|
||||
}
|
||||
|
||||
c.HandleMessage(c.ctx,
|
||||
peer,
|
||||
messageID,
|
||||
platformID,
|
||||
fmt.Sprintf("%d", chatID),
|
||||
compositeChatID,
|
||||
content,
|
||||
mediaPaths,
|
||||
metadata,
|
||||
@@ -598,10 +623,23 @@ func (c *TelegramChannel) downloadFile(ctx context.Context, fileID, ext string)
|
||||
return c.downloadFileWithInfo(file, ext)
|
||||
}
|
||||
|
||||
func parseChatID(chatIDStr string) (int64, error) {
|
||||
var id int64
|
||||
_, err := fmt.Sscanf(chatIDStr, "%d", &id)
|
||||
return id, err
|
||||
// parseTelegramChatID splits "chatID/threadID" into its components.
|
||||
// Returns threadID=0 when no "/" is present (non-forum messages).
|
||||
func parseTelegramChatID(chatID string) (int64, int, error) {
|
||||
idx := strings.Index(chatID, "/")
|
||||
if idx == -1 {
|
||||
cid, err := strconv.ParseInt(chatID, 10, 64)
|
||||
return cid, 0, err
|
||||
}
|
||||
cid, err := strconv.ParseInt(chatID[:idx], 10, 64)
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
tid, err := strconv.Atoi(chatID[idx+1:])
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("invalid thread ID in chat ID %q: %w", chatID, err)
|
||||
}
|
||||
return cid, tid, nil
|
||||
}
|
||||
|
||||
func markdownToTelegramHTML(text string) string {
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mymmrac/telego"
|
||||
ta "github.com/mymmrac/telego/telegoapi"
|
||||
@@ -271,3 +272,191 @@ func TestSend_InvalidChatID(t *testing.T) {
|
||||
assert.True(t, errors.Is(err, channels.ErrSendFailed), "error should wrap ErrSendFailed")
|
||||
assert.Empty(t, caller.calls)
|
||||
}
|
||||
|
||||
func TestParseTelegramChatID_Plain(t *testing.T) {
|
||||
cid, tid, err := parseTelegramChatID("12345")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(12345), cid)
|
||||
assert.Equal(t, 0, tid)
|
||||
}
|
||||
|
||||
func TestParseTelegramChatID_NegativeGroup(t *testing.T) {
|
||||
cid, tid, err := parseTelegramChatID("-1001234567890")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(-1001234567890), cid)
|
||||
assert.Equal(t, 0, tid)
|
||||
}
|
||||
|
||||
func TestParseTelegramChatID_WithThreadID(t *testing.T) {
|
||||
cid, tid, err := parseTelegramChatID("-1001234567890/42")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(-1001234567890), cid)
|
||||
assert.Equal(t, 42, tid)
|
||||
}
|
||||
|
||||
func TestParseTelegramChatID_GeneralTopic(t *testing.T) {
|
||||
cid, tid, err := parseTelegramChatID("-100123/1")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(-100123), cid)
|
||||
assert.Equal(t, 1, tid)
|
||||
}
|
||||
|
||||
func TestParseTelegramChatID_Invalid(t *testing.T) {
|
||||
_, _, err := parseTelegramChatID("not-a-number")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestParseTelegramChatID_InvalidThreadID(t *testing.T) {
|
||||
_, _, err := parseTelegramChatID("-100123/not-a-thread")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid thread ID")
|
||||
}
|
||||
|
||||
func TestSend_WithForumThreadID(t *testing.T) {
|
||||
caller := &stubCaller{
|
||||
callFn: func(ctx context.Context, url string, data *ta.RequestData) (*ta.Response, error) {
|
||||
return successResponse(t), nil
|
||||
},
|
||||
}
|
||||
ch := newTestChannel(t, caller)
|
||||
|
||||
err := ch.Send(context.Background(), bus.OutboundMessage{
|
||||
ChatID: "-1001234567890/42",
|
||||
Content: "Hello from topic",
|
||||
})
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, caller.calls, 1)
|
||||
}
|
||||
|
||||
func TestHandleMessage_ForumTopic_SetsMetadata(t *testing.T) {
|
||||
messageBus := bus.NewMessageBus()
|
||||
ch := &TelegramChannel{
|
||||
BaseChannel: channels.NewBaseChannel("telegram", nil, messageBus, nil),
|
||||
chatIDs: make(map[string]int64),
|
||||
ctx: context.Background(),
|
||||
}
|
||||
|
||||
msg := &telego.Message{
|
||||
Text: "hello from topic",
|
||||
MessageID: 10,
|
||||
MessageThreadID: 42,
|
||||
Chat: telego.Chat{
|
||||
ID: -1001234567890,
|
||||
Type: "supergroup",
|
||||
IsForum: true,
|
||||
},
|
||||
From: &telego.User{
|
||||
ID: 7,
|
||||
FirstName: "Alice",
|
||||
},
|
||||
}
|
||||
|
||||
err := ch.handleMessage(context.Background(), msg)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
|
||||
inbound, ok := messageBus.ConsumeInbound(ctx)
|
||||
require.True(t, ok, "expected inbound message")
|
||||
|
||||
// Composite chatID should include thread ID
|
||||
assert.Equal(t, "-1001234567890/42", inbound.ChatID)
|
||||
|
||||
// Peer ID should include thread ID for session key isolation
|
||||
assert.Equal(t, "group", inbound.Peer.Kind)
|
||||
assert.Equal(t, "-1001234567890/42", inbound.Peer.ID)
|
||||
|
||||
// Parent peer metadata should be set for agent binding
|
||||
assert.Equal(t, "topic", inbound.Metadata["parent_peer_kind"])
|
||||
assert.Equal(t, "42", inbound.Metadata["parent_peer_id"])
|
||||
}
|
||||
|
||||
func TestHandleMessage_NoForum_NoThreadMetadata(t *testing.T) {
|
||||
messageBus := bus.NewMessageBus()
|
||||
ch := &TelegramChannel{
|
||||
BaseChannel: channels.NewBaseChannel("telegram", nil, messageBus, nil),
|
||||
chatIDs: make(map[string]int64),
|
||||
ctx: context.Background(),
|
||||
}
|
||||
|
||||
msg := &telego.Message{
|
||||
Text: "regular group message",
|
||||
MessageID: 11,
|
||||
Chat: telego.Chat{
|
||||
ID: -100999,
|
||||
Type: "group",
|
||||
},
|
||||
From: &telego.User{
|
||||
ID: 8,
|
||||
FirstName: "Bob",
|
||||
},
|
||||
}
|
||||
|
||||
err := ch.handleMessage(context.Background(), msg)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
|
||||
inbound, ok := messageBus.ConsumeInbound(ctx)
|
||||
require.True(t, ok)
|
||||
|
||||
// Plain chatID without thread suffix
|
||||
assert.Equal(t, "-100999", inbound.ChatID)
|
||||
|
||||
// Peer ID should be raw chat ID (no thread suffix)
|
||||
assert.Equal(t, "group", inbound.Peer.Kind)
|
||||
assert.Equal(t, "-100999", inbound.Peer.ID)
|
||||
|
||||
// No parent peer metadata
|
||||
assert.Empty(t, inbound.Metadata["parent_peer_kind"])
|
||||
assert.Empty(t, inbound.Metadata["parent_peer_id"])
|
||||
}
|
||||
|
||||
func TestHandleMessage_ReplyThread_NonForum_NoIsolation(t *testing.T) {
|
||||
messageBus := bus.NewMessageBus()
|
||||
ch := &TelegramChannel{
|
||||
BaseChannel: channels.NewBaseChannel("telegram", nil, messageBus, nil),
|
||||
chatIDs: make(map[string]int64),
|
||||
ctx: context.Background(),
|
||||
}
|
||||
|
||||
// In regular groups, reply threads set MessageThreadID to the original
|
||||
// message ID. This should NOT trigger per-thread session isolation.
|
||||
msg := &telego.Message{
|
||||
Text: "reply in thread",
|
||||
MessageID: 20,
|
||||
MessageThreadID: 15,
|
||||
Chat: telego.Chat{
|
||||
ID: -100999,
|
||||
Type: "supergroup",
|
||||
IsForum: false,
|
||||
},
|
||||
From: &telego.User{
|
||||
ID: 9,
|
||||
FirstName: "Carol",
|
||||
},
|
||||
}
|
||||
|
||||
err := ch.handleMessage(context.Background(), msg)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
|
||||
inbound, ok := messageBus.ConsumeInbound(ctx)
|
||||
require.True(t, ok)
|
||||
|
||||
// chatID should NOT include thread suffix for non-forum groups
|
||||
assert.Equal(t, "-100999", inbound.ChatID)
|
||||
|
||||
// Peer ID should be raw chat ID (shared session for whole group)
|
||||
assert.Equal(t, "group", inbound.Peer.Kind)
|
||||
assert.Equal(t, "-100999", inbound.Peer.ID)
|
||||
|
||||
// No parent peer metadata
|
||||
assert.Empty(t, inbound.Metadata["parent_peer_kind"])
|
||||
assert.Empty(t, inbound.Metadata["parent_peer_id"])
|
||||
}
|
||||
|
||||
@@ -12,5 +12,6 @@ func BuiltinDefinitions() []Definition {
|
||||
listCommand(),
|
||||
switchCommand(),
|
||||
checkCommand(),
|
||||
clearCommand(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,20 @@
|
||||
package commands
|
||||
|
||||
import "context"
|
||||
|
||||
func clearCommand() Definition {
|
||||
return Definition{
|
||||
Name: "clear",
|
||||
Description: "Clear the chat history",
|
||||
Usage: "/clear",
|
||||
Handler: func(_ context.Context, req Request, rt *Runtime) error {
|
||||
if rt == nil || rt.ClearHistory == nil {
|
||||
return req.Reply(unavailableMsg)
|
||||
}
|
||||
if err := rt.ClearHistory(); err != nil {
|
||||
return req.Reply("Failed to clear chat history: " + err.Error())
|
||||
}
|
||||
return req.Reply("Chat history cleared!")
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -13,4 +13,5 @@ type Runtime struct {
|
||||
GetEnabledChannels func() []string
|
||||
SwitchModel func(value string) (oldModel string, err error)
|
||||
SwitchChannel func(value string) error
|
||||
ClearHistory func() error
|
||||
}
|
||||
|
||||
+94
-27
@@ -4,6 +4,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/caarlos0/env/v11"
|
||||
@@ -58,7 +59,16 @@ type Config struct {
|
||||
Tools ToolsConfig `json:"tools"`
|
||||
Heartbeat HeartbeatConfig `json:"heartbeat"`
|
||||
Devices DevicesConfig `json:"devices"`
|
||||
Voice VoiceConfig `json:"voice"`
|
||||
// BuildInfo contains build-time version information
|
||||
BuildInfo BuildInfo `json:"build_info,omitempty"`
|
||||
}
|
||||
|
||||
// BuildInfo contains build-time version information
|
||||
type BuildInfo struct {
|
||||
Version string `json:"version"`
|
||||
GitCommit string `json:"git_commit"`
|
||||
BuildTime string `json:"build_time"`
|
||||
GoVersion string `json:"go_version"`
|
||||
}
|
||||
|
||||
// MarshalJSON implements custom JSON marshaling for Config
|
||||
@@ -226,6 +236,7 @@ type ChannelsConfig struct {
|
||||
QQ QQConfig `json:"qq"`
|
||||
DingTalk DingTalkConfig `json:"dingtalk"`
|
||||
Slack SlackConfig `json:"slack"`
|
||||
Matrix MatrixConfig `json:"matrix"`
|
||||
LINE LINEConfig `json:"line"`
|
||||
OneBot OneBotConfig `json:"onebot"`
|
||||
WeCom WeComConfig `json:"wecom"`
|
||||
@@ -274,15 +285,16 @@ type TelegramConfig struct {
|
||||
}
|
||||
|
||||
type FeishuConfig struct {
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_FEISHU_ENABLED"`
|
||||
AppID string `json:"app_id" env:"PICOCLAW_CHANNELS_FEISHU_APP_ID"`
|
||||
AppSecret string `json:"app_secret" env:"PICOCLAW_CHANNELS_FEISHU_APP_SECRET"`
|
||||
EncryptKey string `json:"encrypt_key" env:"PICOCLAW_CHANNELS_FEISHU_ENCRYPT_KEY"`
|
||||
VerificationToken string `json:"verification_token" env:"PICOCLAW_CHANNELS_FEISHU_VERIFICATION_TOKEN"`
|
||||
AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_FEISHU_ALLOW_FROM"`
|
||||
GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"`
|
||||
Placeholder PlaceholderConfig `json:"placeholder,omitempty"`
|
||||
ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_FEISHU_REASONING_CHANNEL_ID"`
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_FEISHU_ENABLED"`
|
||||
AppID string `json:"app_id" env:"PICOCLAW_CHANNELS_FEISHU_APP_ID"`
|
||||
AppSecret string `json:"app_secret" env:"PICOCLAW_CHANNELS_FEISHU_APP_SECRET"`
|
||||
EncryptKey string `json:"encrypt_key" env:"PICOCLAW_CHANNELS_FEISHU_ENCRYPT_KEY"`
|
||||
VerificationToken string `json:"verification_token" env:"PICOCLAW_CHANNELS_FEISHU_VERIFICATION_TOKEN"`
|
||||
AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_FEISHU_ALLOW_FROM"`
|
||||
GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"`
|
||||
Placeholder PlaceholderConfig `json:"placeholder,omitempty"`
|
||||
ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_FEISHU_REASONING_CHANNEL_ID"`
|
||||
RandomReactionEmoji FlexibleStringSlice `json:"random_reaction_emoji" env:"PICOCLAW_CHANNELS_FEISHU_RANDOM_REACTION_EMOJI"`
|
||||
}
|
||||
|
||||
type DiscordConfig struct {
|
||||
@@ -311,6 +323,8 @@ type QQConfig struct {
|
||||
AppSecret string `json:"app_secret" env:"PICOCLAW_CHANNELS_QQ_APP_SECRET"`
|
||||
AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_QQ_ALLOW_FROM"`
|
||||
GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"`
|
||||
MaxMessageLength int `json:"max_message_length" env:"PICOCLAW_CHANNELS_QQ_MAX_MESSAGE_LENGTH"`
|
||||
SendMarkdown bool `json:"send_markdown" env:"PICOCLAW_CHANNELS_QQ_SEND_MARKDOWN"`
|
||||
ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_QQ_REASONING_CHANNEL_ID"`
|
||||
}
|
||||
|
||||
@@ -334,6 +348,19 @@ type SlackConfig struct {
|
||||
ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_SLACK_REASONING_CHANNEL_ID"`
|
||||
}
|
||||
|
||||
type MatrixConfig struct {
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_MATRIX_ENABLED"`
|
||||
Homeserver string `json:"homeserver" env:"PICOCLAW_CHANNELS_MATRIX_HOMESERVER"`
|
||||
UserID string `json:"user_id" env:"PICOCLAW_CHANNELS_MATRIX_USER_ID"`
|
||||
AccessToken string `json:"access_token" env:"PICOCLAW_CHANNELS_MATRIX_ACCESS_TOKEN"`
|
||||
DeviceID string `json:"device_id,omitempty" env:"PICOCLAW_CHANNELS_MATRIX_DEVICE_ID"`
|
||||
JoinOnInvite bool `json:"join_on_invite" env:"PICOCLAW_CHANNELS_MATRIX_JOIN_ON_INVITE"`
|
||||
AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_MATRIX_ALLOW_FROM"`
|
||||
GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"`
|
||||
Placeholder PlaceholderConfig `json:"placeholder,omitempty"`
|
||||
ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_MATRIX_REASONING_CHANNEL_ID"`
|
||||
}
|
||||
|
||||
type LINEConfig struct {
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_LINE_ENABLED"`
|
||||
ChannelSecret string `json:"channel_secret" env:"PICOCLAW_CHANNELS_LINE_CHANNEL_SECRET"`
|
||||
@@ -445,10 +472,6 @@ type DevicesConfig struct {
|
||||
MonitorUSB bool `json:"monitor_usb" env:"PICOCLAW_DEVICES_MONITOR_USB"`
|
||||
}
|
||||
|
||||
type VoiceConfig struct {
|
||||
EchoTranscription bool `json:"echo_transcription" env:"PICOCLAW_VOICE_ECHO_TRANSCRIPTION"`
|
||||
}
|
||||
|
||||
type ProvidersConfig struct {
|
||||
Anthropic ProviderConfig `json:"anthropic"`
|
||||
OpenAI OpenAIProviderConfig `json:"openai"`
|
||||
@@ -464,12 +487,14 @@ type ProvidersConfig struct {
|
||||
ShengSuanYun ProviderConfig `json:"shengsuanyun"`
|
||||
DeepSeek ProviderConfig `json:"deepseek"`
|
||||
Cerebras ProviderConfig `json:"cerebras"`
|
||||
Vivgrid ProviderConfig `json:"vivgrid"`
|
||||
VolcEngine ProviderConfig `json:"volcengine"`
|
||||
GitHubCopilot ProviderConfig `json:"github_copilot"`
|
||||
Antigravity ProviderConfig `json:"antigravity"`
|
||||
Qwen ProviderConfig `json:"qwen"`
|
||||
Mistral ProviderConfig `json:"mistral"`
|
||||
Avian ProviderConfig `json:"avian"`
|
||||
Minimax ProviderConfig `json:"minimax"`
|
||||
}
|
||||
|
||||
// IsEmpty checks if all provider configs are empty (no API keys or API bases set)
|
||||
@@ -489,12 +514,14 @@ func (p ProvidersConfig) IsEmpty() bool {
|
||||
p.ShengSuanYun.APIKey == "" && p.ShengSuanYun.APIBase == "" &&
|
||||
p.DeepSeek.APIKey == "" && p.DeepSeek.APIBase == "" &&
|
||||
p.Cerebras.APIKey == "" && p.Cerebras.APIBase == "" &&
|
||||
p.Vivgrid.APIKey == "" && p.Vivgrid.APIBase == "" &&
|
||||
p.VolcEngine.APIKey == "" && p.VolcEngine.APIBase == "" &&
|
||||
p.GitHubCopilot.APIKey == "" && p.GitHubCopilot.APIBase == "" &&
|
||||
p.Antigravity.APIKey == "" && p.Antigravity.APIBase == "" &&
|
||||
p.Qwen.APIKey == "" && p.Qwen.APIBase == "" &&
|
||||
p.Mistral.APIKey == "" && p.Mistral.APIBase == "" &&
|
||||
p.Avian.APIKey == "" && p.Avian.APIBase == ""
|
||||
p.Avian.APIKey == "" && p.Avian.APIBase == "" &&
|
||||
p.Minimax.APIKey == "" && p.Minimax.APIBase == ""
|
||||
}
|
||||
|
||||
// MarshalJSON implements custom JSON marshaling for ProvidersConfig
|
||||
@@ -564,21 +591,31 @@ type GatewayConfig struct {
|
||||
Port int `json:"port" env:"PICOCLAW_GATEWAY_PORT"`
|
||||
}
|
||||
|
||||
type ToolDiscoveryConfig struct {
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_DISCOVERY_ENABLED"`
|
||||
TTL int `json:"ttl" env:"PICOCLAW_TOOLS_DISCOVERY_TTL"`
|
||||
MaxSearchResults int `json:"max_search_results" env:"PICOCLAW_MAX_SEARCH_RESULTS"`
|
||||
UseBM25 bool `json:"use_bm25" env:"PICOCLAW_TOOLS_DISCOVERY_USE_BM25"`
|
||||
UseRegex bool `json:"use_regex" env:"PICOCLAW_TOOLS_DISCOVERY_USE_REGEX"`
|
||||
}
|
||||
|
||||
type ToolConfig struct {
|
||||
Enabled bool `json:"enabled" env:"ENABLED"`
|
||||
}
|
||||
|
||||
type BraveConfig struct {
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_WEB_BRAVE_ENABLED"`
|
||||
APIKey string `json:"api_key" env:"PICOCLAW_TOOLS_WEB_BRAVE_API_KEY"`
|
||||
MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_BRAVE_MAX_RESULTS"`
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_WEB_BRAVE_ENABLED"`
|
||||
APIKey string `json:"api_key" env:"PICOCLAW_TOOLS_WEB_BRAVE_API_KEY"`
|
||||
APIKeys []string `json:"api_keys" env:"PICOCLAW_TOOLS_WEB_BRAVE_API_KEYS"`
|
||||
MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_BRAVE_MAX_RESULTS"`
|
||||
}
|
||||
|
||||
type TavilyConfig struct {
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_WEB_TAVILY_ENABLED"`
|
||||
APIKey string `json:"api_key" env:"PICOCLAW_TOOLS_WEB_TAVILY_API_KEY"`
|
||||
BaseURL string `json:"base_url" env:"PICOCLAW_TOOLS_WEB_TAVILY_BASE_URL"`
|
||||
MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_TAVILY_MAX_RESULTS"`
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_WEB_TAVILY_ENABLED"`
|
||||
APIKey string `json:"api_key" env:"PICOCLAW_TOOLS_WEB_TAVILY_API_KEY"`
|
||||
APIKeys []string `json:"api_keys" env:"PICOCLAW_TOOLS_WEB_TAVILY_API_KEYS"`
|
||||
BaseURL string `json:"base_url" env:"PICOCLAW_TOOLS_WEB_TAVILY_BASE_URL"`
|
||||
MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_TAVILY_MAX_RESULTS"`
|
||||
}
|
||||
|
||||
type DuckDuckGoConfig struct {
|
||||
@@ -587,9 +624,10 @@ type DuckDuckGoConfig struct {
|
||||
}
|
||||
|
||||
type PerplexityConfig struct {
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_WEB_PERPLEXITY_ENABLED"`
|
||||
APIKey string `json:"api_key" env:"PICOCLAW_TOOLS_WEB_PERPLEXITY_API_KEY"`
|
||||
MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_PERPLEXITY_MAX_RESULTS"`
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_WEB_PERPLEXITY_ENABLED"`
|
||||
APIKey string `json:"api_key" env:"PICOCLAW_TOOLS_WEB_PERPLEXITY_API_KEY"`
|
||||
APIKeys []string `json:"api_keys" env:"PICOCLAW_TOOLS_WEB_PERPLEXITY_API_KEYS"`
|
||||
MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_PERPLEXITY_MAX_RESULTS"`
|
||||
}
|
||||
|
||||
type SearXNGConfig struct {
|
||||
@@ -648,6 +686,11 @@ type MediaCleanupConfig struct {
|
||||
Interval int ` env:"PICOCLAW_MEDIA_CLEANUP_INTERVAL" json:"interval_minutes"`
|
||||
}
|
||||
|
||||
type ReadFileToolConfig struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
MaxReadFileSize int `json:"max_read_file_size"`
|
||||
}
|
||||
|
||||
type ToolsConfig struct {
|
||||
AllowReadPaths []string `json:"allow_read_paths" env:"PICOCLAW_TOOLS_ALLOW_READ_PATHS"`
|
||||
AllowWritePaths []string `json:"allow_write_paths" env:"PICOCLAW_TOOLS_ALLOW_WRITE_PATHS"`
|
||||
@@ -664,7 +707,7 @@ type ToolsConfig struct {
|
||||
InstallSkill ToolConfig `json:"install_skill" envPrefix:"PICOCLAW_TOOLS_INSTALL_SKILL_"`
|
||||
ListDir ToolConfig `json:"list_dir" envPrefix:"PICOCLAW_TOOLS_LIST_DIR_"`
|
||||
Message ToolConfig `json:"message" envPrefix:"PICOCLAW_TOOLS_MESSAGE_"`
|
||||
ReadFile ToolConfig `json:"read_file" envPrefix:"PICOCLAW_TOOLS_READ_FILE_"`
|
||||
ReadFile ReadFileToolConfig `json:"read_file" envPrefix:"PICOCLAW_TOOLS_READ_FILE_"`
|
||||
SendFile ToolConfig `json:"send_file" envPrefix:"PICOCLAW_TOOLS_SEND_FILE_"`
|
||||
Spawn ToolConfig `json:"spawn" envPrefix:"PICOCLAW_TOOLS_SPAWN_"`
|
||||
SPI ToolConfig `json:"spi" envPrefix:"PICOCLAW_TOOLS_SPI_"`
|
||||
@@ -716,7 +759,8 @@ type MCPServerConfig struct {
|
||||
|
||||
// MCPConfig defines configuration for all MCP servers
|
||||
type MCPConfig struct {
|
||||
ToolConfig `envPrefix:"PICOCLAW_TOOLS_MCP_"`
|
||||
ToolConfig ` envPrefix:"PICOCLAW_TOOLS_MCP_"`
|
||||
Discovery ToolDiscoveryConfig ` json:"discovery"`
|
||||
// Servers is a map of server name to server configuration
|
||||
Servers map[string]MCPServerConfig `json:"servers,omitempty"`
|
||||
}
|
||||
@@ -903,6 +947,29 @@ func (c *Config) ValidateModelList() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func MergeAPIKeys(apiKey string, apiKeys []string) []string {
|
||||
seen := make(map[string]struct{})
|
||||
var all []string
|
||||
|
||||
if k := strings.TrimSpace(apiKey); k != "" {
|
||||
if _, exists := seen[k]; !exists {
|
||||
seen[k] = struct{}{}
|
||||
all = append(all, k)
|
||||
}
|
||||
}
|
||||
|
||||
for _, k := range apiKeys {
|
||||
if trimmed := strings.TrimSpace(k); trimmed != "" {
|
||||
if _, exists := seen[trimmed]; !exists {
|
||||
seen[trimmed] = struct{}{}
|
||||
all = append(all, trimmed)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return all
|
||||
}
|
||||
|
||||
func (t *ToolsConfig) IsToolEnabled(name string) bool {
|
||||
switch name {
|
||||
case "web":
|
||||
|
||||
@@ -283,6 +283,9 @@ func TestDefaultConfig_Channels(t *testing.T) {
|
||||
if cfg.Channels.Slack.Enabled {
|
||||
t.Error("Slack should be disabled by default")
|
||||
}
|
||||
if cfg.Channels.Matrix.Enabled {
|
||||
t.Error("Matrix should be disabled by default")
|
||||
}
|
||||
}
|
||||
|
||||
// TestDefaultConfig_WebTools verifies web tools config
|
||||
@@ -293,7 +296,7 @@ func TestDefaultConfig_WebTools(t *testing.T) {
|
||||
if cfg.Tools.Web.Brave.MaxResults != 5 {
|
||||
t.Error("Expected Brave MaxResults 5, got ", cfg.Tools.Web.Brave.MaxResults)
|
||||
}
|
||||
if cfg.Tools.Web.Brave.APIKey != "" {
|
||||
if len(cfg.Tools.Web.Brave.APIKeys) != 0 {
|
||||
t.Error("Brave API key should be empty by default")
|
||||
}
|
||||
if cfg.Tools.Web.DuckDuckGo.MaxResults != 5 {
|
||||
|
||||
+60
-8
@@ -80,10 +80,11 @@ func DefaultConfig() *Config {
|
||||
AllowFrom: FlexibleStringSlice{},
|
||||
},
|
||||
QQ: QQConfig{
|
||||
Enabled: false,
|
||||
AppID: "",
|
||||
AppSecret: "",
|
||||
AllowFrom: FlexibleStringSlice{},
|
||||
Enabled: false,
|
||||
AppID: "",
|
||||
AppSecret: "",
|
||||
AllowFrom: FlexibleStringSlice{},
|
||||
MaxMessageLength: 2000,
|
||||
},
|
||||
DingTalk: DingTalkConfig{
|
||||
Enabled: false,
|
||||
@@ -97,6 +98,22 @@ func DefaultConfig() *Config {
|
||||
AppToken: "",
|
||||
AllowFrom: FlexibleStringSlice{},
|
||||
},
|
||||
Matrix: MatrixConfig{
|
||||
Enabled: false,
|
||||
Homeserver: "https://matrix.org",
|
||||
UserID: "",
|
||||
AccessToken: "",
|
||||
DeviceID: "",
|
||||
JoinOnInvite: true,
|
||||
AllowFrom: FlexibleStringSlice{},
|
||||
GroupTrigger: GroupTriggerConfig{
|
||||
MentionOnly: true,
|
||||
},
|
||||
Placeholder: PlaceholderConfig{
|
||||
Enabled: true,
|
||||
Text: "Thinking... 💭",
|
||||
},
|
||||
},
|
||||
LINE: LINEConfig{
|
||||
Enabled: false,
|
||||
ChannelSecret: "",
|
||||
@@ -261,6 +278,14 @@ func DefaultConfig() *Config {
|
||||
APIKey: "",
|
||||
},
|
||||
|
||||
// Vivgrid - https://vivgrid.com
|
||||
{
|
||||
ModelName: "vivgrid-auto",
|
||||
Model: "vivgrid/auto",
|
||||
APIBase: "https://api.vivgrid.com/v1",
|
||||
APIKey: "",
|
||||
},
|
||||
|
||||
// Volcengine (火山引擎) - https://console.volcengine.com/ark
|
||||
{
|
||||
ModelName: "doubao-pro",
|
||||
@@ -322,6 +347,14 @@ func DefaultConfig() *Config {
|
||||
APIKey: "",
|
||||
},
|
||||
|
||||
// Minimax - https://api.minimaxi.com/
|
||||
{
|
||||
ModelName: "MiniMax-M2.5",
|
||||
Model: "minimax/MiniMax-M2.5",
|
||||
APIBase: "https://api.minimaxi.com/v1",
|
||||
APIKey: "",
|
||||
},
|
||||
|
||||
// VLLM (local) - http://localhost:8000
|
||||
{
|
||||
ModelName: "local-model",
|
||||
@@ -351,6 +384,13 @@ func DefaultConfig() *Config {
|
||||
Brave: BraveConfig{
|
||||
Enabled: false,
|
||||
APIKey: "",
|
||||
APIKeys: nil,
|
||||
MaxResults: 5,
|
||||
},
|
||||
Tavily: TavilyConfig{
|
||||
Enabled: false,
|
||||
APIKey: "",
|
||||
APIKeys: nil,
|
||||
MaxResults: 5,
|
||||
},
|
||||
DuckDuckGo: DuckDuckGoConfig{
|
||||
@@ -360,6 +400,7 @@ func DefaultConfig() *Config {
|
||||
Perplexity: PerplexityConfig{
|
||||
Enabled: false,
|
||||
APIKey: "",
|
||||
APIKeys: nil,
|
||||
MaxResults: 5,
|
||||
},
|
||||
SearXNG: SearXNGConfig{
|
||||
@@ -411,6 +452,13 @@ func DefaultConfig() *Config {
|
||||
ToolConfig: ToolConfig{
|
||||
Enabled: false,
|
||||
},
|
||||
Discovery: ToolDiscoveryConfig{
|
||||
Enabled: false,
|
||||
TTL: 5,
|
||||
MaxSearchResults: 5,
|
||||
UseBM25: true,
|
||||
UseRegex: false,
|
||||
},
|
||||
Servers: map[string]MCPServerConfig{},
|
||||
},
|
||||
AppendFile: ToolConfig{
|
||||
@@ -434,8 +482,9 @@ func DefaultConfig() *Config {
|
||||
Message: ToolConfig{
|
||||
Enabled: true,
|
||||
},
|
||||
ReadFile: ToolConfig{
|
||||
Enabled: true,
|
||||
ReadFile: ReadFileToolConfig{
|
||||
Enabled: true,
|
||||
MaxReadFileSize: 64 * 1024, // 64KB
|
||||
},
|
||||
Spawn: ToolConfig{
|
||||
Enabled: true,
|
||||
@@ -461,8 +510,11 @@ func DefaultConfig() *Config {
|
||||
Enabled: false,
|
||||
MonitorUSB: true,
|
||||
},
|
||||
Voice: VoiceConfig{
|
||||
EchoTranscription: false,
|
||||
BuildInfo: BuildInfo{
|
||||
Version: Version,
|
||||
GitCommit: GitCommit,
|
||||
BuildTime: BuildTime,
|
||||
GoVersion: GoVersion,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -292,6 +292,23 @@ func ConvertProvidersToModelList(cfg *Config) []ModelConfig {
|
||||
}, true
|
||||
},
|
||||
},
|
||||
{
|
||||
providerNames: []string{"vivgrid"},
|
||||
protocol: "vivgrid",
|
||||
buildConfig: func(p ProvidersConfig) (ModelConfig, bool) {
|
||||
if p.Vivgrid.APIKey == "" && p.Vivgrid.APIBase == "" {
|
||||
return ModelConfig{}, false
|
||||
}
|
||||
return ModelConfig{
|
||||
ModelName: "vivgrid",
|
||||
Model: "vivgrid/auto",
|
||||
APIKey: p.Vivgrid.APIKey,
|
||||
APIBase: p.Vivgrid.APIBase,
|
||||
Proxy: p.Vivgrid.Proxy,
|
||||
RequestTimeout: p.Vivgrid.RequestTimeout,
|
||||
}, true
|
||||
},
|
||||
},
|
||||
{
|
||||
providerNames: []string{"volcengine", "doubao"},
|
||||
protocol: "volcengine",
|
||||
|
||||
@@ -155,7 +155,8 @@ func TestConvertProvidersToModelList_AllProviders(t *testing.T) {
|
||||
ShengSuanYun: ProviderConfig{APIKey: "key11"},
|
||||
DeepSeek: ProviderConfig{APIKey: "key12"},
|
||||
Cerebras: ProviderConfig{APIKey: "key13"},
|
||||
VolcEngine: ProviderConfig{APIKey: "key14"},
|
||||
Vivgrid: ProviderConfig{APIKey: "key14"},
|
||||
VolcEngine: ProviderConfig{APIKey: "key15"},
|
||||
GitHubCopilot: ProviderConfig{ConnectMode: "grpc"},
|
||||
Antigravity: ProviderConfig{AuthMethod: "oauth"},
|
||||
Qwen: ProviderConfig{APIKey: "key17"},
|
||||
@@ -166,9 +167,9 @@ func TestConvertProvidersToModelList_AllProviders(t *testing.T) {
|
||||
|
||||
result := ConvertProvidersToModelList(cfg)
|
||||
|
||||
// All 20 providers should be converted
|
||||
if len(result) != 20 {
|
||||
t.Errorf("len(result) = %d, want 20", len(result))
|
||||
// All 21 providers should be converted
|
||||
if len(result) != 21 {
|
||||
t.Errorf("len(result) = %d, want 21", len(result))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,44 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
// Build-time variables injected via ldflags during build process.
|
||||
// These are set by the Makefile or .goreleaser.yaml using the -X flag:
|
||||
//
|
||||
// -X github.com/sipeed/picoclaw/pkg/config.Version=<version>
|
||||
// -X github.com/sipeed/picoclaw/pkg/config.GitCommit=<commit>
|
||||
// -X github.com/sipeed/picoclaw/pkg/config.BuildTime=<timestamp>
|
||||
// -X github.com/sipeed/picoclaw/pkg/config.GoVersion=<go-version>
|
||||
var (
|
||||
Version = "dev" // Default value when not built with ldflags
|
||||
GitCommit string // Git commit SHA (short)
|
||||
BuildTime string // Build timestamp in RFC3339 format
|
||||
GoVersion string // Go version used for building
|
||||
)
|
||||
|
||||
// FormatVersion returns the version string with optional git commit
|
||||
func FormatVersion() string {
|
||||
v := Version
|
||||
if GitCommit != "" {
|
||||
v += fmt.Sprintf(" (git: %s)", GitCommit)
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
// FormatBuildInfo returns build time and go version info
|
||||
func FormatBuildInfo() (string, string) {
|
||||
build := BuildTime
|
||||
goVer := GoVersion
|
||||
if goVer == "" {
|
||||
goVer = runtime.Version()
|
||||
}
|
||||
return build, goVer
|
||||
}
|
||||
|
||||
// GetVersion returns the version string
|
||||
func GetVersion() string {
|
||||
return Version
|
||||
}
|
||||
@@ -0,0 +1,92 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestFormatVersion_NoGitCommit(t *testing.T) {
|
||||
oldVersion, oldGit := Version, GitCommit
|
||||
t.Cleanup(func() { Version, GitCommit = oldVersion, oldGit })
|
||||
|
||||
Version = "1.2.3"
|
||||
GitCommit = ""
|
||||
|
||||
assert.Equal(t, "1.2.3", FormatVersion())
|
||||
}
|
||||
|
||||
func TestFormatVersion_WithGitCommit(t *testing.T) {
|
||||
oldVersion, oldGit := Version, GitCommit
|
||||
t.Cleanup(func() { Version, GitCommit = oldVersion, oldGit })
|
||||
|
||||
Version = "1.2.3"
|
||||
GitCommit = "abc123"
|
||||
|
||||
assert.Equal(t, "1.2.3 (git: abc123)", FormatVersion())
|
||||
}
|
||||
|
||||
func TestFormatBuildInfo_UsesBuildTimeAndGoVersion_WhenSet(t *testing.T) {
|
||||
oldBuildTime, oldGoVersion := BuildTime, GoVersion
|
||||
t.Cleanup(func() { BuildTime, GoVersion = oldBuildTime, oldGoVersion })
|
||||
|
||||
BuildTime = "2026-02-20T00:00:00Z"
|
||||
GoVersion = "go1.23.0"
|
||||
|
||||
build, goVer := FormatBuildInfo()
|
||||
|
||||
assert.Equal(t, BuildTime, build)
|
||||
assert.Equal(t, GoVersion, goVer)
|
||||
}
|
||||
|
||||
func TestFormatBuildInfo_EmptyBuildTime_ReturnsEmptyBuild(t *testing.T) {
|
||||
oldBuildTime, oldGoVersion := BuildTime, GoVersion
|
||||
t.Cleanup(func() { BuildTime, GoVersion = oldBuildTime, oldGoVersion })
|
||||
|
||||
BuildTime = ""
|
||||
GoVersion = "go1.23.0"
|
||||
|
||||
build, goVer := FormatBuildInfo()
|
||||
|
||||
assert.Empty(t, build)
|
||||
assert.Equal(t, GoVersion, goVer)
|
||||
}
|
||||
|
||||
func TestFormatBuildInfo_EmptyGoVersion_FallsBackToRuntimeVersion(t *testing.T) {
|
||||
oldBuildTime, oldGoVersion := BuildTime, GoVersion
|
||||
t.Cleanup(func() { BuildTime, GoVersion = oldBuildTime, oldGoVersion })
|
||||
|
||||
BuildTime = "x"
|
||||
GoVersion = ""
|
||||
|
||||
build, goVer := FormatBuildInfo()
|
||||
|
||||
assert.Equal(t, "x", build)
|
||||
assert.Equal(t, runtime.Version(), goVer)
|
||||
}
|
||||
|
||||
func TestGetVersion(t *testing.T) {
|
||||
oldVersion := Version
|
||||
t.Cleanup(func() { Version = oldVersion })
|
||||
|
||||
Version = "dev"
|
||||
assert.Equal(t, "dev", GetVersion())
|
||||
}
|
||||
|
||||
func TestGetVersion_Custom(t *testing.T) {
|
||||
oldVersion := Version
|
||||
t.Cleanup(func() { Version = oldVersion })
|
||||
|
||||
Version = "v1.0.0"
|
||||
assert.Equal(t, "v1.0.0", GetVersion())
|
||||
}
|
||||
|
||||
func TestVersion_DefaultIsDev(t *testing.T) {
|
||||
// Reset to default values
|
||||
oldVersion := Version
|
||||
Version = "dev"
|
||||
t.Cleanup(func() { Version = oldVersion })
|
||||
|
||||
assert.Equal(t, "dev", Version)
|
||||
}
|
||||
@@ -22,6 +22,7 @@ var supportedChannels = map[string]bool{
|
||||
"qq": true,
|
||||
"dingtalk": true,
|
||||
"slack": true,
|
||||
"matrix": true,
|
||||
"line": true,
|
||||
"onebot": true,
|
||||
"wecom": true,
|
||||
|
||||
@@ -371,6 +371,8 @@ func (c *OpenClawConfig) IsChannelEnabled(name string) bool {
|
||||
return c.Channels.Discord == nil || c.Channels.Discord.Enabled == nil || *c.Channels.Discord.Enabled
|
||||
case "slack":
|
||||
return c.Channels.Slack == nil || c.Channels.Slack.Enabled == nil || *c.Channels.Slack.Enabled
|
||||
case "matrix":
|
||||
return c.Channels.Matrix == nil || c.Channels.Matrix.Enabled == nil || *c.Channels.Matrix.Enabled
|
||||
case "whatsapp":
|
||||
return c.Channels.WhatsApp == nil || c.Channels.WhatsApp.Enabled == nil || *c.Channels.WhatsApp.Enabled
|
||||
case "feishu":
|
||||
@@ -397,6 +399,11 @@ func GetChannelAllowFrom(ch any) []string {
|
||||
return nil
|
||||
}
|
||||
return c.AllowFrom
|
||||
case *OpenClawMatrixConfig:
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
return c.AllowFrom
|
||||
case *OpenClawWhatsAppConfig:
|
||||
if c == nil {
|
||||
return nil
|
||||
@@ -627,6 +634,7 @@ type ChannelsConfig struct {
|
||||
QQ QQConfig `json:"qq"`
|
||||
DingTalk DingTalkConfig `json:"dingtalk"`
|
||||
Slack SlackConfig `json:"slack"`
|
||||
Matrix MatrixConfig `json:"matrix"`
|
||||
LINE LINEConfig `json:"line"`
|
||||
}
|
||||
|
||||
@@ -687,6 +695,14 @@ type SlackConfig struct {
|
||||
AllowFrom []string `json:"allow_from"`
|
||||
}
|
||||
|
||||
type MatrixConfig struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Homeserver string `json:"homeserver"`
|
||||
UserID string `json:"user_id"`
|
||||
AccessToken string `json:"access_token"`
|
||||
AllowFrom []string `json:"allow_from"`
|
||||
}
|
||||
|
||||
type LINEConfig struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
ChannelSecret string `json:"channel_secret"`
|
||||
@@ -717,16 +733,18 @@ type WebToolsConfig struct {
|
||||
}
|
||||
|
||||
type BraveConfig struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
APIKey string `json:"api_key"`
|
||||
MaxResults int `json:"max_results"`
|
||||
Enabled bool `json:"enabled"`
|
||||
APIKey string `json:"api_key"`
|
||||
APIKeys []string `json:"api_keys"`
|
||||
MaxResults int `json:"max_results"`
|
||||
}
|
||||
|
||||
type TavilyConfig struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
APIKey string `json:"api_key"`
|
||||
BaseURL string `json:"base_url"`
|
||||
MaxResults int `json:"max_results"`
|
||||
Enabled bool `json:"enabled"`
|
||||
APIKey string `json:"api_key"`
|
||||
APIKeys []string `json:"api_keys"`
|
||||
BaseURL string `json:"base_url"`
|
||||
MaxResults int `json:"max_results"`
|
||||
}
|
||||
|
||||
type DuckDuckGoConfig struct {
|
||||
@@ -735,9 +753,10 @@ type DuckDuckGoConfig struct {
|
||||
}
|
||||
|
||||
type PerplexityConfig struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
APIKey string `json:"api_key"`
|
||||
MaxResults int `json:"max_results"`
|
||||
Enabled bool `json:"enabled"`
|
||||
APIKey string `json:"api_key"`
|
||||
APIKeys []string `json:"api_keys"`
|
||||
MaxResults int `json:"max_results"`
|
||||
}
|
||||
|
||||
type CronConfig struct {
|
||||
@@ -862,12 +881,26 @@ func (c *OpenClawConfig) convertChannels(warnings *[]string) ChannelsConfig {
|
||||
}
|
||||
}
|
||||
|
||||
if c.Channels.Matrix != nil && supportedChannels["matrix"] {
|
||||
enabled := c.Channels.Matrix.Enabled == nil || *c.Channels.Matrix.Enabled
|
||||
channels.Matrix = MatrixConfig{
|
||||
Enabled: enabled,
|
||||
AllowFrom: c.Channels.Matrix.AllowFrom,
|
||||
}
|
||||
if c.Channels.Matrix.Homeserver != nil {
|
||||
channels.Matrix.Homeserver = *c.Channels.Matrix.Homeserver
|
||||
}
|
||||
if c.Channels.Matrix.UserID != nil {
|
||||
channels.Matrix.UserID = *c.Channels.Matrix.UserID
|
||||
}
|
||||
if c.Channels.Matrix.AccessToken != nil {
|
||||
channels.Matrix.AccessToken = *c.Channels.Matrix.AccessToken
|
||||
}
|
||||
}
|
||||
|
||||
if c.Channels.Signal != nil {
|
||||
*warnings = append(*warnings, "Channel 'signal': No PicoClaw adapter available")
|
||||
}
|
||||
if c.Channels.Matrix != nil {
|
||||
*warnings = append(*warnings, "Channel 'matrix': No PicoClaw adapter available")
|
||||
}
|
||||
if c.Channels.IRC != nil {
|
||||
*warnings = append(*warnings, "Channel 'irc': No PicoClaw adapter available")
|
||||
}
|
||||
@@ -1020,6 +1053,14 @@ func (c ChannelsConfig) ToStandardChannels() config.ChannelsConfig {
|
||||
BotToken: c.Slack.BotToken,
|
||||
AppToken: c.Slack.AppToken,
|
||||
},
|
||||
Matrix: config.MatrixConfig{
|
||||
Enabled: c.Matrix.Enabled,
|
||||
Homeserver: c.Matrix.Homeserver,
|
||||
UserID: c.Matrix.UserID,
|
||||
AccessToken: c.Matrix.AccessToken,
|
||||
AllowFrom: c.Matrix.AllowFrom,
|
||||
JoinOnInvite: true,
|
||||
},
|
||||
LINE: config.LINEConfig{
|
||||
Enabled: c.LINE.Enabled,
|
||||
ChannelSecret: c.LINE.ChannelSecret,
|
||||
@@ -1044,6 +1085,7 @@ func (c ToolsConfig) ToStandardTools() config.ToolsConfig {
|
||||
Brave: config.BraveConfig{
|
||||
Enabled: c.Web.Brave.Enabled,
|
||||
APIKey: c.Web.Brave.APIKey,
|
||||
APIKeys: c.Web.Brave.APIKeys,
|
||||
MaxResults: c.Web.Brave.MaxResults,
|
||||
},
|
||||
Tavily: config.TavilyConfig{
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@@ -375,6 +376,96 @@ func TestConvertToPicoClawWithQQAndDingTalk(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertToPicoClawWithMatrix(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
configPath := filepath.Join(tmpDir, "openclaw.json")
|
||||
|
||||
testConfig := `{
|
||||
"channels": {
|
||||
"matrix": {
|
||||
"enabled": true,
|
||||
"homeserver": "https://matrix.example.com",
|
||||
"userId": "@bot:matrix.example.com",
|
||||
"accessToken": "syt_test_token",
|
||||
"allowFrom": ["@alice:matrix.example.com"]
|
||||
}
|
||||
}
|
||||
}`
|
||||
|
||||
err := os.WriteFile(configPath, []byte(testConfig), 0o644)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to write test config: %v", err)
|
||||
}
|
||||
|
||||
cfg, err := LoadOpenClawConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load config: %v", err)
|
||||
}
|
||||
|
||||
picoCfg, warnings, err := cfg.ConvertToPicoClaw("")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to convert config: %v", err)
|
||||
}
|
||||
|
||||
if !picoCfg.Channels.Matrix.Enabled {
|
||||
t.Error("matrix should be enabled")
|
||||
}
|
||||
if picoCfg.Channels.Matrix.Homeserver != "https://matrix.example.com" {
|
||||
t.Errorf("expected matrix homeserver, got %q", picoCfg.Channels.Matrix.Homeserver)
|
||||
}
|
||||
if picoCfg.Channels.Matrix.UserID != "@bot:matrix.example.com" {
|
||||
t.Errorf("expected matrix user_id, got %q", picoCfg.Channels.Matrix.UserID)
|
||||
}
|
||||
if picoCfg.Channels.Matrix.AccessToken != "syt_test_token" {
|
||||
t.Errorf("expected matrix access_token, got %q", picoCfg.Channels.Matrix.AccessToken)
|
||||
}
|
||||
if len(picoCfg.Channels.Matrix.AllowFrom) != 1 ||
|
||||
picoCfg.Channels.Matrix.AllowFrom[0] != "@alice:matrix.example.com" {
|
||||
t.Errorf("unexpected matrix allow_from: %#v", picoCfg.Channels.Matrix.AllowFrom)
|
||||
}
|
||||
|
||||
for _, w := range warnings {
|
||||
if strings.Contains(w, "Channel 'matrix'") {
|
||||
t.Fatalf("matrix should no longer be reported as unsupported, warning=%q", w)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertToPicoClawWithMatrixDisabled(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
configPath := filepath.Join(tmpDir, "openclaw.json")
|
||||
|
||||
testConfig := `{
|
||||
"channels": {
|
||||
"matrix": {
|
||||
"enabled": false,
|
||||
"homeserver": "https://matrix.example.com",
|
||||
"userId": "@bot:matrix.example.com",
|
||||
"accessToken": "syt_test_token"
|
||||
}
|
||||
}
|
||||
}`
|
||||
|
||||
err := os.WriteFile(configPath, []byte(testConfig), 0o644)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to write test config: %v", err)
|
||||
}
|
||||
|
||||
cfg, err := LoadOpenClawConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load config: %v", err)
|
||||
}
|
||||
|
||||
picoCfg, _, err := cfg.ConvertToPicoClaw("")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to convert config: %v", err)
|
||||
}
|
||||
|
||||
if picoCfg.Channels.Matrix.Enabled {
|
||||
t.Error("matrix should respect enabled=false from source config")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenClawAgentModel(t *testing.T) {
|
||||
model := &OpenClawAgentModel{
|
||||
Primary: strPtr("anthropic/claude-3-opus"),
|
||||
@@ -425,6 +516,9 @@ func TestChannelEnabled(t *testing.T) {
|
||||
if !cfg.IsChannelEnabled("slack") {
|
||||
t.Error("slack should be enabled (explicitly set)")
|
||||
}
|
||||
if !cfg.IsChannelEnabled("matrix") {
|
||||
t.Error("matrix should be enabled (nil config defaults to enabled)")
|
||||
}
|
||||
if cfg.IsChannelEnabled("line") {
|
||||
t.Error("line should return false (not in switch cases)")
|
||||
}
|
||||
|
||||
@@ -153,6 +153,15 @@ func resolveProviderSelection(cfg *config.Config) (providerSelection, error) {
|
||||
sel.apiBase = "https://integrate.api.nvidia.com/v1"
|
||||
}
|
||||
}
|
||||
case "vivgrid":
|
||||
if cfg.Providers.Vivgrid.APIKey != "" {
|
||||
sel.apiKey = cfg.Providers.Vivgrid.APIKey
|
||||
sel.apiBase = cfg.Providers.Vivgrid.APIBase
|
||||
sel.proxy = cfg.Providers.Vivgrid.Proxy
|
||||
if sel.apiBase == "" {
|
||||
sel.apiBase = "https://api.vivgrid.com/v1"
|
||||
}
|
||||
}
|
||||
case "claude-cli", "claude-code", "claudecode":
|
||||
workspace := cfg.WorkspacePath()
|
||||
if workspace == "" {
|
||||
@@ -199,6 +208,15 @@ func resolveProviderSelection(cfg *config.Config) (providerSelection, error) {
|
||||
sel.apiBase = "https://api.mistral.ai/v1"
|
||||
}
|
||||
}
|
||||
case "minimax":
|
||||
if cfg.Providers.Minimax.APIKey != "" {
|
||||
sel.apiKey = cfg.Providers.Minimax.APIKey
|
||||
sel.apiBase = cfg.Providers.Minimax.APIBase
|
||||
sel.proxy = cfg.Providers.Minimax.Proxy
|
||||
if sel.apiBase == "" {
|
||||
sel.apiBase = "https://api.minimaxi.com/v1"
|
||||
}
|
||||
}
|
||||
case "github_copilot", "copilot":
|
||||
sel.providerType = providerTypeGitHubCopilot
|
||||
if cfg.Providers.GitHubCopilot.APIBase != "" {
|
||||
@@ -295,6 +313,13 @@ func resolveProviderSelection(cfg *config.Config) (providerSelection, error) {
|
||||
if sel.apiBase == "" {
|
||||
sel.apiBase = "https://integrate.api.nvidia.com/v1"
|
||||
}
|
||||
case strings.HasPrefix(model, "vivgrid/") && cfg.Providers.Vivgrid.APIKey != "":
|
||||
sel.apiKey = cfg.Providers.Vivgrid.APIKey
|
||||
sel.apiBase = cfg.Providers.Vivgrid.APIBase
|
||||
sel.proxy = cfg.Providers.Vivgrid.Proxy
|
||||
if sel.apiBase == "" {
|
||||
sel.apiBase = "https://api.vivgrid.com/v1"
|
||||
}
|
||||
case (strings.Contains(lowerModel, "ollama") || strings.HasPrefix(model, "ollama/")) && cfg.Providers.Ollama.APIKey != "":
|
||||
sel.apiKey = cfg.Providers.Ollama.APIKey
|
||||
sel.apiBase = cfg.Providers.Ollama.APIBase
|
||||
@@ -309,6 +334,13 @@ func resolveProviderSelection(cfg *config.Config) (providerSelection, error) {
|
||||
if sel.apiBase == "" {
|
||||
sel.apiBase = "https://api.mistral.ai/v1"
|
||||
}
|
||||
case (strings.Contains(lowerModel, "minimax") || strings.HasPrefix(model, "minimax/")) && cfg.Providers.Minimax.APIKey != "":
|
||||
sel.apiKey = cfg.Providers.Minimax.APIKey
|
||||
sel.apiBase = cfg.Providers.Minimax.APIBase
|
||||
sel.proxy = cfg.Providers.Minimax.Proxy
|
||||
if sel.apiBase == "" {
|
||||
sel.apiBase = "https://api.minimaxi.com/v1"
|
||||
}
|
||||
case strings.HasPrefix(model, "avian/") && cfg.Providers.Avian.APIKey != "":
|
||||
sel.apiKey = cfg.Providers.Avian.APIKey
|
||||
sel.apiBase = cfg.Providers.Avian.APIBase
|
||||
|
||||
@@ -94,7 +94,8 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
|
||||
|
||||
case "litellm", "openrouter", "groq", "zhipu", "gemini", "nvidia",
|
||||
"ollama", "moonshot", "shengsuanyun", "deepseek", "cerebras",
|
||||
"volcengine", "vllm", "qwen", "mistral", "avian":
|
||||
"vivgrid", "volcengine", "vllm", "qwen", "mistral", "avian",
|
||||
"minimax":
|
||||
// All other OpenAI-compatible HTTP providers
|
||||
if cfg.APIKey == "" && cfg.APIBase == "" {
|
||||
return nil, "", fmt.Errorf("api_key or api_base is required for HTTP-based protocol %q", protocol)
|
||||
@@ -200,6 +201,8 @@ func getDefaultAPIBase(protocol string) string {
|
||||
return "https://api.deepseek.com/v1"
|
||||
case "cerebras":
|
||||
return "https://api.cerebras.ai/v1"
|
||||
case "vivgrid":
|
||||
return "https://api.vivgrid.com/v1"
|
||||
case "volcengine":
|
||||
return "https://ark.cn-beijing.volces.com/api/v3"
|
||||
case "qwen":
|
||||
@@ -210,6 +213,8 @@ func getDefaultAPIBase(protocol string) string {
|
||||
return "https://api.mistral.ai/v1"
|
||||
case "avian":
|
||||
return "https://api.avian.io/v1"
|
||||
case "minimax":
|
||||
return "https://api.minimaxi.com/v1"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -108,6 +108,7 @@ func TestCreateProviderFromConfig_DefaultAPIBase(t *testing.T) {
|
||||
{"groq", "groq"},
|
||||
{"openrouter", "openrouter"},
|
||||
{"cerebras", "cerebras"},
|
||||
{"vivgrid", "vivgrid"},
|
||||
{"qwen", "qwen"},
|
||||
{"vllm", "vllm"},
|
||||
{"deepseek", "deepseek"},
|
||||
|
||||
@@ -88,6 +88,17 @@ func TestResolveProviderSelection(t *testing.T) {
|
||||
wantAPIBase: "https://integrate.api.nvidia.com/v1",
|
||||
wantProxy: "http://127.0.0.1:7890",
|
||||
},
|
||||
{
|
||||
name: "explicit vivgrid provider uses defaults",
|
||||
setup: func(cfg *config.Config) {
|
||||
cfg.Agents.Defaults.Provider = "vivgrid"
|
||||
cfg.Providers.Vivgrid.APIKey = "vivgrid-key"
|
||||
cfg.Providers.Vivgrid.Proxy = "http://127.0.0.1:7890"
|
||||
},
|
||||
wantType: providerTypeHTTPCompat,
|
||||
wantAPIBase: "https://api.vivgrid.com/v1",
|
||||
wantProxy: "http://127.0.0.1:7890",
|
||||
},
|
||||
{
|
||||
name: "openrouter model uses openrouter defaults",
|
||||
setup: func(cfg *config.Config) {
|
||||
|
||||
@@ -439,7 +439,8 @@ func normalizeModel(model, apiBase string) string {
|
||||
|
||||
prefix := strings.ToLower(before)
|
||||
switch prefix {
|
||||
case "litellm", "moonshot", "nvidia", "groq", "ollama", "deepseek", "google", "openrouter", "zhipu", "mistral":
|
||||
case "litellm", "moonshot", "nvidia", "groq", "ollama", "deepseek", "google",
|
||||
"openrouter", "zhipu", "mistral", "vivgrid", "minimax":
|
||||
return after
|
||||
default:
|
||||
return model
|
||||
|
||||
@@ -382,7 +382,7 @@ func TestProviderChat_StripsMoonshotPrefixAndNormalizesKimiTemperature(t *testin
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderChat_StripsGroqAndOllamaPrefixes(t *testing.T) {
|
||||
func TestProviderChat_StripsGroqOllamaDeepseekVivgridPrefixes(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
@@ -408,6 +408,11 @@ func TestProviderChat_StripsGroqAndOllamaPrefixes(t *testing.T) {
|
||||
input: "deepseek/deepseek-chat",
|
||||
wantModel: "deepseek-chat",
|
||||
},
|
||||
{
|
||||
name: "strips vivgrid prefix",
|
||||
input: "vivgrid/auto",
|
||||
wantModel: "auto",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -512,6 +517,12 @@ func TestNormalizeModel_UsesAPIBase(t *testing.T) {
|
||||
if got := normalizeModel("openrouter/auto", "https://openrouter.ai/api/v1"); got != "openrouter/auto" {
|
||||
t.Fatalf("normalizeModel(openrouter) = %q, want %q", got, "openrouter/auto")
|
||||
}
|
||||
if got := normalizeModel("vivgrid/managed", "https://api.vivgrid.com/v1"); got != "managed" {
|
||||
t.Fatalf("normalizeModel(vivgrid) = %q, want %q", got, "managed")
|
||||
}
|
||||
if got := normalizeModel("vivgrid/auto", "https://api.vivgrid.com/v1"); got != "auto" {
|
||||
t.Fatalf("normalizeModel(vivgrid auto) = %q, want %q", got, "auto")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProvider_RequestTimeoutDefault(t *testing.T) {
|
||||
|
||||
@@ -0,0 +1,81 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/memory"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
// JSONLBackend adapts a memory.Store into the SessionStore interface.
|
||||
// Write errors are logged rather than returned, matching the fire-and-forget
|
||||
// contract of SessionManager that the agent loop relies on.
|
||||
type JSONLBackend struct {
|
||||
store memory.Store
|
||||
}
|
||||
|
||||
// NewJSONLBackend wraps a memory.Store for use as a SessionStore.
|
||||
func NewJSONLBackend(store memory.Store) *JSONLBackend {
|
||||
return &JSONLBackend{store: store}
|
||||
}
|
||||
|
||||
func (b *JSONLBackend) AddMessage(sessionKey, role, content string) {
|
||||
if err := b.store.AddMessage(context.Background(), sessionKey, role, content); err != nil {
|
||||
log.Printf("session: add message: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (b *JSONLBackend) AddFullMessage(sessionKey string, msg providers.Message) {
|
||||
if err := b.store.AddFullMessage(context.Background(), sessionKey, msg); err != nil {
|
||||
log.Printf("session: add full message: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (b *JSONLBackend) GetHistory(key string) []providers.Message {
|
||||
msgs, err := b.store.GetHistory(context.Background(), key)
|
||||
if err != nil {
|
||||
log.Printf("session: get history: %v", err)
|
||||
return []providers.Message{}
|
||||
}
|
||||
return msgs
|
||||
}
|
||||
|
||||
func (b *JSONLBackend) GetSummary(key string) string {
|
||||
summary, err := b.store.GetSummary(context.Background(), key)
|
||||
if err != nil {
|
||||
log.Printf("session: get summary: %v", err)
|
||||
return ""
|
||||
}
|
||||
return summary
|
||||
}
|
||||
|
||||
func (b *JSONLBackend) SetSummary(key, summary string) {
|
||||
if err := b.store.SetSummary(context.Background(), key, summary); err != nil {
|
||||
log.Printf("session: set summary: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (b *JSONLBackend) SetHistory(key string, history []providers.Message) {
|
||||
if err := b.store.SetHistory(context.Background(), key, history); err != nil {
|
||||
log.Printf("session: set history: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (b *JSONLBackend) TruncateHistory(key string, keepLast int) {
|
||||
if err := b.store.TruncateHistory(context.Background(), key, keepLast); err != nil {
|
||||
log.Printf("session: truncate history: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Save persists session state. Since the JSONL store fsyncs every write
|
||||
// immediately, the data is already durable. Save runs compaction to reclaim
|
||||
// space from logically truncated messages (no-op when there are none).
|
||||
func (b *JSONLBackend) Save(key string) error {
|
||||
return b.store.Compact(context.Background(), key)
|
||||
}
|
||||
|
||||
// Close releases resources held by the underlying store.
|
||||
func (b *JSONLBackend) Close() error {
|
||||
return b.store.Close()
|
||||
}
|
||||
@@ -0,0 +1,179 @@
|
||||
package session_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/memory"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/session"
|
||||
)
|
||||
|
||||
// Compile-time interface satisfaction checks.
|
||||
var (
|
||||
_ session.SessionStore = (*session.SessionManager)(nil)
|
||||
_ session.SessionStore = (*session.JSONLBackend)(nil)
|
||||
)
|
||||
|
||||
func newBackend(t *testing.T) *session.JSONLBackend {
|
||||
t.Helper()
|
||||
store, err := memory.NewJSONLStore(t.TempDir())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Cleanup(func() { store.Close() })
|
||||
return session.NewJSONLBackend(store)
|
||||
}
|
||||
|
||||
func TestJSONLBackend_AddAndGetHistory(t *testing.T) {
|
||||
b := newBackend(t)
|
||||
|
||||
b.AddMessage("s1", "user", "hello")
|
||||
b.AddMessage("s1", "assistant", "hi")
|
||||
|
||||
history := b.GetHistory("s1")
|
||||
if len(history) != 2 {
|
||||
t.Fatalf("got %d messages, want 2", len(history))
|
||||
}
|
||||
if history[0].Role != "user" || history[0].Content != "hello" {
|
||||
t.Errorf("msg[0] = %+v", history[0])
|
||||
}
|
||||
if history[1].Role != "assistant" || history[1].Content != "hi" {
|
||||
t.Errorf("msg[1] = %+v", history[1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestJSONLBackend_AddFullMessage(t *testing.T) {
|
||||
b := newBackend(t)
|
||||
|
||||
msg := providers.Message{
|
||||
Role: "assistant",
|
||||
Content: "done",
|
||||
ToolCalls: []providers.ToolCall{
|
||||
{ID: "tc1", Function: &providers.FunctionCall{Name: "read_file", Arguments: `{"path":"x"}`}},
|
||||
},
|
||||
}
|
||||
b.AddFullMessage("s1", msg)
|
||||
|
||||
history := b.GetHistory("s1")
|
||||
if len(history) != 1 {
|
||||
t.Fatalf("got %d, want 1", len(history))
|
||||
}
|
||||
if len(history[0].ToolCalls) != 1 || history[0].ToolCalls[0].ID != "tc1" {
|
||||
t.Errorf("tool calls = %+v", history[0].ToolCalls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestJSONLBackend_Summary(t *testing.T) {
|
||||
b := newBackend(t)
|
||||
|
||||
if got := b.GetSummary("s1"); got != "" {
|
||||
t.Errorf("got %q, want empty", got)
|
||||
}
|
||||
|
||||
b.SetSummary("s1", "test summary")
|
||||
if got := b.GetSummary("s1"); got != "test summary" {
|
||||
t.Errorf("got %q, want %q", got, "test summary")
|
||||
}
|
||||
}
|
||||
|
||||
func TestJSONLBackend_TruncateAndSave(t *testing.T) {
|
||||
b := newBackend(t)
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
b.AddMessage("s1", "user", fmt.Sprintf("msg %d", i))
|
||||
}
|
||||
b.TruncateHistory("s1", 3)
|
||||
|
||||
history := b.GetHistory("s1")
|
||||
if len(history) != 3 {
|
||||
t.Fatalf("got %d, want 3", len(history))
|
||||
}
|
||||
if history[0].Content != "msg 7" {
|
||||
t.Errorf("got %q, want %q", history[0].Content, "msg 7")
|
||||
}
|
||||
|
||||
// Save triggers compaction.
|
||||
if err := b.Save("s1"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Messages still accessible after compaction.
|
||||
history = b.GetHistory("s1")
|
||||
if len(history) != 3 {
|
||||
t.Fatalf("after save: got %d, want 3", len(history))
|
||||
}
|
||||
}
|
||||
|
||||
func TestJSONLBackend_SetHistory(t *testing.T) {
|
||||
b := newBackend(t)
|
||||
b.AddMessage("s1", "user", "old")
|
||||
|
||||
b.SetHistory("s1", []providers.Message{
|
||||
{Role: "user", Content: "new1"},
|
||||
{Role: "assistant", Content: "new2"},
|
||||
})
|
||||
|
||||
history := b.GetHistory("s1")
|
||||
if len(history) != 2 {
|
||||
t.Fatalf("got %d, want 2", len(history))
|
||||
}
|
||||
if history[0].Content != "new1" {
|
||||
t.Errorf("got %q, want %q", history[0].Content, "new1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestJSONLBackend_EmptySession(t *testing.T) {
|
||||
b := newBackend(t)
|
||||
|
||||
history := b.GetHistory("nonexistent")
|
||||
if history == nil {
|
||||
t.Fatal("got nil, want empty slice")
|
||||
}
|
||||
if len(history) != 0 {
|
||||
t.Errorf("got %d, want 0", len(history))
|
||||
}
|
||||
}
|
||||
|
||||
func TestJSONLBackend_SessionIsolation(t *testing.T) {
|
||||
b := newBackend(t)
|
||||
b.AddMessage("s1", "user", "session1")
|
||||
b.AddMessage("s2", "user", "session2")
|
||||
|
||||
h1 := b.GetHistory("s1")
|
||||
h2 := b.GetHistory("s2")
|
||||
|
||||
if len(h1) != 1 || h1[0].Content != "session1" {
|
||||
t.Errorf("s1: %+v", h1)
|
||||
}
|
||||
if len(h2) != 1 || h2[0].Content != "session2" {
|
||||
t.Errorf("s2: %+v", h2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestJSONLBackend_SummarizeFlow(t *testing.T) {
|
||||
// Simulates the real summarization flow in the agent loop:
|
||||
// SetSummary → TruncateHistory → Save
|
||||
b := newBackend(t)
|
||||
|
||||
for i := 0; i < 20; i++ {
|
||||
b.AddMessage("s1", "user", fmt.Sprintf("msg %d", i))
|
||||
}
|
||||
|
||||
b.SetSummary("s1", "conversation about testing")
|
||||
b.TruncateHistory("s1", 4)
|
||||
if err := b.Save("s1"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if got := b.GetSummary("s1"); got != "conversation about testing" {
|
||||
t.Errorf("summary = %q", got)
|
||||
}
|
||||
history := b.GetHistory("s1")
|
||||
if len(history) != 4 {
|
||||
t.Fatalf("got %d messages, want 4", len(history))
|
||||
}
|
||||
if history[0].Content != "msg 16" {
|
||||
t.Errorf("first message = %q, want %q", history[0].Content, "msg 16")
|
||||
}
|
||||
}
|
||||
@@ -265,6 +265,12 @@ func (sm *SessionManager) loadSessions() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close is a no-op for the in-memory SessionManager; it satisfies the
|
||||
// SessionStore interface so callers can release resources uniformly.
|
||||
func (sm *SessionManager) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetHistory updates the messages of a session.
|
||||
func (sm *SessionManager) SetHistory(key string, history []providers.Message) {
|
||||
sm.mu.Lock()
|
||||
|
||||
@@ -0,0 +1,32 @@
|
||||
package session
|
||||
|
||||
import "github.com/sipeed/picoclaw/pkg/providers"
|
||||
|
||||
// SessionStore defines the persistence operations used by the agent loop.
|
||||
// Both SessionManager (legacy JSON backend) and JSONLBackend satisfy this
|
||||
// interface, allowing the storage layer to be swapped without touching the
|
||||
// agent loop code.
|
||||
//
|
||||
// Write methods (Add*, Set*, Truncate*) are fire-and-forget: they do not
|
||||
// return errors. Implementations should log failures internally. This
|
||||
// matches the original SessionManager contract that the agent loop relies on.
|
||||
type SessionStore interface {
|
||||
// AddMessage appends a simple role/content message to the session.
|
||||
AddMessage(sessionKey, role, content string)
|
||||
// AddFullMessage appends a complete message including tool calls.
|
||||
AddFullMessage(sessionKey string, msg providers.Message)
|
||||
// GetHistory returns the full message history for the session.
|
||||
GetHistory(key string) []providers.Message
|
||||
// GetSummary returns the conversation summary, or "" if none.
|
||||
GetSummary(key string) string
|
||||
// SetSummary replaces the conversation summary.
|
||||
SetSummary(key, summary string)
|
||||
// SetHistory replaces the full message history.
|
||||
SetHistory(key string, history []providers.Message)
|
||||
// TruncateHistory keeps only the last keepLast messages.
|
||||
TruncateHistory(key string, keepLast int)
|
||||
// Save persists any pending state to durable storage.
|
||||
Save(key string) error
|
||||
// Close releases resources held by the store.
|
||||
Close() error
|
||||
}
|
||||
+241
-7
@@ -2,17 +2,24 @@ package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/fileutil"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
)
|
||||
|
||||
const MaxReadFileSize = 64 * 1024 // 64KB limit to avoid context overflow
|
||||
|
||||
// validatePath ensures the given path is within the workspace if restrict is true.
|
||||
func validatePath(path, workspace string, restrict bool) (string, error) {
|
||||
if workspace == "" {
|
||||
@@ -85,15 +92,30 @@ func isWithinWorkspace(candidate, workspace string) bool {
|
||||
}
|
||||
|
||||
type ReadFileTool struct {
|
||||
fs fileSystem
|
||||
fs fileSystem
|
||||
maxSize int64
|
||||
}
|
||||
|
||||
func NewReadFileTool(workspace string, restrict bool, allowPaths ...[]*regexp.Regexp) *ReadFileTool {
|
||||
func NewReadFileTool(
|
||||
workspace string,
|
||||
restrict bool,
|
||||
maxReadFileSize int,
|
||||
allowPaths ...[]*regexp.Regexp,
|
||||
) *ReadFileTool {
|
||||
var patterns []*regexp.Regexp
|
||||
if len(allowPaths) > 0 {
|
||||
patterns = allowPaths[0]
|
||||
}
|
||||
return &ReadFileTool{fs: buildFs(workspace, restrict, patterns)}
|
||||
|
||||
maxSize := int64(maxReadFileSize)
|
||||
if maxSize <= 0 {
|
||||
maxSize = MaxReadFileSize
|
||||
}
|
||||
|
||||
return &ReadFileTool{
|
||||
fs: buildFs(workspace, restrict, patterns),
|
||||
maxSize: maxSize,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *ReadFileTool) Name() string {
|
||||
@@ -101,7 +123,7 @@ func (t *ReadFileTool) Name() string {
|
||||
}
|
||||
|
||||
func (t *ReadFileTool) Description() string {
|
||||
return "Read the contents of a file"
|
||||
return "Read the contents of a file. Supports pagination via `offset` and `length`."
|
||||
}
|
||||
|
||||
func (t *ReadFileTool) Parameters() map[string]any {
|
||||
@@ -110,7 +132,17 @@ func (t *ReadFileTool) Parameters() map[string]any {
|
||||
"properties": map[string]any{
|
||||
"path": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Path to the file to read",
|
||||
"description": "Path to the file to read.",
|
||||
},
|
||||
"offset": map[string]any{
|
||||
"type": "integer",
|
||||
"description": "Byte offset to start reading from.",
|
||||
"default": 0,
|
||||
},
|
||||
"length": map[string]any{
|
||||
"type": "integer",
|
||||
"description": "Maximum number of bytes to read.",
|
||||
"default": t.maxSize,
|
||||
},
|
||||
},
|
||||
"required": []string{"path"},
|
||||
@@ -123,11 +155,171 @@ func (t *ReadFileTool) Execute(ctx context.Context, args map[string]any) *ToolRe
|
||||
return ErrorResult("path is required")
|
||||
}
|
||||
|
||||
content, err := t.fs.ReadFile(path)
|
||||
// offset (optional, default 0)
|
||||
offset, err := getInt64Arg(args, "offset", 0)
|
||||
if err != nil {
|
||||
return ErrorResult(err.Error())
|
||||
}
|
||||
return NewToolResult(string(content))
|
||||
if offset < 0 {
|
||||
return ErrorResult("offset must be >= 0")
|
||||
}
|
||||
|
||||
// length (optional, capped at MaxReadFileSize)
|
||||
length, err := getInt64Arg(args, "length", t.maxSize)
|
||||
if err != nil {
|
||||
return ErrorResult(err.Error())
|
||||
}
|
||||
if length <= 0 {
|
||||
return ErrorResult("length must be > 0")
|
||||
}
|
||||
if length > t.maxSize {
|
||||
length = t.maxSize
|
||||
}
|
||||
|
||||
file, err := t.fs.Open(path)
|
||||
if err != nil {
|
||||
return ErrorResult(err.Error())
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
// measure total size
|
||||
totalSize := int64(-1) // -1 means unknown
|
||||
if info, statErr := file.Stat(); statErr == nil {
|
||||
totalSize = info.Size()
|
||||
}
|
||||
|
||||
// sniff the first 512 bytes to detect binary content before loading
|
||||
// it into the LLM context. Seeking back to 0 afterwards restores state.
|
||||
sniff := make([]byte, 512)
|
||||
sniffN, _ := file.Read(sniff)
|
||||
|
||||
// Reset read position to beginning before applying the caller's offset.
|
||||
if seeker, ok := file.(io.Seeker); ok {
|
||||
_, err = seeker.Seek(0, io.SeekStart)
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("failed to reset file position after sniff: %v", err))
|
||||
}
|
||||
} else {
|
||||
// Non-seekable: we consumed sniffN bytes above; account for them when
|
||||
// discarding to reach the requested offset below.
|
||||
// If offset < sniffN the data we already read covers it, which we
|
||||
// cannot replay on a non-seekable stream — return a clear error.
|
||||
if offset < int64(sniffN) && offset > 0 {
|
||||
return ErrorResult(
|
||||
"non-seekable file: cannot seek to an offset within the first 512 bytes after binary detection",
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Seek to the requested offset.
|
||||
if seeker, ok := file.(io.Seeker); ok {
|
||||
_, err = seeker.Seek(offset, io.SeekStart)
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("failed to seek to offset %d: %v", offset, err))
|
||||
}
|
||||
} else if offset > 0 {
|
||||
// Fallback for non-seekable streams: discard leading bytes.
|
||||
// sniffN bytes were already consumed above, so subtract them.
|
||||
remaining := offset - int64(sniffN)
|
||||
if remaining > 0 {
|
||||
_, err = io.CopyN(io.Discard, file, remaining)
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("failed to advance to offset %d: %v", offset, err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// read length+1 bytes to reliably detect whether more content exists
|
||||
// without relying on totalSize (which may be -1 for non-seekable streams).
|
||||
// This avoids the false-positive TRUNCATED message on the last page.
|
||||
probe := make([]byte, length+1)
|
||||
n, err := io.ReadFull(file, probe)
|
||||
// FIX: io.ReadFull returns io.ErrUnexpectedEOF for partial reads (0 < n < len),
|
||||
// and io.EOF only when n == 0. Both are normal terminal conditions — only
|
||||
// other errors are genuine failures.
|
||||
if err != nil && err != io.EOF && !errors.Is(err, io.ErrUnexpectedEOF) {
|
||||
return ErrorResult(fmt.Sprintf("failed to read file content: %v", err))
|
||||
}
|
||||
|
||||
// hasMore is true only when we actually got the extra probe byte.
|
||||
hasMore := int64(n) > length
|
||||
data := probe[:min(int64(n), length)]
|
||||
|
||||
if len(data) == 0 {
|
||||
return NewToolResult("[END OF FILE - no content at this offset]")
|
||||
}
|
||||
|
||||
// Build metadata header.
|
||||
// use filepath.Base(path) instead of the raw path to avoid leaking
|
||||
// internal filesystem structure into the LLM context.
|
||||
readEnd := offset + int64(len(data))
|
||||
// use ASCII hyphen-minus instead of en-dash (U+2013) to keep the
|
||||
// header parseable by downstream tools and log processors.
|
||||
readRange := fmt.Sprintf("bytes %d-%d", offset, readEnd-1)
|
||||
|
||||
displayPath := filepath.Base(path)
|
||||
var header string
|
||||
if totalSize >= 0 {
|
||||
header = fmt.Sprintf(
|
||||
"[file: %s | total: %d bytes | read: %s]",
|
||||
displayPath, totalSize, readRange,
|
||||
)
|
||||
} else {
|
||||
header = fmt.Sprintf(
|
||||
"[file: %s | read: %s | total size unknown]",
|
||||
displayPath, readRange,
|
||||
)
|
||||
}
|
||||
|
||||
if hasMore {
|
||||
header += fmt.Sprintf(
|
||||
"\n[TRUNCATED - file has more content. Call read_file again with offset=%d to continue.]",
|
||||
readEnd,
|
||||
)
|
||||
} else {
|
||||
header += "\n[END OF FILE - no further content.]"
|
||||
}
|
||||
|
||||
logger.DebugCF("tool", "ReadFileTool execution completed successfully",
|
||||
map[string]any{
|
||||
"path": path,
|
||||
"bytes_read": len(data),
|
||||
"has_more": hasMore,
|
||||
})
|
||||
|
||||
return NewToolResult(header + "\n\n" + string(data))
|
||||
}
|
||||
|
||||
// getInt64Arg extracts an integer argument from the args map, returning the
|
||||
// provided default if the key is absent.
|
||||
func getInt64Arg(args map[string]any, key string, defaultVal int64) (int64, error) {
|
||||
raw, exists := args[key]
|
||||
if !exists {
|
||||
return defaultVal, nil
|
||||
}
|
||||
|
||||
switch v := raw.(type) {
|
||||
case float64:
|
||||
if v != math.Trunc(v) {
|
||||
return 0, fmt.Errorf("%s must be an integer, got float %v", key, v)
|
||||
}
|
||||
if v > math.MaxInt64 || v < math.MinInt64 {
|
||||
return 0, fmt.Errorf("%s value %v overflows int64", key, v)
|
||||
}
|
||||
return int64(v), nil
|
||||
case int:
|
||||
return int64(v), nil
|
||||
case int64:
|
||||
return v, nil
|
||||
case string:
|
||||
parsed, err := strconv.ParseInt(v, 10, 64)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("invalid integer format for %s parameter: %w", key, err)
|
||||
}
|
||||
return parsed, nil
|
||||
default:
|
||||
return 0, fmt.Errorf("unsupported type %T for %s parameter", raw, key)
|
||||
}
|
||||
}
|
||||
|
||||
type WriteFileTool struct {
|
||||
@@ -249,6 +441,7 @@ type fileSystem interface {
|
||||
ReadFile(path string) ([]byte, error)
|
||||
WriteFile(path string, data []byte) error
|
||||
ReadDir(path string) ([]os.DirEntry, error)
|
||||
Open(path string) (fs.File, error)
|
||||
}
|
||||
|
||||
// hostFs is an unrestricted fileReadWriter that operates directly on the host filesystem.
|
||||
@@ -278,6 +471,20 @@ func (h *hostFs) WriteFile(path string, data []byte) error {
|
||||
return fileutil.WriteFileAtomic(path, data, 0o600)
|
||||
}
|
||||
|
||||
func (h *hostFs) Open(path string) (fs.File, error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil, fmt.Errorf("failed to open file: file not found: %w", err)
|
||||
}
|
||||
if os.IsPermission(err) {
|
||||
return nil, fmt.Errorf("failed to open file: access denied: %w", err)
|
||||
}
|
||||
return nil, fmt.Errorf("failed to open file: %w", err)
|
||||
}
|
||||
return f, nil
|
||||
}
|
||||
|
||||
// sandboxFs is a sandboxed fileSystem that operates within a strictly defined workspace using os.Root.
|
||||
type sandboxFs struct {
|
||||
workspace string
|
||||
@@ -389,6 +596,26 @@ func (r *sandboxFs) ReadDir(path string) ([]os.DirEntry, error) {
|
||||
return entries, err
|
||||
}
|
||||
|
||||
func (r *sandboxFs) Open(path string) (fs.File, error) {
|
||||
var f fs.File
|
||||
err := r.execute(path, func(root *os.Root, relPath string) error {
|
||||
file, err := root.Open(relPath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return fmt.Errorf("failed to open file: file not found: %w", err)
|
||||
}
|
||||
if os.IsPermission(err) || strings.Contains(err.Error(), "escapes from parent") ||
|
||||
strings.Contains(err.Error(), "permission denied") {
|
||||
return fmt.Errorf("failed to open file: access denied: %w", err)
|
||||
}
|
||||
return fmt.Errorf("failed to open file: %w", err)
|
||||
}
|
||||
f = file
|
||||
return nil
|
||||
})
|
||||
return f, err
|
||||
}
|
||||
|
||||
// whitelistFs wraps a sandboxFs and allows access to specific paths outside
|
||||
// the workspace when they match any of the provided patterns.
|
||||
type whitelistFs struct {
|
||||
@@ -427,6 +654,13 @@ func (w *whitelistFs) ReadDir(path string) ([]os.DirEntry, error) {
|
||||
return w.sandbox.ReadDir(path)
|
||||
}
|
||||
|
||||
func (w *whitelistFs) Open(path string) (fs.File, error) {
|
||||
if w.matches(path) {
|
||||
return w.host.Open(path)
|
||||
}
|
||||
return w.sandbox.Open(path)
|
||||
}
|
||||
|
||||
// buildFs returns the appropriate fileSystem implementation based on restriction
|
||||
// settings and optional path whitelist patterns.
|
||||
func buildFs(workspace string, restrict bool, patterns []*regexp.Regexp) fileSystem {
|
||||
|
||||
@@ -18,7 +18,7 @@ func TestFilesystemTool_ReadFile_Success(t *testing.T) {
|
||||
testFile := filepath.Join(tmpDir, "test.txt")
|
||||
os.WriteFile(testFile, []byte("test content"), 0o644)
|
||||
|
||||
tool := NewReadFileTool("", false)
|
||||
tool := NewReadFileTool("", false, MaxReadFileSize)
|
||||
ctx := context.Background()
|
||||
args := map[string]any{
|
||||
"path": testFile,
|
||||
@@ -45,7 +45,7 @@ func TestFilesystemTool_ReadFile_Success(t *testing.T) {
|
||||
|
||||
// TestFilesystemTool_ReadFile_NotFound verifies error handling for missing file
|
||||
func TestFilesystemTool_ReadFile_NotFound(t *testing.T) {
|
||||
tool := NewReadFileTool("", false)
|
||||
tool := NewReadFileTool("", false, MaxReadFileSize)
|
||||
ctx := context.Background()
|
||||
args := map[string]any{
|
||||
"path": "/nonexistent_file_12345.txt",
|
||||
@@ -59,7 +59,7 @@ func TestFilesystemTool_ReadFile_NotFound(t *testing.T) {
|
||||
}
|
||||
|
||||
// Should contain error message
|
||||
if !strings.Contains(result.ForLLM, "failed to read") && !strings.Contains(result.ForUser, "failed to read") {
|
||||
if !strings.Contains(result.ForLLM, "failed to open file") && !strings.Contains(result.ForUser, "failed to read") {
|
||||
t.Errorf("Expected error message, got ForLLM: %s, ForUser: %s", result.ForLLM, result.ForUser)
|
||||
}
|
||||
}
|
||||
@@ -271,7 +271,7 @@ func TestFilesystemTool_ReadFile_RejectsSymlinkEscape(t *testing.T) {
|
||||
t.Skipf("symlink not supported in this environment: %v", err)
|
||||
}
|
||||
|
||||
tool := NewReadFileTool(workspace, true)
|
||||
tool := NewReadFileTool(workspace, true, MaxReadFileSize)
|
||||
result := tool.Execute(context.Background(), map[string]any{
|
||||
"path": link,
|
||||
})
|
||||
@@ -289,7 +289,7 @@ func TestFilesystemTool_ReadFile_RejectsSymlinkEscape(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestFilesystemTool_EmptyWorkspace_AccessDenied(t *testing.T) {
|
||||
tool := NewReadFileTool("", true) // restrict=true but workspace=""
|
||||
tool := NewReadFileTool("", true, MaxReadFileSize) // restrict=true but workspace=""
|
||||
|
||||
// Try to read a sensitive file (simulated by a temp file outside workspace)
|
||||
tmpDir := t.TempDir()
|
||||
@@ -499,7 +499,7 @@ func TestWhitelistFs_AllowsMatchingPaths(t *testing.T) {
|
||||
// Pattern allows access to the outsideDir.
|
||||
patterns := []*regexp.Regexp{regexp.MustCompile(`^` + regexp.QuoteMeta(outsideDir))}
|
||||
|
||||
tool := NewReadFileTool(workspace, true, patterns)
|
||||
tool := NewReadFileTool(workspace, true, MaxReadFileSize, patterns)
|
||||
|
||||
// Read from whitelisted path should succeed.
|
||||
result := tool.Execute(context.Background(), map[string]any{"path": outsideFile})
|
||||
@@ -520,3 +520,127 @@ func TestWhitelistFs_AllowsMatchingPaths(t *testing.T) {
|
||||
t.Errorf("expected non-whitelisted path to be blocked, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
// TestReadFileTool_ChunkedReading verifies the pagination logic of the tool
|
||||
// by reading a file in multiple chunks using 'offset' and 'length'.
|
||||
func TestReadFileTool_ChunkedReading(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
testFile := filepath.Join(tmpDir, "pagination_test.txt")
|
||||
|
||||
// Create a test file with exactly 26 bytes of content
|
||||
fullContent := "abcdefghijklmnopqrstuvwxyz"
|
||||
err := os.WriteFile(testFile, []byte(fullContent), 0o644)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to write test file: %v", err)
|
||||
}
|
||||
|
||||
tool := NewReadFileTool(tmpDir, false, MaxReadFileSize)
|
||||
ctx := context.Background()
|
||||
|
||||
// --- Step 1: Read the first chunk (10 bytes) ---
|
||||
args1 := map[string]any{
|
||||
"path": testFile,
|
||||
"offset": 0,
|
||||
"length": 10,
|
||||
}
|
||||
result1 := tool.Execute(ctx, args1)
|
||||
|
||||
if result1.IsError {
|
||||
t.Fatalf("Chunk 1 failed: %s", result1.ForLLM)
|
||||
}
|
||||
|
||||
// Expect the first 10 characters
|
||||
if !strings.Contains(result1.ForLLM, "abcdefghij") {
|
||||
t.Errorf("Chunk 1 should contain 'abcdefghij', got: %s", result1.ForLLM)
|
||||
}
|
||||
// Expect the header to indicate the file is truncated
|
||||
if !strings.Contains(result1.ForLLM, "[TRUNCATED") {
|
||||
t.Errorf("Chunk 1 header should indicate truncation, got: %s", result1.ForLLM)
|
||||
}
|
||||
// Expect the header to suggest the next offset (10)
|
||||
if !strings.Contains(result1.ForLLM, "offset=10") {
|
||||
t.Errorf("Chunk 1 header should suggest next offset=10, got: %s", result1.ForLLM)
|
||||
}
|
||||
|
||||
// Step 2: Read the second chunk (10 bytes) ---
|
||||
args2 := map[string]any{
|
||||
"path": testFile,
|
||||
"offset": 10,
|
||||
"length": 10,
|
||||
}
|
||||
result2 := tool.Execute(ctx, args2)
|
||||
|
||||
if result2.IsError {
|
||||
t.Fatalf("Chunk 2 failed: %s", result2.ForLLM)
|
||||
}
|
||||
|
||||
// Expect the next 10 characters
|
||||
if !strings.Contains(result2.ForLLM, "klmnopqrst") {
|
||||
t.Errorf("Chunk 2 should contain 'klmnopqrst', got: %s", result2.ForLLM)
|
||||
}
|
||||
// Expect the header to suggest the next offset (20)
|
||||
if !strings.Contains(result2.ForLLM, "offset=20") {
|
||||
t.Errorf("Chunk 2 header should suggest next offset=20, got: %s", result2.ForLLM)
|
||||
}
|
||||
|
||||
// Step 3: Read the final chunk (remaining 6 bytes) ---
|
||||
// We ask for 10 bytes, but only 6 are left in the file
|
||||
args3 := map[string]any{
|
||||
"path": testFile,
|
||||
"offset": 20,
|
||||
"length": 10,
|
||||
}
|
||||
result3 := tool.Execute(ctx, args3)
|
||||
|
||||
if result3.IsError {
|
||||
t.Fatalf("Chunk 3 failed: %s", result3.ForLLM)
|
||||
}
|
||||
|
||||
// Expect the last 6 characters
|
||||
if !strings.Contains(result3.ForLLM, "uvwxyz") {
|
||||
t.Errorf("Chunk 3 should contain 'uvwxyz', got: %s", result3.ForLLM)
|
||||
}
|
||||
// Expect the header to indicate the end of the file
|
||||
if !strings.Contains(result3.ForLLM, "[END OF FILE") {
|
||||
t.Errorf("Chunk 3 header should indicate end of file, got: %s", result3.ForLLM)
|
||||
}
|
||||
|
||||
// Ensure no TRUNCATED message is present in the final chunk
|
||||
if strings.Contains(result3.ForLLM, "[TRUNCATED") {
|
||||
t.Errorf("Chunk 3 header should NOT indicate truncation, got: %s", result3.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
// TestReadFileTool_OffsetBeyondEOF checks the behavior when requesting
|
||||
// An offset that exceeds the total file size.
|
||||
func TestReadFileTool_OffsetBeyondEOF(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
testFile := filepath.Join(tmpDir, "short.txt")
|
||||
|
||||
// create a file of only 5 bytes
|
||||
err := os.WriteFile(testFile, []byte("12345"), 0o644)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to write test file: %v", err)
|
||||
}
|
||||
|
||||
tool := NewReadFileTool(tmpDir, false, MaxReadFileSize)
|
||||
ctx := context.Background()
|
||||
|
||||
args := map[string]any{
|
||||
"path": testFile,
|
||||
"offset": int64(100), // Offset beyond the end of the file
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// It should not be classified as a tool execution error
|
||||
if result.IsError {
|
||||
t.Errorf("A mistake was not expected, obtained IsError=true: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
// Must return EXACTLY the string provided in the code
|
||||
expectedMsg := "[END OF FILE - no content at this offset]"
|
||||
if result.ForLLM != expectedMsg {
|
||||
t.Errorf("The message %q was expected, obtained: %q", expectedMsg, result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
+137
-11
@@ -5,20 +5,28 @@ import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
type ToolEntry struct {
|
||||
Tool Tool
|
||||
IsCore bool
|
||||
TTL int
|
||||
}
|
||||
|
||||
type ToolRegistry struct {
|
||||
tools map[string]Tool
|
||||
mu sync.RWMutex
|
||||
tools map[string]*ToolEntry
|
||||
mu sync.RWMutex
|
||||
version atomic.Uint64 // incremented on Register/RegisterHidden for cache invalidation
|
||||
}
|
||||
|
||||
func NewToolRegistry() *ToolRegistry {
|
||||
return &ToolRegistry{
|
||||
tools: make(map[string]Tool),
|
||||
tools: make(map[string]*ToolEntry),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -30,14 +38,116 @@ func (r *ToolRegistry) Register(tool Tool) {
|
||||
logger.WarnCF("tools", "Tool registration overwrites existing tool",
|
||||
map[string]any{"name": name})
|
||||
}
|
||||
r.tools[name] = tool
|
||||
r.tools[name] = &ToolEntry{
|
||||
Tool: tool,
|
||||
IsCore: true,
|
||||
TTL: 0, // Core tools do not use TTL
|
||||
}
|
||||
r.version.Add(1)
|
||||
logger.DebugCF("tools", "Registered core tool", map[string]any{"name": name})
|
||||
}
|
||||
|
||||
// RegisterHidden saves hidden tools (visible only via TTL)
|
||||
func (r *ToolRegistry) RegisterHidden(tool Tool) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
name := tool.Name()
|
||||
if _, exists := r.tools[name]; exists {
|
||||
logger.WarnCF("tools", "Hidden tool registration overwrites existing tool",
|
||||
map[string]any{"name": name})
|
||||
}
|
||||
r.tools[name] = &ToolEntry{
|
||||
Tool: tool,
|
||||
IsCore: false,
|
||||
TTL: 0,
|
||||
}
|
||||
r.version.Add(1)
|
||||
logger.DebugCF("tools", "Registered hidden tool", map[string]any{"name": name})
|
||||
}
|
||||
|
||||
// PromoteTools atomically sets the TTL for multiple non-core tools.
|
||||
// This prevents a concurrent TickTTL from decrementing between promotions.
|
||||
func (r *ToolRegistry) PromoteTools(names []string, ttl int) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
promoted := 0
|
||||
for _, name := range names {
|
||||
if entry, exists := r.tools[name]; exists {
|
||||
if !entry.IsCore {
|
||||
entry.TTL = ttl
|
||||
promoted++
|
||||
}
|
||||
}
|
||||
}
|
||||
logger.DebugCF(
|
||||
"tools",
|
||||
"PromoteTools completed",
|
||||
map[string]any{"requested": len(names), "promoted": promoted, "ttl": ttl},
|
||||
)
|
||||
}
|
||||
|
||||
// TickTTL decreases TTL only for non-core tools
|
||||
func (r *ToolRegistry) TickTTL() {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
for _, entry := range r.tools {
|
||||
if !entry.IsCore && entry.TTL > 0 {
|
||||
entry.TTL--
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Version returns the current registry version (atomically).
|
||||
func (r *ToolRegistry) Version() uint64 {
|
||||
return r.version.Load()
|
||||
}
|
||||
|
||||
// HiddenToolSnapshot holds a consistent snapshot of hidden tools and the
|
||||
// registry version at which it was taken. Used by BM25SearchTool cache.
|
||||
type HiddenToolSnapshot struct {
|
||||
Docs []HiddenToolDoc
|
||||
Version uint64
|
||||
}
|
||||
|
||||
// HiddenToolDoc is a lightweight representation of a hidden tool for search indexing.
|
||||
type HiddenToolDoc struct {
|
||||
Name string
|
||||
Description string
|
||||
}
|
||||
|
||||
// SnapshotHiddenTools returns all non-core tools and the current registry
|
||||
// version under a single read-lock, guaranteeing consistency between the
|
||||
// two values.
|
||||
func (r *ToolRegistry) SnapshotHiddenTools() HiddenToolSnapshot {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
docs := make([]HiddenToolDoc, 0, len(r.tools))
|
||||
for name, entry := range r.tools {
|
||||
if !entry.IsCore {
|
||||
docs = append(docs, HiddenToolDoc{
|
||||
Name: name,
|
||||
Description: entry.Tool.Description(),
|
||||
})
|
||||
}
|
||||
}
|
||||
return HiddenToolSnapshot{
|
||||
Docs: docs,
|
||||
Version: r.version.Load(),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *ToolRegistry) Get(name string) (Tool, bool) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
tool, ok := r.tools[name]
|
||||
return tool, ok
|
||||
entry, ok := r.tools[name]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
// Hidden tools with expired TTL are not callable.
|
||||
if !entry.IsCore && entry.TTL <= 0 {
|
||||
return nil, false
|
||||
}
|
||||
return entry.Tool, true
|
||||
}
|
||||
|
||||
func (r *ToolRegistry) Execute(ctx context.Context, name string, args map[string]any) *ToolResult {
|
||||
@@ -135,7 +245,13 @@ func (r *ToolRegistry) GetDefinitions() []map[string]any {
|
||||
sorted := r.sortedToolNames()
|
||||
definitions := make([]map[string]any, 0, len(sorted))
|
||||
for _, name := range sorted {
|
||||
definitions = append(definitions, ToolToSchema(r.tools[name]))
|
||||
entry := r.tools[name]
|
||||
|
||||
if !entry.IsCore && entry.TTL <= 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
definitions = append(definitions, ToolToSchema(r.tools[name].Tool))
|
||||
}
|
||||
return definitions
|
||||
}
|
||||
@@ -149,8 +265,13 @@ func (r *ToolRegistry) ToProviderDefs() []providers.ToolDefinition {
|
||||
sorted := r.sortedToolNames()
|
||||
definitions := make([]providers.ToolDefinition, 0, len(sorted))
|
||||
for _, name := range sorted {
|
||||
tool := r.tools[name]
|
||||
schema := ToolToSchema(tool)
|
||||
entry := r.tools[name]
|
||||
|
||||
if !entry.IsCore && entry.TTL <= 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
schema := ToolToSchema(entry.Tool)
|
||||
|
||||
// Safely extract nested values with type checks
|
||||
fn, ok := schema["function"].(map[string]any)
|
||||
@@ -198,8 +319,13 @@ func (r *ToolRegistry) GetSummaries() []string {
|
||||
sorted := r.sortedToolNames()
|
||||
summaries := make([]string, 0, len(sorted))
|
||||
for _, name := range sorted {
|
||||
tool := r.tools[name]
|
||||
summaries = append(summaries, fmt.Sprintf("- `%s` - %s", tool.Name(), tool.Description()))
|
||||
entry := r.tools[name]
|
||||
|
||||
if !entry.IsCore && entry.TTL <= 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
summaries = append(summaries, fmt.Sprintf("- `%s` - %s", entry.Tool.Name(), entry.Tool.Description()))
|
||||
}
|
||||
return summaries
|
||||
}
|
||||
|
||||
@@ -0,0 +1,304 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/utils"
|
||||
)
|
||||
|
||||
const (
|
||||
MaxRegexPatternLength = 200
|
||||
)
|
||||
|
||||
type RegexSearchTool struct {
|
||||
registry *ToolRegistry
|
||||
ttl int
|
||||
maxSearchResults int
|
||||
}
|
||||
|
||||
func NewRegexSearchTool(r *ToolRegistry, ttl int, maxSearchResults int) *RegexSearchTool {
|
||||
return &RegexSearchTool{registry: r, ttl: ttl, maxSearchResults: maxSearchResults}
|
||||
}
|
||||
|
||||
func (t *RegexSearchTool) Name() string {
|
||||
return "tool_search_tool_regex"
|
||||
}
|
||||
|
||||
func (t *RegexSearchTool) Description() string {
|
||||
return "Search available hidden tools on-demand using a regex pattern. Returns JSON schemas of discovered tools."
|
||||
}
|
||||
|
||||
func (t *RegexSearchTool) Parameters() map[string]any {
|
||||
return map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"pattern": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Regex pattern to match tool name or description",
|
||||
},
|
||||
},
|
||||
"required": []string{"pattern"},
|
||||
}
|
||||
}
|
||||
|
||||
func (t *RegexSearchTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
|
||||
pattern, ok := args["pattern"].(string)
|
||||
if !ok || strings.TrimSpace(pattern) == "" {
|
||||
// An empty string regex (?i) will match every hidden tool,
|
||||
// dumping massive payloads into the context and burning tokens.
|
||||
return ErrorResult("Missing or invalid 'pattern' argument. Must be a non-empty string.")
|
||||
}
|
||||
|
||||
if len(pattern) > MaxRegexPatternLength {
|
||||
logger.WarnCF("discovery", "Regex pattern rejected (too long)", map[string]any{"len": len(pattern)})
|
||||
return ErrorResult(fmt.Sprintf("Pattern too long: max %d characters allowed", MaxRegexPatternLength))
|
||||
}
|
||||
|
||||
logger.DebugCF("discovery", "Regex search", map[string]any{"pattern": pattern})
|
||||
|
||||
res, err := t.registry.SearchRegex(pattern, t.maxSearchResults)
|
||||
if err != nil {
|
||||
logger.WarnCF("discovery", "Invalid regex pattern", map[string]any{"pattern": pattern, "error": err.Error()})
|
||||
return ErrorResult(fmt.Sprintf("Invalid regex pattern syntax: %v. Please fix your regex and try again.", err))
|
||||
}
|
||||
|
||||
logger.InfoCF("discovery", "Regex search completed", map[string]any{"pattern": pattern, "results": len(res)})
|
||||
return formatDiscoveryResponse(t.registry, res, t.ttl)
|
||||
}
|
||||
|
||||
type BM25SearchTool struct {
|
||||
registry *ToolRegistry
|
||||
ttl int
|
||||
maxSearchResults int
|
||||
|
||||
// Cache: rebuilt only when the registry version changes.
|
||||
cacheMu sync.Mutex
|
||||
cachedEngine *bm25CachedEngine
|
||||
cacheVersion uint64
|
||||
}
|
||||
|
||||
func NewBM25SearchTool(r *ToolRegistry, ttl int, maxSearchResults int) *BM25SearchTool {
|
||||
return &BM25SearchTool{registry: r, ttl: ttl, maxSearchResults: maxSearchResults}
|
||||
}
|
||||
|
||||
func (t *BM25SearchTool) Name() string {
|
||||
return "tool_search_tool_bm25"
|
||||
}
|
||||
|
||||
func (t *BM25SearchTool) Description() string {
|
||||
return "Search available hidden tools on-demand using natural language query describing the action you need to perform. Returns JSON schemas of discovered tools."
|
||||
}
|
||||
|
||||
func (t *BM25SearchTool) Parameters() map[string]any {
|
||||
return map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"query": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Search query",
|
||||
},
|
||||
},
|
||||
"required": []string{"query"},
|
||||
}
|
||||
}
|
||||
|
||||
func (t *BM25SearchTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
|
||||
query, ok := args["query"].(string)
|
||||
if !ok || strings.TrimSpace(query) == "" {
|
||||
// An empty string query will match every hidden tool,
|
||||
// dumping massive payloads into the context and burning tokens.
|
||||
return ErrorResult("Missing or invalid 'query' argument. Must be a non-empty string.")
|
||||
}
|
||||
|
||||
logger.DebugCF("discovery", "BM25 search", map[string]any{"query": query})
|
||||
|
||||
cached := t.getOrBuildEngine()
|
||||
if cached == nil {
|
||||
logger.DebugCF("discovery", "BM25 search: no hidden tools available", nil)
|
||||
return SilentResult("No tools found matching the query.")
|
||||
}
|
||||
|
||||
ranked := cached.engine.Search(query, t.maxSearchResults)
|
||||
if len(ranked) == 0 {
|
||||
logger.DebugCF("discovery", "BM25 search: no matches", map[string]any{"query": query})
|
||||
return SilentResult("No tools found matching the query.")
|
||||
}
|
||||
|
||||
results := make([]ToolSearchResult, len(ranked))
|
||||
for i, r := range ranked {
|
||||
results[i] = ToolSearchResult{
|
||||
Name: r.Document.Name,
|
||||
Description: r.Document.Description,
|
||||
}
|
||||
}
|
||||
|
||||
logger.InfoCF("discovery", "BM25 search completed", map[string]any{"query": query, "results": len(results)})
|
||||
return formatDiscoveryResponse(t.registry, results, t.ttl)
|
||||
}
|
||||
|
||||
// ToolSearchResult represents the result returned to the LLM.
|
||||
// Parameters are omitted from the JSON response to save context tokens;
|
||||
// the LLM will see full schemas via ToProviderDefs after promotion.
|
||||
type ToolSearchResult struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
}
|
||||
|
||||
func (r *ToolRegistry) SearchRegex(pattern string, maxSearchResults int) ([]ToolSearchResult, error) {
|
||||
if maxSearchResults <= 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
regex, err := regexp.Compile("(?i)" + pattern)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to compile regex pattern %q: %w", pattern, err)
|
||||
}
|
||||
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
var results []ToolSearchResult
|
||||
|
||||
// Iterate in sorted order for deterministic results across calls.
|
||||
for _, name := range r.sortedToolNames() {
|
||||
entry := r.tools[name]
|
||||
// Search only among the hidden tools (Core tools are already visible)
|
||||
if !entry.IsCore {
|
||||
// Directly call interface methods! No reflection/unmarshalling needed.
|
||||
desc := entry.Tool.Description()
|
||||
|
||||
if regex.MatchString(name) || regex.MatchString(desc) {
|
||||
results = append(results, ToolSearchResult{
|
||||
Name: name,
|
||||
Description: desc,
|
||||
})
|
||||
if len(results) >= maxSearchResults {
|
||||
break // Stop searching once we hit the max! Saves CPU.
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func formatDiscoveryResponse(registry *ToolRegistry, results []ToolSearchResult, ttl int) *ToolResult {
|
||||
if len(results) == 0 {
|
||||
return SilentResult("No tools found matching the query.")
|
||||
}
|
||||
|
||||
names := make([]string, len(results))
|
||||
for i, r := range results {
|
||||
names[i] = r.Name
|
||||
}
|
||||
registry.PromoteTools(names, ttl)
|
||||
logger.InfoCF("discovery", "Promoted tools", map[string]any{"tools": names, "ttl": ttl})
|
||||
|
||||
b, err := json.Marshal(results)
|
||||
if err != nil {
|
||||
return ErrorResult("Failed to format search results: " + err.Error())
|
||||
}
|
||||
|
||||
msg := fmt.Sprintf(
|
||||
"Found %d tools:\n%s\n\nSUCCESS: These tools have been temporarily UNLOCKED as native tools! In your next response, you can call them directly just like any normal tool",
|
||||
len(results),
|
||||
string(b),
|
||||
)
|
||||
|
||||
return SilentResult(msg)
|
||||
}
|
||||
|
||||
// Lightweight internal type used as corpus document for BM25.
|
||||
type searchDoc struct {
|
||||
Name string
|
||||
Description string
|
||||
}
|
||||
|
||||
// bm25CachedEngine wraps a BM25Engine with its corpus snapshot.
|
||||
type bm25CachedEngine struct {
|
||||
engine *utils.BM25Engine[searchDoc]
|
||||
}
|
||||
|
||||
// snapshotToSearchDocs converts a HiddenToolSnapshot to BM25 searchDoc slice.
|
||||
func snapshotToSearchDocs(snap HiddenToolSnapshot) []searchDoc {
|
||||
docs := make([]searchDoc, len(snap.Docs))
|
||||
for i, d := range snap.Docs {
|
||||
docs[i] = searchDoc{Name: d.Name, Description: d.Description}
|
||||
}
|
||||
return docs
|
||||
}
|
||||
|
||||
// buildBM25Engine creates a BM25Engine from a slice of searchDocs.
|
||||
func buildBM25Engine(docs []searchDoc) *utils.BM25Engine[searchDoc] {
|
||||
return utils.NewBM25Engine(
|
||||
docs,
|
||||
func(doc searchDoc) string {
|
||||
return doc.Name + " " + doc.Description
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
// getOrBuildEngine returns a cached BM25 engine, rebuilding it only when
|
||||
// the registry version has changed (new tools registered).
|
||||
func (t *BM25SearchTool) getOrBuildEngine() *bm25CachedEngine {
|
||||
// Fast path: optimistic check without locking.
|
||||
if t.cachedEngine != nil && t.cacheVersion == t.registry.Version() {
|
||||
return t.cachedEngine
|
||||
}
|
||||
|
||||
t.cacheMu.Lock()
|
||||
defer t.cacheMu.Unlock()
|
||||
|
||||
// Snapshot + version are read under a single registry RLock,
|
||||
// guaranteeing consistency (no TOCTOU).
|
||||
snap := t.registry.SnapshotHiddenTools()
|
||||
|
||||
// Re-check: another goroutine may have rebuilt while we waited for cacheMu.
|
||||
if t.cachedEngine != nil && t.cacheVersion == snap.Version {
|
||||
return t.cachedEngine
|
||||
}
|
||||
|
||||
docs := snapshotToSearchDocs(snap)
|
||||
if len(docs) == 0 {
|
||||
t.cachedEngine = nil
|
||||
t.cacheVersion = snap.Version
|
||||
return nil
|
||||
}
|
||||
|
||||
cached := &bm25CachedEngine{engine: buildBM25Engine(docs)}
|
||||
t.cachedEngine = cached
|
||||
t.cacheVersion = snap.Version
|
||||
logger.DebugCF("discovery", "BM25 engine rebuilt", map[string]any{"docs": len(docs), "version": snap.Version})
|
||||
return cached
|
||||
}
|
||||
|
||||
// SearchBM25 ranks hidden tools against query using BM25 via utils.BM25Engine.
|
||||
// This non-cached variant rebuilds the engine on every call. Used by tests
|
||||
// and any code that doesn't hold a BM25SearchTool instance.
|
||||
func (r *ToolRegistry) SearchBM25(query string, maxSearchResults int) []ToolSearchResult {
|
||||
snap := r.SnapshotHiddenTools()
|
||||
docs := snapshotToSearchDocs(snap)
|
||||
if len(docs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
ranked := buildBM25Engine(docs).Search(query, maxSearchResults)
|
||||
if len(ranked) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
out := make([]ToolSearchResult, len(ranked))
|
||||
for i, r := range ranked {
|
||||
out[i] = ToolSearchResult{
|
||||
Name: r.Document.Name,
|
||||
Description: r.Document.Description,
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
@@ -0,0 +1,339 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// Dummy tool to fill the registry in our tests.
|
||||
type mockSearchableTool struct {
|
||||
name string
|
||||
desc string
|
||||
}
|
||||
|
||||
func (m *mockSearchableTool) Name() string { return m.name }
|
||||
func (m *mockSearchableTool) Description() string { return m.desc }
|
||||
func (m *mockSearchableTool) Parameters() map[string]any {
|
||||
return map[string]any{"type": "object"}
|
||||
}
|
||||
|
||||
func (m *mockSearchableTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
|
||||
return SilentResult("mock executed: " + m.name)
|
||||
}
|
||||
|
||||
// Helper to initialize a populated ToolRegistry
|
||||
func setupPopulatedRegistry() *ToolRegistry {
|
||||
reg := NewToolRegistry()
|
||||
|
||||
// A core tool (NOT to be found by searches)
|
||||
reg.Register(&mockSearchableTool{
|
||||
name: "core_search",
|
||||
desc: "I am a visible core tool for searching files",
|
||||
})
|
||||
|
||||
// Hidden tools (must be found by searches)
|
||||
reg.RegisterHidden(&mockSearchableTool{
|
||||
name: "mcp_read_file",
|
||||
desc: "Read the contents of a system file",
|
||||
})
|
||||
reg.RegisterHidden(&mockSearchableTool{
|
||||
name: "mcp_list_dir",
|
||||
desc: "List directories and files in the system",
|
||||
})
|
||||
reg.RegisterHidden(&mockSearchableTool{
|
||||
name: "mcp_fetch_net",
|
||||
desc: "Fetch data from a network database",
|
||||
})
|
||||
|
||||
return reg
|
||||
}
|
||||
|
||||
func TestRegexSearchTool_Execute(t *testing.T) {
|
||||
reg := setupPopulatedRegistry()
|
||||
tool := NewRegexSearchTool(reg, 5, 10)
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("Empty Pattern Error", func(t *testing.T) {
|
||||
res := tool.Execute(ctx, map[string]any{})
|
||||
if !res.IsError || !strings.Contains(res.ForLLM, "Missing or invalid 'pattern'") {
|
||||
t.Errorf("Expected missing pattern error, got: %v", res.ForLLM)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Invalid Regex Syntax", func(t *testing.T) {
|
||||
res := tool.Execute(ctx, map[string]any{"pattern": "[unclosed"})
|
||||
if !res.IsError || !strings.Contains(res.ForLLM, "Invalid regex pattern syntax") {
|
||||
t.Errorf("Expected regex syntax error, got: %v", res.ForLLM)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("No Match Found", func(t *testing.T) {
|
||||
res := tool.Execute(ctx, map[string]any{"pattern": "alien"})
|
||||
if res.IsError || !strings.Contains(res.ForLLM, "No tools found matching") {
|
||||
t.Errorf("Expected 'no tools found' message, got: %v", res.ForLLM)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Successful Match & Promotion", func(t *testing.T) {
|
||||
res := tool.Execute(ctx, map[string]any{"pattern": "system"})
|
||||
|
||||
if res.IsError {
|
||||
t.Fatalf("Unexpected error: %v", res.ForLLM)
|
||||
}
|
||||
if !strings.Contains(res.ForLLM, "SUCCESS: These tools have been temporarily UNLOCKED") {
|
||||
t.Errorf("Expected success string, got: %v", res.ForLLM)
|
||||
}
|
||||
if !strings.Contains(res.ForLLM, "mcp_read_file") {
|
||||
t.Errorf("Expected 'mcp_read_file' in results")
|
||||
}
|
||||
|
||||
// Verify that the TTL has been updated for the tools found
|
||||
reg.mu.RLock()
|
||||
defer reg.mu.RUnlock()
|
||||
if reg.tools["mcp_read_file"].TTL != 5 {
|
||||
t.Errorf("Expected TTL of 'mcp_read_file' to be promoted to 5, got %d", reg.tools["mcp_read_file"].TTL)
|
||||
}
|
||||
if reg.tools["mcp_fetch_net"].TTL != 0 {
|
||||
t.Errorf("Expected 'mcp_fetch_net' to NOT be promoted (TTL=0)")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestBM25SearchTool_Execute(t *testing.T) {
|
||||
reg := setupPopulatedRegistry()
|
||||
tool := NewBM25SearchTool(reg, 3, 10)
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("Empty Query Error", func(t *testing.T) {
|
||||
res := tool.Execute(ctx, map[string]any{"query": " "})
|
||||
if !res.IsError || !strings.Contains(res.ForLLM, "Missing or invalid 'query'") {
|
||||
t.Errorf("Expected missing query error, got: %v", res.ForLLM)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("No Match Found", func(t *testing.T) {
|
||||
res := tool.Execute(ctx, map[string]any{"query": "aliens spaceships"})
|
||||
if res.IsError || !strings.Contains(res.ForLLM, "No tools found matching") {
|
||||
t.Errorf("Expected 'no tools found', got: %v", res.ForLLM)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Successful Match & Promotion", func(t *testing.T) {
|
||||
res := tool.Execute(ctx, map[string]any{"query": "read files"})
|
||||
|
||||
if res.IsError {
|
||||
t.Fatalf("Unexpected error: %v", res.ForLLM)
|
||||
}
|
||||
if !strings.Contains(res.ForLLM, "mcp_read_file") {
|
||||
t.Errorf("Expected 'mcp_read_file' in BM25 results")
|
||||
}
|
||||
|
||||
reg.mu.RLock()
|
||||
defer reg.mu.RUnlock()
|
||||
if reg.tools["mcp_read_file"].TTL != 3 {
|
||||
t.Errorf("Expected TTL of 'mcp_read_file' to be promoted to 3")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestRegexSearchTool_PatternTooLong(t *testing.T) {
|
||||
reg := setupPopulatedRegistry()
|
||||
tool := NewRegexSearchTool(reg, 5, 10)
|
||||
ctx := context.Background()
|
||||
|
||||
longPattern := strings.Repeat("a", MaxRegexPatternLength+1)
|
||||
res := tool.Execute(ctx, map[string]any{"pattern": longPattern})
|
||||
if !res.IsError || !strings.Contains(res.ForLLM, "Pattern too long") {
|
||||
t.Errorf("Expected pattern too long error, got: %v", res.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearchRegex_ZeroMaxResults(t *testing.T) {
|
||||
reg := setupPopulatedRegistry()
|
||||
|
||||
res, err := reg.SearchRegex("mcp", 0)
|
||||
if err != nil {
|
||||
t.Fatalf("SearchRegex failed: %v", err)
|
||||
}
|
||||
if len(res) != 0 {
|
||||
t.Errorf("Expected 0 results with maxSearchResults=0, got %d", len(res))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearchBM25_ZeroMaxResults(t *testing.T) {
|
||||
reg := setupPopulatedRegistry()
|
||||
|
||||
res := reg.SearchBM25("read file", 0)
|
||||
if len(res) != 0 {
|
||||
t.Errorf("Expected 0 results with maxSearchResults=0, got %d", len(res))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearchRegex_DeterministicOrder(t *testing.T) {
|
||||
reg := NewToolRegistry()
|
||||
for i := 0; i < 20; i++ {
|
||||
reg.RegisterHidden(&mockSearchableTool{
|
||||
name: fmt.Sprintf("tool_%02d", i),
|
||||
desc: "searchable tool",
|
||||
})
|
||||
}
|
||||
|
||||
// Run the same search multiple times and verify order is stable
|
||||
var firstRun []string
|
||||
for attempt := 0; attempt < 10; attempt++ {
|
||||
res, err := reg.SearchRegex("searchable", 20)
|
||||
if err != nil {
|
||||
t.Fatalf("SearchRegex failed: %v", err)
|
||||
}
|
||||
|
||||
names := make([]string, len(res))
|
||||
for i, r := range res {
|
||||
names[i] = r.Name
|
||||
}
|
||||
|
||||
if attempt == 0 {
|
||||
firstRun = names
|
||||
} else {
|
||||
for i, name := range names {
|
||||
if name != firstRun[i] {
|
||||
t.Fatalf("Non-deterministic order at attempt %d, index %d: got %q, want %q",
|
||||
attempt, i, name, firstRun[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolRegistry_SearchLimitsAndCoreFiltering(t *testing.T) {
|
||||
reg := NewToolRegistry()
|
||||
|
||||
// Add 1 Core and 10 Hidden, all containing the word "match"
|
||||
reg.Register(&mockSearchableTool{"core_match", "I am core with match"})
|
||||
for i := 0; i < 10; i++ {
|
||||
reg.RegisterHidden(&mockSearchableTool{
|
||||
name: fmt.Sprintf("hidden_match_%d", i),
|
||||
desc: "this has a match",
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("Regex limits and core filtering", func(t *testing.T) {
|
||||
// Search with Regex and a limit of maxSearchResults = 4
|
||||
res, err := reg.SearchRegex("match", 4)
|
||||
if err != nil {
|
||||
t.Fatalf("SearchRegex failed: %v", err)
|
||||
}
|
||||
|
||||
if len(res) != 4 {
|
||||
t.Errorf("Expected exactly 4 results due to limit, got %d", len(res))
|
||||
}
|
||||
|
||||
for _, r := range res {
|
||||
if r.Name == "core_match" {
|
||||
t.Errorf("SearchRegex returned a Core tool, which should be excluded")
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("BM25 limits and core filtering", func(t *testing.T) {
|
||||
// Search with BM25 and a limit of maxSearchResults = 3
|
||||
res := reg.SearchBM25("match", 3)
|
||||
|
||||
if len(res) != 3 {
|
||||
t.Errorf("Expected exactly 3 results due to limit, got %d", len(res))
|
||||
}
|
||||
|
||||
for _, r := range res {
|
||||
if r.Name == "core_match" {
|
||||
t.Errorf("SearchBM25 returned a Core tool, which should be excluded")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestGet_HiddenToolTTLLifecycle(t *testing.T) {
|
||||
reg := NewToolRegistry()
|
||||
reg.RegisterHidden(&mockSearchableTool{name: "hidden_tool", desc: "test"})
|
||||
|
||||
// TTL=0 at registration → not gettable
|
||||
_, ok := reg.Get("hidden_tool")
|
||||
if ok {
|
||||
t.Error("Expected hidden tool with TTL=0 to NOT be gettable")
|
||||
}
|
||||
|
||||
// Promote → gettable
|
||||
reg.PromoteTools([]string{"hidden_tool"}, 3)
|
||||
_, ok = reg.Get("hidden_tool")
|
||||
if !ok {
|
||||
t.Error("Expected promoted hidden tool to be gettable")
|
||||
}
|
||||
|
||||
// Tick down to 0 → not gettable again
|
||||
reg.TickTTL() // 3→2
|
||||
reg.TickTTL() // 2→1
|
||||
reg.TickTTL() // 1→0
|
||||
_, ok = reg.Get("hidden_tool")
|
||||
if ok {
|
||||
t.Error("Expected hidden tool with TTL ticked to 0 to NOT be gettable")
|
||||
}
|
||||
|
||||
// Core tools remain always gettable
|
||||
reg.Register(&mockSearchableTool{name: "core_tool", desc: "core"})
|
||||
_, ok = reg.Get("core_tool")
|
||||
if !ok {
|
||||
t.Error("Expected core tool to always be gettable")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBM25CacheInvalidation(t *testing.T) {
|
||||
reg := NewToolRegistry()
|
||||
reg.RegisterHidden(&mockSearchableTool{name: "tool_alpha", desc: "alpha functionality"})
|
||||
|
||||
tool := NewBM25SearchTool(reg, 5, 10)
|
||||
ctx := context.Background()
|
||||
|
||||
// First search should find tool_alpha
|
||||
res := tool.Execute(ctx, map[string]any{"query": "alpha"})
|
||||
if !strings.Contains(res.ForLLM, "tool_alpha") {
|
||||
t.Fatalf("Expected 'tool_alpha' in first search, got: %v", res.ForLLM)
|
||||
}
|
||||
|
||||
// Register a new hidden tool
|
||||
reg.RegisterHidden(&mockSearchableTool{name: "tool_beta", desc: "beta functionality"})
|
||||
|
||||
// Cache should be invalidated; new tool should be findable
|
||||
res = tool.Execute(ctx, map[string]any{"query": "beta"})
|
||||
if !strings.Contains(res.ForLLM, "tool_beta") {
|
||||
t.Errorf("Expected 'tool_beta' after cache invalidation, got: %v", res.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPromoteTools_ConcurrentWithTickTTL(t *testing.T) {
|
||||
reg := NewToolRegistry()
|
||||
for i := 0; i < 20; i++ {
|
||||
reg.RegisterHidden(&mockSearchableTool{
|
||||
name: fmt.Sprintf("concurrent_tool_%d", i),
|
||||
desc: "concurrent test tool",
|
||||
})
|
||||
}
|
||||
|
||||
names := make([]string, 20)
|
||||
for i := 0; i < 20; i++ {
|
||||
names[i] = fmt.Sprintf("concurrent_tool_%d", i)
|
||||
}
|
||||
|
||||
// Hammer PromoteTools and TickTTL concurrently to detect races
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
for i := 0; i < 1000; i++ {
|
||||
reg.PromoteTools(names, 5)
|
||||
}
|
||||
close(done)
|
||||
}()
|
||||
|
||||
for i := 0; i < 1000; i++ {
|
||||
reg.TickTTL()
|
||||
}
|
||||
<-done
|
||||
}
|
||||
+295
-188
@@ -11,6 +11,7 @@ import (
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
@@ -76,81 +77,140 @@ func createHTTPClient(proxyURL string, timeout time.Duration) (*http.Client, err
|
||||
return client, nil
|
||||
}
|
||||
|
||||
type APIKeyPool struct {
|
||||
keys []string
|
||||
current uint32
|
||||
}
|
||||
|
||||
func NewAPIKeyPool(keys []string) *APIKeyPool {
|
||||
return &APIKeyPool{
|
||||
keys: keys,
|
||||
}
|
||||
}
|
||||
|
||||
type APIKeyIterator struct {
|
||||
pool *APIKeyPool
|
||||
startIdx uint32
|
||||
attempt uint32
|
||||
}
|
||||
|
||||
func (p *APIKeyPool) NewIterator() *APIKeyIterator {
|
||||
if len(p.keys) == 0 {
|
||||
return &APIKeyIterator{pool: p}
|
||||
}
|
||||
idx := atomic.AddUint32(&p.current, 1) - 1
|
||||
return &APIKeyIterator{
|
||||
pool: p,
|
||||
startIdx: idx,
|
||||
}
|
||||
}
|
||||
|
||||
func (it *APIKeyIterator) Next() (string, bool) {
|
||||
length := uint32(len(it.pool.keys))
|
||||
if length == 0 || it.attempt >= length {
|
||||
return "", false
|
||||
}
|
||||
key := it.pool.keys[(it.startIdx+it.attempt)%length]
|
||||
it.attempt++
|
||||
return key, true
|
||||
}
|
||||
|
||||
type SearchProvider interface {
|
||||
Search(ctx context.Context, query string, count int) (string, error)
|
||||
}
|
||||
|
||||
type BraveSearchProvider struct {
|
||||
apiKey string
|
||||
proxy string
|
||||
client *http.Client
|
||||
keyPool *APIKeyPool
|
||||
proxy string
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
func (p *BraveSearchProvider) Search(ctx context.Context, query string, count int) (string, error) {
|
||||
searchURL := fmt.Sprintf("https://api.search.brave.com/res/v1/web/search?q=%s&count=%d",
|
||||
url.QueryEscape(query), count)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", searchURL, nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
var lastErr error
|
||||
iter := p.keyPool.NewIterator()
|
||||
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("X-Subscription-Token", p.apiKey)
|
||||
|
||||
resp, err := p.client.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("brave api error (status %d): %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var searchResp struct {
|
||||
Web struct {
|
||||
Results []struct {
|
||||
Title string `json:"title"`
|
||||
URL string `json:"url"`
|
||||
Description string `json:"description"`
|
||||
} `json:"results"`
|
||||
} `json:"web"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(body, &searchResp); err != nil {
|
||||
// Log error body for debugging
|
||||
fmt.Printf("Brave API Error Body: %s\n", string(body))
|
||||
return "", fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
results := searchResp.Web.Results
|
||||
if len(results) == 0 {
|
||||
return fmt.Sprintf("No results for: %s", query), nil
|
||||
}
|
||||
|
||||
var lines []string
|
||||
lines = append(lines, fmt.Sprintf("Results for: %s", query))
|
||||
for i, item := range results {
|
||||
if i >= count {
|
||||
for {
|
||||
apiKey, ok := iter.Next()
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
lines = append(lines, fmt.Sprintf("%d. %s\n %s", i+1, item.Title, item.URL))
|
||||
if item.Description != "" {
|
||||
lines = append(lines, fmt.Sprintf(" %s", item.Description))
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", searchURL, nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("X-Subscription-Token", apiKey)
|
||||
|
||||
resp, err := p.client.Do(req)
|
||||
if err != nil {
|
||||
lastErr = fmt.Errorf("request failed: %w", err)
|
||||
continue
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
|
||||
if err != nil {
|
||||
lastErr = fmt.Errorf("failed to read response: %w", err)
|
||||
continue
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
lastErr = fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
|
||||
if resp.StatusCode == http.StatusTooManyRequests ||
|
||||
resp.StatusCode == http.StatusUnauthorized ||
|
||||
resp.StatusCode == http.StatusForbidden ||
|
||||
resp.StatusCode >= 500 {
|
||||
continue
|
||||
}
|
||||
return "", lastErr
|
||||
}
|
||||
|
||||
var searchResp struct {
|
||||
Web struct {
|
||||
Results []struct {
|
||||
Title string `json:"title"`
|
||||
URL string `json:"url"`
|
||||
Description string `json:"description"`
|
||||
} `json:"results"`
|
||||
} `json:"web"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(body, &searchResp); err != nil {
|
||||
// Log error body for debugging
|
||||
return "", fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
results := searchResp.Web.Results
|
||||
if len(results) == 0 {
|
||||
return fmt.Sprintf("No results for: %s", query), nil
|
||||
}
|
||||
|
||||
var lines []string
|
||||
lines = append(lines, fmt.Sprintf("Results for: %s", query))
|
||||
for i, item := range results {
|
||||
if i >= count {
|
||||
break
|
||||
}
|
||||
lines = append(lines, fmt.Sprintf("%d. %s\n %s", i+1, item.Title, item.URL))
|
||||
if item.Description != "" {
|
||||
lines = append(lines, fmt.Sprintf(" %s", item.Description))
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(lines, "\n"), nil
|
||||
}
|
||||
|
||||
return strings.Join(lines, "\n"), nil
|
||||
return "", fmt.Errorf("all api keys failed, last error: %w", lastErr)
|
||||
}
|
||||
|
||||
type TavilySearchProvider struct {
|
||||
apiKey string
|
||||
keyPool *APIKeyPool
|
||||
baseURL string
|
||||
proxy string
|
||||
client *http.Client
|
||||
@@ -162,74 +222,96 @@ func (p *TavilySearchProvider) Search(ctx context.Context, query string, count i
|
||||
searchURL = "https://api.tavily.com/search"
|
||||
}
|
||||
|
||||
payload := map[string]any{
|
||||
"api_key": p.apiKey,
|
||||
"query": query,
|
||||
"search_depth": "advanced",
|
||||
"include_answer": false,
|
||||
"include_images": false,
|
||||
"include_raw_content": false,
|
||||
"max_results": count,
|
||||
}
|
||||
var lastErr error
|
||||
iter := p.keyPool.NewIterator()
|
||||
|
||||
bodyBytes, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to marshal payload: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", searchURL, bytes.NewBuffer(bodyBytes))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", userAgent)
|
||||
|
||||
resp, err := p.client.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("tavily api error (status %d): %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var searchResp struct {
|
||||
Results []struct {
|
||||
Title string `json:"title"`
|
||||
URL string `json:"url"`
|
||||
Content string `json:"content"`
|
||||
} `json:"results"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(body, &searchResp); err != nil {
|
||||
return "", fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
results := searchResp.Results
|
||||
if len(results) == 0 {
|
||||
return fmt.Sprintf("No results for: %s", query), nil
|
||||
}
|
||||
|
||||
var lines []string
|
||||
lines = append(lines, fmt.Sprintf("Results for: %s (via Tavily)", query))
|
||||
for i, item := range results {
|
||||
if i >= count {
|
||||
for {
|
||||
apiKey, ok := iter.Next()
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
lines = append(lines, fmt.Sprintf("%d. %s\n %s", i+1, item.Title, item.URL))
|
||||
if item.Content != "" {
|
||||
lines = append(lines, fmt.Sprintf(" %s", item.Content))
|
||||
|
||||
payload := map[string]any{
|
||||
"api_key": apiKey,
|
||||
"query": query,
|
||||
"search_depth": "advanced",
|
||||
"include_answer": false,
|
||||
"include_images": false,
|
||||
"include_raw_content": false,
|
||||
"max_results": count,
|
||||
}
|
||||
|
||||
bodyBytes, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to marshal payload: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", searchURL, bytes.NewBuffer(bodyBytes))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", userAgent)
|
||||
|
||||
resp, err := p.client.Do(req)
|
||||
if err != nil {
|
||||
lastErr = fmt.Errorf("request failed: %w", err)
|
||||
continue
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
|
||||
if err != nil {
|
||||
lastErr = fmt.Errorf("failed to read response: %w", err)
|
||||
continue
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
lastErr = fmt.Errorf("tavily api error (status %d): %s", resp.StatusCode, string(body))
|
||||
if resp.StatusCode == http.StatusTooManyRequests ||
|
||||
resp.StatusCode == http.StatusUnauthorized ||
|
||||
resp.StatusCode == http.StatusForbidden ||
|
||||
resp.StatusCode >= 500 {
|
||||
continue
|
||||
}
|
||||
return "", lastErr
|
||||
}
|
||||
|
||||
var searchResp struct {
|
||||
Results []struct {
|
||||
Title string `json:"title"`
|
||||
URL string `json:"url"`
|
||||
Content string `json:"content"`
|
||||
} `json:"results"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(body, &searchResp); err != nil {
|
||||
return "", fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
results := searchResp.Results
|
||||
if len(results) == 0 {
|
||||
return fmt.Sprintf("No results for: %s", query), nil
|
||||
}
|
||||
|
||||
var lines []string
|
||||
lines = append(lines, fmt.Sprintf("Results for: %s (via Tavily)", query))
|
||||
for i, item := range results {
|
||||
if i >= count {
|
||||
break
|
||||
}
|
||||
lines = append(lines, fmt.Sprintf("%d. %s\n %s", i+1, item.Title, item.URL))
|
||||
if item.Content != "" {
|
||||
lines = append(lines, fmt.Sprintf(" %s", item.Content))
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(lines, "\n"), nil
|
||||
}
|
||||
|
||||
return strings.Join(lines, "\n"), nil
|
||||
return "", fmt.Errorf("all api keys failed, last error: %w", lastErr)
|
||||
}
|
||||
|
||||
type DuckDuckGoSearchProvider struct {
|
||||
@@ -324,75 +406,97 @@ func stripTags(content string) string {
|
||||
}
|
||||
|
||||
type PerplexitySearchProvider struct {
|
||||
apiKey string
|
||||
proxy string
|
||||
client *http.Client
|
||||
keyPool *APIKeyPool
|
||||
proxy string
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
func (p *PerplexitySearchProvider) Search(ctx context.Context, query string, count int) (string, error) {
|
||||
searchURL := "https://api.perplexity.ai/chat/completions"
|
||||
|
||||
payload := map[string]any{
|
||||
"model": "sonar",
|
||||
"messages": []map[string]string{
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a search assistant. Provide concise search results with titles, URLs, and brief descriptions in the following format:\n1. Title\n URL\n Description\n\nDo not add extra commentary.",
|
||||
var lastErr error
|
||||
iter := p.keyPool.NewIterator()
|
||||
|
||||
for {
|
||||
apiKey, ok := iter.Next()
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
|
||||
payload := map[string]any{
|
||||
"model": "sonar",
|
||||
"messages": []map[string]string{
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a search assistant. Provide concise search results with titles, URLs, and brief descriptions in the following format:\n1. Title\n URL\n Description\n\nDo not add extra commentary.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": fmt.Sprintf("Search for: %s. Provide up to %d relevant results.", query, count),
|
||||
},
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": fmt.Sprintf("Search for: %s. Provide up to %d relevant results.", query, count),
|
||||
},
|
||||
},
|
||||
"max_tokens": 1000,
|
||||
"max_tokens": 1000,
|
||||
}
|
||||
|
||||
payloadBytes, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", searchURL, strings.NewReader(string(payloadBytes)))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
req.Header.Set("User-Agent", userAgent)
|
||||
|
||||
resp, err := p.client.Do(req)
|
||||
if err != nil {
|
||||
lastErr = fmt.Errorf("request failed: %w", err)
|
||||
continue
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
|
||||
if err != nil {
|
||||
lastErr = fmt.Errorf("failed to read response: %w", err)
|
||||
continue
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
lastErr = fmt.Errorf("Perplexity API error: %s", string(body))
|
||||
if resp.StatusCode == http.StatusTooManyRequests ||
|
||||
resp.StatusCode == http.StatusUnauthorized ||
|
||||
resp.StatusCode == http.StatusForbidden ||
|
||||
resp.StatusCode >= 500 {
|
||||
continue
|
||||
}
|
||||
return "", lastErr
|
||||
}
|
||||
|
||||
var searchResp struct {
|
||||
Choices []struct {
|
||||
Message struct {
|
||||
Content string `json:"content"`
|
||||
} `json:"message"`
|
||||
} `json:"choices"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(body, &searchResp); err != nil {
|
||||
return "", fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
if len(searchResp.Choices) == 0 {
|
||||
return fmt.Sprintf("No results for: %s", query), nil
|
||||
}
|
||||
|
||||
return fmt.Sprintf("Results for: %s (via Perplexity)\n%s", query, searchResp.Choices[0].Message.Content), nil
|
||||
}
|
||||
|
||||
payloadBytes, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", searchURL, strings.NewReader(string(payloadBytes)))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+p.apiKey)
|
||||
req.Header.Set("User-Agent", userAgent)
|
||||
|
||||
resp, err := p.client.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("Perplexity API error: %s", string(body))
|
||||
}
|
||||
|
||||
var searchResp struct {
|
||||
Choices []struct {
|
||||
Message struct {
|
||||
Content string `json:"content"`
|
||||
} `json:"message"`
|
||||
} `json:"choices"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(body, &searchResp); err != nil {
|
||||
return "", fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
if len(searchResp.Choices) == 0 {
|
||||
return fmt.Sprintf("No results for: %s", query), nil
|
||||
}
|
||||
|
||||
return fmt.Sprintf("Results for: %s (via Perplexity)\n%s", query, searchResp.Choices[0].Message.Content), nil
|
||||
return "", fmt.Errorf("all api keys failed, last error: %w", lastErr)
|
||||
}
|
||||
|
||||
type SearXNGSearchProvider struct {
|
||||
@@ -545,16 +649,16 @@ type WebSearchTool struct {
|
||||
}
|
||||
|
||||
type WebSearchToolOptions struct {
|
||||
BraveAPIKey string
|
||||
BraveAPIKeys []string
|
||||
BraveMaxResults int
|
||||
BraveEnabled bool
|
||||
TavilyAPIKey string
|
||||
TavilyAPIKeys []string
|
||||
TavilyBaseURL string
|
||||
TavilyMaxResults int
|
||||
TavilyEnabled bool
|
||||
DuckDuckGoMaxResults int
|
||||
DuckDuckGoEnabled bool
|
||||
PerplexityAPIKey string
|
||||
PerplexityAPIKeys []string
|
||||
PerplexityMaxResults int
|
||||
PerplexityEnabled bool
|
||||
SearXNGBaseURL string
|
||||
@@ -571,23 +675,26 @@ type WebSearchToolOptions struct {
|
||||
func NewWebSearchTool(opts WebSearchToolOptions) (*WebSearchTool, error) {
|
||||
var provider SearchProvider
|
||||
maxResults := 5
|
||||
|
||||
// Priority: Perplexity > Brave > SearXNG > Tavily > DuckDuckGo > GLM Search
|
||||
if opts.PerplexityEnabled && opts.PerplexityAPIKey != "" {
|
||||
if opts.PerplexityEnabled && len(opts.PerplexityAPIKeys) > 0 {
|
||||
client, err := createHTTPClient(opts.Proxy, perplexityTimeout)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create HTTP client for Perplexity: %w", err)
|
||||
}
|
||||
provider = &PerplexitySearchProvider{apiKey: opts.PerplexityAPIKey, proxy: opts.Proxy, client: client}
|
||||
provider = &PerplexitySearchProvider{
|
||||
keyPool: NewAPIKeyPool(opts.PerplexityAPIKeys),
|
||||
proxy: opts.Proxy,
|
||||
client: client,
|
||||
}
|
||||
if opts.PerplexityMaxResults > 0 {
|
||||
maxResults = opts.PerplexityMaxResults
|
||||
}
|
||||
} else if opts.BraveEnabled && opts.BraveAPIKey != "" {
|
||||
} else if opts.BraveEnabled && len(opts.BraveAPIKeys) > 0 {
|
||||
client, err := createHTTPClient(opts.Proxy, searchTimeout)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create HTTP client for Brave: %w", err)
|
||||
}
|
||||
provider = &BraveSearchProvider{apiKey: opts.BraveAPIKey, proxy: opts.Proxy, client: client}
|
||||
provider = &BraveSearchProvider{keyPool: NewAPIKeyPool(opts.BraveAPIKeys), proxy: opts.Proxy, client: client}
|
||||
if opts.BraveMaxResults > 0 {
|
||||
maxResults = opts.BraveMaxResults
|
||||
}
|
||||
@@ -596,13 +703,13 @@ func NewWebSearchTool(opts WebSearchToolOptions) (*WebSearchTool, error) {
|
||||
if opts.SearXNGMaxResults > 0 {
|
||||
maxResults = opts.SearXNGMaxResults
|
||||
}
|
||||
} else if opts.TavilyEnabled && opts.TavilyAPIKey != "" {
|
||||
} else if opts.TavilyEnabled && len(opts.TavilyAPIKeys) > 0 {
|
||||
client, err := createHTTPClient(opts.Proxy, searchTimeout)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create HTTP client for Tavily: %w", err)
|
||||
}
|
||||
provider = &TavilySearchProvider{
|
||||
apiKey: opts.TavilyAPIKey,
|
||||
keyPool: NewAPIKeyPool(opts.TavilyAPIKeys),
|
||||
baseURL: opts.TavilyBaseURL,
|
||||
proxy: opts.Proxy,
|
||||
client: client,
|
||||
|
||||
+124
-5
@@ -249,7 +249,7 @@ func TestWebFetchTool_PayloadTooLarge(t *testing.T) {
|
||||
|
||||
// TestWebTool_WebSearch_NoApiKey verifies that no tool is created when API key is missing
|
||||
func TestWebTool_WebSearch_NoApiKey(t *testing.T) {
|
||||
tool, err := NewWebSearchTool(WebSearchToolOptions{BraveEnabled: true, BraveAPIKey: ""})
|
||||
tool, err := NewWebSearchTool(WebSearchToolOptions{BraveEnabled: true, BraveAPIKeys: nil})
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
@@ -269,7 +269,11 @@ func TestWebTool_WebSearch_NoApiKey(t *testing.T) {
|
||||
|
||||
// TestWebTool_WebSearch_MissingQuery verifies error handling for missing query
|
||||
func TestWebTool_WebSearch_MissingQuery(t *testing.T) {
|
||||
tool, err := NewWebSearchTool(WebSearchToolOptions{BraveEnabled: true, BraveAPIKey: "test-key", BraveMaxResults: 5})
|
||||
tool, err := NewWebSearchTool(WebSearchToolOptions{
|
||||
BraveEnabled: true,
|
||||
BraveAPIKeys: []string{"test-key"},
|
||||
BraveMaxResults: 5,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
@@ -553,7 +557,7 @@ func TestNewWebSearchTool_PropagatesProxy(t *testing.T) {
|
||||
t.Run("perplexity", func(t *testing.T) {
|
||||
tool, err := NewWebSearchTool(WebSearchToolOptions{
|
||||
PerplexityEnabled: true,
|
||||
PerplexityAPIKey: "k",
|
||||
PerplexityAPIKeys: []string{"k"},
|
||||
PerplexityMaxResults: 3,
|
||||
Proxy: "http://127.0.0.1:7890",
|
||||
})
|
||||
@@ -572,7 +576,7 @@ func TestNewWebSearchTool_PropagatesProxy(t *testing.T) {
|
||||
t.Run("brave", func(t *testing.T) {
|
||||
tool, err := NewWebSearchTool(WebSearchToolOptions{
|
||||
BraveEnabled: true,
|
||||
BraveAPIKey: "k",
|
||||
BraveAPIKeys: []string{"k"},
|
||||
BraveMaxResults: 3,
|
||||
Proxy: "http://127.0.0.1:7890",
|
||||
})
|
||||
@@ -650,7 +654,7 @@ func TestWebTool_TavilySearch_Success(t *testing.T) {
|
||||
|
||||
tool, err := NewWebSearchTool(WebSearchToolOptions{
|
||||
TavilyEnabled: true,
|
||||
TavilyAPIKey: "test-key",
|
||||
TavilyAPIKeys: []string{"test-key"},
|
||||
TavilyBaseURL: server.URL,
|
||||
TavilyMaxResults: 5,
|
||||
})
|
||||
@@ -682,6 +686,121 @@ func TestWebTool_TavilySearch_Success(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIKeyPool(t *testing.T) {
|
||||
pool := NewAPIKeyPool([]string{"key1", "key2", "key3"})
|
||||
if len(pool.keys) != 3 {
|
||||
t.Fatalf("expected 3 keys, got %d", len(pool.keys))
|
||||
}
|
||||
if pool.keys[0] != "key1" || pool.keys[1] != "key2" || pool.keys[2] != "key3" {
|
||||
t.Fatalf("unexpected keys: %v", pool.keys)
|
||||
}
|
||||
|
||||
// Test Iterator: each iterator should cover all keys exactly once
|
||||
iter := pool.NewIterator()
|
||||
expected := []string{"key1", "key2", "key3"}
|
||||
for i, want := range expected {
|
||||
k, ok := iter.Next()
|
||||
if !ok {
|
||||
t.Fatalf("iter.Next() returned false at step %d", i)
|
||||
}
|
||||
if k != want {
|
||||
t.Errorf("step %d: expected %s, got %s", i, want, k)
|
||||
}
|
||||
}
|
||||
// Should be exhausted
|
||||
if _, ok := iter.Next(); ok {
|
||||
t.Errorf("expected iterator exhausted after all keys")
|
||||
}
|
||||
|
||||
// Second iterator starts at next position (load balancing)
|
||||
iter2 := pool.NewIterator()
|
||||
k, ok := iter2.Next()
|
||||
if !ok {
|
||||
t.Fatal("iter2.Next() returned false")
|
||||
}
|
||||
if k != "key2" {
|
||||
t.Errorf("expected key2 (round-robin), got %s", k)
|
||||
}
|
||||
|
||||
// Empty pool
|
||||
emptyPool := NewAPIKeyPool([]string{})
|
||||
emptyIter := emptyPool.NewIterator()
|
||||
if _, ok := emptyIter.Next(); ok {
|
||||
t.Errorf("expected false for empty pool")
|
||||
}
|
||||
|
||||
// Single key pool
|
||||
singlePool := NewAPIKeyPool([]string{"single"})
|
||||
singleIter := singlePool.NewIterator()
|
||||
if k, ok := singleIter.Next(); !ok || k != "single" {
|
||||
t.Errorf("expected single, got %s (ok=%v)", k, ok)
|
||||
}
|
||||
if _, ok := singleIter.Next(); ok {
|
||||
t.Errorf("expected exhausted after single key")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebTool_TavilySearch_Failover(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var payload map[string]any
|
||||
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
|
||||
t.Fatalf("failed to decode payload: %v", err)
|
||||
}
|
||||
|
||||
apiKey := payload["api_key"].(string)
|
||||
|
||||
if apiKey == "key1" {
|
||||
w.WriteHeader(http.StatusTooManyRequests)
|
||||
w.Write([]byte("Rate limited"))
|
||||
return
|
||||
}
|
||||
|
||||
if apiKey == "key2" {
|
||||
// Success
|
||||
response := map[string]any{
|
||||
"results": []map[string]any{
|
||||
{
|
||||
"title": "Success Result",
|
||||
"url": "https://example.com/success",
|
||||
"content": "Success content",
|
||||
},
|
||||
},
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(response)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
tool, err := NewWebSearchTool(WebSearchToolOptions{
|
||||
TavilyEnabled: true,
|
||||
TavilyAPIKeys: []string{"key1", "key2"},
|
||||
TavilyBaseURL: server.URL,
|
||||
TavilyMaxResults: 5,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewWebSearchTool() error: %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]any{
|
||||
"query": "test query",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
if result.IsError {
|
||||
t.Errorf("Expected success, got Error: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForUser, "Success Result") {
|
||||
t.Errorf("Expected failover to second key and success result, got: %s", result.ForUser)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebTool_GLMSearch_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != "POST" {
|
||||
|
||||
@@ -0,0 +1,272 @@
|
||||
// Package utils provides shared, reusable algorithms.
|
||||
// This file implements a generic BM25 search engine.
|
||||
//
|
||||
// Usage:
|
||||
//
|
||||
// type MyDoc struct { ID string; Body string }
|
||||
//
|
||||
// corpus := []MyDoc{...}
|
||||
// engine := bm25.New(corpus, func(d MyDoc) string {
|
||||
// return d.ID + " " + d.Body
|
||||
// })
|
||||
// results := engine.Search("my query", 5)
|
||||
package utils
|
||||
|
||||
import (
|
||||
"math"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ── Tuning defaults ───────────────────────────────────────────────────────────
|
||||
|
||||
const (
|
||||
// DefaultBM25K1 is the term-frequency saturation factor (typical range 1.2–2.0).
|
||||
// Higher values give more weight to repeated terms.
|
||||
DefaultBM25K1 = 1.2
|
||||
|
||||
// DefaultBM25B is the document-length normalization factor (0 = none, 1 = full).
|
||||
DefaultBM25B = 0.75
|
||||
)
|
||||
|
||||
// BM25Engine is a query-time BM25 search engine over a generic corpus.
|
||||
// T is the document type; the caller supplies a TextFunc that extracts the
|
||||
// searchable text from each document.
|
||||
//
|
||||
// The engine is stateless between queries: no caching, no invalidation logic.
|
||||
// All indexing work is performed inside Search() on every call, making it
|
||||
// safe to use on corpora that change frequently.
|
||||
type BM25Engine[T any] struct {
|
||||
corpus []T
|
||||
textFunc func(T) string
|
||||
k1 float64
|
||||
b float64
|
||||
}
|
||||
|
||||
// BM25Option is a functional option to configure a BM25Engine.
|
||||
type BM25Option func(*bm25Config)
|
||||
|
||||
type bm25Config struct {
|
||||
k1 float64
|
||||
b float64
|
||||
}
|
||||
|
||||
// WithK1 overrides the term-frequency saturation constant (default 1.2).
|
||||
func WithK1(k1 float64) BM25Option {
|
||||
return func(c *bm25Config) { c.k1 = k1 }
|
||||
}
|
||||
|
||||
// WithB overrides the document-length normalization factor (default 0.75).
|
||||
func WithB(b float64) BM25Option {
|
||||
return func(c *bm25Config) { c.b = b }
|
||||
}
|
||||
|
||||
// NewBM25Engine creates a BM25Engine for the given corpus.
|
||||
//
|
||||
// - corpus : slice of documents of any type T.
|
||||
// - textFunc : function that returns the searchable text for a document.
|
||||
// - opts : optional tuning (WithK1, WithB).
|
||||
//
|
||||
// The corpus slice is referenced, not copied. Callers must not mutate it
|
||||
// concurrently with Search().
|
||||
func NewBM25Engine[T any](corpus []T, textFunc func(T) string, opts ...BM25Option) *BM25Engine[T] {
|
||||
cfg := bm25Config{k1: DefaultBM25K1, b: DefaultBM25B}
|
||||
for _, o := range opts {
|
||||
o(&cfg)
|
||||
}
|
||||
return &BM25Engine[T]{
|
||||
corpus: corpus,
|
||||
textFunc: textFunc,
|
||||
k1: cfg.k1,
|
||||
b: cfg.b,
|
||||
}
|
||||
}
|
||||
|
||||
// BM25Result is a single ranked result from a Search call.
|
||||
type BM25Result[T any] struct {
|
||||
Document T
|
||||
Score float32
|
||||
}
|
||||
|
||||
// Search ranks the corpus against query and returns the top-k results.
|
||||
// Returns an empty slice (not nil) when there are no matches.
|
||||
//
|
||||
// Complexity: O(N×L) for indexing + O(|Q|×avgPostingLen) for scoring,
|
||||
// where N = corpus size, L = average document length, Q = query terms.
|
||||
// Top-k extraction uses a fixed-size min-heap: O(candidates × log k).
|
||||
func (e *BM25Engine[T]) Search(query string, topK int) []BM25Result[T] {
|
||||
if topK <= 0 {
|
||||
return []BM25Result[T]{}
|
||||
}
|
||||
|
||||
queryTerms := bm25Tokenize(query)
|
||||
if len(queryTerms) == 0 {
|
||||
return []BM25Result[T]{}
|
||||
}
|
||||
|
||||
N := len(e.corpus)
|
||||
if N == 0 {
|
||||
return []BM25Result[T]{}
|
||||
}
|
||||
|
||||
// Step 1: build per-document tf + raw doc lengths
|
||||
type docEntry struct {
|
||||
tf map[string]uint32
|
||||
rawLen int
|
||||
}
|
||||
|
||||
entries := make([]docEntry, N)
|
||||
df := make(map[string]int, 64)
|
||||
totalLen := 0
|
||||
|
||||
for i, doc := range e.corpus {
|
||||
tokens := bm25Tokenize(e.textFunc(doc))
|
||||
totalLen += len(tokens)
|
||||
|
||||
tf := make(map[string]uint32, len(tokens))
|
||||
for _, t := range tokens {
|
||||
tf[t]++
|
||||
}
|
||||
// df: each term counts once per document (iterate the map, keys are unique)
|
||||
for t := range tf {
|
||||
df[t]++
|
||||
}
|
||||
|
||||
entries[i] = docEntry{tf: tf, rawLen: len(tokens)}
|
||||
}
|
||||
|
||||
avgDocLen := float64(totalLen) / float64(N)
|
||||
|
||||
// Step 2: pre-compute IDF and per-doc length normalization
|
||||
// IDF (Robertson smoothing): log( (N - df(t) + 0.5) / (df(t) + 0.5) + 1 )
|
||||
idf := make(map[string]float32, len(df))
|
||||
for term, freq := range df {
|
||||
idf[term] = float32(math.Log(
|
||||
(float64(N)-float64(freq)+0.5)/(float64(freq)+0.5) + 1,
|
||||
))
|
||||
}
|
||||
|
||||
// docLenNorm[i] = k1 * (1 - b + b * |doc_i| / avgDocLen)
|
||||
// Stored as float32 — sufficient precision for ranking.
|
||||
docLenNorm := make([]float32, N)
|
||||
for i, entry := range entries {
|
||||
docLenNorm[i] = float32(e.k1 * (1 - e.b + e.b*float64(entry.rawLen)/avgDocLen))
|
||||
}
|
||||
|
||||
// Step 3: build inverted index (posting lists)
|
||||
// Iterate the tf map directly — map keys are already unique, no seen-set needed.
|
||||
posting := make(map[string][]int32, len(df))
|
||||
for i, entry := range entries {
|
||||
for term := range entry.tf {
|
||||
posting[term] = append(posting[term], int32(i))
|
||||
}
|
||||
}
|
||||
|
||||
// Step 4: score via posting lists
|
||||
// Deduplicate query terms to avoid double-weighting the same term.
|
||||
unique := bm25Dedupe(queryTerms)
|
||||
|
||||
scores := make(map[int32]float32)
|
||||
for _, term := range unique {
|
||||
termIDF, ok := idf[term]
|
||||
if !ok {
|
||||
continue // term not in vocabulary → zero contribution
|
||||
}
|
||||
for _, docID := range posting[term] {
|
||||
freq := float32(entries[docID].tf[term])
|
||||
// TF_norm = freq * (k1+1) / (freq + docLenNorm)
|
||||
tfNorm := freq * float32(e.k1+1) / (freq + docLenNorm[docID])
|
||||
scores[docID] += termIDF * tfNorm
|
||||
}
|
||||
}
|
||||
|
||||
if len(scores) == 0 {
|
||||
return []BM25Result[T]{}
|
||||
}
|
||||
|
||||
// Step 5: top-K via fixed-size min-heap
|
||||
heap := make([]bm25ScoredDoc, 0, topK)
|
||||
|
||||
for docID, sc := range scores {
|
||||
switch {
|
||||
case len(heap) < topK:
|
||||
heap = append(heap, bm25ScoredDoc{docID: docID, score: sc})
|
||||
if len(heap) == topK {
|
||||
bm25MinHeapify(heap)
|
||||
}
|
||||
case sc > heap[0].score:
|
||||
heap[0] = bm25ScoredDoc{docID: docID, score: sc}
|
||||
bm25SiftDown(heap, 0)
|
||||
}
|
||||
}
|
||||
|
||||
sort.Slice(heap, func(i, j int) bool { return heap[i].score > heap[j].score })
|
||||
|
||||
out := make([]BM25Result[T], len(heap))
|
||||
for i, h := range heap {
|
||||
out[i] = BM25Result[T]{
|
||||
Document: e.corpus[h.docID],
|
||||
Score: h.score,
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// bm25Tokenize splits s into lowercase tokens, stripping edge punctuation.
|
||||
func bm25Tokenize(s string) []string {
|
||||
raw := strings.Fields(strings.ToLower(s))
|
||||
out := raw[:0] // reuse backing array to avoid extra allocation
|
||||
for _, t := range raw {
|
||||
t = strings.Trim(t, ".,;:!?\"'()/\\-_")
|
||||
if t != "" {
|
||||
out = append(out, t)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// bm25Dedupe returns a new slice with duplicate tokens removed,
|
||||
// preserving first-occurrence order.
|
||||
func bm25Dedupe(tokens []string) []string {
|
||||
seen := make(map[string]struct{}, len(tokens))
|
||||
out := make([]string, 0, len(tokens))
|
||||
for _, t := range tokens {
|
||||
if _, ok := seen[t]; !ok {
|
||||
seen[t] = struct{}{}
|
||||
out = append(out, t)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
type bm25ScoredDoc struct {
|
||||
docID int32
|
||||
score float32
|
||||
}
|
||||
|
||||
// bm25MinHeapify builds a min-heap in-place using Floyd's algorithm: O(k).
|
||||
func bm25MinHeapify(h []bm25ScoredDoc) {
|
||||
for i := len(h)/2 - 1; i >= 0; i-- {
|
||||
bm25SiftDown(h, i)
|
||||
}
|
||||
}
|
||||
|
||||
// bm25SiftDown restores the min-heap property starting at node i: O(log k).
|
||||
func bm25SiftDown(h []bm25ScoredDoc, i int) {
|
||||
n := len(h)
|
||||
for {
|
||||
smallest := i
|
||||
l, r := 2*i+1, 2*i+2
|
||||
if l < n && h[l].score < h[smallest].score {
|
||||
smallest = l
|
||||
}
|
||||
if r < n && h[r].score < h[smallest].score {
|
||||
smallest = r
|
||||
}
|
||||
if smallest == i {
|
||||
break
|
||||
}
|
||||
h[i], h[smallest] = h[smallest], h[i]
|
||||
i = smallest
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,175 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// testDoc is a generic structure for use in tests.
|
||||
type testDoc struct {
|
||||
ID int
|
||||
Text string
|
||||
}
|
||||
|
||||
func extractText(d testDoc) string {
|
||||
return d.Text
|
||||
}
|
||||
|
||||
func TestBM25Search_EdgeCases(t *testing.T) {
|
||||
corpus := []testDoc{
|
||||
{1, "hello world"},
|
||||
{2, "foo bar"},
|
||||
}
|
||||
engine := NewBM25Engine(corpus, extractText)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
query string
|
||||
topK int
|
||||
}{
|
||||
{"Zero topK", "hello", 0},
|
||||
{"Negative topK", "hello", -1},
|
||||
{"Empty query", "", 5},
|
||||
{"Query with only punctuation", "...,,,!!!", 5},
|
||||
{"No matches found", "golang", 5},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
results := engine.Search(tt.query, tt.topK)
|
||||
if len(results) != 0 {
|
||||
t.Errorf("expected 0 results, got %d", len(results))
|
||||
}
|
||||
// Check that it never returns nil, but an empty slice
|
||||
if results == nil {
|
||||
t.Errorf("expected empty slice, got nil")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBM25Search_EmptyCorpus(t *testing.T) {
|
||||
engine := NewBM25Engine([]testDoc{}, extractText)
|
||||
results := engine.Search("hello", 5)
|
||||
if len(results) != 0 || results == nil {
|
||||
t.Errorf("expected empty slice from empty corpus, got %v", results)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBM25Search_RankingLogic(t *testing.T) {
|
||||
corpus := []testDoc{
|
||||
{1, "the quick brown fox jumps over the lazy dog"},
|
||||
{2, "quick fox"},
|
||||
{3, "quick quick quick fox"}, // High Term Frequency (TF)
|
||||
{4, "completely irrelevant document here"},
|
||||
}
|
||||
engine := NewBM25Engine(corpus, extractText)
|
||||
|
||||
t.Run("Term Frequency (TF) boosts score", func(t *testing.T) {
|
||||
results := engine.Search("quick", 5)
|
||||
if len(results) < 3 {
|
||||
t.Fatalf("expected at least 3 results, got %d", len(results))
|
||||
}
|
||||
// Doc 3 has the word "quick" repeated 3 times, it should beat Doc 2
|
||||
if results[0].Document.ID != 3 {
|
||||
t.Errorf("expected doc 3 to rank first due to high TF, got doc %d", results[0].Document.ID)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Document Length penalty", func(t *testing.T) {
|
||||
results := engine.Search("fox", 5)
|
||||
if len(results) < 3 {
|
||||
t.Fatalf("expected at least 3 results, got %d", len(results))
|
||||
}
|
||||
// Doc 2 ("quick fox") is much shorter than Doc 1 ("the quick brown fox..."),
|
||||
// so, with equal Term Frequency for the word "fox" (1 time), Doc 2 wins.
|
||||
if results[0].Document.ID != 2 {
|
||||
t.Errorf("expected doc 2 to rank first due to shorter length, got doc %d", results[0].Document.ID)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("TopK limits results", func(t *testing.T) {
|
||||
results := engine.Search("quick", 2)
|
||||
if len(results) != 2 {
|
||||
t.Errorf("expected exactly 2 results, got %d", len(results))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestBM25Tokenize(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expected []string
|
||||
}{
|
||||
{"Hello World", []string{"hello", "world"}},
|
||||
{" spaces everywhere ", []string{"spaces", "everywhere"}},
|
||||
{"punctuation... test!!!", []string{"punctuation", "test"}},
|
||||
{"(parentheses) and-hyphens", []string{"parentheses", "and-hyphens"}}, // hyphens trimmed from edges
|
||||
{"internal-hyphen is kept", []string{"internal-hyphen", "is", "kept"}},
|
||||
{".,;?!", []string{}}, // Becomes empty after trim
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
got := bm25Tokenize(tt.input)
|
||||
if len(got) == 0 && len(tt.expected) == 0 {
|
||||
return // Both empty
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.expected) {
|
||||
t.Errorf("bm25Tokenize(%q) = %v, want %v", tt.input, got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBM25Dedupe(t *testing.T) {
|
||||
input := []string{"apple", "banana", "apple", "orange", "banana"}
|
||||
expected := []string{"apple", "banana", "orange"}
|
||||
|
||||
got := bm25Dedupe(input)
|
||||
if !reflect.DeepEqual(got, expected) {
|
||||
t.Errorf("bm25Dedupe() = %v, want %v", got, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBM25Options(t *testing.T) {
|
||||
corpus := []testDoc{{1, "test"}}
|
||||
|
||||
engine := NewBM25Engine(
|
||||
corpus,
|
||||
extractText,
|
||||
WithK1(2.5),
|
||||
WithB(0.9),
|
||||
)
|
||||
|
||||
if engine.k1 != 2.5 {
|
||||
t.Errorf("expected k1 to be 2.5, got %v", engine.k1)
|
||||
}
|
||||
if engine.b != 0.9 {
|
||||
t.Errorf("expected b to be 0.9, got %v", engine.b)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBM25Search_SortingStability(t *testing.T) {
|
||||
// Ensure that sorting by heap returns in correct descending order
|
||||
corpus := []testDoc{
|
||||
{1, "golang is good"},
|
||||
{2, "golang golang"},
|
||||
{3, "golang golang golang"},
|
||||
{4, "golang golang golang golang"},
|
||||
}
|
||||
engine := NewBM25Engine(corpus, extractText)
|
||||
results := engine.Search("golang", 10)
|
||||
|
||||
if len(results) != 4 {
|
||||
t.Fatalf("expected 4 results, got %d", len(results))
|
||||
}
|
||||
|
||||
// Score should be strictly decreasing
|
||||
for i := 1; i < len(results); i++ {
|
||||
if results[i].Score > results[i-1].Score {
|
||||
t.Errorf("results not sorted correctly: result %d score (%v) > result %d score (%v)",
|
||||
i, results[i].Score, i-1, results[i-1].Score)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -2,9 +2,18 @@ package utils
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
// Global variable to disable truncation
|
||||
var disableTruncation atomic.Bool
|
||||
|
||||
// SetDisableTruncation globally enables or disables string truncation
|
||||
func SetDisableTruncation(enabled bool) {
|
||||
disableTruncation.Store(enabled)
|
||||
}
|
||||
|
||||
// SanitizeMessageContent removes Unicode control characters, format characters (RTL overrides,
|
||||
// zero-width characters), and other non-graphic characters that could confuse an LLM
|
||||
// or cause display issues in the agent UI.
|
||||
@@ -30,6 +39,10 @@ func SanitizeMessageContent(input string) string {
|
||||
// Handles multi-byte Unicode characters properly.
|
||||
// If the string is truncated, "..." is appended to indicate truncation.
|
||||
func Truncate(s string, maxLen int) string {
|
||||
// If the no-truncate flag is active, it returns the full string
|
||||
if disableTruncation.Load() {
|
||||
return s
|
||||
}
|
||||
if maxLen <= 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user