mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
chore: revert unrelated golines formatting
This commit is contained in:
@@ -500,11 +500,8 @@ func TestEstimateMessageTokens_ReasoningContent(t *testing.T) {
|
||||
reasoningTokens := estimateMessageTokens(withReasoning)
|
||||
|
||||
if reasoningTokens <= plainTokens {
|
||||
t.Errorf(
|
||||
"message with ReasoningContent (%d tokens) should exceed plain message (%d tokens)",
|
||||
reasoningTokens,
|
||||
plainTokens,
|
||||
)
|
||||
t.Errorf("message with ReasoningContent (%d tokens) should exceed plain message (%d tokens)",
|
||||
reasoningTokens, plainTokens)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -767,11 +764,7 @@ func TestEstimateMessageTokens_WithReasoningAndMedia(t *testing.T) {
|
||||
tokensNoReasoning := estimateMessageTokens(msgNoReasoning)
|
||||
|
||||
if tokens <= tokensNoReasoning {
|
||||
t.Errorf(
|
||||
"reasoning content should add tokens: with=%d, without=%d",
|
||||
tokens,
|
||||
tokensNoReasoning,
|
||||
)
|
||||
t.Errorf("reasoning content should add tokens: with=%d, without=%d", tokens, tokensNoReasoning)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -82,16 +82,7 @@ func TestSingleSystemMessage(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
msgs := cb.BuildMessages(
|
||||
tt.history,
|
||||
tt.summary,
|
||||
tt.message,
|
||||
nil,
|
||||
"test",
|
||||
"chat1",
|
||||
"",
|
||||
"",
|
||||
)
|
||||
msgs := cb.BuildMessages(tt.history, tt.summary, tt.message, nil, "test", "chat1", "", "")
|
||||
|
||||
systemCount := 0
|
||||
for _, m := range msgs {
|
||||
@@ -177,16 +168,7 @@ func TestBuildMessages_CurrentSenderDynamicContext(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
msgs := cb.BuildMessages(
|
||||
nil,
|
||||
"",
|
||||
"hello",
|
||||
nil,
|
||||
"discord",
|
||||
"chat1",
|
||||
tt.senderID,
|
||||
tt.senderDisplayName,
|
||||
)
|
||||
msgs := cb.BuildMessages(nil, "", "hello", nil, "discord", "chat1", tt.senderID, tt.senderDisplayName)
|
||||
sys := msgs[0].Content
|
||||
|
||||
if tt.wantSection {
|
||||
@@ -400,10 +382,7 @@ func TestNewFileCreationInvalidatesCache(t *testing.T) {
|
||||
// Cache should auto-invalidate because file went from absent -> present
|
||||
sp2 := cb.BuildSystemPromptWithCache()
|
||||
if !strings.Contains(sp2, tt.checkField) {
|
||||
t.Errorf(
|
||||
"cache not invalidated on new file creation: expected %q in prompt",
|
||||
tt.checkField,
|
||||
)
|
||||
t.Errorf("cache not invalidated on new file creation: expected %q in prompt", tt.checkField)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -151,19 +151,7 @@ func TestSanitizeHistoryForProvider_MultiToolCallsThenNewRound(t *testing.T) {
|
||||
if len(result) != 9 {
|
||||
t.Fatalf("expected 9 messages, got %d: %+v", len(result), roles(result))
|
||||
}
|
||||
assertRoles(
|
||||
t,
|
||||
result,
|
||||
"user",
|
||||
"assistant",
|
||||
"tool",
|
||||
"tool",
|
||||
"assistant",
|
||||
"user",
|
||||
"assistant",
|
||||
"tool",
|
||||
"assistant",
|
||||
)
|
||||
assertRoles(t, result, "user", "assistant", "tool", "tool", "assistant", "user", "assistant", "tool", "assistant")
|
||||
}
|
||||
|
||||
func TestSanitizeHistoryForProvider_ConsecutiveMultiToolRounds(t *testing.T) {
|
||||
@@ -182,18 +170,7 @@ func TestSanitizeHistoryForProvider_ConsecutiveMultiToolRounds(t *testing.T) {
|
||||
if len(result) != 8 {
|
||||
t.Fatalf("expected 8 messages, got %d: %+v", len(result), roles(result))
|
||||
}
|
||||
assertRoles(
|
||||
t,
|
||||
result,
|
||||
"user",
|
||||
"assistant",
|
||||
"tool",
|
||||
"tool",
|
||||
"assistant",
|
||||
"tool",
|
||||
"tool",
|
||||
"assistant",
|
||||
)
|
||||
assertRoles(t, result, "user", "assistant", "tool", "tool", "assistant", "tool", "tool", "assistant")
|
||||
}
|
||||
|
||||
func TestSanitizeHistoryForProvider_PlainConversation(t *testing.T) {
|
||||
@@ -327,17 +304,5 @@ func TestSanitizeHistoryForProvider_PartialToolResultsInMiddle(t *testing.T) {
|
||||
if len(result) != 9 {
|
||||
t.Fatalf("expected 9 messages, got %d: %+v", len(result), roles(result))
|
||||
}
|
||||
assertRoles(
|
||||
t,
|
||||
result,
|
||||
"user",
|
||||
"assistant",
|
||||
"tool",
|
||||
"assistant",
|
||||
"user",
|
||||
"user",
|
||||
"assistant",
|
||||
"tool",
|
||||
"assistant",
|
||||
)
|
||||
assertRoles(t, result, "user", "assistant", "tool", "assistant", "user", "user", "assistant", "tool", "assistant")
|
||||
}
|
||||
|
||||
@@ -61,12 +61,8 @@ Act directly and use tools first.
|
||||
if len(definition.Agent.Frontmatter.Skills) != 2 {
|
||||
t.Fatalf("expected skills to be parsed, got %v", definition.Agent.Frontmatter.Skills)
|
||||
}
|
||||
if len(definition.Agent.Frontmatter.MCPServers) != 1 ||
|
||||
definition.Agent.Frontmatter.MCPServers[0] != "github" {
|
||||
t.Fatalf(
|
||||
"expected mcpServers to be parsed, got %v",
|
||||
definition.Agent.Frontmatter.MCPServers,
|
||||
)
|
||||
if len(definition.Agent.Frontmatter.MCPServers) != 1 || definition.Agent.Frontmatter.MCPServers[0] != "github" {
|
||||
t.Fatalf("expected mcpServers to be parsed, got %v", definition.Agent.Frontmatter.MCPServers)
|
||||
}
|
||||
if definition.Agent.Frontmatter.Fields["metadata"] == nil {
|
||||
t.Fatal("expected arbitrary frontmatter fields to remain available")
|
||||
@@ -100,10 +96,7 @@ func TestLoadAgentDefinitionFallsBackToLegacyAgentsMarkdown(t *testing.T) {
|
||||
t.Fatal("expected AGENTS.md to be loaded")
|
||||
}
|
||||
if definition.Agent.RawFrontmatter != "" {
|
||||
t.Fatalf(
|
||||
"legacy AGENTS.md should not have frontmatter, got %q",
|
||||
definition.Agent.RawFrontmatter,
|
||||
)
|
||||
t.Fatalf("legacy AGENTS.md should not have frontmatter, got %q", definition.Agent.RawFrontmatter)
|
||||
}
|
||||
if !strings.Contains(definition.Agent.Body, "Keep compatibility") {
|
||||
t.Fatalf("expected legacy body to be preserved, got %q", definition.Agent.Body)
|
||||
@@ -166,10 +159,7 @@ Keep going.
|
||||
len(definition.Agent.Frontmatter.Skills) != 0 ||
|
||||
len(definition.Agent.Frontmatter.MCPServers) != 0 ||
|
||||
len(definition.Agent.Frontmatter.Fields) != 0 {
|
||||
t.Fatalf(
|
||||
"expected invalid frontmatter to decode as empty struct, got %+v",
|
||||
definition.Agent.Frontmatter,
|
||||
)
|
||||
t.Fatalf("expected invalid frontmatter to decode as empty struct, got %+v", definition.Agent.Frontmatter)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -275,13 +275,7 @@ func TestAgentLoop_EmitsSteeringAndSkippedToolEvents(t *testing.T) {
|
||||
|
||||
resultCh := make(chan string, 1)
|
||||
go func() {
|
||||
resp, _ := al.ProcessDirectWithChannel(
|
||||
context.Background(),
|
||||
"do something",
|
||||
"test-session",
|
||||
"test",
|
||||
"chat1",
|
||||
)
|
||||
resp, _ := al.ProcessDirectWithChannel(context.Background(), "do something", "test-session", "test", "chat1")
|
||||
resultCh <- resp
|
||||
}()
|
||||
|
||||
@@ -344,11 +338,7 @@ func TestAgentLoop_EmitsSteeringAndSkippedToolEvents(t *testing.T) {
|
||||
t.Fatalf("expected steering interrupt kind, got %q", interruptPayload.Kind)
|
||||
}
|
||||
if interruptPayload.ContentLen != len("change course") {
|
||||
t.Fatalf(
|
||||
"expected interrupt content len %d, got %d",
|
||||
len("change course"),
|
||||
interruptPayload.ContentLen,
|
||||
)
|
||||
t.Fatalf("expected interrupt content len %d, got %d", len("change course"), interruptPayload.ContentLen)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -370,9 +360,7 @@ func TestAgentLoop_EmitsContextCompressEventOnRetry(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
contextErr := stringError(
|
||||
"InvalidParameter: Total tokens of image and text exceed max message tokens",
|
||||
)
|
||||
contextErr := stringError("InvalidParameter: Total tokens of image and text exceed max message tokens")
|
||||
provider := &failFirstMockProvider{
|
||||
failures: 1,
|
||||
failError: contextErr,
|
||||
@@ -615,12 +603,7 @@ func collectEventStream(ch <-chan Event) []Event {
|
||||
}
|
||||
}
|
||||
|
||||
func waitForEvent(
|
||||
t *testing.T,
|
||||
ch <-chan Event,
|
||||
timeout time.Duration,
|
||||
match func(Event) bool,
|
||||
) Event {
|
||||
func waitForEvent(t *testing.T, ch <-chan Event, timeout time.Duration, match func(Event) bool) Event {
|
||||
t.Helper()
|
||||
|
||||
timer := time.NewTimer(timeout)
|
||||
|
||||
@@ -40,11 +40,7 @@ func (h *builtinAutoHook) AfterLLM(
|
||||
return next, HookDecision{Action: HookActionModify}, nil
|
||||
}
|
||||
|
||||
func newConfiguredHookLoop(
|
||||
t *testing.T,
|
||||
provider *llmHookTestProvider,
|
||||
hooks config.HooksConfig,
|
||||
) *AgentLoop {
|
||||
func newConfiguredHookLoop(t *testing.T, provider *llmHookTestProvider, hooks config.HooksConfig) *AgentLoop {
|
||||
t.Helper()
|
||||
|
||||
cfg := &config.Config{
|
||||
@@ -106,13 +102,7 @@ func TestAgentLoop_ProcessDirectWithChannel_AutoMountsBuiltinHook(t *testing.T)
|
||||
})
|
||||
defer al.Close()
|
||||
|
||||
resp, err := al.ProcessDirectWithChannel(
|
||||
context.Background(),
|
||||
"hello",
|
||||
"session-1",
|
||||
"cli",
|
||||
"direct",
|
||||
)
|
||||
resp, err := al.ProcessDirectWithChannel(context.Background(), "hello", "session-1", "cli", "direct")
|
||||
if err != nil {
|
||||
t.Fatalf("ProcessDirectWithChannel failed: %v", err)
|
||||
}
|
||||
@@ -150,13 +140,7 @@ func TestAgentLoop_ProcessDirectWithChannel_AutoMountsProcessHook(t *testing.T)
|
||||
})
|
||||
defer al.Close()
|
||||
|
||||
resp, err := al.ProcessDirectWithChannel(
|
||||
context.Background(),
|
||||
"hello",
|
||||
"session-1",
|
||||
"cli",
|
||||
"direct",
|
||||
)
|
||||
resp, err := al.ProcessDirectWithChannel(context.Background(), "hello", "session-1", "cli", "direct")
|
||||
if err != nil {
|
||||
t.Fatalf("ProcessDirectWithChannel failed: %v", err)
|
||||
}
|
||||
@@ -188,13 +172,7 @@ func TestAgentLoop_ProcessDirectWithChannel_InvalidConfiguredHookFails(t *testin
|
||||
})
|
||||
defer al.Close()
|
||||
|
||||
_, err := al.ProcessDirectWithChannel(
|
||||
context.Background(),
|
||||
"hello",
|
||||
"session-1",
|
||||
"cli",
|
||||
"direct",
|
||||
)
|
||||
_, err := al.ProcessDirectWithChannel(context.Background(), "hello", "session-1", "cli", "direct")
|
||||
if err == nil {
|
||||
t.Fatal("expected invalid configured hook error")
|
||||
}
|
||||
|
||||
@@ -98,11 +98,7 @@ type processHookAfterToolResponse struct {
|
||||
Result *ToolResultHookResponse `json:"result,omitempty"`
|
||||
}
|
||||
|
||||
func NewProcessHook(
|
||||
ctx context.Context,
|
||||
name string,
|
||||
opts ProcessHookOptions,
|
||||
) (*ProcessHook, error) {
|
||||
func NewProcessHook(ctx context.Context, name string, opts ProcessHookOptions) (*ProcessHook, error) {
|
||||
if len(opts.Command) == 0 {
|
||||
return nil, fmt.Errorf("process hook command is required")
|
||||
}
|
||||
@@ -266,10 +262,7 @@ func (ph *ProcessHook) AfterTool(
|
||||
return resp.Result, HookDecision{Action: resp.Action, Reason: resp.Reason}, nil
|
||||
}
|
||||
|
||||
func (ph *ProcessHook) ApproveTool(
|
||||
ctx context.Context,
|
||||
req *ToolApprovalRequest,
|
||||
) (ApprovalDecision, error) {
|
||||
func (ph *ProcessHook) ApproveTool(ctx context.Context, req *ToolApprovalRequest) (ApprovalDecision, error) {
|
||||
if ph == nil || !ph.opts.ApproveTool {
|
||||
return ApprovalDecision{Approved: true}, nil
|
||||
}
|
||||
@@ -480,11 +473,7 @@ func (ph *ProcessHook) removePending(id uint64) {
|
||||
}
|
||||
}
|
||||
|
||||
func (al *AgentLoop) MountProcessHook(
|
||||
ctx context.Context,
|
||||
name string,
|
||||
opts ProcessHookOptions,
|
||||
) error {
|
||||
func (al *AgentLoop) MountProcessHook(ctx context.Context, name string, opts ProcessHookOptions) error {
|
||||
if al == nil {
|
||||
return fmt.Errorf("agent loop is nil")
|
||||
}
|
||||
|
||||
+4
-16
@@ -79,14 +79,8 @@ type LLMInterceptor interface {
|
||||
}
|
||||
|
||||
type ToolInterceptor interface {
|
||||
BeforeTool(
|
||||
ctx context.Context,
|
||||
call *ToolCallHookRequest,
|
||||
) (*ToolCallHookRequest, HookDecision, error)
|
||||
AfterTool(
|
||||
ctx context.Context,
|
||||
result *ToolResultHookResponse,
|
||||
) (*ToolResultHookResponse, HookDecision, error)
|
||||
BeforeTool(ctx context.Context, call *ToolCallHookRequest) (*ToolCallHookRequest, HookDecision, error)
|
||||
AfterTool(ctx context.Context, result *ToolResultHookResponse) (*ToolResultHookResponse, HookDecision, error)
|
||||
}
|
||||
|
||||
type ToolApprover interface {
|
||||
@@ -301,10 +295,7 @@ func (hm *HookManager) dispatchEvents() {
|
||||
}
|
||||
}
|
||||
|
||||
func (hm *HookManager) BeforeLLM(
|
||||
ctx context.Context,
|
||||
req *LLMHookRequest,
|
||||
) (*LLMHookRequest, HookDecision) {
|
||||
func (hm *HookManager) BeforeLLM(ctx context.Context, req *LLMHookRequest) (*LLMHookRequest, HookDecision) {
|
||||
if hm == nil || req == nil {
|
||||
return req, HookDecision{Action: HookActionContinue}
|
||||
}
|
||||
@@ -335,10 +326,7 @@ func (hm *HookManager) BeforeLLM(
|
||||
return current, HookDecision{Action: HookActionContinue}
|
||||
}
|
||||
|
||||
func (hm *HookManager) AfterLLM(
|
||||
ctx context.Context,
|
||||
resp *LLMHookResponse,
|
||||
) (*LLMHookResponse, HookDecision) {
|
||||
func (hm *HookManager) AfterLLM(ctx context.Context, resp *LLMHookResponse) (*LLMHookResponse, HookDecision) {
|
||||
if hm == nil || resp == nil {
|
||||
return resp, HookDecision{Action: HookActionContinue}
|
||||
}
|
||||
|
||||
@@ -293,10 +293,7 @@ func TestAgentLoop_Hooks_ToolInterceptorCanRewrite(t *testing.T) {
|
||||
|
||||
type denyApprovalHook struct{}
|
||||
|
||||
func (h *denyApprovalHook) ApproveTool(
|
||||
ctx context.Context,
|
||||
req *ToolApprovalRequest,
|
||||
) (ApprovalDecision, error) {
|
||||
func (h *denyApprovalHook) ApproveTool(ctx context.Context, req *ToolApprovalRequest) (ApprovalDecision, error) {
|
||||
return ApprovalDecision{
|
||||
Approved: false,
|
||||
Reason: "blocked",
|
||||
|
||||
@@ -156,11 +156,7 @@ func TestNewAgentInstance_ResolveCandidatesFromModelListAlias(t *testing.T) {
|
||||
t.Fatalf("len(Candidates) = %d, want 1", len(agent.Candidates))
|
||||
}
|
||||
if agent.Candidates[0].Provider != tt.wantProvider {
|
||||
t.Fatalf(
|
||||
"candidate provider = %q, want %q",
|
||||
agent.Candidates[0].Provider,
|
||||
tt.wantProvider,
|
||||
)
|
||||
t.Fatalf("candidate provider = %q, want %q", agent.Candidates[0].Provider, tt.wantProvider)
|
||||
}
|
||||
if agent.Candidates[0].Model != tt.wantModel {
|
||||
t.Fatalf("candidate model = %q, want %q", agent.Candidates[0].Model, tt.wantModel)
|
||||
|
||||
+28
-107
@@ -192,11 +192,7 @@ func registerSharedTools(
|
||||
Proxy: cfg.Tools.Web.Proxy,
|
||||
})
|
||||
if err != nil {
|
||||
logger.ErrorCF(
|
||||
"agent",
|
||||
"Failed to create web search tool",
|
||||
map[string]any{"error": err.Error()},
|
||||
)
|
||||
logger.ErrorCF("agent", "Failed to create web search tool", map[string]any{"error": err.Error()})
|
||||
} else if searchTool != nil {
|
||||
agent.Tools.Register(searchTool)
|
||||
}
|
||||
@@ -209,11 +205,7 @@ func registerSharedTools(
|
||||
cfg.Tools.Web.FetchLimitBytes,
|
||||
cfg.Tools.Web.PrivateHostWhitelist)
|
||||
if err != nil {
|
||||
logger.ErrorCF(
|
||||
"agent",
|
||||
"Failed to create web fetch tool",
|
||||
map[string]any{"error": err.Error()},
|
||||
)
|
||||
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
|
||||
} else {
|
||||
agent.Tools.Register(fetchTool)
|
||||
}
|
||||
@@ -483,12 +475,7 @@ func (al *AgentLoop) Run(ctx context.Context) error {
|
||||
"queue_depth": al.pendingSteeringCountForScope(target.SessionKey),
|
||||
})
|
||||
|
||||
continued, continueErr := al.Continue(
|
||||
ctx,
|
||||
target.SessionKey,
|
||||
target.Channel,
|
||||
target.ChatID,
|
||||
)
|
||||
continued, continueErr := al.Continue(ctx, target.SessionKey, target.Channel, target.ChatID)
|
||||
if continueErr != nil {
|
||||
logger.WarnCF("agent", "Failed to continue queued steering",
|
||||
map[string]any{
|
||||
@@ -516,22 +503,14 @@ func (al *AgentLoop) Run(ctx context.Context) error {
|
||||
"queue_depth": al.pendingSteeringCountForScope(target.SessionKey),
|
||||
})
|
||||
|
||||
continued, continueErr := al.Continue(
|
||||
ctx,
|
||||
target.SessionKey,
|
||||
target.Channel,
|
||||
target.ChatID,
|
||||
)
|
||||
continued, continueErr := al.Continue(ctx, target.SessionKey, target.Channel, target.ChatID)
|
||||
if continueErr != nil {
|
||||
logger.WarnCF(
|
||||
"agent",
|
||||
"Failed to continue queued steering after shutdown drain",
|
||||
logger.WarnCF("agent", "Failed to continue queued steering after shutdown drain",
|
||||
map[string]any{
|
||||
"channel": target.Channel,
|
||||
"chat_id": target.ChatID,
|
||||
"error": continueErr.Error(),
|
||||
},
|
||||
)
|
||||
})
|
||||
return
|
||||
}
|
||||
if continued == "" {
|
||||
@@ -586,15 +565,11 @@ func (al *AgentLoop) drainBusToSteering(ctx context.Context, activeScope, active
|
||||
msgScope, _, scopeOK := al.resolveSteeringTarget(msg)
|
||||
if !scopeOK || msgScope != activeScope {
|
||||
if err := al.requeueInboundMessage(msg); err != nil {
|
||||
logger.WarnCF(
|
||||
"agent",
|
||||
"Failed to requeue non-steering inbound message",
|
||||
map[string]any{
|
||||
"error": err.Error(),
|
||||
"channel": msg.Channel,
|
||||
"sender_id": msg.SenderID,
|
||||
},
|
||||
)
|
||||
logger.WarnCF("agent", "Failed to requeue non-steering inbound message", map[string]any{
|
||||
"error": err.Error(),
|
||||
"channel": msg.Channel,
|
||||
"sender_id": msg.SenderID,
|
||||
})
|
||||
}
|
||||
continue
|
||||
}
|
||||
@@ -628,10 +603,7 @@ func (al *AgentLoop) Stop() {
|
||||
al.running.Store(false)
|
||||
}
|
||||
|
||||
func (al *AgentLoop) PublishResponseIfNeeded(
|
||||
ctx context.Context,
|
||||
channel, chatID, response string,
|
||||
) {
|
||||
func (al *AgentLoop) PublishResponseIfNeeded(ctx context.Context, channel, chatID, response string) {
|
||||
if response == "" {
|
||||
return
|
||||
}
|
||||
@@ -1081,10 +1053,7 @@ var audioAnnotationRe = regexp.MustCompile(`\[(voice|audio)(?::[^\]]*)?\]`)
|
||||
// transcribeAudioInMessage resolves audio media refs, transcribes them, and
|
||||
// replaces audio annotations in msg.Content with the transcribed text.
|
||||
// Returns the (possibly modified) message and true if audio was transcribed.
|
||||
func (al *AgentLoop) transcribeAudioInMessage(
|
||||
ctx context.Context,
|
||||
msg bus.InboundMessage,
|
||||
) (bus.InboundMessage, bool) {
|
||||
func (al *AgentLoop) transcribeAudioInMessage(ctx context.Context, msg bus.InboundMessage) (bus.InboundMessage, bool) {
|
||||
if al.transcriber == nil || al.mediaStore == nil || len(msg.Media) == 0 {
|
||||
return msg, false
|
||||
}
|
||||
@@ -1094,11 +1063,7 @@ func (al *AgentLoop) transcribeAudioInMessage(
|
||||
for _, ref := range msg.Media {
|
||||
path, meta, err := al.mediaStore.ResolveWithMeta(ref)
|
||||
if err != nil {
|
||||
logger.WarnCF(
|
||||
"voice",
|
||||
"Failed to resolve media ref",
|
||||
map[string]any{"ref": ref, "error": err},
|
||||
)
|
||||
logger.WarnCF("voice", "Failed to resolve media ref", map[string]any{"ref": ref, "error": err})
|
||||
continue
|
||||
}
|
||||
if !utils.IsAudioFile(meta.Filename, meta.ContentType) {
|
||||
@@ -1176,11 +1141,7 @@ func (al *AgentLoop) sendTranscriptionFeedback(
|
||||
ReplyToMessageID: messageID,
|
||||
})
|
||||
if err != nil {
|
||||
logger.WarnCF(
|
||||
"voice",
|
||||
"Failed to send transcription feedback",
|
||||
map[string]any{"error": err.Error()},
|
||||
)
|
||||
logger.WarnCF("voice", "Failed to send transcription feedback", map[string]any{"error": err.Error()})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1381,9 +1342,7 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage)
|
||||
return al.runAgentLoop(ctx, agent, opts)
|
||||
}
|
||||
|
||||
func (al *AgentLoop) resolveMessageRoute(
|
||||
msg bus.InboundMessage,
|
||||
) (routing.ResolvedRoute, *AgentInstance, error) {
|
||||
func (al *AgentLoop) resolveMessageRoute(msg bus.InboundMessage) (routing.ResolvedRoute, *AgentInstance, error) {
|
||||
registry := al.GetRegistry()
|
||||
route := registry.ResolveRoute(routing.RouteInput{
|
||||
Channel: msg.Channel,
|
||||
@@ -1399,10 +1358,7 @@ func (al *AgentLoop) resolveMessageRoute(
|
||||
agent = registry.GetDefaultAgent()
|
||||
}
|
||||
if agent == nil {
|
||||
return routing.ResolvedRoute{}, nil, fmt.Errorf(
|
||||
"no agent available for route (agent_id=%s)",
|
||||
route.AgentID,
|
||||
)
|
||||
return routing.ResolvedRoute{}, nil, fmt.Errorf("no agent available for route (agent_id=%s)", route.AgentID)
|
||||
}
|
||||
|
||||
return route, agent, nil
|
||||
@@ -1727,11 +1683,7 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er
|
||||
ts.recordPersistedMessage(rootMsg)
|
||||
}
|
||||
|
||||
activeCandidates, activeModel, usedLight := al.selectCandidates(
|
||||
ts.agent,
|
||||
ts.userMessage,
|
||||
messages,
|
||||
)
|
||||
activeCandidates, activeModel, usedLight := al.selectCandidates(ts.agent, ts.userMessage, messages)
|
||||
activeProvider := ts.agent.Provider
|
||||
if usedLight && ts.agent.LightProvider != nil {
|
||||
activeProvider = ts.agent.LightProvider
|
||||
@@ -2704,15 +2656,12 @@ turnLoop:
|
||||
}
|
||||
|
||||
if steerMsgs := al.dequeueSteeringMessagesForScope(ts.sessionKey); len(steerMsgs) > 0 {
|
||||
logger.InfoCF(
|
||||
"agent",
|
||||
"Steering arrived after turn completion; continuing turn before finalizing",
|
||||
logger.InfoCF("agent", "Steering arrived after turn completion; continuing turn before finalizing",
|
||||
map[string]any{
|
||||
"agent_id": ts.agent.ID,
|
||||
"steering_count": len(steerMsgs),
|
||||
"session_key": ts.sessionKey,
|
||||
},
|
||||
)
|
||||
})
|
||||
pendingMessages = append(pendingMessages, steerMsgs...)
|
||||
finalContent = ""
|
||||
goto turnLoop
|
||||
@@ -2828,18 +2777,11 @@ func (al *AgentLoop) selectCandidates(
|
||||
"score": score,
|
||||
"threshold": agent.Router.Threshold(),
|
||||
})
|
||||
return agent.LightCandidates, resolvedCandidateModel(
|
||||
agent.LightCandidates,
|
||||
agent.Router.LightModel(),
|
||||
), true
|
||||
return agent.LightCandidates, resolvedCandidateModel(agent.LightCandidates, agent.Router.LightModel()), true
|
||||
}
|
||||
|
||||
// maybeSummarize triggers summarization if the session history exceeds thresholds.
|
||||
func (al *AgentLoop) maybeSummarize(
|
||||
agent *AgentInstance,
|
||||
sessionKey string,
|
||||
turnScope turnEventScope,
|
||||
) {
|
||||
func (al *AgentLoop) maybeSummarize(agent *AgentInstance, sessionKey string, turnScope turnEventScope) {
|
||||
newHistory := agent.Sessions.GetHistory(sessionKey)
|
||||
tokenEstimate := al.estimateTokens(newHistory)
|
||||
threshold := agent.ContextWindow * agent.SummarizeTokenPercent / 100
|
||||
@@ -2873,10 +2815,7 @@ type compressionResult struct {
|
||||
// prompt is built dynamically by BuildMessages and is NOT stored here.
|
||||
// The compression note is recorded in the session summary so that
|
||||
// BuildMessages can include it in the next system prompt.
|
||||
func (al *AgentLoop) forceCompression(
|
||||
agent *AgentInstance,
|
||||
sessionKey string,
|
||||
) (compressionResult, bool) {
|
||||
func (al *AgentLoop) forceCompression(agent *AgentInstance, sessionKey string) (compressionResult, bool) {
|
||||
history := agent.Sessions.GetHistory(sessionKey)
|
||||
if len(history) <= 2 {
|
||||
return compressionResult{}, false
|
||||
@@ -3029,11 +2968,7 @@ func formatToolsForLog(toolDefs []providers.ToolDefinition) string {
|
||||
}
|
||||
|
||||
// summarizeSession summarizes the conversation history for a session.
|
||||
func (al *AgentLoop) summarizeSession(
|
||||
agent *AgentInstance,
|
||||
sessionKey string,
|
||||
turnScope turnEventScope,
|
||||
) {
|
||||
func (al *AgentLoop) summarizeSession(agent *AgentInstance, sessionKey string, turnScope turnEventScope) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||
defer cancel()
|
||||
|
||||
@@ -3385,10 +3320,7 @@ func (al *AgentLoop) applyExplicitSkillCommand(
|
||||
|
||||
skillName, ok := agent.ContextBuilder.ResolveSkillName(arg)
|
||||
if !ok {
|
||||
return true, true, fmt.Sprintf(
|
||||
"Unknown skill: %s\nUse /list skills to see installed skills.",
|
||||
arg,
|
||||
)
|
||||
return true, true, fmt.Sprintf("Unknown skill: %s\nUse /list skills to see installed skills.", arg)
|
||||
}
|
||||
|
||||
if len(parts) < 3 {
|
||||
@@ -3415,10 +3347,7 @@ func (al *AgentLoop) applyExplicitSkillCommand(
|
||||
return true, false, ""
|
||||
}
|
||||
|
||||
func (al *AgentLoop) buildCommandsRuntime(
|
||||
agent *AgentInstance,
|
||||
opts *processOptions,
|
||||
) *commands.Runtime {
|
||||
func (al *AgentLoop) buildCommandsRuntime(agent *AgentInstance, opts *processOptions) *commands.Runtime {
|
||||
registry := al.GetRegistry()
|
||||
cfg := al.GetConfig()
|
||||
rt := &commands.Runtime{
|
||||
@@ -3462,10 +3391,7 @@ func (al *AgentLoop) buildCommandsRuntime(
|
||||
rt.ListSkillNames = agent.ContextBuilder.ListSkillNames
|
||||
}
|
||||
rt.GetModelInfo = func() (string, string) {
|
||||
return agent.Model, resolvedCandidateProvider(
|
||||
agent.Candidates,
|
||||
cfg.Agents.Defaults.Provider,
|
||||
)
|
||||
return agent.Model, resolvedCandidateProvider(agent.Candidates, cfg.Agents.Defaults.Provider)
|
||||
}
|
||||
rt.SwitchModel = func(value string) (string, error) {
|
||||
value = strings.TrimSpace(value)
|
||||
@@ -3479,12 +3405,7 @@ func (al *AgentLoop) buildCommandsRuntime(
|
||||
return "", fmt.Errorf("failed to initialize model %q: %w", value, err)
|
||||
}
|
||||
|
||||
nextCandidates := resolveModelCandidates(
|
||||
cfg,
|
||||
cfg.Agents.Defaults.Provider,
|
||||
modelCfg.Model,
|
||||
agent.Fallbacks,
|
||||
)
|
||||
nextCandidates := resolveModelCandidates(cfg, cfg.Agents.Defaults.Provider, modelCfg.Model, agent.Fallbacks)
|
||||
if len(nextCandidates) == 0 {
|
||||
return "", fmt.Errorf("model %q did not resolve to any provider candidates", value)
|
||||
}
|
||||
|
||||
+4
-16
@@ -65,11 +65,7 @@ func (al *AgentLoop) ensureMCPInitialized(ctx context.Context) error {
|
||||
}
|
||||
|
||||
if al.cfg.Tools.MCP.Servers == nil || len(al.cfg.Tools.MCP.Servers) == 0 {
|
||||
logger.WarnCF(
|
||||
"agent",
|
||||
"MCP is enabled but no servers are configured, skipping MCP initialization",
|
||||
nil,
|
||||
)
|
||||
logger.WarnCF("agent", "MCP is enabled but no servers are configured, skipping MCP initialization", nil)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -80,11 +76,7 @@ func (al *AgentLoop) ensureMCPInitialized(ctx context.Context) error {
|
||||
}
|
||||
}
|
||||
if !findValidServer {
|
||||
logger.WarnCF(
|
||||
"agent",
|
||||
"MCP is enabled but no valid servers are configured, skipping MCP initialization",
|
||||
nil,
|
||||
)
|
||||
logger.WarnCF("agent", "MCP is enabled but no valid servers are configured, skipping MCP initialization", nil)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -201,14 +193,10 @@ func (al *AgentLoop) ensureMCPInitialized(ctx context.Context) error {
|
||||
}
|
||||
|
||||
if useRegex {
|
||||
agent.Tools.Register(
|
||||
tools.NewRegexSearchTool(agent.Tools, ttl, maxSearchResults),
|
||||
)
|
||||
agent.Tools.Register(tools.NewRegexSearchTool(agent.Tools, ttl, maxSearchResults))
|
||||
}
|
||||
if useBM25 {
|
||||
agent.Tools.Register(
|
||||
tools.NewBM25SearchTool(agent.Tools, ttl, maxSearchResults),
|
||||
)
|
||||
agent.Tools.Register(tools.NewBM25SearchTool(agent.Tools, ttl, maxSearchResults))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -25,11 +25,7 @@ import (
|
||||
// Non-image files (documents, audio, video) have their local path injected
|
||||
// into Content so the agent can access them via file tools like read_file.
|
||||
// Returns a new slice; original messages are not mutated.
|
||||
func resolveMediaRefs(
|
||||
messages []providers.Message,
|
||||
store media.MediaStore,
|
||||
maxSize int,
|
||||
) []providers.Message {
|
||||
func resolveMediaRefs(messages []providers.Message, store media.MediaStore, maxSize int) []providers.Message {
|
||||
if store == nil {
|
||||
return messages
|
||||
}
|
||||
|
||||
+24
-97
@@ -591,9 +591,7 @@ func TestProcessMessage_MediaToolHandledSkipsFollowUpLLMAndFinalText(t *testing.
|
||||
store := media.NewFileMediaStore()
|
||||
al.SetMediaStore(store)
|
||||
telegramChannel := &fakeMediaChannel{fakeChannel: fakeChannel{id: "rid-telegram"}}
|
||||
al.SetChannelManager(
|
||||
newStartedTestChannelManager(t, msgBus, store, "telegram", telegramChannel),
|
||||
)
|
||||
al.SetChannelManager(newStartedTestChannelManager(t, msgBus, store, "telegram", telegramChannel))
|
||||
|
||||
imagePath := filepath.Join(tmpDir, "screen.png")
|
||||
if err := os.WriteFile(imagePath, []byte("fake screenshot"), 0o644); err != nil {
|
||||
@@ -615,10 +613,7 @@ func TestProcessMessage_MediaToolHandledSkipsFollowUpLLMAndFinalText(t *testing.
|
||||
t.Fatalf("processMessage() error = %v", err)
|
||||
}
|
||||
if response != "" {
|
||||
t.Fatalf(
|
||||
"expected no final response when media tool already handled delivery, got %q",
|
||||
response,
|
||||
)
|
||||
t.Fatalf("expected no final response when media tool already handled delivery, got %q", response)
|
||||
}
|
||||
if provider.calls != 1 {
|
||||
t.Fatalf("expected exactly 1 LLM call, got %d", provider.calls)
|
||||
@@ -631,20 +626,13 @@ func TestProcessMessage_MediaToolHandledSkipsFollowUpLLMAndFinalText(t *testing.
|
||||
}
|
||||
|
||||
if len(telegramChannel.sentMedia) != 1 {
|
||||
t.Fatalf(
|
||||
"expected exactly 1 synchronously sent media message, got %d",
|
||||
len(telegramChannel.sentMedia),
|
||||
)
|
||||
t.Fatalf("expected exactly 1 synchronously sent media message, got %d", len(telegramChannel.sentMedia))
|
||||
}
|
||||
if telegramChannel.sentMedia[0].Channel != "telegram" ||
|
||||
telegramChannel.sentMedia[0].ChatID != "chat1" {
|
||||
if telegramChannel.sentMedia[0].Channel != "telegram" || telegramChannel.sentMedia[0].ChatID != "chat1" {
|
||||
t.Fatalf("unexpected sent media target: %+v", telegramChannel.sentMedia[0])
|
||||
}
|
||||
if len(telegramChannel.sentMedia[0].Parts) != 1 {
|
||||
t.Fatalf(
|
||||
"expected exactly 1 sent media part, got %d",
|
||||
len(telegramChannel.sentMedia[0].Parts),
|
||||
)
|
||||
t.Fatalf("expected exactly 1 sent media part, got %d", len(telegramChannel.sentMedia[0].Parts))
|
||||
}
|
||||
|
||||
select {
|
||||
@@ -672,8 +660,7 @@ func TestProcessMessage_MediaToolHandledSkipsFollowUpLLMAndFinalText(t *testing.
|
||||
t.Fatal("expected session history to be saved")
|
||||
}
|
||||
last := history[len(history)-1]
|
||||
if last.Role != "assistant" ||
|
||||
last.Content != "Requested output delivered via tool attachment." {
|
||||
if last.Role != "assistant" || last.Content != "Requested output delivered via tool attachment." {
|
||||
t.Fatalf("expected handled assistant summary in history, got %+v", last)
|
||||
}
|
||||
}
|
||||
@@ -698,9 +685,7 @@ func TestProcessMessage_HandledToolProcessesQueuedSteeringBeforeReturning(t *tes
|
||||
store := media.NewFileMediaStore()
|
||||
al.SetMediaStore(store)
|
||||
telegramChannel := &fakeMediaChannel{fakeChannel: fakeChannel{id: "rid-telegram"}}
|
||||
al.SetChannelManager(
|
||||
newStartedTestChannelManager(t, msgBus, store, "telegram", telegramChannel),
|
||||
)
|
||||
al.SetChannelManager(newStartedTestChannelManager(t, msgBus, store, "telegram", telegramChannel))
|
||||
|
||||
imagePath := filepath.Join(tmpDir, "screen-steering.png")
|
||||
if err := os.WriteFile(imagePath, []byte("fake screenshot"), 0o644); err != nil {
|
||||
@@ -729,10 +714,7 @@ func TestProcessMessage_HandledToolProcessesQueuedSteeringBeforeReturning(t *tes
|
||||
t.Fatalf("expected 2 LLM calls after queued steering, got %d", provider.calls)
|
||||
}
|
||||
if len(telegramChannel.sentMedia) != 1 {
|
||||
t.Fatalf(
|
||||
"expected exactly 1 synchronously sent media message, got %d",
|
||||
len(telegramChannel.sentMedia),
|
||||
)
|
||||
t.Fatalf("expected exactly 1 synchronously sent media message, got %d", len(telegramChannel.sentMedia))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -751,9 +733,7 @@ func TestProcessMessage_MediaArtifactCanBeForwardedBySendFile(t *testing.T) {
|
||||
store := media.NewFileMediaStore()
|
||||
al.SetMediaStore(store)
|
||||
telegramChannel := &fakeMediaChannel{fakeChannel: fakeChannel{id: "rid-telegram"}}
|
||||
al.SetChannelManager(
|
||||
newStartedTestChannelManager(t, msgBus, store, "telegram", telegramChannel),
|
||||
)
|
||||
al.SetChannelManager(newStartedTestChannelManager(t, msgBus, store, "telegram", telegramChannel))
|
||||
|
||||
mediaDir := media.TempDir()
|
||||
if err := os.MkdirAll(mediaDir, 0o700); err != nil {
|
||||
@@ -786,20 +766,13 @@ func TestProcessMessage_MediaArtifactCanBeForwardedBySendFile(t *testing.T) {
|
||||
}
|
||||
|
||||
if len(telegramChannel.sentMedia) != 1 {
|
||||
t.Fatalf(
|
||||
"expected exactly 1 synchronously sent media message, got %d",
|
||||
len(telegramChannel.sentMedia),
|
||||
)
|
||||
t.Fatalf("expected exactly 1 synchronously sent media message, got %d", len(telegramChannel.sentMedia))
|
||||
}
|
||||
if telegramChannel.sentMedia[0].Channel != "telegram" ||
|
||||
telegramChannel.sentMedia[0].ChatID != "chat1" {
|
||||
if telegramChannel.sentMedia[0].Channel != "telegram" || telegramChannel.sentMedia[0].ChatID != "chat1" {
|
||||
t.Fatalf("unexpected sent media target: %+v", telegramChannel.sentMedia[0])
|
||||
}
|
||||
if len(telegramChannel.sentMedia[0].Parts) != 1 {
|
||||
t.Fatalf(
|
||||
"expected exactly 1 sent media part, got %d",
|
||||
len(telegramChannel.sentMedia[0].Parts),
|
||||
)
|
||||
t.Fatalf("expected exactly 1 sent media part, got %d", len(telegramChannel.sentMedia[0].Parts))
|
||||
}
|
||||
|
||||
select {
|
||||
@@ -1210,10 +1183,7 @@ func (m *handledMediaWithSteeringTool) Parameters() map[string]any {
|
||||
}
|
||||
}
|
||||
|
||||
func (m *handledMediaWithSteeringTool) Execute(
|
||||
ctx context.Context,
|
||||
args map[string]any,
|
||||
) *tools.ToolResult {
|
||||
func (m *handledMediaWithSteeringTool) Execute(ctx context.Context, args map[string]any) *tools.ToolResult {
|
||||
if err := m.loop.Steer(providers.Message{Role: "user", Content: "what about this instead?"}); err != nil {
|
||||
return tools.ErrorResult(err.Error()).WithError(err)
|
||||
}
|
||||
@@ -1366,11 +1336,7 @@ func newStrictChatCompletionTestServer(
|
||||
}))
|
||||
}
|
||||
|
||||
func (h testHelper) executeAndGetResponse(
|
||||
tb testing.TB,
|
||||
ctx context.Context,
|
||||
msg bus.InboundMessage,
|
||||
) string {
|
||||
func (h testHelper) executeAndGetResponse(tb testing.TB, ctx context.Context, msg bus.InboundMessage) string {
|
||||
// Use a short timeout to avoid hanging
|
||||
timeoutCtx, cancel := context.WithTimeout(ctx, responseTimeout)
|
||||
defer cancel()
|
||||
@@ -1501,10 +1467,7 @@ func TestProcessMessage_CommandOutcomes(t *testing.T) {
|
||||
t.Fatalf("unexpected /foo reply: %q", fooResp)
|
||||
}
|
||||
if provider.calls != 1 {
|
||||
t.Fatalf(
|
||||
"LLM should be called exactly once after /foo passthrough, calls=%d",
|
||||
provider.calls,
|
||||
)
|
||||
t.Fatalf("LLM should be called exactly once after /foo passthrough, calls=%d", provider.calls)
|
||||
}
|
||||
|
||||
newResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
|
||||
@@ -1654,10 +1617,7 @@ func TestProcessMessage_SwitchModelRejectsUnknownAlias(t *testing.T) {
|
||||
}
|
||||
|
||||
if provider.calls != 0 {
|
||||
t.Fatalf(
|
||||
"LLM should not be called for rejected /switch and /show, calls=%d",
|
||||
provider.calls,
|
||||
)
|
||||
t.Fatalf("LLM should not be called for rejected /switch and /show, calls=%d", provider.calls)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1675,13 +1635,7 @@ func TestProcessMessage_SwitchModelRoutesSubsequentRequestsToSelectedProvider(t
|
||||
|
||||
remoteCalls := 0
|
||||
remoteModel := ""
|
||||
remoteServer := newChatCompletionTestServer(
|
||||
t,
|
||||
"remote",
|
||||
"remote reply",
|
||||
&remoteCalls,
|
||||
&remoteModel,
|
||||
)
|
||||
remoteServer := newChatCompletionTestServer(t, "remote", "remote reply", &remoteCalls, &remoteModel)
|
||||
defer remoteServer.Close()
|
||||
|
||||
cfg := &config.Config{
|
||||
@@ -2004,9 +1958,7 @@ func TestAgentLoop_ContextExhaustionRetry(t *testing.T) {
|
||||
msgBus := bus.NewMessageBus()
|
||||
|
||||
// Create a provider that fails once with a context error
|
||||
contextErr := fmt.Errorf(
|
||||
"InvalidParameter: Total tokens of image and text exceed max message tokens",
|
||||
)
|
||||
contextErr := fmt.Errorf("InvalidParameter: Total tokens of image and text exceed max message tokens")
|
||||
provider := &failFirstMockProvider{
|
||||
failures: 1,
|
||||
failError: contextErr,
|
||||
@@ -2087,13 +2039,7 @@ func TestAgentLoop_EmptyModelResponseUsesAccurateFallback(t *testing.T) {
|
||||
provider := &simpleMockProvider{response: ""}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
|
||||
response, err := al.ProcessDirectWithChannel(
|
||||
context.Background(),
|
||||
"hello",
|
||||
"empty-response",
|
||||
"test",
|
||||
"chat1",
|
||||
)
|
||||
response, err := al.ProcessDirectWithChannel(context.Background(), "hello", "empty-response", "test", "chat1")
|
||||
if err != nil {
|
||||
t.Fatalf("ProcessDirectWithChannel failed: %v", err)
|
||||
}
|
||||
@@ -2125,13 +2071,7 @@ func TestAgentLoop_ToolLimitUsesDedicatedFallback(t *testing.T) {
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
al.RegisterTool(&toolLimitTestTool{})
|
||||
|
||||
response, err := al.ProcessDirectWithChannel(
|
||||
context.Background(),
|
||||
"hello",
|
||||
"tool-limit",
|
||||
"test",
|
||||
"chat1",
|
||||
)
|
||||
response, err := al.ProcessDirectWithChannel(context.Background(), "hello", "tool-limit", "test", "chat1")
|
||||
if err != nil {
|
||||
t.Fatalf("ProcessDirectWithChannel failed: %v", err)
|
||||
}
|
||||
@@ -2449,9 +2389,7 @@ func TestHandleReasoning(t *testing.T) {
|
||||
break
|
||||
}
|
||||
if msg.Content == "should timeout" {
|
||||
t.Fatal(
|
||||
"expected reasoning message to be dropped when bus is full, but it was published",
|
||||
)
|
||||
t.Fatal("expected reasoning message to be dropped when bus is full, but it was published")
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -2545,12 +2483,7 @@ func TestProcessHeartbeat_DoesNotPublishToolFeedback(t *testing.T) {
|
||||
provider := &toolFeedbackProvider{filePath: heartbeatFile}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
|
||||
response, err := al.ProcessHeartbeat(
|
||||
context.Background(),
|
||||
"check heartbeat tasks",
|
||||
"telegram",
|
||||
"chat-1",
|
||||
)
|
||||
response, err := al.ProcessHeartbeat(context.Background(), "check heartbeat tasks", "telegram", "chat-1")
|
||||
if err != nil {
|
||||
t.Fatalf("ProcessHeartbeat() error = %v", err)
|
||||
}
|
||||
@@ -3035,14 +2968,8 @@ func TestProcessMessage_ContextOverflowRecovery(t *testing.T) {
|
||||
agent := al.GetRegistry().GetDefaultAgent()
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
agent.Sessions.AddFullMessage(
|
||||
sessionKey,
|
||||
providers.Message{Role: "user", Content: "heavy message"},
|
||||
)
|
||||
agent.Sessions.AddFullMessage(
|
||||
sessionKey,
|
||||
providers.Message{Role: "assistant", Content: "response"},
|
||||
)
|
||||
agent.Sessions.AddFullMessage(sessionKey, providers.Message{Role: "user", Content: "heavy message"})
|
||||
agent.Sessions.AddFullMessage(sessionKey, providers.Message{Role: "assistant", Content: "response"})
|
||||
}
|
||||
|
||||
response, err := al.processMessage(context.Background(), bus.InboundMessage{
|
||||
|
||||
@@ -26,8 +26,7 @@ func buildModelListResolver(cfg *config.Config) func(raw string) (string, bool)
|
||||
return "", false
|
||||
}
|
||||
|
||||
if mc, err := cfg.GetModelConfig(raw); err == nil && mc != nil &&
|
||||
strings.TrimSpace(mc.Model) != "" {
|
||||
if mc, err := cfg.GetModelConfig(raw); err == nil && mc != nil && strings.TrimSpace(mc.Model) != "" {
|
||||
return ensureProtocol(mc.Model), true
|
||||
}
|
||||
|
||||
@@ -79,10 +78,7 @@ func resolvedCandidateProvider(candidates []providers.FallbackCandidate, fallbac
|
||||
return fallback
|
||||
}
|
||||
|
||||
func resolvedModelConfig(
|
||||
cfg *config.Config,
|
||||
modelName, workspace string,
|
||||
) (*config.ModelConfig, error) {
|
||||
func resolvedModelConfig(cfg *config.Config, modelName, workspace string) (*config.ModelConfig, error) {
|
||||
if cfg == nil {
|
||||
return nil, fmt.Errorf("config is nil")
|
||||
}
|
||||
|
||||
@@ -325,10 +325,7 @@ func (al *AgentLoop) agentForSession(sessionKey string) *AgentInstance {
|
||||
// user has since enqueued steering messages.
|
||||
//
|
||||
// If no steering messages are pending, it returns an empty string.
|
||||
func (al *AgentLoop) Continue(
|
||||
ctx context.Context,
|
||||
sessionKey, channel, chatID string,
|
||||
) (string, error) {
|
||||
func (al *AgentLoop) Continue(ctx context.Context, sessionKey, channel, chatID string) (string, error) {
|
||||
if active := al.GetActiveTurn(); active != nil {
|
||||
return "", fmt.Errorf("turn %s is still active", active.TurnID)
|
||||
}
|
||||
|
||||
@@ -896,10 +896,7 @@ func TestAgentLoop_Run_AutoContinuesLateSteeringMessage(t *testing.T) {
|
||||
defer cancelNoExtra()
|
||||
select {
|
||||
case out2 := <-msgBus.OutboundChan():
|
||||
t.Fatalf(
|
||||
"expected stale direct response to be suppressed, got extra outbound %q",
|
||||
out2.Content,
|
||||
)
|
||||
t.Fatalf("expected stale direct response to be suppressed, got extra outbound %q", out2.Content)
|
||||
case <-noExtraCtx.Done():
|
||||
}
|
||||
|
||||
@@ -1047,11 +1044,7 @@ func TestAgentLoop_Continue_PreservesSteeringMedia(t *testing.T) {
|
||||
if err = os.WriteFile(pngPath, pngHeader, 0o644); err != nil {
|
||||
t.Fatalf("WriteFile failed: %v", err)
|
||||
}
|
||||
ref, err := store.Store(
|
||||
pngPath,
|
||||
media.MediaMeta{Filename: "steer.png", ContentType: "image/png"},
|
||||
"test",
|
||||
)
|
||||
ref, err := store.Store(pngPath, media.MediaMeta{Filename: "steer.png", ContentType: "image/png"}, "test")
|
||||
if err != nil {
|
||||
t.Fatalf("Store failed: %v", err)
|
||||
}
|
||||
@@ -1243,10 +1236,7 @@ func TestAgentLoop_InterruptGraceful_UsesTerminalNoToolCall(t *testing.T) {
|
||||
t.Fatalf("expected 2 provider calls, got %d", calls)
|
||||
}
|
||||
if terminalToolsCount != 0 {
|
||||
t.Fatalf(
|
||||
"expected graceful terminal call to disable tools, got %d tool defs",
|
||||
terminalToolsCount,
|
||||
)
|
||||
t.Fatalf("expected graceful terminal call to disable tools, got %d tool defs", terminalToolsCount)
|
||||
}
|
||||
|
||||
foundHint := false
|
||||
@@ -1257,8 +1247,7 @@ func TestAgentLoop_InterruptGraceful_UsesTerminalNoToolCall(t *testing.T) {
|
||||
if msg.Role == "user" && msg.Content == expectedHint {
|
||||
foundHint = true
|
||||
}
|
||||
if msg.Role == "tool" && msg.ToolCallID == "call_2" &&
|
||||
msg.Content == "Skipped due to graceful interrupt." {
|
||||
if msg.Role == "tool" && msg.ToolCallID == "call_2" && msg.Content == "Skipped due to graceful interrupt." {
|
||||
foundSkipped = true
|
||||
}
|
||||
}
|
||||
@@ -1550,8 +1539,7 @@ func TestAgentLoop_Steering_SkippedToolsHaveErrorResults(t *testing.T) {
|
||||
|
||||
foundSkipped := false
|
||||
for _, m := range msgs {
|
||||
if m.Role == "tool" && m.ToolCallID == "call_2" &&
|
||||
m.Content == "Skipped due to queued user message." {
|
||||
if m.Role == "tool" && m.ToolCallID == "call_2" && m.Content == "Skipped due to queued user message." {
|
||||
foundSkipped = true
|
||||
break
|
||||
}
|
||||
@@ -1559,13 +1547,7 @@ func TestAgentLoop_Steering_SkippedToolsHaveErrorResults(t *testing.T) {
|
||||
if !foundSkipped {
|
||||
// Log what we actually got
|
||||
for i, m := range msgs {
|
||||
t.Logf(
|
||||
"msg[%d]: role=%s toolCallID=%s content=%s",
|
||||
i,
|
||||
m.Role,
|
||||
m.ToolCallID,
|
||||
truncate(m.Content, 80),
|
||||
)
|
||||
t.Logf("msg[%d]: role=%s toolCallID=%s content=%s", i, m.Role, m.ToolCallID, truncate(m.Content, 80))
|
||||
}
|
||||
t.Fatal("expected skipped tool result for call_2")
|
||||
}
|
||||
|
||||
+5
-20
@@ -505,12 +505,7 @@ func spawnSubTurn(
|
||||
// Event emissions:
|
||||
// - SubTurnResultDeliveredEvent: successful delivery to channel
|
||||
// - SubTurnOrphanResultEvent: delivery failed (parent finished or channel full)
|
||||
func deliverSubTurnResult(
|
||||
al *AgentLoop,
|
||||
parentTS *turnState,
|
||||
childID string,
|
||||
result *tools.ToolResult,
|
||||
) {
|
||||
func deliverSubTurnResult(al *AgentLoop, parentTS *turnState, childID string, result *tools.ToolResult) {
|
||||
// Let GC clean up the pendingResults channel; parent Finish will no longer close it.
|
||||
// We use defer/recover to catch any unlikely channel panics if it were ever closed.
|
||||
defer func() {
|
||||
@@ -521,14 +516,9 @@ func deliverSubTurnResult(
|
||||
"recover": r,
|
||||
})
|
||||
if result != nil && al != nil {
|
||||
al.emitEvent(
|
||||
EventKindSubTurnOrphan,
|
||||
al.emitEvent(EventKindSubTurnOrphan,
|
||||
parentTS.eventMeta("deliverSubTurnResult", "subturn.orphan"),
|
||||
SubTurnOrphanPayload{
|
||||
ParentTurnID: parentTS.turnID,
|
||||
ChildTurnID: childID,
|
||||
Reason: "panic",
|
||||
},
|
||||
SubTurnOrphanPayload{ParentTurnID: parentTS.turnID, ChildTurnID: childID, Reason: "panic"},
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -541,14 +531,9 @@ func deliverSubTurnResult(
|
||||
// If parent turn has already finished, treat this as an orphan result
|
||||
if isFinished || resultChan == nil {
|
||||
if result != nil && al != nil {
|
||||
al.emitEvent(
|
||||
EventKindSubTurnOrphan,
|
||||
al.emitEvent(EventKindSubTurnOrphan,
|
||||
parentTS.eventMeta("deliverSubTurnResult", "subturn.orphan"),
|
||||
SubTurnOrphanPayload{
|
||||
ParentTurnID: parentTS.turnID,
|
||||
ChildTurnID: childID,
|
||||
Reason: "parent_finished",
|
||||
},
|
||||
SubTurnOrphanPayload{ParentTurnID: parentTS.turnID, ChildTurnID: childID, Reason: "parent_finished"},
|
||||
)
|
||||
}
|
||||
return
|
||||
|
||||
@@ -571,8 +571,7 @@ func TestHardAbortSessionRollback(t *testing.T) {
|
||||
}
|
||||
|
||||
// Verify the content matches the initial state
|
||||
if finalHistory[0].Content != "initial message 1" ||
|
||||
finalHistory[1].Content != "initial response 1" {
|
||||
if finalHistory[0].Content != "initial message 1" || finalHistory[1].Content != "initial response 1" {
|
||||
t.Error("history content does not match initial state after rollback")
|
||||
}
|
||||
}
|
||||
@@ -1291,12 +1290,7 @@ func TestDeliverSubTurnResult_RaceWithFinish(t *testing.T) {
|
||||
finalOrphan := orphanCount
|
||||
mu.Unlock()
|
||||
|
||||
t.Logf(
|
||||
"Delivered: %d, Orphan: %d, Total: %d",
|
||||
finalDelivered,
|
||||
finalOrphan,
|
||||
finalDelivered+finalOrphan,
|
||||
)
|
||||
t.Logf("Delivered: %d, Orphan: %d, Total: %d", finalDelivered, finalOrphan, finalDelivered+finalOrphan)
|
||||
|
||||
// With the new drainPendingResults behavior, the total events may be >= numResults
|
||||
// because Finish() drains remaining results from the channel and emits them as orphans.
|
||||
|
||||
@@ -65,11 +65,7 @@ func DefaultConfig() *Config {
|
||||
Enabled: true,
|
||||
Text: FlexibleStringSlice{"Thinking... 💭"},
|
||||
},
|
||||
Streaming: StreamingConfig{
|
||||
Enabled: true,
|
||||
ThrottleSeconds: 3,
|
||||
MinGrowthChars: 200,
|
||||
},
|
||||
Streaming: StreamingConfig{Enabled: true, ThrottleSeconds: 3, MinGrowthChars: 200},
|
||||
UseMarkdownV2: false,
|
||||
},
|
||||
Feishu: FeishuConfig{
|
||||
|
||||
@@ -335,8 +335,7 @@ func v0ConvertProvidersToModelList(cfg *configV0) []modelConfigV0 {
|
||||
providerNames: []string{"github_copilot", "copilot"},
|
||||
protocol: "github-copilot",
|
||||
buildConfig: func(p providersConfigV0) (modelConfigV0, bool) {
|
||||
if p.GitHubCopilot.APIKey == "" && p.GitHubCopilot.APIBase == "" &&
|
||||
p.GitHubCopilot.ConnectMode == "" {
|
||||
if p.GitHubCopilot.APIKey == "" && p.GitHubCopilot.APIBase == "" && p.GitHubCopilot.ConnectMode == "" {
|
||||
return modelConfigV0{}, false
|
||||
}
|
||||
return modelConfigV0{
|
||||
|
||||
@@ -72,11 +72,7 @@ func TestMigration_Integration_LegacyConfigWithoutWorkspace(t *testing.T) {
|
||||
// CRITICAL: Verify that user's settings are preserved
|
||||
// This was the bug - these settings were lost when Workspace was empty
|
||||
if cfg.Agents.Defaults.Provider != "openai" {
|
||||
t.Errorf(
|
||||
"Provider = %q, want %q (user's setting should be preserved)",
|
||||
cfg.Agents.Defaults.Provider,
|
||||
"openai",
|
||||
)
|
||||
t.Errorf("Provider = %q, want %q (user's setting should be preserved)", cfg.Agents.Defaults.Provider, "openai")
|
||||
}
|
||||
// Old "model" field is migrated to "model_name" field
|
||||
if cfg.Agents.Defaults.ModelName != "gpt-4o" {
|
||||
@@ -303,11 +299,7 @@ func TestMigration_Integration_PreservesAllAgentsFields(t *testing.T) {
|
||||
t.Errorf("Agent.ID = %q, want %q", cfg.Agents.List[0].ID, "special-agent")
|
||||
}
|
||||
if cfg.Agents.List[0].Workspace != "/special/workspace" {
|
||||
t.Errorf(
|
||||
"Agent.Workspace = %q, want %q",
|
||||
cfg.Agents.List[0].Workspace,
|
||||
"/special/workspace",
|
||||
)
|
||||
t.Errorf("Agent.Workspace = %q, want %q", cfg.Agents.List[0].Workspace, "/special/workspace")
|
||||
}
|
||||
|
||||
// Workspace should have default since it was empty in legacy config
|
||||
@@ -370,10 +362,7 @@ func TestMigration_Integration_ChannelsConfigMigrated(t *testing.T) {
|
||||
|
||||
// OneBot: group_trigger_prefix should be migrated to group_trigger.prefixes
|
||||
if len(cfg.Channels.OneBot.GroupTrigger.Prefixes) != 2 {
|
||||
t.Errorf(
|
||||
"len(OneBot.GroupTrigger.Prefixes) = %d, want 2",
|
||||
len(cfg.Channels.OneBot.GroupTrigger.Prefixes),
|
||||
)
|
||||
t.Errorf("len(OneBot.GroupTrigger.Prefixes) = %d, want 2", len(cfg.Channels.OneBot.GroupTrigger.Prefixes))
|
||||
} else {
|
||||
if cfg.Channels.OneBot.GroupTrigger.Prefixes[0] != "/" {
|
||||
t.Errorf("Prefixes[0] = %q, want %q", cfg.Channels.OneBot.GroupTrigger.Prefixes[0], "/")
|
||||
@@ -454,25 +443,13 @@ func TestMigration_Integration_RoundTrip_SerializeAndLoad(t *testing.T) {
|
||||
|
||||
// Verify configs are identical
|
||||
if cfg2.Agents.Defaults.Provider != cfg1.Agents.Defaults.Provider {
|
||||
t.Errorf(
|
||||
"Provider changed from %q to %q",
|
||||
cfg1.Agents.Defaults.Provider,
|
||||
cfg2.Agents.Defaults.Provider,
|
||||
)
|
||||
t.Errorf("Provider changed from %q to %q", cfg1.Agents.Defaults.Provider, cfg2.Agents.Defaults.Provider)
|
||||
}
|
||||
if cfg2.Agents.Defaults.ModelName != cfg1.Agents.Defaults.ModelName {
|
||||
t.Errorf(
|
||||
"ModelName changed from %q to %q",
|
||||
cfg1.Agents.Defaults.ModelName,
|
||||
cfg2.Agents.Defaults.ModelName,
|
||||
)
|
||||
t.Errorf("ModelName changed from %q to %q", cfg1.Agents.Defaults.ModelName, cfg2.Agents.Defaults.ModelName)
|
||||
}
|
||||
if cfg2.Agents.Defaults.MaxTokens != cfg1.Agents.Defaults.MaxTokens {
|
||||
t.Errorf(
|
||||
"MaxTokens changed from %d to %d",
|
||||
cfg1.Agents.Defaults.MaxTokens,
|
||||
cfg2.Agents.Defaults.MaxTokens,
|
||||
)
|
||||
t.Errorf("MaxTokens changed from %d to %d", cfg1.Agents.Defaults.MaxTokens, cfg2.Agents.Defaults.MaxTokens)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -580,11 +557,7 @@ func TestMigration_Integration_ModelNameField(t *testing.T) {
|
||||
|
||||
// GetModelName() should return model_name, not model (deprecated)
|
||||
if cfg.Agents.Defaults.GetModelName() != "deepseek-reasoner" {
|
||||
t.Errorf(
|
||||
"GetModelName() = %q, want %q",
|
||||
cfg.Agents.Defaults.GetModelName(),
|
||||
"deepseek-reasoner",
|
||||
)
|
||||
t.Errorf("GetModelName() = %q, want %q", cfg.Agents.Defaults.GetModelName(), "deepseek-reasoner")
|
||||
}
|
||||
|
||||
if len(cfg.Agents.Defaults.ModelFallbacks) != 1 {
|
||||
|
||||
@@ -91,11 +91,9 @@ func TestConvertProvidersToModelList_LiteLLM(t *testing.T) {
|
||||
func TestConvertProvidersToModelList_Multiple(t *testing.T) {
|
||||
cfg := &configV0{
|
||||
Providers: providersConfigV0{
|
||||
OpenAI: openAIProviderConfigV0{
|
||||
providerConfigV0: providerConfigV0{APIKey: "openai-key"},
|
||||
},
|
||||
Groq: providerConfigV0{APIKey: "groq-key"},
|
||||
Zhipu: providerConfigV0{APIKey: "zhipu-key"},
|
||||
OpenAI: openAIProviderConfigV0{providerConfigV0: providerConfigV0{APIKey: "openai-key"}},
|
||||
Groq: providerConfigV0{APIKey: "groq-key"},
|
||||
Zhipu: providerConfigV0{APIKey: "zhipu-key"},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -144,13 +142,8 @@ func TestConvertProvidersToModelList_AllProviders(t *testing.T) {
|
||||
// Other providers have no configuration, so they won't be converted.
|
||||
cfg := &configV0{
|
||||
Providers: providersConfigV0{
|
||||
OpenAI: openAIProviderConfigV0{
|
||||
providerConfigV0: providerConfigV0{APIKey: "key1"},
|
||||
},
|
||||
LiteLLM: providerConfigV0{
|
||||
APIKey: "key-litellm",
|
||||
APIBase: "http://localhost:4000/v1",
|
||||
},
|
||||
OpenAI: openAIProviderConfigV0{providerConfigV0: providerConfigV0{APIKey: "key1"}},
|
||||
LiteLLM: providerConfigV0{APIKey: "key-litellm", APIBase: "http://localhost:4000/v1"},
|
||||
Anthropic: providerConfigV0{APIKey: "key2"},
|
||||
OpenRouter: providerConfigV0{APIKey: "key3"},
|
||||
Groq: providerConfigV0{APIKey: "key4"},
|
||||
@@ -268,11 +261,7 @@ func TestConvertProvidersToModelList_PreservesUserModel_DeepSeek(t *testing.T) {
|
||||
|
||||
// Should use user's model, not default
|
||||
if result[0].Model != "deepseek/deepseek-reasoner" {
|
||||
t.Errorf(
|
||||
"Model = %q, want %q (user's configured model)",
|
||||
result[0].Model,
|
||||
"deepseek/deepseek-reasoner",
|
||||
)
|
||||
t.Errorf("Model = %q, want %q (user's configured model)", result[0].Model, "deepseek/deepseek-reasoner")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -382,9 +371,7 @@ func TestConvertProvidersToModelList_MultipleProviders_PreservesUserModel(t *tes
|
||||
},
|
||||
},
|
||||
Providers: providersConfigV0{
|
||||
OpenAI: openAIProviderConfigV0{
|
||||
providerConfigV0: providerConfigV0{APIKey: "sk-openai"},
|
||||
},
|
||||
OpenAI: openAIProviderConfigV0{providerConfigV0: providerConfigV0{APIKey: "sk-openai"}},
|
||||
DeepSeek: providerConfigV0{APIKey: "sk-deepseek"},
|
||||
},
|
||||
}
|
||||
@@ -404,11 +391,7 @@ func TestConvertProvidersToModelList_MultipleProviders_PreservesUserModel(t *tes
|
||||
}
|
||||
case "deepseek":
|
||||
if mc.Model != "deepseek/deepseek-reasoner" {
|
||||
t.Errorf(
|
||||
"DeepSeek Model = %q, want %q (user's)",
|
||||
mc.Model,
|
||||
"deepseek/deepseek-reasoner",
|
||||
)
|
||||
t.Errorf("DeepSeek Model = %q, want %q (user's)", mc.Model, "deepseek/deepseek-reasoner")
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -506,11 +489,7 @@ func TestConvertProvidersToModelList_NoProviderField_SingleProvider(t *testing.T
|
||||
|
||||
// ModelName should be the user's model value for backward compatibility
|
||||
if result[0].ModelName != "glm-4.7" {
|
||||
t.Errorf(
|
||||
"ModelName = %q, want %q (user's model for backward compatibility)",
|
||||
result[0].ModelName,
|
||||
"glm-4.7",
|
||||
)
|
||||
t.Errorf("ModelName = %q, want %q (user's model for backward compatibility)", result[0].ModelName, "glm-4.7")
|
||||
}
|
||||
|
||||
// Model should use the user's model with protocol prefix
|
||||
@@ -531,10 +510,8 @@ func TestConvertProvidersToModelList_NoProviderField_MultipleProviders(t *testin
|
||||
},
|
||||
},
|
||||
Providers: providersConfigV0{
|
||||
OpenAI: openAIProviderConfigV0{
|
||||
providerConfigV0: providerConfigV0{APIKey: "openai-key"},
|
||||
},
|
||||
Zhipu: providerConfigV0{APIKey: "zhipu-key"},
|
||||
OpenAI: openAIProviderConfigV0{providerConfigV0: providerConfigV0{APIKey: "openai-key"}},
|
||||
Zhipu: providerConfigV0{APIKey: "zhipu-key"},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -594,11 +571,7 @@ func TestBuildModelWithProtocol_NoPrefix(t *testing.T) {
|
||||
func TestBuildModelWithProtocol_AlreadyHasPrefix(t *testing.T) {
|
||||
result := buildModelWithProtocol("openrouter", "openrouter/auto")
|
||||
if result != "openrouter/auto" {
|
||||
t.Errorf(
|
||||
"buildModelWithProtocol(openrouter, openrouter/auto) = %q, want %q",
|
||||
result,
|
||||
"openrouter/auto",
|
||||
)
|
||||
t.Errorf("buildModelWithProtocol(openrouter, openrouter/auto) = %q, want %q", result, "openrouter/auto")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -640,10 +613,6 @@ func TestConvertProvidersToModelList_LegacyModelWithProtocolPrefix(t *testing.T)
|
||||
|
||||
// Model should NOT have duplicated prefix
|
||||
if result[0].Model != "openrouter/auto" {
|
||||
t.Errorf(
|
||||
"Model = %q, want %q (should not duplicate prefix)",
|
||||
result[0].Model,
|
||||
"openrouter/auto",
|
||||
)
|
||||
t.Errorf("Model = %q, want %q (should not duplicate prefix)", result[0].Model, "openrouter/auto")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,11 +17,7 @@ func TestGetModelConfig_Found(t *testing.T) {
|
||||
Version: CurrentVersion,
|
||||
ModelList: []*ModelConfig{
|
||||
{ModelName: "test-model", Model: "openai/gpt-4o", APIKeys: SimpleSecureStrings("key1")},
|
||||
{
|
||||
ModelName: "other-model",
|
||||
Model: "anthropic/claude",
|
||||
APIKeys: SimpleSecureStrings("key2"),
|
||||
},
|
||||
{ModelName: "other-model", Model: "anthropic/claude", APIKeys: SimpleSecureStrings("key2")},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -118,16 +114,8 @@ func TestGetModelConfig_RoundRobinStartsFromFirstMatch(t *testing.T) {
|
||||
func TestGetModelConfig_Concurrent(t *testing.T) {
|
||||
cfg := &Config{
|
||||
ModelList: []*ModelConfig{
|
||||
{
|
||||
ModelName: "concurrent-model",
|
||||
Model: "openai/gpt-4o-1",
|
||||
APIKeys: SimpleSecureStrings("key1"),
|
||||
},
|
||||
{
|
||||
ModelName: "concurrent-model",
|
||||
Model: "openai/gpt-4o-2",
|
||||
APIKeys: SimpleSecureStrings("key2"),
|
||||
},
|
||||
{ModelName: "concurrent-model", Model: "openai/gpt-4o-1", APIKeys: SimpleSecureStrings("key1")},
|
||||
{ModelName: "concurrent-model", Model: "openai/gpt-4o-2", APIKeys: SimpleSecureStrings("key2")},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -302,11 +290,7 @@ func TestConfig_ValidateModelList(t *testing.T) {
|
||||
}
|
||||
if err != nil && tt.errMsg != "" {
|
||||
if !strings.Contains(err.Error(), tt.errMsg) {
|
||||
t.Errorf(
|
||||
"ValidateModelList() error = %v, want error containing %q",
|
||||
err,
|
||||
tt.errMsg,
|
||||
)
|
||||
t.Errorf("ValidateModelList() error = %v, want error containing %q", err, tt.errMsg)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
@@ -117,10 +117,7 @@ func TestExpandMultiKeyModels_WithExistingFallbacks(t *testing.T) {
|
||||
ModelName: "gpt-4",
|
||||
Model: "openai/gpt-4o",
|
||||
}
|
||||
modelCfg.APIKeys = SimpleSecureStrings(
|
||||
"key0",
|
||||
"key1",
|
||||
) // Use internal field for multi-key testing
|
||||
modelCfg.APIKeys = SimpleSecureStrings("key0", "key1") // Use internal field for multi-key testing
|
||||
modelCfg.Fallbacks = []string{"claude-3"}
|
||||
models := []*ModelConfig{modelCfg}
|
||||
|
||||
@@ -199,10 +196,7 @@ func TestExpandMultiKeyModels_PreservesOtherFields(t *testing.T) {
|
||||
RequestTimeout: 30,
|
||||
ThinkingLevel: "high",
|
||||
}
|
||||
modelCfg.APIKeys = SimpleSecureStrings(
|
||||
"key0",
|
||||
"key1",
|
||||
) // Use internal field for multi-key testing
|
||||
modelCfg.APIKeys = SimpleSecureStrings("key0", "key1") // Use internal field for multi-key testing
|
||||
models := []*ModelConfig{modelCfg}
|
||||
|
||||
result := expandMultiKeyModels(models)
|
||||
|
||||
@@ -304,13 +304,11 @@ func (s *SecureString) UnmarshalJSON(value []byte) error {
|
||||
|
||||
func (s SecureString) MarshalYAML() (any, error) {
|
||||
// Preserve raw value if it is already a reference (enc:// or file://)
|
||||
if strings.HasPrefix(s.raw, credential.EncScheme) ||
|
||||
strings.HasPrefix(s.raw, credential.FileScheme) {
|
||||
if strings.HasPrefix(s.raw, credential.EncScheme) || strings.HasPrefix(s.raw, credential.FileScheme) {
|
||||
return s.raw, nil
|
||||
}
|
||||
// If resolved is a reference format (e.g. set via Set), copy back to raw
|
||||
if strings.HasPrefix(s.resolved, credential.EncScheme) ||
|
||||
strings.HasPrefix(s.resolved, credential.FileScheme) {
|
||||
if strings.HasPrefix(s.resolved, credential.EncScheme) || strings.HasPrefix(s.resolved, credential.FileScheme) {
|
||||
s.raw = s.resolved
|
||||
return s.raw, nil
|
||||
}
|
||||
|
||||
@@ -35,10 +35,7 @@ func TestJSONUnmarshalPrivateFields(t *testing.T) {
|
||||
t.Errorf("PublicField = %q, want 'pub'", s.PublicField)
|
||||
}
|
||||
if s.privateField != "" {
|
||||
t.Errorf(
|
||||
"privateField = %q, want empty because unexported fields are ignored",
|
||||
s.privateField,
|
||||
)
|
||||
t.Errorf("privateField = %q, want empty because unexported fields are ignored", s.privateField)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -355,21 +352,13 @@ skills:
|
||||
|
||||
// Verify Channel tokens via Key() methods
|
||||
// Telegram
|
||||
assert.Equal(
|
||||
t,
|
||||
"123456789:ABCdefGHIjklMNOpqrsTUVwxyz",
|
||||
cfg.Channels.Telegram.Token.String(),
|
||||
)
|
||||
assert.Equal(t, "123456789:ABCdefGHIjklMNOpqrsTUVwxyz", cfg.Channels.Telegram.Token.String())
|
||||
t.Logf("Telegram Token(): %s", cfg.Channels.Telegram.Token.String())
|
||||
|
||||
// Feishu
|
||||
assert.Equal(t, "feishu_test_app_secret", cfg.Channels.Feishu.AppSecret.String())
|
||||
assert.Equal(t, "feishu_test_encrypt_key", cfg.Channels.Feishu.EncryptKey.String())
|
||||
assert.Equal(
|
||||
t,
|
||||
"feishu_test_verification_token",
|
||||
cfg.Channels.Feishu.VerificationToken.String(),
|
||||
)
|
||||
assert.Equal(t, "feishu_test_verification_token", cfg.Channels.Feishu.VerificationToken.String())
|
||||
t.Logf("Feishu AppSecret(): %s", cfg.Channels.Feishu.AppSecret.String())
|
||||
t.Logf("Feishu EncryptKey(): %s", cfg.Channels.Feishu.EncryptKey.String())
|
||||
t.Logf("Feishu VerificationToken(): %s", cfg.Channels.Feishu.VerificationToken.String())
|
||||
@@ -394,11 +383,7 @@ skills:
|
||||
|
||||
// LINE
|
||||
assert.Equal(t, "line_test_channel_secret", cfg.Channels.LINE.ChannelSecret.String())
|
||||
assert.Equal(
|
||||
t,
|
||||
"line_test_channel_access_token",
|
||||
cfg.Channels.LINE.ChannelAccessToken.String(),
|
||||
)
|
||||
assert.Equal(t, "line_test_channel_access_token", cfg.Channels.LINE.ChannelAccessToken.String())
|
||||
t.Logf("LINE ChannelSecret(): %s", cfg.Channels.LINE.ChannelSecret.String())
|
||||
t.Logf("LINE ChannelAccessToken(): %s", cfg.Channels.LINE.ChannelAccessToken.String())
|
||||
|
||||
@@ -446,11 +431,7 @@ skills:
|
||||
assert.Equal(t, "ghp-github-from-file-abc123", cfg.Tools.Skills.Github.Token.String())
|
||||
t.Logf("Github Token(): %s", cfg.Tools.Skills.Github.Token.String())
|
||||
|
||||
assert.Equal(
|
||||
t,
|
||||
"clawhub-auth-token-from-file",
|
||||
cfg.Tools.Skills.Registries.ClawHub.AuthToken.String(),
|
||||
)
|
||||
assert.Equal(t, "clawhub-auth-token-from-file", cfg.Tools.Skills.Registries.ClawHub.AuthToken.String())
|
||||
t.Logf("ClawHub AuthToken(): %s", cfg.Tools.Skills.Registries.ClawHub.AuthToken.String())
|
||||
|
||||
t.Log("All security keys are successfully accessible via their respective Key() methods")
|
||||
|
||||
+5
-17
@@ -15,10 +15,7 @@ import (
|
||||
|
||||
// JobExecutor is the interface for executing cron jobs through the agent
|
||||
type JobExecutor interface {
|
||||
ProcessDirectWithChannel(
|
||||
ctx context.Context,
|
||||
content, sessionKey, channel, chatID string,
|
||||
) (string, error)
|
||||
ProcessDirectWithChannel(ctx context.Context, content, sessionKey, channel, chatID string) (string, error)
|
||||
// PublishResponseIfNeeded sends response to the outbound bus only when the
|
||||
// agent did not already deliver content through the message tool in this round.
|
||||
PublishResponseIfNeeded(ctx context.Context, channel, chatID, response string)
|
||||
@@ -37,13 +34,8 @@ type CronTool struct {
|
||||
// NewCronTool creates a new CronTool
|
||||
// execTimeout: 0 means no timeout, >0 sets the timeout duration
|
||||
func NewCronTool(
|
||||
cronService *cron.CronService,
|
||||
executor JobExecutor,
|
||||
msgBus *bus.MessageBus,
|
||||
workspace string,
|
||||
restrict bool,
|
||||
execTimeout time.Duration,
|
||||
config *config.Config,
|
||||
cronService *cron.CronService, executor JobExecutor, msgBus *bus.MessageBus, workspace string, restrict bool,
|
||||
execTimeout time.Duration, config *config.Config,
|
||||
) (*CronTool, error) {
|
||||
allowCommand := true
|
||||
execEnabled := true
|
||||
@@ -164,9 +156,7 @@ func (t *CronTool) addJob(ctx context.Context, args map[string]any) *ToolResult
|
||||
chatID := ToolChatID(ctx)
|
||||
|
||||
if channel == "" || chatID == "" {
|
||||
return ErrorResult(
|
||||
"no session context (channel/chat_id not set). Use this tool in an active conversation.",
|
||||
)
|
||||
return ErrorResult("no session context (channel/chat_id not set). Use this tool in an active conversation.")
|
||||
}
|
||||
|
||||
message, ok := args["message"].(string)
|
||||
@@ -218,9 +208,7 @@ func (t *CronTool) addJob(ctx context.Context, args map[string]any) *ToolResult
|
||||
// Validate type parameter (server-side whitelist, not just LLM schema hint)
|
||||
msgType, _ := args["type"].(string)
|
||||
if msgType != "" && msgType != "message" && msgType != "directive" {
|
||||
return ErrorResult(
|
||||
fmt.Sprintf("invalid type %q, must be 'message' or 'directive'", msgType),
|
||||
)
|
||||
return ErrorResult(fmt.Sprintf("invalid type %q, must be 'message' or 'directive'", msgType))
|
||||
}
|
||||
|
||||
// GHSA-pv8c-p6jf-3fpp: command scheduling requires internal channel. When
|
||||
|
||||
+8
-34
@@ -49,11 +49,7 @@ func (s *stubJobExecutor) PublishResponseIfNeeded(
|
||||
s.publishedChatID = chatID
|
||||
}
|
||||
|
||||
func newTestCronToolWithExecutorAndConfig(
|
||||
t *testing.T,
|
||||
executor JobExecutor,
|
||||
cfg *config.Config,
|
||||
) *CronTool {
|
||||
func newTestCronToolWithExecutorAndConfig(t *testing.T, executor JobExecutor, cfg *config.Config) *CronTool {
|
||||
t.Helper()
|
||||
storePath := filepath.Join(t.TempDir(), "cron.json")
|
||||
cronService := cron.NewCronService(storePath, nil)
|
||||
@@ -106,10 +102,7 @@ func TestCronTool_CommandDoesNotRequireConfirmByDefault(t *testing.T) {
|
||||
})
|
||||
|
||||
if result.IsError {
|
||||
t.Fatalf(
|
||||
"expected command scheduling without confirm to succeed by default, got: %s",
|
||||
result.ForLLM,
|
||||
)
|
||||
t.Fatalf("expected command scheduling without confirm to succeed by default, got: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "Cron job added") {
|
||||
t.Errorf("expected 'Cron job added', got: %s", result.ForLLM)
|
||||
@@ -197,10 +190,7 @@ func TestCronTool_CommandAllowedFromInternalChannel(t *testing.T) {
|
||||
})
|
||||
|
||||
if result.IsError {
|
||||
t.Fatalf(
|
||||
"expected command scheduling to succeed from internal channel, got: %s",
|
||||
result.ForLLM,
|
||||
)
|
||||
t.Fatalf("expected command scheduling to succeed from internal channel, got: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "Cron job added") {
|
||||
t.Errorf("expected 'Cron job added', got: %s", result.ForLLM)
|
||||
@@ -235,10 +225,7 @@ func TestCronTool_NonCommandJobAllowedFromRemoteChannel(t *testing.T) {
|
||||
})
|
||||
|
||||
if result.IsError {
|
||||
t.Fatalf(
|
||||
"expected non-command reminder to succeed from remote channel, got: %s",
|
||||
result.ForLLM,
|
||||
)
|
||||
t.Fatalf("expected non-command reminder to succeed from remote channel, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -310,11 +297,7 @@ func TestCronTool_ExecuteJobPublishesAgentResponse(t *testing.T) {
|
||||
t.Fatalf("sessionKey = %q, want cron-job-1", executor.lastKey)
|
||||
}
|
||||
if executor.lastChan != "telegram" || executor.lastChatID != "chat-1" {
|
||||
t.Fatalf(
|
||||
"executor target = %s/%s, want telegram/chat-1",
|
||||
executor.lastChan,
|
||||
executor.lastChatID,
|
||||
)
|
||||
t.Fatalf("executor target = %s/%s, want telegram/chat-1", executor.lastChan, executor.lastChatID)
|
||||
}
|
||||
if executor.lastPrompt != "send me a poem" {
|
||||
t.Fatalf("prompt = %q, want original message", executor.lastPrompt)
|
||||
@@ -323,11 +306,7 @@ func TestCronTool_ExecuteJobPublishesAgentResponse(t *testing.T) {
|
||||
t.Fatalf("published response = %q, want generated reply", executor.publishedResp)
|
||||
}
|
||||
if executor.publishedChan != "telegram" || executor.publishedChatID != "chat-1" {
|
||||
t.Fatalf(
|
||||
"published target = %s/%s, want telegram/chat-1",
|
||||
executor.publishedChan,
|
||||
executor.publishedChatID,
|
||||
)
|
||||
t.Fatalf("published target = %s/%s, want telegram/chat-1", executor.publishedChan, executor.publishedChatID)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -363,10 +342,7 @@ func TestCronTool_ExecuteJobSkipsWhenMessageToolAlreadySent(t *testing.T) {
|
||||
}
|
||||
|
||||
if executor.publishedResp != "" {
|
||||
t.Fatalf(
|
||||
"expected no published response when message tool already sent, got: %q",
|
||||
executor.publishedResp,
|
||||
)
|
||||
t.Fatalf("expected no published response when message tool already sent, got: %q", executor.publishedResp)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -410,9 +386,7 @@ func TestCronTool_ExecuteJobDirectiveWithDeliverRoutesToAgent(t *testing.T) {
|
||||
}
|
||||
|
||||
if executor.lastPrompt == "" {
|
||||
t.Fatal(
|
||||
"expected agent to be called for directive+deliver, but ProcessDirectWithChannel was not invoked",
|
||||
)
|
||||
t.Fatal("expected agent to be called for directive+deliver, but ProcessDirectWithChannel was not invoked")
|
||||
}
|
||||
if executor.publishedResp != "agent processed" {
|
||||
t.Fatalf("published response = %q, want %q", executor.publishedResp, "agent processed")
|
||||
|
||||
+3
-14
@@ -16,11 +16,7 @@ type EditFileTool struct {
|
||||
}
|
||||
|
||||
// NewEditFileTool creates a new EditFileTool with optional directory restriction.
|
||||
func NewEditFileTool(
|
||||
workspace string,
|
||||
restrict bool,
|
||||
allowPaths ...[]*regexp.Regexp,
|
||||
) *EditFileTool {
|
||||
func NewEditFileTool(workspace string, restrict bool, allowPaths ...[]*regexp.Regexp) *EditFileTool {
|
||||
var patterns []*regexp.Regexp
|
||||
if len(allowPaths) > 0 {
|
||||
patterns = allowPaths[0]
|
||||
@@ -83,11 +79,7 @@ type AppendFileTool struct {
|
||||
fs fileSystem
|
||||
}
|
||||
|
||||
func NewAppendFileTool(
|
||||
workspace string,
|
||||
restrict bool,
|
||||
allowPaths ...[]*regexp.Regexp,
|
||||
) *AppendFileTool {
|
||||
func NewAppendFileTool(workspace string, restrict bool, allowPaths ...[]*regexp.Regexp) *AppendFileTool {
|
||||
var patterns []*regexp.Regexp
|
||||
if len(allowPaths) > 0 {
|
||||
patterns = allowPaths[0]
|
||||
@@ -174,10 +166,7 @@ func replaceEditContent(content []byte, oldText, newText string) ([]byte, error)
|
||||
|
||||
count := strings.Count(contentStr, oldText)
|
||||
if count > 1 {
|
||||
return nil, fmt.Errorf(
|
||||
"old_text appears %d times. Please provide more context to make it unique",
|
||||
count,
|
||||
)
|
||||
return nil, fmt.Errorf("old_text appears %d times. Please provide more context to make it unique", count)
|
||||
}
|
||||
|
||||
newContent := strings.Replace(contentStr, oldText, newText, 1)
|
||||
|
||||
@@ -76,8 +76,7 @@ func TestEditTool_EditFile_NotFound(t *testing.T) {
|
||||
}
|
||||
|
||||
// Should mention file not found
|
||||
if !strings.Contains(result.ForLLM, "not found") &&
|
||||
!strings.Contains(result.ForUser, "not found") {
|
||||
if !strings.Contains(result.ForLLM, "not found") && !strings.Contains(result.ForUser, "not found") {
|
||||
t.Errorf("Expected 'file not found' message, got ForLLM: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
@@ -104,8 +103,7 @@ func TestEditTool_EditFile_OldTextNotFound(t *testing.T) {
|
||||
}
|
||||
|
||||
// Should mention old_text not found
|
||||
if !strings.Contains(result.ForLLM, "not found") &&
|
||||
!strings.Contains(result.ForUser, "not found") {
|
||||
if !strings.Contains(result.ForLLM, "not found") && !strings.Contains(result.ForUser, "not found") {
|
||||
t.Errorf("Expected 'not found' message, got ForLLM: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
+3
-13
@@ -20,11 +20,7 @@ import (
|
||||
|
||||
const MaxReadFileSize = 64 * 1024 // 64KB limit to avoid context overflow
|
||||
|
||||
func validatePathWithAllowPaths(
|
||||
path, workspace string,
|
||||
restrict bool,
|
||||
patterns []*regexp.Regexp,
|
||||
) (string, error) {
|
||||
func validatePathWithAllowPaths(path, workspace string, restrict bool, patterns []*regexp.Regexp) (string, error) {
|
||||
if workspace == "" {
|
||||
return path, fmt.Errorf("workspace is not defined")
|
||||
}
|
||||
@@ -487,11 +483,7 @@ type WriteFileTool struct {
|
||||
fs fileSystem
|
||||
}
|
||||
|
||||
func NewWriteFileTool(
|
||||
workspace string,
|
||||
restrict bool,
|
||||
allowPaths ...[]*regexp.Regexp,
|
||||
) *WriteFileTool {
|
||||
func NewWriteFileTool(workspace string, restrict bool, allowPaths ...[]*regexp.Regexp) *WriteFileTool {
|
||||
var patterns []*regexp.Regexp
|
||||
if len(allowPaths) > 0 {
|
||||
patterns = allowPaths[0]
|
||||
@@ -544,9 +536,7 @@ func (t *WriteFileTool) Execute(ctx context.Context, args map[string]any) *ToolR
|
||||
|
||||
if !overwrite {
|
||||
if _, err := t.fs.Open(path); err == nil {
|
||||
return ErrorResult(
|
||||
fmt.Sprintf("file: %s already exists. Set overwrite=true to replace.", path),
|
||||
)
|
||||
return ErrorResult(fmt.Sprintf("file: %s already exists. Set overwrite=true to replace.", path))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -59,13 +59,8 @@ func TestFilesystemTool_ReadFile_NotFound(t *testing.T) {
|
||||
}
|
||||
|
||||
// Should contain error message
|
||||
if !strings.Contains(result.ForLLM, "failed to open file") &&
|
||||
!strings.Contains(result.ForUser, "failed to read") {
|
||||
t.Errorf(
|
||||
"Expected error message, got ForLLM: %s, ForUser: %s",
|
||||
result.ForLLM,
|
||||
result.ForUser,
|
||||
)
|
||||
if !strings.Contains(result.ForLLM, "failed to open file") && !strings.Contains(result.ForUser, "failed to read") {
|
||||
t.Errorf("Expected error message, got ForLLM: %s, ForUser: %s", result.ForLLM, result.ForUser)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -83,8 +78,7 @@ func TestFilesystemTool_ReadFile_MissingPath(t *testing.T) {
|
||||
}
|
||||
|
||||
// Should mention required parameter
|
||||
if !strings.Contains(result.ForLLM, "path is required") &&
|
||||
!strings.Contains(result.ForUser, "path is required") {
|
||||
if !strings.Contains(result.ForLLM, "path is required") && !strings.Contains(result.ForUser, "path is required") {
|
||||
t.Errorf("Expected 'path is required' message, got ForLLM: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
@@ -303,12 +297,7 @@ func TestFilesystemTool_WriteFile_OverwriteSandboxed(t *testing.T) {
|
||||
"content": "replaced in sandbox",
|
||||
"overwrite": true,
|
||||
})
|
||||
assert.False(
|
||||
t,
|
||||
result.IsError,
|
||||
"expected success in sandbox mode with overwrite=true, got: %s",
|
||||
result.ForLLM,
|
||||
)
|
||||
assert.False(t, result.IsError, "expected success in sandbox mode with overwrite=true, got: %s", result.ForLLM)
|
||||
|
||||
data, err := os.ReadFile(filepath.Join(workspace, testFile))
|
||||
assert.NoError(t, err)
|
||||
@@ -336,8 +325,7 @@ func TestFilesystemTool_ListDir_Success(t *testing.T) {
|
||||
}
|
||||
|
||||
// Should list files and directories
|
||||
if !strings.Contains(result.ForLLM, "file1.txt") ||
|
||||
!strings.Contains(result.ForLLM, "file2.txt") {
|
||||
if !strings.Contains(result.ForLLM, "file1.txt") || !strings.Contains(result.ForLLM, "file2.txt") {
|
||||
t.Errorf("Expected files in listing, got: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "subdir") {
|
||||
@@ -361,13 +349,8 @@ func TestFilesystemTool_ListDir_NotFound(t *testing.T) {
|
||||
}
|
||||
|
||||
// Should contain error message
|
||||
if !strings.Contains(result.ForLLM, "failed to read") &&
|
||||
!strings.Contains(result.ForUser, "failed to read") {
|
||||
t.Errorf(
|
||||
"Expected error message, got ForLLM: %s, ForUser: %s",
|
||||
result.ForLLM,
|
||||
result.ForUser,
|
||||
)
|
||||
if !strings.Contains(result.ForLLM, "failed to read") && !strings.Contains(result.ForUser, "failed to read") {
|
||||
t.Errorf("Expected error message, got ForLLM: %s, ForUser: %s", result.ForLLM, result.ForUser)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -414,8 +397,7 @@ func TestFilesystemTool_ReadFile_RejectsSymlinkEscape(t *testing.T) {
|
||||
// os.Root might return different errors depending on platform/implementation
|
||||
// but it definitely should error.
|
||||
// Our wrapper returns "access denied or file not found"
|
||||
if !strings.Contains(result.ForLLM, "access denied") &&
|
||||
!strings.Contains(result.ForLLM, "file not found") &&
|
||||
if !strings.Contains(result.ForLLM, "access denied") && !strings.Contains(result.ForLLM, "file not found") &&
|
||||
!strings.Contains(result.ForLLM, "no such file") {
|
||||
t.Fatalf("expected symlink escape error, got: %s", result.ForLLM)
|
||||
}
|
||||
@@ -434,20 +416,10 @@ func TestFilesystemTool_EmptyWorkspace_AccessDenied(t *testing.T) {
|
||||
})
|
||||
|
||||
// We EXPECT IsError=true (access blocked due to empty workspace)
|
||||
assert.True(
|
||||
t,
|
||||
result.IsError,
|
||||
"Security Regression: Empty workspace allowed access! content: %s",
|
||||
result.ForLLM,
|
||||
)
|
||||
assert.True(t, result.IsError, "Security Regression: Empty workspace allowed access! content: %s", result.ForLLM)
|
||||
|
||||
// Verify it failed for the right reason
|
||||
assert.Contains(
|
||||
t,
|
||||
result.ForLLM,
|
||||
"workspace is not defined",
|
||||
"Expected 'workspace is not defined' error",
|
||||
)
|
||||
assert.Contains(t, result.ForLLM, "workspace is not defined", "Expected 'workspace is not defined' error")
|
||||
}
|
||||
|
||||
// TestRootMkdirAll verifies that root.MkdirAll (used by atomicWriteFileInRoot) handles all cases:
|
||||
@@ -681,10 +653,7 @@ func TestWhitelistFs_BlocksSymlinkEscapeInAllowedDir(t *testing.T) {
|
||||
patterns := []*regexp.Regexp{regexp.MustCompile(`^` + regexp.QuoteMeta(allowedDir))}
|
||||
tool := NewReadFileTool(workspace, true, MaxReadFileSize, patterns)
|
||||
|
||||
result := tool.Execute(
|
||||
context.Background(),
|
||||
map[string]any{"path": filepath.Join(linkPath, "secret.txt")},
|
||||
)
|
||||
result := tool.Execute(context.Background(), map[string]any{"path": filepath.Join(linkPath, "secret.txt")})
|
||||
if !result.IsError {
|
||||
t.Fatalf("expected symlink escape from allowed dir to be blocked, got: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
+2
-6
@@ -65,9 +65,7 @@ func (t *I2CTool) Parameters() map[string]any {
|
||||
|
||||
func (t *I2CTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
|
||||
if runtime.GOOS != "linux" {
|
||||
return ErrorResult(
|
||||
"I2C is only supported on Linux. This tool requires /dev/i2c-* device files.",
|
||||
)
|
||||
return ErrorResult("I2C is only supported on Linux. This tool requires /dev/i2c-* device files.")
|
||||
}
|
||||
|
||||
action, ok := args["action"].(string)
|
||||
@@ -85,9 +83,7 @@ func (t *I2CTool) Execute(ctx context.Context, args map[string]any) *ToolResult
|
||||
case "write":
|
||||
return t.writeDevice(args)
|
||||
default:
|
||||
return ErrorResult(
|
||||
fmt.Sprintf("unknown action: %s (valid: detect, scan, read, write)", action),
|
||||
)
|
||||
return ErrorResult(fmt.Sprintf("unknown action: %s (valid: detect, scan, read, write)", action))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+7
-35
@@ -55,12 +55,7 @@ func smbusProbe(fd int, addr int, hasQuick bool) bool {
|
||||
size: i2cSmbusQuick,
|
||||
data: nil,
|
||||
}
|
||||
_, _, errno := syscall.Syscall(
|
||||
syscall.SYS_IOCTL,
|
||||
uintptr(fd),
|
||||
i2cSmbus,
|
||||
uintptr(unsafe.Pointer(&args)),
|
||||
)
|
||||
_, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), i2cSmbus, uintptr(unsafe.Pointer(&args)))
|
||||
return errno == 0
|
||||
}
|
||||
|
||||
@@ -72,12 +67,7 @@ func smbusProbe(fd int, addr int, hasQuick bool) bool {
|
||||
size: i2cSmbusByte,
|
||||
data: &data,
|
||||
}
|
||||
_, _, errno := syscall.Syscall(
|
||||
syscall.SYS_IOCTL,
|
||||
uintptr(fd),
|
||||
i2cSmbus,
|
||||
uintptr(unsafe.Pointer(&args)),
|
||||
)
|
||||
_, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), i2cSmbus, uintptr(unsafe.Pointer(&args)))
|
||||
return errno == 0
|
||||
}
|
||||
|
||||
@@ -93,29 +83,16 @@ func (t *I2CTool) scan(args map[string]any) *ToolResult {
|
||||
devPath := fmt.Sprintf("/dev/i2c-%s", bus)
|
||||
fd, err := syscall.Open(devPath, syscall.O_RDWR, 0)
|
||||
if err != nil {
|
||||
return ErrorResult(
|
||||
fmt.Sprintf(
|
||||
"failed to open %s: %v (check permissions and i2c-dev module)",
|
||||
devPath,
|
||||
err,
|
||||
),
|
||||
)
|
||||
return ErrorResult(fmt.Sprintf("failed to open %s: %v (check permissions and i2c-dev module)", devPath, err))
|
||||
}
|
||||
defer syscall.Close(fd)
|
||||
|
||||
// Query adapter capabilities to determine available probe methods.
|
||||
// I2C_FUNCS writes an unsigned long, which is word-sized on Linux.
|
||||
var funcs uintptr
|
||||
_, _, errno := syscall.Syscall(
|
||||
syscall.SYS_IOCTL,
|
||||
uintptr(fd),
|
||||
i2cFuncs,
|
||||
uintptr(unsafe.Pointer(&funcs)),
|
||||
)
|
||||
_, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), i2cFuncs, uintptr(unsafe.Pointer(&funcs)))
|
||||
if errno != 0 {
|
||||
return ErrorResult(
|
||||
fmt.Sprintf("failed to query I2C adapter capabilities on %s: %v", devPath, errno),
|
||||
)
|
||||
return ErrorResult(fmt.Sprintf("failed to query I2C adapter capabilities on %s: %v", devPath, errno))
|
||||
}
|
||||
|
||||
hasQuick := funcs&i2cFuncSmbusQuick != 0
|
||||
@@ -123,10 +100,7 @@ func (t *I2CTool) scan(args map[string]any) *ToolResult {
|
||||
|
||||
if !hasQuick && !hasReadByte {
|
||||
return ErrorResult(
|
||||
fmt.Sprintf(
|
||||
"I2C adapter %s supports neither SMBus Quick nor Read Byte — cannot probe safely",
|
||||
devPath,
|
||||
),
|
||||
fmt.Sprintf("I2C adapter %s supports neither SMBus Quick nor Read Byte — cannot probe safely", devPath),
|
||||
)
|
||||
}
|
||||
|
||||
@@ -158,9 +132,7 @@ func (t *I2CTool) scan(args map[string]any) *ToolResult {
|
||||
}
|
||||
|
||||
if len(found) == 0 {
|
||||
return SilentResult(
|
||||
fmt.Sprintf("No devices found on %s. Check wiring and pull-up resistors.", devPath),
|
||||
)
|
||||
return SilentResult(fmt.Sprintf("No devices found on %s. Check wiring and pull-up resistors.", devPath))
|
||||
}
|
||||
|
||||
result, _ := json.MarshalIndent(map[string]any{
|
||||
|
||||
+6
-28
@@ -314,10 +314,7 @@ func (t *MCPTool) normalizeResultContent(ctx context.Context, content []mcp.Cont
|
||||
return result
|
||||
}
|
||||
|
||||
func (t *MCPTool) storeEmbeddedResource(
|
||||
ctx context.Context,
|
||||
content *mcp.EmbeddedResource,
|
||||
) (string, string) {
|
||||
func (t *MCPTool) storeEmbeddedResource(ctx context.Context, content *mcp.EmbeddedResource) (string, string) {
|
||||
if content == nil || content.Resource == nil {
|
||||
return "", "[MCP returned an embedded resource without data.]"
|
||||
}
|
||||
@@ -377,39 +374,23 @@ func (t *MCPTool) storeBinaryContent(
|
||||
|
||||
dir := media.TempDir()
|
||||
if err := os.MkdirAll(dir, 0o700); err != nil {
|
||||
return "", fmt.Sprintf(
|
||||
"[MCP returned %s content (%s) but it could not be stored.]",
|
||||
kind,
|
||||
mimeType,
|
||||
)
|
||||
return "", fmt.Sprintf("[MCP returned %s content (%s) but it could not be stored.]", kind, mimeType)
|
||||
}
|
||||
|
||||
ext := extensionForMIMEType(mimeType)
|
||||
tmpFile, err := os.CreateTemp(dir, "mcp-*"+ext)
|
||||
if err != nil {
|
||||
return "", fmt.Sprintf(
|
||||
"[MCP returned %s content (%s) but it could not be stored.]",
|
||||
kind,
|
||||
mimeType,
|
||||
)
|
||||
return "", fmt.Sprintf("[MCP returned %s content (%s) but it could not be stored.]", kind, mimeType)
|
||||
}
|
||||
tmpPath := tmpFile.Name()
|
||||
if _, err = tmpFile.Write(data); err != nil {
|
||||
_ = tmpFile.Close()
|
||||
_ = os.Remove(tmpPath)
|
||||
return "", fmt.Sprintf(
|
||||
"[MCP returned %s content (%s) but it could not be stored.]",
|
||||
kind,
|
||||
mimeType,
|
||||
)
|
||||
return "", fmt.Sprintf("[MCP returned %s content (%s) but it could not be stored.]", kind, mimeType)
|
||||
}
|
||||
if err = tmpFile.Close(); err != nil {
|
||||
_ = os.Remove(tmpPath)
|
||||
return "", fmt.Sprintf(
|
||||
"[MCP returned %s content (%s) but it could not be stored.]",
|
||||
kind,
|
||||
mimeType,
|
||||
)
|
||||
return "", fmt.Sprintf("[MCP returned %s content (%s) but it could not be stored.]", kind, mimeType)
|
||||
}
|
||||
|
||||
scope := fmt.Sprintf(
|
||||
@@ -489,10 +470,7 @@ func summarizeEmbeddedResource(content *mcp.EmbeddedResource) string {
|
||||
normalizedMIMEType(resource.MIMEType),
|
||||
)
|
||||
}
|
||||
return fmt.Sprintf(
|
||||
"[MCP returned embedded resource (%s).]",
|
||||
normalizedMIMEType(resource.MIMEType),
|
||||
)
|
||||
return fmt.Sprintf("[MCP returned embedded resource (%s).]", normalizedMIMEType(resource.MIMEType))
|
||||
}
|
||||
|
||||
func annotationsAllowUser(annotations *mcp.Annotations) bool {
|
||||
|
||||
@@ -571,10 +571,7 @@ func TestMCPTool_Execute_EmbeddedResourceBlobStoredAsMedia(t *testing.T) {
|
||||
result := mcpTool.Execute(WithToolContext(context.Background(), "telegram", "chat-42"), nil)
|
||||
|
||||
if len(result.Media) != 1 {
|
||||
t.Fatalf(
|
||||
"expected embedded resource blob to be stored as media, got %d refs",
|
||||
len(result.Media),
|
||||
)
|
||||
t.Fatalf("expected embedded resource blob to be stored as media, got %d refs", len(result.Media))
|
||||
}
|
||||
path, _, err := store.ResolveWithMeta(result.Media[0])
|
||||
if err != nil {
|
||||
|
||||
@@ -43,10 +43,7 @@ func TestMessageTool_Execute_Success(t *testing.T) {
|
||||
|
||||
// - ForLLM contains send status description
|
||||
if result.ForLLM != "Message sent to test-channel:test-chat-id" {
|
||||
t.Errorf(
|
||||
"Expected ForLLM 'Message sent to test-channel:test-chat-id', got '%s'",
|
||||
result.ForLLM,
|
||||
)
|
||||
t.Errorf("Expected ForLLM 'Message sent to test-channel:test-chat-id', got '%s'", result.ForLLM)
|
||||
}
|
||||
|
||||
// - ForUser is empty (user already received message directly)
|
||||
@@ -91,10 +88,7 @@ func TestMessageTool_Execute_WithCustomChannel(t *testing.T) {
|
||||
t.Error("Expected Silent=true")
|
||||
}
|
||||
if result.ForLLM != "Message sent to custom-channel:custom-chat-id" {
|
||||
t.Errorf(
|
||||
"Expected ForLLM 'Message sent to custom-channel:custom-chat-id', got '%s'",
|
||||
result.ForLLM,
|
||||
)
|
||||
t.Errorf("Expected ForLLM 'Message sent to custom-channel:custom-chat-id', got '%s'", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -215,43 +215,28 @@ func storeInlineDataURL(
|
||||
payload = strings.NewReplacer("\n", "", "\r", "", "\t", "", " ", "").Replace(payload)
|
||||
decoded, err := base64.StdEncoding.DecodeString(payload)
|
||||
if err != nil {
|
||||
return "", fmt.Sprintf(
|
||||
"[Tool returned inline media content (%s) that could not be decoded.]",
|
||||
mimeType,
|
||||
)
|
||||
return "", fmt.Sprintf("[Tool returned inline media content (%s) that could not be decoded.]", mimeType)
|
||||
}
|
||||
|
||||
dir := media.TempDir()
|
||||
if err = os.MkdirAll(dir, 0o700); err != nil {
|
||||
return "", fmt.Sprintf(
|
||||
"[Tool returned inline media content (%s) but it could not be stored.]",
|
||||
mimeType,
|
||||
)
|
||||
return "", fmt.Sprintf("[Tool returned inline media content (%s) but it could not be stored.]", mimeType)
|
||||
}
|
||||
|
||||
ext := extensionForMIMEType(mimeType)
|
||||
tmpFile, err := os.CreateTemp(dir, "tool-inline-*"+ext)
|
||||
if err != nil {
|
||||
return "", fmt.Sprintf(
|
||||
"[Tool returned inline media content (%s) but it could not be stored.]",
|
||||
mimeType,
|
||||
)
|
||||
return "", fmt.Sprintf("[Tool returned inline media content (%s) but it could not be stored.]", mimeType)
|
||||
}
|
||||
tmpPath := tmpFile.Name()
|
||||
if _, err = tmpFile.Write(decoded); err != nil {
|
||||
tmpFile.Close()
|
||||
_ = os.Remove(tmpPath)
|
||||
return "", fmt.Sprintf(
|
||||
"[Tool returned inline media content (%s) but it could not be stored.]",
|
||||
mimeType,
|
||||
)
|
||||
return "", fmt.Sprintf("[Tool returned inline media content (%s) but it could not be stored.]", mimeType)
|
||||
}
|
||||
if err = tmpFile.Close(); err != nil {
|
||||
_ = os.Remove(tmpPath)
|
||||
return "", fmt.Sprintf(
|
||||
"[Tool returned inline media content (%s) but it could not be stored.]",
|
||||
mimeType,
|
||||
)
|
||||
return "", fmt.Sprintf("[Tool returned inline media content (%s) but it could not be stored.]", mimeType)
|
||||
}
|
||||
|
||||
filename := sanitizeIdentifierComponent(toolName) + ext
|
||||
@@ -270,10 +255,7 @@ func storeInlineDataURL(
|
||||
}, scope)
|
||||
if err != nil {
|
||||
_ = os.Remove(tmpPath)
|
||||
return "", fmt.Sprintf(
|
||||
"[Tool returned inline media content (%s) but it could not be registered.]",
|
||||
mimeType,
|
||||
)
|
||||
return "", fmt.Sprintf("[Tool returned inline media content (%s) but it could not be registered.]", mimeType)
|
||||
}
|
||||
|
||||
return ref, fmt.Sprintf(inlineMediaStoredMessage, mimeType)
|
||||
|
||||
+1
-4
@@ -80,10 +80,7 @@ func (tr *ToolResult) ContentForLLM() string {
|
||||
}
|
||||
}
|
||||
if len(tr.ArtifactTags) > 0 {
|
||||
artifactNote := "Local artifact paths: " + strings.Join(
|
||||
tr.ArtifactTags,
|
||||
" ",
|
||||
) + "\n" + artifactPathsLLMNote
|
||||
artifactNote := "Local artifact paths: " + strings.Join(tr.ArtifactTags, " ") + "\n" + artifactPathsLLMNote
|
||||
if content == "" {
|
||||
content = artifactNote
|
||||
} else if !strings.Contains(content, artifactNote) {
|
||||
|
||||
@@ -142,11 +142,7 @@ func TestToolResultJSONSerialization(t *testing.T) {
|
||||
t.Errorf("ForLLM mismatch: got '%s', want '%s'", decoded.ForLLM, tt.result.ForLLM)
|
||||
}
|
||||
if decoded.ForUser != tt.result.ForUser {
|
||||
t.Errorf(
|
||||
"ForUser mismatch: got '%s', want '%s'",
|
||||
decoded.ForUser,
|
||||
tt.result.ForUser,
|
||||
)
|
||||
t.Errorf("ForUser mismatch: got '%s', want '%s'", decoded.ForUser, tt.result.ForUser)
|
||||
}
|
||||
if decoded.Silent != tt.result.Silent {
|
||||
t.Errorf("Silent mismatch: got %v, want %v", decoded.Silent, tt.result.Silent)
|
||||
|
||||
@@ -56,38 +56,19 @@ func (t *RegexSearchTool) Execute(ctx context.Context, args map[string]any) *Too
|
||||
}
|
||||
|
||||
if len(pattern) > MaxRegexPatternLength {
|
||||
logger.WarnCF(
|
||||
"discovery",
|
||||
"Regex pattern rejected (too long)",
|
||||
map[string]any{"len": len(pattern)},
|
||||
)
|
||||
return ErrorResult(
|
||||
fmt.Sprintf("Pattern too long: max %d characters allowed", MaxRegexPatternLength),
|
||||
)
|
||||
logger.WarnCF("discovery", "Regex pattern rejected (too long)", map[string]any{"len": len(pattern)})
|
||||
return ErrorResult(fmt.Sprintf("Pattern too long: max %d characters allowed", MaxRegexPatternLength))
|
||||
}
|
||||
|
||||
logger.DebugCF("discovery", "Regex search", map[string]any{"pattern": pattern})
|
||||
|
||||
res, err := t.registry.SearchRegex(pattern, t.maxSearchResults)
|
||||
if err != nil {
|
||||
logger.WarnCF(
|
||||
"discovery",
|
||||
"Invalid regex pattern",
|
||||
map[string]any{"pattern": pattern, "error": err.Error()},
|
||||
)
|
||||
return ErrorResult(
|
||||
fmt.Sprintf(
|
||||
"Invalid regex pattern syntax: %v. Please fix your regex and try again.",
|
||||
err,
|
||||
),
|
||||
)
|
||||
logger.WarnCF("discovery", "Invalid regex pattern", map[string]any{"pattern": pattern, "error": err.Error()})
|
||||
return ErrorResult(fmt.Sprintf("Invalid regex pattern syntax: %v. Please fix your regex and try again.", err))
|
||||
}
|
||||
|
||||
logger.InfoCF(
|
||||
"discovery",
|
||||
"Regex search completed",
|
||||
map[string]any{"pattern": pattern, "results": len(res)},
|
||||
)
|
||||
logger.InfoCF("discovery", "Regex search completed", map[string]any{"pattern": pattern, "results": len(res)})
|
||||
return formatDiscoveryResponse(t.registry, res, t.ttl)
|
||||
}
|
||||
|
||||
@@ -157,11 +138,7 @@ func (t *BM25SearchTool) Execute(ctx context.Context, args map[string]any) *Tool
|
||||
}
|
||||
}
|
||||
|
||||
logger.InfoCF(
|
||||
"discovery",
|
||||
"BM25 search completed",
|
||||
map[string]any{"query": query, "results": len(results)},
|
||||
)
|
||||
logger.InfoCF("discovery", "BM25 search completed", map[string]any{"query": query, "results": len(results)})
|
||||
return formatDiscoveryResponse(t.registry, results, t.ttl)
|
||||
}
|
||||
|
||||
@@ -173,10 +150,7 @@ type ToolSearchResult struct {
|
||||
Description string `json:"description"`
|
||||
}
|
||||
|
||||
func (r *ToolRegistry) SearchRegex(
|
||||
pattern string,
|
||||
maxSearchResults int,
|
||||
) ([]ToolSearchResult, error) {
|
||||
func (r *ToolRegistry) SearchRegex(pattern string, maxSearchResults int) ([]ToolSearchResult, error) {
|
||||
if maxSearchResults <= 0 {
|
||||
return nil, nil
|
||||
}
|
||||
@@ -214,11 +188,7 @@ func (r *ToolRegistry) SearchRegex(
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func formatDiscoveryResponse(
|
||||
registry *ToolRegistry,
|
||||
results []ToolSearchResult,
|
||||
ttl int,
|
||||
) *ToolResult {
|
||||
func formatDiscoveryResponse(registry *ToolRegistry, results []ToolSearchResult, ttl int) *ToolResult {
|
||||
if len(results) == 0 {
|
||||
return SilentResult("No tools found matching the query.")
|
||||
}
|
||||
@@ -304,11 +274,7 @@ func (t *BM25SearchTool) getOrBuildEngine() *bm25CachedEngine {
|
||||
cached := &bm25CachedEngine{engine: buildBM25Engine(docs)}
|
||||
t.cachedEngine = cached
|
||||
t.cacheVersion = snap.Version
|
||||
logger.DebugCF(
|
||||
"discovery",
|
||||
"BM25 engine rebuilt",
|
||||
map[string]any{"docs": len(docs), "version": snap.Version},
|
||||
)
|
||||
logger.DebugCF("discovery", "BM25 engine rebuilt", map[string]any{"docs": len(docs), "version": snap.Version})
|
||||
return cached
|
||||
}
|
||||
|
||||
|
||||
@@ -93,10 +93,7 @@ func TestRegexSearchTool_Execute(t *testing.T) {
|
||||
reg.mu.RLock()
|
||||
defer reg.mu.RUnlock()
|
||||
if reg.tools["mcp_read_file"].TTL != 5 {
|
||||
t.Errorf(
|
||||
"Expected TTL of 'mcp_read_file' to be promoted to 5, got %d",
|
||||
reg.tools["mcp_read_file"].TTL,
|
||||
)
|
||||
t.Errorf("Expected TTL of 'mcp_read_file' to be promoted to 5, got %d", reg.tools["mcp_read_file"].TTL)
|
||||
}
|
||||
if reg.tools["mcp_fetch_net"].TTL != 0 {
|
||||
t.Errorf("Expected 'mcp_fetch_net' to NOT be promoted (TTL=0)")
|
||||
|
||||
@@ -142,10 +142,7 @@ func (t *SendFileTool) Execute(ctx context.Context, args map[string]any) *ToolRe
|
||||
return ErrorResult(fmt.Sprintf("failed to register media: %v", err))
|
||||
}
|
||||
|
||||
return MediaResult(
|
||||
fmt.Sprintf("File %q sent to user", filename),
|
||||
[]string{ref},
|
||||
).WithResponseHandled()
|
||||
return MediaResult(fmt.Sprintf("File %q sent to user", filename), []string{ref}).WithResponseHandled()
|
||||
}
|
||||
|
||||
// detectMediaType determines the MIME type of a file.
|
||||
|
||||
@@ -79,11 +79,7 @@ func TestSendFileTool_FileTooLarge(t *testing.T) {
|
||||
func TestSendFileTool_DefaultMaxSize(t *testing.T) {
|
||||
tool := NewSendFileTool("/tmp", false, 0, nil)
|
||||
if tool.maxFileSize != config.DefaultMaxMediaSize {
|
||||
t.Errorf(
|
||||
"expected default max size %d, got %d",
|
||||
config.DefaultMaxMediaSize,
|
||||
tool.maxFileSize,
|
||||
)
|
||||
t.Errorf("expected default max size %d, got %d", config.DefaultMaxMediaSize, tool.maxFileSize)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -166,11 +162,7 @@ func TestSendFileTool_AllowsWhitelistedMediaTempPath(t *testing.T) {
|
||||
t.Cleanup(func() { _ = os.Remove(testPath) })
|
||||
|
||||
pattern := regexp.MustCompile(
|
||||
"^" + regexp.QuoteMeta(
|
||||
filepath.Clean(mediaDir),
|
||||
) + "(?:" + regexp.QuoteMeta(
|
||||
string(os.PathSeparator),
|
||||
) + "|$)",
|
||||
"^" + regexp.QuoteMeta(filepath.Clean(mediaDir)) + "(?:" + regexp.QuoteMeta(string(os.PathSeparator)) + "|$)",
|
||||
)
|
||||
|
||||
store := media.NewFileMediaStore()
|
||||
|
||||
+14
-58
@@ -113,11 +113,7 @@ var (
|
||||
}
|
||||
)
|
||||
|
||||
func NewExecTool(
|
||||
workingDir string,
|
||||
restrict bool,
|
||||
allowPaths ...[]*regexp.Regexp,
|
||||
) (*ExecTool, error) {
|
||||
func NewExecTool(workingDir string, restrict bool, allowPaths ...[]*regexp.Regexp) (*ExecTool, error) {
|
||||
return NewExecToolWithConfig(workingDir, restrict, nil, allowPaths...)
|
||||
}
|
||||
|
||||
@@ -197,16 +193,8 @@ func (t *ExecTool) Parameters() map[string]any {
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"action": map[string]any{
|
||||
"type": "string",
|
||||
"enum": []string{
|
||||
"run",
|
||||
"list",
|
||||
"poll",
|
||||
"read",
|
||||
"write",
|
||||
"kill",
|
||||
"send-keys",
|
||||
},
|
||||
"type": "string",
|
||||
"enum": []string{"run", "list", "poll", "read", "write", "kill", "send-keys"},
|
||||
"description": "Action: run (execute command), list (show sessions), poll (check status), read (get output), write (send input), kill (terminate), send-keys (send keys to PTY)",
|
||||
},
|
||||
"command": map[string]any{
|
||||
@@ -312,12 +300,7 @@ func (t *ExecTool) executeRun(ctx context.Context, args map[string]any) *ToolRes
|
||||
cwd := t.workingDir
|
||||
if wd, ok := args["cwd"].(string); ok && wd != "" {
|
||||
if t.restrictToWorkspace && t.workingDir != "" {
|
||||
resolvedWD, err := validatePathWithAllowPaths(
|
||||
wd,
|
||||
t.workingDir,
|
||||
true,
|
||||
t.allowedPathPatterns,
|
||||
)
|
||||
resolvedWD, err := validatePathWithAllowPaths(wd, t.workingDir, true, t.allowedPathPatterns)
|
||||
if err != nil {
|
||||
return ErrorResult("Command blocked by safety guard (" + err.Error() + ")")
|
||||
}
|
||||
@@ -343,9 +326,7 @@ func (t *ExecTool) executeRun(ctx context.Context, args map[string]any) *ToolRes
|
||||
if t.restrictToWorkspace && t.workingDir != "" && cwd != t.workingDir {
|
||||
resolved, err := filepath.EvalSymlinks(cwd)
|
||||
if err != nil {
|
||||
return ErrorResult(
|
||||
fmt.Sprintf("Command blocked by safety guard (path resolution failed: %v)", err),
|
||||
)
|
||||
return ErrorResult(fmt.Sprintf("Command blocked by safety guard (path resolution failed: %v)", err))
|
||||
}
|
||||
if isAllowedPath(resolved, t.allowedPathPatterns) {
|
||||
cwd = resolved
|
||||
@@ -383,14 +364,7 @@ func (t *ExecTool) runSync(ctx context.Context, command, cwd string) *ToolResult
|
||||
|
||||
var cmd *exec.Cmd
|
||||
if runtime.GOOS == "windows" {
|
||||
cmd = exec.CommandContext(
|
||||
cmdCtx,
|
||||
"powershell",
|
||||
"-NoProfile",
|
||||
"-NonInteractive",
|
||||
"-Command",
|
||||
command,
|
||||
)
|
||||
cmd = exec.CommandContext(cmdCtx, "powershell", "-NoProfile", "-NonInteractive", "-Command", command)
|
||||
} else {
|
||||
cmd = exec.CommandContext(cmdCtx, "sh", "-c", command)
|
||||
}
|
||||
@@ -468,10 +442,7 @@ func (t *ExecTool) runSync(ctx context.Context, command, cwd string) *ToolResult
|
||||
|
||||
maxLen := 10000
|
||||
if len(output) > maxLen {
|
||||
output = output[:maxLen] + fmt.Sprintf(
|
||||
"\n... (truncated, %d more chars)",
|
||||
len(output)-maxLen,
|
||||
)
|
||||
output = output[:maxLen] + fmt.Sprintf("\n... (truncated, %d more chars)", len(output)-maxLen)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
@@ -489,11 +460,7 @@ func (t *ExecTool) runSync(ctx context.Context, command, cwd string) *ToolResult
|
||||
}
|
||||
}
|
||||
|
||||
func (t *ExecTool) runBackground(
|
||||
ctx context.Context,
|
||||
command, cwd string,
|
||||
ptyEnabled bool,
|
||||
) *ToolResult {
|
||||
func (t *ExecTool) runBackground(ctx context.Context, command, cwd string, ptyEnabled bool) *ToolResult {
|
||||
sessionID := generateSessionID()
|
||||
session := &ProcessSession{
|
||||
ID: sessionID,
|
||||
@@ -586,8 +553,7 @@ func (t *ExecTool) runBackground(
|
||||
n, err := session.ptyMaster.Read(buf)
|
||||
if n > 0 {
|
||||
raw := string(buf[:n])
|
||||
if mode := detectPtyKeyMode(raw); mode != PtyKeyModeNotFound &&
|
||||
mode != session.GetPtyKeyMode() {
|
||||
if mode := detectPtyKeyMode(raw); mode != PtyKeyModeNotFound && mode != session.GetPtyKeyMode() {
|
||||
session.SetPtyKeyMode(mode)
|
||||
}
|
||||
|
||||
@@ -768,16 +734,12 @@ func (t *ExecTool) executeWrite(args map[string]any) *ToolResult {
|
||||
}
|
||||
|
||||
if session.IsDone() {
|
||||
return ErrorResult(
|
||||
fmt.Sprintf("process already exited with code %d", session.GetExitCode()),
|
||||
)
|
||||
return ErrorResult(fmt.Sprintf("process already exited with code %d", session.GetExitCode()))
|
||||
}
|
||||
|
||||
if err := session.Write(data); err != nil {
|
||||
if errors.Is(err, ErrSessionDone) {
|
||||
return ErrorResult(
|
||||
fmt.Sprintf("process already exited with code %d", session.GetExitCode()),
|
||||
)
|
||||
return ErrorResult(fmt.Sprintf("process already exited with code %d", session.GetExitCode()))
|
||||
}
|
||||
return ErrorResult(fmt.Sprintf("failed to write to session: %v", err))
|
||||
}
|
||||
@@ -808,9 +770,7 @@ func (t *ExecTool) executeKill(args map[string]any) *ToolResult {
|
||||
}
|
||||
|
||||
if session.IsDone() {
|
||||
return ErrorResult(
|
||||
fmt.Sprintf("process already exited with code %d", session.GetExitCode()),
|
||||
)
|
||||
return ErrorResult(fmt.Sprintf("process already exited with code %d", session.GetExitCode()))
|
||||
}
|
||||
|
||||
if err := session.Kill(); err != nil {
|
||||
@@ -1032,16 +992,12 @@ func (t *ExecTool) executeSendKeys(args map[string]any) *ToolResult {
|
||||
}
|
||||
|
||||
if session.IsDone() {
|
||||
return ErrorResult(
|
||||
fmt.Sprintf("process already exited with code %d", session.GetExitCode()),
|
||||
)
|
||||
return ErrorResult(fmt.Sprintf("process already exited with code %d", session.GetExitCode()))
|
||||
}
|
||||
|
||||
if err := session.Write(data); err != nil {
|
||||
if errors.Is(err, ErrSessionDone) {
|
||||
return ErrorResult(
|
||||
fmt.Sprintf("process already exited with code %d", session.GetExitCode()),
|
||||
)
|
||||
return ErrorResult(fmt.Sprintf("process already exited with code %d", session.GetExitCode()))
|
||||
}
|
||||
return ErrorResult(fmt.Sprintf("failed to send keys: %v", err))
|
||||
}
|
||||
|
||||
+18
-77
@@ -100,13 +100,8 @@ func TestShellTool_Timeout(t *testing.T) {
|
||||
}
|
||||
|
||||
// Should mention timeout
|
||||
if !strings.Contains(result.ForLLM, "timed out") &&
|
||||
!strings.Contains(result.ForUser, "timed out") {
|
||||
t.Errorf(
|
||||
"Expected timeout message, got ForLLM: %s, ForUser: %s",
|
||||
result.ForLLM,
|
||||
result.ForUser,
|
||||
)
|
||||
if !strings.Contains(result.ForLLM, "timed out") && !strings.Contains(result.ForUser, "timed out") {
|
||||
t.Errorf("Expected timeout message, got ForLLM: %s, ForUser: %s", result.ForLLM, result.ForUser)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -161,11 +156,7 @@ func TestShellTool_DangerousCommand(t *testing.T) {
|
||||
}
|
||||
|
||||
if !strings.Contains(result.ForLLM, "blocked") && !strings.Contains(result.ForUser, "blocked") {
|
||||
t.Errorf(
|
||||
"Expected 'blocked' message, got ForLLM: %s, ForUser: %s",
|
||||
result.ForLLM,
|
||||
result.ForUser,
|
||||
)
|
||||
t.Errorf("Expected 'blocked' message, got ForLLM: %s, ForUser: %s", result.ForLLM, result.ForUser)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -186,11 +177,7 @@ func TestShellTool_DangerousCommand_KillBlocked(t *testing.T) {
|
||||
t.Errorf("Expected kill command to be blocked")
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "blocked") && !strings.Contains(result.ForUser, "blocked") {
|
||||
t.Errorf(
|
||||
"Expected blocked message, got ForLLM: %s, ForUser: %s",
|
||||
result.ForLLM,
|
||||
result.ForUser,
|
||||
)
|
||||
t.Errorf("Expected blocked message, got ForLLM: %s, ForUser: %s", result.ForLLM, result.ForUser)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -282,10 +269,7 @@ func TestShellTool_WorkingDir_OutsideWorkspace(t *testing.T) {
|
||||
})
|
||||
|
||||
if !result.IsError {
|
||||
t.Fatalf(
|
||||
"expected working_dir outside workspace to be blocked, got output: %s",
|
||||
result.ForLLM,
|
||||
)
|
||||
t.Fatalf("expected working_dir outside workspace to be blocked, got output: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "blocked") {
|
||||
t.Errorf("expected 'blocked' in error, got: %s", result.ForLLM)
|
||||
@@ -460,10 +444,7 @@ func TestShellTool_DevNullAllowed(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, cmd := range commands {
|
||||
result := tool.Execute(
|
||||
context.Background(),
|
||||
map[string]any{"action": "run", "command": cmd},
|
||||
)
|
||||
result := tool.Execute(context.Background(), map[string]any{"action": "run", "command": cmd})
|
||||
if result.IsError && strings.Contains(result.ForLLM, "blocked") {
|
||||
t.Errorf("command should not be blocked: %s\n error: %s", cmd, result.ForLLM)
|
||||
}
|
||||
@@ -492,10 +473,7 @@ func TestShellTool_BlockDevices(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, cmd := range blocked {
|
||||
result := tool.Execute(
|
||||
context.Background(),
|
||||
map[string]any{"action": "run", "command": cmd},
|
||||
)
|
||||
result := tool.Execute(context.Background(), map[string]any{"action": "run", "command": cmd})
|
||||
if !result.IsError {
|
||||
t.Errorf("expected block device write to be blocked: %s", cmd)
|
||||
}
|
||||
@@ -519,16 +497,9 @@ func TestShellTool_SafePathsInWorkspaceRestriction(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, cmd := range commands {
|
||||
result := tool.Execute(
|
||||
context.Background(),
|
||||
map[string]any{"action": "run", "command": cmd},
|
||||
)
|
||||
result := tool.Execute(context.Background(), map[string]any{"action": "run", "command": cmd})
|
||||
if result.IsError && strings.Contains(result.ForLLM, "path outside working dir") {
|
||||
t.Errorf(
|
||||
"safe path should not be blocked by workspace check: %s\n error: %s",
|
||||
cmd,
|
||||
result.ForLLM,
|
||||
)
|
||||
t.Errorf("safe path should not be blocked by workspace check: %s\n error: %s", cmd, result.ForLLM)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -620,10 +591,7 @@ func TestShellTool_CustomAllowPatterns(t *testing.T) {
|
||||
"command": "git push origin main",
|
||||
})
|
||||
if result.IsError && strings.Contains(result.ForLLM, "blocked") {
|
||||
t.Errorf(
|
||||
"custom allow pattern should exempt 'git push origin main', got: %s",
|
||||
result.ForLLM,
|
||||
)
|
||||
t.Errorf("custom allow pattern should exempt 'git push origin main', got: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
// "git push upstream main" should still be blocked (does not match allow pattern).
|
||||
@@ -661,11 +629,7 @@ func TestShellTool_URLsNotBlocked(t *testing.T) {
|
||||
result := tool.Execute(ctx, map[string]any{"action": "run", "command": cmd})
|
||||
cancel()
|
||||
if result.IsError && strings.Contains(result.ForLLM, "path outside working dir") {
|
||||
t.Errorf(
|
||||
"command with URL should not be blocked by workspace check: %s\n error: %s",
|
||||
cmd,
|
||||
result.ForLLM,
|
||||
)
|
||||
t.Errorf("command with URL should not be blocked by workspace check: %s\n error: %s", cmd, result.ForLLM)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -688,10 +652,7 @@ func TestShellTool_FileURISandboxing(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, cmd := range blockedCommands {
|
||||
result := tool.Execute(
|
||||
context.Background(),
|
||||
map[string]any{"action": "run", "command": cmd},
|
||||
)
|
||||
result := tool.Execute(context.Background(), map[string]any{"action": "run", "command": cmd})
|
||||
if !result.IsError || !strings.Contains(result.ForLLM, "path outside working dir") {
|
||||
t.Errorf("file:// URI outside workspace should be blocked: %s", cmd)
|
||||
}
|
||||
@@ -709,16 +670,9 @@ func TestShellTool_FileURISandboxing(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, cmd := range allowedCommands {
|
||||
result := tool.Execute(
|
||||
context.Background(),
|
||||
map[string]any{"action": "run", "command": cmd},
|
||||
)
|
||||
result := tool.Execute(context.Background(), map[string]any{"action": "run", "command": cmd})
|
||||
if result.IsError && strings.Contains(result.ForLLM, "path outside working dir") {
|
||||
t.Errorf(
|
||||
"file:// URI inside workspace should be allowed: %s\n error: %s",
|
||||
cmd,
|
||||
result.ForLLM,
|
||||
)
|
||||
t.Errorf("file:// URI inside workspace should be allowed: %s\n error: %s", cmd, result.ForLLM)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -742,10 +696,7 @@ func TestShellTool_URLBypassPrevented(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, cmd := range blockedCommands {
|
||||
result := tool.Execute(
|
||||
context.Background(),
|
||||
map[string]any{"action": "run", "command": cmd},
|
||||
)
|
||||
result := tool.Execute(context.Background(), map[string]any{"action": "run", "command": cmd})
|
||||
if !result.IsError || !strings.Contains(result.ForLLM, "path outside working dir") {
|
||||
t.Errorf("bypass attempt should be blocked: %q\n got: %s", cmd, result.ForLLM)
|
||||
}
|
||||
@@ -1270,9 +1221,7 @@ func TestShellTool_PTY_ProcessGroupKill(t *testing.T) {
|
||||
// The binary is created in /tmp/test_pgroup.c and compiled as part of test setup.
|
||||
testBinary := "/tmp/test_pgroup"
|
||||
if _, err := os.Stat(testBinary); os.IsNotExist(err) {
|
||||
t.Skip(
|
||||
"Test binary /tmp/test_pgroup not found - run: gcc -o /tmp/test_pgroup /tmp/test_pgroup.c",
|
||||
)
|
||||
t.Skip("Test binary /tmp/test_pgroup not found - run: gcc -o /tmp/test_pgroup /tmp/test_pgroup.c")
|
||||
}
|
||||
|
||||
tool, err := NewExecTool("", false)
|
||||
@@ -1606,16 +1555,8 @@ func TestDetectPtyKeyMode(t *testing.T) {
|
||||
{"rmkx only", "\x1b[?1l\x1b>", PtyKeyModeCSI},
|
||||
{"both smkx first", "\x1b[?1h\x1b=...\x1b[?1l\x1b>", PtyKeyModeCSI},
|
||||
{"both rmkx first", "\x1b[?1l\x1b>...\x1b[?1h\x1b=", PtyKeyModeSS3},
|
||||
{
|
||||
"multiple toggles smkx last",
|
||||
"\x1b[?1h\x1b=...\x1b[?1l\x1b>...\x1b[?1h\x1b=",
|
||||
PtyKeyModeSS3,
|
||||
},
|
||||
{
|
||||
"multiple toggles rmkx last",
|
||||
"\x1b[?1l\x1b>...\x1b[?1h\x1b=...\x1b[?1l\x1b>",
|
||||
PtyKeyModeCSI,
|
||||
},
|
||||
{"multiple toggles smkx last", "\x1b[?1h\x1b=...\x1b[?1l\x1b>...\x1b[?1h\x1b=", PtyKeyModeSS3},
|
||||
{"multiple toggles rmkx last", "\x1b[?1l\x1b>...\x1b[?1h\x1b=...\x1b[?1l\x1b>", PtyKeyModeCSI},
|
||||
{"partial smkx", "\x1b[?1h", PtyKeyModeSS3},
|
||||
{"partial rmkx", "\x1b[?1l", PtyKeyModeCSI},
|
||||
}
|
||||
|
||||
@@ -96,11 +96,7 @@ func (t *InstallSkillTool) Execute(ctx context.Context, args map[string]any) *To
|
||||
if !force {
|
||||
if _, err := os.Stat(targetDir); err == nil {
|
||||
return ErrorResult(
|
||||
fmt.Sprintf(
|
||||
"skill %q already installed at %s. Use force=true to reinstall.",
|
||||
slug,
|
||||
targetDir,
|
||||
),
|
||||
fmt.Sprintf("skill %q already installed at %s. Use force=true to reinstall.", slug, targetDir),
|
||||
)
|
||||
}
|
||||
} else {
|
||||
@@ -146,9 +142,7 @@ func (t *InstallSkillTool) Execute(ctx context.Context, args map[string]any) *To
|
||||
"error": rmErr.Error(),
|
||||
})
|
||||
}
|
||||
return ErrorResult(
|
||||
fmt.Sprintf("skill %q is flagged as malicious and cannot be installed", slug),
|
||||
)
|
||||
return ErrorResult(fmt.Sprintf("skill %q is flagged as malicious and cannot be installed", slug))
|
||||
}
|
||||
|
||||
// Write origin metadata.
|
||||
@@ -168,10 +162,7 @@ func (t *InstallSkillTool) Execute(ctx context.Context, args map[string]any) *To
|
||||
// Build result with moderation warning if suspicious.
|
||||
var output string
|
||||
if result.IsSuspicious {
|
||||
output = fmt.Sprintf(
|
||||
"⚠️ Warning: skill %q is flagged as suspicious (may contain risky patterns).\n\n",
|
||||
slug,
|
||||
)
|
||||
output = fmt.Sprintf("⚠️ Warning: skill %q is flagged as suspicious (may contain risky patterns).\n\n", slug)
|
||||
}
|
||||
output += fmt.Sprintf("Successfully installed skill %q v%s from %s registry.\nLocation: %s\n",
|
||||
slug, result.Version, registry.Name(), targetDir)
|
||||
|
||||
@@ -17,10 +17,7 @@ type FindSkillsTool struct {
|
||||
// NewFindSkillsTool creates a new FindSkillsTool.
|
||||
// registryMgr is the shared registry manager (built from config in createToolRegistry).
|
||||
// cache is the search cache for deduplicating similar queries.
|
||||
func NewFindSkillsTool(
|
||||
registryMgr *skills.RegistryManager,
|
||||
cache *skills.SearchCache,
|
||||
) *FindSkillsTool {
|
||||
func NewFindSkillsTool(registryMgr *skills.RegistryManager, cache *skills.SearchCache) *FindSkillsTool {
|
||||
return &FindSkillsTool{
|
||||
registryMgr: registryMgr,
|
||||
cache: cache,
|
||||
|
||||
@@ -77,12 +77,10 @@ func (t *SpawnStatusTool) Execute(ctx context.Context, args map[string]any) *Too
|
||||
}
|
||||
|
||||
// Restrict lookup to tasks that belong to this conversation.
|
||||
if callerChannel != "" && taskCopy.OriginChannel != "" &&
|
||||
taskCopy.OriginChannel != callerChannel {
|
||||
if callerChannel != "" && taskCopy.OriginChannel != "" && taskCopy.OriginChannel != callerChannel {
|
||||
return ErrorResult(fmt.Sprintf("No subagent found with task ID: %s", taskID))
|
||||
}
|
||||
if callerChatID != "" && taskCopy.OriginChatID != "" &&
|
||||
taskCopy.OriginChatID != callerChatID {
|
||||
if callerChatID != "" && taskCopy.OriginChatID != "" && taskCopy.OriginChatID != callerChatID {
|
||||
return ErrorResult(fmt.Sprintf("No subagent found with task ID: %s", taskID))
|
||||
}
|
||||
|
||||
|
||||
@@ -195,12 +195,7 @@ func TestSpawnStatusTool_TaskID_NonString(t *testing.T) {
|
||||
for _, badVal := range []any{42, 3.14, true, map[string]any{"x": 1}, []string{"a"}} {
|
||||
result := tool.Execute(context.Background(), map[string]any{"task_id": badVal})
|
||||
if !result.IsError {
|
||||
t.Errorf(
|
||||
"Expected error for task_id=%T(%v), got success: %s",
|
||||
badVal,
|
||||
badVal,
|
||||
result.ForLLM,
|
||||
)
|
||||
t.Errorf("Expected error for task_id=%T(%v), got success: %s", badVal, badVal, result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "task_id must be a string") {
|
||||
t.Errorf("Expected type-error message, got: %s", result.ForLLM)
|
||||
@@ -324,10 +319,7 @@ func TestSpawnStatusTool_SortByCreatedTimestamp(t *testing.T) {
|
||||
t.Fatalf("Both task IDs should appear in output:\n%s", result.ForLLM)
|
||||
}
|
||||
if pos2 > pos10 {
|
||||
t.Errorf(
|
||||
"Expected subagent-2 (created first) to appear before subagent-10, but got:\n%s",
|
||||
result.ForLLM,
|
||||
)
|
||||
t.Errorf("Expected subagent-2 (created first) to appear before subagent-10, but got:\n%s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+2
-6
@@ -69,9 +69,7 @@ func (t *SPITool) Parameters() map[string]any {
|
||||
|
||||
func (t *SPITool) Execute(ctx context.Context, args map[string]any) *ToolResult {
|
||||
if runtime.GOOS != "linux" {
|
||||
return ErrorResult(
|
||||
"SPI is only supported on Linux. This tool requires /dev/spidev* device files.",
|
||||
)
|
||||
return ErrorResult("SPI is only supported on Linux. This tool requires /dev/spidev* device files.")
|
||||
}
|
||||
|
||||
action, ok := args["action"].(string)
|
||||
@@ -126,9 +124,7 @@ func (t *SPITool) list() *ToolResult {
|
||||
// parseSPIArgs extracts and validates common SPI parameters
|
||||
//
|
||||
//nolint:unused // Used by spi_linux.go
|
||||
func parseSPIArgs(
|
||||
args map[string]any,
|
||||
) (device string, speed uint32, mode uint8, bits uint8, errMsg string) {
|
||||
func parseSPIArgs(args map[string]any) (device string, speed uint32, mode uint8, bits uint8, errMsg string) {
|
||||
dev, ok := args["device"].(string)
|
||||
if !ok || dev == "" {
|
||||
return "", 0, 0, 0, "device is required (e.g. \"2.0\" for /dev/spidev2.0)"
|
||||
|
||||
+6
-37
@@ -38,46 +38,25 @@ type spiTransfer struct {
|
||||
func configureSPI(devPath string, mode uint8, bits uint8, speed uint32) (int, *ToolResult) {
|
||||
fd, err := syscall.Open(devPath, syscall.O_RDWR, 0)
|
||||
if err != nil {
|
||||
return -1, ErrorResult(
|
||||
fmt.Sprintf(
|
||||
"failed to open %s: %v (check permissions and spidev module)",
|
||||
devPath,
|
||||
err,
|
||||
),
|
||||
)
|
||||
return -1, ErrorResult(fmt.Sprintf("failed to open %s: %v (check permissions and spidev module)", devPath, err))
|
||||
}
|
||||
|
||||
// Set SPI mode
|
||||
_, _, errno := syscall.Syscall(
|
||||
syscall.SYS_IOCTL,
|
||||
uintptr(fd),
|
||||
spiIocWrMode,
|
||||
uintptr(unsafe.Pointer(&mode)),
|
||||
)
|
||||
_, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), spiIocWrMode, uintptr(unsafe.Pointer(&mode)))
|
||||
if errno != 0 {
|
||||
syscall.Close(fd)
|
||||
return -1, ErrorResult(fmt.Sprintf("failed to set SPI mode %d: %v", mode, errno))
|
||||
}
|
||||
|
||||
// Set bits per word
|
||||
_, _, errno = syscall.Syscall(
|
||||
syscall.SYS_IOCTL,
|
||||
uintptr(fd),
|
||||
spiIocWrBitsPerWord,
|
||||
uintptr(unsafe.Pointer(&bits)),
|
||||
)
|
||||
_, _, errno = syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), spiIocWrBitsPerWord, uintptr(unsafe.Pointer(&bits)))
|
||||
if errno != 0 {
|
||||
syscall.Close(fd)
|
||||
return -1, ErrorResult(fmt.Sprintf("failed to set bits per word %d: %v", bits, errno))
|
||||
}
|
||||
|
||||
// Set max speed
|
||||
_, _, errno = syscall.Syscall(
|
||||
syscall.SYS_IOCTL,
|
||||
uintptr(fd),
|
||||
spiIocWrMaxSpeedHz,
|
||||
uintptr(unsafe.Pointer(&speed)),
|
||||
)
|
||||
_, _, errno = syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), spiIocWrMaxSpeedHz, uintptr(unsafe.Pointer(&speed)))
|
||||
if errno != 0 {
|
||||
syscall.Close(fd)
|
||||
return -1, ErrorResult(fmt.Sprintf("failed to set SPI speed %d Hz: %v", speed, errno))
|
||||
@@ -138,12 +117,7 @@ func (t *SPITool) transfer(args map[string]any) *ToolResult {
|
||||
bitsPerWord: bits,
|
||||
}
|
||||
|
||||
_, _, errno := syscall.Syscall(
|
||||
syscall.SYS_IOCTL,
|
||||
uintptr(fd),
|
||||
spiIocMessage1,
|
||||
uintptr(unsafe.Pointer(&xfer)),
|
||||
)
|
||||
_, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), spiIocMessage1, uintptr(unsafe.Pointer(&xfer)))
|
||||
runtime.KeepAlive(txBuf)
|
||||
runtime.KeepAlive(rxBuf)
|
||||
if errno != 0 {
|
||||
@@ -200,12 +174,7 @@ func (t *SPITool) readDevice(args map[string]any) *ToolResult {
|
||||
bitsPerWord: bits,
|
||||
}
|
||||
|
||||
_, _, errno := syscall.Syscall(
|
||||
syscall.SYS_IOCTL,
|
||||
uintptr(fd),
|
||||
spiIocMessage1,
|
||||
uintptr(unsafe.Pointer(&xfer)),
|
||||
)
|
||||
_, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), spiIocMessage1, uintptr(unsafe.Pointer(&xfer)))
|
||||
runtime.KeepAlive(txBuf)
|
||||
runtime.KeepAlive(rxBuf)
|
||||
if errno != 0 {
|
||||
|
||||
@@ -316,11 +316,7 @@ func TestSubagentTool_ForUserTruncation(t *testing.T) {
|
||||
// ForUser should be truncated to 500 chars + "..."
|
||||
maxUserLen := 500
|
||||
if len(result.ForUser) > maxUserLen+3 { // +3 for "..."
|
||||
t.Errorf(
|
||||
"ForUser should be truncated to ~%d chars, got: %d",
|
||||
maxUserLen,
|
||||
len(result.ForUser),
|
||||
)
|
||||
t.Errorf("ForUser should be truncated to ~%d chars, got: %d", maxUserLen, len(result.ForUser))
|
||||
}
|
||||
|
||||
// ForLLM should have full content
|
||||
|
||||
+2
-15
@@ -64,13 +64,7 @@ func RunToolLoop(
|
||||
llmOpts = map[string]any{}
|
||||
}
|
||||
// 3. Call LLM
|
||||
response, err := config.Provider.Chat(
|
||||
ctx,
|
||||
messages,
|
||||
providerToolDefs,
|
||||
config.Model,
|
||||
llmOpts,
|
||||
)
|
||||
response, err := config.Provider.Chat(ctx, messages, providerToolDefs, config.Model, llmOpts)
|
||||
if err != nil {
|
||||
logger.ErrorCF("toolloop", "LLM call failed",
|
||||
map[string]any{
|
||||
@@ -154,14 +148,7 @@ func RunToolLoop(
|
||||
|
||||
var toolResult *ToolResult
|
||||
if config.Tools != nil {
|
||||
toolResult = config.Tools.ExecuteWithContext(
|
||||
ctx,
|
||||
tc.Name,
|
||||
tc.Arguments,
|
||||
channel,
|
||||
chatID,
|
||||
nil,
|
||||
)
|
||||
toolResult = config.Tools.ExecuteWithContext(ctx, tc.Name, tc.Arguments, channel, chatID, nil)
|
||||
} else {
|
||||
toolResult = ErrorResult("No tools available")
|
||||
}
|
||||
|
||||
@@ -151,10 +151,7 @@ func TestValidateToolArgs(t *testing.T) {
|
||||
schema: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"color": map[string]any{
|
||||
"type": "string",
|
||||
"enum": []any{"red", "green", "blue"},
|
||||
},
|
||||
"color": map[string]any{"type": "string", "enum": []any{"red", "green", "blue"}},
|
||||
},
|
||||
},
|
||||
args: map[string]any{"color": "red"},
|
||||
@@ -164,10 +161,7 @@ func TestValidateToolArgs(t *testing.T) {
|
||||
schema: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"color": map[string]any{
|
||||
"type": "string",
|
||||
"enum": []any{"red", "green", "blue"},
|
||||
},
|
||||
"color": map[string]any{"type": "string", "enum": []any{"red", "green", "blue"}},
|
||||
},
|
||||
},
|
||||
args: map[string]any{"color": "yellow"},
|
||||
@@ -178,10 +172,7 @@ func TestValidateToolArgs(t *testing.T) {
|
||||
schema: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"color": map[string]any{
|
||||
"type": "string",
|
||||
"enum": []string{"red", "green", "blue"},
|
||||
},
|
||||
"color": map[string]any{"type": "string", "enum": []string{"red", "green", "blue"}},
|
||||
},
|
||||
},
|
||||
args: map[string]any{"color": "green"},
|
||||
@@ -191,10 +182,7 @@ func TestValidateToolArgs(t *testing.T) {
|
||||
schema: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"color": map[string]any{
|
||||
"type": "string",
|
||||
"enum": []string{"red", "green", "blue"},
|
||||
},
|
||||
"color": map[string]any{"type": "string", "enum": []string{"red", "green", "blue"}},
|
||||
},
|
||||
},
|
||||
args: map[string]any{"color": "yellow"},
|
||||
@@ -354,11 +342,7 @@ func TestValidateToolArgs_RegistryIntegration(t *testing.T) {
|
||||
}
|
||||
|
||||
// Extra property — should fail with validation error
|
||||
result = r.Execute(
|
||||
context.Background(),
|
||||
"read_file",
|
||||
map[string]any{"path": "/x", "__inject": true},
|
||||
)
|
||||
result = r.Execute(context.Background(), "read_file", map[string]any{"path": "/x", "__inject": true})
|
||||
if !result.IsError {
|
||||
t.Error("expected validation error for extra property")
|
||||
}
|
||||
|
||||
+32
-130
@@ -54,8 +54,7 @@ func TestWebTool_WebFetch_Success(t *testing.T) {
|
||||
}
|
||||
|
||||
// ForUser should contain summary
|
||||
if !strings.Contains(result.ForUser, "bytes") &&
|
||||
!strings.Contains(result.ForUser, "extractor") {
|
||||
if !strings.Contains(result.ForUser, "bytes") && !strings.Contains(result.ForUser, "extractor") {
|
||||
t.Errorf("Expected ForUser to contain summary, got: %s", result.ForUser)
|
||||
}
|
||||
}
|
||||
@@ -76,11 +75,7 @@ func TestWebTool_WebFetch_JSON(t *testing.T) {
|
||||
|
||||
tool, err := NewWebFetchTool(50000, format, testFetchLimit)
|
||||
if err != nil {
|
||||
logger.ErrorCF(
|
||||
"agent",
|
||||
"Failed to create web fetch tool",
|
||||
map[string]any{"error": err.Error()},
|
||||
)
|
||||
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
@@ -105,11 +100,7 @@ func TestWebTool_WebFetch_JSON(t *testing.T) {
|
||||
func TestWebTool_WebFetch_InvalidURL(t *testing.T) {
|
||||
tool, err := NewWebFetchTool(50000, format, testFetchLimit)
|
||||
if err != nil {
|
||||
logger.ErrorCF(
|
||||
"agent",
|
||||
"Failed to create web fetch tool",
|
||||
map[string]any{"error": err.Error()},
|
||||
)
|
||||
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
@@ -134,11 +125,7 @@ func TestWebTool_WebFetch_InvalidURL(t *testing.T) {
|
||||
func TestWebTool_WebFetch_UnsupportedScheme(t *testing.T) {
|
||||
tool, err := NewWebFetchTool(50000, format, testFetchLimit)
|
||||
if err != nil {
|
||||
logger.ErrorCF(
|
||||
"agent",
|
||||
"Failed to create web fetch tool",
|
||||
map[string]any{"error": err.Error()},
|
||||
)
|
||||
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
@@ -154,8 +141,7 @@ func TestWebTool_WebFetch_UnsupportedScheme(t *testing.T) {
|
||||
}
|
||||
|
||||
// Should mention only http/https allowed
|
||||
if !strings.Contains(result.ForLLM, "http/https") &&
|
||||
!strings.Contains(result.ForUser, "http/https") {
|
||||
if !strings.Contains(result.ForLLM, "http/https") && !strings.Contains(result.ForUser, "http/https") {
|
||||
t.Errorf("Expected scheme error message, got ForLLM: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
@@ -164,11 +150,7 @@ func TestWebTool_WebFetch_UnsupportedScheme(t *testing.T) {
|
||||
func TestWebTool_WebFetch_MissingURL(t *testing.T) {
|
||||
tool, err := NewWebFetchTool(50000, format, testFetchLimit)
|
||||
if err != nil {
|
||||
logger.ErrorCF(
|
||||
"agent",
|
||||
"Failed to create web fetch tool",
|
||||
map[string]any{"error": err.Error()},
|
||||
)
|
||||
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
@@ -182,8 +164,7 @@ func TestWebTool_WebFetch_MissingURL(t *testing.T) {
|
||||
}
|
||||
|
||||
// Should mention URL is required
|
||||
if !strings.Contains(result.ForLLM, "url is required") &&
|
||||
!strings.Contains(result.ForUser, "url is required") {
|
||||
if !strings.Contains(result.ForLLM, "url is required") && !strings.Contains(result.ForUser, "url is required") {
|
||||
t.Errorf("Expected 'url is required' message, got ForLLM: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
@@ -203,11 +184,7 @@ func TestWebTool_WebFetch_Truncation(t *testing.T) {
|
||||
|
||||
tool, err := NewWebFetchTool(1000, format, testFetchLimit) // Limit to 1000 chars
|
||||
if err != nil {
|
||||
logger.ErrorCF(
|
||||
"agent",
|
||||
"Failed to create web fetch tool",
|
||||
map[string]any{"error": err.Error()},
|
||||
)
|
||||
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
@@ -239,10 +216,7 @@ func TestWebTool_WebFetch_Truncation(t *testing.T) {
|
||||
// Text should end with the truncation notice
|
||||
if text, ok := resultMap["text"].(string); ok {
|
||||
if !strings.HasSuffix(text, "[Content truncated due to size limit]") {
|
||||
t.Errorf(
|
||||
"Expected text to end with truncation notice, got: %q",
|
||||
text[max(0, len(text)-60):],
|
||||
)
|
||||
t.Errorf("Expected text to end with truncation notice, got: %q", text[max(0, len(text)-60):])
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -289,13 +263,11 @@ func TestWebTool_WebFetch_TruncationNotice(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
server := httptest.NewServer(
|
||||
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", tt.contentType)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(tt.body))
|
||||
}),
|
||||
)
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", tt.contentType)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(tt.body))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
tool, err := NewWebFetchTool(maxChars, tt.format, testFetchLimit)
|
||||
@@ -319,11 +291,7 @@ func TestWebTool_WebFetch_TruncationNotice(t *testing.T) {
|
||||
}
|
||||
|
||||
if !strings.HasSuffix(text, truncationNotice) {
|
||||
t.Errorf(
|
||||
"expected text to end with %q, got suffix: %q",
|
||||
truncationNotice,
|
||||
text[max(0, len(text)-60):],
|
||||
)
|
||||
t.Errorf("expected text to end with %q, got suffix: %q", truncationNotice, text[max(0, len(text)-60):])
|
||||
}
|
||||
|
||||
if truncated, ok := resultMap["truncated"].(bool); !ok || !truncated {
|
||||
@@ -392,11 +360,7 @@ func TestWebFetchTool_PayloadTooLarge(t *testing.T) {
|
||||
// Initialize the tool
|
||||
tool, err := NewWebFetchTool(50000, format, testFetchLimit)
|
||||
if err != nil {
|
||||
logger.ErrorCF(
|
||||
"agent",
|
||||
"Failed to create web fetch tool",
|
||||
map[string]any{"error": err.Error()},
|
||||
)
|
||||
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
|
||||
}
|
||||
|
||||
// Prepare the arguments pointing to the URL of our local mock server
|
||||
@@ -416,8 +380,7 @@ func TestWebFetchTool_PayloadTooLarge(t *testing.T) {
|
||||
// Search for the exact error string we set earlier in the Execute method
|
||||
expectedErrorMsg := fmt.Sprintf("size exceeded %d bytes limit", testFetchLimit)
|
||||
|
||||
if !strings.Contains(result.ForLLM, expectedErrorMsg) &&
|
||||
!strings.Contains(result.ForUser, expectedErrorMsg) {
|
||||
if !strings.Contains(result.ForLLM, expectedErrorMsg) && !strings.Contains(result.ForUser, expectedErrorMsg) {
|
||||
t.Errorf("test failed: expected error %q, but got: %+v", expectedErrorMsg, result)
|
||||
}
|
||||
}
|
||||
@@ -570,11 +533,7 @@ func TestWebTool_WebFetch_HTMLExtraction(t *testing.T) {
|
||||
|
||||
tool, err := NewWebFetchTool(50000, format, testFetchLimit)
|
||||
if err != nil {
|
||||
logger.ErrorCF(
|
||||
"agent",
|
||||
"Failed to create web fetch tool",
|
||||
map[string]any{"error": err.Error()},
|
||||
)
|
||||
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
@@ -759,13 +718,7 @@ func TestWebTool_WebFetch_PrivateHostAllowedByCIDRWhitelist(t *testing.T) {
|
||||
defer server.Close()
|
||||
|
||||
host, _ := serverHostAndPort(t, server.URL)
|
||||
tool, err := NewWebFetchToolWithConfig(
|
||||
50000,
|
||||
"",
|
||||
format,
|
||||
testFetchLimit,
|
||||
[]string{singleHostCIDR(t, host)},
|
||||
)
|
||||
tool, err := NewWebFetchToolWithConfig(50000, "", format, testFetchLimit, []string{singleHostCIDR(t, host)})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create web fetch tool: %v", err)
|
||||
}
|
||||
@@ -800,10 +753,7 @@ func TestWebTool_WebFetch_PrivateHostAllowedForTests(t *testing.T) {
|
||||
})
|
||||
|
||||
if result.IsError {
|
||||
t.Errorf(
|
||||
"expected success when private host access is allowed in tests, got %q",
|
||||
result.ForLLM,
|
||||
)
|
||||
t.Errorf("expected success when private host access is allowed in tests, got %q", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1023,11 +973,7 @@ func TestIsPrivateOrRestrictedIP_Table(t *testing.T) {
|
||||
func TestWebTool_WebFetch_MissingDomain(t *testing.T) {
|
||||
tool, err := NewWebFetchTool(50000, format, testFetchLimit)
|
||||
if err != nil {
|
||||
logger.ErrorCF(
|
||||
"agent",
|
||||
"Failed to create web fetch tool",
|
||||
map[string]any{"error": err.Error()},
|
||||
)
|
||||
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
@@ -1049,19 +995,9 @@ func TestWebTool_WebFetch_MissingDomain(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestNewWebFetchToolWithProxy(t *testing.T) {
|
||||
tool, err := NewWebFetchToolWithProxy(
|
||||
1024,
|
||||
"http://127.0.0.1:7890",
|
||||
format,
|
||||
testFetchLimit,
|
||||
nil,
|
||||
)
|
||||
tool, err := NewWebFetchToolWithProxy(1024, "http://127.0.0.1:7890", format, testFetchLimit, nil)
|
||||
if err != nil {
|
||||
logger.ErrorCF(
|
||||
"agent",
|
||||
"Failed to create web fetch tool",
|
||||
map[string]any{"error": err.Error()},
|
||||
)
|
||||
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
|
||||
} else if tool.maxChars != 1024 {
|
||||
t.Fatalf("maxChars = %d, want %d", tool.maxChars, 1024)
|
||||
}
|
||||
@@ -1072,11 +1008,7 @@ func TestNewWebFetchToolWithProxy(t *testing.T) {
|
||||
|
||||
tool, err = NewWebFetchToolWithProxy(0, "http://127.0.0.1:7890", format, testFetchLimit, nil)
|
||||
if err != nil {
|
||||
logger.ErrorCF(
|
||||
"agent",
|
||||
"Failed to create web fetch tool",
|
||||
map[string]any{"error": err.Error()},
|
||||
)
|
||||
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
|
||||
}
|
||||
|
||||
if tool.maxChars != 50000 {
|
||||
@@ -1085,13 +1017,7 @@ func TestNewWebFetchToolWithProxy(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestNewWebFetchToolWithConfig_InvalidPrivateHostWhitelist(t *testing.T) {
|
||||
_, err := NewWebFetchToolWithConfig(
|
||||
1024,
|
||||
"",
|
||||
format,
|
||||
testFetchLimit,
|
||||
[]string{"not-an-ip-or-cidr"},
|
||||
)
|
||||
_, err := NewWebFetchToolWithConfig(1024, "", format, testFetchLimit, []string{"not-an-ip-or-cidr"})
|
||||
if err == nil {
|
||||
t.Fatal("expected invalid whitelist entry to fail")
|
||||
}
|
||||
@@ -1247,11 +1173,7 @@ func TestWebTool_TavilySearch_RangeMapping(t *testing.T) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"results": []map[string]any{
|
||||
{
|
||||
"title": "Recent result",
|
||||
"url": "https://example.com/recent",
|
||||
"content": "snippet",
|
||||
},
|
||||
{"title": "Recent result", "url": "https://example.com/recent", "content": "snippet"},
|
||||
},
|
||||
})
|
||||
}))
|
||||
@@ -1381,10 +1303,7 @@ func TestWebFetchTool_CloudflareChallenge_RetryFailsToo(t *testing.T) {
|
||||
|
||||
// Should not be an error — the retry response is used as-is (403 is a valid HTTP response)
|
||||
if result.IsError {
|
||||
t.Fatalf(
|
||||
"expected non-error result even when retry is also blocked, got: %s",
|
||||
result.ForLLM,
|
||||
)
|
||||
t.Fatalf("expected non-error result even when retry is also blocked, got: %s", result.ForLLM)
|
||||
}
|
||||
// Status in the JSON result should reflect the 403
|
||||
if !strings.Contains(result.ForLLM, "403") {
|
||||
@@ -1549,10 +1468,7 @@ func TestWebTool_GLMSearch_Success(t *testing.T) {
|
||||
t.Errorf("Expected Content-Type application/json, got %s", r.Header.Get("Content-Type"))
|
||||
}
|
||||
if r.Header.Get("Authorization") != "Bearer test-glm-key" {
|
||||
t.Errorf(
|
||||
"Expected Authorization Bearer test-glm-key, got %s",
|
||||
r.Header.Get("Authorization"),
|
||||
)
|
||||
t.Errorf("Expected Authorization Bearer test-glm-key, got %s", r.Header.Get("Authorization"))
|
||||
}
|
||||
|
||||
var payload map[string]any
|
||||
@@ -1618,21 +1534,14 @@ func TestWebTool_GLMSearch_RangeMapping(t *testing.T) {
|
||||
t.Fatalf("failed to decode payload: %v", err)
|
||||
}
|
||||
if payload["search_recency_filter"] != "oneMonth" {
|
||||
t.Fatalf(
|
||||
"expected search_recency_filter=oneMonth, got %v",
|
||||
payload["search_recency_filter"],
|
||||
)
|
||||
t.Fatalf("expected search_recency_filter=oneMonth, got %v", payload["search_recency_filter"])
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"search_result": []map[string]any{
|
||||
{
|
||||
"title": "Recent GLM Result",
|
||||
"content": "snippet",
|
||||
"link": "https://example.com/glm-range",
|
||||
},
|
||||
{"title": "Recent GLM Result", "content": "snippet", "link": "https://example.com/glm-range"},
|
||||
},
|
||||
})
|
||||
}))
|
||||
@@ -1664,21 +1573,14 @@ func TestWebTool_BaiduSearch_RangeMapping(t *testing.T) {
|
||||
t.Fatalf("failed to decode payload: %v", err)
|
||||
}
|
||||
if payload["search_recency_filter"] != "week" {
|
||||
t.Fatalf(
|
||||
"expected search_recency_filter=week for day fallback, got %v",
|
||||
payload["search_recency_filter"],
|
||||
)
|
||||
t.Fatalf("expected search_recency_filter=week for day fallback, got %v", payload["search_recency_filter"])
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"references": []map[string]any{
|
||||
{
|
||||
"title": "Recent Baidu Result",
|
||||
"url": "https://example.com/baidu",
|
||||
"content": "snippet",
|
||||
},
|
||||
{"title": "Recent Baidu Result", "url": "https://example.com/baidu", "content": "snippet"},
|
||||
},
|
||||
})
|
||||
}))
|
||||
|
||||
Reference in New Issue
Block a user