mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-05-25 16:00:35 +00:00
feat(bedrock): implement StreamingProvider for real-time token streaming
Adds ConverseStream API support to the Bedrock provider, implementing the StreamingProvider interface. Tokens flow via onChunk callback for real-time delivery to streaming-capable channels. - Extract buildConverseParams to share request logic between Chat and ChatStream - Add converseStreamReader interface for testability - Preserve raw payload in Arguments on JSON parse failure - Ensure Function.Arguments is always valid JSON - Streaming timeout only applied when explicitly configured - Capture stream Close() errors for diagnostics - Consistent "bedrock conversestream" / "bedrock:" log prefixes Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -135,48 +135,23 @@ func NewProvider(ctx context.Context, opts ...Option) (*Provider, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Chat sends messages to AWS Bedrock using the Converse API.
|
||||
func (p *Provider) Chat(
|
||||
ctx context.Context,
|
||||
messages []Message,
|
||||
tools []ToolDefinition,
|
||||
model string,
|
||||
options map[string]any,
|
||||
) (*LLMResponse, error) {
|
||||
// Apply request timeout if context doesn't already have a deadline.
|
||||
// Use explicit timeout if set, otherwise fall back to common default.
|
||||
effectiveTimeout := p.requestTimeout
|
||||
if effectiveTimeout <= 0 {
|
||||
effectiveTimeout = common.DefaultRequestTimeout
|
||||
}
|
||||
if _, hasDeadline := ctx.Deadline(); !hasDeadline {
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithTimeout(ctx, effectiveTimeout)
|
||||
defer cancel()
|
||||
}
|
||||
// converseParams holds the shared request parameters for Converse and ConverseStream.
|
||||
type converseParams struct {
|
||||
messages []types.Message
|
||||
system []types.SystemContentBlock
|
||||
inferenceConfig *types.InferenceConfiguration
|
||||
toolConfig *types.ToolConfiguration
|
||||
}
|
||||
|
||||
// Build the Converse API input
|
||||
input := &bedrockruntime.ConverseInput{
|
||||
ModelId: aws.String(model),
|
||||
}
|
||||
|
||||
// Convert messages to Bedrock format
|
||||
func buildConverseParams(messages []Message, tools []ToolDefinition, options map[string]any) converseParams {
|
||||
bedrockMessages, systemPrompts := convertMessages(messages)
|
||||
input.Messages = bedrockMessages
|
||||
|
||||
// Set system prompts if any
|
||||
if len(systemPrompts) > 0 {
|
||||
input.System = systemPrompts
|
||||
}
|
||||
|
||||
// Set inference configuration only when options are provided
|
||||
var inferenceConfig *types.InferenceConfiguration
|
||||
|
||||
if maxTokens, ok := common.AsInt(options["max_tokens"]); ok && maxTokens > 0 {
|
||||
if inferenceConfig == nil {
|
||||
inferenceConfig = &types.InferenceConfiguration{}
|
||||
}
|
||||
// Clamp to int32 range to avoid overflow
|
||||
if maxTokens > math.MaxInt32 {
|
||||
maxTokens = math.MaxInt32
|
||||
}
|
||||
@@ -190,23 +165,53 @@ func (p *Provider) Chat(
|
||||
inferenceConfig.Temperature = aws.Float32(float32(temp))
|
||||
}
|
||||
|
||||
if inferenceConfig != nil {
|
||||
input.InferenceConfig = inferenceConfig
|
||||
}
|
||||
|
||||
// Convert tools to Bedrock format
|
||||
// Only set ToolConfig if at least one valid tool was produced
|
||||
var toolConfig *types.ToolConfiguration
|
||||
if len(tools) > 0 {
|
||||
toolConfig := convertTools(tools)
|
||||
if len(toolConfig.Tools) > 0 {
|
||||
input.ToolConfig = toolConfig
|
||||
tc := convertTools(tools)
|
||||
if len(tc.Tools) > 0 {
|
||||
toolConfig = tc
|
||||
}
|
||||
}
|
||||
|
||||
// Call Bedrock Converse API
|
||||
return converseParams{
|
||||
messages: bedrockMessages,
|
||||
system: systemPrompts,
|
||||
inferenceConfig: inferenceConfig,
|
||||
toolConfig: toolConfig,
|
||||
}
|
||||
}
|
||||
|
||||
// Chat sends messages to AWS Bedrock using the Converse API.
|
||||
func (p *Provider) Chat(
|
||||
ctx context.Context,
|
||||
messages []Message,
|
||||
tools []ToolDefinition,
|
||||
model string,
|
||||
options map[string]any,
|
||||
) (*LLMResponse, error) {
|
||||
effectiveTimeout := p.requestTimeout
|
||||
if effectiveTimeout <= 0 {
|
||||
effectiveTimeout = common.DefaultRequestTimeout
|
||||
}
|
||||
if _, hasDeadline := ctx.Deadline(); !hasDeadline {
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithTimeout(ctx, effectiveTimeout)
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
params := buildConverseParams(messages, tools, options)
|
||||
input := &bedrockruntime.ConverseInput{
|
||||
ModelId: aws.String(model),
|
||||
Messages: params.messages,
|
||||
InferenceConfig: params.inferenceConfig,
|
||||
ToolConfig: params.toolConfig,
|
||||
}
|
||||
if len(params.system) > 0 {
|
||||
input.System = params.system
|
||||
}
|
||||
|
||||
output, err := p.client.Converse(ctx, input)
|
||||
if err != nil {
|
||||
// Check for SSO token expiration errors and provide actionable guidance
|
||||
if isSSOTokenError(err) {
|
||||
return nil, fmt.Errorf(
|
||||
"bedrock converse: AWS credentials may have expired. If using AWS SSO, run 'aws sso login' to refresh: %w",
|
||||
@@ -216,10 +221,199 @@ func (p *Provider) Chat(
|
||||
return nil, fmt.Errorf("bedrock converse: %w", err)
|
||||
}
|
||||
|
||||
// Parse the response
|
||||
return parseResponse(output)
|
||||
}
|
||||
|
||||
// ChatStream sends messages to AWS Bedrock using the ConverseStream API.
|
||||
// It streams the accumulated text so far via the onChunk callback and returns the complete response.
|
||||
func (p *Provider) ChatStream(
|
||||
ctx context.Context,
|
||||
messages []Message,
|
||||
tools []ToolDefinition,
|
||||
model string,
|
||||
options map[string]any,
|
||||
onChunk func(accumulated string),
|
||||
) (*LLMResponse, error) {
|
||||
if p.requestTimeout > 0 {
|
||||
if _, hasDeadline := ctx.Deadline(); !hasDeadline {
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithTimeout(ctx, p.requestTimeout)
|
||||
defer cancel()
|
||||
}
|
||||
}
|
||||
|
||||
params := buildConverseParams(messages, tools, options)
|
||||
input := &bedrockruntime.ConverseStreamInput{
|
||||
ModelId: aws.String(model),
|
||||
Messages: params.messages,
|
||||
InferenceConfig: params.inferenceConfig,
|
||||
ToolConfig: params.toolConfig,
|
||||
}
|
||||
if len(params.system) > 0 {
|
||||
input.System = params.system
|
||||
}
|
||||
|
||||
output, err := p.client.ConverseStream(ctx, input)
|
||||
if err != nil {
|
||||
if isSSOTokenError(err) {
|
||||
return nil, fmt.Errorf(
|
||||
"bedrock conversestream: AWS credentials may have expired. If using AWS SSO, run 'aws sso login' to refresh: %w",
|
||||
err,
|
||||
)
|
||||
}
|
||||
return nil, fmt.Errorf("bedrock conversestream: %w", err)
|
||||
}
|
||||
|
||||
return parseStreamResponse(ctx, output.GetStream(), onChunk)
|
||||
}
|
||||
|
||||
// converseStreamReader abstracts the Bedrock event stream so parseStreamResponse
|
||||
// can be unit-tested with a mock event source.
|
||||
type converseStreamReader interface {
|
||||
Events() <-chan types.ConverseStreamOutput
|
||||
Err() error
|
||||
Close() error
|
||||
}
|
||||
|
||||
// parseStreamResponse processes the ConverseStream event stream and accumulates the response.
|
||||
func parseStreamResponse(
|
||||
ctx context.Context,
|
||||
stream converseStreamReader,
|
||||
onChunk func(accumulated string),
|
||||
) (resp *LLMResponse, err error) {
|
||||
if stream == nil {
|
||||
return nil, fmt.Errorf("bedrock conversestream: nil event stream")
|
||||
}
|
||||
defer func() {
|
||||
if closeErr := stream.Close(); closeErr != nil {
|
||||
if err == nil {
|
||||
err = fmt.Errorf("bedrock conversestream: close event stream: %w", closeErr)
|
||||
} else {
|
||||
log.Printf("bedrock conversestream: close event stream: %v", closeErr)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
var textContent strings.Builder
|
||||
finishReason := "stop"
|
||||
var usage *UsageInfo
|
||||
toolCalls := make([]ToolCall, 0)
|
||||
|
||||
// Track active tool use blocks by index
|
||||
type toolAccum struct {
|
||||
id string
|
||||
name string
|
||||
argsJSON strings.Builder
|
||||
}
|
||||
activeTools := map[int]*toolAccum{}
|
||||
|
||||
events := stream.Events()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case event, ok := <-events:
|
||||
if !ok {
|
||||
// Stream closed
|
||||
goto done
|
||||
}
|
||||
|
||||
switch e := event.(type) {
|
||||
case *types.ConverseStreamOutputMemberContentBlockStart:
|
||||
// New content block starting
|
||||
if toolUse, ok := e.Value.Start.(*types.ContentBlockStartMemberToolUse); ok {
|
||||
activeTools[int(aws.ToInt32(e.Value.ContentBlockIndex))] = &toolAccum{
|
||||
id: aws.ToString(toolUse.Value.ToolUseId),
|
||||
name: aws.ToString(toolUse.Value.Name),
|
||||
}
|
||||
}
|
||||
|
||||
case *types.ConverseStreamOutputMemberContentBlockDelta:
|
||||
// Content delta
|
||||
switch delta := e.Value.Delta.(type) {
|
||||
case *types.ContentBlockDeltaMemberText:
|
||||
textContent.WriteString(delta.Value)
|
||||
if onChunk != nil {
|
||||
onChunk(textContent.String())
|
||||
}
|
||||
case *types.ContentBlockDeltaMemberToolUse:
|
||||
idx := int(aws.ToInt32(e.Value.ContentBlockIndex))
|
||||
if tool, exists := activeTools[idx]; exists {
|
||||
tool.argsJSON.WriteString(aws.ToString(delta.Value.Input))
|
||||
}
|
||||
}
|
||||
|
||||
case *types.ConverseStreamOutputMemberContentBlockStop:
|
||||
// Content block finished - finalize tool if it was a tool use
|
||||
idx := int(aws.ToInt32(e.Value.ContentBlockIndex))
|
||||
if tool, exists := activeTools[idx]; exists {
|
||||
args := make(map[string]any)
|
||||
argsStr := tool.argsJSON.String()
|
||||
if argsStr != "" {
|
||||
if err := json.Unmarshal([]byte(argsStr), &args); err != nil {
|
||||
log.Printf("bedrock: stream: failed to parse tool arguments for %q: %v", tool.name, err)
|
||||
args = map[string]any{"raw": argsStr}
|
||||
}
|
||||
}
|
||||
funcArgs := argsStr
|
||||
if argsJSON, marshalErr := json.Marshal(args); marshalErr == nil {
|
||||
funcArgs = string(argsJSON)
|
||||
}
|
||||
toolCalls = append(toolCalls, ToolCall{
|
||||
ID: tool.id,
|
||||
Name: tool.name,
|
||||
Arguments: args,
|
||||
Function: &FunctionCall{
|
||||
Name: tool.name,
|
||||
Arguments: funcArgs,
|
||||
},
|
||||
})
|
||||
delete(activeTools, idx)
|
||||
}
|
||||
|
||||
case *types.ConverseStreamOutputMemberMessageStop:
|
||||
// Message complete
|
||||
switch e.Value.StopReason {
|
||||
case types.StopReasonToolUse:
|
||||
finishReason = "tool_calls"
|
||||
case types.StopReasonMaxTokens:
|
||||
finishReason = "length"
|
||||
case types.StopReasonEndTurn:
|
||||
finishReason = "stop"
|
||||
case types.StopReasonStopSequence:
|
||||
finishReason = "stop"
|
||||
case types.StopReasonContentFiltered:
|
||||
finishReason = "content_filter"
|
||||
default:
|
||||
finishReason = "stop"
|
||||
}
|
||||
|
||||
case *types.ConverseStreamOutputMemberMetadata:
|
||||
// Usage metadata
|
||||
if e.Value.Usage != nil {
|
||||
usage = &UsageInfo{
|
||||
PromptTokens: int(aws.ToInt32(e.Value.Usage.InputTokens)),
|
||||
CompletionTokens: int(aws.ToInt32(e.Value.Usage.OutputTokens)),
|
||||
TotalTokens: int(aws.ToInt32(e.Value.Usage.InputTokens)) + int(aws.ToInt32(e.Value.Usage.OutputTokens)),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
done:
|
||||
if err := stream.Err(); err != nil {
|
||||
return nil, fmt.Errorf("bedrock conversestream: %w", err)
|
||||
}
|
||||
|
||||
return &LLMResponse{
|
||||
Content: textContent.String(),
|
||||
ToolCalls: toolCalls,
|
||||
FinishReason: finishReason,
|
||||
Usage: usage,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetDefaultModel returns an empty string as Bedrock models are user-configured.
|
||||
func (p *Provider) GetDefaultModel() string {
|
||||
return ""
|
||||
|
||||
Reference in New Issue
Block a user