feat(providers): add AWS Bedrock provider (#1903)

Add support for AWS Bedrock as an LLM provider using the Converse API.
The implementation is behind a build tag (-tags bedrock) to keep the
default binary size small.

Features:
- AWS SDK v2 with automatic credential chain (env vars, profiles, IAM roles)
- Converse API for unified access to Claude, Llama, Mistral models
- Tool/function calling support with proper document handling
- Image support with base64 decoding and size limits
- Request timeout configuration
- Region validation and endpoint resolution for all AWS partitions

Usage:
  go build -tags bedrock
  model: bedrock/us.anthropic.claude-sonnet-4-20250514-v1:0
  api_base: us-east-1  (or full endpoint URL)
This commit is contained in:
Andy Lo-A-Foe
2026-03-23 18:10:56 +01:00
committed by GitHub
parent 40571996b1
commit b787131c82
9 changed files with 1397 additions and 2 deletions
+3
View File
@@ -373,6 +373,9 @@ PicoClaw supports 30+ LLM providers through the `model_list` configuration. Use
| [Azure OpenAI](https://portal.azure.com/) | `azure/` | Required | Enterprise Azure deployment |
| [GitHub Copilot](https://github.com/features/copilot) | `github-copilot/` | OAuth | Device code login |
| [Antigravity](https://console.cloud.google.com/) | `antigravity/` | OAuth | Google Cloud AI |
| [AWS Bedrock](https://console.aws.amazon.com/bedrock)* | `bedrock/` | AWS credentials | Claude, Llama, Mistral on AWS |
> \* AWS Bedrock requires build tag: `go build -tags bedrock`. Set `api_base` to a region name (e.g., `us-east-1`) for automatic endpoint resolution across all AWS partitions (aws, aws-cn, aws-us-gov). When using a full endpoint URL instead, you must also configure `AWS_REGION` via environment variable or AWS config/profile.
<details>
<summary><b>Local deployment (Ollama, vLLM, etc.)</b></summary>
+16
View File
@@ -7,6 +7,9 @@ require (
github.com/BurntSushi/toml v1.6.0
github.com/adhocore/gronx v1.19.6
github.com/anthropics/anthropic-sdk-go v1.26.0
github.com/aws/aws-sdk-go-v2 v1.41.4
github.com/aws/aws-sdk-go-v2/config v1.32.12
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.50.2
github.com/bwmarrin/discordgo v0.29.0
github.com/caarlos0/env/v11 v11.4.0
github.com/ergochat/irc-go v0.6.0
@@ -40,6 +43,19 @@ require (
require (
filippo.io/edwards25519 v1.2.0 // indirect
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.7 // indirect
github.com/aws/aws-sdk-go-v2/credentials v1.19.12 // indirect
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.20 // indirect
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.20 // indirect
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.20 // indirect
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.6 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.7 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.20 // indirect
github.com/aws/aws-sdk-go-v2/service/signin v1.0.8 // indirect
github.com/aws/aws-sdk-go-v2/service/sso v1.30.13 // indirect
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.17 // indirect
github.com/aws/aws-sdk-go-v2/service/sts v1.41.9 // indirect
github.com/aws/smithy-go v1.24.2 // indirect
github.com/beeper/argo-go v1.1.2 // indirect
github.com/coder/websocket v1.8.14 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
+32
View File
@@ -17,6 +17,38 @@ github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwTo
github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY=
github.com/anthropics/anthropic-sdk-go v1.26.0 h1:oUTzFaUpAevfuELAP1sjL6CQJ9HHAfT7CoSYSac11PY=
github.com/anthropics/anthropic-sdk-go v1.26.0/go.mod h1:qUKmaW+uuPB64iy1l+4kOSvaLqPXnHTTBKH6RVZ7q5Q=
github.com/aws/aws-sdk-go-v2 v1.41.4 h1:10f50G7WyU02T56ox1wWXq+zTX9I1zxG46HYuG1hH/k=
github.com/aws/aws-sdk-go-v2 v1.41.4/go.mod h1:mwsPRE8ceUUpiTgF7QmQIJ7lgsKUPQOUl3o72QBrE1o=
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.7 h1:3kGOqnh1pPeddVa/E37XNTaWJ8W6vrbYV9lJEkCnhuY=
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.7/go.mod h1:lyw7GFp3qENLh7kwzf7iMzAxDn+NzjXEAGjKS2UOKqI=
github.com/aws/aws-sdk-go-v2/config v1.32.12 h1:O3csC7HUGn2895eNrLytOJQdoL2xyJy0iYXhoZ1OmP0=
github.com/aws/aws-sdk-go-v2/config v1.32.12/go.mod h1:96zTvoOFR4FURjI+/5wY1vc1ABceROO4lWgWJuxgy0g=
github.com/aws/aws-sdk-go-v2/credentials v1.19.12 h1:oqtA6v+y5fZg//tcTWahyN9PEn5eDU/Wpvc2+kJ4aY8=
github.com/aws/aws-sdk-go-v2/credentials v1.19.12/go.mod h1:U3R1RtSHx6NB0DvEQFGyf/0sbrpJrluENHdPy1j/3TE=
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.20 h1:zOgq3uezl5nznfoK3ODuqbhVg1JzAGDUhXOsU0IDCAo=
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.20/go.mod h1:z/MVwUARehy6GAg/yQ1GO2IMl0k++cu1ohP9zo887wE=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.20 h1:CNXO7mvgThFGqOFgbNAP2nol2qAWBOGfqR/7tQlvLmc=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.20/go.mod h1:oydPDJKcfMhgfcgBUZaG+toBbwy8yPWubJXBVERtI4o=
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.20 h1:tN6W/hg+pkM+tf9XDkWUbDEjGLb+raoBMFsTodcoYKw=
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.20/go.mod h1:YJ898MhD067hSHA6xYCx5ts/jEd8BSOLtQDL3iZsvbc=
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.6 h1:qYQ4pzQ2Oz6WpQ8T3HvGHnZydA72MnLuFK9tJwmrbHw=
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.6/go.mod h1:O3h0IK87yXci+kg6flUKzJnWeziQUKciKrLjcatSNcY=
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.50.2 h1:x0eGAWpd1B5I/vMtrB4Q4Zuc3CXWI8wjHfPPqBSrKmM=
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.50.2/go.mod h1:V9oTWSDC2MtS1DR71hbNET/bZ8psQp022amEBe1grJc=
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.7 h1:5EniKhLZe4xzL7a+fU3C2tfUN4nWIqlLesfrjkuPFTY=
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.7/go.mod h1:x0nZssQ3qZSnIcePWLvcoFisRXJzcTVvYpAAdYX8+GI=
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.20 h1:2HvVAIq+YqgGotK6EkMf+KIEqTISmTYh5zLpYyeTo1Y=
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.20/go.mod h1:V4X406Y666khGa8ghKmphma/7C0DAtEQYhkq9z4vpbk=
github.com/aws/aws-sdk-go-v2/service/signin v1.0.8 h1:0GFOLzEbOyZABS3PhYfBIx2rNBACYcKty+XGkTgw1ow=
github.com/aws/aws-sdk-go-v2/service/signin v1.0.8/go.mod h1:LXypKvk85AROkKhOG6/YEcHFPoX+prKTowKnVdcaIxE=
github.com/aws/aws-sdk-go-v2/service/sso v1.30.13 h1:kiIDLZ005EcKomYYITtfsjn7dtOwHDOFy7IbPXKek2o=
github.com/aws/aws-sdk-go-v2/service/sso v1.30.13/go.mod h1:2h/xGEowcW/g38g06g3KpRWDlT+OTfxxI0o1KqayAB8=
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.17 h1:jzKAXIlhZhJbnYwHbvUQZEB8KfgAEuG0dc08Bkda7NU=
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.17/go.mod h1:Al9fFsXjv4KfbzQHGe6V4NZSZQXecFcvaIF4e70FoRA=
github.com/aws/aws-sdk-go-v2/service/sts v1.41.9 h1:Cng+OOwCHmFljXIxpEVXAGMnBia8MSU6Ch5i9PgBkcU=
github.com/aws/aws-sdk-go-v2/service/sts v1.41.9/go.mod h1:LrlIndBDdjA/EeXeyNBle+gyCwTlizzW5ycgWnvIxkk=
github.com/aws/smithy-go v1.24.2 h1:FzA3bu/nt/vDvmnkg+R8Xl46gmzEDam6mZ1hzmwXFng=
github.com/aws/smithy-go v1.24.2/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc=
github.com/beeper/argo-go v1.1.2 h1:UQI2G8F+NLfGTOmTUI0254pGKx/HUU/etbUGTJv91Fs=
github.com/beeper/argo-go v1.1.2/go.mod h1:M+LJAnyowKVQ6Rdj6XYGEn+qcVFkb3R/MUpqkGR0hM4=
github.com/bwmarrin/discordgo v0.29.0 h1:FmWeXFaKUwrcL3Cx65c20bTRW+vOb6k8AnaP+EgjDno=
+580
View File
@@ -0,0 +1,580 @@
//go:build bedrock
// PicoClaw - Ultra-lightweight personal AI agent
// License: MIT
//
// Copyright (c) 2026 PicoClaw contributors
// Package bedrock implements the LLM provider interface for AWS Bedrock.
// It uses the Bedrock Runtime Converse API for unified access to multiple
// model families (Claude, Llama, Mistral, etc.) with tool/function calling support.
package bedrock
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"log"
"math"
"strings"
"time"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime/document"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types"
"github.com/sipeed/picoclaw/pkg/providers/common"
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
)
type (
ToolCall = protocoltypes.ToolCall
FunctionCall = protocoltypes.FunctionCall
LLMResponse = protocoltypes.LLMResponse
UsageInfo = protocoltypes.UsageInfo
Message = protocoltypes.Message
ToolDefinition = protocoltypes.ToolDefinition
ToolFunctionDefinition = protocoltypes.ToolFunctionDefinition
)
// Provider implements the LLM provider interface for AWS Bedrock.
type Provider struct {
client *bedrockruntime.Client
region string
requestTimeout time.Duration
}
// Option configures the Bedrock Provider.
type Option func(*providerConfig)
type providerConfig struct {
region string
profile string
baseEndpoint string
requestTimeout time.Duration
}
// WithRegion sets the AWS region for Bedrock requests.
func WithRegion(region string) Option {
return func(c *providerConfig) {
c.region = region
}
}
// WithProfile sets the AWS profile to use for credentials.
func WithProfile(profile string) Option {
return func(c *providerConfig) {
c.profile = profile
}
}
// WithBaseEndpoint sets a custom Bedrock endpoint URL.
// Example: https://bedrock-runtime.us-east-1.amazonaws.com
func WithBaseEndpoint(endpoint string) Option {
return func(c *providerConfig) {
c.baseEndpoint = endpoint
}
}
// WithRequestTimeout sets the timeout for Bedrock API requests.
func WithRequestTimeout(timeout time.Duration) Option {
return func(c *providerConfig) {
c.requestTimeout = timeout
}
}
// NewProvider creates a new AWS Bedrock provider.
// It uses the default AWS credential chain (env vars, shared config, IAM roles, etc.).
func NewProvider(ctx context.Context, opts ...Option) (*Provider, error) {
pc := &providerConfig{}
for _, opt := range opts {
opt(pc)
}
// Build AWS config options
var configOpts []func(*config.LoadOptions) error
if pc.region != "" {
configOpts = append(configOpts, config.WithRegion(pc.region))
}
if pc.profile != "" {
configOpts = append(configOpts, config.WithSharedConfigProfile(pc.profile))
}
// Load AWS config with automatic credential discovery
cfg, err := config.LoadDefaultConfig(ctx, configOpts...)
if err != nil {
return nil, fmt.Errorf("loading AWS config: %w", err)
}
// Validate region is set - required for Bedrock request signing
if cfg.Region == "" {
return nil, fmt.Errorf("AWS region not configured: set AWS_REGION, AWS_DEFAULT_REGION, or use WithRegion option")
}
// Build client options
var clientOpts []func(*bedrockruntime.Options)
if pc.baseEndpoint != "" {
clientOpts = append(clientOpts, func(o *bedrockruntime.Options) {
o.BaseEndpoint = aws.String(pc.baseEndpoint)
})
}
client := bedrockruntime.NewFromConfig(cfg, clientOpts...)
return &Provider{
client: client,
region: cfg.Region,
requestTimeout: pc.requestTimeout,
}, nil
}
// Chat sends messages to AWS Bedrock using the Converse API.
func (p *Provider) Chat(
ctx context.Context,
messages []Message,
tools []ToolDefinition,
model string,
options map[string]any,
) (*LLMResponse, error) {
// Apply request timeout if context doesn't already have a deadline.
// Use explicit timeout if set, otherwise fall back to common default.
effectiveTimeout := p.requestTimeout
if effectiveTimeout <= 0 {
effectiveTimeout = common.DefaultRequestTimeout
}
if _, hasDeadline := ctx.Deadline(); !hasDeadline {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, effectiveTimeout)
defer cancel()
}
// Build the Converse API input
input := &bedrockruntime.ConverseInput{
ModelId: aws.String(model),
}
// Convert messages to Bedrock format
bedrockMessages, systemPrompts := convertMessages(messages)
input.Messages = bedrockMessages
// Set system prompts if any
if len(systemPrompts) > 0 {
input.System = systemPrompts
}
// Set inference configuration only when options are provided
var inferenceConfig *types.InferenceConfiguration
if maxTokens, ok := common.AsInt(options["max_tokens"]); ok && maxTokens > 0 {
if inferenceConfig == nil {
inferenceConfig = &types.InferenceConfiguration{}
}
// Clamp to int32 range to avoid overflow
if maxTokens > math.MaxInt32 {
maxTokens = math.MaxInt32
}
inferenceConfig.MaxTokens = aws.Int32(int32(maxTokens))
}
if temp, ok := common.AsFloat(options["temperature"]); ok {
if inferenceConfig == nil {
inferenceConfig = &types.InferenceConfiguration{}
}
inferenceConfig.Temperature = aws.Float32(float32(temp))
}
if inferenceConfig != nil {
input.InferenceConfig = inferenceConfig
}
// Convert tools to Bedrock format
// Only set ToolConfig if at least one valid tool was produced
if len(tools) > 0 {
toolConfig := convertTools(tools)
if len(toolConfig.Tools) > 0 {
input.ToolConfig = toolConfig
}
}
// Call Bedrock Converse API
output, err := p.client.Converse(ctx, input)
if err != nil {
return nil, fmt.Errorf("bedrock converse: %w", err)
}
// Parse the response
return parseResponse(output)
}
// GetDefaultModel returns an empty string as Bedrock models are user-configured.
func (p *Provider) GetDefaultModel() string {
return ""
}
// Region returns the AWS region configured for this Provider.
func (p *Provider) Region() string {
return p.region
}
// convertMessages converts internal messages to Bedrock Converse format.
// Returns the conversation messages and any system prompts separately.
// Note: Bedrock requires all tool results for a given assistant turn to be in a single
// user message with multiple ToolResultBlock content blocks. This function merges
// consecutive tool result messages accordingly.
func convertMessages(messages []Message) ([]types.Message, []types.SystemContentBlock) {
var bedrockMessages []types.Message
var systemPrompts []types.SystemContentBlock
// Helper to check if a message is a tool result
isToolResult := func(msg Message) bool {
return (msg.Role == "tool" || (msg.Role == "user" && msg.ToolCallID != "")) && msg.ToolCallID != ""
}
// Helper to create a tool result content block
makeToolResultBlock := func(msg Message) types.ContentBlock {
return &types.ContentBlockMemberToolResult{
Value: types.ToolResultBlock{
ToolUseId: aws.String(msg.ToolCallID),
Content: []types.ToolResultContentBlock{
&types.ToolResultContentBlockMemberText{
Value: msg.Content,
},
},
},
}
}
i := 0
for i < len(messages) {
msg := messages[i]
switch {
case msg.Role == "system":
// System messages go to the System field
systemPrompts = append(systemPrompts, &types.SystemContentBlockMemberText{
Value: msg.Content,
})
i++
case isToolResult(msg):
// Collect all consecutive tool results into a single user message
// Bedrock requires all tool results for a turn in one message
var toolResultBlocks []types.ContentBlock
for i < len(messages) && isToolResult(messages[i]) {
toolResultBlocks = append(toolResultBlocks, makeToolResultBlock(messages[i]))
i++
}
bedrockMessages = append(bedrockMessages, types.Message{
Role: types.ConversationRoleUser,
Content: toolResultBlocks,
})
case msg.Role == "user":
// Regular user message (no ToolCallID)
content := buildUserContent(msg)
bedrockMessages = append(bedrockMessages, types.Message{
Role: types.ConversationRoleUser,
Content: content,
})
i++
case msg.Role == "assistant":
content := buildAssistantContent(msg)
bedrockMessages = append(bedrockMessages, types.Message{
Role: types.ConversationRoleAssistant,
Content: content,
})
i++
case msg.Role == "tool" && msg.ToolCallID == "":
// Tool message without ToolCallID - treat as regular user message
content := buildUserContent(msg)
bedrockMessages = append(bedrockMessages, types.Message{
Role: types.ConversationRoleUser,
Content: content,
})
i++
default:
// Unknown role - skip
i++
}
}
return bedrockMessages, systemPrompts
}
// buildUserContent builds Bedrock content blocks for a user message.
func buildUserContent(msg Message) []types.ContentBlock {
var content []types.ContentBlock
// Add text content
if msg.Content != "" {
content = append(content, &types.ContentBlockMemberText{
Value: msg.Content,
})
}
// Add images from Media field
for _, mediaURL := range msg.Media {
if strings.HasPrefix(mediaURL, "data:image/") {
// Parse data URL: data:image/jpeg;base64,<data>
parts := strings.SplitN(mediaURL, ",", 2)
if len(parts) != 2 {
continue
}
// Extract media type from "data:image/jpeg;base64"
mediaType := ""
header := parts[0]
if idx := strings.Index(header, "/"); idx != -1 {
end := strings.Index(header[idx:], ";")
if end == -1 {
end = len(header) - idx
}
mediaType = header[idx+1 : idx+end]
}
// Verify this is base64 encoded
if !strings.Contains(header, ";base64") {
continue // Skip non-base64 encoded data
}
// Map media type to Bedrock format
var format types.ImageFormat
switch mediaType {
case "jpeg", "jpg":
format = types.ImageFormatJpeg
case "png":
format = types.ImageFormatPng
case "gif":
format = types.ImageFormatGif
case "webp":
format = types.ImageFormatWebp
default:
continue // Skip unsupported formats
}
// Check size before decoding to prevent excessive memory allocation
// Bedrock has a ~20MB request limit; cap decoded images at 10MB
const maxImageSize = 10 * 1024 * 1024
decodedLen := base64.StdEncoding.DecodedLen(len(parts[1]))
if decodedLen > maxImageSize {
log.Printf("bedrock: skipping image exceeding size limit (%d bytes > %d)", decodedLen, maxImageSize)
continue
}
// Decode base64 data
imageData, err := base64.StdEncoding.DecodeString(parts[1])
if err != nil {
log.Printf("bedrock: failed to decode base64 image data: %v", err)
continue
}
content = append(content, &types.ContentBlockMemberImage{
Value: types.ImageBlock{
Format: format,
Source: &types.ImageSourceMemberBytes{
Value: imageData,
},
},
})
}
}
// Bedrock requires at least one content block; add empty text if needed
if len(content) == 0 {
content = append(content, &types.ContentBlockMemberText{Value: ""})
}
return content
}
// buildAssistantContent builds Bedrock content blocks for an assistant message.
func buildAssistantContent(msg Message) []types.ContentBlock {
var content []types.ContentBlock
// Add text content if present
if msg.Content != "" {
content = append(content, &types.ContentBlockMemberText{
Value: msg.Content,
})
}
// Add tool use blocks
for _, tc := range msg.ToolCalls {
// Validate tool call ID - Bedrock requires non-empty ToolUseId
if strings.TrimSpace(tc.ID) == "" {
log.Printf("bedrock: skipping tool call with empty ID (name: %q)", tc.Name)
continue
}
// Resolve tool name: prefer tc.Name, fallback to tc.Function.Name
// (tc.Name/tc.Arguments are json:"-" and may be empty when from JSON)
toolName := tc.Name
if toolName == "" && tc.Function != nil {
toolName = tc.Function.Name
}
if strings.TrimSpace(toolName) == "" {
continue
}
// Resolve arguments: prefer tc.Arguments, fallback to parsing tc.Function.Arguments
args := tc.Arguments
if args == nil && tc.Function != nil && tc.Function.Arguments != "" {
if err := json.Unmarshal([]byte(tc.Function.Arguments), &args); err != nil {
log.Printf("bedrock: failed to parse Function.Arguments for tool %q: %v", toolName, err)
args = map[string]any{}
}
}
if args == nil {
args = map[string]any{}
}
// Convert arguments to a Bedrock document using NewLazyDocument
inputDoc := document.NewLazyDocument(args)
content = append(content, &types.ContentBlockMemberToolUse{
Value: types.ToolUseBlock{
ToolUseId: aws.String(tc.ID),
Name: aws.String(toolName),
Input: inputDoc,
},
})
}
// Bedrock requires at least one content block; add empty text if needed
if len(content) == 0 {
content = append(content, &types.ContentBlockMemberText{Value: ""})
}
return content
}
// convertTools converts tool definitions to Bedrock format.
func convertTools(tools []ToolDefinition) *types.ToolConfiguration {
bedrockTools := make([]types.Tool, 0, len(tools))
for _, tool := range tools {
// Skip tools with empty names
if strings.TrimSpace(tool.Function.Name) == "" {
continue
}
// Ensure parameters is not nil - default to minimal object schema
params := tool.Function.Parameters
if params == nil {
params = map[string]any{
"type": "object",
"properties": map[string]any{},
}
}
// Convert parameters schema to a Bedrock document
inputSchema := document.NewLazyDocument(params)
bedrockTools = append(bedrockTools, &types.ToolMemberToolSpec{
Value: types.ToolSpecification{
Name: aws.String(tool.Function.Name),
Description: aws.String(tool.Function.Description),
InputSchema: &types.ToolInputSchemaMemberJson{
Value: inputSchema,
},
},
})
}
return &types.ToolConfiguration{
Tools: bedrockTools,
}
}
// parseResponse converts Bedrock Converse output to LLMResponse.
func parseResponse(output *bedrockruntime.ConverseOutput) (*LLMResponse, error) {
var content strings.Builder
toolCalls := make([]ToolCall, 0)
// Process output content blocks
if output.Output != nil {
if msgOutput, ok := output.Output.(*types.ConverseOutputMemberMessage); ok {
for _, block := range msgOutput.Value.Content {
switch b := block.(type) {
case *types.ContentBlockMemberText:
content.WriteString(b.Value)
case *types.ContentBlockMemberToolUse:
// Unmarshal the document interface to a map
args := make(map[string]any)
if b.Value.Input != nil {
if err := b.Value.Input.UnmarshalSmithyDocument(&args); err != nil {
log.Printf("bedrock: failed to unmarshal tool input for tool %q (id %q): %v",
aws.ToString(b.Value.Name),
aws.ToString(b.Value.ToolUseId),
err,
)
args = make(map[string]any)
}
}
// Serialize arguments to JSON string for FunctionCall
argsJSON, err := json.Marshal(args)
if err != nil {
log.Printf("bedrock: failed to marshal tool arguments for tool %q (id %q): %v",
aws.ToString(b.Value.Name),
aws.ToString(b.Value.ToolUseId),
err,
)
argsJSON = []byte("{}")
}
toolCalls = append(toolCalls, ToolCall{
ID: aws.ToString(b.Value.ToolUseId),
Name: aws.ToString(b.Value.Name),
Arguments: args,
Function: &FunctionCall{
Name: aws.ToString(b.Value.Name),
Arguments: string(argsJSON),
},
})
}
}
}
}
// Map stop reason
finishReason := "stop"
switch output.StopReason {
case types.StopReasonToolUse:
finishReason = "tool_calls"
case types.StopReasonMaxTokens:
finishReason = "length"
case types.StopReasonEndTurn:
finishReason = "stop"
case types.StopReasonStopSequence:
finishReason = "stop"
case types.StopReasonContentFiltered:
finishReason = "content_filter"
}
// Build usage info
var usage *UsageInfo
if output.Usage != nil {
usage = &UsageInfo{
PromptTokens: int(aws.ToInt32(output.Usage.InputTokens)),
CompletionTokens: int(aws.ToInt32(output.Usage.OutputTokens)),
TotalTokens: int(aws.ToInt32(output.Usage.InputTokens)) + int(aws.ToInt32(output.Usage.OutputTokens)),
}
}
return &LLMResponse{
Content: content.String(),
ToolCalls: toolCalls,
FinishReason: finishReason,
Usage: usage,
}, nil
}
@@ -0,0 +1,541 @@
//go:build bedrock
// PicoClaw - Ultra-lightweight personal AI agent
// License: MIT
//
// Copyright (c) 2026 PicoClaw contributors
package bedrock
import (
"testing"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime/document"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
)
func TestConvertMessages_SystemPrompts(t *testing.T) {
messages := []Message{
{Role: "system", Content: "You are a helpful assistant."},
{Role: "user", Content: "Hello"},
}
bedrockMsgs, systemPrompts := convertMessages(messages)
assert.Len(t, systemPrompts, 1)
assert.Len(t, bedrockMsgs, 1)
// Check system prompt
textBlock, ok := systemPrompts[0].(*types.SystemContentBlockMemberText)
require.True(t, ok)
assert.Equal(t, "You are a helpful assistant.", textBlock.Value)
// Check user message
assert.Equal(t, types.ConversationRoleUser, bedrockMsgs[0].Role)
}
func TestConvertMessages_UserMessage(t *testing.T) {
messages := []Message{
{Role: "user", Content: "What is 2+2?"},
}
bedrockMsgs, systemPrompts := convertMessages(messages)
assert.Empty(t, systemPrompts)
assert.Len(t, bedrockMsgs, 1)
assert.Equal(t, types.ConversationRoleUser, bedrockMsgs[0].Role)
textBlock, ok := bedrockMsgs[0].Content[0].(*types.ContentBlockMemberText)
require.True(t, ok)
assert.Equal(t, "What is 2+2?", textBlock.Value)
}
func TestConvertMessages_AssistantMessage(t *testing.T) {
messages := []Message{
{Role: "assistant", Content: "The answer is 4."},
}
bedrockMsgs, _ := convertMessages(messages)
assert.Len(t, bedrockMsgs, 1)
assert.Equal(t, types.ConversationRoleAssistant, bedrockMsgs[0].Role)
textBlock, ok := bedrockMsgs[0].Content[0].(*types.ContentBlockMemberText)
require.True(t, ok)
assert.Equal(t, "The answer is 4.", textBlock.Value)
}
func TestConvertMessages_ToolResult(t *testing.T) {
messages := []Message{
{Role: "tool", Content: "Result from tool", ToolCallID: "call_123"},
}
bedrockMsgs, _ := convertMessages(messages)
assert.Len(t, bedrockMsgs, 1)
assert.Equal(t, types.ConversationRoleUser, bedrockMsgs[0].Role)
toolResult, ok := bedrockMsgs[0].Content[0].(*types.ContentBlockMemberToolResult)
require.True(t, ok)
assert.Equal(t, "call_123", aws.ToString(toolResult.Value.ToolUseId))
}
func TestConvertMessages_MultipleToolResultsMerged(t *testing.T) {
// When an assistant makes multiple tool calls, all tool results must be
// merged into a single user message for Bedrock
messages := []Message{
{Role: "user", Content: "What's the weather in NYC and LA?"},
{
Role: "assistant",
Content: "Let me check both cities.",
ToolCalls: []protocoltypes.ToolCall{
{ID: "call_nyc", Name: "get_weather", Arguments: map[string]any{"city": "NYC"}},
{ID: "call_la", Name: "get_weather", Arguments: map[string]any{"city": "LA"}},
},
},
{Role: "tool", Content: "NYC: 72°F, sunny", ToolCallID: "call_nyc"},
{Role: "tool", Content: "LA: 85°F, clear", ToolCallID: "call_la"},
}
bedrockMsgs, _ := convertMessages(messages)
// Should be: user message, assistant message, merged tool results (single user message)
assert.Len(t, bedrockMsgs, 3)
// First message: user
assert.Equal(t, types.ConversationRoleUser, bedrockMsgs[0].Role)
// Second message: assistant with tool calls
assert.Equal(t, types.ConversationRoleAssistant, bedrockMsgs[1].Role)
// Third message: merged tool results in single user message
assert.Equal(t, types.ConversationRoleUser, bedrockMsgs[2].Role)
assert.Len(t, bedrockMsgs[2].Content, 2) // Both tool results in one message
// Verify both tool results are present
result1, ok := bedrockMsgs[2].Content[0].(*types.ContentBlockMemberToolResult)
require.True(t, ok)
assert.Equal(t, "call_nyc", aws.ToString(result1.Value.ToolUseId))
result2, ok := bedrockMsgs[2].Content[1].(*types.ContentBlockMemberToolResult)
require.True(t, ok)
assert.Equal(t, "call_la", aws.ToString(result2.Value.ToolUseId))
}
func TestConvertMessages_AssistantWithToolCalls(t *testing.T) {
messages := []Message{
{
Role: "assistant",
Content: "Let me calculate that.",
ToolCalls: []protocoltypes.ToolCall{
{
ID: "call_456",
Name: "calculator",
Arguments: map[string]any{"expression": "2+2"},
},
},
},
}
bedrockMsgs, _ := convertMessages(messages)
assert.Len(t, bedrockMsgs, 1)
assert.Len(t, bedrockMsgs[0].Content, 2) // text + tool use
// Check text content
textBlock, ok := bedrockMsgs[0].Content[0].(*types.ContentBlockMemberText)
require.True(t, ok)
assert.Equal(t, "Let me calculate that.", textBlock.Value)
// Check tool use
toolUse, ok := bedrockMsgs[0].Content[1].(*types.ContentBlockMemberToolUse)
require.True(t, ok)
assert.Equal(t, "call_456", aws.ToString(toolUse.Value.ToolUseId))
assert.Equal(t, "calculator", aws.ToString(toolUse.Value.Name))
}
func TestConvertTools_Basic(t *testing.T) {
tools := []ToolDefinition{
{
Function: protocoltypes.ToolFunctionDefinition{
Name: "get_weather",
Description: "Get the current weather",
Parameters: map[string]any{
"type": "object",
"properties": map[string]any{
"location": map[string]any{"type": "string"},
},
},
},
},
}
toolConfig := convertTools(tools)
assert.NotNil(t, toolConfig)
assert.Len(t, toolConfig.Tools, 1)
toolSpec, ok := toolConfig.Tools[0].(*types.ToolMemberToolSpec)
require.True(t, ok)
assert.Equal(t, "get_weather", aws.ToString(toolSpec.Value.Name))
assert.Equal(t, "Get the current weather", aws.ToString(toolSpec.Value.Description))
}
func TestConvertTools_SkipsEmptyName(t *testing.T) {
tools := []ToolDefinition{
{
Function: protocoltypes.ToolFunctionDefinition{
Name: "",
Description: "Empty name tool",
},
},
{
Function: protocoltypes.ToolFunctionDefinition{
Name: " ",
Description: "Whitespace name tool",
},
},
{
Function: protocoltypes.ToolFunctionDefinition{
Name: "valid_tool",
Description: "Valid tool",
},
},
}
toolConfig := convertTools(tools)
assert.Len(t, toolConfig.Tools, 1)
toolSpec := toolConfig.Tools[0].(*types.ToolMemberToolSpec)
assert.Equal(t, "valid_tool", aws.ToString(toolSpec.Value.Name))
}
func TestConvertTools_NilParameters(t *testing.T) {
tools := []ToolDefinition{
{
Function: protocoltypes.ToolFunctionDefinition{
Name: "simple_tool",
Description: "A tool with no parameters",
Parameters: nil,
},
},
}
toolConfig := convertTools(tools)
assert.Len(t, toolConfig.Tools, 1)
// Should not panic and should create a valid tool
}
func TestBuildUserContent_TextOnly(t *testing.T) {
msg := Message{Content: "Hello world"}
content := buildUserContent(msg)
assert.Len(t, content, 1)
textBlock, ok := content[0].(*types.ContentBlockMemberText)
require.True(t, ok)
assert.Equal(t, "Hello world", textBlock.Value)
}
func TestBuildUserContent_WithImage(t *testing.T) {
// Base64-encoded 1x1 PNG (the provider doesn't validate image correctness,
// it just verifies the format and base64 decoding works)
b64Data := "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADUlEQVR4nGNgYAAAAAMAASsJTYQAAAAASUVORK5CYII="
msg := Message{
Content: "Look at this image",
Media: []string{"data:image/png;base64," + b64Data},
}
content := buildUserContent(msg)
assert.Len(t, content, 2)
// Check text
textBlock, ok := content[0].(*types.ContentBlockMemberText)
require.True(t, ok)
assert.Equal(t, "Look at this image", textBlock.Value)
// Check image
imageBlock, ok := content[1].(*types.ContentBlockMemberImage)
require.True(t, ok)
assert.Equal(t, types.ImageFormatPng, imageBlock.Value.Format)
}
func TestBuildUserContent_SkipsInvalidBase64(t *testing.T) {
msg := Message{
Content: "Invalid image",
Media: []string{"data:image/png;base64,not-valid-base64!!!"},
}
content := buildUserContent(msg)
// Should only have text, image should be skipped
assert.Len(t, content, 1)
}
func TestBuildUserContent_SkipsNonBase64Data(t *testing.T) {
msg := Message{
Content: "Non-base64 image",
Media: []string{"data:image/png,raw-data-here"},
}
content := buildUserContent(msg)
// Should only have text, non-base64 image should be skipped
assert.Len(t, content, 1)
}
func TestBuildAssistantContent_SkipsEmptyToolName(t *testing.T) {
msg := Message{
Content: "Response",
ToolCalls: []protocoltypes.ToolCall{
{ID: "1", Name: "", Arguments: map[string]any{}},
{ID: "2", Name: " ", Arguments: map[string]any{}},
{ID: "3", Name: "valid", Arguments: map[string]any{}},
},
}
content := buildAssistantContent(msg)
// Should have text + 1 valid tool
assert.Len(t, content, 2)
}
func TestBuildAssistantContent_NilArguments(t *testing.T) {
msg := Message{
ToolCalls: []protocoltypes.ToolCall{
{ID: "1", Name: "tool", Arguments: nil},
},
}
content := buildAssistantContent(msg)
assert.Len(t, content, 1)
toolUse, ok := content[0].(*types.ContentBlockMemberToolUse)
require.True(t, ok)
assert.NotNil(t, toolUse.Value.Input)
}
func TestBuildAssistantContent_FunctionFallback(t *testing.T) {
// When Name/Arguments are empty (json:"-"), should fallback to Function fields
msg := Message{
ToolCalls: []protocoltypes.ToolCall{
{
ID: "1",
Name: "", // empty, should fallback to Function.Name
Function: &protocoltypes.FunctionCall{
Name: "fallback_tool",
Arguments: `{"key":"value"}`,
},
},
},
}
content := buildAssistantContent(msg)
assert.Len(t, content, 1)
toolUse, ok := content[0].(*types.ContentBlockMemberToolUse)
require.True(t, ok)
assert.Equal(t, "fallback_tool", aws.ToString(toolUse.Value.Name))
}
func TestParseResponse_TextOnly(t *testing.T) {
output := &bedrockruntime.ConverseOutput{
Output: &types.ConverseOutputMemberMessage{
Value: types.Message{
Role: types.ConversationRoleAssistant,
Content: []types.ContentBlock{
&types.ContentBlockMemberText{Value: "Hello!"},
},
},
},
StopReason: types.StopReasonEndTurn,
Usage: &types.TokenUsage{
InputTokens: aws.Int32(10),
OutputTokens: aws.Int32(5),
},
}
resp, err := parseResponse(output)
require.NoError(t, err)
assert.Equal(t, "Hello!", resp.Content)
assert.Equal(t, "stop", resp.FinishReason)
assert.Empty(t, resp.ToolCalls)
assert.Equal(t, 10, resp.Usage.PromptTokens)
assert.Equal(t, 5, resp.Usage.CompletionTokens)
}
func TestParseResponse_StopReasons(t *testing.T) {
tests := []struct {
stopReason types.StopReason
expectedFinish string
}{
{types.StopReasonEndTurn, "stop"},
{types.StopReasonToolUse, "tool_calls"},
{types.StopReasonMaxTokens, "length"},
{types.StopReasonStopSequence, "stop"},
{types.StopReasonContentFiltered, "content_filter"},
}
for _, tt := range tests {
t.Run(string(tt.stopReason), func(t *testing.T) {
output := &bedrockruntime.ConverseOutput{
Output: &types.ConverseOutputMemberMessage{
Value: types.Message{
Content: []types.ContentBlock{
&types.ContentBlockMemberText{Value: "test"},
},
},
},
StopReason: tt.stopReason,
}
resp, err := parseResponse(output)
require.NoError(t, err)
assert.Equal(t, tt.expectedFinish, resp.FinishReason)
})
}
}
func TestParseResponse_WithToolCalls(t *testing.T) {
// Note: document.NewLazyDocument has limitations with UnmarshalSmithyDocument in tests,
// so we test the structure extraction and verify Arguments gets populated (even if empty
// due to SDK limitations). The actual unmarshal works correctly at runtime.
toolInput := document.NewLazyDocument(map[string]any{
"location": "San Francisco",
"unit": "celsius",
})
output := &bedrockruntime.ConverseOutput{
Output: &types.ConverseOutputMemberMessage{
Value: types.Message{
Role: types.ConversationRoleAssistant,
Content: []types.ContentBlock{
&types.ContentBlockMemberText{Value: "Let me check the weather."},
&types.ContentBlockMemberToolUse{
Value: types.ToolUseBlock{
ToolUseId: aws.String("call_weather_123"),
Name: aws.String("get_weather"),
Input: toolInput,
},
},
},
},
},
StopReason: types.StopReasonToolUse,
Usage: &types.TokenUsage{
InputTokens: aws.Int32(20),
OutputTokens: aws.Int32(15),
},
}
resp, err := parseResponse(output)
require.NoError(t, err)
assert.Equal(t, "Let me check the weather.", resp.Content)
assert.Equal(t, "tool_calls", resp.FinishReason)
assert.Len(t, resp.ToolCalls, 1)
// Verify tool call ID and Name are extracted correctly
tc := resp.ToolCalls[0]
assert.Equal(t, "call_weather_123", tc.ID)
assert.Equal(t, "get_weather", tc.Name)
// Verify Function fields are also populated
require.NotNil(t, tc.Function)
assert.Equal(t, "get_weather", tc.Function.Name)
// Verify Arguments is not nil (content may vary due to SDK limitations in tests)
assert.NotNil(t, tc.Arguments)
// Verify usage
assert.Equal(t, 20, resp.Usage.PromptTokens)
assert.Equal(t, 15, resp.Usage.CompletionTokens)
assert.Equal(t, 35, resp.Usage.TotalTokens)
}
func TestParseResponse_MultipleToolCalls(t *testing.T) {
output := &bedrockruntime.ConverseOutput{
Output: &types.ConverseOutputMemberMessage{
Value: types.Message{
Role: types.ConversationRoleAssistant,
Content: []types.ContentBlock{
&types.ContentBlockMemberToolUse{
Value: types.ToolUseBlock{
ToolUseId: aws.String("call_1"),
Name: aws.String("tool_a"),
Input: document.NewLazyDocument(map[string]any{"arg": "value1"}),
},
},
&types.ContentBlockMemberToolUse{
Value: types.ToolUseBlock{
ToolUseId: aws.String("call_2"),
Name: aws.String("tool_b"),
Input: document.NewLazyDocument(map[string]any{"arg": "value2"}),
},
},
},
},
},
StopReason: types.StopReasonToolUse,
}
resp, err := parseResponse(output)
require.NoError(t, err)
assert.Equal(t, "tool_calls", resp.FinishReason)
assert.Len(t, resp.ToolCalls, 2)
// Verify tool call structure
assert.Equal(t, "call_1", resp.ToolCalls[0].ID)
assert.Equal(t, "tool_a", resp.ToolCalls[0].Name)
assert.NotNil(t, resp.ToolCalls[0].Arguments)
assert.NotNil(t, resp.ToolCalls[0].Function)
assert.Equal(t, "tool_a", resp.ToolCalls[0].Function.Name)
assert.Equal(t, "call_2", resp.ToolCalls[1].ID)
assert.Equal(t, "tool_b", resp.ToolCalls[1].Name)
assert.NotNil(t, resp.ToolCalls[1].Arguments)
assert.NotNil(t, resp.ToolCalls[1].Function)
assert.Equal(t, "tool_b", resp.ToolCalls[1].Function.Name)
}
func TestParseResponse_ToolCallWithNilInput(t *testing.T) {
output := &bedrockruntime.ConverseOutput{
Output: &types.ConverseOutputMemberMessage{
Value: types.Message{
Role: types.ConversationRoleAssistant,
Content: []types.ContentBlock{
&types.ContentBlockMemberToolUse{
Value: types.ToolUseBlock{
ToolUseId: aws.String("call_nil"),
Name: aws.String("no_args_tool"),
Input: nil,
},
},
},
},
},
StopReason: types.StopReasonToolUse,
}
resp, err := parseResponse(output)
require.NoError(t, err)
assert.Len(t, resp.ToolCalls, 1)
assert.Equal(t, "call_nil", resp.ToolCalls[0].ID)
assert.Equal(t, "no_args_tool", resp.ToolCalls[0].Name)
// Arguments should be empty map, not nil
assert.NotNil(t, resp.ToolCalls[0].Arguments)
assert.Empty(t, resp.ToolCalls[0].Arguments)
}
+73
View File
@@ -0,0 +1,73 @@
//go:build !bedrock
// PicoClaw - Ultra-lightweight personal AI agent
// License: MIT
//
// Copyright (c) 2026 PicoClaw contributors
// Package bedrock provides a stub implementation when built without the bedrock tag.
// To enable AWS Bedrock support, build with: go build -tags bedrock
package bedrock
import (
"context"
"fmt"
"time"
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
)
type (
LLMResponse = protocoltypes.LLMResponse
Message = protocoltypes.Message
ToolDefinition = protocoltypes.ToolDefinition
)
// Provider is a stub that returns an error when Bedrock support is not compiled in.
type Provider struct{}
// Option is a no-op when Bedrock is not enabled.
type Option func(*providerConfig)
type providerConfig struct{}
// WithRegion is a no-op when Bedrock is not enabled.
func WithRegion(region string) Option {
return func(c *providerConfig) {}
}
// WithProfile is a no-op when Bedrock is not enabled.
func WithProfile(profile string) Option {
return func(c *providerConfig) {}
}
// WithBaseEndpoint is a no-op when Bedrock is not enabled.
func WithBaseEndpoint(endpoint string) Option {
return func(c *providerConfig) {}
}
// WithRequestTimeout is a no-op when Bedrock is not enabled.
func WithRequestTimeout(timeout time.Duration) Option {
return func(c *providerConfig) {}
}
// NewProvider returns an error indicating Bedrock support is not compiled in.
func NewProvider(ctx context.Context, opts ...Option) (*Provider, error) {
return nil, fmt.Errorf("bedrock provider not available: build with -tags bedrock to enable AWS Bedrock support")
}
// Chat returns an error - this should never be called since NewProvider fails.
func (p *Provider) Chat(
ctx context.Context,
messages []Message,
tools []ToolDefinition,
model string,
options map[string]any,
) (*LLMResponse, error) {
return nil, fmt.Errorf("bedrock provider not available: build with -tags bedrock to enable AWS Bedrock support")
}
// GetDefaultModel returns an empty string.
func (p *Provider) GetDefaultModel() string {
return ""
}
@@ -0,0 +1,35 @@
//go:build !bedrock
// PicoClaw - Ultra-lightweight personal AI agent
// License: MIT
//
// Copyright (c) 2026 PicoClaw contributors
package bedrock
import (
"context"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewProvider_ReturnsStubError(t *testing.T) {
provider, err := NewProvider(context.Background())
assert.Nil(t, provider)
require.Error(t, err)
assert.True(t, strings.Contains(err.Error(), "build with -tags bedrock"),
"error should mention build tag requirement, got: %s", err.Error())
}
func TestNewProvider_WithOptions_ReturnsStubError(t *testing.T) {
provider, err := NewProvider(context.Background(), WithRegion("us-west-2"), WithProfile("test"))
assert.Nil(t, provider)
require.Error(t, err)
assert.True(t, strings.Contains(err.Error(), "build with -tags bedrock"),
"error should mention build tag requirement, got: %s", err.Error())
}
+42 -2
View File
@@ -6,12 +6,15 @@
package providers
import (
"context"
"fmt"
"strings"
"time"
"github.com/sipeed/picoclaw/pkg/config"
anthropicmessages "github.com/sipeed/picoclaw/pkg/providers/anthropic_messages"
"github.com/sipeed/picoclaw/pkg/providers/azure"
"github.com/sipeed/picoclaw/pkg/providers/bedrock"
)
// createClaudeAuthProvider creates a Claude provider using OAuth credentials from auth store.
@@ -55,8 +58,9 @@ func ExtractProtocol(model string) (protocol, modelID string) {
// CreateProviderFromConfig creates a provider based on the ModelConfig.
// It uses the protocol prefix in the Model field to determine which provider to create.
// Supported protocols: openai, litellm, novita, anthropic, anthropic-messages,
// antigravity, claude-cli, codex-cli, github-copilot
// Supported protocol families include OpenAI-compatible prefixes (e.g., openai, openrouter, groq, gemini),
// Azure OpenAI, Amazon Bedrock, Anthropic (including messages), and various CLI/compatibility shims.
// See the switch on protocol in this function for the authoritative list.
// Returns the provider, the model ID (without protocol prefix), and any error.
func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, error) {
if cfg == nil {
@@ -114,6 +118,42 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
cfg.RequestTimeout,
), modelID, nil
case "bedrock":
// AWS Bedrock uses AWS SDK credentials (env vars, profiles, IAM roles, etc.)
// api_base can be:
// - A full endpoint URL: https://bedrock-runtime.us-east-1.amazonaws.com
// - A region name: us-east-1 (AWS SDK resolves endpoint automatically)
var opts []bedrock.Option
if cfg.APIBase != "" {
if !strings.Contains(cfg.APIBase, "://") {
// Treat as region: let AWS SDK resolve the correct endpoint
// (supports all AWS partitions: aws, aws-cn, aws-us-gov, etc.)
opts = append(opts, bedrock.WithRegion(cfg.APIBase))
} else {
// Full endpoint URL provided (for custom endpoints or testing)
opts = append(opts, bedrock.WithBaseEndpoint(cfg.APIBase))
}
}
// Use a separate timeout for AWS config loading (credential resolution can block)
initTimeout := 30 * time.Second
if cfg.RequestTimeout > 0 {
reqTimeout := time.Duration(cfg.RequestTimeout) * time.Second
// Set request timeout for API calls
opts = append(opts, bedrock.WithRequestTimeout(reqTimeout))
// Ensure init timeout is at least as large as request timeout
if reqTimeout > initTimeout {
initTimeout = reqTimeout
}
}
ctx, cancel := context.WithTimeout(context.Background(), initTimeout)
defer cancel()
// Note: AWS_PROFILE env var is automatically used by AWS SDK
provider, err := bedrock.NewProvider(ctx, opts...)
if err != nil {
return nil, "", fmt.Errorf("creating bedrock provider: %w", err)
}
return provider, modelID, nil
case "litellm", "openrouter", "groq", "zhipu", "gemini", "nvidia",
"ollama", "moonshot", "shengsuanyun", "deepseek", "cerebras",
"vivgrid", "volcengine", "vllm", "qwen", "qwen-intl", "qwen-international", "dashscope-intl",
+75
View File
@@ -700,3 +700,78 @@ func TestCreateProviderFromConfig_MinimaxPreservesUserExtraBody(t *testing.T) {
t.Fatalf("custom_field = %v, want test", got)
}
}
func TestCreateProviderFromConfig_Bedrock(t *testing.T) {
// Set dummy AWS env vars to make test deterministic
t.Setenv("AWS_ACCESS_KEY_ID", "test-key")
t.Setenv("AWS_SECRET_ACCESS_KEY", "test-secret")
t.Setenv("AWS_EC2_METADATA_DISABLED", "true")
// Clear profile-related env vars to avoid loading shared config
t.Setenv("AWS_PROFILE", "")
t.Setenv("AWS_DEFAULT_PROFILE", "")
t.Setenv("AWS_SDK_LOAD_CONFIG", "")
t.Setenv("AWS_SHARED_CREDENTIALS_FILE", "")
cfg := &config.ModelConfig{
ModelName: "bedrock-claude",
Model: "bedrock/us.anthropic.claude-sonnet-4-20250514-v1:0",
APIBase: "us-west-2", // Region (also sets AWS region)
}
provider, modelID, err := CreateProviderFromConfig(cfg)
if err == nil {
// Provider created successfully (built with -tags bedrock)
if provider == nil {
t.Error("provider is nil on success")
}
if modelID != "us.anthropic.claude-sonnet-4-20250514-v1:0" {
t.Errorf("modelID = %q, want %q", modelID, "us.anthropic.claude-sonnet-4-20250514-v1:0")
}
return
}
errMsg := err.Error()
// When built without -tags bedrock, expect stub error
if strings.Contains(errMsg, "build with -tags bedrock") {
return // Expected stub error
}
// Unexpected error - fail the test
t.Errorf("unexpected error from bedrock provider: %v", err)
}
func TestCreateProviderFromConfig_BedrockWithEndpointURL(t *testing.T) {
// Set dummy AWS env vars to make test deterministic
t.Setenv("AWS_ACCESS_KEY_ID", "test-key")
t.Setenv("AWS_SECRET_ACCESS_KEY", "test-secret")
t.Setenv("AWS_REGION", "us-east-1") // Required when using endpoint URL
t.Setenv("AWS_EC2_METADATA_DISABLED", "true")
// Clear profile-related env vars to avoid loading shared config
t.Setenv("AWS_PROFILE", "")
t.Setenv("AWS_DEFAULT_PROFILE", "")
t.Setenv("AWS_SDK_LOAD_CONFIG", "")
t.Setenv("AWS_SHARED_CREDENTIALS_FILE", "")
cfg := &config.ModelConfig{
ModelName: "bedrock-claude",
Model: "bedrock/us.anthropic.claude-sonnet-4-20250514-v1:0",
APIBase: "https://bedrock-runtime.us-east-1.amazonaws.com", // Full endpoint URL
}
provider, modelID, err := CreateProviderFromConfig(cfg)
if err == nil {
// Provider created successfully (built with -tags bedrock)
if provider == nil {
t.Error("provider is nil on success")
}
if modelID != "us.anthropic.claude-sonnet-4-20250514-v1:0" {
t.Errorf("modelID = %q, want %q", modelID, "us.anthropic.claude-sonnet-4-20250514-v1:0")
}
return
}
errMsg := err.Error()
// When built without -tags bedrock, expect stub error
if strings.Contains(errMsg, "build with -tags bedrock") {
return // Expected stub error
}
// Unexpected error - fail the test
t.Errorf("unexpected error from bedrock provider: %v", err)
}