mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
refactor: reorganize commands and provider architecture
Refactor command handlers into separate files to improve code organization and maintainability. Each command (agent, auth, cron, gateway, migrate, onboard, skills, status) now has its own dedicated file. Restructure provider creation to support new model_list configuration system that enables zero-code addition of OpenAI-compatible providers. Move legacy provider logic to separate file for backward compatibility. Move configuration functions from config.go to separate files (defaults.go, migration.go) for better organization.
This commit is contained in:
@@ -0,0 +1,181 @@
|
||||
// PicoClaw - Ultra-lightweight personal AI agent
|
||||
// License: MIT
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/chzyer/readline"
|
||||
"github.com/sipeed/picoclaw/pkg/agent"
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
func agentCmd() {
|
||||
message := ""
|
||||
sessionKey := "cli:default"
|
||||
modelOverride := ""
|
||||
|
||||
args := os.Args[2:]
|
||||
for i := 0; i < len(args); i++ {
|
||||
switch args[i] {
|
||||
case "--debug", "-d":
|
||||
logger.SetLevel(logger.DEBUG)
|
||||
fmt.Println("🔍 Debug mode enabled")
|
||||
case "-m", "--message":
|
||||
if i+1 < len(args) {
|
||||
message = args[i+1]
|
||||
i++
|
||||
}
|
||||
case "-s", "--session":
|
||||
if i+1 < len(args) {
|
||||
sessionKey = args[i+1]
|
||||
i++
|
||||
}
|
||||
case "--model", "-model":
|
||||
if i+1 < len(args) {
|
||||
modelOverride = args[i+1]
|
||||
i++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
cfg, err := loadConfig()
|
||||
if err != nil {
|
||||
fmt.Printf("Error loading config: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if modelOverride != "" {
|
||||
cfg.Agents.Defaults.Model = modelOverride
|
||||
}
|
||||
|
||||
provider, modelID, err := providers.CreateProvider(cfg)
|
||||
if err != nil {
|
||||
fmt.Printf("Error creating provider: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
// Use the resolved model ID from provider creation
|
||||
if modelID != "" {
|
||||
cfg.Agents.Defaults.Model = modelID
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
agentLoop := agent.NewAgentLoop(cfg, msgBus, provider)
|
||||
|
||||
// Print agent startup info (only for interactive mode)
|
||||
startupInfo := agentLoop.GetStartupInfo()
|
||||
logger.InfoCF("agent", "Agent initialized",
|
||||
map[string]interface{}{
|
||||
"tools_count": startupInfo["tools"].(map[string]interface{})["count"],
|
||||
"skills_total": startupInfo["skills"].(map[string]interface{})["total"],
|
||||
"skills_available": startupInfo["skills"].(map[string]interface{})["available"],
|
||||
})
|
||||
|
||||
if message != "" {
|
||||
ctx := context.Background()
|
||||
response, err := agentLoop.ProcessDirect(ctx, message, sessionKey)
|
||||
if err != nil {
|
||||
fmt.Printf("Error: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
fmt.Printf("\n%s %s\n", logo, response)
|
||||
} else {
|
||||
fmt.Printf("%s Interactive mode (Ctrl+C to exit)\n\n", logo)
|
||||
interactiveMode(agentLoop, sessionKey)
|
||||
}
|
||||
}
|
||||
|
||||
func interactiveMode(agentLoop *agent.AgentLoop, sessionKey string) {
|
||||
prompt := fmt.Sprintf("%s You: ", logo)
|
||||
|
||||
rl, err := readline.NewEx(&readline.Config{
|
||||
Prompt: prompt,
|
||||
HistoryFile: filepath.Join(os.TempDir(), ".picoclaw_history"),
|
||||
HistoryLimit: 100,
|
||||
InterruptPrompt: "^C",
|
||||
EOFPrompt: "exit",
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
fmt.Printf("Error initializing readline: %v\n", err)
|
||||
fmt.Println("Falling back to simple input mode...")
|
||||
simpleInteractiveMode(agentLoop, sessionKey)
|
||||
return
|
||||
}
|
||||
defer rl.Close()
|
||||
|
||||
for {
|
||||
line, err := rl.Readline()
|
||||
if err != nil {
|
||||
if err == readline.ErrInterrupt || err == io.EOF {
|
||||
fmt.Println("\nGoodbye!")
|
||||
return
|
||||
}
|
||||
fmt.Printf("Error reading input: %v\n", err)
|
||||
continue
|
||||
}
|
||||
|
||||
input := strings.TrimSpace(line)
|
||||
if input == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
if input == "exit" || input == "quit" {
|
||||
fmt.Println("Goodbye!")
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
response, err := agentLoop.ProcessDirect(ctx, input, sessionKey)
|
||||
if err != nil {
|
||||
fmt.Printf("Error: %v\n", err)
|
||||
continue
|
||||
}
|
||||
|
||||
fmt.Printf("\n%s %s\n\n", logo, response)
|
||||
}
|
||||
}
|
||||
|
||||
func simpleInteractiveMode(agentLoop *agent.AgentLoop, sessionKey string) {
|
||||
reader := bufio.NewReader(os.Stdin)
|
||||
for {
|
||||
fmt.Print(fmt.Sprintf("%s You: ", logo))
|
||||
line, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
fmt.Println("\nGoodbye!")
|
||||
return
|
||||
}
|
||||
fmt.Printf("Error reading input: %v\n", err)
|
||||
continue
|
||||
}
|
||||
|
||||
input := strings.TrimSpace(line)
|
||||
if input == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
if input == "exit" || input == "quit" {
|
||||
fmt.Println("Goodbye!")
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
response, err := agentLoop.ProcessDirect(ctx, input, sessionKey)
|
||||
if err != nil {
|
||||
fmt.Printf("Error: %v\n", err)
|
||||
continue
|
||||
}
|
||||
|
||||
fmt.Printf("\n%s %s\n\n", logo, response)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,386 @@
|
||||
// PicoClaw - Ultra-lightweight personal AI agent
|
||||
// License: MIT
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/auth"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
func authCmd() {
|
||||
if len(os.Args) < 3 {
|
||||
authHelp()
|
||||
return
|
||||
}
|
||||
|
||||
switch os.Args[2] {
|
||||
case "login":
|
||||
authLoginCmd()
|
||||
case "logout":
|
||||
authLogoutCmd()
|
||||
case "status":
|
||||
authStatusCmd()
|
||||
case "models":
|
||||
authModelsCmd()
|
||||
default:
|
||||
fmt.Printf("Unknown auth command: %s\n", os.Args[2])
|
||||
authHelp()
|
||||
}
|
||||
}
|
||||
|
||||
func authHelp() {
|
||||
fmt.Println("\nAuth commands:")
|
||||
fmt.Println(" login Login via OAuth or paste token")
|
||||
fmt.Println(" logout Remove stored credentials")
|
||||
fmt.Println(" status Show current auth status")
|
||||
fmt.Println(" models List available Antigravity models")
|
||||
fmt.Println()
|
||||
fmt.Println("Login options:")
|
||||
fmt.Println(" --provider <name> Provider to login with (openai, anthropic, google-antigravity)")
|
||||
fmt.Println(" --device-code Use device code flow (for headless environments)")
|
||||
fmt.Println()
|
||||
fmt.Println("Examples:")
|
||||
fmt.Println(" picoclaw auth login --provider openai")
|
||||
fmt.Println(" picoclaw auth login --provider openai --device-code")
|
||||
fmt.Println(" picoclaw auth login --provider anthropic")
|
||||
fmt.Println(" picoclaw auth login --provider google-antigravity")
|
||||
fmt.Println(" picoclaw auth models")
|
||||
fmt.Println(" picoclaw auth logout --provider openai")
|
||||
fmt.Println(" picoclaw auth status")
|
||||
}
|
||||
|
||||
func authLoginCmd() {
|
||||
provider := ""
|
||||
useDeviceCode := false
|
||||
|
||||
args := os.Args[3:]
|
||||
for i := 0; i < len(args); i++ {
|
||||
switch args[i] {
|
||||
case "--provider", "-p":
|
||||
if i+1 < len(args) {
|
||||
provider = args[i+1]
|
||||
i++
|
||||
}
|
||||
case "--device-code":
|
||||
useDeviceCode = true
|
||||
}
|
||||
}
|
||||
|
||||
if provider == "" {
|
||||
fmt.Println("Error: --provider is required")
|
||||
fmt.Println("Supported providers: openai, anthropic, google-antigravity")
|
||||
return
|
||||
}
|
||||
|
||||
switch provider {
|
||||
case "openai":
|
||||
authLoginOpenAI(useDeviceCode)
|
||||
case "anthropic":
|
||||
authLoginPasteToken(provider)
|
||||
case "google-antigravity", "antigravity":
|
||||
authLoginGoogleAntigravity()
|
||||
default:
|
||||
fmt.Printf("Unsupported provider: %s\n", provider)
|
||||
fmt.Println("Supported providers: openai, anthropic, google-antigravity")
|
||||
}
|
||||
}
|
||||
|
||||
func authLoginOpenAI(useDeviceCode bool) {
|
||||
cfg := auth.OpenAIOAuthConfig()
|
||||
|
||||
var cred *auth.AuthCredential
|
||||
var err error
|
||||
|
||||
if useDeviceCode {
|
||||
cred, err = auth.LoginDeviceCode(cfg)
|
||||
} else {
|
||||
cred, err = auth.LoginBrowser(cfg)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
fmt.Printf("Login failed: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if err := auth.SetCredential("openai", cred); err != nil {
|
||||
fmt.Printf("Failed to save credentials: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
appCfg, err := loadConfig()
|
||||
if err == nil {
|
||||
appCfg.Providers.OpenAI.AuthMethod = "oauth"
|
||||
if err := config.SaveConfig(getConfigPath(), appCfg); err != nil {
|
||||
fmt.Printf("Warning: could not update config: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Println("Login successful!")
|
||||
if cred.AccountID != "" {
|
||||
fmt.Printf("Account: %s\n", cred.AccountID)
|
||||
}
|
||||
}
|
||||
|
||||
func authLoginGoogleAntigravity() {
|
||||
cfg := auth.GoogleAntigravityOAuthConfig()
|
||||
|
||||
cred, err := auth.LoginBrowser(cfg)
|
||||
if err != nil {
|
||||
fmt.Printf("Login failed: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
cred.Provider = "google-antigravity"
|
||||
|
||||
// Fetch user email from Google userinfo
|
||||
email, err := fetchGoogleUserEmail(cred.AccessToken)
|
||||
if err != nil {
|
||||
fmt.Printf("Warning: could not fetch email: %v\n", err)
|
||||
} else {
|
||||
cred.Email = email
|
||||
fmt.Printf("Email: %s\n", email)
|
||||
}
|
||||
|
||||
// Fetch Cloud Code Assist project ID
|
||||
projectID, err := providers.FetchAntigravityProjectID(cred.AccessToken)
|
||||
if err != nil {
|
||||
fmt.Printf("Warning: could not fetch project ID: %v\n", err)
|
||||
fmt.Println("You may need Google Cloud Code Assist enabled on your account.")
|
||||
} else {
|
||||
cred.ProjectID = projectID
|
||||
fmt.Printf("Project: %s\n", projectID)
|
||||
}
|
||||
|
||||
if err := auth.SetCredential("google-antigravity", cred); err != nil {
|
||||
fmt.Printf("Failed to save credentials: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
appCfg, err := loadConfig()
|
||||
if err == nil {
|
||||
appCfg.Providers.Antigravity.AuthMethod = "oauth"
|
||||
if appCfg.Agents.Defaults.Provider == "" {
|
||||
appCfg.Agents.Defaults.Provider = "antigravity"
|
||||
}
|
||||
if appCfg.Agents.Defaults.Provider == "antigravity" || appCfg.Agents.Defaults.Provider == "google-antigravity" {
|
||||
appCfg.Agents.Defaults.Model = "gemini-3-flash"
|
||||
}
|
||||
if err := config.SaveConfig(getConfigPath(), appCfg); err != nil {
|
||||
fmt.Printf("Warning: could not update config: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Println("\n✓ Google Antigravity login successful!")
|
||||
fmt.Println("Config updated: provider=antigravity, model=gemini-3-flash")
|
||||
fmt.Println("Try it: picoclaw agent -m \"Hello world\"")
|
||||
}
|
||||
|
||||
func fetchGoogleUserEmail(accessToken string) (string, error) {
|
||||
req, err := http.NewRequest("GET", "https://www.googleapis.com/oauth2/v2/userinfo", nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
|
||||
client := &http.Client{Timeout: 10 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("userinfo request failed: %s", string(body))
|
||||
}
|
||||
|
||||
var userInfo struct {
|
||||
Email string `json:"email"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &userInfo); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return userInfo.Email, nil
|
||||
}
|
||||
|
||||
func authLoginPasteToken(provider string) {
|
||||
cred, err := auth.LoginPasteToken(provider, os.Stdin)
|
||||
if err != nil {
|
||||
fmt.Printf("Login failed: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if err := auth.SetCredential(provider, cred); err != nil {
|
||||
fmt.Printf("Failed to save credentials: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
appCfg, err := loadConfig()
|
||||
if err == nil {
|
||||
switch provider {
|
||||
case "anthropic":
|
||||
appCfg.Providers.Anthropic.AuthMethod = "token"
|
||||
case "openai":
|
||||
appCfg.Providers.OpenAI.AuthMethod = "token"
|
||||
}
|
||||
if err := config.SaveConfig(getConfigPath(), appCfg); err != nil {
|
||||
fmt.Printf("Warning: could not update config: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Printf("Token saved for %s!\n", provider)
|
||||
}
|
||||
|
||||
func authLogoutCmd() {
|
||||
provider := ""
|
||||
|
||||
args := os.Args[3:]
|
||||
for i := 0; i < len(args); i++ {
|
||||
switch args[i] {
|
||||
case "--provider", "-p":
|
||||
if i+1 < len(args) {
|
||||
provider = args[i+1]
|
||||
i++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if provider != "" {
|
||||
if err := auth.DeleteCredential(provider); err != nil {
|
||||
fmt.Printf("Failed to remove credentials: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
appCfg, err := loadConfig()
|
||||
if err == nil {
|
||||
switch provider {
|
||||
case "openai":
|
||||
appCfg.Providers.OpenAI.AuthMethod = ""
|
||||
case "anthropic":
|
||||
appCfg.Providers.Anthropic.AuthMethod = ""
|
||||
case "google-antigravity", "antigravity":
|
||||
appCfg.Providers.Antigravity.AuthMethod = ""
|
||||
}
|
||||
config.SaveConfig(getConfigPath(), appCfg)
|
||||
}
|
||||
|
||||
fmt.Printf("Logged out from %s\n", provider)
|
||||
} else {
|
||||
if err := auth.DeleteAllCredentials(); err != nil {
|
||||
fmt.Printf("Failed to remove credentials: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
appCfg, err := loadConfig()
|
||||
if err == nil {
|
||||
appCfg.Providers.OpenAI.AuthMethod = ""
|
||||
appCfg.Providers.Anthropic.AuthMethod = ""
|
||||
appCfg.Providers.Antigravity.AuthMethod = ""
|
||||
config.SaveConfig(getConfigPath(), appCfg)
|
||||
}
|
||||
|
||||
fmt.Println("Logged out from all providers")
|
||||
}
|
||||
}
|
||||
|
||||
func authStatusCmd() {
|
||||
store, err := auth.LoadStore()
|
||||
if err != nil {
|
||||
fmt.Printf("Error loading auth store: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
if len(store.Credentials) == 0 {
|
||||
fmt.Println("No authenticated providers.")
|
||||
fmt.Println("Run: picoclaw auth login --provider <name>")
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Println("\nAuthenticated Providers:")
|
||||
fmt.Println("------------------------")
|
||||
for provider, cred := range store.Credentials {
|
||||
status := "active"
|
||||
if cred.IsExpired() {
|
||||
status = "expired"
|
||||
} else if cred.NeedsRefresh() {
|
||||
status = "needs refresh"
|
||||
}
|
||||
|
||||
fmt.Printf(" %s:\n", provider)
|
||||
fmt.Printf(" Method: %s\n", cred.AuthMethod)
|
||||
fmt.Printf(" Status: %s\n", status)
|
||||
if cred.AccountID != "" {
|
||||
fmt.Printf(" Account: %s\n", cred.AccountID)
|
||||
}
|
||||
if cred.Email != "" {
|
||||
fmt.Printf(" Email: %s\n", cred.Email)
|
||||
}
|
||||
if cred.ProjectID != "" {
|
||||
fmt.Printf(" Project: %s\n", cred.ProjectID)
|
||||
}
|
||||
if !cred.ExpiresAt.IsZero() {
|
||||
fmt.Printf(" Expires: %s\n", cred.ExpiresAt.Format("2006-01-02 15:04"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func authModelsCmd() {
|
||||
cred, err := auth.GetCredential("google-antigravity")
|
||||
if err != nil || cred == nil {
|
||||
fmt.Println("Not logged in to Google Antigravity.")
|
||||
fmt.Println("Run: picoclaw auth login --provider google-antigravity")
|
||||
return
|
||||
}
|
||||
|
||||
// Refresh token if needed
|
||||
if cred.NeedsRefresh() && cred.RefreshToken != "" {
|
||||
oauthCfg := auth.GoogleAntigravityOAuthConfig()
|
||||
refreshed, refreshErr := auth.RefreshAccessToken(cred, oauthCfg)
|
||||
if refreshErr == nil {
|
||||
cred = refreshed
|
||||
_ = auth.SetCredential("google-antigravity", cred)
|
||||
}
|
||||
}
|
||||
|
||||
projectID := cred.ProjectID
|
||||
if projectID == "" {
|
||||
fmt.Println("No project ID stored. Try logging in again.")
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Printf("Fetching models for project: %s\n\n", projectID)
|
||||
|
||||
models, err := providers.FetchAntigravityModels(cred.AccessToken, projectID)
|
||||
if err != nil {
|
||||
fmt.Printf("Error fetching models: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
if len(models) == 0 {
|
||||
fmt.Println("No models available.")
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Println("Available Antigravity Models:")
|
||||
fmt.Println("-----------------------------")
|
||||
for _, m := range models {
|
||||
status := "✓"
|
||||
if m.IsExhausted {
|
||||
status = "✗ (quota exhausted)"
|
||||
}
|
||||
name := m.ID
|
||||
if m.DisplayName != "" {
|
||||
name = fmt.Sprintf("%s (%s)", m.ID, m.DisplayName)
|
||||
}
|
||||
fmt.Printf(" %s %s\n", status, name)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,227 @@
|
||||
// PicoClaw - Ultra-lightweight personal AI agent
|
||||
// License: MIT
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/cron"
|
||||
)
|
||||
|
||||
func cronCmd() {
|
||||
if len(os.Args) < 3 {
|
||||
cronHelp()
|
||||
return
|
||||
}
|
||||
|
||||
subcommand := os.Args[2]
|
||||
|
||||
// Load config to get workspace path
|
||||
cfg, err := loadConfig()
|
||||
if err != nil {
|
||||
fmt.Printf("Error loading config: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
cronStorePath := filepath.Join(cfg.WorkspacePath(), "cron", "jobs.json")
|
||||
|
||||
switch subcommand {
|
||||
case "list":
|
||||
cronListCmd(cronStorePath)
|
||||
case "add":
|
||||
cronAddCmd(cronStorePath)
|
||||
case "remove":
|
||||
if len(os.Args) < 4 {
|
||||
fmt.Println("Usage: picoclaw cron remove <job_id>")
|
||||
return
|
||||
}
|
||||
cronRemoveCmd(cronStorePath, os.Args[3])
|
||||
case "enable":
|
||||
cronEnableCmd(cronStorePath, false)
|
||||
case "disable":
|
||||
cronEnableCmd(cronStorePath, true)
|
||||
default:
|
||||
fmt.Printf("Unknown cron command: %s\n", subcommand)
|
||||
cronHelp()
|
||||
}
|
||||
}
|
||||
|
||||
func cronHelp() {
|
||||
fmt.Println("\nCron commands:")
|
||||
fmt.Println(" list List all scheduled jobs")
|
||||
fmt.Println(" add Add a new scheduled job")
|
||||
fmt.Println(" remove <id> Remove a job by ID")
|
||||
fmt.Println(" enable <id> Enable a job")
|
||||
fmt.Println(" disable <id> Disable a job")
|
||||
fmt.Println()
|
||||
fmt.Println("Add options:")
|
||||
fmt.Println(" -n, --name Job name")
|
||||
fmt.Println(" -m, --message Message for agent")
|
||||
fmt.Println(" -e, --every Run every N seconds")
|
||||
fmt.Println(" -c, --cron Cron expression (e.g. '0 9 * * *')")
|
||||
fmt.Println(" -d, --deliver Deliver response to channel")
|
||||
fmt.Println(" --to Recipient for delivery")
|
||||
fmt.Println(" --channel Channel for delivery")
|
||||
}
|
||||
|
||||
func cronListCmd(storePath string) {
|
||||
cs := cron.NewCronService(storePath, nil)
|
||||
jobs := cs.ListJobs(true) // Show all jobs, including disabled
|
||||
|
||||
if len(jobs) == 0 {
|
||||
fmt.Println("No scheduled jobs.")
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Println("\nScheduled Jobs:")
|
||||
fmt.Println("----------------")
|
||||
for _, job := range jobs {
|
||||
var schedule string
|
||||
if job.Schedule.Kind == "every" && job.Schedule.EveryMS != nil {
|
||||
schedule = fmt.Sprintf("every %ds", *job.Schedule.EveryMS/1000)
|
||||
} else if job.Schedule.Kind == "cron" {
|
||||
schedule = job.Schedule.Expr
|
||||
} else {
|
||||
schedule = "one-time"
|
||||
}
|
||||
|
||||
nextRun := "scheduled"
|
||||
if job.State.NextRunAtMS != nil {
|
||||
nextTime := time.UnixMilli(*job.State.NextRunAtMS)
|
||||
nextRun = nextTime.Format("2006-01-02 15:04")
|
||||
}
|
||||
|
||||
status := "enabled"
|
||||
if !job.Enabled {
|
||||
status = "disabled"
|
||||
}
|
||||
|
||||
fmt.Printf(" %s (%s)\n", job.Name, job.ID)
|
||||
fmt.Printf(" Schedule: %s\n", schedule)
|
||||
fmt.Printf(" Status: %s\n", status)
|
||||
fmt.Printf(" Next run: %s\n", nextRun)
|
||||
}
|
||||
}
|
||||
|
||||
func cronAddCmd(storePath string) {
|
||||
name := ""
|
||||
message := ""
|
||||
var everySec *int64
|
||||
cronExpr := ""
|
||||
deliver := false
|
||||
channel := ""
|
||||
to := ""
|
||||
|
||||
args := os.Args[3:]
|
||||
for i := 0; i < len(args); i++ {
|
||||
switch args[i] {
|
||||
case "-n", "--name":
|
||||
if i+1 < len(args) {
|
||||
name = args[i+1]
|
||||
i++
|
||||
}
|
||||
case "-m", "--message":
|
||||
if i+1 < len(args) {
|
||||
message = args[i+1]
|
||||
i++
|
||||
}
|
||||
case "-e", "--every":
|
||||
if i+1 < len(args) {
|
||||
var sec int64
|
||||
fmt.Sscanf(args[i+1], "%d", &sec)
|
||||
everySec = &sec
|
||||
i++
|
||||
}
|
||||
case "-c", "--cron":
|
||||
if i+1 < len(args) {
|
||||
cronExpr = args[i+1]
|
||||
i++
|
||||
}
|
||||
case "-d", "--deliver":
|
||||
deliver = true
|
||||
case "--to":
|
||||
if i+1 < len(args) {
|
||||
to = args[i+1]
|
||||
i++
|
||||
}
|
||||
case "--channel":
|
||||
if i+1 < len(args) {
|
||||
channel = args[i+1]
|
||||
i++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if name == "" {
|
||||
fmt.Println("Error: --name is required")
|
||||
return
|
||||
}
|
||||
|
||||
if message == "" {
|
||||
fmt.Println("Error: --message is required")
|
||||
return
|
||||
}
|
||||
|
||||
if everySec == nil && cronExpr == "" {
|
||||
fmt.Println("Error: Either --every or --cron must be specified")
|
||||
return
|
||||
}
|
||||
|
||||
var schedule cron.CronSchedule
|
||||
if everySec != nil {
|
||||
everyMS := *everySec * 1000
|
||||
schedule = cron.CronSchedule{
|
||||
Kind: "every",
|
||||
EveryMS: &everyMS,
|
||||
}
|
||||
} else {
|
||||
schedule = cron.CronSchedule{
|
||||
Kind: "cron",
|
||||
Expr: cronExpr,
|
||||
}
|
||||
}
|
||||
|
||||
cs := cron.NewCronService(storePath, nil)
|
||||
job, err := cs.AddJob(name, schedule, message, deliver, channel, to)
|
||||
if err != nil {
|
||||
fmt.Printf("Error adding job: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Printf("✓ Added job '%s' (%s)\n", job.Name, job.ID)
|
||||
}
|
||||
|
||||
func cronRemoveCmd(storePath, jobID string) {
|
||||
cs := cron.NewCronService(storePath, nil)
|
||||
if cs.RemoveJob(jobID) {
|
||||
fmt.Printf("✓ Removed job %s\n", jobID)
|
||||
} else {
|
||||
fmt.Printf("✗ Job %s not found\n", jobID)
|
||||
}
|
||||
}
|
||||
|
||||
func cronEnableCmd(storePath string, disable bool) {
|
||||
if len(os.Args) < 4 {
|
||||
fmt.Println("Usage: picoclaw cron enable/disable <job_id>")
|
||||
return
|
||||
}
|
||||
|
||||
jobID := os.Args[3]
|
||||
cs := cron.NewCronService(storePath, nil)
|
||||
enabled := !disable
|
||||
|
||||
job := cs.EnableJob(jobID, enabled)
|
||||
if job != nil {
|
||||
status := "enabled"
|
||||
if disable {
|
||||
status = "disabled"
|
||||
}
|
||||
fmt.Printf("✓ Job '%s' %s\n", job.Name, status)
|
||||
} else {
|
||||
fmt.Printf("✗ Job %s not found\n", jobID)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,222 @@
|
||||
// PicoClaw - Ultra-lightweight personal AI agent
|
||||
// License: MIT
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/agent"
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/channels"
|
||||
"github.com/sipeed/picoclaw/pkg/cron"
|
||||
"github.com/sipeed/picoclaw/pkg/devices"
|
||||
"github.com/sipeed/picoclaw/pkg/health"
|
||||
"github.com/sipeed/picoclaw/pkg/heartbeat"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/state"
|
||||
"github.com/sipeed/picoclaw/pkg/tools"
|
||||
"github.com/sipeed/picoclaw/pkg/voice"
|
||||
)
|
||||
|
||||
func gatewayCmd() {
|
||||
// Check for --debug flag
|
||||
args := os.Args[2:]
|
||||
for _, arg := range args {
|
||||
if arg == "--debug" || arg == "-d" {
|
||||
logger.SetLevel(logger.DEBUG)
|
||||
fmt.Println("🔍 Debug mode enabled")
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
cfg, err := loadConfig()
|
||||
if err != nil {
|
||||
fmt.Printf("Error loading config: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
provider, modelID, err := providers.CreateProvider(cfg)
|
||||
if err != nil {
|
||||
fmt.Printf("Error creating provider: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
// Use the resolved model ID from provider creation
|
||||
if modelID != "" {
|
||||
cfg.Agents.Defaults.Model = modelID
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
agentLoop := agent.NewAgentLoop(cfg, msgBus, provider)
|
||||
|
||||
// Print agent startup info
|
||||
fmt.Println("\n📦 Agent Status:")
|
||||
startupInfo := agentLoop.GetStartupInfo()
|
||||
toolsInfo := startupInfo["tools"].(map[string]interface{})
|
||||
skillsInfo := startupInfo["skills"].(map[string]interface{})
|
||||
fmt.Printf(" • Tools: %d loaded\n", toolsInfo["count"])
|
||||
fmt.Printf(" • Skills: %d/%d available\n",
|
||||
skillsInfo["available"],
|
||||
skillsInfo["total"])
|
||||
|
||||
// Log to file as well
|
||||
logger.InfoCF("agent", "Agent initialized",
|
||||
map[string]interface{}{
|
||||
"tools_count": toolsInfo["count"],
|
||||
"skills_total": skillsInfo["total"],
|
||||
"skills_available": skillsInfo["available"],
|
||||
})
|
||||
|
||||
// Setup cron tool and service
|
||||
execTimeout := time.Duration(cfg.Tools.Cron.ExecTimeoutMinutes) * time.Minute
|
||||
cronService := setupCronTool(agentLoop, msgBus, cfg.WorkspacePath(), cfg.Agents.Defaults.RestrictToWorkspace, execTimeout)
|
||||
|
||||
heartbeatService := heartbeat.NewHeartbeatService(
|
||||
cfg.WorkspacePath(),
|
||||
cfg.Heartbeat.Interval,
|
||||
cfg.Heartbeat.Enabled,
|
||||
)
|
||||
heartbeatService.SetBus(msgBus)
|
||||
heartbeatService.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult {
|
||||
// Use cli:direct as fallback if no valid channel
|
||||
if channel == "" || chatID == "" {
|
||||
channel, chatID = "cli", "direct"
|
||||
}
|
||||
// Use ProcessHeartbeat - no session history, each heartbeat is independent
|
||||
response, err := agentLoop.ProcessHeartbeat(context.Background(), prompt, channel, chatID)
|
||||
if err != nil {
|
||||
return tools.ErrorResult(fmt.Sprintf("Heartbeat error: %v", err))
|
||||
}
|
||||
if response == "HEARTBEAT_OK" {
|
||||
return tools.SilentResult("Heartbeat OK")
|
||||
}
|
||||
// For heartbeat, always return silent - the subagent result will be
|
||||
// sent to user via processSystemMessage when the async task completes
|
||||
return tools.SilentResult(response)
|
||||
})
|
||||
|
||||
channelManager, err := channels.NewManager(cfg, msgBus)
|
||||
if err != nil {
|
||||
fmt.Printf("Error creating channel manager: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Inject channel manager into agent loop for command handling
|
||||
agentLoop.SetChannelManager(channelManager)
|
||||
|
||||
var transcriber *voice.GroqTranscriber
|
||||
if cfg.Providers.Groq.APIKey != "" {
|
||||
transcriber = voice.NewGroqTranscriber(cfg.Providers.Groq.APIKey)
|
||||
logger.InfoC("voice", "Groq voice transcription enabled")
|
||||
}
|
||||
|
||||
if transcriber != nil {
|
||||
if telegramChannel, ok := channelManager.GetChannel("telegram"); ok {
|
||||
if tc, ok := telegramChannel.(*channels.TelegramChannel); ok {
|
||||
tc.SetTranscriber(transcriber)
|
||||
logger.InfoC("voice", "Groq transcription attached to Telegram channel")
|
||||
}
|
||||
}
|
||||
if discordChannel, ok := channelManager.GetChannel("discord"); ok {
|
||||
if dc, ok := discordChannel.(*channels.DiscordChannel); ok {
|
||||
dc.SetTranscriber(transcriber)
|
||||
logger.InfoC("voice", "Groq transcription attached to Discord channel")
|
||||
}
|
||||
}
|
||||
if slackChannel, ok := channelManager.GetChannel("slack"); ok {
|
||||
if sc, ok := slackChannel.(*channels.SlackChannel); ok {
|
||||
sc.SetTranscriber(transcriber)
|
||||
logger.InfoC("voice", "Groq transcription attached to Slack channel")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
enabledChannels := channelManager.GetEnabledChannels()
|
||||
if len(enabledChannels) > 0 {
|
||||
fmt.Printf("✓ Channels enabled: %s\n", enabledChannels)
|
||||
} else {
|
||||
fmt.Println("⚠ Warning: No channels enabled")
|
||||
}
|
||||
|
||||
fmt.Printf("✓ Gateway started on %s:%d\n", cfg.Gateway.Host, cfg.Gateway.Port)
|
||||
fmt.Println("Press Ctrl+C to stop")
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
if err := cronService.Start(); err != nil {
|
||||
fmt.Printf("Error starting cron service: %v\n", err)
|
||||
}
|
||||
fmt.Println("✓ Cron service started")
|
||||
|
||||
if err := heartbeatService.Start(); err != nil {
|
||||
fmt.Printf("Error starting heartbeat service: %v\n", err)
|
||||
}
|
||||
fmt.Println("✓ Heartbeat service started")
|
||||
|
||||
stateManager := state.NewManager(cfg.WorkspacePath())
|
||||
deviceService := devices.NewService(devices.Config{
|
||||
Enabled: cfg.Devices.Enabled,
|
||||
MonitorUSB: cfg.Devices.MonitorUSB,
|
||||
}, stateManager)
|
||||
deviceService.SetBus(msgBus)
|
||||
if err := deviceService.Start(ctx); err != nil {
|
||||
fmt.Printf("Error starting device service: %v\n", err)
|
||||
} else if cfg.Devices.Enabled {
|
||||
fmt.Println("✓ Device event service started")
|
||||
}
|
||||
|
||||
if err := channelManager.StartAll(ctx); err != nil {
|
||||
fmt.Printf("Error starting channels: %v\n", err)
|
||||
}
|
||||
|
||||
healthServer := health.NewServer(cfg.Gateway.Host, cfg.Gateway.Port)
|
||||
go func() {
|
||||
if err := healthServer.Start(); err != nil && err != http.ErrServerClosed {
|
||||
logger.ErrorCF("health", "Health server error", map[string]interface{}{"error": err.Error()})
|
||||
}
|
||||
}()
|
||||
fmt.Printf("✓ Health endpoints available at http://%s:%d/health and /ready\n", cfg.Gateway.Host, cfg.Gateway.Port)
|
||||
|
||||
go agentLoop.Run(ctx)
|
||||
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
signal.Notify(sigChan, os.Interrupt)
|
||||
<-sigChan
|
||||
|
||||
fmt.Println("\nShutting down...")
|
||||
cancel()
|
||||
healthServer.Stop(context.Background())
|
||||
deviceService.Stop()
|
||||
heartbeatService.Stop()
|
||||
cronService.Stop()
|
||||
agentLoop.Stop()
|
||||
channelManager.StopAll(ctx)
|
||||
fmt.Println("✓ Gateway stopped")
|
||||
}
|
||||
|
||||
func setupCronTool(agentLoop *agent.AgentLoop, msgBus *bus.MessageBus, workspace string, restrict bool, execTimeout time.Duration) *cron.CronService {
|
||||
cronStorePath := filepath.Join(workspace, "cron", "jobs.json")
|
||||
|
||||
// Create cron service
|
||||
cronService := cron.NewCronService(cronStorePath, nil)
|
||||
|
||||
// Create and register CronTool
|
||||
cronTool := tools.NewCronTool(cronService, agentLoop, msgBus, workspace, restrict, execTimeout)
|
||||
agentLoop.RegisterTool(cronTool)
|
||||
|
||||
// Set the onJob handler
|
||||
cronService.SetOnJob(func(job *cron.CronJob) (string, error) {
|
||||
result := cronTool.ExecuteJob(context.Background(), job)
|
||||
return result, nil
|
||||
})
|
||||
|
||||
return cronService
|
||||
}
|
||||
@@ -0,0 +1,81 @@
|
||||
// PicoClaw - Ultra-lightweight personal AI agent
|
||||
// License: MIT
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/migrate"
|
||||
)
|
||||
|
||||
func migrateCmd() {
|
||||
if len(os.Args) > 2 && (os.Args[2] == "--help" || os.Args[2] == "-h") {
|
||||
migrateHelp()
|
||||
return
|
||||
}
|
||||
|
||||
opts := migrate.Options{}
|
||||
|
||||
args := os.Args[2:]
|
||||
for i := 0; i < len(args); i++ {
|
||||
switch args[i] {
|
||||
case "--dry-run":
|
||||
opts.DryRun = true
|
||||
case "--config-only":
|
||||
opts.ConfigOnly = true
|
||||
case "--workspace-only":
|
||||
opts.WorkspaceOnly = true
|
||||
case "--force":
|
||||
opts.Force = true
|
||||
case "--refresh":
|
||||
opts.Refresh = true
|
||||
case "--openclaw-home":
|
||||
if i+1 < len(args) {
|
||||
opts.OpenClawHome = args[i+1]
|
||||
i++
|
||||
}
|
||||
case "--picoclaw-home":
|
||||
if i+1 < len(args) {
|
||||
opts.PicoClawHome = args[i+1]
|
||||
i++
|
||||
}
|
||||
default:
|
||||
fmt.Printf("Unknown flag: %s\n", args[i])
|
||||
migrateHelp()
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
result, err := migrate.Run(opts)
|
||||
if err != nil {
|
||||
fmt.Printf("Error: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if !opts.DryRun {
|
||||
migrate.PrintSummary(result)
|
||||
}
|
||||
}
|
||||
|
||||
func migrateHelp() {
|
||||
fmt.Println("\nMigrate from OpenClaw to PicoClaw")
|
||||
fmt.Println()
|
||||
fmt.Println("Usage: picoclaw migrate [options]")
|
||||
fmt.Println()
|
||||
fmt.Println("Options:")
|
||||
fmt.Println(" --dry-run Show what would be migrated without making changes")
|
||||
fmt.Println(" --refresh Re-sync workspace files from OpenClaw (repeatable)")
|
||||
fmt.Println(" --config-only Only migrate config, skip workspace files")
|
||||
fmt.Println(" --workspace-only Only migrate workspace files, skip config")
|
||||
fmt.Println(" --force Skip confirmation prompts")
|
||||
fmt.Println(" --openclaw-home Override OpenClaw home directory (default: ~/.openclaw)")
|
||||
fmt.Println(" --picoclaw-home Override PicoClaw home directory (default: ~/.picoclaw)")
|
||||
fmt.Println()
|
||||
fmt.Println("Examples:")
|
||||
fmt.Println(" picoclaw migrate Detect and migrate from OpenClaw")
|
||||
fmt.Println(" picoclaw migrate --dry-run Show what would be migrated")
|
||||
fmt.Println(" picoclaw migrate --refresh Re-sync workspace files")
|
||||
fmt.Println(" picoclaw migrate --force Migrate without confirmation")
|
||||
}
|
||||
@@ -0,0 +1,102 @@
|
||||
// PicoClaw - Ultra-lightweight personal AI agent
|
||||
// License: MIT
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"embed"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
//go:generate cp -r ../../workspace .
|
||||
//go:embed workspace
|
||||
var embeddedFiles embed.FS
|
||||
|
||||
func onboard() {
|
||||
configPath := getConfigPath()
|
||||
|
||||
if _, err := os.Stat(configPath); err == nil {
|
||||
fmt.Printf("Config already exists at %s\n", configPath)
|
||||
fmt.Print("Overwrite? (y/n): ")
|
||||
var response string
|
||||
fmt.Scanln(&response)
|
||||
if response != "y" {
|
||||
fmt.Println("Aborted.")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
cfg := config.DefaultConfig()
|
||||
if err := config.SaveConfig(configPath, cfg); err != nil {
|
||||
fmt.Printf("Error saving config: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
workspace := cfg.WorkspacePath()
|
||||
createWorkspaceTemplates(workspace)
|
||||
|
||||
fmt.Printf("%s picoclaw is ready!\n", logo)
|
||||
fmt.Println("\nNext steps:")
|
||||
fmt.Println(" 1. Add your API key to", configPath)
|
||||
fmt.Println(" Get one at: https://openrouter.ai/keys")
|
||||
fmt.Println(" 2. Chat: picoclaw agent -m \"Hello!\"")
|
||||
}
|
||||
|
||||
func copyEmbeddedToTarget(targetDir string) error {
|
||||
// Ensure target directory exists
|
||||
if err := os.MkdirAll(targetDir, 0755); err != nil {
|
||||
return fmt.Errorf("Failed to create target directory: %w", err)
|
||||
}
|
||||
|
||||
// Walk through all files in embed.FS
|
||||
err := fs.WalkDir(embeddedFiles, "workspace", func(path string, d fs.DirEntry, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Skip directories
|
||||
if d.IsDir() {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Read embedded file
|
||||
data, err := embeddedFiles.ReadFile(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to read embedded file %s: %w", path, err)
|
||||
}
|
||||
|
||||
new_path, err := filepath.Rel("workspace", path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to get relative path for %s: %v\n", path, err)
|
||||
}
|
||||
|
||||
// Build target file path
|
||||
targetPath := filepath.Join(targetDir, new_path)
|
||||
|
||||
// Ensure target file's directory exists
|
||||
if err := os.MkdirAll(filepath.Dir(targetPath), 0755); err != nil {
|
||||
return fmt.Errorf("Failed to create directory %s: %w", filepath.Dir(targetPath), err)
|
||||
}
|
||||
|
||||
// Write file
|
||||
if err := os.WriteFile(targetPath, data, 0644); err != nil {
|
||||
return fmt.Errorf("Failed to write file %s: %w", targetPath, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func createWorkspaceTemplates(workspace string) {
|
||||
err := copyEmbeddedToTarget(workspace)
|
||||
if err != nil {
|
||||
fmt.Printf("Error copying workspace templates: %v\n", err)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,216 @@
|
||||
// PicoClaw - Ultra-lightweight personal AI agent
|
||||
// License: MIT
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/skills"
|
||||
)
|
||||
|
||||
func skillsHelp() {
|
||||
fmt.Println("\nSkills commands:")
|
||||
fmt.Println(" list List installed skills")
|
||||
fmt.Println(" install <repo> Install skill from GitHub")
|
||||
fmt.Println(" install-builtin Install all builtin skills to workspace")
|
||||
fmt.Println(" list-builtin List available builtin skills")
|
||||
fmt.Println(" remove <name> Remove installed skill")
|
||||
fmt.Println(" search Search available skills")
|
||||
fmt.Println(" show <name> Show skill details")
|
||||
fmt.Println()
|
||||
fmt.Println("Examples:")
|
||||
fmt.Println(" picoclaw skills list")
|
||||
fmt.Println(" picoclaw skills install sipeed/picoclaw-skills/weather")
|
||||
fmt.Println(" picoclaw skills install-builtin")
|
||||
fmt.Println(" picoclaw skills list-builtin")
|
||||
fmt.Println(" picoclaw skills remove weather")
|
||||
}
|
||||
|
||||
func skillsListCmd(loader *skills.SkillsLoader) {
|
||||
allSkills := loader.ListSkills()
|
||||
|
||||
if len(allSkills) == 0 {
|
||||
fmt.Println("No skills installed.")
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Println("\nInstalled Skills:")
|
||||
fmt.Println("------------------")
|
||||
for _, skill := range allSkills {
|
||||
fmt.Printf(" ✓ %s (%s)\n", skill.Name, skill.Source)
|
||||
if skill.Description != "" {
|
||||
fmt.Printf(" %s\n", skill.Description)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func skillsInstallCmd(installer *skills.SkillInstaller) {
|
||||
if len(os.Args) < 4 {
|
||||
fmt.Println("Usage: picoclaw skills install <github-repo>")
|
||||
fmt.Println("Example: picoclaw skills install sipeed/picoclaw-skills/weather")
|
||||
return
|
||||
}
|
||||
|
||||
repo := os.Args[3]
|
||||
fmt.Printf("Installing skill from %s...\n", repo)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := installer.InstallFromGitHub(ctx, repo); err != nil {
|
||||
fmt.Printf("✗ Failed to install skill: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
fmt.Printf("✓ Skill '%s' installed successfully!\n", filepath.Base(repo))
|
||||
}
|
||||
|
||||
func skillsRemoveCmd(installer *skills.SkillInstaller, skillName string) {
|
||||
fmt.Printf("Removing skill '%s'...\n", skillName)
|
||||
|
||||
if err := installer.Uninstall(skillName); err != nil {
|
||||
fmt.Printf("✗ Failed to remove skill: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
fmt.Printf("✓ Skill '%s' removed successfully!\n", skillName)
|
||||
}
|
||||
|
||||
func skillsInstallBuiltinCmd(workspace string) {
|
||||
builtinSkillsDir := "./picoclaw/skills"
|
||||
workspaceSkillsDir := filepath.Join(workspace, "skills")
|
||||
|
||||
fmt.Printf("Copying builtin skills to workspace...\n")
|
||||
|
||||
skillsToInstall := []string{
|
||||
"weather",
|
||||
"news",
|
||||
"stock",
|
||||
"calculator",
|
||||
}
|
||||
|
||||
for _, skillName := range skillsToInstall {
|
||||
builtinPath := filepath.Join(builtinSkillsDir, skillName)
|
||||
workspacePath := filepath.Join(workspaceSkillsDir, skillName)
|
||||
|
||||
if _, err := os.Stat(builtinPath); err != nil {
|
||||
fmt.Printf("⊘ Builtin skill '%s' not found: %v\n", skillName, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(workspacePath, 0755); err != nil {
|
||||
fmt.Printf("✗ Failed to create directory for %s: %v\n", skillName, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if err := copyDirectory(builtinPath, workspacePath); err != nil {
|
||||
fmt.Printf("✗ Failed to copy %s: %v\n", skillName, err)
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Println("\n✓ All builtin skills installed!")
|
||||
fmt.Println("Now you can use them in your workspace.")
|
||||
}
|
||||
|
||||
func skillsListBuiltinCmd() {
|
||||
cfg, err := loadConfig()
|
||||
if err != nil {
|
||||
fmt.Printf("Error loading config: %v\n", err)
|
||||
return
|
||||
}
|
||||
builtinSkillsDir := filepath.Join(filepath.Dir(cfg.WorkspacePath()), "picoclaw", "skills")
|
||||
|
||||
fmt.Println("\nAvailable Builtin Skills:")
|
||||
fmt.Println("-----------------------")
|
||||
|
||||
entries, err := os.ReadDir(builtinSkillsDir)
|
||||
if err != nil {
|
||||
fmt.Printf("Error reading builtin skills: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
if len(entries) == 0 {
|
||||
fmt.Println("No builtin skills available.")
|
||||
return
|
||||
}
|
||||
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() {
|
||||
skillName := entry.Name()
|
||||
skillFile := filepath.Join(builtinSkillsDir, skillName, "SKILL.md")
|
||||
|
||||
description := "No description"
|
||||
if _, err := os.Stat(skillFile); err == nil {
|
||||
data, err := os.ReadFile(skillFile)
|
||||
if err == nil {
|
||||
content := string(data)
|
||||
if idx := strings.Index(content, "\n"); idx > 0 {
|
||||
firstLine := content[:idx]
|
||||
if strings.Contains(firstLine, "description:") {
|
||||
descLine := strings.Index(content[idx:], "\n")
|
||||
if descLine > 0 {
|
||||
description = strings.TrimSpace(content[idx+descLine : idx+descLine])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
status := "✓"
|
||||
fmt.Printf(" %s %s\n", status, entry.Name())
|
||||
if description != "" {
|
||||
fmt.Printf(" %s\n", description)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func skillsSearchCmd(installer *skills.SkillInstaller) {
|
||||
fmt.Println("Searching for available skills...")
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
availableSkills, err := installer.ListAvailableSkills(ctx)
|
||||
if err != nil {
|
||||
fmt.Printf("✗ Failed to fetch skills list: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
if len(availableSkills) == 0 {
|
||||
fmt.Println("No skills available.")
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Printf("\nAvailable Skills (%d):\n", len(availableSkills))
|
||||
fmt.Println("--------------------")
|
||||
for _, skill := range availableSkills {
|
||||
fmt.Printf(" 📦 %s\n", skill.Name)
|
||||
fmt.Printf(" %s\n", skill.Description)
|
||||
fmt.Printf(" Repo: %s\n", skill.Repository)
|
||||
if skill.Author != "" {
|
||||
fmt.Printf(" Author: %s\n", skill.Author)
|
||||
}
|
||||
if len(skill.Tags) > 0 {
|
||||
fmt.Printf(" Tags: %v\n", skill.Tags)
|
||||
}
|
||||
fmt.Println()
|
||||
}
|
||||
}
|
||||
|
||||
func skillsShowCmd(loader *skills.SkillsLoader, skillName string) {
|
||||
content, ok := loader.LoadSkill(skillName)
|
||||
if !ok {
|
||||
fmt.Printf("✗ Skill '%s' not found\n", skillName)
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Printf("\n📦 Skill: %s\n", skillName)
|
||||
fmt.Println("----------------------")
|
||||
fmt.Println(content)
|
||||
}
|
||||
@@ -0,0 +1,102 @@
|
||||
// PicoClaw - Ultra-lightweight personal AI agent
|
||||
// License: MIT
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/auth"
|
||||
)
|
||||
|
||||
func statusCmd() {
|
||||
cfg, err := loadConfig()
|
||||
if err != nil {
|
||||
fmt.Printf("Error loading config: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
configPath := getConfigPath()
|
||||
|
||||
fmt.Printf("%s picoclaw Status\n", logo)
|
||||
fmt.Printf("Version: %s\n", formatVersion())
|
||||
build, _ := formatBuildInfo()
|
||||
if build != "" {
|
||||
fmt.Printf("Build: %s\n", build)
|
||||
}
|
||||
fmt.Println()
|
||||
|
||||
if _, err := os.Stat(configPath); err == nil {
|
||||
fmt.Println("Config:", configPath, "✓")
|
||||
} else {
|
||||
fmt.Println("Config:", configPath, "✗")
|
||||
}
|
||||
|
||||
workspace := cfg.WorkspacePath()
|
||||
if _, err := os.Stat(workspace); err == nil {
|
||||
fmt.Println("Workspace:", workspace, "✓")
|
||||
} else {
|
||||
fmt.Println("Workspace:", workspace, "✗")
|
||||
}
|
||||
|
||||
if _, err := os.Stat(configPath); err == nil {
|
||||
fmt.Printf("Model: %s\n", cfg.Agents.Defaults.Model)
|
||||
|
||||
hasOpenRouter := cfg.Providers.OpenRouter.APIKey != ""
|
||||
hasAnthropic := cfg.Providers.Anthropic.APIKey != ""
|
||||
hasOpenAI := cfg.Providers.OpenAI.APIKey != ""
|
||||
hasGemini := cfg.Providers.Gemini.APIKey != ""
|
||||
hasZhipu := cfg.Providers.Zhipu.APIKey != ""
|
||||
hasQwen := cfg.Providers.Qwen.APIKey != ""
|
||||
hasGroq := cfg.Providers.Groq.APIKey != ""
|
||||
hasVLLM := cfg.Providers.VLLM.APIBase != ""
|
||||
hasMoonshot := cfg.Providers.Moonshot.APIKey != ""
|
||||
hasDeepSeek := cfg.Providers.DeepSeek.APIKey != ""
|
||||
hasVolcEngine := cfg.Providers.VolcEngine.APIKey != ""
|
||||
hasNvidia := cfg.Providers.Nvidia.APIKey != ""
|
||||
hasOllama := cfg.Providers.Ollama.APIBase != ""
|
||||
|
||||
status := func(enabled bool) string {
|
||||
if enabled {
|
||||
return "✓"
|
||||
}
|
||||
return "not set"
|
||||
}
|
||||
fmt.Println("OpenRouter API:", status(hasOpenRouter))
|
||||
fmt.Println("Anthropic API:", status(hasAnthropic))
|
||||
fmt.Println("OpenAI API:", status(hasOpenAI))
|
||||
fmt.Println("Gemini API:", status(hasGemini))
|
||||
fmt.Println("Zhipu API:", status(hasZhipu))
|
||||
fmt.Println("Qwen API:", status(hasQwen))
|
||||
fmt.Println("Groq API:", status(hasGroq))
|
||||
fmt.Println("Moonshot API:", status(hasMoonshot))
|
||||
fmt.Println("DeepSeek API:", status(hasDeepSeek))
|
||||
fmt.Println("VolcEngine API:", status(hasVolcEngine))
|
||||
fmt.Println("Nvidia API:", status(hasNvidia))
|
||||
if hasVLLM {
|
||||
fmt.Printf("vLLM/Local: ✓ %s\n", cfg.Providers.VLLM.APIBase)
|
||||
} else {
|
||||
fmt.Println("vLLM/Local: not set")
|
||||
}
|
||||
if hasOllama {
|
||||
fmt.Printf("Ollama: ✓ %s\n", cfg.Providers.Ollama.APIBase)
|
||||
} else {
|
||||
fmt.Println("Ollama: not set")
|
||||
}
|
||||
|
||||
store, _ := auth.LoadStore()
|
||||
if store != nil && len(store.Credentials) > 0 {
|
||||
fmt.Println("\nOAuth/Token Auth:")
|
||||
for provider, cred := range store.Credentials {
|
||||
status := "authenticated"
|
||||
if cred.IsExpired() {
|
||||
status = "expired"
|
||||
} else if cred.NeedsRefresh() {
|
||||
status = "needs refresh"
|
||||
}
|
||||
fmt.Printf(" %s (%s): %s\n", provider, cred.AuthMethod, status)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,179 @@
|
||||
# Provider Architecture Refactoring - Test Suite Summary
|
||||
|
||||
> PRD: `tasks/prd-provider-refactoring.md`
|
||||
|
||||
This document summarizes the complete test suite designed for the Provider architecture refactoring.
|
||||
|
||||
## Test File Structure
|
||||
|
||||
```
|
||||
pkg/
|
||||
├── config/
|
||||
│ ├── model_config_test.go # US-001, US-002: ModelConfig struct and GetModelConfig tests
|
||||
│ └── migration_test.go # US-003: Backward compatibility and migration tests
|
||||
├── providers/
|
||||
│ ├── registry_test.go # US-006: Load balancing tests
|
||||
│ ├── integration_test.go # E2E integration tests
|
||||
│ └── factory/
|
||||
│ └── factory_test.go # US-004, US-005: Provider factory tests
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Test Case Checklist
|
||||
|
||||
### 1. `pkg/config/model_config_test.go` - Configuration Parsing Tests
|
||||
|
||||
| Test Name | Purpose | PRD Reference |
|
||||
|-----------|---------|---------------|
|
||||
| `TestModelConfig_Parsing` | Verify ModelConfig JSON parsing | US-001 |
|
||||
| `TestModelConfig_ModelListInConfig` | Verify model_list parsing in Config | US-001 |
|
||||
| `TestModelConfig_Validation` | Verify required field validation | US-001 |
|
||||
| `TestConfig_GetModelConfig_Found` | Verify GetModelConfig finds model | US-002 |
|
||||
| `TestConfig_GetModelConfig_NotFound` | Verify GetModelConfig returns error | US-002 |
|
||||
| `TestConfig_GetModelConfig_EmptyModelList` | Verify empty model_list handling | US-002 |
|
||||
| `TestConfig_BackwardCompatibility_ProvidersToModelList` | Verify old config conversion | US-003 |
|
||||
| `TestConfig_DeprecationWarning` | Verify deprecation warning | US-003 |
|
||||
| `TestModelConfig_ProtocolExtraction` | Verify protocol prefix extraction | US-004 |
|
||||
| `TestConfig_ModelNameUniqueness` | Verify model_name uniqueness | US-001 |
|
||||
|
||||
### 2. `pkg/config/migration_test.go` - Migration Tests
|
||||
|
||||
| Test Name | Purpose | PRD Reference |
|
||||
|-----------|---------|---------------|
|
||||
| `TestConvertProvidersToModelList_OpenAI` | OpenAI config conversion | US-003 |
|
||||
| `TestConvertProvidersToModelList_Anthropic` | Anthropic config conversion | US-003 |
|
||||
| `TestConvertProvidersToModelList_MultipleProviders` | Multiple provider conversion | US-003 |
|
||||
| `TestConvertProvidersToModelList_EmptyProviders` | Empty providers handling | US-003 |
|
||||
| `TestConvertProvidersToModelList_GitHubCopilot` | GitHub Copilot conversion | US-003 |
|
||||
| `TestConvertProvidersToModelList_Antigravity` | Antigravity conversion | US-003 |
|
||||
| `TestGenerateModelName_*` | Model name generation | US-003 |
|
||||
| `TestHasProvidersConfig_*` | Detect old config existence | US-003 |
|
||||
| `TestValidateMigration_*` | Migration validation | US-003 |
|
||||
| `TestMigrateConfig_DryRun` | Dry run migration | US-003 |
|
||||
| `TestMigrateConfig_Actual` | Actual migration | US-003 |
|
||||
|
||||
### 3. `pkg/providers/registry_test.go` - Load Balancing Tests
|
||||
|
||||
| Test Name | Purpose | PRD Reference |
|
||||
|-----------|---------|---------------|
|
||||
| `TestModelRegistry_SingleConfig` | Single config returns same result | US-006 |
|
||||
| `TestModelRegistry_RoundRobinSelection` | 3-config round-robin selection | US-006 |
|
||||
| `TestModelRegistry_RoundRobinTwoConfigs` | 2-config round-robin selection | US-006 |
|
||||
| `TestModelRegistry_ConcurrentAccess` | Concurrent access thread safety | US-006 |
|
||||
| `TestModelRegistry_RaceDetection` | Data race detection | US-006 |
|
||||
| `TestModelRegistry_ModelNotFound` | Model not found error | US-006 |
|
||||
| `TestModelRegistry_EmptyRegistry` | Empty registry handling | US-006 |
|
||||
| `TestModelRegistry_MultipleModels` | Multiple model registration | US-006 |
|
||||
| `TestModelRegistry_MixedSingleAndMultiple` | Single/multiple config mix | US-006 |
|
||||
| `TestModelRegistry_CaseSensitiveModelNames` | Case sensitivity | US-006 |
|
||||
|
||||
### 4. `pkg/providers/factory/factory_test.go` - Provider Factory Tests
|
||||
|
||||
| Test Name | Purpose | PRD Reference |
|
||||
|-----------|---------|---------------|
|
||||
| `TestCreateProviderFromConfig_OpenAI` | Create OpenAI provider | US-004 |
|
||||
| `TestCreateProviderFromConfig_OpenAIDefault` | Default openai protocol | US-004 |
|
||||
| `TestCreateProviderFromConfig_Anthropic` | Create Anthropic provider | US-004 |
|
||||
| `TestCreateProviderFromConfig_Antigravity` | Create Antigravity provider | US-004 |
|
||||
| `TestCreateProviderFromConfig_ClaudeCLI` | Create Claude CLI provider | US-004 |
|
||||
| `TestCreateProviderFromConfig_CodexCLI` | Create Codex CLI provider | US-004 |
|
||||
| `TestCreateProviderFromConfig_GitHubCopilot` | Create GitHub Copilot provider | US-004 |
|
||||
| `TestCreateProviderFromConfig_UnknownProtocol` | Unknown protocol error handling | US-004 |
|
||||
| `TestCreateProviderFromConfig_MissingAPIKey` | Missing API key error | US-004 |
|
||||
| `TestExtractProtocol` | Protocol prefix extraction | US-004 |
|
||||
| `TestCreateProvider_UsesModelList` | Create using model_list | US-005 |
|
||||
| `TestCreateProvider_FallbackToProviders` | Fallback to providers | US-005 |
|
||||
| `TestCreateProvider_PriorityModelListOverProviders` | model_list priority | US-005 |
|
||||
|
||||
### 5. `pkg/providers/integration_test.go` - E2E Integration Tests
|
||||
|
||||
| Test Name | Purpose | PRD Reference |
|
||||
|-----------|---------|---------------|
|
||||
| `TestE2E_OpenAICompatibleProvider_NoCodeChange` | Zero-code provider addition | Goal |
|
||||
| `TestE2E_LoadBalancing_RoundRobin` | Load balancing actual effect | US-006 |
|
||||
| `TestE2E_BackwardCompatibility_OldProvidersConfig` | Old config compatibility | US-003 |
|
||||
| `TestE2E_ErrorHandling_ModelNotFound` | Model not found | FR-30 |
|
||||
| `TestE2E_ErrorHandling_MissingAPIKey` | Missing API key | FR-31 |
|
||||
| `TestE2E_ErrorHandling_InvalidAPIBase` | Invalid API base | FR-30 |
|
||||
| `TestE2E_ToolCalls_OpenAICompatible` | Tool call support | - |
|
||||
| `TestE2E_AntigravityProvider` | Antigravity provider | US-004 |
|
||||
| `TestE2E_ClaudeCLIProvider` | Claude CLI provider | US-004 |
|
||||
|
||||
### 6. Performance Tests
|
||||
|
||||
| Test Name | Purpose |
|
||||
|-----------|---------|
|
||||
| `BenchmarkCreateProviderFromConfig` | Provider creation performance |
|
||||
| `BenchmarkGetModelConfig` | Model lookup performance |
|
||||
| `BenchmarkGetModelConfigParallel` | Concurrent lookup performance |
|
||||
|
||||
---
|
||||
|
||||
## Running Tests
|
||||
|
||||
```bash
|
||||
# Run all tests
|
||||
go test ./pkg/... -v
|
||||
|
||||
# Run with data race detection
|
||||
go test ./pkg/... -race
|
||||
|
||||
# Run specific package tests
|
||||
go test ./pkg/config -v
|
||||
go test ./pkg/providers -v
|
||||
go test ./pkg/providers/factory -v
|
||||
|
||||
# Run E2E tests
|
||||
go test ./pkg/providers -run TestE2E -v
|
||||
|
||||
# Run performance tests
|
||||
go test ./pkg/providers -bench=. -benchmem
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## PRD Acceptance Criteria Mapping
|
||||
|
||||
| PRD Acceptance Criteria | Test Cases |
|
||||
|------------------------|------------|
|
||||
| US-001: Add ModelConfig struct | `TestModelConfig_Parsing`, `TestModelConfig_Validation` |
|
||||
| US-001: model_name unique | `TestConfig_ModelNameUniqueness` |
|
||||
| US-002: GetModelConfig method | `TestConfig_GetModelConfig_*` |
|
||||
| US-003: Auto-convert providers | `TestConvertProvidersToModelList_*` |
|
||||
| US-003: Deprecation warning | `TestConfig_DeprecationWarning` |
|
||||
| US-003: Existing tests pass | (existing test files unchanged) |
|
||||
| US-004: Protocol prefix factory | `TestExtractProtocol`, `TestCreateProviderFromConfig_*` |
|
||||
| US-004: Default prefix openai | `TestCreateProviderFromConfig_OpenAIDefault` |
|
||||
| US-005: CreateProvider uses factory | `TestCreateProvider_*` |
|
||||
| US-006: Round-robin selection | `TestModelRegistry_RoundRobin*` |
|
||||
| US-006: Thread-safe atomic | `TestModelRegistry_RaceDetection` |
|
||||
|
||||
---
|
||||
|
||||
## Recommended Implementation Order
|
||||
|
||||
1. **Phase 1: Configuration Structure** (US-001, US-002)
|
||||
- Implement `ModelConfig` struct
|
||||
- Implement `GetModelConfig` method
|
||||
- Run `model_config_test.go`
|
||||
|
||||
2. **Phase 2: Protocol Factory** (US-004)
|
||||
- Implement `CreateProviderFromConfig`
|
||||
- Implement `ExtractProtocol`
|
||||
- Run `factory_test.go`
|
||||
|
||||
3. **Phase 3: Load Balancing** (US-006)
|
||||
- Implement `ModelRegistry`
|
||||
- Implement round-robin selection
|
||||
- Run `registry_test.go` (with `-race`)
|
||||
|
||||
4. **Phase 4: Backward Compatibility** (US-003, US-005)
|
||||
- Implement `ConvertProvidersToModelList`
|
||||
- Refactor `CreateProvider`
|
||||
- Run `migration_test.go`
|
||||
- Verify existing tests pass
|
||||
|
||||
5. **Phase 5: E2E Verification**
|
||||
- Run `integration_test.go`
|
||||
- Manual testing with `config.example.json`
|
||||
@@ -0,0 +1,334 @@
|
||||
# Provider Architecture Refactoring Design
|
||||
|
||||
> Issue: #283
|
||||
> Discussion: #122
|
||||
> Branch: feat/refactor-provider-by-protocol
|
||||
|
||||
## 1. Current Problems
|
||||
|
||||
### 1.1 Configuration Structure Issues
|
||||
|
||||
**Current State**: Each Provider requires a predefined field in `ProvidersConfig`
|
||||
|
||||
```go
|
||||
type ProvidersConfig struct {
|
||||
Anthropic ProviderConfig `json:"anthropic"`
|
||||
OpenAI ProviderConfig `json:"openai"`
|
||||
DeepSeek ProviderConfig `json:"deepseek"`
|
||||
Qwen ProviderConfig `json:"qwen"`
|
||||
Cerebras ProviderConfig `json:"cerebras"`
|
||||
VolcEngine ProviderConfig `json:"volcengine"`
|
||||
// ... every new provider requires changes here
|
||||
}
|
||||
```
|
||||
|
||||
**Problems**:
|
||||
- Adding a new Provider requires modifying Go code (struct definition)
|
||||
- `CreateProvider` function in `http_provider.go` has 200+ lines of switch-case
|
||||
- Most Providers are OpenAI-compatible, but code is duplicated
|
||||
|
||||
### 1.2 Code Bloat Trend
|
||||
|
||||
Recent PRs demonstrate this issue:
|
||||
|
||||
| PR | Provider | Code Changes |
|
||||
|----|----------|--------------|
|
||||
| #365 | Qwen | +17 lines to http_provider.go |
|
||||
| #333 | Cerebras | +17 lines to http_provider.go |
|
||||
| #368 | Volcengine | +18 lines to http_provider.go |
|
||||
|
||||
Each OpenAI-compatible Provider requires:
|
||||
1. Modify `config.go` to add configuration field
|
||||
2. Modify `http_provider.go` to add switch case
|
||||
3. Update documentation
|
||||
|
||||
### 1.3 Agent-Provider Coupling
|
||||
|
||||
```json
|
||||
{
|
||||
"agents": {
|
||||
"defaults": {
|
||||
"provider": "deepseek", // need to know provider name
|
||||
"model": "deepseek-chat"
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Problem: Agent needs to know both `provider` and `model`, adding complexity.
|
||||
|
||||
---
|
||||
|
||||
## 2. New Approach: model_list
|
||||
|
||||
### 2.1 Core Principles
|
||||
|
||||
Inspired by [LiteLLM](https://docs.litellm.ai/docs/proxy/configs) design:
|
||||
|
||||
1. **Model-centric**: Users care about models, not providers
|
||||
2. **Protocol prefix**: Use `protocol/model_name` format, e.g., `openai/gpt-4o`, `anthropic/claude-3-sonnet`
|
||||
3. **Configuration-driven**: Adding new Providers only requires config changes, no code changes
|
||||
|
||||
### 2.2 New Configuration Structure
|
||||
|
||||
```json
|
||||
{
|
||||
"model_list": [
|
||||
{
|
||||
"model_name": "deepseek-chat",
|
||||
"model": "openai/deepseek-chat",
|
||||
"api_base": "https://api.deepseek.com/v1",
|
||||
"api_key": "sk-xxx"
|
||||
},
|
||||
{
|
||||
"model_name": "gpt-4o",
|
||||
"model": "openai/gpt-4o",
|
||||
"api_key": "sk-xxx"
|
||||
},
|
||||
{
|
||||
"model_name": "claude-3-sonnet",
|
||||
"model": "anthropic/claude-3-5-sonnet-20241022",
|
||||
"api_key": "sk-xxx"
|
||||
},
|
||||
{
|
||||
"model_name": "gemini-3-flash",
|
||||
"model": "antigravity/gemini-3-flash",
|
||||
"auth_method": "oauth"
|
||||
},
|
||||
{
|
||||
"model_name": "my-company-llm",
|
||||
"model": "openai/company-model-v1",
|
||||
"api_base": "https://llm.company.com/v1",
|
||||
"api_key": "xxx"
|
||||
}
|
||||
],
|
||||
|
||||
"agents": {
|
||||
"defaults": {
|
||||
"model": "deepseek-chat",
|
||||
"max_tokens": 8192,
|
||||
"temperature": 0.7
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 2.3 Go Struct Definition
|
||||
|
||||
```go
|
||||
type Config struct {
|
||||
ModelList []ModelConfig `json:"model_list"` // new
|
||||
Providers ProvidersConfig `json:"providers"` // old, deprecated
|
||||
|
||||
Agents AgentsConfig `json:"agents"`
|
||||
Channels ChannelsConfig `json:"channels"`
|
||||
// ...
|
||||
}
|
||||
|
||||
type ModelConfig struct {
|
||||
// Required
|
||||
ModelName string `json:"model_name"` // user-facing name (alias)
|
||||
Model string `json:"model"` // protocol/model, e.g., openai/gpt-4o
|
||||
|
||||
// Common config
|
||||
APIBase string `json:"api_base,omitempty"`
|
||||
APIKey string `json:"api_key,omitempty"`
|
||||
Proxy string `json:"proxy,omitempty"`
|
||||
|
||||
// Special provider config
|
||||
AuthMethod string `json:"auth_method,omitempty"` // oauth, token
|
||||
ConnectMode string `json:"connect_mode,omitempty"` // stdio, grpc
|
||||
|
||||
// Optional optimizations
|
||||
RPM int `json:"rpm,omitempty"` // rate limit
|
||||
MaxTokensField string `json:"max_tokens_field,omitempty"` // max_tokens or max_completion_tokens
|
||||
}
|
||||
```
|
||||
|
||||
### 2.4 Protocol Recognition
|
||||
|
||||
Identify protocol via prefix in `model` field:
|
||||
|
||||
| Prefix | Protocol | Description |
|
||||
|--------|----------|-------------|
|
||||
| `openai/` | OpenAI-compatible | Most common, includes DeepSeek, Qwen, Groq, etc. |
|
||||
| `anthropic/` | Anthropic | Claude series specific |
|
||||
| `antigravity/` | Antigravity | Google Cloud Code Assist |
|
||||
| `gemini/` | Gemini | Google Gemini native API (if needed) |
|
||||
|
||||
---
|
||||
|
||||
## 3. Design Rationale
|
||||
|
||||
### 3.1 Problems Solved
|
||||
|
||||
| Problem | Old Approach | New Approach |
|
||||
|---------|--------------|--------------|
|
||||
| Add OpenAI-compatible Provider | Change 3 code locations | Add one config entry |
|
||||
| Agent specifies model | Need provider + model | Only need model |
|
||||
| Code duplication | Each Provider duplicates logic | Share protocol implementation |
|
||||
| Multi-Agent support | Complex | Naturally compatible |
|
||||
|
||||
### 3.2 Multi-Agent Compatibility
|
||||
|
||||
```json
|
||||
{
|
||||
"model_list": [...],
|
||||
|
||||
"agents": {
|
||||
"defaults": {
|
||||
"model": "deepseek-chat"
|
||||
},
|
||||
"coder": {
|
||||
"model": "gpt-4o",
|
||||
"system_prompt": "You are a coding assistant..."
|
||||
},
|
||||
"translator": {
|
||||
"model": "claude-3-sonnet"
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Each Agent only needs to specify `model` (corresponds to `model_name` in `model_list`).
|
||||
|
||||
### 3.3 Industry Comparison
|
||||
|
||||
**LiteLLM** (most mature open-source LLM Proxy) uses similar design:
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: gpt-4o
|
||||
litellm_params:
|
||||
model: openai/gpt-4o
|
||||
api_key: xxx
|
||||
- model_name: my-custom
|
||||
litellm_params:
|
||||
model: openai/custom-model
|
||||
api_base: https://my-api.com/v1
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 4. Migration Plan
|
||||
|
||||
### 4.1 Phase 1: Compatibility Period (v1.x)
|
||||
|
||||
Support both `providers` and `model_list`:
|
||||
|
||||
```go
|
||||
func (c *Config) GetModelConfig(modelName string) (*ModelConfig, error) {
|
||||
// Prefer new config
|
||||
if len(c.ModelList) > 0 {
|
||||
return c.findModelByName(modelName)
|
||||
}
|
||||
|
||||
// Backward compatibility with old config
|
||||
if !c.Providers.IsEmpty() {
|
||||
logger.Warn("'providers' config is deprecated, please migrate to 'model_list'")
|
||||
return c.convertFromProviders(modelName)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("model %s not found", modelName)
|
||||
}
|
||||
```
|
||||
|
||||
### 4.2 Phase 2: Warning Period (late v1.x)
|
||||
|
||||
- Print more prominent warnings at startup
|
||||
- Provide automatic migration script
|
||||
- Mark `providers` as deprecated in documentation
|
||||
|
||||
### 4.3 Phase 3: Removal Period (v2.0)
|
||||
|
||||
- Completely remove `providers` support
|
||||
- Remove `agents.defaults.provider` field
|
||||
- Only support `model_list`
|
||||
|
||||
### 4.4 Configuration Migration Example
|
||||
|
||||
**Old Config**:
|
||||
```json
|
||||
{
|
||||
"providers": {
|
||||
"deepseek": {
|
||||
"api_key": "sk-xxx",
|
||||
"api_base": "https://api.deepseek.com/v1"
|
||||
}
|
||||
},
|
||||
"agents": {
|
||||
"defaults": {
|
||||
"provider": "deepseek",
|
||||
"model": "deepseek-chat"
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**New Config**:
|
||||
```json
|
||||
{
|
||||
"model_list": [
|
||||
{
|
||||
"model_name": "deepseek-chat",
|
||||
"model": "openai/deepseek-chat",
|
||||
"api_base": "https://api.deepseek.com/v1",
|
||||
"api_key": "sk-xxx"
|
||||
}
|
||||
],
|
||||
"agents": {
|
||||
"defaults": {
|
||||
"model": "deepseek-chat"
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 5. Implementation Checklist
|
||||
|
||||
### 5.1 Configuration Layer
|
||||
|
||||
- [ ] Add `ModelConfig` struct
|
||||
- [ ] Add `Config.ModelList` field
|
||||
- [ ] Implement `GetModelConfig(modelName)` method
|
||||
- [ ] Implement old config compatibility conversion
|
||||
- [ ] Add `model_name` uniqueness validation
|
||||
|
||||
### 5.2 Provider Layer
|
||||
|
||||
- [ ] Create `pkg/providers/factory/` directory
|
||||
- [ ] Implement `CreateProviderFromModelConfig()`
|
||||
- [ ] Refactor `http_provider.go` to `openai/provider.go`
|
||||
- [ ] Maintain backward compatibility for old `CreateProvider()`
|
||||
|
||||
### 5.3 Testing
|
||||
|
||||
- [ ] New config unit tests
|
||||
- [ ] Old config compatibility tests
|
||||
- [ ] Integration tests
|
||||
|
||||
### 5.4 Documentation
|
||||
|
||||
- [ ] Update README
|
||||
- [ ] Update config.example.json
|
||||
- [ ] Write migration guide
|
||||
|
||||
---
|
||||
|
||||
## 6. Risks and Mitigations
|
||||
|
||||
| Risk | Mitigation |
|
||||
|------|------------|
|
||||
| Breaking existing configs | Compatibility period keeps old config working |
|
||||
| User migration cost | Provide automatic migration script |
|
||||
| Special Provider incompatibility | Keep `auth_method` and other extension fields |
|
||||
|
||||
---
|
||||
|
||||
## 7. References
|
||||
|
||||
- [LiteLLM Config Documentation](https://docs.litellm.ai/docs/proxy/configs)
|
||||
- [One-API GitHub](https://github.com/songquanpeng/one-api)
|
||||
- Discussion #122: Refactor Provider Architecture
|
||||
+47
-372
@@ -232,23 +232,6 @@ func (c *ModelConfig) Validate() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ParseProtocol extracts the protocol prefix and model identifier from the Model field.
|
||||
// If no prefix is specified, it defaults to "openai".
|
||||
// Examples:
|
||||
// - "openai/gpt-4o" -> ("openai", "gpt-4o")
|
||||
// - "anthropic/claude-3" -> ("anthropic", "claude-3")
|
||||
// - "gpt-4o" -> ("openai", "gpt-4o") // default protocol
|
||||
func (c *ModelConfig) ParseProtocol() (protocol, modelID string) {
|
||||
model := c.Model
|
||||
for i := 0; i < len(model); i++ {
|
||||
if model[i] == '/' {
|
||||
return model[:i], model[i+1:]
|
||||
}
|
||||
}
|
||||
// No prefix found, default to openai
|
||||
return "openai", model
|
||||
}
|
||||
|
||||
type GatewayConfig struct {
|
||||
Host string `json:"host" env:"PICOCLAW_GATEWAY_HOST"`
|
||||
Port int `json:"port" env:"PICOCLAW_GATEWAY_PORT"`
|
||||
@@ -286,135 +269,6 @@ type ToolsConfig struct {
|
||||
Cron CronToolsConfig `json:"cron"`
|
||||
}
|
||||
|
||||
func DefaultConfig() *Config {
|
||||
return &Config{
|
||||
Agents: AgentsConfig{
|
||||
Defaults: AgentDefaults{
|
||||
Workspace: "~/.picoclaw/workspace",
|
||||
RestrictToWorkspace: true,
|
||||
Provider: "",
|
||||
Model: "glm-4.7",
|
||||
MaxTokens: 8192,
|
||||
Temperature: 0.7,
|
||||
MaxToolIterations: 20,
|
||||
},
|
||||
},
|
||||
Channels: ChannelsConfig{
|
||||
WhatsApp: WhatsAppConfig{
|
||||
Enabled: false,
|
||||
BridgeURL: "ws://localhost:3001",
|
||||
AllowFrom: FlexibleStringSlice{},
|
||||
},
|
||||
Telegram: TelegramConfig{
|
||||
Enabled: false,
|
||||
Token: "",
|
||||
AllowFrom: FlexibleStringSlice{},
|
||||
},
|
||||
Feishu: FeishuConfig{
|
||||
Enabled: false,
|
||||
AppID: "",
|
||||
AppSecret: "",
|
||||
EncryptKey: "",
|
||||
VerificationToken: "",
|
||||
AllowFrom: FlexibleStringSlice{},
|
||||
},
|
||||
Discord: DiscordConfig{
|
||||
Enabled: false,
|
||||
Token: "",
|
||||
AllowFrom: FlexibleStringSlice{},
|
||||
},
|
||||
MaixCam: MaixCamConfig{
|
||||
Enabled: false,
|
||||
Host: "0.0.0.0",
|
||||
Port: 18790,
|
||||
AllowFrom: FlexibleStringSlice{},
|
||||
},
|
||||
QQ: QQConfig{
|
||||
Enabled: false,
|
||||
AppID: "",
|
||||
AppSecret: "",
|
||||
AllowFrom: FlexibleStringSlice{},
|
||||
},
|
||||
DingTalk: DingTalkConfig{
|
||||
Enabled: false,
|
||||
ClientID: "",
|
||||
ClientSecret: "",
|
||||
AllowFrom: FlexibleStringSlice{},
|
||||
},
|
||||
Slack: SlackConfig{
|
||||
Enabled: false,
|
||||
BotToken: "",
|
||||
AppToken: "",
|
||||
AllowFrom: FlexibleStringSlice{},
|
||||
},
|
||||
LINE: LINEConfig{
|
||||
Enabled: false,
|
||||
ChannelSecret: "",
|
||||
ChannelAccessToken: "",
|
||||
WebhookHost: "0.0.0.0",
|
||||
WebhookPort: 18791,
|
||||
WebhookPath: "/webhook/line",
|
||||
AllowFrom: FlexibleStringSlice{},
|
||||
},
|
||||
OneBot: OneBotConfig{
|
||||
Enabled: false,
|
||||
WSUrl: "ws://127.0.0.1:3001",
|
||||
AccessToken: "",
|
||||
ReconnectInterval: 5,
|
||||
GroupTriggerPrefix: []string{},
|
||||
AllowFrom: FlexibleStringSlice{},
|
||||
},
|
||||
},
|
||||
Providers: ProvidersConfig{
|
||||
Anthropic: ProviderConfig{},
|
||||
OpenAI: ProviderConfig{},
|
||||
OpenRouter: ProviderConfig{},
|
||||
Groq: ProviderConfig{},
|
||||
Zhipu: ProviderConfig{},
|
||||
VLLM: ProviderConfig{},
|
||||
Gemini: ProviderConfig{},
|
||||
Nvidia: ProviderConfig{},
|
||||
Moonshot: ProviderConfig{},
|
||||
ShengSuanYun: ProviderConfig{},
|
||||
Cerebras: ProviderConfig{},
|
||||
VolcEngine: ProviderConfig{},
|
||||
},
|
||||
Gateway: GatewayConfig{
|
||||
Host: "0.0.0.0",
|
||||
Port: 18790,
|
||||
},
|
||||
Tools: ToolsConfig{
|
||||
Web: WebToolsConfig{
|
||||
Brave: BraveConfig{
|
||||
Enabled: false,
|
||||
APIKey: "",
|
||||
MaxResults: 5,
|
||||
},
|
||||
DuckDuckGo: DuckDuckGoConfig{
|
||||
Enabled: true,
|
||||
MaxResults: 5,
|
||||
},
|
||||
Perplexity: PerplexityConfig{
|
||||
Enabled: false,
|
||||
APIKey: "",
|
||||
MaxResults: 5,
|
||||
},
|
||||
},
|
||||
Cron: CronToolsConfig{
|
||||
ExecTimeoutMinutes: 5, // default 5 minutes for LLM operations
|
||||
},
|
||||
},
|
||||
Heartbeat: HeartbeatConfig{
|
||||
Enabled: true,
|
||||
Interval: 30, // default 30 minutes
|
||||
},
|
||||
Devices: DevicesConfig{
|
||||
Enabled: false,
|
||||
MonitorUSB: true,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func LoadConfig(path string) (*Config, error) {
|
||||
cfg := DefaultConfig()
|
||||
|
||||
@@ -528,40 +382,61 @@ func expandHome(path string) string {
|
||||
// GetModelConfig returns the ModelConfig for the given model name.
|
||||
// If multiple configs exist with the same model_name, it uses round-robin
|
||||
// selection for load balancing. Returns an error if the model is not found.
|
||||
// Uses double-check locking for optimal read performance.
|
||||
func (c *Config) GetModelConfig(modelName string) (*ModelConfig, error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
// First pass: use read lock to find matches
|
||||
c.mu.RLock()
|
||||
matches := c.findMatchesLocked(modelName)
|
||||
if len(matches) == 0 {
|
||||
c.mu.RUnlock()
|
||||
return nil, fmt.Errorf("model %q not found in model_list or providers", modelName)
|
||||
}
|
||||
if len(matches) == 1 {
|
||||
c.mu.RUnlock()
|
||||
return &matches[0], nil
|
||||
}
|
||||
|
||||
// Find all configs with matching model_name
|
||||
// Multiple configs - check if counter exists
|
||||
counter, ok := c.rrCounters[modelName]
|
||||
c.mu.RUnlock()
|
||||
|
||||
// Double-check locking: only acquire write lock if counter needs initialization
|
||||
if !ok {
|
||||
c.mu.Lock()
|
||||
// Re-check after acquiring write lock
|
||||
if c.rrCounters == nil {
|
||||
c.rrCounters = make(map[string]*atomic.Uint64)
|
||||
}
|
||||
if c.rrCounters[modelName] == nil {
|
||||
c.rrCounters[modelName] = &atomic.Uint64{}
|
||||
}
|
||||
counter = c.rrCounters[modelName]
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
// Re-fetch matches to ensure consistency (ModelList could have changed)
|
||||
c.mu.RLock()
|
||||
matches = c.findMatchesLocked(modelName)
|
||||
c.mu.RUnlock()
|
||||
|
||||
if len(matches) == 0 {
|
||||
return nil, fmt.Errorf("model %q not found in model_list or providers", modelName)
|
||||
}
|
||||
|
||||
idx := counter.Add(1) % uint64(len(matches))
|
||||
return &matches[idx], nil
|
||||
}
|
||||
|
||||
// findMatchesLocked finds all ModelConfig entries with the given model_name.
|
||||
// Must be called with c.mu locked (read or write).
|
||||
func (c *Config) findMatchesLocked(modelName string) []ModelConfig {
|
||||
var matches []ModelConfig
|
||||
for i := range c.ModelList {
|
||||
if c.ModelList[i].ModelName == modelName {
|
||||
matches = append(matches, c.ModelList[i])
|
||||
}
|
||||
}
|
||||
|
||||
if len(matches) == 0 {
|
||||
return nil, fmt.Errorf("model %q not found in model_list or providers", modelName)
|
||||
}
|
||||
|
||||
// Single config - return directly
|
||||
if len(matches) == 1 {
|
||||
return &matches[0], nil
|
||||
}
|
||||
|
||||
// Multiple configs - use round-robin for load balancing
|
||||
if c.rrCounters == nil {
|
||||
c.rrCounters = make(map[string]*atomic.Uint64)
|
||||
}
|
||||
|
||||
counter, ok := c.rrCounters[modelName]
|
||||
if !ok {
|
||||
counter = &atomic.Uint64{}
|
||||
c.rrCounters[modelName] = counter
|
||||
}
|
||||
|
||||
idx := counter.Add(1) % uint64(len(matches))
|
||||
return &matches[idx], nil
|
||||
return matches
|
||||
}
|
||||
|
||||
// HasProvidersConfig checks if any provider in the old providers config has configuration.
|
||||
@@ -599,203 +474,3 @@ func (c *Config) ValidateModelList() error {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ConvertProvidersToModelList converts the old ProvidersConfig to a slice of ModelConfig.
|
||||
// This enables backward compatibility with existing configurations.
|
||||
func ConvertProvidersToModelList(cfg *Config) []ModelConfig {
|
||||
if cfg == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var result []ModelConfig
|
||||
p := cfg.Providers
|
||||
|
||||
// OpenAI
|
||||
if p.OpenAI.APIKey != "" || p.OpenAI.APIBase != "" {
|
||||
result = append(result, ModelConfig{
|
||||
ModelName: "openai",
|
||||
Model: "openai/gpt-4o",
|
||||
APIKey: p.OpenAI.APIKey,
|
||||
APIBase: p.OpenAI.APIBase,
|
||||
Proxy: p.OpenAI.Proxy,
|
||||
AuthMethod: p.OpenAI.AuthMethod,
|
||||
})
|
||||
}
|
||||
|
||||
// Anthropic
|
||||
if p.Anthropic.APIKey != "" || p.Anthropic.APIBase != "" {
|
||||
result = append(result, ModelConfig{
|
||||
ModelName: "anthropic",
|
||||
Model: "anthropic/claude-3-sonnet",
|
||||
APIKey: p.Anthropic.APIKey,
|
||||
APIBase: p.Anthropic.APIBase,
|
||||
Proxy: p.Anthropic.Proxy,
|
||||
AuthMethod: p.Anthropic.AuthMethod,
|
||||
})
|
||||
}
|
||||
|
||||
// OpenRouter
|
||||
if p.OpenRouter.APIKey != "" || p.OpenRouter.APIBase != "" {
|
||||
result = append(result, ModelConfig{
|
||||
ModelName: "openrouter",
|
||||
Model: "openrouter/auto",
|
||||
APIKey: p.OpenRouter.APIKey,
|
||||
APIBase: p.OpenRouter.APIBase,
|
||||
Proxy: p.OpenRouter.Proxy,
|
||||
})
|
||||
}
|
||||
|
||||
// Groq
|
||||
if p.Groq.APIKey != "" || p.Groq.APIBase != "" {
|
||||
result = append(result, ModelConfig{
|
||||
ModelName: "groq",
|
||||
Model: "groq/llama-3.1-70b-versatile",
|
||||
APIKey: p.Groq.APIKey,
|
||||
APIBase: p.Groq.APIBase,
|
||||
Proxy: p.Groq.Proxy,
|
||||
})
|
||||
}
|
||||
|
||||
// Zhipu
|
||||
if p.Zhipu.APIKey != "" || p.Zhipu.APIBase != "" {
|
||||
result = append(result, ModelConfig{
|
||||
ModelName: "zhipu",
|
||||
Model: "openai/glm-4",
|
||||
APIKey: p.Zhipu.APIKey,
|
||||
APIBase: p.Zhipu.APIBase,
|
||||
Proxy: p.Zhipu.Proxy,
|
||||
})
|
||||
}
|
||||
|
||||
// VLLM
|
||||
if p.VLLM.APIKey != "" || p.VLLM.APIBase != "" {
|
||||
result = append(result, ModelConfig{
|
||||
ModelName: "vllm",
|
||||
Model: "openai/auto",
|
||||
APIKey: p.VLLM.APIKey,
|
||||
APIBase: p.VLLM.APIBase,
|
||||
Proxy: p.VLLM.Proxy,
|
||||
})
|
||||
}
|
||||
|
||||
// Gemini
|
||||
if p.Gemini.APIKey != "" || p.Gemini.APIBase != "" {
|
||||
result = append(result, ModelConfig{
|
||||
ModelName: "gemini",
|
||||
Model: "openai/gemini-pro",
|
||||
APIKey: p.Gemini.APIKey,
|
||||
APIBase: p.Gemini.APIBase,
|
||||
Proxy: p.Gemini.Proxy,
|
||||
})
|
||||
}
|
||||
|
||||
// Nvidia
|
||||
if p.Nvidia.APIKey != "" || p.Nvidia.APIBase != "" {
|
||||
result = append(result, ModelConfig{
|
||||
ModelName: "nvidia",
|
||||
Model: "nvidia/meta/llama-3.1-8b-instruct",
|
||||
APIKey: p.Nvidia.APIKey,
|
||||
APIBase: p.Nvidia.APIBase,
|
||||
Proxy: p.Nvidia.Proxy,
|
||||
})
|
||||
}
|
||||
|
||||
// Ollama
|
||||
if p.Ollama.APIKey != "" || p.Ollama.APIBase != "" {
|
||||
result = append(result, ModelConfig{
|
||||
ModelName: "ollama",
|
||||
Model: "ollama/llama3",
|
||||
APIKey: p.Ollama.APIKey,
|
||||
APIBase: p.Ollama.APIBase,
|
||||
Proxy: p.Ollama.Proxy,
|
||||
})
|
||||
}
|
||||
|
||||
// Moonshot
|
||||
if p.Moonshot.APIKey != "" || p.Moonshot.APIBase != "" {
|
||||
result = append(result, ModelConfig{
|
||||
ModelName: "moonshot",
|
||||
Model: "moonshot/kimi",
|
||||
APIKey: p.Moonshot.APIKey,
|
||||
APIBase: p.Moonshot.APIBase,
|
||||
Proxy: p.Moonshot.Proxy,
|
||||
})
|
||||
}
|
||||
|
||||
// ShengSuanYun
|
||||
if p.ShengSuanYun.APIKey != "" || p.ShengSuanYun.APIBase != "" {
|
||||
result = append(result, ModelConfig{
|
||||
ModelName: "shengsuanyun",
|
||||
Model: "openai/auto",
|
||||
APIKey: p.ShengSuanYun.APIKey,
|
||||
APIBase: p.ShengSuanYun.APIBase,
|
||||
Proxy: p.ShengSuanYun.Proxy,
|
||||
})
|
||||
}
|
||||
|
||||
// DeepSeek
|
||||
if p.DeepSeek.APIKey != "" || p.DeepSeek.APIBase != "" {
|
||||
result = append(result, ModelConfig{
|
||||
ModelName: "deepseek",
|
||||
Model: "openai/deepseek-chat",
|
||||
APIKey: p.DeepSeek.APIKey,
|
||||
APIBase: p.DeepSeek.APIBase,
|
||||
Proxy: p.DeepSeek.Proxy,
|
||||
})
|
||||
}
|
||||
|
||||
// Cerebras
|
||||
if p.Cerebras.APIKey != "" || p.Cerebras.APIBase != "" {
|
||||
result = append(result, ModelConfig{
|
||||
ModelName: "cerebras",
|
||||
Model: "cerebras/llama-3.3-70b",
|
||||
APIKey: p.Cerebras.APIKey,
|
||||
APIBase: p.Cerebras.APIBase,
|
||||
Proxy: p.Cerebras.Proxy,
|
||||
})
|
||||
}
|
||||
|
||||
// VolcEngine (Doubao)
|
||||
if p.VolcEngine.APIKey != "" || p.VolcEngine.APIBase != "" {
|
||||
result = append(result, ModelConfig{
|
||||
ModelName: "volcengine",
|
||||
Model: "openai/doubao-pro",
|
||||
APIKey: p.VolcEngine.APIKey,
|
||||
APIBase: p.VolcEngine.APIBase,
|
||||
Proxy: p.VolcEngine.Proxy,
|
||||
})
|
||||
}
|
||||
|
||||
// GitHub Copilot
|
||||
if p.GitHubCopilot.APIKey != "" || p.GitHubCopilot.APIBase != "" || p.GitHubCopilot.ConnectMode != "" {
|
||||
result = append(result, ModelConfig{
|
||||
ModelName: "github-copilot",
|
||||
Model: "github-copilot/gpt-4o",
|
||||
APIBase: p.GitHubCopilot.APIBase,
|
||||
ConnectMode: p.GitHubCopilot.ConnectMode,
|
||||
})
|
||||
}
|
||||
|
||||
// Antigravity
|
||||
if p.Antigravity.APIKey != "" || p.Antigravity.AuthMethod != "" {
|
||||
result = append(result, ModelConfig{
|
||||
ModelName: "antigravity",
|
||||
Model: "antigravity/gemini-2.0-flash",
|
||||
APIKey: p.Antigravity.APIKey,
|
||||
AuthMethod: p.Antigravity.AuthMethod,
|
||||
})
|
||||
}
|
||||
|
||||
// Qwen
|
||||
if p.Qwen.APIKey != "" || p.Qwen.APIBase != "" {
|
||||
result = append(result, ModelConfig{
|
||||
ModelName: "qwen",
|
||||
Model: "qwen/qwen-max",
|
||||
APIKey: p.Qwen.APIKey,
|
||||
APIBase: p.Qwen.APIBase,
|
||||
Proxy: p.Qwen.Proxy,
|
||||
})
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
@@ -0,0 +1,136 @@
|
||||
// PicoClaw - Ultra-lightweight personal AI agent
|
||||
// License: MIT
|
||||
//
|
||||
// Copyright (c) 2026 PicoClaw contributors
|
||||
|
||||
package config
|
||||
|
||||
// DefaultConfig returns the default configuration for PicoClaw.
|
||||
func DefaultConfig() *Config {
|
||||
return &Config{
|
||||
Agents: AgentsConfig{
|
||||
Defaults: AgentDefaults{
|
||||
Workspace: "~/.picoclaw/workspace",
|
||||
RestrictToWorkspace: true,
|
||||
Provider: "",
|
||||
Model: "glm-4.7",
|
||||
MaxTokens: 8192,
|
||||
Temperature: 0.7,
|
||||
MaxToolIterations: 20,
|
||||
},
|
||||
},
|
||||
Channels: ChannelsConfig{
|
||||
WhatsApp: WhatsAppConfig{
|
||||
Enabled: false,
|
||||
BridgeURL: "ws://localhost:3001",
|
||||
AllowFrom: FlexibleStringSlice{},
|
||||
},
|
||||
Telegram: TelegramConfig{
|
||||
Enabled: false,
|
||||
Token: "",
|
||||
AllowFrom: FlexibleStringSlice{},
|
||||
},
|
||||
Feishu: FeishuConfig{
|
||||
Enabled: false,
|
||||
AppID: "",
|
||||
AppSecret: "",
|
||||
EncryptKey: "",
|
||||
VerificationToken: "",
|
||||
AllowFrom: FlexibleStringSlice{},
|
||||
},
|
||||
Discord: DiscordConfig{
|
||||
Enabled: false,
|
||||
Token: "",
|
||||
AllowFrom: FlexibleStringSlice{},
|
||||
},
|
||||
MaixCam: MaixCamConfig{
|
||||
Enabled: false,
|
||||
Host: "0.0.0.0",
|
||||
Port: 18790,
|
||||
AllowFrom: FlexibleStringSlice{},
|
||||
},
|
||||
QQ: QQConfig{
|
||||
Enabled: false,
|
||||
AppID: "",
|
||||
AppSecret: "",
|
||||
AllowFrom: FlexibleStringSlice{},
|
||||
},
|
||||
DingTalk: DingTalkConfig{
|
||||
Enabled: false,
|
||||
ClientID: "",
|
||||
ClientSecret: "",
|
||||
AllowFrom: FlexibleStringSlice{},
|
||||
},
|
||||
Slack: SlackConfig{
|
||||
Enabled: false,
|
||||
BotToken: "",
|
||||
AppToken: "",
|
||||
AllowFrom: FlexibleStringSlice{},
|
||||
},
|
||||
LINE: LINEConfig{
|
||||
Enabled: false,
|
||||
ChannelSecret: "",
|
||||
ChannelAccessToken: "",
|
||||
WebhookHost: "0.0.0.0",
|
||||
WebhookPort: 18791,
|
||||
WebhookPath: "/webhook/line",
|
||||
AllowFrom: FlexibleStringSlice{},
|
||||
},
|
||||
OneBot: OneBotConfig{
|
||||
Enabled: false,
|
||||
WSUrl: "ws://127.0.0.1:3001",
|
||||
AccessToken: "",
|
||||
ReconnectInterval: 5,
|
||||
GroupTriggerPrefix: []string{},
|
||||
AllowFrom: FlexibleStringSlice{},
|
||||
},
|
||||
},
|
||||
Providers: ProvidersConfig{
|
||||
Anthropic: ProviderConfig{},
|
||||
OpenAI: ProviderConfig{},
|
||||
OpenRouter: ProviderConfig{},
|
||||
Groq: ProviderConfig{},
|
||||
Zhipu: ProviderConfig{},
|
||||
VLLM: ProviderConfig{},
|
||||
Gemini: ProviderConfig{},
|
||||
Nvidia: ProviderConfig{},
|
||||
Moonshot: ProviderConfig{},
|
||||
ShengSuanYun: ProviderConfig{},
|
||||
Cerebras: ProviderConfig{},
|
||||
VolcEngine: ProviderConfig{},
|
||||
},
|
||||
Gateway: GatewayConfig{
|
||||
Host: "0.0.0.0",
|
||||
Port: 18790,
|
||||
},
|
||||
Tools: ToolsConfig{
|
||||
Web: WebToolsConfig{
|
||||
Brave: BraveConfig{
|
||||
Enabled: false,
|
||||
APIKey: "",
|
||||
MaxResults: 5,
|
||||
},
|
||||
DuckDuckGo: DuckDuckGoConfig{
|
||||
Enabled: true,
|
||||
MaxResults: 5,
|
||||
},
|
||||
Perplexity: PerplexityConfig{
|
||||
Enabled: false,
|
||||
APIKey: "",
|
||||
MaxResults: 5,
|
||||
},
|
||||
},
|
||||
Cron: CronToolsConfig{
|
||||
ExecTimeoutMinutes: 5, // default 5 minutes for LLM operations
|
||||
},
|
||||
},
|
||||
Heartbeat: HeartbeatConfig{
|
||||
Enabled: true,
|
||||
Interval: 30, // default 30 minutes
|
||||
},
|
||||
Devices: DevicesConfig{
|
||||
Enabled: false,
|
||||
MonitorUSB: true,
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,206 @@
|
||||
// PicoClaw - Ultra-lightweight personal AI agent
|
||||
// License: MIT
|
||||
//
|
||||
// Copyright (c) 2026 PicoClaw contributors
|
||||
|
||||
package config
|
||||
|
||||
// ConvertProvidersToModelList converts the old ProvidersConfig to a slice of ModelConfig.
|
||||
// This enables backward compatibility with existing configurations.
|
||||
func ConvertProvidersToModelList(cfg *Config) []ModelConfig {
|
||||
if cfg == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var result []ModelConfig
|
||||
p := cfg.Providers
|
||||
|
||||
// OpenAI
|
||||
if p.OpenAI.APIKey != "" || p.OpenAI.APIBase != "" {
|
||||
result = append(result, ModelConfig{
|
||||
ModelName: "openai",
|
||||
Model: "openai/gpt-4o",
|
||||
APIKey: p.OpenAI.APIKey,
|
||||
APIBase: p.OpenAI.APIBase,
|
||||
Proxy: p.OpenAI.Proxy,
|
||||
AuthMethod: p.OpenAI.AuthMethod,
|
||||
})
|
||||
}
|
||||
|
||||
// Anthropic
|
||||
if p.Anthropic.APIKey != "" || p.Anthropic.APIBase != "" {
|
||||
result = append(result, ModelConfig{
|
||||
ModelName: "anthropic",
|
||||
Model: "anthropic/claude-3-sonnet",
|
||||
APIKey: p.Anthropic.APIKey,
|
||||
APIBase: p.Anthropic.APIBase,
|
||||
Proxy: p.Anthropic.Proxy,
|
||||
AuthMethod: p.Anthropic.AuthMethod,
|
||||
})
|
||||
}
|
||||
|
||||
// OpenRouter
|
||||
if p.OpenRouter.APIKey != "" || p.OpenRouter.APIBase != "" {
|
||||
result = append(result, ModelConfig{
|
||||
ModelName: "openrouter",
|
||||
Model: "openrouter/auto",
|
||||
APIKey: p.OpenRouter.APIKey,
|
||||
APIBase: p.OpenRouter.APIBase,
|
||||
Proxy: p.OpenRouter.Proxy,
|
||||
})
|
||||
}
|
||||
|
||||
// Groq
|
||||
if p.Groq.APIKey != "" || p.Groq.APIBase != "" {
|
||||
result = append(result, ModelConfig{
|
||||
ModelName: "groq",
|
||||
Model: "groq/llama-3.1-70b-versatile",
|
||||
APIKey: p.Groq.APIKey,
|
||||
APIBase: p.Groq.APIBase,
|
||||
Proxy: p.Groq.Proxy,
|
||||
})
|
||||
}
|
||||
|
||||
// Zhipu
|
||||
if p.Zhipu.APIKey != "" || p.Zhipu.APIBase != "" {
|
||||
result = append(result, ModelConfig{
|
||||
ModelName: "zhipu",
|
||||
Model: "openai/glm-4",
|
||||
APIKey: p.Zhipu.APIKey,
|
||||
APIBase: p.Zhipu.APIBase,
|
||||
Proxy: p.Zhipu.Proxy,
|
||||
})
|
||||
}
|
||||
|
||||
// VLLM
|
||||
if p.VLLM.APIKey != "" || p.VLLM.APIBase != "" {
|
||||
result = append(result, ModelConfig{
|
||||
ModelName: "vllm",
|
||||
Model: "openai/auto",
|
||||
APIKey: p.VLLM.APIKey,
|
||||
APIBase: p.VLLM.APIBase,
|
||||
Proxy: p.VLLM.Proxy,
|
||||
})
|
||||
}
|
||||
|
||||
// Gemini
|
||||
if p.Gemini.APIKey != "" || p.Gemini.APIBase != "" {
|
||||
result = append(result, ModelConfig{
|
||||
ModelName: "gemini",
|
||||
Model: "openai/gemini-pro",
|
||||
APIKey: p.Gemini.APIKey,
|
||||
APIBase: p.Gemini.APIBase,
|
||||
Proxy: p.Gemini.Proxy,
|
||||
})
|
||||
}
|
||||
|
||||
// Nvidia
|
||||
if p.Nvidia.APIKey != "" || p.Nvidia.APIBase != "" {
|
||||
result = append(result, ModelConfig{
|
||||
ModelName: "nvidia",
|
||||
Model: "nvidia/meta/llama-3.1-8b-instruct",
|
||||
APIKey: p.Nvidia.APIKey,
|
||||
APIBase: p.Nvidia.APIBase,
|
||||
Proxy: p.Nvidia.Proxy,
|
||||
})
|
||||
}
|
||||
|
||||
// Ollama
|
||||
if p.Ollama.APIKey != "" || p.Ollama.APIBase != "" {
|
||||
result = append(result, ModelConfig{
|
||||
ModelName: "ollama",
|
||||
Model: "ollama/llama3",
|
||||
APIKey: p.Ollama.APIKey,
|
||||
APIBase: p.Ollama.APIBase,
|
||||
Proxy: p.Ollama.Proxy,
|
||||
})
|
||||
}
|
||||
|
||||
// Moonshot
|
||||
if p.Moonshot.APIKey != "" || p.Moonshot.APIBase != "" {
|
||||
result = append(result, ModelConfig{
|
||||
ModelName: "moonshot",
|
||||
Model: "moonshot/kimi",
|
||||
APIKey: p.Moonshot.APIKey,
|
||||
APIBase: p.Moonshot.APIBase,
|
||||
Proxy: p.Moonshot.Proxy,
|
||||
})
|
||||
}
|
||||
|
||||
// ShengSuanYun
|
||||
if p.ShengSuanYun.APIKey != "" || p.ShengSuanYun.APIBase != "" {
|
||||
result = append(result, ModelConfig{
|
||||
ModelName: "shengsuanyun",
|
||||
Model: "openai/auto",
|
||||
APIKey: p.ShengSuanYun.APIKey,
|
||||
APIBase: p.ShengSuanYun.APIBase,
|
||||
Proxy: p.ShengSuanYun.Proxy,
|
||||
})
|
||||
}
|
||||
|
||||
// DeepSeek
|
||||
if p.DeepSeek.APIKey != "" || p.DeepSeek.APIBase != "" {
|
||||
result = append(result, ModelConfig{
|
||||
ModelName: "deepseek",
|
||||
Model: "openai/deepseek-chat",
|
||||
APIKey: p.DeepSeek.APIKey,
|
||||
APIBase: p.DeepSeek.APIBase,
|
||||
Proxy: p.DeepSeek.Proxy,
|
||||
})
|
||||
}
|
||||
|
||||
// Cerebras
|
||||
if p.Cerebras.APIKey != "" || p.Cerebras.APIBase != "" {
|
||||
result = append(result, ModelConfig{
|
||||
ModelName: "cerebras",
|
||||
Model: "cerebras/llama-3.3-70b",
|
||||
APIKey: p.Cerebras.APIKey,
|
||||
APIBase: p.Cerebras.APIBase,
|
||||
Proxy: p.Cerebras.Proxy,
|
||||
})
|
||||
}
|
||||
|
||||
// VolcEngine (Doubao)
|
||||
if p.VolcEngine.APIKey != "" || p.VolcEngine.APIBase != "" {
|
||||
result = append(result, ModelConfig{
|
||||
ModelName: "volcengine",
|
||||
Model: "openai/doubao-pro",
|
||||
APIKey: p.VolcEngine.APIKey,
|
||||
APIBase: p.VolcEngine.APIBase,
|
||||
Proxy: p.VolcEngine.Proxy,
|
||||
})
|
||||
}
|
||||
|
||||
// GitHub Copilot
|
||||
if p.GitHubCopilot.APIKey != "" || p.GitHubCopilot.APIBase != "" || p.GitHubCopilot.ConnectMode != "" {
|
||||
result = append(result, ModelConfig{
|
||||
ModelName: "github-copilot",
|
||||
Model: "github-copilot/gpt-4o",
|
||||
APIBase: p.GitHubCopilot.APIBase,
|
||||
ConnectMode: p.GitHubCopilot.ConnectMode,
|
||||
})
|
||||
}
|
||||
|
||||
// Antigravity
|
||||
if p.Antigravity.APIKey != "" || p.Antigravity.AuthMethod != "" {
|
||||
result = append(result, ModelConfig{
|
||||
ModelName: "antigravity",
|
||||
Model: "antigravity/gemini-2.0-flash",
|
||||
APIKey: p.Antigravity.APIKey,
|
||||
AuthMethod: p.Antigravity.AuthMethod,
|
||||
})
|
||||
}
|
||||
|
||||
// Qwen
|
||||
if p.Qwen.APIKey != "" || p.Qwen.APIBase != "" {
|
||||
result = append(result, ModelConfig{
|
||||
ModelName: "qwen",
|
||||
Model: "qwen/qwen-max",
|
||||
APIKey: p.Qwen.APIKey,
|
||||
APIBase: p.Qwen.APIBase,
|
||||
Proxy: p.Qwen.Proxy,
|
||||
})
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
@@ -0,0 +1,177 @@
|
||||
// PicoClaw - Ultra-lightweight personal AI agent
|
||||
// License: MIT
|
||||
//
|
||||
// Copyright (c) 2026 PicoClaw contributors
|
||||
|
||||
package config
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestConvertProvidersToModelList_OpenAI(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Providers: ProvidersConfig{
|
||||
OpenAI: ProviderConfig{
|
||||
APIKey: "sk-test-key",
|
||||
APIBase: "https://custom.api.com/v1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result := ConvertProvidersToModelList(cfg)
|
||||
|
||||
if len(result) != 1 {
|
||||
t.Fatalf("len(result) = %d, want 1", len(result))
|
||||
}
|
||||
|
||||
if result[0].ModelName != "openai" {
|
||||
t.Errorf("ModelName = %q, want %q", result[0].ModelName, "openai")
|
||||
}
|
||||
if result[0].Model != "openai/gpt-4o" {
|
||||
t.Errorf("Model = %q, want %q", result[0].Model, "openai/gpt-4o")
|
||||
}
|
||||
if result[0].APIKey != "sk-test-key" {
|
||||
t.Errorf("APIKey = %q, want %q", result[0].APIKey, "sk-test-key")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertProvidersToModelList_Anthropic(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Providers: ProvidersConfig{
|
||||
Anthropic: ProviderConfig{
|
||||
APIKey: "ant-key",
|
||||
APIBase: "https://custom.anthropic.com",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result := ConvertProvidersToModelList(cfg)
|
||||
|
||||
if len(result) != 1 {
|
||||
t.Fatalf("len(result) = %d, want 1", len(result))
|
||||
}
|
||||
|
||||
if result[0].ModelName != "anthropic" {
|
||||
t.Errorf("ModelName = %q, want %q", result[0].ModelName, "anthropic")
|
||||
}
|
||||
if result[0].Model != "anthropic/claude-3-sonnet" {
|
||||
t.Errorf("Model = %q, want %q", result[0].Model, "anthropic/claude-3-sonnet")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertProvidersToModelList_Multiple(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Providers: ProvidersConfig{
|
||||
OpenAI: ProviderConfig{APIKey: "openai-key"},
|
||||
Groq: ProviderConfig{APIKey: "groq-key"},
|
||||
Zhipu: ProviderConfig{APIKey: "zhipu-key"},
|
||||
},
|
||||
}
|
||||
|
||||
result := ConvertProvidersToModelList(cfg)
|
||||
|
||||
if len(result) != 3 {
|
||||
t.Fatalf("len(result) = %d, want 3", len(result))
|
||||
}
|
||||
|
||||
// Check that all providers are present
|
||||
found := make(map[string]bool)
|
||||
for _, mc := range result {
|
||||
found[mc.ModelName] = true
|
||||
}
|
||||
|
||||
for _, name := range []string{"openai", "groq", "zhipu"} {
|
||||
if !found[name] {
|
||||
t.Errorf("Missing provider %q in result", name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertProvidersToModelList_Empty(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Providers: ProvidersConfig{},
|
||||
}
|
||||
|
||||
result := ConvertProvidersToModelList(cfg)
|
||||
|
||||
if len(result) != 0 {
|
||||
t.Errorf("len(result) = %d, want 0", len(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertProvidersToModelList_Nil(t *testing.T) {
|
||||
result := ConvertProvidersToModelList(nil)
|
||||
|
||||
if result != nil {
|
||||
t.Errorf("result = %v, want nil", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertProvidersToModelList_AllProviders(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Providers: ProvidersConfig{
|
||||
OpenAI: ProviderConfig{APIKey: "key1"},
|
||||
Anthropic: ProviderConfig{APIKey: "key2"},
|
||||
OpenRouter: ProviderConfig{APIKey: "key3"},
|
||||
Groq: ProviderConfig{APIKey: "key4"},
|
||||
Zhipu: ProviderConfig{APIKey: "key5"},
|
||||
VLLM: ProviderConfig{APIKey: "key6"},
|
||||
Gemini: ProviderConfig{APIKey: "key7"},
|
||||
Nvidia: ProviderConfig{APIKey: "key8"},
|
||||
Ollama: ProviderConfig{APIKey: "key9"},
|
||||
Moonshot: ProviderConfig{APIKey: "key10"},
|
||||
ShengSuanYun: ProviderConfig{APIKey: "key11"},
|
||||
DeepSeek: ProviderConfig{APIKey: "key12"},
|
||||
Cerebras: ProviderConfig{APIKey: "key13"},
|
||||
VolcEngine: ProviderConfig{APIKey: "key14"},
|
||||
GitHubCopilot: ProviderConfig{ConnectMode: "grpc"},
|
||||
Antigravity: ProviderConfig{AuthMethod: "oauth"},
|
||||
Qwen: ProviderConfig{APIKey: "key17"},
|
||||
},
|
||||
}
|
||||
|
||||
result := ConvertProvidersToModelList(cfg)
|
||||
|
||||
// All 17 providers should be converted
|
||||
if len(result) != 17 {
|
||||
t.Errorf("len(result) = %d, want 17", len(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertProvidersToModelList_Proxy(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Providers: ProvidersConfig{
|
||||
OpenAI: ProviderConfig{
|
||||
APIKey: "key",
|
||||
Proxy: "http://proxy:8080",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result := ConvertProvidersToModelList(cfg)
|
||||
|
||||
if len(result) != 1 {
|
||||
t.Fatalf("len(result) = %d, want 1", len(result))
|
||||
}
|
||||
|
||||
if result[0].Proxy != "http://proxy:8080" {
|
||||
t.Errorf("Proxy = %q, want %q", result[0].Proxy, "http://proxy:8080")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertProvidersToModelList_AuthMethod(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Providers: ProvidersConfig{
|
||||
OpenAI: ProviderConfig{
|
||||
AuthMethod: "oauth",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result := ConvertProvidersToModelList(cfg)
|
||||
|
||||
if len(result) != 0 {
|
||||
t.Errorf("len(result) = %d, want 0 (AuthMethod alone should not create entry)", len(result))
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,204 @@
|
||||
// PicoClaw - Ultra-lightweight personal AI agent
|
||||
// License: MIT
|
||||
//
|
||||
// Copyright (c) 2026 PicoClaw contributors
|
||||
|
||||
package config
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGetModelConfig_Found(t *testing.T) {
|
||||
cfg := &Config{
|
||||
ModelList: []ModelConfig{
|
||||
{ModelName: "test-model", Model: "openai/gpt-4o", APIKey: "key1"},
|
||||
{ModelName: "other-model", Model: "anthropic/claude", APIKey: "key2"},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := cfg.GetModelConfig("test-model")
|
||||
if err != nil {
|
||||
t.Fatalf("GetModelConfig() error = %v", err)
|
||||
}
|
||||
if result.Model != "openai/gpt-4o" {
|
||||
t.Errorf("Model = %q, want %q", result.Model, "openai/gpt-4o")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetModelConfig_NotFound(t *testing.T) {
|
||||
cfg := &Config{
|
||||
ModelList: []ModelConfig{
|
||||
{ModelName: "test-model", Model: "openai/gpt-4o", APIKey: "key1"},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := cfg.GetModelConfig("nonexistent")
|
||||
if err == nil {
|
||||
t.Fatal("GetModelConfig() expected error for nonexistent model")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetModelConfig_EmptyList(t *testing.T) {
|
||||
cfg := &Config{
|
||||
ModelList: []ModelConfig{},
|
||||
}
|
||||
|
||||
_, err := cfg.GetModelConfig("any-model")
|
||||
if err == nil {
|
||||
t.Fatal("GetModelConfig() expected error for empty model list")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetModelConfig_RoundRobin(t *testing.T) {
|
||||
cfg := &Config{
|
||||
ModelList: []ModelConfig{
|
||||
{ModelName: "lb-model", Model: "openai/gpt-4o-1", APIKey: "key1"},
|
||||
{ModelName: "lb-model", Model: "openai/gpt-4o-2", APIKey: "key2"},
|
||||
{ModelName: "lb-model", Model: "openai/gpt-4o-3", APIKey: "key3"},
|
||||
},
|
||||
}
|
||||
|
||||
// Test round-robin distribution
|
||||
results := make(map[string]int)
|
||||
for i := 0; i < 30; i++ {
|
||||
result, err := cfg.GetModelConfig("lb-model")
|
||||
if err != nil {
|
||||
t.Fatalf("GetModelConfig() error = %v", err)
|
||||
}
|
||||
results[result.Model]++
|
||||
}
|
||||
|
||||
// Each model should appear roughly 10 times (30 calls / 3 models)
|
||||
for model, count := range results {
|
||||
if count < 5 || count > 15 {
|
||||
t.Errorf("Model %s appeared %d times, expected ~10", model, count)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetModelConfig_Concurrent(t *testing.T) {
|
||||
cfg := &Config{
|
||||
ModelList: []ModelConfig{
|
||||
{ModelName: "concurrent-model", Model: "openai/gpt-4o-1", APIKey: "key1"},
|
||||
{ModelName: "concurrent-model", Model: "openai/gpt-4o-2", APIKey: "key2"},
|
||||
},
|
||||
}
|
||||
|
||||
const goroutines = 100
|
||||
const iterations = 10
|
||||
|
||||
var wg sync.WaitGroup
|
||||
errors := make(chan error, goroutines*iterations)
|
||||
|
||||
for i := 0; i < goroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for j := 0; j < iterations; j++ {
|
||||
_, err := cfg.GetModelConfig("concurrent-model")
|
||||
if err != nil {
|
||||
errors <- err
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(errors)
|
||||
|
||||
for err := range errors {
|
||||
t.Errorf("Concurrent GetModelConfig() error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelConfig_Validate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config ModelConfig
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid config",
|
||||
config: ModelConfig{
|
||||
ModelName: "test",
|
||||
Model: "openai/gpt-4o",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "missing model_name",
|
||||
config: ModelConfig{
|
||||
Model: "openai/gpt-4o",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "missing model",
|
||||
config: ModelConfig{
|
||||
ModelName: "test",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "empty config",
|
||||
config: ModelConfig{},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := tt.config.Validate()
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Validate() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig_ValidateModelList(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config *Config
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid list",
|
||||
config: &Config{
|
||||
ModelList: []ModelConfig{
|
||||
{ModelName: "test1", Model: "openai/gpt-4o"},
|
||||
{ModelName: "test2", Model: "anthropic/claude"},
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid entry",
|
||||
config: &Config{
|
||||
ModelList: []ModelConfig{
|
||||
{ModelName: "test1", Model: "openai/gpt-4o"},
|
||||
{ModelName: "", Model: "anthropic/claude"}, // missing model_name
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "empty list",
|
||||
config: &Config{
|
||||
ModelList: []ModelConfig{},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := tt.config.ValidateModelList()
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("ValidateModelList() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -180,8 +180,8 @@ func TestConvertConfig(t *testing.T) {
|
||||
t.Run("unsupported provider warning", func(t *testing.T) {
|
||||
data := map[string]interface{}{
|
||||
"providers": map[string]interface{}{
|
||||
"deepseek": map[string]interface{}{
|
||||
"api_key": "sk-deep-test",
|
||||
"unknown_provider": map[string]interface{}{
|
||||
"api_key": "sk-test",
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -193,7 +193,7 @@ func TestConvertConfig(t *testing.T) {
|
||||
if len(warnings) != 1 {
|
||||
t.Fatalf("expected 1 warning, got %d", len(warnings))
|
||||
}
|
||||
if warnings[0] != "Provider 'deepseek' not supported in PicoClaw, skipping" {
|
||||
if warnings[0] != "Provider 'unknown_provider' not supported in PicoClaw, skipping" {
|
||||
t.Errorf("unexpected warning: %s", warnings[0])
|
||||
}
|
||||
})
|
||||
|
||||
@@ -419,7 +419,7 @@ func TestCreateProvider_ClaudeCli(t *testing.T) {
|
||||
cfg.Agents.Defaults.Provider = "claude-cli"
|
||||
cfg.Agents.Defaults.Workspace = "/test/ws"
|
||||
|
||||
provider, err := CreateProvider(cfg)
|
||||
provider, _, err := CreateProvider(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateProvider(claude-cli) error = %v", err)
|
||||
}
|
||||
@@ -437,7 +437,7 @@ func TestCreateProvider_ClaudeCode(t *testing.T) {
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Agents.Defaults.Provider = "claude-code"
|
||||
|
||||
provider, err := CreateProvider(cfg)
|
||||
provider, _, err := CreateProvider(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateProvider(claude-code) error = %v", err)
|
||||
}
|
||||
@@ -450,7 +450,7 @@ func TestCreateProvider_ClaudeCodec(t *testing.T) {
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Agents.Defaults.Provider = "claudecode"
|
||||
|
||||
provider, err := CreateProvider(cfg)
|
||||
provider, _, err := CreateProvider(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateProvider(claudecode) error = %v", err)
|
||||
}
|
||||
@@ -464,7 +464,7 @@ func TestCreateProvider_ClaudeCliDefaultWorkspace(t *testing.T) {
|
||||
cfg.Agents.Defaults.Provider = "claude-cli"
|
||||
cfg.Agents.Defaults.Workspace = ""
|
||||
|
||||
provider, err := CreateProvider(cfg)
|
||||
provider, _, err := CreateProvider(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateProvider error = %v", err)
|
||||
}
|
||||
|
||||
@@ -32,13 +32,14 @@ func ExtractProtocol(model string) (protocol, modelID string) {
|
||||
// CreateProviderFromConfig creates a provider based on the ModelConfig.
|
||||
// It uses the protocol prefix in the Model field to determine which provider to create.
|
||||
// Supported protocols: openai, anthropic, antigravity, claude-cli, codex-cli, github-copilot
|
||||
func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, error) {
|
||||
// Returns the provider, the model ID (without protocol prefix), and any error.
|
||||
func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, error) {
|
||||
if cfg == nil {
|
||||
return nil, fmt.Errorf("config is nil")
|
||||
return nil, "", fmt.Errorf("config is nil")
|
||||
}
|
||||
|
||||
if cfg.Model == "" {
|
||||
return nil, fmt.Errorf("model is required")
|
||||
return nil, "", fmt.Errorf("model is required")
|
||||
}
|
||||
|
||||
protocol, modelID := ExtractProtocol(cfg.Model)
|
||||
@@ -49,36 +50,36 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, error) {
|
||||
"volcengine", "vllm", "qwen":
|
||||
// All OpenAI-compatible HTTP providers
|
||||
if cfg.APIKey == "" && cfg.APIBase == "" {
|
||||
return nil, fmt.Errorf("api_key or api_base is required for HTTP-based protocol %q", protocol)
|
||||
return nil, "", fmt.Errorf("api_key or api_base is required for HTTP-based protocol %q", protocol)
|
||||
}
|
||||
apiBase := cfg.APIBase
|
||||
if apiBase == "" {
|
||||
apiBase = getDefaultAPIBase(protocol)
|
||||
}
|
||||
return NewHTTPProvider(cfg.APIKey, apiBase, cfg.Proxy), nil
|
||||
return NewHTTPProvider(cfg.APIKey, apiBase, cfg.Proxy), modelID, nil
|
||||
|
||||
case "anthropic":
|
||||
if cfg.AuthMethod == "oauth" || cfg.AuthMethod == "token" {
|
||||
// Use Claude SDK with token
|
||||
return NewClaudeProvider(cfg.APIKey), nil
|
||||
return NewClaudeProvider(cfg.APIKey), modelID, nil
|
||||
}
|
||||
// Use HTTP API
|
||||
apiBase := cfg.APIBase
|
||||
if apiBase == "" {
|
||||
apiBase = "https://api.anthropic.com/v1"
|
||||
}
|
||||
return NewHTTPProvider(cfg.APIKey, apiBase, cfg.Proxy), nil
|
||||
return NewHTTPProvider(cfg.APIKey, apiBase, cfg.Proxy), modelID, nil
|
||||
|
||||
case "antigravity":
|
||||
return NewAntigravityProvider(), nil
|
||||
return NewAntigravityProvider(), modelID, nil
|
||||
|
||||
case "claude-cli", "claudecli":
|
||||
workspace := "."
|
||||
return NewClaudeCliProvider(workspace), nil
|
||||
return NewClaudeCliProvider(workspace), modelID, nil
|
||||
|
||||
case "codex-cli", "codexcli":
|
||||
workspace := "."
|
||||
return NewCodexCliProvider(workspace), nil
|
||||
return NewCodexCliProvider(workspace), modelID, nil
|
||||
|
||||
case "github-copilot", "copilot":
|
||||
apiBase := cfg.APIBase
|
||||
@@ -89,10 +90,14 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, error) {
|
||||
if connectMode == "" {
|
||||
connectMode = "grpc"
|
||||
}
|
||||
return NewGitHubCopilotProvider(apiBase, connectMode, modelID)
|
||||
provider, err := NewGitHubCopilotProvider(apiBase, connectMode, modelID)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
return provider, modelID, nil
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown protocol %q in model %q", protocol, cfg.Model)
|
||||
return nil, "", fmt.Errorf("unknown protocol %q in model %q", protocol, cfg.Model)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,250 @@
|
||||
// PicoClaw - Ultra-lightweight personal AI agent
|
||||
// License: MIT
|
||||
//
|
||||
// Copyright (c) 2026 PicoClaw contributors
|
||||
|
||||
package providers
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
func TestExtractProtocol(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
model string
|
||||
wantProtocol string
|
||||
wantModelID string
|
||||
}{
|
||||
{
|
||||
name: "openai with prefix",
|
||||
model: "openai/gpt-4o",
|
||||
wantProtocol: "openai",
|
||||
wantModelID: "gpt-4o",
|
||||
},
|
||||
{
|
||||
name: "anthropic with prefix",
|
||||
model: "anthropic/claude-3-sonnet",
|
||||
wantProtocol: "anthropic",
|
||||
wantModelID: "claude-3-sonnet",
|
||||
},
|
||||
{
|
||||
name: "no prefix - defaults to openai",
|
||||
model: "gpt-4o",
|
||||
wantProtocol: "openai",
|
||||
wantModelID: "gpt-4o",
|
||||
},
|
||||
{
|
||||
name: "groq with prefix",
|
||||
model: "groq/llama-3.1-70b",
|
||||
wantProtocol: "groq",
|
||||
wantModelID: "llama-3.1-70b",
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
model: "",
|
||||
wantProtocol: "openai",
|
||||
wantModelID: "",
|
||||
},
|
||||
{
|
||||
name: "with whitespace",
|
||||
model: " openai/gpt-4 ",
|
||||
wantProtocol: "openai",
|
||||
wantModelID: "gpt-4",
|
||||
},
|
||||
{
|
||||
name: "multiple slashes",
|
||||
model: "nvidia/meta/llama-3.1-8b",
|
||||
wantProtocol: "nvidia",
|
||||
wantModelID: "meta/llama-3.1-8b",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
protocol, modelID := ExtractProtocol(tt.model)
|
||||
if protocol != tt.wantProtocol {
|
||||
t.Errorf("ExtractProtocol(%q) protocol = %q, want %q", tt.model, protocol, tt.wantProtocol)
|
||||
}
|
||||
if modelID != tt.wantModelID {
|
||||
t.Errorf("ExtractProtocol(%q) modelID = %q, want %q", tt.model, modelID, tt.wantModelID)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateProviderFromConfig_OpenAI(t *testing.T) {
|
||||
cfg := &config.ModelConfig{
|
||||
ModelName: "test-openai",
|
||||
Model: "openai/gpt-4o",
|
||||
APIKey: "test-key",
|
||||
APIBase: "https://api.example.com/v1",
|
||||
}
|
||||
|
||||
provider, modelID, err := CreateProviderFromConfig(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateProviderFromConfig() error = %v", err)
|
||||
}
|
||||
if provider == nil {
|
||||
t.Fatal("CreateProviderFromConfig() returned nil provider")
|
||||
}
|
||||
if modelID != "gpt-4o" {
|
||||
t.Errorf("modelID = %q, want %q", modelID, "gpt-4o")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateProviderFromConfig_DefaultAPIBase(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
protocol string
|
||||
wantBase string
|
||||
}{
|
||||
{"openai", "openai", "https://api.openai.com/v1"},
|
||||
{"groq", "groq", "https://api.groq.com/openai/v1"},
|
||||
{"openrouter", "openrouter", "https://openrouter.ai/api/v1"},
|
||||
{"cerebras", "cerebras", "https://api.cerebras.ai/v1"},
|
||||
{"qwen", "qwen", "https://dashscope.aliyuncs.com/compatible-mode/v1"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cfg := &config.ModelConfig{
|
||||
ModelName: "test-" + tt.protocol,
|
||||
Model: tt.protocol + "/test-model",
|
||||
APIKey: "test-key",
|
||||
}
|
||||
|
||||
provider, _, err := CreateProviderFromConfig(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateProviderFromConfig() error = %v", err)
|
||||
}
|
||||
|
||||
httpProvider, ok := provider.(*HTTPProvider)
|
||||
if !ok {
|
||||
t.Fatalf("expected *HTTPProvider, got %T", provider)
|
||||
}
|
||||
if httpProvider.apiBase != tt.wantBase {
|
||||
t.Errorf("apiBase = %q, want %q", httpProvider.apiBase, tt.wantBase)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateProviderFromConfig_Anthropic(t *testing.T) {
|
||||
cfg := &config.ModelConfig{
|
||||
ModelName: "test-anthropic",
|
||||
Model: "anthropic/claude-3-sonnet",
|
||||
APIKey: "test-key",
|
||||
}
|
||||
|
||||
provider, modelID, err := CreateProviderFromConfig(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateProviderFromConfig() error = %v", err)
|
||||
}
|
||||
if provider == nil {
|
||||
t.Fatal("CreateProviderFromConfig() returned nil provider")
|
||||
}
|
||||
if modelID != "claude-3-sonnet" {
|
||||
t.Errorf("modelID = %q, want %q", modelID, "claude-3-sonnet")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateProviderFromConfig_Antigravity(t *testing.T) {
|
||||
cfg := &config.ModelConfig{
|
||||
ModelName: "test-antigravity",
|
||||
Model: "antigravity/gemini-2.0-flash",
|
||||
}
|
||||
|
||||
provider, modelID, err := CreateProviderFromConfig(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateProviderFromConfig() error = %v", err)
|
||||
}
|
||||
if provider == nil {
|
||||
t.Fatal("CreateProviderFromConfig() returned nil provider")
|
||||
}
|
||||
if modelID != "gemini-2.0-flash" {
|
||||
t.Errorf("modelID = %q, want %q", modelID, "gemini-2.0-flash")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateProviderFromConfig_ClaudeCLI(t *testing.T) {
|
||||
cfg := &config.ModelConfig{
|
||||
ModelName: "test-claude-cli",
|
||||
Model: "claude-cli/claude-sonnet-4-20250514",
|
||||
}
|
||||
|
||||
provider, modelID, err := CreateProviderFromConfig(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateProviderFromConfig() error = %v", err)
|
||||
}
|
||||
if provider == nil {
|
||||
t.Fatal("CreateProviderFromConfig() returned nil provider")
|
||||
}
|
||||
if modelID != "claude-sonnet-4-20250514" {
|
||||
t.Errorf("modelID = %q, want %q", modelID, "claude-sonnet-4-20250514")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateProviderFromConfig_CodexCLI(t *testing.T) {
|
||||
cfg := &config.ModelConfig{
|
||||
ModelName: "test-codex-cli",
|
||||
Model: "codex-cli/codex",
|
||||
}
|
||||
|
||||
provider, modelID, err := CreateProviderFromConfig(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateProviderFromConfig() error = %v", err)
|
||||
}
|
||||
if provider == nil {
|
||||
t.Fatal("CreateProviderFromConfig() returned nil provider")
|
||||
}
|
||||
if modelID != "codex" {
|
||||
t.Errorf("modelID = %q, want %q", modelID, "codex")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateProviderFromConfig_MissingAPIKey(t *testing.T) {
|
||||
cfg := &config.ModelConfig{
|
||||
ModelName: "test-no-key",
|
||||
Model: "openai/gpt-4o",
|
||||
}
|
||||
|
||||
_, _, err := CreateProviderFromConfig(cfg)
|
||||
if err == nil {
|
||||
t.Fatal("CreateProviderFromConfig() expected error for missing API key")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateProviderFromConfig_UnknownProtocol(t *testing.T) {
|
||||
cfg := &config.ModelConfig{
|
||||
ModelName: "test-unknown",
|
||||
Model: "unknown-protocol/model",
|
||||
APIKey: "test-key",
|
||||
}
|
||||
|
||||
_, _, err := CreateProviderFromConfig(cfg)
|
||||
if err == nil {
|
||||
t.Fatal("CreateProviderFromConfig() expected error for unknown protocol")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateProviderFromConfig_NilConfig(t *testing.T) {
|
||||
_, _, err := CreateProviderFromConfig(nil)
|
||||
if err == nil {
|
||||
t.Fatal("CreateProviderFromConfig(nil) expected error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateProviderFromConfig_EmptyModel(t *testing.T) {
|
||||
cfg := &config.ModelConfig{
|
||||
ModelName: "test-empty",
|
||||
Model: "",
|
||||
}
|
||||
|
||||
_, _, err := CreateProviderFromConfig(cfg)
|
||||
if err == nil {
|
||||
t.Fatal("CreateProviderFromConfig() expected error for empty model")
|
||||
}
|
||||
}
|
||||
@@ -16,9 +16,6 @@ import (
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/auth"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
type HTTPProvider struct {
|
||||
@@ -161,13 +158,15 @@ func (p *HTTPProvider) parseResponse(body []byte) (*LLMResponse, error) {
|
||||
arguments := make(map[string]interface{})
|
||||
name := ""
|
||||
thoughtSignature := ""
|
||||
argsStr := ""
|
||||
|
||||
if tc.Function != nil {
|
||||
name = tc.Function.Name
|
||||
thoughtSignature = tc.Function.ThoughtSignature
|
||||
if tc.Function.Arguments != "" {
|
||||
if err := json.Unmarshal([]byte(tc.Function.Arguments), &arguments); err != nil {
|
||||
arguments["raw"] = tc.Function.Arguments
|
||||
argsStr = tc.Function.Arguments
|
||||
if argsStr != "" {
|
||||
if err := json.Unmarshal([]byte(argsStr), &arguments); err != nil {
|
||||
arguments["raw"] = argsStr
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -177,7 +176,7 @@ func (p *HTTPProvider) parseResponse(body []byte) (*LLMResponse, error) {
|
||||
Type: tc.Type,
|
||||
Function: &FunctionCall{
|
||||
Name: name,
|
||||
Arguments: tc.Function.Arguments,
|
||||
Arguments: argsStr,
|
||||
ThoughtSignature: thoughtSignature,
|
||||
},
|
||||
Name: name,
|
||||
@@ -196,328 +195,3 @@ func (p *HTTPProvider) parseResponse(body []byte) (*LLMResponse, error) {
|
||||
func (p *HTTPProvider) GetDefaultModel() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func createClaudeAuthProvider() (LLMProvider, error) {
|
||||
cred, err := auth.GetCredential("anthropic")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("loading auth credentials: %w", err)
|
||||
}
|
||||
if cred == nil {
|
||||
return nil, fmt.Errorf("no credentials for anthropic. Run: picoclaw auth login --provider anthropic")
|
||||
}
|
||||
return NewClaudeProviderWithTokenSource(cred.AccessToken, createClaudeTokenSource()), nil
|
||||
}
|
||||
|
||||
func createCodexAuthProvider() (LLMProvider, error) {
|
||||
cred, err := auth.GetCredential("openai")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("loading auth credentials: %w", err)
|
||||
}
|
||||
if cred == nil {
|
||||
return nil, fmt.Errorf("no credentials for openai. Run: picoclaw auth login --provider openai")
|
||||
}
|
||||
return NewCodexProviderWithTokenSource(cred.AccessToken, cred.AccountID, createCodexTokenSource()), nil
|
||||
}
|
||||
|
||||
func CreateProvider(cfg *config.Config) (LLMProvider, error) {
|
||||
model := cfg.Agents.Defaults.Model
|
||||
|
||||
// First, try to use model_list configuration
|
||||
if len(cfg.ModelList) > 0 {
|
||||
// Try to get config by model name first
|
||||
modelCfg, err := cfg.GetModelConfig(model)
|
||||
if err == nil {
|
||||
// Found in model_list, use factory to create provider
|
||||
provider, err := CreateProviderFromConfig(modelCfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create provider from model_list: %w", err)
|
||||
}
|
||||
return provider, nil
|
||||
}
|
||||
// Model not found in model_list, fall through to providers config
|
||||
}
|
||||
|
||||
// Log deprecation warning if using old providers config
|
||||
if cfg.HasProvidersConfig() && len(cfg.ModelList) == 0 {
|
||||
fmt.Println("WARNING: providers config is deprecated, please migrate to model_list")
|
||||
}
|
||||
|
||||
providerName := strings.ToLower(cfg.Agents.Defaults.Provider)
|
||||
|
||||
var apiKey, apiBase, proxy string
|
||||
|
||||
lowerModel := strings.ToLower(model)
|
||||
|
||||
// First, try to use explicitly configured provider
|
||||
if providerName != "" {
|
||||
switch providerName {
|
||||
case "groq":
|
||||
if cfg.Providers.Groq.APIKey != "" {
|
||||
apiKey = cfg.Providers.Groq.APIKey
|
||||
apiBase = cfg.Providers.Groq.APIBase
|
||||
if apiBase == "" {
|
||||
apiBase = "https://api.groq.com/openai/v1"
|
||||
}
|
||||
}
|
||||
case "openai", "gpt":
|
||||
if cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != "" {
|
||||
if cfg.Providers.OpenAI.AuthMethod == "codex-cli" {
|
||||
return NewCodexProviderWithTokenSource("", "", CreateCodexCliTokenSource()), nil
|
||||
}
|
||||
if cfg.Providers.OpenAI.AuthMethod == "oauth" || cfg.Providers.OpenAI.AuthMethod == "token" {
|
||||
return createCodexAuthProvider()
|
||||
}
|
||||
apiKey = cfg.Providers.OpenAI.APIKey
|
||||
apiBase = cfg.Providers.OpenAI.APIBase
|
||||
if apiBase == "" {
|
||||
apiBase = "https://api.openai.com/v1"
|
||||
}
|
||||
}
|
||||
case "anthropic", "claude":
|
||||
if cfg.Providers.Anthropic.APIKey != "" || cfg.Providers.Anthropic.AuthMethod != "" {
|
||||
if cfg.Providers.Anthropic.AuthMethod == "oauth" || cfg.Providers.Anthropic.AuthMethod == "token" {
|
||||
return createClaudeAuthProvider()
|
||||
}
|
||||
apiKey = cfg.Providers.Anthropic.APIKey
|
||||
apiBase = cfg.Providers.Anthropic.APIBase
|
||||
if apiBase == "" {
|
||||
apiBase = "https://api.anthropic.com/v1"
|
||||
}
|
||||
}
|
||||
case "openrouter":
|
||||
if cfg.Providers.OpenRouter.APIKey != "" {
|
||||
apiKey = cfg.Providers.OpenRouter.APIKey
|
||||
if cfg.Providers.OpenRouter.APIBase != "" {
|
||||
apiBase = cfg.Providers.OpenRouter.APIBase
|
||||
} else {
|
||||
apiBase = "https://openrouter.ai/api/v1"
|
||||
}
|
||||
}
|
||||
case "zhipu", "glm":
|
||||
if cfg.Providers.Zhipu.APIKey != "" {
|
||||
apiKey = cfg.Providers.Zhipu.APIKey
|
||||
apiBase = cfg.Providers.Zhipu.APIBase
|
||||
if apiBase == "" {
|
||||
apiBase = "https://open.bigmodel.cn/api/paas/v4"
|
||||
}
|
||||
}
|
||||
case "gemini", "google":
|
||||
if cfg.Providers.Gemini.APIKey != "" {
|
||||
apiKey = cfg.Providers.Gemini.APIKey
|
||||
apiBase = cfg.Providers.Gemini.APIBase
|
||||
if apiBase == "" {
|
||||
apiBase = "https://generativelanguage.googleapis.com/v1beta"
|
||||
}
|
||||
}
|
||||
case "vllm":
|
||||
if cfg.Providers.VLLM.APIBase != "" {
|
||||
apiKey = cfg.Providers.VLLM.APIKey
|
||||
apiBase = cfg.Providers.VLLM.APIBase
|
||||
}
|
||||
case "shengsuanyun":
|
||||
if cfg.Providers.ShengSuanYun.APIKey != "" {
|
||||
apiKey = cfg.Providers.ShengSuanYun.APIKey
|
||||
apiBase = cfg.Providers.ShengSuanYun.APIBase
|
||||
if apiBase == "" {
|
||||
apiBase = "https://router.shengsuanyun.com/api/v1"
|
||||
}
|
||||
}
|
||||
case "claude-cli", "claudecode", "claude-code":
|
||||
workspace := cfg.WorkspacePath()
|
||||
if workspace == "" {
|
||||
workspace = "."
|
||||
}
|
||||
return NewClaudeCliProvider(workspace), nil
|
||||
case "codex-cli", "codex-code":
|
||||
workspace := cfg.WorkspacePath()
|
||||
if workspace == "" {
|
||||
workspace = "."
|
||||
}
|
||||
return NewCodexCliProvider(workspace), nil
|
||||
case "cerebras":
|
||||
if cfg.Providers.Cerebras.APIKey != "" {
|
||||
apiKey = cfg.Providers.Cerebras.APIKey
|
||||
apiBase = cfg.Providers.Cerebras.APIBase
|
||||
if apiBase == "" {
|
||||
apiBase = "https://api.cerebras.ai/v1"
|
||||
}
|
||||
}
|
||||
case "deepseek":
|
||||
if cfg.Providers.DeepSeek.APIKey != "" {
|
||||
apiKey = cfg.Providers.DeepSeek.APIKey
|
||||
apiBase = cfg.Providers.DeepSeek.APIBase
|
||||
if apiBase == "" {
|
||||
apiBase = "https://api.deepseek.com/v1"
|
||||
}
|
||||
if model != "deepseek-chat" && model != "deepseek-reasoner" {
|
||||
model = "deepseek-chat"
|
||||
}
|
||||
}
|
||||
case "qwen":
|
||||
if cfg.Providers.Qwen.APIKey != "" {
|
||||
apiKey = cfg.Providers.Qwen.APIKey
|
||||
apiBase = cfg.Providers.Qwen.APIBase
|
||||
if apiBase == "" {
|
||||
apiBase = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
}
|
||||
}
|
||||
case "github_copilot", "copilot":
|
||||
if cfg.Providers.GitHubCopilot.APIBase != "" {
|
||||
apiBase = cfg.Providers.GitHubCopilot.APIBase
|
||||
} else {
|
||||
apiBase = "localhost:4321"
|
||||
}
|
||||
return NewGitHubCopilotProvider(apiBase, cfg.Providers.GitHubCopilot.ConnectMode, model)
|
||||
case "antigravity", "google-antigravity":
|
||||
return NewAntigravityProvider(), nil
|
||||
|
||||
case "volcengine", "doubao":
|
||||
if cfg.Providers.VolcEngine.APIKey != "" {
|
||||
apiKey = cfg.Providers.VolcEngine.APIKey
|
||||
apiBase = cfg.Providers.VolcEngine.APIBase
|
||||
if apiBase == "" {
|
||||
apiBase = "https://ark.cn-beijing.volces.com/api/v3"
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Fallback: detect provider from model name
|
||||
if apiKey == "" && apiBase == "" {
|
||||
switch {
|
||||
case (strings.Contains(lowerModel, "kimi") || strings.Contains(lowerModel, "moonshot") || strings.HasPrefix(model, "moonshot/")) && cfg.Providers.Moonshot.APIKey != "":
|
||||
apiKey = cfg.Providers.Moonshot.APIKey
|
||||
apiBase = cfg.Providers.Moonshot.APIBase
|
||||
proxy = cfg.Providers.Moonshot.Proxy
|
||||
if apiBase == "" {
|
||||
apiBase = "https://api.moonshot.cn/v1"
|
||||
}
|
||||
|
||||
case strings.HasPrefix(model, "openrouter/") || strings.HasPrefix(model, "anthropic/") || strings.HasPrefix(model, "openai/") || strings.HasPrefix(model, "meta-llama/") || strings.HasPrefix(model, "deepseek/") || strings.HasPrefix(model, "google/"):
|
||||
apiKey = cfg.Providers.OpenRouter.APIKey
|
||||
proxy = cfg.Providers.OpenRouter.Proxy
|
||||
if cfg.Providers.OpenRouter.APIBase != "" {
|
||||
apiBase = cfg.Providers.OpenRouter.APIBase
|
||||
} else {
|
||||
apiBase = "https://openrouter.ai/api/v1"
|
||||
}
|
||||
|
||||
case (strings.Contains(lowerModel, "claude") || strings.HasPrefix(model, "anthropic/")) && (cfg.Providers.Anthropic.APIKey != "" || cfg.Providers.Anthropic.AuthMethod != ""):
|
||||
if cfg.Providers.Anthropic.AuthMethod == "oauth" || cfg.Providers.Anthropic.AuthMethod == "token" {
|
||||
return createClaudeAuthProvider()
|
||||
}
|
||||
apiKey = cfg.Providers.Anthropic.APIKey
|
||||
apiBase = cfg.Providers.Anthropic.APIBase
|
||||
proxy = cfg.Providers.Anthropic.Proxy
|
||||
if apiBase == "" {
|
||||
apiBase = "https://api.anthropic.com/v1"
|
||||
}
|
||||
|
||||
case (strings.Contains(lowerModel, "gpt") || strings.HasPrefix(model, "openai/")) && (cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != ""):
|
||||
if cfg.Providers.OpenAI.AuthMethod == "oauth" || cfg.Providers.OpenAI.AuthMethod == "token" {
|
||||
return createCodexAuthProvider()
|
||||
}
|
||||
apiKey = cfg.Providers.OpenAI.APIKey
|
||||
apiBase = cfg.Providers.OpenAI.APIBase
|
||||
proxy = cfg.Providers.OpenAI.Proxy
|
||||
if apiBase == "" {
|
||||
apiBase = "https://api.openai.com/v1"
|
||||
}
|
||||
|
||||
case (strings.Contains(lowerModel, "gemini") || strings.HasPrefix(model, "google/")) && cfg.Providers.Gemini.APIKey != "":
|
||||
apiKey = cfg.Providers.Gemini.APIKey
|
||||
apiBase = cfg.Providers.Gemini.APIBase
|
||||
proxy = cfg.Providers.Gemini.Proxy
|
||||
if apiBase == "" {
|
||||
apiBase = "https://generativelanguage.googleapis.com/v1beta"
|
||||
}
|
||||
|
||||
case (strings.Contains(lowerModel, "glm") || strings.Contains(lowerModel, "zhipu") || strings.Contains(lowerModel, "zai")) && cfg.Providers.Zhipu.APIKey != "":
|
||||
apiKey = cfg.Providers.Zhipu.APIKey
|
||||
apiBase = cfg.Providers.Zhipu.APIBase
|
||||
proxy = cfg.Providers.Zhipu.Proxy
|
||||
if apiBase == "" {
|
||||
apiBase = "https://open.bigmodel.cn/api/paas/v4"
|
||||
}
|
||||
|
||||
case (strings.Contains(lowerModel, "groq") || strings.HasPrefix(model, "groq/")) && cfg.Providers.Groq.APIKey != "":
|
||||
apiKey = cfg.Providers.Groq.APIKey
|
||||
apiBase = cfg.Providers.Groq.APIBase
|
||||
proxy = cfg.Providers.Groq.Proxy
|
||||
if apiBase == "" {
|
||||
apiBase = "https://api.groq.com/openai/v1"
|
||||
}
|
||||
|
||||
case (strings.Contains(lowerModel, "qwen") || strings.HasPrefix(model, "qwen/")) && cfg.Providers.Qwen.APIKey != "":
|
||||
apiKey = cfg.Providers.Qwen.APIKey
|
||||
apiBase = cfg.Providers.Qwen.APIBase
|
||||
proxy = cfg.Providers.Qwen.Proxy
|
||||
if apiBase == "" {
|
||||
apiBase = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
}
|
||||
|
||||
case (strings.Contains(lowerModel, "nvidia") || strings.HasPrefix(model, "nvidia/")) && cfg.Providers.Nvidia.APIKey != "":
|
||||
apiKey = cfg.Providers.Nvidia.APIKey
|
||||
apiBase = cfg.Providers.Nvidia.APIBase
|
||||
proxy = cfg.Providers.Nvidia.Proxy
|
||||
if apiBase == "" {
|
||||
apiBase = "https://integrate.api.nvidia.com/v1"
|
||||
}
|
||||
case (strings.Contains(lowerModel, "cerebras") || strings.HasPrefix(model, "cerebras/")) && cfg.Providers.Cerebras.APIKey != "":
|
||||
apiKey = cfg.Providers.Cerebras.APIKey
|
||||
apiBase = cfg.Providers.Cerebras.APIBase
|
||||
proxy = cfg.Providers.Cerebras.Proxy
|
||||
if apiBase == "" {
|
||||
apiBase = "https://api.cerebras.ai/v1"
|
||||
}
|
||||
|
||||
case (strings.Contains(lowerModel, "ollama") || strings.HasPrefix(model, "ollama/")) && cfg.Providers.Ollama.APIKey != "":
|
||||
fmt.Println("Ollama provider selected based on model name prefix")
|
||||
apiKey = cfg.Providers.Ollama.APIKey
|
||||
apiBase = cfg.Providers.Ollama.APIBase
|
||||
proxy = cfg.Providers.Ollama.Proxy
|
||||
if apiBase == "" {
|
||||
apiBase = "http://localhost:11434/v1"
|
||||
}
|
||||
fmt.Println("Ollama apiBase:", apiBase)
|
||||
|
||||
case (strings.Contains(lowerModel, "doubao") || strings.HasPrefix(lowerModel, "doubao") || strings.Contains(lowerModel, "volcengine")) && cfg.Providers.VolcEngine.APIKey != "":
|
||||
apiKey = cfg.Providers.VolcEngine.APIKey
|
||||
apiBase = cfg.Providers.VolcEngine.APIBase
|
||||
proxy = cfg.Providers.VolcEngine.Proxy
|
||||
if apiBase == "" {
|
||||
apiBase = "https://ark.cn-beijing.volces.com/api/v3"
|
||||
}
|
||||
|
||||
case cfg.Providers.VLLM.APIBase != "":
|
||||
apiKey = cfg.Providers.VLLM.APIKey
|
||||
apiBase = cfg.Providers.VLLM.APIBase
|
||||
proxy = cfg.Providers.VLLM.Proxy
|
||||
|
||||
default:
|
||||
if cfg.Providers.OpenRouter.APIKey != "" {
|
||||
apiKey = cfg.Providers.OpenRouter.APIKey
|
||||
proxy = cfg.Providers.OpenRouter.Proxy
|
||||
if cfg.Providers.OpenRouter.APIBase != "" {
|
||||
apiBase = cfg.Providers.OpenRouter.APIBase
|
||||
} else {
|
||||
apiBase = "https://openrouter.ai/api/v1"
|
||||
}
|
||||
} else {
|
||||
return nil, fmt.Errorf("no API key configured for model: %s", model)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if apiKey == "" && !strings.HasPrefix(model, "bedrock/") {
|
||||
return nil, fmt.Errorf("no API key configured for provider (model: %s)", model)
|
||||
}
|
||||
|
||||
if apiBase == "" {
|
||||
return nil, fmt.Errorf("no API base configured for provider (model: %s)", model)
|
||||
}
|
||||
|
||||
return NewHTTPProvider(apiKey, apiBase, proxy), nil
|
||||
}
|
||||
|
||||
@@ -0,0 +1,349 @@
|
||||
// PicoClaw - Ultra-lightweight personal AI agent
|
||||
// License: MIT
|
||||
//
|
||||
// Copyright (c) 2026 PicoClaw contributors
|
||||
|
||||
package providers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/auth"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
// createClaudeAuthProvider creates a Claude provider using OAuth credentials.
|
||||
func createClaudeAuthProvider() (LLMProvider, error) {
|
||||
cred, err := auth.GetCredential("anthropic")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("loading auth credentials: %w", err)
|
||||
}
|
||||
if cred == nil {
|
||||
return nil, fmt.Errorf("no credentials for anthropic. Run: picoclaw auth login --provider anthropic")
|
||||
}
|
||||
return NewClaudeProviderWithTokenSource(cred.AccessToken, createClaudeTokenSource()), nil
|
||||
}
|
||||
|
||||
// createCodexAuthProvider creates a Codex provider using OAuth credentials.
|
||||
func createCodexAuthProvider() (LLMProvider, error) {
|
||||
cred, err := auth.GetCredential("openai")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("loading auth credentials: %w", err)
|
||||
}
|
||||
if cred == nil {
|
||||
return nil, fmt.Errorf("no credentials for openai. Run: picoclaw auth login --provider openai")
|
||||
}
|
||||
return NewCodexProviderWithTokenSource(cred.AccessToken, cred.AccountID, createCodexTokenSource()), nil
|
||||
}
|
||||
|
||||
// CreateProvider creates a provider based on the configuration.
|
||||
// It supports both the new model_list configuration and the legacy providers configuration.
|
||||
// Returns the provider, the model ID to use, and any error.
|
||||
func CreateProvider(cfg *config.Config) (LLMProvider, string, error) {
|
||||
model := cfg.Agents.Defaults.Model
|
||||
|
||||
// First, try to use model_list configuration
|
||||
if len(cfg.ModelList) > 0 {
|
||||
// Try to get config by model name first
|
||||
modelCfg, err := cfg.GetModelConfig(model)
|
||||
if err == nil {
|
||||
// Found in model_list, use factory to create provider
|
||||
provider, modelID, err := CreateProviderFromConfig(modelCfg)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("failed to create provider from model_list: %w", err)
|
||||
}
|
||||
return provider, modelID, nil
|
||||
}
|
||||
// Model not found in model_list, fall through to providers config
|
||||
}
|
||||
|
||||
// Log deprecation warning if using old providers config
|
||||
if cfg.HasProvidersConfig() && len(cfg.ModelList) == 0 {
|
||||
fmt.Println("WARNING: providers config is deprecated, please migrate to model_list")
|
||||
}
|
||||
|
||||
providerName := strings.ToLower(cfg.Agents.Defaults.Provider)
|
||||
|
||||
var apiKey, apiBase, proxy string
|
||||
|
||||
lowerModel := strings.ToLower(model)
|
||||
|
||||
// First, try to use explicitly configured provider
|
||||
if providerName != "" {
|
||||
switch providerName {
|
||||
case "groq":
|
||||
if cfg.Providers.Groq.APIKey != "" {
|
||||
apiKey = cfg.Providers.Groq.APIKey
|
||||
apiBase = cfg.Providers.Groq.APIBase
|
||||
if apiBase == "" {
|
||||
apiBase = "https://api.groq.com/openai/v1"
|
||||
}
|
||||
}
|
||||
case "openai", "gpt":
|
||||
if cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != "" {
|
||||
if cfg.Providers.OpenAI.AuthMethod == "codex-cli" {
|
||||
return NewCodexProviderWithTokenSource("", "", CreateCodexCliTokenSource()), model, nil
|
||||
}
|
||||
if cfg.Providers.OpenAI.AuthMethod == "oauth" || cfg.Providers.OpenAI.AuthMethod == "token" {
|
||||
provider, err := createCodexAuthProvider()
|
||||
return provider, model, err
|
||||
}
|
||||
apiKey = cfg.Providers.OpenAI.APIKey
|
||||
apiBase = cfg.Providers.OpenAI.APIBase
|
||||
if apiBase == "" {
|
||||
apiBase = "https://api.openai.com/v1"
|
||||
}
|
||||
}
|
||||
case "anthropic", "claude":
|
||||
if cfg.Providers.Anthropic.APIKey != "" || cfg.Providers.Anthropic.AuthMethod != "" {
|
||||
if cfg.Providers.Anthropic.AuthMethod == "oauth" || cfg.Providers.Anthropic.AuthMethod == "token" {
|
||||
provider, err := createClaudeAuthProvider()
|
||||
return provider, model, err
|
||||
}
|
||||
apiKey = cfg.Providers.Anthropic.APIKey
|
||||
apiBase = cfg.Providers.Anthropic.APIBase
|
||||
if apiBase == "" {
|
||||
apiBase = "https://api.anthropic.com/v1"
|
||||
}
|
||||
}
|
||||
case "openrouter":
|
||||
if cfg.Providers.OpenRouter.APIKey != "" {
|
||||
apiKey = cfg.Providers.OpenRouter.APIKey
|
||||
if cfg.Providers.OpenRouter.APIBase != "" {
|
||||
apiBase = cfg.Providers.OpenRouter.APIBase
|
||||
} else {
|
||||
apiBase = "https://openrouter.ai/api/v1"
|
||||
}
|
||||
}
|
||||
case "zhipu", "glm":
|
||||
if cfg.Providers.Zhipu.APIKey != "" {
|
||||
apiKey = cfg.Providers.Zhipu.APIKey
|
||||
apiBase = cfg.Providers.Zhipu.APIBase
|
||||
if apiBase == "" {
|
||||
apiBase = "https://open.bigmodel.cn/api/paas/v4"
|
||||
}
|
||||
}
|
||||
case "gemini", "google":
|
||||
if cfg.Providers.Gemini.APIKey != "" {
|
||||
apiKey = cfg.Providers.Gemini.APIKey
|
||||
apiBase = cfg.Providers.Gemini.APIBase
|
||||
if apiBase == "" {
|
||||
apiBase = "https://generativelanguage.googleapis.com/v1beta"
|
||||
}
|
||||
}
|
||||
case "vllm":
|
||||
if cfg.Providers.VLLM.APIBase != "" {
|
||||
apiKey = cfg.Providers.VLLM.APIKey
|
||||
apiBase = cfg.Providers.VLLM.APIBase
|
||||
}
|
||||
case "shengsuanyun":
|
||||
if cfg.Providers.ShengSuanYun.APIKey != "" {
|
||||
apiKey = cfg.Providers.ShengSuanYun.APIKey
|
||||
apiBase = cfg.Providers.ShengSuanYun.APIBase
|
||||
if apiBase == "" {
|
||||
apiBase = "https://router.shengsuanyun.com/api/v1"
|
||||
}
|
||||
}
|
||||
case "claude-cli", "claudecode", "claude-code":
|
||||
workspace := cfg.WorkspacePath()
|
||||
if workspace == "" {
|
||||
workspace = "."
|
||||
}
|
||||
return NewClaudeCliProvider(workspace), model, nil
|
||||
case "codex-cli", "codex-code":
|
||||
workspace := cfg.WorkspacePath()
|
||||
if workspace == "" {
|
||||
workspace = "."
|
||||
}
|
||||
return NewCodexCliProvider(workspace), model, nil
|
||||
case "cerebras":
|
||||
if cfg.Providers.Cerebras.APIKey != "" {
|
||||
apiKey = cfg.Providers.Cerebras.APIKey
|
||||
apiBase = cfg.Providers.Cerebras.APIBase
|
||||
if apiBase == "" {
|
||||
apiBase = "https://api.cerebras.ai/v1"
|
||||
}
|
||||
}
|
||||
case "deepseek":
|
||||
if cfg.Providers.DeepSeek.APIKey != "" {
|
||||
apiKey = cfg.Providers.DeepSeek.APIKey
|
||||
apiBase = cfg.Providers.DeepSeek.APIBase
|
||||
if apiBase == "" {
|
||||
apiBase = "https://api.deepseek.com/v1"
|
||||
}
|
||||
if model != "deepseek-chat" && model != "deepseek-reasoner" {
|
||||
model = "deepseek-chat"
|
||||
}
|
||||
}
|
||||
case "qwen":
|
||||
if cfg.Providers.Qwen.APIKey != "" {
|
||||
apiKey = cfg.Providers.Qwen.APIKey
|
||||
apiBase = cfg.Providers.Qwen.APIBase
|
||||
if apiBase == "" {
|
||||
apiBase = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
}
|
||||
}
|
||||
case "github_copilot", "copilot":
|
||||
if cfg.Providers.GitHubCopilot.APIBase != "" {
|
||||
apiBase = cfg.Providers.GitHubCopilot.APIBase
|
||||
} else {
|
||||
apiBase = "localhost:4321"
|
||||
}
|
||||
provider, err := NewGitHubCopilotProvider(apiBase, cfg.Providers.GitHubCopilot.ConnectMode, model)
|
||||
return provider, model, err
|
||||
case "antigravity", "google-antigravity":
|
||||
return NewAntigravityProvider(), model, nil
|
||||
|
||||
case "volcengine", "doubao":
|
||||
if cfg.Providers.VolcEngine.APIKey != "" {
|
||||
apiKey = cfg.Providers.VolcEngine.APIKey
|
||||
apiBase = cfg.Providers.VolcEngine.APIBase
|
||||
if apiBase == "" {
|
||||
apiBase = "https://ark.cn-beijing.volces.com/api/v3"
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Fallback: detect provider from model name
|
||||
if apiKey == "" && apiBase == "" {
|
||||
switch {
|
||||
case (strings.Contains(lowerModel, "kimi") || strings.Contains(lowerModel, "moonshot") || strings.HasPrefix(model, "moonshot/")) && cfg.Providers.Moonshot.APIKey != "":
|
||||
apiKey = cfg.Providers.Moonshot.APIKey
|
||||
apiBase = cfg.Providers.Moonshot.APIBase
|
||||
proxy = cfg.Providers.Moonshot.Proxy
|
||||
if apiBase == "" {
|
||||
apiBase = "https://api.moonshot.cn/v1"
|
||||
}
|
||||
|
||||
case strings.HasPrefix(model, "openrouter/") || strings.HasPrefix(model, "anthropic/") || strings.HasPrefix(model, "openai/") || strings.HasPrefix(model, "meta-llama/") || strings.HasPrefix(model, "deepseek/") || strings.HasPrefix(model, "google/"):
|
||||
apiKey = cfg.Providers.OpenRouter.APIKey
|
||||
proxy = cfg.Providers.OpenRouter.Proxy
|
||||
if cfg.Providers.OpenRouter.APIBase != "" {
|
||||
apiBase = cfg.Providers.OpenRouter.APIBase
|
||||
} else {
|
||||
apiBase = "https://openrouter.ai/api/v1"
|
||||
}
|
||||
|
||||
case (strings.Contains(lowerModel, "claude") || strings.HasPrefix(model, "anthropic/")) && (cfg.Providers.Anthropic.APIKey != "" || cfg.Providers.Anthropic.AuthMethod != ""):
|
||||
if cfg.Providers.Anthropic.AuthMethod == "oauth" || cfg.Providers.Anthropic.AuthMethod == "token" {
|
||||
provider, err := createClaudeAuthProvider()
|
||||
return provider, model, err
|
||||
}
|
||||
apiKey = cfg.Providers.Anthropic.APIKey
|
||||
apiBase = cfg.Providers.Anthropic.APIBase
|
||||
proxy = cfg.Providers.Anthropic.Proxy
|
||||
if apiBase == "" {
|
||||
apiBase = "https://api.anthropic.com/v1"
|
||||
}
|
||||
|
||||
case (strings.Contains(lowerModel, "gpt") || strings.HasPrefix(model, "openai/")) && (cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != ""):
|
||||
if cfg.Providers.OpenAI.AuthMethod == "oauth" || cfg.Providers.OpenAI.AuthMethod == "token" {
|
||||
provider, err := createCodexAuthProvider()
|
||||
return provider, model, err
|
||||
}
|
||||
apiKey = cfg.Providers.OpenAI.APIKey
|
||||
apiBase = cfg.Providers.OpenAI.APIBase
|
||||
proxy = cfg.Providers.OpenAI.Proxy
|
||||
if apiBase == "" {
|
||||
apiBase = "https://api.openai.com/v1"
|
||||
}
|
||||
|
||||
case (strings.Contains(lowerModel, "gemini") || strings.HasPrefix(model, "google/")) && cfg.Providers.Gemini.APIKey != "":
|
||||
apiKey = cfg.Providers.Gemini.APIKey
|
||||
apiBase = cfg.Providers.Gemini.APIBase
|
||||
proxy = cfg.Providers.Gemini.Proxy
|
||||
if apiBase == "" {
|
||||
apiBase = "https://generativelanguage.googleapis.com/v1beta"
|
||||
}
|
||||
|
||||
case (strings.Contains(lowerModel, "glm") || strings.Contains(lowerModel, "zhipu") || strings.Contains(lowerModel, "zai")) && cfg.Providers.Zhipu.APIKey != "":
|
||||
apiKey = cfg.Providers.Zhipu.APIKey
|
||||
apiBase = cfg.Providers.Zhipu.APIBase
|
||||
proxy = cfg.Providers.Zhipu.Proxy
|
||||
if apiBase == "" {
|
||||
apiBase = "https://open.bigmodel.cn/api/paas/v4"
|
||||
}
|
||||
|
||||
case (strings.Contains(lowerModel, "groq") || strings.HasPrefix(model, "groq/")) && cfg.Providers.Groq.APIKey != "":
|
||||
apiKey = cfg.Providers.Groq.APIKey
|
||||
apiBase = cfg.Providers.Groq.APIBase
|
||||
proxy = cfg.Providers.Groq.Proxy
|
||||
if apiBase == "" {
|
||||
apiBase = "https://api.groq.com/openai/v1"
|
||||
}
|
||||
|
||||
case (strings.Contains(lowerModel, "qwen") || strings.HasPrefix(model, "qwen/")) && cfg.Providers.Qwen.APIKey != "":
|
||||
apiKey = cfg.Providers.Qwen.APIKey
|
||||
apiBase = cfg.Providers.Qwen.APIBase
|
||||
proxy = cfg.Providers.Qwen.Proxy
|
||||
if apiBase == "" {
|
||||
apiBase = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
}
|
||||
|
||||
case (strings.Contains(lowerModel, "nvidia") || strings.HasPrefix(model, "nvidia/")) && cfg.Providers.Nvidia.APIKey != "":
|
||||
apiKey = cfg.Providers.Nvidia.APIKey
|
||||
apiBase = cfg.Providers.Nvidia.APIBase
|
||||
proxy = cfg.Providers.Nvidia.Proxy
|
||||
if apiBase == "" {
|
||||
apiBase = "https://integrate.api.nvidia.com/v1"
|
||||
}
|
||||
case (strings.Contains(lowerModel, "cerebras") || strings.HasPrefix(model, "cerebras/")) && cfg.Providers.Cerebras.APIKey != "":
|
||||
apiKey = cfg.Providers.Cerebras.APIKey
|
||||
apiBase = cfg.Providers.Cerebras.APIBase
|
||||
proxy = cfg.Providers.Cerebras.Proxy
|
||||
if apiBase == "" {
|
||||
apiBase = "https://api.cerebras.ai/v1"
|
||||
}
|
||||
|
||||
case (strings.Contains(lowerModel, "ollama") || strings.HasPrefix(model, "ollama/")) && cfg.Providers.Ollama.APIKey != "":
|
||||
fmt.Println("Ollama provider selected based on model name prefix")
|
||||
apiKey = cfg.Providers.Ollama.APIKey
|
||||
apiBase = cfg.Providers.Ollama.APIBase
|
||||
proxy = cfg.Providers.Ollama.Proxy
|
||||
if apiBase == "" {
|
||||
apiBase = "http://localhost:11434/v1"
|
||||
}
|
||||
fmt.Println("Ollama apiBase:", apiBase)
|
||||
|
||||
case (strings.Contains(lowerModel, "doubao") || strings.HasPrefix(lowerModel, "doubao") || strings.Contains(lowerModel, "volcengine")) && cfg.Providers.VolcEngine.APIKey != "":
|
||||
apiKey = cfg.Providers.VolcEngine.APIKey
|
||||
apiBase = cfg.Providers.VolcEngine.APIBase
|
||||
proxy = cfg.Providers.VolcEngine.Proxy
|
||||
if apiBase == "" {
|
||||
apiBase = "https://ark.cn-beijing.volces.com/api/v3"
|
||||
}
|
||||
|
||||
case cfg.Providers.VLLM.APIBase != "":
|
||||
apiKey = cfg.Providers.VLLM.APIKey
|
||||
apiBase = cfg.Providers.VLLM.APIBase
|
||||
proxy = cfg.Providers.VLLM.Proxy
|
||||
|
||||
default:
|
||||
if cfg.Providers.OpenRouter.APIKey != "" {
|
||||
apiKey = cfg.Providers.OpenRouter.APIKey
|
||||
proxy = cfg.Providers.OpenRouter.Proxy
|
||||
if cfg.Providers.OpenRouter.APIBase != "" {
|
||||
apiBase = cfg.Providers.OpenRouter.APIBase
|
||||
} else {
|
||||
apiBase = "https://openrouter.ai/api/v1"
|
||||
}
|
||||
} else {
|
||||
return nil, "", fmt.Errorf("no API key configured for model: %s", model)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if apiKey == "" && !strings.HasPrefix(model, "bedrock/") {
|
||||
return nil, "", fmt.Errorf("no API key configured for provider (model: %s)", model)
|
||||
}
|
||||
|
||||
if apiBase == "" {
|
||||
return nil, "", fmt.Errorf("no API base configured for provider (model: %s)", model)
|
||||
}
|
||||
|
||||
return NewHTTPProvider(apiKey, apiBase, proxy), model, nil
|
||||
}
|
||||
@@ -1,113 +0,0 @@
|
||||
// PicoClaw - Ultra-lightweight personal AI agent
|
||||
// License: MIT
|
||||
//
|
||||
// Copyright (c) 2026 PicoClaw contributors
|
||||
|
||||
package providers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
// ModelRegistry manages model configurations with thread-safe round-robin load balancing.
|
||||
// It allows multiple configurations for the same model_name to distribute load across endpoints.
|
||||
type ModelRegistry struct {
|
||||
configs map[string][]config.ModelConfig // model_name -> []ModelConfig
|
||||
counters map[string]*atomic.Uint64 // model_name -> round-robin counter
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewModelRegistry creates a new ModelRegistry from a slice of ModelConfig.
|
||||
func NewModelRegistry(modelList []config.ModelConfig) *ModelRegistry {
|
||||
r := &ModelRegistry{
|
||||
configs: make(map[string][]config.ModelConfig),
|
||||
counters: make(map[string]*atomic.Uint64),
|
||||
}
|
||||
|
||||
for _, cfg := range modelList {
|
||||
r.configs[cfg.ModelName] = append(r.configs[cfg.ModelName], cfg)
|
||||
}
|
||||
|
||||
// Initialize counters for models with multiple configs
|
||||
for name, cfgs := range r.configs {
|
||||
if len(cfgs) > 1 {
|
||||
r.counters[name] = &atomic.Uint64{}
|
||||
}
|
||||
}
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
// GetModelConfig returns a ModelConfig for the given model name.
|
||||
// If multiple configs exist for the same model_name, it uses round-robin selection.
|
||||
// Returns an error if the model is not found.
|
||||
func (r *ModelRegistry) GetModelConfig(modelName string) (*config.ModelConfig, error) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
configs, ok := r.configs[modelName]
|
||||
if !ok || len(configs) == 0 {
|
||||
return nil, fmt.Errorf("model %q not found", modelName)
|
||||
}
|
||||
|
||||
// Single config - return directly
|
||||
if len(configs) == 1 {
|
||||
return &configs[0], nil
|
||||
}
|
||||
|
||||
// Multiple configs - use round-robin for load balancing
|
||||
counter, ok := r.counters[modelName]
|
||||
if !ok {
|
||||
// Should not happen, but handle gracefully
|
||||
return &configs[0], nil
|
||||
}
|
||||
|
||||
idx := counter.Add(1) % uint64(len(configs))
|
||||
return &configs[idx], nil
|
||||
}
|
||||
|
||||
// AddConfig adds a new ModelConfig to the registry.
|
||||
func (r *ModelRegistry) AddConfig(cfg config.ModelConfig) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
r.configs[cfg.ModelName] = append(r.configs[cfg.ModelName], cfg)
|
||||
|
||||
// Initialize counter if we now have multiple configs
|
||||
if len(r.configs[cfg.ModelName]) > 1 && r.counters[cfg.ModelName] == nil {
|
||||
r.counters[cfg.ModelName] = &atomic.Uint64{}
|
||||
}
|
||||
}
|
||||
|
||||
// RemoveConfig removes all configs with the given model_name.
|
||||
func (r *ModelRegistry) RemoveConfig(modelName string) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
delete(r.configs, modelName)
|
||||
delete(r.counters, modelName)
|
||||
}
|
||||
|
||||
// ListModels returns all unique model names in the registry.
|
||||
func (r *ModelRegistry) ListModels() []string {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
names := make([]string, 0, len(r.configs))
|
||||
for name := range r.configs {
|
||||
names = append(names, name)
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
||||
// ConfigCount returns the number of configurations for a given model name.
|
||||
func (r *ModelRegistry) ConfigCount(modelName string) int {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
return len(r.configs[modelName])
|
||||
}
|
||||
Reference in New Issue
Block a user