From a6e885bb473a20d671ed1dab5e8e8ea9bb8cd399 Mon Sep 17 00:00:00 2001 From: Jared Mahotiere Date: Sun, 15 Feb 2026 08:04:07 -0500 Subject: [PATCH 1/4] refactor(providers): extract protocol factory and openai-compat transport --- pkg/providers/factory.go | 291 ++++++++++++ pkg/providers/factory_test.go | 150 ++++++ pkg/providers/http_provider.go | 473 ++++--------------- pkg/providers/openai_compat/provider.go | 230 +++++++++ pkg/providers/openai_compat/provider_test.go | 149 ++++++ 5 files changed, 905 insertions(+), 388 deletions(-) create mode 100644 pkg/providers/factory.go create mode 100644 pkg/providers/factory_test.go create mode 100644 pkg/providers/openai_compat/provider.go create mode 100644 pkg/providers/openai_compat/provider_test.go diff --git a/pkg/providers/factory.go b/pkg/providers/factory.go new file mode 100644 index 000000000..84dcd9aaa --- /dev/null +++ b/pkg/providers/factory.go @@ -0,0 +1,291 @@ +package providers + +import ( + "fmt" + "strings" + + "github.com/sipeed/picoclaw/pkg/auth" + "github.com/sipeed/picoclaw/pkg/config" +) + +type providerType int + +const ( + providerTypeHTTPCompat providerType = iota + providerTypeClaudeAuth + providerTypeCodexAuth + providerTypeClaudeCLI + providerTypeGitHubCopilot +) + +type providerSelection struct { + providerType providerType + apiKey string + apiBase string + proxy string + model string + workspace string + connectMode string +} + +func createClaudeAuthProvider() (LLMProvider, error) { + cred, err := auth.GetCredential("anthropic") + if err != nil { + return nil, fmt.Errorf("loading auth credentials: %w", err) + } + if cred == nil { + return nil, fmt.Errorf("no credentials for anthropic. Run: picoclaw auth login --provider anthropic") + } + return NewClaudeProviderWithTokenSource(cred.AccessToken, createClaudeTokenSource()), nil +} + +func createCodexAuthProvider() (LLMProvider, error) { + cred, err := auth.GetCredential("openai") + if err != nil { + return nil, fmt.Errorf("loading auth credentials: %w", err) + } + if cred == nil { + return nil, fmt.Errorf("no credentials for openai. Run: picoclaw auth login --provider openai") + } + return NewCodexProviderWithTokenSource(cred.AccessToken, cred.AccountID, createCodexTokenSource()), nil +} + +func resolveProviderSelection(cfg *config.Config) (providerSelection, error) { + model := cfg.Agents.Defaults.Model + providerName := strings.ToLower(cfg.Agents.Defaults.Provider) + lowerModel := strings.ToLower(model) + + sel := providerSelection{ + providerType: providerTypeHTTPCompat, + model: model, + } + + // First, prefer explicit provider configuration. + if providerName != "" { + switch providerName { + case "groq": + if cfg.Providers.Groq.APIKey != "" { + sel.apiKey = cfg.Providers.Groq.APIKey + sel.apiBase = cfg.Providers.Groq.APIBase + if sel.apiBase == "" { + sel.apiBase = "https://api.groq.com/openai/v1" + } + } + case "openai", "gpt": + if cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != "" { + if cfg.Providers.OpenAI.AuthMethod == "oauth" || cfg.Providers.OpenAI.AuthMethod == "token" { + sel.providerType = providerTypeCodexAuth + return sel, nil + } + sel.apiKey = cfg.Providers.OpenAI.APIKey + sel.apiBase = cfg.Providers.OpenAI.APIBase + if sel.apiBase == "" { + sel.apiBase = "https://api.openai.com/v1" + } + } + case "anthropic", "claude": + if cfg.Providers.Anthropic.APIKey != "" || cfg.Providers.Anthropic.AuthMethod != "" { + if cfg.Providers.Anthropic.AuthMethod == "oauth" || cfg.Providers.Anthropic.AuthMethod == "token" { + sel.providerType = providerTypeClaudeAuth + return sel, nil + } + sel.apiKey = cfg.Providers.Anthropic.APIKey + sel.apiBase = cfg.Providers.Anthropic.APIBase + if sel.apiBase == "" { + sel.apiBase = "https://api.anthropic.com/v1" + } + } + case "openrouter": + if cfg.Providers.OpenRouter.APIKey != "" { + sel.apiKey = cfg.Providers.OpenRouter.APIKey + if cfg.Providers.OpenRouter.APIBase != "" { + sel.apiBase = cfg.Providers.OpenRouter.APIBase + } else { + sel.apiBase = "https://openrouter.ai/api/v1" + } + } + case "zhipu", "glm": + if cfg.Providers.Zhipu.APIKey != "" { + sel.apiKey = cfg.Providers.Zhipu.APIKey + sel.apiBase = cfg.Providers.Zhipu.APIBase + if sel.apiBase == "" { + sel.apiBase = "https://open.bigmodel.cn/api/paas/v4" + } + } + case "gemini", "google": + if cfg.Providers.Gemini.APIKey != "" { + sel.apiKey = cfg.Providers.Gemini.APIKey + sel.apiBase = cfg.Providers.Gemini.APIBase + if sel.apiBase == "" { + sel.apiBase = "https://generativelanguage.googleapis.com/v1beta" + } + } + case "vllm": + if cfg.Providers.VLLM.APIBase != "" { + sel.apiKey = cfg.Providers.VLLM.APIKey + sel.apiBase = cfg.Providers.VLLM.APIBase + } + case "shengsuanyun": + if cfg.Providers.ShengSuanYun.APIKey != "" { + sel.apiKey = cfg.Providers.ShengSuanYun.APIKey + sel.apiBase = cfg.Providers.ShengSuanYun.APIBase + if sel.apiBase == "" { + sel.apiBase = "https://router.shengsuanyun.com/api/v1" + } + } + case "claude-cli", "claude-code", "claudecode": + workspace := cfg.Agents.Defaults.Workspace + if workspace == "" { + workspace = "." + } + sel.providerType = providerTypeClaudeCLI + sel.workspace = workspace + return sel, nil + case "deepseek": + if cfg.Providers.DeepSeek.APIKey != "" { + sel.apiKey = cfg.Providers.DeepSeek.APIKey + sel.apiBase = cfg.Providers.DeepSeek.APIBase + if sel.apiBase == "" { + sel.apiBase = "https://api.deepseek.com/v1" + } + if model != "deepseek-chat" && model != "deepseek-reasoner" { + sel.model = "deepseek-chat" + } + } + case "github_copilot", "copilot": + sel.providerType = providerTypeGitHubCopilot + if cfg.Providers.GitHubCopilot.APIBase != "" { + sel.apiBase = cfg.Providers.GitHubCopilot.APIBase + } else { + sel.apiBase = "localhost:4321" + } + sel.connectMode = cfg.Providers.GitHubCopilot.ConnectMode + return sel, nil + } + } + + // Fallback: infer provider from model and configured keys. + if sel.apiKey == "" && sel.apiBase == "" { + switch { + case (strings.Contains(lowerModel, "kimi") || strings.Contains(lowerModel, "moonshot") || strings.HasPrefix(model, "moonshot/")) && cfg.Providers.Moonshot.APIKey != "": + sel.apiKey = cfg.Providers.Moonshot.APIKey + sel.apiBase = cfg.Providers.Moonshot.APIBase + sel.proxy = cfg.Providers.Moonshot.Proxy + if sel.apiBase == "" { + sel.apiBase = "https://api.moonshot.cn/v1" + } + case strings.HasPrefix(model, "openrouter/") || + strings.HasPrefix(model, "anthropic/") || + strings.HasPrefix(model, "openai/") || + strings.HasPrefix(model, "meta-llama/") || + strings.HasPrefix(model, "deepseek/") || + strings.HasPrefix(model, "google/"): + sel.apiKey = cfg.Providers.OpenRouter.APIKey + sel.proxy = cfg.Providers.OpenRouter.Proxy + if cfg.Providers.OpenRouter.APIBase != "" { + sel.apiBase = cfg.Providers.OpenRouter.APIBase + } else { + sel.apiBase = "https://openrouter.ai/api/v1" + } + case (strings.Contains(lowerModel, "claude") || strings.HasPrefix(model, "anthropic/")) && + (cfg.Providers.Anthropic.APIKey != "" || cfg.Providers.Anthropic.AuthMethod != ""): + if cfg.Providers.Anthropic.AuthMethod == "oauth" || cfg.Providers.Anthropic.AuthMethod == "token" { + sel.providerType = providerTypeClaudeAuth + return sel, nil + } + sel.apiKey = cfg.Providers.Anthropic.APIKey + sel.apiBase = cfg.Providers.Anthropic.APIBase + sel.proxy = cfg.Providers.Anthropic.Proxy + if sel.apiBase == "" { + sel.apiBase = "https://api.anthropic.com/v1" + } + case (strings.Contains(lowerModel, "gpt") || strings.HasPrefix(model, "openai/")) && + (cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != ""): + if cfg.Providers.OpenAI.AuthMethod == "oauth" || cfg.Providers.OpenAI.AuthMethod == "token" { + sel.providerType = providerTypeCodexAuth + return sel, nil + } + sel.apiKey = cfg.Providers.OpenAI.APIKey + sel.apiBase = cfg.Providers.OpenAI.APIBase + sel.proxy = cfg.Providers.OpenAI.Proxy + if sel.apiBase == "" { + sel.apiBase = "https://api.openai.com/v1" + } + case (strings.Contains(lowerModel, "gemini") || strings.HasPrefix(model, "google/")) && cfg.Providers.Gemini.APIKey != "": + sel.apiKey = cfg.Providers.Gemini.APIKey + sel.apiBase = cfg.Providers.Gemini.APIBase + sel.proxy = cfg.Providers.Gemini.Proxy + if sel.apiBase == "" { + sel.apiBase = "https://generativelanguage.googleapis.com/v1beta" + } + case (strings.Contains(lowerModel, "glm") || strings.Contains(lowerModel, "zhipu") || strings.Contains(lowerModel, "zai")) && cfg.Providers.Zhipu.APIKey != "": + sel.apiKey = cfg.Providers.Zhipu.APIKey + sel.apiBase = cfg.Providers.Zhipu.APIBase + sel.proxy = cfg.Providers.Zhipu.Proxy + if sel.apiBase == "" { + sel.apiBase = "https://open.bigmodel.cn/api/paas/v4" + } + case (strings.Contains(lowerModel, "groq") || strings.HasPrefix(model, "groq/")) && cfg.Providers.Groq.APIKey != "": + sel.apiKey = cfg.Providers.Groq.APIKey + sel.apiBase = cfg.Providers.Groq.APIBase + sel.proxy = cfg.Providers.Groq.Proxy + if sel.apiBase == "" { + sel.apiBase = "https://api.groq.com/openai/v1" + } + case (strings.Contains(lowerModel, "nvidia") || strings.HasPrefix(model, "nvidia/")) && cfg.Providers.Nvidia.APIKey != "": + sel.apiKey = cfg.Providers.Nvidia.APIKey + sel.apiBase = cfg.Providers.Nvidia.APIBase + sel.proxy = cfg.Providers.Nvidia.Proxy + if sel.apiBase == "" { + sel.apiBase = "https://integrate.api.nvidia.com/v1" + } + case cfg.Providers.VLLM.APIBase != "": + sel.apiKey = cfg.Providers.VLLM.APIKey + sel.apiBase = cfg.Providers.VLLM.APIBase + sel.proxy = cfg.Providers.VLLM.Proxy + default: + if cfg.Providers.OpenRouter.APIKey != "" { + sel.apiKey = cfg.Providers.OpenRouter.APIKey + sel.proxy = cfg.Providers.OpenRouter.Proxy + if cfg.Providers.OpenRouter.APIBase != "" { + sel.apiBase = cfg.Providers.OpenRouter.APIBase + } else { + sel.apiBase = "https://openrouter.ai/api/v1" + } + } else { + return providerSelection{}, fmt.Errorf("no API key configured for model: %s", model) + } + } + } + + if sel.providerType == providerTypeHTTPCompat { + if sel.apiKey == "" && !strings.HasPrefix(model, "bedrock/") { + return providerSelection{}, fmt.Errorf("no API key configured for provider (model: %s)", model) + } + if sel.apiBase == "" { + return providerSelection{}, fmt.Errorf("no API base configured for provider (model: %s)", model) + } + } + + return sel, nil +} + +func CreateProvider(cfg *config.Config) (LLMProvider, error) { + sel, err := resolveProviderSelection(cfg) + if err != nil { + return nil, err + } + + switch sel.providerType { + case providerTypeClaudeAuth: + return createClaudeAuthProvider() + case providerTypeCodexAuth: + return createCodexAuthProvider() + case providerTypeClaudeCLI: + return NewClaudeCliProvider(sel.workspace), nil + case providerTypeGitHubCopilot: + return NewGitHubCopilotProvider(sel.apiBase, sel.connectMode, sel.model) + default: + return NewHTTPProvider(sel.apiKey, sel.apiBase, sel.proxy), nil + } +} diff --git a/pkg/providers/factory_test.go b/pkg/providers/factory_test.go new file mode 100644 index 000000000..f894b292a --- /dev/null +++ b/pkg/providers/factory_test.go @@ -0,0 +1,150 @@ +package providers + +import ( + "strings" + "testing" + + "github.com/sipeed/picoclaw/pkg/config" +) + +func TestResolveProviderSelection(t *testing.T) { + tests := []struct { + name string + setup func(*config.Config) + wantType providerType + wantAPIBase string + wantProxy string + wantErrSubstr string + }{ + { + name: "explicit claude-cli provider routes to cli provider type", + setup: func(cfg *config.Config) { + cfg.Agents.Defaults.Provider = "claude-cli" + cfg.Agents.Defaults.Workspace = "/tmp/ws" + }, + wantType: providerTypeClaudeCLI, + }, + { + name: "explicit copilot provider routes to github copilot type", + setup: func(cfg *config.Config) { + cfg.Agents.Defaults.Provider = "copilot" + }, + wantType: providerTypeGitHubCopilot, + wantAPIBase: "localhost:4321", + }, + { + name: "openrouter model uses openrouter defaults", + setup: func(cfg *config.Config) { + cfg.Agents.Defaults.Model = "openrouter/auto" + cfg.Providers.OpenRouter.APIKey = "sk-or-test" + }, + wantType: providerTypeHTTPCompat, + wantAPIBase: "https://openrouter.ai/api/v1", + }, + { + name: "anthropic oauth routes to claude auth provider", + setup: func(cfg *config.Config) { + cfg.Agents.Defaults.Model = "claude-sonnet-4-5-20250929" + cfg.Providers.Anthropic.AuthMethod = "oauth" + }, + wantType: providerTypeClaudeAuth, + }, + { + name: "openai oauth routes to codex auth provider", + setup: func(cfg *config.Config) { + cfg.Agents.Defaults.Model = "gpt-4o" + cfg.Providers.OpenAI.AuthMethod = "oauth" + }, + wantType: providerTypeCodexAuth, + }, + { + name: "zhipu model uses zhipu base default", + setup: func(cfg *config.Config) { + cfg.Agents.Defaults.Model = "glm-4.7" + cfg.Providers.Zhipu.APIKey = "zhipu-key" + }, + wantType: providerTypeHTTPCompat, + wantAPIBase: "https://open.bigmodel.cn/api/paas/v4", + }, + { + name: "groq model uses groq base default", + setup: func(cfg *config.Config) { + cfg.Agents.Defaults.Model = "groq/llama-3.3-70b" + cfg.Providers.Groq.APIKey = "gsk-key" + }, + wantType: providerTypeHTTPCompat, + wantAPIBase: "https://api.groq.com/openai/v1", + }, + { + name: "moonshot model keeps proxy and default base", + setup: func(cfg *config.Config) { + cfg.Agents.Defaults.Model = "moonshot/kimi-k2.5" + cfg.Providers.Moonshot.APIKey = "moonshot-key" + cfg.Providers.Moonshot.Proxy = "http://127.0.0.1:7890" + }, + wantType: providerTypeHTTPCompat, + wantAPIBase: "https://api.moonshot.cn/v1", + wantProxy: "http://127.0.0.1:7890", + }, + { + name: "missing keys returns model config error", + setup: func(cfg *config.Config) { + cfg.Agents.Defaults.Model = "custom-model" + }, + wantErrSubstr: "no API key configured for model", + }, + { + name: "openrouter prefix without key returns provider key error", + setup: func(cfg *config.Config) { + cfg.Agents.Defaults.Model = "openrouter/auto" + }, + wantErrSubstr: "no API key configured for provider", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := config.DefaultConfig() + tt.setup(cfg) + + got, err := resolveProviderSelection(cfg) + if tt.wantErrSubstr != "" { + if err == nil { + t.Fatalf("expected error containing %q, got nil", tt.wantErrSubstr) + } + if !strings.Contains(err.Error(), tt.wantErrSubstr) { + t.Fatalf("error = %q, want substring %q", err.Error(), tt.wantErrSubstr) + } + return + } + + if err != nil { + t.Fatalf("resolveProviderSelection() error = %v", err) + } + if got.providerType != tt.wantType { + t.Fatalf("providerType = %v, want %v", got.providerType, tt.wantType) + } + if tt.wantAPIBase != "" && got.apiBase != tt.wantAPIBase { + t.Fatalf("apiBase = %q, want %q", got.apiBase, tt.wantAPIBase) + } + if tt.wantProxy != "" && got.proxy != tt.wantProxy { + t.Fatalf("proxy = %q, want %q", got.proxy, tt.wantProxy) + } + }) + } +} + +func TestCreateProviderReturnsHTTPProviderForOpenRouter(t *testing.T) { + cfg := config.DefaultConfig() + cfg.Agents.Defaults.Model = "openrouter/auto" + cfg.Providers.OpenRouter.APIKey = "sk-or-test" + + provider, err := CreateProvider(cfg) + if err != nil { + t.Fatalf("CreateProvider() error = %v", err) + } + + if _, ok := provider.(*HTTPProvider); !ok { + t.Fatalf("provider type = %T, want *HTTPProvider", provider) + } +} diff --git a/pkg/providers/http_provider.go b/pkg/providers/http_provider.go index 17eb6214c..0f7f646d8 100644 --- a/pkg/providers/http_provider.go +++ b/pkg/providers/http_provider.go @@ -7,427 +7,124 @@ package providers import ( - "bytes" "context" - "encoding/json" - "fmt" - "io" - "net/http" - "net/url" - "strings" - "time" - - "github.com/sipeed/picoclaw/pkg/auth" - "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/providers/openai_compat" ) type HTTPProvider struct { - apiKey string - apiBase string - httpClient *http.Client + delegate *openai_compat.Provider } -func NewHTTPProvider(apiKey, apiBase, proxy string) *HTTPProvider { - client := &http.Client{ - Timeout: 120 * time.Second, +func NewHTTPProvider(apiKey, apiBase string, proxy ...string) *HTTPProvider { + proxyURL := "" + if len(proxy) > 0 { + proxyURL = proxy[0] } - - if proxy != "" { - proxyURL, err := url.Parse(proxy) - if err == nil { - client.Transport = &http.Transport{ - Proxy: http.ProxyURL(proxyURL), - } - } - } - return &HTTPProvider{ - apiKey: apiKey, - apiBase: strings.TrimRight(apiBase, "/"), - httpClient: client, + delegate: openai_compat.NewProvider(apiKey, apiBase, proxyURL), } } func (p *HTTPProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) { - if p.apiBase == "" { - return nil, fmt.Errorf("API base not configured") - } - - // Strip provider prefix from model name (e.g., moonshot/kimi-k2.5 -> kimi-k2.5) - if idx := strings.Index(model, "/"); idx != -1 { - prefix := model[:idx] - if prefix == "moonshot" || prefix == "nvidia" { - model = model[idx+1:] - } - } - - requestBody := map[string]interface{}{ - "model": model, - "messages": messages, - } - - if len(tools) > 0 { - requestBody["tools"] = tools - requestBody["tool_choice"] = "auto" - } - - if maxTokens, ok := options["max_tokens"].(int); ok { - lowerModel := strings.ToLower(model) - if strings.Contains(lowerModel, "glm") || strings.Contains(lowerModel, "o1") { - requestBody["max_completion_tokens"] = maxTokens - } else { - requestBody["max_tokens"] = maxTokens - } - } - - if temperature, ok := options["temperature"].(float64); ok { - lowerModel := strings.ToLower(model) - // Kimi k2 models only support temperature=1 - if strings.Contains(lowerModel, "kimi") && strings.Contains(lowerModel, "k2") { - requestBody["temperature"] = 1.0 - } else { - requestBody["temperature"] = temperature - } - } - - jsonData, err := json.Marshal(requestBody) + compatResp, err := p.delegate.Chat(ctx, toOpenAICompatMessages(messages), toOpenAICompatTools(tools), model, options) if err != nil { - return nil, fmt.Errorf("failed to marshal request: %w", err) + return nil, err } - - req, err := http.NewRequestWithContext(ctx, "POST", p.apiBase+"/chat/completions", bytes.NewReader(jsonData)) - if err != nil { - return nil, fmt.Errorf("failed to create request: %w", err) - } - - req.Header.Set("Content-Type", "application/json") - if p.apiKey != "" { - req.Header.Set("Authorization", "Bearer "+p.apiKey) - } - - resp, err := p.httpClient.Do(req) - if err != nil { - return nil, fmt.Errorf("failed to send request: %w", err) - } - defer resp.Body.Close() - - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("failed to read response: %w", err) - } - - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("API request failed:\n Status: %d\n Body: %s", resp.StatusCode, string(body)) - } - - return p.parseResponse(body) -} - -func (p *HTTPProvider) parseResponse(body []byte) (*LLMResponse, error) { - var apiResponse struct { - Choices []struct { - Message struct { - Content string `json:"content"` - ToolCalls []struct { - ID string `json:"id"` - Type string `json:"type"` - Function *struct { - Name string `json:"name"` - Arguments string `json:"arguments"` - } `json:"function"` - } `json:"tool_calls"` - } `json:"message"` - FinishReason string `json:"finish_reason"` - } `json:"choices"` - Usage *UsageInfo `json:"usage"` - } - - if err := json.Unmarshal(body, &apiResponse); err != nil { - return nil, fmt.Errorf("failed to unmarshal response: %w", err) - } - - if len(apiResponse.Choices) == 0 { - return &LLMResponse{ - Content: "", - FinishReason: "stop", - }, nil - } - - choice := apiResponse.Choices[0] - - toolCalls := make([]ToolCall, 0, len(choice.Message.ToolCalls)) - for _, tc := range choice.Message.ToolCalls { - arguments := make(map[string]interface{}) - name := "" - - // Handle OpenAI format with nested function object - if tc.Type == "function" && tc.Function != nil { - name = tc.Function.Name - if tc.Function.Arguments != "" { - if err := json.Unmarshal([]byte(tc.Function.Arguments), &arguments); err != nil { - arguments["raw"] = tc.Function.Arguments - } - } - } else if tc.Function != nil { - // Legacy format without type field - name = tc.Function.Name - if tc.Function.Arguments != "" { - if err := json.Unmarshal([]byte(tc.Function.Arguments), &arguments); err != nil { - arguments["raw"] = tc.Function.Arguments - } - } - } - - toolCalls = append(toolCalls, ToolCall{ - ID: tc.ID, - Name: name, - Arguments: arguments, - }) - } - - return &LLMResponse{ - Content: choice.Message.Content, - ToolCalls: toolCalls, - FinishReason: choice.FinishReason, - Usage: apiResponse.Usage, - }, nil + return fromOpenAICompatResponse(compatResp), nil } func (p *HTTPProvider) GetDefaultModel() string { return "" } -func createClaudeAuthProvider() (LLMProvider, error) { - cred, err := auth.GetCredential("anthropic") - if err != nil { - return nil, fmt.Errorf("loading auth credentials: %w", err) +func toOpenAICompatMessages(messages []Message) []openai_compat.Message { + out := make([]openai_compat.Message, 0, len(messages)) + for _, msg := range messages { + out = append(out, openai_compat.Message{ + Role: msg.Role, + Content: msg.Content, + ToolCalls: toOpenAICompatToolCalls(msg.ToolCalls), + ToolCallID: msg.ToolCallID, + }) } - if cred == nil { - return nil, fmt.Errorf("no credentials for anthropic. Run: picoclaw auth login --provider anthropic") - } - return NewClaudeProviderWithTokenSource(cred.AccessToken, createClaudeTokenSource()), nil + return out } -func createCodexAuthProvider() (LLMProvider, error) { - cred, err := auth.GetCredential("openai") - if err != nil { - return nil, fmt.Errorf("loading auth credentials: %w", err) +func toOpenAICompatTools(tools []ToolDefinition) []openai_compat.ToolDefinition { + out := make([]openai_compat.ToolDefinition, 0, len(tools)) + for _, t := range tools { + out = append(out, openai_compat.ToolDefinition{ + Type: t.Type, + Function: openai_compat.ToolFunctionDefinition{ + Name: t.Function.Name, + Description: t.Function.Description, + Parameters: t.Function.Parameters, + }, + }) } - if cred == nil { - return nil, fmt.Errorf("no credentials for openai. Run: picoclaw auth login --provider openai") - } - return NewCodexProviderWithTokenSource(cred.AccessToken, cred.AccountID, createCodexTokenSource()), nil + return out } -func CreateProvider(cfg *config.Config) (LLMProvider, error) { - model := cfg.Agents.Defaults.Model - providerName := strings.ToLower(cfg.Agents.Defaults.Provider) - - var apiKey, apiBase, proxy string - - lowerModel := strings.ToLower(model) - - // First, try to use explicitly configured provider - if providerName != "" { - switch providerName { - case "groq": - if cfg.Providers.Groq.APIKey != "" { - apiKey = cfg.Providers.Groq.APIKey - apiBase = cfg.Providers.Groq.APIBase - if apiBase == "" { - apiBase = "https://api.groq.com/openai/v1" - } +func toOpenAICompatToolCalls(toolCalls []ToolCall) []openai_compat.ToolCall { + out := make([]openai_compat.ToolCall, 0, len(toolCalls)) + for _, tc := range toolCalls { + var fn *openai_compat.FunctionCall + if tc.Function != nil { + fn = &openai_compat.FunctionCall{ + Name: tc.Function.Name, + Arguments: tc.Function.Arguments, } - case "openai", "gpt": - if cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != "" { - if cfg.Providers.OpenAI.AuthMethod == "oauth" || cfg.Providers.OpenAI.AuthMethod == "token" { - return createCodexAuthProvider() - } - apiKey = cfg.Providers.OpenAI.APIKey - apiBase = cfg.Providers.OpenAI.APIBase - if apiBase == "" { - apiBase = "https://api.openai.com/v1" - } - } - case "anthropic", "claude": - if cfg.Providers.Anthropic.APIKey != "" || cfg.Providers.Anthropic.AuthMethod != "" { - if cfg.Providers.Anthropic.AuthMethod == "oauth" || cfg.Providers.Anthropic.AuthMethod == "token" { - return createClaudeAuthProvider() - } - apiKey = cfg.Providers.Anthropic.APIKey - apiBase = cfg.Providers.Anthropic.APIBase - if apiBase == "" { - apiBase = "https://api.anthropic.com/v1" - } - } - case "openrouter": - if cfg.Providers.OpenRouter.APIKey != "" { - apiKey = cfg.Providers.OpenRouter.APIKey - if cfg.Providers.OpenRouter.APIBase != "" { - apiBase = cfg.Providers.OpenRouter.APIBase - } else { - apiBase = "https://openrouter.ai/api/v1" - } - } - case "zhipu", "glm": - if cfg.Providers.Zhipu.APIKey != "" { - apiKey = cfg.Providers.Zhipu.APIKey - apiBase = cfg.Providers.Zhipu.APIBase - if apiBase == "" { - apiBase = "https://open.bigmodel.cn/api/paas/v4" - } - } - case "gemini", "google": - if cfg.Providers.Gemini.APIKey != "" { - apiKey = cfg.Providers.Gemini.APIKey - apiBase = cfg.Providers.Gemini.APIBase - if apiBase == "" { - apiBase = "https://generativelanguage.googleapis.com/v1beta" - } - } - case "vllm": - if cfg.Providers.VLLM.APIBase != "" { - apiKey = cfg.Providers.VLLM.APIKey - apiBase = cfg.Providers.VLLM.APIBase - } - case "shengsuanyun": - if cfg.Providers.ShengSuanYun.APIKey != "" { - apiKey = cfg.Providers.ShengSuanYun.APIKey - apiBase = cfg.Providers.ShengSuanYun.APIBase - if apiBase == "" { - apiBase = "https://router.shengsuanyun.com/api/v1" - } - } - case "claude-cli", "claudecode", "claude-code": - workspace := cfg.Agents.Defaults.Workspace - if workspace == "" { - workspace = "." - } - return NewClaudeCliProvider(workspace), nil - case "deepseek": - if cfg.Providers.DeepSeek.APIKey != "" { - apiKey = cfg.Providers.DeepSeek.APIKey - apiBase = cfg.Providers.DeepSeek.APIBase - if apiBase == "" { - apiBase = "https://api.deepseek.com/v1" - } - if model != "deepseek-chat" && model != "deepseek-reasoner" { - model = "deepseek-chat" - } - } - case "github_copilot", "copilot": - if cfg.Providers.GitHubCopilot.APIBase != "" { - apiBase = cfg.Providers.GitHubCopilot.APIBase - } else { - apiBase = "localhost:4321" - } - return NewGitHubCopilotProvider(apiBase, cfg.Providers.GitHubCopilot.ConnectMode, model) - } + out = append(out, openai_compat.ToolCall{ + ID: tc.ID, + Type: tc.Type, + Function: fn, + Name: tc.Name, + Arguments: tc.Arguments, + }) + } + return out +} +func fromOpenAICompatResponse(resp *openai_compat.LLMResponse) *LLMResponse { + if resp == nil { + return &LLMResponse{} } - // Fallback: detect provider from model name - if apiKey == "" && apiBase == "" { - switch { - case (strings.Contains(lowerModel, "kimi") || strings.Contains(lowerModel, "moonshot") || strings.HasPrefix(model, "moonshot/")) && cfg.Providers.Moonshot.APIKey != "": - apiKey = cfg.Providers.Moonshot.APIKey - apiBase = cfg.Providers.Moonshot.APIBase - proxy = cfg.Providers.Moonshot.Proxy - if apiBase == "" { - apiBase = "https://api.moonshot.cn/v1" - } - - case strings.HasPrefix(model, "openrouter/") || strings.HasPrefix(model, "anthropic/") || strings.HasPrefix(model, "openai/") || strings.HasPrefix(model, "meta-llama/") || strings.HasPrefix(model, "deepseek/") || strings.HasPrefix(model, "google/"): - apiKey = cfg.Providers.OpenRouter.APIKey - proxy = cfg.Providers.OpenRouter.Proxy - if cfg.Providers.OpenRouter.APIBase != "" { - apiBase = cfg.Providers.OpenRouter.APIBase - } else { - apiBase = "https://openrouter.ai/api/v1" - } - - case (strings.Contains(lowerModel, "claude") || strings.HasPrefix(model, "anthropic/")) && (cfg.Providers.Anthropic.APIKey != "" || cfg.Providers.Anthropic.AuthMethod != ""): - if cfg.Providers.Anthropic.AuthMethod == "oauth" || cfg.Providers.Anthropic.AuthMethod == "token" { - return createClaudeAuthProvider() - } - apiKey = cfg.Providers.Anthropic.APIKey - apiBase = cfg.Providers.Anthropic.APIBase - proxy = cfg.Providers.Anthropic.Proxy - if apiBase == "" { - apiBase = "https://api.anthropic.com/v1" - } - - case (strings.Contains(lowerModel, "gpt") || strings.HasPrefix(model, "openai/")) && (cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != ""): - if cfg.Providers.OpenAI.AuthMethod == "oauth" || cfg.Providers.OpenAI.AuthMethod == "token" { - return createCodexAuthProvider() - } - apiKey = cfg.Providers.OpenAI.APIKey - apiBase = cfg.Providers.OpenAI.APIBase - proxy = cfg.Providers.OpenAI.Proxy - if apiBase == "" { - apiBase = "https://api.openai.com/v1" - } - - case (strings.Contains(lowerModel, "gemini") || strings.HasPrefix(model, "google/")) && cfg.Providers.Gemini.APIKey != "": - apiKey = cfg.Providers.Gemini.APIKey - apiBase = cfg.Providers.Gemini.APIBase - proxy = cfg.Providers.Gemini.Proxy - if apiBase == "" { - apiBase = "https://generativelanguage.googleapis.com/v1beta" - } - - case (strings.Contains(lowerModel, "glm") || strings.Contains(lowerModel, "zhipu") || strings.Contains(lowerModel, "zai")) && cfg.Providers.Zhipu.APIKey != "": - apiKey = cfg.Providers.Zhipu.APIKey - apiBase = cfg.Providers.Zhipu.APIBase - proxy = cfg.Providers.Zhipu.Proxy - if apiBase == "" { - apiBase = "https://open.bigmodel.cn/api/paas/v4" - } - - case (strings.Contains(lowerModel, "groq") || strings.HasPrefix(model, "groq/")) && cfg.Providers.Groq.APIKey != "": - apiKey = cfg.Providers.Groq.APIKey - apiBase = cfg.Providers.Groq.APIBase - proxy = cfg.Providers.Groq.Proxy - if apiBase == "" { - apiBase = "https://api.groq.com/openai/v1" - } - - case (strings.Contains(lowerModel, "nvidia") || strings.HasPrefix(model, "nvidia/")) && cfg.Providers.Nvidia.APIKey != "": - apiKey = cfg.Providers.Nvidia.APIKey - apiBase = cfg.Providers.Nvidia.APIBase - proxy = cfg.Providers.Nvidia.Proxy - if apiBase == "" { - apiBase = "https://integrate.api.nvidia.com/v1" - } - - case cfg.Providers.VLLM.APIBase != "": - apiKey = cfg.Providers.VLLM.APIKey - apiBase = cfg.Providers.VLLM.APIBase - proxy = cfg.Providers.VLLM.Proxy - - default: - if cfg.Providers.OpenRouter.APIKey != "" { - apiKey = cfg.Providers.OpenRouter.APIKey - proxy = cfg.Providers.OpenRouter.Proxy - if cfg.Providers.OpenRouter.APIBase != "" { - apiBase = cfg.Providers.OpenRouter.APIBase - } else { - apiBase = "https://openrouter.ai/api/v1" - } - } else { - return nil, fmt.Errorf("no API key configured for model: %s", model) - } + var usage *UsageInfo + if resp.Usage != nil { + usage = &UsageInfo{ + PromptTokens: resp.Usage.PromptTokens, + CompletionTokens: resp.Usage.CompletionTokens, + TotalTokens: resp.Usage.TotalTokens, } } - if apiKey == "" && !strings.HasPrefix(model, "bedrock/") { - return nil, fmt.Errorf("no API key configured for provider (model: %s)", model) + return &LLMResponse{ + Content: resp.Content, + ToolCalls: fromOpenAICompatToolCalls(resp.ToolCalls), + FinishReason: resp.FinishReason, + Usage: usage, } - - if apiBase == "" { - return nil, fmt.Errorf("no API base configured for provider (model: %s)", model) - } - - return NewHTTPProvider(apiKey, apiBase, proxy), nil +} + +func fromOpenAICompatToolCalls(toolCalls []openai_compat.ToolCall) []ToolCall { + out := make([]ToolCall, 0, len(toolCalls)) + for _, tc := range toolCalls { + var fn *FunctionCall + if tc.Function != nil { + fn = &FunctionCall{ + Name: tc.Function.Name, + Arguments: tc.Function.Arguments, + } + } + out = append(out, ToolCall{ + ID: tc.ID, + Type: tc.Type, + Function: fn, + Name: tc.Name, + Arguments: tc.Arguments, + }) + } + return out } diff --git a/pkg/providers/openai_compat/provider.go b/pkg/providers/openai_compat/provider.go new file mode 100644 index 000000000..4aef1389a --- /dev/null +++ b/pkg/providers/openai_compat/provider.go @@ -0,0 +1,230 @@ +package openai_compat + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" +) + +type ToolCall struct { + ID string `json:"id"` + Type string `json:"type,omitempty"` + Function *FunctionCall `json:"function,omitempty"` + Name string `json:"name,omitempty"` + Arguments map[string]interface{} `json:"arguments,omitempty"` +} + +type FunctionCall struct { + Name string `json:"name"` + Arguments string `json:"arguments"` +} + +type LLMResponse struct { + Content string `json:"content"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + FinishReason string `json:"finish_reason"` + Usage *UsageInfo `json:"usage,omitempty"` +} + +type UsageInfo struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` +} + +type Message struct { + Role string `json:"role"` + Content string `json:"content"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` +} + +type ToolDefinition struct { + Type string `json:"type"` + Function ToolFunctionDefinition `json:"function"` +} + +type ToolFunctionDefinition struct { + Name string `json:"name"` + Description string `json:"description"` + Parameters map[string]interface{} `json:"parameters"` +} + +type Provider struct { + apiKey string + apiBase string + httpClient *http.Client +} + +func NewProvider(apiKey, apiBase string, proxy ...string) *Provider { + proxyURL := "" + if len(proxy) > 0 { + proxyURL = proxy[0] + } + client := &http.Client{ + Timeout: 120 * time.Second, + } + + if proxyURL != "" { + parsed, err := url.Parse(proxyURL) + if err == nil { + client.Transport = &http.Transport{ + Proxy: http.ProxyURL(parsed), + } + } + } + + return &Provider{ + apiKey: apiKey, + apiBase: strings.TrimRight(apiBase, "/"), + httpClient: client, + } +} + +func (p *Provider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) { + if p.apiBase == "" { + return nil, fmt.Errorf("API base not configured") + } + + // Strip provider prefix (moonshot/kimi-*, nvidia/*) for OpenAI-compatible backends. + if idx := strings.Index(model, "/"); idx != -1 { + prefix := model[:idx] + if prefix == "moonshot" || prefix == "nvidia" { + model = model[idx+1:] + } + } + + requestBody := map[string]interface{}{ + "model": model, + "messages": messages, + } + + if len(tools) > 0 { + requestBody["tools"] = tools + requestBody["tool_choice"] = "auto" + } + + if maxTokens, ok := options["max_tokens"].(int); ok { + lowerModel := strings.ToLower(model) + if strings.Contains(lowerModel, "glm") || strings.Contains(lowerModel, "o1") { + requestBody["max_completion_tokens"] = maxTokens + } else { + requestBody["max_tokens"] = maxTokens + } + } + + if temperature, ok := options["temperature"].(float64); ok { + lowerModel := strings.ToLower(model) + // Kimi k2 models only support temperature=1. + if strings.Contains(lowerModel, "kimi") && strings.Contains(lowerModel, "k2") { + requestBody["temperature"] = 1.0 + } else { + requestBody["temperature"] = temperature + } + } + + jsonData, err := json.Marshal(requestBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, "POST", p.apiBase+"/chat/completions", bytes.NewReader(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + if p.apiKey != "" { + req.Header.Set("Authorization", "Bearer "+p.apiKey) + } + + resp, err := p.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("API request failed:\n Status: %d\n Body: %s", resp.StatusCode, string(body)) + } + + return parseResponse(body) +} + +func parseResponse(body []byte) (*LLMResponse, error) { + var apiResponse struct { + Choices []struct { + Message struct { + Content string `json:"content"` + ToolCalls []struct { + ID string `json:"id"` + Type string `json:"type"` + Function *struct { + Name string `json:"name"` + Arguments string `json:"arguments"` + } `json:"function"` + } `json:"tool_calls"` + } `json:"message"` + FinishReason string `json:"finish_reason"` + } `json:"choices"` + Usage *UsageInfo `json:"usage"` + } + + if err := json.Unmarshal(body, &apiResponse); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + + if len(apiResponse.Choices) == 0 { + return &LLMResponse{ + Content: "", + FinishReason: "stop", + }, nil + } + + choice := apiResponse.Choices[0] + toolCalls := make([]ToolCall, 0, len(choice.Message.ToolCalls)) + for _, tc := range choice.Message.ToolCalls { + arguments := make(map[string]interface{}) + name := "" + + if tc.Type == "function" && tc.Function != nil { + name = tc.Function.Name + if tc.Function.Arguments != "" { + if err := json.Unmarshal([]byte(tc.Function.Arguments), &arguments); err != nil { + arguments["raw"] = tc.Function.Arguments + } + } + } else if tc.Function != nil { + name = tc.Function.Name + if tc.Function.Arguments != "" { + if err := json.Unmarshal([]byte(tc.Function.Arguments), &arguments); err != nil { + arguments["raw"] = tc.Function.Arguments + } + } + } + + toolCalls = append(toolCalls, ToolCall{ + ID: tc.ID, + Name: name, + Arguments: arguments, + }) + } + + return &LLMResponse{ + Content: choice.Message.Content, + ToolCalls: toolCalls, + FinishReason: choice.FinishReason, + Usage: apiResponse.Usage, + }, nil +} diff --git a/pkg/providers/openai_compat/provider_test.go b/pkg/providers/openai_compat/provider_test.go new file mode 100644 index 000000000..7c5f1c63c --- /dev/null +++ b/pkg/providers/openai_compat/provider_test.go @@ -0,0 +1,149 @@ +package openai_compat + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" +) + +func TestProviderChat_UsesMaxCompletionTokensForGLM(t *testing.T) { + var requestBody map[string]interface{} + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/chat/completions" { + http.Error(w, "not found", http.StatusNotFound) + return + } + if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + resp := map[string]interface{}{ + "choices": []map[string]interface{}{ + { + "message": map[string]interface{}{"content": "ok"}, + "finish_reason": "stop", + }, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + p := NewProvider("key", server.URL) + _, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "glm-4.7", map[string]interface{}{"max_tokens": 1234}) + if err != nil { + t.Fatalf("Chat() error = %v", err) + } + + if _, ok := requestBody["max_completion_tokens"]; !ok { + t.Fatalf("expected max_completion_tokens in request body") + } + if _, ok := requestBody["max_tokens"]; ok { + t.Fatalf("did not expect max_tokens key for glm model") + } +} + +func TestProviderChat_ParsesToolCalls(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + resp := map[string]interface{}{ + "choices": []map[string]interface{}{ + { + "message": map[string]interface{}{ + "content": "", + "tool_calls": []map[string]interface{}{ + { + "id": "call_1", + "type": "function", + "function": map[string]interface{}{ + "name": "get_weather", + "arguments": "{\"city\":\"SF\"}", + }, + }, + }, + }, + "finish_reason": "tool_calls", + }, + }, + "usage": map[string]interface{}{ + "prompt_tokens": 10, + "completion_tokens": 5, + "total_tokens": 15, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + p := NewProvider("key", server.URL) + out, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "gpt-4o", nil) + if err != nil { + t.Fatalf("Chat() error = %v", err) + } + if len(out.ToolCalls) != 1 { + t.Fatalf("len(ToolCalls) = %d, want 1", len(out.ToolCalls)) + } + if out.ToolCalls[0].Name != "get_weather" { + t.Fatalf("ToolCalls[0].Name = %q, want %q", out.ToolCalls[0].Name, "get_weather") + } + if out.ToolCalls[0].Arguments["city"] != "SF" { + t.Fatalf("ToolCalls[0].Arguments[city] = %v, want SF", out.ToolCalls[0].Arguments["city"]) + } +} + +func TestProviderChat_HTTPError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "bad request", http.StatusBadRequest) + })) + defer server.Close() + + p := NewProvider("key", server.URL) + _, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "gpt-4o", nil) + if err == nil { + t.Fatal("expected error, got nil") + } +} + +func TestProviderChat_StripsMoonshotPrefixAndNormalizesKimiTemperature(t *testing.T) { + var requestBody map[string]interface{} + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + resp := map[string]interface{}{ + "choices": []map[string]interface{}{ + { + "message": map[string]interface{}{"content": "ok"}, + "finish_reason": "stop", + }, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + p := NewProvider("key", server.URL) + _, err := p.Chat( + t.Context(), + []Message{{Role: "user", Content: "hi"}}, + nil, + "moonshot/kimi-k2.5", + map[string]interface{}{"temperature": 0.3}, + ) + if err != nil { + t.Fatalf("Chat() error = %v", err) + } + + if requestBody["model"] != "kimi-k2.5" { + t.Fatalf("model = %v, want kimi-k2.5", requestBody["model"]) + } + if requestBody["temperature"] != 1.0 { + t.Fatalf("temperature = %v, want 1.0", requestBody["temperature"]) + } +} From 762565b0d4406aee7fb617d0b5c46d85014ab04e Mon Sep 17 00:00:00 2001 From: Jared Mahotiere Date: Sun, 15 Feb 2026 08:04:12 -0500 Subject: [PATCH 2/4] refactor(providers): move anthropic logic to protocol package --- pkg/providers/anthropic/provider.go | 241 +++++++++++++++++++ pkg/providers/anthropic/provider_test.go | 208 +++++++++++++++++ pkg/providers/claude_provider.go | 281 +++++++++-------------- pkg/providers/claude_provider_test.go | 137 +---------- 4 files changed, 565 insertions(+), 302 deletions(-) create mode 100644 pkg/providers/anthropic/provider.go create mode 100644 pkg/providers/anthropic/provider_test.go diff --git a/pkg/providers/anthropic/provider.go b/pkg/providers/anthropic/provider.go new file mode 100644 index 000000000..ca72f0180 --- /dev/null +++ b/pkg/providers/anthropic/provider.go @@ -0,0 +1,241 @@ +package anthropicprovider + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/anthropics/anthropic-sdk-go" + "github.com/anthropics/anthropic-sdk-go/option" +) + +type ToolCall struct { + ID string `json:"id"` + Type string `json:"type,omitempty"` + Function *FunctionCall `json:"function,omitempty"` + Name string `json:"name,omitempty"` + Arguments map[string]interface{} `json:"arguments,omitempty"` +} + +type FunctionCall struct { + Name string `json:"name"` + Arguments string `json:"arguments"` +} + +type LLMResponse struct { + Content string `json:"content"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + FinishReason string `json:"finish_reason"` + Usage *UsageInfo `json:"usage,omitempty"` +} + +type UsageInfo struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` +} + +type Message struct { + Role string `json:"role"` + Content string `json:"content"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` +} + +type ToolDefinition struct { + Type string `json:"type"` + Function ToolFunctionDefinition `json:"function"` +} + +type ToolFunctionDefinition struct { + Name string `json:"name"` + Description string `json:"description"` + Parameters map[string]interface{} `json:"parameters"` +} + +type Provider struct { + client *anthropic.Client + tokenSource func() (string, error) +} + +func NewProvider(token string) *Provider { + client := anthropic.NewClient( + option.WithAuthToken(token), + option.WithBaseURL("https://api.anthropic.com"), + ) + return &Provider{client: &client} +} + +func NewProviderWithClient(client *anthropic.Client) *Provider { + return &Provider{client: client} +} + +func NewProviderWithTokenSource(token string, tokenSource func() (string, error)) *Provider { + p := NewProvider(token) + p.tokenSource = tokenSource + return p +} + +func (p *Provider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) { + var opts []option.RequestOption + if p.tokenSource != nil { + tok, err := p.tokenSource() + if err != nil { + return nil, fmt.Errorf("refreshing token: %w", err) + } + opts = append(opts, option.WithAuthToken(tok)) + } + + params, err := buildParams(messages, tools, model, options) + if err != nil { + return nil, err + } + + resp, err := p.client.Messages.New(ctx, params, opts...) + if err != nil { + return nil, fmt.Errorf("claude API call: %w", err) + } + + return parseResponse(resp), nil +} + +func (p *Provider) GetDefaultModel() string { + return "claude-sonnet-4-5-20250929" +} + +func buildParams(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (anthropic.MessageNewParams, error) { + var system []anthropic.TextBlockParam + var anthropicMessages []anthropic.MessageParam + + for _, msg := range messages { + switch msg.Role { + case "system": + system = append(system, anthropic.TextBlockParam{Text: msg.Content}) + case "user": + if msg.ToolCallID != "" { + anthropicMessages = append(anthropicMessages, + anthropic.NewUserMessage(anthropic.NewToolResultBlock(msg.ToolCallID, msg.Content, false)), + ) + } else { + anthropicMessages = append(anthropicMessages, + anthropic.NewUserMessage(anthropic.NewTextBlock(msg.Content)), + ) + } + case "assistant": + if len(msg.ToolCalls) > 0 { + var blocks []anthropic.ContentBlockParamUnion + if msg.Content != "" { + blocks = append(blocks, anthropic.NewTextBlock(msg.Content)) + } + for _, tc := range msg.ToolCalls { + blocks = append(blocks, anthropic.NewToolUseBlock(tc.ID, tc.Arguments, tc.Name)) + } + anthropicMessages = append(anthropicMessages, anthropic.NewAssistantMessage(blocks...)) + } else { + anthropicMessages = append(anthropicMessages, + anthropic.NewAssistantMessage(anthropic.NewTextBlock(msg.Content)), + ) + } + case "tool": + anthropicMessages = append(anthropicMessages, + anthropic.NewUserMessage(anthropic.NewToolResultBlock(msg.ToolCallID, msg.Content, false)), + ) + } + } + + maxTokens := int64(4096) + if mt, ok := options["max_tokens"].(int); ok { + maxTokens = int64(mt) + } + + params := anthropic.MessageNewParams{ + Model: anthropic.Model(model), + Messages: anthropicMessages, + MaxTokens: maxTokens, + } + + if len(system) > 0 { + params.System = system + } + + if temp, ok := options["temperature"].(float64); ok { + params.Temperature = anthropic.Float(temp) + } + + if len(tools) > 0 { + params.Tools = translateTools(tools) + } + + return params, nil +} + +func translateTools(tools []ToolDefinition) []anthropic.ToolUnionParam { + result := make([]anthropic.ToolUnionParam, 0, len(tools)) + for _, t := range tools { + tool := anthropic.ToolParam{ + Name: t.Function.Name, + InputSchema: anthropic.ToolInputSchemaParam{ + Properties: t.Function.Parameters["properties"], + }, + } + if desc := t.Function.Description; desc != "" { + tool.Description = anthropic.String(desc) + } + if req, ok := t.Function.Parameters["required"].([]interface{}); ok { + required := make([]string, 0, len(req)) + for _, r := range req { + if s, ok := r.(string); ok { + required = append(required, s) + } + } + tool.InputSchema.Required = required + } + result = append(result, anthropic.ToolUnionParam{OfTool: &tool}) + } + return result +} + +func parseResponse(resp *anthropic.Message) *LLMResponse { + var content string + var toolCalls []ToolCall + + for _, block := range resp.Content { + switch block.Type { + case "text": + tb := block.AsText() + content += tb.Text + case "tool_use": + tu := block.AsToolUse() + var args map[string]interface{} + if err := json.Unmarshal(tu.Input, &args); err != nil { + args = map[string]interface{}{"raw": string(tu.Input)} + } + toolCalls = append(toolCalls, ToolCall{ + ID: tu.ID, + Name: tu.Name, + Arguments: args, + }) + } + } + + finishReason := "stop" + switch resp.StopReason { + case anthropic.StopReasonToolUse: + finishReason = "tool_calls" + case anthropic.StopReasonMaxTokens: + finishReason = "length" + case anthropic.StopReasonEndTurn: + finishReason = "stop" + } + + return &LLMResponse{ + Content: content, + ToolCalls: toolCalls, + FinishReason: finishReason, + Usage: &UsageInfo{ + PromptTokens: int(resp.Usage.InputTokens), + CompletionTokens: int(resp.Usage.OutputTokens), + TotalTokens: int(resp.Usage.InputTokens + resp.Usage.OutputTokens), + }, + } +} diff --git a/pkg/providers/anthropic/provider_test.go b/pkg/providers/anthropic/provider_test.go new file mode 100644 index 000000000..01b4fe663 --- /dev/null +++ b/pkg/providers/anthropic/provider_test.go @@ -0,0 +1,208 @@ +package anthropicprovider + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/anthropics/anthropic-sdk-go" + anthropicoption "github.com/anthropics/anthropic-sdk-go/option" +) + +func TestBuildParams_BasicMessage(t *testing.T) { + messages := []Message{ + {Role: "user", Content: "Hello"}, + } + params, err := buildParams(messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{ + "max_tokens": 1024, + }) + if err != nil { + t.Fatalf("buildParams() error: %v", err) + } + if string(params.Model) != "claude-sonnet-4-5-20250929" { + t.Errorf("Model = %q, want %q", params.Model, "claude-sonnet-4-5-20250929") + } + if params.MaxTokens != 1024 { + t.Errorf("MaxTokens = %d, want 1024", params.MaxTokens) + } + if len(params.Messages) != 1 { + t.Fatalf("len(Messages) = %d, want 1", len(params.Messages)) + } +} + +func TestBuildParams_SystemMessage(t *testing.T) { + messages := []Message{ + {Role: "system", Content: "You are helpful"}, + {Role: "user", Content: "Hi"}, + } + params, err := buildParams(messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{}) + if err != nil { + t.Fatalf("buildParams() error: %v", err) + } + if len(params.System) != 1 { + t.Fatalf("len(System) = %d, want 1", len(params.System)) + } + if params.System[0].Text != "You are helpful" { + t.Errorf("System[0].Text = %q, want %q", params.System[0].Text, "You are helpful") + } + if len(params.Messages) != 1 { + t.Fatalf("len(Messages) = %d, want 1", len(params.Messages)) + } +} + +func TestBuildParams_ToolCallMessage(t *testing.T) { + messages := []Message{ + {Role: "user", Content: "What's the weather?"}, + { + Role: "assistant", + Content: "", + ToolCalls: []ToolCall{ + { + ID: "call_1", + Name: "get_weather", + Arguments: map[string]interface{}{"city": "SF"}, + }, + }, + }, + {Role: "tool", Content: `{"temp": 72}`, ToolCallID: "call_1"}, + } + params, err := buildParams(messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{}) + if err != nil { + t.Fatalf("buildParams() error: %v", err) + } + if len(params.Messages) != 3 { + t.Fatalf("len(Messages) = %d, want 3", len(params.Messages)) + } +} + +func TestBuildParams_WithTools(t *testing.T) { + tools := []ToolDefinition{ + { + Type: "function", + Function: ToolFunctionDefinition{ + Name: "get_weather", + Description: "Get weather for a city", + Parameters: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "city": map[string]interface{}{"type": "string"}, + }, + "required": []interface{}{"city"}, + }, + }, + }, + } + params, err := buildParams([]Message{{Role: "user", Content: "Hi"}}, tools, "claude-sonnet-4-5-20250929", map[string]interface{}{}) + if err != nil { + t.Fatalf("buildParams() error: %v", err) + } + if len(params.Tools) != 1 { + t.Fatalf("len(Tools) = %d, want 1", len(params.Tools)) + } +} + +func TestParseResponse_TextOnly(t *testing.T) { + resp := &anthropic.Message{ + Content: []anthropic.ContentBlockUnion{}, + Usage: anthropic.Usage{ + InputTokens: 10, + OutputTokens: 20, + }, + } + result := parseResponse(resp) + if result.Usage.PromptTokens != 10 { + t.Errorf("PromptTokens = %d, want 10", result.Usage.PromptTokens) + } + if result.Usage.CompletionTokens != 20 { + t.Errorf("CompletionTokens = %d, want 20", result.Usage.CompletionTokens) + } + if result.FinishReason != "stop" { + t.Errorf("FinishReason = %q, want %q", result.FinishReason, "stop") + } +} + +func TestParseResponse_StopReasons(t *testing.T) { + tests := []struct { + stopReason anthropic.StopReason + want string + }{ + {anthropic.StopReasonEndTurn, "stop"}, + {anthropic.StopReasonMaxTokens, "length"}, + {anthropic.StopReasonToolUse, "tool_calls"}, + } + for _, tt := range tests { + resp := &anthropic.Message{ + StopReason: tt.stopReason, + } + result := parseResponse(resp) + if result.FinishReason != tt.want { + t.Errorf("StopReason %q: FinishReason = %q, want %q", tt.stopReason, result.FinishReason, tt.want) + } + } +} + +func TestProvider_ChatRoundTrip(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/messages" { + http.Error(w, "not found", http.StatusNotFound) + return + } + if r.Header.Get("Authorization") != "Bearer test-token" { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + + var reqBody map[string]interface{} + json.NewDecoder(r.Body).Decode(&reqBody) + + resp := map[string]interface{}{ + "id": "msg_test", + "type": "message", + "role": "assistant", + "model": reqBody["model"], + "stop_reason": "end_turn", + "content": []map[string]interface{}{ + {"type": "text", "text": "Hello! How can I help you?"}, + }, + "usage": map[string]interface{}{ + "input_tokens": 15, + "output_tokens": 8, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + provider := NewProviderWithClient(createAnthropicTestClient(server.URL, "test-token")) + messages := []Message{{Role: "user", Content: "Hello"}} + resp, err := provider.Chat(t.Context(), messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{"max_tokens": 1024}) + if err != nil { + t.Fatalf("Chat() error: %v", err) + } + if resp.Content != "Hello! How can I help you?" { + t.Errorf("Content = %q, want %q", resp.Content, "Hello! How can I help you?") + } + if resp.FinishReason != "stop" { + t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "stop") + } + if resp.Usage.PromptTokens != 15 { + t.Errorf("PromptTokens = %d, want 15", resp.Usage.PromptTokens) + } +} + +func TestProvider_GetDefaultModel(t *testing.T) { + p := NewProvider("test-token") + if got := p.GetDefaultModel(); got != "claude-sonnet-4-5-20250929" { + t.Errorf("GetDefaultModel() = %q, want %q", got, "claude-sonnet-4-5-20250929") + } +} + +func createAnthropicTestClient(baseURL, token string) *anthropic.Client { + c := anthropic.NewClient( + anthropicoption.WithAuthToken(token), + anthropicoption.WithBaseURL(baseURL), + ) + return &c +} diff --git a/pkg/providers/claude_provider.go b/pkg/providers/claude_provider.go index ae6aca96d..16f1884c5 100644 --- a/pkg/providers/claude_provider.go +++ b/pkg/providers/claude_provider.go @@ -2,195 +2,48 @@ package providers import ( "context" - "encoding/json" "fmt" - "github.com/anthropics/anthropic-sdk-go" - "github.com/anthropics/anthropic-sdk-go/option" "github.com/sipeed/picoclaw/pkg/auth" + anthropicprovider "github.com/sipeed/picoclaw/pkg/providers/anthropic" ) type ClaudeProvider struct { - client *anthropic.Client - tokenSource func() (string, error) + delegate *anthropicprovider.Provider } func NewClaudeProvider(token string) *ClaudeProvider { - client := anthropic.NewClient( - option.WithAuthToken(token), - option.WithBaseURL("https://api.anthropic.com"), - ) - return &ClaudeProvider{client: &client} + return &ClaudeProvider{ + delegate: anthropicprovider.NewProvider(token), + } } func NewClaudeProviderWithTokenSource(token string, tokenSource func() (string, error)) *ClaudeProvider { - p := NewClaudeProvider(token) - p.tokenSource = tokenSource - return p + return &ClaudeProvider{ + delegate: anthropicprovider.NewProviderWithTokenSource(token, tokenSource), + } +} + +func newClaudeProviderWithDelegate(delegate *anthropicprovider.Provider) *ClaudeProvider { + return &ClaudeProvider{delegate: delegate} } func (p *ClaudeProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) { - var opts []option.RequestOption - if p.tokenSource != nil { - tok, err := p.tokenSource() - if err != nil { - return nil, fmt.Errorf("refreshing token: %w", err) - } - opts = append(opts, option.WithAuthToken(tok)) - } - - params, err := buildClaudeParams(messages, tools, model, options) + resp, err := p.delegate.Chat( + ctx, + toAnthropicProviderMessages(messages), + toAnthropicProviderTools(tools), + model, + options, + ) if err != nil { return nil, err } - - resp, err := p.client.Messages.New(ctx, params, opts...) - if err != nil { - return nil, fmt.Errorf("claude API call: %w", err) - } - - return parseClaudeResponse(resp), nil + return fromAnthropicProviderResponse(resp), nil } func (p *ClaudeProvider) GetDefaultModel() string { - return "claude-sonnet-4-5-20250929" -} - -func buildClaudeParams(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (anthropic.MessageNewParams, error) { - var system []anthropic.TextBlockParam - var anthropicMessages []anthropic.MessageParam - - for _, msg := range messages { - switch msg.Role { - case "system": - system = append(system, anthropic.TextBlockParam{Text: msg.Content}) - case "user": - if msg.ToolCallID != "" { - anthropicMessages = append(anthropicMessages, - anthropic.NewUserMessage(anthropic.NewToolResultBlock(msg.ToolCallID, msg.Content, false)), - ) - } else { - anthropicMessages = append(anthropicMessages, - anthropic.NewUserMessage(anthropic.NewTextBlock(msg.Content)), - ) - } - case "assistant": - if len(msg.ToolCalls) > 0 { - var blocks []anthropic.ContentBlockParamUnion - if msg.Content != "" { - blocks = append(blocks, anthropic.NewTextBlock(msg.Content)) - } - for _, tc := range msg.ToolCalls { - blocks = append(blocks, anthropic.NewToolUseBlock(tc.ID, tc.Arguments, tc.Name)) - } - anthropicMessages = append(anthropicMessages, anthropic.NewAssistantMessage(blocks...)) - } else { - anthropicMessages = append(anthropicMessages, - anthropic.NewAssistantMessage(anthropic.NewTextBlock(msg.Content)), - ) - } - case "tool": - anthropicMessages = append(anthropicMessages, - anthropic.NewUserMessage(anthropic.NewToolResultBlock(msg.ToolCallID, msg.Content, false)), - ) - } - } - - maxTokens := int64(4096) - if mt, ok := options["max_tokens"].(int); ok { - maxTokens = int64(mt) - } - - params := anthropic.MessageNewParams{ - Model: anthropic.Model(model), - Messages: anthropicMessages, - MaxTokens: maxTokens, - } - - if len(system) > 0 { - params.System = system - } - - if temp, ok := options["temperature"].(float64); ok { - params.Temperature = anthropic.Float(temp) - } - - if len(tools) > 0 { - params.Tools = translateToolsForClaude(tools) - } - - return params, nil -} - -func translateToolsForClaude(tools []ToolDefinition) []anthropic.ToolUnionParam { - result := make([]anthropic.ToolUnionParam, 0, len(tools)) - for _, t := range tools { - tool := anthropic.ToolParam{ - Name: t.Function.Name, - InputSchema: anthropic.ToolInputSchemaParam{ - Properties: t.Function.Parameters["properties"], - }, - } - if desc := t.Function.Description; desc != "" { - tool.Description = anthropic.String(desc) - } - if req, ok := t.Function.Parameters["required"].([]interface{}); ok { - required := make([]string, 0, len(req)) - for _, r := range req { - if s, ok := r.(string); ok { - required = append(required, s) - } - } - tool.InputSchema.Required = required - } - result = append(result, anthropic.ToolUnionParam{OfTool: &tool}) - } - return result -} - -func parseClaudeResponse(resp *anthropic.Message) *LLMResponse { - var content string - var toolCalls []ToolCall - - for _, block := range resp.Content { - switch block.Type { - case "text": - tb := block.AsText() - content += tb.Text - case "tool_use": - tu := block.AsToolUse() - var args map[string]interface{} - if err := json.Unmarshal(tu.Input, &args); err != nil { - args = map[string]interface{}{"raw": string(tu.Input)} - } - toolCalls = append(toolCalls, ToolCall{ - ID: tu.ID, - Name: tu.Name, - Arguments: args, - }) - } - } - - finishReason := "stop" - switch resp.StopReason { - case anthropic.StopReasonToolUse: - finishReason = "tool_calls" - case anthropic.StopReasonMaxTokens: - finishReason = "length" - case anthropic.StopReasonEndTurn: - finishReason = "stop" - } - - return &LLMResponse{ - Content: content, - ToolCalls: toolCalls, - FinishReason: finishReason, - Usage: &UsageInfo{ - PromptTokens: int(resp.Usage.InputTokens), - CompletionTokens: int(resp.Usage.OutputTokens), - TotalTokens: int(resp.Usage.InputTokens + resp.Usage.OutputTokens), - }, - } + return p.delegate.GetDefaultModel() } func createClaudeTokenSource() func() (string, error) { @@ -205,3 +58,95 @@ func createClaudeTokenSource() func() (string, error) { return cred.AccessToken, nil } } + +func toAnthropicProviderMessages(messages []Message) []anthropicprovider.Message { + out := make([]anthropicprovider.Message, 0, len(messages)) + for _, msg := range messages { + out = append(out, anthropicprovider.Message{ + Role: msg.Role, + Content: msg.Content, + ToolCalls: toAnthropicProviderToolCalls(msg.ToolCalls), + ToolCallID: msg.ToolCallID, + }) + } + return out +} + +func toAnthropicProviderTools(tools []ToolDefinition) []anthropicprovider.ToolDefinition { + out := make([]anthropicprovider.ToolDefinition, 0, len(tools)) + for _, t := range tools { + out = append(out, anthropicprovider.ToolDefinition{ + Type: t.Type, + Function: anthropicprovider.ToolFunctionDefinition{ + Name: t.Function.Name, + Description: t.Function.Description, + Parameters: t.Function.Parameters, + }, + }) + } + return out +} + +func toAnthropicProviderToolCalls(toolCalls []ToolCall) []anthropicprovider.ToolCall { + out := make([]anthropicprovider.ToolCall, 0, len(toolCalls)) + for _, tc := range toolCalls { + var fn *anthropicprovider.FunctionCall + if tc.Function != nil { + fn = &anthropicprovider.FunctionCall{ + Name: tc.Function.Name, + Arguments: tc.Function.Arguments, + } + } + out = append(out, anthropicprovider.ToolCall{ + ID: tc.ID, + Type: tc.Type, + Function: fn, + Name: tc.Name, + Arguments: tc.Arguments, + }) + } + return out +} + +func fromAnthropicProviderResponse(resp *anthropicprovider.LLMResponse) *LLMResponse { + if resp == nil { + return &LLMResponse{} + } + + var usage *UsageInfo + if resp.Usage != nil { + usage = &UsageInfo{ + PromptTokens: resp.Usage.PromptTokens, + CompletionTokens: resp.Usage.CompletionTokens, + TotalTokens: resp.Usage.TotalTokens, + } + } + + return &LLMResponse{ + Content: resp.Content, + ToolCalls: fromAnthropicProviderToolCalls(resp.ToolCalls), + FinishReason: resp.FinishReason, + Usage: usage, + } +} + +func fromAnthropicProviderToolCalls(toolCalls []anthropicprovider.ToolCall) []ToolCall { + out := make([]ToolCall, 0, len(toolCalls)) + for _, tc := range toolCalls { + var fn *FunctionCall + if tc.Function != nil { + fn = &FunctionCall{ + Name: tc.Function.Name, + Arguments: tc.Function.Arguments, + } + } + out = append(out, ToolCall{ + ID: tc.ID, + Type: tc.Type, + Function: fn, + Name: tc.Name, + Arguments: tc.Arguments, + }) + } + return out +} diff --git a/pkg/providers/claude_provider_test.go b/pkg/providers/claude_provider_test.go index bbad2d269..13bbde1fc 100644 --- a/pkg/providers/claude_provider_test.go +++ b/pkg/providers/claude_provider_test.go @@ -8,140 +8,9 @@ import ( "github.com/anthropics/anthropic-sdk-go" anthropicoption "github.com/anthropics/anthropic-sdk-go/option" + anthropicprovider "github.com/sipeed/picoclaw/pkg/providers/anthropic" ) -func TestBuildClaudeParams_BasicMessage(t *testing.T) { - messages := []Message{ - {Role: "user", Content: "Hello"}, - } - params, err := buildClaudeParams(messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{ - "max_tokens": 1024, - }) - if err != nil { - t.Fatalf("buildClaudeParams() error: %v", err) - } - if string(params.Model) != "claude-sonnet-4-5-20250929" { - t.Errorf("Model = %q, want %q", params.Model, "claude-sonnet-4-5-20250929") - } - if params.MaxTokens != 1024 { - t.Errorf("MaxTokens = %d, want 1024", params.MaxTokens) - } - if len(params.Messages) != 1 { - t.Fatalf("len(Messages) = %d, want 1", len(params.Messages)) - } -} - -func TestBuildClaudeParams_SystemMessage(t *testing.T) { - messages := []Message{ - {Role: "system", Content: "You are helpful"}, - {Role: "user", Content: "Hi"}, - } - params, err := buildClaudeParams(messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{}) - if err != nil { - t.Fatalf("buildClaudeParams() error: %v", err) - } - if len(params.System) != 1 { - t.Fatalf("len(System) = %d, want 1", len(params.System)) - } - if params.System[0].Text != "You are helpful" { - t.Errorf("System[0].Text = %q, want %q", params.System[0].Text, "You are helpful") - } - if len(params.Messages) != 1 { - t.Fatalf("len(Messages) = %d, want 1", len(params.Messages)) - } -} - -func TestBuildClaudeParams_ToolCallMessage(t *testing.T) { - messages := []Message{ - {Role: "user", Content: "What's the weather?"}, - { - Role: "assistant", - Content: "", - ToolCalls: []ToolCall{ - { - ID: "call_1", - Name: "get_weather", - Arguments: map[string]interface{}{"city": "SF"}, - }, - }, - }, - {Role: "tool", Content: `{"temp": 72}`, ToolCallID: "call_1"}, - } - params, err := buildClaudeParams(messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{}) - if err != nil { - t.Fatalf("buildClaudeParams() error: %v", err) - } - if len(params.Messages) != 3 { - t.Fatalf("len(Messages) = %d, want 3", len(params.Messages)) - } -} - -func TestBuildClaudeParams_WithTools(t *testing.T) { - tools := []ToolDefinition{ - { - Type: "function", - Function: ToolFunctionDefinition{ - Name: "get_weather", - Description: "Get weather for a city", - Parameters: map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "city": map[string]interface{}{"type": "string"}, - }, - "required": []interface{}{"city"}, - }, - }, - }, - } - params, err := buildClaudeParams([]Message{{Role: "user", Content: "Hi"}}, tools, "claude-sonnet-4-5-20250929", map[string]interface{}{}) - if err != nil { - t.Fatalf("buildClaudeParams() error: %v", err) - } - if len(params.Tools) != 1 { - t.Fatalf("len(Tools) = %d, want 1", len(params.Tools)) - } -} - -func TestParseClaudeResponse_TextOnly(t *testing.T) { - resp := &anthropic.Message{ - Content: []anthropic.ContentBlockUnion{}, - Usage: anthropic.Usage{ - InputTokens: 10, - OutputTokens: 20, - }, - } - result := parseClaudeResponse(resp) - if result.Usage.PromptTokens != 10 { - t.Errorf("PromptTokens = %d, want 10", result.Usage.PromptTokens) - } - if result.Usage.CompletionTokens != 20 { - t.Errorf("CompletionTokens = %d, want 20", result.Usage.CompletionTokens) - } - if result.FinishReason != "stop" { - t.Errorf("FinishReason = %q, want %q", result.FinishReason, "stop") - } -} - -func TestParseClaudeResponse_StopReasons(t *testing.T) { - tests := []struct { - stopReason anthropic.StopReason - want string - }{ - {anthropic.StopReasonEndTurn, "stop"}, - {anthropic.StopReasonMaxTokens, "length"}, - {anthropic.StopReasonToolUse, "tool_calls"}, - } - for _, tt := range tests { - resp := &anthropic.Message{ - StopReason: tt.stopReason, - } - result := parseClaudeResponse(resp) - if result.FinishReason != tt.want { - t.Errorf("StopReason %q: FinishReason = %q, want %q", tt.stopReason, result.FinishReason, tt.want) - } - } -} - func TestClaudeProvider_ChatRoundTrip(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/v1/messages" { @@ -175,8 +44,8 @@ func TestClaudeProvider_ChatRoundTrip(t *testing.T) { })) defer server.Close() - provider := NewClaudeProvider("test-token") - provider.client = createAnthropicTestClient(server.URL, "test-token") + delegate := anthropicprovider.NewProviderWithClient(createAnthropicTestClient(server.URL, "test-token")) + provider := newClaudeProviderWithDelegate(delegate) messages := []Message{{Role: "user", Content: "Hello"}} resp, err := provider.Chat(t.Context(), messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{"max_tokens": 1024}) From 362c49a69d0465b711153e1ab14eeaaeb779eee6 Mon Sep 17 00:00:00 2001 From: Jared Mahotiere Date: Sun, 15 Feb 2026 08:04:16 -0500 Subject: [PATCH 3/4] docs(test): document protocol architecture and migration compatibility --- README.md | 10 ++++++++++ pkg/migrate/migrate_test.go | 18 ++++++++++++++++++ 2 files changed, 28 insertions(+) diff --git a/README.md b/README.md index 091af2811..25c6d9863 100644 --- a/README.md +++ b/README.md @@ -662,6 +662,16 @@ The subagent has access to tools (message, web_search, etc.) and can communicate | `deepseek(To be tested)` | LLM (DeepSeek direct) | [platform.deepseek.com](https://platform.deepseek.com) | | `groq` | LLM + **Voice transcription** (Whisper) | [console.groq.com](https://console.groq.com) | +### Provider Architecture + +PicoClaw routes providers by protocol family: + +- OpenAI-compatible protocol: OpenRouter, OpenAI-compatible gateways, Groq, Zhipu, and vLLM-style endpoints. +- Anthropic protocol: Claude-native API behavior. +- Codex/OAuth path: OpenAI OAuth/token authentication route. + +This keeps the runtime lightweight while making new OpenAI-compatible backends mostly a config operation (`api_base` + `api_key`). +
Zhipu diff --git a/pkg/migrate/migrate_test.go b/pkg/migrate/migrate_test.go index be2360aac..e930d45f4 100644 --- a/pkg/migrate/migrate_test.go +++ b/pkg/migrate/migrate_test.go @@ -299,6 +299,24 @@ func TestConvertConfig(t *testing.T) { }) } +func TestSupportedProvidersCompatibility(t *testing.T) { + expected := []string{ + "anthropic", + "openai", + "openrouter", + "groq", + "zhipu", + "vllm", + "gemini", + } + + for _, provider := range expected { + if !supportedProviders[provider] { + t.Fatalf("supportedProviders missing expected key %q", provider) + } + } +} + func TestMergeConfig(t *testing.T) { t.Run("fills empty fields", func(t *testing.T) { existing := config.DefaultConfig() From c4cbb5fb35374d0ff917baff9196746f843b99fa Mon Sep 17 00:00:00 2001 From: Jared Mahotiere Date: Tue, 17 Feb 2026 11:13:10 -0500 Subject: [PATCH 4/4] providers: finalize PR213 review fixes Phase 1: centralize protocol message/tool/response types in protocoltypes and keep compatibility aliases in providers and protocol packages. Phase 1: preserve HTTPProvider constructor compatibility and route Anthropic api_base through factory auth/provider constructors with base URL normalization. Phase 2: expand provider routing/auth tests (deepseek/nvidia/shengsuanyun, codex/claude oauth/codex-cli) and add openai_compat + anthropic coverage for proxy transport, model normalization, numeric option coercion, token-source refresh, and base URL behavior. Phase 3: apply gofmt and validate with Dockerized tests (go test ./pkg/providers/... ./pkg/migrate and go test ./...). --- pkg/providers/anthropic/provider.go | 99 +++++++------- pkg/providers/anthropic/provider_test.go | 57 ++++++++ pkg/providers/claude_provider.go | 118 ++-------------- pkg/providers/factory.go | 47 ++++++- pkg/providers/factory_test.go | 95 +++++++++++++ pkg/providers/http_provider.go | 106 +-------------- pkg/providers/openai_compat/provider.go | 136 ++++++++++--------- pkg/providers/openai_compat/provider_test.go | 85 +++++++++++- pkg/providers/protocoltypes/types.go | 45 ++++++ pkg/providers/types.go | 54 ++------ 10 files changed, 468 insertions(+), 374 deletions(-) create mode 100644 pkg/providers/protocoltypes/types.go diff --git a/pkg/providers/anthropic/provider.go b/pkg/providers/anthropic/provider.go index ca72f0180..8f46aa70c 100644 --- a/pkg/providers/anthropic/provider.go +++ b/pkg/providers/anthropic/provider.go @@ -4,74 +4,59 @@ import ( "context" "encoding/json" "fmt" + "log" + "strings" "github.com/anthropics/anthropic-sdk-go" "github.com/anthropics/anthropic-sdk-go/option" + "github.com/sipeed/picoclaw/pkg/providers/protocoltypes" ) -type ToolCall struct { - ID string `json:"id"` - Type string `json:"type,omitempty"` - Function *FunctionCall `json:"function,omitempty"` - Name string `json:"name,omitempty"` - Arguments map[string]interface{} `json:"arguments,omitempty"` -} +type ToolCall = protocoltypes.ToolCall +type FunctionCall = protocoltypes.FunctionCall +type LLMResponse = protocoltypes.LLMResponse +type UsageInfo = protocoltypes.UsageInfo +type Message = protocoltypes.Message +type ToolDefinition = protocoltypes.ToolDefinition +type ToolFunctionDefinition = protocoltypes.ToolFunctionDefinition -type FunctionCall struct { - Name string `json:"name"` - Arguments string `json:"arguments"` -} - -type LLMResponse struct { - Content string `json:"content"` - ToolCalls []ToolCall `json:"tool_calls,omitempty"` - FinishReason string `json:"finish_reason"` - Usage *UsageInfo `json:"usage,omitempty"` -} - -type UsageInfo struct { - PromptTokens int `json:"prompt_tokens"` - CompletionTokens int `json:"completion_tokens"` - TotalTokens int `json:"total_tokens"` -} - -type Message struct { - Role string `json:"role"` - Content string `json:"content"` - ToolCalls []ToolCall `json:"tool_calls,omitempty"` - ToolCallID string `json:"tool_call_id,omitempty"` -} - -type ToolDefinition struct { - Type string `json:"type"` - Function ToolFunctionDefinition `json:"function"` -} - -type ToolFunctionDefinition struct { - Name string `json:"name"` - Description string `json:"description"` - Parameters map[string]interface{} `json:"parameters"` -} +const defaultBaseURL = "https://api.anthropic.com" type Provider struct { client *anthropic.Client tokenSource func() (string, error) + baseURL string } func NewProvider(token string) *Provider { + return NewProviderWithBaseURL(token, "") +} + +func NewProviderWithBaseURL(token, apiBase string) *Provider { + baseURL := normalizeBaseURL(apiBase) client := anthropic.NewClient( option.WithAuthToken(token), - option.WithBaseURL("https://api.anthropic.com"), + option.WithBaseURL(baseURL), ) - return &Provider{client: &client} + return &Provider{ + client: &client, + baseURL: baseURL, + } } func NewProviderWithClient(client *anthropic.Client) *Provider { - return &Provider{client: client} + return &Provider{ + client: client, + baseURL: defaultBaseURL, + } } func NewProviderWithTokenSource(token string, tokenSource func() (string, error)) *Provider { - p := NewProvider(token) + return NewProviderWithTokenSourceAndBaseURL(token, tokenSource, "") +} + +func NewProviderWithTokenSourceAndBaseURL(token string, tokenSource func() (string, error), apiBase string) *Provider { + p := NewProviderWithBaseURL(token, apiBase) p.tokenSource = tokenSource return p } @@ -103,6 +88,10 @@ func (p *Provider) GetDefaultModel() string { return "claude-sonnet-4-5-20250929" } +func (p *Provider) BaseURL() string { + return p.baseURL +} + func buildParams(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (anthropic.MessageNewParams, error) { var system []anthropic.TextBlockParam var anthropicMessages []anthropic.MessageParam @@ -208,6 +197,7 @@ func parseResponse(resp *anthropic.Message) *LLMResponse { tu := block.AsToolUse() var args map[string]interface{} if err := json.Unmarshal(tu.Input, &args); err != nil { + log.Printf("anthropic: failed to decode tool call input for %q: %v", tu.Name, err) args = map[string]interface{}{"raw": string(tu.Input)} } toolCalls = append(toolCalls, ToolCall{ @@ -239,3 +229,20 @@ func parseResponse(resp *anthropic.Message) *LLMResponse { }, } } + +func normalizeBaseURL(apiBase string) string { + base := strings.TrimSpace(apiBase) + if base == "" { + return defaultBaseURL + } + + base = strings.TrimRight(base, "/") + if strings.HasSuffix(base, "/v1") { + base = strings.TrimSuffix(base, "/v1") + } + if base == "" { + return defaultBaseURL + } + + return base +} diff --git a/pkg/providers/anthropic/provider_test.go b/pkg/providers/anthropic/provider_test.go index 01b4fe663..6a1dabafb 100644 --- a/pkg/providers/anthropic/provider_test.go +++ b/pkg/providers/anthropic/provider_test.go @@ -4,6 +4,7 @@ import ( "encoding/json" "net/http" "net/http/httptest" + "sync/atomic" "testing" "github.com/anthropics/anthropic-sdk-go" @@ -199,6 +200,62 @@ func TestProvider_GetDefaultModel(t *testing.T) { } } +func TestProvider_NewProviderWithBaseURL_NormalizesV1Suffix(t *testing.T) { + p := NewProviderWithBaseURL("token", "https://api.anthropic.com/v1/") + if got := p.BaseURL(); got != "https://api.anthropic.com" { + t.Fatalf("BaseURL() = %q, want %q", got, "https://api.anthropic.com") + } +} + +func TestProvider_ChatUsesTokenSource(t *testing.T) { + var requests int32 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/messages" { + http.Error(w, "not found", http.StatusNotFound) + return + } + atomic.AddInt32(&requests, 1) + + if got := r.Header.Get("Authorization"); got != "Bearer refreshed-token" { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + + var reqBody map[string]interface{} + json.NewDecoder(r.Body).Decode(&reqBody) + + resp := map[string]interface{}{ + "id": "msg_test", + "type": "message", + "role": "assistant", + "model": reqBody["model"], + "stop_reason": "end_turn", + "content": []map[string]interface{}{ + {"type": "text", "text": "ok"}, + }, + "usage": map[string]interface{}{ + "input_tokens": 1, + "output_tokens": 1, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + p := NewProviderWithTokenSourceAndBaseURL("stale-token", func() (string, error) { + return "refreshed-token", nil + }, server.URL) + + _, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hello"}}, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{}) + if err != nil { + t.Fatalf("Chat() error: %v", err) + } + if got := atomic.LoadInt32(&requests); got != 1 { + t.Fatalf("requests = %d, want 1", got) + } +} + func createAnthropicTestClient(baseURL, token string) *anthropic.Client { c := anthropic.NewClient( anthropicoption.WithAuthToken(token), diff --git a/pkg/providers/claude_provider.go b/pkg/providers/claude_provider.go index 16f1884c5..c72f5b0ef 100644 --- a/pkg/providers/claude_provider.go +++ b/pkg/providers/claude_provider.go @@ -3,8 +3,6 @@ package providers import ( "context" "fmt" - - "github.com/sipeed/picoclaw/pkg/auth" anthropicprovider "github.com/sipeed/picoclaw/pkg/providers/anthropic" ) @@ -18,28 +16,34 @@ func NewClaudeProvider(token string) *ClaudeProvider { } } +func NewClaudeProviderWithBaseURL(token, apiBase string) *ClaudeProvider { + return &ClaudeProvider{ + delegate: anthropicprovider.NewProviderWithBaseURL(token, apiBase), + } +} + func NewClaudeProviderWithTokenSource(token string, tokenSource func() (string, error)) *ClaudeProvider { return &ClaudeProvider{ delegate: anthropicprovider.NewProviderWithTokenSource(token, tokenSource), } } +func NewClaudeProviderWithTokenSourceAndBaseURL(token string, tokenSource func() (string, error), apiBase string) *ClaudeProvider { + return &ClaudeProvider{ + delegate: anthropicprovider.NewProviderWithTokenSourceAndBaseURL(token, tokenSource, apiBase), + } +} + func newClaudeProviderWithDelegate(delegate *anthropicprovider.Provider) *ClaudeProvider { return &ClaudeProvider{delegate: delegate} } func (p *ClaudeProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) { - resp, err := p.delegate.Chat( - ctx, - toAnthropicProviderMessages(messages), - toAnthropicProviderTools(tools), - model, - options, - ) + resp, err := p.delegate.Chat(ctx, messages, tools, model, options) if err != nil { return nil, err } - return fromAnthropicProviderResponse(resp), nil + return resp, nil } func (p *ClaudeProvider) GetDefaultModel() string { @@ -48,7 +52,7 @@ func (p *ClaudeProvider) GetDefaultModel() string { func createClaudeTokenSource() func() (string, error) { return func() (string, error) { - cred, err := auth.GetCredential("anthropic") + cred, err := getCredential("anthropic") if err != nil { return "", fmt.Errorf("loading auth credentials: %w", err) } @@ -58,95 +62,3 @@ func createClaudeTokenSource() func() (string, error) { return cred.AccessToken, nil } } - -func toAnthropicProviderMessages(messages []Message) []anthropicprovider.Message { - out := make([]anthropicprovider.Message, 0, len(messages)) - for _, msg := range messages { - out = append(out, anthropicprovider.Message{ - Role: msg.Role, - Content: msg.Content, - ToolCalls: toAnthropicProviderToolCalls(msg.ToolCalls), - ToolCallID: msg.ToolCallID, - }) - } - return out -} - -func toAnthropicProviderTools(tools []ToolDefinition) []anthropicprovider.ToolDefinition { - out := make([]anthropicprovider.ToolDefinition, 0, len(tools)) - for _, t := range tools { - out = append(out, anthropicprovider.ToolDefinition{ - Type: t.Type, - Function: anthropicprovider.ToolFunctionDefinition{ - Name: t.Function.Name, - Description: t.Function.Description, - Parameters: t.Function.Parameters, - }, - }) - } - return out -} - -func toAnthropicProviderToolCalls(toolCalls []ToolCall) []anthropicprovider.ToolCall { - out := make([]anthropicprovider.ToolCall, 0, len(toolCalls)) - for _, tc := range toolCalls { - var fn *anthropicprovider.FunctionCall - if tc.Function != nil { - fn = &anthropicprovider.FunctionCall{ - Name: tc.Function.Name, - Arguments: tc.Function.Arguments, - } - } - out = append(out, anthropicprovider.ToolCall{ - ID: tc.ID, - Type: tc.Type, - Function: fn, - Name: tc.Name, - Arguments: tc.Arguments, - }) - } - return out -} - -func fromAnthropicProviderResponse(resp *anthropicprovider.LLMResponse) *LLMResponse { - if resp == nil { - return &LLMResponse{} - } - - var usage *UsageInfo - if resp.Usage != nil { - usage = &UsageInfo{ - PromptTokens: resp.Usage.PromptTokens, - CompletionTokens: resp.Usage.CompletionTokens, - TotalTokens: resp.Usage.TotalTokens, - } - } - - return &LLMResponse{ - Content: resp.Content, - ToolCalls: fromAnthropicProviderToolCalls(resp.ToolCalls), - FinishReason: resp.FinishReason, - Usage: usage, - } -} - -func fromAnthropicProviderToolCalls(toolCalls []anthropicprovider.ToolCall) []ToolCall { - out := make([]ToolCall, 0, len(toolCalls)) - for _, tc := range toolCalls { - var fn *FunctionCall - if tc.Function != nil { - fn = &FunctionCall{ - Name: tc.Function.Name, - Arguments: tc.Function.Arguments, - } - } - out = append(out, ToolCall{ - ID: tc.ID, - Type: tc.Type, - Function: fn, - Name: tc.Name, - Arguments: tc.Arguments, - }) - } - return out -} diff --git a/pkg/providers/factory.go b/pkg/providers/factory.go index 28609c4b3..67a347721 100644 --- a/pkg/providers/factory.go +++ b/pkg/providers/factory.go @@ -8,6 +8,10 @@ import ( "github.com/sipeed/picoclaw/pkg/config" ) +const defaultAnthropicAPIBase = "https://api.anthropic.com/v1" + +var getCredential = auth.GetCredential + type providerType int const ( @@ -30,19 +34,22 @@ type providerSelection struct { connectMode string } -func createClaudeAuthProvider() (LLMProvider, error) { - cred, err := auth.GetCredential("anthropic") +func createClaudeAuthProvider(apiBase string) (LLMProvider, error) { + if apiBase == "" { + apiBase = defaultAnthropicAPIBase + } + cred, err := getCredential("anthropic") if err != nil { return nil, fmt.Errorf("loading auth credentials: %w", err) } if cred == nil { return nil, fmt.Errorf("no credentials for anthropic. Run: picoclaw auth login --provider anthropic") } - return NewClaudeProviderWithTokenSource(cred.AccessToken, createClaudeTokenSource()), nil + return NewClaudeProviderWithTokenSourceAndBaseURL(cred.AccessToken, createClaudeTokenSource(), apiBase), nil } func createCodexAuthProvider() (LLMProvider, error) { - cred, err := auth.GetCredential("openai") + cred, err := getCredential("openai") if err != nil { return nil, fmt.Errorf("loading auth credentials: %w", err) } @@ -69,6 +76,7 @@ func resolveProviderSelection(cfg *config.Config) (providerSelection, error) { if cfg.Providers.Groq.APIKey != "" { sel.apiKey = cfg.Providers.Groq.APIKey sel.apiBase = cfg.Providers.Groq.APIBase + sel.proxy = cfg.Providers.Groq.Proxy if sel.apiBase == "" { sel.apiBase = "https://api.groq.com/openai/v1" } @@ -85,6 +93,7 @@ func resolveProviderSelection(cfg *config.Config) (providerSelection, error) { } sel.apiKey = cfg.Providers.OpenAI.APIKey sel.apiBase = cfg.Providers.OpenAI.APIBase + sel.proxy = cfg.Providers.OpenAI.Proxy if sel.apiBase == "" { sel.apiBase = "https://api.openai.com/v1" } @@ -92,18 +101,24 @@ func resolveProviderSelection(cfg *config.Config) (providerSelection, error) { case "anthropic", "claude": if cfg.Providers.Anthropic.APIKey != "" || cfg.Providers.Anthropic.AuthMethod != "" { if cfg.Providers.Anthropic.AuthMethod == "oauth" || cfg.Providers.Anthropic.AuthMethod == "token" { + sel.apiBase = cfg.Providers.Anthropic.APIBase + if sel.apiBase == "" { + sel.apiBase = defaultAnthropicAPIBase + } sel.providerType = providerTypeClaudeAuth return sel, nil } sel.apiKey = cfg.Providers.Anthropic.APIKey sel.apiBase = cfg.Providers.Anthropic.APIBase + sel.proxy = cfg.Providers.Anthropic.Proxy if sel.apiBase == "" { - sel.apiBase = "https://api.anthropic.com/v1" + sel.apiBase = defaultAnthropicAPIBase } } case "openrouter": if cfg.Providers.OpenRouter.APIKey != "" { sel.apiKey = cfg.Providers.OpenRouter.APIKey + sel.proxy = cfg.Providers.OpenRouter.Proxy if cfg.Providers.OpenRouter.APIBase != "" { sel.apiBase = cfg.Providers.OpenRouter.APIBase } else { @@ -114,6 +129,7 @@ func resolveProviderSelection(cfg *config.Config) (providerSelection, error) { if cfg.Providers.Zhipu.APIKey != "" { sel.apiKey = cfg.Providers.Zhipu.APIKey sel.apiBase = cfg.Providers.Zhipu.APIBase + sel.proxy = cfg.Providers.Zhipu.Proxy if sel.apiBase == "" { sel.apiBase = "https://open.bigmodel.cn/api/paas/v4" } @@ -122,6 +138,7 @@ func resolveProviderSelection(cfg *config.Config) (providerSelection, error) { if cfg.Providers.Gemini.APIKey != "" { sel.apiKey = cfg.Providers.Gemini.APIKey sel.apiBase = cfg.Providers.Gemini.APIBase + sel.proxy = cfg.Providers.Gemini.Proxy if sel.apiBase == "" { sel.apiBase = "https://generativelanguage.googleapis.com/v1beta" } @@ -130,15 +147,26 @@ func resolveProviderSelection(cfg *config.Config) (providerSelection, error) { if cfg.Providers.VLLM.APIBase != "" { sel.apiKey = cfg.Providers.VLLM.APIKey sel.apiBase = cfg.Providers.VLLM.APIBase + sel.proxy = cfg.Providers.VLLM.Proxy } case "shengsuanyun": if cfg.Providers.ShengSuanYun.APIKey != "" { sel.apiKey = cfg.Providers.ShengSuanYun.APIKey sel.apiBase = cfg.Providers.ShengSuanYun.APIBase + sel.proxy = cfg.Providers.ShengSuanYun.Proxy if sel.apiBase == "" { sel.apiBase = "https://router.shengsuanyun.com/api/v1" } } + case "nvidia": + if cfg.Providers.Nvidia.APIKey != "" { + sel.apiKey = cfg.Providers.Nvidia.APIKey + sel.apiBase = cfg.Providers.Nvidia.APIBase + sel.proxy = cfg.Providers.Nvidia.Proxy + if sel.apiBase == "" { + sel.apiBase = "https://integrate.api.nvidia.com/v1" + } + } case "claude-cli", "claude-code", "claudecode": workspace := cfg.WorkspacePath() if workspace == "" { @@ -159,6 +187,7 @@ func resolveProviderSelection(cfg *config.Config) (providerSelection, error) { if cfg.Providers.DeepSeek.APIKey != "" { sel.apiKey = cfg.Providers.DeepSeek.APIKey sel.apiBase = cfg.Providers.DeepSeek.APIBase + sel.proxy = cfg.Providers.DeepSeek.Proxy if sel.apiBase == "" { sel.apiBase = "https://api.deepseek.com/v1" } @@ -204,6 +233,10 @@ func resolveProviderSelection(cfg *config.Config) (providerSelection, error) { case (strings.Contains(lowerModel, "claude") || strings.HasPrefix(model, "anthropic/")) && (cfg.Providers.Anthropic.APIKey != "" || cfg.Providers.Anthropic.AuthMethod != ""): if cfg.Providers.Anthropic.AuthMethod == "oauth" || cfg.Providers.Anthropic.AuthMethod == "token" { + sel.apiBase = cfg.Providers.Anthropic.APIBase + if sel.apiBase == "" { + sel.apiBase = defaultAnthropicAPIBase + } sel.providerType = providerTypeClaudeAuth return sel, nil } @@ -211,7 +244,7 @@ func resolveProviderSelection(cfg *config.Config) (providerSelection, error) { sel.apiBase = cfg.Providers.Anthropic.APIBase sel.proxy = cfg.Providers.Anthropic.Proxy if sel.apiBase == "" { - sel.apiBase = "https://api.anthropic.com/v1" + sel.apiBase = defaultAnthropicAPIBase } case (strings.Contains(lowerModel, "gpt") || strings.HasPrefix(model, "openai/")) && (cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != ""): @@ -303,7 +336,7 @@ func CreateProvider(cfg *config.Config) (LLMProvider, error) { switch sel.providerType { case providerTypeClaudeAuth: - return createClaudeAuthProvider() + return createClaudeAuthProvider(sel.apiBase) case providerTypeCodexAuth: return createCodexAuthProvider() case providerTypeCodexCLIToken: diff --git a/pkg/providers/factory_test.go b/pkg/providers/factory_test.go index c1f14291d..e31737eb9 100644 --- a/pkg/providers/factory_test.go +++ b/pkg/providers/factory_test.go @@ -4,6 +4,7 @@ import ( "strings" "testing" + "github.com/sipeed/picoclaw/pkg/auth" "github.com/sipeed/picoclaw/pkg/config" ) @@ -32,6 +33,40 @@ func TestResolveProviderSelection(t *testing.T) { wantType: providerTypeGitHubCopilot, wantAPIBase: "localhost:4321", }, + { + name: "explicit deepseek provider uses deepseek defaults", + setup: func(cfg *config.Config) { + cfg.Agents.Defaults.Provider = "deepseek" + cfg.Agents.Defaults.Model = "deepseek/deepseek-chat" + cfg.Providers.DeepSeek.APIKey = "deepseek-key" + cfg.Providers.DeepSeek.Proxy = "http://127.0.0.1:7890" + }, + wantType: providerTypeHTTPCompat, + wantAPIBase: "https://api.deepseek.com/v1", + wantProxy: "http://127.0.0.1:7890", + }, + { + name: "explicit shengsuanyun provider uses defaults", + setup: func(cfg *config.Config) { + cfg.Agents.Defaults.Provider = "shengsuanyun" + cfg.Providers.ShengSuanYun.APIKey = "ssy-key" + cfg.Providers.ShengSuanYun.Proxy = "http://127.0.0.1:7890" + }, + wantType: providerTypeHTTPCompat, + wantAPIBase: "https://router.shengsuanyun.com/api/v1", + wantProxy: "http://127.0.0.1:7890", + }, + { + name: "explicit nvidia provider uses defaults", + setup: func(cfg *config.Config) { + cfg.Agents.Defaults.Provider = "nvidia" + cfg.Providers.Nvidia.APIKey = "nvapi-test" + cfg.Providers.Nvidia.Proxy = "http://127.0.0.1:7890" + }, + wantType: providerTypeHTTPCompat, + wantAPIBase: "https://integrate.api.nvidia.com/v1", + wantProxy: "http://127.0.0.1:7890", + }, { name: "openrouter model uses openrouter defaults", setup: func(cfg *config.Config) { @@ -202,3 +237,63 @@ func TestCreateProviderReturnsCodexProviderForCodexCliAuthMethod(t *testing.T) { t.Fatalf("provider type = %T, want *CodexProvider", provider) } } + +func TestCreateProviderReturnsClaudeProviderForAnthropicOAuth(t *testing.T) { + originalGetCredential := getCredential + t.Cleanup(func() { getCredential = originalGetCredential }) + + getCredential = func(provider string) (*auth.AuthCredential, error) { + if provider != "anthropic" { + t.Fatalf("provider = %q, want anthropic", provider) + } + return &auth.AuthCredential{ + AccessToken: "anthropic-token", + }, nil + } + + cfg := config.DefaultConfig() + cfg.Agents.Defaults.Provider = "anthropic" + cfg.Providers.Anthropic.AuthMethod = "oauth" + cfg.Providers.Anthropic.APIBase = "https://proxy.example.com/v1" + + provider, err := CreateProvider(cfg) + if err != nil { + t.Fatalf("CreateProvider() error = %v", err) + } + + claudeProvider, ok := provider.(*ClaudeProvider) + if !ok { + t.Fatalf("provider type = %T, want *ClaudeProvider", provider) + } + if got := claudeProvider.delegate.BaseURL(); got != "https://proxy.example.com" { + t.Fatalf("anthropic baseURL = %q, want %q", got, "https://proxy.example.com") + } +} + +func TestCreateProviderReturnsCodexProviderForOpenAIOAuth(t *testing.T) { + originalGetCredential := getCredential + t.Cleanup(func() { getCredential = originalGetCredential }) + + getCredential = func(provider string) (*auth.AuthCredential, error) { + if provider != "openai" { + t.Fatalf("provider = %q, want openai", provider) + } + return &auth.AuthCredential{ + AccessToken: "openai-token", + AccountID: "acct_123", + }, nil + } + + cfg := config.DefaultConfig() + cfg.Agents.Defaults.Provider = "openai" + cfg.Providers.OpenAI.AuthMethod = "oauth" + + provider, err := CreateProvider(cfg) + if err != nil { + t.Fatalf("CreateProvider() error = %v", err) + } + + if _, ok := provider.(*CodexProvider); !ok { + t.Fatalf("provider type = %T, want *CodexProvider", provider) + } +} diff --git a/pkg/providers/http_provider.go b/pkg/providers/http_provider.go index 0f7f646d8..e39a19e90 100644 --- a/pkg/providers/http_provider.go +++ b/pkg/providers/http_provider.go @@ -15,116 +15,16 @@ type HTTPProvider struct { delegate *openai_compat.Provider } -func NewHTTPProvider(apiKey, apiBase string, proxy ...string) *HTTPProvider { - proxyURL := "" - if len(proxy) > 0 { - proxyURL = proxy[0] - } +func NewHTTPProvider(apiKey, apiBase, proxy string) *HTTPProvider { return &HTTPProvider{ - delegate: openai_compat.NewProvider(apiKey, apiBase, proxyURL), + delegate: openai_compat.NewProvider(apiKey, apiBase, proxy), } } func (p *HTTPProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) { - compatResp, err := p.delegate.Chat(ctx, toOpenAICompatMessages(messages), toOpenAICompatTools(tools), model, options) - if err != nil { - return nil, err - } - return fromOpenAICompatResponse(compatResp), nil + return p.delegate.Chat(ctx, messages, tools, model, options) } func (p *HTTPProvider) GetDefaultModel() string { return "" } - -func toOpenAICompatMessages(messages []Message) []openai_compat.Message { - out := make([]openai_compat.Message, 0, len(messages)) - for _, msg := range messages { - out = append(out, openai_compat.Message{ - Role: msg.Role, - Content: msg.Content, - ToolCalls: toOpenAICompatToolCalls(msg.ToolCalls), - ToolCallID: msg.ToolCallID, - }) - } - return out -} - -func toOpenAICompatTools(tools []ToolDefinition) []openai_compat.ToolDefinition { - out := make([]openai_compat.ToolDefinition, 0, len(tools)) - for _, t := range tools { - out = append(out, openai_compat.ToolDefinition{ - Type: t.Type, - Function: openai_compat.ToolFunctionDefinition{ - Name: t.Function.Name, - Description: t.Function.Description, - Parameters: t.Function.Parameters, - }, - }) - } - return out -} - -func toOpenAICompatToolCalls(toolCalls []ToolCall) []openai_compat.ToolCall { - out := make([]openai_compat.ToolCall, 0, len(toolCalls)) - for _, tc := range toolCalls { - var fn *openai_compat.FunctionCall - if tc.Function != nil { - fn = &openai_compat.FunctionCall{ - Name: tc.Function.Name, - Arguments: tc.Function.Arguments, - } - } - out = append(out, openai_compat.ToolCall{ - ID: tc.ID, - Type: tc.Type, - Function: fn, - Name: tc.Name, - Arguments: tc.Arguments, - }) - } - return out -} - -func fromOpenAICompatResponse(resp *openai_compat.LLMResponse) *LLMResponse { - if resp == nil { - return &LLMResponse{} - } - - var usage *UsageInfo - if resp.Usage != nil { - usage = &UsageInfo{ - PromptTokens: resp.Usage.PromptTokens, - CompletionTokens: resp.Usage.CompletionTokens, - TotalTokens: resp.Usage.TotalTokens, - } - } - - return &LLMResponse{ - Content: resp.Content, - ToolCalls: fromOpenAICompatToolCalls(resp.ToolCalls), - FinishReason: resp.FinishReason, - Usage: usage, - } -} - -func fromOpenAICompatToolCalls(toolCalls []openai_compat.ToolCall) []ToolCall { - out := make([]ToolCall, 0, len(toolCalls)) - for _, tc := range toolCalls { - var fn *FunctionCall - if tc.Function != nil { - fn = &FunctionCall{ - Name: tc.Function.Name, - Arguments: tc.Function.Arguments, - } - } - out = append(out, ToolCall{ - ID: tc.ID, - Type: tc.Type, - Function: fn, - Name: tc.Name, - Arguments: tc.Arguments, - }) - } - return out -} diff --git a/pkg/providers/openai_compat/provider.go b/pkg/providers/openai_compat/provider.go index 7bc8e26be..9b404dd77 100644 --- a/pkg/providers/openai_compat/provider.go +++ b/pkg/providers/openai_compat/provider.go @@ -6,55 +6,22 @@ import ( "encoding/json" "fmt" "io" + "log" "net/http" "net/url" "strings" "time" + + "github.com/sipeed/picoclaw/pkg/providers/protocoltypes" ) -type ToolCall struct { - ID string `json:"id"` - Type string `json:"type,omitempty"` - Function *FunctionCall `json:"function,omitempty"` - Name string `json:"name,omitempty"` - Arguments map[string]interface{} `json:"arguments,omitempty"` -} - -type FunctionCall struct { - Name string `json:"name"` - Arguments string `json:"arguments"` -} - -type LLMResponse struct { - Content string `json:"content"` - ToolCalls []ToolCall `json:"tool_calls,omitempty"` - FinishReason string `json:"finish_reason"` - Usage *UsageInfo `json:"usage,omitempty"` -} - -type UsageInfo struct { - PromptTokens int `json:"prompt_tokens"` - CompletionTokens int `json:"completion_tokens"` - TotalTokens int `json:"total_tokens"` -} - -type Message struct { - Role string `json:"role"` - Content string `json:"content"` - ToolCalls []ToolCall `json:"tool_calls,omitempty"` - ToolCallID string `json:"tool_call_id,omitempty"` -} - -type ToolDefinition struct { - Type string `json:"type"` - Function ToolFunctionDefinition `json:"function"` -} - -type ToolFunctionDefinition struct { - Name string `json:"name"` - Description string `json:"description"` - Parameters map[string]interface{} `json:"parameters"` -} +type ToolCall = protocoltypes.ToolCall +type FunctionCall = protocoltypes.FunctionCall +type LLMResponse = protocoltypes.LLMResponse +type UsageInfo = protocoltypes.UsageInfo +type Message = protocoltypes.Message +type ToolDefinition = protocoltypes.ToolDefinition +type ToolFunctionDefinition = protocoltypes.ToolFunctionDefinition type Provider struct { apiKey string @@ -62,21 +29,19 @@ type Provider struct { httpClient *http.Client } -func NewProvider(apiKey, apiBase string, proxy ...string) *Provider { - proxyURL := "" - if len(proxy) > 0 { - proxyURL = proxy[0] - } +func NewProvider(apiKey, apiBase, proxy string) *Provider { client := &http.Client{ Timeout: 120 * time.Second, } - if proxyURL != "" { - parsed, err := url.Parse(proxyURL) + if proxy != "" { + parsed, err := url.Parse(proxy) if err == nil { client.Transport = &http.Transport{ Proxy: http.ProxyURL(parsed), } + } else { + log.Printf("openai_compat: invalid proxy URL %q: %v", proxy, err) } } @@ -92,13 +57,7 @@ func (p *Provider) Chat(ctx context.Context, messages []Message, tools []ToolDef return nil, fmt.Errorf("API base not configured") } - // Strip provider prefix for OpenAI-compatible backends. - if idx := strings.Index(model, "/"); idx != -1 { - prefix := model[:idx] - if prefix == "moonshot" || prefix == "nvidia" || prefix == "groq" || prefix == "ollama" { - model = model[idx+1:] - } - } + model = normalizeModel(model, p.apiBase) requestBody := map[string]interface{}{ "model": model, @@ -110,7 +69,7 @@ func (p *Provider) Chat(ctx context.Context, messages []Message, tools []ToolDef requestBody["tool_choice"] = "auto" } - if maxTokens, ok := options["max_tokens"].(int); ok { + if maxTokens, ok := asInt(options["max_tokens"]); ok { lowerModel := strings.ToLower(model) if strings.Contains(lowerModel, "glm") || strings.Contains(lowerModel, "o1") { requestBody["max_completion_tokens"] = maxTokens @@ -119,7 +78,7 @@ func (p *Provider) Chat(ctx context.Context, messages []Message, tools []ToolDef } } - if temperature, ok := options["temperature"].(float64); ok { + if temperature, ok := asFloat(options["temperature"]); ok { lowerModel := strings.ToLower(model) // Kimi k2 models only support temperature=1. if strings.Contains(lowerModel, "kimi") && strings.Contains(lowerModel, "k2") { @@ -198,17 +157,11 @@ func parseResponse(body []byte) (*LLMResponse, error) { arguments := make(map[string]interface{}) name := "" - if tc.Type == "function" && tc.Function != nil { - name = tc.Function.Name - if tc.Function.Arguments != "" { - if err := json.Unmarshal([]byte(tc.Function.Arguments), &arguments); err != nil { - arguments["raw"] = tc.Function.Arguments - } - } - } else if tc.Function != nil { + if tc.Function != nil { name = tc.Function.Name if tc.Function.Arguments != "" { if err := json.Unmarshal([]byte(tc.Function.Arguments), &arguments); err != nil { + log.Printf("openai_compat: failed to decode tool call arguments for %q: %v", name, err) arguments["raw"] = tc.Function.Arguments } } @@ -228,3 +181,52 @@ func parseResponse(body []byte) (*LLMResponse, error) { Usage: apiResponse.Usage, }, nil } + +func normalizeModel(model, apiBase string) string { + idx := strings.Index(model, "/") + if idx == -1 { + return model + } + + if strings.Contains(strings.ToLower(apiBase), "openrouter.ai") { + return model + } + + prefix := strings.ToLower(model[:idx]) + switch prefix { + case "moonshot", "nvidia", "groq", "ollama", "deepseek", "google", "openrouter", "zhipu": + return model[idx+1:] + default: + return model + } +} + +func asInt(v interface{}) (int, bool) { + switch val := v.(type) { + case int: + return val, true + case int64: + return int(val), true + case float64: + return int(val), true + case float32: + return int(val), true + default: + return 0, false + } +} + +func asFloat(v interface{}) (float64, bool) { + switch val := v.(type) { + case float64: + return val, true + case float32: + return float64(val), true + case int: + return float64(val), true + case int64: + return float64(val), true + default: + return 0, false + } +} diff --git a/pkg/providers/openai_compat/provider_test.go b/pkg/providers/openai_compat/provider_test.go index e5926458b..94779b39c 100644 --- a/pkg/providers/openai_compat/provider_test.go +++ b/pkg/providers/openai_compat/provider_test.go @@ -4,6 +4,7 @@ import ( "encoding/json" "net/http" "net/http/httptest" + "net/url" "testing" ) @@ -32,7 +33,7 @@ func TestProviderChat_UsesMaxCompletionTokensForGLM(t *testing.T) { })) defer server.Close() - p := NewProvider("key", server.URL) + p := NewProvider("key", server.URL, "") _, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "glm-4.7", map[string]interface{}{"max_tokens": 1234}) if err != nil { t.Fatalf("Chat() error = %v", err) @@ -78,7 +79,7 @@ func TestProviderChat_ParsesToolCalls(t *testing.T) { })) defer server.Close() - p := NewProvider("key", server.URL) + p := NewProvider("key", server.URL, "") out, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "gpt-4o", nil) if err != nil { t.Fatalf("Chat() error = %v", err) @@ -100,7 +101,7 @@ func TestProviderChat_HTTPError(t *testing.T) { })) defer server.Close() - p := NewProvider("key", server.URL) + p := NewProvider("key", server.URL, "") _, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "gpt-4o", nil) if err == nil { t.Fatal("expected error, got nil") @@ -128,7 +129,7 @@ func TestProviderChat_StripsMoonshotPrefixAndNormalizesKimiTemperature(t *testin })) defer server.Close() - p := NewProvider("key", server.URL) + p := NewProvider("key", server.URL, "") _, err := p.Chat( t.Context(), []Message{{Role: "user", Content: "hi"}}, @@ -164,6 +165,11 @@ func TestProviderChat_StripsGroqAndOllamaPrefixes(t *testing.T) { input: "ollama/qwen2.5:14b", wantModel: "qwen2.5:14b", }, + { + name: "strips deepseek prefix", + input: "deepseek/deepseek-chat", + wantModel: "deepseek-chat", + }, } for _, tt := range tests { @@ -188,7 +194,7 @@ func TestProviderChat_StripsGroqAndOllamaPrefixes(t *testing.T) { })) defer server.Close() - p := NewProvider("key", server.URL) + p := NewProvider("key", server.URL, "") _, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, tt.input, nil) if err != nil { t.Fatalf("Chat() error = %v", err) @@ -200,3 +206,72 @@ func TestProviderChat_StripsGroqAndOllamaPrefixes(t *testing.T) { }) } } + +func TestProvider_ProxyConfigured(t *testing.T) { + proxyURL := "http://127.0.0.1:8080" + p := NewProvider("key", "https://example.com", proxyURL) + + transport, ok := p.httpClient.Transport.(*http.Transport) + if !ok || transport == nil { + t.Fatalf("expected http transport with proxy, got %T", p.httpClient.Transport) + } + + req := &http.Request{URL: &url.URL{Scheme: "https", Host: "api.example.com"}} + gotProxy, err := transport.Proxy(req) + if err != nil { + t.Fatalf("proxy function returned error: %v", err) + } + if gotProxy == nil || gotProxy.String() != proxyURL { + t.Fatalf("proxy = %v, want %s", gotProxy, proxyURL) + } +} + +func TestProviderChat_AcceptsNumericOptionTypes(t *testing.T) { + var requestBody map[string]interface{} + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + resp := map[string]interface{}{ + "choices": []map[string]interface{}{ + { + "message": map[string]interface{}{"content": "ok"}, + "finish_reason": "stop", + }, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + p := NewProvider("key", server.URL, "") + _, err := p.Chat( + t.Context(), + []Message{{Role: "user", Content: "hi"}}, + nil, + "gpt-4o", + map[string]interface{}{"max_tokens": float64(512), "temperature": 1}, + ) + if err != nil { + t.Fatalf("Chat() error = %v", err) + } + + if requestBody["max_tokens"] != float64(512) { + t.Fatalf("max_tokens = %v, want 512", requestBody["max_tokens"]) + } + if requestBody["temperature"] != float64(1) { + t.Fatalf("temperature = %v, want 1", requestBody["temperature"]) + } +} + +func TestNormalizeModel_UsesAPIBase(t *testing.T) { + if got := normalizeModel("deepseek/deepseek-chat", "https://api.deepseek.com/v1"); got != "deepseek-chat" { + t.Fatalf("normalizeModel(deepseek) = %q, want %q", got, "deepseek-chat") + } + if got := normalizeModel("openrouter/auto", "https://openrouter.ai/api/v1"); got != "openrouter/auto" { + t.Fatalf("normalizeModel(openrouter) = %q, want %q", got, "openrouter/auto") + } +} diff --git a/pkg/providers/protocoltypes/types.go b/pkg/providers/protocoltypes/types.go new file mode 100644 index 000000000..6b33ae734 --- /dev/null +++ b/pkg/providers/protocoltypes/types.go @@ -0,0 +1,45 @@ +package protocoltypes + +type ToolCall struct { + ID string `json:"id"` + Type string `json:"type,omitempty"` + Function *FunctionCall `json:"function,omitempty"` + Name string `json:"name,omitempty"` + Arguments map[string]interface{} `json:"arguments,omitempty"` +} + +type FunctionCall struct { + Name string `json:"name"` + Arguments string `json:"arguments"` +} + +type LLMResponse struct { + Content string `json:"content"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + FinishReason string `json:"finish_reason"` + Usage *UsageInfo `json:"usage,omitempty"` +} + +type UsageInfo struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` +} + +type Message struct { + Role string `json:"role"` + Content string `json:"content"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` +} + +type ToolDefinition struct { + Type string `json:"type"` + Function ToolFunctionDefinition `json:"function"` +} + +type ToolFunctionDefinition struct { + Name string `json:"name"` + Description string `json:"description"` + Parameters map[string]interface{} `json:"parameters"` +} diff --git a/pkg/providers/types.go b/pkg/providers/types.go index 88b62e975..221a842fa 100644 --- a/pkg/providers/types.go +++ b/pkg/providers/types.go @@ -1,52 +1,20 @@ package providers -import "context" +import ( + "context" -type ToolCall struct { - ID string `json:"id"` - Type string `json:"type,omitempty"` - Function *FunctionCall `json:"function,omitempty"` - Name string `json:"name,omitempty"` - Arguments map[string]interface{} `json:"arguments,omitempty"` -} + "github.com/sipeed/picoclaw/pkg/providers/protocoltypes" +) -type FunctionCall struct { - Name string `json:"name"` - Arguments string `json:"arguments"` -} - -type LLMResponse struct { - Content string `json:"content"` - ToolCalls []ToolCall `json:"tool_calls,omitempty"` - FinishReason string `json:"finish_reason"` - Usage *UsageInfo `json:"usage,omitempty"` -} - -type UsageInfo struct { - PromptTokens int `json:"prompt_tokens"` - CompletionTokens int `json:"completion_tokens"` - TotalTokens int `json:"total_tokens"` -} - -type Message struct { - Role string `json:"role"` - Content string `json:"content"` - ToolCalls []ToolCall `json:"tool_calls,omitempty"` - ToolCallID string `json:"tool_call_id,omitempty"` -} +type ToolCall = protocoltypes.ToolCall +type FunctionCall = protocoltypes.FunctionCall +type LLMResponse = protocoltypes.LLMResponse +type UsageInfo = protocoltypes.UsageInfo +type Message = protocoltypes.Message +type ToolDefinition = protocoltypes.ToolDefinition +type ToolFunctionDefinition = protocoltypes.ToolFunctionDefinition type LLMProvider interface { Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) GetDefaultModel() string } - -type ToolDefinition struct { - Type string `json:"type"` - Function ToolFunctionDefinition `json:"function"` -} - -type ToolFunctionDefinition struct { - Name string `json:"name"` - Description string `json:"description"` - Parameters map[string]interface{} `json:"parameters"` -}