mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
1f7cbd9164
Avoid rebuilding the entire system prompt on every BuildMessages() call by caching the static portion (identity, bootstrap, skills summary, memory) and only recomputing it when workspace source files change. Key changes: - ContextBuilder caches the static prompt behind an RWMutex with double-checked locking. Source file changes are detected via cheap os.Stat mtime checks so no explicit invalidation is needed. - Track file existence at cache time (existedAtCache map) so that newly created or deleted bootstrap/memory files also trigger a rebuild — the old modifiedSince() silently returned false on os.IsNotExist. - Walk the skills directory recursively with filepath.WalkDir to catch content-only edits at any nesting depth; directory mtime alone misses in-place file modifications on most filesystems. - ToolRegistry.sortedToolNames() sorts tool names before iteration, ensuring deterministic tool definition order across calls — a prerequisite for LLM-side prefix/KV cache reuse. - Merge all context (static + dynamic + summary) into a single system message for provider compatibility: the Anthropic adapter extracts messages[0] as the top-level system parameter, and Codex reads only the first system message as instructions. - Fix a data race in BuildMessages() where cachedSystemPrompt was read without holding the lock in a debug log statement. - Add tests: single system message invariant, mtime auto-invalidation, new-file creation detection, skill file content change, explicit InvalidateCache, cache stability, concurrent access (20 goroutines x 50 iterations, passes go test -race), and a benchmark.
315 lines
8.3 KiB
Go
315 lines
8.3 KiB
Go
package openai_compat
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"net/http"
|
|
"net/url"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
|
|
)
|
|
|
|
type (
|
|
ToolCall = protocoltypes.ToolCall
|
|
FunctionCall = protocoltypes.FunctionCall
|
|
LLMResponse = protocoltypes.LLMResponse
|
|
UsageInfo = protocoltypes.UsageInfo
|
|
Message = protocoltypes.Message
|
|
ToolDefinition = protocoltypes.ToolDefinition
|
|
ToolFunctionDefinition = protocoltypes.ToolFunctionDefinition
|
|
ExtraContent = protocoltypes.ExtraContent
|
|
GoogleExtra = protocoltypes.GoogleExtra
|
|
)
|
|
|
|
type Provider struct {
|
|
apiKey string
|
|
apiBase string
|
|
maxTokensField string // Field name for max tokens (e.g., "max_completion_tokens" for o1/glm models)
|
|
httpClient *http.Client
|
|
}
|
|
|
|
func NewProvider(apiKey, apiBase, proxy string) *Provider {
|
|
return NewProviderWithMaxTokensField(apiKey, apiBase, proxy, "")
|
|
}
|
|
|
|
func NewProviderWithMaxTokensField(apiKey, apiBase, proxy, maxTokensField string) *Provider {
|
|
client := &http.Client{
|
|
Timeout: 120 * time.Second,
|
|
}
|
|
|
|
if proxy != "" {
|
|
parsed, err := url.Parse(proxy)
|
|
if err == nil {
|
|
client.Transport = &http.Transport{
|
|
Proxy: http.ProxyURL(parsed),
|
|
}
|
|
} else {
|
|
log.Printf("openai_compat: invalid proxy URL %q: %v", proxy, err)
|
|
}
|
|
}
|
|
|
|
return &Provider{
|
|
apiKey: apiKey,
|
|
apiBase: strings.TrimRight(apiBase, "/"),
|
|
maxTokensField: maxTokensField,
|
|
httpClient: client,
|
|
}
|
|
}
|
|
|
|
func (p *Provider) Chat(
|
|
ctx context.Context,
|
|
messages []Message,
|
|
tools []ToolDefinition,
|
|
model string,
|
|
options map[string]any,
|
|
) (*LLMResponse, error) {
|
|
if p.apiBase == "" {
|
|
return nil, fmt.Errorf("API base not configured")
|
|
}
|
|
|
|
model = normalizeModel(model, p.apiBase)
|
|
|
|
requestBody := map[string]any{
|
|
"model": model,
|
|
"messages": stripSystemParts(messages),
|
|
}
|
|
|
|
if len(tools) > 0 {
|
|
requestBody["tools"] = tools
|
|
requestBody["tool_choice"] = "auto"
|
|
}
|
|
|
|
if maxTokens, ok := asInt(options["max_tokens"]); ok {
|
|
// Use configured maxTokensField if specified, otherwise fallback to model-based detection
|
|
fieldName := p.maxTokensField
|
|
if fieldName == "" {
|
|
// Fallback: detect from model name for backward compatibility
|
|
lowerModel := strings.ToLower(model)
|
|
if strings.Contains(lowerModel, "glm") || strings.Contains(lowerModel, "o1") ||
|
|
strings.Contains(lowerModel, "gpt-5") {
|
|
fieldName = "max_completion_tokens"
|
|
} else {
|
|
fieldName = "max_tokens"
|
|
}
|
|
}
|
|
requestBody[fieldName] = maxTokens
|
|
}
|
|
|
|
if temperature, ok := asFloat(options["temperature"]); ok {
|
|
lowerModel := strings.ToLower(model)
|
|
// Kimi k2 models only support temperature=1.
|
|
if strings.Contains(lowerModel, "kimi") && strings.Contains(lowerModel, "k2") {
|
|
requestBody["temperature"] = 1.0
|
|
} else {
|
|
requestBody["temperature"] = temperature
|
|
}
|
|
}
|
|
|
|
// Prompt caching: pass a stable cache key so OpenAI can bucket requests
|
|
// with the same key and reuse prefix KV cache across calls.
|
|
// The key is typically the agent ID — stable per agent, shared across requests.
|
|
// See: https://platform.openai.com/docs/guides/prompt-caching
|
|
if cacheKey, ok := options["prompt_cache_key"].(string); ok && cacheKey != "" {
|
|
requestBody["prompt_cache_key"] = cacheKey
|
|
}
|
|
|
|
jsonData, err := json.Marshal(requestBody)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
|
}
|
|
|
|
req, err := http.NewRequestWithContext(ctx, "POST", p.apiBase+"/chat/completions", bytes.NewReader(jsonData))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create request: %w", err)
|
|
}
|
|
|
|
req.Header.Set("Content-Type", "application/json")
|
|
if p.apiKey != "" {
|
|
req.Header.Set("Authorization", "Bearer "+p.apiKey)
|
|
}
|
|
|
|
resp, err := p.httpClient.Do(req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to send request: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to read response: %w", err)
|
|
}
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
return nil, fmt.Errorf("API request failed:\n Status: %d\n Body: %s", resp.StatusCode, string(body))
|
|
}
|
|
|
|
return parseResponse(body)
|
|
}
|
|
|
|
func parseResponse(body []byte) (*LLMResponse, error) {
|
|
var apiResponse struct {
|
|
Choices []struct {
|
|
Message struct {
|
|
Content string `json:"content"`
|
|
ReasoningContent string `json:"reasoning_content"`
|
|
ToolCalls []struct {
|
|
ID string `json:"id"`
|
|
Type string `json:"type"`
|
|
Function *struct {
|
|
Name string `json:"name"`
|
|
Arguments string `json:"arguments"`
|
|
} `json:"function"`
|
|
ExtraContent *struct {
|
|
Google *struct {
|
|
ThoughtSignature string `json:"thought_signature"`
|
|
} `json:"google"`
|
|
} `json:"extra_content"`
|
|
} `json:"tool_calls"`
|
|
} `json:"message"`
|
|
FinishReason string `json:"finish_reason"`
|
|
} `json:"choices"`
|
|
Usage *UsageInfo `json:"usage"`
|
|
}
|
|
|
|
if err := json.Unmarshal(body, &apiResponse); err != nil {
|
|
return nil, fmt.Errorf("failed to unmarshal response: %w", err)
|
|
}
|
|
|
|
if len(apiResponse.Choices) == 0 {
|
|
return &LLMResponse{
|
|
Content: "",
|
|
FinishReason: "stop",
|
|
}, nil
|
|
}
|
|
|
|
choice := apiResponse.Choices[0]
|
|
toolCalls := make([]ToolCall, 0, len(choice.Message.ToolCalls))
|
|
for _, tc := range choice.Message.ToolCalls {
|
|
arguments := make(map[string]any)
|
|
name := ""
|
|
|
|
// Extract thought_signature from Gemini/Google-specific extra content
|
|
thoughtSignature := ""
|
|
if tc.ExtraContent != nil && tc.ExtraContent.Google != nil {
|
|
thoughtSignature = tc.ExtraContent.Google.ThoughtSignature
|
|
}
|
|
|
|
if tc.Function != nil {
|
|
name = tc.Function.Name
|
|
if tc.Function.Arguments != "" {
|
|
if err := json.Unmarshal([]byte(tc.Function.Arguments), &arguments); err != nil {
|
|
log.Printf("openai_compat: failed to decode tool call arguments for %q: %v", name, err)
|
|
arguments["raw"] = tc.Function.Arguments
|
|
}
|
|
}
|
|
}
|
|
|
|
// Build ToolCall with ExtraContent for Gemini 3 thought_signature persistence
|
|
toolCall := ToolCall{
|
|
ID: tc.ID,
|
|
Name: name,
|
|
Arguments: arguments,
|
|
ThoughtSignature: thoughtSignature,
|
|
}
|
|
|
|
if thoughtSignature != "" {
|
|
toolCall.ExtraContent = &ExtraContent{
|
|
Google: &GoogleExtra{
|
|
ThoughtSignature: thoughtSignature,
|
|
},
|
|
}
|
|
}
|
|
|
|
toolCalls = append(toolCalls, toolCall)
|
|
}
|
|
|
|
return &LLMResponse{
|
|
Content: choice.Message.Content,
|
|
ReasoningContent: choice.Message.ReasoningContent,
|
|
ToolCalls: toolCalls,
|
|
FinishReason: choice.FinishReason,
|
|
Usage: apiResponse.Usage,
|
|
}, nil
|
|
}
|
|
|
|
// openaiMessage is the wire-format message for OpenAI-compatible APIs.
|
|
// It mirrors protocoltypes.Message but omits SystemParts, which is an
|
|
// internal field that would be unknown to third-party endpoints.
|
|
type openaiMessage struct {
|
|
Role string `json:"role"`
|
|
Content string `json:"content"`
|
|
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
|
ToolCallID string `json:"tool_call_id,omitempty"`
|
|
}
|
|
|
|
// stripSystemParts converts []Message to []openaiMessage, dropping the
|
|
// SystemParts field so it doesn't leak into the JSON payload sent to
|
|
// OpenAI-compatible APIs (some strict endpoints reject unknown fields).
|
|
func stripSystemParts(messages []Message) []openaiMessage {
|
|
out := make([]openaiMessage, len(messages))
|
|
for i, m := range messages {
|
|
out[i] = openaiMessage{
|
|
Role: m.Role,
|
|
Content: m.Content,
|
|
ToolCalls: m.ToolCalls,
|
|
ToolCallID: m.ToolCallID,
|
|
}
|
|
}
|
|
return out
|
|
}
|
|
|
|
func normalizeModel(model, apiBase string) string {
|
|
idx := strings.Index(model, "/")
|
|
if idx == -1 {
|
|
return model
|
|
}
|
|
|
|
if strings.Contains(strings.ToLower(apiBase), "openrouter.ai") {
|
|
return model
|
|
}
|
|
|
|
prefix := strings.ToLower(model[:idx])
|
|
switch prefix {
|
|
case "moonshot", "nvidia", "groq", "ollama", "deepseek", "google", "openrouter", "zhipu", "mistral":
|
|
return model[idx+1:]
|
|
default:
|
|
return model
|
|
}
|
|
}
|
|
|
|
func asInt(v any) (int, bool) {
|
|
switch val := v.(type) {
|
|
case int:
|
|
return val, true
|
|
case int64:
|
|
return int(val), true
|
|
case float64:
|
|
return int(val), true
|
|
case float32:
|
|
return int(val), true
|
|
default:
|
|
return 0, false
|
|
}
|
|
}
|
|
|
|
func asFloat(v any) (float64, bool) {
|
|
switch val := v.(type) {
|
|
case float64:
|
|
return val, true
|
|
case float32:
|
|
return float64(val), true
|
|
case int:
|
|
return float64(val), true
|
|
case int64:
|
|
return float64(val), true
|
|
default:
|
|
return 0, false
|
|
}
|
|
}
|