Merge pull request #2991 from chengzhichao-xydt/codex/retry-transient-llm-errors

fix(agent): retry transient LLM HTTP errors using provider error classifier
This commit is contained in:
Mauro
2026-06-02 18:45:35 +02:00
committed by GitHub
2 changed files with 133 additions and 48 deletions
+48 -48
View File
@@ -280,22 +280,8 @@ func (p *Pipeline) CallLLM(
}
errMsg := strings.ToLower(err.Error())
isTimeoutError := errors.Is(err, context.DeadlineExceeded) ||
strings.Contains(errMsg, "deadline exceeded") ||
strings.Contains(errMsg, "client.timeout") ||
strings.Contains(errMsg, "timed out") ||
strings.Contains(errMsg, "timeout exceeded")
isNetworkError := !isTimeoutError && (strings.Contains(errMsg, "connection reset") ||
strings.Contains(errMsg, "connection refused") ||
strings.Contains(errMsg, "broken pipe") ||
strings.Contains(errMsg, "no such host") ||
strings.Contains(errMsg, "network is unreachable") ||
strings.Contains(errMsg, "read tcp") ||
strings.Contains(errMsg, "write tcp") ||
strings.Contains(errMsg, "eof"))
isContextError := !isTimeoutError && (strings.Contains(errMsg, "context_length_exceeded") ||
retryReason, isTransientError := transientLLMRetryReason(err)
isContextError := !isTransientError && (strings.Contains(errMsg, "context_length_exceeded") ||
strings.Contains(errMsg, "context window") ||
strings.Contains(errMsg, "context_window") ||
strings.Contains(errMsg, "maximum context length") ||
@@ -306,7 +292,7 @@ func (p *Pipeline) CallLLM(
strings.Contains(errMsg, "prompt is too long") ||
strings.Contains(errMsg, "request too large"))
if isTimeoutError && retry < maxRetries {
if isTransientError && retry < maxRetries {
backoff := time.Duration(retry+1) * time.Duration(backoffSecs) * time.Second
al.emitEvent(
runtimeevents.KindAgentLLMRetry,
@@ -314,42 +300,14 @@ func (p *Pipeline) CallLLM(
LLMRetryPayload{
Attempt: retry + 1,
MaxRetries: maxRetries,
Reason: "timeout",
Reason: retryReason,
Error: err.Error(),
Backoff: backoff,
},
)
logger.WarnCF("agent", "Timeout error, retrying after backoff", map[string]any{
"error": err.Error(),
"retry": retry,
"backoff": backoff.String(),
})
if sleepErr := sleepWithContext(turnCtx, backoff); sleepErr != nil {
if ts.hardAbortRequested() {
_ = ts.requestHardAbort()
return ControlBreak, nil
}
err = sleepErr
break
}
continue
}
if isNetworkError && retry < maxRetries {
backoff := time.Duration(retry+1) * time.Duration(backoffSecs) * time.Second
al.emitEvent(
runtimeevents.KindAgentLLMRetry,
ts.eventMeta("runTurn", "turn.llm.retry"),
LLMRetryPayload{
Attempt: retry + 1,
MaxRetries: maxRetries,
Reason: "network",
Error: err.Error(),
Backoff: backoff,
},
)
logger.WarnCF("agent", "Network error, retrying after backoff", map[string]any{
logger.WarnCF("agent", "Transient LLM error, retrying after backoff", map[string]any{
"error": err.Error(),
"reason": retryReason,
"retry": retry,
"backoff": backoff.String(),
})
@@ -735,3 +693,45 @@ func providerForFallbackCandidate(
}
return activeProvider, nil
}
func transientLLMRetryReason(err error) (string, bool) {
if err == nil {
return "", false
}
if failErr := providers.ClassifyError(err, "", ""); failErr != nil {
switch failErr.Reason {
case providers.FailoverTimeout:
if failErr.Status >= 500 {
return "server_error", true
}
return "timeout", true
case providers.FailoverNetwork:
return "network", true
case providers.FailoverRateLimit, providers.FailoverOverloaded:
return "rate_limit", true
}
}
errMsg := strings.ToLower(err.Error())
if errors.Is(err, context.DeadlineExceeded) ||
strings.Contains(errMsg, "deadline exceeded") ||
strings.Contains(errMsg, "client.timeout") ||
strings.Contains(errMsg, "timed out") ||
strings.Contains(errMsg, "timeout exceeded") {
return "timeout", true
}
if strings.Contains(errMsg, "connection reset") ||
strings.Contains(errMsg, "connection refused") ||
strings.Contains(errMsg, "broken pipe") ||
strings.Contains(errMsg, "no such host") ||
strings.Contains(errMsg, "network is unreachable") ||
strings.Contains(errMsg, "read tcp") ||
strings.Contains(errMsg, "write tcp") ||
strings.Contains(errMsg, "eof") {
return "network", true
}
return "", false
}
+85
View File
@@ -193,6 +193,38 @@ func (p *errorProvider) GetDefaultModel() string {
return "error-model"
}
type failOnceLLMProvider struct {
err error
response string
callCount int
mu sync.Mutex
}
func (p *failOnceLLMProvider) Chat(
ctx context.Context,
messages []providers.Message,
tools []providers.ToolDefinition,
model string,
opts map[string]any,
) (*providers.LLMResponse, error) {
p.mu.Lock()
p.callCount++
callCount := p.callCount
p.mu.Unlock()
if callCount == 1 {
return nil, p.err
}
return &providers.LLMResponse{
Content: p.response,
FinishReason: "stop",
}, nil
}
func (p *failOnceLLMProvider) GetDefaultModel() string {
return "fail-once-model"
}
// =============================================================================
// Test Helper Functions
// =============================================================================
@@ -586,6 +618,59 @@ func TestPipeline_CallLLM_TimeoutRetry(t *testing.T) {
}
}
func TestPipeline_CallLLM_HTTP5xxRetry(t *testing.T) {
tmpDir := t.TempDir()
provider := &failOnceLLMProvider{
err: errors.New("API request failed:\n Status: 500\n Body: internal server error"),
response: "Recovered from server error",
}
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
ModelName: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
MaxLLMRetries: 1,
LLMRetryBackoffSecs: 1,
},
},
}
msgBus := bus.NewMessageBus()
al := NewAgentLoop(cfg, msgBus, provider)
defer al.Close()
agent := al.registry.GetDefaultAgent()
if agent == nil {
t.Fatal("expected default agent")
}
pipeline := NewPipeline(al)
ts := newTurnState(agent, makeTestProcessOpts("test-session"), turnEventScope{
turnID: "turn-1",
context: newTurnContext(nil, nil, nil),
})
exec, err := pipeline.SetupTurn(context.Background(), ts)
if err != nil {
t.Fatalf("SetupTurn failed: %v", err)
}
ctrl, err := pipeline.CallLLM(context.Background(), context.Background(), ts, exec, 1)
if err != nil {
t.Fatalf("expected HTTP 500 retry to recover, got error: %v", err)
}
if ctrl != ControlBreak {
t.Fatalf("expected ControlBreak, got %v", ctrl)
}
if exec.finalContent != "Recovered from server error" {
t.Fatalf("finalContent = %q, want recovered response", exec.finalContent)
}
if provider.callCount != 2 {
t.Fatalf("callCount = %d, want 2", provider.callCount)
}
}
func TestPipeline_CallLLM_ContextLengthError(t *testing.T) {
errorPrv := &errorProvider{errType: "context_length"}
al, agent, cleanup := newTurnCoordTestLoop(t, errorPrv)