From e1583f3b1379061f7913b5b7e093cc0f672d2063 Mon Sep 17 00:00:00 2001 From: yinwm Date: Thu, 19 Feb 2026 01:30:19 +0800 Subject: [PATCH] refactor: simplify legacy_provider.go from 349 to 49 lines - Move OAuth helper functions to factory_provider.go - Add auto-migration in LoadConfig: old providers -> model_list - Add Workspace field to ModelConfig for CLI-based providers - Fix OAuth handling to use auth store instead of raw APIKey - Update tests to use new model_list configuration format This eliminates the giant switch-case in legacy_provider.go, achieving the goal of "zero-code provider addition" from the design document (issue #283). Co-Authored-By: Claude Opus 4.6 --- pkg/config/config.go | 6 + pkg/providers/claude_cli_provider_test.go | 21 +- pkg/providers/factory_provider.go | 48 ++- pkg/providers/legacy_provider.go | 340 ++-------------------- 4 files changed, 85 insertions(+), 330 deletions(-) diff --git a/pkg/config/config.go b/pkg/config/config.go index 1b6f7b76c..c2b5ee01f 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -215,6 +215,7 @@ type ModelConfig struct { // Special providers (CLI-based, OAuth, etc.) AuthMethod string `json:"auth_method,omitempty"` // Authentication method: oauth, token ConnectMode string `json:"connect_mode,omitempty"` // Connection mode: stdio, grpc + Workspace string `json:"workspace,omitempty"` // Workspace path for CLI-based providers // Optional optimizations RPM int `json:"rpm,omitempty"` // Requests per minute limit @@ -288,6 +289,11 @@ func LoadConfig(path string) (*Config, error) { return nil, err } + // Auto-migrate: if only legacy providers config exists, convert to model_list + if len(cfg.ModelList) == 0 && cfg.HasProvidersConfig() { + cfg.ModelList = ConvertProvidersToModelList(cfg) + } + return cfg, nil } diff --git a/pkg/providers/claude_cli_provider_test.go b/pkg/providers/claude_cli_provider_test.go index ae49af042..2c68e6809 100644 --- a/pkg/providers/claude_cli_provider_test.go +++ b/pkg/providers/claude_cli_provider_test.go @@ -416,8 +416,10 @@ func TestChat_EmptyWorkspaceDoesNotSetDir(t *testing.T) { func TestCreateProvider_ClaudeCli(t *testing.T) { cfg := config.DefaultConfig() - cfg.Agents.Defaults.Provider = "claude-cli" - cfg.Agents.Defaults.Workspace = "/test/ws" + cfg.ModelList = []config.ModelConfig{ + {ModelName: "claude-sonnet-4", Model: "claude-cli/claude-sonnet-4-20250514", Workspace: "/test/ws"}, + } + cfg.Agents.Defaults.Model = "claude-sonnet-4" provider, _, err := CreateProvider(cfg) if err != nil { @@ -435,7 +437,10 @@ func TestCreateProvider_ClaudeCli(t *testing.T) { func TestCreateProvider_ClaudeCode(t *testing.T) { cfg := config.DefaultConfig() - cfg.Agents.Defaults.Provider = "claude-code" + cfg.ModelList = []config.ModelConfig{ + {ModelName: "claude-code", Model: "claude-cli/claude-code"}, + } + cfg.Agents.Defaults.Model = "claude-code" provider, _, err := CreateProvider(cfg) if err != nil { @@ -448,7 +453,10 @@ func TestCreateProvider_ClaudeCode(t *testing.T) { func TestCreateProvider_ClaudeCodec(t *testing.T) { cfg := config.DefaultConfig() - cfg.Agents.Defaults.Provider = "claudecode" + cfg.ModelList = []config.ModelConfig{ + {ModelName: "claudecode", Model: "claude-cli/claudecode"}, + } + cfg.Agents.Defaults.Model = "claudecode" provider, _, err := CreateProvider(cfg) if err != nil { @@ -461,7 +469,10 @@ func TestCreateProvider_ClaudeCodec(t *testing.T) { func TestCreateProvider_ClaudeCliDefaultWorkspace(t *testing.T) { cfg := config.DefaultConfig() - cfg.Agents.Defaults.Provider = "claude-cli" + cfg.ModelList = []config.ModelConfig{ + {ModelName: "claude-cli", Model: "claude-cli/claude-sonnet"}, + } + cfg.Agents.Defaults.Model = "claude-cli" cfg.Agents.Defaults.Workspace = "" provider, _, err := CreateProvider(cfg) diff --git a/pkg/providers/factory_provider.go b/pkg/providers/factory_provider.go index 695d4ffa5..8ed7559c6 100644 --- a/pkg/providers/factory_provider.go +++ b/pkg/providers/factory_provider.go @@ -9,9 +9,34 @@ import ( "fmt" "strings" + "github.com/sipeed/picoclaw/pkg/auth" "github.com/sipeed/picoclaw/pkg/config" ) +// createClaudeAuthProvider creates a Claude provider using OAuth credentials from auth store. +func createClaudeAuthProvider() (LLMProvider, error) { + cred, err := auth.GetCredential("anthropic") + if err != nil { + return nil, fmt.Errorf("loading auth credentials: %w", err) + } + if cred == nil { + return nil, fmt.Errorf("no credentials for anthropic. Run: picoclaw auth login --provider anthropic") + } + return NewClaudeProviderWithTokenSource(cred.AccessToken, createClaudeTokenSource()), nil +} + +// createCodexAuthProvider creates a Codex provider using OAuth credentials from auth store. +func createCodexAuthProvider() (LLMProvider, error) { + cred, err := auth.GetCredential("openai") + if err != nil { + return nil, fmt.Errorf("loading auth credentials: %w", err) + } + if cred == nil { + return nil, fmt.Errorf("no credentials for openai. Run: picoclaw auth login --provider openai") + } + return NewCodexProviderWithTokenSource(cred.AccessToken, cred.AccountID, createCodexTokenSource()), nil +} + // ExtractProtocol extracts the protocol prefix and model identifier from a model string. // If no prefix is specified, it defaults to "openai". // Examples: @@ -60,25 +85,38 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err case "anthropic": if cfg.AuthMethod == "oauth" || cfg.AuthMethod == "token" { - // Use Claude SDK with token - return NewClaudeProvider(cfg.APIKey), modelID, nil + // Use OAuth credentials from auth store + provider, err := createClaudeAuthProvider() + if err != nil { + return nil, "", err + } + return provider, modelID, nil } - // Use HTTP API + // Use API key with HTTP API apiBase := cfg.APIBase if apiBase == "" { apiBase = "https://api.anthropic.com/v1" } + if cfg.APIKey == "" { + return nil, "", fmt.Errorf("api_key is required for anthropic protocol (model: %s)", cfg.Model) + } return NewHTTPProvider(cfg.APIKey, apiBase, cfg.Proxy), modelID, nil case "antigravity": return NewAntigravityProvider(), modelID, nil case "claude-cli", "claudecli": - workspace := "." + workspace := cfg.Workspace + if workspace == "" { + workspace = "." + } return NewClaudeCliProvider(workspace), modelID, nil case "codex-cli", "codexcli": - workspace := "." + workspace := cfg.Workspace + if workspace == "" { + workspace = "." + } return NewCodexCliProvider(workspace), modelID, nil case "github-copilot", "copilot": diff --git a/pkg/providers/legacy_provider.go b/pkg/providers/legacy_provider.go index c1efb03b3..eb13cec65 100644 --- a/pkg/providers/legacy_provider.go +++ b/pkg/providers/legacy_provider.go @@ -7,343 +7,43 @@ package providers import ( "fmt" - "strings" - "github.com/sipeed/picoclaw/pkg/auth" "github.com/sipeed/picoclaw/pkg/config" ) -// createClaudeAuthProvider creates a Claude provider using OAuth credentials. -func createClaudeAuthProvider() (LLMProvider, error) { - cred, err := auth.GetCredential("anthropic") - if err != nil { - return nil, fmt.Errorf("loading auth credentials: %w", err) - } - if cred == nil { - return nil, fmt.Errorf("no credentials for anthropic. Run: picoclaw auth login --provider anthropic") - } - return NewClaudeProviderWithTokenSource(cred.AccessToken, createClaudeTokenSource()), nil -} - -// createCodexAuthProvider creates a Codex provider using OAuth credentials. -func createCodexAuthProvider() (LLMProvider, error) { - cred, err := auth.GetCredential("openai") - if err != nil { - return nil, fmt.Errorf("loading auth credentials: %w", err) - } - if cred == nil { - return nil, fmt.Errorf("no credentials for openai. Run: picoclaw auth login --provider openai") - } - return NewCodexProviderWithTokenSource(cred.AccessToken, cred.AccountID, createCodexTokenSource()), nil -} - // CreateProvider creates a provider based on the configuration. -// It supports both the new model_list configuration and the legacy providers configuration. +// It uses the model_list configuration (new format) to create providers. +// The old providers config is automatically converted to model_list during config loading. // Returns the provider, the model ID to use, and any error. func CreateProvider(cfg *config.Config) (LLMProvider, string, error) { model := cfg.Agents.Defaults.Model - // First, try to use model_list configuration - if len(cfg.ModelList) > 0 { - // Try to get config by model name first - modelCfg, err := cfg.GetModelConfig(model) - if err == nil { - // Found in model_list, use factory to create provider - provider, modelID, err := CreateProviderFromConfig(modelCfg) - if err != nil { - return nil, "", fmt.Errorf("failed to create provider from model_list: %w", err) - } - return provider, modelID, nil - } - // Model not found in model_list, fall through to providers config + // Ensure model_list is populated (should be done by LoadConfig, but handle edge cases) + if len(cfg.ModelList) == 0 && cfg.HasProvidersConfig() { + cfg.ModelList = config.ConvertProvidersToModelList(cfg) } - // Log deprecation warning if using old providers config - if cfg.HasProvidersConfig() && len(cfg.ModelList) == 0 { - fmt.Println("WARNING: providers config is deprecated, please migrate to model_list") + // Must have model_list at this point + if len(cfg.ModelList) == 0 { + return nil, "", fmt.Errorf("no providers configured. Please add entries to model_list in your config") } - providerName := strings.ToLower(cfg.Agents.Defaults.Provider) - - var apiKey, apiBase, proxy string - - lowerModel := strings.ToLower(model) - - // First, try to use explicitly configured provider - if providerName != "" { - switch providerName { - case "groq": - if cfg.Providers.Groq.APIKey != "" { - apiKey = cfg.Providers.Groq.APIKey - apiBase = cfg.Providers.Groq.APIBase - if apiBase == "" { - apiBase = "https://api.groq.com/openai/v1" - } - } - case "openai", "gpt": - if cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != "" { - if cfg.Providers.OpenAI.AuthMethod == "codex-cli" { - return NewCodexProviderWithTokenSource("", "", CreateCodexCliTokenSource()), model, nil - } - if cfg.Providers.OpenAI.AuthMethod == "oauth" || cfg.Providers.OpenAI.AuthMethod == "token" { - provider, err := createCodexAuthProvider() - return provider, model, err - } - apiKey = cfg.Providers.OpenAI.APIKey - apiBase = cfg.Providers.OpenAI.APIBase - if apiBase == "" { - apiBase = "https://api.openai.com/v1" - } - } - case "anthropic", "claude": - if cfg.Providers.Anthropic.APIKey != "" || cfg.Providers.Anthropic.AuthMethod != "" { - if cfg.Providers.Anthropic.AuthMethod == "oauth" || cfg.Providers.Anthropic.AuthMethod == "token" { - provider, err := createClaudeAuthProvider() - return provider, model, err - } - apiKey = cfg.Providers.Anthropic.APIKey - apiBase = cfg.Providers.Anthropic.APIBase - if apiBase == "" { - apiBase = "https://api.anthropic.com/v1" - } - } - case "openrouter": - if cfg.Providers.OpenRouter.APIKey != "" { - apiKey = cfg.Providers.OpenRouter.APIKey - if cfg.Providers.OpenRouter.APIBase != "" { - apiBase = cfg.Providers.OpenRouter.APIBase - } else { - apiBase = "https://openrouter.ai/api/v1" - } - } - case "zhipu", "glm": - if cfg.Providers.Zhipu.APIKey != "" { - apiKey = cfg.Providers.Zhipu.APIKey - apiBase = cfg.Providers.Zhipu.APIBase - if apiBase == "" { - apiBase = "https://open.bigmodel.cn/api/paas/v4" - } - } - case "gemini", "google": - if cfg.Providers.Gemini.APIKey != "" { - apiKey = cfg.Providers.Gemini.APIKey - apiBase = cfg.Providers.Gemini.APIBase - if apiBase == "" { - apiBase = "https://generativelanguage.googleapis.com/v1beta" - } - } - case "vllm": - if cfg.Providers.VLLM.APIBase != "" { - apiKey = cfg.Providers.VLLM.APIKey - apiBase = cfg.Providers.VLLM.APIBase - } - case "shengsuanyun": - if cfg.Providers.ShengSuanYun.APIKey != "" { - apiKey = cfg.Providers.ShengSuanYun.APIKey - apiBase = cfg.Providers.ShengSuanYun.APIBase - if apiBase == "" { - apiBase = "https://router.shengsuanyun.com/api/v1" - } - } - case "claude-cli", "claudecode", "claude-code": - workspace := cfg.WorkspacePath() - if workspace == "" { - workspace = "." - } - return NewClaudeCliProvider(workspace), model, nil - case "codex-cli", "codex-code": - workspace := cfg.WorkspacePath() - if workspace == "" { - workspace = "." - } - return NewCodexCliProvider(workspace), model, nil - case "cerebras": - if cfg.Providers.Cerebras.APIKey != "" { - apiKey = cfg.Providers.Cerebras.APIKey - apiBase = cfg.Providers.Cerebras.APIBase - if apiBase == "" { - apiBase = "https://api.cerebras.ai/v1" - } - } - case "deepseek": - if cfg.Providers.DeepSeek.APIKey != "" { - apiKey = cfg.Providers.DeepSeek.APIKey - apiBase = cfg.Providers.DeepSeek.APIBase - if apiBase == "" { - apiBase = "https://api.deepseek.com/v1" - } - if model != "deepseek-chat" && model != "deepseek-reasoner" { - model = "deepseek-chat" - } - } - case "qwen": - if cfg.Providers.Qwen.APIKey != "" { - apiKey = cfg.Providers.Qwen.APIKey - apiBase = cfg.Providers.Qwen.APIBase - if apiBase == "" { - apiBase = "https://dashscope.aliyuncs.com/compatible-mode/v1" - } - } - case "github_copilot", "copilot": - if cfg.Providers.GitHubCopilot.APIBase != "" { - apiBase = cfg.Providers.GitHubCopilot.APIBase - } else { - apiBase = "localhost:4321" - } - provider, err := NewGitHubCopilotProvider(apiBase, cfg.Providers.GitHubCopilot.ConnectMode, model) - return provider, model, err - case "antigravity", "google-antigravity": - return NewAntigravityProvider(), model, nil - - case "volcengine", "doubao": - if cfg.Providers.VolcEngine.APIKey != "" { - apiKey = cfg.Providers.VolcEngine.APIKey - apiBase = cfg.Providers.VolcEngine.APIBase - if apiBase == "" { - apiBase = "https://ark.cn-beijing.volces.com/api/v3" - } - } - - } - + // Get model config from model_list + modelCfg, err := cfg.GetModelConfig(model) + if err != nil { + return nil, "", fmt.Errorf("model %q not found in model_list: %w", model, err) } - // Fallback: detect provider from model name - if apiKey == "" && apiBase == "" { - switch { - case (strings.Contains(lowerModel, "kimi") || strings.Contains(lowerModel, "moonshot") || strings.HasPrefix(model, "moonshot/")) && cfg.Providers.Moonshot.APIKey != "": - apiKey = cfg.Providers.Moonshot.APIKey - apiBase = cfg.Providers.Moonshot.APIBase - proxy = cfg.Providers.Moonshot.Proxy - if apiBase == "" { - apiBase = "https://api.moonshot.cn/v1" - } - - case strings.HasPrefix(model, "openrouter/") || strings.HasPrefix(model, "anthropic/") || strings.HasPrefix(model, "openai/") || strings.HasPrefix(model, "meta-llama/") || strings.HasPrefix(model, "deepseek/") || strings.HasPrefix(model, "google/"): - apiKey = cfg.Providers.OpenRouter.APIKey - proxy = cfg.Providers.OpenRouter.Proxy - if cfg.Providers.OpenRouter.APIBase != "" { - apiBase = cfg.Providers.OpenRouter.APIBase - } else { - apiBase = "https://openrouter.ai/api/v1" - } - - case (strings.Contains(lowerModel, "claude") || strings.HasPrefix(model, "anthropic/")) && (cfg.Providers.Anthropic.APIKey != "" || cfg.Providers.Anthropic.AuthMethod != ""): - if cfg.Providers.Anthropic.AuthMethod == "oauth" || cfg.Providers.Anthropic.AuthMethod == "token" { - provider, err := createClaudeAuthProvider() - return provider, model, err - } - apiKey = cfg.Providers.Anthropic.APIKey - apiBase = cfg.Providers.Anthropic.APIBase - proxy = cfg.Providers.Anthropic.Proxy - if apiBase == "" { - apiBase = "https://api.anthropic.com/v1" - } - - case (strings.Contains(lowerModel, "gpt") || strings.HasPrefix(model, "openai/")) && (cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != ""): - if cfg.Providers.OpenAI.AuthMethod == "oauth" || cfg.Providers.OpenAI.AuthMethod == "token" { - provider, err := createCodexAuthProvider() - return provider, model, err - } - apiKey = cfg.Providers.OpenAI.APIKey - apiBase = cfg.Providers.OpenAI.APIBase - proxy = cfg.Providers.OpenAI.Proxy - if apiBase == "" { - apiBase = "https://api.openai.com/v1" - } - - case (strings.Contains(lowerModel, "gemini") || strings.HasPrefix(model, "google/")) && cfg.Providers.Gemini.APIKey != "": - apiKey = cfg.Providers.Gemini.APIKey - apiBase = cfg.Providers.Gemini.APIBase - proxy = cfg.Providers.Gemini.Proxy - if apiBase == "" { - apiBase = "https://generativelanguage.googleapis.com/v1beta" - } - - case (strings.Contains(lowerModel, "glm") || strings.Contains(lowerModel, "zhipu") || strings.Contains(lowerModel, "zai")) && cfg.Providers.Zhipu.APIKey != "": - apiKey = cfg.Providers.Zhipu.APIKey - apiBase = cfg.Providers.Zhipu.APIBase - proxy = cfg.Providers.Zhipu.Proxy - if apiBase == "" { - apiBase = "https://open.bigmodel.cn/api/paas/v4" - } - - case (strings.Contains(lowerModel, "groq") || strings.HasPrefix(model, "groq/")) && cfg.Providers.Groq.APIKey != "": - apiKey = cfg.Providers.Groq.APIKey - apiBase = cfg.Providers.Groq.APIBase - proxy = cfg.Providers.Groq.Proxy - if apiBase == "" { - apiBase = "https://api.groq.com/openai/v1" - } - - case (strings.Contains(lowerModel, "qwen") || strings.HasPrefix(model, "qwen/")) && cfg.Providers.Qwen.APIKey != "": - apiKey = cfg.Providers.Qwen.APIKey - apiBase = cfg.Providers.Qwen.APIBase - proxy = cfg.Providers.Qwen.Proxy - if apiBase == "" { - apiBase = "https://dashscope.aliyuncs.com/compatible-mode/v1" - } - - case (strings.Contains(lowerModel, "nvidia") || strings.HasPrefix(model, "nvidia/")) && cfg.Providers.Nvidia.APIKey != "": - apiKey = cfg.Providers.Nvidia.APIKey - apiBase = cfg.Providers.Nvidia.APIBase - proxy = cfg.Providers.Nvidia.Proxy - if apiBase == "" { - apiBase = "https://integrate.api.nvidia.com/v1" - } - case (strings.Contains(lowerModel, "cerebras") || strings.HasPrefix(model, "cerebras/")) && cfg.Providers.Cerebras.APIKey != "": - apiKey = cfg.Providers.Cerebras.APIKey - apiBase = cfg.Providers.Cerebras.APIBase - proxy = cfg.Providers.Cerebras.Proxy - if apiBase == "" { - apiBase = "https://api.cerebras.ai/v1" - } - - case (strings.Contains(lowerModel, "ollama") || strings.HasPrefix(model, "ollama/")) && cfg.Providers.Ollama.APIKey != "": - fmt.Println("Ollama provider selected based on model name prefix") - apiKey = cfg.Providers.Ollama.APIKey - apiBase = cfg.Providers.Ollama.APIBase - proxy = cfg.Providers.Ollama.Proxy - if apiBase == "" { - apiBase = "http://localhost:11434/v1" - } - fmt.Println("Ollama apiBase:", apiBase) - - case (strings.Contains(lowerModel, "doubao") || strings.HasPrefix(lowerModel, "doubao") || strings.Contains(lowerModel, "volcengine")) && cfg.Providers.VolcEngine.APIKey != "": - apiKey = cfg.Providers.VolcEngine.APIKey - apiBase = cfg.Providers.VolcEngine.APIBase - proxy = cfg.Providers.VolcEngine.Proxy - if apiBase == "" { - apiBase = "https://ark.cn-beijing.volces.com/api/v3" - } - - case cfg.Providers.VLLM.APIBase != "": - apiKey = cfg.Providers.VLLM.APIKey - apiBase = cfg.Providers.VLLM.APIBase - proxy = cfg.Providers.VLLM.Proxy - - default: - if cfg.Providers.OpenRouter.APIKey != "" { - apiKey = cfg.Providers.OpenRouter.APIKey - proxy = cfg.Providers.OpenRouter.Proxy - if cfg.Providers.OpenRouter.APIBase != "" { - apiBase = cfg.Providers.OpenRouter.APIBase - } else { - apiBase = "https://openrouter.ai/api/v1" - } - } else { - return nil, "", fmt.Errorf("no API key configured for model: %s", model) - } - } + // Inject global workspace if not set in model config + if modelCfg.Workspace == "" { + modelCfg.Workspace = cfg.WorkspacePath() } - if apiKey == "" && !strings.HasPrefix(model, "bedrock/") { - return nil, "", fmt.Errorf("no API key configured for provider (model: %s)", model) + // Use factory to create provider + provider, modelID, err := CreateProviderFromConfig(modelCfg) + if err != nil { + return nil, "", fmt.Errorf("failed to create provider for model %q: %w", model, err) } - if apiBase == "" { - return nil, "", fmt.Errorf("no API base configured for provider (model: %s)", model) - } - - return NewHTTPProvider(apiKey, apiBase, proxy), model, nil + return provider, modelID, nil }