mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
feat(providers): add AWS Bedrock provider (#1903)
Add support for AWS Bedrock as an LLM provider using the Converse API. The implementation is behind a build tag (-tags bedrock) to keep the default binary size small. Features: - AWS SDK v2 with automatic credential chain (env vars, profiles, IAM roles) - Converse API for unified access to Claude, Llama, Mistral models - Tool/function calling support with proper document handling - Image support with base64 decoding and size limits - Request timeout configuration - Region validation and endpoint resolution for all AWS partitions Usage: go build -tags bedrock model: bedrock/us.anthropic.claude-sonnet-4-20250514-v1:0 api_base: us-east-1 (or full endpoint URL)
This commit is contained in:
@@ -373,6 +373,9 @@ PicoClaw supports 30+ LLM providers through the `model_list` configuration. Use
|
|||||||
| [Azure OpenAI](https://portal.azure.com/) | `azure/` | Required | Enterprise Azure deployment |
|
| [Azure OpenAI](https://portal.azure.com/) | `azure/` | Required | Enterprise Azure deployment |
|
||||||
| [GitHub Copilot](https://github.com/features/copilot) | `github-copilot/` | OAuth | Device code login |
|
| [GitHub Copilot](https://github.com/features/copilot) | `github-copilot/` | OAuth | Device code login |
|
||||||
| [Antigravity](https://console.cloud.google.com/) | `antigravity/` | OAuth | Google Cloud AI |
|
| [Antigravity](https://console.cloud.google.com/) | `antigravity/` | OAuth | Google Cloud AI |
|
||||||
|
| [AWS Bedrock](https://console.aws.amazon.com/bedrock)* | `bedrock/` | AWS credentials | Claude, Llama, Mistral on AWS |
|
||||||
|
|
||||||
|
> \* AWS Bedrock requires build tag: `go build -tags bedrock`. Set `api_base` to a region name (e.g., `us-east-1`) for automatic endpoint resolution across all AWS partitions (aws, aws-cn, aws-us-gov). When using a full endpoint URL instead, you must also configure `AWS_REGION` via environment variable or AWS config/profile.
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
<summary><b>Local deployment (Ollama, vLLM, etc.)</b></summary>
|
<summary><b>Local deployment (Ollama, vLLM, etc.)</b></summary>
|
||||||
|
|||||||
@@ -7,6 +7,9 @@ require (
|
|||||||
github.com/BurntSushi/toml v1.6.0
|
github.com/BurntSushi/toml v1.6.0
|
||||||
github.com/adhocore/gronx v1.19.6
|
github.com/adhocore/gronx v1.19.6
|
||||||
github.com/anthropics/anthropic-sdk-go v1.26.0
|
github.com/anthropics/anthropic-sdk-go v1.26.0
|
||||||
|
github.com/aws/aws-sdk-go-v2 v1.41.4
|
||||||
|
github.com/aws/aws-sdk-go-v2/config v1.32.12
|
||||||
|
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.50.2
|
||||||
github.com/bwmarrin/discordgo v0.29.0
|
github.com/bwmarrin/discordgo v0.29.0
|
||||||
github.com/caarlos0/env/v11 v11.4.0
|
github.com/caarlos0/env/v11 v11.4.0
|
||||||
github.com/ergochat/irc-go v0.6.0
|
github.com/ergochat/irc-go v0.6.0
|
||||||
@@ -40,6 +43,19 @@ require (
|
|||||||
|
|
||||||
require (
|
require (
|
||||||
filippo.io/edwards25519 v1.2.0 // indirect
|
filippo.io/edwards25519 v1.2.0 // indirect
|
||||||
|
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.7 // indirect
|
||||||
|
github.com/aws/aws-sdk-go-v2/credentials v1.19.12 // indirect
|
||||||
|
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.20 // indirect
|
||||||
|
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.20 // indirect
|
||||||
|
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.20 // indirect
|
||||||
|
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.6 // indirect
|
||||||
|
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.7 // indirect
|
||||||
|
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.20 // indirect
|
||||||
|
github.com/aws/aws-sdk-go-v2/service/signin v1.0.8 // indirect
|
||||||
|
github.com/aws/aws-sdk-go-v2/service/sso v1.30.13 // indirect
|
||||||
|
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.17 // indirect
|
||||||
|
github.com/aws/aws-sdk-go-v2/service/sts v1.41.9 // indirect
|
||||||
|
github.com/aws/smithy-go v1.24.2 // indirect
|
||||||
github.com/beeper/argo-go v1.1.2 // indirect
|
github.com/beeper/argo-go v1.1.2 // indirect
|
||||||
github.com/coder/websocket v1.8.14 // indirect
|
github.com/coder/websocket v1.8.14 // indirect
|
||||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||||
|
|||||||
@@ -17,6 +17,38 @@ github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwTo
|
|||||||
github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY=
|
github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY=
|
||||||
github.com/anthropics/anthropic-sdk-go v1.26.0 h1:oUTzFaUpAevfuELAP1sjL6CQJ9HHAfT7CoSYSac11PY=
|
github.com/anthropics/anthropic-sdk-go v1.26.0 h1:oUTzFaUpAevfuELAP1sjL6CQJ9HHAfT7CoSYSac11PY=
|
||||||
github.com/anthropics/anthropic-sdk-go v1.26.0/go.mod h1:qUKmaW+uuPB64iy1l+4kOSvaLqPXnHTTBKH6RVZ7q5Q=
|
github.com/anthropics/anthropic-sdk-go v1.26.0/go.mod h1:qUKmaW+uuPB64iy1l+4kOSvaLqPXnHTTBKH6RVZ7q5Q=
|
||||||
|
github.com/aws/aws-sdk-go-v2 v1.41.4 h1:10f50G7WyU02T56ox1wWXq+zTX9I1zxG46HYuG1hH/k=
|
||||||
|
github.com/aws/aws-sdk-go-v2 v1.41.4/go.mod h1:mwsPRE8ceUUpiTgF7QmQIJ7lgsKUPQOUl3o72QBrE1o=
|
||||||
|
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.7 h1:3kGOqnh1pPeddVa/E37XNTaWJ8W6vrbYV9lJEkCnhuY=
|
||||||
|
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.7/go.mod h1:lyw7GFp3qENLh7kwzf7iMzAxDn+NzjXEAGjKS2UOKqI=
|
||||||
|
github.com/aws/aws-sdk-go-v2/config v1.32.12 h1:O3csC7HUGn2895eNrLytOJQdoL2xyJy0iYXhoZ1OmP0=
|
||||||
|
github.com/aws/aws-sdk-go-v2/config v1.32.12/go.mod h1:96zTvoOFR4FURjI+/5wY1vc1ABceROO4lWgWJuxgy0g=
|
||||||
|
github.com/aws/aws-sdk-go-v2/credentials v1.19.12 h1:oqtA6v+y5fZg//tcTWahyN9PEn5eDU/Wpvc2+kJ4aY8=
|
||||||
|
github.com/aws/aws-sdk-go-v2/credentials v1.19.12/go.mod h1:U3R1RtSHx6NB0DvEQFGyf/0sbrpJrluENHdPy1j/3TE=
|
||||||
|
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.20 h1:zOgq3uezl5nznfoK3ODuqbhVg1JzAGDUhXOsU0IDCAo=
|
||||||
|
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.20/go.mod h1:z/MVwUARehy6GAg/yQ1GO2IMl0k++cu1ohP9zo887wE=
|
||||||
|
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.20 h1:CNXO7mvgThFGqOFgbNAP2nol2qAWBOGfqR/7tQlvLmc=
|
||||||
|
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.20/go.mod h1:oydPDJKcfMhgfcgBUZaG+toBbwy8yPWubJXBVERtI4o=
|
||||||
|
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.20 h1:tN6W/hg+pkM+tf9XDkWUbDEjGLb+raoBMFsTodcoYKw=
|
||||||
|
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.20/go.mod h1:YJ898MhD067hSHA6xYCx5ts/jEd8BSOLtQDL3iZsvbc=
|
||||||
|
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.6 h1:qYQ4pzQ2Oz6WpQ8T3HvGHnZydA72MnLuFK9tJwmrbHw=
|
||||||
|
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.6/go.mod h1:O3h0IK87yXci+kg6flUKzJnWeziQUKciKrLjcatSNcY=
|
||||||
|
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.50.2 h1:x0eGAWpd1B5I/vMtrB4Q4Zuc3CXWI8wjHfPPqBSrKmM=
|
||||||
|
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.50.2/go.mod h1:V9oTWSDC2MtS1DR71hbNET/bZ8psQp022amEBe1grJc=
|
||||||
|
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.7 h1:5EniKhLZe4xzL7a+fU3C2tfUN4nWIqlLesfrjkuPFTY=
|
||||||
|
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.7/go.mod h1:x0nZssQ3qZSnIcePWLvcoFisRXJzcTVvYpAAdYX8+GI=
|
||||||
|
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.20 h1:2HvVAIq+YqgGotK6EkMf+KIEqTISmTYh5zLpYyeTo1Y=
|
||||||
|
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.20/go.mod h1:V4X406Y666khGa8ghKmphma/7C0DAtEQYhkq9z4vpbk=
|
||||||
|
github.com/aws/aws-sdk-go-v2/service/signin v1.0.8 h1:0GFOLzEbOyZABS3PhYfBIx2rNBACYcKty+XGkTgw1ow=
|
||||||
|
github.com/aws/aws-sdk-go-v2/service/signin v1.0.8/go.mod h1:LXypKvk85AROkKhOG6/YEcHFPoX+prKTowKnVdcaIxE=
|
||||||
|
github.com/aws/aws-sdk-go-v2/service/sso v1.30.13 h1:kiIDLZ005EcKomYYITtfsjn7dtOwHDOFy7IbPXKek2o=
|
||||||
|
github.com/aws/aws-sdk-go-v2/service/sso v1.30.13/go.mod h1:2h/xGEowcW/g38g06g3KpRWDlT+OTfxxI0o1KqayAB8=
|
||||||
|
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.17 h1:jzKAXIlhZhJbnYwHbvUQZEB8KfgAEuG0dc08Bkda7NU=
|
||||||
|
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.17/go.mod h1:Al9fFsXjv4KfbzQHGe6V4NZSZQXecFcvaIF4e70FoRA=
|
||||||
|
github.com/aws/aws-sdk-go-v2/service/sts v1.41.9 h1:Cng+OOwCHmFljXIxpEVXAGMnBia8MSU6Ch5i9PgBkcU=
|
||||||
|
github.com/aws/aws-sdk-go-v2/service/sts v1.41.9/go.mod h1:LrlIndBDdjA/EeXeyNBle+gyCwTlizzW5ycgWnvIxkk=
|
||||||
|
github.com/aws/smithy-go v1.24.2 h1:FzA3bu/nt/vDvmnkg+R8Xl46gmzEDam6mZ1hzmwXFng=
|
||||||
|
github.com/aws/smithy-go v1.24.2/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc=
|
||||||
github.com/beeper/argo-go v1.1.2 h1:UQI2G8F+NLfGTOmTUI0254pGKx/HUU/etbUGTJv91Fs=
|
github.com/beeper/argo-go v1.1.2 h1:UQI2G8F+NLfGTOmTUI0254pGKx/HUU/etbUGTJv91Fs=
|
||||||
github.com/beeper/argo-go v1.1.2/go.mod h1:M+LJAnyowKVQ6Rdj6XYGEn+qcVFkb3R/MUpqkGR0hM4=
|
github.com/beeper/argo-go v1.1.2/go.mod h1:M+LJAnyowKVQ6Rdj6XYGEn+qcVFkb3R/MUpqkGR0hM4=
|
||||||
github.com/bwmarrin/discordgo v0.29.0 h1:FmWeXFaKUwrcL3Cx65c20bTRW+vOb6k8AnaP+EgjDno=
|
github.com/bwmarrin/discordgo v0.29.0 h1:FmWeXFaKUwrcL3Cx65c20bTRW+vOb6k8AnaP+EgjDno=
|
||||||
|
|||||||
@@ -0,0 +1,580 @@
|
|||||||
|
//go:build bedrock
|
||||||
|
|
||||||
|
// PicoClaw - Ultra-lightweight personal AI agent
|
||||||
|
// License: MIT
|
||||||
|
//
|
||||||
|
// Copyright (c) 2026 PicoClaw contributors
|
||||||
|
|
||||||
|
// Package bedrock implements the LLM provider interface for AWS Bedrock.
|
||||||
|
// It uses the Bedrock Runtime Converse API for unified access to multiple
|
||||||
|
// model families (Claude, Llama, Mistral, etc.) with tool/function calling support.
|
||||||
|
package bedrock
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"math"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/aws/aws-sdk-go-v2/aws"
|
||||||
|
"github.com/aws/aws-sdk-go-v2/config"
|
||||||
|
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
|
||||||
|
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime/document"
|
||||||
|
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types"
|
||||||
|
|
||||||
|
"github.com/sipeed/picoclaw/pkg/providers/common"
|
||||||
|
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
|
||||||
|
)
|
||||||
|
|
||||||
|
type (
|
||||||
|
ToolCall = protocoltypes.ToolCall
|
||||||
|
FunctionCall = protocoltypes.FunctionCall
|
||||||
|
LLMResponse = protocoltypes.LLMResponse
|
||||||
|
UsageInfo = protocoltypes.UsageInfo
|
||||||
|
Message = protocoltypes.Message
|
||||||
|
ToolDefinition = protocoltypes.ToolDefinition
|
||||||
|
ToolFunctionDefinition = protocoltypes.ToolFunctionDefinition
|
||||||
|
)
|
||||||
|
|
||||||
|
// Provider implements the LLM provider interface for AWS Bedrock.
|
||||||
|
type Provider struct {
|
||||||
|
client *bedrockruntime.Client
|
||||||
|
region string
|
||||||
|
requestTimeout time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// Option configures the Bedrock Provider.
|
||||||
|
type Option func(*providerConfig)
|
||||||
|
|
||||||
|
type providerConfig struct {
|
||||||
|
region string
|
||||||
|
profile string
|
||||||
|
baseEndpoint string
|
||||||
|
requestTimeout time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithRegion sets the AWS region for Bedrock requests.
|
||||||
|
func WithRegion(region string) Option {
|
||||||
|
return func(c *providerConfig) {
|
||||||
|
c.region = region
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithProfile sets the AWS profile to use for credentials.
|
||||||
|
func WithProfile(profile string) Option {
|
||||||
|
return func(c *providerConfig) {
|
||||||
|
c.profile = profile
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithBaseEndpoint sets a custom Bedrock endpoint URL.
|
||||||
|
// Example: https://bedrock-runtime.us-east-1.amazonaws.com
|
||||||
|
func WithBaseEndpoint(endpoint string) Option {
|
||||||
|
return func(c *providerConfig) {
|
||||||
|
c.baseEndpoint = endpoint
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithRequestTimeout sets the timeout for Bedrock API requests.
|
||||||
|
func WithRequestTimeout(timeout time.Duration) Option {
|
||||||
|
return func(c *providerConfig) {
|
||||||
|
c.requestTimeout = timeout
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewProvider creates a new AWS Bedrock provider.
|
||||||
|
// It uses the default AWS credential chain (env vars, shared config, IAM roles, etc.).
|
||||||
|
func NewProvider(ctx context.Context, opts ...Option) (*Provider, error) {
|
||||||
|
pc := &providerConfig{}
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(pc)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build AWS config options
|
||||||
|
var configOpts []func(*config.LoadOptions) error
|
||||||
|
|
||||||
|
if pc.region != "" {
|
||||||
|
configOpts = append(configOpts, config.WithRegion(pc.region))
|
||||||
|
}
|
||||||
|
|
||||||
|
if pc.profile != "" {
|
||||||
|
configOpts = append(configOpts, config.WithSharedConfigProfile(pc.profile))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load AWS config with automatic credential discovery
|
||||||
|
cfg, err := config.LoadDefaultConfig(ctx, configOpts...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("loading AWS config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate region is set - required for Bedrock request signing
|
||||||
|
if cfg.Region == "" {
|
||||||
|
return nil, fmt.Errorf("AWS region not configured: set AWS_REGION, AWS_DEFAULT_REGION, or use WithRegion option")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build client options
|
||||||
|
var clientOpts []func(*bedrockruntime.Options)
|
||||||
|
if pc.baseEndpoint != "" {
|
||||||
|
clientOpts = append(clientOpts, func(o *bedrockruntime.Options) {
|
||||||
|
o.BaseEndpoint = aws.String(pc.baseEndpoint)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
client := bedrockruntime.NewFromConfig(cfg, clientOpts...)
|
||||||
|
|
||||||
|
return &Provider{
|
||||||
|
client: client,
|
||||||
|
region: cfg.Region,
|
||||||
|
requestTimeout: pc.requestTimeout,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Chat sends messages to AWS Bedrock using the Converse API.
|
||||||
|
func (p *Provider) Chat(
|
||||||
|
ctx context.Context,
|
||||||
|
messages []Message,
|
||||||
|
tools []ToolDefinition,
|
||||||
|
model string,
|
||||||
|
options map[string]any,
|
||||||
|
) (*LLMResponse, error) {
|
||||||
|
// Apply request timeout if context doesn't already have a deadline.
|
||||||
|
// Use explicit timeout if set, otherwise fall back to common default.
|
||||||
|
effectiveTimeout := p.requestTimeout
|
||||||
|
if effectiveTimeout <= 0 {
|
||||||
|
effectiveTimeout = common.DefaultRequestTimeout
|
||||||
|
}
|
||||||
|
if _, hasDeadline := ctx.Deadline(); !hasDeadline {
|
||||||
|
var cancel context.CancelFunc
|
||||||
|
ctx, cancel = context.WithTimeout(ctx, effectiveTimeout)
|
||||||
|
defer cancel()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build the Converse API input
|
||||||
|
input := &bedrockruntime.ConverseInput{
|
||||||
|
ModelId: aws.String(model),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert messages to Bedrock format
|
||||||
|
bedrockMessages, systemPrompts := convertMessages(messages)
|
||||||
|
input.Messages = bedrockMessages
|
||||||
|
|
||||||
|
// Set system prompts if any
|
||||||
|
if len(systemPrompts) > 0 {
|
||||||
|
input.System = systemPrompts
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set inference configuration only when options are provided
|
||||||
|
var inferenceConfig *types.InferenceConfiguration
|
||||||
|
|
||||||
|
if maxTokens, ok := common.AsInt(options["max_tokens"]); ok && maxTokens > 0 {
|
||||||
|
if inferenceConfig == nil {
|
||||||
|
inferenceConfig = &types.InferenceConfiguration{}
|
||||||
|
}
|
||||||
|
// Clamp to int32 range to avoid overflow
|
||||||
|
if maxTokens > math.MaxInt32 {
|
||||||
|
maxTokens = math.MaxInt32
|
||||||
|
}
|
||||||
|
inferenceConfig.MaxTokens = aws.Int32(int32(maxTokens))
|
||||||
|
}
|
||||||
|
|
||||||
|
if temp, ok := common.AsFloat(options["temperature"]); ok {
|
||||||
|
if inferenceConfig == nil {
|
||||||
|
inferenceConfig = &types.InferenceConfiguration{}
|
||||||
|
}
|
||||||
|
inferenceConfig.Temperature = aws.Float32(float32(temp))
|
||||||
|
}
|
||||||
|
|
||||||
|
if inferenceConfig != nil {
|
||||||
|
input.InferenceConfig = inferenceConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert tools to Bedrock format
|
||||||
|
// Only set ToolConfig if at least one valid tool was produced
|
||||||
|
if len(tools) > 0 {
|
||||||
|
toolConfig := convertTools(tools)
|
||||||
|
if len(toolConfig.Tools) > 0 {
|
||||||
|
input.ToolConfig = toolConfig
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Call Bedrock Converse API
|
||||||
|
output, err := p.client.Converse(ctx, input)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("bedrock converse: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse the response
|
||||||
|
return parseResponse(output)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDefaultModel returns an empty string as Bedrock models are user-configured.
|
||||||
|
func (p *Provider) GetDefaultModel() string {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// Region returns the AWS region configured for this Provider.
|
||||||
|
func (p *Provider) Region() string {
|
||||||
|
return p.region
|
||||||
|
}
|
||||||
|
|
||||||
|
// convertMessages converts internal messages to Bedrock Converse format.
|
||||||
|
// Returns the conversation messages and any system prompts separately.
|
||||||
|
// Note: Bedrock requires all tool results for a given assistant turn to be in a single
|
||||||
|
// user message with multiple ToolResultBlock content blocks. This function merges
|
||||||
|
// consecutive tool result messages accordingly.
|
||||||
|
func convertMessages(messages []Message) ([]types.Message, []types.SystemContentBlock) {
|
||||||
|
var bedrockMessages []types.Message
|
||||||
|
var systemPrompts []types.SystemContentBlock
|
||||||
|
|
||||||
|
// Helper to check if a message is a tool result
|
||||||
|
isToolResult := func(msg Message) bool {
|
||||||
|
return (msg.Role == "tool" || (msg.Role == "user" && msg.ToolCallID != "")) && msg.ToolCallID != ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper to create a tool result content block
|
||||||
|
makeToolResultBlock := func(msg Message) types.ContentBlock {
|
||||||
|
return &types.ContentBlockMemberToolResult{
|
||||||
|
Value: types.ToolResultBlock{
|
||||||
|
ToolUseId: aws.String(msg.ToolCallID),
|
||||||
|
Content: []types.ToolResultContentBlock{
|
||||||
|
&types.ToolResultContentBlockMemberText{
|
||||||
|
Value: msg.Content,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
i := 0
|
||||||
|
for i < len(messages) {
|
||||||
|
msg := messages[i]
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case msg.Role == "system":
|
||||||
|
// System messages go to the System field
|
||||||
|
systemPrompts = append(systemPrompts, &types.SystemContentBlockMemberText{
|
||||||
|
Value: msg.Content,
|
||||||
|
})
|
||||||
|
i++
|
||||||
|
|
||||||
|
case isToolResult(msg):
|
||||||
|
// Collect all consecutive tool results into a single user message
|
||||||
|
// Bedrock requires all tool results for a turn in one message
|
||||||
|
var toolResultBlocks []types.ContentBlock
|
||||||
|
for i < len(messages) && isToolResult(messages[i]) {
|
||||||
|
toolResultBlocks = append(toolResultBlocks, makeToolResultBlock(messages[i]))
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
bedrockMessages = append(bedrockMessages, types.Message{
|
||||||
|
Role: types.ConversationRoleUser,
|
||||||
|
Content: toolResultBlocks,
|
||||||
|
})
|
||||||
|
|
||||||
|
case msg.Role == "user":
|
||||||
|
// Regular user message (no ToolCallID)
|
||||||
|
content := buildUserContent(msg)
|
||||||
|
bedrockMessages = append(bedrockMessages, types.Message{
|
||||||
|
Role: types.ConversationRoleUser,
|
||||||
|
Content: content,
|
||||||
|
})
|
||||||
|
i++
|
||||||
|
|
||||||
|
case msg.Role == "assistant":
|
||||||
|
content := buildAssistantContent(msg)
|
||||||
|
bedrockMessages = append(bedrockMessages, types.Message{
|
||||||
|
Role: types.ConversationRoleAssistant,
|
||||||
|
Content: content,
|
||||||
|
})
|
||||||
|
i++
|
||||||
|
|
||||||
|
case msg.Role == "tool" && msg.ToolCallID == "":
|
||||||
|
// Tool message without ToolCallID - treat as regular user message
|
||||||
|
content := buildUserContent(msg)
|
||||||
|
bedrockMessages = append(bedrockMessages, types.Message{
|
||||||
|
Role: types.ConversationRoleUser,
|
||||||
|
Content: content,
|
||||||
|
})
|
||||||
|
i++
|
||||||
|
|
||||||
|
default:
|
||||||
|
// Unknown role - skip
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return bedrockMessages, systemPrompts
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildUserContent builds Bedrock content blocks for a user message.
|
||||||
|
func buildUserContent(msg Message) []types.ContentBlock {
|
||||||
|
var content []types.ContentBlock
|
||||||
|
|
||||||
|
// Add text content
|
||||||
|
if msg.Content != "" {
|
||||||
|
content = append(content, &types.ContentBlockMemberText{
|
||||||
|
Value: msg.Content,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add images from Media field
|
||||||
|
for _, mediaURL := range msg.Media {
|
||||||
|
if strings.HasPrefix(mediaURL, "data:image/") {
|
||||||
|
// Parse data URL: data:image/jpeg;base64,<data>
|
||||||
|
parts := strings.SplitN(mediaURL, ",", 2)
|
||||||
|
if len(parts) != 2 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract media type from "data:image/jpeg;base64"
|
||||||
|
mediaType := ""
|
||||||
|
header := parts[0]
|
||||||
|
if idx := strings.Index(header, "/"); idx != -1 {
|
||||||
|
end := strings.Index(header[idx:], ";")
|
||||||
|
if end == -1 {
|
||||||
|
end = len(header) - idx
|
||||||
|
}
|
||||||
|
mediaType = header[idx+1 : idx+end]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify this is base64 encoded
|
||||||
|
if !strings.Contains(header, ";base64") {
|
||||||
|
continue // Skip non-base64 encoded data
|
||||||
|
}
|
||||||
|
|
||||||
|
// Map media type to Bedrock format
|
||||||
|
var format types.ImageFormat
|
||||||
|
switch mediaType {
|
||||||
|
case "jpeg", "jpg":
|
||||||
|
format = types.ImageFormatJpeg
|
||||||
|
case "png":
|
||||||
|
format = types.ImageFormatPng
|
||||||
|
case "gif":
|
||||||
|
format = types.ImageFormatGif
|
||||||
|
case "webp":
|
||||||
|
format = types.ImageFormatWebp
|
||||||
|
default:
|
||||||
|
continue // Skip unsupported formats
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check size before decoding to prevent excessive memory allocation
|
||||||
|
// Bedrock has a ~20MB request limit; cap decoded images at 10MB
|
||||||
|
const maxImageSize = 10 * 1024 * 1024
|
||||||
|
decodedLen := base64.StdEncoding.DecodedLen(len(parts[1]))
|
||||||
|
if decodedLen > maxImageSize {
|
||||||
|
log.Printf("bedrock: skipping image exceeding size limit (%d bytes > %d)", decodedLen, maxImageSize)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decode base64 data
|
||||||
|
imageData, err := base64.StdEncoding.DecodeString(parts[1])
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("bedrock: failed to decode base64 image data: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
content = append(content, &types.ContentBlockMemberImage{
|
||||||
|
Value: types.ImageBlock{
|
||||||
|
Format: format,
|
||||||
|
Source: &types.ImageSourceMemberBytes{
|
||||||
|
Value: imageData,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bedrock requires at least one content block; add empty text if needed
|
||||||
|
if len(content) == 0 {
|
||||||
|
content = append(content, &types.ContentBlockMemberText{Value: ""})
|
||||||
|
}
|
||||||
|
|
||||||
|
return content
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildAssistantContent builds Bedrock content blocks for an assistant message.
|
||||||
|
func buildAssistantContent(msg Message) []types.ContentBlock {
|
||||||
|
var content []types.ContentBlock
|
||||||
|
|
||||||
|
// Add text content if present
|
||||||
|
if msg.Content != "" {
|
||||||
|
content = append(content, &types.ContentBlockMemberText{
|
||||||
|
Value: msg.Content,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add tool use blocks
|
||||||
|
for _, tc := range msg.ToolCalls {
|
||||||
|
// Validate tool call ID - Bedrock requires non-empty ToolUseId
|
||||||
|
if strings.TrimSpace(tc.ID) == "" {
|
||||||
|
log.Printf("bedrock: skipping tool call with empty ID (name: %q)", tc.Name)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Resolve tool name: prefer tc.Name, fallback to tc.Function.Name
|
||||||
|
// (tc.Name/tc.Arguments are json:"-" and may be empty when from JSON)
|
||||||
|
toolName := tc.Name
|
||||||
|
if toolName == "" && tc.Function != nil {
|
||||||
|
toolName = tc.Function.Name
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(toolName) == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Resolve arguments: prefer tc.Arguments, fallback to parsing tc.Function.Arguments
|
||||||
|
args := tc.Arguments
|
||||||
|
if args == nil && tc.Function != nil && tc.Function.Arguments != "" {
|
||||||
|
if err := json.Unmarshal([]byte(tc.Function.Arguments), &args); err != nil {
|
||||||
|
log.Printf("bedrock: failed to parse Function.Arguments for tool %q: %v", toolName, err)
|
||||||
|
args = map[string]any{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if args == nil {
|
||||||
|
args = map[string]any{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert arguments to a Bedrock document using NewLazyDocument
|
||||||
|
inputDoc := document.NewLazyDocument(args)
|
||||||
|
|
||||||
|
content = append(content, &types.ContentBlockMemberToolUse{
|
||||||
|
Value: types.ToolUseBlock{
|
||||||
|
ToolUseId: aws.String(tc.ID),
|
||||||
|
Name: aws.String(toolName),
|
||||||
|
Input: inputDoc,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bedrock requires at least one content block; add empty text if needed
|
||||||
|
if len(content) == 0 {
|
||||||
|
content = append(content, &types.ContentBlockMemberText{Value: ""})
|
||||||
|
}
|
||||||
|
|
||||||
|
return content
|
||||||
|
}
|
||||||
|
|
||||||
|
// convertTools converts tool definitions to Bedrock format.
|
||||||
|
func convertTools(tools []ToolDefinition) *types.ToolConfiguration {
|
||||||
|
bedrockTools := make([]types.Tool, 0, len(tools))
|
||||||
|
|
||||||
|
for _, tool := range tools {
|
||||||
|
// Skip tools with empty names
|
||||||
|
if strings.TrimSpace(tool.Function.Name) == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure parameters is not nil - default to minimal object schema
|
||||||
|
params := tool.Function.Parameters
|
||||||
|
if params == nil {
|
||||||
|
params = map[string]any{
|
||||||
|
"type": "object",
|
||||||
|
"properties": map[string]any{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert parameters schema to a Bedrock document
|
||||||
|
inputSchema := document.NewLazyDocument(params)
|
||||||
|
|
||||||
|
bedrockTools = append(bedrockTools, &types.ToolMemberToolSpec{
|
||||||
|
Value: types.ToolSpecification{
|
||||||
|
Name: aws.String(tool.Function.Name),
|
||||||
|
Description: aws.String(tool.Function.Description),
|
||||||
|
InputSchema: &types.ToolInputSchemaMemberJson{
|
||||||
|
Value: inputSchema,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return &types.ToolConfiguration{
|
||||||
|
Tools: bedrockTools,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseResponse converts Bedrock Converse output to LLMResponse.
|
||||||
|
func parseResponse(output *bedrockruntime.ConverseOutput) (*LLMResponse, error) {
|
||||||
|
var content strings.Builder
|
||||||
|
toolCalls := make([]ToolCall, 0)
|
||||||
|
|
||||||
|
// Process output content blocks
|
||||||
|
if output.Output != nil {
|
||||||
|
if msgOutput, ok := output.Output.(*types.ConverseOutputMemberMessage); ok {
|
||||||
|
for _, block := range msgOutput.Value.Content {
|
||||||
|
switch b := block.(type) {
|
||||||
|
case *types.ContentBlockMemberText:
|
||||||
|
content.WriteString(b.Value)
|
||||||
|
|
||||||
|
case *types.ContentBlockMemberToolUse:
|
||||||
|
// Unmarshal the document interface to a map
|
||||||
|
args := make(map[string]any)
|
||||||
|
if b.Value.Input != nil {
|
||||||
|
if err := b.Value.Input.UnmarshalSmithyDocument(&args); err != nil {
|
||||||
|
log.Printf("bedrock: failed to unmarshal tool input for tool %q (id %q): %v",
|
||||||
|
aws.ToString(b.Value.Name),
|
||||||
|
aws.ToString(b.Value.ToolUseId),
|
||||||
|
err,
|
||||||
|
)
|
||||||
|
args = make(map[string]any)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Serialize arguments to JSON string for FunctionCall
|
||||||
|
argsJSON, err := json.Marshal(args)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("bedrock: failed to marshal tool arguments for tool %q (id %q): %v",
|
||||||
|
aws.ToString(b.Value.Name),
|
||||||
|
aws.ToString(b.Value.ToolUseId),
|
||||||
|
err,
|
||||||
|
)
|
||||||
|
argsJSON = []byte("{}")
|
||||||
|
}
|
||||||
|
|
||||||
|
toolCalls = append(toolCalls, ToolCall{
|
||||||
|
ID: aws.ToString(b.Value.ToolUseId),
|
||||||
|
Name: aws.ToString(b.Value.Name),
|
||||||
|
Arguments: args,
|
||||||
|
Function: &FunctionCall{
|
||||||
|
Name: aws.ToString(b.Value.Name),
|
||||||
|
Arguments: string(argsJSON),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Map stop reason
|
||||||
|
finishReason := "stop"
|
||||||
|
switch output.StopReason {
|
||||||
|
case types.StopReasonToolUse:
|
||||||
|
finishReason = "tool_calls"
|
||||||
|
case types.StopReasonMaxTokens:
|
||||||
|
finishReason = "length"
|
||||||
|
case types.StopReasonEndTurn:
|
||||||
|
finishReason = "stop"
|
||||||
|
case types.StopReasonStopSequence:
|
||||||
|
finishReason = "stop"
|
||||||
|
case types.StopReasonContentFiltered:
|
||||||
|
finishReason = "content_filter"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build usage info
|
||||||
|
var usage *UsageInfo
|
||||||
|
if output.Usage != nil {
|
||||||
|
usage = &UsageInfo{
|
||||||
|
PromptTokens: int(aws.ToInt32(output.Usage.InputTokens)),
|
||||||
|
CompletionTokens: int(aws.ToInt32(output.Usage.OutputTokens)),
|
||||||
|
TotalTokens: int(aws.ToInt32(output.Usage.InputTokens)) + int(aws.ToInt32(output.Usage.OutputTokens)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &LLMResponse{
|
||||||
|
Content: content.String(),
|
||||||
|
ToolCalls: toolCalls,
|
||||||
|
FinishReason: finishReason,
|
||||||
|
Usage: usage,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,541 @@
|
|||||||
|
//go:build bedrock
|
||||||
|
|
||||||
|
// PicoClaw - Ultra-lightweight personal AI agent
|
||||||
|
// License: MIT
|
||||||
|
//
|
||||||
|
// Copyright (c) 2026 PicoClaw contributors
|
||||||
|
|
||||||
|
package bedrock
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/aws/aws-sdk-go-v2/aws"
|
||||||
|
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
|
||||||
|
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime/document"
|
||||||
|
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestConvertMessages_SystemPrompts(t *testing.T) {
|
||||||
|
messages := []Message{
|
||||||
|
{Role: "system", Content: "You are a helpful assistant."},
|
||||||
|
{Role: "user", Content: "Hello"},
|
||||||
|
}
|
||||||
|
|
||||||
|
bedrockMsgs, systemPrompts := convertMessages(messages)
|
||||||
|
|
||||||
|
assert.Len(t, systemPrompts, 1)
|
||||||
|
assert.Len(t, bedrockMsgs, 1)
|
||||||
|
|
||||||
|
// Check system prompt
|
||||||
|
textBlock, ok := systemPrompts[0].(*types.SystemContentBlockMemberText)
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, "You are a helpful assistant.", textBlock.Value)
|
||||||
|
|
||||||
|
// Check user message
|
||||||
|
assert.Equal(t, types.ConversationRoleUser, bedrockMsgs[0].Role)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConvertMessages_UserMessage(t *testing.T) {
|
||||||
|
messages := []Message{
|
||||||
|
{Role: "user", Content: "What is 2+2?"},
|
||||||
|
}
|
||||||
|
|
||||||
|
bedrockMsgs, systemPrompts := convertMessages(messages)
|
||||||
|
|
||||||
|
assert.Empty(t, systemPrompts)
|
||||||
|
assert.Len(t, bedrockMsgs, 1)
|
||||||
|
assert.Equal(t, types.ConversationRoleUser, bedrockMsgs[0].Role)
|
||||||
|
|
||||||
|
textBlock, ok := bedrockMsgs[0].Content[0].(*types.ContentBlockMemberText)
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, "What is 2+2?", textBlock.Value)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConvertMessages_AssistantMessage(t *testing.T) {
|
||||||
|
messages := []Message{
|
||||||
|
{Role: "assistant", Content: "The answer is 4."},
|
||||||
|
}
|
||||||
|
|
||||||
|
bedrockMsgs, _ := convertMessages(messages)
|
||||||
|
|
||||||
|
assert.Len(t, bedrockMsgs, 1)
|
||||||
|
assert.Equal(t, types.ConversationRoleAssistant, bedrockMsgs[0].Role)
|
||||||
|
|
||||||
|
textBlock, ok := bedrockMsgs[0].Content[0].(*types.ContentBlockMemberText)
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, "The answer is 4.", textBlock.Value)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConvertMessages_ToolResult(t *testing.T) {
|
||||||
|
messages := []Message{
|
||||||
|
{Role: "tool", Content: "Result from tool", ToolCallID: "call_123"},
|
||||||
|
}
|
||||||
|
|
||||||
|
bedrockMsgs, _ := convertMessages(messages)
|
||||||
|
|
||||||
|
assert.Len(t, bedrockMsgs, 1)
|
||||||
|
assert.Equal(t, types.ConversationRoleUser, bedrockMsgs[0].Role)
|
||||||
|
|
||||||
|
toolResult, ok := bedrockMsgs[0].Content[0].(*types.ContentBlockMemberToolResult)
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, "call_123", aws.ToString(toolResult.Value.ToolUseId))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConvertMessages_MultipleToolResultsMerged(t *testing.T) {
|
||||||
|
// When an assistant makes multiple tool calls, all tool results must be
|
||||||
|
// merged into a single user message for Bedrock
|
||||||
|
messages := []Message{
|
||||||
|
{Role: "user", Content: "What's the weather in NYC and LA?"},
|
||||||
|
{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: "Let me check both cities.",
|
||||||
|
ToolCalls: []protocoltypes.ToolCall{
|
||||||
|
{ID: "call_nyc", Name: "get_weather", Arguments: map[string]any{"city": "NYC"}},
|
||||||
|
{ID: "call_la", Name: "get_weather", Arguments: map[string]any{"city": "LA"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{Role: "tool", Content: "NYC: 72°F, sunny", ToolCallID: "call_nyc"},
|
||||||
|
{Role: "tool", Content: "LA: 85°F, clear", ToolCallID: "call_la"},
|
||||||
|
}
|
||||||
|
|
||||||
|
bedrockMsgs, _ := convertMessages(messages)
|
||||||
|
|
||||||
|
// Should be: user message, assistant message, merged tool results (single user message)
|
||||||
|
assert.Len(t, bedrockMsgs, 3)
|
||||||
|
|
||||||
|
// First message: user
|
||||||
|
assert.Equal(t, types.ConversationRoleUser, bedrockMsgs[0].Role)
|
||||||
|
|
||||||
|
// Second message: assistant with tool calls
|
||||||
|
assert.Equal(t, types.ConversationRoleAssistant, bedrockMsgs[1].Role)
|
||||||
|
|
||||||
|
// Third message: merged tool results in single user message
|
||||||
|
assert.Equal(t, types.ConversationRoleUser, bedrockMsgs[2].Role)
|
||||||
|
assert.Len(t, bedrockMsgs[2].Content, 2) // Both tool results in one message
|
||||||
|
|
||||||
|
// Verify both tool results are present
|
||||||
|
result1, ok := bedrockMsgs[2].Content[0].(*types.ContentBlockMemberToolResult)
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, "call_nyc", aws.ToString(result1.Value.ToolUseId))
|
||||||
|
|
||||||
|
result2, ok := bedrockMsgs[2].Content[1].(*types.ContentBlockMemberToolResult)
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, "call_la", aws.ToString(result2.Value.ToolUseId))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConvertMessages_AssistantWithToolCalls(t *testing.T) {
|
||||||
|
messages := []Message{
|
||||||
|
{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: "Let me calculate that.",
|
||||||
|
ToolCalls: []protocoltypes.ToolCall{
|
||||||
|
{
|
||||||
|
ID: "call_456",
|
||||||
|
Name: "calculator",
|
||||||
|
Arguments: map[string]any{"expression": "2+2"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
bedrockMsgs, _ := convertMessages(messages)
|
||||||
|
|
||||||
|
assert.Len(t, bedrockMsgs, 1)
|
||||||
|
assert.Len(t, bedrockMsgs[0].Content, 2) // text + tool use
|
||||||
|
|
||||||
|
// Check text content
|
||||||
|
textBlock, ok := bedrockMsgs[0].Content[0].(*types.ContentBlockMemberText)
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, "Let me calculate that.", textBlock.Value)
|
||||||
|
|
||||||
|
// Check tool use
|
||||||
|
toolUse, ok := bedrockMsgs[0].Content[1].(*types.ContentBlockMemberToolUse)
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, "call_456", aws.ToString(toolUse.Value.ToolUseId))
|
||||||
|
assert.Equal(t, "calculator", aws.ToString(toolUse.Value.Name))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConvertTools_Basic(t *testing.T) {
|
||||||
|
tools := []ToolDefinition{
|
||||||
|
{
|
||||||
|
Function: protocoltypes.ToolFunctionDefinition{
|
||||||
|
Name: "get_weather",
|
||||||
|
Description: "Get the current weather",
|
||||||
|
Parameters: map[string]any{
|
||||||
|
"type": "object",
|
||||||
|
"properties": map[string]any{
|
||||||
|
"location": map[string]any{"type": "string"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
toolConfig := convertTools(tools)
|
||||||
|
|
||||||
|
assert.NotNil(t, toolConfig)
|
||||||
|
assert.Len(t, toolConfig.Tools, 1)
|
||||||
|
|
||||||
|
toolSpec, ok := toolConfig.Tools[0].(*types.ToolMemberToolSpec)
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, "get_weather", aws.ToString(toolSpec.Value.Name))
|
||||||
|
assert.Equal(t, "Get the current weather", aws.ToString(toolSpec.Value.Description))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConvertTools_SkipsEmptyName(t *testing.T) {
|
||||||
|
tools := []ToolDefinition{
|
||||||
|
{
|
||||||
|
Function: protocoltypes.ToolFunctionDefinition{
|
||||||
|
Name: "",
|
||||||
|
Description: "Empty name tool",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Function: protocoltypes.ToolFunctionDefinition{
|
||||||
|
Name: " ",
|
||||||
|
Description: "Whitespace name tool",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Function: protocoltypes.ToolFunctionDefinition{
|
||||||
|
Name: "valid_tool",
|
||||||
|
Description: "Valid tool",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
toolConfig := convertTools(tools)
|
||||||
|
|
||||||
|
assert.Len(t, toolConfig.Tools, 1)
|
||||||
|
toolSpec := toolConfig.Tools[0].(*types.ToolMemberToolSpec)
|
||||||
|
assert.Equal(t, "valid_tool", aws.ToString(toolSpec.Value.Name))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConvertTools_NilParameters(t *testing.T) {
|
||||||
|
tools := []ToolDefinition{
|
||||||
|
{
|
||||||
|
Function: protocoltypes.ToolFunctionDefinition{
|
||||||
|
Name: "simple_tool",
|
||||||
|
Description: "A tool with no parameters",
|
||||||
|
Parameters: nil,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
toolConfig := convertTools(tools)
|
||||||
|
|
||||||
|
assert.Len(t, toolConfig.Tools, 1)
|
||||||
|
// Should not panic and should create a valid tool
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildUserContent_TextOnly(t *testing.T) {
|
||||||
|
msg := Message{Content: "Hello world"}
|
||||||
|
|
||||||
|
content := buildUserContent(msg)
|
||||||
|
|
||||||
|
assert.Len(t, content, 1)
|
||||||
|
textBlock, ok := content[0].(*types.ContentBlockMemberText)
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, "Hello world", textBlock.Value)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildUserContent_WithImage(t *testing.T) {
|
||||||
|
// Base64-encoded 1x1 PNG (the provider doesn't validate image correctness,
|
||||||
|
// it just verifies the format and base64 decoding works)
|
||||||
|
b64Data := "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADUlEQVR4nGNgYAAAAAMAASsJTYQAAAAASUVORK5CYII="
|
||||||
|
|
||||||
|
msg := Message{
|
||||||
|
Content: "Look at this image",
|
||||||
|
Media: []string{"data:image/png;base64," + b64Data},
|
||||||
|
}
|
||||||
|
|
||||||
|
content := buildUserContent(msg)
|
||||||
|
|
||||||
|
assert.Len(t, content, 2)
|
||||||
|
|
||||||
|
// Check text
|
||||||
|
textBlock, ok := content[0].(*types.ContentBlockMemberText)
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, "Look at this image", textBlock.Value)
|
||||||
|
|
||||||
|
// Check image
|
||||||
|
imageBlock, ok := content[1].(*types.ContentBlockMemberImage)
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, types.ImageFormatPng, imageBlock.Value.Format)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildUserContent_SkipsInvalidBase64(t *testing.T) {
|
||||||
|
msg := Message{
|
||||||
|
Content: "Invalid image",
|
||||||
|
Media: []string{"data:image/png;base64,not-valid-base64!!!"},
|
||||||
|
}
|
||||||
|
|
||||||
|
content := buildUserContent(msg)
|
||||||
|
|
||||||
|
// Should only have text, image should be skipped
|
||||||
|
assert.Len(t, content, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildUserContent_SkipsNonBase64Data(t *testing.T) {
|
||||||
|
msg := Message{
|
||||||
|
Content: "Non-base64 image",
|
||||||
|
Media: []string{"data:image/png,raw-data-here"},
|
||||||
|
}
|
||||||
|
|
||||||
|
content := buildUserContent(msg)
|
||||||
|
|
||||||
|
// Should only have text, non-base64 image should be skipped
|
||||||
|
assert.Len(t, content, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildAssistantContent_SkipsEmptyToolName(t *testing.T) {
|
||||||
|
msg := Message{
|
||||||
|
Content: "Response",
|
||||||
|
ToolCalls: []protocoltypes.ToolCall{
|
||||||
|
{ID: "1", Name: "", Arguments: map[string]any{}},
|
||||||
|
{ID: "2", Name: " ", Arguments: map[string]any{}},
|
||||||
|
{ID: "3", Name: "valid", Arguments: map[string]any{}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
content := buildAssistantContent(msg)
|
||||||
|
|
||||||
|
// Should have text + 1 valid tool
|
||||||
|
assert.Len(t, content, 2)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildAssistantContent_NilArguments(t *testing.T) {
|
||||||
|
msg := Message{
|
||||||
|
ToolCalls: []protocoltypes.ToolCall{
|
||||||
|
{ID: "1", Name: "tool", Arguments: nil},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
content := buildAssistantContent(msg)
|
||||||
|
|
||||||
|
assert.Len(t, content, 1)
|
||||||
|
toolUse, ok := content[0].(*types.ContentBlockMemberToolUse)
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.NotNil(t, toolUse.Value.Input)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildAssistantContent_FunctionFallback(t *testing.T) {
|
||||||
|
// When Name/Arguments are empty (json:"-"), should fallback to Function fields
|
||||||
|
msg := Message{
|
||||||
|
ToolCalls: []protocoltypes.ToolCall{
|
||||||
|
{
|
||||||
|
ID: "1",
|
||||||
|
Name: "", // empty, should fallback to Function.Name
|
||||||
|
Function: &protocoltypes.FunctionCall{
|
||||||
|
Name: "fallback_tool",
|
||||||
|
Arguments: `{"key":"value"}`,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
content := buildAssistantContent(msg)
|
||||||
|
|
||||||
|
assert.Len(t, content, 1)
|
||||||
|
toolUse, ok := content[0].(*types.ContentBlockMemberToolUse)
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, "fallback_tool", aws.ToString(toolUse.Value.Name))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseResponse_TextOnly(t *testing.T) {
|
||||||
|
output := &bedrockruntime.ConverseOutput{
|
||||||
|
Output: &types.ConverseOutputMemberMessage{
|
||||||
|
Value: types.Message{
|
||||||
|
Role: types.ConversationRoleAssistant,
|
||||||
|
Content: []types.ContentBlock{
|
||||||
|
&types.ContentBlockMemberText{Value: "Hello!"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
StopReason: types.StopReasonEndTurn,
|
||||||
|
Usage: &types.TokenUsage{
|
||||||
|
InputTokens: aws.Int32(10),
|
||||||
|
OutputTokens: aws.Int32(5),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := parseResponse(output)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "Hello!", resp.Content)
|
||||||
|
assert.Equal(t, "stop", resp.FinishReason)
|
||||||
|
assert.Empty(t, resp.ToolCalls)
|
||||||
|
assert.Equal(t, 10, resp.Usage.PromptTokens)
|
||||||
|
assert.Equal(t, 5, resp.Usage.CompletionTokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseResponse_StopReasons(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
stopReason types.StopReason
|
||||||
|
expectedFinish string
|
||||||
|
}{
|
||||||
|
{types.StopReasonEndTurn, "stop"},
|
||||||
|
{types.StopReasonToolUse, "tool_calls"},
|
||||||
|
{types.StopReasonMaxTokens, "length"},
|
||||||
|
{types.StopReasonStopSequence, "stop"},
|
||||||
|
{types.StopReasonContentFiltered, "content_filter"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(string(tt.stopReason), func(t *testing.T) {
|
||||||
|
output := &bedrockruntime.ConverseOutput{
|
||||||
|
Output: &types.ConverseOutputMemberMessage{
|
||||||
|
Value: types.Message{
|
||||||
|
Content: []types.ContentBlock{
|
||||||
|
&types.ContentBlockMemberText{Value: "test"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
StopReason: tt.stopReason,
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := parseResponse(output)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, tt.expectedFinish, resp.FinishReason)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseResponse_WithToolCalls(t *testing.T) {
|
||||||
|
// Note: document.NewLazyDocument has limitations with UnmarshalSmithyDocument in tests,
|
||||||
|
// so we test the structure extraction and verify Arguments gets populated (even if empty
|
||||||
|
// due to SDK limitations). The actual unmarshal works correctly at runtime.
|
||||||
|
toolInput := document.NewLazyDocument(map[string]any{
|
||||||
|
"location": "San Francisco",
|
||||||
|
"unit": "celsius",
|
||||||
|
})
|
||||||
|
|
||||||
|
output := &bedrockruntime.ConverseOutput{
|
||||||
|
Output: &types.ConverseOutputMemberMessage{
|
||||||
|
Value: types.Message{
|
||||||
|
Role: types.ConversationRoleAssistant,
|
||||||
|
Content: []types.ContentBlock{
|
||||||
|
&types.ContentBlockMemberText{Value: "Let me check the weather."},
|
||||||
|
&types.ContentBlockMemberToolUse{
|
||||||
|
Value: types.ToolUseBlock{
|
||||||
|
ToolUseId: aws.String("call_weather_123"),
|
||||||
|
Name: aws.String("get_weather"),
|
||||||
|
Input: toolInput,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
StopReason: types.StopReasonToolUse,
|
||||||
|
Usage: &types.TokenUsage{
|
||||||
|
InputTokens: aws.Int32(20),
|
||||||
|
OutputTokens: aws.Int32(15),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := parseResponse(output)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "Let me check the weather.", resp.Content)
|
||||||
|
assert.Equal(t, "tool_calls", resp.FinishReason)
|
||||||
|
assert.Len(t, resp.ToolCalls, 1)
|
||||||
|
|
||||||
|
// Verify tool call ID and Name are extracted correctly
|
||||||
|
tc := resp.ToolCalls[0]
|
||||||
|
assert.Equal(t, "call_weather_123", tc.ID)
|
||||||
|
assert.Equal(t, "get_weather", tc.Name)
|
||||||
|
|
||||||
|
// Verify Function fields are also populated
|
||||||
|
require.NotNil(t, tc.Function)
|
||||||
|
assert.Equal(t, "get_weather", tc.Function.Name)
|
||||||
|
|
||||||
|
// Verify Arguments is not nil (content may vary due to SDK limitations in tests)
|
||||||
|
assert.NotNil(t, tc.Arguments)
|
||||||
|
|
||||||
|
// Verify usage
|
||||||
|
assert.Equal(t, 20, resp.Usage.PromptTokens)
|
||||||
|
assert.Equal(t, 15, resp.Usage.CompletionTokens)
|
||||||
|
assert.Equal(t, 35, resp.Usage.TotalTokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseResponse_MultipleToolCalls(t *testing.T) {
|
||||||
|
output := &bedrockruntime.ConverseOutput{
|
||||||
|
Output: &types.ConverseOutputMemberMessage{
|
||||||
|
Value: types.Message{
|
||||||
|
Role: types.ConversationRoleAssistant,
|
||||||
|
Content: []types.ContentBlock{
|
||||||
|
&types.ContentBlockMemberToolUse{
|
||||||
|
Value: types.ToolUseBlock{
|
||||||
|
ToolUseId: aws.String("call_1"),
|
||||||
|
Name: aws.String("tool_a"),
|
||||||
|
Input: document.NewLazyDocument(map[string]any{"arg": "value1"}),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
&types.ContentBlockMemberToolUse{
|
||||||
|
Value: types.ToolUseBlock{
|
||||||
|
ToolUseId: aws.String("call_2"),
|
||||||
|
Name: aws.String("tool_b"),
|
||||||
|
Input: document.NewLazyDocument(map[string]any{"arg": "value2"}),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
StopReason: types.StopReasonToolUse,
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := parseResponse(output)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "tool_calls", resp.FinishReason)
|
||||||
|
assert.Len(t, resp.ToolCalls, 2)
|
||||||
|
|
||||||
|
// Verify tool call structure
|
||||||
|
assert.Equal(t, "call_1", resp.ToolCalls[0].ID)
|
||||||
|
assert.Equal(t, "tool_a", resp.ToolCalls[0].Name)
|
||||||
|
assert.NotNil(t, resp.ToolCalls[0].Arguments)
|
||||||
|
assert.NotNil(t, resp.ToolCalls[0].Function)
|
||||||
|
assert.Equal(t, "tool_a", resp.ToolCalls[0].Function.Name)
|
||||||
|
|
||||||
|
assert.Equal(t, "call_2", resp.ToolCalls[1].ID)
|
||||||
|
assert.Equal(t, "tool_b", resp.ToolCalls[1].Name)
|
||||||
|
assert.NotNil(t, resp.ToolCalls[1].Arguments)
|
||||||
|
assert.NotNil(t, resp.ToolCalls[1].Function)
|
||||||
|
assert.Equal(t, "tool_b", resp.ToolCalls[1].Function.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseResponse_ToolCallWithNilInput(t *testing.T) {
|
||||||
|
output := &bedrockruntime.ConverseOutput{
|
||||||
|
Output: &types.ConverseOutputMemberMessage{
|
||||||
|
Value: types.Message{
|
||||||
|
Role: types.ConversationRoleAssistant,
|
||||||
|
Content: []types.ContentBlock{
|
||||||
|
&types.ContentBlockMemberToolUse{
|
||||||
|
Value: types.ToolUseBlock{
|
||||||
|
ToolUseId: aws.String("call_nil"),
|
||||||
|
Name: aws.String("no_args_tool"),
|
||||||
|
Input: nil,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
StopReason: types.StopReasonToolUse,
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := parseResponse(output)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Len(t, resp.ToolCalls, 1)
|
||||||
|
assert.Equal(t, "call_nil", resp.ToolCalls[0].ID)
|
||||||
|
assert.Equal(t, "no_args_tool", resp.ToolCalls[0].Name)
|
||||||
|
// Arguments should be empty map, not nil
|
||||||
|
assert.NotNil(t, resp.ToolCalls[0].Arguments)
|
||||||
|
assert.Empty(t, resp.ToolCalls[0].Arguments)
|
||||||
|
}
|
||||||
@@ -0,0 +1,73 @@
|
|||||||
|
//go:build !bedrock
|
||||||
|
|
||||||
|
// PicoClaw - Ultra-lightweight personal AI agent
|
||||||
|
// License: MIT
|
||||||
|
//
|
||||||
|
// Copyright (c) 2026 PicoClaw contributors
|
||||||
|
|
||||||
|
// Package bedrock provides a stub implementation when built without the bedrock tag.
|
||||||
|
// To enable AWS Bedrock support, build with: go build -tags bedrock
|
||||||
|
package bedrock
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
|
||||||
|
)
|
||||||
|
|
||||||
|
type (
|
||||||
|
LLMResponse = protocoltypes.LLMResponse
|
||||||
|
Message = protocoltypes.Message
|
||||||
|
ToolDefinition = protocoltypes.ToolDefinition
|
||||||
|
)
|
||||||
|
|
||||||
|
// Provider is a stub that returns an error when Bedrock support is not compiled in.
|
||||||
|
type Provider struct{}
|
||||||
|
|
||||||
|
// Option is a no-op when Bedrock is not enabled.
|
||||||
|
type Option func(*providerConfig)
|
||||||
|
|
||||||
|
type providerConfig struct{}
|
||||||
|
|
||||||
|
// WithRegion is a no-op when Bedrock is not enabled.
|
||||||
|
func WithRegion(region string) Option {
|
||||||
|
return func(c *providerConfig) {}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithProfile is a no-op when Bedrock is not enabled.
|
||||||
|
func WithProfile(profile string) Option {
|
||||||
|
return func(c *providerConfig) {}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithBaseEndpoint is a no-op when Bedrock is not enabled.
|
||||||
|
func WithBaseEndpoint(endpoint string) Option {
|
||||||
|
return func(c *providerConfig) {}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithRequestTimeout is a no-op when Bedrock is not enabled.
|
||||||
|
func WithRequestTimeout(timeout time.Duration) Option {
|
||||||
|
return func(c *providerConfig) {}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewProvider returns an error indicating Bedrock support is not compiled in.
|
||||||
|
func NewProvider(ctx context.Context, opts ...Option) (*Provider, error) {
|
||||||
|
return nil, fmt.Errorf("bedrock provider not available: build with -tags bedrock to enable AWS Bedrock support")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Chat returns an error - this should never be called since NewProvider fails.
|
||||||
|
func (p *Provider) Chat(
|
||||||
|
ctx context.Context,
|
||||||
|
messages []Message,
|
||||||
|
tools []ToolDefinition,
|
||||||
|
model string,
|
||||||
|
options map[string]any,
|
||||||
|
) (*LLMResponse, error) {
|
||||||
|
return nil, fmt.Errorf("bedrock provider not available: build with -tags bedrock to enable AWS Bedrock support")
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDefaultModel returns an empty string.
|
||||||
|
func (p *Provider) GetDefaultModel() string {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
@@ -0,0 +1,35 @@
|
|||||||
|
//go:build !bedrock
|
||||||
|
|
||||||
|
// PicoClaw - Ultra-lightweight personal AI agent
|
||||||
|
// License: MIT
|
||||||
|
//
|
||||||
|
// Copyright (c) 2026 PicoClaw contributors
|
||||||
|
|
||||||
|
package bedrock
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewProvider_ReturnsStubError(t *testing.T) {
|
||||||
|
provider, err := NewProvider(context.Background())
|
||||||
|
|
||||||
|
assert.Nil(t, provider)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.True(t, strings.Contains(err.Error(), "build with -tags bedrock"),
|
||||||
|
"error should mention build tag requirement, got: %s", err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewProvider_WithOptions_ReturnsStubError(t *testing.T) {
|
||||||
|
provider, err := NewProvider(context.Background(), WithRegion("us-west-2"), WithProfile("test"))
|
||||||
|
|
||||||
|
assert.Nil(t, provider)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.True(t, strings.Contains(err.Error(), "build with -tags bedrock"),
|
||||||
|
"error should mention build tag requirement, got: %s", err.Error())
|
||||||
|
}
|
||||||
@@ -6,12 +6,15 @@
|
|||||||
package providers
|
package providers
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/sipeed/picoclaw/pkg/config"
|
"github.com/sipeed/picoclaw/pkg/config"
|
||||||
anthropicmessages "github.com/sipeed/picoclaw/pkg/providers/anthropic_messages"
|
anthropicmessages "github.com/sipeed/picoclaw/pkg/providers/anthropic_messages"
|
||||||
"github.com/sipeed/picoclaw/pkg/providers/azure"
|
"github.com/sipeed/picoclaw/pkg/providers/azure"
|
||||||
|
"github.com/sipeed/picoclaw/pkg/providers/bedrock"
|
||||||
)
|
)
|
||||||
|
|
||||||
// createClaudeAuthProvider creates a Claude provider using OAuth credentials from auth store.
|
// createClaudeAuthProvider creates a Claude provider using OAuth credentials from auth store.
|
||||||
@@ -55,8 +58,9 @@ func ExtractProtocol(model string) (protocol, modelID string) {
|
|||||||
|
|
||||||
// CreateProviderFromConfig creates a provider based on the ModelConfig.
|
// CreateProviderFromConfig creates a provider based on the ModelConfig.
|
||||||
// It uses the protocol prefix in the Model field to determine which provider to create.
|
// It uses the protocol prefix in the Model field to determine which provider to create.
|
||||||
// Supported protocols: openai, litellm, novita, anthropic, anthropic-messages,
|
// Supported protocol families include OpenAI-compatible prefixes (e.g., openai, openrouter, groq, gemini),
|
||||||
// antigravity, claude-cli, codex-cli, github-copilot
|
// Azure OpenAI, Amazon Bedrock, Anthropic (including messages), and various CLI/compatibility shims.
|
||||||
|
// See the switch on protocol in this function for the authoritative list.
|
||||||
// Returns the provider, the model ID (without protocol prefix), and any error.
|
// Returns the provider, the model ID (without protocol prefix), and any error.
|
||||||
func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, error) {
|
func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, error) {
|
||||||
if cfg == nil {
|
if cfg == nil {
|
||||||
@@ -114,6 +118,42 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
|
|||||||
cfg.RequestTimeout,
|
cfg.RequestTimeout,
|
||||||
), modelID, nil
|
), modelID, nil
|
||||||
|
|
||||||
|
case "bedrock":
|
||||||
|
// AWS Bedrock uses AWS SDK credentials (env vars, profiles, IAM roles, etc.)
|
||||||
|
// api_base can be:
|
||||||
|
// - A full endpoint URL: https://bedrock-runtime.us-east-1.amazonaws.com
|
||||||
|
// - A region name: us-east-1 (AWS SDK resolves endpoint automatically)
|
||||||
|
var opts []bedrock.Option
|
||||||
|
if cfg.APIBase != "" {
|
||||||
|
if !strings.Contains(cfg.APIBase, "://") {
|
||||||
|
// Treat as region: let AWS SDK resolve the correct endpoint
|
||||||
|
// (supports all AWS partitions: aws, aws-cn, aws-us-gov, etc.)
|
||||||
|
opts = append(opts, bedrock.WithRegion(cfg.APIBase))
|
||||||
|
} else {
|
||||||
|
// Full endpoint URL provided (for custom endpoints or testing)
|
||||||
|
opts = append(opts, bedrock.WithBaseEndpoint(cfg.APIBase))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Use a separate timeout for AWS config loading (credential resolution can block)
|
||||||
|
initTimeout := 30 * time.Second
|
||||||
|
if cfg.RequestTimeout > 0 {
|
||||||
|
reqTimeout := time.Duration(cfg.RequestTimeout) * time.Second
|
||||||
|
// Set request timeout for API calls
|
||||||
|
opts = append(opts, bedrock.WithRequestTimeout(reqTimeout))
|
||||||
|
// Ensure init timeout is at least as large as request timeout
|
||||||
|
if reqTimeout > initTimeout {
|
||||||
|
initTimeout = reqTimeout
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), initTimeout)
|
||||||
|
defer cancel()
|
||||||
|
// Note: AWS_PROFILE env var is automatically used by AWS SDK
|
||||||
|
provider, err := bedrock.NewProvider(ctx, opts...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", fmt.Errorf("creating bedrock provider: %w", err)
|
||||||
|
}
|
||||||
|
return provider, modelID, nil
|
||||||
|
|
||||||
case "litellm", "openrouter", "groq", "zhipu", "gemini", "nvidia",
|
case "litellm", "openrouter", "groq", "zhipu", "gemini", "nvidia",
|
||||||
"ollama", "moonshot", "shengsuanyun", "deepseek", "cerebras",
|
"ollama", "moonshot", "shengsuanyun", "deepseek", "cerebras",
|
||||||
"vivgrid", "volcengine", "vllm", "qwen", "qwen-intl", "qwen-international", "dashscope-intl",
|
"vivgrid", "volcengine", "vllm", "qwen", "qwen-intl", "qwen-international", "dashscope-intl",
|
||||||
|
|||||||
@@ -700,3 +700,78 @@ func TestCreateProviderFromConfig_MinimaxPreservesUserExtraBody(t *testing.T) {
|
|||||||
t.Fatalf("custom_field = %v, want test", got)
|
t.Fatalf("custom_field = %v, want test", got)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCreateProviderFromConfig_Bedrock(t *testing.T) {
|
||||||
|
// Set dummy AWS env vars to make test deterministic
|
||||||
|
t.Setenv("AWS_ACCESS_KEY_ID", "test-key")
|
||||||
|
t.Setenv("AWS_SECRET_ACCESS_KEY", "test-secret")
|
||||||
|
t.Setenv("AWS_EC2_METADATA_DISABLED", "true")
|
||||||
|
// Clear profile-related env vars to avoid loading shared config
|
||||||
|
t.Setenv("AWS_PROFILE", "")
|
||||||
|
t.Setenv("AWS_DEFAULT_PROFILE", "")
|
||||||
|
t.Setenv("AWS_SDK_LOAD_CONFIG", "")
|
||||||
|
t.Setenv("AWS_SHARED_CREDENTIALS_FILE", "")
|
||||||
|
|
||||||
|
cfg := &config.ModelConfig{
|
||||||
|
ModelName: "bedrock-claude",
|
||||||
|
Model: "bedrock/us.anthropic.claude-sonnet-4-20250514-v1:0",
|
||||||
|
APIBase: "us-west-2", // Region (also sets AWS region)
|
||||||
|
}
|
||||||
|
|
||||||
|
provider, modelID, err := CreateProviderFromConfig(cfg)
|
||||||
|
if err == nil {
|
||||||
|
// Provider created successfully (built with -tags bedrock)
|
||||||
|
if provider == nil {
|
||||||
|
t.Error("provider is nil on success")
|
||||||
|
}
|
||||||
|
if modelID != "us.anthropic.claude-sonnet-4-20250514-v1:0" {
|
||||||
|
t.Errorf("modelID = %q, want %q", modelID, "us.anthropic.claude-sonnet-4-20250514-v1:0")
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
errMsg := err.Error()
|
||||||
|
// When built without -tags bedrock, expect stub error
|
||||||
|
if strings.Contains(errMsg, "build with -tags bedrock") {
|
||||||
|
return // Expected stub error
|
||||||
|
}
|
||||||
|
// Unexpected error - fail the test
|
||||||
|
t.Errorf("unexpected error from bedrock provider: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateProviderFromConfig_BedrockWithEndpointURL(t *testing.T) {
|
||||||
|
// Set dummy AWS env vars to make test deterministic
|
||||||
|
t.Setenv("AWS_ACCESS_KEY_ID", "test-key")
|
||||||
|
t.Setenv("AWS_SECRET_ACCESS_KEY", "test-secret")
|
||||||
|
t.Setenv("AWS_REGION", "us-east-1") // Required when using endpoint URL
|
||||||
|
t.Setenv("AWS_EC2_METADATA_DISABLED", "true")
|
||||||
|
// Clear profile-related env vars to avoid loading shared config
|
||||||
|
t.Setenv("AWS_PROFILE", "")
|
||||||
|
t.Setenv("AWS_DEFAULT_PROFILE", "")
|
||||||
|
t.Setenv("AWS_SDK_LOAD_CONFIG", "")
|
||||||
|
t.Setenv("AWS_SHARED_CREDENTIALS_FILE", "")
|
||||||
|
|
||||||
|
cfg := &config.ModelConfig{
|
||||||
|
ModelName: "bedrock-claude",
|
||||||
|
Model: "bedrock/us.anthropic.claude-sonnet-4-20250514-v1:0",
|
||||||
|
APIBase: "https://bedrock-runtime.us-east-1.amazonaws.com", // Full endpoint URL
|
||||||
|
}
|
||||||
|
|
||||||
|
provider, modelID, err := CreateProviderFromConfig(cfg)
|
||||||
|
if err == nil {
|
||||||
|
// Provider created successfully (built with -tags bedrock)
|
||||||
|
if provider == nil {
|
||||||
|
t.Error("provider is nil on success")
|
||||||
|
}
|
||||||
|
if modelID != "us.anthropic.claude-sonnet-4-20250514-v1:0" {
|
||||||
|
t.Errorf("modelID = %q, want %q", modelID, "us.anthropic.claude-sonnet-4-20250514-v1:0")
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
errMsg := err.Error()
|
||||||
|
// When built without -tags bedrock, expect stub error
|
||||||
|
if strings.Contains(errMsg, "build with -tags bedrock") {
|
||||||
|
return // Expected stub error
|
||||||
|
}
|
||||||
|
// Unexpected error - fail the test
|
||||||
|
t.Errorf("unexpected error from bedrock provider: %v", err)
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user