fix(tool): route binary outputs through the media pipeline.

This commit is contained in:
afjcjsbx
2026-03-22 12:05:28 +01:00
parent c0bb8d6df9
commit df4f322f09
14 changed files with 1462 additions and 64 deletions
+296
View File
@@ -298,6 +298,152 @@ func TestToolRegistry_GetDefinitions(t *testing.T) {
}
}
func TestProcessMessage_MediaToolHandledSkipsFollowUpLLMAndFinalText(t *testing.T) {
tmpDir := t.TempDir()
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
msgBus := bus.NewMessageBus()
provider := &handledMediaProvider{}
al := NewAgentLoop(cfg, msgBus, provider)
store := media.NewFileMediaStore()
al.SetMediaStore(store)
imagePath := filepath.Join(tmpDir, "screen.png")
if err := os.WriteFile(imagePath, []byte("fake screenshot"), 0o644); err != nil {
t.Fatalf("WriteFile(imagePath) error = %v", err)
}
al.RegisterTool(&handledMediaTool{
store: store,
path: imagePath,
})
response, err := al.processMessage(context.Background(), bus.InboundMessage{
Channel: "telegram",
ChatID: "chat1",
SenderID: "user1",
Content: "take a screenshot of the screen and send it to me",
})
if err != nil {
t.Fatalf("processMessage() error = %v", err)
}
if response != "" {
t.Fatalf("expected no final response when media tool already handled delivery, got %q", response)
}
if provider.calls != 1 {
t.Fatalf("expected exactly 1 LLM call, got %d", provider.calls)
}
if len(provider.toolCounts) != 1 {
t.Fatalf("expected tool counts for 1 provider call, got %d", len(provider.toolCounts))
}
if provider.toolCounts[0] == 0 {
t.Fatal("expected tools to be available on the first LLM call")
}
select {
case mediaMsg := <-msgBus.OutboundMediaChan():
if mediaMsg.Channel != "telegram" || mediaMsg.ChatID != "chat1" {
t.Fatalf("unexpected outbound media target: %+v", mediaMsg)
}
if len(mediaMsg.Parts) != 1 {
t.Fatalf("expected exactly 1 outbound media part, got %d", len(mediaMsg.Parts))
}
default:
t.Fatal("expected outbound media message to be published")
}
defaultAgent := al.GetRegistry().GetDefaultAgent()
if defaultAgent == nil {
t.Fatal("expected default agent")
}
route, _, err := al.resolveMessageRoute(bus.InboundMessage{
Channel: "telegram",
ChatID: "chat1",
SenderID: "user1",
Content: "take a screenshot of the screen and send it to me",
})
if err != nil {
t.Fatalf("resolveMessageRoute() error = %v", err)
}
sessionKey := resolveScopeKey(route, "")
history := defaultAgent.Sessions.GetHistory(sessionKey)
if len(history) == 0 {
t.Fatal("expected session history to be saved")
}
last := history[len(history)-1]
if last.Role != "assistant" || last.Content != handledToolResponseSummary {
t.Fatalf("expected handled assistant summary in history, got %+v", last)
}
}
func TestProcessMessage_MediaArtifactCanBeForwardedBySendFile(t *testing.T) {
tmpDir := t.TempDir()
cfg := config.DefaultConfig()
cfg.Agents.Defaults.Workspace = tmpDir
cfg.Agents.Defaults.Model = "test-model"
cfg.Agents.Defaults.MaxTokens = 4096
cfg.Agents.Defaults.MaxToolIterations = 10
msgBus := bus.NewMessageBus()
provider := &artifactThenSendProvider{}
al := NewAgentLoop(cfg, msgBus, provider)
store := media.NewFileMediaStore()
al.SetMediaStore(store)
mediaDir := media.TempDir()
if err := os.MkdirAll(mediaDir, 0o700); err != nil {
t.Fatalf("MkdirAll(mediaDir) error = %v", err)
}
imagePath := filepath.Join(mediaDir, "artifact-screen.png")
if err := os.WriteFile(imagePath, []byte("fake screenshot"), 0o644); err != nil {
t.Fatalf("WriteFile(imagePath) error = %v", err)
}
al.RegisterTool(&mediaArtifactTool{
store: store,
path: imagePath,
})
response, err := al.processMessage(context.Background(), bus.InboundMessage{
Channel: "telegram",
ChatID: "chat1",
SenderID: "user1",
Content: "take a screenshot of the screen and send it to me",
})
if err != nil {
t.Fatalf("processMessage() error = %v", err)
}
if response != "" {
t.Fatalf("expected no final response after send_file handled delivery, got %q", response)
}
if provider.calls != 2 {
t.Fatalf("expected 2 LLM calls (artifact + send_file), got %d", provider.calls)
}
select {
case mediaMsg := <-msgBus.OutboundMediaChan():
if mediaMsg.Channel != "telegram" || mediaMsg.ChatID != "chat1" {
t.Fatalf("unexpected outbound media target: %+v", mediaMsg)
}
if len(mediaMsg.Parts) != 1 {
t.Fatalf("expected exactly 1 outbound media part, got %d", len(mediaMsg.Parts))
}
default:
t.Fatal("expected outbound media from send_file")
}
}
// TestAgentLoop_GetStartupInfo verifies startup info contains tools
func TestAgentLoop_GetStartupInfo(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
@@ -420,6 +566,98 @@ func (m *countingMockProvider) GetDefaultModel() string {
return "counting-mock-model"
}
type handledMediaProvider struct {
calls int
toolCounts []int
}
func (m *handledMediaProvider) Chat(
ctx context.Context,
messages []providers.Message,
tools []providers.ToolDefinition,
model string,
opts map[string]any,
) (*providers.LLMResponse, error) {
m.calls++
m.toolCounts = append(m.toolCounts, len(tools))
if m.calls == 1 {
return &providers.LLMResponse{
Content: "Taking the screenshot now.",
ToolCalls: []providers.ToolCall{{
ID: "call_handled_media",
Type: "function",
Name: "handled_media_tool",
Arguments: map[string]any{},
}},
}, nil
}
return &providers.LLMResponse{}, nil
}
func (m *handledMediaProvider) GetDefaultModel() string {
return "handled-media-model"
}
type artifactThenSendProvider struct {
calls int
}
func (m *artifactThenSendProvider) Chat(
ctx context.Context,
messages []providers.Message,
tools []providers.ToolDefinition,
model string,
opts map[string]any,
) (*providers.LLMResponse, error) {
m.calls++
if m.calls == 1 {
return &providers.LLMResponse{
Content: "Taking the screenshot now.",
ToolCalls: []providers.ToolCall{{
ID: "call_artifact_media",
Type: "function",
Name: "media_artifact_tool",
Arguments: map[string]any{},
}},
}, nil
}
var artifactPath string
for i := len(messages) - 1; i >= 0; i-- {
if messages[i].Role != "tool" {
continue
}
start := strings.Index(messages[i].Content, "[file:")
if start < 0 {
continue
}
rest := messages[i].Content[start+len("[file:"):]
end := strings.Index(rest, "]")
if end < 0 {
continue
}
artifactPath = rest[:end]
break
}
if artifactPath == "" {
return nil, fmt.Errorf("provider did not receive artifact path in tool result")
}
return &providers.LLMResponse{
Content: "",
ToolCalls: []providers.ToolCall{{
ID: "call_send_file",
Type: "function",
Name: "send_file",
Arguments: map[string]any{"path": artifactPath},
}},
}, nil
}
func (m *artifactThenSendProvider) GetDefaultModel() string {
return "artifact-then-send-model"
}
type toolLimitOnlyProvider struct{}
func (m *toolLimitOnlyProvider) Chat(
@@ -465,6 +703,64 @@ func (m *mockCustomTool) Execute(ctx context.Context, args map[string]any) *tool
return tools.SilentResult("Custom tool executed")
}
type handledMediaTool struct {
store media.MediaStore
path string
}
func (m *handledMediaTool) Name() string { return "handled_media_tool" }
func (m *handledMediaTool) Description() string {
return "Returns a media attachment and fully handles the user response"
}
func (m *handledMediaTool) Parameters() map[string]any {
return map[string]any{
"type": "object",
"properties": map[string]any{},
}
}
func (m *handledMediaTool) Execute(ctx context.Context, args map[string]any) *tools.ToolResult {
ref, err := m.store.Store(m.path, media.MediaMeta{
Filename: filepath.Base(m.path),
ContentType: "image/png",
Source: "test:handled_media_tool",
}, "test:handled_media")
if err != nil {
return tools.ErrorResult(err.Error()).WithError(err)
}
return tools.MediaResult("Attachment delivered by tool.", []string{ref}).WithResponseHandled()
}
type mediaArtifactTool struct {
store media.MediaStore
path string
}
func (m *mediaArtifactTool) Name() string { return "media_artifact_tool" }
func (m *mediaArtifactTool) Description() string {
return "Returns a media artifact that the agent can forward or save later"
}
func (m *mediaArtifactTool) Parameters() map[string]any {
return map[string]any{
"type": "object",
"properties": map[string]any{},
}
}
func (m *mediaArtifactTool) Execute(ctx context.Context, args map[string]any) *tools.ToolResult {
ref, err := m.store.Store(m.path, media.MediaMeta{
Filename: filepath.Base(m.path),
ContentType: "image/png",
Source: "test:media_artifact_tool",
}, "test:media_artifact")
if err != nil {
return tools.ErrorResult(err.Error()).WithError(err)
}
return tools.MediaResult("Artifact created.", []string{ref})
}
type toolLimitTestTool struct{}
func (m *toolLimitTestTool) Name() string {