From a6e885bb473a20d671ed1dab5e8e8ea9bb8cd399 Mon Sep 17 00:00:00 2001 From: Jared Mahotiere Date: Sun, 15 Feb 2026 08:04:07 -0500 Subject: [PATCH] refactor(providers): extract protocol factory and openai-compat transport --- pkg/providers/factory.go | 291 ++++++++++++ pkg/providers/factory_test.go | 150 ++++++ pkg/providers/http_provider.go | 473 ++++--------------- pkg/providers/openai_compat/provider.go | 230 +++++++++ pkg/providers/openai_compat/provider_test.go | 149 ++++++ 5 files changed, 905 insertions(+), 388 deletions(-) create mode 100644 pkg/providers/factory.go create mode 100644 pkg/providers/factory_test.go create mode 100644 pkg/providers/openai_compat/provider.go create mode 100644 pkg/providers/openai_compat/provider_test.go diff --git a/pkg/providers/factory.go b/pkg/providers/factory.go new file mode 100644 index 000000000..84dcd9aaa --- /dev/null +++ b/pkg/providers/factory.go @@ -0,0 +1,291 @@ +package providers + +import ( + "fmt" + "strings" + + "github.com/sipeed/picoclaw/pkg/auth" + "github.com/sipeed/picoclaw/pkg/config" +) + +type providerType int + +const ( + providerTypeHTTPCompat providerType = iota + providerTypeClaudeAuth + providerTypeCodexAuth + providerTypeClaudeCLI + providerTypeGitHubCopilot +) + +type providerSelection struct { + providerType providerType + apiKey string + apiBase string + proxy string + model string + workspace string + connectMode string +} + +func createClaudeAuthProvider() (LLMProvider, error) { + cred, err := auth.GetCredential("anthropic") + if err != nil { + return nil, fmt.Errorf("loading auth credentials: %w", err) + } + if cred == nil { + return nil, fmt.Errorf("no credentials for anthropic. Run: picoclaw auth login --provider anthropic") + } + return NewClaudeProviderWithTokenSource(cred.AccessToken, createClaudeTokenSource()), nil +} + +func createCodexAuthProvider() (LLMProvider, error) { + cred, err := auth.GetCredential("openai") + if err != nil { + return nil, fmt.Errorf("loading auth credentials: %w", err) + } + if cred == nil { + return nil, fmt.Errorf("no credentials for openai. Run: picoclaw auth login --provider openai") + } + return NewCodexProviderWithTokenSource(cred.AccessToken, cred.AccountID, createCodexTokenSource()), nil +} + +func resolveProviderSelection(cfg *config.Config) (providerSelection, error) { + model := cfg.Agents.Defaults.Model + providerName := strings.ToLower(cfg.Agents.Defaults.Provider) + lowerModel := strings.ToLower(model) + + sel := providerSelection{ + providerType: providerTypeHTTPCompat, + model: model, + } + + // First, prefer explicit provider configuration. + if providerName != "" { + switch providerName { + case "groq": + if cfg.Providers.Groq.APIKey != "" { + sel.apiKey = cfg.Providers.Groq.APIKey + sel.apiBase = cfg.Providers.Groq.APIBase + if sel.apiBase == "" { + sel.apiBase = "https://api.groq.com/openai/v1" + } + } + case "openai", "gpt": + if cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != "" { + if cfg.Providers.OpenAI.AuthMethod == "oauth" || cfg.Providers.OpenAI.AuthMethod == "token" { + sel.providerType = providerTypeCodexAuth + return sel, nil + } + sel.apiKey = cfg.Providers.OpenAI.APIKey + sel.apiBase = cfg.Providers.OpenAI.APIBase + if sel.apiBase == "" { + sel.apiBase = "https://api.openai.com/v1" + } + } + case "anthropic", "claude": + if cfg.Providers.Anthropic.APIKey != "" || cfg.Providers.Anthropic.AuthMethod != "" { + if cfg.Providers.Anthropic.AuthMethod == "oauth" || cfg.Providers.Anthropic.AuthMethod == "token" { + sel.providerType = providerTypeClaudeAuth + return sel, nil + } + sel.apiKey = cfg.Providers.Anthropic.APIKey + sel.apiBase = cfg.Providers.Anthropic.APIBase + if sel.apiBase == "" { + sel.apiBase = "https://api.anthropic.com/v1" + } + } + case "openrouter": + if cfg.Providers.OpenRouter.APIKey != "" { + sel.apiKey = cfg.Providers.OpenRouter.APIKey + if cfg.Providers.OpenRouter.APIBase != "" { + sel.apiBase = cfg.Providers.OpenRouter.APIBase + } else { + sel.apiBase = "https://openrouter.ai/api/v1" + } + } + case "zhipu", "glm": + if cfg.Providers.Zhipu.APIKey != "" { + sel.apiKey = cfg.Providers.Zhipu.APIKey + sel.apiBase = cfg.Providers.Zhipu.APIBase + if sel.apiBase == "" { + sel.apiBase = "https://open.bigmodel.cn/api/paas/v4" + } + } + case "gemini", "google": + if cfg.Providers.Gemini.APIKey != "" { + sel.apiKey = cfg.Providers.Gemini.APIKey + sel.apiBase = cfg.Providers.Gemini.APIBase + if sel.apiBase == "" { + sel.apiBase = "https://generativelanguage.googleapis.com/v1beta" + } + } + case "vllm": + if cfg.Providers.VLLM.APIBase != "" { + sel.apiKey = cfg.Providers.VLLM.APIKey + sel.apiBase = cfg.Providers.VLLM.APIBase + } + case "shengsuanyun": + if cfg.Providers.ShengSuanYun.APIKey != "" { + sel.apiKey = cfg.Providers.ShengSuanYun.APIKey + sel.apiBase = cfg.Providers.ShengSuanYun.APIBase + if sel.apiBase == "" { + sel.apiBase = "https://router.shengsuanyun.com/api/v1" + } + } + case "claude-cli", "claude-code", "claudecode": + workspace := cfg.Agents.Defaults.Workspace + if workspace == "" { + workspace = "." + } + sel.providerType = providerTypeClaudeCLI + sel.workspace = workspace + return sel, nil + case "deepseek": + if cfg.Providers.DeepSeek.APIKey != "" { + sel.apiKey = cfg.Providers.DeepSeek.APIKey + sel.apiBase = cfg.Providers.DeepSeek.APIBase + if sel.apiBase == "" { + sel.apiBase = "https://api.deepseek.com/v1" + } + if model != "deepseek-chat" && model != "deepseek-reasoner" { + sel.model = "deepseek-chat" + } + } + case "github_copilot", "copilot": + sel.providerType = providerTypeGitHubCopilot + if cfg.Providers.GitHubCopilot.APIBase != "" { + sel.apiBase = cfg.Providers.GitHubCopilot.APIBase + } else { + sel.apiBase = "localhost:4321" + } + sel.connectMode = cfg.Providers.GitHubCopilot.ConnectMode + return sel, nil + } + } + + // Fallback: infer provider from model and configured keys. + if sel.apiKey == "" && sel.apiBase == "" { + switch { + case (strings.Contains(lowerModel, "kimi") || strings.Contains(lowerModel, "moonshot") || strings.HasPrefix(model, "moonshot/")) && cfg.Providers.Moonshot.APIKey != "": + sel.apiKey = cfg.Providers.Moonshot.APIKey + sel.apiBase = cfg.Providers.Moonshot.APIBase + sel.proxy = cfg.Providers.Moonshot.Proxy + if sel.apiBase == "" { + sel.apiBase = "https://api.moonshot.cn/v1" + } + case strings.HasPrefix(model, "openrouter/") || + strings.HasPrefix(model, "anthropic/") || + strings.HasPrefix(model, "openai/") || + strings.HasPrefix(model, "meta-llama/") || + strings.HasPrefix(model, "deepseek/") || + strings.HasPrefix(model, "google/"): + sel.apiKey = cfg.Providers.OpenRouter.APIKey + sel.proxy = cfg.Providers.OpenRouter.Proxy + if cfg.Providers.OpenRouter.APIBase != "" { + sel.apiBase = cfg.Providers.OpenRouter.APIBase + } else { + sel.apiBase = "https://openrouter.ai/api/v1" + } + case (strings.Contains(lowerModel, "claude") || strings.HasPrefix(model, "anthropic/")) && + (cfg.Providers.Anthropic.APIKey != "" || cfg.Providers.Anthropic.AuthMethod != ""): + if cfg.Providers.Anthropic.AuthMethod == "oauth" || cfg.Providers.Anthropic.AuthMethod == "token" { + sel.providerType = providerTypeClaudeAuth + return sel, nil + } + sel.apiKey = cfg.Providers.Anthropic.APIKey + sel.apiBase = cfg.Providers.Anthropic.APIBase + sel.proxy = cfg.Providers.Anthropic.Proxy + if sel.apiBase == "" { + sel.apiBase = "https://api.anthropic.com/v1" + } + case (strings.Contains(lowerModel, "gpt") || strings.HasPrefix(model, "openai/")) && + (cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != ""): + if cfg.Providers.OpenAI.AuthMethod == "oauth" || cfg.Providers.OpenAI.AuthMethod == "token" { + sel.providerType = providerTypeCodexAuth + return sel, nil + } + sel.apiKey = cfg.Providers.OpenAI.APIKey + sel.apiBase = cfg.Providers.OpenAI.APIBase + sel.proxy = cfg.Providers.OpenAI.Proxy + if sel.apiBase == "" { + sel.apiBase = "https://api.openai.com/v1" + } + case (strings.Contains(lowerModel, "gemini") || strings.HasPrefix(model, "google/")) && cfg.Providers.Gemini.APIKey != "": + sel.apiKey = cfg.Providers.Gemini.APIKey + sel.apiBase = cfg.Providers.Gemini.APIBase + sel.proxy = cfg.Providers.Gemini.Proxy + if sel.apiBase == "" { + sel.apiBase = "https://generativelanguage.googleapis.com/v1beta" + } + case (strings.Contains(lowerModel, "glm") || strings.Contains(lowerModel, "zhipu") || strings.Contains(lowerModel, "zai")) && cfg.Providers.Zhipu.APIKey != "": + sel.apiKey = cfg.Providers.Zhipu.APIKey + sel.apiBase = cfg.Providers.Zhipu.APIBase + sel.proxy = cfg.Providers.Zhipu.Proxy + if sel.apiBase == "" { + sel.apiBase = "https://open.bigmodel.cn/api/paas/v4" + } + case (strings.Contains(lowerModel, "groq") || strings.HasPrefix(model, "groq/")) && cfg.Providers.Groq.APIKey != "": + sel.apiKey = cfg.Providers.Groq.APIKey + sel.apiBase = cfg.Providers.Groq.APIBase + sel.proxy = cfg.Providers.Groq.Proxy + if sel.apiBase == "" { + sel.apiBase = "https://api.groq.com/openai/v1" + } + case (strings.Contains(lowerModel, "nvidia") || strings.HasPrefix(model, "nvidia/")) && cfg.Providers.Nvidia.APIKey != "": + sel.apiKey = cfg.Providers.Nvidia.APIKey + sel.apiBase = cfg.Providers.Nvidia.APIBase + sel.proxy = cfg.Providers.Nvidia.Proxy + if sel.apiBase == "" { + sel.apiBase = "https://integrate.api.nvidia.com/v1" + } + case cfg.Providers.VLLM.APIBase != "": + sel.apiKey = cfg.Providers.VLLM.APIKey + sel.apiBase = cfg.Providers.VLLM.APIBase + sel.proxy = cfg.Providers.VLLM.Proxy + default: + if cfg.Providers.OpenRouter.APIKey != "" { + sel.apiKey = cfg.Providers.OpenRouter.APIKey + sel.proxy = cfg.Providers.OpenRouter.Proxy + if cfg.Providers.OpenRouter.APIBase != "" { + sel.apiBase = cfg.Providers.OpenRouter.APIBase + } else { + sel.apiBase = "https://openrouter.ai/api/v1" + } + } else { + return providerSelection{}, fmt.Errorf("no API key configured for model: %s", model) + } + } + } + + if sel.providerType == providerTypeHTTPCompat { + if sel.apiKey == "" && !strings.HasPrefix(model, "bedrock/") { + return providerSelection{}, fmt.Errorf("no API key configured for provider (model: %s)", model) + } + if sel.apiBase == "" { + return providerSelection{}, fmt.Errorf("no API base configured for provider (model: %s)", model) + } + } + + return sel, nil +} + +func CreateProvider(cfg *config.Config) (LLMProvider, error) { + sel, err := resolveProviderSelection(cfg) + if err != nil { + return nil, err + } + + switch sel.providerType { + case providerTypeClaudeAuth: + return createClaudeAuthProvider() + case providerTypeCodexAuth: + return createCodexAuthProvider() + case providerTypeClaudeCLI: + return NewClaudeCliProvider(sel.workspace), nil + case providerTypeGitHubCopilot: + return NewGitHubCopilotProvider(sel.apiBase, sel.connectMode, sel.model) + default: + return NewHTTPProvider(sel.apiKey, sel.apiBase, sel.proxy), nil + } +} diff --git a/pkg/providers/factory_test.go b/pkg/providers/factory_test.go new file mode 100644 index 000000000..f894b292a --- /dev/null +++ b/pkg/providers/factory_test.go @@ -0,0 +1,150 @@ +package providers + +import ( + "strings" + "testing" + + "github.com/sipeed/picoclaw/pkg/config" +) + +func TestResolveProviderSelection(t *testing.T) { + tests := []struct { + name string + setup func(*config.Config) + wantType providerType + wantAPIBase string + wantProxy string + wantErrSubstr string + }{ + { + name: "explicit claude-cli provider routes to cli provider type", + setup: func(cfg *config.Config) { + cfg.Agents.Defaults.Provider = "claude-cli" + cfg.Agents.Defaults.Workspace = "/tmp/ws" + }, + wantType: providerTypeClaudeCLI, + }, + { + name: "explicit copilot provider routes to github copilot type", + setup: func(cfg *config.Config) { + cfg.Agents.Defaults.Provider = "copilot" + }, + wantType: providerTypeGitHubCopilot, + wantAPIBase: "localhost:4321", + }, + { + name: "openrouter model uses openrouter defaults", + setup: func(cfg *config.Config) { + cfg.Agents.Defaults.Model = "openrouter/auto" + cfg.Providers.OpenRouter.APIKey = "sk-or-test" + }, + wantType: providerTypeHTTPCompat, + wantAPIBase: "https://openrouter.ai/api/v1", + }, + { + name: "anthropic oauth routes to claude auth provider", + setup: func(cfg *config.Config) { + cfg.Agents.Defaults.Model = "claude-sonnet-4-5-20250929" + cfg.Providers.Anthropic.AuthMethod = "oauth" + }, + wantType: providerTypeClaudeAuth, + }, + { + name: "openai oauth routes to codex auth provider", + setup: func(cfg *config.Config) { + cfg.Agents.Defaults.Model = "gpt-4o" + cfg.Providers.OpenAI.AuthMethod = "oauth" + }, + wantType: providerTypeCodexAuth, + }, + { + name: "zhipu model uses zhipu base default", + setup: func(cfg *config.Config) { + cfg.Agents.Defaults.Model = "glm-4.7" + cfg.Providers.Zhipu.APIKey = "zhipu-key" + }, + wantType: providerTypeHTTPCompat, + wantAPIBase: "https://open.bigmodel.cn/api/paas/v4", + }, + { + name: "groq model uses groq base default", + setup: func(cfg *config.Config) { + cfg.Agents.Defaults.Model = "groq/llama-3.3-70b" + cfg.Providers.Groq.APIKey = "gsk-key" + }, + wantType: providerTypeHTTPCompat, + wantAPIBase: "https://api.groq.com/openai/v1", + }, + { + name: "moonshot model keeps proxy and default base", + setup: func(cfg *config.Config) { + cfg.Agents.Defaults.Model = "moonshot/kimi-k2.5" + cfg.Providers.Moonshot.APIKey = "moonshot-key" + cfg.Providers.Moonshot.Proxy = "http://127.0.0.1:7890" + }, + wantType: providerTypeHTTPCompat, + wantAPIBase: "https://api.moonshot.cn/v1", + wantProxy: "http://127.0.0.1:7890", + }, + { + name: "missing keys returns model config error", + setup: func(cfg *config.Config) { + cfg.Agents.Defaults.Model = "custom-model" + }, + wantErrSubstr: "no API key configured for model", + }, + { + name: "openrouter prefix without key returns provider key error", + setup: func(cfg *config.Config) { + cfg.Agents.Defaults.Model = "openrouter/auto" + }, + wantErrSubstr: "no API key configured for provider", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := config.DefaultConfig() + tt.setup(cfg) + + got, err := resolveProviderSelection(cfg) + if tt.wantErrSubstr != "" { + if err == nil { + t.Fatalf("expected error containing %q, got nil", tt.wantErrSubstr) + } + if !strings.Contains(err.Error(), tt.wantErrSubstr) { + t.Fatalf("error = %q, want substring %q", err.Error(), tt.wantErrSubstr) + } + return + } + + if err != nil { + t.Fatalf("resolveProviderSelection() error = %v", err) + } + if got.providerType != tt.wantType { + t.Fatalf("providerType = %v, want %v", got.providerType, tt.wantType) + } + if tt.wantAPIBase != "" && got.apiBase != tt.wantAPIBase { + t.Fatalf("apiBase = %q, want %q", got.apiBase, tt.wantAPIBase) + } + if tt.wantProxy != "" && got.proxy != tt.wantProxy { + t.Fatalf("proxy = %q, want %q", got.proxy, tt.wantProxy) + } + }) + } +} + +func TestCreateProviderReturnsHTTPProviderForOpenRouter(t *testing.T) { + cfg := config.DefaultConfig() + cfg.Agents.Defaults.Model = "openrouter/auto" + cfg.Providers.OpenRouter.APIKey = "sk-or-test" + + provider, err := CreateProvider(cfg) + if err != nil { + t.Fatalf("CreateProvider() error = %v", err) + } + + if _, ok := provider.(*HTTPProvider); !ok { + t.Fatalf("provider type = %T, want *HTTPProvider", provider) + } +} diff --git a/pkg/providers/http_provider.go b/pkg/providers/http_provider.go index 17eb6214c..0f7f646d8 100644 --- a/pkg/providers/http_provider.go +++ b/pkg/providers/http_provider.go @@ -7,427 +7,124 @@ package providers import ( - "bytes" "context" - "encoding/json" - "fmt" - "io" - "net/http" - "net/url" - "strings" - "time" - - "github.com/sipeed/picoclaw/pkg/auth" - "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/providers/openai_compat" ) type HTTPProvider struct { - apiKey string - apiBase string - httpClient *http.Client + delegate *openai_compat.Provider } -func NewHTTPProvider(apiKey, apiBase, proxy string) *HTTPProvider { - client := &http.Client{ - Timeout: 120 * time.Second, +func NewHTTPProvider(apiKey, apiBase string, proxy ...string) *HTTPProvider { + proxyURL := "" + if len(proxy) > 0 { + proxyURL = proxy[0] } - - if proxy != "" { - proxyURL, err := url.Parse(proxy) - if err == nil { - client.Transport = &http.Transport{ - Proxy: http.ProxyURL(proxyURL), - } - } - } - return &HTTPProvider{ - apiKey: apiKey, - apiBase: strings.TrimRight(apiBase, "/"), - httpClient: client, + delegate: openai_compat.NewProvider(apiKey, apiBase, proxyURL), } } func (p *HTTPProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) { - if p.apiBase == "" { - return nil, fmt.Errorf("API base not configured") - } - - // Strip provider prefix from model name (e.g., moonshot/kimi-k2.5 -> kimi-k2.5) - if idx := strings.Index(model, "/"); idx != -1 { - prefix := model[:idx] - if prefix == "moonshot" || prefix == "nvidia" { - model = model[idx+1:] - } - } - - requestBody := map[string]interface{}{ - "model": model, - "messages": messages, - } - - if len(tools) > 0 { - requestBody["tools"] = tools - requestBody["tool_choice"] = "auto" - } - - if maxTokens, ok := options["max_tokens"].(int); ok { - lowerModel := strings.ToLower(model) - if strings.Contains(lowerModel, "glm") || strings.Contains(lowerModel, "o1") { - requestBody["max_completion_tokens"] = maxTokens - } else { - requestBody["max_tokens"] = maxTokens - } - } - - if temperature, ok := options["temperature"].(float64); ok { - lowerModel := strings.ToLower(model) - // Kimi k2 models only support temperature=1 - if strings.Contains(lowerModel, "kimi") && strings.Contains(lowerModel, "k2") { - requestBody["temperature"] = 1.0 - } else { - requestBody["temperature"] = temperature - } - } - - jsonData, err := json.Marshal(requestBody) + compatResp, err := p.delegate.Chat(ctx, toOpenAICompatMessages(messages), toOpenAICompatTools(tools), model, options) if err != nil { - return nil, fmt.Errorf("failed to marshal request: %w", err) + return nil, err } - - req, err := http.NewRequestWithContext(ctx, "POST", p.apiBase+"/chat/completions", bytes.NewReader(jsonData)) - if err != nil { - return nil, fmt.Errorf("failed to create request: %w", err) - } - - req.Header.Set("Content-Type", "application/json") - if p.apiKey != "" { - req.Header.Set("Authorization", "Bearer "+p.apiKey) - } - - resp, err := p.httpClient.Do(req) - if err != nil { - return nil, fmt.Errorf("failed to send request: %w", err) - } - defer resp.Body.Close() - - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("failed to read response: %w", err) - } - - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("API request failed:\n Status: %d\n Body: %s", resp.StatusCode, string(body)) - } - - return p.parseResponse(body) -} - -func (p *HTTPProvider) parseResponse(body []byte) (*LLMResponse, error) { - var apiResponse struct { - Choices []struct { - Message struct { - Content string `json:"content"` - ToolCalls []struct { - ID string `json:"id"` - Type string `json:"type"` - Function *struct { - Name string `json:"name"` - Arguments string `json:"arguments"` - } `json:"function"` - } `json:"tool_calls"` - } `json:"message"` - FinishReason string `json:"finish_reason"` - } `json:"choices"` - Usage *UsageInfo `json:"usage"` - } - - if err := json.Unmarshal(body, &apiResponse); err != nil { - return nil, fmt.Errorf("failed to unmarshal response: %w", err) - } - - if len(apiResponse.Choices) == 0 { - return &LLMResponse{ - Content: "", - FinishReason: "stop", - }, nil - } - - choice := apiResponse.Choices[0] - - toolCalls := make([]ToolCall, 0, len(choice.Message.ToolCalls)) - for _, tc := range choice.Message.ToolCalls { - arguments := make(map[string]interface{}) - name := "" - - // Handle OpenAI format with nested function object - if tc.Type == "function" && tc.Function != nil { - name = tc.Function.Name - if tc.Function.Arguments != "" { - if err := json.Unmarshal([]byte(tc.Function.Arguments), &arguments); err != nil { - arguments["raw"] = tc.Function.Arguments - } - } - } else if tc.Function != nil { - // Legacy format without type field - name = tc.Function.Name - if tc.Function.Arguments != "" { - if err := json.Unmarshal([]byte(tc.Function.Arguments), &arguments); err != nil { - arguments["raw"] = tc.Function.Arguments - } - } - } - - toolCalls = append(toolCalls, ToolCall{ - ID: tc.ID, - Name: name, - Arguments: arguments, - }) - } - - return &LLMResponse{ - Content: choice.Message.Content, - ToolCalls: toolCalls, - FinishReason: choice.FinishReason, - Usage: apiResponse.Usage, - }, nil + return fromOpenAICompatResponse(compatResp), nil } func (p *HTTPProvider) GetDefaultModel() string { return "" } -func createClaudeAuthProvider() (LLMProvider, error) { - cred, err := auth.GetCredential("anthropic") - if err != nil { - return nil, fmt.Errorf("loading auth credentials: %w", err) +func toOpenAICompatMessages(messages []Message) []openai_compat.Message { + out := make([]openai_compat.Message, 0, len(messages)) + for _, msg := range messages { + out = append(out, openai_compat.Message{ + Role: msg.Role, + Content: msg.Content, + ToolCalls: toOpenAICompatToolCalls(msg.ToolCalls), + ToolCallID: msg.ToolCallID, + }) } - if cred == nil { - return nil, fmt.Errorf("no credentials for anthropic. Run: picoclaw auth login --provider anthropic") - } - return NewClaudeProviderWithTokenSource(cred.AccessToken, createClaudeTokenSource()), nil + return out } -func createCodexAuthProvider() (LLMProvider, error) { - cred, err := auth.GetCredential("openai") - if err != nil { - return nil, fmt.Errorf("loading auth credentials: %w", err) +func toOpenAICompatTools(tools []ToolDefinition) []openai_compat.ToolDefinition { + out := make([]openai_compat.ToolDefinition, 0, len(tools)) + for _, t := range tools { + out = append(out, openai_compat.ToolDefinition{ + Type: t.Type, + Function: openai_compat.ToolFunctionDefinition{ + Name: t.Function.Name, + Description: t.Function.Description, + Parameters: t.Function.Parameters, + }, + }) } - if cred == nil { - return nil, fmt.Errorf("no credentials for openai. Run: picoclaw auth login --provider openai") - } - return NewCodexProviderWithTokenSource(cred.AccessToken, cred.AccountID, createCodexTokenSource()), nil + return out } -func CreateProvider(cfg *config.Config) (LLMProvider, error) { - model := cfg.Agents.Defaults.Model - providerName := strings.ToLower(cfg.Agents.Defaults.Provider) - - var apiKey, apiBase, proxy string - - lowerModel := strings.ToLower(model) - - // First, try to use explicitly configured provider - if providerName != "" { - switch providerName { - case "groq": - if cfg.Providers.Groq.APIKey != "" { - apiKey = cfg.Providers.Groq.APIKey - apiBase = cfg.Providers.Groq.APIBase - if apiBase == "" { - apiBase = "https://api.groq.com/openai/v1" - } +func toOpenAICompatToolCalls(toolCalls []ToolCall) []openai_compat.ToolCall { + out := make([]openai_compat.ToolCall, 0, len(toolCalls)) + for _, tc := range toolCalls { + var fn *openai_compat.FunctionCall + if tc.Function != nil { + fn = &openai_compat.FunctionCall{ + Name: tc.Function.Name, + Arguments: tc.Function.Arguments, } - case "openai", "gpt": - if cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != "" { - if cfg.Providers.OpenAI.AuthMethod == "oauth" || cfg.Providers.OpenAI.AuthMethod == "token" { - return createCodexAuthProvider() - } - apiKey = cfg.Providers.OpenAI.APIKey - apiBase = cfg.Providers.OpenAI.APIBase - if apiBase == "" { - apiBase = "https://api.openai.com/v1" - } - } - case "anthropic", "claude": - if cfg.Providers.Anthropic.APIKey != "" || cfg.Providers.Anthropic.AuthMethod != "" { - if cfg.Providers.Anthropic.AuthMethod == "oauth" || cfg.Providers.Anthropic.AuthMethod == "token" { - return createClaudeAuthProvider() - } - apiKey = cfg.Providers.Anthropic.APIKey - apiBase = cfg.Providers.Anthropic.APIBase - if apiBase == "" { - apiBase = "https://api.anthropic.com/v1" - } - } - case "openrouter": - if cfg.Providers.OpenRouter.APIKey != "" { - apiKey = cfg.Providers.OpenRouter.APIKey - if cfg.Providers.OpenRouter.APIBase != "" { - apiBase = cfg.Providers.OpenRouter.APIBase - } else { - apiBase = "https://openrouter.ai/api/v1" - } - } - case "zhipu", "glm": - if cfg.Providers.Zhipu.APIKey != "" { - apiKey = cfg.Providers.Zhipu.APIKey - apiBase = cfg.Providers.Zhipu.APIBase - if apiBase == "" { - apiBase = "https://open.bigmodel.cn/api/paas/v4" - } - } - case "gemini", "google": - if cfg.Providers.Gemini.APIKey != "" { - apiKey = cfg.Providers.Gemini.APIKey - apiBase = cfg.Providers.Gemini.APIBase - if apiBase == "" { - apiBase = "https://generativelanguage.googleapis.com/v1beta" - } - } - case "vllm": - if cfg.Providers.VLLM.APIBase != "" { - apiKey = cfg.Providers.VLLM.APIKey - apiBase = cfg.Providers.VLLM.APIBase - } - case "shengsuanyun": - if cfg.Providers.ShengSuanYun.APIKey != "" { - apiKey = cfg.Providers.ShengSuanYun.APIKey - apiBase = cfg.Providers.ShengSuanYun.APIBase - if apiBase == "" { - apiBase = "https://router.shengsuanyun.com/api/v1" - } - } - case "claude-cli", "claudecode", "claude-code": - workspace := cfg.Agents.Defaults.Workspace - if workspace == "" { - workspace = "." - } - return NewClaudeCliProvider(workspace), nil - case "deepseek": - if cfg.Providers.DeepSeek.APIKey != "" { - apiKey = cfg.Providers.DeepSeek.APIKey - apiBase = cfg.Providers.DeepSeek.APIBase - if apiBase == "" { - apiBase = "https://api.deepseek.com/v1" - } - if model != "deepseek-chat" && model != "deepseek-reasoner" { - model = "deepseek-chat" - } - } - case "github_copilot", "copilot": - if cfg.Providers.GitHubCopilot.APIBase != "" { - apiBase = cfg.Providers.GitHubCopilot.APIBase - } else { - apiBase = "localhost:4321" - } - return NewGitHubCopilotProvider(apiBase, cfg.Providers.GitHubCopilot.ConnectMode, model) - } + out = append(out, openai_compat.ToolCall{ + ID: tc.ID, + Type: tc.Type, + Function: fn, + Name: tc.Name, + Arguments: tc.Arguments, + }) + } + return out +} +func fromOpenAICompatResponse(resp *openai_compat.LLMResponse) *LLMResponse { + if resp == nil { + return &LLMResponse{} } - // Fallback: detect provider from model name - if apiKey == "" && apiBase == "" { - switch { - case (strings.Contains(lowerModel, "kimi") || strings.Contains(lowerModel, "moonshot") || strings.HasPrefix(model, "moonshot/")) && cfg.Providers.Moonshot.APIKey != "": - apiKey = cfg.Providers.Moonshot.APIKey - apiBase = cfg.Providers.Moonshot.APIBase - proxy = cfg.Providers.Moonshot.Proxy - if apiBase == "" { - apiBase = "https://api.moonshot.cn/v1" - } - - case strings.HasPrefix(model, "openrouter/") || strings.HasPrefix(model, "anthropic/") || strings.HasPrefix(model, "openai/") || strings.HasPrefix(model, "meta-llama/") || strings.HasPrefix(model, "deepseek/") || strings.HasPrefix(model, "google/"): - apiKey = cfg.Providers.OpenRouter.APIKey - proxy = cfg.Providers.OpenRouter.Proxy - if cfg.Providers.OpenRouter.APIBase != "" { - apiBase = cfg.Providers.OpenRouter.APIBase - } else { - apiBase = "https://openrouter.ai/api/v1" - } - - case (strings.Contains(lowerModel, "claude") || strings.HasPrefix(model, "anthropic/")) && (cfg.Providers.Anthropic.APIKey != "" || cfg.Providers.Anthropic.AuthMethod != ""): - if cfg.Providers.Anthropic.AuthMethod == "oauth" || cfg.Providers.Anthropic.AuthMethod == "token" { - return createClaudeAuthProvider() - } - apiKey = cfg.Providers.Anthropic.APIKey - apiBase = cfg.Providers.Anthropic.APIBase - proxy = cfg.Providers.Anthropic.Proxy - if apiBase == "" { - apiBase = "https://api.anthropic.com/v1" - } - - case (strings.Contains(lowerModel, "gpt") || strings.HasPrefix(model, "openai/")) && (cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != ""): - if cfg.Providers.OpenAI.AuthMethod == "oauth" || cfg.Providers.OpenAI.AuthMethod == "token" { - return createCodexAuthProvider() - } - apiKey = cfg.Providers.OpenAI.APIKey - apiBase = cfg.Providers.OpenAI.APIBase - proxy = cfg.Providers.OpenAI.Proxy - if apiBase == "" { - apiBase = "https://api.openai.com/v1" - } - - case (strings.Contains(lowerModel, "gemini") || strings.HasPrefix(model, "google/")) && cfg.Providers.Gemini.APIKey != "": - apiKey = cfg.Providers.Gemini.APIKey - apiBase = cfg.Providers.Gemini.APIBase - proxy = cfg.Providers.Gemini.Proxy - if apiBase == "" { - apiBase = "https://generativelanguage.googleapis.com/v1beta" - } - - case (strings.Contains(lowerModel, "glm") || strings.Contains(lowerModel, "zhipu") || strings.Contains(lowerModel, "zai")) && cfg.Providers.Zhipu.APIKey != "": - apiKey = cfg.Providers.Zhipu.APIKey - apiBase = cfg.Providers.Zhipu.APIBase - proxy = cfg.Providers.Zhipu.Proxy - if apiBase == "" { - apiBase = "https://open.bigmodel.cn/api/paas/v4" - } - - case (strings.Contains(lowerModel, "groq") || strings.HasPrefix(model, "groq/")) && cfg.Providers.Groq.APIKey != "": - apiKey = cfg.Providers.Groq.APIKey - apiBase = cfg.Providers.Groq.APIBase - proxy = cfg.Providers.Groq.Proxy - if apiBase == "" { - apiBase = "https://api.groq.com/openai/v1" - } - - case (strings.Contains(lowerModel, "nvidia") || strings.HasPrefix(model, "nvidia/")) && cfg.Providers.Nvidia.APIKey != "": - apiKey = cfg.Providers.Nvidia.APIKey - apiBase = cfg.Providers.Nvidia.APIBase - proxy = cfg.Providers.Nvidia.Proxy - if apiBase == "" { - apiBase = "https://integrate.api.nvidia.com/v1" - } - - case cfg.Providers.VLLM.APIBase != "": - apiKey = cfg.Providers.VLLM.APIKey - apiBase = cfg.Providers.VLLM.APIBase - proxy = cfg.Providers.VLLM.Proxy - - default: - if cfg.Providers.OpenRouter.APIKey != "" { - apiKey = cfg.Providers.OpenRouter.APIKey - proxy = cfg.Providers.OpenRouter.Proxy - if cfg.Providers.OpenRouter.APIBase != "" { - apiBase = cfg.Providers.OpenRouter.APIBase - } else { - apiBase = "https://openrouter.ai/api/v1" - } - } else { - return nil, fmt.Errorf("no API key configured for model: %s", model) - } + var usage *UsageInfo + if resp.Usage != nil { + usage = &UsageInfo{ + PromptTokens: resp.Usage.PromptTokens, + CompletionTokens: resp.Usage.CompletionTokens, + TotalTokens: resp.Usage.TotalTokens, } } - if apiKey == "" && !strings.HasPrefix(model, "bedrock/") { - return nil, fmt.Errorf("no API key configured for provider (model: %s)", model) + return &LLMResponse{ + Content: resp.Content, + ToolCalls: fromOpenAICompatToolCalls(resp.ToolCalls), + FinishReason: resp.FinishReason, + Usage: usage, } - - if apiBase == "" { - return nil, fmt.Errorf("no API base configured for provider (model: %s)", model) - } - - return NewHTTPProvider(apiKey, apiBase, proxy), nil +} + +func fromOpenAICompatToolCalls(toolCalls []openai_compat.ToolCall) []ToolCall { + out := make([]ToolCall, 0, len(toolCalls)) + for _, tc := range toolCalls { + var fn *FunctionCall + if tc.Function != nil { + fn = &FunctionCall{ + Name: tc.Function.Name, + Arguments: tc.Function.Arguments, + } + } + out = append(out, ToolCall{ + ID: tc.ID, + Type: tc.Type, + Function: fn, + Name: tc.Name, + Arguments: tc.Arguments, + }) + } + return out } diff --git a/pkg/providers/openai_compat/provider.go b/pkg/providers/openai_compat/provider.go new file mode 100644 index 000000000..4aef1389a --- /dev/null +++ b/pkg/providers/openai_compat/provider.go @@ -0,0 +1,230 @@ +package openai_compat + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" +) + +type ToolCall struct { + ID string `json:"id"` + Type string `json:"type,omitempty"` + Function *FunctionCall `json:"function,omitempty"` + Name string `json:"name,omitempty"` + Arguments map[string]interface{} `json:"arguments,omitempty"` +} + +type FunctionCall struct { + Name string `json:"name"` + Arguments string `json:"arguments"` +} + +type LLMResponse struct { + Content string `json:"content"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + FinishReason string `json:"finish_reason"` + Usage *UsageInfo `json:"usage,omitempty"` +} + +type UsageInfo struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` +} + +type Message struct { + Role string `json:"role"` + Content string `json:"content"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` +} + +type ToolDefinition struct { + Type string `json:"type"` + Function ToolFunctionDefinition `json:"function"` +} + +type ToolFunctionDefinition struct { + Name string `json:"name"` + Description string `json:"description"` + Parameters map[string]interface{} `json:"parameters"` +} + +type Provider struct { + apiKey string + apiBase string + httpClient *http.Client +} + +func NewProvider(apiKey, apiBase string, proxy ...string) *Provider { + proxyURL := "" + if len(proxy) > 0 { + proxyURL = proxy[0] + } + client := &http.Client{ + Timeout: 120 * time.Second, + } + + if proxyURL != "" { + parsed, err := url.Parse(proxyURL) + if err == nil { + client.Transport = &http.Transport{ + Proxy: http.ProxyURL(parsed), + } + } + } + + return &Provider{ + apiKey: apiKey, + apiBase: strings.TrimRight(apiBase, "/"), + httpClient: client, + } +} + +func (p *Provider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) { + if p.apiBase == "" { + return nil, fmt.Errorf("API base not configured") + } + + // Strip provider prefix (moonshot/kimi-*, nvidia/*) for OpenAI-compatible backends. + if idx := strings.Index(model, "/"); idx != -1 { + prefix := model[:idx] + if prefix == "moonshot" || prefix == "nvidia" { + model = model[idx+1:] + } + } + + requestBody := map[string]interface{}{ + "model": model, + "messages": messages, + } + + if len(tools) > 0 { + requestBody["tools"] = tools + requestBody["tool_choice"] = "auto" + } + + if maxTokens, ok := options["max_tokens"].(int); ok { + lowerModel := strings.ToLower(model) + if strings.Contains(lowerModel, "glm") || strings.Contains(lowerModel, "o1") { + requestBody["max_completion_tokens"] = maxTokens + } else { + requestBody["max_tokens"] = maxTokens + } + } + + if temperature, ok := options["temperature"].(float64); ok { + lowerModel := strings.ToLower(model) + // Kimi k2 models only support temperature=1. + if strings.Contains(lowerModel, "kimi") && strings.Contains(lowerModel, "k2") { + requestBody["temperature"] = 1.0 + } else { + requestBody["temperature"] = temperature + } + } + + jsonData, err := json.Marshal(requestBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, "POST", p.apiBase+"/chat/completions", bytes.NewReader(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + if p.apiKey != "" { + req.Header.Set("Authorization", "Bearer "+p.apiKey) + } + + resp, err := p.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("API request failed:\n Status: %d\n Body: %s", resp.StatusCode, string(body)) + } + + return parseResponse(body) +} + +func parseResponse(body []byte) (*LLMResponse, error) { + var apiResponse struct { + Choices []struct { + Message struct { + Content string `json:"content"` + ToolCalls []struct { + ID string `json:"id"` + Type string `json:"type"` + Function *struct { + Name string `json:"name"` + Arguments string `json:"arguments"` + } `json:"function"` + } `json:"tool_calls"` + } `json:"message"` + FinishReason string `json:"finish_reason"` + } `json:"choices"` + Usage *UsageInfo `json:"usage"` + } + + if err := json.Unmarshal(body, &apiResponse); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + + if len(apiResponse.Choices) == 0 { + return &LLMResponse{ + Content: "", + FinishReason: "stop", + }, nil + } + + choice := apiResponse.Choices[0] + toolCalls := make([]ToolCall, 0, len(choice.Message.ToolCalls)) + for _, tc := range choice.Message.ToolCalls { + arguments := make(map[string]interface{}) + name := "" + + if tc.Type == "function" && tc.Function != nil { + name = tc.Function.Name + if tc.Function.Arguments != "" { + if err := json.Unmarshal([]byte(tc.Function.Arguments), &arguments); err != nil { + arguments["raw"] = tc.Function.Arguments + } + } + } else if tc.Function != nil { + name = tc.Function.Name + if tc.Function.Arguments != "" { + if err := json.Unmarshal([]byte(tc.Function.Arguments), &arguments); err != nil { + arguments["raw"] = tc.Function.Arguments + } + } + } + + toolCalls = append(toolCalls, ToolCall{ + ID: tc.ID, + Name: name, + Arguments: arguments, + }) + } + + return &LLMResponse{ + Content: choice.Message.Content, + ToolCalls: toolCalls, + FinishReason: choice.FinishReason, + Usage: apiResponse.Usage, + }, nil +} diff --git a/pkg/providers/openai_compat/provider_test.go b/pkg/providers/openai_compat/provider_test.go new file mode 100644 index 000000000..7c5f1c63c --- /dev/null +++ b/pkg/providers/openai_compat/provider_test.go @@ -0,0 +1,149 @@ +package openai_compat + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" +) + +func TestProviderChat_UsesMaxCompletionTokensForGLM(t *testing.T) { + var requestBody map[string]interface{} + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/chat/completions" { + http.Error(w, "not found", http.StatusNotFound) + return + } + if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + resp := map[string]interface{}{ + "choices": []map[string]interface{}{ + { + "message": map[string]interface{}{"content": "ok"}, + "finish_reason": "stop", + }, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + p := NewProvider("key", server.URL) + _, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "glm-4.7", map[string]interface{}{"max_tokens": 1234}) + if err != nil { + t.Fatalf("Chat() error = %v", err) + } + + if _, ok := requestBody["max_completion_tokens"]; !ok { + t.Fatalf("expected max_completion_tokens in request body") + } + if _, ok := requestBody["max_tokens"]; ok { + t.Fatalf("did not expect max_tokens key for glm model") + } +} + +func TestProviderChat_ParsesToolCalls(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + resp := map[string]interface{}{ + "choices": []map[string]interface{}{ + { + "message": map[string]interface{}{ + "content": "", + "tool_calls": []map[string]interface{}{ + { + "id": "call_1", + "type": "function", + "function": map[string]interface{}{ + "name": "get_weather", + "arguments": "{\"city\":\"SF\"}", + }, + }, + }, + }, + "finish_reason": "tool_calls", + }, + }, + "usage": map[string]interface{}{ + "prompt_tokens": 10, + "completion_tokens": 5, + "total_tokens": 15, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + p := NewProvider("key", server.URL) + out, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "gpt-4o", nil) + if err != nil { + t.Fatalf("Chat() error = %v", err) + } + if len(out.ToolCalls) != 1 { + t.Fatalf("len(ToolCalls) = %d, want 1", len(out.ToolCalls)) + } + if out.ToolCalls[0].Name != "get_weather" { + t.Fatalf("ToolCalls[0].Name = %q, want %q", out.ToolCalls[0].Name, "get_weather") + } + if out.ToolCalls[0].Arguments["city"] != "SF" { + t.Fatalf("ToolCalls[0].Arguments[city] = %v, want SF", out.ToolCalls[0].Arguments["city"]) + } +} + +func TestProviderChat_HTTPError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "bad request", http.StatusBadRequest) + })) + defer server.Close() + + p := NewProvider("key", server.URL) + _, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "gpt-4o", nil) + if err == nil { + t.Fatal("expected error, got nil") + } +} + +func TestProviderChat_StripsMoonshotPrefixAndNormalizesKimiTemperature(t *testing.T) { + var requestBody map[string]interface{} + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + resp := map[string]interface{}{ + "choices": []map[string]interface{}{ + { + "message": map[string]interface{}{"content": "ok"}, + "finish_reason": "stop", + }, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + p := NewProvider("key", server.URL) + _, err := p.Chat( + t.Context(), + []Message{{Role: "user", Content: "hi"}}, + nil, + "moonshot/kimi-k2.5", + map[string]interface{}{"temperature": 0.3}, + ) + if err != nil { + t.Fatalf("Chat() error = %v", err) + } + + if requestBody["model"] != "kimi-k2.5" { + t.Fatalf("model = %v, want kimi-k2.5", requestBody["model"]) + } + if requestBody["temperature"] != 1.0 { + t.Fatalf("temperature = %v, want 1.0", requestBody["temperature"]) + } +}