Merge remote-tracking branch 'origin_picoclaw/main'

This commit is contained in:
zihan987
2026-03-04 09:58:20 -08:00
88 changed files with 9366 additions and 1329 deletions
+123 -58
View File
@@ -34,6 +34,11 @@ type ContextBuilder struct {
// created (didn't exist at cache time, now exist) or deleted (existed at
// cache time, now gone) — both of which should trigger a cache rebuild.
existedAtCache map[string]bool
// skillFilesAtCache snapshots the skill tree file set and mtimes at cache
// build time. This catches nested file creations/deletions/mtime changes
// that may not update the top-level skill root directory mtime.
skillFilesAtCache map[string]time.Time
}
func getGlobalConfigDir() string {
@@ -47,8 +52,11 @@ func getGlobalConfigDir() string {
func NewContextBuilder(workspace string) *ContextBuilder {
// builtin skills: skills directory in current project
// Use the skills/ directory under the current working directory
wd, _ := os.Getwd()
builtinSkillsDir := filepath.Join(wd, "skills")
builtinSkillsDir := strings.TrimSpace(os.Getenv("PICOCLAW_BUILTIN_SKILLS"))
if builtinSkillsDir == "" {
wd, _ := os.Getwd()
builtinSkillsDir = filepath.Join(wd, "skills")
}
globalSkillsDir := filepath.Join(getGlobalConfigDir(), "skills")
return &ContextBuilder{
@@ -148,6 +156,7 @@ func (cb *ContextBuilder) BuildSystemPromptWithCache() string {
cb.cachedSystemPrompt = prompt
cb.cachedAt = baseline.maxMtime
cb.existedAtCache = baseline.existed
cb.skillFilesAtCache = baseline.skillFiles
logger.DebugCF("agent", "System prompt cached",
map[string]any{
@@ -167,14 +176,14 @@ func (cb *ContextBuilder) InvalidateCache() {
cb.cachedSystemPrompt = ""
cb.cachedAt = time.Time{}
cb.existedAtCache = nil
cb.skillFilesAtCache = nil
logger.DebugCF("agent", "System prompt cache invalidated", nil)
}
// sourcePaths returns the workspace source file paths tracked for cache
// invalidation (bootstrap files + memory). The skills directory is handled
// separately in sourceFilesChangedLocked because it requires both directory-
// level and recursive file-level mtime checks.
// sourcePaths returns non-skill workspace source files tracked for cache
// invalidation (bootstrap files + memory). Skill roots are handled separately
// because they require both directory-level and recursive file-level checks.
func (cb *ContextBuilder) sourcePaths() []string {
return []string{
filepath.Join(cb.workspace, "AGENTS.md"),
@@ -185,23 +194,39 @@ func (cb *ContextBuilder) sourcePaths() []string {
}
}
// skillRoots returns all skill root directories that can affect
// BuildSkillsSummary output (workspace/global/builtin).
func (cb *ContextBuilder) skillRoots() []string {
if cb.skillsLoader == nil {
return []string{filepath.Join(cb.workspace, "skills")}
}
roots := cb.skillsLoader.SkillRoots()
if len(roots) == 0 {
return []string{filepath.Join(cb.workspace, "skills")}
}
return roots
}
// cacheBaseline holds the file existence snapshot and the latest observed
// mtime across all tracked paths. Used as the cache reference point.
type cacheBaseline struct {
existed map[string]bool
maxMtime time.Time
existed map[string]bool
skillFiles map[string]time.Time
maxMtime time.Time
}
// buildCacheBaseline records which tracked paths currently exist and computes
// the latest mtime across all tracked files + skills directory contents.
// Called under write lock when the cache is built.
func (cb *ContextBuilder) buildCacheBaseline() cacheBaseline {
skillsDir := filepath.Join(cb.workspace, "skills")
skillRoots := cb.skillRoots()
// All paths whose existence we track: source files + skills dir.
allPaths := append(cb.sourcePaths(), skillsDir)
// All paths whose existence we track: source files + all skill roots.
allPaths := append(cb.sourcePaths(), skillRoots...)
existed := make(map[string]bool, len(allPaths))
skillFiles := make(map[string]time.Time)
var maxMtime time.Time
for _, p := range allPaths {
@@ -212,17 +237,21 @@ func (cb *ContextBuilder) buildCacheBaseline() cacheBaseline {
}
}
// Walk skills files to capture their mtimes too.
// Use os.Stat (not d.Info) to match the stat method used in
// fileChangedSince / skillFilesModifiedSince for consistency.
_ = filepath.WalkDir(skillsDir, func(path string, d fs.DirEntry, walkErr error) error {
if walkErr == nil && !d.IsDir() {
if info, err := os.Stat(path); err == nil && info.ModTime().After(maxMtime) {
maxMtime = info.ModTime()
// Walk all skill roots recursively to snapshot skill files and mtimes.
// Use os.Stat (not d.Info) for consistency with sourceFilesChanged checks.
for _, root := range skillRoots {
_ = filepath.WalkDir(root, func(path string, d fs.DirEntry, walkErr error) error {
if walkErr == nil && !d.IsDir() {
if info, err := os.Stat(path); err == nil {
skillFiles[path] = info.ModTime()
if info.ModTime().After(maxMtime) {
maxMtime = info.ModTime()
}
}
}
}
return nil
})
return nil
})
}
// If no tracked files exist yet (empty workspace), maxMtime is zero.
// Use a very old non-zero time so that:
@@ -234,7 +263,7 @@ func (cb *ContextBuilder) buildCacheBaseline() cacheBaseline {
maxMtime = time.Unix(1, 0)
}
return cacheBaseline{existed: existed, maxMtime: maxMtime}
return cacheBaseline{existed: existed, skillFiles: skillFiles, maxMtime: maxMtime}
}
// sourceFilesChangedLocked checks whether any workspace source file has been
@@ -254,21 +283,17 @@ func (cb *ContextBuilder) sourceFilesChangedLocked() bool {
return true
}
// --- Skills directory (handled separately from sourcePaths) ---
// --- Skill roots (workspace/global/builtin) ---
//
// 1. Creation/deletion: tracked via existedAtCache, same as bootstrap files.
skillsDir := filepath.Join(cb.workspace, "skills")
if cb.fileChangedSince(skillsDir) {
return true
// For each root:
// 1. Creation/deletion and root directory mtime changes are tracked by fileChangedSince.
// 2. Nested file create/delete/mtime changes are tracked by the skill file snapshot.
for _, root := range cb.skillRoots() {
if cb.fileChangedSince(root) {
return true
}
}
// 2. Structural changes (add/remove entries inside the dir) are reflected
// in the directory's own mtime, which fileChangedSince already checks.
//
// 3. Content-only edits to files inside skills/ do NOT update the parent
// directory mtime on most filesystems, so we recursively walk to check
// individual file mtimes at any nesting depth.
if skillFilesModifiedSince(skillsDir, cb.cachedAt) {
if skillFilesChangedSince(cb.skillRoots(), cb.skillFilesAtCache) {
return true
}
@@ -309,28 +334,64 @@ func (cb *ContextBuilder) fileChangedSince(path string) bool {
// if the callback returned nil when its err parameter is non-nil.
var errWalkStop = errors.New("walk stop")
// skillFilesModifiedSince recursively walks the skills directory and checks
// whether any file was modified after t. This catches content-only edits at
// any nesting depth (e.g. skills/name/docs/extra.md) that don't update
// parent directory mtimes.
func skillFilesModifiedSince(skillsDir string, t time.Time) bool {
changed := false
err := filepath.WalkDir(skillsDir, func(path string, d fs.DirEntry, walkErr error) error {
if walkErr == nil && !d.IsDir() {
if info, statErr := os.Stat(path); statErr == nil && info.ModTime().After(t) {
changed = true
return errWalkStop // stop walking
}
}
return nil
})
// errWalkStop is expected (early exit on first changed file).
// os.IsNotExist means the skills dir doesn't exist yet — not an error.
// Any other error is unexpected and worth logging.
if err != nil && !errors.Is(err, errWalkStop) && !os.IsNotExist(err) {
logger.DebugCF("agent", "skills walk error", map[string]any{"error": err.Error()})
// skillFilesChangedSince compares the current recursive skill file tree
// against the cache-time snapshot. Any create/delete/mtime drift invalidates
// the cache.
func skillFilesChangedSince(skillRoots []string, filesAtCache map[string]time.Time) bool {
// Defensive: if the snapshot was never initialized, force rebuild.
if filesAtCache == nil {
return true
}
return changed
// Check cached files still exist and keep the same mtime.
for path, cachedMtime := range filesAtCache {
info, err := os.Stat(path)
if err != nil {
// A previously tracked file disappeared (or became inaccessible):
// either way, cached skill summary may now be stale.
return true
}
if !info.ModTime().Equal(cachedMtime) {
return true
}
}
// Check no new files appeared under any skill root.
changed := false
for _, root := range skillRoots {
if strings.TrimSpace(root) == "" {
continue
}
err := filepath.WalkDir(root, func(path string, d fs.DirEntry, walkErr error) error {
if walkErr != nil {
// Treat unexpected walk errors as changed to avoid stale cache.
if !os.IsNotExist(walkErr) {
changed = true
return errWalkStop
}
return nil
}
if d.IsDir() {
return nil
}
if _, ok := filesAtCache[path]; !ok {
changed = true
return errWalkStop
}
return nil
})
if changed {
return true
}
if err != nil && !errors.Is(err, errWalkStop) && !os.IsNotExist(err) {
logger.DebugCF("agent", "skills walk error", map[string]any{"error": err.Error()})
return true
}
}
return false
}
func (cb *ContextBuilder) LoadBootstrapFiles() string {
@@ -466,10 +527,14 @@ func (cb *ContextBuilder) BuildMessages(
// Add current user message
if strings.TrimSpace(currentMessage) != "" {
messages = append(messages, providers.Message{
msg := providers.Message{
Role: "user",
Content: currentMessage,
})
}
if len(media) > 0 {
msg.Media = media
}
messages = append(messages, msg)
}
return messages
+156
View File
@@ -383,6 +383,162 @@ Updated content.`
}
}
// TestGlobalSkillFileContentChange verifies that modifying a global skill
// (~/.picoclaw/skills) invalidates the cached system prompt.
func TestGlobalSkillFileContentChange(t *testing.T) {
tmpHome := t.TempDir()
t.Setenv("HOME", tmpHome)
tmpDir := setupWorkspace(t, nil)
defer os.RemoveAll(tmpDir)
globalSkillPath := filepath.Join(tmpHome, ".picoclaw", "skills", "global-skill", "SKILL.md")
if err := os.MkdirAll(filepath.Dir(globalSkillPath), 0o755); err != nil {
t.Fatal(err)
}
v1 := `---
name: global-skill
description: global-v1
---
# Global Skill v1`
if err := os.WriteFile(globalSkillPath, []byte(v1), 0o644); err != nil {
t.Fatal(err)
}
cb := NewContextBuilder(tmpDir)
sp1 := cb.BuildSystemPromptWithCache()
if !strings.Contains(sp1, "global-v1") {
t.Fatal("expected initial prompt to contain global skill description")
}
v2 := `---
name: global-skill
description: global-v2
---
# Global Skill v2`
if err := os.WriteFile(globalSkillPath, []byte(v2), 0o644); err != nil {
t.Fatal(err)
}
future := time.Now().Add(2 * time.Second)
if err := os.Chtimes(globalSkillPath, future, future); err != nil {
t.Fatalf("failed to update mtime for %s: %v", globalSkillPath, err)
}
cb.systemPromptMutex.RLock()
changed := cb.sourceFilesChangedLocked()
cb.systemPromptMutex.RUnlock()
if !changed {
t.Fatal("sourceFilesChangedLocked() should detect global skill file content change")
}
sp2 := cb.BuildSystemPromptWithCache()
if !strings.Contains(sp2, "global-v2") {
t.Error("rebuilt prompt should contain updated global skill description")
}
if sp1 == sp2 {
t.Error("cache should be invalidated when global skill file content changes")
}
}
// TestBuiltinSkillFileContentChange verifies that modifying a builtin skill
// invalidates the cached system prompt.
func TestBuiltinSkillFileContentChange(t *testing.T) {
tmpHome := t.TempDir()
t.Setenv("HOME", tmpHome)
tmpDir := setupWorkspace(t, nil)
defer os.RemoveAll(tmpDir)
builtinRoot := t.TempDir()
t.Setenv("PICOCLAW_BUILTIN_SKILLS", builtinRoot)
builtinSkillPath := filepath.Join(builtinRoot, "builtin-skill", "SKILL.md")
if err := os.MkdirAll(filepath.Dir(builtinSkillPath), 0o755); err != nil {
t.Fatal(err)
}
v1 := `---
name: builtin-skill
description: builtin-v1
---
# Builtin Skill v1`
if err := os.WriteFile(builtinSkillPath, []byte(v1), 0o644); err != nil {
t.Fatal(err)
}
cb := NewContextBuilder(tmpDir)
sp1 := cb.BuildSystemPromptWithCache()
if !strings.Contains(sp1, "builtin-v1") {
t.Fatal("expected initial prompt to contain builtin skill description")
}
v2 := `---
name: builtin-skill
description: builtin-v2
---
# Builtin Skill v2`
if err := os.WriteFile(builtinSkillPath, []byte(v2), 0o644); err != nil {
t.Fatal(err)
}
future := time.Now().Add(2 * time.Second)
if err := os.Chtimes(builtinSkillPath, future, future); err != nil {
t.Fatalf("failed to update mtime for %s: %v", builtinSkillPath, err)
}
cb.systemPromptMutex.RLock()
changed := cb.sourceFilesChangedLocked()
cb.systemPromptMutex.RUnlock()
if !changed {
t.Fatal("sourceFilesChangedLocked() should detect builtin skill file content change")
}
sp2 := cb.BuildSystemPromptWithCache()
if !strings.Contains(sp2, "builtin-v2") {
t.Error("rebuilt prompt should contain updated builtin skill description")
}
if sp1 == sp2 {
t.Error("cache should be invalidated when builtin skill file content changes")
}
}
// TestSkillFileDeletionInvalidatesCache verifies that deleting a nested skill
// file invalidates the cached system prompt.
func TestSkillFileDeletionInvalidatesCache(t *testing.T) {
tmpDir := setupWorkspace(t, map[string]string{
"skills/delete-me/SKILL.md": `---
name: delete-me
description: delete-me-v1
---
# Delete Me`,
})
defer os.RemoveAll(tmpDir)
cb := NewContextBuilder(tmpDir)
sp1 := cb.BuildSystemPromptWithCache()
if !strings.Contains(sp1, "delete-me-v1") {
t.Fatal("expected initial prompt to contain skill description")
}
skillPath := filepath.Join(tmpDir, "skills", "delete-me", "SKILL.md")
if err := os.Remove(skillPath); err != nil {
t.Fatal(err)
}
cb.systemPromptMutex.RLock()
changed := cb.sourceFilesChangedLocked()
cb.systemPromptMutex.RUnlock()
if !changed {
t.Fatal("sourceFilesChangedLocked() should detect deleted skill file")
}
sp2 := cb.BuildSystemPromptWithCache()
if strings.Contains(sp2, "delete-me-v1") {
t.Error("rebuilt prompt should not contain deleted skill description")
}
if sp1 == sp2 {
t.Error("cache should be invalidated when skill file is deleted")
}
}
// TestConcurrentBuildSystemPromptWithCache verifies that multiple goroutines
// can safely call BuildSystemPromptWithCache concurrently without producing
// empty results, panics, or data races.
+46 -32
View File
@@ -18,22 +18,24 @@ import (
// AgentInstance represents a fully configured agent with its own workspace,
// session manager, context builder, and tool registry.
type AgentInstance struct {
ID string
Name string
Model string
Fallbacks []string
Workspace string
MaxIterations int
MaxTokens int
Temperature float64
ContextWindow int
Provider providers.LLMProvider
Sessions *session.SessionManager
ContextBuilder *ContextBuilder
Tools *tools.ToolRegistry
Subagents *config.SubagentsConfig
SkillsFilter []string
Candidates []providers.FallbackCandidate
ID string
Name string
Model string
Fallbacks []string
Workspace string
MaxIterations int
MaxTokens int
Temperature float64
ContextWindow int
SummarizeMessageThreshold int
SummarizeTokenPercent int
Provider providers.LLMProvider
Sessions *session.SessionManager
ContextBuilder *ContextBuilder
Tools *tools.ToolRegistry
Subagents *config.SubagentsConfig
SkillsFilter []string
Candidates []providers.FallbackCandidate
}
// NewAgentInstance creates an agent instance from config.
@@ -101,6 +103,16 @@ func NewAgentInstance(
temperature = *defaults.Temperature
}
summarizeMessageThreshold := defaults.SummarizeMessageThreshold
if summarizeMessageThreshold == 0 {
summarizeMessageThreshold = 20
}
summarizeTokenPercent := defaults.SummarizeTokenPercent
if summarizeTokenPercent == 0 {
summarizeTokenPercent = 75
}
// Resolve fallback candidates
modelCfg := providers.ModelConfig{
Primary: model,
@@ -149,22 +161,24 @@ func NewAgentInstance(
candidates := providers.ResolveCandidatesWithLookup(modelCfg, defaults.Provider, resolveFromModelList)
return &AgentInstance{
ID: agentID,
Name: agentName,
Model: model,
Fallbacks: fallbacks,
Workspace: workspace,
MaxIterations: maxIter,
MaxTokens: maxTokens,
Temperature: temperature,
ContextWindow: maxTokens,
Provider: provider,
Sessions: sessionsManager,
ContextBuilder: contextBuilder,
Tools: toolsRegistry,
Subagents: subagents,
SkillsFilter: skillsFilter,
Candidates: candidates,
ID: agentID,
Name: agentName,
Model: model,
Fallbacks: fallbacks,
Workspace: workspace,
MaxIterations: maxIter,
MaxTokens: maxTokens,
Temperature: temperature,
ContextWindow: maxTokens,
SummarizeMessageThreshold: summarizeMessageThreshold,
SummarizeTokenPercent: summarizeTokenPercent,
Provider: provider,
Sessions: sessionsManager,
ContextBuilder: contextBuilder,
Tools: toolsRegistry,
Subagents: subagents,
SkillsFilter: skillsFilter,
Candidates: candidates,
}
}
+58 -65
View File
@@ -95,75 +95,68 @@ func TestNewAgentInstance_DefaultsTemperatureWhenUnset(t *testing.T) {
}
func TestNewAgentInstance_ResolveCandidatesFromModelListAlias(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-instance-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "step-3.5-flash",
},
tests := []struct {
name string
aliasName string
modelName string
apiBase string
wantProvider string
wantModel string
}{
{
name: "alias with provider prefix",
aliasName: "step-3.5-flash",
modelName: "openrouter/stepfun/step-3.5-flash:free",
apiBase: "https://openrouter.ai/api/v1",
wantProvider: "openrouter",
wantModel: "stepfun/step-3.5-flash:free",
},
ModelList: []config.ModelConfig{
{
ModelName: "step-3.5-flash",
Model: "openrouter/stepfun/step-3.5-flash:free",
APIBase: "https://openrouter.ai/api/v1",
},
{
name: "alias without provider prefix",
aliasName: "glm-5",
modelName: "glm-5",
apiBase: "https://api.z.ai/api/coding/paas/v4",
wantProvider: "openai",
wantModel: "glm-5",
},
}
provider := &mockProvider{}
agent := NewAgentInstance(nil, &cfg.Agents.Defaults, cfg, provider)
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-instance-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
if len(agent.Candidates) != 1 {
t.Fatalf("len(Candidates) = %d, want 1", len(agent.Candidates))
}
if agent.Candidates[0].Provider != "openrouter" {
t.Fatalf("candidate provider = %q, want %q", agent.Candidates[0].Provider, "openrouter")
}
if agent.Candidates[0].Model != "stepfun/step-3.5-flash:free" {
t.Fatalf("candidate model = %q, want %q", agent.Candidates[0].Model, "stepfun/step-3.5-flash:free")
}
}
func TestNewAgentInstance_ResolveCandidatesFromModelListAliasWithoutProtocol(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-instance-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "glm-5",
},
},
ModelList: []config.ModelConfig{
{
ModelName: "glm-5",
Model: "glm-5",
APIBase: "https://api.z.ai/api/coding/paas/v4",
},
},
}
provider := &mockProvider{}
agent := NewAgentInstance(nil, &cfg.Agents.Defaults, cfg, provider)
if len(agent.Candidates) != 1 {
t.Fatalf("len(Candidates) = %d, want 1", len(agent.Candidates))
}
if agent.Candidates[0].Provider != "openai" {
t.Fatalf("candidate provider = %q, want %q", agent.Candidates[0].Provider, "openai")
}
if agent.Candidates[0].Model != "glm-5" {
t.Fatalf("candidate model = %q, want %q", agent.Candidates[0].Model, "glm-5")
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: tt.aliasName,
},
},
ModelList: []config.ModelConfig{
{
ModelName: tt.aliasName,
Model: tt.modelName,
APIBase: tt.apiBase,
},
},
}
provider := &mockProvider{}
agent := NewAgentInstance(nil, &cfg.Agents.Defaults, cfg, provider)
if len(agent.Candidates) != 1 {
t.Fatalf("len(Candidates) = %d, want 1", len(agent.Candidates))
}
if agent.Candidates[0].Provider != tt.wantProvider {
t.Fatalf("candidate provider = %q, want %q", agent.Candidates[0].Provider, tt.wantProvider)
}
if agent.Candidates[0].Model != tt.wantModel {
t.Fatalf("candidate model = %q, want %q", agent.Candidates[0].Model, tt.wantModel)
}
})
}
}
+297 -82
View File
@@ -12,6 +12,7 @@ import (
"errors"
"fmt"
"path/filepath"
"regexp"
"strings"
"sync"
"sync/atomic"
@@ -23,6 +24,7 @@ import (
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/constants"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/mcp"
"github.com/sipeed/picoclaw/pkg/media"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/routing"
@@ -30,6 +32,7 @@ import (
"github.com/sipeed/picoclaw/pkg/state"
"github.com/sipeed/picoclaw/pkg/tools"
"github.com/sipeed/picoclaw/pkg/utils"
"github.com/sipeed/picoclaw/pkg/voice"
)
type AgentLoop struct {
@@ -42,23 +45,29 @@ type AgentLoop struct {
fallback *providers.FallbackChain
channelManager *channels.Manager
mediaStore media.MediaStore
transcriber voice.Transcriber
}
// processOptions configures how a message is processed
type processOptions struct {
SessionKey string // Session identifier for history/context
Channel string // Target channel for tool execution
ChatID string // Target chat ID for tool execution
UserMessage string // User message content (may include prefix)
DefaultResponse string // Response when LLM returns empty
EnableSummary bool // Whether to trigger summarization
SendResponse bool // Whether to send response via bus
NoHistory bool // If true, don't load session history (for heartbeat)
SessionKey string // Session identifier for history/context
Channel string // Target channel for tool execution
ChatID string // Target chat ID for tool execution
UserMessage string // User message content (may include prefix)
Media []string // media:// refs from inbound message
DefaultResponse string // Response when LLM returns empty
EnableSummary bool // Whether to trigger summarization
SendResponse bool // Whether to send response via bus
NoHistory bool // If true, don't load session history (for heartbeat)
}
const defaultResponse = "I've completed processing but have no response to give. Increase `max_tool_iterations` in config.json."
func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers.LLMProvider) *AgentLoop {
func NewAgentLoop(
cfg *config.Config,
msgBus *bus.MessageBus,
provider providers.LLMProvider,
) *AgentLoop {
registry := NewAgentRegistry(cfg, provider)
// Register shared tools to all agents
@@ -112,6 +121,11 @@ func registerSharedTools(
PerplexityAPIKey: cfg.Tools.Web.Perplexity.APIKey,
PerplexityMaxResults: cfg.Tools.Web.Perplexity.MaxResults,
PerplexityEnabled: cfg.Tools.Web.Perplexity.Enabled,
GLMSearchAPIKey: cfg.Tools.Web.GLMSearch.APIKey,
GLMSearchBaseURL: cfg.Tools.Web.GLMSearch.BaseURL,
GLMSearchEngine: cfg.Tools.Web.GLMSearch.SearchEngine,
GLMSearchMaxResults: cfg.Tools.Web.GLMSearch.MaxResults,
GLMSearchEnabled: cfg.Tools.Web.GLMSearch.Enabled,
Proxy: cfg.Tools.Web.Proxy,
})
if err != nil {
@@ -170,6 +184,72 @@ func registerSharedTools(
func (al *AgentLoop) Run(ctx context.Context) error {
al.running.Store(true)
// Initialize MCP servers for all agents
if al.cfg.Tools.MCP.Enabled {
mcpManager := mcp.NewManager()
// Ensure MCP connections are cleaned up on exit, regardless of initialization success
// This fixes resource leak when LoadFromMCPConfig partially succeeds then fails
defer func() {
if err := mcpManager.Close(); err != nil {
logger.ErrorCF("agent", "Failed to close MCP manager",
map[string]any{
"error": err.Error(),
})
}
}()
defaultAgent := al.registry.GetDefaultAgent()
var workspacePath string
if defaultAgent != nil && defaultAgent.Workspace != "" {
workspacePath = defaultAgent.Workspace
} else {
workspacePath = al.cfg.WorkspacePath()
}
if err := mcpManager.LoadFromMCPConfig(ctx, al.cfg.Tools.MCP, workspacePath); err != nil {
logger.WarnCF("agent", "Failed to load MCP servers, MCP tools will not be available",
map[string]any{
"error": err.Error(),
})
} else {
// Register MCP tools for all agents
servers := mcpManager.GetServers()
uniqueTools := 0
totalRegistrations := 0
agentIDs := al.registry.ListAgentIDs()
agentCount := len(agentIDs)
for serverName, conn := range servers {
uniqueTools += len(conn.Tools)
for _, tool := range conn.Tools {
for _, agentID := range agentIDs {
agent, ok := al.registry.GetAgent(agentID)
if !ok {
continue
}
mcpTool := tools.NewMCPTool(mcpManager, serverName, tool)
agent.Tools.Register(mcpTool)
totalRegistrations++
logger.DebugCF("agent", "Registered MCP tool",
map[string]any{
"agent_id": agentID,
"server": serverName,
"tool": tool.Name,
"name": mcpTool.Name(),
})
}
}
}
logger.InfoCF("agent", "MCP tools registered successfully",
map[string]any{
"server_count": len(servers),
"unique_tools": uniqueTools,
"total_registrations": totalRegistrations,
"agent_count": agentCount,
})
}
}
for al.running.Load() {
select {
case <-ctx.Done():
@@ -262,6 +342,64 @@ func (al *AgentLoop) SetMediaStore(s media.MediaStore) {
al.mediaStore = s
}
// SetTranscriber injects a voice transcriber for agent-level audio transcription.
func (al *AgentLoop) SetTranscriber(t voice.Transcriber) {
al.transcriber = t
}
var audioAnnotationRe = regexp.MustCompile(`\[(voice|audio)(?::[^\]]*)?\]`)
// transcribeAudioInMessage resolves audio media refs, transcribes them, and
// replaces audio annotations in msg.Content with the transcribed text.
func (al *AgentLoop) transcribeAudioInMessage(ctx context.Context, msg bus.InboundMessage) bus.InboundMessage {
if al.transcriber == nil || al.mediaStore == nil || len(msg.Media) == 0 {
return msg
}
// Transcribe each audio media ref in order.
var transcriptions []string
for _, ref := range msg.Media {
path, meta, err := al.mediaStore.ResolveWithMeta(ref)
if err != nil {
logger.WarnCF("voice", "Failed to resolve media ref", map[string]any{"ref": ref, "error": err})
continue
}
if !utils.IsAudioFile(meta.Filename, meta.ContentType) {
continue
}
result, err := al.transcriber.Transcribe(ctx, path)
if err != nil {
logger.WarnCF("voice", "Transcription failed", map[string]any{"ref": ref, "error": err})
transcriptions = append(transcriptions, "")
continue
}
transcriptions = append(transcriptions, result.Text)
}
if len(transcriptions) == 0 {
return msg
}
// Replace audio annotations sequentially with transcriptions.
idx := 0
newContent := audioAnnotationRe.ReplaceAllStringFunc(msg.Content, func(match string) string {
if idx >= len(transcriptions) {
return match
}
text := transcriptions[idx]
idx++
return "[voice: " + text + "]"
})
// Append any remaining transcriptions not matched by an annotation.
for ; idx < len(transcriptions); idx++ {
newContent += "\n[voice: " + transcriptions[idx] + "]"
}
msg.Content = newContent
return msg
}
// inferMediaType determines the media type ("image", "audio", "video", "file")
// from a filename and MIME content type.
func inferMediaType(filename, contentType string) string {
@@ -310,7 +448,10 @@ func (al *AgentLoop) RecordLastChatID(chatID string) error {
return al.state.SetLastChatID(chatID)
}
func (al *AgentLoop) ProcessDirect(ctx context.Context, content, sessionKey string) (string, error) {
func (al *AgentLoop) ProcessDirect(
ctx context.Context,
content, sessionKey string,
) (string, error) {
return al.ProcessDirectWithChannel(ctx, content, sessionKey, "cli", "direct")
}
@@ -331,7 +472,10 @@ func (al *AgentLoop) ProcessDirectWithChannel(
// ProcessHeartbeat processes a heartbeat request without session history.
// Each heartbeat is independent and doesn't accumulate context.
func (al *AgentLoop) ProcessHeartbeat(ctx context.Context, content, channel, chatID string) (string, error) {
func (al *AgentLoop) ProcessHeartbeat(
ctx context.Context,
content, channel, chatID string,
) (string, error) {
agent := al.registry.GetDefaultAgent()
if agent == nil {
return "", fmt.Errorf("no default agent for heartbeat")
@@ -356,13 +500,18 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage)
} else {
logContent = utils.Truncate(msg.Content, 80)
}
logger.InfoCF("agent", fmt.Sprintf("Processing message from %s:%s: %s", msg.Channel, msg.SenderID, logContent),
logger.InfoCF(
"agent",
fmt.Sprintf("Processing message from %s:%s: %s", msg.Channel, msg.SenderID, logContent),
map[string]any{
"channel": msg.Channel,
"chat_id": msg.ChatID,
"sender_id": msg.SenderID,
"session_key": msg.SessionKey,
})
},
)
msg = al.transcribeAudioInMessage(ctx, msg)
// Route system messages to processSystemMessage
if msg.Channel == "system" {
@@ -417,15 +566,22 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage)
Channel: msg.Channel,
ChatID: msg.ChatID,
UserMessage: msg.Content,
Media: msg.Media,
DefaultResponse: defaultResponse,
EnableSummary: true,
SendResponse: false,
})
}
func (al *AgentLoop) processSystemMessage(ctx context.Context, msg bus.InboundMessage) (string, error) {
func (al *AgentLoop) processSystemMessage(
ctx context.Context,
msg bus.InboundMessage,
) (string, error) {
if msg.Channel != "system" {
return "", fmt.Errorf("processSystemMessage called with non-system message channel: %s", msg.Channel)
return "", fmt.Errorf(
"processSystemMessage called with non-system message channel: %s",
msg.Channel,
)
}
logger.InfoCF("agent", "Processing system message",
@@ -483,14 +639,22 @@ func (al *AgentLoop) processSystemMessage(ctx context.Context, msg bus.InboundMe
}
// runAgentLoop is the core message processing logic.
func (al *AgentLoop) runAgentLoop(ctx context.Context, agent *AgentInstance, opts processOptions) (string, error) {
func (al *AgentLoop) runAgentLoop(
ctx context.Context,
agent *AgentInstance,
opts processOptions,
) (string, error) {
// 0. Record last channel for heartbeat notifications (skip internal channels)
if opts.Channel != "" && opts.ChatID != "" {
// Don't record internal channels (cli, system, subagent)
if !constants.IsInternalChannel(opts.Channel) {
channelKey := fmt.Sprintf("%s:%s", opts.Channel, opts.ChatID)
if err := al.RecordLastChannel(channelKey); err != nil {
logger.WarnCF("agent", "Failed to record last channel", map[string]any{"error": err.Error()})
logger.WarnCF(
"agent",
"Failed to record last channel",
map[string]any{"error": err.Error()},
)
}
}
}
@@ -509,11 +673,15 @@ func (al *AgentLoop) runAgentLoop(ctx context.Context, agent *AgentInstance, opt
history,
summary,
opts.UserMessage,
nil,
opts.Media,
opts.Channel,
opts.ChatID,
)
// Resolve media:// refs to base64 data URLs (streaming)
maxMediaSize := al.cfg.Agents.Defaults.GetMaxMediaSize()
messages = resolveMediaRefs(messages, al.mediaStore, maxMediaSize)
// 3. Save user message to session
agent.Sessions.AddMessage(opts.SessionKey, "user", opts.UserMessage)
@@ -572,7 +740,10 @@ func (al *AgentLoop) targetReasoningChannelID(channelName string) (chatID string
return ""
}
func (al *AgentLoop) handleReasoning(ctx context.Context, reasoningContent, channelName, channelID string) {
func (al *AgentLoop) handleReasoning(
ctx context.Context,
reasoningContent, channelName, channelID string,
) {
if reasoningContent == "" || channelName == "" || channelID == "" {
return
}
@@ -665,22 +836,33 @@ func (al *AgentLoop) runLLMIteration(
callLLM := func() (*providers.LLMResponse, error) {
if len(agent.Candidates) > 1 && al.fallback != nil {
fbResult, fbErr := al.fallback.Execute(ctx, agent.Candidates,
fbResult, fbErr := al.fallback.Execute(
ctx,
agent.Candidates,
func(ctx context.Context, provider, model string) (*providers.LLMResponse, error) {
return agent.Provider.Chat(ctx, messages, providerToolDefs, model, map[string]any{
"max_tokens": agent.MaxTokens,
"temperature": agent.Temperature,
"prompt_cache_key": agent.ID,
})
return agent.Provider.Chat(
ctx,
messages,
providerToolDefs,
model,
map[string]any{
"max_tokens": agent.MaxTokens,
"temperature": agent.Temperature,
"prompt_cache_key": agent.ID,
},
)
},
)
if fbErr != nil {
return nil, fbErr
}
if fbResult.Provider != "" && len(fbResult.Attempts) > 0 {
logger.InfoCF("agent", fmt.Sprintf("Fallback: succeeded with %s/%s after %d attempts",
fbResult.Provider, fbResult.Model, len(fbResult.Attempts)+1),
map[string]any{"agent_id": agent.ID, "iteration": iteration})
logger.InfoCF(
"agent",
fmt.Sprintf("Fallback: succeeded with %s/%s after %d attempts",
fbResult.Provider, fbResult.Model, len(fbResult.Attempts)+1),
map[string]any{"agent_id": agent.ID, "iteration": iteration},
)
}
return fbResult.Response, nil
}
@@ -731,10 +913,14 @@ func (al *AgentLoop) runLLMIteration(
}
if isContextError && retry < maxRetries {
logger.WarnCF("agent", "Context window error detected, attempting compression", map[string]any{
"error": err.Error(),
"retry": retry,
})
logger.WarnCF(
"agent",
"Context window error detected, attempting compression",
map[string]any{
"error": err.Error(),
"retry": retry,
},
)
if retry == 0 && !constants.IsInternalChannel(opts.Channel) {
al.bus.PublishOutbound(ctx, bus.OutboundMessage{
@@ -766,7 +952,12 @@ func (al *AgentLoop) runLLMIteration(
return "", iteration, fmt.Errorf("LLM call failed after retries: %w", err)
}
go al.handleReasoning(ctx, response.Reasoning, opts.Channel, al.targetReasoningChannelID(opts.Channel))
go al.handleReasoning(
ctx,
response.Reasoning,
opts.Channel,
al.targetReasoningChannelID(opts.Channel),
)
logger.DebugCF("agent", "LLM response",
map[string]any{
@@ -841,62 +1032,76 @@ func (al *AgentLoop) runLLMIteration(
// Save assistant message with tool calls to session
agent.Sessions.AddFullMessage(opts.SessionKey, assistantMsg)
// Execute tool calls
for _, tc := range normalizedToolCalls {
argsJSON, _ := json.Marshal(tc.Arguments)
argsPreview := utils.Truncate(string(argsJSON), 200)
logger.InfoCF("agent", fmt.Sprintf("Tool call: %s(%s)", tc.Name, argsPreview),
map[string]any{
"agent_id": agent.ID,
"tool": tc.Name,
"iteration": iteration,
})
// Execute tool calls in parallel
type indexedAgentResult struct {
result *tools.ToolResult
tc providers.ToolCall
}
// Create async callback for tools that implement AsyncTool
// NOTE: Following openclaw's design, async tools do NOT send results directly to users.
// Instead, they notify the agent via PublishInbound, and the agent decides
// whether to forward the result to the user (in processSystemMessage).
asyncCallback := func(callbackCtx context.Context, result *tools.ToolResult) {
// Log the async completion but don't send directly to user
// The agent will handle user notification via processSystemMessage
if !result.Silent && result.ForUser != "" {
logger.InfoCF("agent", "Async tool completed, agent will handle notification",
map[string]any{
"tool": tc.Name,
"content_len": len(result.ForUser),
})
agentResults := make([]indexedAgentResult, len(normalizedToolCalls))
var wg sync.WaitGroup
for i, tc := range normalizedToolCalls {
agentResults[i].tc = tc
wg.Add(1)
go func(idx int, tc providers.ToolCall) {
defer wg.Done()
argsJSON, _ := json.Marshal(tc.Arguments)
argsPreview := utils.Truncate(string(argsJSON), 200)
logger.InfoCF("agent", fmt.Sprintf("Tool call: %s(%s)", tc.Name, argsPreview),
map[string]any{
"agent_id": agent.ID,
"tool": tc.Name,
"iteration": iteration,
})
// Create async callback for tools that implement AsyncTool
asyncCallback := func(callbackCtx context.Context, result *tools.ToolResult) {
if !result.Silent && result.ForUser != "" {
logger.InfoCF("agent", "Async tool completed, agent will handle notification",
map[string]any{
"tool": tc.Name,
"content_len": len(result.ForUser),
})
}
}
}
toolResult := agent.Tools.ExecuteWithContext(
ctx,
tc.Name,
tc.Arguments,
opts.Channel,
opts.ChatID,
asyncCallback,
)
toolResult := agent.Tools.ExecuteWithContext(
ctx,
tc.Name,
tc.Arguments,
opts.Channel,
opts.ChatID,
asyncCallback,
)
agentResults[idx].result = toolResult
}(i, tc)
}
wg.Wait()
// Process results in original order (send to user, save to session)
for _, r := range agentResults {
// Send ForUser content to user immediately if not Silent
if !toolResult.Silent && toolResult.ForUser != "" && opts.SendResponse {
if !r.result.Silent && r.result.ForUser != "" && opts.SendResponse {
al.bus.PublishOutbound(ctx, bus.OutboundMessage{
Channel: opts.Channel,
ChatID: opts.ChatID,
Content: toolResult.ForUser,
Content: r.result.ForUser,
})
logger.DebugCF("agent", "Sent tool result to user",
map[string]any{
"tool": tc.Name,
"content_len": len(toolResult.ForUser),
"tool": r.tc.Name,
"content_len": len(r.result.ForUser),
})
}
// If tool returned media refs, publish them as outbound media
if len(toolResult.Media) > 0 && opts.SendResponse {
parts := make([]bus.MediaPart, 0, len(toolResult.Media))
for _, ref := range toolResult.Media {
if len(r.result.Media) > 0 && opts.SendResponse {
parts := make([]bus.MediaPart, 0, len(r.result.Media))
for _, ref := range r.result.Media {
part := bus.MediaPart{Ref: ref}
// Populate metadata from MediaStore when available
if al.mediaStore != nil {
if _, meta, err := al.mediaStore.ResolveWithMeta(ref); err == nil {
part.Filename = meta.Filename
@@ -914,15 +1119,15 @@ func (al *AgentLoop) runLLMIteration(
}
// Determine content for LLM based on tool result
contentForLLM := toolResult.ForLLM
if contentForLLM == "" && toolResult.Err != nil {
contentForLLM = toolResult.Err.Error()
contentForLLM := r.result.ForLLM
if contentForLLM == "" && r.result.Err != nil {
contentForLLM = r.result.Err.Error()
}
toolResultMsg := providers.Message{
Role: "tool",
Content: contentForLLM,
ToolCallID: tc.ID,
ToolCallID: r.tc.ID,
}
messages = append(messages, toolResultMsg)
@@ -958,9 +1163,9 @@ func (al *AgentLoop) updateToolContexts(agent *AgentInstance, channel, chatID st
func (al *AgentLoop) maybeSummarize(agent *AgentInstance, sessionKey, channel, chatID string) {
newHistory := agent.Sessions.GetHistory(sessionKey)
tokenEstimate := al.estimateTokens(newHistory)
threshold := agent.ContextWindow * 75 / 100
threshold := agent.ContextWindow * agent.SummarizeTokenPercent / 100
if len(newHistory) > 20 || tokenEstimate > threshold {
if len(newHistory) > agent.SummarizeMessageThreshold || tokenEstimate > threshold {
summarizeKey := agent.ID + ":" + sessionKey
if _, loading := al.summarizing.LoadOrStore(summarizeKey, true); !loading {
go func() {
@@ -1068,7 +1273,11 @@ func formatMessagesForLog(messages []providers.Message) string {
for _, tc := range msg.ToolCalls {
fmt.Fprintf(&sb, " - ID: %s, Type: %s, Name: %s\n", tc.ID, tc.Type, tc.Name)
if tc.Function != nil {
fmt.Fprintf(&sb, " Arguments: %s\n", utils.Truncate(tc.Function.Arguments, 200))
fmt.Fprintf(
&sb,
" Arguments: %s\n",
utils.Truncate(tc.Function.Arguments, 200),
)
}
}
}
@@ -1097,7 +1306,11 @@ func formatToolsForLog(toolDefs []providers.ToolDefinition) string {
fmt.Fprintf(&sb, " [%d] Type: %s, Name: %s\n", i, tool.Type, tool.Function.Name)
fmt.Fprintf(&sb, " Description: %s\n", tool.Function.Description)
if len(tool.Function.Parameters) > 0 {
fmt.Fprintf(&sb, " Parameters: %s\n", utils.Truncate(fmt.Sprintf("%v", tool.Function.Parameters), 200))
fmt.Fprintf(
&sb,
" Parameters: %s\n",
utils.Truncate(fmt.Sprintf("%v", tool.Function.Parameters), 200),
)
}
}
sb.WriteString("]")
@@ -1194,7 +1407,9 @@ func (al *AgentLoop) summarizeBatch(
existingSummary string,
) (string, error) {
var sb strings.Builder
sb.WriteString("Provide a concise summary of this conversation segment, preserving core context and key points.\n")
sb.WriteString(
"Provide a concise summary of this conversation segment, preserving core context and key points.\n",
)
if existingSummary != "" {
sb.WriteString("Existing context: ")
sb.WriteString(existingSummary)
+122
View File
@@ -0,0 +1,122 @@
// PicoClaw - Ultra-lightweight personal AI agent
// Inspired by and based on nanobot: https://github.com/HKUDS/nanobot
// License: MIT
//
// Copyright (c) 2026 PicoClaw contributors
package agent
import (
"bytes"
"encoding/base64"
"io"
"os"
"strings"
"github.com/h2non/filetype"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/media"
"github.com/sipeed/picoclaw/pkg/providers"
)
// resolveMediaRefs replaces media:// refs in message Media fields with base64 data URLs.
// Uses streaming base64 encoding (file handle → encoder → buffer) to avoid holding
// both raw bytes and encoded string in memory simultaneously.
// Returns a new slice; original messages are not mutated.
func resolveMediaRefs(messages []providers.Message, store media.MediaStore, maxSize int) []providers.Message {
if store == nil {
return messages
}
result := make([]providers.Message, len(messages))
copy(result, messages)
for i, m := range result {
if len(m.Media) == 0 {
continue
}
resolved := make([]string, 0, len(m.Media))
for _, ref := range m.Media {
if !strings.HasPrefix(ref, "media://") {
resolved = append(resolved, ref)
continue
}
localPath, meta, err := store.ResolveWithMeta(ref)
if err != nil {
logger.WarnCF("agent", "Failed to resolve media ref", map[string]any{
"ref": ref,
"error": err.Error(),
})
continue
}
info, err := os.Stat(localPath)
if err != nil {
logger.WarnCF("agent", "Failed to stat media file", map[string]any{
"path": localPath,
"error": err.Error(),
})
continue
}
if info.Size() > int64(maxSize) {
logger.WarnCF("agent", "Media file too large, skipping", map[string]any{
"path": localPath,
"size": info.Size(),
"max_size": maxSize,
})
continue
}
// Determine MIME type: prefer metadata, fallback to magic-bytes detection
mime := meta.ContentType
if mime == "" {
kind, ftErr := filetype.MatchFile(localPath)
if ftErr != nil || kind == filetype.Unknown {
logger.WarnCF("agent", "Unknown media type, skipping", map[string]any{
"path": localPath,
})
continue
}
mime = kind.MIME.Value
}
// Streaming base64: open file → base64 encoder → buffer
// Peak memory: ~1.33x file size (buffer only, no raw bytes copy)
f, err := os.Open(localPath)
if err != nil {
logger.WarnCF("agent", "Failed to open media file", map[string]any{
"path": localPath,
"error": err.Error(),
})
continue
}
prefix := "data:" + mime + ";base64,"
encodedLen := base64.StdEncoding.EncodedLen(int(info.Size()))
var buf bytes.Buffer
buf.Grow(len(prefix) + encodedLen)
buf.WriteString(prefix)
encoder := base64.NewEncoder(base64.StdEncoding, &buf)
if _, err := io.Copy(encoder, f); err != nil {
f.Close()
logger.WarnCF("agent", "Failed to encode media file", map[string]any{
"path": localPath,
"error": err.Error(),
})
continue
}
encoder.Close()
f.Close()
resolved = append(resolved, buf.String())
}
result[i].Media = resolved
}
return result
}
+166 -57
View File
@@ -6,12 +6,14 @@ import (
"os"
"path/filepath"
"slices"
"strings"
"testing"
"time"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/channels"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/media"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/tools"
)
@@ -27,16 +29,15 @@ func (f *fakeChannel) IsAllowed(string) bool {
func (f *fakeChannel) IsAllowedSender(sender bus.SenderInfo) bool { return true }
func (f *fakeChannel) ReasoningChannelID() string { return f.id }
func TestRecordLastChannel(t *testing.T) {
// Create temp workspace
func newTestAgentLoop(
t *testing.T,
) (al *AgentLoop, cfg *config.Config, msgBus *bus.MessageBus, provider *mockProvider, cleanup func()) {
t.Helper()
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
// Create test config
cfg := &config.Config{
cfg = &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
@@ -46,74 +47,43 @@ func TestRecordLastChannel(t *testing.T) {
},
},
}
msgBus = bus.NewMessageBus()
provider = &mockProvider{}
al = NewAgentLoop(cfg, msgBus, provider)
return al, cfg, msgBus, provider, func() { os.RemoveAll(tmpDir) }
}
// Create agent loop
msgBus := bus.NewMessageBus()
provider := &mockProvider{}
al := NewAgentLoop(cfg, msgBus, provider)
func TestRecordLastChannel(t *testing.T) {
al, cfg, msgBus, provider, cleanup := newTestAgentLoop(t)
defer cleanup()
// Test RecordLastChannel
testChannel := "test-channel"
err = al.RecordLastChannel(testChannel)
if err != nil {
if err := al.RecordLastChannel(testChannel); err != nil {
t.Fatalf("RecordLastChannel failed: %v", err)
}
// Verify channel was saved
lastChannel := al.state.GetLastChannel()
if lastChannel != testChannel {
t.Errorf("Expected channel '%s', got '%s'", testChannel, lastChannel)
if got := al.state.GetLastChannel(); got != testChannel {
t.Errorf("Expected channel '%s', got '%s'", testChannel, got)
}
// Verify persistence by creating a new agent loop
al2 := NewAgentLoop(cfg, msgBus, provider)
if al2.state.GetLastChannel() != testChannel {
t.Errorf("Expected persistent channel '%s', got '%s'", testChannel, al2.state.GetLastChannel())
if got := al2.state.GetLastChannel(); got != testChannel {
t.Errorf("Expected persistent channel '%s', got '%s'", testChannel, got)
}
}
func TestRecordLastChatID(t *testing.T) {
// Create temp workspace
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
al, cfg, msgBus, provider, cleanup := newTestAgentLoop(t)
defer cleanup()
// Create test config
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
// Create agent loop
msgBus := bus.NewMessageBus()
provider := &mockProvider{}
al := NewAgentLoop(cfg, msgBus, provider)
// Test RecordLastChatID
testChatID := "test-chat-id-123"
err = al.RecordLastChatID(testChatID)
if err != nil {
if err := al.RecordLastChatID(testChatID); err != nil {
t.Fatalf("RecordLastChatID failed: %v", err)
}
// Verify chat ID was saved
lastChatID := al.state.GetLastChatID()
if lastChatID != testChatID {
t.Errorf("Expected chat ID '%s', got '%s'", testChatID, lastChatID)
if got := al.state.GetLastChatID(); got != testChatID {
t.Errorf("Expected chat ID '%s', got '%s'", testChatID, got)
}
// Verify persistence by creating a new agent loop
al2 := NewAgentLoop(cfg, msgBus, provider)
if al2.state.GetLastChatID() != testChatID {
t.Errorf("Expected persistent chat ID '%s', got '%s'", testChatID, al2.state.GetLastChatID())
if got := al2.state.GetLastChatID(); got != testChatID {
t.Errorf("Expected persistent chat ID '%s', got '%s'", testChatID, got)
}
}
@@ -840,3 +810,142 @@ func TestHandleReasoning(t *testing.T) {
}
})
}
func TestResolveMediaRefs_ResolvesToBase64(t *testing.T) {
store := media.NewFileMediaStore()
dir := t.TempDir()
// Create a minimal valid PNG (8-byte header is enough for filetype detection)
pngPath := filepath.Join(dir, "test.png")
// PNG magic: 0x89 P N G \r \n 0x1A \n + minimal IHDR
pngHeader := []byte{
0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A, // PNG signature
0x00, 0x00, 0x00, 0x0D, // IHDR length
0x49, 0x48, 0x44, 0x52, // "IHDR"
0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x08, 0x02, // 1x1 RGB
0x00, 0x00, 0x00, // no interlace
0x90, 0x77, 0x53, 0xDE, // CRC
}
if err := os.WriteFile(pngPath, pngHeader, 0o644); err != nil {
t.Fatal(err)
}
ref, err := store.Store(pngPath, media.MediaMeta{}, "test")
if err != nil {
t.Fatal(err)
}
messages := []providers.Message{
{Role: "user", Content: "describe this", Media: []string{ref}},
}
result := resolveMediaRefs(messages, store, config.DefaultMaxMediaSize)
if len(result[0].Media) != 1 {
t.Fatalf("expected 1 resolved media, got %d", len(result[0].Media))
}
if !strings.HasPrefix(result[0].Media[0], "data:image/png;base64,") {
t.Fatalf("expected data:image/png;base64, prefix, got %q", result[0].Media[0][:40])
}
}
func TestResolveMediaRefs_SkipsOversizedFile(t *testing.T) {
store := media.NewFileMediaStore()
dir := t.TempDir()
bigPath := filepath.Join(dir, "big.png")
// Write PNG header + padding to exceed limit
data := make([]byte, 1024+1) // 1KB + 1 byte
copy(data, []byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A})
if err := os.WriteFile(bigPath, data, 0o644); err != nil {
t.Fatal(err)
}
ref, _ := store.Store(bigPath, media.MediaMeta{}, "test")
messages := []providers.Message{
{Role: "user", Content: "hi", Media: []string{ref}},
}
// Use a tiny limit (1KB) so the file is oversized
result := resolveMediaRefs(messages, store, 1024)
if len(result[0].Media) != 0 {
t.Fatalf("expected 0 media (oversized), got %d", len(result[0].Media))
}
}
func TestResolveMediaRefs_SkipsUnknownType(t *testing.T) {
store := media.NewFileMediaStore()
dir := t.TempDir()
txtPath := filepath.Join(dir, "readme.txt")
if err := os.WriteFile(txtPath, []byte("hello world"), 0o644); err != nil {
t.Fatal(err)
}
ref, _ := store.Store(txtPath, media.MediaMeta{}, "test")
messages := []providers.Message{
{Role: "user", Content: "hi", Media: []string{ref}},
}
result := resolveMediaRefs(messages, store, config.DefaultMaxMediaSize)
if len(result[0].Media) != 0 {
t.Fatalf("expected 0 media (unknown type), got %d", len(result[0].Media))
}
}
func TestResolveMediaRefs_PassesThroughNonMediaRefs(t *testing.T) {
messages := []providers.Message{
{Role: "user", Content: "hi", Media: []string{"https://example.com/img.png"}},
}
result := resolveMediaRefs(messages, nil, config.DefaultMaxMediaSize)
if len(result[0].Media) != 1 || result[0].Media[0] != "https://example.com/img.png" {
t.Fatalf("expected passthrough of non-media:// URL, got %v", result[0].Media)
}
}
func TestResolveMediaRefs_DoesNotMutateOriginal(t *testing.T) {
store := media.NewFileMediaStore()
dir := t.TempDir()
pngPath := filepath.Join(dir, "test.png")
pngHeader := []byte{
0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A,
0x00, 0x00, 0x00, 0x0D, 0x49, 0x48, 0x44, 0x52,
0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x08, 0x02,
0x00, 0x00, 0x00, 0x90, 0x77, 0x53, 0xDE,
}
os.WriteFile(pngPath, pngHeader, 0o644)
ref, _ := store.Store(pngPath, media.MediaMeta{}, "test")
original := []providers.Message{
{Role: "user", Content: "hi", Media: []string{ref}},
}
originalRef := original[0].Media[0]
resolveMediaRefs(original, store, config.DefaultMaxMediaSize)
if original[0].Media[0] != originalRef {
t.Fatal("resolveMediaRefs mutated original message slice")
}
}
func TestResolveMediaRefs_UsesMetaContentType(t *testing.T) {
store := media.NewFileMediaStore()
dir := t.TempDir()
// File with JPEG content but stored with explicit content type
jpegPath := filepath.Join(dir, "photo")
jpegHeader := []byte{0xFF, 0xD8, 0xFF, 0xE0} // JPEG magic bytes
os.WriteFile(jpegPath, jpegHeader, 0o644)
ref, _ := store.Store(jpegPath, media.MediaMeta{ContentType: "image/jpeg"}, "test")
messages := []providers.Message{
{Role: "user", Content: "hi", Media: []string{ref}},
}
result := resolveMediaRefs(messages, store, config.DefaultMaxMediaSize)
if len(result[0].Media) != 1 {
t.Fatalf("expected 1 media, got %d", len(result[0].Media))
}
if !strings.HasPrefix(result[0].Media[0], "data:image/jpeg;base64,") {
t.Fatalf("expected jpeg prefix, got %q", result[0].Media[0][:30])
}
}
+40
View File
@@ -3,12 +3,15 @@ package discord
import (
"context"
"fmt"
"net/http"
"net/url"
"os"
"strings"
"sync"
"time"
"github.com/bwmarrin/discordgo"
"github.com/gorilla/websocket"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/channels"
@@ -40,6 +43,9 @@ func NewDiscordChannel(cfg config.DiscordConfig, bus *bus.MessageBus) (*DiscordC
return nil, fmt.Errorf("failed to create discord session: %w", err)
}
if err := applyDiscordProxy(session, cfg.Proxy); err != nil {
return nil, err
}
base := channels.NewBaseChannel("discord", cfg, bus, cfg.AllowFrom,
channels.WithMaxMessageLength(2000),
channels.WithGroupTrigger(cfg.GroupTrigger),
@@ -465,9 +471,43 @@ func (c *DiscordChannel) StartTyping(ctx context.Context, chatID string) (func()
func (c *DiscordChannel) downloadAttachment(url, filename string) string {
return utils.DownloadFile(url, filename, utils.DownloadOptions{
LoggerPrefix: "discord",
ProxyURL: c.config.Proxy,
})
}
func applyDiscordProxy(session *discordgo.Session, proxyAddr string) error {
var proxyFunc func(*http.Request) (*url.URL, error)
if proxyAddr != "" {
proxyURL, err := url.Parse(proxyAddr)
if err != nil {
return fmt.Errorf("invalid discord proxy URL %q: %w", proxyAddr, err)
}
proxyFunc = http.ProxyURL(proxyURL)
} else if os.Getenv("HTTP_PROXY") != "" || os.Getenv("HTTPS_PROXY") != "" {
proxyFunc = http.ProxyFromEnvironment
}
if proxyFunc == nil {
return nil
}
transport := &http.Transport{Proxy: proxyFunc}
session.Client = &http.Client{
Timeout: sendTimeout,
Transport: transport,
}
if session.Dialer != nil {
dialerCopy := *session.Dialer
dialerCopy.Proxy = proxyFunc
session.Dialer = &dialerCopy
} else {
session.Dialer = &websocket.Dialer{Proxy: proxyFunc}
}
return nil
}
// stripBotMention removes the bot mention from the message content.
// Discord mentions have the format <@USER_ID> or <@!USER_ID> (with nickname).
func (c *DiscordChannel) stripBotMention(text string) string {
+91
View File
@@ -0,0 +1,91 @@
package discord
import (
"net/http"
"net/url"
"testing"
"github.com/bwmarrin/discordgo"
)
func TestApplyDiscordProxy_CustomProxy(t *testing.T) {
session, err := discordgo.New("Bot test-token")
if err != nil {
t.Fatalf("discordgo.New() error: %v", err)
}
if err = applyDiscordProxy(session, "http://127.0.0.1:7890"); err != nil {
t.Fatalf("applyDiscordProxy() error: %v", err)
}
req, err := http.NewRequest("GET", "https://discord.com/api/v10/gateway", nil)
if err != nil {
t.Fatalf("http.NewRequest() error: %v", err)
}
restProxy := session.Client.Transport.(*http.Transport).Proxy
restProxyURL, err := restProxy(req)
if err != nil {
t.Fatalf("rest proxy func error: %v", err)
}
if got, want := restProxyURL.String(), "http://127.0.0.1:7890"; got != want {
t.Fatalf("REST proxy = %q, want %q", got, want)
}
wsProxyURL, err := session.Dialer.Proxy(req)
if err != nil {
t.Fatalf("ws proxy func error: %v", err)
}
if got, want := wsProxyURL.String(), "http://127.0.0.1:7890"; got != want {
t.Fatalf("WS proxy = %q, want %q", got, want)
}
}
func TestApplyDiscordProxy_FromEnvironment(t *testing.T) {
t.Setenv("HTTP_PROXY", "http://127.0.0.1:8888")
t.Setenv("http_proxy", "http://127.0.0.1:8888")
t.Setenv("HTTPS_PROXY", "http://127.0.0.1:8888")
t.Setenv("https_proxy", "http://127.0.0.1:8888")
t.Setenv("ALL_PROXY", "")
t.Setenv("all_proxy", "")
t.Setenv("NO_PROXY", "")
t.Setenv("no_proxy", "")
session, err := discordgo.New("Bot test-token")
if err != nil {
t.Fatalf("discordgo.New() error: %v", err)
}
if err = applyDiscordProxy(session, ""); err != nil {
t.Fatalf("applyDiscordProxy() error: %v", err)
}
req, err := http.NewRequest("GET", "https://discord.com/api/v10/gateway", nil)
if err != nil {
t.Fatalf("http.NewRequest() error: %v", err)
}
gotURL, err := session.Dialer.Proxy(req)
if err != nil {
t.Fatalf("ws proxy func error: %v", err)
}
wantURL, err := url.Parse("http://127.0.0.1:8888")
if err != nil {
t.Fatalf("url.Parse() error: %v", err)
}
if gotURL.String() != wantURL.String() {
t.Fatalf("WS proxy = %q, want %q", gotURL.String(), wantURL.String())
}
}
func TestApplyDiscordProxy_InvalidProxyURL(t *testing.T) {
session, err := discordgo.New("Bot test-token")
if err != nil {
t.Fatalf("discordgo.New() error: %v", err)
}
if err = applyDiscordProxy(session, "://bad-proxy"); err == nil {
t.Fatal("applyDiscordProxy() expected error for invalid proxy URL, got nil")
}
}
+77
View File
@@ -1,5 +1,16 @@
package feishu
import (
"encoding/json"
"regexp"
"strings"
larkim "github.com/larksuite/oapi-sdk-go/v3/service/im/v1"
)
// mentionPlaceholderRegex matches @_user_N placeholders inserted by Feishu for mentions.
var mentionPlaceholderRegex = regexp.MustCompile(`@_user_\d+`)
// stringValue safely dereferences a *string pointer.
func stringValue(v *string) string {
if v == nil {
@@ -7,3 +18,69 @@ func stringValue(v *string) string {
}
return *v
}
// buildMarkdownCard builds a Feishu Interactive Card JSON 2.0 string with markdown content.
// JSON 2.0 cards support full CommonMark standard markdown syntax.
func buildMarkdownCard(content string) (string, error) {
card := map[string]any{
"schema": "2.0",
"body": map[string]any{
"elements": []map[string]any{
{
"tag": "markdown",
"content": content,
},
},
},
}
data, err := json.Marshal(card)
if err != nil {
return "", err
}
return string(data), nil
}
// extractJSONStringField unmarshals content as JSON and returns the value of the given string field.
// Returns "" if the content is invalid JSON or the field is missing/empty.
func extractJSONStringField(content, field string) string {
var m map[string]json.RawMessage
if err := json.Unmarshal([]byte(content), &m); err != nil {
return ""
}
raw, ok := m[field]
if !ok {
return ""
}
var s string
if err := json.Unmarshal(raw, &s); err != nil {
return ""
}
return s
}
// extractImageKey extracts the image_key from a Feishu image message content JSON.
// Format: {"image_key": "img_xxx"}
func extractImageKey(content string) string { return extractJSONStringField(content, "image_key") }
// extractFileKey extracts the file_key from a Feishu file/audio message content JSON.
// Format: {"file_key": "file_xxx", "file_name": "...", ...}
func extractFileKey(content string) string { return extractJSONStringField(content, "file_key") }
// extractFileName extracts the file_name from a Feishu file message content JSON.
func extractFileName(content string) string { return extractJSONStringField(content, "file_name") }
// stripMentionPlaceholders removes @_user_N placeholders from the text content.
// These are inserted by Feishu when users @mention someone in a message.
func stripMentionPlaceholders(content string, mentions []*larkim.MentionEvent) string {
if len(mentions) == 0 {
return content
}
for _, m := range mentions {
if m.Key != nil && *m.Key != "" {
content = strings.ReplaceAll(content, *m.Key, "")
}
}
// Also clean up any remaining @_user_N patterns
content = mentionPlaceholderRegex.ReplaceAllString(content, "")
return strings.TrimSpace(content)
}
+292
View File
@@ -0,0 +1,292 @@
package feishu
import (
"encoding/json"
"testing"
larkim "github.com/larksuite/oapi-sdk-go/v3/service/im/v1"
)
func TestExtractJSONStringField(t *testing.T) {
tests := []struct {
name string
content string
field string
want string
}{
{
name: "valid field",
content: `{"image_key": "img_v2_xxx"}`,
field: "image_key",
want: "img_v2_xxx",
},
{
name: "missing field",
content: `{"image_key": "img_v2_xxx"}`,
field: "file_key",
want: "",
},
{
name: "invalid JSON",
content: `not json at all`,
field: "image_key",
want: "",
},
{
name: "empty content",
content: "",
field: "image_key",
want: "",
},
{
name: "non-string field value",
content: `{"count": 42}`,
field: "count",
want: "",
},
{
name: "empty string value",
content: `{"image_key": ""}`,
field: "image_key",
want: "",
},
{
name: "multiple fields",
content: `{"file_key": "file_xxx", "file_name": "test.pdf"}`,
field: "file_name",
want: "test.pdf",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := extractJSONStringField(tt.content, tt.field)
if got != tt.want {
t.Errorf("extractJSONStringField(%q, %q) = %q, want %q", tt.content, tt.field, got, tt.want)
}
})
}
}
func TestExtractImageKey(t *testing.T) {
tests := []struct {
name string
content string
want string
}{
{
name: "normal",
content: `{"image_key": "img_v2_abc123"}`,
want: "img_v2_abc123",
},
{
name: "missing key",
content: `{"file_key": "file_xxx"}`,
want: "",
},
{
name: "malformed JSON",
content: `{broken`,
want: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := extractImageKey(tt.content)
if got != tt.want {
t.Errorf("extractImageKey(%q) = %q, want %q", tt.content, got, tt.want)
}
})
}
}
func TestExtractFileKey(t *testing.T) {
tests := []struct {
name string
content string
want string
}{
{
name: "normal",
content: `{"file_key": "file_v2_abc123", "file_name": "test.doc"}`,
want: "file_v2_abc123",
},
{
name: "missing key",
content: `{"image_key": "img_xxx"}`,
want: "",
},
{
name: "malformed JSON",
content: `not json`,
want: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := extractFileKey(tt.content)
if got != tt.want {
t.Errorf("extractFileKey(%q) = %q, want %q", tt.content, got, tt.want)
}
})
}
}
func TestExtractFileName(t *testing.T) {
tests := []struct {
name string
content string
want string
}{
{
name: "normal",
content: `{"file_key": "file_xxx", "file_name": "report.pdf"}`,
want: "report.pdf",
},
{
name: "missing name",
content: `{"file_key": "file_xxx"}`,
want: "",
},
{
name: "malformed JSON",
content: `{bad`,
want: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := extractFileName(tt.content)
if got != tt.want {
t.Errorf("extractFileName(%q) = %q, want %q", tt.content, got, tt.want)
}
})
}
}
func TestBuildMarkdownCard(t *testing.T) {
tests := []struct {
name string
content string
}{
{
name: "normal content",
content: "Hello **world**",
},
{
name: "empty content",
content: "",
},
{
name: "special characters",
content: `Code: "foo" & <bar> 'baz'`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := buildMarkdownCard(tt.content)
if err != nil {
t.Fatalf("buildMarkdownCard(%q) unexpected error: %v", tt.content, err)
}
// Verify valid JSON
var parsed map[string]any
if err := json.Unmarshal([]byte(result), &parsed); err != nil {
t.Fatalf("buildMarkdownCard(%q) produced invalid JSON: %v", tt.content, err)
}
// Verify schema
if parsed["schema"] != "2.0" {
t.Errorf("schema = %v, want %q", parsed["schema"], "2.0")
}
// Verify body.elements[0].content == input
body, ok := parsed["body"].(map[string]any)
if !ok {
t.Fatal("missing body in card JSON")
}
elements, ok := body["elements"].([]any)
if !ok || len(elements) == 0 {
t.Fatal("missing or empty elements in card JSON")
}
elem, ok := elements[0].(map[string]any)
if !ok {
t.Fatal("first element is not an object")
}
if elem["tag"] != "markdown" {
t.Errorf("tag = %v, want %q", elem["tag"], "markdown")
}
if elem["content"] != tt.content {
t.Errorf("content = %v, want %q", elem["content"], tt.content)
}
})
}
}
func TestStripMentionPlaceholders(t *testing.T) {
strPtr := func(s string) *string { return &s }
tests := []struct {
name string
content string
mentions []*larkim.MentionEvent
want string
}{
{
name: "no mentions",
content: "Hello world",
mentions: nil,
want: "Hello world",
},
{
name: "single mention",
content: "@_user_1 hello",
mentions: []*larkim.MentionEvent{
{Key: strPtr("@_user_1")},
},
want: "hello",
},
{
name: "multiple mentions",
content: "@_user_1 @_user_2 hey",
mentions: []*larkim.MentionEvent{
{Key: strPtr("@_user_1")},
{Key: strPtr("@_user_2")},
},
want: "hey",
},
{
name: "empty content",
content: "",
mentions: []*larkim.MentionEvent{{Key: strPtr("@_user_1")}},
want: "",
},
{
name: "empty mentions slice",
content: "@_user_1 test",
mentions: []*larkim.MentionEvent{},
want: "@_user_1 test",
},
{
name: "mention with nil key",
content: "@_user_1 test",
mentions: []*larkim.MentionEvent{
{Key: nil},
},
want: "test",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := stripMentionPlaceholders(tt.content, tt.mentions)
if got != tt.want {
t.Errorf("stripMentionPlaceholders(%q, ...) = %q, want %q", tt.content, got, tt.want)
}
})
}
}
+25 -3
View File
@@ -16,6 +16,8 @@ type FeishuChannel struct {
*channels.BaseChannel
}
var errUnsupported = errors.New("feishu channel is not supported on 32-bit architectures")
// NewFeishuChannel returns an error on 32-bit architectures where the Feishu SDK is not supported
func NewFeishuChannel(cfg config.FeishuConfig, bus *bus.MessageBus) (*FeishuChannel, error) {
return nil, errors.New(
@@ -25,15 +27,35 @@ func NewFeishuChannel(cfg config.FeishuConfig, bus *bus.MessageBus) (*FeishuChan
// Start is a stub method to satisfy the Channel interface
func (c *FeishuChannel) Start(ctx context.Context) error {
return nil
return errUnsupported
}
// Stop is a stub method to satisfy the Channel interface
func (c *FeishuChannel) Stop(ctx context.Context) error {
return nil
return errUnsupported
}
// Send is a stub method to satisfy the Channel interface
func (c *FeishuChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
return errors.New("feishu channel is not supported on 32-bit architectures")
return errUnsupported
}
// EditMessage is a stub method to satisfy MessageEditor
func (c *FeishuChannel) EditMessage(ctx context.Context, chatID, messageID, content string) error {
return errUnsupported
}
// SendPlaceholder is a stub method to satisfy PlaceholderCapable
func (c *FeishuChannel) SendPlaceholder(ctx context.Context, chatID string) (string, error) {
return "", errUnsupported
}
// ReactToMessage is a stub method to satisfy ReactionCapable
func (c *FeishuChannel) ReactToMessage(ctx context.Context, chatID, messageID string) (func(), error) {
return func() {}, errUnsupported
}
// SendMedia is a stub method to satisfy MediaSender
func (c *FeishuChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error {
return errUnsupported
}
+628 -51
View File
@@ -6,10 +6,15 @@ import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"sync"
"time"
"sync/atomic"
lark "github.com/larksuite/oapi-sdk-go/v3"
larkcore "github.com/larksuite/oapi-sdk-go/v3/core"
larkdispatcher "github.com/larksuite/oapi-sdk-go/v3/event/dispatcher"
larkim "github.com/larksuite/oapi-sdk-go/v3/service/im/v1"
larkws "github.com/larksuite/oapi-sdk-go/v3/ws"
@@ -19,6 +24,7 @@ import (
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/identity"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/media"
"github.com/sipeed/picoclaw/pkg/utils"
)
@@ -28,6 +34,8 @@ type FeishuChannel struct {
client *lark.Client
wsClient *larkws.Client
botOpenID atomic.Value // stores string; populated lazily for @mention detection
mu sync.Mutex
cancel context.CancelFunc
}
@@ -38,11 +46,13 @@ func NewFeishuChannel(cfg config.FeishuConfig, bus *bus.MessageBus) (*FeishuChan
channels.WithReasoningChannelID(cfg.ReasoningChannelID),
)
return &FeishuChannel{
ch := &FeishuChannel{
BaseChannel: base,
config: cfg,
client: lark.NewClient(cfg.AppID, cfg.AppSecret),
}, nil
}
ch.SetOwner(ch)
return ch, nil
}
func (c *FeishuChannel) Start(ctx context.Context) error {
@@ -50,6 +60,13 @@ func (c *FeishuChannel) Start(ctx context.Context) error {
return fmt.Errorf("feishu app_id or app_secret is empty")
}
// Fetch bot open_id via API for reliable @mention detection.
if err := c.fetchBotOpenID(ctx); err != nil {
logger.ErrorCF("feishu", "Failed to fetch bot open_id, @mention detection may not work", map[string]any{
"error": err.Error(),
})
}
dispatcher := larkdispatcher.NewEventDispatcher(c.config.VerificationToken, c.config.EncryptKey).
OnP2MessageReceiveV1(c.handleMessageReceive)
@@ -93,46 +110,213 @@ func (c *FeishuChannel) Stop(ctx context.Context) error {
return nil
}
// Send sends a message using Interactive Card format for markdown rendering.
func (c *FeishuChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
if !c.IsRunning() {
return channels.ErrNotRunning
}
if msg.ChatID == "" {
return fmt.Errorf("chat ID is empty")
return fmt.Errorf("chat ID is empty: %w", channels.ErrSendFailed)
}
payload, err := json.Marshal(map[string]string{"text": msg.Content})
// Build interactive card with markdown content
cardContent, err := buildMarkdownCard(msg.Content)
if err != nil {
return fmt.Errorf("failed to marshal feishu content: %w", err)
return fmt.Errorf("feishu send: card build failed: %w", err)
}
return c.sendCard(ctx, msg.ChatID, cardContent)
}
// EditMessage implements channels.MessageEditor.
// Uses Message.Patch to update an interactive card message.
func (c *FeishuChannel) EditMessage(ctx context.Context, chatID, messageID, content string) error {
cardContent, err := buildMarkdownCard(content)
if err != nil {
return fmt.Errorf("feishu edit: card build failed: %w", err)
}
req := larkim.NewPatchMessageReqBuilder().
MessageId(messageID).
Body(larkim.NewPatchMessageReqBodyBuilder().Content(cardContent).Build()).
Build()
resp, err := c.client.Im.V1.Message.Patch(ctx, req)
if err != nil {
return fmt.Errorf("feishu edit: %w", err)
}
if !resp.Success() {
return fmt.Errorf("feishu edit api error (code=%d msg=%s)", resp.Code, resp.Msg)
}
return nil
}
// SendPlaceholder implements channels.PlaceholderCapable.
// Sends an interactive card with placeholder text and returns its message ID.
func (c *FeishuChannel) SendPlaceholder(ctx context.Context, chatID string) (string, error) {
if !c.config.Placeholder.Enabled {
logger.DebugCF("feishu", "Placeholder disabled, skipping", map[string]any{
"chat_id": chatID,
})
return "", nil
}
text := c.config.Placeholder.Text
if text == "" {
text = "Thinking..."
}
cardContent, err := buildMarkdownCard(text)
if err != nil {
return "", fmt.Errorf("feishu placeholder: card build failed: %w", err)
}
req := larkim.NewCreateMessageReqBuilder().
ReceiveIdType(larkim.ReceiveIdTypeChatId).
Body(larkim.NewCreateMessageReqBodyBuilder().
ReceiveId(msg.ChatID).
MsgType(larkim.MsgTypeText).
Content(string(payload)).
Uuid(fmt.Sprintf("picoclaw-%d", time.Now().UnixNano())).
ReceiveId(chatID).
MsgType(larkim.MsgTypeInteractive).
Content(cardContent).
Build()).
Build()
resp, err := c.client.Im.V1.Message.Create(ctx, req)
if err != nil {
return fmt.Errorf("feishu send: %w", channels.ErrTemporary)
return "", fmt.Errorf("feishu placeholder send: %w", err)
}
if !resp.Success() {
return fmt.Errorf("feishu api error (code=%d msg=%s): %w", resp.Code, resp.Msg, channels.ErrTemporary)
return "", fmt.Errorf("feishu placeholder api error (code=%d msg=%s)", resp.Code, resp.Msg)
}
logger.DebugCF("feishu", "Feishu message sent", map[string]any{
"chat_id": msg.ChatID,
})
if resp.Data != nil && resp.Data.MessageId != nil {
return *resp.Data.MessageId, nil
}
return "", nil
}
// ReactToMessage implements channels.ReactionCapable.
// Adds an "Pin" reaction and returns an undo function to remove it.
func (c *FeishuChannel) ReactToMessage(ctx context.Context, chatID, messageID string) (func(), error) {
req := larkim.NewCreateMessageReactionReqBuilder().
MessageId(messageID).
Body(larkim.NewCreateMessageReactionReqBodyBuilder().
ReactionType(larkim.NewEmojiBuilder().EmojiType("Pin").Build()).
Build()).
Build()
resp, err := c.client.Im.V1.MessageReaction.Create(ctx, req)
if err != nil {
logger.ErrorCF("feishu", "Failed to add reaction", map[string]any{
"message_id": messageID,
"error": err.Error(),
})
return func() {}, fmt.Errorf("feishu react: %w", err)
}
if !resp.Success() {
logger.ErrorCF("feishu", "Reaction API error", map[string]any{
"message_id": messageID,
"code": resp.Code,
"msg": resp.Msg,
})
return func() {}, fmt.Errorf("feishu react api error (code=%d msg=%s)", resp.Code, resp.Msg)
}
var reactionID string
if resp.Data != nil && resp.Data.ReactionId != nil {
reactionID = *resp.Data.ReactionId
}
if reactionID == "" {
return func() {}, nil
}
var undone atomic.Bool
undo := func() {
if !undone.CompareAndSwap(false, true) {
return
}
delReq := larkim.NewDeleteMessageReactionReqBuilder().
MessageId(messageID).
ReactionId(reactionID).
Build()
_, _ = c.client.Im.V1.MessageReaction.Delete(context.Background(), delReq)
}
return undo, nil
}
// SendMedia implements channels.MediaSender.
// Uploads images/files via Feishu API then sends as messages.
func (c *FeishuChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error {
if !c.IsRunning() {
return channels.ErrNotRunning
}
if msg.ChatID == "" {
return fmt.Errorf("chat ID is empty: %w", channels.ErrSendFailed)
}
store := c.GetMediaStore()
if store == nil {
return fmt.Errorf("no media store available: %w", channels.ErrSendFailed)
}
for _, part := range msg.Parts {
if err := c.sendMediaPart(ctx, msg.ChatID, part, store); err != nil {
return err
}
}
return nil
}
// sendMediaPart resolves and sends a single media part.
func (c *FeishuChannel) sendMediaPart(
ctx context.Context,
chatID string,
part bus.MediaPart,
store media.MediaStore,
) error {
localPath, err := store.Resolve(part.Ref)
if err != nil {
logger.ErrorCF("feishu", "Failed to resolve media ref", map[string]any{
"ref": part.Ref,
"error": err.Error(),
})
return nil // skip this part
}
file, err := os.Open(localPath)
if err != nil {
logger.ErrorCF("feishu", "Failed to open media file", map[string]any{
"path": localPath,
"error": err.Error(),
})
return nil // skip this part
}
defer file.Close()
switch part.Type {
case "image":
err = c.sendImage(ctx, chatID, file)
default:
filename := part.Filename
if filename == "" {
filename = "file"
}
err = c.sendFile(ctx, chatID, file, filename, part.Type)
}
if err != nil {
logger.ErrorCF("feishu", "Failed to send media", map[string]any{
"type": part.Type,
"error": err.Error(),
})
return fmt.Errorf("feishu send media: %w", channels.ErrTemporary)
}
return nil
}
// --- Inbound message handling ---
func (c *FeishuChannel) handleMessageReceive(ctx context.Context, event *larkim.P2MessageReceiveV1) error {
if event == nil || event.Event == nil || event.Event.Message == nil {
return nil
@@ -151,34 +335,68 @@ func (c *FeishuChannel) handleMessageReceive(ctx context.Context, event *larkim.
senderID = "unknown"
}
content := extractFeishuMessageContent(message)
messageType := stringValue(message.MessageType)
messageID := stringValue(message.MessageId)
rawContent := stringValue(message.Content)
// Check allowlist early to avoid downloading media for rejected senders.
// BaseChannel.HandleMessage will check again, but this avoids wasted network I/O.
senderInfo := bus.SenderInfo{
Platform: "feishu",
PlatformID: senderID,
CanonicalID: identity.BuildCanonicalID("feishu", senderID),
}
if !c.IsAllowedSender(senderInfo) {
return nil
}
// Extract content based on message type
content := extractContent(messageType, rawContent)
// Handle media messages (download and store)
var mediaRefs []string
if store := c.GetMediaStore(); store != nil && messageID != "" {
mediaRefs = c.downloadInboundMedia(ctx, chatID, messageID, messageType, rawContent, store)
}
// Append media tags to content (like Telegram does)
content = appendMediaTags(content, messageType, mediaRefs)
if content == "" {
content = "[empty message]"
}
metadata := map[string]string{}
messageID := ""
if mid := stringValue(message.MessageId); mid != "" {
messageID = mid
if messageID != "" {
metadata["message_id"] = messageID
}
if messageType := stringValue(message.MessageType); messageType != "" {
if messageType != "" {
metadata["message_type"] = messageType
}
if chatType := stringValue(message.ChatType); chatType != "" {
chatType := stringValue(message.ChatType)
if chatType != "" {
metadata["chat_type"] = chatType
}
if sender != nil && sender.TenantKey != nil {
metadata["tenant_key"] = *sender.TenantKey
}
chatType := stringValue(message.ChatType)
var peer bus.Peer
if chatType == "p2p" {
peer = bus.Peer{Kind: "direct", ID: senderID}
} else {
peer = bus.Peer{Kind: "group", ID: chatID}
// Check if bot was mentioned
isMentioned := c.isBotMentioned(message)
// Strip mention placeholders from content before group trigger check
if len(message.Mentions) > 0 {
content = stripMentionPlaceholders(content, message.Mentions)
}
// In group chats, apply unified group trigger filtering
respond, cleaned := c.ShouldRespondInGroup(false, content)
respond, cleaned := c.ShouldRespondInGroup(isMentioned, content)
if !respond {
return nil
}
@@ -186,22 +404,398 @@ func (c *FeishuChannel) handleMessageReceive(ctx context.Context, event *larkim.
}
logger.InfoCF("feishu", "Feishu message received", map[string]any{
"sender_id": senderID,
"chat_id": chatID,
"preview": utils.Truncate(content, 80),
"sender_id": senderID,
"chat_id": chatID,
"message_id": messageID,
"preview": utils.Truncate(content, 80),
})
senderInfo := bus.SenderInfo{
Platform: "feishu",
PlatformID: senderID,
CanonicalID: identity.BuildCanonicalID("feishu", senderID),
c.HandleMessage(ctx, peer, messageID, senderID, chatID, content, mediaRefs, metadata, senderInfo)
return nil
}
// --- Internal helpers ---
// fetchBotOpenID calls the Feishu bot info API to retrieve and store the bot's open_id.
func (c *FeishuChannel) fetchBotOpenID(ctx context.Context) error {
resp, err := c.client.Do(ctx, &larkcore.ApiReq{
HttpMethod: http.MethodGet,
ApiPath: "/open-apis/bot/v3/info",
SupportedAccessTokenTypes: []larkcore.AccessTokenType{larkcore.AccessTokenTypeTenant},
})
if err != nil {
return fmt.Errorf("bot info request: %w", err)
}
if !c.IsAllowedSender(senderInfo) {
return nil
var result struct {
Code int `json:"code"`
Bot struct {
OpenID string `json:"open_id"`
} `json:"bot"`
}
if err := json.Unmarshal(resp.RawBody, &result); err != nil {
return fmt.Errorf("bot info parse: %w", err)
}
if result.Code != 0 {
return fmt.Errorf("bot info api error (code=%d)", result.Code)
}
if result.Bot.OpenID == "" {
return fmt.Errorf("bot info: empty open_id")
}
c.HandleMessage(ctx, peer, messageID, senderID, chatID, content, nil, metadata, senderInfo)
c.botOpenID.Store(result.Bot.OpenID)
logger.InfoCF("feishu", "Fetched bot open_id from API", map[string]any{
"open_id": result.Bot.OpenID,
})
return nil
}
// isBotMentioned checks if the bot was @mentioned in the message.
func (c *FeishuChannel) isBotMentioned(message *larkim.EventMessage) bool {
if message.Mentions == nil {
return false
}
knownID, _ := c.botOpenID.Load().(string)
if knownID == "" {
logger.DebugCF("feishu", "Bot open_id unknown, cannot detect @mention", nil)
return false
}
for _, m := range message.Mentions {
if m.Id == nil {
continue
}
if m.Id.OpenId != nil && *m.Id.OpenId == knownID {
return true
}
}
return false
}
// extractContent extracts text content from different message types.
func extractContent(messageType, rawContent string) string {
if rawContent == "" {
return ""
}
switch messageType {
case larkim.MsgTypeText:
var textPayload struct {
Text string `json:"text"`
}
if err := json.Unmarshal([]byte(rawContent), &textPayload); err == nil {
return textPayload.Text
}
return rawContent
case larkim.MsgTypePost:
// Pass raw JSON to LLM — structured rich text is more informative than flattened plain text
return rawContent
case larkim.MsgTypeImage:
// Image messages don't have text content
return ""
case larkim.MsgTypeFile, larkim.MsgTypeAudio, larkim.MsgTypeMedia:
// File/audio/video messages may have a filename
name := extractFileName(rawContent)
if name != "" {
return name
}
return ""
default:
return rawContent
}
}
// downloadInboundMedia downloads media from inbound messages and stores in MediaStore.
func (c *FeishuChannel) downloadInboundMedia(
ctx context.Context,
chatID, messageID, messageType, rawContent string,
store media.MediaStore,
) []string {
var refs []string
scope := channels.BuildMediaScope("feishu", chatID, messageID)
switch messageType {
case larkim.MsgTypeImage:
imageKey := extractImageKey(rawContent)
if imageKey == "" {
return nil
}
ref := c.downloadResource(ctx, messageID, imageKey, "image", ".jpg", store, scope)
if ref != "" {
refs = append(refs, ref)
}
case larkim.MsgTypeFile, larkim.MsgTypeAudio, larkim.MsgTypeMedia:
fileKey := extractFileKey(rawContent)
if fileKey == "" {
return nil
}
// Derive a fallback extension from the message type.
var ext string
switch messageType {
case larkim.MsgTypeAudio:
ext = ".ogg"
case larkim.MsgTypeMedia:
ext = ".mp4"
default:
ext = "" // generic file — rely on resp.FileName
}
ref := c.downloadResource(ctx, messageID, fileKey, "file", ext, store, scope)
if ref != "" {
refs = append(refs, ref)
}
}
return refs
}
// downloadResource downloads a message resource (image/file) from Feishu,
// writes it to the project media directory, and stores the reference in MediaStore.
// fallbackExt (e.g. ".jpg") is appended when the resolved filename has no extension.
func (c *FeishuChannel) downloadResource(
ctx context.Context,
messageID, fileKey, resourceType, fallbackExt string,
store media.MediaStore,
scope string,
) string {
req := larkim.NewGetMessageResourceReqBuilder().
MessageId(messageID).
FileKey(fileKey).
Type(resourceType).
Build()
resp, err := c.client.Im.V1.MessageResource.Get(ctx, req)
if err != nil {
logger.ErrorCF("feishu", "Failed to download resource", map[string]any{
"message_id": messageID,
"file_key": fileKey,
"error": err.Error(),
})
return ""
}
if !resp.Success() {
logger.ErrorCF("feishu", "Resource download api error", map[string]any{
"code": resp.Code,
"msg": resp.Msg,
})
return ""
}
if resp.File == nil {
return ""
}
// Safely close the underlying reader if it implements io.Closer (e.g. HTTP response body).
if closer, ok := resp.File.(io.Closer); ok {
defer closer.Close()
}
filename := resp.FileName
if filename == "" {
filename = fileKey
}
// If filename still has no extension, append the fallback (like Telegram's ext parameter).
if filepath.Ext(filename) == "" && fallbackExt != "" {
filename += fallbackExt
}
// Write to the shared picoclaw_media directory using a unique name to avoid collisions.
mediaDir := filepath.Join(os.TempDir(), "picoclaw_media")
if mkdirErr := os.MkdirAll(mediaDir, 0o700); mkdirErr != nil {
logger.ErrorCF("feishu", "Failed to create media directory", map[string]any{
"error": mkdirErr.Error(),
})
return ""
}
ext := filepath.Ext(filename)
localPath := filepath.Join(mediaDir, utils.SanitizeFilename(messageID+"-"+fileKey+ext))
out, err := os.Create(localPath)
if err != nil {
logger.ErrorCF("feishu", "Failed to create local file for resource", map[string]any{
"error": err.Error(),
})
return ""
}
if _, copyErr := io.Copy(out, resp.File); copyErr != nil {
out.Close()
os.Remove(localPath)
logger.ErrorCF("feishu", "Failed to write resource to file", map[string]any{
"error": copyErr.Error(),
})
return ""
}
out.Close()
ref, err := store.Store(localPath, media.MediaMeta{
Filename: filename,
Source: "feishu",
}, scope)
if err != nil {
logger.ErrorCF("feishu", "Failed to store downloaded resource", map[string]any{
"file_key": fileKey,
"error": err.Error(),
})
os.Remove(localPath)
return ""
}
return ref
}
// appendMediaTags appends media type tags to content (like Telegram's "[image: photo]").
func appendMediaTags(content, messageType string, mediaRefs []string) string {
if len(mediaRefs) == 0 {
return content
}
var tag string
switch messageType {
case larkim.MsgTypeImage:
tag = "[image: photo]"
case larkim.MsgTypeAudio:
tag = "[audio]"
case larkim.MsgTypeMedia:
tag = "[video]"
case larkim.MsgTypeFile:
tag = "[file]"
default:
tag = "[attachment]"
}
if content == "" {
return tag
}
return content + " " + tag
}
// sendCard sends an interactive card message to a chat.
func (c *FeishuChannel) sendCard(ctx context.Context, chatID, cardContent string) error {
req := larkim.NewCreateMessageReqBuilder().
ReceiveIdType(larkim.ReceiveIdTypeChatId).
Body(larkim.NewCreateMessageReqBodyBuilder().
ReceiveId(chatID).
MsgType(larkim.MsgTypeInteractive).
Content(cardContent).
Build()).
Build()
resp, err := c.client.Im.V1.Message.Create(ctx, req)
if err != nil {
return fmt.Errorf("feishu send card: %w", channels.ErrTemporary)
}
if !resp.Success() {
return fmt.Errorf("feishu api error (code=%d msg=%s): %w", resp.Code, resp.Msg, channels.ErrTemporary)
}
logger.DebugCF("feishu", "Feishu card message sent", map[string]any{
"chat_id": chatID,
})
return nil
}
// sendImage uploads an image and sends it as a message.
func (c *FeishuChannel) sendImage(ctx context.Context, chatID string, file *os.File) error {
// Upload image to get image_key
uploadReq := larkim.NewCreateImageReqBuilder().
Body(larkim.NewCreateImageReqBodyBuilder().
ImageType("message").
Image(file).
Build()).
Build()
uploadResp, err := c.client.Im.V1.Image.Create(ctx, uploadReq)
if err != nil {
return fmt.Errorf("feishu image upload: %w", err)
}
if !uploadResp.Success() {
return fmt.Errorf("feishu image upload api error (code=%d msg=%s)", uploadResp.Code, uploadResp.Msg)
}
if uploadResp.Data == nil || uploadResp.Data.ImageKey == nil {
return fmt.Errorf("feishu image upload: no image_key returned")
}
imageKey := *uploadResp.Data.ImageKey
// Send image message
content, _ := json.Marshal(map[string]string{"image_key": imageKey})
req := larkim.NewCreateMessageReqBuilder().
ReceiveIdType(larkim.ReceiveIdTypeChatId).
Body(larkim.NewCreateMessageReqBodyBuilder().
ReceiveId(chatID).
MsgType(larkim.MsgTypeImage).
Content(string(content)).
Build()).
Build()
resp, err := c.client.Im.V1.Message.Create(ctx, req)
if err != nil {
return fmt.Errorf("feishu image send: %w", err)
}
if !resp.Success() {
return fmt.Errorf("feishu image send api error (code=%d msg=%s)", resp.Code, resp.Msg)
}
return nil
}
// sendFile uploads a file and sends it as a message.
func (c *FeishuChannel) sendFile(ctx context.Context, chatID string, file *os.File, filename, fileType string) error {
// Map part type to Feishu file type
feishuFileType := "stream"
switch fileType {
case "audio":
feishuFileType = "opus"
case "video":
feishuFileType = "mp4"
}
// Upload file to get file_key
uploadReq := larkim.NewCreateFileReqBuilder().
Body(larkim.NewCreateFileReqBodyBuilder().
FileType(feishuFileType).
FileName(filename).
File(file).
Build()).
Build()
uploadResp, err := c.client.Im.V1.File.Create(ctx, uploadReq)
if err != nil {
return fmt.Errorf("feishu file upload: %w", err)
}
if !uploadResp.Success() {
return fmt.Errorf("feishu file upload api error (code=%d msg=%s)", uploadResp.Code, uploadResp.Msg)
}
if uploadResp.Data == nil || uploadResp.Data.FileKey == nil {
return fmt.Errorf("feishu file upload: no file_key returned")
}
fileKey := *uploadResp.Data.FileKey
// Send file message
content, _ := json.Marshal(map[string]string{"file_key": fileKey})
req := larkim.NewCreateMessageReqBuilder().
ReceiveIdType(larkim.ReceiveIdTypeChatId).
Body(larkim.NewCreateMessageReqBodyBuilder().
ReceiveId(chatID).
MsgType(larkim.MsgTypeFile).
Content(string(content)).
Build()).
Build()
resp, err := c.client.Im.V1.Message.Create(ctx, req)
if err != nil {
return fmt.Errorf("feishu file send: %w", err)
}
if !resp.Success() {
return fmt.Errorf("feishu file send api error (code=%d msg=%s)", resp.Code, resp.Msg)
}
return nil
}
@@ -222,20 +816,3 @@ func extractFeishuSenderID(sender *larkim.EventSender) string {
return ""
}
func extractFeishuMessageContent(message *larkim.EventMessage) string {
if message == nil || message.Content == nil || *message.Content == "" {
return ""
}
if message.MessageType != nil && *message.MessageType == larkim.MsgTypeText {
var textPayload struct {
Text string `json:"text"`
}
if err := json.Unmarshal([]byte(*message.Content), &textPayload); err == nil {
return textPayload.Text
}
}
return *message.Content
}
+256
View File
@@ -0,0 +1,256 @@
//go:build amd64 || arm64 || riscv64 || mips64 || ppc64
package feishu
import (
"testing"
larkim "github.com/larksuite/oapi-sdk-go/v3/service/im/v1"
)
func TestExtractContent(t *testing.T) {
tests := []struct {
name string
messageType string
rawContent string
want string
}{
{
name: "text message",
messageType: "text",
rawContent: `{"text": "hello world"}`,
want: "hello world",
},
{
name: "text message invalid JSON",
messageType: "text",
rawContent: `not json`,
want: "not json",
},
{
name: "post message returns raw JSON",
messageType: "post",
rawContent: `{"title": "test post"}`,
want: `{"title": "test post"}`,
},
{
name: "image message returns empty",
messageType: "image",
rawContent: `{"image_key": "img_xxx"}`,
want: "",
},
{
name: "file message with filename",
messageType: "file",
rawContent: `{"file_key": "file_xxx", "file_name": "report.pdf"}`,
want: "report.pdf",
},
{
name: "file message without filename",
messageType: "file",
rawContent: `{"file_key": "file_xxx"}`,
want: "",
},
{
name: "audio message with filename",
messageType: "audio",
rawContent: `{"file_key": "file_xxx", "file_name": "recording.ogg"}`,
want: "recording.ogg",
},
{
name: "media message with filename",
messageType: "media",
rawContent: `{"file_key": "file_xxx", "file_name": "video.mp4"}`,
want: "video.mp4",
},
{
name: "unknown message type returns raw",
messageType: "sticker",
rawContent: `{"sticker_id": "sticker_xxx"}`,
want: `{"sticker_id": "sticker_xxx"}`,
},
{
name: "empty raw content",
messageType: "text",
rawContent: "",
want: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := extractContent(tt.messageType, tt.rawContent)
if got != tt.want {
t.Errorf("extractContent(%q, %q) = %q, want %q", tt.messageType, tt.rawContent, got, tt.want)
}
})
}
}
func TestAppendMediaTags(t *testing.T) {
tests := []struct {
name string
content string
messageType string
mediaRefs []string
want string
}{
{
name: "no refs returns content unchanged",
content: "hello",
messageType: "image",
mediaRefs: nil,
want: "hello",
},
{
name: "empty refs returns content unchanged",
content: "hello",
messageType: "image",
mediaRefs: []string{},
want: "hello",
},
{
name: "image with content",
content: "check this",
messageType: "image",
mediaRefs: []string{"ref1"},
want: "check this [image: photo]",
},
{
name: "image empty content",
content: "",
messageType: "image",
mediaRefs: []string{"ref1"},
want: "[image: photo]",
},
{
name: "audio",
content: "listen",
messageType: "audio",
mediaRefs: []string{"ref1"},
want: "listen [audio]",
},
{
name: "media/video",
content: "watch",
messageType: "media",
mediaRefs: []string{"ref1"},
want: "watch [video]",
},
{
name: "file",
content: "report.pdf",
messageType: "file",
mediaRefs: []string{"ref1"},
want: "report.pdf [file]",
},
{
name: "unknown type",
content: "something",
messageType: "sticker",
mediaRefs: []string{"ref1"},
want: "something [attachment]",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := appendMediaTags(tt.content, tt.messageType, tt.mediaRefs)
if got != tt.want {
t.Errorf(
"appendMediaTags(%q, %q, %v) = %q, want %q",
tt.content,
tt.messageType,
tt.mediaRefs,
got,
tt.want,
)
}
})
}
}
func TestExtractFeishuSenderID(t *testing.T) {
strPtr := func(s string) *string { return &s }
tests := []struct {
name string
sender *larkim.EventSender
want string
}{
{
name: "nil sender",
sender: nil,
want: "",
},
{
name: "nil sender ID",
sender: &larkim.EventSender{SenderId: nil},
want: "",
},
{
name: "userId preferred",
sender: &larkim.EventSender{
SenderId: &larkim.UserId{
UserId: strPtr("u_abc123"),
OpenId: strPtr("ou_def456"),
UnionId: strPtr("on_ghi789"),
},
},
want: "u_abc123",
},
{
name: "openId fallback",
sender: &larkim.EventSender{
SenderId: &larkim.UserId{
UserId: strPtr(""),
OpenId: strPtr("ou_def456"),
UnionId: strPtr("on_ghi789"),
},
},
want: "ou_def456",
},
{
name: "unionId fallback",
sender: &larkim.EventSender{
SenderId: &larkim.UserId{
UserId: strPtr(""),
OpenId: strPtr(""),
UnionId: strPtr("on_ghi789"),
},
},
want: "on_ghi789",
},
{
name: "all empty strings",
sender: &larkim.EventSender{
SenderId: &larkim.UserId{
UserId: strPtr(""),
OpenId: strPtr(""),
UnionId: strPtr(""),
},
},
want: "",
},
{
name: "nil userId pointer falls through",
sender: &larkim.EventSender{
SenderId: &larkim.UserId{
UserId: nil,
OpenId: strPtr("ou_def456"),
UnionId: nil,
},
},
want: "ou_def456",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := extractFeishuSenderID(tt.sender)
if got != tt.want {
t.Errorf("extractFeishuSenderID() = %q, want %q", got, tt.want)
}
})
}
}
+56 -50
View File
@@ -255,6 +255,10 @@ func (m *Manager) initChannels() error {
m.initChannel("wecom", "WeCom")
}
if m.config.Channels.WeComAIBot.Enabled && m.config.Channels.WeComAIBot.Token != "" {
m.initChannel("wecom_aibot", "WeCom AI Bot")
}
if m.config.Channels.WeComApp.Enabled && m.config.Channels.WeComApp.CorpID != "" {
m.initChannel("wecom_app", "WeCom App")
}
@@ -539,86 +543,88 @@ func (m *Manager) sendWithRetry(ctx context.Context, name string, w *channelWork
})
}
func (m *Manager) dispatchOutbound(ctx context.Context) {
logger.InfoC("channels", "Outbound dispatcher started")
func dispatchLoop[M any](
ctx context.Context,
m *Manager,
subscribe func(context.Context) (M, bool),
getChannel func(M) string,
enqueue func(context.Context, *channelWorker, M) bool,
startMsg, stopMsg, unknownMsg, noWorkerMsg string,
) {
logger.InfoC("channels", startMsg)
for {
msg, ok := m.bus.SubscribeOutbound(ctx)
msg, ok := subscribe(ctx)
if !ok {
logger.InfoC("channels", "Outbound dispatcher stopped")
logger.InfoC("channels", stopMsg)
return
}
channel := getChannel(msg)
// Silently skip internal channels
if constants.IsInternalChannel(msg.Channel) {
if constants.IsInternalChannel(channel) {
continue
}
m.mu.RLock()
_, exists := m.channels[msg.Channel]
w, wExists := m.workers[msg.Channel]
_, exists := m.channels[channel]
w, wExists := m.workers[channel]
m.mu.RUnlock()
if !exists {
logger.WarnCF("channels", "Unknown channel for outbound message", map[string]any{
"channel": msg.Channel,
})
logger.WarnCF("channels", unknownMsg, map[string]any{"channel": channel})
continue
}
if wExists && w != nil {
select {
case w.queue <- msg:
case <-ctx.Done():
if !enqueue(ctx, w, msg) {
return
}
} else if exists {
logger.WarnCF("channels", "Channel has no active worker, skipping message", map[string]any{
"channel": msg.Channel,
})
logger.WarnCF("channels", noWorkerMsg, map[string]any{"channel": channel})
}
}
}
func (m *Manager) dispatchOutbound(ctx context.Context) {
dispatchLoop(
ctx, m,
m.bus.SubscribeOutbound,
func(msg bus.OutboundMessage) string { return msg.Channel },
func(ctx context.Context, w *channelWorker, msg bus.OutboundMessage) bool {
select {
case w.queue <- msg:
return true
case <-ctx.Done():
return false
}
},
"Outbound dispatcher started",
"Outbound dispatcher stopped",
"Unknown channel for outbound message",
"Channel has no active worker, skipping message",
)
}
func (m *Manager) dispatchOutboundMedia(ctx context.Context) {
logger.InfoC("channels", "Outbound media dispatcher started")
for {
msg, ok := m.bus.SubscribeOutboundMedia(ctx)
if !ok {
logger.InfoC("channels", "Outbound media dispatcher stopped")
return
}
// Silently skip internal channels
if constants.IsInternalChannel(msg.Channel) {
continue
}
m.mu.RLock()
_, exists := m.channels[msg.Channel]
w, wExists := m.workers[msg.Channel]
m.mu.RUnlock()
if !exists {
logger.WarnCF("channels", "Unknown channel for outbound media message", map[string]any{
"channel": msg.Channel,
})
continue
}
if wExists && w != nil {
dispatchLoop(
ctx, m,
m.bus.SubscribeOutboundMedia,
func(msg bus.OutboundMediaMessage) string { return msg.Channel },
func(ctx context.Context, w *channelWorker, msg bus.OutboundMediaMessage) bool {
select {
case w.mediaQueue <- msg:
return true
case <-ctx.Done():
return
return false
}
} else if exists {
logger.WarnCF("channels", "Channel has no active worker, skipping media message", map[string]any{
"channel": msg.Channel,
})
}
}
},
"Outbound media dispatcher started",
"Outbound media dispatcher stopped",
"Unknown channel for outbound media message",
"Channel has no active worker, skipping media message",
)
}
// runMediaWorker processes outbound media messages for a single channel.
+69 -9
View File
@@ -7,12 +7,12 @@ import (
"net/url"
"os"
"regexp"
"slices"
"strconv"
"strings"
"time"
"github.com/mymmrac/telego"
"github.com/mymmrac/telego/telegohandler"
th "github.com/mymmrac/telego/telegohandler"
tu "github.com/mymmrac/telego/telegoutil"
@@ -41,7 +41,7 @@ var (
type TelegramChannel struct {
*channels.BaseChannel
bot *telego.Bot
bh *telegohandler.BotHandler
bh *th.BotHandler
commands TelegramCommander
config *config.Config
chatIDs map[string]int64
@@ -72,6 +72,10 @@ func NewTelegramChannel(cfg *config.Config, bus *bus.MessageBus) (*TelegramChann
}))
}
if baseURL := strings.TrimRight(strings.TrimSpace(telegramCfg.BaseURL), "/"); baseURL != "" {
opts = append(opts, telego.WithAPIServer(baseURL))
}
bot, err := telego.NewBot(telegramCfg.Token, opts...)
if err != nil {
return nil, fmt.Errorf("failed to create telegram bot: %w", err)
@@ -101,6 +105,12 @@ func (c *TelegramChannel) Start(ctx context.Context) error {
c.ctx, c.cancel = context.WithCancel(ctx)
if err := c.initBotCommands(c.ctx); err != nil {
logger.WarnCF("telegram", "Failed to initialize bot commands", map[string]any{
"error": err.Error(),
})
}
updates, err := c.bot.UpdatesViaLongPolling(c.ctx, &telego.GetUpdatesParams{
Timeout: 30,
})
@@ -109,20 +119,19 @@ func (c *TelegramChannel) Start(ctx context.Context) error {
return fmt.Errorf("failed to start long polling: %w", err)
}
bh, err := telegohandler.NewBotHandler(c.bot, updates)
bh, err := th.NewBotHandler(c.bot, updates)
if err != nil {
c.cancel()
return fmt.Errorf("failed to create bot handler: %w", err)
}
c.bh = bh
bh.HandleMessage(func(ctx *th.Context, message telego.Message) error {
c.commands.Help(ctx, message)
return nil
}, th.CommandEqual("help"))
bh.HandleMessage(func(ctx *th.Context, message telego.Message) error {
return c.commands.Start(ctx, message)
}, th.CommandEqual("start"))
bh.HandleMessage(func(ctx *th.Context, message telego.Message) error {
return c.commands.Help(ctx, message)
}, th.CommandEqual("help"))
bh.HandleMessage(func(ctx *th.Context, message telego.Message) error {
return c.commands.Show(ctx, message)
@@ -141,7 +150,13 @@ func (c *TelegramChannel) Start(ctx context.Context) error {
"username": c.bot.Username(),
})
go bh.Start()
go func() {
if err = bh.Start(); err != nil {
logger.ErrorCF("telegram", "Bot handler failed", map[string]any{
"error": err.Error(),
})
}
}()
return nil
}
@@ -152,7 +167,7 @@ func (c *TelegramChannel) Stop(ctx context.Context) error {
// Stop the bot handler
if c.bh != nil {
c.bh.Stop()
_ = c.bh.StopWithContext(ctx)
}
// Cancel our context (stops long polling)
@@ -163,6 +178,51 @@ func (c *TelegramChannel) Stop(ctx context.Context) error {
return nil
}
func (c *TelegramChannel) initBotCommands(ctx context.Context) error {
currentCommands, err := c.bot.GetMyCommands(ctx, &telego.GetMyCommandsParams{
Scope: tu.ScopeDefault(),
})
if err != nil {
return fmt.Errorf("get commands: %w", err)
}
commands := []telego.BotCommand{
{
Command: "start",
Description: "Start the bot",
},
{
Command: "help",
Description: "Show a help message",
},
{
Command: "show",
Description: "Show current configuration",
},
{
Command: "list",
Description: "List available options",
},
}
// Setting commands on each start will hit the rate limit very quickly, that's why we check if an update is needed
if !slices.Equal(currentCommands, commands) {
logger.InfoC("telegram", "Updating bot commands")
err = c.bot.SetMyCommands(ctx, &telego.SetMyCommandsParams{
Commands: commands,
Scope: tu.ScopeDefault(),
})
if err != nil {
return fmt.Errorf("set commands: %w", err)
}
} else {
logger.DebugC("telegram", "Bot commands are up to date")
}
return nil
}
func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
if !c.IsRunning() {
return channels.ErrNotRunning
File diff suppressed because it is too large Load Diff
+210
View File
@@ -0,0 +1,210 @@
package wecom
import (
"context"
"testing"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
)
func TestNewWeComAIBotChannel(t *testing.T) {
t.Run("success with valid config", func(t *testing.T) {
cfg := config.WeComAIBotConfig{
Enabled: true,
Token: "test_token",
EncodingAESKey: "testkey1234567890123456789012345678901234567",
WebhookPath: "/webhook/test",
}
messageBus := bus.NewMessageBus()
ch, err := NewWeComAIBotChannel(cfg, messageBus)
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
if ch == nil {
t.Fatal("Expected channel to be created")
}
if ch.Name() != "wecom_aibot" {
t.Errorf("Expected name 'wecom_aibot', got '%s'", ch.Name())
}
})
t.Run("error with missing token", func(t *testing.T) {
cfg := config.WeComAIBotConfig{
Enabled: true,
EncodingAESKey: "testkey1234567890123456789012345678901234567",
}
messageBus := bus.NewMessageBus()
_, err := NewWeComAIBotChannel(cfg, messageBus)
if err == nil {
t.Fatal("Expected error for missing token, got nil")
}
})
t.Run("error with missing encoding key", func(t *testing.T) {
cfg := config.WeComAIBotConfig{
Enabled: true,
Token: "test_token",
}
messageBus := bus.NewMessageBus()
_, err := NewWeComAIBotChannel(cfg, messageBus)
if err == nil {
t.Fatal("Expected error for missing encoding key, got nil")
}
})
}
func TestWeComAIBotChannelStartStop(t *testing.T) {
cfg := config.WeComAIBotConfig{
Enabled: true,
Token: "test_token",
EncodingAESKey: "testkey1234567890123456789012345678901234567",
}
messageBus := bus.NewMessageBus()
ch, err := NewWeComAIBotChannel(cfg, messageBus)
if err != nil {
t.Fatalf("Failed to create channel: %v", err)
}
ctx := context.Background()
// Test Start
if err := ch.Start(ctx); err != nil {
t.Fatalf("Failed to start channel: %v", err)
}
if !ch.IsRunning() {
t.Error("Expected channel to be running")
}
// Test Stop
if err := ch.Stop(ctx); err != nil {
t.Fatalf("Failed to stop channel: %v", err)
}
if ch.IsRunning() {
t.Error("Expected channel to be stopped")
}
}
func TestWeComAIBotChannelWebhookPath(t *testing.T) {
t.Run("default path", func(t *testing.T) {
cfg := config.WeComAIBotConfig{
Enabled: true,
Token: "test_token",
EncodingAESKey: "testkey1234567890123456789012345678901234567",
}
messageBus := bus.NewMessageBus()
ch, _ := NewWeComAIBotChannel(cfg, messageBus)
expectedPath := "/webhook/wecom-aibot"
if ch.WebhookPath() != expectedPath {
t.Errorf("Expected webhook path '%s', got '%s'", expectedPath, ch.WebhookPath())
}
})
t.Run("custom path", func(t *testing.T) {
customPath := "/custom/webhook"
cfg := config.WeComAIBotConfig{
Enabled: true,
Token: "test_token",
EncodingAESKey: "testkey1234567890123456789012345678901234567",
WebhookPath: customPath,
}
messageBus := bus.NewMessageBus()
ch, _ := NewWeComAIBotChannel(cfg, messageBus)
if ch.WebhookPath() != customPath {
t.Errorf("Expected webhook path '%s', got '%s'", customPath, ch.WebhookPath())
}
})
}
func TestGenerateStreamID(t *testing.T) {
cfg := config.WeComAIBotConfig{
Enabled: true,
Token: "test_token",
EncodingAESKey: "testkey1234567890123456789012345678901234567",
}
messageBus := bus.NewMessageBus()
ch, _ := NewWeComAIBotChannel(cfg, messageBus)
// Generate multiple IDs and check they are unique
ids := make(map[string]bool)
for i := 0; i < 100; i++ {
id := ch.generateStreamID()
if len(id) != 10 {
t.Errorf("Expected stream ID length 10, got %d", len(id))
}
if ids[id] {
t.Errorf("Duplicate stream ID generated: %s", id)
}
ids[id] = true
}
}
func TestEncryptDecrypt(t *testing.T) {
// Use a valid 43-character base64 key (企业微信标准格式)
cfg := config.WeComAIBotConfig{
Enabled: true,
Token: "test_token",
EncodingAESKey: "abcdefghijklmnopqrstuvwxyz0123456789ABCDEFG", // 43 characters
}
messageBus := bus.NewMessageBus()
ch, _ := NewWeComAIBotChannel(cfg, messageBus)
plaintext := "Hello, World!"
receiveid := ""
// Encrypt
encrypted, err := ch.encryptMessage(plaintext, receiveid)
if err != nil {
t.Fatalf("Failed to encrypt message: %v", err)
}
if encrypted == "" {
t.Fatal("Encrypted message is empty")
}
// Decrypt
decrypted, err := decryptMessageWithVerify(encrypted, cfg.EncodingAESKey, receiveid)
if err != nil {
t.Fatalf("Failed to decrypt message: %v", err)
}
if decrypted != plaintext {
t.Errorf("Expected decrypted message '%s', got '%s'", plaintext, decrypted)
}
}
func TestGenerateSignature(t *testing.T) {
token := "test_token"
timestamp := "1234567890"
nonce := "test_nonce"
encrypt := "encrypted_msg"
signature := computeSignature(token, timestamp, nonce, encrypt)
if signature == "" {
t.Error("Generated signature is empty")
}
// Verify signature using verifySignature function
if !verifySignature(token, signature, timestamp, nonce, encrypt) {
t.Error("Generated signature does not verify correctly")
}
}
+19 -75
View File
@@ -38,8 +38,7 @@ type WeComAppChannel struct {
tokenMu sync.RWMutex
ctx context.Context
cancel context.CancelFunc
processedMsgs map[string]bool // Message deduplication: msg_id -> processed
msgMu sync.RWMutex
processedMsgs *MessageDeduplicator
}
// WeComXMLMessage represents the XML message structure from WeCom
@@ -144,7 +143,7 @@ func NewWeComAppChannel(cfg config.WeComAppConfig, messageBus *bus.MessageBus) (
client: &http.Client{Timeout: clientTimeout},
ctx: ctx,
cancel: cancel,
processedMsgs: make(map[string]bool),
processedMsgs: NewMessageDeduplicator(wecomMaxProcessedMessages),
}, nil
}
@@ -342,18 +341,11 @@ func (c *WeComAppChannel) uploadMedia(ctx context.Context, accessToken, mediaTyp
return result.MediaID, nil
}
// sendImageMessage sends an image message using a media_id.
func (c *WeComAppChannel) sendImageMessage(ctx context.Context, accessToken, userID, mediaID string) error {
// sendWeComMessage marshals payload and POSTs it to the WeCom message API.
func (c *WeComAppChannel) sendWeComMessage(ctx context.Context, accessToken string, payload any) error {
apiURL := fmt.Sprintf("%s/cgi-bin/message/send?access_token=%s", wecomAPIBase, accessToken)
msg := WeComImageMessage{
ToUser: userID,
MsgType: "image",
AgentID: c.config.AgentID,
}
msg.Image.MediaID = mediaID
jsonData, err := json.Marshal(msg)
jsonData, err := json.Marshal(payload)
if err != nil {
return fmt.Errorf("failed to marshal message: %w", err)
}
@@ -400,6 +392,17 @@ func (c *WeComAppChannel) sendImageMessage(ctx context.Context, accessToken, use
return nil
}
// sendImageMessage sends an image message using a media_id.
func (c *WeComAppChannel) sendImageMessage(ctx context.Context, accessToken, userID, mediaID string) error {
msg := WeComImageMessage{
ToUser: userID,
MsgType: "image",
AgentID: c.config.AgentID,
}
msg.Image.MediaID = mediaID
return c.sendWeComMessage(ctx, accessToken, msg)
}
// WebhookPath returns the path for registering on the shared HTTP server.
func (c *WeComAppChannel) WebhookPath() string {
if c.config.WebhookPath != "" {
@@ -603,23 +606,12 @@ func (c *WeComAppChannel) processMessage(ctx context.Context, msg WeComXMLMessag
// Message deduplication: Use msg_id to prevent duplicate processing
// As per WeCom documentation, use msg_id for deduplication
msgID := fmt.Sprintf("%d", msg.MsgId)
c.msgMu.Lock()
if c.processedMsgs[msgID] {
c.msgMu.Unlock()
if !c.processedMsgs.MarkMessageProcessed(msgID) {
logger.DebugCF("wecom_app", "Skipping duplicate message", map[string]any{
"msg_id": msgID,
})
return
}
c.processedMsgs[msgID] = true
// Clean up old messages while still holding the lock to avoid a data race
// on len(). Reset the map but re-insert the current msgID so it remains
// deduplicated.
if len(c.processedMsgs) > 1000 {
c.processedMsgs = make(map[string]bool)
c.processedMsgs[msgID] = true
}
c.msgMu.Unlock()
senderID := msg.FromUserName
chatID := senderID // WeCom App uses user ID as chat ID for direct messages
@@ -722,63 +714,15 @@ func (c *WeComAppChannel) getAccessToken() string {
return c.accessToken
}
// sendTextMessage sends a text message to a user
// sendTextMessage sends a text message to a user.
func (c *WeComAppChannel) sendTextMessage(ctx context.Context, accessToken, userID, content string) error {
apiURL := fmt.Sprintf("%s/cgi-bin/message/send?access_token=%s", wecomAPIBase, accessToken)
msg := WeComTextMessage{
ToUser: userID,
MsgType: "text",
AgentID: c.config.AgentID,
}
msg.Text.Content = content
jsonData, err := json.Marshal(msg)
if err != nil {
return fmt.Errorf("failed to marshal message: %w", err)
}
// Use configurable timeout (default 5 seconds)
timeout := c.config.ReplyTimeout
if timeout <= 0 {
timeout = 5
}
reqCtx, cancel := context.WithTimeout(ctx, time.Duration(timeout)*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(reqCtx, http.MethodPost, apiURL, bytes.NewBuffer(jsonData))
if err != nil {
return fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := c.client.Do(req)
if err != nil {
return channels.ClassifyNetError(err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return channels.ClassifySendError(resp.StatusCode, fmt.Errorf("wecom_app API error: %s", string(body)))
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("failed to read response: %w", err)
}
var sendResp WeComSendMessageResponse
if err := json.Unmarshal(body, &sendResp); err != nil {
return fmt.Errorf("failed to parse response: %w", err)
}
if sendResp.ErrCode != 0 {
return fmt.Errorf("API error: %s (code: %d)", sendResp.ErrMsg, sendResp.ErrCode)
}
return nil
return c.sendWeComMessage(ctx, accessToken, msg)
}
// handleHealth handles health check requests
-54
View File
@@ -323,60 +323,6 @@ func TestWeComAppDecryptMessage(t *testing.T) {
})
}
func TestWeComAppPKCS7Unpad(t *testing.T) {
tests := []struct {
name string
input []byte
expected []byte
}{
{
name: "empty input",
input: []byte{},
expected: []byte{},
},
{
name: "valid padding 3 bytes",
input: append([]byte("hello"), bytes.Repeat([]byte{3}, 3)...),
expected: []byte("hello"),
},
{
name: "valid padding 16 bytes (full block)",
input: append([]byte("123456789012345"), bytes.Repeat([]byte{16}, 16)...),
expected: []byte("123456789012345"),
},
{
name: "invalid padding larger than data",
input: []byte{20},
expected: nil, // should return error
},
{
name: "invalid padding zero",
input: append([]byte("test"), byte(0)),
expected: nil, // should return error
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := pkcs7Unpad(tt.input)
if tt.expected == nil {
// This case should return an error
if err == nil {
t.Errorf("pkcs7Unpad() expected error for invalid padding, got result: %v", result)
}
return
}
if err != nil {
t.Errorf("pkcs7Unpad() unexpected error: %v", err)
return
}
if !bytes.Equal(result, tt.expected) {
t.Errorf("pkcs7Unpad() = %v, want %v", result, tt.expected)
}
})
}
}
func TestWeComAppHandleVerification(t *testing.T) {
msgBus := bus.NewMessageBus()
aesKey := generateTestAESKeyApp()
+3 -16
View File
@@ -9,7 +9,6 @@ import (
"io"
"net/http"
"strings"
"sync"
"time"
"github.com/sipeed/picoclaw/pkg/bus"
@@ -28,8 +27,7 @@ type WeComBotChannel struct {
client *http.Client
ctx context.Context
cancel context.CancelFunc
processedMsgs map[string]bool // Message deduplication: msg_id -> processed
msgMu sync.RWMutex
processedMsgs *MessageDeduplicator
}
// WeComBotMessage represents the JSON message structure from WeCom Bot (AIBOT)
@@ -108,7 +106,7 @@ func NewWeComBotChannel(cfg config.WeComConfig, messageBus *bus.MessageBus) (*We
client: &http.Client{Timeout: clientTimeout},
ctx: ctx,
cancel: cancel,
processedMsgs: make(map[string]bool),
processedMsgs: NewMessageDeduplicator(wecomMaxProcessedMessages),
}, nil
}
@@ -330,23 +328,12 @@ func (c *WeComBotChannel) processMessage(ctx context.Context, msg WeComBotMessag
// Message deduplication: Use msg_id to prevent duplicate processing
msgID := msg.MsgID
c.msgMu.Lock()
if c.processedMsgs[msgID] {
c.msgMu.Unlock()
if !c.processedMsgs.MarkMessageProcessed(msgID) {
logger.DebugCF("wecom", "Skipping duplicate message", map[string]any{
"msg_id": msgID,
})
return
}
c.processedMsgs[msgID] = true
// Clean up old messages while still holding the lock to avoid a data race
// on len(). Reset the map but re-insert the current msgID so it remains
// deduplicated.
if len(c.processedMsgs) > 1000 {
c.processedMsgs = make(map[string]bool)
c.processedMsgs[msgID] = true
}
c.msgMu.Unlock()
senderID := msg.From.UserID
+16 -47
View File
@@ -412,22 +412,9 @@ func TestWeComBotHandleMessageCallback(t *testing.T) {
}
ch, _ := NewWeComBotChannel(cfg, msgBus)
t.Run("valid direct message callback", func(t *testing.T) {
// Create JSON message for direct chat (single)
jsonMsg := `{
"msgid": "test_msg_id_123",
"aibotid": "test_aibot_id",
"chattype": "single",
"from": {"userid": "user123"},
"response_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test",
"msgtype": "text",
"text": {"content": "Hello World"}
}`
// Encrypt message
runBotMessageCallback := func(t *testing.T, jsonMsg string) *httptest.ResponseRecorder {
t.Helper()
encrypted, _ := encryptTestMessage(jsonMsg, aesKey)
// Create encrypted XML wrapper
encryptedWrapper := struct {
XMLName xml.Name `xml:"xml"`
Encrypt string `xml:"Encrypt"`
@@ -435,20 +422,29 @@ func TestWeComBotHandleMessageCallback(t *testing.T) {
Encrypt: encrypted,
}
wrapperData, _ := xml.Marshal(encryptedWrapper)
timestamp := "1234567890"
nonce := "test_nonce"
signature := generateSignature("test_token", timestamp, nonce, encrypted)
req := httptest.NewRequest(
http.MethodPost,
"/webhook/wecom?msg_signature="+signature+"&timestamp="+timestamp+"&nonce="+nonce,
bytes.NewReader(wrapperData),
)
w := httptest.NewRecorder()
ch.handleMessageCallback(context.Background(), w, req)
return w
}
t.Run("valid direct message callback", func(t *testing.T) {
w := runBotMessageCallback(t, `{
"msgid": "test_msg_id_123",
"aibotid": "test_aibot_id",
"chattype": "single",
"from": {"userid": "user123"},
"response_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test",
"msgtype": "text",
"text": {"content": "Hello World"}
}`)
if w.Code != http.StatusOK {
t.Errorf("status code = %d, want %d", w.Code, http.StatusOK)
}
@@ -458,8 +454,7 @@ func TestWeComBotHandleMessageCallback(t *testing.T) {
})
t.Run("valid group message callback", func(t *testing.T) {
// Create JSON message for group chat
jsonMsg := `{
w := runBotMessageCallback(t, `{
"msgid": "test_msg_id_456",
"aibotid": "test_aibot_id",
"chatid": "group_chat_id_123",
@@ -468,33 +463,7 @@ func TestWeComBotHandleMessageCallback(t *testing.T) {
"response_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test",
"msgtype": "text",
"text": {"content": "Hello Group"}
}`
// Encrypt message
encrypted, _ := encryptTestMessage(jsonMsg, aesKey)
// Create encrypted XML wrapper
encryptedWrapper := struct {
XMLName xml.Name `xml:"xml"`
Encrypt string `xml:"Encrypt"`
}{
Encrypt: encrypted,
}
wrapperData, _ := xml.Marshal(encryptedWrapper)
timestamp := "1234567890"
nonce := "test_nonce"
signature := generateSignature("test_token", timestamp, nonce, encrypted)
req := httptest.NewRequest(
http.MethodPost,
"/webhook/wecom?msg_signature="+signature+"&timestamp="+timestamp+"&nonce="+nonce,
bytes.NewReader(wrapperData),
)
w := httptest.NewRecorder()
ch.handleMessageCallback(context.Background(), w, req)
}`)
if w.Code != http.StatusOK {
t.Errorf("status code = %d, want %d", w.Code, http.StatusOK)
}
+112 -47
View File
@@ -1,12 +1,15 @@
package wecom
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"crypto/sha1"
"encoding/base64"
"encoding/binary"
"fmt"
"math/big"
"sort"
"strings"
)
@@ -14,25 +17,23 @@ import (
// blockSize is the PKCS7 block size used by WeCom (32)
const blockSize = 32
// computeSignature computes the WeCom message signature from the given parameters.
// It sorts [token, timestamp, nonce, encrypt], concatenates them and returns the SHA1 hex digest.
func computeSignature(token, timestamp, nonce, encrypt string) string {
params := []string{token, timestamp, nonce, encrypt}
sort.Strings(params)
str := strings.Join(params, "")
hash := sha1.Sum([]byte(str))
return fmt.Sprintf("%x", hash)
}
// verifySignature verifies the message signature for WeCom
// This is a common function used by both WeCom Bot and WeCom App
func verifySignature(token, msgSignature, timestamp, nonce, msgEncrypt string) bool {
if token == "" {
return true // Skip verification if token is not set
}
// Sort parameters
params := []string{token, timestamp, nonce, msgEncrypt}
sort.Strings(params)
// Concatenate
str := strings.Join(params, "")
// SHA1 hash
hash := sha1.Sum([]byte(str))
expectedSignature := fmt.Sprintf("%x", hash)
return expectedSignature == msgSignature
return computeSignature(token, timestamp, nonce, msgEncrypt) == msgSignature
}
// decryptMessage decrypts the encrypted message using AES
@@ -53,64 +54,128 @@ func decryptMessageWithVerify(encryptedMsg, encodingAESKey, receiveid string) (s
return string(decoded), nil
}
// Decode AES key (base64)
aesKey, err := base64.StdEncoding.DecodeString(encodingAESKey + "=")
aesKey, err := decodeWeComAESKey(encodingAESKey)
if err != nil {
return "", fmt.Errorf("failed to decode AES key: %w", err)
return "", err
}
// Decode encrypted message
cipherText, err := base64.StdEncoding.DecodeString(encryptedMsg)
if err != nil {
return "", fmt.Errorf("failed to decode message: %w", err)
}
// AES decrypt
plainText, err := decryptAESCBC(aesKey, cipherText)
if err != nil {
return "", err
}
return unpackWeComFrame(plainText, receiveid)
}
// decodeWeComAESKey base64-decodes the 43-character EncodingAESKey (trailing "=" is
// appended automatically) and validates that the result is exactly 32 bytes.
// It is the single place that handles this repeated pattern in both encrypt and decrypt paths.
func decodeWeComAESKey(encodingAESKey string) ([]byte, error) {
aesKey, err := base64.StdEncoding.DecodeString(encodingAESKey + "=")
if err != nil {
return nil, fmt.Errorf("failed to decode AES key: %w", err)
}
if len(aesKey) != 32 {
return nil, fmt.Errorf("invalid AES key length: %d", len(aesKey))
}
return aesKey, nil
}
// encryptAESCBC encrypts plaintext using AES-CBC with the given key, mirroring
// decryptAESCBC. IV = aesKey[:aes.BlockSize]. The caller must PKCS7-pad the
// plaintext to a multiple of aes.BlockSize before calling.
func encryptAESCBC(aesKey, plaintext []byte) ([]byte, error) {
block, err := aes.NewCipher(aesKey)
if err != nil {
return "", fmt.Errorf("failed to create cipher: %w", err)
return nil, fmt.Errorf("failed to create cipher: %w", err)
}
if len(cipherText) < aes.BlockSize {
return "", fmt.Errorf("ciphertext too short")
}
// IV is the first 16 bytes of AESKey
iv := aesKey[:aes.BlockSize]
mode := cipher.NewCBCDecrypter(block, iv)
plainText := make([]byte, len(cipherText))
mode.CryptBlocks(plainText, cipherText)
ciphertext := make([]byte, len(plaintext))
cipher.NewCBCEncrypter(block, iv).CryptBlocks(ciphertext, plaintext)
return ciphertext, nil
}
// Remove PKCS7 padding
plainText, err = pkcs7Unpad(plainText)
if err != nil {
return "", fmt.Errorf("failed to unpad: %w", err)
// packWeComFrame builds the WeCom wire format:
//
// random(16 ASCII digits) + msg_len(4, big-endian) + msg + receiveid
func packWeComFrame(msg, receiveid string) ([]byte, error) {
randomBytes := make([]byte, 16)
for i := range 16 {
n, err := rand.Int(rand.Reader, big.NewInt(10))
if err != nil {
return nil, fmt.Errorf("failed to generate random: %w", err)
}
randomBytes[i] = byte('0' + n.Int64())
}
msgBytes := []byte(msg)
msgLenBytes := make([]byte, 4)
binary.BigEndian.PutUint32(msgLenBytes, uint32(len(msgBytes)))
var buf bytes.Buffer
buf.Write(randomBytes)
buf.Write(msgLenBytes)
buf.Write(msgBytes)
buf.WriteString(receiveid)
return buf.Bytes(), nil
}
// Parse message structure
// Format: random(16) + msg_len(4) + msg + receiveid
if len(plainText) < 20 {
return "", fmt.Errorf("decrypted message too short")
// unpackWeComFrame parses the WeCom wire format produced by packWeComFrame.
// If receiveid is non-empty it verifies the frame's trailing receiveid field.
func unpackWeComFrame(data []byte, receiveid string) (string, error) {
if len(data) < 20 {
return "", fmt.Errorf("decrypted frame too short: %d bytes", len(data))
}
msgLen := binary.BigEndian.Uint32(plainText[16:20])
if int(msgLen) > len(plainText)-20 {
return "", fmt.Errorf("invalid message length")
msgLen := binary.BigEndian.Uint32(data[16:20])
if int(msgLen) > len(data)-20 {
return "", fmt.Errorf("invalid message length: %d", msgLen)
}
msg := plainText[20 : 20+msgLen]
// Verify receiveid if provided
if receiveid != "" && len(plainText) > 20+int(msgLen) {
actualReceiveID := string(plainText[20+msgLen:])
msg := data[20 : 20+msgLen]
if receiveid != "" && len(data) > 20+int(msgLen) {
actualReceiveID := string(data[20+msgLen:])
if actualReceiveID != receiveid {
return "", fmt.Errorf("receiveid mismatch: expected %s, got %s", receiveid, actualReceiveID)
}
}
return string(msg), nil
}
// decryptAESCBC decrypts ciphertext using AES-CBC with the given key.
// IV = aesKey[:aes.BlockSize]. PKCS7 padding is stripped from the returned plaintext.
func decryptAESCBC(aesKey, ciphertext []byte) ([]byte, error) {
if len(ciphertext) == 0 {
return nil, fmt.Errorf("ciphertext is empty")
}
if len(ciphertext)%aes.BlockSize != 0 {
return nil, fmt.Errorf("ciphertext length %d is not a multiple of block size", len(ciphertext))
}
block, err := aes.NewCipher(aesKey)
if err != nil {
return nil, fmt.Errorf("failed to create cipher: %w", err)
}
iv := aesKey[:aes.BlockSize]
plaintext := make([]byte, len(ciphertext))
cipher.NewCBCDecrypter(block, iv).CryptBlocks(plaintext, ciphertext)
plaintext, err = pkcs7Unpad(plaintext)
if err != nil {
return nil, fmt.Errorf("failed to unpad: %w", err)
}
return plaintext, nil
}
// pkcs7Pad adds PKCS7 padding
func pkcs7Pad(data []byte, blockSize int) []byte {
padding := blockSize - (len(data) % blockSize)
if padding == 0 {
padding = blockSize
}
padText := bytes.Repeat([]byte{byte(padding)}, padding)
return append(data, padText...)
}
// pkcs7Unpad removes PKCS7 padding with validation
func pkcs7Unpad(data []byte) ([]byte, error) {
if len(data) == 0 {
+54
View File
@@ -0,0 +1,54 @@
package wecom
import "sync"
const wecomMaxProcessedMessages = 1000
// MessageDeduplicator provides thread-safe message deduplication using a circular queue (ring buffer)
// combined with a hash map. This ensures fast O(1) lookups while naturally evicting the oldest
// messages without causing "amnesia cliffs" when the limit is reached.
type MessageDeduplicator struct {
mu sync.Mutex
msgs map[string]bool
ring []string
idx int
max int
}
// NewMessageDeduplicator creates a new deduplicator with the specified capacity.
func NewMessageDeduplicator(maxEntries int) *MessageDeduplicator {
if maxEntries <= 0 {
maxEntries = wecomMaxProcessedMessages
}
return &MessageDeduplicator{
msgs: make(map[string]bool, maxEntries),
ring: make([]string, maxEntries),
max: maxEntries,
}
}
// MarkMessageProcessed marks msgID as processed and returns false for duplicates.
func (d *MessageDeduplicator) MarkMessageProcessed(msgID string) bool {
d.mu.Lock()
defer d.mu.Unlock()
// 1. Check for duplicate
if d.msgs[msgID] {
return false
}
// 2. Evict the oldest message at our current ring position (if any)
oldestID := d.ring[d.idx]
if oldestID != "" {
delete(d.msgs, oldestID)
}
// 3. Store the new message
d.msgs[msgID] = true
d.ring[d.idx] = msgID
// 4. Advance the circle queue index
d.idx = (d.idx + 1) % d.max
return true
}
+83
View File
@@ -0,0 +1,83 @@
package wecom
import (
"sync"
"testing"
)
func TestMessageDeduplicator_DuplicateDetection(t *testing.T) {
d := NewMessageDeduplicator(wecomMaxProcessedMessages)
if ok := d.MarkMessageProcessed("msg-1"); !ok {
t.Fatalf("first message should be accepted")
}
if ok := d.MarkMessageProcessed("msg-1"); ok {
t.Fatalf("duplicate message should be rejected")
}
}
func TestMessageDeduplicator_ConcurrentSameMessage(t *testing.T) {
d := NewMessageDeduplicator(wecomMaxProcessedMessages)
const goroutines = 64
var wg sync.WaitGroup
wg.Add(goroutines)
results := make(chan bool, goroutines)
for i := 0; i < goroutines; i++ {
go func() {
defer wg.Done()
results <- d.MarkMessageProcessed("msg-concurrent")
}()
}
wg.Wait()
close(results)
successes := 0
for ok := range results {
if ok {
successes++
}
}
if successes != 1 {
t.Fatalf("expected exactly 1 successful mark, got %d", successes)
}
}
func TestMessageDeduplicator_CircularQueueEviction(t *testing.T) {
// Create a deduplicator with a very small capacity to test eviction easily.
capacity := 3
d := NewMessageDeduplicator(capacity)
// Fill the queue.
d.MarkMessageProcessed("msg-1")
d.MarkMessageProcessed("msg-2")
d.MarkMessageProcessed("msg-3")
// At this point, the queue is full. msg-1 is the oldest.
if len(d.msgs) != 3 {
t.Fatalf("expected map size to be 3, got %d", len(d.msgs))
}
// This should evict msg-1 and add msg-4.
if ok := d.MarkMessageProcessed("msg-4"); !ok {
t.Fatalf("msg-4 should be accepted")
}
if len(d.msgs) != 3 {
t.Fatalf("expected map size to remain at max capacity (3), got %d", len(d.msgs))
}
// msg-1 should now be forgotten (evicted).
if ok := d.MarkMessageProcessed("msg-1"); !ok {
t.Fatalf("msg-1 should be accepted again because it was evicted")
}
// msg-2 should have been evicted when we added msg-1 back.
if ok := d.MarkMessageProcessed("msg-2"); !ok {
t.Fatalf("msg-2 should be accepted again because it was evicted")
}
}
+3
View File
@@ -13,4 +13,7 @@ func init() {
channels.RegisterFactory("wecom_app", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
return NewWeComAppChannel(cfg.Channels.WeComApp, b)
})
channels.RegisterFactory("wecom_aibot", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
return NewWeComAIBotChannel(cfg.Channels.WeComAIBot, b)
})
}
+89 -14
View File
@@ -180,6 +180,18 @@ type AgentDefaults struct {
MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"`
Temperature *float64 `json:"temperature,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"`
MaxToolIterations int `json:"max_tool_iterations" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"`
SummarizeMessageThreshold int `json:"summarize_message_threshold" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_MESSAGE_THRESHOLD"`
SummarizeTokenPercent int `json:"summarize_token_percent" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_TOKEN_PERCENT"`
MaxMediaSize int `json:"max_media_size,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_MEDIA_SIZE"`
}
const DefaultMaxMediaSize = 20 * 1024 * 1024 // 20 MB
func (d *AgentDefaults) GetMaxMediaSize() int {
if d.MaxMediaSize > 0 {
return d.MaxMediaSize
}
return DefaultMaxMediaSize
}
// GetModelName returns the effective model name for the agent defaults.
@@ -192,19 +204,20 @@ func (d *AgentDefaults) GetModelName() string {
}
type ChannelsConfig struct {
WhatsApp WhatsAppConfig `json:"whatsapp"`
Telegram TelegramConfig `json:"telegram"`
Feishu FeishuConfig `json:"feishu"`
Discord DiscordConfig `json:"discord"`
MaixCam MaixCamConfig `json:"maixcam"`
QQ QQConfig `json:"qq"`
DingTalk DingTalkConfig `json:"dingtalk"`
Slack SlackConfig `json:"slack"`
LINE LINEConfig `json:"line"`
OneBot OneBotConfig `json:"onebot"`
WeCom WeComConfig `json:"wecom"`
WeComApp WeComAppConfig `json:"wecom_app"`
Pico PicoConfig `json:"pico"`
WhatsApp WhatsAppConfig `json:"whatsapp"`
Telegram TelegramConfig `json:"telegram"`
Feishu FeishuConfig `json:"feishu"`
Discord DiscordConfig `json:"discord"`
MaixCam MaixCamConfig `json:"maixcam"`
QQ QQConfig `json:"qq"`
DingTalk DingTalkConfig `json:"dingtalk"`
Slack SlackConfig `json:"slack"`
LINE LINEConfig `json:"line"`
OneBot OneBotConfig `json:"onebot"`
WeCom WeComConfig `json:"wecom"`
WeComApp WeComAppConfig `json:"wecom_app"`
WeComAIBot WeComAIBotConfig `json:"wecom_aibot"`
Pico PicoConfig `json:"pico"`
}
// GroupTriggerConfig controls when the bot responds in group chats.
@@ -236,6 +249,7 @@ type WhatsAppConfig struct {
type TelegramConfig struct {
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_TELEGRAM_ENABLED"`
Token string `json:"token" env:"PICOCLAW_CHANNELS_TELEGRAM_TOKEN"`
BaseURL string `json:"base_url" env:"PICOCLAW_CHANNELS_TELEGRAM_BASE_URL"`
Proxy string `json:"proxy" env:"PICOCLAW_CHANNELS_TELEGRAM_PROXY"`
AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_TELEGRAM_ALLOW_FROM"`
GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"`
@@ -252,12 +266,14 @@ type FeishuConfig struct {
VerificationToken string `json:"verification_token" env:"PICOCLAW_CHANNELS_FEISHU_VERIFICATION_TOKEN"`
AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_FEISHU_ALLOW_FROM"`
GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"`
Placeholder PlaceholderConfig `json:"placeholder,omitempty"`
ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_FEISHU_REASONING_CHANNEL_ID"`
}
type DiscordConfig struct {
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_DISCORD_ENABLED"`
Token string `json:"token" env:"PICOCLAW_CHANNELS_DISCORD_TOKEN"`
Proxy string `json:"proxy" env:"PICOCLAW_CHANNELS_DISCORD_PROXY"`
AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_DISCORD_ALLOW_FROM"`
MentionOnly bool `json:"mention_only" env:"PICOCLAW_CHANNELS_DISCORD_MENTION_ONLY"`
GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"`
@@ -360,6 +376,18 @@ type WeComAppConfig struct {
ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_WECOM_APP_REASONING_CHANNEL_ID"`
}
type WeComAIBotConfig struct {
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_ENABLED"`
Token string `json:"token" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_TOKEN"`
EncodingAESKey string `json:"encoding_aes_key" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_ENCODING_AES_KEY"`
WebhookPath string `json:"webhook_path" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_WEBHOOK_PATH"`
AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_ALLOW_FROM"`
ReplyTimeout int `json:"reply_timeout" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_REPLY_TIMEOUT"`
MaxSteps int `json:"max_steps" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_MAX_STEPS"` // Maximum streaming steps
WelcomeMessage string `json:"welcome_message" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_WELCOME_MESSAGE"` // Sent on enter_chat event; empty = no welcome
ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_REASONING_CHANNEL_ID"`
}
type PicoConfig struct {
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_PICO_ENABLED"`
Token string `json:"token" env:"PICOCLAW_CHANNELS_PICO_TOKEN"`
@@ -386,6 +414,7 @@ type DevicesConfig struct {
type ProvidersConfig struct {
Anthropic ProviderConfig `json:"anthropic"`
OpenAI OpenAIProviderConfig `json:"openai"`
LiteLLM ProviderConfig `json:"litellm"`
OpenRouter ProviderConfig `json:"openrouter"`
Groq ProviderConfig `json:"groq"`
Zhipu ProviderConfig `json:"zhipu"`
@@ -410,6 +439,7 @@ type ProvidersConfig struct {
func (p ProvidersConfig) IsEmpty() bool {
return p.Anthropic.APIKey == "" && p.Anthropic.APIBase == "" &&
p.OpenAI.APIKey == "" && p.OpenAI.APIBase == "" &&
p.LiteLLM.APIKey == "" && p.LiteLLM.APIBase == "" &&
p.OpenRouter.APIKey == "" && p.OpenRouter.APIBase == "" &&
p.Groq.APIKey == "" && p.Groq.APIBase == "" &&
p.Zhipu.APIKey == "" && p.Zhipu.APIBase == "" &&
@@ -519,11 +549,22 @@ type PerplexityConfig struct {
MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_PERPLEXITY_MAX_RESULTS"`
}
type GLMSearchConfig struct {
Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_WEB_GLM_ENABLED"`
APIKey string `json:"api_key" env:"PICOCLAW_TOOLS_WEB_GLM_API_KEY"`
BaseURL string `json:"base_url" env:"PICOCLAW_TOOLS_WEB_GLM_BASE_URL"`
// SearchEngine specifies the search backend: "search_std" (default),
// "search_pro", "search_pro_sogou", or "search_pro_quark".
SearchEngine string `json:"search_engine" env:"PICOCLAW_TOOLS_WEB_GLM_SEARCH_ENGINE"`
MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_GLM_MAX_RESULTS"`
}
type WebToolsConfig struct {
Brave BraveConfig `json:"brave"`
Tavily TavilyConfig `json:"tavily"`
DuckDuckGo DuckDuckGoConfig `json:"duckduckgo"`
Perplexity PerplexityConfig `json:"perplexity"`
GLMSearch GLMSearchConfig `json:"glm_search"`
// Proxy is an optional proxy URL for web tools (http/https/socks5/socks5h).
// For authenticated proxies, prefer HTTP_PROXY/HTTPS_PROXY env vars instead of embedding credentials in config.
Proxy string `json:"proxy,omitempty" env:"PICOCLAW_TOOLS_WEB_PROXY"`
@@ -554,6 +595,7 @@ type ToolsConfig struct {
Exec ExecConfig `json:"exec"`
Skills SkillsToolsConfig `json:"skills"`
MediaCleanup MediaCleanupConfig `json:"media_cleanup"`
MCP MCPConfig `json:"mcp"`
}
type SkillsToolsConfig struct {
@@ -583,6 +625,34 @@ type ClawHubRegistryConfig struct {
MaxResponseSize int `json:"max_response_size" env:"PICOCLAW_SKILLS_REGISTRIES_CLAWHUB_MAX_RESPONSE_SIZE"`
}
// MCPServerConfig defines configuration for a single MCP server
type MCPServerConfig struct {
// Enabled indicates whether this MCP server is active
Enabled bool `json:"enabled"`
// Command is the executable to run (e.g., "npx", "python", "/path/to/server")
Command string `json:"command"`
// Args are the arguments to pass to the command
Args []string `json:"args,omitempty"`
// Env are environment variables to set for the server process (stdio only)
Env map[string]string `json:"env,omitempty"`
// EnvFile is the path to a file containing environment variables (stdio only)
EnvFile string `json:"env_file,omitempty"`
// Type is "stdio", "sse", or "http" (default: stdio if command is set, sse if url is set)
Type string `json:"type,omitempty"`
// URL is used for SSE/HTTP transport
URL string `json:"url,omitempty"`
// Headers are HTTP headers to send with requests (sse/http only)
Headers map[string]string `json:"headers,omitempty"`
}
// MCPConfig defines configuration for all MCP servers
type MCPConfig struct {
// Enabled globally enables/disables MCP integration
Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_MCP_ENABLED"`
// Servers is a map of server name to server configuration
Servers map[string]MCPServerConfig `json:"servers,omitempty"`
}
func LoadConfig(path string) (*Config, error) {
cfg := DefaultConfig()
@@ -639,7 +709,8 @@ func (c *Config) migrateChannelConfigs() {
}
// OneBot: group_trigger_prefix -> group_trigger.prefixes
if len(c.Channels.OneBot.GroupTriggerPrefix) > 0 && len(c.Channels.OneBot.GroupTrigger.Prefixes) == 0 {
if len(c.Channels.OneBot.GroupTriggerPrefix) > 0 &&
len(c.Channels.OneBot.GroupTrigger.Prefixes) == 0 {
c.Channels.OneBot.GroupTrigger.Prefixes = c.Channels.OneBot.GroupTriggerPrefix
}
}
@@ -749,6 +820,7 @@ func (c *Config) findMatches(modelName string) []ModelConfig {
// HasProvidersConfig checks if any provider in the old providers config has configuration.
func (c *Config) HasProvidersConfig() bool {
<<<<<<< HEAD
v := c.Providers
return v.Anthropic.APIKey != "" || v.Anthropic.APIBase != "" ||
v.OpenAI.APIKey != "" || v.OpenAI.APIBase != "" ||
@@ -769,6 +841,9 @@ func (c *Config) HasProvidersConfig() bool {
v.Antigravity.APIKey != "" || v.Antigravity.APIBase != "" ||
v.Qwen.APIKey != "" || v.Qwen.APIBase != "" ||
v.Mistral.APIKey != "" || v.Mistral.APIBase != ""
=======
return !c.Providers.IsEmpty()
>>>>>>> origin_picoclaw/main
}
// ValidateModelList validates all ModelConfig entries in the model_list.
+12
View File
@@ -435,6 +435,18 @@ func TestLoadConfig_WebToolsProxy(t *testing.T) {
}
// TestDefaultConfig_DMScope verifies the default dm_scope value
// TestDefaultConfig_SummarizationThresholds verifies summarization defaults
func TestDefaultConfig_SummarizationThresholds(t *testing.T) {
cfg := DefaultConfig()
if cfg.Agents.Defaults.SummarizeMessageThreshold != 20 {
t.Errorf("SummarizeMessageThreshold = %d, want 20", cfg.Agents.Defaults.SummarizeMessageThreshold)
}
if cfg.Agents.Defaults.SummarizeTokenPercent != 75 {
t.Errorf("SummarizeTokenPercent = %d, want 75", cfg.Agents.Defaults.SummarizeTokenPercent)
}
}
func TestDefaultConfig_DMScope(t *testing.T) {
cfg := DefaultConfig()
+30 -7
View File
@@ -26,13 +26,15 @@ func DefaultConfig() *Config {
return &Config{
Agents: AgentsConfig{
Defaults: AgentDefaults{
Workspace: workspacePath,
RestrictToWorkspace: true,
Provider: "",
Model: "",
MaxTokens: 32768,
Temperature: nil, // nil means use provider default
MaxToolIterations: 50,
Workspace: workspacePath,
RestrictToWorkspace: true,
Provider: "",
Model: "",
MaxTokens: 32768,
Temperature: nil, // nil means use provider default
MaxToolIterations: 50,
SummarizeMessageThreshold: 20,
SummarizeTokenPercent: 75,
},
},
Bindings: []AgentBinding{},
@@ -137,6 +139,16 @@ func DefaultConfig() *Config {
AllowFrom: FlexibleStringSlice{},
ReplyTimeout: 5,
},
WeComAIBot: WeComAIBotConfig{
Enabled: false,
Token: "",
EncodingAESKey: "",
WebhookPath: "/webhook/wecom-aibot",
AllowFrom: FlexibleStringSlice{},
ReplyTimeout: 5,
MaxSteps: 10,
WelcomeMessage: "Hello! I'm your AI assistant. How can I help you today?",
},
Pico: PicoConfig{
Enabled: false,
Token: "",
@@ -339,6 +351,13 @@ func DefaultConfig() *Config {
APIKey: "",
MaxResults: 5,
},
GLMSearch: GLMSearchConfig{
Enabled: false,
APIKey: "",
BaseURL: "https://open.bigmodel.cn/api/paas/v4/web_search",
SearchEngine: "search_std",
MaxResults: 5,
},
},
Cron: CronToolsConfig{
ExecTimeoutMinutes: 5,
@@ -359,6 +378,10 @@ func DefaultConfig() *Config {
TTLSeconds: 300,
},
},
MCP: MCPConfig{
Enabled: false,
Servers: map[string]MCPServerConfig{},
},
},
Heartbeat: HeartbeatConfig{
Enabled: true,
+17
View File
@@ -88,6 +88,23 @@ func ConvertProvidersToModelList(cfg *Config) []ModelConfig {
}, true
},
},
{
providerNames: []string{"litellm"},
protocol: "litellm",
buildConfig: func(p ProvidersConfig) (ModelConfig, bool) {
if p.LiteLLM.APIKey == "" && p.LiteLLM.APIBase == "" {
return ModelConfig{}, false
}
return ModelConfig{
ModelName: "litellm",
Model: "litellm/auto",
APIKey: p.LiteLLM.APIKey,
APIBase: p.LiteLLM.APIBase,
Proxy: p.LiteLLM.Proxy,
RequestTimeout: p.LiteLLM.RequestTimeout,
}, true
},
},
{
providerNames: []string{"openrouter"},
protocol: "openrouter",
+28
View File
@@ -63,6 +63,33 @@ func TestConvertProvidersToModelList_Anthropic(t *testing.T) {
}
}
func TestConvertProvidersToModelList_LiteLLM(t *testing.T) {
cfg := &Config{
Providers: ProvidersConfig{
LiteLLM: ProviderConfig{
APIKey: "litellm-key",
APIBase: "http://localhost:4000/v1",
},
},
}
result := ConvertProvidersToModelList(cfg)
if len(result) != 1 {
t.Fatalf("len(result) = %d, want 1", len(result))
}
if result[0].ModelName != "litellm" {
t.Errorf("ModelName = %q, want %q", result[0].ModelName, "litellm")
}
if result[0].Model != "litellm/auto" {
t.Errorf("Model = %q, want %q", result[0].Model, "litellm/auto")
}
if result[0].APIBase != "http://localhost:4000/v1" {
t.Errorf("APIBase = %q, want %q", result[0].APIBase, "http://localhost:4000/v1")
}
}
func TestConvertProvidersToModelList_Multiple(t *testing.T) {
cfg := &Config{
Providers: ProvidersConfig{
@@ -115,6 +142,7 @@ func TestConvertProvidersToModelList_AllProviders(t *testing.T) {
cfg := &Config{
Providers: ProvidersConfig{
OpenAI: OpenAIProviderConfig{ProviderConfig: ProviderConfig{APIKey: "key1"}},
LiteLLM: ProviderConfig{APIKey: "key-litellm", APIBase: "http://localhost:4000/v1"},
Anthropic: ProviderConfig{APIKey: "key2"},
OpenRouter: ProviderConfig{APIKey: "key3"},
Groq: ProviderConfig{APIKey: "key4"},
+51 -67
View File
@@ -47,79 +47,63 @@ func TestExecuteHeartbeat_Async(t *testing.T) {
}
}
func TestExecuteHeartbeat_Error(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "heartbeat-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
hs := NewHeartbeatService(tmpDir, 30, true)
hs.stopChan = make(chan struct{}) // Enable for testing
hs.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult {
return &tools.ToolResult{
ForLLM: "Heartbeat failed: connection error",
ForUser: "",
Silent: false,
IsError: true,
Async: false,
}
})
// Create HEARTBEAT.md
os.WriteFile(filepath.Join(tmpDir, "HEARTBEAT.md"), []byte("Test task"), 0o644)
hs.executeHeartbeat()
// Check log file for error message
logFile := filepath.Join(tmpDir, "heartbeat.log")
data, err := os.ReadFile(logFile)
if err != nil {
t.Fatalf("Failed to read log file: %v", err)
func TestExecuteHeartbeat_ResultLogging(t *testing.T) {
tests := []struct {
name string
result *tools.ToolResult
wantLog string
}{
{
name: "error result",
result: &tools.ToolResult{
ForLLM: "Heartbeat failed: connection error",
ForUser: "",
Silent: false,
IsError: true,
Async: false,
},
wantLog: "error message",
},
{
name: "silent result",
result: &tools.ToolResult{
ForLLM: "Heartbeat completed successfully",
ForUser: "",
Silent: true,
IsError: false,
Async: false,
},
wantLog: "completion message",
},
}
logContent := string(data)
if logContent == "" {
t.Error("Expected log file to contain error message")
}
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "heartbeat-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
func TestExecuteHeartbeat_Silent(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "heartbeat-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
hs := NewHeartbeatService(tmpDir, 30, true)
hs.stopChan = make(chan struct{}) // Enable for testing
hs := NewHeartbeatService(tmpDir, 30, true)
hs.stopChan = make(chan struct{}) // Enable for testing
hs.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult {
return tt.result
})
hs.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult {
return &tools.ToolResult{
ForLLM: "Heartbeat completed successfully",
ForUser: "",
Silent: true,
IsError: false,
Async: false,
}
})
os.WriteFile(filepath.Join(tmpDir, "HEARTBEAT.md"), []byte("Test task"), 0o644)
hs.executeHeartbeat()
// Create HEARTBEAT.md
os.WriteFile(filepath.Join(tmpDir, "HEARTBEAT.md"), []byte("Test task"), 0o644)
hs.executeHeartbeat()
// Check log file for completion message
logFile := filepath.Join(tmpDir, "heartbeat.log")
data, err := os.ReadFile(logFile)
if err != nil {
t.Fatalf("Failed to read log file: %v", err)
}
logContent := string(data)
if logContent == "" {
t.Error("Expected log file to contain completion message")
logFile := filepath.Join(tmpDir, "heartbeat.log")
data, err := os.ReadFile(logFile)
if err != nil {
t.Fatalf("Failed to read log file: %v", err)
}
if string(data) == "" {
t.Errorf("Expected log file to contain %s", tt.wantLog)
}
})
}
}
+532
View File
@@ -0,0 +1,532 @@
package mcp
import (
"bufio"
"context"
"errors"
"fmt"
"net/http"
"os"
"os/exec"
"path/filepath"
"strings"
"sync"
"sync/atomic"
"github.com/modelcontextprotocol/go-sdk/mcp"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/logger"
)
// headerTransport is an http.RoundTripper that adds custom headers to requests
type headerTransport struct {
base http.RoundTripper
headers map[string]string
}
func (t *headerTransport) RoundTrip(req *http.Request) (*http.Response, error) {
// Clone the request to avoid modifying the original
req = req.Clone(req.Context())
// Add custom headers
for key, value := range t.headers {
req.Header.Set(key, value)
}
// Use the base transport
base := t.base
if base == nil {
base = http.DefaultTransport
}
return base.RoundTrip(req)
}
// loadEnvFile loads environment variables from a file in .env format
// Each line should be in the format: KEY=value
// Lines starting with # are comments
// Empty lines are ignored
func loadEnvFile(path string) (map[string]string, error) {
file, err := os.Open(path)
if err != nil {
return nil, fmt.Errorf("failed to open env file: %w", err)
}
defer file.Close()
envVars := make(map[string]string)
scanner := bufio.NewScanner(file)
lineNum := 0
for scanner.Scan() {
lineNum++
line := strings.TrimSpace(scanner.Text())
// Skip empty lines and comments
if line == "" || strings.HasPrefix(line, "#") {
continue
}
// Parse KEY=value
parts := strings.SplitN(line, "=", 2)
if len(parts) != 2 {
return nil, fmt.Errorf("invalid format at line %d: %s", lineNum, line)
}
key := strings.TrimSpace(parts[0])
value := strings.TrimSpace(parts[1])
if key == "" {
return nil, fmt.Errorf("invalid format at line %d: empty key", lineNum)
}
// Remove surrounding quotes if present
if len(value) >= 2 {
if (value[0] == '"' && value[len(value)-1] == '"') ||
(value[0] == '\'' && value[len(value)-1] == '\'') {
value = value[1 : len(value)-1]
}
}
envVars[key] = value
}
if err := scanner.Err(); err != nil {
return nil, fmt.Errorf("error reading env file: %w", err)
}
return envVars, nil
}
// ServerConnection represents a connection to an MCP server
type ServerConnection struct {
Name string
Client *mcp.Client
Session *mcp.ClientSession
Tools []*mcp.Tool
}
// Manager manages multiple MCP server connections
type Manager struct {
servers map[string]*ServerConnection
mu sync.RWMutex
closed atomic.Bool // changed from bool to atomic.Bool to avoid TOCTOU race
wg sync.WaitGroup // tracks in-flight CallTool calls
}
// NewManager creates a new MCP manager
func NewManager() *Manager {
return &Manager{
servers: make(map[string]*ServerConnection),
}
}
// LoadFromConfig loads MCP servers from configuration
func (m *Manager) LoadFromConfig(ctx context.Context, cfg *config.Config) error {
return m.LoadFromMCPConfig(ctx, cfg.Tools.MCP, cfg.WorkspacePath())
}
// LoadFromMCPConfig loads MCP servers from MCP configuration and workspace path.
// This is the minimal dependency version that doesn't require the full Config object.
func (m *Manager) LoadFromMCPConfig(
ctx context.Context,
mcpCfg config.MCPConfig,
workspacePath string,
) error {
if !mcpCfg.Enabled {
logger.InfoCF("mcp", "MCP integration is disabled", nil)
return nil
}
if len(mcpCfg.Servers) == 0 {
logger.InfoCF("mcp", "No MCP servers configured", nil)
return nil
}
logger.InfoCF("mcp", "Initializing MCP servers",
map[string]any{
"count": len(mcpCfg.Servers),
})
var wg sync.WaitGroup
errs := make(chan error, len(mcpCfg.Servers))
enabledCount := 0
for name, serverCfg := range mcpCfg.Servers {
if !serverCfg.Enabled {
logger.DebugCF("mcp", "Skipping disabled server",
map[string]any{
"server": name,
})
continue
}
enabledCount++
wg.Add(1)
go func(name string, serverCfg config.MCPServerConfig, workspace string) {
defer wg.Done()
// Resolve relative envFile paths relative to workspace
if serverCfg.EnvFile != "" && !filepath.IsAbs(serverCfg.EnvFile) {
if workspace == "" {
err := fmt.Errorf(
"workspace path is empty while resolving relative envFile %q for server %s",
serverCfg.EnvFile,
name,
)
logger.ErrorCF("mcp", "Invalid MCP server configuration",
map[string]any{
"server": name,
"env_file": serverCfg.EnvFile,
"error": err.Error(),
})
errs <- err
return
}
serverCfg.EnvFile = filepath.Join(workspace, serverCfg.EnvFile)
}
if err := m.ConnectServer(ctx, name, serverCfg); err != nil {
logger.ErrorCF("mcp", "Failed to connect to MCP server",
map[string]any{
"server": name,
"error": err.Error(),
})
errs <- fmt.Errorf("failed to connect to server %s: %w", name, err)
}
}(name, serverCfg, workspacePath)
}
wg.Wait()
close(errs)
// Collect errors
var allErrors []error
for err := range errs {
allErrors = append(allErrors, err)
}
connectedCount := len(m.GetServers())
// If all enabled servers failed to connect, return aggregated error
if enabledCount > 0 && connectedCount == 0 {
logger.ErrorCF("mcp", "All MCP servers failed to connect",
map[string]any{
"failed": len(allErrors),
"total": enabledCount,
})
return errors.Join(allErrors...)
}
if len(allErrors) > 0 {
logger.WarnCF("mcp", "Some MCP servers failed to connect",
map[string]any{
"failed": len(allErrors),
"connected": connectedCount,
"total": enabledCount,
})
// Don't fail completely if some servers successfully connected
}
logger.InfoCF("mcp", "MCP server initialization complete",
map[string]any{
"connected": connectedCount,
"total": enabledCount,
})
return nil
}
// ConnectServer connects to a single MCP server
func (m *Manager) ConnectServer(
ctx context.Context,
name string,
cfg config.MCPServerConfig,
) error {
logger.InfoCF("mcp", "Connecting to MCP server",
map[string]any{
"server": name,
"command": cfg.Command,
"args_count": len(cfg.Args),
})
// Create client
client := mcp.NewClient(&mcp.Implementation{
Name: "picoclaw",
Version: "1.0.0",
}, nil)
// Create transport based on configuration
// Auto-detect transport type if not explicitly specified
var transport mcp.Transport
transportType := cfg.Type
// Auto-detect: if URL is provided, use SSE; if command is provided, use stdio
if transportType == "" {
if cfg.URL != "" {
transportType = "sse"
} else if cfg.Command != "" {
transportType = "stdio"
} else {
return fmt.Errorf("either URL or command must be provided")
}
}
switch transportType {
case "sse", "http":
if cfg.URL == "" {
return fmt.Errorf("URL is required for SSE/HTTP transport")
}
logger.DebugCF("mcp", "Using SSE/HTTP transport",
map[string]any{
"server": name,
"url": cfg.URL,
})
sseTransport := &mcp.StreamableClientTransport{
Endpoint: cfg.URL,
}
// Add custom headers if provided
if len(cfg.Headers) > 0 {
// Create a custom HTTP client with header-injecting transport
sseTransport.HTTPClient = &http.Client{
Transport: &headerTransport{
base: http.DefaultTransport,
headers: cfg.Headers,
},
}
logger.DebugCF("mcp", "Added custom HTTP headers",
map[string]any{
"server": name,
"header_count": len(cfg.Headers),
})
}
transport = sseTransport
case "stdio":
if cfg.Command == "" {
return fmt.Errorf("command is required for stdio transport")
}
logger.DebugCF("mcp", "Using stdio transport",
map[string]any{
"server": name,
"command": cfg.Command,
})
// Create command with context
cmd := exec.CommandContext(ctx, cfg.Command, cfg.Args...)
// Build environment variables with proper override semantics
// Use a map to ensure config variables override file variables
envMap := make(map[string]string)
// Start with parent process environment
for _, e := range cmd.Environ() {
if idx := strings.Index(e, "="); idx > 0 {
envMap[e[:idx]] = e[idx+1:]
}
}
// Load environment variables from file if specified
if cfg.EnvFile != "" {
envVars, err := loadEnvFile(cfg.EnvFile)
if err != nil {
return fmt.Errorf("failed to load env file %s: %w", cfg.EnvFile, err)
}
for k, v := range envVars {
envMap[k] = v
}
logger.DebugCF("mcp", "Loaded environment variables from file",
map[string]any{
"server": name,
"envFile": cfg.EnvFile,
"var_count": len(envVars),
})
}
// Environment variables from config override those from file
for k, v := range cfg.Env {
envMap[k] = v
}
// Convert map to slice
env := make([]string, 0, len(envMap))
for k, v := range envMap {
env = append(env, fmt.Sprintf("%s=%s", k, v))
}
cmd.Env = env
transport = &mcp.CommandTransport{Command: cmd}
default:
return fmt.Errorf(
"unsupported transport type: %s (supported: stdio, sse, http)",
transportType,
)
}
// Connect to server
session, err := client.Connect(ctx, transport, nil)
if err != nil {
return fmt.Errorf("failed to connect: %w", err)
}
// Get server info
initResult := session.InitializeResult()
logger.InfoCF("mcp", "Connected to MCP server",
map[string]any{
"server": name,
"serverName": initResult.ServerInfo.Name,
"serverVersion": initResult.ServerInfo.Version,
"protocol": initResult.ProtocolVersion,
})
// List available tools if supported
var tools []*mcp.Tool
if initResult.Capabilities.Tools != nil {
for tool, err := range session.Tools(ctx, nil) {
if err != nil {
logger.WarnCF("mcp", "Error listing tool",
map[string]any{
"server": name,
"error": err.Error(),
})
continue
}
tools = append(tools, tool)
}
logger.InfoCF("mcp", "Listed tools from MCP server",
map[string]any{
"server": name,
"toolCount": len(tools),
})
}
// Store connection
m.mu.Lock()
m.servers[name] = &ServerConnection{
Name: name,
Client: client,
Session: session,
Tools: tools,
}
m.mu.Unlock()
return nil
}
// GetServers returns all connected servers
func (m *Manager) GetServers() map[string]*ServerConnection {
m.mu.RLock()
defer m.mu.RUnlock()
result := make(map[string]*ServerConnection, len(m.servers))
for k, v := range m.servers {
result[k] = v
}
return result
}
// GetServer returns a specific server connection
func (m *Manager) GetServer(name string) (*ServerConnection, bool) {
m.mu.RLock()
defer m.mu.RUnlock()
conn, ok := m.servers[name]
return conn, ok
}
// CallTool calls a tool on a specific server
func (m *Manager) CallTool(
ctx context.Context,
serverName, toolName string,
arguments map[string]any,
) (*mcp.CallToolResult, error) {
// Check if closed before acquiring lock (fast path)
if m.closed.Load() {
return nil, fmt.Errorf("manager is closed")
}
m.mu.RLock()
// Double-check after acquiring lock to prevent TOCTOU race
if m.closed.Load() {
m.mu.RUnlock()
return nil, fmt.Errorf("manager is closed")
}
conn, ok := m.servers[serverName]
if ok {
m.wg.Add(1) // Add to WaitGroup while holding the lock
}
m.mu.RUnlock()
if !ok {
return nil, fmt.Errorf("server %s not found", serverName)
}
defer m.wg.Done()
params := &mcp.CallToolParams{
Name: toolName,
Arguments: arguments,
}
result, err := conn.Session.CallTool(ctx, params)
if err != nil {
return nil, fmt.Errorf("failed to call tool: %w", err)
}
return result, nil
}
// Close closes all server connections
func (m *Manager) Close() error {
// Use Swap to atomically set closed=true and get the previous value
// This prevents TOCTOU race with CallTool's closed check
if m.closed.Swap(true) {
return nil // already closed
}
// Wait for all in-flight CallTool calls to finish before closing sessions
// After closed=true is set, no new CallTool can start (they check closed first)
m.wg.Wait()
m.mu.Lock()
defer m.mu.Unlock()
logger.InfoCF("mcp", "Closing all MCP server connections",
map[string]any{
"count": len(m.servers),
})
var errs []error
for name, conn := range m.servers {
if err := conn.Session.Close(); err != nil {
logger.ErrorCF("mcp", "Failed to close server connection",
map[string]any{
"server": name,
"error": err.Error(),
})
errs = append(errs, fmt.Errorf("server %s: %w", name, err))
}
}
m.servers = make(map[string]*ServerConnection)
if len(errs) > 0 {
return fmt.Errorf("failed to close %d server(s): %w", len(errs), errors.Join(errs...))
}
return nil
}
// GetAllTools returns all tools from all connected servers
func (m *Manager) GetAllTools() map[string][]*mcp.Tool {
m.mu.RLock()
defer m.mu.RUnlock()
result := make(map[string][]*mcp.Tool)
for name, conn := range m.servers {
if len(conn.Tools) > 0 {
result[name] = conn.Tools
}
}
return result
}
+298
View File
@@ -0,0 +1,298 @@
package mcp
import (
"context"
"os"
"path/filepath"
"strings"
"testing"
sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp"
"github.com/sipeed/picoclaw/pkg/config"
)
func TestLoadEnvFile(t *testing.T) {
tests := []struct {
name string
content string
expected map[string]string
expectErr bool
}{
{
name: "basic env file",
content: `API_KEY=secret123
DATABASE_URL=postgres://localhost/db
PORT=8080`,
expected: map[string]string{
"API_KEY": "secret123",
"DATABASE_URL": "postgres://localhost/db",
"PORT": "8080",
},
expectErr: false,
},
{
name: "with comments and empty lines",
content: `# This is a comment
API_KEY=secret123
# Another comment
DATABASE_URL=postgres://localhost/db
PORT=8080`,
expected: map[string]string{
"API_KEY": "secret123",
"DATABASE_URL": "postgres://localhost/db",
"PORT": "8080",
},
expectErr: false,
},
{
name: "with quoted values",
content: `API_KEY="secret with spaces"
NAME='single quoted'
PLAIN=no-quotes`,
expected: map[string]string{
"API_KEY": "secret with spaces",
"NAME": "single quoted",
"PLAIN": "no-quotes",
},
expectErr: false,
},
{
name: "with spaces around equals",
content: `API_KEY = secret123
DATABASE_URL= postgres://localhost/db
PORT =8080`,
expected: map[string]string{
"API_KEY": "secret123",
"DATABASE_URL": "postgres://localhost/db",
"PORT": "8080",
},
expectErr: false,
},
{
name: "invalid format - no equals",
content: `INVALID_LINE`,
expectErr: true,
},
{
name: "empty file",
content: ``,
expected: map[string]string{},
expectErr: false,
},
{
name: "only comments",
content: `# Comment 1
# Comment 2`,
expected: map[string]string{},
expectErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tmpDir := t.TempDir()
envFile := filepath.Join(tmpDir, ".env")
if err := os.WriteFile(envFile, []byte(tt.content), 0o644); err != nil {
t.Fatalf("Failed to create test file: %v", err)
}
result, err := loadEnvFile(envFile)
if tt.expectErr {
if err == nil {
t.Errorf("Expected error but got none")
}
return
}
if err != nil {
t.Errorf("Unexpected error: %v", err)
return
}
if len(result) != len(tt.expected) {
t.Errorf("Expected %d variables, got %d", len(tt.expected), len(result))
}
for key, expectedValue := range tt.expected {
if actualValue, ok := result[key]; !ok {
t.Errorf("Expected key %s not found", key)
} else if actualValue != expectedValue {
t.Errorf("For key %s: expected %q, got %q", key, expectedValue, actualValue)
}
}
})
}
}
func TestLoadEnvFileNotFound(t *testing.T) {
_, err := loadEnvFile("/nonexistent/file.env")
if err == nil {
t.Error("Expected error for nonexistent file")
}
}
func TestEnvFilePriority(t *testing.T) {
// Create a temporary .env file
tmpDir := t.TempDir()
envFile := filepath.Join(tmpDir, ".env")
envContent := `API_KEY=from_file
DATABASE_URL=from_file
SHARED_VAR=from_file`
if err := os.WriteFile(envFile, []byte(envContent), 0o644); err != nil {
t.Fatalf("Failed to create .env file: %v", err)
}
// Load envFile
envVars, err := loadEnvFile(envFile)
if err != nil {
t.Fatalf("Failed to load env file: %v", err)
}
// Verify envFile variables
if envVars["API_KEY"] != "from_file" {
t.Errorf("Expected API_KEY=from_file, got %s", envVars["API_KEY"])
}
// Simulate config.Env overriding envFile
configEnv := map[string]string{
"SHARED_VAR": "from_config",
"NEW_VAR": "from_config",
}
// Merge: envFile first, then config overrides
merged := make(map[string]string)
for k, v := range envVars {
merged[k] = v
}
for k, v := range configEnv {
merged[k] = v
}
// Verify priority: config.Env should override envFile
if merged["SHARED_VAR"] != "from_config" {
t.Errorf(
"Expected SHARED_VAR=from_config (config should override file), got %s",
merged["SHARED_VAR"],
)
}
if merged["API_KEY"] != "from_file" {
t.Errorf("Expected API_KEY=from_file, got %s", merged["API_KEY"])
}
if merged["NEW_VAR"] != "from_config" {
t.Errorf("Expected NEW_VAR=from_config, got %s", merged["NEW_VAR"])
}
}
func TestLoadFromMCPConfig_EmptyWorkspaceWithRelativeEnvFile(t *testing.T) {
mgr := NewManager()
mcpCfg := config.MCPConfig{
Enabled: true,
Servers: map[string]config.MCPServerConfig{
"test-server": {
Enabled: true,
Command: "echo",
Args: []string{"ok"},
EnvFile: ".env",
},
},
}
err := mgr.LoadFromMCPConfig(context.Background(), mcpCfg, "")
if err == nil {
t.Fatal("expected error for relative env_file with empty workspace path, got nil")
}
if !strings.Contains(err.Error(), "workspace path is empty") {
t.Fatalf("expected workspace path validation error, got: %v", err)
}
}
func TestNewManager_InitialState(t *testing.T) {
mgr := NewManager()
if mgr == nil {
t.Fatal("expected manager instance, got nil")
}
if len(mgr.GetServers()) != 0 {
t.Fatalf("expected no servers on new manager, got %d", len(mgr.GetServers()))
}
}
func TestLoadFromMCPConfig_DisabledOrEmptyServers(t *testing.T) {
mgr := NewManager()
err := mgr.LoadFromMCPConfig(context.Background(), config.MCPConfig{Enabled: false}, "/tmp")
if err != nil {
t.Fatalf("expected nil error when MCP disabled, got: %v", err)
}
err = mgr.LoadFromMCPConfig(context.Background(), config.MCPConfig{Enabled: true}, "/tmp")
if err != nil {
t.Fatalf("expected nil error when no servers configured, got: %v", err)
}
}
func TestGetServers_ReturnsCopy(t *testing.T) {
mgr := NewManager()
mgr.servers["s1"] = &ServerConnection{Name: "s1"}
servers := mgr.GetServers()
delete(servers, "s1")
if _, ok := mgr.GetServer("s1"); !ok {
t.Fatal("expected internal manager state to remain unchanged")
}
}
func TestGetAllTools_FiltersEmptyTools(t *testing.T) {
mgr := NewManager()
mgr.servers["empty"] = &ServerConnection{Name: "empty", Tools: nil}
mgr.servers["with-tools"] = &ServerConnection{Name: "with-tools", Tools: []*sdkmcp.Tool{{}}}
all := mgr.GetAllTools()
if _, ok := all["empty"]; ok {
t.Fatal("expected server without tools to be excluded")
}
if _, ok := all["with-tools"]; !ok {
t.Fatal("expected server with tools to be included")
}
}
func TestCallTool_ErrorsForClosedOrMissingServer(t *testing.T) {
t.Run("manager closed", func(t *testing.T) {
mgr := NewManager()
mgr.closed.Store(true)
_, err := mgr.CallTool(context.Background(), "s1", "tool", nil)
if err == nil || !strings.Contains(err.Error(), "manager is closed") {
t.Fatalf("expected manager closed error, got: %v", err)
}
})
t.Run("server missing", func(t *testing.T) {
mgr := NewManager()
_, err := mgr.CallTool(context.Background(), "missing", "tool", nil)
if err == nil || !strings.Contains(err.Error(), "not found") {
t.Fatalf("expected server not found error, got: %v", err)
}
})
}
func TestClose_IdempotentOnEmptyManager(t *testing.T) {
mgr := NewManager()
if err := mgr.Close(); err != nil {
t.Fatalf("first close should succeed, got: %v", err)
}
if err := mgr.Close(); err != nil {
t.Fatalf("second close should be idempotent, got: %v", err)
}
}
+460
View File
@@ -0,0 +1,460 @@
package memory
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"hash/fnv"
"log"
"os"
"path/filepath"
"strings"
"sync"
"time"
"github.com/sipeed/picoclaw/pkg/fileutil"
"github.com/sipeed/picoclaw/pkg/providers"
)
const (
// numLockShards is the fixed number of mutexes used to serialize
// per-session access. Using a sharded array instead of a map keeps
// memory bounded regardless of how many sessions are created over
// the lifetime of the process — important for a long-running daemon.
numLockShards = 64
// maxLineSize is the maximum size of a single JSON line in a .jsonl
// file. Tool results (read_file, web search, etc.) can be large, so
// we set a generous limit. The scanner starts at 64 KB and grows
// only as needed up to this cap.
maxLineSize = 10 * 1024 * 1024 // 10 MB
)
// sessionMeta holds per-session metadata stored in a .meta.json file.
type sessionMeta struct {
Key string `json:"key"`
Summary string `json:"summary"`
Skip int `json:"skip"`
Count int `json:"count"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// JSONLStore implements Store using append-only JSONL files.
//
// Each session is stored as two files:
//
// {sanitized_key}.jsonl — one JSON-encoded message per line, append-only
// {sanitized_key}.meta.json — session metadata (summary, logical truncation offset)
//
// Messages are never physically deleted from the JSONL file. Instead,
// TruncateHistory records a "skip" offset in the metadata file and
// GetHistory ignores lines before that offset. This keeps all writes
// append-only, which is both fast and crash-safe.
type JSONLStore struct {
dir string
locks [numLockShards]sync.Mutex
}
// NewJSONLStore creates a new JSONL-backed store rooted at dir.
func NewJSONLStore(dir string) (*JSONLStore, error) {
err := os.MkdirAll(dir, 0o755)
if err != nil {
return nil, fmt.Errorf("memory: create directory: %w", err)
}
return &JSONLStore{dir: dir}, nil
}
// sessionLock returns a mutex for the given session key.
// Keys are mapped to a fixed pool of shards via FNV hash, so
// memory usage is O(1) regardless of total session count.
func (s *JSONLStore) sessionLock(key string) *sync.Mutex {
h := fnv.New32a()
h.Write([]byte(key))
return &s.locks[h.Sum32()%numLockShards]
}
func (s *JSONLStore) jsonlPath(key string) string {
return filepath.Join(s.dir, sanitizeKey(key)+".jsonl")
}
func (s *JSONLStore) metaPath(key string) string {
return filepath.Join(s.dir, sanitizeKey(key)+".meta.json")
}
// sanitizeKey converts a session key to a safe filename component.
// Mirrors pkg/session.sanitizeFilename so that migration paths match.
//
// Note: this is a lossy mapping — "telegram:123" and "telegram_123"
// both produce the same filename. This is an intentional tradeoff:
// keys with colons (e.g. from channels) are by far the common case,
// and a bidirectional encoding (like URL-encoding) would complicate
// file listings and debugging.
func sanitizeKey(key string) string {
return strings.ReplaceAll(key, ":", "_")
}
// readMeta loads the metadata file for a session.
// Returns a zero-value sessionMeta if the file does not exist.
func (s *JSONLStore) readMeta(key string) (sessionMeta, error) {
data, err := os.ReadFile(s.metaPath(key))
if os.IsNotExist(err) {
return sessionMeta{Key: key}, nil
}
if err != nil {
return sessionMeta{}, fmt.Errorf("memory: read meta: %w", err)
}
var meta sessionMeta
err = json.Unmarshal(data, &meta)
if err != nil {
return sessionMeta{}, fmt.Errorf("memory: decode meta: %w", err)
}
return meta, nil
}
// writeMeta atomically writes the metadata file using the project's
// standard WriteFileAtomic (temp + fsync + rename).
func (s *JSONLStore) writeMeta(key string, meta sessionMeta) error {
data, err := json.MarshalIndent(meta, "", " ")
if err != nil {
return fmt.Errorf("memory: encode meta: %w", err)
}
return fileutil.WriteFileAtomic(s.metaPath(key), data, 0o644)
}
// readMessages reads valid JSON lines from a .jsonl file, skipping
// the first `skip` lines without unmarshaling them. This avoids the
// cost of json.Unmarshal on logically truncated messages.
// Malformed trailing lines (e.g. from a crash) are silently skipped.
func readMessages(path string, skip int) ([]providers.Message, error) {
f, err := os.Open(path)
if os.IsNotExist(err) {
return []providers.Message{}, nil
}
if err != nil {
return nil, fmt.Errorf("memory: open jsonl: %w", err)
}
defer f.Close()
var msgs []providers.Message
scanner := bufio.NewScanner(f)
// Allow large lines for tool results (read_file, web search, etc.).
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
lineNum := 0
for scanner.Scan() {
line := scanner.Bytes()
if len(line) == 0 {
continue
}
lineNum++
if lineNum <= skip {
continue
}
var msg providers.Message
if err := json.Unmarshal(line, &msg); err != nil {
// Corrupt line — likely a partial write from a crash.
// Log so operators know data was skipped, but don't
// fail the entire read; this is the standard JSONL
// recovery pattern.
log.Printf("memory: skipping corrupt line %d in %s: %v",
lineNum, filepath.Base(path), err)
continue
}
msgs = append(msgs, msg)
}
if scanner.Err() != nil {
return nil, fmt.Errorf("memory: scan jsonl: %w", scanner.Err())
}
if msgs == nil {
msgs = []providers.Message{}
}
return msgs, nil
}
// countLines counts the total number of non-empty lines in a .jsonl file.
// Used by TruncateHistory to reconcile a stale meta.Count without
// the overhead of unmarshaling every message.
func countLines(path string) (int, error) {
f, err := os.Open(path)
if os.IsNotExist(err) {
return 0, nil
}
if err != nil {
return 0, fmt.Errorf("memory: open jsonl: %w", err)
}
defer f.Close()
n := 0
scanner := bufio.NewScanner(f)
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
for scanner.Scan() {
if len(scanner.Bytes()) > 0 {
n++
}
}
return n, scanner.Err()
}
func (s *JSONLStore) AddMessage(
_ context.Context, sessionKey, role, content string,
) error {
return s.addMsg(sessionKey, providers.Message{
Role: role,
Content: content,
})
}
func (s *JSONLStore) AddFullMessage(
_ context.Context, sessionKey string, msg providers.Message,
) error {
return s.addMsg(sessionKey, msg)
}
// addMsg is the shared implementation for AddMessage and AddFullMessage.
func (s *JSONLStore) addMsg(sessionKey string, msg providers.Message) error {
l := s.sessionLock(sessionKey)
l.Lock()
defer l.Unlock()
// Append the message as a single JSON line.
line, err := json.Marshal(msg)
if err != nil {
return fmt.Errorf("memory: marshal message: %w", err)
}
line = append(line, '\n')
f, err := os.OpenFile(
s.jsonlPath(sessionKey),
os.O_CREATE|os.O_WRONLY|os.O_APPEND,
0o644,
)
if err != nil {
return fmt.Errorf("memory: open jsonl for append: %w", err)
}
_, writeErr := f.Write(line)
if writeErr != nil {
f.Close()
return fmt.Errorf("memory: append message: %w", writeErr)
}
// Flush to physical storage before closing. This matches the
// durability guarantee of writeMeta and rewriteJSONL (which use
// WriteFileAtomic with fsync). Without Sync, a power loss could
// leave the append in the kernel page cache only — lost on reboot.
if syncErr := f.Sync(); syncErr != nil {
f.Close()
return fmt.Errorf("memory: sync jsonl: %w", syncErr)
}
if closeErr := f.Close(); closeErr != nil {
return fmt.Errorf("memory: close jsonl: %w", closeErr)
}
// Update metadata.
meta, err := s.readMeta(sessionKey)
if err != nil {
return err
}
now := time.Now()
if meta.Count == 0 && meta.CreatedAt.IsZero() {
meta.CreatedAt = now
}
meta.Count++
meta.UpdatedAt = now
return s.writeMeta(sessionKey, meta)
}
func (s *JSONLStore) GetHistory(
_ context.Context, sessionKey string,
) ([]providers.Message, error) {
l := s.sessionLock(sessionKey)
l.Lock()
defer l.Unlock()
meta, err := s.readMeta(sessionKey)
if err != nil {
return nil, err
}
// Pass meta.Skip so readMessages skips those lines without
// unmarshaling them — avoids wasted CPU on truncated messages.
msgs, err := readMessages(s.jsonlPath(sessionKey), meta.Skip)
if err != nil {
return nil, err
}
return msgs, nil
}
func (s *JSONLStore) GetSummary(
_ context.Context, sessionKey string,
) (string, error) {
l := s.sessionLock(sessionKey)
l.Lock()
defer l.Unlock()
meta, err := s.readMeta(sessionKey)
if err != nil {
return "", err
}
return meta.Summary, nil
}
func (s *JSONLStore) SetSummary(
_ context.Context, sessionKey, summary string,
) error {
l := s.sessionLock(sessionKey)
l.Lock()
defer l.Unlock()
meta, err := s.readMeta(sessionKey)
if err != nil {
return err
}
now := time.Now()
if meta.CreatedAt.IsZero() {
meta.CreatedAt = now
}
meta.Summary = summary
meta.UpdatedAt = now
return s.writeMeta(sessionKey, meta)
}
func (s *JSONLStore) TruncateHistory(
_ context.Context, sessionKey string, keepLast int,
) error {
l := s.sessionLock(sessionKey)
l.Lock()
defer l.Unlock()
meta, err := s.readMeta(sessionKey)
if err != nil {
return err
}
// Always reconcile meta.Count with the actual line count on disk.
// A crash between the JSONL append and the meta update in addMsg
// leaves meta.Count stale (e.g. file has 101 lines but meta says
// 100). Counting lines is cheap — no unmarshal, just a scan — and
// TruncateHistory is not a hot path, so always re-count.
n, countErr := countLines(s.jsonlPath(sessionKey))
if countErr != nil {
return countErr
}
meta.Count = n
if keepLast <= 0 {
meta.Skip = meta.Count
} else {
effective := meta.Count - meta.Skip
if keepLast < effective {
meta.Skip = meta.Count - keepLast
}
}
meta.UpdatedAt = time.Now()
return s.writeMeta(sessionKey, meta)
}
func (s *JSONLStore) SetHistory(
_ context.Context,
sessionKey string,
history []providers.Message,
) error {
l := s.sessionLock(sessionKey)
l.Lock()
defer l.Unlock()
meta, err := s.readMeta(sessionKey)
if err != nil {
return err
}
now := time.Now()
if meta.CreatedAt.IsZero() {
meta.CreatedAt = now
}
meta.Skip = 0
meta.Count = len(history)
meta.UpdatedAt = now
// Write meta BEFORE rewriting the JSONL file. If we crash between
// the two writes, meta has Skip=0 and the old file is still intact,
// so GetHistory reads from line 1 — returning "too many" messages
// rather than losing data. The next SetHistory call corrects this.
err = s.writeMeta(sessionKey, meta)
if err != nil {
return err
}
return s.rewriteJSONL(sessionKey, history)
}
// Compact physically rewrites the JSONL file, dropping all logically
// skipped lines. This reclaims disk space that accumulates after
// repeated TruncateHistory calls.
//
// It is safe to call at any time; if there is nothing to compact
// (skip == 0) the method returns immediately.
func (s *JSONLStore) Compact(
_ context.Context, sessionKey string,
) error {
l := s.sessionLock(sessionKey)
l.Lock()
defer l.Unlock()
meta, err := s.readMeta(sessionKey)
if err != nil {
return err
}
if meta.Skip == 0 {
return nil
}
// Read only the active messages, skipping truncated lines
// without unmarshaling them.
active, err := readMessages(s.jsonlPath(sessionKey), meta.Skip)
if err != nil {
return err
}
// Write meta BEFORE rewriting the JSONL file. If the process
// crashes between the two writes, meta has Skip=0 and the old
// (uncompacted) file is still intact, so GetHistory reads from
// line 1 — returning previously-truncated messages rather than
// losing data. The next Compact or TruncateHistory corrects this.
meta.Skip = 0
meta.Count = len(active)
meta.UpdatedAt = time.Now()
err = s.writeMeta(sessionKey, meta)
if err != nil {
return err
}
return s.rewriteJSONL(sessionKey, active)
}
// rewriteJSONL atomically replaces the JSONL file with the given messages
// using the project's standard WriteFileAtomic (temp + fsync + rename).
func (s *JSONLStore) rewriteJSONL(
sessionKey string, msgs []providers.Message,
) error {
var buf bytes.Buffer
for i, msg := range msgs {
line, err := json.Marshal(msg)
if err != nil {
return fmt.Errorf("memory: marshal message %d: %w", i, err)
}
buf.Write(line)
buf.WriteByte('\n')
}
return fileutil.WriteFileAtomic(s.jsonlPath(sessionKey), buf.Bytes(), 0o644)
}
func (s *JSONLStore) Close() error {
return nil
}
+835
View File
@@ -0,0 +1,835 @@
package memory
import (
"context"
"os"
"path/filepath"
"sync"
"testing"
"github.com/sipeed/picoclaw/pkg/providers"
)
func newTestStore(t *testing.T) *JSONLStore {
t.Helper()
store, err := NewJSONLStore(t.TempDir())
if err != nil {
t.Fatalf("NewJSONLStore: %v", err)
}
return store
}
func TestNewJSONLStore_CreatesDirectory(t *testing.T) {
dir := filepath.Join(t.TempDir(), "nested", "sessions")
store, err := NewJSONLStore(dir)
if err != nil {
t.Fatalf("NewJSONLStore: %v", err)
}
defer store.Close()
info, err := os.Stat(dir)
if err != nil {
t.Fatalf("Stat: %v", err)
}
if !info.IsDir() {
t.Errorf("expected directory, got file")
}
}
func TestAddMessage_BasicRoundtrip(t *testing.T) {
store := newTestStore(t)
ctx := context.Background()
err := store.AddMessage(ctx, "s1", "user", "hello")
if err != nil {
t.Fatalf("AddMessage: %v", err)
}
err = store.AddMessage(ctx, "s1", "assistant", "hi there")
if err != nil {
t.Fatalf("AddMessage: %v", err)
}
history, err := store.GetHistory(ctx, "s1")
if err != nil {
t.Fatalf("GetHistory: %v", err)
}
if len(history) != 2 {
t.Fatalf("expected 2 messages, got %d", len(history))
}
if history[0].Role != "user" || history[0].Content != "hello" {
t.Errorf("msg[0] = %+v", history[0])
}
if history[1].Role != "assistant" || history[1].Content != "hi there" {
t.Errorf("msg[1] = %+v", history[1])
}
}
func TestAddMessage_AutoCreatesSession(t *testing.T) {
store := newTestStore(t)
ctx := context.Background()
// Adding a message to a non-existent session should work.
err := store.AddMessage(ctx, "new-session", "user", "first message")
if err != nil {
t.Fatalf("AddMessage: %v", err)
}
history, err := store.GetHistory(ctx, "new-session")
if err != nil {
t.Fatalf("GetHistory: %v", err)
}
if len(history) != 1 {
t.Fatalf("expected 1 message, got %d", len(history))
}
}
func TestAddFullMessage_WithToolCalls(t *testing.T) {
store := newTestStore(t)
ctx := context.Background()
msg := providers.Message{
Role: "assistant",
Content: "Let me search that.",
ToolCalls: []providers.ToolCall{
{
ID: "call_abc",
Type: "function",
Function: &providers.FunctionCall{
Name: "web_search",
Arguments: `{"q":"golang jsonl"}`,
},
},
},
}
err := store.AddFullMessage(ctx, "tc", msg)
if err != nil {
t.Fatalf("AddFullMessage: %v", err)
}
history, err := store.GetHistory(ctx, "tc")
if err != nil {
t.Fatalf("GetHistory: %v", err)
}
if len(history) != 1 {
t.Fatalf("expected 1, got %d", len(history))
}
if len(history[0].ToolCalls) != 1 {
t.Fatalf("expected 1 tool call, got %d", len(history[0].ToolCalls))
}
tc := history[0].ToolCalls[0]
if tc.ID != "call_abc" {
t.Errorf("tool call ID = %q", tc.ID)
}
if tc.Function == nil || tc.Function.Name != "web_search" {
t.Errorf("tool call function = %+v", tc.Function)
}
}
func TestAddFullMessage_ToolCallID(t *testing.T) {
store := newTestStore(t)
ctx := context.Background()
msg := providers.Message{
Role: "tool",
Content: "search results here",
ToolCallID: "call_abc",
}
err := store.AddFullMessage(ctx, "tr", msg)
if err != nil {
t.Fatalf("AddFullMessage: %v", err)
}
history, err := store.GetHistory(ctx, "tr")
if err != nil {
t.Fatalf("GetHistory: %v", err)
}
if len(history) != 1 {
t.Fatalf("expected 1, got %d", len(history))
}
if history[0].ToolCallID != "call_abc" {
t.Errorf("ToolCallID = %q", history[0].ToolCallID)
}
}
func TestGetHistory_EmptySession(t *testing.T) {
store := newTestStore(t)
ctx := context.Background()
history, err := store.GetHistory(ctx, "nonexistent")
if err != nil {
t.Fatalf("GetHistory: %v", err)
}
if history == nil {
t.Fatal("expected non-nil empty slice")
}
if len(history) != 0 {
t.Errorf("expected 0 messages, got %d", len(history))
}
}
func TestGetHistory_Ordering(t *testing.T) {
store := newTestStore(t)
ctx := context.Background()
for i := 0; i < 5; i++ {
err := store.AddMessage(
ctx, "order",
"user",
string(rune('a'+i)),
)
if err != nil {
t.Fatalf("AddMessage(%d): %v", i, err)
}
}
history, err := store.GetHistory(ctx, "order")
if err != nil {
t.Fatalf("GetHistory: %v", err)
}
if len(history) != 5 {
t.Fatalf("expected 5, got %d", len(history))
}
for i := 0; i < 5; i++ {
expected := string(rune('a' + i))
if history[i].Content != expected {
t.Errorf("msg[%d].Content = %q, want %q", i, history[i].Content, expected)
}
}
}
func TestSetSummary_GetSummary(t *testing.T) {
store := newTestStore(t)
ctx := context.Background()
// No summary yet.
summary, err := store.GetSummary(ctx, "s1")
if err != nil {
t.Fatalf("GetSummary: %v", err)
}
if summary != "" {
t.Errorf("expected empty, got %q", summary)
}
// Set a summary.
err = store.SetSummary(ctx, "s1", "talked about Go")
if err != nil {
t.Fatalf("SetSummary: %v", err)
}
summary, err = store.GetSummary(ctx, "s1")
if err != nil {
t.Fatalf("GetSummary: %v", err)
}
if summary != "talked about Go" {
t.Errorf("summary = %q", summary)
}
// Update summary.
err = store.SetSummary(ctx, "s1", "updated summary")
if err != nil {
t.Fatalf("SetSummary: %v", err)
}
summary, err = store.GetSummary(ctx, "s1")
if err != nil {
t.Fatalf("GetSummary: %v", err)
}
if summary != "updated summary" {
t.Errorf("summary = %q", summary)
}
}
func TestTruncateHistory_KeepLast(t *testing.T) {
store := newTestStore(t)
ctx := context.Background()
for i := 0; i < 10; i++ {
err := store.AddMessage(
ctx, "trunc",
"user",
string(rune('a'+i)),
)
if err != nil {
t.Fatalf("AddMessage: %v", err)
}
}
err := store.TruncateHistory(ctx, "trunc", 4)
if err != nil {
t.Fatalf("TruncateHistory: %v", err)
}
history, err := store.GetHistory(ctx, "trunc")
if err != nil {
t.Fatalf("GetHistory: %v", err)
}
if len(history) != 4 {
t.Fatalf("expected 4, got %d", len(history))
}
// Should be the last 4: g, h, i, j
if history[0].Content != "g" {
t.Errorf("first kept = %q, want 'g'", history[0].Content)
}
if history[3].Content != "j" {
t.Errorf("last kept = %q, want 'j'", history[3].Content)
}
}
func TestTruncateHistory_KeepZero(t *testing.T) {
store := newTestStore(t)
ctx := context.Background()
for i := 0; i < 5; i++ {
err := store.AddMessage(ctx, "empty", "user", "msg")
if err != nil {
t.Fatalf("AddMessage: %v", err)
}
}
err := store.TruncateHistory(ctx, "empty", 0)
if err != nil {
t.Fatalf("TruncateHistory: %v", err)
}
history, err := store.GetHistory(ctx, "empty")
if err != nil {
t.Fatalf("GetHistory: %v", err)
}
if len(history) != 0 {
t.Errorf("expected 0, got %d", len(history))
}
}
func TestTruncateHistory_KeepMoreThanExists(t *testing.T) {
store := newTestStore(t)
ctx := context.Background()
for i := 0; i < 3; i++ {
err := store.AddMessage(ctx, "few", "user", "msg")
if err != nil {
t.Fatalf("AddMessage: %v", err)
}
}
// Keep 100, but only 3 exist — should keep all.
err := store.TruncateHistory(ctx, "few", 100)
if err != nil {
t.Fatalf("TruncateHistory: %v", err)
}
history, err := store.GetHistory(ctx, "few")
if err != nil {
t.Fatalf("GetHistory: %v", err)
}
if len(history) != 3 {
t.Errorf("expected 3, got %d", len(history))
}
}
func TestSetHistory_ReplacesAll(t *testing.T) {
store := newTestStore(t)
ctx := context.Background()
// Add some initial messages.
for i := 0; i < 5; i++ {
err := store.AddMessage(ctx, "replace", "user", "old")
if err != nil {
t.Fatalf("AddMessage: %v", err)
}
}
// Replace with new history.
newHistory := []providers.Message{
{Role: "user", Content: "new1"},
{Role: "assistant", Content: "new2"},
}
err := store.SetHistory(ctx, "replace", newHistory)
if err != nil {
t.Fatalf("SetHistory: %v", err)
}
history, err := store.GetHistory(ctx, "replace")
if err != nil {
t.Fatalf("GetHistory: %v", err)
}
if len(history) != 2 {
t.Fatalf("expected 2, got %d", len(history))
}
if history[0].Content != "new1" || history[1].Content != "new2" {
t.Errorf("history = %+v", history)
}
}
func TestSetHistory_ResetsSkip(t *testing.T) {
store := newTestStore(t)
ctx := context.Background()
// Add messages and truncate.
for i := 0; i < 10; i++ {
err := store.AddMessage(ctx, "skip-reset", "user", "old")
if err != nil {
t.Fatalf("AddMessage: %v", err)
}
}
err := store.TruncateHistory(ctx, "skip-reset", 3)
if err != nil {
t.Fatalf("TruncateHistory: %v", err)
}
// SetHistory should reset skip to 0.
newHistory := []providers.Message{
{Role: "user", Content: "fresh"},
}
err = store.SetHistory(ctx, "skip-reset", newHistory)
if err != nil {
t.Fatalf("SetHistory: %v", err)
}
history, err := store.GetHistory(ctx, "skip-reset")
if err != nil {
t.Fatalf("GetHistory: %v", err)
}
if len(history) != 1 {
t.Fatalf("expected 1, got %d", len(history))
}
if history[0].Content != "fresh" {
t.Errorf("content = %q", history[0].Content)
}
}
func TestColonInKey(t *testing.T) {
store := newTestStore(t)
ctx := context.Background()
err := store.AddMessage(ctx, "telegram:123", "user", "hi")
if err != nil {
t.Fatalf("AddMessage: %v", err)
}
history, err := store.GetHistory(ctx, "telegram:123")
if err != nil {
t.Fatalf("GetHistory: %v", err)
}
if len(history) != 1 {
t.Fatalf("expected 1, got %d", len(history))
}
// Verify the file is named with underscore.
jsonlFile := filepath.Join(store.dir, "telegram_123.jsonl")
if _, statErr := os.Stat(jsonlFile); statErr != nil {
t.Errorf("expected file %s to exist: %v", jsonlFile, statErr)
}
}
func TestCompact_RemovesSkippedMessages(t *testing.T) {
store := newTestStore(t)
ctx := context.Background()
// Write 10 messages, then truncate to keep last 3.
for i := 0; i < 10; i++ {
err := store.AddMessage(ctx, "compact", "user", string(rune('a'+i)))
if err != nil {
t.Fatalf("AddMessage: %v", err)
}
}
err := store.TruncateHistory(ctx, "compact", 3)
if err != nil {
t.Fatalf("TruncateHistory: %v", err)
}
// Before compact: file still has 10 lines.
allOnDisk, err := readMessages(store.jsonlPath("compact"), 0)
if err != nil {
t.Fatalf("readMessages: %v", err)
}
if len(allOnDisk) != 10 {
t.Fatalf("before compact: expected 10 on disk, got %d", len(allOnDisk))
}
// Compact.
err = store.Compact(ctx, "compact")
if err != nil {
t.Fatalf("Compact: %v", err)
}
// After compact: file should have only 3 lines.
allOnDisk, err = readMessages(store.jsonlPath("compact"), 0)
if err != nil {
t.Fatalf("readMessages: %v", err)
}
if len(allOnDisk) != 3 {
t.Fatalf("after compact: expected 3 on disk, got %d", len(allOnDisk))
}
// GetHistory should still return the same 3 messages.
history, err := store.GetHistory(ctx, "compact")
if err != nil {
t.Fatalf("GetHistory: %v", err)
}
if len(history) != 3 {
t.Fatalf("expected 3, got %d", len(history))
}
if history[0].Content != "h" || history[2].Content != "j" {
t.Errorf("wrong content: %+v", history)
}
}
func TestCompact_NoOpWhenNoSkip(t *testing.T) {
store := newTestStore(t)
ctx := context.Background()
for i := 0; i < 5; i++ {
err := store.AddMessage(ctx, "noop", "user", "msg")
if err != nil {
t.Fatalf("AddMessage: %v", err)
}
}
// Compact without prior truncation — should be a no-op.
err := store.Compact(ctx, "noop")
if err != nil {
t.Fatalf("Compact: %v", err)
}
history, err := store.GetHistory(ctx, "noop")
if err != nil {
t.Fatalf("GetHistory: %v", err)
}
if len(history) != 5 {
t.Errorf("expected 5, got %d", len(history))
}
}
func TestCompact_ThenAppend(t *testing.T) {
store := newTestStore(t)
ctx := context.Background()
for i := 0; i < 8; i++ {
err := store.AddMessage(ctx, "cap", "user", string(rune('a'+i)))
if err != nil {
t.Fatalf("AddMessage: %v", err)
}
}
err := store.TruncateHistory(ctx, "cap", 2)
if err != nil {
t.Fatalf("TruncateHistory: %v", err)
}
err = store.Compact(ctx, "cap")
if err != nil {
t.Fatalf("Compact: %v", err)
}
// Append after compaction should work correctly.
err = store.AddMessage(ctx, "cap", "user", "new")
if err != nil {
t.Fatalf("AddMessage after compact: %v", err)
}
history, err := store.GetHistory(ctx, "cap")
if err != nil {
t.Fatalf("GetHistory: %v", err)
}
if len(history) != 3 {
t.Fatalf("expected 3, got %d", len(history))
}
// g, h (kept from truncation), new (appended after compaction).
if history[0].Content != "g" {
t.Errorf("first = %q, want 'g'", history[0].Content)
}
if history[2].Content != "new" {
t.Errorf("last = %q, want 'new'", history[2].Content)
}
}
func TestTruncateHistory_StaleMetaCount(t *testing.T) {
// Simulates a crash between JSONL append and meta update in addMsg:
// file has N+1 lines but meta.Count is still N. TruncateHistory must
// reconcile with the real line count so that keepLast is accurate.
store := newTestStore(t)
ctx := context.Background()
// Write 10 messages normally (meta.Count = 10).
for i := 0; i < 10; i++ {
err := store.AddMessage(ctx, "stale", "user", string(rune('a'+i)))
if err != nil {
t.Fatalf("AddMessage: %v", err)
}
}
// Simulate crash: append a line to JSONL but do NOT update meta.
// This leaves meta.Count = 10 while the file has 11 lines.
jsonlPath := store.jsonlPath("stale")
f, err := os.OpenFile(jsonlPath, os.O_WRONLY|os.O_APPEND, 0o644)
if err != nil {
t.Fatalf("open for append: %v", err)
}
_, err = f.WriteString(`{"role":"user","content":"orphan"}` + "\n")
if err != nil {
t.Fatalf("write orphan: %v", err)
}
f.Close()
// TruncateHistory(keepLast=4) should keep the last 4 of 11 lines,
// not the last 4 of 10.
err = store.TruncateHistory(ctx, "stale", 4)
if err != nil {
t.Fatalf("TruncateHistory: %v", err)
}
history, err := store.GetHistory(ctx, "stale")
if err != nil {
t.Fatalf("GetHistory: %v", err)
}
if len(history) != 4 {
t.Fatalf("expected 4, got %d", len(history))
}
// Last 4 of [a,b,c,d,e,f,g,h,i,j,orphan] = [h,i,j,orphan]
if history[0].Content != "h" {
t.Errorf("first kept = %q, want 'h'", history[0].Content)
}
if history[3].Content != "orphan" {
t.Errorf("last kept = %q, want 'orphan'", history[3].Content)
}
}
func TestCrashRecovery_PartialLine(t *testing.T) {
store := newTestStore(t)
ctx := context.Background()
// Write a valid message first.
err := store.AddMessage(ctx, "crash", "user", "valid")
if err != nil {
t.Fatalf("AddMessage: %v", err)
}
// Simulate a crash by appending a partial JSON line directly.
jsonlPath := store.jsonlPath("crash")
f, err := os.OpenFile(jsonlPath, os.O_WRONLY|os.O_APPEND, 0o644)
if err != nil {
t.Fatalf("open for append: %v", err)
}
_, err = f.WriteString(`{"role":"user","content":"incomple`)
if err != nil {
t.Fatalf("write partial: %v", err)
}
f.Close()
// GetHistory should return only the valid message.
history, err := store.GetHistory(ctx, "crash")
if err != nil {
t.Fatalf("GetHistory: %v", err)
}
if len(history) != 1 {
t.Fatalf("expected 1 valid message, got %d", len(history))
}
if history[0].Content != "valid" {
t.Errorf("content = %q", history[0].Content)
}
}
func TestPersistence_AcrossInstances(t *testing.T) {
dir := t.TempDir()
ctx := context.Background()
// Write with first instance.
store1, err := NewJSONLStore(dir)
if err != nil {
t.Fatalf("NewJSONLStore: %v", err)
}
err = store1.AddMessage(ctx, "persist", "user", "remember me")
if err != nil {
t.Fatalf("AddMessage: %v", err)
}
err = store1.SetSummary(ctx, "persist", "a test session")
if err != nil {
t.Fatalf("SetSummary: %v", err)
}
store1.Close()
// Read with second instance.
store2, err := NewJSONLStore(dir)
if err != nil {
t.Fatalf("NewJSONLStore: %v", err)
}
defer store2.Close()
history, err := store2.GetHistory(ctx, "persist")
if err != nil {
t.Fatalf("GetHistory: %v", err)
}
if len(history) != 1 || history[0].Content != "remember me" {
t.Errorf("history = %+v", history)
}
summary, err := store2.GetSummary(ctx, "persist")
if err != nil {
t.Fatalf("GetSummary: %v", err)
}
if summary != "a test session" {
t.Errorf("summary = %q", summary)
}
}
func TestConcurrent_AddAndRead(t *testing.T) {
store := newTestStore(t)
ctx := context.Background()
var wg sync.WaitGroup
const goroutines = 10
const msgsPerGoroutine = 20
// Concurrent writes.
for g := 0; g < goroutines; g++ {
wg.Add(1)
go func() {
defer wg.Done()
for i := 0; i < msgsPerGoroutine; i++ {
_ = store.AddMessage(ctx, "concurrent", "user", "msg")
}
}()
}
wg.Wait()
history, err := store.GetHistory(ctx, "concurrent")
if err != nil {
t.Fatalf("GetHistory: %v", err)
}
expected := goroutines * msgsPerGoroutine
if len(history) != expected {
t.Errorf("expected %d messages, got %d", expected, len(history))
}
}
func TestConcurrent_SummarizeRace(t *testing.T) {
// Simulates the #704 race: one goroutine adds messages while
// another truncates + sets summary — like summarizeSession().
store := newTestStore(t)
ctx := context.Background()
// Seed with some messages.
for i := 0; i < 20; i++ {
err := store.AddMessage(ctx, "race", "user", "seed")
if err != nil {
t.Fatalf("AddMessage: %v", err)
}
}
var wg sync.WaitGroup
// Writer goroutine (main agent loop).
wg.Add(1)
go func() {
defer wg.Done()
for i := 0; i < 50; i++ {
_ = store.AddMessage(ctx, "race", "user", "new")
}
}()
// Summarizer goroutine (background task).
wg.Add(1)
go func() {
defer wg.Done()
for i := 0; i < 10; i++ {
_ = store.SetSummary(ctx, "race", "summary")
_ = store.TruncateHistory(ctx, "race", 5)
}
}()
wg.Wait()
// Verify the store is still in a consistent state.
_, err := store.GetHistory(ctx, "race")
if err != nil {
t.Fatalf("GetHistory after race: %v", err)
}
_, err = store.GetSummary(ctx, "race")
if err != nil {
t.Fatalf("GetSummary after race: %v", err)
}
}
func TestMultipleSessions_Isolation(t *testing.T) {
store := newTestStore(t)
ctx := context.Background()
err := store.AddMessage(ctx, "s1", "user", "msg for s1")
if err != nil {
t.Fatalf("AddMessage: %v", err)
}
err = store.AddMessage(ctx, "s2", "user", "msg for s2")
if err != nil {
t.Fatalf("AddMessage: %v", err)
}
h1, err := store.GetHistory(ctx, "s1")
if err != nil {
t.Fatalf("GetHistory s1: %v", err)
}
h2, err := store.GetHistory(ctx, "s2")
if err != nil {
t.Fatalf("GetHistory s2: %v", err)
}
if len(h1) != 1 || h1[0].Content != "msg for s1" {
t.Errorf("s1 history = %+v", h1)
}
if len(h2) != 1 || h2[0].Content != "msg for s2" {
t.Errorf("s2 history = %+v", h2)
}
}
func BenchmarkAddMessage(b *testing.B) {
dir := b.TempDir()
store, err := NewJSONLStore(dir)
if err != nil {
b.Fatalf("NewJSONLStore: %v", err)
}
defer store.Close()
ctx := context.Background()
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = store.AddMessage(ctx, "bench", "user", "benchmark message content")
}
}
func BenchmarkGetHistory_100(b *testing.B) {
dir := b.TempDir()
store, err := NewJSONLStore(dir)
if err != nil {
b.Fatalf("NewJSONLStore: %v", err)
}
defer store.Close()
ctx := context.Background()
for i := 0; i < 100; i++ {
_ = store.AddMessage(ctx, "bench", "user", "message content")
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = store.GetHistory(ctx, "bench")
}
}
func BenchmarkGetHistory_1000(b *testing.B) {
dir := b.TempDir()
store, err := NewJSONLStore(dir)
if err != nil {
b.Fatalf("NewJSONLStore: %v", err)
}
defer store.Close()
ctx := context.Background()
for i := 0; i < 1000; i++ {
_ = store.AddMessage(ctx, "bench", "user", "message content")
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = store.GetHistory(ctx, "bench")
}
}
+108
View File
@@ -0,0 +1,108 @@
package memory
import (
"context"
"encoding/json"
"fmt"
"log"
"os"
"path/filepath"
"strings"
"time"
"github.com/sipeed/picoclaw/pkg/providers"
)
// jsonSession mirrors pkg/session.Session for migration purposes.
type jsonSession struct {
Key string `json:"key"`
Messages []providers.Message `json:"messages"`
Summary string `json:"summary,omitempty"`
Created time.Time `json:"created"`
Updated time.Time `json:"updated"`
}
// MigrateFromJSON reads legacy sessions/*.json files from sessionsDir,
// writes them into the Store, and renames each migrated file to
// .json.migrated as a backup. Returns the number of sessions migrated.
//
// Files that fail to parse are logged and skipped. Already-migrated
// files (.json.migrated) are ignored, making the function idempotent.
func MigrateFromJSON(
ctx context.Context, sessionsDir string, store Store,
) (int, error) {
entries, err := os.ReadDir(sessionsDir)
if os.IsNotExist(err) {
return 0, nil
}
if err != nil {
return 0, fmt.Errorf("memory: read sessions dir: %w", err)
}
migrated := 0
for _, entry := range entries {
if entry.IsDir() {
continue
}
name := entry.Name()
if !strings.HasSuffix(name, ".json") {
continue
}
// Skip already-migrated files.
if strings.HasSuffix(name, ".migrated") {
continue
}
srcPath := filepath.Join(sessionsDir, name)
data, readErr := os.ReadFile(srcPath)
if readErr != nil {
log.Printf("memory: migrate: skip %s: %v", name, readErr)
continue
}
var sess jsonSession
if parseErr := json.Unmarshal(data, &sess); parseErr != nil {
log.Printf("memory: migrate: skip %s: %v", name, parseErr)
continue
}
// Use the key from the JSON content, not the filename.
// Filenames are sanitized (":" → "_") but keys are not.
key := sess.Key
if key == "" {
key = strings.TrimSuffix(name, ".json")
}
// Use SetHistory (atomic replace) instead of per-message
// AddFullMessage. This makes migration idempotent: if the
// process crashes after writing messages but before the
// rename below, a retry replaces the partial data cleanly
// instead of duplicating messages.
if setErr := store.SetHistory(ctx, key, sess.Messages); setErr != nil {
return migrated, fmt.Errorf(
"memory: migrate %s: set history: %w",
name, setErr,
)
}
if sess.Summary != "" {
if sumErr := store.SetSummary(ctx, key, sess.Summary); sumErr != nil {
return migrated, fmt.Errorf(
"memory: migrate %s: set summary: %w",
name, sumErr,
)
}
}
// Rename to .migrated as backup (not delete).
renameErr := os.Rename(srcPath, srcPath+".migrated")
if renameErr != nil {
log.Printf("memory: migrate: rename %s: %v", name, renameErr)
}
migrated++
}
return migrated, nil
}
+384
View File
@@ -0,0 +1,384 @@
package memory
import (
"context"
"encoding/json"
"os"
"path/filepath"
"testing"
"time"
"github.com/sipeed/picoclaw/pkg/providers"
)
func writeJSONSession(
t *testing.T, dir string, filename string, sess jsonSession,
) {
t.Helper()
data, err := json.MarshalIndent(sess, "", " ")
if err != nil {
t.Fatalf("marshal session: %v", err)
}
err = os.WriteFile(filepath.Join(dir, filename), data, 0o644)
if err != nil {
t.Fatalf("write session file: %v", err)
}
}
func TestMigrateFromJSON_Basic(t *testing.T) {
sessionsDir := t.TempDir()
store := newTestStore(t)
ctx := context.Background()
writeJSONSession(t, sessionsDir, "test.json", jsonSession{
Key: "test",
Messages: []providers.Message{
{Role: "user", Content: "hello"},
{Role: "assistant", Content: "hi"},
},
Summary: "A greeting.",
Created: time.Now(),
Updated: time.Now(),
})
count, err := MigrateFromJSON(ctx, sessionsDir, store)
if err != nil {
t.Fatalf("MigrateFromJSON: %v", err)
}
if count != 1 {
t.Errorf("expected 1 migrated, got %d", count)
}
history, err := store.GetHistory(ctx, "test")
if err != nil {
t.Fatalf("GetHistory: %v", err)
}
if len(history) != 2 {
t.Fatalf("expected 2 messages, got %d", len(history))
}
if history[0].Content != "hello" || history[1].Content != "hi" {
t.Errorf("unexpected messages: %+v", history)
}
summary, err := store.GetSummary(ctx, "test")
if err != nil {
t.Fatalf("GetSummary: %v", err)
}
if summary != "A greeting." {
t.Errorf("summary = %q", summary)
}
}
func TestMigrateFromJSON_WithToolCalls(t *testing.T) {
sessionsDir := t.TempDir()
store := newTestStore(t)
ctx := context.Background()
writeJSONSession(t, sessionsDir, "tools.json", jsonSession{
Key: "tools",
Messages: []providers.Message{
{
Role: "assistant",
Content: "Searching...",
ToolCalls: []providers.ToolCall{
{
ID: "call_1",
Type: "function",
Function: &providers.FunctionCall{
Name: "web_search",
Arguments: `{"q":"test"}`,
},
},
},
},
{
Role: "tool",
Content: "result",
ToolCallID: "call_1",
},
},
Created: time.Now(),
Updated: time.Now(),
})
count, err := MigrateFromJSON(ctx, sessionsDir, store)
if err != nil {
t.Fatalf("MigrateFromJSON: %v", err)
}
if count != 1 {
t.Errorf("expected 1, got %d", count)
}
history, err := store.GetHistory(ctx, "tools")
if err != nil {
t.Fatalf("GetHistory: %v", err)
}
if len(history) != 2 {
t.Fatalf("expected 2 messages, got %d", len(history))
}
if len(history[0].ToolCalls) != 1 {
t.Fatalf("expected 1 tool call, got %d", len(history[0].ToolCalls))
}
if history[0].ToolCalls[0].Function.Name != "web_search" {
t.Errorf("function = %q", history[0].ToolCalls[0].Function.Name)
}
if history[1].ToolCallID != "call_1" {
t.Errorf("ToolCallID = %q", history[1].ToolCallID)
}
}
func TestMigrateFromJSON_MultipleFiles(t *testing.T) {
sessionsDir := t.TempDir()
store := newTestStore(t)
ctx := context.Background()
for i := 0; i < 3; i++ {
key := string(rune('a' + i))
writeJSONSession(t, sessionsDir, key+".json", jsonSession{
Key: key,
Messages: []providers.Message{{Role: "user", Content: "msg " + key}},
Created: time.Now(),
Updated: time.Now(),
})
}
count, err := MigrateFromJSON(ctx, sessionsDir, store)
if err != nil {
t.Fatalf("MigrateFromJSON: %v", err)
}
if count != 3 {
t.Errorf("expected 3, got %d", count)
}
for i := 0; i < 3; i++ {
key := string(rune('a' + i))
history, histErr := store.GetHistory(ctx, key)
if histErr != nil {
t.Fatalf("GetHistory(%q): %v", key, histErr)
}
if len(history) != 1 {
t.Errorf("session %q: expected 1 msg, got %d", key, len(history))
}
}
}
func TestMigrateFromJSON_InvalidJSON(t *testing.T) {
sessionsDir := t.TempDir()
store := newTestStore(t)
ctx := context.Background()
// One valid, one invalid.
writeJSONSession(t, sessionsDir, "good.json", jsonSession{
Key: "good",
Messages: []providers.Message{{Role: "user", Content: "ok"}},
Created: time.Now(),
Updated: time.Now(),
})
err := os.WriteFile(
filepath.Join(sessionsDir, "bad.json"),
[]byte("{invalid json"),
0o644,
)
if err != nil {
t.Fatalf("write bad file: %v", err)
}
count, err := MigrateFromJSON(ctx, sessionsDir, store)
if err != nil {
t.Fatalf("MigrateFromJSON: %v", err)
}
if count != 1 {
t.Errorf("expected 1 (bad file skipped), got %d", count)
}
history, err := store.GetHistory(ctx, "good")
if err != nil {
t.Fatalf("GetHistory: %v", err)
}
if len(history) != 1 {
t.Errorf("expected 1 message, got %d", len(history))
}
}
func TestMigrateFromJSON_RenamesFiles(t *testing.T) {
sessionsDir := t.TempDir()
store := newTestStore(t)
ctx := context.Background()
writeJSONSession(t, sessionsDir, "rename.json", jsonSession{
Key: "rename",
Messages: []providers.Message{{Role: "user", Content: "hi"}},
Created: time.Now(),
Updated: time.Now(),
})
_, err := MigrateFromJSON(ctx, sessionsDir, store)
if err != nil {
t.Fatalf("MigrateFromJSON: %v", err)
}
// Original .json should not exist.
_, statErr := os.Stat(filepath.Join(sessionsDir, "rename.json"))
if !os.IsNotExist(statErr) {
t.Error("rename.json should have been renamed")
}
// .json.migrated should exist.
_, statErr = os.Stat(
filepath.Join(sessionsDir, "rename.json.migrated"),
)
if statErr != nil {
t.Errorf("rename.json.migrated should exist: %v", statErr)
}
}
func TestMigrateFromJSON_Idempotent(t *testing.T) {
sessionsDir := t.TempDir()
store := newTestStore(t)
ctx := context.Background()
writeJSONSession(t, sessionsDir, "idem.json", jsonSession{
Key: "idem",
Messages: []providers.Message{{Role: "user", Content: "once"}},
Created: time.Now(),
Updated: time.Now(),
})
count1, err := MigrateFromJSON(ctx, sessionsDir, store)
if err != nil {
t.Fatalf("first migration: %v", err)
}
if count1 != 1 {
t.Errorf("first run: expected 1, got %d", count1)
}
// Second run should find only .migrated files, skip them.
count2, err := MigrateFromJSON(ctx, sessionsDir, store)
if err != nil {
t.Fatalf("second migration: %v", err)
}
if count2 != 0 {
t.Errorf("second run: expected 0, got %d", count2)
}
history, err := store.GetHistory(ctx, "idem")
if err != nil {
t.Fatalf("GetHistory: %v", err)
}
if len(history) != 1 {
t.Errorf("expected 1 message, got %d", len(history))
}
}
func TestMigrateFromJSON_ColonInKey(t *testing.T) {
sessionsDir := t.TempDir()
store := newTestStore(t)
ctx := context.Background()
// File is named telegram_123 (sanitized), but the key inside is telegram:123.
writeJSONSession(t, sessionsDir, "telegram_123.json", jsonSession{
Key: "telegram:123",
Messages: []providers.Message{{Role: "user", Content: "from telegram"}},
Created: time.Now(),
Updated: time.Now(),
})
count, err := MigrateFromJSON(ctx, sessionsDir, store)
if err != nil {
t.Fatalf("MigrateFromJSON: %v", err)
}
if count != 1 {
t.Errorf("expected 1, got %d", count)
}
// Accessible via the original key "telegram:123".
history, err := store.GetHistory(ctx, "telegram:123")
if err != nil {
t.Fatalf("GetHistory: %v", err)
}
if len(history) != 1 {
t.Fatalf("expected 1 message, got %d", len(history))
}
if history[0].Content != "from telegram" {
t.Errorf("content = %q", history[0].Content)
}
// In the file-based store, "telegram:123" and "telegram_123" both
// sanitize to the same filename, so they share storage. This is
// expected — the colon-to-underscore mapping is a one-way function.
history2, err := store.GetHistory(ctx, "telegram_123")
if err != nil {
t.Fatalf("GetHistory: %v", err)
}
if len(history2) != 1 {
t.Errorf("expected 1 (same file), got %d", len(history2))
}
}
func TestMigrateFromJSON_RetryAfterCrash(t *testing.T) {
// Simulates a crash during migration: first run writes messages
// but doesn't rename the .json file. Second run must replace
// (not duplicate) the messages thanks to SetHistory semantics.
sessionsDir := t.TempDir()
store := newTestStore(t)
ctx := context.Background()
writeJSONSession(t, sessionsDir, "retry.json", jsonSession{
Key: "retry",
Messages: []providers.Message{
{Role: "user", Content: "one"},
{Role: "assistant", Content: "two"},
},
Created: time.Now(),
Updated: time.Now(),
})
// First migration succeeds — writes messages and renames file.
count, err := MigrateFromJSON(ctx, sessionsDir, store)
if err != nil {
t.Fatalf("first migration: %v", err)
}
if count != 1 {
t.Fatalf("expected 1, got %d", count)
}
// Simulate "crash before rename": restore the .json file.
src := filepath.Join(sessionsDir, "retry.json.migrated")
dst := filepath.Join(sessionsDir, "retry.json")
if renameErr := os.Rename(src, dst); renameErr != nil {
t.Fatalf("restore .json: %v", renameErr)
}
// Second migration should re-import without duplicating messages.
count, err = MigrateFromJSON(ctx, sessionsDir, store)
if err != nil {
t.Fatalf("second migration: %v", err)
}
if count != 1 {
t.Fatalf("expected 1, got %d", count)
}
history, err := store.GetHistory(ctx, "retry")
if err != nil {
t.Fatalf("GetHistory: %v", err)
}
// Must be exactly 2 messages (not 4 from duplication).
if len(history) != 2 {
t.Fatalf("expected 2 messages (no duplicates), got %d", len(history))
}
if history[0].Content != "one" || history[1].Content != "two" {
t.Errorf("unexpected messages: %+v", history)
}
}
func TestMigrateFromJSON_NonexistentDir(t *testing.T) {
store := newTestStore(t)
ctx := context.Background()
count, err := MigrateFromJSON(ctx, "/nonexistent/path", store)
if err != nil {
t.Fatalf("MigrateFromJSON: %v", err)
}
if count != 0 {
t.Errorf("expected 0, got %d", count)
}
}
+42
View File
@@ -0,0 +1,42 @@
package memory
import (
"context"
"github.com/sipeed/picoclaw/pkg/providers"
)
// Store defines an interface for persistent session storage.
// Each method is an atomic operation — there is no separate Save() call.
type Store interface {
// AddMessage appends a simple text message to a session.
AddMessage(ctx context.Context, sessionKey, role, content string) error
// AddFullMessage appends a complete message (with tool calls, etc.) to a session.
AddFullMessage(ctx context.Context, sessionKey string, msg providers.Message) error
// GetHistory returns all messages for a session in insertion order.
// Returns an empty slice (not nil) if the session does not exist.
GetHistory(ctx context.Context, sessionKey string) ([]providers.Message, error)
// GetSummary returns the conversation summary for a session.
// Returns an empty string if no summary exists.
GetSummary(ctx context.Context, sessionKey string) (string, error)
// SetSummary updates the conversation summary for a session.
SetSummary(ctx context.Context, sessionKey, summary string) error
// TruncateHistory removes all but the last keepLast messages from a session.
// If keepLast <= 0, all messages are removed.
TruncateHistory(ctx context.Context, sessionKey string, keepLast int) error
// SetHistory replaces all messages in a session with the provided history.
SetHistory(ctx context.Context, sessionKey string, history []providers.Message) error
// Compact reclaims storage by physically removing logically truncated
// data. Backends that do not accumulate dead data may return nil.
Compact(ctx context.Context, sessionKey string) error
// Close releases any resources held by the store.
Close() error
}
+42 -51
View File
@@ -118,64 +118,55 @@ func TestPlanWorkspaceMigration(t *testing.T) {
assert.GreaterOrEqual(t, len(actions), 1)
}
func TestPlanWorkspaceMigrationWithExistingDestination(t *testing.T) {
tmpDir := t.TempDir()
srcWorkspace := filepath.Join(tmpDir, "src", "workspace")
dstWorkspace := filepath.Join(tmpDir, "dst", "workspace")
func TestPlanWorkspaceMigrationExistingFile(t *testing.T) {
tests := []struct {
name string
force bool
wantActionType ActionType
}{
{
name: "backup when not forced",
force: false,
wantActionType: ActionBackup,
},
{
name: "copy when forced",
force: true,
wantActionType: ActionCopy,
},
}
err := os.MkdirAll(srcWorkspace, 0o755)
require.NoError(t, err)
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tmpDir := t.TempDir()
srcWorkspace := filepath.Join(tmpDir, "src", "workspace")
dstWorkspace := filepath.Join(tmpDir, "dst", "workspace")
err = os.MkdirAll(dstWorkspace, 0o755)
require.NoError(t, err)
err := os.MkdirAll(srcWorkspace, 0o755)
require.NoError(t, err)
err = os.WriteFile(filepath.Join(srcWorkspace, "file1.txt"), []byte("source"), 0o644)
require.NoError(t, err)
err = os.MkdirAll(dstWorkspace, 0o755)
require.NoError(t, err)
err = os.WriteFile(filepath.Join(dstWorkspace, "file1.txt"), []byte("existing"), 0o644)
require.NoError(t, err)
err = os.WriteFile(filepath.Join(srcWorkspace, "file1.txt"), []byte("source"), 0o644)
require.NoError(t, err)
actions, err := PlanWorkspaceMigration(
srcWorkspace,
dstWorkspace,
[]string{"file1.txt"},
[]string{},
false,
)
require.NoError(t, err)
err = os.WriteFile(filepath.Join(dstWorkspace, "file1.txt"), []byte("existing"), 0o644)
require.NoError(t, err)
require.GreaterOrEqual(t, len(actions), 1)
assert.Equal(t, ActionBackup, actions[0].Type)
}
actions, err := PlanWorkspaceMigration(
srcWorkspace,
dstWorkspace,
[]string{"file1.txt"},
[]string{},
tt.force,
)
require.NoError(t, err)
func TestPlanWorkspaceMigrationForce(t *testing.T) {
tmpDir := t.TempDir()
srcWorkspace := filepath.Join(tmpDir, "src", "workspace")
dstWorkspace := filepath.Join(tmpDir, "dst", "workspace")
err := os.MkdirAll(srcWorkspace, 0o755)
require.NoError(t, err)
err = os.MkdirAll(dstWorkspace, 0o755)
require.NoError(t, err)
err = os.WriteFile(filepath.Join(srcWorkspace, "file1.txt"), []byte("source"), 0o644)
require.NoError(t, err)
err = os.WriteFile(filepath.Join(dstWorkspace, "file1.txt"), []byte("existing"), 0o644)
require.NoError(t, err)
actions, err := PlanWorkspaceMigration(
srcWorkspace,
dstWorkspace,
[]string{"file1.txt"},
[]string{},
true,
)
require.NoError(t, err)
require.GreaterOrEqual(t, len(actions), 1)
assert.Equal(t, ActionCopy, actions[0].Type)
require.GreaterOrEqual(t, len(actions), 1)
assert.Equal(t, tt.wantActionType, actions[0].Type)
})
}
}
func TestPlanWorkspaceMigrationNonExistentSource(t *testing.T) {
+1 -33
View File
@@ -100,44 +100,12 @@ func (p *ClaudeCliProvider) buildSystemPrompt(messages []Message, tools []ToolDe
}
if len(tools) > 0 {
parts = append(parts, p.buildToolsPrompt(tools))
parts = append(parts, buildCLIToolsPrompt(tools))
}
return strings.Join(parts, "\n\n")
}
// buildToolsPrompt creates the tool definitions section for the system prompt.
func (p *ClaudeCliProvider) buildToolsPrompt(tools []ToolDefinition) string {
var sb strings.Builder
sb.WriteString("## Available Tools\n\n")
sb.WriteString("When you need to use a tool, respond with ONLY a JSON object:\n\n")
sb.WriteString("```json\n")
sb.WriteString(
`{"tool_calls":[{"id":"call_xxx","type":"function","function":{"name":"tool_name","arguments":"{...}"}}]}`,
)
sb.WriteString("\n```\n\n")
sb.WriteString("CRITICAL: The 'arguments' field MUST be a JSON-encoded STRING.\n\n")
sb.WriteString("### Tool Definitions:\n\n")
for _, tool := range tools {
if tool.Type != "function" {
continue
}
sb.WriteString(fmt.Sprintf("#### %s\n", tool.Function.Name))
if tool.Function.Description != "" {
sb.WriteString(fmt.Sprintf("Description: %s\n", tool.Function.Description))
}
if len(tool.Function.Parameters) > 0 {
paramsJSON, _ := json.Marshal(tool.Function.Parameters)
sb.WriteString(fmt.Sprintf("Parameters:\n```json\n%s\n```\n", string(paramsJSON)))
}
sb.WriteString("\n")
}
return sb.String()
}
// parseClaudeCliResponse parses the JSON output from the claude CLI.
func (p *ClaudeCliProvider) parseClaudeCliResponse(output string) (*LLMResponse, error) {
var resp claudeCliJSONResponse
+3 -6
View File
@@ -660,12 +660,11 @@ func TestBuildSystemPrompt_ToolsOnlyNoSystem(t *testing.T) {
// --- buildToolsPrompt tests ---
func TestBuildToolsPrompt_SkipsNonFunction(t *testing.T) {
p := NewClaudeCliProvider("/workspace")
tools := []ToolDefinition{
{Type: "other", Function: ToolFunctionDefinition{Name: "skip_me"}},
{Type: "function", Function: ToolFunctionDefinition{Name: "include_me", Description: "Included"}},
}
got := p.buildToolsPrompt(tools)
got := buildCLIToolsPrompt(tools)
if strings.Contains(got, "skip_me") {
t.Error("buildToolsPrompt() should skip non-function tools")
}
@@ -675,11 +674,10 @@ func TestBuildToolsPrompt_SkipsNonFunction(t *testing.T) {
}
func TestBuildToolsPrompt_NoDescription(t *testing.T) {
p := NewClaudeCliProvider("/workspace")
tools := []ToolDefinition{
{Type: "function", Function: ToolFunctionDefinition{Name: "bare_tool"}},
}
got := p.buildToolsPrompt(tools)
got := buildCLIToolsPrompt(tools)
if !strings.Contains(got, "bare_tool") {
t.Error("should include tool name")
}
@@ -689,14 +687,13 @@ func TestBuildToolsPrompt_NoDescription(t *testing.T) {
}
func TestBuildToolsPrompt_NoParameters(t *testing.T) {
p := NewClaudeCliProvider("/workspace")
tools := []ToolDefinition{
{Type: "function", Function: ToolFunctionDefinition{
Name: "no_params_tool",
Description: "A tool with no parameters",
}},
}
got := p.buildToolsPrompt(tools)
got := buildCLIToolsPrompt(tools)
if strings.Contains(got, "Parameters:") {
t.Error("should not include Parameters: section when nil")
}
+1 -33
View File
@@ -115,7 +115,7 @@ func (p *CodexCliProvider) buildPrompt(messages []Message, tools []ToolDefinitio
}
if len(tools) > 0 {
sb.WriteString(p.buildToolsPrompt(tools))
sb.WriteString(buildCLIToolsPrompt(tools))
sb.WriteString("\n\n")
}
@@ -128,38 +128,6 @@ func (p *CodexCliProvider) buildPrompt(messages []Message, tools []ToolDefinitio
return sb.String()
}
// buildToolsPrompt creates a tool definitions section for the prompt.
func (p *CodexCliProvider) buildToolsPrompt(tools []ToolDefinition) string {
var sb strings.Builder
sb.WriteString("## Available Tools\n\n")
sb.WriteString("When you need to use a tool, respond with ONLY a JSON object:\n\n")
sb.WriteString("```json\n")
sb.WriteString(
`{"tool_calls":[{"id":"call_xxx","type":"function","function":{"name":"tool_name","arguments":"{...}"}}]}`,
)
sb.WriteString("\n```\n\n")
sb.WriteString("CRITICAL: The 'arguments' field MUST be a JSON-encoded STRING.\n\n")
sb.WriteString("### Tool Definitions:\n\n")
for _, tool := range tools {
if tool.Type != "function" {
continue
}
sb.WriteString(fmt.Sprintf("#### %s\n", tool.Function.Name))
if tool.Function.Description != "" {
sb.WriteString(fmt.Sprintf("Description: %s\n", tool.Function.Description))
}
if len(tool.Function.Parameters) > 0 {
paramsJSON, _ := json.Marshal(tool.Function.Parameters)
sb.WriteString(fmt.Sprintf("Parameters:\n```json\n%s\n```\n", string(paramsJSON)))
}
sb.WriteString("\n")
}
return sb.String()
}
// codexEvent represents a single JSONL event from `codex exec --json`.
type codexEvent struct {
Type string `json:"type"`
+9
View File
@@ -102,6 +102,15 @@ func resolveProviderSelection(cfg *config.Config) (providerSelection, error) {
sel.apiBase = "https://openrouter.ai/api/v1"
}
}
case "litellm":
if cfg.Providers.LiteLLM.APIKey != "" || cfg.Providers.LiteLLM.APIBase != "" {
sel.apiKey = cfg.Providers.LiteLLM.APIKey
sel.apiBase = cfg.Providers.LiteLLM.APIBase
sel.proxy = cfg.Providers.LiteLLM.Proxy
if sel.apiBase == "" {
sel.apiBase = "http://localhost:4000/v1"
}
}
case "zhipu", "glm":
if cfg.Providers.Zhipu.APIKey != "" {
sel.apiKey = cfg.Providers.Zhipu.APIKey
+4 -2
View File
@@ -53,7 +53,7 @@ func ExtractProtocol(model string) (protocol, modelID string) {
// CreateProviderFromConfig creates a provider based on the ModelConfig.
// It uses the protocol prefix in the Model field to determine which provider to create.
// Supported protocols: openai, anthropic, antigravity, claude-cli, codex-cli, github-copilot
// Supported protocols: openai, litellm, anthropic, antigravity, claude-cli, codex-cli, github-copilot
// Returns the provider, the model ID (without protocol prefix), and any error.
func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, error) {
if cfg == nil {
@@ -92,7 +92,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
cfg.RequestTimeout,
), modelID, nil
case "openrouter", "groq", "zhipu", "gemini", "nvidia",
case "litellm", "openrouter", "groq", "zhipu", "gemini", "nvidia",
"ollama", "moonshot", "shengsuanyun", "deepseek", "cerebras",
"vivgrid", "volcengine", "vllm", "qwen", "mistral":
// All other OpenAI-compatible HTTP providers
@@ -180,6 +180,8 @@ func getDefaultAPIBase(protocol string) string {
return "https://api.openai.com/v1"
case "openrouter":
return "https://openrouter.ai/api/v1"
case "litellm":
return "http://localhost:4000/v1"
case "groq":
return "https://api.groq.com/openai/v1"
case "zhipu":
+26
View File
@@ -136,6 +136,32 @@ func TestCreateProviderFromConfig_DefaultAPIBase(t *testing.T) {
}
}
func TestGetDefaultAPIBase_LiteLLM(t *testing.T) {
if got := getDefaultAPIBase("litellm"); got != "http://localhost:4000/v1" {
t.Fatalf("getDefaultAPIBase(%q) = %q, want %q", "litellm", got, "http://localhost:4000/v1")
}
}
func TestCreateProviderFromConfig_LiteLLM(t *testing.T) {
cfg := &config.ModelConfig{
ModelName: "test-litellm",
Model: "litellm/my-proxy-alias",
APIKey: "test-key",
APIBase: "http://localhost:4000/v1",
}
provider, modelID, err := CreateProviderFromConfig(cfg)
if err != nil {
t.Fatalf("CreateProviderFromConfig() error = %v", err)
}
if provider == nil {
t.Fatal("CreateProviderFromConfig() returned nil provider")
}
if modelID != "my-proxy-alias" {
t.Errorf("modelID = %q, want %q", modelID, "my-proxy-alias")
}
}
func TestCreateProviderFromConfig_Anthropic(t *testing.T) {
cfg := &config.ModelConfig{
ModelName: "test-anthropic",
+21
View File
@@ -17,6 +17,27 @@ func TestResolveProviderSelection(t *testing.T) {
wantProxy string
wantErrSubstr string
}{
{
name: "explicit litellm provider uses configured base",
setup: func(cfg *config.Config) {
cfg.Agents.Defaults.Provider = "litellm"
cfg.Providers.LiteLLM.APIKey = "litellm-key"
cfg.Providers.LiteLLM.APIBase = "http://localhost:4000/v1"
cfg.Providers.LiteLLM.Proxy = "http://127.0.0.1:7890"
},
wantType: providerTypeHTTPCompat,
wantAPIBase: "http://localhost:4000/v1",
wantProxy: "http://127.0.0.1:7890",
},
{
name: "explicit litellm provider defaults base when only key is configured",
setup: func(cfg *config.Config) {
cfg.Agents.Defaults.Provider = "litellm"
cfg.Providers.LiteLLM.APIKey = "litellm-key"
},
wantType: providerTypeHTTPCompat,
wantAPIBase: "http://localhost:4000/v1",
},
{
name: "explicit claude-cli provider routes to cli provider type",
setup: func(cfg *config.Config) {
+58 -16
View File
@@ -116,7 +116,7 @@ func (p *Provider) Chat(
requestBody := map[string]any{
"model": model,
"messages": stripSystemParts(messages),
"messages": serializeMessages(messages),
}
if len(tools) > 0 {
@@ -289,24 +289,62 @@ func parseResponse(body []byte) (*LLMResponse, error) {
// It mirrors protocoltypes.Message but omits SystemParts, which is an
// internal field that would be unknown to third-party endpoints.
type openaiMessage struct {
Role string `json:"role"`
Content string `json:"content"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
Role string `json:"role"`
Content string `json:"content"`
ReasoningContent string `json:"reasoning_content,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
}
// stripSystemParts converts []Message to []openaiMessage, dropping the
// SystemParts field so it doesn't leak into the JSON payload sent to
// OpenAI-compatible APIs (some strict endpoints reject unknown fields).
func stripSystemParts(messages []Message) []openaiMessage {
out := make([]openaiMessage, len(messages))
for i, m := range messages {
out[i] = openaiMessage{
Role: m.Role,
Content: m.Content,
ToolCalls: m.ToolCalls,
ToolCallID: m.ToolCallID,
// serializeMessages converts internal Message structs to the OpenAI wire format.
// - Strips SystemParts (unknown to third-party endpoints)
// - Converts messages with Media to multipart content format (text + image_url parts)
// - Preserves ToolCallID, ToolCalls, and ReasoningContent for all messages
func serializeMessages(messages []Message) []any {
out := make([]any, 0, len(messages))
for _, m := range messages {
if len(m.Media) == 0 {
out = append(out, openaiMessage{
Role: m.Role,
Content: m.Content,
ReasoningContent: m.ReasoningContent,
ToolCalls: m.ToolCalls,
ToolCallID: m.ToolCallID,
})
continue
}
// Multipart content format for messages with media
parts := make([]map[string]any, 0, 1+len(m.Media))
if m.Content != "" {
parts = append(parts, map[string]any{
"type": "text",
"text": m.Content,
})
}
for _, mediaURL := range m.Media {
parts = append(parts, map[string]any{
"type": "image_url",
"image_url": map[string]any{
"url": mediaURL,
},
})
}
msg := map[string]any{
"role": m.Role,
"content": parts,
}
if m.ToolCallID != "" {
msg["tool_call_id"] = m.ToolCallID
}
if len(m.ToolCalls) > 0 {
msg["tool_calls"] = m.ToolCalls
}
if m.ReasoningContent != "" {
msg["reasoning_content"] = m.ReasoningContent
}
out = append(out, msg)
}
return out
}
@@ -323,7 +361,11 @@ func normalizeModel(model, apiBase string) string {
prefix := strings.ToLower(before)
switch prefix {
<<<<<<< HEAD
case "moonshot", "nvidia", "groq", "ollama", "deepseek", "google", "openrouter", "zhipu", "mistral", "vivgrid":
=======
case "litellm", "moonshot", "nvidia", "groq", "ollama", "deepseek", "google", "openrouter", "zhipu", "mistral":
>>>>>>> origin_picoclaw/main
return after
default:
return model
@@ -5,8 +5,11 @@ import (
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
)
func TestProviderChat_UsesMaxCompletionTokensForGLM(t *testing.T) {
@@ -146,6 +149,56 @@ func TestProviderChat_ParsesReasoningContent(t *testing.T) {
}
}
func TestProviderChat_PreservesReasoningContentInHistory(t *testing.T) {
var requestBody map[string]any
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
resp := map[string]any{
"choices": []map[string]any{
{
"message": map[string]any{"content": "ok"},
"finish_reason": "stop",
},
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
p := NewProvider("key", server.URL, "")
// Simulate a multi-turn conversation where the assistant's previous
// reply included reasoning_content (e.g. from kimi-k2.5).
messages := []Message{
{Role: "user", Content: "What is 1+1?"},
{Role: "assistant", Content: "2", ReasoningContent: "Let me think... 1+1=2"},
{Role: "user", Content: "What about 2+2?"},
}
_, err := p.Chat(t.Context(), messages, nil, "kimi-k2.5", nil)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
// Verify reasoning_content is preserved in the serialized request.
reqMessages, ok := requestBody["messages"].([]any)
if !ok {
t.Fatalf("messages is not []any: %T", requestBody["messages"])
}
assistantMsg, ok := reqMessages[1].(map[string]any)
if !ok {
t.Fatalf("assistant message is not map[string]any: %T", reqMessages[1])
}
if assistantMsg["reasoning_content"] != "Let me think... 1+1=2" {
t.Errorf("reasoning_content not preserved in request, got %v", assistantMsg["reasoning_content"])
}
}
func TestProviderChat_HTTPError(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "bad request", http.StatusBadRequest)
@@ -206,6 +259,11 @@ func TestProviderChat_StripsGroqOllamaDeepseekVivgridPrefixes(t *testing.T) {
input string
wantModel string
}{
{
name: "strips litellm prefix and preserves proxy model name",
input: "litellm/my-proxy-alias",
wantModel: "my-proxy-alias",
},
{
name: "strips groq prefix and keeps nested model",
input: "groq/openai/gpt-oss-120b",
@@ -372,3 +430,97 @@ func TestProvider_FunctionalOptionRequestTimeoutNonPositive(t *testing.T) {
t.Fatalf("http timeout = %v, want %v", p.httpClient.Timeout, defaultRequestTimeout)
}
}
func TestSerializeMessages_PlainText(t *testing.T) {
messages := []protocoltypes.Message{
{Role: "user", Content: "hello"},
{Role: "assistant", Content: "hi", ReasoningContent: "thinking..."},
}
result := serializeMessages(messages)
data, err := json.Marshal(result)
if err != nil {
t.Fatal(err)
}
var msgs []map[string]any
json.Unmarshal(data, &msgs)
if msgs[0]["content"] != "hello" {
t.Fatalf("expected plain string content, got %v", msgs[0]["content"])
}
if msgs[1]["reasoning_content"] != "thinking..." {
t.Fatalf("reasoning_content not preserved, got %v", msgs[1]["reasoning_content"])
}
}
func TestSerializeMessages_WithMedia(t *testing.T) {
messages := []protocoltypes.Message{
{Role: "user", Content: "describe this", Media: []string{"data:image/png;base64,abc123"}},
}
result := serializeMessages(messages)
data, _ := json.Marshal(result)
var msgs []map[string]any
json.Unmarshal(data, &msgs)
content, ok := msgs[0]["content"].([]any)
if !ok {
t.Fatalf("expected array content for media message, got %T", msgs[0]["content"])
}
if len(content) != 2 {
t.Fatalf("expected 2 content parts, got %d", len(content))
}
textPart := content[0].(map[string]any)
if textPart["type"] != "text" || textPart["text"] != "describe this" {
t.Fatalf("text part mismatch: %v", textPart)
}
imgPart := content[1].(map[string]any)
if imgPart["type"] != "image_url" {
t.Fatalf("expected image_url type, got %v", imgPart["type"])
}
imgURL := imgPart["image_url"].(map[string]any)
if imgURL["url"] != "data:image/png;base64,abc123" {
t.Fatalf("image url mismatch: %v", imgURL["url"])
}
}
func TestSerializeMessages_MediaWithToolCallID(t *testing.T) {
messages := []protocoltypes.Message{
{Role: "tool", Content: "image result", Media: []string{"data:image/png;base64,xyz"}, ToolCallID: "call_1"},
}
result := serializeMessages(messages)
data, _ := json.Marshal(result)
var msgs []map[string]any
json.Unmarshal(data, &msgs)
if msgs[0]["tool_call_id"] != "call_1" {
t.Fatalf("tool_call_id not preserved with media, got %v", msgs[0]["tool_call_id"])
}
// Content should be multipart array
if _, ok := msgs[0]["content"].([]any); !ok {
t.Fatalf("expected array content, got %T", msgs[0]["content"])
}
}
func TestSerializeMessages_StripsSystemParts(t *testing.T) {
messages := []protocoltypes.Message{
{
Role: "system",
Content: "you are helpful",
SystemParts: []protocoltypes.ContentBlock{
{Type: "text", Text: "you are helpful"},
},
},
}
result := serializeMessages(messages)
data, _ := json.Marshal(result)
raw := string(data)
if strings.Contains(raw, "system_parts") {
t.Fatal("system_parts should not appear in serialized output")
}
}
+1
View File
@@ -65,6 +65,7 @@ type ContentBlock struct {
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
Media []string `json:"media,omitempty"`
ReasoningContent string `json:"reasoning_content,omitempty"`
SystemParts []ContentBlock `json:"system_parts,omitempty"` // structured system blocks for cache-aware adapters
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
+37 -1
View File
@@ -5,7 +5,43 @@
package providers
import "encoding/json"
import (
"encoding/json"
"fmt"
"strings"
)
// buildCLIToolsPrompt creates the tool definitions section for a CLI provider system prompt.
func buildCLIToolsPrompt(tools []ToolDefinition) string {
var sb strings.Builder
sb.WriteString("## Available Tools\n\n")
sb.WriteString("When you need to use a tool, respond with ONLY a JSON object:\n\n")
sb.WriteString("```json\n")
sb.WriteString(
`{"tool_calls":[{"id":"call_xxx","type":"function","function":{"name":"tool_name","arguments":"{...}"}}]}`,
)
sb.WriteString("\n```\n\n")
sb.WriteString("CRITICAL: The 'arguments' field MUST be a JSON-encoded STRING.\n\n")
sb.WriteString("### Tool Definitions:\n\n")
for _, tool := range tools {
if tool.Type != "function" {
continue
}
sb.WriteString(fmt.Sprintf("#### %s\n", tool.Function.Name))
if tool.Function.Description != "" {
sb.WriteString(fmt.Sprintf("Description: %s\n", tool.Function.Description))
}
if len(tool.Function.Parameters) > 0 {
paramsJSON, _ := json.Marshal(tool.Function.Parameters)
sb.WriteString(fmt.Sprintf("Parameters:\n```json\n%s\n```\n", string(paramsJSON)))
}
sb.WriteString("\n")
}
return sb.String()
}
// NormalizeToolCall normalizes a ToolCall to ensure all fields are properly populated.
// It handles cases where Name/Arguments might be in different locations (top-level vs Function)
+23
View File
@@ -64,6 +64,29 @@ type SkillsLoader struct {
builtinSkills string // builtin skills
}
// SkillRoots returns all unique skill root directories used by this loader.
// The order follows resolution priority: workspace > global > builtin.
func (sl *SkillsLoader) SkillRoots() []string {
roots := []string{sl.workspaceSkills, sl.globalSkills, sl.builtinSkills}
seen := make(map[string]struct{}, len(roots))
out := make([]string, 0, len(roots))
for _, root := range roots {
trimmed := strings.TrimSpace(root)
if trimmed == "" {
continue
}
clean := filepath.Clean(trimmed)
if _, ok := seen[clean]; ok {
continue
}
seen[clean] = struct{}{}
out = append(out, clean)
}
return out
}
func NewSkillsLoader(workspace string, globalSkills string, builtinSkills string) *SkillsLoader {
return &SkillsLoader{
workspace: workspace,
+16
View File
@@ -326,3 +326,19 @@ func TestStripFrontmatter(t *testing.T) {
})
}
}
func TestSkillRootsTrimsWhitespaceAndDedups(t *testing.T) {
tmp := t.TempDir()
workspace := filepath.Join(tmp, "workspace")
global := filepath.Join(tmp, "global")
builtin := filepath.Join(tmp, "builtin")
sl := NewSkillsLoader(workspace, " "+global+" ", "\t"+builtin+"\n")
roots := sl.SkillRoots()
assert.Equal(t, []string{
filepath.Join(workspace, "skills"),
global,
builtin,
}, roots)
}
+246
View File
@@ -0,0 +1,246 @@
package tools
import (
"context"
"encoding/json"
"fmt"
"hash/fnv"
"strings"
"github.com/modelcontextprotocol/go-sdk/mcp"
)
// MCPManager defines the interface for MCP manager operations
// This allows for easier testing with mock implementations
type MCPManager interface {
CallTool(
ctx context.Context,
serverName, toolName string,
arguments map[string]any,
) (*mcp.CallToolResult, error)
}
// MCPTool wraps an MCP tool to implement the Tool interface
type MCPTool struct {
manager MCPManager
serverName string
tool *mcp.Tool
}
// NewMCPTool creates a new MCP tool wrapper
func NewMCPTool(manager MCPManager, serverName string, tool *mcp.Tool) *MCPTool {
return &MCPTool{
manager: manager,
serverName: serverName,
tool: tool,
}
}
// sanitizeIdentifierComponent normalizes a string so it can be safely used
// as part of a tool/function identifier for downstream providers.
// It:
// - lowercases the string
// - replaces any character not in [a-z0-9_-] with '_'
// - collapses multiple consecutive '_' into a single '_'
// - trims leading/trailing '_'
// - falls back to "unnamed" if the result is empty
// - truncates overly long components to a reasonable length
func sanitizeIdentifierComponent(s string) string {
const maxLen = 64
s = strings.ToLower(s)
var b strings.Builder
b.Grow(len(s))
prevUnderscore := false
for _, r := range s {
isAllowed := (r >= 'a' && r <= 'z') ||
(r >= '0' && r <= '9') ||
r == '_' || r == '-'
if !isAllowed {
// Normalize any disallowed character to '_'
if !prevUnderscore {
b.WriteRune('_')
prevUnderscore = true
}
continue
}
if r == '_' {
if prevUnderscore {
continue
}
prevUnderscore = true
} else {
prevUnderscore = false
}
b.WriteRune(r)
}
result := strings.Trim(b.String(), "_")
if result == "" {
result = "unnamed"
}
if len(result) > maxLen {
result = result[:maxLen]
}
return result
}
// Name returns the tool name, prefixed with the server name.
// The total length is capped at 64 characters (OpenAI-compatible API limit).
// A short hash of the original (unsanitized) server and tool names is appended
// whenever sanitization is lossy or the name is truncated, ensuring that two
// names which differ only in disallowed characters remain distinct after sanitization.
func (t *MCPTool) Name() string {
// Prefix with server name to avoid conflicts, and sanitize components
sanitizedServer := sanitizeIdentifierComponent(t.serverName)
sanitizedTool := sanitizeIdentifierComponent(t.tool.Name)
full := fmt.Sprintf("mcp_%s_%s", sanitizedServer, sanitizedTool)
// Check if sanitization was lossless (only lowercasing, no char replacement/truncation)
lossless := strings.ToLower(t.serverName) == sanitizedServer &&
strings.ToLower(t.tool.Name) == sanitizedTool
const maxTotal = 64
if lossless && len(full) <= maxTotal {
return full
}
// Sanitization was lossy or name too long: append hash of the ORIGINAL names
// (not the sanitized names) so different originals always yield different hashes.
h := fnv.New32a()
_, _ = h.Write([]byte(t.serverName + "\x00" + t.tool.Name))
suffix := fmt.Sprintf("%08x", h.Sum32()) // 8 chars
base := full
if len(base) > maxTotal-9 {
base = strings.TrimRight(full[:maxTotal-9], "_")
}
return base + "_" + suffix
}
// Description returns the tool description
func (t *MCPTool) Description() string {
desc := t.tool.Description
if desc == "" {
desc = fmt.Sprintf("MCP tool from %s server", t.serverName)
}
// Add server info to description
return fmt.Sprintf("[MCP:%s] %s", t.serverName, desc)
}
// Parameters returns the tool parameters schema
func (t *MCPTool) Parameters() map[string]any {
// The InputSchema is already a JSON Schema object
schema := t.tool.InputSchema
// Handle nil schema
if schema == nil {
return map[string]any{
"type": "object",
"properties": map[string]any{},
"required": []string{},
}
}
// Try direct conversion first (fast path)
if schemaMap, ok := schema.(map[string]any); ok {
return schemaMap
}
// Handle json.RawMessage and []byte - unmarshal directly
var jsonData []byte
if rawMsg, ok := schema.(json.RawMessage); ok {
jsonData = rawMsg
} else if bytes, ok := schema.([]byte); ok {
jsonData = bytes
}
if jsonData != nil {
var result map[string]any
if err := json.Unmarshal(jsonData, &result); err == nil {
return result
}
// Fallback on error
return map[string]any{
"type": "object",
"properties": map[string]any{},
"required": []string{},
}
}
// For other types (structs, etc.), convert via JSON marshal/unmarshal
var err error
jsonData, err = json.Marshal(schema)
if err != nil {
// Fallback to empty schema if marshaling fails
return map[string]any{
"type": "object",
"properties": map[string]any{},
"required": []string{},
}
}
var result map[string]any
if err := json.Unmarshal(jsonData, &result); err != nil {
// Fallback to empty schema if unmarshaling fails
return map[string]any{
"type": "object",
"properties": map[string]any{},
"required": []string{},
}
}
return result
}
// Execute executes the MCP tool
func (t *MCPTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
result, err := t.manager.CallTool(ctx, t.serverName, t.tool.Name, args)
if err != nil {
return ErrorResult(fmt.Sprintf("MCP tool execution failed: %v", err)).WithError(err)
}
if result == nil {
nilErr := fmt.Errorf("MCP tool returned nil result without error")
return ErrorResult("MCP tool execution failed: nil result").WithError(nilErr)
}
// Handle error result from server
if result.IsError {
errMsg := extractContentText(result.Content)
return ErrorResult(fmt.Sprintf("MCP tool returned error: %s", errMsg)).
WithError(fmt.Errorf("MCP tool error: %s", errMsg))
}
// Extract text content from result
output := extractContentText(result.Content)
return &ToolResult{
ForLLM: output,
IsError: false,
}
}
// extractContentText extracts text from MCP content array
func extractContentText(content []mcp.Content) string {
var parts []string
for _, c := range content {
switch v := c.(type) {
case *mcp.TextContent:
parts = append(parts, v.Text)
case *mcp.ImageContent:
// For images, just indicate that an image was returned
parts = append(parts, fmt.Sprintf("[Image: %s]", v.MIMEType))
default:
// For other content types, use string representation
parts = append(parts, fmt.Sprintf("[Content: %T]", v))
}
}
return strings.Join(parts, "\n")
}
+492
View File
@@ -0,0 +1,492 @@
package tools
import (
"context"
"fmt"
"strings"
"testing"
"github.com/modelcontextprotocol/go-sdk/mcp"
)
// MockMCPManager is a mock implementation of MCPManager interface for testing
type MockMCPManager struct {
callToolFunc func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error)
}
func (m *MockMCPManager) CallTool(
ctx context.Context,
serverName, toolName string,
arguments map[string]any,
) (*mcp.CallToolResult, error) {
if m.callToolFunc != nil {
return m.callToolFunc(ctx, serverName, toolName, arguments)
}
return &mcp.CallToolResult{
Content: []mcp.Content{
&mcp.TextContent{Text: "mock result"},
},
IsError: false,
}, nil
}
// TestNewMCPTool verifies MCP tool creation
func TestNewMCPTool(t *testing.T) {
manager := &MockMCPManager{}
tool := &mcp.Tool{
Name: "test_tool",
Description: "A test tool",
InputSchema: map[string]any{
"type": "object",
"properties": map[string]any{
"input": map[string]any{
"type": "string",
"description": "Test input",
},
},
},
}
mcpTool := NewMCPTool(manager, "test_server", tool)
if mcpTool == nil {
t.Fatal("NewMCPTool should not return nil")
}
// Verify tool properties we can access
if mcpTool.Name() != "mcp_test_server_test_tool" {
t.Errorf("Expected tool name with prefix, got '%s'", mcpTool.Name())
}
}
// TestMCPTool_Name verifies tool name with server prefix
func TestMCPTool_Name(t *testing.T) {
tests := []struct {
name string
serverName string
toolName string
expected string
}{
{
name: "simple name",
serverName: "github",
toolName: "create_issue",
expected: "mcp_github_create_issue",
},
{
name: "filesystem server",
serverName: "filesystem",
toolName: "read_file",
expected: "mcp_filesystem_read_file",
},
{
name: "remote server",
serverName: "remote-api",
toolName: "fetch_data",
expected: "mcp_remote-api_fetch_data",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
manager := &MockMCPManager{}
tool := &mcp.Tool{Name: tt.toolName}
mcpTool := NewMCPTool(manager, tt.serverName, tool)
result := mcpTool.Name()
if result != tt.expected {
t.Errorf("Expected name '%s', got '%s'", tt.expected, result)
}
})
}
}
// TestMCPTool_Description verifies tool description generation
func TestMCPTool_Description(t *testing.T) {
tests := []struct {
name string
serverName string
toolDescription string
expectContains []string
}{
{
name: "with description",
serverName: "github",
toolDescription: "Create a GitHub issue",
expectContains: []string{"[MCP:github]", "Create a GitHub issue"},
},
{
name: "empty description",
serverName: "filesystem",
toolDescription: "",
expectContains: []string{"[MCP:filesystem]", "MCP tool from filesystem server"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
manager := &MockMCPManager{}
tool := &mcp.Tool{
Name: "test_tool",
Description: tt.toolDescription,
}
mcpTool := NewMCPTool(manager, tt.serverName, tool)
result := mcpTool.Description()
for _, expected := range tt.expectContains {
if !strings.Contains(result, expected) {
t.Errorf("Description should contain '%s', got: %s", expected, result)
}
}
})
}
}
// TestMCPTool_Parameters verifies parameter schema conversion
func TestMCPTool_Parameters(t *testing.T) {
tests := []struct {
name string
inputSchema any
expectType string
checkProperty string
expectProperty bool
}{
{
name: "map schema",
inputSchema: map[string]any{
"type": "object",
"properties": map[string]any{
"query": map[string]any{
"type": "string",
"description": "Search query",
},
},
"required": []string{"query"},
},
expectType: "object",
checkProperty: "query",
expectProperty: true,
},
{
name: "nil schema",
inputSchema: nil,
expectType: "object",
expectProperty: false,
},
{
name: "json.RawMessage schema",
inputSchema: []byte(`{
"type": "object",
"properties": {
"repo": {
"type": "string",
"description": "Repository name"
},
"stars": {
"type": "integer",
"description": "Minimum stars"
}
},
"required": ["repo"]
}`),
expectType: "object",
checkProperty: "repo",
expectProperty: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
manager := &MockMCPManager{}
tool := &mcp.Tool{
Name: "test_tool",
InputSchema: tt.inputSchema,
}
mcpTool := NewMCPTool(manager, "test_server", tool)
params := mcpTool.Parameters()
if params == nil {
t.Fatal("Parameters should not be nil")
}
if params["type"] != tt.expectType {
t.Errorf("Expected type '%s', got '%v'", tt.expectType, params["type"])
}
// Check if property exists when expected
if tt.checkProperty != "" {
properties, ok := params["properties"].(map[string]any)
if !ok && tt.expectProperty {
t.Errorf("Expected properties to be a map")
return
}
if ok {
_, hasProperty := properties[tt.checkProperty]
if hasProperty != tt.expectProperty {
t.Errorf("Expected property '%s' existence: %v, got: %v",
tt.checkProperty, tt.expectProperty, hasProperty)
}
}
}
})
}
}
// TestMCPTool_Execute_Success tests successful tool execution
func TestMCPTool_Execute_Success(t *testing.T) {
manager := &MockMCPManager{
callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error) {
// Verify correct parameters passed
if serverName != "github" {
t.Errorf("Expected serverName 'github', got '%s'", serverName)
}
if toolName != "search_repos" {
t.Errorf("Expected toolName 'search_repos', got '%s'", toolName)
}
return &mcp.CallToolResult{
Content: []mcp.Content{
&mcp.TextContent{Text: "Found 3 repositories"},
},
IsError: false,
}, nil
},
}
tool := &mcp.Tool{
Name: "search_repos",
Description: "Search GitHub repositories",
}
mcpTool := NewMCPTool(manager, "github", tool)
ctx := context.Background()
args := map[string]any{
"query": "golang mcp",
}
result := mcpTool.Execute(ctx, args)
if result == nil {
t.Fatal("Result should not be nil")
}
if result.IsError {
t.Errorf("Expected no error, got error: %s", result.ForLLM)
}
if result.ForLLM != "Found 3 repositories" {
t.Errorf("Expected 'Found 3 repositories', got '%s'", result.ForLLM)
}
}
// TestMCPTool_Execute_ManagerError tests execution when manager returns error
func TestMCPTool_Execute_ManagerError(t *testing.T) {
manager := &MockMCPManager{
callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error) {
return nil, fmt.Errorf("connection failed")
},
}
tool := &mcp.Tool{Name: "test_tool"}
mcpTool := NewMCPTool(manager, "test_server", tool)
ctx := context.Background()
result := mcpTool.Execute(ctx, map[string]any{})
if result == nil {
t.Fatal("Result should not be nil")
}
if !result.IsError {
t.Error("Expected IsError to be true")
}
if !strings.Contains(result.ForLLM, "MCP tool execution failed") {
t.Errorf("Error message should mention execution failure, got: %s", result.ForLLM)
}
if !strings.Contains(result.ForLLM, "connection failed") {
t.Errorf("Error message should include original error, got: %s", result.ForLLM)
}
}
// TestMCPTool_Execute_ServerError tests execution when server returns error
func TestMCPTool_Execute_ServerError(t *testing.T) {
manager := &MockMCPManager{
callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error) {
return &mcp.CallToolResult{
Content: []mcp.Content{
&mcp.TextContent{Text: "Invalid API key"},
},
IsError: true,
}, nil
},
}
tool := &mcp.Tool{Name: "test_tool"}
mcpTool := NewMCPTool(manager, "test_server", tool)
ctx := context.Background()
result := mcpTool.Execute(ctx, map[string]any{})
if result == nil {
t.Fatal("Result should not be nil")
}
if !result.IsError {
t.Error("Expected IsError to be true")
}
if !strings.Contains(result.ForLLM, "MCP tool returned error") {
t.Errorf("Error message should mention server error, got: %s", result.ForLLM)
}
if !strings.Contains(result.ForLLM, "Invalid API key") {
t.Errorf("Error message should include server message, got: %s", result.ForLLM)
}
}
// TestMCPTool_Execute_MultipleContent tests execution with multiple content items
func TestMCPTool_Execute_MultipleContent(t *testing.T) {
manager := &MockMCPManager{
callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error) {
return &mcp.CallToolResult{
Content: []mcp.Content{
&mcp.TextContent{Text: "First line"},
&mcp.TextContent{Text: "Second line"},
&mcp.TextContent{Text: "Third line"},
},
IsError: false,
}, nil
},
}
tool := &mcp.Tool{Name: "multi_output"}
mcpTool := NewMCPTool(manager, "test_server", tool)
ctx := context.Background()
result := mcpTool.Execute(ctx, map[string]any{})
if result.IsError {
t.Errorf("Expected no error, got: %s", result.ForLLM)
}
expected := "First line\nSecond line\nThird line"
if result.ForLLM != expected {
t.Errorf("Expected '%s', got '%s'", expected, result.ForLLM)
}
}
// TestExtractContentText_TextContent tests text content extraction
func TestExtractContentText_TextContent(t *testing.T) {
content := []mcp.Content{
&mcp.TextContent{Text: "Hello World"},
&mcp.TextContent{Text: "Second message"},
}
result := extractContentText(content)
expected := "Hello World\nSecond message"
if result != expected {
t.Errorf("Expected '%s', got '%s'", expected, result)
}
}
// TestExtractContentText_ImageContent tests image content extraction
func TestExtractContentText_ImageContent(t *testing.T) {
content := []mcp.Content{
&mcp.ImageContent{
Data: []byte("base64data"),
MIMEType: "image/png",
},
}
result := extractContentText(content)
if !strings.Contains(result, "[Image:") {
t.Errorf("Expected image indicator, got: %s", result)
}
if !strings.Contains(result, "image/png") {
t.Errorf("Expected MIME type in output, got: %s", result)
}
}
// TestExtractContentText_MixedContent tests mixed content types
func TestExtractContentText_MixedContent(t *testing.T) {
content := []mcp.Content{
&mcp.TextContent{Text: "Description"},
&mcp.ImageContent{
Data: []byte("data"),
MIMEType: "image/jpeg",
},
&mcp.TextContent{Text: "More text"},
}
result := extractContentText(content)
if !strings.Contains(result, "Description") {
t.Errorf("Should contain text content, got: %s", result)
}
if !strings.Contains(result, "[Image:") {
t.Errorf("Should contain image indicator, got: %s", result)
}
if !strings.Contains(result, "More text") {
t.Errorf("Should contain second text, got: %s", result)
}
}
// TestExtractContentText_EmptyContent tests empty content array
func TestExtractContentText_EmptyContent(t *testing.T) {
content := []mcp.Content{}
result := extractContentText(content)
if result != "" {
t.Errorf("Expected empty string for empty content, got: %s", result)
}
}
// TestMCPTool_InterfaceCompliance verifies MCPTool implements Tool interface
func TestMCPTool_InterfaceCompliance(t *testing.T) {
manager := &MockMCPManager{}
tool := &mcp.Tool{Name: "test"}
mcpTool := NewMCPTool(manager, "test_server", tool)
// Verify it implements Tool interface
var _ Tool = mcpTool
}
// TestMCPTool_Parameters_MapSchema tests schema that's already a map
func TestMCPTool_Parameters_MapSchema(t *testing.T) {
manager := &MockMCPManager{}
schema := map[string]any{
"type": "object",
"properties": map[string]any{
"name": map[string]any{
"type": "string",
"description": "The name parameter",
},
},
"required": []string{"name"},
}
tool := &mcp.Tool{
Name: "test_tool",
InputSchema: schema,
}
mcpTool := NewMCPTool(manager, "test_server", tool)
params := mcpTool.Parameters()
// Should return the schema as-is when it's already a map
if params["type"] != "object" {
t.Errorf("Expected type 'object', got '%v'", params["type"])
}
props, ok := params["properties"].(map[string]any)
if !ok {
t.Error("Properties should be a map")
}
nameParam, ok := props["name"].(map[string]any)
if !ok {
t.Error("Name parameter should exist")
}
if nameParam["type"] != "string" {
t.Errorf("Name type should be 'string', got '%v'", nameParam["type"])
}
}
+5 -4
View File
@@ -3,6 +3,7 @@ package tools
import (
"context"
"fmt"
"sync/atomic"
)
type SendCallback func(channel, chatID, content string) error
@@ -11,7 +12,7 @@ type MessageTool struct {
sendCallback SendCallback
defaultChannel string
defaultChatID string
sentInRound bool // Tracks whether a message was sent in the current processing round
sentInRound atomic.Bool // Tracks whether a message was sent in the current processing round
}
func NewMessageTool() *MessageTool {
@@ -50,12 +51,12 @@ func (t *MessageTool) Parameters() map[string]any {
func (t *MessageTool) SetContext(channel, chatID string) {
t.defaultChannel = channel
t.defaultChatID = chatID
t.sentInRound = false // Reset send tracking for new processing round
t.sentInRound.Store(false) // Reset send tracking for new processing round
}
// HasSentInRound returns true if the message tool sent a message during the current round.
func (t *MessageTool) HasSentInRound() bool {
return t.sentInRound
return t.sentInRound.Load()
}
func (t *MessageTool) SetSendCallback(callback SendCallback) {
@@ -94,7 +95,7 @@ func (t *MessageTool) Execute(ctx context.Context, args map[string]any) *ToolRes
}
}
t.sentInRound = true
t.sentInRound.Store(true)
// Silent: user already received the message directly
return &ToolResult{
ForLLM: fmt.Sprintf("Message sent to %s:%s", channel, chatID),
+6 -1
View File
@@ -25,7 +25,12 @@ func NewToolRegistry() *ToolRegistry {
func (r *ToolRegistry) Register(tool Tool) {
r.mu.Lock()
defer r.mu.Unlock()
r.tools[tool.Name()] = tool
name := tool.Name()
if _, exists := r.tools[name]; exists {
logger.WarnCF("tools", "Tool registration overwrites existing tool",
map[string]any{"name": name})
}
r.tools[name] = tool
}
func (r *ToolRegistry) Get(name string) (Tool, bool) {
+43 -26
View File
@@ -10,6 +10,7 @@ import (
"context"
"encoding/json"
"fmt"
"sync"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/providers"
@@ -121,37 +122,53 @@ func RunToolLoop(
}
messages = append(messages, assistantMsg)
// 7. Execute tool calls
for _, tc := range normalizedToolCalls {
argsJSON, _ := json.Marshal(tc.Arguments)
argsPreview := utils.Truncate(string(argsJSON), 200)
logger.InfoCF("toolloop", fmt.Sprintf("Tool call: %s(%s)", tc.Name, argsPreview),
map[string]any{
"tool": tc.Name,
"iteration": iteration,
})
// 7. Execute tool calls in parallel
type indexedResult struct {
result *ToolResult
tc providers.ToolCall
}
// Execute tool (no async callback for subagents - they run independently)
var toolResult *ToolResult
if config.Tools != nil {
toolResult = config.Tools.ExecuteWithContext(ctx, tc.Name, tc.Arguments, channel, chatID, nil)
} else {
toolResult = ErrorResult("No tools available")
results := make([]indexedResult, len(normalizedToolCalls))
var wg sync.WaitGroup
for i, tc := range normalizedToolCalls {
results[i].tc = tc
wg.Add(1)
go func(idx int, tc providers.ToolCall) {
defer wg.Done()
argsJSON, _ := json.Marshal(tc.Arguments)
argsPreview := utils.Truncate(string(argsJSON), 200)
logger.InfoCF("toolloop", fmt.Sprintf("Tool call: %s(%s)", tc.Name, argsPreview),
map[string]any{
"tool": tc.Name,
"iteration": iteration,
})
var toolResult *ToolResult
if config.Tools != nil {
toolResult = config.Tools.ExecuteWithContext(ctx, tc.Name, tc.Arguments, channel, chatID, nil)
} else {
toolResult = ErrorResult("No tools available")
}
results[idx].result = toolResult
}(i, tc)
}
wg.Wait()
// Append results in original order
for _, r := range results {
contentForLLM := r.result.ForLLM
if contentForLLM == "" && r.result.Err != nil {
contentForLLM = r.result.Err.Error()
}
// Determine content for LLM
contentForLLM := toolResult.ForLLM
if contentForLLM == "" && toolResult.Err != nil {
contentForLLM = toolResult.Err.Error()
}
// Add tool result message
toolResultMsg := providers.Message{
messages = append(messages, providers.Message{
Role: "tool",
Content: contentForLLM,
ToolCallID: tc.ID,
}
messages = append(messages, toolResultMsg)
ToolCallID: r.tc.ID,
})
}
}
+111 -1
View File
@@ -109,6 +109,10 @@ func (p *BraveSearchProvider) Search(ctx context.Context, query string, count in
return "", fmt.Errorf("failed to read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("brave api error (status %d): %s", resp.StatusCode, string(body))
}
var searchResp struct {
Web struct {
Results []struct {
@@ -391,6 +395,88 @@ func (p *PerplexitySearchProvider) Search(ctx context.Context, query string, cou
return fmt.Sprintf("Results for: %s (via Perplexity)\n%s", query, searchResp.Choices[0].Message.Content), nil
}
type GLMSearchProvider struct {
apiKey string
baseURL string
searchEngine string
proxy string
client *http.Client
}
func (p *GLMSearchProvider) Search(ctx context.Context, query string, count int) (string, error) {
searchURL := p.baseURL
if searchURL == "" {
searchURL = "https://open.bigmodel.cn/api/paas/v4/web_search"
}
payload := map[string]any{
"search_query": query,
"search_engine": p.searchEngine,
"search_intent": false,
"count": count,
"content_size": "medium",
}
bodyBytes, err := json.Marshal(payload)
if err != nil {
return "", fmt.Errorf("failed to marshal payload: %w", err)
}
req, err := http.NewRequestWithContext(ctx, "POST", searchURL, bytes.NewReader(bodyBytes))
if err != nil {
return "", fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+p.apiKey)
resp, err := p.client.Do(req)
if err != nil {
return "", fmt.Errorf("request failed: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
if err != nil {
return "", fmt.Errorf("failed to read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("GLM Search API error (status %d): %s", resp.StatusCode, string(body))
}
var searchResp struct {
SearchResult []struct {
Title string `json:"title"`
Content string `json:"content"`
Link string `json:"link"`
} `json:"search_result"`
}
if err := json.Unmarshal(body, &searchResp); err != nil {
return "", fmt.Errorf("failed to parse response: %w", err)
}
results := searchResp.SearchResult
if len(results) == 0 {
return fmt.Sprintf("No results for: %s", query), nil
}
var lines []string
lines = append(lines, fmt.Sprintf("Results for: %s (via GLM Search)", query))
for i, item := range results {
if i >= count {
break
}
lines = append(lines, fmt.Sprintf("%d. %s\n %s", i+1, item.Title, item.Link))
if item.Content != "" {
lines = append(lines, fmt.Sprintf(" %s", item.Content))
}
}
return strings.Join(lines, "\n"), nil
}
type WebSearchTool struct {
provider SearchProvider
maxResults int
@@ -409,6 +495,11 @@ type WebSearchToolOptions struct {
PerplexityAPIKey string
PerplexityMaxResults int
PerplexityEnabled bool
GLMSearchAPIKey string
GLMSearchBaseURL string
GLMSearchEngine string
GLMSearchMaxResults int
GLMSearchEnabled bool
Proxy string
}
@@ -416,7 +507,7 @@ func NewWebSearchTool(opts WebSearchToolOptions) (*WebSearchTool, error) {
var provider SearchProvider
maxResults := 5
// Priority: Perplexity > Brave > Tavily > DuckDuckGo
// Priority: Perplexity > Brave > Tavily > DuckDuckGo > GLM Search
if opts.PerplexityEnabled && opts.PerplexityAPIKey != "" {
client, err := createHTTPClient(opts.Proxy, perplexityTimeout)
if err != nil {
@@ -458,6 +549,25 @@ func NewWebSearchTool(opts WebSearchToolOptions) (*WebSearchTool, error) {
if opts.DuckDuckGoMaxResults > 0 {
maxResults = opts.DuckDuckGoMaxResults
}
} else if opts.GLMSearchEnabled && opts.GLMSearchAPIKey != "" {
client, err := createHTTPClient(opts.Proxy, searchTimeout)
if err != nil {
return nil, fmt.Errorf("failed to create HTTP client for GLM Search: %w", err)
}
searchEngine := opts.GLMSearchEngine
if searchEngine == "" {
searchEngine = "search_std"
}
provider = &GLMSearchProvider{
apiKey: opts.GLMSearchAPIKey,
baseURL: opts.GLMSearchBaseURL,
searchEngine: searchEngine,
proxy: opts.Proxy,
client: client,
}
if opts.GLMSearchMaxResults > 0 {
maxResults = opts.GLMSearchMaxResults
}
} else {
return nil, nil
}
+132
View File
@@ -681,3 +681,135 @@ func TestWebTool_TavilySearch_Success(t *testing.T) {
t.Errorf("Expected 'via Tavily' in output, got: %s", result.ForUser)
}
}
func TestWebTool_GLMSearch_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
t.Errorf("Expected POST request, got %s", r.Method)
}
if r.Header.Get("Content-Type") != "application/json" {
t.Errorf("Expected Content-Type application/json, got %s", r.Header.Get("Content-Type"))
}
if r.Header.Get("Authorization") != "Bearer test-glm-key" {
t.Errorf("Expected Authorization Bearer test-glm-key, got %s", r.Header.Get("Authorization"))
}
var payload map[string]any
json.NewDecoder(r.Body).Decode(&payload)
if payload["search_query"] != "test query" {
t.Errorf("Expected search_query 'test query', got %v", payload["search_query"])
}
if payload["search_engine"] != "search_std" {
t.Errorf("Expected search_engine 'search_std', got %v", payload["search_engine"])
}
response := map[string]any{
"id": "web-search-test",
"created": 1709568000,
"search_result": []map[string]any{
{
"title": "Test GLM Result",
"content": "GLM search snippet",
"link": "https://example.com/glm",
"media": "Example",
"publish_date": "2026-03-04",
},
},
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(response)
}))
defer server.Close()
tool, err := NewWebSearchTool(WebSearchToolOptions{
GLMSearchEnabled: true,
GLMSearchAPIKey: "test-glm-key",
GLMSearchBaseURL: server.URL,
GLMSearchEngine: "search_std",
})
if err != nil {
t.Fatalf("NewWebSearchTool() error: %v", err)
}
result := tool.Execute(context.Background(), map[string]any{
"query": "test query",
})
if result.IsError {
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
}
if !strings.Contains(result.ForUser, "Test GLM Result") {
t.Errorf("Expected 'Test GLM Result' in output, got: %s", result.ForUser)
}
if !strings.Contains(result.ForUser, "https://example.com/glm") {
t.Errorf("Expected URL in output, got: %s", result.ForUser)
}
if !strings.Contains(result.ForUser, "via GLM Search") {
t.Errorf("Expected 'via GLM Search' in output, got: %s", result.ForUser)
}
}
func TestWebTool_GLMSearch_APIError(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
w.Write([]byte(`{"error":"invalid api key"}`))
}))
defer server.Close()
tool, err := NewWebSearchTool(WebSearchToolOptions{
GLMSearchEnabled: true,
GLMSearchAPIKey: "bad-key",
GLMSearchBaseURL: server.URL,
GLMSearchEngine: "search_std",
})
if err != nil {
t.Fatalf("NewWebSearchTool() error: %v", err)
}
result := tool.Execute(context.Background(), map[string]any{
"query": "test query",
})
if !result.IsError {
t.Errorf("Expected IsError=true for 401 response")
}
if !strings.Contains(result.ForLLM, "status 401") {
t.Errorf("Expected status 401 in error, got: %s", result.ForLLM)
}
}
func TestWebTool_GLMSearch_Priority(t *testing.T) {
// GLM Search should only be selected when all other providers are disabled
tool, err := NewWebSearchTool(WebSearchToolOptions{
DuckDuckGoEnabled: true,
DuckDuckGoMaxResults: 5,
GLMSearchEnabled: true,
GLMSearchAPIKey: "test-key",
GLMSearchBaseURL: "https://example.com",
GLMSearchEngine: "search_std",
})
if err != nil {
t.Fatalf("NewWebSearchTool() error: %v", err)
}
// DuckDuckGo should win over GLM Search
if _, ok := tool.provider.(*DuckDuckGoSearchProvider); !ok {
t.Errorf("Expected DuckDuckGoSearchProvider when both enabled, got %T", tool.provider)
}
// With DuckDuckGo disabled, GLM Search should be selected
tool2, err := NewWebSearchTool(WebSearchToolOptions{
DuckDuckGoEnabled: false,
GLMSearchEnabled: true,
GLMSearchAPIKey: "test-key",
GLMSearchBaseURL: "https://example.com",
GLMSearchEngine: "search_std",
})
if err != nil {
t.Fatalf("NewWebSearchTool() error: %v", err)
}
if _, ok := tool2.provider.(*GLMSearchProvider); !ok {
t.Errorf("Expected GLMSearchProvider when only GLM enabled, got %T", tool2.provider)
}
}
+19 -4
View File
@@ -3,6 +3,7 @@ package utils
import (
"io"
"net/http"
"net/url"
"os"
"path/filepath"
"strings"
@@ -52,11 +53,12 @@ type DownloadOptions struct {
Timeout time.Duration
ExtraHeaders map[string]string
LoggerPrefix string
ProxyURL string
}
// DownloadFile downloads a file from URL to a local temp directory.
// Returns the local file path or empty string on error.
func DownloadFile(url, filename string, opts DownloadOptions) string {
func DownloadFile(urlStr, filename string, opts DownloadOptions) string {
// Set defaults
if opts.Timeout == 0 {
opts.Timeout = 60 * time.Second
@@ -78,7 +80,7 @@ func DownloadFile(url, filename string, opts DownloadOptions) string {
localPath := filepath.Join(mediaDir, uuid.New().String()[:8]+"_"+safeName)
// Create HTTP request
req, err := http.NewRequest("GET", url, nil)
req, err := http.NewRequest("GET", urlStr, nil)
if err != nil {
logger.ErrorCF(opts.LoggerPrefix, "Failed to create download request", map[string]any{
"error": err.Error(),
@@ -92,11 +94,24 @@ func DownloadFile(url, filename string, opts DownloadOptions) string {
}
client := &http.Client{Timeout: opts.Timeout}
if opts.ProxyURL != "" {
proxyURL, parseErr := url.Parse(opts.ProxyURL)
if parseErr != nil {
logger.ErrorCF(opts.LoggerPrefix, "Invalid proxy URL for download", map[string]any{
"error": parseErr.Error(),
"proxy": opts.ProxyURL,
})
return ""
}
client.Transport = &http.Transport{
Proxy: http.ProxyURL(proxyURL),
}
}
resp, err := client.Do(req)
if err != nil {
logger.ErrorCF(opts.LoggerPrefix, "Failed to download file", map[string]any{
"error": err.Error(),
"url": url,
"url": urlStr,
})
return ""
}
@@ -105,7 +120,7 @@ func DownloadFile(url, filename string, opts DownloadOptions) string {
if resp.StatusCode != http.StatusOK {
logger.ErrorCF(opts.LoggerPrefix, "File download returned non-200 status", map[string]any{
"status": resp.StatusCode,
"url": url,
"url": urlStr,
})
return ""
}
+25 -4
View File
@@ -10,12 +10,19 @@ import (
"net/http"
"os"
"path/filepath"
"strings"
"time"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/utils"
)
type Transcriber interface {
Name() string
Transcribe(ctx context.Context, audioFilePath string) (*TranscriptionResponse, error)
}
type GroqTranscriber struct {
apiKey string
apiBase string
@@ -152,8 +159,22 @@ func (t *GroqTranscriber) Transcribe(ctx context.Context, audioFilePath string)
return &result, nil
}
func (t *GroqTranscriber) IsAvailable() bool {
available := t.apiKey != ""
logger.DebugCF("voice", "Checking transcriber availability", map[string]any{"available": available})
return available
func (t *GroqTranscriber) Name() string {
return "groq"
}
// DetectTranscriber inspects cfg and returns the appropriate Transcriber, or
// nil if no supported transcription provider is configured.
func DetectTranscriber(cfg *config.Config) Transcriber {
// Direct Groq provider config takes priority.
if key := cfg.Providers.Groq.APIKey; key != "" {
return NewGroqTranscriber(key)
}
// Fall back to any model-list entry that uses the groq/ protocol.
for _, mc := range cfg.ModelList {
if strings.HasPrefix(mc.Model, "groq/") && mc.APIKey != "" {
return NewGroqTranscriber(mc.APIKey)
}
}
return nil
}
+160
View File
@@ -0,0 +1,160 @@
package voice
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
"github.com/sipeed/picoclaw/pkg/config"
)
// Ensure GroqTranscriber satisfies the Transcriber interface at compile time.
var _ Transcriber = (*GroqTranscriber)(nil)
func TestGroqTranscriberName(t *testing.T) {
tr := NewGroqTranscriber("sk-test")
if got := tr.Name(); got != "groq" {
t.Errorf("Name() = %q, want %q", got, "groq")
}
}
func TestDetectTranscriber(t *testing.T) {
tests := []struct {
name string
cfg *config.Config
wantNil bool
wantName string
}{
{
name: "no config",
cfg: &config.Config{},
wantNil: true,
},
{
name: "groq provider key",
cfg: &config.Config{
Providers: config.ProvidersConfig{
Groq: config.ProviderConfig{APIKey: "sk-groq-direct"},
},
},
wantName: "groq",
},
{
name: "groq via model list",
cfg: &config.Config{
ModelList: []config.ModelConfig{
{Model: "openai/gpt-4o", APIKey: "sk-openai"},
{Model: "groq/llama-3.3-70b", APIKey: "sk-groq-model"},
},
},
wantName: "groq",
},
{
name: "groq model list entry without key is skipped",
cfg: &config.Config{
ModelList: []config.ModelConfig{
{Model: "groq/llama-3.3-70b", APIKey: ""},
},
},
wantNil: true,
},
{
name: "provider key takes priority over model list",
cfg: &config.Config{
Providers: config.ProvidersConfig{
Groq: config.ProviderConfig{APIKey: "sk-groq-direct"},
},
ModelList: []config.ModelConfig{
{Model: "groq/llama-3.3-70b", APIKey: "sk-groq-model"},
},
},
wantName: "groq",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
tr := DetectTranscriber(tc.cfg)
if tc.wantNil {
if tr != nil {
t.Errorf("DetectTranscriber() = %v, want nil", tr)
}
return
}
if tr == nil {
t.Fatal("DetectTranscriber() = nil, want non-nil")
}
if got := tr.Name(); got != tc.wantName {
t.Errorf("Name() = %q, want %q", got, tc.wantName)
}
})
}
}
func TestTranscribe(t *testing.T) {
// Write a minimal fake audio file so the transcriber can open and send it.
tmpDir := t.TempDir()
audioPath := filepath.Join(tmpDir, "clip.ogg")
if err := os.WriteFile(audioPath, []byte("fake-audio-data"), 0o644); err != nil {
t.Fatalf("failed to write fake audio file: %v", err)
}
t.Run("success", func(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/audio/transcriptions" {
t.Errorf("unexpected path: %s", r.URL.Path)
}
if r.Header.Get("Authorization") != "Bearer sk-test" {
t.Errorf("unexpected Authorization header: %s", r.Header.Get("Authorization"))
}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(TranscriptionResponse{
Text: "hello world",
Language: "en",
Duration: 1.5,
})
}))
defer srv.Close()
tr := NewGroqTranscriber("sk-test")
tr.apiBase = srv.URL
resp, err := tr.Transcribe(context.Background(), audioPath)
if err != nil {
t.Fatalf("Transcribe() error: %v", err)
}
if resp.Text != "hello world" {
t.Errorf("Text = %q, want %q", resp.Text, "hello world")
}
if resp.Language != "en" {
t.Errorf("Language = %q, want %q", resp.Language, "en")
}
})
t.Run("api error", func(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, `{"error":"invalid_api_key"}`, http.StatusUnauthorized)
}))
defer srv.Close()
tr := NewGroqTranscriber("sk-bad")
tr.apiBase = srv.URL
_, err := tr.Transcribe(context.Background(), audioPath)
if err == nil {
t.Fatal("expected error for non-200 response, got nil")
}
})
t.Run("missing file", func(t *testing.T) {
tr := NewGroqTranscriber("sk-test")
_, err := tr.Transcribe(context.Background(), filepath.Join(tmpDir, "nonexistent.ogg"))
if err == nil {
t.Fatal("expected error for missing file, got nil")
}
})
}