diff --git a/pkg/agent/pipeline_llm.go b/pkg/agent/pipeline_llm.go index 3dda03b3e..4f8a06132 100644 --- a/pkg/agent/pipeline_llm.go +++ b/pkg/agent/pipeline_llm.go @@ -280,22 +280,8 @@ func (p *Pipeline) CallLLM( } errMsg := strings.ToLower(err.Error()) - isTimeoutError := errors.Is(err, context.DeadlineExceeded) || - strings.Contains(errMsg, "deadline exceeded") || - strings.Contains(errMsg, "client.timeout") || - strings.Contains(errMsg, "timed out") || - strings.Contains(errMsg, "timeout exceeded") - - isNetworkError := !isTimeoutError && (strings.Contains(errMsg, "connection reset") || - strings.Contains(errMsg, "connection refused") || - strings.Contains(errMsg, "broken pipe") || - strings.Contains(errMsg, "no such host") || - strings.Contains(errMsg, "network is unreachable") || - strings.Contains(errMsg, "read tcp") || - strings.Contains(errMsg, "write tcp") || - strings.Contains(errMsg, "eof")) - - isContextError := !isTimeoutError && (strings.Contains(errMsg, "context_length_exceeded") || + retryReason, isTransientError := transientLLMRetryReason(err) + isContextError := !isTransientError && (strings.Contains(errMsg, "context_length_exceeded") || strings.Contains(errMsg, "context window") || strings.Contains(errMsg, "context_window") || strings.Contains(errMsg, "maximum context length") || @@ -306,7 +292,7 @@ func (p *Pipeline) CallLLM( strings.Contains(errMsg, "prompt is too long") || strings.Contains(errMsg, "request too large")) - if isTimeoutError && retry < maxRetries { + if isTransientError && retry < maxRetries { backoff := time.Duration(retry+1) * time.Duration(backoffSecs) * time.Second al.emitEvent( runtimeevents.KindAgentLLMRetry, @@ -314,42 +300,14 @@ func (p *Pipeline) CallLLM( LLMRetryPayload{ Attempt: retry + 1, MaxRetries: maxRetries, - Reason: "timeout", + Reason: retryReason, Error: err.Error(), Backoff: backoff, }, ) - logger.WarnCF("agent", "Timeout error, retrying after backoff", map[string]any{ - "error": err.Error(), - "retry": retry, - "backoff": backoff.String(), - }) - if sleepErr := sleepWithContext(turnCtx, backoff); sleepErr != nil { - if ts.hardAbortRequested() { - _ = ts.requestHardAbort() - return ControlBreak, nil - } - err = sleepErr - break - } - continue - } - - if isNetworkError && retry < maxRetries { - backoff := time.Duration(retry+1) * time.Duration(backoffSecs) * time.Second - al.emitEvent( - runtimeevents.KindAgentLLMRetry, - ts.eventMeta("runTurn", "turn.llm.retry"), - LLMRetryPayload{ - Attempt: retry + 1, - MaxRetries: maxRetries, - Reason: "network", - Error: err.Error(), - Backoff: backoff, - }, - ) - logger.WarnCF("agent", "Network error, retrying after backoff", map[string]any{ + logger.WarnCF("agent", "Transient LLM error, retrying after backoff", map[string]any{ "error": err.Error(), + "reason": retryReason, "retry": retry, "backoff": backoff.String(), }) @@ -735,3 +693,45 @@ func providerForFallbackCandidate( } return activeProvider, nil } + +func transientLLMRetryReason(err error) (string, bool) { + if err == nil { + return "", false + } + + if failErr := providers.ClassifyError(err, "", ""); failErr != nil { + switch failErr.Reason { + case providers.FailoverTimeout: + if failErr.Status >= 500 { + return "server_error", true + } + return "timeout", true + case providers.FailoverNetwork: + return "network", true + case providers.FailoverRateLimit, providers.FailoverOverloaded: + return "rate_limit", true + } + } + + errMsg := strings.ToLower(err.Error()) + if errors.Is(err, context.DeadlineExceeded) || + strings.Contains(errMsg, "deadline exceeded") || + strings.Contains(errMsg, "client.timeout") || + strings.Contains(errMsg, "timed out") || + strings.Contains(errMsg, "timeout exceeded") { + return "timeout", true + } + + if strings.Contains(errMsg, "connection reset") || + strings.Contains(errMsg, "connection refused") || + strings.Contains(errMsg, "broken pipe") || + strings.Contains(errMsg, "no such host") || + strings.Contains(errMsg, "network is unreachable") || + strings.Contains(errMsg, "read tcp") || + strings.Contains(errMsg, "write tcp") || + strings.Contains(errMsg, "eof") { + return "network", true + } + + return "", false +} diff --git a/pkg/agent/turn_coord_test.go b/pkg/agent/turn_coord_test.go index dc2715af7..3aa35559d 100644 --- a/pkg/agent/turn_coord_test.go +++ b/pkg/agent/turn_coord_test.go @@ -193,6 +193,38 @@ func (p *errorProvider) GetDefaultModel() string { return "error-model" } +type failOnceLLMProvider struct { + err error + response string + callCount int + mu sync.Mutex +} + +func (p *failOnceLLMProvider) Chat( + ctx context.Context, + messages []providers.Message, + tools []providers.ToolDefinition, + model string, + opts map[string]any, +) (*providers.LLMResponse, error) { + p.mu.Lock() + p.callCount++ + callCount := p.callCount + p.mu.Unlock() + + if callCount == 1 { + return nil, p.err + } + return &providers.LLMResponse{ + Content: p.response, + FinishReason: "stop", + }, nil +} + +func (p *failOnceLLMProvider) GetDefaultModel() string { + return "fail-once-model" +} + // ============================================================================= // Test Helper Functions // ============================================================================= @@ -586,6 +618,59 @@ func TestPipeline_CallLLM_TimeoutRetry(t *testing.T) { } } +func TestPipeline_CallLLM_HTTP5xxRetry(t *testing.T) { + tmpDir := t.TempDir() + provider := &failOnceLLMProvider{ + err: errors.New("API request failed:\n Status: 500\n Body: internal server error"), + response: "Recovered from server error", + } + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + ModelName: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + MaxLLMRetries: 1, + LLMRetryBackoffSecs: 1, + }, + }, + } + + msgBus := bus.NewMessageBus() + al := NewAgentLoop(cfg, msgBus, provider) + defer al.Close() + agent := al.registry.GetDefaultAgent() + if agent == nil { + t.Fatal("expected default agent") + } + + pipeline := NewPipeline(al) + ts := newTurnState(agent, makeTestProcessOpts("test-session"), turnEventScope{ + turnID: "turn-1", + context: newTurnContext(nil, nil, nil), + }) + + exec, err := pipeline.SetupTurn(context.Background(), ts) + if err != nil { + t.Fatalf("SetupTurn failed: %v", err) + } + + ctrl, err := pipeline.CallLLM(context.Background(), context.Background(), ts, exec, 1) + if err != nil { + t.Fatalf("expected HTTP 500 retry to recover, got error: %v", err) + } + if ctrl != ControlBreak { + t.Fatalf("expected ControlBreak, got %v", ctrl) + } + if exec.finalContent != "Recovered from server error" { + t.Fatalf("finalContent = %q, want recovered response", exec.finalContent) + } + if provider.callCount != 2 { + t.Fatalf("callCount = %d, want 2", provider.callCount) + } +} + func TestPipeline_CallLLM_ContextLengthError(t *testing.T) { errorPrv := &errorProvider{errType: "context_length"} al, agent, cleanup := newTurnCoordTestLoop(t, errorPrv) diff --git a/pkg/tools/session.go b/pkg/tools/session.go index 8c7584254..2b03df350 100644 --- a/pkg/tools/session.go +++ b/pkg/tools/session.go @@ -171,25 +171,42 @@ func (s *ProcessSession) ToSessionInfo() SessionInfo { type SessionManager struct { mu sync.RWMutex sessions map[string]*ProcessSession + stopCh chan struct{} + stopOnce sync.Once } func NewSessionManager() *SessionManager { sm := &SessionManager{ sessions: make(map[string]*ProcessSession), + stopCh: make(chan struct{}), } // Start cleaner goroutine - runs every 5 minutes, cleans up sessions done for >30 minutes go func() { ticker := time.NewTicker(5 * time.Minute) defer ticker.Stop() - for range ticker.C { - sm.cleanupOldSessions() + for { + select { + case <-sm.stopCh: + return + case <-ticker.C: + sm.cleanupOldSessions() + } } }() return sm } +// Stop shuts down the background cleanup goroutine. Safe to call multiple +// times from concurrent goroutines. After Stop returns, the SessionManager +// is still usable — only the cleanup goroutine is terminated. +func (sm *SessionManager) Stop() { + sm.stopOnce.Do(func() { + close(sm.stopCh) + }) +} + // cleanupOldSessions removes sessions that are done and older than 30 minutes func (sm *SessionManager) cleanupOldSessions() { sm.mu.Lock() diff --git a/pkg/tools/session_test.go b/pkg/tools/session_test.go index 6cfe72a10..96e3410e7 100644 --- a/pkg/tools/session_test.go +++ b/pkg/tools/session_test.go @@ -8,6 +8,7 @@ import ( func TestSessionManager_AddGet(t *testing.T) { sm := NewSessionManager() + t.Cleanup(sm.Stop) session := &ProcessSession{ ID: "test-1", Command: "echo hello", @@ -24,6 +25,7 @@ func TestSessionManager_AddGet(t *testing.T) { func TestSessionManager_Remove(t *testing.T) { sm := NewSessionManager() + t.Cleanup(sm.Stop) session := &ProcessSession{ ID: "test-1", Command: "echo hello", @@ -39,6 +41,7 @@ func TestSessionManager_Remove(t *testing.T) { func TestSessionManager_List(t *testing.T) { sm := NewSessionManager() + t.Cleanup(sm.Stop) sm.Add(&ProcessSession{ ID: "test-1", Command: "echo hello", diff --git a/pkg/tools/shell_test.go b/pkg/tools/shell_test.go index 14d5a6697..83807dac6 100644 --- a/pkg/tools/shell_test.go +++ b/pkg/tools/shell_test.go @@ -907,6 +907,7 @@ func TestShellTool_List_Empty(t *testing.T) { require.NoError(t, err) sm := NewSessionManager() + t.Cleanup(sm.Stop) tool.sessionManager = sm ctx := context.Background() @@ -922,6 +923,7 @@ func TestShellTool_RunBackground_List(t *testing.T) { require.NoError(t, err) sm := NewSessionManager() + t.Cleanup(sm.Stop) tool.sessionManager = sm ctx := WithToolContext(context.Background(), "cli", "test") @@ -961,6 +963,7 @@ func TestShellTool_Read_Output(t *testing.T) { require.NoError(t, err) sm := NewSessionManager() + t.Cleanup(sm.Stop) tool.sessionManager = sm ctx := WithToolContext(context.Background(), "cli", "test") @@ -995,6 +998,7 @@ func TestShellTool_Kill(t *testing.T) { require.NoError(t, err) sm := NewSessionManager() + t.Cleanup(sm.Stop) tool.sessionManager = sm ctx := WithToolContext(context.Background(), "cli", "test") @@ -1034,6 +1038,7 @@ func TestShellTool_PTY_AllowedCommands(t *testing.T) { require.NoError(t, err) sm := NewSessionManager() + t.Cleanup(sm.Stop) tool.sessionManager = sm ctx := WithToolContext(context.Background(), "cli", "test") @@ -1069,6 +1074,7 @@ func TestShellTool_PTY_WriteRead(t *testing.T) { require.NoError(t, err) sm := NewSessionManager() + t.Cleanup(sm.Stop) tool.sessionManager = sm ctx := WithToolContext(context.Background(), "cli", "test") @@ -1128,6 +1134,7 @@ func TestShellTool_PTY_Poll(t *testing.T) { require.NoError(t, err) sm := NewSessionManager() + t.Cleanup(sm.Stop) tool.sessionManager = sm ctx := WithToolContext(context.Background(), "cli", "test") @@ -1181,6 +1188,7 @@ func TestShellTool_PTY_Kill(t *testing.T) { require.NoError(t, err) sm := NewSessionManager() + t.Cleanup(sm.Stop) tool.sessionManager = sm ctx := WithToolContext(context.Background(), "cli", "test") @@ -1226,6 +1234,7 @@ func TestShellTool_Write_Read_NonPTY(t *testing.T) { require.NoError(t, err) sm := NewSessionManager() + t.Cleanup(sm.Stop) tool.sessionManager = sm ctx := WithToolContext(context.Background(), "cli", "test") @@ -1279,6 +1288,7 @@ func TestShellTool_Read_NonPTY_Running(t *testing.T) { require.NoError(t, err) sm := NewSessionManager() + t.Cleanup(sm.Stop) tool.sessionManager = sm ctx := WithToolContext(context.Background(), "cli", "test") @@ -1347,6 +1357,7 @@ func TestShellTool_ProcessGroupKill(t *testing.T) { require.NoError(t, err) sm := NewSessionManager() + t.Cleanup(sm.Stop) tool.sessionManager = sm ctx := WithToolContext(context.Background(), "cli", "test") @@ -1407,6 +1418,7 @@ func TestShellTool_PTY_ProcessGroupKill(t *testing.T) { require.NoError(t, err) sm := NewSessionManager() + t.Cleanup(sm.Stop) tool.sessionManager = sm ctx := WithToolContext(context.Background(), "cli", "test") @@ -1459,6 +1471,7 @@ func TestShellTool_PTY_Background_Read(t *testing.T) { require.NoError(t, err) sm := NewSessionManager() + t.Cleanup(sm.Stop) tool.sessionManager = sm ctx := WithToolContext(context.Background(), "cli", "test") @@ -1499,6 +1512,7 @@ func TestShellTool_PTY_Background_ReadNoBlock(t *testing.T) { require.NoError(t, err) sm := NewSessionManager() + t.Cleanup(sm.Stop) tool.sessionManager = sm ctx := WithToolContext(context.Background(), "cli", "test") @@ -1543,6 +1557,7 @@ func TestShellTool_Poll_Status(t *testing.T) { require.NoError(t, err) sm := NewSessionManager() + t.Cleanup(sm.Stop) tool.sessionManager = sm ctx := WithToolContext(context.Background(), "cli", "test")