package openai_compat import ( "bufio" "bytes" "context" "encoding/json" "fmt" "io" "log" "net/http" "net/url" "strings" "time" "github.com/sipeed/picoclaw/pkg/providers/common" "github.com/sipeed/picoclaw/pkg/providers/protocoltypes" ) type ( ToolCall = protocoltypes.ToolCall FunctionCall = protocoltypes.FunctionCall LLMResponse = protocoltypes.LLMResponse UsageInfo = protocoltypes.UsageInfo Message = protocoltypes.Message ToolDefinition = protocoltypes.ToolDefinition ToolFunctionDefinition = protocoltypes.ToolFunctionDefinition ExtraContent = protocoltypes.ExtraContent GoogleExtra = protocoltypes.GoogleExtra ReasoningDetail = protocoltypes.ReasoningDetail ) type Provider struct { apiKey string apiBase string maxTokensField string // Field name for max tokens (e.g., "max_completion_tokens" for o1/glm models) httpClient *http.Client } type Option func(*Provider) const defaultRequestTimeout = common.DefaultRequestTimeout func WithMaxTokensField(maxTokensField string) Option { return func(p *Provider) { p.maxTokensField = maxTokensField } } func WithRequestTimeout(timeout time.Duration) Option { return func(p *Provider) { if timeout > 0 { p.httpClient.Timeout = timeout } } } func NewProvider(apiKey, apiBase, proxy string, opts ...Option) *Provider { p := &Provider{ apiKey: apiKey, apiBase: strings.TrimRight(apiBase, "/"), httpClient: common.NewHTTPClient(proxy), } for _, opt := range opts { if opt != nil { opt(p) } } return p } func NewProviderWithMaxTokensField(apiKey, apiBase, proxy, maxTokensField string) *Provider { return NewProvider(apiKey, apiBase, proxy, WithMaxTokensField(maxTokensField)) } func NewProviderWithMaxTokensFieldAndTimeout( apiKey, apiBase, proxy, maxTokensField string, requestTimeoutSeconds int, ) *Provider { return NewProvider( apiKey, apiBase, proxy, WithMaxTokensField(maxTokensField), WithRequestTimeout(time.Duration(requestTimeoutSeconds)*time.Second), ) } // buildRequestBody constructs the common request body for Chat and ChatStream. func (p *Provider) buildRequestBody( messages []Message, tools []ToolDefinition, model string, options map[string]any, ) map[string]any { model = normalizeModel(model, p.apiBase) requestBody := map[string]any{ "model": model, "messages": common.SerializeMessages(messages), } // When fallback uses a different provider (e.g. DeepSeek), that provider must not inject web_search_preview. nativeSearch, _ := options["native_search"].(bool) nativeSearch = nativeSearch && isNativeSearchHost(p.apiBase) if len(tools) > 0 || nativeSearch { requestBody["tools"] = buildToolsList(tools, nativeSearch) requestBody["tool_choice"] = "auto" } if maxTokens, ok := common.AsInt(options["max_tokens"]); ok { fieldName := p.maxTokensField if fieldName == "" { lowerModel := strings.ToLower(model) if strings.Contains(lowerModel, "glm") || strings.Contains(lowerModel, "o1") || strings.Contains(lowerModel, "gpt-5") { fieldName = "max_completion_tokens" } else { fieldName = "max_tokens" } } requestBody[fieldName] = maxTokens } if temperature, ok := common.AsFloat(options["temperature"]); ok { lowerModel := strings.ToLower(model) if strings.Contains(lowerModel, "kimi") && strings.Contains(lowerModel, "k2") { requestBody["temperature"] = 1.0 } else { requestBody["temperature"] = temperature } } // Prompt caching: pass a stable cache key so OpenAI can bucket requests // with the same key and reuse prefix KV cache across calls. // Prompt caching is only supported by OpenAI-native endpoints. // Non-OpenAI providers reject unknown fields with 422 errors. if cacheKey, ok := options["prompt_cache_key"].(string); ok && cacheKey != "" { if supportsPromptCacheKey(p.apiBase) { requestBody["prompt_cache_key"] = cacheKey } } return requestBody } func (p *Provider) Chat( ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]any, ) (*LLMResponse, error) { if p.apiBase == "" { return nil, fmt.Errorf("API base not configured") } requestBody := p.buildRequestBody(messages, tools, model, options) jsonData, err := json.Marshal(requestBody) if err != nil { return nil, fmt.Errorf("failed to marshal request: %w", err) } req, err := http.NewRequestWithContext(ctx, "POST", p.apiBase+"/chat/completions", bytes.NewReader(jsonData)) if err != nil { return nil, fmt.Errorf("failed to create request: %w", err) } req.Header.Set("Content-Type", "application/json") if p.apiKey != "" { req.Header.Set("Authorization", "Bearer "+p.apiKey) } resp, err := p.httpClient.Do(req) if err != nil { return nil, fmt.Errorf("failed to send request: %w", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { return nil, common.HandleErrorResponse(resp, p.apiBase) } return common.ReadAndParseResponse(resp, p.apiBase) } // ChatStream implements streaming via OpenAI-compatible SSE (stream: true). // onChunk receives the accumulated text so far on each text delta. func (p *Provider) ChatStream( ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]any, onChunk func(accumulated string), ) (*LLMResponse, error) { if p.apiBase == "" { return nil, fmt.Errorf("API base not configured") } requestBody := p.buildRequestBody(messages, tools, model, options) requestBody["stream"] = true jsonData, err := json.Marshal(requestBody) if err != nil { return nil, fmt.Errorf("failed to marshal request: %w", err) } req, err := http.NewRequestWithContext(ctx, "POST", p.apiBase+"/chat/completions", bytes.NewReader(jsonData)) if err != nil { return nil, fmt.Errorf("failed to create request: %w", err) } req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", "text/event-stream") if p.apiKey != "" { req.Header.Set("Authorization", "Bearer "+p.apiKey) } // Use a client without Timeout for streaming — the http.Client.Timeout covers // the entire request lifecycle including body reads, which would kill long streams. // Context cancellation still provides the safety net. streamClient := &http.Client{Transport: p.httpClient.Transport} resp, err := streamClient.Do(req) if err != nil { return nil, fmt.Errorf("failed to send request: %w", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { return nil, common.HandleErrorResponse(resp, p.apiBase) } return parseStreamResponse(ctx, resp.Body, onChunk) } // parseStreamResponse parses an OpenAI-compatible SSE stream. func parseStreamResponse( ctx context.Context, reader io.Reader, onChunk func(accumulated string), ) (*LLMResponse, error) { var textContent strings.Builder var finishReason string var usage *UsageInfo // Tool call assembly: OpenAI streams tool calls as incremental deltas type toolAccum struct { id string name string argsJSON strings.Builder } activeTools := map[int]*toolAccum{} scanner := bufio.NewScanner(reader) scanner.Buffer(make([]byte, 0, 1024*1024), 10*1024*1024) // 1MB initial, 10MB max for scanner.Scan() { // Check for context cancellation between chunks if err := ctx.Err(); err != nil { return nil, err } line := scanner.Text() if !strings.HasPrefix(line, "data: ") { continue } data := strings.TrimPrefix(line, "data: ") if data == "[DONE]" { break } var chunk struct { Choices []struct { Delta struct { Content string `json:"content"` ToolCalls []struct { Index int `json:"index"` ID string `json:"id"` Function *struct { Name string `json:"name"` Arguments string `json:"arguments"` } `json:"function"` } `json:"tool_calls"` } `json:"delta"` FinishReason *string `json:"finish_reason"` } `json:"choices"` Usage *UsageInfo `json:"usage"` } if err := json.Unmarshal([]byte(data), &chunk); err != nil { continue // skip malformed chunks } if chunk.Usage != nil { usage = chunk.Usage } if len(chunk.Choices) == 0 { continue } choice := chunk.Choices[0] // Accumulate text content if choice.Delta.Content != "" { textContent.WriteString(choice.Delta.Content) if onChunk != nil { onChunk(textContent.String()) } } // Accumulate tool call deltas for _, tc := range choice.Delta.ToolCalls { acc, ok := activeTools[tc.Index] if !ok { acc = &toolAccum{} activeTools[tc.Index] = acc } if tc.ID != "" { acc.id = tc.ID } if tc.Function != nil { if tc.Function.Name != "" { acc.name = tc.Function.Name } if tc.Function.Arguments != "" { acc.argsJSON.WriteString(tc.Function.Arguments) } } } if choice.FinishReason != nil { finishReason = *choice.FinishReason } } if err := scanner.Err(); err != nil { return nil, fmt.Errorf("streaming read error: %w", err) } // Assemble tool calls from accumulated deltas var toolCalls []ToolCall for i := 0; i < len(activeTools); i++ { acc, ok := activeTools[i] if !ok { continue } args := make(map[string]any) raw := acc.argsJSON.String() if raw != "" { if err := json.Unmarshal([]byte(raw), &args); err != nil { log.Printf("openai_compat stream: failed to decode tool call arguments for %q: %v", acc.name, err) args["raw"] = raw } } toolCalls = append(toolCalls, ToolCall{ ID: acc.id, Name: acc.name, Arguments: args, }) } if finishReason == "" { finishReason = "stop" } return &LLMResponse{ Content: textContent.String(), ToolCalls: toolCalls, FinishReason: finishReason, Usage: usage, }, nil } func normalizeModel(model, apiBase string) string { before, after, ok := strings.Cut(model, "/") if !ok { return model } if strings.Contains(strings.ToLower(apiBase), "openrouter.ai") { return model } prefix := strings.ToLower(before) switch prefix { case "litellm", "moonshot", "nvidia", "groq", "ollama", "deepseek", "google", "openrouter", "zhipu", "mistral", "vivgrid", "minimax", "novita": return after default: return model } } func buildToolsList(tools []ToolDefinition, nativeSearch bool) []any { result := make([]any, 0, len(tools)+1) for _, t := range tools { if nativeSearch && strings.EqualFold(t.Function.Name, "web_search") { continue } result = append(result, t) } if nativeSearch { result = append(result, map[string]any{"type": "web_search_preview"}) } return result } func (p *Provider) SupportsNativeSearch() bool { return isNativeSearchHost(p.apiBase) } func isNativeSearchHost(apiBase string) bool { u, err := url.Parse(apiBase) if err != nil { return false } host := u.Hostname() return host == "api.openai.com" || strings.HasSuffix(host, ".openai.azure.com") } // supportsPromptCacheKey reports whether the given API base is known to // support the prompt_cache_key request field. Currently only OpenAI's own // API and Azure OpenAI support this. All other OpenAI-compatible providers // (Mistral, Gemini, DeepSeek, Groq, etc.) reject unknown fields with 422 errors. func supportsPromptCacheKey(apiBase string) bool { u, err := url.Parse(apiBase) if err != nil { return false } host := u.Hostname() return host == "api.openai.com" || strings.HasSuffix(host, ".openai.azure.com") }