providers: finalize PR213 review fixes

Phase 1: centralize protocol message/tool/response types in protocoltypes and keep compatibility aliases in providers and protocol packages.

Phase 1: preserve HTTPProvider constructor compatibility and route Anthropic api_base through factory auth/provider constructors with base URL normalization.

Phase 2: expand provider routing/auth tests (deepseek/nvidia/shengsuanyun, codex/claude oauth/codex-cli) and add openai_compat + anthropic coverage for proxy transport, model normalization, numeric option coercion, token-source refresh, and base URL behavior.

Phase 3: apply gofmt and validate with Dockerized tests (go test ./pkg/providers/... ./pkg/migrate and go test ./...).
This commit is contained in:
Jared Mahotiere
2026-02-17 11:13:10 -05:00
parent e3c246a36f
commit c4cbb5fb35
10 changed files with 468 additions and 374 deletions
+53 -46
View File
@@ -4,74 +4,59 @@ import (
"context"
"encoding/json"
"fmt"
"log"
"strings"
"github.com/anthropics/anthropic-sdk-go"
"github.com/anthropics/anthropic-sdk-go/option"
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
)
type ToolCall struct {
ID string `json:"id"`
Type string `json:"type,omitempty"`
Function *FunctionCall `json:"function,omitempty"`
Name string `json:"name,omitempty"`
Arguments map[string]interface{} `json:"arguments,omitempty"`
}
type ToolCall = protocoltypes.ToolCall
type FunctionCall = protocoltypes.FunctionCall
type LLMResponse = protocoltypes.LLMResponse
type UsageInfo = protocoltypes.UsageInfo
type Message = protocoltypes.Message
type ToolDefinition = protocoltypes.ToolDefinition
type ToolFunctionDefinition = protocoltypes.ToolFunctionDefinition
type FunctionCall struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
}
type LLMResponse struct {
Content string `json:"content"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
FinishReason string `json:"finish_reason"`
Usage *UsageInfo `json:"usage,omitempty"`
}
type UsageInfo struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
}
type ToolDefinition struct {
Type string `json:"type"`
Function ToolFunctionDefinition `json:"function"`
}
type ToolFunctionDefinition struct {
Name string `json:"name"`
Description string `json:"description"`
Parameters map[string]interface{} `json:"parameters"`
}
const defaultBaseURL = "https://api.anthropic.com"
type Provider struct {
client *anthropic.Client
tokenSource func() (string, error)
baseURL string
}
func NewProvider(token string) *Provider {
return NewProviderWithBaseURL(token, "")
}
func NewProviderWithBaseURL(token, apiBase string) *Provider {
baseURL := normalizeBaseURL(apiBase)
client := anthropic.NewClient(
option.WithAuthToken(token),
option.WithBaseURL("https://api.anthropic.com"),
option.WithBaseURL(baseURL),
)
return &Provider{client: &client}
return &Provider{
client: &client,
baseURL: baseURL,
}
}
func NewProviderWithClient(client *anthropic.Client) *Provider {
return &Provider{client: client}
return &Provider{
client: client,
baseURL: defaultBaseURL,
}
}
func NewProviderWithTokenSource(token string, tokenSource func() (string, error)) *Provider {
p := NewProvider(token)
return NewProviderWithTokenSourceAndBaseURL(token, tokenSource, "")
}
func NewProviderWithTokenSourceAndBaseURL(token string, tokenSource func() (string, error), apiBase string) *Provider {
p := NewProviderWithBaseURL(token, apiBase)
p.tokenSource = tokenSource
return p
}
@@ -103,6 +88,10 @@ func (p *Provider) GetDefaultModel() string {
return "claude-sonnet-4-5-20250929"
}
func (p *Provider) BaseURL() string {
return p.baseURL
}
func buildParams(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (anthropic.MessageNewParams, error) {
var system []anthropic.TextBlockParam
var anthropicMessages []anthropic.MessageParam
@@ -208,6 +197,7 @@ func parseResponse(resp *anthropic.Message) *LLMResponse {
tu := block.AsToolUse()
var args map[string]interface{}
if err := json.Unmarshal(tu.Input, &args); err != nil {
log.Printf("anthropic: failed to decode tool call input for %q: %v", tu.Name, err)
args = map[string]interface{}{"raw": string(tu.Input)}
}
toolCalls = append(toolCalls, ToolCall{
@@ -239,3 +229,20 @@ func parseResponse(resp *anthropic.Message) *LLMResponse {
},
}
}
func normalizeBaseURL(apiBase string) string {
base := strings.TrimSpace(apiBase)
if base == "" {
return defaultBaseURL
}
base = strings.TrimRight(base, "/")
if strings.HasSuffix(base, "/v1") {
base = strings.TrimSuffix(base, "/v1")
}
if base == "" {
return defaultBaseURL
}
return base
}
+57
View File
@@ -4,6 +4,7 @@ import (
"encoding/json"
"net/http"
"net/http/httptest"
"sync/atomic"
"testing"
"github.com/anthropics/anthropic-sdk-go"
@@ -199,6 +200,62 @@ func TestProvider_GetDefaultModel(t *testing.T) {
}
}
func TestProvider_NewProviderWithBaseURL_NormalizesV1Suffix(t *testing.T) {
p := NewProviderWithBaseURL("token", "https://api.anthropic.com/v1/")
if got := p.BaseURL(); got != "https://api.anthropic.com" {
t.Fatalf("BaseURL() = %q, want %q", got, "https://api.anthropic.com")
}
}
func TestProvider_ChatUsesTokenSource(t *testing.T) {
var requests int32
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/messages" {
http.Error(w, "not found", http.StatusNotFound)
return
}
atomic.AddInt32(&requests, 1)
if got := r.Header.Get("Authorization"); got != "Bearer refreshed-token" {
http.Error(w, "unauthorized", http.StatusUnauthorized)
return
}
var reqBody map[string]interface{}
json.NewDecoder(r.Body).Decode(&reqBody)
resp := map[string]interface{}{
"id": "msg_test",
"type": "message",
"role": "assistant",
"model": reqBody["model"],
"stop_reason": "end_turn",
"content": []map[string]interface{}{
{"type": "text", "text": "ok"},
},
"usage": map[string]interface{}{
"input_tokens": 1,
"output_tokens": 1,
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
p := NewProviderWithTokenSourceAndBaseURL("stale-token", func() (string, error) {
return "refreshed-token", nil
}, server.URL)
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hello"}}, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{})
if err != nil {
t.Fatalf("Chat() error: %v", err)
}
if got := atomic.LoadInt32(&requests); got != 1 {
t.Fatalf("requests = %d, want 1", got)
}
}
func createAnthropicTestClient(baseURL, token string) *anthropic.Client {
c := anthropic.NewClient(
anthropicoption.WithAuthToken(token),
+15 -103
View File
@@ -3,8 +3,6 @@ package providers
import (
"context"
"fmt"
"github.com/sipeed/picoclaw/pkg/auth"
anthropicprovider "github.com/sipeed/picoclaw/pkg/providers/anthropic"
)
@@ -18,28 +16,34 @@ func NewClaudeProvider(token string) *ClaudeProvider {
}
}
func NewClaudeProviderWithBaseURL(token, apiBase string) *ClaudeProvider {
return &ClaudeProvider{
delegate: anthropicprovider.NewProviderWithBaseURL(token, apiBase),
}
}
func NewClaudeProviderWithTokenSource(token string, tokenSource func() (string, error)) *ClaudeProvider {
return &ClaudeProvider{
delegate: anthropicprovider.NewProviderWithTokenSource(token, tokenSource),
}
}
func NewClaudeProviderWithTokenSourceAndBaseURL(token string, tokenSource func() (string, error), apiBase string) *ClaudeProvider {
return &ClaudeProvider{
delegate: anthropicprovider.NewProviderWithTokenSourceAndBaseURL(token, tokenSource, apiBase),
}
}
func newClaudeProviderWithDelegate(delegate *anthropicprovider.Provider) *ClaudeProvider {
return &ClaudeProvider{delegate: delegate}
}
func (p *ClaudeProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) {
resp, err := p.delegate.Chat(
ctx,
toAnthropicProviderMessages(messages),
toAnthropicProviderTools(tools),
model,
options,
)
resp, err := p.delegate.Chat(ctx, messages, tools, model, options)
if err != nil {
return nil, err
}
return fromAnthropicProviderResponse(resp), nil
return resp, nil
}
func (p *ClaudeProvider) GetDefaultModel() string {
@@ -48,7 +52,7 @@ func (p *ClaudeProvider) GetDefaultModel() string {
func createClaudeTokenSource() func() (string, error) {
return func() (string, error) {
cred, err := auth.GetCredential("anthropic")
cred, err := getCredential("anthropic")
if err != nil {
return "", fmt.Errorf("loading auth credentials: %w", err)
}
@@ -58,95 +62,3 @@ func createClaudeTokenSource() func() (string, error) {
return cred.AccessToken, nil
}
}
func toAnthropicProviderMessages(messages []Message) []anthropicprovider.Message {
out := make([]anthropicprovider.Message, 0, len(messages))
for _, msg := range messages {
out = append(out, anthropicprovider.Message{
Role: msg.Role,
Content: msg.Content,
ToolCalls: toAnthropicProviderToolCalls(msg.ToolCalls),
ToolCallID: msg.ToolCallID,
})
}
return out
}
func toAnthropicProviderTools(tools []ToolDefinition) []anthropicprovider.ToolDefinition {
out := make([]anthropicprovider.ToolDefinition, 0, len(tools))
for _, t := range tools {
out = append(out, anthropicprovider.ToolDefinition{
Type: t.Type,
Function: anthropicprovider.ToolFunctionDefinition{
Name: t.Function.Name,
Description: t.Function.Description,
Parameters: t.Function.Parameters,
},
})
}
return out
}
func toAnthropicProviderToolCalls(toolCalls []ToolCall) []anthropicprovider.ToolCall {
out := make([]anthropicprovider.ToolCall, 0, len(toolCalls))
for _, tc := range toolCalls {
var fn *anthropicprovider.FunctionCall
if tc.Function != nil {
fn = &anthropicprovider.FunctionCall{
Name: tc.Function.Name,
Arguments: tc.Function.Arguments,
}
}
out = append(out, anthropicprovider.ToolCall{
ID: tc.ID,
Type: tc.Type,
Function: fn,
Name: tc.Name,
Arguments: tc.Arguments,
})
}
return out
}
func fromAnthropicProviderResponse(resp *anthropicprovider.LLMResponse) *LLMResponse {
if resp == nil {
return &LLMResponse{}
}
var usage *UsageInfo
if resp.Usage != nil {
usage = &UsageInfo{
PromptTokens: resp.Usage.PromptTokens,
CompletionTokens: resp.Usage.CompletionTokens,
TotalTokens: resp.Usage.TotalTokens,
}
}
return &LLMResponse{
Content: resp.Content,
ToolCalls: fromAnthropicProviderToolCalls(resp.ToolCalls),
FinishReason: resp.FinishReason,
Usage: usage,
}
}
func fromAnthropicProviderToolCalls(toolCalls []anthropicprovider.ToolCall) []ToolCall {
out := make([]ToolCall, 0, len(toolCalls))
for _, tc := range toolCalls {
var fn *FunctionCall
if tc.Function != nil {
fn = &FunctionCall{
Name: tc.Function.Name,
Arguments: tc.Function.Arguments,
}
}
out = append(out, ToolCall{
ID: tc.ID,
Type: tc.Type,
Function: fn,
Name: tc.Name,
Arguments: tc.Arguments,
})
}
return out
}
+40 -7
View File
@@ -8,6 +8,10 @@ import (
"github.com/sipeed/picoclaw/pkg/config"
)
const defaultAnthropicAPIBase = "https://api.anthropic.com/v1"
var getCredential = auth.GetCredential
type providerType int
const (
@@ -30,19 +34,22 @@ type providerSelection struct {
connectMode string
}
func createClaudeAuthProvider() (LLMProvider, error) {
cred, err := auth.GetCredential("anthropic")
func createClaudeAuthProvider(apiBase string) (LLMProvider, error) {
if apiBase == "" {
apiBase = defaultAnthropicAPIBase
}
cred, err := getCredential("anthropic")
if err != nil {
return nil, fmt.Errorf("loading auth credentials: %w", err)
}
if cred == nil {
return nil, fmt.Errorf("no credentials for anthropic. Run: picoclaw auth login --provider anthropic")
}
return NewClaudeProviderWithTokenSource(cred.AccessToken, createClaudeTokenSource()), nil
return NewClaudeProviderWithTokenSourceAndBaseURL(cred.AccessToken, createClaudeTokenSource(), apiBase), nil
}
func createCodexAuthProvider() (LLMProvider, error) {
cred, err := auth.GetCredential("openai")
cred, err := getCredential("openai")
if err != nil {
return nil, fmt.Errorf("loading auth credentials: %w", err)
}
@@ -69,6 +76,7 @@ func resolveProviderSelection(cfg *config.Config) (providerSelection, error) {
if cfg.Providers.Groq.APIKey != "" {
sel.apiKey = cfg.Providers.Groq.APIKey
sel.apiBase = cfg.Providers.Groq.APIBase
sel.proxy = cfg.Providers.Groq.Proxy
if sel.apiBase == "" {
sel.apiBase = "https://api.groq.com/openai/v1"
}
@@ -85,6 +93,7 @@ func resolveProviderSelection(cfg *config.Config) (providerSelection, error) {
}
sel.apiKey = cfg.Providers.OpenAI.APIKey
sel.apiBase = cfg.Providers.OpenAI.APIBase
sel.proxy = cfg.Providers.OpenAI.Proxy
if sel.apiBase == "" {
sel.apiBase = "https://api.openai.com/v1"
}
@@ -92,18 +101,24 @@ func resolveProviderSelection(cfg *config.Config) (providerSelection, error) {
case "anthropic", "claude":
if cfg.Providers.Anthropic.APIKey != "" || cfg.Providers.Anthropic.AuthMethod != "" {
if cfg.Providers.Anthropic.AuthMethod == "oauth" || cfg.Providers.Anthropic.AuthMethod == "token" {
sel.apiBase = cfg.Providers.Anthropic.APIBase
if sel.apiBase == "" {
sel.apiBase = defaultAnthropicAPIBase
}
sel.providerType = providerTypeClaudeAuth
return sel, nil
}
sel.apiKey = cfg.Providers.Anthropic.APIKey
sel.apiBase = cfg.Providers.Anthropic.APIBase
sel.proxy = cfg.Providers.Anthropic.Proxy
if sel.apiBase == "" {
sel.apiBase = "https://api.anthropic.com/v1"
sel.apiBase = defaultAnthropicAPIBase
}
}
case "openrouter":
if cfg.Providers.OpenRouter.APIKey != "" {
sel.apiKey = cfg.Providers.OpenRouter.APIKey
sel.proxy = cfg.Providers.OpenRouter.Proxy
if cfg.Providers.OpenRouter.APIBase != "" {
sel.apiBase = cfg.Providers.OpenRouter.APIBase
} else {
@@ -114,6 +129,7 @@ func resolveProviderSelection(cfg *config.Config) (providerSelection, error) {
if cfg.Providers.Zhipu.APIKey != "" {
sel.apiKey = cfg.Providers.Zhipu.APIKey
sel.apiBase = cfg.Providers.Zhipu.APIBase
sel.proxy = cfg.Providers.Zhipu.Proxy
if sel.apiBase == "" {
sel.apiBase = "https://open.bigmodel.cn/api/paas/v4"
}
@@ -122,6 +138,7 @@ func resolveProviderSelection(cfg *config.Config) (providerSelection, error) {
if cfg.Providers.Gemini.APIKey != "" {
sel.apiKey = cfg.Providers.Gemini.APIKey
sel.apiBase = cfg.Providers.Gemini.APIBase
sel.proxy = cfg.Providers.Gemini.Proxy
if sel.apiBase == "" {
sel.apiBase = "https://generativelanguage.googleapis.com/v1beta"
}
@@ -130,15 +147,26 @@ func resolveProviderSelection(cfg *config.Config) (providerSelection, error) {
if cfg.Providers.VLLM.APIBase != "" {
sel.apiKey = cfg.Providers.VLLM.APIKey
sel.apiBase = cfg.Providers.VLLM.APIBase
sel.proxy = cfg.Providers.VLLM.Proxy
}
case "shengsuanyun":
if cfg.Providers.ShengSuanYun.APIKey != "" {
sel.apiKey = cfg.Providers.ShengSuanYun.APIKey
sel.apiBase = cfg.Providers.ShengSuanYun.APIBase
sel.proxy = cfg.Providers.ShengSuanYun.Proxy
if sel.apiBase == "" {
sel.apiBase = "https://router.shengsuanyun.com/api/v1"
}
}
case "nvidia":
if cfg.Providers.Nvidia.APIKey != "" {
sel.apiKey = cfg.Providers.Nvidia.APIKey
sel.apiBase = cfg.Providers.Nvidia.APIBase
sel.proxy = cfg.Providers.Nvidia.Proxy
if sel.apiBase == "" {
sel.apiBase = "https://integrate.api.nvidia.com/v1"
}
}
case "claude-cli", "claude-code", "claudecode":
workspace := cfg.WorkspacePath()
if workspace == "" {
@@ -159,6 +187,7 @@ func resolveProviderSelection(cfg *config.Config) (providerSelection, error) {
if cfg.Providers.DeepSeek.APIKey != "" {
sel.apiKey = cfg.Providers.DeepSeek.APIKey
sel.apiBase = cfg.Providers.DeepSeek.APIBase
sel.proxy = cfg.Providers.DeepSeek.Proxy
if sel.apiBase == "" {
sel.apiBase = "https://api.deepseek.com/v1"
}
@@ -204,6 +233,10 @@ func resolveProviderSelection(cfg *config.Config) (providerSelection, error) {
case (strings.Contains(lowerModel, "claude") || strings.HasPrefix(model, "anthropic/")) &&
(cfg.Providers.Anthropic.APIKey != "" || cfg.Providers.Anthropic.AuthMethod != ""):
if cfg.Providers.Anthropic.AuthMethod == "oauth" || cfg.Providers.Anthropic.AuthMethod == "token" {
sel.apiBase = cfg.Providers.Anthropic.APIBase
if sel.apiBase == "" {
sel.apiBase = defaultAnthropicAPIBase
}
sel.providerType = providerTypeClaudeAuth
return sel, nil
}
@@ -211,7 +244,7 @@ func resolveProviderSelection(cfg *config.Config) (providerSelection, error) {
sel.apiBase = cfg.Providers.Anthropic.APIBase
sel.proxy = cfg.Providers.Anthropic.Proxy
if sel.apiBase == "" {
sel.apiBase = "https://api.anthropic.com/v1"
sel.apiBase = defaultAnthropicAPIBase
}
case (strings.Contains(lowerModel, "gpt") || strings.HasPrefix(model, "openai/")) &&
(cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != ""):
@@ -303,7 +336,7 @@ func CreateProvider(cfg *config.Config) (LLMProvider, error) {
switch sel.providerType {
case providerTypeClaudeAuth:
return createClaudeAuthProvider()
return createClaudeAuthProvider(sel.apiBase)
case providerTypeCodexAuth:
return createCodexAuthProvider()
case providerTypeCodexCLIToken:
+95
View File
@@ -4,6 +4,7 @@ import (
"strings"
"testing"
"github.com/sipeed/picoclaw/pkg/auth"
"github.com/sipeed/picoclaw/pkg/config"
)
@@ -32,6 +33,40 @@ func TestResolveProviderSelection(t *testing.T) {
wantType: providerTypeGitHubCopilot,
wantAPIBase: "localhost:4321",
},
{
name: "explicit deepseek provider uses deepseek defaults",
setup: func(cfg *config.Config) {
cfg.Agents.Defaults.Provider = "deepseek"
cfg.Agents.Defaults.Model = "deepseek/deepseek-chat"
cfg.Providers.DeepSeek.APIKey = "deepseek-key"
cfg.Providers.DeepSeek.Proxy = "http://127.0.0.1:7890"
},
wantType: providerTypeHTTPCompat,
wantAPIBase: "https://api.deepseek.com/v1",
wantProxy: "http://127.0.0.1:7890",
},
{
name: "explicit shengsuanyun provider uses defaults",
setup: func(cfg *config.Config) {
cfg.Agents.Defaults.Provider = "shengsuanyun"
cfg.Providers.ShengSuanYun.APIKey = "ssy-key"
cfg.Providers.ShengSuanYun.Proxy = "http://127.0.0.1:7890"
},
wantType: providerTypeHTTPCompat,
wantAPIBase: "https://router.shengsuanyun.com/api/v1",
wantProxy: "http://127.0.0.1:7890",
},
{
name: "explicit nvidia provider uses defaults",
setup: func(cfg *config.Config) {
cfg.Agents.Defaults.Provider = "nvidia"
cfg.Providers.Nvidia.APIKey = "nvapi-test"
cfg.Providers.Nvidia.Proxy = "http://127.0.0.1:7890"
},
wantType: providerTypeHTTPCompat,
wantAPIBase: "https://integrate.api.nvidia.com/v1",
wantProxy: "http://127.0.0.1:7890",
},
{
name: "openrouter model uses openrouter defaults",
setup: func(cfg *config.Config) {
@@ -202,3 +237,63 @@ func TestCreateProviderReturnsCodexProviderForCodexCliAuthMethod(t *testing.T) {
t.Fatalf("provider type = %T, want *CodexProvider", provider)
}
}
func TestCreateProviderReturnsClaudeProviderForAnthropicOAuth(t *testing.T) {
originalGetCredential := getCredential
t.Cleanup(func() { getCredential = originalGetCredential })
getCredential = func(provider string) (*auth.AuthCredential, error) {
if provider != "anthropic" {
t.Fatalf("provider = %q, want anthropic", provider)
}
return &auth.AuthCredential{
AccessToken: "anthropic-token",
}, nil
}
cfg := config.DefaultConfig()
cfg.Agents.Defaults.Provider = "anthropic"
cfg.Providers.Anthropic.AuthMethod = "oauth"
cfg.Providers.Anthropic.APIBase = "https://proxy.example.com/v1"
provider, err := CreateProvider(cfg)
if err != nil {
t.Fatalf("CreateProvider() error = %v", err)
}
claudeProvider, ok := provider.(*ClaudeProvider)
if !ok {
t.Fatalf("provider type = %T, want *ClaudeProvider", provider)
}
if got := claudeProvider.delegate.BaseURL(); got != "https://proxy.example.com" {
t.Fatalf("anthropic baseURL = %q, want %q", got, "https://proxy.example.com")
}
}
func TestCreateProviderReturnsCodexProviderForOpenAIOAuth(t *testing.T) {
originalGetCredential := getCredential
t.Cleanup(func() { getCredential = originalGetCredential })
getCredential = func(provider string) (*auth.AuthCredential, error) {
if provider != "openai" {
t.Fatalf("provider = %q, want openai", provider)
}
return &auth.AuthCredential{
AccessToken: "openai-token",
AccountID: "acct_123",
}, nil
}
cfg := config.DefaultConfig()
cfg.Agents.Defaults.Provider = "openai"
cfg.Providers.OpenAI.AuthMethod = "oauth"
provider, err := CreateProvider(cfg)
if err != nil {
t.Fatalf("CreateProvider() error = %v", err)
}
if _, ok := provider.(*CodexProvider); !ok {
t.Fatalf("provider type = %T, want *CodexProvider", provider)
}
}
+3 -103
View File
@@ -15,116 +15,16 @@ type HTTPProvider struct {
delegate *openai_compat.Provider
}
func NewHTTPProvider(apiKey, apiBase string, proxy ...string) *HTTPProvider {
proxyURL := ""
if len(proxy) > 0 {
proxyURL = proxy[0]
}
func NewHTTPProvider(apiKey, apiBase, proxy string) *HTTPProvider {
return &HTTPProvider{
delegate: openai_compat.NewProvider(apiKey, apiBase, proxyURL),
delegate: openai_compat.NewProvider(apiKey, apiBase, proxy),
}
}
func (p *HTTPProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) {
compatResp, err := p.delegate.Chat(ctx, toOpenAICompatMessages(messages), toOpenAICompatTools(tools), model, options)
if err != nil {
return nil, err
}
return fromOpenAICompatResponse(compatResp), nil
return p.delegate.Chat(ctx, messages, tools, model, options)
}
func (p *HTTPProvider) GetDefaultModel() string {
return ""
}
func toOpenAICompatMessages(messages []Message) []openai_compat.Message {
out := make([]openai_compat.Message, 0, len(messages))
for _, msg := range messages {
out = append(out, openai_compat.Message{
Role: msg.Role,
Content: msg.Content,
ToolCalls: toOpenAICompatToolCalls(msg.ToolCalls),
ToolCallID: msg.ToolCallID,
})
}
return out
}
func toOpenAICompatTools(tools []ToolDefinition) []openai_compat.ToolDefinition {
out := make([]openai_compat.ToolDefinition, 0, len(tools))
for _, t := range tools {
out = append(out, openai_compat.ToolDefinition{
Type: t.Type,
Function: openai_compat.ToolFunctionDefinition{
Name: t.Function.Name,
Description: t.Function.Description,
Parameters: t.Function.Parameters,
},
})
}
return out
}
func toOpenAICompatToolCalls(toolCalls []ToolCall) []openai_compat.ToolCall {
out := make([]openai_compat.ToolCall, 0, len(toolCalls))
for _, tc := range toolCalls {
var fn *openai_compat.FunctionCall
if tc.Function != nil {
fn = &openai_compat.FunctionCall{
Name: tc.Function.Name,
Arguments: tc.Function.Arguments,
}
}
out = append(out, openai_compat.ToolCall{
ID: tc.ID,
Type: tc.Type,
Function: fn,
Name: tc.Name,
Arguments: tc.Arguments,
})
}
return out
}
func fromOpenAICompatResponse(resp *openai_compat.LLMResponse) *LLMResponse {
if resp == nil {
return &LLMResponse{}
}
var usage *UsageInfo
if resp.Usage != nil {
usage = &UsageInfo{
PromptTokens: resp.Usage.PromptTokens,
CompletionTokens: resp.Usage.CompletionTokens,
TotalTokens: resp.Usage.TotalTokens,
}
}
return &LLMResponse{
Content: resp.Content,
ToolCalls: fromOpenAICompatToolCalls(resp.ToolCalls),
FinishReason: resp.FinishReason,
Usage: usage,
}
}
func fromOpenAICompatToolCalls(toolCalls []openai_compat.ToolCall) []ToolCall {
out := make([]ToolCall, 0, len(toolCalls))
for _, tc := range toolCalls {
var fn *FunctionCall
if tc.Function != nil {
fn = &FunctionCall{
Name: tc.Function.Name,
Arguments: tc.Function.Arguments,
}
}
out = append(out, ToolCall{
ID: tc.ID,
Type: tc.Type,
Function: fn,
Name: tc.Name,
Arguments: tc.Arguments,
})
}
return out
}
+69 -67
View File
@@ -6,55 +6,22 @@ import (
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"net/url"
"strings"
"time"
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
)
type ToolCall struct {
ID string `json:"id"`
Type string `json:"type,omitempty"`
Function *FunctionCall `json:"function,omitempty"`
Name string `json:"name,omitempty"`
Arguments map[string]interface{} `json:"arguments,omitempty"`
}
type FunctionCall struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
}
type LLMResponse struct {
Content string `json:"content"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
FinishReason string `json:"finish_reason"`
Usage *UsageInfo `json:"usage,omitempty"`
}
type UsageInfo struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
}
type ToolDefinition struct {
Type string `json:"type"`
Function ToolFunctionDefinition `json:"function"`
}
type ToolFunctionDefinition struct {
Name string `json:"name"`
Description string `json:"description"`
Parameters map[string]interface{} `json:"parameters"`
}
type ToolCall = protocoltypes.ToolCall
type FunctionCall = protocoltypes.FunctionCall
type LLMResponse = protocoltypes.LLMResponse
type UsageInfo = protocoltypes.UsageInfo
type Message = protocoltypes.Message
type ToolDefinition = protocoltypes.ToolDefinition
type ToolFunctionDefinition = protocoltypes.ToolFunctionDefinition
type Provider struct {
apiKey string
@@ -62,21 +29,19 @@ type Provider struct {
httpClient *http.Client
}
func NewProvider(apiKey, apiBase string, proxy ...string) *Provider {
proxyURL := ""
if len(proxy) > 0 {
proxyURL = proxy[0]
}
func NewProvider(apiKey, apiBase, proxy string) *Provider {
client := &http.Client{
Timeout: 120 * time.Second,
}
if proxyURL != "" {
parsed, err := url.Parse(proxyURL)
if proxy != "" {
parsed, err := url.Parse(proxy)
if err == nil {
client.Transport = &http.Transport{
Proxy: http.ProxyURL(parsed),
}
} else {
log.Printf("openai_compat: invalid proxy URL %q: %v", proxy, err)
}
}
@@ -92,13 +57,7 @@ func (p *Provider) Chat(ctx context.Context, messages []Message, tools []ToolDef
return nil, fmt.Errorf("API base not configured")
}
// Strip provider prefix for OpenAI-compatible backends.
if idx := strings.Index(model, "/"); idx != -1 {
prefix := model[:idx]
if prefix == "moonshot" || prefix == "nvidia" || prefix == "groq" || prefix == "ollama" {
model = model[idx+1:]
}
}
model = normalizeModel(model, p.apiBase)
requestBody := map[string]interface{}{
"model": model,
@@ -110,7 +69,7 @@ func (p *Provider) Chat(ctx context.Context, messages []Message, tools []ToolDef
requestBody["tool_choice"] = "auto"
}
if maxTokens, ok := options["max_tokens"].(int); ok {
if maxTokens, ok := asInt(options["max_tokens"]); ok {
lowerModel := strings.ToLower(model)
if strings.Contains(lowerModel, "glm") || strings.Contains(lowerModel, "o1") {
requestBody["max_completion_tokens"] = maxTokens
@@ -119,7 +78,7 @@ func (p *Provider) Chat(ctx context.Context, messages []Message, tools []ToolDef
}
}
if temperature, ok := options["temperature"].(float64); ok {
if temperature, ok := asFloat(options["temperature"]); ok {
lowerModel := strings.ToLower(model)
// Kimi k2 models only support temperature=1.
if strings.Contains(lowerModel, "kimi") && strings.Contains(lowerModel, "k2") {
@@ -198,17 +157,11 @@ func parseResponse(body []byte) (*LLMResponse, error) {
arguments := make(map[string]interface{})
name := ""
if tc.Type == "function" && tc.Function != nil {
name = tc.Function.Name
if tc.Function.Arguments != "" {
if err := json.Unmarshal([]byte(tc.Function.Arguments), &arguments); err != nil {
arguments["raw"] = tc.Function.Arguments
}
}
} else if tc.Function != nil {
if tc.Function != nil {
name = tc.Function.Name
if tc.Function.Arguments != "" {
if err := json.Unmarshal([]byte(tc.Function.Arguments), &arguments); err != nil {
log.Printf("openai_compat: failed to decode tool call arguments for %q: %v", name, err)
arguments["raw"] = tc.Function.Arguments
}
}
@@ -228,3 +181,52 @@ func parseResponse(body []byte) (*LLMResponse, error) {
Usage: apiResponse.Usage,
}, nil
}
func normalizeModel(model, apiBase string) string {
idx := strings.Index(model, "/")
if idx == -1 {
return model
}
if strings.Contains(strings.ToLower(apiBase), "openrouter.ai") {
return model
}
prefix := strings.ToLower(model[:idx])
switch prefix {
case "moonshot", "nvidia", "groq", "ollama", "deepseek", "google", "openrouter", "zhipu":
return model[idx+1:]
default:
return model
}
}
func asInt(v interface{}) (int, bool) {
switch val := v.(type) {
case int:
return val, true
case int64:
return int(val), true
case float64:
return int(val), true
case float32:
return int(val), true
default:
return 0, false
}
}
func asFloat(v interface{}) (float64, bool) {
switch val := v.(type) {
case float64:
return val, true
case float32:
return float64(val), true
case int:
return float64(val), true
case int64:
return float64(val), true
default:
return 0, false
}
}
+80 -5
View File
@@ -4,6 +4,7 @@ import (
"encoding/json"
"net/http"
"net/http/httptest"
"net/url"
"testing"
)
@@ -32,7 +33,7 @@ func TestProviderChat_UsesMaxCompletionTokensForGLM(t *testing.T) {
}))
defer server.Close()
p := NewProvider("key", server.URL)
p := NewProvider("key", server.URL, "")
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "glm-4.7", map[string]interface{}{"max_tokens": 1234})
if err != nil {
t.Fatalf("Chat() error = %v", err)
@@ -78,7 +79,7 @@ func TestProviderChat_ParsesToolCalls(t *testing.T) {
}))
defer server.Close()
p := NewProvider("key", server.URL)
p := NewProvider("key", server.URL, "")
out, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "gpt-4o", nil)
if err != nil {
t.Fatalf("Chat() error = %v", err)
@@ -100,7 +101,7 @@ func TestProviderChat_HTTPError(t *testing.T) {
}))
defer server.Close()
p := NewProvider("key", server.URL)
p := NewProvider("key", server.URL, "")
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "gpt-4o", nil)
if err == nil {
t.Fatal("expected error, got nil")
@@ -128,7 +129,7 @@ func TestProviderChat_StripsMoonshotPrefixAndNormalizesKimiTemperature(t *testin
}))
defer server.Close()
p := NewProvider("key", server.URL)
p := NewProvider("key", server.URL, "")
_, err := p.Chat(
t.Context(),
[]Message{{Role: "user", Content: "hi"}},
@@ -164,6 +165,11 @@ func TestProviderChat_StripsGroqAndOllamaPrefixes(t *testing.T) {
input: "ollama/qwen2.5:14b",
wantModel: "qwen2.5:14b",
},
{
name: "strips deepseek prefix",
input: "deepseek/deepseek-chat",
wantModel: "deepseek-chat",
},
}
for _, tt := range tests {
@@ -188,7 +194,7 @@ func TestProviderChat_StripsGroqAndOllamaPrefixes(t *testing.T) {
}))
defer server.Close()
p := NewProvider("key", server.URL)
p := NewProvider("key", server.URL, "")
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, tt.input, nil)
if err != nil {
t.Fatalf("Chat() error = %v", err)
@@ -200,3 +206,72 @@ func TestProviderChat_StripsGroqAndOllamaPrefixes(t *testing.T) {
})
}
}
func TestProvider_ProxyConfigured(t *testing.T) {
proxyURL := "http://127.0.0.1:8080"
p := NewProvider("key", "https://example.com", proxyURL)
transport, ok := p.httpClient.Transport.(*http.Transport)
if !ok || transport == nil {
t.Fatalf("expected http transport with proxy, got %T", p.httpClient.Transport)
}
req := &http.Request{URL: &url.URL{Scheme: "https", Host: "api.example.com"}}
gotProxy, err := transport.Proxy(req)
if err != nil {
t.Fatalf("proxy function returned error: %v", err)
}
if gotProxy == nil || gotProxy.String() != proxyURL {
t.Fatalf("proxy = %v, want %s", gotProxy, proxyURL)
}
}
func TestProviderChat_AcceptsNumericOptionTypes(t *testing.T) {
var requestBody map[string]interface{}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
resp := map[string]interface{}{
"choices": []map[string]interface{}{
{
"message": map[string]interface{}{"content": "ok"},
"finish_reason": "stop",
},
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
p := NewProvider("key", server.URL, "")
_, err := p.Chat(
t.Context(),
[]Message{{Role: "user", Content: "hi"}},
nil,
"gpt-4o",
map[string]interface{}{"max_tokens": float64(512), "temperature": 1},
)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
if requestBody["max_tokens"] != float64(512) {
t.Fatalf("max_tokens = %v, want 512", requestBody["max_tokens"])
}
if requestBody["temperature"] != float64(1) {
t.Fatalf("temperature = %v, want 1", requestBody["temperature"])
}
}
func TestNormalizeModel_UsesAPIBase(t *testing.T) {
if got := normalizeModel("deepseek/deepseek-chat", "https://api.deepseek.com/v1"); got != "deepseek-chat" {
t.Fatalf("normalizeModel(deepseek) = %q, want %q", got, "deepseek-chat")
}
if got := normalizeModel("openrouter/auto", "https://openrouter.ai/api/v1"); got != "openrouter/auto" {
t.Fatalf("normalizeModel(openrouter) = %q, want %q", got, "openrouter/auto")
}
}
+45
View File
@@ -0,0 +1,45 @@
package protocoltypes
type ToolCall struct {
ID string `json:"id"`
Type string `json:"type,omitempty"`
Function *FunctionCall `json:"function,omitempty"`
Name string `json:"name,omitempty"`
Arguments map[string]interface{} `json:"arguments,omitempty"`
}
type FunctionCall struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
}
type LLMResponse struct {
Content string `json:"content"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
FinishReason string `json:"finish_reason"`
Usage *UsageInfo `json:"usage,omitempty"`
}
type UsageInfo struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
}
type ToolDefinition struct {
Type string `json:"type"`
Function ToolFunctionDefinition `json:"function"`
}
type ToolFunctionDefinition struct {
Name string `json:"name"`
Description string `json:"description"`
Parameters map[string]interface{} `json:"parameters"`
}
+11 -43
View File
@@ -1,52 +1,20 @@
package providers
import "context"
import (
"context"
type ToolCall struct {
ID string `json:"id"`
Type string `json:"type,omitempty"`
Function *FunctionCall `json:"function,omitempty"`
Name string `json:"name,omitempty"`
Arguments map[string]interface{} `json:"arguments,omitempty"`
}
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
)
type FunctionCall struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
}
type LLMResponse struct {
Content string `json:"content"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
FinishReason string `json:"finish_reason"`
Usage *UsageInfo `json:"usage,omitempty"`
}
type UsageInfo struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
}
type ToolCall = protocoltypes.ToolCall
type FunctionCall = protocoltypes.FunctionCall
type LLMResponse = protocoltypes.LLMResponse
type UsageInfo = protocoltypes.UsageInfo
type Message = protocoltypes.Message
type ToolDefinition = protocoltypes.ToolDefinition
type ToolFunctionDefinition = protocoltypes.ToolFunctionDefinition
type LLMProvider interface {
Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error)
GetDefaultModel() string
}
type ToolDefinition struct {
Type string `json:"type"`
Function ToolFunctionDefinition `json:"function"`
}
type ToolFunctionDefinition struct {
Name string `json:"name"`
Description string `json:"description"`
Parameters map[string]interface{} `json:"parameters"`
}