Merge pull request #213 from jmahotiedu/refactor/provider-protocol-122

Refactor providers by protocol family (discussion #122)
This commit is contained in:
Leandro Barbosa
2026-02-18 11:25:01 -03:00
committed by GitHub
13 changed files with 1800 additions and 776 deletions
+10
View File
@@ -679,6 +679,16 @@ The subagent has access to tools (message, web_search, etc.) and can communicate
| `deepseek(To be tested)` | LLM (DeepSeek direct) | [platform.deepseek.com](https://platform.deepseek.com) |
| `groq` | LLM + **Voice transcription** (Whisper) | [console.groq.com](https://console.groq.com) |
### Provider Architecture
PicoClaw routes providers by protocol family:
- OpenAI-compatible protocol: OpenRouter, OpenAI-compatible gateways, Groq, Zhipu, and vLLM-style endpoints.
- Anthropic protocol: Claude-native API behavior.
- Codex/OAuth path: OpenAI OAuth/token authentication route.
This keeps the runtime lightweight while making new OpenAI-compatible backends mostly a config operation (`api_base` + `api_key`).
<details>
<summary><b>Zhipu</b></summary>
+18
View File
@@ -299,6 +299,24 @@ func TestConvertConfig(t *testing.T) {
})
}
func TestSupportedProvidersCompatibility(t *testing.T) {
expected := []string{
"anthropic",
"openai",
"openrouter",
"groq",
"zhipu",
"vllm",
"gemini",
}
for _, provider := range expected {
if !supportedProviders[provider] {
t.Fatalf("supportedProviders missing expected key %q", provider)
}
}
}
func TestMergeConfig(t *testing.T) {
t.Run("fills empty fields", func(t *testing.T) {
existing := config.DefaultConfig()
+248
View File
@@ -0,0 +1,248 @@
package anthropicprovider
import (
"context"
"encoding/json"
"fmt"
"log"
"strings"
"github.com/anthropics/anthropic-sdk-go"
"github.com/anthropics/anthropic-sdk-go/option"
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
)
type ToolCall = protocoltypes.ToolCall
type FunctionCall = protocoltypes.FunctionCall
type LLMResponse = protocoltypes.LLMResponse
type UsageInfo = protocoltypes.UsageInfo
type Message = protocoltypes.Message
type ToolDefinition = protocoltypes.ToolDefinition
type ToolFunctionDefinition = protocoltypes.ToolFunctionDefinition
const defaultBaseURL = "https://api.anthropic.com"
type Provider struct {
client *anthropic.Client
tokenSource func() (string, error)
baseURL string
}
func NewProvider(token string) *Provider {
return NewProviderWithBaseURL(token, "")
}
func NewProviderWithBaseURL(token, apiBase string) *Provider {
baseURL := normalizeBaseURL(apiBase)
client := anthropic.NewClient(
option.WithAuthToken(token),
option.WithBaseURL(baseURL),
)
return &Provider{
client: &client,
baseURL: baseURL,
}
}
func NewProviderWithClient(client *anthropic.Client) *Provider {
return &Provider{
client: client,
baseURL: defaultBaseURL,
}
}
func NewProviderWithTokenSource(token string, tokenSource func() (string, error)) *Provider {
return NewProviderWithTokenSourceAndBaseURL(token, tokenSource, "")
}
func NewProviderWithTokenSourceAndBaseURL(token string, tokenSource func() (string, error), apiBase string) *Provider {
p := NewProviderWithBaseURL(token, apiBase)
p.tokenSource = tokenSource
return p
}
func (p *Provider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) {
var opts []option.RequestOption
if p.tokenSource != nil {
tok, err := p.tokenSource()
if err != nil {
return nil, fmt.Errorf("refreshing token: %w", err)
}
opts = append(opts, option.WithAuthToken(tok))
}
params, err := buildParams(messages, tools, model, options)
if err != nil {
return nil, err
}
resp, err := p.client.Messages.New(ctx, params, opts...)
if err != nil {
return nil, fmt.Errorf("claude API call: %w", err)
}
return parseResponse(resp), nil
}
func (p *Provider) GetDefaultModel() string {
return "claude-sonnet-4-5-20250929"
}
func (p *Provider) BaseURL() string {
return p.baseURL
}
func buildParams(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (anthropic.MessageNewParams, error) {
var system []anthropic.TextBlockParam
var anthropicMessages []anthropic.MessageParam
for _, msg := range messages {
switch msg.Role {
case "system":
system = append(system, anthropic.TextBlockParam{Text: msg.Content})
case "user":
if msg.ToolCallID != "" {
anthropicMessages = append(anthropicMessages,
anthropic.NewUserMessage(anthropic.NewToolResultBlock(msg.ToolCallID, msg.Content, false)),
)
} else {
anthropicMessages = append(anthropicMessages,
anthropic.NewUserMessage(anthropic.NewTextBlock(msg.Content)),
)
}
case "assistant":
if len(msg.ToolCalls) > 0 {
var blocks []anthropic.ContentBlockParamUnion
if msg.Content != "" {
blocks = append(blocks, anthropic.NewTextBlock(msg.Content))
}
for _, tc := range msg.ToolCalls {
blocks = append(blocks, anthropic.NewToolUseBlock(tc.ID, tc.Arguments, tc.Name))
}
anthropicMessages = append(anthropicMessages, anthropic.NewAssistantMessage(blocks...))
} else {
anthropicMessages = append(anthropicMessages,
anthropic.NewAssistantMessage(anthropic.NewTextBlock(msg.Content)),
)
}
case "tool":
anthropicMessages = append(anthropicMessages,
anthropic.NewUserMessage(anthropic.NewToolResultBlock(msg.ToolCallID, msg.Content, false)),
)
}
}
maxTokens := int64(4096)
if mt, ok := options["max_tokens"].(int); ok {
maxTokens = int64(mt)
}
params := anthropic.MessageNewParams{
Model: anthropic.Model(model),
Messages: anthropicMessages,
MaxTokens: maxTokens,
}
if len(system) > 0 {
params.System = system
}
if temp, ok := options["temperature"].(float64); ok {
params.Temperature = anthropic.Float(temp)
}
if len(tools) > 0 {
params.Tools = translateTools(tools)
}
return params, nil
}
func translateTools(tools []ToolDefinition) []anthropic.ToolUnionParam {
result := make([]anthropic.ToolUnionParam, 0, len(tools))
for _, t := range tools {
tool := anthropic.ToolParam{
Name: t.Function.Name,
InputSchema: anthropic.ToolInputSchemaParam{
Properties: t.Function.Parameters["properties"],
},
}
if desc := t.Function.Description; desc != "" {
tool.Description = anthropic.String(desc)
}
if req, ok := t.Function.Parameters["required"].([]interface{}); ok {
required := make([]string, 0, len(req))
for _, r := range req {
if s, ok := r.(string); ok {
required = append(required, s)
}
}
tool.InputSchema.Required = required
}
result = append(result, anthropic.ToolUnionParam{OfTool: &tool})
}
return result
}
func parseResponse(resp *anthropic.Message) *LLMResponse {
var content string
var toolCalls []ToolCall
for _, block := range resp.Content {
switch block.Type {
case "text":
tb := block.AsText()
content += tb.Text
case "tool_use":
tu := block.AsToolUse()
var args map[string]interface{}
if err := json.Unmarshal(tu.Input, &args); err != nil {
log.Printf("anthropic: failed to decode tool call input for %q: %v", tu.Name, err)
args = map[string]interface{}{"raw": string(tu.Input)}
}
toolCalls = append(toolCalls, ToolCall{
ID: tu.ID,
Name: tu.Name,
Arguments: args,
})
}
}
finishReason := "stop"
switch resp.StopReason {
case anthropic.StopReasonToolUse:
finishReason = "tool_calls"
case anthropic.StopReasonMaxTokens:
finishReason = "length"
case anthropic.StopReasonEndTurn:
finishReason = "stop"
}
return &LLMResponse{
Content: content,
ToolCalls: toolCalls,
FinishReason: finishReason,
Usage: &UsageInfo{
PromptTokens: int(resp.Usage.InputTokens),
CompletionTokens: int(resp.Usage.OutputTokens),
TotalTokens: int(resp.Usage.InputTokens + resp.Usage.OutputTokens),
},
}
}
func normalizeBaseURL(apiBase string) string {
base := strings.TrimSpace(apiBase)
if base == "" {
return defaultBaseURL
}
base = strings.TrimRight(base, "/")
if strings.HasSuffix(base, "/v1") {
base = strings.TrimSuffix(base, "/v1")
}
if base == "" {
return defaultBaseURL
}
return base
}
+265
View File
@@ -0,0 +1,265 @@
package anthropicprovider
import (
"encoding/json"
"net/http"
"net/http/httptest"
"sync/atomic"
"testing"
"github.com/anthropics/anthropic-sdk-go"
anthropicoption "github.com/anthropics/anthropic-sdk-go/option"
)
func TestBuildParams_BasicMessage(t *testing.T) {
messages := []Message{
{Role: "user", Content: "Hello"},
}
params, err := buildParams(messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{
"max_tokens": 1024,
})
if err != nil {
t.Fatalf("buildParams() error: %v", err)
}
if string(params.Model) != "claude-sonnet-4-5-20250929" {
t.Errorf("Model = %q, want %q", params.Model, "claude-sonnet-4-5-20250929")
}
if params.MaxTokens != 1024 {
t.Errorf("MaxTokens = %d, want 1024", params.MaxTokens)
}
if len(params.Messages) != 1 {
t.Fatalf("len(Messages) = %d, want 1", len(params.Messages))
}
}
func TestBuildParams_SystemMessage(t *testing.T) {
messages := []Message{
{Role: "system", Content: "You are helpful"},
{Role: "user", Content: "Hi"},
}
params, err := buildParams(messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{})
if err != nil {
t.Fatalf("buildParams() error: %v", err)
}
if len(params.System) != 1 {
t.Fatalf("len(System) = %d, want 1", len(params.System))
}
if params.System[0].Text != "You are helpful" {
t.Errorf("System[0].Text = %q, want %q", params.System[0].Text, "You are helpful")
}
if len(params.Messages) != 1 {
t.Fatalf("len(Messages) = %d, want 1", len(params.Messages))
}
}
func TestBuildParams_ToolCallMessage(t *testing.T) {
messages := []Message{
{Role: "user", Content: "What's the weather?"},
{
Role: "assistant",
Content: "",
ToolCalls: []ToolCall{
{
ID: "call_1",
Name: "get_weather",
Arguments: map[string]interface{}{"city": "SF"},
},
},
},
{Role: "tool", Content: `{"temp": 72}`, ToolCallID: "call_1"},
}
params, err := buildParams(messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{})
if err != nil {
t.Fatalf("buildParams() error: %v", err)
}
if len(params.Messages) != 3 {
t.Fatalf("len(Messages) = %d, want 3", len(params.Messages))
}
}
func TestBuildParams_WithTools(t *testing.T) {
tools := []ToolDefinition{
{
Type: "function",
Function: ToolFunctionDefinition{
Name: "get_weather",
Description: "Get weather for a city",
Parameters: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"city": map[string]interface{}{"type": "string"},
},
"required": []interface{}{"city"},
},
},
},
}
params, err := buildParams([]Message{{Role: "user", Content: "Hi"}}, tools, "claude-sonnet-4-5-20250929", map[string]interface{}{})
if err != nil {
t.Fatalf("buildParams() error: %v", err)
}
if len(params.Tools) != 1 {
t.Fatalf("len(Tools) = %d, want 1", len(params.Tools))
}
}
func TestParseResponse_TextOnly(t *testing.T) {
resp := &anthropic.Message{
Content: []anthropic.ContentBlockUnion{},
Usage: anthropic.Usage{
InputTokens: 10,
OutputTokens: 20,
},
}
result := parseResponse(resp)
if result.Usage.PromptTokens != 10 {
t.Errorf("PromptTokens = %d, want 10", result.Usage.PromptTokens)
}
if result.Usage.CompletionTokens != 20 {
t.Errorf("CompletionTokens = %d, want 20", result.Usage.CompletionTokens)
}
if result.FinishReason != "stop" {
t.Errorf("FinishReason = %q, want %q", result.FinishReason, "stop")
}
}
func TestParseResponse_StopReasons(t *testing.T) {
tests := []struct {
stopReason anthropic.StopReason
want string
}{
{anthropic.StopReasonEndTurn, "stop"},
{anthropic.StopReasonMaxTokens, "length"},
{anthropic.StopReasonToolUse, "tool_calls"},
}
for _, tt := range tests {
resp := &anthropic.Message{
StopReason: tt.stopReason,
}
result := parseResponse(resp)
if result.FinishReason != tt.want {
t.Errorf("StopReason %q: FinishReason = %q, want %q", tt.stopReason, result.FinishReason, tt.want)
}
}
}
func TestProvider_ChatRoundTrip(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/messages" {
http.Error(w, "not found", http.StatusNotFound)
return
}
if r.Header.Get("Authorization") != "Bearer test-token" {
http.Error(w, "unauthorized", http.StatusUnauthorized)
return
}
var reqBody map[string]interface{}
json.NewDecoder(r.Body).Decode(&reqBody)
resp := map[string]interface{}{
"id": "msg_test",
"type": "message",
"role": "assistant",
"model": reqBody["model"],
"stop_reason": "end_turn",
"content": []map[string]interface{}{
{"type": "text", "text": "Hello! How can I help you?"},
},
"usage": map[string]interface{}{
"input_tokens": 15,
"output_tokens": 8,
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
provider := NewProviderWithClient(createAnthropicTestClient(server.URL, "test-token"))
messages := []Message{{Role: "user", Content: "Hello"}}
resp, err := provider.Chat(t.Context(), messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{"max_tokens": 1024})
if err != nil {
t.Fatalf("Chat() error: %v", err)
}
if resp.Content != "Hello! How can I help you?" {
t.Errorf("Content = %q, want %q", resp.Content, "Hello! How can I help you?")
}
if resp.FinishReason != "stop" {
t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "stop")
}
if resp.Usage.PromptTokens != 15 {
t.Errorf("PromptTokens = %d, want 15", resp.Usage.PromptTokens)
}
}
func TestProvider_GetDefaultModel(t *testing.T) {
p := NewProvider("test-token")
if got := p.GetDefaultModel(); got != "claude-sonnet-4-5-20250929" {
t.Errorf("GetDefaultModel() = %q, want %q", got, "claude-sonnet-4-5-20250929")
}
}
func TestProvider_NewProviderWithBaseURL_NormalizesV1Suffix(t *testing.T) {
p := NewProviderWithBaseURL("token", "https://api.anthropic.com/v1/")
if got := p.BaseURL(); got != "https://api.anthropic.com" {
t.Fatalf("BaseURL() = %q, want %q", got, "https://api.anthropic.com")
}
}
func TestProvider_ChatUsesTokenSource(t *testing.T) {
var requests int32
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/messages" {
http.Error(w, "not found", http.StatusNotFound)
return
}
atomic.AddInt32(&requests, 1)
if got := r.Header.Get("Authorization"); got != "Bearer refreshed-token" {
http.Error(w, "unauthorized", http.StatusUnauthorized)
return
}
var reqBody map[string]interface{}
json.NewDecoder(r.Body).Decode(&reqBody)
resp := map[string]interface{}{
"id": "msg_test",
"type": "message",
"role": "assistant",
"model": reqBody["model"],
"stop_reason": "end_turn",
"content": []map[string]interface{}{
{"type": "text", "text": "ok"},
},
"usage": map[string]interface{}{
"input_tokens": 1,
"output_tokens": 1,
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
p := NewProviderWithTokenSourceAndBaseURL("stale-token", func() (string, error) {
return "refreshed-token", nil
}, server.URL)
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hello"}}, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{})
if err != nil {
t.Fatalf("Chat() error: %v", err)
}
if got := atomic.LoadInt32(&requests); got != 1 {
t.Fatalf("requests = %d, want 1", got)
}
}
func createAnthropicTestClient(baseURL, token string) *anthropic.Client {
c := anthropic.NewClient(
anthropicoption.WithAuthToken(token),
anthropicoption.WithBaseURL(baseURL),
)
return &c
}
+28 -171
View File
@@ -2,200 +2,57 @@ package providers
import (
"context"
"encoding/json"
"fmt"
"github.com/anthropics/anthropic-sdk-go"
"github.com/anthropics/anthropic-sdk-go/option"
"github.com/sipeed/picoclaw/pkg/auth"
anthropicprovider "github.com/sipeed/picoclaw/pkg/providers/anthropic"
)
type ClaudeProvider struct {
client *anthropic.Client
tokenSource func() (string, error)
delegate *anthropicprovider.Provider
}
func NewClaudeProvider(token string) *ClaudeProvider {
client := anthropic.NewClient(
option.WithAuthToken(token),
option.WithBaseURL("https://api.anthropic.com"),
)
return &ClaudeProvider{client: &client}
return &ClaudeProvider{
delegate: anthropicprovider.NewProvider(token),
}
}
func NewClaudeProviderWithBaseURL(token, apiBase string) *ClaudeProvider {
return &ClaudeProvider{
delegate: anthropicprovider.NewProviderWithBaseURL(token, apiBase),
}
}
func NewClaudeProviderWithTokenSource(token string, tokenSource func() (string, error)) *ClaudeProvider {
p := NewClaudeProvider(token)
p.tokenSource = tokenSource
return p
return &ClaudeProvider{
delegate: anthropicprovider.NewProviderWithTokenSource(token, tokenSource),
}
}
func NewClaudeProviderWithTokenSourceAndBaseURL(token string, tokenSource func() (string, error), apiBase string) *ClaudeProvider {
return &ClaudeProvider{
delegate: anthropicprovider.NewProviderWithTokenSourceAndBaseURL(token, tokenSource, apiBase),
}
}
func newClaudeProviderWithDelegate(delegate *anthropicprovider.Provider) *ClaudeProvider {
return &ClaudeProvider{delegate: delegate}
}
func (p *ClaudeProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) {
var opts []option.RequestOption
if p.tokenSource != nil {
tok, err := p.tokenSource()
if err != nil {
return nil, fmt.Errorf("refreshing token: %w", err)
}
opts = append(opts, option.WithAuthToken(tok))
}
params, err := buildClaudeParams(messages, tools, model, options)
resp, err := p.delegate.Chat(ctx, messages, tools, model, options)
if err != nil {
return nil, err
}
resp, err := p.client.Messages.New(ctx, params, opts...)
if err != nil {
return nil, fmt.Errorf("claude API call: %w", err)
}
return parseClaudeResponse(resp), nil
return resp, nil
}
func (p *ClaudeProvider) GetDefaultModel() string {
return "claude-sonnet-4-5-20250929"
}
func buildClaudeParams(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (anthropic.MessageNewParams, error) {
var system []anthropic.TextBlockParam
var anthropicMessages []anthropic.MessageParam
for _, msg := range messages {
switch msg.Role {
case "system":
system = append(system, anthropic.TextBlockParam{Text: msg.Content})
case "user":
if msg.ToolCallID != "" {
anthropicMessages = append(anthropicMessages,
anthropic.NewUserMessage(anthropic.NewToolResultBlock(msg.ToolCallID, msg.Content, false)),
)
} else {
anthropicMessages = append(anthropicMessages,
anthropic.NewUserMessage(anthropic.NewTextBlock(msg.Content)),
)
}
case "assistant":
if len(msg.ToolCalls) > 0 {
var blocks []anthropic.ContentBlockParamUnion
if msg.Content != "" {
blocks = append(blocks, anthropic.NewTextBlock(msg.Content))
}
for _, tc := range msg.ToolCalls {
blocks = append(blocks, anthropic.NewToolUseBlock(tc.ID, tc.Arguments, tc.Name))
}
anthropicMessages = append(anthropicMessages, anthropic.NewAssistantMessage(blocks...))
} else {
anthropicMessages = append(anthropicMessages,
anthropic.NewAssistantMessage(anthropic.NewTextBlock(msg.Content)),
)
}
case "tool":
anthropicMessages = append(anthropicMessages,
anthropic.NewUserMessage(anthropic.NewToolResultBlock(msg.ToolCallID, msg.Content, false)),
)
}
}
maxTokens := int64(4096)
if mt, ok := options["max_tokens"].(int); ok {
maxTokens = int64(mt)
}
params := anthropic.MessageNewParams{
Model: anthropic.Model(model),
Messages: anthropicMessages,
MaxTokens: maxTokens,
}
if len(system) > 0 {
params.System = system
}
if temp, ok := options["temperature"].(float64); ok {
params.Temperature = anthropic.Float(temp)
}
if len(tools) > 0 {
params.Tools = translateToolsForClaude(tools)
}
return params, nil
}
func translateToolsForClaude(tools []ToolDefinition) []anthropic.ToolUnionParam {
result := make([]anthropic.ToolUnionParam, 0, len(tools))
for _, t := range tools {
tool := anthropic.ToolParam{
Name: t.Function.Name,
InputSchema: anthropic.ToolInputSchemaParam{
Properties: t.Function.Parameters["properties"],
},
}
if desc := t.Function.Description; desc != "" {
tool.Description = anthropic.String(desc)
}
if req, ok := t.Function.Parameters["required"].([]interface{}); ok {
required := make([]string, 0, len(req))
for _, r := range req {
if s, ok := r.(string); ok {
required = append(required, s)
}
}
tool.InputSchema.Required = required
}
result = append(result, anthropic.ToolUnionParam{OfTool: &tool})
}
return result
}
func parseClaudeResponse(resp *anthropic.Message) *LLMResponse {
var content string
var toolCalls []ToolCall
for _, block := range resp.Content {
switch block.Type {
case "text":
tb := block.AsText()
content += tb.Text
case "tool_use":
tu := block.AsToolUse()
var args map[string]interface{}
if err := json.Unmarshal(tu.Input, &args); err != nil {
args = map[string]interface{}{"raw": string(tu.Input)}
}
toolCalls = append(toolCalls, ToolCall{
ID: tu.ID,
Name: tu.Name,
Arguments: args,
})
}
}
finishReason := "stop"
switch resp.StopReason {
case anthropic.StopReasonToolUse:
finishReason = "tool_calls"
case anthropic.StopReasonMaxTokens:
finishReason = "length"
case anthropic.StopReasonEndTurn:
finishReason = "stop"
}
return &LLMResponse{
Content: content,
ToolCalls: toolCalls,
FinishReason: finishReason,
Usage: &UsageInfo{
PromptTokens: int(resp.Usage.InputTokens),
CompletionTokens: int(resp.Usage.OutputTokens),
TotalTokens: int(resp.Usage.InputTokens + resp.Usage.OutputTokens),
},
}
return p.delegate.GetDefaultModel()
}
func createClaudeTokenSource() func() (string, error) {
return func() (string, error) {
cred, err := auth.GetCredential("anthropic")
cred, err := getCredential("anthropic")
if err != nil {
return "", fmt.Errorf("loading auth credentials: %w", err)
}
+3 -134
View File
@@ -8,140 +8,9 @@ import (
"github.com/anthropics/anthropic-sdk-go"
anthropicoption "github.com/anthropics/anthropic-sdk-go/option"
anthropicprovider "github.com/sipeed/picoclaw/pkg/providers/anthropic"
)
func TestBuildClaudeParams_BasicMessage(t *testing.T) {
messages := []Message{
{Role: "user", Content: "Hello"},
}
params, err := buildClaudeParams(messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{
"max_tokens": 1024,
})
if err != nil {
t.Fatalf("buildClaudeParams() error: %v", err)
}
if string(params.Model) != "claude-sonnet-4-5-20250929" {
t.Errorf("Model = %q, want %q", params.Model, "claude-sonnet-4-5-20250929")
}
if params.MaxTokens != 1024 {
t.Errorf("MaxTokens = %d, want 1024", params.MaxTokens)
}
if len(params.Messages) != 1 {
t.Fatalf("len(Messages) = %d, want 1", len(params.Messages))
}
}
func TestBuildClaudeParams_SystemMessage(t *testing.T) {
messages := []Message{
{Role: "system", Content: "You are helpful"},
{Role: "user", Content: "Hi"},
}
params, err := buildClaudeParams(messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{})
if err != nil {
t.Fatalf("buildClaudeParams() error: %v", err)
}
if len(params.System) != 1 {
t.Fatalf("len(System) = %d, want 1", len(params.System))
}
if params.System[0].Text != "You are helpful" {
t.Errorf("System[0].Text = %q, want %q", params.System[0].Text, "You are helpful")
}
if len(params.Messages) != 1 {
t.Fatalf("len(Messages) = %d, want 1", len(params.Messages))
}
}
func TestBuildClaudeParams_ToolCallMessage(t *testing.T) {
messages := []Message{
{Role: "user", Content: "What's the weather?"},
{
Role: "assistant",
Content: "",
ToolCalls: []ToolCall{
{
ID: "call_1",
Name: "get_weather",
Arguments: map[string]interface{}{"city": "SF"},
},
},
},
{Role: "tool", Content: `{"temp": 72}`, ToolCallID: "call_1"},
}
params, err := buildClaudeParams(messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{})
if err != nil {
t.Fatalf("buildClaudeParams() error: %v", err)
}
if len(params.Messages) != 3 {
t.Fatalf("len(Messages) = %d, want 3", len(params.Messages))
}
}
func TestBuildClaudeParams_WithTools(t *testing.T) {
tools := []ToolDefinition{
{
Type: "function",
Function: ToolFunctionDefinition{
Name: "get_weather",
Description: "Get weather for a city",
Parameters: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"city": map[string]interface{}{"type": "string"},
},
"required": []interface{}{"city"},
},
},
},
}
params, err := buildClaudeParams([]Message{{Role: "user", Content: "Hi"}}, tools, "claude-sonnet-4-5-20250929", map[string]interface{}{})
if err != nil {
t.Fatalf("buildClaudeParams() error: %v", err)
}
if len(params.Tools) != 1 {
t.Fatalf("len(Tools) = %d, want 1", len(params.Tools))
}
}
func TestParseClaudeResponse_TextOnly(t *testing.T) {
resp := &anthropic.Message{
Content: []anthropic.ContentBlockUnion{},
Usage: anthropic.Usage{
InputTokens: 10,
OutputTokens: 20,
},
}
result := parseClaudeResponse(resp)
if result.Usage.PromptTokens != 10 {
t.Errorf("PromptTokens = %d, want 10", result.Usage.PromptTokens)
}
if result.Usage.CompletionTokens != 20 {
t.Errorf("CompletionTokens = %d, want 20", result.Usage.CompletionTokens)
}
if result.FinishReason != "stop" {
t.Errorf("FinishReason = %q, want %q", result.FinishReason, "stop")
}
}
func TestParseClaudeResponse_StopReasons(t *testing.T) {
tests := []struct {
stopReason anthropic.StopReason
want string
}{
{anthropic.StopReasonEndTurn, "stop"},
{anthropic.StopReasonMaxTokens, "length"},
{anthropic.StopReasonToolUse, "tool_calls"},
}
for _, tt := range tests {
resp := &anthropic.Message{
StopReason: tt.stopReason,
}
result := parseClaudeResponse(resp)
if result.FinishReason != tt.want {
t.Errorf("StopReason %q: FinishReason = %q, want %q", tt.stopReason, result.FinishReason, tt.want)
}
}
}
func TestClaudeProvider_ChatRoundTrip(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/messages" {
@@ -175,8 +44,8 @@ func TestClaudeProvider_ChatRoundTrip(t *testing.T) {
}))
defer server.Close()
provider := NewClaudeProvider("test-token")
provider.client = createAnthropicTestClient(server.URL, "test-token")
delegate := anthropicprovider.NewProviderWithClient(createAnthropicTestClient(server.URL, "test-token"))
provider := newClaudeProviderWithDelegate(delegate)
messages := []Message{{Role: "user", Content: "Hello"}}
resp, err := provider.Chat(t.Context(), messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{"max_tokens": 1024})
+360
View File
@@ -0,0 +1,360 @@
package providers
import (
"fmt"
"strings"
"github.com/sipeed/picoclaw/pkg/auth"
"github.com/sipeed/picoclaw/pkg/config"
)
const defaultAnthropicAPIBase = "https://api.anthropic.com/v1"
var getCredential = auth.GetCredential
type providerType int
const (
providerTypeHTTPCompat providerType = iota
providerTypeClaudeAuth
providerTypeCodexAuth
providerTypeCodexCLIToken
providerTypeClaudeCLI
providerTypeCodexCLI
providerTypeGitHubCopilot
)
type providerSelection struct {
providerType providerType
apiKey string
apiBase string
proxy string
model string
workspace string
connectMode string
enableWebSearch bool
}
func createClaudeAuthProvider(apiBase string) (LLMProvider, error) {
if apiBase == "" {
apiBase = defaultAnthropicAPIBase
}
cred, err := getCredential("anthropic")
if err != nil {
return nil, fmt.Errorf("loading auth credentials: %w", err)
}
if cred == nil {
return nil, fmt.Errorf("no credentials for anthropic. Run: picoclaw auth login --provider anthropic")
}
return NewClaudeProviderWithTokenSourceAndBaseURL(cred.AccessToken, createClaudeTokenSource(), apiBase), nil
}
func createCodexAuthProvider(enableWebSearch bool) (LLMProvider, error) {
cred, err := getCredential("openai")
if err != nil {
return nil, fmt.Errorf("loading auth credentials: %w", err)
}
if cred == nil {
return nil, fmt.Errorf("no credentials for openai. Run: picoclaw auth login --provider openai")
}
p := NewCodexProviderWithTokenSource(cred.AccessToken, cred.AccountID, createCodexTokenSource())
p.enableWebSearch = enableWebSearch
return p, nil
}
func resolveProviderSelection(cfg *config.Config) (providerSelection, error) {
model := cfg.Agents.Defaults.Model
providerName := strings.ToLower(cfg.Agents.Defaults.Provider)
lowerModel := strings.ToLower(model)
sel := providerSelection{
providerType: providerTypeHTTPCompat,
model: model,
}
// First, prefer explicit provider configuration.
if providerName != "" {
switch providerName {
case "groq":
if cfg.Providers.Groq.APIKey != "" {
sel.apiKey = cfg.Providers.Groq.APIKey
sel.apiBase = cfg.Providers.Groq.APIBase
sel.proxy = cfg.Providers.Groq.Proxy
if sel.apiBase == "" {
sel.apiBase = "https://api.groq.com/openai/v1"
}
}
case "openai", "gpt":
if cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != "" {
sel.enableWebSearch = cfg.Providers.OpenAI.WebSearch
if cfg.Providers.OpenAI.AuthMethod == "codex-cli" {
sel.providerType = providerTypeCodexCLIToken
return sel, nil
}
if cfg.Providers.OpenAI.AuthMethod == "oauth" || cfg.Providers.OpenAI.AuthMethod == "token" {
sel.providerType = providerTypeCodexAuth
return sel, nil
}
sel.apiKey = cfg.Providers.OpenAI.APIKey
sel.apiBase = cfg.Providers.OpenAI.APIBase
sel.proxy = cfg.Providers.OpenAI.Proxy
if sel.apiBase == "" {
sel.apiBase = "https://api.openai.com/v1"
}
}
case "anthropic", "claude":
if cfg.Providers.Anthropic.APIKey != "" || cfg.Providers.Anthropic.AuthMethod != "" {
if cfg.Providers.Anthropic.AuthMethod == "oauth" || cfg.Providers.Anthropic.AuthMethod == "token" {
sel.apiBase = cfg.Providers.Anthropic.APIBase
if sel.apiBase == "" {
sel.apiBase = defaultAnthropicAPIBase
}
sel.providerType = providerTypeClaudeAuth
return sel, nil
}
sel.apiKey = cfg.Providers.Anthropic.APIKey
sel.apiBase = cfg.Providers.Anthropic.APIBase
sel.proxy = cfg.Providers.Anthropic.Proxy
if sel.apiBase == "" {
sel.apiBase = defaultAnthropicAPIBase
}
}
case "openrouter":
if cfg.Providers.OpenRouter.APIKey != "" {
sel.apiKey = cfg.Providers.OpenRouter.APIKey
sel.proxy = cfg.Providers.OpenRouter.Proxy
if cfg.Providers.OpenRouter.APIBase != "" {
sel.apiBase = cfg.Providers.OpenRouter.APIBase
} else {
sel.apiBase = "https://openrouter.ai/api/v1"
}
}
case "zhipu", "glm":
if cfg.Providers.Zhipu.APIKey != "" {
sel.apiKey = cfg.Providers.Zhipu.APIKey
sel.apiBase = cfg.Providers.Zhipu.APIBase
sel.proxy = cfg.Providers.Zhipu.Proxy
if sel.apiBase == "" {
sel.apiBase = "https://open.bigmodel.cn/api/paas/v4"
}
}
case "gemini", "google":
if cfg.Providers.Gemini.APIKey != "" {
sel.apiKey = cfg.Providers.Gemini.APIKey
sel.apiBase = cfg.Providers.Gemini.APIBase
sel.proxy = cfg.Providers.Gemini.Proxy
if sel.apiBase == "" {
sel.apiBase = "https://generativelanguage.googleapis.com/v1beta"
}
}
case "vllm":
if cfg.Providers.VLLM.APIBase != "" {
sel.apiKey = cfg.Providers.VLLM.APIKey
sel.apiBase = cfg.Providers.VLLM.APIBase
sel.proxy = cfg.Providers.VLLM.Proxy
}
case "shengsuanyun":
if cfg.Providers.ShengSuanYun.APIKey != "" {
sel.apiKey = cfg.Providers.ShengSuanYun.APIKey
sel.apiBase = cfg.Providers.ShengSuanYun.APIBase
sel.proxy = cfg.Providers.ShengSuanYun.Proxy
if sel.apiBase == "" {
sel.apiBase = "https://router.shengsuanyun.com/api/v1"
}
}
case "nvidia":
if cfg.Providers.Nvidia.APIKey != "" {
sel.apiKey = cfg.Providers.Nvidia.APIKey
sel.apiBase = cfg.Providers.Nvidia.APIBase
sel.proxy = cfg.Providers.Nvidia.Proxy
if sel.apiBase == "" {
sel.apiBase = "https://integrate.api.nvidia.com/v1"
}
}
case "claude-cli", "claude-code", "claudecode":
workspace := cfg.WorkspacePath()
if workspace == "" {
workspace = "."
}
sel.providerType = providerTypeClaudeCLI
sel.workspace = workspace
return sel, nil
case "codex-cli", "codex-code":
workspace := cfg.WorkspacePath()
if workspace == "" {
workspace = "."
}
sel.providerType = providerTypeCodexCLI
sel.workspace = workspace
return sel, nil
case "deepseek":
if cfg.Providers.DeepSeek.APIKey != "" {
sel.apiKey = cfg.Providers.DeepSeek.APIKey
sel.apiBase = cfg.Providers.DeepSeek.APIBase
sel.proxy = cfg.Providers.DeepSeek.Proxy
if sel.apiBase == "" {
sel.apiBase = "https://api.deepseek.com/v1"
}
if model != "deepseek-chat" && model != "deepseek-reasoner" {
sel.model = "deepseek-chat"
}
}
case "github_copilot", "copilot":
sel.providerType = providerTypeGitHubCopilot
if cfg.Providers.GitHubCopilot.APIBase != "" {
sel.apiBase = cfg.Providers.GitHubCopilot.APIBase
} else {
sel.apiBase = "localhost:4321"
}
sel.connectMode = cfg.Providers.GitHubCopilot.ConnectMode
return sel, nil
}
}
// Fallback: infer provider from model and configured keys.
if sel.apiKey == "" && sel.apiBase == "" {
switch {
case (strings.Contains(lowerModel, "kimi") || strings.Contains(lowerModel, "moonshot") || strings.HasPrefix(model, "moonshot/")) && cfg.Providers.Moonshot.APIKey != "":
sel.apiKey = cfg.Providers.Moonshot.APIKey
sel.apiBase = cfg.Providers.Moonshot.APIBase
sel.proxy = cfg.Providers.Moonshot.Proxy
if sel.apiBase == "" {
sel.apiBase = "https://api.moonshot.cn/v1"
}
case strings.HasPrefix(model, "openrouter/") ||
strings.HasPrefix(model, "anthropic/") ||
strings.HasPrefix(model, "openai/") ||
strings.HasPrefix(model, "meta-llama/") ||
strings.HasPrefix(model, "deepseek/") ||
strings.HasPrefix(model, "google/"):
sel.apiKey = cfg.Providers.OpenRouter.APIKey
sel.proxy = cfg.Providers.OpenRouter.Proxy
if cfg.Providers.OpenRouter.APIBase != "" {
sel.apiBase = cfg.Providers.OpenRouter.APIBase
} else {
sel.apiBase = "https://openrouter.ai/api/v1"
}
case (strings.Contains(lowerModel, "claude") || strings.HasPrefix(model, "anthropic/")) &&
(cfg.Providers.Anthropic.APIKey != "" || cfg.Providers.Anthropic.AuthMethod != ""):
if cfg.Providers.Anthropic.AuthMethod == "oauth" || cfg.Providers.Anthropic.AuthMethod == "token" {
sel.apiBase = cfg.Providers.Anthropic.APIBase
if sel.apiBase == "" {
sel.apiBase = defaultAnthropicAPIBase
}
sel.providerType = providerTypeClaudeAuth
return sel, nil
}
sel.apiKey = cfg.Providers.Anthropic.APIKey
sel.apiBase = cfg.Providers.Anthropic.APIBase
sel.proxy = cfg.Providers.Anthropic.Proxy
if sel.apiBase == "" {
sel.apiBase = defaultAnthropicAPIBase
}
case (strings.Contains(lowerModel, "gpt") || strings.HasPrefix(model, "openai/")) &&
(cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != ""):
sel.enableWebSearch = cfg.Providers.OpenAI.WebSearch
if cfg.Providers.OpenAI.AuthMethod == "codex-cli" {
sel.providerType = providerTypeCodexCLIToken
return sel, nil
}
if cfg.Providers.OpenAI.AuthMethod == "oauth" || cfg.Providers.OpenAI.AuthMethod == "token" {
sel.providerType = providerTypeCodexAuth
return sel, nil
}
sel.apiKey = cfg.Providers.OpenAI.APIKey
sel.apiBase = cfg.Providers.OpenAI.APIBase
sel.proxy = cfg.Providers.OpenAI.Proxy
if sel.apiBase == "" {
sel.apiBase = "https://api.openai.com/v1"
}
case (strings.Contains(lowerModel, "gemini") || strings.HasPrefix(model, "google/")) && cfg.Providers.Gemini.APIKey != "":
sel.apiKey = cfg.Providers.Gemini.APIKey
sel.apiBase = cfg.Providers.Gemini.APIBase
sel.proxy = cfg.Providers.Gemini.Proxy
if sel.apiBase == "" {
sel.apiBase = "https://generativelanguage.googleapis.com/v1beta"
}
case (strings.Contains(lowerModel, "glm") || strings.Contains(lowerModel, "zhipu") || strings.Contains(lowerModel, "zai")) && cfg.Providers.Zhipu.APIKey != "":
sel.apiKey = cfg.Providers.Zhipu.APIKey
sel.apiBase = cfg.Providers.Zhipu.APIBase
sel.proxy = cfg.Providers.Zhipu.Proxy
if sel.apiBase == "" {
sel.apiBase = "https://open.bigmodel.cn/api/paas/v4"
}
case (strings.Contains(lowerModel, "groq") || strings.HasPrefix(model, "groq/")) && cfg.Providers.Groq.APIKey != "":
sel.apiKey = cfg.Providers.Groq.APIKey
sel.apiBase = cfg.Providers.Groq.APIBase
sel.proxy = cfg.Providers.Groq.Proxy
if sel.apiBase == "" {
sel.apiBase = "https://api.groq.com/openai/v1"
}
case (strings.Contains(lowerModel, "nvidia") || strings.HasPrefix(model, "nvidia/")) && cfg.Providers.Nvidia.APIKey != "":
sel.apiKey = cfg.Providers.Nvidia.APIKey
sel.apiBase = cfg.Providers.Nvidia.APIBase
sel.proxy = cfg.Providers.Nvidia.Proxy
if sel.apiBase == "" {
sel.apiBase = "https://integrate.api.nvidia.com/v1"
}
case (strings.Contains(lowerModel, "ollama") || strings.HasPrefix(model, "ollama/")) && cfg.Providers.Ollama.APIKey != "":
sel.apiKey = cfg.Providers.Ollama.APIKey
sel.apiBase = cfg.Providers.Ollama.APIBase
sel.proxy = cfg.Providers.Ollama.Proxy
if sel.apiBase == "" {
sel.apiBase = "http://localhost:11434/v1"
}
case cfg.Providers.VLLM.APIBase != "":
sel.apiKey = cfg.Providers.VLLM.APIKey
sel.apiBase = cfg.Providers.VLLM.APIBase
sel.proxy = cfg.Providers.VLLM.Proxy
default:
if cfg.Providers.OpenRouter.APIKey != "" {
sel.apiKey = cfg.Providers.OpenRouter.APIKey
sel.proxy = cfg.Providers.OpenRouter.Proxy
if cfg.Providers.OpenRouter.APIBase != "" {
sel.apiBase = cfg.Providers.OpenRouter.APIBase
} else {
sel.apiBase = "https://openrouter.ai/api/v1"
}
} else {
return providerSelection{}, fmt.Errorf("no API key configured for model: %s", model)
}
}
}
if sel.providerType == providerTypeHTTPCompat {
if sel.apiKey == "" && !strings.HasPrefix(model, "bedrock/") {
return providerSelection{}, fmt.Errorf("no API key configured for provider (model: %s)", model)
}
if sel.apiBase == "" {
return providerSelection{}, fmt.Errorf("no API base configured for provider (model: %s)", model)
}
}
return sel, nil
}
func CreateProvider(cfg *config.Config) (LLMProvider, error) {
sel, err := resolveProviderSelection(cfg)
if err != nil {
return nil, err
}
switch sel.providerType {
case providerTypeClaudeAuth:
return createClaudeAuthProvider(sel.apiBase)
case providerTypeCodexAuth:
return createCodexAuthProvider(sel.enableWebSearch)
case providerTypeCodexCLIToken:
c := NewCodexProviderWithTokenSource("", "", CreateCodexCliTokenSource())
c.enableWebSearch = sel.enableWebSearch
return c, nil
case providerTypeClaudeCLI:
return NewClaudeCliProvider(sel.workspace), nil
case providerTypeCodexCLI:
return NewCodexCliProvider(sel.workspace), nil
case providerTypeGitHubCopilot:
return NewGitHubCopilotProvider(sel.apiBase, sel.connectMode, sel.model)
default:
return NewHTTPProvider(sel.apiKey, sel.apiBase, sel.proxy), nil
}
}
+299
View File
@@ -0,0 +1,299 @@
package providers
import (
"strings"
"testing"
"github.com/sipeed/picoclaw/pkg/auth"
"github.com/sipeed/picoclaw/pkg/config"
)
func TestResolveProviderSelection(t *testing.T) {
tests := []struct {
name string
setup func(*config.Config)
wantType providerType
wantAPIBase string
wantProxy string
wantErrSubstr string
}{
{
name: "explicit claude-cli provider routes to cli provider type",
setup: func(cfg *config.Config) {
cfg.Agents.Defaults.Provider = "claude-cli"
cfg.Agents.Defaults.Workspace = "/tmp/ws"
},
wantType: providerTypeClaudeCLI,
},
{
name: "explicit copilot provider routes to github copilot type",
setup: func(cfg *config.Config) {
cfg.Agents.Defaults.Provider = "copilot"
},
wantType: providerTypeGitHubCopilot,
wantAPIBase: "localhost:4321",
},
{
name: "explicit deepseek provider uses deepseek defaults",
setup: func(cfg *config.Config) {
cfg.Agents.Defaults.Provider = "deepseek"
cfg.Agents.Defaults.Model = "deepseek/deepseek-chat"
cfg.Providers.DeepSeek.APIKey = "deepseek-key"
cfg.Providers.DeepSeek.Proxy = "http://127.0.0.1:7890"
},
wantType: providerTypeHTTPCompat,
wantAPIBase: "https://api.deepseek.com/v1",
wantProxy: "http://127.0.0.1:7890",
},
{
name: "explicit shengsuanyun provider uses defaults",
setup: func(cfg *config.Config) {
cfg.Agents.Defaults.Provider = "shengsuanyun"
cfg.Providers.ShengSuanYun.APIKey = "ssy-key"
cfg.Providers.ShengSuanYun.Proxy = "http://127.0.0.1:7890"
},
wantType: providerTypeHTTPCompat,
wantAPIBase: "https://router.shengsuanyun.com/api/v1",
wantProxy: "http://127.0.0.1:7890",
},
{
name: "explicit nvidia provider uses defaults",
setup: func(cfg *config.Config) {
cfg.Agents.Defaults.Provider = "nvidia"
cfg.Providers.Nvidia.APIKey = "nvapi-test"
cfg.Providers.Nvidia.Proxy = "http://127.0.0.1:7890"
},
wantType: providerTypeHTTPCompat,
wantAPIBase: "https://integrate.api.nvidia.com/v1",
wantProxy: "http://127.0.0.1:7890",
},
{
name: "openrouter model uses openrouter defaults",
setup: func(cfg *config.Config) {
cfg.Agents.Defaults.Model = "openrouter/auto"
cfg.Providers.OpenRouter.APIKey = "sk-or-test"
},
wantType: providerTypeHTTPCompat,
wantAPIBase: "https://openrouter.ai/api/v1",
},
{
name: "anthropic oauth routes to claude auth provider",
setup: func(cfg *config.Config) {
cfg.Agents.Defaults.Model = "claude-sonnet-4-5-20250929"
cfg.Providers.Anthropic.AuthMethod = "oauth"
},
wantType: providerTypeClaudeAuth,
},
{
name: "openai oauth routes to codex auth provider",
setup: func(cfg *config.Config) {
cfg.Agents.Defaults.Model = "gpt-4o"
cfg.Providers.OpenAI.AuthMethod = "oauth"
},
wantType: providerTypeCodexAuth,
},
{
name: "openai codex-cli auth routes to codex cli token provider",
setup: func(cfg *config.Config) {
cfg.Agents.Defaults.Model = "gpt-4o"
cfg.Providers.OpenAI.AuthMethod = "codex-cli"
},
wantType: providerTypeCodexCLIToken,
},
{
name: "explicit codex-code provider routes to codex cli provider type",
setup: func(cfg *config.Config) {
cfg.Agents.Defaults.Provider = "codex-code"
cfg.Agents.Defaults.Workspace = "/tmp/ws"
},
wantType: providerTypeCodexCLI,
},
{
name: "zhipu model uses zhipu base default",
setup: func(cfg *config.Config) {
cfg.Agents.Defaults.Model = "glm-4.7"
cfg.Providers.Zhipu.APIKey = "zhipu-key"
},
wantType: providerTypeHTTPCompat,
wantAPIBase: "https://open.bigmodel.cn/api/paas/v4",
},
{
name: "groq model uses groq base default",
setup: func(cfg *config.Config) {
cfg.Agents.Defaults.Model = "groq/llama-3.3-70b"
cfg.Providers.Groq.APIKey = "gsk-key"
},
wantType: providerTypeHTTPCompat,
wantAPIBase: "https://api.groq.com/openai/v1",
},
{
name: "ollama model uses ollama base default",
setup: func(cfg *config.Config) {
cfg.Agents.Defaults.Model = "ollama/qwen2.5:14b"
cfg.Providers.Ollama.APIKey = "ollama-key"
},
wantType: providerTypeHTTPCompat,
wantAPIBase: "http://localhost:11434/v1",
},
{
name: "moonshot model keeps proxy and default base",
setup: func(cfg *config.Config) {
cfg.Agents.Defaults.Model = "moonshot/kimi-k2.5"
cfg.Providers.Moonshot.APIKey = "moonshot-key"
cfg.Providers.Moonshot.Proxy = "http://127.0.0.1:7890"
},
wantType: providerTypeHTTPCompat,
wantAPIBase: "https://api.moonshot.cn/v1",
wantProxy: "http://127.0.0.1:7890",
},
{
name: "missing keys returns model config error",
setup: func(cfg *config.Config) {
cfg.Agents.Defaults.Model = "custom-model"
},
wantErrSubstr: "no API key configured for model",
},
{
name: "openrouter prefix without key returns provider key error",
setup: func(cfg *config.Config) {
cfg.Agents.Defaults.Model = "openrouter/auto"
},
wantErrSubstr: "no API key configured for provider",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := config.DefaultConfig()
tt.setup(cfg)
got, err := resolveProviderSelection(cfg)
if tt.wantErrSubstr != "" {
if err == nil {
t.Fatalf("expected error containing %q, got nil", tt.wantErrSubstr)
}
if !strings.Contains(err.Error(), tt.wantErrSubstr) {
t.Fatalf("error = %q, want substring %q", err.Error(), tt.wantErrSubstr)
}
return
}
if err != nil {
t.Fatalf("resolveProviderSelection() error = %v", err)
}
if got.providerType != tt.wantType {
t.Fatalf("providerType = %v, want %v", got.providerType, tt.wantType)
}
if tt.wantAPIBase != "" && got.apiBase != tt.wantAPIBase {
t.Fatalf("apiBase = %q, want %q", got.apiBase, tt.wantAPIBase)
}
if tt.wantProxy != "" && got.proxy != tt.wantProxy {
t.Fatalf("proxy = %q, want %q", got.proxy, tt.wantProxy)
}
})
}
}
func TestCreateProviderReturnsHTTPProviderForOpenRouter(t *testing.T) {
cfg := config.DefaultConfig()
cfg.Agents.Defaults.Model = "openrouter/auto"
cfg.Providers.OpenRouter.APIKey = "sk-or-test"
provider, err := CreateProvider(cfg)
if err != nil {
t.Fatalf("CreateProvider() error = %v", err)
}
if _, ok := provider.(*HTTPProvider); !ok {
t.Fatalf("provider type = %T, want *HTTPProvider", provider)
}
}
func TestCreateProviderReturnsCodexCliProviderForCodexCode(t *testing.T) {
cfg := config.DefaultConfig()
cfg.Agents.Defaults.Provider = "codex-code"
provider, err := CreateProvider(cfg)
if err != nil {
t.Fatalf("CreateProvider() error = %v", err)
}
if _, ok := provider.(*CodexCliProvider); !ok {
t.Fatalf("provider type = %T, want *CodexCliProvider", provider)
}
}
func TestCreateProviderReturnsCodexProviderForCodexCliAuthMethod(t *testing.T) {
cfg := config.DefaultConfig()
cfg.Agents.Defaults.Provider = "openai"
cfg.Providers.OpenAI.AuthMethod = "codex-cli"
provider, err := CreateProvider(cfg)
if err != nil {
t.Fatalf("CreateProvider() error = %v", err)
}
if _, ok := provider.(*CodexProvider); !ok {
t.Fatalf("provider type = %T, want *CodexProvider", provider)
}
}
func TestCreateProviderReturnsClaudeProviderForAnthropicOAuth(t *testing.T) {
originalGetCredential := getCredential
t.Cleanup(func() { getCredential = originalGetCredential })
getCredential = func(provider string) (*auth.AuthCredential, error) {
if provider != "anthropic" {
t.Fatalf("provider = %q, want anthropic", provider)
}
return &auth.AuthCredential{
AccessToken: "anthropic-token",
}, nil
}
cfg := config.DefaultConfig()
cfg.Agents.Defaults.Provider = "anthropic"
cfg.Providers.Anthropic.AuthMethod = "oauth"
cfg.Providers.Anthropic.APIBase = "https://proxy.example.com/v1"
provider, err := CreateProvider(cfg)
if err != nil {
t.Fatalf("CreateProvider() error = %v", err)
}
claudeProvider, ok := provider.(*ClaudeProvider)
if !ok {
t.Fatalf("provider type = %T, want *ClaudeProvider", provider)
}
if got := claudeProvider.delegate.BaseURL(); got != "https://proxy.example.com" {
t.Fatalf("anthropic baseURL = %q, want %q", got, "https://proxy.example.com")
}
}
func TestCreateProviderReturnsCodexProviderForOpenAIOAuth(t *testing.T) {
originalGetCredential := getCredential
t.Cleanup(func() { getCredential = originalGetCredential })
getCredential = func(provider string) (*auth.AuthCredential, error) {
if provider != "openai" {
t.Fatalf("provider = %q, want openai", provider)
}
return &auth.AuthCredential{
AccessToken: "openai-token",
AccountID: "acct_123",
}, nil
}
cfg := config.DefaultConfig()
cfg.Agents.Defaults.Provider = "openai"
cfg.Providers.OpenAI.AuthMethod = "oauth"
provider, err := CreateProvider(cfg)
if err != nil {
t.Fatalf("CreateProvider() error = %v", err)
}
if _, ok := provider.(*CodexProvider); !ok {
t.Fatalf("provider type = %T, want *CodexProvider", provider)
}
}
+4 -428
View File
@@ -7,448 +7,24 @@
package providers
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
"github.com/sipeed/picoclaw/pkg/auth"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/providers/openai_compat"
)
type HTTPProvider struct {
apiKey string
apiBase string
httpClient *http.Client
delegate *openai_compat.Provider
}
func NewHTTPProvider(apiKey, apiBase, proxy string) *HTTPProvider {
client := &http.Client{
Timeout: 120 * time.Second,
}
if proxy != "" {
proxyURL, err := url.Parse(proxy)
if err == nil {
client.Transport = &http.Transport{
Proxy: http.ProxyURL(proxyURL),
}
}
}
return &HTTPProvider{
apiKey: apiKey,
apiBase: strings.TrimRight(apiBase, "/"),
httpClient: client,
delegate: openai_compat.NewProvider(apiKey, apiBase, proxy),
}
}
func (p *HTTPProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) {
if p.apiBase == "" {
return nil, fmt.Errorf("API base not configured")
}
// Strip provider prefix from model name (e.g., moonshot/kimi-k2.5 -> kimi-k2.5, groq/openai/gpt-oss-120b -> openai/gpt-oss-120b, ollama/qwen2.5:14b -> qwen2.5:14b)
if idx := strings.Index(model, "/"); idx != -1 {
prefix := model[:idx]
if prefix == "moonshot" || prefix == "nvidia" || prefix == "groq" || prefix == "ollama" {
model = model[idx+1:]
}
}
requestBody := map[string]interface{}{
"model": model,
"messages": messages,
}
if len(tools) > 0 {
requestBody["tools"] = tools
requestBody["tool_choice"] = "auto"
}
if maxTokens, ok := options["max_tokens"].(int); ok {
lowerModel := strings.ToLower(model)
if strings.Contains(lowerModel, "glm") || strings.Contains(lowerModel, "o1") {
requestBody["max_completion_tokens"] = maxTokens
} else {
requestBody["max_tokens"] = maxTokens
}
}
if temperature, ok := options["temperature"].(float64); ok {
lowerModel := strings.ToLower(model)
// Kimi k2 models only support temperature=1
if strings.Contains(lowerModel, "kimi") && strings.Contains(lowerModel, "k2") {
requestBody["temperature"] = 1.0
} else {
requestBody["temperature"] = temperature
}
}
jsonData, err := json.Marshal(requestBody)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
req, err := http.NewRequestWithContext(ctx, "POST", p.apiBase+"/chat/completions", bytes.NewReader(jsonData))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
if p.apiKey != "" {
req.Header.Set("Authorization", "Bearer "+p.apiKey)
}
resp, err := p.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("API request failed:\n Status: %d\n Body: %s", resp.StatusCode, string(body))
}
return p.parseResponse(body)
}
func (p *HTTPProvider) parseResponse(body []byte) (*LLMResponse, error) {
var apiResponse struct {
Choices []struct {
Message struct {
Content string `json:"content"`
ToolCalls []struct {
ID string `json:"id"`
Type string `json:"type"`
Function *struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
} `json:"function"`
} `json:"tool_calls"`
} `json:"message"`
FinishReason string `json:"finish_reason"`
} `json:"choices"`
Usage *UsageInfo `json:"usage"`
}
if err := json.Unmarshal(body, &apiResponse); err != nil {
return nil, fmt.Errorf("failed to unmarshal response: %w", err)
}
if len(apiResponse.Choices) == 0 {
return &LLMResponse{
Content: "",
FinishReason: "stop",
}, nil
}
choice := apiResponse.Choices[0]
toolCalls := make([]ToolCall, 0, len(choice.Message.ToolCalls))
for _, tc := range choice.Message.ToolCalls {
arguments := make(map[string]interface{})
name := ""
// Handle OpenAI format with nested function object
if tc.Type == "function" && tc.Function != nil {
name = tc.Function.Name
if tc.Function.Arguments != "" {
if err := json.Unmarshal([]byte(tc.Function.Arguments), &arguments); err != nil {
arguments["raw"] = tc.Function.Arguments
}
}
} else if tc.Function != nil {
// Legacy format without type field
name = tc.Function.Name
if tc.Function.Arguments != "" {
if err := json.Unmarshal([]byte(tc.Function.Arguments), &arguments); err != nil {
arguments["raw"] = tc.Function.Arguments
}
}
}
toolCalls = append(toolCalls, ToolCall{
ID: tc.ID,
Name: name,
Arguments: arguments,
})
}
return &LLMResponse{
Content: choice.Message.Content,
ToolCalls: toolCalls,
FinishReason: choice.FinishReason,
Usage: apiResponse.Usage,
}, nil
return p.delegate.Chat(ctx, messages, tools, model, options)
}
func (p *HTTPProvider) GetDefaultModel() string {
return ""
}
func createClaudeAuthProvider() (LLMProvider, error) {
cred, err := auth.GetCredential("anthropic")
if err != nil {
return nil, fmt.Errorf("loading auth credentials: %w", err)
}
if cred == nil {
return nil, fmt.Errorf("no credentials for anthropic. Run: picoclaw auth login --provider anthropic")
}
return NewClaudeProviderWithTokenSource(cred.AccessToken, createClaudeTokenSource()), nil
}
func createCodexAuthProvider(enableWebSearch bool) (LLMProvider, error) {
cred, err := auth.GetCredential("openai")
if err != nil {
return nil, fmt.Errorf("loading auth credentials: %w", err)
}
if cred == nil {
return nil, fmt.Errorf("no credentials for openai. Run: picoclaw auth login --provider openai")
}
p := NewCodexProviderWithTokenSource(cred.AccessToken, cred.AccountID, createCodexTokenSource())
p.enableWebSearch = enableWebSearch
return p, nil
}
func CreateProvider(cfg *config.Config) (LLMProvider, error) {
model := cfg.Agents.Defaults.Model
providerName := strings.ToLower(cfg.Agents.Defaults.Provider)
var apiKey, apiBase, proxy string
lowerModel := strings.ToLower(model)
// First, try to use explicitly configured provider
if providerName != "" {
switch providerName {
case "groq":
if cfg.Providers.Groq.APIKey != "" {
apiKey = cfg.Providers.Groq.APIKey
apiBase = cfg.Providers.Groq.APIBase
if apiBase == "" {
apiBase = "https://api.groq.com/openai/v1"
}
}
case "openai", "gpt":
if cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != "" {
if cfg.Providers.OpenAI.AuthMethod == "codex-cli" {
c := NewCodexProviderWithTokenSource("", "", CreateCodexCliTokenSource())
c.enableWebSearch = cfg.Providers.OpenAI.WebSearch
return c, nil
}
if cfg.Providers.OpenAI.AuthMethod == "oauth" || cfg.Providers.OpenAI.AuthMethod == "token" {
return createCodexAuthProvider(cfg.Providers.OpenAI.WebSearch)
}
apiKey = cfg.Providers.OpenAI.APIKey
apiBase = cfg.Providers.OpenAI.APIBase
if apiBase == "" {
apiBase = "https://api.openai.com/v1"
}
}
case "anthropic", "claude":
if cfg.Providers.Anthropic.APIKey != "" || cfg.Providers.Anthropic.AuthMethod != "" {
if cfg.Providers.Anthropic.AuthMethod == "oauth" || cfg.Providers.Anthropic.AuthMethod == "token" {
return createClaudeAuthProvider()
}
apiKey = cfg.Providers.Anthropic.APIKey
apiBase = cfg.Providers.Anthropic.APIBase
if apiBase == "" {
apiBase = "https://api.anthropic.com/v1"
}
}
case "openrouter":
if cfg.Providers.OpenRouter.APIKey != "" {
apiKey = cfg.Providers.OpenRouter.APIKey
if cfg.Providers.OpenRouter.APIBase != "" {
apiBase = cfg.Providers.OpenRouter.APIBase
} else {
apiBase = "https://openrouter.ai/api/v1"
}
}
case "zhipu", "glm":
if cfg.Providers.Zhipu.APIKey != "" {
apiKey = cfg.Providers.Zhipu.APIKey
apiBase = cfg.Providers.Zhipu.APIBase
if apiBase == "" {
apiBase = "https://open.bigmodel.cn/api/paas/v4"
}
}
case "gemini", "google":
if cfg.Providers.Gemini.APIKey != "" {
apiKey = cfg.Providers.Gemini.APIKey
apiBase = cfg.Providers.Gemini.APIBase
if apiBase == "" {
apiBase = "https://generativelanguage.googleapis.com/v1beta"
}
}
case "vllm":
if cfg.Providers.VLLM.APIBase != "" {
apiKey = cfg.Providers.VLLM.APIKey
apiBase = cfg.Providers.VLLM.APIBase
}
case "shengsuanyun":
if cfg.Providers.ShengSuanYun.APIKey != "" {
apiKey = cfg.Providers.ShengSuanYun.APIKey
apiBase = cfg.Providers.ShengSuanYun.APIBase
if apiBase == "" {
apiBase = "https://router.shengsuanyun.com/api/v1"
}
}
case "claude-cli", "claudecode", "claude-code":
workspace := cfg.WorkspacePath()
if workspace == "" {
workspace = "."
}
return NewClaudeCliProvider(workspace), nil
case "codex-cli", "codex-code":
workspace := cfg.WorkspacePath()
if workspace == "" {
workspace = "."
}
return NewCodexCliProvider(workspace), nil
case "deepseek":
if cfg.Providers.DeepSeek.APIKey != "" {
apiKey = cfg.Providers.DeepSeek.APIKey
apiBase = cfg.Providers.DeepSeek.APIBase
if apiBase == "" {
apiBase = "https://api.deepseek.com/v1"
}
if model != "deepseek-chat" && model != "deepseek-reasoner" {
model = "deepseek-chat"
}
}
case "github_copilot", "copilot":
if cfg.Providers.GitHubCopilot.APIBase != "" {
apiBase = cfg.Providers.GitHubCopilot.APIBase
} else {
apiBase = "localhost:4321"
}
return NewGitHubCopilotProvider(apiBase, cfg.Providers.GitHubCopilot.ConnectMode, model)
}
}
// Fallback: detect provider from model name
if apiKey == "" && apiBase == "" {
switch {
case (strings.Contains(lowerModel, "kimi") || strings.Contains(lowerModel, "moonshot") || strings.HasPrefix(model, "moonshot/")) && cfg.Providers.Moonshot.APIKey != "":
apiKey = cfg.Providers.Moonshot.APIKey
apiBase = cfg.Providers.Moonshot.APIBase
proxy = cfg.Providers.Moonshot.Proxy
if apiBase == "" {
apiBase = "https://api.moonshot.cn/v1"
}
case strings.HasPrefix(model, "openrouter/") || strings.HasPrefix(model, "anthropic/") || strings.HasPrefix(model, "openai/") || strings.HasPrefix(model, "meta-llama/") || strings.HasPrefix(model, "deepseek/") || strings.HasPrefix(model, "google/"):
apiKey = cfg.Providers.OpenRouter.APIKey
proxy = cfg.Providers.OpenRouter.Proxy
if cfg.Providers.OpenRouter.APIBase != "" {
apiBase = cfg.Providers.OpenRouter.APIBase
} else {
apiBase = "https://openrouter.ai/api/v1"
}
case (strings.Contains(lowerModel, "claude") || strings.HasPrefix(model, "anthropic/")) && (cfg.Providers.Anthropic.APIKey != "" || cfg.Providers.Anthropic.AuthMethod != ""):
if cfg.Providers.Anthropic.AuthMethod == "oauth" || cfg.Providers.Anthropic.AuthMethod == "token" {
return createClaudeAuthProvider()
}
apiKey = cfg.Providers.Anthropic.APIKey
apiBase = cfg.Providers.Anthropic.APIBase
proxy = cfg.Providers.Anthropic.Proxy
if apiBase == "" {
apiBase = "https://api.anthropic.com/v1"
}
case (strings.Contains(lowerModel, "gpt") || strings.HasPrefix(model, "openai/")) && (cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != ""):
if cfg.Providers.OpenAI.AuthMethod == "oauth" || cfg.Providers.OpenAI.AuthMethod == "token" {
return createCodexAuthProvider(cfg.Providers.OpenAI.WebSearch)
}
apiKey = cfg.Providers.OpenAI.APIKey
apiBase = cfg.Providers.OpenAI.APIBase
proxy = cfg.Providers.OpenAI.Proxy
if apiBase == "" {
apiBase = "https://api.openai.com/v1"
}
case (strings.Contains(lowerModel, "gemini") || strings.HasPrefix(model, "google/")) && cfg.Providers.Gemini.APIKey != "":
apiKey = cfg.Providers.Gemini.APIKey
apiBase = cfg.Providers.Gemini.APIBase
proxy = cfg.Providers.Gemini.Proxy
if apiBase == "" {
apiBase = "https://generativelanguage.googleapis.com/v1beta"
}
case (strings.Contains(lowerModel, "glm") || strings.Contains(lowerModel, "zhipu") || strings.Contains(lowerModel, "zai")) && cfg.Providers.Zhipu.APIKey != "":
apiKey = cfg.Providers.Zhipu.APIKey
apiBase = cfg.Providers.Zhipu.APIBase
proxy = cfg.Providers.Zhipu.Proxy
if apiBase == "" {
apiBase = "https://open.bigmodel.cn/api/paas/v4"
}
case (strings.Contains(lowerModel, "groq") || strings.HasPrefix(model, "groq/")) && cfg.Providers.Groq.APIKey != "":
apiKey = cfg.Providers.Groq.APIKey
apiBase = cfg.Providers.Groq.APIBase
proxy = cfg.Providers.Groq.Proxy
if apiBase == "" {
apiBase = "https://api.groq.com/openai/v1"
}
case (strings.Contains(lowerModel, "nvidia") || strings.HasPrefix(model, "nvidia/")) && cfg.Providers.Nvidia.APIKey != "":
apiKey = cfg.Providers.Nvidia.APIKey
apiBase = cfg.Providers.Nvidia.APIBase
proxy = cfg.Providers.Nvidia.Proxy
if apiBase == "" {
apiBase = "https://integrate.api.nvidia.com/v1"
}
case (strings.Contains(lowerModel, "ollama") || strings.HasPrefix(model, "ollama/")) && cfg.Providers.Ollama.APIKey != "":
fmt.Println("Ollama provider selected based on model name prefix")
apiKey = cfg.Providers.Ollama.APIKey
apiBase = cfg.Providers.Ollama.APIBase
proxy = cfg.Providers.Ollama.Proxy
if apiBase == "" {
apiBase = "http://localhost:11434/v1"
}
fmt.Println("Ollama apiBase:", apiBase)
case cfg.Providers.VLLM.APIBase != "":
apiKey = cfg.Providers.VLLM.APIKey
apiBase = cfg.Providers.VLLM.APIBase
proxy = cfg.Providers.VLLM.Proxy
default:
if cfg.Providers.OpenRouter.APIKey != "" {
apiKey = cfg.Providers.OpenRouter.APIKey
proxy = cfg.Providers.OpenRouter.Proxy
if cfg.Providers.OpenRouter.APIBase != "" {
apiBase = cfg.Providers.OpenRouter.APIBase
} else {
apiBase = "https://openrouter.ai/api/v1"
}
} else {
return nil, fmt.Errorf("no API key configured for model: %s", model)
}
}
}
if apiKey == "" && !strings.HasPrefix(model, "bedrock/") {
return nil, fmt.Errorf("no API key configured for provider (model: %s)", model)
}
if apiBase == "" {
return nil, fmt.Errorf("no API base configured for provider (model: %s)", model)
}
return NewHTTPProvider(apiKey, apiBase, proxy), nil
}
+232
View File
@@ -0,0 +1,232 @@
package openai_compat
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"net/url"
"strings"
"time"
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
)
type ToolCall = protocoltypes.ToolCall
type FunctionCall = protocoltypes.FunctionCall
type LLMResponse = protocoltypes.LLMResponse
type UsageInfo = protocoltypes.UsageInfo
type Message = protocoltypes.Message
type ToolDefinition = protocoltypes.ToolDefinition
type ToolFunctionDefinition = protocoltypes.ToolFunctionDefinition
type Provider struct {
apiKey string
apiBase string
httpClient *http.Client
}
func NewProvider(apiKey, apiBase, proxy string) *Provider {
client := &http.Client{
Timeout: 120 * time.Second,
}
if proxy != "" {
parsed, err := url.Parse(proxy)
if err == nil {
client.Transport = &http.Transport{
Proxy: http.ProxyURL(parsed),
}
} else {
log.Printf("openai_compat: invalid proxy URL %q: %v", proxy, err)
}
}
return &Provider{
apiKey: apiKey,
apiBase: strings.TrimRight(apiBase, "/"),
httpClient: client,
}
}
func (p *Provider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) {
if p.apiBase == "" {
return nil, fmt.Errorf("API base not configured")
}
model = normalizeModel(model, p.apiBase)
requestBody := map[string]interface{}{
"model": model,
"messages": messages,
}
if len(tools) > 0 {
requestBody["tools"] = tools
requestBody["tool_choice"] = "auto"
}
if maxTokens, ok := asInt(options["max_tokens"]); ok {
lowerModel := strings.ToLower(model)
if strings.Contains(lowerModel, "glm") || strings.Contains(lowerModel, "o1") {
requestBody["max_completion_tokens"] = maxTokens
} else {
requestBody["max_tokens"] = maxTokens
}
}
if temperature, ok := asFloat(options["temperature"]); ok {
lowerModel := strings.ToLower(model)
// Kimi k2 models only support temperature=1.
if strings.Contains(lowerModel, "kimi") && strings.Contains(lowerModel, "k2") {
requestBody["temperature"] = 1.0
} else {
requestBody["temperature"] = temperature
}
}
jsonData, err := json.Marshal(requestBody)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
req, err := http.NewRequestWithContext(ctx, "POST", p.apiBase+"/chat/completions", bytes.NewReader(jsonData))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
if p.apiKey != "" {
req.Header.Set("Authorization", "Bearer "+p.apiKey)
}
resp, err := p.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("API request failed:\n Status: %d\n Body: %s", resp.StatusCode, string(body))
}
return parseResponse(body)
}
func parseResponse(body []byte) (*LLMResponse, error) {
var apiResponse struct {
Choices []struct {
Message struct {
Content string `json:"content"`
ToolCalls []struct {
ID string `json:"id"`
Type string `json:"type"`
Function *struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
} `json:"function"`
} `json:"tool_calls"`
} `json:"message"`
FinishReason string `json:"finish_reason"`
} `json:"choices"`
Usage *UsageInfo `json:"usage"`
}
if err := json.Unmarshal(body, &apiResponse); err != nil {
return nil, fmt.Errorf("failed to unmarshal response: %w", err)
}
if len(apiResponse.Choices) == 0 {
return &LLMResponse{
Content: "",
FinishReason: "stop",
}, nil
}
choice := apiResponse.Choices[0]
toolCalls := make([]ToolCall, 0, len(choice.Message.ToolCalls))
for _, tc := range choice.Message.ToolCalls {
arguments := make(map[string]interface{})
name := ""
if tc.Function != nil {
name = tc.Function.Name
if tc.Function.Arguments != "" {
if err := json.Unmarshal([]byte(tc.Function.Arguments), &arguments); err != nil {
log.Printf("openai_compat: failed to decode tool call arguments for %q: %v", name, err)
arguments["raw"] = tc.Function.Arguments
}
}
}
toolCalls = append(toolCalls, ToolCall{
ID: tc.ID,
Name: name,
Arguments: arguments,
})
}
return &LLMResponse{
Content: choice.Message.Content,
ToolCalls: toolCalls,
FinishReason: choice.FinishReason,
Usage: apiResponse.Usage,
}, nil
}
func normalizeModel(model, apiBase string) string {
idx := strings.Index(model, "/")
if idx == -1 {
return model
}
if strings.Contains(strings.ToLower(apiBase), "openrouter.ai") {
return model
}
prefix := strings.ToLower(model[:idx])
switch prefix {
case "moonshot", "nvidia", "groq", "ollama", "deepseek", "google", "openrouter", "zhipu":
return model[idx+1:]
default:
return model
}
}
func asInt(v interface{}) (int, bool) {
switch val := v.(type) {
case int:
return val, true
case int64:
return int(val), true
case float64:
return int(val), true
case float32:
return int(val), true
default:
return 0, false
}
}
func asFloat(v interface{}) (float64, bool) {
switch val := v.(type) {
case float64:
return val, true
case float32:
return float64(val), true
case int:
return float64(val), true
case int64:
return float64(val), true
default:
return 0, false
}
}
@@ -0,0 +1,277 @@
package openai_compat
import (
"encoding/json"
"net/http"
"net/http/httptest"
"net/url"
"testing"
)
func TestProviderChat_UsesMaxCompletionTokensForGLM(t *testing.T) {
var requestBody map[string]interface{}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/chat/completions" {
http.Error(w, "not found", http.StatusNotFound)
return
}
if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
resp := map[string]interface{}{
"choices": []map[string]interface{}{
{
"message": map[string]interface{}{"content": "ok"},
"finish_reason": "stop",
},
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
p := NewProvider("key", server.URL, "")
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "glm-4.7", map[string]interface{}{"max_tokens": 1234})
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
if _, ok := requestBody["max_completion_tokens"]; !ok {
t.Fatalf("expected max_completion_tokens in request body")
}
if _, ok := requestBody["max_tokens"]; ok {
t.Fatalf("did not expect max_tokens key for glm model")
}
}
func TestProviderChat_ParsesToolCalls(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
resp := map[string]interface{}{
"choices": []map[string]interface{}{
{
"message": map[string]interface{}{
"content": "",
"tool_calls": []map[string]interface{}{
{
"id": "call_1",
"type": "function",
"function": map[string]interface{}{
"name": "get_weather",
"arguments": "{\"city\":\"SF\"}",
},
},
},
},
"finish_reason": "tool_calls",
},
},
"usage": map[string]interface{}{
"prompt_tokens": 10,
"completion_tokens": 5,
"total_tokens": 15,
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
p := NewProvider("key", server.URL, "")
out, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "gpt-4o", nil)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
if len(out.ToolCalls) != 1 {
t.Fatalf("len(ToolCalls) = %d, want 1", len(out.ToolCalls))
}
if out.ToolCalls[0].Name != "get_weather" {
t.Fatalf("ToolCalls[0].Name = %q, want %q", out.ToolCalls[0].Name, "get_weather")
}
if out.ToolCalls[0].Arguments["city"] != "SF" {
t.Fatalf("ToolCalls[0].Arguments[city] = %v, want SF", out.ToolCalls[0].Arguments["city"])
}
}
func TestProviderChat_HTTPError(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "bad request", http.StatusBadRequest)
}))
defer server.Close()
p := NewProvider("key", server.URL, "")
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "gpt-4o", nil)
if err == nil {
t.Fatal("expected error, got nil")
}
}
func TestProviderChat_StripsMoonshotPrefixAndNormalizesKimiTemperature(t *testing.T) {
var requestBody map[string]interface{}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
resp := map[string]interface{}{
"choices": []map[string]interface{}{
{
"message": map[string]interface{}{"content": "ok"},
"finish_reason": "stop",
},
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
p := NewProvider("key", server.URL, "")
_, err := p.Chat(
t.Context(),
[]Message{{Role: "user", Content: "hi"}},
nil,
"moonshot/kimi-k2.5",
map[string]interface{}{"temperature": 0.3},
)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
if requestBody["model"] != "kimi-k2.5" {
t.Fatalf("model = %v, want kimi-k2.5", requestBody["model"])
}
if requestBody["temperature"] != 1.0 {
t.Fatalf("temperature = %v, want 1.0", requestBody["temperature"])
}
}
func TestProviderChat_StripsGroqAndOllamaPrefixes(t *testing.T) {
tests := []struct {
name string
input string
wantModel string
}{
{
name: "strips groq prefix and keeps nested model",
input: "groq/openai/gpt-oss-120b",
wantModel: "openai/gpt-oss-120b",
},
{
name: "strips ollama prefix",
input: "ollama/qwen2.5:14b",
wantModel: "qwen2.5:14b",
},
{
name: "strips deepseek prefix",
input: "deepseek/deepseek-chat",
wantModel: "deepseek-chat",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var requestBody map[string]interface{}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
resp := map[string]interface{}{
"choices": []map[string]interface{}{
{
"message": map[string]interface{}{"content": "ok"},
"finish_reason": "stop",
},
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
p := NewProvider("key", server.URL, "")
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, tt.input, nil)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
if requestBody["model"] != tt.wantModel {
t.Fatalf("model = %v, want %s", requestBody["model"], tt.wantModel)
}
})
}
}
func TestProvider_ProxyConfigured(t *testing.T) {
proxyURL := "http://127.0.0.1:8080"
p := NewProvider("key", "https://example.com", proxyURL)
transport, ok := p.httpClient.Transport.(*http.Transport)
if !ok || transport == nil {
t.Fatalf("expected http transport with proxy, got %T", p.httpClient.Transport)
}
req := &http.Request{URL: &url.URL{Scheme: "https", Host: "api.example.com"}}
gotProxy, err := transport.Proxy(req)
if err != nil {
t.Fatalf("proxy function returned error: %v", err)
}
if gotProxy == nil || gotProxy.String() != proxyURL {
t.Fatalf("proxy = %v, want %s", gotProxy, proxyURL)
}
}
func TestProviderChat_AcceptsNumericOptionTypes(t *testing.T) {
var requestBody map[string]interface{}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
resp := map[string]interface{}{
"choices": []map[string]interface{}{
{
"message": map[string]interface{}{"content": "ok"},
"finish_reason": "stop",
},
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
p := NewProvider("key", server.URL, "")
_, err := p.Chat(
t.Context(),
[]Message{{Role: "user", Content: "hi"}},
nil,
"gpt-4o",
map[string]interface{}{"max_tokens": float64(512), "temperature": 1},
)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
if requestBody["max_tokens"] != float64(512) {
t.Fatalf("max_tokens = %v, want 512", requestBody["max_tokens"])
}
if requestBody["temperature"] != float64(1) {
t.Fatalf("temperature = %v, want 1", requestBody["temperature"])
}
}
func TestNormalizeModel_UsesAPIBase(t *testing.T) {
if got := normalizeModel("deepseek/deepseek-chat", "https://api.deepseek.com/v1"); got != "deepseek-chat" {
t.Fatalf("normalizeModel(deepseek) = %q, want %q", got, "deepseek-chat")
}
if got := normalizeModel("openrouter/auto", "https://openrouter.ai/api/v1"); got != "openrouter/auto" {
t.Fatalf("normalizeModel(openrouter) = %q, want %q", got, "openrouter/auto")
}
}
+45
View File
@@ -0,0 +1,45 @@
package protocoltypes
type ToolCall struct {
ID string `json:"id"`
Type string `json:"type,omitempty"`
Function *FunctionCall `json:"function,omitempty"`
Name string `json:"name,omitempty"`
Arguments map[string]interface{} `json:"arguments,omitempty"`
}
type FunctionCall struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
}
type LLMResponse struct {
Content string `json:"content"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
FinishReason string `json:"finish_reason"`
Usage *UsageInfo `json:"usage,omitempty"`
}
type UsageInfo struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
}
type ToolDefinition struct {
Type string `json:"type"`
Function ToolFunctionDefinition `json:"function"`
}
type ToolFunctionDefinition struct {
Name string `json:"name"`
Description string `json:"description"`
Parameters map[string]interface{} `json:"parameters"`
}
+11 -43
View File
@@ -1,52 +1,20 @@
package providers
import "context"
import (
"context"
type ToolCall struct {
ID string `json:"id"`
Type string `json:"type,omitempty"`
Function *FunctionCall `json:"function,omitempty"`
Name string `json:"name,omitempty"`
Arguments map[string]interface{} `json:"arguments,omitempty"`
}
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
)
type FunctionCall struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
}
type LLMResponse struct {
Content string `json:"content"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
FinishReason string `json:"finish_reason"`
Usage *UsageInfo `json:"usage,omitempty"`
}
type UsageInfo struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
}
type ToolCall = protocoltypes.ToolCall
type FunctionCall = protocoltypes.FunctionCall
type LLMResponse = protocoltypes.LLMResponse
type UsageInfo = protocoltypes.UsageInfo
type Message = protocoltypes.Message
type ToolDefinition = protocoltypes.ToolDefinition
type ToolFunctionDefinition = protocoltypes.ToolFunctionDefinition
type LLMProvider interface {
Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error)
GetDefaultModel() string
}
type ToolDefinition struct {
Type string `json:"type"`
Function ToolFunctionDefinition `json:"function"`
}
type ToolFunctionDefinition struct {
Name string `json:"name"`
Description string `json:"description"`
Parameters map[string]interface{} `json:"parameters"`
}