feat(bedrock): implement StreamingProvider for real-time token streaming

Adds ConverseStream API support to the Bedrock provider, implementing
the StreamingProvider interface. Tokens flow via onChunk callback for
real-time delivery to streaming-capable channels.

- Extract buildConverseParams to share request logic between Chat and ChatStream
- Add converseStreamReader interface for testability
- Preserve raw payload in Arguments on JSON parse failure
- Ensure Function.Arguments is always valid JSON
- Streaming timeout only applied when explicitly configured
- Capture stream Close() errors for diagnostics
- Consistent "bedrock conversestream" / "bedrock:" log prefixes

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Andy Lo-A-Foe
2026-04-23 23:27:55 +02:00
parent a36472b55f
commit ad5232ade8
+239 -45
View File
@@ -135,48 +135,23 @@ func NewProvider(ctx context.Context, opts ...Option) (*Provider, error) {
}, nil
}
// Chat sends messages to AWS Bedrock using the Converse API.
func (p *Provider) Chat(
ctx context.Context,
messages []Message,
tools []ToolDefinition,
model string,
options map[string]any,
) (*LLMResponse, error) {
// Apply request timeout if context doesn't already have a deadline.
// Use explicit timeout if set, otherwise fall back to common default.
effectiveTimeout := p.requestTimeout
if effectiveTimeout <= 0 {
effectiveTimeout = common.DefaultRequestTimeout
}
if _, hasDeadline := ctx.Deadline(); !hasDeadline {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, effectiveTimeout)
defer cancel()
}
// converseParams holds the shared request parameters for Converse and ConverseStream.
type converseParams struct {
messages []types.Message
system []types.SystemContentBlock
inferenceConfig *types.InferenceConfiguration
toolConfig *types.ToolConfiguration
}
// Build the Converse API input
input := &bedrockruntime.ConverseInput{
ModelId: aws.String(model),
}
// Convert messages to Bedrock format
func buildConverseParams(messages []Message, tools []ToolDefinition, options map[string]any) converseParams {
bedrockMessages, systemPrompts := convertMessages(messages)
input.Messages = bedrockMessages
// Set system prompts if any
if len(systemPrompts) > 0 {
input.System = systemPrompts
}
// Set inference configuration only when options are provided
var inferenceConfig *types.InferenceConfiguration
if maxTokens, ok := common.AsInt(options["max_tokens"]); ok && maxTokens > 0 {
if inferenceConfig == nil {
inferenceConfig = &types.InferenceConfiguration{}
}
// Clamp to int32 range to avoid overflow
if maxTokens > math.MaxInt32 {
maxTokens = math.MaxInt32
}
@@ -190,23 +165,53 @@ func (p *Provider) Chat(
inferenceConfig.Temperature = aws.Float32(float32(temp))
}
if inferenceConfig != nil {
input.InferenceConfig = inferenceConfig
}
// Convert tools to Bedrock format
// Only set ToolConfig if at least one valid tool was produced
var toolConfig *types.ToolConfiguration
if len(tools) > 0 {
toolConfig := convertTools(tools)
if len(toolConfig.Tools) > 0 {
input.ToolConfig = toolConfig
tc := convertTools(tools)
if len(tc.Tools) > 0 {
toolConfig = tc
}
}
// Call Bedrock Converse API
return converseParams{
messages: bedrockMessages,
system: systemPrompts,
inferenceConfig: inferenceConfig,
toolConfig: toolConfig,
}
}
// Chat sends messages to AWS Bedrock using the Converse API.
func (p *Provider) Chat(
ctx context.Context,
messages []Message,
tools []ToolDefinition,
model string,
options map[string]any,
) (*LLMResponse, error) {
effectiveTimeout := p.requestTimeout
if effectiveTimeout <= 0 {
effectiveTimeout = common.DefaultRequestTimeout
}
if _, hasDeadline := ctx.Deadline(); !hasDeadline {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, effectiveTimeout)
defer cancel()
}
params := buildConverseParams(messages, tools, options)
input := &bedrockruntime.ConverseInput{
ModelId: aws.String(model),
Messages: params.messages,
InferenceConfig: params.inferenceConfig,
ToolConfig: params.toolConfig,
}
if len(params.system) > 0 {
input.System = params.system
}
output, err := p.client.Converse(ctx, input)
if err != nil {
// Check for SSO token expiration errors and provide actionable guidance
if isSSOTokenError(err) {
return nil, fmt.Errorf(
"bedrock converse: AWS credentials may have expired. If using AWS SSO, run 'aws sso login' to refresh: %w",
@@ -216,10 +221,199 @@ func (p *Provider) Chat(
return nil, fmt.Errorf("bedrock converse: %w", err)
}
// Parse the response
return parseResponse(output)
}
// ChatStream sends messages to AWS Bedrock using the ConverseStream API.
// It streams the accumulated text so far via the onChunk callback and returns the complete response.
func (p *Provider) ChatStream(
ctx context.Context,
messages []Message,
tools []ToolDefinition,
model string,
options map[string]any,
onChunk func(accumulated string),
) (*LLMResponse, error) {
if p.requestTimeout > 0 {
if _, hasDeadline := ctx.Deadline(); !hasDeadline {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, p.requestTimeout)
defer cancel()
}
}
params := buildConverseParams(messages, tools, options)
input := &bedrockruntime.ConverseStreamInput{
ModelId: aws.String(model),
Messages: params.messages,
InferenceConfig: params.inferenceConfig,
ToolConfig: params.toolConfig,
}
if len(params.system) > 0 {
input.System = params.system
}
output, err := p.client.ConverseStream(ctx, input)
if err != nil {
if isSSOTokenError(err) {
return nil, fmt.Errorf(
"bedrock conversestream: AWS credentials may have expired. If using AWS SSO, run 'aws sso login' to refresh: %w",
err,
)
}
return nil, fmt.Errorf("bedrock conversestream: %w", err)
}
return parseStreamResponse(ctx, output.GetStream(), onChunk)
}
// converseStreamReader abstracts the Bedrock event stream so parseStreamResponse
// can be unit-tested with a mock event source.
type converseStreamReader interface {
Events() <-chan types.ConverseStreamOutput
Err() error
Close() error
}
// parseStreamResponse processes the ConverseStream event stream and accumulates the response.
func parseStreamResponse(
ctx context.Context,
stream converseStreamReader,
onChunk func(accumulated string),
) (resp *LLMResponse, err error) {
if stream == nil {
return nil, fmt.Errorf("bedrock conversestream: nil event stream")
}
defer func() {
if closeErr := stream.Close(); closeErr != nil {
if err == nil {
err = fmt.Errorf("bedrock conversestream: close event stream: %w", closeErr)
} else {
log.Printf("bedrock conversestream: close event stream: %v", closeErr)
}
}
}()
var textContent strings.Builder
finishReason := "stop"
var usage *UsageInfo
toolCalls := make([]ToolCall, 0)
// Track active tool use blocks by index
type toolAccum struct {
id string
name string
argsJSON strings.Builder
}
activeTools := map[int]*toolAccum{}
events := stream.Events()
for {
select {
case <-ctx.Done():
return nil, ctx.Err()
case event, ok := <-events:
if !ok {
// Stream closed
goto done
}
switch e := event.(type) {
case *types.ConverseStreamOutputMemberContentBlockStart:
// New content block starting
if toolUse, ok := e.Value.Start.(*types.ContentBlockStartMemberToolUse); ok {
activeTools[int(aws.ToInt32(e.Value.ContentBlockIndex))] = &toolAccum{
id: aws.ToString(toolUse.Value.ToolUseId),
name: aws.ToString(toolUse.Value.Name),
}
}
case *types.ConverseStreamOutputMemberContentBlockDelta:
// Content delta
switch delta := e.Value.Delta.(type) {
case *types.ContentBlockDeltaMemberText:
textContent.WriteString(delta.Value)
if onChunk != nil {
onChunk(textContent.String())
}
case *types.ContentBlockDeltaMemberToolUse:
idx := int(aws.ToInt32(e.Value.ContentBlockIndex))
if tool, exists := activeTools[idx]; exists {
tool.argsJSON.WriteString(aws.ToString(delta.Value.Input))
}
}
case *types.ConverseStreamOutputMemberContentBlockStop:
// Content block finished - finalize tool if it was a tool use
idx := int(aws.ToInt32(e.Value.ContentBlockIndex))
if tool, exists := activeTools[idx]; exists {
args := make(map[string]any)
argsStr := tool.argsJSON.String()
if argsStr != "" {
if err := json.Unmarshal([]byte(argsStr), &args); err != nil {
log.Printf("bedrock: stream: failed to parse tool arguments for %q: %v", tool.name, err)
args = map[string]any{"raw": argsStr}
}
}
funcArgs := argsStr
if argsJSON, marshalErr := json.Marshal(args); marshalErr == nil {
funcArgs = string(argsJSON)
}
toolCalls = append(toolCalls, ToolCall{
ID: tool.id,
Name: tool.name,
Arguments: args,
Function: &FunctionCall{
Name: tool.name,
Arguments: funcArgs,
},
})
delete(activeTools, idx)
}
case *types.ConverseStreamOutputMemberMessageStop:
// Message complete
switch e.Value.StopReason {
case types.StopReasonToolUse:
finishReason = "tool_calls"
case types.StopReasonMaxTokens:
finishReason = "length"
case types.StopReasonEndTurn:
finishReason = "stop"
case types.StopReasonStopSequence:
finishReason = "stop"
case types.StopReasonContentFiltered:
finishReason = "content_filter"
default:
finishReason = "stop"
}
case *types.ConverseStreamOutputMemberMetadata:
// Usage metadata
if e.Value.Usage != nil {
usage = &UsageInfo{
PromptTokens: int(aws.ToInt32(e.Value.Usage.InputTokens)),
CompletionTokens: int(aws.ToInt32(e.Value.Usage.OutputTokens)),
TotalTokens: int(aws.ToInt32(e.Value.Usage.InputTokens)) + int(aws.ToInt32(e.Value.Usage.OutputTokens)),
}
}
}
}
}
done:
if err := stream.Err(); err != nil {
return nil, fmt.Errorf("bedrock conversestream: %w", err)
}
return &LLMResponse{
Content: textContent.String(),
ToolCalls: toolCalls,
FinishReason: finishReason,
Usage: usage,
}, nil
}
// GetDefaultModel returns an empty string as Bedrock models are user-configured.
func (p *Provider) GetDefaultModel() string {
return ""