fix: improve openai compat HTML response handling

This commit is contained in:
amagi
2026-03-07 15:50:08 +08:00
parent c1a3876f7d
commit 6eaa49f7ab
2 changed files with 139 additions and 13 deletions
+52 -9
View File
@@ -1,6 +1,7 @@
package openai_compat
import (
"bufio"
"bytes"
"context"
"encoding/json"
@@ -185,28 +186,70 @@ func (p *Provider) Chat(
contentType := resp.Header.Get("Content-Type")
// check if there is an HTTP error (caused by proxy or gateway) or if the response is HTML
if resp.StatusCode != http.StatusOK || strings.Contains(strings.ToLower(contentType), "text/html") {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 256))
return nil, wrapHTTPResponseError(resp.StatusCode, body, contentType, p.apiBase)
// Non-200: read a prefix to tell HTML error page apart from JSON error body.
if resp.StatusCode != http.StatusOK {
body, err := io.ReadAll(io.LimitReader(resp.Body, 256))
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}
if looksLikeHTML(body, contentType) {
return nil, wrapHTMLResponseError(resp.StatusCode, body, contentType, p.apiBase)
}
return nil, fmt.Errorf("API request failed:\n Status: %d\n Body: %s", resp.StatusCode, responsePreview(body, 128))
}
// directly pass the stream (resp.Body) to the JSON parser without loading everything into memory
out, err := parseResponse(resp.Body)
// Peek without consuming so the full stream reaches the JSON decoder.
reader := bufio.NewReader(resp.Body)
prefix, err := reader.Peek(256) // io.EOF/ErrBufferFull are normal; only real errors abort
if err != nil && err != io.EOF && err != bufio.ErrBufferFull {
return nil, fmt.Errorf("failed to inspect response: %w", err)
}
if looksLikeHTML(prefix, contentType) {
return nil, wrapHTMLResponseError(resp.StatusCode, prefix, contentType, p.apiBase)
}
out, err := parseResponse(reader)
if err != nil {
// Note: if it fails here, we do not have the full body in memory for HTML inspection,
// but having already checked the Content-Type above, the error is genuinely related to JSON parsing.
return nil, fmt.Errorf("failed to parse JSON response: %w", err)
}
return out, nil
}
func wrapHTTPResponseError(statusCode int, body []byte, contentType, apiBase string) error {
func wrapHTMLResponseError(statusCode int, body []byte, contentType, apiBase string) error {
respPreview := responsePreview(body, 128)
return fmt.Errorf("API request failed: %s returned HTML instead of JSON (content-type: %s); check api_base or proxy configuration.\n Status: %d\n Body: %s", apiBase, contentType, statusCode, respPreview)
}
func looksLikeHTML(body []byte, contentType string) bool {
contentType = strings.ToLower(strings.TrimSpace(contentType))
if strings.Contains(contentType, "text/html") || strings.Contains(contentType, "application/xhtml+xml") {
return true
}
prefix := bytes.ToLower(leadingTrimmedPrefix(body, 128))
return bytes.HasPrefix(prefix, []byte("<!doctype html")) ||
bytes.HasPrefix(prefix, []byte("<html")) ||
bytes.HasPrefix(prefix, []byte("<head")) ||
bytes.HasPrefix(prefix, []byte("<body"))
}
func leadingTrimmedPrefix(body []byte, maxLen int) []byte {
i := 0
for i < len(body) {
switch body[i] {
case ' ', '\t', '\n', '\r', '\f', '\v':
i++
default:
end := i + maxLen
if end > len(body) {
end = len(body)
}
return body[i:end]
}
}
return nil
}
func responsePreview(body []byte, maxLen int) string {
trimmed := bytes.TrimSpace(body)
if len(trimmed) == 0 {
+87 -4
View File
@@ -3,6 +3,7 @@ package openai_compat
import (
"bytes"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"net/url"
@@ -213,6 +214,27 @@ func TestProviderChat_HTTPError(t *testing.T) {
}
}
func TestProviderChat_JSONHTTPErrorDoesNotReportHTML(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
_, _ = w.Write([]byte(`{"error":"bad request"}`))
}))
defer server.Close()
p := NewProvider("key", server.URL, "")
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "gpt-4o", nil)
if err == nil {
t.Fatal("expected error, got nil")
}
if !strings.Contains(err.Error(), "Status: 400") {
t.Fatalf("expected status code in error, got %v", err)
}
if strings.Contains(err.Error(), "returned HTML instead of JSON") {
t.Fatalf("expected non-HTML http error, got %v", err)
}
}
func TestProviderChat_HTMLSuccessResponseReturnsHelpfulError(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/html; charset=utf-8")
@@ -226,7 +248,7 @@ func TestProviderChat_HTMLSuccessResponseReturnsHelpfulError(t *testing.T) {
if err == nil {
t.Fatal("expected error, got nil")
}
if !strings.Contains(err.Error(), "received HTML") {
if !strings.Contains(err.Error(), "returned HTML instead of JSON") {
t.Fatalf("expected helpful HTML error, got %v", err)
}
if !strings.Contains(err.Error(), "check api_base or proxy configuration") {
@@ -250,7 +272,7 @@ func TestProviderChat_HTMLErrorResponseReturnsHelpfulError(t *testing.T) {
if !strings.Contains(err.Error(), "Status: 502") {
t.Fatalf("expected status code in error, got %v", err)
}
if !strings.Contains(err.Error(), "received HTML") {
if !strings.Contains(err.Error(), "returned HTML instead of JSON") {
t.Fatalf("expected helpful HTML error, got %v", err)
}
if !strings.Contains(err.Error(), "check api_base or proxy configuration") {
@@ -271,7 +293,7 @@ func TestProviderChat_MislabeledHTMLSuccessResponseReturnsHelpfulError(t *testin
if err == nil {
t.Fatal("expected error, got nil")
}
if !strings.Contains(err.Error(), "received HTML") {
if !strings.Contains(err.Error(), "returned HTML instead of JSON") {
t.Fatalf("expected helpful HTML error, got %v", err)
}
if !strings.Contains(err.Error(), "check api_base or proxy configuration") {
@@ -279,6 +301,33 @@ func TestProviderChat_MislabeledHTMLSuccessResponseReturnsHelpfulError(t *testin
}
}
func TestProviderChat_SuccessResponseUsesStreamingDecoder(t *testing.T) {
content := strings.Repeat("a", 1024)
body := `{"choices":[{"message":{"content":"` + content + `"},"finish_reason":"stop"}]}`
p := NewProvider("key", "https://example.com/v1", "")
p.httpClient = &http.Client{
Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: &errAfterDataReadCloser{
data: []byte(body),
chunkSize: 64,
},
}, nil
}),
}
out, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "gpt-4o", nil)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
if out.Content != content {
t.Fatalf("Content = %q, want %q", out.Content, content)
}
}
func TestProviderChat_LargeHTMLResponsePreviewIsTruncated(t *testing.T) {
body := append([]byte("<!DOCTYPE html><html><body>"), bytes.Repeat([]byte("A"), 2048)...)
body = append(body, []byte("</body></html>")...)
@@ -295,7 +344,7 @@ func TestProviderChat_LargeHTMLResponsePreviewIsTruncated(t *testing.T) {
if err == nil {
t.Fatal("expected error, got nil")
}
if !strings.Contains(err.Error(), "Response preview: <!DOCTYPE html><html><body>") {
if !strings.Contains(err.Error(), "Body: <!DOCTYPE html><html><body>") {
t.Fatalf("expected html preview in error, got %v", err)
}
if !strings.Contains(err.Error(), "...") {
@@ -490,6 +539,40 @@ func TestProvider_RequestTimeoutOverride(t *testing.T) {
}
}
type roundTripperFunc func(*http.Request) (*http.Response, error)
func (f roundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) {
return f(r)
}
type errAfterDataReadCloser struct {
data []byte
chunkSize int
offset int
}
func (r *errAfterDataReadCloser) Read(p []byte) (int, error) {
if r.offset >= len(r.data) {
return 0, io.ErrUnexpectedEOF
}
n := r.chunkSize
if n <= 0 || n > len(p) {
n = len(p)
}
remaining := len(r.data) - r.offset
if n > remaining {
n = remaining
}
copy(p, r.data[r.offset:r.offset+n])
r.offset += n
return n, nil
}
func (r *errAfterDataReadCloser) Close() error {
return nil
}
func TestProvider_FunctionalOptionMaxTokensField(t *testing.T) {
p := NewProvider("key", "https://example.com/v1", "", WithMaxTokensField("max_completion_tokens"))
if p.maxTokensField != "max_completion_tokens" {