diff --git a/README.md b/README.md index e25366ef8..72d38103c 100644 --- a/README.md +++ b/README.md @@ -373,6 +373,9 @@ PicoClaw supports 30+ LLM providers through the `model_list` configuration. Use | [Azure OpenAI](https://portal.azure.com/) | `azure/` | Required | Enterprise Azure deployment | | [GitHub Copilot](https://github.com/features/copilot) | `github-copilot/` | OAuth | Device code login | | [Antigravity](https://console.cloud.google.com/) | `antigravity/` | OAuth | Google Cloud AI | +| [AWS Bedrock](https://console.aws.amazon.com/bedrock)* | `bedrock/` | AWS credentials | Claude, Llama, Mistral on AWS | + +> \* AWS Bedrock requires build tag: `go build -tags bedrock`. Set `api_base` to a region name (e.g., `us-east-1`) for automatic endpoint resolution across all AWS partitions (aws, aws-cn, aws-us-gov). When using a full endpoint URL instead, you must also configure `AWS_REGION` via environment variable or AWS config/profile.
Local deployment (Ollama, vLLM, etc.) diff --git a/go.mod b/go.mod index e4b6f37fd..bce41d0d3 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,9 @@ require ( github.com/BurntSushi/toml v1.6.0 github.com/adhocore/gronx v1.19.6 github.com/anthropics/anthropic-sdk-go v1.26.0 + github.com/aws/aws-sdk-go-v2 v1.41.4 + github.com/aws/aws-sdk-go-v2/config v1.32.12 + github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.50.2 github.com/bwmarrin/discordgo v0.29.0 github.com/caarlos0/env/v11 v11.4.0 github.com/ergochat/irc-go v0.6.0 @@ -40,6 +43,19 @@ require ( require ( filippo.io/edwards25519 v1.2.0 // indirect + github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.7 // indirect + github.com/aws/aws-sdk-go-v2/credentials v1.19.12 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.20 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.20 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.20 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.8.6 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.7 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.20 // indirect + github.com/aws/aws-sdk-go-v2/service/signin v1.0.8 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.30.13 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.17 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.41.9 // indirect + github.com/aws/smithy-go v1.24.2 // indirect github.com/beeper/argo-go v1.1.2 // indirect github.com/coder/websocket v1.8.14 // indirect github.com/davecgh/go-spew v1.1.1 // indirect diff --git a/go.sum b/go.sum index f24b997d4..87117bc98 100644 --- a/go.sum +++ b/go.sum @@ -17,6 +17,38 @@ github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwTo github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY= github.com/anthropics/anthropic-sdk-go v1.26.0 h1:oUTzFaUpAevfuELAP1sjL6CQJ9HHAfT7CoSYSac11PY= github.com/anthropics/anthropic-sdk-go v1.26.0/go.mod h1:qUKmaW+uuPB64iy1l+4kOSvaLqPXnHTTBKH6RVZ7q5Q= +github.com/aws/aws-sdk-go-v2 v1.41.4 h1:10f50G7WyU02T56ox1wWXq+zTX9I1zxG46HYuG1hH/k= +github.com/aws/aws-sdk-go-v2 v1.41.4/go.mod h1:mwsPRE8ceUUpiTgF7QmQIJ7lgsKUPQOUl3o72QBrE1o= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.7 h1:3kGOqnh1pPeddVa/E37XNTaWJ8W6vrbYV9lJEkCnhuY= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.7/go.mod h1:lyw7GFp3qENLh7kwzf7iMzAxDn+NzjXEAGjKS2UOKqI= +github.com/aws/aws-sdk-go-v2/config v1.32.12 h1:O3csC7HUGn2895eNrLytOJQdoL2xyJy0iYXhoZ1OmP0= +github.com/aws/aws-sdk-go-v2/config v1.32.12/go.mod h1:96zTvoOFR4FURjI+/5wY1vc1ABceROO4lWgWJuxgy0g= +github.com/aws/aws-sdk-go-v2/credentials v1.19.12 h1:oqtA6v+y5fZg//tcTWahyN9PEn5eDU/Wpvc2+kJ4aY8= +github.com/aws/aws-sdk-go-v2/credentials v1.19.12/go.mod h1:U3R1RtSHx6NB0DvEQFGyf/0sbrpJrluENHdPy1j/3TE= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.20 h1:zOgq3uezl5nznfoK3ODuqbhVg1JzAGDUhXOsU0IDCAo= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.20/go.mod h1:z/MVwUARehy6GAg/yQ1GO2IMl0k++cu1ohP9zo887wE= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.20 h1:CNXO7mvgThFGqOFgbNAP2nol2qAWBOGfqR/7tQlvLmc= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.20/go.mod h1:oydPDJKcfMhgfcgBUZaG+toBbwy8yPWubJXBVERtI4o= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.20 h1:tN6W/hg+pkM+tf9XDkWUbDEjGLb+raoBMFsTodcoYKw= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.20/go.mod h1:YJ898MhD067hSHA6xYCx5ts/jEd8BSOLtQDL3iZsvbc= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.6 h1:qYQ4pzQ2Oz6WpQ8T3HvGHnZydA72MnLuFK9tJwmrbHw= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.6/go.mod h1:O3h0IK87yXci+kg6flUKzJnWeziQUKciKrLjcatSNcY= +github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.50.2 h1:x0eGAWpd1B5I/vMtrB4Q4Zuc3CXWI8wjHfPPqBSrKmM= +github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.50.2/go.mod h1:V9oTWSDC2MtS1DR71hbNET/bZ8psQp022amEBe1grJc= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.7 h1:5EniKhLZe4xzL7a+fU3C2tfUN4nWIqlLesfrjkuPFTY= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.7/go.mod h1:x0nZssQ3qZSnIcePWLvcoFisRXJzcTVvYpAAdYX8+GI= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.20 h1:2HvVAIq+YqgGotK6EkMf+KIEqTISmTYh5zLpYyeTo1Y= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.20/go.mod h1:V4X406Y666khGa8ghKmphma/7C0DAtEQYhkq9z4vpbk= +github.com/aws/aws-sdk-go-v2/service/signin v1.0.8 h1:0GFOLzEbOyZABS3PhYfBIx2rNBACYcKty+XGkTgw1ow= +github.com/aws/aws-sdk-go-v2/service/signin v1.0.8/go.mod h1:LXypKvk85AROkKhOG6/YEcHFPoX+prKTowKnVdcaIxE= +github.com/aws/aws-sdk-go-v2/service/sso v1.30.13 h1:kiIDLZ005EcKomYYITtfsjn7dtOwHDOFy7IbPXKek2o= +github.com/aws/aws-sdk-go-v2/service/sso v1.30.13/go.mod h1:2h/xGEowcW/g38g06g3KpRWDlT+OTfxxI0o1KqayAB8= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.17 h1:jzKAXIlhZhJbnYwHbvUQZEB8KfgAEuG0dc08Bkda7NU= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.17/go.mod h1:Al9fFsXjv4KfbzQHGe6V4NZSZQXecFcvaIF4e70FoRA= +github.com/aws/aws-sdk-go-v2/service/sts v1.41.9 h1:Cng+OOwCHmFljXIxpEVXAGMnBia8MSU6Ch5i9PgBkcU= +github.com/aws/aws-sdk-go-v2/service/sts v1.41.9/go.mod h1:LrlIndBDdjA/EeXeyNBle+gyCwTlizzW5ycgWnvIxkk= +github.com/aws/smithy-go v1.24.2 h1:FzA3bu/nt/vDvmnkg+R8Xl46gmzEDam6mZ1hzmwXFng= +github.com/aws/smithy-go v1.24.2/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc= github.com/beeper/argo-go v1.1.2 h1:UQI2G8F+NLfGTOmTUI0254pGKx/HUU/etbUGTJv91Fs= github.com/beeper/argo-go v1.1.2/go.mod h1:M+LJAnyowKVQ6Rdj6XYGEn+qcVFkb3R/MUpqkGR0hM4= github.com/bwmarrin/discordgo v0.29.0 h1:FmWeXFaKUwrcL3Cx65c20bTRW+vOb6k8AnaP+EgjDno= diff --git a/pkg/providers/bedrock/provider_bedrock.go b/pkg/providers/bedrock/provider_bedrock.go new file mode 100644 index 000000000..838beab70 --- /dev/null +++ b/pkg/providers/bedrock/provider_bedrock.go @@ -0,0 +1,580 @@ +//go:build bedrock + +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +// Package bedrock implements the LLM provider interface for AWS Bedrock. +// It uses the Bedrock Runtime Converse API for unified access to multiple +// model families (Claude, Llama, Mistral, etc.) with tool/function calling support. +package bedrock + +import ( + "context" + "encoding/base64" + "encoding/json" + "fmt" + "log" + "math" + "strings" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" + "github.com/aws/aws-sdk-go-v2/service/bedrockruntime/document" + "github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types" + + "github.com/sipeed/picoclaw/pkg/providers/common" + "github.com/sipeed/picoclaw/pkg/providers/protocoltypes" +) + +type ( + ToolCall = protocoltypes.ToolCall + FunctionCall = protocoltypes.FunctionCall + LLMResponse = protocoltypes.LLMResponse + UsageInfo = protocoltypes.UsageInfo + Message = protocoltypes.Message + ToolDefinition = protocoltypes.ToolDefinition + ToolFunctionDefinition = protocoltypes.ToolFunctionDefinition +) + +// Provider implements the LLM provider interface for AWS Bedrock. +type Provider struct { + client *bedrockruntime.Client + region string + requestTimeout time.Duration +} + +// Option configures the Bedrock Provider. +type Option func(*providerConfig) + +type providerConfig struct { + region string + profile string + baseEndpoint string + requestTimeout time.Duration +} + +// WithRegion sets the AWS region for Bedrock requests. +func WithRegion(region string) Option { + return func(c *providerConfig) { + c.region = region + } +} + +// WithProfile sets the AWS profile to use for credentials. +func WithProfile(profile string) Option { + return func(c *providerConfig) { + c.profile = profile + } +} + +// WithBaseEndpoint sets a custom Bedrock endpoint URL. +// Example: https://bedrock-runtime.us-east-1.amazonaws.com +func WithBaseEndpoint(endpoint string) Option { + return func(c *providerConfig) { + c.baseEndpoint = endpoint + } +} + +// WithRequestTimeout sets the timeout for Bedrock API requests. +func WithRequestTimeout(timeout time.Duration) Option { + return func(c *providerConfig) { + c.requestTimeout = timeout + } +} + +// NewProvider creates a new AWS Bedrock provider. +// It uses the default AWS credential chain (env vars, shared config, IAM roles, etc.). +func NewProvider(ctx context.Context, opts ...Option) (*Provider, error) { + pc := &providerConfig{} + for _, opt := range opts { + opt(pc) + } + + // Build AWS config options + var configOpts []func(*config.LoadOptions) error + + if pc.region != "" { + configOpts = append(configOpts, config.WithRegion(pc.region)) + } + + if pc.profile != "" { + configOpts = append(configOpts, config.WithSharedConfigProfile(pc.profile)) + } + + // Load AWS config with automatic credential discovery + cfg, err := config.LoadDefaultConfig(ctx, configOpts...) + if err != nil { + return nil, fmt.Errorf("loading AWS config: %w", err) + } + + // Validate region is set - required for Bedrock request signing + if cfg.Region == "" { + return nil, fmt.Errorf("AWS region not configured: set AWS_REGION, AWS_DEFAULT_REGION, or use WithRegion option") + } + + // Build client options + var clientOpts []func(*bedrockruntime.Options) + if pc.baseEndpoint != "" { + clientOpts = append(clientOpts, func(o *bedrockruntime.Options) { + o.BaseEndpoint = aws.String(pc.baseEndpoint) + }) + } + + client := bedrockruntime.NewFromConfig(cfg, clientOpts...) + + return &Provider{ + client: client, + region: cfg.Region, + requestTimeout: pc.requestTimeout, + }, nil +} + +// Chat sends messages to AWS Bedrock using the Converse API. +func (p *Provider) Chat( + ctx context.Context, + messages []Message, + tools []ToolDefinition, + model string, + options map[string]any, +) (*LLMResponse, error) { + // Apply request timeout if context doesn't already have a deadline. + // Use explicit timeout if set, otherwise fall back to common default. + effectiveTimeout := p.requestTimeout + if effectiveTimeout <= 0 { + effectiveTimeout = common.DefaultRequestTimeout + } + if _, hasDeadline := ctx.Deadline(); !hasDeadline { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, effectiveTimeout) + defer cancel() + } + + // Build the Converse API input + input := &bedrockruntime.ConverseInput{ + ModelId: aws.String(model), + } + + // Convert messages to Bedrock format + bedrockMessages, systemPrompts := convertMessages(messages) + input.Messages = bedrockMessages + + // Set system prompts if any + if len(systemPrompts) > 0 { + input.System = systemPrompts + } + + // Set inference configuration only when options are provided + var inferenceConfig *types.InferenceConfiguration + + if maxTokens, ok := common.AsInt(options["max_tokens"]); ok && maxTokens > 0 { + if inferenceConfig == nil { + inferenceConfig = &types.InferenceConfiguration{} + } + // Clamp to int32 range to avoid overflow + if maxTokens > math.MaxInt32 { + maxTokens = math.MaxInt32 + } + inferenceConfig.MaxTokens = aws.Int32(int32(maxTokens)) + } + + if temp, ok := common.AsFloat(options["temperature"]); ok { + if inferenceConfig == nil { + inferenceConfig = &types.InferenceConfiguration{} + } + inferenceConfig.Temperature = aws.Float32(float32(temp)) + } + + if inferenceConfig != nil { + input.InferenceConfig = inferenceConfig + } + + // Convert tools to Bedrock format + // Only set ToolConfig if at least one valid tool was produced + if len(tools) > 0 { + toolConfig := convertTools(tools) + if len(toolConfig.Tools) > 0 { + input.ToolConfig = toolConfig + } + } + + // Call Bedrock Converse API + output, err := p.client.Converse(ctx, input) + if err != nil { + return nil, fmt.Errorf("bedrock converse: %w", err) + } + + // Parse the response + return parseResponse(output) +} + +// GetDefaultModel returns an empty string as Bedrock models are user-configured. +func (p *Provider) GetDefaultModel() string { + return "" +} + +// Region returns the AWS region configured for this Provider. +func (p *Provider) Region() string { + return p.region +} + +// convertMessages converts internal messages to Bedrock Converse format. +// Returns the conversation messages and any system prompts separately. +// Note: Bedrock requires all tool results for a given assistant turn to be in a single +// user message with multiple ToolResultBlock content blocks. This function merges +// consecutive tool result messages accordingly. +func convertMessages(messages []Message) ([]types.Message, []types.SystemContentBlock) { + var bedrockMessages []types.Message + var systemPrompts []types.SystemContentBlock + + // Helper to check if a message is a tool result + isToolResult := func(msg Message) bool { + return (msg.Role == "tool" || (msg.Role == "user" && msg.ToolCallID != "")) && msg.ToolCallID != "" + } + + // Helper to create a tool result content block + makeToolResultBlock := func(msg Message) types.ContentBlock { + return &types.ContentBlockMemberToolResult{ + Value: types.ToolResultBlock{ + ToolUseId: aws.String(msg.ToolCallID), + Content: []types.ToolResultContentBlock{ + &types.ToolResultContentBlockMemberText{ + Value: msg.Content, + }, + }, + }, + } + } + + i := 0 + for i < len(messages) { + msg := messages[i] + + switch { + case msg.Role == "system": + // System messages go to the System field + systemPrompts = append(systemPrompts, &types.SystemContentBlockMemberText{ + Value: msg.Content, + }) + i++ + + case isToolResult(msg): + // Collect all consecutive tool results into a single user message + // Bedrock requires all tool results for a turn in one message + var toolResultBlocks []types.ContentBlock + for i < len(messages) && isToolResult(messages[i]) { + toolResultBlocks = append(toolResultBlocks, makeToolResultBlock(messages[i])) + i++ + } + bedrockMessages = append(bedrockMessages, types.Message{ + Role: types.ConversationRoleUser, + Content: toolResultBlocks, + }) + + case msg.Role == "user": + // Regular user message (no ToolCallID) + content := buildUserContent(msg) + bedrockMessages = append(bedrockMessages, types.Message{ + Role: types.ConversationRoleUser, + Content: content, + }) + i++ + + case msg.Role == "assistant": + content := buildAssistantContent(msg) + bedrockMessages = append(bedrockMessages, types.Message{ + Role: types.ConversationRoleAssistant, + Content: content, + }) + i++ + + case msg.Role == "tool" && msg.ToolCallID == "": + // Tool message without ToolCallID - treat as regular user message + content := buildUserContent(msg) + bedrockMessages = append(bedrockMessages, types.Message{ + Role: types.ConversationRoleUser, + Content: content, + }) + i++ + + default: + // Unknown role - skip + i++ + } + } + + return bedrockMessages, systemPrompts +} + +// buildUserContent builds Bedrock content blocks for a user message. +func buildUserContent(msg Message) []types.ContentBlock { + var content []types.ContentBlock + + // Add text content + if msg.Content != "" { + content = append(content, &types.ContentBlockMemberText{ + Value: msg.Content, + }) + } + + // Add images from Media field + for _, mediaURL := range msg.Media { + if strings.HasPrefix(mediaURL, "data:image/") { + // Parse data URL: data:image/jpeg;base64, + parts := strings.SplitN(mediaURL, ",", 2) + if len(parts) != 2 { + continue + } + + // Extract media type from "data:image/jpeg;base64" + mediaType := "" + header := parts[0] + if idx := strings.Index(header, "/"); idx != -1 { + end := strings.Index(header[idx:], ";") + if end == -1 { + end = len(header) - idx + } + mediaType = header[idx+1 : idx+end] + } + + // Verify this is base64 encoded + if !strings.Contains(header, ";base64") { + continue // Skip non-base64 encoded data + } + + // Map media type to Bedrock format + var format types.ImageFormat + switch mediaType { + case "jpeg", "jpg": + format = types.ImageFormatJpeg + case "png": + format = types.ImageFormatPng + case "gif": + format = types.ImageFormatGif + case "webp": + format = types.ImageFormatWebp + default: + continue // Skip unsupported formats + } + + // Check size before decoding to prevent excessive memory allocation + // Bedrock has a ~20MB request limit; cap decoded images at 10MB + const maxImageSize = 10 * 1024 * 1024 + decodedLen := base64.StdEncoding.DecodedLen(len(parts[1])) + if decodedLen > maxImageSize { + log.Printf("bedrock: skipping image exceeding size limit (%d bytes > %d)", decodedLen, maxImageSize) + continue + } + + // Decode base64 data + imageData, err := base64.StdEncoding.DecodeString(parts[1]) + if err != nil { + log.Printf("bedrock: failed to decode base64 image data: %v", err) + continue + } + + content = append(content, &types.ContentBlockMemberImage{ + Value: types.ImageBlock{ + Format: format, + Source: &types.ImageSourceMemberBytes{ + Value: imageData, + }, + }, + }) + } + } + + // Bedrock requires at least one content block; add empty text if needed + if len(content) == 0 { + content = append(content, &types.ContentBlockMemberText{Value: ""}) + } + + return content +} + +// buildAssistantContent builds Bedrock content blocks for an assistant message. +func buildAssistantContent(msg Message) []types.ContentBlock { + var content []types.ContentBlock + + // Add text content if present + if msg.Content != "" { + content = append(content, &types.ContentBlockMemberText{ + Value: msg.Content, + }) + } + + // Add tool use blocks + for _, tc := range msg.ToolCalls { + // Validate tool call ID - Bedrock requires non-empty ToolUseId + if strings.TrimSpace(tc.ID) == "" { + log.Printf("bedrock: skipping tool call with empty ID (name: %q)", tc.Name) + continue + } + + // Resolve tool name: prefer tc.Name, fallback to tc.Function.Name + // (tc.Name/tc.Arguments are json:"-" and may be empty when from JSON) + toolName := tc.Name + if toolName == "" && tc.Function != nil { + toolName = tc.Function.Name + } + if strings.TrimSpace(toolName) == "" { + continue + } + + // Resolve arguments: prefer tc.Arguments, fallback to parsing tc.Function.Arguments + args := tc.Arguments + if args == nil && tc.Function != nil && tc.Function.Arguments != "" { + if err := json.Unmarshal([]byte(tc.Function.Arguments), &args); err != nil { + log.Printf("bedrock: failed to parse Function.Arguments for tool %q: %v", toolName, err) + args = map[string]any{} + } + } + if args == nil { + args = map[string]any{} + } + + // Convert arguments to a Bedrock document using NewLazyDocument + inputDoc := document.NewLazyDocument(args) + + content = append(content, &types.ContentBlockMemberToolUse{ + Value: types.ToolUseBlock{ + ToolUseId: aws.String(tc.ID), + Name: aws.String(toolName), + Input: inputDoc, + }, + }) + } + + // Bedrock requires at least one content block; add empty text if needed + if len(content) == 0 { + content = append(content, &types.ContentBlockMemberText{Value: ""}) + } + + return content +} + +// convertTools converts tool definitions to Bedrock format. +func convertTools(tools []ToolDefinition) *types.ToolConfiguration { + bedrockTools := make([]types.Tool, 0, len(tools)) + + for _, tool := range tools { + // Skip tools with empty names + if strings.TrimSpace(tool.Function.Name) == "" { + continue + } + + // Ensure parameters is not nil - default to minimal object schema + params := tool.Function.Parameters + if params == nil { + params = map[string]any{ + "type": "object", + "properties": map[string]any{}, + } + } + + // Convert parameters schema to a Bedrock document + inputSchema := document.NewLazyDocument(params) + + bedrockTools = append(bedrockTools, &types.ToolMemberToolSpec{ + Value: types.ToolSpecification{ + Name: aws.String(tool.Function.Name), + Description: aws.String(tool.Function.Description), + InputSchema: &types.ToolInputSchemaMemberJson{ + Value: inputSchema, + }, + }, + }) + } + + return &types.ToolConfiguration{ + Tools: bedrockTools, + } +} + +// parseResponse converts Bedrock Converse output to LLMResponse. +func parseResponse(output *bedrockruntime.ConverseOutput) (*LLMResponse, error) { + var content strings.Builder + toolCalls := make([]ToolCall, 0) + + // Process output content blocks + if output.Output != nil { + if msgOutput, ok := output.Output.(*types.ConverseOutputMemberMessage); ok { + for _, block := range msgOutput.Value.Content { + switch b := block.(type) { + case *types.ContentBlockMemberText: + content.WriteString(b.Value) + + case *types.ContentBlockMemberToolUse: + // Unmarshal the document interface to a map + args := make(map[string]any) + if b.Value.Input != nil { + if err := b.Value.Input.UnmarshalSmithyDocument(&args); err != nil { + log.Printf("bedrock: failed to unmarshal tool input for tool %q (id %q): %v", + aws.ToString(b.Value.Name), + aws.ToString(b.Value.ToolUseId), + err, + ) + args = make(map[string]any) + } + } + + // Serialize arguments to JSON string for FunctionCall + argsJSON, err := json.Marshal(args) + if err != nil { + log.Printf("bedrock: failed to marshal tool arguments for tool %q (id %q): %v", + aws.ToString(b.Value.Name), + aws.ToString(b.Value.ToolUseId), + err, + ) + argsJSON = []byte("{}") + } + + toolCalls = append(toolCalls, ToolCall{ + ID: aws.ToString(b.Value.ToolUseId), + Name: aws.ToString(b.Value.Name), + Arguments: args, + Function: &FunctionCall{ + Name: aws.ToString(b.Value.Name), + Arguments: string(argsJSON), + }, + }) + } + } + } + } + + // Map stop reason + finishReason := "stop" + switch output.StopReason { + case types.StopReasonToolUse: + finishReason = "tool_calls" + case types.StopReasonMaxTokens: + finishReason = "length" + case types.StopReasonEndTurn: + finishReason = "stop" + case types.StopReasonStopSequence: + finishReason = "stop" + case types.StopReasonContentFiltered: + finishReason = "content_filter" + } + + // Build usage info + var usage *UsageInfo + if output.Usage != nil { + usage = &UsageInfo{ + PromptTokens: int(aws.ToInt32(output.Usage.InputTokens)), + CompletionTokens: int(aws.ToInt32(output.Usage.OutputTokens)), + TotalTokens: int(aws.ToInt32(output.Usage.InputTokens)) + int(aws.ToInt32(output.Usage.OutputTokens)), + } + } + + return &LLMResponse{ + Content: content.String(), + ToolCalls: toolCalls, + FinishReason: finishReason, + Usage: usage, + }, nil +} diff --git a/pkg/providers/bedrock/provider_bedrock_test.go b/pkg/providers/bedrock/provider_bedrock_test.go new file mode 100644 index 000000000..754d112ee --- /dev/null +++ b/pkg/providers/bedrock/provider_bedrock_test.go @@ -0,0 +1,541 @@ +//go:build bedrock + +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package bedrock + +import ( + "testing" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" + "github.com/aws/aws-sdk-go-v2/service/bedrockruntime/document" + "github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/sipeed/picoclaw/pkg/providers/protocoltypes" +) + +func TestConvertMessages_SystemPrompts(t *testing.T) { + messages := []Message{ + {Role: "system", Content: "You are a helpful assistant."}, + {Role: "user", Content: "Hello"}, + } + + bedrockMsgs, systemPrompts := convertMessages(messages) + + assert.Len(t, systemPrompts, 1) + assert.Len(t, bedrockMsgs, 1) + + // Check system prompt + textBlock, ok := systemPrompts[0].(*types.SystemContentBlockMemberText) + require.True(t, ok) + assert.Equal(t, "You are a helpful assistant.", textBlock.Value) + + // Check user message + assert.Equal(t, types.ConversationRoleUser, bedrockMsgs[0].Role) +} + +func TestConvertMessages_UserMessage(t *testing.T) { + messages := []Message{ + {Role: "user", Content: "What is 2+2?"}, + } + + bedrockMsgs, systemPrompts := convertMessages(messages) + + assert.Empty(t, systemPrompts) + assert.Len(t, bedrockMsgs, 1) + assert.Equal(t, types.ConversationRoleUser, bedrockMsgs[0].Role) + + textBlock, ok := bedrockMsgs[0].Content[0].(*types.ContentBlockMemberText) + require.True(t, ok) + assert.Equal(t, "What is 2+2?", textBlock.Value) +} + +func TestConvertMessages_AssistantMessage(t *testing.T) { + messages := []Message{ + {Role: "assistant", Content: "The answer is 4."}, + } + + bedrockMsgs, _ := convertMessages(messages) + + assert.Len(t, bedrockMsgs, 1) + assert.Equal(t, types.ConversationRoleAssistant, bedrockMsgs[0].Role) + + textBlock, ok := bedrockMsgs[0].Content[0].(*types.ContentBlockMemberText) + require.True(t, ok) + assert.Equal(t, "The answer is 4.", textBlock.Value) +} + +func TestConvertMessages_ToolResult(t *testing.T) { + messages := []Message{ + {Role: "tool", Content: "Result from tool", ToolCallID: "call_123"}, + } + + bedrockMsgs, _ := convertMessages(messages) + + assert.Len(t, bedrockMsgs, 1) + assert.Equal(t, types.ConversationRoleUser, bedrockMsgs[0].Role) + + toolResult, ok := bedrockMsgs[0].Content[0].(*types.ContentBlockMemberToolResult) + require.True(t, ok) + assert.Equal(t, "call_123", aws.ToString(toolResult.Value.ToolUseId)) +} + +func TestConvertMessages_MultipleToolResultsMerged(t *testing.T) { + // When an assistant makes multiple tool calls, all tool results must be + // merged into a single user message for Bedrock + messages := []Message{ + {Role: "user", Content: "What's the weather in NYC and LA?"}, + { + Role: "assistant", + Content: "Let me check both cities.", + ToolCalls: []protocoltypes.ToolCall{ + {ID: "call_nyc", Name: "get_weather", Arguments: map[string]any{"city": "NYC"}}, + {ID: "call_la", Name: "get_weather", Arguments: map[string]any{"city": "LA"}}, + }, + }, + {Role: "tool", Content: "NYC: 72°F, sunny", ToolCallID: "call_nyc"}, + {Role: "tool", Content: "LA: 85°F, clear", ToolCallID: "call_la"}, + } + + bedrockMsgs, _ := convertMessages(messages) + + // Should be: user message, assistant message, merged tool results (single user message) + assert.Len(t, bedrockMsgs, 3) + + // First message: user + assert.Equal(t, types.ConversationRoleUser, bedrockMsgs[0].Role) + + // Second message: assistant with tool calls + assert.Equal(t, types.ConversationRoleAssistant, bedrockMsgs[1].Role) + + // Third message: merged tool results in single user message + assert.Equal(t, types.ConversationRoleUser, bedrockMsgs[2].Role) + assert.Len(t, bedrockMsgs[2].Content, 2) // Both tool results in one message + + // Verify both tool results are present + result1, ok := bedrockMsgs[2].Content[0].(*types.ContentBlockMemberToolResult) + require.True(t, ok) + assert.Equal(t, "call_nyc", aws.ToString(result1.Value.ToolUseId)) + + result2, ok := bedrockMsgs[2].Content[1].(*types.ContentBlockMemberToolResult) + require.True(t, ok) + assert.Equal(t, "call_la", aws.ToString(result2.Value.ToolUseId)) +} + +func TestConvertMessages_AssistantWithToolCalls(t *testing.T) { + messages := []Message{ + { + Role: "assistant", + Content: "Let me calculate that.", + ToolCalls: []protocoltypes.ToolCall{ + { + ID: "call_456", + Name: "calculator", + Arguments: map[string]any{"expression": "2+2"}, + }, + }, + }, + } + + bedrockMsgs, _ := convertMessages(messages) + + assert.Len(t, bedrockMsgs, 1) + assert.Len(t, bedrockMsgs[0].Content, 2) // text + tool use + + // Check text content + textBlock, ok := bedrockMsgs[0].Content[0].(*types.ContentBlockMemberText) + require.True(t, ok) + assert.Equal(t, "Let me calculate that.", textBlock.Value) + + // Check tool use + toolUse, ok := bedrockMsgs[0].Content[1].(*types.ContentBlockMemberToolUse) + require.True(t, ok) + assert.Equal(t, "call_456", aws.ToString(toolUse.Value.ToolUseId)) + assert.Equal(t, "calculator", aws.ToString(toolUse.Value.Name)) +} + +func TestConvertTools_Basic(t *testing.T) { + tools := []ToolDefinition{ + { + Function: protocoltypes.ToolFunctionDefinition{ + Name: "get_weather", + Description: "Get the current weather", + Parameters: map[string]any{ + "type": "object", + "properties": map[string]any{ + "location": map[string]any{"type": "string"}, + }, + }, + }, + }, + } + + toolConfig := convertTools(tools) + + assert.NotNil(t, toolConfig) + assert.Len(t, toolConfig.Tools, 1) + + toolSpec, ok := toolConfig.Tools[0].(*types.ToolMemberToolSpec) + require.True(t, ok) + assert.Equal(t, "get_weather", aws.ToString(toolSpec.Value.Name)) + assert.Equal(t, "Get the current weather", aws.ToString(toolSpec.Value.Description)) +} + +func TestConvertTools_SkipsEmptyName(t *testing.T) { + tools := []ToolDefinition{ + { + Function: protocoltypes.ToolFunctionDefinition{ + Name: "", + Description: "Empty name tool", + }, + }, + { + Function: protocoltypes.ToolFunctionDefinition{ + Name: " ", + Description: "Whitespace name tool", + }, + }, + { + Function: protocoltypes.ToolFunctionDefinition{ + Name: "valid_tool", + Description: "Valid tool", + }, + }, + } + + toolConfig := convertTools(tools) + + assert.Len(t, toolConfig.Tools, 1) + toolSpec := toolConfig.Tools[0].(*types.ToolMemberToolSpec) + assert.Equal(t, "valid_tool", aws.ToString(toolSpec.Value.Name)) +} + +func TestConvertTools_NilParameters(t *testing.T) { + tools := []ToolDefinition{ + { + Function: protocoltypes.ToolFunctionDefinition{ + Name: "simple_tool", + Description: "A tool with no parameters", + Parameters: nil, + }, + }, + } + + toolConfig := convertTools(tools) + + assert.Len(t, toolConfig.Tools, 1) + // Should not panic and should create a valid tool +} + +func TestBuildUserContent_TextOnly(t *testing.T) { + msg := Message{Content: "Hello world"} + + content := buildUserContent(msg) + + assert.Len(t, content, 1) + textBlock, ok := content[0].(*types.ContentBlockMemberText) + require.True(t, ok) + assert.Equal(t, "Hello world", textBlock.Value) +} + +func TestBuildUserContent_WithImage(t *testing.T) { + // Base64-encoded 1x1 PNG (the provider doesn't validate image correctness, + // it just verifies the format and base64 decoding works) + b64Data := "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADUlEQVR4nGNgYAAAAAMAASsJTYQAAAAASUVORK5CYII=" + + msg := Message{ + Content: "Look at this image", + Media: []string{"data:image/png;base64," + b64Data}, + } + + content := buildUserContent(msg) + + assert.Len(t, content, 2) + + // Check text + textBlock, ok := content[0].(*types.ContentBlockMemberText) + require.True(t, ok) + assert.Equal(t, "Look at this image", textBlock.Value) + + // Check image + imageBlock, ok := content[1].(*types.ContentBlockMemberImage) + require.True(t, ok) + assert.Equal(t, types.ImageFormatPng, imageBlock.Value.Format) +} + +func TestBuildUserContent_SkipsInvalidBase64(t *testing.T) { + msg := Message{ + Content: "Invalid image", + Media: []string{"data:image/png;base64,not-valid-base64!!!"}, + } + + content := buildUserContent(msg) + + // Should only have text, image should be skipped + assert.Len(t, content, 1) +} + +func TestBuildUserContent_SkipsNonBase64Data(t *testing.T) { + msg := Message{ + Content: "Non-base64 image", + Media: []string{"data:image/png,raw-data-here"}, + } + + content := buildUserContent(msg) + + // Should only have text, non-base64 image should be skipped + assert.Len(t, content, 1) +} + +func TestBuildAssistantContent_SkipsEmptyToolName(t *testing.T) { + msg := Message{ + Content: "Response", + ToolCalls: []protocoltypes.ToolCall{ + {ID: "1", Name: "", Arguments: map[string]any{}}, + {ID: "2", Name: " ", Arguments: map[string]any{}}, + {ID: "3", Name: "valid", Arguments: map[string]any{}}, + }, + } + + content := buildAssistantContent(msg) + + // Should have text + 1 valid tool + assert.Len(t, content, 2) +} + +func TestBuildAssistantContent_NilArguments(t *testing.T) { + msg := Message{ + ToolCalls: []protocoltypes.ToolCall{ + {ID: "1", Name: "tool", Arguments: nil}, + }, + } + + content := buildAssistantContent(msg) + + assert.Len(t, content, 1) + toolUse, ok := content[0].(*types.ContentBlockMemberToolUse) + require.True(t, ok) + assert.NotNil(t, toolUse.Value.Input) +} + +func TestBuildAssistantContent_FunctionFallback(t *testing.T) { + // When Name/Arguments are empty (json:"-"), should fallback to Function fields + msg := Message{ + ToolCalls: []protocoltypes.ToolCall{ + { + ID: "1", + Name: "", // empty, should fallback to Function.Name + Function: &protocoltypes.FunctionCall{ + Name: "fallback_tool", + Arguments: `{"key":"value"}`, + }, + }, + }, + } + + content := buildAssistantContent(msg) + + assert.Len(t, content, 1) + toolUse, ok := content[0].(*types.ContentBlockMemberToolUse) + require.True(t, ok) + assert.Equal(t, "fallback_tool", aws.ToString(toolUse.Value.Name)) +} + +func TestParseResponse_TextOnly(t *testing.T) { + output := &bedrockruntime.ConverseOutput{ + Output: &types.ConverseOutputMemberMessage{ + Value: types.Message{ + Role: types.ConversationRoleAssistant, + Content: []types.ContentBlock{ + &types.ContentBlockMemberText{Value: "Hello!"}, + }, + }, + }, + StopReason: types.StopReasonEndTurn, + Usage: &types.TokenUsage{ + InputTokens: aws.Int32(10), + OutputTokens: aws.Int32(5), + }, + } + + resp, err := parseResponse(output) + + require.NoError(t, err) + assert.Equal(t, "Hello!", resp.Content) + assert.Equal(t, "stop", resp.FinishReason) + assert.Empty(t, resp.ToolCalls) + assert.Equal(t, 10, resp.Usage.PromptTokens) + assert.Equal(t, 5, resp.Usage.CompletionTokens) +} + +func TestParseResponse_StopReasons(t *testing.T) { + tests := []struct { + stopReason types.StopReason + expectedFinish string + }{ + {types.StopReasonEndTurn, "stop"}, + {types.StopReasonToolUse, "tool_calls"}, + {types.StopReasonMaxTokens, "length"}, + {types.StopReasonStopSequence, "stop"}, + {types.StopReasonContentFiltered, "content_filter"}, + } + + for _, tt := range tests { + t.Run(string(tt.stopReason), func(t *testing.T) { + output := &bedrockruntime.ConverseOutput{ + Output: &types.ConverseOutputMemberMessage{ + Value: types.Message{ + Content: []types.ContentBlock{ + &types.ContentBlockMemberText{Value: "test"}, + }, + }, + }, + StopReason: tt.stopReason, + } + + resp, err := parseResponse(output) + + require.NoError(t, err) + assert.Equal(t, tt.expectedFinish, resp.FinishReason) + }) + } +} + +func TestParseResponse_WithToolCalls(t *testing.T) { + // Note: document.NewLazyDocument has limitations with UnmarshalSmithyDocument in tests, + // so we test the structure extraction and verify Arguments gets populated (even if empty + // due to SDK limitations). The actual unmarshal works correctly at runtime. + toolInput := document.NewLazyDocument(map[string]any{ + "location": "San Francisco", + "unit": "celsius", + }) + + output := &bedrockruntime.ConverseOutput{ + Output: &types.ConverseOutputMemberMessage{ + Value: types.Message{ + Role: types.ConversationRoleAssistant, + Content: []types.ContentBlock{ + &types.ContentBlockMemberText{Value: "Let me check the weather."}, + &types.ContentBlockMemberToolUse{ + Value: types.ToolUseBlock{ + ToolUseId: aws.String("call_weather_123"), + Name: aws.String("get_weather"), + Input: toolInput, + }, + }, + }, + }, + }, + StopReason: types.StopReasonToolUse, + Usage: &types.TokenUsage{ + InputTokens: aws.Int32(20), + OutputTokens: aws.Int32(15), + }, + } + + resp, err := parseResponse(output) + + require.NoError(t, err) + assert.Equal(t, "Let me check the weather.", resp.Content) + assert.Equal(t, "tool_calls", resp.FinishReason) + assert.Len(t, resp.ToolCalls, 1) + + // Verify tool call ID and Name are extracted correctly + tc := resp.ToolCalls[0] + assert.Equal(t, "call_weather_123", tc.ID) + assert.Equal(t, "get_weather", tc.Name) + + // Verify Function fields are also populated + require.NotNil(t, tc.Function) + assert.Equal(t, "get_weather", tc.Function.Name) + + // Verify Arguments is not nil (content may vary due to SDK limitations in tests) + assert.NotNil(t, tc.Arguments) + + // Verify usage + assert.Equal(t, 20, resp.Usage.PromptTokens) + assert.Equal(t, 15, resp.Usage.CompletionTokens) + assert.Equal(t, 35, resp.Usage.TotalTokens) +} + +func TestParseResponse_MultipleToolCalls(t *testing.T) { + output := &bedrockruntime.ConverseOutput{ + Output: &types.ConverseOutputMemberMessage{ + Value: types.Message{ + Role: types.ConversationRoleAssistant, + Content: []types.ContentBlock{ + &types.ContentBlockMemberToolUse{ + Value: types.ToolUseBlock{ + ToolUseId: aws.String("call_1"), + Name: aws.String("tool_a"), + Input: document.NewLazyDocument(map[string]any{"arg": "value1"}), + }, + }, + &types.ContentBlockMemberToolUse{ + Value: types.ToolUseBlock{ + ToolUseId: aws.String("call_2"), + Name: aws.String("tool_b"), + Input: document.NewLazyDocument(map[string]any{"arg": "value2"}), + }, + }, + }, + }, + }, + StopReason: types.StopReasonToolUse, + } + + resp, err := parseResponse(output) + + require.NoError(t, err) + assert.Equal(t, "tool_calls", resp.FinishReason) + assert.Len(t, resp.ToolCalls, 2) + + // Verify tool call structure + assert.Equal(t, "call_1", resp.ToolCalls[0].ID) + assert.Equal(t, "tool_a", resp.ToolCalls[0].Name) + assert.NotNil(t, resp.ToolCalls[0].Arguments) + assert.NotNil(t, resp.ToolCalls[0].Function) + assert.Equal(t, "tool_a", resp.ToolCalls[0].Function.Name) + + assert.Equal(t, "call_2", resp.ToolCalls[1].ID) + assert.Equal(t, "tool_b", resp.ToolCalls[1].Name) + assert.NotNil(t, resp.ToolCalls[1].Arguments) + assert.NotNil(t, resp.ToolCalls[1].Function) + assert.Equal(t, "tool_b", resp.ToolCalls[1].Function.Name) +} + +func TestParseResponse_ToolCallWithNilInput(t *testing.T) { + output := &bedrockruntime.ConverseOutput{ + Output: &types.ConverseOutputMemberMessage{ + Value: types.Message{ + Role: types.ConversationRoleAssistant, + Content: []types.ContentBlock{ + &types.ContentBlockMemberToolUse{ + Value: types.ToolUseBlock{ + ToolUseId: aws.String("call_nil"), + Name: aws.String("no_args_tool"), + Input: nil, + }, + }, + }, + }, + }, + StopReason: types.StopReasonToolUse, + } + + resp, err := parseResponse(output) + + require.NoError(t, err) + assert.Len(t, resp.ToolCalls, 1) + assert.Equal(t, "call_nil", resp.ToolCalls[0].ID) + assert.Equal(t, "no_args_tool", resp.ToolCalls[0].Name) + // Arguments should be empty map, not nil + assert.NotNil(t, resp.ToolCalls[0].Arguments) + assert.Empty(t, resp.ToolCalls[0].Arguments) +} diff --git a/pkg/providers/bedrock/provider_stub.go b/pkg/providers/bedrock/provider_stub.go new file mode 100644 index 000000000..894d9f2ca --- /dev/null +++ b/pkg/providers/bedrock/provider_stub.go @@ -0,0 +1,73 @@ +//go:build !bedrock + +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +// Package bedrock provides a stub implementation when built without the bedrock tag. +// To enable AWS Bedrock support, build with: go build -tags bedrock +package bedrock + +import ( + "context" + "fmt" + "time" + + "github.com/sipeed/picoclaw/pkg/providers/protocoltypes" +) + +type ( + LLMResponse = protocoltypes.LLMResponse + Message = protocoltypes.Message + ToolDefinition = protocoltypes.ToolDefinition +) + +// Provider is a stub that returns an error when Bedrock support is not compiled in. +type Provider struct{} + +// Option is a no-op when Bedrock is not enabled. +type Option func(*providerConfig) + +type providerConfig struct{} + +// WithRegion is a no-op when Bedrock is not enabled. +func WithRegion(region string) Option { + return func(c *providerConfig) {} +} + +// WithProfile is a no-op when Bedrock is not enabled. +func WithProfile(profile string) Option { + return func(c *providerConfig) {} +} + +// WithBaseEndpoint is a no-op when Bedrock is not enabled. +func WithBaseEndpoint(endpoint string) Option { + return func(c *providerConfig) {} +} + +// WithRequestTimeout is a no-op when Bedrock is not enabled. +func WithRequestTimeout(timeout time.Duration) Option { + return func(c *providerConfig) {} +} + +// NewProvider returns an error indicating Bedrock support is not compiled in. +func NewProvider(ctx context.Context, opts ...Option) (*Provider, error) { + return nil, fmt.Errorf("bedrock provider not available: build with -tags bedrock to enable AWS Bedrock support") +} + +// Chat returns an error - this should never be called since NewProvider fails. +func (p *Provider) Chat( + ctx context.Context, + messages []Message, + tools []ToolDefinition, + model string, + options map[string]any, +) (*LLMResponse, error) { + return nil, fmt.Errorf("bedrock provider not available: build with -tags bedrock to enable AWS Bedrock support") +} + +// GetDefaultModel returns an empty string. +func (p *Provider) GetDefaultModel() string { + return "" +} diff --git a/pkg/providers/bedrock/provider_stub_test.go b/pkg/providers/bedrock/provider_stub_test.go new file mode 100644 index 000000000..50ec8340f --- /dev/null +++ b/pkg/providers/bedrock/provider_stub_test.go @@ -0,0 +1,35 @@ +//go:build !bedrock + +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package bedrock + +import ( + "context" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewProvider_ReturnsStubError(t *testing.T) { + provider, err := NewProvider(context.Background()) + + assert.Nil(t, provider) + require.Error(t, err) + assert.True(t, strings.Contains(err.Error(), "build with -tags bedrock"), + "error should mention build tag requirement, got: %s", err.Error()) +} + +func TestNewProvider_WithOptions_ReturnsStubError(t *testing.T) { + provider, err := NewProvider(context.Background(), WithRegion("us-west-2"), WithProfile("test")) + + assert.Nil(t, provider) + require.Error(t, err) + assert.True(t, strings.Contains(err.Error(), "build with -tags bedrock"), + "error should mention build tag requirement, got: %s", err.Error()) +} diff --git a/pkg/providers/factory_provider.go b/pkg/providers/factory_provider.go index bc7c2ff70..1128fc042 100644 --- a/pkg/providers/factory_provider.go +++ b/pkg/providers/factory_provider.go @@ -6,12 +6,15 @@ package providers import ( + "context" "fmt" "strings" + "time" "github.com/sipeed/picoclaw/pkg/config" anthropicmessages "github.com/sipeed/picoclaw/pkg/providers/anthropic_messages" "github.com/sipeed/picoclaw/pkg/providers/azure" + "github.com/sipeed/picoclaw/pkg/providers/bedrock" ) // createClaudeAuthProvider creates a Claude provider using OAuth credentials from auth store. @@ -55,8 +58,9 @@ func ExtractProtocol(model string) (protocol, modelID string) { // CreateProviderFromConfig creates a provider based on the ModelConfig. // It uses the protocol prefix in the Model field to determine which provider to create. -// Supported protocols: openai, litellm, novita, anthropic, anthropic-messages, -// antigravity, claude-cli, codex-cli, github-copilot +// Supported protocol families include OpenAI-compatible prefixes (e.g., openai, openrouter, groq, gemini), +// Azure OpenAI, Amazon Bedrock, Anthropic (including messages), and various CLI/compatibility shims. +// See the switch on protocol in this function for the authoritative list. // Returns the provider, the model ID (without protocol prefix), and any error. func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, error) { if cfg == nil { @@ -114,6 +118,42 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err cfg.RequestTimeout, ), modelID, nil + case "bedrock": + // AWS Bedrock uses AWS SDK credentials (env vars, profiles, IAM roles, etc.) + // api_base can be: + // - A full endpoint URL: https://bedrock-runtime.us-east-1.amazonaws.com + // - A region name: us-east-1 (AWS SDK resolves endpoint automatically) + var opts []bedrock.Option + if cfg.APIBase != "" { + if !strings.Contains(cfg.APIBase, "://") { + // Treat as region: let AWS SDK resolve the correct endpoint + // (supports all AWS partitions: aws, aws-cn, aws-us-gov, etc.) + opts = append(opts, bedrock.WithRegion(cfg.APIBase)) + } else { + // Full endpoint URL provided (for custom endpoints or testing) + opts = append(opts, bedrock.WithBaseEndpoint(cfg.APIBase)) + } + } + // Use a separate timeout for AWS config loading (credential resolution can block) + initTimeout := 30 * time.Second + if cfg.RequestTimeout > 0 { + reqTimeout := time.Duration(cfg.RequestTimeout) * time.Second + // Set request timeout for API calls + opts = append(opts, bedrock.WithRequestTimeout(reqTimeout)) + // Ensure init timeout is at least as large as request timeout + if reqTimeout > initTimeout { + initTimeout = reqTimeout + } + } + ctx, cancel := context.WithTimeout(context.Background(), initTimeout) + defer cancel() + // Note: AWS_PROFILE env var is automatically used by AWS SDK + provider, err := bedrock.NewProvider(ctx, opts...) + if err != nil { + return nil, "", fmt.Errorf("creating bedrock provider: %w", err) + } + return provider, modelID, nil + case "litellm", "openrouter", "groq", "zhipu", "gemini", "nvidia", "ollama", "moonshot", "shengsuanyun", "deepseek", "cerebras", "vivgrid", "volcengine", "vllm", "qwen", "qwen-intl", "qwen-international", "dashscope-intl", diff --git a/pkg/providers/factory_provider_test.go b/pkg/providers/factory_provider_test.go index 1bff0419d..2fed18c35 100644 --- a/pkg/providers/factory_provider_test.go +++ b/pkg/providers/factory_provider_test.go @@ -700,3 +700,78 @@ func TestCreateProviderFromConfig_MinimaxPreservesUserExtraBody(t *testing.T) { t.Fatalf("custom_field = %v, want test", got) } } + +func TestCreateProviderFromConfig_Bedrock(t *testing.T) { + // Set dummy AWS env vars to make test deterministic + t.Setenv("AWS_ACCESS_KEY_ID", "test-key") + t.Setenv("AWS_SECRET_ACCESS_KEY", "test-secret") + t.Setenv("AWS_EC2_METADATA_DISABLED", "true") + // Clear profile-related env vars to avoid loading shared config + t.Setenv("AWS_PROFILE", "") + t.Setenv("AWS_DEFAULT_PROFILE", "") + t.Setenv("AWS_SDK_LOAD_CONFIG", "") + t.Setenv("AWS_SHARED_CREDENTIALS_FILE", "") + + cfg := &config.ModelConfig{ + ModelName: "bedrock-claude", + Model: "bedrock/us.anthropic.claude-sonnet-4-20250514-v1:0", + APIBase: "us-west-2", // Region (also sets AWS region) + } + + provider, modelID, err := CreateProviderFromConfig(cfg) + if err == nil { + // Provider created successfully (built with -tags bedrock) + if provider == nil { + t.Error("provider is nil on success") + } + if modelID != "us.anthropic.claude-sonnet-4-20250514-v1:0" { + t.Errorf("modelID = %q, want %q", modelID, "us.anthropic.claude-sonnet-4-20250514-v1:0") + } + return + } + errMsg := err.Error() + // When built without -tags bedrock, expect stub error + if strings.Contains(errMsg, "build with -tags bedrock") { + return // Expected stub error + } + // Unexpected error - fail the test + t.Errorf("unexpected error from bedrock provider: %v", err) +} + +func TestCreateProviderFromConfig_BedrockWithEndpointURL(t *testing.T) { + // Set dummy AWS env vars to make test deterministic + t.Setenv("AWS_ACCESS_KEY_ID", "test-key") + t.Setenv("AWS_SECRET_ACCESS_KEY", "test-secret") + t.Setenv("AWS_REGION", "us-east-1") // Required when using endpoint URL + t.Setenv("AWS_EC2_METADATA_DISABLED", "true") + // Clear profile-related env vars to avoid loading shared config + t.Setenv("AWS_PROFILE", "") + t.Setenv("AWS_DEFAULT_PROFILE", "") + t.Setenv("AWS_SDK_LOAD_CONFIG", "") + t.Setenv("AWS_SHARED_CREDENTIALS_FILE", "") + + cfg := &config.ModelConfig{ + ModelName: "bedrock-claude", + Model: "bedrock/us.anthropic.claude-sonnet-4-20250514-v1:0", + APIBase: "https://bedrock-runtime.us-east-1.amazonaws.com", // Full endpoint URL + } + + provider, modelID, err := CreateProviderFromConfig(cfg) + if err == nil { + // Provider created successfully (built with -tags bedrock) + if provider == nil { + t.Error("provider is nil on success") + } + if modelID != "us.anthropic.claude-sonnet-4-20250514-v1:0" { + t.Errorf("modelID = %q, want %q", modelID, "us.anthropic.claude-sonnet-4-20250514-v1:0") + } + return + } + errMsg := err.Error() + // When built without -tags bedrock, expect stub error + if strings.Contains(errMsg, "build with -tags bedrock") { + return // Expected stub error + } + // Unexpected error - fail the test + t.Errorf("unexpected error from bedrock provider: %v", err) +}