mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
feat(providers): add AWS Bedrock provider (#1903)
Add support for AWS Bedrock as an LLM provider using the Converse API. The implementation is behind a build tag (-tags bedrock) to keep the default binary size small. Features: - AWS SDK v2 with automatic credential chain (env vars, profiles, IAM roles) - Converse API for unified access to Claude, Llama, Mistral models - Tool/function calling support with proper document handling - Image support with base64 decoding and size limits - Request timeout configuration - Region validation and endpoint resolution for all AWS partitions Usage: go build -tags bedrock model: bedrock/us.anthropic.claude-sonnet-4-20250514-v1:0 api_base: us-east-1 (or full endpoint URL)
This commit is contained in:
@@ -373,6 +373,9 @@ PicoClaw supports 30+ LLM providers through the `model_list` configuration. Use
|
||||
| [Azure OpenAI](https://portal.azure.com/) | `azure/` | Required | Enterprise Azure deployment |
|
||||
| [GitHub Copilot](https://github.com/features/copilot) | `github-copilot/` | OAuth | Device code login |
|
||||
| [Antigravity](https://console.cloud.google.com/) | `antigravity/` | OAuth | Google Cloud AI |
|
||||
| [AWS Bedrock](https://console.aws.amazon.com/bedrock)* | `bedrock/` | AWS credentials | Claude, Llama, Mistral on AWS |
|
||||
|
||||
> \* AWS Bedrock requires build tag: `go build -tags bedrock`. Set `api_base` to a region name (e.g., `us-east-1`) for automatic endpoint resolution across all AWS partitions (aws, aws-cn, aws-us-gov). When using a full endpoint URL instead, you must also configure `AWS_REGION` via environment variable or AWS config/profile.
|
||||
|
||||
<details>
|
||||
<summary><b>Local deployment (Ollama, vLLM, etc.)</b></summary>
|
||||
|
||||
@@ -7,6 +7,9 @@ require (
|
||||
github.com/BurntSushi/toml v1.6.0
|
||||
github.com/adhocore/gronx v1.19.6
|
||||
github.com/anthropics/anthropic-sdk-go v1.26.0
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.4
|
||||
github.com/aws/aws-sdk-go-v2/config v1.32.12
|
||||
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.50.2
|
||||
github.com/bwmarrin/discordgo v0.29.0
|
||||
github.com/caarlos0/env/v11 v11.4.0
|
||||
github.com/ergochat/irc-go v0.6.0
|
||||
@@ -40,6 +43,19 @@ require (
|
||||
|
||||
require (
|
||||
filippo.io/edwards25519 v1.2.0 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.7 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.19.12 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.20 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.20 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.20 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.6 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.7 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.20 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/signin v1.0.8 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/sso v1.30.13 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.17 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.41.9 // indirect
|
||||
github.com/aws/smithy-go v1.24.2 // indirect
|
||||
github.com/beeper/argo-go v1.1.2 // indirect
|
||||
github.com/coder/websocket v1.8.14 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
|
||||
@@ -17,6 +17,38 @@ github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwTo
|
||||
github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY=
|
||||
github.com/anthropics/anthropic-sdk-go v1.26.0 h1:oUTzFaUpAevfuELAP1sjL6CQJ9HHAfT7CoSYSac11PY=
|
||||
github.com/anthropics/anthropic-sdk-go v1.26.0/go.mod h1:qUKmaW+uuPB64iy1l+4kOSvaLqPXnHTTBKH6RVZ7q5Q=
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.4 h1:10f50G7WyU02T56ox1wWXq+zTX9I1zxG46HYuG1hH/k=
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.4/go.mod h1:mwsPRE8ceUUpiTgF7QmQIJ7lgsKUPQOUl3o72QBrE1o=
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.7 h1:3kGOqnh1pPeddVa/E37XNTaWJ8W6vrbYV9lJEkCnhuY=
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.7/go.mod h1:lyw7GFp3qENLh7kwzf7iMzAxDn+NzjXEAGjKS2UOKqI=
|
||||
github.com/aws/aws-sdk-go-v2/config v1.32.12 h1:O3csC7HUGn2895eNrLytOJQdoL2xyJy0iYXhoZ1OmP0=
|
||||
github.com/aws/aws-sdk-go-v2/config v1.32.12/go.mod h1:96zTvoOFR4FURjI+/5wY1vc1ABceROO4lWgWJuxgy0g=
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.19.12 h1:oqtA6v+y5fZg//tcTWahyN9PEn5eDU/Wpvc2+kJ4aY8=
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.19.12/go.mod h1:U3R1RtSHx6NB0DvEQFGyf/0sbrpJrluENHdPy1j/3TE=
|
||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.20 h1:zOgq3uezl5nznfoK3ODuqbhVg1JzAGDUhXOsU0IDCAo=
|
||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.20/go.mod h1:z/MVwUARehy6GAg/yQ1GO2IMl0k++cu1ohP9zo887wE=
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.20 h1:CNXO7mvgThFGqOFgbNAP2nol2qAWBOGfqR/7tQlvLmc=
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.20/go.mod h1:oydPDJKcfMhgfcgBUZaG+toBbwy8yPWubJXBVERtI4o=
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.20 h1:tN6W/hg+pkM+tf9XDkWUbDEjGLb+raoBMFsTodcoYKw=
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.20/go.mod h1:YJ898MhD067hSHA6xYCx5ts/jEd8BSOLtQDL3iZsvbc=
|
||||
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.6 h1:qYQ4pzQ2Oz6WpQ8T3HvGHnZydA72MnLuFK9tJwmrbHw=
|
||||
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.6/go.mod h1:O3h0IK87yXci+kg6flUKzJnWeziQUKciKrLjcatSNcY=
|
||||
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.50.2 h1:x0eGAWpd1B5I/vMtrB4Q4Zuc3CXWI8wjHfPPqBSrKmM=
|
||||
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.50.2/go.mod h1:V9oTWSDC2MtS1DR71hbNET/bZ8psQp022amEBe1grJc=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.7 h1:5EniKhLZe4xzL7a+fU3C2tfUN4nWIqlLesfrjkuPFTY=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.7/go.mod h1:x0nZssQ3qZSnIcePWLvcoFisRXJzcTVvYpAAdYX8+GI=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.20 h1:2HvVAIq+YqgGotK6EkMf+KIEqTISmTYh5zLpYyeTo1Y=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.20/go.mod h1:V4X406Y666khGa8ghKmphma/7C0DAtEQYhkq9z4vpbk=
|
||||
github.com/aws/aws-sdk-go-v2/service/signin v1.0.8 h1:0GFOLzEbOyZABS3PhYfBIx2rNBACYcKty+XGkTgw1ow=
|
||||
github.com/aws/aws-sdk-go-v2/service/signin v1.0.8/go.mod h1:LXypKvk85AROkKhOG6/YEcHFPoX+prKTowKnVdcaIxE=
|
||||
github.com/aws/aws-sdk-go-v2/service/sso v1.30.13 h1:kiIDLZ005EcKomYYITtfsjn7dtOwHDOFy7IbPXKek2o=
|
||||
github.com/aws/aws-sdk-go-v2/service/sso v1.30.13/go.mod h1:2h/xGEowcW/g38g06g3KpRWDlT+OTfxxI0o1KqayAB8=
|
||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.17 h1:jzKAXIlhZhJbnYwHbvUQZEB8KfgAEuG0dc08Bkda7NU=
|
||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.17/go.mod h1:Al9fFsXjv4KfbzQHGe6V4NZSZQXecFcvaIF4e70FoRA=
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.41.9 h1:Cng+OOwCHmFljXIxpEVXAGMnBia8MSU6Ch5i9PgBkcU=
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.41.9/go.mod h1:LrlIndBDdjA/EeXeyNBle+gyCwTlizzW5ycgWnvIxkk=
|
||||
github.com/aws/smithy-go v1.24.2 h1:FzA3bu/nt/vDvmnkg+R8Xl46gmzEDam6mZ1hzmwXFng=
|
||||
github.com/aws/smithy-go v1.24.2/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc=
|
||||
github.com/beeper/argo-go v1.1.2 h1:UQI2G8F+NLfGTOmTUI0254pGKx/HUU/etbUGTJv91Fs=
|
||||
github.com/beeper/argo-go v1.1.2/go.mod h1:M+LJAnyowKVQ6Rdj6XYGEn+qcVFkb3R/MUpqkGR0hM4=
|
||||
github.com/bwmarrin/discordgo v0.29.0 h1:FmWeXFaKUwrcL3Cx65c20bTRW+vOb6k8AnaP+EgjDno=
|
||||
|
||||
@@ -0,0 +1,580 @@
|
||||
//go:build bedrock
|
||||
|
||||
// PicoClaw - Ultra-lightweight personal AI agent
|
||||
// License: MIT
|
||||
//
|
||||
// Copyright (c) 2026 PicoClaw contributors
|
||||
|
||||
// Package bedrock implements the LLM provider interface for AWS Bedrock.
|
||||
// It uses the Bedrock Runtime Converse API for unified access to multiple
|
||||
// model families (Claude, Llama, Mistral, etc.) with tool/function calling support.
|
||||
package bedrock
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"math"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
"github.com/aws/aws-sdk-go-v2/config"
|
||||
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
|
||||
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime/document"
|
||||
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers/common"
|
||||
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
|
||||
)
|
||||
|
||||
type (
|
||||
ToolCall = protocoltypes.ToolCall
|
||||
FunctionCall = protocoltypes.FunctionCall
|
||||
LLMResponse = protocoltypes.LLMResponse
|
||||
UsageInfo = protocoltypes.UsageInfo
|
||||
Message = protocoltypes.Message
|
||||
ToolDefinition = protocoltypes.ToolDefinition
|
||||
ToolFunctionDefinition = protocoltypes.ToolFunctionDefinition
|
||||
)
|
||||
|
||||
// Provider implements the LLM provider interface for AWS Bedrock.
|
||||
type Provider struct {
|
||||
client *bedrockruntime.Client
|
||||
region string
|
||||
requestTimeout time.Duration
|
||||
}
|
||||
|
||||
// Option configures the Bedrock Provider.
|
||||
type Option func(*providerConfig)
|
||||
|
||||
type providerConfig struct {
|
||||
region string
|
||||
profile string
|
||||
baseEndpoint string
|
||||
requestTimeout time.Duration
|
||||
}
|
||||
|
||||
// WithRegion sets the AWS region for Bedrock requests.
|
||||
func WithRegion(region string) Option {
|
||||
return func(c *providerConfig) {
|
||||
c.region = region
|
||||
}
|
||||
}
|
||||
|
||||
// WithProfile sets the AWS profile to use for credentials.
|
||||
func WithProfile(profile string) Option {
|
||||
return func(c *providerConfig) {
|
||||
c.profile = profile
|
||||
}
|
||||
}
|
||||
|
||||
// WithBaseEndpoint sets a custom Bedrock endpoint URL.
|
||||
// Example: https://bedrock-runtime.us-east-1.amazonaws.com
|
||||
func WithBaseEndpoint(endpoint string) Option {
|
||||
return func(c *providerConfig) {
|
||||
c.baseEndpoint = endpoint
|
||||
}
|
||||
}
|
||||
|
||||
// WithRequestTimeout sets the timeout for Bedrock API requests.
|
||||
func WithRequestTimeout(timeout time.Duration) Option {
|
||||
return func(c *providerConfig) {
|
||||
c.requestTimeout = timeout
|
||||
}
|
||||
}
|
||||
|
||||
// NewProvider creates a new AWS Bedrock provider.
|
||||
// It uses the default AWS credential chain (env vars, shared config, IAM roles, etc.).
|
||||
func NewProvider(ctx context.Context, opts ...Option) (*Provider, error) {
|
||||
pc := &providerConfig{}
|
||||
for _, opt := range opts {
|
||||
opt(pc)
|
||||
}
|
||||
|
||||
// Build AWS config options
|
||||
var configOpts []func(*config.LoadOptions) error
|
||||
|
||||
if pc.region != "" {
|
||||
configOpts = append(configOpts, config.WithRegion(pc.region))
|
||||
}
|
||||
|
||||
if pc.profile != "" {
|
||||
configOpts = append(configOpts, config.WithSharedConfigProfile(pc.profile))
|
||||
}
|
||||
|
||||
// Load AWS config with automatic credential discovery
|
||||
cfg, err := config.LoadDefaultConfig(ctx, configOpts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("loading AWS config: %w", err)
|
||||
}
|
||||
|
||||
// Validate region is set - required for Bedrock request signing
|
||||
if cfg.Region == "" {
|
||||
return nil, fmt.Errorf("AWS region not configured: set AWS_REGION, AWS_DEFAULT_REGION, or use WithRegion option")
|
||||
}
|
||||
|
||||
// Build client options
|
||||
var clientOpts []func(*bedrockruntime.Options)
|
||||
if pc.baseEndpoint != "" {
|
||||
clientOpts = append(clientOpts, func(o *bedrockruntime.Options) {
|
||||
o.BaseEndpoint = aws.String(pc.baseEndpoint)
|
||||
})
|
||||
}
|
||||
|
||||
client := bedrockruntime.NewFromConfig(cfg, clientOpts...)
|
||||
|
||||
return &Provider{
|
||||
client: client,
|
||||
region: cfg.Region,
|
||||
requestTimeout: pc.requestTimeout,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Chat sends messages to AWS Bedrock using the Converse API.
|
||||
func (p *Provider) Chat(
|
||||
ctx context.Context,
|
||||
messages []Message,
|
||||
tools []ToolDefinition,
|
||||
model string,
|
||||
options map[string]any,
|
||||
) (*LLMResponse, error) {
|
||||
// Apply request timeout if context doesn't already have a deadline.
|
||||
// Use explicit timeout if set, otherwise fall back to common default.
|
||||
effectiveTimeout := p.requestTimeout
|
||||
if effectiveTimeout <= 0 {
|
||||
effectiveTimeout = common.DefaultRequestTimeout
|
||||
}
|
||||
if _, hasDeadline := ctx.Deadline(); !hasDeadline {
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithTimeout(ctx, effectiveTimeout)
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
// Build the Converse API input
|
||||
input := &bedrockruntime.ConverseInput{
|
||||
ModelId: aws.String(model),
|
||||
}
|
||||
|
||||
// Convert messages to Bedrock format
|
||||
bedrockMessages, systemPrompts := convertMessages(messages)
|
||||
input.Messages = bedrockMessages
|
||||
|
||||
// Set system prompts if any
|
||||
if len(systemPrompts) > 0 {
|
||||
input.System = systemPrompts
|
||||
}
|
||||
|
||||
// Set inference configuration only when options are provided
|
||||
var inferenceConfig *types.InferenceConfiguration
|
||||
|
||||
if maxTokens, ok := common.AsInt(options["max_tokens"]); ok && maxTokens > 0 {
|
||||
if inferenceConfig == nil {
|
||||
inferenceConfig = &types.InferenceConfiguration{}
|
||||
}
|
||||
// Clamp to int32 range to avoid overflow
|
||||
if maxTokens > math.MaxInt32 {
|
||||
maxTokens = math.MaxInt32
|
||||
}
|
||||
inferenceConfig.MaxTokens = aws.Int32(int32(maxTokens))
|
||||
}
|
||||
|
||||
if temp, ok := common.AsFloat(options["temperature"]); ok {
|
||||
if inferenceConfig == nil {
|
||||
inferenceConfig = &types.InferenceConfiguration{}
|
||||
}
|
||||
inferenceConfig.Temperature = aws.Float32(float32(temp))
|
||||
}
|
||||
|
||||
if inferenceConfig != nil {
|
||||
input.InferenceConfig = inferenceConfig
|
||||
}
|
||||
|
||||
// Convert tools to Bedrock format
|
||||
// Only set ToolConfig if at least one valid tool was produced
|
||||
if len(tools) > 0 {
|
||||
toolConfig := convertTools(tools)
|
||||
if len(toolConfig.Tools) > 0 {
|
||||
input.ToolConfig = toolConfig
|
||||
}
|
||||
}
|
||||
|
||||
// Call Bedrock Converse API
|
||||
output, err := p.client.Converse(ctx, input)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("bedrock converse: %w", err)
|
||||
}
|
||||
|
||||
// Parse the response
|
||||
return parseResponse(output)
|
||||
}
|
||||
|
||||
// GetDefaultModel returns an empty string as Bedrock models are user-configured.
|
||||
func (p *Provider) GetDefaultModel() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Region returns the AWS region configured for this Provider.
|
||||
func (p *Provider) Region() string {
|
||||
return p.region
|
||||
}
|
||||
|
||||
// convertMessages converts internal messages to Bedrock Converse format.
|
||||
// Returns the conversation messages and any system prompts separately.
|
||||
// Note: Bedrock requires all tool results for a given assistant turn to be in a single
|
||||
// user message with multiple ToolResultBlock content blocks. This function merges
|
||||
// consecutive tool result messages accordingly.
|
||||
func convertMessages(messages []Message) ([]types.Message, []types.SystemContentBlock) {
|
||||
var bedrockMessages []types.Message
|
||||
var systemPrompts []types.SystemContentBlock
|
||||
|
||||
// Helper to check if a message is a tool result
|
||||
isToolResult := func(msg Message) bool {
|
||||
return (msg.Role == "tool" || (msg.Role == "user" && msg.ToolCallID != "")) && msg.ToolCallID != ""
|
||||
}
|
||||
|
||||
// Helper to create a tool result content block
|
||||
makeToolResultBlock := func(msg Message) types.ContentBlock {
|
||||
return &types.ContentBlockMemberToolResult{
|
||||
Value: types.ToolResultBlock{
|
||||
ToolUseId: aws.String(msg.ToolCallID),
|
||||
Content: []types.ToolResultContentBlock{
|
||||
&types.ToolResultContentBlockMemberText{
|
||||
Value: msg.Content,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
i := 0
|
||||
for i < len(messages) {
|
||||
msg := messages[i]
|
||||
|
||||
switch {
|
||||
case msg.Role == "system":
|
||||
// System messages go to the System field
|
||||
systemPrompts = append(systemPrompts, &types.SystemContentBlockMemberText{
|
||||
Value: msg.Content,
|
||||
})
|
||||
i++
|
||||
|
||||
case isToolResult(msg):
|
||||
// Collect all consecutive tool results into a single user message
|
||||
// Bedrock requires all tool results for a turn in one message
|
||||
var toolResultBlocks []types.ContentBlock
|
||||
for i < len(messages) && isToolResult(messages[i]) {
|
||||
toolResultBlocks = append(toolResultBlocks, makeToolResultBlock(messages[i]))
|
||||
i++
|
||||
}
|
||||
bedrockMessages = append(bedrockMessages, types.Message{
|
||||
Role: types.ConversationRoleUser,
|
||||
Content: toolResultBlocks,
|
||||
})
|
||||
|
||||
case msg.Role == "user":
|
||||
// Regular user message (no ToolCallID)
|
||||
content := buildUserContent(msg)
|
||||
bedrockMessages = append(bedrockMessages, types.Message{
|
||||
Role: types.ConversationRoleUser,
|
||||
Content: content,
|
||||
})
|
||||
i++
|
||||
|
||||
case msg.Role == "assistant":
|
||||
content := buildAssistantContent(msg)
|
||||
bedrockMessages = append(bedrockMessages, types.Message{
|
||||
Role: types.ConversationRoleAssistant,
|
||||
Content: content,
|
||||
})
|
||||
i++
|
||||
|
||||
case msg.Role == "tool" && msg.ToolCallID == "":
|
||||
// Tool message without ToolCallID - treat as regular user message
|
||||
content := buildUserContent(msg)
|
||||
bedrockMessages = append(bedrockMessages, types.Message{
|
||||
Role: types.ConversationRoleUser,
|
||||
Content: content,
|
||||
})
|
||||
i++
|
||||
|
||||
default:
|
||||
// Unknown role - skip
|
||||
i++
|
||||
}
|
||||
}
|
||||
|
||||
return bedrockMessages, systemPrompts
|
||||
}
|
||||
|
||||
// buildUserContent builds Bedrock content blocks for a user message.
|
||||
func buildUserContent(msg Message) []types.ContentBlock {
|
||||
var content []types.ContentBlock
|
||||
|
||||
// Add text content
|
||||
if msg.Content != "" {
|
||||
content = append(content, &types.ContentBlockMemberText{
|
||||
Value: msg.Content,
|
||||
})
|
||||
}
|
||||
|
||||
// Add images from Media field
|
||||
for _, mediaURL := range msg.Media {
|
||||
if strings.HasPrefix(mediaURL, "data:image/") {
|
||||
// Parse data URL: data:image/jpeg;base64,<data>
|
||||
parts := strings.SplitN(mediaURL, ",", 2)
|
||||
if len(parts) != 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Extract media type from "data:image/jpeg;base64"
|
||||
mediaType := ""
|
||||
header := parts[0]
|
||||
if idx := strings.Index(header, "/"); idx != -1 {
|
||||
end := strings.Index(header[idx:], ";")
|
||||
if end == -1 {
|
||||
end = len(header) - idx
|
||||
}
|
||||
mediaType = header[idx+1 : idx+end]
|
||||
}
|
||||
|
||||
// Verify this is base64 encoded
|
||||
if !strings.Contains(header, ";base64") {
|
||||
continue // Skip non-base64 encoded data
|
||||
}
|
||||
|
||||
// Map media type to Bedrock format
|
||||
var format types.ImageFormat
|
||||
switch mediaType {
|
||||
case "jpeg", "jpg":
|
||||
format = types.ImageFormatJpeg
|
||||
case "png":
|
||||
format = types.ImageFormatPng
|
||||
case "gif":
|
||||
format = types.ImageFormatGif
|
||||
case "webp":
|
||||
format = types.ImageFormatWebp
|
||||
default:
|
||||
continue // Skip unsupported formats
|
||||
}
|
||||
|
||||
// Check size before decoding to prevent excessive memory allocation
|
||||
// Bedrock has a ~20MB request limit; cap decoded images at 10MB
|
||||
const maxImageSize = 10 * 1024 * 1024
|
||||
decodedLen := base64.StdEncoding.DecodedLen(len(parts[1]))
|
||||
if decodedLen > maxImageSize {
|
||||
log.Printf("bedrock: skipping image exceeding size limit (%d bytes > %d)", decodedLen, maxImageSize)
|
||||
continue
|
||||
}
|
||||
|
||||
// Decode base64 data
|
||||
imageData, err := base64.StdEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
log.Printf("bedrock: failed to decode base64 image data: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
content = append(content, &types.ContentBlockMemberImage{
|
||||
Value: types.ImageBlock{
|
||||
Format: format,
|
||||
Source: &types.ImageSourceMemberBytes{
|
||||
Value: imageData,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Bedrock requires at least one content block; add empty text if needed
|
||||
if len(content) == 0 {
|
||||
content = append(content, &types.ContentBlockMemberText{Value: ""})
|
||||
}
|
||||
|
||||
return content
|
||||
}
|
||||
|
||||
// buildAssistantContent builds Bedrock content blocks for an assistant message.
|
||||
func buildAssistantContent(msg Message) []types.ContentBlock {
|
||||
var content []types.ContentBlock
|
||||
|
||||
// Add text content if present
|
||||
if msg.Content != "" {
|
||||
content = append(content, &types.ContentBlockMemberText{
|
||||
Value: msg.Content,
|
||||
})
|
||||
}
|
||||
|
||||
// Add tool use blocks
|
||||
for _, tc := range msg.ToolCalls {
|
||||
// Validate tool call ID - Bedrock requires non-empty ToolUseId
|
||||
if strings.TrimSpace(tc.ID) == "" {
|
||||
log.Printf("bedrock: skipping tool call with empty ID (name: %q)", tc.Name)
|
||||
continue
|
||||
}
|
||||
|
||||
// Resolve tool name: prefer tc.Name, fallback to tc.Function.Name
|
||||
// (tc.Name/tc.Arguments are json:"-" and may be empty when from JSON)
|
||||
toolName := tc.Name
|
||||
if toolName == "" && tc.Function != nil {
|
||||
toolName = tc.Function.Name
|
||||
}
|
||||
if strings.TrimSpace(toolName) == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Resolve arguments: prefer tc.Arguments, fallback to parsing tc.Function.Arguments
|
||||
args := tc.Arguments
|
||||
if args == nil && tc.Function != nil && tc.Function.Arguments != "" {
|
||||
if err := json.Unmarshal([]byte(tc.Function.Arguments), &args); err != nil {
|
||||
log.Printf("bedrock: failed to parse Function.Arguments for tool %q: %v", toolName, err)
|
||||
args = map[string]any{}
|
||||
}
|
||||
}
|
||||
if args == nil {
|
||||
args = map[string]any{}
|
||||
}
|
||||
|
||||
// Convert arguments to a Bedrock document using NewLazyDocument
|
||||
inputDoc := document.NewLazyDocument(args)
|
||||
|
||||
content = append(content, &types.ContentBlockMemberToolUse{
|
||||
Value: types.ToolUseBlock{
|
||||
ToolUseId: aws.String(tc.ID),
|
||||
Name: aws.String(toolName),
|
||||
Input: inputDoc,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// Bedrock requires at least one content block; add empty text if needed
|
||||
if len(content) == 0 {
|
||||
content = append(content, &types.ContentBlockMemberText{Value: ""})
|
||||
}
|
||||
|
||||
return content
|
||||
}
|
||||
|
||||
// convertTools converts tool definitions to Bedrock format.
|
||||
func convertTools(tools []ToolDefinition) *types.ToolConfiguration {
|
||||
bedrockTools := make([]types.Tool, 0, len(tools))
|
||||
|
||||
for _, tool := range tools {
|
||||
// Skip tools with empty names
|
||||
if strings.TrimSpace(tool.Function.Name) == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Ensure parameters is not nil - default to minimal object schema
|
||||
params := tool.Function.Parameters
|
||||
if params == nil {
|
||||
params = map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{},
|
||||
}
|
||||
}
|
||||
|
||||
// Convert parameters schema to a Bedrock document
|
||||
inputSchema := document.NewLazyDocument(params)
|
||||
|
||||
bedrockTools = append(bedrockTools, &types.ToolMemberToolSpec{
|
||||
Value: types.ToolSpecification{
|
||||
Name: aws.String(tool.Function.Name),
|
||||
Description: aws.String(tool.Function.Description),
|
||||
InputSchema: &types.ToolInputSchemaMemberJson{
|
||||
Value: inputSchema,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
return &types.ToolConfiguration{
|
||||
Tools: bedrockTools,
|
||||
}
|
||||
}
|
||||
|
||||
// parseResponse converts Bedrock Converse output to LLMResponse.
|
||||
func parseResponse(output *bedrockruntime.ConverseOutput) (*LLMResponse, error) {
|
||||
var content strings.Builder
|
||||
toolCalls := make([]ToolCall, 0)
|
||||
|
||||
// Process output content blocks
|
||||
if output.Output != nil {
|
||||
if msgOutput, ok := output.Output.(*types.ConverseOutputMemberMessage); ok {
|
||||
for _, block := range msgOutput.Value.Content {
|
||||
switch b := block.(type) {
|
||||
case *types.ContentBlockMemberText:
|
||||
content.WriteString(b.Value)
|
||||
|
||||
case *types.ContentBlockMemberToolUse:
|
||||
// Unmarshal the document interface to a map
|
||||
args := make(map[string]any)
|
||||
if b.Value.Input != nil {
|
||||
if err := b.Value.Input.UnmarshalSmithyDocument(&args); err != nil {
|
||||
log.Printf("bedrock: failed to unmarshal tool input for tool %q (id %q): %v",
|
||||
aws.ToString(b.Value.Name),
|
||||
aws.ToString(b.Value.ToolUseId),
|
||||
err,
|
||||
)
|
||||
args = make(map[string]any)
|
||||
}
|
||||
}
|
||||
|
||||
// Serialize arguments to JSON string for FunctionCall
|
||||
argsJSON, err := json.Marshal(args)
|
||||
if err != nil {
|
||||
log.Printf("bedrock: failed to marshal tool arguments for tool %q (id %q): %v",
|
||||
aws.ToString(b.Value.Name),
|
||||
aws.ToString(b.Value.ToolUseId),
|
||||
err,
|
||||
)
|
||||
argsJSON = []byte("{}")
|
||||
}
|
||||
|
||||
toolCalls = append(toolCalls, ToolCall{
|
||||
ID: aws.ToString(b.Value.ToolUseId),
|
||||
Name: aws.ToString(b.Value.Name),
|
||||
Arguments: args,
|
||||
Function: &FunctionCall{
|
||||
Name: aws.ToString(b.Value.Name),
|
||||
Arguments: string(argsJSON),
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Map stop reason
|
||||
finishReason := "stop"
|
||||
switch output.StopReason {
|
||||
case types.StopReasonToolUse:
|
||||
finishReason = "tool_calls"
|
||||
case types.StopReasonMaxTokens:
|
||||
finishReason = "length"
|
||||
case types.StopReasonEndTurn:
|
||||
finishReason = "stop"
|
||||
case types.StopReasonStopSequence:
|
||||
finishReason = "stop"
|
||||
case types.StopReasonContentFiltered:
|
||||
finishReason = "content_filter"
|
||||
}
|
||||
|
||||
// Build usage info
|
||||
var usage *UsageInfo
|
||||
if output.Usage != nil {
|
||||
usage = &UsageInfo{
|
||||
PromptTokens: int(aws.ToInt32(output.Usage.InputTokens)),
|
||||
CompletionTokens: int(aws.ToInt32(output.Usage.OutputTokens)),
|
||||
TotalTokens: int(aws.ToInt32(output.Usage.InputTokens)) + int(aws.ToInt32(output.Usage.OutputTokens)),
|
||||
}
|
||||
}
|
||||
|
||||
return &LLMResponse{
|
||||
Content: content.String(),
|
||||
ToolCalls: toolCalls,
|
||||
FinishReason: finishReason,
|
||||
Usage: usage,
|
||||
}, nil
|
||||
}
|
||||
@@ -0,0 +1,541 @@
|
||||
//go:build bedrock
|
||||
|
||||
// PicoClaw - Ultra-lightweight personal AI agent
|
||||
// License: MIT
|
||||
//
|
||||
// Copyright (c) 2026 PicoClaw contributors
|
||||
|
||||
package bedrock
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
|
||||
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime/document"
|
||||
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
|
||||
)
|
||||
|
||||
func TestConvertMessages_SystemPrompts(t *testing.T) {
|
||||
messages := []Message{
|
||||
{Role: "system", Content: "You are a helpful assistant."},
|
||||
{Role: "user", Content: "Hello"},
|
||||
}
|
||||
|
||||
bedrockMsgs, systemPrompts := convertMessages(messages)
|
||||
|
||||
assert.Len(t, systemPrompts, 1)
|
||||
assert.Len(t, bedrockMsgs, 1)
|
||||
|
||||
// Check system prompt
|
||||
textBlock, ok := systemPrompts[0].(*types.SystemContentBlockMemberText)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "You are a helpful assistant.", textBlock.Value)
|
||||
|
||||
// Check user message
|
||||
assert.Equal(t, types.ConversationRoleUser, bedrockMsgs[0].Role)
|
||||
}
|
||||
|
||||
func TestConvertMessages_UserMessage(t *testing.T) {
|
||||
messages := []Message{
|
||||
{Role: "user", Content: "What is 2+2?"},
|
||||
}
|
||||
|
||||
bedrockMsgs, systemPrompts := convertMessages(messages)
|
||||
|
||||
assert.Empty(t, systemPrompts)
|
||||
assert.Len(t, bedrockMsgs, 1)
|
||||
assert.Equal(t, types.ConversationRoleUser, bedrockMsgs[0].Role)
|
||||
|
||||
textBlock, ok := bedrockMsgs[0].Content[0].(*types.ContentBlockMemberText)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "What is 2+2?", textBlock.Value)
|
||||
}
|
||||
|
||||
func TestConvertMessages_AssistantMessage(t *testing.T) {
|
||||
messages := []Message{
|
||||
{Role: "assistant", Content: "The answer is 4."},
|
||||
}
|
||||
|
||||
bedrockMsgs, _ := convertMessages(messages)
|
||||
|
||||
assert.Len(t, bedrockMsgs, 1)
|
||||
assert.Equal(t, types.ConversationRoleAssistant, bedrockMsgs[0].Role)
|
||||
|
||||
textBlock, ok := bedrockMsgs[0].Content[0].(*types.ContentBlockMemberText)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "The answer is 4.", textBlock.Value)
|
||||
}
|
||||
|
||||
func TestConvertMessages_ToolResult(t *testing.T) {
|
||||
messages := []Message{
|
||||
{Role: "tool", Content: "Result from tool", ToolCallID: "call_123"},
|
||||
}
|
||||
|
||||
bedrockMsgs, _ := convertMessages(messages)
|
||||
|
||||
assert.Len(t, bedrockMsgs, 1)
|
||||
assert.Equal(t, types.ConversationRoleUser, bedrockMsgs[0].Role)
|
||||
|
||||
toolResult, ok := bedrockMsgs[0].Content[0].(*types.ContentBlockMemberToolResult)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "call_123", aws.ToString(toolResult.Value.ToolUseId))
|
||||
}
|
||||
|
||||
func TestConvertMessages_MultipleToolResultsMerged(t *testing.T) {
|
||||
// When an assistant makes multiple tool calls, all tool results must be
|
||||
// merged into a single user message for Bedrock
|
||||
messages := []Message{
|
||||
{Role: "user", Content: "What's the weather in NYC and LA?"},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "Let me check both cities.",
|
||||
ToolCalls: []protocoltypes.ToolCall{
|
||||
{ID: "call_nyc", Name: "get_weather", Arguments: map[string]any{"city": "NYC"}},
|
||||
{ID: "call_la", Name: "get_weather", Arguments: map[string]any{"city": "LA"}},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: "NYC: 72°F, sunny", ToolCallID: "call_nyc"},
|
||||
{Role: "tool", Content: "LA: 85°F, clear", ToolCallID: "call_la"},
|
||||
}
|
||||
|
||||
bedrockMsgs, _ := convertMessages(messages)
|
||||
|
||||
// Should be: user message, assistant message, merged tool results (single user message)
|
||||
assert.Len(t, bedrockMsgs, 3)
|
||||
|
||||
// First message: user
|
||||
assert.Equal(t, types.ConversationRoleUser, bedrockMsgs[0].Role)
|
||||
|
||||
// Second message: assistant with tool calls
|
||||
assert.Equal(t, types.ConversationRoleAssistant, bedrockMsgs[1].Role)
|
||||
|
||||
// Third message: merged tool results in single user message
|
||||
assert.Equal(t, types.ConversationRoleUser, bedrockMsgs[2].Role)
|
||||
assert.Len(t, bedrockMsgs[2].Content, 2) // Both tool results in one message
|
||||
|
||||
// Verify both tool results are present
|
||||
result1, ok := bedrockMsgs[2].Content[0].(*types.ContentBlockMemberToolResult)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "call_nyc", aws.ToString(result1.Value.ToolUseId))
|
||||
|
||||
result2, ok := bedrockMsgs[2].Content[1].(*types.ContentBlockMemberToolResult)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "call_la", aws.ToString(result2.Value.ToolUseId))
|
||||
}
|
||||
|
||||
func TestConvertMessages_AssistantWithToolCalls(t *testing.T) {
|
||||
messages := []Message{
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "Let me calculate that.",
|
||||
ToolCalls: []protocoltypes.ToolCall{
|
||||
{
|
||||
ID: "call_456",
|
||||
Name: "calculator",
|
||||
Arguments: map[string]any{"expression": "2+2"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
bedrockMsgs, _ := convertMessages(messages)
|
||||
|
||||
assert.Len(t, bedrockMsgs, 1)
|
||||
assert.Len(t, bedrockMsgs[0].Content, 2) // text + tool use
|
||||
|
||||
// Check text content
|
||||
textBlock, ok := bedrockMsgs[0].Content[0].(*types.ContentBlockMemberText)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "Let me calculate that.", textBlock.Value)
|
||||
|
||||
// Check tool use
|
||||
toolUse, ok := bedrockMsgs[0].Content[1].(*types.ContentBlockMemberToolUse)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "call_456", aws.ToString(toolUse.Value.ToolUseId))
|
||||
assert.Equal(t, "calculator", aws.ToString(toolUse.Value.Name))
|
||||
}
|
||||
|
||||
func TestConvertTools_Basic(t *testing.T) {
|
||||
tools := []ToolDefinition{
|
||||
{
|
||||
Function: protocoltypes.ToolFunctionDefinition{
|
||||
Name: "get_weather",
|
||||
Description: "Get the current weather",
|
||||
Parameters: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"location": map[string]any{"type": "string"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
toolConfig := convertTools(tools)
|
||||
|
||||
assert.NotNil(t, toolConfig)
|
||||
assert.Len(t, toolConfig.Tools, 1)
|
||||
|
||||
toolSpec, ok := toolConfig.Tools[0].(*types.ToolMemberToolSpec)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "get_weather", aws.ToString(toolSpec.Value.Name))
|
||||
assert.Equal(t, "Get the current weather", aws.ToString(toolSpec.Value.Description))
|
||||
}
|
||||
|
||||
func TestConvertTools_SkipsEmptyName(t *testing.T) {
|
||||
tools := []ToolDefinition{
|
||||
{
|
||||
Function: protocoltypes.ToolFunctionDefinition{
|
||||
Name: "",
|
||||
Description: "Empty name tool",
|
||||
},
|
||||
},
|
||||
{
|
||||
Function: protocoltypes.ToolFunctionDefinition{
|
||||
Name: " ",
|
||||
Description: "Whitespace name tool",
|
||||
},
|
||||
},
|
||||
{
|
||||
Function: protocoltypes.ToolFunctionDefinition{
|
||||
Name: "valid_tool",
|
||||
Description: "Valid tool",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
toolConfig := convertTools(tools)
|
||||
|
||||
assert.Len(t, toolConfig.Tools, 1)
|
||||
toolSpec := toolConfig.Tools[0].(*types.ToolMemberToolSpec)
|
||||
assert.Equal(t, "valid_tool", aws.ToString(toolSpec.Value.Name))
|
||||
}
|
||||
|
||||
func TestConvertTools_NilParameters(t *testing.T) {
|
||||
tools := []ToolDefinition{
|
||||
{
|
||||
Function: protocoltypes.ToolFunctionDefinition{
|
||||
Name: "simple_tool",
|
||||
Description: "A tool with no parameters",
|
||||
Parameters: nil,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
toolConfig := convertTools(tools)
|
||||
|
||||
assert.Len(t, toolConfig.Tools, 1)
|
||||
// Should not panic and should create a valid tool
|
||||
}
|
||||
|
||||
func TestBuildUserContent_TextOnly(t *testing.T) {
|
||||
msg := Message{Content: "Hello world"}
|
||||
|
||||
content := buildUserContent(msg)
|
||||
|
||||
assert.Len(t, content, 1)
|
||||
textBlock, ok := content[0].(*types.ContentBlockMemberText)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "Hello world", textBlock.Value)
|
||||
}
|
||||
|
||||
func TestBuildUserContent_WithImage(t *testing.T) {
|
||||
// Base64-encoded 1x1 PNG (the provider doesn't validate image correctness,
|
||||
// it just verifies the format and base64 decoding works)
|
||||
b64Data := "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADUlEQVR4nGNgYAAAAAMAASsJTYQAAAAASUVORK5CYII="
|
||||
|
||||
msg := Message{
|
||||
Content: "Look at this image",
|
||||
Media: []string{"data:image/png;base64," + b64Data},
|
||||
}
|
||||
|
||||
content := buildUserContent(msg)
|
||||
|
||||
assert.Len(t, content, 2)
|
||||
|
||||
// Check text
|
||||
textBlock, ok := content[0].(*types.ContentBlockMemberText)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "Look at this image", textBlock.Value)
|
||||
|
||||
// Check image
|
||||
imageBlock, ok := content[1].(*types.ContentBlockMemberImage)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, types.ImageFormatPng, imageBlock.Value.Format)
|
||||
}
|
||||
|
||||
func TestBuildUserContent_SkipsInvalidBase64(t *testing.T) {
|
||||
msg := Message{
|
||||
Content: "Invalid image",
|
||||
Media: []string{"data:image/png;base64,not-valid-base64!!!"},
|
||||
}
|
||||
|
||||
content := buildUserContent(msg)
|
||||
|
||||
// Should only have text, image should be skipped
|
||||
assert.Len(t, content, 1)
|
||||
}
|
||||
|
||||
func TestBuildUserContent_SkipsNonBase64Data(t *testing.T) {
|
||||
msg := Message{
|
||||
Content: "Non-base64 image",
|
||||
Media: []string{"data:image/png,raw-data-here"},
|
||||
}
|
||||
|
||||
content := buildUserContent(msg)
|
||||
|
||||
// Should only have text, non-base64 image should be skipped
|
||||
assert.Len(t, content, 1)
|
||||
}
|
||||
|
||||
func TestBuildAssistantContent_SkipsEmptyToolName(t *testing.T) {
|
||||
msg := Message{
|
||||
Content: "Response",
|
||||
ToolCalls: []protocoltypes.ToolCall{
|
||||
{ID: "1", Name: "", Arguments: map[string]any{}},
|
||||
{ID: "2", Name: " ", Arguments: map[string]any{}},
|
||||
{ID: "3", Name: "valid", Arguments: map[string]any{}},
|
||||
},
|
||||
}
|
||||
|
||||
content := buildAssistantContent(msg)
|
||||
|
||||
// Should have text + 1 valid tool
|
||||
assert.Len(t, content, 2)
|
||||
}
|
||||
|
||||
func TestBuildAssistantContent_NilArguments(t *testing.T) {
|
||||
msg := Message{
|
||||
ToolCalls: []protocoltypes.ToolCall{
|
||||
{ID: "1", Name: "tool", Arguments: nil},
|
||||
},
|
||||
}
|
||||
|
||||
content := buildAssistantContent(msg)
|
||||
|
||||
assert.Len(t, content, 1)
|
||||
toolUse, ok := content[0].(*types.ContentBlockMemberToolUse)
|
||||
require.True(t, ok)
|
||||
assert.NotNil(t, toolUse.Value.Input)
|
||||
}
|
||||
|
||||
func TestBuildAssistantContent_FunctionFallback(t *testing.T) {
|
||||
// When Name/Arguments are empty (json:"-"), should fallback to Function fields
|
||||
msg := Message{
|
||||
ToolCalls: []protocoltypes.ToolCall{
|
||||
{
|
||||
ID: "1",
|
||||
Name: "", // empty, should fallback to Function.Name
|
||||
Function: &protocoltypes.FunctionCall{
|
||||
Name: "fallback_tool",
|
||||
Arguments: `{"key":"value"}`,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
content := buildAssistantContent(msg)
|
||||
|
||||
assert.Len(t, content, 1)
|
||||
toolUse, ok := content[0].(*types.ContentBlockMemberToolUse)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "fallback_tool", aws.ToString(toolUse.Value.Name))
|
||||
}
|
||||
|
||||
func TestParseResponse_TextOnly(t *testing.T) {
|
||||
output := &bedrockruntime.ConverseOutput{
|
||||
Output: &types.ConverseOutputMemberMessage{
|
||||
Value: types.Message{
|
||||
Role: types.ConversationRoleAssistant,
|
||||
Content: []types.ContentBlock{
|
||||
&types.ContentBlockMemberText{Value: "Hello!"},
|
||||
},
|
||||
},
|
||||
},
|
||||
StopReason: types.StopReasonEndTurn,
|
||||
Usage: &types.TokenUsage{
|
||||
InputTokens: aws.Int32(10),
|
||||
OutputTokens: aws.Int32(5),
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := parseResponse(output)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Hello!", resp.Content)
|
||||
assert.Equal(t, "stop", resp.FinishReason)
|
||||
assert.Empty(t, resp.ToolCalls)
|
||||
assert.Equal(t, 10, resp.Usage.PromptTokens)
|
||||
assert.Equal(t, 5, resp.Usage.CompletionTokens)
|
||||
}
|
||||
|
||||
func TestParseResponse_StopReasons(t *testing.T) {
|
||||
tests := []struct {
|
||||
stopReason types.StopReason
|
||||
expectedFinish string
|
||||
}{
|
||||
{types.StopReasonEndTurn, "stop"},
|
||||
{types.StopReasonToolUse, "tool_calls"},
|
||||
{types.StopReasonMaxTokens, "length"},
|
||||
{types.StopReasonStopSequence, "stop"},
|
||||
{types.StopReasonContentFiltered, "content_filter"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(string(tt.stopReason), func(t *testing.T) {
|
||||
output := &bedrockruntime.ConverseOutput{
|
||||
Output: &types.ConverseOutputMemberMessage{
|
||||
Value: types.Message{
|
||||
Content: []types.ContentBlock{
|
||||
&types.ContentBlockMemberText{Value: "test"},
|
||||
},
|
||||
},
|
||||
},
|
||||
StopReason: tt.stopReason,
|
||||
}
|
||||
|
||||
resp, err := parseResponse(output)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.expectedFinish, resp.FinishReason)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseResponse_WithToolCalls(t *testing.T) {
|
||||
// Note: document.NewLazyDocument has limitations with UnmarshalSmithyDocument in tests,
|
||||
// so we test the structure extraction and verify Arguments gets populated (even if empty
|
||||
// due to SDK limitations). The actual unmarshal works correctly at runtime.
|
||||
toolInput := document.NewLazyDocument(map[string]any{
|
||||
"location": "San Francisco",
|
||||
"unit": "celsius",
|
||||
})
|
||||
|
||||
output := &bedrockruntime.ConverseOutput{
|
||||
Output: &types.ConverseOutputMemberMessage{
|
||||
Value: types.Message{
|
||||
Role: types.ConversationRoleAssistant,
|
||||
Content: []types.ContentBlock{
|
||||
&types.ContentBlockMemberText{Value: "Let me check the weather."},
|
||||
&types.ContentBlockMemberToolUse{
|
||||
Value: types.ToolUseBlock{
|
||||
ToolUseId: aws.String("call_weather_123"),
|
||||
Name: aws.String("get_weather"),
|
||||
Input: toolInput,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
StopReason: types.StopReasonToolUse,
|
||||
Usage: &types.TokenUsage{
|
||||
InputTokens: aws.Int32(20),
|
||||
OutputTokens: aws.Int32(15),
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := parseResponse(output)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Let me check the weather.", resp.Content)
|
||||
assert.Equal(t, "tool_calls", resp.FinishReason)
|
||||
assert.Len(t, resp.ToolCalls, 1)
|
||||
|
||||
// Verify tool call ID and Name are extracted correctly
|
||||
tc := resp.ToolCalls[0]
|
||||
assert.Equal(t, "call_weather_123", tc.ID)
|
||||
assert.Equal(t, "get_weather", tc.Name)
|
||||
|
||||
// Verify Function fields are also populated
|
||||
require.NotNil(t, tc.Function)
|
||||
assert.Equal(t, "get_weather", tc.Function.Name)
|
||||
|
||||
// Verify Arguments is not nil (content may vary due to SDK limitations in tests)
|
||||
assert.NotNil(t, tc.Arguments)
|
||||
|
||||
// Verify usage
|
||||
assert.Equal(t, 20, resp.Usage.PromptTokens)
|
||||
assert.Equal(t, 15, resp.Usage.CompletionTokens)
|
||||
assert.Equal(t, 35, resp.Usage.TotalTokens)
|
||||
}
|
||||
|
||||
func TestParseResponse_MultipleToolCalls(t *testing.T) {
|
||||
output := &bedrockruntime.ConverseOutput{
|
||||
Output: &types.ConverseOutputMemberMessage{
|
||||
Value: types.Message{
|
||||
Role: types.ConversationRoleAssistant,
|
||||
Content: []types.ContentBlock{
|
||||
&types.ContentBlockMemberToolUse{
|
||||
Value: types.ToolUseBlock{
|
||||
ToolUseId: aws.String("call_1"),
|
||||
Name: aws.String("tool_a"),
|
||||
Input: document.NewLazyDocument(map[string]any{"arg": "value1"}),
|
||||
},
|
||||
},
|
||||
&types.ContentBlockMemberToolUse{
|
||||
Value: types.ToolUseBlock{
|
||||
ToolUseId: aws.String("call_2"),
|
||||
Name: aws.String("tool_b"),
|
||||
Input: document.NewLazyDocument(map[string]any{"arg": "value2"}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
StopReason: types.StopReasonToolUse,
|
||||
}
|
||||
|
||||
resp, err := parseResponse(output)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "tool_calls", resp.FinishReason)
|
||||
assert.Len(t, resp.ToolCalls, 2)
|
||||
|
||||
// Verify tool call structure
|
||||
assert.Equal(t, "call_1", resp.ToolCalls[0].ID)
|
||||
assert.Equal(t, "tool_a", resp.ToolCalls[0].Name)
|
||||
assert.NotNil(t, resp.ToolCalls[0].Arguments)
|
||||
assert.NotNil(t, resp.ToolCalls[0].Function)
|
||||
assert.Equal(t, "tool_a", resp.ToolCalls[0].Function.Name)
|
||||
|
||||
assert.Equal(t, "call_2", resp.ToolCalls[1].ID)
|
||||
assert.Equal(t, "tool_b", resp.ToolCalls[1].Name)
|
||||
assert.NotNil(t, resp.ToolCalls[1].Arguments)
|
||||
assert.NotNil(t, resp.ToolCalls[1].Function)
|
||||
assert.Equal(t, "tool_b", resp.ToolCalls[1].Function.Name)
|
||||
}
|
||||
|
||||
func TestParseResponse_ToolCallWithNilInput(t *testing.T) {
|
||||
output := &bedrockruntime.ConverseOutput{
|
||||
Output: &types.ConverseOutputMemberMessage{
|
||||
Value: types.Message{
|
||||
Role: types.ConversationRoleAssistant,
|
||||
Content: []types.ContentBlock{
|
||||
&types.ContentBlockMemberToolUse{
|
||||
Value: types.ToolUseBlock{
|
||||
ToolUseId: aws.String("call_nil"),
|
||||
Name: aws.String("no_args_tool"),
|
||||
Input: nil,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
StopReason: types.StopReasonToolUse,
|
||||
}
|
||||
|
||||
resp, err := parseResponse(output)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, resp.ToolCalls, 1)
|
||||
assert.Equal(t, "call_nil", resp.ToolCalls[0].ID)
|
||||
assert.Equal(t, "no_args_tool", resp.ToolCalls[0].Name)
|
||||
// Arguments should be empty map, not nil
|
||||
assert.NotNil(t, resp.ToolCalls[0].Arguments)
|
||||
assert.Empty(t, resp.ToolCalls[0].Arguments)
|
||||
}
|
||||
@@ -0,0 +1,73 @@
|
||||
//go:build !bedrock
|
||||
|
||||
// PicoClaw - Ultra-lightweight personal AI agent
|
||||
// License: MIT
|
||||
//
|
||||
// Copyright (c) 2026 PicoClaw contributors
|
||||
|
||||
// Package bedrock provides a stub implementation when built without the bedrock tag.
|
||||
// To enable AWS Bedrock support, build with: go build -tags bedrock
|
||||
package bedrock
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
|
||||
)
|
||||
|
||||
type (
|
||||
LLMResponse = protocoltypes.LLMResponse
|
||||
Message = protocoltypes.Message
|
||||
ToolDefinition = protocoltypes.ToolDefinition
|
||||
)
|
||||
|
||||
// Provider is a stub that returns an error when Bedrock support is not compiled in.
|
||||
type Provider struct{}
|
||||
|
||||
// Option is a no-op when Bedrock is not enabled.
|
||||
type Option func(*providerConfig)
|
||||
|
||||
type providerConfig struct{}
|
||||
|
||||
// WithRegion is a no-op when Bedrock is not enabled.
|
||||
func WithRegion(region string) Option {
|
||||
return func(c *providerConfig) {}
|
||||
}
|
||||
|
||||
// WithProfile is a no-op when Bedrock is not enabled.
|
||||
func WithProfile(profile string) Option {
|
||||
return func(c *providerConfig) {}
|
||||
}
|
||||
|
||||
// WithBaseEndpoint is a no-op when Bedrock is not enabled.
|
||||
func WithBaseEndpoint(endpoint string) Option {
|
||||
return func(c *providerConfig) {}
|
||||
}
|
||||
|
||||
// WithRequestTimeout is a no-op when Bedrock is not enabled.
|
||||
func WithRequestTimeout(timeout time.Duration) Option {
|
||||
return func(c *providerConfig) {}
|
||||
}
|
||||
|
||||
// NewProvider returns an error indicating Bedrock support is not compiled in.
|
||||
func NewProvider(ctx context.Context, opts ...Option) (*Provider, error) {
|
||||
return nil, fmt.Errorf("bedrock provider not available: build with -tags bedrock to enable AWS Bedrock support")
|
||||
}
|
||||
|
||||
// Chat returns an error - this should never be called since NewProvider fails.
|
||||
func (p *Provider) Chat(
|
||||
ctx context.Context,
|
||||
messages []Message,
|
||||
tools []ToolDefinition,
|
||||
model string,
|
||||
options map[string]any,
|
||||
) (*LLMResponse, error) {
|
||||
return nil, fmt.Errorf("bedrock provider not available: build with -tags bedrock to enable AWS Bedrock support")
|
||||
}
|
||||
|
||||
// GetDefaultModel returns an empty string.
|
||||
func (p *Provider) GetDefaultModel() string {
|
||||
return ""
|
||||
}
|
||||
@@ -0,0 +1,35 @@
|
||||
//go:build !bedrock
|
||||
|
||||
// PicoClaw - Ultra-lightweight personal AI agent
|
||||
// License: MIT
|
||||
//
|
||||
// Copyright (c) 2026 PicoClaw contributors
|
||||
|
||||
package bedrock
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewProvider_ReturnsStubError(t *testing.T) {
|
||||
provider, err := NewProvider(context.Background())
|
||||
|
||||
assert.Nil(t, provider)
|
||||
require.Error(t, err)
|
||||
assert.True(t, strings.Contains(err.Error(), "build with -tags bedrock"),
|
||||
"error should mention build tag requirement, got: %s", err.Error())
|
||||
}
|
||||
|
||||
func TestNewProvider_WithOptions_ReturnsStubError(t *testing.T) {
|
||||
provider, err := NewProvider(context.Background(), WithRegion("us-west-2"), WithProfile("test"))
|
||||
|
||||
assert.Nil(t, provider)
|
||||
require.Error(t, err)
|
||||
assert.True(t, strings.Contains(err.Error(), "build with -tags bedrock"),
|
||||
"error should mention build tag requirement, got: %s", err.Error())
|
||||
}
|
||||
@@ -6,12 +6,15 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
anthropicmessages "github.com/sipeed/picoclaw/pkg/providers/anthropic_messages"
|
||||
"github.com/sipeed/picoclaw/pkg/providers/azure"
|
||||
"github.com/sipeed/picoclaw/pkg/providers/bedrock"
|
||||
)
|
||||
|
||||
// createClaudeAuthProvider creates a Claude provider using OAuth credentials from auth store.
|
||||
@@ -55,8 +58,9 @@ func ExtractProtocol(model string) (protocol, modelID string) {
|
||||
|
||||
// CreateProviderFromConfig creates a provider based on the ModelConfig.
|
||||
// It uses the protocol prefix in the Model field to determine which provider to create.
|
||||
// Supported protocols: openai, litellm, novita, anthropic, anthropic-messages,
|
||||
// antigravity, claude-cli, codex-cli, github-copilot
|
||||
// Supported protocol families include OpenAI-compatible prefixes (e.g., openai, openrouter, groq, gemini),
|
||||
// Azure OpenAI, Amazon Bedrock, Anthropic (including messages), and various CLI/compatibility shims.
|
||||
// See the switch on protocol in this function for the authoritative list.
|
||||
// Returns the provider, the model ID (without protocol prefix), and any error.
|
||||
func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, error) {
|
||||
if cfg == nil {
|
||||
@@ -114,6 +118,42 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
|
||||
cfg.RequestTimeout,
|
||||
), modelID, nil
|
||||
|
||||
case "bedrock":
|
||||
// AWS Bedrock uses AWS SDK credentials (env vars, profiles, IAM roles, etc.)
|
||||
// api_base can be:
|
||||
// - A full endpoint URL: https://bedrock-runtime.us-east-1.amazonaws.com
|
||||
// - A region name: us-east-1 (AWS SDK resolves endpoint automatically)
|
||||
var opts []bedrock.Option
|
||||
if cfg.APIBase != "" {
|
||||
if !strings.Contains(cfg.APIBase, "://") {
|
||||
// Treat as region: let AWS SDK resolve the correct endpoint
|
||||
// (supports all AWS partitions: aws, aws-cn, aws-us-gov, etc.)
|
||||
opts = append(opts, bedrock.WithRegion(cfg.APIBase))
|
||||
} else {
|
||||
// Full endpoint URL provided (for custom endpoints or testing)
|
||||
opts = append(opts, bedrock.WithBaseEndpoint(cfg.APIBase))
|
||||
}
|
||||
}
|
||||
// Use a separate timeout for AWS config loading (credential resolution can block)
|
||||
initTimeout := 30 * time.Second
|
||||
if cfg.RequestTimeout > 0 {
|
||||
reqTimeout := time.Duration(cfg.RequestTimeout) * time.Second
|
||||
// Set request timeout for API calls
|
||||
opts = append(opts, bedrock.WithRequestTimeout(reqTimeout))
|
||||
// Ensure init timeout is at least as large as request timeout
|
||||
if reqTimeout > initTimeout {
|
||||
initTimeout = reqTimeout
|
||||
}
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), initTimeout)
|
||||
defer cancel()
|
||||
// Note: AWS_PROFILE env var is automatically used by AWS SDK
|
||||
provider, err := bedrock.NewProvider(ctx, opts...)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("creating bedrock provider: %w", err)
|
||||
}
|
||||
return provider, modelID, nil
|
||||
|
||||
case "litellm", "openrouter", "groq", "zhipu", "gemini", "nvidia",
|
||||
"ollama", "moonshot", "shengsuanyun", "deepseek", "cerebras",
|
||||
"vivgrid", "volcengine", "vllm", "qwen", "qwen-intl", "qwen-international", "dashscope-intl",
|
||||
|
||||
@@ -700,3 +700,78 @@ func TestCreateProviderFromConfig_MinimaxPreservesUserExtraBody(t *testing.T) {
|
||||
t.Fatalf("custom_field = %v, want test", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateProviderFromConfig_Bedrock(t *testing.T) {
|
||||
// Set dummy AWS env vars to make test deterministic
|
||||
t.Setenv("AWS_ACCESS_KEY_ID", "test-key")
|
||||
t.Setenv("AWS_SECRET_ACCESS_KEY", "test-secret")
|
||||
t.Setenv("AWS_EC2_METADATA_DISABLED", "true")
|
||||
// Clear profile-related env vars to avoid loading shared config
|
||||
t.Setenv("AWS_PROFILE", "")
|
||||
t.Setenv("AWS_DEFAULT_PROFILE", "")
|
||||
t.Setenv("AWS_SDK_LOAD_CONFIG", "")
|
||||
t.Setenv("AWS_SHARED_CREDENTIALS_FILE", "")
|
||||
|
||||
cfg := &config.ModelConfig{
|
||||
ModelName: "bedrock-claude",
|
||||
Model: "bedrock/us.anthropic.claude-sonnet-4-20250514-v1:0",
|
||||
APIBase: "us-west-2", // Region (also sets AWS region)
|
||||
}
|
||||
|
||||
provider, modelID, err := CreateProviderFromConfig(cfg)
|
||||
if err == nil {
|
||||
// Provider created successfully (built with -tags bedrock)
|
||||
if provider == nil {
|
||||
t.Error("provider is nil on success")
|
||||
}
|
||||
if modelID != "us.anthropic.claude-sonnet-4-20250514-v1:0" {
|
||||
t.Errorf("modelID = %q, want %q", modelID, "us.anthropic.claude-sonnet-4-20250514-v1:0")
|
||||
}
|
||||
return
|
||||
}
|
||||
errMsg := err.Error()
|
||||
// When built without -tags bedrock, expect stub error
|
||||
if strings.Contains(errMsg, "build with -tags bedrock") {
|
||||
return // Expected stub error
|
||||
}
|
||||
// Unexpected error - fail the test
|
||||
t.Errorf("unexpected error from bedrock provider: %v", err)
|
||||
}
|
||||
|
||||
func TestCreateProviderFromConfig_BedrockWithEndpointURL(t *testing.T) {
|
||||
// Set dummy AWS env vars to make test deterministic
|
||||
t.Setenv("AWS_ACCESS_KEY_ID", "test-key")
|
||||
t.Setenv("AWS_SECRET_ACCESS_KEY", "test-secret")
|
||||
t.Setenv("AWS_REGION", "us-east-1") // Required when using endpoint URL
|
||||
t.Setenv("AWS_EC2_METADATA_DISABLED", "true")
|
||||
// Clear profile-related env vars to avoid loading shared config
|
||||
t.Setenv("AWS_PROFILE", "")
|
||||
t.Setenv("AWS_DEFAULT_PROFILE", "")
|
||||
t.Setenv("AWS_SDK_LOAD_CONFIG", "")
|
||||
t.Setenv("AWS_SHARED_CREDENTIALS_FILE", "")
|
||||
|
||||
cfg := &config.ModelConfig{
|
||||
ModelName: "bedrock-claude",
|
||||
Model: "bedrock/us.anthropic.claude-sonnet-4-20250514-v1:0",
|
||||
APIBase: "https://bedrock-runtime.us-east-1.amazonaws.com", // Full endpoint URL
|
||||
}
|
||||
|
||||
provider, modelID, err := CreateProviderFromConfig(cfg)
|
||||
if err == nil {
|
||||
// Provider created successfully (built with -tags bedrock)
|
||||
if provider == nil {
|
||||
t.Error("provider is nil on success")
|
||||
}
|
||||
if modelID != "us.anthropic.claude-sonnet-4-20250514-v1:0" {
|
||||
t.Errorf("modelID = %q, want %q", modelID, "us.anthropic.claude-sonnet-4-20250514-v1:0")
|
||||
}
|
||||
return
|
||||
}
|
||||
errMsg := err.Error()
|
||||
// When built without -tags bedrock, expect stub error
|
||||
if strings.Contains(errMsg, "build with -tags bedrock") {
|
||||
return // Expected stub error
|
||||
}
|
||||
// Unexpected error - fail the test
|
||||
t.Errorf("unexpected error from bedrock provider: %v", err)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user