refactor: reorganize commands and provider architecture

Refactor command handlers into separate files to improve code organization
and maintainability. Each command (agent, auth, cron, gateway, migrate,
onboard, skills, status) now has its own dedicated file.

Restructure provider creation to support new model_list configuration
system that enables zero-code addition of OpenAI-compatible providers.
Move legacy provider logic to separate file for backward compatibility.

Move configuration functions from config.go to separate files
(defaults.go, migration.go) for better organization.
This commit is contained in:
yinwm
2026-02-19 01:03:34 +08:00
parent a73d8e1a16
commit ef7078a356
23 changed files with 3429 additions and 2241 deletions
+47 -372
View File
@@ -232,23 +232,6 @@ func (c *ModelConfig) Validate() error {
return nil
}
// ParseProtocol extracts the protocol prefix and model identifier from the Model field.
// If no prefix is specified, it defaults to "openai".
// Examples:
// - "openai/gpt-4o" -> ("openai", "gpt-4o")
// - "anthropic/claude-3" -> ("anthropic", "claude-3")
// - "gpt-4o" -> ("openai", "gpt-4o") // default protocol
func (c *ModelConfig) ParseProtocol() (protocol, modelID string) {
model := c.Model
for i := 0; i < len(model); i++ {
if model[i] == '/' {
return model[:i], model[i+1:]
}
}
// No prefix found, default to openai
return "openai", model
}
type GatewayConfig struct {
Host string `json:"host" env:"PICOCLAW_GATEWAY_HOST"`
Port int `json:"port" env:"PICOCLAW_GATEWAY_PORT"`
@@ -286,135 +269,6 @@ type ToolsConfig struct {
Cron CronToolsConfig `json:"cron"`
}
func DefaultConfig() *Config {
return &Config{
Agents: AgentsConfig{
Defaults: AgentDefaults{
Workspace: "~/.picoclaw/workspace",
RestrictToWorkspace: true,
Provider: "",
Model: "glm-4.7",
MaxTokens: 8192,
Temperature: 0.7,
MaxToolIterations: 20,
},
},
Channels: ChannelsConfig{
WhatsApp: WhatsAppConfig{
Enabled: false,
BridgeURL: "ws://localhost:3001",
AllowFrom: FlexibleStringSlice{},
},
Telegram: TelegramConfig{
Enabled: false,
Token: "",
AllowFrom: FlexibleStringSlice{},
},
Feishu: FeishuConfig{
Enabled: false,
AppID: "",
AppSecret: "",
EncryptKey: "",
VerificationToken: "",
AllowFrom: FlexibleStringSlice{},
},
Discord: DiscordConfig{
Enabled: false,
Token: "",
AllowFrom: FlexibleStringSlice{},
},
MaixCam: MaixCamConfig{
Enabled: false,
Host: "0.0.0.0",
Port: 18790,
AllowFrom: FlexibleStringSlice{},
},
QQ: QQConfig{
Enabled: false,
AppID: "",
AppSecret: "",
AllowFrom: FlexibleStringSlice{},
},
DingTalk: DingTalkConfig{
Enabled: false,
ClientID: "",
ClientSecret: "",
AllowFrom: FlexibleStringSlice{},
},
Slack: SlackConfig{
Enabled: false,
BotToken: "",
AppToken: "",
AllowFrom: FlexibleStringSlice{},
},
LINE: LINEConfig{
Enabled: false,
ChannelSecret: "",
ChannelAccessToken: "",
WebhookHost: "0.0.0.0",
WebhookPort: 18791,
WebhookPath: "/webhook/line",
AllowFrom: FlexibleStringSlice{},
},
OneBot: OneBotConfig{
Enabled: false,
WSUrl: "ws://127.0.0.1:3001",
AccessToken: "",
ReconnectInterval: 5,
GroupTriggerPrefix: []string{},
AllowFrom: FlexibleStringSlice{},
},
},
Providers: ProvidersConfig{
Anthropic: ProviderConfig{},
OpenAI: ProviderConfig{},
OpenRouter: ProviderConfig{},
Groq: ProviderConfig{},
Zhipu: ProviderConfig{},
VLLM: ProviderConfig{},
Gemini: ProviderConfig{},
Nvidia: ProviderConfig{},
Moonshot: ProviderConfig{},
ShengSuanYun: ProviderConfig{},
Cerebras: ProviderConfig{},
VolcEngine: ProviderConfig{},
},
Gateway: GatewayConfig{
Host: "0.0.0.0",
Port: 18790,
},
Tools: ToolsConfig{
Web: WebToolsConfig{
Brave: BraveConfig{
Enabled: false,
APIKey: "",
MaxResults: 5,
},
DuckDuckGo: DuckDuckGoConfig{
Enabled: true,
MaxResults: 5,
},
Perplexity: PerplexityConfig{
Enabled: false,
APIKey: "",
MaxResults: 5,
},
},
Cron: CronToolsConfig{
ExecTimeoutMinutes: 5, // default 5 minutes for LLM operations
},
},
Heartbeat: HeartbeatConfig{
Enabled: true,
Interval: 30, // default 30 minutes
},
Devices: DevicesConfig{
Enabled: false,
MonitorUSB: true,
},
}
}
func LoadConfig(path string) (*Config, error) {
cfg := DefaultConfig()
@@ -528,40 +382,61 @@ func expandHome(path string) string {
// GetModelConfig returns the ModelConfig for the given model name.
// If multiple configs exist with the same model_name, it uses round-robin
// selection for load balancing. Returns an error if the model is not found.
// Uses double-check locking for optimal read performance.
func (c *Config) GetModelConfig(modelName string) (*ModelConfig, error) {
c.mu.Lock()
defer c.mu.Unlock()
// First pass: use read lock to find matches
c.mu.RLock()
matches := c.findMatchesLocked(modelName)
if len(matches) == 0 {
c.mu.RUnlock()
return nil, fmt.Errorf("model %q not found in model_list or providers", modelName)
}
if len(matches) == 1 {
c.mu.RUnlock()
return &matches[0], nil
}
// Find all configs with matching model_name
// Multiple configs - check if counter exists
counter, ok := c.rrCounters[modelName]
c.mu.RUnlock()
// Double-check locking: only acquire write lock if counter needs initialization
if !ok {
c.mu.Lock()
// Re-check after acquiring write lock
if c.rrCounters == nil {
c.rrCounters = make(map[string]*atomic.Uint64)
}
if c.rrCounters[modelName] == nil {
c.rrCounters[modelName] = &atomic.Uint64{}
}
counter = c.rrCounters[modelName]
c.mu.Unlock()
}
// Re-fetch matches to ensure consistency (ModelList could have changed)
c.mu.RLock()
matches = c.findMatchesLocked(modelName)
c.mu.RUnlock()
if len(matches) == 0 {
return nil, fmt.Errorf("model %q not found in model_list or providers", modelName)
}
idx := counter.Add(1) % uint64(len(matches))
return &matches[idx], nil
}
// findMatchesLocked finds all ModelConfig entries with the given model_name.
// Must be called with c.mu locked (read or write).
func (c *Config) findMatchesLocked(modelName string) []ModelConfig {
var matches []ModelConfig
for i := range c.ModelList {
if c.ModelList[i].ModelName == modelName {
matches = append(matches, c.ModelList[i])
}
}
if len(matches) == 0 {
return nil, fmt.Errorf("model %q not found in model_list or providers", modelName)
}
// Single config - return directly
if len(matches) == 1 {
return &matches[0], nil
}
// Multiple configs - use round-robin for load balancing
if c.rrCounters == nil {
c.rrCounters = make(map[string]*atomic.Uint64)
}
counter, ok := c.rrCounters[modelName]
if !ok {
counter = &atomic.Uint64{}
c.rrCounters[modelName] = counter
}
idx := counter.Add(1) % uint64(len(matches))
return &matches[idx], nil
return matches
}
// HasProvidersConfig checks if any provider in the old providers config has configuration.
@@ -599,203 +474,3 @@ func (c *Config) ValidateModelList() error {
}
return nil
}
// ConvertProvidersToModelList converts the old ProvidersConfig to a slice of ModelConfig.
// This enables backward compatibility with existing configurations.
func ConvertProvidersToModelList(cfg *Config) []ModelConfig {
if cfg == nil {
return nil
}
var result []ModelConfig
p := cfg.Providers
// OpenAI
if p.OpenAI.APIKey != "" || p.OpenAI.APIBase != "" {
result = append(result, ModelConfig{
ModelName: "openai",
Model: "openai/gpt-4o",
APIKey: p.OpenAI.APIKey,
APIBase: p.OpenAI.APIBase,
Proxy: p.OpenAI.Proxy,
AuthMethod: p.OpenAI.AuthMethod,
})
}
// Anthropic
if p.Anthropic.APIKey != "" || p.Anthropic.APIBase != "" {
result = append(result, ModelConfig{
ModelName: "anthropic",
Model: "anthropic/claude-3-sonnet",
APIKey: p.Anthropic.APIKey,
APIBase: p.Anthropic.APIBase,
Proxy: p.Anthropic.Proxy,
AuthMethod: p.Anthropic.AuthMethod,
})
}
// OpenRouter
if p.OpenRouter.APIKey != "" || p.OpenRouter.APIBase != "" {
result = append(result, ModelConfig{
ModelName: "openrouter",
Model: "openrouter/auto",
APIKey: p.OpenRouter.APIKey,
APIBase: p.OpenRouter.APIBase,
Proxy: p.OpenRouter.Proxy,
})
}
// Groq
if p.Groq.APIKey != "" || p.Groq.APIBase != "" {
result = append(result, ModelConfig{
ModelName: "groq",
Model: "groq/llama-3.1-70b-versatile",
APIKey: p.Groq.APIKey,
APIBase: p.Groq.APIBase,
Proxy: p.Groq.Proxy,
})
}
// Zhipu
if p.Zhipu.APIKey != "" || p.Zhipu.APIBase != "" {
result = append(result, ModelConfig{
ModelName: "zhipu",
Model: "openai/glm-4",
APIKey: p.Zhipu.APIKey,
APIBase: p.Zhipu.APIBase,
Proxy: p.Zhipu.Proxy,
})
}
// VLLM
if p.VLLM.APIKey != "" || p.VLLM.APIBase != "" {
result = append(result, ModelConfig{
ModelName: "vllm",
Model: "openai/auto",
APIKey: p.VLLM.APIKey,
APIBase: p.VLLM.APIBase,
Proxy: p.VLLM.Proxy,
})
}
// Gemini
if p.Gemini.APIKey != "" || p.Gemini.APIBase != "" {
result = append(result, ModelConfig{
ModelName: "gemini",
Model: "openai/gemini-pro",
APIKey: p.Gemini.APIKey,
APIBase: p.Gemini.APIBase,
Proxy: p.Gemini.Proxy,
})
}
// Nvidia
if p.Nvidia.APIKey != "" || p.Nvidia.APIBase != "" {
result = append(result, ModelConfig{
ModelName: "nvidia",
Model: "nvidia/meta/llama-3.1-8b-instruct",
APIKey: p.Nvidia.APIKey,
APIBase: p.Nvidia.APIBase,
Proxy: p.Nvidia.Proxy,
})
}
// Ollama
if p.Ollama.APIKey != "" || p.Ollama.APIBase != "" {
result = append(result, ModelConfig{
ModelName: "ollama",
Model: "ollama/llama3",
APIKey: p.Ollama.APIKey,
APIBase: p.Ollama.APIBase,
Proxy: p.Ollama.Proxy,
})
}
// Moonshot
if p.Moonshot.APIKey != "" || p.Moonshot.APIBase != "" {
result = append(result, ModelConfig{
ModelName: "moonshot",
Model: "moonshot/kimi",
APIKey: p.Moonshot.APIKey,
APIBase: p.Moonshot.APIBase,
Proxy: p.Moonshot.Proxy,
})
}
// ShengSuanYun
if p.ShengSuanYun.APIKey != "" || p.ShengSuanYun.APIBase != "" {
result = append(result, ModelConfig{
ModelName: "shengsuanyun",
Model: "openai/auto",
APIKey: p.ShengSuanYun.APIKey,
APIBase: p.ShengSuanYun.APIBase,
Proxy: p.ShengSuanYun.Proxy,
})
}
// DeepSeek
if p.DeepSeek.APIKey != "" || p.DeepSeek.APIBase != "" {
result = append(result, ModelConfig{
ModelName: "deepseek",
Model: "openai/deepseek-chat",
APIKey: p.DeepSeek.APIKey,
APIBase: p.DeepSeek.APIBase,
Proxy: p.DeepSeek.Proxy,
})
}
// Cerebras
if p.Cerebras.APIKey != "" || p.Cerebras.APIBase != "" {
result = append(result, ModelConfig{
ModelName: "cerebras",
Model: "cerebras/llama-3.3-70b",
APIKey: p.Cerebras.APIKey,
APIBase: p.Cerebras.APIBase,
Proxy: p.Cerebras.Proxy,
})
}
// VolcEngine (Doubao)
if p.VolcEngine.APIKey != "" || p.VolcEngine.APIBase != "" {
result = append(result, ModelConfig{
ModelName: "volcengine",
Model: "openai/doubao-pro",
APIKey: p.VolcEngine.APIKey,
APIBase: p.VolcEngine.APIBase,
Proxy: p.VolcEngine.Proxy,
})
}
// GitHub Copilot
if p.GitHubCopilot.APIKey != "" || p.GitHubCopilot.APIBase != "" || p.GitHubCopilot.ConnectMode != "" {
result = append(result, ModelConfig{
ModelName: "github-copilot",
Model: "github-copilot/gpt-4o",
APIBase: p.GitHubCopilot.APIBase,
ConnectMode: p.GitHubCopilot.ConnectMode,
})
}
// Antigravity
if p.Antigravity.APIKey != "" || p.Antigravity.AuthMethod != "" {
result = append(result, ModelConfig{
ModelName: "antigravity",
Model: "antigravity/gemini-2.0-flash",
APIKey: p.Antigravity.APIKey,
AuthMethod: p.Antigravity.AuthMethod,
})
}
// Qwen
if p.Qwen.APIKey != "" || p.Qwen.APIBase != "" {
result = append(result, ModelConfig{
ModelName: "qwen",
Model: "qwen/qwen-max",
APIKey: p.Qwen.APIKey,
APIBase: p.Qwen.APIBase,
Proxy: p.Qwen.Proxy,
})
}
return result
}
+136
View File
@@ -0,0 +1,136 @@
// PicoClaw - Ultra-lightweight personal AI agent
// License: MIT
//
// Copyright (c) 2026 PicoClaw contributors
package config
// DefaultConfig returns the default configuration for PicoClaw.
func DefaultConfig() *Config {
return &Config{
Agents: AgentsConfig{
Defaults: AgentDefaults{
Workspace: "~/.picoclaw/workspace",
RestrictToWorkspace: true,
Provider: "",
Model: "glm-4.7",
MaxTokens: 8192,
Temperature: 0.7,
MaxToolIterations: 20,
},
},
Channels: ChannelsConfig{
WhatsApp: WhatsAppConfig{
Enabled: false,
BridgeURL: "ws://localhost:3001",
AllowFrom: FlexibleStringSlice{},
},
Telegram: TelegramConfig{
Enabled: false,
Token: "",
AllowFrom: FlexibleStringSlice{},
},
Feishu: FeishuConfig{
Enabled: false,
AppID: "",
AppSecret: "",
EncryptKey: "",
VerificationToken: "",
AllowFrom: FlexibleStringSlice{},
},
Discord: DiscordConfig{
Enabled: false,
Token: "",
AllowFrom: FlexibleStringSlice{},
},
MaixCam: MaixCamConfig{
Enabled: false,
Host: "0.0.0.0",
Port: 18790,
AllowFrom: FlexibleStringSlice{},
},
QQ: QQConfig{
Enabled: false,
AppID: "",
AppSecret: "",
AllowFrom: FlexibleStringSlice{},
},
DingTalk: DingTalkConfig{
Enabled: false,
ClientID: "",
ClientSecret: "",
AllowFrom: FlexibleStringSlice{},
},
Slack: SlackConfig{
Enabled: false,
BotToken: "",
AppToken: "",
AllowFrom: FlexibleStringSlice{},
},
LINE: LINEConfig{
Enabled: false,
ChannelSecret: "",
ChannelAccessToken: "",
WebhookHost: "0.0.0.0",
WebhookPort: 18791,
WebhookPath: "/webhook/line",
AllowFrom: FlexibleStringSlice{},
},
OneBot: OneBotConfig{
Enabled: false,
WSUrl: "ws://127.0.0.1:3001",
AccessToken: "",
ReconnectInterval: 5,
GroupTriggerPrefix: []string{},
AllowFrom: FlexibleStringSlice{},
},
},
Providers: ProvidersConfig{
Anthropic: ProviderConfig{},
OpenAI: ProviderConfig{},
OpenRouter: ProviderConfig{},
Groq: ProviderConfig{},
Zhipu: ProviderConfig{},
VLLM: ProviderConfig{},
Gemini: ProviderConfig{},
Nvidia: ProviderConfig{},
Moonshot: ProviderConfig{},
ShengSuanYun: ProviderConfig{},
Cerebras: ProviderConfig{},
VolcEngine: ProviderConfig{},
},
Gateway: GatewayConfig{
Host: "0.0.0.0",
Port: 18790,
},
Tools: ToolsConfig{
Web: WebToolsConfig{
Brave: BraveConfig{
Enabled: false,
APIKey: "",
MaxResults: 5,
},
DuckDuckGo: DuckDuckGoConfig{
Enabled: true,
MaxResults: 5,
},
Perplexity: PerplexityConfig{
Enabled: false,
APIKey: "",
MaxResults: 5,
},
},
Cron: CronToolsConfig{
ExecTimeoutMinutes: 5, // default 5 minutes for LLM operations
},
},
Heartbeat: HeartbeatConfig{
Enabled: true,
Interval: 30, // default 30 minutes
},
Devices: DevicesConfig{
Enabled: false,
MonitorUSB: true,
},
}
}
+206
View File
@@ -0,0 +1,206 @@
// PicoClaw - Ultra-lightweight personal AI agent
// License: MIT
//
// Copyright (c) 2026 PicoClaw contributors
package config
// ConvertProvidersToModelList converts the old ProvidersConfig to a slice of ModelConfig.
// This enables backward compatibility with existing configurations.
func ConvertProvidersToModelList(cfg *Config) []ModelConfig {
if cfg == nil {
return nil
}
var result []ModelConfig
p := cfg.Providers
// OpenAI
if p.OpenAI.APIKey != "" || p.OpenAI.APIBase != "" {
result = append(result, ModelConfig{
ModelName: "openai",
Model: "openai/gpt-4o",
APIKey: p.OpenAI.APIKey,
APIBase: p.OpenAI.APIBase,
Proxy: p.OpenAI.Proxy,
AuthMethod: p.OpenAI.AuthMethod,
})
}
// Anthropic
if p.Anthropic.APIKey != "" || p.Anthropic.APIBase != "" {
result = append(result, ModelConfig{
ModelName: "anthropic",
Model: "anthropic/claude-3-sonnet",
APIKey: p.Anthropic.APIKey,
APIBase: p.Anthropic.APIBase,
Proxy: p.Anthropic.Proxy,
AuthMethod: p.Anthropic.AuthMethod,
})
}
// OpenRouter
if p.OpenRouter.APIKey != "" || p.OpenRouter.APIBase != "" {
result = append(result, ModelConfig{
ModelName: "openrouter",
Model: "openrouter/auto",
APIKey: p.OpenRouter.APIKey,
APIBase: p.OpenRouter.APIBase,
Proxy: p.OpenRouter.Proxy,
})
}
// Groq
if p.Groq.APIKey != "" || p.Groq.APIBase != "" {
result = append(result, ModelConfig{
ModelName: "groq",
Model: "groq/llama-3.1-70b-versatile",
APIKey: p.Groq.APIKey,
APIBase: p.Groq.APIBase,
Proxy: p.Groq.Proxy,
})
}
// Zhipu
if p.Zhipu.APIKey != "" || p.Zhipu.APIBase != "" {
result = append(result, ModelConfig{
ModelName: "zhipu",
Model: "openai/glm-4",
APIKey: p.Zhipu.APIKey,
APIBase: p.Zhipu.APIBase,
Proxy: p.Zhipu.Proxy,
})
}
// VLLM
if p.VLLM.APIKey != "" || p.VLLM.APIBase != "" {
result = append(result, ModelConfig{
ModelName: "vllm",
Model: "openai/auto",
APIKey: p.VLLM.APIKey,
APIBase: p.VLLM.APIBase,
Proxy: p.VLLM.Proxy,
})
}
// Gemini
if p.Gemini.APIKey != "" || p.Gemini.APIBase != "" {
result = append(result, ModelConfig{
ModelName: "gemini",
Model: "openai/gemini-pro",
APIKey: p.Gemini.APIKey,
APIBase: p.Gemini.APIBase,
Proxy: p.Gemini.Proxy,
})
}
// Nvidia
if p.Nvidia.APIKey != "" || p.Nvidia.APIBase != "" {
result = append(result, ModelConfig{
ModelName: "nvidia",
Model: "nvidia/meta/llama-3.1-8b-instruct",
APIKey: p.Nvidia.APIKey,
APIBase: p.Nvidia.APIBase,
Proxy: p.Nvidia.Proxy,
})
}
// Ollama
if p.Ollama.APIKey != "" || p.Ollama.APIBase != "" {
result = append(result, ModelConfig{
ModelName: "ollama",
Model: "ollama/llama3",
APIKey: p.Ollama.APIKey,
APIBase: p.Ollama.APIBase,
Proxy: p.Ollama.Proxy,
})
}
// Moonshot
if p.Moonshot.APIKey != "" || p.Moonshot.APIBase != "" {
result = append(result, ModelConfig{
ModelName: "moonshot",
Model: "moonshot/kimi",
APIKey: p.Moonshot.APIKey,
APIBase: p.Moonshot.APIBase,
Proxy: p.Moonshot.Proxy,
})
}
// ShengSuanYun
if p.ShengSuanYun.APIKey != "" || p.ShengSuanYun.APIBase != "" {
result = append(result, ModelConfig{
ModelName: "shengsuanyun",
Model: "openai/auto",
APIKey: p.ShengSuanYun.APIKey,
APIBase: p.ShengSuanYun.APIBase,
Proxy: p.ShengSuanYun.Proxy,
})
}
// DeepSeek
if p.DeepSeek.APIKey != "" || p.DeepSeek.APIBase != "" {
result = append(result, ModelConfig{
ModelName: "deepseek",
Model: "openai/deepseek-chat",
APIKey: p.DeepSeek.APIKey,
APIBase: p.DeepSeek.APIBase,
Proxy: p.DeepSeek.Proxy,
})
}
// Cerebras
if p.Cerebras.APIKey != "" || p.Cerebras.APIBase != "" {
result = append(result, ModelConfig{
ModelName: "cerebras",
Model: "cerebras/llama-3.3-70b",
APIKey: p.Cerebras.APIKey,
APIBase: p.Cerebras.APIBase,
Proxy: p.Cerebras.Proxy,
})
}
// VolcEngine (Doubao)
if p.VolcEngine.APIKey != "" || p.VolcEngine.APIBase != "" {
result = append(result, ModelConfig{
ModelName: "volcengine",
Model: "openai/doubao-pro",
APIKey: p.VolcEngine.APIKey,
APIBase: p.VolcEngine.APIBase,
Proxy: p.VolcEngine.Proxy,
})
}
// GitHub Copilot
if p.GitHubCopilot.APIKey != "" || p.GitHubCopilot.APIBase != "" || p.GitHubCopilot.ConnectMode != "" {
result = append(result, ModelConfig{
ModelName: "github-copilot",
Model: "github-copilot/gpt-4o",
APIBase: p.GitHubCopilot.APIBase,
ConnectMode: p.GitHubCopilot.ConnectMode,
})
}
// Antigravity
if p.Antigravity.APIKey != "" || p.Antigravity.AuthMethod != "" {
result = append(result, ModelConfig{
ModelName: "antigravity",
Model: "antigravity/gemini-2.0-flash",
APIKey: p.Antigravity.APIKey,
AuthMethod: p.Antigravity.AuthMethod,
})
}
// Qwen
if p.Qwen.APIKey != "" || p.Qwen.APIBase != "" {
result = append(result, ModelConfig{
ModelName: "qwen",
Model: "qwen/qwen-max",
APIKey: p.Qwen.APIKey,
APIBase: p.Qwen.APIBase,
Proxy: p.Qwen.Proxy,
})
}
return result
}
+177
View File
@@ -0,0 +1,177 @@
// PicoClaw - Ultra-lightweight personal AI agent
// License: MIT
//
// Copyright (c) 2026 PicoClaw contributors
package config
import (
"testing"
)
func TestConvertProvidersToModelList_OpenAI(t *testing.T) {
cfg := &Config{
Providers: ProvidersConfig{
OpenAI: ProviderConfig{
APIKey: "sk-test-key",
APIBase: "https://custom.api.com/v1",
},
},
}
result := ConvertProvidersToModelList(cfg)
if len(result) != 1 {
t.Fatalf("len(result) = %d, want 1", len(result))
}
if result[0].ModelName != "openai" {
t.Errorf("ModelName = %q, want %q", result[0].ModelName, "openai")
}
if result[0].Model != "openai/gpt-4o" {
t.Errorf("Model = %q, want %q", result[0].Model, "openai/gpt-4o")
}
if result[0].APIKey != "sk-test-key" {
t.Errorf("APIKey = %q, want %q", result[0].APIKey, "sk-test-key")
}
}
func TestConvertProvidersToModelList_Anthropic(t *testing.T) {
cfg := &Config{
Providers: ProvidersConfig{
Anthropic: ProviderConfig{
APIKey: "ant-key",
APIBase: "https://custom.anthropic.com",
},
},
}
result := ConvertProvidersToModelList(cfg)
if len(result) != 1 {
t.Fatalf("len(result) = %d, want 1", len(result))
}
if result[0].ModelName != "anthropic" {
t.Errorf("ModelName = %q, want %q", result[0].ModelName, "anthropic")
}
if result[0].Model != "anthropic/claude-3-sonnet" {
t.Errorf("Model = %q, want %q", result[0].Model, "anthropic/claude-3-sonnet")
}
}
func TestConvertProvidersToModelList_Multiple(t *testing.T) {
cfg := &Config{
Providers: ProvidersConfig{
OpenAI: ProviderConfig{APIKey: "openai-key"},
Groq: ProviderConfig{APIKey: "groq-key"},
Zhipu: ProviderConfig{APIKey: "zhipu-key"},
},
}
result := ConvertProvidersToModelList(cfg)
if len(result) != 3 {
t.Fatalf("len(result) = %d, want 3", len(result))
}
// Check that all providers are present
found := make(map[string]bool)
for _, mc := range result {
found[mc.ModelName] = true
}
for _, name := range []string{"openai", "groq", "zhipu"} {
if !found[name] {
t.Errorf("Missing provider %q in result", name)
}
}
}
func TestConvertProvidersToModelList_Empty(t *testing.T) {
cfg := &Config{
Providers: ProvidersConfig{},
}
result := ConvertProvidersToModelList(cfg)
if len(result) != 0 {
t.Errorf("len(result) = %d, want 0", len(result))
}
}
func TestConvertProvidersToModelList_Nil(t *testing.T) {
result := ConvertProvidersToModelList(nil)
if result != nil {
t.Errorf("result = %v, want nil", result)
}
}
func TestConvertProvidersToModelList_AllProviders(t *testing.T) {
cfg := &Config{
Providers: ProvidersConfig{
OpenAI: ProviderConfig{APIKey: "key1"},
Anthropic: ProviderConfig{APIKey: "key2"},
OpenRouter: ProviderConfig{APIKey: "key3"},
Groq: ProviderConfig{APIKey: "key4"},
Zhipu: ProviderConfig{APIKey: "key5"},
VLLM: ProviderConfig{APIKey: "key6"},
Gemini: ProviderConfig{APIKey: "key7"},
Nvidia: ProviderConfig{APIKey: "key8"},
Ollama: ProviderConfig{APIKey: "key9"},
Moonshot: ProviderConfig{APIKey: "key10"},
ShengSuanYun: ProviderConfig{APIKey: "key11"},
DeepSeek: ProviderConfig{APIKey: "key12"},
Cerebras: ProviderConfig{APIKey: "key13"},
VolcEngine: ProviderConfig{APIKey: "key14"},
GitHubCopilot: ProviderConfig{ConnectMode: "grpc"},
Antigravity: ProviderConfig{AuthMethod: "oauth"},
Qwen: ProviderConfig{APIKey: "key17"},
},
}
result := ConvertProvidersToModelList(cfg)
// All 17 providers should be converted
if len(result) != 17 {
t.Errorf("len(result) = %d, want 17", len(result))
}
}
func TestConvertProvidersToModelList_Proxy(t *testing.T) {
cfg := &Config{
Providers: ProvidersConfig{
OpenAI: ProviderConfig{
APIKey: "key",
Proxy: "http://proxy:8080",
},
},
}
result := ConvertProvidersToModelList(cfg)
if len(result) != 1 {
t.Fatalf("len(result) = %d, want 1", len(result))
}
if result[0].Proxy != "http://proxy:8080" {
t.Errorf("Proxy = %q, want %q", result[0].Proxy, "http://proxy:8080")
}
}
func TestConvertProvidersToModelList_AuthMethod(t *testing.T) {
cfg := &Config{
Providers: ProvidersConfig{
OpenAI: ProviderConfig{
AuthMethod: "oauth",
},
},
}
result := ConvertProvidersToModelList(cfg)
if len(result) != 0 {
t.Errorf("len(result) = %d, want 0 (AuthMethod alone should not create entry)", len(result))
}
}
+204
View File
@@ -0,0 +1,204 @@
// PicoClaw - Ultra-lightweight personal AI agent
// License: MIT
//
// Copyright (c) 2026 PicoClaw contributors
package config
import (
"sync"
"testing"
)
func TestGetModelConfig_Found(t *testing.T) {
cfg := &Config{
ModelList: []ModelConfig{
{ModelName: "test-model", Model: "openai/gpt-4o", APIKey: "key1"},
{ModelName: "other-model", Model: "anthropic/claude", APIKey: "key2"},
},
}
result, err := cfg.GetModelConfig("test-model")
if err != nil {
t.Fatalf("GetModelConfig() error = %v", err)
}
if result.Model != "openai/gpt-4o" {
t.Errorf("Model = %q, want %q", result.Model, "openai/gpt-4o")
}
}
func TestGetModelConfig_NotFound(t *testing.T) {
cfg := &Config{
ModelList: []ModelConfig{
{ModelName: "test-model", Model: "openai/gpt-4o", APIKey: "key1"},
},
}
_, err := cfg.GetModelConfig("nonexistent")
if err == nil {
t.Fatal("GetModelConfig() expected error for nonexistent model")
}
}
func TestGetModelConfig_EmptyList(t *testing.T) {
cfg := &Config{
ModelList: []ModelConfig{},
}
_, err := cfg.GetModelConfig("any-model")
if err == nil {
t.Fatal("GetModelConfig() expected error for empty model list")
}
}
func TestGetModelConfig_RoundRobin(t *testing.T) {
cfg := &Config{
ModelList: []ModelConfig{
{ModelName: "lb-model", Model: "openai/gpt-4o-1", APIKey: "key1"},
{ModelName: "lb-model", Model: "openai/gpt-4o-2", APIKey: "key2"},
{ModelName: "lb-model", Model: "openai/gpt-4o-3", APIKey: "key3"},
},
}
// Test round-robin distribution
results := make(map[string]int)
for i := 0; i < 30; i++ {
result, err := cfg.GetModelConfig("lb-model")
if err != nil {
t.Fatalf("GetModelConfig() error = %v", err)
}
results[result.Model]++
}
// Each model should appear roughly 10 times (30 calls / 3 models)
for model, count := range results {
if count < 5 || count > 15 {
t.Errorf("Model %s appeared %d times, expected ~10", model, count)
}
}
}
func TestGetModelConfig_Concurrent(t *testing.T) {
cfg := &Config{
ModelList: []ModelConfig{
{ModelName: "concurrent-model", Model: "openai/gpt-4o-1", APIKey: "key1"},
{ModelName: "concurrent-model", Model: "openai/gpt-4o-2", APIKey: "key2"},
},
}
const goroutines = 100
const iterations = 10
var wg sync.WaitGroup
errors := make(chan error, goroutines*iterations)
for i := 0; i < goroutines; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for j := 0; j < iterations; j++ {
_, err := cfg.GetModelConfig("concurrent-model")
if err != nil {
errors <- err
}
}
}()
}
wg.Wait()
close(errors)
for err := range errors {
t.Errorf("Concurrent GetModelConfig() error: %v", err)
}
}
func TestModelConfig_Validate(t *testing.T) {
tests := []struct {
name string
config ModelConfig
wantErr bool
}{
{
name: "valid config",
config: ModelConfig{
ModelName: "test",
Model: "openai/gpt-4o",
},
wantErr: false,
},
{
name: "missing model_name",
config: ModelConfig{
Model: "openai/gpt-4o",
},
wantErr: true,
},
{
name: "missing model",
config: ModelConfig{
ModelName: "test",
},
wantErr: true,
},
{
name: "empty config",
config: ModelConfig{},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.config.Validate()
if (err != nil) != tt.wantErr {
t.Errorf("Validate() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestConfig_ValidateModelList(t *testing.T) {
tests := []struct {
name string
config *Config
wantErr bool
}{
{
name: "valid list",
config: &Config{
ModelList: []ModelConfig{
{ModelName: "test1", Model: "openai/gpt-4o"},
{ModelName: "test2", Model: "anthropic/claude"},
},
},
wantErr: false,
},
{
name: "invalid entry",
config: &Config{
ModelList: []ModelConfig{
{ModelName: "test1", Model: "openai/gpt-4o"},
{ModelName: "", Model: "anthropic/claude"}, // missing model_name
},
},
wantErr: true,
},
{
name: "empty list",
config: &Config{
ModelList: []ModelConfig{},
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.config.ValidateModelList()
if (err != nil) != tt.wantErr {
t.Errorf("ValidateModelList() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
+3 -3
View File
@@ -180,8 +180,8 @@ func TestConvertConfig(t *testing.T) {
t.Run("unsupported provider warning", func(t *testing.T) {
data := map[string]interface{}{
"providers": map[string]interface{}{
"deepseek": map[string]interface{}{
"api_key": "sk-deep-test",
"unknown_provider": map[string]interface{}{
"api_key": "sk-test",
},
},
}
@@ -193,7 +193,7 @@ func TestConvertConfig(t *testing.T) {
if len(warnings) != 1 {
t.Fatalf("expected 1 warning, got %d", len(warnings))
}
if warnings[0] != "Provider 'deepseek' not supported in PicoClaw, skipping" {
if warnings[0] != "Provider 'unknown_provider' not supported in PicoClaw, skipping" {
t.Errorf("unexpected warning: %s", warnings[0])
}
})
+4 -4
View File
@@ -419,7 +419,7 @@ func TestCreateProvider_ClaudeCli(t *testing.T) {
cfg.Agents.Defaults.Provider = "claude-cli"
cfg.Agents.Defaults.Workspace = "/test/ws"
provider, err := CreateProvider(cfg)
provider, _, err := CreateProvider(cfg)
if err != nil {
t.Fatalf("CreateProvider(claude-cli) error = %v", err)
}
@@ -437,7 +437,7 @@ func TestCreateProvider_ClaudeCode(t *testing.T) {
cfg := config.DefaultConfig()
cfg.Agents.Defaults.Provider = "claude-code"
provider, err := CreateProvider(cfg)
provider, _, err := CreateProvider(cfg)
if err != nil {
t.Fatalf("CreateProvider(claude-code) error = %v", err)
}
@@ -450,7 +450,7 @@ func TestCreateProvider_ClaudeCodec(t *testing.T) {
cfg := config.DefaultConfig()
cfg.Agents.Defaults.Provider = "claudecode"
provider, err := CreateProvider(cfg)
provider, _, err := CreateProvider(cfg)
if err != nil {
t.Fatalf("CreateProvider(claudecode) error = %v", err)
}
@@ -464,7 +464,7 @@ func TestCreateProvider_ClaudeCliDefaultWorkspace(t *testing.T) {
cfg.Agents.Defaults.Provider = "claude-cli"
cfg.Agents.Defaults.Workspace = ""
provider, err := CreateProvider(cfg)
provider, _, err := CreateProvider(cfg)
if err != nil {
t.Fatalf("CreateProvider error = %v", err)
}
+17 -12
View File
@@ -32,13 +32,14 @@ func ExtractProtocol(model string) (protocol, modelID string) {
// CreateProviderFromConfig creates a provider based on the ModelConfig.
// It uses the protocol prefix in the Model field to determine which provider to create.
// Supported protocols: openai, anthropic, antigravity, claude-cli, codex-cli, github-copilot
func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, error) {
// Returns the provider, the model ID (without protocol prefix), and any error.
func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, error) {
if cfg == nil {
return nil, fmt.Errorf("config is nil")
return nil, "", fmt.Errorf("config is nil")
}
if cfg.Model == "" {
return nil, fmt.Errorf("model is required")
return nil, "", fmt.Errorf("model is required")
}
protocol, modelID := ExtractProtocol(cfg.Model)
@@ -49,36 +50,36 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, error) {
"volcengine", "vllm", "qwen":
// All OpenAI-compatible HTTP providers
if cfg.APIKey == "" && cfg.APIBase == "" {
return nil, fmt.Errorf("api_key or api_base is required for HTTP-based protocol %q", protocol)
return nil, "", fmt.Errorf("api_key or api_base is required for HTTP-based protocol %q", protocol)
}
apiBase := cfg.APIBase
if apiBase == "" {
apiBase = getDefaultAPIBase(protocol)
}
return NewHTTPProvider(cfg.APIKey, apiBase, cfg.Proxy), nil
return NewHTTPProvider(cfg.APIKey, apiBase, cfg.Proxy), modelID, nil
case "anthropic":
if cfg.AuthMethod == "oauth" || cfg.AuthMethod == "token" {
// Use Claude SDK with token
return NewClaudeProvider(cfg.APIKey), nil
return NewClaudeProvider(cfg.APIKey), modelID, nil
}
// Use HTTP API
apiBase := cfg.APIBase
if apiBase == "" {
apiBase = "https://api.anthropic.com/v1"
}
return NewHTTPProvider(cfg.APIKey, apiBase, cfg.Proxy), nil
return NewHTTPProvider(cfg.APIKey, apiBase, cfg.Proxy), modelID, nil
case "antigravity":
return NewAntigravityProvider(), nil
return NewAntigravityProvider(), modelID, nil
case "claude-cli", "claudecli":
workspace := "."
return NewClaudeCliProvider(workspace), nil
return NewClaudeCliProvider(workspace), modelID, nil
case "codex-cli", "codexcli":
workspace := "."
return NewCodexCliProvider(workspace), nil
return NewCodexCliProvider(workspace), modelID, nil
case "github-copilot", "copilot":
apiBase := cfg.APIBase
@@ -89,10 +90,14 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, error) {
if connectMode == "" {
connectMode = "grpc"
}
return NewGitHubCopilotProvider(apiBase, connectMode, modelID)
provider, err := NewGitHubCopilotProvider(apiBase, connectMode, modelID)
if err != nil {
return nil, "", err
}
return provider, modelID, nil
default:
return nil, fmt.Errorf("unknown protocol %q in model %q", protocol, cfg.Model)
return nil, "", fmt.Errorf("unknown protocol %q in model %q", protocol, cfg.Model)
}
}
+250
View File
@@ -0,0 +1,250 @@
// PicoClaw - Ultra-lightweight personal AI agent
// License: MIT
//
// Copyright (c) 2026 PicoClaw contributors
package providers
import (
"testing"
"github.com/sipeed/picoclaw/pkg/config"
)
func TestExtractProtocol(t *testing.T) {
tests := []struct {
name string
model string
wantProtocol string
wantModelID string
}{
{
name: "openai with prefix",
model: "openai/gpt-4o",
wantProtocol: "openai",
wantModelID: "gpt-4o",
},
{
name: "anthropic with prefix",
model: "anthropic/claude-3-sonnet",
wantProtocol: "anthropic",
wantModelID: "claude-3-sonnet",
},
{
name: "no prefix - defaults to openai",
model: "gpt-4o",
wantProtocol: "openai",
wantModelID: "gpt-4o",
},
{
name: "groq with prefix",
model: "groq/llama-3.1-70b",
wantProtocol: "groq",
wantModelID: "llama-3.1-70b",
},
{
name: "empty string",
model: "",
wantProtocol: "openai",
wantModelID: "",
},
{
name: "with whitespace",
model: " openai/gpt-4 ",
wantProtocol: "openai",
wantModelID: "gpt-4",
},
{
name: "multiple slashes",
model: "nvidia/meta/llama-3.1-8b",
wantProtocol: "nvidia",
wantModelID: "meta/llama-3.1-8b",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
protocol, modelID := ExtractProtocol(tt.model)
if protocol != tt.wantProtocol {
t.Errorf("ExtractProtocol(%q) protocol = %q, want %q", tt.model, protocol, tt.wantProtocol)
}
if modelID != tt.wantModelID {
t.Errorf("ExtractProtocol(%q) modelID = %q, want %q", tt.model, modelID, tt.wantModelID)
}
})
}
}
func TestCreateProviderFromConfig_OpenAI(t *testing.T) {
cfg := &config.ModelConfig{
ModelName: "test-openai",
Model: "openai/gpt-4o",
APIKey: "test-key",
APIBase: "https://api.example.com/v1",
}
provider, modelID, err := CreateProviderFromConfig(cfg)
if err != nil {
t.Fatalf("CreateProviderFromConfig() error = %v", err)
}
if provider == nil {
t.Fatal("CreateProviderFromConfig() returned nil provider")
}
if modelID != "gpt-4o" {
t.Errorf("modelID = %q, want %q", modelID, "gpt-4o")
}
}
func TestCreateProviderFromConfig_DefaultAPIBase(t *testing.T) {
tests := []struct {
name string
protocol string
wantBase string
}{
{"openai", "openai", "https://api.openai.com/v1"},
{"groq", "groq", "https://api.groq.com/openai/v1"},
{"openrouter", "openrouter", "https://openrouter.ai/api/v1"},
{"cerebras", "cerebras", "https://api.cerebras.ai/v1"},
{"qwen", "qwen", "https://dashscope.aliyuncs.com/compatible-mode/v1"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := &config.ModelConfig{
ModelName: "test-" + tt.protocol,
Model: tt.protocol + "/test-model",
APIKey: "test-key",
}
provider, _, err := CreateProviderFromConfig(cfg)
if err != nil {
t.Fatalf("CreateProviderFromConfig() error = %v", err)
}
httpProvider, ok := provider.(*HTTPProvider)
if !ok {
t.Fatalf("expected *HTTPProvider, got %T", provider)
}
if httpProvider.apiBase != tt.wantBase {
t.Errorf("apiBase = %q, want %q", httpProvider.apiBase, tt.wantBase)
}
})
}
}
func TestCreateProviderFromConfig_Anthropic(t *testing.T) {
cfg := &config.ModelConfig{
ModelName: "test-anthropic",
Model: "anthropic/claude-3-sonnet",
APIKey: "test-key",
}
provider, modelID, err := CreateProviderFromConfig(cfg)
if err != nil {
t.Fatalf("CreateProviderFromConfig() error = %v", err)
}
if provider == nil {
t.Fatal("CreateProviderFromConfig() returned nil provider")
}
if modelID != "claude-3-sonnet" {
t.Errorf("modelID = %q, want %q", modelID, "claude-3-sonnet")
}
}
func TestCreateProviderFromConfig_Antigravity(t *testing.T) {
cfg := &config.ModelConfig{
ModelName: "test-antigravity",
Model: "antigravity/gemini-2.0-flash",
}
provider, modelID, err := CreateProviderFromConfig(cfg)
if err != nil {
t.Fatalf("CreateProviderFromConfig() error = %v", err)
}
if provider == nil {
t.Fatal("CreateProviderFromConfig() returned nil provider")
}
if modelID != "gemini-2.0-flash" {
t.Errorf("modelID = %q, want %q", modelID, "gemini-2.0-flash")
}
}
func TestCreateProviderFromConfig_ClaudeCLI(t *testing.T) {
cfg := &config.ModelConfig{
ModelName: "test-claude-cli",
Model: "claude-cli/claude-sonnet-4-20250514",
}
provider, modelID, err := CreateProviderFromConfig(cfg)
if err != nil {
t.Fatalf("CreateProviderFromConfig() error = %v", err)
}
if provider == nil {
t.Fatal("CreateProviderFromConfig() returned nil provider")
}
if modelID != "claude-sonnet-4-20250514" {
t.Errorf("modelID = %q, want %q", modelID, "claude-sonnet-4-20250514")
}
}
func TestCreateProviderFromConfig_CodexCLI(t *testing.T) {
cfg := &config.ModelConfig{
ModelName: "test-codex-cli",
Model: "codex-cli/codex",
}
provider, modelID, err := CreateProviderFromConfig(cfg)
if err != nil {
t.Fatalf("CreateProviderFromConfig() error = %v", err)
}
if provider == nil {
t.Fatal("CreateProviderFromConfig() returned nil provider")
}
if modelID != "codex" {
t.Errorf("modelID = %q, want %q", modelID, "codex")
}
}
func TestCreateProviderFromConfig_MissingAPIKey(t *testing.T) {
cfg := &config.ModelConfig{
ModelName: "test-no-key",
Model: "openai/gpt-4o",
}
_, _, err := CreateProviderFromConfig(cfg)
if err == nil {
t.Fatal("CreateProviderFromConfig() expected error for missing API key")
}
}
func TestCreateProviderFromConfig_UnknownProtocol(t *testing.T) {
cfg := &config.ModelConfig{
ModelName: "test-unknown",
Model: "unknown-protocol/model",
APIKey: "test-key",
}
_, _, err := CreateProviderFromConfig(cfg)
if err == nil {
t.Fatal("CreateProviderFromConfig() expected error for unknown protocol")
}
}
func TestCreateProviderFromConfig_NilConfig(t *testing.T) {
_, _, err := CreateProviderFromConfig(nil)
if err == nil {
t.Fatal("CreateProviderFromConfig(nil) expected error")
}
}
func TestCreateProviderFromConfig_EmptyModel(t *testing.T) {
cfg := &config.ModelConfig{
ModelName: "test-empty",
Model: "",
}
_, _, err := CreateProviderFromConfig(cfg)
if err == nil {
t.Fatal("CreateProviderFromConfig() expected error for empty model")
}
}
+6 -332
View File
@@ -16,9 +16,6 @@ import (
"net/url"
"strings"
"time"
"github.com/sipeed/picoclaw/pkg/auth"
"github.com/sipeed/picoclaw/pkg/config"
)
type HTTPProvider struct {
@@ -161,13 +158,15 @@ func (p *HTTPProvider) parseResponse(body []byte) (*LLMResponse, error) {
arguments := make(map[string]interface{})
name := ""
thoughtSignature := ""
argsStr := ""
if tc.Function != nil {
name = tc.Function.Name
thoughtSignature = tc.Function.ThoughtSignature
if tc.Function.Arguments != "" {
if err := json.Unmarshal([]byte(tc.Function.Arguments), &arguments); err != nil {
arguments["raw"] = tc.Function.Arguments
argsStr = tc.Function.Arguments
if argsStr != "" {
if err := json.Unmarshal([]byte(argsStr), &arguments); err != nil {
arguments["raw"] = argsStr
}
}
}
@@ -177,7 +176,7 @@ func (p *HTTPProvider) parseResponse(body []byte) (*LLMResponse, error) {
Type: tc.Type,
Function: &FunctionCall{
Name: name,
Arguments: tc.Function.Arguments,
Arguments: argsStr,
ThoughtSignature: thoughtSignature,
},
Name: name,
@@ -196,328 +195,3 @@ func (p *HTTPProvider) parseResponse(body []byte) (*LLMResponse, error) {
func (p *HTTPProvider) GetDefaultModel() string {
return ""
}
func createClaudeAuthProvider() (LLMProvider, error) {
cred, err := auth.GetCredential("anthropic")
if err != nil {
return nil, fmt.Errorf("loading auth credentials: %w", err)
}
if cred == nil {
return nil, fmt.Errorf("no credentials for anthropic. Run: picoclaw auth login --provider anthropic")
}
return NewClaudeProviderWithTokenSource(cred.AccessToken, createClaudeTokenSource()), nil
}
func createCodexAuthProvider() (LLMProvider, error) {
cred, err := auth.GetCredential("openai")
if err != nil {
return nil, fmt.Errorf("loading auth credentials: %w", err)
}
if cred == nil {
return nil, fmt.Errorf("no credentials for openai. Run: picoclaw auth login --provider openai")
}
return NewCodexProviderWithTokenSource(cred.AccessToken, cred.AccountID, createCodexTokenSource()), nil
}
func CreateProvider(cfg *config.Config) (LLMProvider, error) {
model := cfg.Agents.Defaults.Model
// First, try to use model_list configuration
if len(cfg.ModelList) > 0 {
// Try to get config by model name first
modelCfg, err := cfg.GetModelConfig(model)
if err == nil {
// Found in model_list, use factory to create provider
provider, err := CreateProviderFromConfig(modelCfg)
if err != nil {
return nil, fmt.Errorf("failed to create provider from model_list: %w", err)
}
return provider, nil
}
// Model not found in model_list, fall through to providers config
}
// Log deprecation warning if using old providers config
if cfg.HasProvidersConfig() && len(cfg.ModelList) == 0 {
fmt.Println("WARNING: providers config is deprecated, please migrate to model_list")
}
providerName := strings.ToLower(cfg.Agents.Defaults.Provider)
var apiKey, apiBase, proxy string
lowerModel := strings.ToLower(model)
// First, try to use explicitly configured provider
if providerName != "" {
switch providerName {
case "groq":
if cfg.Providers.Groq.APIKey != "" {
apiKey = cfg.Providers.Groq.APIKey
apiBase = cfg.Providers.Groq.APIBase
if apiBase == "" {
apiBase = "https://api.groq.com/openai/v1"
}
}
case "openai", "gpt":
if cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != "" {
if cfg.Providers.OpenAI.AuthMethod == "codex-cli" {
return NewCodexProviderWithTokenSource("", "", CreateCodexCliTokenSource()), nil
}
if cfg.Providers.OpenAI.AuthMethod == "oauth" || cfg.Providers.OpenAI.AuthMethod == "token" {
return createCodexAuthProvider()
}
apiKey = cfg.Providers.OpenAI.APIKey
apiBase = cfg.Providers.OpenAI.APIBase
if apiBase == "" {
apiBase = "https://api.openai.com/v1"
}
}
case "anthropic", "claude":
if cfg.Providers.Anthropic.APIKey != "" || cfg.Providers.Anthropic.AuthMethod != "" {
if cfg.Providers.Anthropic.AuthMethod == "oauth" || cfg.Providers.Anthropic.AuthMethod == "token" {
return createClaudeAuthProvider()
}
apiKey = cfg.Providers.Anthropic.APIKey
apiBase = cfg.Providers.Anthropic.APIBase
if apiBase == "" {
apiBase = "https://api.anthropic.com/v1"
}
}
case "openrouter":
if cfg.Providers.OpenRouter.APIKey != "" {
apiKey = cfg.Providers.OpenRouter.APIKey
if cfg.Providers.OpenRouter.APIBase != "" {
apiBase = cfg.Providers.OpenRouter.APIBase
} else {
apiBase = "https://openrouter.ai/api/v1"
}
}
case "zhipu", "glm":
if cfg.Providers.Zhipu.APIKey != "" {
apiKey = cfg.Providers.Zhipu.APIKey
apiBase = cfg.Providers.Zhipu.APIBase
if apiBase == "" {
apiBase = "https://open.bigmodel.cn/api/paas/v4"
}
}
case "gemini", "google":
if cfg.Providers.Gemini.APIKey != "" {
apiKey = cfg.Providers.Gemini.APIKey
apiBase = cfg.Providers.Gemini.APIBase
if apiBase == "" {
apiBase = "https://generativelanguage.googleapis.com/v1beta"
}
}
case "vllm":
if cfg.Providers.VLLM.APIBase != "" {
apiKey = cfg.Providers.VLLM.APIKey
apiBase = cfg.Providers.VLLM.APIBase
}
case "shengsuanyun":
if cfg.Providers.ShengSuanYun.APIKey != "" {
apiKey = cfg.Providers.ShengSuanYun.APIKey
apiBase = cfg.Providers.ShengSuanYun.APIBase
if apiBase == "" {
apiBase = "https://router.shengsuanyun.com/api/v1"
}
}
case "claude-cli", "claudecode", "claude-code":
workspace := cfg.WorkspacePath()
if workspace == "" {
workspace = "."
}
return NewClaudeCliProvider(workspace), nil
case "codex-cli", "codex-code":
workspace := cfg.WorkspacePath()
if workspace == "" {
workspace = "."
}
return NewCodexCliProvider(workspace), nil
case "cerebras":
if cfg.Providers.Cerebras.APIKey != "" {
apiKey = cfg.Providers.Cerebras.APIKey
apiBase = cfg.Providers.Cerebras.APIBase
if apiBase == "" {
apiBase = "https://api.cerebras.ai/v1"
}
}
case "deepseek":
if cfg.Providers.DeepSeek.APIKey != "" {
apiKey = cfg.Providers.DeepSeek.APIKey
apiBase = cfg.Providers.DeepSeek.APIBase
if apiBase == "" {
apiBase = "https://api.deepseek.com/v1"
}
if model != "deepseek-chat" && model != "deepseek-reasoner" {
model = "deepseek-chat"
}
}
case "qwen":
if cfg.Providers.Qwen.APIKey != "" {
apiKey = cfg.Providers.Qwen.APIKey
apiBase = cfg.Providers.Qwen.APIBase
if apiBase == "" {
apiBase = "https://dashscope.aliyuncs.com/compatible-mode/v1"
}
}
case "github_copilot", "copilot":
if cfg.Providers.GitHubCopilot.APIBase != "" {
apiBase = cfg.Providers.GitHubCopilot.APIBase
} else {
apiBase = "localhost:4321"
}
return NewGitHubCopilotProvider(apiBase, cfg.Providers.GitHubCopilot.ConnectMode, model)
case "antigravity", "google-antigravity":
return NewAntigravityProvider(), nil
case "volcengine", "doubao":
if cfg.Providers.VolcEngine.APIKey != "" {
apiKey = cfg.Providers.VolcEngine.APIKey
apiBase = cfg.Providers.VolcEngine.APIBase
if apiBase == "" {
apiBase = "https://ark.cn-beijing.volces.com/api/v3"
}
}
}
}
// Fallback: detect provider from model name
if apiKey == "" && apiBase == "" {
switch {
case (strings.Contains(lowerModel, "kimi") || strings.Contains(lowerModel, "moonshot") || strings.HasPrefix(model, "moonshot/")) && cfg.Providers.Moonshot.APIKey != "":
apiKey = cfg.Providers.Moonshot.APIKey
apiBase = cfg.Providers.Moonshot.APIBase
proxy = cfg.Providers.Moonshot.Proxy
if apiBase == "" {
apiBase = "https://api.moonshot.cn/v1"
}
case strings.HasPrefix(model, "openrouter/") || strings.HasPrefix(model, "anthropic/") || strings.HasPrefix(model, "openai/") || strings.HasPrefix(model, "meta-llama/") || strings.HasPrefix(model, "deepseek/") || strings.HasPrefix(model, "google/"):
apiKey = cfg.Providers.OpenRouter.APIKey
proxy = cfg.Providers.OpenRouter.Proxy
if cfg.Providers.OpenRouter.APIBase != "" {
apiBase = cfg.Providers.OpenRouter.APIBase
} else {
apiBase = "https://openrouter.ai/api/v1"
}
case (strings.Contains(lowerModel, "claude") || strings.HasPrefix(model, "anthropic/")) && (cfg.Providers.Anthropic.APIKey != "" || cfg.Providers.Anthropic.AuthMethod != ""):
if cfg.Providers.Anthropic.AuthMethod == "oauth" || cfg.Providers.Anthropic.AuthMethod == "token" {
return createClaudeAuthProvider()
}
apiKey = cfg.Providers.Anthropic.APIKey
apiBase = cfg.Providers.Anthropic.APIBase
proxy = cfg.Providers.Anthropic.Proxy
if apiBase == "" {
apiBase = "https://api.anthropic.com/v1"
}
case (strings.Contains(lowerModel, "gpt") || strings.HasPrefix(model, "openai/")) && (cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != ""):
if cfg.Providers.OpenAI.AuthMethod == "oauth" || cfg.Providers.OpenAI.AuthMethod == "token" {
return createCodexAuthProvider()
}
apiKey = cfg.Providers.OpenAI.APIKey
apiBase = cfg.Providers.OpenAI.APIBase
proxy = cfg.Providers.OpenAI.Proxy
if apiBase == "" {
apiBase = "https://api.openai.com/v1"
}
case (strings.Contains(lowerModel, "gemini") || strings.HasPrefix(model, "google/")) && cfg.Providers.Gemini.APIKey != "":
apiKey = cfg.Providers.Gemini.APIKey
apiBase = cfg.Providers.Gemini.APIBase
proxy = cfg.Providers.Gemini.Proxy
if apiBase == "" {
apiBase = "https://generativelanguage.googleapis.com/v1beta"
}
case (strings.Contains(lowerModel, "glm") || strings.Contains(lowerModel, "zhipu") || strings.Contains(lowerModel, "zai")) && cfg.Providers.Zhipu.APIKey != "":
apiKey = cfg.Providers.Zhipu.APIKey
apiBase = cfg.Providers.Zhipu.APIBase
proxy = cfg.Providers.Zhipu.Proxy
if apiBase == "" {
apiBase = "https://open.bigmodel.cn/api/paas/v4"
}
case (strings.Contains(lowerModel, "groq") || strings.HasPrefix(model, "groq/")) && cfg.Providers.Groq.APIKey != "":
apiKey = cfg.Providers.Groq.APIKey
apiBase = cfg.Providers.Groq.APIBase
proxy = cfg.Providers.Groq.Proxy
if apiBase == "" {
apiBase = "https://api.groq.com/openai/v1"
}
case (strings.Contains(lowerModel, "qwen") || strings.HasPrefix(model, "qwen/")) && cfg.Providers.Qwen.APIKey != "":
apiKey = cfg.Providers.Qwen.APIKey
apiBase = cfg.Providers.Qwen.APIBase
proxy = cfg.Providers.Qwen.Proxy
if apiBase == "" {
apiBase = "https://dashscope.aliyuncs.com/compatible-mode/v1"
}
case (strings.Contains(lowerModel, "nvidia") || strings.HasPrefix(model, "nvidia/")) && cfg.Providers.Nvidia.APIKey != "":
apiKey = cfg.Providers.Nvidia.APIKey
apiBase = cfg.Providers.Nvidia.APIBase
proxy = cfg.Providers.Nvidia.Proxy
if apiBase == "" {
apiBase = "https://integrate.api.nvidia.com/v1"
}
case (strings.Contains(lowerModel, "cerebras") || strings.HasPrefix(model, "cerebras/")) && cfg.Providers.Cerebras.APIKey != "":
apiKey = cfg.Providers.Cerebras.APIKey
apiBase = cfg.Providers.Cerebras.APIBase
proxy = cfg.Providers.Cerebras.Proxy
if apiBase == "" {
apiBase = "https://api.cerebras.ai/v1"
}
case (strings.Contains(lowerModel, "ollama") || strings.HasPrefix(model, "ollama/")) && cfg.Providers.Ollama.APIKey != "":
fmt.Println("Ollama provider selected based on model name prefix")
apiKey = cfg.Providers.Ollama.APIKey
apiBase = cfg.Providers.Ollama.APIBase
proxy = cfg.Providers.Ollama.Proxy
if apiBase == "" {
apiBase = "http://localhost:11434/v1"
}
fmt.Println("Ollama apiBase:", apiBase)
case (strings.Contains(lowerModel, "doubao") || strings.HasPrefix(lowerModel, "doubao") || strings.Contains(lowerModel, "volcengine")) && cfg.Providers.VolcEngine.APIKey != "":
apiKey = cfg.Providers.VolcEngine.APIKey
apiBase = cfg.Providers.VolcEngine.APIBase
proxy = cfg.Providers.VolcEngine.Proxy
if apiBase == "" {
apiBase = "https://ark.cn-beijing.volces.com/api/v3"
}
case cfg.Providers.VLLM.APIBase != "":
apiKey = cfg.Providers.VLLM.APIKey
apiBase = cfg.Providers.VLLM.APIBase
proxy = cfg.Providers.VLLM.Proxy
default:
if cfg.Providers.OpenRouter.APIKey != "" {
apiKey = cfg.Providers.OpenRouter.APIKey
proxy = cfg.Providers.OpenRouter.Proxy
if cfg.Providers.OpenRouter.APIBase != "" {
apiBase = cfg.Providers.OpenRouter.APIBase
} else {
apiBase = "https://openrouter.ai/api/v1"
}
} else {
return nil, fmt.Errorf("no API key configured for model: %s", model)
}
}
}
if apiKey == "" && !strings.HasPrefix(model, "bedrock/") {
return nil, fmt.Errorf("no API key configured for provider (model: %s)", model)
}
if apiBase == "" {
return nil, fmt.Errorf("no API base configured for provider (model: %s)", model)
}
return NewHTTPProvider(apiKey, apiBase, proxy), nil
}
+349
View File
@@ -0,0 +1,349 @@
// PicoClaw - Ultra-lightweight personal AI agent
// License: MIT
//
// Copyright (c) 2026 PicoClaw contributors
package providers
import (
"fmt"
"strings"
"github.com/sipeed/picoclaw/pkg/auth"
"github.com/sipeed/picoclaw/pkg/config"
)
// createClaudeAuthProvider creates a Claude provider using OAuth credentials.
func createClaudeAuthProvider() (LLMProvider, error) {
cred, err := auth.GetCredential("anthropic")
if err != nil {
return nil, fmt.Errorf("loading auth credentials: %w", err)
}
if cred == nil {
return nil, fmt.Errorf("no credentials for anthropic. Run: picoclaw auth login --provider anthropic")
}
return NewClaudeProviderWithTokenSource(cred.AccessToken, createClaudeTokenSource()), nil
}
// createCodexAuthProvider creates a Codex provider using OAuth credentials.
func createCodexAuthProvider() (LLMProvider, error) {
cred, err := auth.GetCredential("openai")
if err != nil {
return nil, fmt.Errorf("loading auth credentials: %w", err)
}
if cred == nil {
return nil, fmt.Errorf("no credentials for openai. Run: picoclaw auth login --provider openai")
}
return NewCodexProviderWithTokenSource(cred.AccessToken, cred.AccountID, createCodexTokenSource()), nil
}
// CreateProvider creates a provider based on the configuration.
// It supports both the new model_list configuration and the legacy providers configuration.
// Returns the provider, the model ID to use, and any error.
func CreateProvider(cfg *config.Config) (LLMProvider, string, error) {
model := cfg.Agents.Defaults.Model
// First, try to use model_list configuration
if len(cfg.ModelList) > 0 {
// Try to get config by model name first
modelCfg, err := cfg.GetModelConfig(model)
if err == nil {
// Found in model_list, use factory to create provider
provider, modelID, err := CreateProviderFromConfig(modelCfg)
if err != nil {
return nil, "", fmt.Errorf("failed to create provider from model_list: %w", err)
}
return provider, modelID, nil
}
// Model not found in model_list, fall through to providers config
}
// Log deprecation warning if using old providers config
if cfg.HasProvidersConfig() && len(cfg.ModelList) == 0 {
fmt.Println("WARNING: providers config is deprecated, please migrate to model_list")
}
providerName := strings.ToLower(cfg.Agents.Defaults.Provider)
var apiKey, apiBase, proxy string
lowerModel := strings.ToLower(model)
// First, try to use explicitly configured provider
if providerName != "" {
switch providerName {
case "groq":
if cfg.Providers.Groq.APIKey != "" {
apiKey = cfg.Providers.Groq.APIKey
apiBase = cfg.Providers.Groq.APIBase
if apiBase == "" {
apiBase = "https://api.groq.com/openai/v1"
}
}
case "openai", "gpt":
if cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != "" {
if cfg.Providers.OpenAI.AuthMethod == "codex-cli" {
return NewCodexProviderWithTokenSource("", "", CreateCodexCliTokenSource()), model, nil
}
if cfg.Providers.OpenAI.AuthMethod == "oauth" || cfg.Providers.OpenAI.AuthMethod == "token" {
provider, err := createCodexAuthProvider()
return provider, model, err
}
apiKey = cfg.Providers.OpenAI.APIKey
apiBase = cfg.Providers.OpenAI.APIBase
if apiBase == "" {
apiBase = "https://api.openai.com/v1"
}
}
case "anthropic", "claude":
if cfg.Providers.Anthropic.APIKey != "" || cfg.Providers.Anthropic.AuthMethod != "" {
if cfg.Providers.Anthropic.AuthMethod == "oauth" || cfg.Providers.Anthropic.AuthMethod == "token" {
provider, err := createClaudeAuthProvider()
return provider, model, err
}
apiKey = cfg.Providers.Anthropic.APIKey
apiBase = cfg.Providers.Anthropic.APIBase
if apiBase == "" {
apiBase = "https://api.anthropic.com/v1"
}
}
case "openrouter":
if cfg.Providers.OpenRouter.APIKey != "" {
apiKey = cfg.Providers.OpenRouter.APIKey
if cfg.Providers.OpenRouter.APIBase != "" {
apiBase = cfg.Providers.OpenRouter.APIBase
} else {
apiBase = "https://openrouter.ai/api/v1"
}
}
case "zhipu", "glm":
if cfg.Providers.Zhipu.APIKey != "" {
apiKey = cfg.Providers.Zhipu.APIKey
apiBase = cfg.Providers.Zhipu.APIBase
if apiBase == "" {
apiBase = "https://open.bigmodel.cn/api/paas/v4"
}
}
case "gemini", "google":
if cfg.Providers.Gemini.APIKey != "" {
apiKey = cfg.Providers.Gemini.APIKey
apiBase = cfg.Providers.Gemini.APIBase
if apiBase == "" {
apiBase = "https://generativelanguage.googleapis.com/v1beta"
}
}
case "vllm":
if cfg.Providers.VLLM.APIBase != "" {
apiKey = cfg.Providers.VLLM.APIKey
apiBase = cfg.Providers.VLLM.APIBase
}
case "shengsuanyun":
if cfg.Providers.ShengSuanYun.APIKey != "" {
apiKey = cfg.Providers.ShengSuanYun.APIKey
apiBase = cfg.Providers.ShengSuanYun.APIBase
if apiBase == "" {
apiBase = "https://router.shengsuanyun.com/api/v1"
}
}
case "claude-cli", "claudecode", "claude-code":
workspace := cfg.WorkspacePath()
if workspace == "" {
workspace = "."
}
return NewClaudeCliProvider(workspace), model, nil
case "codex-cli", "codex-code":
workspace := cfg.WorkspacePath()
if workspace == "" {
workspace = "."
}
return NewCodexCliProvider(workspace), model, nil
case "cerebras":
if cfg.Providers.Cerebras.APIKey != "" {
apiKey = cfg.Providers.Cerebras.APIKey
apiBase = cfg.Providers.Cerebras.APIBase
if apiBase == "" {
apiBase = "https://api.cerebras.ai/v1"
}
}
case "deepseek":
if cfg.Providers.DeepSeek.APIKey != "" {
apiKey = cfg.Providers.DeepSeek.APIKey
apiBase = cfg.Providers.DeepSeek.APIBase
if apiBase == "" {
apiBase = "https://api.deepseek.com/v1"
}
if model != "deepseek-chat" && model != "deepseek-reasoner" {
model = "deepseek-chat"
}
}
case "qwen":
if cfg.Providers.Qwen.APIKey != "" {
apiKey = cfg.Providers.Qwen.APIKey
apiBase = cfg.Providers.Qwen.APIBase
if apiBase == "" {
apiBase = "https://dashscope.aliyuncs.com/compatible-mode/v1"
}
}
case "github_copilot", "copilot":
if cfg.Providers.GitHubCopilot.APIBase != "" {
apiBase = cfg.Providers.GitHubCopilot.APIBase
} else {
apiBase = "localhost:4321"
}
provider, err := NewGitHubCopilotProvider(apiBase, cfg.Providers.GitHubCopilot.ConnectMode, model)
return provider, model, err
case "antigravity", "google-antigravity":
return NewAntigravityProvider(), model, nil
case "volcengine", "doubao":
if cfg.Providers.VolcEngine.APIKey != "" {
apiKey = cfg.Providers.VolcEngine.APIKey
apiBase = cfg.Providers.VolcEngine.APIBase
if apiBase == "" {
apiBase = "https://ark.cn-beijing.volces.com/api/v3"
}
}
}
}
// Fallback: detect provider from model name
if apiKey == "" && apiBase == "" {
switch {
case (strings.Contains(lowerModel, "kimi") || strings.Contains(lowerModel, "moonshot") || strings.HasPrefix(model, "moonshot/")) && cfg.Providers.Moonshot.APIKey != "":
apiKey = cfg.Providers.Moonshot.APIKey
apiBase = cfg.Providers.Moonshot.APIBase
proxy = cfg.Providers.Moonshot.Proxy
if apiBase == "" {
apiBase = "https://api.moonshot.cn/v1"
}
case strings.HasPrefix(model, "openrouter/") || strings.HasPrefix(model, "anthropic/") || strings.HasPrefix(model, "openai/") || strings.HasPrefix(model, "meta-llama/") || strings.HasPrefix(model, "deepseek/") || strings.HasPrefix(model, "google/"):
apiKey = cfg.Providers.OpenRouter.APIKey
proxy = cfg.Providers.OpenRouter.Proxy
if cfg.Providers.OpenRouter.APIBase != "" {
apiBase = cfg.Providers.OpenRouter.APIBase
} else {
apiBase = "https://openrouter.ai/api/v1"
}
case (strings.Contains(lowerModel, "claude") || strings.HasPrefix(model, "anthropic/")) && (cfg.Providers.Anthropic.APIKey != "" || cfg.Providers.Anthropic.AuthMethod != ""):
if cfg.Providers.Anthropic.AuthMethod == "oauth" || cfg.Providers.Anthropic.AuthMethod == "token" {
provider, err := createClaudeAuthProvider()
return provider, model, err
}
apiKey = cfg.Providers.Anthropic.APIKey
apiBase = cfg.Providers.Anthropic.APIBase
proxy = cfg.Providers.Anthropic.Proxy
if apiBase == "" {
apiBase = "https://api.anthropic.com/v1"
}
case (strings.Contains(lowerModel, "gpt") || strings.HasPrefix(model, "openai/")) && (cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != ""):
if cfg.Providers.OpenAI.AuthMethod == "oauth" || cfg.Providers.OpenAI.AuthMethod == "token" {
provider, err := createCodexAuthProvider()
return provider, model, err
}
apiKey = cfg.Providers.OpenAI.APIKey
apiBase = cfg.Providers.OpenAI.APIBase
proxy = cfg.Providers.OpenAI.Proxy
if apiBase == "" {
apiBase = "https://api.openai.com/v1"
}
case (strings.Contains(lowerModel, "gemini") || strings.HasPrefix(model, "google/")) && cfg.Providers.Gemini.APIKey != "":
apiKey = cfg.Providers.Gemini.APIKey
apiBase = cfg.Providers.Gemini.APIBase
proxy = cfg.Providers.Gemini.Proxy
if apiBase == "" {
apiBase = "https://generativelanguage.googleapis.com/v1beta"
}
case (strings.Contains(lowerModel, "glm") || strings.Contains(lowerModel, "zhipu") || strings.Contains(lowerModel, "zai")) && cfg.Providers.Zhipu.APIKey != "":
apiKey = cfg.Providers.Zhipu.APIKey
apiBase = cfg.Providers.Zhipu.APIBase
proxy = cfg.Providers.Zhipu.Proxy
if apiBase == "" {
apiBase = "https://open.bigmodel.cn/api/paas/v4"
}
case (strings.Contains(lowerModel, "groq") || strings.HasPrefix(model, "groq/")) && cfg.Providers.Groq.APIKey != "":
apiKey = cfg.Providers.Groq.APIKey
apiBase = cfg.Providers.Groq.APIBase
proxy = cfg.Providers.Groq.Proxy
if apiBase == "" {
apiBase = "https://api.groq.com/openai/v1"
}
case (strings.Contains(lowerModel, "qwen") || strings.HasPrefix(model, "qwen/")) && cfg.Providers.Qwen.APIKey != "":
apiKey = cfg.Providers.Qwen.APIKey
apiBase = cfg.Providers.Qwen.APIBase
proxy = cfg.Providers.Qwen.Proxy
if apiBase == "" {
apiBase = "https://dashscope.aliyuncs.com/compatible-mode/v1"
}
case (strings.Contains(lowerModel, "nvidia") || strings.HasPrefix(model, "nvidia/")) && cfg.Providers.Nvidia.APIKey != "":
apiKey = cfg.Providers.Nvidia.APIKey
apiBase = cfg.Providers.Nvidia.APIBase
proxy = cfg.Providers.Nvidia.Proxy
if apiBase == "" {
apiBase = "https://integrate.api.nvidia.com/v1"
}
case (strings.Contains(lowerModel, "cerebras") || strings.HasPrefix(model, "cerebras/")) && cfg.Providers.Cerebras.APIKey != "":
apiKey = cfg.Providers.Cerebras.APIKey
apiBase = cfg.Providers.Cerebras.APIBase
proxy = cfg.Providers.Cerebras.Proxy
if apiBase == "" {
apiBase = "https://api.cerebras.ai/v1"
}
case (strings.Contains(lowerModel, "ollama") || strings.HasPrefix(model, "ollama/")) && cfg.Providers.Ollama.APIKey != "":
fmt.Println("Ollama provider selected based on model name prefix")
apiKey = cfg.Providers.Ollama.APIKey
apiBase = cfg.Providers.Ollama.APIBase
proxy = cfg.Providers.Ollama.Proxy
if apiBase == "" {
apiBase = "http://localhost:11434/v1"
}
fmt.Println("Ollama apiBase:", apiBase)
case (strings.Contains(lowerModel, "doubao") || strings.HasPrefix(lowerModel, "doubao") || strings.Contains(lowerModel, "volcengine")) && cfg.Providers.VolcEngine.APIKey != "":
apiKey = cfg.Providers.VolcEngine.APIKey
apiBase = cfg.Providers.VolcEngine.APIBase
proxy = cfg.Providers.VolcEngine.Proxy
if apiBase == "" {
apiBase = "https://ark.cn-beijing.volces.com/api/v3"
}
case cfg.Providers.VLLM.APIBase != "":
apiKey = cfg.Providers.VLLM.APIKey
apiBase = cfg.Providers.VLLM.APIBase
proxy = cfg.Providers.VLLM.Proxy
default:
if cfg.Providers.OpenRouter.APIKey != "" {
apiKey = cfg.Providers.OpenRouter.APIKey
proxy = cfg.Providers.OpenRouter.Proxy
if cfg.Providers.OpenRouter.APIBase != "" {
apiBase = cfg.Providers.OpenRouter.APIBase
} else {
apiBase = "https://openrouter.ai/api/v1"
}
} else {
return nil, "", fmt.Errorf("no API key configured for model: %s", model)
}
}
}
if apiKey == "" && !strings.HasPrefix(model, "bedrock/") {
return nil, "", fmt.Errorf("no API key configured for provider (model: %s)", model)
}
if apiBase == "" {
return nil, "", fmt.Errorf("no API base configured for provider (model: %s)", model)
}
return NewHTTPProvider(apiKey, apiBase, proxy), model, nil
}
-113
View File
@@ -1,113 +0,0 @@
// PicoClaw - Ultra-lightweight personal AI agent
// License: MIT
//
// Copyright (c) 2026 PicoClaw contributors
package providers
import (
"fmt"
"sync"
"sync/atomic"
"github.com/sipeed/picoclaw/pkg/config"
)
// ModelRegistry manages model configurations with thread-safe round-robin load balancing.
// It allows multiple configurations for the same model_name to distribute load across endpoints.
type ModelRegistry struct {
configs map[string][]config.ModelConfig // model_name -> []ModelConfig
counters map[string]*atomic.Uint64 // model_name -> round-robin counter
mu sync.RWMutex
}
// NewModelRegistry creates a new ModelRegistry from a slice of ModelConfig.
func NewModelRegistry(modelList []config.ModelConfig) *ModelRegistry {
r := &ModelRegistry{
configs: make(map[string][]config.ModelConfig),
counters: make(map[string]*atomic.Uint64),
}
for _, cfg := range modelList {
r.configs[cfg.ModelName] = append(r.configs[cfg.ModelName], cfg)
}
// Initialize counters for models with multiple configs
for name, cfgs := range r.configs {
if len(cfgs) > 1 {
r.counters[name] = &atomic.Uint64{}
}
}
return r
}
// GetModelConfig returns a ModelConfig for the given model name.
// If multiple configs exist for the same model_name, it uses round-robin selection.
// Returns an error if the model is not found.
func (r *ModelRegistry) GetModelConfig(modelName string) (*config.ModelConfig, error) {
r.mu.RLock()
defer r.mu.RUnlock()
configs, ok := r.configs[modelName]
if !ok || len(configs) == 0 {
return nil, fmt.Errorf("model %q not found", modelName)
}
// Single config - return directly
if len(configs) == 1 {
return &configs[0], nil
}
// Multiple configs - use round-robin for load balancing
counter, ok := r.counters[modelName]
if !ok {
// Should not happen, but handle gracefully
return &configs[0], nil
}
idx := counter.Add(1) % uint64(len(configs))
return &configs[idx], nil
}
// AddConfig adds a new ModelConfig to the registry.
func (r *ModelRegistry) AddConfig(cfg config.ModelConfig) {
r.mu.Lock()
defer r.mu.Unlock()
r.configs[cfg.ModelName] = append(r.configs[cfg.ModelName], cfg)
// Initialize counter if we now have multiple configs
if len(r.configs[cfg.ModelName]) > 1 && r.counters[cfg.ModelName] == nil {
r.counters[cfg.ModelName] = &atomic.Uint64{}
}
}
// RemoveConfig removes all configs with the given model_name.
func (r *ModelRegistry) RemoveConfig(modelName string) {
r.mu.Lock()
defer r.mu.Unlock()
delete(r.configs, modelName)
delete(r.counters, modelName)
}
// ListModels returns all unique model names in the registry.
func (r *ModelRegistry) ListModels() []string {
r.mu.RLock()
defer r.mu.RUnlock()
names := make([]string, 0, len(r.configs))
for name := range r.configs {
names = append(names, name)
}
return names
}
// ConfigCount returns the number of configurations for a given model name.
func (r *ModelRegistry) ConfigCount(modelName string) int {
r.mu.RLock()
defer r.mu.RUnlock()
return len(r.configs[modelName])
}