refactor(providers): reorganize provider packages and facades

This commit is contained in:
lc6464
2026-04-17 12:42:03 +08:00
parent 72f30c58e9
commit ee634dc8db
29 changed files with 573 additions and 102 deletions
+139
View File
@@ -0,0 +1,139 @@
package httpapi
import (
"encoding/json"
"strings"
)
func normalizeStoredToolCall(tc ToolCall) (string, map[string]any, string) {
name := tc.Name
args := tc.Arguments
thoughtSignature := ""
if name == "" && tc.Function != nil {
name = tc.Function.Name
thoughtSignature = tc.Function.ThoughtSignature
} else if tc.Function != nil {
thoughtSignature = tc.Function.ThoughtSignature
}
if args == nil {
args = map[string]any{}
}
if len(args) == 0 && tc.Function != nil && tc.Function.Arguments != "" {
var parsed map[string]any
if err := json.Unmarshal([]byte(tc.Function.Arguments), &parsed); err == nil && parsed != nil {
args = parsed
}
}
return name, args, thoughtSignature
}
func resolveToolResponseName(toolCallID string, toolCallNames map[string]string) string {
if toolCallID == "" {
return ""
}
if name, ok := toolCallNames[toolCallID]; ok && name != "" {
return name
}
return inferToolNameFromCallID(toolCallID)
}
func inferToolNameFromCallID(toolCallID string) string {
if !strings.HasPrefix(toolCallID, "call_") {
return toolCallID
}
rest := strings.TrimPrefix(toolCallID, "call_")
if idx := strings.LastIndex(rest, "_"); idx > 0 {
candidate := rest[:idx]
if candidate != "" {
return candidate
}
}
return toolCallID
}
func extractPartThoughtSignature(thoughtSignature string, thoughtSignatureSnake string) string {
if thoughtSignature != "" {
return thoughtSignature
}
if thoughtSignatureSnake != "" {
return thoughtSignatureSnake
}
return ""
}
var geminiUnsupportedKeywords = map[string]bool{
"patternProperties": true,
"additionalProperties": true,
"$schema": true,
"$id": true,
"$ref": true,
"$defs": true,
"definitions": true,
"examples": true,
"minLength": true,
"maxLength": true,
"minimum": true,
"maximum": true,
"multipleOf": true,
"pattern": true,
"format": true,
"minItems": true,
"maxItems": true,
"uniqueItems": true,
"minProperties": true,
"maxProperties": true,
}
func sanitizeSchemaForGemini(schema map[string]any) map[string]any {
if schema == nil {
return nil
}
result := make(map[string]any)
for k, v := range schema {
if geminiUnsupportedKeywords[k] {
continue
}
switch val := v.(type) {
case map[string]any:
result[k] = sanitizeSchemaForGemini(val)
case []any:
sanitized := make([]any, len(val))
for i, item := range val {
if m, ok := item.(map[string]any); ok {
sanitized[i] = sanitizeSchemaForGemini(m)
} else {
sanitized[i] = item
}
}
result[k] = sanitized
default:
result[k] = v
}
}
if _, hasProps := result["properties"]; hasProps {
if _, hasType := result["type"]; !hasType {
result["type"] = "object"
}
}
return result
}
func extractProtocol(model string) (protocol, modelID string) {
model = strings.TrimSpace(model)
protocol, modelID, found := strings.Cut(model, "/")
if !found {
return "openai", model
}
return protocol, modelID
}
+796
View File
@@ -0,0 +1,796 @@
package httpapi
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/sipeed/picoclaw/pkg/providers/common"
)
const (
geminiDefaultAPIBase = "https://generativelanguage.googleapis.com/v1beta"
geminiDefaultModel = "gemini-2.0-flash"
)
type GeminiProvider struct {
apiKey string
apiBase string
httpClient *http.Client
extraBody map[string]any
customHeaders map[string]string
userAgent string
}
func NewGeminiProvider(
apiKey string,
apiBase string,
proxy string,
userAgent string,
requestTimeoutSeconds int,
extraBody map[string]any,
customHeaders map[string]string,
) *GeminiProvider {
if strings.TrimSpace(apiBase) == "" {
apiBase = geminiDefaultAPIBase
}
client := common.NewHTTPClient(proxy)
if requestTimeoutSeconds > 0 {
client.Timeout = time.Duration(requestTimeoutSeconds) * time.Second
}
return &GeminiProvider{
apiKey: strings.TrimSpace(apiKey),
apiBase: strings.TrimRight(strings.TrimSpace(apiBase), "/"),
httpClient: client,
extraBody: cloneAnyMap(extraBody),
customHeaders: cloneStringMap(customHeaders),
userAgent: strings.TrimSpace(userAgent),
}
}
func (p *GeminiProvider) GetDefaultModel() string {
return geminiDefaultModel
}
func (p *GeminiProvider) SupportsThinking() bool {
return true
}
func (p *GeminiProvider) Chat(
ctx context.Context,
messages []Message,
tools []ToolDefinition,
model string,
options map[string]any,
) (*LLMResponse, error) {
if p.apiBase == "" {
return nil, fmt.Errorf("API base not configured")
}
model = normalizeGeminiModel(model)
requestBody := p.buildRequestBody(messages, tools, model, options)
jsonData, err := json.Marshal(requestBody)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
url := fmt.Sprintf("%s/models/%s:generateContent", p.apiBase, model)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(jsonData))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
p.applyHeaders(req)
resp, err := p.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, common.HandleErrorResponse(resp, p.apiBase)
}
var apiResp geminiGenerateContentResponse
if err := json.NewDecoder(resp.Body).Decode(&apiResp); err != nil {
return nil, fmt.Errorf("failed to decode response: %w", err)
}
return parseGeminiResponse(&apiResp), nil
}
func (p *GeminiProvider) ChatStream(
ctx context.Context,
messages []Message,
tools []ToolDefinition,
model string,
options map[string]any,
onChunk func(accumulated string),
) (*LLMResponse, error) {
if p.apiBase == "" {
return nil, fmt.Errorf("API base not configured")
}
model = normalizeGeminiModel(model)
requestBody := p.buildRequestBody(messages, tools, model, options)
jsonData, err := json.Marshal(requestBody)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
url := fmt.Sprintf("%s/models/%s:streamGenerateContent?alt=sse", p.apiBase, model)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(jsonData))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
p.applyHeaders(req)
req.Header.Set("Accept", "text/event-stream")
// Streaming should not use a whole-request timeout; context cancellation is the guard.
streamClient := &http.Client{Transport: p.httpClient.Transport}
resp, err := streamClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, common.HandleErrorResponse(resp, p.apiBase)
}
return parseGeminiStreamResponse(ctx, resp.Body, onChunk)
}
func (p *GeminiProvider) applyHeaders(req *http.Request) {
req.Header.Set("Content-Type", "application/json")
if p.apiKey != "" {
req.Header.Set("X-Goog-Api-Key", p.apiKey)
}
if p.userAgent != "" {
req.Header.Set("User-Agent", p.userAgent)
}
for k, v := range p.customHeaders {
if strings.TrimSpace(k) == "" {
continue
}
req.Header.Set(k, v)
}
}
func (p *GeminiProvider) buildRequestBody(
messages []Message,
tools []ToolDefinition,
model string,
options map[string]any,
) map[string]any {
contents := make([]geminiContent, 0, len(messages))
toolCallNames := make(map[string]string)
systemPrompts := make([]string, 0, 1)
for _, msg := range messages {
switch msg.Role {
case "system":
if strings.TrimSpace(msg.Content) != "" {
systemPrompts = append(systemPrompts, msg.Content)
}
case "user":
if msg.ToolCallID != "" {
toolName := resolveToolResponseName(msg.ToolCallID, toolCallNames)
contents = append(contents, geminiContent{
Role: "user",
Parts: []geminiPart{{
FunctionResponse: buildGeminiFunctionResponse(toolName, msg.ToolCallID, msg.Content, msg.Media),
}},
})
continue
}
parts := make([]geminiPart, 0, 1+len(msg.Media))
if strings.TrimSpace(msg.Content) != "" {
parts = append(parts, geminiPart{Text: msg.Content})
}
parts = append(parts, buildInlineMediaParts(msg.Media)...)
if len(parts) > 0 {
contents = append(contents, geminiContent{Role: "user", Parts: parts})
}
case "assistant":
content := geminiContent{Role: "model"}
if strings.TrimSpace(msg.Content) != "" {
content.Parts = append(content.Parts, geminiPart{Text: msg.Content})
}
for _, tc := range msg.ToolCalls {
toolName, toolArgs, thoughtSignature := normalizeStoredToolCall(tc)
if toolName == "" {
continue
}
if tc.ID != "" {
toolCallNames[tc.ID] = toolName
}
part := geminiPart{
FunctionCall: &geminiFunctionCall{
Name: toolName,
Args: toolArgs,
ID: tc.ID,
},
}
if thoughtSignature != "" {
part.ThoughtSignature = thoughtSignature
}
content.Parts = append(content.Parts, part)
}
if len(content.Parts) > 0 {
contents = append(contents, content)
}
case "tool":
toolName := resolveToolResponseName(msg.ToolCallID, toolCallNames)
contents = append(contents, geminiContent{
Role: "user",
Parts: []geminiPart{{
FunctionResponse: buildGeminiFunctionResponse(toolName, msg.ToolCallID, msg.Content, msg.Media),
}},
})
}
}
body := map[string]any{
"contents": contents,
}
if len(systemPrompts) > 0 {
systemParts := make([]geminiPart, 0, len(systemPrompts))
for _, prompt := range systemPrompts {
systemParts = append(systemParts, geminiPart{Text: prompt})
}
body["systemInstruction"] = &geminiContent{Parts: systemParts}
}
if len(tools) > 0 {
funcDecls := make([]geminiFunctionDeclaration, 0, len(tools))
for _, t := range tools {
if t.Type != "function" {
continue
}
funcDecls = append(funcDecls, geminiFunctionDeclaration{
Name: t.Function.Name,
Description: t.Function.Description,
Parameters: sanitizeSchemaForGemini(t.Function.Parameters),
})
}
if len(funcDecls) > 0 {
body["tools"] = []geminiTool{{FunctionDeclarations: funcDecls}}
}
}
generationConfig := make(map[string]any)
if val, ok := options["max_tokens"]; ok {
if maxTokens, ok := val.(int); ok && maxTokens > 0 {
generationConfig["maxOutputTokens"] = maxTokens
} else if maxTokens, ok := val.(float64); ok && maxTokens > 0 {
generationConfig["maxOutputTokens"] = int(maxTokens)
}
}
if temp, ok := options["temperature"].(float64); ok {
generationConfig["temperature"] = temp
}
if thinkingConfig := buildGeminiThinkingConfig(model, options); len(thinkingConfig) > 0 {
generationConfig["thinkingConfig"] = thinkingConfig
}
if len(generationConfig) > 0 {
body["generationConfig"] = generationConfig
}
for k, v := range p.extraBody {
body[k] = v
}
return body
}
func normalizeGeminiModel(model string) string {
model = strings.TrimSpace(model)
model = strings.TrimPrefix(model, "models/")
if strings.Contains(model, "/") {
_, modelID := extractProtocol(model)
if modelID != "" {
return modelID
}
}
if model == "" {
return geminiDefaultModel
}
return model
}
func mapGeminiThinkingLevel(level string) string {
switch strings.ToLower(strings.TrimSpace(level)) {
case "minimal", "off":
return "minimal"
case "low":
return "low"
case "medium":
return "medium"
case "high", "xhigh", "adaptive":
return "high"
default:
return ""
}
}
func buildGeminiThinkingConfig(model string, options map[string]any) map[string]any {
if !geminiModelSupportsThinkingConfig(model) {
return nil
}
config := map[string]any{}
rawLevel, _ := options["thinking_level"].(string)
rawLevel = strings.ToLower(strings.TrimSpace(rawLevel))
if rawLevel == "" {
// Align with agent-level default: unset means ThinkingOff.
rawLevel = "off"
}
includeThoughts := rawLevel != "off" && rawLevel != "minimal"
config["includeThoughts"] = includeThoughts
if isGemini25Model(model) {
if isGemini25ProModel(model) && (rawLevel == "off" || rawLevel == "minimal") {
// Gemini 2.5 Pro cannot disable thinking; keep model-default thinking.
return config
}
if budget, ok := mapGeminiThinkingBudget(rawLevel); ok {
config["thinkingBudget"] = budget
}
return config
}
if isGemini3ProModel(model) && (rawLevel == "off" || rawLevel == "minimal") {
// Gemini 3.x Pro does not support minimal thinking level.
return config
}
if thinkingLevel := mapGeminiThinkingLevel(rawLevel); thinkingLevel != "" {
config["thinkingLevel"] = thinkingLevel
}
return config
}
func geminiModelSupportsThinkingConfig(model string) bool {
lowerModel := strings.ToLower(strings.TrimSpace(model))
return strings.Contains(lowerModel, "gemini-3") || isGemini25Model(lowerModel)
}
func isGemini25Model(model string) bool {
lowerModel := strings.ToLower(strings.TrimSpace(model))
return strings.Contains(lowerModel, "gemini-2.5") || strings.Contains(lowerModel, "gemini-25")
}
func isGemini25ProModel(model string) bool {
lowerModel := strings.ToLower(strings.TrimSpace(model))
return isGemini25Model(lowerModel) && strings.Contains(lowerModel, "pro")
}
func isGemini3ProModel(model string) bool {
lowerModel := strings.ToLower(strings.TrimSpace(model))
return strings.Contains(lowerModel, "gemini-3") && strings.Contains(lowerModel, "pro")
}
func mapGeminiThinkingBudget(level string) (int, bool) {
level = strings.ToLower(strings.TrimSpace(level))
if level == "" {
return 0, false
}
switch level {
case "adaptive":
return -1, true
case "minimal":
return 0, true
case "off":
return 0, true
case "low":
return 1024, true
case "medium":
return 4096, true
case "high":
return 8192, true
case "xhigh":
return 16384, true
default:
return 0, false
}
}
func parseGeminiResponse(resp *geminiGenerateContentResponse) *LLMResponse {
contentParts := make([]string, 0)
reasoningParts := make([]string, 0)
toolCalls := make([]ToolCall, 0)
finishReason := ""
for _, candidate := range resp.Candidates {
for _, part := range candidate.Content.Parts {
if part.Text != "" {
if part.Thought {
reasoningParts = append(reasoningParts, part.Text)
} else {
contentParts = append(contentParts, part.Text)
}
}
if part.FunctionCall != nil {
toolCalls = append(toolCalls, buildGeminiToolCall(part))
}
}
if candidate.FinishReason != "" {
finishReason = candidate.FinishReason
}
}
var usage *UsageInfo
if resp.UsageMetadata.TotalTokenCount > 0 {
usage = &UsageInfo{
PromptTokens: resp.UsageMetadata.PromptTokenCount,
CompletionTokens: resp.UsageMetadata.CandidatesTokenCount,
TotalTokens: resp.UsageMetadata.TotalTokenCount,
}
}
return &LLMResponse{
Content: strings.Join(contentParts, ""),
ReasoningContent: strings.Join(reasoningParts, ""),
ToolCalls: toolCalls,
FinishReason: normalizeGeminiFinishReason(finishReason, len(toolCalls)),
Usage: usage,
}
}
func parseGeminiStreamResponse(
ctx context.Context,
reader io.Reader,
onChunk func(accumulated string),
) (*LLMResponse, error) {
var contentBuilder strings.Builder
var reasoningBuilder strings.Builder
var finishReason string
var usage *UsageInfo
toolCallsByID := make(map[string]ToolCall)
toolCallOrder := make([]string, 0)
fallbackIndex := 0
scanner := bufio.NewScanner(reader)
scanner.Buffer(make([]byte, 0, 1024*1024), 10*1024*1024)
for scanner.Scan() {
if err := ctx.Err(); err != nil {
return nil, err
}
line := scanner.Text()
if !strings.HasPrefix(line, "data: ") {
continue
}
data := strings.TrimSpace(strings.TrimPrefix(line, "data: "))
if data == "" {
continue
}
if data == "[DONE]" {
break
}
var chunk geminiGenerateContentResponse
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
return nil, fmt.Errorf("invalid gemini stream chunk: %w", err)
}
for _, candidate := range chunk.Candidates {
for _, part := range candidate.Content.Parts {
if part.Text != "" {
if part.Thought {
reasoningBuilder.WriteString(part.Text)
} else {
contentBuilder.WriteString(part.Text)
if onChunk != nil {
onChunk(contentBuilder.String())
}
}
}
if part.FunctionCall != nil {
tc := buildGeminiToolCall(part)
if strings.TrimSpace(tc.Name) == "" {
continue
}
key := strings.TrimSpace(part.FunctionCall.ID)
if key == "" {
if len(toolCallOrder) > 0 {
lastKey := toolCallOrder[len(toolCallOrder)-1]
if lastTC, exists := toolCallsByID[lastKey]; exists && lastTC.Name == tc.Name {
key = lastKey
}
}
if key == "" {
fallbackIndex++
key = fmt.Sprintf("%s#%d", tc.Name, fallbackIndex)
}
}
tc.ID = key
if _, exists := toolCallsByID[key]; !exists {
toolCallOrder = append(toolCallOrder, key)
}
toolCallsByID[key] = tc
}
}
if candidate.FinishReason != "" {
finishReason = candidate.FinishReason
}
}
if chunk.UsageMetadata.TotalTokenCount > 0 {
usage = &UsageInfo{
PromptTokens: chunk.UsageMetadata.PromptTokenCount,
CompletionTokens: chunk.UsageMetadata.CandidatesTokenCount,
TotalTokens: chunk.UsageMetadata.TotalTokenCount,
}
}
}
if err := scanner.Err(); err != nil {
return nil, fmt.Errorf("streaming read error: %w", err)
}
toolCalls := make([]ToolCall, 0, len(toolCallOrder))
for _, key := range toolCallOrder {
toolCalls = append(toolCalls, toolCallsByID[key])
}
return &LLMResponse{
Content: contentBuilder.String(),
ReasoningContent: reasoningBuilder.String(),
ToolCalls: toolCalls,
FinishReason: normalizeGeminiFinishReason(finishReason, len(toolCalls)),
Usage: usage,
}, nil
}
func normalizeGeminiFinishReason(reason string, toolCalls int) string {
if toolCalls > 0 {
return "tool_calls"
}
switch strings.ToUpper(strings.TrimSpace(reason)) {
case "MAX_TOKENS":
return "length"
case "", "STOP":
return "stop"
default:
return strings.ToLower(strings.TrimSpace(reason))
}
}
func buildGeminiToolCall(part geminiPart) ToolCall {
if part.FunctionCall == nil {
return ToolCall{}
}
args := part.FunctionCall.Args
if args == nil {
args = make(map[string]any)
}
argsJSON, _ := json.Marshal(args)
thoughtSignature := extractPartThoughtSignature(part.ThoughtSignature, part.ThoughtSignatureSnake)
toolCall := ToolCall{
ID: part.FunctionCall.ID,
Name: part.FunctionCall.Name,
Arguments: args,
ThoughtSignature: thoughtSignature,
Function: &FunctionCall{
Name: part.FunctionCall.Name,
Arguments: string(argsJSON),
ThoughtSignature: thoughtSignature,
},
}
if thoughtSignature != "" {
toolCall.ExtraContent = &ExtraContent{
Google: &GoogleExtra{ThoughtSignature: thoughtSignature},
}
}
if strings.TrimSpace(toolCall.ID) == "" {
toolCall.ID = fmt.Sprintf("call_%s_%d", toolCall.Name, time.Now().UnixNano())
}
return toolCall
}
func buildInlineMediaParts(media []string) []geminiPart {
parts := make([]geminiPart, 0, len(media))
for _, mediaURL := range media {
mimeType, data, ok := parseBase64DataURL(mediaURL)
if !ok {
continue
}
parts = append(parts, geminiPart{
InlineData: &geminiInlineData{
MIMEType: mimeType,
Data: data,
},
})
}
return parts
}
func buildGeminiFunctionResponse(
toolName string,
toolCallID string,
result string,
media []string,
) *geminiFunctionResponse {
response := &geminiFunctionResponse{
ID: toolCallID,
Name: toolName,
Response: map[string]any{
"result": result,
},
}
if parts := buildFunctionResponseMediaParts(media); len(parts) > 0 {
response.Parts = parts
}
return response
}
func buildFunctionResponseMediaParts(media []string) []geminiFunctionResponsePart {
parts := make([]geminiFunctionResponsePart, 0, len(media))
for i, mediaURL := range media {
mimeType, data, ok := parseBase64DataURL(mediaURL)
if !ok {
continue
}
parts = append(parts, geminiFunctionResponsePart{
InlineData: &geminiInlineData{
MIMEType: mimeType,
Data: data,
DisplayName: defaultFunctionResponseDisplayName(mimeType, i+1),
},
})
}
return parts
}
func defaultFunctionResponseDisplayName(mimeType string, index int) string {
suffix := "bin"
switch strings.ToLower(strings.TrimSpace(mimeType)) {
case "image/png":
suffix = "png"
case "image/jpeg":
suffix = "jpg"
case "image/webp":
suffix = "webp"
case "application/pdf":
suffix = "pdf"
case "text/plain":
suffix = "txt"
}
return fmt.Sprintf("attachment-%d.%s", index, suffix)
}
func parseBase64DataURL(mediaURL string) (mimeType string, data string, ok bool) {
if !strings.HasPrefix(mediaURL, "data:") {
return "", "", false
}
payload := strings.TrimPrefix(mediaURL, "data:")
header, data, found := strings.Cut(payload, ",")
if !found {
return "", "", false
}
mimeType, params, _ := strings.Cut(header, ";")
mimeType = strings.TrimSpace(mimeType)
data = strings.TrimSpace(data)
if mimeType == "" || data == "" {
return "", "", false
}
if !strings.Contains(strings.ToLower(params), "base64") {
return "", "", false
}
return mimeType, data, true
}
func cloneAnyMap(in map[string]any) map[string]any {
if len(in) == 0 {
return nil
}
out := make(map[string]any, len(in))
for k, v := range in {
out[k] = v
}
return out
}
func cloneStringMap(in map[string]string) map[string]string {
if len(in) == 0 {
return nil
}
out := make(map[string]string, len(in))
for k, v := range in {
out[k] = v
}
return out
}
type geminiGenerateContentResponse struct {
Candidates []struct {
Content struct {
Role string `json:"role"`
Parts []geminiPart `json:"parts"`
} `json:"content"`
FinishReason string `json:"finishReason"`
} `json:"candidates"`
UsageMetadata struct {
PromptTokenCount int `json:"promptTokenCount"`
CandidatesTokenCount int `json:"candidatesTokenCount"`
TotalTokenCount int `json:"totalTokenCount"`
} `json:"usageMetadata"`
}
type geminiContent struct {
Role string `json:"role,omitempty"`
Parts []geminiPart `json:"parts"`
}
type geminiPart struct {
Text string `json:"text,omitempty"`
Thought bool `json:"thought,omitempty"`
ThoughtSignature string `json:"thoughtSignature,omitempty"`
ThoughtSignatureSnake string `json:"thought_signature,omitempty"`
InlineData *geminiInlineData `json:"inlineData,omitempty"`
FunctionCall *geminiFunctionCall `json:"functionCall,omitempty"`
FunctionResponse *geminiFunctionResponse `json:"functionResponse,omitempty"`
}
type geminiInlineData struct {
MIMEType string `json:"mimeType"`
Data string `json:"data"`
DisplayName string `json:"displayName,omitempty"`
}
type geminiFunctionCall struct {
ID string `json:"id,omitempty"`
Name string `json:"name"`
Args map[string]any `json:"args,omitempty"`
}
type geminiFunctionResponse struct {
ID string `json:"id,omitempty"`
Name string `json:"name"`
Response map[string]any `json:"response"`
Parts []geminiFunctionResponsePart `json:"parts,omitempty"`
}
type geminiFunctionResponsePart struct {
InlineData *geminiInlineData `json:"inlineData,omitempty"`
}
type geminiTool struct {
FunctionDeclarations []geminiFunctionDeclaration `json:"functionDeclarations"`
}
type geminiFunctionDeclaration struct {
Name string `json:"name"`
Description string `json:"description,omitempty"`
Parameters any `json:"parameters,omitempty"`
}
@@ -0,0 +1,763 @@
package httpapi
import (
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
)
func TestGeminiProvider_ChatSeparatesThoughtAndToolCall(t *testing.T) {
var capturedBody map[string]any
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
t.Fatalf("method = %s, want POST", r.Method)
}
if !strings.Contains(r.URL.Path, ":generateContent") {
t.Fatalf("path = %s, expected generateContent endpoint", r.URL.Path)
}
if got := r.Header.Get("X-Goog-Api-Key"); got != "test-key" {
t.Fatalf("X-Goog-Api-Key = %q, want %q", got, "test-key")
}
if err := json.NewDecoder(r.Body).Decode(&capturedBody); err != nil {
t.Fatalf("decode request body: %v", err)
}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]any{
"candidates": []any{
map[string]any{
"content": map[string]any{
"role": "model",
"parts": []any{
map[string]any{"text": "hidden", "thought": true},
map[string]any{"text": "visible"},
map[string]any{
"functionCall": map[string]any{
"id": "call_1",
"name": "search",
"args": map[string]any{"q": "hi"},
},
"thoughtSignature": "sig-1",
},
},
},
"finishReason": "STOP",
},
},
"usageMetadata": map[string]any{
"promptTokenCount": 2,
"candidatesTokenCount": 3,
"totalTokenCount": 5,
},
})
}))
defer server.Close()
provider := NewGeminiProvider("test-key", server.URL, "", "picoclaw-test", 0, nil, nil)
resp, err := provider.Chat(
t.Context(),
[]Message{{Role: "user", Content: "hello"}},
nil,
"gemini-3-flash-preview",
map[string]any{"thinking_level": "high"},
)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
if resp.Content != "visible" {
t.Fatalf("Content = %q, want %q", resp.Content, "visible")
}
if resp.ReasoningContent != "hidden" {
t.Fatalf("ReasoningContent = %q, want %q", resp.ReasoningContent, "hidden")
}
if resp.FinishReason != "tool_calls" {
t.Fatalf("FinishReason = %q, want %q", resp.FinishReason, "tool_calls")
}
if resp.Usage == nil || resp.Usage.TotalTokens != 5 {
t.Fatalf("Usage = %#v, expected total tokens = 5", resp.Usage)
}
if len(resp.ToolCalls) != 1 {
t.Fatalf("ToolCalls len = %d, want 1", len(resp.ToolCalls))
}
if resp.ToolCalls[0].ID != "call_1" {
t.Fatalf("ToolCall ID = %q, want %q", resp.ToolCalls[0].ID, "call_1")
}
if resp.ToolCalls[0].Name != "search" {
t.Fatalf("ToolCall Name = %q, want %q", resp.ToolCalls[0].Name, "search")
}
if resp.ToolCalls[0].ThoughtSignature != "sig-1" {
t.Fatalf("ToolCall ThoughtSignature = %q, want %q", resp.ToolCalls[0].ThoughtSignature, "sig-1")
}
if resp.ToolCalls[0].Function == nil || !strings.Contains(resp.ToolCalls[0].Function.Arguments, `"q":"hi"`) {
t.Fatalf("ToolCall Function arguments = %#v, want q=hi", resp.ToolCalls[0].Function)
}
generationConfig, ok := capturedBody["generationConfig"].(map[string]any)
if !ok {
t.Fatalf("request missing generationConfig: %#v", capturedBody)
}
thinkingConfig, ok := generationConfig["thinkingConfig"].(map[string]any)
if !ok {
t.Fatalf("request missing thinkingConfig: %#v", generationConfig)
}
if includeThoughts, ok := thinkingConfig["includeThoughts"].(bool); !ok || !includeThoughts {
t.Fatalf("thinkingConfig.includeThoughts = %#v, want true", thinkingConfig["includeThoughts"])
}
if got := thinkingConfig["thinkingLevel"]; got != "high" {
t.Fatalf("thinkingConfig.thinkingLevel = %#v, want %q", got, "high")
}
}
func TestGeminiProvider_ChatStreamParsesThoughtTextAndToolCalls(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !strings.Contains(r.URL.Path, ":streamGenerateContent") {
t.Fatalf("path = %s, expected streamGenerateContent endpoint", r.URL.Path)
}
if got := r.URL.Query().Get("alt"); got != "sse" {
t.Fatalf("alt query = %q, want %q", got, "sse")
}
w.Header().Set("Content-Type", "text/event-stream")
flusher, ok := w.(http.Flusher)
if !ok {
t.Fatal("response writer is not flushable")
}
chunks := []map[string]any{
{
"candidates": []any{map[string]any{
"content": map[string]any{
"parts": []any{
map[string]any{"text": "think ", "thought": true},
map[string]any{"text": "Hello "},
},
},
}},
},
{
"candidates": []any{map[string]any{
"content": map[string]any{
"parts": []any{
map[string]any{"text": "World"},
map[string]any{
"functionCall": map[string]any{
"id": "call_stream",
"name": "search",
"args": map[string]any{"q": "stream"},
},
},
},
},
"finishReason": "STOP",
}},
"usageMetadata": map[string]any{
"promptTokenCount": 1,
"candidatesTokenCount": 2,
"totalTokenCount": 3,
},
},
}
for _, chunk := range chunks {
raw, err := json.Marshal(chunk)
if err != nil {
t.Fatalf("marshal chunk: %v", err)
}
if _, err := fmt.Fprintf(w, "data: %s\n\n", raw); err != nil {
t.Fatalf("write chunk: %v", err)
}
flusher.Flush()
}
_, _ = fmt.Fprint(w, "data: [DONE]\n\n")
flusher.Flush()
}))
defer server.Close()
provider := NewGeminiProvider("test-key", server.URL, "", "", 0, nil, nil)
updates := make([]string, 0)
resp, err := provider.ChatStream(
t.Context(),
[]Message{{Role: "user", Content: "hello"}},
nil,
"gemini-2.5-flash",
nil,
func(accumulated string) {
updates = append(updates, accumulated)
},
)
if err != nil {
t.Fatalf("ChatStream() error = %v", err)
}
if resp.Content != "Hello World" {
t.Fatalf("Content = %q, want %q", resp.Content, "Hello World")
}
if resp.ReasoningContent != "think " {
t.Fatalf("ReasoningContent = %q, want %q", resp.ReasoningContent, "think ")
}
if len(resp.ToolCalls) != 1 || resp.ToolCalls[0].ID != "call_stream" {
t.Fatalf("ToolCalls = %#v, want single call_stream", resp.ToolCalls)
}
if resp.FinishReason != "tool_calls" {
t.Fatalf("FinishReason = %q, want %q", resp.FinishReason, "tool_calls")
}
if resp.Usage == nil || resp.Usage.TotalTokens != 3 {
t.Fatalf("Usage = %#v, expected total tokens = 3", resp.Usage)
}
if len(updates) < 2 || updates[len(updates)-1] != "Hello World" {
t.Fatalf("stream updates = %#v, expected final accumulated text", updates)
}
}
func TestGeminiProvider_ChatStreamSkipsEmptyDataFrames(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
flusher, ok := w.(http.Flusher)
if !ok {
t.Fatal("response writer is not flushable")
}
_, _ = fmt.Fprint(w, "data: \n\n")
flusher.Flush()
chunk := map[string]any{
"candidates": []any{map[string]any{
"content": map[string]any{
"parts": []any{map[string]any{"text": "ok"}},
},
"finishReason": "STOP",
}},
}
raw, err := json.Marshal(chunk)
if err != nil {
t.Fatalf("marshal chunk: %v", err)
}
_, _ = fmt.Fprintf(w, "data: %s\n\n", raw)
flusher.Flush()
_, _ = fmt.Fprint(w, "data: [DONE]\n\n")
flusher.Flush()
}))
defer server.Close()
provider := NewGeminiProvider("test-key", server.URL, "", "", 0, nil, nil)
resp, err := provider.ChatStream(
t.Context(),
[]Message{{Role: "user", Content: "hello"}},
nil,
"gemini-2.5-flash",
nil,
nil,
)
if err != nil {
t.Fatalf("ChatStream() error = %v", err)
}
if resp.Content != "ok" {
t.Fatalf("Content = %q, want %q", resp.Content, "ok")
}
}
func TestGeminiProvider_ChatStreamReturnsErrorOnInvalidDataFrame(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
flusher, ok := w.(http.Flusher)
if !ok {
t.Fatal("response writer is not flushable")
}
_, _ = fmt.Fprint(w, "data: {invalid-json}\n\n")
flusher.Flush()
}))
defer server.Close()
provider := NewGeminiProvider("test-key", server.URL, "", "", 0, nil, nil)
_, err := provider.ChatStream(
t.Context(),
[]Message{{Role: "user", Content: "hello"}},
nil,
"gemini-2.5-flash",
nil,
nil,
)
if err == nil {
t.Fatal("ChatStream() expected error for invalid SSE data frame")
}
if !strings.Contains(err.Error(), "invalid gemini stream chunk") {
t.Fatalf("error = %v, want contains %q", err, "invalid gemini stream chunk")
}
}
func TestGeminiProvider_BuildRequestBody_UsesCamelCaseThoughtSignatureOnly(t *testing.T) {
provider := NewGeminiProvider("test-key", "https://example.com/v1beta", "", "", 0, nil, nil)
body := provider.buildRequestBody(
[]Message{{
Role: "assistant",
ToolCalls: []ToolCall{{
ID: "call_1",
Name: "search",
Arguments: map[string]any{"q": "hello"},
Function: &FunctionCall{
Name: "search",
Arguments: `{"q":"hello"}`,
ThoughtSignature: "sig-1",
},
}},
}},
nil,
"gemini-2.5-flash",
nil,
)
raw, err := json.Marshal(body)
if err != nil {
t.Fatalf("marshal request body: %v", err)
}
jsonBody := string(raw)
if !strings.Contains(jsonBody, `"thoughtSignature":"sig-1"`) {
t.Fatalf("request body = %s, expected camelCase thoughtSignature", jsonBody)
}
if strings.Contains(jsonBody, `"thought_signature"`) {
t.Fatalf("request body = %s, unexpected snake_case thought_signature", jsonBody)
}
}
func TestGeminiProvider_ChatStreamCoalescesToolCallWithoutWireID(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
flusher, ok := w.(http.Flusher)
if !ok {
t.Fatal("response writer is not flushable")
}
chunks := []map[string]any{
{
"candidates": []any{map[string]any{
"content": map[string]any{
"parts": []any{
map[string]any{
"functionCall": map[string]any{
"name": "search",
"args": map[string]any{"q": "first"},
},
},
},
},
}},
},
{
"candidates": []any{map[string]any{
"content": map[string]any{
"parts": []any{
map[string]any{
"functionCall": map[string]any{
"name": "search",
"args": map[string]any{"q": "second"},
},
},
},
},
"finishReason": "STOP",
}},
},
}
for _, chunk := range chunks {
raw, err := json.Marshal(chunk)
if err != nil {
t.Fatalf("marshal chunk: %v", err)
}
if _, err := fmt.Fprintf(w, "data: %s\n\n", raw); err != nil {
t.Fatalf("write chunk: %v", err)
}
flusher.Flush()
}
_, _ = fmt.Fprint(w, "data: [DONE]\n\n")
flusher.Flush()
}))
defer server.Close()
provider := NewGeminiProvider("test-key", server.URL, "", "", 0, nil, nil)
resp, err := provider.ChatStream(
t.Context(),
[]Message{{Role: "user", Content: "hello"}},
nil,
"gemini-2.5-flash",
nil,
nil,
)
if err != nil {
t.Fatalf("ChatStream() error = %v", err)
}
if len(resp.ToolCalls) != 1 {
t.Fatalf("ToolCalls len = %d, want 1", len(resp.ToolCalls))
}
tc := resp.ToolCalls[0]
if tc.ID != "search#1" {
t.Fatalf("ToolCall ID = %q, want %q", tc.ID, "search#1")
}
if tc.Name != "search" {
t.Fatalf("ToolCall Name = %q, want %q", tc.Name, "search")
}
if argQ, ok := tc.Arguments["q"].(string); !ok || argQ != "second" {
t.Fatalf("ToolCall Arguments = %#v, want q=second", tc.Arguments)
}
if resp.FinishReason != "tool_calls" {
t.Fatalf("FinishReason = %q, want %q", resp.FinishReason, "tool_calls")
}
}
func TestGeminiProvider_BuildRequestBodyIncludesMediaAndThinkingConfig(t *testing.T) {
provider := NewGeminiProvider("test-key", "https://example.com/v1beta", "", "", 0, nil, nil)
body := provider.buildRequestBody(
[]Message{{
Role: "user",
Content: "analyze attachments",
Media: []string{
"data:application/pdf;base64,UEZERGF0YQ==",
"data:image/png;base64,aW1hZ2VEYXRh",
},
}},
nil,
"gemini-3-flash-preview",
map[string]any{
"thinking_level": "low",
"max_tokens": 128,
"temperature": 0.2,
},
)
contents, ok := body["contents"].([]geminiContent)
if !ok || len(contents) != 1 {
t.Fatalf("contents = %#v, want one gemini content", body["contents"])
}
parts := contents[0].Parts
mimeSet := map[string]bool{}
for _, part := range parts {
if part.InlineData != nil {
mimeSet[part.InlineData.MIMEType] = true
}
}
if !mimeSet["application/pdf"] {
t.Fatalf("inline media missing application/pdf: %#v", parts)
}
if !mimeSet["image/png"] {
t.Fatalf("inline media missing image/png: %#v", parts)
}
generationConfig, ok := body["generationConfig"].(map[string]any)
if !ok {
t.Fatalf("generationConfig = %#v, want map", body["generationConfig"])
}
if got := generationConfig["maxOutputTokens"]; got != 128 {
t.Fatalf("maxOutputTokens = %#v, want 128", got)
}
if got := generationConfig["temperature"]; got != 0.2 {
t.Fatalf("temperature = %#v, want 0.2", got)
}
thinkingConfig, ok := generationConfig["thinkingConfig"].(map[string]any)
if !ok {
t.Fatalf("thinkingConfig = %#v, want map", generationConfig["thinkingConfig"])
}
if includeThoughts, ok := thinkingConfig["includeThoughts"].(bool); !ok || !includeThoughts {
t.Fatalf("includeThoughts = %#v, want true", thinkingConfig["includeThoughts"])
}
if got := thinkingConfig["thinkingLevel"]; got != "low" {
t.Fatalf("thinkingLevel = %#v, want %q", got, "low")
}
}
func TestGeminiProvider_BuildRequestBody_UsesThinkingBudgetForGemini25(t *testing.T) {
provider := NewGeminiProvider("test-key", "https://example.com/v1beta", "", "", 0, nil, nil)
body := provider.buildRequestBody(
[]Message{{Role: "user", Content: "hello"}},
nil,
"gemini-2.5-flash",
map[string]any{"thinking_level": "medium"},
)
generationConfig, ok := body["generationConfig"].(map[string]any)
if !ok {
t.Fatalf("generationConfig = %#v, want map", body["generationConfig"])
}
thinkingConfig, ok := generationConfig["thinkingConfig"].(map[string]any)
if !ok {
t.Fatalf("thinkingConfig = %#v, want map", generationConfig["thinkingConfig"])
}
if got := thinkingConfig["thinkingBudget"]; got != 4096 {
t.Fatalf("thinkingBudget = %#v, want 4096", got)
}
if _, hasLevel := thinkingConfig["thinkingLevel"]; hasLevel {
t.Fatalf("thinkingLevel should not be set for Gemini 2.5: %#v", thinkingConfig)
}
}
func TestGeminiProvider_BuildRequestBody_OmitsThinkingConfigForGemini20(t *testing.T) {
provider := NewGeminiProvider("test-key", "https://example.com/v1beta", "", "", 0, nil, nil)
body := provider.buildRequestBody(
[]Message{{Role: "user", Content: "hello"}},
nil,
"gemini-2.0-flash-exp",
map[string]any{"thinking_level": "high"},
)
if _, ok := body["generationConfig"]; ok {
t.Fatalf("generationConfig should be omitted for Gemini 2.0 when only thinking_level is set: %#v", body)
}
}
func TestGeminiProvider_BuildRequestBody_DefaultsThinkingOffForGemini25(t *testing.T) {
provider := NewGeminiProvider("test-key", "https://example.com/v1beta", "", "", 0, nil, nil)
body := provider.buildRequestBody(
[]Message{{Role: "user", Content: "hello"}},
nil,
"gemini-2.5-flash",
nil,
)
generationConfig, ok := body["generationConfig"].(map[string]any)
if !ok {
t.Fatalf("generationConfig = %#v, want map", body["generationConfig"])
}
thinkingConfig, ok := generationConfig["thinkingConfig"].(map[string]any)
if !ok {
t.Fatalf("thinkingConfig = %#v, want map", generationConfig["thinkingConfig"])
}
if got := thinkingConfig["thinkingBudget"]; got != 0 {
t.Fatalf("thinkingBudget = %#v, want 0 for default/off", got)
}
if includeThoughts, ok := thinkingConfig["includeThoughts"].(bool); !ok || includeThoughts {
t.Fatalf("includeThoughts = %#v, want false for default/off", thinkingConfig["includeThoughts"])
}
}
func TestGeminiProvider_BuildRequestBody_DefaultsThinkingOffForGemini3(t *testing.T) {
provider := NewGeminiProvider("test-key", "https://example.com/v1beta", "", "", 0, nil, nil)
body := provider.buildRequestBody(
[]Message{{Role: "user", Content: "hello"}},
nil,
"gemini-3-flash-preview",
nil,
)
generationConfig, ok := body["generationConfig"].(map[string]any)
if !ok {
t.Fatalf("generationConfig = %#v, want map", body["generationConfig"])
}
thinkingConfig, ok := generationConfig["thinkingConfig"].(map[string]any)
if !ok {
t.Fatalf("thinkingConfig = %#v, want map", generationConfig["thinkingConfig"])
}
if got := thinkingConfig["thinkingLevel"]; got != "minimal" {
t.Fatalf("thinkingLevel = %#v, want minimal for default/off", got)
}
if includeThoughts, ok := thinkingConfig["includeThoughts"].(bool); !ok || includeThoughts {
t.Fatalf("includeThoughts = %#v, want false for default/off", thinkingConfig["includeThoughts"])
}
}
func TestGeminiProvider_BuildRequestBody_DefaultsThinkingOffForGemini25Pro(t *testing.T) {
provider := NewGeminiProvider("test-key", "https://example.com/v1beta", "", "", 0, nil, nil)
body := provider.buildRequestBody(
[]Message{{Role: "user", Content: "hello"}},
nil,
"gemini-2.5-pro",
nil,
)
generationConfig, ok := body["generationConfig"].(map[string]any)
if !ok {
t.Fatalf("generationConfig = %#v, want map", body["generationConfig"])
}
thinkingConfig, ok := generationConfig["thinkingConfig"].(map[string]any)
if !ok {
t.Fatalf("thinkingConfig = %#v, want map", generationConfig["thinkingConfig"])
}
if includeThoughts, ok := thinkingConfig["includeThoughts"].(bool); !ok || includeThoughts {
t.Fatalf("includeThoughts = %#v, want false for default/off", thinkingConfig["includeThoughts"])
}
if _, hasBudget := thinkingConfig["thinkingBudget"]; hasBudget {
t.Fatalf("thinkingBudget should be omitted for Gemini 2.5 Pro default/off: %#v", thinkingConfig)
}
}
func TestGeminiProvider_BuildRequestBody_DefaultsThinkingOffForGemini31Pro(t *testing.T) {
provider := NewGeminiProvider("test-key", "https://example.com/v1beta", "", "", 0, nil, nil)
body := provider.buildRequestBody(
[]Message{{Role: "user", Content: "hello"}},
nil,
"gemini-3.1-pro",
nil,
)
generationConfig, ok := body["generationConfig"].(map[string]any)
if !ok {
t.Fatalf("generationConfig = %#v, want map", body["generationConfig"])
}
thinkingConfig, ok := generationConfig["thinkingConfig"].(map[string]any)
if !ok {
t.Fatalf("thinkingConfig = %#v, want map", generationConfig["thinkingConfig"])
}
if includeThoughts, ok := thinkingConfig["includeThoughts"].(bool); !ok || includeThoughts {
t.Fatalf("includeThoughts = %#v, want false for default/off", thinkingConfig["includeThoughts"])
}
if _, hasLevel := thinkingConfig["thinkingLevel"]; hasLevel {
t.Fatalf("thinkingLevel should be omitted for Gemini 3.1 Pro default/off: %#v", thinkingConfig)
}
}
func TestGeminiProvider_BuildRequestBody_PreservesMultipleSystemMessages(t *testing.T) {
provider := NewGeminiProvider("test-key", "https://example.com/v1beta", "", "", 0, nil, nil)
body := provider.buildRequestBody(
[]Message{
{Role: "system", Content: "You are helpful."},
{Role: "system", Content: "Be concise."},
{Role: "user", Content: "hello"},
},
nil,
"gemini-3-flash-preview",
nil,
)
systemInstruction, ok := body["systemInstruction"].(*geminiContent)
if !ok || systemInstruction == nil {
t.Fatalf("systemInstruction = %#v, want *geminiContent", body["systemInstruction"])
}
if len(systemInstruction.Parts) != 2 {
t.Fatalf("systemInstruction.Parts len = %d, want 2", len(systemInstruction.Parts))
}
if systemInstruction.Parts[0].Text != "You are helpful." || systemInstruction.Parts[1].Text != "Be concise." {
t.Fatalf("systemInstruction.Parts = %#v, want ordered system prompts", systemInstruction.Parts)
}
}
func TestGeminiProvider_BuildRequestBody_PreservesToolResponseMedia(t *testing.T) {
provider := NewGeminiProvider("test-key", "https://example.com/v1beta", "", "", 0, nil, nil)
body := provider.buildRequestBody(
[]Message{
{
Role: "assistant",
ToolCalls: []ToolCall{{
ID: "call_1",
Name: "load_image",
Arguments: map[string]any{"path": "demo.png"},
}},
},
{
Role: "tool",
ToolCallID: "call_1",
Content: "tool result",
Media: []string{
"data:image/png;base64,aW1hZ2VEYXRh",
"data:application/pdf;base64,UEZERGF0YQ==",
},
},
},
nil,
"gemini-3-flash-preview",
nil,
)
contents, ok := body["contents"].([]geminiContent)
if !ok || len(contents) != 2 {
t.Fatalf("contents = %#v, want two content entries", body["contents"])
}
parts := contents[1].Parts
if len(parts) != 1 || parts[0].FunctionResponse == nil {
t.Fatalf("tool response part = %#v, want functionResponse", parts)
}
response := parts[0].FunctionResponse
if response.Name != "load_image" {
t.Fatalf("functionResponse.Name = %q, want %q", response.Name, "load_image")
}
if response.Response["result"] != "tool result" {
t.Fatalf("functionResponse.Response = %#v, want result=tool result", response.Response)
}
if len(response.Parts) != 2 {
t.Fatalf("functionResponse.Parts len = %d, want 2", len(response.Parts))
}
}
func TestGeminiProvider_ChatAllowsCustomAuthHeaderWithoutAPIKey(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if got := r.Header.Get("Authorization"); got != "Bearer test-token" {
t.Fatalf("Authorization = %q, want %q", got, "Bearer test-token")
}
if got := r.Header.Get("X-Goog-Api-Key"); got != "" {
t.Fatalf("X-Goog-Api-Key = %q, want empty", got)
}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]any{
"candidates": []any{
map[string]any{
"content": map[string]any{
"parts": []any{map[string]any{"text": "ok"}},
},
"finishReason": "STOP",
},
},
})
}))
defer server.Close()
provider := NewGeminiProvider(
"",
server.URL,
"",
"",
0,
nil,
map[string]string{"Authorization": "Bearer test-token"},
)
resp, err := provider.Chat(
t.Context(),
[]Message{{Role: "user", Content: "hello"}},
nil,
"gemini-2.5-flash",
nil,
)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
if resp.Content != "ok" {
t.Fatalf("Content = %q, want %q", resp.Content, "ok")
}
}
func TestGeminiProvider_ChatAllowsMissingAPIKeyForCustomAPIBase(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if got := r.Header.Get("X-Goog-Api-Key"); got != "" {
t.Fatalf("X-Goog-Api-Key = %q, want empty", got)
}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]any{
"candidates": []any{
map[string]any{
"content": map[string]any{"parts": []any{map[string]any{"text": "ok"}}},
"finishReason": "STOP",
},
},
})
}))
defer server.Close()
provider := NewGeminiProvider("", server.URL, "", "", 0, nil, nil)
resp, err := provider.Chat(
t.Context(),
[]Message{{Role: "user", Content: "hello"}},
nil,
"gemini-2.5-flash",
nil,
)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
if resp.Content != "ok" {
t.Fatalf("Content = %q, want %q", resp.Content, "ok")
}
}
+79
View File
@@ -0,0 +1,79 @@
// PicoClaw - Ultra-lightweight personal AI agent
// Inspired by and based on nanobot: https://github.com/HKUDS/nanobot
// License: MIT
//
// Copyright (c) 2026 PicoClaw contributors
package httpapi
import (
"context"
"time"
"github.com/sipeed/picoclaw/pkg/providers/openai_compat"
)
type HTTPProvider struct {
delegate *openai_compat.Provider
}
func NewHTTPProvider(apiKey, apiBase, proxy string) *HTTPProvider {
return &HTTPProvider{
delegate: openai_compat.NewProvider(apiKey, apiBase, proxy),
}
}
func NewHTTPProviderWithMaxTokensField(apiKey, apiBase, proxy, maxTokensField string) *HTTPProvider {
return NewHTTPProviderWithMaxTokensFieldAndRequestTimeout(apiKey, apiBase, proxy, maxTokensField, "", 0, nil, nil)
}
func NewHTTPProviderWithMaxTokensFieldAndRequestTimeout(
apiKey, apiBase, proxy, maxTokensField, userAgent string,
requestTimeoutSeconds int,
extraBody map[string]any,
customHeaders map[string]string,
) *HTTPProvider {
return &HTTPProvider{
delegate: openai_compat.NewProvider(
apiKey,
apiBase,
proxy,
openai_compat.WithMaxTokensField(maxTokensField),
openai_compat.WithRequestTimeout(time.Duration(requestTimeoutSeconds)*time.Second),
openai_compat.WithExtraBody(extraBody),
openai_compat.WithCustomHeaders(customHeaders),
openai_compat.WithUserAgent(userAgent),
),
}
}
func (p *HTTPProvider) Chat(
ctx context.Context,
messages []Message,
tools []ToolDefinition,
model string,
options map[string]any,
) (*LLMResponse, error) {
return p.delegate.Chat(ctx, messages, tools, model, options)
}
// ChatStream implements providers.StreamingProvider by delegating to the
// OpenAI-compatible streaming endpoint (SSE with stream: true).
func (p *HTTPProvider) ChatStream(
ctx context.Context,
messages []Message,
tools []ToolDefinition,
model string,
options map[string]any,
onChunk func(accumulated string),
) (*LLMResponse, error) {
return p.delegate.ChatStream(ctx, messages, tools, model, options, onChunk)
}
func (p *HTTPProvider) GetDefaultModel() string {
return ""
}
func (p *HTTPProvider) SupportsNativeSearch() bool {
return p.delegate.SupportsNativeSearch()
}
+43
View File
@@ -0,0 +1,43 @@
package httpapi
import (
"context"
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
)
type (
ToolCall = protocoltypes.ToolCall
FunctionCall = protocoltypes.FunctionCall
LLMResponse = protocoltypes.LLMResponse
UsageInfo = protocoltypes.UsageInfo
Message = protocoltypes.Message
ToolDefinition = protocoltypes.ToolDefinition
ToolFunctionDefinition = protocoltypes.ToolFunctionDefinition
ExtraContent = protocoltypes.ExtraContent
GoogleExtra = protocoltypes.GoogleExtra
ContentBlock = protocoltypes.ContentBlock
CacheControl = protocoltypes.CacheControl
)
type LLMProvider interface {
Chat(
ctx context.Context,
messages []Message,
tools []ToolDefinition,
model string,
options map[string]any,
) (*LLMResponse, error)
GetDefaultModel() string
}
type StreamingProvider interface {
ChatStream(
ctx context.Context,
messages []Message,
tools []ToolDefinition,
model string,
options map[string]any,
onChunk func(accumulated string),
) (*LLMResponse, error)
}