refactor(tests): extract common logic for fallback error handling into a helper function

This commit is contained in:
lc6464
2026-04-16 22:45:31 +08:00
parent 7aa2d672ce
commit 2b844778ff
+36 -41
View File
@@ -268,12 +268,21 @@ func TestFallback_UnclassifiedError(t *testing.T) {
}
}
func TestFallback_NetworkErrorFallsBack(t *testing.T) {
func assertFallbackErrorFallsBack(
t *testing.T,
primaryProvider string,
primaryModel string,
initialErr error,
successContent string,
expectedReason FailoverReason,
) {
t.Helper()
ct := NewCooldownTracker()
fc := NewFallbackChain(ct, nil)
candidates := []FallbackCandidate{
makeCandidate("minimax", "minimax-m2.7"),
makeCandidate(primaryProvider, primaryModel),
makeCandidate("anthropic", "claude"),
}
@@ -281,11 +290,9 @@ func TestFallback_NetworkErrorFallsBack(t *testing.T) {
run := func(ctx context.Context, provider, model string) (*LLMResponse, error) {
attempt++
if attempt == 1 {
return nil, errors.New(
`failed to send request: Post "https://opencode.ai/zen/go/v1/chat/completions": tls: bad record MAC`,
)
return nil, initialErr
}
return &LLMResponse{Content: "fallback ok", FinishReason: "stop"}, nil
return &LLMResponse{Content: successContent, FinishReason: "stop"}, nil
}
result, err := fc.Execute(context.Background(), candidates, run)
@@ -301,45 +308,33 @@ func TestFallback_NetworkErrorFallsBack(t *testing.T) {
if len(result.Attempts) != 1 {
t.Fatalf("attempts = %d, want 1 failed attempt recorded", len(result.Attempts))
}
if result.Attempts[0].Reason != FailoverNetwork {
t.Fatalf("attempt reason = %q, want network", result.Attempts[0].Reason)
if result.Attempts[0].Reason != expectedReason {
t.Fatalf("attempt reason = %q, want %s", result.Attempts[0].Reason, expectedReason)
}
}
func TestFallback_NetworkErrorFallsBack(t *testing.T) {
assertFallbackErrorFallsBack(
t,
"minimax",
"minimax-m2.7",
errors.New(
`failed to send request: Post "https://opencode.ai/zen/go/v1/chat/completions": tls: bad record MAC`,
),
"fallback ok",
FailoverNetwork,
)
}
func TestFallback_TimeoutErrorFallsBack(t *testing.T) {
ct := NewCooldownTracker()
fc := NewFallbackChain(ct, nil)
candidates := []FallbackCandidate{
makeCandidate("openai", "gpt-4"),
makeCandidate("anthropic", "claude"),
}
attempt := 0
run := func(ctx context.Context, provider, model string) (*LLMResponse, error) {
attempt++
if attempt == 1 {
return nil, errors.New("failed to send request: Post \"https://example.com\": i/o timeout")
}
return &LLMResponse{Content: "timeout fallback ok", FinishReason: "stop"}, nil
}
result, err := fc.Execute(context.Background(), candidates, run)
if err != nil {
t.Fatalf("expected fallback success, got error: %v", err)
}
if attempt != 2 {
t.Fatalf("attempt = %d, want 2", attempt)
}
if result.Provider != "anthropic" || result.Model != "claude" {
t.Fatalf("result = %s/%s, want anthropic/claude", result.Provider, result.Model)
}
if len(result.Attempts) != 1 {
t.Fatalf("attempts = %d, want 1 failed attempt recorded", len(result.Attempts))
}
if result.Attempts[0].Reason != FailoverTimeout {
t.Fatalf("attempt reason = %q, want timeout", result.Attempts[0].Reason)
}
assertFallbackErrorFallsBack(
t,
"openai",
"gpt-4",
errors.New("failed to send request: Post \"https://example.com\": i/o timeout"),
"timeout fallback ok",
FailoverTimeout,
)
}
func TestFallback_SuccessResetsCooldown(t *testing.T) {