diff --git a/pkg/providers/openai_compat/provider.go b/pkg/providers/openai_compat/provider.go index 3a18b8b16..ff9109e96 100644 --- a/pkg/providers/openai_compat/provider.go +++ b/pkg/providers/openai_compat/provider.go @@ -116,7 +116,7 @@ func (p *Provider) Chat( requestBody := map[string]any{ "model": model, - "messages": stripSystemParts(messages), + "messages": serializeMessages(messages), } if len(tools) > 0 { @@ -296,19 +296,55 @@ type openaiMessage struct { ToolCallID string `json:"tool_call_id,omitempty"` } -// stripSystemParts converts []Message to []openaiMessage, dropping the -// SystemParts field so it doesn't leak into the JSON payload sent to -// OpenAI-compatible APIs (some strict endpoints reject unknown fields). -func stripSystemParts(messages []Message) []openaiMessage { - out := make([]openaiMessage, len(messages)) - for i, m := range messages { - out[i] = openaiMessage{ - Role: m.Role, - Content: m.Content, - ReasoningContent: m.ReasoningContent, - ToolCalls: m.ToolCalls, - ToolCallID: m.ToolCallID, +// serializeMessages converts internal Message structs to the OpenAI wire format. +// - Strips SystemParts (unknown to third-party endpoints) +// - Converts messages with Media to multipart content format (text + image_url parts) +// - Preserves ToolCallID, ToolCalls, and ReasoningContent for all messages +func serializeMessages(messages []Message) []any { + out := make([]any, 0, len(messages)) + for _, m := range messages { + if len(m.Media) == 0 { + out = append(out, openaiMessage{ + Role: m.Role, + Content: m.Content, + ReasoningContent: m.ReasoningContent, + ToolCalls: m.ToolCalls, + ToolCallID: m.ToolCallID, + }) + continue } + + // Multipart content format for messages with media + parts := make([]map[string]any, 0, 1+len(m.Media)) + if m.Content != "" { + parts = append(parts, map[string]any{ + "type": "text", + "text": m.Content, + }) + } + for _, mediaURL := range m.Media { + parts = append(parts, map[string]any{ + "type": "image_url", + "image_url": map[string]any{ + "url": mediaURL, + }, + }) + } + + msg := map[string]any{ + "role": m.Role, + "content": parts, + } + if m.ToolCallID != "" { + msg["tool_call_id"] = m.ToolCallID + } + if len(m.ToolCalls) > 0 { + msg["tool_calls"] = m.ToolCalls + } + if m.ReasoningContent != "" { + msg["reasoning_content"] = m.ReasoningContent + } + out = append(out, msg) } return out } diff --git a/pkg/providers/openai_compat/provider_test.go b/pkg/providers/openai_compat/provider_test.go index 53b9e75ee..9d3b91a1a 100644 --- a/pkg/providers/openai_compat/provider_test.go +++ b/pkg/providers/openai_compat/provider_test.go @@ -5,8 +5,11 @@ import ( "net/http" "net/http/httptest" "net/url" + "strings" "testing" "time" + + "github.com/sipeed/picoclaw/pkg/providers/protocoltypes" ) func TestProviderChat_UsesMaxCompletionTokensForGLM(t *testing.T) { @@ -416,3 +419,98 @@ func TestProvider_FunctionalOptionRequestTimeoutNonPositive(t *testing.T) { t.Fatalf("http timeout = %v, want %v", p.httpClient.Timeout, defaultRequestTimeout) } } + +func TestSerializeMessages_PlainText(t *testing.T) { + messages := []protocoltypes.Message{ + {Role: "user", Content: "hello"}, + {Role: "assistant", Content: "hi", ReasoningContent: "thinking..."}, + } + result := serializeMessages(messages) + + data, err := json.Marshal(result) + if err != nil { + t.Fatal(err) + } + + var msgs []map[string]any + json.Unmarshal(data, &msgs) + + if msgs[0]["content"] != "hello" { + t.Fatalf("expected plain string content, got %v", msgs[0]["content"]) + } + if msgs[1]["reasoning_content"] != "thinking..." { + t.Fatalf("reasoning_content not preserved, got %v", msgs[1]["reasoning_content"]) + } +} + +func TestSerializeMessages_WithMedia(t *testing.T) { + messages := []protocoltypes.Message{ + {Role: "user", Content: "describe this", Media: []string{"data:image/png;base64,abc123"}}, + } + result := serializeMessages(messages) + + data, _ := json.Marshal(result) + var msgs []map[string]any + json.Unmarshal(data, &msgs) + + content, ok := msgs[0]["content"].([]any) + if !ok { + t.Fatalf("expected array content for media message, got %T", msgs[0]["content"]) + } + if len(content) != 2 { + t.Fatalf("expected 2 content parts, got %d", len(content)) + } + + textPart := content[0].(map[string]any) + if textPart["type"] != "text" || textPart["text"] != "describe this" { + t.Fatalf("text part mismatch: %v", textPart) + } + + imgPart := content[1].(map[string]any) + if imgPart["type"] != "image_url" { + t.Fatalf("expected image_url type, got %v", imgPart["type"]) + } + imgURL := imgPart["image_url"].(map[string]any) + if imgURL["url"] != "data:image/png;base64,abc123" { + t.Fatalf("image url mismatch: %v", imgURL["url"]) + } +} + +func TestSerializeMessages_MediaWithToolCallID(t *testing.T) { + messages := []protocoltypes.Message{ + {Role: "tool", Content: "image result", Media: []string{"data:image/png;base64,xyz"}, ToolCallID: "call_1"}, + } + result := serializeMessages(messages) + + data, _ := json.Marshal(result) + var msgs []map[string]any + json.Unmarshal(data, &msgs) + + if msgs[0]["tool_call_id"] != "call_1" { + t.Fatalf("tool_call_id not preserved with media, got %v", msgs[0]["tool_call_id"]) + } + // Content should be multipart array + if _, ok := msgs[0]["content"].([]any); !ok { + t.Fatalf("expected array content, got %T", msgs[0]["content"]) + } +} + +func TestSerializeMessages_StripsSystemParts(t *testing.T) { + messages := []protocoltypes.Message{ + { + Role: "system", + Content: "you are helpful", + SystemParts: []protocoltypes.ContentBlock{ + {Type: "text", Text: "you are helpful"}, + }, + }, + } + result := serializeMessages(messages) + + data, _ := json.Marshal(result) + raw := string(data) + if strings.Contains(raw, "system_parts") { + t.Fatal("system_parts should not appear in serialized output") + } +} +