chore: revert unrelated golines formatting

This commit is contained in:
afjcjsbx
2026-03-29 14:06:19 +02:00
parent 3b173c0bee
commit 07748bf076
57 changed files with 297 additions and 1278 deletions
+3 -10
View File
@@ -500,11 +500,8 @@ func TestEstimateMessageTokens_ReasoningContent(t *testing.T) {
reasoningTokens := estimateMessageTokens(withReasoning)
if reasoningTokens <= plainTokens {
t.Errorf(
"message with ReasoningContent (%d tokens) should exceed plain message (%d tokens)",
reasoningTokens,
plainTokens,
)
t.Errorf("message with ReasoningContent (%d tokens) should exceed plain message (%d tokens)",
reasoningTokens, plainTokens)
}
}
@@ -767,11 +764,7 @@ func TestEstimateMessageTokens_WithReasoningAndMedia(t *testing.T) {
tokensNoReasoning := estimateMessageTokens(msgNoReasoning)
if tokens <= tokensNoReasoning {
t.Errorf(
"reasoning content should add tokens: with=%d, without=%d",
tokens,
tokensNoReasoning,
)
t.Errorf("reasoning content should add tokens: with=%d, without=%d", tokens, tokensNoReasoning)
}
}
+3 -24
View File
@@ -82,16 +82,7 @@ func TestSingleSystemMessage(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
msgs := cb.BuildMessages(
tt.history,
tt.summary,
tt.message,
nil,
"test",
"chat1",
"",
"",
)
msgs := cb.BuildMessages(tt.history, tt.summary, tt.message, nil, "test", "chat1", "", "")
systemCount := 0
for _, m := range msgs {
@@ -177,16 +168,7 @@ func TestBuildMessages_CurrentSenderDynamicContext(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
msgs := cb.BuildMessages(
nil,
"",
"hello",
nil,
"discord",
"chat1",
tt.senderID,
tt.senderDisplayName,
)
msgs := cb.BuildMessages(nil, "", "hello", nil, "discord", "chat1", tt.senderID, tt.senderDisplayName)
sys := msgs[0].Content
if tt.wantSection {
@@ -400,10 +382,7 @@ func TestNewFileCreationInvalidatesCache(t *testing.T) {
// Cache should auto-invalidate because file went from absent -> present
sp2 := cb.BuildSystemPromptWithCache()
if !strings.Contains(sp2, tt.checkField) {
t.Errorf(
"cache not invalidated on new file creation: expected %q in prompt",
tt.checkField,
)
t.Errorf("cache not invalidated on new file creation: expected %q in prompt", tt.checkField)
}
})
}
+3 -38
View File
@@ -151,19 +151,7 @@ func TestSanitizeHistoryForProvider_MultiToolCallsThenNewRound(t *testing.T) {
if len(result) != 9 {
t.Fatalf("expected 9 messages, got %d: %+v", len(result), roles(result))
}
assertRoles(
t,
result,
"user",
"assistant",
"tool",
"tool",
"assistant",
"user",
"assistant",
"tool",
"assistant",
)
assertRoles(t, result, "user", "assistant", "tool", "tool", "assistant", "user", "assistant", "tool", "assistant")
}
func TestSanitizeHistoryForProvider_ConsecutiveMultiToolRounds(t *testing.T) {
@@ -182,18 +170,7 @@ func TestSanitizeHistoryForProvider_ConsecutiveMultiToolRounds(t *testing.T) {
if len(result) != 8 {
t.Fatalf("expected 8 messages, got %d: %+v", len(result), roles(result))
}
assertRoles(
t,
result,
"user",
"assistant",
"tool",
"tool",
"assistant",
"tool",
"tool",
"assistant",
)
assertRoles(t, result, "user", "assistant", "tool", "tool", "assistant", "tool", "tool", "assistant")
}
func TestSanitizeHistoryForProvider_PlainConversation(t *testing.T) {
@@ -327,17 +304,5 @@ func TestSanitizeHistoryForProvider_PartialToolResultsInMiddle(t *testing.T) {
if len(result) != 9 {
t.Fatalf("expected 9 messages, got %d: %+v", len(result), roles(result))
}
assertRoles(
t,
result,
"user",
"assistant",
"tool",
"assistant",
"user",
"user",
"assistant",
"tool",
"assistant",
)
assertRoles(t, result, "user", "assistant", "tool", "assistant", "user", "user", "assistant", "tool", "assistant")
}
+4 -14
View File
@@ -61,12 +61,8 @@ Act directly and use tools first.
if len(definition.Agent.Frontmatter.Skills) != 2 {
t.Fatalf("expected skills to be parsed, got %v", definition.Agent.Frontmatter.Skills)
}
if len(definition.Agent.Frontmatter.MCPServers) != 1 ||
definition.Agent.Frontmatter.MCPServers[0] != "github" {
t.Fatalf(
"expected mcpServers to be parsed, got %v",
definition.Agent.Frontmatter.MCPServers,
)
if len(definition.Agent.Frontmatter.MCPServers) != 1 || definition.Agent.Frontmatter.MCPServers[0] != "github" {
t.Fatalf("expected mcpServers to be parsed, got %v", definition.Agent.Frontmatter.MCPServers)
}
if definition.Agent.Frontmatter.Fields["metadata"] == nil {
t.Fatal("expected arbitrary frontmatter fields to remain available")
@@ -100,10 +96,7 @@ func TestLoadAgentDefinitionFallsBackToLegacyAgentsMarkdown(t *testing.T) {
t.Fatal("expected AGENTS.md to be loaded")
}
if definition.Agent.RawFrontmatter != "" {
t.Fatalf(
"legacy AGENTS.md should not have frontmatter, got %q",
definition.Agent.RawFrontmatter,
)
t.Fatalf("legacy AGENTS.md should not have frontmatter, got %q", definition.Agent.RawFrontmatter)
}
if !strings.Contains(definition.Agent.Body, "Keep compatibility") {
t.Fatalf("expected legacy body to be preserved, got %q", definition.Agent.Body)
@@ -166,10 +159,7 @@ Keep going.
len(definition.Agent.Frontmatter.Skills) != 0 ||
len(definition.Agent.Frontmatter.MCPServers) != 0 ||
len(definition.Agent.Frontmatter.Fields) != 0 {
t.Fatalf(
"expected invalid frontmatter to decode as empty struct, got %+v",
definition.Agent.Frontmatter,
)
t.Fatalf("expected invalid frontmatter to decode as empty struct, got %+v", definition.Agent.Frontmatter)
}
}
+4 -21
View File
@@ -275,13 +275,7 @@ func TestAgentLoop_EmitsSteeringAndSkippedToolEvents(t *testing.T) {
resultCh := make(chan string, 1)
go func() {
resp, _ := al.ProcessDirectWithChannel(
context.Background(),
"do something",
"test-session",
"test",
"chat1",
)
resp, _ := al.ProcessDirectWithChannel(context.Background(), "do something", "test-session", "test", "chat1")
resultCh <- resp
}()
@@ -344,11 +338,7 @@ func TestAgentLoop_EmitsSteeringAndSkippedToolEvents(t *testing.T) {
t.Fatalf("expected steering interrupt kind, got %q", interruptPayload.Kind)
}
if interruptPayload.ContentLen != len("change course") {
t.Fatalf(
"expected interrupt content len %d, got %d",
len("change course"),
interruptPayload.ContentLen,
)
t.Fatalf("expected interrupt content len %d, got %d", len("change course"), interruptPayload.ContentLen)
}
}
@@ -370,9 +360,7 @@ func TestAgentLoop_EmitsContextCompressEventOnRetry(t *testing.T) {
},
}
contextErr := stringError(
"InvalidParameter: Total tokens of image and text exceed max message tokens",
)
contextErr := stringError("InvalidParameter: Total tokens of image and text exceed max message tokens")
provider := &failFirstMockProvider{
failures: 1,
failError: contextErr,
@@ -615,12 +603,7 @@ func collectEventStream(ch <-chan Event) []Event {
}
}
func waitForEvent(
t *testing.T,
ch <-chan Event,
timeout time.Duration,
match func(Event) bool,
) Event {
func waitForEvent(t *testing.T, ch <-chan Event, timeout time.Duration, match func(Event) bool) Event {
t.Helper()
timer := time.NewTimer(timeout)
+4 -26
View File
@@ -40,11 +40,7 @@ func (h *builtinAutoHook) AfterLLM(
return next, HookDecision{Action: HookActionModify}, nil
}
func newConfiguredHookLoop(
t *testing.T,
provider *llmHookTestProvider,
hooks config.HooksConfig,
) *AgentLoop {
func newConfiguredHookLoop(t *testing.T, provider *llmHookTestProvider, hooks config.HooksConfig) *AgentLoop {
t.Helper()
cfg := &config.Config{
@@ -106,13 +102,7 @@ func TestAgentLoop_ProcessDirectWithChannel_AutoMountsBuiltinHook(t *testing.T)
})
defer al.Close()
resp, err := al.ProcessDirectWithChannel(
context.Background(),
"hello",
"session-1",
"cli",
"direct",
)
resp, err := al.ProcessDirectWithChannel(context.Background(), "hello", "session-1", "cli", "direct")
if err != nil {
t.Fatalf("ProcessDirectWithChannel failed: %v", err)
}
@@ -150,13 +140,7 @@ func TestAgentLoop_ProcessDirectWithChannel_AutoMountsProcessHook(t *testing.T)
})
defer al.Close()
resp, err := al.ProcessDirectWithChannel(
context.Background(),
"hello",
"session-1",
"cli",
"direct",
)
resp, err := al.ProcessDirectWithChannel(context.Background(), "hello", "session-1", "cli", "direct")
if err != nil {
t.Fatalf("ProcessDirectWithChannel failed: %v", err)
}
@@ -188,13 +172,7 @@ func TestAgentLoop_ProcessDirectWithChannel_InvalidConfiguredHookFails(t *testin
})
defer al.Close()
_, err := al.ProcessDirectWithChannel(
context.Background(),
"hello",
"session-1",
"cli",
"direct",
)
_, err := al.ProcessDirectWithChannel(context.Background(), "hello", "session-1", "cli", "direct")
if err == nil {
t.Fatal("expected invalid configured hook error")
}
+3 -14
View File
@@ -98,11 +98,7 @@ type processHookAfterToolResponse struct {
Result *ToolResultHookResponse `json:"result,omitempty"`
}
func NewProcessHook(
ctx context.Context,
name string,
opts ProcessHookOptions,
) (*ProcessHook, error) {
func NewProcessHook(ctx context.Context, name string, opts ProcessHookOptions) (*ProcessHook, error) {
if len(opts.Command) == 0 {
return nil, fmt.Errorf("process hook command is required")
}
@@ -266,10 +262,7 @@ func (ph *ProcessHook) AfterTool(
return resp.Result, HookDecision{Action: resp.Action, Reason: resp.Reason}, nil
}
func (ph *ProcessHook) ApproveTool(
ctx context.Context,
req *ToolApprovalRequest,
) (ApprovalDecision, error) {
func (ph *ProcessHook) ApproveTool(ctx context.Context, req *ToolApprovalRequest) (ApprovalDecision, error) {
if ph == nil || !ph.opts.ApproveTool {
return ApprovalDecision{Approved: true}, nil
}
@@ -480,11 +473,7 @@ func (ph *ProcessHook) removePending(id uint64) {
}
}
func (al *AgentLoop) MountProcessHook(
ctx context.Context,
name string,
opts ProcessHookOptions,
) error {
func (al *AgentLoop) MountProcessHook(ctx context.Context, name string, opts ProcessHookOptions) error {
if al == nil {
return fmt.Errorf("agent loop is nil")
}
+4 -16
View File
@@ -79,14 +79,8 @@ type LLMInterceptor interface {
}
type ToolInterceptor interface {
BeforeTool(
ctx context.Context,
call *ToolCallHookRequest,
) (*ToolCallHookRequest, HookDecision, error)
AfterTool(
ctx context.Context,
result *ToolResultHookResponse,
) (*ToolResultHookResponse, HookDecision, error)
BeforeTool(ctx context.Context, call *ToolCallHookRequest) (*ToolCallHookRequest, HookDecision, error)
AfterTool(ctx context.Context, result *ToolResultHookResponse) (*ToolResultHookResponse, HookDecision, error)
}
type ToolApprover interface {
@@ -301,10 +295,7 @@ func (hm *HookManager) dispatchEvents() {
}
}
func (hm *HookManager) BeforeLLM(
ctx context.Context,
req *LLMHookRequest,
) (*LLMHookRequest, HookDecision) {
func (hm *HookManager) BeforeLLM(ctx context.Context, req *LLMHookRequest) (*LLMHookRequest, HookDecision) {
if hm == nil || req == nil {
return req, HookDecision{Action: HookActionContinue}
}
@@ -335,10 +326,7 @@ func (hm *HookManager) BeforeLLM(
return current, HookDecision{Action: HookActionContinue}
}
func (hm *HookManager) AfterLLM(
ctx context.Context,
resp *LLMHookResponse,
) (*LLMHookResponse, HookDecision) {
func (hm *HookManager) AfterLLM(ctx context.Context, resp *LLMHookResponse) (*LLMHookResponse, HookDecision) {
if hm == nil || resp == nil {
return resp, HookDecision{Action: HookActionContinue}
}
+1 -4
View File
@@ -293,10 +293,7 @@ func TestAgentLoop_Hooks_ToolInterceptorCanRewrite(t *testing.T) {
type denyApprovalHook struct{}
func (h *denyApprovalHook) ApproveTool(
ctx context.Context,
req *ToolApprovalRequest,
) (ApprovalDecision, error) {
func (h *denyApprovalHook) ApproveTool(ctx context.Context, req *ToolApprovalRequest) (ApprovalDecision, error) {
return ApprovalDecision{
Approved: false,
Reason: "blocked",
+1 -5
View File
@@ -156,11 +156,7 @@ func TestNewAgentInstance_ResolveCandidatesFromModelListAlias(t *testing.T) {
t.Fatalf("len(Candidates) = %d, want 1", len(agent.Candidates))
}
if agent.Candidates[0].Provider != tt.wantProvider {
t.Fatalf(
"candidate provider = %q, want %q",
agent.Candidates[0].Provider,
tt.wantProvider,
)
t.Fatalf("candidate provider = %q, want %q", agent.Candidates[0].Provider, tt.wantProvider)
}
if agent.Candidates[0].Model != tt.wantModel {
t.Fatalf("candidate model = %q, want %q", agent.Candidates[0].Model, tt.wantModel)
+28 -107
View File
@@ -192,11 +192,7 @@ func registerSharedTools(
Proxy: cfg.Tools.Web.Proxy,
})
if err != nil {
logger.ErrorCF(
"agent",
"Failed to create web search tool",
map[string]any{"error": err.Error()},
)
logger.ErrorCF("agent", "Failed to create web search tool", map[string]any{"error": err.Error()})
} else if searchTool != nil {
agent.Tools.Register(searchTool)
}
@@ -209,11 +205,7 @@ func registerSharedTools(
cfg.Tools.Web.FetchLimitBytes,
cfg.Tools.Web.PrivateHostWhitelist)
if err != nil {
logger.ErrorCF(
"agent",
"Failed to create web fetch tool",
map[string]any{"error": err.Error()},
)
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
} else {
agent.Tools.Register(fetchTool)
}
@@ -483,12 +475,7 @@ func (al *AgentLoop) Run(ctx context.Context) error {
"queue_depth": al.pendingSteeringCountForScope(target.SessionKey),
})
continued, continueErr := al.Continue(
ctx,
target.SessionKey,
target.Channel,
target.ChatID,
)
continued, continueErr := al.Continue(ctx, target.SessionKey, target.Channel, target.ChatID)
if continueErr != nil {
logger.WarnCF("agent", "Failed to continue queued steering",
map[string]any{
@@ -516,22 +503,14 @@ func (al *AgentLoop) Run(ctx context.Context) error {
"queue_depth": al.pendingSteeringCountForScope(target.SessionKey),
})
continued, continueErr := al.Continue(
ctx,
target.SessionKey,
target.Channel,
target.ChatID,
)
continued, continueErr := al.Continue(ctx, target.SessionKey, target.Channel, target.ChatID)
if continueErr != nil {
logger.WarnCF(
"agent",
"Failed to continue queued steering after shutdown drain",
logger.WarnCF("agent", "Failed to continue queued steering after shutdown drain",
map[string]any{
"channel": target.Channel,
"chat_id": target.ChatID,
"error": continueErr.Error(),
},
)
})
return
}
if continued == "" {
@@ -586,15 +565,11 @@ func (al *AgentLoop) drainBusToSteering(ctx context.Context, activeScope, active
msgScope, _, scopeOK := al.resolveSteeringTarget(msg)
if !scopeOK || msgScope != activeScope {
if err := al.requeueInboundMessage(msg); err != nil {
logger.WarnCF(
"agent",
"Failed to requeue non-steering inbound message",
map[string]any{
"error": err.Error(),
"channel": msg.Channel,
"sender_id": msg.SenderID,
},
)
logger.WarnCF("agent", "Failed to requeue non-steering inbound message", map[string]any{
"error": err.Error(),
"channel": msg.Channel,
"sender_id": msg.SenderID,
})
}
continue
}
@@ -628,10 +603,7 @@ func (al *AgentLoop) Stop() {
al.running.Store(false)
}
func (al *AgentLoop) PublishResponseIfNeeded(
ctx context.Context,
channel, chatID, response string,
) {
func (al *AgentLoop) PublishResponseIfNeeded(ctx context.Context, channel, chatID, response string) {
if response == "" {
return
}
@@ -1081,10 +1053,7 @@ var audioAnnotationRe = regexp.MustCompile(`\[(voice|audio)(?::[^\]]*)?\]`)
// transcribeAudioInMessage resolves audio media refs, transcribes them, and
// replaces audio annotations in msg.Content with the transcribed text.
// Returns the (possibly modified) message and true if audio was transcribed.
func (al *AgentLoop) transcribeAudioInMessage(
ctx context.Context,
msg bus.InboundMessage,
) (bus.InboundMessage, bool) {
func (al *AgentLoop) transcribeAudioInMessage(ctx context.Context, msg bus.InboundMessage) (bus.InboundMessage, bool) {
if al.transcriber == nil || al.mediaStore == nil || len(msg.Media) == 0 {
return msg, false
}
@@ -1094,11 +1063,7 @@ func (al *AgentLoop) transcribeAudioInMessage(
for _, ref := range msg.Media {
path, meta, err := al.mediaStore.ResolveWithMeta(ref)
if err != nil {
logger.WarnCF(
"voice",
"Failed to resolve media ref",
map[string]any{"ref": ref, "error": err},
)
logger.WarnCF("voice", "Failed to resolve media ref", map[string]any{"ref": ref, "error": err})
continue
}
if !utils.IsAudioFile(meta.Filename, meta.ContentType) {
@@ -1176,11 +1141,7 @@ func (al *AgentLoop) sendTranscriptionFeedback(
ReplyToMessageID: messageID,
})
if err != nil {
logger.WarnCF(
"voice",
"Failed to send transcription feedback",
map[string]any{"error": err.Error()},
)
logger.WarnCF("voice", "Failed to send transcription feedback", map[string]any{"error": err.Error()})
}
}
@@ -1381,9 +1342,7 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage)
return al.runAgentLoop(ctx, agent, opts)
}
func (al *AgentLoop) resolveMessageRoute(
msg bus.InboundMessage,
) (routing.ResolvedRoute, *AgentInstance, error) {
func (al *AgentLoop) resolveMessageRoute(msg bus.InboundMessage) (routing.ResolvedRoute, *AgentInstance, error) {
registry := al.GetRegistry()
route := registry.ResolveRoute(routing.RouteInput{
Channel: msg.Channel,
@@ -1399,10 +1358,7 @@ func (al *AgentLoop) resolveMessageRoute(
agent = registry.GetDefaultAgent()
}
if agent == nil {
return routing.ResolvedRoute{}, nil, fmt.Errorf(
"no agent available for route (agent_id=%s)",
route.AgentID,
)
return routing.ResolvedRoute{}, nil, fmt.Errorf("no agent available for route (agent_id=%s)", route.AgentID)
}
return route, agent, nil
@@ -1727,11 +1683,7 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er
ts.recordPersistedMessage(rootMsg)
}
activeCandidates, activeModel, usedLight := al.selectCandidates(
ts.agent,
ts.userMessage,
messages,
)
activeCandidates, activeModel, usedLight := al.selectCandidates(ts.agent, ts.userMessage, messages)
activeProvider := ts.agent.Provider
if usedLight && ts.agent.LightProvider != nil {
activeProvider = ts.agent.LightProvider
@@ -2704,15 +2656,12 @@ turnLoop:
}
if steerMsgs := al.dequeueSteeringMessagesForScope(ts.sessionKey); len(steerMsgs) > 0 {
logger.InfoCF(
"agent",
"Steering arrived after turn completion; continuing turn before finalizing",
logger.InfoCF("agent", "Steering arrived after turn completion; continuing turn before finalizing",
map[string]any{
"agent_id": ts.agent.ID,
"steering_count": len(steerMsgs),
"session_key": ts.sessionKey,
},
)
})
pendingMessages = append(pendingMessages, steerMsgs...)
finalContent = ""
goto turnLoop
@@ -2828,18 +2777,11 @@ func (al *AgentLoop) selectCandidates(
"score": score,
"threshold": agent.Router.Threshold(),
})
return agent.LightCandidates, resolvedCandidateModel(
agent.LightCandidates,
agent.Router.LightModel(),
), true
return agent.LightCandidates, resolvedCandidateModel(agent.LightCandidates, agent.Router.LightModel()), true
}
// maybeSummarize triggers summarization if the session history exceeds thresholds.
func (al *AgentLoop) maybeSummarize(
agent *AgentInstance,
sessionKey string,
turnScope turnEventScope,
) {
func (al *AgentLoop) maybeSummarize(agent *AgentInstance, sessionKey string, turnScope turnEventScope) {
newHistory := agent.Sessions.GetHistory(sessionKey)
tokenEstimate := al.estimateTokens(newHistory)
threshold := agent.ContextWindow * agent.SummarizeTokenPercent / 100
@@ -2873,10 +2815,7 @@ type compressionResult struct {
// prompt is built dynamically by BuildMessages and is NOT stored here.
// The compression note is recorded in the session summary so that
// BuildMessages can include it in the next system prompt.
func (al *AgentLoop) forceCompression(
agent *AgentInstance,
sessionKey string,
) (compressionResult, bool) {
func (al *AgentLoop) forceCompression(agent *AgentInstance, sessionKey string) (compressionResult, bool) {
history := agent.Sessions.GetHistory(sessionKey)
if len(history) <= 2 {
return compressionResult{}, false
@@ -3029,11 +2968,7 @@ func formatToolsForLog(toolDefs []providers.ToolDefinition) string {
}
// summarizeSession summarizes the conversation history for a session.
func (al *AgentLoop) summarizeSession(
agent *AgentInstance,
sessionKey string,
turnScope turnEventScope,
) {
func (al *AgentLoop) summarizeSession(agent *AgentInstance, sessionKey string, turnScope turnEventScope) {
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
@@ -3385,10 +3320,7 @@ func (al *AgentLoop) applyExplicitSkillCommand(
skillName, ok := agent.ContextBuilder.ResolveSkillName(arg)
if !ok {
return true, true, fmt.Sprintf(
"Unknown skill: %s\nUse /list skills to see installed skills.",
arg,
)
return true, true, fmt.Sprintf("Unknown skill: %s\nUse /list skills to see installed skills.", arg)
}
if len(parts) < 3 {
@@ -3415,10 +3347,7 @@ func (al *AgentLoop) applyExplicitSkillCommand(
return true, false, ""
}
func (al *AgentLoop) buildCommandsRuntime(
agent *AgentInstance,
opts *processOptions,
) *commands.Runtime {
func (al *AgentLoop) buildCommandsRuntime(agent *AgentInstance, opts *processOptions) *commands.Runtime {
registry := al.GetRegistry()
cfg := al.GetConfig()
rt := &commands.Runtime{
@@ -3462,10 +3391,7 @@ func (al *AgentLoop) buildCommandsRuntime(
rt.ListSkillNames = agent.ContextBuilder.ListSkillNames
}
rt.GetModelInfo = func() (string, string) {
return agent.Model, resolvedCandidateProvider(
agent.Candidates,
cfg.Agents.Defaults.Provider,
)
return agent.Model, resolvedCandidateProvider(agent.Candidates, cfg.Agents.Defaults.Provider)
}
rt.SwitchModel = func(value string) (string, error) {
value = strings.TrimSpace(value)
@@ -3479,12 +3405,7 @@ func (al *AgentLoop) buildCommandsRuntime(
return "", fmt.Errorf("failed to initialize model %q: %w", value, err)
}
nextCandidates := resolveModelCandidates(
cfg,
cfg.Agents.Defaults.Provider,
modelCfg.Model,
agent.Fallbacks,
)
nextCandidates := resolveModelCandidates(cfg, cfg.Agents.Defaults.Provider, modelCfg.Model, agent.Fallbacks)
if len(nextCandidates) == 0 {
return "", fmt.Errorf("model %q did not resolve to any provider candidates", value)
}
+4 -16
View File
@@ -65,11 +65,7 @@ func (al *AgentLoop) ensureMCPInitialized(ctx context.Context) error {
}
if al.cfg.Tools.MCP.Servers == nil || len(al.cfg.Tools.MCP.Servers) == 0 {
logger.WarnCF(
"agent",
"MCP is enabled but no servers are configured, skipping MCP initialization",
nil,
)
logger.WarnCF("agent", "MCP is enabled but no servers are configured, skipping MCP initialization", nil)
return nil
}
@@ -80,11 +76,7 @@ func (al *AgentLoop) ensureMCPInitialized(ctx context.Context) error {
}
}
if !findValidServer {
logger.WarnCF(
"agent",
"MCP is enabled but no valid servers are configured, skipping MCP initialization",
nil,
)
logger.WarnCF("agent", "MCP is enabled but no valid servers are configured, skipping MCP initialization", nil)
return nil
}
@@ -201,14 +193,10 @@ func (al *AgentLoop) ensureMCPInitialized(ctx context.Context) error {
}
if useRegex {
agent.Tools.Register(
tools.NewRegexSearchTool(agent.Tools, ttl, maxSearchResults),
)
agent.Tools.Register(tools.NewRegexSearchTool(agent.Tools, ttl, maxSearchResults))
}
if useBM25 {
agent.Tools.Register(
tools.NewBM25SearchTool(agent.Tools, ttl, maxSearchResults),
)
agent.Tools.Register(tools.NewBM25SearchTool(agent.Tools, ttl, maxSearchResults))
}
}
}
+1 -5
View File
@@ -25,11 +25,7 @@ import (
// Non-image files (documents, audio, video) have their local path injected
// into Content so the agent can access them via file tools like read_file.
// Returns a new slice; original messages are not mutated.
func resolveMediaRefs(
messages []providers.Message,
store media.MediaStore,
maxSize int,
) []providers.Message {
func resolveMediaRefs(messages []providers.Message, store media.MediaStore, maxSize int) []providers.Message {
if store == nil {
return messages
}
+24 -97
View File
@@ -591,9 +591,7 @@ func TestProcessMessage_MediaToolHandledSkipsFollowUpLLMAndFinalText(t *testing.
store := media.NewFileMediaStore()
al.SetMediaStore(store)
telegramChannel := &fakeMediaChannel{fakeChannel: fakeChannel{id: "rid-telegram"}}
al.SetChannelManager(
newStartedTestChannelManager(t, msgBus, store, "telegram", telegramChannel),
)
al.SetChannelManager(newStartedTestChannelManager(t, msgBus, store, "telegram", telegramChannel))
imagePath := filepath.Join(tmpDir, "screen.png")
if err := os.WriteFile(imagePath, []byte("fake screenshot"), 0o644); err != nil {
@@ -615,10 +613,7 @@ func TestProcessMessage_MediaToolHandledSkipsFollowUpLLMAndFinalText(t *testing.
t.Fatalf("processMessage() error = %v", err)
}
if response != "" {
t.Fatalf(
"expected no final response when media tool already handled delivery, got %q",
response,
)
t.Fatalf("expected no final response when media tool already handled delivery, got %q", response)
}
if provider.calls != 1 {
t.Fatalf("expected exactly 1 LLM call, got %d", provider.calls)
@@ -631,20 +626,13 @@ func TestProcessMessage_MediaToolHandledSkipsFollowUpLLMAndFinalText(t *testing.
}
if len(telegramChannel.sentMedia) != 1 {
t.Fatalf(
"expected exactly 1 synchronously sent media message, got %d",
len(telegramChannel.sentMedia),
)
t.Fatalf("expected exactly 1 synchronously sent media message, got %d", len(telegramChannel.sentMedia))
}
if telegramChannel.sentMedia[0].Channel != "telegram" ||
telegramChannel.sentMedia[0].ChatID != "chat1" {
if telegramChannel.sentMedia[0].Channel != "telegram" || telegramChannel.sentMedia[0].ChatID != "chat1" {
t.Fatalf("unexpected sent media target: %+v", telegramChannel.sentMedia[0])
}
if len(telegramChannel.sentMedia[0].Parts) != 1 {
t.Fatalf(
"expected exactly 1 sent media part, got %d",
len(telegramChannel.sentMedia[0].Parts),
)
t.Fatalf("expected exactly 1 sent media part, got %d", len(telegramChannel.sentMedia[0].Parts))
}
select {
@@ -672,8 +660,7 @@ func TestProcessMessage_MediaToolHandledSkipsFollowUpLLMAndFinalText(t *testing.
t.Fatal("expected session history to be saved")
}
last := history[len(history)-1]
if last.Role != "assistant" ||
last.Content != "Requested output delivered via tool attachment." {
if last.Role != "assistant" || last.Content != "Requested output delivered via tool attachment." {
t.Fatalf("expected handled assistant summary in history, got %+v", last)
}
}
@@ -698,9 +685,7 @@ func TestProcessMessage_HandledToolProcessesQueuedSteeringBeforeReturning(t *tes
store := media.NewFileMediaStore()
al.SetMediaStore(store)
telegramChannel := &fakeMediaChannel{fakeChannel: fakeChannel{id: "rid-telegram"}}
al.SetChannelManager(
newStartedTestChannelManager(t, msgBus, store, "telegram", telegramChannel),
)
al.SetChannelManager(newStartedTestChannelManager(t, msgBus, store, "telegram", telegramChannel))
imagePath := filepath.Join(tmpDir, "screen-steering.png")
if err := os.WriteFile(imagePath, []byte("fake screenshot"), 0o644); err != nil {
@@ -729,10 +714,7 @@ func TestProcessMessage_HandledToolProcessesQueuedSteeringBeforeReturning(t *tes
t.Fatalf("expected 2 LLM calls after queued steering, got %d", provider.calls)
}
if len(telegramChannel.sentMedia) != 1 {
t.Fatalf(
"expected exactly 1 synchronously sent media message, got %d",
len(telegramChannel.sentMedia),
)
t.Fatalf("expected exactly 1 synchronously sent media message, got %d", len(telegramChannel.sentMedia))
}
}
@@ -751,9 +733,7 @@ func TestProcessMessage_MediaArtifactCanBeForwardedBySendFile(t *testing.T) {
store := media.NewFileMediaStore()
al.SetMediaStore(store)
telegramChannel := &fakeMediaChannel{fakeChannel: fakeChannel{id: "rid-telegram"}}
al.SetChannelManager(
newStartedTestChannelManager(t, msgBus, store, "telegram", telegramChannel),
)
al.SetChannelManager(newStartedTestChannelManager(t, msgBus, store, "telegram", telegramChannel))
mediaDir := media.TempDir()
if err := os.MkdirAll(mediaDir, 0o700); err != nil {
@@ -786,20 +766,13 @@ func TestProcessMessage_MediaArtifactCanBeForwardedBySendFile(t *testing.T) {
}
if len(telegramChannel.sentMedia) != 1 {
t.Fatalf(
"expected exactly 1 synchronously sent media message, got %d",
len(telegramChannel.sentMedia),
)
t.Fatalf("expected exactly 1 synchronously sent media message, got %d", len(telegramChannel.sentMedia))
}
if telegramChannel.sentMedia[0].Channel != "telegram" ||
telegramChannel.sentMedia[0].ChatID != "chat1" {
if telegramChannel.sentMedia[0].Channel != "telegram" || telegramChannel.sentMedia[0].ChatID != "chat1" {
t.Fatalf("unexpected sent media target: %+v", telegramChannel.sentMedia[0])
}
if len(telegramChannel.sentMedia[0].Parts) != 1 {
t.Fatalf(
"expected exactly 1 sent media part, got %d",
len(telegramChannel.sentMedia[0].Parts),
)
t.Fatalf("expected exactly 1 sent media part, got %d", len(telegramChannel.sentMedia[0].Parts))
}
select {
@@ -1210,10 +1183,7 @@ func (m *handledMediaWithSteeringTool) Parameters() map[string]any {
}
}
func (m *handledMediaWithSteeringTool) Execute(
ctx context.Context,
args map[string]any,
) *tools.ToolResult {
func (m *handledMediaWithSteeringTool) Execute(ctx context.Context, args map[string]any) *tools.ToolResult {
if err := m.loop.Steer(providers.Message{Role: "user", Content: "what about this instead?"}); err != nil {
return tools.ErrorResult(err.Error()).WithError(err)
}
@@ -1366,11 +1336,7 @@ func newStrictChatCompletionTestServer(
}))
}
func (h testHelper) executeAndGetResponse(
tb testing.TB,
ctx context.Context,
msg bus.InboundMessage,
) string {
func (h testHelper) executeAndGetResponse(tb testing.TB, ctx context.Context, msg bus.InboundMessage) string {
// Use a short timeout to avoid hanging
timeoutCtx, cancel := context.WithTimeout(ctx, responseTimeout)
defer cancel()
@@ -1501,10 +1467,7 @@ func TestProcessMessage_CommandOutcomes(t *testing.T) {
t.Fatalf("unexpected /foo reply: %q", fooResp)
}
if provider.calls != 1 {
t.Fatalf(
"LLM should be called exactly once after /foo passthrough, calls=%d",
provider.calls,
)
t.Fatalf("LLM should be called exactly once after /foo passthrough, calls=%d", provider.calls)
}
newResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
@@ -1654,10 +1617,7 @@ func TestProcessMessage_SwitchModelRejectsUnknownAlias(t *testing.T) {
}
if provider.calls != 0 {
t.Fatalf(
"LLM should not be called for rejected /switch and /show, calls=%d",
provider.calls,
)
t.Fatalf("LLM should not be called for rejected /switch and /show, calls=%d", provider.calls)
}
}
@@ -1675,13 +1635,7 @@ func TestProcessMessage_SwitchModelRoutesSubsequentRequestsToSelectedProvider(t
remoteCalls := 0
remoteModel := ""
remoteServer := newChatCompletionTestServer(
t,
"remote",
"remote reply",
&remoteCalls,
&remoteModel,
)
remoteServer := newChatCompletionTestServer(t, "remote", "remote reply", &remoteCalls, &remoteModel)
defer remoteServer.Close()
cfg := &config.Config{
@@ -2004,9 +1958,7 @@ func TestAgentLoop_ContextExhaustionRetry(t *testing.T) {
msgBus := bus.NewMessageBus()
// Create a provider that fails once with a context error
contextErr := fmt.Errorf(
"InvalidParameter: Total tokens of image and text exceed max message tokens",
)
contextErr := fmt.Errorf("InvalidParameter: Total tokens of image and text exceed max message tokens")
provider := &failFirstMockProvider{
failures: 1,
failError: contextErr,
@@ -2087,13 +2039,7 @@ func TestAgentLoop_EmptyModelResponseUsesAccurateFallback(t *testing.T) {
provider := &simpleMockProvider{response: ""}
al := NewAgentLoop(cfg, msgBus, provider)
response, err := al.ProcessDirectWithChannel(
context.Background(),
"hello",
"empty-response",
"test",
"chat1",
)
response, err := al.ProcessDirectWithChannel(context.Background(), "hello", "empty-response", "test", "chat1")
if err != nil {
t.Fatalf("ProcessDirectWithChannel failed: %v", err)
}
@@ -2125,13 +2071,7 @@ func TestAgentLoop_ToolLimitUsesDedicatedFallback(t *testing.T) {
al := NewAgentLoop(cfg, msgBus, provider)
al.RegisterTool(&toolLimitTestTool{})
response, err := al.ProcessDirectWithChannel(
context.Background(),
"hello",
"tool-limit",
"test",
"chat1",
)
response, err := al.ProcessDirectWithChannel(context.Background(), "hello", "tool-limit", "test", "chat1")
if err != nil {
t.Fatalf("ProcessDirectWithChannel failed: %v", err)
}
@@ -2449,9 +2389,7 @@ func TestHandleReasoning(t *testing.T) {
break
}
if msg.Content == "should timeout" {
t.Fatal(
"expected reasoning message to be dropped when bus is full, but it was published",
)
t.Fatal("expected reasoning message to be dropped when bus is full, but it was published")
}
}
}
@@ -2545,12 +2483,7 @@ func TestProcessHeartbeat_DoesNotPublishToolFeedback(t *testing.T) {
provider := &toolFeedbackProvider{filePath: heartbeatFile}
al := NewAgentLoop(cfg, msgBus, provider)
response, err := al.ProcessHeartbeat(
context.Background(),
"check heartbeat tasks",
"telegram",
"chat-1",
)
response, err := al.ProcessHeartbeat(context.Background(), "check heartbeat tasks", "telegram", "chat-1")
if err != nil {
t.Fatalf("ProcessHeartbeat() error = %v", err)
}
@@ -3035,14 +2968,8 @@ func TestProcessMessage_ContextOverflowRecovery(t *testing.T) {
agent := al.GetRegistry().GetDefaultAgent()
for i := 0; i < 5; i++ {
agent.Sessions.AddFullMessage(
sessionKey,
providers.Message{Role: "user", Content: "heavy message"},
)
agent.Sessions.AddFullMessage(
sessionKey,
providers.Message{Role: "assistant", Content: "response"},
)
agent.Sessions.AddFullMessage(sessionKey, providers.Message{Role: "user", Content: "heavy message"})
agent.Sessions.AddFullMessage(sessionKey, providers.Message{Role: "assistant", Content: "response"})
}
response, err := al.processMessage(context.Background(), bus.InboundMessage{
+2 -6
View File
@@ -26,8 +26,7 @@ func buildModelListResolver(cfg *config.Config) func(raw string) (string, bool)
return "", false
}
if mc, err := cfg.GetModelConfig(raw); err == nil && mc != nil &&
strings.TrimSpace(mc.Model) != "" {
if mc, err := cfg.GetModelConfig(raw); err == nil && mc != nil && strings.TrimSpace(mc.Model) != "" {
return ensureProtocol(mc.Model), true
}
@@ -79,10 +78,7 @@ func resolvedCandidateProvider(candidates []providers.FallbackCandidate, fallbac
return fallback
}
func resolvedModelConfig(
cfg *config.Config,
modelName, workspace string,
) (*config.ModelConfig, error) {
func resolvedModelConfig(cfg *config.Config, modelName, workspace string) (*config.ModelConfig, error) {
if cfg == nil {
return nil, fmt.Errorf("config is nil")
}
+1 -4
View File
@@ -325,10 +325,7 @@ func (al *AgentLoop) agentForSession(sessionKey string) *AgentInstance {
// user has since enqueued steering messages.
//
// If no steering messages are pending, it returns an empty string.
func (al *AgentLoop) Continue(
ctx context.Context,
sessionKey, channel, chatID string,
) (string, error) {
func (al *AgentLoop) Continue(ctx context.Context, sessionKey, channel, chatID string) (string, error) {
if active := al.GetActiveTurn(); active != nil {
return "", fmt.Errorf("turn %s is still active", active.TurnID)
}
+6 -24
View File
@@ -896,10 +896,7 @@ func TestAgentLoop_Run_AutoContinuesLateSteeringMessage(t *testing.T) {
defer cancelNoExtra()
select {
case out2 := <-msgBus.OutboundChan():
t.Fatalf(
"expected stale direct response to be suppressed, got extra outbound %q",
out2.Content,
)
t.Fatalf("expected stale direct response to be suppressed, got extra outbound %q", out2.Content)
case <-noExtraCtx.Done():
}
@@ -1047,11 +1044,7 @@ func TestAgentLoop_Continue_PreservesSteeringMedia(t *testing.T) {
if err = os.WriteFile(pngPath, pngHeader, 0o644); err != nil {
t.Fatalf("WriteFile failed: %v", err)
}
ref, err := store.Store(
pngPath,
media.MediaMeta{Filename: "steer.png", ContentType: "image/png"},
"test",
)
ref, err := store.Store(pngPath, media.MediaMeta{Filename: "steer.png", ContentType: "image/png"}, "test")
if err != nil {
t.Fatalf("Store failed: %v", err)
}
@@ -1243,10 +1236,7 @@ func TestAgentLoop_InterruptGraceful_UsesTerminalNoToolCall(t *testing.T) {
t.Fatalf("expected 2 provider calls, got %d", calls)
}
if terminalToolsCount != 0 {
t.Fatalf(
"expected graceful terminal call to disable tools, got %d tool defs",
terminalToolsCount,
)
t.Fatalf("expected graceful terminal call to disable tools, got %d tool defs", terminalToolsCount)
}
foundHint := false
@@ -1257,8 +1247,7 @@ func TestAgentLoop_InterruptGraceful_UsesTerminalNoToolCall(t *testing.T) {
if msg.Role == "user" && msg.Content == expectedHint {
foundHint = true
}
if msg.Role == "tool" && msg.ToolCallID == "call_2" &&
msg.Content == "Skipped due to graceful interrupt." {
if msg.Role == "tool" && msg.ToolCallID == "call_2" && msg.Content == "Skipped due to graceful interrupt." {
foundSkipped = true
}
}
@@ -1550,8 +1539,7 @@ func TestAgentLoop_Steering_SkippedToolsHaveErrorResults(t *testing.T) {
foundSkipped := false
for _, m := range msgs {
if m.Role == "tool" && m.ToolCallID == "call_2" &&
m.Content == "Skipped due to queued user message." {
if m.Role == "tool" && m.ToolCallID == "call_2" && m.Content == "Skipped due to queued user message." {
foundSkipped = true
break
}
@@ -1559,13 +1547,7 @@ func TestAgentLoop_Steering_SkippedToolsHaveErrorResults(t *testing.T) {
if !foundSkipped {
// Log what we actually got
for i, m := range msgs {
t.Logf(
"msg[%d]: role=%s toolCallID=%s content=%s",
i,
m.Role,
m.ToolCallID,
truncate(m.Content, 80),
)
t.Logf("msg[%d]: role=%s toolCallID=%s content=%s", i, m.Role, m.ToolCallID, truncate(m.Content, 80))
}
t.Fatal("expected skipped tool result for call_2")
}
+5 -20
View File
@@ -505,12 +505,7 @@ func spawnSubTurn(
// Event emissions:
// - SubTurnResultDeliveredEvent: successful delivery to channel
// - SubTurnOrphanResultEvent: delivery failed (parent finished or channel full)
func deliverSubTurnResult(
al *AgentLoop,
parentTS *turnState,
childID string,
result *tools.ToolResult,
) {
func deliverSubTurnResult(al *AgentLoop, parentTS *turnState, childID string, result *tools.ToolResult) {
// Let GC clean up the pendingResults channel; parent Finish will no longer close it.
// We use defer/recover to catch any unlikely channel panics if it were ever closed.
defer func() {
@@ -521,14 +516,9 @@ func deliverSubTurnResult(
"recover": r,
})
if result != nil && al != nil {
al.emitEvent(
EventKindSubTurnOrphan,
al.emitEvent(EventKindSubTurnOrphan,
parentTS.eventMeta("deliverSubTurnResult", "subturn.orphan"),
SubTurnOrphanPayload{
ParentTurnID: parentTS.turnID,
ChildTurnID: childID,
Reason: "panic",
},
SubTurnOrphanPayload{ParentTurnID: parentTS.turnID, ChildTurnID: childID, Reason: "panic"},
)
}
}
@@ -541,14 +531,9 @@ func deliverSubTurnResult(
// If parent turn has already finished, treat this as an orphan result
if isFinished || resultChan == nil {
if result != nil && al != nil {
al.emitEvent(
EventKindSubTurnOrphan,
al.emitEvent(EventKindSubTurnOrphan,
parentTS.eventMeta("deliverSubTurnResult", "subturn.orphan"),
SubTurnOrphanPayload{
ParentTurnID: parentTS.turnID,
ChildTurnID: childID,
Reason: "parent_finished",
},
SubTurnOrphanPayload{ParentTurnID: parentTS.turnID, ChildTurnID: childID, Reason: "parent_finished"},
)
}
return
+2 -8
View File
@@ -571,8 +571,7 @@ func TestHardAbortSessionRollback(t *testing.T) {
}
// Verify the content matches the initial state
if finalHistory[0].Content != "initial message 1" ||
finalHistory[1].Content != "initial response 1" {
if finalHistory[0].Content != "initial message 1" || finalHistory[1].Content != "initial response 1" {
t.Error("history content does not match initial state after rollback")
}
}
@@ -1291,12 +1290,7 @@ func TestDeliverSubTurnResult_RaceWithFinish(t *testing.T) {
finalOrphan := orphanCount
mu.Unlock()
t.Logf(
"Delivered: %d, Orphan: %d, Total: %d",
finalDelivered,
finalOrphan,
finalDelivered+finalOrphan,
)
t.Logf("Delivered: %d, Orphan: %d, Total: %d", finalDelivered, finalOrphan, finalDelivered+finalOrphan)
// With the new drainPendingResults behavior, the total events may be >= numResults
// because Finish() drains remaining results from the channel and emits them as orphans.
+1 -5
View File
@@ -65,11 +65,7 @@ func DefaultConfig() *Config {
Enabled: true,
Text: FlexibleStringSlice{"Thinking... 💭"},
},
Streaming: StreamingConfig{
Enabled: true,
ThrottleSeconds: 3,
MinGrowthChars: 200,
},
Streaming: StreamingConfig{Enabled: true, ThrottleSeconds: 3, MinGrowthChars: 200},
UseMarkdownV2: false,
},
Feishu: FeishuConfig{
+1 -2
View File
@@ -335,8 +335,7 @@ func v0ConvertProvidersToModelList(cfg *configV0) []modelConfigV0 {
providerNames: []string{"github_copilot", "copilot"},
protocol: "github-copilot",
buildConfig: func(p providersConfigV0) (modelConfigV0, bool) {
if p.GitHubCopilot.APIKey == "" && p.GitHubCopilot.APIBase == "" &&
p.GitHubCopilot.ConnectMode == "" {
if p.GitHubCopilot.APIKey == "" && p.GitHubCopilot.APIBase == "" && p.GitHubCopilot.ConnectMode == "" {
return modelConfigV0{}, false
}
return modelConfigV0{
+7 -34
View File
@@ -72,11 +72,7 @@ func TestMigration_Integration_LegacyConfigWithoutWorkspace(t *testing.T) {
// CRITICAL: Verify that user's settings are preserved
// This was the bug - these settings were lost when Workspace was empty
if cfg.Agents.Defaults.Provider != "openai" {
t.Errorf(
"Provider = %q, want %q (user's setting should be preserved)",
cfg.Agents.Defaults.Provider,
"openai",
)
t.Errorf("Provider = %q, want %q (user's setting should be preserved)", cfg.Agents.Defaults.Provider, "openai")
}
// Old "model" field is migrated to "model_name" field
if cfg.Agents.Defaults.ModelName != "gpt-4o" {
@@ -303,11 +299,7 @@ func TestMigration_Integration_PreservesAllAgentsFields(t *testing.T) {
t.Errorf("Agent.ID = %q, want %q", cfg.Agents.List[0].ID, "special-agent")
}
if cfg.Agents.List[0].Workspace != "/special/workspace" {
t.Errorf(
"Agent.Workspace = %q, want %q",
cfg.Agents.List[0].Workspace,
"/special/workspace",
)
t.Errorf("Agent.Workspace = %q, want %q", cfg.Agents.List[0].Workspace, "/special/workspace")
}
// Workspace should have default since it was empty in legacy config
@@ -370,10 +362,7 @@ func TestMigration_Integration_ChannelsConfigMigrated(t *testing.T) {
// OneBot: group_trigger_prefix should be migrated to group_trigger.prefixes
if len(cfg.Channels.OneBot.GroupTrigger.Prefixes) != 2 {
t.Errorf(
"len(OneBot.GroupTrigger.Prefixes) = %d, want 2",
len(cfg.Channels.OneBot.GroupTrigger.Prefixes),
)
t.Errorf("len(OneBot.GroupTrigger.Prefixes) = %d, want 2", len(cfg.Channels.OneBot.GroupTrigger.Prefixes))
} else {
if cfg.Channels.OneBot.GroupTrigger.Prefixes[0] != "/" {
t.Errorf("Prefixes[0] = %q, want %q", cfg.Channels.OneBot.GroupTrigger.Prefixes[0], "/")
@@ -454,25 +443,13 @@ func TestMigration_Integration_RoundTrip_SerializeAndLoad(t *testing.T) {
// Verify configs are identical
if cfg2.Agents.Defaults.Provider != cfg1.Agents.Defaults.Provider {
t.Errorf(
"Provider changed from %q to %q",
cfg1.Agents.Defaults.Provider,
cfg2.Agents.Defaults.Provider,
)
t.Errorf("Provider changed from %q to %q", cfg1.Agents.Defaults.Provider, cfg2.Agents.Defaults.Provider)
}
if cfg2.Agents.Defaults.ModelName != cfg1.Agents.Defaults.ModelName {
t.Errorf(
"ModelName changed from %q to %q",
cfg1.Agents.Defaults.ModelName,
cfg2.Agents.Defaults.ModelName,
)
t.Errorf("ModelName changed from %q to %q", cfg1.Agents.Defaults.ModelName, cfg2.Agents.Defaults.ModelName)
}
if cfg2.Agents.Defaults.MaxTokens != cfg1.Agents.Defaults.MaxTokens {
t.Errorf(
"MaxTokens changed from %d to %d",
cfg1.Agents.Defaults.MaxTokens,
cfg2.Agents.Defaults.MaxTokens,
)
t.Errorf("MaxTokens changed from %d to %d", cfg1.Agents.Defaults.MaxTokens, cfg2.Agents.Defaults.MaxTokens)
}
}
@@ -580,11 +557,7 @@ func TestMigration_Integration_ModelNameField(t *testing.T) {
// GetModelName() should return model_name, not model (deprecated)
if cfg.Agents.Defaults.GetModelName() != "deepseek-reasoner" {
t.Errorf(
"GetModelName() = %q, want %q",
cfg.Agents.Defaults.GetModelName(),
"deepseek-reasoner",
)
t.Errorf("GetModelName() = %q, want %q", cfg.Agents.Defaults.GetModelName(), "deepseek-reasoner")
}
if len(cfg.Agents.Defaults.ModelFallbacks) != 1 {
+13 -44
View File
@@ -91,11 +91,9 @@ func TestConvertProvidersToModelList_LiteLLM(t *testing.T) {
func TestConvertProvidersToModelList_Multiple(t *testing.T) {
cfg := &configV0{
Providers: providersConfigV0{
OpenAI: openAIProviderConfigV0{
providerConfigV0: providerConfigV0{APIKey: "openai-key"},
},
Groq: providerConfigV0{APIKey: "groq-key"},
Zhipu: providerConfigV0{APIKey: "zhipu-key"},
OpenAI: openAIProviderConfigV0{providerConfigV0: providerConfigV0{APIKey: "openai-key"}},
Groq: providerConfigV0{APIKey: "groq-key"},
Zhipu: providerConfigV0{APIKey: "zhipu-key"},
},
}
@@ -144,13 +142,8 @@ func TestConvertProvidersToModelList_AllProviders(t *testing.T) {
// Other providers have no configuration, so they won't be converted.
cfg := &configV0{
Providers: providersConfigV0{
OpenAI: openAIProviderConfigV0{
providerConfigV0: providerConfigV0{APIKey: "key1"},
},
LiteLLM: providerConfigV0{
APIKey: "key-litellm",
APIBase: "http://localhost:4000/v1",
},
OpenAI: openAIProviderConfigV0{providerConfigV0: providerConfigV0{APIKey: "key1"}},
LiteLLM: providerConfigV0{APIKey: "key-litellm", APIBase: "http://localhost:4000/v1"},
Anthropic: providerConfigV0{APIKey: "key2"},
OpenRouter: providerConfigV0{APIKey: "key3"},
Groq: providerConfigV0{APIKey: "key4"},
@@ -268,11 +261,7 @@ func TestConvertProvidersToModelList_PreservesUserModel_DeepSeek(t *testing.T) {
// Should use user's model, not default
if result[0].Model != "deepseek/deepseek-reasoner" {
t.Errorf(
"Model = %q, want %q (user's configured model)",
result[0].Model,
"deepseek/deepseek-reasoner",
)
t.Errorf("Model = %q, want %q (user's configured model)", result[0].Model, "deepseek/deepseek-reasoner")
}
}
@@ -382,9 +371,7 @@ func TestConvertProvidersToModelList_MultipleProviders_PreservesUserModel(t *tes
},
},
Providers: providersConfigV0{
OpenAI: openAIProviderConfigV0{
providerConfigV0: providerConfigV0{APIKey: "sk-openai"},
},
OpenAI: openAIProviderConfigV0{providerConfigV0: providerConfigV0{APIKey: "sk-openai"}},
DeepSeek: providerConfigV0{APIKey: "sk-deepseek"},
},
}
@@ -404,11 +391,7 @@ func TestConvertProvidersToModelList_MultipleProviders_PreservesUserModel(t *tes
}
case "deepseek":
if mc.Model != "deepseek/deepseek-reasoner" {
t.Errorf(
"DeepSeek Model = %q, want %q (user's)",
mc.Model,
"deepseek/deepseek-reasoner",
)
t.Errorf("DeepSeek Model = %q, want %q (user's)", mc.Model, "deepseek/deepseek-reasoner")
}
}
}
@@ -506,11 +489,7 @@ func TestConvertProvidersToModelList_NoProviderField_SingleProvider(t *testing.T
// ModelName should be the user's model value for backward compatibility
if result[0].ModelName != "glm-4.7" {
t.Errorf(
"ModelName = %q, want %q (user's model for backward compatibility)",
result[0].ModelName,
"glm-4.7",
)
t.Errorf("ModelName = %q, want %q (user's model for backward compatibility)", result[0].ModelName, "glm-4.7")
}
// Model should use the user's model with protocol prefix
@@ -531,10 +510,8 @@ func TestConvertProvidersToModelList_NoProviderField_MultipleProviders(t *testin
},
},
Providers: providersConfigV0{
OpenAI: openAIProviderConfigV0{
providerConfigV0: providerConfigV0{APIKey: "openai-key"},
},
Zhipu: providerConfigV0{APIKey: "zhipu-key"},
OpenAI: openAIProviderConfigV0{providerConfigV0: providerConfigV0{APIKey: "openai-key"}},
Zhipu: providerConfigV0{APIKey: "zhipu-key"},
},
}
@@ -594,11 +571,7 @@ func TestBuildModelWithProtocol_NoPrefix(t *testing.T) {
func TestBuildModelWithProtocol_AlreadyHasPrefix(t *testing.T) {
result := buildModelWithProtocol("openrouter", "openrouter/auto")
if result != "openrouter/auto" {
t.Errorf(
"buildModelWithProtocol(openrouter, openrouter/auto) = %q, want %q",
result,
"openrouter/auto",
)
t.Errorf("buildModelWithProtocol(openrouter, openrouter/auto) = %q, want %q", result, "openrouter/auto")
}
}
@@ -640,10 +613,6 @@ func TestConvertProvidersToModelList_LegacyModelWithProtocolPrefix(t *testing.T)
// Model should NOT have duplicated prefix
if result[0].Model != "openrouter/auto" {
t.Errorf(
"Model = %q, want %q (should not duplicate prefix)",
result[0].Model,
"openrouter/auto",
)
t.Errorf("Model = %q, want %q (should not duplicate prefix)", result[0].Model, "openrouter/auto")
}
}
+4 -20
View File
@@ -17,11 +17,7 @@ func TestGetModelConfig_Found(t *testing.T) {
Version: CurrentVersion,
ModelList: []*ModelConfig{
{ModelName: "test-model", Model: "openai/gpt-4o", APIKeys: SimpleSecureStrings("key1")},
{
ModelName: "other-model",
Model: "anthropic/claude",
APIKeys: SimpleSecureStrings("key2"),
},
{ModelName: "other-model", Model: "anthropic/claude", APIKeys: SimpleSecureStrings("key2")},
},
}
@@ -118,16 +114,8 @@ func TestGetModelConfig_RoundRobinStartsFromFirstMatch(t *testing.T) {
func TestGetModelConfig_Concurrent(t *testing.T) {
cfg := &Config{
ModelList: []*ModelConfig{
{
ModelName: "concurrent-model",
Model: "openai/gpt-4o-1",
APIKeys: SimpleSecureStrings("key1"),
},
{
ModelName: "concurrent-model",
Model: "openai/gpt-4o-2",
APIKeys: SimpleSecureStrings("key2"),
},
{ModelName: "concurrent-model", Model: "openai/gpt-4o-1", APIKeys: SimpleSecureStrings("key1")},
{ModelName: "concurrent-model", Model: "openai/gpt-4o-2", APIKeys: SimpleSecureStrings("key2")},
},
}
@@ -302,11 +290,7 @@ func TestConfig_ValidateModelList(t *testing.T) {
}
if err != nil && tt.errMsg != "" {
if !strings.Contains(err.Error(), tt.errMsg) {
t.Errorf(
"ValidateModelList() error = %v, want error containing %q",
err,
tt.errMsg,
)
t.Errorf("ValidateModelList() error = %v, want error containing %q", err, tt.errMsg)
}
}
})
+2 -8
View File
@@ -117,10 +117,7 @@ func TestExpandMultiKeyModels_WithExistingFallbacks(t *testing.T) {
ModelName: "gpt-4",
Model: "openai/gpt-4o",
}
modelCfg.APIKeys = SimpleSecureStrings(
"key0",
"key1",
) // Use internal field for multi-key testing
modelCfg.APIKeys = SimpleSecureStrings("key0", "key1") // Use internal field for multi-key testing
modelCfg.Fallbacks = []string{"claude-3"}
models := []*ModelConfig{modelCfg}
@@ -199,10 +196,7 @@ func TestExpandMultiKeyModels_PreservesOtherFields(t *testing.T) {
RequestTimeout: 30,
ThinkingLevel: "high",
}
modelCfg.APIKeys = SimpleSecureStrings(
"key0",
"key1",
) // Use internal field for multi-key testing
modelCfg.APIKeys = SimpleSecureStrings("key0", "key1") // Use internal field for multi-key testing
models := []*ModelConfig{modelCfg}
result := expandMultiKeyModels(models)
+2 -4
View File
@@ -304,13 +304,11 @@ func (s *SecureString) UnmarshalJSON(value []byte) error {
func (s SecureString) MarshalYAML() (any, error) {
// Preserve raw value if it is already a reference (enc:// or file://)
if strings.HasPrefix(s.raw, credential.EncScheme) ||
strings.HasPrefix(s.raw, credential.FileScheme) {
if strings.HasPrefix(s.raw, credential.EncScheme) || strings.HasPrefix(s.raw, credential.FileScheme) {
return s.raw, nil
}
// If resolved is a reference format (e.g. set via Set), copy back to raw
if strings.HasPrefix(s.resolved, credential.EncScheme) ||
strings.HasPrefix(s.resolved, credential.FileScheme) {
if strings.HasPrefix(s.resolved, credential.EncScheme) || strings.HasPrefix(s.resolved, credential.FileScheme) {
s.raw = s.resolved
return s.raw, nil
}
+5 -24
View File
@@ -35,10 +35,7 @@ func TestJSONUnmarshalPrivateFields(t *testing.T) {
t.Errorf("PublicField = %q, want 'pub'", s.PublicField)
}
if s.privateField != "" {
t.Errorf(
"privateField = %q, want empty because unexported fields are ignored",
s.privateField,
)
t.Errorf("privateField = %q, want empty because unexported fields are ignored", s.privateField)
}
}
@@ -355,21 +352,13 @@ skills:
// Verify Channel tokens via Key() methods
// Telegram
assert.Equal(
t,
"123456789:ABCdefGHIjklMNOpqrsTUVwxyz",
cfg.Channels.Telegram.Token.String(),
)
assert.Equal(t, "123456789:ABCdefGHIjklMNOpqrsTUVwxyz", cfg.Channels.Telegram.Token.String())
t.Logf("Telegram Token(): %s", cfg.Channels.Telegram.Token.String())
// Feishu
assert.Equal(t, "feishu_test_app_secret", cfg.Channels.Feishu.AppSecret.String())
assert.Equal(t, "feishu_test_encrypt_key", cfg.Channels.Feishu.EncryptKey.String())
assert.Equal(
t,
"feishu_test_verification_token",
cfg.Channels.Feishu.VerificationToken.String(),
)
assert.Equal(t, "feishu_test_verification_token", cfg.Channels.Feishu.VerificationToken.String())
t.Logf("Feishu AppSecret(): %s", cfg.Channels.Feishu.AppSecret.String())
t.Logf("Feishu EncryptKey(): %s", cfg.Channels.Feishu.EncryptKey.String())
t.Logf("Feishu VerificationToken(): %s", cfg.Channels.Feishu.VerificationToken.String())
@@ -394,11 +383,7 @@ skills:
// LINE
assert.Equal(t, "line_test_channel_secret", cfg.Channels.LINE.ChannelSecret.String())
assert.Equal(
t,
"line_test_channel_access_token",
cfg.Channels.LINE.ChannelAccessToken.String(),
)
assert.Equal(t, "line_test_channel_access_token", cfg.Channels.LINE.ChannelAccessToken.String())
t.Logf("LINE ChannelSecret(): %s", cfg.Channels.LINE.ChannelSecret.String())
t.Logf("LINE ChannelAccessToken(): %s", cfg.Channels.LINE.ChannelAccessToken.String())
@@ -446,11 +431,7 @@ skills:
assert.Equal(t, "ghp-github-from-file-abc123", cfg.Tools.Skills.Github.Token.String())
t.Logf("Github Token(): %s", cfg.Tools.Skills.Github.Token.String())
assert.Equal(
t,
"clawhub-auth-token-from-file",
cfg.Tools.Skills.Registries.ClawHub.AuthToken.String(),
)
assert.Equal(t, "clawhub-auth-token-from-file", cfg.Tools.Skills.Registries.ClawHub.AuthToken.String())
t.Logf("ClawHub AuthToken(): %s", cfg.Tools.Skills.Registries.ClawHub.AuthToken.String())
t.Log("All security keys are successfully accessible via their respective Key() methods")
+5 -17
View File
@@ -15,10 +15,7 @@ import (
// JobExecutor is the interface for executing cron jobs through the agent
type JobExecutor interface {
ProcessDirectWithChannel(
ctx context.Context,
content, sessionKey, channel, chatID string,
) (string, error)
ProcessDirectWithChannel(ctx context.Context, content, sessionKey, channel, chatID string) (string, error)
// PublishResponseIfNeeded sends response to the outbound bus only when the
// agent did not already deliver content through the message tool in this round.
PublishResponseIfNeeded(ctx context.Context, channel, chatID, response string)
@@ -37,13 +34,8 @@ type CronTool struct {
// NewCronTool creates a new CronTool
// execTimeout: 0 means no timeout, >0 sets the timeout duration
func NewCronTool(
cronService *cron.CronService,
executor JobExecutor,
msgBus *bus.MessageBus,
workspace string,
restrict bool,
execTimeout time.Duration,
config *config.Config,
cronService *cron.CronService, executor JobExecutor, msgBus *bus.MessageBus, workspace string, restrict bool,
execTimeout time.Duration, config *config.Config,
) (*CronTool, error) {
allowCommand := true
execEnabled := true
@@ -164,9 +156,7 @@ func (t *CronTool) addJob(ctx context.Context, args map[string]any) *ToolResult
chatID := ToolChatID(ctx)
if channel == "" || chatID == "" {
return ErrorResult(
"no session context (channel/chat_id not set). Use this tool in an active conversation.",
)
return ErrorResult("no session context (channel/chat_id not set). Use this tool in an active conversation.")
}
message, ok := args["message"].(string)
@@ -218,9 +208,7 @@ func (t *CronTool) addJob(ctx context.Context, args map[string]any) *ToolResult
// Validate type parameter (server-side whitelist, not just LLM schema hint)
msgType, _ := args["type"].(string)
if msgType != "" && msgType != "message" && msgType != "directive" {
return ErrorResult(
fmt.Sprintf("invalid type %q, must be 'message' or 'directive'", msgType),
)
return ErrorResult(fmt.Sprintf("invalid type %q, must be 'message' or 'directive'", msgType))
}
// GHSA-pv8c-p6jf-3fpp: command scheduling requires internal channel. When
+8 -34
View File
@@ -49,11 +49,7 @@ func (s *stubJobExecutor) PublishResponseIfNeeded(
s.publishedChatID = chatID
}
func newTestCronToolWithExecutorAndConfig(
t *testing.T,
executor JobExecutor,
cfg *config.Config,
) *CronTool {
func newTestCronToolWithExecutorAndConfig(t *testing.T, executor JobExecutor, cfg *config.Config) *CronTool {
t.Helper()
storePath := filepath.Join(t.TempDir(), "cron.json")
cronService := cron.NewCronService(storePath, nil)
@@ -106,10 +102,7 @@ func TestCronTool_CommandDoesNotRequireConfirmByDefault(t *testing.T) {
})
if result.IsError {
t.Fatalf(
"expected command scheduling without confirm to succeed by default, got: %s",
result.ForLLM,
)
t.Fatalf("expected command scheduling without confirm to succeed by default, got: %s", result.ForLLM)
}
if !strings.Contains(result.ForLLM, "Cron job added") {
t.Errorf("expected 'Cron job added', got: %s", result.ForLLM)
@@ -197,10 +190,7 @@ func TestCronTool_CommandAllowedFromInternalChannel(t *testing.T) {
})
if result.IsError {
t.Fatalf(
"expected command scheduling to succeed from internal channel, got: %s",
result.ForLLM,
)
t.Fatalf("expected command scheduling to succeed from internal channel, got: %s", result.ForLLM)
}
if !strings.Contains(result.ForLLM, "Cron job added") {
t.Errorf("expected 'Cron job added', got: %s", result.ForLLM)
@@ -235,10 +225,7 @@ func TestCronTool_NonCommandJobAllowedFromRemoteChannel(t *testing.T) {
})
if result.IsError {
t.Fatalf(
"expected non-command reminder to succeed from remote channel, got: %s",
result.ForLLM,
)
t.Fatalf("expected non-command reminder to succeed from remote channel, got: %s", result.ForLLM)
}
}
@@ -310,11 +297,7 @@ func TestCronTool_ExecuteJobPublishesAgentResponse(t *testing.T) {
t.Fatalf("sessionKey = %q, want cron-job-1", executor.lastKey)
}
if executor.lastChan != "telegram" || executor.lastChatID != "chat-1" {
t.Fatalf(
"executor target = %s/%s, want telegram/chat-1",
executor.lastChan,
executor.lastChatID,
)
t.Fatalf("executor target = %s/%s, want telegram/chat-1", executor.lastChan, executor.lastChatID)
}
if executor.lastPrompt != "send me a poem" {
t.Fatalf("prompt = %q, want original message", executor.lastPrompt)
@@ -323,11 +306,7 @@ func TestCronTool_ExecuteJobPublishesAgentResponse(t *testing.T) {
t.Fatalf("published response = %q, want generated reply", executor.publishedResp)
}
if executor.publishedChan != "telegram" || executor.publishedChatID != "chat-1" {
t.Fatalf(
"published target = %s/%s, want telegram/chat-1",
executor.publishedChan,
executor.publishedChatID,
)
t.Fatalf("published target = %s/%s, want telegram/chat-1", executor.publishedChan, executor.publishedChatID)
}
}
@@ -363,10 +342,7 @@ func TestCronTool_ExecuteJobSkipsWhenMessageToolAlreadySent(t *testing.T) {
}
if executor.publishedResp != "" {
t.Fatalf(
"expected no published response when message tool already sent, got: %q",
executor.publishedResp,
)
t.Fatalf("expected no published response when message tool already sent, got: %q", executor.publishedResp)
}
}
@@ -410,9 +386,7 @@ func TestCronTool_ExecuteJobDirectiveWithDeliverRoutesToAgent(t *testing.T) {
}
if executor.lastPrompt == "" {
t.Fatal(
"expected agent to be called for directive+deliver, but ProcessDirectWithChannel was not invoked",
)
t.Fatal("expected agent to be called for directive+deliver, but ProcessDirectWithChannel was not invoked")
}
if executor.publishedResp != "agent processed" {
t.Fatalf("published response = %q, want %q", executor.publishedResp, "agent processed")
+3 -14
View File
@@ -16,11 +16,7 @@ type EditFileTool struct {
}
// NewEditFileTool creates a new EditFileTool with optional directory restriction.
func NewEditFileTool(
workspace string,
restrict bool,
allowPaths ...[]*regexp.Regexp,
) *EditFileTool {
func NewEditFileTool(workspace string, restrict bool, allowPaths ...[]*regexp.Regexp) *EditFileTool {
var patterns []*regexp.Regexp
if len(allowPaths) > 0 {
patterns = allowPaths[0]
@@ -83,11 +79,7 @@ type AppendFileTool struct {
fs fileSystem
}
func NewAppendFileTool(
workspace string,
restrict bool,
allowPaths ...[]*regexp.Regexp,
) *AppendFileTool {
func NewAppendFileTool(workspace string, restrict bool, allowPaths ...[]*regexp.Regexp) *AppendFileTool {
var patterns []*regexp.Regexp
if len(allowPaths) > 0 {
patterns = allowPaths[0]
@@ -174,10 +166,7 @@ func replaceEditContent(content []byte, oldText, newText string) ([]byte, error)
count := strings.Count(contentStr, oldText)
if count > 1 {
return nil, fmt.Errorf(
"old_text appears %d times. Please provide more context to make it unique",
count,
)
return nil, fmt.Errorf("old_text appears %d times. Please provide more context to make it unique", count)
}
newContent := strings.Replace(contentStr, oldText, newText, 1)
+2 -4
View File
@@ -76,8 +76,7 @@ func TestEditTool_EditFile_NotFound(t *testing.T) {
}
// Should mention file not found
if !strings.Contains(result.ForLLM, "not found") &&
!strings.Contains(result.ForUser, "not found") {
if !strings.Contains(result.ForLLM, "not found") && !strings.Contains(result.ForUser, "not found") {
t.Errorf("Expected 'file not found' message, got ForLLM: %s", result.ForLLM)
}
}
@@ -104,8 +103,7 @@ func TestEditTool_EditFile_OldTextNotFound(t *testing.T) {
}
// Should mention old_text not found
if !strings.Contains(result.ForLLM, "not found") &&
!strings.Contains(result.ForUser, "not found") {
if !strings.Contains(result.ForLLM, "not found") && !strings.Contains(result.ForUser, "not found") {
t.Errorf("Expected 'not found' message, got ForLLM: %s", result.ForLLM)
}
}
+3 -13
View File
@@ -20,11 +20,7 @@ import (
const MaxReadFileSize = 64 * 1024 // 64KB limit to avoid context overflow
func validatePathWithAllowPaths(
path, workspace string,
restrict bool,
patterns []*regexp.Regexp,
) (string, error) {
func validatePathWithAllowPaths(path, workspace string, restrict bool, patterns []*regexp.Regexp) (string, error) {
if workspace == "" {
return path, fmt.Errorf("workspace is not defined")
}
@@ -487,11 +483,7 @@ type WriteFileTool struct {
fs fileSystem
}
func NewWriteFileTool(
workspace string,
restrict bool,
allowPaths ...[]*regexp.Regexp,
) *WriteFileTool {
func NewWriteFileTool(workspace string, restrict bool, allowPaths ...[]*regexp.Regexp) *WriteFileTool {
var patterns []*regexp.Regexp
if len(allowPaths) > 0 {
patterns = allowPaths[0]
@@ -544,9 +536,7 @@ func (t *WriteFileTool) Execute(ctx context.Context, args map[string]any) *ToolR
if !overwrite {
if _, err := t.fs.Open(path); err == nil {
return ErrorResult(
fmt.Sprintf("file: %s already exists. Set overwrite=true to replace.", path),
)
return ErrorResult(fmt.Sprintf("file: %s already exists. Set overwrite=true to replace.", path))
}
}
+11 -42
View File
@@ -59,13 +59,8 @@ func TestFilesystemTool_ReadFile_NotFound(t *testing.T) {
}
// Should contain error message
if !strings.Contains(result.ForLLM, "failed to open file") &&
!strings.Contains(result.ForUser, "failed to read") {
t.Errorf(
"Expected error message, got ForLLM: %s, ForUser: %s",
result.ForLLM,
result.ForUser,
)
if !strings.Contains(result.ForLLM, "failed to open file") && !strings.Contains(result.ForUser, "failed to read") {
t.Errorf("Expected error message, got ForLLM: %s, ForUser: %s", result.ForLLM, result.ForUser)
}
}
@@ -83,8 +78,7 @@ func TestFilesystemTool_ReadFile_MissingPath(t *testing.T) {
}
// Should mention required parameter
if !strings.Contains(result.ForLLM, "path is required") &&
!strings.Contains(result.ForUser, "path is required") {
if !strings.Contains(result.ForLLM, "path is required") && !strings.Contains(result.ForUser, "path is required") {
t.Errorf("Expected 'path is required' message, got ForLLM: %s", result.ForLLM)
}
}
@@ -303,12 +297,7 @@ func TestFilesystemTool_WriteFile_OverwriteSandboxed(t *testing.T) {
"content": "replaced in sandbox",
"overwrite": true,
})
assert.False(
t,
result.IsError,
"expected success in sandbox mode with overwrite=true, got: %s",
result.ForLLM,
)
assert.False(t, result.IsError, "expected success in sandbox mode with overwrite=true, got: %s", result.ForLLM)
data, err := os.ReadFile(filepath.Join(workspace, testFile))
assert.NoError(t, err)
@@ -336,8 +325,7 @@ func TestFilesystemTool_ListDir_Success(t *testing.T) {
}
// Should list files and directories
if !strings.Contains(result.ForLLM, "file1.txt") ||
!strings.Contains(result.ForLLM, "file2.txt") {
if !strings.Contains(result.ForLLM, "file1.txt") || !strings.Contains(result.ForLLM, "file2.txt") {
t.Errorf("Expected files in listing, got: %s", result.ForLLM)
}
if !strings.Contains(result.ForLLM, "subdir") {
@@ -361,13 +349,8 @@ func TestFilesystemTool_ListDir_NotFound(t *testing.T) {
}
// Should contain error message
if !strings.Contains(result.ForLLM, "failed to read") &&
!strings.Contains(result.ForUser, "failed to read") {
t.Errorf(
"Expected error message, got ForLLM: %s, ForUser: %s",
result.ForLLM,
result.ForUser,
)
if !strings.Contains(result.ForLLM, "failed to read") && !strings.Contains(result.ForUser, "failed to read") {
t.Errorf("Expected error message, got ForLLM: %s, ForUser: %s", result.ForLLM, result.ForUser)
}
}
@@ -414,8 +397,7 @@ func TestFilesystemTool_ReadFile_RejectsSymlinkEscape(t *testing.T) {
// os.Root might return different errors depending on platform/implementation
// but it definitely should error.
// Our wrapper returns "access denied or file not found"
if !strings.Contains(result.ForLLM, "access denied") &&
!strings.Contains(result.ForLLM, "file not found") &&
if !strings.Contains(result.ForLLM, "access denied") && !strings.Contains(result.ForLLM, "file not found") &&
!strings.Contains(result.ForLLM, "no such file") {
t.Fatalf("expected symlink escape error, got: %s", result.ForLLM)
}
@@ -434,20 +416,10 @@ func TestFilesystemTool_EmptyWorkspace_AccessDenied(t *testing.T) {
})
// We EXPECT IsError=true (access blocked due to empty workspace)
assert.True(
t,
result.IsError,
"Security Regression: Empty workspace allowed access! content: %s",
result.ForLLM,
)
assert.True(t, result.IsError, "Security Regression: Empty workspace allowed access! content: %s", result.ForLLM)
// Verify it failed for the right reason
assert.Contains(
t,
result.ForLLM,
"workspace is not defined",
"Expected 'workspace is not defined' error",
)
assert.Contains(t, result.ForLLM, "workspace is not defined", "Expected 'workspace is not defined' error")
}
// TestRootMkdirAll verifies that root.MkdirAll (used by atomicWriteFileInRoot) handles all cases:
@@ -681,10 +653,7 @@ func TestWhitelistFs_BlocksSymlinkEscapeInAllowedDir(t *testing.T) {
patterns := []*regexp.Regexp{regexp.MustCompile(`^` + regexp.QuoteMeta(allowedDir))}
tool := NewReadFileTool(workspace, true, MaxReadFileSize, patterns)
result := tool.Execute(
context.Background(),
map[string]any{"path": filepath.Join(linkPath, "secret.txt")},
)
result := tool.Execute(context.Background(), map[string]any{"path": filepath.Join(linkPath, "secret.txt")})
if !result.IsError {
t.Fatalf("expected symlink escape from allowed dir to be blocked, got: %s", result.ForLLM)
}
+2 -6
View File
@@ -65,9 +65,7 @@ func (t *I2CTool) Parameters() map[string]any {
func (t *I2CTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
if runtime.GOOS != "linux" {
return ErrorResult(
"I2C is only supported on Linux. This tool requires /dev/i2c-* device files.",
)
return ErrorResult("I2C is only supported on Linux. This tool requires /dev/i2c-* device files.")
}
action, ok := args["action"].(string)
@@ -85,9 +83,7 @@ func (t *I2CTool) Execute(ctx context.Context, args map[string]any) *ToolResult
case "write":
return t.writeDevice(args)
default:
return ErrorResult(
fmt.Sprintf("unknown action: %s (valid: detect, scan, read, write)", action),
)
return ErrorResult(fmt.Sprintf("unknown action: %s (valid: detect, scan, read, write)", action))
}
}
+7 -35
View File
@@ -55,12 +55,7 @@ func smbusProbe(fd int, addr int, hasQuick bool) bool {
size: i2cSmbusQuick,
data: nil,
}
_, _, errno := syscall.Syscall(
syscall.SYS_IOCTL,
uintptr(fd),
i2cSmbus,
uintptr(unsafe.Pointer(&args)),
)
_, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), i2cSmbus, uintptr(unsafe.Pointer(&args)))
return errno == 0
}
@@ -72,12 +67,7 @@ func smbusProbe(fd int, addr int, hasQuick bool) bool {
size: i2cSmbusByte,
data: &data,
}
_, _, errno := syscall.Syscall(
syscall.SYS_IOCTL,
uintptr(fd),
i2cSmbus,
uintptr(unsafe.Pointer(&args)),
)
_, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), i2cSmbus, uintptr(unsafe.Pointer(&args)))
return errno == 0
}
@@ -93,29 +83,16 @@ func (t *I2CTool) scan(args map[string]any) *ToolResult {
devPath := fmt.Sprintf("/dev/i2c-%s", bus)
fd, err := syscall.Open(devPath, syscall.O_RDWR, 0)
if err != nil {
return ErrorResult(
fmt.Sprintf(
"failed to open %s: %v (check permissions and i2c-dev module)",
devPath,
err,
),
)
return ErrorResult(fmt.Sprintf("failed to open %s: %v (check permissions and i2c-dev module)", devPath, err))
}
defer syscall.Close(fd)
// Query adapter capabilities to determine available probe methods.
// I2C_FUNCS writes an unsigned long, which is word-sized on Linux.
var funcs uintptr
_, _, errno := syscall.Syscall(
syscall.SYS_IOCTL,
uintptr(fd),
i2cFuncs,
uintptr(unsafe.Pointer(&funcs)),
)
_, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), i2cFuncs, uintptr(unsafe.Pointer(&funcs)))
if errno != 0 {
return ErrorResult(
fmt.Sprintf("failed to query I2C adapter capabilities on %s: %v", devPath, errno),
)
return ErrorResult(fmt.Sprintf("failed to query I2C adapter capabilities on %s: %v", devPath, errno))
}
hasQuick := funcs&i2cFuncSmbusQuick != 0
@@ -123,10 +100,7 @@ func (t *I2CTool) scan(args map[string]any) *ToolResult {
if !hasQuick && !hasReadByte {
return ErrorResult(
fmt.Sprintf(
"I2C adapter %s supports neither SMBus Quick nor Read Byte — cannot probe safely",
devPath,
),
fmt.Sprintf("I2C adapter %s supports neither SMBus Quick nor Read Byte — cannot probe safely", devPath),
)
}
@@ -158,9 +132,7 @@ func (t *I2CTool) scan(args map[string]any) *ToolResult {
}
if len(found) == 0 {
return SilentResult(
fmt.Sprintf("No devices found on %s. Check wiring and pull-up resistors.", devPath),
)
return SilentResult(fmt.Sprintf("No devices found on %s. Check wiring and pull-up resistors.", devPath))
}
result, _ := json.MarshalIndent(map[string]any{
+6 -28
View File
@@ -314,10 +314,7 @@ func (t *MCPTool) normalizeResultContent(ctx context.Context, content []mcp.Cont
return result
}
func (t *MCPTool) storeEmbeddedResource(
ctx context.Context,
content *mcp.EmbeddedResource,
) (string, string) {
func (t *MCPTool) storeEmbeddedResource(ctx context.Context, content *mcp.EmbeddedResource) (string, string) {
if content == nil || content.Resource == nil {
return "", "[MCP returned an embedded resource without data.]"
}
@@ -377,39 +374,23 @@ func (t *MCPTool) storeBinaryContent(
dir := media.TempDir()
if err := os.MkdirAll(dir, 0o700); err != nil {
return "", fmt.Sprintf(
"[MCP returned %s content (%s) but it could not be stored.]",
kind,
mimeType,
)
return "", fmt.Sprintf("[MCP returned %s content (%s) but it could not be stored.]", kind, mimeType)
}
ext := extensionForMIMEType(mimeType)
tmpFile, err := os.CreateTemp(dir, "mcp-*"+ext)
if err != nil {
return "", fmt.Sprintf(
"[MCP returned %s content (%s) but it could not be stored.]",
kind,
mimeType,
)
return "", fmt.Sprintf("[MCP returned %s content (%s) but it could not be stored.]", kind, mimeType)
}
tmpPath := tmpFile.Name()
if _, err = tmpFile.Write(data); err != nil {
_ = tmpFile.Close()
_ = os.Remove(tmpPath)
return "", fmt.Sprintf(
"[MCP returned %s content (%s) but it could not be stored.]",
kind,
mimeType,
)
return "", fmt.Sprintf("[MCP returned %s content (%s) but it could not be stored.]", kind, mimeType)
}
if err = tmpFile.Close(); err != nil {
_ = os.Remove(tmpPath)
return "", fmt.Sprintf(
"[MCP returned %s content (%s) but it could not be stored.]",
kind,
mimeType,
)
return "", fmt.Sprintf("[MCP returned %s content (%s) but it could not be stored.]", kind, mimeType)
}
scope := fmt.Sprintf(
@@ -489,10 +470,7 @@ func summarizeEmbeddedResource(content *mcp.EmbeddedResource) string {
normalizedMIMEType(resource.MIMEType),
)
}
return fmt.Sprintf(
"[MCP returned embedded resource (%s).]",
normalizedMIMEType(resource.MIMEType),
)
return fmt.Sprintf("[MCP returned embedded resource (%s).]", normalizedMIMEType(resource.MIMEType))
}
func annotationsAllowUser(annotations *mcp.Annotations) bool {
+1 -4
View File
@@ -571,10 +571,7 @@ func TestMCPTool_Execute_EmbeddedResourceBlobStoredAsMedia(t *testing.T) {
result := mcpTool.Execute(WithToolContext(context.Background(), "telegram", "chat-42"), nil)
if len(result.Media) != 1 {
t.Fatalf(
"expected embedded resource blob to be stored as media, got %d refs",
len(result.Media),
)
t.Fatalf("expected embedded resource blob to be stored as media, got %d refs", len(result.Media))
}
path, _, err := store.ResolveWithMeta(result.Media[0])
if err != nil {
+2 -8
View File
@@ -43,10 +43,7 @@ func TestMessageTool_Execute_Success(t *testing.T) {
// - ForLLM contains send status description
if result.ForLLM != "Message sent to test-channel:test-chat-id" {
t.Errorf(
"Expected ForLLM 'Message sent to test-channel:test-chat-id', got '%s'",
result.ForLLM,
)
t.Errorf("Expected ForLLM 'Message sent to test-channel:test-chat-id', got '%s'", result.ForLLM)
}
// - ForUser is empty (user already received message directly)
@@ -91,10 +88,7 @@ func TestMessageTool_Execute_WithCustomChannel(t *testing.T) {
t.Error("Expected Silent=true")
}
if result.ForLLM != "Message sent to custom-channel:custom-chat-id" {
t.Errorf(
"Expected ForLLM 'Message sent to custom-channel:custom-chat-id', got '%s'",
result.ForLLM,
)
t.Errorf("Expected ForLLM 'Message sent to custom-channel:custom-chat-id', got '%s'", result.ForLLM)
}
}
+6 -24
View File
@@ -215,43 +215,28 @@ func storeInlineDataURL(
payload = strings.NewReplacer("\n", "", "\r", "", "\t", "", " ", "").Replace(payload)
decoded, err := base64.StdEncoding.DecodeString(payload)
if err != nil {
return "", fmt.Sprintf(
"[Tool returned inline media content (%s) that could not be decoded.]",
mimeType,
)
return "", fmt.Sprintf("[Tool returned inline media content (%s) that could not be decoded.]", mimeType)
}
dir := media.TempDir()
if err = os.MkdirAll(dir, 0o700); err != nil {
return "", fmt.Sprintf(
"[Tool returned inline media content (%s) but it could not be stored.]",
mimeType,
)
return "", fmt.Sprintf("[Tool returned inline media content (%s) but it could not be stored.]", mimeType)
}
ext := extensionForMIMEType(mimeType)
tmpFile, err := os.CreateTemp(dir, "tool-inline-*"+ext)
if err != nil {
return "", fmt.Sprintf(
"[Tool returned inline media content (%s) but it could not be stored.]",
mimeType,
)
return "", fmt.Sprintf("[Tool returned inline media content (%s) but it could not be stored.]", mimeType)
}
tmpPath := tmpFile.Name()
if _, err = tmpFile.Write(decoded); err != nil {
tmpFile.Close()
_ = os.Remove(tmpPath)
return "", fmt.Sprintf(
"[Tool returned inline media content (%s) but it could not be stored.]",
mimeType,
)
return "", fmt.Sprintf("[Tool returned inline media content (%s) but it could not be stored.]", mimeType)
}
if err = tmpFile.Close(); err != nil {
_ = os.Remove(tmpPath)
return "", fmt.Sprintf(
"[Tool returned inline media content (%s) but it could not be stored.]",
mimeType,
)
return "", fmt.Sprintf("[Tool returned inline media content (%s) but it could not be stored.]", mimeType)
}
filename := sanitizeIdentifierComponent(toolName) + ext
@@ -270,10 +255,7 @@ func storeInlineDataURL(
}, scope)
if err != nil {
_ = os.Remove(tmpPath)
return "", fmt.Sprintf(
"[Tool returned inline media content (%s) but it could not be registered.]",
mimeType,
)
return "", fmt.Sprintf("[Tool returned inline media content (%s) but it could not be registered.]", mimeType)
}
return ref, fmt.Sprintf(inlineMediaStoredMessage, mimeType)
+1 -4
View File
@@ -80,10 +80,7 @@ func (tr *ToolResult) ContentForLLM() string {
}
}
if len(tr.ArtifactTags) > 0 {
artifactNote := "Local artifact paths: " + strings.Join(
tr.ArtifactTags,
" ",
) + "\n" + artifactPathsLLMNote
artifactNote := "Local artifact paths: " + strings.Join(tr.ArtifactTags, " ") + "\n" + artifactPathsLLMNote
if content == "" {
content = artifactNote
} else if !strings.Contains(content, artifactNote) {
+1 -5
View File
@@ -142,11 +142,7 @@ func TestToolResultJSONSerialization(t *testing.T) {
t.Errorf("ForLLM mismatch: got '%s', want '%s'", decoded.ForLLM, tt.result.ForLLM)
}
if decoded.ForUser != tt.result.ForUser {
t.Errorf(
"ForUser mismatch: got '%s', want '%s'",
decoded.ForUser,
tt.result.ForUser,
)
t.Errorf("ForUser mismatch: got '%s', want '%s'", decoded.ForUser, tt.result.ForUser)
}
if decoded.Silent != tt.result.Silent {
t.Errorf("Silent mismatch: got %v, want %v", decoded.Silent, tt.result.Silent)
+9 -43
View File
@@ -56,38 +56,19 @@ func (t *RegexSearchTool) Execute(ctx context.Context, args map[string]any) *Too
}
if len(pattern) > MaxRegexPatternLength {
logger.WarnCF(
"discovery",
"Regex pattern rejected (too long)",
map[string]any{"len": len(pattern)},
)
return ErrorResult(
fmt.Sprintf("Pattern too long: max %d characters allowed", MaxRegexPatternLength),
)
logger.WarnCF("discovery", "Regex pattern rejected (too long)", map[string]any{"len": len(pattern)})
return ErrorResult(fmt.Sprintf("Pattern too long: max %d characters allowed", MaxRegexPatternLength))
}
logger.DebugCF("discovery", "Regex search", map[string]any{"pattern": pattern})
res, err := t.registry.SearchRegex(pattern, t.maxSearchResults)
if err != nil {
logger.WarnCF(
"discovery",
"Invalid regex pattern",
map[string]any{"pattern": pattern, "error": err.Error()},
)
return ErrorResult(
fmt.Sprintf(
"Invalid regex pattern syntax: %v. Please fix your regex and try again.",
err,
),
)
logger.WarnCF("discovery", "Invalid regex pattern", map[string]any{"pattern": pattern, "error": err.Error()})
return ErrorResult(fmt.Sprintf("Invalid regex pattern syntax: %v. Please fix your regex and try again.", err))
}
logger.InfoCF(
"discovery",
"Regex search completed",
map[string]any{"pattern": pattern, "results": len(res)},
)
logger.InfoCF("discovery", "Regex search completed", map[string]any{"pattern": pattern, "results": len(res)})
return formatDiscoveryResponse(t.registry, res, t.ttl)
}
@@ -157,11 +138,7 @@ func (t *BM25SearchTool) Execute(ctx context.Context, args map[string]any) *Tool
}
}
logger.InfoCF(
"discovery",
"BM25 search completed",
map[string]any{"query": query, "results": len(results)},
)
logger.InfoCF("discovery", "BM25 search completed", map[string]any{"query": query, "results": len(results)})
return formatDiscoveryResponse(t.registry, results, t.ttl)
}
@@ -173,10 +150,7 @@ type ToolSearchResult struct {
Description string `json:"description"`
}
func (r *ToolRegistry) SearchRegex(
pattern string,
maxSearchResults int,
) ([]ToolSearchResult, error) {
func (r *ToolRegistry) SearchRegex(pattern string, maxSearchResults int) ([]ToolSearchResult, error) {
if maxSearchResults <= 0 {
return nil, nil
}
@@ -214,11 +188,7 @@ func (r *ToolRegistry) SearchRegex(
return results, nil
}
func formatDiscoveryResponse(
registry *ToolRegistry,
results []ToolSearchResult,
ttl int,
) *ToolResult {
func formatDiscoveryResponse(registry *ToolRegistry, results []ToolSearchResult, ttl int) *ToolResult {
if len(results) == 0 {
return SilentResult("No tools found matching the query.")
}
@@ -304,11 +274,7 @@ func (t *BM25SearchTool) getOrBuildEngine() *bm25CachedEngine {
cached := &bm25CachedEngine{engine: buildBM25Engine(docs)}
t.cachedEngine = cached
t.cacheVersion = snap.Version
logger.DebugCF(
"discovery",
"BM25 engine rebuilt",
map[string]any{"docs": len(docs), "version": snap.Version},
)
logger.DebugCF("discovery", "BM25 engine rebuilt", map[string]any{"docs": len(docs), "version": snap.Version})
return cached
}
+1 -4
View File
@@ -93,10 +93,7 @@ func TestRegexSearchTool_Execute(t *testing.T) {
reg.mu.RLock()
defer reg.mu.RUnlock()
if reg.tools["mcp_read_file"].TTL != 5 {
t.Errorf(
"Expected TTL of 'mcp_read_file' to be promoted to 5, got %d",
reg.tools["mcp_read_file"].TTL,
)
t.Errorf("Expected TTL of 'mcp_read_file' to be promoted to 5, got %d", reg.tools["mcp_read_file"].TTL)
}
if reg.tools["mcp_fetch_net"].TTL != 0 {
t.Errorf("Expected 'mcp_fetch_net' to NOT be promoted (TTL=0)")
+1 -4
View File
@@ -142,10 +142,7 @@ func (t *SendFileTool) Execute(ctx context.Context, args map[string]any) *ToolRe
return ErrorResult(fmt.Sprintf("failed to register media: %v", err))
}
return MediaResult(
fmt.Sprintf("File %q sent to user", filename),
[]string{ref},
).WithResponseHandled()
return MediaResult(fmt.Sprintf("File %q sent to user", filename), []string{ref}).WithResponseHandled()
}
// detectMediaType determines the MIME type of a file.
+2 -10
View File
@@ -79,11 +79,7 @@ func TestSendFileTool_FileTooLarge(t *testing.T) {
func TestSendFileTool_DefaultMaxSize(t *testing.T) {
tool := NewSendFileTool("/tmp", false, 0, nil)
if tool.maxFileSize != config.DefaultMaxMediaSize {
t.Errorf(
"expected default max size %d, got %d",
config.DefaultMaxMediaSize,
tool.maxFileSize,
)
t.Errorf("expected default max size %d, got %d", config.DefaultMaxMediaSize, tool.maxFileSize)
}
}
@@ -166,11 +162,7 @@ func TestSendFileTool_AllowsWhitelistedMediaTempPath(t *testing.T) {
t.Cleanup(func() { _ = os.Remove(testPath) })
pattern := regexp.MustCompile(
"^" + regexp.QuoteMeta(
filepath.Clean(mediaDir),
) + "(?:" + regexp.QuoteMeta(
string(os.PathSeparator),
) + "|$)",
"^" + regexp.QuoteMeta(filepath.Clean(mediaDir)) + "(?:" + regexp.QuoteMeta(string(os.PathSeparator)) + "|$)",
)
store := media.NewFileMediaStore()
+14 -58
View File
@@ -113,11 +113,7 @@ var (
}
)
func NewExecTool(
workingDir string,
restrict bool,
allowPaths ...[]*regexp.Regexp,
) (*ExecTool, error) {
func NewExecTool(workingDir string, restrict bool, allowPaths ...[]*regexp.Regexp) (*ExecTool, error) {
return NewExecToolWithConfig(workingDir, restrict, nil, allowPaths...)
}
@@ -197,16 +193,8 @@ func (t *ExecTool) Parameters() map[string]any {
"type": "object",
"properties": map[string]any{
"action": map[string]any{
"type": "string",
"enum": []string{
"run",
"list",
"poll",
"read",
"write",
"kill",
"send-keys",
},
"type": "string",
"enum": []string{"run", "list", "poll", "read", "write", "kill", "send-keys"},
"description": "Action: run (execute command), list (show sessions), poll (check status), read (get output), write (send input), kill (terminate), send-keys (send keys to PTY)",
},
"command": map[string]any{
@@ -312,12 +300,7 @@ func (t *ExecTool) executeRun(ctx context.Context, args map[string]any) *ToolRes
cwd := t.workingDir
if wd, ok := args["cwd"].(string); ok && wd != "" {
if t.restrictToWorkspace && t.workingDir != "" {
resolvedWD, err := validatePathWithAllowPaths(
wd,
t.workingDir,
true,
t.allowedPathPatterns,
)
resolvedWD, err := validatePathWithAllowPaths(wd, t.workingDir, true, t.allowedPathPatterns)
if err != nil {
return ErrorResult("Command blocked by safety guard (" + err.Error() + ")")
}
@@ -343,9 +326,7 @@ func (t *ExecTool) executeRun(ctx context.Context, args map[string]any) *ToolRes
if t.restrictToWorkspace && t.workingDir != "" && cwd != t.workingDir {
resolved, err := filepath.EvalSymlinks(cwd)
if err != nil {
return ErrorResult(
fmt.Sprintf("Command blocked by safety guard (path resolution failed: %v)", err),
)
return ErrorResult(fmt.Sprintf("Command blocked by safety guard (path resolution failed: %v)", err))
}
if isAllowedPath(resolved, t.allowedPathPatterns) {
cwd = resolved
@@ -383,14 +364,7 @@ func (t *ExecTool) runSync(ctx context.Context, command, cwd string) *ToolResult
var cmd *exec.Cmd
if runtime.GOOS == "windows" {
cmd = exec.CommandContext(
cmdCtx,
"powershell",
"-NoProfile",
"-NonInteractive",
"-Command",
command,
)
cmd = exec.CommandContext(cmdCtx, "powershell", "-NoProfile", "-NonInteractive", "-Command", command)
} else {
cmd = exec.CommandContext(cmdCtx, "sh", "-c", command)
}
@@ -468,10 +442,7 @@ func (t *ExecTool) runSync(ctx context.Context, command, cwd string) *ToolResult
maxLen := 10000
if len(output) > maxLen {
output = output[:maxLen] + fmt.Sprintf(
"\n... (truncated, %d more chars)",
len(output)-maxLen,
)
output = output[:maxLen] + fmt.Sprintf("\n... (truncated, %d more chars)", len(output)-maxLen)
}
if err != nil {
@@ -489,11 +460,7 @@ func (t *ExecTool) runSync(ctx context.Context, command, cwd string) *ToolResult
}
}
func (t *ExecTool) runBackground(
ctx context.Context,
command, cwd string,
ptyEnabled bool,
) *ToolResult {
func (t *ExecTool) runBackground(ctx context.Context, command, cwd string, ptyEnabled bool) *ToolResult {
sessionID := generateSessionID()
session := &ProcessSession{
ID: sessionID,
@@ -586,8 +553,7 @@ func (t *ExecTool) runBackground(
n, err := session.ptyMaster.Read(buf)
if n > 0 {
raw := string(buf[:n])
if mode := detectPtyKeyMode(raw); mode != PtyKeyModeNotFound &&
mode != session.GetPtyKeyMode() {
if mode := detectPtyKeyMode(raw); mode != PtyKeyModeNotFound && mode != session.GetPtyKeyMode() {
session.SetPtyKeyMode(mode)
}
@@ -768,16 +734,12 @@ func (t *ExecTool) executeWrite(args map[string]any) *ToolResult {
}
if session.IsDone() {
return ErrorResult(
fmt.Sprintf("process already exited with code %d", session.GetExitCode()),
)
return ErrorResult(fmt.Sprintf("process already exited with code %d", session.GetExitCode()))
}
if err := session.Write(data); err != nil {
if errors.Is(err, ErrSessionDone) {
return ErrorResult(
fmt.Sprintf("process already exited with code %d", session.GetExitCode()),
)
return ErrorResult(fmt.Sprintf("process already exited with code %d", session.GetExitCode()))
}
return ErrorResult(fmt.Sprintf("failed to write to session: %v", err))
}
@@ -808,9 +770,7 @@ func (t *ExecTool) executeKill(args map[string]any) *ToolResult {
}
if session.IsDone() {
return ErrorResult(
fmt.Sprintf("process already exited with code %d", session.GetExitCode()),
)
return ErrorResult(fmt.Sprintf("process already exited with code %d", session.GetExitCode()))
}
if err := session.Kill(); err != nil {
@@ -1032,16 +992,12 @@ func (t *ExecTool) executeSendKeys(args map[string]any) *ToolResult {
}
if session.IsDone() {
return ErrorResult(
fmt.Sprintf("process already exited with code %d", session.GetExitCode()),
)
return ErrorResult(fmt.Sprintf("process already exited with code %d", session.GetExitCode()))
}
if err := session.Write(data); err != nil {
if errors.Is(err, ErrSessionDone) {
return ErrorResult(
fmt.Sprintf("process already exited with code %d", session.GetExitCode()),
)
return ErrorResult(fmt.Sprintf("process already exited with code %d", session.GetExitCode()))
}
return ErrorResult(fmt.Sprintf("failed to send keys: %v", err))
}
+18 -77
View File
@@ -100,13 +100,8 @@ func TestShellTool_Timeout(t *testing.T) {
}
// Should mention timeout
if !strings.Contains(result.ForLLM, "timed out") &&
!strings.Contains(result.ForUser, "timed out") {
t.Errorf(
"Expected timeout message, got ForLLM: %s, ForUser: %s",
result.ForLLM,
result.ForUser,
)
if !strings.Contains(result.ForLLM, "timed out") && !strings.Contains(result.ForUser, "timed out") {
t.Errorf("Expected timeout message, got ForLLM: %s, ForUser: %s", result.ForLLM, result.ForUser)
}
}
@@ -161,11 +156,7 @@ func TestShellTool_DangerousCommand(t *testing.T) {
}
if !strings.Contains(result.ForLLM, "blocked") && !strings.Contains(result.ForUser, "blocked") {
t.Errorf(
"Expected 'blocked' message, got ForLLM: %s, ForUser: %s",
result.ForLLM,
result.ForUser,
)
t.Errorf("Expected 'blocked' message, got ForLLM: %s, ForUser: %s", result.ForLLM, result.ForUser)
}
}
@@ -186,11 +177,7 @@ func TestShellTool_DangerousCommand_KillBlocked(t *testing.T) {
t.Errorf("Expected kill command to be blocked")
}
if !strings.Contains(result.ForLLM, "blocked") && !strings.Contains(result.ForUser, "blocked") {
t.Errorf(
"Expected blocked message, got ForLLM: %s, ForUser: %s",
result.ForLLM,
result.ForUser,
)
t.Errorf("Expected blocked message, got ForLLM: %s, ForUser: %s", result.ForLLM, result.ForUser)
}
}
@@ -282,10 +269,7 @@ func TestShellTool_WorkingDir_OutsideWorkspace(t *testing.T) {
})
if !result.IsError {
t.Fatalf(
"expected working_dir outside workspace to be blocked, got output: %s",
result.ForLLM,
)
t.Fatalf("expected working_dir outside workspace to be blocked, got output: %s", result.ForLLM)
}
if !strings.Contains(result.ForLLM, "blocked") {
t.Errorf("expected 'blocked' in error, got: %s", result.ForLLM)
@@ -460,10 +444,7 @@ func TestShellTool_DevNullAllowed(t *testing.T) {
}
for _, cmd := range commands {
result := tool.Execute(
context.Background(),
map[string]any{"action": "run", "command": cmd},
)
result := tool.Execute(context.Background(), map[string]any{"action": "run", "command": cmd})
if result.IsError && strings.Contains(result.ForLLM, "blocked") {
t.Errorf("command should not be blocked: %s\n error: %s", cmd, result.ForLLM)
}
@@ -492,10 +473,7 @@ func TestShellTool_BlockDevices(t *testing.T) {
}
for _, cmd := range blocked {
result := tool.Execute(
context.Background(),
map[string]any{"action": "run", "command": cmd},
)
result := tool.Execute(context.Background(), map[string]any{"action": "run", "command": cmd})
if !result.IsError {
t.Errorf("expected block device write to be blocked: %s", cmd)
}
@@ -519,16 +497,9 @@ func TestShellTool_SafePathsInWorkspaceRestriction(t *testing.T) {
}
for _, cmd := range commands {
result := tool.Execute(
context.Background(),
map[string]any{"action": "run", "command": cmd},
)
result := tool.Execute(context.Background(), map[string]any{"action": "run", "command": cmd})
if result.IsError && strings.Contains(result.ForLLM, "path outside working dir") {
t.Errorf(
"safe path should not be blocked by workspace check: %s\n error: %s",
cmd,
result.ForLLM,
)
t.Errorf("safe path should not be blocked by workspace check: %s\n error: %s", cmd, result.ForLLM)
}
}
}
@@ -620,10 +591,7 @@ func TestShellTool_CustomAllowPatterns(t *testing.T) {
"command": "git push origin main",
})
if result.IsError && strings.Contains(result.ForLLM, "blocked") {
t.Errorf(
"custom allow pattern should exempt 'git push origin main', got: %s",
result.ForLLM,
)
t.Errorf("custom allow pattern should exempt 'git push origin main', got: %s", result.ForLLM)
}
// "git push upstream main" should still be blocked (does not match allow pattern).
@@ -661,11 +629,7 @@ func TestShellTool_URLsNotBlocked(t *testing.T) {
result := tool.Execute(ctx, map[string]any{"action": "run", "command": cmd})
cancel()
if result.IsError && strings.Contains(result.ForLLM, "path outside working dir") {
t.Errorf(
"command with URL should not be blocked by workspace check: %s\n error: %s",
cmd,
result.ForLLM,
)
t.Errorf("command with URL should not be blocked by workspace check: %s\n error: %s", cmd, result.ForLLM)
}
}
}
@@ -688,10 +652,7 @@ func TestShellTool_FileURISandboxing(t *testing.T) {
}
for _, cmd := range blockedCommands {
result := tool.Execute(
context.Background(),
map[string]any{"action": "run", "command": cmd},
)
result := tool.Execute(context.Background(), map[string]any{"action": "run", "command": cmd})
if !result.IsError || !strings.Contains(result.ForLLM, "path outside working dir") {
t.Errorf("file:// URI outside workspace should be blocked: %s", cmd)
}
@@ -709,16 +670,9 @@ func TestShellTool_FileURISandboxing(t *testing.T) {
}
for _, cmd := range allowedCommands {
result := tool.Execute(
context.Background(),
map[string]any{"action": "run", "command": cmd},
)
result := tool.Execute(context.Background(), map[string]any{"action": "run", "command": cmd})
if result.IsError && strings.Contains(result.ForLLM, "path outside working dir") {
t.Errorf(
"file:// URI inside workspace should be allowed: %s\n error: %s",
cmd,
result.ForLLM,
)
t.Errorf("file:// URI inside workspace should be allowed: %s\n error: %s", cmd, result.ForLLM)
}
}
}
@@ -742,10 +696,7 @@ func TestShellTool_URLBypassPrevented(t *testing.T) {
}
for _, cmd := range blockedCommands {
result := tool.Execute(
context.Background(),
map[string]any{"action": "run", "command": cmd},
)
result := tool.Execute(context.Background(), map[string]any{"action": "run", "command": cmd})
if !result.IsError || !strings.Contains(result.ForLLM, "path outside working dir") {
t.Errorf("bypass attempt should be blocked: %q\n got: %s", cmd, result.ForLLM)
}
@@ -1270,9 +1221,7 @@ func TestShellTool_PTY_ProcessGroupKill(t *testing.T) {
// The binary is created in /tmp/test_pgroup.c and compiled as part of test setup.
testBinary := "/tmp/test_pgroup"
if _, err := os.Stat(testBinary); os.IsNotExist(err) {
t.Skip(
"Test binary /tmp/test_pgroup not found - run: gcc -o /tmp/test_pgroup /tmp/test_pgroup.c",
)
t.Skip("Test binary /tmp/test_pgroup not found - run: gcc -o /tmp/test_pgroup /tmp/test_pgroup.c")
}
tool, err := NewExecTool("", false)
@@ -1606,16 +1555,8 @@ func TestDetectPtyKeyMode(t *testing.T) {
{"rmkx only", "\x1b[?1l\x1b>", PtyKeyModeCSI},
{"both smkx first", "\x1b[?1h\x1b=...\x1b[?1l\x1b>", PtyKeyModeCSI},
{"both rmkx first", "\x1b[?1l\x1b>...\x1b[?1h\x1b=", PtyKeyModeSS3},
{
"multiple toggles smkx last",
"\x1b[?1h\x1b=...\x1b[?1l\x1b>...\x1b[?1h\x1b=",
PtyKeyModeSS3,
},
{
"multiple toggles rmkx last",
"\x1b[?1l\x1b>...\x1b[?1h\x1b=...\x1b[?1l\x1b>",
PtyKeyModeCSI,
},
{"multiple toggles smkx last", "\x1b[?1h\x1b=...\x1b[?1l\x1b>...\x1b[?1h\x1b=", PtyKeyModeSS3},
{"multiple toggles rmkx last", "\x1b[?1l\x1b>...\x1b[?1h\x1b=...\x1b[?1l\x1b>", PtyKeyModeCSI},
{"partial smkx", "\x1b[?1h", PtyKeyModeSS3},
{"partial rmkx", "\x1b[?1l", PtyKeyModeCSI},
}
+3 -12
View File
@@ -96,11 +96,7 @@ func (t *InstallSkillTool) Execute(ctx context.Context, args map[string]any) *To
if !force {
if _, err := os.Stat(targetDir); err == nil {
return ErrorResult(
fmt.Sprintf(
"skill %q already installed at %s. Use force=true to reinstall.",
slug,
targetDir,
),
fmt.Sprintf("skill %q already installed at %s. Use force=true to reinstall.", slug, targetDir),
)
}
} else {
@@ -146,9 +142,7 @@ func (t *InstallSkillTool) Execute(ctx context.Context, args map[string]any) *To
"error": rmErr.Error(),
})
}
return ErrorResult(
fmt.Sprintf("skill %q is flagged as malicious and cannot be installed", slug),
)
return ErrorResult(fmt.Sprintf("skill %q is flagged as malicious and cannot be installed", slug))
}
// Write origin metadata.
@@ -168,10 +162,7 @@ func (t *InstallSkillTool) Execute(ctx context.Context, args map[string]any) *To
// Build result with moderation warning if suspicious.
var output string
if result.IsSuspicious {
output = fmt.Sprintf(
"⚠️ Warning: skill %q is flagged as suspicious (may contain risky patterns).\n\n",
slug,
)
output = fmt.Sprintf("⚠️ Warning: skill %q is flagged as suspicious (may contain risky patterns).\n\n", slug)
}
output += fmt.Sprintf("Successfully installed skill %q v%s from %s registry.\nLocation: %s\n",
slug, result.Version, registry.Name(), targetDir)
+1 -4
View File
@@ -17,10 +17,7 @@ type FindSkillsTool struct {
// NewFindSkillsTool creates a new FindSkillsTool.
// registryMgr is the shared registry manager (built from config in createToolRegistry).
// cache is the search cache for deduplicating similar queries.
func NewFindSkillsTool(
registryMgr *skills.RegistryManager,
cache *skills.SearchCache,
) *FindSkillsTool {
func NewFindSkillsTool(registryMgr *skills.RegistryManager, cache *skills.SearchCache) *FindSkillsTool {
return &FindSkillsTool{
registryMgr: registryMgr,
cache: cache,
+2 -4
View File
@@ -77,12 +77,10 @@ func (t *SpawnStatusTool) Execute(ctx context.Context, args map[string]any) *Too
}
// Restrict lookup to tasks that belong to this conversation.
if callerChannel != "" && taskCopy.OriginChannel != "" &&
taskCopy.OriginChannel != callerChannel {
if callerChannel != "" && taskCopy.OriginChannel != "" && taskCopy.OriginChannel != callerChannel {
return ErrorResult(fmt.Sprintf("No subagent found with task ID: %s", taskID))
}
if callerChatID != "" && taskCopy.OriginChatID != "" &&
taskCopy.OriginChatID != callerChatID {
if callerChatID != "" && taskCopy.OriginChatID != "" && taskCopy.OriginChatID != callerChatID {
return ErrorResult(fmt.Sprintf("No subagent found with task ID: %s", taskID))
}
+2 -10
View File
@@ -195,12 +195,7 @@ func TestSpawnStatusTool_TaskID_NonString(t *testing.T) {
for _, badVal := range []any{42, 3.14, true, map[string]any{"x": 1}, []string{"a"}} {
result := tool.Execute(context.Background(), map[string]any{"task_id": badVal})
if !result.IsError {
t.Errorf(
"Expected error for task_id=%T(%v), got success: %s",
badVal,
badVal,
result.ForLLM,
)
t.Errorf("Expected error for task_id=%T(%v), got success: %s", badVal, badVal, result.ForLLM)
}
if !strings.Contains(result.ForLLM, "task_id must be a string") {
t.Errorf("Expected type-error message, got: %s", result.ForLLM)
@@ -324,10 +319,7 @@ func TestSpawnStatusTool_SortByCreatedTimestamp(t *testing.T) {
t.Fatalf("Both task IDs should appear in output:\n%s", result.ForLLM)
}
if pos2 > pos10 {
t.Errorf(
"Expected subagent-2 (created first) to appear before subagent-10, but got:\n%s",
result.ForLLM,
)
t.Errorf("Expected subagent-2 (created first) to appear before subagent-10, but got:\n%s", result.ForLLM)
}
}
+2 -6
View File
@@ -69,9 +69,7 @@ func (t *SPITool) Parameters() map[string]any {
func (t *SPITool) Execute(ctx context.Context, args map[string]any) *ToolResult {
if runtime.GOOS != "linux" {
return ErrorResult(
"SPI is only supported on Linux. This tool requires /dev/spidev* device files.",
)
return ErrorResult("SPI is only supported on Linux. This tool requires /dev/spidev* device files.")
}
action, ok := args["action"].(string)
@@ -126,9 +124,7 @@ func (t *SPITool) list() *ToolResult {
// parseSPIArgs extracts and validates common SPI parameters
//
//nolint:unused // Used by spi_linux.go
func parseSPIArgs(
args map[string]any,
) (device string, speed uint32, mode uint8, bits uint8, errMsg string) {
func parseSPIArgs(args map[string]any) (device string, speed uint32, mode uint8, bits uint8, errMsg string) {
dev, ok := args["device"].(string)
if !ok || dev == "" {
return "", 0, 0, 0, "device is required (e.g. \"2.0\" for /dev/spidev2.0)"
+6 -37
View File
@@ -38,46 +38,25 @@ type spiTransfer struct {
func configureSPI(devPath string, mode uint8, bits uint8, speed uint32) (int, *ToolResult) {
fd, err := syscall.Open(devPath, syscall.O_RDWR, 0)
if err != nil {
return -1, ErrorResult(
fmt.Sprintf(
"failed to open %s: %v (check permissions and spidev module)",
devPath,
err,
),
)
return -1, ErrorResult(fmt.Sprintf("failed to open %s: %v (check permissions and spidev module)", devPath, err))
}
// Set SPI mode
_, _, errno := syscall.Syscall(
syscall.SYS_IOCTL,
uintptr(fd),
spiIocWrMode,
uintptr(unsafe.Pointer(&mode)),
)
_, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), spiIocWrMode, uintptr(unsafe.Pointer(&mode)))
if errno != 0 {
syscall.Close(fd)
return -1, ErrorResult(fmt.Sprintf("failed to set SPI mode %d: %v", mode, errno))
}
// Set bits per word
_, _, errno = syscall.Syscall(
syscall.SYS_IOCTL,
uintptr(fd),
spiIocWrBitsPerWord,
uintptr(unsafe.Pointer(&bits)),
)
_, _, errno = syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), spiIocWrBitsPerWord, uintptr(unsafe.Pointer(&bits)))
if errno != 0 {
syscall.Close(fd)
return -1, ErrorResult(fmt.Sprintf("failed to set bits per word %d: %v", bits, errno))
}
// Set max speed
_, _, errno = syscall.Syscall(
syscall.SYS_IOCTL,
uintptr(fd),
spiIocWrMaxSpeedHz,
uintptr(unsafe.Pointer(&speed)),
)
_, _, errno = syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), spiIocWrMaxSpeedHz, uintptr(unsafe.Pointer(&speed)))
if errno != 0 {
syscall.Close(fd)
return -1, ErrorResult(fmt.Sprintf("failed to set SPI speed %d Hz: %v", speed, errno))
@@ -138,12 +117,7 @@ func (t *SPITool) transfer(args map[string]any) *ToolResult {
bitsPerWord: bits,
}
_, _, errno := syscall.Syscall(
syscall.SYS_IOCTL,
uintptr(fd),
spiIocMessage1,
uintptr(unsafe.Pointer(&xfer)),
)
_, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), spiIocMessage1, uintptr(unsafe.Pointer(&xfer)))
runtime.KeepAlive(txBuf)
runtime.KeepAlive(rxBuf)
if errno != 0 {
@@ -200,12 +174,7 @@ func (t *SPITool) readDevice(args map[string]any) *ToolResult {
bitsPerWord: bits,
}
_, _, errno := syscall.Syscall(
syscall.SYS_IOCTL,
uintptr(fd),
spiIocMessage1,
uintptr(unsafe.Pointer(&xfer)),
)
_, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), spiIocMessage1, uintptr(unsafe.Pointer(&xfer)))
runtime.KeepAlive(txBuf)
runtime.KeepAlive(rxBuf)
if errno != 0 {
+1 -5
View File
@@ -316,11 +316,7 @@ func TestSubagentTool_ForUserTruncation(t *testing.T) {
// ForUser should be truncated to 500 chars + "..."
maxUserLen := 500
if len(result.ForUser) > maxUserLen+3 { // +3 for "..."
t.Errorf(
"ForUser should be truncated to ~%d chars, got: %d",
maxUserLen,
len(result.ForUser),
)
t.Errorf("ForUser should be truncated to ~%d chars, got: %d", maxUserLen, len(result.ForUser))
}
// ForLLM should have full content
+2 -15
View File
@@ -64,13 +64,7 @@ func RunToolLoop(
llmOpts = map[string]any{}
}
// 3. Call LLM
response, err := config.Provider.Chat(
ctx,
messages,
providerToolDefs,
config.Model,
llmOpts,
)
response, err := config.Provider.Chat(ctx, messages, providerToolDefs, config.Model, llmOpts)
if err != nil {
logger.ErrorCF("toolloop", "LLM call failed",
map[string]any{
@@ -154,14 +148,7 @@ func RunToolLoop(
var toolResult *ToolResult
if config.Tools != nil {
toolResult = config.Tools.ExecuteWithContext(
ctx,
tc.Name,
tc.Arguments,
channel,
chatID,
nil,
)
toolResult = config.Tools.ExecuteWithContext(ctx, tc.Name, tc.Arguments, channel, chatID, nil)
} else {
toolResult = ErrorResult("No tools available")
}
+5 -21
View File
@@ -151,10 +151,7 @@ func TestValidateToolArgs(t *testing.T) {
schema: map[string]any{
"type": "object",
"properties": map[string]any{
"color": map[string]any{
"type": "string",
"enum": []any{"red", "green", "blue"},
},
"color": map[string]any{"type": "string", "enum": []any{"red", "green", "blue"}},
},
},
args: map[string]any{"color": "red"},
@@ -164,10 +161,7 @@ func TestValidateToolArgs(t *testing.T) {
schema: map[string]any{
"type": "object",
"properties": map[string]any{
"color": map[string]any{
"type": "string",
"enum": []any{"red", "green", "blue"},
},
"color": map[string]any{"type": "string", "enum": []any{"red", "green", "blue"}},
},
},
args: map[string]any{"color": "yellow"},
@@ -178,10 +172,7 @@ func TestValidateToolArgs(t *testing.T) {
schema: map[string]any{
"type": "object",
"properties": map[string]any{
"color": map[string]any{
"type": "string",
"enum": []string{"red", "green", "blue"},
},
"color": map[string]any{"type": "string", "enum": []string{"red", "green", "blue"}},
},
},
args: map[string]any{"color": "green"},
@@ -191,10 +182,7 @@ func TestValidateToolArgs(t *testing.T) {
schema: map[string]any{
"type": "object",
"properties": map[string]any{
"color": map[string]any{
"type": "string",
"enum": []string{"red", "green", "blue"},
},
"color": map[string]any{"type": "string", "enum": []string{"red", "green", "blue"}},
},
},
args: map[string]any{"color": "yellow"},
@@ -354,11 +342,7 @@ func TestValidateToolArgs_RegistryIntegration(t *testing.T) {
}
// Extra property — should fail with validation error
result = r.Execute(
context.Background(),
"read_file",
map[string]any{"path": "/x", "__inject": true},
)
result = r.Execute(context.Background(), "read_file", map[string]any{"path": "/x", "__inject": true})
if !result.IsError {
t.Error("expected validation error for extra property")
}
+32 -130
View File
@@ -54,8 +54,7 @@ func TestWebTool_WebFetch_Success(t *testing.T) {
}
// ForUser should contain summary
if !strings.Contains(result.ForUser, "bytes") &&
!strings.Contains(result.ForUser, "extractor") {
if !strings.Contains(result.ForUser, "bytes") && !strings.Contains(result.ForUser, "extractor") {
t.Errorf("Expected ForUser to contain summary, got: %s", result.ForUser)
}
}
@@ -76,11 +75,7 @@ func TestWebTool_WebFetch_JSON(t *testing.T) {
tool, err := NewWebFetchTool(50000, format, testFetchLimit)
if err != nil {
logger.ErrorCF(
"agent",
"Failed to create web fetch tool",
map[string]any{"error": err.Error()},
)
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
}
ctx := context.Background()
@@ -105,11 +100,7 @@ func TestWebTool_WebFetch_JSON(t *testing.T) {
func TestWebTool_WebFetch_InvalidURL(t *testing.T) {
tool, err := NewWebFetchTool(50000, format, testFetchLimit)
if err != nil {
logger.ErrorCF(
"agent",
"Failed to create web fetch tool",
map[string]any{"error": err.Error()},
)
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
}
ctx := context.Background()
@@ -134,11 +125,7 @@ func TestWebTool_WebFetch_InvalidURL(t *testing.T) {
func TestWebTool_WebFetch_UnsupportedScheme(t *testing.T) {
tool, err := NewWebFetchTool(50000, format, testFetchLimit)
if err != nil {
logger.ErrorCF(
"agent",
"Failed to create web fetch tool",
map[string]any{"error": err.Error()},
)
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
}
ctx := context.Background()
@@ -154,8 +141,7 @@ func TestWebTool_WebFetch_UnsupportedScheme(t *testing.T) {
}
// Should mention only http/https allowed
if !strings.Contains(result.ForLLM, "http/https") &&
!strings.Contains(result.ForUser, "http/https") {
if !strings.Contains(result.ForLLM, "http/https") && !strings.Contains(result.ForUser, "http/https") {
t.Errorf("Expected scheme error message, got ForLLM: %s", result.ForLLM)
}
}
@@ -164,11 +150,7 @@ func TestWebTool_WebFetch_UnsupportedScheme(t *testing.T) {
func TestWebTool_WebFetch_MissingURL(t *testing.T) {
tool, err := NewWebFetchTool(50000, format, testFetchLimit)
if err != nil {
logger.ErrorCF(
"agent",
"Failed to create web fetch tool",
map[string]any{"error": err.Error()},
)
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
}
ctx := context.Background()
@@ -182,8 +164,7 @@ func TestWebTool_WebFetch_MissingURL(t *testing.T) {
}
// Should mention URL is required
if !strings.Contains(result.ForLLM, "url is required") &&
!strings.Contains(result.ForUser, "url is required") {
if !strings.Contains(result.ForLLM, "url is required") && !strings.Contains(result.ForUser, "url is required") {
t.Errorf("Expected 'url is required' message, got ForLLM: %s", result.ForLLM)
}
}
@@ -203,11 +184,7 @@ func TestWebTool_WebFetch_Truncation(t *testing.T) {
tool, err := NewWebFetchTool(1000, format, testFetchLimit) // Limit to 1000 chars
if err != nil {
logger.ErrorCF(
"agent",
"Failed to create web fetch tool",
map[string]any{"error": err.Error()},
)
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
}
ctx := context.Background()
@@ -239,10 +216,7 @@ func TestWebTool_WebFetch_Truncation(t *testing.T) {
// Text should end with the truncation notice
if text, ok := resultMap["text"].(string); ok {
if !strings.HasSuffix(text, "[Content truncated due to size limit]") {
t.Errorf(
"Expected text to end with truncation notice, got: %q",
text[max(0, len(text)-60):],
)
t.Errorf("Expected text to end with truncation notice, got: %q", text[max(0, len(text)-60):])
}
}
}
@@ -289,13 +263,11 @@ func TestWebTool_WebFetch_TruncationNotice(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
server := httptest.NewServer(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", tt.contentType)
w.WriteHeader(http.StatusOK)
w.Write([]byte(tt.body))
}),
)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", tt.contentType)
w.WriteHeader(http.StatusOK)
w.Write([]byte(tt.body))
}))
defer server.Close()
tool, err := NewWebFetchTool(maxChars, tt.format, testFetchLimit)
@@ -319,11 +291,7 @@ func TestWebTool_WebFetch_TruncationNotice(t *testing.T) {
}
if !strings.HasSuffix(text, truncationNotice) {
t.Errorf(
"expected text to end with %q, got suffix: %q",
truncationNotice,
text[max(0, len(text)-60):],
)
t.Errorf("expected text to end with %q, got suffix: %q", truncationNotice, text[max(0, len(text)-60):])
}
if truncated, ok := resultMap["truncated"].(bool); !ok || !truncated {
@@ -392,11 +360,7 @@ func TestWebFetchTool_PayloadTooLarge(t *testing.T) {
// Initialize the tool
tool, err := NewWebFetchTool(50000, format, testFetchLimit)
if err != nil {
logger.ErrorCF(
"agent",
"Failed to create web fetch tool",
map[string]any{"error": err.Error()},
)
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
}
// Prepare the arguments pointing to the URL of our local mock server
@@ -416,8 +380,7 @@ func TestWebFetchTool_PayloadTooLarge(t *testing.T) {
// Search for the exact error string we set earlier in the Execute method
expectedErrorMsg := fmt.Sprintf("size exceeded %d bytes limit", testFetchLimit)
if !strings.Contains(result.ForLLM, expectedErrorMsg) &&
!strings.Contains(result.ForUser, expectedErrorMsg) {
if !strings.Contains(result.ForLLM, expectedErrorMsg) && !strings.Contains(result.ForUser, expectedErrorMsg) {
t.Errorf("test failed: expected error %q, but got: %+v", expectedErrorMsg, result)
}
}
@@ -570,11 +533,7 @@ func TestWebTool_WebFetch_HTMLExtraction(t *testing.T) {
tool, err := NewWebFetchTool(50000, format, testFetchLimit)
if err != nil {
logger.ErrorCF(
"agent",
"Failed to create web fetch tool",
map[string]any{"error": err.Error()},
)
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
}
ctx := context.Background()
@@ -759,13 +718,7 @@ func TestWebTool_WebFetch_PrivateHostAllowedByCIDRWhitelist(t *testing.T) {
defer server.Close()
host, _ := serverHostAndPort(t, server.URL)
tool, err := NewWebFetchToolWithConfig(
50000,
"",
format,
testFetchLimit,
[]string{singleHostCIDR(t, host)},
)
tool, err := NewWebFetchToolWithConfig(50000, "", format, testFetchLimit, []string{singleHostCIDR(t, host)})
if err != nil {
t.Fatalf("Failed to create web fetch tool: %v", err)
}
@@ -800,10 +753,7 @@ func TestWebTool_WebFetch_PrivateHostAllowedForTests(t *testing.T) {
})
if result.IsError {
t.Errorf(
"expected success when private host access is allowed in tests, got %q",
result.ForLLM,
)
t.Errorf("expected success when private host access is allowed in tests, got %q", result.ForLLM)
}
}
@@ -1023,11 +973,7 @@ func TestIsPrivateOrRestrictedIP_Table(t *testing.T) {
func TestWebTool_WebFetch_MissingDomain(t *testing.T) {
tool, err := NewWebFetchTool(50000, format, testFetchLimit)
if err != nil {
logger.ErrorCF(
"agent",
"Failed to create web fetch tool",
map[string]any{"error": err.Error()},
)
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
}
ctx := context.Background()
@@ -1049,19 +995,9 @@ func TestWebTool_WebFetch_MissingDomain(t *testing.T) {
}
func TestNewWebFetchToolWithProxy(t *testing.T) {
tool, err := NewWebFetchToolWithProxy(
1024,
"http://127.0.0.1:7890",
format,
testFetchLimit,
nil,
)
tool, err := NewWebFetchToolWithProxy(1024, "http://127.0.0.1:7890", format, testFetchLimit, nil)
if err != nil {
logger.ErrorCF(
"agent",
"Failed to create web fetch tool",
map[string]any{"error": err.Error()},
)
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
} else if tool.maxChars != 1024 {
t.Fatalf("maxChars = %d, want %d", tool.maxChars, 1024)
}
@@ -1072,11 +1008,7 @@ func TestNewWebFetchToolWithProxy(t *testing.T) {
tool, err = NewWebFetchToolWithProxy(0, "http://127.0.0.1:7890", format, testFetchLimit, nil)
if err != nil {
logger.ErrorCF(
"agent",
"Failed to create web fetch tool",
map[string]any{"error": err.Error()},
)
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
}
if tool.maxChars != 50000 {
@@ -1085,13 +1017,7 @@ func TestNewWebFetchToolWithProxy(t *testing.T) {
}
func TestNewWebFetchToolWithConfig_InvalidPrivateHostWhitelist(t *testing.T) {
_, err := NewWebFetchToolWithConfig(
1024,
"",
format,
testFetchLimit,
[]string{"not-an-ip-or-cidr"},
)
_, err := NewWebFetchToolWithConfig(1024, "", format, testFetchLimit, []string{"not-an-ip-or-cidr"})
if err == nil {
t.Fatal("expected invalid whitelist entry to fail")
}
@@ -1247,11 +1173,7 @@ func TestWebTool_TavilySearch_RangeMapping(t *testing.T) {
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]any{
"results": []map[string]any{
{
"title": "Recent result",
"url": "https://example.com/recent",
"content": "snippet",
},
{"title": "Recent result", "url": "https://example.com/recent", "content": "snippet"},
},
})
}))
@@ -1381,10 +1303,7 @@ func TestWebFetchTool_CloudflareChallenge_RetryFailsToo(t *testing.T) {
// Should not be an error — the retry response is used as-is (403 is a valid HTTP response)
if result.IsError {
t.Fatalf(
"expected non-error result even when retry is also blocked, got: %s",
result.ForLLM,
)
t.Fatalf("expected non-error result even when retry is also blocked, got: %s", result.ForLLM)
}
// Status in the JSON result should reflect the 403
if !strings.Contains(result.ForLLM, "403") {
@@ -1549,10 +1468,7 @@ func TestWebTool_GLMSearch_Success(t *testing.T) {
t.Errorf("Expected Content-Type application/json, got %s", r.Header.Get("Content-Type"))
}
if r.Header.Get("Authorization") != "Bearer test-glm-key" {
t.Errorf(
"Expected Authorization Bearer test-glm-key, got %s",
r.Header.Get("Authorization"),
)
t.Errorf("Expected Authorization Bearer test-glm-key, got %s", r.Header.Get("Authorization"))
}
var payload map[string]any
@@ -1618,21 +1534,14 @@ func TestWebTool_GLMSearch_RangeMapping(t *testing.T) {
t.Fatalf("failed to decode payload: %v", err)
}
if payload["search_recency_filter"] != "oneMonth" {
t.Fatalf(
"expected search_recency_filter=oneMonth, got %v",
payload["search_recency_filter"],
)
t.Fatalf("expected search_recency_filter=oneMonth, got %v", payload["search_recency_filter"])
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]any{
"search_result": []map[string]any{
{
"title": "Recent GLM Result",
"content": "snippet",
"link": "https://example.com/glm-range",
},
{"title": "Recent GLM Result", "content": "snippet", "link": "https://example.com/glm-range"},
},
})
}))
@@ -1664,21 +1573,14 @@ func TestWebTool_BaiduSearch_RangeMapping(t *testing.T) {
t.Fatalf("failed to decode payload: %v", err)
}
if payload["search_recency_filter"] != "week" {
t.Fatalf(
"expected search_recency_filter=week for day fallback, got %v",
payload["search_recency_filter"],
)
t.Fatalf("expected search_recency_filter=week for day fallback, got %v", payload["search_recency_filter"])
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]any{
"references": []map[string]any{
{
"title": "Recent Baidu Result",
"url": "https://example.com/baidu",
"content": "snippet",
},
{"title": "Recent Baidu Result", "url": "https://example.com/baidu", "content": "snippet"},
},
})
}))