Files
picoclaw/pkg/providers/openai_compat/provider.go
T
Zhaoyikaiii 1f7cbd9164 fix: cache system prompt with mtime-based auto-invalidation (#607)
Avoid rebuilding the entire system prompt on every BuildMessages() call
by caching the static portion (identity, bootstrap, skills summary,
memory) and only recomputing it when workspace source files change.

Key changes:

- ContextBuilder caches the static prompt behind an RWMutex with
  double-checked locking. Source file changes are detected via cheap
  os.Stat mtime checks so no explicit invalidation is needed.

- Track file existence at cache time (existedAtCache map) so that
  newly created or deleted bootstrap/memory files also trigger a
  rebuild — the old modifiedSince() silently returned false on
  os.IsNotExist.

- Walk the skills directory recursively with filepath.WalkDir to
  catch content-only edits at any nesting depth; directory mtime
  alone misses in-place file modifications on most filesystems.

- ToolRegistry.sortedToolNames() sorts tool names before iteration,
  ensuring deterministic tool definition order across calls — a
  prerequisite for LLM-side prefix/KV cache reuse.

- Merge all context (static + dynamic + summary) into a single
  system message for provider compatibility: the Anthropic adapter
  extracts messages[0] as the top-level system parameter, and Codex
  reads only the first system message as instructions.

- Fix a data race in BuildMessages() where cachedSystemPrompt was
  read without holding the lock in a debug log statement.

- Add tests: single system message invariant, mtime auto-invalidation,
  new-file creation detection, skill file content change, explicit
  InvalidateCache, cache stability, concurrent access (20 goroutines
  x 50 iterations, passes go test -race), and a benchmark.
2026-02-25 15:27:45 +08:00

315 lines
8.3 KiB
Go

package openai_compat
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"net/url"
"strings"
"time"
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
)
type (
ToolCall = protocoltypes.ToolCall
FunctionCall = protocoltypes.FunctionCall
LLMResponse = protocoltypes.LLMResponse
UsageInfo = protocoltypes.UsageInfo
Message = protocoltypes.Message
ToolDefinition = protocoltypes.ToolDefinition
ToolFunctionDefinition = protocoltypes.ToolFunctionDefinition
ExtraContent = protocoltypes.ExtraContent
GoogleExtra = protocoltypes.GoogleExtra
)
type Provider struct {
apiKey string
apiBase string
maxTokensField string // Field name for max tokens (e.g., "max_completion_tokens" for o1/glm models)
httpClient *http.Client
}
func NewProvider(apiKey, apiBase, proxy string) *Provider {
return NewProviderWithMaxTokensField(apiKey, apiBase, proxy, "")
}
func NewProviderWithMaxTokensField(apiKey, apiBase, proxy, maxTokensField string) *Provider {
client := &http.Client{
Timeout: 120 * time.Second,
}
if proxy != "" {
parsed, err := url.Parse(proxy)
if err == nil {
client.Transport = &http.Transport{
Proxy: http.ProxyURL(parsed),
}
} else {
log.Printf("openai_compat: invalid proxy URL %q: %v", proxy, err)
}
}
return &Provider{
apiKey: apiKey,
apiBase: strings.TrimRight(apiBase, "/"),
maxTokensField: maxTokensField,
httpClient: client,
}
}
func (p *Provider) Chat(
ctx context.Context,
messages []Message,
tools []ToolDefinition,
model string,
options map[string]any,
) (*LLMResponse, error) {
if p.apiBase == "" {
return nil, fmt.Errorf("API base not configured")
}
model = normalizeModel(model, p.apiBase)
requestBody := map[string]any{
"model": model,
"messages": stripSystemParts(messages),
}
if len(tools) > 0 {
requestBody["tools"] = tools
requestBody["tool_choice"] = "auto"
}
if maxTokens, ok := asInt(options["max_tokens"]); ok {
// Use configured maxTokensField if specified, otherwise fallback to model-based detection
fieldName := p.maxTokensField
if fieldName == "" {
// Fallback: detect from model name for backward compatibility
lowerModel := strings.ToLower(model)
if strings.Contains(lowerModel, "glm") || strings.Contains(lowerModel, "o1") ||
strings.Contains(lowerModel, "gpt-5") {
fieldName = "max_completion_tokens"
} else {
fieldName = "max_tokens"
}
}
requestBody[fieldName] = maxTokens
}
if temperature, ok := asFloat(options["temperature"]); ok {
lowerModel := strings.ToLower(model)
// Kimi k2 models only support temperature=1.
if strings.Contains(lowerModel, "kimi") && strings.Contains(lowerModel, "k2") {
requestBody["temperature"] = 1.0
} else {
requestBody["temperature"] = temperature
}
}
// Prompt caching: pass a stable cache key so OpenAI can bucket requests
// with the same key and reuse prefix KV cache across calls.
// The key is typically the agent ID — stable per agent, shared across requests.
// See: https://platform.openai.com/docs/guides/prompt-caching
if cacheKey, ok := options["prompt_cache_key"].(string); ok && cacheKey != "" {
requestBody["prompt_cache_key"] = cacheKey
}
jsonData, err := json.Marshal(requestBody)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
req, err := http.NewRequestWithContext(ctx, "POST", p.apiBase+"/chat/completions", bytes.NewReader(jsonData))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
if p.apiKey != "" {
req.Header.Set("Authorization", "Bearer "+p.apiKey)
}
resp, err := p.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("API request failed:\n Status: %d\n Body: %s", resp.StatusCode, string(body))
}
return parseResponse(body)
}
func parseResponse(body []byte) (*LLMResponse, error) {
var apiResponse struct {
Choices []struct {
Message struct {
Content string `json:"content"`
ReasoningContent string `json:"reasoning_content"`
ToolCalls []struct {
ID string `json:"id"`
Type string `json:"type"`
Function *struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
} `json:"function"`
ExtraContent *struct {
Google *struct {
ThoughtSignature string `json:"thought_signature"`
} `json:"google"`
} `json:"extra_content"`
} `json:"tool_calls"`
} `json:"message"`
FinishReason string `json:"finish_reason"`
} `json:"choices"`
Usage *UsageInfo `json:"usage"`
}
if err := json.Unmarshal(body, &apiResponse); err != nil {
return nil, fmt.Errorf("failed to unmarshal response: %w", err)
}
if len(apiResponse.Choices) == 0 {
return &LLMResponse{
Content: "",
FinishReason: "stop",
}, nil
}
choice := apiResponse.Choices[0]
toolCalls := make([]ToolCall, 0, len(choice.Message.ToolCalls))
for _, tc := range choice.Message.ToolCalls {
arguments := make(map[string]any)
name := ""
// Extract thought_signature from Gemini/Google-specific extra content
thoughtSignature := ""
if tc.ExtraContent != nil && tc.ExtraContent.Google != nil {
thoughtSignature = tc.ExtraContent.Google.ThoughtSignature
}
if tc.Function != nil {
name = tc.Function.Name
if tc.Function.Arguments != "" {
if err := json.Unmarshal([]byte(tc.Function.Arguments), &arguments); err != nil {
log.Printf("openai_compat: failed to decode tool call arguments for %q: %v", name, err)
arguments["raw"] = tc.Function.Arguments
}
}
}
// Build ToolCall with ExtraContent for Gemini 3 thought_signature persistence
toolCall := ToolCall{
ID: tc.ID,
Name: name,
Arguments: arguments,
ThoughtSignature: thoughtSignature,
}
if thoughtSignature != "" {
toolCall.ExtraContent = &ExtraContent{
Google: &GoogleExtra{
ThoughtSignature: thoughtSignature,
},
}
}
toolCalls = append(toolCalls, toolCall)
}
return &LLMResponse{
Content: choice.Message.Content,
ReasoningContent: choice.Message.ReasoningContent,
ToolCalls: toolCalls,
FinishReason: choice.FinishReason,
Usage: apiResponse.Usage,
}, nil
}
// openaiMessage is the wire-format message for OpenAI-compatible APIs.
// It mirrors protocoltypes.Message but omits SystemParts, which is an
// internal field that would be unknown to third-party endpoints.
type openaiMessage struct {
Role string `json:"role"`
Content string `json:"content"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
}
// stripSystemParts converts []Message to []openaiMessage, dropping the
// SystemParts field so it doesn't leak into the JSON payload sent to
// OpenAI-compatible APIs (some strict endpoints reject unknown fields).
func stripSystemParts(messages []Message) []openaiMessage {
out := make([]openaiMessage, len(messages))
for i, m := range messages {
out[i] = openaiMessage{
Role: m.Role,
Content: m.Content,
ToolCalls: m.ToolCalls,
ToolCallID: m.ToolCallID,
}
}
return out
}
func normalizeModel(model, apiBase string) string {
idx := strings.Index(model, "/")
if idx == -1 {
return model
}
if strings.Contains(strings.ToLower(apiBase), "openrouter.ai") {
return model
}
prefix := strings.ToLower(model[:idx])
switch prefix {
case "moonshot", "nvidia", "groq", "ollama", "deepseek", "google", "openrouter", "zhipu", "mistral":
return model[idx+1:]
default:
return model
}
}
func asInt(v any) (int, bool) {
switch val := v.(type) {
case int:
return val, true
case int64:
return int(val), true
case float64:
return int(val), true
case float32:
return int(val), true
default:
return 0, false
}
}
func asFloat(v any) (float64, bool) {
switch val := v.(type) {
case float64:
return val, true
case float32:
return float64(val), true
case int:
return float64(val), true
case int64:
return float64(val), true
default:
return 0, false
}
}