From c4cbb5fb35374d0ff917baff9196746f843b99fa Mon Sep 17 00:00:00 2001 From: Jared Mahotiere Date: Tue, 17 Feb 2026 11:13:10 -0500 Subject: [PATCH] providers: finalize PR213 review fixes Phase 1: centralize protocol message/tool/response types in protocoltypes and keep compatibility aliases in providers and protocol packages. Phase 1: preserve HTTPProvider constructor compatibility and route Anthropic api_base through factory auth/provider constructors with base URL normalization. Phase 2: expand provider routing/auth tests (deepseek/nvidia/shengsuanyun, codex/claude oauth/codex-cli) and add openai_compat + anthropic coverage for proxy transport, model normalization, numeric option coercion, token-source refresh, and base URL behavior. Phase 3: apply gofmt and validate with Dockerized tests (go test ./pkg/providers/... ./pkg/migrate and go test ./...). --- pkg/providers/anthropic/provider.go | 99 +++++++------- pkg/providers/anthropic/provider_test.go | 57 ++++++++ pkg/providers/claude_provider.go | 118 ++-------------- pkg/providers/factory.go | 47 ++++++- pkg/providers/factory_test.go | 95 +++++++++++++ pkg/providers/http_provider.go | 106 +-------------- pkg/providers/openai_compat/provider.go | 136 ++++++++++--------- pkg/providers/openai_compat/provider_test.go | 85 +++++++++++- pkg/providers/protocoltypes/types.go | 45 ++++++ pkg/providers/types.go | 54 ++------ 10 files changed, 468 insertions(+), 374 deletions(-) create mode 100644 pkg/providers/protocoltypes/types.go diff --git a/pkg/providers/anthropic/provider.go b/pkg/providers/anthropic/provider.go index ca72f0180..8f46aa70c 100644 --- a/pkg/providers/anthropic/provider.go +++ b/pkg/providers/anthropic/provider.go @@ -4,74 +4,59 @@ import ( "context" "encoding/json" "fmt" + "log" + "strings" "github.com/anthropics/anthropic-sdk-go" "github.com/anthropics/anthropic-sdk-go/option" + "github.com/sipeed/picoclaw/pkg/providers/protocoltypes" ) -type ToolCall struct { - ID string `json:"id"` - Type string `json:"type,omitempty"` - Function *FunctionCall `json:"function,omitempty"` - Name string `json:"name,omitempty"` - Arguments map[string]interface{} `json:"arguments,omitempty"` -} +type ToolCall = protocoltypes.ToolCall +type FunctionCall = protocoltypes.FunctionCall +type LLMResponse = protocoltypes.LLMResponse +type UsageInfo = protocoltypes.UsageInfo +type Message = protocoltypes.Message +type ToolDefinition = protocoltypes.ToolDefinition +type ToolFunctionDefinition = protocoltypes.ToolFunctionDefinition -type FunctionCall struct { - Name string `json:"name"` - Arguments string `json:"arguments"` -} - -type LLMResponse struct { - Content string `json:"content"` - ToolCalls []ToolCall `json:"tool_calls,omitempty"` - FinishReason string `json:"finish_reason"` - Usage *UsageInfo `json:"usage,omitempty"` -} - -type UsageInfo struct { - PromptTokens int `json:"prompt_tokens"` - CompletionTokens int `json:"completion_tokens"` - TotalTokens int `json:"total_tokens"` -} - -type Message struct { - Role string `json:"role"` - Content string `json:"content"` - ToolCalls []ToolCall `json:"tool_calls,omitempty"` - ToolCallID string `json:"tool_call_id,omitempty"` -} - -type ToolDefinition struct { - Type string `json:"type"` - Function ToolFunctionDefinition `json:"function"` -} - -type ToolFunctionDefinition struct { - Name string `json:"name"` - Description string `json:"description"` - Parameters map[string]interface{} `json:"parameters"` -} +const defaultBaseURL = "https://api.anthropic.com" type Provider struct { client *anthropic.Client tokenSource func() (string, error) + baseURL string } func NewProvider(token string) *Provider { + return NewProviderWithBaseURL(token, "") +} + +func NewProviderWithBaseURL(token, apiBase string) *Provider { + baseURL := normalizeBaseURL(apiBase) client := anthropic.NewClient( option.WithAuthToken(token), - option.WithBaseURL("https://api.anthropic.com"), + option.WithBaseURL(baseURL), ) - return &Provider{client: &client} + return &Provider{ + client: &client, + baseURL: baseURL, + } } func NewProviderWithClient(client *anthropic.Client) *Provider { - return &Provider{client: client} + return &Provider{ + client: client, + baseURL: defaultBaseURL, + } } func NewProviderWithTokenSource(token string, tokenSource func() (string, error)) *Provider { - p := NewProvider(token) + return NewProviderWithTokenSourceAndBaseURL(token, tokenSource, "") +} + +func NewProviderWithTokenSourceAndBaseURL(token string, tokenSource func() (string, error), apiBase string) *Provider { + p := NewProviderWithBaseURL(token, apiBase) p.tokenSource = tokenSource return p } @@ -103,6 +88,10 @@ func (p *Provider) GetDefaultModel() string { return "claude-sonnet-4-5-20250929" } +func (p *Provider) BaseURL() string { + return p.baseURL +} + func buildParams(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (anthropic.MessageNewParams, error) { var system []anthropic.TextBlockParam var anthropicMessages []anthropic.MessageParam @@ -208,6 +197,7 @@ func parseResponse(resp *anthropic.Message) *LLMResponse { tu := block.AsToolUse() var args map[string]interface{} if err := json.Unmarshal(tu.Input, &args); err != nil { + log.Printf("anthropic: failed to decode tool call input for %q: %v", tu.Name, err) args = map[string]interface{}{"raw": string(tu.Input)} } toolCalls = append(toolCalls, ToolCall{ @@ -239,3 +229,20 @@ func parseResponse(resp *anthropic.Message) *LLMResponse { }, } } + +func normalizeBaseURL(apiBase string) string { + base := strings.TrimSpace(apiBase) + if base == "" { + return defaultBaseURL + } + + base = strings.TrimRight(base, "/") + if strings.HasSuffix(base, "/v1") { + base = strings.TrimSuffix(base, "/v1") + } + if base == "" { + return defaultBaseURL + } + + return base +} diff --git a/pkg/providers/anthropic/provider_test.go b/pkg/providers/anthropic/provider_test.go index 01b4fe663..6a1dabafb 100644 --- a/pkg/providers/anthropic/provider_test.go +++ b/pkg/providers/anthropic/provider_test.go @@ -4,6 +4,7 @@ import ( "encoding/json" "net/http" "net/http/httptest" + "sync/atomic" "testing" "github.com/anthropics/anthropic-sdk-go" @@ -199,6 +200,62 @@ func TestProvider_GetDefaultModel(t *testing.T) { } } +func TestProvider_NewProviderWithBaseURL_NormalizesV1Suffix(t *testing.T) { + p := NewProviderWithBaseURL("token", "https://api.anthropic.com/v1/") + if got := p.BaseURL(); got != "https://api.anthropic.com" { + t.Fatalf("BaseURL() = %q, want %q", got, "https://api.anthropic.com") + } +} + +func TestProvider_ChatUsesTokenSource(t *testing.T) { + var requests int32 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/messages" { + http.Error(w, "not found", http.StatusNotFound) + return + } + atomic.AddInt32(&requests, 1) + + if got := r.Header.Get("Authorization"); got != "Bearer refreshed-token" { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + + var reqBody map[string]interface{} + json.NewDecoder(r.Body).Decode(&reqBody) + + resp := map[string]interface{}{ + "id": "msg_test", + "type": "message", + "role": "assistant", + "model": reqBody["model"], + "stop_reason": "end_turn", + "content": []map[string]interface{}{ + {"type": "text", "text": "ok"}, + }, + "usage": map[string]interface{}{ + "input_tokens": 1, + "output_tokens": 1, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + p := NewProviderWithTokenSourceAndBaseURL("stale-token", func() (string, error) { + return "refreshed-token", nil + }, server.URL) + + _, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hello"}}, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{}) + if err != nil { + t.Fatalf("Chat() error: %v", err) + } + if got := atomic.LoadInt32(&requests); got != 1 { + t.Fatalf("requests = %d, want 1", got) + } +} + func createAnthropicTestClient(baseURL, token string) *anthropic.Client { c := anthropic.NewClient( anthropicoption.WithAuthToken(token), diff --git a/pkg/providers/claude_provider.go b/pkg/providers/claude_provider.go index 16f1884c5..c72f5b0ef 100644 --- a/pkg/providers/claude_provider.go +++ b/pkg/providers/claude_provider.go @@ -3,8 +3,6 @@ package providers import ( "context" "fmt" - - "github.com/sipeed/picoclaw/pkg/auth" anthropicprovider "github.com/sipeed/picoclaw/pkg/providers/anthropic" ) @@ -18,28 +16,34 @@ func NewClaudeProvider(token string) *ClaudeProvider { } } +func NewClaudeProviderWithBaseURL(token, apiBase string) *ClaudeProvider { + return &ClaudeProvider{ + delegate: anthropicprovider.NewProviderWithBaseURL(token, apiBase), + } +} + func NewClaudeProviderWithTokenSource(token string, tokenSource func() (string, error)) *ClaudeProvider { return &ClaudeProvider{ delegate: anthropicprovider.NewProviderWithTokenSource(token, tokenSource), } } +func NewClaudeProviderWithTokenSourceAndBaseURL(token string, tokenSource func() (string, error), apiBase string) *ClaudeProvider { + return &ClaudeProvider{ + delegate: anthropicprovider.NewProviderWithTokenSourceAndBaseURL(token, tokenSource, apiBase), + } +} + func newClaudeProviderWithDelegate(delegate *anthropicprovider.Provider) *ClaudeProvider { return &ClaudeProvider{delegate: delegate} } func (p *ClaudeProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) { - resp, err := p.delegate.Chat( - ctx, - toAnthropicProviderMessages(messages), - toAnthropicProviderTools(tools), - model, - options, - ) + resp, err := p.delegate.Chat(ctx, messages, tools, model, options) if err != nil { return nil, err } - return fromAnthropicProviderResponse(resp), nil + return resp, nil } func (p *ClaudeProvider) GetDefaultModel() string { @@ -48,7 +52,7 @@ func (p *ClaudeProvider) GetDefaultModel() string { func createClaudeTokenSource() func() (string, error) { return func() (string, error) { - cred, err := auth.GetCredential("anthropic") + cred, err := getCredential("anthropic") if err != nil { return "", fmt.Errorf("loading auth credentials: %w", err) } @@ -58,95 +62,3 @@ func createClaudeTokenSource() func() (string, error) { return cred.AccessToken, nil } } - -func toAnthropicProviderMessages(messages []Message) []anthropicprovider.Message { - out := make([]anthropicprovider.Message, 0, len(messages)) - for _, msg := range messages { - out = append(out, anthropicprovider.Message{ - Role: msg.Role, - Content: msg.Content, - ToolCalls: toAnthropicProviderToolCalls(msg.ToolCalls), - ToolCallID: msg.ToolCallID, - }) - } - return out -} - -func toAnthropicProviderTools(tools []ToolDefinition) []anthropicprovider.ToolDefinition { - out := make([]anthropicprovider.ToolDefinition, 0, len(tools)) - for _, t := range tools { - out = append(out, anthropicprovider.ToolDefinition{ - Type: t.Type, - Function: anthropicprovider.ToolFunctionDefinition{ - Name: t.Function.Name, - Description: t.Function.Description, - Parameters: t.Function.Parameters, - }, - }) - } - return out -} - -func toAnthropicProviderToolCalls(toolCalls []ToolCall) []anthropicprovider.ToolCall { - out := make([]anthropicprovider.ToolCall, 0, len(toolCalls)) - for _, tc := range toolCalls { - var fn *anthropicprovider.FunctionCall - if tc.Function != nil { - fn = &anthropicprovider.FunctionCall{ - Name: tc.Function.Name, - Arguments: tc.Function.Arguments, - } - } - out = append(out, anthropicprovider.ToolCall{ - ID: tc.ID, - Type: tc.Type, - Function: fn, - Name: tc.Name, - Arguments: tc.Arguments, - }) - } - return out -} - -func fromAnthropicProviderResponse(resp *anthropicprovider.LLMResponse) *LLMResponse { - if resp == nil { - return &LLMResponse{} - } - - var usage *UsageInfo - if resp.Usage != nil { - usage = &UsageInfo{ - PromptTokens: resp.Usage.PromptTokens, - CompletionTokens: resp.Usage.CompletionTokens, - TotalTokens: resp.Usage.TotalTokens, - } - } - - return &LLMResponse{ - Content: resp.Content, - ToolCalls: fromAnthropicProviderToolCalls(resp.ToolCalls), - FinishReason: resp.FinishReason, - Usage: usage, - } -} - -func fromAnthropicProviderToolCalls(toolCalls []anthropicprovider.ToolCall) []ToolCall { - out := make([]ToolCall, 0, len(toolCalls)) - for _, tc := range toolCalls { - var fn *FunctionCall - if tc.Function != nil { - fn = &FunctionCall{ - Name: tc.Function.Name, - Arguments: tc.Function.Arguments, - } - } - out = append(out, ToolCall{ - ID: tc.ID, - Type: tc.Type, - Function: fn, - Name: tc.Name, - Arguments: tc.Arguments, - }) - } - return out -} diff --git a/pkg/providers/factory.go b/pkg/providers/factory.go index 28609c4b3..67a347721 100644 --- a/pkg/providers/factory.go +++ b/pkg/providers/factory.go @@ -8,6 +8,10 @@ import ( "github.com/sipeed/picoclaw/pkg/config" ) +const defaultAnthropicAPIBase = "https://api.anthropic.com/v1" + +var getCredential = auth.GetCredential + type providerType int const ( @@ -30,19 +34,22 @@ type providerSelection struct { connectMode string } -func createClaudeAuthProvider() (LLMProvider, error) { - cred, err := auth.GetCredential("anthropic") +func createClaudeAuthProvider(apiBase string) (LLMProvider, error) { + if apiBase == "" { + apiBase = defaultAnthropicAPIBase + } + cred, err := getCredential("anthropic") if err != nil { return nil, fmt.Errorf("loading auth credentials: %w", err) } if cred == nil { return nil, fmt.Errorf("no credentials for anthropic. Run: picoclaw auth login --provider anthropic") } - return NewClaudeProviderWithTokenSource(cred.AccessToken, createClaudeTokenSource()), nil + return NewClaudeProviderWithTokenSourceAndBaseURL(cred.AccessToken, createClaudeTokenSource(), apiBase), nil } func createCodexAuthProvider() (LLMProvider, error) { - cred, err := auth.GetCredential("openai") + cred, err := getCredential("openai") if err != nil { return nil, fmt.Errorf("loading auth credentials: %w", err) } @@ -69,6 +76,7 @@ func resolveProviderSelection(cfg *config.Config) (providerSelection, error) { if cfg.Providers.Groq.APIKey != "" { sel.apiKey = cfg.Providers.Groq.APIKey sel.apiBase = cfg.Providers.Groq.APIBase + sel.proxy = cfg.Providers.Groq.Proxy if sel.apiBase == "" { sel.apiBase = "https://api.groq.com/openai/v1" } @@ -85,6 +93,7 @@ func resolveProviderSelection(cfg *config.Config) (providerSelection, error) { } sel.apiKey = cfg.Providers.OpenAI.APIKey sel.apiBase = cfg.Providers.OpenAI.APIBase + sel.proxy = cfg.Providers.OpenAI.Proxy if sel.apiBase == "" { sel.apiBase = "https://api.openai.com/v1" } @@ -92,18 +101,24 @@ func resolveProviderSelection(cfg *config.Config) (providerSelection, error) { case "anthropic", "claude": if cfg.Providers.Anthropic.APIKey != "" || cfg.Providers.Anthropic.AuthMethod != "" { if cfg.Providers.Anthropic.AuthMethod == "oauth" || cfg.Providers.Anthropic.AuthMethod == "token" { + sel.apiBase = cfg.Providers.Anthropic.APIBase + if sel.apiBase == "" { + sel.apiBase = defaultAnthropicAPIBase + } sel.providerType = providerTypeClaudeAuth return sel, nil } sel.apiKey = cfg.Providers.Anthropic.APIKey sel.apiBase = cfg.Providers.Anthropic.APIBase + sel.proxy = cfg.Providers.Anthropic.Proxy if sel.apiBase == "" { - sel.apiBase = "https://api.anthropic.com/v1" + sel.apiBase = defaultAnthropicAPIBase } } case "openrouter": if cfg.Providers.OpenRouter.APIKey != "" { sel.apiKey = cfg.Providers.OpenRouter.APIKey + sel.proxy = cfg.Providers.OpenRouter.Proxy if cfg.Providers.OpenRouter.APIBase != "" { sel.apiBase = cfg.Providers.OpenRouter.APIBase } else { @@ -114,6 +129,7 @@ func resolveProviderSelection(cfg *config.Config) (providerSelection, error) { if cfg.Providers.Zhipu.APIKey != "" { sel.apiKey = cfg.Providers.Zhipu.APIKey sel.apiBase = cfg.Providers.Zhipu.APIBase + sel.proxy = cfg.Providers.Zhipu.Proxy if sel.apiBase == "" { sel.apiBase = "https://open.bigmodel.cn/api/paas/v4" } @@ -122,6 +138,7 @@ func resolveProviderSelection(cfg *config.Config) (providerSelection, error) { if cfg.Providers.Gemini.APIKey != "" { sel.apiKey = cfg.Providers.Gemini.APIKey sel.apiBase = cfg.Providers.Gemini.APIBase + sel.proxy = cfg.Providers.Gemini.Proxy if sel.apiBase == "" { sel.apiBase = "https://generativelanguage.googleapis.com/v1beta" } @@ -130,15 +147,26 @@ func resolveProviderSelection(cfg *config.Config) (providerSelection, error) { if cfg.Providers.VLLM.APIBase != "" { sel.apiKey = cfg.Providers.VLLM.APIKey sel.apiBase = cfg.Providers.VLLM.APIBase + sel.proxy = cfg.Providers.VLLM.Proxy } case "shengsuanyun": if cfg.Providers.ShengSuanYun.APIKey != "" { sel.apiKey = cfg.Providers.ShengSuanYun.APIKey sel.apiBase = cfg.Providers.ShengSuanYun.APIBase + sel.proxy = cfg.Providers.ShengSuanYun.Proxy if sel.apiBase == "" { sel.apiBase = "https://router.shengsuanyun.com/api/v1" } } + case "nvidia": + if cfg.Providers.Nvidia.APIKey != "" { + sel.apiKey = cfg.Providers.Nvidia.APIKey + sel.apiBase = cfg.Providers.Nvidia.APIBase + sel.proxy = cfg.Providers.Nvidia.Proxy + if sel.apiBase == "" { + sel.apiBase = "https://integrate.api.nvidia.com/v1" + } + } case "claude-cli", "claude-code", "claudecode": workspace := cfg.WorkspacePath() if workspace == "" { @@ -159,6 +187,7 @@ func resolveProviderSelection(cfg *config.Config) (providerSelection, error) { if cfg.Providers.DeepSeek.APIKey != "" { sel.apiKey = cfg.Providers.DeepSeek.APIKey sel.apiBase = cfg.Providers.DeepSeek.APIBase + sel.proxy = cfg.Providers.DeepSeek.Proxy if sel.apiBase == "" { sel.apiBase = "https://api.deepseek.com/v1" } @@ -204,6 +233,10 @@ func resolveProviderSelection(cfg *config.Config) (providerSelection, error) { case (strings.Contains(lowerModel, "claude") || strings.HasPrefix(model, "anthropic/")) && (cfg.Providers.Anthropic.APIKey != "" || cfg.Providers.Anthropic.AuthMethod != ""): if cfg.Providers.Anthropic.AuthMethod == "oauth" || cfg.Providers.Anthropic.AuthMethod == "token" { + sel.apiBase = cfg.Providers.Anthropic.APIBase + if sel.apiBase == "" { + sel.apiBase = defaultAnthropicAPIBase + } sel.providerType = providerTypeClaudeAuth return sel, nil } @@ -211,7 +244,7 @@ func resolveProviderSelection(cfg *config.Config) (providerSelection, error) { sel.apiBase = cfg.Providers.Anthropic.APIBase sel.proxy = cfg.Providers.Anthropic.Proxy if sel.apiBase == "" { - sel.apiBase = "https://api.anthropic.com/v1" + sel.apiBase = defaultAnthropicAPIBase } case (strings.Contains(lowerModel, "gpt") || strings.HasPrefix(model, "openai/")) && (cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != ""): @@ -303,7 +336,7 @@ func CreateProvider(cfg *config.Config) (LLMProvider, error) { switch sel.providerType { case providerTypeClaudeAuth: - return createClaudeAuthProvider() + return createClaudeAuthProvider(sel.apiBase) case providerTypeCodexAuth: return createCodexAuthProvider() case providerTypeCodexCLIToken: diff --git a/pkg/providers/factory_test.go b/pkg/providers/factory_test.go index c1f14291d..e31737eb9 100644 --- a/pkg/providers/factory_test.go +++ b/pkg/providers/factory_test.go @@ -4,6 +4,7 @@ import ( "strings" "testing" + "github.com/sipeed/picoclaw/pkg/auth" "github.com/sipeed/picoclaw/pkg/config" ) @@ -32,6 +33,40 @@ func TestResolveProviderSelection(t *testing.T) { wantType: providerTypeGitHubCopilot, wantAPIBase: "localhost:4321", }, + { + name: "explicit deepseek provider uses deepseek defaults", + setup: func(cfg *config.Config) { + cfg.Agents.Defaults.Provider = "deepseek" + cfg.Agents.Defaults.Model = "deepseek/deepseek-chat" + cfg.Providers.DeepSeek.APIKey = "deepseek-key" + cfg.Providers.DeepSeek.Proxy = "http://127.0.0.1:7890" + }, + wantType: providerTypeHTTPCompat, + wantAPIBase: "https://api.deepseek.com/v1", + wantProxy: "http://127.0.0.1:7890", + }, + { + name: "explicit shengsuanyun provider uses defaults", + setup: func(cfg *config.Config) { + cfg.Agents.Defaults.Provider = "shengsuanyun" + cfg.Providers.ShengSuanYun.APIKey = "ssy-key" + cfg.Providers.ShengSuanYun.Proxy = "http://127.0.0.1:7890" + }, + wantType: providerTypeHTTPCompat, + wantAPIBase: "https://router.shengsuanyun.com/api/v1", + wantProxy: "http://127.0.0.1:7890", + }, + { + name: "explicit nvidia provider uses defaults", + setup: func(cfg *config.Config) { + cfg.Agents.Defaults.Provider = "nvidia" + cfg.Providers.Nvidia.APIKey = "nvapi-test" + cfg.Providers.Nvidia.Proxy = "http://127.0.0.1:7890" + }, + wantType: providerTypeHTTPCompat, + wantAPIBase: "https://integrate.api.nvidia.com/v1", + wantProxy: "http://127.0.0.1:7890", + }, { name: "openrouter model uses openrouter defaults", setup: func(cfg *config.Config) { @@ -202,3 +237,63 @@ func TestCreateProviderReturnsCodexProviderForCodexCliAuthMethod(t *testing.T) { t.Fatalf("provider type = %T, want *CodexProvider", provider) } } + +func TestCreateProviderReturnsClaudeProviderForAnthropicOAuth(t *testing.T) { + originalGetCredential := getCredential + t.Cleanup(func() { getCredential = originalGetCredential }) + + getCredential = func(provider string) (*auth.AuthCredential, error) { + if provider != "anthropic" { + t.Fatalf("provider = %q, want anthropic", provider) + } + return &auth.AuthCredential{ + AccessToken: "anthropic-token", + }, nil + } + + cfg := config.DefaultConfig() + cfg.Agents.Defaults.Provider = "anthropic" + cfg.Providers.Anthropic.AuthMethod = "oauth" + cfg.Providers.Anthropic.APIBase = "https://proxy.example.com/v1" + + provider, err := CreateProvider(cfg) + if err != nil { + t.Fatalf("CreateProvider() error = %v", err) + } + + claudeProvider, ok := provider.(*ClaudeProvider) + if !ok { + t.Fatalf("provider type = %T, want *ClaudeProvider", provider) + } + if got := claudeProvider.delegate.BaseURL(); got != "https://proxy.example.com" { + t.Fatalf("anthropic baseURL = %q, want %q", got, "https://proxy.example.com") + } +} + +func TestCreateProviderReturnsCodexProviderForOpenAIOAuth(t *testing.T) { + originalGetCredential := getCredential + t.Cleanup(func() { getCredential = originalGetCredential }) + + getCredential = func(provider string) (*auth.AuthCredential, error) { + if provider != "openai" { + t.Fatalf("provider = %q, want openai", provider) + } + return &auth.AuthCredential{ + AccessToken: "openai-token", + AccountID: "acct_123", + }, nil + } + + cfg := config.DefaultConfig() + cfg.Agents.Defaults.Provider = "openai" + cfg.Providers.OpenAI.AuthMethod = "oauth" + + provider, err := CreateProvider(cfg) + if err != nil { + t.Fatalf("CreateProvider() error = %v", err) + } + + if _, ok := provider.(*CodexProvider); !ok { + t.Fatalf("provider type = %T, want *CodexProvider", provider) + } +} diff --git a/pkg/providers/http_provider.go b/pkg/providers/http_provider.go index 0f7f646d8..e39a19e90 100644 --- a/pkg/providers/http_provider.go +++ b/pkg/providers/http_provider.go @@ -15,116 +15,16 @@ type HTTPProvider struct { delegate *openai_compat.Provider } -func NewHTTPProvider(apiKey, apiBase string, proxy ...string) *HTTPProvider { - proxyURL := "" - if len(proxy) > 0 { - proxyURL = proxy[0] - } +func NewHTTPProvider(apiKey, apiBase, proxy string) *HTTPProvider { return &HTTPProvider{ - delegate: openai_compat.NewProvider(apiKey, apiBase, proxyURL), + delegate: openai_compat.NewProvider(apiKey, apiBase, proxy), } } func (p *HTTPProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) { - compatResp, err := p.delegate.Chat(ctx, toOpenAICompatMessages(messages), toOpenAICompatTools(tools), model, options) - if err != nil { - return nil, err - } - return fromOpenAICompatResponse(compatResp), nil + return p.delegate.Chat(ctx, messages, tools, model, options) } func (p *HTTPProvider) GetDefaultModel() string { return "" } - -func toOpenAICompatMessages(messages []Message) []openai_compat.Message { - out := make([]openai_compat.Message, 0, len(messages)) - for _, msg := range messages { - out = append(out, openai_compat.Message{ - Role: msg.Role, - Content: msg.Content, - ToolCalls: toOpenAICompatToolCalls(msg.ToolCalls), - ToolCallID: msg.ToolCallID, - }) - } - return out -} - -func toOpenAICompatTools(tools []ToolDefinition) []openai_compat.ToolDefinition { - out := make([]openai_compat.ToolDefinition, 0, len(tools)) - for _, t := range tools { - out = append(out, openai_compat.ToolDefinition{ - Type: t.Type, - Function: openai_compat.ToolFunctionDefinition{ - Name: t.Function.Name, - Description: t.Function.Description, - Parameters: t.Function.Parameters, - }, - }) - } - return out -} - -func toOpenAICompatToolCalls(toolCalls []ToolCall) []openai_compat.ToolCall { - out := make([]openai_compat.ToolCall, 0, len(toolCalls)) - for _, tc := range toolCalls { - var fn *openai_compat.FunctionCall - if tc.Function != nil { - fn = &openai_compat.FunctionCall{ - Name: tc.Function.Name, - Arguments: tc.Function.Arguments, - } - } - out = append(out, openai_compat.ToolCall{ - ID: tc.ID, - Type: tc.Type, - Function: fn, - Name: tc.Name, - Arguments: tc.Arguments, - }) - } - return out -} - -func fromOpenAICompatResponse(resp *openai_compat.LLMResponse) *LLMResponse { - if resp == nil { - return &LLMResponse{} - } - - var usage *UsageInfo - if resp.Usage != nil { - usage = &UsageInfo{ - PromptTokens: resp.Usage.PromptTokens, - CompletionTokens: resp.Usage.CompletionTokens, - TotalTokens: resp.Usage.TotalTokens, - } - } - - return &LLMResponse{ - Content: resp.Content, - ToolCalls: fromOpenAICompatToolCalls(resp.ToolCalls), - FinishReason: resp.FinishReason, - Usage: usage, - } -} - -func fromOpenAICompatToolCalls(toolCalls []openai_compat.ToolCall) []ToolCall { - out := make([]ToolCall, 0, len(toolCalls)) - for _, tc := range toolCalls { - var fn *FunctionCall - if tc.Function != nil { - fn = &FunctionCall{ - Name: tc.Function.Name, - Arguments: tc.Function.Arguments, - } - } - out = append(out, ToolCall{ - ID: tc.ID, - Type: tc.Type, - Function: fn, - Name: tc.Name, - Arguments: tc.Arguments, - }) - } - return out -} diff --git a/pkg/providers/openai_compat/provider.go b/pkg/providers/openai_compat/provider.go index 7bc8e26be..9b404dd77 100644 --- a/pkg/providers/openai_compat/provider.go +++ b/pkg/providers/openai_compat/provider.go @@ -6,55 +6,22 @@ import ( "encoding/json" "fmt" "io" + "log" "net/http" "net/url" "strings" "time" + + "github.com/sipeed/picoclaw/pkg/providers/protocoltypes" ) -type ToolCall struct { - ID string `json:"id"` - Type string `json:"type,omitempty"` - Function *FunctionCall `json:"function,omitempty"` - Name string `json:"name,omitempty"` - Arguments map[string]interface{} `json:"arguments,omitempty"` -} - -type FunctionCall struct { - Name string `json:"name"` - Arguments string `json:"arguments"` -} - -type LLMResponse struct { - Content string `json:"content"` - ToolCalls []ToolCall `json:"tool_calls,omitempty"` - FinishReason string `json:"finish_reason"` - Usage *UsageInfo `json:"usage,omitempty"` -} - -type UsageInfo struct { - PromptTokens int `json:"prompt_tokens"` - CompletionTokens int `json:"completion_tokens"` - TotalTokens int `json:"total_tokens"` -} - -type Message struct { - Role string `json:"role"` - Content string `json:"content"` - ToolCalls []ToolCall `json:"tool_calls,omitempty"` - ToolCallID string `json:"tool_call_id,omitempty"` -} - -type ToolDefinition struct { - Type string `json:"type"` - Function ToolFunctionDefinition `json:"function"` -} - -type ToolFunctionDefinition struct { - Name string `json:"name"` - Description string `json:"description"` - Parameters map[string]interface{} `json:"parameters"` -} +type ToolCall = protocoltypes.ToolCall +type FunctionCall = protocoltypes.FunctionCall +type LLMResponse = protocoltypes.LLMResponse +type UsageInfo = protocoltypes.UsageInfo +type Message = protocoltypes.Message +type ToolDefinition = protocoltypes.ToolDefinition +type ToolFunctionDefinition = protocoltypes.ToolFunctionDefinition type Provider struct { apiKey string @@ -62,21 +29,19 @@ type Provider struct { httpClient *http.Client } -func NewProvider(apiKey, apiBase string, proxy ...string) *Provider { - proxyURL := "" - if len(proxy) > 0 { - proxyURL = proxy[0] - } +func NewProvider(apiKey, apiBase, proxy string) *Provider { client := &http.Client{ Timeout: 120 * time.Second, } - if proxyURL != "" { - parsed, err := url.Parse(proxyURL) + if proxy != "" { + parsed, err := url.Parse(proxy) if err == nil { client.Transport = &http.Transport{ Proxy: http.ProxyURL(parsed), } + } else { + log.Printf("openai_compat: invalid proxy URL %q: %v", proxy, err) } } @@ -92,13 +57,7 @@ func (p *Provider) Chat(ctx context.Context, messages []Message, tools []ToolDef return nil, fmt.Errorf("API base not configured") } - // Strip provider prefix for OpenAI-compatible backends. - if idx := strings.Index(model, "/"); idx != -1 { - prefix := model[:idx] - if prefix == "moonshot" || prefix == "nvidia" || prefix == "groq" || prefix == "ollama" { - model = model[idx+1:] - } - } + model = normalizeModel(model, p.apiBase) requestBody := map[string]interface{}{ "model": model, @@ -110,7 +69,7 @@ func (p *Provider) Chat(ctx context.Context, messages []Message, tools []ToolDef requestBody["tool_choice"] = "auto" } - if maxTokens, ok := options["max_tokens"].(int); ok { + if maxTokens, ok := asInt(options["max_tokens"]); ok { lowerModel := strings.ToLower(model) if strings.Contains(lowerModel, "glm") || strings.Contains(lowerModel, "o1") { requestBody["max_completion_tokens"] = maxTokens @@ -119,7 +78,7 @@ func (p *Provider) Chat(ctx context.Context, messages []Message, tools []ToolDef } } - if temperature, ok := options["temperature"].(float64); ok { + if temperature, ok := asFloat(options["temperature"]); ok { lowerModel := strings.ToLower(model) // Kimi k2 models only support temperature=1. if strings.Contains(lowerModel, "kimi") && strings.Contains(lowerModel, "k2") { @@ -198,17 +157,11 @@ func parseResponse(body []byte) (*LLMResponse, error) { arguments := make(map[string]interface{}) name := "" - if tc.Type == "function" && tc.Function != nil { - name = tc.Function.Name - if tc.Function.Arguments != "" { - if err := json.Unmarshal([]byte(tc.Function.Arguments), &arguments); err != nil { - arguments["raw"] = tc.Function.Arguments - } - } - } else if tc.Function != nil { + if tc.Function != nil { name = tc.Function.Name if tc.Function.Arguments != "" { if err := json.Unmarshal([]byte(tc.Function.Arguments), &arguments); err != nil { + log.Printf("openai_compat: failed to decode tool call arguments for %q: %v", name, err) arguments["raw"] = tc.Function.Arguments } } @@ -228,3 +181,52 @@ func parseResponse(body []byte) (*LLMResponse, error) { Usage: apiResponse.Usage, }, nil } + +func normalizeModel(model, apiBase string) string { + idx := strings.Index(model, "/") + if idx == -1 { + return model + } + + if strings.Contains(strings.ToLower(apiBase), "openrouter.ai") { + return model + } + + prefix := strings.ToLower(model[:idx]) + switch prefix { + case "moonshot", "nvidia", "groq", "ollama", "deepseek", "google", "openrouter", "zhipu": + return model[idx+1:] + default: + return model + } +} + +func asInt(v interface{}) (int, bool) { + switch val := v.(type) { + case int: + return val, true + case int64: + return int(val), true + case float64: + return int(val), true + case float32: + return int(val), true + default: + return 0, false + } +} + +func asFloat(v interface{}) (float64, bool) { + switch val := v.(type) { + case float64: + return val, true + case float32: + return float64(val), true + case int: + return float64(val), true + case int64: + return float64(val), true + default: + return 0, false + } +} diff --git a/pkg/providers/openai_compat/provider_test.go b/pkg/providers/openai_compat/provider_test.go index e5926458b..94779b39c 100644 --- a/pkg/providers/openai_compat/provider_test.go +++ b/pkg/providers/openai_compat/provider_test.go @@ -4,6 +4,7 @@ import ( "encoding/json" "net/http" "net/http/httptest" + "net/url" "testing" ) @@ -32,7 +33,7 @@ func TestProviderChat_UsesMaxCompletionTokensForGLM(t *testing.T) { })) defer server.Close() - p := NewProvider("key", server.URL) + p := NewProvider("key", server.URL, "") _, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "glm-4.7", map[string]interface{}{"max_tokens": 1234}) if err != nil { t.Fatalf("Chat() error = %v", err) @@ -78,7 +79,7 @@ func TestProviderChat_ParsesToolCalls(t *testing.T) { })) defer server.Close() - p := NewProvider("key", server.URL) + p := NewProvider("key", server.URL, "") out, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "gpt-4o", nil) if err != nil { t.Fatalf("Chat() error = %v", err) @@ -100,7 +101,7 @@ func TestProviderChat_HTTPError(t *testing.T) { })) defer server.Close() - p := NewProvider("key", server.URL) + p := NewProvider("key", server.URL, "") _, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "gpt-4o", nil) if err == nil { t.Fatal("expected error, got nil") @@ -128,7 +129,7 @@ func TestProviderChat_StripsMoonshotPrefixAndNormalizesKimiTemperature(t *testin })) defer server.Close() - p := NewProvider("key", server.URL) + p := NewProvider("key", server.URL, "") _, err := p.Chat( t.Context(), []Message{{Role: "user", Content: "hi"}}, @@ -164,6 +165,11 @@ func TestProviderChat_StripsGroqAndOllamaPrefixes(t *testing.T) { input: "ollama/qwen2.5:14b", wantModel: "qwen2.5:14b", }, + { + name: "strips deepseek prefix", + input: "deepseek/deepseek-chat", + wantModel: "deepseek-chat", + }, } for _, tt := range tests { @@ -188,7 +194,7 @@ func TestProviderChat_StripsGroqAndOllamaPrefixes(t *testing.T) { })) defer server.Close() - p := NewProvider("key", server.URL) + p := NewProvider("key", server.URL, "") _, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, tt.input, nil) if err != nil { t.Fatalf("Chat() error = %v", err) @@ -200,3 +206,72 @@ func TestProviderChat_StripsGroqAndOllamaPrefixes(t *testing.T) { }) } } + +func TestProvider_ProxyConfigured(t *testing.T) { + proxyURL := "http://127.0.0.1:8080" + p := NewProvider("key", "https://example.com", proxyURL) + + transport, ok := p.httpClient.Transport.(*http.Transport) + if !ok || transport == nil { + t.Fatalf("expected http transport with proxy, got %T", p.httpClient.Transport) + } + + req := &http.Request{URL: &url.URL{Scheme: "https", Host: "api.example.com"}} + gotProxy, err := transport.Proxy(req) + if err != nil { + t.Fatalf("proxy function returned error: %v", err) + } + if gotProxy == nil || gotProxy.String() != proxyURL { + t.Fatalf("proxy = %v, want %s", gotProxy, proxyURL) + } +} + +func TestProviderChat_AcceptsNumericOptionTypes(t *testing.T) { + var requestBody map[string]interface{} + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + resp := map[string]interface{}{ + "choices": []map[string]interface{}{ + { + "message": map[string]interface{}{"content": "ok"}, + "finish_reason": "stop", + }, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + p := NewProvider("key", server.URL, "") + _, err := p.Chat( + t.Context(), + []Message{{Role: "user", Content: "hi"}}, + nil, + "gpt-4o", + map[string]interface{}{"max_tokens": float64(512), "temperature": 1}, + ) + if err != nil { + t.Fatalf("Chat() error = %v", err) + } + + if requestBody["max_tokens"] != float64(512) { + t.Fatalf("max_tokens = %v, want 512", requestBody["max_tokens"]) + } + if requestBody["temperature"] != float64(1) { + t.Fatalf("temperature = %v, want 1", requestBody["temperature"]) + } +} + +func TestNormalizeModel_UsesAPIBase(t *testing.T) { + if got := normalizeModel("deepseek/deepseek-chat", "https://api.deepseek.com/v1"); got != "deepseek-chat" { + t.Fatalf("normalizeModel(deepseek) = %q, want %q", got, "deepseek-chat") + } + if got := normalizeModel("openrouter/auto", "https://openrouter.ai/api/v1"); got != "openrouter/auto" { + t.Fatalf("normalizeModel(openrouter) = %q, want %q", got, "openrouter/auto") + } +} diff --git a/pkg/providers/protocoltypes/types.go b/pkg/providers/protocoltypes/types.go new file mode 100644 index 000000000..6b33ae734 --- /dev/null +++ b/pkg/providers/protocoltypes/types.go @@ -0,0 +1,45 @@ +package protocoltypes + +type ToolCall struct { + ID string `json:"id"` + Type string `json:"type,omitempty"` + Function *FunctionCall `json:"function,omitempty"` + Name string `json:"name,omitempty"` + Arguments map[string]interface{} `json:"arguments,omitempty"` +} + +type FunctionCall struct { + Name string `json:"name"` + Arguments string `json:"arguments"` +} + +type LLMResponse struct { + Content string `json:"content"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + FinishReason string `json:"finish_reason"` + Usage *UsageInfo `json:"usage,omitempty"` +} + +type UsageInfo struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` +} + +type Message struct { + Role string `json:"role"` + Content string `json:"content"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` +} + +type ToolDefinition struct { + Type string `json:"type"` + Function ToolFunctionDefinition `json:"function"` +} + +type ToolFunctionDefinition struct { + Name string `json:"name"` + Description string `json:"description"` + Parameters map[string]interface{} `json:"parameters"` +} diff --git a/pkg/providers/types.go b/pkg/providers/types.go index 88b62e975..221a842fa 100644 --- a/pkg/providers/types.go +++ b/pkg/providers/types.go @@ -1,52 +1,20 @@ package providers -import "context" +import ( + "context" -type ToolCall struct { - ID string `json:"id"` - Type string `json:"type,omitempty"` - Function *FunctionCall `json:"function,omitempty"` - Name string `json:"name,omitempty"` - Arguments map[string]interface{} `json:"arguments,omitempty"` -} + "github.com/sipeed/picoclaw/pkg/providers/protocoltypes" +) -type FunctionCall struct { - Name string `json:"name"` - Arguments string `json:"arguments"` -} - -type LLMResponse struct { - Content string `json:"content"` - ToolCalls []ToolCall `json:"tool_calls,omitempty"` - FinishReason string `json:"finish_reason"` - Usage *UsageInfo `json:"usage,omitempty"` -} - -type UsageInfo struct { - PromptTokens int `json:"prompt_tokens"` - CompletionTokens int `json:"completion_tokens"` - TotalTokens int `json:"total_tokens"` -} - -type Message struct { - Role string `json:"role"` - Content string `json:"content"` - ToolCalls []ToolCall `json:"tool_calls,omitempty"` - ToolCallID string `json:"tool_call_id,omitempty"` -} +type ToolCall = protocoltypes.ToolCall +type FunctionCall = protocoltypes.FunctionCall +type LLMResponse = protocoltypes.LLMResponse +type UsageInfo = protocoltypes.UsageInfo +type Message = protocoltypes.Message +type ToolDefinition = protocoltypes.ToolDefinition +type ToolFunctionDefinition = protocoltypes.ToolFunctionDefinition type LLMProvider interface { Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) GetDefaultModel() string } - -type ToolDefinition struct { - Type string `json:"type"` - Function ToolFunctionDefinition `json:"function"` -} - -type ToolFunctionDefinition struct { - Name string `json:"name"` - Description string `json:"description"` - Parameters map[string]interface{} `json:"parameters"` -}