feat(bedrock): detect SSO token expiration and provide actionable error

When AWS SSO credentials expire, provide a clear error message instructing
the user to run 'aws sso login' to refresh their session.
This commit is contained in:
Andy Lo-A-Foe
2026-03-25 07:02:23 +01:00
parent 27f638e909
commit 4f90909af3
2 changed files with 93 additions and 0 deletions
+31
View File
@@ -206,6 +206,10 @@ func (p *Provider) Chat(
// Call Bedrock Converse API
output, err := p.client.Converse(ctx, input)
if err != nil {
// Check for SSO token expiration errors and provide actionable guidance
if isSSOTokenError(err) {
return nil, fmt.Errorf("bedrock converse: AWS credentials may have expired. If using AWS SSO, run 'aws sso login' to refresh: %w", err)
}
return nil, fmt.Errorf("bedrock converse: %w", err)
}
@@ -580,3 +584,30 @@ func parseResponse(output *bedrockruntime.ConverseOutput) (*LLMResponse, error)
Usage: usage,
}, nil
}
// isSSOTokenError checks if the error is related to expired or invalid AWS SSO tokens.
// This helps provide actionable guidance when SSO credentials need to be refreshed.
// Only matches SSO-specific error patterns to avoid misclassifying other AWS credential errors.
func isSSOTokenError(err error) bool {
if err == nil {
return false
}
lower := strings.ToLower(err.Error())
// Check for specific SSO token expiration/refresh-related error patterns (case-insensitive)
// Avoid matching generic patterns that could match non-SSO AWS errors (e.g., STS ExpiredToken)
if strings.Contains(lower, "refresh cached sso token") {
return true
}
if strings.Contains(lower, "read cached sso token") {
return true
}
if strings.Contains(lower, "sso oidc") {
return true
}
if strings.Contains(lower, "invalidgrantexception") {
return true
}
return false
}
@@ -8,6 +8,7 @@
package bedrock
import (
"fmt"
"testing"
"github.com/aws/aws-sdk-go-v2/aws"
@@ -539,3 +540,64 @@ func TestParseResponse_ToolCallWithNilInput(t *testing.T) {
assert.NotNil(t, resp.ToolCalls[0].Arguments)
assert.Empty(t, resp.ToolCalls[0].Arguments)
}
func TestIsSSOTokenError(t *testing.T) {
tests := []struct {
name string
err error
expected bool
}{
{
name: "nil error",
err: nil,
expected: false,
},
{
name: "generic error",
err: fmt.Errorf("connection refused"),
expected: false,
},
{
name: "SSO config error not expiration",
err: fmt.Errorf("failed to load SSO profile: invalid SSO session"),
expected: false,
},
{
name: "STS ExpiredToken error",
err: fmt.Errorf("ExpiredToken: The security token included in the request is expired"),
expected: false,
},
{
name: "SSO token refresh error",
err: fmt.Errorf("refresh cached SSO token failed"),
expected: true,
},
{
name: "InvalidGrantException",
err: fmt.Errorf("operation error SSO OIDC: CreateToken, InvalidGrantException"),
expected: true,
},
{
name: "SSO OIDC error",
err: fmt.Errorf("operation error SSO OIDC: CreateToken, failed"),
expected: true,
},
{
name: "full SSO error message",
err: fmt.Errorf("get identity: get credentials: failed to refresh cached credentials, refresh cached SSO token failed, unable to refresh SSO token"),
expected: true,
},
{
name: "SSO token file missing",
err: fmt.Errorf("get identity: get credentials: failed to refresh cached credentials, failed to read cached SSO token file, open ~/.aws/sso/cache/abc123.json: no such file or directory"),
expected: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := isSSOTokenError(tt.err)
assert.Equal(t, tt.expected, result)
})
}
}