mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
refactor: simplify legacy_provider.go from 349 to 49 lines
- Move OAuth helper functions to factory_provider.go - Add auto-migration in LoadConfig: old providers -> model_list - Add Workspace field to ModelConfig for CLI-based providers - Fix OAuth handling to use auth store instead of raw APIKey - Update tests to use new model_list configuration format This eliminates the giant switch-case in legacy_provider.go, achieving the goal of "zero-code provider addition" from the design document (issue #283). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -215,6 +215,7 @@ type ModelConfig struct {
|
||||
// Special providers (CLI-based, OAuth, etc.)
|
||||
AuthMethod string `json:"auth_method,omitempty"` // Authentication method: oauth, token
|
||||
ConnectMode string `json:"connect_mode,omitempty"` // Connection mode: stdio, grpc
|
||||
Workspace string `json:"workspace,omitempty"` // Workspace path for CLI-based providers
|
||||
|
||||
// Optional optimizations
|
||||
RPM int `json:"rpm,omitempty"` // Requests per minute limit
|
||||
@@ -288,6 +289,11 @@ func LoadConfig(path string) (*Config, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Auto-migrate: if only legacy providers config exists, convert to model_list
|
||||
if len(cfg.ModelList) == 0 && cfg.HasProvidersConfig() {
|
||||
cfg.ModelList = ConvertProvidersToModelList(cfg)
|
||||
}
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -416,8 +416,10 @@ func TestChat_EmptyWorkspaceDoesNotSetDir(t *testing.T) {
|
||||
|
||||
func TestCreateProvider_ClaudeCli(t *testing.T) {
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Agents.Defaults.Provider = "claude-cli"
|
||||
cfg.Agents.Defaults.Workspace = "/test/ws"
|
||||
cfg.ModelList = []config.ModelConfig{
|
||||
{ModelName: "claude-sonnet-4", Model: "claude-cli/claude-sonnet-4-20250514", Workspace: "/test/ws"},
|
||||
}
|
||||
cfg.Agents.Defaults.Model = "claude-sonnet-4"
|
||||
|
||||
provider, _, err := CreateProvider(cfg)
|
||||
if err != nil {
|
||||
@@ -435,7 +437,10 @@ func TestCreateProvider_ClaudeCli(t *testing.T) {
|
||||
|
||||
func TestCreateProvider_ClaudeCode(t *testing.T) {
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Agents.Defaults.Provider = "claude-code"
|
||||
cfg.ModelList = []config.ModelConfig{
|
||||
{ModelName: "claude-code", Model: "claude-cli/claude-code"},
|
||||
}
|
||||
cfg.Agents.Defaults.Model = "claude-code"
|
||||
|
||||
provider, _, err := CreateProvider(cfg)
|
||||
if err != nil {
|
||||
@@ -448,7 +453,10 @@ func TestCreateProvider_ClaudeCode(t *testing.T) {
|
||||
|
||||
func TestCreateProvider_ClaudeCodec(t *testing.T) {
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Agents.Defaults.Provider = "claudecode"
|
||||
cfg.ModelList = []config.ModelConfig{
|
||||
{ModelName: "claudecode", Model: "claude-cli/claudecode"},
|
||||
}
|
||||
cfg.Agents.Defaults.Model = "claudecode"
|
||||
|
||||
provider, _, err := CreateProvider(cfg)
|
||||
if err != nil {
|
||||
@@ -461,7 +469,10 @@ func TestCreateProvider_ClaudeCodec(t *testing.T) {
|
||||
|
||||
func TestCreateProvider_ClaudeCliDefaultWorkspace(t *testing.T) {
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Agents.Defaults.Provider = "claude-cli"
|
||||
cfg.ModelList = []config.ModelConfig{
|
||||
{ModelName: "claude-cli", Model: "claude-cli/claude-sonnet"},
|
||||
}
|
||||
cfg.Agents.Defaults.Model = "claude-cli"
|
||||
cfg.Agents.Defaults.Workspace = ""
|
||||
|
||||
provider, _, err := CreateProvider(cfg)
|
||||
|
||||
@@ -9,9 +9,34 @@ import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/auth"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
// createClaudeAuthProvider creates a Claude provider using OAuth credentials from auth store.
|
||||
func createClaudeAuthProvider() (LLMProvider, error) {
|
||||
cred, err := auth.GetCredential("anthropic")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("loading auth credentials: %w", err)
|
||||
}
|
||||
if cred == nil {
|
||||
return nil, fmt.Errorf("no credentials for anthropic. Run: picoclaw auth login --provider anthropic")
|
||||
}
|
||||
return NewClaudeProviderWithTokenSource(cred.AccessToken, createClaudeTokenSource()), nil
|
||||
}
|
||||
|
||||
// createCodexAuthProvider creates a Codex provider using OAuth credentials from auth store.
|
||||
func createCodexAuthProvider() (LLMProvider, error) {
|
||||
cred, err := auth.GetCredential("openai")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("loading auth credentials: %w", err)
|
||||
}
|
||||
if cred == nil {
|
||||
return nil, fmt.Errorf("no credentials for openai. Run: picoclaw auth login --provider openai")
|
||||
}
|
||||
return NewCodexProviderWithTokenSource(cred.AccessToken, cred.AccountID, createCodexTokenSource()), nil
|
||||
}
|
||||
|
||||
// ExtractProtocol extracts the protocol prefix and model identifier from a model string.
|
||||
// If no prefix is specified, it defaults to "openai".
|
||||
// Examples:
|
||||
@@ -60,25 +85,38 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
|
||||
|
||||
case "anthropic":
|
||||
if cfg.AuthMethod == "oauth" || cfg.AuthMethod == "token" {
|
||||
// Use Claude SDK with token
|
||||
return NewClaudeProvider(cfg.APIKey), modelID, nil
|
||||
// Use OAuth credentials from auth store
|
||||
provider, err := createClaudeAuthProvider()
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
return provider, modelID, nil
|
||||
}
|
||||
// Use HTTP API
|
||||
// Use API key with HTTP API
|
||||
apiBase := cfg.APIBase
|
||||
if apiBase == "" {
|
||||
apiBase = "https://api.anthropic.com/v1"
|
||||
}
|
||||
if cfg.APIKey == "" {
|
||||
return nil, "", fmt.Errorf("api_key is required for anthropic protocol (model: %s)", cfg.Model)
|
||||
}
|
||||
return NewHTTPProvider(cfg.APIKey, apiBase, cfg.Proxy), modelID, nil
|
||||
|
||||
case "antigravity":
|
||||
return NewAntigravityProvider(), modelID, nil
|
||||
|
||||
case "claude-cli", "claudecli":
|
||||
workspace := "."
|
||||
workspace := cfg.Workspace
|
||||
if workspace == "" {
|
||||
workspace = "."
|
||||
}
|
||||
return NewClaudeCliProvider(workspace), modelID, nil
|
||||
|
||||
case "codex-cli", "codexcli":
|
||||
workspace := "."
|
||||
workspace := cfg.Workspace
|
||||
if workspace == "" {
|
||||
workspace = "."
|
||||
}
|
||||
return NewCodexCliProvider(workspace), modelID, nil
|
||||
|
||||
case "github-copilot", "copilot":
|
||||
|
||||
@@ -7,343 +7,43 @@ package providers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/auth"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
// createClaudeAuthProvider creates a Claude provider using OAuth credentials.
|
||||
func createClaudeAuthProvider() (LLMProvider, error) {
|
||||
cred, err := auth.GetCredential("anthropic")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("loading auth credentials: %w", err)
|
||||
}
|
||||
if cred == nil {
|
||||
return nil, fmt.Errorf("no credentials for anthropic. Run: picoclaw auth login --provider anthropic")
|
||||
}
|
||||
return NewClaudeProviderWithTokenSource(cred.AccessToken, createClaudeTokenSource()), nil
|
||||
}
|
||||
|
||||
// createCodexAuthProvider creates a Codex provider using OAuth credentials.
|
||||
func createCodexAuthProvider() (LLMProvider, error) {
|
||||
cred, err := auth.GetCredential("openai")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("loading auth credentials: %w", err)
|
||||
}
|
||||
if cred == nil {
|
||||
return nil, fmt.Errorf("no credentials for openai. Run: picoclaw auth login --provider openai")
|
||||
}
|
||||
return NewCodexProviderWithTokenSource(cred.AccessToken, cred.AccountID, createCodexTokenSource()), nil
|
||||
}
|
||||
|
||||
// CreateProvider creates a provider based on the configuration.
|
||||
// It supports both the new model_list configuration and the legacy providers configuration.
|
||||
// It uses the model_list configuration (new format) to create providers.
|
||||
// The old providers config is automatically converted to model_list during config loading.
|
||||
// Returns the provider, the model ID to use, and any error.
|
||||
func CreateProvider(cfg *config.Config) (LLMProvider, string, error) {
|
||||
model := cfg.Agents.Defaults.Model
|
||||
|
||||
// First, try to use model_list configuration
|
||||
if len(cfg.ModelList) > 0 {
|
||||
// Try to get config by model name first
|
||||
modelCfg, err := cfg.GetModelConfig(model)
|
||||
if err == nil {
|
||||
// Found in model_list, use factory to create provider
|
||||
provider, modelID, err := CreateProviderFromConfig(modelCfg)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("failed to create provider from model_list: %w", err)
|
||||
}
|
||||
return provider, modelID, nil
|
||||
}
|
||||
// Model not found in model_list, fall through to providers config
|
||||
// Ensure model_list is populated (should be done by LoadConfig, but handle edge cases)
|
||||
if len(cfg.ModelList) == 0 && cfg.HasProvidersConfig() {
|
||||
cfg.ModelList = config.ConvertProvidersToModelList(cfg)
|
||||
}
|
||||
|
||||
// Log deprecation warning if using old providers config
|
||||
if cfg.HasProvidersConfig() && len(cfg.ModelList) == 0 {
|
||||
fmt.Println("WARNING: providers config is deprecated, please migrate to model_list")
|
||||
// Must have model_list at this point
|
||||
if len(cfg.ModelList) == 0 {
|
||||
return nil, "", fmt.Errorf("no providers configured. Please add entries to model_list in your config")
|
||||
}
|
||||
|
||||
providerName := strings.ToLower(cfg.Agents.Defaults.Provider)
|
||||
|
||||
var apiKey, apiBase, proxy string
|
||||
|
||||
lowerModel := strings.ToLower(model)
|
||||
|
||||
// First, try to use explicitly configured provider
|
||||
if providerName != "" {
|
||||
switch providerName {
|
||||
case "groq":
|
||||
if cfg.Providers.Groq.APIKey != "" {
|
||||
apiKey = cfg.Providers.Groq.APIKey
|
||||
apiBase = cfg.Providers.Groq.APIBase
|
||||
if apiBase == "" {
|
||||
apiBase = "https://api.groq.com/openai/v1"
|
||||
}
|
||||
}
|
||||
case "openai", "gpt":
|
||||
if cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != "" {
|
||||
if cfg.Providers.OpenAI.AuthMethod == "codex-cli" {
|
||||
return NewCodexProviderWithTokenSource("", "", CreateCodexCliTokenSource()), model, nil
|
||||
}
|
||||
if cfg.Providers.OpenAI.AuthMethod == "oauth" || cfg.Providers.OpenAI.AuthMethod == "token" {
|
||||
provider, err := createCodexAuthProvider()
|
||||
return provider, model, err
|
||||
}
|
||||
apiKey = cfg.Providers.OpenAI.APIKey
|
||||
apiBase = cfg.Providers.OpenAI.APIBase
|
||||
if apiBase == "" {
|
||||
apiBase = "https://api.openai.com/v1"
|
||||
}
|
||||
}
|
||||
case "anthropic", "claude":
|
||||
if cfg.Providers.Anthropic.APIKey != "" || cfg.Providers.Anthropic.AuthMethod != "" {
|
||||
if cfg.Providers.Anthropic.AuthMethod == "oauth" || cfg.Providers.Anthropic.AuthMethod == "token" {
|
||||
provider, err := createClaudeAuthProvider()
|
||||
return provider, model, err
|
||||
}
|
||||
apiKey = cfg.Providers.Anthropic.APIKey
|
||||
apiBase = cfg.Providers.Anthropic.APIBase
|
||||
if apiBase == "" {
|
||||
apiBase = "https://api.anthropic.com/v1"
|
||||
}
|
||||
}
|
||||
case "openrouter":
|
||||
if cfg.Providers.OpenRouter.APIKey != "" {
|
||||
apiKey = cfg.Providers.OpenRouter.APIKey
|
||||
if cfg.Providers.OpenRouter.APIBase != "" {
|
||||
apiBase = cfg.Providers.OpenRouter.APIBase
|
||||
} else {
|
||||
apiBase = "https://openrouter.ai/api/v1"
|
||||
}
|
||||
}
|
||||
case "zhipu", "glm":
|
||||
if cfg.Providers.Zhipu.APIKey != "" {
|
||||
apiKey = cfg.Providers.Zhipu.APIKey
|
||||
apiBase = cfg.Providers.Zhipu.APIBase
|
||||
if apiBase == "" {
|
||||
apiBase = "https://open.bigmodel.cn/api/paas/v4"
|
||||
}
|
||||
}
|
||||
case "gemini", "google":
|
||||
if cfg.Providers.Gemini.APIKey != "" {
|
||||
apiKey = cfg.Providers.Gemini.APIKey
|
||||
apiBase = cfg.Providers.Gemini.APIBase
|
||||
if apiBase == "" {
|
||||
apiBase = "https://generativelanguage.googleapis.com/v1beta"
|
||||
}
|
||||
}
|
||||
case "vllm":
|
||||
if cfg.Providers.VLLM.APIBase != "" {
|
||||
apiKey = cfg.Providers.VLLM.APIKey
|
||||
apiBase = cfg.Providers.VLLM.APIBase
|
||||
}
|
||||
case "shengsuanyun":
|
||||
if cfg.Providers.ShengSuanYun.APIKey != "" {
|
||||
apiKey = cfg.Providers.ShengSuanYun.APIKey
|
||||
apiBase = cfg.Providers.ShengSuanYun.APIBase
|
||||
if apiBase == "" {
|
||||
apiBase = "https://router.shengsuanyun.com/api/v1"
|
||||
}
|
||||
}
|
||||
case "claude-cli", "claudecode", "claude-code":
|
||||
workspace := cfg.WorkspacePath()
|
||||
if workspace == "" {
|
||||
workspace = "."
|
||||
}
|
||||
return NewClaudeCliProvider(workspace), model, nil
|
||||
case "codex-cli", "codex-code":
|
||||
workspace := cfg.WorkspacePath()
|
||||
if workspace == "" {
|
||||
workspace = "."
|
||||
}
|
||||
return NewCodexCliProvider(workspace), model, nil
|
||||
case "cerebras":
|
||||
if cfg.Providers.Cerebras.APIKey != "" {
|
||||
apiKey = cfg.Providers.Cerebras.APIKey
|
||||
apiBase = cfg.Providers.Cerebras.APIBase
|
||||
if apiBase == "" {
|
||||
apiBase = "https://api.cerebras.ai/v1"
|
||||
}
|
||||
}
|
||||
case "deepseek":
|
||||
if cfg.Providers.DeepSeek.APIKey != "" {
|
||||
apiKey = cfg.Providers.DeepSeek.APIKey
|
||||
apiBase = cfg.Providers.DeepSeek.APIBase
|
||||
if apiBase == "" {
|
||||
apiBase = "https://api.deepseek.com/v1"
|
||||
}
|
||||
if model != "deepseek-chat" && model != "deepseek-reasoner" {
|
||||
model = "deepseek-chat"
|
||||
}
|
||||
}
|
||||
case "qwen":
|
||||
if cfg.Providers.Qwen.APIKey != "" {
|
||||
apiKey = cfg.Providers.Qwen.APIKey
|
||||
apiBase = cfg.Providers.Qwen.APIBase
|
||||
if apiBase == "" {
|
||||
apiBase = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
}
|
||||
}
|
||||
case "github_copilot", "copilot":
|
||||
if cfg.Providers.GitHubCopilot.APIBase != "" {
|
||||
apiBase = cfg.Providers.GitHubCopilot.APIBase
|
||||
} else {
|
||||
apiBase = "localhost:4321"
|
||||
}
|
||||
provider, err := NewGitHubCopilotProvider(apiBase, cfg.Providers.GitHubCopilot.ConnectMode, model)
|
||||
return provider, model, err
|
||||
case "antigravity", "google-antigravity":
|
||||
return NewAntigravityProvider(), model, nil
|
||||
|
||||
case "volcengine", "doubao":
|
||||
if cfg.Providers.VolcEngine.APIKey != "" {
|
||||
apiKey = cfg.Providers.VolcEngine.APIKey
|
||||
apiBase = cfg.Providers.VolcEngine.APIBase
|
||||
if apiBase == "" {
|
||||
apiBase = "https://ark.cn-beijing.volces.com/api/v3"
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Get model config from model_list
|
||||
modelCfg, err := cfg.GetModelConfig(model)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("model %q not found in model_list: %w", model, err)
|
||||
}
|
||||
|
||||
// Fallback: detect provider from model name
|
||||
if apiKey == "" && apiBase == "" {
|
||||
switch {
|
||||
case (strings.Contains(lowerModel, "kimi") || strings.Contains(lowerModel, "moonshot") || strings.HasPrefix(model, "moonshot/")) && cfg.Providers.Moonshot.APIKey != "":
|
||||
apiKey = cfg.Providers.Moonshot.APIKey
|
||||
apiBase = cfg.Providers.Moonshot.APIBase
|
||||
proxy = cfg.Providers.Moonshot.Proxy
|
||||
if apiBase == "" {
|
||||
apiBase = "https://api.moonshot.cn/v1"
|
||||
}
|
||||
|
||||
case strings.HasPrefix(model, "openrouter/") || strings.HasPrefix(model, "anthropic/") || strings.HasPrefix(model, "openai/") || strings.HasPrefix(model, "meta-llama/") || strings.HasPrefix(model, "deepseek/") || strings.HasPrefix(model, "google/"):
|
||||
apiKey = cfg.Providers.OpenRouter.APIKey
|
||||
proxy = cfg.Providers.OpenRouter.Proxy
|
||||
if cfg.Providers.OpenRouter.APIBase != "" {
|
||||
apiBase = cfg.Providers.OpenRouter.APIBase
|
||||
} else {
|
||||
apiBase = "https://openrouter.ai/api/v1"
|
||||
}
|
||||
|
||||
case (strings.Contains(lowerModel, "claude") || strings.HasPrefix(model, "anthropic/")) && (cfg.Providers.Anthropic.APIKey != "" || cfg.Providers.Anthropic.AuthMethod != ""):
|
||||
if cfg.Providers.Anthropic.AuthMethod == "oauth" || cfg.Providers.Anthropic.AuthMethod == "token" {
|
||||
provider, err := createClaudeAuthProvider()
|
||||
return provider, model, err
|
||||
}
|
||||
apiKey = cfg.Providers.Anthropic.APIKey
|
||||
apiBase = cfg.Providers.Anthropic.APIBase
|
||||
proxy = cfg.Providers.Anthropic.Proxy
|
||||
if apiBase == "" {
|
||||
apiBase = "https://api.anthropic.com/v1"
|
||||
}
|
||||
|
||||
case (strings.Contains(lowerModel, "gpt") || strings.HasPrefix(model, "openai/")) && (cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != ""):
|
||||
if cfg.Providers.OpenAI.AuthMethod == "oauth" || cfg.Providers.OpenAI.AuthMethod == "token" {
|
||||
provider, err := createCodexAuthProvider()
|
||||
return provider, model, err
|
||||
}
|
||||
apiKey = cfg.Providers.OpenAI.APIKey
|
||||
apiBase = cfg.Providers.OpenAI.APIBase
|
||||
proxy = cfg.Providers.OpenAI.Proxy
|
||||
if apiBase == "" {
|
||||
apiBase = "https://api.openai.com/v1"
|
||||
}
|
||||
|
||||
case (strings.Contains(lowerModel, "gemini") || strings.HasPrefix(model, "google/")) && cfg.Providers.Gemini.APIKey != "":
|
||||
apiKey = cfg.Providers.Gemini.APIKey
|
||||
apiBase = cfg.Providers.Gemini.APIBase
|
||||
proxy = cfg.Providers.Gemini.Proxy
|
||||
if apiBase == "" {
|
||||
apiBase = "https://generativelanguage.googleapis.com/v1beta"
|
||||
}
|
||||
|
||||
case (strings.Contains(lowerModel, "glm") || strings.Contains(lowerModel, "zhipu") || strings.Contains(lowerModel, "zai")) && cfg.Providers.Zhipu.APIKey != "":
|
||||
apiKey = cfg.Providers.Zhipu.APIKey
|
||||
apiBase = cfg.Providers.Zhipu.APIBase
|
||||
proxy = cfg.Providers.Zhipu.Proxy
|
||||
if apiBase == "" {
|
||||
apiBase = "https://open.bigmodel.cn/api/paas/v4"
|
||||
}
|
||||
|
||||
case (strings.Contains(lowerModel, "groq") || strings.HasPrefix(model, "groq/")) && cfg.Providers.Groq.APIKey != "":
|
||||
apiKey = cfg.Providers.Groq.APIKey
|
||||
apiBase = cfg.Providers.Groq.APIBase
|
||||
proxy = cfg.Providers.Groq.Proxy
|
||||
if apiBase == "" {
|
||||
apiBase = "https://api.groq.com/openai/v1"
|
||||
}
|
||||
|
||||
case (strings.Contains(lowerModel, "qwen") || strings.HasPrefix(model, "qwen/")) && cfg.Providers.Qwen.APIKey != "":
|
||||
apiKey = cfg.Providers.Qwen.APIKey
|
||||
apiBase = cfg.Providers.Qwen.APIBase
|
||||
proxy = cfg.Providers.Qwen.Proxy
|
||||
if apiBase == "" {
|
||||
apiBase = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
}
|
||||
|
||||
case (strings.Contains(lowerModel, "nvidia") || strings.HasPrefix(model, "nvidia/")) && cfg.Providers.Nvidia.APIKey != "":
|
||||
apiKey = cfg.Providers.Nvidia.APIKey
|
||||
apiBase = cfg.Providers.Nvidia.APIBase
|
||||
proxy = cfg.Providers.Nvidia.Proxy
|
||||
if apiBase == "" {
|
||||
apiBase = "https://integrate.api.nvidia.com/v1"
|
||||
}
|
||||
case (strings.Contains(lowerModel, "cerebras") || strings.HasPrefix(model, "cerebras/")) && cfg.Providers.Cerebras.APIKey != "":
|
||||
apiKey = cfg.Providers.Cerebras.APIKey
|
||||
apiBase = cfg.Providers.Cerebras.APIBase
|
||||
proxy = cfg.Providers.Cerebras.Proxy
|
||||
if apiBase == "" {
|
||||
apiBase = "https://api.cerebras.ai/v1"
|
||||
}
|
||||
|
||||
case (strings.Contains(lowerModel, "ollama") || strings.HasPrefix(model, "ollama/")) && cfg.Providers.Ollama.APIKey != "":
|
||||
fmt.Println("Ollama provider selected based on model name prefix")
|
||||
apiKey = cfg.Providers.Ollama.APIKey
|
||||
apiBase = cfg.Providers.Ollama.APIBase
|
||||
proxy = cfg.Providers.Ollama.Proxy
|
||||
if apiBase == "" {
|
||||
apiBase = "http://localhost:11434/v1"
|
||||
}
|
||||
fmt.Println("Ollama apiBase:", apiBase)
|
||||
|
||||
case (strings.Contains(lowerModel, "doubao") || strings.HasPrefix(lowerModel, "doubao") || strings.Contains(lowerModel, "volcengine")) && cfg.Providers.VolcEngine.APIKey != "":
|
||||
apiKey = cfg.Providers.VolcEngine.APIKey
|
||||
apiBase = cfg.Providers.VolcEngine.APIBase
|
||||
proxy = cfg.Providers.VolcEngine.Proxy
|
||||
if apiBase == "" {
|
||||
apiBase = "https://ark.cn-beijing.volces.com/api/v3"
|
||||
}
|
||||
|
||||
case cfg.Providers.VLLM.APIBase != "":
|
||||
apiKey = cfg.Providers.VLLM.APIKey
|
||||
apiBase = cfg.Providers.VLLM.APIBase
|
||||
proxy = cfg.Providers.VLLM.Proxy
|
||||
|
||||
default:
|
||||
if cfg.Providers.OpenRouter.APIKey != "" {
|
||||
apiKey = cfg.Providers.OpenRouter.APIKey
|
||||
proxy = cfg.Providers.OpenRouter.Proxy
|
||||
if cfg.Providers.OpenRouter.APIBase != "" {
|
||||
apiBase = cfg.Providers.OpenRouter.APIBase
|
||||
} else {
|
||||
apiBase = "https://openrouter.ai/api/v1"
|
||||
}
|
||||
} else {
|
||||
return nil, "", fmt.Errorf("no API key configured for model: %s", model)
|
||||
}
|
||||
}
|
||||
// Inject global workspace if not set in model config
|
||||
if modelCfg.Workspace == "" {
|
||||
modelCfg.Workspace = cfg.WorkspacePath()
|
||||
}
|
||||
|
||||
if apiKey == "" && !strings.HasPrefix(model, "bedrock/") {
|
||||
return nil, "", fmt.Errorf("no API key configured for provider (model: %s)", model)
|
||||
// Use factory to create provider
|
||||
provider, modelID, err := CreateProviderFromConfig(modelCfg)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("failed to create provider for model %q: %w", model, err)
|
||||
}
|
||||
|
||||
if apiBase == "" {
|
||||
return nil, "", fmt.Errorf("no API base configured for provider (model: %s)", model)
|
||||
}
|
||||
|
||||
return NewHTTPProvider(apiKey, apiBase, proxy), model, nil
|
||||
return provider, modelID, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user