mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
refactor: reorganize commands and provider architecture
Refactor command handlers into separate files to improve code organization and maintainability. Each command (agent, auth, cron, gateway, migrate, onboard, skills, status) now has its own dedicated file. Restructure provider creation to support new model_list configuration system that enables zero-code addition of OpenAI-compatible providers. Move legacy provider logic to separate file for backward compatibility. Move configuration functions from config.go to separate files (defaults.go, migration.go) for better organization.
This commit is contained in:
+47
-372
@@ -232,23 +232,6 @@ func (c *ModelConfig) Validate() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ParseProtocol extracts the protocol prefix and model identifier from the Model field.
|
||||
// If no prefix is specified, it defaults to "openai".
|
||||
// Examples:
|
||||
// - "openai/gpt-4o" -> ("openai", "gpt-4o")
|
||||
// - "anthropic/claude-3" -> ("anthropic", "claude-3")
|
||||
// - "gpt-4o" -> ("openai", "gpt-4o") // default protocol
|
||||
func (c *ModelConfig) ParseProtocol() (protocol, modelID string) {
|
||||
model := c.Model
|
||||
for i := 0; i < len(model); i++ {
|
||||
if model[i] == '/' {
|
||||
return model[:i], model[i+1:]
|
||||
}
|
||||
}
|
||||
// No prefix found, default to openai
|
||||
return "openai", model
|
||||
}
|
||||
|
||||
type GatewayConfig struct {
|
||||
Host string `json:"host" env:"PICOCLAW_GATEWAY_HOST"`
|
||||
Port int `json:"port" env:"PICOCLAW_GATEWAY_PORT"`
|
||||
@@ -286,135 +269,6 @@ type ToolsConfig struct {
|
||||
Cron CronToolsConfig `json:"cron"`
|
||||
}
|
||||
|
||||
func DefaultConfig() *Config {
|
||||
return &Config{
|
||||
Agents: AgentsConfig{
|
||||
Defaults: AgentDefaults{
|
||||
Workspace: "~/.picoclaw/workspace",
|
||||
RestrictToWorkspace: true,
|
||||
Provider: "",
|
||||
Model: "glm-4.7",
|
||||
MaxTokens: 8192,
|
||||
Temperature: 0.7,
|
||||
MaxToolIterations: 20,
|
||||
},
|
||||
},
|
||||
Channels: ChannelsConfig{
|
||||
WhatsApp: WhatsAppConfig{
|
||||
Enabled: false,
|
||||
BridgeURL: "ws://localhost:3001",
|
||||
AllowFrom: FlexibleStringSlice{},
|
||||
},
|
||||
Telegram: TelegramConfig{
|
||||
Enabled: false,
|
||||
Token: "",
|
||||
AllowFrom: FlexibleStringSlice{},
|
||||
},
|
||||
Feishu: FeishuConfig{
|
||||
Enabled: false,
|
||||
AppID: "",
|
||||
AppSecret: "",
|
||||
EncryptKey: "",
|
||||
VerificationToken: "",
|
||||
AllowFrom: FlexibleStringSlice{},
|
||||
},
|
||||
Discord: DiscordConfig{
|
||||
Enabled: false,
|
||||
Token: "",
|
||||
AllowFrom: FlexibleStringSlice{},
|
||||
},
|
||||
MaixCam: MaixCamConfig{
|
||||
Enabled: false,
|
||||
Host: "0.0.0.0",
|
||||
Port: 18790,
|
||||
AllowFrom: FlexibleStringSlice{},
|
||||
},
|
||||
QQ: QQConfig{
|
||||
Enabled: false,
|
||||
AppID: "",
|
||||
AppSecret: "",
|
||||
AllowFrom: FlexibleStringSlice{},
|
||||
},
|
||||
DingTalk: DingTalkConfig{
|
||||
Enabled: false,
|
||||
ClientID: "",
|
||||
ClientSecret: "",
|
||||
AllowFrom: FlexibleStringSlice{},
|
||||
},
|
||||
Slack: SlackConfig{
|
||||
Enabled: false,
|
||||
BotToken: "",
|
||||
AppToken: "",
|
||||
AllowFrom: FlexibleStringSlice{},
|
||||
},
|
||||
LINE: LINEConfig{
|
||||
Enabled: false,
|
||||
ChannelSecret: "",
|
||||
ChannelAccessToken: "",
|
||||
WebhookHost: "0.0.0.0",
|
||||
WebhookPort: 18791,
|
||||
WebhookPath: "/webhook/line",
|
||||
AllowFrom: FlexibleStringSlice{},
|
||||
},
|
||||
OneBot: OneBotConfig{
|
||||
Enabled: false,
|
||||
WSUrl: "ws://127.0.0.1:3001",
|
||||
AccessToken: "",
|
||||
ReconnectInterval: 5,
|
||||
GroupTriggerPrefix: []string{},
|
||||
AllowFrom: FlexibleStringSlice{},
|
||||
},
|
||||
},
|
||||
Providers: ProvidersConfig{
|
||||
Anthropic: ProviderConfig{},
|
||||
OpenAI: ProviderConfig{},
|
||||
OpenRouter: ProviderConfig{},
|
||||
Groq: ProviderConfig{},
|
||||
Zhipu: ProviderConfig{},
|
||||
VLLM: ProviderConfig{},
|
||||
Gemini: ProviderConfig{},
|
||||
Nvidia: ProviderConfig{},
|
||||
Moonshot: ProviderConfig{},
|
||||
ShengSuanYun: ProviderConfig{},
|
||||
Cerebras: ProviderConfig{},
|
||||
VolcEngine: ProviderConfig{},
|
||||
},
|
||||
Gateway: GatewayConfig{
|
||||
Host: "0.0.0.0",
|
||||
Port: 18790,
|
||||
},
|
||||
Tools: ToolsConfig{
|
||||
Web: WebToolsConfig{
|
||||
Brave: BraveConfig{
|
||||
Enabled: false,
|
||||
APIKey: "",
|
||||
MaxResults: 5,
|
||||
},
|
||||
DuckDuckGo: DuckDuckGoConfig{
|
||||
Enabled: true,
|
||||
MaxResults: 5,
|
||||
},
|
||||
Perplexity: PerplexityConfig{
|
||||
Enabled: false,
|
||||
APIKey: "",
|
||||
MaxResults: 5,
|
||||
},
|
||||
},
|
||||
Cron: CronToolsConfig{
|
||||
ExecTimeoutMinutes: 5, // default 5 minutes for LLM operations
|
||||
},
|
||||
},
|
||||
Heartbeat: HeartbeatConfig{
|
||||
Enabled: true,
|
||||
Interval: 30, // default 30 minutes
|
||||
},
|
||||
Devices: DevicesConfig{
|
||||
Enabled: false,
|
||||
MonitorUSB: true,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func LoadConfig(path string) (*Config, error) {
|
||||
cfg := DefaultConfig()
|
||||
|
||||
@@ -528,40 +382,61 @@ func expandHome(path string) string {
|
||||
// GetModelConfig returns the ModelConfig for the given model name.
|
||||
// If multiple configs exist with the same model_name, it uses round-robin
|
||||
// selection for load balancing. Returns an error if the model is not found.
|
||||
// Uses double-check locking for optimal read performance.
|
||||
func (c *Config) GetModelConfig(modelName string) (*ModelConfig, error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
// First pass: use read lock to find matches
|
||||
c.mu.RLock()
|
||||
matches := c.findMatchesLocked(modelName)
|
||||
if len(matches) == 0 {
|
||||
c.mu.RUnlock()
|
||||
return nil, fmt.Errorf("model %q not found in model_list or providers", modelName)
|
||||
}
|
||||
if len(matches) == 1 {
|
||||
c.mu.RUnlock()
|
||||
return &matches[0], nil
|
||||
}
|
||||
|
||||
// Find all configs with matching model_name
|
||||
// Multiple configs - check if counter exists
|
||||
counter, ok := c.rrCounters[modelName]
|
||||
c.mu.RUnlock()
|
||||
|
||||
// Double-check locking: only acquire write lock if counter needs initialization
|
||||
if !ok {
|
||||
c.mu.Lock()
|
||||
// Re-check after acquiring write lock
|
||||
if c.rrCounters == nil {
|
||||
c.rrCounters = make(map[string]*atomic.Uint64)
|
||||
}
|
||||
if c.rrCounters[modelName] == nil {
|
||||
c.rrCounters[modelName] = &atomic.Uint64{}
|
||||
}
|
||||
counter = c.rrCounters[modelName]
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
// Re-fetch matches to ensure consistency (ModelList could have changed)
|
||||
c.mu.RLock()
|
||||
matches = c.findMatchesLocked(modelName)
|
||||
c.mu.RUnlock()
|
||||
|
||||
if len(matches) == 0 {
|
||||
return nil, fmt.Errorf("model %q not found in model_list or providers", modelName)
|
||||
}
|
||||
|
||||
idx := counter.Add(1) % uint64(len(matches))
|
||||
return &matches[idx], nil
|
||||
}
|
||||
|
||||
// findMatchesLocked finds all ModelConfig entries with the given model_name.
|
||||
// Must be called with c.mu locked (read or write).
|
||||
func (c *Config) findMatchesLocked(modelName string) []ModelConfig {
|
||||
var matches []ModelConfig
|
||||
for i := range c.ModelList {
|
||||
if c.ModelList[i].ModelName == modelName {
|
||||
matches = append(matches, c.ModelList[i])
|
||||
}
|
||||
}
|
||||
|
||||
if len(matches) == 0 {
|
||||
return nil, fmt.Errorf("model %q not found in model_list or providers", modelName)
|
||||
}
|
||||
|
||||
// Single config - return directly
|
||||
if len(matches) == 1 {
|
||||
return &matches[0], nil
|
||||
}
|
||||
|
||||
// Multiple configs - use round-robin for load balancing
|
||||
if c.rrCounters == nil {
|
||||
c.rrCounters = make(map[string]*atomic.Uint64)
|
||||
}
|
||||
|
||||
counter, ok := c.rrCounters[modelName]
|
||||
if !ok {
|
||||
counter = &atomic.Uint64{}
|
||||
c.rrCounters[modelName] = counter
|
||||
}
|
||||
|
||||
idx := counter.Add(1) % uint64(len(matches))
|
||||
return &matches[idx], nil
|
||||
return matches
|
||||
}
|
||||
|
||||
// HasProvidersConfig checks if any provider in the old providers config has configuration.
|
||||
@@ -599,203 +474,3 @@ func (c *Config) ValidateModelList() error {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ConvertProvidersToModelList converts the old ProvidersConfig to a slice of ModelConfig.
|
||||
// This enables backward compatibility with existing configurations.
|
||||
func ConvertProvidersToModelList(cfg *Config) []ModelConfig {
|
||||
if cfg == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var result []ModelConfig
|
||||
p := cfg.Providers
|
||||
|
||||
// OpenAI
|
||||
if p.OpenAI.APIKey != "" || p.OpenAI.APIBase != "" {
|
||||
result = append(result, ModelConfig{
|
||||
ModelName: "openai",
|
||||
Model: "openai/gpt-4o",
|
||||
APIKey: p.OpenAI.APIKey,
|
||||
APIBase: p.OpenAI.APIBase,
|
||||
Proxy: p.OpenAI.Proxy,
|
||||
AuthMethod: p.OpenAI.AuthMethod,
|
||||
})
|
||||
}
|
||||
|
||||
// Anthropic
|
||||
if p.Anthropic.APIKey != "" || p.Anthropic.APIBase != "" {
|
||||
result = append(result, ModelConfig{
|
||||
ModelName: "anthropic",
|
||||
Model: "anthropic/claude-3-sonnet",
|
||||
APIKey: p.Anthropic.APIKey,
|
||||
APIBase: p.Anthropic.APIBase,
|
||||
Proxy: p.Anthropic.Proxy,
|
||||
AuthMethod: p.Anthropic.AuthMethod,
|
||||
})
|
||||
}
|
||||
|
||||
// OpenRouter
|
||||
if p.OpenRouter.APIKey != "" || p.OpenRouter.APIBase != "" {
|
||||
result = append(result, ModelConfig{
|
||||
ModelName: "openrouter",
|
||||
Model: "openrouter/auto",
|
||||
APIKey: p.OpenRouter.APIKey,
|
||||
APIBase: p.OpenRouter.APIBase,
|
||||
Proxy: p.OpenRouter.Proxy,
|
||||
})
|
||||
}
|
||||
|
||||
// Groq
|
||||
if p.Groq.APIKey != "" || p.Groq.APIBase != "" {
|
||||
result = append(result, ModelConfig{
|
||||
ModelName: "groq",
|
||||
Model: "groq/llama-3.1-70b-versatile",
|
||||
APIKey: p.Groq.APIKey,
|
||||
APIBase: p.Groq.APIBase,
|
||||
Proxy: p.Groq.Proxy,
|
||||
})
|
||||
}
|
||||
|
||||
// Zhipu
|
||||
if p.Zhipu.APIKey != "" || p.Zhipu.APIBase != "" {
|
||||
result = append(result, ModelConfig{
|
||||
ModelName: "zhipu",
|
||||
Model: "openai/glm-4",
|
||||
APIKey: p.Zhipu.APIKey,
|
||||
APIBase: p.Zhipu.APIBase,
|
||||
Proxy: p.Zhipu.Proxy,
|
||||
})
|
||||
}
|
||||
|
||||
// VLLM
|
||||
if p.VLLM.APIKey != "" || p.VLLM.APIBase != "" {
|
||||
result = append(result, ModelConfig{
|
||||
ModelName: "vllm",
|
||||
Model: "openai/auto",
|
||||
APIKey: p.VLLM.APIKey,
|
||||
APIBase: p.VLLM.APIBase,
|
||||
Proxy: p.VLLM.Proxy,
|
||||
})
|
||||
}
|
||||
|
||||
// Gemini
|
||||
if p.Gemini.APIKey != "" || p.Gemini.APIBase != "" {
|
||||
result = append(result, ModelConfig{
|
||||
ModelName: "gemini",
|
||||
Model: "openai/gemini-pro",
|
||||
APIKey: p.Gemini.APIKey,
|
||||
APIBase: p.Gemini.APIBase,
|
||||
Proxy: p.Gemini.Proxy,
|
||||
})
|
||||
}
|
||||
|
||||
// Nvidia
|
||||
if p.Nvidia.APIKey != "" || p.Nvidia.APIBase != "" {
|
||||
result = append(result, ModelConfig{
|
||||
ModelName: "nvidia",
|
||||
Model: "nvidia/meta/llama-3.1-8b-instruct",
|
||||
APIKey: p.Nvidia.APIKey,
|
||||
APIBase: p.Nvidia.APIBase,
|
||||
Proxy: p.Nvidia.Proxy,
|
||||
})
|
||||
}
|
||||
|
||||
// Ollama
|
||||
if p.Ollama.APIKey != "" || p.Ollama.APIBase != "" {
|
||||
result = append(result, ModelConfig{
|
||||
ModelName: "ollama",
|
||||
Model: "ollama/llama3",
|
||||
APIKey: p.Ollama.APIKey,
|
||||
APIBase: p.Ollama.APIBase,
|
||||
Proxy: p.Ollama.Proxy,
|
||||
})
|
||||
}
|
||||
|
||||
// Moonshot
|
||||
if p.Moonshot.APIKey != "" || p.Moonshot.APIBase != "" {
|
||||
result = append(result, ModelConfig{
|
||||
ModelName: "moonshot",
|
||||
Model: "moonshot/kimi",
|
||||
APIKey: p.Moonshot.APIKey,
|
||||
APIBase: p.Moonshot.APIBase,
|
||||
Proxy: p.Moonshot.Proxy,
|
||||
})
|
||||
}
|
||||
|
||||
// ShengSuanYun
|
||||
if p.ShengSuanYun.APIKey != "" || p.ShengSuanYun.APIBase != "" {
|
||||
result = append(result, ModelConfig{
|
||||
ModelName: "shengsuanyun",
|
||||
Model: "openai/auto",
|
||||
APIKey: p.ShengSuanYun.APIKey,
|
||||
APIBase: p.ShengSuanYun.APIBase,
|
||||
Proxy: p.ShengSuanYun.Proxy,
|
||||
})
|
||||
}
|
||||
|
||||
// DeepSeek
|
||||
if p.DeepSeek.APIKey != "" || p.DeepSeek.APIBase != "" {
|
||||
result = append(result, ModelConfig{
|
||||
ModelName: "deepseek",
|
||||
Model: "openai/deepseek-chat",
|
||||
APIKey: p.DeepSeek.APIKey,
|
||||
APIBase: p.DeepSeek.APIBase,
|
||||
Proxy: p.DeepSeek.Proxy,
|
||||
})
|
||||
}
|
||||
|
||||
// Cerebras
|
||||
if p.Cerebras.APIKey != "" || p.Cerebras.APIBase != "" {
|
||||
result = append(result, ModelConfig{
|
||||
ModelName: "cerebras",
|
||||
Model: "cerebras/llama-3.3-70b",
|
||||
APIKey: p.Cerebras.APIKey,
|
||||
APIBase: p.Cerebras.APIBase,
|
||||
Proxy: p.Cerebras.Proxy,
|
||||
})
|
||||
}
|
||||
|
||||
// VolcEngine (Doubao)
|
||||
if p.VolcEngine.APIKey != "" || p.VolcEngine.APIBase != "" {
|
||||
result = append(result, ModelConfig{
|
||||
ModelName: "volcengine",
|
||||
Model: "openai/doubao-pro",
|
||||
APIKey: p.VolcEngine.APIKey,
|
||||
APIBase: p.VolcEngine.APIBase,
|
||||
Proxy: p.VolcEngine.Proxy,
|
||||
})
|
||||
}
|
||||
|
||||
// GitHub Copilot
|
||||
if p.GitHubCopilot.APIKey != "" || p.GitHubCopilot.APIBase != "" || p.GitHubCopilot.ConnectMode != "" {
|
||||
result = append(result, ModelConfig{
|
||||
ModelName: "github-copilot",
|
||||
Model: "github-copilot/gpt-4o",
|
||||
APIBase: p.GitHubCopilot.APIBase,
|
||||
ConnectMode: p.GitHubCopilot.ConnectMode,
|
||||
})
|
||||
}
|
||||
|
||||
// Antigravity
|
||||
if p.Antigravity.APIKey != "" || p.Antigravity.AuthMethod != "" {
|
||||
result = append(result, ModelConfig{
|
||||
ModelName: "antigravity",
|
||||
Model: "antigravity/gemini-2.0-flash",
|
||||
APIKey: p.Antigravity.APIKey,
|
||||
AuthMethod: p.Antigravity.AuthMethod,
|
||||
})
|
||||
}
|
||||
|
||||
// Qwen
|
||||
if p.Qwen.APIKey != "" || p.Qwen.APIBase != "" {
|
||||
result = append(result, ModelConfig{
|
||||
ModelName: "qwen",
|
||||
Model: "qwen/qwen-max",
|
||||
APIKey: p.Qwen.APIKey,
|
||||
APIBase: p.Qwen.APIBase,
|
||||
Proxy: p.Qwen.Proxy,
|
||||
})
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
@@ -0,0 +1,136 @@
|
||||
// PicoClaw - Ultra-lightweight personal AI agent
|
||||
// License: MIT
|
||||
//
|
||||
// Copyright (c) 2026 PicoClaw contributors
|
||||
|
||||
package config
|
||||
|
||||
// DefaultConfig returns the default configuration for PicoClaw.
|
||||
func DefaultConfig() *Config {
|
||||
return &Config{
|
||||
Agents: AgentsConfig{
|
||||
Defaults: AgentDefaults{
|
||||
Workspace: "~/.picoclaw/workspace",
|
||||
RestrictToWorkspace: true,
|
||||
Provider: "",
|
||||
Model: "glm-4.7",
|
||||
MaxTokens: 8192,
|
||||
Temperature: 0.7,
|
||||
MaxToolIterations: 20,
|
||||
},
|
||||
},
|
||||
Channels: ChannelsConfig{
|
||||
WhatsApp: WhatsAppConfig{
|
||||
Enabled: false,
|
||||
BridgeURL: "ws://localhost:3001",
|
||||
AllowFrom: FlexibleStringSlice{},
|
||||
},
|
||||
Telegram: TelegramConfig{
|
||||
Enabled: false,
|
||||
Token: "",
|
||||
AllowFrom: FlexibleStringSlice{},
|
||||
},
|
||||
Feishu: FeishuConfig{
|
||||
Enabled: false,
|
||||
AppID: "",
|
||||
AppSecret: "",
|
||||
EncryptKey: "",
|
||||
VerificationToken: "",
|
||||
AllowFrom: FlexibleStringSlice{},
|
||||
},
|
||||
Discord: DiscordConfig{
|
||||
Enabled: false,
|
||||
Token: "",
|
||||
AllowFrom: FlexibleStringSlice{},
|
||||
},
|
||||
MaixCam: MaixCamConfig{
|
||||
Enabled: false,
|
||||
Host: "0.0.0.0",
|
||||
Port: 18790,
|
||||
AllowFrom: FlexibleStringSlice{},
|
||||
},
|
||||
QQ: QQConfig{
|
||||
Enabled: false,
|
||||
AppID: "",
|
||||
AppSecret: "",
|
||||
AllowFrom: FlexibleStringSlice{},
|
||||
},
|
||||
DingTalk: DingTalkConfig{
|
||||
Enabled: false,
|
||||
ClientID: "",
|
||||
ClientSecret: "",
|
||||
AllowFrom: FlexibleStringSlice{},
|
||||
},
|
||||
Slack: SlackConfig{
|
||||
Enabled: false,
|
||||
BotToken: "",
|
||||
AppToken: "",
|
||||
AllowFrom: FlexibleStringSlice{},
|
||||
},
|
||||
LINE: LINEConfig{
|
||||
Enabled: false,
|
||||
ChannelSecret: "",
|
||||
ChannelAccessToken: "",
|
||||
WebhookHost: "0.0.0.0",
|
||||
WebhookPort: 18791,
|
||||
WebhookPath: "/webhook/line",
|
||||
AllowFrom: FlexibleStringSlice{},
|
||||
},
|
||||
OneBot: OneBotConfig{
|
||||
Enabled: false,
|
||||
WSUrl: "ws://127.0.0.1:3001",
|
||||
AccessToken: "",
|
||||
ReconnectInterval: 5,
|
||||
GroupTriggerPrefix: []string{},
|
||||
AllowFrom: FlexibleStringSlice{},
|
||||
},
|
||||
},
|
||||
Providers: ProvidersConfig{
|
||||
Anthropic: ProviderConfig{},
|
||||
OpenAI: ProviderConfig{},
|
||||
OpenRouter: ProviderConfig{},
|
||||
Groq: ProviderConfig{},
|
||||
Zhipu: ProviderConfig{},
|
||||
VLLM: ProviderConfig{},
|
||||
Gemini: ProviderConfig{},
|
||||
Nvidia: ProviderConfig{},
|
||||
Moonshot: ProviderConfig{},
|
||||
ShengSuanYun: ProviderConfig{},
|
||||
Cerebras: ProviderConfig{},
|
||||
VolcEngine: ProviderConfig{},
|
||||
},
|
||||
Gateway: GatewayConfig{
|
||||
Host: "0.0.0.0",
|
||||
Port: 18790,
|
||||
},
|
||||
Tools: ToolsConfig{
|
||||
Web: WebToolsConfig{
|
||||
Brave: BraveConfig{
|
||||
Enabled: false,
|
||||
APIKey: "",
|
||||
MaxResults: 5,
|
||||
},
|
||||
DuckDuckGo: DuckDuckGoConfig{
|
||||
Enabled: true,
|
||||
MaxResults: 5,
|
||||
},
|
||||
Perplexity: PerplexityConfig{
|
||||
Enabled: false,
|
||||
APIKey: "",
|
||||
MaxResults: 5,
|
||||
},
|
||||
},
|
||||
Cron: CronToolsConfig{
|
||||
ExecTimeoutMinutes: 5, // default 5 minutes for LLM operations
|
||||
},
|
||||
},
|
||||
Heartbeat: HeartbeatConfig{
|
||||
Enabled: true,
|
||||
Interval: 30, // default 30 minutes
|
||||
},
|
||||
Devices: DevicesConfig{
|
||||
Enabled: false,
|
||||
MonitorUSB: true,
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,206 @@
|
||||
// PicoClaw - Ultra-lightweight personal AI agent
|
||||
// License: MIT
|
||||
//
|
||||
// Copyright (c) 2026 PicoClaw contributors
|
||||
|
||||
package config
|
||||
|
||||
// ConvertProvidersToModelList converts the old ProvidersConfig to a slice of ModelConfig.
|
||||
// This enables backward compatibility with existing configurations.
|
||||
func ConvertProvidersToModelList(cfg *Config) []ModelConfig {
|
||||
if cfg == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var result []ModelConfig
|
||||
p := cfg.Providers
|
||||
|
||||
// OpenAI
|
||||
if p.OpenAI.APIKey != "" || p.OpenAI.APIBase != "" {
|
||||
result = append(result, ModelConfig{
|
||||
ModelName: "openai",
|
||||
Model: "openai/gpt-4o",
|
||||
APIKey: p.OpenAI.APIKey,
|
||||
APIBase: p.OpenAI.APIBase,
|
||||
Proxy: p.OpenAI.Proxy,
|
||||
AuthMethod: p.OpenAI.AuthMethod,
|
||||
})
|
||||
}
|
||||
|
||||
// Anthropic
|
||||
if p.Anthropic.APIKey != "" || p.Anthropic.APIBase != "" {
|
||||
result = append(result, ModelConfig{
|
||||
ModelName: "anthropic",
|
||||
Model: "anthropic/claude-3-sonnet",
|
||||
APIKey: p.Anthropic.APIKey,
|
||||
APIBase: p.Anthropic.APIBase,
|
||||
Proxy: p.Anthropic.Proxy,
|
||||
AuthMethod: p.Anthropic.AuthMethod,
|
||||
})
|
||||
}
|
||||
|
||||
// OpenRouter
|
||||
if p.OpenRouter.APIKey != "" || p.OpenRouter.APIBase != "" {
|
||||
result = append(result, ModelConfig{
|
||||
ModelName: "openrouter",
|
||||
Model: "openrouter/auto",
|
||||
APIKey: p.OpenRouter.APIKey,
|
||||
APIBase: p.OpenRouter.APIBase,
|
||||
Proxy: p.OpenRouter.Proxy,
|
||||
})
|
||||
}
|
||||
|
||||
// Groq
|
||||
if p.Groq.APIKey != "" || p.Groq.APIBase != "" {
|
||||
result = append(result, ModelConfig{
|
||||
ModelName: "groq",
|
||||
Model: "groq/llama-3.1-70b-versatile",
|
||||
APIKey: p.Groq.APIKey,
|
||||
APIBase: p.Groq.APIBase,
|
||||
Proxy: p.Groq.Proxy,
|
||||
})
|
||||
}
|
||||
|
||||
// Zhipu
|
||||
if p.Zhipu.APIKey != "" || p.Zhipu.APIBase != "" {
|
||||
result = append(result, ModelConfig{
|
||||
ModelName: "zhipu",
|
||||
Model: "openai/glm-4",
|
||||
APIKey: p.Zhipu.APIKey,
|
||||
APIBase: p.Zhipu.APIBase,
|
||||
Proxy: p.Zhipu.Proxy,
|
||||
})
|
||||
}
|
||||
|
||||
// VLLM
|
||||
if p.VLLM.APIKey != "" || p.VLLM.APIBase != "" {
|
||||
result = append(result, ModelConfig{
|
||||
ModelName: "vllm",
|
||||
Model: "openai/auto",
|
||||
APIKey: p.VLLM.APIKey,
|
||||
APIBase: p.VLLM.APIBase,
|
||||
Proxy: p.VLLM.Proxy,
|
||||
})
|
||||
}
|
||||
|
||||
// Gemini
|
||||
if p.Gemini.APIKey != "" || p.Gemini.APIBase != "" {
|
||||
result = append(result, ModelConfig{
|
||||
ModelName: "gemini",
|
||||
Model: "openai/gemini-pro",
|
||||
APIKey: p.Gemini.APIKey,
|
||||
APIBase: p.Gemini.APIBase,
|
||||
Proxy: p.Gemini.Proxy,
|
||||
})
|
||||
}
|
||||
|
||||
// Nvidia
|
||||
if p.Nvidia.APIKey != "" || p.Nvidia.APIBase != "" {
|
||||
result = append(result, ModelConfig{
|
||||
ModelName: "nvidia",
|
||||
Model: "nvidia/meta/llama-3.1-8b-instruct",
|
||||
APIKey: p.Nvidia.APIKey,
|
||||
APIBase: p.Nvidia.APIBase,
|
||||
Proxy: p.Nvidia.Proxy,
|
||||
})
|
||||
}
|
||||
|
||||
// Ollama
|
||||
if p.Ollama.APIKey != "" || p.Ollama.APIBase != "" {
|
||||
result = append(result, ModelConfig{
|
||||
ModelName: "ollama",
|
||||
Model: "ollama/llama3",
|
||||
APIKey: p.Ollama.APIKey,
|
||||
APIBase: p.Ollama.APIBase,
|
||||
Proxy: p.Ollama.Proxy,
|
||||
})
|
||||
}
|
||||
|
||||
// Moonshot
|
||||
if p.Moonshot.APIKey != "" || p.Moonshot.APIBase != "" {
|
||||
result = append(result, ModelConfig{
|
||||
ModelName: "moonshot",
|
||||
Model: "moonshot/kimi",
|
||||
APIKey: p.Moonshot.APIKey,
|
||||
APIBase: p.Moonshot.APIBase,
|
||||
Proxy: p.Moonshot.Proxy,
|
||||
})
|
||||
}
|
||||
|
||||
// ShengSuanYun
|
||||
if p.ShengSuanYun.APIKey != "" || p.ShengSuanYun.APIBase != "" {
|
||||
result = append(result, ModelConfig{
|
||||
ModelName: "shengsuanyun",
|
||||
Model: "openai/auto",
|
||||
APIKey: p.ShengSuanYun.APIKey,
|
||||
APIBase: p.ShengSuanYun.APIBase,
|
||||
Proxy: p.ShengSuanYun.Proxy,
|
||||
})
|
||||
}
|
||||
|
||||
// DeepSeek
|
||||
if p.DeepSeek.APIKey != "" || p.DeepSeek.APIBase != "" {
|
||||
result = append(result, ModelConfig{
|
||||
ModelName: "deepseek",
|
||||
Model: "openai/deepseek-chat",
|
||||
APIKey: p.DeepSeek.APIKey,
|
||||
APIBase: p.DeepSeek.APIBase,
|
||||
Proxy: p.DeepSeek.Proxy,
|
||||
})
|
||||
}
|
||||
|
||||
// Cerebras
|
||||
if p.Cerebras.APIKey != "" || p.Cerebras.APIBase != "" {
|
||||
result = append(result, ModelConfig{
|
||||
ModelName: "cerebras",
|
||||
Model: "cerebras/llama-3.3-70b",
|
||||
APIKey: p.Cerebras.APIKey,
|
||||
APIBase: p.Cerebras.APIBase,
|
||||
Proxy: p.Cerebras.Proxy,
|
||||
})
|
||||
}
|
||||
|
||||
// VolcEngine (Doubao)
|
||||
if p.VolcEngine.APIKey != "" || p.VolcEngine.APIBase != "" {
|
||||
result = append(result, ModelConfig{
|
||||
ModelName: "volcengine",
|
||||
Model: "openai/doubao-pro",
|
||||
APIKey: p.VolcEngine.APIKey,
|
||||
APIBase: p.VolcEngine.APIBase,
|
||||
Proxy: p.VolcEngine.Proxy,
|
||||
})
|
||||
}
|
||||
|
||||
// GitHub Copilot
|
||||
if p.GitHubCopilot.APIKey != "" || p.GitHubCopilot.APIBase != "" || p.GitHubCopilot.ConnectMode != "" {
|
||||
result = append(result, ModelConfig{
|
||||
ModelName: "github-copilot",
|
||||
Model: "github-copilot/gpt-4o",
|
||||
APIBase: p.GitHubCopilot.APIBase,
|
||||
ConnectMode: p.GitHubCopilot.ConnectMode,
|
||||
})
|
||||
}
|
||||
|
||||
// Antigravity
|
||||
if p.Antigravity.APIKey != "" || p.Antigravity.AuthMethod != "" {
|
||||
result = append(result, ModelConfig{
|
||||
ModelName: "antigravity",
|
||||
Model: "antigravity/gemini-2.0-flash",
|
||||
APIKey: p.Antigravity.APIKey,
|
||||
AuthMethod: p.Antigravity.AuthMethod,
|
||||
})
|
||||
}
|
||||
|
||||
// Qwen
|
||||
if p.Qwen.APIKey != "" || p.Qwen.APIBase != "" {
|
||||
result = append(result, ModelConfig{
|
||||
ModelName: "qwen",
|
||||
Model: "qwen/qwen-max",
|
||||
APIKey: p.Qwen.APIKey,
|
||||
APIBase: p.Qwen.APIBase,
|
||||
Proxy: p.Qwen.Proxy,
|
||||
})
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
@@ -0,0 +1,177 @@
|
||||
// PicoClaw - Ultra-lightweight personal AI agent
|
||||
// License: MIT
|
||||
//
|
||||
// Copyright (c) 2026 PicoClaw contributors
|
||||
|
||||
package config
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestConvertProvidersToModelList_OpenAI(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Providers: ProvidersConfig{
|
||||
OpenAI: ProviderConfig{
|
||||
APIKey: "sk-test-key",
|
||||
APIBase: "https://custom.api.com/v1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result := ConvertProvidersToModelList(cfg)
|
||||
|
||||
if len(result) != 1 {
|
||||
t.Fatalf("len(result) = %d, want 1", len(result))
|
||||
}
|
||||
|
||||
if result[0].ModelName != "openai" {
|
||||
t.Errorf("ModelName = %q, want %q", result[0].ModelName, "openai")
|
||||
}
|
||||
if result[0].Model != "openai/gpt-4o" {
|
||||
t.Errorf("Model = %q, want %q", result[0].Model, "openai/gpt-4o")
|
||||
}
|
||||
if result[0].APIKey != "sk-test-key" {
|
||||
t.Errorf("APIKey = %q, want %q", result[0].APIKey, "sk-test-key")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertProvidersToModelList_Anthropic(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Providers: ProvidersConfig{
|
||||
Anthropic: ProviderConfig{
|
||||
APIKey: "ant-key",
|
||||
APIBase: "https://custom.anthropic.com",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result := ConvertProvidersToModelList(cfg)
|
||||
|
||||
if len(result) != 1 {
|
||||
t.Fatalf("len(result) = %d, want 1", len(result))
|
||||
}
|
||||
|
||||
if result[0].ModelName != "anthropic" {
|
||||
t.Errorf("ModelName = %q, want %q", result[0].ModelName, "anthropic")
|
||||
}
|
||||
if result[0].Model != "anthropic/claude-3-sonnet" {
|
||||
t.Errorf("Model = %q, want %q", result[0].Model, "anthropic/claude-3-sonnet")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertProvidersToModelList_Multiple(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Providers: ProvidersConfig{
|
||||
OpenAI: ProviderConfig{APIKey: "openai-key"},
|
||||
Groq: ProviderConfig{APIKey: "groq-key"},
|
||||
Zhipu: ProviderConfig{APIKey: "zhipu-key"},
|
||||
},
|
||||
}
|
||||
|
||||
result := ConvertProvidersToModelList(cfg)
|
||||
|
||||
if len(result) != 3 {
|
||||
t.Fatalf("len(result) = %d, want 3", len(result))
|
||||
}
|
||||
|
||||
// Check that all providers are present
|
||||
found := make(map[string]bool)
|
||||
for _, mc := range result {
|
||||
found[mc.ModelName] = true
|
||||
}
|
||||
|
||||
for _, name := range []string{"openai", "groq", "zhipu"} {
|
||||
if !found[name] {
|
||||
t.Errorf("Missing provider %q in result", name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertProvidersToModelList_Empty(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Providers: ProvidersConfig{},
|
||||
}
|
||||
|
||||
result := ConvertProvidersToModelList(cfg)
|
||||
|
||||
if len(result) != 0 {
|
||||
t.Errorf("len(result) = %d, want 0", len(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertProvidersToModelList_Nil(t *testing.T) {
|
||||
result := ConvertProvidersToModelList(nil)
|
||||
|
||||
if result != nil {
|
||||
t.Errorf("result = %v, want nil", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertProvidersToModelList_AllProviders(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Providers: ProvidersConfig{
|
||||
OpenAI: ProviderConfig{APIKey: "key1"},
|
||||
Anthropic: ProviderConfig{APIKey: "key2"},
|
||||
OpenRouter: ProviderConfig{APIKey: "key3"},
|
||||
Groq: ProviderConfig{APIKey: "key4"},
|
||||
Zhipu: ProviderConfig{APIKey: "key5"},
|
||||
VLLM: ProviderConfig{APIKey: "key6"},
|
||||
Gemini: ProviderConfig{APIKey: "key7"},
|
||||
Nvidia: ProviderConfig{APIKey: "key8"},
|
||||
Ollama: ProviderConfig{APIKey: "key9"},
|
||||
Moonshot: ProviderConfig{APIKey: "key10"},
|
||||
ShengSuanYun: ProviderConfig{APIKey: "key11"},
|
||||
DeepSeek: ProviderConfig{APIKey: "key12"},
|
||||
Cerebras: ProviderConfig{APIKey: "key13"},
|
||||
VolcEngine: ProviderConfig{APIKey: "key14"},
|
||||
GitHubCopilot: ProviderConfig{ConnectMode: "grpc"},
|
||||
Antigravity: ProviderConfig{AuthMethod: "oauth"},
|
||||
Qwen: ProviderConfig{APIKey: "key17"},
|
||||
},
|
||||
}
|
||||
|
||||
result := ConvertProvidersToModelList(cfg)
|
||||
|
||||
// All 17 providers should be converted
|
||||
if len(result) != 17 {
|
||||
t.Errorf("len(result) = %d, want 17", len(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertProvidersToModelList_Proxy(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Providers: ProvidersConfig{
|
||||
OpenAI: ProviderConfig{
|
||||
APIKey: "key",
|
||||
Proxy: "http://proxy:8080",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result := ConvertProvidersToModelList(cfg)
|
||||
|
||||
if len(result) != 1 {
|
||||
t.Fatalf("len(result) = %d, want 1", len(result))
|
||||
}
|
||||
|
||||
if result[0].Proxy != "http://proxy:8080" {
|
||||
t.Errorf("Proxy = %q, want %q", result[0].Proxy, "http://proxy:8080")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertProvidersToModelList_AuthMethod(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Providers: ProvidersConfig{
|
||||
OpenAI: ProviderConfig{
|
||||
AuthMethod: "oauth",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result := ConvertProvidersToModelList(cfg)
|
||||
|
||||
if len(result) != 0 {
|
||||
t.Errorf("len(result) = %d, want 0 (AuthMethod alone should not create entry)", len(result))
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,204 @@
|
||||
// PicoClaw - Ultra-lightweight personal AI agent
|
||||
// License: MIT
|
||||
//
|
||||
// Copyright (c) 2026 PicoClaw contributors
|
||||
|
||||
package config
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGetModelConfig_Found(t *testing.T) {
|
||||
cfg := &Config{
|
||||
ModelList: []ModelConfig{
|
||||
{ModelName: "test-model", Model: "openai/gpt-4o", APIKey: "key1"},
|
||||
{ModelName: "other-model", Model: "anthropic/claude", APIKey: "key2"},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := cfg.GetModelConfig("test-model")
|
||||
if err != nil {
|
||||
t.Fatalf("GetModelConfig() error = %v", err)
|
||||
}
|
||||
if result.Model != "openai/gpt-4o" {
|
||||
t.Errorf("Model = %q, want %q", result.Model, "openai/gpt-4o")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetModelConfig_NotFound(t *testing.T) {
|
||||
cfg := &Config{
|
||||
ModelList: []ModelConfig{
|
||||
{ModelName: "test-model", Model: "openai/gpt-4o", APIKey: "key1"},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := cfg.GetModelConfig("nonexistent")
|
||||
if err == nil {
|
||||
t.Fatal("GetModelConfig() expected error for nonexistent model")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetModelConfig_EmptyList(t *testing.T) {
|
||||
cfg := &Config{
|
||||
ModelList: []ModelConfig{},
|
||||
}
|
||||
|
||||
_, err := cfg.GetModelConfig("any-model")
|
||||
if err == nil {
|
||||
t.Fatal("GetModelConfig() expected error for empty model list")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetModelConfig_RoundRobin(t *testing.T) {
|
||||
cfg := &Config{
|
||||
ModelList: []ModelConfig{
|
||||
{ModelName: "lb-model", Model: "openai/gpt-4o-1", APIKey: "key1"},
|
||||
{ModelName: "lb-model", Model: "openai/gpt-4o-2", APIKey: "key2"},
|
||||
{ModelName: "lb-model", Model: "openai/gpt-4o-3", APIKey: "key3"},
|
||||
},
|
||||
}
|
||||
|
||||
// Test round-robin distribution
|
||||
results := make(map[string]int)
|
||||
for i := 0; i < 30; i++ {
|
||||
result, err := cfg.GetModelConfig("lb-model")
|
||||
if err != nil {
|
||||
t.Fatalf("GetModelConfig() error = %v", err)
|
||||
}
|
||||
results[result.Model]++
|
||||
}
|
||||
|
||||
// Each model should appear roughly 10 times (30 calls / 3 models)
|
||||
for model, count := range results {
|
||||
if count < 5 || count > 15 {
|
||||
t.Errorf("Model %s appeared %d times, expected ~10", model, count)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetModelConfig_Concurrent(t *testing.T) {
|
||||
cfg := &Config{
|
||||
ModelList: []ModelConfig{
|
||||
{ModelName: "concurrent-model", Model: "openai/gpt-4o-1", APIKey: "key1"},
|
||||
{ModelName: "concurrent-model", Model: "openai/gpt-4o-2", APIKey: "key2"},
|
||||
},
|
||||
}
|
||||
|
||||
const goroutines = 100
|
||||
const iterations = 10
|
||||
|
||||
var wg sync.WaitGroup
|
||||
errors := make(chan error, goroutines*iterations)
|
||||
|
||||
for i := 0; i < goroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for j := 0; j < iterations; j++ {
|
||||
_, err := cfg.GetModelConfig("concurrent-model")
|
||||
if err != nil {
|
||||
errors <- err
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(errors)
|
||||
|
||||
for err := range errors {
|
||||
t.Errorf("Concurrent GetModelConfig() error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelConfig_Validate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config ModelConfig
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid config",
|
||||
config: ModelConfig{
|
||||
ModelName: "test",
|
||||
Model: "openai/gpt-4o",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "missing model_name",
|
||||
config: ModelConfig{
|
||||
Model: "openai/gpt-4o",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "missing model",
|
||||
config: ModelConfig{
|
||||
ModelName: "test",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "empty config",
|
||||
config: ModelConfig{},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := tt.config.Validate()
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Validate() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig_ValidateModelList(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config *Config
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid list",
|
||||
config: &Config{
|
||||
ModelList: []ModelConfig{
|
||||
{ModelName: "test1", Model: "openai/gpt-4o"},
|
||||
{ModelName: "test2", Model: "anthropic/claude"},
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid entry",
|
||||
config: &Config{
|
||||
ModelList: []ModelConfig{
|
||||
{ModelName: "test1", Model: "openai/gpt-4o"},
|
||||
{ModelName: "", Model: "anthropic/claude"}, // missing model_name
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "empty list",
|
||||
config: &Config{
|
||||
ModelList: []ModelConfig{},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := tt.config.ValidateModelList()
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("ValidateModelList() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -180,8 +180,8 @@ func TestConvertConfig(t *testing.T) {
|
||||
t.Run("unsupported provider warning", func(t *testing.T) {
|
||||
data := map[string]interface{}{
|
||||
"providers": map[string]interface{}{
|
||||
"deepseek": map[string]interface{}{
|
||||
"api_key": "sk-deep-test",
|
||||
"unknown_provider": map[string]interface{}{
|
||||
"api_key": "sk-test",
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -193,7 +193,7 @@ func TestConvertConfig(t *testing.T) {
|
||||
if len(warnings) != 1 {
|
||||
t.Fatalf("expected 1 warning, got %d", len(warnings))
|
||||
}
|
||||
if warnings[0] != "Provider 'deepseek' not supported in PicoClaw, skipping" {
|
||||
if warnings[0] != "Provider 'unknown_provider' not supported in PicoClaw, skipping" {
|
||||
t.Errorf("unexpected warning: %s", warnings[0])
|
||||
}
|
||||
})
|
||||
|
||||
@@ -419,7 +419,7 @@ func TestCreateProvider_ClaudeCli(t *testing.T) {
|
||||
cfg.Agents.Defaults.Provider = "claude-cli"
|
||||
cfg.Agents.Defaults.Workspace = "/test/ws"
|
||||
|
||||
provider, err := CreateProvider(cfg)
|
||||
provider, _, err := CreateProvider(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateProvider(claude-cli) error = %v", err)
|
||||
}
|
||||
@@ -437,7 +437,7 @@ func TestCreateProvider_ClaudeCode(t *testing.T) {
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Agents.Defaults.Provider = "claude-code"
|
||||
|
||||
provider, err := CreateProvider(cfg)
|
||||
provider, _, err := CreateProvider(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateProvider(claude-code) error = %v", err)
|
||||
}
|
||||
@@ -450,7 +450,7 @@ func TestCreateProvider_ClaudeCodec(t *testing.T) {
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Agents.Defaults.Provider = "claudecode"
|
||||
|
||||
provider, err := CreateProvider(cfg)
|
||||
provider, _, err := CreateProvider(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateProvider(claudecode) error = %v", err)
|
||||
}
|
||||
@@ -464,7 +464,7 @@ func TestCreateProvider_ClaudeCliDefaultWorkspace(t *testing.T) {
|
||||
cfg.Agents.Defaults.Provider = "claude-cli"
|
||||
cfg.Agents.Defaults.Workspace = ""
|
||||
|
||||
provider, err := CreateProvider(cfg)
|
||||
provider, _, err := CreateProvider(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateProvider error = %v", err)
|
||||
}
|
||||
|
||||
@@ -32,13 +32,14 @@ func ExtractProtocol(model string) (protocol, modelID string) {
|
||||
// CreateProviderFromConfig creates a provider based on the ModelConfig.
|
||||
// It uses the protocol prefix in the Model field to determine which provider to create.
|
||||
// Supported protocols: openai, anthropic, antigravity, claude-cli, codex-cli, github-copilot
|
||||
func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, error) {
|
||||
// Returns the provider, the model ID (without protocol prefix), and any error.
|
||||
func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, error) {
|
||||
if cfg == nil {
|
||||
return nil, fmt.Errorf("config is nil")
|
||||
return nil, "", fmt.Errorf("config is nil")
|
||||
}
|
||||
|
||||
if cfg.Model == "" {
|
||||
return nil, fmt.Errorf("model is required")
|
||||
return nil, "", fmt.Errorf("model is required")
|
||||
}
|
||||
|
||||
protocol, modelID := ExtractProtocol(cfg.Model)
|
||||
@@ -49,36 +50,36 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, error) {
|
||||
"volcengine", "vllm", "qwen":
|
||||
// All OpenAI-compatible HTTP providers
|
||||
if cfg.APIKey == "" && cfg.APIBase == "" {
|
||||
return nil, fmt.Errorf("api_key or api_base is required for HTTP-based protocol %q", protocol)
|
||||
return nil, "", fmt.Errorf("api_key or api_base is required for HTTP-based protocol %q", protocol)
|
||||
}
|
||||
apiBase := cfg.APIBase
|
||||
if apiBase == "" {
|
||||
apiBase = getDefaultAPIBase(protocol)
|
||||
}
|
||||
return NewHTTPProvider(cfg.APIKey, apiBase, cfg.Proxy), nil
|
||||
return NewHTTPProvider(cfg.APIKey, apiBase, cfg.Proxy), modelID, nil
|
||||
|
||||
case "anthropic":
|
||||
if cfg.AuthMethod == "oauth" || cfg.AuthMethod == "token" {
|
||||
// Use Claude SDK with token
|
||||
return NewClaudeProvider(cfg.APIKey), nil
|
||||
return NewClaudeProvider(cfg.APIKey), modelID, nil
|
||||
}
|
||||
// Use HTTP API
|
||||
apiBase := cfg.APIBase
|
||||
if apiBase == "" {
|
||||
apiBase = "https://api.anthropic.com/v1"
|
||||
}
|
||||
return NewHTTPProvider(cfg.APIKey, apiBase, cfg.Proxy), nil
|
||||
return NewHTTPProvider(cfg.APIKey, apiBase, cfg.Proxy), modelID, nil
|
||||
|
||||
case "antigravity":
|
||||
return NewAntigravityProvider(), nil
|
||||
return NewAntigravityProvider(), modelID, nil
|
||||
|
||||
case "claude-cli", "claudecli":
|
||||
workspace := "."
|
||||
return NewClaudeCliProvider(workspace), nil
|
||||
return NewClaudeCliProvider(workspace), modelID, nil
|
||||
|
||||
case "codex-cli", "codexcli":
|
||||
workspace := "."
|
||||
return NewCodexCliProvider(workspace), nil
|
||||
return NewCodexCliProvider(workspace), modelID, nil
|
||||
|
||||
case "github-copilot", "copilot":
|
||||
apiBase := cfg.APIBase
|
||||
@@ -89,10 +90,14 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, error) {
|
||||
if connectMode == "" {
|
||||
connectMode = "grpc"
|
||||
}
|
||||
return NewGitHubCopilotProvider(apiBase, connectMode, modelID)
|
||||
provider, err := NewGitHubCopilotProvider(apiBase, connectMode, modelID)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
return provider, modelID, nil
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown protocol %q in model %q", protocol, cfg.Model)
|
||||
return nil, "", fmt.Errorf("unknown protocol %q in model %q", protocol, cfg.Model)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,250 @@
|
||||
// PicoClaw - Ultra-lightweight personal AI agent
|
||||
// License: MIT
|
||||
//
|
||||
// Copyright (c) 2026 PicoClaw contributors
|
||||
|
||||
package providers
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
func TestExtractProtocol(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
model string
|
||||
wantProtocol string
|
||||
wantModelID string
|
||||
}{
|
||||
{
|
||||
name: "openai with prefix",
|
||||
model: "openai/gpt-4o",
|
||||
wantProtocol: "openai",
|
||||
wantModelID: "gpt-4o",
|
||||
},
|
||||
{
|
||||
name: "anthropic with prefix",
|
||||
model: "anthropic/claude-3-sonnet",
|
||||
wantProtocol: "anthropic",
|
||||
wantModelID: "claude-3-sonnet",
|
||||
},
|
||||
{
|
||||
name: "no prefix - defaults to openai",
|
||||
model: "gpt-4o",
|
||||
wantProtocol: "openai",
|
||||
wantModelID: "gpt-4o",
|
||||
},
|
||||
{
|
||||
name: "groq with prefix",
|
||||
model: "groq/llama-3.1-70b",
|
||||
wantProtocol: "groq",
|
||||
wantModelID: "llama-3.1-70b",
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
model: "",
|
||||
wantProtocol: "openai",
|
||||
wantModelID: "",
|
||||
},
|
||||
{
|
||||
name: "with whitespace",
|
||||
model: " openai/gpt-4 ",
|
||||
wantProtocol: "openai",
|
||||
wantModelID: "gpt-4",
|
||||
},
|
||||
{
|
||||
name: "multiple slashes",
|
||||
model: "nvidia/meta/llama-3.1-8b",
|
||||
wantProtocol: "nvidia",
|
||||
wantModelID: "meta/llama-3.1-8b",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
protocol, modelID := ExtractProtocol(tt.model)
|
||||
if protocol != tt.wantProtocol {
|
||||
t.Errorf("ExtractProtocol(%q) protocol = %q, want %q", tt.model, protocol, tt.wantProtocol)
|
||||
}
|
||||
if modelID != tt.wantModelID {
|
||||
t.Errorf("ExtractProtocol(%q) modelID = %q, want %q", tt.model, modelID, tt.wantModelID)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateProviderFromConfig_OpenAI(t *testing.T) {
|
||||
cfg := &config.ModelConfig{
|
||||
ModelName: "test-openai",
|
||||
Model: "openai/gpt-4o",
|
||||
APIKey: "test-key",
|
||||
APIBase: "https://api.example.com/v1",
|
||||
}
|
||||
|
||||
provider, modelID, err := CreateProviderFromConfig(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateProviderFromConfig() error = %v", err)
|
||||
}
|
||||
if provider == nil {
|
||||
t.Fatal("CreateProviderFromConfig() returned nil provider")
|
||||
}
|
||||
if modelID != "gpt-4o" {
|
||||
t.Errorf("modelID = %q, want %q", modelID, "gpt-4o")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateProviderFromConfig_DefaultAPIBase(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
protocol string
|
||||
wantBase string
|
||||
}{
|
||||
{"openai", "openai", "https://api.openai.com/v1"},
|
||||
{"groq", "groq", "https://api.groq.com/openai/v1"},
|
||||
{"openrouter", "openrouter", "https://openrouter.ai/api/v1"},
|
||||
{"cerebras", "cerebras", "https://api.cerebras.ai/v1"},
|
||||
{"qwen", "qwen", "https://dashscope.aliyuncs.com/compatible-mode/v1"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cfg := &config.ModelConfig{
|
||||
ModelName: "test-" + tt.protocol,
|
||||
Model: tt.protocol + "/test-model",
|
||||
APIKey: "test-key",
|
||||
}
|
||||
|
||||
provider, _, err := CreateProviderFromConfig(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateProviderFromConfig() error = %v", err)
|
||||
}
|
||||
|
||||
httpProvider, ok := provider.(*HTTPProvider)
|
||||
if !ok {
|
||||
t.Fatalf("expected *HTTPProvider, got %T", provider)
|
||||
}
|
||||
if httpProvider.apiBase != tt.wantBase {
|
||||
t.Errorf("apiBase = %q, want %q", httpProvider.apiBase, tt.wantBase)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateProviderFromConfig_Anthropic(t *testing.T) {
|
||||
cfg := &config.ModelConfig{
|
||||
ModelName: "test-anthropic",
|
||||
Model: "anthropic/claude-3-sonnet",
|
||||
APIKey: "test-key",
|
||||
}
|
||||
|
||||
provider, modelID, err := CreateProviderFromConfig(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateProviderFromConfig() error = %v", err)
|
||||
}
|
||||
if provider == nil {
|
||||
t.Fatal("CreateProviderFromConfig() returned nil provider")
|
||||
}
|
||||
if modelID != "claude-3-sonnet" {
|
||||
t.Errorf("modelID = %q, want %q", modelID, "claude-3-sonnet")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateProviderFromConfig_Antigravity(t *testing.T) {
|
||||
cfg := &config.ModelConfig{
|
||||
ModelName: "test-antigravity",
|
||||
Model: "antigravity/gemini-2.0-flash",
|
||||
}
|
||||
|
||||
provider, modelID, err := CreateProviderFromConfig(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateProviderFromConfig() error = %v", err)
|
||||
}
|
||||
if provider == nil {
|
||||
t.Fatal("CreateProviderFromConfig() returned nil provider")
|
||||
}
|
||||
if modelID != "gemini-2.0-flash" {
|
||||
t.Errorf("modelID = %q, want %q", modelID, "gemini-2.0-flash")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateProviderFromConfig_ClaudeCLI(t *testing.T) {
|
||||
cfg := &config.ModelConfig{
|
||||
ModelName: "test-claude-cli",
|
||||
Model: "claude-cli/claude-sonnet-4-20250514",
|
||||
}
|
||||
|
||||
provider, modelID, err := CreateProviderFromConfig(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateProviderFromConfig() error = %v", err)
|
||||
}
|
||||
if provider == nil {
|
||||
t.Fatal("CreateProviderFromConfig() returned nil provider")
|
||||
}
|
||||
if modelID != "claude-sonnet-4-20250514" {
|
||||
t.Errorf("modelID = %q, want %q", modelID, "claude-sonnet-4-20250514")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateProviderFromConfig_CodexCLI(t *testing.T) {
|
||||
cfg := &config.ModelConfig{
|
||||
ModelName: "test-codex-cli",
|
||||
Model: "codex-cli/codex",
|
||||
}
|
||||
|
||||
provider, modelID, err := CreateProviderFromConfig(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateProviderFromConfig() error = %v", err)
|
||||
}
|
||||
if provider == nil {
|
||||
t.Fatal("CreateProviderFromConfig() returned nil provider")
|
||||
}
|
||||
if modelID != "codex" {
|
||||
t.Errorf("modelID = %q, want %q", modelID, "codex")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateProviderFromConfig_MissingAPIKey(t *testing.T) {
|
||||
cfg := &config.ModelConfig{
|
||||
ModelName: "test-no-key",
|
||||
Model: "openai/gpt-4o",
|
||||
}
|
||||
|
||||
_, _, err := CreateProviderFromConfig(cfg)
|
||||
if err == nil {
|
||||
t.Fatal("CreateProviderFromConfig() expected error for missing API key")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateProviderFromConfig_UnknownProtocol(t *testing.T) {
|
||||
cfg := &config.ModelConfig{
|
||||
ModelName: "test-unknown",
|
||||
Model: "unknown-protocol/model",
|
||||
APIKey: "test-key",
|
||||
}
|
||||
|
||||
_, _, err := CreateProviderFromConfig(cfg)
|
||||
if err == nil {
|
||||
t.Fatal("CreateProviderFromConfig() expected error for unknown protocol")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateProviderFromConfig_NilConfig(t *testing.T) {
|
||||
_, _, err := CreateProviderFromConfig(nil)
|
||||
if err == nil {
|
||||
t.Fatal("CreateProviderFromConfig(nil) expected error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateProviderFromConfig_EmptyModel(t *testing.T) {
|
||||
cfg := &config.ModelConfig{
|
||||
ModelName: "test-empty",
|
||||
Model: "",
|
||||
}
|
||||
|
||||
_, _, err := CreateProviderFromConfig(cfg)
|
||||
if err == nil {
|
||||
t.Fatal("CreateProviderFromConfig() expected error for empty model")
|
||||
}
|
||||
}
|
||||
@@ -16,9 +16,6 @@ import (
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/auth"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
type HTTPProvider struct {
|
||||
@@ -161,13 +158,15 @@ func (p *HTTPProvider) parseResponse(body []byte) (*LLMResponse, error) {
|
||||
arguments := make(map[string]interface{})
|
||||
name := ""
|
||||
thoughtSignature := ""
|
||||
argsStr := ""
|
||||
|
||||
if tc.Function != nil {
|
||||
name = tc.Function.Name
|
||||
thoughtSignature = tc.Function.ThoughtSignature
|
||||
if tc.Function.Arguments != "" {
|
||||
if err := json.Unmarshal([]byte(tc.Function.Arguments), &arguments); err != nil {
|
||||
arguments["raw"] = tc.Function.Arguments
|
||||
argsStr = tc.Function.Arguments
|
||||
if argsStr != "" {
|
||||
if err := json.Unmarshal([]byte(argsStr), &arguments); err != nil {
|
||||
arguments["raw"] = argsStr
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -177,7 +176,7 @@ func (p *HTTPProvider) parseResponse(body []byte) (*LLMResponse, error) {
|
||||
Type: tc.Type,
|
||||
Function: &FunctionCall{
|
||||
Name: name,
|
||||
Arguments: tc.Function.Arguments,
|
||||
Arguments: argsStr,
|
||||
ThoughtSignature: thoughtSignature,
|
||||
},
|
||||
Name: name,
|
||||
@@ -196,328 +195,3 @@ func (p *HTTPProvider) parseResponse(body []byte) (*LLMResponse, error) {
|
||||
func (p *HTTPProvider) GetDefaultModel() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func createClaudeAuthProvider() (LLMProvider, error) {
|
||||
cred, err := auth.GetCredential("anthropic")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("loading auth credentials: %w", err)
|
||||
}
|
||||
if cred == nil {
|
||||
return nil, fmt.Errorf("no credentials for anthropic. Run: picoclaw auth login --provider anthropic")
|
||||
}
|
||||
return NewClaudeProviderWithTokenSource(cred.AccessToken, createClaudeTokenSource()), nil
|
||||
}
|
||||
|
||||
func createCodexAuthProvider() (LLMProvider, error) {
|
||||
cred, err := auth.GetCredential("openai")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("loading auth credentials: %w", err)
|
||||
}
|
||||
if cred == nil {
|
||||
return nil, fmt.Errorf("no credentials for openai. Run: picoclaw auth login --provider openai")
|
||||
}
|
||||
return NewCodexProviderWithTokenSource(cred.AccessToken, cred.AccountID, createCodexTokenSource()), nil
|
||||
}
|
||||
|
||||
func CreateProvider(cfg *config.Config) (LLMProvider, error) {
|
||||
model := cfg.Agents.Defaults.Model
|
||||
|
||||
// First, try to use model_list configuration
|
||||
if len(cfg.ModelList) > 0 {
|
||||
// Try to get config by model name first
|
||||
modelCfg, err := cfg.GetModelConfig(model)
|
||||
if err == nil {
|
||||
// Found in model_list, use factory to create provider
|
||||
provider, err := CreateProviderFromConfig(modelCfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create provider from model_list: %w", err)
|
||||
}
|
||||
return provider, nil
|
||||
}
|
||||
// Model not found in model_list, fall through to providers config
|
||||
}
|
||||
|
||||
// Log deprecation warning if using old providers config
|
||||
if cfg.HasProvidersConfig() && len(cfg.ModelList) == 0 {
|
||||
fmt.Println("WARNING: providers config is deprecated, please migrate to model_list")
|
||||
}
|
||||
|
||||
providerName := strings.ToLower(cfg.Agents.Defaults.Provider)
|
||||
|
||||
var apiKey, apiBase, proxy string
|
||||
|
||||
lowerModel := strings.ToLower(model)
|
||||
|
||||
// First, try to use explicitly configured provider
|
||||
if providerName != "" {
|
||||
switch providerName {
|
||||
case "groq":
|
||||
if cfg.Providers.Groq.APIKey != "" {
|
||||
apiKey = cfg.Providers.Groq.APIKey
|
||||
apiBase = cfg.Providers.Groq.APIBase
|
||||
if apiBase == "" {
|
||||
apiBase = "https://api.groq.com/openai/v1"
|
||||
}
|
||||
}
|
||||
case "openai", "gpt":
|
||||
if cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != "" {
|
||||
if cfg.Providers.OpenAI.AuthMethod == "codex-cli" {
|
||||
return NewCodexProviderWithTokenSource("", "", CreateCodexCliTokenSource()), nil
|
||||
}
|
||||
if cfg.Providers.OpenAI.AuthMethod == "oauth" || cfg.Providers.OpenAI.AuthMethod == "token" {
|
||||
return createCodexAuthProvider()
|
||||
}
|
||||
apiKey = cfg.Providers.OpenAI.APIKey
|
||||
apiBase = cfg.Providers.OpenAI.APIBase
|
||||
if apiBase == "" {
|
||||
apiBase = "https://api.openai.com/v1"
|
||||
}
|
||||
}
|
||||
case "anthropic", "claude":
|
||||
if cfg.Providers.Anthropic.APIKey != "" || cfg.Providers.Anthropic.AuthMethod != "" {
|
||||
if cfg.Providers.Anthropic.AuthMethod == "oauth" || cfg.Providers.Anthropic.AuthMethod == "token" {
|
||||
return createClaudeAuthProvider()
|
||||
}
|
||||
apiKey = cfg.Providers.Anthropic.APIKey
|
||||
apiBase = cfg.Providers.Anthropic.APIBase
|
||||
if apiBase == "" {
|
||||
apiBase = "https://api.anthropic.com/v1"
|
||||
}
|
||||
}
|
||||
case "openrouter":
|
||||
if cfg.Providers.OpenRouter.APIKey != "" {
|
||||
apiKey = cfg.Providers.OpenRouter.APIKey
|
||||
if cfg.Providers.OpenRouter.APIBase != "" {
|
||||
apiBase = cfg.Providers.OpenRouter.APIBase
|
||||
} else {
|
||||
apiBase = "https://openrouter.ai/api/v1"
|
||||
}
|
||||
}
|
||||
case "zhipu", "glm":
|
||||
if cfg.Providers.Zhipu.APIKey != "" {
|
||||
apiKey = cfg.Providers.Zhipu.APIKey
|
||||
apiBase = cfg.Providers.Zhipu.APIBase
|
||||
if apiBase == "" {
|
||||
apiBase = "https://open.bigmodel.cn/api/paas/v4"
|
||||
}
|
||||
}
|
||||
case "gemini", "google":
|
||||
if cfg.Providers.Gemini.APIKey != "" {
|
||||
apiKey = cfg.Providers.Gemini.APIKey
|
||||
apiBase = cfg.Providers.Gemini.APIBase
|
||||
if apiBase == "" {
|
||||
apiBase = "https://generativelanguage.googleapis.com/v1beta"
|
||||
}
|
||||
}
|
||||
case "vllm":
|
||||
if cfg.Providers.VLLM.APIBase != "" {
|
||||
apiKey = cfg.Providers.VLLM.APIKey
|
||||
apiBase = cfg.Providers.VLLM.APIBase
|
||||
}
|
||||
case "shengsuanyun":
|
||||
if cfg.Providers.ShengSuanYun.APIKey != "" {
|
||||
apiKey = cfg.Providers.ShengSuanYun.APIKey
|
||||
apiBase = cfg.Providers.ShengSuanYun.APIBase
|
||||
if apiBase == "" {
|
||||
apiBase = "https://router.shengsuanyun.com/api/v1"
|
||||
}
|
||||
}
|
||||
case "claude-cli", "claudecode", "claude-code":
|
||||
workspace := cfg.WorkspacePath()
|
||||
if workspace == "" {
|
||||
workspace = "."
|
||||
}
|
||||
return NewClaudeCliProvider(workspace), nil
|
||||
case "codex-cli", "codex-code":
|
||||
workspace := cfg.WorkspacePath()
|
||||
if workspace == "" {
|
||||
workspace = "."
|
||||
}
|
||||
return NewCodexCliProvider(workspace), nil
|
||||
case "cerebras":
|
||||
if cfg.Providers.Cerebras.APIKey != "" {
|
||||
apiKey = cfg.Providers.Cerebras.APIKey
|
||||
apiBase = cfg.Providers.Cerebras.APIBase
|
||||
if apiBase == "" {
|
||||
apiBase = "https://api.cerebras.ai/v1"
|
||||
}
|
||||
}
|
||||
case "deepseek":
|
||||
if cfg.Providers.DeepSeek.APIKey != "" {
|
||||
apiKey = cfg.Providers.DeepSeek.APIKey
|
||||
apiBase = cfg.Providers.DeepSeek.APIBase
|
||||
if apiBase == "" {
|
||||
apiBase = "https://api.deepseek.com/v1"
|
||||
}
|
||||
if model != "deepseek-chat" && model != "deepseek-reasoner" {
|
||||
model = "deepseek-chat"
|
||||
}
|
||||
}
|
||||
case "qwen":
|
||||
if cfg.Providers.Qwen.APIKey != "" {
|
||||
apiKey = cfg.Providers.Qwen.APIKey
|
||||
apiBase = cfg.Providers.Qwen.APIBase
|
||||
if apiBase == "" {
|
||||
apiBase = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
}
|
||||
}
|
||||
case "github_copilot", "copilot":
|
||||
if cfg.Providers.GitHubCopilot.APIBase != "" {
|
||||
apiBase = cfg.Providers.GitHubCopilot.APIBase
|
||||
} else {
|
||||
apiBase = "localhost:4321"
|
||||
}
|
||||
return NewGitHubCopilotProvider(apiBase, cfg.Providers.GitHubCopilot.ConnectMode, model)
|
||||
case "antigravity", "google-antigravity":
|
||||
return NewAntigravityProvider(), nil
|
||||
|
||||
case "volcengine", "doubao":
|
||||
if cfg.Providers.VolcEngine.APIKey != "" {
|
||||
apiKey = cfg.Providers.VolcEngine.APIKey
|
||||
apiBase = cfg.Providers.VolcEngine.APIBase
|
||||
if apiBase == "" {
|
||||
apiBase = "https://ark.cn-beijing.volces.com/api/v3"
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Fallback: detect provider from model name
|
||||
if apiKey == "" && apiBase == "" {
|
||||
switch {
|
||||
case (strings.Contains(lowerModel, "kimi") || strings.Contains(lowerModel, "moonshot") || strings.HasPrefix(model, "moonshot/")) && cfg.Providers.Moonshot.APIKey != "":
|
||||
apiKey = cfg.Providers.Moonshot.APIKey
|
||||
apiBase = cfg.Providers.Moonshot.APIBase
|
||||
proxy = cfg.Providers.Moonshot.Proxy
|
||||
if apiBase == "" {
|
||||
apiBase = "https://api.moonshot.cn/v1"
|
||||
}
|
||||
|
||||
case strings.HasPrefix(model, "openrouter/") || strings.HasPrefix(model, "anthropic/") || strings.HasPrefix(model, "openai/") || strings.HasPrefix(model, "meta-llama/") || strings.HasPrefix(model, "deepseek/") || strings.HasPrefix(model, "google/"):
|
||||
apiKey = cfg.Providers.OpenRouter.APIKey
|
||||
proxy = cfg.Providers.OpenRouter.Proxy
|
||||
if cfg.Providers.OpenRouter.APIBase != "" {
|
||||
apiBase = cfg.Providers.OpenRouter.APIBase
|
||||
} else {
|
||||
apiBase = "https://openrouter.ai/api/v1"
|
||||
}
|
||||
|
||||
case (strings.Contains(lowerModel, "claude") || strings.HasPrefix(model, "anthropic/")) && (cfg.Providers.Anthropic.APIKey != "" || cfg.Providers.Anthropic.AuthMethod != ""):
|
||||
if cfg.Providers.Anthropic.AuthMethod == "oauth" || cfg.Providers.Anthropic.AuthMethod == "token" {
|
||||
return createClaudeAuthProvider()
|
||||
}
|
||||
apiKey = cfg.Providers.Anthropic.APIKey
|
||||
apiBase = cfg.Providers.Anthropic.APIBase
|
||||
proxy = cfg.Providers.Anthropic.Proxy
|
||||
if apiBase == "" {
|
||||
apiBase = "https://api.anthropic.com/v1"
|
||||
}
|
||||
|
||||
case (strings.Contains(lowerModel, "gpt") || strings.HasPrefix(model, "openai/")) && (cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != ""):
|
||||
if cfg.Providers.OpenAI.AuthMethod == "oauth" || cfg.Providers.OpenAI.AuthMethod == "token" {
|
||||
return createCodexAuthProvider()
|
||||
}
|
||||
apiKey = cfg.Providers.OpenAI.APIKey
|
||||
apiBase = cfg.Providers.OpenAI.APIBase
|
||||
proxy = cfg.Providers.OpenAI.Proxy
|
||||
if apiBase == "" {
|
||||
apiBase = "https://api.openai.com/v1"
|
||||
}
|
||||
|
||||
case (strings.Contains(lowerModel, "gemini") || strings.HasPrefix(model, "google/")) && cfg.Providers.Gemini.APIKey != "":
|
||||
apiKey = cfg.Providers.Gemini.APIKey
|
||||
apiBase = cfg.Providers.Gemini.APIBase
|
||||
proxy = cfg.Providers.Gemini.Proxy
|
||||
if apiBase == "" {
|
||||
apiBase = "https://generativelanguage.googleapis.com/v1beta"
|
||||
}
|
||||
|
||||
case (strings.Contains(lowerModel, "glm") || strings.Contains(lowerModel, "zhipu") || strings.Contains(lowerModel, "zai")) && cfg.Providers.Zhipu.APIKey != "":
|
||||
apiKey = cfg.Providers.Zhipu.APIKey
|
||||
apiBase = cfg.Providers.Zhipu.APIBase
|
||||
proxy = cfg.Providers.Zhipu.Proxy
|
||||
if apiBase == "" {
|
||||
apiBase = "https://open.bigmodel.cn/api/paas/v4"
|
||||
}
|
||||
|
||||
case (strings.Contains(lowerModel, "groq") || strings.HasPrefix(model, "groq/")) && cfg.Providers.Groq.APIKey != "":
|
||||
apiKey = cfg.Providers.Groq.APIKey
|
||||
apiBase = cfg.Providers.Groq.APIBase
|
||||
proxy = cfg.Providers.Groq.Proxy
|
||||
if apiBase == "" {
|
||||
apiBase = "https://api.groq.com/openai/v1"
|
||||
}
|
||||
|
||||
case (strings.Contains(lowerModel, "qwen") || strings.HasPrefix(model, "qwen/")) && cfg.Providers.Qwen.APIKey != "":
|
||||
apiKey = cfg.Providers.Qwen.APIKey
|
||||
apiBase = cfg.Providers.Qwen.APIBase
|
||||
proxy = cfg.Providers.Qwen.Proxy
|
||||
if apiBase == "" {
|
||||
apiBase = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
}
|
||||
|
||||
case (strings.Contains(lowerModel, "nvidia") || strings.HasPrefix(model, "nvidia/")) && cfg.Providers.Nvidia.APIKey != "":
|
||||
apiKey = cfg.Providers.Nvidia.APIKey
|
||||
apiBase = cfg.Providers.Nvidia.APIBase
|
||||
proxy = cfg.Providers.Nvidia.Proxy
|
||||
if apiBase == "" {
|
||||
apiBase = "https://integrate.api.nvidia.com/v1"
|
||||
}
|
||||
case (strings.Contains(lowerModel, "cerebras") || strings.HasPrefix(model, "cerebras/")) && cfg.Providers.Cerebras.APIKey != "":
|
||||
apiKey = cfg.Providers.Cerebras.APIKey
|
||||
apiBase = cfg.Providers.Cerebras.APIBase
|
||||
proxy = cfg.Providers.Cerebras.Proxy
|
||||
if apiBase == "" {
|
||||
apiBase = "https://api.cerebras.ai/v1"
|
||||
}
|
||||
|
||||
case (strings.Contains(lowerModel, "ollama") || strings.HasPrefix(model, "ollama/")) && cfg.Providers.Ollama.APIKey != "":
|
||||
fmt.Println("Ollama provider selected based on model name prefix")
|
||||
apiKey = cfg.Providers.Ollama.APIKey
|
||||
apiBase = cfg.Providers.Ollama.APIBase
|
||||
proxy = cfg.Providers.Ollama.Proxy
|
||||
if apiBase == "" {
|
||||
apiBase = "http://localhost:11434/v1"
|
||||
}
|
||||
fmt.Println("Ollama apiBase:", apiBase)
|
||||
|
||||
case (strings.Contains(lowerModel, "doubao") || strings.HasPrefix(lowerModel, "doubao") || strings.Contains(lowerModel, "volcengine")) && cfg.Providers.VolcEngine.APIKey != "":
|
||||
apiKey = cfg.Providers.VolcEngine.APIKey
|
||||
apiBase = cfg.Providers.VolcEngine.APIBase
|
||||
proxy = cfg.Providers.VolcEngine.Proxy
|
||||
if apiBase == "" {
|
||||
apiBase = "https://ark.cn-beijing.volces.com/api/v3"
|
||||
}
|
||||
|
||||
case cfg.Providers.VLLM.APIBase != "":
|
||||
apiKey = cfg.Providers.VLLM.APIKey
|
||||
apiBase = cfg.Providers.VLLM.APIBase
|
||||
proxy = cfg.Providers.VLLM.Proxy
|
||||
|
||||
default:
|
||||
if cfg.Providers.OpenRouter.APIKey != "" {
|
||||
apiKey = cfg.Providers.OpenRouter.APIKey
|
||||
proxy = cfg.Providers.OpenRouter.Proxy
|
||||
if cfg.Providers.OpenRouter.APIBase != "" {
|
||||
apiBase = cfg.Providers.OpenRouter.APIBase
|
||||
} else {
|
||||
apiBase = "https://openrouter.ai/api/v1"
|
||||
}
|
||||
} else {
|
||||
return nil, fmt.Errorf("no API key configured for model: %s", model)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if apiKey == "" && !strings.HasPrefix(model, "bedrock/") {
|
||||
return nil, fmt.Errorf("no API key configured for provider (model: %s)", model)
|
||||
}
|
||||
|
||||
if apiBase == "" {
|
||||
return nil, fmt.Errorf("no API base configured for provider (model: %s)", model)
|
||||
}
|
||||
|
||||
return NewHTTPProvider(apiKey, apiBase, proxy), nil
|
||||
}
|
||||
|
||||
@@ -0,0 +1,349 @@
|
||||
// PicoClaw - Ultra-lightweight personal AI agent
|
||||
// License: MIT
|
||||
//
|
||||
// Copyright (c) 2026 PicoClaw contributors
|
||||
|
||||
package providers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/auth"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
// createClaudeAuthProvider creates a Claude provider using OAuth credentials.
|
||||
func createClaudeAuthProvider() (LLMProvider, error) {
|
||||
cred, err := auth.GetCredential("anthropic")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("loading auth credentials: %w", err)
|
||||
}
|
||||
if cred == nil {
|
||||
return nil, fmt.Errorf("no credentials for anthropic. Run: picoclaw auth login --provider anthropic")
|
||||
}
|
||||
return NewClaudeProviderWithTokenSource(cred.AccessToken, createClaudeTokenSource()), nil
|
||||
}
|
||||
|
||||
// createCodexAuthProvider creates a Codex provider using OAuth credentials.
|
||||
func createCodexAuthProvider() (LLMProvider, error) {
|
||||
cred, err := auth.GetCredential("openai")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("loading auth credentials: %w", err)
|
||||
}
|
||||
if cred == nil {
|
||||
return nil, fmt.Errorf("no credentials for openai. Run: picoclaw auth login --provider openai")
|
||||
}
|
||||
return NewCodexProviderWithTokenSource(cred.AccessToken, cred.AccountID, createCodexTokenSource()), nil
|
||||
}
|
||||
|
||||
// CreateProvider creates a provider based on the configuration.
|
||||
// It supports both the new model_list configuration and the legacy providers configuration.
|
||||
// Returns the provider, the model ID to use, and any error.
|
||||
func CreateProvider(cfg *config.Config) (LLMProvider, string, error) {
|
||||
model := cfg.Agents.Defaults.Model
|
||||
|
||||
// First, try to use model_list configuration
|
||||
if len(cfg.ModelList) > 0 {
|
||||
// Try to get config by model name first
|
||||
modelCfg, err := cfg.GetModelConfig(model)
|
||||
if err == nil {
|
||||
// Found in model_list, use factory to create provider
|
||||
provider, modelID, err := CreateProviderFromConfig(modelCfg)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("failed to create provider from model_list: %w", err)
|
||||
}
|
||||
return provider, modelID, nil
|
||||
}
|
||||
// Model not found in model_list, fall through to providers config
|
||||
}
|
||||
|
||||
// Log deprecation warning if using old providers config
|
||||
if cfg.HasProvidersConfig() && len(cfg.ModelList) == 0 {
|
||||
fmt.Println("WARNING: providers config is deprecated, please migrate to model_list")
|
||||
}
|
||||
|
||||
providerName := strings.ToLower(cfg.Agents.Defaults.Provider)
|
||||
|
||||
var apiKey, apiBase, proxy string
|
||||
|
||||
lowerModel := strings.ToLower(model)
|
||||
|
||||
// First, try to use explicitly configured provider
|
||||
if providerName != "" {
|
||||
switch providerName {
|
||||
case "groq":
|
||||
if cfg.Providers.Groq.APIKey != "" {
|
||||
apiKey = cfg.Providers.Groq.APIKey
|
||||
apiBase = cfg.Providers.Groq.APIBase
|
||||
if apiBase == "" {
|
||||
apiBase = "https://api.groq.com/openai/v1"
|
||||
}
|
||||
}
|
||||
case "openai", "gpt":
|
||||
if cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != "" {
|
||||
if cfg.Providers.OpenAI.AuthMethod == "codex-cli" {
|
||||
return NewCodexProviderWithTokenSource("", "", CreateCodexCliTokenSource()), model, nil
|
||||
}
|
||||
if cfg.Providers.OpenAI.AuthMethod == "oauth" || cfg.Providers.OpenAI.AuthMethod == "token" {
|
||||
provider, err := createCodexAuthProvider()
|
||||
return provider, model, err
|
||||
}
|
||||
apiKey = cfg.Providers.OpenAI.APIKey
|
||||
apiBase = cfg.Providers.OpenAI.APIBase
|
||||
if apiBase == "" {
|
||||
apiBase = "https://api.openai.com/v1"
|
||||
}
|
||||
}
|
||||
case "anthropic", "claude":
|
||||
if cfg.Providers.Anthropic.APIKey != "" || cfg.Providers.Anthropic.AuthMethod != "" {
|
||||
if cfg.Providers.Anthropic.AuthMethod == "oauth" || cfg.Providers.Anthropic.AuthMethod == "token" {
|
||||
provider, err := createClaudeAuthProvider()
|
||||
return provider, model, err
|
||||
}
|
||||
apiKey = cfg.Providers.Anthropic.APIKey
|
||||
apiBase = cfg.Providers.Anthropic.APIBase
|
||||
if apiBase == "" {
|
||||
apiBase = "https://api.anthropic.com/v1"
|
||||
}
|
||||
}
|
||||
case "openrouter":
|
||||
if cfg.Providers.OpenRouter.APIKey != "" {
|
||||
apiKey = cfg.Providers.OpenRouter.APIKey
|
||||
if cfg.Providers.OpenRouter.APIBase != "" {
|
||||
apiBase = cfg.Providers.OpenRouter.APIBase
|
||||
} else {
|
||||
apiBase = "https://openrouter.ai/api/v1"
|
||||
}
|
||||
}
|
||||
case "zhipu", "glm":
|
||||
if cfg.Providers.Zhipu.APIKey != "" {
|
||||
apiKey = cfg.Providers.Zhipu.APIKey
|
||||
apiBase = cfg.Providers.Zhipu.APIBase
|
||||
if apiBase == "" {
|
||||
apiBase = "https://open.bigmodel.cn/api/paas/v4"
|
||||
}
|
||||
}
|
||||
case "gemini", "google":
|
||||
if cfg.Providers.Gemini.APIKey != "" {
|
||||
apiKey = cfg.Providers.Gemini.APIKey
|
||||
apiBase = cfg.Providers.Gemini.APIBase
|
||||
if apiBase == "" {
|
||||
apiBase = "https://generativelanguage.googleapis.com/v1beta"
|
||||
}
|
||||
}
|
||||
case "vllm":
|
||||
if cfg.Providers.VLLM.APIBase != "" {
|
||||
apiKey = cfg.Providers.VLLM.APIKey
|
||||
apiBase = cfg.Providers.VLLM.APIBase
|
||||
}
|
||||
case "shengsuanyun":
|
||||
if cfg.Providers.ShengSuanYun.APIKey != "" {
|
||||
apiKey = cfg.Providers.ShengSuanYun.APIKey
|
||||
apiBase = cfg.Providers.ShengSuanYun.APIBase
|
||||
if apiBase == "" {
|
||||
apiBase = "https://router.shengsuanyun.com/api/v1"
|
||||
}
|
||||
}
|
||||
case "claude-cli", "claudecode", "claude-code":
|
||||
workspace := cfg.WorkspacePath()
|
||||
if workspace == "" {
|
||||
workspace = "."
|
||||
}
|
||||
return NewClaudeCliProvider(workspace), model, nil
|
||||
case "codex-cli", "codex-code":
|
||||
workspace := cfg.WorkspacePath()
|
||||
if workspace == "" {
|
||||
workspace = "."
|
||||
}
|
||||
return NewCodexCliProvider(workspace), model, nil
|
||||
case "cerebras":
|
||||
if cfg.Providers.Cerebras.APIKey != "" {
|
||||
apiKey = cfg.Providers.Cerebras.APIKey
|
||||
apiBase = cfg.Providers.Cerebras.APIBase
|
||||
if apiBase == "" {
|
||||
apiBase = "https://api.cerebras.ai/v1"
|
||||
}
|
||||
}
|
||||
case "deepseek":
|
||||
if cfg.Providers.DeepSeek.APIKey != "" {
|
||||
apiKey = cfg.Providers.DeepSeek.APIKey
|
||||
apiBase = cfg.Providers.DeepSeek.APIBase
|
||||
if apiBase == "" {
|
||||
apiBase = "https://api.deepseek.com/v1"
|
||||
}
|
||||
if model != "deepseek-chat" && model != "deepseek-reasoner" {
|
||||
model = "deepseek-chat"
|
||||
}
|
||||
}
|
||||
case "qwen":
|
||||
if cfg.Providers.Qwen.APIKey != "" {
|
||||
apiKey = cfg.Providers.Qwen.APIKey
|
||||
apiBase = cfg.Providers.Qwen.APIBase
|
||||
if apiBase == "" {
|
||||
apiBase = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
}
|
||||
}
|
||||
case "github_copilot", "copilot":
|
||||
if cfg.Providers.GitHubCopilot.APIBase != "" {
|
||||
apiBase = cfg.Providers.GitHubCopilot.APIBase
|
||||
} else {
|
||||
apiBase = "localhost:4321"
|
||||
}
|
||||
provider, err := NewGitHubCopilotProvider(apiBase, cfg.Providers.GitHubCopilot.ConnectMode, model)
|
||||
return provider, model, err
|
||||
case "antigravity", "google-antigravity":
|
||||
return NewAntigravityProvider(), model, nil
|
||||
|
||||
case "volcengine", "doubao":
|
||||
if cfg.Providers.VolcEngine.APIKey != "" {
|
||||
apiKey = cfg.Providers.VolcEngine.APIKey
|
||||
apiBase = cfg.Providers.VolcEngine.APIBase
|
||||
if apiBase == "" {
|
||||
apiBase = "https://ark.cn-beijing.volces.com/api/v3"
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Fallback: detect provider from model name
|
||||
if apiKey == "" && apiBase == "" {
|
||||
switch {
|
||||
case (strings.Contains(lowerModel, "kimi") || strings.Contains(lowerModel, "moonshot") || strings.HasPrefix(model, "moonshot/")) && cfg.Providers.Moonshot.APIKey != "":
|
||||
apiKey = cfg.Providers.Moonshot.APIKey
|
||||
apiBase = cfg.Providers.Moonshot.APIBase
|
||||
proxy = cfg.Providers.Moonshot.Proxy
|
||||
if apiBase == "" {
|
||||
apiBase = "https://api.moonshot.cn/v1"
|
||||
}
|
||||
|
||||
case strings.HasPrefix(model, "openrouter/") || strings.HasPrefix(model, "anthropic/") || strings.HasPrefix(model, "openai/") || strings.HasPrefix(model, "meta-llama/") || strings.HasPrefix(model, "deepseek/") || strings.HasPrefix(model, "google/"):
|
||||
apiKey = cfg.Providers.OpenRouter.APIKey
|
||||
proxy = cfg.Providers.OpenRouter.Proxy
|
||||
if cfg.Providers.OpenRouter.APIBase != "" {
|
||||
apiBase = cfg.Providers.OpenRouter.APIBase
|
||||
} else {
|
||||
apiBase = "https://openrouter.ai/api/v1"
|
||||
}
|
||||
|
||||
case (strings.Contains(lowerModel, "claude") || strings.HasPrefix(model, "anthropic/")) && (cfg.Providers.Anthropic.APIKey != "" || cfg.Providers.Anthropic.AuthMethod != ""):
|
||||
if cfg.Providers.Anthropic.AuthMethod == "oauth" || cfg.Providers.Anthropic.AuthMethod == "token" {
|
||||
provider, err := createClaudeAuthProvider()
|
||||
return provider, model, err
|
||||
}
|
||||
apiKey = cfg.Providers.Anthropic.APIKey
|
||||
apiBase = cfg.Providers.Anthropic.APIBase
|
||||
proxy = cfg.Providers.Anthropic.Proxy
|
||||
if apiBase == "" {
|
||||
apiBase = "https://api.anthropic.com/v1"
|
||||
}
|
||||
|
||||
case (strings.Contains(lowerModel, "gpt") || strings.HasPrefix(model, "openai/")) && (cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != ""):
|
||||
if cfg.Providers.OpenAI.AuthMethod == "oauth" || cfg.Providers.OpenAI.AuthMethod == "token" {
|
||||
provider, err := createCodexAuthProvider()
|
||||
return provider, model, err
|
||||
}
|
||||
apiKey = cfg.Providers.OpenAI.APIKey
|
||||
apiBase = cfg.Providers.OpenAI.APIBase
|
||||
proxy = cfg.Providers.OpenAI.Proxy
|
||||
if apiBase == "" {
|
||||
apiBase = "https://api.openai.com/v1"
|
||||
}
|
||||
|
||||
case (strings.Contains(lowerModel, "gemini") || strings.HasPrefix(model, "google/")) && cfg.Providers.Gemini.APIKey != "":
|
||||
apiKey = cfg.Providers.Gemini.APIKey
|
||||
apiBase = cfg.Providers.Gemini.APIBase
|
||||
proxy = cfg.Providers.Gemini.Proxy
|
||||
if apiBase == "" {
|
||||
apiBase = "https://generativelanguage.googleapis.com/v1beta"
|
||||
}
|
||||
|
||||
case (strings.Contains(lowerModel, "glm") || strings.Contains(lowerModel, "zhipu") || strings.Contains(lowerModel, "zai")) && cfg.Providers.Zhipu.APIKey != "":
|
||||
apiKey = cfg.Providers.Zhipu.APIKey
|
||||
apiBase = cfg.Providers.Zhipu.APIBase
|
||||
proxy = cfg.Providers.Zhipu.Proxy
|
||||
if apiBase == "" {
|
||||
apiBase = "https://open.bigmodel.cn/api/paas/v4"
|
||||
}
|
||||
|
||||
case (strings.Contains(lowerModel, "groq") || strings.HasPrefix(model, "groq/")) && cfg.Providers.Groq.APIKey != "":
|
||||
apiKey = cfg.Providers.Groq.APIKey
|
||||
apiBase = cfg.Providers.Groq.APIBase
|
||||
proxy = cfg.Providers.Groq.Proxy
|
||||
if apiBase == "" {
|
||||
apiBase = "https://api.groq.com/openai/v1"
|
||||
}
|
||||
|
||||
case (strings.Contains(lowerModel, "qwen") || strings.HasPrefix(model, "qwen/")) && cfg.Providers.Qwen.APIKey != "":
|
||||
apiKey = cfg.Providers.Qwen.APIKey
|
||||
apiBase = cfg.Providers.Qwen.APIBase
|
||||
proxy = cfg.Providers.Qwen.Proxy
|
||||
if apiBase == "" {
|
||||
apiBase = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
}
|
||||
|
||||
case (strings.Contains(lowerModel, "nvidia") || strings.HasPrefix(model, "nvidia/")) && cfg.Providers.Nvidia.APIKey != "":
|
||||
apiKey = cfg.Providers.Nvidia.APIKey
|
||||
apiBase = cfg.Providers.Nvidia.APIBase
|
||||
proxy = cfg.Providers.Nvidia.Proxy
|
||||
if apiBase == "" {
|
||||
apiBase = "https://integrate.api.nvidia.com/v1"
|
||||
}
|
||||
case (strings.Contains(lowerModel, "cerebras") || strings.HasPrefix(model, "cerebras/")) && cfg.Providers.Cerebras.APIKey != "":
|
||||
apiKey = cfg.Providers.Cerebras.APIKey
|
||||
apiBase = cfg.Providers.Cerebras.APIBase
|
||||
proxy = cfg.Providers.Cerebras.Proxy
|
||||
if apiBase == "" {
|
||||
apiBase = "https://api.cerebras.ai/v1"
|
||||
}
|
||||
|
||||
case (strings.Contains(lowerModel, "ollama") || strings.HasPrefix(model, "ollama/")) && cfg.Providers.Ollama.APIKey != "":
|
||||
fmt.Println("Ollama provider selected based on model name prefix")
|
||||
apiKey = cfg.Providers.Ollama.APIKey
|
||||
apiBase = cfg.Providers.Ollama.APIBase
|
||||
proxy = cfg.Providers.Ollama.Proxy
|
||||
if apiBase == "" {
|
||||
apiBase = "http://localhost:11434/v1"
|
||||
}
|
||||
fmt.Println("Ollama apiBase:", apiBase)
|
||||
|
||||
case (strings.Contains(lowerModel, "doubao") || strings.HasPrefix(lowerModel, "doubao") || strings.Contains(lowerModel, "volcengine")) && cfg.Providers.VolcEngine.APIKey != "":
|
||||
apiKey = cfg.Providers.VolcEngine.APIKey
|
||||
apiBase = cfg.Providers.VolcEngine.APIBase
|
||||
proxy = cfg.Providers.VolcEngine.Proxy
|
||||
if apiBase == "" {
|
||||
apiBase = "https://ark.cn-beijing.volces.com/api/v3"
|
||||
}
|
||||
|
||||
case cfg.Providers.VLLM.APIBase != "":
|
||||
apiKey = cfg.Providers.VLLM.APIKey
|
||||
apiBase = cfg.Providers.VLLM.APIBase
|
||||
proxy = cfg.Providers.VLLM.Proxy
|
||||
|
||||
default:
|
||||
if cfg.Providers.OpenRouter.APIKey != "" {
|
||||
apiKey = cfg.Providers.OpenRouter.APIKey
|
||||
proxy = cfg.Providers.OpenRouter.Proxy
|
||||
if cfg.Providers.OpenRouter.APIBase != "" {
|
||||
apiBase = cfg.Providers.OpenRouter.APIBase
|
||||
} else {
|
||||
apiBase = "https://openrouter.ai/api/v1"
|
||||
}
|
||||
} else {
|
||||
return nil, "", fmt.Errorf("no API key configured for model: %s", model)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if apiKey == "" && !strings.HasPrefix(model, "bedrock/") {
|
||||
return nil, "", fmt.Errorf("no API key configured for provider (model: %s)", model)
|
||||
}
|
||||
|
||||
if apiBase == "" {
|
||||
return nil, "", fmt.Errorf("no API base configured for provider (model: %s)", model)
|
||||
}
|
||||
|
||||
return NewHTTPProvider(apiKey, apiBase, proxy), model, nil
|
||||
}
|
||||
@@ -1,113 +0,0 @@
|
||||
// PicoClaw - Ultra-lightweight personal AI agent
|
||||
// License: MIT
|
||||
//
|
||||
// Copyright (c) 2026 PicoClaw contributors
|
||||
|
||||
package providers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
// ModelRegistry manages model configurations with thread-safe round-robin load balancing.
|
||||
// It allows multiple configurations for the same model_name to distribute load across endpoints.
|
||||
type ModelRegistry struct {
|
||||
configs map[string][]config.ModelConfig // model_name -> []ModelConfig
|
||||
counters map[string]*atomic.Uint64 // model_name -> round-robin counter
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewModelRegistry creates a new ModelRegistry from a slice of ModelConfig.
|
||||
func NewModelRegistry(modelList []config.ModelConfig) *ModelRegistry {
|
||||
r := &ModelRegistry{
|
||||
configs: make(map[string][]config.ModelConfig),
|
||||
counters: make(map[string]*atomic.Uint64),
|
||||
}
|
||||
|
||||
for _, cfg := range modelList {
|
||||
r.configs[cfg.ModelName] = append(r.configs[cfg.ModelName], cfg)
|
||||
}
|
||||
|
||||
// Initialize counters for models with multiple configs
|
||||
for name, cfgs := range r.configs {
|
||||
if len(cfgs) > 1 {
|
||||
r.counters[name] = &atomic.Uint64{}
|
||||
}
|
||||
}
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
// GetModelConfig returns a ModelConfig for the given model name.
|
||||
// If multiple configs exist for the same model_name, it uses round-robin selection.
|
||||
// Returns an error if the model is not found.
|
||||
func (r *ModelRegistry) GetModelConfig(modelName string) (*config.ModelConfig, error) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
configs, ok := r.configs[modelName]
|
||||
if !ok || len(configs) == 0 {
|
||||
return nil, fmt.Errorf("model %q not found", modelName)
|
||||
}
|
||||
|
||||
// Single config - return directly
|
||||
if len(configs) == 1 {
|
||||
return &configs[0], nil
|
||||
}
|
||||
|
||||
// Multiple configs - use round-robin for load balancing
|
||||
counter, ok := r.counters[modelName]
|
||||
if !ok {
|
||||
// Should not happen, but handle gracefully
|
||||
return &configs[0], nil
|
||||
}
|
||||
|
||||
idx := counter.Add(1) % uint64(len(configs))
|
||||
return &configs[idx], nil
|
||||
}
|
||||
|
||||
// AddConfig adds a new ModelConfig to the registry.
|
||||
func (r *ModelRegistry) AddConfig(cfg config.ModelConfig) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
r.configs[cfg.ModelName] = append(r.configs[cfg.ModelName], cfg)
|
||||
|
||||
// Initialize counter if we now have multiple configs
|
||||
if len(r.configs[cfg.ModelName]) > 1 && r.counters[cfg.ModelName] == nil {
|
||||
r.counters[cfg.ModelName] = &atomic.Uint64{}
|
||||
}
|
||||
}
|
||||
|
||||
// RemoveConfig removes all configs with the given model_name.
|
||||
func (r *ModelRegistry) RemoveConfig(modelName string) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
delete(r.configs, modelName)
|
||||
delete(r.counters, modelName)
|
||||
}
|
||||
|
||||
// ListModels returns all unique model names in the registry.
|
||||
func (r *ModelRegistry) ListModels() []string {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
names := make([]string, 0, len(r.configs))
|
||||
for name := range r.configs {
|
||||
names = append(names, name)
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
||||
// ConfigCount returns the number of configurations for a given model name.
|
||||
func (r *ModelRegistry) ConfigCount(modelName string) int {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
return len(r.configs[modelName])
|
||||
}
|
||||
Reference in New Issue
Block a user