Merge upstream/main into fix/1323-telegram-endless-typing

Made-with: Cursor
This commit is contained in:
kiannidev
2026-03-19 10:53:50 +02:00
270 changed files with 31493 additions and 9825 deletions
+24 -5
View File
@@ -52,7 +52,7 @@ func (cb *ContextBuilder) WithToolDiscovery(useBM25, useRegex bool) *ContextBuil
}
func getGlobalConfigDir() string {
if home := os.Getenv("PICOCLAW_HOME"); home != "" {
if home := os.Getenv(config.EnvHome); home != "" {
return home
}
home, err := os.UserHomeDir()
@@ -65,7 +65,7 @@ func getGlobalConfigDir() string {
func NewContextBuilder(workspace string) *ContextBuilder {
// builtin skills: skills directory in current project
// Use the skills/ directory under the current working directory
builtinSkillsDir := strings.TrimSpace(os.Getenv("PICOCLAW_BUILTIN_SKILLS"))
builtinSkillsDir := strings.TrimSpace(os.Getenv(config.EnvBuiltinSkills))
if builtinSkillsDir == "" {
wd, _ := os.Getwd()
builtinSkillsDir = filepath.Join(wd, "skills")
@@ -458,7 +458,23 @@ func (cb *ContextBuilder) LoadBootstrapFiles() string {
//
// See: https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
// See: https://platform.openai.com/docs/guides/prompt-caching
func (cb *ContextBuilder) buildDynamicContext(channel, chatID string) string {
func formatCurrentSenderLine(senderID, senderDisplayName string) string {
senderID = strings.TrimSpace(senderID)
senderDisplayName = strings.TrimSpace(senderDisplayName)
switch {
case senderDisplayName != "" && senderID != "":
return fmt.Sprintf("Current sender: %s (ID: %s)", senderDisplayName, senderID)
case senderDisplayName != "":
return fmt.Sprintf("Current sender: %s", senderDisplayName)
case senderID != "":
return fmt.Sprintf("Current sender: %s", senderID)
default:
return ""
}
}
func (cb *ContextBuilder) buildDynamicContext(channel, chatID, senderID, senderDisplayName string) string {
now := time.Now().Format("2006-01-02 15:04 (Monday)")
rt := fmt.Sprintf("%s %s, Go %s", runtime.GOOS, runtime.GOARCH, runtime.Version())
@@ -468,6 +484,9 @@ func (cb *ContextBuilder) buildDynamicContext(channel, chatID string) string {
if channel != "" && chatID != "" {
fmt.Fprintf(&sb, "\n\n## Current Session\nChannel: %s\nChat ID: %s", channel, chatID)
}
if senderLine := formatCurrentSenderLine(senderID, senderDisplayName); senderLine != "" {
fmt.Fprintf(&sb, "\n\n## Current Sender\n%s", senderLine)
}
return sb.String()
}
@@ -477,7 +496,7 @@ func (cb *ContextBuilder) BuildMessages(
summary string,
currentMessage string,
media []string,
channel, chatID string,
channel, chatID, senderID, senderDisplayName string,
) []providers.Message {
messages := []providers.Message{}
@@ -493,7 +512,7 @@ func (cb *ContextBuilder) BuildMessages(
staticPrompt := cb.BuildSystemPromptWithCache()
// Build short dynamic context (time, runtime, session) — changes per request
dynamicCtx := cb.buildDynamicContext(channel, chatID)
dynamicCtx := cb.buildDynamicContext(channel, chatID, senderID, senderDisplayName)
// Compose a single system message: static (cached) + dynamic + optional summary.
// Keeping all system content in one message ensures every provider adapter can
+65 -3
View File
@@ -82,7 +82,7 @@ func TestSingleSystemMessage(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
msgs := cb.BuildMessages(tt.history, tt.summary, tt.message, nil, "test", "chat1")
msgs := cb.BuildMessages(tt.history, tt.summary, tt.message, nil, "test", "chat1", "", "")
systemCount := 0
for _, m := range msgs {
@@ -126,6 +126,68 @@ func TestSingleSystemMessage(t *testing.T) {
}
}
func TestBuildMessages_CurrentSenderDynamicContext(t *testing.T) {
tmpDir := setupWorkspace(t, map[string]string{
"IDENTITY.md": "# Identity\nTest agent.",
})
defer os.RemoveAll(tmpDir)
cb := NewContextBuilder(tmpDir)
tests := []struct {
name string
senderID string
senderDisplayName string
wantLine string
wantSection bool
}{
{
name: "both id and display name",
senderID: "feishu:ou_xxx",
senderDisplayName: "Zhang San",
wantLine: "Current sender: Zhang San (ID: feishu:ou_xxx)",
wantSection: true,
},
{
name: "display name only",
senderDisplayName: "Alice",
wantLine: "Current sender: Alice",
wantSection: true,
},
{
name: "id only",
senderID: "discord:123",
wantLine: "Current sender: discord:123",
wantSection: true,
},
{
name: "no sender info",
wantSection: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
msgs := cb.BuildMessages(nil, "", "hello", nil, "discord", "chat1", tt.senderID, tt.senderDisplayName)
sys := msgs[0].Content
if tt.wantSection {
if !strings.Contains(sys, "## Current Sender") {
t.Fatalf("system prompt missing Current Sender section:\n%s", sys)
}
if !strings.Contains(sys, tt.wantLine) {
t.Fatalf("system prompt missing sender line %q:\n%s", tt.wantLine, sys)
}
return
}
if strings.Contains(sys, "## Current Sender") {
t.Fatalf("system prompt should omit Current Sender section:\n%s", sys)
}
})
}
}
// TestMtimeAutoInvalidation verifies that the cache detects source file changes
// via mtime without requiring explicit InvalidateCache().
// Fix: original implementation had no auto-invalidation — edits to bootstrap files,
@@ -576,7 +638,7 @@ func TestConcurrentBuildSystemPromptWithCache(t *testing.T) {
}
// Also exercise BuildMessages concurrently
msgs := cb.BuildMessages(nil, "", "hello", nil, "test", "chat")
msgs := cb.BuildMessages(nil, "", "hello", nil, "test", "chat", "", "")
if len(msgs) < 2 {
errs <- "BuildMessages returned fewer than 2 messages"
return
@@ -664,6 +726,6 @@ func BenchmarkBuildMessagesWithCache(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = cb.BuildMessages(history, "summary", "new message", nil, "cli", "test")
_ = cb.BuildMessages(history, "summary", "new message", nil, "cli", "test", "", "")
}
}
+25 -2
View File
@@ -10,6 +10,7 @@ import (
"strings"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/media"
"github.com/sipeed/picoclaw/pkg/memory"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/routing"
@@ -66,7 +67,7 @@ func NewAgentInstance(
readRestrict := restrict && !defaults.AllowReadOutsideWorkspace
// Compile path whitelist patterns from config.
allowReadPaths := compilePatterns(cfg.Tools.AllowReadPaths)
allowReadPaths := buildAllowReadPatterns(cfg)
allowWritePaths := compilePatterns(cfg.Tools.AllowWritePaths)
toolsRegistry := tools.NewToolRegistry()
@@ -82,7 +83,7 @@ func NewAgentInstance(
toolsRegistry.Register(tools.NewListDirTool(workspace, readRestrict, allowReadPaths))
}
if cfg.Tools.IsToolEnabled("exec") {
execTool, err := tools.NewExecToolWithConfig(workspace, restrict, cfg)
execTool, err := tools.NewExecToolWithConfig(workspace, restrict, cfg, allowReadPaths)
if err != nil {
log.Fatalf("Critical error: unable to initialize exec tool: %v", err)
}
@@ -282,6 +283,28 @@ func compilePatterns(patterns []string) []*regexp.Regexp {
return compiled
}
func buildAllowReadPatterns(cfg *config.Config) []*regexp.Regexp {
var configured []string
if cfg != nil {
configured = cfg.Tools.AllowReadPaths
}
compiled := compilePatterns(configured)
mediaDirPattern := regexp.MustCompile(mediaTempDirPattern())
for _, pattern := range compiled {
if pattern.String() == mediaDirPattern.String() {
return compiled
}
}
return append(compiled, mediaDirPattern)
}
func mediaTempDirPattern() string {
sep := regexp.QuoteMeta(string(os.PathSeparator))
return "^" + regexp.QuoteMeta(filepath.Clean(media.TempDir())) + "(?:" + sep + "|$)"
}
// Close releases resources held by the agent's session store.
func (a *AgentInstance) Close() error {
if a.Sessions != nil {
+86
View File
@@ -1,10 +1,14 @@
package agent
import (
"context"
"os"
"path/filepath"
"strings"
"testing"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/media"
)
func TestNewAgentInstance_UsesDefaultsTemperatureAndMaxTokens(t *testing.T) {
@@ -160,3 +164,85 @@ func TestNewAgentInstance_ResolveCandidatesFromModelListAlias(t *testing.T) {
})
}
}
func TestNewAgentInstance_AllowsMediaTempDirForReadListAndExec(t *testing.T) {
workspace := t.TempDir()
mediaDir := media.TempDir()
if err := os.MkdirAll(mediaDir, 0o700); err != nil {
t.Fatalf("MkdirAll(mediaDir) error = %v", err)
}
mediaFile, err := os.CreateTemp(mediaDir, "instance-tool-*.txt")
if err != nil {
t.Fatalf("CreateTemp(mediaDir) error = %v", err)
}
mediaPath := mediaFile.Name()
if _, err := mediaFile.WriteString("attachment content"); err != nil {
mediaFile.Close()
t.Fatalf("WriteString(mediaFile) error = %v", err)
}
if err := mediaFile.Close(); err != nil {
t.Fatalf("Close(mediaFile) error = %v", err)
}
t.Cleanup(func() { _ = os.Remove(mediaPath) })
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: workspace,
ModelName: "test-model",
RestrictToWorkspace: true,
},
},
Tools: config.ToolsConfig{
ReadFile: config.ReadFileToolConfig{Enabled: true},
ListDir: config.ToolConfig{Enabled: true},
Exec: config.ExecConfig{
ToolConfig: config.ToolConfig{Enabled: true},
EnableDenyPatterns: true,
AllowRemote: true,
},
},
}
agent := NewAgentInstance(nil, &cfg.Agents.Defaults, cfg, &mockProvider{})
readTool, ok := agent.Tools.Get("read_file")
if !ok {
t.Fatal("read_file tool not registered")
}
readResult := readTool.Execute(context.Background(), map[string]any{"path": mediaPath})
if readResult.IsError {
t.Fatalf("read_file should allow media temp dir, got: %s", readResult.ForLLM)
}
if !strings.Contains(readResult.ForLLM, "attachment content") {
t.Fatalf("read_file output missing media content: %s", readResult.ForLLM)
}
listTool, ok := agent.Tools.Get("list_dir")
if !ok {
t.Fatal("list_dir tool not registered")
}
listResult := listTool.Execute(context.Background(), map[string]any{"path": mediaDir})
if listResult.IsError {
t.Fatalf("list_dir should allow media temp dir, got: %s", listResult.ForLLM)
}
if !strings.Contains(listResult.ForLLM, filepath.Base(mediaPath)) {
t.Fatalf("list_dir output missing media file: %s", listResult.ForLLM)
}
execTool, ok := agent.Tools.Get("exec")
if !ok {
t.Fatal("exec tool not registered")
}
execResult := execTool.Execute(context.Background(), map[string]any{
"command": "cat " + filepath.Base(mediaPath),
"working_dir": mediaDir,
})
if execResult.IsError {
t.Fatalf("exec should allow media temp dir, got: %s", execResult.ForLLM)
}
if !strings.Contains(execResult.ForLLM, "attachment content") {
t.Fatalf("exec output missing media content: %s", execResult.ForLLM)
}
}
+278 -60
View File
@@ -48,19 +48,25 @@ type AgentLoop struct {
transcriber voice.Transcriber
cmdRegistry *commands.Registry
mcp mcpRuntime
mu sync.RWMutex
reloadFunc func() error
// Track active requests for safe provider cleanup
activeRequests sync.WaitGroup
}
// processOptions configures how a message is processed
type processOptions struct {
SessionKey string // Session identifier for history/context
Channel string // Target channel for tool execution
ChatID string // Target chat ID for tool execution
UserMessage string // User message content (may include prefix)
Media []string // media:// refs from inbound message
DefaultResponse string // Response when LLM returns empty
EnableSummary bool // Whether to trigger summarization
SendResponse bool // Whether to send response via bus
NoHistory bool // If true, don't load session history (for heartbeat)
SessionKey string // Session identifier for history/context
Channel string // Target channel for tool execution
ChatID string // Target chat ID for tool execution
SenderID string // Current sender ID for dynamic context
SenderDisplayName string // Current sender display name for dynamic context
UserMessage string // User message content (may include prefix)
Media []string // media:// refs from inbound message
DefaultResponse string // Response when LLM returns empty
EnableSummary bool // Whether to trigger summarization
SendResponse bool // Whether to send response via bus
NoHistory bool // If true, don't load session history (for heartbeat)
}
const (
@@ -114,6 +120,8 @@ func registerSharedTools(
registry *AgentRegistry,
provider providers.LLMProvider,
) {
allowReadPaths := buildAllowReadPatterns(cfg)
for _, agentID := range registry.ListAgentIDs() {
agent, ok := registry.GetAgent(agentID)
if !ok {
@@ -154,7 +162,12 @@ func registerSharedTools(
}
}
if cfg.Tools.IsToolEnabled("web_fetch") {
fetchTool, err := tools.NewWebFetchToolWithProxy(50000, cfg.Tools.Web.Proxy, cfg.Tools.Web.FetchLimitBytes)
fetchTool, err := tools.NewWebFetchToolWithProxy(
50000,
cfg.Tools.Web.Proxy,
cfg.Tools.Web.Format,
cfg.Tools.Web.FetchLimitBytes,
cfg.Tools.Web.PrivateHostWhitelist)
if err != nil {
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
} else {
@@ -192,6 +205,7 @@ func registerSharedTools(
cfg.Agents.Defaults.RestrictToWorkspace,
cfg.Agents.Defaults.GetMaxMediaSize(),
nil,
allowReadPaths,
)
agent.Tools.Register(sendFileTool)
}
@@ -219,26 +233,38 @@ func registerSharedTools(
}
}
// Spawn tool with allowlist checker
if cfg.Tools.IsToolEnabled("spawn") {
if cfg.Tools.IsToolEnabled("subagent") {
subagentManager := tools.NewSubagentManager(provider, agent.Model, agent.Workspace)
subagentManager.SetLLMOptions(agent.MaxTokens, agent.Temperature)
// Spawn and spawn_status tools share a SubagentManager.
// Construct it when either tool is enabled (both require subagent).
spawnEnabled := cfg.Tools.IsToolEnabled("spawn")
spawnStatusEnabled := cfg.Tools.IsToolEnabled("spawn_status")
if (spawnEnabled || spawnStatusEnabled) && cfg.Tools.IsToolEnabled("subagent") {
subagentManager := tools.NewSubagentManager(provider, agent.Model, agent.Workspace)
subagentManager.SetLLMOptions(agent.MaxTokens, agent.Temperature)
// Clone the parent's tool registry so subagents can use all
// tools registered so far (file, web, etc.) but NOT spawn/
// spawn_status which are added below — preventing recursive
// subagent spawning.
subagentManager.SetTools(agent.Tools.Clone())
if spawnEnabled {
spawnTool := tools.NewSpawnTool(subagentManager)
currentAgentID := agentID
spawnTool.SetAllowlistChecker(func(targetAgentID string) bool {
return registry.CanSpawnSubagent(currentAgentID, targetAgentID)
})
agent.Tools.Register(spawnTool)
} else {
logger.WarnCF("agent", "spawn tool requires subagent to be enabled", nil)
}
if spawnStatusEnabled {
agent.Tools.Register(tools.NewSpawnStatusTool(subagentManager))
}
} else if (spawnEnabled || spawnStatusEnabled) && !cfg.Tools.IsToolEnabled("subagent") {
logger.WarnCF("agent", "spawn/spawn_status tools require subagent to be enabled", nil)
}
}
}
func (al *AgentLoop) Run(ctx context.Context) error {
al.running.Store(true)
if err := al.ensureMCPInitialized(ctx); err != nil {
return err
}
@@ -247,12 +273,10 @@ func (al *AgentLoop) Run(ctx context.Context) error {
select {
case <-ctx.Done():
return nil
default:
msg, ok := al.bus.ConsumeInbound(ctx)
case msg, ok := <-al.bus.InboundChan():
if !ok {
continue
return nil
}
// Process message
func() {
defer func() {
@@ -283,7 +307,7 @@ func (al *AgentLoop) Run(ctx context.Context) error {
// If so, skip publishing to avoid duplicate messages to the user.
// Use default agent's tools to check (message tool is shared).
alreadySent := false
defaultAgent := al.registry.GetDefaultAgent()
defaultAgent := al.GetRegistry().GetDefaultAgent()
if defaultAgent != nil {
if tool, ok := defaultAgent.Tools.Get("message"); ok {
if mt, ok := tool.(*tools.MessageTool); ok {
@@ -291,7 +315,6 @@ func (al *AgentLoop) Run(ctx context.Context) error {
}
}
}
if !alreadySent {
al.bus.PublishOutbound(ctx, bus.OutboundMessage{
Channel: msg.Channel,
@@ -313,6 +336,8 @@ func (al *AgentLoop) Run(ctx context.Context) error {
}
}
}()
default:
time.Sleep(time.Microsecond * 200)
}
}
@@ -336,12 +361,13 @@ func (al *AgentLoop) Close() {
}
}
al.registry.Close()
al.GetRegistry().Close()
}
func (al *AgentLoop) RegisterTool(tool tools.Tool) {
for _, agentID := range al.registry.ListAgentIDs() {
if agent, ok := al.registry.GetAgent(agentID); ok {
registry := al.GetRegistry()
for _, agentID := range registry.ListAgentIDs() {
if agent, ok := registry.GetAgent(agentID); ok {
agent.Tools.Register(tool)
}
}
@@ -351,12 +377,123 @@ func (al *AgentLoop) SetChannelManager(cm *channels.Manager) {
al.channelManager = cm
}
// ReloadProviderAndConfig atomically swaps the provider and config with proper synchronization.
// It uses a context to allow timeout control from the caller.
// Returns an error if the reload fails or context is canceled.
func (al *AgentLoop) ReloadProviderAndConfig(
ctx context.Context,
provider providers.LLMProvider,
cfg *config.Config,
) error {
// Validate inputs
if provider == nil {
return fmt.Errorf("provider cannot be nil")
}
if cfg == nil {
return fmt.Errorf("config cannot be nil")
}
// Create new registry with updated config and provider
// Wrap in defer/recover to handle any panics gracefully
var registry *AgentRegistry
var panicErr error
done := make(chan struct{}, 1)
go func() {
defer func() {
if r := recover(); r != nil {
panicErr = fmt.Errorf("panic during registry creation: %v", r)
logger.ErrorCF("agent", "Panic during registry creation",
map[string]any{"panic": r})
}
close(done)
}()
registry = NewAgentRegistry(cfg, provider)
}()
// Wait for completion or context cancellation
select {
case <-done:
if registry == nil {
if panicErr != nil {
return fmt.Errorf("registry creation failed: %w", panicErr)
}
return fmt.Errorf("registry creation failed (nil result)")
}
case <-ctx.Done():
return fmt.Errorf("context canceled during registry creation: %w", ctx.Err())
}
// Check context again before proceeding
if err := ctx.Err(); err != nil {
return fmt.Errorf("context canceled after registry creation: %w", err)
}
// Ensure shared tools are re-registered on the new registry
registerSharedTools(cfg, al.bus, registry, provider)
// Atomically swap the config and registry under write lock
// This ensures readers see a consistent pair
al.mu.Lock()
oldRegistry := al.registry
// Store new values
al.cfg = cfg
al.registry = registry
// Also update fallback chain with new config
al.fallback = providers.NewFallbackChain(providers.NewCooldownTracker())
al.mu.Unlock()
// Close old provider after releasing the lock
// This prevents blocking readers while closing
if oldProvider, ok := extractProvider(oldRegistry); ok {
if stateful, ok := oldProvider.(providers.StatefulProvider); ok {
// Give in-flight requests a moment to complete
// Use a reasonable timeout that balances cleanup vs resource usage
select {
case <-time.After(100 * time.Millisecond):
stateful.Close()
case <-ctx.Done():
// Context canceled, close immediately but log warning
logger.WarnCF("agent", "Context canceled during provider cleanup, forcing close",
map[string]any{"error": ctx.Err()})
stateful.Close()
}
}
}
logger.InfoCF("agent", "Provider and config reloaded successfully",
map[string]any{
"model": cfg.Agents.Defaults.GetModelName(),
})
return nil
}
// GetRegistry returns the current registry (thread-safe)
func (al *AgentLoop) GetRegistry() *AgentRegistry {
al.mu.RLock()
defer al.mu.RUnlock()
return al.registry
}
// GetConfig returns the current config (thread-safe)
func (al *AgentLoop) GetConfig() *config.Config {
al.mu.RLock()
defer al.mu.RUnlock()
return al.cfg
}
// SetMediaStore injects a MediaStore for media lifecycle management.
func (al *AgentLoop) SetMediaStore(s media.MediaStore) {
al.mediaStore = s
// Propagate store to send_file tools in all agents.
al.registry.ForEachTool("send_file", func(t tools.Tool) {
registry := al.GetRegistry()
registry.ForEachTool("send_file", func(t tools.Tool) {
if sf, ok := t.(*tools.SendFileTool); ok {
sf.SetMediaStore(s)
}
@@ -368,6 +505,11 @@ func (al *AgentLoop) SetTranscriber(t voice.Transcriber) {
al.transcriber = t
}
// SetReloadFunc sets the callback function for triggering config reload.
func (al *AgentLoop) SetReloadFunc(fn func() error) {
al.reloadFunc = fn
}
var audioAnnotationRe = regexp.MustCompile(`\[(voice|audio)(?::[^\]]*)?\]`)
// transcribeAudioInMessage resolves audio media refs, transcribes them, and
@@ -545,7 +687,7 @@ func (al *AgentLoop) ProcessHeartbeat(
ctx context.Context,
content, channel, chatID string,
) (string, error) {
agent := al.registry.GetDefaultAgent()
agent := al.GetRegistry().GetDefaultAgent()
if agent == nil {
return "", fmt.Errorf("no default agent for heartbeat")
}
@@ -621,14 +763,16 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage)
})
opts := processOptions{
SessionKey: sessionKey,
Channel: msg.Channel,
ChatID: msg.ChatID,
UserMessage: msg.Content,
Media: msg.Media,
DefaultResponse: defaultResponse,
EnableSummary: true,
SendResponse: false,
SessionKey: sessionKey,
Channel: msg.Channel,
ChatID: msg.ChatID,
SenderID: msg.SenderID,
SenderDisplayName: msg.Sender.DisplayName,
UserMessage: msg.Content,
Media: msg.Media,
DefaultResponse: defaultResponse,
EnableSummary: true,
SendResponse: false,
}
// context-dependent commands check their own Runtime fields and report
@@ -641,7 +785,8 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage)
}
func (al *AgentLoop) resolveMessageRoute(msg bus.InboundMessage) (routing.ResolvedRoute, *AgentInstance, error) {
route := al.registry.ResolveRoute(routing.RouteInput{
registry := al.GetRegistry()
route := registry.ResolveRoute(routing.RouteInput{
Channel: msg.Channel,
AccountID: inboundMetadata(msg, metadataKeyAccountID),
Peer: extractPeer(msg),
@@ -650,9 +795,9 @@ func (al *AgentLoop) resolveMessageRoute(msg bus.InboundMessage) (routing.Resolv
TeamID: inboundMetadata(msg, metadataKeyTeamID),
})
agent, ok := al.registry.GetAgent(route.AgentID)
agent, ok := registry.GetAgent(route.AgentID)
if !ok {
agent = al.registry.GetDefaultAgent()
agent = registry.GetDefaultAgent()
}
if agent == nil {
return routing.ResolvedRoute{}, nil, fmt.Errorf("no agent available for route (agent_id=%s)", route.AgentID)
@@ -714,7 +859,7 @@ func (al *AgentLoop) processSystemMessage(
}
// Use default agent for system messages
agent := al.registry.GetDefaultAgent()
agent := al.GetRegistry().GetDefaultAgent()
if agent == nil {
return "", fmt.Errorf("no default agent for system message")
}
@@ -767,10 +912,13 @@ func (al *AgentLoop) runAgentLoop(
opts.Media,
opts.Channel,
opts.ChatID,
opts.SenderID,
opts.SenderDisplayName,
)
// Resolve media:// refs to base64 data URLs (streaming)
maxMediaSize := al.cfg.Agents.Defaults.GetMaxMediaSize()
// Resolve media:// refs: images→base64 data URLs, non-images→local paths in content
cfg := al.GetConfig()
maxMediaSize := cfg.Agents.Defaults.GetMaxMediaSize()
messages = resolveMediaRefs(messages, al.mediaStore, maxMediaSize)
// 2. Save user message to session
@@ -906,6 +1054,19 @@ func (al *AgentLoop) runLLMIteration(
// Build tool definitions
providerToolDefs := agent.Tools.ToProviderDefs()
// Determine whether the provider's native web search should replace
// the client-side web_search tool for this request. Only enable when web
// search is actually enabled and registered (so users who disabled web
// access do not get provider-side search or billing).
_, hasWebSearch := agent.Tools.Get("web_search")
useNativeSearch := al.cfg.Tools.Web.PreferNative &&
isNativeSearchProvider(agent.Provider) &&
hasWebSearch
if useNativeSearch {
providerToolDefs = filterClientWebSearch(providerToolDefs)
}
// Log LLM request details
logger.DebugCF("agent", "LLM request",
map[string]any{
@@ -914,6 +1075,7 @@ func (al *AgentLoop) runLLMIteration(
"model": activeModel,
"messages_count": len(messages),
"tools_count": len(providerToolDefs),
"native_search": useNativeSearch,
"max_tokens": agent.MaxTokens,
"temperature": agent.Temperature,
"system_prompt_len": len(messages[0].Content),
@@ -936,6 +1098,9 @@ func (al *AgentLoop) runLLMIteration(
"temperature": agent.Temperature,
"prompt_cache_key": agent.ID,
}
if useNativeSearch {
llmOpts["native_search"] = true
}
// parseThinkingLevel guarantees ThinkingOff for empty/unknown values,
// so checking != ThinkingOff is sufficient.
if agent.ThinkingLevel != ThinkingOff {
@@ -948,6 +1113,9 @@ func (al *AgentLoop) runLLMIteration(
}
callLLM := func() (*providers.LLMResponse, error) {
al.activeRequests.Add(1)
defer al.activeRequests.Done()
if len(activeCandidates) > 1 && al.fallback != nil {
fbResult, fbErr := al.fallback.Execute(
ctx,
@@ -1034,7 +1202,7 @@ func (al *AgentLoop) runLLMIteration(
newSummary := agent.Sessions.GetSummary(opts.SessionKey)
messages = agent.ContextBuilder.BuildMessages(
newHistory, newSummary, "",
nil, opts.Channel, opts.ChatID,
nil, opts.Channel, opts.ChatID, opts.SenderID, opts.SenderDisplayName,
)
continue
}
@@ -1046,6 +1214,7 @@ func (al *AgentLoop) runLLMIteration(
map[string]any{
"agent_id": agent.ID,
"iteration": iteration,
"model": activeModel,
"error": err.Error(),
})
return "", iteration, fmt.Errorf("LLM call failed after retries: %w", err)
@@ -1397,7 +1566,8 @@ func (al *AgentLoop) forceCompression(agent *AgentInstance, sessionKey string) {
func (al *AgentLoop) GetStartupInfo() map[string]any {
info := make(map[string]any)
agent := al.registry.GetDefaultAgent()
registry := al.GetRegistry()
agent := registry.GetDefaultAgent()
if agent == nil {
return info
}
@@ -1414,8 +1584,8 @@ func (al *AgentLoop) GetStartupInfo() map[string]any {
// Agents info
info["agents"] = map[string]any{
"count": len(al.registry.ListAgentIDs()),
"ids": al.registry.ListAgentIDs(),
"count": len(registry.ListAgentIDs()),
"ids": registry.ListAgentIDs(),
}
return info
@@ -1603,17 +1773,22 @@ func (al *AgentLoop) retryLLMCall(
var err error
for attempt := 0; attempt < maxRetries; attempt++ {
resp, err = agent.Provider.Chat(
ctx,
[]providers.Message{{Role: "user", Content: prompt}},
nil,
agent.Model,
map[string]any{
"max_tokens": agent.MaxTokens,
"temperature": llmTemperature,
"prompt_cache_key": agent.ID,
},
)
al.activeRequests.Add(1)
resp, err = func() (*providers.LLMResponse, error) {
defer al.activeRequests.Done()
return agent.Provider.Chat(
ctx,
[]providers.Message{{Role: "user", Content: prompt}},
nil,
agent.Model,
map[string]any{
"max_tokens": agent.MaxTokens,
"temperature": llmTemperature,
"prompt_cache_key": agent.ID,
},
)
}()
if err == nil && resp != nil && resp.Content != "" {
return resp, nil
}
@@ -1746,9 +1921,11 @@ func (al *AgentLoop) handleCommand(
}
func (al *AgentLoop) buildCommandsRuntime(agent *AgentInstance, opts *processOptions) *commands.Runtime {
registry := al.GetRegistry()
cfg := al.GetConfig()
rt := &commands.Runtime{
Config: al.cfg,
ListAgentIDs: al.registry.ListAgentIDs,
Config: cfg,
ListAgentIDs: registry.ListAgentIDs,
ListDefinitions: al.cmdRegistry.Definitions,
GetEnabledChannels: func() []string {
if al.channelManager == nil {
@@ -1766,9 +1943,15 @@ func (al *AgentLoop) buildCommandsRuntime(agent *AgentInstance, opts *processOpt
return nil
},
}
rt.ReloadConfig = func() error {
if al.reloadFunc == nil {
return fmt.Errorf("reload not configured")
}
return al.reloadFunc()
}
if agent != nil {
rt.GetModelInfo = func() (string, string) {
return agent.Model, al.cfg.Agents.Defaults.Provider
return agent.Model, cfg.Agents.Defaults.Provider
}
rt.SwitchModel = func(value string) (string, error) {
oldModel := agent.Model
@@ -1832,3 +2015,38 @@ func extractParentPeer(msg bus.InboundMessage) *routing.RoutePeer {
}
return &routing.RoutePeer{Kind: parentKind, ID: parentID}
}
// isNativeSearchProvider reports whether the given LLM provider implements
// NativeSearchCapable and returns true for SupportsNativeSearch.
func isNativeSearchProvider(p providers.LLMProvider) bool {
if ns, ok := p.(providers.NativeSearchCapable); ok {
return ns.SupportsNativeSearch()
}
return false
}
// filterClientWebSearch returns a copy of tools with the client-side
// web_search tool removed. Used when native provider search is preferred.
func filterClientWebSearch(tools []providers.ToolDefinition) []providers.ToolDefinition {
result := make([]providers.ToolDefinition, 0, len(tools))
for _, t := range tools {
if strings.EqualFold(t.Function.Name, "web_search") {
continue
}
result = append(result, t)
}
return result
}
// Helper to extract provider from registry for cleanup
func extractProvider(registry *AgentRegistry) (providers.LLMProvider, bool) {
if registry == nil {
return nil, false
}
// Get any agent to access the provider
defaultAgent := registry.GetDefaultAgent()
if defaultAgent == nil {
return nil, false
}
return defaultAgent.Provider, true
}
+16
View File
@@ -63,6 +63,22 @@ func (al *AgentLoop) ensureMCPInitialized(ctx context.Context) error {
return nil
}
if al.cfg.Tools.MCP.Servers == nil || len(al.cfg.Tools.MCP.Servers) == 0 {
logger.WarnCF("agent", "MCP is enabled but no servers are configured, skipping MCP initialization", nil)
return nil
}
findValidServer := false
for _, serverCfg := range al.cfg.Tools.MCP.Servers {
if serverCfg.Enabled {
findValidServer = true
}
}
if !findValidServer {
logger.WarnCF("agent", "MCP is enabled but no valid servers are configured, skipping MCP initialization", nil)
return nil
}
al.mcp.initOnce.Do(func() {
mcpManager := mcp.NewManager()
+108 -50
View File
@@ -20,9 +20,10 @@ import (
"github.com/sipeed/picoclaw/pkg/providers"
)
// resolveMediaRefs replaces media:// refs in message Media fields with base64 data URLs.
// Uses streaming base64 encoding (file handle → encoder → buffer) to avoid holding
// both raw bytes and encoded string in memory simultaneously.
// resolveMediaRefs resolves media:// refs in messages.
// Images are base64-encoded into the Media array for multimodal LLMs.
// Non-image files (documents, audio, video) have their local path injected
// into Content so the agent can access them via file tools like read_file.
// Returns a new slice; original messages are not mutated.
func resolveMediaRefs(messages []providers.Message, store media.MediaStore, maxSize int) []providers.Message {
if store == nil {
@@ -38,6 +39,8 @@ func resolveMediaRefs(messages []providers.Message, store media.MediaStore, maxS
}
resolved := make([]string, 0, len(m.Media))
var pathTags []string
for _, ref := range m.Media {
if !strings.HasPrefix(ref, "media://") {
resolved = append(resolved, ref)
@@ -61,62 +64,117 @@ func resolveMediaRefs(messages []providers.Message, store media.MediaStore, maxS
})
continue
}
if info.Size() > int64(maxSize) {
logger.WarnCF("agent", "Media file too large, skipping", map[string]any{
"path": localPath,
"size": info.Size(),
"max_size": maxSize,
})
continue
}
// Determine MIME type: prefer metadata, fallback to magic-bytes detection
mime := meta.ContentType
if mime == "" {
kind, ftErr := filetype.MatchFile(localPath)
if ftErr != nil || kind == filetype.Unknown {
logger.WarnCF("agent", "Unknown media type, skipping", map[string]any{
"path": localPath,
})
continue
mime := detectMIME(localPath, meta)
if strings.HasPrefix(mime, "image/") {
dataURL := encodeImageToDataURL(localPath, mime, info, maxSize)
if dataURL != "" {
resolved = append(resolved, dataURL)
}
mime = kind.MIME.Value
}
// Streaming base64: open file → base64 encoder → buffer
// Peak memory: ~1.33x file size (buffer only, no raw bytes copy)
f, err := os.Open(localPath)
if err != nil {
logger.WarnCF("agent", "Failed to open media file", map[string]any{
"path": localPath,
"error": err.Error(),
})
continue
}
prefix := "data:" + mime + ";base64,"
encodedLen := base64.StdEncoding.EncodedLen(int(info.Size()))
var buf bytes.Buffer
buf.Grow(len(prefix) + encodedLen)
buf.WriteString(prefix)
encoder := base64.NewEncoder(base64.StdEncoding, &buf)
if _, err := io.Copy(encoder, f); err != nil {
f.Close()
logger.WarnCF("agent", "Failed to encode media file", map[string]any{
"path": localPath,
"error": err.Error(),
})
continue
}
encoder.Close()
f.Close()
resolved = append(resolved, buf.String())
pathTags = append(pathTags, buildPathTag(mime, localPath))
}
result[i].Media = resolved
if len(pathTags) > 0 {
result[i].Content = injectPathTags(result[i].Content, pathTags)
}
}
return result
}
// detectMIME determines the MIME type from metadata or magic-bytes detection.
// Returns empty string if detection fails.
func detectMIME(localPath string, meta media.MediaMeta) string {
if meta.ContentType != "" {
return meta.ContentType
}
kind, err := filetype.MatchFile(localPath)
if err != nil || kind == filetype.Unknown {
return ""
}
return kind.MIME.Value
}
// encodeImageToDataURL base64-encodes an image file into a data URL.
// Returns empty string if the file exceeds maxSize or encoding fails.
func encodeImageToDataURL(localPath, mime string, info os.FileInfo, maxSize int) string {
if info.Size() > int64(maxSize) {
logger.WarnCF("agent", "Media file too large, skipping", map[string]any{
"path": localPath,
"size": info.Size(),
"max_size": maxSize,
})
return ""
}
f, err := os.Open(localPath)
if err != nil {
logger.WarnCF("agent", "Failed to open media file", map[string]any{
"path": localPath,
"error": err.Error(),
})
return ""
}
defer f.Close()
prefix := "data:" + mime + ";base64,"
encodedLen := base64.StdEncoding.EncodedLen(int(info.Size()))
var buf bytes.Buffer
buf.Grow(len(prefix) + encodedLen)
buf.WriteString(prefix)
encoder := base64.NewEncoder(base64.StdEncoding, &buf)
if _, err := io.Copy(encoder, f); err != nil {
logger.WarnCF("agent", "Failed to encode media file", map[string]any{
"path": localPath,
"error": err.Error(),
})
return ""
}
encoder.Close()
return buf.String()
}
// buildPathTag creates a structured tag exposing the local file path.
// Tag type is derived from MIME: [audio:/path], [video:/path], or [file:/path].
func buildPathTag(mime, localPath string) string {
switch {
case strings.HasPrefix(mime, "audio/"):
return "[audio:" + localPath + "]"
case strings.HasPrefix(mime, "video/"):
return "[video:" + localPath + "]"
default:
return "[file:" + localPath + "]"
}
}
// injectPathTags replaces generic media tags in content with path-bearing versions,
// or appends if no matching generic tag is found.
func injectPathTags(content string, tags []string) string {
for _, tag := range tags {
var generic string
switch {
case strings.HasPrefix(tag, "[audio:"):
generic = "[audio]"
case strings.HasPrefix(tag, "[video:"):
generic = "[video]"
case strings.HasPrefix(tag, "[file:"):
generic = "[file]"
}
if generic != "" && strings.Contains(content, generic) {
content = strings.Replace(content, generic, tag, 1)
} else if content == "" {
content = tag
} else {
content += " " + tag
}
}
return content
}
+385 -44
View File
@@ -30,6 +30,28 @@ func (f *fakeChannel) IsAllowed(string) bool {
func (f *fakeChannel) IsAllowedSender(sender bus.SenderInfo) bool { return true }
func (f *fakeChannel) ReasoningChannelID() string { return f.id }
type recordingProvider struct {
lastMessages []providers.Message
}
func (r *recordingProvider) Chat(
ctx context.Context,
messages []providers.Message,
tools []providers.ToolDefinition,
model string,
opts map[string]any,
) (*providers.LLMResponse, error) {
r.lastMessages = append([]providers.Message(nil), messages...)
return &providers.LLMResponse{
Content: "Mock response",
ToolCalls: []providers.ToolCall{},
}, nil
}
func (r *recordingProvider) GetDefaultModel() string {
return "mock-model"
}
func newTestAgentLoop(
t *testing.T,
) (al *AgentLoop, cfg *config.Config, msgBus *bus.MessageBus, provider *mockProvider, cleanup func()) {
@@ -54,6 +76,59 @@ func newTestAgentLoop(
return al, cfg, msgBus, provider, func() { os.RemoveAll(tmpDir) }
}
func TestProcessMessage_IncludesCurrentSenderInDynamicContext(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
msgBus := bus.NewMessageBus()
provider := &recordingProvider{}
al := NewAgentLoop(cfg, msgBus, provider)
response, err := al.processMessage(context.Background(), bus.InboundMessage{
Channel: "discord",
SenderID: "discord:123",
Sender: bus.SenderInfo{
DisplayName: "Alice",
},
ChatID: "group-1",
Content: "hello",
})
if err != nil {
t.Fatalf("processMessage() error = %v", err)
}
if response != "Mock response" {
t.Fatalf("processMessage() response = %q, want %q", response, "Mock response")
}
if len(provider.lastMessages) == 0 {
t.Fatal("provider did not receive any messages")
}
systemPrompt := provider.lastMessages[0].Content
wantSender := "## Current Sender\nCurrent sender: Alice (ID: discord:123)"
if !strings.Contains(systemPrompt, wantSender) {
t.Fatalf("system prompt missing sender context %q:\n%s", wantSender, systemPrompt)
}
lastMessage := provider.lastMessages[len(provider.lastMessages)-1]
if lastMessage.Role != "user" || lastMessage.Content != "hello" {
t.Fatalf("last provider message = %+v, want unchanged user message", lastMessage)
}
}
func TestRecordLastChannel(t *testing.T) {
al, cfg, msgBus, provider, cleanup := newTestAgentLoop(t)
defer cleanup()
@@ -770,13 +845,18 @@ func TestAgentLoop_ContextExhaustionRetry(t *testing.T) {
}
}
func TestProcessDirectWithChannel_InitializesMCPInAgentMode(t *testing.T) {
// TestProcessDirectWithChannel_TriggersMCPInitialization verifies that
// ProcessDirectWithChannel triggers MCP initialization when MCP is enabled.
// Note: Manager is only initialized when at least one MCP server is configured
// and successfully connected.
func TestProcessDirectWithChannel_TriggersMCPInitialization(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
// Test with MCP enabled but no servers - should not initialize manager
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
@@ -791,6 +871,7 @@ func TestProcessDirectWithChannel_InitializesMCPInAgentMode(t *testing.T) {
ToolConfig: config.ToolConfig{
Enabled: true,
},
// No servers configured - manager should not be initialized
},
},
}
@@ -815,8 +896,9 @@ func TestProcessDirectWithChannel_InitializesMCPInAgentMode(t *testing.T) {
t.Fatalf("ProcessDirectWithChannel failed: %v", err)
}
if !al.mcp.hasManager() {
t.Fatal("expected MCP manager to be initialized in direct agent mode")
// Manager should not be initialized when no servers are configured
if al.mcp.hasManager() {
t.Fatal("expected MCP manager to be nil when no servers are configured")
}
}
@@ -915,10 +997,25 @@ func TestHandleReasoning(t *testing.T) {
al, msgBus := newLoop(t)
al.handleReasoning(context.Background(), "reasoning", "telegram", "")
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if msg, ok := msgBus.SubscribeOutbound(ctx); ok {
t.Fatalf("expected no outbound message, got %+v", msg)
for {
select {
case msg, ok := <-msgBus.OutboundChan():
if !ok {
t.Fatalf("expected no outbound message, got %+v", msg)
}
if msg.Content == "reasoning" {
t.Fatalf("expected no message for empty chatID, got %+v", msg)
}
return
case <-ctx.Done():
t.Log("expected an outbound message, got none within timeout")
return
default:
// Continue to check for message
time.Sleep(5 * time.Millisecond) // Avoid busy loop
}
}
})
@@ -926,9 +1023,7 @@ func TestHandleReasoning(t *testing.T) {
al, msgBus := newLoop(t)
al.handleReasoning(context.Background(), "hello reasoning", "slack", "channel-1")
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
msg, ok := msgBus.SubscribeOutbound(ctx)
msg, ok := <-msgBus.OutboundChan()
if !ok {
t.Fatal("expected an outbound message")
}
@@ -942,35 +1037,52 @@ func TestHandleReasoning(t *testing.T) {
reasoning := "hello telegram reasoning"
al.handleReasoning(context.Background(), reasoning, "telegram", "tg-chat")
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
msg, ok := msgBus.SubscribeOutbound(ctx)
if !ok {
t.Fatal("expected outbound message")
}
for {
select {
case <-ctx.Done():
t.Fatal("expected an outbound message, got none within timeout")
return
case msg, ok := <-msgBus.OutboundChan():
if !ok {
t.Fatal("expected outbound message")
}
if msg.Channel != "telegram" {
t.Fatalf("expected telegram channel message, got %+v", msg)
}
if msg.ChatID != "tg-chat" {
t.Fatalf("expected chatID tg-chat, got %+v", msg)
}
if msg.Content != reasoning {
t.Fatalf("content mismatch: got %q want %q", msg.Content, reasoning)
if msg.Channel != "telegram" {
t.Fatalf("expected telegram channel message, got %+v", msg)
}
if msg.ChatID != "tg-chat" {
t.Fatalf("expected chatID tg-chat, got %+v", msg)
}
if msg.Content != reasoning {
t.Fatalf("content mismatch: got %q want %q", msg.Content, reasoning)
}
return
}
}
})
t.Run("expired ctx", func(t *testing.T) {
al, msgBus := newLoop(t)
reasoning := "hello telegram reasoning"
ctx, cancel := context.WithCancel(context.Background())
cancel()
al.handleReasoning(ctx, reasoning, "telegram", "tg-chat")
ctx, cancel = context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
msg, ok := msgBus.SubscribeOutbound(ctx)
if ok {
t.Fatalf("expected no outbound message, got %+v", msg)
al.handleReasoning(context.Background(), reasoning, "telegram", "tg-chat")
consumeCtx, consumeCancel := context.WithTimeout(context.Background(), 2*time.Second)
defer consumeCancel()
for {
select {
case msg, ok := <-msgBus.OutboundChan():
if !ok {
t.Fatalf("expected no outbound message, but received: %+v", msg)
}
t.Logf("Received unexpected outbound message: %+v", msg)
return
case <-consumeCtx.Done():
t.Fatalf("failed: no message received within timeout")
return
}
}
})
@@ -1010,20 +1122,23 @@ func TestHandleReasoning(t *testing.T) {
// Drain the bus and verify the reasoning message was NOT published
// (it should have been dropped due to timeout).
drainCtx, drainCancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer drainCancel()
foundReasoning := false
timeer := time.After(1 * time.Second)
for {
msg, ok := msgBus.SubscribeOutbound(drainCtx)
if !ok {
break
select {
case <-timeer:
t.Logf(
"no reasoning message received after draining bus for 1s, as expected,length=%d",
len(msgBus.OutboundChan()),
)
return
case msg, ok := <-msgBus.OutboundChan():
if !ok {
break
}
if msg.Content == "should timeout" {
t.Fatal("expected reasoning message to be dropped when bus is full, but it was published")
}
}
if msg.Content == "should timeout" {
foundReasoning = true
}
}
if foundReasoning {
t.Fatal("expected reasoning message to be dropped when bus is full, but it was published")
}
})
}
@@ -1088,7 +1203,7 @@ func TestResolveMediaRefs_SkipsOversizedFile(t *testing.T) {
}
}
func TestResolveMediaRefs_SkipsUnknownType(t *testing.T) {
func TestResolveMediaRefs_UnknownTypeInjectsPath(t *testing.T) {
store := media.NewFileMediaStore()
dir := t.TempDir()
@@ -1104,7 +1219,11 @@ func TestResolveMediaRefs_SkipsUnknownType(t *testing.T) {
result := resolveMediaRefs(messages, store, config.DefaultMaxMediaSize)
if len(result[0].Media) != 0 {
t.Fatalf("expected 0 media (unknown type), got %d", len(result[0].Media))
t.Fatalf("expected 0 media entries, got %d", len(result[0].Media))
}
expected := "hi [file:" + txtPath + "]"
if result[0].Content != expected {
t.Fatalf("expected content %q, got %q", expected, result[0].Content)
}
}
@@ -1166,3 +1285,225 @@ func TestResolveMediaRefs_UsesMetaContentType(t *testing.T) {
t.Fatalf("expected jpeg prefix, got %q", result[0].Media[0][:30])
}
}
func TestResolveMediaRefs_PDFInjectsFilePath(t *testing.T) {
store := media.NewFileMediaStore()
dir := t.TempDir()
pdfPath := filepath.Join(dir, "report.pdf")
// PDF magic bytes
os.WriteFile(pdfPath, []byte("%PDF-1.4 test content"), 0o644)
ref, _ := store.Store(pdfPath, media.MediaMeta{ContentType: "application/pdf"}, "test")
messages := []providers.Message{
{Role: "user", Content: "report.pdf [file]", Media: []string{ref}},
}
result := resolveMediaRefs(messages, store, config.DefaultMaxMediaSize)
if len(result[0].Media) != 0 {
t.Fatalf("expected 0 media (non-image), got %d", len(result[0].Media))
}
expected := "report.pdf [file:" + pdfPath + "]"
if result[0].Content != expected {
t.Fatalf("expected content %q, got %q", expected, result[0].Content)
}
}
func TestResolveMediaRefs_AudioInjectsAudioPath(t *testing.T) {
store := media.NewFileMediaStore()
dir := t.TempDir()
oggPath := filepath.Join(dir, "voice.ogg")
os.WriteFile(oggPath, []byte("fake audio"), 0o644)
ref, _ := store.Store(oggPath, media.MediaMeta{ContentType: "audio/ogg"}, "test")
messages := []providers.Message{
{Role: "user", Content: "voice.ogg [audio]", Media: []string{ref}},
}
result := resolveMediaRefs(messages, store, config.DefaultMaxMediaSize)
if len(result[0].Media) != 0 {
t.Fatalf("expected 0 media, got %d", len(result[0].Media))
}
expected := "voice.ogg [audio:" + oggPath + "]"
if result[0].Content != expected {
t.Fatalf("expected content %q, got %q", expected, result[0].Content)
}
}
func TestResolveMediaRefs_VideoInjectsVideoPath(t *testing.T) {
store := media.NewFileMediaStore()
dir := t.TempDir()
mp4Path := filepath.Join(dir, "clip.mp4")
os.WriteFile(mp4Path, []byte("fake video"), 0o644)
ref, _ := store.Store(mp4Path, media.MediaMeta{ContentType: "video/mp4"}, "test")
messages := []providers.Message{
{Role: "user", Content: "clip.mp4 [video]", Media: []string{ref}},
}
result := resolveMediaRefs(messages, store, config.DefaultMaxMediaSize)
if len(result[0].Media) != 0 {
t.Fatalf("expected 0 media, got %d", len(result[0].Media))
}
expected := "clip.mp4 [video:" + mp4Path + "]"
if result[0].Content != expected {
t.Fatalf("expected content %q, got %q", expected, result[0].Content)
}
}
func TestResolveMediaRefs_NoGenericTagAppendsPath(t *testing.T) {
store := media.NewFileMediaStore()
dir := t.TempDir()
csvPath := filepath.Join(dir, "data.csv")
os.WriteFile(csvPath, []byte("a,b,c"), 0o644)
ref, _ := store.Store(csvPath, media.MediaMeta{ContentType: "text/csv"}, "test")
messages := []providers.Message{
{Role: "user", Content: "here is my data", Media: []string{ref}},
}
result := resolveMediaRefs(messages, store, config.DefaultMaxMediaSize)
expected := "here is my data [file:" + csvPath + "]"
if result[0].Content != expected {
t.Fatalf("expected content %q, got %q", expected, result[0].Content)
}
}
func TestResolveMediaRefs_EmptyContentGetsPathTag(t *testing.T) {
store := media.NewFileMediaStore()
dir := t.TempDir()
docPath := filepath.Join(dir, "doc.docx")
os.WriteFile(docPath, []byte("fake docx"), 0o644)
docxMIME := "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
ref, _ := store.Store(docPath, media.MediaMeta{ContentType: docxMIME}, "test")
messages := []providers.Message{
{Role: "user", Content: "", Media: []string{ref}},
}
result := resolveMediaRefs(messages, store, config.DefaultMaxMediaSize)
expected := "[file:" + docPath + "]"
if result[0].Content != expected {
t.Fatalf("expected content %q, got %q", expected, result[0].Content)
}
}
func TestResolveMediaRefs_MixedImageAndFile(t *testing.T) {
store := media.NewFileMediaStore()
dir := t.TempDir()
pngPath := filepath.Join(dir, "photo.png")
pngHeader := []byte{
0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A,
0x00, 0x00, 0x00, 0x0D, 0x49, 0x48, 0x44, 0x52,
0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x08, 0x02,
0x00, 0x00, 0x00, 0x90, 0x77, 0x53, 0xDE,
}
os.WriteFile(pngPath, pngHeader, 0o644)
imgRef, _ := store.Store(pngPath, media.MediaMeta{}, "test")
pdfPath := filepath.Join(dir, "report.pdf")
os.WriteFile(pdfPath, []byte("%PDF-1.4 test"), 0o644)
fileRef, _ := store.Store(pdfPath, media.MediaMeta{ContentType: "application/pdf"}, "test")
messages := []providers.Message{
{Role: "user", Content: "check these [file]", Media: []string{imgRef, fileRef}},
}
result := resolveMediaRefs(messages, store, config.DefaultMaxMediaSize)
if len(result[0].Media) != 1 {
t.Fatalf("expected 1 media (image only), got %d", len(result[0].Media))
}
if !strings.HasPrefix(result[0].Media[0], "data:image/png;base64,") {
t.Fatal("expected image to be base64 encoded")
}
expectedContent := "check these [file:" + pdfPath + "]"
if result[0].Content != expectedContent {
t.Fatalf("expected content %q, got %q", expectedContent, result[0].Content)
}
}
// --- Native search helper tests ---
type nativeSearchProvider struct {
supported bool
}
func (p *nativeSearchProvider) Chat(
ctx context.Context, msgs []providers.Message, tools []providers.ToolDefinition,
model string, opts map[string]any,
) (*providers.LLMResponse, error) {
return &providers.LLMResponse{Content: "ok"}, nil
}
func (p *nativeSearchProvider) GetDefaultModel() string { return "test-model" }
func (p *nativeSearchProvider) SupportsNativeSearch() bool { return p.supported }
type plainProvider struct{}
func (p *plainProvider) Chat(
ctx context.Context, msgs []providers.Message, tools []providers.ToolDefinition,
model string, opts map[string]any,
) (*providers.LLMResponse, error) {
return &providers.LLMResponse{Content: "ok"}, nil
}
func (p *plainProvider) GetDefaultModel() string { return "test-model" }
func TestIsNativeSearchProvider_Supported(t *testing.T) {
if !isNativeSearchProvider(&nativeSearchProvider{supported: true}) {
t.Fatal("expected true for provider that supports native search")
}
}
func TestIsNativeSearchProvider_NotSupported(t *testing.T) {
if isNativeSearchProvider(&nativeSearchProvider{supported: false}) {
t.Fatal("expected false for provider that does not support native search")
}
}
func TestIsNativeSearchProvider_NoInterface(t *testing.T) {
if isNativeSearchProvider(&plainProvider{}) {
t.Fatal("expected false for provider that does not implement NativeSearchCapable")
}
}
func TestFilterClientWebSearch_RemovesWebSearch(t *testing.T) {
defs := []providers.ToolDefinition{
{Type: "function", Function: providers.ToolFunctionDefinition{Name: "web_search"}},
{Type: "function", Function: providers.ToolFunctionDefinition{Name: "read_file"}},
{Type: "function", Function: providers.ToolFunctionDefinition{Name: "exec"}},
}
result := filterClientWebSearch(defs)
if len(result) != 2 {
t.Fatalf("len(result) = %d, want 2", len(result))
}
for _, td := range result {
if td.Function.Name == "web_search" {
t.Fatal("web_search should be filtered out")
}
}
}
func TestFilterClientWebSearch_NoWebSearch(t *testing.T) {
defs := []providers.ToolDefinition{
{Type: "function", Function: providers.ToolFunctionDefinition{Name: "read_file"}},
{Type: "function", Function: providers.ToolFunctionDefinition{Name: "exec"}},
}
result := filterClientWebSearch(defs)
if len(result) != 2 {
t.Fatalf("len(result) = %d, want 2", len(result))
}
}
func TestFilterClientWebSearch_EmptyInput(t *testing.T) {
result := filterClientWebSearch(nil)
if len(result) != 0 {
t.Fatalf("len(result) = %d, want 0", len(result))
}
}