Files
picoclaw/pkg/providers/openai_compat/provider.go
T
winterfx d224397f40 fix: preserve reasoning_content for OpenAI-compatible reasoning models
Models like Moonshot kimi-k2.5 and DeepSeek-R1 return a
reasoning_content field in assistant messages. When thinking is enabled,
the API requires this field to be echoed back in subsequent requests.
PicoClaw was silently dropping it, causing 400 errors on tool-call
round-trips.

- Add ReasoningContent to Message and LLMResponse types
- Parse reasoning_content in openai_compat parseResponse()
- Carry reasoning_content through assistant tool-call messages
- Add unit test for reasoning_content parsing

Fixes #588
2026-02-21 23:29:40 +08:00

281 lines
7.0 KiB
Go

package openai_compat
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"net/url"
"strings"
"time"
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
)
type (
ToolCall = protocoltypes.ToolCall
FunctionCall = protocoltypes.FunctionCall
LLMResponse = protocoltypes.LLMResponse
UsageInfo = protocoltypes.UsageInfo
Message = protocoltypes.Message
ToolDefinition = protocoltypes.ToolDefinition
ToolFunctionDefinition = protocoltypes.ToolFunctionDefinition
ExtraContent = protocoltypes.ExtraContent
GoogleExtra = protocoltypes.GoogleExtra
)
type Provider struct {
apiKey string
apiBase string
maxTokensField string // Field name for max tokens (e.g., "max_completion_tokens" for o1/glm models)
httpClient *http.Client
}
func NewProvider(apiKey, apiBase, proxy string) *Provider {
return NewProviderWithMaxTokensField(apiKey, apiBase, proxy, "")
}
func NewProviderWithMaxTokensField(apiKey, apiBase, proxy, maxTokensField string) *Provider {
client := &http.Client{
Timeout: 120 * time.Second,
}
if proxy != "" {
parsed, err := url.Parse(proxy)
if err == nil {
client.Transport = &http.Transport{
Proxy: http.ProxyURL(parsed),
}
} else {
log.Printf("openai_compat: invalid proxy URL %q: %v", proxy, err)
}
}
return &Provider{
apiKey: apiKey,
apiBase: strings.TrimRight(apiBase, "/"),
maxTokensField: maxTokensField,
httpClient: client,
}
}
func (p *Provider) Chat(
ctx context.Context,
messages []Message,
tools []ToolDefinition,
model string,
options map[string]any,
) (*LLMResponse, error) {
if p.apiBase == "" {
return nil, fmt.Errorf("API base not configured")
}
model = normalizeModel(model, p.apiBase)
requestBody := map[string]any{
"model": model,
"messages": messages,
}
if len(tools) > 0 {
requestBody["tools"] = tools
requestBody["tool_choice"] = "auto"
}
if maxTokens, ok := asInt(options["max_tokens"]); ok {
// Use configured maxTokensField if specified, otherwise fallback to model-based detection
fieldName := p.maxTokensField
if fieldName == "" {
// Fallback: detect from model name for backward compatibility
lowerModel := strings.ToLower(model)
if strings.Contains(lowerModel, "glm") || strings.Contains(lowerModel, "o1") ||
strings.Contains(lowerModel, "gpt-5") {
fieldName = "max_completion_tokens"
} else {
fieldName = "max_tokens"
}
}
requestBody[fieldName] = maxTokens
}
if temperature, ok := asFloat(options["temperature"]); ok {
lowerModel := strings.ToLower(model)
// Kimi k2 models only support temperature=1.
if strings.Contains(lowerModel, "kimi") && strings.Contains(lowerModel, "k2") {
requestBody["temperature"] = 1.0
} else {
requestBody["temperature"] = temperature
}
}
jsonData, err := json.Marshal(requestBody)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
req, err := http.NewRequestWithContext(ctx, "POST", p.apiBase+"/chat/completions", bytes.NewReader(jsonData))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
if p.apiKey != "" {
req.Header.Set("Authorization", "Bearer "+p.apiKey)
}
resp, err := p.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("API request failed:\n Status: %d\n Body: %s", resp.StatusCode, string(body))
}
return parseResponse(body)
}
func parseResponse(body []byte) (*LLMResponse, error) {
var apiResponse struct {
Choices []struct {
Message struct {
Content string `json:"content"`
ReasoningContent string `json:"reasoning_content"`
ToolCalls []struct {
ID string `json:"id"`
Type string `json:"type"`
Function *struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
} `json:"function"`
ExtraContent *struct {
Google *struct {
ThoughtSignature string `json:"thought_signature"`
} `json:"google"`
} `json:"extra_content"`
} `json:"tool_calls"`
} `json:"message"`
FinishReason string `json:"finish_reason"`
} `json:"choices"`
Usage *UsageInfo `json:"usage"`
}
if err := json.Unmarshal(body, &apiResponse); err != nil {
return nil, fmt.Errorf("failed to unmarshal response: %w", err)
}
if len(apiResponse.Choices) == 0 {
return &LLMResponse{
Content: "",
FinishReason: "stop",
}, nil
}
choice := apiResponse.Choices[0]
toolCalls := make([]ToolCall, 0, len(choice.Message.ToolCalls))
for _, tc := range choice.Message.ToolCalls {
arguments := make(map[string]any)
name := ""
// Extract thought_signature from Gemini/Google-specific extra content
thoughtSignature := ""
if tc.ExtraContent != nil && tc.ExtraContent.Google != nil {
thoughtSignature = tc.ExtraContent.Google.ThoughtSignature
}
if tc.Function != nil {
name = tc.Function.Name
if tc.Function.Arguments != "" {
if err := json.Unmarshal([]byte(tc.Function.Arguments), &arguments); err != nil {
log.Printf("openai_compat: failed to decode tool call arguments for %q: %v", name, err)
arguments["raw"] = tc.Function.Arguments
}
}
}
// Build ToolCall with ExtraContent for Gemini 3 thought_signature persistence
toolCall := ToolCall{
ID: tc.ID,
Name: name,
Arguments: arguments,
ThoughtSignature: thoughtSignature,
}
if thoughtSignature != "" {
toolCall.ExtraContent = &ExtraContent{
Google: &GoogleExtra{
ThoughtSignature: thoughtSignature,
},
}
}
toolCalls = append(toolCalls, toolCall)
}
return &LLMResponse{
Content: choice.Message.Content,
ReasoningContent: choice.Message.ReasoningContent,
ToolCalls: toolCalls,
FinishReason: choice.FinishReason,
Usage: apiResponse.Usage,
}, nil
}
func normalizeModel(model, apiBase string) string {
idx := strings.Index(model, "/")
if idx == -1 {
return model
}
if strings.Contains(strings.ToLower(apiBase), "openrouter.ai") {
return model
}
prefix := strings.ToLower(model[:idx])
switch prefix {
case "moonshot", "nvidia", "groq", "ollama", "deepseek", "google", "openrouter", "zhipu":
return model[idx+1:]
default:
return model
}
}
func asInt(v any) (int, bool) {
switch val := v.(type) {
case int:
return val, true
case int64:
return int(val), true
case float64:
return int(val), true
case float32:
return int(val), true
default:
return 0, false
}
}
func asFloat(v any) (float64, bool) {
switch val := v.(type) {
case float64:
return val, true
case float32:
return float64(val), true
case int:
return float64(val), true
case int64:
return float64(val), true
default:
return 0, false
}
}