fix(agent,openai_compat): address review feedback on vision pipeline

- serializeMessages: preserve ToolCallID/ToolCalls when Media is present
- resolveMediaRefs: add 20MB file size limit to prevent OOM
- mimeFromExtension: return empty string for unknown extensions
- Add 11 unit tests for serializeMessages, resolveMediaRefs, mimeFromExtension

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
shikihane
2026-03-03 11:13:22 +08:00
parent 18b36af934
commit 8ebeefc59f
4 changed files with 229 additions and 1 deletions
+30 -1
View File
@@ -1356,6 +1356,10 @@ func extractParentPeer(msg bus.InboundMessage) *routing.RoutePeer {
return &routing.RoutePeer{Kind: parentKind, ID: parentID}
}
// maxMediaFileSize is the maximum file size (20 MB) for media resolution.
// Files larger than this are skipped to prevent OOM under concurrent load.
const maxMediaFileSize = 20 * 1024 * 1024
// resolveMediaRefs replaces media:// refs in message Media fields with base64 data URLs.
// Returns a new slice with resolved URLs; original messages are not mutated.
func resolveMediaRefs(messages []providers.Message, store media.MediaStore) []providers.Message {
@@ -1387,6 +1391,23 @@ func resolveMediaRefs(messages []providers.Message, store media.MediaStore) []pr
continue
}
info, err := os.Stat(localPath)
if err != nil {
logger.WarnCF("agent", "Failed to stat media file", map[string]any{
"path": localPath,
"error": err.Error(),
})
continue
}
if info.Size() > maxMediaFileSize {
logger.WarnCF("agent", "Media file too large, skipping", map[string]any{
"path": localPath,
"size": info.Size(),
"max_size": maxMediaFileSize,
})
continue
}
data, err := os.ReadFile(localPath)
if err != nil {
logger.WarnCF("agent", "Failed to read media file", map[string]any{
@@ -1400,6 +1421,13 @@ func resolveMediaRefs(messages []providers.Message, store media.MediaStore) []pr
if mime == "" {
mime = mimeFromExtension(filepath.Ext(localPath))
}
if mime == "" {
logger.WarnCF("agent", "Unknown media type, skipping", map[string]any{
"path": localPath,
"ext": filepath.Ext(localPath),
})
continue
}
dataURL := "data:" + mime + ";base64," + base64.StdEncoding.EncodeToString(data)
resolved = append(resolved, dataURL)
@@ -1412,6 +1440,7 @@ func resolveMediaRefs(messages []providers.Message, store media.MediaStore) []pr
}
// mimeFromExtension returns a MIME type for common image extensions.
// Returns empty string for unrecognized extensions.
func mimeFromExtension(ext string) string {
switch strings.ToLower(ext) {
case ".jpg", ".jpeg":
@@ -1425,6 +1454,6 @@ func mimeFromExtension(ext string) string {
case ".bmp":
return "image/bmp"
default:
return "image/jpeg"
return ""
}
}
+123
View File
@@ -6,12 +6,14 @@ import (
"os"
"path/filepath"
"slices"
"strings"
"testing"
"time"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/channels"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/media"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/tools"
)
@@ -840,3 +842,124 @@ func TestHandleReasoning(t *testing.T) {
}
})
}
func TestMimeFromExtension(t *testing.T) {
tests := []struct {
ext string
want string
}{
{".jpg", "image/jpeg"},
{".JPEG", "image/jpeg"},
{".png", "image/png"},
{".gif", "image/gif"},
{".webp", "image/webp"},
{".bmp", "image/bmp"},
{".txt", ""},
{".pdf", ""},
{"", ""},
}
for _, tt := range tests {
if got := mimeFromExtension(tt.ext); got != tt.want {
t.Errorf("mimeFromExtension(%q) = %q, want %q", tt.ext, got, tt.want)
}
}
}
func TestResolveMediaRefs_NilStore(t *testing.T) {
msgs := []providers.Message{{Role: "user", Content: "hi", Media: []string{"media://abc"}}}
result := resolveMediaRefs(msgs, nil)
if result[0].Media[0] != "media://abc" {
t.Error("nil store should return messages unchanged")
}
}
func TestResolveMediaRefs_NonMediaRef(t *testing.T) {
msgs := []providers.Message{{Role: "user", Content: "hi", Media: []string{"https://example.com/img.png"}}}
result := resolveMediaRefs(msgs, media.NewFileMediaStore())
if result[0].Media[0] != "https://example.com/img.png" {
t.Error("non-media:// refs should be passed through unchanged")
}
}
func TestResolveMediaRefs_ResolvesToBase64(t *testing.T) {
store := media.NewFileMediaStore()
imgPath := filepath.Join(t.TempDir(), "test.png")
if err := os.WriteFile(imgPath, []byte("fake-png-data"), 0o644); err != nil {
t.Fatal(err)
}
ref, err := store.Store(imgPath, media.MediaMeta{ContentType: "image/png"}, "test")
if err != nil {
t.Fatal(err)
}
msgs := []providers.Message{{Role: "user", Content: "describe", Media: []string{ref}}}
result := resolveMediaRefs(msgs, store)
if len(result[0].Media) != 1 {
t.Fatalf("expected 1 resolved media, got %d", len(result[0].Media))
}
if !strings.HasPrefix(result[0].Media[0], "data:image/png;base64,") {
t.Errorf("expected data URL, got %s", result[0].Media[0][:40])
}
}
func TestResolveMediaRefs_SkipsOversizedFile(t *testing.T) {
store := media.NewFileMediaStore()
bigPath := filepath.Join(t.TempDir(), "big.jpg")
if err := os.WriteFile(bigPath, make([]byte, maxMediaFileSize+1), 0o644); err != nil {
t.Fatal(err)
}
ref, err := store.Store(bigPath, media.MediaMeta{ContentType: "image/jpeg"}, "test")
if err != nil {
t.Fatal(err)
}
msgs := []providers.Message{{Role: "user", Content: "hi", Media: []string{ref}}}
result := resolveMediaRefs(msgs, store)
if len(result[0].Media) != 0 {
t.Error("oversized file should be skipped")
}
}
func TestResolveMediaRefs_SkipsUnknownExtension(t *testing.T) {
store := media.NewFileMediaStore()
txtPath := filepath.Join(t.TempDir(), "readme.txt")
if err := os.WriteFile(txtPath, []byte("hello"), 0o644); err != nil {
t.Fatal(err)
}
ref, err := store.Store(txtPath, media.MediaMeta{}, "test")
if err != nil {
t.Fatal(err)
}
msgs := []providers.Message{{Role: "user", Content: "hi", Media: []string{ref}}}
result := resolveMediaRefs(msgs, store)
if len(result[0].Media) != 0 {
t.Error("unknown extension with no ContentType should be skipped")
}
}
func TestResolveMediaRefs_DoesNotMutateOriginal(t *testing.T) {
store := media.NewFileMediaStore()
imgPath := filepath.Join(t.TempDir(), "test.jpg")
if err := os.WriteFile(imgPath, []byte("data"), 0o644); err != nil {
t.Fatal(err)
}
ref, _ := store.Store(imgPath, media.MediaMeta{ContentType: "image/jpeg"}, "test")
original := []providers.Message{{Role: "user", Content: "hi", Media: []string{ref}}}
resolveMediaRefs(original, store)
if !strings.HasPrefix(original[0].Media[0], "media://") {
t.Error("original message should not be mutated")
}
}
+6
View File
@@ -235,6 +235,12 @@ func serializeMessages(messages []Message) []map[string]interface{} {
"role": m.Role,
"content": parts,
}
if m.ToolCallID != "" {
msg["tool_call_id"] = m.ToolCallID
}
if len(m.ToolCalls) > 0 {
msg["tool_calls"] = m.ToolCalls
}
if m.ReasoningContent != "" {
msg["reasoning_content"] = m.ReasoningContent
}
@@ -411,3 +411,73 @@ func TestProvider_FunctionalOptionRequestTimeoutNonPositive(t *testing.T) {
t.Fatalf("http timeout = %v, want %v", p.httpClient.Timeout, defaultRequestTimeout)
}
}
func TestSerializeMessages_PlainText(t *testing.T) {
msgs := []Message{
{Role: "user", Content: "hello"},
{Role: "assistant", Content: "hi"},
}
result := serializeMessages(msgs)
if len(result) != 2 {
t.Fatalf("expected 2 messages, got %d", len(result))
}
if result[0]["content"] != "hello" {
t.Errorf("expected plain string content, got %v", result[0]["content"])
}
}
func TestSerializeMessages_WithMedia(t *testing.T) {
msgs := []Message{
{Role: "user", Content: "describe this", Media: []string{"data:image/png;base64,abc123"}},
}
result := serializeMessages(msgs)
if len(result) != 1 {
t.Fatalf("expected 1 message, got %d", len(result))
}
parts, ok := result[0]["content"].([]map[string]interface{})
if !ok {
t.Fatalf("expected content to be []map, got %T", result[0]["content"])
}
if len(parts) != 2 {
t.Fatalf("expected 2 parts (text + image), got %d", len(parts))
}
if parts[0]["type"] != "text" {
t.Errorf("expected first part type=text, got %v", parts[0]["type"])
}
if parts[1]["type"] != "image_url" {
t.Errorf("expected second part type=image_url, got %v", parts[1]["type"])
}
}
func TestSerializeMessages_WithMediaPreservesToolFields(t *testing.T) {
msgs := []Message{
{
Role: "assistant",
Content: "result",
Media: []string{"data:image/png;base64,abc"},
ToolCallID: "call_123",
ToolCalls: []ToolCall{{ID: "tc_1", Type: "function", Function: &FunctionCall{Name: "test", Arguments: "{}"}}},
ReasoningContent: "thinking...",
},
}
result := serializeMessages(msgs)
if result[0]["tool_call_id"] != "call_123" {
t.Errorf("expected tool_call_id=call_123, got %v", result[0]["tool_call_id"])
}
if result[0]["tool_calls"] == nil {
t.Error("expected tool_calls to be present")
}
if result[0]["reasoning_content"] != "thinking..." {
t.Errorf("expected reasoning_content, got %v", result[0]["reasoning_content"])
}
}
func TestSerializeMessages_EmptyMediaUsesPlainFormat(t *testing.T) {
msgs := []Message{
{Role: "user", Content: "hello", Media: []string{}},
}
result := serializeMessages(msgs)
if _, ok := result[0]["content"].(string); !ok {
t.Errorf("empty Media should use plain string format, got %T", result[0]["content"])
}
}