From ad5232ade8035012facaab0f3fae5bb2406d8599 Mon Sep 17 00:00:00 2001 From: Andy Lo-A-Foe Date: Thu, 23 Apr 2026 23:27:55 +0200 Subject: [PATCH] feat(bedrock): implement StreamingProvider for real-time token streaming Adds ConverseStream API support to the Bedrock provider, implementing the StreamingProvider interface. Tokens flow via onChunk callback for real-time delivery to streaming-capable channels. - Extract buildConverseParams to share request logic between Chat and ChatStream - Add converseStreamReader interface for testability - Preserve raw payload in Arguments on JSON parse failure - Ensure Function.Arguments is always valid JSON - Streaming timeout only applied when explicitly configured - Capture stream Close() errors for diagnostics - Consistent "bedrock conversestream" / "bedrock:" log prefixes Co-Authored-By: Claude Opus 4.6 --- pkg/providers/bedrock/provider_bedrock.go | 284 ++++++++++++++++++---- 1 file changed, 239 insertions(+), 45 deletions(-) diff --git a/pkg/providers/bedrock/provider_bedrock.go b/pkg/providers/bedrock/provider_bedrock.go index 3798c5fd8..ee0ac75a0 100644 --- a/pkg/providers/bedrock/provider_bedrock.go +++ b/pkg/providers/bedrock/provider_bedrock.go @@ -135,48 +135,23 @@ func NewProvider(ctx context.Context, opts ...Option) (*Provider, error) { }, nil } -// Chat sends messages to AWS Bedrock using the Converse API. -func (p *Provider) Chat( - ctx context.Context, - messages []Message, - tools []ToolDefinition, - model string, - options map[string]any, -) (*LLMResponse, error) { - // Apply request timeout if context doesn't already have a deadline. - // Use explicit timeout if set, otherwise fall back to common default. - effectiveTimeout := p.requestTimeout - if effectiveTimeout <= 0 { - effectiveTimeout = common.DefaultRequestTimeout - } - if _, hasDeadline := ctx.Deadline(); !hasDeadline { - var cancel context.CancelFunc - ctx, cancel = context.WithTimeout(ctx, effectiveTimeout) - defer cancel() - } +// converseParams holds the shared request parameters for Converse and ConverseStream. +type converseParams struct { + messages []types.Message + system []types.SystemContentBlock + inferenceConfig *types.InferenceConfiguration + toolConfig *types.ToolConfiguration +} - // Build the Converse API input - input := &bedrockruntime.ConverseInput{ - ModelId: aws.String(model), - } - - // Convert messages to Bedrock format +func buildConverseParams(messages []Message, tools []ToolDefinition, options map[string]any) converseParams { bedrockMessages, systemPrompts := convertMessages(messages) - input.Messages = bedrockMessages - // Set system prompts if any - if len(systemPrompts) > 0 { - input.System = systemPrompts - } - - // Set inference configuration only when options are provided var inferenceConfig *types.InferenceConfiguration if maxTokens, ok := common.AsInt(options["max_tokens"]); ok && maxTokens > 0 { if inferenceConfig == nil { inferenceConfig = &types.InferenceConfiguration{} } - // Clamp to int32 range to avoid overflow if maxTokens > math.MaxInt32 { maxTokens = math.MaxInt32 } @@ -190,23 +165,53 @@ func (p *Provider) Chat( inferenceConfig.Temperature = aws.Float32(float32(temp)) } - if inferenceConfig != nil { - input.InferenceConfig = inferenceConfig - } - - // Convert tools to Bedrock format - // Only set ToolConfig if at least one valid tool was produced + var toolConfig *types.ToolConfiguration if len(tools) > 0 { - toolConfig := convertTools(tools) - if len(toolConfig.Tools) > 0 { - input.ToolConfig = toolConfig + tc := convertTools(tools) + if len(tc.Tools) > 0 { + toolConfig = tc } } - // Call Bedrock Converse API + return converseParams{ + messages: bedrockMessages, + system: systemPrompts, + inferenceConfig: inferenceConfig, + toolConfig: toolConfig, + } +} + +// Chat sends messages to AWS Bedrock using the Converse API. +func (p *Provider) Chat( + ctx context.Context, + messages []Message, + tools []ToolDefinition, + model string, + options map[string]any, +) (*LLMResponse, error) { + effectiveTimeout := p.requestTimeout + if effectiveTimeout <= 0 { + effectiveTimeout = common.DefaultRequestTimeout + } + if _, hasDeadline := ctx.Deadline(); !hasDeadline { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, effectiveTimeout) + defer cancel() + } + + params := buildConverseParams(messages, tools, options) + input := &bedrockruntime.ConverseInput{ + ModelId: aws.String(model), + Messages: params.messages, + InferenceConfig: params.inferenceConfig, + ToolConfig: params.toolConfig, + } + if len(params.system) > 0 { + input.System = params.system + } + output, err := p.client.Converse(ctx, input) if err != nil { - // Check for SSO token expiration errors and provide actionable guidance if isSSOTokenError(err) { return nil, fmt.Errorf( "bedrock converse: AWS credentials may have expired. If using AWS SSO, run 'aws sso login' to refresh: %w", @@ -216,10 +221,199 @@ func (p *Provider) Chat( return nil, fmt.Errorf("bedrock converse: %w", err) } - // Parse the response return parseResponse(output) } +// ChatStream sends messages to AWS Bedrock using the ConverseStream API. +// It streams the accumulated text so far via the onChunk callback and returns the complete response. +func (p *Provider) ChatStream( + ctx context.Context, + messages []Message, + tools []ToolDefinition, + model string, + options map[string]any, + onChunk func(accumulated string), +) (*LLMResponse, error) { + if p.requestTimeout > 0 { + if _, hasDeadline := ctx.Deadline(); !hasDeadline { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, p.requestTimeout) + defer cancel() + } + } + + params := buildConverseParams(messages, tools, options) + input := &bedrockruntime.ConverseStreamInput{ + ModelId: aws.String(model), + Messages: params.messages, + InferenceConfig: params.inferenceConfig, + ToolConfig: params.toolConfig, + } + if len(params.system) > 0 { + input.System = params.system + } + + output, err := p.client.ConverseStream(ctx, input) + if err != nil { + if isSSOTokenError(err) { + return nil, fmt.Errorf( + "bedrock conversestream: AWS credentials may have expired. If using AWS SSO, run 'aws sso login' to refresh: %w", + err, + ) + } + return nil, fmt.Errorf("bedrock conversestream: %w", err) + } + + return parseStreamResponse(ctx, output.GetStream(), onChunk) +} + +// converseStreamReader abstracts the Bedrock event stream so parseStreamResponse +// can be unit-tested with a mock event source. +type converseStreamReader interface { + Events() <-chan types.ConverseStreamOutput + Err() error + Close() error +} + +// parseStreamResponse processes the ConverseStream event stream and accumulates the response. +func parseStreamResponse( + ctx context.Context, + stream converseStreamReader, + onChunk func(accumulated string), +) (resp *LLMResponse, err error) { + if stream == nil { + return nil, fmt.Errorf("bedrock conversestream: nil event stream") + } + defer func() { + if closeErr := stream.Close(); closeErr != nil { + if err == nil { + err = fmt.Errorf("bedrock conversestream: close event stream: %w", closeErr) + } else { + log.Printf("bedrock conversestream: close event stream: %v", closeErr) + } + } + }() + + var textContent strings.Builder + finishReason := "stop" + var usage *UsageInfo + toolCalls := make([]ToolCall, 0) + + // Track active tool use blocks by index + type toolAccum struct { + id string + name string + argsJSON strings.Builder + } + activeTools := map[int]*toolAccum{} + + events := stream.Events() + for { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case event, ok := <-events: + if !ok { + // Stream closed + goto done + } + + switch e := event.(type) { + case *types.ConverseStreamOutputMemberContentBlockStart: + // New content block starting + if toolUse, ok := e.Value.Start.(*types.ContentBlockStartMemberToolUse); ok { + activeTools[int(aws.ToInt32(e.Value.ContentBlockIndex))] = &toolAccum{ + id: aws.ToString(toolUse.Value.ToolUseId), + name: aws.ToString(toolUse.Value.Name), + } + } + + case *types.ConverseStreamOutputMemberContentBlockDelta: + // Content delta + switch delta := e.Value.Delta.(type) { + case *types.ContentBlockDeltaMemberText: + textContent.WriteString(delta.Value) + if onChunk != nil { + onChunk(textContent.String()) + } + case *types.ContentBlockDeltaMemberToolUse: + idx := int(aws.ToInt32(e.Value.ContentBlockIndex)) + if tool, exists := activeTools[idx]; exists { + tool.argsJSON.WriteString(aws.ToString(delta.Value.Input)) + } + } + + case *types.ConverseStreamOutputMemberContentBlockStop: + // Content block finished - finalize tool if it was a tool use + idx := int(aws.ToInt32(e.Value.ContentBlockIndex)) + if tool, exists := activeTools[idx]; exists { + args := make(map[string]any) + argsStr := tool.argsJSON.String() + if argsStr != "" { + if err := json.Unmarshal([]byte(argsStr), &args); err != nil { + log.Printf("bedrock: stream: failed to parse tool arguments for %q: %v", tool.name, err) + args = map[string]any{"raw": argsStr} + } + } + funcArgs := argsStr + if argsJSON, marshalErr := json.Marshal(args); marshalErr == nil { + funcArgs = string(argsJSON) + } + toolCalls = append(toolCalls, ToolCall{ + ID: tool.id, + Name: tool.name, + Arguments: args, + Function: &FunctionCall{ + Name: tool.name, + Arguments: funcArgs, + }, + }) + delete(activeTools, idx) + } + + case *types.ConverseStreamOutputMemberMessageStop: + // Message complete + switch e.Value.StopReason { + case types.StopReasonToolUse: + finishReason = "tool_calls" + case types.StopReasonMaxTokens: + finishReason = "length" + case types.StopReasonEndTurn: + finishReason = "stop" + case types.StopReasonStopSequence: + finishReason = "stop" + case types.StopReasonContentFiltered: + finishReason = "content_filter" + default: + finishReason = "stop" + } + + case *types.ConverseStreamOutputMemberMetadata: + // Usage metadata + if e.Value.Usage != nil { + usage = &UsageInfo{ + PromptTokens: int(aws.ToInt32(e.Value.Usage.InputTokens)), + CompletionTokens: int(aws.ToInt32(e.Value.Usage.OutputTokens)), + TotalTokens: int(aws.ToInt32(e.Value.Usage.InputTokens)) + int(aws.ToInt32(e.Value.Usage.OutputTokens)), + } + } + } + } + } + +done: + if err := stream.Err(); err != nil { + return nil, fmt.Errorf("bedrock conversestream: %w", err) + } + + return &LLMResponse{ + Content: textContent.String(), + ToolCalls: toolCalls, + FinishReason: finishReason, + Usage: usage, + }, nil +} + // GetDefaultModel returns an empty string as Bedrock models are user-configured. func (p *Provider) GetDefaultModel() string { return ""