Merge remote-tracking branch 'origin/main' into feat/echo-voice-audio-transcription

# Conflicts:
#	pkg/channels/telegram/telegram.go
#	pkg/config/config.go
#	pkg/config/defaults.go
This commit is contained in:
afjcjsbx
2026-03-11 00:06:37 +01:00
238 changed files with 30227 additions and 4832 deletions
+41 -10
View File
@@ -12,15 +12,19 @@ import (
"sync"
"time"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/skills"
"github.com/sipeed/picoclaw/pkg/utils"
)
type ContextBuilder struct {
workspace string
skillsLoader *skills.SkillsLoader
memory *MemoryStore
workspace string
skillsLoader *skills.SkillsLoader
memory *MemoryStore
toolDiscoveryBM25 bool
toolDiscoveryRegex bool
// Cache for system prompt to avoid rebuilding on every call.
// This fixes issue #607: repeated reprocessing of the entire context.
@@ -41,6 +45,12 @@ type ContextBuilder struct {
skillFilesAtCache map[string]time.Time
}
func (cb *ContextBuilder) WithToolDiscovery(useBM25, useRegex bool) *ContextBuilder {
cb.toolDiscoveryBM25 = useBM25
cb.toolDiscoveryRegex = useRegex
return cb
}
func getGlobalConfigDir() string {
if home := os.Getenv("PICOCLAW_HOME"); home != "" {
return home
@@ -71,8 +81,11 @@ func NewContextBuilder(workspace string) *ContextBuilder {
func (cb *ContextBuilder) getIdentity() string {
workspacePath, _ := filepath.Abs(filepath.Join(cb.workspace))
toolDiscovery := cb.getDiscoveryRule()
version := config.FormatVersion()
return fmt.Sprintf(`# picoclaw 🦞
return fmt.Sprintf(
`# picoclaw 🦞 (%s)
You are picoclaw, a helpful AI assistant.
@@ -90,8 +103,29 @@ Your workspace is at: %s
3. **Memory** - When interacting with me if something seems memorable, update %s/memory/MEMORY.md
4. **Context summaries** - Conversation summaries provided as context are approximate references only. They may be incomplete or outdated. Always defer to explicit user instructions over summary content.`,
workspacePath, workspacePath, workspacePath, workspacePath, workspacePath)
4. **Context summaries** - Conversation summaries provided as context are approximate references only. They may be incomplete or outdated. Always defer to explicit user instructions over summary content.
%s`,
version, workspacePath, workspacePath, workspacePath, workspacePath, workspacePath, toolDiscovery)
}
func (cb *ContextBuilder) getDiscoveryRule() string {
if !cb.toolDiscoveryBM25 && !cb.toolDiscoveryRegex {
return ""
}
var toolNames []string
if cb.toolDiscoveryBM25 {
toolNames = append(toolNames, `"tool_search_tool_bm25"`)
}
if cb.toolDiscoveryRegex {
toolNames = append(toolNames, `"tool_search_tool_regex"`)
}
return fmt.Sprintf(
`5. **Tool Discovery** - Your visible tools are limited to save memory, but a vast hidden library exists. If you lack the right tool for a task, BEFORE giving up, you MUST search using the %s tool. Do not refuse a request unless the search returns nothing. Found tools will temporarily unlock for your next turn.`,
strings.Join(toolNames, " or "),
)
}
func (cb *ContextBuilder) BuildSystemPrompt() string {
@@ -505,10 +539,7 @@ func (cb *ContextBuilder) BuildMessages(
})
// Log preview of system prompt (avoid logging huge content)
preview := fullSystemPrompt
if len(preview) > 500 {
preview = preview[:500] + "... (truncated)"
}
preview := utils.Truncate(fullSystemPrompt, 500)
logger.DebugCF("agent", "System prompt preview",
map[string]any{
"preview": preview,
+45 -5
View File
@@ -1,6 +1,7 @@
package agent
import (
"context"
"fmt"
"log"
"os"
@@ -9,6 +10,7 @@ import (
"strings"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/memory"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/routing"
"github.com/sipeed/picoclaw/pkg/session"
@@ -31,7 +33,7 @@ type AgentInstance struct {
SummarizeMessageThreshold int
SummarizeTokenPercent int
Provider providers.LLMProvider
Sessions *session.SessionManager
Sessions session.SessionStore
ContextBuilder *ContextBuilder
Tools *tools.ToolRegistry
Subagents *config.SubagentsConfig
@@ -70,7 +72,8 @@ func NewAgentInstance(
toolsRegistry := tools.NewToolRegistry()
if cfg.Tools.IsToolEnabled("read_file") {
toolsRegistry.Register(tools.NewReadFileTool(workspace, readRestrict, allowReadPaths))
maxReadFileSize := cfg.Tools.ReadFile.MaxReadFileSize
toolsRegistry.Register(tools.NewReadFileTool(workspace, readRestrict, maxReadFileSize, allowReadPaths))
}
if cfg.Tools.IsToolEnabled("write_file") {
toolsRegistry.Register(tools.NewWriteFileTool(workspace, restrict, allowWritePaths))
@@ -94,9 +97,13 @@ func NewAgentInstance(
}
sessionsDir := filepath.Join(workspace, "sessions")
sessionsManager := session.NewSessionManager(sessionsDir)
sessions := initSessionStore(sessionsDir)
contextBuilder := NewContextBuilder(workspace)
mcpDiscoveryActive := cfg.Tools.MCP.Enabled && cfg.Tools.MCP.Discovery.Enabled
contextBuilder := NewContextBuilder(workspace).WithToolDiscovery(
mcpDiscoveryActive && cfg.Tools.MCP.Discovery.UseBM25,
mcpDiscoveryActive && cfg.Tools.MCP.Discovery.UseRegex,
)
agentID := routing.DefaultAgentID
agentName := ""
@@ -221,7 +228,7 @@ func NewAgentInstance(
SummarizeMessageThreshold: summarizeMessageThreshold,
SummarizeTokenPercent: summarizeTokenPercent,
Provider: provider,
Sessions: sessionsManager,
Sessions: sessions,
ContextBuilder: contextBuilder,
Tools: toolsRegistry,
Subagents: subagents,
@@ -275,6 +282,39 @@ func compilePatterns(patterns []string) []*regexp.Regexp {
return compiled
}
// Close releases resources held by the agent's session store.
func (a *AgentInstance) Close() error {
if a.Sessions != nil {
return a.Sessions.Close()
}
return nil
}
// initSessionStore creates the session persistence backend.
// It uses the JSONL store by default and auto-migrates legacy JSON sessions.
// Falls back to SessionManager if the JSONL store cannot be initialized or
// if migration fails (which indicates the store cannot write reliably).
func initSessionStore(dir string) session.SessionStore {
store, err := memory.NewJSONLStore(dir)
if err != nil {
log.Printf("memory: init store: %v; using json sessions", err)
return session.NewSessionManager(dir)
}
if n, merr := memory.MigrateFromJSON(context.Background(), dir, store); merr != nil {
// Migration failure means the store could not write data.
// Fall back to SessionManager to avoid a split state where
// some sessions are in JSONL and others remain in JSON.
log.Printf("memory: migration failed: %v; falling back to json sessions", merr)
store.Close()
return session.NewSessionManager(dir)
} else if n > 0 {
log.Printf("memory: migrated %d session(s) to jsonl", n)
}
return session.NewJSONLBackend(store)
}
func expandHome(path string) string {
if path == "" {
return path
+213 -45
View File
@@ -120,19 +120,21 @@ func registerSharedTools(
continue
}
// Web tools
if cfg.Tools.IsToolEnabled("web") {
searchTool, err := tools.NewWebSearchTool(tools.WebSearchToolOptions{
BraveAPIKey: cfg.Tools.Web.Brave.APIKey,
BraveAPIKeys: config.MergeAPIKeys(cfg.Tools.Web.Brave.APIKey, cfg.Tools.Web.Brave.APIKeys),
BraveMaxResults: cfg.Tools.Web.Brave.MaxResults,
BraveEnabled: cfg.Tools.Web.Brave.Enabled,
TavilyAPIKey: cfg.Tools.Web.Tavily.APIKey,
TavilyAPIKeys: config.MergeAPIKeys(cfg.Tools.Web.Tavily.APIKey, cfg.Tools.Web.Tavily.APIKeys),
TavilyBaseURL: cfg.Tools.Web.Tavily.BaseURL,
TavilyMaxResults: cfg.Tools.Web.Tavily.MaxResults,
TavilyEnabled: cfg.Tools.Web.Tavily.Enabled,
DuckDuckGoMaxResults: cfg.Tools.Web.DuckDuckGo.MaxResults,
DuckDuckGoEnabled: cfg.Tools.Web.DuckDuckGo.Enabled,
PerplexityAPIKey: cfg.Tools.Web.Perplexity.APIKey,
PerplexityAPIKeys: config.MergeAPIKeys(
cfg.Tools.Web.Perplexity.APIKey,
cfg.Tools.Web.Perplexity.APIKeys,
),
PerplexityMaxResults: cfg.Tools.Web.Perplexity.MaxResults,
PerplexityEnabled: cfg.Tools.Web.Perplexity.Enabled,
SearXNGBaseURL: cfg.Tools.Web.SearXNG.BaseURL,
@@ -283,7 +285,13 @@ func (al *AgentLoop) Run(ctx context.Context) error {
}
mcpTool := tools.NewMCPTool(mcpManager, serverName, tool)
agent.Tools.Register(mcpTool)
if al.cfg.Tools.MCP.Discovery.Enabled {
agent.Tools.RegisterHidden(mcpTool)
} else {
agent.Tools.Register(mcpTool)
}
totalRegistrations++
logger.DebugCF("agent", "Registered MCP tool",
map[string]any{
@@ -302,6 +310,47 @@ func (al *AgentLoop) Run(ctx context.Context) error {
"total_registrations": totalRegistrations,
"agent_count": agentCount,
})
// Initializes Discovery Tools only if enabled by configuration
if al.cfg.Tools.MCP.Enabled && al.cfg.Tools.MCP.Discovery.Enabled {
useBM25 := al.cfg.Tools.MCP.Discovery.UseBM25
useRegex := al.cfg.Tools.MCP.Discovery.UseRegex
// Fail fast: If discovery is enabled but no search method is turned on
if !useBM25 && !useRegex {
return fmt.Errorf(
"tool discovery is enabled but neither 'use_bm25' nor 'use_regex' is set to true in the configuration",
)
}
ttl := al.cfg.Tools.MCP.Discovery.TTL
if ttl <= 0 {
ttl = 5 // Default value
}
maxSearchResults := al.cfg.Tools.MCP.Discovery.MaxSearchResults
if maxSearchResults <= 0 {
maxSearchResults = 5 // Default value
}
logger.InfoCF("agent", "Initializing tool discovery", map[string]any{
"bm25": useBM25, "regex": useRegex, "ttl": ttl, "max_results": maxSearchResults,
})
for _, agentID := range agentIDs {
agent, ok := al.registry.GetAgent(agentID)
if !ok {
continue
}
if useRegex {
agent.Tools.Register(tools.NewRegexSearchTool(agent.Tools, ttl, maxSearchResults))
}
if useBM25 {
agent.Tools.Register(tools.NewBM25SearchTool(agent.Tools, ttl, maxSearchResults))
}
}
}
}
}
@@ -380,6 +429,11 @@ func (al *AgentLoop) Stop() {
al.running.Store(false)
}
// Close releases resources held by agent session stores. Call after Stop.
func (al *AgentLoop) Close() {
al.registry.Close()
}
func (al *AgentLoop) RegisterTool(tool tools.Tool) {
for _, agentID := range al.registry.ListAgentIDs() {
if agent, ok := al.registry.GetAgent(agentID); ok {
@@ -632,15 +686,6 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage)
}
route, agent, routeErr := al.resolveMessageRoute(msg)
// Commands are checked before requiring a successful route.
// Global commands (/help, /show, /switch) work even when routing fails;
// context-dependent commands check their own Runtime fields and report
// "unavailable" when the required capability is nil.
if response, handled := al.handleCommand(ctx, msg, agent); handled {
return response, nil
}
if routeErr != nil {
return "", routeErr
}
@@ -666,7 +711,7 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage)
"route_channel": route.Channel,
})
return al.runAgentLoop(ctx, agent, processOptions{
opts := processOptions{
SessionKey: sessionKey,
Channel: msg.Channel,
ChatID: msg.ChatID,
@@ -675,7 +720,15 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage)
DefaultResponse: defaultResponse,
EnableSummary: true,
SendResponse: false,
})
}
// context-dependent commands check their own Runtime fields and report
// "unavailable" when the required capability is nil.
if response, handled := al.handleCommand(ctx, msg, agent, &opts); handled {
return response, nil
}
return al.runAgentLoop(ctx, agent, opts)
}
func (al *AgentLoop) resolveMessageRoute(msg bus.InboundMessage) (routing.ResolvedRoute, *AgentInstance, error) {
@@ -1306,6 +1359,17 @@ func (al *AgentLoop) runLLMIteration(
// Save tool result message to session
agent.Sessions.AddFullMessage(opts.SessionKey, toolResultMsg)
}
// Tick down TTL of discovered tools after processing tool results.
// Only reached when tool calls were made (the loop continues);
// the break on no-tool-call responses skips this.
// NOTE: This is safe because processMessage is sequential per agent.
// If per-agent concurrency is added, TTL consistency between
// ToProviderDefs and Get must be re-evaluated.
agent.Tools.TickTTL()
logger.DebugCF("agent", "TTL tick after tool execution", map[string]any{
"agent_id": agent.ID, "iteration": iteration,
})
}
return finalContent, iteration, nil
@@ -1543,10 +1607,20 @@ func (al *AgentLoop) summarizeSession(agent *AgentInstance, sessionKey string) {
return
}
const (
maxSummarizationMessages = 10
llmMaxRetries = 3
llmTemperature = 0.3
fallbackMaxContentLength = 200
)
// Multi-Part Summarization
var finalSummary string
if len(validMessages) > 10 {
if len(validMessages) > maxSummarizationMessages {
mid := len(validMessages) / 2
mid = al.findNearestUserMessage(validMessages, mid)
part1 := validMessages[:mid]
part2 := validMessages[mid:]
@@ -1558,18 +1632,9 @@ func (al *AgentLoop) summarizeSession(agent *AgentInstance, sessionKey string) {
s1,
s2,
)
resp, err := agent.Provider.Chat(
ctx,
[]providers.Message{{Role: "user", Content: mergePrompt}},
nil,
agent.Model,
map[string]any{
"max_tokens": 1024,
"temperature": 0.3,
"prompt_cache_key": agent.ID,
},
)
if err == nil {
resp, err := al.retryLLMCall(ctx, agent, mergePrompt, llmMaxRetries)
if err == nil && resp.Content != "" {
finalSummary = resp.Content
} else {
finalSummary = s1 + " " + s2
@@ -1589,6 +1654,68 @@ func (al *AgentLoop) summarizeSession(agent *AgentInstance, sessionKey string) {
}
}
// findNearestUserMessage finds the nearest user message to the given index.
// It searches backward first, then forward if no user message is found.
func (al *AgentLoop) findNearestUserMessage(messages []providers.Message, mid int) int {
originalMid := mid
for mid > 0 && messages[mid].Role != "user" {
mid--
}
if messages[mid].Role == "user" {
return mid
}
mid = originalMid
for mid < len(messages) && messages[mid].Role != "user" {
mid++
}
if mid < len(messages) {
return mid
}
return originalMid
}
// retryLLMCall calls the LLM with retry logic.
func (al *AgentLoop) retryLLMCall(
ctx context.Context,
agent *AgentInstance,
prompt string,
maxRetries int,
) (*providers.LLMResponse, error) {
const (
llmTemperature = 0.3
)
var resp *providers.LLMResponse
var err error
for attempt := 0; attempt < maxRetries; attempt++ {
resp, err = agent.Provider.Chat(
ctx,
[]providers.Message{{Role: "user", Content: prompt}},
nil,
agent.Model,
map[string]any{
"max_tokens": agent.MaxTokens,
"temperature": llmTemperature,
"prompt_cache_key": agent.ID,
},
)
if err == nil && resp != nil && resp.Content != "" {
return resp, nil
}
if attempt < maxRetries-1 {
time.Sleep(time.Duration(attempt+1) * 100 * time.Millisecond)
}
}
return resp, err
}
// summarizeBatch summarizes a batch of messages.
func (al *AgentLoop) summarizeBatch(
ctx context.Context,
@@ -1596,6 +1723,13 @@ func (al *AgentLoop) summarizeBatch(
batch []providers.Message,
existingSummary string,
) (string, error) {
const (
llmMaxRetries = 3
llmTemperature = 0.3
fallbackMinContentLength = 200
fallbackMaxContentPercent = 10
)
var sb strings.Builder
sb.WriteString(
"Provide a concise summary of this conversation segment, preserving core context and key points.\n",
@@ -1611,21 +1745,40 @@ func (al *AgentLoop) summarizeBatch(
}
prompt := sb.String()
response, err := agent.Provider.Chat(
ctx,
[]providers.Message{{Role: "user", Content: prompt}},
nil,
agent.Model,
map[string]any{
"max_tokens": 1024,
"temperature": 0.3,
"prompt_cache_key": agent.ID,
},
)
if err != nil {
return "", err
response, err := al.retryLLMCall(ctx, agent, prompt, llmMaxRetries)
if err == nil && response.Content != "" {
return strings.TrimSpace(response.Content), nil
}
return response.Content, nil
var fallback strings.Builder
fallback.WriteString("Conversation summary: ")
for i, m := range batch {
if i > 0 {
fallback.WriteString(" | ")
}
content := strings.TrimSpace(m.Content)
runes := []rune(content)
if len(runes) == 0 {
fallback.WriteString(fmt.Sprintf("%s: ", m.Role))
continue
}
keepLength := len(runes) * fallbackMaxContentPercent / 100
if keepLength < fallbackMinContentLength {
keepLength = fallbackMinContentLength
}
if keepLength > len(runes) {
keepLength = len(runes)
}
content = string(runes[:keepLength])
if keepLength < len(runes) {
content += "..."
}
fallback.WriteString(fmt.Sprintf("%s: %s", m.Role, content))
}
return fallback.String(), nil
}
// estimateTokens estimates the number of tokens in a message list.
@@ -1644,6 +1797,7 @@ func (al *AgentLoop) handleCommand(
ctx context.Context,
msg bus.InboundMessage,
agent *AgentInstance,
opts *processOptions,
) (string, bool) {
if !commands.HasCommandPrefix(msg.Content) {
return "", false
@@ -1653,7 +1807,7 @@ func (al *AgentLoop) handleCommand(
return "", false
}
rt := al.buildCommandsRuntime(agent)
rt := al.buildCommandsRuntime(agent, opts)
executor := commands.NewExecutor(al.cmdRegistry, rt)
var commandReply string
@@ -1682,7 +1836,7 @@ func (al *AgentLoop) handleCommand(
}
}
func (al *AgentLoop) buildCommandsRuntime(agent *AgentInstance) *commands.Runtime {
func (al *AgentLoop) buildCommandsRuntime(agent *AgentInstance, opts *processOptions) *commands.Runtime {
rt := &commands.Runtime{
Config: al.cfg,
ListAgentIDs: al.registry.ListAgentIDs,
@@ -1712,6 +1866,20 @@ func (al *AgentLoop) buildCommandsRuntime(agent *AgentInstance) *commands.Runtim
agent.Model = value
return oldModel, nil
}
rt.ClearHistory = func() error {
if opts == nil {
return fmt.Errorf("process options not available")
}
if agent.Sessions == nil {
return fmt.Errorf("sessions not initialized for agent")
}
agent.Sessions.SetHistory(opts.SessionKey, make([]providers.Message, 0))
agent.Sessions.SetSummary(opts.SessionKey, "")
agent.Sessions.Save(opts.SessionKey)
return nil
}
}
return rt
}
+12
View File
@@ -114,6 +114,18 @@ func (r *AgentRegistry) ForEachTool(name string, fn func(tools.Tool)) {
}
}
// Close releases resources held by all registered agents.
func (r *AgentRegistry) Close() {
r.mu.RLock()
defer r.mu.RUnlock()
for _, agent := range r.agents {
if err := agent.Close(); err != nil {
logger.WarnCF("agent", "Failed to close agent",
map[string]any{"agent_id": agent.ID, "error": err.Error()})
}
}
}
// GetDefaultAgent returns the default agent instance.
func (r *AgentRegistry) GetDefaultAgent() *AgentInstance {
r.mu.RLock()
+16 -2
View File
@@ -7,6 +7,7 @@ import (
"encoding/json"
"fmt"
"io"
"math/rand"
"net/http"
"os"
"path/filepath"
@@ -195,18 +196,30 @@ func (c *FeishuChannel) SendPlaceholder(ctx context.Context, chatID string) (str
}
// ReactToMessage implements channels.ReactionCapable.
// Adds an "Pin" reaction and returns an undo function to remove it.
// Adds a reaction (randomly chosen from config) and returns an undo function to remove it.
func (c *FeishuChannel) ReactToMessage(ctx context.Context, chatID, messageID string) (func(), error) {
// Get emoji list from config
emojiList := c.config.RandomReactionEmoji
var chosenEmoji string
if len(emojiList) == 0 {
// Default to "Pin" if no config
chosenEmoji = "Pin"
} else {
idx := rand.Intn(len(emojiList))
chosenEmoji = emojiList[idx]
}
req := larkim.NewCreateMessageReactionReqBuilder().
MessageId(messageID).
Body(larkim.NewCreateMessageReactionReqBodyBuilder().
ReactionType(larkim.NewEmojiBuilder().EmojiType("Pin").Build()).
ReactionType(larkim.NewEmojiBuilder().EmojiType(chosenEmoji).Build()).
Build()).
Build()
resp, err := c.client.Im.V1.MessageReaction.Create(ctx, req)
if err != nil {
logger.ErrorCF("feishu", "Failed to add reaction", map[string]any{
"emoji": chosenEmoji,
"message_id": messageID,
"error": err.Error(),
})
@@ -214,6 +227,7 @@ func (c *FeishuChannel) ReactToMessage(ctx context.Context, chatID, messageID st
}
if !resp.Success() {
logger.ErrorCF("feishu", "Reaction API error", map[string]any{
"emoji": chosenEmoji,
"message_id": messageID,
"code": resp.Code,
"msg": resp.Msg,
+9
View File
@@ -61,7 +61,9 @@ var channelRateConfig = map[string]float64{
"telegram": 20,
"discord": 1,
"slack": 1,
"matrix": 2,
"line": 10,
"qq": 5,
"irc": 2,
}
@@ -265,6 +267,13 @@ func (m *Manager) initChannels() error {
m.initChannel("slack", "Slack")
}
if m.config.Channels.Matrix.Enabled &&
m.config.Channels.Matrix.Homeserver != "" &&
m.config.Channels.Matrix.UserID != "" &&
m.config.Channels.Matrix.AccessToken != "" {
m.initChannel("matrix", "Matrix")
}
if m.config.Channels.LINE.Enabled && m.config.Channels.LINE.ChannelAccessToken != "" {
m.initChannel("line", "LINE")
}
+13
View File
@@ -0,0 +1,13 @@
package matrix
import (
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/channels"
"github.com/sipeed/picoclaw/pkg/config"
)
func init() {
channels.RegisterFactory("matrix", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
return NewMatrixChannel(cfg.Channels.Matrix, b)
})
}
File diff suppressed because it is too large Load Diff
+291
View File
@@ -0,0 +1,291 @@
package matrix
import (
"context"
"os"
"path/filepath"
"testing"
"time"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
)
func TestMatrixLocalpartMentionRegexp(t *testing.T) {
re := localpartMentionRegexp("picoclaw")
cases := []struct {
text string
want bool
}{
{text: "@picoclaw hello", want: true},
{text: "hi @picoclaw:matrix.org", want: true},
{
text: "\u6b22\u8fce\u4e00\u4e0bpicoclaw\u5c0f\u9f99\u867e",
want: false, // historical false-positive case in PR #356
},
{text: "mail test@example.com", want: false},
}
for _, tc := range cases {
if got := re.MatchString(tc.text); got != tc.want {
t.Fatalf("text=%q match=%v want=%v", tc.text, got, tc.want)
}
}
}
func TestStripUserMention(t *testing.T) {
userID := id.UserID("@picoclaw:matrix.org")
cases := []struct {
in string
want string
}{
{in: "@picoclaw:matrix.org hello", want: "hello"},
{in: "@picoclaw, hello", want: "hello"},
{in: "no mention here", want: "no mention here"},
}
for _, tc := range cases {
if got := stripUserMention(tc.in, userID); got != tc.want {
t.Fatalf("stripUserMention(%q)=%q want=%q", tc.in, got, tc.want)
}
}
}
func TestIsBotMentioned(t *testing.T) {
ch := &MatrixChannel{
client: &mautrix.Client{
UserID: id.UserID("@picoclaw:matrix.org"),
},
}
cases := []struct {
name string
msg event.MessageEventContent
want bool
}{
{
name: "mentions field",
msg: event.MessageEventContent{
Body: "hello",
Mentions: &event.Mentions{
UserIDs: []id.UserID{id.UserID("@picoclaw:matrix.org")},
},
},
want: true,
},
{
name: "full user id in body",
msg: event.MessageEventContent{
Body: "@picoclaw:matrix.org hello",
},
want: true,
},
{
name: "localpart with at sign",
msg: event.MessageEventContent{
Body: "@picoclaw hello",
},
want: true,
},
{
name: "localpart without at sign should not match",
msg: event.MessageEventContent{
Body: "\u6b22\u8fce\u4e00\u4e0bpicoclaw\u5c0f\u9f99\u867e",
},
want: false,
},
{
name: "formatted mention href matrix.to plain",
msg: event.MessageEventContent{
Body: "hello bot",
FormattedBody: `<a href="https://matrix.to/#/@picoclaw:matrix.org">PicoClaw</a> hello`,
},
want: true,
},
{
name: "formatted mention href matrix.to encoded",
msg: event.MessageEventContent{
Body: "hello bot",
FormattedBody: `<a href="https://matrix.to/#/%40picoclaw%3Amatrix.org">PicoClaw</a> hello`,
},
want: true,
},
}
for _, tc := range cases {
if got := ch.isBotMentioned(&tc.msg); got != tc.want {
t.Fatalf("%s: got=%v want=%v", tc.name, got, tc.want)
}
}
}
func TestRoomKindCache_ExpiresEntries(t *testing.T) {
cache := newRoomKindCache(4, 5*time.Second)
now := time.Unix(100, 0)
cache.set("!room:matrix.org", true, now)
if got, ok := cache.get("!room:matrix.org", now.Add(2*time.Second)); !ok || !got {
t.Fatalf("expected cached group room before ttl, got ok=%v group=%v", ok, got)
}
if _, ok := cache.get("!room:matrix.org", now.Add(6*time.Second)); ok {
t.Fatal("expected cache miss after ttl expiry")
}
}
func TestRoomKindCache_EvictsOldestWhenFull(t *testing.T) {
cache := newRoomKindCache(2, time.Minute)
now := time.Unix(200, 0)
cache.set("!room1:matrix.org", false, now)
cache.set("!room2:matrix.org", false, now.Add(1*time.Second))
cache.set("!room3:matrix.org", true, now.Add(2*time.Second))
if _, ok := cache.get("!room1:matrix.org", now.Add(2*time.Second)); ok {
t.Fatal("expected oldest cache entry to be evicted")
}
if got, ok := cache.get("!room2:matrix.org", now.Add(2*time.Second)); !ok || got {
t.Fatalf("expected room2 to remain and be direct, got ok=%v group=%v", ok, got)
}
if got, ok := cache.get("!room3:matrix.org", now.Add(2*time.Second)); !ok || !got {
t.Fatalf("expected room3 to remain and be group, got ok=%v group=%v", ok, got)
}
}
func TestMatrixMediaTempDir(t *testing.T) {
dir, err := matrixMediaTempDir()
if err != nil {
t.Fatalf("matrixMediaTempDir failed: %v", err)
}
if filepath.Base(dir) != matrixMediaTempDirName {
t.Fatalf("unexpected media dir base: %q", filepath.Base(dir))
}
info, err := os.Stat(dir)
if err != nil {
t.Fatalf("media dir not created: %v", err)
}
if !info.IsDir() {
t.Fatalf("expected directory, got mode=%v", info.Mode())
}
}
func TestMatrixMediaExt(t *testing.T) {
if got := matrixMediaExt("photo.png", "", "image"); got != ".png" {
t.Fatalf("filename extension mismatch: got=%q", got)
}
if got := matrixMediaExt("", "image/webp", "image"); got != ".webp" {
t.Fatalf("content-type extension mismatch: got=%q", got)
}
if got := matrixMediaExt("", "", "image"); got != ".jpg" {
t.Fatalf("default image extension mismatch: got=%q", got)
}
if got := matrixMediaExt("", "", "audio"); got != ".ogg" {
t.Fatalf("default audio extension mismatch: got=%q", got)
}
if got := matrixMediaExt("", "", "video"); got != ".mp4" {
t.Fatalf("default video extension mismatch: got=%q", got)
}
if got := matrixMediaExt("", "", "file"); got != ".bin" {
t.Fatalf("default file extension mismatch: got=%q", got)
}
}
func TestExtractInboundContent_ImageNoURLFallback(t *testing.T) {
ch := &MatrixChannel{}
msg := &event.MessageEventContent{
MsgType: event.MsgImage,
Body: "test.png",
}
content, mediaRefs, ok := ch.extractInboundContent(context.Background(), msg, "matrix:room:event")
if !ok {
t.Fatal("expected ok for image fallback")
}
if content != "[image: test.png]" {
t.Fatalf("unexpected content: %q", content)
}
if len(mediaRefs) != 0 {
t.Fatalf("expected no media refs, got %d", len(mediaRefs))
}
}
func TestExtractInboundContent_AudioNoURLFallback(t *testing.T) {
ch := &MatrixChannel{}
msg := &event.MessageEventContent{
MsgType: event.MsgAudio,
FileName: "voice.ogg",
Body: "please transcribe",
}
content, mediaRefs, ok := ch.extractInboundContent(context.Background(), msg, "matrix:room:event")
if !ok {
t.Fatal("expected ok for audio fallback")
}
if content != "please transcribe\n[audio: voice.ogg]" {
t.Fatalf("unexpected content: %q", content)
}
if len(mediaRefs) != 0 {
t.Fatalf("expected no media refs, got %d", len(mediaRefs))
}
}
func TestMatrixOutboundMsgType(t *testing.T) {
cases := []struct {
name string
partType string
filename string
contentType string
want event.MessageType
}{
{name: "explicit image", partType: "image", want: event.MsgImage},
{name: "explicit audio", partType: "audio", want: event.MsgAudio},
{name: "mime fallback video", contentType: "video/mp4", want: event.MsgVideo},
{name: "extension fallback audio", filename: "voice.ogg", want: event.MsgAudio},
{name: "unknown defaults file", filename: "report.txt", want: event.MsgFile},
}
for _, tc := range cases {
if got := matrixOutboundMsgType(tc.partType, tc.filename, tc.contentType); got != tc.want {
t.Fatalf("%s: got=%q want=%q", tc.name, got, tc.want)
}
}
}
func TestMatrixOutboundContent(t *testing.T) {
content := matrixOutboundContent(
"please review",
"voice.ogg",
event.MsgAudio,
"audio/ogg",
1234,
id.ContentURIString("mxc://matrix.org/abc"),
)
if content.Body != "please review" {
t.Fatalf("unexpected body: %q", content.Body)
}
if content.FileName != "voice.ogg" {
t.Fatalf("unexpected filename: %q", content.FileName)
}
if content.Info == nil || content.Info.MimeType != "audio/ogg" {
t.Fatalf("unexpected content type: %+v", content.Info)
}
if content.Info == nil || content.Info.Size != 1234 {
t.Fatalf("unexpected size: %+v", content.Info)
}
noCaption := matrixOutboundContent(
"",
"image.png",
event.MsgImage,
"image/png",
0,
id.ContentURIString("mxc://matrix.org/def"),
)
if noCaption.Body != "image.png" {
t.Fatalf("unexpected fallback body: %q", noCaption.Body)
}
}
+362 -29
View File
@@ -3,7 +3,10 @@ package qq
import (
"context"
"fmt"
"regexp"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/tencent-connect/botgo"
@@ -20,6 +23,14 @@ import (
"github.com/sipeed/picoclaw/pkg/logger"
)
const (
dedupTTL = 5 * time.Minute
dedupInterval = 60 * time.Second
dedupMaxSize = 10000 // hard cap on dedup map entries
typingResend = 8 * time.Second
typingSeconds = 10
)
type QQChannel struct {
*channels.BaseChannel
config config.QQConfig
@@ -28,20 +39,37 @@ type QQChannel struct {
ctx context.Context
cancel context.CancelFunc
sessionManager botgo.SessionManager
processedIDs map[string]bool
mu sync.RWMutex
// Chat routing: track whether a chatID is group or direct.
chatType sync.Map // chatID → "group" | "direct"
// Passive reply: store last inbound message ID per chat.
lastMsgID sync.Map // chatID → string
// msg_seq: per-chat atomic counter for multi-part replies.
msgSeqCounters sync.Map // chatID → *atomic.Uint64
// Time-based dedup replacing the unbounded map.
dedup map[string]time.Time
muDedup sync.Mutex
// done is closed on Stop to shut down the dedup janitor.
done chan struct{}
stopOnce sync.Once
}
func NewQQChannel(cfg config.QQConfig, messageBus *bus.MessageBus) (*QQChannel, error) {
base := channels.NewBaseChannel("qq", cfg, messageBus, cfg.AllowFrom,
channels.WithMaxMessageLength(cfg.MaxMessageLength),
channels.WithGroupTrigger(cfg.GroupTrigger),
channels.WithReasoningChannelID(cfg.ReasoningChannelID),
)
return &QQChannel{
BaseChannel: base,
config: cfg,
processedIDs: make(map[string]bool),
BaseChannel: base,
config: cfg,
dedup: make(map[string]time.Time),
done: make(chan struct{}),
}, nil
}
@@ -52,6 +80,10 @@ func (c *QQChannel) Start(ctx context.Context) error {
logger.InfoC("qq", "Starting QQ bot (WebSocket mode)")
// Reinitialize shutdown signal for clean restart.
c.done = make(chan struct{})
c.stopOnce = sync.Once{}
// create token source
credentials := &token.QQBotCredentials{
AppID: c.config.AppID,
@@ -99,6 +131,15 @@ func (c *QQChannel) Start(ctx context.Context) error {
}
}()
// start dedup janitor goroutine
go c.dedupJanitor()
// Pre-register reasoning_channel_id as group chat if configured,
// so outbound-only destinations are routed correctly.
if c.config.ReasoningChannelID != "" {
c.chatType.Store(c.config.ReasoningChannelID, "group")
}
c.SetRunning(true)
logger.InfoC("qq", "QQ bot started successfully")
@@ -109,6 +150,9 @@ func (c *QQChannel) Stop(ctx context.Context) error {
logger.InfoC("qq", "Stopping QQ bot")
c.SetRunning(false)
// Signal the dedup janitor to stop (idempotent).
c.stopOnce.Do(func() { close(c.done) })
if c.cancel != nil {
c.cancel()
}
@@ -116,21 +160,82 @@ func (c *QQChannel) Stop(ctx context.Context) error {
return nil
}
// getChatKind returns the chat type for a given chatID ("group" or "direct").
// Unknown chatIDs default to "group" and log a warning, since QQ group IDs are
// more common as outbound-only destinations (e.g. reasoning_channel_id).
func (c *QQChannel) getChatKind(chatID string) string {
if v, ok := c.chatType.Load(chatID); ok {
if k, ok := v.(string); ok {
return k
}
}
logger.DebugCF("qq", "Unknown chat type for chatID, defaulting to group", map[string]any{
"chat_id": chatID,
})
return "group"
}
func (c *QQChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
if !c.IsRunning() {
return channels.ErrNotRunning
}
// construct message
chatKind := c.getChatKind(msg.ChatID)
// Build message with content.
msgToCreate := &dto.MessageToCreate{
Content: msg.Content,
MsgType: dto.TextMsg,
}
// Use Markdown message type if enabled in config.
if c.config.SendMarkdown {
msgToCreate.MsgType = dto.MarkdownMsg
msgToCreate.Markdown = &dto.Markdown{
Content: msg.Content,
}
// Clear plain content to avoid sending duplicate text.
msgToCreate.Content = ""
}
// Attach passive reply msg_id and msg_seq if available.
if v, ok := c.lastMsgID.Load(msg.ChatID); ok {
if msgID, ok := v.(string); ok && msgID != "" {
msgToCreate.MsgID = msgID
// Increment msg_seq atomically for multi-part replies.
if counterVal, ok := c.msgSeqCounters.Load(msg.ChatID); ok {
if counter, ok := counterVal.(*atomic.Uint64); ok {
seq := counter.Add(1)
msgToCreate.MsgSeq = uint32(seq)
}
}
}
}
// Sanitize URLs in group messages to avoid QQ's URL blacklist rejection.
if chatKind == "group" {
if msgToCreate.Content != "" {
msgToCreate.Content = sanitizeURLs(msgToCreate.Content)
}
if msgToCreate.Markdown != nil && msgToCreate.Markdown.Content != "" {
msgToCreate.Markdown.Content = sanitizeURLs(msgToCreate.Markdown.Content)
}
}
// Route to group or C2C.
var err error
if chatKind == "group" {
_, err = c.api.PostGroupMessage(ctx, msg.ChatID, msgToCreate)
} else {
_, err = c.api.PostC2CMessage(ctx, msg.ChatID, msgToCreate)
}
// send C2C message
_, err := c.api.PostC2CMessage(ctx, msg.ChatID, msgToCreate)
if err != nil {
logger.ErrorCF("qq", "Failed to send C2C message", map[string]any{
"error": err.Error(),
logger.ErrorCF("qq", "Failed to send message", map[string]any{
"chat_id": msg.ChatID,
"chat_kind": chatKind,
"error": err.Error(),
})
return fmt.Errorf("qq send: %w", channels.ErrTemporary)
}
@@ -138,7 +243,150 @@ func (c *QQChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
return nil
}
// handleC2CMessage handles QQ private messages
// StartTyping implements channels.TypingCapable.
// It sends an InputNotify (msg_type=6) immediately and re-sends every 8 seconds.
// The returned stop function is idempotent and cancels the goroutine.
func (c *QQChannel) StartTyping(ctx context.Context, chatID string) (func(), error) {
// We need a stored msg_id for passive InputNotify; skip if none available.
v, ok := c.lastMsgID.Load(chatID)
if !ok {
return func() {}, nil
}
msgID, ok := v.(string)
if !ok || msgID == "" {
return func() {}, nil
}
chatKind := c.getChatKind(chatID)
sendTyping := func(sendCtx context.Context) {
typingMsg := &dto.MessageToCreate{
MsgType: dto.InputNotifyMsg,
MsgID: msgID,
InputNotify: &dto.InputNotify{
InputType: 1,
InputSecond: typingSeconds,
},
}
var err error
if chatKind == "group" {
_, err = c.api.PostGroupMessage(sendCtx, chatID, typingMsg)
} else {
_, err = c.api.PostC2CMessage(sendCtx, chatID, typingMsg)
}
if err != nil {
logger.DebugCF("qq", "Failed to send typing indicator", map[string]any{
"chat_id": chatID,
"error": err.Error(),
})
}
}
// Send immediately.
sendTyping(c.ctx)
typingCtx, cancel := context.WithCancel(c.ctx)
go func() {
ticker := time.NewTicker(typingResend)
defer ticker.Stop()
for {
select {
case <-typingCtx.Done():
return
case <-ticker.C:
sendTyping(typingCtx)
}
}
}()
return cancel, nil
}
// SendMedia implements the channels.MediaSender interface.
// QQ RichMediaMessage requires an HTTP/HTTPS URL — local file paths are not supported.
// If part.Ref is already an http(s) URL it is used directly; otherwise we try
// the media store, and skip with a warning if the resolved path is not an HTTP URL.
func (c *QQChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error {
if !c.IsRunning() {
return channels.ErrNotRunning
}
chatKind := c.getChatKind(msg.ChatID)
for _, part := range msg.Parts {
// If the ref is already an HTTP(S) URL, use it directly.
mediaURL := part.Ref
if !isHTTPURL(mediaURL) {
// Try resolving through media store.
store := c.GetMediaStore()
if store == nil {
logger.WarnCF("qq", "QQ media requires HTTP/HTTPS URL, no media store available", map[string]any{
"ref": part.Ref,
})
continue
}
resolved, err := store.Resolve(part.Ref)
if err != nil {
logger.ErrorCF("qq", "Failed to resolve media ref", map[string]any{
"ref": part.Ref,
"error": err.Error(),
})
continue
}
if !isHTTPURL(resolved) {
logger.WarnCF("qq", "QQ media requires HTTP/HTTPS URL, local files not supported", map[string]any{
"ref": part.Ref,
"resolved": resolved,
})
continue
}
mediaURL = resolved
}
// Map part type to QQ file type: 1=image, 2=video, 3=audio, 4=file.
var fileType uint64
switch part.Type {
case "image":
fileType = 1
case "video":
fileType = 2
case "audio":
fileType = 3
default:
fileType = 4 // file
}
richMedia := &dto.RichMediaMessage{
FileType: fileType,
URL: mediaURL,
SrvSendMsg: true,
}
var sendErr error
if chatKind == "group" {
_, sendErr = c.api.PostGroupMessage(ctx, msg.ChatID, richMedia)
} else {
_, sendErr = c.api.PostC2CMessage(ctx, msg.ChatID, richMedia)
}
if sendErr != nil {
logger.ErrorCF("qq", "Failed to send media", map[string]any{
"type": part.Type,
"chat_id": msg.ChatID,
"error": sendErr.Error(),
})
return fmt.Errorf("qq send media: %w", channels.ErrTemporary)
}
}
return nil
}
// handleC2CMessage handles QQ private messages.
func (c *QQChannel) handleC2CMessage() event.C2CMessageEventHandler {
return func(event *dto.WSPayload, data *dto.WSC2CMessageData) error {
// deduplication check
@@ -167,7 +415,13 @@ func (c *QQChannel) handleC2CMessage() event.C2CMessageEventHandler {
"length": len(content),
})
// 转发到消息总线
// Store chat routing context.
c.chatType.Store(senderID, "direct")
c.lastMsgID.Store(senderID, data.ID)
// Reset msg_seq counter for new inbound message.
c.msgSeqCounters.Store(senderID, new(atomic.Uint64))
metadata := map[string]string{}
sender := bus.SenderInfo{
@@ -195,7 +449,7 @@ func (c *QQChannel) handleC2CMessage() event.C2CMessageEventHandler {
}
}
// handleGroupATMessage handles QQ group @ messages
// handleGroupATMessage handles QQ group @ messages.
func (c *QQChannel) handleGroupATMessage() event.GroupATMessageEventHandler {
return func(event *dto.WSPayload, data *dto.WSGroupATMessageData) error {
// deduplication check
@@ -232,7 +486,13 @@ func (c *QQChannel) handleGroupATMessage() event.GroupATMessageEventHandler {
"length": len(content),
})
// 转发到消息总线(使用 GroupID 作为 ChatID
// Store chat routing context using GroupID as chatID.
c.chatType.Store(data.GroupID, "group")
c.lastMsgID.Store(data.GroupID, data.ID)
// Reset msg_seq counter for new inbound message.
c.msgSeqCounters.Store(data.GroupID, new(atomic.Uint64))
metadata := map[string]string{
"group_id": data.GroupID,
}
@@ -262,29 +522,102 @@ func (c *QQChannel) handleGroupATMessage() event.GroupATMessageEventHandler {
}
}
// isDuplicate 检查消息是否重复
// isDuplicate checks whether a message has been seen within the TTL window.
// It also enforces a hard cap on map size by evicting oldest entries.
func (c *QQChannel) isDuplicate(messageID string) bool {
c.mu.Lock()
defer c.mu.Unlock()
c.muDedup.Lock()
defer c.muDedup.Unlock()
if c.processedIDs[messageID] {
if ts, exists := c.dedup[messageID]; exists && time.Since(ts) < dedupTTL {
return true
}
c.processedIDs[messageID] = true
// 简单清理:限制 map 大小
if len(c.processedIDs) > 10000 {
// 清空一半
count := 0
for id := range c.processedIDs {
if count >= 5000 {
break
// Enforce hard cap: evict oldest entries when at capacity.
if len(c.dedup) >= dedupMaxSize {
var oldestID string
var oldestTS time.Time
for id, ts := range c.dedup {
if oldestID == "" || ts.Before(oldestTS) {
oldestID = id
oldestTS = ts
}
delete(c.processedIDs, id)
count++
}
if oldestID != "" {
delete(c.dedup, oldestID)
}
}
c.dedup[messageID] = time.Now()
return false
}
// dedupJanitor periodically evicts expired entries from the dedup map.
func (c *QQChannel) dedupJanitor() {
ticker := time.NewTicker(dedupInterval)
defer ticker.Stop()
for {
select {
case <-c.done:
return
case <-ticker.C:
// Collect expired keys under read-like scan.
c.muDedup.Lock()
now := time.Now()
var expired []string
for id, ts := range c.dedup {
if now.Sub(ts) >= dedupTTL {
expired = append(expired, id)
}
}
for _, id := range expired {
delete(c.dedup, id)
}
c.muDedup.Unlock()
}
}
}
// isHTTPURL returns true if s starts with http:// or https://.
func isHTTPURL(s string) bool {
return strings.HasPrefix(s, "http://") || strings.HasPrefix(s, "https://")
}
// urlPattern matches URLs with explicit http(s):// scheme.
// Only scheme-prefixed URLs are matched to avoid false positives on bare text
// like version numbers (e.g., "1.2.3") or domain-like fragments.
var urlPattern = regexp.MustCompile(
`(?i)` +
`https?://` + // required scheme
`(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+` + // domain parts
`[a-zA-Z]{2,}` + // TLD
`(?:[/?#]\S*)?`, // optional path/query/fragment
)
// sanitizeURLs replaces dots in URL domains with "。" (fullwidth period)
// to prevent QQ's URL blacklist from rejecting the message.
func sanitizeURLs(text string) string {
return urlPattern.ReplaceAllStringFunc(text, func(match string) string {
// Split into scheme + rest (scheme is always present).
idx := strings.Index(match, "://")
scheme := match[:idx+3]
rest := match[idx+3:]
// Find where the domain ends (first / ? or #).
domainEnd := len(rest)
for i, ch := range rest {
if ch == '/' || ch == '?' || ch == '#' {
domainEnd = i
break
}
}
domain := rest[:domainEnd]
path := rest[domainEnd:]
// Replace dots in domain only.
domain = strings.ReplaceAll(domain, ".", "。")
return scheme + domain + path
})
}
+71 -33
View File
@@ -168,7 +168,7 @@ func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) err
return channels.ErrNotRunning
}
chatID, err := parseChatID(msg.ChatID)
chatID, threadID, err := parseTelegramChatID(msg.ChatID)
if err != nil {
return fmt.Errorf("invalid chat ID %s: %w", msg.ChatID, channels.ErrSendFailed)
}
@@ -201,7 +201,7 @@ func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) err
continue
}
if err := c.sendHTMLChunk(ctx, chatID, htmlContent, chunk, replyToID); err != nil {
if err := c.sendHTMLChunk(ctx, chatID, threadID, htmlContent, chunk, replyToID); err != nil {
return err
}
// Only the first chunk should be a reply; subsequent chunks are normal messages.
@@ -214,12 +214,11 @@ func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) err
// sendHTMLChunk sends a single HTML message, falling back to the original
// markdown as plain text on parse failure so users never see raw HTML tags.
func (c *TelegramChannel) sendHTMLChunk(
ctx context.Context,
chatID int64,
htmlContent, mdFallback, replyToID string,
ctx context.Context, chatID int64, threadID int, htmlContent, mdFallback string, replyToID string
) error {
tgMsg := tu.Message(tu.ID(chatID), htmlContent)
tgMsg.ParseMode = telego.ModeHTML
tgMsg.MessageThreadID = threadID
if replyToID != "" {
if mid, parseErr := strconv.Atoi(replyToID); parseErr == nil {
@@ -247,13 +246,16 @@ func (c *TelegramChannel) sendHTMLChunk(
// (Telegram's typing indicator expires after ~5s) in a background goroutine.
// The returned stop function is idempotent and cancels the goroutine.
func (c *TelegramChannel) StartTyping(ctx context.Context, chatID string) (func(), error) {
cid, err := parseChatID(chatID)
cid, threadID, err := parseTelegramChatID(chatID)
if err != nil {
return func() {}, err
}
action := tu.ChatAction(tu.ID(cid), telego.ChatActionTyping)
action.MessageThreadID = threadID
// Send the first typing action immediately
_ = c.bot.SendChatAction(ctx, tu.ChatAction(tu.ID(cid), telego.ChatActionTyping))
_ = c.bot.SendChatAction(ctx, action)
typingCtx, cancel := context.WithCancel(ctx)
go func() {
@@ -264,7 +266,9 @@ func (c *TelegramChannel) StartTyping(ctx context.Context, chatID string) (func(
case <-typingCtx.Done():
return
case <-ticker.C:
_ = c.bot.SendChatAction(typingCtx, tu.ChatAction(tu.ID(cid), telego.ChatActionTyping))
a := tu.ChatAction(tu.ID(cid), telego.ChatActionTyping)
a.MessageThreadID = threadID
_ = c.bot.SendChatAction(typingCtx, a)
}
}
}()
@@ -274,7 +278,7 @@ func (c *TelegramChannel) StartTyping(ctx context.Context, chatID string) (func(
// EditMessage implements channels.MessageEditor.
func (c *TelegramChannel) EditMessage(ctx context.Context, chatID string, messageID string, content string) error {
cid, err := parseChatID(chatID)
cid, _, err := parseTelegramChatID(chatID)
if err != nil {
return err
}
@@ -303,12 +307,14 @@ func (c *TelegramChannel) SendPlaceholder(ctx context.Context, chatID string) (s
text = "Thinking... 💭"
}
cid, err := parseChatID(chatID)
cid, threadID, err := parseTelegramChatID(chatID)
if err != nil {
return "", err
}
pMsg, err := c.bot.SendMessage(ctx, tu.Message(tu.ID(cid), text))
phMsg := tu.Message(tu.ID(cid), text)
phMsg.MessageThreadID = threadID
pMsg, err := c.bot.SendMessage(ctx, phMsg)
if err != nil {
return "", err
}
@@ -322,7 +328,7 @@ func (c *TelegramChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMe
return channels.ErrNotRunning
}
chatID, err := parseChatID(msg.ChatID)
chatID, threadID, err := parseTelegramChatID(msg.ChatID)
if err != nil {
return fmt.Errorf("invalid chat ID %s: %w", msg.ChatID, channels.ErrSendFailed)
}
@@ -354,30 +360,34 @@ func (c *TelegramChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMe
switch part.Type {
case "image":
params := &telego.SendPhotoParams{
ChatID: tu.ID(chatID),
Photo: telego.InputFile{File: file},
Caption: part.Caption,
ChatID: tu.ID(chatID),
MessageThreadID: threadID,
Photo: telego.InputFile{File: file},
Caption: part.Caption,
}
_, err = c.bot.SendPhoto(ctx, params)
case "audio":
params := &telego.SendAudioParams{
ChatID: tu.ID(chatID),
Audio: telego.InputFile{File: file},
Caption: part.Caption,
ChatID: tu.ID(chatID),
MessageThreadID: threadID,
Audio: telego.InputFile{File: file},
Caption: part.Caption,
}
_, err = c.bot.SendAudio(ctx, params)
case "video":
params := &telego.SendVideoParams{
ChatID: tu.ID(chatID),
Video: telego.InputFile{File: file},
Caption: part.Caption,
ChatID: tu.ID(chatID),
MessageThreadID: threadID,
Video: telego.InputFile{File: file},
Caption: part.Caption,
}
_, err = c.bot.SendVideo(ctx, params)
default: // "file" or unknown types
params := &telego.SendDocumentParams{
ChatID: tu.ID(chatID),
Document: telego.InputFile{File: file},
Caption: part.Caption,
ChatID: tu.ID(chatID),
MessageThreadID: threadID,
Document: telego.InputFile{File: file},
Caption: part.Caption,
}
_, err = c.bot.SendDocument(ctx, params)
}
@@ -521,19 +531,28 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Mes
content = cleaned
}
// For forum topics, embed the thread ID as "chatID/threadID" so replies
// route to the correct topic and each topic gets its own session.
// Only forum groups (IsForum) are handled; regular group reply threads
// must share one session per group.
compositeChatID := fmt.Sprintf("%d", chatID)
threadID := message.MessageThreadID
if message.Chat.IsForum && threadID != 0 {
compositeChatID = fmt.Sprintf("%d/%d", chatID, threadID)
}
logger.DebugCF("telegram", "Received message", map[string]any{
"sender_id": sender.CanonicalID,
"chat_id": fmt.Sprintf("%d", chatID),
"chat_id": compositeChatID,
"thread_id": threadID,
"preview": utils.Truncate(content, 50),
})
// Placeholder is now auto-triggered by BaseChannel.HandleMessage via PlaceholderCapable
peerKind := "direct"
peerID := fmt.Sprintf("%d", user.ID)
if message.Chat.Type != "private" {
peerKind = "group"
peerID = fmt.Sprintf("%d", chatID)
peerID = compositeChatID
}
peer := bus.Peer{Kind: peerKind, ID: peerID}
@@ -546,11 +565,17 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Mes
"is_group": fmt.Sprintf("%t", message.Chat.Type != "private"),
}
// Set parent_peer metadata for per-topic agent binding.
if message.Chat.IsForum && threadID != 0 {
metadata["parent_peer_kind"] = "topic"
metadata["parent_peer_id"] = fmt.Sprintf("%d", threadID)
}
c.HandleMessage(c.ctx,
peer,
messageID,
platformID,
fmt.Sprintf("%d", chatID),
compositeChatID,
content,
mediaPaths,
metadata,
@@ -598,10 +623,23 @@ func (c *TelegramChannel) downloadFile(ctx context.Context, fileID, ext string)
return c.downloadFileWithInfo(file, ext)
}
func parseChatID(chatIDStr string) (int64, error) {
var id int64
_, err := fmt.Sscanf(chatIDStr, "%d", &id)
return id, err
// parseTelegramChatID splits "chatID/threadID" into its components.
// Returns threadID=0 when no "/" is present (non-forum messages).
func parseTelegramChatID(chatID string) (int64, int, error) {
idx := strings.Index(chatID, "/")
if idx == -1 {
cid, err := strconv.ParseInt(chatID, 10, 64)
return cid, 0, err
}
cid, err := strconv.ParseInt(chatID[:idx], 10, 64)
if err != nil {
return 0, 0, err
}
tid, err := strconv.Atoi(chatID[idx+1:])
if err != nil {
return 0, 0, fmt.Errorf("invalid thread ID in chat ID %q: %w", chatID, err)
}
return cid, tid, nil
}
func markdownToTelegramHTML(text string) string {
+189
View File
@@ -6,6 +6,7 @@ import (
"errors"
"strings"
"testing"
"time"
"github.com/mymmrac/telego"
ta "github.com/mymmrac/telego/telegoapi"
@@ -271,3 +272,191 @@ func TestSend_InvalidChatID(t *testing.T) {
assert.True(t, errors.Is(err, channels.ErrSendFailed), "error should wrap ErrSendFailed")
assert.Empty(t, caller.calls)
}
func TestParseTelegramChatID_Plain(t *testing.T) {
cid, tid, err := parseTelegramChatID("12345")
assert.NoError(t, err)
assert.Equal(t, int64(12345), cid)
assert.Equal(t, 0, tid)
}
func TestParseTelegramChatID_NegativeGroup(t *testing.T) {
cid, tid, err := parseTelegramChatID("-1001234567890")
assert.NoError(t, err)
assert.Equal(t, int64(-1001234567890), cid)
assert.Equal(t, 0, tid)
}
func TestParseTelegramChatID_WithThreadID(t *testing.T) {
cid, tid, err := parseTelegramChatID("-1001234567890/42")
assert.NoError(t, err)
assert.Equal(t, int64(-1001234567890), cid)
assert.Equal(t, 42, tid)
}
func TestParseTelegramChatID_GeneralTopic(t *testing.T) {
cid, tid, err := parseTelegramChatID("-100123/1")
assert.NoError(t, err)
assert.Equal(t, int64(-100123), cid)
assert.Equal(t, 1, tid)
}
func TestParseTelegramChatID_Invalid(t *testing.T) {
_, _, err := parseTelegramChatID("not-a-number")
assert.Error(t, err)
}
func TestParseTelegramChatID_InvalidThreadID(t *testing.T) {
_, _, err := parseTelegramChatID("-100123/not-a-thread")
assert.Error(t, err)
assert.Contains(t, err.Error(), "invalid thread ID")
}
func TestSend_WithForumThreadID(t *testing.T) {
caller := &stubCaller{
callFn: func(ctx context.Context, url string, data *ta.RequestData) (*ta.Response, error) {
return successResponse(t), nil
},
}
ch := newTestChannel(t, caller)
err := ch.Send(context.Background(), bus.OutboundMessage{
ChatID: "-1001234567890/42",
Content: "Hello from topic",
})
assert.NoError(t, err)
assert.Len(t, caller.calls, 1)
}
func TestHandleMessage_ForumTopic_SetsMetadata(t *testing.T) {
messageBus := bus.NewMessageBus()
ch := &TelegramChannel{
BaseChannel: channels.NewBaseChannel("telegram", nil, messageBus, nil),
chatIDs: make(map[string]int64),
ctx: context.Background(),
}
msg := &telego.Message{
Text: "hello from topic",
MessageID: 10,
MessageThreadID: 42,
Chat: telego.Chat{
ID: -1001234567890,
Type: "supergroup",
IsForum: true,
},
From: &telego.User{
ID: 7,
FirstName: "Alice",
},
}
err := ch.handleMessage(context.Background(), msg)
require.NoError(t, err)
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
inbound, ok := messageBus.ConsumeInbound(ctx)
require.True(t, ok, "expected inbound message")
// Composite chatID should include thread ID
assert.Equal(t, "-1001234567890/42", inbound.ChatID)
// Peer ID should include thread ID for session key isolation
assert.Equal(t, "group", inbound.Peer.Kind)
assert.Equal(t, "-1001234567890/42", inbound.Peer.ID)
// Parent peer metadata should be set for agent binding
assert.Equal(t, "topic", inbound.Metadata["parent_peer_kind"])
assert.Equal(t, "42", inbound.Metadata["parent_peer_id"])
}
func TestHandleMessage_NoForum_NoThreadMetadata(t *testing.T) {
messageBus := bus.NewMessageBus()
ch := &TelegramChannel{
BaseChannel: channels.NewBaseChannel("telegram", nil, messageBus, nil),
chatIDs: make(map[string]int64),
ctx: context.Background(),
}
msg := &telego.Message{
Text: "regular group message",
MessageID: 11,
Chat: telego.Chat{
ID: -100999,
Type: "group",
},
From: &telego.User{
ID: 8,
FirstName: "Bob",
},
}
err := ch.handleMessage(context.Background(), msg)
require.NoError(t, err)
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
inbound, ok := messageBus.ConsumeInbound(ctx)
require.True(t, ok)
// Plain chatID without thread suffix
assert.Equal(t, "-100999", inbound.ChatID)
// Peer ID should be raw chat ID (no thread suffix)
assert.Equal(t, "group", inbound.Peer.Kind)
assert.Equal(t, "-100999", inbound.Peer.ID)
// No parent peer metadata
assert.Empty(t, inbound.Metadata["parent_peer_kind"])
assert.Empty(t, inbound.Metadata["parent_peer_id"])
}
func TestHandleMessage_ReplyThread_NonForum_NoIsolation(t *testing.T) {
messageBus := bus.NewMessageBus()
ch := &TelegramChannel{
BaseChannel: channels.NewBaseChannel("telegram", nil, messageBus, nil),
chatIDs: make(map[string]int64),
ctx: context.Background(),
}
// In regular groups, reply threads set MessageThreadID to the original
// message ID. This should NOT trigger per-thread session isolation.
msg := &telego.Message{
Text: "reply in thread",
MessageID: 20,
MessageThreadID: 15,
Chat: telego.Chat{
ID: -100999,
Type: "supergroup",
IsForum: false,
},
From: &telego.User{
ID: 9,
FirstName: "Carol",
},
}
err := ch.handleMessage(context.Background(), msg)
require.NoError(t, err)
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
inbound, ok := messageBus.ConsumeInbound(ctx)
require.True(t, ok)
// chatID should NOT include thread suffix for non-forum groups
assert.Equal(t, "-100999", inbound.ChatID)
// Peer ID should be raw chat ID (shared session for whole group)
assert.Equal(t, "group", inbound.Peer.Kind)
assert.Equal(t, "-100999", inbound.Peer.ID)
// No parent peer metadata
assert.Empty(t, inbound.Metadata["parent_peer_kind"])
assert.Empty(t, inbound.Metadata["parent_peer_id"])
}
+1
View File
@@ -12,5 +12,6 @@ func BuiltinDefinitions() []Definition {
listCommand(),
switchCommand(),
checkCommand(),
clearCommand(),
}
}
+20
View File
@@ -0,0 +1,20 @@
package commands
import "context"
func clearCommand() Definition {
return Definition{
Name: "clear",
Description: "Clear the chat history",
Usage: "/clear",
Handler: func(_ context.Context, req Request, rt *Runtime) error {
if rt == nil || rt.ClearHistory == nil {
return req.Reply(unavailableMsg)
}
if err := rt.ClearHistory(); err != nil {
return req.Reply("Failed to clear chat history: " + err.Error())
}
return req.Reply("Chat history cleared!")
},
}
}
+1
View File
@@ -13,4 +13,5 @@ type Runtime struct {
GetEnabledChannels func() []string
SwitchModel func(value string) (oldModel string, err error)
SwitchChannel func(value string) error
ClearHistory func() error
}
+94 -27
View File
@@ -4,6 +4,7 @@ import (
"encoding/json"
"fmt"
"os"
"strings"
"sync/atomic"
"github.com/caarlos0/env/v11"
@@ -58,7 +59,16 @@ type Config struct {
Tools ToolsConfig `json:"tools"`
Heartbeat HeartbeatConfig `json:"heartbeat"`
Devices DevicesConfig `json:"devices"`
Voice VoiceConfig `json:"voice"`
// BuildInfo contains build-time version information
BuildInfo BuildInfo `json:"build_info,omitempty"`
}
// BuildInfo contains build-time version information
type BuildInfo struct {
Version string `json:"version"`
GitCommit string `json:"git_commit"`
BuildTime string `json:"build_time"`
GoVersion string `json:"go_version"`
}
// MarshalJSON implements custom JSON marshaling for Config
@@ -226,6 +236,7 @@ type ChannelsConfig struct {
QQ QQConfig `json:"qq"`
DingTalk DingTalkConfig `json:"dingtalk"`
Slack SlackConfig `json:"slack"`
Matrix MatrixConfig `json:"matrix"`
LINE LINEConfig `json:"line"`
OneBot OneBotConfig `json:"onebot"`
WeCom WeComConfig `json:"wecom"`
@@ -274,15 +285,16 @@ type TelegramConfig struct {
}
type FeishuConfig struct {
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_FEISHU_ENABLED"`
AppID string `json:"app_id" env:"PICOCLAW_CHANNELS_FEISHU_APP_ID"`
AppSecret string `json:"app_secret" env:"PICOCLAW_CHANNELS_FEISHU_APP_SECRET"`
EncryptKey string `json:"encrypt_key" env:"PICOCLAW_CHANNELS_FEISHU_ENCRYPT_KEY"`
VerificationToken string `json:"verification_token" env:"PICOCLAW_CHANNELS_FEISHU_VERIFICATION_TOKEN"`
AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_FEISHU_ALLOW_FROM"`
GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"`
Placeholder PlaceholderConfig `json:"placeholder,omitempty"`
ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_FEISHU_REASONING_CHANNEL_ID"`
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_FEISHU_ENABLED"`
AppID string `json:"app_id" env:"PICOCLAW_CHANNELS_FEISHU_APP_ID"`
AppSecret string `json:"app_secret" env:"PICOCLAW_CHANNELS_FEISHU_APP_SECRET"`
EncryptKey string `json:"encrypt_key" env:"PICOCLAW_CHANNELS_FEISHU_ENCRYPT_KEY"`
VerificationToken string `json:"verification_token" env:"PICOCLAW_CHANNELS_FEISHU_VERIFICATION_TOKEN"`
AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_FEISHU_ALLOW_FROM"`
GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"`
Placeholder PlaceholderConfig `json:"placeholder,omitempty"`
ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_FEISHU_REASONING_CHANNEL_ID"`
RandomReactionEmoji FlexibleStringSlice `json:"random_reaction_emoji" env:"PICOCLAW_CHANNELS_FEISHU_RANDOM_REACTION_EMOJI"`
}
type DiscordConfig struct {
@@ -311,6 +323,8 @@ type QQConfig struct {
AppSecret string `json:"app_secret" env:"PICOCLAW_CHANNELS_QQ_APP_SECRET"`
AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_QQ_ALLOW_FROM"`
GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"`
MaxMessageLength int `json:"max_message_length" env:"PICOCLAW_CHANNELS_QQ_MAX_MESSAGE_LENGTH"`
SendMarkdown bool `json:"send_markdown" env:"PICOCLAW_CHANNELS_QQ_SEND_MARKDOWN"`
ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_QQ_REASONING_CHANNEL_ID"`
}
@@ -334,6 +348,19 @@ type SlackConfig struct {
ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_SLACK_REASONING_CHANNEL_ID"`
}
type MatrixConfig struct {
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_MATRIX_ENABLED"`
Homeserver string `json:"homeserver" env:"PICOCLAW_CHANNELS_MATRIX_HOMESERVER"`
UserID string `json:"user_id" env:"PICOCLAW_CHANNELS_MATRIX_USER_ID"`
AccessToken string `json:"access_token" env:"PICOCLAW_CHANNELS_MATRIX_ACCESS_TOKEN"`
DeviceID string `json:"device_id,omitempty" env:"PICOCLAW_CHANNELS_MATRIX_DEVICE_ID"`
JoinOnInvite bool `json:"join_on_invite" env:"PICOCLAW_CHANNELS_MATRIX_JOIN_ON_INVITE"`
AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_MATRIX_ALLOW_FROM"`
GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"`
Placeholder PlaceholderConfig `json:"placeholder,omitempty"`
ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_MATRIX_REASONING_CHANNEL_ID"`
}
type LINEConfig struct {
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_LINE_ENABLED"`
ChannelSecret string `json:"channel_secret" env:"PICOCLAW_CHANNELS_LINE_CHANNEL_SECRET"`
@@ -445,10 +472,6 @@ type DevicesConfig struct {
MonitorUSB bool `json:"monitor_usb" env:"PICOCLAW_DEVICES_MONITOR_USB"`
}
type VoiceConfig struct {
EchoTranscription bool `json:"echo_transcription" env:"PICOCLAW_VOICE_ECHO_TRANSCRIPTION"`
}
type ProvidersConfig struct {
Anthropic ProviderConfig `json:"anthropic"`
OpenAI OpenAIProviderConfig `json:"openai"`
@@ -464,12 +487,14 @@ type ProvidersConfig struct {
ShengSuanYun ProviderConfig `json:"shengsuanyun"`
DeepSeek ProviderConfig `json:"deepseek"`
Cerebras ProviderConfig `json:"cerebras"`
Vivgrid ProviderConfig `json:"vivgrid"`
VolcEngine ProviderConfig `json:"volcengine"`
GitHubCopilot ProviderConfig `json:"github_copilot"`
Antigravity ProviderConfig `json:"antigravity"`
Qwen ProviderConfig `json:"qwen"`
Mistral ProviderConfig `json:"mistral"`
Avian ProviderConfig `json:"avian"`
Minimax ProviderConfig `json:"minimax"`
}
// IsEmpty checks if all provider configs are empty (no API keys or API bases set)
@@ -489,12 +514,14 @@ func (p ProvidersConfig) IsEmpty() bool {
p.ShengSuanYun.APIKey == "" && p.ShengSuanYun.APIBase == "" &&
p.DeepSeek.APIKey == "" && p.DeepSeek.APIBase == "" &&
p.Cerebras.APIKey == "" && p.Cerebras.APIBase == "" &&
p.Vivgrid.APIKey == "" && p.Vivgrid.APIBase == "" &&
p.VolcEngine.APIKey == "" && p.VolcEngine.APIBase == "" &&
p.GitHubCopilot.APIKey == "" && p.GitHubCopilot.APIBase == "" &&
p.Antigravity.APIKey == "" && p.Antigravity.APIBase == "" &&
p.Qwen.APIKey == "" && p.Qwen.APIBase == "" &&
p.Mistral.APIKey == "" && p.Mistral.APIBase == "" &&
p.Avian.APIKey == "" && p.Avian.APIBase == ""
p.Avian.APIKey == "" && p.Avian.APIBase == "" &&
p.Minimax.APIKey == "" && p.Minimax.APIBase == ""
}
// MarshalJSON implements custom JSON marshaling for ProvidersConfig
@@ -564,21 +591,31 @@ type GatewayConfig struct {
Port int `json:"port" env:"PICOCLAW_GATEWAY_PORT"`
}
type ToolDiscoveryConfig struct {
Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_DISCOVERY_ENABLED"`
TTL int `json:"ttl" env:"PICOCLAW_TOOLS_DISCOVERY_TTL"`
MaxSearchResults int `json:"max_search_results" env:"PICOCLAW_MAX_SEARCH_RESULTS"`
UseBM25 bool `json:"use_bm25" env:"PICOCLAW_TOOLS_DISCOVERY_USE_BM25"`
UseRegex bool `json:"use_regex" env:"PICOCLAW_TOOLS_DISCOVERY_USE_REGEX"`
}
type ToolConfig struct {
Enabled bool `json:"enabled" env:"ENABLED"`
}
type BraveConfig struct {
Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_WEB_BRAVE_ENABLED"`
APIKey string `json:"api_key" env:"PICOCLAW_TOOLS_WEB_BRAVE_API_KEY"`
MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_BRAVE_MAX_RESULTS"`
Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_WEB_BRAVE_ENABLED"`
APIKey string `json:"api_key" env:"PICOCLAW_TOOLS_WEB_BRAVE_API_KEY"`
APIKeys []string `json:"api_keys" env:"PICOCLAW_TOOLS_WEB_BRAVE_API_KEYS"`
MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_BRAVE_MAX_RESULTS"`
}
type TavilyConfig struct {
Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_WEB_TAVILY_ENABLED"`
APIKey string `json:"api_key" env:"PICOCLAW_TOOLS_WEB_TAVILY_API_KEY"`
BaseURL string `json:"base_url" env:"PICOCLAW_TOOLS_WEB_TAVILY_BASE_URL"`
MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_TAVILY_MAX_RESULTS"`
Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_WEB_TAVILY_ENABLED"`
APIKey string `json:"api_key" env:"PICOCLAW_TOOLS_WEB_TAVILY_API_KEY"`
APIKeys []string `json:"api_keys" env:"PICOCLAW_TOOLS_WEB_TAVILY_API_KEYS"`
BaseURL string `json:"base_url" env:"PICOCLAW_TOOLS_WEB_TAVILY_BASE_URL"`
MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_TAVILY_MAX_RESULTS"`
}
type DuckDuckGoConfig struct {
@@ -587,9 +624,10 @@ type DuckDuckGoConfig struct {
}
type PerplexityConfig struct {
Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_WEB_PERPLEXITY_ENABLED"`
APIKey string `json:"api_key" env:"PICOCLAW_TOOLS_WEB_PERPLEXITY_API_KEY"`
MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_PERPLEXITY_MAX_RESULTS"`
Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_WEB_PERPLEXITY_ENABLED"`
APIKey string `json:"api_key" env:"PICOCLAW_TOOLS_WEB_PERPLEXITY_API_KEY"`
APIKeys []string `json:"api_keys" env:"PICOCLAW_TOOLS_WEB_PERPLEXITY_API_KEYS"`
MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_PERPLEXITY_MAX_RESULTS"`
}
type SearXNGConfig struct {
@@ -648,6 +686,11 @@ type MediaCleanupConfig struct {
Interval int ` env:"PICOCLAW_MEDIA_CLEANUP_INTERVAL" json:"interval_minutes"`
}
type ReadFileToolConfig struct {
Enabled bool `json:"enabled"`
MaxReadFileSize int `json:"max_read_file_size"`
}
type ToolsConfig struct {
AllowReadPaths []string `json:"allow_read_paths" env:"PICOCLAW_TOOLS_ALLOW_READ_PATHS"`
AllowWritePaths []string `json:"allow_write_paths" env:"PICOCLAW_TOOLS_ALLOW_WRITE_PATHS"`
@@ -664,7 +707,7 @@ type ToolsConfig struct {
InstallSkill ToolConfig `json:"install_skill" envPrefix:"PICOCLAW_TOOLS_INSTALL_SKILL_"`
ListDir ToolConfig `json:"list_dir" envPrefix:"PICOCLAW_TOOLS_LIST_DIR_"`
Message ToolConfig `json:"message" envPrefix:"PICOCLAW_TOOLS_MESSAGE_"`
ReadFile ToolConfig `json:"read_file" envPrefix:"PICOCLAW_TOOLS_READ_FILE_"`
ReadFile ReadFileToolConfig `json:"read_file" envPrefix:"PICOCLAW_TOOLS_READ_FILE_"`
SendFile ToolConfig `json:"send_file" envPrefix:"PICOCLAW_TOOLS_SEND_FILE_"`
Spawn ToolConfig `json:"spawn" envPrefix:"PICOCLAW_TOOLS_SPAWN_"`
SPI ToolConfig `json:"spi" envPrefix:"PICOCLAW_TOOLS_SPI_"`
@@ -716,7 +759,8 @@ type MCPServerConfig struct {
// MCPConfig defines configuration for all MCP servers
type MCPConfig struct {
ToolConfig `envPrefix:"PICOCLAW_TOOLS_MCP_"`
ToolConfig ` envPrefix:"PICOCLAW_TOOLS_MCP_"`
Discovery ToolDiscoveryConfig ` json:"discovery"`
// Servers is a map of server name to server configuration
Servers map[string]MCPServerConfig `json:"servers,omitempty"`
}
@@ -903,6 +947,29 @@ func (c *Config) ValidateModelList() error {
return nil
}
func MergeAPIKeys(apiKey string, apiKeys []string) []string {
seen := make(map[string]struct{})
var all []string
if k := strings.TrimSpace(apiKey); k != "" {
if _, exists := seen[k]; !exists {
seen[k] = struct{}{}
all = append(all, k)
}
}
for _, k := range apiKeys {
if trimmed := strings.TrimSpace(k); trimmed != "" {
if _, exists := seen[trimmed]; !exists {
seen[trimmed] = struct{}{}
all = append(all, trimmed)
}
}
}
return all
}
func (t *ToolsConfig) IsToolEnabled(name string) bool {
switch name {
case "web":
+4 -1
View File
@@ -283,6 +283,9 @@ func TestDefaultConfig_Channels(t *testing.T) {
if cfg.Channels.Slack.Enabled {
t.Error("Slack should be disabled by default")
}
if cfg.Channels.Matrix.Enabled {
t.Error("Matrix should be disabled by default")
}
}
// TestDefaultConfig_WebTools verifies web tools config
@@ -293,7 +296,7 @@ func TestDefaultConfig_WebTools(t *testing.T) {
if cfg.Tools.Web.Brave.MaxResults != 5 {
t.Error("Expected Brave MaxResults 5, got ", cfg.Tools.Web.Brave.MaxResults)
}
if cfg.Tools.Web.Brave.APIKey != "" {
if len(cfg.Tools.Web.Brave.APIKeys) != 0 {
t.Error("Brave API key should be empty by default")
}
if cfg.Tools.Web.DuckDuckGo.MaxResults != 5 {
+60 -8
View File
@@ -80,10 +80,11 @@ func DefaultConfig() *Config {
AllowFrom: FlexibleStringSlice{},
},
QQ: QQConfig{
Enabled: false,
AppID: "",
AppSecret: "",
AllowFrom: FlexibleStringSlice{},
Enabled: false,
AppID: "",
AppSecret: "",
AllowFrom: FlexibleStringSlice{},
MaxMessageLength: 2000,
},
DingTalk: DingTalkConfig{
Enabled: false,
@@ -97,6 +98,22 @@ func DefaultConfig() *Config {
AppToken: "",
AllowFrom: FlexibleStringSlice{},
},
Matrix: MatrixConfig{
Enabled: false,
Homeserver: "https://matrix.org",
UserID: "",
AccessToken: "",
DeviceID: "",
JoinOnInvite: true,
AllowFrom: FlexibleStringSlice{},
GroupTrigger: GroupTriggerConfig{
MentionOnly: true,
},
Placeholder: PlaceholderConfig{
Enabled: true,
Text: "Thinking... 💭",
},
},
LINE: LINEConfig{
Enabled: false,
ChannelSecret: "",
@@ -261,6 +278,14 @@ func DefaultConfig() *Config {
APIKey: "",
},
// Vivgrid - https://vivgrid.com
{
ModelName: "vivgrid-auto",
Model: "vivgrid/auto",
APIBase: "https://api.vivgrid.com/v1",
APIKey: "",
},
// Volcengine (火山引擎) - https://console.volcengine.com/ark
{
ModelName: "doubao-pro",
@@ -322,6 +347,14 @@ func DefaultConfig() *Config {
APIKey: "",
},
// Minimax - https://api.minimaxi.com/
{
ModelName: "MiniMax-M2.5",
Model: "minimax/MiniMax-M2.5",
APIBase: "https://api.minimaxi.com/v1",
APIKey: "",
},
// VLLM (local) - http://localhost:8000
{
ModelName: "local-model",
@@ -351,6 +384,13 @@ func DefaultConfig() *Config {
Brave: BraveConfig{
Enabled: false,
APIKey: "",
APIKeys: nil,
MaxResults: 5,
},
Tavily: TavilyConfig{
Enabled: false,
APIKey: "",
APIKeys: nil,
MaxResults: 5,
},
DuckDuckGo: DuckDuckGoConfig{
@@ -360,6 +400,7 @@ func DefaultConfig() *Config {
Perplexity: PerplexityConfig{
Enabled: false,
APIKey: "",
APIKeys: nil,
MaxResults: 5,
},
SearXNG: SearXNGConfig{
@@ -411,6 +452,13 @@ func DefaultConfig() *Config {
ToolConfig: ToolConfig{
Enabled: false,
},
Discovery: ToolDiscoveryConfig{
Enabled: false,
TTL: 5,
MaxSearchResults: 5,
UseBM25: true,
UseRegex: false,
},
Servers: map[string]MCPServerConfig{},
},
AppendFile: ToolConfig{
@@ -434,8 +482,9 @@ func DefaultConfig() *Config {
Message: ToolConfig{
Enabled: true,
},
ReadFile: ToolConfig{
Enabled: true,
ReadFile: ReadFileToolConfig{
Enabled: true,
MaxReadFileSize: 64 * 1024, // 64KB
},
Spawn: ToolConfig{
Enabled: true,
@@ -461,8 +510,11 @@ func DefaultConfig() *Config {
Enabled: false,
MonitorUSB: true,
},
Voice: VoiceConfig{
EchoTranscription: false,
BuildInfo: BuildInfo{
Version: Version,
GitCommit: GitCommit,
BuildTime: BuildTime,
GoVersion: GoVersion,
},
}
}
+17
View File
@@ -292,6 +292,23 @@ func ConvertProvidersToModelList(cfg *Config) []ModelConfig {
}, true
},
},
{
providerNames: []string{"vivgrid"},
protocol: "vivgrid",
buildConfig: func(p ProvidersConfig) (ModelConfig, bool) {
if p.Vivgrid.APIKey == "" && p.Vivgrid.APIBase == "" {
return ModelConfig{}, false
}
return ModelConfig{
ModelName: "vivgrid",
Model: "vivgrid/auto",
APIKey: p.Vivgrid.APIKey,
APIBase: p.Vivgrid.APIBase,
Proxy: p.Vivgrid.Proxy,
RequestTimeout: p.Vivgrid.RequestTimeout,
}, true
},
},
{
providerNames: []string{"volcengine", "doubao"},
protocol: "volcengine",
+5 -4
View File
@@ -155,7 +155,8 @@ func TestConvertProvidersToModelList_AllProviders(t *testing.T) {
ShengSuanYun: ProviderConfig{APIKey: "key11"},
DeepSeek: ProviderConfig{APIKey: "key12"},
Cerebras: ProviderConfig{APIKey: "key13"},
VolcEngine: ProviderConfig{APIKey: "key14"},
Vivgrid: ProviderConfig{APIKey: "key14"},
VolcEngine: ProviderConfig{APIKey: "key15"},
GitHubCopilot: ProviderConfig{ConnectMode: "grpc"},
Antigravity: ProviderConfig{AuthMethod: "oauth"},
Qwen: ProviderConfig{APIKey: "key17"},
@@ -166,9 +167,9 @@ func TestConvertProvidersToModelList_AllProviders(t *testing.T) {
result := ConvertProvidersToModelList(cfg)
// All 20 providers should be converted
if len(result) != 20 {
t.Errorf("len(result) = %d, want 20", len(result))
// All 21 providers should be converted
if len(result) != 21 {
t.Errorf("len(result) = %d, want 21", len(result))
}
}
+44
View File
@@ -0,0 +1,44 @@
package config
import (
"fmt"
"runtime"
)
// Build-time variables injected via ldflags during build process.
// These are set by the Makefile or .goreleaser.yaml using the -X flag:
//
// -X github.com/sipeed/picoclaw/pkg/config.Version=<version>
// -X github.com/sipeed/picoclaw/pkg/config.GitCommit=<commit>
// -X github.com/sipeed/picoclaw/pkg/config.BuildTime=<timestamp>
// -X github.com/sipeed/picoclaw/pkg/config.GoVersion=<go-version>
var (
Version = "dev" // Default value when not built with ldflags
GitCommit string // Git commit SHA (short)
BuildTime string // Build timestamp in RFC3339 format
GoVersion string // Go version used for building
)
// FormatVersion returns the version string with optional git commit
func FormatVersion() string {
v := Version
if GitCommit != "" {
v += fmt.Sprintf(" (git: %s)", GitCommit)
}
return v
}
// FormatBuildInfo returns build time and go version info
func FormatBuildInfo() (string, string) {
build := BuildTime
goVer := GoVersion
if goVer == "" {
goVer = runtime.Version()
}
return build, goVer
}
// GetVersion returns the version string
func GetVersion() string {
return Version
}
+92
View File
@@ -0,0 +1,92 @@
package config
import (
"runtime"
"testing"
"github.com/stretchr/testify/assert"
)
func TestFormatVersion_NoGitCommit(t *testing.T) {
oldVersion, oldGit := Version, GitCommit
t.Cleanup(func() { Version, GitCommit = oldVersion, oldGit })
Version = "1.2.3"
GitCommit = ""
assert.Equal(t, "1.2.3", FormatVersion())
}
func TestFormatVersion_WithGitCommit(t *testing.T) {
oldVersion, oldGit := Version, GitCommit
t.Cleanup(func() { Version, GitCommit = oldVersion, oldGit })
Version = "1.2.3"
GitCommit = "abc123"
assert.Equal(t, "1.2.3 (git: abc123)", FormatVersion())
}
func TestFormatBuildInfo_UsesBuildTimeAndGoVersion_WhenSet(t *testing.T) {
oldBuildTime, oldGoVersion := BuildTime, GoVersion
t.Cleanup(func() { BuildTime, GoVersion = oldBuildTime, oldGoVersion })
BuildTime = "2026-02-20T00:00:00Z"
GoVersion = "go1.23.0"
build, goVer := FormatBuildInfo()
assert.Equal(t, BuildTime, build)
assert.Equal(t, GoVersion, goVer)
}
func TestFormatBuildInfo_EmptyBuildTime_ReturnsEmptyBuild(t *testing.T) {
oldBuildTime, oldGoVersion := BuildTime, GoVersion
t.Cleanup(func() { BuildTime, GoVersion = oldBuildTime, oldGoVersion })
BuildTime = ""
GoVersion = "go1.23.0"
build, goVer := FormatBuildInfo()
assert.Empty(t, build)
assert.Equal(t, GoVersion, goVer)
}
func TestFormatBuildInfo_EmptyGoVersion_FallsBackToRuntimeVersion(t *testing.T) {
oldBuildTime, oldGoVersion := BuildTime, GoVersion
t.Cleanup(func() { BuildTime, GoVersion = oldBuildTime, oldGoVersion })
BuildTime = "x"
GoVersion = ""
build, goVer := FormatBuildInfo()
assert.Equal(t, "x", build)
assert.Equal(t, runtime.Version(), goVer)
}
func TestGetVersion(t *testing.T) {
oldVersion := Version
t.Cleanup(func() { Version = oldVersion })
Version = "dev"
assert.Equal(t, "dev", GetVersion())
}
func TestGetVersion_Custom(t *testing.T) {
oldVersion := Version
t.Cleanup(func() { Version = oldVersion })
Version = "v1.0.0"
assert.Equal(t, "v1.0.0", GetVersion())
}
func TestVersion_DefaultIsDev(t *testing.T) {
// Reset to default values
oldVersion := Version
Version = "dev"
t.Cleanup(func() { Version = oldVersion })
assert.Equal(t, "dev", Version)
}
+1
View File
@@ -22,6 +22,7 @@ var supportedChannels = map[string]bool{
"qq": true,
"dingtalk": true,
"slack": true,
"matrix": true,
"line": true,
"onebot": true,
"wecom": true,
+55 -13
View File
@@ -371,6 +371,8 @@ func (c *OpenClawConfig) IsChannelEnabled(name string) bool {
return c.Channels.Discord == nil || c.Channels.Discord.Enabled == nil || *c.Channels.Discord.Enabled
case "slack":
return c.Channels.Slack == nil || c.Channels.Slack.Enabled == nil || *c.Channels.Slack.Enabled
case "matrix":
return c.Channels.Matrix == nil || c.Channels.Matrix.Enabled == nil || *c.Channels.Matrix.Enabled
case "whatsapp":
return c.Channels.WhatsApp == nil || c.Channels.WhatsApp.Enabled == nil || *c.Channels.WhatsApp.Enabled
case "feishu":
@@ -397,6 +399,11 @@ func GetChannelAllowFrom(ch any) []string {
return nil
}
return c.AllowFrom
case *OpenClawMatrixConfig:
if c == nil {
return nil
}
return c.AllowFrom
case *OpenClawWhatsAppConfig:
if c == nil {
return nil
@@ -627,6 +634,7 @@ type ChannelsConfig struct {
QQ QQConfig `json:"qq"`
DingTalk DingTalkConfig `json:"dingtalk"`
Slack SlackConfig `json:"slack"`
Matrix MatrixConfig `json:"matrix"`
LINE LINEConfig `json:"line"`
}
@@ -687,6 +695,14 @@ type SlackConfig struct {
AllowFrom []string `json:"allow_from"`
}
type MatrixConfig struct {
Enabled bool `json:"enabled"`
Homeserver string `json:"homeserver"`
UserID string `json:"user_id"`
AccessToken string `json:"access_token"`
AllowFrom []string `json:"allow_from"`
}
type LINEConfig struct {
Enabled bool `json:"enabled"`
ChannelSecret string `json:"channel_secret"`
@@ -717,16 +733,18 @@ type WebToolsConfig struct {
}
type BraveConfig struct {
Enabled bool `json:"enabled"`
APIKey string `json:"api_key"`
MaxResults int `json:"max_results"`
Enabled bool `json:"enabled"`
APIKey string `json:"api_key"`
APIKeys []string `json:"api_keys"`
MaxResults int `json:"max_results"`
}
type TavilyConfig struct {
Enabled bool `json:"enabled"`
APIKey string `json:"api_key"`
BaseURL string `json:"base_url"`
MaxResults int `json:"max_results"`
Enabled bool `json:"enabled"`
APIKey string `json:"api_key"`
APIKeys []string `json:"api_keys"`
BaseURL string `json:"base_url"`
MaxResults int `json:"max_results"`
}
type DuckDuckGoConfig struct {
@@ -735,9 +753,10 @@ type DuckDuckGoConfig struct {
}
type PerplexityConfig struct {
Enabled bool `json:"enabled"`
APIKey string `json:"api_key"`
MaxResults int `json:"max_results"`
Enabled bool `json:"enabled"`
APIKey string `json:"api_key"`
APIKeys []string `json:"api_keys"`
MaxResults int `json:"max_results"`
}
type CronConfig struct {
@@ -862,12 +881,26 @@ func (c *OpenClawConfig) convertChannels(warnings *[]string) ChannelsConfig {
}
}
if c.Channels.Matrix != nil && supportedChannels["matrix"] {
enabled := c.Channels.Matrix.Enabled == nil || *c.Channels.Matrix.Enabled
channels.Matrix = MatrixConfig{
Enabled: enabled,
AllowFrom: c.Channels.Matrix.AllowFrom,
}
if c.Channels.Matrix.Homeserver != nil {
channels.Matrix.Homeserver = *c.Channels.Matrix.Homeserver
}
if c.Channels.Matrix.UserID != nil {
channels.Matrix.UserID = *c.Channels.Matrix.UserID
}
if c.Channels.Matrix.AccessToken != nil {
channels.Matrix.AccessToken = *c.Channels.Matrix.AccessToken
}
}
if c.Channels.Signal != nil {
*warnings = append(*warnings, "Channel 'signal': No PicoClaw adapter available")
}
if c.Channels.Matrix != nil {
*warnings = append(*warnings, "Channel 'matrix': No PicoClaw adapter available")
}
if c.Channels.IRC != nil {
*warnings = append(*warnings, "Channel 'irc': No PicoClaw adapter available")
}
@@ -1020,6 +1053,14 @@ func (c ChannelsConfig) ToStandardChannels() config.ChannelsConfig {
BotToken: c.Slack.BotToken,
AppToken: c.Slack.AppToken,
},
Matrix: config.MatrixConfig{
Enabled: c.Matrix.Enabled,
Homeserver: c.Matrix.Homeserver,
UserID: c.Matrix.UserID,
AccessToken: c.Matrix.AccessToken,
AllowFrom: c.Matrix.AllowFrom,
JoinOnInvite: true,
},
LINE: config.LINEConfig{
Enabled: c.LINE.Enabled,
ChannelSecret: c.LINE.ChannelSecret,
@@ -1044,6 +1085,7 @@ func (c ToolsConfig) ToStandardTools() config.ToolsConfig {
Brave: config.BraveConfig{
Enabled: c.Web.Brave.Enabled,
APIKey: c.Web.Brave.APIKey,
APIKeys: c.Web.Brave.APIKeys,
MaxResults: c.Web.Brave.MaxResults,
},
Tavily: config.TavilyConfig{
@@ -4,6 +4,7 @@ import (
"encoding/json"
"os"
"path/filepath"
"strings"
"testing"
)
@@ -375,6 +376,96 @@ func TestConvertToPicoClawWithQQAndDingTalk(t *testing.T) {
}
}
func TestConvertToPicoClawWithMatrix(t *testing.T) {
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "openclaw.json")
testConfig := `{
"channels": {
"matrix": {
"enabled": true,
"homeserver": "https://matrix.example.com",
"userId": "@bot:matrix.example.com",
"accessToken": "syt_test_token",
"allowFrom": ["@alice:matrix.example.com"]
}
}
}`
err := os.WriteFile(configPath, []byte(testConfig), 0o644)
if err != nil {
t.Fatalf("failed to write test config: %v", err)
}
cfg, err := LoadOpenClawConfig(configPath)
if err != nil {
t.Fatalf("failed to load config: %v", err)
}
picoCfg, warnings, err := cfg.ConvertToPicoClaw("")
if err != nil {
t.Fatalf("failed to convert config: %v", err)
}
if !picoCfg.Channels.Matrix.Enabled {
t.Error("matrix should be enabled")
}
if picoCfg.Channels.Matrix.Homeserver != "https://matrix.example.com" {
t.Errorf("expected matrix homeserver, got %q", picoCfg.Channels.Matrix.Homeserver)
}
if picoCfg.Channels.Matrix.UserID != "@bot:matrix.example.com" {
t.Errorf("expected matrix user_id, got %q", picoCfg.Channels.Matrix.UserID)
}
if picoCfg.Channels.Matrix.AccessToken != "syt_test_token" {
t.Errorf("expected matrix access_token, got %q", picoCfg.Channels.Matrix.AccessToken)
}
if len(picoCfg.Channels.Matrix.AllowFrom) != 1 ||
picoCfg.Channels.Matrix.AllowFrom[0] != "@alice:matrix.example.com" {
t.Errorf("unexpected matrix allow_from: %#v", picoCfg.Channels.Matrix.AllowFrom)
}
for _, w := range warnings {
if strings.Contains(w, "Channel 'matrix'") {
t.Fatalf("matrix should no longer be reported as unsupported, warning=%q", w)
}
}
}
func TestConvertToPicoClawWithMatrixDisabled(t *testing.T) {
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "openclaw.json")
testConfig := `{
"channels": {
"matrix": {
"enabled": false,
"homeserver": "https://matrix.example.com",
"userId": "@bot:matrix.example.com",
"accessToken": "syt_test_token"
}
}
}`
err := os.WriteFile(configPath, []byte(testConfig), 0o644)
if err != nil {
t.Fatalf("failed to write test config: %v", err)
}
cfg, err := LoadOpenClawConfig(configPath)
if err != nil {
t.Fatalf("failed to load config: %v", err)
}
picoCfg, _, err := cfg.ConvertToPicoClaw("")
if err != nil {
t.Fatalf("failed to convert config: %v", err)
}
if picoCfg.Channels.Matrix.Enabled {
t.Error("matrix should respect enabled=false from source config")
}
}
func TestOpenClawAgentModel(t *testing.T) {
model := &OpenClawAgentModel{
Primary: strPtr("anthropic/claude-3-opus"),
@@ -425,6 +516,9 @@ func TestChannelEnabled(t *testing.T) {
if !cfg.IsChannelEnabled("slack") {
t.Error("slack should be enabled (explicitly set)")
}
if !cfg.IsChannelEnabled("matrix") {
t.Error("matrix should be enabled (nil config defaults to enabled)")
}
if cfg.IsChannelEnabled("line") {
t.Error("line should return false (not in switch cases)")
}
+32
View File
@@ -153,6 +153,15 @@ func resolveProviderSelection(cfg *config.Config) (providerSelection, error) {
sel.apiBase = "https://integrate.api.nvidia.com/v1"
}
}
case "vivgrid":
if cfg.Providers.Vivgrid.APIKey != "" {
sel.apiKey = cfg.Providers.Vivgrid.APIKey
sel.apiBase = cfg.Providers.Vivgrid.APIBase
sel.proxy = cfg.Providers.Vivgrid.Proxy
if sel.apiBase == "" {
sel.apiBase = "https://api.vivgrid.com/v1"
}
}
case "claude-cli", "claude-code", "claudecode":
workspace := cfg.WorkspacePath()
if workspace == "" {
@@ -199,6 +208,15 @@ func resolveProviderSelection(cfg *config.Config) (providerSelection, error) {
sel.apiBase = "https://api.mistral.ai/v1"
}
}
case "minimax":
if cfg.Providers.Minimax.APIKey != "" {
sel.apiKey = cfg.Providers.Minimax.APIKey
sel.apiBase = cfg.Providers.Minimax.APIBase
sel.proxy = cfg.Providers.Minimax.Proxy
if sel.apiBase == "" {
sel.apiBase = "https://api.minimaxi.com/v1"
}
}
case "github_copilot", "copilot":
sel.providerType = providerTypeGitHubCopilot
if cfg.Providers.GitHubCopilot.APIBase != "" {
@@ -295,6 +313,13 @@ func resolveProviderSelection(cfg *config.Config) (providerSelection, error) {
if sel.apiBase == "" {
sel.apiBase = "https://integrate.api.nvidia.com/v1"
}
case strings.HasPrefix(model, "vivgrid/") && cfg.Providers.Vivgrid.APIKey != "":
sel.apiKey = cfg.Providers.Vivgrid.APIKey
sel.apiBase = cfg.Providers.Vivgrid.APIBase
sel.proxy = cfg.Providers.Vivgrid.Proxy
if sel.apiBase == "" {
sel.apiBase = "https://api.vivgrid.com/v1"
}
case (strings.Contains(lowerModel, "ollama") || strings.HasPrefix(model, "ollama/")) && cfg.Providers.Ollama.APIKey != "":
sel.apiKey = cfg.Providers.Ollama.APIKey
sel.apiBase = cfg.Providers.Ollama.APIBase
@@ -309,6 +334,13 @@ func resolveProviderSelection(cfg *config.Config) (providerSelection, error) {
if sel.apiBase == "" {
sel.apiBase = "https://api.mistral.ai/v1"
}
case (strings.Contains(lowerModel, "minimax") || strings.HasPrefix(model, "minimax/")) && cfg.Providers.Minimax.APIKey != "":
sel.apiKey = cfg.Providers.Minimax.APIKey
sel.apiBase = cfg.Providers.Minimax.APIBase
sel.proxy = cfg.Providers.Minimax.Proxy
if sel.apiBase == "" {
sel.apiBase = "https://api.minimaxi.com/v1"
}
case strings.HasPrefix(model, "avian/") && cfg.Providers.Avian.APIKey != "":
sel.apiKey = cfg.Providers.Avian.APIKey
sel.apiBase = cfg.Providers.Avian.APIBase
+6 -1
View File
@@ -94,7 +94,8 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
case "litellm", "openrouter", "groq", "zhipu", "gemini", "nvidia",
"ollama", "moonshot", "shengsuanyun", "deepseek", "cerebras",
"volcengine", "vllm", "qwen", "mistral", "avian":
"vivgrid", "volcengine", "vllm", "qwen", "mistral", "avian",
"minimax":
// All other OpenAI-compatible HTTP providers
if cfg.APIKey == "" && cfg.APIBase == "" {
return nil, "", fmt.Errorf("api_key or api_base is required for HTTP-based protocol %q", protocol)
@@ -200,6 +201,8 @@ func getDefaultAPIBase(protocol string) string {
return "https://api.deepseek.com/v1"
case "cerebras":
return "https://api.cerebras.ai/v1"
case "vivgrid":
return "https://api.vivgrid.com/v1"
case "volcengine":
return "https://ark.cn-beijing.volces.com/api/v3"
case "qwen":
@@ -210,6 +213,8 @@ func getDefaultAPIBase(protocol string) string {
return "https://api.mistral.ai/v1"
case "avian":
return "https://api.avian.io/v1"
case "minimax":
return "https://api.minimaxi.com/v1"
default:
return ""
}
+1
View File
@@ -108,6 +108,7 @@ func TestCreateProviderFromConfig_DefaultAPIBase(t *testing.T) {
{"groq", "groq"},
{"openrouter", "openrouter"},
{"cerebras", "cerebras"},
{"vivgrid", "vivgrid"},
{"qwen", "qwen"},
{"vllm", "vllm"},
{"deepseek", "deepseek"},
+11
View File
@@ -88,6 +88,17 @@ func TestResolveProviderSelection(t *testing.T) {
wantAPIBase: "https://integrate.api.nvidia.com/v1",
wantProxy: "http://127.0.0.1:7890",
},
{
name: "explicit vivgrid provider uses defaults",
setup: func(cfg *config.Config) {
cfg.Agents.Defaults.Provider = "vivgrid"
cfg.Providers.Vivgrid.APIKey = "vivgrid-key"
cfg.Providers.Vivgrid.Proxy = "http://127.0.0.1:7890"
},
wantType: providerTypeHTTPCompat,
wantAPIBase: "https://api.vivgrid.com/v1",
wantProxy: "http://127.0.0.1:7890",
},
{
name: "openrouter model uses openrouter defaults",
setup: func(cfg *config.Config) {
+2 -1
View File
@@ -439,7 +439,8 @@ func normalizeModel(model, apiBase string) string {
prefix := strings.ToLower(before)
switch prefix {
case "litellm", "moonshot", "nvidia", "groq", "ollama", "deepseek", "google", "openrouter", "zhipu", "mistral":
case "litellm", "moonshot", "nvidia", "groq", "ollama", "deepseek", "google",
"openrouter", "zhipu", "mistral", "vivgrid", "minimax":
return after
default:
return model
+12 -1
View File
@@ -382,7 +382,7 @@ func TestProviderChat_StripsMoonshotPrefixAndNormalizesKimiTemperature(t *testin
}
}
func TestProviderChat_StripsGroqAndOllamaPrefixes(t *testing.T) {
func TestProviderChat_StripsGroqOllamaDeepseekVivgridPrefixes(t *testing.T) {
tests := []struct {
name string
input string
@@ -408,6 +408,11 @@ func TestProviderChat_StripsGroqAndOllamaPrefixes(t *testing.T) {
input: "deepseek/deepseek-chat",
wantModel: "deepseek-chat",
},
{
name: "strips vivgrid prefix",
input: "vivgrid/auto",
wantModel: "auto",
},
}
for _, tt := range tests {
@@ -512,6 +517,12 @@ func TestNormalizeModel_UsesAPIBase(t *testing.T) {
if got := normalizeModel("openrouter/auto", "https://openrouter.ai/api/v1"); got != "openrouter/auto" {
t.Fatalf("normalizeModel(openrouter) = %q, want %q", got, "openrouter/auto")
}
if got := normalizeModel("vivgrid/managed", "https://api.vivgrid.com/v1"); got != "managed" {
t.Fatalf("normalizeModel(vivgrid) = %q, want %q", got, "managed")
}
if got := normalizeModel("vivgrid/auto", "https://api.vivgrid.com/v1"); got != "auto" {
t.Fatalf("normalizeModel(vivgrid auto) = %q, want %q", got, "auto")
}
}
func TestProvider_RequestTimeoutDefault(t *testing.T) {
+81
View File
@@ -0,0 +1,81 @@
package session
import (
"context"
"log"
"github.com/sipeed/picoclaw/pkg/memory"
"github.com/sipeed/picoclaw/pkg/providers"
)
// JSONLBackend adapts a memory.Store into the SessionStore interface.
// Write errors are logged rather than returned, matching the fire-and-forget
// contract of SessionManager that the agent loop relies on.
type JSONLBackend struct {
store memory.Store
}
// NewJSONLBackend wraps a memory.Store for use as a SessionStore.
func NewJSONLBackend(store memory.Store) *JSONLBackend {
return &JSONLBackend{store: store}
}
func (b *JSONLBackend) AddMessage(sessionKey, role, content string) {
if err := b.store.AddMessage(context.Background(), sessionKey, role, content); err != nil {
log.Printf("session: add message: %v", err)
}
}
func (b *JSONLBackend) AddFullMessage(sessionKey string, msg providers.Message) {
if err := b.store.AddFullMessage(context.Background(), sessionKey, msg); err != nil {
log.Printf("session: add full message: %v", err)
}
}
func (b *JSONLBackend) GetHistory(key string) []providers.Message {
msgs, err := b.store.GetHistory(context.Background(), key)
if err != nil {
log.Printf("session: get history: %v", err)
return []providers.Message{}
}
return msgs
}
func (b *JSONLBackend) GetSummary(key string) string {
summary, err := b.store.GetSummary(context.Background(), key)
if err != nil {
log.Printf("session: get summary: %v", err)
return ""
}
return summary
}
func (b *JSONLBackend) SetSummary(key, summary string) {
if err := b.store.SetSummary(context.Background(), key, summary); err != nil {
log.Printf("session: set summary: %v", err)
}
}
func (b *JSONLBackend) SetHistory(key string, history []providers.Message) {
if err := b.store.SetHistory(context.Background(), key, history); err != nil {
log.Printf("session: set history: %v", err)
}
}
func (b *JSONLBackend) TruncateHistory(key string, keepLast int) {
if err := b.store.TruncateHistory(context.Background(), key, keepLast); err != nil {
log.Printf("session: truncate history: %v", err)
}
}
// Save persists session state. Since the JSONL store fsyncs every write
// immediately, the data is already durable. Save runs compaction to reclaim
// space from logically truncated messages (no-op when there are none).
func (b *JSONLBackend) Save(key string) error {
return b.store.Compact(context.Background(), key)
}
// Close releases resources held by the underlying store.
func (b *JSONLBackend) Close() error {
return b.store.Close()
}
+179
View File
@@ -0,0 +1,179 @@
package session_test
import (
"fmt"
"testing"
"github.com/sipeed/picoclaw/pkg/memory"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/session"
)
// Compile-time interface satisfaction checks.
var (
_ session.SessionStore = (*session.SessionManager)(nil)
_ session.SessionStore = (*session.JSONLBackend)(nil)
)
func newBackend(t *testing.T) *session.JSONLBackend {
t.Helper()
store, err := memory.NewJSONLStore(t.TempDir())
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() { store.Close() })
return session.NewJSONLBackend(store)
}
func TestJSONLBackend_AddAndGetHistory(t *testing.T) {
b := newBackend(t)
b.AddMessage("s1", "user", "hello")
b.AddMessage("s1", "assistant", "hi")
history := b.GetHistory("s1")
if len(history) != 2 {
t.Fatalf("got %d messages, want 2", len(history))
}
if history[0].Role != "user" || history[0].Content != "hello" {
t.Errorf("msg[0] = %+v", history[0])
}
if history[1].Role != "assistant" || history[1].Content != "hi" {
t.Errorf("msg[1] = %+v", history[1])
}
}
func TestJSONLBackend_AddFullMessage(t *testing.T) {
b := newBackend(t)
msg := providers.Message{
Role: "assistant",
Content: "done",
ToolCalls: []providers.ToolCall{
{ID: "tc1", Function: &providers.FunctionCall{Name: "read_file", Arguments: `{"path":"x"}`}},
},
}
b.AddFullMessage("s1", msg)
history := b.GetHistory("s1")
if len(history) != 1 {
t.Fatalf("got %d, want 1", len(history))
}
if len(history[0].ToolCalls) != 1 || history[0].ToolCalls[0].ID != "tc1" {
t.Errorf("tool calls = %+v", history[0].ToolCalls)
}
}
func TestJSONLBackend_Summary(t *testing.T) {
b := newBackend(t)
if got := b.GetSummary("s1"); got != "" {
t.Errorf("got %q, want empty", got)
}
b.SetSummary("s1", "test summary")
if got := b.GetSummary("s1"); got != "test summary" {
t.Errorf("got %q, want %q", got, "test summary")
}
}
func TestJSONLBackend_TruncateAndSave(t *testing.T) {
b := newBackend(t)
for i := 0; i < 10; i++ {
b.AddMessage("s1", "user", fmt.Sprintf("msg %d", i))
}
b.TruncateHistory("s1", 3)
history := b.GetHistory("s1")
if len(history) != 3 {
t.Fatalf("got %d, want 3", len(history))
}
if history[0].Content != "msg 7" {
t.Errorf("got %q, want %q", history[0].Content, "msg 7")
}
// Save triggers compaction.
if err := b.Save("s1"); err != nil {
t.Fatal(err)
}
// Messages still accessible after compaction.
history = b.GetHistory("s1")
if len(history) != 3 {
t.Fatalf("after save: got %d, want 3", len(history))
}
}
func TestJSONLBackend_SetHistory(t *testing.T) {
b := newBackend(t)
b.AddMessage("s1", "user", "old")
b.SetHistory("s1", []providers.Message{
{Role: "user", Content: "new1"},
{Role: "assistant", Content: "new2"},
})
history := b.GetHistory("s1")
if len(history) != 2 {
t.Fatalf("got %d, want 2", len(history))
}
if history[0].Content != "new1" {
t.Errorf("got %q, want %q", history[0].Content, "new1")
}
}
func TestJSONLBackend_EmptySession(t *testing.T) {
b := newBackend(t)
history := b.GetHistory("nonexistent")
if history == nil {
t.Fatal("got nil, want empty slice")
}
if len(history) != 0 {
t.Errorf("got %d, want 0", len(history))
}
}
func TestJSONLBackend_SessionIsolation(t *testing.T) {
b := newBackend(t)
b.AddMessage("s1", "user", "session1")
b.AddMessage("s2", "user", "session2")
h1 := b.GetHistory("s1")
h2 := b.GetHistory("s2")
if len(h1) != 1 || h1[0].Content != "session1" {
t.Errorf("s1: %+v", h1)
}
if len(h2) != 1 || h2[0].Content != "session2" {
t.Errorf("s2: %+v", h2)
}
}
func TestJSONLBackend_SummarizeFlow(t *testing.T) {
// Simulates the real summarization flow in the agent loop:
// SetSummary → TruncateHistory → Save
b := newBackend(t)
for i := 0; i < 20; i++ {
b.AddMessage("s1", "user", fmt.Sprintf("msg %d", i))
}
b.SetSummary("s1", "conversation about testing")
b.TruncateHistory("s1", 4)
if err := b.Save("s1"); err != nil {
t.Fatal(err)
}
if got := b.GetSummary("s1"); got != "conversation about testing" {
t.Errorf("summary = %q", got)
}
history := b.GetHistory("s1")
if len(history) != 4 {
t.Fatalf("got %d messages, want 4", len(history))
}
if history[0].Content != "msg 16" {
t.Errorf("first message = %q, want %q", history[0].Content, "msg 16")
}
}
+6
View File
@@ -265,6 +265,12 @@ func (sm *SessionManager) loadSessions() error {
return nil
}
// Close is a no-op for the in-memory SessionManager; it satisfies the
// SessionStore interface so callers can release resources uniformly.
func (sm *SessionManager) Close() error {
return nil
}
// SetHistory updates the messages of a session.
func (sm *SessionManager) SetHistory(key string, history []providers.Message) {
sm.mu.Lock()
+32
View File
@@ -0,0 +1,32 @@
package session
import "github.com/sipeed/picoclaw/pkg/providers"
// SessionStore defines the persistence operations used by the agent loop.
// Both SessionManager (legacy JSON backend) and JSONLBackend satisfy this
// interface, allowing the storage layer to be swapped without touching the
// agent loop code.
//
// Write methods (Add*, Set*, Truncate*) are fire-and-forget: they do not
// return errors. Implementations should log failures internally. This
// matches the original SessionManager contract that the agent loop relies on.
type SessionStore interface {
// AddMessage appends a simple role/content message to the session.
AddMessage(sessionKey, role, content string)
// AddFullMessage appends a complete message including tool calls.
AddFullMessage(sessionKey string, msg providers.Message)
// GetHistory returns the full message history for the session.
GetHistory(key string) []providers.Message
// GetSummary returns the conversation summary, or "" if none.
GetSummary(key string) string
// SetSummary replaces the conversation summary.
SetSummary(key, summary string)
// SetHistory replaces the full message history.
SetHistory(key string, history []providers.Message)
// TruncateHistory keeps only the last keepLast messages.
TruncateHistory(key string, keepLast int)
// Save persists any pending state to durable storage.
Save(key string) error
// Close releases resources held by the store.
Close() error
}
+241 -7
View File
@@ -2,17 +2,24 @@ package tools
import (
"context"
"errors"
"fmt"
"io"
"io/fs"
"math"
"os"
"path/filepath"
"regexp"
"strconv"
"strings"
"time"
"github.com/sipeed/picoclaw/pkg/fileutil"
"github.com/sipeed/picoclaw/pkg/logger"
)
const MaxReadFileSize = 64 * 1024 // 64KB limit to avoid context overflow
// validatePath ensures the given path is within the workspace if restrict is true.
func validatePath(path, workspace string, restrict bool) (string, error) {
if workspace == "" {
@@ -85,15 +92,30 @@ func isWithinWorkspace(candidate, workspace string) bool {
}
type ReadFileTool struct {
fs fileSystem
fs fileSystem
maxSize int64
}
func NewReadFileTool(workspace string, restrict bool, allowPaths ...[]*regexp.Regexp) *ReadFileTool {
func NewReadFileTool(
workspace string,
restrict bool,
maxReadFileSize int,
allowPaths ...[]*regexp.Regexp,
) *ReadFileTool {
var patterns []*regexp.Regexp
if len(allowPaths) > 0 {
patterns = allowPaths[0]
}
return &ReadFileTool{fs: buildFs(workspace, restrict, patterns)}
maxSize := int64(maxReadFileSize)
if maxSize <= 0 {
maxSize = MaxReadFileSize
}
return &ReadFileTool{
fs: buildFs(workspace, restrict, patterns),
maxSize: maxSize,
}
}
func (t *ReadFileTool) Name() string {
@@ -101,7 +123,7 @@ func (t *ReadFileTool) Name() string {
}
func (t *ReadFileTool) Description() string {
return "Read the contents of a file"
return "Read the contents of a file. Supports pagination via `offset` and `length`."
}
func (t *ReadFileTool) Parameters() map[string]any {
@@ -110,7 +132,17 @@ func (t *ReadFileTool) Parameters() map[string]any {
"properties": map[string]any{
"path": map[string]any{
"type": "string",
"description": "Path to the file to read",
"description": "Path to the file to read.",
},
"offset": map[string]any{
"type": "integer",
"description": "Byte offset to start reading from.",
"default": 0,
},
"length": map[string]any{
"type": "integer",
"description": "Maximum number of bytes to read.",
"default": t.maxSize,
},
},
"required": []string{"path"},
@@ -123,11 +155,171 @@ func (t *ReadFileTool) Execute(ctx context.Context, args map[string]any) *ToolRe
return ErrorResult("path is required")
}
content, err := t.fs.ReadFile(path)
// offset (optional, default 0)
offset, err := getInt64Arg(args, "offset", 0)
if err != nil {
return ErrorResult(err.Error())
}
return NewToolResult(string(content))
if offset < 0 {
return ErrorResult("offset must be >= 0")
}
// length (optional, capped at MaxReadFileSize)
length, err := getInt64Arg(args, "length", t.maxSize)
if err != nil {
return ErrorResult(err.Error())
}
if length <= 0 {
return ErrorResult("length must be > 0")
}
if length > t.maxSize {
length = t.maxSize
}
file, err := t.fs.Open(path)
if err != nil {
return ErrorResult(err.Error())
}
defer file.Close()
// measure total size
totalSize := int64(-1) // -1 means unknown
if info, statErr := file.Stat(); statErr == nil {
totalSize = info.Size()
}
// sniff the first 512 bytes to detect binary content before loading
// it into the LLM context. Seeking back to 0 afterwards restores state.
sniff := make([]byte, 512)
sniffN, _ := file.Read(sniff)
// Reset read position to beginning before applying the caller's offset.
if seeker, ok := file.(io.Seeker); ok {
_, err = seeker.Seek(0, io.SeekStart)
if err != nil {
return ErrorResult(fmt.Sprintf("failed to reset file position after sniff: %v", err))
}
} else {
// Non-seekable: we consumed sniffN bytes above; account for them when
// discarding to reach the requested offset below.
// If offset < sniffN the data we already read covers it, which we
// cannot replay on a non-seekable stream — return a clear error.
if offset < int64(sniffN) && offset > 0 {
return ErrorResult(
"non-seekable file: cannot seek to an offset within the first 512 bytes after binary detection",
)
}
}
// Seek to the requested offset.
if seeker, ok := file.(io.Seeker); ok {
_, err = seeker.Seek(offset, io.SeekStart)
if err != nil {
return ErrorResult(fmt.Sprintf("failed to seek to offset %d: %v", offset, err))
}
} else if offset > 0 {
// Fallback for non-seekable streams: discard leading bytes.
// sniffN bytes were already consumed above, so subtract them.
remaining := offset - int64(sniffN)
if remaining > 0 {
_, err = io.CopyN(io.Discard, file, remaining)
if err != nil {
return ErrorResult(fmt.Sprintf("failed to advance to offset %d: %v", offset, err))
}
}
}
// read length+1 bytes to reliably detect whether more content exists
// without relying on totalSize (which may be -1 for non-seekable streams).
// This avoids the false-positive TRUNCATED message on the last page.
probe := make([]byte, length+1)
n, err := io.ReadFull(file, probe)
// FIX: io.ReadFull returns io.ErrUnexpectedEOF for partial reads (0 < n < len),
// and io.EOF only when n == 0. Both are normal terminal conditions — only
// other errors are genuine failures.
if err != nil && err != io.EOF && !errors.Is(err, io.ErrUnexpectedEOF) {
return ErrorResult(fmt.Sprintf("failed to read file content: %v", err))
}
// hasMore is true only when we actually got the extra probe byte.
hasMore := int64(n) > length
data := probe[:min(int64(n), length)]
if len(data) == 0 {
return NewToolResult("[END OF FILE - no content at this offset]")
}
// Build metadata header.
// use filepath.Base(path) instead of the raw path to avoid leaking
// internal filesystem structure into the LLM context.
readEnd := offset + int64(len(data))
// use ASCII hyphen-minus instead of en-dash (U+2013) to keep the
// header parseable by downstream tools and log processors.
readRange := fmt.Sprintf("bytes %d-%d", offset, readEnd-1)
displayPath := filepath.Base(path)
var header string
if totalSize >= 0 {
header = fmt.Sprintf(
"[file: %s | total: %d bytes | read: %s]",
displayPath, totalSize, readRange,
)
} else {
header = fmt.Sprintf(
"[file: %s | read: %s | total size unknown]",
displayPath, readRange,
)
}
if hasMore {
header += fmt.Sprintf(
"\n[TRUNCATED - file has more content. Call read_file again with offset=%d to continue.]",
readEnd,
)
} else {
header += "\n[END OF FILE - no further content.]"
}
logger.DebugCF("tool", "ReadFileTool execution completed successfully",
map[string]any{
"path": path,
"bytes_read": len(data),
"has_more": hasMore,
})
return NewToolResult(header + "\n\n" + string(data))
}
// getInt64Arg extracts an integer argument from the args map, returning the
// provided default if the key is absent.
func getInt64Arg(args map[string]any, key string, defaultVal int64) (int64, error) {
raw, exists := args[key]
if !exists {
return defaultVal, nil
}
switch v := raw.(type) {
case float64:
if v != math.Trunc(v) {
return 0, fmt.Errorf("%s must be an integer, got float %v", key, v)
}
if v > math.MaxInt64 || v < math.MinInt64 {
return 0, fmt.Errorf("%s value %v overflows int64", key, v)
}
return int64(v), nil
case int:
return int64(v), nil
case int64:
return v, nil
case string:
parsed, err := strconv.ParseInt(v, 10, 64)
if err != nil {
return 0, fmt.Errorf("invalid integer format for %s parameter: %w", key, err)
}
return parsed, nil
default:
return 0, fmt.Errorf("unsupported type %T for %s parameter", raw, key)
}
}
type WriteFileTool struct {
@@ -249,6 +441,7 @@ type fileSystem interface {
ReadFile(path string) ([]byte, error)
WriteFile(path string, data []byte) error
ReadDir(path string) ([]os.DirEntry, error)
Open(path string) (fs.File, error)
}
// hostFs is an unrestricted fileReadWriter that operates directly on the host filesystem.
@@ -278,6 +471,20 @@ func (h *hostFs) WriteFile(path string, data []byte) error {
return fileutil.WriteFileAtomic(path, data, 0o600)
}
func (h *hostFs) Open(path string) (fs.File, error) {
f, err := os.Open(path)
if err != nil {
if os.IsNotExist(err) {
return nil, fmt.Errorf("failed to open file: file not found: %w", err)
}
if os.IsPermission(err) {
return nil, fmt.Errorf("failed to open file: access denied: %w", err)
}
return nil, fmt.Errorf("failed to open file: %w", err)
}
return f, nil
}
// sandboxFs is a sandboxed fileSystem that operates within a strictly defined workspace using os.Root.
type sandboxFs struct {
workspace string
@@ -389,6 +596,26 @@ func (r *sandboxFs) ReadDir(path string) ([]os.DirEntry, error) {
return entries, err
}
func (r *sandboxFs) Open(path string) (fs.File, error) {
var f fs.File
err := r.execute(path, func(root *os.Root, relPath string) error {
file, err := root.Open(relPath)
if err != nil {
if os.IsNotExist(err) {
return fmt.Errorf("failed to open file: file not found: %w", err)
}
if os.IsPermission(err) || strings.Contains(err.Error(), "escapes from parent") ||
strings.Contains(err.Error(), "permission denied") {
return fmt.Errorf("failed to open file: access denied: %w", err)
}
return fmt.Errorf("failed to open file: %w", err)
}
f = file
return nil
})
return f, err
}
// whitelistFs wraps a sandboxFs and allows access to specific paths outside
// the workspace when they match any of the provided patterns.
type whitelistFs struct {
@@ -427,6 +654,13 @@ func (w *whitelistFs) ReadDir(path string) ([]os.DirEntry, error) {
return w.sandbox.ReadDir(path)
}
func (w *whitelistFs) Open(path string) (fs.File, error) {
if w.matches(path) {
return w.host.Open(path)
}
return w.sandbox.Open(path)
}
// buildFs returns the appropriate fileSystem implementation based on restriction
// settings and optional path whitelist patterns.
func buildFs(workspace string, restrict bool, patterns []*regexp.Regexp) fileSystem {
+130 -6
View File
@@ -18,7 +18,7 @@ func TestFilesystemTool_ReadFile_Success(t *testing.T) {
testFile := filepath.Join(tmpDir, "test.txt")
os.WriteFile(testFile, []byte("test content"), 0o644)
tool := NewReadFileTool("", false)
tool := NewReadFileTool("", false, MaxReadFileSize)
ctx := context.Background()
args := map[string]any{
"path": testFile,
@@ -45,7 +45,7 @@ func TestFilesystemTool_ReadFile_Success(t *testing.T) {
// TestFilesystemTool_ReadFile_NotFound verifies error handling for missing file
func TestFilesystemTool_ReadFile_NotFound(t *testing.T) {
tool := NewReadFileTool("", false)
tool := NewReadFileTool("", false, MaxReadFileSize)
ctx := context.Background()
args := map[string]any{
"path": "/nonexistent_file_12345.txt",
@@ -59,7 +59,7 @@ func TestFilesystemTool_ReadFile_NotFound(t *testing.T) {
}
// Should contain error message
if !strings.Contains(result.ForLLM, "failed to read") && !strings.Contains(result.ForUser, "failed to read") {
if !strings.Contains(result.ForLLM, "failed to open file") && !strings.Contains(result.ForUser, "failed to read") {
t.Errorf("Expected error message, got ForLLM: %s, ForUser: %s", result.ForLLM, result.ForUser)
}
}
@@ -271,7 +271,7 @@ func TestFilesystemTool_ReadFile_RejectsSymlinkEscape(t *testing.T) {
t.Skipf("symlink not supported in this environment: %v", err)
}
tool := NewReadFileTool(workspace, true)
tool := NewReadFileTool(workspace, true, MaxReadFileSize)
result := tool.Execute(context.Background(), map[string]any{
"path": link,
})
@@ -289,7 +289,7 @@ func TestFilesystemTool_ReadFile_RejectsSymlinkEscape(t *testing.T) {
}
func TestFilesystemTool_EmptyWorkspace_AccessDenied(t *testing.T) {
tool := NewReadFileTool("", true) // restrict=true but workspace=""
tool := NewReadFileTool("", true, MaxReadFileSize) // restrict=true but workspace=""
// Try to read a sensitive file (simulated by a temp file outside workspace)
tmpDir := t.TempDir()
@@ -499,7 +499,7 @@ func TestWhitelistFs_AllowsMatchingPaths(t *testing.T) {
// Pattern allows access to the outsideDir.
patterns := []*regexp.Regexp{regexp.MustCompile(`^` + regexp.QuoteMeta(outsideDir))}
tool := NewReadFileTool(workspace, true, patterns)
tool := NewReadFileTool(workspace, true, MaxReadFileSize, patterns)
// Read from whitelisted path should succeed.
result := tool.Execute(context.Background(), map[string]any{"path": outsideFile})
@@ -520,3 +520,127 @@ func TestWhitelistFs_AllowsMatchingPaths(t *testing.T) {
t.Errorf("expected non-whitelisted path to be blocked, got: %s", result.ForLLM)
}
}
// TestReadFileTool_ChunkedReading verifies the pagination logic of the tool
// by reading a file in multiple chunks using 'offset' and 'length'.
func TestReadFileTool_ChunkedReading(t *testing.T) {
tmpDir := t.TempDir()
testFile := filepath.Join(tmpDir, "pagination_test.txt")
// Create a test file with exactly 26 bytes of content
fullContent := "abcdefghijklmnopqrstuvwxyz"
err := os.WriteFile(testFile, []byte(fullContent), 0o644)
if err != nil {
t.Fatalf("Failed to write test file: %v", err)
}
tool := NewReadFileTool(tmpDir, false, MaxReadFileSize)
ctx := context.Background()
// --- Step 1: Read the first chunk (10 bytes) ---
args1 := map[string]any{
"path": testFile,
"offset": 0,
"length": 10,
}
result1 := tool.Execute(ctx, args1)
if result1.IsError {
t.Fatalf("Chunk 1 failed: %s", result1.ForLLM)
}
// Expect the first 10 characters
if !strings.Contains(result1.ForLLM, "abcdefghij") {
t.Errorf("Chunk 1 should contain 'abcdefghij', got: %s", result1.ForLLM)
}
// Expect the header to indicate the file is truncated
if !strings.Contains(result1.ForLLM, "[TRUNCATED") {
t.Errorf("Chunk 1 header should indicate truncation, got: %s", result1.ForLLM)
}
// Expect the header to suggest the next offset (10)
if !strings.Contains(result1.ForLLM, "offset=10") {
t.Errorf("Chunk 1 header should suggest next offset=10, got: %s", result1.ForLLM)
}
// Step 2: Read the second chunk (10 bytes) ---
args2 := map[string]any{
"path": testFile,
"offset": 10,
"length": 10,
}
result2 := tool.Execute(ctx, args2)
if result2.IsError {
t.Fatalf("Chunk 2 failed: %s", result2.ForLLM)
}
// Expect the next 10 characters
if !strings.Contains(result2.ForLLM, "klmnopqrst") {
t.Errorf("Chunk 2 should contain 'klmnopqrst', got: %s", result2.ForLLM)
}
// Expect the header to suggest the next offset (20)
if !strings.Contains(result2.ForLLM, "offset=20") {
t.Errorf("Chunk 2 header should suggest next offset=20, got: %s", result2.ForLLM)
}
// Step 3: Read the final chunk (remaining 6 bytes) ---
// We ask for 10 bytes, but only 6 are left in the file
args3 := map[string]any{
"path": testFile,
"offset": 20,
"length": 10,
}
result3 := tool.Execute(ctx, args3)
if result3.IsError {
t.Fatalf("Chunk 3 failed: %s", result3.ForLLM)
}
// Expect the last 6 characters
if !strings.Contains(result3.ForLLM, "uvwxyz") {
t.Errorf("Chunk 3 should contain 'uvwxyz', got: %s", result3.ForLLM)
}
// Expect the header to indicate the end of the file
if !strings.Contains(result3.ForLLM, "[END OF FILE") {
t.Errorf("Chunk 3 header should indicate end of file, got: %s", result3.ForLLM)
}
// Ensure no TRUNCATED message is present in the final chunk
if strings.Contains(result3.ForLLM, "[TRUNCATED") {
t.Errorf("Chunk 3 header should NOT indicate truncation, got: %s", result3.ForLLM)
}
}
// TestReadFileTool_OffsetBeyondEOF checks the behavior when requesting
// An offset that exceeds the total file size.
func TestReadFileTool_OffsetBeyondEOF(t *testing.T) {
tmpDir := t.TempDir()
testFile := filepath.Join(tmpDir, "short.txt")
// create a file of only 5 bytes
err := os.WriteFile(testFile, []byte("12345"), 0o644)
if err != nil {
t.Fatalf("Failed to write test file: %v", err)
}
tool := NewReadFileTool(tmpDir, false, MaxReadFileSize)
ctx := context.Background()
args := map[string]any{
"path": testFile,
"offset": int64(100), // Offset beyond the end of the file
}
result := tool.Execute(ctx, args)
// It should not be classified as a tool execution error
if result.IsError {
t.Errorf("A mistake was not expected, obtained IsError=true: %s", result.ForLLM)
}
// Must return EXACTLY the string provided in the code
expectedMsg := "[END OF FILE - no content at this offset]"
if result.ForLLM != expectedMsg {
t.Errorf("The message %q was expected, obtained: %q", expectedMsg, result.ForLLM)
}
}
+137 -11
View File
@@ -5,20 +5,28 @@ import (
"fmt"
"sort"
"sync"
"sync/atomic"
"time"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/providers"
)
type ToolEntry struct {
Tool Tool
IsCore bool
TTL int
}
type ToolRegistry struct {
tools map[string]Tool
mu sync.RWMutex
tools map[string]*ToolEntry
mu sync.RWMutex
version atomic.Uint64 // incremented on Register/RegisterHidden for cache invalidation
}
func NewToolRegistry() *ToolRegistry {
return &ToolRegistry{
tools: make(map[string]Tool),
tools: make(map[string]*ToolEntry),
}
}
@@ -30,14 +38,116 @@ func (r *ToolRegistry) Register(tool Tool) {
logger.WarnCF("tools", "Tool registration overwrites existing tool",
map[string]any{"name": name})
}
r.tools[name] = tool
r.tools[name] = &ToolEntry{
Tool: tool,
IsCore: true,
TTL: 0, // Core tools do not use TTL
}
r.version.Add(1)
logger.DebugCF("tools", "Registered core tool", map[string]any{"name": name})
}
// RegisterHidden saves hidden tools (visible only via TTL)
func (r *ToolRegistry) RegisterHidden(tool Tool) {
r.mu.Lock()
defer r.mu.Unlock()
name := tool.Name()
if _, exists := r.tools[name]; exists {
logger.WarnCF("tools", "Hidden tool registration overwrites existing tool",
map[string]any{"name": name})
}
r.tools[name] = &ToolEntry{
Tool: tool,
IsCore: false,
TTL: 0,
}
r.version.Add(1)
logger.DebugCF("tools", "Registered hidden tool", map[string]any{"name": name})
}
// PromoteTools atomically sets the TTL for multiple non-core tools.
// This prevents a concurrent TickTTL from decrementing between promotions.
func (r *ToolRegistry) PromoteTools(names []string, ttl int) {
r.mu.Lock()
defer r.mu.Unlock()
promoted := 0
for _, name := range names {
if entry, exists := r.tools[name]; exists {
if !entry.IsCore {
entry.TTL = ttl
promoted++
}
}
}
logger.DebugCF(
"tools",
"PromoteTools completed",
map[string]any{"requested": len(names), "promoted": promoted, "ttl": ttl},
)
}
// TickTTL decreases TTL only for non-core tools
func (r *ToolRegistry) TickTTL() {
r.mu.Lock()
defer r.mu.Unlock()
for _, entry := range r.tools {
if !entry.IsCore && entry.TTL > 0 {
entry.TTL--
}
}
}
// Version returns the current registry version (atomically).
func (r *ToolRegistry) Version() uint64 {
return r.version.Load()
}
// HiddenToolSnapshot holds a consistent snapshot of hidden tools and the
// registry version at which it was taken. Used by BM25SearchTool cache.
type HiddenToolSnapshot struct {
Docs []HiddenToolDoc
Version uint64
}
// HiddenToolDoc is a lightweight representation of a hidden tool for search indexing.
type HiddenToolDoc struct {
Name string
Description string
}
// SnapshotHiddenTools returns all non-core tools and the current registry
// version under a single read-lock, guaranteeing consistency between the
// two values.
func (r *ToolRegistry) SnapshotHiddenTools() HiddenToolSnapshot {
r.mu.RLock()
defer r.mu.RUnlock()
docs := make([]HiddenToolDoc, 0, len(r.tools))
for name, entry := range r.tools {
if !entry.IsCore {
docs = append(docs, HiddenToolDoc{
Name: name,
Description: entry.Tool.Description(),
})
}
}
return HiddenToolSnapshot{
Docs: docs,
Version: r.version.Load(),
}
}
func (r *ToolRegistry) Get(name string) (Tool, bool) {
r.mu.RLock()
defer r.mu.RUnlock()
tool, ok := r.tools[name]
return tool, ok
entry, ok := r.tools[name]
if !ok {
return nil, false
}
// Hidden tools with expired TTL are not callable.
if !entry.IsCore && entry.TTL <= 0 {
return nil, false
}
return entry.Tool, true
}
func (r *ToolRegistry) Execute(ctx context.Context, name string, args map[string]any) *ToolResult {
@@ -135,7 +245,13 @@ func (r *ToolRegistry) GetDefinitions() []map[string]any {
sorted := r.sortedToolNames()
definitions := make([]map[string]any, 0, len(sorted))
for _, name := range sorted {
definitions = append(definitions, ToolToSchema(r.tools[name]))
entry := r.tools[name]
if !entry.IsCore && entry.TTL <= 0 {
continue
}
definitions = append(definitions, ToolToSchema(r.tools[name].Tool))
}
return definitions
}
@@ -149,8 +265,13 @@ func (r *ToolRegistry) ToProviderDefs() []providers.ToolDefinition {
sorted := r.sortedToolNames()
definitions := make([]providers.ToolDefinition, 0, len(sorted))
for _, name := range sorted {
tool := r.tools[name]
schema := ToolToSchema(tool)
entry := r.tools[name]
if !entry.IsCore && entry.TTL <= 0 {
continue
}
schema := ToolToSchema(entry.Tool)
// Safely extract nested values with type checks
fn, ok := schema["function"].(map[string]any)
@@ -198,8 +319,13 @@ func (r *ToolRegistry) GetSummaries() []string {
sorted := r.sortedToolNames()
summaries := make([]string, 0, len(sorted))
for _, name := range sorted {
tool := r.tools[name]
summaries = append(summaries, fmt.Sprintf("- `%s` - %s", tool.Name(), tool.Description()))
entry := r.tools[name]
if !entry.IsCore && entry.TTL <= 0 {
continue
}
summaries = append(summaries, fmt.Sprintf("- `%s` - %s", entry.Tool.Name(), entry.Tool.Description()))
}
return summaries
}
+304
View File
@@ -0,0 +1,304 @@
package tools
import (
"context"
"encoding/json"
"fmt"
"regexp"
"strings"
"sync"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/utils"
)
const (
MaxRegexPatternLength = 200
)
type RegexSearchTool struct {
registry *ToolRegistry
ttl int
maxSearchResults int
}
func NewRegexSearchTool(r *ToolRegistry, ttl int, maxSearchResults int) *RegexSearchTool {
return &RegexSearchTool{registry: r, ttl: ttl, maxSearchResults: maxSearchResults}
}
func (t *RegexSearchTool) Name() string {
return "tool_search_tool_regex"
}
func (t *RegexSearchTool) Description() string {
return "Search available hidden tools on-demand using a regex pattern. Returns JSON schemas of discovered tools."
}
func (t *RegexSearchTool) Parameters() map[string]any {
return map[string]any{
"type": "object",
"properties": map[string]any{
"pattern": map[string]any{
"type": "string",
"description": "Regex pattern to match tool name or description",
},
},
"required": []string{"pattern"},
}
}
func (t *RegexSearchTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
pattern, ok := args["pattern"].(string)
if !ok || strings.TrimSpace(pattern) == "" {
// An empty string regex (?i) will match every hidden tool,
// dumping massive payloads into the context and burning tokens.
return ErrorResult("Missing or invalid 'pattern' argument. Must be a non-empty string.")
}
if len(pattern) > MaxRegexPatternLength {
logger.WarnCF("discovery", "Regex pattern rejected (too long)", map[string]any{"len": len(pattern)})
return ErrorResult(fmt.Sprintf("Pattern too long: max %d characters allowed", MaxRegexPatternLength))
}
logger.DebugCF("discovery", "Regex search", map[string]any{"pattern": pattern})
res, err := t.registry.SearchRegex(pattern, t.maxSearchResults)
if err != nil {
logger.WarnCF("discovery", "Invalid regex pattern", map[string]any{"pattern": pattern, "error": err.Error()})
return ErrorResult(fmt.Sprintf("Invalid regex pattern syntax: %v. Please fix your regex and try again.", err))
}
logger.InfoCF("discovery", "Regex search completed", map[string]any{"pattern": pattern, "results": len(res)})
return formatDiscoveryResponse(t.registry, res, t.ttl)
}
type BM25SearchTool struct {
registry *ToolRegistry
ttl int
maxSearchResults int
// Cache: rebuilt only when the registry version changes.
cacheMu sync.Mutex
cachedEngine *bm25CachedEngine
cacheVersion uint64
}
func NewBM25SearchTool(r *ToolRegistry, ttl int, maxSearchResults int) *BM25SearchTool {
return &BM25SearchTool{registry: r, ttl: ttl, maxSearchResults: maxSearchResults}
}
func (t *BM25SearchTool) Name() string {
return "tool_search_tool_bm25"
}
func (t *BM25SearchTool) Description() string {
return "Search available hidden tools on-demand using natural language query describing the action you need to perform. Returns JSON schemas of discovered tools."
}
func (t *BM25SearchTool) Parameters() map[string]any {
return map[string]any{
"type": "object",
"properties": map[string]any{
"query": map[string]any{
"type": "string",
"description": "Search query",
},
},
"required": []string{"query"},
}
}
func (t *BM25SearchTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
query, ok := args["query"].(string)
if !ok || strings.TrimSpace(query) == "" {
// An empty string query will match every hidden tool,
// dumping massive payloads into the context and burning tokens.
return ErrorResult("Missing or invalid 'query' argument. Must be a non-empty string.")
}
logger.DebugCF("discovery", "BM25 search", map[string]any{"query": query})
cached := t.getOrBuildEngine()
if cached == nil {
logger.DebugCF("discovery", "BM25 search: no hidden tools available", nil)
return SilentResult("No tools found matching the query.")
}
ranked := cached.engine.Search(query, t.maxSearchResults)
if len(ranked) == 0 {
logger.DebugCF("discovery", "BM25 search: no matches", map[string]any{"query": query})
return SilentResult("No tools found matching the query.")
}
results := make([]ToolSearchResult, len(ranked))
for i, r := range ranked {
results[i] = ToolSearchResult{
Name: r.Document.Name,
Description: r.Document.Description,
}
}
logger.InfoCF("discovery", "BM25 search completed", map[string]any{"query": query, "results": len(results)})
return formatDiscoveryResponse(t.registry, results, t.ttl)
}
// ToolSearchResult represents the result returned to the LLM.
// Parameters are omitted from the JSON response to save context tokens;
// the LLM will see full schemas via ToProviderDefs after promotion.
type ToolSearchResult struct {
Name string `json:"name"`
Description string `json:"description"`
}
func (r *ToolRegistry) SearchRegex(pattern string, maxSearchResults int) ([]ToolSearchResult, error) {
if maxSearchResults <= 0 {
return nil, nil
}
regex, err := regexp.Compile("(?i)" + pattern)
if err != nil {
return nil, fmt.Errorf("failed to compile regex pattern %q: %w", pattern, err)
}
r.mu.RLock()
defer r.mu.RUnlock()
var results []ToolSearchResult
// Iterate in sorted order for deterministic results across calls.
for _, name := range r.sortedToolNames() {
entry := r.tools[name]
// Search only among the hidden tools (Core tools are already visible)
if !entry.IsCore {
// Directly call interface methods! No reflection/unmarshalling needed.
desc := entry.Tool.Description()
if regex.MatchString(name) || regex.MatchString(desc) {
results = append(results, ToolSearchResult{
Name: name,
Description: desc,
})
if len(results) >= maxSearchResults {
break // Stop searching once we hit the max! Saves CPU.
}
}
}
}
return results, nil
}
func formatDiscoveryResponse(registry *ToolRegistry, results []ToolSearchResult, ttl int) *ToolResult {
if len(results) == 0 {
return SilentResult("No tools found matching the query.")
}
names := make([]string, len(results))
for i, r := range results {
names[i] = r.Name
}
registry.PromoteTools(names, ttl)
logger.InfoCF("discovery", "Promoted tools", map[string]any{"tools": names, "ttl": ttl})
b, err := json.Marshal(results)
if err != nil {
return ErrorResult("Failed to format search results: " + err.Error())
}
msg := fmt.Sprintf(
"Found %d tools:\n%s\n\nSUCCESS: These tools have been temporarily UNLOCKED as native tools! In your next response, you can call them directly just like any normal tool",
len(results),
string(b),
)
return SilentResult(msg)
}
// Lightweight internal type used as corpus document for BM25.
type searchDoc struct {
Name string
Description string
}
// bm25CachedEngine wraps a BM25Engine with its corpus snapshot.
type bm25CachedEngine struct {
engine *utils.BM25Engine[searchDoc]
}
// snapshotToSearchDocs converts a HiddenToolSnapshot to BM25 searchDoc slice.
func snapshotToSearchDocs(snap HiddenToolSnapshot) []searchDoc {
docs := make([]searchDoc, len(snap.Docs))
for i, d := range snap.Docs {
docs[i] = searchDoc{Name: d.Name, Description: d.Description}
}
return docs
}
// buildBM25Engine creates a BM25Engine from a slice of searchDocs.
func buildBM25Engine(docs []searchDoc) *utils.BM25Engine[searchDoc] {
return utils.NewBM25Engine(
docs,
func(doc searchDoc) string {
return doc.Name + " " + doc.Description
},
)
}
// getOrBuildEngine returns a cached BM25 engine, rebuilding it only when
// the registry version has changed (new tools registered).
func (t *BM25SearchTool) getOrBuildEngine() *bm25CachedEngine {
// Fast path: optimistic check without locking.
if t.cachedEngine != nil && t.cacheVersion == t.registry.Version() {
return t.cachedEngine
}
t.cacheMu.Lock()
defer t.cacheMu.Unlock()
// Snapshot + version are read under a single registry RLock,
// guaranteeing consistency (no TOCTOU).
snap := t.registry.SnapshotHiddenTools()
// Re-check: another goroutine may have rebuilt while we waited for cacheMu.
if t.cachedEngine != nil && t.cacheVersion == snap.Version {
return t.cachedEngine
}
docs := snapshotToSearchDocs(snap)
if len(docs) == 0 {
t.cachedEngine = nil
t.cacheVersion = snap.Version
return nil
}
cached := &bm25CachedEngine{engine: buildBM25Engine(docs)}
t.cachedEngine = cached
t.cacheVersion = snap.Version
logger.DebugCF("discovery", "BM25 engine rebuilt", map[string]any{"docs": len(docs), "version": snap.Version})
return cached
}
// SearchBM25 ranks hidden tools against query using BM25 via utils.BM25Engine.
// This non-cached variant rebuilds the engine on every call. Used by tests
// and any code that doesn't hold a BM25SearchTool instance.
func (r *ToolRegistry) SearchBM25(query string, maxSearchResults int) []ToolSearchResult {
snap := r.SnapshotHiddenTools()
docs := snapshotToSearchDocs(snap)
if len(docs) == 0 {
return nil
}
ranked := buildBM25Engine(docs).Search(query, maxSearchResults)
if len(ranked) == 0 {
return nil
}
out := make([]ToolSearchResult, len(ranked))
for i, r := range ranked {
out[i] = ToolSearchResult{
Name: r.Document.Name,
Description: r.Document.Description,
}
}
return out
}
+339
View File
@@ -0,0 +1,339 @@
package tools
import (
"context"
"fmt"
"strings"
"testing"
)
// Dummy tool to fill the registry in our tests.
type mockSearchableTool struct {
name string
desc string
}
func (m *mockSearchableTool) Name() string { return m.name }
func (m *mockSearchableTool) Description() string { return m.desc }
func (m *mockSearchableTool) Parameters() map[string]any {
return map[string]any{"type": "object"}
}
func (m *mockSearchableTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
return SilentResult("mock executed: " + m.name)
}
// Helper to initialize a populated ToolRegistry
func setupPopulatedRegistry() *ToolRegistry {
reg := NewToolRegistry()
// A core tool (NOT to be found by searches)
reg.Register(&mockSearchableTool{
name: "core_search",
desc: "I am a visible core tool for searching files",
})
// Hidden tools (must be found by searches)
reg.RegisterHidden(&mockSearchableTool{
name: "mcp_read_file",
desc: "Read the contents of a system file",
})
reg.RegisterHidden(&mockSearchableTool{
name: "mcp_list_dir",
desc: "List directories and files in the system",
})
reg.RegisterHidden(&mockSearchableTool{
name: "mcp_fetch_net",
desc: "Fetch data from a network database",
})
return reg
}
func TestRegexSearchTool_Execute(t *testing.T) {
reg := setupPopulatedRegistry()
tool := NewRegexSearchTool(reg, 5, 10)
ctx := context.Background()
t.Run("Empty Pattern Error", func(t *testing.T) {
res := tool.Execute(ctx, map[string]any{})
if !res.IsError || !strings.Contains(res.ForLLM, "Missing or invalid 'pattern'") {
t.Errorf("Expected missing pattern error, got: %v", res.ForLLM)
}
})
t.Run("Invalid Regex Syntax", func(t *testing.T) {
res := tool.Execute(ctx, map[string]any{"pattern": "[unclosed"})
if !res.IsError || !strings.Contains(res.ForLLM, "Invalid regex pattern syntax") {
t.Errorf("Expected regex syntax error, got: %v", res.ForLLM)
}
})
t.Run("No Match Found", func(t *testing.T) {
res := tool.Execute(ctx, map[string]any{"pattern": "alien"})
if res.IsError || !strings.Contains(res.ForLLM, "No tools found matching") {
t.Errorf("Expected 'no tools found' message, got: %v", res.ForLLM)
}
})
t.Run("Successful Match & Promotion", func(t *testing.T) {
res := tool.Execute(ctx, map[string]any{"pattern": "system"})
if res.IsError {
t.Fatalf("Unexpected error: %v", res.ForLLM)
}
if !strings.Contains(res.ForLLM, "SUCCESS: These tools have been temporarily UNLOCKED") {
t.Errorf("Expected success string, got: %v", res.ForLLM)
}
if !strings.Contains(res.ForLLM, "mcp_read_file") {
t.Errorf("Expected 'mcp_read_file' in results")
}
// Verify that the TTL has been updated for the tools found
reg.mu.RLock()
defer reg.mu.RUnlock()
if reg.tools["mcp_read_file"].TTL != 5 {
t.Errorf("Expected TTL of 'mcp_read_file' to be promoted to 5, got %d", reg.tools["mcp_read_file"].TTL)
}
if reg.tools["mcp_fetch_net"].TTL != 0 {
t.Errorf("Expected 'mcp_fetch_net' to NOT be promoted (TTL=0)")
}
})
}
func TestBM25SearchTool_Execute(t *testing.T) {
reg := setupPopulatedRegistry()
tool := NewBM25SearchTool(reg, 3, 10)
ctx := context.Background()
t.Run("Empty Query Error", func(t *testing.T) {
res := tool.Execute(ctx, map[string]any{"query": " "})
if !res.IsError || !strings.Contains(res.ForLLM, "Missing or invalid 'query'") {
t.Errorf("Expected missing query error, got: %v", res.ForLLM)
}
})
t.Run("No Match Found", func(t *testing.T) {
res := tool.Execute(ctx, map[string]any{"query": "aliens spaceships"})
if res.IsError || !strings.Contains(res.ForLLM, "No tools found matching") {
t.Errorf("Expected 'no tools found', got: %v", res.ForLLM)
}
})
t.Run("Successful Match & Promotion", func(t *testing.T) {
res := tool.Execute(ctx, map[string]any{"query": "read files"})
if res.IsError {
t.Fatalf("Unexpected error: %v", res.ForLLM)
}
if !strings.Contains(res.ForLLM, "mcp_read_file") {
t.Errorf("Expected 'mcp_read_file' in BM25 results")
}
reg.mu.RLock()
defer reg.mu.RUnlock()
if reg.tools["mcp_read_file"].TTL != 3 {
t.Errorf("Expected TTL of 'mcp_read_file' to be promoted to 3")
}
})
}
func TestRegexSearchTool_PatternTooLong(t *testing.T) {
reg := setupPopulatedRegistry()
tool := NewRegexSearchTool(reg, 5, 10)
ctx := context.Background()
longPattern := strings.Repeat("a", MaxRegexPatternLength+1)
res := tool.Execute(ctx, map[string]any{"pattern": longPattern})
if !res.IsError || !strings.Contains(res.ForLLM, "Pattern too long") {
t.Errorf("Expected pattern too long error, got: %v", res.ForLLM)
}
}
func TestSearchRegex_ZeroMaxResults(t *testing.T) {
reg := setupPopulatedRegistry()
res, err := reg.SearchRegex("mcp", 0)
if err != nil {
t.Fatalf("SearchRegex failed: %v", err)
}
if len(res) != 0 {
t.Errorf("Expected 0 results with maxSearchResults=0, got %d", len(res))
}
}
func TestSearchBM25_ZeroMaxResults(t *testing.T) {
reg := setupPopulatedRegistry()
res := reg.SearchBM25("read file", 0)
if len(res) != 0 {
t.Errorf("Expected 0 results with maxSearchResults=0, got %d", len(res))
}
}
func TestSearchRegex_DeterministicOrder(t *testing.T) {
reg := NewToolRegistry()
for i := 0; i < 20; i++ {
reg.RegisterHidden(&mockSearchableTool{
name: fmt.Sprintf("tool_%02d", i),
desc: "searchable tool",
})
}
// Run the same search multiple times and verify order is stable
var firstRun []string
for attempt := 0; attempt < 10; attempt++ {
res, err := reg.SearchRegex("searchable", 20)
if err != nil {
t.Fatalf("SearchRegex failed: %v", err)
}
names := make([]string, len(res))
for i, r := range res {
names[i] = r.Name
}
if attempt == 0 {
firstRun = names
} else {
for i, name := range names {
if name != firstRun[i] {
t.Fatalf("Non-deterministic order at attempt %d, index %d: got %q, want %q",
attempt, i, name, firstRun[i])
}
}
}
}
}
func TestToolRegistry_SearchLimitsAndCoreFiltering(t *testing.T) {
reg := NewToolRegistry()
// Add 1 Core and 10 Hidden, all containing the word "match"
reg.Register(&mockSearchableTool{"core_match", "I am core with match"})
for i := 0; i < 10; i++ {
reg.RegisterHidden(&mockSearchableTool{
name: fmt.Sprintf("hidden_match_%d", i),
desc: "this has a match",
})
}
t.Run("Regex limits and core filtering", func(t *testing.T) {
// Search with Regex and a limit of maxSearchResults = 4
res, err := reg.SearchRegex("match", 4)
if err != nil {
t.Fatalf("SearchRegex failed: %v", err)
}
if len(res) != 4 {
t.Errorf("Expected exactly 4 results due to limit, got %d", len(res))
}
for _, r := range res {
if r.Name == "core_match" {
t.Errorf("SearchRegex returned a Core tool, which should be excluded")
}
}
})
t.Run("BM25 limits and core filtering", func(t *testing.T) {
// Search with BM25 and a limit of maxSearchResults = 3
res := reg.SearchBM25("match", 3)
if len(res) != 3 {
t.Errorf("Expected exactly 3 results due to limit, got %d", len(res))
}
for _, r := range res {
if r.Name == "core_match" {
t.Errorf("SearchBM25 returned a Core tool, which should be excluded")
}
}
})
}
func TestGet_HiddenToolTTLLifecycle(t *testing.T) {
reg := NewToolRegistry()
reg.RegisterHidden(&mockSearchableTool{name: "hidden_tool", desc: "test"})
// TTL=0 at registration → not gettable
_, ok := reg.Get("hidden_tool")
if ok {
t.Error("Expected hidden tool with TTL=0 to NOT be gettable")
}
// Promote → gettable
reg.PromoteTools([]string{"hidden_tool"}, 3)
_, ok = reg.Get("hidden_tool")
if !ok {
t.Error("Expected promoted hidden tool to be gettable")
}
// Tick down to 0 → not gettable again
reg.TickTTL() // 3→2
reg.TickTTL() // 2→1
reg.TickTTL() // 1→0
_, ok = reg.Get("hidden_tool")
if ok {
t.Error("Expected hidden tool with TTL ticked to 0 to NOT be gettable")
}
// Core tools remain always gettable
reg.Register(&mockSearchableTool{name: "core_tool", desc: "core"})
_, ok = reg.Get("core_tool")
if !ok {
t.Error("Expected core tool to always be gettable")
}
}
func TestBM25CacheInvalidation(t *testing.T) {
reg := NewToolRegistry()
reg.RegisterHidden(&mockSearchableTool{name: "tool_alpha", desc: "alpha functionality"})
tool := NewBM25SearchTool(reg, 5, 10)
ctx := context.Background()
// First search should find tool_alpha
res := tool.Execute(ctx, map[string]any{"query": "alpha"})
if !strings.Contains(res.ForLLM, "tool_alpha") {
t.Fatalf("Expected 'tool_alpha' in first search, got: %v", res.ForLLM)
}
// Register a new hidden tool
reg.RegisterHidden(&mockSearchableTool{name: "tool_beta", desc: "beta functionality"})
// Cache should be invalidated; new tool should be findable
res = tool.Execute(ctx, map[string]any{"query": "beta"})
if !strings.Contains(res.ForLLM, "tool_beta") {
t.Errorf("Expected 'tool_beta' after cache invalidation, got: %v", res.ForLLM)
}
}
func TestPromoteTools_ConcurrentWithTickTTL(t *testing.T) {
reg := NewToolRegistry()
for i := 0; i < 20; i++ {
reg.RegisterHidden(&mockSearchableTool{
name: fmt.Sprintf("concurrent_tool_%d", i),
desc: "concurrent test tool",
})
}
names := make([]string, 20)
for i := 0; i < 20; i++ {
names[i] = fmt.Sprintf("concurrent_tool_%d", i)
}
// Hammer PromoteTools and TickTTL concurrently to detect races
done := make(chan struct{})
go func() {
for i := 0; i < 1000; i++ {
reg.PromoteTools(names, 5)
}
close(done)
}()
for i := 0; i < 1000; i++ {
reg.TickTTL()
}
<-done
}
+295 -188
View File
@@ -11,6 +11,7 @@ import (
"net/url"
"regexp"
"strings"
"sync/atomic"
"time"
)
@@ -76,81 +77,140 @@ func createHTTPClient(proxyURL string, timeout time.Duration) (*http.Client, err
return client, nil
}
type APIKeyPool struct {
keys []string
current uint32
}
func NewAPIKeyPool(keys []string) *APIKeyPool {
return &APIKeyPool{
keys: keys,
}
}
type APIKeyIterator struct {
pool *APIKeyPool
startIdx uint32
attempt uint32
}
func (p *APIKeyPool) NewIterator() *APIKeyIterator {
if len(p.keys) == 0 {
return &APIKeyIterator{pool: p}
}
idx := atomic.AddUint32(&p.current, 1) - 1
return &APIKeyIterator{
pool: p,
startIdx: idx,
}
}
func (it *APIKeyIterator) Next() (string, bool) {
length := uint32(len(it.pool.keys))
if length == 0 || it.attempt >= length {
return "", false
}
key := it.pool.keys[(it.startIdx+it.attempt)%length]
it.attempt++
return key, true
}
type SearchProvider interface {
Search(ctx context.Context, query string, count int) (string, error)
}
type BraveSearchProvider struct {
apiKey string
proxy string
client *http.Client
keyPool *APIKeyPool
proxy string
client *http.Client
}
func (p *BraveSearchProvider) Search(ctx context.Context, query string, count int) (string, error) {
searchURL := fmt.Sprintf("https://api.search.brave.com/res/v1/web/search?q=%s&count=%d",
url.QueryEscape(query), count)
req, err := http.NewRequestWithContext(ctx, "GET", searchURL, nil)
if err != nil {
return "", fmt.Errorf("failed to create request: %w", err)
}
var lastErr error
iter := p.keyPool.NewIterator()
req.Header.Set("Accept", "application/json")
req.Header.Set("X-Subscription-Token", p.apiKey)
resp, err := p.client.Do(req)
if err != nil {
return "", fmt.Errorf("request failed: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("failed to read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("brave api error (status %d): %s", resp.StatusCode, string(body))
}
var searchResp struct {
Web struct {
Results []struct {
Title string `json:"title"`
URL string `json:"url"`
Description string `json:"description"`
} `json:"results"`
} `json:"web"`
}
if err := json.Unmarshal(body, &searchResp); err != nil {
// Log error body for debugging
fmt.Printf("Brave API Error Body: %s\n", string(body))
return "", fmt.Errorf("failed to parse response: %w", err)
}
results := searchResp.Web.Results
if len(results) == 0 {
return fmt.Sprintf("No results for: %s", query), nil
}
var lines []string
lines = append(lines, fmt.Sprintf("Results for: %s", query))
for i, item := range results {
if i >= count {
for {
apiKey, ok := iter.Next()
if !ok {
break
}
lines = append(lines, fmt.Sprintf("%d. %s\n %s", i+1, item.Title, item.URL))
if item.Description != "" {
lines = append(lines, fmt.Sprintf(" %s", item.Description))
req, err := http.NewRequestWithContext(ctx, "GET", searchURL, nil)
if err != nil {
return "", fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Accept", "application/json")
req.Header.Set("X-Subscription-Token", apiKey)
resp, err := p.client.Do(req)
if err != nil {
lastErr = fmt.Errorf("request failed: %w", err)
continue
}
body, err := io.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
lastErr = fmt.Errorf("failed to read response: %w", err)
continue
}
if resp.StatusCode != http.StatusOK {
lastErr = fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
if resp.StatusCode == http.StatusTooManyRequests ||
resp.StatusCode == http.StatusUnauthorized ||
resp.StatusCode == http.StatusForbidden ||
resp.StatusCode >= 500 {
continue
}
return "", lastErr
}
var searchResp struct {
Web struct {
Results []struct {
Title string `json:"title"`
URL string `json:"url"`
Description string `json:"description"`
} `json:"results"`
} `json:"web"`
}
if err := json.Unmarshal(body, &searchResp); err != nil {
// Log error body for debugging
return "", fmt.Errorf("failed to parse response: %w", err)
}
results := searchResp.Web.Results
if len(results) == 0 {
return fmt.Sprintf("No results for: %s", query), nil
}
var lines []string
lines = append(lines, fmt.Sprintf("Results for: %s", query))
for i, item := range results {
if i >= count {
break
}
lines = append(lines, fmt.Sprintf("%d. %s\n %s", i+1, item.Title, item.URL))
if item.Description != "" {
lines = append(lines, fmt.Sprintf(" %s", item.Description))
}
}
return strings.Join(lines, "\n"), nil
}
return strings.Join(lines, "\n"), nil
return "", fmt.Errorf("all api keys failed, last error: %w", lastErr)
}
type TavilySearchProvider struct {
apiKey string
keyPool *APIKeyPool
baseURL string
proxy string
client *http.Client
@@ -162,74 +222,96 @@ func (p *TavilySearchProvider) Search(ctx context.Context, query string, count i
searchURL = "https://api.tavily.com/search"
}
payload := map[string]any{
"api_key": p.apiKey,
"query": query,
"search_depth": "advanced",
"include_answer": false,
"include_images": false,
"include_raw_content": false,
"max_results": count,
}
var lastErr error
iter := p.keyPool.NewIterator()
bodyBytes, err := json.Marshal(payload)
if err != nil {
return "", fmt.Errorf("failed to marshal payload: %w", err)
}
req, err := http.NewRequestWithContext(ctx, "POST", searchURL, bytes.NewBuffer(bodyBytes))
if err != nil {
return "", fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", userAgent)
resp, err := p.client.Do(req)
if err != nil {
return "", fmt.Errorf("request failed: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("failed to read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("tavily api error (status %d): %s", resp.StatusCode, string(body))
}
var searchResp struct {
Results []struct {
Title string `json:"title"`
URL string `json:"url"`
Content string `json:"content"`
} `json:"results"`
}
if err := json.Unmarshal(body, &searchResp); err != nil {
return "", fmt.Errorf("failed to parse response: %w", err)
}
results := searchResp.Results
if len(results) == 0 {
return fmt.Sprintf("No results for: %s", query), nil
}
var lines []string
lines = append(lines, fmt.Sprintf("Results for: %s (via Tavily)", query))
for i, item := range results {
if i >= count {
for {
apiKey, ok := iter.Next()
if !ok {
break
}
lines = append(lines, fmt.Sprintf("%d. %s\n %s", i+1, item.Title, item.URL))
if item.Content != "" {
lines = append(lines, fmt.Sprintf(" %s", item.Content))
payload := map[string]any{
"api_key": apiKey,
"query": query,
"search_depth": "advanced",
"include_answer": false,
"include_images": false,
"include_raw_content": false,
"max_results": count,
}
bodyBytes, err := json.Marshal(payload)
if err != nil {
return "", fmt.Errorf("failed to marshal payload: %w", err)
}
req, err := http.NewRequestWithContext(ctx, "POST", searchURL, bytes.NewBuffer(bodyBytes))
if err != nil {
return "", fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", userAgent)
resp, err := p.client.Do(req)
if err != nil {
lastErr = fmt.Errorf("request failed: %w", err)
continue
}
body, err := io.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
lastErr = fmt.Errorf("failed to read response: %w", err)
continue
}
if resp.StatusCode != http.StatusOK {
lastErr = fmt.Errorf("tavily api error (status %d): %s", resp.StatusCode, string(body))
if resp.StatusCode == http.StatusTooManyRequests ||
resp.StatusCode == http.StatusUnauthorized ||
resp.StatusCode == http.StatusForbidden ||
resp.StatusCode >= 500 {
continue
}
return "", lastErr
}
var searchResp struct {
Results []struct {
Title string `json:"title"`
URL string `json:"url"`
Content string `json:"content"`
} `json:"results"`
}
if err := json.Unmarshal(body, &searchResp); err != nil {
return "", fmt.Errorf("failed to parse response: %w", err)
}
results := searchResp.Results
if len(results) == 0 {
return fmt.Sprintf("No results for: %s", query), nil
}
var lines []string
lines = append(lines, fmt.Sprintf("Results for: %s (via Tavily)", query))
for i, item := range results {
if i >= count {
break
}
lines = append(lines, fmt.Sprintf("%d. %s\n %s", i+1, item.Title, item.URL))
if item.Content != "" {
lines = append(lines, fmt.Sprintf(" %s", item.Content))
}
}
return strings.Join(lines, "\n"), nil
}
return strings.Join(lines, "\n"), nil
return "", fmt.Errorf("all api keys failed, last error: %w", lastErr)
}
type DuckDuckGoSearchProvider struct {
@@ -324,75 +406,97 @@ func stripTags(content string) string {
}
type PerplexitySearchProvider struct {
apiKey string
proxy string
client *http.Client
keyPool *APIKeyPool
proxy string
client *http.Client
}
func (p *PerplexitySearchProvider) Search(ctx context.Context, query string, count int) (string, error) {
searchURL := "https://api.perplexity.ai/chat/completions"
payload := map[string]any{
"model": "sonar",
"messages": []map[string]string{
{
"role": "system",
"content": "You are a search assistant. Provide concise search results with titles, URLs, and brief descriptions in the following format:\n1. Title\n URL\n Description\n\nDo not add extra commentary.",
var lastErr error
iter := p.keyPool.NewIterator()
for {
apiKey, ok := iter.Next()
if !ok {
break
}
payload := map[string]any{
"model": "sonar",
"messages": []map[string]string{
{
"role": "system",
"content": "You are a search assistant. Provide concise search results with titles, URLs, and brief descriptions in the following format:\n1. Title\n URL\n Description\n\nDo not add extra commentary.",
},
{
"role": "user",
"content": fmt.Sprintf("Search for: %s. Provide up to %d relevant results.", query, count),
},
},
{
"role": "user",
"content": fmt.Sprintf("Search for: %s. Provide up to %d relevant results.", query, count),
},
},
"max_tokens": 1000,
"max_tokens": 1000,
}
payloadBytes, err := json.Marshal(payload)
if err != nil {
return "", fmt.Errorf("failed to marshal request: %w", err)
}
req, err := http.NewRequestWithContext(ctx, "POST", searchURL, strings.NewReader(string(payloadBytes)))
if err != nil {
return "", fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+apiKey)
req.Header.Set("User-Agent", userAgent)
resp, err := p.client.Do(req)
if err != nil {
lastErr = fmt.Errorf("request failed: %w", err)
continue
}
body, err := io.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
lastErr = fmt.Errorf("failed to read response: %w", err)
continue
}
if resp.StatusCode != http.StatusOK {
lastErr = fmt.Errorf("Perplexity API error: %s", string(body))
if resp.StatusCode == http.StatusTooManyRequests ||
resp.StatusCode == http.StatusUnauthorized ||
resp.StatusCode == http.StatusForbidden ||
resp.StatusCode >= 500 {
continue
}
return "", lastErr
}
var searchResp struct {
Choices []struct {
Message struct {
Content string `json:"content"`
} `json:"message"`
} `json:"choices"`
}
if err := json.Unmarshal(body, &searchResp); err != nil {
return "", fmt.Errorf("failed to parse response: %w", err)
}
if len(searchResp.Choices) == 0 {
return fmt.Sprintf("No results for: %s", query), nil
}
return fmt.Sprintf("Results for: %s (via Perplexity)\n%s", query, searchResp.Choices[0].Message.Content), nil
}
payloadBytes, err := json.Marshal(payload)
if err != nil {
return "", fmt.Errorf("failed to marshal request: %w", err)
}
req, err := http.NewRequestWithContext(ctx, "POST", searchURL, strings.NewReader(string(payloadBytes)))
if err != nil {
return "", fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+p.apiKey)
req.Header.Set("User-Agent", userAgent)
resp, err := p.client.Do(req)
if err != nil {
return "", fmt.Errorf("request failed: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("failed to read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("Perplexity API error: %s", string(body))
}
var searchResp struct {
Choices []struct {
Message struct {
Content string `json:"content"`
} `json:"message"`
} `json:"choices"`
}
if err := json.Unmarshal(body, &searchResp); err != nil {
return "", fmt.Errorf("failed to parse response: %w", err)
}
if len(searchResp.Choices) == 0 {
return fmt.Sprintf("No results for: %s", query), nil
}
return fmt.Sprintf("Results for: %s (via Perplexity)\n%s", query, searchResp.Choices[0].Message.Content), nil
return "", fmt.Errorf("all api keys failed, last error: %w", lastErr)
}
type SearXNGSearchProvider struct {
@@ -545,16 +649,16 @@ type WebSearchTool struct {
}
type WebSearchToolOptions struct {
BraveAPIKey string
BraveAPIKeys []string
BraveMaxResults int
BraveEnabled bool
TavilyAPIKey string
TavilyAPIKeys []string
TavilyBaseURL string
TavilyMaxResults int
TavilyEnabled bool
DuckDuckGoMaxResults int
DuckDuckGoEnabled bool
PerplexityAPIKey string
PerplexityAPIKeys []string
PerplexityMaxResults int
PerplexityEnabled bool
SearXNGBaseURL string
@@ -571,23 +675,26 @@ type WebSearchToolOptions struct {
func NewWebSearchTool(opts WebSearchToolOptions) (*WebSearchTool, error) {
var provider SearchProvider
maxResults := 5
// Priority: Perplexity > Brave > SearXNG > Tavily > DuckDuckGo > GLM Search
if opts.PerplexityEnabled && opts.PerplexityAPIKey != "" {
if opts.PerplexityEnabled && len(opts.PerplexityAPIKeys) > 0 {
client, err := createHTTPClient(opts.Proxy, perplexityTimeout)
if err != nil {
return nil, fmt.Errorf("failed to create HTTP client for Perplexity: %w", err)
}
provider = &PerplexitySearchProvider{apiKey: opts.PerplexityAPIKey, proxy: opts.Proxy, client: client}
provider = &PerplexitySearchProvider{
keyPool: NewAPIKeyPool(opts.PerplexityAPIKeys),
proxy: opts.Proxy,
client: client,
}
if opts.PerplexityMaxResults > 0 {
maxResults = opts.PerplexityMaxResults
}
} else if opts.BraveEnabled && opts.BraveAPIKey != "" {
} else if opts.BraveEnabled && len(opts.BraveAPIKeys) > 0 {
client, err := createHTTPClient(opts.Proxy, searchTimeout)
if err != nil {
return nil, fmt.Errorf("failed to create HTTP client for Brave: %w", err)
}
provider = &BraveSearchProvider{apiKey: opts.BraveAPIKey, proxy: opts.Proxy, client: client}
provider = &BraveSearchProvider{keyPool: NewAPIKeyPool(opts.BraveAPIKeys), proxy: opts.Proxy, client: client}
if opts.BraveMaxResults > 0 {
maxResults = opts.BraveMaxResults
}
@@ -596,13 +703,13 @@ func NewWebSearchTool(opts WebSearchToolOptions) (*WebSearchTool, error) {
if opts.SearXNGMaxResults > 0 {
maxResults = opts.SearXNGMaxResults
}
} else if opts.TavilyEnabled && opts.TavilyAPIKey != "" {
} else if opts.TavilyEnabled && len(opts.TavilyAPIKeys) > 0 {
client, err := createHTTPClient(opts.Proxy, searchTimeout)
if err != nil {
return nil, fmt.Errorf("failed to create HTTP client for Tavily: %w", err)
}
provider = &TavilySearchProvider{
apiKey: opts.TavilyAPIKey,
keyPool: NewAPIKeyPool(opts.TavilyAPIKeys),
baseURL: opts.TavilyBaseURL,
proxy: opts.Proxy,
client: client,
+124 -5
View File
@@ -249,7 +249,7 @@ func TestWebFetchTool_PayloadTooLarge(t *testing.T) {
// TestWebTool_WebSearch_NoApiKey verifies that no tool is created when API key is missing
func TestWebTool_WebSearch_NoApiKey(t *testing.T) {
tool, err := NewWebSearchTool(WebSearchToolOptions{BraveEnabled: true, BraveAPIKey: ""})
tool, err := NewWebSearchTool(WebSearchToolOptions{BraveEnabled: true, BraveAPIKeys: nil})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
@@ -269,7 +269,11 @@ func TestWebTool_WebSearch_NoApiKey(t *testing.T) {
// TestWebTool_WebSearch_MissingQuery verifies error handling for missing query
func TestWebTool_WebSearch_MissingQuery(t *testing.T) {
tool, err := NewWebSearchTool(WebSearchToolOptions{BraveEnabled: true, BraveAPIKey: "test-key", BraveMaxResults: 5})
tool, err := NewWebSearchTool(WebSearchToolOptions{
BraveEnabled: true,
BraveAPIKeys: []string{"test-key"},
BraveMaxResults: 5,
})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
@@ -553,7 +557,7 @@ func TestNewWebSearchTool_PropagatesProxy(t *testing.T) {
t.Run("perplexity", func(t *testing.T) {
tool, err := NewWebSearchTool(WebSearchToolOptions{
PerplexityEnabled: true,
PerplexityAPIKey: "k",
PerplexityAPIKeys: []string{"k"},
PerplexityMaxResults: 3,
Proxy: "http://127.0.0.1:7890",
})
@@ -572,7 +576,7 @@ func TestNewWebSearchTool_PropagatesProxy(t *testing.T) {
t.Run("brave", func(t *testing.T) {
tool, err := NewWebSearchTool(WebSearchToolOptions{
BraveEnabled: true,
BraveAPIKey: "k",
BraveAPIKeys: []string{"k"},
BraveMaxResults: 3,
Proxy: "http://127.0.0.1:7890",
})
@@ -650,7 +654,7 @@ func TestWebTool_TavilySearch_Success(t *testing.T) {
tool, err := NewWebSearchTool(WebSearchToolOptions{
TavilyEnabled: true,
TavilyAPIKey: "test-key",
TavilyAPIKeys: []string{"test-key"},
TavilyBaseURL: server.URL,
TavilyMaxResults: 5,
})
@@ -682,6 +686,121 @@ func TestWebTool_TavilySearch_Success(t *testing.T) {
}
}
func TestAPIKeyPool(t *testing.T) {
pool := NewAPIKeyPool([]string{"key1", "key2", "key3"})
if len(pool.keys) != 3 {
t.Fatalf("expected 3 keys, got %d", len(pool.keys))
}
if pool.keys[0] != "key1" || pool.keys[1] != "key2" || pool.keys[2] != "key3" {
t.Fatalf("unexpected keys: %v", pool.keys)
}
// Test Iterator: each iterator should cover all keys exactly once
iter := pool.NewIterator()
expected := []string{"key1", "key2", "key3"}
for i, want := range expected {
k, ok := iter.Next()
if !ok {
t.Fatalf("iter.Next() returned false at step %d", i)
}
if k != want {
t.Errorf("step %d: expected %s, got %s", i, want, k)
}
}
// Should be exhausted
if _, ok := iter.Next(); ok {
t.Errorf("expected iterator exhausted after all keys")
}
// Second iterator starts at next position (load balancing)
iter2 := pool.NewIterator()
k, ok := iter2.Next()
if !ok {
t.Fatal("iter2.Next() returned false")
}
if k != "key2" {
t.Errorf("expected key2 (round-robin), got %s", k)
}
// Empty pool
emptyPool := NewAPIKeyPool([]string{})
emptyIter := emptyPool.NewIterator()
if _, ok := emptyIter.Next(); ok {
t.Errorf("expected false for empty pool")
}
// Single key pool
singlePool := NewAPIKeyPool([]string{"single"})
singleIter := singlePool.NewIterator()
if k, ok := singleIter.Next(); !ok || k != "single" {
t.Errorf("expected single, got %s (ok=%v)", k, ok)
}
if _, ok := singleIter.Next(); ok {
t.Errorf("expected exhausted after single key")
}
}
func TestWebTool_TavilySearch_Failover(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var payload map[string]any
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
t.Fatalf("failed to decode payload: %v", err)
}
apiKey := payload["api_key"].(string)
if apiKey == "key1" {
w.WriteHeader(http.StatusTooManyRequests)
w.Write([]byte("Rate limited"))
return
}
if apiKey == "key2" {
// Success
response := map[string]any{
"results": []map[string]any{
{
"title": "Success Result",
"url": "https://example.com/success",
"content": "Success content",
},
},
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(response)
return
}
w.WriteHeader(http.StatusBadRequest)
}))
defer server.Close()
tool, err := NewWebSearchTool(WebSearchToolOptions{
TavilyEnabled: true,
TavilyAPIKeys: []string{"key1", "key2"},
TavilyBaseURL: server.URL,
TavilyMaxResults: 5,
})
if err != nil {
t.Fatalf("NewWebSearchTool() error: %v", err)
}
ctx := context.Background()
args := map[string]any{
"query": "test query",
}
result := tool.Execute(ctx, args)
if result.IsError {
t.Errorf("Expected success, got Error: %s", result.ForLLM)
}
if !strings.Contains(result.ForUser, "Success Result") {
t.Errorf("Expected failover to second key and success result, got: %s", result.ForUser)
}
}
func TestWebTool_GLMSearch_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
+272
View File
@@ -0,0 +1,272 @@
// Package utils provides shared, reusable algorithms.
// This file implements a generic BM25 search engine.
//
// Usage:
//
// type MyDoc struct { ID string; Body string }
//
// corpus := []MyDoc{...}
// engine := bm25.New(corpus, func(d MyDoc) string {
// return d.ID + " " + d.Body
// })
// results := engine.Search("my query", 5)
package utils
import (
"math"
"sort"
"strings"
)
// ── Tuning defaults ───────────────────────────────────────────────────────────
const (
// DefaultBM25K1 is the term-frequency saturation factor (typical range 1.22.0).
// Higher values give more weight to repeated terms.
DefaultBM25K1 = 1.2
// DefaultBM25B is the document-length normalization factor (0 = none, 1 = full).
DefaultBM25B = 0.75
)
// BM25Engine is a query-time BM25 search engine over a generic corpus.
// T is the document type; the caller supplies a TextFunc that extracts the
// searchable text from each document.
//
// The engine is stateless between queries: no caching, no invalidation logic.
// All indexing work is performed inside Search() on every call, making it
// safe to use on corpora that change frequently.
type BM25Engine[T any] struct {
corpus []T
textFunc func(T) string
k1 float64
b float64
}
// BM25Option is a functional option to configure a BM25Engine.
type BM25Option func(*bm25Config)
type bm25Config struct {
k1 float64
b float64
}
// WithK1 overrides the term-frequency saturation constant (default 1.2).
func WithK1(k1 float64) BM25Option {
return func(c *bm25Config) { c.k1 = k1 }
}
// WithB overrides the document-length normalization factor (default 0.75).
func WithB(b float64) BM25Option {
return func(c *bm25Config) { c.b = b }
}
// NewBM25Engine creates a BM25Engine for the given corpus.
//
// - corpus : slice of documents of any type T.
// - textFunc : function that returns the searchable text for a document.
// - opts : optional tuning (WithK1, WithB).
//
// The corpus slice is referenced, not copied. Callers must not mutate it
// concurrently with Search().
func NewBM25Engine[T any](corpus []T, textFunc func(T) string, opts ...BM25Option) *BM25Engine[T] {
cfg := bm25Config{k1: DefaultBM25K1, b: DefaultBM25B}
for _, o := range opts {
o(&cfg)
}
return &BM25Engine[T]{
corpus: corpus,
textFunc: textFunc,
k1: cfg.k1,
b: cfg.b,
}
}
// BM25Result is a single ranked result from a Search call.
type BM25Result[T any] struct {
Document T
Score float32
}
// Search ranks the corpus against query and returns the top-k results.
// Returns an empty slice (not nil) when there are no matches.
//
// Complexity: O(N×L) for indexing + O(|Q|×avgPostingLen) for scoring,
// where N = corpus size, L = average document length, Q = query terms.
// Top-k extraction uses a fixed-size min-heap: O(candidates × log k).
func (e *BM25Engine[T]) Search(query string, topK int) []BM25Result[T] {
if topK <= 0 {
return []BM25Result[T]{}
}
queryTerms := bm25Tokenize(query)
if len(queryTerms) == 0 {
return []BM25Result[T]{}
}
N := len(e.corpus)
if N == 0 {
return []BM25Result[T]{}
}
// Step 1: build per-document tf + raw doc lengths
type docEntry struct {
tf map[string]uint32
rawLen int
}
entries := make([]docEntry, N)
df := make(map[string]int, 64)
totalLen := 0
for i, doc := range e.corpus {
tokens := bm25Tokenize(e.textFunc(doc))
totalLen += len(tokens)
tf := make(map[string]uint32, len(tokens))
for _, t := range tokens {
tf[t]++
}
// df: each term counts once per document (iterate the map, keys are unique)
for t := range tf {
df[t]++
}
entries[i] = docEntry{tf: tf, rawLen: len(tokens)}
}
avgDocLen := float64(totalLen) / float64(N)
// Step 2: pre-compute IDF and per-doc length normalization
// IDF (Robertson smoothing): log( (N - df(t) + 0.5) / (df(t) + 0.5) + 1 )
idf := make(map[string]float32, len(df))
for term, freq := range df {
idf[term] = float32(math.Log(
(float64(N)-float64(freq)+0.5)/(float64(freq)+0.5) + 1,
))
}
// docLenNorm[i] = k1 * (1 - b + b * |doc_i| / avgDocLen)
// Stored as float32 — sufficient precision for ranking.
docLenNorm := make([]float32, N)
for i, entry := range entries {
docLenNorm[i] = float32(e.k1 * (1 - e.b + e.b*float64(entry.rawLen)/avgDocLen))
}
// Step 3: build inverted index (posting lists)
// Iterate the tf map directly — map keys are already unique, no seen-set needed.
posting := make(map[string][]int32, len(df))
for i, entry := range entries {
for term := range entry.tf {
posting[term] = append(posting[term], int32(i))
}
}
// Step 4: score via posting lists
// Deduplicate query terms to avoid double-weighting the same term.
unique := bm25Dedupe(queryTerms)
scores := make(map[int32]float32)
for _, term := range unique {
termIDF, ok := idf[term]
if !ok {
continue // term not in vocabulary → zero contribution
}
for _, docID := range posting[term] {
freq := float32(entries[docID].tf[term])
// TF_norm = freq * (k1+1) / (freq + docLenNorm)
tfNorm := freq * float32(e.k1+1) / (freq + docLenNorm[docID])
scores[docID] += termIDF * tfNorm
}
}
if len(scores) == 0 {
return []BM25Result[T]{}
}
// Step 5: top-K via fixed-size min-heap
heap := make([]bm25ScoredDoc, 0, topK)
for docID, sc := range scores {
switch {
case len(heap) < topK:
heap = append(heap, bm25ScoredDoc{docID: docID, score: sc})
if len(heap) == topK {
bm25MinHeapify(heap)
}
case sc > heap[0].score:
heap[0] = bm25ScoredDoc{docID: docID, score: sc}
bm25SiftDown(heap, 0)
}
}
sort.Slice(heap, func(i, j int) bool { return heap[i].score > heap[j].score })
out := make([]BM25Result[T], len(heap))
for i, h := range heap {
out[i] = BM25Result[T]{
Document: e.corpus[h.docID],
Score: h.score,
}
}
return out
}
// bm25Tokenize splits s into lowercase tokens, stripping edge punctuation.
func bm25Tokenize(s string) []string {
raw := strings.Fields(strings.ToLower(s))
out := raw[:0] // reuse backing array to avoid extra allocation
for _, t := range raw {
t = strings.Trim(t, ".,;:!?\"'()/\\-_")
if t != "" {
out = append(out, t)
}
}
return out
}
// bm25Dedupe returns a new slice with duplicate tokens removed,
// preserving first-occurrence order.
func bm25Dedupe(tokens []string) []string {
seen := make(map[string]struct{}, len(tokens))
out := make([]string, 0, len(tokens))
for _, t := range tokens {
if _, ok := seen[t]; !ok {
seen[t] = struct{}{}
out = append(out, t)
}
}
return out
}
type bm25ScoredDoc struct {
docID int32
score float32
}
// bm25MinHeapify builds a min-heap in-place using Floyd's algorithm: O(k).
func bm25MinHeapify(h []bm25ScoredDoc) {
for i := len(h)/2 - 1; i >= 0; i-- {
bm25SiftDown(h, i)
}
}
// bm25SiftDown restores the min-heap property starting at node i: O(log k).
func bm25SiftDown(h []bm25ScoredDoc, i int) {
n := len(h)
for {
smallest := i
l, r := 2*i+1, 2*i+2
if l < n && h[l].score < h[smallest].score {
smallest = l
}
if r < n && h[r].score < h[smallest].score {
smallest = r
}
if smallest == i {
break
}
h[i], h[smallest] = h[smallest], h[i]
i = smallest
}
}
+175
View File
@@ -0,0 +1,175 @@
package utils
import (
"reflect"
"testing"
)
// testDoc is a generic structure for use in tests.
type testDoc struct {
ID int
Text string
}
func extractText(d testDoc) string {
return d.Text
}
func TestBM25Search_EdgeCases(t *testing.T) {
corpus := []testDoc{
{1, "hello world"},
{2, "foo bar"},
}
engine := NewBM25Engine(corpus, extractText)
tests := []struct {
name string
query string
topK int
}{
{"Zero topK", "hello", 0},
{"Negative topK", "hello", -1},
{"Empty query", "", 5},
{"Query with only punctuation", "...,,,!!!", 5},
{"No matches found", "golang", 5},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
results := engine.Search(tt.query, tt.topK)
if len(results) != 0 {
t.Errorf("expected 0 results, got %d", len(results))
}
// Check that it never returns nil, but an empty slice
if results == nil {
t.Errorf("expected empty slice, got nil")
}
})
}
}
func TestBM25Search_EmptyCorpus(t *testing.T) {
engine := NewBM25Engine([]testDoc{}, extractText)
results := engine.Search("hello", 5)
if len(results) != 0 || results == nil {
t.Errorf("expected empty slice from empty corpus, got %v", results)
}
}
func TestBM25Search_RankingLogic(t *testing.T) {
corpus := []testDoc{
{1, "the quick brown fox jumps over the lazy dog"},
{2, "quick fox"},
{3, "quick quick quick fox"}, // High Term Frequency (TF)
{4, "completely irrelevant document here"},
}
engine := NewBM25Engine(corpus, extractText)
t.Run("Term Frequency (TF) boosts score", func(t *testing.T) {
results := engine.Search("quick", 5)
if len(results) < 3 {
t.Fatalf("expected at least 3 results, got %d", len(results))
}
// Doc 3 has the word "quick" repeated 3 times, it should beat Doc 2
if results[0].Document.ID != 3 {
t.Errorf("expected doc 3 to rank first due to high TF, got doc %d", results[0].Document.ID)
}
})
t.Run("Document Length penalty", func(t *testing.T) {
results := engine.Search("fox", 5)
if len(results) < 3 {
t.Fatalf("expected at least 3 results, got %d", len(results))
}
// Doc 2 ("quick fox") is much shorter than Doc 1 ("the quick brown fox..."),
// so, with equal Term Frequency for the word "fox" (1 time), Doc 2 wins.
if results[0].Document.ID != 2 {
t.Errorf("expected doc 2 to rank first due to shorter length, got doc %d", results[0].Document.ID)
}
})
t.Run("TopK limits results", func(t *testing.T) {
results := engine.Search("quick", 2)
if len(results) != 2 {
t.Errorf("expected exactly 2 results, got %d", len(results))
}
})
}
func TestBM25Tokenize(t *testing.T) {
tests := []struct {
input string
expected []string
}{
{"Hello World", []string{"hello", "world"}},
{" spaces everywhere ", []string{"spaces", "everywhere"}},
{"punctuation... test!!!", []string{"punctuation", "test"}},
{"(parentheses) and-hyphens", []string{"parentheses", "and-hyphens"}}, // hyphens trimmed from edges
{"internal-hyphen is kept", []string{"internal-hyphen", "is", "kept"}},
{".,;?!", []string{}}, // Becomes empty after trim
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
got := bm25Tokenize(tt.input)
if len(got) == 0 && len(tt.expected) == 0 {
return // Both empty
}
if !reflect.DeepEqual(got, tt.expected) {
t.Errorf("bm25Tokenize(%q) = %v, want %v", tt.input, got, tt.expected)
}
})
}
}
func TestBM25Dedupe(t *testing.T) {
input := []string{"apple", "banana", "apple", "orange", "banana"}
expected := []string{"apple", "banana", "orange"}
got := bm25Dedupe(input)
if !reflect.DeepEqual(got, expected) {
t.Errorf("bm25Dedupe() = %v, want %v", got, expected)
}
}
func TestBM25Options(t *testing.T) {
corpus := []testDoc{{1, "test"}}
engine := NewBM25Engine(
corpus,
extractText,
WithK1(2.5),
WithB(0.9),
)
if engine.k1 != 2.5 {
t.Errorf("expected k1 to be 2.5, got %v", engine.k1)
}
if engine.b != 0.9 {
t.Errorf("expected b to be 0.9, got %v", engine.b)
}
}
func TestBM25Search_SortingStability(t *testing.T) {
// Ensure that sorting by heap returns in correct descending order
corpus := []testDoc{
{1, "golang is good"},
{2, "golang golang"},
{3, "golang golang golang"},
{4, "golang golang golang golang"},
}
engine := NewBM25Engine(corpus, extractText)
results := engine.Search("golang", 10)
if len(results) != 4 {
t.Fatalf("expected 4 results, got %d", len(results))
}
// Score should be strictly decreasing
for i := 1; i < len(results); i++ {
if results[i].Score > results[i-1].Score {
t.Errorf("results not sorted correctly: result %d score (%v) > result %d score (%v)",
i, results[i].Score, i-1, results[i-1].Score)
}
}
}
+13
View File
@@ -2,9 +2,18 @@ package utils
import (
"strings"
"sync/atomic"
"unicode"
)
// Global variable to disable truncation
var disableTruncation atomic.Bool
// SetDisableTruncation globally enables or disables string truncation
func SetDisableTruncation(enabled bool) {
disableTruncation.Store(enabled)
}
// SanitizeMessageContent removes Unicode control characters, format characters (RTL overrides,
// zero-width characters), and other non-graphic characters that could confuse an LLM
// or cause display issues in the agent UI.
@@ -30,6 +39,10 @@ func SanitizeMessageContent(input string) string {
// Handles multi-byte Unicode characters properly.
// If the string is truncated, "..." is appended to indicate truncation.
func Truncate(s string, maxLen int) string {
// If the no-truncate flag is active, it returns the full string
if disableTruncation.Load() {
return s
}
if maxLen <= 0 {
return ""
}