feat(providers): add AWS Bedrock provider (#1903)

Add support for AWS Bedrock as an LLM provider using the Converse API.
The implementation is behind a build tag (-tags bedrock) to keep the
default binary size small.

Features:
- AWS SDK v2 with automatic credential chain (env vars, profiles, IAM roles)
- Converse API for unified access to Claude, Llama, Mistral models
- Tool/function calling support with proper document handling
- Image support with base64 decoding and size limits
- Request timeout configuration
- Region validation and endpoint resolution for all AWS partitions

Usage:
  go build -tags bedrock
  model: bedrock/us.anthropic.claude-sonnet-4-20250514-v1:0
  api_base: us-east-1  (or full endpoint URL)
This commit is contained in:
Andy Lo-A-Foe
2026-03-23 18:10:56 +01:00
committed by GitHub
parent 40571996b1
commit b787131c82
9 changed files with 1397 additions and 2 deletions
+3
View File
@@ -373,6 +373,9 @@ PicoClaw supports 30+ LLM providers through the `model_list` configuration. Use
| [Azure OpenAI](https://portal.azure.com/) | `azure/` | Required | Enterprise Azure deployment | | [Azure OpenAI](https://portal.azure.com/) | `azure/` | Required | Enterprise Azure deployment |
| [GitHub Copilot](https://github.com/features/copilot) | `github-copilot/` | OAuth | Device code login | | [GitHub Copilot](https://github.com/features/copilot) | `github-copilot/` | OAuth | Device code login |
| [Antigravity](https://console.cloud.google.com/) | `antigravity/` | OAuth | Google Cloud AI | | [Antigravity](https://console.cloud.google.com/) | `antigravity/` | OAuth | Google Cloud AI |
| [AWS Bedrock](https://console.aws.amazon.com/bedrock)* | `bedrock/` | AWS credentials | Claude, Llama, Mistral on AWS |
> \* AWS Bedrock requires build tag: `go build -tags bedrock`. Set `api_base` to a region name (e.g., `us-east-1`) for automatic endpoint resolution across all AWS partitions (aws, aws-cn, aws-us-gov). When using a full endpoint URL instead, you must also configure `AWS_REGION` via environment variable or AWS config/profile.
<details> <details>
<summary><b>Local deployment (Ollama, vLLM, etc.)</b></summary> <summary><b>Local deployment (Ollama, vLLM, etc.)</b></summary>
+16
View File
@@ -7,6 +7,9 @@ require (
github.com/BurntSushi/toml v1.6.0 github.com/BurntSushi/toml v1.6.0
github.com/adhocore/gronx v1.19.6 github.com/adhocore/gronx v1.19.6
github.com/anthropics/anthropic-sdk-go v1.26.0 github.com/anthropics/anthropic-sdk-go v1.26.0
github.com/aws/aws-sdk-go-v2 v1.41.4
github.com/aws/aws-sdk-go-v2/config v1.32.12
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.50.2
github.com/bwmarrin/discordgo v0.29.0 github.com/bwmarrin/discordgo v0.29.0
github.com/caarlos0/env/v11 v11.4.0 github.com/caarlos0/env/v11 v11.4.0
github.com/ergochat/irc-go v0.6.0 github.com/ergochat/irc-go v0.6.0
@@ -40,6 +43,19 @@ require (
require ( require (
filippo.io/edwards25519 v1.2.0 // indirect filippo.io/edwards25519 v1.2.0 // indirect
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.7 // indirect
github.com/aws/aws-sdk-go-v2/credentials v1.19.12 // indirect
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.20 // indirect
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.20 // indirect
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.20 // indirect
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.6 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.7 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.20 // indirect
github.com/aws/aws-sdk-go-v2/service/signin v1.0.8 // indirect
github.com/aws/aws-sdk-go-v2/service/sso v1.30.13 // indirect
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.17 // indirect
github.com/aws/aws-sdk-go-v2/service/sts v1.41.9 // indirect
github.com/aws/smithy-go v1.24.2 // indirect
github.com/beeper/argo-go v1.1.2 // indirect github.com/beeper/argo-go v1.1.2 // indirect
github.com/coder/websocket v1.8.14 // indirect github.com/coder/websocket v1.8.14 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect github.com/davecgh/go-spew v1.1.1 // indirect
+32
View File
@@ -17,6 +17,38 @@ github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwTo
github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY= github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY=
github.com/anthropics/anthropic-sdk-go v1.26.0 h1:oUTzFaUpAevfuELAP1sjL6CQJ9HHAfT7CoSYSac11PY= github.com/anthropics/anthropic-sdk-go v1.26.0 h1:oUTzFaUpAevfuELAP1sjL6CQJ9HHAfT7CoSYSac11PY=
github.com/anthropics/anthropic-sdk-go v1.26.0/go.mod h1:qUKmaW+uuPB64iy1l+4kOSvaLqPXnHTTBKH6RVZ7q5Q= github.com/anthropics/anthropic-sdk-go v1.26.0/go.mod h1:qUKmaW+uuPB64iy1l+4kOSvaLqPXnHTTBKH6RVZ7q5Q=
github.com/aws/aws-sdk-go-v2 v1.41.4 h1:10f50G7WyU02T56ox1wWXq+zTX9I1zxG46HYuG1hH/k=
github.com/aws/aws-sdk-go-v2 v1.41.4/go.mod h1:mwsPRE8ceUUpiTgF7QmQIJ7lgsKUPQOUl3o72QBrE1o=
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.7 h1:3kGOqnh1pPeddVa/E37XNTaWJ8W6vrbYV9lJEkCnhuY=
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.7/go.mod h1:lyw7GFp3qENLh7kwzf7iMzAxDn+NzjXEAGjKS2UOKqI=
github.com/aws/aws-sdk-go-v2/config v1.32.12 h1:O3csC7HUGn2895eNrLytOJQdoL2xyJy0iYXhoZ1OmP0=
github.com/aws/aws-sdk-go-v2/config v1.32.12/go.mod h1:96zTvoOFR4FURjI+/5wY1vc1ABceROO4lWgWJuxgy0g=
github.com/aws/aws-sdk-go-v2/credentials v1.19.12 h1:oqtA6v+y5fZg//tcTWahyN9PEn5eDU/Wpvc2+kJ4aY8=
github.com/aws/aws-sdk-go-v2/credentials v1.19.12/go.mod h1:U3R1RtSHx6NB0DvEQFGyf/0sbrpJrluENHdPy1j/3TE=
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.20 h1:zOgq3uezl5nznfoK3ODuqbhVg1JzAGDUhXOsU0IDCAo=
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.20/go.mod h1:z/MVwUARehy6GAg/yQ1GO2IMl0k++cu1ohP9zo887wE=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.20 h1:CNXO7mvgThFGqOFgbNAP2nol2qAWBOGfqR/7tQlvLmc=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.20/go.mod h1:oydPDJKcfMhgfcgBUZaG+toBbwy8yPWubJXBVERtI4o=
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.20 h1:tN6W/hg+pkM+tf9XDkWUbDEjGLb+raoBMFsTodcoYKw=
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.20/go.mod h1:YJ898MhD067hSHA6xYCx5ts/jEd8BSOLtQDL3iZsvbc=
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.6 h1:qYQ4pzQ2Oz6WpQ8T3HvGHnZydA72MnLuFK9tJwmrbHw=
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.6/go.mod h1:O3h0IK87yXci+kg6flUKzJnWeziQUKciKrLjcatSNcY=
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.50.2 h1:x0eGAWpd1B5I/vMtrB4Q4Zuc3CXWI8wjHfPPqBSrKmM=
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.50.2/go.mod h1:V9oTWSDC2MtS1DR71hbNET/bZ8psQp022amEBe1grJc=
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.7 h1:5EniKhLZe4xzL7a+fU3C2tfUN4nWIqlLesfrjkuPFTY=
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.7/go.mod h1:x0nZssQ3qZSnIcePWLvcoFisRXJzcTVvYpAAdYX8+GI=
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.20 h1:2HvVAIq+YqgGotK6EkMf+KIEqTISmTYh5zLpYyeTo1Y=
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.20/go.mod h1:V4X406Y666khGa8ghKmphma/7C0DAtEQYhkq9z4vpbk=
github.com/aws/aws-sdk-go-v2/service/signin v1.0.8 h1:0GFOLzEbOyZABS3PhYfBIx2rNBACYcKty+XGkTgw1ow=
github.com/aws/aws-sdk-go-v2/service/signin v1.0.8/go.mod h1:LXypKvk85AROkKhOG6/YEcHFPoX+prKTowKnVdcaIxE=
github.com/aws/aws-sdk-go-v2/service/sso v1.30.13 h1:kiIDLZ005EcKomYYITtfsjn7dtOwHDOFy7IbPXKek2o=
github.com/aws/aws-sdk-go-v2/service/sso v1.30.13/go.mod h1:2h/xGEowcW/g38g06g3KpRWDlT+OTfxxI0o1KqayAB8=
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.17 h1:jzKAXIlhZhJbnYwHbvUQZEB8KfgAEuG0dc08Bkda7NU=
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.17/go.mod h1:Al9fFsXjv4KfbzQHGe6V4NZSZQXecFcvaIF4e70FoRA=
github.com/aws/aws-sdk-go-v2/service/sts v1.41.9 h1:Cng+OOwCHmFljXIxpEVXAGMnBia8MSU6Ch5i9PgBkcU=
github.com/aws/aws-sdk-go-v2/service/sts v1.41.9/go.mod h1:LrlIndBDdjA/EeXeyNBle+gyCwTlizzW5ycgWnvIxkk=
github.com/aws/smithy-go v1.24.2 h1:FzA3bu/nt/vDvmnkg+R8Xl46gmzEDam6mZ1hzmwXFng=
github.com/aws/smithy-go v1.24.2/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc=
github.com/beeper/argo-go v1.1.2 h1:UQI2G8F+NLfGTOmTUI0254pGKx/HUU/etbUGTJv91Fs= github.com/beeper/argo-go v1.1.2 h1:UQI2G8F+NLfGTOmTUI0254pGKx/HUU/etbUGTJv91Fs=
github.com/beeper/argo-go v1.1.2/go.mod h1:M+LJAnyowKVQ6Rdj6XYGEn+qcVFkb3R/MUpqkGR0hM4= github.com/beeper/argo-go v1.1.2/go.mod h1:M+LJAnyowKVQ6Rdj6XYGEn+qcVFkb3R/MUpqkGR0hM4=
github.com/bwmarrin/discordgo v0.29.0 h1:FmWeXFaKUwrcL3Cx65c20bTRW+vOb6k8AnaP+EgjDno= github.com/bwmarrin/discordgo v0.29.0 h1:FmWeXFaKUwrcL3Cx65c20bTRW+vOb6k8AnaP+EgjDno=
+580
View File
@@ -0,0 +1,580 @@
//go:build bedrock
// PicoClaw - Ultra-lightweight personal AI agent
// License: MIT
//
// Copyright (c) 2026 PicoClaw contributors
// Package bedrock implements the LLM provider interface for AWS Bedrock.
// It uses the Bedrock Runtime Converse API for unified access to multiple
// model families (Claude, Llama, Mistral, etc.) with tool/function calling support.
package bedrock
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"log"
"math"
"strings"
"time"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime/document"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types"
"github.com/sipeed/picoclaw/pkg/providers/common"
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
)
type (
ToolCall = protocoltypes.ToolCall
FunctionCall = protocoltypes.FunctionCall
LLMResponse = protocoltypes.LLMResponse
UsageInfo = protocoltypes.UsageInfo
Message = protocoltypes.Message
ToolDefinition = protocoltypes.ToolDefinition
ToolFunctionDefinition = protocoltypes.ToolFunctionDefinition
)
// Provider implements the LLM provider interface for AWS Bedrock.
type Provider struct {
client *bedrockruntime.Client
region string
requestTimeout time.Duration
}
// Option configures the Bedrock Provider.
type Option func(*providerConfig)
type providerConfig struct {
region string
profile string
baseEndpoint string
requestTimeout time.Duration
}
// WithRegion sets the AWS region for Bedrock requests.
func WithRegion(region string) Option {
return func(c *providerConfig) {
c.region = region
}
}
// WithProfile sets the AWS profile to use for credentials.
func WithProfile(profile string) Option {
return func(c *providerConfig) {
c.profile = profile
}
}
// WithBaseEndpoint sets a custom Bedrock endpoint URL.
// Example: https://bedrock-runtime.us-east-1.amazonaws.com
func WithBaseEndpoint(endpoint string) Option {
return func(c *providerConfig) {
c.baseEndpoint = endpoint
}
}
// WithRequestTimeout sets the timeout for Bedrock API requests.
func WithRequestTimeout(timeout time.Duration) Option {
return func(c *providerConfig) {
c.requestTimeout = timeout
}
}
// NewProvider creates a new AWS Bedrock provider.
// It uses the default AWS credential chain (env vars, shared config, IAM roles, etc.).
func NewProvider(ctx context.Context, opts ...Option) (*Provider, error) {
pc := &providerConfig{}
for _, opt := range opts {
opt(pc)
}
// Build AWS config options
var configOpts []func(*config.LoadOptions) error
if pc.region != "" {
configOpts = append(configOpts, config.WithRegion(pc.region))
}
if pc.profile != "" {
configOpts = append(configOpts, config.WithSharedConfigProfile(pc.profile))
}
// Load AWS config with automatic credential discovery
cfg, err := config.LoadDefaultConfig(ctx, configOpts...)
if err != nil {
return nil, fmt.Errorf("loading AWS config: %w", err)
}
// Validate region is set - required for Bedrock request signing
if cfg.Region == "" {
return nil, fmt.Errorf("AWS region not configured: set AWS_REGION, AWS_DEFAULT_REGION, or use WithRegion option")
}
// Build client options
var clientOpts []func(*bedrockruntime.Options)
if pc.baseEndpoint != "" {
clientOpts = append(clientOpts, func(o *bedrockruntime.Options) {
o.BaseEndpoint = aws.String(pc.baseEndpoint)
})
}
client := bedrockruntime.NewFromConfig(cfg, clientOpts...)
return &Provider{
client: client,
region: cfg.Region,
requestTimeout: pc.requestTimeout,
}, nil
}
// Chat sends messages to AWS Bedrock using the Converse API.
func (p *Provider) Chat(
ctx context.Context,
messages []Message,
tools []ToolDefinition,
model string,
options map[string]any,
) (*LLMResponse, error) {
// Apply request timeout if context doesn't already have a deadline.
// Use explicit timeout if set, otherwise fall back to common default.
effectiveTimeout := p.requestTimeout
if effectiveTimeout <= 0 {
effectiveTimeout = common.DefaultRequestTimeout
}
if _, hasDeadline := ctx.Deadline(); !hasDeadline {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, effectiveTimeout)
defer cancel()
}
// Build the Converse API input
input := &bedrockruntime.ConverseInput{
ModelId: aws.String(model),
}
// Convert messages to Bedrock format
bedrockMessages, systemPrompts := convertMessages(messages)
input.Messages = bedrockMessages
// Set system prompts if any
if len(systemPrompts) > 0 {
input.System = systemPrompts
}
// Set inference configuration only when options are provided
var inferenceConfig *types.InferenceConfiguration
if maxTokens, ok := common.AsInt(options["max_tokens"]); ok && maxTokens > 0 {
if inferenceConfig == nil {
inferenceConfig = &types.InferenceConfiguration{}
}
// Clamp to int32 range to avoid overflow
if maxTokens > math.MaxInt32 {
maxTokens = math.MaxInt32
}
inferenceConfig.MaxTokens = aws.Int32(int32(maxTokens))
}
if temp, ok := common.AsFloat(options["temperature"]); ok {
if inferenceConfig == nil {
inferenceConfig = &types.InferenceConfiguration{}
}
inferenceConfig.Temperature = aws.Float32(float32(temp))
}
if inferenceConfig != nil {
input.InferenceConfig = inferenceConfig
}
// Convert tools to Bedrock format
// Only set ToolConfig if at least one valid tool was produced
if len(tools) > 0 {
toolConfig := convertTools(tools)
if len(toolConfig.Tools) > 0 {
input.ToolConfig = toolConfig
}
}
// Call Bedrock Converse API
output, err := p.client.Converse(ctx, input)
if err != nil {
return nil, fmt.Errorf("bedrock converse: %w", err)
}
// Parse the response
return parseResponse(output)
}
// GetDefaultModel returns an empty string as Bedrock models are user-configured.
func (p *Provider) GetDefaultModel() string {
return ""
}
// Region returns the AWS region configured for this Provider.
func (p *Provider) Region() string {
return p.region
}
// convertMessages converts internal messages to Bedrock Converse format.
// Returns the conversation messages and any system prompts separately.
// Note: Bedrock requires all tool results for a given assistant turn to be in a single
// user message with multiple ToolResultBlock content blocks. This function merges
// consecutive tool result messages accordingly.
func convertMessages(messages []Message) ([]types.Message, []types.SystemContentBlock) {
var bedrockMessages []types.Message
var systemPrompts []types.SystemContentBlock
// Helper to check if a message is a tool result
isToolResult := func(msg Message) bool {
return (msg.Role == "tool" || (msg.Role == "user" && msg.ToolCallID != "")) && msg.ToolCallID != ""
}
// Helper to create a tool result content block
makeToolResultBlock := func(msg Message) types.ContentBlock {
return &types.ContentBlockMemberToolResult{
Value: types.ToolResultBlock{
ToolUseId: aws.String(msg.ToolCallID),
Content: []types.ToolResultContentBlock{
&types.ToolResultContentBlockMemberText{
Value: msg.Content,
},
},
},
}
}
i := 0
for i < len(messages) {
msg := messages[i]
switch {
case msg.Role == "system":
// System messages go to the System field
systemPrompts = append(systemPrompts, &types.SystemContentBlockMemberText{
Value: msg.Content,
})
i++
case isToolResult(msg):
// Collect all consecutive tool results into a single user message
// Bedrock requires all tool results for a turn in one message
var toolResultBlocks []types.ContentBlock
for i < len(messages) && isToolResult(messages[i]) {
toolResultBlocks = append(toolResultBlocks, makeToolResultBlock(messages[i]))
i++
}
bedrockMessages = append(bedrockMessages, types.Message{
Role: types.ConversationRoleUser,
Content: toolResultBlocks,
})
case msg.Role == "user":
// Regular user message (no ToolCallID)
content := buildUserContent(msg)
bedrockMessages = append(bedrockMessages, types.Message{
Role: types.ConversationRoleUser,
Content: content,
})
i++
case msg.Role == "assistant":
content := buildAssistantContent(msg)
bedrockMessages = append(bedrockMessages, types.Message{
Role: types.ConversationRoleAssistant,
Content: content,
})
i++
case msg.Role == "tool" && msg.ToolCallID == "":
// Tool message without ToolCallID - treat as regular user message
content := buildUserContent(msg)
bedrockMessages = append(bedrockMessages, types.Message{
Role: types.ConversationRoleUser,
Content: content,
})
i++
default:
// Unknown role - skip
i++
}
}
return bedrockMessages, systemPrompts
}
// buildUserContent builds Bedrock content blocks for a user message.
func buildUserContent(msg Message) []types.ContentBlock {
var content []types.ContentBlock
// Add text content
if msg.Content != "" {
content = append(content, &types.ContentBlockMemberText{
Value: msg.Content,
})
}
// Add images from Media field
for _, mediaURL := range msg.Media {
if strings.HasPrefix(mediaURL, "data:image/") {
// Parse data URL: data:image/jpeg;base64,<data>
parts := strings.SplitN(mediaURL, ",", 2)
if len(parts) != 2 {
continue
}
// Extract media type from "data:image/jpeg;base64"
mediaType := ""
header := parts[0]
if idx := strings.Index(header, "/"); idx != -1 {
end := strings.Index(header[idx:], ";")
if end == -1 {
end = len(header) - idx
}
mediaType = header[idx+1 : idx+end]
}
// Verify this is base64 encoded
if !strings.Contains(header, ";base64") {
continue // Skip non-base64 encoded data
}
// Map media type to Bedrock format
var format types.ImageFormat
switch mediaType {
case "jpeg", "jpg":
format = types.ImageFormatJpeg
case "png":
format = types.ImageFormatPng
case "gif":
format = types.ImageFormatGif
case "webp":
format = types.ImageFormatWebp
default:
continue // Skip unsupported formats
}
// Check size before decoding to prevent excessive memory allocation
// Bedrock has a ~20MB request limit; cap decoded images at 10MB
const maxImageSize = 10 * 1024 * 1024
decodedLen := base64.StdEncoding.DecodedLen(len(parts[1]))
if decodedLen > maxImageSize {
log.Printf("bedrock: skipping image exceeding size limit (%d bytes > %d)", decodedLen, maxImageSize)
continue
}
// Decode base64 data
imageData, err := base64.StdEncoding.DecodeString(parts[1])
if err != nil {
log.Printf("bedrock: failed to decode base64 image data: %v", err)
continue
}
content = append(content, &types.ContentBlockMemberImage{
Value: types.ImageBlock{
Format: format,
Source: &types.ImageSourceMemberBytes{
Value: imageData,
},
},
})
}
}
// Bedrock requires at least one content block; add empty text if needed
if len(content) == 0 {
content = append(content, &types.ContentBlockMemberText{Value: ""})
}
return content
}
// buildAssistantContent builds Bedrock content blocks for an assistant message.
func buildAssistantContent(msg Message) []types.ContentBlock {
var content []types.ContentBlock
// Add text content if present
if msg.Content != "" {
content = append(content, &types.ContentBlockMemberText{
Value: msg.Content,
})
}
// Add tool use blocks
for _, tc := range msg.ToolCalls {
// Validate tool call ID - Bedrock requires non-empty ToolUseId
if strings.TrimSpace(tc.ID) == "" {
log.Printf("bedrock: skipping tool call with empty ID (name: %q)", tc.Name)
continue
}
// Resolve tool name: prefer tc.Name, fallback to tc.Function.Name
// (tc.Name/tc.Arguments are json:"-" and may be empty when from JSON)
toolName := tc.Name
if toolName == "" && tc.Function != nil {
toolName = tc.Function.Name
}
if strings.TrimSpace(toolName) == "" {
continue
}
// Resolve arguments: prefer tc.Arguments, fallback to parsing tc.Function.Arguments
args := tc.Arguments
if args == nil && tc.Function != nil && tc.Function.Arguments != "" {
if err := json.Unmarshal([]byte(tc.Function.Arguments), &args); err != nil {
log.Printf("bedrock: failed to parse Function.Arguments for tool %q: %v", toolName, err)
args = map[string]any{}
}
}
if args == nil {
args = map[string]any{}
}
// Convert arguments to a Bedrock document using NewLazyDocument
inputDoc := document.NewLazyDocument(args)
content = append(content, &types.ContentBlockMemberToolUse{
Value: types.ToolUseBlock{
ToolUseId: aws.String(tc.ID),
Name: aws.String(toolName),
Input: inputDoc,
},
})
}
// Bedrock requires at least one content block; add empty text if needed
if len(content) == 0 {
content = append(content, &types.ContentBlockMemberText{Value: ""})
}
return content
}
// convertTools converts tool definitions to Bedrock format.
func convertTools(tools []ToolDefinition) *types.ToolConfiguration {
bedrockTools := make([]types.Tool, 0, len(tools))
for _, tool := range tools {
// Skip tools with empty names
if strings.TrimSpace(tool.Function.Name) == "" {
continue
}
// Ensure parameters is not nil - default to minimal object schema
params := tool.Function.Parameters
if params == nil {
params = map[string]any{
"type": "object",
"properties": map[string]any{},
}
}
// Convert parameters schema to a Bedrock document
inputSchema := document.NewLazyDocument(params)
bedrockTools = append(bedrockTools, &types.ToolMemberToolSpec{
Value: types.ToolSpecification{
Name: aws.String(tool.Function.Name),
Description: aws.String(tool.Function.Description),
InputSchema: &types.ToolInputSchemaMemberJson{
Value: inputSchema,
},
},
})
}
return &types.ToolConfiguration{
Tools: bedrockTools,
}
}
// parseResponse converts Bedrock Converse output to LLMResponse.
func parseResponse(output *bedrockruntime.ConverseOutput) (*LLMResponse, error) {
var content strings.Builder
toolCalls := make([]ToolCall, 0)
// Process output content blocks
if output.Output != nil {
if msgOutput, ok := output.Output.(*types.ConverseOutputMemberMessage); ok {
for _, block := range msgOutput.Value.Content {
switch b := block.(type) {
case *types.ContentBlockMemberText:
content.WriteString(b.Value)
case *types.ContentBlockMemberToolUse:
// Unmarshal the document interface to a map
args := make(map[string]any)
if b.Value.Input != nil {
if err := b.Value.Input.UnmarshalSmithyDocument(&args); err != nil {
log.Printf("bedrock: failed to unmarshal tool input for tool %q (id %q): %v",
aws.ToString(b.Value.Name),
aws.ToString(b.Value.ToolUseId),
err,
)
args = make(map[string]any)
}
}
// Serialize arguments to JSON string for FunctionCall
argsJSON, err := json.Marshal(args)
if err != nil {
log.Printf("bedrock: failed to marshal tool arguments for tool %q (id %q): %v",
aws.ToString(b.Value.Name),
aws.ToString(b.Value.ToolUseId),
err,
)
argsJSON = []byte("{}")
}
toolCalls = append(toolCalls, ToolCall{
ID: aws.ToString(b.Value.ToolUseId),
Name: aws.ToString(b.Value.Name),
Arguments: args,
Function: &FunctionCall{
Name: aws.ToString(b.Value.Name),
Arguments: string(argsJSON),
},
})
}
}
}
}
// Map stop reason
finishReason := "stop"
switch output.StopReason {
case types.StopReasonToolUse:
finishReason = "tool_calls"
case types.StopReasonMaxTokens:
finishReason = "length"
case types.StopReasonEndTurn:
finishReason = "stop"
case types.StopReasonStopSequence:
finishReason = "stop"
case types.StopReasonContentFiltered:
finishReason = "content_filter"
}
// Build usage info
var usage *UsageInfo
if output.Usage != nil {
usage = &UsageInfo{
PromptTokens: int(aws.ToInt32(output.Usage.InputTokens)),
CompletionTokens: int(aws.ToInt32(output.Usage.OutputTokens)),
TotalTokens: int(aws.ToInt32(output.Usage.InputTokens)) + int(aws.ToInt32(output.Usage.OutputTokens)),
}
}
return &LLMResponse{
Content: content.String(),
ToolCalls: toolCalls,
FinishReason: finishReason,
Usage: usage,
}, nil
}
@@ -0,0 +1,541 @@
//go:build bedrock
// PicoClaw - Ultra-lightweight personal AI agent
// License: MIT
//
// Copyright (c) 2026 PicoClaw contributors
package bedrock
import (
"testing"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime/document"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
)
func TestConvertMessages_SystemPrompts(t *testing.T) {
messages := []Message{
{Role: "system", Content: "You are a helpful assistant."},
{Role: "user", Content: "Hello"},
}
bedrockMsgs, systemPrompts := convertMessages(messages)
assert.Len(t, systemPrompts, 1)
assert.Len(t, bedrockMsgs, 1)
// Check system prompt
textBlock, ok := systemPrompts[0].(*types.SystemContentBlockMemberText)
require.True(t, ok)
assert.Equal(t, "You are a helpful assistant.", textBlock.Value)
// Check user message
assert.Equal(t, types.ConversationRoleUser, bedrockMsgs[0].Role)
}
func TestConvertMessages_UserMessage(t *testing.T) {
messages := []Message{
{Role: "user", Content: "What is 2+2?"},
}
bedrockMsgs, systemPrompts := convertMessages(messages)
assert.Empty(t, systemPrompts)
assert.Len(t, bedrockMsgs, 1)
assert.Equal(t, types.ConversationRoleUser, bedrockMsgs[0].Role)
textBlock, ok := bedrockMsgs[0].Content[0].(*types.ContentBlockMemberText)
require.True(t, ok)
assert.Equal(t, "What is 2+2?", textBlock.Value)
}
func TestConvertMessages_AssistantMessage(t *testing.T) {
messages := []Message{
{Role: "assistant", Content: "The answer is 4."},
}
bedrockMsgs, _ := convertMessages(messages)
assert.Len(t, bedrockMsgs, 1)
assert.Equal(t, types.ConversationRoleAssistant, bedrockMsgs[0].Role)
textBlock, ok := bedrockMsgs[0].Content[0].(*types.ContentBlockMemberText)
require.True(t, ok)
assert.Equal(t, "The answer is 4.", textBlock.Value)
}
func TestConvertMessages_ToolResult(t *testing.T) {
messages := []Message{
{Role: "tool", Content: "Result from tool", ToolCallID: "call_123"},
}
bedrockMsgs, _ := convertMessages(messages)
assert.Len(t, bedrockMsgs, 1)
assert.Equal(t, types.ConversationRoleUser, bedrockMsgs[0].Role)
toolResult, ok := bedrockMsgs[0].Content[0].(*types.ContentBlockMemberToolResult)
require.True(t, ok)
assert.Equal(t, "call_123", aws.ToString(toolResult.Value.ToolUseId))
}
func TestConvertMessages_MultipleToolResultsMerged(t *testing.T) {
// When an assistant makes multiple tool calls, all tool results must be
// merged into a single user message for Bedrock
messages := []Message{
{Role: "user", Content: "What's the weather in NYC and LA?"},
{
Role: "assistant",
Content: "Let me check both cities.",
ToolCalls: []protocoltypes.ToolCall{
{ID: "call_nyc", Name: "get_weather", Arguments: map[string]any{"city": "NYC"}},
{ID: "call_la", Name: "get_weather", Arguments: map[string]any{"city": "LA"}},
},
},
{Role: "tool", Content: "NYC: 72°F, sunny", ToolCallID: "call_nyc"},
{Role: "tool", Content: "LA: 85°F, clear", ToolCallID: "call_la"},
}
bedrockMsgs, _ := convertMessages(messages)
// Should be: user message, assistant message, merged tool results (single user message)
assert.Len(t, bedrockMsgs, 3)
// First message: user
assert.Equal(t, types.ConversationRoleUser, bedrockMsgs[0].Role)
// Second message: assistant with tool calls
assert.Equal(t, types.ConversationRoleAssistant, bedrockMsgs[1].Role)
// Third message: merged tool results in single user message
assert.Equal(t, types.ConversationRoleUser, bedrockMsgs[2].Role)
assert.Len(t, bedrockMsgs[2].Content, 2) // Both tool results in one message
// Verify both tool results are present
result1, ok := bedrockMsgs[2].Content[0].(*types.ContentBlockMemberToolResult)
require.True(t, ok)
assert.Equal(t, "call_nyc", aws.ToString(result1.Value.ToolUseId))
result2, ok := bedrockMsgs[2].Content[1].(*types.ContentBlockMemberToolResult)
require.True(t, ok)
assert.Equal(t, "call_la", aws.ToString(result2.Value.ToolUseId))
}
func TestConvertMessages_AssistantWithToolCalls(t *testing.T) {
messages := []Message{
{
Role: "assistant",
Content: "Let me calculate that.",
ToolCalls: []protocoltypes.ToolCall{
{
ID: "call_456",
Name: "calculator",
Arguments: map[string]any{"expression": "2+2"},
},
},
},
}
bedrockMsgs, _ := convertMessages(messages)
assert.Len(t, bedrockMsgs, 1)
assert.Len(t, bedrockMsgs[0].Content, 2) // text + tool use
// Check text content
textBlock, ok := bedrockMsgs[0].Content[0].(*types.ContentBlockMemberText)
require.True(t, ok)
assert.Equal(t, "Let me calculate that.", textBlock.Value)
// Check tool use
toolUse, ok := bedrockMsgs[0].Content[1].(*types.ContentBlockMemberToolUse)
require.True(t, ok)
assert.Equal(t, "call_456", aws.ToString(toolUse.Value.ToolUseId))
assert.Equal(t, "calculator", aws.ToString(toolUse.Value.Name))
}
func TestConvertTools_Basic(t *testing.T) {
tools := []ToolDefinition{
{
Function: protocoltypes.ToolFunctionDefinition{
Name: "get_weather",
Description: "Get the current weather",
Parameters: map[string]any{
"type": "object",
"properties": map[string]any{
"location": map[string]any{"type": "string"},
},
},
},
},
}
toolConfig := convertTools(tools)
assert.NotNil(t, toolConfig)
assert.Len(t, toolConfig.Tools, 1)
toolSpec, ok := toolConfig.Tools[0].(*types.ToolMemberToolSpec)
require.True(t, ok)
assert.Equal(t, "get_weather", aws.ToString(toolSpec.Value.Name))
assert.Equal(t, "Get the current weather", aws.ToString(toolSpec.Value.Description))
}
func TestConvertTools_SkipsEmptyName(t *testing.T) {
tools := []ToolDefinition{
{
Function: protocoltypes.ToolFunctionDefinition{
Name: "",
Description: "Empty name tool",
},
},
{
Function: protocoltypes.ToolFunctionDefinition{
Name: " ",
Description: "Whitespace name tool",
},
},
{
Function: protocoltypes.ToolFunctionDefinition{
Name: "valid_tool",
Description: "Valid tool",
},
},
}
toolConfig := convertTools(tools)
assert.Len(t, toolConfig.Tools, 1)
toolSpec := toolConfig.Tools[0].(*types.ToolMemberToolSpec)
assert.Equal(t, "valid_tool", aws.ToString(toolSpec.Value.Name))
}
func TestConvertTools_NilParameters(t *testing.T) {
tools := []ToolDefinition{
{
Function: protocoltypes.ToolFunctionDefinition{
Name: "simple_tool",
Description: "A tool with no parameters",
Parameters: nil,
},
},
}
toolConfig := convertTools(tools)
assert.Len(t, toolConfig.Tools, 1)
// Should not panic and should create a valid tool
}
func TestBuildUserContent_TextOnly(t *testing.T) {
msg := Message{Content: "Hello world"}
content := buildUserContent(msg)
assert.Len(t, content, 1)
textBlock, ok := content[0].(*types.ContentBlockMemberText)
require.True(t, ok)
assert.Equal(t, "Hello world", textBlock.Value)
}
func TestBuildUserContent_WithImage(t *testing.T) {
// Base64-encoded 1x1 PNG (the provider doesn't validate image correctness,
// it just verifies the format and base64 decoding works)
b64Data := "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADUlEQVR4nGNgYAAAAAMAASsJTYQAAAAASUVORK5CYII="
msg := Message{
Content: "Look at this image",
Media: []string{"data:image/png;base64," + b64Data},
}
content := buildUserContent(msg)
assert.Len(t, content, 2)
// Check text
textBlock, ok := content[0].(*types.ContentBlockMemberText)
require.True(t, ok)
assert.Equal(t, "Look at this image", textBlock.Value)
// Check image
imageBlock, ok := content[1].(*types.ContentBlockMemberImage)
require.True(t, ok)
assert.Equal(t, types.ImageFormatPng, imageBlock.Value.Format)
}
func TestBuildUserContent_SkipsInvalidBase64(t *testing.T) {
msg := Message{
Content: "Invalid image",
Media: []string{"data:image/png;base64,not-valid-base64!!!"},
}
content := buildUserContent(msg)
// Should only have text, image should be skipped
assert.Len(t, content, 1)
}
func TestBuildUserContent_SkipsNonBase64Data(t *testing.T) {
msg := Message{
Content: "Non-base64 image",
Media: []string{"data:image/png,raw-data-here"},
}
content := buildUserContent(msg)
// Should only have text, non-base64 image should be skipped
assert.Len(t, content, 1)
}
func TestBuildAssistantContent_SkipsEmptyToolName(t *testing.T) {
msg := Message{
Content: "Response",
ToolCalls: []protocoltypes.ToolCall{
{ID: "1", Name: "", Arguments: map[string]any{}},
{ID: "2", Name: " ", Arguments: map[string]any{}},
{ID: "3", Name: "valid", Arguments: map[string]any{}},
},
}
content := buildAssistantContent(msg)
// Should have text + 1 valid tool
assert.Len(t, content, 2)
}
func TestBuildAssistantContent_NilArguments(t *testing.T) {
msg := Message{
ToolCalls: []protocoltypes.ToolCall{
{ID: "1", Name: "tool", Arguments: nil},
},
}
content := buildAssistantContent(msg)
assert.Len(t, content, 1)
toolUse, ok := content[0].(*types.ContentBlockMemberToolUse)
require.True(t, ok)
assert.NotNil(t, toolUse.Value.Input)
}
func TestBuildAssistantContent_FunctionFallback(t *testing.T) {
// When Name/Arguments are empty (json:"-"), should fallback to Function fields
msg := Message{
ToolCalls: []protocoltypes.ToolCall{
{
ID: "1",
Name: "", // empty, should fallback to Function.Name
Function: &protocoltypes.FunctionCall{
Name: "fallback_tool",
Arguments: `{"key":"value"}`,
},
},
},
}
content := buildAssistantContent(msg)
assert.Len(t, content, 1)
toolUse, ok := content[0].(*types.ContentBlockMemberToolUse)
require.True(t, ok)
assert.Equal(t, "fallback_tool", aws.ToString(toolUse.Value.Name))
}
func TestParseResponse_TextOnly(t *testing.T) {
output := &bedrockruntime.ConverseOutput{
Output: &types.ConverseOutputMemberMessage{
Value: types.Message{
Role: types.ConversationRoleAssistant,
Content: []types.ContentBlock{
&types.ContentBlockMemberText{Value: "Hello!"},
},
},
},
StopReason: types.StopReasonEndTurn,
Usage: &types.TokenUsage{
InputTokens: aws.Int32(10),
OutputTokens: aws.Int32(5),
},
}
resp, err := parseResponse(output)
require.NoError(t, err)
assert.Equal(t, "Hello!", resp.Content)
assert.Equal(t, "stop", resp.FinishReason)
assert.Empty(t, resp.ToolCalls)
assert.Equal(t, 10, resp.Usage.PromptTokens)
assert.Equal(t, 5, resp.Usage.CompletionTokens)
}
func TestParseResponse_StopReasons(t *testing.T) {
tests := []struct {
stopReason types.StopReason
expectedFinish string
}{
{types.StopReasonEndTurn, "stop"},
{types.StopReasonToolUse, "tool_calls"},
{types.StopReasonMaxTokens, "length"},
{types.StopReasonStopSequence, "stop"},
{types.StopReasonContentFiltered, "content_filter"},
}
for _, tt := range tests {
t.Run(string(tt.stopReason), func(t *testing.T) {
output := &bedrockruntime.ConverseOutput{
Output: &types.ConverseOutputMemberMessage{
Value: types.Message{
Content: []types.ContentBlock{
&types.ContentBlockMemberText{Value: "test"},
},
},
},
StopReason: tt.stopReason,
}
resp, err := parseResponse(output)
require.NoError(t, err)
assert.Equal(t, tt.expectedFinish, resp.FinishReason)
})
}
}
func TestParseResponse_WithToolCalls(t *testing.T) {
// Note: document.NewLazyDocument has limitations with UnmarshalSmithyDocument in tests,
// so we test the structure extraction and verify Arguments gets populated (even if empty
// due to SDK limitations). The actual unmarshal works correctly at runtime.
toolInput := document.NewLazyDocument(map[string]any{
"location": "San Francisco",
"unit": "celsius",
})
output := &bedrockruntime.ConverseOutput{
Output: &types.ConverseOutputMemberMessage{
Value: types.Message{
Role: types.ConversationRoleAssistant,
Content: []types.ContentBlock{
&types.ContentBlockMemberText{Value: "Let me check the weather."},
&types.ContentBlockMemberToolUse{
Value: types.ToolUseBlock{
ToolUseId: aws.String("call_weather_123"),
Name: aws.String("get_weather"),
Input: toolInput,
},
},
},
},
},
StopReason: types.StopReasonToolUse,
Usage: &types.TokenUsage{
InputTokens: aws.Int32(20),
OutputTokens: aws.Int32(15),
},
}
resp, err := parseResponse(output)
require.NoError(t, err)
assert.Equal(t, "Let me check the weather.", resp.Content)
assert.Equal(t, "tool_calls", resp.FinishReason)
assert.Len(t, resp.ToolCalls, 1)
// Verify tool call ID and Name are extracted correctly
tc := resp.ToolCalls[0]
assert.Equal(t, "call_weather_123", tc.ID)
assert.Equal(t, "get_weather", tc.Name)
// Verify Function fields are also populated
require.NotNil(t, tc.Function)
assert.Equal(t, "get_weather", tc.Function.Name)
// Verify Arguments is not nil (content may vary due to SDK limitations in tests)
assert.NotNil(t, tc.Arguments)
// Verify usage
assert.Equal(t, 20, resp.Usage.PromptTokens)
assert.Equal(t, 15, resp.Usage.CompletionTokens)
assert.Equal(t, 35, resp.Usage.TotalTokens)
}
func TestParseResponse_MultipleToolCalls(t *testing.T) {
output := &bedrockruntime.ConverseOutput{
Output: &types.ConverseOutputMemberMessage{
Value: types.Message{
Role: types.ConversationRoleAssistant,
Content: []types.ContentBlock{
&types.ContentBlockMemberToolUse{
Value: types.ToolUseBlock{
ToolUseId: aws.String("call_1"),
Name: aws.String("tool_a"),
Input: document.NewLazyDocument(map[string]any{"arg": "value1"}),
},
},
&types.ContentBlockMemberToolUse{
Value: types.ToolUseBlock{
ToolUseId: aws.String("call_2"),
Name: aws.String("tool_b"),
Input: document.NewLazyDocument(map[string]any{"arg": "value2"}),
},
},
},
},
},
StopReason: types.StopReasonToolUse,
}
resp, err := parseResponse(output)
require.NoError(t, err)
assert.Equal(t, "tool_calls", resp.FinishReason)
assert.Len(t, resp.ToolCalls, 2)
// Verify tool call structure
assert.Equal(t, "call_1", resp.ToolCalls[0].ID)
assert.Equal(t, "tool_a", resp.ToolCalls[0].Name)
assert.NotNil(t, resp.ToolCalls[0].Arguments)
assert.NotNil(t, resp.ToolCalls[0].Function)
assert.Equal(t, "tool_a", resp.ToolCalls[0].Function.Name)
assert.Equal(t, "call_2", resp.ToolCalls[1].ID)
assert.Equal(t, "tool_b", resp.ToolCalls[1].Name)
assert.NotNil(t, resp.ToolCalls[1].Arguments)
assert.NotNil(t, resp.ToolCalls[1].Function)
assert.Equal(t, "tool_b", resp.ToolCalls[1].Function.Name)
}
func TestParseResponse_ToolCallWithNilInput(t *testing.T) {
output := &bedrockruntime.ConverseOutput{
Output: &types.ConverseOutputMemberMessage{
Value: types.Message{
Role: types.ConversationRoleAssistant,
Content: []types.ContentBlock{
&types.ContentBlockMemberToolUse{
Value: types.ToolUseBlock{
ToolUseId: aws.String("call_nil"),
Name: aws.String("no_args_tool"),
Input: nil,
},
},
},
},
},
StopReason: types.StopReasonToolUse,
}
resp, err := parseResponse(output)
require.NoError(t, err)
assert.Len(t, resp.ToolCalls, 1)
assert.Equal(t, "call_nil", resp.ToolCalls[0].ID)
assert.Equal(t, "no_args_tool", resp.ToolCalls[0].Name)
// Arguments should be empty map, not nil
assert.NotNil(t, resp.ToolCalls[0].Arguments)
assert.Empty(t, resp.ToolCalls[0].Arguments)
}
+73
View File
@@ -0,0 +1,73 @@
//go:build !bedrock
// PicoClaw - Ultra-lightweight personal AI agent
// License: MIT
//
// Copyright (c) 2026 PicoClaw contributors
// Package bedrock provides a stub implementation when built without the bedrock tag.
// To enable AWS Bedrock support, build with: go build -tags bedrock
package bedrock
import (
"context"
"fmt"
"time"
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
)
type (
LLMResponse = protocoltypes.LLMResponse
Message = protocoltypes.Message
ToolDefinition = protocoltypes.ToolDefinition
)
// Provider is a stub that returns an error when Bedrock support is not compiled in.
type Provider struct{}
// Option is a no-op when Bedrock is not enabled.
type Option func(*providerConfig)
type providerConfig struct{}
// WithRegion is a no-op when Bedrock is not enabled.
func WithRegion(region string) Option {
return func(c *providerConfig) {}
}
// WithProfile is a no-op when Bedrock is not enabled.
func WithProfile(profile string) Option {
return func(c *providerConfig) {}
}
// WithBaseEndpoint is a no-op when Bedrock is not enabled.
func WithBaseEndpoint(endpoint string) Option {
return func(c *providerConfig) {}
}
// WithRequestTimeout is a no-op when Bedrock is not enabled.
func WithRequestTimeout(timeout time.Duration) Option {
return func(c *providerConfig) {}
}
// NewProvider returns an error indicating Bedrock support is not compiled in.
func NewProvider(ctx context.Context, opts ...Option) (*Provider, error) {
return nil, fmt.Errorf("bedrock provider not available: build with -tags bedrock to enable AWS Bedrock support")
}
// Chat returns an error - this should never be called since NewProvider fails.
func (p *Provider) Chat(
ctx context.Context,
messages []Message,
tools []ToolDefinition,
model string,
options map[string]any,
) (*LLMResponse, error) {
return nil, fmt.Errorf("bedrock provider not available: build with -tags bedrock to enable AWS Bedrock support")
}
// GetDefaultModel returns an empty string.
func (p *Provider) GetDefaultModel() string {
return ""
}
@@ -0,0 +1,35 @@
//go:build !bedrock
// PicoClaw - Ultra-lightweight personal AI agent
// License: MIT
//
// Copyright (c) 2026 PicoClaw contributors
package bedrock
import (
"context"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewProvider_ReturnsStubError(t *testing.T) {
provider, err := NewProvider(context.Background())
assert.Nil(t, provider)
require.Error(t, err)
assert.True(t, strings.Contains(err.Error(), "build with -tags bedrock"),
"error should mention build tag requirement, got: %s", err.Error())
}
func TestNewProvider_WithOptions_ReturnsStubError(t *testing.T) {
provider, err := NewProvider(context.Background(), WithRegion("us-west-2"), WithProfile("test"))
assert.Nil(t, provider)
require.Error(t, err)
assert.True(t, strings.Contains(err.Error(), "build with -tags bedrock"),
"error should mention build tag requirement, got: %s", err.Error())
}
+42 -2
View File
@@ -6,12 +6,15 @@
package providers package providers
import ( import (
"context"
"fmt" "fmt"
"strings" "strings"
"time"
"github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/config"
anthropicmessages "github.com/sipeed/picoclaw/pkg/providers/anthropic_messages" anthropicmessages "github.com/sipeed/picoclaw/pkg/providers/anthropic_messages"
"github.com/sipeed/picoclaw/pkg/providers/azure" "github.com/sipeed/picoclaw/pkg/providers/azure"
"github.com/sipeed/picoclaw/pkg/providers/bedrock"
) )
// createClaudeAuthProvider creates a Claude provider using OAuth credentials from auth store. // createClaudeAuthProvider creates a Claude provider using OAuth credentials from auth store.
@@ -55,8 +58,9 @@ func ExtractProtocol(model string) (protocol, modelID string) {
// CreateProviderFromConfig creates a provider based on the ModelConfig. // CreateProviderFromConfig creates a provider based on the ModelConfig.
// It uses the protocol prefix in the Model field to determine which provider to create. // It uses the protocol prefix in the Model field to determine which provider to create.
// Supported protocols: openai, litellm, novita, anthropic, anthropic-messages, // Supported protocol families include OpenAI-compatible prefixes (e.g., openai, openrouter, groq, gemini),
// antigravity, claude-cli, codex-cli, github-copilot // Azure OpenAI, Amazon Bedrock, Anthropic (including messages), and various CLI/compatibility shims.
// See the switch on protocol in this function for the authoritative list.
// Returns the provider, the model ID (without protocol prefix), and any error. // Returns the provider, the model ID (without protocol prefix), and any error.
func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, error) { func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, error) {
if cfg == nil { if cfg == nil {
@@ -114,6 +118,42 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
cfg.RequestTimeout, cfg.RequestTimeout,
), modelID, nil ), modelID, nil
case "bedrock":
// AWS Bedrock uses AWS SDK credentials (env vars, profiles, IAM roles, etc.)
// api_base can be:
// - A full endpoint URL: https://bedrock-runtime.us-east-1.amazonaws.com
// - A region name: us-east-1 (AWS SDK resolves endpoint automatically)
var opts []bedrock.Option
if cfg.APIBase != "" {
if !strings.Contains(cfg.APIBase, "://") {
// Treat as region: let AWS SDK resolve the correct endpoint
// (supports all AWS partitions: aws, aws-cn, aws-us-gov, etc.)
opts = append(opts, bedrock.WithRegion(cfg.APIBase))
} else {
// Full endpoint URL provided (for custom endpoints or testing)
opts = append(opts, bedrock.WithBaseEndpoint(cfg.APIBase))
}
}
// Use a separate timeout for AWS config loading (credential resolution can block)
initTimeout := 30 * time.Second
if cfg.RequestTimeout > 0 {
reqTimeout := time.Duration(cfg.RequestTimeout) * time.Second
// Set request timeout for API calls
opts = append(opts, bedrock.WithRequestTimeout(reqTimeout))
// Ensure init timeout is at least as large as request timeout
if reqTimeout > initTimeout {
initTimeout = reqTimeout
}
}
ctx, cancel := context.WithTimeout(context.Background(), initTimeout)
defer cancel()
// Note: AWS_PROFILE env var is automatically used by AWS SDK
provider, err := bedrock.NewProvider(ctx, opts...)
if err != nil {
return nil, "", fmt.Errorf("creating bedrock provider: %w", err)
}
return provider, modelID, nil
case "litellm", "openrouter", "groq", "zhipu", "gemini", "nvidia", case "litellm", "openrouter", "groq", "zhipu", "gemini", "nvidia",
"ollama", "moonshot", "shengsuanyun", "deepseek", "cerebras", "ollama", "moonshot", "shengsuanyun", "deepseek", "cerebras",
"vivgrid", "volcengine", "vllm", "qwen", "qwen-intl", "qwen-international", "dashscope-intl", "vivgrid", "volcengine", "vllm", "qwen", "qwen-intl", "qwen-international", "dashscope-intl",
+75
View File
@@ -700,3 +700,78 @@ func TestCreateProviderFromConfig_MinimaxPreservesUserExtraBody(t *testing.T) {
t.Fatalf("custom_field = %v, want test", got) t.Fatalf("custom_field = %v, want test", got)
} }
} }
func TestCreateProviderFromConfig_Bedrock(t *testing.T) {
// Set dummy AWS env vars to make test deterministic
t.Setenv("AWS_ACCESS_KEY_ID", "test-key")
t.Setenv("AWS_SECRET_ACCESS_KEY", "test-secret")
t.Setenv("AWS_EC2_METADATA_DISABLED", "true")
// Clear profile-related env vars to avoid loading shared config
t.Setenv("AWS_PROFILE", "")
t.Setenv("AWS_DEFAULT_PROFILE", "")
t.Setenv("AWS_SDK_LOAD_CONFIG", "")
t.Setenv("AWS_SHARED_CREDENTIALS_FILE", "")
cfg := &config.ModelConfig{
ModelName: "bedrock-claude",
Model: "bedrock/us.anthropic.claude-sonnet-4-20250514-v1:0",
APIBase: "us-west-2", // Region (also sets AWS region)
}
provider, modelID, err := CreateProviderFromConfig(cfg)
if err == nil {
// Provider created successfully (built with -tags bedrock)
if provider == nil {
t.Error("provider is nil on success")
}
if modelID != "us.anthropic.claude-sonnet-4-20250514-v1:0" {
t.Errorf("modelID = %q, want %q", modelID, "us.anthropic.claude-sonnet-4-20250514-v1:0")
}
return
}
errMsg := err.Error()
// When built without -tags bedrock, expect stub error
if strings.Contains(errMsg, "build with -tags bedrock") {
return // Expected stub error
}
// Unexpected error - fail the test
t.Errorf("unexpected error from bedrock provider: %v", err)
}
func TestCreateProviderFromConfig_BedrockWithEndpointURL(t *testing.T) {
// Set dummy AWS env vars to make test deterministic
t.Setenv("AWS_ACCESS_KEY_ID", "test-key")
t.Setenv("AWS_SECRET_ACCESS_KEY", "test-secret")
t.Setenv("AWS_REGION", "us-east-1") // Required when using endpoint URL
t.Setenv("AWS_EC2_METADATA_DISABLED", "true")
// Clear profile-related env vars to avoid loading shared config
t.Setenv("AWS_PROFILE", "")
t.Setenv("AWS_DEFAULT_PROFILE", "")
t.Setenv("AWS_SDK_LOAD_CONFIG", "")
t.Setenv("AWS_SHARED_CREDENTIALS_FILE", "")
cfg := &config.ModelConfig{
ModelName: "bedrock-claude",
Model: "bedrock/us.anthropic.claude-sonnet-4-20250514-v1:0",
APIBase: "https://bedrock-runtime.us-east-1.amazonaws.com", // Full endpoint URL
}
provider, modelID, err := CreateProviderFromConfig(cfg)
if err == nil {
// Provider created successfully (built with -tags bedrock)
if provider == nil {
t.Error("provider is nil on success")
}
if modelID != "us.anthropic.claude-sonnet-4-20250514-v1:0" {
t.Errorf("modelID = %q, want %q", modelID, "us.anthropic.claude-sonnet-4-20250514-v1:0")
}
return
}
errMsg := err.Error()
// When built without -tags bedrock, expect stub error
if strings.Contains(errMsg, "build with -tags bedrock") {
return // Expected stub error
}
// Unexpected error - fail the test
t.Errorf("unexpected error from bedrock provider: %v", err)
}