Merge upstream/main into fix/bugfixes

Resolve conflicts:
- provider.go: keep upstream's serializeMessages (supersedes stripSystemParts)
- provider_test.go: keep upstream's serializeMessages tests
- loop_test.go: add slices import needed by upstream tests
- shell.go: merge PR's --format deny fix with upstream's block device
  pattern, safePaths, and absolutePathPattern
- shell_test.go: include tests from both branches

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
I Putu Eddy Irawan
2026-03-03 21:55:26 +07:00
119 changed files with 8055 additions and 1855 deletions
+126 -62
View File
@@ -7,6 +7,7 @@ import (
"os"
"path/filepath"
"runtime"
"slices"
"strings"
"sync"
"time"
@@ -33,6 +34,11 @@ type ContextBuilder struct {
// created (didn't exist at cache time, now exist) or deleted (existed at
// cache time, now gone) — both of which should trigger a cache rebuild.
existedAtCache map[string]bool
// skillFilesAtCache snapshots the skill tree file set and mtimes at cache
// build time. This catches nested file creations/deletions/mtime changes
// that may not update the top-level skill root directory mtime.
skillFilesAtCache map[string]time.Time
}
func getGlobalConfigDir() string {
@@ -46,8 +52,11 @@ func getGlobalConfigDir() string {
func NewContextBuilder(workspace string) *ContextBuilder {
// builtin skills: skills directory in current project
// Use the skills/ directory under the current working directory
wd, _ := os.Getwd()
builtinSkillsDir := filepath.Join(wd, "skills")
builtinSkillsDir := strings.TrimSpace(os.Getenv("PICOCLAW_BUILTIN_SKILLS"))
if builtinSkillsDir == "" {
wd, _ := os.Getwd()
builtinSkillsDir = filepath.Join(wd, "skills")
}
globalSkillsDir := filepath.Join(getGlobalConfigDir(), "skills")
return &ContextBuilder{
@@ -147,6 +156,7 @@ func (cb *ContextBuilder) BuildSystemPromptWithCache() string {
cb.cachedSystemPrompt = prompt
cb.cachedAt = baseline.maxMtime
cb.existedAtCache = baseline.existed
cb.skillFilesAtCache = baseline.skillFiles
logger.DebugCF("agent", "System prompt cached",
map[string]any{
@@ -166,14 +176,14 @@ func (cb *ContextBuilder) InvalidateCache() {
cb.cachedSystemPrompt = ""
cb.cachedAt = time.Time{}
cb.existedAtCache = nil
cb.skillFilesAtCache = nil
logger.DebugCF("agent", "System prompt cache invalidated", nil)
}
// sourcePaths returns the workspace source file paths tracked for cache
// invalidation (bootstrap files + memory). The skills directory is handled
// separately in sourceFilesChangedLocked because it requires both directory-
// level and recursive file-level mtime checks.
// sourcePaths returns non-skill workspace source files tracked for cache
// invalidation (bootstrap files + memory). Skill roots are handled separately
// because they require both directory-level and recursive file-level checks.
func (cb *ContextBuilder) sourcePaths() []string {
return []string{
filepath.Join(cb.workspace, "AGENTS.md"),
@@ -184,23 +194,39 @@ func (cb *ContextBuilder) sourcePaths() []string {
}
}
// skillRoots returns all skill root directories that can affect
// BuildSkillsSummary output (workspace/global/builtin).
func (cb *ContextBuilder) skillRoots() []string {
if cb.skillsLoader == nil {
return []string{filepath.Join(cb.workspace, "skills")}
}
roots := cb.skillsLoader.SkillRoots()
if len(roots) == 0 {
return []string{filepath.Join(cb.workspace, "skills")}
}
return roots
}
// cacheBaseline holds the file existence snapshot and the latest observed
// mtime across all tracked paths. Used as the cache reference point.
type cacheBaseline struct {
existed map[string]bool
maxMtime time.Time
existed map[string]bool
skillFiles map[string]time.Time
maxMtime time.Time
}
// buildCacheBaseline records which tracked paths currently exist and computes
// the latest mtime across all tracked files + skills directory contents.
// Called under write lock when the cache is built.
func (cb *ContextBuilder) buildCacheBaseline() cacheBaseline {
skillsDir := filepath.Join(cb.workspace, "skills")
skillRoots := cb.skillRoots()
// All paths whose existence we track: source files + skills dir.
allPaths := append(cb.sourcePaths(), skillsDir)
// All paths whose existence we track: source files + all skill roots.
allPaths := append(cb.sourcePaths(), skillRoots...)
existed := make(map[string]bool, len(allPaths))
skillFiles := make(map[string]time.Time)
var maxMtime time.Time
for _, p := range allPaths {
@@ -211,17 +237,21 @@ func (cb *ContextBuilder) buildCacheBaseline() cacheBaseline {
}
}
// Walk skills files to capture their mtimes too.
// Use os.Stat (not d.Info) to match the stat method used in
// fileChangedSince / skillFilesModifiedSince for consistency.
_ = filepath.WalkDir(skillsDir, func(path string, d fs.DirEntry, walkErr error) error {
if walkErr == nil && !d.IsDir() {
if info, err := os.Stat(path); err == nil && info.ModTime().After(maxMtime) {
maxMtime = info.ModTime()
// Walk all skill roots recursively to snapshot skill files and mtimes.
// Use os.Stat (not d.Info) for consistency with sourceFilesChanged checks.
for _, root := range skillRoots {
_ = filepath.WalkDir(root, func(path string, d fs.DirEntry, walkErr error) error {
if walkErr == nil && !d.IsDir() {
if info, err := os.Stat(path); err == nil {
skillFiles[path] = info.ModTime()
if info.ModTime().After(maxMtime) {
maxMtime = info.ModTime()
}
}
}
}
return nil
})
return nil
})
}
// If no tracked files exist yet (empty workspace), maxMtime is zero.
// Use a very old non-zero time so that:
@@ -233,7 +263,7 @@ func (cb *ContextBuilder) buildCacheBaseline() cacheBaseline {
maxMtime = time.Unix(1, 0)
}
return cacheBaseline{existed: existed, maxMtime: maxMtime}
return cacheBaseline{existed: existed, skillFiles: skillFiles, maxMtime: maxMtime}
}
// sourceFilesChangedLocked checks whether any workspace source file has been
@@ -249,27 +279,21 @@ func (cb *ContextBuilder) sourceFilesChangedLocked() bool {
}
// Check tracked source files (bootstrap + memory).
for _, p := range cb.sourcePaths() {
if cb.fileChangedSince(p) {
return true
}
}
// --- Skills directory (handled separately from sourcePaths) ---
//
// 1. Creation/deletion: tracked via existedAtCache, same as bootstrap files.
skillsDir := filepath.Join(cb.workspace, "skills")
if cb.fileChangedSince(skillsDir) {
if slices.ContainsFunc(cb.sourcePaths(), cb.fileChangedSince) {
return true
}
// 2. Structural changes (add/remove entries inside the dir) are reflected
// in the directory's own mtime, which fileChangedSince already checks.
// --- Skill roots (workspace/global/builtin) ---
//
// 3. Content-only edits to files inside skills/ do NOT update the parent
// directory mtime on most filesystems, so we recursively walk to check
// individual file mtimes at any nesting depth.
if skillFilesModifiedSince(skillsDir, cb.cachedAt) {
// For each root:
// 1. Creation/deletion and root directory mtime changes are tracked by fileChangedSince.
// 2. Nested file create/delete/mtime changes are tracked by the skill file snapshot.
for _, root := range cb.skillRoots() {
if cb.fileChangedSince(root) {
return true
}
}
if skillFilesChangedSince(cb.skillRoots(), cb.skillFilesAtCache) {
return true
}
@@ -310,28 +334,64 @@ func (cb *ContextBuilder) fileChangedSince(path string) bool {
// if the callback returned nil when its err parameter is non-nil.
var errWalkStop = errors.New("walk stop")
// skillFilesModifiedSince recursively walks the skills directory and checks
// whether any file was modified after t. This catches content-only edits at
// any nesting depth (e.g. skills/name/docs/extra.md) that don't update
// parent directory mtimes.
func skillFilesModifiedSince(skillsDir string, t time.Time) bool {
changed := false
err := filepath.WalkDir(skillsDir, func(path string, d fs.DirEntry, walkErr error) error {
if walkErr == nil && !d.IsDir() {
if info, statErr := os.Stat(path); statErr == nil && info.ModTime().After(t) {
changed = true
return errWalkStop // stop walking
}
}
return nil
})
// errWalkStop is expected (early exit on first changed file).
// os.IsNotExist means the skills dir doesn't exist yet — not an error.
// Any other error is unexpected and worth logging.
if err != nil && !errors.Is(err, errWalkStop) && !os.IsNotExist(err) {
logger.DebugCF("agent", "skills walk error", map[string]any{"error": err.Error()})
// skillFilesChangedSince compares the current recursive skill file tree
// against the cache-time snapshot. Any create/delete/mtime drift invalidates
// the cache.
func skillFilesChangedSince(skillRoots []string, filesAtCache map[string]time.Time) bool {
// Defensive: if the snapshot was never initialized, force rebuild.
if filesAtCache == nil {
return true
}
return changed
// Check cached files still exist and keep the same mtime.
for path, cachedMtime := range filesAtCache {
info, err := os.Stat(path)
if err != nil {
// A previously tracked file disappeared (or became inaccessible):
// either way, cached skill summary may now be stale.
return true
}
if !info.ModTime().Equal(cachedMtime) {
return true
}
}
// Check no new files appeared under any skill root.
changed := false
for _, root := range skillRoots {
if strings.TrimSpace(root) == "" {
continue
}
err := filepath.WalkDir(root, func(path string, d fs.DirEntry, walkErr error) error {
if walkErr != nil {
// Treat unexpected walk errors as changed to avoid stale cache.
if !os.IsNotExist(walkErr) {
changed = true
return errWalkStop
}
return nil
}
if d.IsDir() {
return nil
}
if _, ok := filesAtCache[path]; !ok {
changed = true
return errWalkStop
}
return nil
})
if changed {
return true
}
if err != nil && !errors.Is(err, errWalkStop) && !os.IsNotExist(err) {
logger.DebugCF("agent", "skills walk error", map[string]any{"error": err.Error()})
return true
}
}
return false
}
func (cb *ContextBuilder) LoadBootstrapFiles() string {
@@ -467,10 +527,14 @@ func (cb *ContextBuilder) BuildMessages(
// Add current user message
if strings.TrimSpace(currentMessage) != "" {
messages = append(messages, providers.Message{
msg := providers.Message{
Role: "user",
Content: currentMessage,
})
}
if len(media) > 0 {
msg.Media = media
}
messages = append(messages, msg)
}
return messages
+158 -2
View File
@@ -383,6 +383,162 @@ Updated content.`
}
}
// TestGlobalSkillFileContentChange verifies that modifying a global skill
// (~/.picoclaw/skills) invalidates the cached system prompt.
func TestGlobalSkillFileContentChange(t *testing.T) {
tmpHome := t.TempDir()
t.Setenv("HOME", tmpHome)
tmpDir := setupWorkspace(t, nil)
defer os.RemoveAll(tmpDir)
globalSkillPath := filepath.Join(tmpHome, ".picoclaw", "skills", "global-skill", "SKILL.md")
if err := os.MkdirAll(filepath.Dir(globalSkillPath), 0o755); err != nil {
t.Fatal(err)
}
v1 := `---
name: global-skill
description: global-v1
---
# Global Skill v1`
if err := os.WriteFile(globalSkillPath, []byte(v1), 0o644); err != nil {
t.Fatal(err)
}
cb := NewContextBuilder(tmpDir)
sp1 := cb.BuildSystemPromptWithCache()
if !strings.Contains(sp1, "global-v1") {
t.Fatal("expected initial prompt to contain global skill description")
}
v2 := `---
name: global-skill
description: global-v2
---
# Global Skill v2`
if err := os.WriteFile(globalSkillPath, []byte(v2), 0o644); err != nil {
t.Fatal(err)
}
future := time.Now().Add(2 * time.Second)
if err := os.Chtimes(globalSkillPath, future, future); err != nil {
t.Fatalf("failed to update mtime for %s: %v", globalSkillPath, err)
}
cb.systemPromptMutex.RLock()
changed := cb.sourceFilesChangedLocked()
cb.systemPromptMutex.RUnlock()
if !changed {
t.Fatal("sourceFilesChangedLocked() should detect global skill file content change")
}
sp2 := cb.BuildSystemPromptWithCache()
if !strings.Contains(sp2, "global-v2") {
t.Error("rebuilt prompt should contain updated global skill description")
}
if sp1 == sp2 {
t.Error("cache should be invalidated when global skill file content changes")
}
}
// TestBuiltinSkillFileContentChange verifies that modifying a builtin skill
// invalidates the cached system prompt.
func TestBuiltinSkillFileContentChange(t *testing.T) {
tmpHome := t.TempDir()
t.Setenv("HOME", tmpHome)
tmpDir := setupWorkspace(t, nil)
defer os.RemoveAll(tmpDir)
builtinRoot := t.TempDir()
t.Setenv("PICOCLAW_BUILTIN_SKILLS", builtinRoot)
builtinSkillPath := filepath.Join(builtinRoot, "builtin-skill", "SKILL.md")
if err := os.MkdirAll(filepath.Dir(builtinSkillPath), 0o755); err != nil {
t.Fatal(err)
}
v1 := `---
name: builtin-skill
description: builtin-v1
---
# Builtin Skill v1`
if err := os.WriteFile(builtinSkillPath, []byte(v1), 0o644); err != nil {
t.Fatal(err)
}
cb := NewContextBuilder(tmpDir)
sp1 := cb.BuildSystemPromptWithCache()
if !strings.Contains(sp1, "builtin-v1") {
t.Fatal("expected initial prompt to contain builtin skill description")
}
v2 := `---
name: builtin-skill
description: builtin-v2
---
# Builtin Skill v2`
if err := os.WriteFile(builtinSkillPath, []byte(v2), 0o644); err != nil {
t.Fatal(err)
}
future := time.Now().Add(2 * time.Second)
if err := os.Chtimes(builtinSkillPath, future, future); err != nil {
t.Fatalf("failed to update mtime for %s: %v", builtinSkillPath, err)
}
cb.systemPromptMutex.RLock()
changed := cb.sourceFilesChangedLocked()
cb.systemPromptMutex.RUnlock()
if !changed {
t.Fatal("sourceFilesChangedLocked() should detect builtin skill file content change")
}
sp2 := cb.BuildSystemPromptWithCache()
if !strings.Contains(sp2, "builtin-v2") {
t.Error("rebuilt prompt should contain updated builtin skill description")
}
if sp1 == sp2 {
t.Error("cache should be invalidated when builtin skill file content changes")
}
}
// TestSkillFileDeletionInvalidatesCache verifies that deleting a nested skill
// file invalidates the cached system prompt.
func TestSkillFileDeletionInvalidatesCache(t *testing.T) {
tmpDir := setupWorkspace(t, map[string]string{
"skills/delete-me/SKILL.md": `---
name: delete-me
description: delete-me-v1
---
# Delete Me`,
})
defer os.RemoveAll(tmpDir)
cb := NewContextBuilder(tmpDir)
sp1 := cb.BuildSystemPromptWithCache()
if !strings.Contains(sp1, "delete-me-v1") {
t.Fatal("expected initial prompt to contain skill description")
}
skillPath := filepath.Join(tmpDir, "skills", "delete-me", "SKILL.md")
if err := os.Remove(skillPath); err != nil {
t.Fatal(err)
}
cb.systemPromptMutex.RLock()
changed := cb.sourceFilesChangedLocked()
cb.systemPromptMutex.RUnlock()
if !changed {
t.Fatal("sourceFilesChangedLocked() should detect deleted skill file")
}
sp2 := cb.BuildSystemPromptWithCache()
if strings.Contains(sp2, "delete-me-v1") {
t.Error("rebuilt prompt should not contain deleted skill description")
}
if sp1 == sp2 {
t.Error("cache should be invalidated when skill file is deleted")
}
}
// TestConcurrentBuildSystemPromptWithCache verifies that multiple goroutines
// can safely call BuildSystemPromptWithCache concurrently without producing
// empty results, panics, or data races.
@@ -404,11 +560,11 @@ func TestConcurrentBuildSystemPromptWithCache(t *testing.T) {
var wg sync.WaitGroup
errs := make(chan string, goroutines*iterations)
for g := 0; g < goroutines; g++ {
for g := range goroutines {
wg.Add(1)
go func(id int) {
defer wg.Done()
for i := 0; i < iterations; i++ {
for i := range iterations {
result := cb.BuildSystemPromptWithCache()
if result == "" {
errs <- "empty prompt returned"
+26 -5
View File
@@ -1,9 +1,11 @@
package agent
import (
"fmt"
"log"
"os"
"path/filepath"
"regexp"
"strings"
"github.com/sipeed/picoclaw/pkg/config"
@@ -48,18 +50,24 @@ func NewAgentInstance(
fallbacks := resolveAgentFallbacks(agentCfg, defaults)
restrict := defaults.RestrictToWorkspace
readRestrict := restrict && !defaults.AllowReadOutsideWorkspace
// Compile path whitelist patterns from config.
allowReadPaths := compilePatterns(cfg.Tools.AllowReadPaths)
allowWritePaths := compilePatterns(cfg.Tools.AllowWritePaths)
toolsRegistry := tools.NewToolRegistry()
toolsRegistry.Register(tools.NewReadFileTool(workspace, restrict))
toolsRegistry.Register(tools.NewWriteFileTool(workspace, restrict))
toolsRegistry.Register(tools.NewListDirTool(workspace, restrict))
toolsRegistry.Register(tools.NewReadFileTool(workspace, readRestrict, allowReadPaths))
toolsRegistry.Register(tools.NewWriteFileTool(workspace, restrict, allowWritePaths))
toolsRegistry.Register(tools.NewListDirTool(workspace, readRestrict, allowReadPaths))
execTool, err := tools.NewExecToolWithConfig(workspace, restrict, cfg)
if err != nil {
log.Fatalf("Critical error: unable to initialize exec tool: %v", err)
}
toolsRegistry.Register(execTool)
toolsRegistry.Register(tools.NewEditFileTool(workspace, restrict))
toolsRegistry.Register(tools.NewAppendFileTool(workspace, restrict))
toolsRegistry.Register(tools.NewEditFileTool(workspace, restrict, allowWritePaths))
toolsRegistry.Register(tools.NewAppendFileTool(workspace, restrict, allowWritePaths))
sessionsDir := filepath.Join(workspace, "sessions")
sessionsManager := session.NewSessionManager(sessionsDir)
@@ -189,6 +197,19 @@ func resolveAgentFallbacks(agentCfg *config.AgentConfig, defaults *config.AgentD
return defaults.ModelFallbacks
}
func compilePatterns(patterns []string) []*regexp.Regexp {
compiled := make([]*regexp.Regexp, 0, len(patterns))
for _, p := range patterns {
re, err := regexp.Compile(p)
if err != nil {
fmt.Printf("Warning: invalid path pattern %q: %v\n", p, err)
continue
}
compiled = append(compiled, re)
}
return compiled
}
func expandHome(path string) string {
if path == "" {
return path
+58 -65
View File
@@ -95,75 +95,68 @@ func TestNewAgentInstance_DefaultsTemperatureWhenUnset(t *testing.T) {
}
func TestNewAgentInstance_ResolveCandidatesFromModelListAlias(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-instance-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "step-3.5-flash",
},
tests := []struct {
name string
aliasName string
modelName string
apiBase string
wantProvider string
wantModel string
}{
{
name: "alias with provider prefix",
aliasName: "step-3.5-flash",
modelName: "openrouter/stepfun/step-3.5-flash:free",
apiBase: "https://openrouter.ai/api/v1",
wantProvider: "openrouter",
wantModel: "stepfun/step-3.5-flash:free",
},
ModelList: []config.ModelConfig{
{
ModelName: "step-3.5-flash",
Model: "openrouter/stepfun/step-3.5-flash:free",
APIBase: "https://openrouter.ai/api/v1",
},
{
name: "alias without provider prefix",
aliasName: "glm-5",
modelName: "glm-5",
apiBase: "https://api.z.ai/api/coding/paas/v4",
wantProvider: "openai",
wantModel: "glm-5",
},
}
provider := &mockProvider{}
agent := NewAgentInstance(nil, &cfg.Agents.Defaults, cfg, provider)
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-instance-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
if len(agent.Candidates) != 1 {
t.Fatalf("len(Candidates) = %d, want 1", len(agent.Candidates))
}
if agent.Candidates[0].Provider != "openrouter" {
t.Fatalf("candidate provider = %q, want %q", agent.Candidates[0].Provider, "openrouter")
}
if agent.Candidates[0].Model != "stepfun/step-3.5-flash:free" {
t.Fatalf("candidate model = %q, want %q", agent.Candidates[0].Model, "stepfun/step-3.5-flash:free")
}
}
func TestNewAgentInstance_ResolveCandidatesFromModelListAliasWithoutProtocol(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-instance-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "glm-5",
},
},
ModelList: []config.ModelConfig{
{
ModelName: "glm-5",
Model: "glm-5",
APIBase: "https://api.z.ai/api/coding/paas/v4",
},
},
}
provider := &mockProvider{}
agent := NewAgentInstance(nil, &cfg.Agents.Defaults, cfg, provider)
if len(agent.Candidates) != 1 {
t.Fatalf("len(Candidates) = %d, want 1", len(agent.Candidates))
}
if agent.Candidates[0].Provider != "openai" {
t.Fatalf("candidate provider = %q, want %q", agent.Candidates[0].Provider, "openai")
}
if agent.Candidates[0].Model != "glm-5" {
t.Fatalf("candidate model = %q, want %q", agent.Candidates[0].Model, "glm-5")
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: tt.aliasName,
},
},
ModelList: []config.ModelConfig{
{
ModelName: tt.aliasName,
Model: tt.modelName,
APIBase: tt.apiBase,
},
},
}
provider := &mockProvider{}
agent := NewAgentInstance(nil, &cfg.Agents.Defaults, cfg, provider)
if len(agent.Candidates) != 1 {
t.Fatalf("len(Candidates) = %d, want 1", len(agent.Candidates))
}
if agent.Candidates[0].Provider != tt.wantProvider {
t.Fatalf("candidate provider = %q, want %q", agent.Candidates[0].Provider, tt.wantProvider)
}
if agent.Candidates[0].Model != tt.wantModel {
t.Fatalf("candidate model = %q, want %q", agent.Candidates[0].Model, tt.wantModel)
}
})
}
}
+180 -39
View File
@@ -23,6 +23,7 @@ import (
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/constants"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/mcp"
"github.com/sipeed/picoclaw/pkg/media"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/routing"
@@ -46,19 +47,24 @@ type AgentLoop struct {
// processOptions configures how a message is processed
type processOptions struct {
SessionKey string // Session identifier for history/context
Channel string // Target channel for tool execution
ChatID string // Target chat ID for tool execution
UserMessage string // User message content (may include prefix)
DefaultResponse string // Response when LLM returns empty
EnableSummary bool // Whether to trigger summarization
SendResponse bool // Whether to send response via bus
NoHistory bool // If true, don't load session history (for heartbeat)
SessionKey string // Session identifier for history/context
Channel string // Target channel for tool execution
ChatID string // Target chat ID for tool execution
UserMessage string // User message content (may include prefix)
Media []string // media:// refs from inbound message
DefaultResponse string // Response when LLM returns empty
EnableSummary bool // Whether to trigger summarization
SendResponse bool // Whether to send response via bus
NoHistory bool // If true, don't load session history (for heartbeat)
}
const defaultResponse = "I've completed processing but have no response to give. Increase `max_tool_iterations` in config.json."
func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers.LLMProvider) *AgentLoop {
func NewAgentLoop(
cfg *config.Config,
msgBus *bus.MessageBus,
provider providers.LLMProvider,
) *AgentLoop {
registry := NewAgentRegistry(cfg, provider)
// Register shared tools to all agents
@@ -99,7 +105,7 @@ func registerSharedTools(
}
// Web tools
if searchTool := tools.NewWebSearchTool(tools.WebSearchToolOptions{
searchTool, err := tools.NewWebSearchTool(tools.WebSearchToolOptions{
BraveAPIKey: cfg.Tools.Web.Brave.APIKey,
BraveMaxResults: cfg.Tools.Web.Brave.MaxResults,
BraveEnabled: cfg.Tools.Web.Brave.Enabled,
@@ -113,10 +119,18 @@ func registerSharedTools(
PerplexityMaxResults: cfg.Tools.Web.Perplexity.MaxResults,
PerplexityEnabled: cfg.Tools.Web.Perplexity.Enabled,
Proxy: cfg.Tools.Web.Proxy,
}); searchTool != nil {
})
if err != nil {
logger.ErrorCF("agent", "Failed to create web search tool", map[string]any{"error": err.Error()})
} else if searchTool != nil {
agent.Tools.Register(searchTool)
}
agent.Tools.Register(tools.NewWebFetchToolWithProxy(50000, cfg.Tools.Web.Proxy))
fetchTool, err := tools.NewWebFetchToolWithProxy(50000, cfg.Tools.Web.Proxy, cfg.Tools.Web.FetchLimitBytes)
if err != nil {
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
} else {
agent.Tools.Register(fetchTool)
}
// Hardware tools (I2C, SPI) - Linux only, returns error on other platforms
agent.Tools.Register(tools.NewI2CTool())
@@ -162,6 +176,72 @@ func registerSharedTools(
func (al *AgentLoop) Run(ctx context.Context) error {
al.running.Store(true)
// Initialize MCP servers for all agents
if al.cfg.Tools.MCP.Enabled {
mcpManager := mcp.NewManager()
// Ensure MCP connections are cleaned up on exit, regardless of initialization success
// This fixes resource leak when LoadFromMCPConfig partially succeeds then fails
defer func() {
if err := mcpManager.Close(); err != nil {
logger.ErrorCF("agent", "Failed to close MCP manager",
map[string]any{
"error": err.Error(),
})
}
}()
defaultAgent := al.registry.GetDefaultAgent()
var workspacePath string
if defaultAgent != nil && defaultAgent.Workspace != "" {
workspacePath = defaultAgent.Workspace
} else {
workspacePath = al.cfg.WorkspacePath()
}
if err := mcpManager.LoadFromMCPConfig(ctx, al.cfg.Tools.MCP, workspacePath); err != nil {
logger.WarnCF("agent", "Failed to load MCP servers, MCP tools will not be available",
map[string]any{
"error": err.Error(),
})
} else {
// Register MCP tools for all agents
servers := mcpManager.GetServers()
uniqueTools := 0
totalRegistrations := 0
agentIDs := al.registry.ListAgentIDs()
agentCount := len(agentIDs)
for serverName, conn := range servers {
uniqueTools += len(conn.Tools)
for _, tool := range conn.Tools {
for _, agentID := range agentIDs {
agent, ok := al.registry.GetAgent(agentID)
if !ok {
continue
}
mcpTool := tools.NewMCPTool(mcpManager, serverName, tool)
agent.Tools.Register(mcpTool)
totalRegistrations++
logger.DebugCF("agent", "Registered MCP tool",
map[string]any{
"agent_id": agentID,
"server": serverName,
"tool": tool.Name,
"name": mcpTool.Name(),
})
}
}
}
logger.InfoCF("agent", "MCP tools registered successfully",
map[string]any{
"server_count": len(servers),
"unique_tools": uniqueTools,
"total_registrations": totalRegistrations,
"agent_count": agentCount,
})
}
}
for al.running.Load() {
select {
case <-ctx.Done():
@@ -302,7 +382,10 @@ func (al *AgentLoop) RecordLastChatID(chatID string) error {
return al.state.SetLastChatID(chatID)
}
func (al *AgentLoop) ProcessDirect(ctx context.Context, content, sessionKey string) (string, error) {
func (al *AgentLoop) ProcessDirect(
ctx context.Context,
content, sessionKey string,
) (string, error) {
return al.ProcessDirectWithChannel(ctx, content, sessionKey, "cli", "direct")
}
@@ -323,7 +406,10 @@ func (al *AgentLoop) ProcessDirectWithChannel(
// ProcessHeartbeat processes a heartbeat request without session history.
// Each heartbeat is independent and doesn't accumulate context.
func (al *AgentLoop) ProcessHeartbeat(ctx context.Context, content, channel, chatID string) (string, error) {
func (al *AgentLoop) ProcessHeartbeat(
ctx context.Context,
content, channel, chatID string,
) (string, error) {
agent := al.registry.GetDefaultAgent()
if agent == nil {
return "", fmt.Errorf("no default agent for heartbeat")
@@ -348,13 +434,16 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage)
} else {
logContent = utils.Truncate(msg.Content, 80)
}
logger.InfoCF("agent", fmt.Sprintf("Processing message from %s:%s: %s", msg.Channel, msg.SenderID, logContent),
logger.InfoCF(
"agent",
fmt.Sprintf("Processing message from %s:%s: %s", msg.Channel, msg.SenderID, logContent),
map[string]any{
"channel": msg.Channel,
"chat_id": msg.ChatID,
"sender_id": msg.SenderID,
"session_key": msg.SessionKey,
})
},
)
// Route system messages to processSystemMessage
if msg.Channel == "system" {
@@ -409,15 +498,22 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage)
Channel: msg.Channel,
ChatID: msg.ChatID,
UserMessage: msg.Content,
Media: msg.Media,
DefaultResponse: defaultResponse,
EnableSummary: true,
SendResponse: false,
})
}
func (al *AgentLoop) processSystemMessage(ctx context.Context, msg bus.InboundMessage) (string, error) {
func (al *AgentLoop) processSystemMessage(
ctx context.Context,
msg bus.InboundMessage,
) (string, error) {
if msg.Channel != "system" {
return "", fmt.Errorf("processSystemMessage called with non-system message channel: %s", msg.Channel)
return "", fmt.Errorf(
"processSystemMessage called with non-system message channel: %s",
msg.Channel,
)
}
logger.InfoCF("agent", "Processing system message",
@@ -475,14 +571,22 @@ func (al *AgentLoop) processSystemMessage(ctx context.Context, msg bus.InboundMe
}
// runAgentLoop is the core message processing logic.
func (al *AgentLoop) runAgentLoop(ctx context.Context, agent *AgentInstance, opts processOptions) (string, error) {
func (al *AgentLoop) runAgentLoop(
ctx context.Context,
agent *AgentInstance,
opts processOptions,
) (string, error) {
// 0. Record last channel for heartbeat notifications (skip internal channels)
if opts.Channel != "" && opts.ChatID != "" {
// Don't record internal channels (cli, system, subagent)
if !constants.IsInternalChannel(opts.Channel) {
channelKey := fmt.Sprintf("%s:%s", opts.Channel, opts.ChatID)
if err := al.RecordLastChannel(channelKey); err != nil {
logger.WarnCF("agent", "Failed to record last channel", map[string]any{"error": err.Error()})
logger.WarnCF(
"agent",
"Failed to record last channel",
map[string]any{"error": err.Error()},
)
}
}
}
@@ -501,11 +605,15 @@ func (al *AgentLoop) runAgentLoop(ctx context.Context, agent *AgentInstance, opt
history,
summary,
opts.UserMessage,
nil,
opts.Media,
opts.Channel,
opts.ChatID,
)
// Resolve media:// refs to base64 data URLs (streaming)
maxMediaSize := al.cfg.Agents.Defaults.GetMaxMediaSize()
messages = resolveMediaRefs(messages, al.mediaStore, maxMediaSize)
// 3. Save user message to session
agent.Sessions.AddMessage(opts.SessionKey, "user", opts.UserMessage)
@@ -564,7 +672,10 @@ func (al *AgentLoop) targetReasoningChannelID(channelName string) (chatID string
return ""
}
func (al *AgentLoop) handleReasoning(ctx context.Context, reasoningContent, channelName, channelID string) {
func (al *AgentLoop) handleReasoning(
ctx context.Context,
reasoningContent, channelName, channelID string,
) {
if reasoningContent == "" || channelName == "" || channelID == "" {
return
}
@@ -657,22 +768,33 @@ func (al *AgentLoop) runLLMIteration(
callLLM := func() (*providers.LLMResponse, error) {
if len(agent.Candidates) > 1 && al.fallback != nil {
fbResult, fbErr := al.fallback.Execute(ctx, agent.Candidates,
fbResult, fbErr := al.fallback.Execute(
ctx,
agent.Candidates,
func(ctx context.Context, provider, model string) (*providers.LLMResponse, error) {
return agent.Provider.Chat(ctx, messages, providerToolDefs, model, map[string]any{
"max_tokens": agent.MaxTokens,
"temperature": agent.Temperature,
"prompt_cache_key": agent.ID,
})
return agent.Provider.Chat(
ctx,
messages,
providerToolDefs,
model,
map[string]any{
"max_tokens": agent.MaxTokens,
"temperature": agent.Temperature,
"prompt_cache_key": agent.ID,
},
)
},
)
if fbErr != nil {
return nil, fbErr
}
if fbResult.Provider != "" && len(fbResult.Attempts) > 0 {
logger.InfoCF("agent", fmt.Sprintf("Fallback: succeeded with %s/%s after %d attempts",
fbResult.Provider, fbResult.Model, len(fbResult.Attempts)+1),
map[string]any{"agent_id": agent.ID, "iteration": iteration})
logger.InfoCF(
"agent",
fmt.Sprintf("Fallback: succeeded with %s/%s after %d attempts",
fbResult.Provider, fbResult.Model, len(fbResult.Attempts)+1),
map[string]any{"agent_id": agent.ID, "iteration": iteration},
)
}
return fbResult.Response, nil
}
@@ -723,10 +845,14 @@ func (al *AgentLoop) runLLMIteration(
}
if isContextError && retry < maxRetries {
logger.WarnCF("agent", "Context window error detected, attempting compression", map[string]any{
"error": err.Error(),
"retry": retry,
})
logger.WarnCF(
"agent",
"Context window error detected, attempting compression",
map[string]any{
"error": err.Error(),
"retry": retry,
},
)
if retry == 0 && !constants.IsInternalChannel(opts.Channel) {
al.bus.PublishOutbound(ctx, bus.OutboundMessage{
@@ -758,7 +884,12 @@ func (al *AgentLoop) runLLMIteration(
return "", iteration, fmt.Errorf("LLM call failed after retries: %w", err)
}
go al.handleReasoning(ctx, response.Reasoning, opts.Channel, al.targetReasoningChannelID(opts.Channel))
go al.handleReasoning(
ctx,
response.Reasoning,
opts.Channel,
al.targetReasoningChannelID(opts.Channel),
)
logger.DebugCF("agent", "LLM response",
map[string]any{
@@ -1067,7 +1198,11 @@ func formatMessagesForLog(messages []providers.Message) string {
for _, tc := range msg.ToolCalls {
fmt.Fprintf(&sb, " - ID: %s, Type: %s, Name: %s\n", tc.ID, tc.Type, tc.Name)
if tc.Function != nil {
fmt.Fprintf(&sb, " Arguments: %s\n", utils.Truncate(tc.Function.Arguments, 200))
fmt.Fprintf(
&sb,
" Arguments: %s\n",
utils.Truncate(tc.Function.Arguments, 200),
)
}
}
}
@@ -1096,7 +1231,11 @@ func formatToolsForLog(toolDefs []providers.ToolDefinition) string {
fmt.Fprintf(&sb, " [%d] Type: %s, Name: %s\n", i, tool.Type, tool.Function.Name)
fmt.Fprintf(&sb, " Description: %s\n", tool.Function.Description)
if len(tool.Function.Parameters) > 0 {
fmt.Fprintf(&sb, " Parameters: %s\n", utils.Truncate(fmt.Sprintf("%v", tool.Function.Parameters), 200))
fmt.Fprintf(
&sb,
" Parameters: %s\n",
utils.Truncate(fmt.Sprintf("%v", tool.Function.Parameters), 200),
)
}
}
sb.WriteString("]")
@@ -1193,7 +1332,9 @@ func (al *AgentLoop) summarizeBatch(
existingSummary string,
) (string, error) {
var sb strings.Builder
sb.WriteString("Provide a concise summary of this conversation segment, preserving core context and key points.\n")
sb.WriteString(
"Provide a concise summary of this conversation segment, preserving core context and key points.\n",
)
if existingSummary != "" {
sb.WriteString("Existing context: ")
sb.WriteString(existingSummary)
+122
View File
@@ -0,0 +1,122 @@
// PicoClaw - Ultra-lightweight personal AI agent
// Inspired by and based on nanobot: https://github.com/HKUDS/nanobot
// License: MIT
//
// Copyright (c) 2026 PicoClaw contributors
package agent
import (
"bytes"
"encoding/base64"
"io"
"os"
"strings"
"github.com/h2non/filetype"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/media"
"github.com/sipeed/picoclaw/pkg/providers"
)
// resolveMediaRefs replaces media:// refs in message Media fields with base64 data URLs.
// Uses streaming base64 encoding (file handle → encoder → buffer) to avoid holding
// both raw bytes and encoded string in memory simultaneously.
// Returns a new slice; original messages are not mutated.
func resolveMediaRefs(messages []providers.Message, store media.MediaStore, maxSize int) []providers.Message {
if store == nil {
return messages
}
result := make([]providers.Message, len(messages))
copy(result, messages)
for i, m := range result {
if len(m.Media) == 0 {
continue
}
resolved := make([]string, 0, len(m.Media))
for _, ref := range m.Media {
if !strings.HasPrefix(ref, "media://") {
resolved = append(resolved, ref)
continue
}
localPath, meta, err := store.ResolveWithMeta(ref)
if err != nil {
logger.WarnCF("agent", "Failed to resolve media ref", map[string]any{
"ref": ref,
"error": err.Error(),
})
continue
}
info, err := os.Stat(localPath)
if err != nil {
logger.WarnCF("agent", "Failed to stat media file", map[string]any{
"path": localPath,
"error": err.Error(),
})
continue
}
if info.Size() > int64(maxSize) {
logger.WarnCF("agent", "Media file too large, skipping", map[string]any{
"path": localPath,
"size": info.Size(),
"max_size": maxSize,
})
continue
}
// Determine MIME type: prefer metadata, fallback to magic-bytes detection
mime := meta.ContentType
if mime == "" {
kind, ftErr := filetype.MatchFile(localPath)
if ftErr != nil || kind == filetype.Unknown {
logger.WarnCF("agent", "Unknown media type, skipping", map[string]any{
"path": localPath,
})
continue
}
mime = kind.MIME.Value
}
// Streaming base64: open file → base64 encoder → buffer
// Peak memory: ~1.33x file size (buffer only, no raw bytes copy)
f, err := os.Open(localPath)
if err != nil {
logger.WarnCF("agent", "Failed to open media file", map[string]any{
"path": localPath,
"error": err.Error(),
})
continue
}
prefix := "data:" + mime + ";base64,"
encodedLen := base64.StdEncoding.EncodedLen(int(info.Size()))
var buf bytes.Buffer
buf.Grow(len(prefix) + encodedLen)
buf.WriteString(prefix)
encoder := base64.NewEncoder(base64.StdEncoding, &buf)
if _, err := io.Copy(encoder, f); err != nil {
f.Close()
logger.WarnCF("agent", "Failed to encode media file", map[string]any{
"path": localPath,
"error": err.Error(),
})
continue
}
encoder.Close()
f.Close()
resolved = append(resolved, buf.String())
}
result[i].Media = resolved
}
return result
}
+168 -71
View File
@@ -5,6 +5,7 @@ import (
"fmt"
"os"
"path/filepath"
"slices"
"strings"
"testing"
"time"
@@ -12,6 +13,7 @@ import (
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/channels"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/media"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/tools"
)
@@ -27,16 +29,15 @@ func (f *fakeChannel) IsAllowed(string) bool {
func (f *fakeChannel) IsAllowedSender(sender bus.SenderInfo) bool { return true }
func (f *fakeChannel) ReasoningChannelID() string { return f.id }
func TestRecordLastChannel(t *testing.T) {
// Create temp workspace
func newTestAgentLoop(
t *testing.T,
) (al *AgentLoop, cfg *config.Config, msgBus *bus.MessageBus, provider *mockProvider, cleanup func()) {
t.Helper()
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
// Create test config
cfg := &config.Config{
cfg = &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
@@ -46,74 +47,43 @@ func TestRecordLastChannel(t *testing.T) {
},
},
}
msgBus = bus.NewMessageBus()
provider = &mockProvider{}
al = NewAgentLoop(cfg, msgBus, provider)
return al, cfg, msgBus, provider, func() { os.RemoveAll(tmpDir) }
}
// Create agent loop
msgBus := bus.NewMessageBus()
provider := &mockProvider{}
al := NewAgentLoop(cfg, msgBus, provider)
func TestRecordLastChannel(t *testing.T) {
al, cfg, msgBus, provider, cleanup := newTestAgentLoop(t)
defer cleanup()
// Test RecordLastChannel
testChannel := "test-channel"
err = al.RecordLastChannel(testChannel)
if err != nil {
if err := al.RecordLastChannel(testChannel); err != nil {
t.Fatalf("RecordLastChannel failed: %v", err)
}
// Verify channel was saved
lastChannel := al.state.GetLastChannel()
if lastChannel != testChannel {
t.Errorf("Expected channel '%s', got '%s'", testChannel, lastChannel)
if got := al.state.GetLastChannel(); got != testChannel {
t.Errorf("Expected channel '%s', got '%s'", testChannel, got)
}
// Verify persistence by creating a new agent loop
al2 := NewAgentLoop(cfg, msgBus, provider)
if al2.state.GetLastChannel() != testChannel {
t.Errorf("Expected persistent channel '%s', got '%s'", testChannel, al2.state.GetLastChannel())
if got := al2.state.GetLastChannel(); got != testChannel {
t.Errorf("Expected persistent channel '%s', got '%s'", testChannel, got)
}
}
func TestRecordLastChatID(t *testing.T) {
// Create temp workspace
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
al, cfg, msgBus, provider, cleanup := newTestAgentLoop(t)
defer cleanup()
// Create test config
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
// Create agent loop
msgBus := bus.NewMessageBus()
provider := &mockProvider{}
al := NewAgentLoop(cfg, msgBus, provider)
// Test RecordLastChatID
testChatID := "test-chat-id-123"
err = al.RecordLastChatID(testChatID)
if err != nil {
if err := al.RecordLastChatID(testChatID); err != nil {
t.Fatalf("RecordLastChatID failed: %v", err)
}
// Verify chat ID was saved
lastChatID := al.state.GetLastChatID()
if lastChatID != testChatID {
t.Errorf("Expected chat ID '%s', got '%s'", testChatID, lastChatID)
if got := al.state.GetLastChatID(); got != testChatID {
t.Errorf("Expected chat ID '%s', got '%s'", testChatID, got)
}
// Verify persistence by creating a new agent loop
al2 := NewAgentLoop(cfg, msgBus, provider)
if al2.state.GetLastChatID() != testChatID {
t.Errorf("Expected persistent chat ID '%s', got '%s'", testChatID, al2.state.GetLastChatID())
if got := al2.state.GetLastChatID(); got != testChatID {
t.Errorf("Expected persistent chat ID '%s', got '%s'", testChatID, got)
}
}
@@ -188,13 +158,7 @@ func TestToolRegistry_ToolRegistration(t *testing.T) {
toolsList := toolsInfo["names"].([]string)
// Check that our custom tool name is in the list
found := false
for _, name := range toolsList {
if name == "mock_custom" {
found = true
break
}
}
found := slices.Contains(toolsList, "mock_custom")
if !found {
t.Error("Expected custom tool to be registered")
}
@@ -263,13 +227,7 @@ func TestToolRegistry_GetDefinitions(t *testing.T) {
toolsList := toolsInfo["names"].([]string)
// Check that our custom tool name is in the list
found := false
for _, name := range toolsList {
if name == "mock_custom" {
found = true
break
}
}
found := slices.Contains(toolsList, "mock_custom")
if !found {
t.Error("Expected custom tool to be registered")
}
@@ -931,3 +889,142 @@ func TestHandleReasoning(t *testing.T) {
}
})
}
func TestResolveMediaRefs_ResolvesToBase64(t *testing.T) {
store := media.NewFileMediaStore()
dir := t.TempDir()
// Create a minimal valid PNG (8-byte header is enough for filetype detection)
pngPath := filepath.Join(dir, "test.png")
// PNG magic: 0x89 P N G \r \n 0x1A \n + minimal IHDR
pngHeader := []byte{
0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A, // PNG signature
0x00, 0x00, 0x00, 0x0D, // IHDR length
0x49, 0x48, 0x44, 0x52, // "IHDR"
0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x08, 0x02, // 1x1 RGB
0x00, 0x00, 0x00, // no interlace
0x90, 0x77, 0x53, 0xDE, // CRC
}
if err := os.WriteFile(pngPath, pngHeader, 0o644); err != nil {
t.Fatal(err)
}
ref, err := store.Store(pngPath, media.MediaMeta{}, "test")
if err != nil {
t.Fatal(err)
}
messages := []providers.Message{
{Role: "user", Content: "describe this", Media: []string{ref}},
}
result := resolveMediaRefs(messages, store, config.DefaultMaxMediaSize)
if len(result[0].Media) != 1 {
t.Fatalf("expected 1 resolved media, got %d", len(result[0].Media))
}
if !strings.HasPrefix(result[0].Media[0], "data:image/png;base64,") {
t.Fatalf("expected data:image/png;base64, prefix, got %q", result[0].Media[0][:40])
}
}
func TestResolveMediaRefs_SkipsOversizedFile(t *testing.T) {
store := media.NewFileMediaStore()
dir := t.TempDir()
bigPath := filepath.Join(dir, "big.png")
// Write PNG header + padding to exceed limit
data := make([]byte, 1024+1) // 1KB + 1 byte
copy(data, []byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A})
if err := os.WriteFile(bigPath, data, 0o644); err != nil {
t.Fatal(err)
}
ref, _ := store.Store(bigPath, media.MediaMeta{}, "test")
messages := []providers.Message{
{Role: "user", Content: "hi", Media: []string{ref}},
}
// Use a tiny limit (1KB) so the file is oversized
result := resolveMediaRefs(messages, store, 1024)
if len(result[0].Media) != 0 {
t.Fatalf("expected 0 media (oversized), got %d", len(result[0].Media))
}
}
func TestResolveMediaRefs_SkipsUnknownType(t *testing.T) {
store := media.NewFileMediaStore()
dir := t.TempDir()
txtPath := filepath.Join(dir, "readme.txt")
if err := os.WriteFile(txtPath, []byte("hello world"), 0o644); err != nil {
t.Fatal(err)
}
ref, _ := store.Store(txtPath, media.MediaMeta{}, "test")
messages := []providers.Message{
{Role: "user", Content: "hi", Media: []string{ref}},
}
result := resolveMediaRefs(messages, store, config.DefaultMaxMediaSize)
if len(result[0].Media) != 0 {
t.Fatalf("expected 0 media (unknown type), got %d", len(result[0].Media))
}
}
func TestResolveMediaRefs_PassesThroughNonMediaRefs(t *testing.T) {
messages := []providers.Message{
{Role: "user", Content: "hi", Media: []string{"https://example.com/img.png"}},
}
result := resolveMediaRefs(messages, nil, config.DefaultMaxMediaSize)
if len(result[0].Media) != 1 || result[0].Media[0] != "https://example.com/img.png" {
t.Fatalf("expected passthrough of non-media:// URL, got %v", result[0].Media)
}
}
func TestResolveMediaRefs_DoesNotMutateOriginal(t *testing.T) {
store := media.NewFileMediaStore()
dir := t.TempDir()
pngPath := filepath.Join(dir, "test.png")
pngHeader := []byte{
0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A,
0x00, 0x00, 0x00, 0x0D, 0x49, 0x48, 0x44, 0x52,
0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x08, 0x02,
0x00, 0x00, 0x00, 0x90, 0x77, 0x53, 0xDE,
}
os.WriteFile(pngPath, pngHeader, 0o644)
ref, _ := store.Store(pngPath, media.MediaMeta{}, "test")
original := []providers.Message{
{Role: "user", Content: "hi", Media: []string{ref}},
}
originalRef := original[0].Media[0]
resolveMediaRefs(original, store, config.DefaultMaxMediaSize)
if original[0].Media[0] != originalRef {
t.Fatal("resolveMediaRefs mutated original message slice")
}
}
func TestResolveMediaRefs_UsesMetaContentType(t *testing.T) {
store := media.NewFileMediaStore()
dir := t.TempDir()
// File with JPEG content but stored with explicit content type
jpegPath := filepath.Join(dir, "photo")
jpegHeader := []byte{0xFF, 0xD8, 0xFF, 0xE0} // JPEG magic bytes
os.WriteFile(jpegPath, jpegHeader, 0o644)
ref, _ := store.Store(jpegPath, media.MediaMeta{ContentType: "image/jpeg"}, "test")
messages := []providers.Message{
{Role: "user", Content: "hi", Media: []string{ref}},
}
result := resolveMediaRefs(messages, store, config.DefaultMaxMediaSize)
if len(result[0].Media) != 1 {
t.Fatalf("expected 1 media, got %d", len(result[0].Media))
}
if !strings.HasPrefix(result[0].Media[0], "data:image/jpeg;base64,") {
t.Fatalf("expected jpeg prefix, got %q", result[0].Media[0][:30])
}
}
+1 -1
View File
@@ -111,7 +111,7 @@ func (ms *MemoryStore) GetRecentDailyNotes(days int) string {
var sb strings.Builder
first := true
for i := 0; i < days; i++ {
for i := range days {
date := time.Now().AddDate(0, 0, -i)
dateStr := date.Format("20060102") // YYYYMMDD
monthDir := dateStr[:6] // YYYYMM