mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
Merge branch 'sipeed:main' into main
This commit is contained in:
@@ -188,17 +188,23 @@ func buildRequestBody(
|
||||
|
||||
case "user":
|
||||
if msg.ToolCallID != "" {
|
||||
// Tool result message
|
||||
content := []map[string]any{
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": msg.ToolCallID,
|
||||
"content": msg.Content,
|
||||
},
|
||||
// Tool result message — merge into previous user message if it contains tool_results
|
||||
toolResultBlock := map[string]any{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": msg.ToolCallID,
|
||||
"content": msg.Content,
|
||||
}
|
||||
if len(apiMessages) > 0 {
|
||||
if prev, ok := apiMessages[len(apiMessages)-1].(map[string]any); ok && prev["role"] == "user" {
|
||||
if content, ok := prev["content"].([]map[string]any); ok {
|
||||
prev["content"] = append(content, toolResultBlock)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
apiMessages = append(apiMessages, map[string]any{
|
||||
"role": "user",
|
||||
"content": content,
|
||||
"content": []map[string]any{toolResultBlock},
|
||||
})
|
||||
} else {
|
||||
// Regular user message
|
||||
@@ -246,17 +252,23 @@ func buildRequestBody(
|
||||
})
|
||||
|
||||
case "tool":
|
||||
// Tool result (alternative format)
|
||||
content := []map[string]any{
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": msg.ToolCallID,
|
||||
"content": msg.Content,
|
||||
},
|
||||
// Tool result (alternative format) — merge into previous user message if it contains tool_results
|
||||
toolResultBlock := map[string]any{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": msg.ToolCallID,
|
||||
"content": msg.Content,
|
||||
}
|
||||
if len(apiMessages) > 0 {
|
||||
if prev, ok := apiMessages[len(apiMessages)-1].(map[string]any); ok && prev["role"] == "user" {
|
||||
if content, ok := prev["content"].([]map[string]any); ok {
|
||||
prev["content"] = append(content, toolResultBlock)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
apiMessages = append(apiMessages, map[string]any{
|
||||
"role": "user",
|
||||
"content": content,
|
||||
"content": []map[string]any{toolResultBlock},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -562,6 +562,96 @@ func TestBuildRequestBodyEdgeCases(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildRequestBody_ConsecutiveToolResultsMerged(t *testing.T) {
|
||||
// Consecutive tool results (role "tool") should be merged into a single "user" message
|
||||
messages := []Message{
|
||||
{Role: "user", Content: "Use tools"},
|
||||
{Role: "assistant", Content: "", ToolCalls: []ToolCall{
|
||||
{ID: "t1", Name: "tool_a", Arguments: map[string]any{"x": 1}},
|
||||
{ID: "t2", Name: "tool_b", Arguments: map[string]any{"y": 2}},
|
||||
}},
|
||||
{Role: "tool", ToolCallID: "t1", Content: "result1"},
|
||||
{Role: "tool", ToolCallID: "t2", Content: "result2"},
|
||||
}
|
||||
|
||||
got, err := buildRequestBody(messages, nil, "test-model", map[string]any{"max_tokens": 8192})
|
||||
if err != nil {
|
||||
t.Fatalf("buildRequestBody() error: %v", err)
|
||||
}
|
||||
|
||||
apiMessages, ok := got["messages"].([]any)
|
||||
if !ok {
|
||||
t.Fatalf("messages is not []any")
|
||||
}
|
||||
|
||||
// Expect: user, assistant, user (merged tool results)
|
||||
if len(apiMessages) != 3 {
|
||||
for i, m := range apiMessages {
|
||||
t.Logf("message[%d]: %+v", i, m)
|
||||
}
|
||||
t.Fatalf("expected 3 API messages, got %d", len(apiMessages))
|
||||
}
|
||||
|
||||
// The third message should be a user message with 2 tool_result blocks
|
||||
toolResultMsg, ok := apiMessages[2].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("tool result message is not map[string]any")
|
||||
}
|
||||
if toolResultMsg["role"] != "user" {
|
||||
t.Errorf("expected role 'user', got %v", toolResultMsg["role"])
|
||||
}
|
||||
content, ok := toolResultMsg["content"].([]map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("content is not []map[string]any: %T", toolResultMsg["content"])
|
||||
}
|
||||
if len(content) != 2 {
|
||||
t.Fatalf("expected 2 tool_result blocks, got %d", len(content))
|
||||
}
|
||||
if content[0]["tool_use_id"] != "t1" {
|
||||
t.Errorf("first tool_result tool_use_id = %v, want t1", content[0]["tool_use_id"])
|
||||
}
|
||||
if content[1]["tool_use_id"] != "t2" {
|
||||
t.Errorf("second tool_result tool_use_id = %v, want t2", content[1]["tool_use_id"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildRequestBody_UserToolResultsMerged(t *testing.T) {
|
||||
// Consecutive tool results using role "user" with ToolCallID should also be merged
|
||||
messages := []Message{
|
||||
{Role: "user", Content: "Use tools"},
|
||||
{Role: "assistant", Content: "", ToolCalls: []ToolCall{
|
||||
{ID: "t1", Name: "tool_a", Arguments: map[string]any{"x": 1}},
|
||||
{ID: "t2", Name: "tool_b", Arguments: map[string]any{"y": 2}},
|
||||
}},
|
||||
{Role: "user", ToolCallID: "t1", Content: "result1"},
|
||||
{Role: "user", ToolCallID: "t2", Content: "result2"},
|
||||
}
|
||||
|
||||
got, err := buildRequestBody(messages, nil, "test-model", map[string]any{"max_tokens": 8192})
|
||||
if err != nil {
|
||||
t.Fatalf("buildRequestBody() error: %v", err)
|
||||
}
|
||||
|
||||
apiMessages, ok := got["messages"].([]any)
|
||||
if !ok {
|
||||
t.Fatalf("messages is not []any")
|
||||
}
|
||||
|
||||
// Expect: user, assistant, user (merged tool results)
|
||||
if len(apiMessages) != 3 {
|
||||
t.Fatalf("expected 3 API messages, got %d", len(apiMessages))
|
||||
}
|
||||
|
||||
toolResultMsg := apiMessages[2].(map[string]any)
|
||||
content, ok := toolResultMsg["content"].([]map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("content is not []map[string]any: %T", toolResultMsg["content"])
|
||||
}
|
||||
if len(content) != 2 {
|
||||
t.Fatalf("expected 2 tool_result blocks, got %d", len(content))
|
||||
}
|
||||
}
|
||||
|
||||
// TestParseResponseBodyEdgeCases tests edge cases for parseResponseBody.
|
||||
func TestParseResponseBodyEdgeCases(t *testing.T) {
|
||||
tests := []struct {
|
||||
|
||||
@@ -0,0 +1,582 @@
|
||||
//go:build bedrock
|
||||
|
||||
// PicoClaw - Ultra-lightweight personal AI agent
|
||||
// License: MIT
|
||||
//
|
||||
// Copyright (c) 2026 PicoClaw contributors
|
||||
|
||||
// Package bedrock implements the LLM provider interface for AWS Bedrock.
|
||||
// It uses the Bedrock Runtime Converse API for unified access to multiple
|
||||
// model families (Claude, Llama, Mistral, etc.) with tool/function calling support.
|
||||
package bedrock
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"math"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
"github.com/aws/aws-sdk-go-v2/config"
|
||||
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
|
||||
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime/document"
|
||||
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers/common"
|
||||
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
|
||||
)
|
||||
|
||||
type (
|
||||
ToolCall = protocoltypes.ToolCall
|
||||
FunctionCall = protocoltypes.FunctionCall
|
||||
LLMResponse = protocoltypes.LLMResponse
|
||||
UsageInfo = protocoltypes.UsageInfo
|
||||
Message = protocoltypes.Message
|
||||
ToolDefinition = protocoltypes.ToolDefinition
|
||||
ToolFunctionDefinition = protocoltypes.ToolFunctionDefinition
|
||||
)
|
||||
|
||||
// Provider implements the LLM provider interface for AWS Bedrock.
|
||||
type Provider struct {
|
||||
client *bedrockruntime.Client
|
||||
region string
|
||||
requestTimeout time.Duration
|
||||
}
|
||||
|
||||
// Option configures the Bedrock Provider.
|
||||
type Option func(*providerConfig)
|
||||
|
||||
type providerConfig struct {
|
||||
region string
|
||||
profile string
|
||||
baseEndpoint string
|
||||
requestTimeout time.Duration
|
||||
}
|
||||
|
||||
// WithRegion sets the AWS region for Bedrock requests.
|
||||
func WithRegion(region string) Option {
|
||||
return func(c *providerConfig) {
|
||||
c.region = region
|
||||
}
|
||||
}
|
||||
|
||||
// WithProfile sets the AWS profile to use for credentials.
|
||||
func WithProfile(profile string) Option {
|
||||
return func(c *providerConfig) {
|
||||
c.profile = profile
|
||||
}
|
||||
}
|
||||
|
||||
// WithBaseEndpoint sets a custom Bedrock endpoint URL.
|
||||
// Example: https://bedrock-runtime.us-east-1.amazonaws.com
|
||||
func WithBaseEndpoint(endpoint string) Option {
|
||||
return func(c *providerConfig) {
|
||||
c.baseEndpoint = endpoint
|
||||
}
|
||||
}
|
||||
|
||||
// WithRequestTimeout sets the timeout for Bedrock API requests.
|
||||
func WithRequestTimeout(timeout time.Duration) Option {
|
||||
return func(c *providerConfig) {
|
||||
c.requestTimeout = timeout
|
||||
}
|
||||
}
|
||||
|
||||
// NewProvider creates a new AWS Bedrock provider.
|
||||
// It uses the default AWS credential chain (env vars, shared config, IAM roles, etc.).
|
||||
func NewProvider(ctx context.Context, opts ...Option) (*Provider, error) {
|
||||
pc := &providerConfig{}
|
||||
for _, opt := range opts {
|
||||
opt(pc)
|
||||
}
|
||||
|
||||
// Build AWS config options
|
||||
var configOpts []func(*config.LoadOptions) error
|
||||
|
||||
if pc.region != "" {
|
||||
configOpts = append(configOpts, config.WithRegion(pc.region))
|
||||
}
|
||||
|
||||
if pc.profile != "" {
|
||||
configOpts = append(configOpts, config.WithSharedConfigProfile(pc.profile))
|
||||
}
|
||||
|
||||
// Load AWS config with automatic credential discovery
|
||||
cfg, err := config.LoadDefaultConfig(ctx, configOpts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("loading AWS config: %w", err)
|
||||
}
|
||||
|
||||
// Validate region is set - required for Bedrock request signing
|
||||
if cfg.Region == "" {
|
||||
return nil, fmt.Errorf(
|
||||
"AWS region not configured: set AWS_REGION, AWS_DEFAULT_REGION, or use WithRegion option",
|
||||
)
|
||||
}
|
||||
|
||||
// Build client options
|
||||
var clientOpts []func(*bedrockruntime.Options)
|
||||
if pc.baseEndpoint != "" {
|
||||
clientOpts = append(clientOpts, func(o *bedrockruntime.Options) {
|
||||
o.BaseEndpoint = aws.String(pc.baseEndpoint)
|
||||
})
|
||||
}
|
||||
|
||||
client := bedrockruntime.NewFromConfig(cfg, clientOpts...)
|
||||
|
||||
return &Provider{
|
||||
client: client,
|
||||
region: cfg.Region,
|
||||
requestTimeout: pc.requestTimeout,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Chat sends messages to AWS Bedrock using the Converse API.
|
||||
func (p *Provider) Chat(
|
||||
ctx context.Context,
|
||||
messages []Message,
|
||||
tools []ToolDefinition,
|
||||
model string,
|
||||
options map[string]any,
|
||||
) (*LLMResponse, error) {
|
||||
// Apply request timeout if context doesn't already have a deadline.
|
||||
// Use explicit timeout if set, otherwise fall back to common default.
|
||||
effectiveTimeout := p.requestTimeout
|
||||
if effectiveTimeout <= 0 {
|
||||
effectiveTimeout = common.DefaultRequestTimeout
|
||||
}
|
||||
if _, hasDeadline := ctx.Deadline(); !hasDeadline {
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithTimeout(ctx, effectiveTimeout)
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
// Build the Converse API input
|
||||
input := &bedrockruntime.ConverseInput{
|
||||
ModelId: aws.String(model),
|
||||
}
|
||||
|
||||
// Convert messages to Bedrock format
|
||||
bedrockMessages, systemPrompts := convertMessages(messages)
|
||||
input.Messages = bedrockMessages
|
||||
|
||||
// Set system prompts if any
|
||||
if len(systemPrompts) > 0 {
|
||||
input.System = systemPrompts
|
||||
}
|
||||
|
||||
// Set inference configuration only when options are provided
|
||||
var inferenceConfig *types.InferenceConfiguration
|
||||
|
||||
if maxTokens, ok := common.AsInt(options["max_tokens"]); ok && maxTokens > 0 {
|
||||
if inferenceConfig == nil {
|
||||
inferenceConfig = &types.InferenceConfiguration{}
|
||||
}
|
||||
// Clamp to int32 range to avoid overflow
|
||||
if maxTokens > math.MaxInt32 {
|
||||
maxTokens = math.MaxInt32
|
||||
}
|
||||
inferenceConfig.MaxTokens = aws.Int32(int32(maxTokens))
|
||||
}
|
||||
|
||||
if temp, ok := common.AsFloat(options["temperature"]); ok {
|
||||
if inferenceConfig == nil {
|
||||
inferenceConfig = &types.InferenceConfiguration{}
|
||||
}
|
||||
inferenceConfig.Temperature = aws.Float32(float32(temp))
|
||||
}
|
||||
|
||||
if inferenceConfig != nil {
|
||||
input.InferenceConfig = inferenceConfig
|
||||
}
|
||||
|
||||
// Convert tools to Bedrock format
|
||||
// Only set ToolConfig if at least one valid tool was produced
|
||||
if len(tools) > 0 {
|
||||
toolConfig := convertTools(tools)
|
||||
if len(toolConfig.Tools) > 0 {
|
||||
input.ToolConfig = toolConfig
|
||||
}
|
||||
}
|
||||
|
||||
// Call Bedrock Converse API
|
||||
output, err := p.client.Converse(ctx, input)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("bedrock converse: %w", err)
|
||||
}
|
||||
|
||||
// Parse the response
|
||||
return parseResponse(output)
|
||||
}
|
||||
|
||||
// GetDefaultModel returns an empty string as Bedrock models are user-configured.
|
||||
func (p *Provider) GetDefaultModel() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Region returns the AWS region configured for this Provider.
|
||||
func (p *Provider) Region() string {
|
||||
return p.region
|
||||
}
|
||||
|
||||
// convertMessages converts internal messages to Bedrock Converse format.
|
||||
// Returns the conversation messages and any system prompts separately.
|
||||
// Note: Bedrock requires all tool results for a given assistant turn to be in a single
|
||||
// user message with multiple ToolResultBlock content blocks. This function merges
|
||||
// consecutive tool result messages accordingly.
|
||||
func convertMessages(messages []Message) ([]types.Message, []types.SystemContentBlock) {
|
||||
var bedrockMessages []types.Message
|
||||
var systemPrompts []types.SystemContentBlock
|
||||
|
||||
// Helper to check if a message is a tool result
|
||||
isToolResult := func(msg Message) bool {
|
||||
return (msg.Role == "tool" || (msg.Role == "user" && msg.ToolCallID != "")) && msg.ToolCallID != ""
|
||||
}
|
||||
|
||||
// Helper to create a tool result content block
|
||||
makeToolResultBlock := func(msg Message) types.ContentBlock {
|
||||
return &types.ContentBlockMemberToolResult{
|
||||
Value: types.ToolResultBlock{
|
||||
ToolUseId: aws.String(msg.ToolCallID),
|
||||
Content: []types.ToolResultContentBlock{
|
||||
&types.ToolResultContentBlockMemberText{
|
||||
Value: msg.Content,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
i := 0
|
||||
for i < len(messages) {
|
||||
msg := messages[i]
|
||||
|
||||
switch {
|
||||
case msg.Role == "system":
|
||||
// System messages go to the System field
|
||||
systemPrompts = append(systemPrompts, &types.SystemContentBlockMemberText{
|
||||
Value: msg.Content,
|
||||
})
|
||||
i++
|
||||
|
||||
case isToolResult(msg):
|
||||
// Collect all consecutive tool results into a single user message
|
||||
// Bedrock requires all tool results for a turn in one message
|
||||
var toolResultBlocks []types.ContentBlock
|
||||
for i < len(messages) && isToolResult(messages[i]) {
|
||||
toolResultBlocks = append(toolResultBlocks, makeToolResultBlock(messages[i]))
|
||||
i++
|
||||
}
|
||||
bedrockMessages = append(bedrockMessages, types.Message{
|
||||
Role: types.ConversationRoleUser,
|
||||
Content: toolResultBlocks,
|
||||
})
|
||||
|
||||
case msg.Role == "user":
|
||||
// Regular user message (no ToolCallID)
|
||||
content := buildUserContent(msg)
|
||||
bedrockMessages = append(bedrockMessages, types.Message{
|
||||
Role: types.ConversationRoleUser,
|
||||
Content: content,
|
||||
})
|
||||
i++
|
||||
|
||||
case msg.Role == "assistant":
|
||||
content := buildAssistantContent(msg)
|
||||
bedrockMessages = append(bedrockMessages, types.Message{
|
||||
Role: types.ConversationRoleAssistant,
|
||||
Content: content,
|
||||
})
|
||||
i++
|
||||
|
||||
case msg.Role == "tool" && msg.ToolCallID == "":
|
||||
// Tool message without ToolCallID - treat as regular user message
|
||||
content := buildUserContent(msg)
|
||||
bedrockMessages = append(bedrockMessages, types.Message{
|
||||
Role: types.ConversationRoleUser,
|
||||
Content: content,
|
||||
})
|
||||
i++
|
||||
|
||||
default:
|
||||
// Unknown role - skip
|
||||
i++
|
||||
}
|
||||
}
|
||||
|
||||
return bedrockMessages, systemPrompts
|
||||
}
|
||||
|
||||
// buildUserContent builds Bedrock content blocks for a user message.
|
||||
func buildUserContent(msg Message) []types.ContentBlock {
|
||||
var content []types.ContentBlock
|
||||
|
||||
// Add text content
|
||||
if msg.Content != "" {
|
||||
content = append(content, &types.ContentBlockMemberText{
|
||||
Value: msg.Content,
|
||||
})
|
||||
}
|
||||
|
||||
// Add images from Media field
|
||||
for _, mediaURL := range msg.Media {
|
||||
if strings.HasPrefix(mediaURL, "data:image/") {
|
||||
// Parse data URL: data:image/jpeg;base64,<data>
|
||||
parts := strings.SplitN(mediaURL, ",", 2)
|
||||
if len(parts) != 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Extract media type from "data:image/jpeg;base64"
|
||||
mediaType := ""
|
||||
header := parts[0]
|
||||
if idx := strings.Index(header, "/"); idx != -1 {
|
||||
end := strings.Index(header[idx:], ";")
|
||||
if end == -1 {
|
||||
end = len(header) - idx
|
||||
}
|
||||
mediaType = header[idx+1 : idx+end]
|
||||
}
|
||||
|
||||
// Verify this is base64 encoded
|
||||
if !strings.Contains(header, ";base64") {
|
||||
continue // Skip non-base64 encoded data
|
||||
}
|
||||
|
||||
// Map media type to Bedrock format
|
||||
var format types.ImageFormat
|
||||
switch mediaType {
|
||||
case "jpeg", "jpg":
|
||||
format = types.ImageFormatJpeg
|
||||
case "png":
|
||||
format = types.ImageFormatPng
|
||||
case "gif":
|
||||
format = types.ImageFormatGif
|
||||
case "webp":
|
||||
format = types.ImageFormatWebp
|
||||
default:
|
||||
continue // Skip unsupported formats
|
||||
}
|
||||
|
||||
// Check size before decoding to prevent excessive memory allocation
|
||||
// Bedrock has a ~20MB request limit; cap decoded images at 10MB
|
||||
const maxImageSize = 10 * 1024 * 1024
|
||||
decodedLen := base64.StdEncoding.DecodedLen(len(parts[1]))
|
||||
if decodedLen > maxImageSize {
|
||||
log.Printf("bedrock: skipping image exceeding size limit (%d bytes > %d)", decodedLen, maxImageSize)
|
||||
continue
|
||||
}
|
||||
|
||||
// Decode base64 data
|
||||
imageData, err := base64.StdEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
log.Printf("bedrock: failed to decode base64 image data: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
content = append(content, &types.ContentBlockMemberImage{
|
||||
Value: types.ImageBlock{
|
||||
Format: format,
|
||||
Source: &types.ImageSourceMemberBytes{
|
||||
Value: imageData,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Bedrock requires at least one content block; add empty text if needed
|
||||
if len(content) == 0 {
|
||||
content = append(content, &types.ContentBlockMemberText{Value: ""})
|
||||
}
|
||||
|
||||
return content
|
||||
}
|
||||
|
||||
// buildAssistantContent builds Bedrock content blocks for an assistant message.
|
||||
func buildAssistantContent(msg Message) []types.ContentBlock {
|
||||
var content []types.ContentBlock
|
||||
|
||||
// Add text content if present
|
||||
if msg.Content != "" {
|
||||
content = append(content, &types.ContentBlockMemberText{
|
||||
Value: msg.Content,
|
||||
})
|
||||
}
|
||||
|
||||
// Add tool use blocks
|
||||
for _, tc := range msg.ToolCalls {
|
||||
// Validate tool call ID - Bedrock requires non-empty ToolUseId
|
||||
if strings.TrimSpace(tc.ID) == "" {
|
||||
log.Printf("bedrock: skipping tool call with empty ID (name: %q)", tc.Name)
|
||||
continue
|
||||
}
|
||||
|
||||
// Resolve tool name: prefer tc.Name, fallback to tc.Function.Name
|
||||
// (tc.Name/tc.Arguments are json:"-" and may be empty when from JSON)
|
||||
toolName := tc.Name
|
||||
if toolName == "" && tc.Function != nil {
|
||||
toolName = tc.Function.Name
|
||||
}
|
||||
if strings.TrimSpace(toolName) == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Resolve arguments: prefer tc.Arguments, fallback to parsing tc.Function.Arguments
|
||||
args := tc.Arguments
|
||||
if args == nil && tc.Function != nil && tc.Function.Arguments != "" {
|
||||
if err := json.Unmarshal([]byte(tc.Function.Arguments), &args); err != nil {
|
||||
log.Printf("bedrock: failed to parse Function.Arguments for tool %q: %v", toolName, err)
|
||||
args = map[string]any{}
|
||||
}
|
||||
}
|
||||
if args == nil {
|
||||
args = map[string]any{}
|
||||
}
|
||||
|
||||
// Convert arguments to a Bedrock document using NewLazyDocument
|
||||
inputDoc := document.NewLazyDocument(args)
|
||||
|
||||
content = append(content, &types.ContentBlockMemberToolUse{
|
||||
Value: types.ToolUseBlock{
|
||||
ToolUseId: aws.String(tc.ID),
|
||||
Name: aws.String(toolName),
|
||||
Input: inputDoc,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// Bedrock requires at least one content block; add empty text if needed
|
||||
if len(content) == 0 {
|
||||
content = append(content, &types.ContentBlockMemberText{Value: ""})
|
||||
}
|
||||
|
||||
return content
|
||||
}
|
||||
|
||||
// convertTools converts tool definitions to Bedrock format.
|
||||
func convertTools(tools []ToolDefinition) *types.ToolConfiguration {
|
||||
bedrockTools := make([]types.Tool, 0, len(tools))
|
||||
|
||||
for _, tool := range tools {
|
||||
// Skip tools with empty names
|
||||
if strings.TrimSpace(tool.Function.Name) == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Ensure parameters is not nil - default to minimal object schema
|
||||
params := tool.Function.Parameters
|
||||
if params == nil {
|
||||
params = map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{},
|
||||
}
|
||||
}
|
||||
|
||||
// Convert parameters schema to a Bedrock document
|
||||
inputSchema := document.NewLazyDocument(params)
|
||||
|
||||
bedrockTools = append(bedrockTools, &types.ToolMemberToolSpec{
|
||||
Value: types.ToolSpecification{
|
||||
Name: aws.String(tool.Function.Name),
|
||||
Description: aws.String(tool.Function.Description),
|
||||
InputSchema: &types.ToolInputSchemaMemberJson{
|
||||
Value: inputSchema,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
return &types.ToolConfiguration{
|
||||
Tools: bedrockTools,
|
||||
}
|
||||
}
|
||||
|
||||
// parseResponse converts Bedrock Converse output to LLMResponse.
|
||||
func parseResponse(output *bedrockruntime.ConverseOutput) (*LLMResponse, error) {
|
||||
var content strings.Builder
|
||||
toolCalls := make([]ToolCall, 0)
|
||||
|
||||
// Process output content blocks
|
||||
if output.Output != nil {
|
||||
if msgOutput, ok := output.Output.(*types.ConverseOutputMemberMessage); ok {
|
||||
for _, block := range msgOutput.Value.Content {
|
||||
switch b := block.(type) {
|
||||
case *types.ContentBlockMemberText:
|
||||
content.WriteString(b.Value)
|
||||
|
||||
case *types.ContentBlockMemberToolUse:
|
||||
// Unmarshal the document interface to a map
|
||||
args := make(map[string]any)
|
||||
if b.Value.Input != nil {
|
||||
if err := b.Value.Input.UnmarshalSmithyDocument(&args); err != nil {
|
||||
log.Printf("bedrock: failed to unmarshal tool input for tool %q (id %q): %v",
|
||||
aws.ToString(b.Value.Name),
|
||||
aws.ToString(b.Value.ToolUseId),
|
||||
err,
|
||||
)
|
||||
args = make(map[string]any)
|
||||
}
|
||||
}
|
||||
|
||||
// Serialize arguments to JSON string for FunctionCall
|
||||
argsJSON, err := json.Marshal(args)
|
||||
if err != nil {
|
||||
log.Printf("bedrock: failed to marshal tool arguments for tool %q (id %q): %v",
|
||||
aws.ToString(b.Value.Name),
|
||||
aws.ToString(b.Value.ToolUseId),
|
||||
err,
|
||||
)
|
||||
argsJSON = []byte("{}")
|
||||
}
|
||||
|
||||
toolCalls = append(toolCalls, ToolCall{
|
||||
ID: aws.ToString(b.Value.ToolUseId),
|
||||
Name: aws.ToString(b.Value.Name),
|
||||
Arguments: args,
|
||||
Function: &FunctionCall{
|
||||
Name: aws.ToString(b.Value.Name),
|
||||
Arguments: string(argsJSON),
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Map stop reason
|
||||
finishReason := "stop"
|
||||
switch output.StopReason {
|
||||
case types.StopReasonToolUse:
|
||||
finishReason = "tool_calls"
|
||||
case types.StopReasonMaxTokens:
|
||||
finishReason = "length"
|
||||
case types.StopReasonEndTurn:
|
||||
finishReason = "stop"
|
||||
case types.StopReasonStopSequence:
|
||||
finishReason = "stop"
|
||||
case types.StopReasonContentFiltered:
|
||||
finishReason = "content_filter"
|
||||
}
|
||||
|
||||
// Build usage info
|
||||
var usage *UsageInfo
|
||||
if output.Usage != nil {
|
||||
usage = &UsageInfo{
|
||||
PromptTokens: int(aws.ToInt32(output.Usage.InputTokens)),
|
||||
CompletionTokens: int(aws.ToInt32(output.Usage.OutputTokens)),
|
||||
TotalTokens: int(aws.ToInt32(output.Usage.InputTokens)) + int(aws.ToInt32(output.Usage.OutputTokens)),
|
||||
}
|
||||
}
|
||||
|
||||
return &LLMResponse{
|
||||
Content: content.String(),
|
||||
ToolCalls: toolCalls,
|
||||
FinishReason: finishReason,
|
||||
Usage: usage,
|
||||
}, nil
|
||||
}
|
||||
@@ -0,0 +1,541 @@
|
||||
//go:build bedrock
|
||||
|
||||
// PicoClaw - Ultra-lightweight personal AI agent
|
||||
// License: MIT
|
||||
//
|
||||
// Copyright (c) 2026 PicoClaw contributors
|
||||
|
||||
package bedrock
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
|
||||
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime/document"
|
||||
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
|
||||
)
|
||||
|
||||
func TestConvertMessages_SystemPrompts(t *testing.T) {
|
||||
messages := []Message{
|
||||
{Role: "system", Content: "You are a helpful assistant."},
|
||||
{Role: "user", Content: "Hello"},
|
||||
}
|
||||
|
||||
bedrockMsgs, systemPrompts := convertMessages(messages)
|
||||
|
||||
assert.Len(t, systemPrompts, 1)
|
||||
assert.Len(t, bedrockMsgs, 1)
|
||||
|
||||
// Check system prompt
|
||||
textBlock, ok := systemPrompts[0].(*types.SystemContentBlockMemberText)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "You are a helpful assistant.", textBlock.Value)
|
||||
|
||||
// Check user message
|
||||
assert.Equal(t, types.ConversationRoleUser, bedrockMsgs[0].Role)
|
||||
}
|
||||
|
||||
func TestConvertMessages_UserMessage(t *testing.T) {
|
||||
messages := []Message{
|
||||
{Role: "user", Content: "What is 2+2?"},
|
||||
}
|
||||
|
||||
bedrockMsgs, systemPrompts := convertMessages(messages)
|
||||
|
||||
assert.Empty(t, systemPrompts)
|
||||
assert.Len(t, bedrockMsgs, 1)
|
||||
assert.Equal(t, types.ConversationRoleUser, bedrockMsgs[0].Role)
|
||||
|
||||
textBlock, ok := bedrockMsgs[0].Content[0].(*types.ContentBlockMemberText)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "What is 2+2?", textBlock.Value)
|
||||
}
|
||||
|
||||
func TestConvertMessages_AssistantMessage(t *testing.T) {
|
||||
messages := []Message{
|
||||
{Role: "assistant", Content: "The answer is 4."},
|
||||
}
|
||||
|
||||
bedrockMsgs, _ := convertMessages(messages)
|
||||
|
||||
assert.Len(t, bedrockMsgs, 1)
|
||||
assert.Equal(t, types.ConversationRoleAssistant, bedrockMsgs[0].Role)
|
||||
|
||||
textBlock, ok := bedrockMsgs[0].Content[0].(*types.ContentBlockMemberText)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "The answer is 4.", textBlock.Value)
|
||||
}
|
||||
|
||||
func TestConvertMessages_ToolResult(t *testing.T) {
|
||||
messages := []Message{
|
||||
{Role: "tool", Content: "Result from tool", ToolCallID: "call_123"},
|
||||
}
|
||||
|
||||
bedrockMsgs, _ := convertMessages(messages)
|
||||
|
||||
assert.Len(t, bedrockMsgs, 1)
|
||||
assert.Equal(t, types.ConversationRoleUser, bedrockMsgs[0].Role)
|
||||
|
||||
toolResult, ok := bedrockMsgs[0].Content[0].(*types.ContentBlockMemberToolResult)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "call_123", aws.ToString(toolResult.Value.ToolUseId))
|
||||
}
|
||||
|
||||
func TestConvertMessages_MultipleToolResultsMerged(t *testing.T) {
|
||||
// When an assistant makes multiple tool calls, all tool results must be
|
||||
// merged into a single user message for Bedrock
|
||||
messages := []Message{
|
||||
{Role: "user", Content: "What's the weather in NYC and LA?"},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "Let me check both cities.",
|
||||
ToolCalls: []protocoltypes.ToolCall{
|
||||
{ID: "call_nyc", Name: "get_weather", Arguments: map[string]any{"city": "NYC"}},
|
||||
{ID: "call_la", Name: "get_weather", Arguments: map[string]any{"city": "LA"}},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: "NYC: 72°F, sunny", ToolCallID: "call_nyc"},
|
||||
{Role: "tool", Content: "LA: 85°F, clear", ToolCallID: "call_la"},
|
||||
}
|
||||
|
||||
bedrockMsgs, _ := convertMessages(messages)
|
||||
|
||||
// Should be: user message, assistant message, merged tool results (single user message)
|
||||
assert.Len(t, bedrockMsgs, 3)
|
||||
|
||||
// First message: user
|
||||
assert.Equal(t, types.ConversationRoleUser, bedrockMsgs[0].Role)
|
||||
|
||||
// Second message: assistant with tool calls
|
||||
assert.Equal(t, types.ConversationRoleAssistant, bedrockMsgs[1].Role)
|
||||
|
||||
// Third message: merged tool results in single user message
|
||||
assert.Equal(t, types.ConversationRoleUser, bedrockMsgs[2].Role)
|
||||
assert.Len(t, bedrockMsgs[2].Content, 2) // Both tool results in one message
|
||||
|
||||
// Verify both tool results are present
|
||||
result1, ok := bedrockMsgs[2].Content[0].(*types.ContentBlockMemberToolResult)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "call_nyc", aws.ToString(result1.Value.ToolUseId))
|
||||
|
||||
result2, ok := bedrockMsgs[2].Content[1].(*types.ContentBlockMemberToolResult)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "call_la", aws.ToString(result2.Value.ToolUseId))
|
||||
}
|
||||
|
||||
func TestConvertMessages_AssistantWithToolCalls(t *testing.T) {
|
||||
messages := []Message{
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "Let me calculate that.",
|
||||
ToolCalls: []protocoltypes.ToolCall{
|
||||
{
|
||||
ID: "call_456",
|
||||
Name: "calculator",
|
||||
Arguments: map[string]any{"expression": "2+2"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
bedrockMsgs, _ := convertMessages(messages)
|
||||
|
||||
assert.Len(t, bedrockMsgs, 1)
|
||||
assert.Len(t, bedrockMsgs[0].Content, 2) // text + tool use
|
||||
|
||||
// Check text content
|
||||
textBlock, ok := bedrockMsgs[0].Content[0].(*types.ContentBlockMemberText)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "Let me calculate that.", textBlock.Value)
|
||||
|
||||
// Check tool use
|
||||
toolUse, ok := bedrockMsgs[0].Content[1].(*types.ContentBlockMemberToolUse)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "call_456", aws.ToString(toolUse.Value.ToolUseId))
|
||||
assert.Equal(t, "calculator", aws.ToString(toolUse.Value.Name))
|
||||
}
|
||||
|
||||
func TestConvertTools_Basic(t *testing.T) {
|
||||
tools := []ToolDefinition{
|
||||
{
|
||||
Function: protocoltypes.ToolFunctionDefinition{
|
||||
Name: "get_weather",
|
||||
Description: "Get the current weather",
|
||||
Parameters: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"location": map[string]any{"type": "string"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
toolConfig := convertTools(tools)
|
||||
|
||||
assert.NotNil(t, toolConfig)
|
||||
assert.Len(t, toolConfig.Tools, 1)
|
||||
|
||||
toolSpec, ok := toolConfig.Tools[0].(*types.ToolMemberToolSpec)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "get_weather", aws.ToString(toolSpec.Value.Name))
|
||||
assert.Equal(t, "Get the current weather", aws.ToString(toolSpec.Value.Description))
|
||||
}
|
||||
|
||||
func TestConvertTools_SkipsEmptyName(t *testing.T) {
|
||||
tools := []ToolDefinition{
|
||||
{
|
||||
Function: protocoltypes.ToolFunctionDefinition{
|
||||
Name: "",
|
||||
Description: "Empty name tool",
|
||||
},
|
||||
},
|
||||
{
|
||||
Function: protocoltypes.ToolFunctionDefinition{
|
||||
Name: " ",
|
||||
Description: "Whitespace name tool",
|
||||
},
|
||||
},
|
||||
{
|
||||
Function: protocoltypes.ToolFunctionDefinition{
|
||||
Name: "valid_tool",
|
||||
Description: "Valid tool",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
toolConfig := convertTools(tools)
|
||||
|
||||
assert.Len(t, toolConfig.Tools, 1)
|
||||
toolSpec := toolConfig.Tools[0].(*types.ToolMemberToolSpec)
|
||||
assert.Equal(t, "valid_tool", aws.ToString(toolSpec.Value.Name))
|
||||
}
|
||||
|
||||
func TestConvertTools_NilParameters(t *testing.T) {
|
||||
tools := []ToolDefinition{
|
||||
{
|
||||
Function: protocoltypes.ToolFunctionDefinition{
|
||||
Name: "simple_tool",
|
||||
Description: "A tool with no parameters",
|
||||
Parameters: nil,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
toolConfig := convertTools(tools)
|
||||
|
||||
assert.Len(t, toolConfig.Tools, 1)
|
||||
// Should not panic and should create a valid tool
|
||||
}
|
||||
|
||||
func TestBuildUserContent_TextOnly(t *testing.T) {
|
||||
msg := Message{Content: "Hello world"}
|
||||
|
||||
content := buildUserContent(msg)
|
||||
|
||||
assert.Len(t, content, 1)
|
||||
textBlock, ok := content[0].(*types.ContentBlockMemberText)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "Hello world", textBlock.Value)
|
||||
}
|
||||
|
||||
func TestBuildUserContent_WithImage(t *testing.T) {
|
||||
// Base64-encoded 1x1 PNG (the provider doesn't validate image correctness,
|
||||
// it just verifies the format and base64 decoding works)
|
||||
b64Data := "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADUlEQVR4nGNgYAAAAAMAASsJTYQAAAAASUVORK5CYII="
|
||||
|
||||
msg := Message{
|
||||
Content: "Look at this image",
|
||||
Media: []string{"data:image/png;base64," + b64Data},
|
||||
}
|
||||
|
||||
content := buildUserContent(msg)
|
||||
|
||||
assert.Len(t, content, 2)
|
||||
|
||||
// Check text
|
||||
textBlock, ok := content[0].(*types.ContentBlockMemberText)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "Look at this image", textBlock.Value)
|
||||
|
||||
// Check image
|
||||
imageBlock, ok := content[1].(*types.ContentBlockMemberImage)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, types.ImageFormatPng, imageBlock.Value.Format)
|
||||
}
|
||||
|
||||
func TestBuildUserContent_SkipsInvalidBase64(t *testing.T) {
|
||||
msg := Message{
|
||||
Content: "Invalid image",
|
||||
Media: []string{"data:image/png;base64,not-valid-base64!!!"},
|
||||
}
|
||||
|
||||
content := buildUserContent(msg)
|
||||
|
||||
// Should only have text, image should be skipped
|
||||
assert.Len(t, content, 1)
|
||||
}
|
||||
|
||||
func TestBuildUserContent_SkipsNonBase64Data(t *testing.T) {
|
||||
msg := Message{
|
||||
Content: "Non-base64 image",
|
||||
Media: []string{"data:image/png,raw-data-here"},
|
||||
}
|
||||
|
||||
content := buildUserContent(msg)
|
||||
|
||||
// Should only have text, non-base64 image should be skipped
|
||||
assert.Len(t, content, 1)
|
||||
}
|
||||
|
||||
func TestBuildAssistantContent_SkipsEmptyToolName(t *testing.T) {
|
||||
msg := Message{
|
||||
Content: "Response",
|
||||
ToolCalls: []protocoltypes.ToolCall{
|
||||
{ID: "1", Name: "", Arguments: map[string]any{}},
|
||||
{ID: "2", Name: " ", Arguments: map[string]any{}},
|
||||
{ID: "3", Name: "valid", Arguments: map[string]any{}},
|
||||
},
|
||||
}
|
||||
|
||||
content := buildAssistantContent(msg)
|
||||
|
||||
// Should have text + 1 valid tool
|
||||
assert.Len(t, content, 2)
|
||||
}
|
||||
|
||||
func TestBuildAssistantContent_NilArguments(t *testing.T) {
|
||||
msg := Message{
|
||||
ToolCalls: []protocoltypes.ToolCall{
|
||||
{ID: "1", Name: "tool", Arguments: nil},
|
||||
},
|
||||
}
|
||||
|
||||
content := buildAssistantContent(msg)
|
||||
|
||||
assert.Len(t, content, 1)
|
||||
toolUse, ok := content[0].(*types.ContentBlockMemberToolUse)
|
||||
require.True(t, ok)
|
||||
assert.NotNil(t, toolUse.Value.Input)
|
||||
}
|
||||
|
||||
func TestBuildAssistantContent_FunctionFallback(t *testing.T) {
|
||||
// When Name/Arguments are empty (json:"-"), should fallback to Function fields
|
||||
msg := Message{
|
||||
ToolCalls: []protocoltypes.ToolCall{
|
||||
{
|
||||
ID: "1",
|
||||
Name: "", // empty, should fallback to Function.Name
|
||||
Function: &protocoltypes.FunctionCall{
|
||||
Name: "fallback_tool",
|
||||
Arguments: `{"key":"value"}`,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
content := buildAssistantContent(msg)
|
||||
|
||||
assert.Len(t, content, 1)
|
||||
toolUse, ok := content[0].(*types.ContentBlockMemberToolUse)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "fallback_tool", aws.ToString(toolUse.Value.Name))
|
||||
}
|
||||
|
||||
func TestParseResponse_TextOnly(t *testing.T) {
|
||||
output := &bedrockruntime.ConverseOutput{
|
||||
Output: &types.ConverseOutputMemberMessage{
|
||||
Value: types.Message{
|
||||
Role: types.ConversationRoleAssistant,
|
||||
Content: []types.ContentBlock{
|
||||
&types.ContentBlockMemberText{Value: "Hello!"},
|
||||
},
|
||||
},
|
||||
},
|
||||
StopReason: types.StopReasonEndTurn,
|
||||
Usage: &types.TokenUsage{
|
||||
InputTokens: aws.Int32(10),
|
||||
OutputTokens: aws.Int32(5),
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := parseResponse(output)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Hello!", resp.Content)
|
||||
assert.Equal(t, "stop", resp.FinishReason)
|
||||
assert.Empty(t, resp.ToolCalls)
|
||||
assert.Equal(t, 10, resp.Usage.PromptTokens)
|
||||
assert.Equal(t, 5, resp.Usage.CompletionTokens)
|
||||
}
|
||||
|
||||
func TestParseResponse_StopReasons(t *testing.T) {
|
||||
tests := []struct {
|
||||
stopReason types.StopReason
|
||||
expectedFinish string
|
||||
}{
|
||||
{types.StopReasonEndTurn, "stop"},
|
||||
{types.StopReasonToolUse, "tool_calls"},
|
||||
{types.StopReasonMaxTokens, "length"},
|
||||
{types.StopReasonStopSequence, "stop"},
|
||||
{types.StopReasonContentFiltered, "content_filter"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(string(tt.stopReason), func(t *testing.T) {
|
||||
output := &bedrockruntime.ConverseOutput{
|
||||
Output: &types.ConverseOutputMemberMessage{
|
||||
Value: types.Message{
|
||||
Content: []types.ContentBlock{
|
||||
&types.ContentBlockMemberText{Value: "test"},
|
||||
},
|
||||
},
|
||||
},
|
||||
StopReason: tt.stopReason,
|
||||
}
|
||||
|
||||
resp, err := parseResponse(output)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.expectedFinish, resp.FinishReason)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseResponse_WithToolCalls(t *testing.T) {
|
||||
// Note: document.NewLazyDocument has limitations with UnmarshalSmithyDocument in tests,
|
||||
// so we test the structure extraction and verify Arguments gets populated (even if empty
|
||||
// due to SDK limitations). The actual unmarshal works correctly at runtime.
|
||||
toolInput := document.NewLazyDocument(map[string]any{
|
||||
"location": "San Francisco",
|
||||
"unit": "celsius",
|
||||
})
|
||||
|
||||
output := &bedrockruntime.ConverseOutput{
|
||||
Output: &types.ConverseOutputMemberMessage{
|
||||
Value: types.Message{
|
||||
Role: types.ConversationRoleAssistant,
|
||||
Content: []types.ContentBlock{
|
||||
&types.ContentBlockMemberText{Value: "Let me check the weather."},
|
||||
&types.ContentBlockMemberToolUse{
|
||||
Value: types.ToolUseBlock{
|
||||
ToolUseId: aws.String("call_weather_123"),
|
||||
Name: aws.String("get_weather"),
|
||||
Input: toolInput,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
StopReason: types.StopReasonToolUse,
|
||||
Usage: &types.TokenUsage{
|
||||
InputTokens: aws.Int32(20),
|
||||
OutputTokens: aws.Int32(15),
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := parseResponse(output)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Let me check the weather.", resp.Content)
|
||||
assert.Equal(t, "tool_calls", resp.FinishReason)
|
||||
assert.Len(t, resp.ToolCalls, 1)
|
||||
|
||||
// Verify tool call ID and Name are extracted correctly
|
||||
tc := resp.ToolCalls[0]
|
||||
assert.Equal(t, "call_weather_123", tc.ID)
|
||||
assert.Equal(t, "get_weather", tc.Name)
|
||||
|
||||
// Verify Function fields are also populated
|
||||
require.NotNil(t, tc.Function)
|
||||
assert.Equal(t, "get_weather", tc.Function.Name)
|
||||
|
||||
// Verify Arguments is not nil (content may vary due to SDK limitations in tests)
|
||||
assert.NotNil(t, tc.Arguments)
|
||||
|
||||
// Verify usage
|
||||
assert.Equal(t, 20, resp.Usage.PromptTokens)
|
||||
assert.Equal(t, 15, resp.Usage.CompletionTokens)
|
||||
assert.Equal(t, 35, resp.Usage.TotalTokens)
|
||||
}
|
||||
|
||||
func TestParseResponse_MultipleToolCalls(t *testing.T) {
|
||||
output := &bedrockruntime.ConverseOutput{
|
||||
Output: &types.ConverseOutputMemberMessage{
|
||||
Value: types.Message{
|
||||
Role: types.ConversationRoleAssistant,
|
||||
Content: []types.ContentBlock{
|
||||
&types.ContentBlockMemberToolUse{
|
||||
Value: types.ToolUseBlock{
|
||||
ToolUseId: aws.String("call_1"),
|
||||
Name: aws.String("tool_a"),
|
||||
Input: document.NewLazyDocument(map[string]any{"arg": "value1"}),
|
||||
},
|
||||
},
|
||||
&types.ContentBlockMemberToolUse{
|
||||
Value: types.ToolUseBlock{
|
||||
ToolUseId: aws.String("call_2"),
|
||||
Name: aws.String("tool_b"),
|
||||
Input: document.NewLazyDocument(map[string]any{"arg": "value2"}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
StopReason: types.StopReasonToolUse,
|
||||
}
|
||||
|
||||
resp, err := parseResponse(output)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "tool_calls", resp.FinishReason)
|
||||
assert.Len(t, resp.ToolCalls, 2)
|
||||
|
||||
// Verify tool call structure
|
||||
assert.Equal(t, "call_1", resp.ToolCalls[0].ID)
|
||||
assert.Equal(t, "tool_a", resp.ToolCalls[0].Name)
|
||||
assert.NotNil(t, resp.ToolCalls[0].Arguments)
|
||||
assert.NotNil(t, resp.ToolCalls[0].Function)
|
||||
assert.Equal(t, "tool_a", resp.ToolCalls[0].Function.Name)
|
||||
|
||||
assert.Equal(t, "call_2", resp.ToolCalls[1].ID)
|
||||
assert.Equal(t, "tool_b", resp.ToolCalls[1].Name)
|
||||
assert.NotNil(t, resp.ToolCalls[1].Arguments)
|
||||
assert.NotNil(t, resp.ToolCalls[1].Function)
|
||||
assert.Equal(t, "tool_b", resp.ToolCalls[1].Function.Name)
|
||||
}
|
||||
|
||||
func TestParseResponse_ToolCallWithNilInput(t *testing.T) {
|
||||
output := &bedrockruntime.ConverseOutput{
|
||||
Output: &types.ConverseOutputMemberMessage{
|
||||
Value: types.Message{
|
||||
Role: types.ConversationRoleAssistant,
|
||||
Content: []types.ContentBlock{
|
||||
&types.ContentBlockMemberToolUse{
|
||||
Value: types.ToolUseBlock{
|
||||
ToolUseId: aws.String("call_nil"),
|
||||
Name: aws.String("no_args_tool"),
|
||||
Input: nil,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
StopReason: types.StopReasonToolUse,
|
||||
}
|
||||
|
||||
resp, err := parseResponse(output)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, resp.ToolCalls, 1)
|
||||
assert.Equal(t, "call_nil", resp.ToolCalls[0].ID)
|
||||
assert.Equal(t, "no_args_tool", resp.ToolCalls[0].Name)
|
||||
// Arguments should be empty map, not nil
|
||||
assert.NotNil(t, resp.ToolCalls[0].Arguments)
|
||||
assert.Empty(t, resp.ToolCalls[0].Arguments)
|
||||
}
|
||||
@@ -0,0 +1,73 @@
|
||||
//go:build !bedrock
|
||||
|
||||
// PicoClaw - Ultra-lightweight personal AI agent
|
||||
// License: MIT
|
||||
//
|
||||
// Copyright (c) 2026 PicoClaw contributors
|
||||
|
||||
// Package bedrock provides a stub implementation when built without the bedrock tag.
|
||||
// To enable AWS Bedrock support, build with: go build -tags bedrock
|
||||
package bedrock
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
|
||||
)
|
||||
|
||||
type (
|
||||
LLMResponse = protocoltypes.LLMResponse
|
||||
Message = protocoltypes.Message
|
||||
ToolDefinition = protocoltypes.ToolDefinition
|
||||
)
|
||||
|
||||
// Provider is a stub that returns an error when Bedrock support is not compiled in.
|
||||
type Provider struct{}
|
||||
|
||||
// Option is a no-op when Bedrock is not enabled.
|
||||
type Option func(*providerConfig)
|
||||
|
||||
type providerConfig struct{}
|
||||
|
||||
// WithRegion is a no-op when Bedrock is not enabled.
|
||||
func WithRegion(region string) Option {
|
||||
return func(c *providerConfig) {}
|
||||
}
|
||||
|
||||
// WithProfile is a no-op when Bedrock is not enabled.
|
||||
func WithProfile(profile string) Option {
|
||||
return func(c *providerConfig) {}
|
||||
}
|
||||
|
||||
// WithBaseEndpoint is a no-op when Bedrock is not enabled.
|
||||
func WithBaseEndpoint(endpoint string) Option {
|
||||
return func(c *providerConfig) {}
|
||||
}
|
||||
|
||||
// WithRequestTimeout is a no-op when Bedrock is not enabled.
|
||||
func WithRequestTimeout(timeout time.Duration) Option {
|
||||
return func(c *providerConfig) {}
|
||||
}
|
||||
|
||||
// NewProvider returns an error indicating Bedrock support is not compiled in.
|
||||
func NewProvider(ctx context.Context, opts ...Option) (*Provider, error) {
|
||||
return nil, fmt.Errorf("bedrock provider not available: build with -tags bedrock to enable AWS Bedrock support")
|
||||
}
|
||||
|
||||
// Chat returns an error - this should never be called since NewProvider fails.
|
||||
func (p *Provider) Chat(
|
||||
ctx context.Context,
|
||||
messages []Message,
|
||||
tools []ToolDefinition,
|
||||
model string,
|
||||
options map[string]any,
|
||||
) (*LLMResponse, error) {
|
||||
return nil, fmt.Errorf("bedrock provider not available: build with -tags bedrock to enable AWS Bedrock support")
|
||||
}
|
||||
|
||||
// GetDefaultModel returns an empty string.
|
||||
func (p *Provider) GetDefaultModel() string {
|
||||
return ""
|
||||
}
|
||||
@@ -0,0 +1,35 @@
|
||||
//go:build !bedrock
|
||||
|
||||
// PicoClaw - Ultra-lightweight personal AI agent
|
||||
// License: MIT
|
||||
//
|
||||
// Copyright (c) 2026 PicoClaw contributors
|
||||
|
||||
package bedrock
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewProvider_ReturnsStubError(t *testing.T) {
|
||||
provider, err := NewProvider(context.Background())
|
||||
|
||||
assert.Nil(t, provider)
|
||||
require.Error(t, err)
|
||||
assert.True(t, strings.Contains(err.Error(), "build with -tags bedrock"),
|
||||
"error should mention build tag requirement, got: %s", err.Error())
|
||||
}
|
||||
|
||||
func TestNewProvider_WithOptions_ReturnsStubError(t *testing.T) {
|
||||
provider, err := NewProvider(context.Background(), WithRegion("us-west-2"), WithProfile("test"))
|
||||
|
||||
assert.Nil(t, provider)
|
||||
require.Error(t, err)
|
||||
assert.True(t, strings.Contains(err.Error(), "build with -tags bedrock"),
|
||||
"error should mention build tag requirement, got: %s", err.Error())
|
||||
}
|
||||
@@ -413,10 +413,10 @@ func TestChat_EmptyWorkspaceDoesNotSetDir(t *testing.T) {
|
||||
|
||||
func TestCreateProvider_ClaudeCli(t *testing.T) {
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.ModelList = []config.ModelConfig{
|
||||
cfg.ModelList = []*config.ModelConfig{
|
||||
{ModelName: "claude-sonnet-4.6", Model: "claude-cli/claude-sonnet-4.6", Workspace: "/test/ws"},
|
||||
}
|
||||
cfg.Agents.Defaults.Model = "claude-sonnet-4.6"
|
||||
cfg.Agents.Defaults.ModelName = "claude-sonnet-4.6"
|
||||
|
||||
provider, _, err := CreateProvider(cfg)
|
||||
if err != nil {
|
||||
@@ -434,10 +434,10 @@ func TestCreateProvider_ClaudeCli(t *testing.T) {
|
||||
|
||||
func TestCreateProvider_ClaudeCode(t *testing.T) {
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.ModelList = []config.ModelConfig{
|
||||
cfg.ModelList = []*config.ModelConfig{
|
||||
{ModelName: "claude-code", Model: "claude-cli/claude-code"},
|
||||
}
|
||||
cfg.Agents.Defaults.Model = "claude-code"
|
||||
cfg.Agents.Defaults.ModelName = "claude-code"
|
||||
|
||||
provider, _, err := CreateProvider(cfg)
|
||||
if err != nil {
|
||||
@@ -450,10 +450,10 @@ func TestCreateProvider_ClaudeCode(t *testing.T) {
|
||||
|
||||
func TestCreateProvider_ClaudeCodec(t *testing.T) {
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.ModelList = []config.ModelConfig{
|
||||
cfg.ModelList = []*config.ModelConfig{
|
||||
{ModelName: "claudecode", Model: "claude-cli/claudecode"},
|
||||
}
|
||||
cfg.Agents.Defaults.Model = "claudecode"
|
||||
cfg.Agents.Defaults.ModelName = "claudecode"
|
||||
|
||||
provider, _, err := CreateProvider(cfg)
|
||||
if err != nil {
|
||||
@@ -466,10 +466,10 @@ func TestCreateProvider_ClaudeCodec(t *testing.T) {
|
||||
|
||||
func TestCreateProvider_ClaudeCliDefaultWorkspace(t *testing.T) {
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.ModelList = []config.ModelConfig{
|
||||
cfg.ModelList = []*config.ModelConfig{
|
||||
{ModelName: "claude-cli", Model: "claude-cli/claude-sonnet"},
|
||||
}
|
||||
cfg.Agents.Defaults.Model = "claude-cli"
|
||||
cfg.Agents.Defaults.ModelName = "claude-cli"
|
||||
cfg.Agents.Defaults.Workspace = ""
|
||||
|
||||
provider, _, err := CreateProvider(cfg)
|
||||
|
||||
@@ -111,6 +111,17 @@ func SerializeMessages(messages []Message) []any {
|
||||
"url": mediaURL,
|
||||
},
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
if format, data, ok := parseDataAudioURL(mediaURL); ok {
|
||||
parts = append(parts, map[string]any{
|
||||
"type": "input_audio",
|
||||
"input_audio": map[string]any{
|
||||
"data": data,
|
||||
"format": format,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -132,6 +143,26 @@ func SerializeMessages(messages []Message) []any {
|
||||
return out
|
||||
}
|
||||
|
||||
func parseDataAudioURL(mediaURL string) (format, data string, ok bool) {
|
||||
if !strings.HasPrefix(mediaURL, "data:audio/") {
|
||||
return "", "", false
|
||||
}
|
||||
|
||||
payload := strings.TrimPrefix(mediaURL, "data:audio/")
|
||||
meta, data, found := strings.Cut(payload, ",")
|
||||
if !found {
|
||||
return "", "", false
|
||||
}
|
||||
|
||||
format, _, _ = strings.Cut(meta, ";")
|
||||
format = strings.TrimSpace(format)
|
||||
data = strings.TrimSpace(data)
|
||||
if format == "" || data == "" {
|
||||
return "", "", false
|
||||
}
|
||||
return format, data, true
|
||||
}
|
||||
|
||||
// --- Response parsing ---
|
||||
|
||||
// ParseResponse parses a JSON chat completion response body into an LLMResponse.
|
||||
@@ -214,11 +245,20 @@ func ParseResponse(body io.Reader) (*LLMResponse, error) {
|
||||
Reasoning: choice.Message.Reasoning,
|
||||
ReasoningDetails: choice.Message.ReasoningDetails,
|
||||
ToolCalls: toolCalls,
|
||||
FinishReason: choice.FinishReason,
|
||||
FinishReason: normalizeFinishReason(choice.FinishReason),
|
||||
Usage: apiResponse.Usage,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// normalizeFinishReason normalizes finish_reason values across providers.
|
||||
// Converts "length" to "truncated" for consistent handling.
|
||||
func normalizeFinishReason(reason string) string {
|
||||
if reason == "length" {
|
||||
return "truncated"
|
||||
}
|
||||
return reason
|
||||
}
|
||||
|
||||
// DecodeToolCallArguments decodes a tool call's arguments from raw JSON.
|
||||
func DecodeToolCallArguments(raw json.RawMessage, name string) map[string]any {
|
||||
arguments := make(map[string]any)
|
||||
|
||||
@@ -91,6 +91,44 @@ func TestSerializeMessages_WithMedia(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSerializeMessages_WithAudioMedia(t *testing.T) {
|
||||
messages := []Message{
|
||||
{Role: "user", Content: "transcribe this", Media: []string{"data:audio/ogg;base64,abc123"}},
|
||||
}
|
||||
result := SerializeMessages(messages)
|
||||
|
||||
data, _ := json.Marshal(result)
|
||||
var msgs []map[string]any
|
||||
json.Unmarshal(data, &msgs)
|
||||
|
||||
content, ok := msgs[0]["content"].([]any)
|
||||
if !ok {
|
||||
t.Fatalf("expected array content for media message, got %T", msgs[0]["content"])
|
||||
}
|
||||
if len(content) != 2 {
|
||||
t.Fatalf("expected 2 content parts, got %d", len(content))
|
||||
}
|
||||
|
||||
audioPart, ok := content[1].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("expected audio content part to be an object, got %T", content[1])
|
||||
}
|
||||
if audioPart["type"] != "input_audio" {
|
||||
t.Fatalf("audio part type = %v, want input_audio", audioPart["type"])
|
||||
}
|
||||
|
||||
inputAudio, ok := audioPart["input_audio"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("expected input_audio object, got %T", audioPart["input_audio"])
|
||||
}
|
||||
if inputAudio["format"] != "ogg" {
|
||||
t.Fatalf("audio format = %v, want ogg", inputAudio["format"])
|
||||
}
|
||||
if inputAudio["data"] != "abc123" {
|
||||
t.Fatalf("audio data = %v, want abc123", inputAudio["data"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestSerializeMessages_MediaWithToolCallID(t *testing.T) {
|
||||
messages := []Message{
|
||||
{Role: "tool", Content: "result", Media: []string{"data:image/png;base64,xyz"}, ToolCallID: "call_1"},
|
||||
|
||||
@@ -1,400 +1,7 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/auth"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
const defaultAnthropicAPIBase = "https://api.anthropic.com/v1"
|
||||
|
||||
var getCredential = auth.GetCredential
|
||||
|
||||
type providerType int
|
||||
|
||||
const (
|
||||
providerTypeHTTPCompat providerType = iota
|
||||
providerTypeClaudeAuth
|
||||
providerTypeCodexAuth
|
||||
providerTypeCodexCLIToken
|
||||
providerTypeClaudeCLI
|
||||
providerTypeCodexCLI
|
||||
providerTypeGitHubCopilot
|
||||
)
|
||||
|
||||
type providerSelection struct {
|
||||
providerType providerType
|
||||
apiKey string
|
||||
apiBase string
|
||||
proxy string
|
||||
model string
|
||||
workspace string
|
||||
connectMode string
|
||||
enableWebSearch bool
|
||||
}
|
||||
|
||||
func resolveProviderSelection(cfg *config.Config) (providerSelection, error) {
|
||||
model := cfg.Agents.Defaults.GetModelName()
|
||||
providerName := strings.ToLower(cfg.Agents.Defaults.Provider)
|
||||
lowerModel := strings.ToLower(model)
|
||||
|
||||
if providerName == "" && model == "" {
|
||||
return providerSelection{}, fmt.Errorf("no model configured: agents.defaults.model is empty")
|
||||
}
|
||||
|
||||
sel := providerSelection{
|
||||
providerType: providerTypeHTTPCompat,
|
||||
model: model,
|
||||
}
|
||||
|
||||
// First, prefer explicit provider configuration.
|
||||
if providerName != "" {
|
||||
switch providerName {
|
||||
case "groq":
|
||||
if cfg.Providers.Groq.APIKey != "" {
|
||||
sel.apiKey = cfg.Providers.Groq.APIKey
|
||||
sel.apiBase = cfg.Providers.Groq.APIBase
|
||||
sel.proxy = cfg.Providers.Groq.Proxy
|
||||
if sel.apiBase == "" {
|
||||
sel.apiBase = "https://api.groq.com/openai/v1"
|
||||
}
|
||||
}
|
||||
case "openai", "gpt":
|
||||
if cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != "" {
|
||||
sel.enableWebSearch = cfg.Providers.OpenAI.WebSearch
|
||||
if cfg.Providers.OpenAI.AuthMethod == "codex-cli" {
|
||||
sel.providerType = providerTypeCodexCLIToken
|
||||
return sel, nil
|
||||
}
|
||||
if cfg.Providers.OpenAI.AuthMethod == "oauth" || cfg.Providers.OpenAI.AuthMethod == "token" {
|
||||
sel.providerType = providerTypeCodexAuth
|
||||
return sel, nil
|
||||
}
|
||||
sel.apiKey = cfg.Providers.OpenAI.APIKey
|
||||
sel.apiBase = cfg.Providers.OpenAI.APIBase
|
||||
sel.proxy = cfg.Providers.OpenAI.Proxy
|
||||
if sel.apiBase == "" {
|
||||
sel.apiBase = "https://api.openai.com/v1"
|
||||
}
|
||||
}
|
||||
case "anthropic", "claude":
|
||||
if cfg.Providers.Anthropic.APIKey != "" || cfg.Providers.Anthropic.AuthMethod != "" {
|
||||
if cfg.Providers.Anthropic.AuthMethod == "oauth" || cfg.Providers.Anthropic.AuthMethod == "token" {
|
||||
sel.apiBase = cfg.Providers.Anthropic.APIBase
|
||||
if sel.apiBase == "" {
|
||||
sel.apiBase = defaultAnthropicAPIBase
|
||||
}
|
||||
sel.providerType = providerTypeClaudeAuth
|
||||
return sel, nil
|
||||
}
|
||||
sel.apiKey = cfg.Providers.Anthropic.APIKey
|
||||
sel.apiBase = cfg.Providers.Anthropic.APIBase
|
||||
sel.proxy = cfg.Providers.Anthropic.Proxy
|
||||
if sel.apiBase == "" {
|
||||
sel.apiBase = defaultAnthropicAPIBase
|
||||
}
|
||||
}
|
||||
case "openrouter":
|
||||
if cfg.Providers.OpenRouter.APIKey != "" {
|
||||
sel.apiKey = cfg.Providers.OpenRouter.APIKey
|
||||
sel.proxy = cfg.Providers.OpenRouter.Proxy
|
||||
if cfg.Providers.OpenRouter.APIBase != "" {
|
||||
sel.apiBase = cfg.Providers.OpenRouter.APIBase
|
||||
} else {
|
||||
sel.apiBase = "https://openrouter.ai/api/v1"
|
||||
}
|
||||
}
|
||||
case "litellm":
|
||||
if cfg.Providers.LiteLLM.APIKey != "" || cfg.Providers.LiteLLM.APIBase != "" {
|
||||
sel.apiKey = cfg.Providers.LiteLLM.APIKey
|
||||
sel.apiBase = cfg.Providers.LiteLLM.APIBase
|
||||
sel.proxy = cfg.Providers.LiteLLM.Proxy
|
||||
if sel.apiBase == "" {
|
||||
sel.apiBase = "http://localhost:4000/v1"
|
||||
}
|
||||
}
|
||||
case "zhipu", "glm":
|
||||
if cfg.Providers.Zhipu.APIKey != "" {
|
||||
sel.apiKey = cfg.Providers.Zhipu.APIKey
|
||||
sel.apiBase = cfg.Providers.Zhipu.APIBase
|
||||
sel.proxy = cfg.Providers.Zhipu.Proxy
|
||||
if sel.apiBase == "" {
|
||||
sel.apiBase = "https://open.bigmodel.cn/api/paas/v4"
|
||||
}
|
||||
}
|
||||
case "gemini", "google":
|
||||
if cfg.Providers.Gemini.APIKey != "" {
|
||||
sel.apiKey = cfg.Providers.Gemini.APIKey
|
||||
sel.apiBase = cfg.Providers.Gemini.APIBase
|
||||
sel.proxy = cfg.Providers.Gemini.Proxy
|
||||
if sel.apiBase == "" {
|
||||
sel.apiBase = "https://generativelanguage.googleapis.com/v1beta"
|
||||
}
|
||||
}
|
||||
case "vllm":
|
||||
if cfg.Providers.VLLM.APIBase != "" {
|
||||
sel.apiKey = cfg.Providers.VLLM.APIKey
|
||||
sel.apiBase = cfg.Providers.VLLM.APIBase
|
||||
sel.proxy = cfg.Providers.VLLM.Proxy
|
||||
}
|
||||
case "shengsuanyun":
|
||||
if cfg.Providers.ShengSuanYun.APIKey != "" {
|
||||
sel.apiKey = cfg.Providers.ShengSuanYun.APIKey
|
||||
sel.apiBase = cfg.Providers.ShengSuanYun.APIBase
|
||||
sel.proxy = cfg.Providers.ShengSuanYun.Proxy
|
||||
if sel.apiBase == "" {
|
||||
sel.apiBase = "https://router.shengsuanyun.com/api/v1"
|
||||
}
|
||||
}
|
||||
case "nvidia":
|
||||
if cfg.Providers.Nvidia.APIKey != "" {
|
||||
sel.apiKey = cfg.Providers.Nvidia.APIKey
|
||||
sel.apiBase = cfg.Providers.Nvidia.APIBase
|
||||
sel.proxy = cfg.Providers.Nvidia.Proxy
|
||||
if sel.apiBase == "" {
|
||||
sel.apiBase = "https://integrate.api.nvidia.com/v1"
|
||||
}
|
||||
}
|
||||
case "vivgrid":
|
||||
if cfg.Providers.Vivgrid.APIKey != "" {
|
||||
sel.apiKey = cfg.Providers.Vivgrid.APIKey
|
||||
sel.apiBase = cfg.Providers.Vivgrid.APIBase
|
||||
sel.proxy = cfg.Providers.Vivgrid.Proxy
|
||||
if sel.apiBase == "" {
|
||||
sel.apiBase = "https://api.vivgrid.com/v1"
|
||||
}
|
||||
}
|
||||
case "claude-cli", "claude-code", "claudecode":
|
||||
workspace := cfg.WorkspacePath()
|
||||
if workspace == "" {
|
||||
workspace = "."
|
||||
}
|
||||
sel.providerType = providerTypeClaudeCLI
|
||||
sel.workspace = workspace
|
||||
return sel, nil
|
||||
case "codex-cli", "codex-code":
|
||||
workspace := cfg.WorkspacePath()
|
||||
if workspace == "" {
|
||||
workspace = "."
|
||||
}
|
||||
sel.providerType = providerTypeCodexCLI
|
||||
sel.workspace = workspace
|
||||
return sel, nil
|
||||
case "deepseek":
|
||||
if cfg.Providers.DeepSeek.APIKey != "" {
|
||||
sel.apiKey = cfg.Providers.DeepSeek.APIKey
|
||||
sel.apiBase = cfg.Providers.DeepSeek.APIBase
|
||||
sel.proxy = cfg.Providers.DeepSeek.Proxy
|
||||
if sel.apiBase == "" {
|
||||
sel.apiBase = "https://api.deepseek.com/v1"
|
||||
}
|
||||
if model != "deepseek-chat" && model != "deepseek-reasoner" {
|
||||
sel.model = "deepseek-chat"
|
||||
}
|
||||
}
|
||||
case "avian":
|
||||
if cfg.Providers.Avian.APIKey != "" {
|
||||
sel.apiKey = cfg.Providers.Avian.APIKey
|
||||
sel.apiBase = cfg.Providers.Avian.APIBase
|
||||
sel.proxy = cfg.Providers.Avian.Proxy
|
||||
if sel.apiBase == "" {
|
||||
sel.apiBase = "https://api.avian.io/v1"
|
||||
}
|
||||
}
|
||||
case "mistral":
|
||||
if cfg.Providers.Mistral.APIKey != "" {
|
||||
sel.apiKey = cfg.Providers.Mistral.APIKey
|
||||
sel.apiBase = cfg.Providers.Mistral.APIBase
|
||||
sel.proxy = cfg.Providers.Mistral.Proxy
|
||||
if sel.apiBase == "" {
|
||||
sel.apiBase = "https://api.mistral.ai/v1"
|
||||
}
|
||||
}
|
||||
case "minimax":
|
||||
if cfg.Providers.Minimax.APIKey != "" {
|
||||
sel.apiKey = cfg.Providers.Minimax.APIKey
|
||||
sel.apiBase = cfg.Providers.Minimax.APIBase
|
||||
sel.proxy = cfg.Providers.Minimax.Proxy
|
||||
if sel.apiBase == "" {
|
||||
sel.apiBase = "https://api.minimaxi.com/v1"
|
||||
}
|
||||
}
|
||||
case "longcat":
|
||||
if cfg.Providers.LongCat.APIKey != "" {
|
||||
sel.apiKey = cfg.Providers.LongCat.APIKey
|
||||
sel.apiBase = cfg.Providers.LongCat.APIBase
|
||||
sel.proxy = cfg.Providers.LongCat.Proxy
|
||||
if sel.apiBase == "" {
|
||||
sel.apiBase = "https://api.longcat.chat/openai"
|
||||
}
|
||||
}
|
||||
case "github_copilot", "copilot":
|
||||
sel.providerType = providerTypeGitHubCopilot
|
||||
if cfg.Providers.GitHubCopilot.APIBase != "" {
|
||||
sel.apiBase = cfg.Providers.GitHubCopilot.APIBase
|
||||
} else {
|
||||
sel.apiBase = "localhost:4321"
|
||||
}
|
||||
sel.connectMode = cfg.Providers.GitHubCopilot.ConnectMode
|
||||
return sel, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: infer provider from model and configured keys.
|
||||
if sel.apiKey == "" && sel.apiBase == "" {
|
||||
switch {
|
||||
case (strings.Contains(lowerModel, "kimi") || strings.Contains(lowerModel, "moonshot") || strings.HasPrefix(model, "moonshot/")) && cfg.Providers.Moonshot.APIKey != "":
|
||||
sel.apiKey = cfg.Providers.Moonshot.APIKey
|
||||
sel.apiBase = cfg.Providers.Moonshot.APIBase
|
||||
sel.proxy = cfg.Providers.Moonshot.Proxy
|
||||
if sel.apiBase == "" {
|
||||
sel.apiBase = "https://api.moonshot.cn/v1"
|
||||
}
|
||||
case strings.HasPrefix(model, "openrouter/") ||
|
||||
strings.HasPrefix(model, "anthropic/") ||
|
||||
strings.HasPrefix(model, "openai/") ||
|
||||
strings.HasPrefix(model, "meta-llama/") ||
|
||||
strings.HasPrefix(model, "deepseek/") ||
|
||||
strings.HasPrefix(model, "google/"):
|
||||
sel.apiKey = cfg.Providers.OpenRouter.APIKey
|
||||
sel.proxy = cfg.Providers.OpenRouter.Proxy
|
||||
if cfg.Providers.OpenRouter.APIBase != "" {
|
||||
sel.apiBase = cfg.Providers.OpenRouter.APIBase
|
||||
} else {
|
||||
sel.apiBase = "https://openrouter.ai/api/v1"
|
||||
}
|
||||
case (strings.Contains(lowerModel, "claude") || strings.HasPrefix(model, "anthropic/")) &&
|
||||
(cfg.Providers.Anthropic.APIKey != "" || cfg.Providers.Anthropic.AuthMethod != ""):
|
||||
if cfg.Providers.Anthropic.AuthMethod == "oauth" || cfg.Providers.Anthropic.AuthMethod == "token" {
|
||||
sel.apiBase = cfg.Providers.Anthropic.APIBase
|
||||
if sel.apiBase == "" {
|
||||
sel.apiBase = defaultAnthropicAPIBase
|
||||
}
|
||||
sel.providerType = providerTypeClaudeAuth
|
||||
return sel, nil
|
||||
}
|
||||
sel.apiKey = cfg.Providers.Anthropic.APIKey
|
||||
sel.apiBase = cfg.Providers.Anthropic.APIBase
|
||||
sel.proxy = cfg.Providers.Anthropic.Proxy
|
||||
if sel.apiBase == "" {
|
||||
sel.apiBase = defaultAnthropicAPIBase
|
||||
}
|
||||
case (strings.Contains(lowerModel, "gpt") || strings.HasPrefix(model, "openai/")) &&
|
||||
(cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != ""):
|
||||
sel.enableWebSearch = cfg.Providers.OpenAI.WebSearch
|
||||
if cfg.Providers.OpenAI.AuthMethod == "codex-cli" {
|
||||
sel.providerType = providerTypeCodexCLIToken
|
||||
return sel, nil
|
||||
}
|
||||
if cfg.Providers.OpenAI.AuthMethod == "oauth" || cfg.Providers.OpenAI.AuthMethod == "token" {
|
||||
sel.providerType = providerTypeCodexAuth
|
||||
return sel, nil
|
||||
}
|
||||
sel.apiKey = cfg.Providers.OpenAI.APIKey
|
||||
sel.apiBase = cfg.Providers.OpenAI.APIBase
|
||||
sel.proxy = cfg.Providers.OpenAI.Proxy
|
||||
if sel.apiBase == "" {
|
||||
sel.apiBase = "https://api.openai.com/v1"
|
||||
}
|
||||
case (strings.Contains(lowerModel, "gemini") || strings.HasPrefix(model, "google/")) && cfg.Providers.Gemini.APIKey != "":
|
||||
sel.apiKey = cfg.Providers.Gemini.APIKey
|
||||
sel.apiBase = cfg.Providers.Gemini.APIBase
|
||||
sel.proxy = cfg.Providers.Gemini.Proxy
|
||||
if sel.apiBase == "" {
|
||||
sel.apiBase = "https://generativelanguage.googleapis.com/v1beta"
|
||||
}
|
||||
case (strings.Contains(lowerModel, "glm") || strings.Contains(lowerModel, "zhipu") || strings.Contains(lowerModel, "zai")) && cfg.Providers.Zhipu.APIKey != "":
|
||||
sel.apiKey = cfg.Providers.Zhipu.APIKey
|
||||
sel.apiBase = cfg.Providers.Zhipu.APIBase
|
||||
sel.proxy = cfg.Providers.Zhipu.Proxy
|
||||
if sel.apiBase == "" {
|
||||
sel.apiBase = "https://open.bigmodel.cn/api/paas/v4"
|
||||
}
|
||||
case (strings.Contains(lowerModel, "groq") || strings.HasPrefix(model, "groq/")) && cfg.Providers.Groq.APIKey != "":
|
||||
sel.apiKey = cfg.Providers.Groq.APIKey
|
||||
sel.apiBase = cfg.Providers.Groq.APIBase
|
||||
sel.proxy = cfg.Providers.Groq.Proxy
|
||||
if sel.apiBase == "" {
|
||||
sel.apiBase = "https://api.groq.com/openai/v1"
|
||||
}
|
||||
case (strings.Contains(lowerModel, "nvidia") || strings.HasPrefix(model, "nvidia/")) && cfg.Providers.Nvidia.APIKey != "":
|
||||
sel.apiKey = cfg.Providers.Nvidia.APIKey
|
||||
sel.apiBase = cfg.Providers.Nvidia.APIBase
|
||||
sel.proxy = cfg.Providers.Nvidia.Proxy
|
||||
if sel.apiBase == "" {
|
||||
sel.apiBase = "https://integrate.api.nvidia.com/v1"
|
||||
}
|
||||
case strings.HasPrefix(model, "vivgrid/") && cfg.Providers.Vivgrid.APIKey != "":
|
||||
sel.apiKey = cfg.Providers.Vivgrid.APIKey
|
||||
sel.apiBase = cfg.Providers.Vivgrid.APIBase
|
||||
sel.proxy = cfg.Providers.Vivgrid.Proxy
|
||||
if sel.apiBase == "" {
|
||||
sel.apiBase = "https://api.vivgrid.com/v1"
|
||||
}
|
||||
case (strings.Contains(lowerModel, "ollama") || strings.HasPrefix(model, "ollama/")) && cfg.Providers.Ollama.APIKey != "":
|
||||
sel.apiKey = cfg.Providers.Ollama.APIKey
|
||||
sel.apiBase = cfg.Providers.Ollama.APIBase
|
||||
sel.proxy = cfg.Providers.Ollama.Proxy
|
||||
if sel.apiBase == "" {
|
||||
sel.apiBase = "http://localhost:11434/v1"
|
||||
}
|
||||
case (strings.Contains(lowerModel, "mistral") || strings.HasPrefix(model, "mistral/")) && cfg.Providers.Mistral.APIKey != "":
|
||||
sel.apiKey = cfg.Providers.Mistral.APIKey
|
||||
sel.apiBase = cfg.Providers.Mistral.APIBase
|
||||
sel.proxy = cfg.Providers.Mistral.Proxy
|
||||
if sel.apiBase == "" {
|
||||
sel.apiBase = "https://api.mistral.ai/v1"
|
||||
}
|
||||
case (strings.Contains(lowerModel, "minimax") || strings.HasPrefix(model, "minimax/")) && cfg.Providers.Minimax.APIKey != "":
|
||||
sel.apiKey = cfg.Providers.Minimax.APIKey
|
||||
sel.apiBase = cfg.Providers.Minimax.APIBase
|
||||
sel.proxy = cfg.Providers.Minimax.Proxy
|
||||
if sel.apiBase == "" {
|
||||
sel.apiBase = "https://api.minimaxi.com/v1"
|
||||
}
|
||||
case strings.HasPrefix(model, "avian/") && cfg.Providers.Avian.APIKey != "":
|
||||
sel.apiKey = cfg.Providers.Avian.APIKey
|
||||
sel.apiBase = cfg.Providers.Avian.APIBase
|
||||
sel.proxy = cfg.Providers.Avian.Proxy
|
||||
if sel.apiBase == "" {
|
||||
sel.apiBase = "https://api.avian.io/v1"
|
||||
}
|
||||
case (strings.Contains(lowerModel, "longcat") || strings.HasPrefix(model, "longcat/")) && cfg.Providers.LongCat.APIKey != "":
|
||||
sel.apiKey = cfg.Providers.LongCat.APIKey
|
||||
sel.apiBase = cfg.Providers.LongCat.APIBase
|
||||
sel.proxy = cfg.Providers.LongCat.Proxy
|
||||
if sel.apiBase == "" {
|
||||
sel.apiBase = "https://api.longcat.chat/openai"
|
||||
}
|
||||
case cfg.Providers.VLLM.APIBase != "":
|
||||
sel.apiKey = cfg.Providers.VLLM.APIKey
|
||||
sel.apiBase = cfg.Providers.VLLM.APIBase
|
||||
sel.proxy = cfg.Providers.VLLM.Proxy
|
||||
default:
|
||||
if cfg.Providers.OpenRouter.APIKey != "" {
|
||||
sel.apiKey = cfg.Providers.OpenRouter.APIKey
|
||||
sel.proxy = cfg.Providers.OpenRouter.Proxy
|
||||
if cfg.Providers.OpenRouter.APIBase != "" {
|
||||
sel.apiBase = cfg.Providers.OpenRouter.APIBase
|
||||
} else {
|
||||
sel.apiBase = "https://openrouter.ai/api/v1"
|
||||
}
|
||||
} else {
|
||||
return providerSelection{}, fmt.Errorf("no API key configured for model: %s", model)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if sel.providerType == providerTypeHTTPCompat {
|
||||
if sel.apiKey == "" && !strings.HasPrefix(model, "bedrock/") {
|
||||
return providerSelection{}, fmt.Errorf("no API key configured for provider (model: %s)", model)
|
||||
}
|
||||
if sel.apiBase == "" {
|
||||
return providerSelection{}, fmt.Errorf("no API base configured for provider (model: %s)", model)
|
||||
}
|
||||
}
|
||||
|
||||
return sel, nil
|
||||
}
|
||||
|
||||
@@ -6,12 +6,15 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
anthropicmessages "github.com/sipeed/picoclaw/pkg/providers/anthropic_messages"
|
||||
"github.com/sipeed/picoclaw/pkg/providers/azure"
|
||||
"github.com/sipeed/picoclaw/pkg/providers/bedrock"
|
||||
)
|
||||
|
||||
// createClaudeAuthProvider creates a Claude provider using OAuth credentials from auth store.
|
||||
@@ -55,8 +58,9 @@ func ExtractProtocol(model string) (protocol, modelID string) {
|
||||
|
||||
// CreateProviderFromConfig creates a provider based on the ModelConfig.
|
||||
// It uses the protocol prefix in the Model field to determine which provider to create.
|
||||
// Supported protocols: openai, litellm, novita, anthropic, anthropic-messages,
|
||||
// antigravity, claude-cli, codex-cli, github-copilot
|
||||
// Supported protocol families include OpenAI-compatible prefixes (e.g., openai, openrouter, groq, gemini),
|
||||
// Azure OpenAI, Amazon Bedrock, Anthropic (including messages), and various CLI/compatibility shims.
|
||||
// See the switch on protocol in this function for the authoritative list.
|
||||
// Returns the provider, the model ID (without protocol prefix), and any error.
|
||||
func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, error) {
|
||||
if cfg == nil {
|
||||
@@ -80,7 +84,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
|
||||
return provider, modelID, nil
|
||||
}
|
||||
// OpenAI with API key
|
||||
if cfg.APIKey == "" && cfg.APIBase == "" {
|
||||
if cfg.APIKey() == "" && cfg.APIBase == "" {
|
||||
return nil, "", fmt.Errorf("api_key or api_base is required for HTTP-based protocol %q", protocol)
|
||||
}
|
||||
apiBase := cfg.APIBase
|
||||
@@ -88,17 +92,18 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
|
||||
apiBase = getDefaultAPIBase(protocol)
|
||||
}
|
||||
return NewHTTPProviderWithMaxTokensFieldAndRequestTimeout(
|
||||
cfg.APIKey,
|
||||
cfg.APIKey(),
|
||||
apiBase,
|
||||
cfg.Proxy,
|
||||
cfg.MaxTokensField,
|
||||
cfg.RequestTimeout,
|
||||
cfg.ExtraBody,
|
||||
), modelID, nil
|
||||
|
||||
case "azure", "azure-openai":
|
||||
// Azure OpenAI uses deployment-based URLs, api-key header auth,
|
||||
// and always sends max_completion_tokens.
|
||||
if cfg.APIKey == "" {
|
||||
if cfg.APIKey() == "" {
|
||||
return nil, "", fmt.Errorf("api_key is required for azure protocol")
|
||||
}
|
||||
if cfg.APIBase == "" {
|
||||
@@ -107,19 +112,55 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
|
||||
)
|
||||
}
|
||||
return azure.NewProviderWithTimeout(
|
||||
cfg.APIKey,
|
||||
cfg.APIKey(),
|
||||
cfg.APIBase,
|
||||
cfg.Proxy,
|
||||
cfg.RequestTimeout,
|
||||
), modelID, nil
|
||||
|
||||
case "bedrock":
|
||||
// AWS Bedrock uses AWS SDK credentials (env vars, profiles, IAM roles, etc.)
|
||||
// api_base can be:
|
||||
// - A full endpoint URL: https://bedrock-runtime.us-east-1.amazonaws.com
|
||||
// - A region name: us-east-1 (AWS SDK resolves endpoint automatically)
|
||||
var opts []bedrock.Option
|
||||
if cfg.APIBase != "" {
|
||||
if !strings.Contains(cfg.APIBase, "://") {
|
||||
// Treat as region: let AWS SDK resolve the correct endpoint
|
||||
// (supports all AWS partitions: aws, aws-cn, aws-us-gov, etc.)
|
||||
opts = append(opts, bedrock.WithRegion(cfg.APIBase))
|
||||
} else {
|
||||
// Full endpoint URL provided (for custom endpoints or testing)
|
||||
opts = append(opts, bedrock.WithBaseEndpoint(cfg.APIBase))
|
||||
}
|
||||
}
|
||||
// Use a separate timeout for AWS config loading (credential resolution can block)
|
||||
initTimeout := 30 * time.Second
|
||||
if cfg.RequestTimeout > 0 {
|
||||
reqTimeout := time.Duration(cfg.RequestTimeout) * time.Second
|
||||
// Set request timeout for API calls
|
||||
opts = append(opts, bedrock.WithRequestTimeout(reqTimeout))
|
||||
// Ensure init timeout is at least as large as request timeout
|
||||
if reqTimeout > initTimeout {
|
||||
initTimeout = reqTimeout
|
||||
}
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), initTimeout)
|
||||
defer cancel()
|
||||
// Note: AWS_PROFILE env var is automatically used by AWS SDK
|
||||
provider, err := bedrock.NewProvider(ctx, opts...)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("creating bedrock provider: %w", err)
|
||||
}
|
||||
return provider, modelID, nil
|
||||
|
||||
case "litellm", "openrouter", "groq", "zhipu", "gemini", "nvidia",
|
||||
"ollama", "moonshot", "shengsuanyun", "deepseek", "cerebras",
|
||||
"vivgrid", "volcengine", "vllm", "qwen", "qwen-intl", "qwen-international", "dashscope-intl",
|
||||
"qwen-us", "dashscope-us", "mistral", "avian", "minimax", "longcat", "modelscope", "novita",
|
||||
"qwen-us", "dashscope-us", "mistral", "avian", "longcat", "modelscope", "novita",
|
||||
"coding-plan", "alibaba-coding", "qwen-coding":
|
||||
// All other OpenAI-compatible HTTP providers
|
||||
if cfg.APIKey == "" && cfg.APIBase == "" {
|
||||
if cfg.APIKey() == "" && cfg.APIBase == "" {
|
||||
return nil, "", fmt.Errorf("api_key or api_base is required for HTTP-based protocol %q", protocol)
|
||||
}
|
||||
apiBase := cfg.APIBase
|
||||
@@ -127,11 +168,37 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
|
||||
apiBase = getDefaultAPIBase(protocol)
|
||||
}
|
||||
return NewHTTPProviderWithMaxTokensFieldAndRequestTimeout(
|
||||
cfg.APIKey,
|
||||
cfg.APIKey(),
|
||||
apiBase,
|
||||
cfg.Proxy,
|
||||
cfg.MaxTokensField,
|
||||
cfg.RequestTimeout,
|
||||
cfg.ExtraBody,
|
||||
), modelID, nil
|
||||
|
||||
case "minimax":
|
||||
// Minimax requires reasoning_split: true in the request body
|
||||
if cfg.APIKey() == "" && cfg.APIBase == "" {
|
||||
return nil, "", fmt.Errorf("api_key or api_base is required for HTTP-based protocol %q", protocol)
|
||||
}
|
||||
apiBase := cfg.APIBase
|
||||
if apiBase == "" {
|
||||
apiBase = getDefaultAPIBase(protocol)
|
||||
}
|
||||
extraBody := cfg.ExtraBody
|
||||
if extraBody == nil {
|
||||
extraBody = make(map[string]any)
|
||||
}
|
||||
if _, ok := extraBody["reasoning_split"]; !ok {
|
||||
extraBody["reasoning_split"] = true
|
||||
}
|
||||
return NewHTTPProviderWithMaxTokensFieldAndRequestTimeout(
|
||||
cfg.APIKey(),
|
||||
apiBase,
|
||||
cfg.Proxy,
|
||||
cfg.MaxTokensField,
|
||||
cfg.RequestTimeout,
|
||||
extraBody,
|
||||
), modelID, nil
|
||||
|
||||
case "anthropic":
|
||||
@@ -148,15 +215,16 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
|
||||
if apiBase == "" {
|
||||
apiBase = "https://api.anthropic.com/v1"
|
||||
}
|
||||
if cfg.APIKey == "" {
|
||||
if cfg.APIKey() == "" {
|
||||
return nil, "", fmt.Errorf("api_key is required for anthropic protocol (model: %s)", cfg.Model)
|
||||
}
|
||||
return NewHTTPProviderWithMaxTokensFieldAndRequestTimeout(
|
||||
cfg.APIKey,
|
||||
cfg.APIKey(),
|
||||
apiBase,
|
||||
cfg.Proxy,
|
||||
cfg.MaxTokensField,
|
||||
cfg.RequestTimeout,
|
||||
cfg.ExtraBody,
|
||||
), modelID, nil
|
||||
|
||||
case "anthropic-messages":
|
||||
@@ -165,11 +233,11 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
|
||||
if apiBase == "" {
|
||||
apiBase = "https://api.anthropic.com/v1"
|
||||
}
|
||||
if cfg.APIKey == "" {
|
||||
if cfg.APIKey() == "" {
|
||||
return nil, "", fmt.Errorf("api_key is required for anthropic-messages protocol (model: %s)", cfg.Model)
|
||||
}
|
||||
return anthropicmessages.NewProviderWithTimeout(
|
||||
cfg.APIKey,
|
||||
cfg.APIKey(),
|
||||
apiBase,
|
||||
cfg.RequestTimeout,
|
||||
), modelID, nil
|
||||
@@ -180,11 +248,11 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
|
||||
if apiBase == "" {
|
||||
apiBase = getDefaultAPIBase(protocol)
|
||||
}
|
||||
if cfg.APIKey == "" {
|
||||
if cfg.APIKey() == "" {
|
||||
return nil, "", fmt.Errorf("api_key is required for %q protocol (model: %s)", protocol, cfg.Model)
|
||||
}
|
||||
return anthropicmessages.NewProviderWithTimeout(
|
||||
cfg.APIKey,
|
||||
cfg.APIKey(),
|
||||
apiBase,
|
||||
cfg.RequestTimeout,
|
||||
), modelID, nil
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
@@ -89,9 +90,9 @@ func TestCreateProviderFromConfig_OpenAI(t *testing.T) {
|
||||
cfg := &config.ModelConfig{
|
||||
ModelName: "test-openai",
|
||||
Model: "openai/gpt-4o",
|
||||
APIKey: "test-key",
|
||||
APIBase: "https://api.example.com/v1",
|
||||
}
|
||||
cfg.SetAPIKey("test-key")
|
||||
|
||||
provider, modelID, err := CreateProviderFromConfig(cfg)
|
||||
if err != nil {
|
||||
@@ -129,8 +130,8 @@ func TestCreateProviderFromConfig_DefaultAPIBase(t *testing.T) {
|
||||
cfg := &config.ModelConfig{
|
||||
ModelName: "test-" + tt.protocol,
|
||||
Model: tt.protocol + "/test-model",
|
||||
APIKey: "test-key",
|
||||
}
|
||||
cfg.SetAPIKey("test-key")
|
||||
|
||||
provider, _, err := CreateProviderFromConfig(cfg)
|
||||
if err != nil {
|
||||
@@ -155,9 +156,9 @@ func TestCreateProviderFromConfig_LiteLLM(t *testing.T) {
|
||||
cfg := &config.ModelConfig{
|
||||
ModelName: "test-litellm",
|
||||
Model: "litellm/my-proxy-alias",
|
||||
APIKey: "test-key",
|
||||
APIBase: "http://localhost:4000/v1",
|
||||
}
|
||||
cfg.SetAPIKey("test-key")
|
||||
|
||||
provider, modelID, err := CreateProviderFromConfig(cfg)
|
||||
if err != nil {
|
||||
@@ -175,9 +176,9 @@ func TestCreateProviderFromConfig_LongCat(t *testing.T) {
|
||||
cfg := &config.ModelConfig{
|
||||
ModelName: "test-longcat",
|
||||
Model: "longcat/LongCat-Flash-Thinking",
|
||||
APIKey: "test-key",
|
||||
APIBase: "https://api.longcat.chat/openai",
|
||||
}
|
||||
cfg.SetAPIKey("test-key")
|
||||
|
||||
provider, modelID, err := CreateProviderFromConfig(cfg)
|
||||
if err != nil {
|
||||
@@ -198,9 +199,9 @@ func TestCreateProviderFromConfig_ModelScope(t *testing.T) {
|
||||
cfg := &config.ModelConfig{
|
||||
ModelName: "test-modelscope",
|
||||
Model: "modelscope/Qwen/Qwen3-235B-A22B-Instruct-2507",
|
||||
APIKey: "test-key",
|
||||
APIBase: "https://api-inference.modelscope.cn/v1",
|
||||
}
|
||||
cfg.SetAPIKey("test-key")
|
||||
|
||||
provider, modelID, err := CreateProviderFromConfig(cfg)
|
||||
if err != nil {
|
||||
@@ -227,8 +228,8 @@ func TestCreateProviderFromConfig_Novita(t *testing.T) {
|
||||
cfg := &config.ModelConfig{
|
||||
ModelName: "test-novita",
|
||||
Model: "novita/deepseek/deepseek-v3.2",
|
||||
APIKey: "test-key",
|
||||
}
|
||||
cfg.SetAPIKey("test-key")
|
||||
|
||||
provider, modelID, err := CreateProviderFromConfig(cfg)
|
||||
if err != nil {
|
||||
@@ -255,8 +256,8 @@ func TestCreateProviderFromConfig_Anthropic(t *testing.T) {
|
||||
cfg := &config.ModelConfig{
|
||||
ModelName: "test-anthropic",
|
||||
Model: "anthropic/claude-sonnet-4.6",
|
||||
APIKey: "test-key",
|
||||
}
|
||||
cfg.SetAPIKey("test-key")
|
||||
|
||||
provider, modelID, err := CreateProviderFromConfig(cfg)
|
||||
if err != nil {
|
||||
@@ -340,8 +341,8 @@ func TestCreateProviderFromConfig_UnknownProtocol(t *testing.T) {
|
||||
cfg := &config.ModelConfig{
|
||||
ModelName: "test-unknown",
|
||||
Model: "unknown-protocol/model",
|
||||
APIKey: "test-key",
|
||||
}
|
||||
cfg.SetAPIKey("test-key")
|
||||
|
||||
_, _, err := CreateProviderFromConfig(cfg)
|
||||
if err == nil {
|
||||
@@ -382,6 +383,7 @@ func TestCreateProviderFromConfig_RequestTimeoutPropagation(t *testing.T) {
|
||||
APIBase: server.URL,
|
||||
RequestTimeout: 1,
|
||||
}
|
||||
cfg.SetAPIKey("test-key")
|
||||
|
||||
provider, modelID, err := CreateProviderFromConfig(cfg)
|
||||
if err != nil {
|
||||
@@ -411,9 +413,9 @@ func TestCreateProviderFromConfig_Azure(t *testing.T) {
|
||||
cfg := &config.ModelConfig{
|
||||
ModelName: "azure-gpt5",
|
||||
Model: "azure/my-gpt5-deployment",
|
||||
APIKey: "test-azure-key",
|
||||
APIBase: "https://my-resource.openai.azure.com",
|
||||
}
|
||||
cfg.SetAPIKey("test-azure-key")
|
||||
|
||||
provider, modelID, err := CreateProviderFromConfig(cfg)
|
||||
if err != nil {
|
||||
@@ -431,9 +433,9 @@ func TestCreateProviderFromConfig_AzureOpenAIAlias(t *testing.T) {
|
||||
cfg := &config.ModelConfig{
|
||||
ModelName: "azure-gpt4",
|
||||
Model: "azure-openai/my-deployment",
|
||||
APIKey: "test-azure-key",
|
||||
APIBase: "https://my-resource.openai.azure.com",
|
||||
}
|
||||
cfg.SetAPIKey("test-azure-key")
|
||||
|
||||
provider, modelID, err := CreateProviderFromConfig(cfg)
|
||||
if err != nil {
|
||||
@@ -464,8 +466,8 @@ func TestCreateProviderFromConfig_AzureMissingAPIBase(t *testing.T) {
|
||||
cfg := &config.ModelConfig{
|
||||
ModelName: "azure-gpt5",
|
||||
Model: "azure/my-gpt5-deployment",
|
||||
APIKey: "test-azure-key",
|
||||
}
|
||||
cfg.SetAPIKey("test-azure-key")
|
||||
|
||||
_, _, err := CreateProviderFromConfig(cfg)
|
||||
if err == nil {
|
||||
@@ -488,8 +490,8 @@ func TestCreateProviderFromConfig_QwenInternationalAlias(t *testing.T) {
|
||||
cfg := &config.ModelConfig{
|
||||
ModelName: "test-" + tt.protocol,
|
||||
Model: tt.protocol + "/qwen-max",
|
||||
APIKey: "test-key",
|
||||
}
|
||||
cfg.SetAPIKey("test-key")
|
||||
|
||||
provider, modelID, err := CreateProviderFromConfig(cfg)
|
||||
if err != nil {
|
||||
@@ -522,8 +524,8 @@ func TestCreateProviderFromConfig_QwenUSAlias(t *testing.T) {
|
||||
cfg := &config.ModelConfig{
|
||||
ModelName: "test-" + tt.protocol,
|
||||
Model: tt.protocol + "/qwen-max",
|
||||
APIKey: "test-key",
|
||||
}
|
||||
cfg.SetAPIKey("test-key")
|
||||
|
||||
provider, modelID, err := CreateProviderFromConfig(cfg)
|
||||
if err != nil {
|
||||
@@ -556,8 +558,8 @@ func TestCreateProviderFromConfig_CodingPlanAnthropic(t *testing.T) {
|
||||
cfg := &config.ModelConfig{
|
||||
ModelName: "test-" + tt.protocol,
|
||||
Model: tt.protocol + "/claude-sonnet-4-20250514",
|
||||
APIKey: "test-key",
|
||||
}
|
||||
cfg.SetAPIKey("test-key")
|
||||
|
||||
provider, modelID, err := CreateProviderFromConfig(cfg)
|
||||
if err != nil {
|
||||
@@ -603,3 +605,173 @@ func TestGetDefaultAPIBase_QwenUSAliases(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateProviderFromConfig_MinimaxInjectsReasoningSplit(t *testing.T) {
|
||||
var requestBody map[string]any
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"choices":[{"message":{"content":"ok"},"finish_reason":"stop"}]}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
cfg := &config.ModelConfig{
|
||||
ModelName: "test-minimax",
|
||||
Model: "minimax/MiniMax-M2.5",
|
||||
APIBase: server.URL,
|
||||
}
|
||||
cfg.SetAPIKey("test-key")
|
||||
|
||||
provider, modelID, err := CreateProviderFromConfig(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateProviderFromConfig() error = %v", err)
|
||||
}
|
||||
if provider == nil {
|
||||
t.Fatal("CreateProviderFromConfig() returned nil provider")
|
||||
}
|
||||
if modelID != "MiniMax-M2.5" {
|
||||
t.Errorf("modelID = %q, want %q", modelID, "MiniMax-M2.5")
|
||||
}
|
||||
|
||||
_, err = provider.Chat(
|
||||
t.Context(),
|
||||
[]Message{{Role: "user", Content: "hi"}},
|
||||
nil,
|
||||
modelID,
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
}
|
||||
|
||||
// Verify reasoning_split is automatically injected
|
||||
if got, ok := requestBody["reasoning_split"]; !ok || got != true {
|
||||
t.Fatalf("reasoning_split = %v, want true", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateProviderFromConfig_MinimaxPreservesUserExtraBody(t *testing.T) {
|
||||
var requestBody map[string]any
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"choices":[{"message":{"content":"ok"},"finish_reason":"stop"}]}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
cfg := &config.ModelConfig{
|
||||
ModelName: "test-minimax-custom",
|
||||
Model: "minimax/MiniMax-M2.5",
|
||||
APIBase: server.URL,
|
||||
ExtraBody: map[string]any{"custom_field": "test"},
|
||||
}
|
||||
cfg.SetAPIKey("test-key")
|
||||
|
||||
provider, modelID, err := CreateProviderFromConfig(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateProviderFromConfig() error = %v", err)
|
||||
}
|
||||
|
||||
_, err = provider.Chat(
|
||||
t.Context(),
|
||||
[]Message{{Role: "user", Content: "hi"}},
|
||||
nil,
|
||||
modelID,
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
}
|
||||
|
||||
// Verify reasoning_split is automatically injected
|
||||
if got, ok := requestBody["reasoning_split"]; !ok || got != true {
|
||||
t.Fatalf("reasoning_split = %v, want true", got)
|
||||
}
|
||||
// Verify user's custom field is preserved
|
||||
if got, ok := requestBody["custom_field"]; !ok || got != "test" {
|
||||
t.Fatalf("custom_field = %v, want test", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateProviderFromConfig_Bedrock(t *testing.T) {
|
||||
// Set dummy AWS env vars to make test deterministic
|
||||
t.Setenv("AWS_ACCESS_KEY_ID", "test-key")
|
||||
t.Setenv("AWS_SECRET_ACCESS_KEY", "test-secret")
|
||||
t.Setenv("AWS_EC2_METADATA_DISABLED", "true")
|
||||
// Clear profile-related env vars to avoid loading shared config
|
||||
t.Setenv("AWS_PROFILE", "")
|
||||
t.Setenv("AWS_DEFAULT_PROFILE", "")
|
||||
t.Setenv("AWS_SDK_LOAD_CONFIG", "")
|
||||
t.Setenv("AWS_SHARED_CREDENTIALS_FILE", "")
|
||||
|
||||
cfg := &config.ModelConfig{
|
||||
ModelName: "bedrock-claude",
|
||||
Model: "bedrock/us.anthropic.claude-sonnet-4-20250514-v1:0",
|
||||
APIBase: "us-west-2", // Region (also sets AWS region)
|
||||
}
|
||||
|
||||
provider, modelID, err := CreateProviderFromConfig(cfg)
|
||||
if err == nil {
|
||||
// Provider created successfully (built with -tags bedrock)
|
||||
if provider == nil {
|
||||
t.Error("provider is nil on success")
|
||||
}
|
||||
if modelID != "us.anthropic.claude-sonnet-4-20250514-v1:0" {
|
||||
t.Errorf("modelID = %q, want %q", modelID, "us.anthropic.claude-sonnet-4-20250514-v1:0")
|
||||
}
|
||||
return
|
||||
}
|
||||
errMsg := err.Error()
|
||||
// When built without -tags bedrock, expect stub error
|
||||
if strings.Contains(errMsg, "build with -tags bedrock") {
|
||||
return // Expected stub error
|
||||
}
|
||||
// Unexpected error - fail the test
|
||||
t.Errorf("unexpected error from bedrock provider: %v", err)
|
||||
}
|
||||
|
||||
func TestCreateProviderFromConfig_BedrockWithEndpointURL(t *testing.T) {
|
||||
// Set dummy AWS env vars to make test deterministic
|
||||
t.Setenv("AWS_ACCESS_KEY_ID", "test-key")
|
||||
t.Setenv("AWS_SECRET_ACCESS_KEY", "test-secret")
|
||||
t.Setenv("AWS_REGION", "us-east-1") // Required when using endpoint URL
|
||||
t.Setenv("AWS_EC2_METADATA_DISABLED", "true")
|
||||
// Clear profile-related env vars to avoid loading shared config
|
||||
t.Setenv("AWS_PROFILE", "")
|
||||
t.Setenv("AWS_DEFAULT_PROFILE", "")
|
||||
t.Setenv("AWS_SDK_LOAD_CONFIG", "")
|
||||
t.Setenv("AWS_SHARED_CREDENTIALS_FILE", "")
|
||||
|
||||
cfg := &config.ModelConfig{
|
||||
ModelName: "bedrock-claude",
|
||||
Model: "bedrock/us.anthropic.claude-sonnet-4-20250514-v1:0",
|
||||
APIBase: "https://bedrock-runtime.us-east-1.amazonaws.com", // Full endpoint URL
|
||||
}
|
||||
|
||||
provider, modelID, err := CreateProviderFromConfig(cfg)
|
||||
if err == nil {
|
||||
// Provider created successfully (built with -tags bedrock)
|
||||
if provider == nil {
|
||||
t.Error("provider is nil on success")
|
||||
}
|
||||
if modelID != "us.anthropic.claude-sonnet-4-20250514-v1:0" {
|
||||
t.Errorf("modelID = %q, want %q", modelID, "us.anthropic.claude-sonnet-4-20250514-v1:0")
|
||||
}
|
||||
return
|
||||
}
|
||||
errMsg := err.Error()
|
||||
// When built without -tags bedrock, expect stub error
|
||||
if strings.Contains(errMsg, "build with -tags bedrock") {
|
||||
return // Expected stub error
|
||||
}
|
||||
// Unexpected error - fail the test
|
||||
t.Errorf("unexpected error from bedrock provider: %v", err)
|
||||
}
|
||||
|
||||
+13
-253
@@ -1,262 +1,22 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/auth"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
func TestResolveProviderSelection(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setup func(*config.Config)
|
||||
wantType providerType
|
||||
wantAPIBase string
|
||||
wantProxy string
|
||||
wantErrSubstr string
|
||||
}{
|
||||
{
|
||||
name: "explicit litellm provider uses configured base",
|
||||
setup: func(cfg *config.Config) {
|
||||
cfg.Agents.Defaults.Provider = "litellm"
|
||||
cfg.Providers.LiteLLM.APIKey = "litellm-key"
|
||||
cfg.Providers.LiteLLM.APIBase = "http://localhost:4000/v1"
|
||||
cfg.Providers.LiteLLM.Proxy = "http://127.0.0.1:7890"
|
||||
},
|
||||
wantType: providerTypeHTTPCompat,
|
||||
wantAPIBase: "http://localhost:4000/v1",
|
||||
wantProxy: "http://127.0.0.1:7890",
|
||||
},
|
||||
{
|
||||
name: "explicit litellm provider defaults base when only key is configured",
|
||||
setup: func(cfg *config.Config) {
|
||||
cfg.Agents.Defaults.Provider = "litellm"
|
||||
cfg.Providers.LiteLLM.APIKey = "litellm-key"
|
||||
},
|
||||
wantType: providerTypeHTTPCompat,
|
||||
wantAPIBase: "http://localhost:4000/v1",
|
||||
},
|
||||
{
|
||||
name: "explicit claude-cli provider routes to cli provider type",
|
||||
setup: func(cfg *config.Config) {
|
||||
cfg.Agents.Defaults.Provider = "claude-cli"
|
||||
cfg.Agents.Defaults.Workspace = "/tmp/ws"
|
||||
},
|
||||
wantType: providerTypeClaudeCLI,
|
||||
},
|
||||
{
|
||||
name: "explicit copilot provider routes to github copilot type",
|
||||
setup: func(cfg *config.Config) {
|
||||
cfg.Agents.Defaults.Provider = "copilot"
|
||||
},
|
||||
wantType: providerTypeGitHubCopilot,
|
||||
wantAPIBase: "localhost:4321",
|
||||
},
|
||||
{
|
||||
name: "explicit deepseek provider uses deepseek defaults",
|
||||
setup: func(cfg *config.Config) {
|
||||
cfg.Agents.Defaults.Provider = "deepseek"
|
||||
cfg.Agents.Defaults.Model = "deepseek/deepseek-chat"
|
||||
cfg.Providers.DeepSeek.APIKey = "deepseek-key"
|
||||
cfg.Providers.DeepSeek.Proxy = "http://127.0.0.1:7890"
|
||||
},
|
||||
wantType: providerTypeHTTPCompat,
|
||||
wantAPIBase: "https://api.deepseek.com/v1",
|
||||
wantProxy: "http://127.0.0.1:7890",
|
||||
},
|
||||
{
|
||||
name: "explicit shengsuanyun provider uses defaults",
|
||||
setup: func(cfg *config.Config) {
|
||||
cfg.Agents.Defaults.Provider = "shengsuanyun"
|
||||
cfg.Providers.ShengSuanYun.APIKey = "ssy-key"
|
||||
cfg.Providers.ShengSuanYun.Proxy = "http://127.0.0.1:7890"
|
||||
},
|
||||
wantType: providerTypeHTTPCompat,
|
||||
wantAPIBase: "https://router.shengsuanyun.com/api/v1",
|
||||
wantProxy: "http://127.0.0.1:7890",
|
||||
},
|
||||
{
|
||||
name: "explicit nvidia provider uses defaults",
|
||||
setup: func(cfg *config.Config) {
|
||||
cfg.Agents.Defaults.Provider = "nvidia"
|
||||
cfg.Providers.Nvidia.APIKey = "nvapi-test"
|
||||
cfg.Providers.Nvidia.Proxy = "http://127.0.0.1:7890"
|
||||
},
|
||||
wantType: providerTypeHTTPCompat,
|
||||
wantAPIBase: "https://integrate.api.nvidia.com/v1",
|
||||
wantProxy: "http://127.0.0.1:7890",
|
||||
},
|
||||
{
|
||||
name: "explicit vivgrid provider uses defaults",
|
||||
setup: func(cfg *config.Config) {
|
||||
cfg.Agents.Defaults.Provider = "vivgrid"
|
||||
cfg.Providers.Vivgrid.APIKey = "vivgrid-key"
|
||||
cfg.Providers.Vivgrid.Proxy = "http://127.0.0.1:7890"
|
||||
},
|
||||
wantType: providerTypeHTTPCompat,
|
||||
wantAPIBase: "https://api.vivgrid.com/v1",
|
||||
wantProxy: "http://127.0.0.1:7890",
|
||||
},
|
||||
{
|
||||
name: "openrouter model uses openrouter defaults",
|
||||
setup: func(cfg *config.Config) {
|
||||
cfg.Agents.Defaults.Model = "openrouter/auto"
|
||||
cfg.Providers.OpenRouter.APIKey = "sk-or-test"
|
||||
},
|
||||
wantType: providerTypeHTTPCompat,
|
||||
wantAPIBase: "https://openrouter.ai/api/v1",
|
||||
},
|
||||
{
|
||||
name: "anthropic oauth routes to claude auth provider",
|
||||
setup: func(cfg *config.Config) {
|
||||
cfg.Agents.Defaults.Model = "claude-sonnet-4.6"
|
||||
cfg.Providers.Anthropic.AuthMethod = "oauth"
|
||||
},
|
||||
wantType: providerTypeClaudeAuth,
|
||||
},
|
||||
{
|
||||
name: "openai oauth routes to codex auth provider",
|
||||
setup: func(cfg *config.Config) {
|
||||
cfg.Agents.Defaults.Model = "gpt-4o"
|
||||
cfg.Providers.OpenAI.AuthMethod = "oauth"
|
||||
},
|
||||
wantType: providerTypeCodexAuth,
|
||||
},
|
||||
{
|
||||
name: "openai codex-cli auth routes to codex cli token provider",
|
||||
setup: func(cfg *config.Config) {
|
||||
cfg.Agents.Defaults.Model = "gpt-4o"
|
||||
cfg.Providers.OpenAI.AuthMethod = "codex-cli"
|
||||
},
|
||||
wantType: providerTypeCodexCLIToken,
|
||||
},
|
||||
{
|
||||
name: "explicit codex-code provider routes to codex cli provider type",
|
||||
setup: func(cfg *config.Config) {
|
||||
cfg.Agents.Defaults.Provider = "codex-code"
|
||||
cfg.Agents.Defaults.Workspace = "/tmp/ws"
|
||||
},
|
||||
wantType: providerTypeCodexCLI,
|
||||
},
|
||||
{
|
||||
name: "zhipu model uses zhipu base default",
|
||||
setup: func(cfg *config.Config) {
|
||||
cfg.Agents.Defaults.Model = "glm-4.7"
|
||||
cfg.Providers.Zhipu.APIKey = "zhipu-key"
|
||||
},
|
||||
wantType: providerTypeHTTPCompat,
|
||||
wantAPIBase: "https://open.bigmodel.cn/api/paas/v4",
|
||||
},
|
||||
{
|
||||
name: "groq model uses groq base default",
|
||||
setup: func(cfg *config.Config) {
|
||||
cfg.Agents.Defaults.Model = "groq/llama-3.3-70b"
|
||||
cfg.Providers.Groq.APIKey = "gsk-key"
|
||||
},
|
||||
wantType: providerTypeHTTPCompat,
|
||||
wantAPIBase: "https://api.groq.com/openai/v1",
|
||||
},
|
||||
{
|
||||
name: "ollama model uses ollama base default",
|
||||
setup: func(cfg *config.Config) {
|
||||
cfg.Agents.Defaults.Model = "ollama/qwen2.5:14b"
|
||||
cfg.Providers.Ollama.APIKey = "ollama-key"
|
||||
},
|
||||
wantType: providerTypeHTTPCompat,
|
||||
wantAPIBase: "http://localhost:11434/v1",
|
||||
},
|
||||
{
|
||||
name: "moonshot model keeps proxy and default base",
|
||||
setup: func(cfg *config.Config) {
|
||||
cfg.Agents.Defaults.Model = "moonshot/kimi-k2.5"
|
||||
cfg.Providers.Moonshot.APIKey = "moonshot-key"
|
||||
cfg.Providers.Moonshot.Proxy = "http://127.0.0.1:7890"
|
||||
},
|
||||
wantType: providerTypeHTTPCompat,
|
||||
wantAPIBase: "https://api.moonshot.cn/v1",
|
||||
wantProxy: "http://127.0.0.1:7890",
|
||||
},
|
||||
{
|
||||
name: "explicit longcat provider uses defaults",
|
||||
setup: func(cfg *config.Config) {
|
||||
cfg.Agents.Defaults.Provider = "longcat"
|
||||
cfg.Providers.LongCat.APIKey = "longcat-key"
|
||||
cfg.Providers.LongCat.Proxy = "http://127.0.0.1:7890"
|
||||
},
|
||||
wantType: providerTypeHTTPCompat,
|
||||
wantAPIBase: "https://api.longcat.chat/openai",
|
||||
wantProxy: "http://127.0.0.1:7890",
|
||||
},
|
||||
{
|
||||
name: "longcat model fallback uses longcat base default",
|
||||
setup: func(cfg *config.Config) {
|
||||
cfg.Agents.Defaults.Model = "longcat/LongCat-Flash-Thinking"
|
||||
cfg.Providers.LongCat.APIKey = "longcat-key"
|
||||
},
|
||||
wantType: providerTypeHTTPCompat,
|
||||
wantAPIBase: "https://api.longcat.chat/openai",
|
||||
},
|
||||
{
|
||||
name: "missing keys returns model config error",
|
||||
setup: func(cfg *config.Config) {
|
||||
cfg.Agents.Defaults.Model = "custom-model"
|
||||
},
|
||||
wantErrSubstr: "no API key configured for model",
|
||||
},
|
||||
{
|
||||
name: "openrouter prefix without key returns provider key error",
|
||||
setup: func(cfg *config.Config) {
|
||||
cfg.Agents.Defaults.Model = "openrouter/auto"
|
||||
},
|
||||
wantErrSubstr: "no API key configured for provider",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cfg := config.DefaultConfig()
|
||||
tt.setup(cfg)
|
||||
|
||||
got, err := resolveProviderSelection(cfg)
|
||||
if tt.wantErrSubstr != "" {
|
||||
if err == nil {
|
||||
t.Fatalf("expected error containing %q, got nil", tt.wantErrSubstr)
|
||||
}
|
||||
if !strings.Contains(err.Error(), tt.wantErrSubstr) {
|
||||
t.Fatalf("error = %q, want substring %q", err.Error(), tt.wantErrSubstr)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("resolveProviderSelection() error = %v", err)
|
||||
}
|
||||
if got.providerType != tt.wantType {
|
||||
t.Fatalf("providerType = %v, want %v", got.providerType, tt.wantType)
|
||||
}
|
||||
if tt.wantAPIBase != "" && got.apiBase != tt.wantAPIBase {
|
||||
t.Fatalf("apiBase = %q, want %q", got.apiBase, tt.wantAPIBase)
|
||||
}
|
||||
if tt.wantProxy != "" && got.proxy != tt.wantProxy {
|
||||
t.Fatalf("proxy = %q, want %q", got.proxy, tt.wantProxy)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateProviderReturnsHTTPProviderForOpenRouter(t *testing.T) {
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Agents.Defaults.Model = "test-openrouter"
|
||||
cfg.ModelList = []config.ModelConfig{
|
||||
{
|
||||
ModelName: "test-openrouter",
|
||||
Model: "openrouter/auto",
|
||||
APIKey: "sk-or-test",
|
||||
APIBase: "https://openrouter.ai/api/v1",
|
||||
},
|
||||
cfg.Agents.Defaults.ModelName = "test-openrouter"
|
||||
modelCfg := &config.ModelConfig{
|
||||
ModelName: "test-openrouter",
|
||||
Model: "openrouter/auto",
|
||||
APIBase: "https://openrouter.ai/api/v1",
|
||||
}
|
||||
modelCfg.SetAPIKey("sk-or-test")
|
||||
cfg.ModelList = []*config.ModelConfig{modelCfg}
|
||||
|
||||
provider, _, err := CreateProvider(cfg)
|
||||
if err != nil {
|
||||
@@ -270,8 +30,8 @@ func TestCreateProviderReturnsHTTPProviderForOpenRouter(t *testing.T) {
|
||||
|
||||
func TestCreateProviderReturnsCodexCliProviderForCodexCode(t *testing.T) {
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Agents.Defaults.Model = "test-codex"
|
||||
cfg.ModelList = []config.ModelConfig{
|
||||
cfg.Agents.Defaults.ModelName = "test-codex"
|
||||
cfg.ModelList = []*config.ModelConfig{
|
||||
{
|
||||
ModelName: "test-codex",
|
||||
Model: "codex-cli/codex-model",
|
||||
@@ -291,8 +51,8 @@ func TestCreateProviderReturnsCodexCliProviderForCodexCode(t *testing.T) {
|
||||
|
||||
func TestCreateProviderReturnsClaudeCliProviderForClaudeCli(t *testing.T) {
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Agents.Defaults.Model = "test-claude-cli"
|
||||
cfg.ModelList = []config.ModelConfig{
|
||||
cfg.Agents.Defaults.ModelName = "test-claude-cli"
|
||||
cfg.ModelList = []*config.ModelConfig{
|
||||
{
|
||||
ModelName: "test-claude-cli",
|
||||
Model: "claude-cli/claude-sonnet",
|
||||
@@ -324,8 +84,8 @@ func TestCreateProviderReturnsClaudeProviderForAnthropicOAuth(t *testing.T) {
|
||||
}
|
||||
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Agents.Defaults.Model = "test-claude-oauth"
|
||||
cfg.ModelList = []config.ModelConfig{
|
||||
cfg.Agents.Defaults.ModelName = "test-claude-oauth"
|
||||
cfg.ModelList = []*config.ModelConfig{
|
||||
{
|
||||
ModelName: "test-claude-oauth",
|
||||
Model: "anthropic/claude-sonnet-4.6",
|
||||
|
||||
@@ -24,12 +24,13 @@ func NewHTTPProvider(apiKey, apiBase, proxy string) *HTTPProvider {
|
||||
}
|
||||
|
||||
func NewHTTPProviderWithMaxTokensField(apiKey, apiBase, proxy, maxTokensField string) *HTTPProvider {
|
||||
return NewHTTPProviderWithMaxTokensFieldAndRequestTimeout(apiKey, apiBase, proxy, maxTokensField, 0)
|
||||
return NewHTTPProviderWithMaxTokensFieldAndRequestTimeout(apiKey, apiBase, proxy, maxTokensField, 0, nil)
|
||||
}
|
||||
|
||||
func NewHTTPProviderWithMaxTokensFieldAndRequestTimeout(
|
||||
apiKey, apiBase, proxy, maxTokensField string,
|
||||
requestTimeoutSeconds int,
|
||||
extraBody map[string]any,
|
||||
) *HTTPProvider {
|
||||
return &HTTPProvider{
|
||||
delegate: openai_compat.NewProvider(
|
||||
@@ -38,6 +39,7 @@ func NewHTTPProviderWithMaxTokensFieldAndRequestTimeout(
|
||||
proxy,
|
||||
openai_compat.WithMaxTokensField(maxTokensField),
|
||||
openai_compat.WithRequestTimeout(time.Duration(requestTimeoutSeconds)*time.Second),
|
||||
openai_compat.WithExtraBody(extraBody),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -18,23 +18,6 @@ import (
|
||||
func CreateProvider(cfg *config.Config) (LLMProvider, string, error) {
|
||||
model := cfg.Agents.Defaults.GetModelName()
|
||||
|
||||
// Ensure model_list is populated from providers config if needed
|
||||
// This handles two cases:
|
||||
// 1. ModelList is empty - convert all providers
|
||||
// 2. ModelList has some entries but not all providers - merge missing ones
|
||||
if cfg.HasProvidersConfig() {
|
||||
providerModels := config.ConvertProvidersToModelList(cfg)
|
||||
existingModelNames := make(map[string]bool)
|
||||
for _, m := range cfg.ModelList {
|
||||
existingModelNames[m.ModelName] = true
|
||||
}
|
||||
for _, pm := range providerModels {
|
||||
if !existingModelNames[pm.ModelName] {
|
||||
cfg.ModelList = append(cfg.ModelList, pm)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Must have model_list at this point
|
||||
if len(cfg.ModelList) == 0 {
|
||||
return nil, "", fmt.Errorf("no providers configured. Please add entries to model_list in your config")
|
||||
|
||||
@@ -35,6 +35,7 @@ type Provider struct {
|
||||
apiBase string
|
||||
maxTokensField string // Field name for max tokens (e.g., "max_completion_tokens" for o1/glm models)
|
||||
httpClient *http.Client
|
||||
extraBody map[string]any // Additional fields to inject into request body
|
||||
}
|
||||
|
||||
type Option func(*Provider)
|
||||
@@ -55,6 +56,12 @@ func WithRequestTimeout(timeout time.Duration) Option {
|
||||
}
|
||||
}
|
||||
|
||||
func WithExtraBody(extraBody map[string]any) Option {
|
||||
return func(p *Provider) {
|
||||
p.extraBody = extraBody
|
||||
}
|
||||
}
|
||||
|
||||
func NewProvider(apiKey, apiBase, proxy string, opts ...Option) *Provider {
|
||||
p := &Provider{
|
||||
apiKey: apiKey,
|
||||
@@ -140,6 +147,12 @@ func (p *Provider) buildRequestBody(
|
||||
}
|
||||
}
|
||||
|
||||
// Merge extra body fields configured per-provider/model.
|
||||
// These are injected last so they take precedence over defaults.
|
||||
for k, v := range p.extraBody {
|
||||
requestBody[k] = v
|
||||
}
|
||||
|
||||
return requestBody
|
||||
}
|
||||
|
||||
|
||||
@@ -610,6 +610,90 @@ func TestProvider_RequestTimeoutOverride(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderChat_ExtraBodyInjected(t *testing.T) {
|
||||
var requestBody map[string]any
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
resp := map[string]any{
|
||||
"choices": []map[string]any{
|
||||
{
|
||||
"message": map[string]any{"content": "ok"},
|
||||
"finish_reason": "stop",
|
||||
},
|
||||
},
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
extraBody := map[string]any{"reasoning_split": true, "custom_field": "test"}
|
||||
p := NewProvider("key", server.URL, "", WithExtraBody(extraBody))
|
||||
|
||||
_, err := p.Chat(
|
||||
t.Context(),
|
||||
[]Message{{Role: "user", Content: "hi"}},
|
||||
nil,
|
||||
"minimax/abab7",
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
}
|
||||
|
||||
if got, ok := requestBody["reasoning_split"]; !ok || got != true {
|
||||
t.Fatalf("reasoning_split = %v, want true", got)
|
||||
}
|
||||
if got, ok := requestBody["custom_field"]; !ok || got != "test" {
|
||||
t.Fatalf("custom_field = %v, want test", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderChat_ExtraBodyOverridesOptions(t *testing.T) {
|
||||
var requestBody map[string]any
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
resp := map[string]any{
|
||||
"choices": []map[string]any{
|
||||
{
|
||||
"message": map[string]any{"content": "ok"},
|
||||
"finish_reason": "stop",
|
||||
},
|
||||
},
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
extraBody := map[string]any{"temperature": 0.9}
|
||||
p := NewProvider("key", server.URL, "", WithExtraBody(extraBody))
|
||||
|
||||
_, err := p.Chat(
|
||||
t.Context(),
|
||||
[]Message{{Role: "user", Content: "hi"}},
|
||||
nil,
|
||||
"gpt-4o",
|
||||
map[string]any{"temperature": 0.5},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
}
|
||||
|
||||
// ExtraBody takes precedence over options since it is merged last.
|
||||
if got := requestBody["temperature"]; got != float64(0.9) {
|
||||
t.Fatalf("temperature = %v, want 0.9 (from extraBody, overriding options)", got)
|
||||
}
|
||||
}
|
||||
|
||||
type roundTripperFunc func(*http.Request) (*http.Response, error)
|
||||
|
||||
func (f roundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) {
|
||||
|
||||
Reference in New Issue
Block a user