mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
Merge upstream/main into fix/bugfixes
Resolve conflicts: - provider.go: keep upstream's serializeMessages (supersedes stripSystemParts) - provider_test.go: keep upstream's serializeMessages tests - loop_test.go: add slices import needed by upstream tests - shell.go: merge PR's --format deny fix with upstream's block device pattern, safePaths, and absolutePathPattern - shell_test.go: include tests from both branches Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
+126
-62
@@ -7,6 +7,7 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -33,6 +34,11 @@ type ContextBuilder struct {
|
||||
// created (didn't exist at cache time, now exist) or deleted (existed at
|
||||
// cache time, now gone) — both of which should trigger a cache rebuild.
|
||||
existedAtCache map[string]bool
|
||||
|
||||
// skillFilesAtCache snapshots the skill tree file set and mtimes at cache
|
||||
// build time. This catches nested file creations/deletions/mtime changes
|
||||
// that may not update the top-level skill root directory mtime.
|
||||
skillFilesAtCache map[string]time.Time
|
||||
}
|
||||
|
||||
func getGlobalConfigDir() string {
|
||||
@@ -46,8 +52,11 @@ func getGlobalConfigDir() string {
|
||||
func NewContextBuilder(workspace string) *ContextBuilder {
|
||||
// builtin skills: skills directory in current project
|
||||
// Use the skills/ directory under the current working directory
|
||||
wd, _ := os.Getwd()
|
||||
builtinSkillsDir := filepath.Join(wd, "skills")
|
||||
builtinSkillsDir := strings.TrimSpace(os.Getenv("PICOCLAW_BUILTIN_SKILLS"))
|
||||
if builtinSkillsDir == "" {
|
||||
wd, _ := os.Getwd()
|
||||
builtinSkillsDir = filepath.Join(wd, "skills")
|
||||
}
|
||||
globalSkillsDir := filepath.Join(getGlobalConfigDir(), "skills")
|
||||
|
||||
return &ContextBuilder{
|
||||
@@ -147,6 +156,7 @@ func (cb *ContextBuilder) BuildSystemPromptWithCache() string {
|
||||
cb.cachedSystemPrompt = prompt
|
||||
cb.cachedAt = baseline.maxMtime
|
||||
cb.existedAtCache = baseline.existed
|
||||
cb.skillFilesAtCache = baseline.skillFiles
|
||||
|
||||
logger.DebugCF("agent", "System prompt cached",
|
||||
map[string]any{
|
||||
@@ -166,14 +176,14 @@ func (cb *ContextBuilder) InvalidateCache() {
|
||||
cb.cachedSystemPrompt = ""
|
||||
cb.cachedAt = time.Time{}
|
||||
cb.existedAtCache = nil
|
||||
cb.skillFilesAtCache = nil
|
||||
|
||||
logger.DebugCF("agent", "System prompt cache invalidated", nil)
|
||||
}
|
||||
|
||||
// sourcePaths returns the workspace source file paths tracked for cache
|
||||
// invalidation (bootstrap files + memory). The skills directory is handled
|
||||
// separately in sourceFilesChangedLocked because it requires both directory-
|
||||
// level and recursive file-level mtime checks.
|
||||
// sourcePaths returns non-skill workspace source files tracked for cache
|
||||
// invalidation (bootstrap files + memory). Skill roots are handled separately
|
||||
// because they require both directory-level and recursive file-level checks.
|
||||
func (cb *ContextBuilder) sourcePaths() []string {
|
||||
return []string{
|
||||
filepath.Join(cb.workspace, "AGENTS.md"),
|
||||
@@ -184,23 +194,39 @@ func (cb *ContextBuilder) sourcePaths() []string {
|
||||
}
|
||||
}
|
||||
|
||||
// skillRoots returns all skill root directories that can affect
|
||||
// BuildSkillsSummary output (workspace/global/builtin).
|
||||
func (cb *ContextBuilder) skillRoots() []string {
|
||||
if cb.skillsLoader == nil {
|
||||
return []string{filepath.Join(cb.workspace, "skills")}
|
||||
}
|
||||
|
||||
roots := cb.skillsLoader.SkillRoots()
|
||||
if len(roots) == 0 {
|
||||
return []string{filepath.Join(cb.workspace, "skills")}
|
||||
}
|
||||
return roots
|
||||
}
|
||||
|
||||
// cacheBaseline holds the file existence snapshot and the latest observed
|
||||
// mtime across all tracked paths. Used as the cache reference point.
|
||||
type cacheBaseline struct {
|
||||
existed map[string]bool
|
||||
maxMtime time.Time
|
||||
existed map[string]bool
|
||||
skillFiles map[string]time.Time
|
||||
maxMtime time.Time
|
||||
}
|
||||
|
||||
// buildCacheBaseline records which tracked paths currently exist and computes
|
||||
// the latest mtime across all tracked files + skills directory contents.
|
||||
// Called under write lock when the cache is built.
|
||||
func (cb *ContextBuilder) buildCacheBaseline() cacheBaseline {
|
||||
skillsDir := filepath.Join(cb.workspace, "skills")
|
||||
skillRoots := cb.skillRoots()
|
||||
|
||||
// All paths whose existence we track: source files + skills dir.
|
||||
allPaths := append(cb.sourcePaths(), skillsDir)
|
||||
// All paths whose existence we track: source files + all skill roots.
|
||||
allPaths := append(cb.sourcePaths(), skillRoots...)
|
||||
|
||||
existed := make(map[string]bool, len(allPaths))
|
||||
skillFiles := make(map[string]time.Time)
|
||||
var maxMtime time.Time
|
||||
|
||||
for _, p := range allPaths {
|
||||
@@ -211,17 +237,21 @@ func (cb *ContextBuilder) buildCacheBaseline() cacheBaseline {
|
||||
}
|
||||
}
|
||||
|
||||
// Walk skills files to capture their mtimes too.
|
||||
// Use os.Stat (not d.Info) to match the stat method used in
|
||||
// fileChangedSince / skillFilesModifiedSince for consistency.
|
||||
_ = filepath.WalkDir(skillsDir, func(path string, d fs.DirEntry, walkErr error) error {
|
||||
if walkErr == nil && !d.IsDir() {
|
||||
if info, err := os.Stat(path); err == nil && info.ModTime().After(maxMtime) {
|
||||
maxMtime = info.ModTime()
|
||||
// Walk all skill roots recursively to snapshot skill files and mtimes.
|
||||
// Use os.Stat (not d.Info) for consistency with sourceFilesChanged checks.
|
||||
for _, root := range skillRoots {
|
||||
_ = filepath.WalkDir(root, func(path string, d fs.DirEntry, walkErr error) error {
|
||||
if walkErr == nil && !d.IsDir() {
|
||||
if info, err := os.Stat(path); err == nil {
|
||||
skillFiles[path] = info.ModTime()
|
||||
if info.ModTime().After(maxMtime) {
|
||||
maxMtime = info.ModTime()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// If no tracked files exist yet (empty workspace), maxMtime is zero.
|
||||
// Use a very old non-zero time so that:
|
||||
@@ -233,7 +263,7 @@ func (cb *ContextBuilder) buildCacheBaseline() cacheBaseline {
|
||||
maxMtime = time.Unix(1, 0)
|
||||
}
|
||||
|
||||
return cacheBaseline{existed: existed, maxMtime: maxMtime}
|
||||
return cacheBaseline{existed: existed, skillFiles: skillFiles, maxMtime: maxMtime}
|
||||
}
|
||||
|
||||
// sourceFilesChangedLocked checks whether any workspace source file has been
|
||||
@@ -249,27 +279,21 @@ func (cb *ContextBuilder) sourceFilesChangedLocked() bool {
|
||||
}
|
||||
|
||||
// Check tracked source files (bootstrap + memory).
|
||||
for _, p := range cb.sourcePaths() {
|
||||
if cb.fileChangedSince(p) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// --- Skills directory (handled separately from sourcePaths) ---
|
||||
//
|
||||
// 1. Creation/deletion: tracked via existedAtCache, same as bootstrap files.
|
||||
skillsDir := filepath.Join(cb.workspace, "skills")
|
||||
if cb.fileChangedSince(skillsDir) {
|
||||
if slices.ContainsFunc(cb.sourcePaths(), cb.fileChangedSince) {
|
||||
return true
|
||||
}
|
||||
|
||||
// 2. Structural changes (add/remove entries inside the dir) are reflected
|
||||
// in the directory's own mtime, which fileChangedSince already checks.
|
||||
// --- Skill roots (workspace/global/builtin) ---
|
||||
//
|
||||
// 3. Content-only edits to files inside skills/ do NOT update the parent
|
||||
// directory mtime on most filesystems, so we recursively walk to check
|
||||
// individual file mtimes at any nesting depth.
|
||||
if skillFilesModifiedSince(skillsDir, cb.cachedAt) {
|
||||
// For each root:
|
||||
// 1. Creation/deletion and root directory mtime changes are tracked by fileChangedSince.
|
||||
// 2. Nested file create/delete/mtime changes are tracked by the skill file snapshot.
|
||||
for _, root := range cb.skillRoots() {
|
||||
if cb.fileChangedSince(root) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
if skillFilesChangedSince(cb.skillRoots(), cb.skillFilesAtCache) {
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -310,28 +334,64 @@ func (cb *ContextBuilder) fileChangedSince(path string) bool {
|
||||
// if the callback returned nil when its err parameter is non-nil.
|
||||
var errWalkStop = errors.New("walk stop")
|
||||
|
||||
// skillFilesModifiedSince recursively walks the skills directory and checks
|
||||
// whether any file was modified after t. This catches content-only edits at
|
||||
// any nesting depth (e.g. skills/name/docs/extra.md) that don't update
|
||||
// parent directory mtimes.
|
||||
func skillFilesModifiedSince(skillsDir string, t time.Time) bool {
|
||||
changed := false
|
||||
err := filepath.WalkDir(skillsDir, func(path string, d fs.DirEntry, walkErr error) error {
|
||||
if walkErr == nil && !d.IsDir() {
|
||||
if info, statErr := os.Stat(path); statErr == nil && info.ModTime().After(t) {
|
||||
changed = true
|
||||
return errWalkStop // stop walking
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
// errWalkStop is expected (early exit on first changed file).
|
||||
// os.IsNotExist means the skills dir doesn't exist yet — not an error.
|
||||
// Any other error is unexpected and worth logging.
|
||||
if err != nil && !errors.Is(err, errWalkStop) && !os.IsNotExist(err) {
|
||||
logger.DebugCF("agent", "skills walk error", map[string]any{"error": err.Error()})
|
||||
// skillFilesChangedSince compares the current recursive skill file tree
|
||||
// against the cache-time snapshot. Any create/delete/mtime drift invalidates
|
||||
// the cache.
|
||||
func skillFilesChangedSince(skillRoots []string, filesAtCache map[string]time.Time) bool {
|
||||
// Defensive: if the snapshot was never initialized, force rebuild.
|
||||
if filesAtCache == nil {
|
||||
return true
|
||||
}
|
||||
return changed
|
||||
|
||||
// Check cached files still exist and keep the same mtime.
|
||||
for path, cachedMtime := range filesAtCache {
|
||||
info, err := os.Stat(path)
|
||||
if err != nil {
|
||||
// A previously tracked file disappeared (or became inaccessible):
|
||||
// either way, cached skill summary may now be stale.
|
||||
return true
|
||||
}
|
||||
if !info.ModTime().Equal(cachedMtime) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Check no new files appeared under any skill root.
|
||||
changed := false
|
||||
for _, root := range skillRoots {
|
||||
if strings.TrimSpace(root) == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
err := filepath.WalkDir(root, func(path string, d fs.DirEntry, walkErr error) error {
|
||||
if walkErr != nil {
|
||||
// Treat unexpected walk errors as changed to avoid stale cache.
|
||||
if !os.IsNotExist(walkErr) {
|
||||
changed = true
|
||||
return errWalkStop
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if d.IsDir() {
|
||||
return nil
|
||||
}
|
||||
if _, ok := filesAtCache[path]; !ok {
|
||||
changed = true
|
||||
return errWalkStop
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
if changed {
|
||||
return true
|
||||
}
|
||||
if err != nil && !errors.Is(err, errWalkStop) && !os.IsNotExist(err) {
|
||||
logger.DebugCF("agent", "skills walk error", map[string]any{"error": err.Error()})
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (cb *ContextBuilder) LoadBootstrapFiles() string {
|
||||
@@ -467,10 +527,14 @@ func (cb *ContextBuilder) BuildMessages(
|
||||
|
||||
// Add current user message
|
||||
if strings.TrimSpace(currentMessage) != "" {
|
||||
messages = append(messages, providers.Message{
|
||||
msg := providers.Message{
|
||||
Role: "user",
|
||||
Content: currentMessage,
|
||||
})
|
||||
}
|
||||
if len(media) > 0 {
|
||||
msg.Media = media
|
||||
}
|
||||
messages = append(messages, msg)
|
||||
}
|
||||
|
||||
return messages
|
||||
|
||||
@@ -383,6 +383,162 @@ Updated content.`
|
||||
}
|
||||
}
|
||||
|
||||
// TestGlobalSkillFileContentChange verifies that modifying a global skill
|
||||
// (~/.picoclaw/skills) invalidates the cached system prompt.
|
||||
func TestGlobalSkillFileContentChange(t *testing.T) {
|
||||
tmpHome := t.TempDir()
|
||||
t.Setenv("HOME", tmpHome)
|
||||
|
||||
tmpDir := setupWorkspace(t, nil)
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
globalSkillPath := filepath.Join(tmpHome, ".picoclaw", "skills", "global-skill", "SKILL.md")
|
||||
if err := os.MkdirAll(filepath.Dir(globalSkillPath), 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
v1 := `---
|
||||
name: global-skill
|
||||
description: global-v1
|
||||
---
|
||||
# Global Skill v1`
|
||||
if err := os.WriteFile(globalSkillPath, []byte(v1), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cb := NewContextBuilder(tmpDir)
|
||||
sp1 := cb.BuildSystemPromptWithCache()
|
||||
if !strings.Contains(sp1, "global-v1") {
|
||||
t.Fatal("expected initial prompt to contain global skill description")
|
||||
}
|
||||
|
||||
v2 := `---
|
||||
name: global-skill
|
||||
description: global-v2
|
||||
---
|
||||
# Global Skill v2`
|
||||
if err := os.WriteFile(globalSkillPath, []byte(v2), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
future := time.Now().Add(2 * time.Second)
|
||||
if err := os.Chtimes(globalSkillPath, future, future); err != nil {
|
||||
t.Fatalf("failed to update mtime for %s: %v", globalSkillPath, err)
|
||||
}
|
||||
|
||||
cb.systemPromptMutex.RLock()
|
||||
changed := cb.sourceFilesChangedLocked()
|
||||
cb.systemPromptMutex.RUnlock()
|
||||
if !changed {
|
||||
t.Fatal("sourceFilesChangedLocked() should detect global skill file content change")
|
||||
}
|
||||
|
||||
sp2 := cb.BuildSystemPromptWithCache()
|
||||
if !strings.Contains(sp2, "global-v2") {
|
||||
t.Error("rebuilt prompt should contain updated global skill description")
|
||||
}
|
||||
if sp1 == sp2 {
|
||||
t.Error("cache should be invalidated when global skill file content changes")
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuiltinSkillFileContentChange verifies that modifying a builtin skill
|
||||
// invalidates the cached system prompt.
|
||||
func TestBuiltinSkillFileContentChange(t *testing.T) {
|
||||
tmpHome := t.TempDir()
|
||||
t.Setenv("HOME", tmpHome)
|
||||
|
||||
tmpDir := setupWorkspace(t, nil)
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
builtinRoot := t.TempDir()
|
||||
t.Setenv("PICOCLAW_BUILTIN_SKILLS", builtinRoot)
|
||||
|
||||
builtinSkillPath := filepath.Join(builtinRoot, "builtin-skill", "SKILL.md")
|
||||
if err := os.MkdirAll(filepath.Dir(builtinSkillPath), 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
v1 := `---
|
||||
name: builtin-skill
|
||||
description: builtin-v1
|
||||
---
|
||||
# Builtin Skill v1`
|
||||
if err := os.WriteFile(builtinSkillPath, []byte(v1), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cb := NewContextBuilder(tmpDir)
|
||||
sp1 := cb.BuildSystemPromptWithCache()
|
||||
if !strings.Contains(sp1, "builtin-v1") {
|
||||
t.Fatal("expected initial prompt to contain builtin skill description")
|
||||
}
|
||||
|
||||
v2 := `---
|
||||
name: builtin-skill
|
||||
description: builtin-v2
|
||||
---
|
||||
# Builtin Skill v2`
|
||||
if err := os.WriteFile(builtinSkillPath, []byte(v2), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
future := time.Now().Add(2 * time.Second)
|
||||
if err := os.Chtimes(builtinSkillPath, future, future); err != nil {
|
||||
t.Fatalf("failed to update mtime for %s: %v", builtinSkillPath, err)
|
||||
}
|
||||
|
||||
cb.systemPromptMutex.RLock()
|
||||
changed := cb.sourceFilesChangedLocked()
|
||||
cb.systemPromptMutex.RUnlock()
|
||||
if !changed {
|
||||
t.Fatal("sourceFilesChangedLocked() should detect builtin skill file content change")
|
||||
}
|
||||
|
||||
sp2 := cb.BuildSystemPromptWithCache()
|
||||
if !strings.Contains(sp2, "builtin-v2") {
|
||||
t.Error("rebuilt prompt should contain updated builtin skill description")
|
||||
}
|
||||
if sp1 == sp2 {
|
||||
t.Error("cache should be invalidated when builtin skill file content changes")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSkillFileDeletionInvalidatesCache verifies that deleting a nested skill
|
||||
// file invalidates the cached system prompt.
|
||||
func TestSkillFileDeletionInvalidatesCache(t *testing.T) {
|
||||
tmpDir := setupWorkspace(t, map[string]string{
|
||||
"skills/delete-me/SKILL.md": `---
|
||||
name: delete-me
|
||||
description: delete-me-v1
|
||||
---
|
||||
# Delete Me`,
|
||||
})
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cb := NewContextBuilder(tmpDir)
|
||||
sp1 := cb.BuildSystemPromptWithCache()
|
||||
if !strings.Contains(sp1, "delete-me-v1") {
|
||||
t.Fatal("expected initial prompt to contain skill description")
|
||||
}
|
||||
|
||||
skillPath := filepath.Join(tmpDir, "skills", "delete-me", "SKILL.md")
|
||||
if err := os.Remove(skillPath); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cb.systemPromptMutex.RLock()
|
||||
changed := cb.sourceFilesChangedLocked()
|
||||
cb.systemPromptMutex.RUnlock()
|
||||
if !changed {
|
||||
t.Fatal("sourceFilesChangedLocked() should detect deleted skill file")
|
||||
}
|
||||
|
||||
sp2 := cb.BuildSystemPromptWithCache()
|
||||
if strings.Contains(sp2, "delete-me-v1") {
|
||||
t.Error("rebuilt prompt should not contain deleted skill description")
|
||||
}
|
||||
if sp1 == sp2 {
|
||||
t.Error("cache should be invalidated when skill file is deleted")
|
||||
}
|
||||
}
|
||||
|
||||
// TestConcurrentBuildSystemPromptWithCache verifies that multiple goroutines
|
||||
// can safely call BuildSystemPromptWithCache concurrently without producing
|
||||
// empty results, panics, or data races.
|
||||
@@ -404,11 +560,11 @@ func TestConcurrentBuildSystemPromptWithCache(t *testing.T) {
|
||||
var wg sync.WaitGroup
|
||||
errs := make(chan string, goroutines*iterations)
|
||||
|
||||
for g := 0; g < goroutines; g++ {
|
||||
for g := range goroutines {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
for i := 0; i < iterations; i++ {
|
||||
for i := range iterations {
|
||||
result := cb.BuildSystemPromptWithCache()
|
||||
if result == "" {
|
||||
errs <- "empty prompt returned"
|
||||
|
||||
+26
-5
@@ -1,9 +1,11 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
@@ -48,18 +50,24 @@ func NewAgentInstance(
|
||||
fallbacks := resolveAgentFallbacks(agentCfg, defaults)
|
||||
|
||||
restrict := defaults.RestrictToWorkspace
|
||||
readRestrict := restrict && !defaults.AllowReadOutsideWorkspace
|
||||
|
||||
// Compile path whitelist patterns from config.
|
||||
allowReadPaths := compilePatterns(cfg.Tools.AllowReadPaths)
|
||||
allowWritePaths := compilePatterns(cfg.Tools.AllowWritePaths)
|
||||
|
||||
toolsRegistry := tools.NewToolRegistry()
|
||||
toolsRegistry.Register(tools.NewReadFileTool(workspace, restrict))
|
||||
toolsRegistry.Register(tools.NewWriteFileTool(workspace, restrict))
|
||||
toolsRegistry.Register(tools.NewListDirTool(workspace, restrict))
|
||||
toolsRegistry.Register(tools.NewReadFileTool(workspace, readRestrict, allowReadPaths))
|
||||
toolsRegistry.Register(tools.NewWriteFileTool(workspace, restrict, allowWritePaths))
|
||||
toolsRegistry.Register(tools.NewListDirTool(workspace, readRestrict, allowReadPaths))
|
||||
execTool, err := tools.NewExecToolWithConfig(workspace, restrict, cfg)
|
||||
if err != nil {
|
||||
log.Fatalf("Critical error: unable to initialize exec tool: %v", err)
|
||||
}
|
||||
toolsRegistry.Register(execTool)
|
||||
|
||||
toolsRegistry.Register(tools.NewEditFileTool(workspace, restrict))
|
||||
toolsRegistry.Register(tools.NewAppendFileTool(workspace, restrict))
|
||||
toolsRegistry.Register(tools.NewEditFileTool(workspace, restrict, allowWritePaths))
|
||||
toolsRegistry.Register(tools.NewAppendFileTool(workspace, restrict, allowWritePaths))
|
||||
|
||||
sessionsDir := filepath.Join(workspace, "sessions")
|
||||
sessionsManager := session.NewSessionManager(sessionsDir)
|
||||
@@ -189,6 +197,19 @@ func resolveAgentFallbacks(agentCfg *config.AgentConfig, defaults *config.AgentD
|
||||
return defaults.ModelFallbacks
|
||||
}
|
||||
|
||||
func compilePatterns(patterns []string) []*regexp.Regexp {
|
||||
compiled := make([]*regexp.Regexp, 0, len(patterns))
|
||||
for _, p := range patterns {
|
||||
re, err := regexp.Compile(p)
|
||||
if err != nil {
|
||||
fmt.Printf("Warning: invalid path pattern %q: %v\n", p, err)
|
||||
continue
|
||||
}
|
||||
compiled = append(compiled, re)
|
||||
}
|
||||
return compiled
|
||||
}
|
||||
|
||||
func expandHome(path string) string {
|
||||
if path == "" {
|
||||
return path
|
||||
|
||||
+58
-65
@@ -95,75 +95,68 @@ func TestNewAgentInstance_DefaultsTemperatureWhenUnset(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestNewAgentInstance_ResolveCandidatesFromModelListAlias(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-instance-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: "step-3.5-flash",
|
||||
},
|
||||
tests := []struct {
|
||||
name string
|
||||
aliasName string
|
||||
modelName string
|
||||
apiBase string
|
||||
wantProvider string
|
||||
wantModel string
|
||||
}{
|
||||
{
|
||||
name: "alias with provider prefix",
|
||||
aliasName: "step-3.5-flash",
|
||||
modelName: "openrouter/stepfun/step-3.5-flash:free",
|
||||
apiBase: "https://openrouter.ai/api/v1",
|
||||
wantProvider: "openrouter",
|
||||
wantModel: "stepfun/step-3.5-flash:free",
|
||||
},
|
||||
ModelList: []config.ModelConfig{
|
||||
{
|
||||
ModelName: "step-3.5-flash",
|
||||
Model: "openrouter/stepfun/step-3.5-flash:free",
|
||||
APIBase: "https://openrouter.ai/api/v1",
|
||||
},
|
||||
{
|
||||
name: "alias without provider prefix",
|
||||
aliasName: "glm-5",
|
||||
modelName: "glm-5",
|
||||
apiBase: "https://api.z.ai/api/coding/paas/v4",
|
||||
wantProvider: "openai",
|
||||
wantModel: "glm-5",
|
||||
},
|
||||
}
|
||||
|
||||
provider := &mockProvider{}
|
||||
agent := NewAgentInstance(nil, &cfg.Agents.Defaults, cfg, provider)
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-instance-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
if len(agent.Candidates) != 1 {
|
||||
t.Fatalf("len(Candidates) = %d, want 1", len(agent.Candidates))
|
||||
}
|
||||
if agent.Candidates[0].Provider != "openrouter" {
|
||||
t.Fatalf("candidate provider = %q, want %q", agent.Candidates[0].Provider, "openrouter")
|
||||
}
|
||||
if agent.Candidates[0].Model != "stepfun/step-3.5-flash:free" {
|
||||
t.Fatalf("candidate model = %q, want %q", agent.Candidates[0].Model, "stepfun/step-3.5-flash:free")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewAgentInstance_ResolveCandidatesFromModelListAliasWithoutProtocol(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-instance-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: "glm-5",
|
||||
},
|
||||
},
|
||||
ModelList: []config.ModelConfig{
|
||||
{
|
||||
ModelName: "glm-5",
|
||||
Model: "glm-5",
|
||||
APIBase: "https://api.z.ai/api/coding/paas/v4",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
provider := &mockProvider{}
|
||||
agent := NewAgentInstance(nil, &cfg.Agents.Defaults, cfg, provider)
|
||||
|
||||
if len(agent.Candidates) != 1 {
|
||||
t.Fatalf("len(Candidates) = %d, want 1", len(agent.Candidates))
|
||||
}
|
||||
if agent.Candidates[0].Provider != "openai" {
|
||||
t.Fatalf("candidate provider = %q, want %q", agent.Candidates[0].Provider, "openai")
|
||||
}
|
||||
if agent.Candidates[0].Model != "glm-5" {
|
||||
t.Fatalf("candidate model = %q, want %q", agent.Candidates[0].Model, "glm-5")
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: tt.aliasName,
|
||||
},
|
||||
},
|
||||
ModelList: []config.ModelConfig{
|
||||
{
|
||||
ModelName: tt.aliasName,
|
||||
Model: tt.modelName,
|
||||
APIBase: tt.apiBase,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
provider := &mockProvider{}
|
||||
agent := NewAgentInstance(nil, &cfg.Agents.Defaults, cfg, provider)
|
||||
|
||||
if len(agent.Candidates) != 1 {
|
||||
t.Fatalf("len(Candidates) = %d, want 1", len(agent.Candidates))
|
||||
}
|
||||
if agent.Candidates[0].Provider != tt.wantProvider {
|
||||
t.Fatalf("candidate provider = %q, want %q", agent.Candidates[0].Provider, tt.wantProvider)
|
||||
}
|
||||
if agent.Candidates[0].Model != tt.wantModel {
|
||||
t.Fatalf("candidate model = %q, want %q", agent.Candidates[0].Model, tt.wantModel)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
+180
-39
@@ -23,6 +23,7 @@ import (
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/constants"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/mcp"
|
||||
"github.com/sipeed/picoclaw/pkg/media"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/routing"
|
||||
@@ -46,19 +47,24 @@ type AgentLoop struct {
|
||||
|
||||
// processOptions configures how a message is processed
|
||||
type processOptions struct {
|
||||
SessionKey string // Session identifier for history/context
|
||||
Channel string // Target channel for tool execution
|
||||
ChatID string // Target chat ID for tool execution
|
||||
UserMessage string // User message content (may include prefix)
|
||||
DefaultResponse string // Response when LLM returns empty
|
||||
EnableSummary bool // Whether to trigger summarization
|
||||
SendResponse bool // Whether to send response via bus
|
||||
NoHistory bool // If true, don't load session history (for heartbeat)
|
||||
SessionKey string // Session identifier for history/context
|
||||
Channel string // Target channel for tool execution
|
||||
ChatID string // Target chat ID for tool execution
|
||||
UserMessage string // User message content (may include prefix)
|
||||
Media []string // media:// refs from inbound message
|
||||
DefaultResponse string // Response when LLM returns empty
|
||||
EnableSummary bool // Whether to trigger summarization
|
||||
SendResponse bool // Whether to send response via bus
|
||||
NoHistory bool // If true, don't load session history (for heartbeat)
|
||||
}
|
||||
|
||||
const defaultResponse = "I've completed processing but have no response to give. Increase `max_tool_iterations` in config.json."
|
||||
|
||||
func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers.LLMProvider) *AgentLoop {
|
||||
func NewAgentLoop(
|
||||
cfg *config.Config,
|
||||
msgBus *bus.MessageBus,
|
||||
provider providers.LLMProvider,
|
||||
) *AgentLoop {
|
||||
registry := NewAgentRegistry(cfg, provider)
|
||||
|
||||
// Register shared tools to all agents
|
||||
@@ -99,7 +105,7 @@ func registerSharedTools(
|
||||
}
|
||||
|
||||
// Web tools
|
||||
if searchTool := tools.NewWebSearchTool(tools.WebSearchToolOptions{
|
||||
searchTool, err := tools.NewWebSearchTool(tools.WebSearchToolOptions{
|
||||
BraveAPIKey: cfg.Tools.Web.Brave.APIKey,
|
||||
BraveMaxResults: cfg.Tools.Web.Brave.MaxResults,
|
||||
BraveEnabled: cfg.Tools.Web.Brave.Enabled,
|
||||
@@ -113,10 +119,18 @@ func registerSharedTools(
|
||||
PerplexityMaxResults: cfg.Tools.Web.Perplexity.MaxResults,
|
||||
PerplexityEnabled: cfg.Tools.Web.Perplexity.Enabled,
|
||||
Proxy: cfg.Tools.Web.Proxy,
|
||||
}); searchTool != nil {
|
||||
})
|
||||
if err != nil {
|
||||
logger.ErrorCF("agent", "Failed to create web search tool", map[string]any{"error": err.Error()})
|
||||
} else if searchTool != nil {
|
||||
agent.Tools.Register(searchTool)
|
||||
}
|
||||
agent.Tools.Register(tools.NewWebFetchToolWithProxy(50000, cfg.Tools.Web.Proxy))
|
||||
fetchTool, err := tools.NewWebFetchToolWithProxy(50000, cfg.Tools.Web.Proxy, cfg.Tools.Web.FetchLimitBytes)
|
||||
if err != nil {
|
||||
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
|
||||
} else {
|
||||
agent.Tools.Register(fetchTool)
|
||||
}
|
||||
|
||||
// Hardware tools (I2C, SPI) - Linux only, returns error on other platforms
|
||||
agent.Tools.Register(tools.NewI2CTool())
|
||||
@@ -162,6 +176,72 @@ func registerSharedTools(
|
||||
func (al *AgentLoop) Run(ctx context.Context) error {
|
||||
al.running.Store(true)
|
||||
|
||||
// Initialize MCP servers for all agents
|
||||
if al.cfg.Tools.MCP.Enabled {
|
||||
mcpManager := mcp.NewManager()
|
||||
// Ensure MCP connections are cleaned up on exit, regardless of initialization success
|
||||
// This fixes resource leak when LoadFromMCPConfig partially succeeds then fails
|
||||
defer func() {
|
||||
if err := mcpManager.Close(); err != nil {
|
||||
logger.ErrorCF("agent", "Failed to close MCP manager",
|
||||
map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
}()
|
||||
|
||||
defaultAgent := al.registry.GetDefaultAgent()
|
||||
var workspacePath string
|
||||
if defaultAgent != nil && defaultAgent.Workspace != "" {
|
||||
workspacePath = defaultAgent.Workspace
|
||||
} else {
|
||||
workspacePath = al.cfg.WorkspacePath()
|
||||
}
|
||||
|
||||
if err := mcpManager.LoadFromMCPConfig(ctx, al.cfg.Tools.MCP, workspacePath); err != nil {
|
||||
logger.WarnCF("agent", "Failed to load MCP servers, MCP tools will not be available",
|
||||
map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
} else {
|
||||
// Register MCP tools for all agents
|
||||
servers := mcpManager.GetServers()
|
||||
uniqueTools := 0
|
||||
totalRegistrations := 0
|
||||
agentIDs := al.registry.ListAgentIDs()
|
||||
agentCount := len(agentIDs)
|
||||
|
||||
for serverName, conn := range servers {
|
||||
uniqueTools += len(conn.Tools)
|
||||
for _, tool := range conn.Tools {
|
||||
for _, agentID := range agentIDs {
|
||||
agent, ok := al.registry.GetAgent(agentID)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
mcpTool := tools.NewMCPTool(mcpManager, serverName, tool)
|
||||
agent.Tools.Register(mcpTool)
|
||||
totalRegistrations++
|
||||
logger.DebugCF("agent", "Registered MCP tool",
|
||||
map[string]any{
|
||||
"agent_id": agentID,
|
||||
"server": serverName,
|
||||
"tool": tool.Name,
|
||||
"name": mcpTool.Name(),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
logger.InfoCF("agent", "MCP tools registered successfully",
|
||||
map[string]any{
|
||||
"server_count": len(servers),
|
||||
"unique_tools": uniqueTools,
|
||||
"total_registrations": totalRegistrations,
|
||||
"agent_count": agentCount,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
for al.running.Load() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
@@ -302,7 +382,10 @@ func (al *AgentLoop) RecordLastChatID(chatID string) error {
|
||||
return al.state.SetLastChatID(chatID)
|
||||
}
|
||||
|
||||
func (al *AgentLoop) ProcessDirect(ctx context.Context, content, sessionKey string) (string, error) {
|
||||
func (al *AgentLoop) ProcessDirect(
|
||||
ctx context.Context,
|
||||
content, sessionKey string,
|
||||
) (string, error) {
|
||||
return al.ProcessDirectWithChannel(ctx, content, sessionKey, "cli", "direct")
|
||||
}
|
||||
|
||||
@@ -323,7 +406,10 @@ func (al *AgentLoop) ProcessDirectWithChannel(
|
||||
|
||||
// ProcessHeartbeat processes a heartbeat request without session history.
|
||||
// Each heartbeat is independent and doesn't accumulate context.
|
||||
func (al *AgentLoop) ProcessHeartbeat(ctx context.Context, content, channel, chatID string) (string, error) {
|
||||
func (al *AgentLoop) ProcessHeartbeat(
|
||||
ctx context.Context,
|
||||
content, channel, chatID string,
|
||||
) (string, error) {
|
||||
agent := al.registry.GetDefaultAgent()
|
||||
if agent == nil {
|
||||
return "", fmt.Errorf("no default agent for heartbeat")
|
||||
@@ -348,13 +434,16 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage)
|
||||
} else {
|
||||
logContent = utils.Truncate(msg.Content, 80)
|
||||
}
|
||||
logger.InfoCF("agent", fmt.Sprintf("Processing message from %s:%s: %s", msg.Channel, msg.SenderID, logContent),
|
||||
logger.InfoCF(
|
||||
"agent",
|
||||
fmt.Sprintf("Processing message from %s:%s: %s", msg.Channel, msg.SenderID, logContent),
|
||||
map[string]any{
|
||||
"channel": msg.Channel,
|
||||
"chat_id": msg.ChatID,
|
||||
"sender_id": msg.SenderID,
|
||||
"session_key": msg.SessionKey,
|
||||
})
|
||||
},
|
||||
)
|
||||
|
||||
// Route system messages to processSystemMessage
|
||||
if msg.Channel == "system" {
|
||||
@@ -409,15 +498,22 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage)
|
||||
Channel: msg.Channel,
|
||||
ChatID: msg.ChatID,
|
||||
UserMessage: msg.Content,
|
||||
Media: msg.Media,
|
||||
DefaultResponse: defaultResponse,
|
||||
EnableSummary: true,
|
||||
SendResponse: false,
|
||||
})
|
||||
}
|
||||
|
||||
func (al *AgentLoop) processSystemMessage(ctx context.Context, msg bus.InboundMessage) (string, error) {
|
||||
func (al *AgentLoop) processSystemMessage(
|
||||
ctx context.Context,
|
||||
msg bus.InboundMessage,
|
||||
) (string, error) {
|
||||
if msg.Channel != "system" {
|
||||
return "", fmt.Errorf("processSystemMessage called with non-system message channel: %s", msg.Channel)
|
||||
return "", fmt.Errorf(
|
||||
"processSystemMessage called with non-system message channel: %s",
|
||||
msg.Channel,
|
||||
)
|
||||
}
|
||||
|
||||
logger.InfoCF("agent", "Processing system message",
|
||||
@@ -475,14 +571,22 @@ func (al *AgentLoop) processSystemMessage(ctx context.Context, msg bus.InboundMe
|
||||
}
|
||||
|
||||
// runAgentLoop is the core message processing logic.
|
||||
func (al *AgentLoop) runAgentLoop(ctx context.Context, agent *AgentInstance, opts processOptions) (string, error) {
|
||||
func (al *AgentLoop) runAgentLoop(
|
||||
ctx context.Context,
|
||||
agent *AgentInstance,
|
||||
opts processOptions,
|
||||
) (string, error) {
|
||||
// 0. Record last channel for heartbeat notifications (skip internal channels)
|
||||
if opts.Channel != "" && opts.ChatID != "" {
|
||||
// Don't record internal channels (cli, system, subagent)
|
||||
if !constants.IsInternalChannel(opts.Channel) {
|
||||
channelKey := fmt.Sprintf("%s:%s", opts.Channel, opts.ChatID)
|
||||
if err := al.RecordLastChannel(channelKey); err != nil {
|
||||
logger.WarnCF("agent", "Failed to record last channel", map[string]any{"error": err.Error()})
|
||||
logger.WarnCF(
|
||||
"agent",
|
||||
"Failed to record last channel",
|
||||
map[string]any{"error": err.Error()},
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -501,11 +605,15 @@ func (al *AgentLoop) runAgentLoop(ctx context.Context, agent *AgentInstance, opt
|
||||
history,
|
||||
summary,
|
||||
opts.UserMessage,
|
||||
nil,
|
||||
opts.Media,
|
||||
opts.Channel,
|
||||
opts.ChatID,
|
||||
)
|
||||
|
||||
// Resolve media:// refs to base64 data URLs (streaming)
|
||||
maxMediaSize := al.cfg.Agents.Defaults.GetMaxMediaSize()
|
||||
messages = resolveMediaRefs(messages, al.mediaStore, maxMediaSize)
|
||||
|
||||
// 3. Save user message to session
|
||||
agent.Sessions.AddMessage(opts.SessionKey, "user", opts.UserMessage)
|
||||
|
||||
@@ -564,7 +672,10 @@ func (al *AgentLoop) targetReasoningChannelID(channelName string) (chatID string
|
||||
return ""
|
||||
}
|
||||
|
||||
func (al *AgentLoop) handleReasoning(ctx context.Context, reasoningContent, channelName, channelID string) {
|
||||
func (al *AgentLoop) handleReasoning(
|
||||
ctx context.Context,
|
||||
reasoningContent, channelName, channelID string,
|
||||
) {
|
||||
if reasoningContent == "" || channelName == "" || channelID == "" {
|
||||
return
|
||||
}
|
||||
@@ -657,22 +768,33 @@ func (al *AgentLoop) runLLMIteration(
|
||||
|
||||
callLLM := func() (*providers.LLMResponse, error) {
|
||||
if len(agent.Candidates) > 1 && al.fallback != nil {
|
||||
fbResult, fbErr := al.fallback.Execute(ctx, agent.Candidates,
|
||||
fbResult, fbErr := al.fallback.Execute(
|
||||
ctx,
|
||||
agent.Candidates,
|
||||
func(ctx context.Context, provider, model string) (*providers.LLMResponse, error) {
|
||||
return agent.Provider.Chat(ctx, messages, providerToolDefs, model, map[string]any{
|
||||
"max_tokens": agent.MaxTokens,
|
||||
"temperature": agent.Temperature,
|
||||
"prompt_cache_key": agent.ID,
|
||||
})
|
||||
return agent.Provider.Chat(
|
||||
ctx,
|
||||
messages,
|
||||
providerToolDefs,
|
||||
model,
|
||||
map[string]any{
|
||||
"max_tokens": agent.MaxTokens,
|
||||
"temperature": agent.Temperature,
|
||||
"prompt_cache_key": agent.ID,
|
||||
},
|
||||
)
|
||||
},
|
||||
)
|
||||
if fbErr != nil {
|
||||
return nil, fbErr
|
||||
}
|
||||
if fbResult.Provider != "" && len(fbResult.Attempts) > 0 {
|
||||
logger.InfoCF("agent", fmt.Sprintf("Fallback: succeeded with %s/%s after %d attempts",
|
||||
fbResult.Provider, fbResult.Model, len(fbResult.Attempts)+1),
|
||||
map[string]any{"agent_id": agent.ID, "iteration": iteration})
|
||||
logger.InfoCF(
|
||||
"agent",
|
||||
fmt.Sprintf("Fallback: succeeded with %s/%s after %d attempts",
|
||||
fbResult.Provider, fbResult.Model, len(fbResult.Attempts)+1),
|
||||
map[string]any{"agent_id": agent.ID, "iteration": iteration},
|
||||
)
|
||||
}
|
||||
return fbResult.Response, nil
|
||||
}
|
||||
@@ -723,10 +845,14 @@ func (al *AgentLoop) runLLMIteration(
|
||||
}
|
||||
|
||||
if isContextError && retry < maxRetries {
|
||||
logger.WarnCF("agent", "Context window error detected, attempting compression", map[string]any{
|
||||
"error": err.Error(),
|
||||
"retry": retry,
|
||||
})
|
||||
logger.WarnCF(
|
||||
"agent",
|
||||
"Context window error detected, attempting compression",
|
||||
map[string]any{
|
||||
"error": err.Error(),
|
||||
"retry": retry,
|
||||
},
|
||||
)
|
||||
|
||||
if retry == 0 && !constants.IsInternalChannel(opts.Channel) {
|
||||
al.bus.PublishOutbound(ctx, bus.OutboundMessage{
|
||||
@@ -758,7 +884,12 @@ func (al *AgentLoop) runLLMIteration(
|
||||
return "", iteration, fmt.Errorf("LLM call failed after retries: %w", err)
|
||||
}
|
||||
|
||||
go al.handleReasoning(ctx, response.Reasoning, opts.Channel, al.targetReasoningChannelID(opts.Channel))
|
||||
go al.handleReasoning(
|
||||
ctx,
|
||||
response.Reasoning,
|
||||
opts.Channel,
|
||||
al.targetReasoningChannelID(opts.Channel),
|
||||
)
|
||||
|
||||
logger.DebugCF("agent", "LLM response",
|
||||
map[string]any{
|
||||
@@ -1067,7 +1198,11 @@ func formatMessagesForLog(messages []providers.Message) string {
|
||||
for _, tc := range msg.ToolCalls {
|
||||
fmt.Fprintf(&sb, " - ID: %s, Type: %s, Name: %s\n", tc.ID, tc.Type, tc.Name)
|
||||
if tc.Function != nil {
|
||||
fmt.Fprintf(&sb, " Arguments: %s\n", utils.Truncate(tc.Function.Arguments, 200))
|
||||
fmt.Fprintf(
|
||||
&sb,
|
||||
" Arguments: %s\n",
|
||||
utils.Truncate(tc.Function.Arguments, 200),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1096,7 +1231,11 @@ func formatToolsForLog(toolDefs []providers.ToolDefinition) string {
|
||||
fmt.Fprintf(&sb, " [%d] Type: %s, Name: %s\n", i, tool.Type, tool.Function.Name)
|
||||
fmt.Fprintf(&sb, " Description: %s\n", tool.Function.Description)
|
||||
if len(tool.Function.Parameters) > 0 {
|
||||
fmt.Fprintf(&sb, " Parameters: %s\n", utils.Truncate(fmt.Sprintf("%v", tool.Function.Parameters), 200))
|
||||
fmt.Fprintf(
|
||||
&sb,
|
||||
" Parameters: %s\n",
|
||||
utils.Truncate(fmt.Sprintf("%v", tool.Function.Parameters), 200),
|
||||
)
|
||||
}
|
||||
}
|
||||
sb.WriteString("]")
|
||||
@@ -1193,7 +1332,9 @@ func (al *AgentLoop) summarizeBatch(
|
||||
existingSummary string,
|
||||
) (string, error) {
|
||||
var sb strings.Builder
|
||||
sb.WriteString("Provide a concise summary of this conversation segment, preserving core context and key points.\n")
|
||||
sb.WriteString(
|
||||
"Provide a concise summary of this conversation segment, preserving core context and key points.\n",
|
||||
)
|
||||
if existingSummary != "" {
|
||||
sb.WriteString("Existing context: ")
|
||||
sb.WriteString(existingSummary)
|
||||
|
||||
@@ -0,0 +1,122 @@
|
||||
// PicoClaw - Ultra-lightweight personal AI agent
|
||||
// Inspired by and based on nanobot: https://github.com/HKUDS/nanobot
|
||||
// License: MIT
|
||||
//
|
||||
// Copyright (c) 2026 PicoClaw contributors
|
||||
|
||||
package agent
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/h2non/filetype"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/media"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
// resolveMediaRefs replaces media:// refs in message Media fields with base64 data URLs.
|
||||
// Uses streaming base64 encoding (file handle → encoder → buffer) to avoid holding
|
||||
// both raw bytes and encoded string in memory simultaneously.
|
||||
// Returns a new slice; original messages are not mutated.
|
||||
func resolveMediaRefs(messages []providers.Message, store media.MediaStore, maxSize int) []providers.Message {
|
||||
if store == nil {
|
||||
return messages
|
||||
}
|
||||
|
||||
result := make([]providers.Message, len(messages))
|
||||
copy(result, messages)
|
||||
|
||||
for i, m := range result {
|
||||
if len(m.Media) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
resolved := make([]string, 0, len(m.Media))
|
||||
for _, ref := range m.Media {
|
||||
if !strings.HasPrefix(ref, "media://") {
|
||||
resolved = append(resolved, ref)
|
||||
continue
|
||||
}
|
||||
|
||||
localPath, meta, err := store.ResolveWithMeta(ref)
|
||||
if err != nil {
|
||||
logger.WarnCF("agent", "Failed to resolve media ref", map[string]any{
|
||||
"ref": ref,
|
||||
"error": err.Error(),
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
info, err := os.Stat(localPath)
|
||||
if err != nil {
|
||||
logger.WarnCF("agent", "Failed to stat media file", map[string]any{
|
||||
"path": localPath,
|
||||
"error": err.Error(),
|
||||
})
|
||||
continue
|
||||
}
|
||||
if info.Size() > int64(maxSize) {
|
||||
logger.WarnCF("agent", "Media file too large, skipping", map[string]any{
|
||||
"path": localPath,
|
||||
"size": info.Size(),
|
||||
"max_size": maxSize,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
// Determine MIME type: prefer metadata, fallback to magic-bytes detection
|
||||
mime := meta.ContentType
|
||||
if mime == "" {
|
||||
kind, ftErr := filetype.MatchFile(localPath)
|
||||
if ftErr != nil || kind == filetype.Unknown {
|
||||
logger.WarnCF("agent", "Unknown media type, skipping", map[string]any{
|
||||
"path": localPath,
|
||||
})
|
||||
continue
|
||||
}
|
||||
mime = kind.MIME.Value
|
||||
}
|
||||
|
||||
// Streaming base64: open file → base64 encoder → buffer
|
||||
// Peak memory: ~1.33x file size (buffer only, no raw bytes copy)
|
||||
f, err := os.Open(localPath)
|
||||
if err != nil {
|
||||
logger.WarnCF("agent", "Failed to open media file", map[string]any{
|
||||
"path": localPath,
|
||||
"error": err.Error(),
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
prefix := "data:" + mime + ";base64,"
|
||||
encodedLen := base64.StdEncoding.EncodedLen(int(info.Size()))
|
||||
var buf bytes.Buffer
|
||||
buf.Grow(len(prefix) + encodedLen)
|
||||
buf.WriteString(prefix)
|
||||
|
||||
encoder := base64.NewEncoder(base64.StdEncoding, &buf)
|
||||
if _, err := io.Copy(encoder, f); err != nil {
|
||||
f.Close()
|
||||
logger.WarnCF("agent", "Failed to encode media file", map[string]any{
|
||||
"path": localPath,
|
||||
"error": err.Error(),
|
||||
})
|
||||
continue
|
||||
}
|
||||
encoder.Close()
|
||||
f.Close()
|
||||
|
||||
resolved = append(resolved, buf.String())
|
||||
}
|
||||
|
||||
result[i].Media = resolved
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
+168
-71
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -12,6 +13,7 @@ import (
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/channels"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/media"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/tools"
|
||||
)
|
||||
@@ -27,16 +29,15 @@ func (f *fakeChannel) IsAllowed(string) bool {
|
||||
func (f *fakeChannel) IsAllowedSender(sender bus.SenderInfo) bool { return true }
|
||||
func (f *fakeChannel) ReasoningChannelID() string { return f.id }
|
||||
|
||||
func TestRecordLastChannel(t *testing.T) {
|
||||
// Create temp workspace
|
||||
func newTestAgentLoop(
|
||||
t *testing.T,
|
||||
) (al *AgentLoop, cfg *config.Config, msgBus *bus.MessageBus, provider *mockProvider, cleanup func()) {
|
||||
t.Helper()
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
// Create test config
|
||||
cfg := &config.Config{
|
||||
cfg = &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
@@ -46,74 +47,43 @@ func TestRecordLastChannel(t *testing.T) {
|
||||
},
|
||||
},
|
||||
}
|
||||
msgBus = bus.NewMessageBus()
|
||||
provider = &mockProvider{}
|
||||
al = NewAgentLoop(cfg, msgBus, provider)
|
||||
return al, cfg, msgBus, provider, func() { os.RemoveAll(tmpDir) }
|
||||
}
|
||||
|
||||
// Create agent loop
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &mockProvider{}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
func TestRecordLastChannel(t *testing.T) {
|
||||
al, cfg, msgBus, provider, cleanup := newTestAgentLoop(t)
|
||||
defer cleanup()
|
||||
|
||||
// Test RecordLastChannel
|
||||
testChannel := "test-channel"
|
||||
err = al.RecordLastChannel(testChannel)
|
||||
if err != nil {
|
||||
if err := al.RecordLastChannel(testChannel); err != nil {
|
||||
t.Fatalf("RecordLastChannel failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify channel was saved
|
||||
lastChannel := al.state.GetLastChannel()
|
||||
if lastChannel != testChannel {
|
||||
t.Errorf("Expected channel '%s', got '%s'", testChannel, lastChannel)
|
||||
if got := al.state.GetLastChannel(); got != testChannel {
|
||||
t.Errorf("Expected channel '%s', got '%s'", testChannel, got)
|
||||
}
|
||||
|
||||
// Verify persistence by creating a new agent loop
|
||||
al2 := NewAgentLoop(cfg, msgBus, provider)
|
||||
if al2.state.GetLastChannel() != testChannel {
|
||||
t.Errorf("Expected persistent channel '%s', got '%s'", testChannel, al2.state.GetLastChannel())
|
||||
if got := al2.state.GetLastChannel(); got != testChannel {
|
||||
t.Errorf("Expected persistent channel '%s', got '%s'", testChannel, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordLastChatID(t *testing.T) {
|
||||
// Create temp workspace
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
al, cfg, msgBus, provider, cleanup := newTestAgentLoop(t)
|
||||
defer cleanup()
|
||||
|
||||
// Create test config
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Create agent loop
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &mockProvider{}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
|
||||
// Test RecordLastChatID
|
||||
testChatID := "test-chat-id-123"
|
||||
err = al.RecordLastChatID(testChatID)
|
||||
if err != nil {
|
||||
if err := al.RecordLastChatID(testChatID); err != nil {
|
||||
t.Fatalf("RecordLastChatID failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify chat ID was saved
|
||||
lastChatID := al.state.GetLastChatID()
|
||||
if lastChatID != testChatID {
|
||||
t.Errorf("Expected chat ID '%s', got '%s'", testChatID, lastChatID)
|
||||
if got := al.state.GetLastChatID(); got != testChatID {
|
||||
t.Errorf("Expected chat ID '%s', got '%s'", testChatID, got)
|
||||
}
|
||||
|
||||
// Verify persistence by creating a new agent loop
|
||||
al2 := NewAgentLoop(cfg, msgBus, provider)
|
||||
if al2.state.GetLastChatID() != testChatID {
|
||||
t.Errorf("Expected persistent chat ID '%s', got '%s'", testChatID, al2.state.GetLastChatID())
|
||||
if got := al2.state.GetLastChatID(); got != testChatID {
|
||||
t.Errorf("Expected persistent chat ID '%s', got '%s'", testChatID, got)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -188,13 +158,7 @@ func TestToolRegistry_ToolRegistration(t *testing.T) {
|
||||
toolsList := toolsInfo["names"].([]string)
|
||||
|
||||
// Check that our custom tool name is in the list
|
||||
found := false
|
||||
for _, name := range toolsList {
|
||||
if name == "mock_custom" {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
found := slices.Contains(toolsList, "mock_custom")
|
||||
if !found {
|
||||
t.Error("Expected custom tool to be registered")
|
||||
}
|
||||
@@ -263,13 +227,7 @@ func TestToolRegistry_GetDefinitions(t *testing.T) {
|
||||
toolsList := toolsInfo["names"].([]string)
|
||||
|
||||
// Check that our custom tool name is in the list
|
||||
found := false
|
||||
for _, name := range toolsList {
|
||||
if name == "mock_custom" {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
found := slices.Contains(toolsList, "mock_custom")
|
||||
if !found {
|
||||
t.Error("Expected custom tool to be registered")
|
||||
}
|
||||
@@ -931,3 +889,142 @@ func TestHandleReasoning(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestResolveMediaRefs_ResolvesToBase64(t *testing.T) {
|
||||
store := media.NewFileMediaStore()
|
||||
dir := t.TempDir()
|
||||
|
||||
// Create a minimal valid PNG (8-byte header is enough for filetype detection)
|
||||
pngPath := filepath.Join(dir, "test.png")
|
||||
// PNG magic: 0x89 P N G \r \n 0x1A \n + minimal IHDR
|
||||
pngHeader := []byte{
|
||||
0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A, // PNG signature
|
||||
0x00, 0x00, 0x00, 0x0D, // IHDR length
|
||||
0x49, 0x48, 0x44, 0x52, // "IHDR"
|
||||
0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x08, 0x02, // 1x1 RGB
|
||||
0x00, 0x00, 0x00, // no interlace
|
||||
0x90, 0x77, 0x53, 0xDE, // CRC
|
||||
}
|
||||
if err := os.WriteFile(pngPath, pngHeader, 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ref, err := store.Store(pngPath, media.MediaMeta{}, "test")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
messages := []providers.Message{
|
||||
{Role: "user", Content: "describe this", Media: []string{ref}},
|
||||
}
|
||||
result := resolveMediaRefs(messages, store, config.DefaultMaxMediaSize)
|
||||
|
||||
if len(result[0].Media) != 1 {
|
||||
t.Fatalf("expected 1 resolved media, got %d", len(result[0].Media))
|
||||
}
|
||||
if !strings.HasPrefix(result[0].Media[0], "data:image/png;base64,") {
|
||||
t.Fatalf("expected data:image/png;base64, prefix, got %q", result[0].Media[0][:40])
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveMediaRefs_SkipsOversizedFile(t *testing.T) {
|
||||
store := media.NewFileMediaStore()
|
||||
dir := t.TempDir()
|
||||
|
||||
bigPath := filepath.Join(dir, "big.png")
|
||||
// Write PNG header + padding to exceed limit
|
||||
data := make([]byte, 1024+1) // 1KB + 1 byte
|
||||
copy(data, []byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A})
|
||||
if err := os.WriteFile(bigPath, data, 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ref, _ := store.Store(bigPath, media.MediaMeta{}, "test")
|
||||
|
||||
messages := []providers.Message{
|
||||
{Role: "user", Content: "hi", Media: []string{ref}},
|
||||
}
|
||||
// Use a tiny limit (1KB) so the file is oversized
|
||||
result := resolveMediaRefs(messages, store, 1024)
|
||||
|
||||
if len(result[0].Media) != 0 {
|
||||
t.Fatalf("expected 0 media (oversized), got %d", len(result[0].Media))
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveMediaRefs_SkipsUnknownType(t *testing.T) {
|
||||
store := media.NewFileMediaStore()
|
||||
dir := t.TempDir()
|
||||
|
||||
txtPath := filepath.Join(dir, "readme.txt")
|
||||
if err := os.WriteFile(txtPath, []byte("hello world"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ref, _ := store.Store(txtPath, media.MediaMeta{}, "test")
|
||||
|
||||
messages := []providers.Message{
|
||||
{Role: "user", Content: "hi", Media: []string{ref}},
|
||||
}
|
||||
result := resolveMediaRefs(messages, store, config.DefaultMaxMediaSize)
|
||||
|
||||
if len(result[0].Media) != 0 {
|
||||
t.Fatalf("expected 0 media (unknown type), got %d", len(result[0].Media))
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveMediaRefs_PassesThroughNonMediaRefs(t *testing.T) {
|
||||
messages := []providers.Message{
|
||||
{Role: "user", Content: "hi", Media: []string{"https://example.com/img.png"}},
|
||||
}
|
||||
result := resolveMediaRefs(messages, nil, config.DefaultMaxMediaSize)
|
||||
|
||||
if len(result[0].Media) != 1 || result[0].Media[0] != "https://example.com/img.png" {
|
||||
t.Fatalf("expected passthrough of non-media:// URL, got %v", result[0].Media)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveMediaRefs_DoesNotMutateOriginal(t *testing.T) {
|
||||
store := media.NewFileMediaStore()
|
||||
dir := t.TempDir()
|
||||
pngPath := filepath.Join(dir, "test.png")
|
||||
pngHeader := []byte{
|
||||
0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A,
|
||||
0x00, 0x00, 0x00, 0x0D, 0x49, 0x48, 0x44, 0x52,
|
||||
0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x08, 0x02,
|
||||
0x00, 0x00, 0x00, 0x90, 0x77, 0x53, 0xDE,
|
||||
}
|
||||
os.WriteFile(pngPath, pngHeader, 0o644)
|
||||
ref, _ := store.Store(pngPath, media.MediaMeta{}, "test")
|
||||
|
||||
original := []providers.Message{
|
||||
{Role: "user", Content: "hi", Media: []string{ref}},
|
||||
}
|
||||
originalRef := original[0].Media[0]
|
||||
|
||||
resolveMediaRefs(original, store, config.DefaultMaxMediaSize)
|
||||
|
||||
if original[0].Media[0] != originalRef {
|
||||
t.Fatal("resolveMediaRefs mutated original message slice")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveMediaRefs_UsesMetaContentType(t *testing.T) {
|
||||
store := media.NewFileMediaStore()
|
||||
dir := t.TempDir()
|
||||
|
||||
// File with JPEG content but stored with explicit content type
|
||||
jpegPath := filepath.Join(dir, "photo")
|
||||
jpegHeader := []byte{0xFF, 0xD8, 0xFF, 0xE0} // JPEG magic bytes
|
||||
os.WriteFile(jpegPath, jpegHeader, 0o644)
|
||||
ref, _ := store.Store(jpegPath, media.MediaMeta{ContentType: "image/jpeg"}, "test")
|
||||
|
||||
messages := []providers.Message{
|
||||
{Role: "user", Content: "hi", Media: []string{ref}},
|
||||
}
|
||||
result := resolveMediaRefs(messages, store, config.DefaultMaxMediaSize)
|
||||
|
||||
if len(result[0].Media) != 1 {
|
||||
t.Fatalf("expected 1 media, got %d", len(result[0].Media))
|
||||
}
|
||||
if !strings.HasPrefix(result[0].Media[0], "data:image/jpeg;base64,") {
|
||||
t.Fatalf("expected jpeg prefix, got %q", result[0].Media[0][:30])
|
||||
}
|
||||
}
|
||||
|
||||
+1
-1
@@ -111,7 +111,7 @@ func (ms *MemoryStore) GetRecentDailyNotes(days int) string {
|
||||
var sb strings.Builder
|
||||
first := true
|
||||
|
||||
for i := 0; i < days; i++ {
|
||||
for i := range days {
|
||||
date := time.Now().AddDate(0, 0, -i)
|
||||
dateStr := date.Format("20060102") // YYYYMMDD
|
||||
monthDir := dateStr[:6] // YYYYMM
|
||||
|
||||
+3
-3
@@ -67,7 +67,7 @@ func TestPublishInbound_ContextCancel(t *testing.T) {
|
||||
|
||||
// Fill the buffer
|
||||
ctx := context.Background()
|
||||
for i := 0; i < defaultBusBufferSize; i++ {
|
||||
for i := range defaultBusBufferSize {
|
||||
if err := mb.PublishInbound(ctx, InboundMessage{Content: "fill"}); err != nil {
|
||||
t.Fatalf("fill failed at %d: %v", i, err)
|
||||
}
|
||||
@@ -154,7 +154,7 @@ func TestConcurrentPublishClose(t *testing.T) {
|
||||
wg.Add(numGoroutines + 1)
|
||||
|
||||
// Spawn many goroutines trying to publish
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
for range numGoroutines {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
// Use a short timeout context so we don't block forever after close
|
||||
@@ -194,7 +194,7 @@ func TestPublishInbound_FullBuffer(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Fill the buffer
|
||||
for i := 0; i < defaultBusBufferSize; i++ {
|
||||
for i := range defaultBusBufferSize {
|
||||
if err := mb.PublishInbound(ctx, InboundMessage{Content: "fill"}); err != nil {
|
||||
t.Fatalf("fill failed at %d: %v", i, err)
|
||||
}
|
||||
|
||||
+80
-27
@@ -1,7 +1,5 @@
|
||||
# PicoClaw Channel System Refactor: Complete Development Guide
|
||||
# PicoClaw Channel System: Complete Development Guide
|
||||
|
||||
> **Branch**: `refactor/channel-system`
|
||||
> **Status**: Active development (~40 commits)
|
||||
> **Scope**: `pkg/channels/`, `pkg/bus/`, `pkg/media/`, `pkg/identity/`, `cmd/picoclaw/internal/gateway/`
|
||||
|
||||
---
|
||||
@@ -46,6 +44,8 @@ pkg/channels/
|
||||
pkg/channels/
|
||||
├── base.go # BaseChannel shared abstraction layer
|
||||
├── interfaces.go # Optional capability interfaces (TypingCapable, MessageEditor, ReactionCapable, PlaceholderCapable, PlaceholderRecorder)
|
||||
├── README.md # English documentation
|
||||
├── README.zh.md # Chinese documentation
|
||||
├── media.go # MediaSender optional interface
|
||||
├── webhook.go # WebhookHandler, HealthChecker optional interfaces
|
||||
├── errors.go # Sentinel errors (ErrNotRunning, ErrRateLimit, ErrTemporary, ErrSendFailed)
|
||||
@@ -60,7 +60,7 @@ pkg/channels/
|
||||
├── discord/
|
||||
│ ├── init.go
|
||||
│ └── discord.go
|
||||
├── slack/ line/ onebot/ dingtalk/ feishu/ wecom/ qq/ whatsapp/ maixcam/ pico/
|
||||
├── slack/ line/ onebot/ dingtalk/ feishu/ wecom/ qq/ whatsapp/ whatsapp_native/ maixcam/ pico/
|
||||
│ └── ...
|
||||
|
||||
pkg/bus/
|
||||
@@ -111,7 +111,7 @@ pkg/identity/
|
||||
|-----------|-------------|
|
||||
| **Sub-package Isolation** | Each channel is a standalone Go sub-package, depending on `BaseChannel` and interfaces from the `channels` parent package |
|
||||
| **Factory Registration** | Sub-packages self-register via `init()`, Manager looks up factories by name, eliminating import coupling |
|
||||
| **Capability Discovery** | Optional capabilities are declared via interfaces (`MediaSender`, `TypingCapable`, `ReactionCapable`, `PlaceholderCapable`, `MessageEditor`, `WebhookHandler`), discovered by Manager via runtime type assertions |
|
||||
| **Capability Discovery** | Optional capabilities are declared via interfaces (`MediaSender`, `TypingCapable`, `ReactionCapable`, `PlaceholderCapable`, `MessageEditor`, `WebhookHandler`, `HealthChecker`), discovered by Manager via runtime type assertions |
|
||||
| **Structured Messages** | Peer, MessageID, and SenderInfo promoted from Metadata to first-class fields on InboundMessage |
|
||||
| **Error Classification** | Channels return sentinel errors (`ErrRateLimit`, `ErrTemporary`, etc.), Manager uses these to determine retry strategy |
|
||||
| **Centralized Orchestration** | Rate limiting, message splitting, retries, and Typing/Reaction/Placeholder management are all handled by Manager and BaseChannel; channels only need to implement Send |
|
||||
@@ -145,6 +145,7 @@ After refactoring, these files have been removed and code moved to corresponding
|
||||
| _(did not exist)_ | `pkg/channels/interfaces.go` | New optional capability interfaces |
|
||||
| _(did not exist)_ | `pkg/channels/media.go` | New MediaSender interface |
|
||||
| _(did not exist)_ | `pkg/channels/webhook.go` | New WebhookHandler/HealthChecker |
|
||||
| _(did not exist)_ | `pkg/channels/whatsapp_native/` | New WhatsApp native mode (whatsmeow) |
|
||||
| _(did not exist)_ | `pkg/channels/split.go` | New message splitting (migrated from utils) |
|
||||
| _(did not exist)_ | `pkg/bus/types.go` | New structured message types |
|
||||
| _(did not exist)_ | `pkg/media/store.go` | New media file lifecycle management |
|
||||
@@ -220,6 +221,7 @@ func NewTelegramChannel(cfg *config.Config, bus *bus.MessageBus) (*TelegramChann
|
||||
cfg.Channels.Telegram.AllowFrom, // Allow list
|
||||
channels.WithMaxMessageLength(4096), // Platform message length limit
|
||||
channels.WithGroupTrigger(cfg.Channels.Telegram.GroupTrigger), // Group trigger config
|
||||
channels.WithReasoningChannelID(cfg.Channels.Telegram.ReasoningChannelID), // Reasoning chain routing
|
||||
)
|
||||
return &TelegramChannel{
|
||||
BaseChannel: base,
|
||||
@@ -466,6 +468,7 @@ func NewMatrixChannel(cfg *config.Config, msgBus *bus.MessageBus) (*MatrixChanne
|
||||
matrixCfg.AllowFrom, // Allow list
|
||||
channels.WithMaxMessageLength(65536), // Matrix message length limit
|
||||
channels.WithGroupTrigger(matrixCfg.GroupTrigger),
|
||||
channels.WithReasoningChannelID(matrixCfg.ReasoningChannelID), // Reasoning chain routing (optional)
|
||||
)
|
||||
|
||||
return &MatrixChannel{
|
||||
@@ -666,6 +669,32 @@ func (c *MatrixChannel) EditMessage(ctx context.Context, chatID, messageID, cont
|
||||
}
|
||||
```
|
||||
|
||||
#### PlaceholderCapable — Placeholder Messages
|
||||
|
||||
```go
|
||||
// If the platform supports sending placeholder messages (e.g. "Thinking... 💭"),
|
||||
// and the channel also implements MessageEditor, then Manager's preSend will
|
||||
// automatically edit the placeholder into the final response on outbound.
|
||||
// SendPlaceholder checks PlaceholderConfig.Enabled internally;
|
||||
// returning ("", nil) means skip.
|
||||
func (c *MatrixChannel) SendPlaceholder(ctx context.Context, chatID string) (string, error) {
|
||||
cfg := c.config.Channels.Matrix.Placeholder
|
||||
if !cfg.Enabled {
|
||||
return "", nil
|
||||
}
|
||||
text := cfg.Text
|
||||
if text == "" {
|
||||
text = "Thinking... 💭"
|
||||
}
|
||||
// Call Matrix API to send placeholder message
|
||||
msg, err := c.sendText(ctx, chatID, text)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return msg.ID, nil
|
||||
}
|
||||
```
|
||||
|
||||
#### WebhookHandler — HTTP Webhook Reception
|
||||
|
||||
```go
|
||||
@@ -746,15 +775,17 @@ When the Agent finishes processing a message, Manager's `preSend` automatically:
|
||||
```go
|
||||
type ChannelsConfig struct {
|
||||
// ... existing channels
|
||||
Matrix MatrixChannelConfig `yaml:"matrix" json:"matrix"`
|
||||
Matrix MatrixChannelConfig `json:"matrix"`
|
||||
}
|
||||
|
||||
type MatrixChannelConfig struct {
|
||||
Enabled bool `yaml:"enabled" json:"enabled"`
|
||||
HomeServer string `yaml:"home_server" json:"home_server"`
|
||||
Token string `yaml:"token" json:"token"`
|
||||
AllowFrom []string `yaml:"allow_from" json:"allow_from"`
|
||||
GroupTrigger GroupTriggerConfig `yaml:"group_trigger" json:"group_trigger"`
|
||||
Enabled bool `json:"enabled"`
|
||||
HomeServer string `json:"home_server"`
|
||||
Token string `json:"token"`
|
||||
AllowFrom []string `json:"allow_from"`
|
||||
GroupTrigger GroupTriggerConfig `json:"group_trigger"`
|
||||
Placeholder PlaceholderConfig `json:"placeholder"`
|
||||
ReasoningChannelID string `json:"reasoning_channel_id"`
|
||||
}
|
||||
```
|
||||
|
||||
@@ -767,6 +798,15 @@ if m.config.Channels.Matrix.Enabled && m.config.Channels.Matrix.Token != "" {
|
||||
}
|
||||
```
|
||||
|
||||
> **Note**: If your channel has multiple modes (like WhatsApp Bridge vs Native), branch in initChannels based on config:
|
||||
> ```go
|
||||
> if cfg.UseNative {
|
||||
> m.initChannel("whatsapp_native", "WhatsApp Native")
|
||||
> } else {
|
||||
> m.initChannel("whatsapp", "WhatsApp")
|
||||
> }
|
||||
> ```
|
||||
|
||||
#### Add blank import in Gateway
|
||||
|
||||
```go
|
||||
@@ -882,19 +922,21 @@ BaseChannel is the shared abstraction layer for all channels, providing the foll
|
||||
| `IsRunning() bool` | Atomically read running state |
|
||||
| `SetRunning(bool)` | Atomically set running state |
|
||||
| `MaxMessageLength() int` | Message length limit (rune count), 0 = unlimited |
|
||||
| `ReasoningChannelID() string` | Reasoning chain routing target channel ID (empty = no routing) |
|
||||
| `IsAllowed(senderID string) bool` | Legacy allow-list check (supports `"id\|username"` and `"@username"` formats) |
|
||||
| `IsAllowedSender(sender SenderInfo) bool` | New allow-list check (delegates to `identity.MatchAllowed`) |
|
||||
| `ShouldRespondInGroup(isMentioned, content) (bool, string)` | Unified group chat trigger filtering logic |
|
||||
| `HandleMessage(...)` | Unified inbound message handling: permission check → build MediaScope → auto-trigger Typing/Reaction → publish to Bus |
|
||||
| `HandleMessage(...)` | Unified inbound message handling: permission check → build MediaScope → auto-trigger Typing/Reaction/Placeholder → publish to Bus |
|
||||
| `SetMediaStore(s) / GetMediaStore()` | MediaStore injected by Manager |
|
||||
| `SetPlaceholderRecorder(r) / GetPlaceholderRecorder()` | PlaceholderRecorder injected by Manager |
|
||||
| `SetOwner(ch)` | Concrete channel reference injected by Manager (used for Typing/Reaction type assertions in HandleMessage) |
|
||||
| `SetOwner(ch)` | Concrete channel reference injected by Manager (used for Typing/Reaction/Placeholder type assertions in HandleMessage) |
|
||||
|
||||
**Functional Options**:
|
||||
|
||||
```go
|
||||
channels.WithMaxMessageLength(4096) // Set platform message length limit
|
||||
channels.WithGroupTrigger(groupTriggerCfg) // Set group trigger configuration
|
||||
channels.WithReasoningChannelID(id) // Set reasoning chain routing target channel
|
||||
```
|
||||
|
||||
### 4.4 Factory Registry
|
||||
@@ -998,7 +1040,7 @@ StartAll:
|
||||
- runMediaWorker (per-channel outbound media)
|
||||
- dispatchOutbound (route from bus to worker queues)
|
||||
- dispatchOutboundMedia (route from bus to media worker queues)
|
||||
- runTTLJanitor (every 10s clean up expired typing/placeholder)
|
||||
- runTTLJanitor (every 10s clean up expired typing/reaction/placeholder)
|
||||
4. Start shared HTTP server (if configured)
|
||||
|
||||
StopAll:
|
||||
@@ -1206,18 +1248,20 @@ make test # Full test suite
|
||||
|
||||
| Sub-package | Registered Name | Optional Interfaces |
|
||||
|-------------|----------------|-------------------|
|
||||
| `pkg/channels/telegram/` | `"telegram"` | MessageEditor, MediaSender, TypingCapable, PlaceholderCapable |
|
||||
| `pkg/channels/discord/` | `"discord"` | MessageEditor, TypingCapable, PlaceholderCapable |
|
||||
| `pkg/channels/slack/` | `"slack"` | ReactionCapable |
|
||||
| `pkg/channels/line/` | `"line"` | WebhookHandler, HealthChecker, TypingCapable |
|
||||
| `pkg/channels/onebot/` | `"onebot"` | ReactionCapable |
|
||||
| `pkg/channels/dingtalk/` | `"dingtalk"` | WebhookHandler |
|
||||
| `pkg/channels/feishu/` | `"feishu"` | WebhookHandler (architecture-specific build tags) |
|
||||
| `pkg/channels/wecom/` | `"wecom"` + `"wecom_app"` | WebhookHandler |
|
||||
| `pkg/channels/telegram/` | `"telegram"` | TypingCapable, PlaceholderCapable, MessageEditor, MediaSender |
|
||||
| `pkg/channels/discord/` | `"discord"` | TypingCapable, PlaceholderCapable, MessageEditor, MediaSender |
|
||||
| `pkg/channels/slack/` | `"slack"` | ReactionCapable, MediaSender |
|
||||
| `pkg/channels/line/` | `"line"` | TypingCapable, MediaSender, WebhookHandler |
|
||||
| `pkg/channels/onebot/` | `"onebot"` | ReactionCapable, MediaSender |
|
||||
| `pkg/channels/dingtalk/` | `"dingtalk"` | — |
|
||||
| `pkg/channels/feishu/` | `"feishu"` | — (architecture-specific build tags: `feishu_32.go` / `feishu_64.go`) |
|
||||
| `pkg/channels/wecom/` | `"wecom"` | WebhookHandler, HealthChecker |
|
||||
| `pkg/channels/wecom/` | `"wecom_app"` | MediaSender, WebhookHandler, HealthChecker |
|
||||
| `pkg/channels/qq/` | `"qq"` | — |
|
||||
| `pkg/channels/whatsapp/` | `"whatsapp"` | — |
|
||||
| `pkg/channels/whatsapp/` | `"whatsapp"` | — (Bridge mode) |
|
||||
| `pkg/channels/whatsapp_native/` | `"whatsapp_native"` | — (Native whatsmeow mode) |
|
||||
| `pkg/channels/maixcam/` | `"maixcam"` | — |
|
||||
| `pkg/channels/pico/` | `"pico"` | WebhookHandler (Pico Protocol), TypingCapable, PlaceholderCapable |
|
||||
| `pkg/channels/pico/` | `"pico"` | TypingCapable, PlaceholderCapable, MessageEditor, WebhookHandler |
|
||||
|
||||
### A.3 Interface Quick Reference
|
||||
|
||||
@@ -1231,6 +1275,7 @@ type Channel interface {
|
||||
IsRunning() bool
|
||||
IsAllowed(senderID string) bool
|
||||
IsAllowedSender(sender bus.SenderInfo) bool
|
||||
ReasoningChannelID() string
|
||||
}
|
||||
|
||||
// ===== Optional =====
|
||||
@@ -1324,8 +1369,16 @@ agentLoop.Stop() // Stop Agent
|
||||
|
||||
1. **Media cleanup temporarily disabled**: The `ReleaseAll` call in the Agent loop is commented out (`refactor(loop): disable media cleanup to prevent premature file deletion`) because session boundaries are not yet clearly defined. TTL cleanup remains active.
|
||||
|
||||
2. **Feishu architecture-specific compilation**: The Feishu channel uses build tags to distinguish 32-bit and 64-bit architectures (`feishu_32.go` / `feishu_64.go`).
|
||||
2. **Feishu architecture-specific compilation**: The Feishu channel uses build tags to distinguish 32-bit and 64-bit architectures (`feishu_32.go` / `feishu_64.go`). Feishu uses the SDK's WebSocket mode (not HTTP webhook), so it does not implement `WebhookHandler`.
|
||||
|
||||
3. **WeCom has two factories**: `"wecom"` (Bot mode) and `"wecom_app"` (App mode) are registered separately.
|
||||
3. **WeCom has two factories**: `"wecom"` (Bot mode, webhook only) and `"wecom_app"` (App mode, supports MediaSender) are registered separately. Both implement `WebhookHandler` and `HealthChecker`.
|
||||
|
||||
4. **Pico Protocol**: `pkg/channels/pico/` implements a custom PicoClaw native protocol channel that receives messages via webhook.
|
||||
4. **Pico Protocol**: `pkg/channels/pico/` implements a custom PicoClaw native protocol channel that receives messages via WebSocket webhook (`/pico/ws`).
|
||||
|
||||
5. **WhatsApp has two modes**: `"whatsapp"` (Bridge mode, communicates via external bridge URL) and `"whatsapp_native"` (native whatsmeow mode, connects directly to WhatsApp). Manager selects which to initialize based on `WhatsAppConfig.UseNative`.
|
||||
|
||||
6. **DingTalk uses Stream mode**: DingTalk uses the SDK's Stream/WebSocket mode (not HTTP webhook), so it does not implement `WebhookHandler`.
|
||||
|
||||
7. **PlaceholderConfig vs implementation**: `PlaceholderConfig` appears in 6 channel configs (Telegram, Discord, Slack, LINE, OneBot, Pico), but only channels that implement both `PlaceholderCapable` + `MessageEditor` (Telegram, Discord, Pico) can actually use placeholder message editing. The rest are reserved fields.
|
||||
|
||||
8. **ReasoningChannelID**: Most channel configs include a `reasoning_channel_id` field to route LLM reasoning/thinking output to a designated channel (WhatsApp, Telegram, Feishu, Discord, MaixCam, QQ, DingTalk, Slack, LINE, OneBot, WeCom, WeComApp). Note: `PicoConfig` does not currently expose this field. `BaseChannel` exposes this via the `WithReasoningChannelID` option and `ReasoningChannelID()` method.
|
||||
+79
-27
@@ -1,7 +1,5 @@
|
||||
# PicoClaw Channel System 重构:完整开发指南
|
||||
# PicoClaw Channel System:完整开发指南
|
||||
|
||||
> **分支**: `refactor/channel-system`
|
||||
> **状态**: 活跃开发中(约 40 commits)
|
||||
> **影响范围**: `pkg/channels/`, `pkg/bus/`, `pkg/media/`, `pkg/identity/`, `cmd/picoclaw/internal/gateway/`
|
||||
|
||||
---
|
||||
@@ -46,6 +44,8 @@ pkg/channels/
|
||||
pkg/channels/
|
||||
├── base.go # BaseChannel 共享抽象层
|
||||
├── interfaces.go # 可选能力接口(TypingCapable, MessageEditor, ReactionCapable, PlaceholderCapable, PlaceholderRecorder)
|
||||
├── README.md # 英文文档
|
||||
├── README.zh.md # 中文文档
|
||||
├── media.go # MediaSender 可选接口
|
||||
├── webhook.go # WebhookHandler, HealthChecker 可选接口
|
||||
├── errors.go # 错误哨兵值(ErrNotRunning, ErrRateLimit, ErrTemporary, ErrSendFailed)
|
||||
@@ -60,7 +60,7 @@ pkg/channels/
|
||||
├── discord/
|
||||
│ ├── init.go
|
||||
│ └── discord.go
|
||||
├── slack/ line/ onebot/ dingtalk/ feishu/ wecom/ qq/ whatsapp/ maixcam/ pico/
|
||||
├── slack/ line/ onebot/ dingtalk/ feishu/ wecom/ qq/ whatsapp/ whatsapp_native/ maixcam/ pico/
|
||||
│ └── ...
|
||||
|
||||
pkg/bus/
|
||||
@@ -111,7 +111,7 @@ pkg/identity/
|
||||
|------|------|
|
||||
| **子包隔离** | 每个 channel 一个独立 Go 子包,依赖 `channels` 父包提供的 `BaseChannel` 和接口 |
|
||||
| **工厂注册** | 各子包通过 `init()` 自注册,Manager 通过名字查找工厂,消除 import 耦合 |
|
||||
| **能力发现** | 可选能力通过接口(`MediaSender`, `TypingCapable`, `ReactionCapable`, `PlaceholderCapable`, `MessageEditor`, `WebhookHandler`)声明,Manager 运行时类型断言发现 |
|
||||
| **能力发现** | 可选能力通过接口(`MediaSender`, `TypingCapable`, `ReactionCapable`, `PlaceholderCapable`, `MessageEditor`, `WebhookHandler`, `HealthChecker`)声明,Manager 运行时类型断言发现 |
|
||||
| **结构化消息** | Peer、MessageID、SenderInfo 从 Metadata 提升为 InboundMessage 的一等字段 |
|
||||
| **错误分类** | Channel 返回哨兵错误(`ErrRateLimit`, `ErrTemporary` 等),Manager 据此决定重试策略 |
|
||||
| **集中编排** | 速率限制、消息分割、重试、Typing/Reaction/Placeholder 全部由 Manager 和 BaseChannel 统一处理,Channel 只负责 Send |
|
||||
@@ -145,6 +145,7 @@ pkg/identity/
|
||||
| _(不存在)_ | `pkg/channels/interfaces.go` | 新增可选能力接口 |
|
||||
| _(不存在)_ | `pkg/channels/media.go` | 新增 MediaSender 接口 |
|
||||
| _(不存在)_ | `pkg/channels/webhook.go` | 新增 WebhookHandler/HealthChecker |
|
||||
| _(不存在)_ | `pkg/channels/whatsapp_native/` | 新增 WhatsApp 原生模式(whatsmeow) |
|
||||
| _(不存在)_ | `pkg/channels/split.go` | 新增消息分割(从 utils 迁入) |
|
||||
| _(不存在)_ | `pkg/bus/types.go` | 新增结构化消息类型 |
|
||||
| _(不存在)_ | `pkg/media/store.go` | 新增媒体文件生命周期管理 |
|
||||
@@ -220,6 +221,7 @@ func NewTelegramChannel(cfg *config.Config, bus *bus.MessageBus) (*TelegramChann
|
||||
cfg.Channels.Telegram.AllowFrom, // 允许列表
|
||||
channels.WithMaxMessageLength(4096), // 平台消息长度上限
|
||||
channels.WithGroupTrigger(cfg.Channels.Telegram.GroupTrigger), // 群聊触发配置
|
||||
channels.WithReasoningChannelID(cfg.Channels.Telegram.ReasoningChannelID), // 思维链路由
|
||||
)
|
||||
return &TelegramChannel{
|
||||
BaseChannel: base,
|
||||
@@ -466,6 +468,7 @@ func NewMatrixChannel(cfg *config.Config, msgBus *bus.MessageBus) (*MatrixChanne
|
||||
matrixCfg.AllowFrom, // 允许列表
|
||||
channels.WithMaxMessageLength(65536), // Matrix 消息长度限制
|
||||
channels.WithGroupTrigger(matrixCfg.GroupTrigger),
|
||||
channels.WithReasoningChannelID(matrixCfg.ReasoningChannelID), // 思维链路由(可选)
|
||||
)
|
||||
|
||||
return &MatrixChannel{
|
||||
@@ -666,6 +669,31 @@ func (c *MatrixChannel) EditMessage(ctx context.Context, chatID, messageID, cont
|
||||
}
|
||||
```
|
||||
|
||||
#### PlaceholderCapable — 占位消息
|
||||
|
||||
```go
|
||||
// 如果平台支持发送占位消息(如 "Thinking... 💭"),并且实现了 MessageEditor,
|
||||
// 则 Manager 的 preSend 会在出站时自动将占位消息编辑为最终回复。
|
||||
// SendPlaceholder 内部根据 PlaceholderConfig.Enabled 决定是否发送;
|
||||
// 返回 ("", nil) 表示跳过。
|
||||
func (c *MatrixChannel) SendPlaceholder(ctx context.Context, chatID string) (string, error) {
|
||||
cfg := c.config.Channels.Matrix.Placeholder
|
||||
if !cfg.Enabled {
|
||||
return "", nil
|
||||
}
|
||||
text := cfg.Text
|
||||
if text == "" {
|
||||
text = "Thinking... 💭"
|
||||
}
|
||||
// 调用 Matrix API 发送占位消息
|
||||
msg, err := c.sendText(ctx, chatID, text)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return msg.ID, nil
|
||||
}
|
||||
```
|
||||
|
||||
#### WebhookHandler — HTTP Webhook 接收
|
||||
|
||||
```go
|
||||
@@ -746,15 +774,17 @@ if c.owner != nil && c.placeholderRecorder != nil {
|
||||
```go
|
||||
type ChannelsConfig struct {
|
||||
// ... 现有 channels
|
||||
Matrix MatrixChannelConfig `yaml:"matrix" json:"matrix"`
|
||||
Matrix MatrixChannelConfig `json:"matrix"`
|
||||
}
|
||||
|
||||
type MatrixChannelConfig struct {
|
||||
Enabled bool `yaml:"enabled" json:"enabled"`
|
||||
HomeServer string `yaml:"home_server" json:"home_server"`
|
||||
Token string `yaml:"token" json:"token"`
|
||||
AllowFrom []string `yaml:"allow_from" json:"allow_from"`
|
||||
GroupTrigger GroupTriggerConfig `yaml:"group_trigger" json:"group_trigger"`
|
||||
Enabled bool `json:"enabled"`
|
||||
HomeServer string `json:"home_server"`
|
||||
Token string `json:"token"`
|
||||
AllowFrom []string `json:"allow_from"`
|
||||
GroupTrigger GroupTriggerConfig `json:"group_trigger"`
|
||||
Placeholder PlaceholderConfig `json:"placeholder"`
|
||||
ReasoningChannelID string `json:"reasoning_channel_id"`
|
||||
}
|
||||
```
|
||||
|
||||
@@ -767,6 +797,15 @@ if m.config.Channels.Matrix.Enabled && m.config.Channels.Matrix.Token != "" {
|
||||
}
|
||||
```
|
||||
|
||||
> **注意**:如果你的 channel 有多种模式(如 WhatsApp Bridge vs Native),需要在 initChannels 中根据配置分支:
|
||||
> ```go
|
||||
> if cfg.UseNative {
|
||||
> m.initChannel("whatsapp_native", "WhatsApp Native")
|
||||
> } else {
|
||||
> m.initChannel("whatsapp", "WhatsApp")
|
||||
> }
|
||||
> ```
|
||||
|
||||
#### 在 Gateway 中添加 blank import
|
||||
|
||||
```go
|
||||
@@ -882,19 +921,21 @@ BaseChannel 是所有 channel 的共享抽象层,提供以下能力:
|
||||
| `IsRunning() bool` | 原子读取运行状态 |
|
||||
| `SetRunning(bool)` | 原子设置运行状态 |
|
||||
| `MaxMessageLength() int` | 消息长度限制(rune 计数),0 = 无限制 |
|
||||
| `ReasoningChannelID() string` | 思维链路由目标 channel ID(空 = 不路由) |
|
||||
| `IsAllowed(senderID string) bool` | 旧格式允许列表检查(支持 `"id\|username"` 和 `"@username"` 格式) |
|
||||
| `IsAllowedSender(sender SenderInfo) bool` | 新格式允许列表检查(委托给 `identity.MatchAllowed`) |
|
||||
| `ShouldRespondInGroup(isMentioned, content) (bool, string)` | 统一群聊触发过滤逻辑 |
|
||||
| `HandleMessage(...)` | 统一入站消息处理:权限检查 → 构建 MediaScope → 自动触发 Typing/Reaction → 发布到 Bus |
|
||||
| `HandleMessage(...)` | 统一入站消息处理:权限检查 → 构建 MediaScope → 自动触发 Typing/Reaction/Placeholder → 发布到 Bus |
|
||||
| `SetMediaStore(s) / GetMediaStore()` | Manager 注入的媒体存储 |
|
||||
| `SetPlaceholderRecorder(r) / GetPlaceholderRecorder()` | Manager 注入的占位符记录器 |
|
||||
| `SetOwner(ch) ` | Manager 注入的具体 channel 引用(用于 HandleMessage 内部的 Typing/Reaction 类型断言) |
|
||||
| `SetOwner(ch) ` | Manager 注入的具体 channel 引用(用于 HandleMessage 内部的 Typing/Reaction/Placeholder 类型断言) |
|
||||
|
||||
**功能选项**:
|
||||
|
||||
```go
|
||||
channels.WithMaxMessageLength(4096) // 设置平台消息长度限制
|
||||
channels.WithGroupTrigger(groupTriggerCfg) // 设置群聊触发配置
|
||||
channels.WithReasoningChannelID(id) // 设置思维链路由目标 channel
|
||||
```
|
||||
|
||||
### 4.4 工厂注册表
|
||||
@@ -998,7 +1039,7 @@ StartAll:
|
||||
- runMediaWorker (per-channel 出站媒体)
|
||||
- dispatchOutbound (从 bus 路由到 worker 队列)
|
||||
- dispatchOutboundMedia (从 bus 路由到 media worker 队列)
|
||||
- runTTLJanitor (每 10s 清理过期 typing/placeholder)
|
||||
- runTTLJanitor (每 10s 清理过期 typing/reaction/placeholder)
|
||||
4. 启动共享 HTTP 服务器(如已配置)
|
||||
|
||||
StopAll:
|
||||
@@ -1206,18 +1247,20 @@ make test # 全量测试
|
||||
|
||||
| 子包 | 注册名 | 可选接口 |
|
||||
|------|--------|----------|
|
||||
| `pkg/channels/telegram/` | `"telegram"` | MessageEditor, MediaSender, TypingCapable, PlaceholderCapable |
|
||||
| `pkg/channels/discord/` | `"discord"` | MessageEditor, TypingCapable, PlaceholderCapable |
|
||||
| `pkg/channels/slack/` | `"slack"` | ReactionCapable |
|
||||
| `pkg/channels/line/` | `"line"` | WebhookHandler, HealthChecker, TypingCapable |
|
||||
| `pkg/channels/onebot/` | `"onebot"` | ReactionCapable |
|
||||
| `pkg/channels/dingtalk/` | `"dingtalk"` | WebhookHandler |
|
||||
| `pkg/channels/feishu/` | `"feishu"` | WebhookHandler (架构特定 build tags) |
|
||||
| `pkg/channels/wecom/` | `"wecom"` + `"wecom_app"` | WebhookHandler |
|
||||
| `pkg/channels/telegram/` | `"telegram"` | TypingCapable, PlaceholderCapable, MessageEditor, MediaSender |
|
||||
| `pkg/channels/discord/` | `"discord"` | TypingCapable, PlaceholderCapable, MessageEditor, MediaSender |
|
||||
| `pkg/channels/slack/` | `"slack"` | ReactionCapable, MediaSender |
|
||||
| `pkg/channels/line/` | `"line"` | TypingCapable, MediaSender, WebhookHandler |
|
||||
| `pkg/channels/onebot/` | `"onebot"` | ReactionCapable, MediaSender |
|
||||
| `pkg/channels/dingtalk/` | `"dingtalk"` | — |
|
||||
| `pkg/channels/feishu/` | `"feishu"` | — (架构特定 build tags: `feishu_32.go` / `feishu_64.go`) |
|
||||
| `pkg/channels/wecom/` | `"wecom"` | WebhookHandler, HealthChecker |
|
||||
| `pkg/channels/wecom/` | `"wecom_app"` | MediaSender, WebhookHandler, HealthChecker |
|
||||
| `pkg/channels/qq/` | `"qq"` | — |
|
||||
| `pkg/channels/whatsapp/` | `"whatsapp"` | — |
|
||||
| `pkg/channels/whatsapp/` | `"whatsapp"` | — (Bridge 模式) |
|
||||
| `pkg/channels/whatsapp_native/` | `"whatsapp_native"` | — (原生 whatsmeow 模式) |
|
||||
| `pkg/channels/maixcam/` | `"maixcam"` | — |
|
||||
| `pkg/channels/pico/` | `"pico"` | WebhookHandler (Pico Protocol), TypingCapable, PlaceholderCapable |
|
||||
| `pkg/channels/pico/` | `"pico"` | TypingCapable, PlaceholderCapable, MessageEditor, WebhookHandler |
|
||||
|
||||
### A.3 接口速查表
|
||||
|
||||
@@ -1231,6 +1274,7 @@ type Channel interface {
|
||||
IsRunning() bool
|
||||
IsAllowed(senderID string) bool
|
||||
IsAllowedSender(sender bus.SenderInfo) bool
|
||||
ReasoningChannelID() string
|
||||
}
|
||||
|
||||
// ===== 可选实现 =====
|
||||
@@ -1324,8 +1368,16 @@ agentLoop.Stop() // 停止 Agent
|
||||
|
||||
1. **媒体清理暂时禁用**:Agent loop 中的 `ReleaseAll` 调用被注释掉了(`refactor(loop): disable media cleanup to prevent premature file deletion`),因为会话边界尚未明确定义。TTL 清理仍然有效。
|
||||
|
||||
2. **Feishu 架构特定编译**:Feishu channel 使用 build tags 区分 32 位和 64 位架构(`feishu_32.go` / `feishu_64.go`)。
|
||||
2. **Feishu 架构特定编译**:Feishu channel 使用 build tags 区分 32 位和 64 位架构(`feishu_32.go` / `feishu_64.go`)。Feishu 使用 SDK 的 WebSocket 模式(非 HTTP webhook),因此不实现 `WebhookHandler`。
|
||||
|
||||
3. **WeCom 有两个工厂**:`"wecom"`(Bot 模式)和 `"wecom_app"`(应用模式)分别注册。
|
||||
3. **WeCom 有两个工厂**:`"wecom"`(Bot 模式,纯 webhook)和 `"wecom_app"`(应用模式,支持 MediaSender)分别注册。两者都实现了 `WebhookHandler` 和 `HealthChecker`。
|
||||
|
||||
4. **Pico Protocol**:`pkg/channels/pico/` 实现了一个自定义的 PicoClaw 原生协议 channel,通过 webhook 接收消息。
|
||||
4. **Pico Protocol**:`pkg/channels/pico/` 实现了一个自定义的 PicoClaw 原生协议 channel,通过 WebSocket webhook (`/pico/ws`) 接收消息。
|
||||
|
||||
5. **WhatsApp 有两种模式**:`"whatsapp"`(Bridge 模式,通过外部 bridge URL 通信)和 `"whatsapp_native"`(原生 whatsmeow 模式,直接连接 WhatsApp)。Manager 根据 `WhatsAppConfig.UseNative` 决定初始化哪个。
|
||||
|
||||
6. **DingTalk 使用 Stream 模式**:DingTalk 使用 SDK 的 Stream/WebSocket 模式(非 HTTP webhook),因此不实现 `WebhookHandler`。
|
||||
|
||||
7. **PlaceholderConfig 的配置与实现**:`PlaceholderConfig` 出现在 6 个 channel config 中(Telegram、Discord、Slack、LINE、OneBot、Pico),但只有实现了 `PlaceholderCapable` + `MessageEditor` 的 channel(Telegram、Discord、Pico)能真正使用占位消息编辑功能。其余 channel 的 `PlaceholderConfig` 为预留字段。
|
||||
|
||||
8. **ReasoningChannelID**:大多数 channel config 都包含 `reasoning_channel_id` 字段,用于将 LLM 的思维链(reasoning/thinking)路由到指定 channel(WhatsApp、Telegram、Feishu、Discord、MaixCam、QQ、DingTalk、Slack、LINE、OneBot、WeCom、WeComApp)。注意:`PicoConfig` 目前不包含该字段。`BaseChannel` 通过 `WithReasoningChannelID` 选项和 `ReasoningChannelID()` 方法暴露此配置。
|
||||
@@ -1,5 +1,16 @@
|
||||
package feishu
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
larkim "github.com/larksuite/oapi-sdk-go/v3/service/im/v1"
|
||||
)
|
||||
|
||||
// mentionPlaceholderRegex matches @_user_N placeholders inserted by Feishu for mentions.
|
||||
var mentionPlaceholderRegex = regexp.MustCompile(`@_user_\d+`)
|
||||
|
||||
// stringValue safely dereferences a *string pointer.
|
||||
func stringValue(v *string) string {
|
||||
if v == nil {
|
||||
@@ -7,3 +18,69 @@ func stringValue(v *string) string {
|
||||
}
|
||||
return *v
|
||||
}
|
||||
|
||||
// buildMarkdownCard builds a Feishu Interactive Card JSON 2.0 string with markdown content.
|
||||
// JSON 2.0 cards support full CommonMark standard markdown syntax.
|
||||
func buildMarkdownCard(content string) (string, error) {
|
||||
card := map[string]any{
|
||||
"schema": "2.0",
|
||||
"body": map[string]any{
|
||||
"elements": []map[string]any{
|
||||
{
|
||||
"tag": "markdown",
|
||||
"content": content,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
data, err := json.Marshal(card)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(data), nil
|
||||
}
|
||||
|
||||
// extractJSONStringField unmarshals content as JSON and returns the value of the given string field.
|
||||
// Returns "" if the content is invalid JSON or the field is missing/empty.
|
||||
func extractJSONStringField(content, field string) string {
|
||||
var m map[string]json.RawMessage
|
||||
if err := json.Unmarshal([]byte(content), &m); err != nil {
|
||||
return ""
|
||||
}
|
||||
raw, ok := m[field]
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
var s string
|
||||
if err := json.Unmarshal(raw, &s); err != nil {
|
||||
return ""
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// extractImageKey extracts the image_key from a Feishu image message content JSON.
|
||||
// Format: {"image_key": "img_xxx"}
|
||||
func extractImageKey(content string) string { return extractJSONStringField(content, "image_key") }
|
||||
|
||||
// extractFileKey extracts the file_key from a Feishu file/audio message content JSON.
|
||||
// Format: {"file_key": "file_xxx", "file_name": "...", ...}
|
||||
func extractFileKey(content string) string { return extractJSONStringField(content, "file_key") }
|
||||
|
||||
// extractFileName extracts the file_name from a Feishu file message content JSON.
|
||||
func extractFileName(content string) string { return extractJSONStringField(content, "file_name") }
|
||||
|
||||
// stripMentionPlaceholders removes @_user_N placeholders from the text content.
|
||||
// These are inserted by Feishu when users @mention someone in a message.
|
||||
func stripMentionPlaceholders(content string, mentions []*larkim.MentionEvent) string {
|
||||
if len(mentions) == 0 {
|
||||
return content
|
||||
}
|
||||
for _, m := range mentions {
|
||||
if m.Key != nil && *m.Key != "" {
|
||||
content = strings.ReplaceAll(content, *m.Key, "")
|
||||
}
|
||||
}
|
||||
// Also clean up any remaining @_user_N patterns
|
||||
content = mentionPlaceholderRegex.ReplaceAllString(content, "")
|
||||
return strings.TrimSpace(content)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,292 @@
|
||||
package feishu
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
larkim "github.com/larksuite/oapi-sdk-go/v3/service/im/v1"
|
||||
)
|
||||
|
||||
func TestExtractJSONStringField(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
content string
|
||||
field string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "valid field",
|
||||
content: `{"image_key": "img_v2_xxx"}`,
|
||||
field: "image_key",
|
||||
want: "img_v2_xxx",
|
||||
},
|
||||
{
|
||||
name: "missing field",
|
||||
content: `{"image_key": "img_v2_xxx"}`,
|
||||
field: "file_key",
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "invalid JSON",
|
||||
content: `not json at all`,
|
||||
field: "image_key",
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "empty content",
|
||||
content: "",
|
||||
field: "image_key",
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "non-string field value",
|
||||
content: `{"count": 42}`,
|
||||
field: "count",
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "empty string value",
|
||||
content: `{"image_key": ""}`,
|
||||
field: "image_key",
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "multiple fields",
|
||||
content: `{"file_key": "file_xxx", "file_name": "test.pdf"}`,
|
||||
field: "file_name",
|
||||
want: "test.pdf",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := extractJSONStringField(tt.content, tt.field)
|
||||
if got != tt.want {
|
||||
t.Errorf("extractJSONStringField(%q, %q) = %q, want %q", tt.content, tt.field, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractImageKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
content string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "normal",
|
||||
content: `{"image_key": "img_v2_abc123"}`,
|
||||
want: "img_v2_abc123",
|
||||
},
|
||||
{
|
||||
name: "missing key",
|
||||
content: `{"file_key": "file_xxx"}`,
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "malformed JSON",
|
||||
content: `{broken`,
|
||||
want: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := extractImageKey(tt.content)
|
||||
if got != tt.want {
|
||||
t.Errorf("extractImageKey(%q) = %q, want %q", tt.content, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractFileKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
content string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "normal",
|
||||
content: `{"file_key": "file_v2_abc123", "file_name": "test.doc"}`,
|
||||
want: "file_v2_abc123",
|
||||
},
|
||||
{
|
||||
name: "missing key",
|
||||
content: `{"image_key": "img_xxx"}`,
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "malformed JSON",
|
||||
content: `not json`,
|
||||
want: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := extractFileKey(tt.content)
|
||||
if got != tt.want {
|
||||
t.Errorf("extractFileKey(%q) = %q, want %q", tt.content, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractFileName(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
content string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "normal",
|
||||
content: `{"file_key": "file_xxx", "file_name": "report.pdf"}`,
|
||||
want: "report.pdf",
|
||||
},
|
||||
{
|
||||
name: "missing name",
|
||||
content: `{"file_key": "file_xxx"}`,
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "malformed JSON",
|
||||
content: `{bad`,
|
||||
want: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := extractFileName(tt.content)
|
||||
if got != tt.want {
|
||||
t.Errorf("extractFileName(%q) = %q, want %q", tt.content, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildMarkdownCard(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
content string
|
||||
}{
|
||||
{
|
||||
name: "normal content",
|
||||
content: "Hello **world**",
|
||||
},
|
||||
{
|
||||
name: "empty content",
|
||||
content: "",
|
||||
},
|
||||
{
|
||||
name: "special characters",
|
||||
content: `Code: "foo" & <bar> 'baz'`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := buildMarkdownCard(tt.content)
|
||||
if err != nil {
|
||||
t.Fatalf("buildMarkdownCard(%q) unexpected error: %v", tt.content, err)
|
||||
}
|
||||
|
||||
// Verify valid JSON
|
||||
var parsed map[string]any
|
||||
if err := json.Unmarshal([]byte(result), &parsed); err != nil {
|
||||
t.Fatalf("buildMarkdownCard(%q) produced invalid JSON: %v", tt.content, err)
|
||||
}
|
||||
|
||||
// Verify schema
|
||||
if parsed["schema"] != "2.0" {
|
||||
t.Errorf("schema = %v, want %q", parsed["schema"], "2.0")
|
||||
}
|
||||
|
||||
// Verify body.elements[0].content == input
|
||||
body, ok := parsed["body"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("missing body in card JSON")
|
||||
}
|
||||
elements, ok := body["elements"].([]any)
|
||||
if !ok || len(elements) == 0 {
|
||||
t.Fatal("missing or empty elements in card JSON")
|
||||
}
|
||||
elem, ok := elements[0].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("first element is not an object")
|
||||
}
|
||||
if elem["tag"] != "markdown" {
|
||||
t.Errorf("tag = %v, want %q", elem["tag"], "markdown")
|
||||
}
|
||||
if elem["content"] != tt.content {
|
||||
t.Errorf("content = %v, want %q", elem["content"], tt.content)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStripMentionPlaceholders(t *testing.T) {
|
||||
strPtr := func(s string) *string { return &s }
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
content string
|
||||
mentions []*larkim.MentionEvent
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "no mentions",
|
||||
content: "Hello world",
|
||||
mentions: nil,
|
||||
want: "Hello world",
|
||||
},
|
||||
{
|
||||
name: "single mention",
|
||||
content: "@_user_1 hello",
|
||||
mentions: []*larkim.MentionEvent{
|
||||
{Key: strPtr("@_user_1")},
|
||||
},
|
||||
want: "hello",
|
||||
},
|
||||
{
|
||||
name: "multiple mentions",
|
||||
content: "@_user_1 @_user_2 hey",
|
||||
mentions: []*larkim.MentionEvent{
|
||||
{Key: strPtr("@_user_1")},
|
||||
{Key: strPtr("@_user_2")},
|
||||
},
|
||||
want: "hey",
|
||||
},
|
||||
{
|
||||
name: "empty content",
|
||||
content: "",
|
||||
mentions: []*larkim.MentionEvent{{Key: strPtr("@_user_1")}},
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "empty mentions slice",
|
||||
content: "@_user_1 test",
|
||||
mentions: []*larkim.MentionEvent{},
|
||||
want: "@_user_1 test",
|
||||
},
|
||||
{
|
||||
name: "mention with nil key",
|
||||
content: "@_user_1 test",
|
||||
mentions: []*larkim.MentionEvent{
|
||||
{Key: nil},
|
||||
},
|
||||
want: "test",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := stripMentionPlaceholders(tt.content, tt.mentions)
|
||||
if got != tt.want {
|
||||
t.Errorf("stripMentionPlaceholders(%q, ...) = %q, want %q", tt.content, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -16,6 +16,8 @@ type FeishuChannel struct {
|
||||
*channels.BaseChannel
|
||||
}
|
||||
|
||||
var errUnsupported = errors.New("feishu channel is not supported on 32-bit architectures")
|
||||
|
||||
// NewFeishuChannel returns an error on 32-bit architectures where the Feishu SDK is not supported
|
||||
func NewFeishuChannel(cfg config.FeishuConfig, bus *bus.MessageBus) (*FeishuChannel, error) {
|
||||
return nil, errors.New(
|
||||
@@ -25,15 +27,35 @@ func NewFeishuChannel(cfg config.FeishuConfig, bus *bus.MessageBus) (*FeishuChan
|
||||
|
||||
// Start is a stub method to satisfy the Channel interface
|
||||
func (c *FeishuChannel) Start(ctx context.Context) error {
|
||||
return nil
|
||||
return errUnsupported
|
||||
}
|
||||
|
||||
// Stop is a stub method to satisfy the Channel interface
|
||||
func (c *FeishuChannel) Stop(ctx context.Context) error {
|
||||
return nil
|
||||
return errUnsupported
|
||||
}
|
||||
|
||||
// Send is a stub method to satisfy the Channel interface
|
||||
func (c *FeishuChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
|
||||
return errors.New("feishu channel is not supported on 32-bit architectures")
|
||||
return errUnsupported
|
||||
}
|
||||
|
||||
// EditMessage is a stub method to satisfy MessageEditor
|
||||
func (c *FeishuChannel) EditMessage(ctx context.Context, chatID, messageID, content string) error {
|
||||
return errUnsupported
|
||||
}
|
||||
|
||||
// SendPlaceholder is a stub method to satisfy PlaceholderCapable
|
||||
func (c *FeishuChannel) SendPlaceholder(ctx context.Context, chatID string) (string, error) {
|
||||
return "", errUnsupported
|
||||
}
|
||||
|
||||
// ReactToMessage is a stub method to satisfy ReactionCapable
|
||||
func (c *FeishuChannel) ReactToMessage(ctx context.Context, chatID, messageID string) (func(), error) {
|
||||
return func() {}, errUnsupported
|
||||
}
|
||||
|
||||
// SendMedia is a stub method to satisfy MediaSender
|
||||
func (c *FeishuChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error {
|
||||
return errUnsupported
|
||||
}
|
||||
|
||||
@@ -6,10 +6,15 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"time"
|
||||
"sync/atomic"
|
||||
|
||||
lark "github.com/larksuite/oapi-sdk-go/v3"
|
||||
larkcore "github.com/larksuite/oapi-sdk-go/v3/core"
|
||||
larkdispatcher "github.com/larksuite/oapi-sdk-go/v3/event/dispatcher"
|
||||
larkim "github.com/larksuite/oapi-sdk-go/v3/service/im/v1"
|
||||
larkws "github.com/larksuite/oapi-sdk-go/v3/ws"
|
||||
@@ -19,6 +24,7 @@ import (
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/identity"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/media"
|
||||
"github.com/sipeed/picoclaw/pkg/utils"
|
||||
)
|
||||
|
||||
@@ -28,6 +34,8 @@ type FeishuChannel struct {
|
||||
client *lark.Client
|
||||
wsClient *larkws.Client
|
||||
|
||||
botOpenID atomic.Value // stores string; populated lazily for @mention detection
|
||||
|
||||
mu sync.Mutex
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
@@ -38,11 +46,13 @@ func NewFeishuChannel(cfg config.FeishuConfig, bus *bus.MessageBus) (*FeishuChan
|
||||
channels.WithReasoningChannelID(cfg.ReasoningChannelID),
|
||||
)
|
||||
|
||||
return &FeishuChannel{
|
||||
ch := &FeishuChannel{
|
||||
BaseChannel: base,
|
||||
config: cfg,
|
||||
client: lark.NewClient(cfg.AppID, cfg.AppSecret),
|
||||
}, nil
|
||||
}
|
||||
ch.SetOwner(ch)
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
func (c *FeishuChannel) Start(ctx context.Context) error {
|
||||
@@ -50,6 +60,13 @@ func (c *FeishuChannel) Start(ctx context.Context) error {
|
||||
return fmt.Errorf("feishu app_id or app_secret is empty")
|
||||
}
|
||||
|
||||
// Fetch bot open_id via API for reliable @mention detection.
|
||||
if err := c.fetchBotOpenID(ctx); err != nil {
|
||||
logger.ErrorCF("feishu", "Failed to fetch bot open_id, @mention detection may not work", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
dispatcher := larkdispatcher.NewEventDispatcher(c.config.VerificationToken, c.config.EncryptKey).
|
||||
OnP2MessageReceiveV1(c.handleMessageReceive)
|
||||
|
||||
@@ -93,46 +110,213 @@ func (c *FeishuChannel) Stop(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Send sends a message using Interactive Card format for markdown rendering.
|
||||
func (c *FeishuChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
|
||||
if !c.IsRunning() {
|
||||
return channels.ErrNotRunning
|
||||
}
|
||||
|
||||
if msg.ChatID == "" {
|
||||
return fmt.Errorf("chat ID is empty")
|
||||
return fmt.Errorf("chat ID is empty: %w", channels.ErrSendFailed)
|
||||
}
|
||||
|
||||
payload, err := json.Marshal(map[string]string{"text": msg.Content})
|
||||
// Build interactive card with markdown content
|
||||
cardContent, err := buildMarkdownCard(msg.Content)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal feishu content: %w", err)
|
||||
return fmt.Errorf("feishu send: card build failed: %w", err)
|
||||
}
|
||||
return c.sendCard(ctx, msg.ChatID, cardContent)
|
||||
}
|
||||
|
||||
// EditMessage implements channels.MessageEditor.
|
||||
// Uses Message.Patch to update an interactive card message.
|
||||
func (c *FeishuChannel) EditMessage(ctx context.Context, chatID, messageID, content string) error {
|
||||
cardContent, err := buildMarkdownCard(content)
|
||||
if err != nil {
|
||||
return fmt.Errorf("feishu edit: card build failed: %w", err)
|
||||
}
|
||||
|
||||
req := larkim.NewPatchMessageReqBuilder().
|
||||
MessageId(messageID).
|
||||
Body(larkim.NewPatchMessageReqBodyBuilder().Content(cardContent).Build()).
|
||||
Build()
|
||||
|
||||
resp, err := c.client.Im.V1.Message.Patch(ctx, req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("feishu edit: %w", err)
|
||||
}
|
||||
if !resp.Success() {
|
||||
return fmt.Errorf("feishu edit api error (code=%d msg=%s)", resp.Code, resp.Msg)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SendPlaceholder implements channels.PlaceholderCapable.
|
||||
// Sends an interactive card with placeholder text and returns its message ID.
|
||||
func (c *FeishuChannel) SendPlaceholder(ctx context.Context, chatID string) (string, error) {
|
||||
if !c.config.Placeholder.Enabled {
|
||||
logger.DebugCF("feishu", "Placeholder disabled, skipping", map[string]any{
|
||||
"chat_id": chatID,
|
||||
})
|
||||
return "", nil
|
||||
}
|
||||
|
||||
text := c.config.Placeholder.Text
|
||||
if text == "" {
|
||||
text = "Thinking..."
|
||||
}
|
||||
|
||||
cardContent, err := buildMarkdownCard(text)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("feishu placeholder: card build failed: %w", err)
|
||||
}
|
||||
|
||||
req := larkim.NewCreateMessageReqBuilder().
|
||||
ReceiveIdType(larkim.ReceiveIdTypeChatId).
|
||||
Body(larkim.NewCreateMessageReqBodyBuilder().
|
||||
ReceiveId(msg.ChatID).
|
||||
MsgType(larkim.MsgTypeText).
|
||||
Content(string(payload)).
|
||||
Uuid(fmt.Sprintf("picoclaw-%d", time.Now().UnixNano())).
|
||||
ReceiveId(chatID).
|
||||
MsgType(larkim.MsgTypeInteractive).
|
||||
Content(cardContent).
|
||||
Build()).
|
||||
Build()
|
||||
|
||||
resp, err := c.client.Im.V1.Message.Create(ctx, req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("feishu send: %w", channels.ErrTemporary)
|
||||
return "", fmt.Errorf("feishu placeholder send: %w", err)
|
||||
}
|
||||
|
||||
if !resp.Success() {
|
||||
return fmt.Errorf("feishu api error (code=%d msg=%s): %w", resp.Code, resp.Msg, channels.ErrTemporary)
|
||||
return "", fmt.Errorf("feishu placeholder api error (code=%d msg=%s)", resp.Code, resp.Msg)
|
||||
}
|
||||
|
||||
logger.DebugCF("feishu", "Feishu message sent", map[string]any{
|
||||
"chat_id": msg.ChatID,
|
||||
})
|
||||
if resp.Data != nil && resp.Data.MessageId != nil {
|
||||
return *resp.Data.MessageId, nil
|
||||
}
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// ReactToMessage implements channels.ReactionCapable.
|
||||
// Adds an "Pin" reaction and returns an undo function to remove it.
|
||||
func (c *FeishuChannel) ReactToMessage(ctx context.Context, chatID, messageID string) (func(), error) {
|
||||
req := larkim.NewCreateMessageReactionReqBuilder().
|
||||
MessageId(messageID).
|
||||
Body(larkim.NewCreateMessageReactionReqBodyBuilder().
|
||||
ReactionType(larkim.NewEmojiBuilder().EmojiType("Pin").Build()).
|
||||
Build()).
|
||||
Build()
|
||||
|
||||
resp, err := c.client.Im.V1.MessageReaction.Create(ctx, req)
|
||||
if err != nil {
|
||||
logger.ErrorCF("feishu", "Failed to add reaction", map[string]any{
|
||||
"message_id": messageID,
|
||||
"error": err.Error(),
|
||||
})
|
||||
return func() {}, fmt.Errorf("feishu react: %w", err)
|
||||
}
|
||||
if !resp.Success() {
|
||||
logger.ErrorCF("feishu", "Reaction API error", map[string]any{
|
||||
"message_id": messageID,
|
||||
"code": resp.Code,
|
||||
"msg": resp.Msg,
|
||||
})
|
||||
return func() {}, fmt.Errorf("feishu react api error (code=%d msg=%s)", resp.Code, resp.Msg)
|
||||
}
|
||||
|
||||
var reactionID string
|
||||
if resp.Data != nil && resp.Data.ReactionId != nil {
|
||||
reactionID = *resp.Data.ReactionId
|
||||
}
|
||||
if reactionID == "" {
|
||||
return func() {}, nil
|
||||
}
|
||||
|
||||
var undone atomic.Bool
|
||||
undo := func() {
|
||||
if !undone.CompareAndSwap(false, true) {
|
||||
return
|
||||
}
|
||||
delReq := larkim.NewDeleteMessageReactionReqBuilder().
|
||||
MessageId(messageID).
|
||||
ReactionId(reactionID).
|
||||
Build()
|
||||
_, _ = c.client.Im.V1.MessageReaction.Delete(context.Background(), delReq)
|
||||
}
|
||||
return undo, nil
|
||||
}
|
||||
|
||||
// SendMedia implements channels.MediaSender.
|
||||
// Uploads images/files via Feishu API then sends as messages.
|
||||
func (c *FeishuChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error {
|
||||
if !c.IsRunning() {
|
||||
return channels.ErrNotRunning
|
||||
}
|
||||
|
||||
if msg.ChatID == "" {
|
||||
return fmt.Errorf("chat ID is empty: %w", channels.ErrSendFailed)
|
||||
}
|
||||
|
||||
store := c.GetMediaStore()
|
||||
if store == nil {
|
||||
return fmt.Errorf("no media store available: %w", channels.ErrSendFailed)
|
||||
}
|
||||
|
||||
for _, part := range msg.Parts {
|
||||
if err := c.sendMediaPart(ctx, msg.ChatID, part, store); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// sendMediaPart resolves and sends a single media part.
|
||||
func (c *FeishuChannel) sendMediaPart(
|
||||
ctx context.Context,
|
||||
chatID string,
|
||||
part bus.MediaPart,
|
||||
store media.MediaStore,
|
||||
) error {
|
||||
localPath, err := store.Resolve(part.Ref)
|
||||
if err != nil {
|
||||
logger.ErrorCF("feishu", "Failed to resolve media ref", map[string]any{
|
||||
"ref": part.Ref,
|
||||
"error": err.Error(),
|
||||
})
|
||||
return nil // skip this part
|
||||
}
|
||||
|
||||
file, err := os.Open(localPath)
|
||||
if err != nil {
|
||||
logger.ErrorCF("feishu", "Failed to open media file", map[string]any{
|
||||
"path": localPath,
|
||||
"error": err.Error(),
|
||||
})
|
||||
return nil // skip this part
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
switch part.Type {
|
||||
case "image":
|
||||
err = c.sendImage(ctx, chatID, file)
|
||||
default:
|
||||
filename := part.Filename
|
||||
if filename == "" {
|
||||
filename = "file"
|
||||
}
|
||||
err = c.sendFile(ctx, chatID, file, filename, part.Type)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
logger.ErrorCF("feishu", "Failed to send media", map[string]any{
|
||||
"type": part.Type,
|
||||
"error": err.Error(),
|
||||
})
|
||||
return fmt.Errorf("feishu send media: %w", channels.ErrTemporary)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// --- Inbound message handling ---
|
||||
|
||||
func (c *FeishuChannel) handleMessageReceive(ctx context.Context, event *larkim.P2MessageReceiveV1) error {
|
||||
if event == nil || event.Event == nil || event.Event.Message == nil {
|
||||
return nil
|
||||
@@ -151,34 +335,68 @@ func (c *FeishuChannel) handleMessageReceive(ctx context.Context, event *larkim.
|
||||
senderID = "unknown"
|
||||
}
|
||||
|
||||
content := extractFeishuMessageContent(message)
|
||||
messageType := stringValue(message.MessageType)
|
||||
messageID := stringValue(message.MessageId)
|
||||
rawContent := stringValue(message.Content)
|
||||
|
||||
// Check allowlist early to avoid downloading media for rejected senders.
|
||||
// BaseChannel.HandleMessage will check again, but this avoids wasted network I/O.
|
||||
senderInfo := bus.SenderInfo{
|
||||
Platform: "feishu",
|
||||
PlatformID: senderID,
|
||||
CanonicalID: identity.BuildCanonicalID("feishu", senderID),
|
||||
}
|
||||
if !c.IsAllowedSender(senderInfo) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Extract content based on message type
|
||||
content := extractContent(messageType, rawContent)
|
||||
|
||||
// Handle media messages (download and store)
|
||||
var mediaRefs []string
|
||||
if store := c.GetMediaStore(); store != nil && messageID != "" {
|
||||
mediaRefs = c.downloadInboundMedia(ctx, chatID, messageID, messageType, rawContent, store)
|
||||
}
|
||||
|
||||
// Append media tags to content (like Telegram does)
|
||||
content = appendMediaTags(content, messageType, mediaRefs)
|
||||
|
||||
if content == "" {
|
||||
content = "[empty message]"
|
||||
}
|
||||
|
||||
metadata := map[string]string{}
|
||||
messageID := ""
|
||||
if mid := stringValue(message.MessageId); mid != "" {
|
||||
messageID = mid
|
||||
if messageID != "" {
|
||||
metadata["message_id"] = messageID
|
||||
}
|
||||
if messageType := stringValue(message.MessageType); messageType != "" {
|
||||
if messageType != "" {
|
||||
metadata["message_type"] = messageType
|
||||
}
|
||||
if chatType := stringValue(message.ChatType); chatType != "" {
|
||||
chatType := stringValue(message.ChatType)
|
||||
if chatType != "" {
|
||||
metadata["chat_type"] = chatType
|
||||
}
|
||||
if sender != nil && sender.TenantKey != nil {
|
||||
metadata["tenant_key"] = *sender.TenantKey
|
||||
}
|
||||
|
||||
chatType := stringValue(message.ChatType)
|
||||
var peer bus.Peer
|
||||
if chatType == "p2p" {
|
||||
peer = bus.Peer{Kind: "direct", ID: senderID}
|
||||
} else {
|
||||
peer = bus.Peer{Kind: "group", ID: chatID}
|
||||
|
||||
// Check if bot was mentioned
|
||||
isMentioned := c.isBotMentioned(message)
|
||||
|
||||
// Strip mention placeholders from content before group trigger check
|
||||
if len(message.Mentions) > 0 {
|
||||
content = stripMentionPlaceholders(content, message.Mentions)
|
||||
}
|
||||
|
||||
// In group chats, apply unified group trigger filtering
|
||||
respond, cleaned := c.ShouldRespondInGroup(false, content)
|
||||
respond, cleaned := c.ShouldRespondInGroup(isMentioned, content)
|
||||
if !respond {
|
||||
return nil
|
||||
}
|
||||
@@ -186,22 +404,398 @@ func (c *FeishuChannel) handleMessageReceive(ctx context.Context, event *larkim.
|
||||
}
|
||||
|
||||
logger.InfoCF("feishu", "Feishu message received", map[string]any{
|
||||
"sender_id": senderID,
|
||||
"chat_id": chatID,
|
||||
"preview": utils.Truncate(content, 80),
|
||||
"sender_id": senderID,
|
||||
"chat_id": chatID,
|
||||
"message_id": messageID,
|
||||
"preview": utils.Truncate(content, 80),
|
||||
})
|
||||
|
||||
senderInfo := bus.SenderInfo{
|
||||
Platform: "feishu",
|
||||
PlatformID: senderID,
|
||||
CanonicalID: identity.BuildCanonicalID("feishu", senderID),
|
||||
c.HandleMessage(ctx, peer, messageID, senderID, chatID, content, mediaRefs, metadata, senderInfo)
|
||||
return nil
|
||||
}
|
||||
|
||||
// --- Internal helpers ---
|
||||
|
||||
// fetchBotOpenID calls the Feishu bot info API to retrieve and store the bot's open_id.
|
||||
func (c *FeishuChannel) fetchBotOpenID(ctx context.Context) error {
|
||||
resp, err := c.client.Do(ctx, &larkcore.ApiReq{
|
||||
HttpMethod: http.MethodGet,
|
||||
ApiPath: "/open-apis/bot/v3/info",
|
||||
SupportedAccessTokenTypes: []larkcore.AccessTokenType{larkcore.AccessTokenTypeTenant},
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("bot info request: %w", err)
|
||||
}
|
||||
|
||||
if !c.IsAllowedSender(senderInfo) {
|
||||
return nil
|
||||
var result struct {
|
||||
Code int `json:"code"`
|
||||
Bot struct {
|
||||
OpenID string `json:"open_id"`
|
||||
} `json:"bot"`
|
||||
}
|
||||
if err := json.Unmarshal(resp.RawBody, &result); err != nil {
|
||||
return fmt.Errorf("bot info parse: %w", err)
|
||||
}
|
||||
if result.Code != 0 {
|
||||
return fmt.Errorf("bot info api error (code=%d)", result.Code)
|
||||
}
|
||||
if result.Bot.OpenID == "" {
|
||||
return fmt.Errorf("bot info: empty open_id")
|
||||
}
|
||||
|
||||
c.HandleMessage(ctx, peer, messageID, senderID, chatID, content, nil, metadata, senderInfo)
|
||||
c.botOpenID.Store(result.Bot.OpenID)
|
||||
logger.InfoCF("feishu", "Fetched bot open_id from API", map[string]any{
|
||||
"open_id": result.Bot.OpenID,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// isBotMentioned checks if the bot was @mentioned in the message.
|
||||
func (c *FeishuChannel) isBotMentioned(message *larkim.EventMessage) bool {
|
||||
if message.Mentions == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
knownID, _ := c.botOpenID.Load().(string)
|
||||
if knownID == "" {
|
||||
logger.DebugCF("feishu", "Bot open_id unknown, cannot detect @mention", nil)
|
||||
return false
|
||||
}
|
||||
|
||||
for _, m := range message.Mentions {
|
||||
if m.Id == nil {
|
||||
continue
|
||||
}
|
||||
if m.Id.OpenId != nil && *m.Id.OpenId == knownID {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// extractContent extracts text content from different message types.
|
||||
func extractContent(messageType, rawContent string) string {
|
||||
if rawContent == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
switch messageType {
|
||||
case larkim.MsgTypeText:
|
||||
var textPayload struct {
|
||||
Text string `json:"text"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(rawContent), &textPayload); err == nil {
|
||||
return textPayload.Text
|
||||
}
|
||||
return rawContent
|
||||
|
||||
case larkim.MsgTypePost:
|
||||
// Pass raw JSON to LLM — structured rich text is more informative than flattened plain text
|
||||
return rawContent
|
||||
|
||||
case larkim.MsgTypeImage:
|
||||
// Image messages don't have text content
|
||||
return ""
|
||||
|
||||
case larkim.MsgTypeFile, larkim.MsgTypeAudio, larkim.MsgTypeMedia:
|
||||
// File/audio/video messages may have a filename
|
||||
name := extractFileName(rawContent)
|
||||
if name != "" {
|
||||
return name
|
||||
}
|
||||
return ""
|
||||
|
||||
default:
|
||||
return rawContent
|
||||
}
|
||||
}
|
||||
|
||||
// downloadInboundMedia downloads media from inbound messages and stores in MediaStore.
|
||||
func (c *FeishuChannel) downloadInboundMedia(
|
||||
ctx context.Context,
|
||||
chatID, messageID, messageType, rawContent string,
|
||||
store media.MediaStore,
|
||||
) []string {
|
||||
var refs []string
|
||||
scope := channels.BuildMediaScope("feishu", chatID, messageID)
|
||||
|
||||
switch messageType {
|
||||
case larkim.MsgTypeImage:
|
||||
imageKey := extractImageKey(rawContent)
|
||||
if imageKey == "" {
|
||||
return nil
|
||||
}
|
||||
ref := c.downloadResource(ctx, messageID, imageKey, "image", ".jpg", store, scope)
|
||||
if ref != "" {
|
||||
refs = append(refs, ref)
|
||||
}
|
||||
|
||||
case larkim.MsgTypeFile, larkim.MsgTypeAudio, larkim.MsgTypeMedia:
|
||||
fileKey := extractFileKey(rawContent)
|
||||
if fileKey == "" {
|
||||
return nil
|
||||
}
|
||||
// Derive a fallback extension from the message type.
|
||||
var ext string
|
||||
switch messageType {
|
||||
case larkim.MsgTypeAudio:
|
||||
ext = ".ogg"
|
||||
case larkim.MsgTypeMedia:
|
||||
ext = ".mp4"
|
||||
default:
|
||||
ext = "" // generic file — rely on resp.FileName
|
||||
}
|
||||
ref := c.downloadResource(ctx, messageID, fileKey, "file", ext, store, scope)
|
||||
if ref != "" {
|
||||
refs = append(refs, ref)
|
||||
}
|
||||
}
|
||||
|
||||
return refs
|
||||
}
|
||||
|
||||
// downloadResource downloads a message resource (image/file) from Feishu,
|
||||
// writes it to the project media directory, and stores the reference in MediaStore.
|
||||
// fallbackExt (e.g. ".jpg") is appended when the resolved filename has no extension.
|
||||
func (c *FeishuChannel) downloadResource(
|
||||
ctx context.Context,
|
||||
messageID, fileKey, resourceType, fallbackExt string,
|
||||
store media.MediaStore,
|
||||
scope string,
|
||||
) string {
|
||||
req := larkim.NewGetMessageResourceReqBuilder().
|
||||
MessageId(messageID).
|
||||
FileKey(fileKey).
|
||||
Type(resourceType).
|
||||
Build()
|
||||
|
||||
resp, err := c.client.Im.V1.MessageResource.Get(ctx, req)
|
||||
if err != nil {
|
||||
logger.ErrorCF("feishu", "Failed to download resource", map[string]any{
|
||||
"message_id": messageID,
|
||||
"file_key": fileKey,
|
||||
"error": err.Error(),
|
||||
})
|
||||
return ""
|
||||
}
|
||||
if !resp.Success() {
|
||||
logger.ErrorCF("feishu", "Resource download api error", map[string]any{
|
||||
"code": resp.Code,
|
||||
"msg": resp.Msg,
|
||||
})
|
||||
return ""
|
||||
}
|
||||
|
||||
if resp.File == nil {
|
||||
return ""
|
||||
}
|
||||
// Safely close the underlying reader if it implements io.Closer (e.g. HTTP response body).
|
||||
if closer, ok := resp.File.(io.Closer); ok {
|
||||
defer closer.Close()
|
||||
}
|
||||
|
||||
filename := resp.FileName
|
||||
if filename == "" {
|
||||
filename = fileKey
|
||||
}
|
||||
// If filename still has no extension, append the fallback (like Telegram's ext parameter).
|
||||
if filepath.Ext(filename) == "" && fallbackExt != "" {
|
||||
filename += fallbackExt
|
||||
}
|
||||
|
||||
// Write to the shared picoclaw_media directory using a unique name to avoid collisions.
|
||||
mediaDir := filepath.Join(os.TempDir(), "picoclaw_media")
|
||||
if mkdirErr := os.MkdirAll(mediaDir, 0o700); mkdirErr != nil {
|
||||
logger.ErrorCF("feishu", "Failed to create media directory", map[string]any{
|
||||
"error": mkdirErr.Error(),
|
||||
})
|
||||
return ""
|
||||
}
|
||||
ext := filepath.Ext(filename)
|
||||
localPath := filepath.Join(mediaDir, utils.SanitizeFilename(messageID+"-"+fileKey+ext))
|
||||
|
||||
out, err := os.Create(localPath)
|
||||
if err != nil {
|
||||
logger.ErrorCF("feishu", "Failed to create local file for resource", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
return ""
|
||||
}
|
||||
|
||||
if _, copyErr := io.Copy(out, resp.File); copyErr != nil {
|
||||
out.Close()
|
||||
os.Remove(localPath)
|
||||
logger.ErrorCF("feishu", "Failed to write resource to file", map[string]any{
|
||||
"error": copyErr.Error(),
|
||||
})
|
||||
return ""
|
||||
}
|
||||
out.Close()
|
||||
|
||||
ref, err := store.Store(localPath, media.MediaMeta{
|
||||
Filename: filename,
|
||||
Source: "feishu",
|
||||
}, scope)
|
||||
if err != nil {
|
||||
logger.ErrorCF("feishu", "Failed to store downloaded resource", map[string]any{
|
||||
"file_key": fileKey,
|
||||
"error": err.Error(),
|
||||
})
|
||||
os.Remove(localPath)
|
||||
return ""
|
||||
}
|
||||
|
||||
return ref
|
||||
}
|
||||
|
||||
// appendMediaTags appends media type tags to content (like Telegram's "[image: photo]").
|
||||
func appendMediaTags(content, messageType string, mediaRefs []string) string {
|
||||
if len(mediaRefs) == 0 {
|
||||
return content
|
||||
}
|
||||
|
||||
var tag string
|
||||
switch messageType {
|
||||
case larkim.MsgTypeImage:
|
||||
tag = "[image: photo]"
|
||||
case larkim.MsgTypeAudio:
|
||||
tag = "[audio]"
|
||||
case larkim.MsgTypeMedia:
|
||||
tag = "[video]"
|
||||
case larkim.MsgTypeFile:
|
||||
tag = "[file]"
|
||||
default:
|
||||
tag = "[attachment]"
|
||||
}
|
||||
|
||||
if content == "" {
|
||||
return tag
|
||||
}
|
||||
return content + " " + tag
|
||||
}
|
||||
|
||||
// sendCard sends an interactive card message to a chat.
|
||||
func (c *FeishuChannel) sendCard(ctx context.Context, chatID, cardContent string) error {
|
||||
req := larkim.NewCreateMessageReqBuilder().
|
||||
ReceiveIdType(larkim.ReceiveIdTypeChatId).
|
||||
Body(larkim.NewCreateMessageReqBodyBuilder().
|
||||
ReceiveId(chatID).
|
||||
MsgType(larkim.MsgTypeInteractive).
|
||||
Content(cardContent).
|
||||
Build()).
|
||||
Build()
|
||||
|
||||
resp, err := c.client.Im.V1.Message.Create(ctx, req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("feishu send card: %w", channels.ErrTemporary)
|
||||
}
|
||||
|
||||
if !resp.Success() {
|
||||
return fmt.Errorf("feishu api error (code=%d msg=%s): %w", resp.Code, resp.Msg, channels.ErrTemporary)
|
||||
}
|
||||
|
||||
logger.DebugCF("feishu", "Feishu card message sent", map[string]any{
|
||||
"chat_id": chatID,
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// sendImage uploads an image and sends it as a message.
|
||||
func (c *FeishuChannel) sendImage(ctx context.Context, chatID string, file *os.File) error {
|
||||
// Upload image to get image_key
|
||||
uploadReq := larkim.NewCreateImageReqBuilder().
|
||||
Body(larkim.NewCreateImageReqBodyBuilder().
|
||||
ImageType("message").
|
||||
Image(file).
|
||||
Build()).
|
||||
Build()
|
||||
|
||||
uploadResp, err := c.client.Im.V1.Image.Create(ctx, uploadReq)
|
||||
if err != nil {
|
||||
return fmt.Errorf("feishu image upload: %w", err)
|
||||
}
|
||||
if !uploadResp.Success() {
|
||||
return fmt.Errorf("feishu image upload api error (code=%d msg=%s)", uploadResp.Code, uploadResp.Msg)
|
||||
}
|
||||
if uploadResp.Data == nil || uploadResp.Data.ImageKey == nil {
|
||||
return fmt.Errorf("feishu image upload: no image_key returned")
|
||||
}
|
||||
|
||||
imageKey := *uploadResp.Data.ImageKey
|
||||
|
||||
// Send image message
|
||||
content, _ := json.Marshal(map[string]string{"image_key": imageKey})
|
||||
req := larkim.NewCreateMessageReqBuilder().
|
||||
ReceiveIdType(larkim.ReceiveIdTypeChatId).
|
||||
Body(larkim.NewCreateMessageReqBodyBuilder().
|
||||
ReceiveId(chatID).
|
||||
MsgType(larkim.MsgTypeImage).
|
||||
Content(string(content)).
|
||||
Build()).
|
||||
Build()
|
||||
|
||||
resp, err := c.client.Im.V1.Message.Create(ctx, req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("feishu image send: %w", err)
|
||||
}
|
||||
if !resp.Success() {
|
||||
return fmt.Errorf("feishu image send api error (code=%d msg=%s)", resp.Code, resp.Msg)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// sendFile uploads a file and sends it as a message.
|
||||
func (c *FeishuChannel) sendFile(ctx context.Context, chatID string, file *os.File, filename, fileType string) error {
|
||||
// Map part type to Feishu file type
|
||||
feishuFileType := "stream"
|
||||
switch fileType {
|
||||
case "audio":
|
||||
feishuFileType = "opus"
|
||||
case "video":
|
||||
feishuFileType = "mp4"
|
||||
}
|
||||
|
||||
// Upload file to get file_key
|
||||
uploadReq := larkim.NewCreateFileReqBuilder().
|
||||
Body(larkim.NewCreateFileReqBodyBuilder().
|
||||
FileType(feishuFileType).
|
||||
FileName(filename).
|
||||
File(file).
|
||||
Build()).
|
||||
Build()
|
||||
|
||||
uploadResp, err := c.client.Im.V1.File.Create(ctx, uploadReq)
|
||||
if err != nil {
|
||||
return fmt.Errorf("feishu file upload: %w", err)
|
||||
}
|
||||
if !uploadResp.Success() {
|
||||
return fmt.Errorf("feishu file upload api error (code=%d msg=%s)", uploadResp.Code, uploadResp.Msg)
|
||||
}
|
||||
if uploadResp.Data == nil || uploadResp.Data.FileKey == nil {
|
||||
return fmt.Errorf("feishu file upload: no file_key returned")
|
||||
}
|
||||
|
||||
fileKey := *uploadResp.Data.FileKey
|
||||
|
||||
// Send file message
|
||||
content, _ := json.Marshal(map[string]string{"file_key": fileKey})
|
||||
req := larkim.NewCreateMessageReqBuilder().
|
||||
ReceiveIdType(larkim.ReceiveIdTypeChatId).
|
||||
Body(larkim.NewCreateMessageReqBodyBuilder().
|
||||
ReceiveId(chatID).
|
||||
MsgType(larkim.MsgTypeFile).
|
||||
Content(string(content)).
|
||||
Build()).
|
||||
Build()
|
||||
|
||||
resp, err := c.client.Im.V1.Message.Create(ctx, req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("feishu file send: %w", err)
|
||||
}
|
||||
if !resp.Success() {
|
||||
return fmt.Errorf("feishu file send api error (code=%d msg=%s)", resp.Code, resp.Msg)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -222,20 +816,3 @@ func extractFeishuSenderID(sender *larkim.EventSender) string {
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func extractFeishuMessageContent(message *larkim.EventMessage) string {
|
||||
if message == nil || message.Content == nil || *message.Content == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
if message.MessageType != nil && *message.MessageType == larkim.MsgTypeText {
|
||||
var textPayload struct {
|
||||
Text string `json:"text"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(*message.Content), &textPayload); err == nil {
|
||||
return textPayload.Text
|
||||
}
|
||||
}
|
||||
|
||||
return *message.Content
|
||||
}
|
||||
|
||||
@@ -0,0 +1,256 @@
|
||||
//go:build amd64 || arm64 || riscv64 || mips64 || ppc64
|
||||
|
||||
package feishu
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
larkim "github.com/larksuite/oapi-sdk-go/v3/service/im/v1"
|
||||
)
|
||||
|
||||
func TestExtractContent(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
messageType string
|
||||
rawContent string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "text message",
|
||||
messageType: "text",
|
||||
rawContent: `{"text": "hello world"}`,
|
||||
want: "hello world",
|
||||
},
|
||||
{
|
||||
name: "text message invalid JSON",
|
||||
messageType: "text",
|
||||
rawContent: `not json`,
|
||||
want: "not json",
|
||||
},
|
||||
{
|
||||
name: "post message returns raw JSON",
|
||||
messageType: "post",
|
||||
rawContent: `{"title": "test post"}`,
|
||||
want: `{"title": "test post"}`,
|
||||
},
|
||||
{
|
||||
name: "image message returns empty",
|
||||
messageType: "image",
|
||||
rawContent: `{"image_key": "img_xxx"}`,
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "file message with filename",
|
||||
messageType: "file",
|
||||
rawContent: `{"file_key": "file_xxx", "file_name": "report.pdf"}`,
|
||||
want: "report.pdf",
|
||||
},
|
||||
{
|
||||
name: "file message without filename",
|
||||
messageType: "file",
|
||||
rawContent: `{"file_key": "file_xxx"}`,
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "audio message with filename",
|
||||
messageType: "audio",
|
||||
rawContent: `{"file_key": "file_xxx", "file_name": "recording.ogg"}`,
|
||||
want: "recording.ogg",
|
||||
},
|
||||
{
|
||||
name: "media message with filename",
|
||||
messageType: "media",
|
||||
rawContent: `{"file_key": "file_xxx", "file_name": "video.mp4"}`,
|
||||
want: "video.mp4",
|
||||
},
|
||||
{
|
||||
name: "unknown message type returns raw",
|
||||
messageType: "sticker",
|
||||
rawContent: `{"sticker_id": "sticker_xxx"}`,
|
||||
want: `{"sticker_id": "sticker_xxx"}`,
|
||||
},
|
||||
{
|
||||
name: "empty raw content",
|
||||
messageType: "text",
|
||||
rawContent: "",
|
||||
want: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := extractContent(tt.messageType, tt.rawContent)
|
||||
if got != tt.want {
|
||||
t.Errorf("extractContent(%q, %q) = %q, want %q", tt.messageType, tt.rawContent, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppendMediaTags(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
content string
|
||||
messageType string
|
||||
mediaRefs []string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "no refs returns content unchanged",
|
||||
content: "hello",
|
||||
messageType: "image",
|
||||
mediaRefs: nil,
|
||||
want: "hello",
|
||||
},
|
||||
{
|
||||
name: "empty refs returns content unchanged",
|
||||
content: "hello",
|
||||
messageType: "image",
|
||||
mediaRefs: []string{},
|
||||
want: "hello",
|
||||
},
|
||||
{
|
||||
name: "image with content",
|
||||
content: "check this",
|
||||
messageType: "image",
|
||||
mediaRefs: []string{"ref1"},
|
||||
want: "check this [image: photo]",
|
||||
},
|
||||
{
|
||||
name: "image empty content",
|
||||
content: "",
|
||||
messageType: "image",
|
||||
mediaRefs: []string{"ref1"},
|
||||
want: "[image: photo]",
|
||||
},
|
||||
{
|
||||
name: "audio",
|
||||
content: "listen",
|
||||
messageType: "audio",
|
||||
mediaRefs: []string{"ref1"},
|
||||
want: "listen [audio]",
|
||||
},
|
||||
{
|
||||
name: "media/video",
|
||||
content: "watch",
|
||||
messageType: "media",
|
||||
mediaRefs: []string{"ref1"},
|
||||
want: "watch [video]",
|
||||
},
|
||||
{
|
||||
name: "file",
|
||||
content: "report.pdf",
|
||||
messageType: "file",
|
||||
mediaRefs: []string{"ref1"},
|
||||
want: "report.pdf [file]",
|
||||
},
|
||||
{
|
||||
name: "unknown type",
|
||||
content: "something",
|
||||
messageType: "sticker",
|
||||
mediaRefs: []string{"ref1"},
|
||||
want: "something [attachment]",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := appendMediaTags(tt.content, tt.messageType, tt.mediaRefs)
|
||||
if got != tt.want {
|
||||
t.Errorf(
|
||||
"appendMediaTags(%q, %q, %v) = %q, want %q",
|
||||
tt.content,
|
||||
tt.messageType,
|
||||
tt.mediaRefs,
|
||||
got,
|
||||
tt.want,
|
||||
)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractFeishuSenderID(t *testing.T) {
|
||||
strPtr := func(s string) *string { return &s }
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
sender *larkim.EventSender
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "nil sender",
|
||||
sender: nil,
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "nil sender ID",
|
||||
sender: &larkim.EventSender{SenderId: nil},
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "userId preferred",
|
||||
sender: &larkim.EventSender{
|
||||
SenderId: &larkim.UserId{
|
||||
UserId: strPtr("u_abc123"),
|
||||
OpenId: strPtr("ou_def456"),
|
||||
UnionId: strPtr("on_ghi789"),
|
||||
},
|
||||
},
|
||||
want: "u_abc123",
|
||||
},
|
||||
{
|
||||
name: "openId fallback",
|
||||
sender: &larkim.EventSender{
|
||||
SenderId: &larkim.UserId{
|
||||
UserId: strPtr(""),
|
||||
OpenId: strPtr("ou_def456"),
|
||||
UnionId: strPtr("on_ghi789"),
|
||||
},
|
||||
},
|
||||
want: "ou_def456",
|
||||
},
|
||||
{
|
||||
name: "unionId fallback",
|
||||
sender: &larkim.EventSender{
|
||||
SenderId: &larkim.UserId{
|
||||
UserId: strPtr(""),
|
||||
OpenId: strPtr(""),
|
||||
UnionId: strPtr("on_ghi789"),
|
||||
},
|
||||
},
|
||||
want: "on_ghi789",
|
||||
},
|
||||
{
|
||||
name: "all empty strings",
|
||||
sender: &larkim.EventSender{
|
||||
SenderId: &larkim.UserId{
|
||||
UserId: strPtr(""),
|
||||
OpenId: strPtr(""),
|
||||
UnionId: strPtr(""),
|
||||
},
|
||||
},
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "nil userId pointer falls through",
|
||||
sender: &larkim.EventSender{
|
||||
SenderId: &larkim.UserId{
|
||||
UserId: nil,
|
||||
OpenId: strPtr("ou_def456"),
|
||||
UnionId: nil,
|
||||
},
|
||||
},
|
||||
want: "ou_def456",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := extractFeishuSenderID(tt.sender)
|
||||
if got != tt.want {
|
||||
t.Errorf("extractFeishuSenderID() = %q, want %q", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -45,11 +45,13 @@ type replyTokenEntry struct {
|
||||
type LINEChannel struct {
|
||||
*channels.BaseChannel
|
||||
config config.LINEConfig
|
||||
botUserID string // Bot's user ID
|
||||
botBasicID string // Bot's basic ID (e.g. @216ru...)
|
||||
botDisplayName string // Bot's display name for text-based mention detection
|
||||
replyTokens sync.Map // chatID -> replyTokenEntry
|
||||
quoteTokens sync.Map // chatID -> quoteToken (string)
|
||||
infoClient *http.Client // for bot info lookups (short timeout)
|
||||
apiClient *http.Client // for messaging API calls
|
||||
botUserID string // Bot's user ID
|
||||
botBasicID string // Bot's basic ID (e.g. @216ru...)
|
||||
botDisplayName string // Bot's display name for text-based mention detection
|
||||
replyTokens sync.Map // chatID -> replyTokenEntry
|
||||
quoteTokens sync.Map // chatID -> quoteToken (string)
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
@@ -69,6 +71,8 @@ func NewLINEChannel(cfg config.LINEConfig, messageBus *bus.MessageBus) (*LINECha
|
||||
return &LINEChannel{
|
||||
BaseChannel: base,
|
||||
config: cfg,
|
||||
infoClient: &http.Client{Timeout: 10 * time.Second},
|
||||
apiClient: &http.Client{Timeout: 30 * time.Second},
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -104,8 +108,7 @@ func (c *LINEChannel) fetchBotInfo() error {
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+c.config.ChannelAccessToken)
|
||||
|
||||
client := &http.Client{Timeout: 10 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
resp, err := c.infoClient.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -644,8 +647,7 @@ func (c *LINEChannel) callAPI(ctx context.Context, endpoint string, payload any)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+c.config.ChannelAccessToken)
|
||||
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
resp, err := c.apiClient.Do(req)
|
||||
if err != nil {
|
||||
return channels.ClassifyNetError(err)
|
||||
}
|
||||
|
||||
+56
-50
@@ -255,6 +255,10 @@ func (m *Manager) initChannels() error {
|
||||
m.initChannel("wecom", "WeCom")
|
||||
}
|
||||
|
||||
if m.config.Channels.WeComAIBot.Enabled && m.config.Channels.WeComAIBot.Token != "" {
|
||||
m.initChannel("wecom_aibot", "WeCom AI Bot")
|
||||
}
|
||||
|
||||
if m.config.Channels.WeComApp.Enabled && m.config.Channels.WeComApp.CorpID != "" {
|
||||
m.initChannel("wecom_app", "WeCom App")
|
||||
}
|
||||
@@ -539,86 +543,88 @@ func (m *Manager) sendWithRetry(ctx context.Context, name string, w *channelWork
|
||||
})
|
||||
}
|
||||
|
||||
func (m *Manager) dispatchOutbound(ctx context.Context) {
|
||||
logger.InfoC("channels", "Outbound dispatcher started")
|
||||
func dispatchLoop[M any](
|
||||
ctx context.Context,
|
||||
m *Manager,
|
||||
subscribe func(context.Context) (M, bool),
|
||||
getChannel func(M) string,
|
||||
enqueue func(context.Context, *channelWorker, M) bool,
|
||||
startMsg, stopMsg, unknownMsg, noWorkerMsg string,
|
||||
) {
|
||||
logger.InfoC("channels", startMsg)
|
||||
|
||||
for {
|
||||
msg, ok := m.bus.SubscribeOutbound(ctx)
|
||||
msg, ok := subscribe(ctx)
|
||||
if !ok {
|
||||
logger.InfoC("channels", "Outbound dispatcher stopped")
|
||||
logger.InfoC("channels", stopMsg)
|
||||
return
|
||||
}
|
||||
|
||||
channel := getChannel(msg)
|
||||
|
||||
// Silently skip internal channels
|
||||
if constants.IsInternalChannel(msg.Channel) {
|
||||
if constants.IsInternalChannel(channel) {
|
||||
continue
|
||||
}
|
||||
|
||||
m.mu.RLock()
|
||||
_, exists := m.channels[msg.Channel]
|
||||
w, wExists := m.workers[msg.Channel]
|
||||
_, exists := m.channels[channel]
|
||||
w, wExists := m.workers[channel]
|
||||
m.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
logger.WarnCF("channels", "Unknown channel for outbound message", map[string]any{
|
||||
"channel": msg.Channel,
|
||||
})
|
||||
logger.WarnCF("channels", unknownMsg, map[string]any{"channel": channel})
|
||||
continue
|
||||
}
|
||||
|
||||
if wExists && w != nil {
|
||||
select {
|
||||
case w.queue <- msg:
|
||||
case <-ctx.Done():
|
||||
if !enqueue(ctx, w, msg) {
|
||||
return
|
||||
}
|
||||
} else if exists {
|
||||
logger.WarnCF("channels", "Channel has no active worker, skipping message", map[string]any{
|
||||
"channel": msg.Channel,
|
||||
})
|
||||
logger.WarnCF("channels", noWorkerMsg, map[string]any{"channel": channel})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) dispatchOutbound(ctx context.Context) {
|
||||
dispatchLoop(
|
||||
ctx, m,
|
||||
m.bus.SubscribeOutbound,
|
||||
func(msg bus.OutboundMessage) string { return msg.Channel },
|
||||
func(ctx context.Context, w *channelWorker, msg bus.OutboundMessage) bool {
|
||||
select {
|
||||
case w.queue <- msg:
|
||||
return true
|
||||
case <-ctx.Done():
|
||||
return false
|
||||
}
|
||||
},
|
||||
"Outbound dispatcher started",
|
||||
"Outbound dispatcher stopped",
|
||||
"Unknown channel for outbound message",
|
||||
"Channel has no active worker, skipping message",
|
||||
)
|
||||
}
|
||||
|
||||
func (m *Manager) dispatchOutboundMedia(ctx context.Context) {
|
||||
logger.InfoC("channels", "Outbound media dispatcher started")
|
||||
|
||||
for {
|
||||
msg, ok := m.bus.SubscribeOutboundMedia(ctx)
|
||||
if !ok {
|
||||
logger.InfoC("channels", "Outbound media dispatcher stopped")
|
||||
return
|
||||
}
|
||||
|
||||
// Silently skip internal channels
|
||||
if constants.IsInternalChannel(msg.Channel) {
|
||||
continue
|
||||
}
|
||||
|
||||
m.mu.RLock()
|
||||
_, exists := m.channels[msg.Channel]
|
||||
w, wExists := m.workers[msg.Channel]
|
||||
m.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
logger.WarnCF("channels", "Unknown channel for outbound media message", map[string]any{
|
||||
"channel": msg.Channel,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
if wExists && w != nil {
|
||||
dispatchLoop(
|
||||
ctx, m,
|
||||
m.bus.SubscribeOutboundMedia,
|
||||
func(msg bus.OutboundMediaMessage) string { return msg.Channel },
|
||||
func(ctx context.Context, w *channelWorker, msg bus.OutboundMediaMessage) bool {
|
||||
select {
|
||||
case w.mediaQueue <- msg:
|
||||
return true
|
||||
case <-ctx.Done():
|
||||
return
|
||||
return false
|
||||
}
|
||||
} else if exists {
|
||||
logger.WarnCF("channels", "Channel has no active worker, skipping media message", map[string]any{
|
||||
"channel": msg.Channel,
|
||||
})
|
||||
}
|
||||
}
|
||||
},
|
||||
"Outbound media dispatcher started",
|
||||
"Outbound media dispatcher stopped",
|
||||
"Unknown channel for outbound media message",
|
||||
"Channel has no active worker, skipping media message",
|
||||
)
|
||||
}
|
||||
|
||||
// runMediaWorker processes outbound media messages for a single channel.
|
||||
|
||||
@@ -274,13 +274,12 @@ func TestWorkerRateLimiter(t *testing.T) {
|
||||
limiter: rate.NewLimiter(2, 1),
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
ctx := t.Context()
|
||||
|
||||
go m.runWorker(ctx, "test", w)
|
||||
|
||||
// Enqueue 4 messages
|
||||
for i := 0; i < 4; i++ {
|
||||
for i := range 4 {
|
||||
w.queue <- bus.OutboundMessage{Channel: "test", ChatID: "1", Content: fmt.Sprintf("msg%d", i)}
|
||||
}
|
||||
|
||||
@@ -352,8 +351,7 @@ func TestRunWorker_MessageSplitting(t *testing.T) {
|
||||
limiter: rate.NewLimiter(rate.Inf, 1),
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
ctx := t.Context()
|
||||
|
||||
go m.runWorker(ctx, "test", w)
|
||||
|
||||
@@ -576,7 +574,7 @@ func TestRecordPlaceholder_ConcurrentSafe(t *testing.T) {
|
||||
m := newTestManager()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 100; i++ {
|
||||
for i := range 100 {
|
||||
wg.Add(1)
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
@@ -591,7 +589,7 @@ func TestRecordTypingStop_ConcurrentSafe(t *testing.T) {
|
||||
m := newTestManager()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 100; i++ {
|
||||
for i := range 100 {
|
||||
wg.Add(1)
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
@@ -834,7 +832,7 @@ func TestLazyWorkerCreation(t *testing.T) {
|
||||
func TestBuildMediaScope_FastIDUniqueness(t *testing.T) {
|
||||
seen := make(map[string]bool)
|
||||
|
||||
for i := 0; i < 1000; i++ {
|
||||
for range 1000 {
|
||||
scope := BuildMediaScope("test", "chat1", "")
|
||||
if seen[scope] {
|
||||
t.Fatalf("duplicate scope generated: %s", scope)
|
||||
|
||||
@@ -337,10 +337,7 @@ func (c *OneBotChannel) sendAPIRequest(action string, params any, timeout time.D
|
||||
}
|
||||
|
||||
func (c *OneBotChannel) reconnectLoop() {
|
||||
interval := time.Duration(c.config.ReconnectInterval) * time.Second
|
||||
if interval < 5*time.Second {
|
||||
interval = 5 * time.Second
|
||||
}
|
||||
interval := max(time.Duration(c.config.ReconnectInterval)*time.Second, 5*time.Second)
|
||||
|
||||
for {
|
||||
select {
|
||||
|
||||
@@ -292,8 +292,8 @@ func (c *PicoChannel) authenticate(r *http.Request) bool {
|
||||
|
||||
// Check Authorization header
|
||||
auth := r.Header.Get("Authorization")
|
||||
if strings.HasPrefix(auth, "Bearer ") {
|
||||
if strings.TrimPrefix(auth, "Bearer ") == token {
|
||||
if after, ok := strings.CutPrefix(auth, "Bearer "); ok {
|
||||
if after == token {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
+8
-24
@@ -23,10 +23,7 @@ func SplitMessage(content string, maxLen int) []string {
|
||||
var messages []string
|
||||
|
||||
// Dynamic buffer: 10% of maxLen, but at least 50 chars if possible
|
||||
codeBlockBuffer := maxLen / 10
|
||||
if codeBlockBuffer < 50 {
|
||||
codeBlockBuffer = 50
|
||||
}
|
||||
codeBlockBuffer := max(maxLen/10, 50)
|
||||
if codeBlockBuffer > maxLen/2 {
|
||||
codeBlockBuffer = maxLen / 2
|
||||
}
|
||||
@@ -40,10 +37,7 @@ func SplitMessage(content string, maxLen int) []string {
|
||||
}
|
||||
|
||||
// Effective split point: maxLen minus buffer, to leave room for code blocks
|
||||
effectiveLimit := maxLen - codeBlockBuffer
|
||||
if effectiveLimit < maxLen/2 {
|
||||
effectiveLimit = maxLen / 2
|
||||
}
|
||||
effectiveLimit := max(maxLen-codeBlockBuffer, maxLen/2)
|
||||
|
||||
end := start + effectiveLimit
|
||||
|
||||
@@ -85,10 +79,9 @@ func SplitMessage(content string, maxLen int) []string {
|
||||
// If we have a reasonable amount of content after the header, split inside
|
||||
if msgEnd > headerEndIdx+20 {
|
||||
// Find a better split point closer to maxLen
|
||||
innerLimit := start + maxLen - 5 // Leave room for "\n```"
|
||||
if innerLimit > totalLen {
|
||||
innerLimit = totalLen
|
||||
}
|
||||
innerLimit := min(
|
||||
// Leave room for "\n```"
|
||||
start+maxLen-5, totalLen)
|
||||
betterEnd := findLastNewlineInRange(runes, start, innerLimit, 200)
|
||||
if betterEnd > headerEndIdx {
|
||||
msgEnd = betterEnd
|
||||
@@ -117,10 +110,7 @@ func SplitMessage(content string, maxLen int) []string {
|
||||
if unclosedIdx-start > 20 {
|
||||
msgEnd = unclosedIdx
|
||||
} else {
|
||||
splitAt := start + maxLen - 5
|
||||
if splitAt > totalLen {
|
||||
splitAt = totalLen
|
||||
}
|
||||
splitAt := min(start+maxLen-5, totalLen)
|
||||
chunk := strings.TrimRight(string(runes[start:splitAt]), " \t\n\r") + "\n```"
|
||||
messages = append(messages, chunk)
|
||||
remaining := strings.TrimSpace(header + "\n" + string(runes[splitAt:totalLen]))
|
||||
@@ -196,10 +186,7 @@ func findNewlineFrom(runes []rune, from int) int {
|
||||
// findLastNewlineInRange finds the last newline within the last searchWindow runes
|
||||
// of the range runes[start:end]. Returns the absolute index or start-1 (indicating not found).
|
||||
func findLastNewlineInRange(runes []rune, start, end, searchWindow int) int {
|
||||
searchStart := end - searchWindow
|
||||
if searchStart < start {
|
||||
searchStart = start
|
||||
}
|
||||
searchStart := max(end-searchWindow, start)
|
||||
for i := end - 1; i >= searchStart; i-- {
|
||||
if runes[i] == '\n' {
|
||||
return i
|
||||
@@ -211,10 +198,7 @@ func findLastNewlineInRange(runes []rune, start, end, searchWindow int) int {
|
||||
// findLastSpaceInRange finds the last space/tab within the last searchWindow runes
|
||||
// of the range runes[start:end]. Returns the absolute index or start-1 (indicating not found).
|
||||
func findLastSpaceInRange(runes []rune, start, end, searchWindow int) int {
|
||||
searchStart := end - searchWindow
|
||||
if searchStart < start {
|
||||
searchStart = start
|
||||
}
|
||||
searchStart := max(end-searchWindow, start)
|
||||
for i := end - 1; i >= searchStart; i-- {
|
||||
if runes[i] == ' ' || runes[i] == '\t' {
|
||||
return i
|
||||
|
||||
@@ -7,12 +7,12 @@ import (
|
||||
"net/url"
|
||||
"os"
|
||||
"regexp"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/mymmrac/telego"
|
||||
"github.com/mymmrac/telego/telegohandler"
|
||||
th "github.com/mymmrac/telego/telegohandler"
|
||||
tu "github.com/mymmrac/telego/telegoutil"
|
||||
|
||||
@@ -41,7 +41,7 @@ var (
|
||||
type TelegramChannel struct {
|
||||
*channels.BaseChannel
|
||||
bot *telego.Bot
|
||||
bh *telegohandler.BotHandler
|
||||
bh *th.BotHandler
|
||||
commands TelegramCommander
|
||||
config *config.Config
|
||||
chatIDs map[string]int64
|
||||
@@ -72,6 +72,10 @@ func NewTelegramChannel(cfg *config.Config, bus *bus.MessageBus) (*TelegramChann
|
||||
}))
|
||||
}
|
||||
|
||||
if baseURL := strings.TrimRight(strings.TrimSpace(telegramCfg.BaseURL), "/"); baseURL != "" {
|
||||
opts = append(opts, telego.WithAPIServer(baseURL))
|
||||
}
|
||||
|
||||
bot, err := telego.NewBot(telegramCfg.Token, opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create telegram bot: %w", err)
|
||||
@@ -101,6 +105,12 @@ func (c *TelegramChannel) Start(ctx context.Context) error {
|
||||
|
||||
c.ctx, c.cancel = context.WithCancel(ctx)
|
||||
|
||||
if err := c.initBotCommands(c.ctx); err != nil {
|
||||
logger.WarnCF("telegram", "Failed to initialize bot commands", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
updates, err := c.bot.UpdatesViaLongPolling(c.ctx, &telego.GetUpdatesParams{
|
||||
Timeout: 30,
|
||||
})
|
||||
@@ -109,20 +119,19 @@ func (c *TelegramChannel) Start(ctx context.Context) error {
|
||||
return fmt.Errorf("failed to start long polling: %w", err)
|
||||
}
|
||||
|
||||
bh, err := telegohandler.NewBotHandler(c.bot, updates)
|
||||
bh, err := th.NewBotHandler(c.bot, updates)
|
||||
if err != nil {
|
||||
c.cancel()
|
||||
return fmt.Errorf("failed to create bot handler: %w", err)
|
||||
}
|
||||
c.bh = bh
|
||||
|
||||
bh.HandleMessage(func(ctx *th.Context, message telego.Message) error {
|
||||
c.commands.Help(ctx, message)
|
||||
return nil
|
||||
}, th.CommandEqual("help"))
|
||||
bh.HandleMessage(func(ctx *th.Context, message telego.Message) error {
|
||||
return c.commands.Start(ctx, message)
|
||||
}, th.CommandEqual("start"))
|
||||
bh.HandleMessage(func(ctx *th.Context, message telego.Message) error {
|
||||
return c.commands.Help(ctx, message)
|
||||
}, th.CommandEqual("help"))
|
||||
|
||||
bh.HandleMessage(func(ctx *th.Context, message telego.Message) error {
|
||||
return c.commands.Show(ctx, message)
|
||||
@@ -141,7 +150,13 @@ func (c *TelegramChannel) Start(ctx context.Context) error {
|
||||
"username": c.bot.Username(),
|
||||
})
|
||||
|
||||
go bh.Start()
|
||||
go func() {
|
||||
if err = bh.Start(); err != nil {
|
||||
logger.ErrorCF("telegram", "Bot handler failed", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -152,7 +167,7 @@ func (c *TelegramChannel) Stop(ctx context.Context) error {
|
||||
|
||||
// Stop the bot handler
|
||||
if c.bh != nil {
|
||||
c.bh.Stop()
|
||||
_ = c.bh.StopWithContext(ctx)
|
||||
}
|
||||
|
||||
// Cancel our context (stops long polling)
|
||||
@@ -163,6 +178,51 @@ func (c *TelegramChannel) Stop(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *TelegramChannel) initBotCommands(ctx context.Context) error {
|
||||
currentCommands, err := c.bot.GetMyCommands(ctx, &telego.GetMyCommandsParams{
|
||||
Scope: tu.ScopeDefault(),
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("get commands: %w", err)
|
||||
}
|
||||
|
||||
commands := []telego.BotCommand{
|
||||
{
|
||||
Command: "start",
|
||||
Description: "Start the bot",
|
||||
},
|
||||
{
|
||||
Command: "help",
|
||||
Description: "Show a help message",
|
||||
},
|
||||
{
|
||||
Command: "show",
|
||||
Description: "Show current configuration",
|
||||
},
|
||||
{
|
||||
Command: "list",
|
||||
Description: "List available options",
|
||||
},
|
||||
}
|
||||
|
||||
// Setting commands on each start will hit the rate limit very quickly, that's why we check if an update is needed
|
||||
if !slices.Equal(currentCommands, commands) {
|
||||
logger.InfoC("telegram", "Updating bot commands")
|
||||
|
||||
err = c.bot.SetMyCommands(ctx, &telego.SetMyCommandsParams{
|
||||
Commands: commands,
|
||||
Scope: tu.ScopeDefault(),
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("set commands: %w", err)
|
||||
}
|
||||
} else {
|
||||
logger.DebugC("telegram", "Bot commands are up to date")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
|
||||
if !c.IsRunning() {
|
||||
return channels.ErrNotRunning
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,210 @@
|
||||
package wecom
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
func TestNewWeComAIBotChannel(t *testing.T) {
|
||||
t.Run("success with valid config", func(t *testing.T) {
|
||||
cfg := config.WeComAIBotConfig{
|
||||
Enabled: true,
|
||||
Token: "test_token",
|
||||
EncodingAESKey: "testkey1234567890123456789012345678901234567",
|
||||
WebhookPath: "/webhook/test",
|
||||
}
|
||||
|
||||
messageBus := bus.NewMessageBus()
|
||||
ch, err := NewWeComAIBotChannel(cfg, messageBus)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got %v", err)
|
||||
}
|
||||
|
||||
if ch == nil {
|
||||
t.Fatal("Expected channel to be created")
|
||||
}
|
||||
|
||||
if ch.Name() != "wecom_aibot" {
|
||||
t.Errorf("Expected name 'wecom_aibot', got '%s'", ch.Name())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("error with missing token", func(t *testing.T) {
|
||||
cfg := config.WeComAIBotConfig{
|
||||
Enabled: true,
|
||||
EncodingAESKey: "testkey1234567890123456789012345678901234567",
|
||||
}
|
||||
|
||||
messageBus := bus.NewMessageBus()
|
||||
_, err := NewWeComAIBotChannel(cfg, messageBus)
|
||||
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for missing token, got nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("error with missing encoding key", func(t *testing.T) {
|
||||
cfg := config.WeComAIBotConfig{
|
||||
Enabled: true,
|
||||
Token: "test_token",
|
||||
}
|
||||
|
||||
messageBus := bus.NewMessageBus()
|
||||
_, err := NewWeComAIBotChannel(cfg, messageBus)
|
||||
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for missing encoding key, got nil")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestWeComAIBotChannelStartStop(t *testing.T) {
|
||||
cfg := config.WeComAIBotConfig{
|
||||
Enabled: true,
|
||||
Token: "test_token",
|
||||
EncodingAESKey: "testkey1234567890123456789012345678901234567",
|
||||
}
|
||||
|
||||
messageBus := bus.NewMessageBus()
|
||||
ch, err := NewWeComAIBotChannel(cfg, messageBus)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create channel: %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Test Start
|
||||
if err := ch.Start(ctx); err != nil {
|
||||
t.Fatalf("Failed to start channel: %v", err)
|
||||
}
|
||||
|
||||
if !ch.IsRunning() {
|
||||
t.Error("Expected channel to be running")
|
||||
}
|
||||
|
||||
// Test Stop
|
||||
if err := ch.Stop(ctx); err != nil {
|
||||
t.Fatalf("Failed to stop channel: %v", err)
|
||||
}
|
||||
|
||||
if ch.IsRunning() {
|
||||
t.Error("Expected channel to be stopped")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWeComAIBotChannelWebhookPath(t *testing.T) {
|
||||
t.Run("default path", func(t *testing.T) {
|
||||
cfg := config.WeComAIBotConfig{
|
||||
Enabled: true,
|
||||
Token: "test_token",
|
||||
EncodingAESKey: "testkey1234567890123456789012345678901234567",
|
||||
}
|
||||
|
||||
messageBus := bus.NewMessageBus()
|
||||
ch, _ := NewWeComAIBotChannel(cfg, messageBus)
|
||||
|
||||
expectedPath := "/webhook/wecom-aibot"
|
||||
if ch.WebhookPath() != expectedPath {
|
||||
t.Errorf("Expected webhook path '%s', got '%s'", expectedPath, ch.WebhookPath())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("custom path", func(t *testing.T) {
|
||||
customPath := "/custom/webhook"
|
||||
cfg := config.WeComAIBotConfig{
|
||||
Enabled: true,
|
||||
Token: "test_token",
|
||||
EncodingAESKey: "testkey1234567890123456789012345678901234567",
|
||||
WebhookPath: customPath,
|
||||
}
|
||||
|
||||
messageBus := bus.NewMessageBus()
|
||||
ch, _ := NewWeComAIBotChannel(cfg, messageBus)
|
||||
|
||||
if ch.WebhookPath() != customPath {
|
||||
t.Errorf("Expected webhook path '%s', got '%s'", customPath, ch.WebhookPath())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestGenerateStreamID(t *testing.T) {
|
||||
cfg := config.WeComAIBotConfig{
|
||||
Enabled: true,
|
||||
Token: "test_token",
|
||||
EncodingAESKey: "testkey1234567890123456789012345678901234567",
|
||||
}
|
||||
|
||||
messageBus := bus.NewMessageBus()
|
||||
ch, _ := NewWeComAIBotChannel(cfg, messageBus)
|
||||
|
||||
// Generate multiple IDs and check they are unique
|
||||
ids := make(map[string]bool)
|
||||
for i := 0; i < 100; i++ {
|
||||
id := ch.generateStreamID()
|
||||
|
||||
if len(id) != 10 {
|
||||
t.Errorf("Expected stream ID length 10, got %d", len(id))
|
||||
}
|
||||
|
||||
if ids[id] {
|
||||
t.Errorf("Duplicate stream ID generated: %s", id)
|
||||
}
|
||||
ids[id] = true
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncryptDecrypt(t *testing.T) {
|
||||
// Use a valid 43-character base64 key (企业微信标准格式)
|
||||
cfg := config.WeComAIBotConfig{
|
||||
Enabled: true,
|
||||
Token: "test_token",
|
||||
EncodingAESKey: "abcdefghijklmnopqrstuvwxyz0123456789ABCDEFG", // 43 characters
|
||||
}
|
||||
|
||||
messageBus := bus.NewMessageBus()
|
||||
ch, _ := NewWeComAIBotChannel(cfg, messageBus)
|
||||
|
||||
plaintext := "Hello, World!"
|
||||
receiveid := ""
|
||||
|
||||
// Encrypt
|
||||
encrypted, err := ch.encryptMessage(plaintext, receiveid)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to encrypt message: %v", err)
|
||||
}
|
||||
|
||||
if encrypted == "" {
|
||||
t.Fatal("Encrypted message is empty")
|
||||
}
|
||||
|
||||
// Decrypt
|
||||
decrypted, err := decryptMessageWithVerify(encrypted, cfg.EncodingAESKey, receiveid)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to decrypt message: %v", err)
|
||||
}
|
||||
|
||||
if decrypted != plaintext {
|
||||
t.Errorf("Expected decrypted message '%s', got '%s'", plaintext, decrypted)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateSignature(t *testing.T) {
|
||||
token := "test_token"
|
||||
timestamp := "1234567890"
|
||||
nonce := "test_nonce"
|
||||
encrypt := "encrypted_msg"
|
||||
|
||||
signature := computeSignature(token, timestamp, nonce, encrypt)
|
||||
|
||||
if signature == "" {
|
||||
t.Error("Generated signature is empty")
|
||||
}
|
||||
|
||||
// Verify signature using verifySignature function
|
||||
if !verifySignature(token, signature, timestamp, nonce, encrypt) {
|
||||
t.Error("Generated signature does not verify correctly")
|
||||
}
|
||||
}
|
||||
+34
-80
@@ -32,13 +32,13 @@ const (
|
||||
type WeComAppChannel struct {
|
||||
*channels.BaseChannel
|
||||
config config.WeComAppConfig
|
||||
client *http.Client
|
||||
accessToken string
|
||||
tokenExpiry time.Time
|
||||
tokenMu sync.RWMutex
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
processedMsgs map[string]bool // Message deduplication: msg_id -> processed
|
||||
msgMu sync.RWMutex
|
||||
processedMsgs *MessageDeduplicator
|
||||
}
|
||||
|
||||
// WeComXMLMessage represents the XML message structure from WeCom
|
||||
@@ -129,13 +129,21 @@ func NewWeComAppChannel(cfg config.WeComAppConfig, messageBus *bus.MessageBus) (
|
||||
channels.WithReasoningChannelID(cfg.ReasoningChannelID),
|
||||
)
|
||||
|
||||
// Client timeout must be >= the configured ReplyTimeout so the
|
||||
// per-request context deadline is always the effective limit.
|
||||
clientTimeout := 30 * time.Second
|
||||
if d := time.Duration(cfg.ReplyTimeout) * time.Second; d > clientTimeout {
|
||||
clientTimeout = d
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
return &WeComAppChannel{
|
||||
BaseChannel: base,
|
||||
config: cfg,
|
||||
client: &http.Client{Timeout: clientTimeout},
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
processedMsgs: make(map[string]bool),
|
||||
processedMsgs: NewMessageDeduplicator(wecomMaxProcessedMessages),
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -148,6 +156,10 @@ func (c *WeComAppChannel) Name() string {
|
||||
func (c *WeComAppChannel) Start(ctx context.Context) error {
|
||||
logger.InfoC("wecom_app", "Starting WeCom App channel...")
|
||||
|
||||
// Cancel the context created in the constructor to avoid a resource leak.
|
||||
if c.cancel != nil {
|
||||
c.cancel()
|
||||
}
|
||||
c.ctx, c.cancel = context.WithCancel(ctx)
|
||||
|
||||
// Get initial access token
|
||||
@@ -302,8 +314,7 @@ func (c *WeComAppChannel) uploadMedia(ctx context.Context, accessToken, mediaTyp
|
||||
}
|
||||
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||||
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
resp, err := c.client.Do(req)
|
||||
if err != nil {
|
||||
return "", channels.ClassifyNetError(err)
|
||||
}
|
||||
@@ -330,18 +341,11 @@ func (c *WeComAppChannel) uploadMedia(ctx context.Context, accessToken, mediaTyp
|
||||
return result.MediaID, nil
|
||||
}
|
||||
|
||||
// sendImageMessage sends an image message using a media_id.
|
||||
func (c *WeComAppChannel) sendImageMessage(ctx context.Context, accessToken, userID, mediaID string) error {
|
||||
// sendWeComMessage marshals payload and POSTs it to the WeCom message API.
|
||||
func (c *WeComAppChannel) sendWeComMessage(ctx context.Context, accessToken string, payload any) error {
|
||||
apiURL := fmt.Sprintf("%s/cgi-bin/message/send?access_token=%s", wecomAPIBase, accessToken)
|
||||
|
||||
msg := WeComImageMessage{
|
||||
ToUser: userID,
|
||||
MsgType: "image",
|
||||
AgentID: c.config.AgentID,
|
||||
}
|
||||
msg.Image.MediaID = mediaID
|
||||
|
||||
jsonData, err := json.Marshal(msg)
|
||||
jsonData, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal message: %w", err)
|
||||
}
|
||||
@@ -360,8 +364,7 @@ func (c *WeComAppChannel) sendImageMessage(ctx context.Context, accessToken, use
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
client := &http.Client{Timeout: time.Duration(timeout) * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
resp, err := c.client.Do(req)
|
||||
if err != nil {
|
||||
return channels.ClassifyNetError(err)
|
||||
}
|
||||
@@ -389,6 +392,17 @@ func (c *WeComAppChannel) sendImageMessage(ctx context.Context, accessToken, use
|
||||
return nil
|
||||
}
|
||||
|
||||
// sendImageMessage sends an image message using a media_id.
|
||||
func (c *WeComAppChannel) sendImageMessage(ctx context.Context, accessToken, userID, mediaID string) error {
|
||||
msg := WeComImageMessage{
|
||||
ToUser: userID,
|
||||
MsgType: "image",
|
||||
AgentID: c.config.AgentID,
|
||||
}
|
||||
msg.Image.MediaID = mediaID
|
||||
return c.sendWeComMessage(ctx, accessToken, msg)
|
||||
}
|
||||
|
||||
// WebhookPath returns the path for registering on the shared HTTP server.
|
||||
func (c *WeComAppChannel) WebhookPath() string {
|
||||
if c.config.WebhookPath != "" {
|
||||
@@ -592,23 +606,12 @@ func (c *WeComAppChannel) processMessage(ctx context.Context, msg WeComXMLMessag
|
||||
// Message deduplication: Use msg_id to prevent duplicate processing
|
||||
// As per WeCom documentation, use msg_id for deduplication
|
||||
msgID := fmt.Sprintf("%d", msg.MsgId)
|
||||
c.msgMu.Lock()
|
||||
if c.processedMsgs[msgID] {
|
||||
c.msgMu.Unlock()
|
||||
if !c.processedMsgs.MarkMessageProcessed(msgID) {
|
||||
logger.DebugCF("wecom_app", "Skipping duplicate message", map[string]any{
|
||||
"msg_id": msgID,
|
||||
})
|
||||
return
|
||||
}
|
||||
c.processedMsgs[msgID] = true
|
||||
c.msgMu.Unlock()
|
||||
|
||||
// Clean up old messages periodically (keep last 1000)
|
||||
if len(c.processedMsgs) > 1000 {
|
||||
c.msgMu.Lock()
|
||||
c.processedMsgs = make(map[string]bool)
|
||||
c.msgMu.Unlock()
|
||||
}
|
||||
|
||||
senderID := msg.FromUserName
|
||||
chatID := senderID // WeCom App uses user ID as chat ID for direct messages
|
||||
@@ -711,64 +714,15 @@ func (c *WeComAppChannel) getAccessToken() string {
|
||||
return c.accessToken
|
||||
}
|
||||
|
||||
// sendTextMessage sends a text message to a user
|
||||
// sendTextMessage sends a text message to a user.
|
||||
func (c *WeComAppChannel) sendTextMessage(ctx context.Context, accessToken, userID, content string) error {
|
||||
apiURL := fmt.Sprintf("%s/cgi-bin/message/send?access_token=%s", wecomAPIBase, accessToken)
|
||||
|
||||
msg := WeComTextMessage{
|
||||
ToUser: userID,
|
||||
MsgType: "text",
|
||||
AgentID: c.config.AgentID,
|
||||
}
|
||||
msg.Text.Content = content
|
||||
|
||||
jsonData, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal message: %w", err)
|
||||
}
|
||||
|
||||
// Use configurable timeout (default 5 seconds)
|
||||
timeout := c.config.ReplyTimeout
|
||||
if timeout <= 0 {
|
||||
timeout = 5
|
||||
}
|
||||
|
||||
reqCtx, cancel := context.WithTimeout(ctx, time.Duration(timeout)*time.Second)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(reqCtx, http.MethodPost, apiURL, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
client := &http.Client{Timeout: time.Duration(timeout) * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return channels.ClassifyNetError(err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return channels.ClassifySendError(resp.StatusCode, fmt.Errorf("wecom_app API error: %s", string(body)))
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
var sendResp WeComSendMessageResponse
|
||||
if err := json.Unmarshal(body, &sendResp); err != nil {
|
||||
return fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
if sendResp.ErrCode != 0 {
|
||||
return fmt.Errorf("API error: %s (code: %d)", sendResp.ErrMsg, sendResp.ErrCode)
|
||||
}
|
||||
|
||||
return nil
|
||||
return c.sendWeComMessage(ctx, accessToken, msg)
|
||||
}
|
||||
|
||||
// handleHealth handles health check requests
|
||||
|
||||
@@ -43,7 +43,7 @@ func encryptTestMessageApp(message, aesKey string) (string, error) {
|
||||
|
||||
// Prepare message: random(16) + msg_len(4) + msg + corp_id
|
||||
random := make([]byte, 0, 16)
|
||||
for i := 0; i < 16; i++ {
|
||||
for i := range 16 {
|
||||
random = append(random, byte(i+1))
|
||||
}
|
||||
|
||||
@@ -323,60 +323,6 @@ func TestWeComAppDecryptMessage(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestWeComAppPKCS7Unpad(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input []byte
|
||||
expected []byte
|
||||
}{
|
||||
{
|
||||
name: "empty input",
|
||||
input: []byte{},
|
||||
expected: []byte{},
|
||||
},
|
||||
{
|
||||
name: "valid padding 3 bytes",
|
||||
input: append([]byte("hello"), bytes.Repeat([]byte{3}, 3)...),
|
||||
expected: []byte("hello"),
|
||||
},
|
||||
{
|
||||
name: "valid padding 16 bytes (full block)",
|
||||
input: append([]byte("123456789012345"), bytes.Repeat([]byte{16}, 16)...),
|
||||
expected: []byte("123456789012345"),
|
||||
},
|
||||
{
|
||||
name: "invalid padding larger than data",
|
||||
input: []byte{20},
|
||||
expected: nil, // should return error
|
||||
},
|
||||
{
|
||||
name: "invalid padding zero",
|
||||
input: append([]byte("test"), byte(0)),
|
||||
expected: nil, // should return error
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := pkcs7Unpad(tt.input)
|
||||
if tt.expected == nil {
|
||||
// This case should return an error
|
||||
if err == nil {
|
||||
t.Errorf("pkcs7Unpad() expected error for invalid padding, got result: %v", result)
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Errorf("pkcs7Unpad() unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
if !bytes.Equal(result, tt.expected) {
|
||||
t.Errorf("pkcs7Unpad() = %v, want %v", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWeComAppHandleVerification(t *testing.T) {
|
||||
msgBus := bus.NewMessageBus()
|
||||
aesKey := generateTestAESKeyApp()
|
||||
|
||||
+17
-18
@@ -9,7 +9,6 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
@@ -25,10 +24,10 @@ import (
|
||||
type WeComBotChannel struct {
|
||||
*channels.BaseChannel
|
||||
config config.WeComConfig
|
||||
client *http.Client
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
processedMsgs map[string]bool // Message deduplication: msg_id -> processed
|
||||
msgMu sync.RWMutex
|
||||
processedMsgs *MessageDeduplicator
|
||||
}
|
||||
|
||||
// WeComBotMessage represents the JSON message structure from WeCom Bot (AIBOT)
|
||||
@@ -93,13 +92,21 @@ func NewWeComBotChannel(cfg config.WeComConfig, messageBus *bus.MessageBus) (*We
|
||||
channels.WithReasoningChannelID(cfg.ReasoningChannelID),
|
||||
)
|
||||
|
||||
// Client timeout must be >= the configured ReplyTimeout so the
|
||||
// per-request context deadline is always the effective limit.
|
||||
clientTimeout := 30 * time.Second
|
||||
if d := time.Duration(cfg.ReplyTimeout) * time.Second; d > clientTimeout {
|
||||
clientTimeout = d
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
return &WeComBotChannel{
|
||||
BaseChannel: base,
|
||||
config: cfg,
|
||||
client: &http.Client{Timeout: clientTimeout},
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
processedMsgs: make(map[string]bool),
|
||||
processedMsgs: NewMessageDeduplicator(wecomMaxProcessedMessages),
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -112,6 +119,10 @@ func (c *WeComBotChannel) Name() string {
|
||||
func (c *WeComBotChannel) Start(ctx context.Context) error {
|
||||
logger.InfoC("wecom", "Starting WeCom Bot channel...")
|
||||
|
||||
// Cancel the context created in the constructor to avoid a resource leak.
|
||||
if c.cancel != nil {
|
||||
c.cancel()
|
||||
}
|
||||
c.ctx, c.cancel = context.WithCancel(ctx)
|
||||
|
||||
c.SetRunning(true)
|
||||
@@ -317,23 +328,12 @@ func (c *WeComBotChannel) processMessage(ctx context.Context, msg WeComBotMessag
|
||||
|
||||
// Message deduplication: Use msg_id to prevent duplicate processing
|
||||
msgID := msg.MsgID
|
||||
c.msgMu.Lock()
|
||||
if c.processedMsgs[msgID] {
|
||||
c.msgMu.Unlock()
|
||||
if !c.processedMsgs.MarkMessageProcessed(msgID) {
|
||||
logger.DebugCF("wecom", "Skipping duplicate message", map[string]any{
|
||||
"msg_id": msgID,
|
||||
})
|
||||
return
|
||||
}
|
||||
c.processedMsgs[msgID] = true
|
||||
c.msgMu.Unlock()
|
||||
|
||||
// Clean up old messages periodically (keep last 1000)
|
||||
if len(c.processedMsgs) > 1000 {
|
||||
c.msgMu.Lock()
|
||||
c.processedMsgs = make(map[string]bool)
|
||||
c.msgMu.Unlock()
|
||||
}
|
||||
|
||||
senderID := msg.From.UserID
|
||||
|
||||
@@ -446,8 +446,7 @@ func (c *WeComBotChannel) sendWebhookReply(ctx context.Context, userID, content
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
client := &http.Client{Timeout: time.Duration(timeout) * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
resp, err := c.client.Do(req)
|
||||
if err != nil {
|
||||
return channels.ClassifyNetError(err)
|
||||
}
|
||||
|
||||
@@ -42,7 +42,7 @@ func encryptTestMessage(message, aesKey string) (string, error) {
|
||||
|
||||
// Prepare message: random(16) + msg_len(4) + msg + receiveid
|
||||
random := make([]byte, 0, 16)
|
||||
for i := 0; i < 16; i++ {
|
||||
for i := range 16 {
|
||||
random = append(random, byte(i))
|
||||
}
|
||||
|
||||
@@ -412,22 +412,9 @@ func TestWeComBotHandleMessageCallback(t *testing.T) {
|
||||
}
|
||||
ch, _ := NewWeComBotChannel(cfg, msgBus)
|
||||
|
||||
t.Run("valid direct message callback", func(t *testing.T) {
|
||||
// Create JSON message for direct chat (single)
|
||||
jsonMsg := `{
|
||||
"msgid": "test_msg_id_123",
|
||||
"aibotid": "test_aibot_id",
|
||||
"chattype": "single",
|
||||
"from": {"userid": "user123"},
|
||||
"response_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test",
|
||||
"msgtype": "text",
|
||||
"text": {"content": "Hello World"}
|
||||
}`
|
||||
|
||||
// Encrypt message
|
||||
runBotMessageCallback := func(t *testing.T, jsonMsg string) *httptest.ResponseRecorder {
|
||||
t.Helper()
|
||||
encrypted, _ := encryptTestMessage(jsonMsg, aesKey)
|
||||
|
||||
// Create encrypted XML wrapper
|
||||
encryptedWrapper := struct {
|
||||
XMLName xml.Name `xml:"xml"`
|
||||
Encrypt string `xml:"Encrypt"`
|
||||
@@ -435,20 +422,29 @@ func TestWeComBotHandleMessageCallback(t *testing.T) {
|
||||
Encrypt: encrypted,
|
||||
}
|
||||
wrapperData, _ := xml.Marshal(encryptedWrapper)
|
||||
|
||||
timestamp := "1234567890"
|
||||
nonce := "test_nonce"
|
||||
signature := generateSignature("test_token", timestamp, nonce, encrypted)
|
||||
|
||||
req := httptest.NewRequest(
|
||||
http.MethodPost,
|
||||
"/webhook/wecom?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce,
|
||||
bytes.NewReader(wrapperData),
|
||||
)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
ch.handleMessageCallback(context.Background(), w, req)
|
||||
return w
|
||||
}
|
||||
|
||||
t.Run("valid direct message callback", func(t *testing.T) {
|
||||
w := runBotMessageCallback(t, `{
|
||||
"msgid": "test_msg_id_123",
|
||||
"aibotid": "test_aibot_id",
|
||||
"chattype": "single",
|
||||
"from": {"userid": "user123"},
|
||||
"response_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test",
|
||||
"msgtype": "text",
|
||||
"text": {"content": "Hello World"}
|
||||
}`)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("status code = %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
@@ -458,8 +454,7 @@ func TestWeComBotHandleMessageCallback(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("valid group message callback", func(t *testing.T) {
|
||||
// Create JSON message for group chat
|
||||
jsonMsg := `{
|
||||
w := runBotMessageCallback(t, `{
|
||||
"msgid": "test_msg_id_456",
|
||||
"aibotid": "test_aibot_id",
|
||||
"chatid": "group_chat_id_123",
|
||||
@@ -468,33 +463,7 @@ func TestWeComBotHandleMessageCallback(t *testing.T) {
|
||||
"response_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test",
|
||||
"msgtype": "text",
|
||||
"text": {"content": "Hello Group"}
|
||||
}`
|
||||
|
||||
// Encrypt message
|
||||
encrypted, _ := encryptTestMessage(jsonMsg, aesKey)
|
||||
|
||||
// Create encrypted XML wrapper
|
||||
encryptedWrapper := struct {
|
||||
XMLName xml.Name `xml:"xml"`
|
||||
Encrypt string `xml:"Encrypt"`
|
||||
}{
|
||||
Encrypt: encrypted,
|
||||
}
|
||||
wrapperData, _ := xml.Marshal(encryptedWrapper)
|
||||
|
||||
timestamp := "1234567890"
|
||||
nonce := "test_nonce"
|
||||
signature := generateSignature("test_token", timestamp, nonce, encrypted)
|
||||
|
||||
req := httptest.NewRequest(
|
||||
http.MethodPost,
|
||||
"/webhook/wecom?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce,
|
||||
bytes.NewReader(wrapperData),
|
||||
)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
ch.handleMessageCallback(context.Background(), w, req)
|
||||
|
||||
}`)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("status code = %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
+113
-48
@@ -1,12 +1,15 @@
|
||||
package wecom
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"crypto/sha1"
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
@@ -14,25 +17,23 @@ import (
|
||||
// blockSize is the PKCS7 block size used by WeCom (32)
|
||||
const blockSize = 32
|
||||
|
||||
// computeSignature computes the WeCom message signature from the given parameters.
|
||||
// It sorts [token, timestamp, nonce, encrypt], concatenates them and returns the SHA1 hex digest.
|
||||
func computeSignature(token, timestamp, nonce, encrypt string) string {
|
||||
params := []string{token, timestamp, nonce, encrypt}
|
||||
sort.Strings(params)
|
||||
str := strings.Join(params, "")
|
||||
hash := sha1.Sum([]byte(str))
|
||||
return fmt.Sprintf("%x", hash)
|
||||
}
|
||||
|
||||
// verifySignature verifies the message signature for WeCom
|
||||
// This is a common function used by both WeCom Bot and WeCom App
|
||||
func verifySignature(token, msgSignature, timestamp, nonce, msgEncrypt string) bool {
|
||||
if token == "" {
|
||||
return true // Skip verification if token is not set
|
||||
}
|
||||
|
||||
// Sort parameters
|
||||
params := []string{token, timestamp, nonce, msgEncrypt}
|
||||
sort.Strings(params)
|
||||
|
||||
// Concatenate
|
||||
str := strings.Join(params, "")
|
||||
|
||||
// SHA1 hash
|
||||
hash := sha1.Sum([]byte(str))
|
||||
expectedSignature := fmt.Sprintf("%x", hash)
|
||||
|
||||
return expectedSignature == msgSignature
|
||||
return computeSignature(token, timestamp, nonce, msgEncrypt) == msgSignature
|
||||
}
|
||||
|
||||
// decryptMessage decrypts the encrypted message using AES
|
||||
@@ -53,64 +54,128 @@ func decryptMessageWithVerify(encryptedMsg, encodingAESKey, receiveid string) (s
|
||||
return string(decoded), nil
|
||||
}
|
||||
|
||||
// Decode AES key (base64)
|
||||
aesKey, err := base64.StdEncoding.DecodeString(encodingAESKey + "=")
|
||||
aesKey, err := decodeWeComAESKey(encodingAESKey)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to decode AES key: %w", err)
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Decode encrypted message
|
||||
cipherText, err := base64.StdEncoding.DecodeString(encryptedMsg)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to decode message: %w", err)
|
||||
}
|
||||
|
||||
// AES decrypt
|
||||
plainText, err := decryptAESCBC(aesKey, cipherText)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return unpackWeComFrame(plainText, receiveid)
|
||||
}
|
||||
|
||||
// decodeWeComAESKey base64-decodes the 43-character EncodingAESKey (trailing "=" is
|
||||
// appended automatically) and validates that the result is exactly 32 bytes.
|
||||
// It is the single place that handles this repeated pattern in both encrypt and decrypt paths.
|
||||
func decodeWeComAESKey(encodingAESKey string) ([]byte, error) {
|
||||
aesKey, err := base64.StdEncoding.DecodeString(encodingAESKey + "=")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode AES key: %w", err)
|
||||
}
|
||||
if len(aesKey) != 32 {
|
||||
return nil, fmt.Errorf("invalid AES key length: %d", len(aesKey))
|
||||
}
|
||||
return aesKey, nil
|
||||
}
|
||||
|
||||
// encryptAESCBC encrypts plaintext using AES-CBC with the given key, mirroring
|
||||
// decryptAESCBC. IV = aesKey[:aes.BlockSize]. The caller must PKCS7-pad the
|
||||
// plaintext to a multiple of aes.BlockSize before calling.
|
||||
func encryptAESCBC(aesKey, plaintext []byte) ([]byte, error) {
|
||||
block, err := aes.NewCipher(aesKey)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create cipher: %w", err)
|
||||
return nil, fmt.Errorf("failed to create cipher: %w", err)
|
||||
}
|
||||
|
||||
if len(cipherText) < aes.BlockSize {
|
||||
return "", fmt.Errorf("ciphertext too short")
|
||||
}
|
||||
|
||||
// IV is the first 16 bytes of AESKey
|
||||
iv := aesKey[:aes.BlockSize]
|
||||
mode := cipher.NewCBCDecrypter(block, iv)
|
||||
plainText := make([]byte, len(cipherText))
|
||||
mode.CryptBlocks(plainText, cipherText)
|
||||
ciphertext := make([]byte, len(plaintext))
|
||||
cipher.NewCBCEncrypter(block, iv).CryptBlocks(ciphertext, plaintext)
|
||||
return ciphertext, nil
|
||||
}
|
||||
|
||||
// Remove PKCS7 padding
|
||||
plainText, err = pkcs7Unpad(plainText)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to unpad: %w", err)
|
||||
// packWeComFrame builds the WeCom wire format:
|
||||
//
|
||||
// random(16 ASCII digits) + msg_len(4, big-endian) + msg + receiveid
|
||||
func packWeComFrame(msg, receiveid string) ([]byte, error) {
|
||||
randomBytes := make([]byte, 16)
|
||||
for i := range 16 {
|
||||
n, err := rand.Int(rand.Reader, big.NewInt(10))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate random: %w", err)
|
||||
}
|
||||
randomBytes[i] = byte('0' + n.Int64())
|
||||
}
|
||||
msgBytes := []byte(msg)
|
||||
msgLenBytes := make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(msgLenBytes, uint32(len(msgBytes)))
|
||||
var buf bytes.Buffer
|
||||
buf.Write(randomBytes)
|
||||
buf.Write(msgLenBytes)
|
||||
buf.Write(msgBytes)
|
||||
buf.WriteString(receiveid)
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
// Parse message structure
|
||||
// Format: random(16) + msg_len(4) + msg + receiveid
|
||||
if len(plainText) < 20 {
|
||||
return "", fmt.Errorf("decrypted message too short")
|
||||
// unpackWeComFrame parses the WeCom wire format produced by packWeComFrame.
|
||||
// If receiveid is non-empty it verifies the frame's trailing receiveid field.
|
||||
func unpackWeComFrame(data []byte, receiveid string) (string, error) {
|
||||
if len(data) < 20 {
|
||||
return "", fmt.Errorf("decrypted frame too short: %d bytes", len(data))
|
||||
}
|
||||
|
||||
msgLen := binary.BigEndian.Uint32(plainText[16:20])
|
||||
if int(msgLen) > len(plainText)-20 {
|
||||
return "", fmt.Errorf("invalid message length")
|
||||
msgLen := binary.BigEndian.Uint32(data[16:20])
|
||||
if int(msgLen) > len(data)-20 {
|
||||
return "", fmt.Errorf("invalid message length: %d", msgLen)
|
||||
}
|
||||
|
||||
msg := plainText[20 : 20+msgLen]
|
||||
|
||||
// Verify receiveid if provided
|
||||
if receiveid != "" && len(plainText) > 20+int(msgLen) {
|
||||
actualReceiveID := string(plainText[20+msgLen:])
|
||||
msg := data[20 : 20+msgLen]
|
||||
if receiveid != "" && len(data) > 20+int(msgLen) {
|
||||
actualReceiveID := string(data[20+msgLen:])
|
||||
if actualReceiveID != receiveid {
|
||||
return "", fmt.Errorf("receiveid mismatch: expected %s, got %s", receiveid, actualReceiveID)
|
||||
}
|
||||
}
|
||||
|
||||
return string(msg), nil
|
||||
}
|
||||
|
||||
// decryptAESCBC decrypts ciphertext using AES-CBC with the given key.
|
||||
// IV = aesKey[:aes.BlockSize]. PKCS7 padding is stripped from the returned plaintext.
|
||||
func decryptAESCBC(aesKey, ciphertext []byte) ([]byte, error) {
|
||||
if len(ciphertext) == 0 {
|
||||
return nil, fmt.Errorf("ciphertext is empty")
|
||||
}
|
||||
if len(ciphertext)%aes.BlockSize != 0 {
|
||||
return nil, fmt.Errorf("ciphertext length %d is not a multiple of block size", len(ciphertext))
|
||||
}
|
||||
block, err := aes.NewCipher(aesKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create cipher: %w", err)
|
||||
}
|
||||
iv := aesKey[:aes.BlockSize]
|
||||
plaintext := make([]byte, len(ciphertext))
|
||||
cipher.NewCBCDecrypter(block, iv).CryptBlocks(plaintext, ciphertext)
|
||||
plaintext, err = pkcs7Unpad(plaintext)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to unpad: %w", err)
|
||||
}
|
||||
return plaintext, nil
|
||||
}
|
||||
|
||||
// pkcs7Pad adds PKCS7 padding
|
||||
func pkcs7Pad(data []byte, blockSize int) []byte {
|
||||
padding := blockSize - (len(data) % blockSize)
|
||||
if padding == 0 {
|
||||
padding = blockSize
|
||||
}
|
||||
padText := bytes.Repeat([]byte{byte(padding)}, padding)
|
||||
return append(data, padText...)
|
||||
}
|
||||
|
||||
// pkcs7Unpad removes PKCS7 padding with validation
|
||||
func pkcs7Unpad(data []byte) ([]byte, error) {
|
||||
if len(data) == 0 {
|
||||
@@ -125,7 +190,7 @@ func pkcs7Unpad(data []byte) ([]byte, error) {
|
||||
return nil, fmt.Errorf("padding size larger than data")
|
||||
}
|
||||
// Verify all padding bytes
|
||||
for i := 0; i < padding; i++ {
|
||||
for i := range padding {
|
||||
if data[len(data)-1-i] != byte(padding) {
|
||||
return nil, fmt.Errorf("invalid padding byte at position %d", i)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,54 @@
|
||||
package wecom
|
||||
|
||||
import "sync"
|
||||
|
||||
const wecomMaxProcessedMessages = 1000
|
||||
|
||||
// MessageDeduplicator provides thread-safe message deduplication using a circular queue (ring buffer)
|
||||
// combined with a hash map. This ensures fast O(1) lookups while naturally evicting the oldest
|
||||
// messages without causing "amnesia cliffs" when the limit is reached.
|
||||
type MessageDeduplicator struct {
|
||||
mu sync.Mutex
|
||||
msgs map[string]bool
|
||||
ring []string
|
||||
idx int
|
||||
max int
|
||||
}
|
||||
|
||||
// NewMessageDeduplicator creates a new deduplicator with the specified capacity.
|
||||
func NewMessageDeduplicator(maxEntries int) *MessageDeduplicator {
|
||||
if maxEntries <= 0 {
|
||||
maxEntries = wecomMaxProcessedMessages
|
||||
}
|
||||
return &MessageDeduplicator{
|
||||
msgs: make(map[string]bool, maxEntries),
|
||||
ring: make([]string, maxEntries),
|
||||
max: maxEntries,
|
||||
}
|
||||
}
|
||||
|
||||
// MarkMessageProcessed marks msgID as processed and returns false for duplicates.
|
||||
func (d *MessageDeduplicator) MarkMessageProcessed(msgID string) bool {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
// 1. Check for duplicate
|
||||
if d.msgs[msgID] {
|
||||
return false
|
||||
}
|
||||
|
||||
// 2. Evict the oldest message at our current ring position (if any)
|
||||
oldestID := d.ring[d.idx]
|
||||
if oldestID != "" {
|
||||
delete(d.msgs, oldestID)
|
||||
}
|
||||
|
||||
// 3. Store the new message
|
||||
d.msgs[msgID] = true
|
||||
d.ring[d.idx] = msgID
|
||||
|
||||
// 4. Advance the circle queue index
|
||||
d.idx = (d.idx + 1) % d.max
|
||||
|
||||
return true
|
||||
}
|
||||
@@ -0,0 +1,83 @@
|
||||
package wecom
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestMessageDeduplicator_DuplicateDetection(t *testing.T) {
|
||||
d := NewMessageDeduplicator(wecomMaxProcessedMessages)
|
||||
|
||||
if ok := d.MarkMessageProcessed("msg-1"); !ok {
|
||||
t.Fatalf("first message should be accepted")
|
||||
}
|
||||
|
||||
if ok := d.MarkMessageProcessed("msg-1"); ok {
|
||||
t.Fatalf("duplicate message should be rejected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessageDeduplicator_ConcurrentSameMessage(t *testing.T) {
|
||||
d := NewMessageDeduplicator(wecomMaxProcessedMessages)
|
||||
|
||||
const goroutines = 64
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(goroutines)
|
||||
|
||||
results := make(chan bool, goroutines)
|
||||
for i := 0; i < goroutines; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
results <- d.MarkMessageProcessed("msg-concurrent")
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(results)
|
||||
|
||||
successes := 0
|
||||
for ok := range results {
|
||||
if ok {
|
||||
successes++
|
||||
}
|
||||
}
|
||||
|
||||
if successes != 1 {
|
||||
t.Fatalf("expected exactly 1 successful mark, got %d", successes)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessageDeduplicator_CircularQueueEviction(t *testing.T) {
|
||||
// Create a deduplicator with a very small capacity to test eviction easily.
|
||||
capacity := 3
|
||||
d := NewMessageDeduplicator(capacity)
|
||||
|
||||
// Fill the queue.
|
||||
d.MarkMessageProcessed("msg-1")
|
||||
d.MarkMessageProcessed("msg-2")
|
||||
d.MarkMessageProcessed("msg-3")
|
||||
|
||||
// At this point, the queue is full. msg-1 is the oldest.
|
||||
if len(d.msgs) != 3 {
|
||||
t.Fatalf("expected map size to be 3, got %d", len(d.msgs))
|
||||
}
|
||||
|
||||
// This should evict msg-1 and add msg-4.
|
||||
if ok := d.MarkMessageProcessed("msg-4"); !ok {
|
||||
t.Fatalf("msg-4 should be accepted")
|
||||
}
|
||||
|
||||
if len(d.msgs) != 3 {
|
||||
t.Fatalf("expected map size to remain at max capacity (3), got %d", len(d.msgs))
|
||||
}
|
||||
|
||||
// msg-1 should now be forgotten (evicted).
|
||||
if ok := d.MarkMessageProcessed("msg-1"); !ok {
|
||||
t.Fatalf("msg-1 should be accepted again because it was evicted")
|
||||
}
|
||||
|
||||
// msg-2 should have been evicted when we added msg-1 back.
|
||||
if ok := d.MarkMessageProcessed("msg-2"); !ok {
|
||||
t.Fatalf("msg-2 should be accepted again because it was evicted")
|
||||
}
|
||||
}
|
||||
@@ -13,4 +13,7 @@ func init() {
|
||||
channels.RegisterFactory("wecom_app", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
|
||||
return NewWeComAppChannel(cfg.Channels.WeComApp, b)
|
||||
})
|
||||
channels.RegisterFactory("wecom_aibot", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
|
||||
return NewWeComAIBotChannel(cfg.Channels.WeComAIBot, b)
|
||||
})
|
||||
}
|
||||
|
||||
+96
-52
@@ -168,17 +168,28 @@ type SessionConfig struct {
|
||||
}
|
||||
|
||||
type AgentDefaults struct {
|
||||
Workspace string `json:"workspace" env:"PICOCLAW_AGENTS_DEFAULTS_WORKSPACE"`
|
||||
RestrictToWorkspace bool `json:"restrict_to_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE"`
|
||||
Provider string `json:"provider" env:"PICOCLAW_AGENTS_DEFAULTS_PROVIDER"`
|
||||
ModelName string `json:"model_name,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL_NAME"`
|
||||
Model string `json:"model" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL"` // Deprecated: use model_name instead
|
||||
ModelFallbacks []string `json:"model_fallbacks,omitempty"`
|
||||
ImageModel string `json:"image_model,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_IMAGE_MODEL"`
|
||||
ImageModelFallbacks []string `json:"image_model_fallbacks,omitempty"`
|
||||
MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"`
|
||||
Temperature *float64 `json:"temperature,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"`
|
||||
MaxToolIterations int `json:"max_tool_iterations" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"`
|
||||
Workspace string `json:"workspace" env:"PICOCLAW_AGENTS_DEFAULTS_WORKSPACE"`
|
||||
RestrictToWorkspace bool `json:"restrict_to_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE"`
|
||||
AllowReadOutsideWorkspace bool `json:"allow_read_outside_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_ALLOW_READ_OUTSIDE_WORKSPACE"`
|
||||
Provider string `json:"provider" env:"PICOCLAW_AGENTS_DEFAULTS_PROVIDER"`
|
||||
ModelName string `json:"model_name,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL_NAME"`
|
||||
Model string `json:"model" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL"` // Deprecated: use model_name instead
|
||||
ModelFallbacks []string `json:"model_fallbacks,omitempty"`
|
||||
ImageModel string `json:"image_model,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_IMAGE_MODEL"`
|
||||
ImageModelFallbacks []string `json:"image_model_fallbacks,omitempty"`
|
||||
MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"`
|
||||
Temperature *float64 `json:"temperature,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"`
|
||||
MaxToolIterations int `json:"max_tool_iterations" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"`
|
||||
MaxMediaSize int `json:"max_media_size,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_MEDIA_SIZE"`
|
||||
}
|
||||
|
||||
const DefaultMaxMediaSize = 20 * 1024 * 1024 // 20 MB
|
||||
|
||||
func (d *AgentDefaults) GetMaxMediaSize() int {
|
||||
if d.MaxMediaSize > 0 {
|
||||
return d.MaxMediaSize
|
||||
}
|
||||
return DefaultMaxMediaSize
|
||||
}
|
||||
|
||||
// GetModelName returns the effective model name for the agent defaults.
|
||||
@@ -191,19 +202,20 @@ func (d *AgentDefaults) GetModelName() string {
|
||||
}
|
||||
|
||||
type ChannelsConfig struct {
|
||||
WhatsApp WhatsAppConfig `json:"whatsapp"`
|
||||
Telegram TelegramConfig `json:"telegram"`
|
||||
Feishu FeishuConfig `json:"feishu"`
|
||||
Discord DiscordConfig `json:"discord"`
|
||||
MaixCam MaixCamConfig `json:"maixcam"`
|
||||
QQ QQConfig `json:"qq"`
|
||||
DingTalk DingTalkConfig `json:"dingtalk"`
|
||||
Slack SlackConfig `json:"slack"`
|
||||
LINE LINEConfig `json:"line"`
|
||||
OneBot OneBotConfig `json:"onebot"`
|
||||
WeCom WeComConfig `json:"wecom"`
|
||||
WeComApp WeComAppConfig `json:"wecom_app"`
|
||||
Pico PicoConfig `json:"pico"`
|
||||
WhatsApp WhatsAppConfig `json:"whatsapp"`
|
||||
Telegram TelegramConfig `json:"telegram"`
|
||||
Feishu FeishuConfig `json:"feishu"`
|
||||
Discord DiscordConfig `json:"discord"`
|
||||
MaixCam MaixCamConfig `json:"maixcam"`
|
||||
QQ QQConfig `json:"qq"`
|
||||
DingTalk DingTalkConfig `json:"dingtalk"`
|
||||
Slack SlackConfig `json:"slack"`
|
||||
LINE LINEConfig `json:"line"`
|
||||
OneBot OneBotConfig `json:"onebot"`
|
||||
WeCom WeComConfig `json:"wecom"`
|
||||
WeComApp WeComAppConfig `json:"wecom_app"`
|
||||
WeComAIBot WeComAIBotConfig `json:"wecom_aibot"`
|
||||
Pico PicoConfig `json:"pico"`
|
||||
}
|
||||
|
||||
// GroupTriggerConfig controls when the bot responds in group chats.
|
||||
@@ -235,6 +247,7 @@ type WhatsAppConfig struct {
|
||||
type TelegramConfig struct {
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_TELEGRAM_ENABLED"`
|
||||
Token string `json:"token" env:"PICOCLAW_CHANNELS_TELEGRAM_TOKEN"`
|
||||
BaseURL string `json:"base_url" env:"PICOCLAW_CHANNELS_TELEGRAM_BASE_URL"`
|
||||
Proxy string `json:"proxy" env:"PICOCLAW_CHANNELS_TELEGRAM_PROXY"`
|
||||
AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_TELEGRAM_ALLOW_FROM"`
|
||||
GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"`
|
||||
@@ -251,6 +264,7 @@ type FeishuConfig struct {
|
||||
VerificationToken string `json:"verification_token" env:"PICOCLAW_CHANNELS_FEISHU_VERIFICATION_TOKEN"`
|
||||
AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_FEISHU_ALLOW_FROM"`
|
||||
GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"`
|
||||
Placeholder PlaceholderConfig `json:"placeholder,omitempty"`
|
||||
ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_FEISHU_REASONING_CHANNEL_ID"`
|
||||
}
|
||||
|
||||
@@ -359,6 +373,18 @@ type WeComAppConfig struct {
|
||||
ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_WECOM_APP_REASONING_CHANNEL_ID"`
|
||||
}
|
||||
|
||||
type WeComAIBotConfig struct {
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_ENABLED"`
|
||||
Token string `json:"token" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_TOKEN"`
|
||||
EncodingAESKey string `json:"encoding_aes_key" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_ENCODING_AES_KEY"`
|
||||
WebhookPath string `json:"webhook_path" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_WEBHOOK_PATH"`
|
||||
AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_ALLOW_FROM"`
|
||||
ReplyTimeout int `json:"reply_timeout" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_REPLY_TIMEOUT"`
|
||||
MaxSteps int `json:"max_steps" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_MAX_STEPS"` // Maximum streaming steps
|
||||
WelcomeMessage string `json:"welcome_message" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_WELCOME_MESSAGE"` // Sent on enter_chat event; empty = no welcome
|
||||
ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_REASONING_CHANNEL_ID"`
|
||||
}
|
||||
|
||||
type PicoConfig struct {
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_PICO_ENABLED"`
|
||||
Token string `json:"token" env:"PICOCLAW_CHANNELS_PICO_TOKEN"`
|
||||
@@ -385,6 +411,7 @@ type DevicesConfig struct {
|
||||
type ProvidersConfig struct {
|
||||
Anthropic ProviderConfig `json:"anthropic"`
|
||||
OpenAI OpenAIProviderConfig `json:"openai"`
|
||||
LiteLLM ProviderConfig `json:"litellm"`
|
||||
OpenRouter ProviderConfig `json:"openrouter"`
|
||||
Groq ProviderConfig `json:"groq"`
|
||||
Zhipu ProviderConfig `json:"zhipu"`
|
||||
@@ -408,6 +435,7 @@ type ProvidersConfig struct {
|
||||
func (p ProvidersConfig) IsEmpty() bool {
|
||||
return p.Anthropic.APIKey == "" && p.Anthropic.APIBase == "" &&
|
||||
p.OpenAI.APIKey == "" && p.OpenAI.APIBase == "" &&
|
||||
p.LiteLLM.APIKey == "" && p.LiteLLM.APIBase == "" &&
|
||||
p.OpenRouter.APIKey == "" && p.OpenRouter.APIBase == "" &&
|
||||
p.Groq.APIKey == "" && p.Groq.APIBase == "" &&
|
||||
p.Zhipu.APIKey == "" && p.Zhipu.APIBase == "" &&
|
||||
@@ -523,7 +551,8 @@ type WebToolsConfig struct {
|
||||
Perplexity PerplexityConfig `json:"perplexity"`
|
||||
// Proxy is an optional proxy URL for web tools (http/https/socks5/socks5h).
|
||||
// For authenticated proxies, prefer HTTP_PROXY/HTTPS_PROXY env vars instead of embedding credentials in config.
|
||||
Proxy string `json:"proxy,omitempty" env:"PICOCLAW_TOOLS_WEB_PROXY"`
|
||||
Proxy string `json:"proxy,omitempty" env:"PICOCLAW_TOOLS_WEB_PROXY"`
|
||||
FetchLimitBytes int64 `json:"fetch_limit_bytes,omitempty" env:"PICOCLAW_TOOLS_WEB_FETCH_LIMIT_BYTES"`
|
||||
}
|
||||
|
||||
type CronToolsConfig struct {
|
||||
@@ -531,8 +560,9 @@ type CronToolsConfig struct {
|
||||
}
|
||||
|
||||
type ExecConfig struct {
|
||||
EnableDenyPatterns bool `json:"enable_deny_patterns" env:"PICOCLAW_TOOLS_EXEC_ENABLE_DENY_PATTERNS"`
|
||||
CustomDenyPatterns []string `json:"custom_deny_patterns" env:"PICOCLAW_TOOLS_EXEC_CUSTOM_DENY_PATTERNS"`
|
||||
EnableDenyPatterns bool `json:"enable_deny_patterns" env:"PICOCLAW_TOOLS_EXEC_ENABLE_DENY_PATTERNS"`
|
||||
CustomDenyPatterns []string `json:"custom_deny_patterns" env:"PICOCLAW_TOOLS_EXEC_CUSTOM_DENY_PATTERNS"`
|
||||
CustomAllowPatterns []string `json:"custom_allow_patterns" env:"PICOCLAW_TOOLS_EXEC_CUSTOM_ALLOW_PATTERNS"`
|
||||
}
|
||||
|
||||
type MediaCleanupConfig struct {
|
||||
@@ -542,11 +572,14 @@ type MediaCleanupConfig struct {
|
||||
}
|
||||
|
||||
type ToolsConfig struct {
|
||||
Web WebToolsConfig `json:"web"`
|
||||
Cron CronToolsConfig `json:"cron"`
|
||||
Exec ExecConfig `json:"exec"`
|
||||
Skills SkillsToolsConfig `json:"skills"`
|
||||
MediaCleanup MediaCleanupConfig `json:"media_cleanup"`
|
||||
AllowReadPaths []string `json:"allow_read_paths" env:"PICOCLAW_TOOLS_ALLOW_READ_PATHS"`
|
||||
AllowWritePaths []string `json:"allow_write_paths" env:"PICOCLAW_TOOLS_ALLOW_WRITE_PATHS"`
|
||||
Web WebToolsConfig `json:"web"`
|
||||
Cron CronToolsConfig `json:"cron"`
|
||||
Exec ExecConfig `json:"exec"`
|
||||
Skills SkillsToolsConfig `json:"skills"`
|
||||
MediaCleanup MediaCleanupConfig `json:"media_cleanup"`
|
||||
MCP MCPConfig `json:"mcp"`
|
||||
}
|
||||
|
||||
type SkillsToolsConfig struct {
|
||||
@@ -576,6 +609,34 @@ type ClawHubRegistryConfig struct {
|
||||
MaxResponseSize int `json:"max_response_size" env:"PICOCLAW_SKILLS_REGISTRIES_CLAWHUB_MAX_RESPONSE_SIZE"`
|
||||
}
|
||||
|
||||
// MCPServerConfig defines configuration for a single MCP server
|
||||
type MCPServerConfig struct {
|
||||
// Enabled indicates whether this MCP server is active
|
||||
Enabled bool `json:"enabled"`
|
||||
// Command is the executable to run (e.g., "npx", "python", "/path/to/server")
|
||||
Command string `json:"command"`
|
||||
// Args are the arguments to pass to the command
|
||||
Args []string `json:"args,omitempty"`
|
||||
// Env are environment variables to set for the server process (stdio only)
|
||||
Env map[string]string `json:"env,omitempty"`
|
||||
// EnvFile is the path to a file containing environment variables (stdio only)
|
||||
EnvFile string `json:"env_file,omitempty"`
|
||||
// Type is "stdio", "sse", or "http" (default: stdio if command is set, sse if url is set)
|
||||
Type string `json:"type,omitempty"`
|
||||
// URL is used for SSE/HTTP transport
|
||||
URL string `json:"url,omitempty"`
|
||||
// Headers are HTTP headers to send with requests (sse/http only)
|
||||
Headers map[string]string `json:"headers,omitempty"`
|
||||
}
|
||||
|
||||
// MCPConfig defines configuration for all MCP servers
|
||||
type MCPConfig struct {
|
||||
// Enabled globally enables/disables MCP integration
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_MCP_ENABLED"`
|
||||
// Servers is a map of server name to server configuration
|
||||
Servers map[string]MCPServerConfig `json:"servers,omitempty"`
|
||||
}
|
||||
|
||||
func LoadConfig(path string) (*Config, error) {
|
||||
cfg := DefaultConfig()
|
||||
|
||||
@@ -632,7 +693,8 @@ func (c *Config) migrateChannelConfigs() {
|
||||
}
|
||||
|
||||
// OneBot: group_trigger_prefix -> group_trigger.prefixes
|
||||
if len(c.Channels.OneBot.GroupTriggerPrefix) > 0 && len(c.Channels.OneBot.GroupTrigger.Prefixes) == 0 {
|
||||
if len(c.Channels.OneBot.GroupTriggerPrefix) > 0 &&
|
||||
len(c.Channels.OneBot.GroupTrigger.Prefixes) == 0 {
|
||||
c.Channels.OneBot.GroupTrigger.Prefixes = c.Channels.OneBot.GroupTriggerPrefix
|
||||
}
|
||||
}
|
||||
@@ -742,25 +804,7 @@ func (c *Config) findMatches(modelName string) []ModelConfig {
|
||||
|
||||
// HasProvidersConfig checks if any provider in the old providers config has configuration.
|
||||
func (c *Config) HasProvidersConfig() bool {
|
||||
v := c.Providers
|
||||
return v.Anthropic.APIKey != "" || v.Anthropic.APIBase != "" ||
|
||||
v.OpenAI.APIKey != "" || v.OpenAI.APIBase != "" ||
|
||||
v.OpenRouter.APIKey != "" || v.OpenRouter.APIBase != "" ||
|
||||
v.Groq.APIKey != "" || v.Groq.APIBase != "" ||
|
||||
v.Zhipu.APIKey != "" || v.Zhipu.APIBase != "" ||
|
||||
v.VLLM.APIKey != "" || v.VLLM.APIBase != "" ||
|
||||
v.Gemini.APIKey != "" || v.Gemini.APIBase != "" ||
|
||||
v.Nvidia.APIKey != "" || v.Nvidia.APIBase != "" ||
|
||||
v.Ollama.APIKey != "" || v.Ollama.APIBase != "" ||
|
||||
v.Moonshot.APIKey != "" || v.Moonshot.APIBase != "" ||
|
||||
v.ShengSuanYun.APIKey != "" || v.ShengSuanYun.APIBase != "" ||
|
||||
v.DeepSeek.APIKey != "" || v.DeepSeek.APIBase != "" ||
|
||||
v.Cerebras.APIKey != "" || v.Cerebras.APIBase != "" ||
|
||||
v.VolcEngine.APIKey != "" || v.VolcEngine.APIBase != "" ||
|
||||
v.GitHubCopilot.APIKey != "" || v.GitHubCopilot.APIBase != "" ||
|
||||
v.Antigravity.APIKey != "" || v.Antigravity.APIBase != "" ||
|
||||
v.Qwen.APIKey != "" || v.Qwen.APIBase != "" ||
|
||||
v.Mistral.APIKey != "" || v.Mistral.APIBase != ""
|
||||
return !c.Providers.IsEmpty()
|
||||
}
|
||||
|
||||
// ValidateModelList validates all ModelConfig entries in the model_list.
|
||||
|
||||
@@ -442,3 +442,28 @@ func TestDefaultConfig_DMScope(t *testing.T) {
|
||||
t.Errorf("Session.DMScope = %q, want 'per-channel-peer'", cfg.Session.DMScope)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultConfig_WorkspacePath_Default(t *testing.T) {
|
||||
// Unset to ensure we test the default
|
||||
t.Setenv("PICOCLAW_HOME", "")
|
||||
// Set a known home for consistent test results
|
||||
t.Setenv("HOME", "/tmp/home")
|
||||
|
||||
cfg := DefaultConfig()
|
||||
want := filepath.Join("/tmp/home", ".picoclaw", "workspace")
|
||||
|
||||
if cfg.Agents.Defaults.Workspace != want {
|
||||
t.Errorf("Default workspace path = %q, want %q", cfg.Agents.Defaults.Workspace, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultConfig_WorkspacePath_WithPicoclawHome(t *testing.T) {
|
||||
t.Setenv("PICOCLAW_HOME", "/custom/picoclaw/home")
|
||||
|
||||
cfg := DefaultConfig()
|
||||
want := "/custom/picoclaw/home/workspace"
|
||||
|
||||
if cfg.Agents.Defaults.Workspace != want {
|
||||
t.Errorf("Workspace path with PICOCLAW_HOME = %q, want %q", cfg.Agents.Defaults.Workspace, want)
|
||||
}
|
||||
}
|
||||
|
||||
+33
-2
@@ -5,12 +5,28 @@
|
||||
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
// DefaultConfig returns the default configuration for PicoClaw.
|
||||
func DefaultConfig() *Config {
|
||||
// Determine the base path for the workspace.
|
||||
// Priority: $PICOCLAW_HOME > ~/.picoclaw
|
||||
var homePath string
|
||||
if picoclawHome := os.Getenv("PICOCLAW_HOME"); picoclawHome != "" {
|
||||
homePath = picoclawHome
|
||||
} else {
|
||||
userHome, _ := os.UserHomeDir()
|
||||
homePath = filepath.Join(userHome, ".picoclaw")
|
||||
}
|
||||
workspacePath := filepath.Join(homePath, "workspace")
|
||||
|
||||
return &Config{
|
||||
Agents: AgentsConfig{
|
||||
Defaults: AgentDefaults{
|
||||
Workspace: "~/.picoclaw/workspace",
|
||||
Workspace: workspacePath,
|
||||
RestrictToWorkspace: true,
|
||||
Provider: "",
|
||||
Model: "",
|
||||
@@ -121,6 +137,16 @@ func DefaultConfig() *Config {
|
||||
AllowFrom: FlexibleStringSlice{},
|
||||
ReplyTimeout: 5,
|
||||
},
|
||||
WeComAIBot: WeComAIBotConfig{
|
||||
Enabled: false,
|
||||
Token: "",
|
||||
EncodingAESKey: "",
|
||||
WebhookPath: "/webhook/wecom-aibot",
|
||||
AllowFrom: FlexibleStringSlice{},
|
||||
ReplyTimeout: 5,
|
||||
MaxSteps: 10,
|
||||
WelcomeMessage: "Hello! I'm your AI assistant. How can I help you today?",
|
||||
},
|
||||
Pico: PicoConfig{
|
||||
Enabled: false,
|
||||
Token: "",
|
||||
@@ -299,7 +325,8 @@ func DefaultConfig() *Config {
|
||||
Interval: 5,
|
||||
},
|
||||
Web: WebToolsConfig{
|
||||
Proxy: "",
|
||||
Proxy: "",
|
||||
FetchLimitBytes: 10 * 1024 * 1024, // 10MB by default
|
||||
Brave: BraveConfig{
|
||||
Enabled: false,
|
||||
APIKey: "",
|
||||
@@ -334,6 +361,10 @@ func DefaultConfig() *Config {
|
||||
TTLSeconds: 300,
|
||||
},
|
||||
},
|
||||
MCP: MCPConfig{
|
||||
Enabled: false,
|
||||
Servers: map[string]MCPServerConfig{},
|
||||
},
|
||||
},
|
||||
Heartbeat: HeartbeatConfig{
|
||||
Enabled: true,
|
||||
|
||||
@@ -88,6 +88,23 @@ func ConvertProvidersToModelList(cfg *Config) []ModelConfig {
|
||||
}, true
|
||||
},
|
||||
},
|
||||
{
|
||||
providerNames: []string{"litellm"},
|
||||
protocol: "litellm",
|
||||
buildConfig: func(p ProvidersConfig) (ModelConfig, bool) {
|
||||
if p.LiteLLM.APIKey == "" && p.LiteLLM.APIBase == "" {
|
||||
return ModelConfig{}, false
|
||||
}
|
||||
return ModelConfig{
|
||||
ModelName: "litellm",
|
||||
Model: "litellm/auto",
|
||||
APIKey: p.LiteLLM.APIKey,
|
||||
APIBase: p.LiteLLM.APIBase,
|
||||
Proxy: p.LiteLLM.Proxy,
|
||||
RequestTimeout: p.LiteLLM.RequestTimeout,
|
||||
}, true
|
||||
},
|
||||
},
|
||||
{
|
||||
providerNames: []string{"openrouter"},
|
||||
protocol: "openrouter",
|
||||
|
||||
@@ -63,6 +63,33 @@ func TestConvertProvidersToModelList_Anthropic(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertProvidersToModelList_LiteLLM(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Providers: ProvidersConfig{
|
||||
LiteLLM: ProviderConfig{
|
||||
APIKey: "litellm-key",
|
||||
APIBase: "http://localhost:4000/v1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result := ConvertProvidersToModelList(cfg)
|
||||
|
||||
if len(result) != 1 {
|
||||
t.Fatalf("len(result) = %d, want 1", len(result))
|
||||
}
|
||||
|
||||
if result[0].ModelName != "litellm" {
|
||||
t.Errorf("ModelName = %q, want %q", result[0].ModelName, "litellm")
|
||||
}
|
||||
if result[0].Model != "litellm/auto" {
|
||||
t.Errorf("Model = %q, want %q", result[0].Model, "litellm/auto")
|
||||
}
|
||||
if result[0].APIBase != "http://localhost:4000/v1" {
|
||||
t.Errorf("APIBase = %q, want %q", result[0].APIBase, "http://localhost:4000/v1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertProvidersToModelList_Multiple(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Providers: ProvidersConfig{
|
||||
@@ -115,6 +142,7 @@ func TestConvertProvidersToModelList_AllProviders(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Providers: ProvidersConfig{
|
||||
OpenAI: OpenAIProviderConfig{ProviderConfig: ProviderConfig{APIKey: "key1"}},
|
||||
LiteLLM: ProviderConfig{APIKey: "key-litellm", APIBase: "http://localhost:4000/v1"},
|
||||
Anthropic: ProviderConfig{APIKey: "key2"},
|
||||
OpenRouter: ProviderConfig{APIKey: "key3"},
|
||||
Groq: ProviderConfig{APIKey: "key4"},
|
||||
@@ -137,9 +165,9 @@ func TestConvertProvidersToModelList_AllProviders(t *testing.T) {
|
||||
|
||||
result := ConvertProvidersToModelList(cfg)
|
||||
|
||||
// All 18 providers should be converted
|
||||
if len(result) != 18 {
|
||||
t.Errorf("len(result) = %d, want 18", len(result))
|
||||
// All 19 providers should be converted
|
||||
if len(result) != 19 {
|
||||
t.Errorf("len(result) = %d, want 19", len(result))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -64,7 +64,7 @@ func TestGetModelConfig_RoundRobin(t *testing.T) {
|
||||
|
||||
// Test round-robin distribution
|
||||
results := make(map[string]int)
|
||||
for i := 0; i < 30; i++ {
|
||||
for range 30 {
|
||||
result, err := cfg.GetModelConfig("lb-model")
|
||||
if err != nil {
|
||||
t.Fatalf("GetModelConfig() error = %v", err)
|
||||
@@ -94,17 +94,15 @@ func TestGetModelConfig_Concurrent(t *testing.T) {
|
||||
var wg sync.WaitGroup
|
||||
errors := make(chan error, goroutines*iterations)
|
||||
|
||||
for i := 0; i < goroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for j := 0; j < iterations; j++ {
|
||||
for range goroutines {
|
||||
wg.Go(func() {
|
||||
for range iterations {
|
||||
_, err := cfg.GetModelConfig("concurrent-model")
|
||||
if err != nil {
|
||||
errors <- err
|
||||
}
|
||||
}
|
||||
}()
|
||||
})
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"maps"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -122,9 +123,7 @@ func (s *Server) readyHandler(w http.ResponseWriter, r *http.Request) {
|
||||
s.mu.RLock()
|
||||
ready := s.ready
|
||||
checks := make(map[string]Check)
|
||||
for k, v := range s.checks {
|
||||
checks[k] = v
|
||||
}
|
||||
maps.Copy(checks, s.checks)
|
||||
s.mu.RUnlock()
|
||||
|
||||
if !ready {
|
||||
|
||||
@@ -47,79 +47,63 @@ func TestExecuteHeartbeat_Async(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteHeartbeat_Error(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "heartbeat-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
hs := NewHeartbeatService(tmpDir, 30, true)
|
||||
hs.stopChan = make(chan struct{}) // Enable for testing
|
||||
|
||||
hs.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult {
|
||||
return &tools.ToolResult{
|
||||
ForLLM: "Heartbeat failed: connection error",
|
||||
ForUser: "",
|
||||
Silent: false,
|
||||
IsError: true,
|
||||
Async: false,
|
||||
}
|
||||
})
|
||||
|
||||
// Create HEARTBEAT.md
|
||||
os.WriteFile(filepath.Join(tmpDir, "HEARTBEAT.md"), []byte("Test task"), 0o644)
|
||||
|
||||
hs.executeHeartbeat()
|
||||
|
||||
// Check log file for error message
|
||||
logFile := filepath.Join(tmpDir, "heartbeat.log")
|
||||
data, err := os.ReadFile(logFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read log file: %v", err)
|
||||
func TestExecuteHeartbeat_ResultLogging(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
result *tools.ToolResult
|
||||
wantLog string
|
||||
}{
|
||||
{
|
||||
name: "error result",
|
||||
result: &tools.ToolResult{
|
||||
ForLLM: "Heartbeat failed: connection error",
|
||||
ForUser: "",
|
||||
Silent: false,
|
||||
IsError: true,
|
||||
Async: false,
|
||||
},
|
||||
wantLog: "error message",
|
||||
},
|
||||
{
|
||||
name: "silent result",
|
||||
result: &tools.ToolResult{
|
||||
ForLLM: "Heartbeat completed successfully",
|
||||
ForUser: "",
|
||||
Silent: true,
|
||||
IsError: false,
|
||||
Async: false,
|
||||
},
|
||||
wantLog: "completion message",
|
||||
},
|
||||
}
|
||||
|
||||
logContent := string(data)
|
||||
if logContent == "" {
|
||||
t.Error("Expected log file to contain error message")
|
||||
}
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "heartbeat-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
func TestExecuteHeartbeat_Silent(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "heartbeat-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
hs := NewHeartbeatService(tmpDir, 30, true)
|
||||
hs.stopChan = make(chan struct{}) // Enable for testing
|
||||
|
||||
hs := NewHeartbeatService(tmpDir, 30, true)
|
||||
hs.stopChan = make(chan struct{}) // Enable for testing
|
||||
hs.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult {
|
||||
return tt.result
|
||||
})
|
||||
|
||||
hs.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult {
|
||||
return &tools.ToolResult{
|
||||
ForLLM: "Heartbeat completed successfully",
|
||||
ForUser: "",
|
||||
Silent: true,
|
||||
IsError: false,
|
||||
Async: false,
|
||||
}
|
||||
})
|
||||
os.WriteFile(filepath.Join(tmpDir, "HEARTBEAT.md"), []byte("Test task"), 0o644)
|
||||
hs.executeHeartbeat()
|
||||
|
||||
// Create HEARTBEAT.md
|
||||
os.WriteFile(filepath.Join(tmpDir, "HEARTBEAT.md"), []byte("Test task"), 0o644)
|
||||
|
||||
hs.executeHeartbeat()
|
||||
|
||||
// Check log file for completion message
|
||||
logFile := filepath.Join(tmpDir, "heartbeat.log")
|
||||
data, err := os.ReadFile(logFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read log file: %v", err)
|
||||
}
|
||||
|
||||
logContent := string(data)
|
||||
if logContent == "" {
|
||||
t.Error("Expected log file to contain completion message")
|
||||
logFile := filepath.Join(tmpDir, "heartbeat.log")
|
||||
data, err := os.ReadFile(logFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read log file: %v", err)
|
||||
}
|
||||
if string(data) == "" {
|
||||
t.Errorf("Expected log file to contain %s", tt.wantLog)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,532 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
)
|
||||
|
||||
// headerTransport is an http.RoundTripper that adds custom headers to requests
|
||||
type headerTransport struct {
|
||||
base http.RoundTripper
|
||||
headers map[string]string
|
||||
}
|
||||
|
||||
func (t *headerTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
// Clone the request to avoid modifying the original
|
||||
req = req.Clone(req.Context())
|
||||
|
||||
// Add custom headers
|
||||
for key, value := range t.headers {
|
||||
req.Header.Set(key, value)
|
||||
}
|
||||
|
||||
// Use the base transport
|
||||
base := t.base
|
||||
if base == nil {
|
||||
base = http.DefaultTransport
|
||||
}
|
||||
return base.RoundTrip(req)
|
||||
}
|
||||
|
||||
// loadEnvFile loads environment variables from a file in .env format
|
||||
// Each line should be in the format: KEY=value
|
||||
// Lines starting with # are comments
|
||||
// Empty lines are ignored
|
||||
func loadEnvFile(path string) (map[string]string, error) {
|
||||
file, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open env file: %w", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
envVars := make(map[string]string)
|
||||
scanner := bufio.NewScanner(file)
|
||||
lineNum := 0
|
||||
|
||||
for scanner.Scan() {
|
||||
lineNum++
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
|
||||
// Skip empty lines and comments
|
||||
if line == "" || strings.HasPrefix(line, "#") {
|
||||
continue
|
||||
}
|
||||
|
||||
// Parse KEY=value
|
||||
parts := strings.SplitN(line, "=", 2)
|
||||
if len(parts) != 2 {
|
||||
return nil, fmt.Errorf("invalid format at line %d: %s", lineNum, line)
|
||||
}
|
||||
|
||||
key := strings.TrimSpace(parts[0])
|
||||
value := strings.TrimSpace(parts[1])
|
||||
|
||||
if key == "" {
|
||||
return nil, fmt.Errorf("invalid format at line %d: empty key", lineNum)
|
||||
}
|
||||
|
||||
// Remove surrounding quotes if present
|
||||
if len(value) >= 2 {
|
||||
if (value[0] == '"' && value[len(value)-1] == '"') ||
|
||||
(value[0] == '\'' && value[len(value)-1] == '\'') {
|
||||
value = value[1 : len(value)-1]
|
||||
}
|
||||
}
|
||||
|
||||
envVars[key] = value
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
return nil, fmt.Errorf("error reading env file: %w", err)
|
||||
}
|
||||
|
||||
return envVars, nil
|
||||
}
|
||||
|
||||
// ServerConnection represents a connection to an MCP server
|
||||
type ServerConnection struct {
|
||||
Name string
|
||||
Client *mcp.Client
|
||||
Session *mcp.ClientSession
|
||||
Tools []*mcp.Tool
|
||||
}
|
||||
|
||||
// Manager manages multiple MCP server connections
|
||||
type Manager struct {
|
||||
servers map[string]*ServerConnection
|
||||
mu sync.RWMutex
|
||||
closed atomic.Bool // changed from bool to atomic.Bool to avoid TOCTOU race
|
||||
wg sync.WaitGroup // tracks in-flight CallTool calls
|
||||
}
|
||||
|
||||
// NewManager creates a new MCP manager
|
||||
func NewManager() *Manager {
|
||||
return &Manager{
|
||||
servers: make(map[string]*ServerConnection),
|
||||
}
|
||||
}
|
||||
|
||||
// LoadFromConfig loads MCP servers from configuration
|
||||
func (m *Manager) LoadFromConfig(ctx context.Context, cfg *config.Config) error {
|
||||
return m.LoadFromMCPConfig(ctx, cfg.Tools.MCP, cfg.WorkspacePath())
|
||||
}
|
||||
|
||||
// LoadFromMCPConfig loads MCP servers from MCP configuration and workspace path.
|
||||
// This is the minimal dependency version that doesn't require the full Config object.
|
||||
func (m *Manager) LoadFromMCPConfig(
|
||||
ctx context.Context,
|
||||
mcpCfg config.MCPConfig,
|
||||
workspacePath string,
|
||||
) error {
|
||||
if !mcpCfg.Enabled {
|
||||
logger.InfoCF("mcp", "MCP integration is disabled", nil)
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(mcpCfg.Servers) == 0 {
|
||||
logger.InfoCF("mcp", "No MCP servers configured", nil)
|
||||
return nil
|
||||
}
|
||||
|
||||
logger.InfoCF("mcp", "Initializing MCP servers",
|
||||
map[string]any{
|
||||
"count": len(mcpCfg.Servers),
|
||||
})
|
||||
|
||||
var wg sync.WaitGroup
|
||||
errs := make(chan error, len(mcpCfg.Servers))
|
||||
enabledCount := 0
|
||||
|
||||
for name, serverCfg := range mcpCfg.Servers {
|
||||
if !serverCfg.Enabled {
|
||||
logger.DebugCF("mcp", "Skipping disabled server",
|
||||
map[string]any{
|
||||
"server": name,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
enabledCount++
|
||||
wg.Add(1)
|
||||
go func(name string, serverCfg config.MCPServerConfig, workspace string) {
|
||||
defer wg.Done()
|
||||
|
||||
// Resolve relative envFile paths relative to workspace
|
||||
if serverCfg.EnvFile != "" && !filepath.IsAbs(serverCfg.EnvFile) {
|
||||
if workspace == "" {
|
||||
err := fmt.Errorf(
|
||||
"workspace path is empty while resolving relative envFile %q for server %s",
|
||||
serverCfg.EnvFile,
|
||||
name,
|
||||
)
|
||||
logger.ErrorCF("mcp", "Invalid MCP server configuration",
|
||||
map[string]any{
|
||||
"server": name,
|
||||
"env_file": serverCfg.EnvFile,
|
||||
"error": err.Error(),
|
||||
})
|
||||
errs <- err
|
||||
return
|
||||
}
|
||||
serverCfg.EnvFile = filepath.Join(workspace, serverCfg.EnvFile)
|
||||
}
|
||||
|
||||
if err := m.ConnectServer(ctx, name, serverCfg); err != nil {
|
||||
logger.ErrorCF("mcp", "Failed to connect to MCP server",
|
||||
map[string]any{
|
||||
"server": name,
|
||||
"error": err.Error(),
|
||||
})
|
||||
errs <- fmt.Errorf("failed to connect to server %s: %w", name, err)
|
||||
}
|
||||
}(name, serverCfg, workspacePath)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(errs)
|
||||
|
||||
// Collect errors
|
||||
var allErrors []error
|
||||
for err := range errs {
|
||||
allErrors = append(allErrors, err)
|
||||
}
|
||||
|
||||
connectedCount := len(m.GetServers())
|
||||
|
||||
// If all enabled servers failed to connect, return aggregated error
|
||||
if enabledCount > 0 && connectedCount == 0 {
|
||||
logger.ErrorCF("mcp", "All MCP servers failed to connect",
|
||||
map[string]any{
|
||||
"failed": len(allErrors),
|
||||
"total": enabledCount,
|
||||
})
|
||||
return errors.Join(allErrors...)
|
||||
}
|
||||
|
||||
if len(allErrors) > 0 {
|
||||
logger.WarnCF("mcp", "Some MCP servers failed to connect",
|
||||
map[string]any{
|
||||
"failed": len(allErrors),
|
||||
"connected": connectedCount,
|
||||
"total": enabledCount,
|
||||
})
|
||||
// Don't fail completely if some servers successfully connected
|
||||
}
|
||||
|
||||
logger.InfoCF("mcp", "MCP server initialization complete",
|
||||
map[string]any{
|
||||
"connected": connectedCount,
|
||||
"total": enabledCount,
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ConnectServer connects to a single MCP server
|
||||
func (m *Manager) ConnectServer(
|
||||
ctx context.Context,
|
||||
name string,
|
||||
cfg config.MCPServerConfig,
|
||||
) error {
|
||||
logger.InfoCF("mcp", "Connecting to MCP server",
|
||||
map[string]any{
|
||||
"server": name,
|
||||
"command": cfg.Command,
|
||||
"args_count": len(cfg.Args),
|
||||
})
|
||||
|
||||
// Create client
|
||||
client := mcp.NewClient(&mcp.Implementation{
|
||||
Name: "picoclaw",
|
||||
Version: "1.0.0",
|
||||
}, nil)
|
||||
|
||||
// Create transport based on configuration
|
||||
// Auto-detect transport type if not explicitly specified
|
||||
var transport mcp.Transport
|
||||
transportType := cfg.Type
|
||||
|
||||
// Auto-detect: if URL is provided, use SSE; if command is provided, use stdio
|
||||
if transportType == "" {
|
||||
if cfg.URL != "" {
|
||||
transportType = "sse"
|
||||
} else if cfg.Command != "" {
|
||||
transportType = "stdio"
|
||||
} else {
|
||||
return fmt.Errorf("either URL or command must be provided")
|
||||
}
|
||||
}
|
||||
|
||||
switch transportType {
|
||||
case "sse", "http":
|
||||
if cfg.URL == "" {
|
||||
return fmt.Errorf("URL is required for SSE/HTTP transport")
|
||||
}
|
||||
logger.DebugCF("mcp", "Using SSE/HTTP transport",
|
||||
map[string]any{
|
||||
"server": name,
|
||||
"url": cfg.URL,
|
||||
})
|
||||
|
||||
sseTransport := &mcp.StreamableClientTransport{
|
||||
Endpoint: cfg.URL,
|
||||
}
|
||||
|
||||
// Add custom headers if provided
|
||||
if len(cfg.Headers) > 0 {
|
||||
// Create a custom HTTP client with header-injecting transport
|
||||
sseTransport.HTTPClient = &http.Client{
|
||||
Transport: &headerTransport{
|
||||
base: http.DefaultTransport,
|
||||
headers: cfg.Headers,
|
||||
},
|
||||
}
|
||||
logger.DebugCF("mcp", "Added custom HTTP headers",
|
||||
map[string]any{
|
||||
"server": name,
|
||||
"header_count": len(cfg.Headers),
|
||||
})
|
||||
}
|
||||
|
||||
transport = sseTransport
|
||||
case "stdio":
|
||||
if cfg.Command == "" {
|
||||
return fmt.Errorf("command is required for stdio transport")
|
||||
}
|
||||
logger.DebugCF("mcp", "Using stdio transport",
|
||||
map[string]any{
|
||||
"server": name,
|
||||
"command": cfg.Command,
|
||||
})
|
||||
// Create command with context
|
||||
cmd := exec.CommandContext(ctx, cfg.Command, cfg.Args...)
|
||||
|
||||
// Build environment variables with proper override semantics
|
||||
// Use a map to ensure config variables override file variables
|
||||
envMap := make(map[string]string)
|
||||
|
||||
// Start with parent process environment
|
||||
for _, e := range cmd.Environ() {
|
||||
if idx := strings.Index(e, "="); idx > 0 {
|
||||
envMap[e[:idx]] = e[idx+1:]
|
||||
}
|
||||
}
|
||||
|
||||
// Load environment variables from file if specified
|
||||
if cfg.EnvFile != "" {
|
||||
envVars, err := loadEnvFile(cfg.EnvFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load env file %s: %w", cfg.EnvFile, err)
|
||||
}
|
||||
for k, v := range envVars {
|
||||
envMap[k] = v
|
||||
}
|
||||
logger.DebugCF("mcp", "Loaded environment variables from file",
|
||||
map[string]any{
|
||||
"server": name,
|
||||
"envFile": cfg.EnvFile,
|
||||
"var_count": len(envVars),
|
||||
})
|
||||
}
|
||||
|
||||
// Environment variables from config override those from file
|
||||
for k, v := range cfg.Env {
|
||||
envMap[k] = v
|
||||
}
|
||||
|
||||
// Convert map to slice
|
||||
env := make([]string, 0, len(envMap))
|
||||
for k, v := range envMap {
|
||||
env = append(env, fmt.Sprintf("%s=%s", k, v))
|
||||
}
|
||||
cmd.Env = env
|
||||
|
||||
transport = &mcp.CommandTransport{Command: cmd}
|
||||
default:
|
||||
return fmt.Errorf(
|
||||
"unsupported transport type: %s (supported: stdio, sse, http)",
|
||||
transportType,
|
||||
)
|
||||
}
|
||||
|
||||
// Connect to server
|
||||
session, err := client.Connect(ctx, transport, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect: %w", err)
|
||||
}
|
||||
|
||||
// Get server info
|
||||
initResult := session.InitializeResult()
|
||||
logger.InfoCF("mcp", "Connected to MCP server",
|
||||
map[string]any{
|
||||
"server": name,
|
||||
"serverName": initResult.ServerInfo.Name,
|
||||
"serverVersion": initResult.ServerInfo.Version,
|
||||
"protocol": initResult.ProtocolVersion,
|
||||
})
|
||||
|
||||
// List available tools if supported
|
||||
var tools []*mcp.Tool
|
||||
if initResult.Capabilities.Tools != nil {
|
||||
for tool, err := range session.Tools(ctx, nil) {
|
||||
if err != nil {
|
||||
logger.WarnCF("mcp", "Error listing tool",
|
||||
map[string]any{
|
||||
"server": name,
|
||||
"error": err.Error(),
|
||||
})
|
||||
continue
|
||||
}
|
||||
tools = append(tools, tool)
|
||||
}
|
||||
|
||||
logger.InfoCF("mcp", "Listed tools from MCP server",
|
||||
map[string]any{
|
||||
"server": name,
|
||||
"toolCount": len(tools),
|
||||
})
|
||||
}
|
||||
|
||||
// Store connection
|
||||
m.mu.Lock()
|
||||
m.servers[name] = &ServerConnection{
|
||||
Name: name,
|
||||
Client: client,
|
||||
Session: session,
|
||||
Tools: tools,
|
||||
}
|
||||
m.mu.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetServers returns all connected servers
|
||||
func (m *Manager) GetServers() map[string]*ServerConnection {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
result := make(map[string]*ServerConnection, len(m.servers))
|
||||
for k, v := range m.servers {
|
||||
result[k] = v
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// GetServer returns a specific server connection
|
||||
func (m *Manager) GetServer(name string) (*ServerConnection, bool) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
conn, ok := m.servers[name]
|
||||
return conn, ok
|
||||
}
|
||||
|
||||
// CallTool calls a tool on a specific server
|
||||
func (m *Manager) CallTool(
|
||||
ctx context.Context,
|
||||
serverName, toolName string,
|
||||
arguments map[string]any,
|
||||
) (*mcp.CallToolResult, error) {
|
||||
// Check if closed before acquiring lock (fast path)
|
||||
if m.closed.Load() {
|
||||
return nil, fmt.Errorf("manager is closed")
|
||||
}
|
||||
|
||||
m.mu.RLock()
|
||||
// Double-check after acquiring lock to prevent TOCTOU race
|
||||
if m.closed.Load() {
|
||||
m.mu.RUnlock()
|
||||
return nil, fmt.Errorf("manager is closed")
|
||||
}
|
||||
conn, ok := m.servers[serverName]
|
||||
if ok {
|
||||
m.wg.Add(1) // Add to WaitGroup while holding the lock
|
||||
}
|
||||
m.mu.RUnlock()
|
||||
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("server %s not found", serverName)
|
||||
}
|
||||
defer m.wg.Done()
|
||||
|
||||
params := &mcp.CallToolParams{
|
||||
Name: toolName,
|
||||
Arguments: arguments,
|
||||
}
|
||||
|
||||
result, err := conn.Session.CallTool(ctx, params)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to call tool: %w", err)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Close closes all server connections
|
||||
func (m *Manager) Close() error {
|
||||
// Use Swap to atomically set closed=true and get the previous value
|
||||
// This prevents TOCTOU race with CallTool's closed check
|
||||
if m.closed.Swap(true) {
|
||||
return nil // already closed
|
||||
}
|
||||
|
||||
// Wait for all in-flight CallTool calls to finish before closing sessions
|
||||
// After closed=true is set, no new CallTool can start (they check closed first)
|
||||
m.wg.Wait()
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
logger.InfoCF("mcp", "Closing all MCP server connections",
|
||||
map[string]any{
|
||||
"count": len(m.servers),
|
||||
})
|
||||
|
||||
var errs []error
|
||||
for name, conn := range m.servers {
|
||||
if err := conn.Session.Close(); err != nil {
|
||||
logger.ErrorCF("mcp", "Failed to close server connection",
|
||||
map[string]any{
|
||||
"server": name,
|
||||
"error": err.Error(),
|
||||
})
|
||||
errs = append(errs, fmt.Errorf("server %s: %w", name, err))
|
||||
}
|
||||
}
|
||||
|
||||
m.servers = make(map[string]*ServerConnection)
|
||||
|
||||
if len(errs) > 0 {
|
||||
return fmt.Errorf("failed to close %d server(s): %w", len(errs), errors.Join(errs...))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAllTools returns all tools from all connected servers
|
||||
func (m *Manager) GetAllTools() map[string][]*mcp.Tool {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
result := make(map[string][]*mcp.Tool)
|
||||
for name, conn := range m.servers {
|
||||
if len(conn.Tools) > 0 {
|
||||
result[name] = conn.Tools
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
@@ -0,0 +1,298 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
func TestLoadEnvFile(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
content string
|
||||
expected map[string]string
|
||||
expectErr bool
|
||||
}{
|
||||
{
|
||||
name: "basic env file",
|
||||
content: `API_KEY=secret123
|
||||
DATABASE_URL=postgres://localhost/db
|
||||
PORT=8080`,
|
||||
expected: map[string]string{
|
||||
"API_KEY": "secret123",
|
||||
"DATABASE_URL": "postgres://localhost/db",
|
||||
"PORT": "8080",
|
||||
},
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "with comments and empty lines",
|
||||
content: `# This is a comment
|
||||
API_KEY=secret123
|
||||
|
||||
# Another comment
|
||||
DATABASE_URL=postgres://localhost/db
|
||||
|
||||
PORT=8080`,
|
||||
expected: map[string]string{
|
||||
"API_KEY": "secret123",
|
||||
"DATABASE_URL": "postgres://localhost/db",
|
||||
"PORT": "8080",
|
||||
},
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "with quoted values",
|
||||
content: `API_KEY="secret with spaces"
|
||||
NAME='single quoted'
|
||||
PLAIN=no-quotes`,
|
||||
expected: map[string]string{
|
||||
"API_KEY": "secret with spaces",
|
||||
"NAME": "single quoted",
|
||||
"PLAIN": "no-quotes",
|
||||
},
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "with spaces around equals",
|
||||
content: `API_KEY = secret123
|
||||
DATABASE_URL= postgres://localhost/db
|
||||
PORT =8080`,
|
||||
expected: map[string]string{
|
||||
"API_KEY": "secret123",
|
||||
"DATABASE_URL": "postgres://localhost/db",
|
||||
"PORT": "8080",
|
||||
},
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid format - no equals",
|
||||
content: `INVALID_LINE`,
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "empty file",
|
||||
content: ``,
|
||||
expected: map[string]string{},
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "only comments",
|
||||
content: `# Comment 1
|
||||
# Comment 2`,
|
||||
expected: map[string]string{},
|
||||
expectErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
envFile := filepath.Join(tmpDir, ".env")
|
||||
|
||||
if err := os.WriteFile(envFile, []byte(tt.content), 0o644); err != nil {
|
||||
t.Fatalf("Failed to create test file: %v", err)
|
||||
}
|
||||
|
||||
result, err := loadEnvFile(envFile)
|
||||
|
||||
if tt.expectErr {
|
||||
if err == nil {
|
||||
t.Errorf("Expected error but got none")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if len(result) != len(tt.expected) {
|
||||
t.Errorf("Expected %d variables, got %d", len(tt.expected), len(result))
|
||||
}
|
||||
|
||||
for key, expectedValue := range tt.expected {
|
||||
if actualValue, ok := result[key]; !ok {
|
||||
t.Errorf("Expected key %s not found", key)
|
||||
} else if actualValue != expectedValue {
|
||||
t.Errorf("For key %s: expected %q, got %q", key, expectedValue, actualValue)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadEnvFileNotFound(t *testing.T) {
|
||||
_, err := loadEnvFile("/nonexistent/file.env")
|
||||
if err == nil {
|
||||
t.Error("Expected error for nonexistent file")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnvFilePriority(t *testing.T) {
|
||||
// Create a temporary .env file
|
||||
tmpDir := t.TempDir()
|
||||
envFile := filepath.Join(tmpDir, ".env")
|
||||
|
||||
envContent := `API_KEY=from_file
|
||||
DATABASE_URL=from_file
|
||||
SHARED_VAR=from_file`
|
||||
|
||||
if err := os.WriteFile(envFile, []byte(envContent), 0o644); err != nil {
|
||||
t.Fatalf("Failed to create .env file: %v", err)
|
||||
}
|
||||
|
||||
// Load envFile
|
||||
envVars, err := loadEnvFile(envFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load env file: %v", err)
|
||||
}
|
||||
|
||||
// Verify envFile variables
|
||||
if envVars["API_KEY"] != "from_file" {
|
||||
t.Errorf("Expected API_KEY=from_file, got %s", envVars["API_KEY"])
|
||||
}
|
||||
|
||||
// Simulate config.Env overriding envFile
|
||||
configEnv := map[string]string{
|
||||
"SHARED_VAR": "from_config",
|
||||
"NEW_VAR": "from_config",
|
||||
}
|
||||
|
||||
// Merge: envFile first, then config overrides
|
||||
merged := make(map[string]string)
|
||||
for k, v := range envVars {
|
||||
merged[k] = v
|
||||
}
|
||||
for k, v := range configEnv {
|
||||
merged[k] = v
|
||||
}
|
||||
|
||||
// Verify priority: config.Env should override envFile
|
||||
if merged["SHARED_VAR"] != "from_config" {
|
||||
t.Errorf(
|
||||
"Expected SHARED_VAR=from_config (config should override file), got %s",
|
||||
merged["SHARED_VAR"],
|
||||
)
|
||||
}
|
||||
if merged["API_KEY"] != "from_file" {
|
||||
t.Errorf("Expected API_KEY=from_file, got %s", merged["API_KEY"])
|
||||
}
|
||||
if merged["NEW_VAR"] != "from_config" {
|
||||
t.Errorf("Expected NEW_VAR=from_config, got %s", merged["NEW_VAR"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadFromMCPConfig_EmptyWorkspaceWithRelativeEnvFile(t *testing.T) {
|
||||
mgr := NewManager()
|
||||
|
||||
mcpCfg := config.MCPConfig{
|
||||
Enabled: true,
|
||||
Servers: map[string]config.MCPServerConfig{
|
||||
"test-server": {
|
||||
Enabled: true,
|
||||
Command: "echo",
|
||||
Args: []string{"ok"},
|
||||
EnvFile: ".env",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err := mgr.LoadFromMCPConfig(context.Background(), mcpCfg, "")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for relative env_file with empty workspace path, got nil")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "workspace path is empty") {
|
||||
t.Fatalf("expected workspace path validation error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewManager_InitialState(t *testing.T) {
|
||||
mgr := NewManager()
|
||||
if mgr == nil {
|
||||
t.Fatal("expected manager instance, got nil")
|
||||
}
|
||||
if len(mgr.GetServers()) != 0 {
|
||||
t.Fatalf("expected no servers on new manager, got %d", len(mgr.GetServers()))
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadFromMCPConfig_DisabledOrEmptyServers(t *testing.T) {
|
||||
mgr := NewManager()
|
||||
|
||||
err := mgr.LoadFromMCPConfig(context.Background(), config.MCPConfig{Enabled: false}, "/tmp")
|
||||
if err != nil {
|
||||
t.Fatalf("expected nil error when MCP disabled, got: %v", err)
|
||||
}
|
||||
|
||||
err = mgr.LoadFromMCPConfig(context.Background(), config.MCPConfig{Enabled: true}, "/tmp")
|
||||
if err != nil {
|
||||
t.Fatalf("expected nil error when no servers configured, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetServers_ReturnsCopy(t *testing.T) {
|
||||
mgr := NewManager()
|
||||
mgr.servers["s1"] = &ServerConnection{Name: "s1"}
|
||||
|
||||
servers := mgr.GetServers()
|
||||
delete(servers, "s1")
|
||||
|
||||
if _, ok := mgr.GetServer("s1"); !ok {
|
||||
t.Fatal("expected internal manager state to remain unchanged")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAllTools_FiltersEmptyTools(t *testing.T) {
|
||||
mgr := NewManager()
|
||||
mgr.servers["empty"] = &ServerConnection{Name: "empty", Tools: nil}
|
||||
mgr.servers["with-tools"] = &ServerConnection{Name: "with-tools", Tools: []*sdkmcp.Tool{{}}}
|
||||
|
||||
all := mgr.GetAllTools()
|
||||
if _, ok := all["empty"]; ok {
|
||||
t.Fatal("expected server without tools to be excluded")
|
||||
}
|
||||
if _, ok := all["with-tools"]; !ok {
|
||||
t.Fatal("expected server with tools to be included")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCallTool_ErrorsForClosedOrMissingServer(t *testing.T) {
|
||||
t.Run("manager closed", func(t *testing.T) {
|
||||
mgr := NewManager()
|
||||
mgr.closed.Store(true)
|
||||
|
||||
_, err := mgr.CallTool(context.Background(), "s1", "tool", nil)
|
||||
if err == nil || !strings.Contains(err.Error(), "manager is closed") {
|
||||
t.Fatalf("expected manager closed error, got: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("server missing", func(t *testing.T) {
|
||||
mgr := NewManager()
|
||||
|
||||
_, err := mgr.CallTool(context.Background(), "missing", "tool", nil)
|
||||
if err == nil || !strings.Contains(err.Error(), "not found") {
|
||||
t.Fatalf("expected server not found error, got: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestClose_IdempotentOnEmptyManager(t *testing.T) {
|
||||
mgr := NewManager()
|
||||
|
||||
if err := mgr.Close(); err != nil {
|
||||
t.Fatalf("first close should succeed, got: %v", err)
|
||||
}
|
||||
if err := mgr.Close(); err != nil {
|
||||
t.Fatalf("second close should be idempotent, got: %v", err)
|
||||
}
|
||||
}
|
||||
+11
-11
@@ -49,7 +49,7 @@ func TestReleaseAll(t *testing.T) {
|
||||
|
||||
paths := make([]string, 3)
|
||||
refs := make([]string, 3)
|
||||
for i := 0; i < 3; i++ {
|
||||
for i := range 3 {
|
||||
paths[i] = createTempFile(t, dir, strings.Repeat("a", i+1)+".jpg")
|
||||
var err error
|
||||
refs[i], err = store.Store(paths[i], MediaMeta{Source: "test"}, "scope1")
|
||||
@@ -228,12 +228,12 @@ func TestConcurrentSafety(t *testing.T) {
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(goroutines)
|
||||
|
||||
for g := 0; g < goroutines; g++ {
|
||||
for g := range goroutines {
|
||||
go func(gIdx int) {
|
||||
defer wg.Done()
|
||||
scope := strings.Repeat("s", gIdx+1)
|
||||
|
||||
for i := 0; i < filesPerGoroutine; i++ {
|
||||
for i := range filesPerGoroutine {
|
||||
path := createTempFile(t, dir, strings.Repeat("f", gIdx*filesPerGoroutine+i+1)+".tmp")
|
||||
ref, err := store.Store(path, MediaMeta{Source: "test"}, scope)
|
||||
if err != nil {
|
||||
@@ -448,11 +448,11 @@ func TestConcurrentCleanupSafety(t *testing.T) {
|
||||
wg.Add(workers * 4)
|
||||
|
||||
// Store workers
|
||||
for w := 0; w < workers; w++ {
|
||||
for w := range workers {
|
||||
go func(wIdx int) {
|
||||
defer wg.Done()
|
||||
scope := fmt.Sprintf("scope-%d", wIdx)
|
||||
for i := 0; i < ops; i++ {
|
||||
for i := range ops {
|
||||
p := createTempFile(t, dir, fmt.Sprintf("w%d-f%d.tmp", wIdx, i))
|
||||
store.Store(p, MediaMeta{Source: "test"}, scope)
|
||||
}
|
||||
@@ -460,30 +460,30 @@ func TestConcurrentCleanupSafety(t *testing.T) {
|
||||
}
|
||||
|
||||
// Resolve workers
|
||||
for w := 0; w < workers; w++ {
|
||||
for range workers {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for i := 0; i < ops; i++ {
|
||||
for range ops {
|
||||
store.Resolve("media://nonexistent")
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// ReleaseAll workers
|
||||
for w := 0; w < workers; w++ {
|
||||
for w := range workers {
|
||||
go func(wIdx int) {
|
||||
defer wg.Done()
|
||||
for i := 0; i < ops; i++ {
|
||||
for range ops {
|
||||
store.ReleaseAll(fmt.Sprintf("scope-%d", wIdx))
|
||||
}
|
||||
}(w)
|
||||
}
|
||||
|
||||
// CleanExpired workers
|
||||
for w := 0; w < workers; w++ {
|
||||
for range workers {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for i := 0; i < ops; i++ {
|
||||
for range ops {
|
||||
store.CleanExpired()
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -118,64 +118,55 @@ func TestPlanWorkspaceMigration(t *testing.T) {
|
||||
assert.GreaterOrEqual(t, len(actions), 1)
|
||||
}
|
||||
|
||||
func TestPlanWorkspaceMigrationWithExistingDestination(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
srcWorkspace := filepath.Join(tmpDir, "src", "workspace")
|
||||
dstWorkspace := filepath.Join(tmpDir, "dst", "workspace")
|
||||
func TestPlanWorkspaceMigrationExistingFile(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
force bool
|
||||
wantActionType ActionType
|
||||
}{
|
||||
{
|
||||
name: "backup when not forced",
|
||||
force: false,
|
||||
wantActionType: ActionBackup,
|
||||
},
|
||||
{
|
||||
name: "copy when forced",
|
||||
force: true,
|
||||
wantActionType: ActionCopy,
|
||||
},
|
||||
}
|
||||
|
||||
err := os.MkdirAll(srcWorkspace, 0o755)
|
||||
require.NoError(t, err)
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
srcWorkspace := filepath.Join(tmpDir, "src", "workspace")
|
||||
dstWorkspace := filepath.Join(tmpDir, "dst", "workspace")
|
||||
|
||||
err = os.MkdirAll(dstWorkspace, 0o755)
|
||||
require.NoError(t, err)
|
||||
err := os.MkdirAll(srcWorkspace, 0o755)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = os.WriteFile(filepath.Join(srcWorkspace, "file1.txt"), []byte("source"), 0o644)
|
||||
require.NoError(t, err)
|
||||
err = os.MkdirAll(dstWorkspace, 0o755)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = os.WriteFile(filepath.Join(dstWorkspace, "file1.txt"), []byte("existing"), 0o644)
|
||||
require.NoError(t, err)
|
||||
err = os.WriteFile(filepath.Join(srcWorkspace, "file1.txt"), []byte("source"), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
actions, err := PlanWorkspaceMigration(
|
||||
srcWorkspace,
|
||||
dstWorkspace,
|
||||
[]string{"file1.txt"},
|
||||
[]string{},
|
||||
false,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
err = os.WriteFile(filepath.Join(dstWorkspace, "file1.txt"), []byte("existing"), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.GreaterOrEqual(t, len(actions), 1)
|
||||
assert.Equal(t, ActionBackup, actions[0].Type)
|
||||
}
|
||||
actions, err := PlanWorkspaceMigration(
|
||||
srcWorkspace,
|
||||
dstWorkspace,
|
||||
[]string{"file1.txt"},
|
||||
[]string{},
|
||||
tt.force,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
func TestPlanWorkspaceMigrationForce(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
srcWorkspace := filepath.Join(tmpDir, "src", "workspace")
|
||||
dstWorkspace := filepath.Join(tmpDir, "dst", "workspace")
|
||||
|
||||
err := os.MkdirAll(srcWorkspace, 0o755)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = os.MkdirAll(dstWorkspace, 0o755)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = os.WriteFile(filepath.Join(srcWorkspace, "file1.txt"), []byte("source"), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = os.WriteFile(filepath.Join(dstWorkspace, "file1.txt"), []byte("existing"), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
actions, err := PlanWorkspaceMigration(
|
||||
srcWorkspace,
|
||||
dstWorkspace,
|
||||
[]string{"file1.txt"},
|
||||
[]string{},
|
||||
true,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.GreaterOrEqual(t, len(actions), 1)
|
||||
assert.Equal(t, ActionCopy, actions[0].Type)
|
||||
require.GreaterOrEqual(t, len(actions), 1)
|
||||
assert.Equal(t, tt.wantActionType, actions[0].Type)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPlanWorkspaceMigrationNonExistentSource(t *testing.T) {
|
||||
|
||||
@@ -212,14 +212,14 @@ func translateTools(tools []ToolDefinition) []anthropic.ToolUnionParam {
|
||||
}
|
||||
|
||||
func parseResponse(resp *anthropic.Message) *LLMResponse {
|
||||
var content string
|
||||
var content strings.Builder
|
||||
var toolCalls []ToolCall
|
||||
|
||||
for _, block := range resp.Content {
|
||||
switch block.Type {
|
||||
case "text":
|
||||
tb := block.AsText()
|
||||
content += tb.Text
|
||||
content.WriteString(tb.Text)
|
||||
case "tool_use":
|
||||
tu := block.AsToolUse()
|
||||
var args map[string]any
|
||||
@@ -246,7 +246,7 @@ func parseResponse(resp *anthropic.Message) *LLMResponse {
|
||||
}
|
||||
|
||||
return &LLMResponse{
|
||||
Content: content,
|
||||
Content: content.String(),
|
||||
ToolCalls: toolCalls,
|
||||
FinishReason: finishReason,
|
||||
Usage: &UsageInfo{
|
||||
@@ -264,8 +264,8 @@ func normalizeBaseURL(apiBase string) string {
|
||||
}
|
||||
|
||||
base = strings.TrimRight(base, "/")
|
||||
if strings.HasSuffix(base, "/v1") {
|
||||
base = strings.TrimSuffix(base, "/v1")
|
||||
if before, ok := strings.CutSuffix(base, "/v1"); ok {
|
||||
base = before
|
||||
}
|
||||
if base == "" {
|
||||
return defaultBaseURL
|
||||
|
||||
@@ -100,44 +100,12 @@ func (p *ClaudeCliProvider) buildSystemPrompt(messages []Message, tools []ToolDe
|
||||
}
|
||||
|
||||
if len(tools) > 0 {
|
||||
parts = append(parts, p.buildToolsPrompt(tools))
|
||||
parts = append(parts, buildCLIToolsPrompt(tools))
|
||||
}
|
||||
|
||||
return strings.Join(parts, "\n\n")
|
||||
}
|
||||
|
||||
// buildToolsPrompt creates the tool definitions section for the system prompt.
|
||||
func (p *ClaudeCliProvider) buildToolsPrompt(tools []ToolDefinition) string {
|
||||
var sb strings.Builder
|
||||
|
||||
sb.WriteString("## Available Tools\n\n")
|
||||
sb.WriteString("When you need to use a tool, respond with ONLY a JSON object:\n\n")
|
||||
sb.WriteString("```json\n")
|
||||
sb.WriteString(
|
||||
`{"tool_calls":[{"id":"call_xxx","type":"function","function":{"name":"tool_name","arguments":"{...}"}}]}`,
|
||||
)
|
||||
sb.WriteString("\n```\n\n")
|
||||
sb.WriteString("CRITICAL: The 'arguments' field MUST be a JSON-encoded STRING.\n\n")
|
||||
sb.WriteString("### Tool Definitions:\n\n")
|
||||
|
||||
for _, tool := range tools {
|
||||
if tool.Type != "function" {
|
||||
continue
|
||||
}
|
||||
sb.WriteString(fmt.Sprintf("#### %s\n", tool.Function.Name))
|
||||
if tool.Function.Description != "" {
|
||||
sb.WriteString(fmt.Sprintf("Description: %s\n", tool.Function.Description))
|
||||
}
|
||||
if len(tool.Function.Parameters) > 0 {
|
||||
paramsJSON, _ := json.Marshal(tool.Function.Parameters)
|
||||
sb.WriteString(fmt.Sprintf("Parameters:\n```json\n%s\n```\n", string(paramsJSON)))
|
||||
}
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// parseClaudeCliResponse parses the JSON output from the claude CLI.
|
||||
func (p *ClaudeCliProvider) parseClaudeCliResponse(output string) (*LLMResponse, error) {
|
||||
var resp claudeCliJSONResponse
|
||||
|
||||
@@ -660,12 +660,11 @@ func TestBuildSystemPrompt_ToolsOnlyNoSystem(t *testing.T) {
|
||||
// --- buildToolsPrompt tests ---
|
||||
|
||||
func TestBuildToolsPrompt_SkipsNonFunction(t *testing.T) {
|
||||
p := NewClaudeCliProvider("/workspace")
|
||||
tools := []ToolDefinition{
|
||||
{Type: "other", Function: ToolFunctionDefinition{Name: "skip_me"}},
|
||||
{Type: "function", Function: ToolFunctionDefinition{Name: "include_me", Description: "Included"}},
|
||||
}
|
||||
got := p.buildToolsPrompt(tools)
|
||||
got := buildCLIToolsPrompt(tools)
|
||||
if strings.Contains(got, "skip_me") {
|
||||
t.Error("buildToolsPrompt() should skip non-function tools")
|
||||
}
|
||||
@@ -675,11 +674,10 @@ func TestBuildToolsPrompt_SkipsNonFunction(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBuildToolsPrompt_NoDescription(t *testing.T) {
|
||||
p := NewClaudeCliProvider("/workspace")
|
||||
tools := []ToolDefinition{
|
||||
{Type: "function", Function: ToolFunctionDefinition{Name: "bare_tool"}},
|
||||
}
|
||||
got := p.buildToolsPrompt(tools)
|
||||
got := buildCLIToolsPrompt(tools)
|
||||
if !strings.Contains(got, "bare_tool") {
|
||||
t.Error("should include tool name")
|
||||
}
|
||||
@@ -689,14 +687,13 @@ func TestBuildToolsPrompt_NoDescription(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBuildToolsPrompt_NoParameters(t *testing.T) {
|
||||
p := NewClaudeCliProvider("/workspace")
|
||||
tools := []ToolDefinition{
|
||||
{Type: "function", Function: ToolFunctionDefinition{
|
||||
Name: "no_params_tool",
|
||||
Description: "A tool with no parameters",
|
||||
}},
|
||||
}
|
||||
got := p.buildToolsPrompt(tools)
|
||||
got := buildCLIToolsPrompt(tools)
|
||||
if strings.Contains(got, "Parameters:") {
|
||||
t.Error("should not include Parameters: section when nil")
|
||||
}
|
||||
|
||||
@@ -115,7 +115,7 @@ func (p *CodexCliProvider) buildPrompt(messages []Message, tools []ToolDefinitio
|
||||
}
|
||||
|
||||
if len(tools) > 0 {
|
||||
sb.WriteString(p.buildToolsPrompt(tools))
|
||||
sb.WriteString(buildCLIToolsPrompt(tools))
|
||||
sb.WriteString("\n\n")
|
||||
}
|
||||
|
||||
@@ -128,38 +128,6 @@ func (p *CodexCliProvider) buildPrompt(messages []Message, tools []ToolDefinitio
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// buildToolsPrompt creates a tool definitions section for the prompt.
|
||||
func (p *CodexCliProvider) buildToolsPrompt(tools []ToolDefinition) string {
|
||||
var sb strings.Builder
|
||||
|
||||
sb.WriteString("## Available Tools\n\n")
|
||||
sb.WriteString("When you need to use a tool, respond with ONLY a JSON object:\n\n")
|
||||
sb.WriteString("```json\n")
|
||||
sb.WriteString(
|
||||
`{"tool_calls":[{"id":"call_xxx","type":"function","function":{"name":"tool_name","arguments":"{...}"}}]}`,
|
||||
)
|
||||
sb.WriteString("\n```\n\n")
|
||||
sb.WriteString("CRITICAL: The 'arguments' field MUST be a JSON-encoded STRING.\n\n")
|
||||
sb.WriteString("### Tool Definitions:\n\n")
|
||||
|
||||
for _, tool := range tools {
|
||||
if tool.Type != "function" {
|
||||
continue
|
||||
}
|
||||
sb.WriteString(fmt.Sprintf("#### %s\n", tool.Function.Name))
|
||||
if tool.Function.Description != "" {
|
||||
sb.WriteString(fmt.Sprintf("Description: %s\n", tool.Function.Description))
|
||||
}
|
||||
if len(tool.Function.Parameters) > 0 {
|
||||
paramsJSON, _ := json.Marshal(tool.Function.Parameters)
|
||||
sb.WriteString(fmt.Sprintf("Parameters:\n```json\n%s\n```\n", string(paramsJSON)))
|
||||
}
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// codexEvent represents a single JSONL event from `codex exec --json`.
|
||||
type codexEvent struct {
|
||||
Type string `json:"type"`
|
||||
|
||||
@@ -163,8 +163,8 @@ func resolveCodexModel(model string) (string, string) {
|
||||
return codexDefaultModel, "empty model"
|
||||
}
|
||||
|
||||
if strings.HasPrefix(m, "openai/") {
|
||||
m = strings.TrimPrefix(m, "openai/")
|
||||
if after, ok := strings.CutPrefix(m, "openai/"); ok {
|
||||
m = after
|
||||
} else if strings.Contains(m, "/") {
|
||||
return codexDefaultModel, "non-openai model namespace"
|
||||
}
|
||||
|
||||
@@ -138,7 +138,7 @@ func TestCooldown_FailureWindowReset(t *testing.T) {
|
||||
ct, current := newTestTracker(now)
|
||||
|
||||
// 4 errors → 1h cooldown
|
||||
for i := 0; i < 4; i++ {
|
||||
for range 4 {
|
||||
ct.MarkFailure("openai", FailoverRateLimit)
|
||||
*current = current.Add(2 * time.Second) // small advance between errors
|
||||
}
|
||||
@@ -230,7 +230,7 @@ func TestCooldown_ConcurrentAccess(t *testing.T) {
|
||||
ct := NewCooldownTracker()
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
for range 100 {
|
||||
wg.Add(3)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
@@ -6,6 +6,13 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Common patterns in Go HTTP error messages
|
||||
var httpStatusPatterns = []*regexp.Regexp{
|
||||
regexp.MustCompile(`status[:\s]+(\d{3})`),
|
||||
regexp.MustCompile(`http[/\s]+\d*\.?\d*\s+(\d{3})`),
|
||||
regexp.MustCompile(`\b([3-5]\d{2})\b`),
|
||||
}
|
||||
|
||||
// errorPattern defines a single pattern (string or regex) for error classification.
|
||||
type errorPattern struct {
|
||||
substring string
|
||||
@@ -198,20 +205,13 @@ func classifyByMessage(msg string) FailoverReason {
|
||||
}
|
||||
|
||||
// extractHTTPStatus extracts an HTTP status code from an error message.
|
||||
// Looks for patterns like "status: 429", "status 429", "HTTP 429", or standalone "429".
|
||||
// Looks for patterns like "status: 429", "status 429", "http/1.1 429", "http 429", or standalone "429".
|
||||
func extractHTTPStatus(msg string) int {
|
||||
// Common patterns in Go HTTP error messages
|
||||
patterns := []*regexp.Regexp{
|
||||
regexp.MustCompile(`status[:\s]+(\d{3})`),
|
||||
regexp.MustCompile(`HTTP[/\s]+\d*\.?\d*\s+(\d{3})`),
|
||||
}
|
||||
|
||||
for _, p := range patterns {
|
||||
for _, p := range httpStatusPatterns {
|
||||
if m := p.FindStringSubmatch(msg); len(m) > 1 {
|
||||
return parseDigits(m[1])
|
||||
}
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
|
||||
@@ -305,7 +305,8 @@ func TestExtractHTTPStatus(t *testing.T) {
|
||||
}{
|
||||
{"status: 429 rate limited", 429},
|
||||
{"status 401 unauthorized", 401},
|
||||
{"HTTP/1.1 502 Bad Gateway", 502},
|
||||
{"http/1.1 502 bad gateway", 502},
|
||||
{"error 429", 429},
|
||||
{"no status code here", 0},
|
||||
{"random number 12345", 0},
|
||||
}
|
||||
|
||||
@@ -102,6 +102,15 @@ func resolveProviderSelection(cfg *config.Config) (providerSelection, error) {
|
||||
sel.apiBase = "https://openrouter.ai/api/v1"
|
||||
}
|
||||
}
|
||||
case "litellm":
|
||||
if cfg.Providers.LiteLLM.APIKey != "" || cfg.Providers.LiteLLM.APIBase != "" {
|
||||
sel.apiKey = cfg.Providers.LiteLLM.APIKey
|
||||
sel.apiBase = cfg.Providers.LiteLLM.APIBase
|
||||
sel.proxy = cfg.Providers.LiteLLM.Proxy
|
||||
if sel.apiBase == "" {
|
||||
sel.apiBase = "http://localhost:4000/v1"
|
||||
}
|
||||
}
|
||||
case "zhipu", "glm":
|
||||
if cfg.Providers.Zhipu.APIKey != "" {
|
||||
sel.apiKey = cfg.Providers.Zhipu.APIKey
|
||||
|
||||
@@ -53,7 +53,7 @@ func ExtractProtocol(model string) (protocol, modelID string) {
|
||||
|
||||
// CreateProviderFromConfig creates a provider based on the ModelConfig.
|
||||
// It uses the protocol prefix in the Model field to determine which provider to create.
|
||||
// Supported protocols: openai, anthropic, antigravity, claude-cli, codex-cli, github-copilot
|
||||
// Supported protocols: openai, litellm, anthropic, antigravity, claude-cli, codex-cli, github-copilot
|
||||
// Returns the provider, the model ID (without protocol prefix), and any error.
|
||||
func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, error) {
|
||||
if cfg == nil {
|
||||
@@ -92,7 +92,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
|
||||
cfg.RequestTimeout,
|
||||
), modelID, nil
|
||||
|
||||
case "openrouter", "groq", "zhipu", "gemini", "nvidia",
|
||||
case "litellm", "openrouter", "groq", "zhipu", "gemini", "nvidia",
|
||||
"ollama", "moonshot", "shengsuanyun", "deepseek", "cerebras",
|
||||
"volcengine", "vllm", "qwen", "mistral":
|
||||
// All other OpenAI-compatible HTTP providers
|
||||
@@ -180,6 +180,8 @@ func getDefaultAPIBase(protocol string) string {
|
||||
return "https://api.openai.com/v1"
|
||||
case "openrouter":
|
||||
return "https://openrouter.ai/api/v1"
|
||||
case "litellm":
|
||||
return "http://localhost:4000/v1"
|
||||
case "groq":
|
||||
return "https://api.groq.com/openai/v1"
|
||||
case "zhipu":
|
||||
|
||||
@@ -135,6 +135,32 @@ func TestCreateProviderFromConfig_DefaultAPIBase(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetDefaultAPIBase_LiteLLM(t *testing.T) {
|
||||
if got := getDefaultAPIBase("litellm"); got != "http://localhost:4000/v1" {
|
||||
t.Fatalf("getDefaultAPIBase(%q) = %q, want %q", "litellm", got, "http://localhost:4000/v1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateProviderFromConfig_LiteLLM(t *testing.T) {
|
||||
cfg := &config.ModelConfig{
|
||||
ModelName: "test-litellm",
|
||||
Model: "litellm/my-proxy-alias",
|
||||
APIKey: "test-key",
|
||||
APIBase: "http://localhost:4000/v1",
|
||||
}
|
||||
|
||||
provider, modelID, err := CreateProviderFromConfig(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateProviderFromConfig() error = %v", err)
|
||||
}
|
||||
if provider == nil {
|
||||
t.Fatal("CreateProviderFromConfig() returned nil provider")
|
||||
}
|
||||
if modelID != "my-proxy-alias" {
|
||||
t.Errorf("modelID = %q, want %q", modelID, "my-proxy-alias")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateProviderFromConfig_Anthropic(t *testing.T) {
|
||||
cfg := &config.ModelConfig{
|
||||
ModelName: "test-anthropic",
|
||||
|
||||
@@ -17,6 +17,27 @@ func TestResolveProviderSelection(t *testing.T) {
|
||||
wantProxy string
|
||||
wantErrSubstr string
|
||||
}{
|
||||
{
|
||||
name: "explicit litellm provider uses configured base",
|
||||
setup: func(cfg *config.Config) {
|
||||
cfg.Agents.Defaults.Provider = "litellm"
|
||||
cfg.Providers.LiteLLM.APIKey = "litellm-key"
|
||||
cfg.Providers.LiteLLM.APIBase = "http://localhost:4000/v1"
|
||||
cfg.Providers.LiteLLM.Proxy = "http://127.0.0.1:7890"
|
||||
},
|
||||
wantType: providerTypeHTTPCompat,
|
||||
wantAPIBase: "http://localhost:4000/v1",
|
||||
wantProxy: "http://127.0.0.1:7890",
|
||||
},
|
||||
{
|
||||
name: "explicit litellm provider defaults base when only key is configured",
|
||||
setup: func(cfg *config.Config) {
|
||||
cfg.Agents.Defaults.Provider = "litellm"
|
||||
cfg.Providers.LiteLLM.APIKey = "litellm-key"
|
||||
},
|
||||
wantType: providerTypeHTTPCompat,
|
||||
wantAPIBase: "http://localhost:4000/v1",
|
||||
},
|
||||
{
|
||||
name: "explicit claude-cli provider routes to cli provider type",
|
||||
setup: func(cfg *config.Config) {
|
||||
|
||||
@@ -26,8 +26,9 @@ func NewGitHubCopilotProvider(uri string, connectMode string, model string) (*Gi
|
||||
|
||||
switch connectMode {
|
||||
case "stdio":
|
||||
// TODO:
|
||||
return nil, fmt.Errorf("stdio mode not implemented")
|
||||
// TODO: Implement stdio mode for GitHub Copilot provider
|
||||
// See https://github.com/github/copilot-sdk/blob/main/docs/getting-started.md for details
|
||||
return nil, fmt.Errorf("stdio mode not implemented for GitHub Copilot provider; please use 'grpc' mode instead")
|
||||
case "grpc":
|
||||
client := copilot.NewClient(&copilot.ClientOptions{
|
||||
CLIUrl: uri,
|
||||
@@ -100,9 +101,12 @@ func (p *GitHubCopilotProvider) Chat(
|
||||
return nil, fmt.Errorf("provider closed")
|
||||
}
|
||||
|
||||
resp, _ := session.SendAndWait(ctx, copilot.MessageOptions{
|
||||
resp, err := session.SendAndWait(ctx, copilot.MessageOptions{
|
||||
Prompt: string(fullcontent),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to send message to copilot: %w", err)
|
||||
}
|
||||
|
||||
if resp == nil {
|
||||
return nil, fmt.Errorf("empty response from copilot")
|
||||
|
||||
@@ -116,7 +116,7 @@ func (p *Provider) Chat(
|
||||
|
||||
requestBody := map[string]any{
|
||||
"model": model,
|
||||
"messages": stripSystemParts(messages),
|
||||
"messages": serializeMessages(messages),
|
||||
}
|
||||
|
||||
if len(tools) > 0 {
|
||||
@@ -296,26 +296,62 @@ type openaiMessage struct {
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
}
|
||||
|
||||
// stripSystemParts converts []Message to []openaiMessage, dropping the
|
||||
// SystemParts field so it doesn't leak into the JSON payload sent to
|
||||
// OpenAI-compatible APIs (some strict endpoints reject unknown fields).
|
||||
func stripSystemParts(messages []Message) []openaiMessage {
|
||||
out := make([]openaiMessage, len(messages))
|
||||
for i, m := range messages {
|
||||
out[i] = openaiMessage{
|
||||
Role: m.Role,
|
||||
Content: m.Content,
|
||||
ReasoningContent: m.ReasoningContent,
|
||||
ToolCalls: m.ToolCalls,
|
||||
ToolCallID: m.ToolCallID,
|
||||
// serializeMessages converts internal Message structs to the OpenAI wire format.
|
||||
// - Strips SystemParts (unknown to third-party endpoints)
|
||||
// - Converts messages with Media to multipart content format (text + image_url parts)
|
||||
// - Preserves ToolCallID, ToolCalls, and ReasoningContent for all messages
|
||||
func serializeMessages(messages []Message) []any {
|
||||
out := make([]any, 0, len(messages))
|
||||
for _, m := range messages {
|
||||
if len(m.Media) == 0 {
|
||||
out = append(out, openaiMessage{
|
||||
Role: m.Role,
|
||||
Content: m.Content,
|
||||
ReasoningContent: m.ReasoningContent,
|
||||
ToolCalls: m.ToolCalls,
|
||||
ToolCallID: m.ToolCallID,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
// Multipart content format for messages with media
|
||||
parts := make([]map[string]any, 0, 1+len(m.Media))
|
||||
if m.Content != "" {
|
||||
parts = append(parts, map[string]any{
|
||||
"type": "text",
|
||||
"text": m.Content,
|
||||
})
|
||||
}
|
||||
for _, mediaURL := range m.Media {
|
||||
parts = append(parts, map[string]any{
|
||||
"type": "image_url",
|
||||
"image_url": map[string]any{
|
||||
"url": mediaURL,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
msg := map[string]any{
|
||||
"role": m.Role,
|
||||
"content": parts,
|
||||
}
|
||||
if m.ToolCallID != "" {
|
||||
msg["tool_call_id"] = m.ToolCallID
|
||||
}
|
||||
if len(m.ToolCalls) > 0 {
|
||||
msg["tool_calls"] = m.ToolCalls
|
||||
}
|
||||
if m.ReasoningContent != "" {
|
||||
msg["reasoning_content"] = m.ReasoningContent
|
||||
}
|
||||
out = append(out, msg)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func normalizeModel(model, apiBase string) string {
|
||||
idx := strings.Index(model, "/")
|
||||
if idx == -1 {
|
||||
before, after, ok := strings.Cut(model, "/")
|
||||
if !ok {
|
||||
return model
|
||||
}
|
||||
|
||||
@@ -323,10 +359,10 @@ func normalizeModel(model, apiBase string) string {
|
||||
return model
|
||||
}
|
||||
|
||||
prefix := strings.ToLower(model[:idx])
|
||||
prefix := strings.ToLower(before)
|
||||
switch prefix {
|
||||
case "moonshot", "nvidia", "groq", "ollama", "deepseek", "google", "openrouter", "zhipu", "mistral":
|
||||
return model[idx+1:]
|
||||
case "litellm", "moonshot", "nvidia", "groq", "ollama", "deepseek", "google", "openrouter", "zhipu", "mistral":
|
||||
return after
|
||||
default:
|
||||
return model
|
||||
}
|
||||
|
||||
@@ -5,8 +5,11 @@ import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
|
||||
)
|
||||
|
||||
func TestProviderChat_UsesMaxCompletionTokensForGLM(t *testing.T) {
|
||||
@@ -146,6 +149,56 @@ func TestProviderChat_ParsesReasoningContent(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderChat_PreservesReasoningContentInHistory(t *testing.T) {
|
||||
var requestBody map[string]any
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
resp := map[string]any{
|
||||
"choices": []map[string]any{
|
||||
{
|
||||
"message": map[string]any{"content": "ok"},
|
||||
"finish_reason": "stop",
|
||||
},
|
||||
},
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
p := NewProvider("key", server.URL, "")
|
||||
|
||||
// Simulate a multi-turn conversation where the assistant's previous
|
||||
// reply included reasoning_content (e.g. from kimi-k2.5).
|
||||
messages := []Message{
|
||||
{Role: "user", Content: "What is 1+1?"},
|
||||
{Role: "assistant", Content: "2", ReasoningContent: "Let me think... 1+1=2"},
|
||||
{Role: "user", Content: "What about 2+2?"},
|
||||
}
|
||||
|
||||
_, err := p.Chat(t.Context(), messages, nil, "kimi-k2.5", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
}
|
||||
|
||||
// Verify reasoning_content is preserved in the serialized request.
|
||||
reqMessages, ok := requestBody["messages"].([]any)
|
||||
if !ok {
|
||||
t.Fatalf("messages is not []any: %T", requestBody["messages"])
|
||||
}
|
||||
assistantMsg, ok := reqMessages[1].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("assistant message is not map[string]any: %T", reqMessages[1])
|
||||
}
|
||||
if assistantMsg["reasoning_content"] != "Let me think... 1+1=2" {
|
||||
t.Errorf("reasoning_content not preserved in request, got %v", assistantMsg["reasoning_content"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderChat_HTTPError(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, "bad request", http.StatusBadRequest)
|
||||
@@ -206,6 +259,11 @@ func TestProviderChat_StripsGroqAndOllamaPrefixes(t *testing.T) {
|
||||
input string
|
||||
wantModel string
|
||||
}{
|
||||
{
|
||||
name: "strips litellm prefix and preserves proxy model name",
|
||||
input: "litellm/my-proxy-alias",
|
||||
wantModel: "my-proxy-alias",
|
||||
},
|
||||
{
|
||||
name: "strips groq prefix and keeps nested model",
|
||||
input: "groq/openai/gpt-oss-120b",
|
||||
@@ -362,61 +420,96 @@ func TestProvider_FunctionalOptionRequestTimeoutNonPositive(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestStripSystemParts_PreservesReasoningContent verifies that reasoning_content
|
||||
// is preserved in the wire message format when present, and omitted when empty.
|
||||
// Regression test for: Kimi K2 API returning 400 "reasoning_content is missing".
|
||||
func TestStripSystemParts_PreservesReasoningContent(t *testing.T) {
|
||||
messages := []Message{
|
||||
{Role: "user", Content: "What is 1+1?"},
|
||||
func TestSerializeMessages_PlainText(t *testing.T) {
|
||||
messages := []protocoltypes.Message{
|
||||
{Role: "user", Content: "hello"},
|
||||
{Role: "assistant", Content: "hi", ReasoningContent: "thinking..."},
|
||||
}
|
||||
result := serializeMessages(messages)
|
||||
|
||||
data, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var msgs []map[string]any
|
||||
json.Unmarshal(data, &msgs)
|
||||
|
||||
if msgs[0]["content"] != "hello" {
|
||||
t.Fatalf("expected plain string content, got %v", msgs[0]["content"])
|
||||
}
|
||||
if msgs[1]["reasoning_content"] != "thinking..." {
|
||||
t.Fatalf("reasoning_content not preserved, got %v", msgs[1]["reasoning_content"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestSerializeMessages_WithMedia(t *testing.T) {
|
||||
messages := []protocoltypes.Message{
|
||||
{Role: "user", Content: "describe this", Media: []string{"data:image/png;base64,abc123"}},
|
||||
}
|
||||
result := serializeMessages(messages)
|
||||
|
||||
data, _ := json.Marshal(result)
|
||||
var msgs []map[string]any
|
||||
json.Unmarshal(data, &msgs)
|
||||
|
||||
content, ok := msgs[0]["content"].([]any)
|
||||
if !ok {
|
||||
t.Fatalf("expected array content for media message, got %T", msgs[0]["content"])
|
||||
}
|
||||
if len(content) != 2 {
|
||||
t.Fatalf("expected 2 content parts, got %d", len(content))
|
||||
}
|
||||
|
||||
textPart := content[0].(map[string]any)
|
||||
if textPart["type"] != "text" || textPart["text"] != "describe this" {
|
||||
t.Fatalf("text part mismatch: %v", textPart)
|
||||
}
|
||||
|
||||
imgPart := content[1].(map[string]any)
|
||||
if imgPart["type"] != "image_url" {
|
||||
t.Fatalf("expected image_url type, got %v", imgPart["type"])
|
||||
}
|
||||
imgURL := imgPart["image_url"].(map[string]any)
|
||||
if imgURL["url"] != "data:image/png;base64,abc123" {
|
||||
t.Fatalf("image url mismatch: %v", imgURL["url"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestSerializeMessages_MediaWithToolCallID(t *testing.T) {
|
||||
messages := []protocoltypes.Message{
|
||||
{Role: "tool", Content: "image result", Media: []string{"data:image/png;base64,xyz"}, ToolCallID: "call_1"},
|
||||
}
|
||||
result := serializeMessages(messages)
|
||||
|
||||
data, _ := json.Marshal(result)
|
||||
var msgs []map[string]any
|
||||
json.Unmarshal(data, &msgs)
|
||||
|
||||
if msgs[0]["tool_call_id"] != "call_1" {
|
||||
t.Fatalf("tool_call_id not preserved with media, got %v", msgs[0]["tool_call_id"])
|
||||
}
|
||||
// Content should be multipart array
|
||||
if _, ok := msgs[0]["content"].([]any); !ok {
|
||||
t.Fatalf("expected array content, got %T", msgs[0]["content"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestSerializeMessages_StripsSystemParts(t *testing.T) {
|
||||
messages := []protocoltypes.Message{
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "The answer is 2",
|
||||
ReasoningContent: "Let me think step by step... 1+1=2",
|
||||
Role: "system",
|
||||
Content: "you are helpful",
|
||||
SystemParts: []protocoltypes.ContentBlock{
|
||||
{Type: "text", Text: "you are helpful"},
|
||||
},
|
||||
},
|
||||
{Role: "user", Content: "Thanks"},
|
||||
}
|
||||
result := serializeMessages(messages)
|
||||
|
||||
result := stripSystemParts(messages)
|
||||
|
||||
if len(result) != 3 {
|
||||
t.Fatalf("len(result) = %d, want 3", len(result))
|
||||
}
|
||||
|
||||
// Assistant message should preserve reasoning_content
|
||||
if result[1].ReasoningContent != "Let me think step by step... 1+1=2" {
|
||||
t.Errorf("ReasoningContent = %q, want %q", result[1].ReasoningContent, "Let me think step by step... 1+1=2")
|
||||
}
|
||||
|
||||
// Verify it serializes to JSON correctly
|
||||
data, err := json.Marshal(result[1])
|
||||
if err != nil {
|
||||
t.Fatalf("json.Marshal error: %v", err)
|
||||
}
|
||||
|
||||
jsonStr := string(data)
|
||||
if !contains(jsonStr, `"reasoning_content"`) {
|
||||
t.Errorf("JSON should contain reasoning_content field, got: %s", jsonStr)
|
||||
}
|
||||
|
||||
// User message should have empty reasoning_content (omitted via omitempty)
|
||||
data2, err := json.Marshal(result[0])
|
||||
if err != nil {
|
||||
t.Fatalf("json.Marshal error: %v", err)
|
||||
}
|
||||
if contains(string(data2), `"reasoning_content"`) {
|
||||
t.Errorf("JSON should omit empty reasoning_content, got: %s", string(data2))
|
||||
data, _ := json.Marshal(result)
|
||||
raw := string(data)
|
||||
if strings.Contains(raw, "system_parts") {
|
||||
t.Fatal("system_parts should not appear in serialized output")
|
||||
}
|
||||
}
|
||||
|
||||
func contains(s, substr string) bool {
|
||||
return len(s) >= len(substr) && searchString(s, substr)
|
||||
}
|
||||
|
||||
func searchString(s, substr string) bool {
|
||||
for i := 0; i+len(substr) <= len(s); i++ {
|
||||
if s[i:i+len(substr)] == substr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -65,6 +65,7 @@ type ContentBlock struct {
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
Media []string `json:"media,omitempty"`
|
||||
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||
SystemParts []ContentBlock `json:"system_parts,omitempty"` // structured system blocks for cache-aware adapters
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
|
||||
@@ -5,7 +5,43 @@
|
||||
|
||||
package providers
|
||||
|
||||
import "encoding/json"
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// buildCLIToolsPrompt creates the tool definitions section for a CLI provider system prompt.
|
||||
func buildCLIToolsPrompt(tools []ToolDefinition) string {
|
||||
var sb strings.Builder
|
||||
|
||||
sb.WriteString("## Available Tools\n\n")
|
||||
sb.WriteString("When you need to use a tool, respond with ONLY a JSON object:\n\n")
|
||||
sb.WriteString("```json\n")
|
||||
sb.WriteString(
|
||||
`{"tool_calls":[{"id":"call_xxx","type":"function","function":{"name":"tool_name","arguments":"{...}"}}]}`,
|
||||
)
|
||||
sb.WriteString("\n```\n\n")
|
||||
sb.WriteString("CRITICAL: The 'arguments' field MUST be a JSON-encoded STRING.\n\n")
|
||||
sb.WriteString("### Tool Definitions:\n\n")
|
||||
|
||||
for _, tool := range tools {
|
||||
if tool.Type != "function" {
|
||||
continue
|
||||
}
|
||||
sb.WriteString(fmt.Sprintf("#### %s\n", tool.Function.Name))
|
||||
if tool.Function.Description != "" {
|
||||
sb.WriteString(fmt.Sprintf("Description: %s\n", tool.Function.Description))
|
||||
}
|
||||
if len(tool.Function.Parameters) > 0 {
|
||||
paramsJSON, _ := json.Marshal(tool.Function.Parameters)
|
||||
sb.WriteString(fmt.Sprintf("Parameters:\n```json\n%s\n```\n", string(paramsJSON)))
|
||||
}
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// NormalizeToolCall normalizes a ToolCall to ensure all fields are properly populated.
|
||||
// It handles cases where Name/Arguments might be in different locations (top-level vs Function)
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
package routing
|
||||
|
||||
import "testing"
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNormalizeAgentID_Empty(t *testing.T) {
|
||||
if got := NormalizeAgentID(""); got != DefaultAgentID {
|
||||
@@ -57,11 +60,11 @@ func TestNormalizeAgentID_AllInvalid(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestNormalizeAgentID_TruncatesAt64(t *testing.T) {
|
||||
long := ""
|
||||
for i := 0; i < 100; i++ {
|
||||
long += "a"
|
||||
var long strings.Builder
|
||||
for range 100 {
|
||||
long.WriteString("a")
|
||||
}
|
||||
got := NormalizeAgentID(long)
|
||||
got := NormalizeAgentID(long.String())
|
||||
if len(got) > MaxAgentIDLength {
|
||||
t.Errorf("length = %d, want <= %d", len(got), MaxAgentIDLength)
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ package skills
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -18,14 +17,6 @@ type SkillInstaller struct {
|
||||
workspace string
|
||||
}
|
||||
|
||||
type AvailableSkill struct {
|
||||
Name string `json:"name"`
|
||||
Repository string `json:"repository"`
|
||||
Description string `json:"description"`
|
||||
Author string `json:"author"`
|
||||
Tags []string `json:"tags"`
|
||||
}
|
||||
|
||||
func NewSkillInstaller(workspace string) *SkillInstaller {
|
||||
return &SkillInstaller{
|
||||
workspace: workspace,
|
||||
@@ -89,35 +80,3 @@ func (si *SkillInstaller) Uninstall(skillName string) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (si *SkillInstaller) ListAvailableSkills(ctx context.Context) ([]AvailableSkill, error) {
|
||||
url := "https://raw.githubusercontent.com/sipeed/picoclaw-skills/main/skills.json"
|
||||
|
||||
client := &http.Client{Timeout: 15 * time.Second}
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
resp, err := utils.DoRequestWithRetry(client, req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch skills list: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
return nil, fmt.Errorf("failed to fetch skills list: HTTP %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
var skills []AvailableSkill
|
||||
if err := json.Unmarshal(body, &skills); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse skills list: %w", err)
|
||||
}
|
||||
|
||||
return skills, nil
|
||||
}
|
||||
|
||||
+24
-1
@@ -64,6 +64,29 @@ type SkillsLoader struct {
|
||||
builtinSkills string // builtin skills
|
||||
}
|
||||
|
||||
// SkillRoots returns all unique skill root directories used by this loader.
|
||||
// The order follows resolution priority: workspace > global > builtin.
|
||||
func (sl *SkillsLoader) SkillRoots() []string {
|
||||
roots := []string{sl.workspaceSkills, sl.globalSkills, sl.builtinSkills}
|
||||
seen := make(map[string]struct{}, len(roots))
|
||||
out := make([]string, 0, len(roots))
|
||||
|
||||
for _, root := range roots {
|
||||
trimmed := strings.TrimSpace(root)
|
||||
if trimmed == "" {
|
||||
continue
|
||||
}
|
||||
clean := filepath.Clean(trimmed)
|
||||
if _, ok := seen[clean]; ok {
|
||||
continue
|
||||
}
|
||||
seen[clean] = struct{}{}
|
||||
out = append(out, clean)
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func NewSkillsLoader(workspace string, globalSkills string, builtinSkills string) *SkillsLoader {
|
||||
return &SkillsLoader{
|
||||
workspace: workspace,
|
||||
@@ -240,7 +263,7 @@ func (sl *SkillsLoader) parseSimpleYAML(content string) map[string]string {
|
||||
normalized := strings.ReplaceAll(content, "\r\n", "\n")
|
||||
normalized = strings.ReplaceAll(normalized, "\r", "\n")
|
||||
|
||||
for _, line := range strings.Split(normalized, "\n") {
|
||||
for line := range strings.SplitSeq(normalized, "\n") {
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" || strings.HasPrefix(line, "#") {
|
||||
continue
|
||||
|
||||
@@ -326,3 +326,19 @@ func TestStripFrontmatter(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSkillRootsTrimsWhitespaceAndDedups(t *testing.T) {
|
||||
tmp := t.TempDir()
|
||||
workspace := filepath.Join(tmp, "workspace")
|
||||
global := filepath.Join(tmp, "global")
|
||||
builtin := filepath.Join(tmp, "builtin")
|
||||
|
||||
sl := NewSkillsLoader(workspace, " "+global+" ", "\t"+builtin+"\n")
|
||||
roots := sl.SkillRoots()
|
||||
|
||||
assert.Equal(t, []string{
|
||||
filepath.Join(workspace, "skills"),
|
||||
global,
|
||||
builtin,
|
||||
}, roots)
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package skills
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -183,7 +183,7 @@ func buildTrigrams(s string) []uint32 {
|
||||
}
|
||||
|
||||
// Sort and Deduplication
|
||||
sort.Slice(trigrams, func(i, j int) bool { return trigrams[i] < trigrams[j] })
|
||||
slices.Sort(trigrams)
|
||||
n := 1
|
||||
for i := 1; i < len(trigrams); i++ {
|
||||
if trigrams[i] != trigrams[i-1] {
|
||||
|
||||
@@ -153,7 +153,7 @@ func TestSearchCacheConcurrency(t *testing.T) {
|
||||
|
||||
// Concurrent writes
|
||||
go func() {
|
||||
for i := 0; i < 100; i++ {
|
||||
for i := range 100 {
|
||||
cache.Put("query-write-"+string(rune('a'+i%26)), []SearchResult{{Slug: "x"}})
|
||||
}
|
||||
done <- struct{}{}
|
||||
@@ -161,7 +161,7 @@ func TestSearchCacheConcurrency(t *testing.T) {
|
||||
|
||||
// Concurrent reads
|
||||
go func() {
|
||||
for i := 0; i < 100; i++ {
|
||||
for range 100 {
|
||||
cache.Get("query-write-a")
|
||||
}
|
||||
done <- struct{}{}
|
||||
|
||||
+9
-3
@@ -40,7 +40,9 @@ func NewManager(workspace string) *Manager {
|
||||
oldStateFile := filepath.Join(workspace, "state.json")
|
||||
|
||||
// Create state directory if it doesn't exist
|
||||
os.MkdirAll(stateDir, 0o755)
|
||||
if err := os.MkdirAll(stateDir, 0o755); err != nil {
|
||||
log.Fatalf("[FATAL] state: failed to create state directory: %v", err)
|
||||
}
|
||||
|
||||
sm := &Manager{
|
||||
workspace: workspace,
|
||||
@@ -54,13 +56,17 @@ func NewManager(workspace string) *Manager {
|
||||
if data, err := os.ReadFile(oldStateFile); err == nil {
|
||||
if err := json.Unmarshal(data, sm.state); err == nil {
|
||||
// Migrate to new location
|
||||
sm.saveAtomic()
|
||||
if err := sm.saveAtomic(); err != nil {
|
||||
log.Printf("[WARN] state: failed to save state: %v", err)
|
||||
}
|
||||
log.Printf("[INFO] state: migrated state from %s to %s", oldStateFile, stateFile)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Load from new location
|
||||
sm.load()
|
||||
if err := sm.load(); err != nil {
|
||||
log.Printf("[WARN] state: failed to load state: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return sm
|
||||
|
||||
+40
-2
@@ -2,8 +2,10 @@ package state
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
@@ -135,7 +137,7 @@ func TestConcurrentAccess(t *testing.T) {
|
||||
|
||||
// Test concurrent writes
|
||||
done := make(chan bool, 10)
|
||||
for i := 0; i < 10; i++ {
|
||||
for i := range 10 {
|
||||
go func(idx int) {
|
||||
channel := fmt.Sprintf("channel-%d", idx)
|
||||
sm.SetLastChannel(channel)
|
||||
@@ -144,7 +146,7 @@ func TestConcurrentAccess(t *testing.T) {
|
||||
}
|
||||
|
||||
// Wait for all goroutines to complete
|
||||
for i := 0; i < 10; i++ {
|
||||
for range 10 {
|
||||
<-done
|
||||
}
|
||||
|
||||
@@ -214,3 +216,39 @@ func TestNewManager_EmptyWorkspace(t *testing.T) {
|
||||
t.Error("Expected zero timestamp for new state")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewManager_MkdirFailureCrashes(t *testing.T) {
|
||||
// Since log.Fatalf calls os.Exit(1), we cannot test it normally
|
||||
// Otherwise, the test suite would stop altogether.
|
||||
// We use the standard pattern of Go: rerun this test in a subprocess.
|
||||
if os.Getenv("BE_CRASHER") == "1" {
|
||||
tmpDir := os.Getenv("CRASH_DIR")
|
||||
|
||||
statePath := filepath.Join(tmpDir, "state")
|
||||
if err := os.WriteFile(statePath, []byte("I'm a file, not a folder"), 0o644); err != nil {
|
||||
fmt.Printf("setup failed: %v", err)
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
NewManager(tmpDir)
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
tmpDir, err := os.MkdirTemp("", "state-crash-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cmd := exec.Command(os.Args[0], "-test.run=TestNewManager_MkdirFailureCrashes")
|
||||
cmd.Env = append(os.Environ(), "BE_CRASHER=1", "CRASH_DIR="+tmpDir)
|
||||
|
||||
err = cmd.Run()
|
||||
|
||||
var e *exec.ExitError
|
||||
if errors.As(err, &e) && !e.Success() {
|
||||
return
|
||||
}
|
||||
|
||||
t.Fatalf("The process ended without error, a crash was expected via os.Exit(1). Err: %v", err)
|
||||
}
|
||||
|
||||
+5
-3
@@ -3,6 +3,7 @@ package tools
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -222,7 +223,8 @@ func (t *CronTool) listJobs() *ToolResult {
|
||||
return SilentResult("No scheduled jobs")
|
||||
}
|
||||
|
||||
result := "Scheduled jobs:\n"
|
||||
var result strings.Builder
|
||||
result.WriteString("Scheduled jobs:\n")
|
||||
for _, j := range jobs {
|
||||
var scheduleInfo string
|
||||
if j.Schedule.Kind == "every" && j.Schedule.EveryMS != nil {
|
||||
@@ -234,10 +236,10 @@ func (t *CronTool) listJobs() *ToolResult {
|
||||
} else {
|
||||
scheduleInfo = "unknown"
|
||||
}
|
||||
result += fmt.Sprintf("- %s (id: %s, %s)\n", j.Name, j.ID, scheduleInfo)
|
||||
result.WriteString(fmt.Sprintf("- %s (id: %s, %s)\n", j.Name, j.ID, scheduleInfo))
|
||||
}
|
||||
|
||||
return SilentResult(result)
|
||||
return SilentResult(result.String())
|
||||
}
|
||||
|
||||
func (t *CronTool) removeJob(args map[string]any) *ToolResult {
|
||||
|
||||
+11
-14
@@ -5,6 +5,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
@@ -15,14 +16,12 @@ type EditFileTool struct {
|
||||
}
|
||||
|
||||
// NewEditFileTool creates a new EditFileTool with optional directory restriction.
|
||||
func NewEditFileTool(workspace string, restrict bool) *EditFileTool {
|
||||
var fs fileSystem
|
||||
if restrict {
|
||||
fs = &sandboxFs{workspace: workspace}
|
||||
} else {
|
||||
fs = &hostFs{}
|
||||
func NewEditFileTool(workspace string, restrict bool, allowPaths ...[]*regexp.Regexp) *EditFileTool {
|
||||
var patterns []*regexp.Regexp
|
||||
if len(allowPaths) > 0 {
|
||||
patterns = allowPaths[0]
|
||||
}
|
||||
return &EditFileTool{fs: fs}
|
||||
return &EditFileTool{fs: buildFs(workspace, restrict, patterns)}
|
||||
}
|
||||
|
||||
func (t *EditFileTool) Name() string {
|
||||
@@ -80,14 +79,12 @@ type AppendFileTool struct {
|
||||
fs fileSystem
|
||||
}
|
||||
|
||||
func NewAppendFileTool(workspace string, restrict bool) *AppendFileTool {
|
||||
var fs fileSystem
|
||||
if restrict {
|
||||
fs = &sandboxFs{workspace: workspace}
|
||||
} else {
|
||||
fs = &hostFs{}
|
||||
func NewAppendFileTool(workspace string, restrict bool, allowPaths ...[]*regexp.Regexp) *AppendFileTool {
|
||||
var patterns []*regexp.Regexp
|
||||
if len(allowPaths) > 0 {
|
||||
patterns = allowPaths[0]
|
||||
}
|
||||
return &AppendFileTool{fs: fs}
|
||||
return &AppendFileTool{fs: buildFs(workspace, restrict, patterns)}
|
||||
}
|
||||
|
||||
func (t *AppendFileTool) Name() string {
|
||||
|
||||
+67
-21
@@ -6,6 +6,7 @@ import (
|
||||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -87,14 +88,12 @@ type ReadFileTool struct {
|
||||
fs fileSystem
|
||||
}
|
||||
|
||||
func NewReadFileTool(workspace string, restrict bool) *ReadFileTool {
|
||||
var fs fileSystem
|
||||
if restrict {
|
||||
fs = &sandboxFs{workspace: workspace}
|
||||
} else {
|
||||
fs = &hostFs{}
|
||||
func NewReadFileTool(workspace string, restrict bool, allowPaths ...[]*regexp.Regexp) *ReadFileTool {
|
||||
var patterns []*regexp.Regexp
|
||||
if len(allowPaths) > 0 {
|
||||
patterns = allowPaths[0]
|
||||
}
|
||||
return &ReadFileTool{fs: fs}
|
||||
return &ReadFileTool{fs: buildFs(workspace, restrict, patterns)}
|
||||
}
|
||||
|
||||
func (t *ReadFileTool) Name() string {
|
||||
@@ -135,14 +134,12 @@ type WriteFileTool struct {
|
||||
fs fileSystem
|
||||
}
|
||||
|
||||
func NewWriteFileTool(workspace string, restrict bool) *WriteFileTool {
|
||||
var fs fileSystem
|
||||
if restrict {
|
||||
fs = &sandboxFs{workspace: workspace}
|
||||
} else {
|
||||
fs = &hostFs{}
|
||||
func NewWriteFileTool(workspace string, restrict bool, allowPaths ...[]*regexp.Regexp) *WriteFileTool {
|
||||
var patterns []*regexp.Regexp
|
||||
if len(allowPaths) > 0 {
|
||||
patterns = allowPaths[0]
|
||||
}
|
||||
return &WriteFileTool{fs: fs}
|
||||
return &WriteFileTool{fs: buildFs(workspace, restrict, patterns)}
|
||||
}
|
||||
|
||||
func (t *WriteFileTool) Name() string {
|
||||
@@ -192,14 +189,12 @@ type ListDirTool struct {
|
||||
fs fileSystem
|
||||
}
|
||||
|
||||
func NewListDirTool(workspace string, restrict bool) *ListDirTool {
|
||||
var fs fileSystem
|
||||
if restrict {
|
||||
fs = &sandboxFs{workspace: workspace}
|
||||
} else {
|
||||
fs = &hostFs{}
|
||||
func NewListDirTool(workspace string, restrict bool, allowPaths ...[]*regexp.Regexp) *ListDirTool {
|
||||
var patterns []*regexp.Regexp
|
||||
if len(allowPaths) > 0 {
|
||||
patterns = allowPaths[0]
|
||||
}
|
||||
return &ListDirTool{fs: fs}
|
||||
return &ListDirTool{fs: buildFs(workspace, restrict, patterns)}
|
||||
}
|
||||
|
||||
func (t *ListDirTool) Name() string {
|
||||
@@ -394,6 +389,57 @@ func (r *sandboxFs) ReadDir(path string) ([]os.DirEntry, error) {
|
||||
return entries, err
|
||||
}
|
||||
|
||||
// whitelistFs wraps a sandboxFs and allows access to specific paths outside
|
||||
// the workspace when they match any of the provided patterns.
|
||||
type whitelistFs struct {
|
||||
sandbox *sandboxFs
|
||||
host hostFs
|
||||
patterns []*regexp.Regexp
|
||||
}
|
||||
|
||||
func (w *whitelistFs) matches(path string) bool {
|
||||
for _, p := range w.patterns {
|
||||
if p.MatchString(path) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (w *whitelistFs) ReadFile(path string) ([]byte, error) {
|
||||
if w.matches(path) {
|
||||
return w.host.ReadFile(path)
|
||||
}
|
||||
return w.sandbox.ReadFile(path)
|
||||
}
|
||||
|
||||
func (w *whitelistFs) WriteFile(path string, data []byte) error {
|
||||
if w.matches(path) {
|
||||
return w.host.WriteFile(path, data)
|
||||
}
|
||||
return w.sandbox.WriteFile(path, data)
|
||||
}
|
||||
|
||||
func (w *whitelistFs) ReadDir(path string) ([]os.DirEntry, error) {
|
||||
if w.matches(path) {
|
||||
return w.host.ReadDir(path)
|
||||
}
|
||||
return w.sandbox.ReadDir(path)
|
||||
}
|
||||
|
||||
// buildFs returns the appropriate fileSystem implementation based on restriction
|
||||
// settings and optional path whitelist patterns.
|
||||
func buildFs(workspace string, restrict bool, patterns []*regexp.Regexp) fileSystem {
|
||||
if !restrict {
|
||||
return &hostFs{}
|
||||
}
|
||||
sandbox := &sandboxFs{workspace: workspace}
|
||||
if len(patterns) > 0 {
|
||||
return &whitelistFs{sandbox: sandbox, patterns: patterns}
|
||||
}
|
||||
return sandbox
|
||||
}
|
||||
|
||||
// Helper to get a safe relative path for os.Root usage
|
||||
func getSafeRelPath(workspace, path string) (string, error) {
|
||||
if workspace == "" {
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
@@ -486,3 +487,36 @@ func TestRootRW_Write(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, newData, content)
|
||||
}
|
||||
|
||||
// TestWhitelistFs_AllowsMatchingPaths verifies that whitelistFs allows access to
|
||||
// paths matching the whitelist patterns while blocking non-matching paths.
|
||||
func TestWhitelistFs_AllowsMatchingPaths(t *testing.T) {
|
||||
workspace := t.TempDir()
|
||||
outsideDir := t.TempDir()
|
||||
outsideFile := filepath.Join(outsideDir, "allowed.txt")
|
||||
os.WriteFile(outsideFile, []byte("outside content"), 0o644)
|
||||
|
||||
// Pattern allows access to the outsideDir.
|
||||
patterns := []*regexp.Regexp{regexp.MustCompile(`^` + regexp.QuoteMeta(outsideDir))}
|
||||
|
||||
tool := NewReadFileTool(workspace, true, patterns)
|
||||
|
||||
// Read from whitelisted path should succeed.
|
||||
result := tool.Execute(context.Background(), map[string]any{"path": outsideFile})
|
||||
if result.IsError {
|
||||
t.Errorf("expected whitelisted path to be readable, got: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "outside content") {
|
||||
t.Errorf("expected file content, got: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
// Read from non-whitelisted path outside workspace should fail.
|
||||
otherDir := t.TempDir()
|
||||
otherFile := filepath.Join(otherDir, "blocked.txt")
|
||||
os.WriteFile(otherFile, []byte("blocked"), 0o644)
|
||||
|
||||
result = tool.Execute(context.Background(), map[string]any{"path": otherFile})
|
||||
if !result.IsError {
|
||||
t.Errorf("expected non-whitelisted path to be blocked, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,246 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"hash/fnv"
|
||||
"strings"
|
||||
|
||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
)
|
||||
|
||||
// MCPManager defines the interface for MCP manager operations
|
||||
// This allows for easier testing with mock implementations
|
||||
type MCPManager interface {
|
||||
CallTool(
|
||||
ctx context.Context,
|
||||
serverName, toolName string,
|
||||
arguments map[string]any,
|
||||
) (*mcp.CallToolResult, error)
|
||||
}
|
||||
|
||||
// MCPTool wraps an MCP tool to implement the Tool interface
|
||||
type MCPTool struct {
|
||||
manager MCPManager
|
||||
serverName string
|
||||
tool *mcp.Tool
|
||||
}
|
||||
|
||||
// NewMCPTool creates a new MCP tool wrapper
|
||||
func NewMCPTool(manager MCPManager, serverName string, tool *mcp.Tool) *MCPTool {
|
||||
return &MCPTool{
|
||||
manager: manager,
|
||||
serverName: serverName,
|
||||
tool: tool,
|
||||
}
|
||||
}
|
||||
|
||||
// sanitizeIdentifierComponent normalizes a string so it can be safely used
|
||||
// as part of a tool/function identifier for downstream providers.
|
||||
// It:
|
||||
// - lowercases the string
|
||||
// - replaces any character not in [a-z0-9_-] with '_'
|
||||
// - collapses multiple consecutive '_' into a single '_'
|
||||
// - trims leading/trailing '_'
|
||||
// - falls back to "unnamed" if the result is empty
|
||||
// - truncates overly long components to a reasonable length
|
||||
func sanitizeIdentifierComponent(s string) string {
|
||||
const maxLen = 64
|
||||
|
||||
s = strings.ToLower(s)
|
||||
var b strings.Builder
|
||||
b.Grow(len(s))
|
||||
|
||||
prevUnderscore := false
|
||||
for _, r := range s {
|
||||
isAllowed := (r >= 'a' && r <= 'z') ||
|
||||
(r >= '0' && r <= '9') ||
|
||||
r == '_' || r == '-'
|
||||
|
||||
if !isAllowed {
|
||||
// Normalize any disallowed character to '_'
|
||||
if !prevUnderscore {
|
||||
b.WriteRune('_')
|
||||
prevUnderscore = true
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if r == '_' {
|
||||
if prevUnderscore {
|
||||
continue
|
||||
}
|
||||
prevUnderscore = true
|
||||
} else {
|
||||
prevUnderscore = false
|
||||
}
|
||||
|
||||
b.WriteRune(r)
|
||||
}
|
||||
|
||||
result := strings.Trim(b.String(), "_")
|
||||
if result == "" {
|
||||
result = "unnamed"
|
||||
}
|
||||
|
||||
if len(result) > maxLen {
|
||||
result = result[:maxLen]
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// Name returns the tool name, prefixed with the server name.
|
||||
// The total length is capped at 64 characters (OpenAI-compatible API limit).
|
||||
// A short hash of the original (unsanitized) server and tool names is appended
|
||||
// whenever sanitization is lossy or the name is truncated, ensuring that two
|
||||
// names which differ only in disallowed characters remain distinct after sanitization.
|
||||
func (t *MCPTool) Name() string {
|
||||
// Prefix with server name to avoid conflicts, and sanitize components
|
||||
sanitizedServer := sanitizeIdentifierComponent(t.serverName)
|
||||
sanitizedTool := sanitizeIdentifierComponent(t.tool.Name)
|
||||
full := fmt.Sprintf("mcp_%s_%s", sanitizedServer, sanitizedTool)
|
||||
|
||||
// Check if sanitization was lossless (only lowercasing, no char replacement/truncation)
|
||||
lossless := strings.ToLower(t.serverName) == sanitizedServer &&
|
||||
strings.ToLower(t.tool.Name) == sanitizedTool
|
||||
|
||||
const maxTotal = 64
|
||||
if lossless && len(full) <= maxTotal {
|
||||
return full
|
||||
}
|
||||
|
||||
// Sanitization was lossy or name too long: append hash of the ORIGINAL names
|
||||
// (not the sanitized names) so different originals always yield different hashes.
|
||||
h := fnv.New32a()
|
||||
_, _ = h.Write([]byte(t.serverName + "\x00" + t.tool.Name))
|
||||
suffix := fmt.Sprintf("%08x", h.Sum32()) // 8 chars
|
||||
|
||||
base := full
|
||||
if len(base) > maxTotal-9 {
|
||||
base = strings.TrimRight(full[:maxTotal-9], "_")
|
||||
}
|
||||
return base + "_" + suffix
|
||||
}
|
||||
|
||||
// Description returns the tool description
|
||||
func (t *MCPTool) Description() string {
|
||||
desc := t.tool.Description
|
||||
if desc == "" {
|
||||
desc = fmt.Sprintf("MCP tool from %s server", t.serverName)
|
||||
}
|
||||
// Add server info to description
|
||||
return fmt.Sprintf("[MCP:%s] %s", t.serverName, desc)
|
||||
}
|
||||
|
||||
// Parameters returns the tool parameters schema
|
||||
func (t *MCPTool) Parameters() map[string]any {
|
||||
// The InputSchema is already a JSON Schema object
|
||||
schema := t.tool.InputSchema
|
||||
|
||||
// Handle nil schema
|
||||
if schema == nil {
|
||||
return map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{},
|
||||
"required": []string{},
|
||||
}
|
||||
}
|
||||
|
||||
// Try direct conversion first (fast path)
|
||||
if schemaMap, ok := schema.(map[string]any); ok {
|
||||
return schemaMap
|
||||
}
|
||||
|
||||
// Handle json.RawMessage and []byte - unmarshal directly
|
||||
var jsonData []byte
|
||||
if rawMsg, ok := schema.(json.RawMessage); ok {
|
||||
jsonData = rawMsg
|
||||
} else if bytes, ok := schema.([]byte); ok {
|
||||
jsonData = bytes
|
||||
}
|
||||
|
||||
if jsonData != nil {
|
||||
var result map[string]any
|
||||
if err := json.Unmarshal(jsonData, &result); err == nil {
|
||||
return result
|
||||
}
|
||||
// Fallback on error
|
||||
return map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{},
|
||||
"required": []string{},
|
||||
}
|
||||
}
|
||||
|
||||
// For other types (structs, etc.), convert via JSON marshal/unmarshal
|
||||
var err error
|
||||
jsonData, err = json.Marshal(schema)
|
||||
if err != nil {
|
||||
// Fallback to empty schema if marshaling fails
|
||||
return map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{},
|
||||
"required": []string{},
|
||||
}
|
||||
}
|
||||
|
||||
var result map[string]any
|
||||
if err := json.Unmarshal(jsonData, &result); err != nil {
|
||||
// Fallback to empty schema if unmarshaling fails
|
||||
return map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{},
|
||||
"required": []string{},
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// Execute executes the MCP tool
|
||||
func (t *MCPTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
|
||||
result, err := t.manager.CallTool(ctx, t.serverName, t.tool.Name, args)
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("MCP tool execution failed: %v", err)).WithError(err)
|
||||
}
|
||||
|
||||
if result == nil {
|
||||
nilErr := fmt.Errorf("MCP tool returned nil result without error")
|
||||
return ErrorResult("MCP tool execution failed: nil result").WithError(nilErr)
|
||||
}
|
||||
|
||||
// Handle error result from server
|
||||
if result.IsError {
|
||||
errMsg := extractContentText(result.Content)
|
||||
return ErrorResult(fmt.Sprintf("MCP tool returned error: %s", errMsg)).
|
||||
WithError(fmt.Errorf("MCP tool error: %s", errMsg))
|
||||
}
|
||||
|
||||
// Extract text content from result
|
||||
output := extractContentText(result.Content)
|
||||
|
||||
return &ToolResult{
|
||||
ForLLM: output,
|
||||
IsError: false,
|
||||
}
|
||||
}
|
||||
|
||||
// extractContentText extracts text from MCP content array
|
||||
func extractContentText(content []mcp.Content) string {
|
||||
var parts []string
|
||||
for _, c := range content {
|
||||
switch v := c.(type) {
|
||||
case *mcp.TextContent:
|
||||
parts = append(parts, v.Text)
|
||||
case *mcp.ImageContent:
|
||||
// For images, just indicate that an image was returned
|
||||
parts = append(parts, fmt.Sprintf("[Image: %s]", v.MIMEType))
|
||||
default:
|
||||
// For other content types, use string representation
|
||||
parts = append(parts, fmt.Sprintf("[Content: %T]", v))
|
||||
}
|
||||
}
|
||||
return strings.Join(parts, "\n")
|
||||
}
|
||||
@@ -0,0 +1,492 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
)
|
||||
|
||||
// MockMCPManager is a mock implementation of MCPManager interface for testing
|
||||
type MockMCPManager struct {
|
||||
callToolFunc func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error)
|
||||
}
|
||||
|
||||
func (m *MockMCPManager) CallTool(
|
||||
ctx context.Context,
|
||||
serverName, toolName string,
|
||||
arguments map[string]any,
|
||||
) (*mcp.CallToolResult, error) {
|
||||
if m.callToolFunc != nil {
|
||||
return m.callToolFunc(ctx, serverName, toolName, arguments)
|
||||
}
|
||||
return &mcp.CallToolResult{
|
||||
Content: []mcp.Content{
|
||||
&mcp.TextContent{Text: "mock result"},
|
||||
},
|
||||
IsError: false,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// TestNewMCPTool verifies MCP tool creation
|
||||
func TestNewMCPTool(t *testing.T) {
|
||||
manager := &MockMCPManager{}
|
||||
tool := &mcp.Tool{
|
||||
Name: "test_tool",
|
||||
Description: "A test tool",
|
||||
InputSchema: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"input": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Test input",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
mcpTool := NewMCPTool(manager, "test_server", tool)
|
||||
|
||||
if mcpTool == nil {
|
||||
t.Fatal("NewMCPTool should not return nil")
|
||||
}
|
||||
// Verify tool properties we can access
|
||||
if mcpTool.Name() != "mcp_test_server_test_tool" {
|
||||
t.Errorf("Expected tool name with prefix, got '%s'", mcpTool.Name())
|
||||
}
|
||||
}
|
||||
|
||||
// TestMCPTool_Name verifies tool name with server prefix
|
||||
func TestMCPTool_Name(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
serverName string
|
||||
toolName string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "simple name",
|
||||
serverName: "github",
|
||||
toolName: "create_issue",
|
||||
expected: "mcp_github_create_issue",
|
||||
},
|
||||
{
|
||||
name: "filesystem server",
|
||||
serverName: "filesystem",
|
||||
toolName: "read_file",
|
||||
expected: "mcp_filesystem_read_file",
|
||||
},
|
||||
{
|
||||
name: "remote server",
|
||||
serverName: "remote-api",
|
||||
toolName: "fetch_data",
|
||||
expected: "mcp_remote-api_fetch_data",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
manager := &MockMCPManager{}
|
||||
tool := &mcp.Tool{Name: tt.toolName}
|
||||
mcpTool := NewMCPTool(manager, tt.serverName, tool)
|
||||
|
||||
result := mcpTool.Name()
|
||||
if result != tt.expected {
|
||||
t.Errorf("Expected name '%s', got '%s'", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestMCPTool_Description verifies tool description generation
|
||||
func TestMCPTool_Description(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
serverName string
|
||||
toolDescription string
|
||||
expectContains []string
|
||||
}{
|
||||
{
|
||||
name: "with description",
|
||||
serverName: "github",
|
||||
toolDescription: "Create a GitHub issue",
|
||||
expectContains: []string{"[MCP:github]", "Create a GitHub issue"},
|
||||
},
|
||||
{
|
||||
name: "empty description",
|
||||
serverName: "filesystem",
|
||||
toolDescription: "",
|
||||
expectContains: []string{"[MCP:filesystem]", "MCP tool from filesystem server"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
manager := &MockMCPManager{}
|
||||
tool := &mcp.Tool{
|
||||
Name: "test_tool",
|
||||
Description: tt.toolDescription,
|
||||
}
|
||||
mcpTool := NewMCPTool(manager, tt.serverName, tool)
|
||||
|
||||
result := mcpTool.Description()
|
||||
|
||||
for _, expected := range tt.expectContains {
|
||||
if !strings.Contains(result, expected) {
|
||||
t.Errorf("Description should contain '%s', got: %s", expected, result)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestMCPTool_Parameters verifies parameter schema conversion
|
||||
func TestMCPTool_Parameters(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
inputSchema any
|
||||
expectType string
|
||||
checkProperty string
|
||||
expectProperty bool
|
||||
}{
|
||||
{
|
||||
name: "map schema",
|
||||
inputSchema: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"query": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Search query",
|
||||
},
|
||||
},
|
||||
"required": []string{"query"},
|
||||
},
|
||||
expectType: "object",
|
||||
checkProperty: "query",
|
||||
expectProperty: true,
|
||||
},
|
||||
{
|
||||
name: "nil schema",
|
||||
inputSchema: nil,
|
||||
expectType: "object",
|
||||
expectProperty: false,
|
||||
},
|
||||
{
|
||||
name: "json.RawMessage schema",
|
||||
inputSchema: []byte(`{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"repo": {
|
||||
"type": "string",
|
||||
"description": "Repository name"
|
||||
},
|
||||
"stars": {
|
||||
"type": "integer",
|
||||
"description": "Minimum stars"
|
||||
}
|
||||
},
|
||||
"required": ["repo"]
|
||||
}`),
|
||||
expectType: "object",
|
||||
checkProperty: "repo",
|
||||
expectProperty: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
manager := &MockMCPManager{}
|
||||
tool := &mcp.Tool{
|
||||
Name: "test_tool",
|
||||
InputSchema: tt.inputSchema,
|
||||
}
|
||||
mcpTool := NewMCPTool(manager, "test_server", tool)
|
||||
|
||||
params := mcpTool.Parameters()
|
||||
|
||||
if params == nil {
|
||||
t.Fatal("Parameters should not be nil")
|
||||
}
|
||||
|
||||
if params["type"] != tt.expectType {
|
||||
t.Errorf("Expected type '%s', got '%v'", tt.expectType, params["type"])
|
||||
}
|
||||
|
||||
// Check if property exists when expected
|
||||
if tt.checkProperty != "" {
|
||||
properties, ok := params["properties"].(map[string]any)
|
||||
if !ok && tt.expectProperty {
|
||||
t.Errorf("Expected properties to be a map")
|
||||
return
|
||||
}
|
||||
if ok {
|
||||
_, hasProperty := properties[tt.checkProperty]
|
||||
if hasProperty != tt.expectProperty {
|
||||
t.Errorf("Expected property '%s' existence: %v, got: %v",
|
||||
tt.checkProperty, tt.expectProperty, hasProperty)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestMCPTool_Execute_Success tests successful tool execution
|
||||
func TestMCPTool_Execute_Success(t *testing.T) {
|
||||
manager := &MockMCPManager{
|
||||
callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error) {
|
||||
// Verify correct parameters passed
|
||||
if serverName != "github" {
|
||||
t.Errorf("Expected serverName 'github', got '%s'", serverName)
|
||||
}
|
||||
if toolName != "search_repos" {
|
||||
t.Errorf("Expected toolName 'search_repos', got '%s'", toolName)
|
||||
}
|
||||
|
||||
return &mcp.CallToolResult{
|
||||
Content: []mcp.Content{
|
||||
&mcp.TextContent{Text: "Found 3 repositories"},
|
||||
},
|
||||
IsError: false,
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
tool := &mcp.Tool{
|
||||
Name: "search_repos",
|
||||
Description: "Search GitHub repositories",
|
||||
}
|
||||
mcpTool := NewMCPTool(manager, "github", tool)
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]any{
|
||||
"query": "golang mcp",
|
||||
}
|
||||
|
||||
result := mcpTool.Execute(ctx, args)
|
||||
|
||||
if result == nil {
|
||||
t.Fatal("Result should not be nil")
|
||||
}
|
||||
if result.IsError {
|
||||
t.Errorf("Expected no error, got error: %s", result.ForLLM)
|
||||
}
|
||||
if result.ForLLM != "Found 3 repositories" {
|
||||
t.Errorf("Expected 'Found 3 repositories', got '%s'", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMCPTool_Execute_ManagerError tests execution when manager returns error
|
||||
func TestMCPTool_Execute_ManagerError(t *testing.T) {
|
||||
manager := &MockMCPManager{
|
||||
callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error) {
|
||||
return nil, fmt.Errorf("connection failed")
|
||||
},
|
||||
}
|
||||
|
||||
tool := &mcp.Tool{Name: "test_tool"}
|
||||
mcpTool := NewMCPTool(manager, "test_server", tool)
|
||||
|
||||
ctx := context.Background()
|
||||
result := mcpTool.Execute(ctx, map[string]any{})
|
||||
|
||||
if result == nil {
|
||||
t.Fatal("Result should not be nil")
|
||||
}
|
||||
if !result.IsError {
|
||||
t.Error("Expected IsError to be true")
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "MCP tool execution failed") {
|
||||
t.Errorf("Error message should mention execution failure, got: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "connection failed") {
|
||||
t.Errorf("Error message should include original error, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMCPTool_Execute_ServerError tests execution when server returns error
|
||||
func TestMCPTool_Execute_ServerError(t *testing.T) {
|
||||
manager := &MockMCPManager{
|
||||
callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error) {
|
||||
return &mcp.CallToolResult{
|
||||
Content: []mcp.Content{
|
||||
&mcp.TextContent{Text: "Invalid API key"},
|
||||
},
|
||||
IsError: true,
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
tool := &mcp.Tool{Name: "test_tool"}
|
||||
mcpTool := NewMCPTool(manager, "test_server", tool)
|
||||
|
||||
ctx := context.Background()
|
||||
result := mcpTool.Execute(ctx, map[string]any{})
|
||||
|
||||
if result == nil {
|
||||
t.Fatal("Result should not be nil")
|
||||
}
|
||||
if !result.IsError {
|
||||
t.Error("Expected IsError to be true")
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "MCP tool returned error") {
|
||||
t.Errorf("Error message should mention server error, got: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "Invalid API key") {
|
||||
t.Errorf("Error message should include server message, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMCPTool_Execute_MultipleContent tests execution with multiple content items
|
||||
func TestMCPTool_Execute_MultipleContent(t *testing.T) {
|
||||
manager := &MockMCPManager{
|
||||
callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error) {
|
||||
return &mcp.CallToolResult{
|
||||
Content: []mcp.Content{
|
||||
&mcp.TextContent{Text: "First line"},
|
||||
&mcp.TextContent{Text: "Second line"},
|
||||
&mcp.TextContent{Text: "Third line"},
|
||||
},
|
||||
IsError: false,
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
tool := &mcp.Tool{Name: "multi_output"}
|
||||
mcpTool := NewMCPTool(manager, "test_server", tool)
|
||||
|
||||
ctx := context.Background()
|
||||
result := mcpTool.Execute(ctx, map[string]any{})
|
||||
|
||||
if result.IsError {
|
||||
t.Errorf("Expected no error, got: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
expected := "First line\nSecond line\nThird line"
|
||||
if result.ForLLM != expected {
|
||||
t.Errorf("Expected '%s', got '%s'", expected, result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
// TestExtractContentText_TextContent tests text content extraction
|
||||
func TestExtractContentText_TextContent(t *testing.T) {
|
||||
content := []mcp.Content{
|
||||
&mcp.TextContent{Text: "Hello World"},
|
||||
&mcp.TextContent{Text: "Second message"},
|
||||
}
|
||||
|
||||
result := extractContentText(content)
|
||||
expected := "Hello World\nSecond message"
|
||||
|
||||
if result != expected {
|
||||
t.Errorf("Expected '%s', got '%s'", expected, result)
|
||||
}
|
||||
}
|
||||
|
||||
// TestExtractContentText_ImageContent tests image content extraction
|
||||
func TestExtractContentText_ImageContent(t *testing.T) {
|
||||
content := []mcp.Content{
|
||||
&mcp.ImageContent{
|
||||
Data: []byte("base64data"),
|
||||
MIMEType: "image/png",
|
||||
},
|
||||
}
|
||||
|
||||
result := extractContentText(content)
|
||||
|
||||
if !strings.Contains(result, "[Image:") {
|
||||
t.Errorf("Expected image indicator, got: %s", result)
|
||||
}
|
||||
if !strings.Contains(result, "image/png") {
|
||||
t.Errorf("Expected MIME type in output, got: %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
// TestExtractContentText_MixedContent tests mixed content types
|
||||
func TestExtractContentText_MixedContent(t *testing.T) {
|
||||
content := []mcp.Content{
|
||||
&mcp.TextContent{Text: "Description"},
|
||||
&mcp.ImageContent{
|
||||
Data: []byte("data"),
|
||||
MIMEType: "image/jpeg",
|
||||
},
|
||||
&mcp.TextContent{Text: "More text"},
|
||||
}
|
||||
|
||||
result := extractContentText(content)
|
||||
|
||||
if !strings.Contains(result, "Description") {
|
||||
t.Errorf("Should contain text content, got: %s", result)
|
||||
}
|
||||
if !strings.Contains(result, "[Image:") {
|
||||
t.Errorf("Should contain image indicator, got: %s", result)
|
||||
}
|
||||
if !strings.Contains(result, "More text") {
|
||||
t.Errorf("Should contain second text, got: %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
// TestExtractContentText_EmptyContent tests empty content array
|
||||
func TestExtractContentText_EmptyContent(t *testing.T) {
|
||||
content := []mcp.Content{}
|
||||
|
||||
result := extractContentText(content)
|
||||
|
||||
if result != "" {
|
||||
t.Errorf("Expected empty string for empty content, got: %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMCPTool_InterfaceCompliance verifies MCPTool implements Tool interface
|
||||
func TestMCPTool_InterfaceCompliance(t *testing.T) {
|
||||
manager := &MockMCPManager{}
|
||||
tool := &mcp.Tool{Name: "test"}
|
||||
mcpTool := NewMCPTool(manager, "test_server", tool)
|
||||
|
||||
// Verify it implements Tool interface
|
||||
var _ Tool = mcpTool
|
||||
}
|
||||
|
||||
// TestMCPTool_Parameters_MapSchema tests schema that's already a map
|
||||
func TestMCPTool_Parameters_MapSchema(t *testing.T) {
|
||||
manager := &MockMCPManager{}
|
||||
schema := map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"name": map[string]any{
|
||||
"type": "string",
|
||||
"description": "The name parameter",
|
||||
},
|
||||
},
|
||||
"required": []string{"name"},
|
||||
}
|
||||
|
||||
tool := &mcp.Tool{
|
||||
Name: "test_tool",
|
||||
InputSchema: schema,
|
||||
}
|
||||
mcpTool := NewMCPTool(manager, "test_server", tool)
|
||||
|
||||
params := mcpTool.Parameters()
|
||||
|
||||
// Should return the schema as-is when it's already a map
|
||||
if params["type"] != "object" {
|
||||
t.Errorf("Expected type 'object', got '%v'", params["type"])
|
||||
}
|
||||
|
||||
props, ok := params["properties"].(map[string]any)
|
||||
if !ok {
|
||||
t.Error("Properties should be a map")
|
||||
}
|
||||
|
||||
nameParam, ok := props["name"].(map[string]any)
|
||||
if !ok {
|
||||
t.Error("Name parameter should exist")
|
||||
}
|
||||
|
||||
if nameParam["type"] != "string" {
|
||||
t.Errorf("Name type should be 'string', got '%v'", nameParam["type"])
|
||||
}
|
||||
}
|
||||
@@ -25,7 +25,12 @@ func NewToolRegistry() *ToolRegistry {
|
||||
func (r *ToolRegistry) Register(tool Tool) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.tools[tool.Name()] = tool
|
||||
name := tool.Name()
|
||||
if _, exists := r.tools[name]; exists {
|
||||
logger.WarnCF("tools", "Tool registration overwrites existing tool",
|
||||
map[string]any{"name": name})
|
||||
}
|
||||
r.tools[name] = tool
|
||||
}
|
||||
|
||||
func (r *ToolRegistry) Get(name string) (Tool, bool) {
|
||||
|
||||
@@ -329,7 +329,7 @@ func TestToolRegistry_ConcurrentAccess(t *testing.T) {
|
||||
r := NewToolRegistry()
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for i := 0; i < 50; i++ {
|
||||
for i := range 50 {
|
||||
wg.Add(1)
|
||||
go func(n int) {
|
||||
defer wg.Done()
|
||||
|
||||
+96
-50
@@ -21,53 +21,77 @@ type ExecTool struct {
|
||||
timeout time.Duration
|
||||
denyPatterns []*regexp.Regexp
|
||||
allowPatterns []*regexp.Regexp
|
||||
customAllowPatterns []*regexp.Regexp
|
||||
restrictToWorkspace bool
|
||||
}
|
||||
|
||||
var defaultDenyPatterns = []*regexp.Regexp{
|
||||
regexp.MustCompile(`\brm\s+-[rf]{1,2}\b`),
|
||||
regexp.MustCompile(`\bdel\s+/[fq]\b`),
|
||||
regexp.MustCompile(`\brmdir\s+/s\b`),
|
||||
regexp.MustCompile(`(?:^|[;&|]\s*|\s+)(format|mkfs|diskpart)\s`), // Match disk wiping commands, avoid matching --format flags
|
||||
regexp.MustCompile(`\bdd\s+if=`),
|
||||
regexp.MustCompile(`>\s*/dev/sd[a-z]\b`), // Block writes to disk devices (but allow /dev/null)
|
||||
regexp.MustCompile(`\b(shutdown|reboot|poweroff)\b`),
|
||||
regexp.MustCompile(`:\(\)\s*\{.*\};\s*:`),
|
||||
regexp.MustCompile(`\$\([^)]+\)`),
|
||||
regexp.MustCompile(`\$\{[^}]+\}`),
|
||||
regexp.MustCompile("`[^`]+`"),
|
||||
regexp.MustCompile(`\|\s*sh\b`),
|
||||
regexp.MustCompile(`\|\s*bash\b`),
|
||||
regexp.MustCompile(`;\s*rm\s+-[rf]`),
|
||||
regexp.MustCompile(`&&\s*rm\s+-[rf]`),
|
||||
regexp.MustCompile(`\|\|\s*rm\s+-[rf]`),
|
||||
regexp.MustCompile(`>\s*/dev/null\s*>&?\s*\d?`),
|
||||
regexp.MustCompile(`<<\s*EOF`),
|
||||
regexp.MustCompile(`\$\(\s*cat\s+`),
|
||||
regexp.MustCompile(`\$\(\s*curl\s+`),
|
||||
regexp.MustCompile(`\$\(\s*wget\s+`),
|
||||
regexp.MustCompile(`\$\(\s*which\s+`),
|
||||
regexp.MustCompile(`\bsudo\b`),
|
||||
regexp.MustCompile(`\bchmod\s+[0-7]{3,4}\b`),
|
||||
regexp.MustCompile(`\bchown\b`),
|
||||
regexp.MustCompile(`\bpkill\b`),
|
||||
regexp.MustCompile(`\bkillall\b`),
|
||||
regexp.MustCompile(`\bkill\s+-[9]\b`),
|
||||
regexp.MustCompile(`\bcurl\b.*\|\s*(sh|bash)`),
|
||||
regexp.MustCompile(`\bwget\b.*\|\s*(sh|bash)`),
|
||||
regexp.MustCompile(`\bnpm\s+install\s+-g\b`),
|
||||
regexp.MustCompile(`\bpip\s+install\s+--user\b`),
|
||||
regexp.MustCompile(`\bapt\s+(install|remove|purge)\b`),
|
||||
regexp.MustCompile(`\byum\s+(install|remove)\b`),
|
||||
regexp.MustCompile(`\bdnf\s+(install|remove)\b`),
|
||||
regexp.MustCompile(`\bdocker\s+run\b`),
|
||||
regexp.MustCompile(`\bdocker\s+exec\b`),
|
||||
regexp.MustCompile(`\bgit\s+push\b`),
|
||||
regexp.MustCompile(`\bgit\s+force\b`),
|
||||
regexp.MustCompile(`\bssh\b.*@`),
|
||||
regexp.MustCompile(`\beval\b`),
|
||||
regexp.MustCompile(`\bsource\s+.*\.sh\b`),
|
||||
}
|
||||
var (
|
||||
defaultDenyPatterns = []*regexp.Regexp{
|
||||
regexp.MustCompile(`\brm\s+-[rf]{1,2}\b`),
|
||||
regexp.MustCompile(`\bdel\s+/[fq]\b`),
|
||||
regexp.MustCompile(`\brmdir\s+/s\b`),
|
||||
// Match disk wiping commands, avoid matching --format flags
|
||||
regexp.MustCompile(
|
||||
`(?:^|[;&|]\s*|\s+)(format|mkfs|diskpart)\s`,
|
||||
),
|
||||
regexp.MustCompile(`\bdd\s+if=`),
|
||||
// Block writes to block devices (all common naming schemes).
|
||||
regexp.MustCompile(
|
||||
`>\s*/dev/(sd[a-z]|hd[a-z]|vd[a-z]|xvd[a-z]|nvme\d|mmcblk\d|loop\d|dm-\d|md\d|sr\d|nbd\d)`,
|
||||
),
|
||||
regexp.MustCompile(`\b(shutdown|reboot|poweroff)\b`),
|
||||
regexp.MustCompile(`:\(\)\s*\{.*\};\s*:`),
|
||||
regexp.MustCompile(`\$\([^)]+\)`),
|
||||
regexp.MustCompile(`\$\{[^}]+\}`),
|
||||
regexp.MustCompile("`[^`]+`"),
|
||||
regexp.MustCompile(`\|\s*sh\b`),
|
||||
regexp.MustCompile(`\|\s*bash\b`),
|
||||
regexp.MustCompile(`;\s*rm\s+-[rf]`),
|
||||
regexp.MustCompile(`&&\s*rm\s+-[rf]`),
|
||||
regexp.MustCompile(`\|\|\s*rm\s+-[rf]`),
|
||||
regexp.MustCompile(`<<\s*EOF`),
|
||||
regexp.MustCompile(`\$\(\s*cat\s+`),
|
||||
regexp.MustCompile(`\$\(\s*curl\s+`),
|
||||
regexp.MustCompile(`\$\(\s*wget\s+`),
|
||||
regexp.MustCompile(`\$\(\s*which\s+`),
|
||||
regexp.MustCompile(`\bsudo\b`),
|
||||
regexp.MustCompile(`\bchmod\s+[0-7]{3,4}\b`),
|
||||
regexp.MustCompile(`\bchown\b`),
|
||||
regexp.MustCompile(`\bpkill\b`),
|
||||
regexp.MustCompile(`\bkillall\b`),
|
||||
regexp.MustCompile(`\bkill\s+-[9]\b`),
|
||||
regexp.MustCompile(`\bcurl\b.*\|\s*(sh|bash)`),
|
||||
regexp.MustCompile(`\bwget\b.*\|\s*(sh|bash)`),
|
||||
regexp.MustCompile(`\bnpm\s+install\s+-g\b`),
|
||||
regexp.MustCompile(`\bpip\s+install\s+--user\b`),
|
||||
regexp.MustCompile(`\bapt\s+(install|remove|purge)\b`),
|
||||
regexp.MustCompile(`\byum\s+(install|remove)\b`),
|
||||
regexp.MustCompile(`\bdnf\s+(install|remove)\b`),
|
||||
regexp.MustCompile(`\bdocker\s+run\b`),
|
||||
regexp.MustCompile(`\bdocker\s+exec\b`),
|
||||
regexp.MustCompile(`\bgit\s+push\b`),
|
||||
regexp.MustCompile(`\bgit\s+force\b`),
|
||||
regexp.MustCompile(`\bssh\b.*@`),
|
||||
regexp.MustCompile(`\beval\b`),
|
||||
regexp.MustCompile(`\bsource\s+.*\.sh\b`),
|
||||
}
|
||||
|
||||
// absolutePathPattern matches absolute file paths in commands (Unix and Windows).
|
||||
absolutePathPattern = regexp.MustCompile(`[A-Za-z]:\\[^\\\"']+|/[^\s\"']+`)
|
||||
|
||||
// safePaths are kernel pseudo-devices that are always safe to reference in
|
||||
// commands, regardless of workspace restriction. They contain no user data
|
||||
// and cannot cause destructive writes.
|
||||
safePaths = map[string]bool{
|
||||
"/dev/null": true,
|
||||
"/dev/zero": true,
|
||||
"/dev/random": true,
|
||||
"/dev/urandom": true,
|
||||
"/dev/stdin": true,
|
||||
"/dev/stdout": true,
|
||||
"/dev/stderr": true,
|
||||
}
|
||||
)
|
||||
|
||||
func NewExecTool(workingDir string, restrict bool) (*ExecTool, error) {
|
||||
return NewExecToolWithConfig(workingDir, restrict, nil)
|
||||
@@ -75,6 +99,7 @@ func NewExecTool(workingDir string, restrict bool) (*ExecTool, error) {
|
||||
|
||||
func NewExecToolWithConfig(workingDir string, restrict bool, config *config.Config) (*ExecTool, error) {
|
||||
denyPatterns := make([]*regexp.Regexp, 0)
|
||||
customAllowPatterns := make([]*regexp.Regexp, 0)
|
||||
|
||||
if config != nil {
|
||||
execConfig := config.Tools.Exec
|
||||
@@ -95,6 +120,13 @@ func NewExecToolWithConfig(workingDir string, restrict bool, config *config.Conf
|
||||
// If deny patterns are disabled, we won't add any patterns, allowing all commands.
|
||||
fmt.Println("Warning: deny patterns are disabled. All commands will be allowed.")
|
||||
}
|
||||
for _, pattern := range execConfig.CustomAllowPatterns {
|
||||
re, err := regexp.Compile(pattern)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid custom allow pattern %q: %w", pattern, err)
|
||||
}
|
||||
customAllowPatterns = append(customAllowPatterns, re)
|
||||
}
|
||||
} else {
|
||||
denyPatterns = append(denyPatterns, defaultDenyPatterns...)
|
||||
}
|
||||
@@ -104,6 +136,7 @@ func NewExecToolWithConfig(workingDir string, restrict bool, config *config.Conf
|
||||
timeout: 60 * time.Second,
|
||||
denyPatterns: denyPatterns,
|
||||
allowPatterns: nil,
|
||||
customAllowPatterns: customAllowPatterns,
|
||||
restrictToWorkspace: restrict,
|
||||
}, nil
|
||||
}
|
||||
@@ -258,9 +291,20 @@ func (t *ExecTool) guardCommand(command, cwd string) string {
|
||||
cmd := strings.TrimSpace(command)
|
||||
lower := strings.ToLower(cmd)
|
||||
|
||||
for _, pattern := range t.denyPatterns {
|
||||
// Custom allow patterns exempt a command from deny checks.
|
||||
explicitlyAllowed := false
|
||||
for _, pattern := range t.customAllowPatterns {
|
||||
if pattern.MatchString(lower) {
|
||||
return "Command blocked by safety guard (dangerous pattern detected)"
|
||||
explicitlyAllowed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !explicitlyAllowed {
|
||||
for _, pattern := range t.denyPatterns {
|
||||
if pattern.MatchString(lower) {
|
||||
return "Command blocked by safety guard (dangerous pattern detected)"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -287,16 +331,18 @@ func (t *ExecTool) guardCommand(command, cwd string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
pathPattern := regexp.MustCompile(`(?:^|\s|=)([A-Za-z]:\\[^\\"']+|/[a-zA-Z.][^\s"']*)`)
|
||||
matches := pathPattern.FindAllStringSubmatch(cmd, -1)
|
||||
matches := absolutePathPattern.FindAllString(cmd, -1)
|
||||
|
||||
for _, match := range matches {
|
||||
raw := match[1]
|
||||
for _, raw := range matches {
|
||||
p, err := filepath.Abs(raw)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if safePaths[p] {
|
||||
continue
|
||||
}
|
||||
|
||||
rel, err := filepath.Rel(cwdPath, p)
|
||||
if err != nil {
|
||||
continue
|
||||
|
||||
+106
-25
@@ -7,6 +7,8 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
// TestShellTool_Success verifies successful command execution
|
||||
@@ -310,6 +312,60 @@ func TestShellTool_RestrictToWorkspace(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestShellTool_DevNullAllowed verifies that /dev/null redirections are not blocked (issue #964).
|
||||
func TestShellTool_DevNullAllowed(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
tool, err := NewExecTool(tmpDir, true)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to configure exec tool: %s", err)
|
||||
}
|
||||
|
||||
commands := []string{
|
||||
"echo hello 2>/dev/null",
|
||||
"echo hello >/dev/null",
|
||||
"echo hello > /dev/null",
|
||||
"echo hello 2> /dev/null",
|
||||
"echo hello >/dev/null 2>&1",
|
||||
"find " + tmpDir + " -name '*.go' 2>/dev/null",
|
||||
}
|
||||
|
||||
for _, cmd := range commands {
|
||||
result := tool.Execute(context.Background(), map[string]any{"command": cmd})
|
||||
if result.IsError && strings.Contains(result.ForLLM, "blocked") {
|
||||
t.Errorf("command should not be blocked: %s\n error: %s", cmd, result.ForLLM)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestShellTool_BlockDevices verifies that writes to block devices are blocked (issue #965).
|
||||
func TestShellTool_BlockDevices(t *testing.T) {
|
||||
tool, err := NewExecTool("", false)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to configure exec tool: %s", err)
|
||||
}
|
||||
|
||||
blocked := []string{
|
||||
"echo x > /dev/sda",
|
||||
"echo x > /dev/hda",
|
||||
"echo x > /dev/vda",
|
||||
"echo x > /dev/xvda",
|
||||
"echo x > /dev/nvme0n1",
|
||||
"echo x > /dev/mmcblk0",
|
||||
"echo x > /dev/loop0",
|
||||
"echo x > /dev/dm-0",
|
||||
"echo x > /dev/md0",
|
||||
"echo x > /dev/sr0",
|
||||
"echo x > /dev/nbd0",
|
||||
}
|
||||
|
||||
for _, cmd := range blocked {
|
||||
result := tool.Execute(context.Background(), map[string]any{"command": cmd})
|
||||
if !result.IsError {
|
||||
t.Errorf("expected block device write to be blocked: %s", cmd)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestShellTool_DenyPattern_DiskWiping verifies the deny pattern for disk wiping
|
||||
// commands (format, mkfs, diskpart) blocks them when preceded by shell separators
|
||||
// but does NOT block legitimate uses like --format flags.
|
||||
@@ -322,7 +378,7 @@ func TestShellTool_DenyPattern_DiskWiping(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// These should be BLOCKED (disk wiping commands)
|
||||
blocked := []struct {
|
||||
blockedCmds := []struct {
|
||||
name string
|
||||
cmd string
|
||||
}{
|
||||
@@ -334,7 +390,7 @@ func TestShellTool_DenyPattern_DiskWiping(t *testing.T) {
|
||||
{"diskpart standalone", "diskpart /s script.txt"},
|
||||
}
|
||||
|
||||
for _, tt := range blocked {
|
||||
for _, tt := range blockedCmds {
|
||||
t.Run("blocked_"+tt.name, func(t *testing.T) {
|
||||
result := tool.Execute(ctx, map[string]any{"command": tt.cmd})
|
||||
if !result.IsError {
|
||||
@@ -362,35 +418,60 @@ func TestShellTool_DenyPattern_DiskWiping(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestShellTool_RestrictToWorkspace_HiddenDirs verifies that hidden directory
|
||||
// paths (starting with .) are properly detected by the workspace guard.
|
||||
func TestShellTool_RestrictToWorkspace_HiddenDirs(t *testing.T) {
|
||||
// TestShellTool_SafePathsInWorkspaceRestriction verifies that safe kernel pseudo-devices
|
||||
// are allowed even when workspace restriction is active.
|
||||
func TestShellTool_SafePathsInWorkspaceRestriction(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
tool, err := NewExecTool(tmpDir, false)
|
||||
tool, err := NewExecTool(tmpDir, true)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to configure exec tool: %s", err)
|
||||
}
|
||||
tool.SetRestrictToWorkspace(true)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Reading a hidden dir outside workspace should be blocked
|
||||
result := tool.Execute(ctx, map[string]any{
|
||||
"command": "cat /.ssh/config",
|
||||
})
|
||||
if !result.IsError {
|
||||
t.Errorf("Expected /.ssh/config to be blocked with restrictToWorkspace=true")
|
||||
// These reference paths outside workspace but should be allowed via safePaths.
|
||||
commands := []string{
|
||||
"cat /dev/urandom | head -c 16 | od",
|
||||
"echo test > /dev/null",
|
||||
"dd if=/dev/zero bs=1 count=1",
|
||||
}
|
||||
|
||||
// Flag-attached paths outside workspace should be blocked
|
||||
result2 := tool.Execute(ctx, map[string]any{
|
||||
"command": "grep --include=/etc/passwd pattern",
|
||||
})
|
||||
if !result2.IsError {
|
||||
// This tests the = delimiter fix; --include=/etc/passwd uses = in real
|
||||
// usage but --include /etc/passwd uses space. Both patterns should catch it.
|
||||
// If this specific form isn't blocked, it's acceptable since the primary
|
||||
// concern is the = form (--file=/etc/passwd).
|
||||
_ = result2 // acceptable either way for this pattern variant
|
||||
for _, cmd := range commands {
|
||||
result := tool.Execute(context.Background(), map[string]any{"command": cmd})
|
||||
if result.IsError && strings.Contains(result.ForLLM, "path outside working dir") {
|
||||
t.Errorf("safe path should not be blocked by workspace check: %s\n error: %s", cmd, result.ForLLM)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestShellTool_CustomAllowPatterns verifies that custom allow patterns exempt
|
||||
// commands from deny pattern checks.
|
||||
func TestShellTool_CustomAllowPatterns(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Tools: config.ToolsConfig{
|
||||
Exec: config.ExecConfig{
|
||||
EnableDenyPatterns: true,
|
||||
CustomAllowPatterns: []string{`\bgit\s+push\s+origin\b`},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
tool, err := NewExecToolWithConfig("", false, cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to configure exec tool: %s", err)
|
||||
}
|
||||
|
||||
// "git push origin main" should be allowed by custom allow pattern.
|
||||
result := tool.Execute(context.Background(), map[string]any{
|
||||
"command": "git push origin main",
|
||||
})
|
||||
if result.IsError && strings.Contains(result.ForLLM, "blocked") {
|
||||
t.Errorf("custom allow pattern should exempt 'git push origin main', got: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
// "git push upstream main" should still be blocked (does not match allow pattern).
|
||||
result = tool.Execute(context.Background(), map[string]any{
|
||||
"command": "git push upstream main",
|
||||
})
|
||||
if !result.IsError {
|
||||
t.Errorf("'git push upstream main' should still be blocked by deny pattern")
|
||||
}
|
||||
}
|
||||
|
||||
+86
-61
@@ -4,6 +4,7 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -15,6 +16,14 @@ import (
|
||||
|
||||
const (
|
||||
userAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36"
|
||||
|
||||
// HTTP client timeouts for web tool providers.
|
||||
searchTimeout = 10 * time.Second // Brave, Tavily, DuckDuckGo
|
||||
perplexityTimeout = 30 * time.Second // Perplexity (LLM-based, slower)
|
||||
fetchTimeout = 60 * time.Second // WebFetchTool
|
||||
|
||||
defaultMaxChars = 50000
|
||||
maxRedirects = 5
|
||||
)
|
||||
|
||||
// Pre-compiled regexes for HTML text extraction
|
||||
@@ -74,6 +83,7 @@ type SearchProvider interface {
|
||||
type BraveSearchProvider struct {
|
||||
apiKey string
|
||||
proxy string
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
func (p *BraveSearchProvider) Search(ctx context.Context, query string, count int) (string, error) {
|
||||
@@ -88,11 +98,7 @@ func (p *BraveSearchProvider) Search(ctx context.Context, query string, count in
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("X-Subscription-Token", p.apiKey)
|
||||
|
||||
client, err := createHTTPClient(p.proxy, 10*time.Second)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create HTTP client: %w", err)
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
resp, err := p.client.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
@@ -103,6 +109,10 @@ func (p *BraveSearchProvider) Search(ctx context.Context, query string, count in
|
||||
return "", fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("brave api error (status %d): %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var searchResp struct {
|
||||
Web struct {
|
||||
Results []struct {
|
||||
@@ -143,6 +153,7 @@ type TavilySearchProvider struct {
|
||||
apiKey string
|
||||
baseURL string
|
||||
proxy string
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
func (p *TavilySearchProvider) Search(ctx context.Context, query string, count int) (string, error) {
|
||||
@@ -174,11 +185,7 @@ func (p *TavilySearchProvider) Search(ctx context.Context, query string, count i
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", userAgent)
|
||||
|
||||
client, err := createHTTPClient(p.proxy, 10*time.Second)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create HTTP client: %w", err)
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
resp, err := p.client.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
@@ -226,7 +233,8 @@ func (p *TavilySearchProvider) Search(ctx context.Context, query string, count i
|
||||
}
|
||||
|
||||
type DuckDuckGoSearchProvider struct {
|
||||
proxy string
|
||||
proxy string
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
func (p *DuckDuckGoSearchProvider) Search(ctx context.Context, query string, count int) (string, error) {
|
||||
@@ -239,11 +247,7 @@ func (p *DuckDuckGoSearchProvider) Search(ctx context.Context, query string, cou
|
||||
|
||||
req.Header.Set("User-Agent", userAgent)
|
||||
|
||||
client, err := createHTTPClient(p.proxy, 10*time.Second)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create HTTP client: %w", err)
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
resp, err := p.client.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
@@ -285,7 +289,7 @@ func (p *DuckDuckGoSearchProvider) extractResults(html string, count int, query
|
||||
|
||||
maxItems := min(len(matches), count)
|
||||
|
||||
for i := 0; i < maxItems; i++ {
|
||||
for i := range maxItems {
|
||||
urlStr := matches[i][1]
|
||||
title := stripTags(matches[i][2])
|
||||
title = strings.TrimSpace(title)
|
||||
@@ -293,9 +297,9 @@ func (p *DuckDuckGoSearchProvider) extractResults(html string, count int, query
|
||||
// URL decoding if needed
|
||||
if strings.Contains(urlStr, "uddg=") {
|
||||
if u, err := url.QueryUnescape(urlStr); err == nil {
|
||||
idx := strings.Index(u, "uddg=")
|
||||
if idx != -1 {
|
||||
urlStr = u[idx+5:]
|
||||
_, after, ok := strings.Cut(u, "uddg=")
|
||||
if ok {
|
||||
urlStr = after
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -322,6 +326,7 @@ func stripTags(content string) string {
|
||||
type PerplexitySearchProvider struct {
|
||||
apiKey string
|
||||
proxy string
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
func (p *PerplexitySearchProvider) Search(ctx context.Context, query string, count int) (string, error) {
|
||||
@@ -356,11 +361,7 @@ func (p *PerplexitySearchProvider) Search(ctx context.Context, query string, cou
|
||||
req.Header.Set("Authorization", "Bearer "+p.apiKey)
|
||||
req.Header.Set("User-Agent", userAgent)
|
||||
|
||||
client, err := createHTTPClient(p.proxy, 30*time.Second)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create HTTP client: %w", err)
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
resp, err := p.client.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
@@ -415,43 +416,60 @@ type WebSearchToolOptions struct {
|
||||
Proxy string
|
||||
}
|
||||
|
||||
func NewWebSearchTool(opts WebSearchToolOptions) *WebSearchTool {
|
||||
func NewWebSearchTool(opts WebSearchToolOptions) (*WebSearchTool, error) {
|
||||
var provider SearchProvider
|
||||
maxResults := 5
|
||||
|
||||
// Priority: Perplexity > Brave > Tavily > DuckDuckGo
|
||||
if opts.PerplexityEnabled && opts.PerplexityAPIKey != "" {
|
||||
provider = &PerplexitySearchProvider{apiKey: opts.PerplexityAPIKey, proxy: opts.Proxy}
|
||||
client, err := createHTTPClient(opts.Proxy, perplexityTimeout)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create HTTP client for Perplexity: %w", err)
|
||||
}
|
||||
provider = &PerplexitySearchProvider{apiKey: opts.PerplexityAPIKey, proxy: opts.Proxy, client: client}
|
||||
if opts.PerplexityMaxResults > 0 {
|
||||
maxResults = opts.PerplexityMaxResults
|
||||
}
|
||||
} else if opts.BraveEnabled && opts.BraveAPIKey != "" {
|
||||
provider = &BraveSearchProvider{apiKey: opts.BraveAPIKey, proxy: opts.Proxy}
|
||||
client, err := createHTTPClient(opts.Proxy, searchTimeout)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create HTTP client for Brave: %w", err)
|
||||
}
|
||||
provider = &BraveSearchProvider{apiKey: opts.BraveAPIKey, proxy: opts.Proxy, client: client}
|
||||
if opts.BraveMaxResults > 0 {
|
||||
maxResults = opts.BraveMaxResults
|
||||
}
|
||||
} else if opts.TavilyEnabled && opts.TavilyAPIKey != "" {
|
||||
client, err := createHTTPClient(opts.Proxy, searchTimeout)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create HTTP client for Tavily: %w", err)
|
||||
}
|
||||
provider = &TavilySearchProvider{
|
||||
apiKey: opts.TavilyAPIKey,
|
||||
baseURL: opts.TavilyBaseURL,
|
||||
proxy: opts.Proxy,
|
||||
client: client,
|
||||
}
|
||||
if opts.TavilyMaxResults > 0 {
|
||||
maxResults = opts.TavilyMaxResults
|
||||
}
|
||||
} else if opts.DuckDuckGoEnabled {
|
||||
provider = &DuckDuckGoSearchProvider{proxy: opts.Proxy}
|
||||
client, err := createHTTPClient(opts.Proxy, searchTimeout)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create HTTP client for DuckDuckGo: %w", err)
|
||||
}
|
||||
provider = &DuckDuckGoSearchProvider{proxy: opts.Proxy, client: client}
|
||||
if opts.DuckDuckGoMaxResults > 0 {
|
||||
maxResults = opts.DuckDuckGoMaxResults
|
||||
}
|
||||
} else {
|
||||
return nil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return &WebSearchTool{
|
||||
provider: provider,
|
||||
maxResults: maxResults,
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (t *WebSearchTool) Name() string {
|
||||
@@ -506,27 +524,40 @@ func (t *WebSearchTool) Execute(ctx context.Context, args map[string]any) *ToolR
|
||||
}
|
||||
|
||||
type WebFetchTool struct {
|
||||
maxChars int
|
||||
proxy string
|
||||
maxChars int
|
||||
proxy string
|
||||
client *http.Client
|
||||
fetchLimitBytes int64
|
||||
}
|
||||
|
||||
func NewWebFetchTool(maxChars int) *WebFetchTool {
|
||||
if maxChars <= 0 {
|
||||
maxChars = 50000
|
||||
}
|
||||
return &WebFetchTool{
|
||||
maxChars: maxChars,
|
||||
}
|
||||
func NewWebFetchTool(maxChars int, fetchLimitBytes int64) (*WebFetchTool, error) {
|
||||
// createHTTPClient cannot fail with an empty proxy string.
|
||||
return NewWebFetchToolWithProxy(maxChars, "", fetchLimitBytes)
|
||||
}
|
||||
|
||||
func NewWebFetchToolWithProxy(maxChars int, proxy string) *WebFetchTool {
|
||||
func NewWebFetchToolWithProxy(maxChars int, proxy string, fetchLimitBytes int64) (*WebFetchTool, error) {
|
||||
if maxChars <= 0 {
|
||||
maxChars = 50000
|
||||
maxChars = defaultMaxChars
|
||||
}
|
||||
client, err := createHTTPClient(proxy, fetchTimeout)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create HTTP client for web fetch: %w", err)
|
||||
}
|
||||
client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
|
||||
if len(via) >= maxRedirects {
|
||||
return fmt.Errorf("stopped after %d redirects", maxRedirects)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if fetchLimitBytes <= 0 {
|
||||
fetchLimitBytes = 10 * 1024 * 1024 // Security Fallback
|
||||
}
|
||||
return &WebFetchTool{
|
||||
maxChars: maxChars,
|
||||
proxy: proxy,
|
||||
}
|
||||
maxChars: maxChars,
|
||||
proxy: proxy,
|
||||
client: client,
|
||||
fetchLimitBytes: fetchLimitBytes,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (t *WebFetchTool) Name() string {
|
||||
@@ -588,27 +619,21 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]any) *ToolRe
|
||||
|
||||
req.Header.Set("User-Agent", userAgent)
|
||||
|
||||
client, err := createHTTPClient(t.proxy, 60*time.Second)
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("failed to create HTTP client: %v", err))
|
||||
}
|
||||
|
||||
// Configure redirect handling
|
||||
client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
|
||||
if len(via) >= 5 {
|
||||
return fmt.Errorf("stopped after 5 redirects")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
resp, err := t.client.Do(req)
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("request failed: %v", err))
|
||||
}
|
||||
|
||||
resp.Body = http.MaxBytesReader(nil, resp.Body, t.fetchLimitBytes)
|
||||
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
var maxBytesErr *http.MaxBytesError
|
||||
if errors.As(err, &maxBytesErr) {
|
||||
return ErrorResult(fmt.Sprintf("failed to read response: size exceeded %d bytes limit", t.fetchLimitBytes))
|
||||
}
|
||||
return ErrorResult(fmt.Sprintf("failed to read response: %v", err))
|
||||
}
|
||||
|
||||
@@ -652,14 +677,14 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]any) *ToolRe
|
||||
resultJSON, _ := json.MarshalIndent(result, "", " ")
|
||||
|
||||
return &ToolResult{
|
||||
ForLLM: fmt.Sprintf(
|
||||
ForLLM: string(resultJSON),
|
||||
ForUser: fmt.Sprintf(
|
||||
"Fetched %d bytes from %s (extractor: %s, truncated: %v)",
|
||||
len(text),
|
||||
urlStr,
|
||||
extractor,
|
||||
truncated,
|
||||
),
|
||||
ForUser: string(resultJSON),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+144
-35
@@ -1,15 +1,21 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
)
|
||||
|
||||
const testFetchLimit = int64(10 * 1024 * 1024)
|
||||
|
||||
// TestWebTool_WebFetch_Success verifies successful URL fetching
|
||||
func TestWebTool_WebFetch_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -19,7 +25,11 @@ func TestWebTool_WebFetch_Success(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
tool := NewWebFetchTool(50000)
|
||||
tool, err := NewWebFetchTool(50000, testFetchLimit)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create web fetch tool: %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]any{
|
||||
"url": server.URL,
|
||||
@@ -32,14 +42,14 @@ func TestWebTool_WebFetch_Success(t *testing.T) {
|
||||
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
// ForUser should contain the fetched content
|
||||
if !strings.Contains(result.ForUser, "Test Page") {
|
||||
t.Errorf("Expected ForUser to contain 'Test Page', got: %s", result.ForUser)
|
||||
// ForLLM should contain the fetched content (full JSON result)
|
||||
if !strings.Contains(result.ForLLM, "Test Page") {
|
||||
t.Errorf("Expected ForLLM to contain 'Test Page', got: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
// ForLLM should contain summary
|
||||
if !strings.Contains(result.ForLLM, "bytes") && !strings.Contains(result.ForLLM, "extractor") {
|
||||
t.Errorf("Expected ForLLM to contain summary, got: %s", result.ForLLM)
|
||||
// ForUser should contain summary
|
||||
if !strings.Contains(result.ForUser, "bytes") && !strings.Contains(result.ForUser, "extractor") {
|
||||
t.Errorf("Expected ForUser to contain summary, got: %s", result.ForUser)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -55,7 +65,11 @@ func TestWebTool_WebFetch_JSON(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
tool := NewWebFetchTool(50000)
|
||||
tool, err := NewWebFetchTool(50000, testFetchLimit)
|
||||
if err != nil {
|
||||
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]any{
|
||||
"url": server.URL,
|
||||
@@ -68,15 +82,19 @@ func TestWebTool_WebFetch_JSON(t *testing.T) {
|
||||
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
// ForUser should contain formatted JSON
|
||||
if !strings.Contains(result.ForUser, "key") && !strings.Contains(result.ForUser, "value") {
|
||||
t.Errorf("Expected ForUser to contain JSON data, got: %s", result.ForUser)
|
||||
// ForLLM should contain formatted JSON
|
||||
if !strings.Contains(result.ForLLM, "key") && !strings.Contains(result.ForLLM, "value") {
|
||||
t.Errorf("Expected ForLLM to contain JSON data, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebTool_WebFetch_InvalidURL verifies error handling for invalid URL
|
||||
func TestWebTool_WebFetch_InvalidURL(t *testing.T) {
|
||||
tool := NewWebFetchTool(50000)
|
||||
tool, err := NewWebFetchTool(50000, testFetchLimit)
|
||||
if err != nil {
|
||||
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]any{
|
||||
"url": "not-a-valid-url",
|
||||
@@ -97,7 +115,11 @@ func TestWebTool_WebFetch_InvalidURL(t *testing.T) {
|
||||
|
||||
// TestWebTool_WebFetch_UnsupportedScheme verifies error handling for non-http URLs
|
||||
func TestWebTool_WebFetch_UnsupportedScheme(t *testing.T) {
|
||||
tool := NewWebFetchTool(50000)
|
||||
tool, err := NewWebFetchTool(50000, testFetchLimit)
|
||||
if err != nil {
|
||||
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]any{
|
||||
"url": "ftp://example.com/file.txt",
|
||||
@@ -118,7 +140,11 @@ func TestWebTool_WebFetch_UnsupportedScheme(t *testing.T) {
|
||||
|
||||
// TestWebTool_WebFetch_MissingURL verifies error handling for missing URL
|
||||
func TestWebTool_WebFetch_MissingURL(t *testing.T) {
|
||||
tool := NewWebFetchTool(50000)
|
||||
tool, err := NewWebFetchTool(50000, testFetchLimit)
|
||||
if err != nil {
|
||||
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]any{}
|
||||
|
||||
@@ -146,7 +172,11 @@ func TestWebTool_WebFetch_Truncation(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
tool := NewWebFetchTool(1000) // Limit to 1000 chars
|
||||
tool, err := NewWebFetchTool(1000, testFetchLimit) // Limit to 1000 chars
|
||||
if err != nil {
|
||||
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]any{
|
||||
"url": server.URL,
|
||||
@@ -159,9 +189,9 @@ func TestWebTool_WebFetch_Truncation(t *testing.T) {
|
||||
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
// ForUser should contain truncated content (not the full 20000 chars)
|
||||
// ForLLM should contain truncated content (not the full 20000 chars)
|
||||
resultMap := make(map[string]any)
|
||||
json.Unmarshal([]byte(result.ForUser), &resultMap)
|
||||
json.Unmarshal([]byte(result.ForLLM), &resultMap)
|
||||
if text, ok := resultMap["text"].(string); ok {
|
||||
if len(text) > 1100 { // Allow some margin
|
||||
t.Errorf("Expected content to be truncated to ~1000 chars, got: %d", len(text))
|
||||
@@ -174,15 +204,64 @@ func TestWebTool_WebFetch_Truncation(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebFetchTool_PayloadTooLarge(t *testing.T) {
|
||||
// Create a mock HTTP server
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/html")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
||||
// Generate a payload intentionally larger than our limit.
|
||||
// Limit: 10 * 1024 * 1024 (10MB). We generate 10MB + 100 bytes of the letter 'A'.
|
||||
largeData := bytes.Repeat([]byte("A"), int(testFetchLimit)+100)
|
||||
|
||||
w.Write(largeData)
|
||||
}))
|
||||
// Ensure the server is shut down at the end of the test
|
||||
defer ts.Close()
|
||||
|
||||
// Initialize the tool
|
||||
tool, err := NewWebFetchTool(50000, testFetchLimit)
|
||||
if err != nil {
|
||||
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
|
||||
}
|
||||
|
||||
// Prepare the arguments pointing to the URL of our local mock server
|
||||
args := map[string]any{
|
||||
"url": ts.URL,
|
||||
}
|
||||
|
||||
// Execute the tool
|
||||
ctx := context.Background()
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Assuming ErrorResult sets the ForLLM field with the error text.
|
||||
if result == nil {
|
||||
t.Fatal("expected a ToolResult, got nil")
|
||||
}
|
||||
|
||||
// Search for the exact error string we set earlier in the Execute method
|
||||
expectedErrorMsg := fmt.Sprintf("size exceeded %d bytes limit", testFetchLimit)
|
||||
|
||||
if !strings.Contains(result.ForLLM, expectedErrorMsg) && !strings.Contains(result.ForUser, expectedErrorMsg) {
|
||||
t.Errorf("test failed: expected error %q, but got: %+v", expectedErrorMsg, result)
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebTool_WebSearch_NoApiKey verifies that no tool is created when API key is missing
|
||||
func TestWebTool_WebSearch_NoApiKey(t *testing.T) {
|
||||
tool := NewWebSearchTool(WebSearchToolOptions{BraveEnabled: true, BraveAPIKey: ""})
|
||||
tool, err := NewWebSearchTool(WebSearchToolOptions{BraveEnabled: true, BraveAPIKey: ""})
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
if tool != nil {
|
||||
t.Errorf("Expected nil tool when Brave API key is empty")
|
||||
}
|
||||
|
||||
// Also nil when nothing is enabled
|
||||
tool = NewWebSearchTool(WebSearchToolOptions{})
|
||||
tool, err = NewWebSearchTool(WebSearchToolOptions{})
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
if tool != nil {
|
||||
t.Errorf("Expected nil tool when no provider is enabled")
|
||||
}
|
||||
@@ -190,7 +269,10 @@ func TestWebTool_WebSearch_NoApiKey(t *testing.T) {
|
||||
|
||||
// TestWebTool_WebSearch_MissingQuery verifies error handling for missing query
|
||||
func TestWebTool_WebSearch_MissingQuery(t *testing.T) {
|
||||
tool := NewWebSearchTool(WebSearchToolOptions{BraveEnabled: true, BraveAPIKey: "test-key", BraveMaxResults: 5})
|
||||
tool, err := NewWebSearchTool(WebSearchToolOptions{BraveEnabled: true, BraveAPIKey: "test-key", BraveMaxResults: 5})
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
ctx := context.Background()
|
||||
args := map[string]any{}
|
||||
|
||||
@@ -215,7 +297,11 @@ func TestWebTool_WebFetch_HTMLExtraction(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
tool := NewWebFetchTool(50000)
|
||||
tool, err := NewWebFetchTool(50000, testFetchLimit)
|
||||
if err != nil {
|
||||
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]any{
|
||||
"url": server.URL,
|
||||
@@ -228,14 +314,14 @@ func TestWebTool_WebFetch_HTMLExtraction(t *testing.T) {
|
||||
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
// ForUser should contain extracted text (without script/style tags)
|
||||
if !strings.Contains(result.ForUser, "Title") && !strings.Contains(result.ForUser, "Content") {
|
||||
t.Errorf("Expected ForUser to contain extracted text, got: %s", result.ForUser)
|
||||
// ForLLM should contain extracted text (without script/style tags)
|
||||
if !strings.Contains(result.ForLLM, "Title") && !strings.Contains(result.ForLLM, "Content") {
|
||||
t.Errorf("Expected ForLLM to contain extracted text, got: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
// Should NOT contain script or style tags
|
||||
if strings.Contains(result.ForUser, "<script>") || strings.Contains(result.ForUser, "<style>") {
|
||||
t.Errorf("Expected script/style tags to be removed, got: %s", result.ForUser)
|
||||
// Should NOT contain script or style tags in ForLLM
|
||||
if strings.Contains(result.ForLLM, "<script>") || strings.Contains(result.ForLLM, "<style>") {
|
||||
t.Errorf("Expected script/style tags to be removed, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -316,7 +402,11 @@ func TestWebFetchTool_extractText(t *testing.T) {
|
||||
|
||||
// TestWebTool_WebFetch_MissingDomain verifies error handling for URL without domain
|
||||
func TestWebTool_WebFetch_MissingDomain(t *testing.T) {
|
||||
tool := NewWebFetchTool(50000)
|
||||
tool, err := NewWebFetchTool(50000, testFetchLimit)
|
||||
if err != nil {
|
||||
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]any{
|
||||
"url": "https://",
|
||||
@@ -438,15 +528,22 @@ func TestCreateHTTPClient_ProxyFromEnvironmentWhenConfigEmpty(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestNewWebFetchToolWithProxy(t *testing.T) {
|
||||
tool := NewWebFetchToolWithProxy(1024, "http://127.0.0.1:7890")
|
||||
if tool.maxChars != 1024 {
|
||||
tool, err := NewWebFetchToolWithProxy(1024, "http://127.0.0.1:7890", testFetchLimit)
|
||||
if err != nil {
|
||||
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
|
||||
} else if tool.maxChars != 1024 {
|
||||
t.Fatalf("maxChars = %d, want %d", tool.maxChars, 1024)
|
||||
}
|
||||
|
||||
if tool.proxy != "http://127.0.0.1:7890" {
|
||||
t.Fatalf("proxy = %q, want %q", tool.proxy, "http://127.0.0.1:7890")
|
||||
}
|
||||
|
||||
tool = NewWebFetchToolWithProxy(0, "http://127.0.0.1:7890")
|
||||
tool, err = NewWebFetchToolWithProxy(0, "http://127.0.0.1:7890", testFetchLimit)
|
||||
if err != nil {
|
||||
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
|
||||
}
|
||||
|
||||
if tool.maxChars != 50000 {
|
||||
t.Fatalf("default maxChars = %d, want %d", tool.maxChars, 50000)
|
||||
}
|
||||
@@ -454,12 +551,15 @@ func TestNewWebFetchToolWithProxy(t *testing.T) {
|
||||
|
||||
func TestNewWebSearchTool_PropagatesProxy(t *testing.T) {
|
||||
t.Run("perplexity", func(t *testing.T) {
|
||||
tool := NewWebSearchTool(WebSearchToolOptions{
|
||||
tool, err := NewWebSearchTool(WebSearchToolOptions{
|
||||
PerplexityEnabled: true,
|
||||
PerplexityAPIKey: "k",
|
||||
PerplexityMaxResults: 3,
|
||||
Proxy: "http://127.0.0.1:7890",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewWebSearchTool() error: %v", err)
|
||||
}
|
||||
p, ok := tool.provider.(*PerplexitySearchProvider)
|
||||
if !ok {
|
||||
t.Fatalf("provider type = %T, want *PerplexitySearchProvider", tool.provider)
|
||||
@@ -470,12 +570,15 @@ func TestNewWebSearchTool_PropagatesProxy(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("brave", func(t *testing.T) {
|
||||
tool := NewWebSearchTool(WebSearchToolOptions{
|
||||
tool, err := NewWebSearchTool(WebSearchToolOptions{
|
||||
BraveEnabled: true,
|
||||
BraveAPIKey: "k",
|
||||
BraveMaxResults: 3,
|
||||
Proxy: "http://127.0.0.1:7890",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewWebSearchTool() error: %v", err)
|
||||
}
|
||||
p, ok := tool.provider.(*BraveSearchProvider)
|
||||
if !ok {
|
||||
t.Fatalf("provider type = %T, want *BraveSearchProvider", tool.provider)
|
||||
@@ -486,11 +589,14 @@ func TestNewWebSearchTool_PropagatesProxy(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("duckduckgo", func(t *testing.T) {
|
||||
tool := NewWebSearchTool(WebSearchToolOptions{
|
||||
tool, err := NewWebSearchTool(WebSearchToolOptions{
|
||||
DuckDuckGoEnabled: true,
|
||||
DuckDuckGoMaxResults: 3,
|
||||
Proxy: "http://127.0.0.1:7890",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewWebSearchTool() error: %v", err)
|
||||
}
|
||||
p, ok := tool.provider.(*DuckDuckGoSearchProvider)
|
||||
if !ok {
|
||||
t.Fatalf("provider type = %T, want *DuckDuckGoSearchProvider", tool.provider)
|
||||
@@ -542,12 +648,15 @@ func TestWebTool_TavilySearch_Success(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
tool := NewWebSearchTool(WebSearchToolOptions{
|
||||
tool, err := NewWebSearchTool(WebSearchToolOptions{
|
||||
TavilyEnabled: true,
|
||||
TavilyAPIKey: "test-key",
|
||||
TavilyBaseURL: server.URL,
|
||||
TavilyMaxResults: 5,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewWebSearchTool() error: %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]any{
|
||||
|
||||
@@ -37,6 +37,9 @@ func DoRequestWithRetry(client *http.Client, req *http.Request) (*http.Response,
|
||||
|
||||
if i < maxRetries-1 {
|
||||
if err = sleepWithCtx(req.Context(), retryDelayUnit*time.Duration(i+1)); err != nil {
|
||||
if resp != nil {
|
||||
resp.Body.Close()
|
||||
}
|
||||
return nil, fmt.Errorf("failed to sleep: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -77,6 +80,91 @@ func TestDoRequestWithRetry(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestDoRequestWithRetry_ContextCancel(t *testing.T) {
|
||||
// Use a long retry delay so cancellation always hits during sleepWithCtx.
|
||||
retryDelayUnit = 10 * time.Second
|
||||
t.Cleanup(func() { retryDelayUnit = time.Second })
|
||||
|
||||
bodyClosed := false
|
||||
firstRoundTripDone := make(chan struct{}, 1)
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
w.Write([]byte("error"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := server.Client()
|
||||
client.Timeout = 30 * time.Second
|
||||
client.Transport = &bodyCloseTracker{
|
||||
rt: client.Transport,
|
||||
onClose: func() { bodyClosed = true },
|
||||
// Signal after the first round-trip response is fully constructed on the client side.
|
||||
onRoundTrip: func() {
|
||||
select {
|
||||
case firstRoundTripDone <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
},
|
||||
trackURL: server.URL,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
// Cancel the context after the first round-trip completes on the client side.
|
||||
// This ensures client.Do has returned a valid resp (with body) and the retry
|
||||
// loop is about to enter sleepWithCtx, where the cancel will be detected.
|
||||
go func() {
|
||||
<-firstRoundTripDone
|
||||
cancel()
|
||||
}()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, server.URL, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err := DoRequestWithRetry(client, req)
|
||||
if resp != nil {
|
||||
resp.Body.Close()
|
||||
}
|
||||
require.Error(t, err, "expected error from context cancellation")
|
||||
assert.Nil(t, resp, "expected nil response when context is canceled")
|
||||
assert.True(t, bodyClosed, "expected resp.Body to be closed on context cancellation")
|
||||
}
|
||||
|
||||
// bodyCloseTracker wraps an http.RoundTripper and records when response bodies are closed.
|
||||
type bodyCloseTracker struct {
|
||||
rt http.RoundTripper
|
||||
onClose func()
|
||||
onRoundTrip func() // called after each successful round-trip
|
||||
trackURL string
|
||||
}
|
||||
|
||||
func (t *bodyCloseTracker) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
resp, err := t.rt.RoundTrip(req)
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
if strings.HasPrefix(req.URL.String(), t.trackURL) {
|
||||
resp.Body = &closeNotifier{ReadCloser: resp.Body, onClose: t.onClose}
|
||||
if t.onRoundTrip != nil {
|
||||
t.onRoundTrip()
|
||||
}
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// closeNotifier wraps an io.ReadCloser to detect Close calls.
|
||||
type closeNotifier struct {
|
||||
io.ReadCloser
|
||||
onClose func()
|
||||
}
|
||||
|
||||
func (c *closeNotifier) Close() error {
|
||||
c.onClose()
|
||||
return c.ReadCloser.Close()
|
||||
}
|
||||
|
||||
func TestDoRequestWithRetry_Delay(t *testing.T) {
|
||||
retryDelayUnit = time.Millisecond
|
||||
t.Cleanup(func() { retryDelayUnit = time.Second })
|
||||
|
||||
Reference in New Issue
Block a user