mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
Merge upstream/main into fix/bugfixes
Resolve conflicts: - provider.go: keep upstream's serializeMessages (supersedes stripSystemParts) - provider_test.go: keep upstream's serializeMessages tests - loop_test.go: add slices import needed by upstream tests - shell.go: merge PR's --format deny fix with upstream's block device pattern, safePaths, and absolutePathPattern - shell_test.go: include tests from both branches Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
+126
-62
@@ -7,6 +7,7 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -33,6 +34,11 @@ type ContextBuilder struct {
|
||||
// created (didn't exist at cache time, now exist) or deleted (existed at
|
||||
// cache time, now gone) — both of which should trigger a cache rebuild.
|
||||
existedAtCache map[string]bool
|
||||
|
||||
// skillFilesAtCache snapshots the skill tree file set and mtimes at cache
|
||||
// build time. This catches nested file creations/deletions/mtime changes
|
||||
// that may not update the top-level skill root directory mtime.
|
||||
skillFilesAtCache map[string]time.Time
|
||||
}
|
||||
|
||||
func getGlobalConfigDir() string {
|
||||
@@ -46,8 +52,11 @@ func getGlobalConfigDir() string {
|
||||
func NewContextBuilder(workspace string) *ContextBuilder {
|
||||
// builtin skills: skills directory in current project
|
||||
// Use the skills/ directory under the current working directory
|
||||
wd, _ := os.Getwd()
|
||||
builtinSkillsDir := filepath.Join(wd, "skills")
|
||||
builtinSkillsDir := strings.TrimSpace(os.Getenv("PICOCLAW_BUILTIN_SKILLS"))
|
||||
if builtinSkillsDir == "" {
|
||||
wd, _ := os.Getwd()
|
||||
builtinSkillsDir = filepath.Join(wd, "skills")
|
||||
}
|
||||
globalSkillsDir := filepath.Join(getGlobalConfigDir(), "skills")
|
||||
|
||||
return &ContextBuilder{
|
||||
@@ -147,6 +156,7 @@ func (cb *ContextBuilder) BuildSystemPromptWithCache() string {
|
||||
cb.cachedSystemPrompt = prompt
|
||||
cb.cachedAt = baseline.maxMtime
|
||||
cb.existedAtCache = baseline.existed
|
||||
cb.skillFilesAtCache = baseline.skillFiles
|
||||
|
||||
logger.DebugCF("agent", "System prompt cached",
|
||||
map[string]any{
|
||||
@@ -166,14 +176,14 @@ func (cb *ContextBuilder) InvalidateCache() {
|
||||
cb.cachedSystemPrompt = ""
|
||||
cb.cachedAt = time.Time{}
|
||||
cb.existedAtCache = nil
|
||||
cb.skillFilesAtCache = nil
|
||||
|
||||
logger.DebugCF("agent", "System prompt cache invalidated", nil)
|
||||
}
|
||||
|
||||
// sourcePaths returns the workspace source file paths tracked for cache
|
||||
// invalidation (bootstrap files + memory). The skills directory is handled
|
||||
// separately in sourceFilesChangedLocked because it requires both directory-
|
||||
// level and recursive file-level mtime checks.
|
||||
// sourcePaths returns non-skill workspace source files tracked for cache
|
||||
// invalidation (bootstrap files + memory). Skill roots are handled separately
|
||||
// because they require both directory-level and recursive file-level checks.
|
||||
func (cb *ContextBuilder) sourcePaths() []string {
|
||||
return []string{
|
||||
filepath.Join(cb.workspace, "AGENTS.md"),
|
||||
@@ -184,23 +194,39 @@ func (cb *ContextBuilder) sourcePaths() []string {
|
||||
}
|
||||
}
|
||||
|
||||
// skillRoots returns all skill root directories that can affect
|
||||
// BuildSkillsSummary output (workspace/global/builtin).
|
||||
func (cb *ContextBuilder) skillRoots() []string {
|
||||
if cb.skillsLoader == nil {
|
||||
return []string{filepath.Join(cb.workspace, "skills")}
|
||||
}
|
||||
|
||||
roots := cb.skillsLoader.SkillRoots()
|
||||
if len(roots) == 0 {
|
||||
return []string{filepath.Join(cb.workspace, "skills")}
|
||||
}
|
||||
return roots
|
||||
}
|
||||
|
||||
// cacheBaseline holds the file existence snapshot and the latest observed
|
||||
// mtime across all tracked paths. Used as the cache reference point.
|
||||
type cacheBaseline struct {
|
||||
existed map[string]bool
|
||||
maxMtime time.Time
|
||||
existed map[string]bool
|
||||
skillFiles map[string]time.Time
|
||||
maxMtime time.Time
|
||||
}
|
||||
|
||||
// buildCacheBaseline records which tracked paths currently exist and computes
|
||||
// the latest mtime across all tracked files + skills directory contents.
|
||||
// Called under write lock when the cache is built.
|
||||
func (cb *ContextBuilder) buildCacheBaseline() cacheBaseline {
|
||||
skillsDir := filepath.Join(cb.workspace, "skills")
|
||||
skillRoots := cb.skillRoots()
|
||||
|
||||
// All paths whose existence we track: source files + skills dir.
|
||||
allPaths := append(cb.sourcePaths(), skillsDir)
|
||||
// All paths whose existence we track: source files + all skill roots.
|
||||
allPaths := append(cb.sourcePaths(), skillRoots...)
|
||||
|
||||
existed := make(map[string]bool, len(allPaths))
|
||||
skillFiles := make(map[string]time.Time)
|
||||
var maxMtime time.Time
|
||||
|
||||
for _, p := range allPaths {
|
||||
@@ -211,17 +237,21 @@ func (cb *ContextBuilder) buildCacheBaseline() cacheBaseline {
|
||||
}
|
||||
}
|
||||
|
||||
// Walk skills files to capture their mtimes too.
|
||||
// Use os.Stat (not d.Info) to match the stat method used in
|
||||
// fileChangedSince / skillFilesModifiedSince for consistency.
|
||||
_ = filepath.WalkDir(skillsDir, func(path string, d fs.DirEntry, walkErr error) error {
|
||||
if walkErr == nil && !d.IsDir() {
|
||||
if info, err := os.Stat(path); err == nil && info.ModTime().After(maxMtime) {
|
||||
maxMtime = info.ModTime()
|
||||
// Walk all skill roots recursively to snapshot skill files and mtimes.
|
||||
// Use os.Stat (not d.Info) for consistency with sourceFilesChanged checks.
|
||||
for _, root := range skillRoots {
|
||||
_ = filepath.WalkDir(root, func(path string, d fs.DirEntry, walkErr error) error {
|
||||
if walkErr == nil && !d.IsDir() {
|
||||
if info, err := os.Stat(path); err == nil {
|
||||
skillFiles[path] = info.ModTime()
|
||||
if info.ModTime().After(maxMtime) {
|
||||
maxMtime = info.ModTime()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// If no tracked files exist yet (empty workspace), maxMtime is zero.
|
||||
// Use a very old non-zero time so that:
|
||||
@@ -233,7 +263,7 @@ func (cb *ContextBuilder) buildCacheBaseline() cacheBaseline {
|
||||
maxMtime = time.Unix(1, 0)
|
||||
}
|
||||
|
||||
return cacheBaseline{existed: existed, maxMtime: maxMtime}
|
||||
return cacheBaseline{existed: existed, skillFiles: skillFiles, maxMtime: maxMtime}
|
||||
}
|
||||
|
||||
// sourceFilesChangedLocked checks whether any workspace source file has been
|
||||
@@ -249,27 +279,21 @@ func (cb *ContextBuilder) sourceFilesChangedLocked() bool {
|
||||
}
|
||||
|
||||
// Check tracked source files (bootstrap + memory).
|
||||
for _, p := range cb.sourcePaths() {
|
||||
if cb.fileChangedSince(p) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// --- Skills directory (handled separately from sourcePaths) ---
|
||||
//
|
||||
// 1. Creation/deletion: tracked via existedAtCache, same as bootstrap files.
|
||||
skillsDir := filepath.Join(cb.workspace, "skills")
|
||||
if cb.fileChangedSince(skillsDir) {
|
||||
if slices.ContainsFunc(cb.sourcePaths(), cb.fileChangedSince) {
|
||||
return true
|
||||
}
|
||||
|
||||
// 2. Structural changes (add/remove entries inside the dir) are reflected
|
||||
// in the directory's own mtime, which fileChangedSince already checks.
|
||||
// --- Skill roots (workspace/global/builtin) ---
|
||||
//
|
||||
// 3. Content-only edits to files inside skills/ do NOT update the parent
|
||||
// directory mtime on most filesystems, so we recursively walk to check
|
||||
// individual file mtimes at any nesting depth.
|
||||
if skillFilesModifiedSince(skillsDir, cb.cachedAt) {
|
||||
// For each root:
|
||||
// 1. Creation/deletion and root directory mtime changes are tracked by fileChangedSince.
|
||||
// 2. Nested file create/delete/mtime changes are tracked by the skill file snapshot.
|
||||
for _, root := range cb.skillRoots() {
|
||||
if cb.fileChangedSince(root) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
if skillFilesChangedSince(cb.skillRoots(), cb.skillFilesAtCache) {
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -310,28 +334,64 @@ func (cb *ContextBuilder) fileChangedSince(path string) bool {
|
||||
// if the callback returned nil when its err parameter is non-nil.
|
||||
var errWalkStop = errors.New("walk stop")
|
||||
|
||||
// skillFilesModifiedSince recursively walks the skills directory and checks
|
||||
// whether any file was modified after t. This catches content-only edits at
|
||||
// any nesting depth (e.g. skills/name/docs/extra.md) that don't update
|
||||
// parent directory mtimes.
|
||||
func skillFilesModifiedSince(skillsDir string, t time.Time) bool {
|
||||
changed := false
|
||||
err := filepath.WalkDir(skillsDir, func(path string, d fs.DirEntry, walkErr error) error {
|
||||
if walkErr == nil && !d.IsDir() {
|
||||
if info, statErr := os.Stat(path); statErr == nil && info.ModTime().After(t) {
|
||||
changed = true
|
||||
return errWalkStop // stop walking
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
// errWalkStop is expected (early exit on first changed file).
|
||||
// os.IsNotExist means the skills dir doesn't exist yet — not an error.
|
||||
// Any other error is unexpected and worth logging.
|
||||
if err != nil && !errors.Is(err, errWalkStop) && !os.IsNotExist(err) {
|
||||
logger.DebugCF("agent", "skills walk error", map[string]any{"error": err.Error()})
|
||||
// skillFilesChangedSince compares the current recursive skill file tree
|
||||
// against the cache-time snapshot. Any create/delete/mtime drift invalidates
|
||||
// the cache.
|
||||
func skillFilesChangedSince(skillRoots []string, filesAtCache map[string]time.Time) bool {
|
||||
// Defensive: if the snapshot was never initialized, force rebuild.
|
||||
if filesAtCache == nil {
|
||||
return true
|
||||
}
|
||||
return changed
|
||||
|
||||
// Check cached files still exist and keep the same mtime.
|
||||
for path, cachedMtime := range filesAtCache {
|
||||
info, err := os.Stat(path)
|
||||
if err != nil {
|
||||
// A previously tracked file disappeared (or became inaccessible):
|
||||
// either way, cached skill summary may now be stale.
|
||||
return true
|
||||
}
|
||||
if !info.ModTime().Equal(cachedMtime) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Check no new files appeared under any skill root.
|
||||
changed := false
|
||||
for _, root := range skillRoots {
|
||||
if strings.TrimSpace(root) == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
err := filepath.WalkDir(root, func(path string, d fs.DirEntry, walkErr error) error {
|
||||
if walkErr != nil {
|
||||
// Treat unexpected walk errors as changed to avoid stale cache.
|
||||
if !os.IsNotExist(walkErr) {
|
||||
changed = true
|
||||
return errWalkStop
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if d.IsDir() {
|
||||
return nil
|
||||
}
|
||||
if _, ok := filesAtCache[path]; !ok {
|
||||
changed = true
|
||||
return errWalkStop
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
if changed {
|
||||
return true
|
||||
}
|
||||
if err != nil && !errors.Is(err, errWalkStop) && !os.IsNotExist(err) {
|
||||
logger.DebugCF("agent", "skills walk error", map[string]any{"error": err.Error()})
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (cb *ContextBuilder) LoadBootstrapFiles() string {
|
||||
@@ -467,10 +527,14 @@ func (cb *ContextBuilder) BuildMessages(
|
||||
|
||||
// Add current user message
|
||||
if strings.TrimSpace(currentMessage) != "" {
|
||||
messages = append(messages, providers.Message{
|
||||
msg := providers.Message{
|
||||
Role: "user",
|
||||
Content: currentMessage,
|
||||
})
|
||||
}
|
||||
if len(media) > 0 {
|
||||
msg.Media = media
|
||||
}
|
||||
messages = append(messages, msg)
|
||||
}
|
||||
|
||||
return messages
|
||||
|
||||
@@ -383,6 +383,162 @@ Updated content.`
|
||||
}
|
||||
}
|
||||
|
||||
// TestGlobalSkillFileContentChange verifies that modifying a global skill
|
||||
// (~/.picoclaw/skills) invalidates the cached system prompt.
|
||||
func TestGlobalSkillFileContentChange(t *testing.T) {
|
||||
tmpHome := t.TempDir()
|
||||
t.Setenv("HOME", tmpHome)
|
||||
|
||||
tmpDir := setupWorkspace(t, nil)
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
globalSkillPath := filepath.Join(tmpHome, ".picoclaw", "skills", "global-skill", "SKILL.md")
|
||||
if err := os.MkdirAll(filepath.Dir(globalSkillPath), 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
v1 := `---
|
||||
name: global-skill
|
||||
description: global-v1
|
||||
---
|
||||
# Global Skill v1`
|
||||
if err := os.WriteFile(globalSkillPath, []byte(v1), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cb := NewContextBuilder(tmpDir)
|
||||
sp1 := cb.BuildSystemPromptWithCache()
|
||||
if !strings.Contains(sp1, "global-v1") {
|
||||
t.Fatal("expected initial prompt to contain global skill description")
|
||||
}
|
||||
|
||||
v2 := `---
|
||||
name: global-skill
|
||||
description: global-v2
|
||||
---
|
||||
# Global Skill v2`
|
||||
if err := os.WriteFile(globalSkillPath, []byte(v2), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
future := time.Now().Add(2 * time.Second)
|
||||
if err := os.Chtimes(globalSkillPath, future, future); err != nil {
|
||||
t.Fatalf("failed to update mtime for %s: %v", globalSkillPath, err)
|
||||
}
|
||||
|
||||
cb.systemPromptMutex.RLock()
|
||||
changed := cb.sourceFilesChangedLocked()
|
||||
cb.systemPromptMutex.RUnlock()
|
||||
if !changed {
|
||||
t.Fatal("sourceFilesChangedLocked() should detect global skill file content change")
|
||||
}
|
||||
|
||||
sp2 := cb.BuildSystemPromptWithCache()
|
||||
if !strings.Contains(sp2, "global-v2") {
|
||||
t.Error("rebuilt prompt should contain updated global skill description")
|
||||
}
|
||||
if sp1 == sp2 {
|
||||
t.Error("cache should be invalidated when global skill file content changes")
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuiltinSkillFileContentChange verifies that modifying a builtin skill
|
||||
// invalidates the cached system prompt.
|
||||
func TestBuiltinSkillFileContentChange(t *testing.T) {
|
||||
tmpHome := t.TempDir()
|
||||
t.Setenv("HOME", tmpHome)
|
||||
|
||||
tmpDir := setupWorkspace(t, nil)
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
builtinRoot := t.TempDir()
|
||||
t.Setenv("PICOCLAW_BUILTIN_SKILLS", builtinRoot)
|
||||
|
||||
builtinSkillPath := filepath.Join(builtinRoot, "builtin-skill", "SKILL.md")
|
||||
if err := os.MkdirAll(filepath.Dir(builtinSkillPath), 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
v1 := `---
|
||||
name: builtin-skill
|
||||
description: builtin-v1
|
||||
---
|
||||
# Builtin Skill v1`
|
||||
if err := os.WriteFile(builtinSkillPath, []byte(v1), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cb := NewContextBuilder(tmpDir)
|
||||
sp1 := cb.BuildSystemPromptWithCache()
|
||||
if !strings.Contains(sp1, "builtin-v1") {
|
||||
t.Fatal("expected initial prompt to contain builtin skill description")
|
||||
}
|
||||
|
||||
v2 := `---
|
||||
name: builtin-skill
|
||||
description: builtin-v2
|
||||
---
|
||||
# Builtin Skill v2`
|
||||
if err := os.WriteFile(builtinSkillPath, []byte(v2), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
future := time.Now().Add(2 * time.Second)
|
||||
if err := os.Chtimes(builtinSkillPath, future, future); err != nil {
|
||||
t.Fatalf("failed to update mtime for %s: %v", builtinSkillPath, err)
|
||||
}
|
||||
|
||||
cb.systemPromptMutex.RLock()
|
||||
changed := cb.sourceFilesChangedLocked()
|
||||
cb.systemPromptMutex.RUnlock()
|
||||
if !changed {
|
||||
t.Fatal("sourceFilesChangedLocked() should detect builtin skill file content change")
|
||||
}
|
||||
|
||||
sp2 := cb.BuildSystemPromptWithCache()
|
||||
if !strings.Contains(sp2, "builtin-v2") {
|
||||
t.Error("rebuilt prompt should contain updated builtin skill description")
|
||||
}
|
||||
if sp1 == sp2 {
|
||||
t.Error("cache should be invalidated when builtin skill file content changes")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSkillFileDeletionInvalidatesCache verifies that deleting a nested skill
|
||||
// file invalidates the cached system prompt.
|
||||
func TestSkillFileDeletionInvalidatesCache(t *testing.T) {
|
||||
tmpDir := setupWorkspace(t, map[string]string{
|
||||
"skills/delete-me/SKILL.md": `---
|
||||
name: delete-me
|
||||
description: delete-me-v1
|
||||
---
|
||||
# Delete Me`,
|
||||
})
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cb := NewContextBuilder(tmpDir)
|
||||
sp1 := cb.BuildSystemPromptWithCache()
|
||||
if !strings.Contains(sp1, "delete-me-v1") {
|
||||
t.Fatal("expected initial prompt to contain skill description")
|
||||
}
|
||||
|
||||
skillPath := filepath.Join(tmpDir, "skills", "delete-me", "SKILL.md")
|
||||
if err := os.Remove(skillPath); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cb.systemPromptMutex.RLock()
|
||||
changed := cb.sourceFilesChangedLocked()
|
||||
cb.systemPromptMutex.RUnlock()
|
||||
if !changed {
|
||||
t.Fatal("sourceFilesChangedLocked() should detect deleted skill file")
|
||||
}
|
||||
|
||||
sp2 := cb.BuildSystemPromptWithCache()
|
||||
if strings.Contains(sp2, "delete-me-v1") {
|
||||
t.Error("rebuilt prompt should not contain deleted skill description")
|
||||
}
|
||||
if sp1 == sp2 {
|
||||
t.Error("cache should be invalidated when skill file is deleted")
|
||||
}
|
||||
}
|
||||
|
||||
// TestConcurrentBuildSystemPromptWithCache verifies that multiple goroutines
|
||||
// can safely call BuildSystemPromptWithCache concurrently without producing
|
||||
// empty results, panics, or data races.
|
||||
@@ -404,11 +560,11 @@ func TestConcurrentBuildSystemPromptWithCache(t *testing.T) {
|
||||
var wg sync.WaitGroup
|
||||
errs := make(chan string, goroutines*iterations)
|
||||
|
||||
for g := 0; g < goroutines; g++ {
|
||||
for g := range goroutines {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
for i := 0; i < iterations; i++ {
|
||||
for i := range iterations {
|
||||
result := cb.BuildSystemPromptWithCache()
|
||||
if result == "" {
|
||||
errs <- "empty prompt returned"
|
||||
|
||||
+26
-5
@@ -1,9 +1,11 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
@@ -48,18 +50,24 @@ func NewAgentInstance(
|
||||
fallbacks := resolveAgentFallbacks(agentCfg, defaults)
|
||||
|
||||
restrict := defaults.RestrictToWorkspace
|
||||
readRestrict := restrict && !defaults.AllowReadOutsideWorkspace
|
||||
|
||||
// Compile path whitelist patterns from config.
|
||||
allowReadPaths := compilePatterns(cfg.Tools.AllowReadPaths)
|
||||
allowWritePaths := compilePatterns(cfg.Tools.AllowWritePaths)
|
||||
|
||||
toolsRegistry := tools.NewToolRegistry()
|
||||
toolsRegistry.Register(tools.NewReadFileTool(workspace, restrict))
|
||||
toolsRegistry.Register(tools.NewWriteFileTool(workspace, restrict))
|
||||
toolsRegistry.Register(tools.NewListDirTool(workspace, restrict))
|
||||
toolsRegistry.Register(tools.NewReadFileTool(workspace, readRestrict, allowReadPaths))
|
||||
toolsRegistry.Register(tools.NewWriteFileTool(workspace, restrict, allowWritePaths))
|
||||
toolsRegistry.Register(tools.NewListDirTool(workspace, readRestrict, allowReadPaths))
|
||||
execTool, err := tools.NewExecToolWithConfig(workspace, restrict, cfg)
|
||||
if err != nil {
|
||||
log.Fatalf("Critical error: unable to initialize exec tool: %v", err)
|
||||
}
|
||||
toolsRegistry.Register(execTool)
|
||||
|
||||
toolsRegistry.Register(tools.NewEditFileTool(workspace, restrict))
|
||||
toolsRegistry.Register(tools.NewAppendFileTool(workspace, restrict))
|
||||
toolsRegistry.Register(tools.NewEditFileTool(workspace, restrict, allowWritePaths))
|
||||
toolsRegistry.Register(tools.NewAppendFileTool(workspace, restrict, allowWritePaths))
|
||||
|
||||
sessionsDir := filepath.Join(workspace, "sessions")
|
||||
sessionsManager := session.NewSessionManager(sessionsDir)
|
||||
@@ -189,6 +197,19 @@ func resolveAgentFallbacks(agentCfg *config.AgentConfig, defaults *config.AgentD
|
||||
return defaults.ModelFallbacks
|
||||
}
|
||||
|
||||
func compilePatterns(patterns []string) []*regexp.Regexp {
|
||||
compiled := make([]*regexp.Regexp, 0, len(patterns))
|
||||
for _, p := range patterns {
|
||||
re, err := regexp.Compile(p)
|
||||
if err != nil {
|
||||
fmt.Printf("Warning: invalid path pattern %q: %v\n", p, err)
|
||||
continue
|
||||
}
|
||||
compiled = append(compiled, re)
|
||||
}
|
||||
return compiled
|
||||
}
|
||||
|
||||
func expandHome(path string) string {
|
||||
if path == "" {
|
||||
return path
|
||||
|
||||
+58
-65
@@ -95,75 +95,68 @@ func TestNewAgentInstance_DefaultsTemperatureWhenUnset(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestNewAgentInstance_ResolveCandidatesFromModelListAlias(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-instance-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: "step-3.5-flash",
|
||||
},
|
||||
tests := []struct {
|
||||
name string
|
||||
aliasName string
|
||||
modelName string
|
||||
apiBase string
|
||||
wantProvider string
|
||||
wantModel string
|
||||
}{
|
||||
{
|
||||
name: "alias with provider prefix",
|
||||
aliasName: "step-3.5-flash",
|
||||
modelName: "openrouter/stepfun/step-3.5-flash:free",
|
||||
apiBase: "https://openrouter.ai/api/v1",
|
||||
wantProvider: "openrouter",
|
||||
wantModel: "stepfun/step-3.5-flash:free",
|
||||
},
|
||||
ModelList: []config.ModelConfig{
|
||||
{
|
||||
ModelName: "step-3.5-flash",
|
||||
Model: "openrouter/stepfun/step-3.5-flash:free",
|
||||
APIBase: "https://openrouter.ai/api/v1",
|
||||
},
|
||||
{
|
||||
name: "alias without provider prefix",
|
||||
aliasName: "glm-5",
|
||||
modelName: "glm-5",
|
||||
apiBase: "https://api.z.ai/api/coding/paas/v4",
|
||||
wantProvider: "openai",
|
||||
wantModel: "glm-5",
|
||||
},
|
||||
}
|
||||
|
||||
provider := &mockProvider{}
|
||||
agent := NewAgentInstance(nil, &cfg.Agents.Defaults, cfg, provider)
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-instance-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
if len(agent.Candidates) != 1 {
|
||||
t.Fatalf("len(Candidates) = %d, want 1", len(agent.Candidates))
|
||||
}
|
||||
if agent.Candidates[0].Provider != "openrouter" {
|
||||
t.Fatalf("candidate provider = %q, want %q", agent.Candidates[0].Provider, "openrouter")
|
||||
}
|
||||
if agent.Candidates[0].Model != "stepfun/step-3.5-flash:free" {
|
||||
t.Fatalf("candidate model = %q, want %q", agent.Candidates[0].Model, "stepfun/step-3.5-flash:free")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewAgentInstance_ResolveCandidatesFromModelListAliasWithoutProtocol(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-instance-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: "glm-5",
|
||||
},
|
||||
},
|
||||
ModelList: []config.ModelConfig{
|
||||
{
|
||||
ModelName: "glm-5",
|
||||
Model: "glm-5",
|
||||
APIBase: "https://api.z.ai/api/coding/paas/v4",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
provider := &mockProvider{}
|
||||
agent := NewAgentInstance(nil, &cfg.Agents.Defaults, cfg, provider)
|
||||
|
||||
if len(agent.Candidates) != 1 {
|
||||
t.Fatalf("len(Candidates) = %d, want 1", len(agent.Candidates))
|
||||
}
|
||||
if agent.Candidates[0].Provider != "openai" {
|
||||
t.Fatalf("candidate provider = %q, want %q", agent.Candidates[0].Provider, "openai")
|
||||
}
|
||||
if agent.Candidates[0].Model != "glm-5" {
|
||||
t.Fatalf("candidate model = %q, want %q", agent.Candidates[0].Model, "glm-5")
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: tt.aliasName,
|
||||
},
|
||||
},
|
||||
ModelList: []config.ModelConfig{
|
||||
{
|
||||
ModelName: tt.aliasName,
|
||||
Model: tt.modelName,
|
||||
APIBase: tt.apiBase,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
provider := &mockProvider{}
|
||||
agent := NewAgentInstance(nil, &cfg.Agents.Defaults, cfg, provider)
|
||||
|
||||
if len(agent.Candidates) != 1 {
|
||||
t.Fatalf("len(Candidates) = %d, want 1", len(agent.Candidates))
|
||||
}
|
||||
if agent.Candidates[0].Provider != tt.wantProvider {
|
||||
t.Fatalf("candidate provider = %q, want %q", agent.Candidates[0].Provider, tt.wantProvider)
|
||||
}
|
||||
if agent.Candidates[0].Model != tt.wantModel {
|
||||
t.Fatalf("candidate model = %q, want %q", agent.Candidates[0].Model, tt.wantModel)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
+180
-39
@@ -23,6 +23,7 @@ import (
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/constants"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/mcp"
|
||||
"github.com/sipeed/picoclaw/pkg/media"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/routing"
|
||||
@@ -46,19 +47,24 @@ type AgentLoop struct {
|
||||
|
||||
// processOptions configures how a message is processed
|
||||
type processOptions struct {
|
||||
SessionKey string // Session identifier for history/context
|
||||
Channel string // Target channel for tool execution
|
||||
ChatID string // Target chat ID for tool execution
|
||||
UserMessage string // User message content (may include prefix)
|
||||
DefaultResponse string // Response when LLM returns empty
|
||||
EnableSummary bool // Whether to trigger summarization
|
||||
SendResponse bool // Whether to send response via bus
|
||||
NoHistory bool // If true, don't load session history (for heartbeat)
|
||||
SessionKey string // Session identifier for history/context
|
||||
Channel string // Target channel for tool execution
|
||||
ChatID string // Target chat ID for tool execution
|
||||
UserMessage string // User message content (may include prefix)
|
||||
Media []string // media:// refs from inbound message
|
||||
DefaultResponse string // Response when LLM returns empty
|
||||
EnableSummary bool // Whether to trigger summarization
|
||||
SendResponse bool // Whether to send response via bus
|
||||
NoHistory bool // If true, don't load session history (for heartbeat)
|
||||
}
|
||||
|
||||
const defaultResponse = "I've completed processing but have no response to give. Increase `max_tool_iterations` in config.json."
|
||||
|
||||
func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers.LLMProvider) *AgentLoop {
|
||||
func NewAgentLoop(
|
||||
cfg *config.Config,
|
||||
msgBus *bus.MessageBus,
|
||||
provider providers.LLMProvider,
|
||||
) *AgentLoop {
|
||||
registry := NewAgentRegistry(cfg, provider)
|
||||
|
||||
// Register shared tools to all agents
|
||||
@@ -99,7 +105,7 @@ func registerSharedTools(
|
||||
}
|
||||
|
||||
// Web tools
|
||||
if searchTool := tools.NewWebSearchTool(tools.WebSearchToolOptions{
|
||||
searchTool, err := tools.NewWebSearchTool(tools.WebSearchToolOptions{
|
||||
BraveAPIKey: cfg.Tools.Web.Brave.APIKey,
|
||||
BraveMaxResults: cfg.Tools.Web.Brave.MaxResults,
|
||||
BraveEnabled: cfg.Tools.Web.Brave.Enabled,
|
||||
@@ -113,10 +119,18 @@ func registerSharedTools(
|
||||
PerplexityMaxResults: cfg.Tools.Web.Perplexity.MaxResults,
|
||||
PerplexityEnabled: cfg.Tools.Web.Perplexity.Enabled,
|
||||
Proxy: cfg.Tools.Web.Proxy,
|
||||
}); searchTool != nil {
|
||||
})
|
||||
if err != nil {
|
||||
logger.ErrorCF("agent", "Failed to create web search tool", map[string]any{"error": err.Error()})
|
||||
} else if searchTool != nil {
|
||||
agent.Tools.Register(searchTool)
|
||||
}
|
||||
agent.Tools.Register(tools.NewWebFetchToolWithProxy(50000, cfg.Tools.Web.Proxy))
|
||||
fetchTool, err := tools.NewWebFetchToolWithProxy(50000, cfg.Tools.Web.Proxy, cfg.Tools.Web.FetchLimitBytes)
|
||||
if err != nil {
|
||||
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
|
||||
} else {
|
||||
agent.Tools.Register(fetchTool)
|
||||
}
|
||||
|
||||
// Hardware tools (I2C, SPI) - Linux only, returns error on other platforms
|
||||
agent.Tools.Register(tools.NewI2CTool())
|
||||
@@ -162,6 +176,72 @@ func registerSharedTools(
|
||||
func (al *AgentLoop) Run(ctx context.Context) error {
|
||||
al.running.Store(true)
|
||||
|
||||
// Initialize MCP servers for all agents
|
||||
if al.cfg.Tools.MCP.Enabled {
|
||||
mcpManager := mcp.NewManager()
|
||||
// Ensure MCP connections are cleaned up on exit, regardless of initialization success
|
||||
// This fixes resource leak when LoadFromMCPConfig partially succeeds then fails
|
||||
defer func() {
|
||||
if err := mcpManager.Close(); err != nil {
|
||||
logger.ErrorCF("agent", "Failed to close MCP manager",
|
||||
map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
}()
|
||||
|
||||
defaultAgent := al.registry.GetDefaultAgent()
|
||||
var workspacePath string
|
||||
if defaultAgent != nil && defaultAgent.Workspace != "" {
|
||||
workspacePath = defaultAgent.Workspace
|
||||
} else {
|
||||
workspacePath = al.cfg.WorkspacePath()
|
||||
}
|
||||
|
||||
if err := mcpManager.LoadFromMCPConfig(ctx, al.cfg.Tools.MCP, workspacePath); err != nil {
|
||||
logger.WarnCF("agent", "Failed to load MCP servers, MCP tools will not be available",
|
||||
map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
} else {
|
||||
// Register MCP tools for all agents
|
||||
servers := mcpManager.GetServers()
|
||||
uniqueTools := 0
|
||||
totalRegistrations := 0
|
||||
agentIDs := al.registry.ListAgentIDs()
|
||||
agentCount := len(agentIDs)
|
||||
|
||||
for serverName, conn := range servers {
|
||||
uniqueTools += len(conn.Tools)
|
||||
for _, tool := range conn.Tools {
|
||||
for _, agentID := range agentIDs {
|
||||
agent, ok := al.registry.GetAgent(agentID)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
mcpTool := tools.NewMCPTool(mcpManager, serverName, tool)
|
||||
agent.Tools.Register(mcpTool)
|
||||
totalRegistrations++
|
||||
logger.DebugCF("agent", "Registered MCP tool",
|
||||
map[string]any{
|
||||
"agent_id": agentID,
|
||||
"server": serverName,
|
||||
"tool": tool.Name,
|
||||
"name": mcpTool.Name(),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
logger.InfoCF("agent", "MCP tools registered successfully",
|
||||
map[string]any{
|
||||
"server_count": len(servers),
|
||||
"unique_tools": uniqueTools,
|
||||
"total_registrations": totalRegistrations,
|
||||
"agent_count": agentCount,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
for al.running.Load() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
@@ -302,7 +382,10 @@ func (al *AgentLoop) RecordLastChatID(chatID string) error {
|
||||
return al.state.SetLastChatID(chatID)
|
||||
}
|
||||
|
||||
func (al *AgentLoop) ProcessDirect(ctx context.Context, content, sessionKey string) (string, error) {
|
||||
func (al *AgentLoop) ProcessDirect(
|
||||
ctx context.Context,
|
||||
content, sessionKey string,
|
||||
) (string, error) {
|
||||
return al.ProcessDirectWithChannel(ctx, content, sessionKey, "cli", "direct")
|
||||
}
|
||||
|
||||
@@ -323,7 +406,10 @@ func (al *AgentLoop) ProcessDirectWithChannel(
|
||||
|
||||
// ProcessHeartbeat processes a heartbeat request without session history.
|
||||
// Each heartbeat is independent and doesn't accumulate context.
|
||||
func (al *AgentLoop) ProcessHeartbeat(ctx context.Context, content, channel, chatID string) (string, error) {
|
||||
func (al *AgentLoop) ProcessHeartbeat(
|
||||
ctx context.Context,
|
||||
content, channel, chatID string,
|
||||
) (string, error) {
|
||||
agent := al.registry.GetDefaultAgent()
|
||||
if agent == nil {
|
||||
return "", fmt.Errorf("no default agent for heartbeat")
|
||||
@@ -348,13 +434,16 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage)
|
||||
} else {
|
||||
logContent = utils.Truncate(msg.Content, 80)
|
||||
}
|
||||
logger.InfoCF("agent", fmt.Sprintf("Processing message from %s:%s: %s", msg.Channel, msg.SenderID, logContent),
|
||||
logger.InfoCF(
|
||||
"agent",
|
||||
fmt.Sprintf("Processing message from %s:%s: %s", msg.Channel, msg.SenderID, logContent),
|
||||
map[string]any{
|
||||
"channel": msg.Channel,
|
||||
"chat_id": msg.ChatID,
|
||||
"sender_id": msg.SenderID,
|
||||
"session_key": msg.SessionKey,
|
||||
})
|
||||
},
|
||||
)
|
||||
|
||||
// Route system messages to processSystemMessage
|
||||
if msg.Channel == "system" {
|
||||
@@ -409,15 +498,22 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage)
|
||||
Channel: msg.Channel,
|
||||
ChatID: msg.ChatID,
|
||||
UserMessage: msg.Content,
|
||||
Media: msg.Media,
|
||||
DefaultResponse: defaultResponse,
|
||||
EnableSummary: true,
|
||||
SendResponse: false,
|
||||
})
|
||||
}
|
||||
|
||||
func (al *AgentLoop) processSystemMessage(ctx context.Context, msg bus.InboundMessage) (string, error) {
|
||||
func (al *AgentLoop) processSystemMessage(
|
||||
ctx context.Context,
|
||||
msg bus.InboundMessage,
|
||||
) (string, error) {
|
||||
if msg.Channel != "system" {
|
||||
return "", fmt.Errorf("processSystemMessage called with non-system message channel: %s", msg.Channel)
|
||||
return "", fmt.Errorf(
|
||||
"processSystemMessage called with non-system message channel: %s",
|
||||
msg.Channel,
|
||||
)
|
||||
}
|
||||
|
||||
logger.InfoCF("agent", "Processing system message",
|
||||
@@ -475,14 +571,22 @@ func (al *AgentLoop) processSystemMessage(ctx context.Context, msg bus.InboundMe
|
||||
}
|
||||
|
||||
// runAgentLoop is the core message processing logic.
|
||||
func (al *AgentLoop) runAgentLoop(ctx context.Context, agent *AgentInstance, opts processOptions) (string, error) {
|
||||
func (al *AgentLoop) runAgentLoop(
|
||||
ctx context.Context,
|
||||
agent *AgentInstance,
|
||||
opts processOptions,
|
||||
) (string, error) {
|
||||
// 0. Record last channel for heartbeat notifications (skip internal channels)
|
||||
if opts.Channel != "" && opts.ChatID != "" {
|
||||
// Don't record internal channels (cli, system, subagent)
|
||||
if !constants.IsInternalChannel(opts.Channel) {
|
||||
channelKey := fmt.Sprintf("%s:%s", opts.Channel, opts.ChatID)
|
||||
if err := al.RecordLastChannel(channelKey); err != nil {
|
||||
logger.WarnCF("agent", "Failed to record last channel", map[string]any{"error": err.Error()})
|
||||
logger.WarnCF(
|
||||
"agent",
|
||||
"Failed to record last channel",
|
||||
map[string]any{"error": err.Error()},
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -501,11 +605,15 @@ func (al *AgentLoop) runAgentLoop(ctx context.Context, agent *AgentInstance, opt
|
||||
history,
|
||||
summary,
|
||||
opts.UserMessage,
|
||||
nil,
|
||||
opts.Media,
|
||||
opts.Channel,
|
||||
opts.ChatID,
|
||||
)
|
||||
|
||||
// Resolve media:// refs to base64 data URLs (streaming)
|
||||
maxMediaSize := al.cfg.Agents.Defaults.GetMaxMediaSize()
|
||||
messages = resolveMediaRefs(messages, al.mediaStore, maxMediaSize)
|
||||
|
||||
// 3. Save user message to session
|
||||
agent.Sessions.AddMessage(opts.SessionKey, "user", opts.UserMessage)
|
||||
|
||||
@@ -564,7 +672,10 @@ func (al *AgentLoop) targetReasoningChannelID(channelName string) (chatID string
|
||||
return ""
|
||||
}
|
||||
|
||||
func (al *AgentLoop) handleReasoning(ctx context.Context, reasoningContent, channelName, channelID string) {
|
||||
func (al *AgentLoop) handleReasoning(
|
||||
ctx context.Context,
|
||||
reasoningContent, channelName, channelID string,
|
||||
) {
|
||||
if reasoningContent == "" || channelName == "" || channelID == "" {
|
||||
return
|
||||
}
|
||||
@@ -657,22 +768,33 @@ func (al *AgentLoop) runLLMIteration(
|
||||
|
||||
callLLM := func() (*providers.LLMResponse, error) {
|
||||
if len(agent.Candidates) > 1 && al.fallback != nil {
|
||||
fbResult, fbErr := al.fallback.Execute(ctx, agent.Candidates,
|
||||
fbResult, fbErr := al.fallback.Execute(
|
||||
ctx,
|
||||
agent.Candidates,
|
||||
func(ctx context.Context, provider, model string) (*providers.LLMResponse, error) {
|
||||
return agent.Provider.Chat(ctx, messages, providerToolDefs, model, map[string]any{
|
||||
"max_tokens": agent.MaxTokens,
|
||||
"temperature": agent.Temperature,
|
||||
"prompt_cache_key": agent.ID,
|
||||
})
|
||||
return agent.Provider.Chat(
|
||||
ctx,
|
||||
messages,
|
||||
providerToolDefs,
|
||||
model,
|
||||
map[string]any{
|
||||
"max_tokens": agent.MaxTokens,
|
||||
"temperature": agent.Temperature,
|
||||
"prompt_cache_key": agent.ID,
|
||||
},
|
||||
)
|
||||
},
|
||||
)
|
||||
if fbErr != nil {
|
||||
return nil, fbErr
|
||||
}
|
||||
if fbResult.Provider != "" && len(fbResult.Attempts) > 0 {
|
||||
logger.InfoCF("agent", fmt.Sprintf("Fallback: succeeded with %s/%s after %d attempts",
|
||||
fbResult.Provider, fbResult.Model, len(fbResult.Attempts)+1),
|
||||
map[string]any{"agent_id": agent.ID, "iteration": iteration})
|
||||
logger.InfoCF(
|
||||
"agent",
|
||||
fmt.Sprintf("Fallback: succeeded with %s/%s after %d attempts",
|
||||
fbResult.Provider, fbResult.Model, len(fbResult.Attempts)+1),
|
||||
map[string]any{"agent_id": agent.ID, "iteration": iteration},
|
||||
)
|
||||
}
|
||||
return fbResult.Response, nil
|
||||
}
|
||||
@@ -723,10 +845,14 @@ func (al *AgentLoop) runLLMIteration(
|
||||
}
|
||||
|
||||
if isContextError && retry < maxRetries {
|
||||
logger.WarnCF("agent", "Context window error detected, attempting compression", map[string]any{
|
||||
"error": err.Error(),
|
||||
"retry": retry,
|
||||
})
|
||||
logger.WarnCF(
|
||||
"agent",
|
||||
"Context window error detected, attempting compression",
|
||||
map[string]any{
|
||||
"error": err.Error(),
|
||||
"retry": retry,
|
||||
},
|
||||
)
|
||||
|
||||
if retry == 0 && !constants.IsInternalChannel(opts.Channel) {
|
||||
al.bus.PublishOutbound(ctx, bus.OutboundMessage{
|
||||
@@ -758,7 +884,12 @@ func (al *AgentLoop) runLLMIteration(
|
||||
return "", iteration, fmt.Errorf("LLM call failed after retries: %w", err)
|
||||
}
|
||||
|
||||
go al.handleReasoning(ctx, response.Reasoning, opts.Channel, al.targetReasoningChannelID(opts.Channel))
|
||||
go al.handleReasoning(
|
||||
ctx,
|
||||
response.Reasoning,
|
||||
opts.Channel,
|
||||
al.targetReasoningChannelID(opts.Channel),
|
||||
)
|
||||
|
||||
logger.DebugCF("agent", "LLM response",
|
||||
map[string]any{
|
||||
@@ -1067,7 +1198,11 @@ func formatMessagesForLog(messages []providers.Message) string {
|
||||
for _, tc := range msg.ToolCalls {
|
||||
fmt.Fprintf(&sb, " - ID: %s, Type: %s, Name: %s\n", tc.ID, tc.Type, tc.Name)
|
||||
if tc.Function != nil {
|
||||
fmt.Fprintf(&sb, " Arguments: %s\n", utils.Truncate(tc.Function.Arguments, 200))
|
||||
fmt.Fprintf(
|
||||
&sb,
|
||||
" Arguments: %s\n",
|
||||
utils.Truncate(tc.Function.Arguments, 200),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1096,7 +1231,11 @@ func formatToolsForLog(toolDefs []providers.ToolDefinition) string {
|
||||
fmt.Fprintf(&sb, " [%d] Type: %s, Name: %s\n", i, tool.Type, tool.Function.Name)
|
||||
fmt.Fprintf(&sb, " Description: %s\n", tool.Function.Description)
|
||||
if len(tool.Function.Parameters) > 0 {
|
||||
fmt.Fprintf(&sb, " Parameters: %s\n", utils.Truncate(fmt.Sprintf("%v", tool.Function.Parameters), 200))
|
||||
fmt.Fprintf(
|
||||
&sb,
|
||||
" Parameters: %s\n",
|
||||
utils.Truncate(fmt.Sprintf("%v", tool.Function.Parameters), 200),
|
||||
)
|
||||
}
|
||||
}
|
||||
sb.WriteString("]")
|
||||
@@ -1193,7 +1332,9 @@ func (al *AgentLoop) summarizeBatch(
|
||||
existingSummary string,
|
||||
) (string, error) {
|
||||
var sb strings.Builder
|
||||
sb.WriteString("Provide a concise summary of this conversation segment, preserving core context and key points.\n")
|
||||
sb.WriteString(
|
||||
"Provide a concise summary of this conversation segment, preserving core context and key points.\n",
|
||||
)
|
||||
if existingSummary != "" {
|
||||
sb.WriteString("Existing context: ")
|
||||
sb.WriteString(existingSummary)
|
||||
|
||||
@@ -0,0 +1,122 @@
|
||||
// PicoClaw - Ultra-lightweight personal AI agent
|
||||
// Inspired by and based on nanobot: https://github.com/HKUDS/nanobot
|
||||
// License: MIT
|
||||
//
|
||||
// Copyright (c) 2026 PicoClaw contributors
|
||||
|
||||
package agent
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/h2non/filetype"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/media"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
// resolveMediaRefs replaces media:// refs in message Media fields with base64 data URLs.
|
||||
// Uses streaming base64 encoding (file handle → encoder → buffer) to avoid holding
|
||||
// both raw bytes and encoded string in memory simultaneously.
|
||||
// Returns a new slice; original messages are not mutated.
|
||||
func resolveMediaRefs(messages []providers.Message, store media.MediaStore, maxSize int) []providers.Message {
|
||||
if store == nil {
|
||||
return messages
|
||||
}
|
||||
|
||||
result := make([]providers.Message, len(messages))
|
||||
copy(result, messages)
|
||||
|
||||
for i, m := range result {
|
||||
if len(m.Media) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
resolved := make([]string, 0, len(m.Media))
|
||||
for _, ref := range m.Media {
|
||||
if !strings.HasPrefix(ref, "media://") {
|
||||
resolved = append(resolved, ref)
|
||||
continue
|
||||
}
|
||||
|
||||
localPath, meta, err := store.ResolveWithMeta(ref)
|
||||
if err != nil {
|
||||
logger.WarnCF("agent", "Failed to resolve media ref", map[string]any{
|
||||
"ref": ref,
|
||||
"error": err.Error(),
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
info, err := os.Stat(localPath)
|
||||
if err != nil {
|
||||
logger.WarnCF("agent", "Failed to stat media file", map[string]any{
|
||||
"path": localPath,
|
||||
"error": err.Error(),
|
||||
})
|
||||
continue
|
||||
}
|
||||
if info.Size() > int64(maxSize) {
|
||||
logger.WarnCF("agent", "Media file too large, skipping", map[string]any{
|
||||
"path": localPath,
|
||||
"size": info.Size(),
|
||||
"max_size": maxSize,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
// Determine MIME type: prefer metadata, fallback to magic-bytes detection
|
||||
mime := meta.ContentType
|
||||
if mime == "" {
|
||||
kind, ftErr := filetype.MatchFile(localPath)
|
||||
if ftErr != nil || kind == filetype.Unknown {
|
||||
logger.WarnCF("agent", "Unknown media type, skipping", map[string]any{
|
||||
"path": localPath,
|
||||
})
|
||||
continue
|
||||
}
|
||||
mime = kind.MIME.Value
|
||||
}
|
||||
|
||||
// Streaming base64: open file → base64 encoder → buffer
|
||||
// Peak memory: ~1.33x file size (buffer only, no raw bytes copy)
|
||||
f, err := os.Open(localPath)
|
||||
if err != nil {
|
||||
logger.WarnCF("agent", "Failed to open media file", map[string]any{
|
||||
"path": localPath,
|
||||
"error": err.Error(),
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
prefix := "data:" + mime + ";base64,"
|
||||
encodedLen := base64.StdEncoding.EncodedLen(int(info.Size()))
|
||||
var buf bytes.Buffer
|
||||
buf.Grow(len(prefix) + encodedLen)
|
||||
buf.WriteString(prefix)
|
||||
|
||||
encoder := base64.NewEncoder(base64.StdEncoding, &buf)
|
||||
if _, err := io.Copy(encoder, f); err != nil {
|
||||
f.Close()
|
||||
logger.WarnCF("agent", "Failed to encode media file", map[string]any{
|
||||
"path": localPath,
|
||||
"error": err.Error(),
|
||||
})
|
||||
continue
|
||||
}
|
||||
encoder.Close()
|
||||
f.Close()
|
||||
|
||||
resolved = append(resolved, buf.String())
|
||||
}
|
||||
|
||||
result[i].Media = resolved
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
+168
-71
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -12,6 +13,7 @@ import (
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/channels"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/media"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/tools"
|
||||
)
|
||||
@@ -27,16 +29,15 @@ func (f *fakeChannel) IsAllowed(string) bool {
|
||||
func (f *fakeChannel) IsAllowedSender(sender bus.SenderInfo) bool { return true }
|
||||
func (f *fakeChannel) ReasoningChannelID() string { return f.id }
|
||||
|
||||
func TestRecordLastChannel(t *testing.T) {
|
||||
// Create temp workspace
|
||||
func newTestAgentLoop(
|
||||
t *testing.T,
|
||||
) (al *AgentLoop, cfg *config.Config, msgBus *bus.MessageBus, provider *mockProvider, cleanup func()) {
|
||||
t.Helper()
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
// Create test config
|
||||
cfg := &config.Config{
|
||||
cfg = &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
@@ -46,74 +47,43 @@ func TestRecordLastChannel(t *testing.T) {
|
||||
},
|
||||
},
|
||||
}
|
||||
msgBus = bus.NewMessageBus()
|
||||
provider = &mockProvider{}
|
||||
al = NewAgentLoop(cfg, msgBus, provider)
|
||||
return al, cfg, msgBus, provider, func() { os.RemoveAll(tmpDir) }
|
||||
}
|
||||
|
||||
// Create agent loop
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &mockProvider{}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
func TestRecordLastChannel(t *testing.T) {
|
||||
al, cfg, msgBus, provider, cleanup := newTestAgentLoop(t)
|
||||
defer cleanup()
|
||||
|
||||
// Test RecordLastChannel
|
||||
testChannel := "test-channel"
|
||||
err = al.RecordLastChannel(testChannel)
|
||||
if err != nil {
|
||||
if err := al.RecordLastChannel(testChannel); err != nil {
|
||||
t.Fatalf("RecordLastChannel failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify channel was saved
|
||||
lastChannel := al.state.GetLastChannel()
|
||||
if lastChannel != testChannel {
|
||||
t.Errorf("Expected channel '%s', got '%s'", testChannel, lastChannel)
|
||||
if got := al.state.GetLastChannel(); got != testChannel {
|
||||
t.Errorf("Expected channel '%s', got '%s'", testChannel, got)
|
||||
}
|
||||
|
||||
// Verify persistence by creating a new agent loop
|
||||
al2 := NewAgentLoop(cfg, msgBus, provider)
|
||||
if al2.state.GetLastChannel() != testChannel {
|
||||
t.Errorf("Expected persistent channel '%s', got '%s'", testChannel, al2.state.GetLastChannel())
|
||||
if got := al2.state.GetLastChannel(); got != testChannel {
|
||||
t.Errorf("Expected persistent channel '%s', got '%s'", testChannel, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordLastChatID(t *testing.T) {
|
||||
// Create temp workspace
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
al, cfg, msgBus, provider, cleanup := newTestAgentLoop(t)
|
||||
defer cleanup()
|
||||
|
||||
// Create test config
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Create agent loop
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &mockProvider{}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
|
||||
// Test RecordLastChatID
|
||||
testChatID := "test-chat-id-123"
|
||||
err = al.RecordLastChatID(testChatID)
|
||||
if err != nil {
|
||||
if err := al.RecordLastChatID(testChatID); err != nil {
|
||||
t.Fatalf("RecordLastChatID failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify chat ID was saved
|
||||
lastChatID := al.state.GetLastChatID()
|
||||
if lastChatID != testChatID {
|
||||
t.Errorf("Expected chat ID '%s', got '%s'", testChatID, lastChatID)
|
||||
if got := al.state.GetLastChatID(); got != testChatID {
|
||||
t.Errorf("Expected chat ID '%s', got '%s'", testChatID, got)
|
||||
}
|
||||
|
||||
// Verify persistence by creating a new agent loop
|
||||
al2 := NewAgentLoop(cfg, msgBus, provider)
|
||||
if al2.state.GetLastChatID() != testChatID {
|
||||
t.Errorf("Expected persistent chat ID '%s', got '%s'", testChatID, al2.state.GetLastChatID())
|
||||
if got := al2.state.GetLastChatID(); got != testChatID {
|
||||
t.Errorf("Expected persistent chat ID '%s', got '%s'", testChatID, got)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -188,13 +158,7 @@ func TestToolRegistry_ToolRegistration(t *testing.T) {
|
||||
toolsList := toolsInfo["names"].([]string)
|
||||
|
||||
// Check that our custom tool name is in the list
|
||||
found := false
|
||||
for _, name := range toolsList {
|
||||
if name == "mock_custom" {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
found := slices.Contains(toolsList, "mock_custom")
|
||||
if !found {
|
||||
t.Error("Expected custom tool to be registered")
|
||||
}
|
||||
@@ -263,13 +227,7 @@ func TestToolRegistry_GetDefinitions(t *testing.T) {
|
||||
toolsList := toolsInfo["names"].([]string)
|
||||
|
||||
// Check that our custom tool name is in the list
|
||||
found := false
|
||||
for _, name := range toolsList {
|
||||
if name == "mock_custom" {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
found := slices.Contains(toolsList, "mock_custom")
|
||||
if !found {
|
||||
t.Error("Expected custom tool to be registered")
|
||||
}
|
||||
@@ -931,3 +889,142 @@ func TestHandleReasoning(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestResolveMediaRefs_ResolvesToBase64(t *testing.T) {
|
||||
store := media.NewFileMediaStore()
|
||||
dir := t.TempDir()
|
||||
|
||||
// Create a minimal valid PNG (8-byte header is enough for filetype detection)
|
||||
pngPath := filepath.Join(dir, "test.png")
|
||||
// PNG magic: 0x89 P N G \r \n 0x1A \n + minimal IHDR
|
||||
pngHeader := []byte{
|
||||
0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A, // PNG signature
|
||||
0x00, 0x00, 0x00, 0x0D, // IHDR length
|
||||
0x49, 0x48, 0x44, 0x52, // "IHDR"
|
||||
0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x08, 0x02, // 1x1 RGB
|
||||
0x00, 0x00, 0x00, // no interlace
|
||||
0x90, 0x77, 0x53, 0xDE, // CRC
|
||||
}
|
||||
if err := os.WriteFile(pngPath, pngHeader, 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ref, err := store.Store(pngPath, media.MediaMeta{}, "test")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
messages := []providers.Message{
|
||||
{Role: "user", Content: "describe this", Media: []string{ref}},
|
||||
}
|
||||
result := resolveMediaRefs(messages, store, config.DefaultMaxMediaSize)
|
||||
|
||||
if len(result[0].Media) != 1 {
|
||||
t.Fatalf("expected 1 resolved media, got %d", len(result[0].Media))
|
||||
}
|
||||
if !strings.HasPrefix(result[0].Media[0], "data:image/png;base64,") {
|
||||
t.Fatalf("expected data:image/png;base64, prefix, got %q", result[0].Media[0][:40])
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveMediaRefs_SkipsOversizedFile(t *testing.T) {
|
||||
store := media.NewFileMediaStore()
|
||||
dir := t.TempDir()
|
||||
|
||||
bigPath := filepath.Join(dir, "big.png")
|
||||
// Write PNG header + padding to exceed limit
|
||||
data := make([]byte, 1024+1) // 1KB + 1 byte
|
||||
copy(data, []byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A})
|
||||
if err := os.WriteFile(bigPath, data, 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ref, _ := store.Store(bigPath, media.MediaMeta{}, "test")
|
||||
|
||||
messages := []providers.Message{
|
||||
{Role: "user", Content: "hi", Media: []string{ref}},
|
||||
}
|
||||
// Use a tiny limit (1KB) so the file is oversized
|
||||
result := resolveMediaRefs(messages, store, 1024)
|
||||
|
||||
if len(result[0].Media) != 0 {
|
||||
t.Fatalf("expected 0 media (oversized), got %d", len(result[0].Media))
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveMediaRefs_SkipsUnknownType(t *testing.T) {
|
||||
store := media.NewFileMediaStore()
|
||||
dir := t.TempDir()
|
||||
|
||||
txtPath := filepath.Join(dir, "readme.txt")
|
||||
if err := os.WriteFile(txtPath, []byte("hello world"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ref, _ := store.Store(txtPath, media.MediaMeta{}, "test")
|
||||
|
||||
messages := []providers.Message{
|
||||
{Role: "user", Content: "hi", Media: []string{ref}},
|
||||
}
|
||||
result := resolveMediaRefs(messages, store, config.DefaultMaxMediaSize)
|
||||
|
||||
if len(result[0].Media) != 0 {
|
||||
t.Fatalf("expected 0 media (unknown type), got %d", len(result[0].Media))
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveMediaRefs_PassesThroughNonMediaRefs(t *testing.T) {
|
||||
messages := []providers.Message{
|
||||
{Role: "user", Content: "hi", Media: []string{"https://example.com/img.png"}},
|
||||
}
|
||||
result := resolveMediaRefs(messages, nil, config.DefaultMaxMediaSize)
|
||||
|
||||
if len(result[0].Media) != 1 || result[0].Media[0] != "https://example.com/img.png" {
|
||||
t.Fatalf("expected passthrough of non-media:// URL, got %v", result[0].Media)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveMediaRefs_DoesNotMutateOriginal(t *testing.T) {
|
||||
store := media.NewFileMediaStore()
|
||||
dir := t.TempDir()
|
||||
pngPath := filepath.Join(dir, "test.png")
|
||||
pngHeader := []byte{
|
||||
0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A,
|
||||
0x00, 0x00, 0x00, 0x0D, 0x49, 0x48, 0x44, 0x52,
|
||||
0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x08, 0x02,
|
||||
0x00, 0x00, 0x00, 0x90, 0x77, 0x53, 0xDE,
|
||||
}
|
||||
os.WriteFile(pngPath, pngHeader, 0o644)
|
||||
ref, _ := store.Store(pngPath, media.MediaMeta{}, "test")
|
||||
|
||||
original := []providers.Message{
|
||||
{Role: "user", Content: "hi", Media: []string{ref}},
|
||||
}
|
||||
originalRef := original[0].Media[0]
|
||||
|
||||
resolveMediaRefs(original, store, config.DefaultMaxMediaSize)
|
||||
|
||||
if original[0].Media[0] != originalRef {
|
||||
t.Fatal("resolveMediaRefs mutated original message slice")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveMediaRefs_UsesMetaContentType(t *testing.T) {
|
||||
store := media.NewFileMediaStore()
|
||||
dir := t.TempDir()
|
||||
|
||||
// File with JPEG content but stored with explicit content type
|
||||
jpegPath := filepath.Join(dir, "photo")
|
||||
jpegHeader := []byte{0xFF, 0xD8, 0xFF, 0xE0} // JPEG magic bytes
|
||||
os.WriteFile(jpegPath, jpegHeader, 0o644)
|
||||
ref, _ := store.Store(jpegPath, media.MediaMeta{ContentType: "image/jpeg"}, "test")
|
||||
|
||||
messages := []providers.Message{
|
||||
{Role: "user", Content: "hi", Media: []string{ref}},
|
||||
}
|
||||
result := resolveMediaRefs(messages, store, config.DefaultMaxMediaSize)
|
||||
|
||||
if len(result[0].Media) != 1 {
|
||||
t.Fatalf("expected 1 media, got %d", len(result[0].Media))
|
||||
}
|
||||
if !strings.HasPrefix(result[0].Media[0], "data:image/jpeg;base64,") {
|
||||
t.Fatalf("expected jpeg prefix, got %q", result[0].Media[0][:30])
|
||||
}
|
||||
}
|
||||
|
||||
+1
-1
@@ -111,7 +111,7 @@ func (ms *MemoryStore) GetRecentDailyNotes(days int) string {
|
||||
var sb strings.Builder
|
||||
first := true
|
||||
|
||||
for i := 0; i < days; i++ {
|
||||
for i := range days {
|
||||
date := time.Now().AddDate(0, 0, -i)
|
||||
dateStr := date.Format("20060102") // YYYYMMDD
|
||||
monthDir := dateStr[:6] // YYYYMM
|
||||
|
||||
Reference in New Issue
Block a user