diff --git a/cmd/picoclaw/cmd_agent.go b/cmd/picoclaw/cmd_agent.go new file mode 100644 index 000000000..cee9f68ec --- /dev/null +++ b/cmd/picoclaw/cmd_agent.go @@ -0,0 +1,181 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT + +package main + +import ( + "bufio" + "context" + "fmt" + "io" + "os" + "path/filepath" + "strings" + + "github.com/chzyer/readline" + "github.com/sipeed/picoclaw/pkg/agent" + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/providers" +) + +func agentCmd() { + message := "" + sessionKey := "cli:default" + modelOverride := "" + + args := os.Args[2:] + for i := 0; i < len(args); i++ { + switch args[i] { + case "--debug", "-d": + logger.SetLevel(logger.DEBUG) + fmt.Println("šŸ” Debug mode enabled") + case "-m", "--message": + if i+1 < len(args) { + message = args[i+1] + i++ + } + case "-s", "--session": + if i+1 < len(args) { + sessionKey = args[i+1] + i++ + } + case "--model", "-model": + if i+1 < len(args) { + modelOverride = args[i+1] + i++ + } + } + } + + cfg, err := loadConfig() + if err != nil { + fmt.Printf("Error loading config: %v\n", err) + os.Exit(1) + } + + if modelOverride != "" { + cfg.Agents.Defaults.Model = modelOverride + } + + provider, modelID, err := providers.CreateProvider(cfg) + if err != nil { + fmt.Printf("Error creating provider: %v\n", err) + os.Exit(1) + } + // Use the resolved model ID from provider creation + if modelID != "" { + cfg.Agents.Defaults.Model = modelID + } + + msgBus := bus.NewMessageBus() + agentLoop := agent.NewAgentLoop(cfg, msgBus, provider) + + // Print agent startup info (only for interactive mode) + startupInfo := agentLoop.GetStartupInfo() + logger.InfoCF("agent", "Agent initialized", + map[string]interface{}{ + "tools_count": startupInfo["tools"].(map[string]interface{})["count"], + "skills_total": startupInfo["skills"].(map[string]interface{})["total"], + "skills_available": startupInfo["skills"].(map[string]interface{})["available"], + }) + + if message != "" { + ctx := context.Background() + response, err := agentLoop.ProcessDirect(ctx, message, sessionKey) + if err != nil { + fmt.Printf("Error: %v\n", err) + os.Exit(1) + } + fmt.Printf("\n%s %s\n", logo, response) + } else { + fmt.Printf("%s Interactive mode (Ctrl+C to exit)\n\n", logo) + interactiveMode(agentLoop, sessionKey) + } +} + +func interactiveMode(agentLoop *agent.AgentLoop, sessionKey string) { + prompt := fmt.Sprintf("%s You: ", logo) + + rl, err := readline.NewEx(&readline.Config{ + Prompt: prompt, + HistoryFile: filepath.Join(os.TempDir(), ".picoclaw_history"), + HistoryLimit: 100, + InterruptPrompt: "^C", + EOFPrompt: "exit", + }) + + if err != nil { + fmt.Printf("Error initializing readline: %v\n", err) + fmt.Println("Falling back to simple input mode...") + simpleInteractiveMode(agentLoop, sessionKey) + return + } + defer rl.Close() + + for { + line, err := rl.Readline() + if err != nil { + if err == readline.ErrInterrupt || err == io.EOF { + fmt.Println("\nGoodbye!") + return + } + fmt.Printf("Error reading input: %v\n", err) + continue + } + + input := strings.TrimSpace(line) + if input == "" { + continue + } + + if input == "exit" || input == "quit" { + fmt.Println("Goodbye!") + return + } + + ctx := context.Background() + response, err := agentLoop.ProcessDirect(ctx, input, sessionKey) + if err != nil { + fmt.Printf("Error: %v\n", err) + continue + } + + fmt.Printf("\n%s %s\n\n", logo, response) + } +} + +func simpleInteractiveMode(agentLoop *agent.AgentLoop, sessionKey string) { + reader := bufio.NewReader(os.Stdin) + for { + fmt.Print(fmt.Sprintf("%s You: ", logo)) + line, err := reader.ReadString('\n') + if err != nil { + if err == io.EOF { + fmt.Println("\nGoodbye!") + return + } + fmt.Printf("Error reading input: %v\n", err) + continue + } + + input := strings.TrimSpace(line) + if input == "" { + continue + } + + if input == "exit" || input == "quit" { + fmt.Println("Goodbye!") + return + } + + ctx := context.Background() + response, err := agentLoop.ProcessDirect(ctx, input, sessionKey) + if err != nil { + fmt.Printf("Error: %v\n", err) + continue + } + + fmt.Printf("\n%s %s\n\n", logo, response) + } +} diff --git a/cmd/picoclaw/cmd_auth.go b/cmd/picoclaw/cmd_auth.go new file mode 100644 index 000000000..b144fe21d --- /dev/null +++ b/cmd/picoclaw/cmd_auth.go @@ -0,0 +1,386 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT + +package main + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "time" + + "github.com/sipeed/picoclaw/pkg/auth" + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/providers" +) + +func authCmd() { + if len(os.Args) < 3 { + authHelp() + return + } + + switch os.Args[2] { + case "login": + authLoginCmd() + case "logout": + authLogoutCmd() + case "status": + authStatusCmd() + case "models": + authModelsCmd() + default: + fmt.Printf("Unknown auth command: %s\n", os.Args[2]) + authHelp() + } +} + +func authHelp() { + fmt.Println("\nAuth commands:") + fmt.Println(" login Login via OAuth or paste token") + fmt.Println(" logout Remove stored credentials") + fmt.Println(" status Show current auth status") + fmt.Println(" models List available Antigravity models") + fmt.Println() + fmt.Println("Login options:") + fmt.Println(" --provider Provider to login with (openai, anthropic, google-antigravity)") + fmt.Println(" --device-code Use device code flow (for headless environments)") + fmt.Println() + fmt.Println("Examples:") + fmt.Println(" picoclaw auth login --provider openai") + fmt.Println(" picoclaw auth login --provider openai --device-code") + fmt.Println(" picoclaw auth login --provider anthropic") + fmt.Println(" picoclaw auth login --provider google-antigravity") + fmt.Println(" picoclaw auth models") + fmt.Println(" picoclaw auth logout --provider openai") + fmt.Println(" picoclaw auth status") +} + +func authLoginCmd() { + provider := "" + useDeviceCode := false + + args := os.Args[3:] + for i := 0; i < len(args); i++ { + switch args[i] { + case "--provider", "-p": + if i+1 < len(args) { + provider = args[i+1] + i++ + } + case "--device-code": + useDeviceCode = true + } + } + + if provider == "" { + fmt.Println("Error: --provider is required") + fmt.Println("Supported providers: openai, anthropic, google-antigravity") + return + } + + switch provider { + case "openai": + authLoginOpenAI(useDeviceCode) + case "anthropic": + authLoginPasteToken(provider) + case "google-antigravity", "antigravity": + authLoginGoogleAntigravity() + default: + fmt.Printf("Unsupported provider: %s\n", provider) + fmt.Println("Supported providers: openai, anthropic, google-antigravity") + } +} + +func authLoginOpenAI(useDeviceCode bool) { + cfg := auth.OpenAIOAuthConfig() + + var cred *auth.AuthCredential + var err error + + if useDeviceCode { + cred, err = auth.LoginDeviceCode(cfg) + } else { + cred, err = auth.LoginBrowser(cfg) + } + + if err != nil { + fmt.Printf("Login failed: %v\n", err) + os.Exit(1) + } + + if err := auth.SetCredential("openai", cred); err != nil { + fmt.Printf("Failed to save credentials: %v\n", err) + os.Exit(1) + } + + appCfg, err := loadConfig() + if err == nil { + appCfg.Providers.OpenAI.AuthMethod = "oauth" + if err := config.SaveConfig(getConfigPath(), appCfg); err != nil { + fmt.Printf("Warning: could not update config: %v\n", err) + } + } + + fmt.Println("Login successful!") + if cred.AccountID != "" { + fmt.Printf("Account: %s\n", cred.AccountID) + } +} + +func authLoginGoogleAntigravity() { + cfg := auth.GoogleAntigravityOAuthConfig() + + cred, err := auth.LoginBrowser(cfg) + if err != nil { + fmt.Printf("Login failed: %v\n", err) + os.Exit(1) + } + + cred.Provider = "google-antigravity" + + // Fetch user email from Google userinfo + email, err := fetchGoogleUserEmail(cred.AccessToken) + if err != nil { + fmt.Printf("Warning: could not fetch email: %v\n", err) + } else { + cred.Email = email + fmt.Printf("Email: %s\n", email) + } + + // Fetch Cloud Code Assist project ID + projectID, err := providers.FetchAntigravityProjectID(cred.AccessToken) + if err != nil { + fmt.Printf("Warning: could not fetch project ID: %v\n", err) + fmt.Println("You may need Google Cloud Code Assist enabled on your account.") + } else { + cred.ProjectID = projectID + fmt.Printf("Project: %s\n", projectID) + } + + if err := auth.SetCredential("google-antigravity", cred); err != nil { + fmt.Printf("Failed to save credentials: %v\n", err) + os.Exit(1) + } + + appCfg, err := loadConfig() + if err == nil { + appCfg.Providers.Antigravity.AuthMethod = "oauth" + if appCfg.Agents.Defaults.Provider == "" { + appCfg.Agents.Defaults.Provider = "antigravity" + } + if appCfg.Agents.Defaults.Provider == "antigravity" || appCfg.Agents.Defaults.Provider == "google-antigravity" { + appCfg.Agents.Defaults.Model = "gemini-3-flash" + } + if err := config.SaveConfig(getConfigPath(), appCfg); err != nil { + fmt.Printf("Warning: could not update config: %v\n", err) + } + } + + fmt.Println("\nāœ“ Google Antigravity login successful!") + fmt.Println("Config updated: provider=antigravity, model=gemini-3-flash") + fmt.Println("Try it: picoclaw agent -m \"Hello world\"") +} + +func fetchGoogleUserEmail(accessToken string) (string, error) { + req, err := http.NewRequest("GET", "https://www.googleapis.com/oauth2/v2/userinfo", nil) + if err != nil { + return "", err + } + req.Header.Set("Authorization", "Bearer "+accessToken) + + client := &http.Client{Timeout: 10 * time.Second} + resp, err := client.Do(req) + if err != nil { + return "", err + } + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("userinfo request failed: %s", string(body)) + } + + var userInfo struct { + Email string `json:"email"` + } + if err := json.Unmarshal(body, &userInfo); err != nil { + return "", err + } + return userInfo.Email, nil +} + +func authLoginPasteToken(provider string) { + cred, err := auth.LoginPasteToken(provider, os.Stdin) + if err != nil { + fmt.Printf("Login failed: %v\n", err) + os.Exit(1) + } + + if err := auth.SetCredential(provider, cred); err != nil { + fmt.Printf("Failed to save credentials: %v\n", err) + os.Exit(1) + } + + appCfg, err := loadConfig() + if err == nil { + switch provider { + case "anthropic": + appCfg.Providers.Anthropic.AuthMethod = "token" + case "openai": + appCfg.Providers.OpenAI.AuthMethod = "token" + } + if err := config.SaveConfig(getConfigPath(), appCfg); err != nil { + fmt.Printf("Warning: could not update config: %v\n", err) + } + } + + fmt.Printf("Token saved for %s!\n", provider) +} + +func authLogoutCmd() { + provider := "" + + args := os.Args[3:] + for i := 0; i < len(args); i++ { + switch args[i] { + case "--provider", "-p": + if i+1 < len(args) { + provider = args[i+1] + i++ + } + } + } + + if provider != "" { + if err := auth.DeleteCredential(provider); err != nil { + fmt.Printf("Failed to remove credentials: %v\n", err) + os.Exit(1) + } + + appCfg, err := loadConfig() + if err == nil { + switch provider { + case "openai": + appCfg.Providers.OpenAI.AuthMethod = "" + case "anthropic": + appCfg.Providers.Anthropic.AuthMethod = "" + case "google-antigravity", "antigravity": + appCfg.Providers.Antigravity.AuthMethod = "" + } + config.SaveConfig(getConfigPath(), appCfg) + } + + fmt.Printf("Logged out from %s\n", provider) + } else { + if err := auth.DeleteAllCredentials(); err != nil { + fmt.Printf("Failed to remove credentials: %v\n", err) + os.Exit(1) + } + + appCfg, err := loadConfig() + if err == nil { + appCfg.Providers.OpenAI.AuthMethod = "" + appCfg.Providers.Anthropic.AuthMethod = "" + appCfg.Providers.Antigravity.AuthMethod = "" + config.SaveConfig(getConfigPath(), appCfg) + } + + fmt.Println("Logged out from all providers") + } +} + +func authStatusCmd() { + store, err := auth.LoadStore() + if err != nil { + fmt.Printf("Error loading auth store: %v\n", err) + return + } + + if len(store.Credentials) == 0 { + fmt.Println("No authenticated providers.") + fmt.Println("Run: picoclaw auth login --provider ") + return + } + + fmt.Println("\nAuthenticated Providers:") + fmt.Println("------------------------") + for provider, cred := range store.Credentials { + status := "active" + if cred.IsExpired() { + status = "expired" + } else if cred.NeedsRefresh() { + status = "needs refresh" + } + + fmt.Printf(" %s:\n", provider) + fmt.Printf(" Method: %s\n", cred.AuthMethod) + fmt.Printf(" Status: %s\n", status) + if cred.AccountID != "" { + fmt.Printf(" Account: %s\n", cred.AccountID) + } + if cred.Email != "" { + fmt.Printf(" Email: %s\n", cred.Email) + } + if cred.ProjectID != "" { + fmt.Printf(" Project: %s\n", cred.ProjectID) + } + if !cred.ExpiresAt.IsZero() { + fmt.Printf(" Expires: %s\n", cred.ExpiresAt.Format("2006-01-02 15:04")) + } + } +} + +func authModelsCmd() { + cred, err := auth.GetCredential("google-antigravity") + if err != nil || cred == nil { + fmt.Println("Not logged in to Google Antigravity.") + fmt.Println("Run: picoclaw auth login --provider google-antigravity") + return + } + + // Refresh token if needed + if cred.NeedsRefresh() && cred.RefreshToken != "" { + oauthCfg := auth.GoogleAntigravityOAuthConfig() + refreshed, refreshErr := auth.RefreshAccessToken(cred, oauthCfg) + if refreshErr == nil { + cred = refreshed + _ = auth.SetCredential("google-antigravity", cred) + } + } + + projectID := cred.ProjectID + if projectID == "" { + fmt.Println("No project ID stored. Try logging in again.") + return + } + + fmt.Printf("Fetching models for project: %s\n\n", projectID) + + models, err := providers.FetchAntigravityModels(cred.AccessToken, projectID) + if err != nil { + fmt.Printf("Error fetching models: %v\n", err) + return + } + + if len(models) == 0 { + fmt.Println("No models available.") + return + } + + fmt.Println("Available Antigravity Models:") + fmt.Println("-----------------------------") + for _, m := range models { + status := "āœ“" + if m.IsExhausted { + status = "āœ— (quota exhausted)" + } + name := m.ID + if m.DisplayName != "" { + name = fmt.Sprintf("%s (%s)", m.ID, m.DisplayName) + } + fmt.Printf(" %s %s\n", status, name) + } +} diff --git a/cmd/picoclaw/cmd_cron.go b/cmd/picoclaw/cmd_cron.go new file mode 100644 index 000000000..8c42bde06 --- /dev/null +++ b/cmd/picoclaw/cmd_cron.go @@ -0,0 +1,227 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT + +package main + +import ( + "fmt" + "os" + "path/filepath" + "time" + + "github.com/sipeed/picoclaw/pkg/cron" +) + +func cronCmd() { + if len(os.Args) < 3 { + cronHelp() + return + } + + subcommand := os.Args[2] + + // Load config to get workspace path + cfg, err := loadConfig() + if err != nil { + fmt.Printf("Error loading config: %v\n", err) + return + } + + cronStorePath := filepath.Join(cfg.WorkspacePath(), "cron", "jobs.json") + + switch subcommand { + case "list": + cronListCmd(cronStorePath) + case "add": + cronAddCmd(cronStorePath) + case "remove": + if len(os.Args) < 4 { + fmt.Println("Usage: picoclaw cron remove ") + return + } + cronRemoveCmd(cronStorePath, os.Args[3]) + case "enable": + cronEnableCmd(cronStorePath, false) + case "disable": + cronEnableCmd(cronStorePath, true) + default: + fmt.Printf("Unknown cron command: %s\n", subcommand) + cronHelp() + } +} + +func cronHelp() { + fmt.Println("\nCron commands:") + fmt.Println(" list List all scheduled jobs") + fmt.Println(" add Add a new scheduled job") + fmt.Println(" remove Remove a job by ID") + fmt.Println(" enable Enable a job") + fmt.Println(" disable Disable a job") + fmt.Println() + fmt.Println("Add options:") + fmt.Println(" -n, --name Job name") + fmt.Println(" -m, --message Message for agent") + fmt.Println(" -e, --every Run every N seconds") + fmt.Println(" -c, --cron Cron expression (e.g. '0 9 * * *')") + fmt.Println(" -d, --deliver Deliver response to channel") + fmt.Println(" --to Recipient for delivery") + fmt.Println(" --channel Channel for delivery") +} + +func cronListCmd(storePath string) { + cs := cron.NewCronService(storePath, nil) + jobs := cs.ListJobs(true) // Show all jobs, including disabled + + if len(jobs) == 0 { + fmt.Println("No scheduled jobs.") + return + } + + fmt.Println("\nScheduled Jobs:") + fmt.Println("----------------") + for _, job := range jobs { + var schedule string + if job.Schedule.Kind == "every" && job.Schedule.EveryMS != nil { + schedule = fmt.Sprintf("every %ds", *job.Schedule.EveryMS/1000) + } else if job.Schedule.Kind == "cron" { + schedule = job.Schedule.Expr + } else { + schedule = "one-time" + } + + nextRun := "scheduled" + if job.State.NextRunAtMS != nil { + nextTime := time.UnixMilli(*job.State.NextRunAtMS) + nextRun = nextTime.Format("2006-01-02 15:04") + } + + status := "enabled" + if !job.Enabled { + status = "disabled" + } + + fmt.Printf(" %s (%s)\n", job.Name, job.ID) + fmt.Printf(" Schedule: %s\n", schedule) + fmt.Printf(" Status: %s\n", status) + fmt.Printf(" Next run: %s\n", nextRun) + } +} + +func cronAddCmd(storePath string) { + name := "" + message := "" + var everySec *int64 + cronExpr := "" + deliver := false + channel := "" + to := "" + + args := os.Args[3:] + for i := 0; i < len(args); i++ { + switch args[i] { + case "-n", "--name": + if i+1 < len(args) { + name = args[i+1] + i++ + } + case "-m", "--message": + if i+1 < len(args) { + message = args[i+1] + i++ + } + case "-e", "--every": + if i+1 < len(args) { + var sec int64 + fmt.Sscanf(args[i+1], "%d", &sec) + everySec = &sec + i++ + } + case "-c", "--cron": + if i+1 < len(args) { + cronExpr = args[i+1] + i++ + } + case "-d", "--deliver": + deliver = true + case "--to": + if i+1 < len(args) { + to = args[i+1] + i++ + } + case "--channel": + if i+1 < len(args) { + channel = args[i+1] + i++ + } + } + } + + if name == "" { + fmt.Println("Error: --name is required") + return + } + + if message == "" { + fmt.Println("Error: --message is required") + return + } + + if everySec == nil && cronExpr == "" { + fmt.Println("Error: Either --every or --cron must be specified") + return + } + + var schedule cron.CronSchedule + if everySec != nil { + everyMS := *everySec * 1000 + schedule = cron.CronSchedule{ + Kind: "every", + EveryMS: &everyMS, + } + } else { + schedule = cron.CronSchedule{ + Kind: "cron", + Expr: cronExpr, + } + } + + cs := cron.NewCronService(storePath, nil) + job, err := cs.AddJob(name, schedule, message, deliver, channel, to) + if err != nil { + fmt.Printf("Error adding job: %v\n", err) + return + } + + fmt.Printf("āœ“ Added job '%s' (%s)\n", job.Name, job.ID) +} + +func cronRemoveCmd(storePath, jobID string) { + cs := cron.NewCronService(storePath, nil) + if cs.RemoveJob(jobID) { + fmt.Printf("āœ“ Removed job %s\n", jobID) + } else { + fmt.Printf("āœ— Job %s not found\n", jobID) + } +} + +func cronEnableCmd(storePath string, disable bool) { + if len(os.Args) < 4 { + fmt.Println("Usage: picoclaw cron enable/disable ") + return + } + + jobID := os.Args[3] + cs := cron.NewCronService(storePath, nil) + enabled := !disable + + job := cs.EnableJob(jobID, enabled) + if job != nil { + status := "enabled" + if disable { + status = "disabled" + } + fmt.Printf("āœ“ Job '%s' %s\n", job.Name, status) + } else { + fmt.Printf("āœ— Job %s not found\n", jobID) + } +} diff --git a/cmd/picoclaw/cmd_gateway.go b/cmd/picoclaw/cmd_gateway.go new file mode 100644 index 000000000..a64c1219f --- /dev/null +++ b/cmd/picoclaw/cmd_gateway.go @@ -0,0 +1,222 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT + +package main + +import ( + "context" + "fmt" + "net/http" + "os" + "os/signal" + "path/filepath" + "time" + + "github.com/sipeed/picoclaw/pkg/agent" + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/cron" + "github.com/sipeed/picoclaw/pkg/devices" + "github.com/sipeed/picoclaw/pkg/health" + "github.com/sipeed/picoclaw/pkg/heartbeat" + "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/providers" + "github.com/sipeed/picoclaw/pkg/state" + "github.com/sipeed/picoclaw/pkg/tools" + "github.com/sipeed/picoclaw/pkg/voice" +) + +func gatewayCmd() { + // Check for --debug flag + args := os.Args[2:] + for _, arg := range args { + if arg == "--debug" || arg == "-d" { + logger.SetLevel(logger.DEBUG) + fmt.Println("šŸ” Debug mode enabled") + break + } + } + + cfg, err := loadConfig() + if err != nil { + fmt.Printf("Error loading config: %v\n", err) + os.Exit(1) + } + + provider, modelID, err := providers.CreateProvider(cfg) + if err != nil { + fmt.Printf("Error creating provider: %v\n", err) + os.Exit(1) + } + // Use the resolved model ID from provider creation + if modelID != "" { + cfg.Agents.Defaults.Model = modelID + } + + msgBus := bus.NewMessageBus() + agentLoop := agent.NewAgentLoop(cfg, msgBus, provider) + + // Print agent startup info + fmt.Println("\nšŸ“¦ Agent Status:") + startupInfo := agentLoop.GetStartupInfo() + toolsInfo := startupInfo["tools"].(map[string]interface{}) + skillsInfo := startupInfo["skills"].(map[string]interface{}) + fmt.Printf(" • Tools: %d loaded\n", toolsInfo["count"]) + fmt.Printf(" • Skills: %d/%d available\n", + skillsInfo["available"], + skillsInfo["total"]) + + // Log to file as well + logger.InfoCF("agent", "Agent initialized", + map[string]interface{}{ + "tools_count": toolsInfo["count"], + "skills_total": skillsInfo["total"], + "skills_available": skillsInfo["available"], + }) + + // Setup cron tool and service + execTimeout := time.Duration(cfg.Tools.Cron.ExecTimeoutMinutes) * time.Minute + cronService := setupCronTool(agentLoop, msgBus, cfg.WorkspacePath(), cfg.Agents.Defaults.RestrictToWorkspace, execTimeout) + + heartbeatService := heartbeat.NewHeartbeatService( + cfg.WorkspacePath(), + cfg.Heartbeat.Interval, + cfg.Heartbeat.Enabled, + ) + heartbeatService.SetBus(msgBus) + heartbeatService.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult { + // Use cli:direct as fallback if no valid channel + if channel == "" || chatID == "" { + channel, chatID = "cli", "direct" + } + // Use ProcessHeartbeat - no session history, each heartbeat is independent + response, err := agentLoop.ProcessHeartbeat(context.Background(), prompt, channel, chatID) + if err != nil { + return tools.ErrorResult(fmt.Sprintf("Heartbeat error: %v", err)) + } + if response == "HEARTBEAT_OK" { + return tools.SilentResult("Heartbeat OK") + } + // For heartbeat, always return silent - the subagent result will be + // sent to user via processSystemMessage when the async task completes + return tools.SilentResult(response) + }) + + channelManager, err := channels.NewManager(cfg, msgBus) + if err != nil { + fmt.Printf("Error creating channel manager: %v\n", err) + os.Exit(1) + } + + // Inject channel manager into agent loop for command handling + agentLoop.SetChannelManager(channelManager) + + var transcriber *voice.GroqTranscriber + if cfg.Providers.Groq.APIKey != "" { + transcriber = voice.NewGroqTranscriber(cfg.Providers.Groq.APIKey) + logger.InfoC("voice", "Groq voice transcription enabled") + } + + if transcriber != nil { + if telegramChannel, ok := channelManager.GetChannel("telegram"); ok { + if tc, ok := telegramChannel.(*channels.TelegramChannel); ok { + tc.SetTranscriber(transcriber) + logger.InfoC("voice", "Groq transcription attached to Telegram channel") + } + } + if discordChannel, ok := channelManager.GetChannel("discord"); ok { + if dc, ok := discordChannel.(*channels.DiscordChannel); ok { + dc.SetTranscriber(transcriber) + logger.InfoC("voice", "Groq transcription attached to Discord channel") + } + } + if slackChannel, ok := channelManager.GetChannel("slack"); ok { + if sc, ok := slackChannel.(*channels.SlackChannel); ok { + sc.SetTranscriber(transcriber) + logger.InfoC("voice", "Groq transcription attached to Slack channel") + } + } + } + + enabledChannels := channelManager.GetEnabledChannels() + if len(enabledChannels) > 0 { + fmt.Printf("āœ“ Channels enabled: %s\n", enabledChannels) + } else { + fmt.Println("⚠ Warning: No channels enabled") + } + + fmt.Printf("āœ“ Gateway started on %s:%d\n", cfg.Gateway.Host, cfg.Gateway.Port) + fmt.Println("Press Ctrl+C to stop") + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + if err := cronService.Start(); err != nil { + fmt.Printf("Error starting cron service: %v\n", err) + } + fmt.Println("āœ“ Cron service started") + + if err := heartbeatService.Start(); err != nil { + fmt.Printf("Error starting heartbeat service: %v\n", err) + } + fmt.Println("āœ“ Heartbeat service started") + + stateManager := state.NewManager(cfg.WorkspacePath()) + deviceService := devices.NewService(devices.Config{ + Enabled: cfg.Devices.Enabled, + MonitorUSB: cfg.Devices.MonitorUSB, + }, stateManager) + deviceService.SetBus(msgBus) + if err := deviceService.Start(ctx); err != nil { + fmt.Printf("Error starting device service: %v\n", err) + } else if cfg.Devices.Enabled { + fmt.Println("āœ“ Device event service started") + } + + if err := channelManager.StartAll(ctx); err != nil { + fmt.Printf("Error starting channels: %v\n", err) + } + + healthServer := health.NewServer(cfg.Gateway.Host, cfg.Gateway.Port) + go func() { + if err := healthServer.Start(); err != nil && err != http.ErrServerClosed { + logger.ErrorCF("health", "Health server error", map[string]interface{}{"error": err.Error()}) + } + }() + fmt.Printf("āœ“ Health endpoints available at http://%s:%d/health and /ready\n", cfg.Gateway.Host, cfg.Gateway.Port) + + go agentLoop.Run(ctx) + + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, os.Interrupt) + <-sigChan + + fmt.Println("\nShutting down...") + cancel() + healthServer.Stop(context.Background()) + deviceService.Stop() + heartbeatService.Stop() + cronService.Stop() + agentLoop.Stop() + channelManager.StopAll(ctx) + fmt.Println("āœ“ Gateway stopped") +} + +func setupCronTool(agentLoop *agent.AgentLoop, msgBus *bus.MessageBus, workspace string, restrict bool, execTimeout time.Duration) *cron.CronService { + cronStorePath := filepath.Join(workspace, "cron", "jobs.json") + + // Create cron service + cronService := cron.NewCronService(cronStorePath, nil) + + // Create and register CronTool + cronTool := tools.NewCronTool(cronService, agentLoop, msgBus, workspace, restrict, execTimeout) + agentLoop.RegisterTool(cronTool) + + // Set the onJob handler + cronService.SetOnJob(func(job *cron.CronJob) (string, error) { + result := cronTool.ExecuteJob(context.Background(), job) + return result, nil + }) + + return cronService +} diff --git a/cmd/picoclaw/cmd_migrate.go b/cmd/picoclaw/cmd_migrate.go new file mode 100644 index 000000000..86d4903ef --- /dev/null +++ b/cmd/picoclaw/cmd_migrate.go @@ -0,0 +1,81 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT + +package main + +import ( + "fmt" + "os" + + "github.com/sipeed/picoclaw/pkg/migrate" +) + +func migrateCmd() { + if len(os.Args) > 2 && (os.Args[2] == "--help" || os.Args[2] == "-h") { + migrateHelp() + return + } + + opts := migrate.Options{} + + args := os.Args[2:] + for i := 0; i < len(args); i++ { + switch args[i] { + case "--dry-run": + opts.DryRun = true + case "--config-only": + opts.ConfigOnly = true + case "--workspace-only": + opts.WorkspaceOnly = true + case "--force": + opts.Force = true + case "--refresh": + opts.Refresh = true + case "--openclaw-home": + if i+1 < len(args) { + opts.OpenClawHome = args[i+1] + i++ + } + case "--picoclaw-home": + if i+1 < len(args) { + opts.PicoClawHome = args[i+1] + i++ + } + default: + fmt.Printf("Unknown flag: %s\n", args[i]) + migrateHelp() + os.Exit(1) + } + } + + result, err := migrate.Run(opts) + if err != nil { + fmt.Printf("Error: %v\n", err) + os.Exit(1) + } + + if !opts.DryRun { + migrate.PrintSummary(result) + } +} + +func migrateHelp() { + fmt.Println("\nMigrate from OpenClaw to PicoClaw") + fmt.Println() + fmt.Println("Usage: picoclaw migrate [options]") + fmt.Println() + fmt.Println("Options:") + fmt.Println(" --dry-run Show what would be migrated without making changes") + fmt.Println(" --refresh Re-sync workspace files from OpenClaw (repeatable)") + fmt.Println(" --config-only Only migrate config, skip workspace files") + fmt.Println(" --workspace-only Only migrate workspace files, skip config") + fmt.Println(" --force Skip confirmation prompts") + fmt.Println(" --openclaw-home Override OpenClaw home directory (default: ~/.openclaw)") + fmt.Println(" --picoclaw-home Override PicoClaw home directory (default: ~/.picoclaw)") + fmt.Println() + fmt.Println("Examples:") + fmt.Println(" picoclaw migrate Detect and migrate from OpenClaw") + fmt.Println(" picoclaw migrate --dry-run Show what would be migrated") + fmt.Println(" picoclaw migrate --refresh Re-sync workspace files") + fmt.Println(" picoclaw migrate --force Migrate without confirmation") +} diff --git a/cmd/picoclaw/cmd_onboard.go b/cmd/picoclaw/cmd_onboard.go new file mode 100644 index 000000000..9c1e9916f --- /dev/null +++ b/cmd/picoclaw/cmd_onboard.go @@ -0,0 +1,102 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT + +package main + +import ( + "embed" + "fmt" + "io/fs" + "os" + "path/filepath" + + "github.com/sipeed/picoclaw/pkg/config" +) + +//go:generate cp -r ../../workspace . +//go:embed workspace +var embeddedFiles embed.FS + +func onboard() { + configPath := getConfigPath() + + if _, err := os.Stat(configPath); err == nil { + fmt.Printf("Config already exists at %s\n", configPath) + fmt.Print("Overwrite? (y/n): ") + var response string + fmt.Scanln(&response) + if response != "y" { + fmt.Println("Aborted.") + return + } + } + + cfg := config.DefaultConfig() + if err := config.SaveConfig(configPath, cfg); err != nil { + fmt.Printf("Error saving config: %v\n", err) + os.Exit(1) + } + + workspace := cfg.WorkspacePath() + createWorkspaceTemplates(workspace) + + fmt.Printf("%s picoclaw is ready!\n", logo) + fmt.Println("\nNext steps:") + fmt.Println(" 1. Add your API key to", configPath) + fmt.Println(" Get one at: https://openrouter.ai/keys") + fmt.Println(" 2. Chat: picoclaw agent -m \"Hello!\"") +} + +func copyEmbeddedToTarget(targetDir string) error { + // Ensure target directory exists + if err := os.MkdirAll(targetDir, 0755); err != nil { + return fmt.Errorf("Failed to create target directory: %w", err) + } + + // Walk through all files in embed.FS + err := fs.WalkDir(embeddedFiles, "workspace", func(path string, d fs.DirEntry, err error) error { + if err != nil { + return err + } + + // Skip directories + if d.IsDir() { + return nil + } + + // Read embedded file + data, err := embeddedFiles.ReadFile(path) + if err != nil { + return fmt.Errorf("Failed to read embedded file %s: %w", path, err) + } + + new_path, err := filepath.Rel("workspace", path) + if err != nil { + return fmt.Errorf("Failed to get relative path for %s: %v\n", path, err) + } + + // Build target file path + targetPath := filepath.Join(targetDir, new_path) + + // Ensure target file's directory exists + if err := os.MkdirAll(filepath.Dir(targetPath), 0755); err != nil { + return fmt.Errorf("Failed to create directory %s: %w", filepath.Dir(targetPath), err) + } + + // Write file + if err := os.WriteFile(targetPath, data, 0644); err != nil { + return fmt.Errorf("Failed to write file %s: %w", targetPath, err) + } + + return nil + }) + + return err +} + +func createWorkspaceTemplates(workspace string) { + err := copyEmbeddedToTarget(workspace) + if err != nil { + fmt.Printf("Error copying workspace templates: %v\n", err) + } +} diff --git a/cmd/picoclaw/cmd_skills.go b/cmd/picoclaw/cmd_skills.go new file mode 100644 index 000000000..9ea38dcf6 --- /dev/null +++ b/cmd/picoclaw/cmd_skills.go @@ -0,0 +1,216 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT + +package main + +import ( + "context" + "fmt" + "os" + "path/filepath" + "strings" + "time" + + "github.com/sipeed/picoclaw/pkg/skills" +) + +func skillsHelp() { + fmt.Println("\nSkills commands:") + fmt.Println(" list List installed skills") + fmt.Println(" install Install skill from GitHub") + fmt.Println(" install-builtin Install all builtin skills to workspace") + fmt.Println(" list-builtin List available builtin skills") + fmt.Println(" remove Remove installed skill") + fmt.Println(" search Search available skills") + fmt.Println(" show Show skill details") + fmt.Println() + fmt.Println("Examples:") + fmt.Println(" picoclaw skills list") + fmt.Println(" picoclaw skills install sipeed/picoclaw-skills/weather") + fmt.Println(" picoclaw skills install-builtin") + fmt.Println(" picoclaw skills list-builtin") + fmt.Println(" picoclaw skills remove weather") +} + +func skillsListCmd(loader *skills.SkillsLoader) { + allSkills := loader.ListSkills() + + if len(allSkills) == 0 { + fmt.Println("No skills installed.") + return + } + + fmt.Println("\nInstalled Skills:") + fmt.Println("------------------") + for _, skill := range allSkills { + fmt.Printf(" āœ“ %s (%s)\n", skill.Name, skill.Source) + if skill.Description != "" { + fmt.Printf(" %s\n", skill.Description) + } + } +} + +func skillsInstallCmd(installer *skills.SkillInstaller) { + if len(os.Args) < 4 { + fmt.Println("Usage: picoclaw skills install ") + fmt.Println("Example: picoclaw skills install sipeed/picoclaw-skills/weather") + return + } + + repo := os.Args[3] + fmt.Printf("Installing skill from %s...\n", repo) + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + if err := installer.InstallFromGitHub(ctx, repo); err != nil { + fmt.Printf("āœ— Failed to install skill: %v\n", err) + os.Exit(1) + } + + fmt.Printf("āœ“ Skill '%s' installed successfully!\n", filepath.Base(repo)) +} + +func skillsRemoveCmd(installer *skills.SkillInstaller, skillName string) { + fmt.Printf("Removing skill '%s'...\n", skillName) + + if err := installer.Uninstall(skillName); err != nil { + fmt.Printf("āœ— Failed to remove skill: %v\n", err) + os.Exit(1) + } + + fmt.Printf("āœ“ Skill '%s' removed successfully!\n", skillName) +} + +func skillsInstallBuiltinCmd(workspace string) { + builtinSkillsDir := "./picoclaw/skills" + workspaceSkillsDir := filepath.Join(workspace, "skills") + + fmt.Printf("Copying builtin skills to workspace...\n") + + skillsToInstall := []string{ + "weather", + "news", + "stock", + "calculator", + } + + for _, skillName := range skillsToInstall { + builtinPath := filepath.Join(builtinSkillsDir, skillName) + workspacePath := filepath.Join(workspaceSkillsDir, skillName) + + if _, err := os.Stat(builtinPath); err != nil { + fmt.Printf("⊘ Builtin skill '%s' not found: %v\n", skillName, err) + continue + } + + if err := os.MkdirAll(workspacePath, 0755); err != nil { + fmt.Printf("āœ— Failed to create directory for %s: %v\n", skillName, err) + continue + } + + if err := copyDirectory(builtinPath, workspacePath); err != nil { + fmt.Printf("āœ— Failed to copy %s: %v\n", skillName, err) + } + } + + fmt.Println("\nāœ“ All builtin skills installed!") + fmt.Println("Now you can use them in your workspace.") +} + +func skillsListBuiltinCmd() { + cfg, err := loadConfig() + if err != nil { + fmt.Printf("Error loading config: %v\n", err) + return + } + builtinSkillsDir := filepath.Join(filepath.Dir(cfg.WorkspacePath()), "picoclaw", "skills") + + fmt.Println("\nAvailable Builtin Skills:") + fmt.Println("-----------------------") + + entries, err := os.ReadDir(builtinSkillsDir) + if err != nil { + fmt.Printf("Error reading builtin skills: %v\n", err) + return + } + + if len(entries) == 0 { + fmt.Println("No builtin skills available.") + return + } + + for _, entry := range entries { + if entry.IsDir() { + skillName := entry.Name() + skillFile := filepath.Join(builtinSkillsDir, skillName, "SKILL.md") + + description := "No description" + if _, err := os.Stat(skillFile); err == nil { + data, err := os.ReadFile(skillFile) + if err == nil { + content := string(data) + if idx := strings.Index(content, "\n"); idx > 0 { + firstLine := content[:idx] + if strings.Contains(firstLine, "description:") { + descLine := strings.Index(content[idx:], "\n") + if descLine > 0 { + description = strings.TrimSpace(content[idx+descLine : idx+descLine]) + } + } + } + } + } + status := "āœ“" + fmt.Printf(" %s %s\n", status, entry.Name()) + if description != "" { + fmt.Printf(" %s\n", description) + } + } + } +} + +func skillsSearchCmd(installer *skills.SkillInstaller) { + fmt.Println("Searching for available skills...") + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + availableSkills, err := installer.ListAvailableSkills(ctx) + if err != nil { + fmt.Printf("āœ— Failed to fetch skills list: %v\n", err) + return + } + + if len(availableSkills) == 0 { + fmt.Println("No skills available.") + return + } + + fmt.Printf("\nAvailable Skills (%d):\n", len(availableSkills)) + fmt.Println("--------------------") + for _, skill := range availableSkills { + fmt.Printf(" šŸ“¦ %s\n", skill.Name) + fmt.Printf(" %s\n", skill.Description) + fmt.Printf(" Repo: %s\n", skill.Repository) + if skill.Author != "" { + fmt.Printf(" Author: %s\n", skill.Author) + } + if len(skill.Tags) > 0 { + fmt.Printf(" Tags: %v\n", skill.Tags) + } + fmt.Println() + } +} + +func skillsShowCmd(loader *skills.SkillsLoader, skillName string) { + content, ok := loader.LoadSkill(skillName) + if !ok { + fmt.Printf("āœ— Skill '%s' not found\n", skillName) + return + } + + fmt.Printf("\nšŸ“¦ Skill: %s\n", skillName) + fmt.Println("----------------------") + fmt.Println(content) +} diff --git a/cmd/picoclaw/cmd_status.go b/cmd/picoclaw/cmd_status.go new file mode 100644 index 000000000..07296784e --- /dev/null +++ b/cmd/picoclaw/cmd_status.go @@ -0,0 +1,102 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT + +package main + +import ( + "fmt" + "os" + + "github.com/sipeed/picoclaw/pkg/auth" +) + +func statusCmd() { + cfg, err := loadConfig() + if err != nil { + fmt.Printf("Error loading config: %v\n", err) + return + } + + configPath := getConfigPath() + + fmt.Printf("%s picoclaw Status\n", logo) + fmt.Printf("Version: %s\n", formatVersion()) + build, _ := formatBuildInfo() + if build != "" { + fmt.Printf("Build: %s\n", build) + } + fmt.Println() + + if _, err := os.Stat(configPath); err == nil { + fmt.Println("Config:", configPath, "āœ“") + } else { + fmt.Println("Config:", configPath, "āœ—") + } + + workspace := cfg.WorkspacePath() + if _, err := os.Stat(workspace); err == nil { + fmt.Println("Workspace:", workspace, "āœ“") + } else { + fmt.Println("Workspace:", workspace, "āœ—") + } + + if _, err := os.Stat(configPath); err == nil { + fmt.Printf("Model: %s\n", cfg.Agents.Defaults.Model) + + hasOpenRouter := cfg.Providers.OpenRouter.APIKey != "" + hasAnthropic := cfg.Providers.Anthropic.APIKey != "" + hasOpenAI := cfg.Providers.OpenAI.APIKey != "" + hasGemini := cfg.Providers.Gemini.APIKey != "" + hasZhipu := cfg.Providers.Zhipu.APIKey != "" + hasQwen := cfg.Providers.Qwen.APIKey != "" + hasGroq := cfg.Providers.Groq.APIKey != "" + hasVLLM := cfg.Providers.VLLM.APIBase != "" + hasMoonshot := cfg.Providers.Moonshot.APIKey != "" + hasDeepSeek := cfg.Providers.DeepSeek.APIKey != "" + hasVolcEngine := cfg.Providers.VolcEngine.APIKey != "" + hasNvidia := cfg.Providers.Nvidia.APIKey != "" + hasOllama := cfg.Providers.Ollama.APIBase != "" + + status := func(enabled bool) string { + if enabled { + return "āœ“" + } + return "not set" + } + fmt.Println("OpenRouter API:", status(hasOpenRouter)) + fmt.Println("Anthropic API:", status(hasAnthropic)) + fmt.Println("OpenAI API:", status(hasOpenAI)) + fmt.Println("Gemini API:", status(hasGemini)) + fmt.Println("Zhipu API:", status(hasZhipu)) + fmt.Println("Qwen API:", status(hasQwen)) + fmt.Println("Groq API:", status(hasGroq)) + fmt.Println("Moonshot API:", status(hasMoonshot)) + fmt.Println("DeepSeek API:", status(hasDeepSeek)) + fmt.Println("VolcEngine API:", status(hasVolcEngine)) + fmt.Println("Nvidia API:", status(hasNvidia)) + if hasVLLM { + fmt.Printf("vLLM/Local: āœ“ %s\n", cfg.Providers.VLLM.APIBase) + } else { + fmt.Println("vLLM/Local: not set") + } + if hasOllama { + fmt.Printf("Ollama: āœ“ %s\n", cfg.Providers.Ollama.APIBase) + } else { + fmt.Println("Ollama: not set") + } + + store, _ := auth.LoadStore() + if store != nil && len(store.Credentials) > 0 { + fmt.Println("\nOAuth/Token Auth:") + for provider, cred := range store.Credentials { + status := "authenticated" + if cred.IsExpired() { + status = "expired" + } else if cred.NeedsRefresh() { + status = "needs refresh" + } + fmt.Printf(" %s (%s): %s\n", provider, cred.AuthMethod, status) + } + } + } +} diff --git a/cmd/picoclaw/main.go b/cmd/picoclaw/main.go index 33ad74255..ce9389417 100644 --- a/cmd/picoclaw/main.go +++ b/cmd/picoclaw/main.go @@ -7,44 +7,16 @@ package main import ( - "bufio" - "context" - "embed" - "encoding/json" "fmt" "io" - "io/fs" - "net/http" "os" - "os/signal" "path/filepath" "runtime" - "strings" - "time" - "github.com/chzyer/readline" - "github.com/sipeed/picoclaw/pkg/agent" - "github.com/sipeed/picoclaw/pkg/auth" - "github.com/sipeed/picoclaw/pkg/bus" - "github.com/sipeed/picoclaw/pkg/channels" "github.com/sipeed/picoclaw/pkg/config" - "github.com/sipeed/picoclaw/pkg/cron" - "github.com/sipeed/picoclaw/pkg/devices" - "github.com/sipeed/picoclaw/pkg/health" - "github.com/sipeed/picoclaw/pkg/heartbeat" - "github.com/sipeed/picoclaw/pkg/logger" - "github.com/sipeed/picoclaw/pkg/migrate" - "github.com/sipeed/picoclaw/pkg/providers" "github.com/sipeed/picoclaw/pkg/skills" - "github.com/sipeed/picoclaw/pkg/state" - "github.com/sipeed/picoclaw/pkg/tools" - "github.com/sipeed/picoclaw/pkg/voice" ) -//go:generate cp -r ../../workspace . -//go:embed workspace -var embeddedFiles embed.FS - var ( version = "dev" gitCommit string @@ -217,1388 +189,11 @@ func printHelp() { fmt.Println(" version Show version information") } -func onboard() { - configPath := getConfigPath() - - if _, err := os.Stat(configPath); err == nil { - fmt.Printf("Config already exists at %s\n", configPath) - fmt.Print("Overwrite? (y/n): ") - var response string - fmt.Scanln(&response) - if response != "y" { - fmt.Println("Aborted.") - return - } - } - - cfg := config.DefaultConfig() - if err := config.SaveConfig(configPath, cfg); err != nil { - fmt.Printf("Error saving config: %v\n", err) - os.Exit(1) - } - - workspace := cfg.WorkspacePath() - createWorkspaceTemplates(workspace) - - fmt.Printf("%s picoclaw is ready!\n", logo) - fmt.Println("\nNext steps:") - fmt.Println(" 1. Add your API key to", configPath) - fmt.Println(" Get one at: https://openrouter.ai/keys") - fmt.Println(" 2. Chat: picoclaw agent -m \"Hello!\"") -} - -func copyEmbeddedToTarget(targetDir string) error { - // Ensure target directory exists - if err := os.MkdirAll(targetDir, 0755); err != nil { - return fmt.Errorf("Failed to create target directory: %w", err) - } - - // Walk through all files in embed.FS - err := fs.WalkDir(embeddedFiles, "workspace", func(path string, d fs.DirEntry, err error) error { - if err != nil { - return err - } - - // Skip directories - if d.IsDir() { - return nil - } - - // Read embedded file - data, err := embeddedFiles.ReadFile(path) - if err != nil { - return fmt.Errorf("Failed to read embedded file %s: %w", path, err) - } - - new_path, err := filepath.Rel("workspace", path) - if err != nil { - return fmt.Errorf("Failed to get relative path for %s: %v\n", path, err) - } - - // Build target file path - targetPath := filepath.Join(targetDir, new_path) - - // Ensure target file's directory exists - if err := os.MkdirAll(filepath.Dir(targetPath), 0755); err != nil { - return fmt.Errorf("Failed to create directory %s: %w", filepath.Dir(targetPath), err) - } - - // Write file - if err := os.WriteFile(targetPath, data, 0644); err != nil { - return fmt.Errorf("Failed to write file %s: %w", targetPath, err) - } - - return nil - }) - - return err -} - -func createWorkspaceTemplates(workspace string) { - err := copyEmbeddedToTarget(workspace) - if err != nil { - fmt.Printf("Error copying workspace templates: %v\n", err) - } -} - -func migrateCmd() { - if len(os.Args) > 2 && (os.Args[2] == "--help" || os.Args[2] == "-h") { - migrateHelp() - return - } - - opts := migrate.Options{} - - args := os.Args[2:] - for i := 0; i < len(args); i++ { - switch args[i] { - case "--dry-run": - opts.DryRun = true - case "--config-only": - opts.ConfigOnly = true - case "--workspace-only": - opts.WorkspaceOnly = true - case "--force": - opts.Force = true - case "--refresh": - opts.Refresh = true - case "--openclaw-home": - if i+1 < len(args) { - opts.OpenClawHome = args[i+1] - i++ - } - case "--picoclaw-home": - if i+1 < len(args) { - opts.PicoClawHome = args[i+1] - i++ - } - default: - fmt.Printf("Unknown flag: %s\n", args[i]) - migrateHelp() - os.Exit(1) - } - } - - result, err := migrate.Run(opts) - if err != nil { - fmt.Printf("Error: %v\n", err) - os.Exit(1) - } - - if !opts.DryRun { - migrate.PrintSummary(result) - } -} - -func migrateHelp() { - fmt.Println("\nMigrate from OpenClaw to PicoClaw") - fmt.Println() - fmt.Println("Usage: picoclaw migrate [options]") - fmt.Println() - fmt.Println("Options:") - fmt.Println(" --dry-run Show what would be migrated without making changes") - fmt.Println(" --refresh Re-sync workspace files from OpenClaw (repeatable)") - fmt.Println(" --config-only Only migrate config, skip workspace files") - fmt.Println(" --workspace-only Only migrate workspace files, skip config") - fmt.Println(" --force Skip confirmation prompts") - fmt.Println(" --openclaw-home Override OpenClaw home directory (default: ~/.openclaw)") - fmt.Println(" --picoclaw-home Override PicoClaw home directory (default: ~/.picoclaw)") - fmt.Println() - fmt.Println("Examples:") - fmt.Println(" picoclaw migrate Detect and migrate from OpenClaw") - fmt.Println(" picoclaw migrate --dry-run Show what would be migrated") - fmt.Println(" picoclaw migrate --refresh Re-sync workspace files") - fmt.Println(" picoclaw migrate --force Migrate without confirmation") -} - -func agentCmd() { - message := "" - sessionKey := "cli:default" - modelOverride := "" - - args := os.Args[2:] - for i := 0; i < len(args); i++ { - switch args[i] { - case "--debug", "-d": - logger.SetLevel(logger.DEBUG) - fmt.Println("šŸ” Debug mode enabled") - case "-m", "--message": - if i+1 < len(args) { - message = args[i+1] - i++ - } - case "-s", "--session": - if i+1 < len(args) { - sessionKey = args[i+1] - i++ - } - case "--model", "-model": - if i+1 < len(args) { - modelOverride = args[i+1] - i++ - } - } - } - - cfg, err := loadConfig() - if err != nil { - fmt.Printf("Error loading config: %v\n", err) - os.Exit(1) - } - - if modelOverride != "" { - cfg.Agents.Defaults.Model = modelOverride - } - - provider, err := providers.CreateProvider(cfg) - if err != nil { - fmt.Printf("Error creating provider: %v\n", err) - os.Exit(1) - } - - msgBus := bus.NewMessageBus() - agentLoop := agent.NewAgentLoop(cfg, msgBus, provider) - - // Print agent startup info (only for interactive mode) - startupInfo := agentLoop.GetStartupInfo() - logger.InfoCF("agent", "Agent initialized", - map[string]interface{}{ - "tools_count": startupInfo["tools"].(map[string]interface{})["count"], - "skills_total": startupInfo["skills"].(map[string]interface{})["total"], - "skills_available": startupInfo["skills"].(map[string]interface{})["available"], - }) - - if message != "" { - ctx := context.Background() - response, err := agentLoop.ProcessDirect(ctx, message, sessionKey) - if err != nil { - fmt.Printf("Error: %v\n", err) - os.Exit(1) - } - fmt.Printf("\n%s %s\n", logo, response) - } else { - fmt.Printf("%s Interactive mode (Ctrl+C to exit)\n\n", logo) - interactiveMode(agentLoop, sessionKey) - } -} - -func interactiveMode(agentLoop *agent.AgentLoop, sessionKey string) { - prompt := fmt.Sprintf("%s You: ", logo) - - rl, err := readline.NewEx(&readline.Config{ - Prompt: prompt, - HistoryFile: filepath.Join(os.TempDir(), ".picoclaw_history"), - HistoryLimit: 100, - InterruptPrompt: "^C", - EOFPrompt: "exit", - }) - - if err != nil { - fmt.Printf("Error initializing readline: %v\n", err) - fmt.Println("Falling back to simple input mode...") - simpleInteractiveMode(agentLoop, sessionKey) - return - } - defer rl.Close() - - for { - line, err := rl.Readline() - if err != nil { - if err == readline.ErrInterrupt || err == io.EOF { - fmt.Println("\nGoodbye!") - return - } - fmt.Printf("Error reading input: %v\n", err) - continue - } - - input := strings.TrimSpace(line) - if input == "" { - continue - } - - if input == "exit" || input == "quit" { - fmt.Println("Goodbye!") - return - } - - ctx := context.Background() - response, err := agentLoop.ProcessDirect(ctx, input, sessionKey) - if err != nil { - fmt.Printf("Error: %v\n", err) - continue - } - - fmt.Printf("\n%s %s\n\n", logo, response) - } -} - -func simpleInteractiveMode(agentLoop *agent.AgentLoop, sessionKey string) { - reader := bufio.NewReader(os.Stdin) - for { - fmt.Print(fmt.Sprintf("%s You: ", logo)) - line, err := reader.ReadString('\n') - if err != nil { - if err == io.EOF { - fmt.Println("\nGoodbye!") - return - } - fmt.Printf("Error reading input: %v\n", err) - continue - } - - input := strings.TrimSpace(line) - if input == "" { - continue - } - - if input == "exit" || input == "quit" { - fmt.Println("Goodbye!") - return - } - - ctx := context.Background() - response, err := agentLoop.ProcessDirect(ctx, input, sessionKey) - if err != nil { - fmt.Printf("Error: %v\n", err) - continue - } - - fmt.Printf("\n%s %s\n\n", logo, response) - } -} - -func gatewayCmd() { - // Check for --debug flag - args := os.Args[2:] - for _, arg := range args { - if arg == "--debug" || arg == "-d" { - logger.SetLevel(logger.DEBUG) - fmt.Println("šŸ” Debug mode enabled") - break - } - } - - cfg, err := loadConfig() - if err != nil { - fmt.Printf("Error loading config: %v\n", err) - os.Exit(1) - } - - provider, err := providers.CreateProvider(cfg) - if err != nil { - fmt.Printf("Error creating provider: %v\n", err) - os.Exit(1) - } - - msgBus := bus.NewMessageBus() - agentLoop := agent.NewAgentLoop(cfg, msgBus, provider) - - // Print agent startup info - fmt.Println("\nšŸ“¦ Agent Status:") - startupInfo := agentLoop.GetStartupInfo() - toolsInfo := startupInfo["tools"].(map[string]interface{}) - skillsInfo := startupInfo["skills"].(map[string]interface{}) - fmt.Printf(" • Tools: %d loaded\n", toolsInfo["count"]) - fmt.Printf(" • Skills: %d/%d available\n", - skillsInfo["available"], - skillsInfo["total"]) - - // Log to file as well - logger.InfoCF("agent", "Agent initialized", - map[string]interface{}{ - "tools_count": toolsInfo["count"], - "skills_total": skillsInfo["total"], - "skills_available": skillsInfo["available"], - }) - - // Setup cron tool and service - execTimeout := time.Duration(cfg.Tools.Cron.ExecTimeoutMinutes) * time.Minute - cronService := setupCronTool(agentLoop, msgBus, cfg.WorkspacePath(), cfg.Agents.Defaults.RestrictToWorkspace, execTimeout) - - heartbeatService := heartbeat.NewHeartbeatService( - cfg.WorkspacePath(), - cfg.Heartbeat.Interval, - cfg.Heartbeat.Enabled, - ) - heartbeatService.SetBus(msgBus) - heartbeatService.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult { - // Use cli:direct as fallback if no valid channel - if channel == "" || chatID == "" { - channel, chatID = "cli", "direct" - } - // Use ProcessHeartbeat - no session history, each heartbeat is independent - response, err := agentLoop.ProcessHeartbeat(context.Background(), prompt, channel, chatID) - if err != nil { - return tools.ErrorResult(fmt.Sprintf("Heartbeat error: %v", err)) - } - if response == "HEARTBEAT_OK" { - return tools.SilentResult("Heartbeat OK") - } - // For heartbeat, always return silent - the subagent result will be - // sent to user via processSystemMessage when the async task completes - return tools.SilentResult(response) - }) - - channelManager, err := channels.NewManager(cfg, msgBus) - if err != nil { - fmt.Printf("Error creating channel manager: %v\n", err) - os.Exit(1) - } - - // Inject channel manager into agent loop for command handling - agentLoop.SetChannelManager(channelManager) - - var transcriber *voice.GroqTranscriber - if cfg.Providers.Groq.APIKey != "" { - transcriber = voice.NewGroqTranscriber(cfg.Providers.Groq.APIKey) - logger.InfoC("voice", "Groq voice transcription enabled") - } - - if transcriber != nil { - if telegramChannel, ok := channelManager.GetChannel("telegram"); ok { - if tc, ok := telegramChannel.(*channels.TelegramChannel); ok { - tc.SetTranscriber(transcriber) - logger.InfoC("voice", "Groq transcription attached to Telegram channel") - } - } - if discordChannel, ok := channelManager.GetChannel("discord"); ok { - if dc, ok := discordChannel.(*channels.DiscordChannel); ok { - dc.SetTranscriber(transcriber) - logger.InfoC("voice", "Groq transcription attached to Discord channel") - } - } - if slackChannel, ok := channelManager.GetChannel("slack"); ok { - if sc, ok := slackChannel.(*channels.SlackChannel); ok { - sc.SetTranscriber(transcriber) - logger.InfoC("voice", "Groq transcription attached to Slack channel") - } - } - } - - enabledChannels := channelManager.GetEnabledChannels() - if len(enabledChannels) > 0 { - fmt.Printf("āœ“ Channels enabled: %s\n", enabledChannels) - } else { - fmt.Println("⚠ Warning: No channels enabled") - } - - fmt.Printf("āœ“ Gateway started on %s:%d\n", cfg.Gateway.Host, cfg.Gateway.Port) - fmt.Println("Press Ctrl+C to stop") - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - if err := cronService.Start(); err != nil { - fmt.Printf("Error starting cron service: %v\n", err) - } - fmt.Println("āœ“ Cron service started") - - if err := heartbeatService.Start(); err != nil { - fmt.Printf("Error starting heartbeat service: %v\n", err) - } - fmt.Println("āœ“ Heartbeat service started") - - stateManager := state.NewManager(cfg.WorkspacePath()) - deviceService := devices.NewService(devices.Config{ - Enabled: cfg.Devices.Enabled, - MonitorUSB: cfg.Devices.MonitorUSB, - }, stateManager) - deviceService.SetBus(msgBus) - if err := deviceService.Start(ctx); err != nil { - fmt.Printf("Error starting device service: %v\n", err) - } else if cfg.Devices.Enabled { - fmt.Println("āœ“ Device event service started") - } - - if err := channelManager.StartAll(ctx); err != nil { - fmt.Printf("Error starting channels: %v\n", err) - } - - healthServer := health.NewServer(cfg.Gateway.Host, cfg.Gateway.Port) - go func() { - if err := healthServer.Start(); err != nil && err != http.ErrServerClosed { - logger.ErrorCF("health", "Health server error", map[string]interface{}{"error": err.Error()}) - } - }() - fmt.Printf("āœ“ Health endpoints available at http://%s:%d/health and /ready\n", cfg.Gateway.Host, cfg.Gateway.Port) - - go agentLoop.Run(ctx) - - sigChan := make(chan os.Signal, 1) - signal.Notify(sigChan, os.Interrupt) - <-sigChan - - fmt.Println("\nShutting down...") - cancel() - healthServer.Stop(context.Background()) - deviceService.Stop() - heartbeatService.Stop() - cronService.Stop() - agentLoop.Stop() - channelManager.StopAll(ctx) - fmt.Println("āœ“ Gateway stopped") -} - -func statusCmd() { - cfg, err := loadConfig() - if err != nil { - fmt.Printf("Error loading config: %v\n", err) - return - } - - configPath := getConfigPath() - - fmt.Printf("%s picoclaw Status\n", logo) - fmt.Printf("Version: %s\n", formatVersion()) - build, _ := formatBuildInfo() - if build != "" { - fmt.Printf("Build: %s\n", build) - } - fmt.Println() - - if _, err := os.Stat(configPath); err == nil { - fmt.Println("Config:", configPath, "āœ“") - } else { - fmt.Println("Config:", configPath, "āœ—") - } - - workspace := cfg.WorkspacePath() - if _, err := os.Stat(workspace); err == nil { - fmt.Println("Workspace:", workspace, "āœ“") - } else { - fmt.Println("Workspace:", workspace, "āœ—") - } - - if _, err := os.Stat(configPath); err == nil { - fmt.Printf("Model: %s\n", cfg.Agents.Defaults.Model) - - hasOpenRouter := cfg.Providers.OpenRouter.APIKey != "" - hasAnthropic := cfg.Providers.Anthropic.APIKey != "" - hasOpenAI := cfg.Providers.OpenAI.APIKey != "" - hasGemini := cfg.Providers.Gemini.APIKey != "" - hasZhipu := cfg.Providers.Zhipu.APIKey != "" - hasQwen := cfg.Providers.Qwen.APIKey != "" - hasGroq := cfg.Providers.Groq.APIKey != "" - hasVLLM := cfg.Providers.VLLM.APIBase != "" - hasMoonshot := cfg.Providers.Moonshot.APIKey != "" - hasDeepSeek := cfg.Providers.DeepSeek.APIKey != "" - hasVolcEngine := cfg.Providers.VolcEngine.APIKey != "" - hasNvidia := cfg.Providers.Nvidia.APIKey != "" - hasOllama := cfg.Providers.Ollama.APIBase != "" - - status := func(enabled bool) string { - if enabled { - return "āœ“" - } - return "not set" - } - fmt.Println("OpenRouter API:", status(hasOpenRouter)) - fmt.Println("Anthropic API:", status(hasAnthropic)) - fmt.Println("OpenAI API:", status(hasOpenAI)) - fmt.Println("Gemini API:", status(hasGemini)) - fmt.Println("Zhipu API:", status(hasZhipu)) - fmt.Println("Qwen API:", status(hasQwen)) - fmt.Println("Groq API:", status(hasGroq)) - fmt.Println("Moonshot API:", status(hasMoonshot)) - fmt.Println("DeepSeek API:", status(hasDeepSeek)) - fmt.Println("VolcEngine API:", status(hasVolcEngine)) - fmt.Println("Nvidia API:", status(hasNvidia)) - if hasVLLM { - fmt.Printf("vLLM/Local: āœ“ %s\n", cfg.Providers.VLLM.APIBase) - } else { - fmt.Println("vLLM/Local: not set") - } - if hasOllama { - fmt.Printf("Ollama: āœ“ %s\n", cfg.Providers.Ollama.APIBase) - } else { - fmt.Println("Ollama: not set") - } - - store, _ := auth.LoadStore() - if store != nil && len(store.Credentials) > 0 { - fmt.Println("\nOAuth/Token Auth:") - for provider, cred := range store.Credentials { - status := "authenticated" - if cred.IsExpired() { - status = "expired" - } else if cred.NeedsRefresh() { - status = "needs refresh" - } - fmt.Printf(" %s (%s): %s\n", provider, cred.AuthMethod, status) - } - } - } -} - -func authCmd() { - if len(os.Args) < 3 { - authHelp() - return - } - - switch os.Args[2] { - case "login": - authLoginCmd() - case "logout": - authLogoutCmd() - case "status": - authStatusCmd() - case "models": - authModelsCmd() - default: - fmt.Printf("Unknown auth command: %s\n", os.Args[2]) - authHelp() - } -} - -func authHelp() { - fmt.Println("\nAuth commands:") - fmt.Println(" login Login via OAuth or paste token") - fmt.Println(" logout Remove stored credentials") - fmt.Println(" status Show current auth status") - fmt.Println(" models List available Antigravity models") - fmt.Println() - fmt.Println("Login options:") - fmt.Println(" --provider Provider to login with (openai, anthropic, google-antigravity)") - fmt.Println(" --device-code Use device code flow (for headless environments)") - fmt.Println() - fmt.Println("Examples:") - fmt.Println(" picoclaw auth login --provider openai") - fmt.Println(" picoclaw auth login --provider openai --device-code") - fmt.Println(" picoclaw auth login --provider anthropic") - fmt.Println(" picoclaw auth login --provider google-antigravity") - fmt.Println(" picoclaw auth models") - fmt.Println(" picoclaw auth logout --provider openai") - fmt.Println(" picoclaw auth status") -} - -func authLoginCmd() { - provider := "" - useDeviceCode := false - - args := os.Args[3:] - for i := 0; i < len(args); i++ { - switch args[i] { - case "--provider", "-p": - if i+1 < len(args) { - provider = args[i+1] - i++ - } - case "--device-code": - useDeviceCode = true - } - } - - if provider == "" { - fmt.Println("Error: --provider is required") - fmt.Println("Supported providers: openai, anthropic, google-antigravity") - return - } - - switch provider { - case "openai": - authLoginOpenAI(useDeviceCode) - case "anthropic": - authLoginPasteToken(provider) - case "google-antigravity", "antigravity": - authLoginGoogleAntigravity() - default: - fmt.Printf("Unsupported provider: %s\n", provider) - fmt.Println("Supported providers: openai, anthropic, google-antigravity") - } -} - -func authLoginOpenAI(useDeviceCode bool) { - cfg := auth.OpenAIOAuthConfig() - - var cred *auth.AuthCredential - var err error - - if useDeviceCode { - cred, err = auth.LoginDeviceCode(cfg) - } else { - cred, err = auth.LoginBrowser(cfg) - } - - if err != nil { - fmt.Printf("Login failed: %v\n", err) - os.Exit(1) - } - - if err := auth.SetCredential("openai", cred); err != nil { - fmt.Printf("Failed to save credentials: %v\n", err) - os.Exit(1) - } - - appCfg, err := loadConfig() - if err == nil { - appCfg.Providers.OpenAI.AuthMethod = "oauth" - if err := config.SaveConfig(getConfigPath(), appCfg); err != nil { - fmt.Printf("Warning: could not update config: %v\n", err) - } - } - - fmt.Println("Login successful!") - if cred.AccountID != "" { - fmt.Printf("Account: %s\n", cred.AccountID) - } -} - -func authLoginGoogleAntigravity() { - cfg := auth.GoogleAntigravityOAuthConfig() - - cred, err := auth.LoginBrowser(cfg) - if err != nil { - fmt.Printf("Login failed: %v\n", err) - os.Exit(1) - } - - cred.Provider = "google-antigravity" - - // Fetch user email from Google userinfo - email, err := fetchGoogleUserEmail(cred.AccessToken) - if err != nil { - fmt.Printf("Warning: could not fetch email: %v\n", err) - } else { - cred.Email = email - fmt.Printf("Email: %s\n", email) - } - - // Fetch Cloud Code Assist project ID - projectID, err := providers.FetchAntigravityProjectID(cred.AccessToken) - if err != nil { - fmt.Printf("Warning: could not fetch project ID: %v\n", err) - fmt.Println("You may need Google Cloud Code Assist enabled on your account.") - } else { - cred.ProjectID = projectID - fmt.Printf("Project: %s\n", projectID) - } - - if err := auth.SetCredential("google-antigravity", cred); err != nil { - fmt.Printf("Failed to save credentials: %v\n", err) - os.Exit(1) - } - - appCfg, err := loadConfig() - if err == nil { - appCfg.Providers.Antigravity.AuthMethod = "oauth" - if appCfg.Agents.Defaults.Provider == "" { - appCfg.Agents.Defaults.Provider = "antigravity" - } - if appCfg.Agents.Defaults.Provider == "antigravity" || appCfg.Agents.Defaults.Provider == "google-antigravity" { - appCfg.Agents.Defaults.Model = "gemini-3-flash" - } - if err := config.SaveConfig(getConfigPath(), appCfg); err != nil { - fmt.Printf("Warning: could not update config: %v\n", err) - } - } - - fmt.Println("\nāœ“ Google Antigravity login successful!") - fmt.Println("Config updated: provider=antigravity, model=gemini-3-flash") - fmt.Println("Try it: picoclaw agent -m \"Hello world\"") -} - -func fetchGoogleUserEmail(accessToken string) (string, error) { - req, err := http.NewRequest("GET", "https://www.googleapis.com/oauth2/v2/userinfo", nil) - if err != nil { - return "", err - } - req.Header.Set("Authorization", "Bearer "+accessToken) - - client := &http.Client{Timeout: 10 * time.Second} - resp, err := client.Do(req) - if err != nil { - return "", err - } - defer resp.Body.Close() - - body, _ := io.ReadAll(resp.Body) - if resp.StatusCode != http.StatusOK { - return "", fmt.Errorf("userinfo request failed: %s", string(body)) - } - - var userInfo struct { - Email string `json:"email"` - } - if err := json.Unmarshal(body, &userInfo); err != nil { - return "", err - } - return userInfo.Email, nil -} - -func authLoginPasteToken(provider string) { - cred, err := auth.LoginPasteToken(provider, os.Stdin) - if err != nil { - fmt.Printf("Login failed: %v\n", err) - os.Exit(1) - } - - if err := auth.SetCredential(provider, cred); err != nil { - fmt.Printf("Failed to save credentials: %v\n", err) - os.Exit(1) - } - - appCfg, err := loadConfig() - if err == nil { - switch provider { - case "anthropic": - appCfg.Providers.Anthropic.AuthMethod = "token" - case "openai": - appCfg.Providers.OpenAI.AuthMethod = "token" - } - if err := config.SaveConfig(getConfigPath(), appCfg); err != nil { - fmt.Printf("Warning: could not update config: %v\n", err) - } - } - - fmt.Printf("Token saved for %s!\n", provider) -} - -func authLogoutCmd() { - provider := "" - - args := os.Args[3:] - for i := 0; i < len(args); i++ { - switch args[i] { - case "--provider", "-p": - if i+1 < len(args) { - provider = args[i+1] - i++ - } - } - } - - if provider != "" { - if err := auth.DeleteCredential(provider); err != nil { - fmt.Printf("Failed to remove credentials: %v\n", err) - os.Exit(1) - } - - appCfg, err := loadConfig() - if err == nil { - switch provider { - case "openai": - appCfg.Providers.OpenAI.AuthMethod = "" - case "anthropic": - appCfg.Providers.Anthropic.AuthMethod = "" - case "google-antigravity", "antigravity": - appCfg.Providers.Antigravity.AuthMethod = "" - } - config.SaveConfig(getConfigPath(), appCfg) - } - - fmt.Printf("Logged out from %s\n", provider) - } else { - if err := auth.DeleteAllCredentials(); err != nil { - fmt.Printf("Failed to remove credentials: %v\n", err) - os.Exit(1) - } - - appCfg, err := loadConfig() - if err == nil { - appCfg.Providers.OpenAI.AuthMethod = "" - appCfg.Providers.Anthropic.AuthMethod = "" - appCfg.Providers.Antigravity.AuthMethod = "" - config.SaveConfig(getConfigPath(), appCfg) - } - - fmt.Println("Logged out from all providers") - } -} - -func authStatusCmd() { - store, err := auth.LoadStore() - if err != nil { - fmt.Printf("Error loading auth store: %v\n", err) - return - } - - if len(store.Credentials) == 0 { - fmt.Println("No authenticated providers.") - fmt.Println("Run: picoclaw auth login --provider ") - return - } - - fmt.Println("\nAuthenticated Providers:") - fmt.Println("------------------------") - for provider, cred := range store.Credentials { - status := "active" - if cred.IsExpired() { - status = "expired" - } else if cred.NeedsRefresh() { - status = "needs refresh" - } - - fmt.Printf(" %s:\n", provider) - fmt.Printf(" Method: %s\n", cred.AuthMethod) - fmt.Printf(" Status: %s\n", status) - if cred.AccountID != "" { - fmt.Printf(" Account: %s\n", cred.AccountID) - } - if cred.Email != "" { - fmt.Printf(" Email: %s\n", cred.Email) - } - if cred.ProjectID != "" { - fmt.Printf(" Project: %s\n", cred.ProjectID) - } - if !cred.ExpiresAt.IsZero() { - fmt.Printf(" Expires: %s\n", cred.ExpiresAt.Format("2006-01-02 15:04")) - } - } -} - -func authModelsCmd() { - cred, err := auth.GetCredential("google-antigravity") - if err != nil || cred == nil { - fmt.Println("Not logged in to Google Antigravity.") - fmt.Println("Run: picoclaw auth login --provider google-antigravity") - return - } - - // Refresh token if needed - if cred.NeedsRefresh() && cred.RefreshToken != "" { - oauthCfg := auth.GoogleAntigravityOAuthConfig() - refreshed, refreshErr := auth.RefreshAccessToken(cred, oauthCfg) - if refreshErr == nil { - cred = refreshed - _ = auth.SetCredential("google-antigravity", cred) - } - } - - projectID := cred.ProjectID - if projectID == "" { - fmt.Println("No project ID stored. Try logging in again.") - return - } - - fmt.Printf("Fetching models for project: %s\n\n", projectID) - - models, err := providers.FetchAntigravityModels(cred.AccessToken, projectID) - if err != nil { - fmt.Printf("Error fetching models: %v\n", err) - return - } - - if len(models) == 0 { - fmt.Println("No models available.") - return - } - - fmt.Println("Available Antigravity Models:") - fmt.Println("-----------------------------") - for _, m := range models { - status := "āœ“" - if m.IsExhausted { - status = "āœ— (quota exhausted)" - } - name := m.ID - if m.DisplayName != "" { - name = fmt.Sprintf("%s (%s)", m.ID, m.DisplayName) - } - fmt.Printf(" %s %s\n", status, name) - } -} - func getConfigPath() string { home, _ := os.UserHomeDir() return filepath.Join(home, ".picoclaw", "config.json") } -func setupCronTool(agentLoop *agent.AgentLoop, msgBus *bus.MessageBus, workspace string, restrict bool, execTimeout time.Duration) *cron.CronService { - cronStorePath := filepath.Join(workspace, "cron", "jobs.json") - - // Create cron service - cronService := cron.NewCronService(cronStorePath, nil) - - // Create and register CronTool - cronTool := tools.NewCronTool(cronService, agentLoop, msgBus, workspace, restrict, execTimeout) - agentLoop.RegisterTool(cronTool) - - // Set the onJob handler - cronService.SetOnJob(func(job *cron.CronJob) (string, error) { - result := cronTool.ExecuteJob(context.Background(), job) - return result, nil - }) - - return cronService -} - func loadConfig() (*config.Config, error) { return config.LoadConfig(getConfigPath()) } - -func cronCmd() { - if len(os.Args) < 3 { - cronHelp() - return - } - - subcommand := os.Args[2] - - // Load config to get workspace path - cfg, err := loadConfig() - if err != nil { - fmt.Printf("Error loading config: %v\n", err) - return - } - - cronStorePath := filepath.Join(cfg.WorkspacePath(), "cron", "jobs.json") - - switch subcommand { - case "list": - cronListCmd(cronStorePath) - case "add": - cronAddCmd(cronStorePath) - case "remove": - if len(os.Args) < 4 { - fmt.Println("Usage: picoclaw cron remove ") - return - } - cronRemoveCmd(cronStorePath, os.Args[3]) - case "enable": - cronEnableCmd(cronStorePath, false) - case "disable": - cronEnableCmd(cronStorePath, true) - default: - fmt.Printf("Unknown cron command: %s\n", subcommand) - cronHelp() - } -} - -func cronHelp() { - fmt.Println("\nCron commands:") - fmt.Println(" list List all scheduled jobs") - fmt.Println(" add Add a new scheduled job") - fmt.Println(" remove Remove a job by ID") - fmt.Println(" enable Enable a job") - fmt.Println(" disable Disable a job") - fmt.Println() - fmt.Println("Add options:") - fmt.Println(" -n, --name Job name") - fmt.Println(" -m, --message Message for agent") - fmt.Println(" -e, --every Run every N seconds") - fmt.Println(" -c, --cron Cron expression (e.g. '0 9 * * *')") - fmt.Println(" -d, --deliver Deliver response to channel") - fmt.Println(" --to Recipient for delivery") - fmt.Println(" --channel Channel for delivery") -} - -func cronListCmd(storePath string) { - cs := cron.NewCronService(storePath, nil) - jobs := cs.ListJobs(true) // Show all jobs, including disabled - - if len(jobs) == 0 { - fmt.Println("No scheduled jobs.") - return - } - - fmt.Println("\nScheduled Jobs:") - fmt.Println("----------------") - for _, job := range jobs { - var schedule string - if job.Schedule.Kind == "every" && job.Schedule.EveryMS != nil { - schedule = fmt.Sprintf("every %ds", *job.Schedule.EveryMS/1000) - } else if job.Schedule.Kind == "cron" { - schedule = job.Schedule.Expr - } else { - schedule = "one-time" - } - - nextRun := "scheduled" - if job.State.NextRunAtMS != nil { - nextTime := time.UnixMilli(*job.State.NextRunAtMS) - nextRun = nextTime.Format("2006-01-02 15:04") - } - - status := "enabled" - if !job.Enabled { - status = "disabled" - } - - fmt.Printf(" %s (%s)\n", job.Name, job.ID) - fmt.Printf(" Schedule: %s\n", schedule) - fmt.Printf(" Status: %s\n", status) - fmt.Printf(" Next run: %s\n", nextRun) - } -} - -func cronAddCmd(storePath string) { - name := "" - message := "" - var everySec *int64 - cronExpr := "" - deliver := false - channel := "" - to := "" - - args := os.Args[3:] - for i := 0; i < len(args); i++ { - switch args[i] { - case "-n", "--name": - if i+1 < len(args) { - name = args[i+1] - i++ - } - case "-m", "--message": - if i+1 < len(args) { - message = args[i+1] - i++ - } - case "-e", "--every": - if i+1 < len(args) { - var sec int64 - fmt.Sscanf(args[i+1], "%d", &sec) - everySec = &sec - i++ - } - case "-c", "--cron": - if i+1 < len(args) { - cronExpr = args[i+1] - i++ - } - case "-d", "--deliver": - deliver = true - case "--to": - if i+1 < len(args) { - to = args[i+1] - i++ - } - case "--channel": - if i+1 < len(args) { - channel = args[i+1] - i++ - } - } - } - - if name == "" { - fmt.Println("Error: --name is required") - return - } - - if message == "" { - fmt.Println("Error: --message is required") - return - } - - if everySec == nil && cronExpr == "" { - fmt.Println("Error: Either --every or --cron must be specified") - return - } - - var schedule cron.CronSchedule - if everySec != nil { - everyMS := *everySec * 1000 - schedule = cron.CronSchedule{ - Kind: "every", - EveryMS: &everyMS, - } - } else { - schedule = cron.CronSchedule{ - Kind: "cron", - Expr: cronExpr, - } - } - - cs := cron.NewCronService(storePath, nil) - job, err := cs.AddJob(name, schedule, message, deliver, channel, to) - if err != nil { - fmt.Printf("Error adding job: %v\n", err) - return - } - - fmt.Printf("āœ“ Added job '%s' (%s)\n", job.Name, job.ID) -} - -func cronRemoveCmd(storePath, jobID string) { - cs := cron.NewCronService(storePath, nil) - if cs.RemoveJob(jobID) { - fmt.Printf("āœ“ Removed job %s\n", jobID) - } else { - fmt.Printf("āœ— Job %s not found\n", jobID) - } -} - -func cronEnableCmd(storePath string, disable bool) { - if len(os.Args) < 4 { - fmt.Println("Usage: picoclaw cron enable/disable ") - return - } - - jobID := os.Args[3] - cs := cron.NewCronService(storePath, nil) - enabled := !disable - - job := cs.EnableJob(jobID, enabled) - if job != nil { - status := "enabled" - if disable { - status = "disabled" - } - fmt.Printf("āœ“ Job '%s' %s\n", job.Name, status) - } else { - fmt.Printf("āœ— Job %s not found\n", jobID) - } -} - -func skillsHelp() { - fmt.Println("\nSkills commands:") - fmt.Println(" list List installed skills") - fmt.Println(" install Install skill from GitHub") - fmt.Println(" install-builtin Install all builtin skills to workspace") - fmt.Println(" list-builtin List available builtin skills") - fmt.Println(" remove Remove installed skill") - fmt.Println(" search Search available skills") - fmt.Println(" show Show skill details") - fmt.Println() - fmt.Println("Examples:") - fmt.Println(" picoclaw skills list") - fmt.Println(" picoclaw skills install sipeed/picoclaw-skills/weather") - fmt.Println(" picoclaw skills install-builtin") - fmt.Println(" picoclaw skills list-builtin") - fmt.Println(" picoclaw skills remove weather") -} - -func skillsListCmd(loader *skills.SkillsLoader) { - allSkills := loader.ListSkills() - - if len(allSkills) == 0 { - fmt.Println("No skills installed.") - return - } - - fmt.Println("\nInstalled Skills:") - fmt.Println("------------------") - for _, skill := range allSkills { - fmt.Printf(" āœ“ %s (%s)\n", skill.Name, skill.Source) - if skill.Description != "" { - fmt.Printf(" %s\n", skill.Description) - } - } -} - -func skillsInstallCmd(installer *skills.SkillInstaller) { - if len(os.Args) < 4 { - fmt.Println("Usage: picoclaw skills install ") - fmt.Println("Example: picoclaw skills install sipeed/picoclaw-skills/weather") - return - } - - repo := os.Args[3] - fmt.Printf("Installing skill from %s...\n", repo) - - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() - - if err := installer.InstallFromGitHub(ctx, repo); err != nil { - fmt.Printf("āœ— Failed to install skill: %v\n", err) - os.Exit(1) - } - - fmt.Printf("āœ“ Skill '%s' installed successfully!\n", filepath.Base(repo)) -} - -func skillsRemoveCmd(installer *skills.SkillInstaller, skillName string) { - fmt.Printf("Removing skill '%s'...\n", skillName) - - if err := installer.Uninstall(skillName); err != nil { - fmt.Printf("āœ— Failed to remove skill: %v\n", err) - os.Exit(1) - } - - fmt.Printf("āœ“ Skill '%s' removed successfully!\n", skillName) -} - -func skillsInstallBuiltinCmd(workspace string) { - builtinSkillsDir := "./picoclaw/skills" - workspaceSkillsDir := filepath.Join(workspace, "skills") - - fmt.Printf("Copying builtin skills to workspace...\n") - - skillsToInstall := []string{ - "weather", - "news", - "stock", - "calculator", - } - - for _, skillName := range skillsToInstall { - builtinPath := filepath.Join(builtinSkillsDir, skillName) - workspacePath := filepath.Join(workspaceSkillsDir, skillName) - - if _, err := os.Stat(builtinPath); err != nil { - fmt.Printf("⊘ Builtin skill '%s' not found: %v\n", skillName, err) - continue - } - - if err := os.MkdirAll(workspacePath, 0755); err != nil { - fmt.Printf("āœ— Failed to create directory for %s: %v\n", skillName, err) - continue - } - - if err := copyDirectory(builtinPath, workspacePath); err != nil { - fmt.Printf("āœ— Failed to copy %s: %v\n", skillName, err) - } - } - - fmt.Println("\nāœ“ All builtin skills installed!") - fmt.Println("Now you can use them in your workspace.") -} - -func skillsListBuiltinCmd() { - cfg, err := loadConfig() - if err != nil { - fmt.Printf("Error loading config: %v\n", err) - return - } - builtinSkillsDir := filepath.Join(filepath.Dir(cfg.WorkspacePath()), "picoclaw", "skills") - - fmt.Println("\nAvailable Builtin Skills:") - fmt.Println("-----------------------") - - entries, err := os.ReadDir(builtinSkillsDir) - if err != nil { - fmt.Printf("Error reading builtin skills: %v\n", err) - return - } - - if len(entries) == 0 { - fmt.Println("No builtin skills available.") - return - } - - for _, entry := range entries { - if entry.IsDir() { - skillName := entry.Name() - skillFile := filepath.Join(builtinSkillsDir, skillName, "SKILL.md") - - description := "No description" - if _, err := os.Stat(skillFile); err == nil { - data, err := os.ReadFile(skillFile) - if err == nil { - content := string(data) - if idx := strings.Index(content, "\n"); idx > 0 { - firstLine := content[:idx] - if strings.Contains(firstLine, "description:") { - descLine := strings.Index(content[idx:], "\n") - if descLine > 0 { - description = strings.TrimSpace(content[idx+descLine : idx+descLine]) - } - } - } - } - } - status := "āœ“" - fmt.Printf(" %s %s\n", status, entry.Name()) - if description != "" { - fmt.Printf(" %s\n", description) - } - } - } -} - -func skillsSearchCmd(installer *skills.SkillInstaller) { - fmt.Println("Searching for available skills...") - - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() - - availableSkills, err := installer.ListAvailableSkills(ctx) - if err != nil { - fmt.Printf("āœ— Failed to fetch skills list: %v\n", err) - return - } - - if len(availableSkills) == 0 { - fmt.Println("No skills available.") - return - } - - fmt.Printf("\nAvailable Skills (%d):\n", len(availableSkills)) - fmt.Println("--------------------") - for _, skill := range availableSkills { - fmt.Printf(" šŸ“¦ %s\n", skill.Name) - fmt.Printf(" %s\n", skill.Description) - fmt.Printf(" Repo: %s\n", skill.Repository) - if skill.Author != "" { - fmt.Printf(" Author: %s\n", skill.Author) - } - if len(skill.Tags) > 0 { - fmt.Printf(" Tags: %v\n", skill.Tags) - } - fmt.Println() - } -} - -func skillsShowCmd(loader *skills.SkillsLoader, skillName string) { - content, ok := loader.LoadSkill(skillName) - if !ok { - fmt.Printf("āœ— Skill '%s' not found\n", skillName) - return - } - - fmt.Printf("\nšŸ“¦ Skill: %s\n", skillName) - fmt.Println("----------------------") - fmt.Println(content) -} diff --git a/docs/design/provider-refactoring-tests.md b/docs/design/provider-refactoring-tests.md new file mode 100644 index 000000000..fc6429278 --- /dev/null +++ b/docs/design/provider-refactoring-tests.md @@ -0,0 +1,179 @@ +# Provider Architecture Refactoring - Test Suite Summary + +> PRD: `tasks/prd-provider-refactoring.md` + +This document summarizes the complete test suite designed for the Provider architecture refactoring. + +## Test File Structure + +``` +pkg/ +ā”œā”€ā”€ config/ +│ ā”œā”€ā”€ model_config_test.go # US-001, US-002: ModelConfig struct and GetModelConfig tests +│ └── migration_test.go # US-003: Backward compatibility and migration tests +ā”œā”€ā”€ providers/ +│ ā”œā”€ā”€ registry_test.go # US-006: Load balancing tests +│ ā”œā”€ā”€ integration_test.go # E2E integration tests +│ └── factory/ +│ └── factory_test.go # US-004, US-005: Provider factory tests +``` + +--- + +## Test Case Checklist + +### 1. `pkg/config/model_config_test.go` - Configuration Parsing Tests + +| Test Name | Purpose | PRD Reference | +|-----------|---------|---------------| +| `TestModelConfig_Parsing` | Verify ModelConfig JSON parsing | US-001 | +| `TestModelConfig_ModelListInConfig` | Verify model_list parsing in Config | US-001 | +| `TestModelConfig_Validation` | Verify required field validation | US-001 | +| `TestConfig_GetModelConfig_Found` | Verify GetModelConfig finds model | US-002 | +| `TestConfig_GetModelConfig_NotFound` | Verify GetModelConfig returns error | US-002 | +| `TestConfig_GetModelConfig_EmptyModelList` | Verify empty model_list handling | US-002 | +| `TestConfig_BackwardCompatibility_ProvidersToModelList` | Verify old config conversion | US-003 | +| `TestConfig_DeprecationWarning` | Verify deprecation warning | US-003 | +| `TestModelConfig_ProtocolExtraction` | Verify protocol prefix extraction | US-004 | +| `TestConfig_ModelNameUniqueness` | Verify model_name uniqueness | US-001 | + +### 2. `pkg/config/migration_test.go` - Migration Tests + +| Test Name | Purpose | PRD Reference | +|-----------|---------|---------------| +| `TestConvertProvidersToModelList_OpenAI` | OpenAI config conversion | US-003 | +| `TestConvertProvidersToModelList_Anthropic` | Anthropic config conversion | US-003 | +| `TestConvertProvidersToModelList_MultipleProviders` | Multiple provider conversion | US-003 | +| `TestConvertProvidersToModelList_EmptyProviders` | Empty providers handling | US-003 | +| `TestConvertProvidersToModelList_GitHubCopilot` | GitHub Copilot conversion | US-003 | +| `TestConvertProvidersToModelList_Antigravity` | Antigravity conversion | US-003 | +| `TestGenerateModelName_*` | Model name generation | US-003 | +| `TestHasProvidersConfig_*` | Detect old config existence | US-003 | +| `TestValidateMigration_*` | Migration validation | US-003 | +| `TestMigrateConfig_DryRun` | Dry run migration | US-003 | +| `TestMigrateConfig_Actual` | Actual migration | US-003 | + +### 3. `pkg/providers/registry_test.go` - Load Balancing Tests + +| Test Name | Purpose | PRD Reference | +|-----------|---------|---------------| +| `TestModelRegistry_SingleConfig` | Single config returns same result | US-006 | +| `TestModelRegistry_RoundRobinSelection` | 3-config round-robin selection | US-006 | +| `TestModelRegistry_RoundRobinTwoConfigs` | 2-config round-robin selection | US-006 | +| `TestModelRegistry_ConcurrentAccess` | Concurrent access thread safety | US-006 | +| `TestModelRegistry_RaceDetection` | Data race detection | US-006 | +| `TestModelRegistry_ModelNotFound` | Model not found error | US-006 | +| `TestModelRegistry_EmptyRegistry` | Empty registry handling | US-006 | +| `TestModelRegistry_MultipleModels` | Multiple model registration | US-006 | +| `TestModelRegistry_MixedSingleAndMultiple` | Single/multiple config mix | US-006 | +| `TestModelRegistry_CaseSensitiveModelNames` | Case sensitivity | US-006 | + +### 4. `pkg/providers/factory/factory_test.go` - Provider Factory Tests + +| Test Name | Purpose | PRD Reference | +|-----------|---------|---------------| +| `TestCreateProviderFromConfig_OpenAI` | Create OpenAI provider | US-004 | +| `TestCreateProviderFromConfig_OpenAIDefault` | Default openai protocol | US-004 | +| `TestCreateProviderFromConfig_Anthropic` | Create Anthropic provider | US-004 | +| `TestCreateProviderFromConfig_Antigravity` | Create Antigravity provider | US-004 | +| `TestCreateProviderFromConfig_ClaudeCLI` | Create Claude CLI provider | US-004 | +| `TestCreateProviderFromConfig_CodexCLI` | Create Codex CLI provider | US-004 | +| `TestCreateProviderFromConfig_GitHubCopilot` | Create GitHub Copilot provider | US-004 | +| `TestCreateProviderFromConfig_UnknownProtocol` | Unknown protocol error handling | US-004 | +| `TestCreateProviderFromConfig_MissingAPIKey` | Missing API key error | US-004 | +| `TestExtractProtocol` | Protocol prefix extraction | US-004 | +| `TestCreateProvider_UsesModelList` | Create using model_list | US-005 | +| `TestCreateProvider_FallbackToProviders` | Fallback to providers | US-005 | +| `TestCreateProvider_PriorityModelListOverProviders` | model_list priority | US-005 | + +### 5. `pkg/providers/integration_test.go` - E2E Integration Tests + +| Test Name | Purpose | PRD Reference | +|-----------|---------|---------------| +| `TestE2E_OpenAICompatibleProvider_NoCodeChange` | Zero-code provider addition | Goal | +| `TestE2E_LoadBalancing_RoundRobin` | Load balancing actual effect | US-006 | +| `TestE2E_BackwardCompatibility_OldProvidersConfig` | Old config compatibility | US-003 | +| `TestE2E_ErrorHandling_ModelNotFound` | Model not found | FR-30 | +| `TestE2E_ErrorHandling_MissingAPIKey` | Missing API key | FR-31 | +| `TestE2E_ErrorHandling_InvalidAPIBase` | Invalid API base | FR-30 | +| `TestE2E_ToolCalls_OpenAICompatible` | Tool call support | - | +| `TestE2E_AntigravityProvider` | Antigravity provider | US-004 | +| `TestE2E_ClaudeCLIProvider` | Claude CLI provider | US-004 | + +### 6. Performance Tests + +| Test Name | Purpose | +|-----------|---------| +| `BenchmarkCreateProviderFromConfig` | Provider creation performance | +| `BenchmarkGetModelConfig` | Model lookup performance | +| `BenchmarkGetModelConfigParallel` | Concurrent lookup performance | + +--- + +## Running Tests + +```bash +# Run all tests +go test ./pkg/... -v + +# Run with data race detection +go test ./pkg/... -race + +# Run specific package tests +go test ./pkg/config -v +go test ./pkg/providers -v +go test ./pkg/providers/factory -v + +# Run E2E tests +go test ./pkg/providers -run TestE2E -v + +# Run performance tests +go test ./pkg/providers -bench=. -benchmem +``` + +--- + +## PRD Acceptance Criteria Mapping + +| PRD Acceptance Criteria | Test Cases | +|------------------------|------------| +| US-001: Add ModelConfig struct | `TestModelConfig_Parsing`, `TestModelConfig_Validation` | +| US-001: model_name unique | `TestConfig_ModelNameUniqueness` | +| US-002: GetModelConfig method | `TestConfig_GetModelConfig_*` | +| US-003: Auto-convert providers | `TestConvertProvidersToModelList_*` | +| US-003: Deprecation warning | `TestConfig_DeprecationWarning` | +| US-003: Existing tests pass | (existing test files unchanged) | +| US-004: Protocol prefix factory | `TestExtractProtocol`, `TestCreateProviderFromConfig_*` | +| US-004: Default prefix openai | `TestCreateProviderFromConfig_OpenAIDefault` | +| US-005: CreateProvider uses factory | `TestCreateProvider_*` | +| US-006: Round-robin selection | `TestModelRegistry_RoundRobin*` | +| US-006: Thread-safe atomic | `TestModelRegistry_RaceDetection` | + +--- + +## Recommended Implementation Order + +1. **Phase 1: Configuration Structure** (US-001, US-002) + - Implement `ModelConfig` struct + - Implement `GetModelConfig` method + - Run `model_config_test.go` + +2. **Phase 2: Protocol Factory** (US-004) + - Implement `CreateProviderFromConfig` + - Implement `ExtractProtocol` + - Run `factory_test.go` + +3. **Phase 3: Load Balancing** (US-006) + - Implement `ModelRegistry` + - Implement round-robin selection + - Run `registry_test.go` (with `-race`) + +4. **Phase 4: Backward Compatibility** (US-003, US-005) + - Implement `ConvertProvidersToModelList` + - Refactor `CreateProvider` + - Run `migration_test.go` + - Verify existing tests pass + +5. **Phase 5: E2E Verification** + - Run `integration_test.go` + - Manual testing with `config.example.json` diff --git a/docs/design/provider-refactoring.md b/docs/design/provider-refactoring.md new file mode 100644 index 000000000..ae60b89a1 --- /dev/null +++ b/docs/design/provider-refactoring.md @@ -0,0 +1,334 @@ +# Provider Architecture Refactoring Design + +> Issue: #283 +> Discussion: #122 +> Branch: feat/refactor-provider-by-protocol + +## 1. Current Problems + +### 1.1 Configuration Structure Issues + +**Current State**: Each Provider requires a predefined field in `ProvidersConfig` + +```go +type ProvidersConfig struct { + Anthropic ProviderConfig `json:"anthropic"` + OpenAI ProviderConfig `json:"openai"` + DeepSeek ProviderConfig `json:"deepseek"` + Qwen ProviderConfig `json:"qwen"` + Cerebras ProviderConfig `json:"cerebras"` + VolcEngine ProviderConfig `json:"volcengine"` + // ... every new provider requires changes here +} +``` + +**Problems**: +- Adding a new Provider requires modifying Go code (struct definition) +- `CreateProvider` function in `http_provider.go` has 200+ lines of switch-case +- Most Providers are OpenAI-compatible, but code is duplicated + +### 1.2 Code Bloat Trend + +Recent PRs demonstrate this issue: + +| PR | Provider | Code Changes | +|----|----------|--------------| +| #365 | Qwen | +17 lines to http_provider.go | +| #333 | Cerebras | +17 lines to http_provider.go | +| #368 | Volcengine | +18 lines to http_provider.go | + +Each OpenAI-compatible Provider requires: +1. Modify `config.go` to add configuration field +2. Modify `http_provider.go` to add switch case +3. Update documentation + +### 1.3 Agent-Provider Coupling + +```json +{ + "agents": { + "defaults": { + "provider": "deepseek", // need to know provider name + "model": "deepseek-chat" + } + } +} +``` + +Problem: Agent needs to know both `provider` and `model`, adding complexity. + +--- + +## 2. New Approach: model_list + +### 2.1 Core Principles + +Inspired by [LiteLLM](https://docs.litellm.ai/docs/proxy/configs) design: + +1. **Model-centric**: Users care about models, not providers +2. **Protocol prefix**: Use `protocol/model_name` format, e.g., `openai/gpt-4o`, `anthropic/claude-3-sonnet` +3. **Configuration-driven**: Adding new Providers only requires config changes, no code changes + +### 2.2 New Configuration Structure + +```json +{ + "model_list": [ + { + "model_name": "deepseek-chat", + "model": "openai/deepseek-chat", + "api_base": "https://api.deepseek.com/v1", + "api_key": "sk-xxx" + }, + { + "model_name": "gpt-4o", + "model": "openai/gpt-4o", + "api_key": "sk-xxx" + }, + { + "model_name": "claude-3-sonnet", + "model": "anthropic/claude-3-5-sonnet-20241022", + "api_key": "sk-xxx" + }, + { + "model_name": "gemini-3-flash", + "model": "antigravity/gemini-3-flash", + "auth_method": "oauth" + }, + { + "model_name": "my-company-llm", + "model": "openai/company-model-v1", + "api_base": "https://llm.company.com/v1", + "api_key": "xxx" + } + ], + + "agents": { + "defaults": { + "model": "deepseek-chat", + "max_tokens": 8192, + "temperature": 0.7 + } + } +} +``` + +### 2.3 Go Struct Definition + +```go +type Config struct { + ModelList []ModelConfig `json:"model_list"` // new + Providers ProvidersConfig `json:"providers"` // old, deprecated + + Agents AgentsConfig `json:"agents"` + Channels ChannelsConfig `json:"channels"` + // ... +} + +type ModelConfig struct { + // Required + ModelName string `json:"model_name"` // user-facing name (alias) + Model string `json:"model"` // protocol/model, e.g., openai/gpt-4o + + // Common config + APIBase string `json:"api_base,omitempty"` + APIKey string `json:"api_key,omitempty"` + Proxy string `json:"proxy,omitempty"` + + // Special provider config + AuthMethod string `json:"auth_method,omitempty"` // oauth, token + ConnectMode string `json:"connect_mode,omitempty"` // stdio, grpc + + // Optional optimizations + RPM int `json:"rpm,omitempty"` // rate limit + MaxTokensField string `json:"max_tokens_field,omitempty"` // max_tokens or max_completion_tokens +} +``` + +### 2.4 Protocol Recognition + +Identify protocol via prefix in `model` field: + +| Prefix | Protocol | Description | +|--------|----------|-------------| +| `openai/` | OpenAI-compatible | Most common, includes DeepSeek, Qwen, Groq, etc. | +| `anthropic/` | Anthropic | Claude series specific | +| `antigravity/` | Antigravity | Google Cloud Code Assist | +| `gemini/` | Gemini | Google Gemini native API (if needed) | + +--- + +## 3. Design Rationale + +### 3.1 Problems Solved + +| Problem | Old Approach | New Approach | +|---------|--------------|--------------| +| Add OpenAI-compatible Provider | Change 3 code locations | Add one config entry | +| Agent specifies model | Need provider + model | Only need model | +| Code duplication | Each Provider duplicates logic | Share protocol implementation | +| Multi-Agent support | Complex | Naturally compatible | + +### 3.2 Multi-Agent Compatibility + +```json +{ + "model_list": [...], + + "agents": { + "defaults": { + "model": "deepseek-chat" + }, + "coder": { + "model": "gpt-4o", + "system_prompt": "You are a coding assistant..." + }, + "translator": { + "model": "claude-3-sonnet" + } + } +} +``` + +Each Agent only needs to specify `model` (corresponds to `model_name` in `model_list`). + +### 3.3 Industry Comparison + +**LiteLLM** (most mature open-source LLM Proxy) uses similar design: + +```yaml +model_list: + - model_name: gpt-4o + litellm_params: + model: openai/gpt-4o + api_key: xxx + - model_name: my-custom + litellm_params: + model: openai/custom-model + api_base: https://my-api.com/v1 +``` + +--- + +## 4. Migration Plan + +### 4.1 Phase 1: Compatibility Period (v1.x) + +Support both `providers` and `model_list`: + +```go +func (c *Config) GetModelConfig(modelName string) (*ModelConfig, error) { + // Prefer new config + if len(c.ModelList) > 0 { + return c.findModelByName(modelName) + } + + // Backward compatibility with old config + if !c.Providers.IsEmpty() { + logger.Warn("'providers' config is deprecated, please migrate to 'model_list'") + return c.convertFromProviders(modelName) + } + + return nil, fmt.Errorf("model %s not found", modelName) +} +``` + +### 4.2 Phase 2: Warning Period (late v1.x) + +- Print more prominent warnings at startup +- Provide automatic migration script +- Mark `providers` as deprecated in documentation + +### 4.3 Phase 3: Removal Period (v2.0) + +- Completely remove `providers` support +- Remove `agents.defaults.provider` field +- Only support `model_list` + +### 4.4 Configuration Migration Example + +**Old Config**: +```json +{ + "providers": { + "deepseek": { + "api_key": "sk-xxx", + "api_base": "https://api.deepseek.com/v1" + } + }, + "agents": { + "defaults": { + "provider": "deepseek", + "model": "deepseek-chat" + } + } +} +``` + +**New Config**: +```json +{ + "model_list": [ + { + "model_name": "deepseek-chat", + "model": "openai/deepseek-chat", + "api_base": "https://api.deepseek.com/v1", + "api_key": "sk-xxx" + } + ], + "agents": { + "defaults": { + "model": "deepseek-chat" + } + } +} +``` + +--- + +## 5. Implementation Checklist + +### 5.1 Configuration Layer + +- [ ] Add `ModelConfig` struct +- [ ] Add `Config.ModelList` field +- [ ] Implement `GetModelConfig(modelName)` method +- [ ] Implement old config compatibility conversion +- [ ] Add `model_name` uniqueness validation + +### 5.2 Provider Layer + +- [ ] Create `pkg/providers/factory/` directory +- [ ] Implement `CreateProviderFromModelConfig()` +- [ ] Refactor `http_provider.go` to `openai/provider.go` +- [ ] Maintain backward compatibility for old `CreateProvider()` + +### 5.3 Testing + +- [ ] New config unit tests +- [ ] Old config compatibility tests +- [ ] Integration tests + +### 5.4 Documentation + +- [ ] Update README +- [ ] Update config.example.json +- [ ] Write migration guide + +--- + +## 6. Risks and Mitigations + +| Risk | Mitigation | +|------|------------| +| Breaking existing configs | Compatibility period keeps old config working | +| User migration cost | Provide automatic migration script | +| Special Provider incompatibility | Keep `auth_method` and other extension fields | + +--- + +## 7. References + +- [LiteLLM Config Documentation](https://docs.litellm.ai/docs/proxy/configs) +- [One-API GitHub](https://github.com/songquanpeng/one-api) +- Discussion #122: Refactor Provider Architecture diff --git a/pkg/config/config.go b/pkg/config/config.go index 4f37d9cea..1b6f7b76c 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -232,23 +232,6 @@ func (c *ModelConfig) Validate() error { return nil } -// ParseProtocol extracts the protocol prefix and model identifier from the Model field. -// If no prefix is specified, it defaults to "openai". -// Examples: -// - "openai/gpt-4o" -> ("openai", "gpt-4o") -// - "anthropic/claude-3" -> ("anthropic", "claude-3") -// - "gpt-4o" -> ("openai", "gpt-4o") // default protocol -func (c *ModelConfig) ParseProtocol() (protocol, modelID string) { - model := c.Model - for i := 0; i < len(model); i++ { - if model[i] == '/' { - return model[:i], model[i+1:] - } - } - // No prefix found, default to openai - return "openai", model -} - type GatewayConfig struct { Host string `json:"host" env:"PICOCLAW_GATEWAY_HOST"` Port int `json:"port" env:"PICOCLAW_GATEWAY_PORT"` @@ -286,135 +269,6 @@ type ToolsConfig struct { Cron CronToolsConfig `json:"cron"` } -func DefaultConfig() *Config { - return &Config{ - Agents: AgentsConfig{ - Defaults: AgentDefaults{ - Workspace: "~/.picoclaw/workspace", - RestrictToWorkspace: true, - Provider: "", - Model: "glm-4.7", - MaxTokens: 8192, - Temperature: 0.7, - MaxToolIterations: 20, - }, - }, - Channels: ChannelsConfig{ - WhatsApp: WhatsAppConfig{ - Enabled: false, - BridgeURL: "ws://localhost:3001", - AllowFrom: FlexibleStringSlice{}, - }, - Telegram: TelegramConfig{ - Enabled: false, - Token: "", - AllowFrom: FlexibleStringSlice{}, - }, - Feishu: FeishuConfig{ - Enabled: false, - AppID: "", - AppSecret: "", - EncryptKey: "", - VerificationToken: "", - AllowFrom: FlexibleStringSlice{}, - }, - Discord: DiscordConfig{ - Enabled: false, - Token: "", - AllowFrom: FlexibleStringSlice{}, - }, - MaixCam: MaixCamConfig{ - Enabled: false, - Host: "0.0.0.0", - Port: 18790, - AllowFrom: FlexibleStringSlice{}, - }, - QQ: QQConfig{ - Enabled: false, - AppID: "", - AppSecret: "", - AllowFrom: FlexibleStringSlice{}, - }, - DingTalk: DingTalkConfig{ - Enabled: false, - ClientID: "", - ClientSecret: "", - AllowFrom: FlexibleStringSlice{}, - }, - Slack: SlackConfig{ - Enabled: false, - BotToken: "", - AppToken: "", - AllowFrom: FlexibleStringSlice{}, - }, - LINE: LINEConfig{ - Enabled: false, - ChannelSecret: "", - ChannelAccessToken: "", - WebhookHost: "0.0.0.0", - WebhookPort: 18791, - WebhookPath: "/webhook/line", - AllowFrom: FlexibleStringSlice{}, - }, - OneBot: OneBotConfig{ - Enabled: false, - WSUrl: "ws://127.0.0.1:3001", - AccessToken: "", - ReconnectInterval: 5, - GroupTriggerPrefix: []string{}, - AllowFrom: FlexibleStringSlice{}, - }, - }, - Providers: ProvidersConfig{ - Anthropic: ProviderConfig{}, - OpenAI: ProviderConfig{}, - OpenRouter: ProviderConfig{}, - Groq: ProviderConfig{}, - Zhipu: ProviderConfig{}, - VLLM: ProviderConfig{}, - Gemini: ProviderConfig{}, - Nvidia: ProviderConfig{}, - Moonshot: ProviderConfig{}, - ShengSuanYun: ProviderConfig{}, - Cerebras: ProviderConfig{}, - VolcEngine: ProviderConfig{}, - }, - Gateway: GatewayConfig{ - Host: "0.0.0.0", - Port: 18790, - }, - Tools: ToolsConfig{ - Web: WebToolsConfig{ - Brave: BraveConfig{ - Enabled: false, - APIKey: "", - MaxResults: 5, - }, - DuckDuckGo: DuckDuckGoConfig{ - Enabled: true, - MaxResults: 5, - }, - Perplexity: PerplexityConfig{ - Enabled: false, - APIKey: "", - MaxResults: 5, - }, - }, - Cron: CronToolsConfig{ - ExecTimeoutMinutes: 5, // default 5 minutes for LLM operations - }, - }, - Heartbeat: HeartbeatConfig{ - Enabled: true, - Interval: 30, // default 30 minutes - }, - Devices: DevicesConfig{ - Enabled: false, - MonitorUSB: true, - }, - } -} - func LoadConfig(path string) (*Config, error) { cfg := DefaultConfig() @@ -528,40 +382,61 @@ func expandHome(path string) string { // GetModelConfig returns the ModelConfig for the given model name. // If multiple configs exist with the same model_name, it uses round-robin // selection for load balancing. Returns an error if the model is not found. +// Uses double-check locking for optimal read performance. func (c *Config) GetModelConfig(modelName string) (*ModelConfig, error) { - c.mu.Lock() - defer c.mu.Unlock() + // First pass: use read lock to find matches + c.mu.RLock() + matches := c.findMatchesLocked(modelName) + if len(matches) == 0 { + c.mu.RUnlock() + return nil, fmt.Errorf("model %q not found in model_list or providers", modelName) + } + if len(matches) == 1 { + c.mu.RUnlock() + return &matches[0], nil + } - // Find all configs with matching model_name + // Multiple configs - check if counter exists + counter, ok := c.rrCounters[modelName] + c.mu.RUnlock() + + // Double-check locking: only acquire write lock if counter needs initialization + if !ok { + c.mu.Lock() + // Re-check after acquiring write lock + if c.rrCounters == nil { + c.rrCounters = make(map[string]*atomic.Uint64) + } + if c.rrCounters[modelName] == nil { + c.rrCounters[modelName] = &atomic.Uint64{} + } + counter = c.rrCounters[modelName] + c.mu.Unlock() + } + + // Re-fetch matches to ensure consistency (ModelList could have changed) + c.mu.RLock() + matches = c.findMatchesLocked(modelName) + c.mu.RUnlock() + + if len(matches) == 0 { + return nil, fmt.Errorf("model %q not found in model_list or providers", modelName) + } + + idx := counter.Add(1) % uint64(len(matches)) + return &matches[idx], nil +} + +// findMatchesLocked finds all ModelConfig entries with the given model_name. +// Must be called with c.mu locked (read or write). +func (c *Config) findMatchesLocked(modelName string) []ModelConfig { var matches []ModelConfig for i := range c.ModelList { if c.ModelList[i].ModelName == modelName { matches = append(matches, c.ModelList[i]) } } - - if len(matches) == 0 { - return nil, fmt.Errorf("model %q not found in model_list or providers", modelName) - } - - // Single config - return directly - if len(matches) == 1 { - return &matches[0], nil - } - - // Multiple configs - use round-robin for load balancing - if c.rrCounters == nil { - c.rrCounters = make(map[string]*atomic.Uint64) - } - - counter, ok := c.rrCounters[modelName] - if !ok { - counter = &atomic.Uint64{} - c.rrCounters[modelName] = counter - } - - idx := counter.Add(1) % uint64(len(matches)) - return &matches[idx], nil + return matches } // HasProvidersConfig checks if any provider in the old providers config has configuration. @@ -599,203 +474,3 @@ func (c *Config) ValidateModelList() error { } return nil } - -// ConvertProvidersToModelList converts the old ProvidersConfig to a slice of ModelConfig. -// This enables backward compatibility with existing configurations. -func ConvertProvidersToModelList(cfg *Config) []ModelConfig { - if cfg == nil { - return nil - } - - var result []ModelConfig - p := cfg.Providers - - // OpenAI - if p.OpenAI.APIKey != "" || p.OpenAI.APIBase != "" { - result = append(result, ModelConfig{ - ModelName: "openai", - Model: "openai/gpt-4o", - APIKey: p.OpenAI.APIKey, - APIBase: p.OpenAI.APIBase, - Proxy: p.OpenAI.Proxy, - AuthMethod: p.OpenAI.AuthMethod, - }) - } - - // Anthropic - if p.Anthropic.APIKey != "" || p.Anthropic.APIBase != "" { - result = append(result, ModelConfig{ - ModelName: "anthropic", - Model: "anthropic/claude-3-sonnet", - APIKey: p.Anthropic.APIKey, - APIBase: p.Anthropic.APIBase, - Proxy: p.Anthropic.Proxy, - AuthMethod: p.Anthropic.AuthMethod, - }) - } - - // OpenRouter - if p.OpenRouter.APIKey != "" || p.OpenRouter.APIBase != "" { - result = append(result, ModelConfig{ - ModelName: "openrouter", - Model: "openrouter/auto", - APIKey: p.OpenRouter.APIKey, - APIBase: p.OpenRouter.APIBase, - Proxy: p.OpenRouter.Proxy, - }) - } - - // Groq - if p.Groq.APIKey != "" || p.Groq.APIBase != "" { - result = append(result, ModelConfig{ - ModelName: "groq", - Model: "groq/llama-3.1-70b-versatile", - APIKey: p.Groq.APIKey, - APIBase: p.Groq.APIBase, - Proxy: p.Groq.Proxy, - }) - } - - // Zhipu - if p.Zhipu.APIKey != "" || p.Zhipu.APIBase != "" { - result = append(result, ModelConfig{ - ModelName: "zhipu", - Model: "openai/glm-4", - APIKey: p.Zhipu.APIKey, - APIBase: p.Zhipu.APIBase, - Proxy: p.Zhipu.Proxy, - }) - } - - // VLLM - if p.VLLM.APIKey != "" || p.VLLM.APIBase != "" { - result = append(result, ModelConfig{ - ModelName: "vllm", - Model: "openai/auto", - APIKey: p.VLLM.APIKey, - APIBase: p.VLLM.APIBase, - Proxy: p.VLLM.Proxy, - }) - } - - // Gemini - if p.Gemini.APIKey != "" || p.Gemini.APIBase != "" { - result = append(result, ModelConfig{ - ModelName: "gemini", - Model: "openai/gemini-pro", - APIKey: p.Gemini.APIKey, - APIBase: p.Gemini.APIBase, - Proxy: p.Gemini.Proxy, - }) - } - - // Nvidia - if p.Nvidia.APIKey != "" || p.Nvidia.APIBase != "" { - result = append(result, ModelConfig{ - ModelName: "nvidia", - Model: "nvidia/meta/llama-3.1-8b-instruct", - APIKey: p.Nvidia.APIKey, - APIBase: p.Nvidia.APIBase, - Proxy: p.Nvidia.Proxy, - }) - } - - // Ollama - if p.Ollama.APIKey != "" || p.Ollama.APIBase != "" { - result = append(result, ModelConfig{ - ModelName: "ollama", - Model: "ollama/llama3", - APIKey: p.Ollama.APIKey, - APIBase: p.Ollama.APIBase, - Proxy: p.Ollama.Proxy, - }) - } - - // Moonshot - if p.Moonshot.APIKey != "" || p.Moonshot.APIBase != "" { - result = append(result, ModelConfig{ - ModelName: "moonshot", - Model: "moonshot/kimi", - APIKey: p.Moonshot.APIKey, - APIBase: p.Moonshot.APIBase, - Proxy: p.Moonshot.Proxy, - }) - } - - // ShengSuanYun - if p.ShengSuanYun.APIKey != "" || p.ShengSuanYun.APIBase != "" { - result = append(result, ModelConfig{ - ModelName: "shengsuanyun", - Model: "openai/auto", - APIKey: p.ShengSuanYun.APIKey, - APIBase: p.ShengSuanYun.APIBase, - Proxy: p.ShengSuanYun.Proxy, - }) - } - - // DeepSeek - if p.DeepSeek.APIKey != "" || p.DeepSeek.APIBase != "" { - result = append(result, ModelConfig{ - ModelName: "deepseek", - Model: "openai/deepseek-chat", - APIKey: p.DeepSeek.APIKey, - APIBase: p.DeepSeek.APIBase, - Proxy: p.DeepSeek.Proxy, - }) - } - - // Cerebras - if p.Cerebras.APIKey != "" || p.Cerebras.APIBase != "" { - result = append(result, ModelConfig{ - ModelName: "cerebras", - Model: "cerebras/llama-3.3-70b", - APIKey: p.Cerebras.APIKey, - APIBase: p.Cerebras.APIBase, - Proxy: p.Cerebras.Proxy, - }) - } - - // VolcEngine (Doubao) - if p.VolcEngine.APIKey != "" || p.VolcEngine.APIBase != "" { - result = append(result, ModelConfig{ - ModelName: "volcengine", - Model: "openai/doubao-pro", - APIKey: p.VolcEngine.APIKey, - APIBase: p.VolcEngine.APIBase, - Proxy: p.VolcEngine.Proxy, - }) - } - - // GitHub Copilot - if p.GitHubCopilot.APIKey != "" || p.GitHubCopilot.APIBase != "" || p.GitHubCopilot.ConnectMode != "" { - result = append(result, ModelConfig{ - ModelName: "github-copilot", - Model: "github-copilot/gpt-4o", - APIBase: p.GitHubCopilot.APIBase, - ConnectMode: p.GitHubCopilot.ConnectMode, - }) - } - - // Antigravity - if p.Antigravity.APIKey != "" || p.Antigravity.AuthMethod != "" { - result = append(result, ModelConfig{ - ModelName: "antigravity", - Model: "antigravity/gemini-2.0-flash", - APIKey: p.Antigravity.APIKey, - AuthMethod: p.Antigravity.AuthMethod, - }) - } - - // Qwen - if p.Qwen.APIKey != "" || p.Qwen.APIBase != "" { - result = append(result, ModelConfig{ - ModelName: "qwen", - Model: "qwen/qwen-max", - APIKey: p.Qwen.APIKey, - APIBase: p.Qwen.APIBase, - Proxy: p.Qwen.Proxy, - }) - } - - return result -} diff --git a/pkg/config/defaults.go b/pkg/config/defaults.go new file mode 100644 index 000000000..fcfdd788d --- /dev/null +++ b/pkg/config/defaults.go @@ -0,0 +1,136 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package config + +// DefaultConfig returns the default configuration for PicoClaw. +func DefaultConfig() *Config { + return &Config{ + Agents: AgentsConfig{ + Defaults: AgentDefaults{ + Workspace: "~/.picoclaw/workspace", + RestrictToWorkspace: true, + Provider: "", + Model: "glm-4.7", + MaxTokens: 8192, + Temperature: 0.7, + MaxToolIterations: 20, + }, + }, + Channels: ChannelsConfig{ + WhatsApp: WhatsAppConfig{ + Enabled: false, + BridgeURL: "ws://localhost:3001", + AllowFrom: FlexibleStringSlice{}, + }, + Telegram: TelegramConfig{ + Enabled: false, + Token: "", + AllowFrom: FlexibleStringSlice{}, + }, + Feishu: FeishuConfig{ + Enabled: false, + AppID: "", + AppSecret: "", + EncryptKey: "", + VerificationToken: "", + AllowFrom: FlexibleStringSlice{}, + }, + Discord: DiscordConfig{ + Enabled: false, + Token: "", + AllowFrom: FlexibleStringSlice{}, + }, + MaixCam: MaixCamConfig{ + Enabled: false, + Host: "0.0.0.0", + Port: 18790, + AllowFrom: FlexibleStringSlice{}, + }, + QQ: QQConfig{ + Enabled: false, + AppID: "", + AppSecret: "", + AllowFrom: FlexibleStringSlice{}, + }, + DingTalk: DingTalkConfig{ + Enabled: false, + ClientID: "", + ClientSecret: "", + AllowFrom: FlexibleStringSlice{}, + }, + Slack: SlackConfig{ + Enabled: false, + BotToken: "", + AppToken: "", + AllowFrom: FlexibleStringSlice{}, + }, + LINE: LINEConfig{ + Enabled: false, + ChannelSecret: "", + ChannelAccessToken: "", + WebhookHost: "0.0.0.0", + WebhookPort: 18791, + WebhookPath: "/webhook/line", + AllowFrom: FlexibleStringSlice{}, + }, + OneBot: OneBotConfig{ + Enabled: false, + WSUrl: "ws://127.0.0.1:3001", + AccessToken: "", + ReconnectInterval: 5, + GroupTriggerPrefix: []string{}, + AllowFrom: FlexibleStringSlice{}, + }, + }, + Providers: ProvidersConfig{ + Anthropic: ProviderConfig{}, + OpenAI: ProviderConfig{}, + OpenRouter: ProviderConfig{}, + Groq: ProviderConfig{}, + Zhipu: ProviderConfig{}, + VLLM: ProviderConfig{}, + Gemini: ProviderConfig{}, + Nvidia: ProviderConfig{}, + Moonshot: ProviderConfig{}, + ShengSuanYun: ProviderConfig{}, + Cerebras: ProviderConfig{}, + VolcEngine: ProviderConfig{}, + }, + Gateway: GatewayConfig{ + Host: "0.0.0.0", + Port: 18790, + }, + Tools: ToolsConfig{ + Web: WebToolsConfig{ + Brave: BraveConfig{ + Enabled: false, + APIKey: "", + MaxResults: 5, + }, + DuckDuckGo: DuckDuckGoConfig{ + Enabled: true, + MaxResults: 5, + }, + Perplexity: PerplexityConfig{ + Enabled: false, + APIKey: "", + MaxResults: 5, + }, + }, + Cron: CronToolsConfig{ + ExecTimeoutMinutes: 5, // default 5 minutes for LLM operations + }, + }, + Heartbeat: HeartbeatConfig{ + Enabled: true, + Interval: 30, // default 30 minutes + }, + Devices: DevicesConfig{ + Enabled: false, + MonitorUSB: true, + }, + } +} diff --git a/pkg/config/migration.go b/pkg/config/migration.go new file mode 100644 index 000000000..d1e165fbb --- /dev/null +++ b/pkg/config/migration.go @@ -0,0 +1,206 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package config + +// ConvertProvidersToModelList converts the old ProvidersConfig to a slice of ModelConfig. +// This enables backward compatibility with existing configurations. +func ConvertProvidersToModelList(cfg *Config) []ModelConfig { + if cfg == nil { + return nil + } + + var result []ModelConfig + p := cfg.Providers + + // OpenAI + if p.OpenAI.APIKey != "" || p.OpenAI.APIBase != "" { + result = append(result, ModelConfig{ + ModelName: "openai", + Model: "openai/gpt-4o", + APIKey: p.OpenAI.APIKey, + APIBase: p.OpenAI.APIBase, + Proxy: p.OpenAI.Proxy, + AuthMethod: p.OpenAI.AuthMethod, + }) + } + + // Anthropic + if p.Anthropic.APIKey != "" || p.Anthropic.APIBase != "" { + result = append(result, ModelConfig{ + ModelName: "anthropic", + Model: "anthropic/claude-3-sonnet", + APIKey: p.Anthropic.APIKey, + APIBase: p.Anthropic.APIBase, + Proxy: p.Anthropic.Proxy, + AuthMethod: p.Anthropic.AuthMethod, + }) + } + + // OpenRouter + if p.OpenRouter.APIKey != "" || p.OpenRouter.APIBase != "" { + result = append(result, ModelConfig{ + ModelName: "openrouter", + Model: "openrouter/auto", + APIKey: p.OpenRouter.APIKey, + APIBase: p.OpenRouter.APIBase, + Proxy: p.OpenRouter.Proxy, + }) + } + + // Groq + if p.Groq.APIKey != "" || p.Groq.APIBase != "" { + result = append(result, ModelConfig{ + ModelName: "groq", + Model: "groq/llama-3.1-70b-versatile", + APIKey: p.Groq.APIKey, + APIBase: p.Groq.APIBase, + Proxy: p.Groq.Proxy, + }) + } + + // Zhipu + if p.Zhipu.APIKey != "" || p.Zhipu.APIBase != "" { + result = append(result, ModelConfig{ + ModelName: "zhipu", + Model: "openai/glm-4", + APIKey: p.Zhipu.APIKey, + APIBase: p.Zhipu.APIBase, + Proxy: p.Zhipu.Proxy, + }) + } + + // VLLM + if p.VLLM.APIKey != "" || p.VLLM.APIBase != "" { + result = append(result, ModelConfig{ + ModelName: "vllm", + Model: "openai/auto", + APIKey: p.VLLM.APIKey, + APIBase: p.VLLM.APIBase, + Proxy: p.VLLM.Proxy, + }) + } + + // Gemini + if p.Gemini.APIKey != "" || p.Gemini.APIBase != "" { + result = append(result, ModelConfig{ + ModelName: "gemini", + Model: "openai/gemini-pro", + APIKey: p.Gemini.APIKey, + APIBase: p.Gemini.APIBase, + Proxy: p.Gemini.Proxy, + }) + } + + // Nvidia + if p.Nvidia.APIKey != "" || p.Nvidia.APIBase != "" { + result = append(result, ModelConfig{ + ModelName: "nvidia", + Model: "nvidia/meta/llama-3.1-8b-instruct", + APIKey: p.Nvidia.APIKey, + APIBase: p.Nvidia.APIBase, + Proxy: p.Nvidia.Proxy, + }) + } + + // Ollama + if p.Ollama.APIKey != "" || p.Ollama.APIBase != "" { + result = append(result, ModelConfig{ + ModelName: "ollama", + Model: "ollama/llama3", + APIKey: p.Ollama.APIKey, + APIBase: p.Ollama.APIBase, + Proxy: p.Ollama.Proxy, + }) + } + + // Moonshot + if p.Moonshot.APIKey != "" || p.Moonshot.APIBase != "" { + result = append(result, ModelConfig{ + ModelName: "moonshot", + Model: "moonshot/kimi", + APIKey: p.Moonshot.APIKey, + APIBase: p.Moonshot.APIBase, + Proxy: p.Moonshot.Proxy, + }) + } + + // ShengSuanYun + if p.ShengSuanYun.APIKey != "" || p.ShengSuanYun.APIBase != "" { + result = append(result, ModelConfig{ + ModelName: "shengsuanyun", + Model: "openai/auto", + APIKey: p.ShengSuanYun.APIKey, + APIBase: p.ShengSuanYun.APIBase, + Proxy: p.ShengSuanYun.Proxy, + }) + } + + // DeepSeek + if p.DeepSeek.APIKey != "" || p.DeepSeek.APIBase != "" { + result = append(result, ModelConfig{ + ModelName: "deepseek", + Model: "openai/deepseek-chat", + APIKey: p.DeepSeek.APIKey, + APIBase: p.DeepSeek.APIBase, + Proxy: p.DeepSeek.Proxy, + }) + } + + // Cerebras + if p.Cerebras.APIKey != "" || p.Cerebras.APIBase != "" { + result = append(result, ModelConfig{ + ModelName: "cerebras", + Model: "cerebras/llama-3.3-70b", + APIKey: p.Cerebras.APIKey, + APIBase: p.Cerebras.APIBase, + Proxy: p.Cerebras.Proxy, + }) + } + + // VolcEngine (Doubao) + if p.VolcEngine.APIKey != "" || p.VolcEngine.APIBase != "" { + result = append(result, ModelConfig{ + ModelName: "volcengine", + Model: "openai/doubao-pro", + APIKey: p.VolcEngine.APIKey, + APIBase: p.VolcEngine.APIBase, + Proxy: p.VolcEngine.Proxy, + }) + } + + // GitHub Copilot + if p.GitHubCopilot.APIKey != "" || p.GitHubCopilot.APIBase != "" || p.GitHubCopilot.ConnectMode != "" { + result = append(result, ModelConfig{ + ModelName: "github-copilot", + Model: "github-copilot/gpt-4o", + APIBase: p.GitHubCopilot.APIBase, + ConnectMode: p.GitHubCopilot.ConnectMode, + }) + } + + // Antigravity + if p.Antigravity.APIKey != "" || p.Antigravity.AuthMethod != "" { + result = append(result, ModelConfig{ + ModelName: "antigravity", + Model: "antigravity/gemini-2.0-flash", + APIKey: p.Antigravity.APIKey, + AuthMethod: p.Antigravity.AuthMethod, + }) + } + + // Qwen + if p.Qwen.APIKey != "" || p.Qwen.APIBase != "" { + result = append(result, ModelConfig{ + ModelName: "qwen", + Model: "qwen/qwen-max", + APIKey: p.Qwen.APIKey, + APIBase: p.Qwen.APIBase, + Proxy: p.Qwen.Proxy, + }) + } + + return result +} diff --git a/pkg/config/migration_test.go b/pkg/config/migration_test.go new file mode 100644 index 000000000..eff16ee7a --- /dev/null +++ b/pkg/config/migration_test.go @@ -0,0 +1,177 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package config + +import ( + "testing" +) + +func TestConvertProvidersToModelList_OpenAI(t *testing.T) { + cfg := &Config{ + Providers: ProvidersConfig{ + OpenAI: ProviderConfig{ + APIKey: "sk-test-key", + APIBase: "https://custom.api.com/v1", + }, + }, + } + + result := ConvertProvidersToModelList(cfg) + + if len(result) != 1 { + t.Fatalf("len(result) = %d, want 1", len(result)) + } + + if result[0].ModelName != "openai" { + t.Errorf("ModelName = %q, want %q", result[0].ModelName, "openai") + } + if result[0].Model != "openai/gpt-4o" { + t.Errorf("Model = %q, want %q", result[0].Model, "openai/gpt-4o") + } + if result[0].APIKey != "sk-test-key" { + t.Errorf("APIKey = %q, want %q", result[0].APIKey, "sk-test-key") + } +} + +func TestConvertProvidersToModelList_Anthropic(t *testing.T) { + cfg := &Config{ + Providers: ProvidersConfig{ + Anthropic: ProviderConfig{ + APIKey: "ant-key", + APIBase: "https://custom.anthropic.com", + }, + }, + } + + result := ConvertProvidersToModelList(cfg) + + if len(result) != 1 { + t.Fatalf("len(result) = %d, want 1", len(result)) + } + + if result[0].ModelName != "anthropic" { + t.Errorf("ModelName = %q, want %q", result[0].ModelName, "anthropic") + } + if result[0].Model != "anthropic/claude-3-sonnet" { + t.Errorf("Model = %q, want %q", result[0].Model, "anthropic/claude-3-sonnet") + } +} + +func TestConvertProvidersToModelList_Multiple(t *testing.T) { + cfg := &Config{ + Providers: ProvidersConfig{ + OpenAI: ProviderConfig{APIKey: "openai-key"}, + Groq: ProviderConfig{APIKey: "groq-key"}, + Zhipu: ProviderConfig{APIKey: "zhipu-key"}, + }, + } + + result := ConvertProvidersToModelList(cfg) + + if len(result) != 3 { + t.Fatalf("len(result) = %d, want 3", len(result)) + } + + // Check that all providers are present + found := make(map[string]bool) + for _, mc := range result { + found[mc.ModelName] = true + } + + for _, name := range []string{"openai", "groq", "zhipu"} { + if !found[name] { + t.Errorf("Missing provider %q in result", name) + } + } +} + +func TestConvertProvidersToModelList_Empty(t *testing.T) { + cfg := &Config{ + Providers: ProvidersConfig{}, + } + + result := ConvertProvidersToModelList(cfg) + + if len(result) != 0 { + t.Errorf("len(result) = %d, want 0", len(result)) + } +} + +func TestConvertProvidersToModelList_Nil(t *testing.T) { + result := ConvertProvidersToModelList(nil) + + if result != nil { + t.Errorf("result = %v, want nil", result) + } +} + +func TestConvertProvidersToModelList_AllProviders(t *testing.T) { + cfg := &Config{ + Providers: ProvidersConfig{ + OpenAI: ProviderConfig{APIKey: "key1"}, + Anthropic: ProviderConfig{APIKey: "key2"}, + OpenRouter: ProviderConfig{APIKey: "key3"}, + Groq: ProviderConfig{APIKey: "key4"}, + Zhipu: ProviderConfig{APIKey: "key5"}, + VLLM: ProviderConfig{APIKey: "key6"}, + Gemini: ProviderConfig{APIKey: "key7"}, + Nvidia: ProviderConfig{APIKey: "key8"}, + Ollama: ProviderConfig{APIKey: "key9"}, + Moonshot: ProviderConfig{APIKey: "key10"}, + ShengSuanYun: ProviderConfig{APIKey: "key11"}, + DeepSeek: ProviderConfig{APIKey: "key12"}, + Cerebras: ProviderConfig{APIKey: "key13"}, + VolcEngine: ProviderConfig{APIKey: "key14"}, + GitHubCopilot: ProviderConfig{ConnectMode: "grpc"}, + Antigravity: ProviderConfig{AuthMethod: "oauth"}, + Qwen: ProviderConfig{APIKey: "key17"}, + }, + } + + result := ConvertProvidersToModelList(cfg) + + // All 17 providers should be converted + if len(result) != 17 { + t.Errorf("len(result) = %d, want 17", len(result)) + } +} + +func TestConvertProvidersToModelList_Proxy(t *testing.T) { + cfg := &Config{ + Providers: ProvidersConfig{ + OpenAI: ProviderConfig{ + APIKey: "key", + Proxy: "http://proxy:8080", + }, + }, + } + + result := ConvertProvidersToModelList(cfg) + + if len(result) != 1 { + t.Fatalf("len(result) = %d, want 1", len(result)) + } + + if result[0].Proxy != "http://proxy:8080" { + t.Errorf("Proxy = %q, want %q", result[0].Proxy, "http://proxy:8080") + } +} + +func TestConvertProvidersToModelList_AuthMethod(t *testing.T) { + cfg := &Config{ + Providers: ProvidersConfig{ + OpenAI: ProviderConfig{ + AuthMethod: "oauth", + }, + }, + } + + result := ConvertProvidersToModelList(cfg) + + if len(result) != 0 { + t.Errorf("len(result) = %d, want 0 (AuthMethod alone should not create entry)", len(result)) + } +} diff --git a/pkg/config/model_config_test.go b/pkg/config/model_config_test.go new file mode 100644 index 000000000..9d817964a --- /dev/null +++ b/pkg/config/model_config_test.go @@ -0,0 +1,204 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package config + +import ( + "sync" + "testing" +) + +func TestGetModelConfig_Found(t *testing.T) { + cfg := &Config{ + ModelList: []ModelConfig{ + {ModelName: "test-model", Model: "openai/gpt-4o", APIKey: "key1"}, + {ModelName: "other-model", Model: "anthropic/claude", APIKey: "key2"}, + }, + } + + result, err := cfg.GetModelConfig("test-model") + if err != nil { + t.Fatalf("GetModelConfig() error = %v", err) + } + if result.Model != "openai/gpt-4o" { + t.Errorf("Model = %q, want %q", result.Model, "openai/gpt-4o") + } +} + +func TestGetModelConfig_NotFound(t *testing.T) { + cfg := &Config{ + ModelList: []ModelConfig{ + {ModelName: "test-model", Model: "openai/gpt-4o", APIKey: "key1"}, + }, + } + + _, err := cfg.GetModelConfig("nonexistent") + if err == nil { + t.Fatal("GetModelConfig() expected error for nonexistent model") + } +} + +func TestGetModelConfig_EmptyList(t *testing.T) { + cfg := &Config{ + ModelList: []ModelConfig{}, + } + + _, err := cfg.GetModelConfig("any-model") + if err == nil { + t.Fatal("GetModelConfig() expected error for empty model list") + } +} + +func TestGetModelConfig_RoundRobin(t *testing.T) { + cfg := &Config{ + ModelList: []ModelConfig{ + {ModelName: "lb-model", Model: "openai/gpt-4o-1", APIKey: "key1"}, + {ModelName: "lb-model", Model: "openai/gpt-4o-2", APIKey: "key2"}, + {ModelName: "lb-model", Model: "openai/gpt-4o-3", APIKey: "key3"}, + }, + } + + // Test round-robin distribution + results := make(map[string]int) + for i := 0; i < 30; i++ { + result, err := cfg.GetModelConfig("lb-model") + if err != nil { + t.Fatalf("GetModelConfig() error = %v", err) + } + results[result.Model]++ + } + + // Each model should appear roughly 10 times (30 calls / 3 models) + for model, count := range results { + if count < 5 || count > 15 { + t.Errorf("Model %s appeared %d times, expected ~10", model, count) + } + } +} + +func TestGetModelConfig_Concurrent(t *testing.T) { + cfg := &Config{ + ModelList: []ModelConfig{ + {ModelName: "concurrent-model", Model: "openai/gpt-4o-1", APIKey: "key1"}, + {ModelName: "concurrent-model", Model: "openai/gpt-4o-2", APIKey: "key2"}, + }, + } + + const goroutines = 100 + const iterations = 10 + + var wg sync.WaitGroup + errors := make(chan error, goroutines*iterations) + + for i := 0; i < goroutines; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for j := 0; j < iterations; j++ { + _, err := cfg.GetModelConfig("concurrent-model") + if err != nil { + errors <- err + } + } + }() + } + + wg.Wait() + close(errors) + + for err := range errors { + t.Errorf("Concurrent GetModelConfig() error: %v", err) + } +} + +func TestModelConfig_Validate(t *testing.T) { + tests := []struct { + name string + config ModelConfig + wantErr bool + }{ + { + name: "valid config", + config: ModelConfig{ + ModelName: "test", + Model: "openai/gpt-4o", + }, + wantErr: false, + }, + { + name: "missing model_name", + config: ModelConfig{ + Model: "openai/gpt-4o", + }, + wantErr: true, + }, + { + name: "missing model", + config: ModelConfig{ + ModelName: "test", + }, + wantErr: true, + }, + { + name: "empty config", + config: ModelConfig{}, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.config.Validate() + if (err != nil) != tt.wantErr { + t.Errorf("Validate() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestConfig_ValidateModelList(t *testing.T) { + tests := []struct { + name string + config *Config + wantErr bool + }{ + { + name: "valid list", + config: &Config{ + ModelList: []ModelConfig{ + {ModelName: "test1", Model: "openai/gpt-4o"}, + {ModelName: "test2", Model: "anthropic/claude"}, + }, + }, + wantErr: false, + }, + { + name: "invalid entry", + config: &Config{ + ModelList: []ModelConfig{ + {ModelName: "test1", Model: "openai/gpt-4o"}, + {ModelName: "", Model: "anthropic/claude"}, // missing model_name + }, + }, + wantErr: true, + }, + { + name: "empty list", + config: &Config{ + ModelList: []ModelConfig{}, + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.config.ValidateModelList() + if (err != nil) != tt.wantErr { + t.Errorf("ValidateModelList() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} diff --git a/pkg/migrate/migrate_test.go b/pkg/migrate/migrate_test.go index be2360aac..cd36043f7 100644 --- a/pkg/migrate/migrate_test.go +++ b/pkg/migrate/migrate_test.go @@ -180,8 +180,8 @@ func TestConvertConfig(t *testing.T) { t.Run("unsupported provider warning", func(t *testing.T) { data := map[string]interface{}{ "providers": map[string]interface{}{ - "deepseek": map[string]interface{}{ - "api_key": "sk-deep-test", + "unknown_provider": map[string]interface{}{ + "api_key": "sk-test", }, }, } @@ -193,7 +193,7 @@ func TestConvertConfig(t *testing.T) { if len(warnings) != 1 { t.Fatalf("expected 1 warning, got %d", len(warnings)) } - if warnings[0] != "Provider 'deepseek' not supported in PicoClaw, skipping" { + if warnings[0] != "Provider 'unknown_provider' not supported in PicoClaw, skipping" { t.Errorf("unexpected warning: %s", warnings[0]) } }) diff --git a/pkg/providers/claude_cli_provider_test.go b/pkg/providers/claude_cli_provider_test.go index 063530deb..ae49af042 100644 --- a/pkg/providers/claude_cli_provider_test.go +++ b/pkg/providers/claude_cli_provider_test.go @@ -419,7 +419,7 @@ func TestCreateProvider_ClaudeCli(t *testing.T) { cfg.Agents.Defaults.Provider = "claude-cli" cfg.Agents.Defaults.Workspace = "/test/ws" - provider, err := CreateProvider(cfg) + provider, _, err := CreateProvider(cfg) if err != nil { t.Fatalf("CreateProvider(claude-cli) error = %v", err) } @@ -437,7 +437,7 @@ func TestCreateProvider_ClaudeCode(t *testing.T) { cfg := config.DefaultConfig() cfg.Agents.Defaults.Provider = "claude-code" - provider, err := CreateProvider(cfg) + provider, _, err := CreateProvider(cfg) if err != nil { t.Fatalf("CreateProvider(claude-code) error = %v", err) } @@ -450,7 +450,7 @@ func TestCreateProvider_ClaudeCodec(t *testing.T) { cfg := config.DefaultConfig() cfg.Agents.Defaults.Provider = "claudecode" - provider, err := CreateProvider(cfg) + provider, _, err := CreateProvider(cfg) if err != nil { t.Fatalf("CreateProvider(claudecode) error = %v", err) } @@ -464,7 +464,7 @@ func TestCreateProvider_ClaudeCliDefaultWorkspace(t *testing.T) { cfg.Agents.Defaults.Provider = "claude-cli" cfg.Agents.Defaults.Workspace = "" - provider, err := CreateProvider(cfg) + provider, _, err := CreateProvider(cfg) if err != nil { t.Fatalf("CreateProvider error = %v", err) } diff --git a/pkg/providers/factory_provider.go b/pkg/providers/factory_provider.go index ff9a4ef20..695d4ffa5 100644 --- a/pkg/providers/factory_provider.go +++ b/pkg/providers/factory_provider.go @@ -32,13 +32,14 @@ func ExtractProtocol(model string) (protocol, modelID string) { // CreateProviderFromConfig creates a provider based on the ModelConfig. // It uses the protocol prefix in the Model field to determine which provider to create. // Supported protocols: openai, anthropic, antigravity, claude-cli, codex-cli, github-copilot -func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, error) { +// Returns the provider, the model ID (without protocol prefix), and any error. +func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, error) { if cfg == nil { - return nil, fmt.Errorf("config is nil") + return nil, "", fmt.Errorf("config is nil") } if cfg.Model == "" { - return nil, fmt.Errorf("model is required") + return nil, "", fmt.Errorf("model is required") } protocol, modelID := ExtractProtocol(cfg.Model) @@ -49,36 +50,36 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, error) { "volcengine", "vllm", "qwen": // All OpenAI-compatible HTTP providers if cfg.APIKey == "" && cfg.APIBase == "" { - return nil, fmt.Errorf("api_key or api_base is required for HTTP-based protocol %q", protocol) + return nil, "", fmt.Errorf("api_key or api_base is required for HTTP-based protocol %q", protocol) } apiBase := cfg.APIBase if apiBase == "" { apiBase = getDefaultAPIBase(protocol) } - return NewHTTPProvider(cfg.APIKey, apiBase, cfg.Proxy), nil + return NewHTTPProvider(cfg.APIKey, apiBase, cfg.Proxy), modelID, nil case "anthropic": if cfg.AuthMethod == "oauth" || cfg.AuthMethod == "token" { // Use Claude SDK with token - return NewClaudeProvider(cfg.APIKey), nil + return NewClaudeProvider(cfg.APIKey), modelID, nil } // Use HTTP API apiBase := cfg.APIBase if apiBase == "" { apiBase = "https://api.anthropic.com/v1" } - return NewHTTPProvider(cfg.APIKey, apiBase, cfg.Proxy), nil + return NewHTTPProvider(cfg.APIKey, apiBase, cfg.Proxy), modelID, nil case "antigravity": - return NewAntigravityProvider(), nil + return NewAntigravityProvider(), modelID, nil case "claude-cli", "claudecli": workspace := "." - return NewClaudeCliProvider(workspace), nil + return NewClaudeCliProvider(workspace), modelID, nil case "codex-cli", "codexcli": workspace := "." - return NewCodexCliProvider(workspace), nil + return NewCodexCliProvider(workspace), modelID, nil case "github-copilot", "copilot": apiBase := cfg.APIBase @@ -89,10 +90,14 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, error) { if connectMode == "" { connectMode = "grpc" } - return NewGitHubCopilotProvider(apiBase, connectMode, modelID) + provider, err := NewGitHubCopilotProvider(apiBase, connectMode, modelID) + if err != nil { + return nil, "", err + } + return provider, modelID, nil default: - return nil, fmt.Errorf("unknown protocol %q in model %q", protocol, cfg.Model) + return nil, "", fmt.Errorf("unknown protocol %q in model %q", protocol, cfg.Model) } } diff --git a/pkg/providers/factory_provider_test.go b/pkg/providers/factory_provider_test.go new file mode 100644 index 000000000..f7c1aa58c --- /dev/null +++ b/pkg/providers/factory_provider_test.go @@ -0,0 +1,250 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package providers + +import ( + "testing" + + "github.com/sipeed/picoclaw/pkg/config" +) + +func TestExtractProtocol(t *testing.T) { + tests := []struct { + name string + model string + wantProtocol string + wantModelID string + }{ + { + name: "openai with prefix", + model: "openai/gpt-4o", + wantProtocol: "openai", + wantModelID: "gpt-4o", + }, + { + name: "anthropic with prefix", + model: "anthropic/claude-3-sonnet", + wantProtocol: "anthropic", + wantModelID: "claude-3-sonnet", + }, + { + name: "no prefix - defaults to openai", + model: "gpt-4o", + wantProtocol: "openai", + wantModelID: "gpt-4o", + }, + { + name: "groq with prefix", + model: "groq/llama-3.1-70b", + wantProtocol: "groq", + wantModelID: "llama-3.1-70b", + }, + { + name: "empty string", + model: "", + wantProtocol: "openai", + wantModelID: "", + }, + { + name: "with whitespace", + model: " openai/gpt-4 ", + wantProtocol: "openai", + wantModelID: "gpt-4", + }, + { + name: "multiple slashes", + model: "nvidia/meta/llama-3.1-8b", + wantProtocol: "nvidia", + wantModelID: "meta/llama-3.1-8b", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + protocol, modelID := ExtractProtocol(tt.model) + if protocol != tt.wantProtocol { + t.Errorf("ExtractProtocol(%q) protocol = %q, want %q", tt.model, protocol, tt.wantProtocol) + } + if modelID != tt.wantModelID { + t.Errorf("ExtractProtocol(%q) modelID = %q, want %q", tt.model, modelID, tt.wantModelID) + } + }) + } +} + +func TestCreateProviderFromConfig_OpenAI(t *testing.T) { + cfg := &config.ModelConfig{ + ModelName: "test-openai", + Model: "openai/gpt-4o", + APIKey: "test-key", + APIBase: "https://api.example.com/v1", + } + + provider, modelID, err := CreateProviderFromConfig(cfg) + if err != nil { + t.Fatalf("CreateProviderFromConfig() error = %v", err) + } + if provider == nil { + t.Fatal("CreateProviderFromConfig() returned nil provider") + } + if modelID != "gpt-4o" { + t.Errorf("modelID = %q, want %q", modelID, "gpt-4o") + } +} + +func TestCreateProviderFromConfig_DefaultAPIBase(t *testing.T) { + tests := []struct { + name string + protocol string + wantBase string + }{ + {"openai", "openai", "https://api.openai.com/v1"}, + {"groq", "groq", "https://api.groq.com/openai/v1"}, + {"openrouter", "openrouter", "https://openrouter.ai/api/v1"}, + {"cerebras", "cerebras", "https://api.cerebras.ai/v1"}, + {"qwen", "qwen", "https://dashscope.aliyuncs.com/compatible-mode/v1"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := &config.ModelConfig{ + ModelName: "test-" + tt.protocol, + Model: tt.protocol + "/test-model", + APIKey: "test-key", + } + + provider, _, err := CreateProviderFromConfig(cfg) + if err != nil { + t.Fatalf("CreateProviderFromConfig() error = %v", err) + } + + httpProvider, ok := provider.(*HTTPProvider) + if !ok { + t.Fatalf("expected *HTTPProvider, got %T", provider) + } + if httpProvider.apiBase != tt.wantBase { + t.Errorf("apiBase = %q, want %q", httpProvider.apiBase, tt.wantBase) + } + }) + } +} + +func TestCreateProviderFromConfig_Anthropic(t *testing.T) { + cfg := &config.ModelConfig{ + ModelName: "test-anthropic", + Model: "anthropic/claude-3-sonnet", + APIKey: "test-key", + } + + provider, modelID, err := CreateProviderFromConfig(cfg) + if err != nil { + t.Fatalf("CreateProviderFromConfig() error = %v", err) + } + if provider == nil { + t.Fatal("CreateProviderFromConfig() returned nil provider") + } + if modelID != "claude-3-sonnet" { + t.Errorf("modelID = %q, want %q", modelID, "claude-3-sonnet") + } +} + +func TestCreateProviderFromConfig_Antigravity(t *testing.T) { + cfg := &config.ModelConfig{ + ModelName: "test-antigravity", + Model: "antigravity/gemini-2.0-flash", + } + + provider, modelID, err := CreateProviderFromConfig(cfg) + if err != nil { + t.Fatalf("CreateProviderFromConfig() error = %v", err) + } + if provider == nil { + t.Fatal("CreateProviderFromConfig() returned nil provider") + } + if modelID != "gemini-2.0-flash" { + t.Errorf("modelID = %q, want %q", modelID, "gemini-2.0-flash") + } +} + +func TestCreateProviderFromConfig_ClaudeCLI(t *testing.T) { + cfg := &config.ModelConfig{ + ModelName: "test-claude-cli", + Model: "claude-cli/claude-sonnet-4-20250514", + } + + provider, modelID, err := CreateProviderFromConfig(cfg) + if err != nil { + t.Fatalf("CreateProviderFromConfig() error = %v", err) + } + if provider == nil { + t.Fatal("CreateProviderFromConfig() returned nil provider") + } + if modelID != "claude-sonnet-4-20250514" { + t.Errorf("modelID = %q, want %q", modelID, "claude-sonnet-4-20250514") + } +} + +func TestCreateProviderFromConfig_CodexCLI(t *testing.T) { + cfg := &config.ModelConfig{ + ModelName: "test-codex-cli", + Model: "codex-cli/codex", + } + + provider, modelID, err := CreateProviderFromConfig(cfg) + if err != nil { + t.Fatalf("CreateProviderFromConfig() error = %v", err) + } + if provider == nil { + t.Fatal("CreateProviderFromConfig() returned nil provider") + } + if modelID != "codex" { + t.Errorf("modelID = %q, want %q", modelID, "codex") + } +} + +func TestCreateProviderFromConfig_MissingAPIKey(t *testing.T) { + cfg := &config.ModelConfig{ + ModelName: "test-no-key", + Model: "openai/gpt-4o", + } + + _, _, err := CreateProviderFromConfig(cfg) + if err == nil { + t.Fatal("CreateProviderFromConfig() expected error for missing API key") + } +} + +func TestCreateProviderFromConfig_UnknownProtocol(t *testing.T) { + cfg := &config.ModelConfig{ + ModelName: "test-unknown", + Model: "unknown-protocol/model", + APIKey: "test-key", + } + + _, _, err := CreateProviderFromConfig(cfg) + if err == nil { + t.Fatal("CreateProviderFromConfig() expected error for unknown protocol") + } +} + +func TestCreateProviderFromConfig_NilConfig(t *testing.T) { + _, _, err := CreateProviderFromConfig(nil) + if err == nil { + t.Fatal("CreateProviderFromConfig(nil) expected error") + } +} + +func TestCreateProviderFromConfig_EmptyModel(t *testing.T) { + cfg := &config.ModelConfig{ + ModelName: "test-empty", + Model: "", + } + + _, _, err := CreateProviderFromConfig(cfg) + if err == nil { + t.Fatal("CreateProviderFromConfig() expected error for empty model") + } +} diff --git a/pkg/providers/http_provider.go b/pkg/providers/http_provider.go index d264ae3a3..6d2ca1eb7 100644 --- a/pkg/providers/http_provider.go +++ b/pkg/providers/http_provider.go @@ -16,9 +16,6 @@ import ( "net/url" "strings" "time" - - "github.com/sipeed/picoclaw/pkg/auth" - "github.com/sipeed/picoclaw/pkg/config" ) type HTTPProvider struct { @@ -161,13 +158,15 @@ func (p *HTTPProvider) parseResponse(body []byte) (*LLMResponse, error) { arguments := make(map[string]interface{}) name := "" thoughtSignature := "" + argsStr := "" if tc.Function != nil { name = tc.Function.Name thoughtSignature = tc.Function.ThoughtSignature - if tc.Function.Arguments != "" { - if err := json.Unmarshal([]byte(tc.Function.Arguments), &arguments); err != nil { - arguments["raw"] = tc.Function.Arguments + argsStr = tc.Function.Arguments + if argsStr != "" { + if err := json.Unmarshal([]byte(argsStr), &arguments); err != nil { + arguments["raw"] = argsStr } } } @@ -177,7 +176,7 @@ func (p *HTTPProvider) parseResponse(body []byte) (*LLMResponse, error) { Type: tc.Type, Function: &FunctionCall{ Name: name, - Arguments: tc.Function.Arguments, + Arguments: argsStr, ThoughtSignature: thoughtSignature, }, Name: name, @@ -196,328 +195,3 @@ func (p *HTTPProvider) parseResponse(body []byte) (*LLMResponse, error) { func (p *HTTPProvider) GetDefaultModel() string { return "" } - -func createClaudeAuthProvider() (LLMProvider, error) { - cred, err := auth.GetCredential("anthropic") - if err != nil { - return nil, fmt.Errorf("loading auth credentials: %w", err) - } - if cred == nil { - return nil, fmt.Errorf("no credentials for anthropic. Run: picoclaw auth login --provider anthropic") - } - return NewClaudeProviderWithTokenSource(cred.AccessToken, createClaudeTokenSource()), nil -} - -func createCodexAuthProvider() (LLMProvider, error) { - cred, err := auth.GetCredential("openai") - if err != nil { - return nil, fmt.Errorf("loading auth credentials: %w", err) - } - if cred == nil { - return nil, fmt.Errorf("no credentials for openai. Run: picoclaw auth login --provider openai") - } - return NewCodexProviderWithTokenSource(cred.AccessToken, cred.AccountID, createCodexTokenSource()), nil -} - -func CreateProvider(cfg *config.Config) (LLMProvider, error) { - model := cfg.Agents.Defaults.Model - - // First, try to use model_list configuration - if len(cfg.ModelList) > 0 { - // Try to get config by model name first - modelCfg, err := cfg.GetModelConfig(model) - if err == nil { - // Found in model_list, use factory to create provider - provider, err := CreateProviderFromConfig(modelCfg) - if err != nil { - return nil, fmt.Errorf("failed to create provider from model_list: %w", err) - } - return provider, nil - } - // Model not found in model_list, fall through to providers config - } - - // Log deprecation warning if using old providers config - if cfg.HasProvidersConfig() && len(cfg.ModelList) == 0 { - fmt.Println("WARNING: providers config is deprecated, please migrate to model_list") - } - - providerName := strings.ToLower(cfg.Agents.Defaults.Provider) - - var apiKey, apiBase, proxy string - - lowerModel := strings.ToLower(model) - - // First, try to use explicitly configured provider - if providerName != "" { - switch providerName { - case "groq": - if cfg.Providers.Groq.APIKey != "" { - apiKey = cfg.Providers.Groq.APIKey - apiBase = cfg.Providers.Groq.APIBase - if apiBase == "" { - apiBase = "https://api.groq.com/openai/v1" - } - } - case "openai", "gpt": - if cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != "" { - if cfg.Providers.OpenAI.AuthMethod == "codex-cli" { - return NewCodexProviderWithTokenSource("", "", CreateCodexCliTokenSource()), nil - } - if cfg.Providers.OpenAI.AuthMethod == "oauth" || cfg.Providers.OpenAI.AuthMethod == "token" { - return createCodexAuthProvider() - } - apiKey = cfg.Providers.OpenAI.APIKey - apiBase = cfg.Providers.OpenAI.APIBase - if apiBase == "" { - apiBase = "https://api.openai.com/v1" - } - } - case "anthropic", "claude": - if cfg.Providers.Anthropic.APIKey != "" || cfg.Providers.Anthropic.AuthMethod != "" { - if cfg.Providers.Anthropic.AuthMethod == "oauth" || cfg.Providers.Anthropic.AuthMethod == "token" { - return createClaudeAuthProvider() - } - apiKey = cfg.Providers.Anthropic.APIKey - apiBase = cfg.Providers.Anthropic.APIBase - if apiBase == "" { - apiBase = "https://api.anthropic.com/v1" - } - } - case "openrouter": - if cfg.Providers.OpenRouter.APIKey != "" { - apiKey = cfg.Providers.OpenRouter.APIKey - if cfg.Providers.OpenRouter.APIBase != "" { - apiBase = cfg.Providers.OpenRouter.APIBase - } else { - apiBase = "https://openrouter.ai/api/v1" - } - } - case "zhipu", "glm": - if cfg.Providers.Zhipu.APIKey != "" { - apiKey = cfg.Providers.Zhipu.APIKey - apiBase = cfg.Providers.Zhipu.APIBase - if apiBase == "" { - apiBase = "https://open.bigmodel.cn/api/paas/v4" - } - } - case "gemini", "google": - if cfg.Providers.Gemini.APIKey != "" { - apiKey = cfg.Providers.Gemini.APIKey - apiBase = cfg.Providers.Gemini.APIBase - if apiBase == "" { - apiBase = "https://generativelanguage.googleapis.com/v1beta" - } - } - case "vllm": - if cfg.Providers.VLLM.APIBase != "" { - apiKey = cfg.Providers.VLLM.APIKey - apiBase = cfg.Providers.VLLM.APIBase - } - case "shengsuanyun": - if cfg.Providers.ShengSuanYun.APIKey != "" { - apiKey = cfg.Providers.ShengSuanYun.APIKey - apiBase = cfg.Providers.ShengSuanYun.APIBase - if apiBase == "" { - apiBase = "https://router.shengsuanyun.com/api/v1" - } - } - case "claude-cli", "claudecode", "claude-code": - workspace := cfg.WorkspacePath() - if workspace == "" { - workspace = "." - } - return NewClaudeCliProvider(workspace), nil - case "codex-cli", "codex-code": - workspace := cfg.WorkspacePath() - if workspace == "" { - workspace = "." - } - return NewCodexCliProvider(workspace), nil - case "cerebras": - if cfg.Providers.Cerebras.APIKey != "" { - apiKey = cfg.Providers.Cerebras.APIKey - apiBase = cfg.Providers.Cerebras.APIBase - if apiBase == "" { - apiBase = "https://api.cerebras.ai/v1" - } - } - case "deepseek": - if cfg.Providers.DeepSeek.APIKey != "" { - apiKey = cfg.Providers.DeepSeek.APIKey - apiBase = cfg.Providers.DeepSeek.APIBase - if apiBase == "" { - apiBase = "https://api.deepseek.com/v1" - } - if model != "deepseek-chat" && model != "deepseek-reasoner" { - model = "deepseek-chat" - } - } - case "qwen": - if cfg.Providers.Qwen.APIKey != "" { - apiKey = cfg.Providers.Qwen.APIKey - apiBase = cfg.Providers.Qwen.APIBase - if apiBase == "" { - apiBase = "https://dashscope.aliyuncs.com/compatible-mode/v1" - } - } - case "github_copilot", "copilot": - if cfg.Providers.GitHubCopilot.APIBase != "" { - apiBase = cfg.Providers.GitHubCopilot.APIBase - } else { - apiBase = "localhost:4321" - } - return NewGitHubCopilotProvider(apiBase, cfg.Providers.GitHubCopilot.ConnectMode, model) - case "antigravity", "google-antigravity": - return NewAntigravityProvider(), nil - - case "volcengine", "doubao": - if cfg.Providers.VolcEngine.APIKey != "" { - apiKey = cfg.Providers.VolcEngine.APIKey - apiBase = cfg.Providers.VolcEngine.APIBase - if apiBase == "" { - apiBase = "https://ark.cn-beijing.volces.com/api/v3" - } - } - - } - - } - - // Fallback: detect provider from model name - if apiKey == "" && apiBase == "" { - switch { - case (strings.Contains(lowerModel, "kimi") || strings.Contains(lowerModel, "moonshot") || strings.HasPrefix(model, "moonshot/")) && cfg.Providers.Moonshot.APIKey != "": - apiKey = cfg.Providers.Moonshot.APIKey - apiBase = cfg.Providers.Moonshot.APIBase - proxy = cfg.Providers.Moonshot.Proxy - if apiBase == "" { - apiBase = "https://api.moonshot.cn/v1" - } - - case strings.HasPrefix(model, "openrouter/") || strings.HasPrefix(model, "anthropic/") || strings.HasPrefix(model, "openai/") || strings.HasPrefix(model, "meta-llama/") || strings.HasPrefix(model, "deepseek/") || strings.HasPrefix(model, "google/"): - apiKey = cfg.Providers.OpenRouter.APIKey - proxy = cfg.Providers.OpenRouter.Proxy - if cfg.Providers.OpenRouter.APIBase != "" { - apiBase = cfg.Providers.OpenRouter.APIBase - } else { - apiBase = "https://openrouter.ai/api/v1" - } - - case (strings.Contains(lowerModel, "claude") || strings.HasPrefix(model, "anthropic/")) && (cfg.Providers.Anthropic.APIKey != "" || cfg.Providers.Anthropic.AuthMethod != ""): - if cfg.Providers.Anthropic.AuthMethod == "oauth" || cfg.Providers.Anthropic.AuthMethod == "token" { - return createClaudeAuthProvider() - } - apiKey = cfg.Providers.Anthropic.APIKey - apiBase = cfg.Providers.Anthropic.APIBase - proxy = cfg.Providers.Anthropic.Proxy - if apiBase == "" { - apiBase = "https://api.anthropic.com/v1" - } - - case (strings.Contains(lowerModel, "gpt") || strings.HasPrefix(model, "openai/")) && (cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != ""): - if cfg.Providers.OpenAI.AuthMethod == "oauth" || cfg.Providers.OpenAI.AuthMethod == "token" { - return createCodexAuthProvider() - } - apiKey = cfg.Providers.OpenAI.APIKey - apiBase = cfg.Providers.OpenAI.APIBase - proxy = cfg.Providers.OpenAI.Proxy - if apiBase == "" { - apiBase = "https://api.openai.com/v1" - } - - case (strings.Contains(lowerModel, "gemini") || strings.HasPrefix(model, "google/")) && cfg.Providers.Gemini.APIKey != "": - apiKey = cfg.Providers.Gemini.APIKey - apiBase = cfg.Providers.Gemini.APIBase - proxy = cfg.Providers.Gemini.Proxy - if apiBase == "" { - apiBase = "https://generativelanguage.googleapis.com/v1beta" - } - - case (strings.Contains(lowerModel, "glm") || strings.Contains(lowerModel, "zhipu") || strings.Contains(lowerModel, "zai")) && cfg.Providers.Zhipu.APIKey != "": - apiKey = cfg.Providers.Zhipu.APIKey - apiBase = cfg.Providers.Zhipu.APIBase - proxy = cfg.Providers.Zhipu.Proxy - if apiBase == "" { - apiBase = "https://open.bigmodel.cn/api/paas/v4" - } - - case (strings.Contains(lowerModel, "groq") || strings.HasPrefix(model, "groq/")) && cfg.Providers.Groq.APIKey != "": - apiKey = cfg.Providers.Groq.APIKey - apiBase = cfg.Providers.Groq.APIBase - proxy = cfg.Providers.Groq.Proxy - if apiBase == "" { - apiBase = "https://api.groq.com/openai/v1" - } - - case (strings.Contains(lowerModel, "qwen") || strings.HasPrefix(model, "qwen/")) && cfg.Providers.Qwen.APIKey != "": - apiKey = cfg.Providers.Qwen.APIKey - apiBase = cfg.Providers.Qwen.APIBase - proxy = cfg.Providers.Qwen.Proxy - if apiBase == "" { - apiBase = "https://dashscope.aliyuncs.com/compatible-mode/v1" - } - - case (strings.Contains(lowerModel, "nvidia") || strings.HasPrefix(model, "nvidia/")) && cfg.Providers.Nvidia.APIKey != "": - apiKey = cfg.Providers.Nvidia.APIKey - apiBase = cfg.Providers.Nvidia.APIBase - proxy = cfg.Providers.Nvidia.Proxy - if apiBase == "" { - apiBase = "https://integrate.api.nvidia.com/v1" - } - case (strings.Contains(lowerModel, "cerebras") || strings.HasPrefix(model, "cerebras/")) && cfg.Providers.Cerebras.APIKey != "": - apiKey = cfg.Providers.Cerebras.APIKey - apiBase = cfg.Providers.Cerebras.APIBase - proxy = cfg.Providers.Cerebras.Proxy - if apiBase == "" { - apiBase = "https://api.cerebras.ai/v1" - } - - case (strings.Contains(lowerModel, "ollama") || strings.HasPrefix(model, "ollama/")) && cfg.Providers.Ollama.APIKey != "": - fmt.Println("Ollama provider selected based on model name prefix") - apiKey = cfg.Providers.Ollama.APIKey - apiBase = cfg.Providers.Ollama.APIBase - proxy = cfg.Providers.Ollama.Proxy - if apiBase == "" { - apiBase = "http://localhost:11434/v1" - } - fmt.Println("Ollama apiBase:", apiBase) - - case (strings.Contains(lowerModel, "doubao") || strings.HasPrefix(lowerModel, "doubao") || strings.Contains(lowerModel, "volcengine")) && cfg.Providers.VolcEngine.APIKey != "": - apiKey = cfg.Providers.VolcEngine.APIKey - apiBase = cfg.Providers.VolcEngine.APIBase - proxy = cfg.Providers.VolcEngine.Proxy - if apiBase == "" { - apiBase = "https://ark.cn-beijing.volces.com/api/v3" - } - - case cfg.Providers.VLLM.APIBase != "": - apiKey = cfg.Providers.VLLM.APIKey - apiBase = cfg.Providers.VLLM.APIBase - proxy = cfg.Providers.VLLM.Proxy - - default: - if cfg.Providers.OpenRouter.APIKey != "" { - apiKey = cfg.Providers.OpenRouter.APIKey - proxy = cfg.Providers.OpenRouter.Proxy - if cfg.Providers.OpenRouter.APIBase != "" { - apiBase = cfg.Providers.OpenRouter.APIBase - } else { - apiBase = "https://openrouter.ai/api/v1" - } - } else { - return nil, fmt.Errorf("no API key configured for model: %s", model) - } - } - } - - if apiKey == "" && !strings.HasPrefix(model, "bedrock/") { - return nil, fmt.Errorf("no API key configured for provider (model: %s)", model) - } - - if apiBase == "" { - return nil, fmt.Errorf("no API base configured for provider (model: %s)", model) - } - - return NewHTTPProvider(apiKey, apiBase, proxy), nil -} diff --git a/pkg/providers/legacy_provider.go b/pkg/providers/legacy_provider.go new file mode 100644 index 000000000..c1efb03b3 --- /dev/null +++ b/pkg/providers/legacy_provider.go @@ -0,0 +1,349 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package providers + +import ( + "fmt" + "strings" + + "github.com/sipeed/picoclaw/pkg/auth" + "github.com/sipeed/picoclaw/pkg/config" +) + +// createClaudeAuthProvider creates a Claude provider using OAuth credentials. +func createClaudeAuthProvider() (LLMProvider, error) { + cred, err := auth.GetCredential("anthropic") + if err != nil { + return nil, fmt.Errorf("loading auth credentials: %w", err) + } + if cred == nil { + return nil, fmt.Errorf("no credentials for anthropic. Run: picoclaw auth login --provider anthropic") + } + return NewClaudeProviderWithTokenSource(cred.AccessToken, createClaudeTokenSource()), nil +} + +// createCodexAuthProvider creates a Codex provider using OAuth credentials. +func createCodexAuthProvider() (LLMProvider, error) { + cred, err := auth.GetCredential("openai") + if err != nil { + return nil, fmt.Errorf("loading auth credentials: %w", err) + } + if cred == nil { + return nil, fmt.Errorf("no credentials for openai. Run: picoclaw auth login --provider openai") + } + return NewCodexProviderWithTokenSource(cred.AccessToken, cred.AccountID, createCodexTokenSource()), nil +} + +// CreateProvider creates a provider based on the configuration. +// It supports both the new model_list configuration and the legacy providers configuration. +// Returns the provider, the model ID to use, and any error. +func CreateProvider(cfg *config.Config) (LLMProvider, string, error) { + model := cfg.Agents.Defaults.Model + + // First, try to use model_list configuration + if len(cfg.ModelList) > 0 { + // Try to get config by model name first + modelCfg, err := cfg.GetModelConfig(model) + if err == nil { + // Found in model_list, use factory to create provider + provider, modelID, err := CreateProviderFromConfig(modelCfg) + if err != nil { + return nil, "", fmt.Errorf("failed to create provider from model_list: %w", err) + } + return provider, modelID, nil + } + // Model not found in model_list, fall through to providers config + } + + // Log deprecation warning if using old providers config + if cfg.HasProvidersConfig() && len(cfg.ModelList) == 0 { + fmt.Println("WARNING: providers config is deprecated, please migrate to model_list") + } + + providerName := strings.ToLower(cfg.Agents.Defaults.Provider) + + var apiKey, apiBase, proxy string + + lowerModel := strings.ToLower(model) + + // First, try to use explicitly configured provider + if providerName != "" { + switch providerName { + case "groq": + if cfg.Providers.Groq.APIKey != "" { + apiKey = cfg.Providers.Groq.APIKey + apiBase = cfg.Providers.Groq.APIBase + if apiBase == "" { + apiBase = "https://api.groq.com/openai/v1" + } + } + case "openai", "gpt": + if cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != "" { + if cfg.Providers.OpenAI.AuthMethod == "codex-cli" { + return NewCodexProviderWithTokenSource("", "", CreateCodexCliTokenSource()), model, nil + } + if cfg.Providers.OpenAI.AuthMethod == "oauth" || cfg.Providers.OpenAI.AuthMethod == "token" { + provider, err := createCodexAuthProvider() + return provider, model, err + } + apiKey = cfg.Providers.OpenAI.APIKey + apiBase = cfg.Providers.OpenAI.APIBase + if apiBase == "" { + apiBase = "https://api.openai.com/v1" + } + } + case "anthropic", "claude": + if cfg.Providers.Anthropic.APIKey != "" || cfg.Providers.Anthropic.AuthMethod != "" { + if cfg.Providers.Anthropic.AuthMethod == "oauth" || cfg.Providers.Anthropic.AuthMethod == "token" { + provider, err := createClaudeAuthProvider() + return provider, model, err + } + apiKey = cfg.Providers.Anthropic.APIKey + apiBase = cfg.Providers.Anthropic.APIBase + if apiBase == "" { + apiBase = "https://api.anthropic.com/v1" + } + } + case "openrouter": + if cfg.Providers.OpenRouter.APIKey != "" { + apiKey = cfg.Providers.OpenRouter.APIKey + if cfg.Providers.OpenRouter.APIBase != "" { + apiBase = cfg.Providers.OpenRouter.APIBase + } else { + apiBase = "https://openrouter.ai/api/v1" + } + } + case "zhipu", "glm": + if cfg.Providers.Zhipu.APIKey != "" { + apiKey = cfg.Providers.Zhipu.APIKey + apiBase = cfg.Providers.Zhipu.APIBase + if apiBase == "" { + apiBase = "https://open.bigmodel.cn/api/paas/v4" + } + } + case "gemini", "google": + if cfg.Providers.Gemini.APIKey != "" { + apiKey = cfg.Providers.Gemini.APIKey + apiBase = cfg.Providers.Gemini.APIBase + if apiBase == "" { + apiBase = "https://generativelanguage.googleapis.com/v1beta" + } + } + case "vllm": + if cfg.Providers.VLLM.APIBase != "" { + apiKey = cfg.Providers.VLLM.APIKey + apiBase = cfg.Providers.VLLM.APIBase + } + case "shengsuanyun": + if cfg.Providers.ShengSuanYun.APIKey != "" { + apiKey = cfg.Providers.ShengSuanYun.APIKey + apiBase = cfg.Providers.ShengSuanYun.APIBase + if apiBase == "" { + apiBase = "https://router.shengsuanyun.com/api/v1" + } + } + case "claude-cli", "claudecode", "claude-code": + workspace := cfg.WorkspacePath() + if workspace == "" { + workspace = "." + } + return NewClaudeCliProvider(workspace), model, nil + case "codex-cli", "codex-code": + workspace := cfg.WorkspacePath() + if workspace == "" { + workspace = "." + } + return NewCodexCliProvider(workspace), model, nil + case "cerebras": + if cfg.Providers.Cerebras.APIKey != "" { + apiKey = cfg.Providers.Cerebras.APIKey + apiBase = cfg.Providers.Cerebras.APIBase + if apiBase == "" { + apiBase = "https://api.cerebras.ai/v1" + } + } + case "deepseek": + if cfg.Providers.DeepSeek.APIKey != "" { + apiKey = cfg.Providers.DeepSeek.APIKey + apiBase = cfg.Providers.DeepSeek.APIBase + if apiBase == "" { + apiBase = "https://api.deepseek.com/v1" + } + if model != "deepseek-chat" && model != "deepseek-reasoner" { + model = "deepseek-chat" + } + } + case "qwen": + if cfg.Providers.Qwen.APIKey != "" { + apiKey = cfg.Providers.Qwen.APIKey + apiBase = cfg.Providers.Qwen.APIBase + if apiBase == "" { + apiBase = "https://dashscope.aliyuncs.com/compatible-mode/v1" + } + } + case "github_copilot", "copilot": + if cfg.Providers.GitHubCopilot.APIBase != "" { + apiBase = cfg.Providers.GitHubCopilot.APIBase + } else { + apiBase = "localhost:4321" + } + provider, err := NewGitHubCopilotProvider(apiBase, cfg.Providers.GitHubCopilot.ConnectMode, model) + return provider, model, err + case "antigravity", "google-antigravity": + return NewAntigravityProvider(), model, nil + + case "volcengine", "doubao": + if cfg.Providers.VolcEngine.APIKey != "" { + apiKey = cfg.Providers.VolcEngine.APIKey + apiBase = cfg.Providers.VolcEngine.APIBase + if apiBase == "" { + apiBase = "https://ark.cn-beijing.volces.com/api/v3" + } + } + + } + + } + + // Fallback: detect provider from model name + if apiKey == "" && apiBase == "" { + switch { + case (strings.Contains(lowerModel, "kimi") || strings.Contains(lowerModel, "moonshot") || strings.HasPrefix(model, "moonshot/")) && cfg.Providers.Moonshot.APIKey != "": + apiKey = cfg.Providers.Moonshot.APIKey + apiBase = cfg.Providers.Moonshot.APIBase + proxy = cfg.Providers.Moonshot.Proxy + if apiBase == "" { + apiBase = "https://api.moonshot.cn/v1" + } + + case strings.HasPrefix(model, "openrouter/") || strings.HasPrefix(model, "anthropic/") || strings.HasPrefix(model, "openai/") || strings.HasPrefix(model, "meta-llama/") || strings.HasPrefix(model, "deepseek/") || strings.HasPrefix(model, "google/"): + apiKey = cfg.Providers.OpenRouter.APIKey + proxy = cfg.Providers.OpenRouter.Proxy + if cfg.Providers.OpenRouter.APIBase != "" { + apiBase = cfg.Providers.OpenRouter.APIBase + } else { + apiBase = "https://openrouter.ai/api/v1" + } + + case (strings.Contains(lowerModel, "claude") || strings.HasPrefix(model, "anthropic/")) && (cfg.Providers.Anthropic.APIKey != "" || cfg.Providers.Anthropic.AuthMethod != ""): + if cfg.Providers.Anthropic.AuthMethod == "oauth" || cfg.Providers.Anthropic.AuthMethod == "token" { + provider, err := createClaudeAuthProvider() + return provider, model, err + } + apiKey = cfg.Providers.Anthropic.APIKey + apiBase = cfg.Providers.Anthropic.APIBase + proxy = cfg.Providers.Anthropic.Proxy + if apiBase == "" { + apiBase = "https://api.anthropic.com/v1" + } + + case (strings.Contains(lowerModel, "gpt") || strings.HasPrefix(model, "openai/")) && (cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != ""): + if cfg.Providers.OpenAI.AuthMethod == "oauth" || cfg.Providers.OpenAI.AuthMethod == "token" { + provider, err := createCodexAuthProvider() + return provider, model, err + } + apiKey = cfg.Providers.OpenAI.APIKey + apiBase = cfg.Providers.OpenAI.APIBase + proxy = cfg.Providers.OpenAI.Proxy + if apiBase == "" { + apiBase = "https://api.openai.com/v1" + } + + case (strings.Contains(lowerModel, "gemini") || strings.HasPrefix(model, "google/")) && cfg.Providers.Gemini.APIKey != "": + apiKey = cfg.Providers.Gemini.APIKey + apiBase = cfg.Providers.Gemini.APIBase + proxy = cfg.Providers.Gemini.Proxy + if apiBase == "" { + apiBase = "https://generativelanguage.googleapis.com/v1beta" + } + + case (strings.Contains(lowerModel, "glm") || strings.Contains(lowerModel, "zhipu") || strings.Contains(lowerModel, "zai")) && cfg.Providers.Zhipu.APIKey != "": + apiKey = cfg.Providers.Zhipu.APIKey + apiBase = cfg.Providers.Zhipu.APIBase + proxy = cfg.Providers.Zhipu.Proxy + if apiBase == "" { + apiBase = "https://open.bigmodel.cn/api/paas/v4" + } + + case (strings.Contains(lowerModel, "groq") || strings.HasPrefix(model, "groq/")) && cfg.Providers.Groq.APIKey != "": + apiKey = cfg.Providers.Groq.APIKey + apiBase = cfg.Providers.Groq.APIBase + proxy = cfg.Providers.Groq.Proxy + if apiBase == "" { + apiBase = "https://api.groq.com/openai/v1" + } + + case (strings.Contains(lowerModel, "qwen") || strings.HasPrefix(model, "qwen/")) && cfg.Providers.Qwen.APIKey != "": + apiKey = cfg.Providers.Qwen.APIKey + apiBase = cfg.Providers.Qwen.APIBase + proxy = cfg.Providers.Qwen.Proxy + if apiBase == "" { + apiBase = "https://dashscope.aliyuncs.com/compatible-mode/v1" + } + + case (strings.Contains(lowerModel, "nvidia") || strings.HasPrefix(model, "nvidia/")) && cfg.Providers.Nvidia.APIKey != "": + apiKey = cfg.Providers.Nvidia.APIKey + apiBase = cfg.Providers.Nvidia.APIBase + proxy = cfg.Providers.Nvidia.Proxy + if apiBase == "" { + apiBase = "https://integrate.api.nvidia.com/v1" + } + case (strings.Contains(lowerModel, "cerebras") || strings.HasPrefix(model, "cerebras/")) && cfg.Providers.Cerebras.APIKey != "": + apiKey = cfg.Providers.Cerebras.APIKey + apiBase = cfg.Providers.Cerebras.APIBase + proxy = cfg.Providers.Cerebras.Proxy + if apiBase == "" { + apiBase = "https://api.cerebras.ai/v1" + } + + case (strings.Contains(lowerModel, "ollama") || strings.HasPrefix(model, "ollama/")) && cfg.Providers.Ollama.APIKey != "": + fmt.Println("Ollama provider selected based on model name prefix") + apiKey = cfg.Providers.Ollama.APIKey + apiBase = cfg.Providers.Ollama.APIBase + proxy = cfg.Providers.Ollama.Proxy + if apiBase == "" { + apiBase = "http://localhost:11434/v1" + } + fmt.Println("Ollama apiBase:", apiBase) + + case (strings.Contains(lowerModel, "doubao") || strings.HasPrefix(lowerModel, "doubao") || strings.Contains(lowerModel, "volcengine")) && cfg.Providers.VolcEngine.APIKey != "": + apiKey = cfg.Providers.VolcEngine.APIKey + apiBase = cfg.Providers.VolcEngine.APIBase + proxy = cfg.Providers.VolcEngine.Proxy + if apiBase == "" { + apiBase = "https://ark.cn-beijing.volces.com/api/v3" + } + + case cfg.Providers.VLLM.APIBase != "": + apiKey = cfg.Providers.VLLM.APIKey + apiBase = cfg.Providers.VLLM.APIBase + proxy = cfg.Providers.VLLM.Proxy + + default: + if cfg.Providers.OpenRouter.APIKey != "" { + apiKey = cfg.Providers.OpenRouter.APIKey + proxy = cfg.Providers.OpenRouter.Proxy + if cfg.Providers.OpenRouter.APIBase != "" { + apiBase = cfg.Providers.OpenRouter.APIBase + } else { + apiBase = "https://openrouter.ai/api/v1" + } + } else { + return nil, "", fmt.Errorf("no API key configured for model: %s", model) + } + } + } + + if apiKey == "" && !strings.HasPrefix(model, "bedrock/") { + return nil, "", fmt.Errorf("no API key configured for provider (model: %s)", model) + } + + if apiBase == "" { + return nil, "", fmt.Errorf("no API base configured for provider (model: %s)", model) + } + + return NewHTTPProvider(apiKey, apiBase, proxy), model, nil +} diff --git a/pkg/providers/registry.go b/pkg/providers/registry.go deleted file mode 100644 index b9adef5d5..000000000 --- a/pkg/providers/registry.go +++ /dev/null @@ -1,113 +0,0 @@ -// PicoClaw - Ultra-lightweight personal AI agent -// License: MIT -// -// Copyright (c) 2026 PicoClaw contributors - -package providers - -import ( - "fmt" - "sync" - "sync/atomic" - - "github.com/sipeed/picoclaw/pkg/config" -) - -// ModelRegistry manages model configurations with thread-safe round-robin load balancing. -// It allows multiple configurations for the same model_name to distribute load across endpoints. -type ModelRegistry struct { - configs map[string][]config.ModelConfig // model_name -> []ModelConfig - counters map[string]*atomic.Uint64 // model_name -> round-robin counter - mu sync.RWMutex -} - -// NewModelRegistry creates a new ModelRegistry from a slice of ModelConfig. -func NewModelRegistry(modelList []config.ModelConfig) *ModelRegistry { - r := &ModelRegistry{ - configs: make(map[string][]config.ModelConfig), - counters: make(map[string]*atomic.Uint64), - } - - for _, cfg := range modelList { - r.configs[cfg.ModelName] = append(r.configs[cfg.ModelName], cfg) - } - - // Initialize counters for models with multiple configs - for name, cfgs := range r.configs { - if len(cfgs) > 1 { - r.counters[name] = &atomic.Uint64{} - } - } - - return r -} - -// GetModelConfig returns a ModelConfig for the given model name. -// If multiple configs exist for the same model_name, it uses round-robin selection. -// Returns an error if the model is not found. -func (r *ModelRegistry) GetModelConfig(modelName string) (*config.ModelConfig, error) { - r.mu.RLock() - defer r.mu.RUnlock() - - configs, ok := r.configs[modelName] - if !ok || len(configs) == 0 { - return nil, fmt.Errorf("model %q not found", modelName) - } - - // Single config - return directly - if len(configs) == 1 { - return &configs[0], nil - } - - // Multiple configs - use round-robin for load balancing - counter, ok := r.counters[modelName] - if !ok { - // Should not happen, but handle gracefully - return &configs[0], nil - } - - idx := counter.Add(1) % uint64(len(configs)) - return &configs[idx], nil -} - -// AddConfig adds a new ModelConfig to the registry. -func (r *ModelRegistry) AddConfig(cfg config.ModelConfig) { - r.mu.Lock() - defer r.mu.Unlock() - - r.configs[cfg.ModelName] = append(r.configs[cfg.ModelName], cfg) - - // Initialize counter if we now have multiple configs - if len(r.configs[cfg.ModelName]) > 1 && r.counters[cfg.ModelName] == nil { - r.counters[cfg.ModelName] = &atomic.Uint64{} - } -} - -// RemoveConfig removes all configs with the given model_name. -func (r *ModelRegistry) RemoveConfig(modelName string) { - r.mu.Lock() - defer r.mu.Unlock() - - delete(r.configs, modelName) - delete(r.counters, modelName) -} - -// ListModels returns all unique model names in the registry. -func (r *ModelRegistry) ListModels() []string { - r.mu.RLock() - defer r.mu.RUnlock() - - names := make([]string, 0, len(r.configs)) - for name := range r.configs { - names = append(names, name) - } - return names -} - -// ConfigCount returns the number of configurations for a given model name. -func (r *ModelRegistry) ConfigCount(modelName string) int { - r.mu.RLock() - defer r.mu.RUnlock() - - return len(r.configs[modelName]) -}