Merge branch 'sipeed:main' into main

This commit is contained in:
Phạm Minh Đạt
2026-03-24 09:58:46 +07:00
committed by GitHub
238 changed files with 36794 additions and 5839 deletions
+28 -16
View File
@@ -188,17 +188,23 @@ func buildRequestBody(
case "user":
if msg.ToolCallID != "" {
// Tool result message
content := []map[string]any{
{
"type": "tool_result",
"tool_use_id": msg.ToolCallID,
"content": msg.Content,
},
// Tool result message — merge into previous user message if it contains tool_results
toolResultBlock := map[string]any{
"type": "tool_result",
"tool_use_id": msg.ToolCallID,
"content": msg.Content,
}
if len(apiMessages) > 0 {
if prev, ok := apiMessages[len(apiMessages)-1].(map[string]any); ok && prev["role"] == "user" {
if content, ok := prev["content"].([]map[string]any); ok {
prev["content"] = append(content, toolResultBlock)
continue
}
}
}
apiMessages = append(apiMessages, map[string]any{
"role": "user",
"content": content,
"content": []map[string]any{toolResultBlock},
})
} else {
// Regular user message
@@ -246,17 +252,23 @@ func buildRequestBody(
})
case "tool":
// Tool result (alternative format)
content := []map[string]any{
{
"type": "tool_result",
"tool_use_id": msg.ToolCallID,
"content": msg.Content,
},
// Tool result (alternative format) — merge into previous user message if it contains tool_results
toolResultBlock := map[string]any{
"type": "tool_result",
"tool_use_id": msg.ToolCallID,
"content": msg.Content,
}
if len(apiMessages) > 0 {
if prev, ok := apiMessages[len(apiMessages)-1].(map[string]any); ok && prev["role"] == "user" {
if content, ok := prev["content"].([]map[string]any); ok {
prev["content"] = append(content, toolResultBlock)
continue
}
}
}
apiMessages = append(apiMessages, map[string]any{
"role": "user",
"content": content,
"content": []map[string]any{toolResultBlock},
})
}
}
@@ -562,6 +562,96 @@ func TestBuildRequestBodyEdgeCases(t *testing.T) {
}
}
func TestBuildRequestBody_ConsecutiveToolResultsMerged(t *testing.T) {
// Consecutive tool results (role "tool") should be merged into a single "user" message
messages := []Message{
{Role: "user", Content: "Use tools"},
{Role: "assistant", Content: "", ToolCalls: []ToolCall{
{ID: "t1", Name: "tool_a", Arguments: map[string]any{"x": 1}},
{ID: "t2", Name: "tool_b", Arguments: map[string]any{"y": 2}},
}},
{Role: "tool", ToolCallID: "t1", Content: "result1"},
{Role: "tool", ToolCallID: "t2", Content: "result2"},
}
got, err := buildRequestBody(messages, nil, "test-model", map[string]any{"max_tokens": 8192})
if err != nil {
t.Fatalf("buildRequestBody() error: %v", err)
}
apiMessages, ok := got["messages"].([]any)
if !ok {
t.Fatalf("messages is not []any")
}
// Expect: user, assistant, user (merged tool results)
if len(apiMessages) != 3 {
for i, m := range apiMessages {
t.Logf("message[%d]: %+v", i, m)
}
t.Fatalf("expected 3 API messages, got %d", len(apiMessages))
}
// The third message should be a user message with 2 tool_result blocks
toolResultMsg, ok := apiMessages[2].(map[string]any)
if !ok {
t.Fatalf("tool result message is not map[string]any")
}
if toolResultMsg["role"] != "user" {
t.Errorf("expected role 'user', got %v", toolResultMsg["role"])
}
content, ok := toolResultMsg["content"].([]map[string]any)
if !ok {
t.Fatalf("content is not []map[string]any: %T", toolResultMsg["content"])
}
if len(content) != 2 {
t.Fatalf("expected 2 tool_result blocks, got %d", len(content))
}
if content[0]["tool_use_id"] != "t1" {
t.Errorf("first tool_result tool_use_id = %v, want t1", content[0]["tool_use_id"])
}
if content[1]["tool_use_id"] != "t2" {
t.Errorf("second tool_result tool_use_id = %v, want t2", content[1]["tool_use_id"])
}
}
func TestBuildRequestBody_UserToolResultsMerged(t *testing.T) {
// Consecutive tool results using role "user" with ToolCallID should also be merged
messages := []Message{
{Role: "user", Content: "Use tools"},
{Role: "assistant", Content: "", ToolCalls: []ToolCall{
{ID: "t1", Name: "tool_a", Arguments: map[string]any{"x": 1}},
{ID: "t2", Name: "tool_b", Arguments: map[string]any{"y": 2}},
}},
{Role: "user", ToolCallID: "t1", Content: "result1"},
{Role: "user", ToolCallID: "t2", Content: "result2"},
}
got, err := buildRequestBody(messages, nil, "test-model", map[string]any{"max_tokens": 8192})
if err != nil {
t.Fatalf("buildRequestBody() error: %v", err)
}
apiMessages, ok := got["messages"].([]any)
if !ok {
t.Fatalf("messages is not []any")
}
// Expect: user, assistant, user (merged tool results)
if len(apiMessages) != 3 {
t.Fatalf("expected 3 API messages, got %d", len(apiMessages))
}
toolResultMsg := apiMessages[2].(map[string]any)
content, ok := toolResultMsg["content"].([]map[string]any)
if !ok {
t.Fatalf("content is not []map[string]any: %T", toolResultMsg["content"])
}
if len(content) != 2 {
t.Fatalf("expected 2 tool_result blocks, got %d", len(content))
}
}
// TestParseResponseBodyEdgeCases tests edge cases for parseResponseBody.
func TestParseResponseBodyEdgeCases(t *testing.T) {
tests := []struct {
+582
View File
@@ -0,0 +1,582 @@
//go:build bedrock
// PicoClaw - Ultra-lightweight personal AI agent
// License: MIT
//
// Copyright (c) 2026 PicoClaw contributors
// Package bedrock implements the LLM provider interface for AWS Bedrock.
// It uses the Bedrock Runtime Converse API for unified access to multiple
// model families (Claude, Llama, Mistral, etc.) with tool/function calling support.
package bedrock
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"log"
"math"
"strings"
"time"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime/document"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types"
"github.com/sipeed/picoclaw/pkg/providers/common"
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
)
type (
ToolCall = protocoltypes.ToolCall
FunctionCall = protocoltypes.FunctionCall
LLMResponse = protocoltypes.LLMResponse
UsageInfo = protocoltypes.UsageInfo
Message = protocoltypes.Message
ToolDefinition = protocoltypes.ToolDefinition
ToolFunctionDefinition = protocoltypes.ToolFunctionDefinition
)
// Provider implements the LLM provider interface for AWS Bedrock.
type Provider struct {
client *bedrockruntime.Client
region string
requestTimeout time.Duration
}
// Option configures the Bedrock Provider.
type Option func(*providerConfig)
type providerConfig struct {
region string
profile string
baseEndpoint string
requestTimeout time.Duration
}
// WithRegion sets the AWS region for Bedrock requests.
func WithRegion(region string) Option {
return func(c *providerConfig) {
c.region = region
}
}
// WithProfile sets the AWS profile to use for credentials.
func WithProfile(profile string) Option {
return func(c *providerConfig) {
c.profile = profile
}
}
// WithBaseEndpoint sets a custom Bedrock endpoint URL.
// Example: https://bedrock-runtime.us-east-1.amazonaws.com
func WithBaseEndpoint(endpoint string) Option {
return func(c *providerConfig) {
c.baseEndpoint = endpoint
}
}
// WithRequestTimeout sets the timeout for Bedrock API requests.
func WithRequestTimeout(timeout time.Duration) Option {
return func(c *providerConfig) {
c.requestTimeout = timeout
}
}
// NewProvider creates a new AWS Bedrock provider.
// It uses the default AWS credential chain (env vars, shared config, IAM roles, etc.).
func NewProvider(ctx context.Context, opts ...Option) (*Provider, error) {
pc := &providerConfig{}
for _, opt := range opts {
opt(pc)
}
// Build AWS config options
var configOpts []func(*config.LoadOptions) error
if pc.region != "" {
configOpts = append(configOpts, config.WithRegion(pc.region))
}
if pc.profile != "" {
configOpts = append(configOpts, config.WithSharedConfigProfile(pc.profile))
}
// Load AWS config with automatic credential discovery
cfg, err := config.LoadDefaultConfig(ctx, configOpts...)
if err != nil {
return nil, fmt.Errorf("loading AWS config: %w", err)
}
// Validate region is set - required for Bedrock request signing
if cfg.Region == "" {
return nil, fmt.Errorf(
"AWS region not configured: set AWS_REGION, AWS_DEFAULT_REGION, or use WithRegion option",
)
}
// Build client options
var clientOpts []func(*bedrockruntime.Options)
if pc.baseEndpoint != "" {
clientOpts = append(clientOpts, func(o *bedrockruntime.Options) {
o.BaseEndpoint = aws.String(pc.baseEndpoint)
})
}
client := bedrockruntime.NewFromConfig(cfg, clientOpts...)
return &Provider{
client: client,
region: cfg.Region,
requestTimeout: pc.requestTimeout,
}, nil
}
// Chat sends messages to AWS Bedrock using the Converse API.
func (p *Provider) Chat(
ctx context.Context,
messages []Message,
tools []ToolDefinition,
model string,
options map[string]any,
) (*LLMResponse, error) {
// Apply request timeout if context doesn't already have a deadline.
// Use explicit timeout if set, otherwise fall back to common default.
effectiveTimeout := p.requestTimeout
if effectiveTimeout <= 0 {
effectiveTimeout = common.DefaultRequestTimeout
}
if _, hasDeadline := ctx.Deadline(); !hasDeadline {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, effectiveTimeout)
defer cancel()
}
// Build the Converse API input
input := &bedrockruntime.ConverseInput{
ModelId: aws.String(model),
}
// Convert messages to Bedrock format
bedrockMessages, systemPrompts := convertMessages(messages)
input.Messages = bedrockMessages
// Set system prompts if any
if len(systemPrompts) > 0 {
input.System = systemPrompts
}
// Set inference configuration only when options are provided
var inferenceConfig *types.InferenceConfiguration
if maxTokens, ok := common.AsInt(options["max_tokens"]); ok && maxTokens > 0 {
if inferenceConfig == nil {
inferenceConfig = &types.InferenceConfiguration{}
}
// Clamp to int32 range to avoid overflow
if maxTokens > math.MaxInt32 {
maxTokens = math.MaxInt32
}
inferenceConfig.MaxTokens = aws.Int32(int32(maxTokens))
}
if temp, ok := common.AsFloat(options["temperature"]); ok {
if inferenceConfig == nil {
inferenceConfig = &types.InferenceConfiguration{}
}
inferenceConfig.Temperature = aws.Float32(float32(temp))
}
if inferenceConfig != nil {
input.InferenceConfig = inferenceConfig
}
// Convert tools to Bedrock format
// Only set ToolConfig if at least one valid tool was produced
if len(tools) > 0 {
toolConfig := convertTools(tools)
if len(toolConfig.Tools) > 0 {
input.ToolConfig = toolConfig
}
}
// Call Bedrock Converse API
output, err := p.client.Converse(ctx, input)
if err != nil {
return nil, fmt.Errorf("bedrock converse: %w", err)
}
// Parse the response
return parseResponse(output)
}
// GetDefaultModel returns an empty string as Bedrock models are user-configured.
func (p *Provider) GetDefaultModel() string {
return ""
}
// Region returns the AWS region configured for this Provider.
func (p *Provider) Region() string {
return p.region
}
// convertMessages converts internal messages to Bedrock Converse format.
// Returns the conversation messages and any system prompts separately.
// Note: Bedrock requires all tool results for a given assistant turn to be in a single
// user message with multiple ToolResultBlock content blocks. This function merges
// consecutive tool result messages accordingly.
func convertMessages(messages []Message) ([]types.Message, []types.SystemContentBlock) {
var bedrockMessages []types.Message
var systemPrompts []types.SystemContentBlock
// Helper to check if a message is a tool result
isToolResult := func(msg Message) bool {
return (msg.Role == "tool" || (msg.Role == "user" && msg.ToolCallID != "")) && msg.ToolCallID != ""
}
// Helper to create a tool result content block
makeToolResultBlock := func(msg Message) types.ContentBlock {
return &types.ContentBlockMemberToolResult{
Value: types.ToolResultBlock{
ToolUseId: aws.String(msg.ToolCallID),
Content: []types.ToolResultContentBlock{
&types.ToolResultContentBlockMemberText{
Value: msg.Content,
},
},
},
}
}
i := 0
for i < len(messages) {
msg := messages[i]
switch {
case msg.Role == "system":
// System messages go to the System field
systemPrompts = append(systemPrompts, &types.SystemContentBlockMemberText{
Value: msg.Content,
})
i++
case isToolResult(msg):
// Collect all consecutive tool results into a single user message
// Bedrock requires all tool results for a turn in one message
var toolResultBlocks []types.ContentBlock
for i < len(messages) && isToolResult(messages[i]) {
toolResultBlocks = append(toolResultBlocks, makeToolResultBlock(messages[i]))
i++
}
bedrockMessages = append(bedrockMessages, types.Message{
Role: types.ConversationRoleUser,
Content: toolResultBlocks,
})
case msg.Role == "user":
// Regular user message (no ToolCallID)
content := buildUserContent(msg)
bedrockMessages = append(bedrockMessages, types.Message{
Role: types.ConversationRoleUser,
Content: content,
})
i++
case msg.Role == "assistant":
content := buildAssistantContent(msg)
bedrockMessages = append(bedrockMessages, types.Message{
Role: types.ConversationRoleAssistant,
Content: content,
})
i++
case msg.Role == "tool" && msg.ToolCallID == "":
// Tool message without ToolCallID - treat as regular user message
content := buildUserContent(msg)
bedrockMessages = append(bedrockMessages, types.Message{
Role: types.ConversationRoleUser,
Content: content,
})
i++
default:
// Unknown role - skip
i++
}
}
return bedrockMessages, systemPrompts
}
// buildUserContent builds Bedrock content blocks for a user message.
func buildUserContent(msg Message) []types.ContentBlock {
var content []types.ContentBlock
// Add text content
if msg.Content != "" {
content = append(content, &types.ContentBlockMemberText{
Value: msg.Content,
})
}
// Add images from Media field
for _, mediaURL := range msg.Media {
if strings.HasPrefix(mediaURL, "data:image/") {
// Parse data URL: data:image/jpeg;base64,<data>
parts := strings.SplitN(mediaURL, ",", 2)
if len(parts) != 2 {
continue
}
// Extract media type from "data:image/jpeg;base64"
mediaType := ""
header := parts[0]
if idx := strings.Index(header, "/"); idx != -1 {
end := strings.Index(header[idx:], ";")
if end == -1 {
end = len(header) - idx
}
mediaType = header[idx+1 : idx+end]
}
// Verify this is base64 encoded
if !strings.Contains(header, ";base64") {
continue // Skip non-base64 encoded data
}
// Map media type to Bedrock format
var format types.ImageFormat
switch mediaType {
case "jpeg", "jpg":
format = types.ImageFormatJpeg
case "png":
format = types.ImageFormatPng
case "gif":
format = types.ImageFormatGif
case "webp":
format = types.ImageFormatWebp
default:
continue // Skip unsupported formats
}
// Check size before decoding to prevent excessive memory allocation
// Bedrock has a ~20MB request limit; cap decoded images at 10MB
const maxImageSize = 10 * 1024 * 1024
decodedLen := base64.StdEncoding.DecodedLen(len(parts[1]))
if decodedLen > maxImageSize {
log.Printf("bedrock: skipping image exceeding size limit (%d bytes > %d)", decodedLen, maxImageSize)
continue
}
// Decode base64 data
imageData, err := base64.StdEncoding.DecodeString(parts[1])
if err != nil {
log.Printf("bedrock: failed to decode base64 image data: %v", err)
continue
}
content = append(content, &types.ContentBlockMemberImage{
Value: types.ImageBlock{
Format: format,
Source: &types.ImageSourceMemberBytes{
Value: imageData,
},
},
})
}
}
// Bedrock requires at least one content block; add empty text if needed
if len(content) == 0 {
content = append(content, &types.ContentBlockMemberText{Value: ""})
}
return content
}
// buildAssistantContent builds Bedrock content blocks for an assistant message.
func buildAssistantContent(msg Message) []types.ContentBlock {
var content []types.ContentBlock
// Add text content if present
if msg.Content != "" {
content = append(content, &types.ContentBlockMemberText{
Value: msg.Content,
})
}
// Add tool use blocks
for _, tc := range msg.ToolCalls {
// Validate tool call ID - Bedrock requires non-empty ToolUseId
if strings.TrimSpace(tc.ID) == "" {
log.Printf("bedrock: skipping tool call with empty ID (name: %q)", tc.Name)
continue
}
// Resolve tool name: prefer tc.Name, fallback to tc.Function.Name
// (tc.Name/tc.Arguments are json:"-" and may be empty when from JSON)
toolName := tc.Name
if toolName == "" && tc.Function != nil {
toolName = tc.Function.Name
}
if strings.TrimSpace(toolName) == "" {
continue
}
// Resolve arguments: prefer tc.Arguments, fallback to parsing tc.Function.Arguments
args := tc.Arguments
if args == nil && tc.Function != nil && tc.Function.Arguments != "" {
if err := json.Unmarshal([]byte(tc.Function.Arguments), &args); err != nil {
log.Printf("bedrock: failed to parse Function.Arguments for tool %q: %v", toolName, err)
args = map[string]any{}
}
}
if args == nil {
args = map[string]any{}
}
// Convert arguments to a Bedrock document using NewLazyDocument
inputDoc := document.NewLazyDocument(args)
content = append(content, &types.ContentBlockMemberToolUse{
Value: types.ToolUseBlock{
ToolUseId: aws.String(tc.ID),
Name: aws.String(toolName),
Input: inputDoc,
},
})
}
// Bedrock requires at least one content block; add empty text if needed
if len(content) == 0 {
content = append(content, &types.ContentBlockMemberText{Value: ""})
}
return content
}
// convertTools converts tool definitions to Bedrock format.
func convertTools(tools []ToolDefinition) *types.ToolConfiguration {
bedrockTools := make([]types.Tool, 0, len(tools))
for _, tool := range tools {
// Skip tools with empty names
if strings.TrimSpace(tool.Function.Name) == "" {
continue
}
// Ensure parameters is not nil - default to minimal object schema
params := tool.Function.Parameters
if params == nil {
params = map[string]any{
"type": "object",
"properties": map[string]any{},
}
}
// Convert parameters schema to a Bedrock document
inputSchema := document.NewLazyDocument(params)
bedrockTools = append(bedrockTools, &types.ToolMemberToolSpec{
Value: types.ToolSpecification{
Name: aws.String(tool.Function.Name),
Description: aws.String(tool.Function.Description),
InputSchema: &types.ToolInputSchemaMemberJson{
Value: inputSchema,
},
},
})
}
return &types.ToolConfiguration{
Tools: bedrockTools,
}
}
// parseResponse converts Bedrock Converse output to LLMResponse.
func parseResponse(output *bedrockruntime.ConverseOutput) (*LLMResponse, error) {
var content strings.Builder
toolCalls := make([]ToolCall, 0)
// Process output content blocks
if output.Output != nil {
if msgOutput, ok := output.Output.(*types.ConverseOutputMemberMessage); ok {
for _, block := range msgOutput.Value.Content {
switch b := block.(type) {
case *types.ContentBlockMemberText:
content.WriteString(b.Value)
case *types.ContentBlockMemberToolUse:
// Unmarshal the document interface to a map
args := make(map[string]any)
if b.Value.Input != nil {
if err := b.Value.Input.UnmarshalSmithyDocument(&args); err != nil {
log.Printf("bedrock: failed to unmarshal tool input for tool %q (id %q): %v",
aws.ToString(b.Value.Name),
aws.ToString(b.Value.ToolUseId),
err,
)
args = make(map[string]any)
}
}
// Serialize arguments to JSON string for FunctionCall
argsJSON, err := json.Marshal(args)
if err != nil {
log.Printf("bedrock: failed to marshal tool arguments for tool %q (id %q): %v",
aws.ToString(b.Value.Name),
aws.ToString(b.Value.ToolUseId),
err,
)
argsJSON = []byte("{}")
}
toolCalls = append(toolCalls, ToolCall{
ID: aws.ToString(b.Value.ToolUseId),
Name: aws.ToString(b.Value.Name),
Arguments: args,
Function: &FunctionCall{
Name: aws.ToString(b.Value.Name),
Arguments: string(argsJSON),
},
})
}
}
}
}
// Map stop reason
finishReason := "stop"
switch output.StopReason {
case types.StopReasonToolUse:
finishReason = "tool_calls"
case types.StopReasonMaxTokens:
finishReason = "length"
case types.StopReasonEndTurn:
finishReason = "stop"
case types.StopReasonStopSequence:
finishReason = "stop"
case types.StopReasonContentFiltered:
finishReason = "content_filter"
}
// Build usage info
var usage *UsageInfo
if output.Usage != nil {
usage = &UsageInfo{
PromptTokens: int(aws.ToInt32(output.Usage.InputTokens)),
CompletionTokens: int(aws.ToInt32(output.Usage.OutputTokens)),
TotalTokens: int(aws.ToInt32(output.Usage.InputTokens)) + int(aws.ToInt32(output.Usage.OutputTokens)),
}
}
return &LLMResponse{
Content: content.String(),
ToolCalls: toolCalls,
FinishReason: finishReason,
Usage: usage,
}, nil
}
@@ -0,0 +1,541 @@
//go:build bedrock
// PicoClaw - Ultra-lightweight personal AI agent
// License: MIT
//
// Copyright (c) 2026 PicoClaw contributors
package bedrock
import (
"testing"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime/document"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
)
func TestConvertMessages_SystemPrompts(t *testing.T) {
messages := []Message{
{Role: "system", Content: "You are a helpful assistant."},
{Role: "user", Content: "Hello"},
}
bedrockMsgs, systemPrompts := convertMessages(messages)
assert.Len(t, systemPrompts, 1)
assert.Len(t, bedrockMsgs, 1)
// Check system prompt
textBlock, ok := systemPrompts[0].(*types.SystemContentBlockMemberText)
require.True(t, ok)
assert.Equal(t, "You are a helpful assistant.", textBlock.Value)
// Check user message
assert.Equal(t, types.ConversationRoleUser, bedrockMsgs[0].Role)
}
func TestConvertMessages_UserMessage(t *testing.T) {
messages := []Message{
{Role: "user", Content: "What is 2+2?"},
}
bedrockMsgs, systemPrompts := convertMessages(messages)
assert.Empty(t, systemPrompts)
assert.Len(t, bedrockMsgs, 1)
assert.Equal(t, types.ConversationRoleUser, bedrockMsgs[0].Role)
textBlock, ok := bedrockMsgs[0].Content[0].(*types.ContentBlockMemberText)
require.True(t, ok)
assert.Equal(t, "What is 2+2?", textBlock.Value)
}
func TestConvertMessages_AssistantMessage(t *testing.T) {
messages := []Message{
{Role: "assistant", Content: "The answer is 4."},
}
bedrockMsgs, _ := convertMessages(messages)
assert.Len(t, bedrockMsgs, 1)
assert.Equal(t, types.ConversationRoleAssistant, bedrockMsgs[0].Role)
textBlock, ok := bedrockMsgs[0].Content[0].(*types.ContentBlockMemberText)
require.True(t, ok)
assert.Equal(t, "The answer is 4.", textBlock.Value)
}
func TestConvertMessages_ToolResult(t *testing.T) {
messages := []Message{
{Role: "tool", Content: "Result from tool", ToolCallID: "call_123"},
}
bedrockMsgs, _ := convertMessages(messages)
assert.Len(t, bedrockMsgs, 1)
assert.Equal(t, types.ConversationRoleUser, bedrockMsgs[0].Role)
toolResult, ok := bedrockMsgs[0].Content[0].(*types.ContentBlockMemberToolResult)
require.True(t, ok)
assert.Equal(t, "call_123", aws.ToString(toolResult.Value.ToolUseId))
}
func TestConvertMessages_MultipleToolResultsMerged(t *testing.T) {
// When an assistant makes multiple tool calls, all tool results must be
// merged into a single user message for Bedrock
messages := []Message{
{Role: "user", Content: "What's the weather in NYC and LA?"},
{
Role: "assistant",
Content: "Let me check both cities.",
ToolCalls: []protocoltypes.ToolCall{
{ID: "call_nyc", Name: "get_weather", Arguments: map[string]any{"city": "NYC"}},
{ID: "call_la", Name: "get_weather", Arguments: map[string]any{"city": "LA"}},
},
},
{Role: "tool", Content: "NYC: 72°F, sunny", ToolCallID: "call_nyc"},
{Role: "tool", Content: "LA: 85°F, clear", ToolCallID: "call_la"},
}
bedrockMsgs, _ := convertMessages(messages)
// Should be: user message, assistant message, merged tool results (single user message)
assert.Len(t, bedrockMsgs, 3)
// First message: user
assert.Equal(t, types.ConversationRoleUser, bedrockMsgs[0].Role)
// Second message: assistant with tool calls
assert.Equal(t, types.ConversationRoleAssistant, bedrockMsgs[1].Role)
// Third message: merged tool results in single user message
assert.Equal(t, types.ConversationRoleUser, bedrockMsgs[2].Role)
assert.Len(t, bedrockMsgs[2].Content, 2) // Both tool results in one message
// Verify both tool results are present
result1, ok := bedrockMsgs[2].Content[0].(*types.ContentBlockMemberToolResult)
require.True(t, ok)
assert.Equal(t, "call_nyc", aws.ToString(result1.Value.ToolUseId))
result2, ok := bedrockMsgs[2].Content[1].(*types.ContentBlockMemberToolResult)
require.True(t, ok)
assert.Equal(t, "call_la", aws.ToString(result2.Value.ToolUseId))
}
func TestConvertMessages_AssistantWithToolCalls(t *testing.T) {
messages := []Message{
{
Role: "assistant",
Content: "Let me calculate that.",
ToolCalls: []protocoltypes.ToolCall{
{
ID: "call_456",
Name: "calculator",
Arguments: map[string]any{"expression": "2+2"},
},
},
},
}
bedrockMsgs, _ := convertMessages(messages)
assert.Len(t, bedrockMsgs, 1)
assert.Len(t, bedrockMsgs[0].Content, 2) // text + tool use
// Check text content
textBlock, ok := bedrockMsgs[0].Content[0].(*types.ContentBlockMemberText)
require.True(t, ok)
assert.Equal(t, "Let me calculate that.", textBlock.Value)
// Check tool use
toolUse, ok := bedrockMsgs[0].Content[1].(*types.ContentBlockMemberToolUse)
require.True(t, ok)
assert.Equal(t, "call_456", aws.ToString(toolUse.Value.ToolUseId))
assert.Equal(t, "calculator", aws.ToString(toolUse.Value.Name))
}
func TestConvertTools_Basic(t *testing.T) {
tools := []ToolDefinition{
{
Function: protocoltypes.ToolFunctionDefinition{
Name: "get_weather",
Description: "Get the current weather",
Parameters: map[string]any{
"type": "object",
"properties": map[string]any{
"location": map[string]any{"type": "string"},
},
},
},
},
}
toolConfig := convertTools(tools)
assert.NotNil(t, toolConfig)
assert.Len(t, toolConfig.Tools, 1)
toolSpec, ok := toolConfig.Tools[0].(*types.ToolMemberToolSpec)
require.True(t, ok)
assert.Equal(t, "get_weather", aws.ToString(toolSpec.Value.Name))
assert.Equal(t, "Get the current weather", aws.ToString(toolSpec.Value.Description))
}
func TestConvertTools_SkipsEmptyName(t *testing.T) {
tools := []ToolDefinition{
{
Function: protocoltypes.ToolFunctionDefinition{
Name: "",
Description: "Empty name tool",
},
},
{
Function: protocoltypes.ToolFunctionDefinition{
Name: " ",
Description: "Whitespace name tool",
},
},
{
Function: protocoltypes.ToolFunctionDefinition{
Name: "valid_tool",
Description: "Valid tool",
},
},
}
toolConfig := convertTools(tools)
assert.Len(t, toolConfig.Tools, 1)
toolSpec := toolConfig.Tools[0].(*types.ToolMemberToolSpec)
assert.Equal(t, "valid_tool", aws.ToString(toolSpec.Value.Name))
}
func TestConvertTools_NilParameters(t *testing.T) {
tools := []ToolDefinition{
{
Function: protocoltypes.ToolFunctionDefinition{
Name: "simple_tool",
Description: "A tool with no parameters",
Parameters: nil,
},
},
}
toolConfig := convertTools(tools)
assert.Len(t, toolConfig.Tools, 1)
// Should not panic and should create a valid tool
}
func TestBuildUserContent_TextOnly(t *testing.T) {
msg := Message{Content: "Hello world"}
content := buildUserContent(msg)
assert.Len(t, content, 1)
textBlock, ok := content[0].(*types.ContentBlockMemberText)
require.True(t, ok)
assert.Equal(t, "Hello world", textBlock.Value)
}
func TestBuildUserContent_WithImage(t *testing.T) {
// Base64-encoded 1x1 PNG (the provider doesn't validate image correctness,
// it just verifies the format and base64 decoding works)
b64Data := "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADUlEQVR4nGNgYAAAAAMAASsJTYQAAAAASUVORK5CYII="
msg := Message{
Content: "Look at this image",
Media: []string{"data:image/png;base64," + b64Data},
}
content := buildUserContent(msg)
assert.Len(t, content, 2)
// Check text
textBlock, ok := content[0].(*types.ContentBlockMemberText)
require.True(t, ok)
assert.Equal(t, "Look at this image", textBlock.Value)
// Check image
imageBlock, ok := content[1].(*types.ContentBlockMemberImage)
require.True(t, ok)
assert.Equal(t, types.ImageFormatPng, imageBlock.Value.Format)
}
func TestBuildUserContent_SkipsInvalidBase64(t *testing.T) {
msg := Message{
Content: "Invalid image",
Media: []string{"data:image/png;base64,not-valid-base64!!!"},
}
content := buildUserContent(msg)
// Should only have text, image should be skipped
assert.Len(t, content, 1)
}
func TestBuildUserContent_SkipsNonBase64Data(t *testing.T) {
msg := Message{
Content: "Non-base64 image",
Media: []string{"data:image/png,raw-data-here"},
}
content := buildUserContent(msg)
// Should only have text, non-base64 image should be skipped
assert.Len(t, content, 1)
}
func TestBuildAssistantContent_SkipsEmptyToolName(t *testing.T) {
msg := Message{
Content: "Response",
ToolCalls: []protocoltypes.ToolCall{
{ID: "1", Name: "", Arguments: map[string]any{}},
{ID: "2", Name: " ", Arguments: map[string]any{}},
{ID: "3", Name: "valid", Arguments: map[string]any{}},
},
}
content := buildAssistantContent(msg)
// Should have text + 1 valid tool
assert.Len(t, content, 2)
}
func TestBuildAssistantContent_NilArguments(t *testing.T) {
msg := Message{
ToolCalls: []protocoltypes.ToolCall{
{ID: "1", Name: "tool", Arguments: nil},
},
}
content := buildAssistantContent(msg)
assert.Len(t, content, 1)
toolUse, ok := content[0].(*types.ContentBlockMemberToolUse)
require.True(t, ok)
assert.NotNil(t, toolUse.Value.Input)
}
func TestBuildAssistantContent_FunctionFallback(t *testing.T) {
// When Name/Arguments are empty (json:"-"), should fallback to Function fields
msg := Message{
ToolCalls: []protocoltypes.ToolCall{
{
ID: "1",
Name: "", // empty, should fallback to Function.Name
Function: &protocoltypes.FunctionCall{
Name: "fallback_tool",
Arguments: `{"key":"value"}`,
},
},
},
}
content := buildAssistantContent(msg)
assert.Len(t, content, 1)
toolUse, ok := content[0].(*types.ContentBlockMemberToolUse)
require.True(t, ok)
assert.Equal(t, "fallback_tool", aws.ToString(toolUse.Value.Name))
}
func TestParseResponse_TextOnly(t *testing.T) {
output := &bedrockruntime.ConverseOutput{
Output: &types.ConverseOutputMemberMessage{
Value: types.Message{
Role: types.ConversationRoleAssistant,
Content: []types.ContentBlock{
&types.ContentBlockMemberText{Value: "Hello!"},
},
},
},
StopReason: types.StopReasonEndTurn,
Usage: &types.TokenUsage{
InputTokens: aws.Int32(10),
OutputTokens: aws.Int32(5),
},
}
resp, err := parseResponse(output)
require.NoError(t, err)
assert.Equal(t, "Hello!", resp.Content)
assert.Equal(t, "stop", resp.FinishReason)
assert.Empty(t, resp.ToolCalls)
assert.Equal(t, 10, resp.Usage.PromptTokens)
assert.Equal(t, 5, resp.Usage.CompletionTokens)
}
func TestParseResponse_StopReasons(t *testing.T) {
tests := []struct {
stopReason types.StopReason
expectedFinish string
}{
{types.StopReasonEndTurn, "stop"},
{types.StopReasonToolUse, "tool_calls"},
{types.StopReasonMaxTokens, "length"},
{types.StopReasonStopSequence, "stop"},
{types.StopReasonContentFiltered, "content_filter"},
}
for _, tt := range tests {
t.Run(string(tt.stopReason), func(t *testing.T) {
output := &bedrockruntime.ConverseOutput{
Output: &types.ConverseOutputMemberMessage{
Value: types.Message{
Content: []types.ContentBlock{
&types.ContentBlockMemberText{Value: "test"},
},
},
},
StopReason: tt.stopReason,
}
resp, err := parseResponse(output)
require.NoError(t, err)
assert.Equal(t, tt.expectedFinish, resp.FinishReason)
})
}
}
func TestParseResponse_WithToolCalls(t *testing.T) {
// Note: document.NewLazyDocument has limitations with UnmarshalSmithyDocument in tests,
// so we test the structure extraction and verify Arguments gets populated (even if empty
// due to SDK limitations). The actual unmarshal works correctly at runtime.
toolInput := document.NewLazyDocument(map[string]any{
"location": "San Francisco",
"unit": "celsius",
})
output := &bedrockruntime.ConverseOutput{
Output: &types.ConverseOutputMemberMessage{
Value: types.Message{
Role: types.ConversationRoleAssistant,
Content: []types.ContentBlock{
&types.ContentBlockMemberText{Value: "Let me check the weather."},
&types.ContentBlockMemberToolUse{
Value: types.ToolUseBlock{
ToolUseId: aws.String("call_weather_123"),
Name: aws.String("get_weather"),
Input: toolInput,
},
},
},
},
},
StopReason: types.StopReasonToolUse,
Usage: &types.TokenUsage{
InputTokens: aws.Int32(20),
OutputTokens: aws.Int32(15),
},
}
resp, err := parseResponse(output)
require.NoError(t, err)
assert.Equal(t, "Let me check the weather.", resp.Content)
assert.Equal(t, "tool_calls", resp.FinishReason)
assert.Len(t, resp.ToolCalls, 1)
// Verify tool call ID and Name are extracted correctly
tc := resp.ToolCalls[0]
assert.Equal(t, "call_weather_123", tc.ID)
assert.Equal(t, "get_weather", tc.Name)
// Verify Function fields are also populated
require.NotNil(t, tc.Function)
assert.Equal(t, "get_weather", tc.Function.Name)
// Verify Arguments is not nil (content may vary due to SDK limitations in tests)
assert.NotNil(t, tc.Arguments)
// Verify usage
assert.Equal(t, 20, resp.Usage.PromptTokens)
assert.Equal(t, 15, resp.Usage.CompletionTokens)
assert.Equal(t, 35, resp.Usage.TotalTokens)
}
func TestParseResponse_MultipleToolCalls(t *testing.T) {
output := &bedrockruntime.ConverseOutput{
Output: &types.ConverseOutputMemberMessage{
Value: types.Message{
Role: types.ConversationRoleAssistant,
Content: []types.ContentBlock{
&types.ContentBlockMemberToolUse{
Value: types.ToolUseBlock{
ToolUseId: aws.String("call_1"),
Name: aws.String("tool_a"),
Input: document.NewLazyDocument(map[string]any{"arg": "value1"}),
},
},
&types.ContentBlockMemberToolUse{
Value: types.ToolUseBlock{
ToolUseId: aws.String("call_2"),
Name: aws.String("tool_b"),
Input: document.NewLazyDocument(map[string]any{"arg": "value2"}),
},
},
},
},
},
StopReason: types.StopReasonToolUse,
}
resp, err := parseResponse(output)
require.NoError(t, err)
assert.Equal(t, "tool_calls", resp.FinishReason)
assert.Len(t, resp.ToolCalls, 2)
// Verify tool call structure
assert.Equal(t, "call_1", resp.ToolCalls[0].ID)
assert.Equal(t, "tool_a", resp.ToolCalls[0].Name)
assert.NotNil(t, resp.ToolCalls[0].Arguments)
assert.NotNil(t, resp.ToolCalls[0].Function)
assert.Equal(t, "tool_a", resp.ToolCalls[0].Function.Name)
assert.Equal(t, "call_2", resp.ToolCalls[1].ID)
assert.Equal(t, "tool_b", resp.ToolCalls[1].Name)
assert.NotNil(t, resp.ToolCalls[1].Arguments)
assert.NotNil(t, resp.ToolCalls[1].Function)
assert.Equal(t, "tool_b", resp.ToolCalls[1].Function.Name)
}
func TestParseResponse_ToolCallWithNilInput(t *testing.T) {
output := &bedrockruntime.ConverseOutput{
Output: &types.ConverseOutputMemberMessage{
Value: types.Message{
Role: types.ConversationRoleAssistant,
Content: []types.ContentBlock{
&types.ContentBlockMemberToolUse{
Value: types.ToolUseBlock{
ToolUseId: aws.String("call_nil"),
Name: aws.String("no_args_tool"),
Input: nil,
},
},
},
},
},
StopReason: types.StopReasonToolUse,
}
resp, err := parseResponse(output)
require.NoError(t, err)
assert.Len(t, resp.ToolCalls, 1)
assert.Equal(t, "call_nil", resp.ToolCalls[0].ID)
assert.Equal(t, "no_args_tool", resp.ToolCalls[0].Name)
// Arguments should be empty map, not nil
assert.NotNil(t, resp.ToolCalls[0].Arguments)
assert.Empty(t, resp.ToolCalls[0].Arguments)
}
+73
View File
@@ -0,0 +1,73 @@
//go:build !bedrock
// PicoClaw - Ultra-lightweight personal AI agent
// License: MIT
//
// Copyright (c) 2026 PicoClaw contributors
// Package bedrock provides a stub implementation when built without the bedrock tag.
// To enable AWS Bedrock support, build with: go build -tags bedrock
package bedrock
import (
"context"
"fmt"
"time"
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
)
type (
LLMResponse = protocoltypes.LLMResponse
Message = protocoltypes.Message
ToolDefinition = protocoltypes.ToolDefinition
)
// Provider is a stub that returns an error when Bedrock support is not compiled in.
type Provider struct{}
// Option is a no-op when Bedrock is not enabled.
type Option func(*providerConfig)
type providerConfig struct{}
// WithRegion is a no-op when Bedrock is not enabled.
func WithRegion(region string) Option {
return func(c *providerConfig) {}
}
// WithProfile is a no-op when Bedrock is not enabled.
func WithProfile(profile string) Option {
return func(c *providerConfig) {}
}
// WithBaseEndpoint is a no-op when Bedrock is not enabled.
func WithBaseEndpoint(endpoint string) Option {
return func(c *providerConfig) {}
}
// WithRequestTimeout is a no-op when Bedrock is not enabled.
func WithRequestTimeout(timeout time.Duration) Option {
return func(c *providerConfig) {}
}
// NewProvider returns an error indicating Bedrock support is not compiled in.
func NewProvider(ctx context.Context, opts ...Option) (*Provider, error) {
return nil, fmt.Errorf("bedrock provider not available: build with -tags bedrock to enable AWS Bedrock support")
}
// Chat returns an error - this should never be called since NewProvider fails.
func (p *Provider) Chat(
ctx context.Context,
messages []Message,
tools []ToolDefinition,
model string,
options map[string]any,
) (*LLMResponse, error) {
return nil, fmt.Errorf("bedrock provider not available: build with -tags bedrock to enable AWS Bedrock support")
}
// GetDefaultModel returns an empty string.
func (p *Provider) GetDefaultModel() string {
return ""
}
@@ -0,0 +1,35 @@
//go:build !bedrock
// PicoClaw - Ultra-lightweight personal AI agent
// License: MIT
//
// Copyright (c) 2026 PicoClaw contributors
package bedrock
import (
"context"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewProvider_ReturnsStubError(t *testing.T) {
provider, err := NewProvider(context.Background())
assert.Nil(t, provider)
require.Error(t, err)
assert.True(t, strings.Contains(err.Error(), "build with -tags bedrock"),
"error should mention build tag requirement, got: %s", err.Error())
}
func TestNewProvider_WithOptions_ReturnsStubError(t *testing.T) {
provider, err := NewProvider(context.Background(), WithRegion("us-west-2"), WithProfile("test"))
assert.Nil(t, provider)
require.Error(t, err)
assert.True(t, strings.Contains(err.Error(), "build with -tags bedrock"),
"error should mention build tag requirement, got: %s", err.Error())
}
+8 -8
View File
@@ -413,10 +413,10 @@ func TestChat_EmptyWorkspaceDoesNotSetDir(t *testing.T) {
func TestCreateProvider_ClaudeCli(t *testing.T) {
cfg := config.DefaultConfig()
cfg.ModelList = []config.ModelConfig{
cfg.ModelList = []*config.ModelConfig{
{ModelName: "claude-sonnet-4.6", Model: "claude-cli/claude-sonnet-4.6", Workspace: "/test/ws"},
}
cfg.Agents.Defaults.Model = "claude-sonnet-4.6"
cfg.Agents.Defaults.ModelName = "claude-sonnet-4.6"
provider, _, err := CreateProvider(cfg)
if err != nil {
@@ -434,10 +434,10 @@ func TestCreateProvider_ClaudeCli(t *testing.T) {
func TestCreateProvider_ClaudeCode(t *testing.T) {
cfg := config.DefaultConfig()
cfg.ModelList = []config.ModelConfig{
cfg.ModelList = []*config.ModelConfig{
{ModelName: "claude-code", Model: "claude-cli/claude-code"},
}
cfg.Agents.Defaults.Model = "claude-code"
cfg.Agents.Defaults.ModelName = "claude-code"
provider, _, err := CreateProvider(cfg)
if err != nil {
@@ -450,10 +450,10 @@ func TestCreateProvider_ClaudeCode(t *testing.T) {
func TestCreateProvider_ClaudeCodec(t *testing.T) {
cfg := config.DefaultConfig()
cfg.ModelList = []config.ModelConfig{
cfg.ModelList = []*config.ModelConfig{
{ModelName: "claudecode", Model: "claude-cli/claudecode"},
}
cfg.Agents.Defaults.Model = "claudecode"
cfg.Agents.Defaults.ModelName = "claudecode"
provider, _, err := CreateProvider(cfg)
if err != nil {
@@ -466,10 +466,10 @@ func TestCreateProvider_ClaudeCodec(t *testing.T) {
func TestCreateProvider_ClaudeCliDefaultWorkspace(t *testing.T) {
cfg := config.DefaultConfig()
cfg.ModelList = []config.ModelConfig{
cfg.ModelList = []*config.ModelConfig{
{ModelName: "claude-cli", Model: "claude-cli/claude-sonnet"},
}
cfg.Agents.Defaults.Model = "claude-cli"
cfg.Agents.Defaults.ModelName = "claude-cli"
cfg.Agents.Defaults.Workspace = ""
provider, _, err := CreateProvider(cfg)
+41 -1
View File
@@ -111,6 +111,17 @@ func SerializeMessages(messages []Message) []any {
"url": mediaURL,
},
})
continue
}
if format, data, ok := parseDataAudioURL(mediaURL); ok {
parts = append(parts, map[string]any{
"type": "input_audio",
"input_audio": map[string]any{
"data": data,
"format": format,
},
})
}
}
@@ -132,6 +143,26 @@ func SerializeMessages(messages []Message) []any {
return out
}
func parseDataAudioURL(mediaURL string) (format, data string, ok bool) {
if !strings.HasPrefix(mediaURL, "data:audio/") {
return "", "", false
}
payload := strings.TrimPrefix(mediaURL, "data:audio/")
meta, data, found := strings.Cut(payload, ",")
if !found {
return "", "", false
}
format, _, _ = strings.Cut(meta, ";")
format = strings.TrimSpace(format)
data = strings.TrimSpace(data)
if format == "" || data == "" {
return "", "", false
}
return format, data, true
}
// --- Response parsing ---
// ParseResponse parses a JSON chat completion response body into an LLMResponse.
@@ -214,11 +245,20 @@ func ParseResponse(body io.Reader) (*LLMResponse, error) {
Reasoning: choice.Message.Reasoning,
ReasoningDetails: choice.Message.ReasoningDetails,
ToolCalls: toolCalls,
FinishReason: choice.FinishReason,
FinishReason: normalizeFinishReason(choice.FinishReason),
Usage: apiResponse.Usage,
}, nil
}
// normalizeFinishReason normalizes finish_reason values across providers.
// Converts "length" to "truncated" for consistent handling.
func normalizeFinishReason(reason string) string {
if reason == "length" {
return "truncated"
}
return reason
}
// DecodeToolCallArguments decodes a tool call's arguments from raw JSON.
func DecodeToolCallArguments(raw json.RawMessage, name string) map[string]any {
arguments := make(map[string]any)
+38
View File
@@ -91,6 +91,44 @@ func TestSerializeMessages_WithMedia(t *testing.T) {
}
}
func TestSerializeMessages_WithAudioMedia(t *testing.T) {
messages := []Message{
{Role: "user", Content: "transcribe this", Media: []string{"data:audio/ogg;base64,abc123"}},
}
result := SerializeMessages(messages)
data, _ := json.Marshal(result)
var msgs []map[string]any
json.Unmarshal(data, &msgs)
content, ok := msgs[0]["content"].([]any)
if !ok {
t.Fatalf("expected array content for media message, got %T", msgs[0]["content"])
}
if len(content) != 2 {
t.Fatalf("expected 2 content parts, got %d", len(content))
}
audioPart, ok := content[1].(map[string]any)
if !ok {
t.Fatalf("expected audio content part to be an object, got %T", content[1])
}
if audioPart["type"] != "input_audio" {
t.Fatalf("audio part type = %v, want input_audio", audioPart["type"])
}
inputAudio, ok := audioPart["input_audio"].(map[string]any)
if !ok {
t.Fatalf("expected input_audio object, got %T", audioPart["input_audio"])
}
if inputAudio["format"] != "ogg" {
t.Fatalf("audio format = %v, want ogg", inputAudio["format"])
}
if inputAudio["data"] != "abc123" {
t.Fatalf("audio data = %v, want abc123", inputAudio["data"])
}
}
func TestSerializeMessages_MediaWithToolCallID(t *testing.T) {
messages := []Message{
{Role: "tool", Content: "result", Media: []string{"data:image/png;base64,xyz"}, ToolCallID: "call_1"},
-393
View File
@@ -1,400 +1,7 @@
package providers
import (
"fmt"
"strings"
"github.com/sipeed/picoclaw/pkg/auth"
"github.com/sipeed/picoclaw/pkg/config"
)
const defaultAnthropicAPIBase = "https://api.anthropic.com/v1"
var getCredential = auth.GetCredential
type providerType int
const (
providerTypeHTTPCompat providerType = iota
providerTypeClaudeAuth
providerTypeCodexAuth
providerTypeCodexCLIToken
providerTypeClaudeCLI
providerTypeCodexCLI
providerTypeGitHubCopilot
)
type providerSelection struct {
providerType providerType
apiKey string
apiBase string
proxy string
model string
workspace string
connectMode string
enableWebSearch bool
}
func resolveProviderSelection(cfg *config.Config) (providerSelection, error) {
model := cfg.Agents.Defaults.GetModelName()
providerName := strings.ToLower(cfg.Agents.Defaults.Provider)
lowerModel := strings.ToLower(model)
if providerName == "" && model == "" {
return providerSelection{}, fmt.Errorf("no model configured: agents.defaults.model is empty")
}
sel := providerSelection{
providerType: providerTypeHTTPCompat,
model: model,
}
// First, prefer explicit provider configuration.
if providerName != "" {
switch providerName {
case "groq":
if cfg.Providers.Groq.APIKey != "" {
sel.apiKey = cfg.Providers.Groq.APIKey
sel.apiBase = cfg.Providers.Groq.APIBase
sel.proxy = cfg.Providers.Groq.Proxy
if sel.apiBase == "" {
sel.apiBase = "https://api.groq.com/openai/v1"
}
}
case "openai", "gpt":
if cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != "" {
sel.enableWebSearch = cfg.Providers.OpenAI.WebSearch
if cfg.Providers.OpenAI.AuthMethod == "codex-cli" {
sel.providerType = providerTypeCodexCLIToken
return sel, nil
}
if cfg.Providers.OpenAI.AuthMethod == "oauth" || cfg.Providers.OpenAI.AuthMethod == "token" {
sel.providerType = providerTypeCodexAuth
return sel, nil
}
sel.apiKey = cfg.Providers.OpenAI.APIKey
sel.apiBase = cfg.Providers.OpenAI.APIBase
sel.proxy = cfg.Providers.OpenAI.Proxy
if sel.apiBase == "" {
sel.apiBase = "https://api.openai.com/v1"
}
}
case "anthropic", "claude":
if cfg.Providers.Anthropic.APIKey != "" || cfg.Providers.Anthropic.AuthMethod != "" {
if cfg.Providers.Anthropic.AuthMethod == "oauth" || cfg.Providers.Anthropic.AuthMethod == "token" {
sel.apiBase = cfg.Providers.Anthropic.APIBase
if sel.apiBase == "" {
sel.apiBase = defaultAnthropicAPIBase
}
sel.providerType = providerTypeClaudeAuth
return sel, nil
}
sel.apiKey = cfg.Providers.Anthropic.APIKey
sel.apiBase = cfg.Providers.Anthropic.APIBase
sel.proxy = cfg.Providers.Anthropic.Proxy
if sel.apiBase == "" {
sel.apiBase = defaultAnthropicAPIBase
}
}
case "openrouter":
if cfg.Providers.OpenRouter.APIKey != "" {
sel.apiKey = cfg.Providers.OpenRouter.APIKey
sel.proxy = cfg.Providers.OpenRouter.Proxy
if cfg.Providers.OpenRouter.APIBase != "" {
sel.apiBase = cfg.Providers.OpenRouter.APIBase
} else {
sel.apiBase = "https://openrouter.ai/api/v1"
}
}
case "litellm":
if cfg.Providers.LiteLLM.APIKey != "" || cfg.Providers.LiteLLM.APIBase != "" {
sel.apiKey = cfg.Providers.LiteLLM.APIKey
sel.apiBase = cfg.Providers.LiteLLM.APIBase
sel.proxy = cfg.Providers.LiteLLM.Proxy
if sel.apiBase == "" {
sel.apiBase = "http://localhost:4000/v1"
}
}
case "zhipu", "glm":
if cfg.Providers.Zhipu.APIKey != "" {
sel.apiKey = cfg.Providers.Zhipu.APIKey
sel.apiBase = cfg.Providers.Zhipu.APIBase
sel.proxy = cfg.Providers.Zhipu.Proxy
if sel.apiBase == "" {
sel.apiBase = "https://open.bigmodel.cn/api/paas/v4"
}
}
case "gemini", "google":
if cfg.Providers.Gemini.APIKey != "" {
sel.apiKey = cfg.Providers.Gemini.APIKey
sel.apiBase = cfg.Providers.Gemini.APIBase
sel.proxy = cfg.Providers.Gemini.Proxy
if sel.apiBase == "" {
sel.apiBase = "https://generativelanguage.googleapis.com/v1beta"
}
}
case "vllm":
if cfg.Providers.VLLM.APIBase != "" {
sel.apiKey = cfg.Providers.VLLM.APIKey
sel.apiBase = cfg.Providers.VLLM.APIBase
sel.proxy = cfg.Providers.VLLM.Proxy
}
case "shengsuanyun":
if cfg.Providers.ShengSuanYun.APIKey != "" {
sel.apiKey = cfg.Providers.ShengSuanYun.APIKey
sel.apiBase = cfg.Providers.ShengSuanYun.APIBase
sel.proxy = cfg.Providers.ShengSuanYun.Proxy
if sel.apiBase == "" {
sel.apiBase = "https://router.shengsuanyun.com/api/v1"
}
}
case "nvidia":
if cfg.Providers.Nvidia.APIKey != "" {
sel.apiKey = cfg.Providers.Nvidia.APIKey
sel.apiBase = cfg.Providers.Nvidia.APIBase
sel.proxy = cfg.Providers.Nvidia.Proxy
if sel.apiBase == "" {
sel.apiBase = "https://integrate.api.nvidia.com/v1"
}
}
case "vivgrid":
if cfg.Providers.Vivgrid.APIKey != "" {
sel.apiKey = cfg.Providers.Vivgrid.APIKey
sel.apiBase = cfg.Providers.Vivgrid.APIBase
sel.proxy = cfg.Providers.Vivgrid.Proxy
if sel.apiBase == "" {
sel.apiBase = "https://api.vivgrid.com/v1"
}
}
case "claude-cli", "claude-code", "claudecode":
workspace := cfg.WorkspacePath()
if workspace == "" {
workspace = "."
}
sel.providerType = providerTypeClaudeCLI
sel.workspace = workspace
return sel, nil
case "codex-cli", "codex-code":
workspace := cfg.WorkspacePath()
if workspace == "" {
workspace = "."
}
sel.providerType = providerTypeCodexCLI
sel.workspace = workspace
return sel, nil
case "deepseek":
if cfg.Providers.DeepSeek.APIKey != "" {
sel.apiKey = cfg.Providers.DeepSeek.APIKey
sel.apiBase = cfg.Providers.DeepSeek.APIBase
sel.proxy = cfg.Providers.DeepSeek.Proxy
if sel.apiBase == "" {
sel.apiBase = "https://api.deepseek.com/v1"
}
if model != "deepseek-chat" && model != "deepseek-reasoner" {
sel.model = "deepseek-chat"
}
}
case "avian":
if cfg.Providers.Avian.APIKey != "" {
sel.apiKey = cfg.Providers.Avian.APIKey
sel.apiBase = cfg.Providers.Avian.APIBase
sel.proxy = cfg.Providers.Avian.Proxy
if sel.apiBase == "" {
sel.apiBase = "https://api.avian.io/v1"
}
}
case "mistral":
if cfg.Providers.Mistral.APIKey != "" {
sel.apiKey = cfg.Providers.Mistral.APIKey
sel.apiBase = cfg.Providers.Mistral.APIBase
sel.proxy = cfg.Providers.Mistral.Proxy
if sel.apiBase == "" {
sel.apiBase = "https://api.mistral.ai/v1"
}
}
case "minimax":
if cfg.Providers.Minimax.APIKey != "" {
sel.apiKey = cfg.Providers.Minimax.APIKey
sel.apiBase = cfg.Providers.Minimax.APIBase
sel.proxy = cfg.Providers.Minimax.Proxy
if sel.apiBase == "" {
sel.apiBase = "https://api.minimaxi.com/v1"
}
}
case "longcat":
if cfg.Providers.LongCat.APIKey != "" {
sel.apiKey = cfg.Providers.LongCat.APIKey
sel.apiBase = cfg.Providers.LongCat.APIBase
sel.proxy = cfg.Providers.LongCat.Proxy
if sel.apiBase == "" {
sel.apiBase = "https://api.longcat.chat/openai"
}
}
case "github_copilot", "copilot":
sel.providerType = providerTypeGitHubCopilot
if cfg.Providers.GitHubCopilot.APIBase != "" {
sel.apiBase = cfg.Providers.GitHubCopilot.APIBase
} else {
sel.apiBase = "localhost:4321"
}
sel.connectMode = cfg.Providers.GitHubCopilot.ConnectMode
return sel, nil
}
}
// Fallback: infer provider from model and configured keys.
if sel.apiKey == "" && sel.apiBase == "" {
switch {
case (strings.Contains(lowerModel, "kimi") || strings.Contains(lowerModel, "moonshot") || strings.HasPrefix(model, "moonshot/")) && cfg.Providers.Moonshot.APIKey != "":
sel.apiKey = cfg.Providers.Moonshot.APIKey
sel.apiBase = cfg.Providers.Moonshot.APIBase
sel.proxy = cfg.Providers.Moonshot.Proxy
if sel.apiBase == "" {
sel.apiBase = "https://api.moonshot.cn/v1"
}
case strings.HasPrefix(model, "openrouter/") ||
strings.HasPrefix(model, "anthropic/") ||
strings.HasPrefix(model, "openai/") ||
strings.HasPrefix(model, "meta-llama/") ||
strings.HasPrefix(model, "deepseek/") ||
strings.HasPrefix(model, "google/"):
sel.apiKey = cfg.Providers.OpenRouter.APIKey
sel.proxy = cfg.Providers.OpenRouter.Proxy
if cfg.Providers.OpenRouter.APIBase != "" {
sel.apiBase = cfg.Providers.OpenRouter.APIBase
} else {
sel.apiBase = "https://openrouter.ai/api/v1"
}
case (strings.Contains(lowerModel, "claude") || strings.HasPrefix(model, "anthropic/")) &&
(cfg.Providers.Anthropic.APIKey != "" || cfg.Providers.Anthropic.AuthMethod != ""):
if cfg.Providers.Anthropic.AuthMethod == "oauth" || cfg.Providers.Anthropic.AuthMethod == "token" {
sel.apiBase = cfg.Providers.Anthropic.APIBase
if sel.apiBase == "" {
sel.apiBase = defaultAnthropicAPIBase
}
sel.providerType = providerTypeClaudeAuth
return sel, nil
}
sel.apiKey = cfg.Providers.Anthropic.APIKey
sel.apiBase = cfg.Providers.Anthropic.APIBase
sel.proxy = cfg.Providers.Anthropic.Proxy
if sel.apiBase == "" {
sel.apiBase = defaultAnthropicAPIBase
}
case (strings.Contains(lowerModel, "gpt") || strings.HasPrefix(model, "openai/")) &&
(cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != ""):
sel.enableWebSearch = cfg.Providers.OpenAI.WebSearch
if cfg.Providers.OpenAI.AuthMethod == "codex-cli" {
sel.providerType = providerTypeCodexCLIToken
return sel, nil
}
if cfg.Providers.OpenAI.AuthMethod == "oauth" || cfg.Providers.OpenAI.AuthMethod == "token" {
sel.providerType = providerTypeCodexAuth
return sel, nil
}
sel.apiKey = cfg.Providers.OpenAI.APIKey
sel.apiBase = cfg.Providers.OpenAI.APIBase
sel.proxy = cfg.Providers.OpenAI.Proxy
if sel.apiBase == "" {
sel.apiBase = "https://api.openai.com/v1"
}
case (strings.Contains(lowerModel, "gemini") || strings.HasPrefix(model, "google/")) && cfg.Providers.Gemini.APIKey != "":
sel.apiKey = cfg.Providers.Gemini.APIKey
sel.apiBase = cfg.Providers.Gemini.APIBase
sel.proxy = cfg.Providers.Gemini.Proxy
if sel.apiBase == "" {
sel.apiBase = "https://generativelanguage.googleapis.com/v1beta"
}
case (strings.Contains(lowerModel, "glm") || strings.Contains(lowerModel, "zhipu") || strings.Contains(lowerModel, "zai")) && cfg.Providers.Zhipu.APIKey != "":
sel.apiKey = cfg.Providers.Zhipu.APIKey
sel.apiBase = cfg.Providers.Zhipu.APIBase
sel.proxy = cfg.Providers.Zhipu.Proxy
if sel.apiBase == "" {
sel.apiBase = "https://open.bigmodel.cn/api/paas/v4"
}
case (strings.Contains(lowerModel, "groq") || strings.HasPrefix(model, "groq/")) && cfg.Providers.Groq.APIKey != "":
sel.apiKey = cfg.Providers.Groq.APIKey
sel.apiBase = cfg.Providers.Groq.APIBase
sel.proxy = cfg.Providers.Groq.Proxy
if sel.apiBase == "" {
sel.apiBase = "https://api.groq.com/openai/v1"
}
case (strings.Contains(lowerModel, "nvidia") || strings.HasPrefix(model, "nvidia/")) && cfg.Providers.Nvidia.APIKey != "":
sel.apiKey = cfg.Providers.Nvidia.APIKey
sel.apiBase = cfg.Providers.Nvidia.APIBase
sel.proxy = cfg.Providers.Nvidia.Proxy
if sel.apiBase == "" {
sel.apiBase = "https://integrate.api.nvidia.com/v1"
}
case strings.HasPrefix(model, "vivgrid/") && cfg.Providers.Vivgrid.APIKey != "":
sel.apiKey = cfg.Providers.Vivgrid.APIKey
sel.apiBase = cfg.Providers.Vivgrid.APIBase
sel.proxy = cfg.Providers.Vivgrid.Proxy
if sel.apiBase == "" {
sel.apiBase = "https://api.vivgrid.com/v1"
}
case (strings.Contains(lowerModel, "ollama") || strings.HasPrefix(model, "ollama/")) && cfg.Providers.Ollama.APIKey != "":
sel.apiKey = cfg.Providers.Ollama.APIKey
sel.apiBase = cfg.Providers.Ollama.APIBase
sel.proxy = cfg.Providers.Ollama.Proxy
if sel.apiBase == "" {
sel.apiBase = "http://localhost:11434/v1"
}
case (strings.Contains(lowerModel, "mistral") || strings.HasPrefix(model, "mistral/")) && cfg.Providers.Mistral.APIKey != "":
sel.apiKey = cfg.Providers.Mistral.APIKey
sel.apiBase = cfg.Providers.Mistral.APIBase
sel.proxy = cfg.Providers.Mistral.Proxy
if sel.apiBase == "" {
sel.apiBase = "https://api.mistral.ai/v1"
}
case (strings.Contains(lowerModel, "minimax") || strings.HasPrefix(model, "minimax/")) && cfg.Providers.Minimax.APIKey != "":
sel.apiKey = cfg.Providers.Minimax.APIKey
sel.apiBase = cfg.Providers.Minimax.APIBase
sel.proxy = cfg.Providers.Minimax.Proxy
if sel.apiBase == "" {
sel.apiBase = "https://api.minimaxi.com/v1"
}
case strings.HasPrefix(model, "avian/") && cfg.Providers.Avian.APIKey != "":
sel.apiKey = cfg.Providers.Avian.APIKey
sel.apiBase = cfg.Providers.Avian.APIBase
sel.proxy = cfg.Providers.Avian.Proxy
if sel.apiBase == "" {
sel.apiBase = "https://api.avian.io/v1"
}
case (strings.Contains(lowerModel, "longcat") || strings.HasPrefix(model, "longcat/")) && cfg.Providers.LongCat.APIKey != "":
sel.apiKey = cfg.Providers.LongCat.APIKey
sel.apiBase = cfg.Providers.LongCat.APIBase
sel.proxy = cfg.Providers.LongCat.Proxy
if sel.apiBase == "" {
sel.apiBase = "https://api.longcat.chat/openai"
}
case cfg.Providers.VLLM.APIBase != "":
sel.apiKey = cfg.Providers.VLLM.APIKey
sel.apiBase = cfg.Providers.VLLM.APIBase
sel.proxy = cfg.Providers.VLLM.Proxy
default:
if cfg.Providers.OpenRouter.APIKey != "" {
sel.apiKey = cfg.Providers.OpenRouter.APIKey
sel.proxy = cfg.Providers.OpenRouter.Proxy
if cfg.Providers.OpenRouter.APIBase != "" {
sel.apiBase = cfg.Providers.OpenRouter.APIBase
} else {
sel.apiBase = "https://openrouter.ai/api/v1"
}
} else {
return providerSelection{}, fmt.Errorf("no API key configured for model: %s", model)
}
}
}
if sel.providerType == providerTypeHTTPCompat {
if sel.apiKey == "" && !strings.HasPrefix(model, "bedrock/") {
return providerSelection{}, fmt.Errorf("no API key configured for provider (model: %s)", model)
}
if sel.apiBase == "" {
return providerSelection{}, fmt.Errorf("no API base configured for provider (model: %s)", model)
}
}
return sel, nil
}
+83 -15
View File
@@ -6,12 +6,15 @@
package providers
import (
"context"
"fmt"
"strings"
"time"
"github.com/sipeed/picoclaw/pkg/config"
anthropicmessages "github.com/sipeed/picoclaw/pkg/providers/anthropic_messages"
"github.com/sipeed/picoclaw/pkg/providers/azure"
"github.com/sipeed/picoclaw/pkg/providers/bedrock"
)
// createClaudeAuthProvider creates a Claude provider using OAuth credentials from auth store.
@@ -55,8 +58,9 @@ func ExtractProtocol(model string) (protocol, modelID string) {
// CreateProviderFromConfig creates a provider based on the ModelConfig.
// It uses the protocol prefix in the Model field to determine which provider to create.
// Supported protocols: openai, litellm, novita, anthropic, anthropic-messages,
// antigravity, claude-cli, codex-cli, github-copilot
// Supported protocol families include OpenAI-compatible prefixes (e.g., openai, openrouter, groq, gemini),
// Azure OpenAI, Amazon Bedrock, Anthropic (including messages), and various CLI/compatibility shims.
// See the switch on protocol in this function for the authoritative list.
// Returns the provider, the model ID (without protocol prefix), and any error.
func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, error) {
if cfg == nil {
@@ -80,7 +84,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
return provider, modelID, nil
}
// OpenAI with API key
if cfg.APIKey == "" && cfg.APIBase == "" {
if cfg.APIKey() == "" && cfg.APIBase == "" {
return nil, "", fmt.Errorf("api_key or api_base is required for HTTP-based protocol %q", protocol)
}
apiBase := cfg.APIBase
@@ -88,17 +92,18 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
apiBase = getDefaultAPIBase(protocol)
}
return NewHTTPProviderWithMaxTokensFieldAndRequestTimeout(
cfg.APIKey,
cfg.APIKey(),
apiBase,
cfg.Proxy,
cfg.MaxTokensField,
cfg.RequestTimeout,
cfg.ExtraBody,
), modelID, nil
case "azure", "azure-openai":
// Azure OpenAI uses deployment-based URLs, api-key header auth,
// and always sends max_completion_tokens.
if cfg.APIKey == "" {
if cfg.APIKey() == "" {
return nil, "", fmt.Errorf("api_key is required for azure protocol")
}
if cfg.APIBase == "" {
@@ -107,19 +112,55 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
)
}
return azure.NewProviderWithTimeout(
cfg.APIKey,
cfg.APIKey(),
cfg.APIBase,
cfg.Proxy,
cfg.RequestTimeout,
), modelID, nil
case "bedrock":
// AWS Bedrock uses AWS SDK credentials (env vars, profiles, IAM roles, etc.)
// api_base can be:
// - A full endpoint URL: https://bedrock-runtime.us-east-1.amazonaws.com
// - A region name: us-east-1 (AWS SDK resolves endpoint automatically)
var opts []bedrock.Option
if cfg.APIBase != "" {
if !strings.Contains(cfg.APIBase, "://") {
// Treat as region: let AWS SDK resolve the correct endpoint
// (supports all AWS partitions: aws, aws-cn, aws-us-gov, etc.)
opts = append(opts, bedrock.WithRegion(cfg.APIBase))
} else {
// Full endpoint URL provided (for custom endpoints or testing)
opts = append(opts, bedrock.WithBaseEndpoint(cfg.APIBase))
}
}
// Use a separate timeout for AWS config loading (credential resolution can block)
initTimeout := 30 * time.Second
if cfg.RequestTimeout > 0 {
reqTimeout := time.Duration(cfg.RequestTimeout) * time.Second
// Set request timeout for API calls
opts = append(opts, bedrock.WithRequestTimeout(reqTimeout))
// Ensure init timeout is at least as large as request timeout
if reqTimeout > initTimeout {
initTimeout = reqTimeout
}
}
ctx, cancel := context.WithTimeout(context.Background(), initTimeout)
defer cancel()
// Note: AWS_PROFILE env var is automatically used by AWS SDK
provider, err := bedrock.NewProvider(ctx, opts...)
if err != nil {
return nil, "", fmt.Errorf("creating bedrock provider: %w", err)
}
return provider, modelID, nil
case "litellm", "openrouter", "groq", "zhipu", "gemini", "nvidia",
"ollama", "moonshot", "shengsuanyun", "deepseek", "cerebras",
"vivgrid", "volcengine", "vllm", "qwen", "qwen-intl", "qwen-international", "dashscope-intl",
"qwen-us", "dashscope-us", "mistral", "avian", "minimax", "longcat", "modelscope", "novita",
"qwen-us", "dashscope-us", "mistral", "avian", "longcat", "modelscope", "novita",
"coding-plan", "alibaba-coding", "qwen-coding":
// All other OpenAI-compatible HTTP providers
if cfg.APIKey == "" && cfg.APIBase == "" {
if cfg.APIKey() == "" && cfg.APIBase == "" {
return nil, "", fmt.Errorf("api_key or api_base is required for HTTP-based protocol %q", protocol)
}
apiBase := cfg.APIBase
@@ -127,11 +168,37 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
apiBase = getDefaultAPIBase(protocol)
}
return NewHTTPProviderWithMaxTokensFieldAndRequestTimeout(
cfg.APIKey,
cfg.APIKey(),
apiBase,
cfg.Proxy,
cfg.MaxTokensField,
cfg.RequestTimeout,
cfg.ExtraBody,
), modelID, nil
case "minimax":
// Minimax requires reasoning_split: true in the request body
if cfg.APIKey() == "" && cfg.APIBase == "" {
return nil, "", fmt.Errorf("api_key or api_base is required for HTTP-based protocol %q", protocol)
}
apiBase := cfg.APIBase
if apiBase == "" {
apiBase = getDefaultAPIBase(protocol)
}
extraBody := cfg.ExtraBody
if extraBody == nil {
extraBody = make(map[string]any)
}
if _, ok := extraBody["reasoning_split"]; !ok {
extraBody["reasoning_split"] = true
}
return NewHTTPProviderWithMaxTokensFieldAndRequestTimeout(
cfg.APIKey(),
apiBase,
cfg.Proxy,
cfg.MaxTokensField,
cfg.RequestTimeout,
extraBody,
), modelID, nil
case "anthropic":
@@ -148,15 +215,16 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
if apiBase == "" {
apiBase = "https://api.anthropic.com/v1"
}
if cfg.APIKey == "" {
if cfg.APIKey() == "" {
return nil, "", fmt.Errorf("api_key is required for anthropic protocol (model: %s)", cfg.Model)
}
return NewHTTPProviderWithMaxTokensFieldAndRequestTimeout(
cfg.APIKey,
cfg.APIKey(),
apiBase,
cfg.Proxy,
cfg.MaxTokensField,
cfg.RequestTimeout,
cfg.ExtraBody,
), modelID, nil
case "anthropic-messages":
@@ -165,11 +233,11 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
if apiBase == "" {
apiBase = "https://api.anthropic.com/v1"
}
if cfg.APIKey == "" {
if cfg.APIKey() == "" {
return nil, "", fmt.Errorf("api_key is required for anthropic-messages protocol (model: %s)", cfg.Model)
}
return anthropicmessages.NewProviderWithTimeout(
cfg.APIKey,
cfg.APIKey(),
apiBase,
cfg.RequestTimeout,
), modelID, nil
@@ -180,11 +248,11 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
if apiBase == "" {
apiBase = getDefaultAPIBase(protocol)
}
if cfg.APIKey == "" {
if cfg.APIKey() == "" {
return nil, "", fmt.Errorf("api_key is required for %q protocol (model: %s)", protocol, cfg.Model)
}
return anthropicmessages.NewProviderWithTimeout(
cfg.APIKey,
cfg.APIKey(),
apiBase,
cfg.RequestTimeout,
), modelID, nil
+186 -14
View File
@@ -6,6 +6,7 @@
package providers
import (
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
@@ -89,9 +90,9 @@ func TestCreateProviderFromConfig_OpenAI(t *testing.T) {
cfg := &config.ModelConfig{
ModelName: "test-openai",
Model: "openai/gpt-4o",
APIKey: "test-key",
APIBase: "https://api.example.com/v1",
}
cfg.SetAPIKey("test-key")
provider, modelID, err := CreateProviderFromConfig(cfg)
if err != nil {
@@ -129,8 +130,8 @@ func TestCreateProviderFromConfig_DefaultAPIBase(t *testing.T) {
cfg := &config.ModelConfig{
ModelName: "test-" + tt.protocol,
Model: tt.protocol + "/test-model",
APIKey: "test-key",
}
cfg.SetAPIKey("test-key")
provider, _, err := CreateProviderFromConfig(cfg)
if err != nil {
@@ -155,9 +156,9 @@ func TestCreateProviderFromConfig_LiteLLM(t *testing.T) {
cfg := &config.ModelConfig{
ModelName: "test-litellm",
Model: "litellm/my-proxy-alias",
APIKey: "test-key",
APIBase: "http://localhost:4000/v1",
}
cfg.SetAPIKey("test-key")
provider, modelID, err := CreateProviderFromConfig(cfg)
if err != nil {
@@ -175,9 +176,9 @@ func TestCreateProviderFromConfig_LongCat(t *testing.T) {
cfg := &config.ModelConfig{
ModelName: "test-longcat",
Model: "longcat/LongCat-Flash-Thinking",
APIKey: "test-key",
APIBase: "https://api.longcat.chat/openai",
}
cfg.SetAPIKey("test-key")
provider, modelID, err := CreateProviderFromConfig(cfg)
if err != nil {
@@ -198,9 +199,9 @@ func TestCreateProviderFromConfig_ModelScope(t *testing.T) {
cfg := &config.ModelConfig{
ModelName: "test-modelscope",
Model: "modelscope/Qwen/Qwen3-235B-A22B-Instruct-2507",
APIKey: "test-key",
APIBase: "https://api-inference.modelscope.cn/v1",
}
cfg.SetAPIKey("test-key")
provider, modelID, err := CreateProviderFromConfig(cfg)
if err != nil {
@@ -227,8 +228,8 @@ func TestCreateProviderFromConfig_Novita(t *testing.T) {
cfg := &config.ModelConfig{
ModelName: "test-novita",
Model: "novita/deepseek/deepseek-v3.2",
APIKey: "test-key",
}
cfg.SetAPIKey("test-key")
provider, modelID, err := CreateProviderFromConfig(cfg)
if err != nil {
@@ -255,8 +256,8 @@ func TestCreateProviderFromConfig_Anthropic(t *testing.T) {
cfg := &config.ModelConfig{
ModelName: "test-anthropic",
Model: "anthropic/claude-sonnet-4.6",
APIKey: "test-key",
}
cfg.SetAPIKey("test-key")
provider, modelID, err := CreateProviderFromConfig(cfg)
if err != nil {
@@ -340,8 +341,8 @@ func TestCreateProviderFromConfig_UnknownProtocol(t *testing.T) {
cfg := &config.ModelConfig{
ModelName: "test-unknown",
Model: "unknown-protocol/model",
APIKey: "test-key",
}
cfg.SetAPIKey("test-key")
_, _, err := CreateProviderFromConfig(cfg)
if err == nil {
@@ -382,6 +383,7 @@ func TestCreateProviderFromConfig_RequestTimeoutPropagation(t *testing.T) {
APIBase: server.URL,
RequestTimeout: 1,
}
cfg.SetAPIKey("test-key")
provider, modelID, err := CreateProviderFromConfig(cfg)
if err != nil {
@@ -411,9 +413,9 @@ func TestCreateProviderFromConfig_Azure(t *testing.T) {
cfg := &config.ModelConfig{
ModelName: "azure-gpt5",
Model: "azure/my-gpt5-deployment",
APIKey: "test-azure-key",
APIBase: "https://my-resource.openai.azure.com",
}
cfg.SetAPIKey("test-azure-key")
provider, modelID, err := CreateProviderFromConfig(cfg)
if err != nil {
@@ -431,9 +433,9 @@ func TestCreateProviderFromConfig_AzureOpenAIAlias(t *testing.T) {
cfg := &config.ModelConfig{
ModelName: "azure-gpt4",
Model: "azure-openai/my-deployment",
APIKey: "test-azure-key",
APIBase: "https://my-resource.openai.azure.com",
}
cfg.SetAPIKey("test-azure-key")
provider, modelID, err := CreateProviderFromConfig(cfg)
if err != nil {
@@ -464,8 +466,8 @@ func TestCreateProviderFromConfig_AzureMissingAPIBase(t *testing.T) {
cfg := &config.ModelConfig{
ModelName: "azure-gpt5",
Model: "azure/my-gpt5-deployment",
APIKey: "test-azure-key",
}
cfg.SetAPIKey("test-azure-key")
_, _, err := CreateProviderFromConfig(cfg)
if err == nil {
@@ -488,8 +490,8 @@ func TestCreateProviderFromConfig_QwenInternationalAlias(t *testing.T) {
cfg := &config.ModelConfig{
ModelName: "test-" + tt.protocol,
Model: tt.protocol + "/qwen-max",
APIKey: "test-key",
}
cfg.SetAPIKey("test-key")
provider, modelID, err := CreateProviderFromConfig(cfg)
if err != nil {
@@ -522,8 +524,8 @@ func TestCreateProviderFromConfig_QwenUSAlias(t *testing.T) {
cfg := &config.ModelConfig{
ModelName: "test-" + tt.protocol,
Model: tt.protocol + "/qwen-max",
APIKey: "test-key",
}
cfg.SetAPIKey("test-key")
provider, modelID, err := CreateProviderFromConfig(cfg)
if err != nil {
@@ -556,8 +558,8 @@ func TestCreateProviderFromConfig_CodingPlanAnthropic(t *testing.T) {
cfg := &config.ModelConfig{
ModelName: "test-" + tt.protocol,
Model: tt.protocol + "/claude-sonnet-4-20250514",
APIKey: "test-key",
}
cfg.SetAPIKey("test-key")
provider, modelID, err := CreateProviderFromConfig(cfg)
if err != nil {
@@ -603,3 +605,173 @@ func TestGetDefaultAPIBase_QwenUSAliases(t *testing.T) {
}
}
}
func TestCreateProviderFromConfig_MinimaxInjectsReasoningSplit(t *testing.T) {
var requestBody map[string]any
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"choices":[{"message":{"content":"ok"},"finish_reason":"stop"}]}`))
}))
defer server.Close()
cfg := &config.ModelConfig{
ModelName: "test-minimax",
Model: "minimax/MiniMax-M2.5",
APIBase: server.URL,
}
cfg.SetAPIKey("test-key")
provider, modelID, err := CreateProviderFromConfig(cfg)
if err != nil {
t.Fatalf("CreateProviderFromConfig() error = %v", err)
}
if provider == nil {
t.Fatal("CreateProviderFromConfig() returned nil provider")
}
if modelID != "MiniMax-M2.5" {
t.Errorf("modelID = %q, want %q", modelID, "MiniMax-M2.5")
}
_, err = provider.Chat(
t.Context(),
[]Message{{Role: "user", Content: "hi"}},
nil,
modelID,
nil,
)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
// Verify reasoning_split is automatically injected
if got, ok := requestBody["reasoning_split"]; !ok || got != true {
t.Fatalf("reasoning_split = %v, want true", got)
}
}
func TestCreateProviderFromConfig_MinimaxPreservesUserExtraBody(t *testing.T) {
var requestBody map[string]any
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"choices":[{"message":{"content":"ok"},"finish_reason":"stop"}]}`))
}))
defer server.Close()
cfg := &config.ModelConfig{
ModelName: "test-minimax-custom",
Model: "minimax/MiniMax-M2.5",
APIBase: server.URL,
ExtraBody: map[string]any{"custom_field": "test"},
}
cfg.SetAPIKey("test-key")
provider, modelID, err := CreateProviderFromConfig(cfg)
if err != nil {
t.Fatalf("CreateProviderFromConfig() error = %v", err)
}
_, err = provider.Chat(
t.Context(),
[]Message{{Role: "user", Content: "hi"}},
nil,
modelID,
nil,
)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
// Verify reasoning_split is automatically injected
if got, ok := requestBody["reasoning_split"]; !ok || got != true {
t.Fatalf("reasoning_split = %v, want true", got)
}
// Verify user's custom field is preserved
if got, ok := requestBody["custom_field"]; !ok || got != "test" {
t.Fatalf("custom_field = %v, want test", got)
}
}
func TestCreateProviderFromConfig_Bedrock(t *testing.T) {
// Set dummy AWS env vars to make test deterministic
t.Setenv("AWS_ACCESS_KEY_ID", "test-key")
t.Setenv("AWS_SECRET_ACCESS_KEY", "test-secret")
t.Setenv("AWS_EC2_METADATA_DISABLED", "true")
// Clear profile-related env vars to avoid loading shared config
t.Setenv("AWS_PROFILE", "")
t.Setenv("AWS_DEFAULT_PROFILE", "")
t.Setenv("AWS_SDK_LOAD_CONFIG", "")
t.Setenv("AWS_SHARED_CREDENTIALS_FILE", "")
cfg := &config.ModelConfig{
ModelName: "bedrock-claude",
Model: "bedrock/us.anthropic.claude-sonnet-4-20250514-v1:0",
APIBase: "us-west-2", // Region (also sets AWS region)
}
provider, modelID, err := CreateProviderFromConfig(cfg)
if err == nil {
// Provider created successfully (built with -tags bedrock)
if provider == nil {
t.Error("provider is nil on success")
}
if modelID != "us.anthropic.claude-sonnet-4-20250514-v1:0" {
t.Errorf("modelID = %q, want %q", modelID, "us.anthropic.claude-sonnet-4-20250514-v1:0")
}
return
}
errMsg := err.Error()
// When built without -tags bedrock, expect stub error
if strings.Contains(errMsg, "build with -tags bedrock") {
return // Expected stub error
}
// Unexpected error - fail the test
t.Errorf("unexpected error from bedrock provider: %v", err)
}
func TestCreateProviderFromConfig_BedrockWithEndpointURL(t *testing.T) {
// Set dummy AWS env vars to make test deterministic
t.Setenv("AWS_ACCESS_KEY_ID", "test-key")
t.Setenv("AWS_SECRET_ACCESS_KEY", "test-secret")
t.Setenv("AWS_REGION", "us-east-1") // Required when using endpoint URL
t.Setenv("AWS_EC2_METADATA_DISABLED", "true")
// Clear profile-related env vars to avoid loading shared config
t.Setenv("AWS_PROFILE", "")
t.Setenv("AWS_DEFAULT_PROFILE", "")
t.Setenv("AWS_SDK_LOAD_CONFIG", "")
t.Setenv("AWS_SHARED_CREDENTIALS_FILE", "")
cfg := &config.ModelConfig{
ModelName: "bedrock-claude",
Model: "bedrock/us.anthropic.claude-sonnet-4-20250514-v1:0",
APIBase: "https://bedrock-runtime.us-east-1.amazonaws.com", // Full endpoint URL
}
provider, modelID, err := CreateProviderFromConfig(cfg)
if err == nil {
// Provider created successfully (built with -tags bedrock)
if provider == nil {
t.Error("provider is nil on success")
}
if modelID != "us.anthropic.claude-sonnet-4-20250514-v1:0" {
t.Errorf("modelID = %q, want %q", modelID, "us.anthropic.claude-sonnet-4-20250514-v1:0")
}
return
}
errMsg := err.Error()
// When built without -tags bedrock, expect stub error
if strings.Contains(errMsg, "build with -tags bedrock") {
return // Expected stub error
}
// Unexpected error - fail the test
t.Errorf("unexpected error from bedrock provider: %v", err)
}
+13 -253
View File
@@ -1,262 +1,22 @@
package providers
import (
"strings"
"testing"
"github.com/sipeed/picoclaw/pkg/auth"
"github.com/sipeed/picoclaw/pkg/config"
)
func TestResolveProviderSelection(t *testing.T) {
tests := []struct {
name string
setup func(*config.Config)
wantType providerType
wantAPIBase string
wantProxy string
wantErrSubstr string
}{
{
name: "explicit litellm provider uses configured base",
setup: func(cfg *config.Config) {
cfg.Agents.Defaults.Provider = "litellm"
cfg.Providers.LiteLLM.APIKey = "litellm-key"
cfg.Providers.LiteLLM.APIBase = "http://localhost:4000/v1"
cfg.Providers.LiteLLM.Proxy = "http://127.0.0.1:7890"
},
wantType: providerTypeHTTPCompat,
wantAPIBase: "http://localhost:4000/v1",
wantProxy: "http://127.0.0.1:7890",
},
{
name: "explicit litellm provider defaults base when only key is configured",
setup: func(cfg *config.Config) {
cfg.Agents.Defaults.Provider = "litellm"
cfg.Providers.LiteLLM.APIKey = "litellm-key"
},
wantType: providerTypeHTTPCompat,
wantAPIBase: "http://localhost:4000/v1",
},
{
name: "explicit claude-cli provider routes to cli provider type",
setup: func(cfg *config.Config) {
cfg.Agents.Defaults.Provider = "claude-cli"
cfg.Agents.Defaults.Workspace = "/tmp/ws"
},
wantType: providerTypeClaudeCLI,
},
{
name: "explicit copilot provider routes to github copilot type",
setup: func(cfg *config.Config) {
cfg.Agents.Defaults.Provider = "copilot"
},
wantType: providerTypeGitHubCopilot,
wantAPIBase: "localhost:4321",
},
{
name: "explicit deepseek provider uses deepseek defaults",
setup: func(cfg *config.Config) {
cfg.Agents.Defaults.Provider = "deepseek"
cfg.Agents.Defaults.Model = "deepseek/deepseek-chat"
cfg.Providers.DeepSeek.APIKey = "deepseek-key"
cfg.Providers.DeepSeek.Proxy = "http://127.0.0.1:7890"
},
wantType: providerTypeHTTPCompat,
wantAPIBase: "https://api.deepseek.com/v1",
wantProxy: "http://127.0.0.1:7890",
},
{
name: "explicit shengsuanyun provider uses defaults",
setup: func(cfg *config.Config) {
cfg.Agents.Defaults.Provider = "shengsuanyun"
cfg.Providers.ShengSuanYun.APIKey = "ssy-key"
cfg.Providers.ShengSuanYun.Proxy = "http://127.0.0.1:7890"
},
wantType: providerTypeHTTPCompat,
wantAPIBase: "https://router.shengsuanyun.com/api/v1",
wantProxy: "http://127.0.0.1:7890",
},
{
name: "explicit nvidia provider uses defaults",
setup: func(cfg *config.Config) {
cfg.Agents.Defaults.Provider = "nvidia"
cfg.Providers.Nvidia.APIKey = "nvapi-test"
cfg.Providers.Nvidia.Proxy = "http://127.0.0.1:7890"
},
wantType: providerTypeHTTPCompat,
wantAPIBase: "https://integrate.api.nvidia.com/v1",
wantProxy: "http://127.0.0.1:7890",
},
{
name: "explicit vivgrid provider uses defaults",
setup: func(cfg *config.Config) {
cfg.Agents.Defaults.Provider = "vivgrid"
cfg.Providers.Vivgrid.APIKey = "vivgrid-key"
cfg.Providers.Vivgrid.Proxy = "http://127.0.0.1:7890"
},
wantType: providerTypeHTTPCompat,
wantAPIBase: "https://api.vivgrid.com/v1",
wantProxy: "http://127.0.0.1:7890",
},
{
name: "openrouter model uses openrouter defaults",
setup: func(cfg *config.Config) {
cfg.Agents.Defaults.Model = "openrouter/auto"
cfg.Providers.OpenRouter.APIKey = "sk-or-test"
},
wantType: providerTypeHTTPCompat,
wantAPIBase: "https://openrouter.ai/api/v1",
},
{
name: "anthropic oauth routes to claude auth provider",
setup: func(cfg *config.Config) {
cfg.Agents.Defaults.Model = "claude-sonnet-4.6"
cfg.Providers.Anthropic.AuthMethod = "oauth"
},
wantType: providerTypeClaudeAuth,
},
{
name: "openai oauth routes to codex auth provider",
setup: func(cfg *config.Config) {
cfg.Agents.Defaults.Model = "gpt-4o"
cfg.Providers.OpenAI.AuthMethod = "oauth"
},
wantType: providerTypeCodexAuth,
},
{
name: "openai codex-cli auth routes to codex cli token provider",
setup: func(cfg *config.Config) {
cfg.Agents.Defaults.Model = "gpt-4o"
cfg.Providers.OpenAI.AuthMethod = "codex-cli"
},
wantType: providerTypeCodexCLIToken,
},
{
name: "explicit codex-code provider routes to codex cli provider type",
setup: func(cfg *config.Config) {
cfg.Agents.Defaults.Provider = "codex-code"
cfg.Agents.Defaults.Workspace = "/tmp/ws"
},
wantType: providerTypeCodexCLI,
},
{
name: "zhipu model uses zhipu base default",
setup: func(cfg *config.Config) {
cfg.Agents.Defaults.Model = "glm-4.7"
cfg.Providers.Zhipu.APIKey = "zhipu-key"
},
wantType: providerTypeHTTPCompat,
wantAPIBase: "https://open.bigmodel.cn/api/paas/v4",
},
{
name: "groq model uses groq base default",
setup: func(cfg *config.Config) {
cfg.Agents.Defaults.Model = "groq/llama-3.3-70b"
cfg.Providers.Groq.APIKey = "gsk-key"
},
wantType: providerTypeHTTPCompat,
wantAPIBase: "https://api.groq.com/openai/v1",
},
{
name: "ollama model uses ollama base default",
setup: func(cfg *config.Config) {
cfg.Agents.Defaults.Model = "ollama/qwen2.5:14b"
cfg.Providers.Ollama.APIKey = "ollama-key"
},
wantType: providerTypeHTTPCompat,
wantAPIBase: "http://localhost:11434/v1",
},
{
name: "moonshot model keeps proxy and default base",
setup: func(cfg *config.Config) {
cfg.Agents.Defaults.Model = "moonshot/kimi-k2.5"
cfg.Providers.Moonshot.APIKey = "moonshot-key"
cfg.Providers.Moonshot.Proxy = "http://127.0.0.1:7890"
},
wantType: providerTypeHTTPCompat,
wantAPIBase: "https://api.moonshot.cn/v1",
wantProxy: "http://127.0.0.1:7890",
},
{
name: "explicit longcat provider uses defaults",
setup: func(cfg *config.Config) {
cfg.Agents.Defaults.Provider = "longcat"
cfg.Providers.LongCat.APIKey = "longcat-key"
cfg.Providers.LongCat.Proxy = "http://127.0.0.1:7890"
},
wantType: providerTypeHTTPCompat,
wantAPIBase: "https://api.longcat.chat/openai",
wantProxy: "http://127.0.0.1:7890",
},
{
name: "longcat model fallback uses longcat base default",
setup: func(cfg *config.Config) {
cfg.Agents.Defaults.Model = "longcat/LongCat-Flash-Thinking"
cfg.Providers.LongCat.APIKey = "longcat-key"
},
wantType: providerTypeHTTPCompat,
wantAPIBase: "https://api.longcat.chat/openai",
},
{
name: "missing keys returns model config error",
setup: func(cfg *config.Config) {
cfg.Agents.Defaults.Model = "custom-model"
},
wantErrSubstr: "no API key configured for model",
},
{
name: "openrouter prefix without key returns provider key error",
setup: func(cfg *config.Config) {
cfg.Agents.Defaults.Model = "openrouter/auto"
},
wantErrSubstr: "no API key configured for provider",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := config.DefaultConfig()
tt.setup(cfg)
got, err := resolveProviderSelection(cfg)
if tt.wantErrSubstr != "" {
if err == nil {
t.Fatalf("expected error containing %q, got nil", tt.wantErrSubstr)
}
if !strings.Contains(err.Error(), tt.wantErrSubstr) {
t.Fatalf("error = %q, want substring %q", err.Error(), tt.wantErrSubstr)
}
return
}
if err != nil {
t.Fatalf("resolveProviderSelection() error = %v", err)
}
if got.providerType != tt.wantType {
t.Fatalf("providerType = %v, want %v", got.providerType, tt.wantType)
}
if tt.wantAPIBase != "" && got.apiBase != tt.wantAPIBase {
t.Fatalf("apiBase = %q, want %q", got.apiBase, tt.wantAPIBase)
}
if tt.wantProxy != "" && got.proxy != tt.wantProxy {
t.Fatalf("proxy = %q, want %q", got.proxy, tt.wantProxy)
}
})
}
}
func TestCreateProviderReturnsHTTPProviderForOpenRouter(t *testing.T) {
cfg := config.DefaultConfig()
cfg.Agents.Defaults.Model = "test-openrouter"
cfg.ModelList = []config.ModelConfig{
{
ModelName: "test-openrouter",
Model: "openrouter/auto",
APIKey: "sk-or-test",
APIBase: "https://openrouter.ai/api/v1",
},
cfg.Agents.Defaults.ModelName = "test-openrouter"
modelCfg := &config.ModelConfig{
ModelName: "test-openrouter",
Model: "openrouter/auto",
APIBase: "https://openrouter.ai/api/v1",
}
modelCfg.SetAPIKey("sk-or-test")
cfg.ModelList = []*config.ModelConfig{modelCfg}
provider, _, err := CreateProvider(cfg)
if err != nil {
@@ -270,8 +30,8 @@ func TestCreateProviderReturnsHTTPProviderForOpenRouter(t *testing.T) {
func TestCreateProviderReturnsCodexCliProviderForCodexCode(t *testing.T) {
cfg := config.DefaultConfig()
cfg.Agents.Defaults.Model = "test-codex"
cfg.ModelList = []config.ModelConfig{
cfg.Agents.Defaults.ModelName = "test-codex"
cfg.ModelList = []*config.ModelConfig{
{
ModelName: "test-codex",
Model: "codex-cli/codex-model",
@@ -291,8 +51,8 @@ func TestCreateProviderReturnsCodexCliProviderForCodexCode(t *testing.T) {
func TestCreateProviderReturnsClaudeCliProviderForClaudeCli(t *testing.T) {
cfg := config.DefaultConfig()
cfg.Agents.Defaults.Model = "test-claude-cli"
cfg.ModelList = []config.ModelConfig{
cfg.Agents.Defaults.ModelName = "test-claude-cli"
cfg.ModelList = []*config.ModelConfig{
{
ModelName: "test-claude-cli",
Model: "claude-cli/claude-sonnet",
@@ -324,8 +84,8 @@ func TestCreateProviderReturnsClaudeProviderForAnthropicOAuth(t *testing.T) {
}
cfg := config.DefaultConfig()
cfg.Agents.Defaults.Model = "test-claude-oauth"
cfg.ModelList = []config.ModelConfig{
cfg.Agents.Defaults.ModelName = "test-claude-oauth"
cfg.ModelList = []*config.ModelConfig{
{
ModelName: "test-claude-oauth",
Model: "anthropic/claude-sonnet-4.6",
+3 -1
View File
@@ -24,12 +24,13 @@ func NewHTTPProvider(apiKey, apiBase, proxy string) *HTTPProvider {
}
func NewHTTPProviderWithMaxTokensField(apiKey, apiBase, proxy, maxTokensField string) *HTTPProvider {
return NewHTTPProviderWithMaxTokensFieldAndRequestTimeout(apiKey, apiBase, proxy, maxTokensField, 0)
return NewHTTPProviderWithMaxTokensFieldAndRequestTimeout(apiKey, apiBase, proxy, maxTokensField, 0, nil)
}
func NewHTTPProviderWithMaxTokensFieldAndRequestTimeout(
apiKey, apiBase, proxy, maxTokensField string,
requestTimeoutSeconds int,
extraBody map[string]any,
) *HTTPProvider {
return &HTTPProvider{
delegate: openai_compat.NewProvider(
@@ -38,6 +39,7 @@ func NewHTTPProviderWithMaxTokensFieldAndRequestTimeout(
proxy,
openai_compat.WithMaxTokensField(maxTokensField),
openai_compat.WithRequestTimeout(time.Duration(requestTimeoutSeconds)*time.Second),
openai_compat.WithExtraBody(extraBody),
),
}
}
-17
View File
@@ -18,23 +18,6 @@ import (
func CreateProvider(cfg *config.Config) (LLMProvider, string, error) {
model := cfg.Agents.Defaults.GetModelName()
// Ensure model_list is populated from providers config if needed
// This handles two cases:
// 1. ModelList is empty - convert all providers
// 2. ModelList has some entries but not all providers - merge missing ones
if cfg.HasProvidersConfig() {
providerModels := config.ConvertProvidersToModelList(cfg)
existingModelNames := make(map[string]bool)
for _, m := range cfg.ModelList {
existingModelNames[m.ModelName] = true
}
for _, pm := range providerModels {
if !existingModelNames[pm.ModelName] {
cfg.ModelList = append(cfg.ModelList, pm)
}
}
}
// Must have model_list at this point
if len(cfg.ModelList) == 0 {
return nil, "", fmt.Errorf("no providers configured. Please add entries to model_list in your config")
+13
View File
@@ -35,6 +35,7 @@ type Provider struct {
apiBase string
maxTokensField string // Field name for max tokens (e.g., "max_completion_tokens" for o1/glm models)
httpClient *http.Client
extraBody map[string]any // Additional fields to inject into request body
}
type Option func(*Provider)
@@ -55,6 +56,12 @@ func WithRequestTimeout(timeout time.Duration) Option {
}
}
func WithExtraBody(extraBody map[string]any) Option {
return func(p *Provider) {
p.extraBody = extraBody
}
}
func NewProvider(apiKey, apiBase, proxy string, opts ...Option) *Provider {
p := &Provider{
apiKey: apiKey,
@@ -140,6 +147,12 @@ func (p *Provider) buildRequestBody(
}
}
// Merge extra body fields configured per-provider/model.
// These are injected last so they take precedence over defaults.
for k, v := range p.extraBody {
requestBody[k] = v
}
return requestBody
}
@@ -610,6 +610,90 @@ func TestProvider_RequestTimeoutOverride(t *testing.T) {
}
}
func TestProviderChat_ExtraBodyInjected(t *testing.T) {
var requestBody map[string]any
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
resp := map[string]any{
"choices": []map[string]any{
{
"message": map[string]any{"content": "ok"},
"finish_reason": "stop",
},
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
extraBody := map[string]any{"reasoning_split": true, "custom_field": "test"}
p := NewProvider("key", server.URL, "", WithExtraBody(extraBody))
_, err := p.Chat(
t.Context(),
[]Message{{Role: "user", Content: "hi"}},
nil,
"minimax/abab7",
nil,
)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
if got, ok := requestBody["reasoning_split"]; !ok || got != true {
t.Fatalf("reasoning_split = %v, want true", got)
}
if got, ok := requestBody["custom_field"]; !ok || got != "test" {
t.Fatalf("custom_field = %v, want test", got)
}
}
func TestProviderChat_ExtraBodyOverridesOptions(t *testing.T) {
var requestBody map[string]any
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
resp := map[string]any{
"choices": []map[string]any{
{
"message": map[string]any{"content": "ok"},
"finish_reason": "stop",
},
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
extraBody := map[string]any{"temperature": 0.9}
p := NewProvider("key", server.URL, "", WithExtraBody(extraBody))
_, err := p.Chat(
t.Context(),
[]Message{{Role: "user", Content: "hi"}},
nil,
"gpt-4o",
map[string]any{"temperature": 0.5},
)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
// ExtraBody takes precedence over options since it is merged last.
if got := requestBody["temperature"]; got != float64(0.9) {
t.Fatalf("temperature = %v, want 0.9 (from extraBody, overriding options)", got)
}
}
type roundTripperFunc func(*http.Request) (*http.Response, error)
func (f roundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) {