From 4f90909af34cbb9d67ca21e7eb20861c293eb7c7 Mon Sep 17 00:00:00 2001 From: Andy Lo-A-Foe Date: Wed, 25 Mar 2026 07:02:23 +0100 Subject: [PATCH] feat(bedrock): detect SSO token expiration and provide actionable error When AWS SSO credentials expire, provide a clear error message instructing the user to run 'aws sso login' to refresh their session. --- pkg/providers/bedrock/provider_bedrock.go | 31 ++++++++++ .../bedrock/provider_bedrock_test.go | 62 +++++++++++++++++++ 2 files changed, 93 insertions(+) diff --git a/pkg/providers/bedrock/provider_bedrock.go b/pkg/providers/bedrock/provider_bedrock.go index 15c4f664e..9ca29455f 100644 --- a/pkg/providers/bedrock/provider_bedrock.go +++ b/pkg/providers/bedrock/provider_bedrock.go @@ -206,6 +206,10 @@ func (p *Provider) Chat( // Call Bedrock Converse API output, err := p.client.Converse(ctx, input) if err != nil { + // Check for SSO token expiration errors and provide actionable guidance + if isSSOTokenError(err) { + return nil, fmt.Errorf("bedrock converse: AWS credentials may have expired. If using AWS SSO, run 'aws sso login' to refresh: %w", err) + } return nil, fmt.Errorf("bedrock converse: %w", err) } @@ -580,3 +584,30 @@ func parseResponse(output *bedrockruntime.ConverseOutput) (*LLMResponse, error) Usage: usage, }, nil } + +// isSSOTokenError checks if the error is related to expired or invalid AWS SSO tokens. +// This helps provide actionable guidance when SSO credentials need to be refreshed. +// Only matches SSO-specific error patterns to avoid misclassifying other AWS credential errors. +func isSSOTokenError(err error) bool { + if err == nil { + return false + } + lower := strings.ToLower(err.Error()) + + // Check for specific SSO token expiration/refresh-related error patterns (case-insensitive) + // Avoid matching generic patterns that could match non-SSO AWS errors (e.g., STS ExpiredToken) + if strings.Contains(lower, "refresh cached sso token") { + return true + } + if strings.Contains(lower, "read cached sso token") { + return true + } + if strings.Contains(lower, "sso oidc") { + return true + } + if strings.Contains(lower, "invalidgrantexception") { + return true + } + + return false +} diff --git a/pkg/providers/bedrock/provider_bedrock_test.go b/pkg/providers/bedrock/provider_bedrock_test.go index 754d112ee..882c2971c 100644 --- a/pkg/providers/bedrock/provider_bedrock_test.go +++ b/pkg/providers/bedrock/provider_bedrock_test.go @@ -8,6 +8,7 @@ package bedrock import ( + "fmt" "testing" "github.com/aws/aws-sdk-go-v2/aws" @@ -539,3 +540,64 @@ func TestParseResponse_ToolCallWithNilInput(t *testing.T) { assert.NotNil(t, resp.ToolCalls[0].Arguments) assert.Empty(t, resp.ToolCalls[0].Arguments) } + +func TestIsSSOTokenError(t *testing.T) { + tests := []struct { + name string + err error + expected bool + }{ + { + name: "nil error", + err: nil, + expected: false, + }, + { + name: "generic error", + err: fmt.Errorf("connection refused"), + expected: false, + }, + { + name: "SSO config error not expiration", + err: fmt.Errorf("failed to load SSO profile: invalid SSO session"), + expected: false, + }, + { + name: "STS ExpiredToken error", + err: fmt.Errorf("ExpiredToken: The security token included in the request is expired"), + expected: false, + }, + { + name: "SSO token refresh error", + err: fmt.Errorf("refresh cached SSO token failed"), + expected: true, + }, + { + name: "InvalidGrantException", + err: fmt.Errorf("operation error SSO OIDC: CreateToken, InvalidGrantException"), + expected: true, + }, + { + name: "SSO OIDC error", + err: fmt.Errorf("operation error SSO OIDC: CreateToken, failed"), + expected: true, + }, + { + name: "full SSO error message", + err: fmt.Errorf("get identity: get credentials: failed to refresh cached credentials, refresh cached SSO token failed, unable to refresh SSO token"), + expected: true, + }, + { + name: "SSO token file missing", + err: fmt.Errorf("get identity: get credentials: failed to refresh cached credentials, failed to read cached SSO token file, open ~/.aws/sso/cache/abc123.json: no such file or directory"), + expected: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := isSSOTokenError(tt.err) + assert.Equal(t, tt.expected, result) + }) + } +}