diff --git a/pkg/providers/fallback_test.go b/pkg/providers/fallback_test.go index 7b0652683..07cc01baa 100644 --- a/pkg/providers/fallback_test.go +++ b/pkg/providers/fallback_test.go @@ -268,12 +268,21 @@ func TestFallback_UnclassifiedError(t *testing.T) { } } -func TestFallback_NetworkErrorFallsBack(t *testing.T) { +func assertFallbackErrorFallsBack( + t *testing.T, + primaryProvider string, + primaryModel string, + initialErr error, + successContent string, + expectedReason FailoverReason, +) { + t.Helper() + ct := NewCooldownTracker() fc := NewFallbackChain(ct, nil) candidates := []FallbackCandidate{ - makeCandidate("minimax", "minimax-m2.7"), + makeCandidate(primaryProvider, primaryModel), makeCandidate("anthropic", "claude"), } @@ -281,11 +290,9 @@ func TestFallback_NetworkErrorFallsBack(t *testing.T) { run := func(ctx context.Context, provider, model string) (*LLMResponse, error) { attempt++ if attempt == 1 { - return nil, errors.New( - `failed to send request: Post "https://opencode.ai/zen/go/v1/chat/completions": tls: bad record MAC`, - ) + return nil, initialErr } - return &LLMResponse{Content: "fallback ok", FinishReason: "stop"}, nil + return &LLMResponse{Content: successContent, FinishReason: "stop"}, nil } result, err := fc.Execute(context.Background(), candidates, run) @@ -301,45 +308,33 @@ func TestFallback_NetworkErrorFallsBack(t *testing.T) { if len(result.Attempts) != 1 { t.Fatalf("attempts = %d, want 1 failed attempt recorded", len(result.Attempts)) } - if result.Attempts[0].Reason != FailoverNetwork { - t.Fatalf("attempt reason = %q, want network", result.Attempts[0].Reason) + if result.Attempts[0].Reason != expectedReason { + t.Fatalf("attempt reason = %q, want %s", result.Attempts[0].Reason, expectedReason) } } +func TestFallback_NetworkErrorFallsBack(t *testing.T) { + assertFallbackErrorFallsBack( + t, + "minimax", + "minimax-m2.7", + errors.New( + `failed to send request: Post "https://opencode.ai/zen/go/v1/chat/completions": tls: bad record MAC`, + ), + "fallback ok", + FailoverNetwork, + ) +} + func TestFallback_TimeoutErrorFallsBack(t *testing.T) { - ct := NewCooldownTracker() - fc := NewFallbackChain(ct, nil) - - candidates := []FallbackCandidate{ - makeCandidate("openai", "gpt-4"), - makeCandidate("anthropic", "claude"), - } - - attempt := 0 - run := func(ctx context.Context, provider, model string) (*LLMResponse, error) { - attempt++ - if attempt == 1 { - return nil, errors.New("failed to send request: Post \"https://example.com\": i/o timeout") - } - return &LLMResponse{Content: "timeout fallback ok", FinishReason: "stop"}, nil - } - - result, err := fc.Execute(context.Background(), candidates, run) - if err != nil { - t.Fatalf("expected fallback success, got error: %v", err) - } - if attempt != 2 { - t.Fatalf("attempt = %d, want 2", attempt) - } - if result.Provider != "anthropic" || result.Model != "claude" { - t.Fatalf("result = %s/%s, want anthropic/claude", result.Provider, result.Model) - } - if len(result.Attempts) != 1 { - t.Fatalf("attempts = %d, want 1 failed attempt recorded", len(result.Attempts)) - } - if result.Attempts[0].Reason != FailoverTimeout { - t.Fatalf("attempt reason = %q, want timeout", result.Attempts[0].Reason) - } + assertFallbackErrorFallsBack( + t, + "openai", + "gpt-4", + errors.New("failed to send request: Post \"https://example.com\": i/o timeout"), + "timeout fallback ok", + FailoverTimeout, + ) } func TestFallback_SuccessResetsCooldown(t *testing.T) {