diff --git a/.gitignore b/.gitignore index ce30d749e..3ff195fbf 100644 --- a/.gitignore +++ b/.gitignore @@ -10,7 +10,7 @@ build/ *.out /picoclaw /picoclaw-test -cmd/picoclaw/workspace +cmd/**/workspace # Picoclaw specific diff --git a/.golangci.yaml b/.golangci.yaml index dd3cbae19..d0ba90716 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -28,9 +28,7 @@ linters: - wsl_v5 # TODO: Disabled, because they are failing at the moment, we should fix them and enable (step by step) - - bodyclose - contextcheck - - dogsled - embeddedstructfieldcheck - errcheck - errchkjson @@ -45,21 +43,16 @@ linters: - gocritic - gocyclo - godox - - goprintffuncname - gosec - ineffassign - lll - maintidx - - misspell - mnd - modernize - - nakedret - nestif - nilnil - paralleltest - perfsprint - - prealloc - - predeclared - revive - staticcheck - tagalign @@ -68,8 +61,6 @@ linters: - unparam - usestdlibvars - usetesting - - wastedassign - - whitespace settings: errcheck: check-type-assertions: true diff --git a/.goreleaser.yaml b/.goreleaser.yaml index 2c47f7d86..b864485d3 100644 --- a/.goreleaser.yaml +++ b/.goreleaser.yaml @@ -15,10 +15,10 @@ builds: - stdjson ldflags: - -s -w - - -X main.version={{ .Version }} - - -X main.gitCommit={{ .ShortCommit }} - - -X main.buildTime={{ .Date }} - - -X main.goVersion={{ .Env.GOVERSION }} + - -X github.com/sipeed/picoclaw/cmd/picoclaw/internal.version={{ .Version }} + - -X github.com/sipeed/picoclaw/cmd/picoclaw/internal.gitCommit={{ .ShortCommit }} + - -X github.com/sipeed/picoclaw/cmd/picoclaw/internal.buildTime={{ .Date }} + - -X github.com/sipeed/picoclaw/cmd/picoclaw/internal.goVersion={{ .Env.GOVERSION }} goos: - linux - windows @@ -28,8 +28,7 @@ builds: - amd64 - arm64 - riscv64 - - s390x - - mips64 + - loong64 - arm main: ./cmd/picoclaw ignore: @@ -67,6 +66,26 @@ archives: - goos: windows formats: [zip] +nfpms: + - id: picoclaw + package_name: picoclaw + file_name_template: >- + {{ .PackageName }}_ + {{- .Version }}_ + {{- if eq .Arch "amd64" }}x86_64 + {{- else if eq .Arch "arm64" }}aarch64 + {{- else if eq .Arch "arm" }}armv{{ .Arm }} + {{- else }}{{ .Arch }}{{ end }} + vendor: picoclaw + homepage: https://github.com/{{ .Env.GITHUB_REPOSITORY_OWNER }}/picoclaw + maintainer: picoclaw contributors + description: picoclaw - a tool for managing and running tasks + license: MIT + formats: + - rpm + - deb + bindir: /usr/bin + changelog: sort: asc filters: diff --git a/Makefile b/Makefile index 576152f40..7bf05a2eb 100644 --- a/Makefile +++ b/Makefile @@ -11,7 +11,8 @@ VERSION?=$(shell git describe --tags --always --dirty 2>/dev/null || echo "dev") GIT_COMMIT=$(shell git rev-parse --short=8 HEAD 2>/dev/null || echo "dev") BUILD_TIME=$(shell date +%FT%T%z) GO_VERSION=$(shell $(GO) version | awk '{print $$3}') -LDFLAGS=-ldflags "-X main.version=$(VERSION) -X main.gitCommit=$(GIT_COMMIT) -X main.buildTime=$(BUILD_TIME) -X main.goVersion=$(GO_VERSION) -s -w" +INTERNAL=github.com/sipeed/picoclaw/cmd/picoclaw/internal +LDFLAGS=-ldflags "-X $(INTERNAL).version=$(VERSION) -X $(INTERNAL).gitCommit=$(GIT_COMMIT) -X $(INTERNAL).buildTime=$(BUILD_TIME) -X $(INTERNAL).goVersion=$(GO_VERSION) -s -w" # Go variables GO?=CGO_ENABLED=0 go diff --git a/assets/wechat.png b/assets/wechat.png index e30c34e4e..776c07885 100644 Binary files a/assets/wechat.png and b/assets/wechat.png differ diff --git a/cmd/picoclaw/cmd_cron.go b/cmd/picoclaw/cmd_cron.go deleted file mode 100644 index 8c42bde06..000000000 --- a/cmd/picoclaw/cmd_cron.go +++ /dev/null @@ -1,227 +0,0 @@ -// PicoClaw - Ultra-lightweight personal AI agent -// License: MIT - -package main - -import ( - "fmt" - "os" - "path/filepath" - "time" - - "github.com/sipeed/picoclaw/pkg/cron" -) - -func cronCmd() { - if len(os.Args) < 3 { - cronHelp() - return - } - - subcommand := os.Args[2] - - // Load config to get workspace path - cfg, err := loadConfig() - if err != nil { - fmt.Printf("Error loading config: %v\n", err) - return - } - - cronStorePath := filepath.Join(cfg.WorkspacePath(), "cron", "jobs.json") - - switch subcommand { - case "list": - cronListCmd(cronStorePath) - case "add": - cronAddCmd(cronStorePath) - case "remove": - if len(os.Args) < 4 { - fmt.Println("Usage: picoclaw cron remove ") - return - } - cronRemoveCmd(cronStorePath, os.Args[3]) - case "enable": - cronEnableCmd(cronStorePath, false) - case "disable": - cronEnableCmd(cronStorePath, true) - default: - fmt.Printf("Unknown cron command: %s\n", subcommand) - cronHelp() - } -} - -func cronHelp() { - fmt.Println("\nCron commands:") - fmt.Println(" list List all scheduled jobs") - fmt.Println(" add Add a new scheduled job") - fmt.Println(" remove Remove a job by ID") - fmt.Println(" enable Enable a job") - fmt.Println(" disable Disable a job") - fmt.Println() - fmt.Println("Add options:") - fmt.Println(" -n, --name Job name") - fmt.Println(" -m, --message Message for agent") - fmt.Println(" -e, --every Run every N seconds") - fmt.Println(" -c, --cron Cron expression (e.g. '0 9 * * *')") - fmt.Println(" -d, --deliver Deliver response to channel") - fmt.Println(" --to Recipient for delivery") - fmt.Println(" --channel Channel for delivery") -} - -func cronListCmd(storePath string) { - cs := cron.NewCronService(storePath, nil) - jobs := cs.ListJobs(true) // Show all jobs, including disabled - - if len(jobs) == 0 { - fmt.Println("No scheduled jobs.") - return - } - - fmt.Println("\nScheduled Jobs:") - fmt.Println("----------------") - for _, job := range jobs { - var schedule string - if job.Schedule.Kind == "every" && job.Schedule.EveryMS != nil { - schedule = fmt.Sprintf("every %ds", *job.Schedule.EveryMS/1000) - } else if job.Schedule.Kind == "cron" { - schedule = job.Schedule.Expr - } else { - schedule = "one-time" - } - - nextRun := "scheduled" - if job.State.NextRunAtMS != nil { - nextTime := time.UnixMilli(*job.State.NextRunAtMS) - nextRun = nextTime.Format("2006-01-02 15:04") - } - - status := "enabled" - if !job.Enabled { - status = "disabled" - } - - fmt.Printf(" %s (%s)\n", job.Name, job.ID) - fmt.Printf(" Schedule: %s\n", schedule) - fmt.Printf(" Status: %s\n", status) - fmt.Printf(" Next run: %s\n", nextRun) - } -} - -func cronAddCmd(storePath string) { - name := "" - message := "" - var everySec *int64 - cronExpr := "" - deliver := false - channel := "" - to := "" - - args := os.Args[3:] - for i := 0; i < len(args); i++ { - switch args[i] { - case "-n", "--name": - if i+1 < len(args) { - name = args[i+1] - i++ - } - case "-m", "--message": - if i+1 < len(args) { - message = args[i+1] - i++ - } - case "-e", "--every": - if i+1 < len(args) { - var sec int64 - fmt.Sscanf(args[i+1], "%d", &sec) - everySec = &sec - i++ - } - case "-c", "--cron": - if i+1 < len(args) { - cronExpr = args[i+1] - i++ - } - case "-d", "--deliver": - deliver = true - case "--to": - if i+1 < len(args) { - to = args[i+1] - i++ - } - case "--channel": - if i+1 < len(args) { - channel = args[i+1] - i++ - } - } - } - - if name == "" { - fmt.Println("Error: --name is required") - return - } - - if message == "" { - fmt.Println("Error: --message is required") - return - } - - if everySec == nil && cronExpr == "" { - fmt.Println("Error: Either --every or --cron must be specified") - return - } - - var schedule cron.CronSchedule - if everySec != nil { - everyMS := *everySec * 1000 - schedule = cron.CronSchedule{ - Kind: "every", - EveryMS: &everyMS, - } - } else { - schedule = cron.CronSchedule{ - Kind: "cron", - Expr: cronExpr, - } - } - - cs := cron.NewCronService(storePath, nil) - job, err := cs.AddJob(name, schedule, message, deliver, channel, to) - if err != nil { - fmt.Printf("Error adding job: %v\n", err) - return - } - - fmt.Printf("✓ Added job '%s' (%s)\n", job.Name, job.ID) -} - -func cronRemoveCmd(storePath, jobID string) { - cs := cron.NewCronService(storePath, nil) - if cs.RemoveJob(jobID) { - fmt.Printf("✓ Removed job %s\n", jobID) - } else { - fmt.Printf("✗ Job %s not found\n", jobID) - } -} - -func cronEnableCmd(storePath string, disable bool) { - if len(os.Args) < 4 { - fmt.Println("Usage: picoclaw cron enable/disable ") - return - } - - jobID := os.Args[3] - cs := cron.NewCronService(storePath, nil) - enabled := !disable - - job := cs.EnableJob(jobID, enabled) - if job != nil { - status := "enabled" - if disable { - status = "disabled" - } - fmt.Printf("✓ Job '%s' %s\n", job.Name, status) - } else { - fmt.Printf("✗ Job %s not found\n", jobID) - } -} diff --git a/cmd/picoclaw/cmd_migrate.go b/cmd/picoclaw/cmd_migrate.go deleted file mode 100644 index 86d4903ef..000000000 --- a/cmd/picoclaw/cmd_migrate.go +++ /dev/null @@ -1,81 +0,0 @@ -// PicoClaw - Ultra-lightweight personal AI agent -// License: MIT - -package main - -import ( - "fmt" - "os" - - "github.com/sipeed/picoclaw/pkg/migrate" -) - -func migrateCmd() { - if len(os.Args) > 2 && (os.Args[2] == "--help" || os.Args[2] == "-h") { - migrateHelp() - return - } - - opts := migrate.Options{} - - args := os.Args[2:] - for i := 0; i < len(args); i++ { - switch args[i] { - case "--dry-run": - opts.DryRun = true - case "--config-only": - opts.ConfigOnly = true - case "--workspace-only": - opts.WorkspaceOnly = true - case "--force": - opts.Force = true - case "--refresh": - opts.Refresh = true - case "--openclaw-home": - if i+1 < len(args) { - opts.OpenClawHome = args[i+1] - i++ - } - case "--picoclaw-home": - if i+1 < len(args) { - opts.PicoClawHome = args[i+1] - i++ - } - default: - fmt.Printf("Unknown flag: %s\n", args[i]) - migrateHelp() - os.Exit(1) - } - } - - result, err := migrate.Run(opts) - if err != nil { - fmt.Printf("Error: %v\n", err) - os.Exit(1) - } - - if !opts.DryRun { - migrate.PrintSummary(result) - } -} - -func migrateHelp() { - fmt.Println("\nMigrate from OpenClaw to PicoClaw") - fmt.Println() - fmt.Println("Usage: picoclaw migrate [options]") - fmt.Println() - fmt.Println("Options:") - fmt.Println(" --dry-run Show what would be migrated without making changes") - fmt.Println(" --refresh Re-sync workspace files from OpenClaw (repeatable)") - fmt.Println(" --config-only Only migrate config, skip workspace files") - fmt.Println(" --workspace-only Only migrate workspace files, skip config") - fmt.Println(" --force Skip confirmation prompts") - fmt.Println(" --openclaw-home Override OpenClaw home directory (default: ~/.openclaw)") - fmt.Println(" --picoclaw-home Override PicoClaw home directory (default: ~/.picoclaw)") - fmt.Println() - fmt.Println("Examples:") - fmt.Println(" picoclaw migrate Detect and migrate from OpenClaw") - fmt.Println(" picoclaw migrate --dry-run Show what would be migrated") - fmt.Println(" picoclaw migrate --refresh Re-sync workspace files") - fmt.Println(" picoclaw migrate --force Migrate without confirmation") -} diff --git a/cmd/picoclaw/internal/agent/command.go b/cmd/picoclaw/internal/agent/command.go new file mode 100644 index 000000000..47262fc85 --- /dev/null +++ b/cmd/picoclaw/internal/agent/command.go @@ -0,0 +1,30 @@ +package agent + +import ( + "github.com/spf13/cobra" +) + +func NewAgentCommand() *cobra.Command { + var ( + message string + sessionKey string + model string + debug bool + ) + + cmd := &cobra.Command{ + Use: "agent", + Short: "Interact with the agent directly", + Args: cobra.NoArgs, + RunE: func(cmd *cobra.Command, _ []string) error { + return agentCmd(message, sessionKey, model, debug) + }, + } + + cmd.Flags().BoolVarP(&debug, "debug", "d", false, "Enable debug logging") + cmd.Flags().StringVarP(&message, "message", "m", "", "Send a single message (non-interactive mode)") + cmd.Flags().StringVarP(&sessionKey, "session", "s", "cli:default", "Session key") + cmd.Flags().StringVarP(&model, "model", "", "", "Model to use") + + return cmd +} diff --git a/cmd/picoclaw/internal/agent/command_test.go b/cmd/picoclaw/internal/agent/command_test.go new file mode 100644 index 000000000..1457d6a49 --- /dev/null +++ b/cmd/picoclaw/internal/agent/command_test.go @@ -0,0 +1,33 @@ +package agent + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewAgentCommand(t *testing.T) { + cmd := NewAgentCommand() + + require.NotNil(t, cmd) + + assert.Equal(t, "agent", cmd.Use) + assert.Equal(t, "Interact with the agent directly", cmd.Short) + + assert.Len(t, cmd.Aliases, 0) + assert.False(t, cmd.HasSubCommands()) + + assert.Nil(t, cmd.Run) + assert.NotNil(t, cmd.RunE) + + assert.Nil(t, cmd.PersistentPreRun) + assert.Nil(t, cmd.PersistentPostRun) + + assert.True(t, cmd.HasFlags()) + + assert.NotNil(t, cmd.Flags().Lookup("debug")) + assert.NotNil(t, cmd.Flags().Lookup("message")) + assert.NotNil(t, cmd.Flags().Lookup("session")) + assert.NotNil(t, cmd.Flags().Lookup("model")) +} diff --git a/cmd/picoclaw/cmd_agent.go b/cmd/picoclaw/internal/agent/helpers.go similarity index 70% rename from cmd/picoclaw/cmd_agent.go rename to cmd/picoclaw/internal/agent/helpers.go index 98ea51103..746e9755e 100644 --- a/cmd/picoclaw/cmd_agent.go +++ b/cmd/picoclaw/internal/agent/helpers.go @@ -1,7 +1,4 @@ -// PicoClaw - Ultra-lightweight personal AI agent -// License: MIT - -package main +package agent import ( "bufio" @@ -14,56 +11,37 @@ import ( "github.com/chzyer/readline" + "github.com/sipeed/picoclaw/cmd/picoclaw/internal" "github.com/sipeed/picoclaw/pkg/agent" "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/providers" ) -func agentCmd() { - message := "" - sessionKey := "cli:default" - modelOverride := "" - - args := os.Args[2:] - for i := 0; i < len(args); i++ { - switch args[i] { - case "--debug", "-d": - logger.SetLevel(logger.DEBUG) - fmt.Println("🔍 Debug mode enabled") - case "-m", "--message": - if i+1 < len(args) { - message = args[i+1] - i++ - } - case "-s", "--session": - if i+1 < len(args) { - sessionKey = args[i+1] - i++ - } - case "--model", "-model": - if i+1 < len(args) { - modelOverride = args[i+1] - i++ - } - } +func agentCmd(message, sessionKey, model string, debug bool) error { + if sessionKey == "" { + sessionKey = "cli:default" } - cfg, err := loadConfig() + if debug { + logger.SetLevel(logger.DEBUG) + fmt.Println("🔍 Debug mode enabled") + } + + cfg, err := internal.LoadConfig() if err != nil { - fmt.Printf("Error loading config: %v\n", err) - os.Exit(1) + return fmt.Errorf("error loading config: %w", err) } - if modelOverride != "" { - cfg.Agents.Defaults.ModelName = modelOverride + if model != "" { + cfg.Agents.Defaults.ModelName = model } provider, modelID, err := providers.CreateProvider(cfg) if err != nil { - fmt.Printf("Error creating provider: %v\n", err) - os.Exit(1) + return fmt.Errorf("error creating provider: %w", err) } + // Use the resolved model ID from provider creation if modelID != "" { cfg.Agents.Defaults.ModelName = modelID @@ -85,18 +63,20 @@ func agentCmd() { ctx := context.Background() response, err := agentLoop.ProcessDirect(ctx, message, sessionKey) if err != nil { - fmt.Printf("Error: %v\n", err) - os.Exit(1) + return fmt.Errorf("error processing message: %w", err) } - fmt.Printf("\n%s %s\n", logo, response) - } else { - fmt.Printf("%s Interactive mode (Ctrl+C to exit)\n\n", logo) - interactiveMode(agentLoop, sessionKey) + fmt.Printf("\n%s %s\n", internal.Logo, response) + return nil } + + fmt.Printf("%s Interactive mode (Ctrl+C to exit)\n\n", internal.Logo) + interactiveMode(agentLoop, sessionKey) + + return nil } func interactiveMode(agentLoop *agent.AgentLoop, sessionKey string) { - prompt := fmt.Sprintf("%s You: ", logo) + prompt := fmt.Sprintf("%s You: ", internal.Logo) rl, err := readline.NewEx(&readline.Config{ Prompt: prompt, @@ -141,14 +121,14 @@ func interactiveMode(agentLoop *agent.AgentLoop, sessionKey string) { continue } - fmt.Printf("\n%s %s\n\n", logo, response) + fmt.Printf("\n%s %s\n\n", internal.Logo, response) } } func simpleInteractiveMode(agentLoop *agent.AgentLoop, sessionKey string) { reader := bufio.NewReader(os.Stdin) for { - fmt.Printf("%s You: ", logo) + fmt.Print(fmt.Sprintf("%s You: ", internal.Logo)) line, err := reader.ReadString('\n') if err != nil { if err == io.EOF { @@ -176,6 +156,6 @@ func simpleInteractiveMode(agentLoop *agent.AgentLoop, sessionKey string) { continue } - fmt.Printf("\n%s %s\n\n", logo, response) + fmt.Printf("\n%s %s\n\n", internal.Logo, response) } } diff --git a/cmd/picoclaw/internal/auth/command.go b/cmd/picoclaw/internal/auth/command.go new file mode 100644 index 000000000..12a0a3a8c --- /dev/null +++ b/cmd/picoclaw/internal/auth/command.go @@ -0,0 +1,22 @@ +package auth + +import "github.com/spf13/cobra" + +func NewAuthCommand() *cobra.Command { + cmd := &cobra.Command{ + Use: "auth", + Short: "Manage authentication (login, logout, status)", + RunE: func(cmd *cobra.Command, _ []string) error { + return cmd.Help() + }, + } + + cmd.AddCommand( + newLoginCommand(), + newLogoutCommand(), + newStatusCommand(), + newModelsCommand(), + ) + + return cmd +} diff --git a/cmd/picoclaw/internal/auth/command_test.go b/cmd/picoclaw/internal/auth/command_test.go new file mode 100644 index 000000000..48dc704dd --- /dev/null +++ b/cmd/picoclaw/internal/auth/command_test.go @@ -0,0 +1,55 @@ +package auth + +import ( + "slices" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewAuthCommand(t *testing.T) { + cmd := NewAuthCommand() + + require.NotNil(t, cmd) + + assert.Equal(t, "auth", cmd.Use) + assert.Equal(t, "Manage authentication (login, logout, status)", cmd.Short) + + assert.Len(t, cmd.Aliases, 0) + + assert.Nil(t, cmd.Run) + assert.NotNil(t, cmd.RunE) + + assert.Nil(t, cmd.PersistentPreRun) + assert.Nil(t, cmd.PersistentPostRun) + + assert.False(t, cmd.HasFlags()) + assert.True(t, cmd.HasSubCommands()) + + allowedCommands := []string{ + "login", + "logout", + "status", + "models", + } + + subcommands := cmd.Commands() + assert.Len(t, subcommands, len(allowedCommands)) + + for _, subcmd := range subcommands { + found := slices.Contains(allowedCommands, subcmd.Name()) + assert.True(t, found, "unexpected subcommand %q", subcmd.Name()) + + assert.Len(t, subcmd.Aliases, 0) + assert.False(t, subcmd.Hidden) + + assert.False(t, subcmd.HasSubCommands()) + + assert.Nil(t, subcmd.Run) + assert.NotNil(t, subcmd.RunE) + + assert.Nil(t, subcmd.PersistentPreRun) + assert.Nil(t, subcmd.PersistentPostRun) + } +} diff --git a/cmd/picoclaw/cmd_auth.go b/cmd/picoclaw/internal/auth/helpers.go similarity index 67% rename from cmd/picoclaw/cmd_auth.go rename to cmd/picoclaw/internal/auth/helpers.go index 55eb3cec3..633ce8740 100644 --- a/cmd/picoclaw/cmd_auth.go +++ b/cmd/picoclaw/internal/auth/helpers.go @@ -1,7 +1,4 @@ -// PicoClaw - Ultra-lightweight personal AI agent -// License: MIT - -package main +package auth import ( "encoding/json" @@ -12,92 +9,28 @@ import ( "strings" "time" + "github.com/sipeed/picoclaw/cmd/picoclaw/internal" "github.com/sipeed/picoclaw/pkg/auth" "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/providers" ) -const supportedProvidersMsg = "Supported providers: openai, anthropic, google-antigravity" - -func authCmd() { - if len(os.Args) < 3 { - authHelp() - return - } - - switch os.Args[2] { - case "login": - authLoginCmd() - case "logout": - authLogoutCmd() - case "status": - authStatusCmd() - case "models": - authModelsCmd() - default: - fmt.Printf("Unknown auth command: %s\n", os.Args[2]) - authHelp() - } -} - -func authHelp() { - fmt.Println("\nAuth commands:") - fmt.Println(" login Login via OAuth or paste token") - fmt.Println(" logout Remove stored credentials") - fmt.Println(" status Show current auth status") - fmt.Println(" models List available Antigravity models") - fmt.Println() - fmt.Println("Login options:") - fmt.Println(" --provider Provider to login with (openai, anthropic, google-antigravity)") - fmt.Println(" --device-code Use device code flow (for headless environments)") - fmt.Println() - fmt.Println("Examples:") - fmt.Println(" picoclaw auth login --provider openai") - fmt.Println(" picoclaw auth login --provider openai --device-code") - fmt.Println(" picoclaw auth login --provider anthropic") - fmt.Println(" picoclaw auth login --provider google-antigravity") - fmt.Println(" picoclaw auth models") - fmt.Println(" picoclaw auth logout --provider openai") - fmt.Println(" picoclaw auth status") -} - -func authLoginCmd() { - provider := "" - useDeviceCode := false - - args := os.Args[3:] - for i := 0; i < len(args); i++ { - switch args[i] { - case "--provider", "-p": - if i+1 < len(args) { - provider = args[i+1] - i++ - } - case "--device-code": - useDeviceCode = true - } - } - - if provider == "" { - fmt.Println("Error: --provider is required") - fmt.Println(supportedProvidersMsg) - return - } +const supportedProvidersMsg = "supported providers: openai, anthropic, google-antigravity" +func authLoginCmd(provider string, useDeviceCode bool) error { switch provider { case "openai": - authLoginOpenAI(useDeviceCode) + return authLoginOpenAI(useDeviceCode) case "anthropic": - authLoginPasteToken(provider) + return authLoginPasteToken(provider) case "google-antigravity", "antigravity": - authLoginGoogleAntigravity() + return authLoginGoogleAntigravity() default: - fmt.Printf("Unsupported provider: %s\n", provider) - fmt.Println(supportedProvidersMsg) + return fmt.Errorf("unsupported provider: %s (%s)", provider, supportedProvidersMsg) } } -func authLoginOpenAI(useDeviceCode bool) { +func authLoginOpenAI(useDeviceCode bool) error { cfg := auth.OpenAIOAuthConfig() var cred *auth.AuthCredential @@ -110,16 +43,14 @@ func authLoginOpenAI(useDeviceCode bool) { } if err != nil { - fmt.Printf("Login failed: %v\n", err) - os.Exit(1) + return fmt.Errorf("login failed: %w", err) } if err = auth.SetCredential("openai", cred); err != nil { - fmt.Printf("Failed to save credentials: %v\n", err) - os.Exit(1) + return fmt.Errorf("failed to save credentials: %w", err) } - appCfg, err := loadConfig() + appCfg, err := internal.LoadConfig() if err == nil { // Update Providers (legacy format) appCfg.Providers.OpenAI.AuthMethod = "oauth" @@ -146,8 +77,8 @@ func authLoginOpenAI(useDeviceCode bool) { // Update default model to use OpenAI appCfg.Agents.Defaults.ModelName = "gpt-5.2" - if err := config.SaveConfig(getConfigPath(), appCfg); err != nil { - fmt.Printf("Warning: could not update config: %v\n", err) + if err = config.SaveConfig(internal.GetConfigPath(), appCfg); err != nil { + return fmt.Errorf("could not update config: %w", err) } } @@ -156,15 +87,16 @@ func authLoginOpenAI(useDeviceCode bool) { fmt.Printf("Account: %s\n", cred.AccountID) } fmt.Println("Default model set to: gpt-5.2") + + return nil } -func authLoginGoogleAntigravity() { +func authLoginGoogleAntigravity() error { cfg := auth.GoogleAntigravityOAuthConfig() cred, err := auth.LoginBrowser(cfg) if err != nil { - fmt.Printf("Login failed: %v\n", err) - os.Exit(1) + return fmt.Errorf("login failed: %w", err) } cred.Provider = "google-antigravity" @@ -189,11 +121,10 @@ func authLoginGoogleAntigravity() { } if err = auth.SetCredential("google-antigravity", cred); err != nil { - fmt.Printf("Failed to save credentials: %v\n", err) - os.Exit(1) + return fmt.Errorf("failed to save credentials: %w", err) } - appCfg, err := loadConfig() + appCfg, err := internal.LoadConfig() if err == nil { // Update Providers (legacy format, for backward compatibility) appCfg.Providers.Antigravity.AuthMethod = "oauth" @@ -220,7 +151,7 @@ func authLoginGoogleAntigravity() { // Update default model appCfg.Agents.Defaults.ModelName = "gemini-flash" - if err := config.SaveConfig(getConfigPath(), appCfg); err != nil { + if err := config.SaveConfig(internal.GetConfigPath(), appCfg); err != nil { fmt.Printf("Warning: could not update config: %v\n", err) } } @@ -228,6 +159,8 @@ func authLoginGoogleAntigravity() { fmt.Println("\n✓ Google Antigravity login successful!") fmt.Println("Default model set to: gemini-flash") fmt.Println("Try it: picoclaw agent -m \"Hello world\"") + + return nil } func fetchGoogleUserEmail(accessToken string) (string, error) { @@ -258,19 +191,17 @@ func fetchGoogleUserEmail(accessToken string) (string, error) { return userInfo.Email, nil } -func authLoginPasteToken(provider string) { +func authLoginPasteToken(provider string) error { cred, err := auth.LoginPasteToken(provider, os.Stdin) if err != nil { - fmt.Printf("Login failed: %v\n", err) - os.Exit(1) + return fmt.Errorf("login failed: %w", err) } if err = auth.SetCredential(provider, cred); err != nil { - fmt.Printf("Failed to save credentials: %v\n", err) - os.Exit(1) + return fmt.Errorf("failed to save credentials: %w", err) } - appCfg, err := loadConfig() + appCfg, err := internal.LoadConfig() if err == nil { switch provider { case "anthropic": @@ -314,36 +245,27 @@ func authLoginPasteToken(provider string) { // Update default model appCfg.Agents.Defaults.ModelName = "gpt-5.2" } - if err := config.SaveConfig(getConfigPath(), appCfg); err != nil { - fmt.Printf("Warning: could not update config: %v\n", err) + if err := config.SaveConfig(internal.GetConfigPath(), appCfg); err != nil { + return fmt.Errorf("could not update config: %w", err) } } fmt.Printf("Token saved for %s!\n", provider) - fmt.Printf("Default model set to: %s\n", appCfg.Agents.Defaults.GetModelName()) -} -func authLogoutCmd() { - provider := "" - - args := os.Args[3:] - for i := 0; i < len(args); i++ { - switch args[i] { - case "--provider", "-p": - if i+1 < len(args) { - provider = args[i+1] - i++ - } - } + if appCfg != nil { + fmt.Printf("Default model set to: %s\n", appCfg.Agents.Defaults.GetModelName()) } + return nil +} + +func authLogoutCmd(provider string) error { if provider != "" { if err := auth.DeleteCredential(provider); err != nil { - fmt.Printf("Failed to remove credentials: %v\n", err) - os.Exit(1) + return fmt.Errorf("failed to remove credentials: %w", err) } - appCfg, err := loadConfig() + appCfg, err := internal.LoadConfig() if err == nil { // Clear AuthMethod in ModelList for i := range appCfg.ModelList { @@ -371,44 +293,46 @@ func authLogoutCmd() { case "google-antigravity", "antigravity": appCfg.Providers.Antigravity.AuthMethod = "" } - config.SaveConfig(getConfigPath(), appCfg) + config.SaveConfig(internal.GetConfigPath(), appCfg) } fmt.Printf("Logged out from %s\n", provider) - } else { - if err := auth.DeleteAllCredentials(); err != nil { - fmt.Printf("Failed to remove credentials: %v\n", err) - os.Exit(1) - } - appCfg, err := loadConfig() - if err == nil { - // Clear all AuthMethods in ModelList - for i := range appCfg.ModelList { - appCfg.ModelList[i].AuthMethod = "" - } - // Clear all AuthMethods in Providers (legacy) - appCfg.Providers.OpenAI.AuthMethod = "" - appCfg.Providers.Anthropic.AuthMethod = "" - appCfg.Providers.Antigravity.AuthMethod = "" - config.SaveConfig(getConfigPath(), appCfg) - } - - fmt.Println("Logged out from all providers") + return nil } + + if err := auth.DeleteAllCredentials(); err != nil { + return fmt.Errorf("failed to remove credentials: %w", err) + } + + appCfg, err := internal.LoadConfig() + if err == nil { + // Clear all AuthMethods in ModelList + for i := range appCfg.ModelList { + appCfg.ModelList[i].AuthMethod = "" + } + // Clear all AuthMethods in Providers (legacy) + appCfg.Providers.OpenAI.AuthMethod = "" + appCfg.Providers.Anthropic.AuthMethod = "" + appCfg.Providers.Antigravity.AuthMethod = "" + config.SaveConfig(internal.GetConfigPath(), appCfg) + } + + fmt.Println("Logged out from all providers") + + return nil } -func authStatusCmd() { +func authStatusCmd() error { store, err := auth.LoadStore() if err != nil { - fmt.Printf("Error loading auth store: %v\n", err) - return + return fmt.Errorf("failed to load auth store: %w", err) } if len(store.Credentials) == 0 { fmt.Println("No authenticated providers.") fmt.Println("Run: picoclaw auth login --provider ") - return + return nil } fmt.Println("\nAuthenticated Providers:") @@ -437,14 +361,16 @@ func authStatusCmd() { fmt.Printf(" Expires: %s\n", cred.ExpiresAt.Format("2006-01-02 15:04")) } } + + return nil } -func authModelsCmd() { +func authModelsCmd() error { cred, err := auth.GetCredential("google-antigravity") if err != nil || cred == nil { - fmt.Println("Not logged in to Google Antigravity.") - fmt.Println("Run: picoclaw auth login --provider google-antigravity") - return + return fmt.Errorf( + "not logged in to Google Antigravity.\nrun: picoclaw auth login --provider google-antigravity", + ) } // Refresh token if needed @@ -459,21 +385,18 @@ func authModelsCmd() { projectID := cred.ProjectID if projectID == "" { - fmt.Println("No project ID stored. Try logging in again.") - return + return fmt.Errorf("no project id stored. Try logging in again") } fmt.Printf("Fetching models for project: %s\n\n", projectID) models, err := providers.FetchAntigravityModels(cred.AccessToken, projectID) if err != nil { - fmt.Printf("Error fetching models: %v\n", err) - return + return fmt.Errorf("error fetching models: %w", err) } if len(models) == 0 { - fmt.Println("No models available.") - return + return fmt.Errorf("no models available") } fmt.Println("Available Antigravity Models:") @@ -489,6 +412,8 @@ func authModelsCmd() { } fmt.Printf(" %s %s\n", status, name) } + + return nil } // isAntigravityModel checks if a model string belongs to antigravity provider diff --git a/cmd/picoclaw/internal/auth/login.go b/cmd/picoclaw/internal/auth/login.go new file mode 100644 index 000000000..9a6d28d2f --- /dev/null +++ b/cmd/picoclaw/internal/auth/login.go @@ -0,0 +1,25 @@ +package auth + +import "github.com/spf13/cobra" + +func newLoginCommand() *cobra.Command { + var ( + provider string + useDeviceCode bool + ) + + cmd := &cobra.Command{ + Use: "login", + Short: "Login via OAuth or paste token", + Args: cobra.NoArgs, + RunE: func(cmd *cobra.Command, _ []string) error { + return authLoginCmd(provider, useDeviceCode) + }, + } + + cmd.Flags().StringVarP(&provider, "provider", "p", "", "Provider to login with (openai, anthropic)") + cmd.Flags().BoolVar(&useDeviceCode, "device-code", false, "Use device code flow (for headless environments)") + _ = cmd.MarkFlagRequired("provider") + + return cmd +} diff --git a/cmd/picoclaw/internal/auth/login_test.go b/cmd/picoclaw/internal/auth/login_test.go new file mode 100644 index 000000000..d6a03c25b --- /dev/null +++ b/cmd/picoclaw/internal/auth/login_test.go @@ -0,0 +1,29 @@ +package auth + +import ( + "testing" + + "github.com/spf13/cobra" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewLoginSubCommand(t *testing.T) { + cmd := newLoginCommand() + + require.NotNil(t, cmd) + + assert.Equal(t, "Login via OAuth or paste token", cmd.Short) + + assert.True(t, cmd.HasFlags()) + + assert.NotNil(t, cmd.Flags().Lookup("device-code")) + + providerFlag := cmd.Flags().Lookup("provider") + require.NotNil(t, providerFlag) + + val, found := providerFlag.Annotations[cobra.BashCompOneRequiredFlag] + require.True(t, found) + require.NotEmpty(t, val) + assert.Equal(t, "true", val[0]) +} diff --git a/cmd/picoclaw/internal/auth/logout.go b/cmd/picoclaw/internal/auth/logout.go new file mode 100644 index 000000000..384667524 --- /dev/null +++ b/cmd/picoclaw/internal/auth/logout.go @@ -0,0 +1,20 @@ +package auth + +import "github.com/spf13/cobra" + +func newLogoutCommand() *cobra.Command { + var provider string + + cmd := &cobra.Command{ + Use: "logout", + Short: "Remove stored credentials", + Args: cobra.NoArgs, + RunE: func(cmd *cobra.Command, _ []string) error { + return authLogoutCmd(provider) + }, + } + + cmd.Flags().StringVarP(&provider, "provider", "p", "", "Provider to logout from (openai, anthropic); empty = all") + + return cmd +} diff --git a/cmd/picoclaw/internal/auth/logout_test.go b/cmd/picoclaw/internal/auth/logout_test.go new file mode 100644 index 000000000..c0f3a5e92 --- /dev/null +++ b/cmd/picoclaw/internal/auth/logout_test.go @@ -0,0 +1,20 @@ +package auth + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewLogoutSubcommand(t *testing.T) { + cmd := newLogoutCommand() + + require.NotNil(t, cmd) + + assert.Equal(t, "Remove stored credentials", cmd.Short) + + assert.True(t, cmd.HasFlags()) + + assert.NotNil(t, cmd.Flags().Lookup("provider")) +} diff --git a/cmd/picoclaw/internal/auth/models.go b/cmd/picoclaw/internal/auth/models.go new file mode 100644 index 000000000..cabe6822c --- /dev/null +++ b/cmd/picoclaw/internal/auth/models.go @@ -0,0 +1,15 @@ +package auth + +import "github.com/spf13/cobra" + +func newModelsCommand() *cobra.Command { + cmd := &cobra.Command{ + Use: "models", + Short: "Show available models", + RunE: func(_ *cobra.Command, _ []string) error { + return authModelsCmd() + }, + } + + return cmd +} diff --git a/cmd/picoclaw/internal/auth/models_test.go b/cmd/picoclaw/internal/auth/models_test.go new file mode 100644 index 000000000..26ca67787 --- /dev/null +++ b/cmd/picoclaw/internal/auth/models_test.go @@ -0,0 +1,19 @@ +package auth + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewModelsCommand(t *testing.T) { + cmd := newModelsCommand() + + require.NotNil(t, cmd) + + assert.Equal(t, "models", cmd.Use) + assert.Equal(t, "Show available models", cmd.Short) + + assert.False(t, cmd.HasFlags()) +} diff --git a/cmd/picoclaw/internal/auth/status.go b/cmd/picoclaw/internal/auth/status.go new file mode 100644 index 000000000..ca3007d12 --- /dev/null +++ b/cmd/picoclaw/internal/auth/status.go @@ -0,0 +1,16 @@ +package auth + +import "github.com/spf13/cobra" + +func newStatusCommand() *cobra.Command { + cmd := &cobra.Command{ + Use: "status", + Short: "Show current auth status", + Args: cobra.NoArgs, + RunE: func(cmd *cobra.Command, _ []string) error { + return authStatusCmd() + }, + } + + return cmd +} diff --git a/cmd/picoclaw/internal/auth/status_test.go b/cmd/picoclaw/internal/auth/status_test.go new file mode 100644 index 000000000..7748ba502 --- /dev/null +++ b/cmd/picoclaw/internal/auth/status_test.go @@ -0,0 +1,18 @@ +package auth + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewStatusSubcommand(t *testing.T) { + cmd := newStatusCommand() + + require.NotNil(t, cmd) + + assert.Equal(t, "Show current auth status", cmd.Short) + + assert.False(t, cmd.HasFlags()) +} diff --git a/cmd/picoclaw/internal/cron/add.go b/cmd/picoclaw/internal/cron/add.go new file mode 100644 index 000000000..947557d5a --- /dev/null +++ b/cmd/picoclaw/internal/cron/add.go @@ -0,0 +1,64 @@ +package cron + +import ( + "fmt" + + "github.com/spf13/cobra" + + "github.com/sipeed/picoclaw/pkg/cron" +) + +func newAddCommand(storePath func() string) *cobra.Command { + var ( + name string + message string + every int64 + cronExp string + deliver bool + channel string + to string + ) + + cmd := &cobra.Command{ + Use: "add", + Short: "Add a new scheduled job", + Args: cobra.NoArgs, + RunE: func(cmd *cobra.Command, _ []string) error { + if every <= 0 && cronExp == "" { + return fmt.Errorf("either --every or --cron must be specified") + } + + var schedule cron.CronSchedule + if every > 0 { + everyMS := every * 1000 + schedule = cron.CronSchedule{Kind: "every", EveryMS: &everyMS} + } else { + schedule = cron.CronSchedule{Kind: "cron", Expr: cronExp} + } + + cs := cron.NewCronService(storePath(), nil) + job, err := cs.AddJob(name, schedule, message, deliver, channel, to) + if err != nil { + return fmt.Errorf("error adding job: %w", err) + } + + fmt.Printf("✓ Added job '%s' (%s)\n", job.Name, job.ID) + + return nil + }, + } + + cmd.Flags().StringVarP(&name, "name", "n", "", "Job name") + cmd.Flags().StringVarP(&message, "message", "m", "", "Message for agent") + cmd.Flags().Int64VarP(&every, "every", "e", 0, "Run every N seconds") + cmd.Flags().StringVarP(&cronExp, "cron", "c", "", "Cron expression (e.g. '0 9 * * *')") + cmd.Flags().BoolVarP(&deliver, "deliver", "d", false, "Deliver response to channel") + cmd.Flags().StringVar(&to, "to", "", "Recipient for delivery") + cmd.Flags().StringVar(&channel, "channel", "", "Channel for delivery") + + _ = cmd.MarkFlagRequired("name") + _ = cmd.MarkFlagRequired("message") + cmd.MarkFlagsMutuallyExclusive("every", "cron") + + return cmd +} diff --git a/cmd/picoclaw/internal/cron/add_test.go b/cmd/picoclaw/internal/cron/add_test.go new file mode 100644 index 000000000..09701fab5 --- /dev/null +++ b/cmd/picoclaw/internal/cron/add_test.go @@ -0,0 +1,57 @@ +package cron + +import ( + "testing" + + "github.com/spf13/cobra" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewAddSubcommand(t *testing.T) { + fn := func() string { return "" } + cmd := newAddCommand(fn) + + require.NotNil(t, cmd) + + assert.Equal(t, "add", cmd.Use) + assert.Equal(t, "Add a new scheduled job", cmd.Short) + + assert.True(t, cmd.HasFlags()) + + assert.NotNil(t, cmd.Flags().Lookup("every")) + assert.NotNil(t, cmd.Flags().Lookup("cron")) + assert.NotNil(t, cmd.Flags().Lookup("deliver")) + assert.NotNil(t, cmd.Flags().Lookup("to")) + assert.NotNil(t, cmd.Flags().Lookup("channel")) + + nameFlag := cmd.Flags().Lookup("name") + require.NotNil(t, nameFlag) + + messageFlag := cmd.Flags().Lookup("message") + require.NotNil(t, messageFlag) + + val, found := nameFlag.Annotations[cobra.BashCompOneRequiredFlag] + require.True(t, found) + require.NotEmpty(t, val) + assert.Equal(t, "true", val[0]) + + val, found = messageFlag.Annotations[cobra.BashCompOneRequiredFlag] + require.True(t, found) + require.NotEmpty(t, val) + assert.Equal(t, "true", val[0]) +} + +func TestNewAddCommandEveryAndCronMutuallyExclusive(t *testing.T) { + cmd := newAddCommand(func() string { return "testing" }) + + cmd.SetArgs([]string{ + "--name", "job", + "--message", "hello", + "--every", "10", + "--cron", "0 9 * * *", + }) + + err := cmd.Execute() + require.Error(t, err) +} diff --git a/cmd/picoclaw/internal/cron/command.go b/cmd/picoclaw/internal/cron/command.go new file mode 100644 index 000000000..39f8ccf28 --- /dev/null +++ b/cmd/picoclaw/internal/cron/command.go @@ -0,0 +1,44 @@ +package cron + +import ( + "fmt" + "path/filepath" + + "github.com/spf13/cobra" + + "github.com/sipeed/picoclaw/cmd/picoclaw/internal" +) + +func NewCronCommand() *cobra.Command { + var storePath string + + cmd := &cobra.Command{ + Use: "cron", + Aliases: []string{"c"}, + Short: "Manage scheduled tasks", + Args: cobra.NoArgs, + RunE: func(cmd *cobra.Command, _ []string) error { + return cmd.Help() + }, + // Resolve storePath at execution time so it reflects the current config + // and is shared across all subcommands. + PersistentPreRunE: func(_ *cobra.Command, _ []string) error { + cfg, err := internal.LoadConfig() + if err != nil { + return fmt.Errorf("error loading config: %w", err) + } + storePath = filepath.Join(cfg.WorkspacePath(), "cron", "jobs.json") + return nil + }, + } + + cmd.AddCommand( + newListCommand(func() string { return storePath }), + newAddCommand(func() string { return storePath }), + newRemoveCommand(func() string { return storePath }), + newEnableCommand(func() string { return storePath }), + newDisableCommand(func() string { return storePath }), + ) + + return cmd +} diff --git a/cmd/picoclaw/internal/cron/command_test.go b/cmd/picoclaw/internal/cron/command_test.go new file mode 100644 index 000000000..af2ac83ae --- /dev/null +++ b/cmd/picoclaw/internal/cron/command_test.go @@ -0,0 +1,58 @@ +package cron + +import ( + "slices" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewCronCommand(t *testing.T) { + cmd := NewCronCommand() + + require.NotNil(t, cmd) + + assert.Equal(t, "Manage scheduled tasks", cmd.Short) + + assert.Len(t, cmd.Aliases, 1) + assert.True(t, cmd.HasAlias("c")) + + assert.False(t, cmd.HasFlags()) + + assert.Nil(t, cmd.Run) + assert.NotNil(t, cmd.RunE) + + assert.NotNil(t, cmd.PersistentPreRunE) + assert.Nil(t, cmd.PersistentPreRun) + assert.Nil(t, cmd.PersistentPostRun) + + assert.True(t, cmd.HasSubCommands()) + + allowedCommands := []string{ + "list", + "add", + "remove", + "enable", + "disable", + } + + subcommands := cmd.Commands() + assert.Len(t, subcommands, len(allowedCommands)) + + for _, subcmd := range subcommands { + found := slices.Contains(allowedCommands, subcmd.Name()) + assert.True(t, found, "unexpected subcommand %q", subcmd.Name()) + + assert.Len(t, subcmd.Aliases, 0) + assert.False(t, subcmd.Hidden) + + assert.False(t, subcmd.HasSubCommands()) + + assert.Nil(t, subcmd.Run) + assert.NotNil(t, subcmd.RunE) + + assert.Nil(t, subcmd.PersistentPreRun) + assert.Nil(t, subcmd.PersistentPostRun) + } +} diff --git a/cmd/picoclaw/internal/cron/disable.go b/cmd/picoclaw/internal/cron/disable.go new file mode 100644 index 000000000..a3670fd50 --- /dev/null +++ b/cmd/picoclaw/internal/cron/disable.go @@ -0,0 +1,16 @@ +package cron + +import "github.com/spf13/cobra" + +func newDisableCommand(storePath func() string) *cobra.Command { + return &cobra.Command{ + Use: "disable", + Short: "Disable a job", + Args: cobra.ExactArgs(1), + Example: `picoclaw cron disable 1`, + RunE: func(_ *cobra.Command, args []string) error { + cronSetJobEnabled(storePath(), args[0], false) + return nil + }, + } +} diff --git a/cmd/picoclaw/internal/cron/disable_test.go b/cmd/picoclaw/internal/cron/disable_test.go new file mode 100644 index 000000000..e5d2ff844 --- /dev/null +++ b/cmd/picoclaw/internal/cron/disable_test.go @@ -0,0 +1,20 @@ +package cron + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestDisableSubcommand(t *testing.T) { + fn := func() string { return "" } + cmd := newDisableCommand(fn) + + require.NotNil(t, cmd) + + assert.Equal(t, "disable", cmd.Use) + assert.Equal(t, "Disable a job", cmd.Short) + + assert.True(t, cmd.HasExample()) +} diff --git a/cmd/picoclaw/internal/cron/enable.go b/cmd/picoclaw/internal/cron/enable.go new file mode 100644 index 000000000..7f8b05233 --- /dev/null +++ b/cmd/picoclaw/internal/cron/enable.go @@ -0,0 +1,16 @@ +package cron + +import "github.com/spf13/cobra" + +func newEnableCommand(storePath func() string) *cobra.Command { + return &cobra.Command{ + Use: "enable", + Short: "Enable a job", + Args: cobra.ExactArgs(1), + Example: `picoclaw cron enable 1`, + RunE: func(_ *cobra.Command, args []string) error { + cronSetJobEnabled(storePath(), args[0], true) + return nil + }, + } +} diff --git a/cmd/picoclaw/internal/cron/enable_test.go b/cmd/picoclaw/internal/cron/enable_test.go new file mode 100644 index 000000000..85a2e01aa --- /dev/null +++ b/cmd/picoclaw/internal/cron/enable_test.go @@ -0,0 +1,20 @@ +package cron + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestEnableSubcommand(t *testing.T) { + fn := func() string { return "" } + cmd := newEnableCommand(fn) + + require.NotNil(t, cmd) + + assert.Equal(t, "enable", cmd.Use) + assert.Equal(t, "Enable a job", cmd.Short) + + assert.True(t, cmd.HasExample()) +} diff --git a/cmd/picoclaw/internal/cron/helpers.go b/cmd/picoclaw/internal/cron/helpers.go new file mode 100644 index 000000000..88bdf1bf7 --- /dev/null +++ b/cmd/picoclaw/internal/cron/helpers.go @@ -0,0 +1,66 @@ +package cron + +import ( + "fmt" + "time" + + "github.com/sipeed/picoclaw/pkg/cron" +) + +func cronListCmd(storePath string) { + cs := cron.NewCronService(storePath, nil) + jobs := cs.ListJobs(true) // Show all jobs, including disabled + + if len(jobs) == 0 { + fmt.Println("No scheduled jobs.") + return + } + + fmt.Println("\nScheduled Jobs:") + fmt.Println("----------------") + for _, job := range jobs { + var schedule string + if job.Schedule.Kind == "every" && job.Schedule.EveryMS != nil { + schedule = fmt.Sprintf("every %ds", *job.Schedule.EveryMS/1000) + } else if job.Schedule.Kind == "cron" { + schedule = job.Schedule.Expr + } else { + schedule = "one-time" + } + + nextRun := "scheduled" + if job.State.NextRunAtMS != nil { + nextTime := time.UnixMilli(*job.State.NextRunAtMS) + nextRun = nextTime.Format("2006-01-02 15:04") + } + + status := "enabled" + if !job.Enabled { + status = "disabled" + } + + fmt.Printf(" %s (%s)\n", job.Name, job.ID) + fmt.Printf(" Schedule: %s\n", schedule) + fmt.Printf(" Status: %s\n", status) + fmt.Printf(" Next run: %s\n", nextRun) + } +} + +func cronRemoveCmd(storePath, jobID string) { + cs := cron.NewCronService(storePath, nil) + if cs.RemoveJob(jobID) { + fmt.Printf("✓ Removed job %s\n", jobID) + } else { + fmt.Printf("✗ Job %s not found\n", jobID) + } +} + +func cronSetJobEnabled(storePath, jobID string, enabled bool) { + cs := cron.NewCronService(storePath, nil) + job := cs.EnableJob(jobID, enabled) + if job != nil { + fmt.Printf("✓ Job '%s' enabled\n", job.Name) + } else { + fmt.Printf("✗ Job %s not found\n", jobID) + } +} diff --git a/cmd/picoclaw/internal/cron/list.go b/cmd/picoclaw/internal/cron/list.go new file mode 100644 index 000000000..854eb1a44 --- /dev/null +++ b/cmd/picoclaw/internal/cron/list.go @@ -0,0 +1,17 @@ +package cron + +import "github.com/spf13/cobra" + +func newListCommand(storePath func() string) *cobra.Command { + cmd := &cobra.Command{ + Use: "list", + Short: "List all scheduled jobs", + Args: cobra.NoArgs, + RunE: func(_ *cobra.Command, _ []string) error { + cronListCmd(storePath()) + return nil + }, + } + + return cmd +} diff --git a/cmd/picoclaw/internal/cron/list_test.go b/cmd/picoclaw/internal/cron/list_test.go new file mode 100644 index 000000000..0b9d1bd59 --- /dev/null +++ b/cmd/picoclaw/internal/cron/list_test.go @@ -0,0 +1,17 @@ +package cron + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewListSubcommand(t *testing.T) { + fn := func() string { return "" } + cmd := newListCommand(fn) + + require.NotNil(t, cmd) + + assert.Equal(t, "List all scheduled jobs", cmd.Short) +} diff --git a/cmd/picoclaw/internal/cron/remove.go b/cmd/picoclaw/internal/cron/remove.go new file mode 100644 index 000000000..5f1d1a04b --- /dev/null +++ b/cmd/picoclaw/internal/cron/remove.go @@ -0,0 +1,18 @@ +package cron + +import "github.com/spf13/cobra" + +func newRemoveCommand(storePath func() string) *cobra.Command { + cmd := &cobra.Command{ + Use: "remove", + Short: "Remove a job by ID", + Args: cobra.ExactArgs(1), + Example: `picoclaw cron remove 1`, + RunE: func(_ *cobra.Command, args []string) error { + cronRemoveCmd(storePath(), args[0]) + return nil + }, + } + + return cmd +} diff --git a/cmd/picoclaw/internal/cron/remove_test.go b/cmd/picoclaw/internal/cron/remove_test.go new file mode 100644 index 000000000..36121f370 --- /dev/null +++ b/cmd/picoclaw/internal/cron/remove_test.go @@ -0,0 +1,19 @@ +package cron + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewRemoveSubcommand(t *testing.T) { + fn := func() string { return "" } + cmd := newRemoveCommand(fn) + + require.NotNil(t, cmd) + + assert.Equal(t, "Remove a job by ID", cmd.Short) + + assert.True(t, cmd.HasExample()) +} diff --git a/cmd/picoclaw/internal/gateway/command.go b/cmd/picoclaw/internal/gateway/command.go new file mode 100644 index 000000000..66a56f9ce --- /dev/null +++ b/cmd/picoclaw/internal/gateway/command.go @@ -0,0 +1,23 @@ +package gateway + +import ( + "github.com/spf13/cobra" +) + +func NewGatewayCommand() *cobra.Command { + var debug bool + + cmd := &cobra.Command{ + Use: "gateway", + Aliases: []string{"g"}, + Short: "Start picoclaw gateway", + Args: cobra.NoArgs, + RunE: func(_ *cobra.Command, _ []string) error { + return gatewayCmd(debug) + }, + } + + cmd.Flags().BoolVarP(&debug, "debug", "d", false, "Enable debug logging") + + return cmd +} diff --git a/cmd/picoclaw/internal/gateway/command_test.go b/cmd/picoclaw/internal/gateway/command_test.go new file mode 100644 index 000000000..4d591ea67 --- /dev/null +++ b/cmd/picoclaw/internal/gateway/command_test.go @@ -0,0 +1,31 @@ +package gateway + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewGatewayCommand(t *testing.T) { + cmd := NewGatewayCommand() + + require.NotNil(t, cmd) + + assert.Equal(t, "gateway", cmd.Use) + assert.Equal(t, "Start picoclaw gateway", cmd.Short) + + assert.Len(t, cmd.Aliases, 1) + assert.True(t, cmd.HasAlias("g")) + + assert.Nil(t, cmd.Run) + assert.NotNil(t, cmd.RunE) + + assert.Nil(t, cmd.PersistentPreRun) + assert.Nil(t, cmd.PersistentPostRun) + + assert.False(t, cmd.HasSubCommands()) + + assert.True(t, cmd.HasFlags()) + assert.NotNil(t, cmd.Flags().Lookup("debug")) +} diff --git a/cmd/picoclaw/cmd_gateway.go b/cmd/picoclaw/internal/gateway/helpers.go similarity index 91% rename from cmd/picoclaw/cmd_gateway.go rename to cmd/picoclaw/internal/gateway/helpers.go index 3010c1451..a06625dc9 100644 --- a/cmd/picoclaw/cmd_gateway.go +++ b/cmd/picoclaw/internal/gateway/helpers.go @@ -1,10 +1,8 @@ -// PicoClaw - Ultra-lightweight personal AI agent -// License: MIT - -package main +package gateway import ( "context" + "errors" "fmt" "net/http" "os" @@ -13,6 +11,7 @@ import ( "strings" "time" + "github.com/sipeed/picoclaw/cmd/picoclaw/internal" "github.com/sipeed/picoclaw/pkg/agent" "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/channels" @@ -28,28 +27,22 @@ import ( "github.com/sipeed/picoclaw/pkg/voice" ) -func gatewayCmd() { - // Check for --debug flag - args := os.Args[2:] - for _, arg := range args { - if arg == "--debug" || arg == "-d" { - logger.SetLevel(logger.DEBUG) - fmt.Println("🔍 Debug mode enabled") - break - } +func gatewayCmd(debug bool) error { + if debug { + logger.SetLevel(logger.DEBUG) + fmt.Println("🔍 Debug mode enabled") } - cfg, err := loadConfig() + cfg, err := internal.LoadConfig() if err != nil { - fmt.Printf("Error loading config: %v\n", err) - os.Exit(1) + return fmt.Errorf("error loading config: %w", err) } provider, modelID, err := providers.CreateProvider(cfg) if err != nil { - fmt.Printf("Error creating provider: %v\n", err) - os.Exit(1) + return fmt.Errorf("error creating provider: %w", err) } + // Use the resolved model ID from provider creation if modelID != "" { cfg.Agents.Defaults.ModelName = modelID @@ -114,8 +107,7 @@ func gatewayCmd() { channelManager, err := channels.NewManager(cfg, msgBus) if err != nil { - fmt.Printf("Error creating channel manager: %v\n", err) - os.Exit(1) + return fmt.Errorf("error creating channel manager: %w", err) } // Inject channel manager into agent loop for command handling @@ -198,7 +190,7 @@ func gatewayCmd() { healthServer := health.NewServer(cfg.Gateway.Host, cfg.Gateway.Port) go func() { - if err := healthServer.Start(); err != nil && err != http.ErrServerClosed { + if err := healthServer.Start(); err != nil && !errors.Is(err, http.ErrServerClosed) { logger.ErrorCF("health", "Health server error", map[string]any{"error": err.Error()}) } }() @@ -222,6 +214,8 @@ func gatewayCmd() { agentLoop.Stop() channelManager.StopAll(ctx) fmt.Println("✓ Gateway stopped") + + return nil } func setupCronTool( diff --git a/cmd/picoclaw/internal/helpers.go b/cmd/picoclaw/internal/helpers.go new file mode 100644 index 000000000..1f52df5dd --- /dev/null +++ b/cmd/picoclaw/internal/helpers.go @@ -0,0 +1,52 @@ +package internal + +import ( + "fmt" + "os" + "path/filepath" + "runtime" + + "github.com/sipeed/picoclaw/pkg/config" +) + +const Logo = "🦞" + +var ( + version = "dev" + gitCommit string + buildTime string + goVersion string +) + +func GetConfigPath() string { + home, _ := os.UserHomeDir() + return filepath.Join(home, ".picoclaw", "config.json") +} + +func LoadConfig() (*config.Config, error) { + return config.LoadConfig(GetConfigPath()) +} + +// FormatVersion returns the version string with optional git commit +func FormatVersion() string { + v := version + if gitCommit != "" { + v += fmt.Sprintf(" (git: %s)", gitCommit) + } + return v +} + +// FormatBuildInfo returns build time and go version info +func FormatBuildInfo() (string, string) { + build := buildTime + goVer := goVersion + if goVer == "" { + goVer = runtime.Version() + } + return build, goVer +} + +// GetVersion returns the version string +func GetVersion() string { + return version +} diff --git a/cmd/picoclaw/internal/helpers_test.go b/cmd/picoclaw/internal/helpers_test.go new file mode 100644 index 000000000..9342d141d --- /dev/null +++ b/cmd/picoclaw/internal/helpers_test.go @@ -0,0 +1,97 @@ +package internal + +import ( + "path/filepath" + "runtime" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestGetConfigPath(t *testing.T) { + t.Setenv("HOME", "/tmp/home") + + got := GetConfigPath() + want := filepath.Join("/tmp/home", ".picoclaw", "config.json") + + assert.Equal(t, want, got) +} + +func TestFormatVersion_NoGitCommit(t *testing.T) { + oldVersion, oldGit := version, gitCommit + t.Cleanup(func() { version, gitCommit = oldVersion, oldGit }) + + version = "1.2.3" + gitCommit = "" + + assert.Equal(t, "1.2.3", FormatVersion()) +} + +func TestFormatVersion_WithGitCommit(t *testing.T) { + oldVersion, oldGit := version, gitCommit + t.Cleanup(func() { version, gitCommit = oldVersion, oldGit }) + + version = "1.2.3" + gitCommit = "abc123" + + assert.Equal(t, "1.2.3 (git: abc123)", FormatVersion()) +} + +func TestFormatBuildInfo_UsesBuildTimeAndGoVersion_WhenSet(t *testing.T) { + oldBuildTime, oldGoVersion := buildTime, goVersion + t.Cleanup(func() { buildTime, goVersion = oldBuildTime, oldGoVersion }) + + buildTime = "2026-02-20T00:00:00Z" + goVersion = "go1.23.0" + + build, goVer := FormatBuildInfo() + + assert.Equal(t, buildTime, build) + assert.Equal(t, goVersion, goVer) +} + +func TestFormatBuildInfo_EmptyBuildTime_ReturnsEmptyBuild(t *testing.T) { + oldBuildTime, oldGoVersion := buildTime, goVersion + t.Cleanup(func() { buildTime, goVersion = oldBuildTime, oldGoVersion }) + + buildTime = "" + goVersion = "go1.23.0" + + build, goVer := FormatBuildInfo() + + assert.Empty(t, build) + assert.Equal(t, goVersion, goVer) +} + +func TestFormatBuildInfo_EmptyGoVersion_FallsBackToRuntimeVersion(t *testing.T) { + oldBuildTime, oldGoVersion := buildTime, goVersion + t.Cleanup(func() { buildTime, goVersion = oldBuildTime, oldGoVersion }) + + buildTime = "x" + goVersion = "" + + build, goVer := FormatBuildInfo() + + assert.Equal(t, "x", build) + assert.Equal(t, runtime.Version(), goVer) +} + +func TestGetConfigPath_Windows(t *testing.T) { + if runtime.GOOS != "windows" { + t.Skip("windows-specific HOME behavior varies; run on windows") + } + + testUserProfilePath := `C:\Users\Test` + t.Setenv("USERPROFILE", testUserProfilePath) + + got := GetConfigPath() + want := filepath.Join(testUserProfilePath, ".picoclaw", "config.json") + + require.True(t, strings.EqualFold(got, want), "GetConfigPath() = %q, want %q", got, want) +} + +func TestGetVersion(t *testing.T) { + assert.Equal(t, "dev", GetVersion()) +} diff --git a/cmd/picoclaw/internal/migrate/command.go b/cmd/picoclaw/internal/migrate/command.go new file mode 100644 index 000000000..fb1cee164 --- /dev/null +++ b/cmd/picoclaw/internal/migrate/command.go @@ -0,0 +1,48 @@ +package migrate + +import ( + "github.com/spf13/cobra" + + "github.com/sipeed/picoclaw/pkg/migrate" +) + +func NewMigrateCommand() *cobra.Command { + var opts migrate.Options + + cmd := &cobra.Command{ + Use: "migrate", + Short: "Migrate from OpenClaw to PicoClaw", + Args: cobra.NoArgs, + Example: ` picoclaw migrate + picoclaw migrate --dry-run + picoclaw migrate --refresh + picoclaw migrate --force`, + RunE: func(cmd *cobra.Command, _ []string) error { + result, err := migrate.Run(opts) + if err != nil { + return err + } + if !opts.DryRun { + migrate.PrintSummary(result) + } + return nil + }, + } + + cmd.Flags().BoolVar(&opts.DryRun, "dry-run", false, + "Show what would be migrated without making changes") + cmd.Flags().BoolVar(&opts.Refresh, "refresh", false, + "Re-sync workspace files from OpenClaw (repeatable)") + cmd.Flags().BoolVar(&opts.ConfigOnly, "config-only", false, + "Only migrate config, skip workspace files") + cmd.Flags().BoolVar(&opts.WorkspaceOnly, "workspace-only", false, + "Only migrate workspace files, skip config") + cmd.Flags().BoolVar(&opts.Force, "force", false, + "Skip confirmation prompts") + cmd.Flags().StringVar(&opts.OpenClawHome, "openclaw-home", "", + "Override OpenClaw home directory (default: ~/.openclaw)") + cmd.Flags().StringVar(&opts.PicoClawHome, "picoclaw-home", "", + "Override PicoClaw home directory (default: ~/.picoclaw)") + + return cmd +} diff --git a/cmd/picoclaw/internal/migrate/command_test.go b/cmd/picoclaw/internal/migrate/command_test.go new file mode 100644 index 000000000..1948aa327 --- /dev/null +++ b/cmd/picoclaw/internal/migrate/command_test.go @@ -0,0 +1,38 @@ +package migrate + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewMigrateCommand(t *testing.T) { + cmd := NewMigrateCommand() + + require.NotNil(t, cmd) + + assert.Equal(t, "migrate", cmd.Use) + assert.Equal(t, "Migrate from OpenClaw to PicoClaw", cmd.Short) + + assert.Len(t, cmd.Aliases, 0) + + assert.True(t, cmd.HasExample()) + assert.False(t, cmd.HasSubCommands()) + + assert.Nil(t, cmd.Run) + assert.NotNil(t, cmd.RunE) + + assert.Nil(t, cmd.PersistentPreRun) + assert.Nil(t, cmd.PersistentPostRun) + + assert.True(t, cmd.HasFlags()) + + assert.NotNil(t, cmd.Flags().Lookup("dry-run")) + assert.NotNil(t, cmd.Flags().Lookup("refresh")) + assert.NotNil(t, cmd.Flags().Lookup("config-only")) + assert.NotNil(t, cmd.Flags().Lookup("workspace-only")) + assert.NotNil(t, cmd.Flags().Lookup("force")) + assert.NotNil(t, cmd.Flags().Lookup("openclaw-home")) + assert.NotNil(t, cmd.Flags().Lookup("picoclaw-home")) +} diff --git a/cmd/picoclaw/internal/onboard/command.go b/cmd/picoclaw/internal/onboard/command.go new file mode 100644 index 000000000..ec1012959 --- /dev/null +++ b/cmd/picoclaw/internal/onboard/command.go @@ -0,0 +1,24 @@ +package onboard + +import ( + "embed" + + "github.com/spf13/cobra" +) + +//go:generate cp -r ../../../../workspace . +//go:embed workspace +var embeddedFiles embed.FS + +func NewOnboardCommand() *cobra.Command { + cmd := &cobra.Command{ + Use: "onboard", + Aliases: []string{"o"}, + Short: "Initialize picoclaw configuration and workspace", + Run: func(cmd *cobra.Command, args []string) { + onboard() + }, + } + + return cmd +} diff --git a/cmd/picoclaw/internal/onboard/command_test.go b/cmd/picoclaw/internal/onboard/command_test.go new file mode 100644 index 000000000..bc799a079 --- /dev/null +++ b/cmd/picoclaw/internal/onboard/command_test.go @@ -0,0 +1,29 @@ +package onboard + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewOnboardCommand(t *testing.T) { + cmd := NewOnboardCommand() + + require.NotNil(t, cmd) + + assert.Equal(t, "onboard", cmd.Use) + assert.Equal(t, "Initialize picoclaw configuration and workspace", cmd.Short) + + assert.Len(t, cmd.Aliases, 1) + assert.True(t, cmd.HasAlias("o")) + + assert.NotNil(t, cmd.Run) + assert.Nil(t, cmd.RunE) + + assert.Nil(t, cmd.PersistentPreRun) + assert.Nil(t, cmd.PersistentPostRun) + + assert.False(t, cmd.HasFlags()) + assert.False(t, cmd.HasSubCommands()) +} diff --git a/cmd/picoclaw/cmd_onboard.go b/cmd/picoclaw/internal/onboard/helpers.go similarity index 90% rename from cmd/picoclaw/cmd_onboard.go rename to cmd/picoclaw/internal/onboard/helpers.go index 1a9ebad61..4db8bdc8b 100644 --- a/cmd/picoclaw/cmd_onboard.go +++ b/cmd/picoclaw/internal/onboard/helpers.go @@ -1,24 +1,17 @@ -// PicoClaw - Ultra-lightweight personal AI agent -// License: MIT - -package main +package onboard import ( - "embed" "fmt" "io/fs" "os" "path/filepath" + "github.com/sipeed/picoclaw/cmd/picoclaw/internal" "github.com/sipeed/picoclaw/pkg/config" ) -//go:generate cp -r ../../workspace . -//go:embed workspace -var embeddedFiles embed.FS - func onboard() { - configPath := getConfigPath() + configPath := internal.GetConfigPath() if _, err := os.Stat(configPath); err == nil { fmt.Printf("Config already exists at %s\n", configPath) @@ -40,7 +33,7 @@ func onboard() { workspace := cfg.WorkspacePath() createWorkspaceTemplates(workspace) - fmt.Printf("%s picoclaw is ready!\n", logo) + fmt.Printf("%s picoclaw is ready!\n", internal.Logo) fmt.Println("\nNext steps:") fmt.Println(" 1. Add your API key to", configPath) fmt.Println("") @@ -53,6 +46,13 @@ func onboard() { fmt.Println(" 2. Chat: picoclaw agent -m \"Hello!\"") } +func createWorkspaceTemplates(workspace string) { + err := copyEmbeddedToTarget(workspace) + if err != nil { + fmt.Printf("Error copying workspace templates: %v\n", err) + } +} + func copyEmbeddedToTarget(targetDir string) error { // Ensure target directory exists if err := os.MkdirAll(targetDir, 0o755); err != nil { @@ -99,10 +99,3 @@ func copyEmbeddedToTarget(targetDir string) error { return err } - -func createWorkspaceTemplates(workspace string) { - err := copyEmbeddedToTarget(workspace) - if err != nil { - fmt.Printf("Error copying workspace templates: %v\n", err) - } -} diff --git a/cmd/picoclaw/internal/skills/command.go b/cmd/picoclaw/internal/skills/command.go new file mode 100644 index 000000000..7f8bd011d --- /dev/null +++ b/cmd/picoclaw/internal/skills/command.go @@ -0,0 +1,79 @@ +package skills + +import ( + "fmt" + "path/filepath" + + "github.com/spf13/cobra" + + "github.com/sipeed/picoclaw/cmd/picoclaw/internal" + "github.com/sipeed/picoclaw/pkg/skills" +) + +type deps struct { + workspace string + installer *skills.SkillInstaller + skillsLoader *skills.SkillsLoader +} + +func NewSkillsCommand() *cobra.Command { + var d deps + + cmd := &cobra.Command{ + Use: "skills", + Short: "Manage skills", + PersistentPreRunE: func(cmd *cobra.Command, _ []string) error { + cfg, err := internal.LoadConfig() + if err != nil { + return fmt.Errorf("error loading config: %w", err) + } + + d.workspace = cfg.WorkspacePath() + d.installer = skills.NewSkillInstaller(d.workspace) + + // get global config directory and builtin skills directory + globalDir := filepath.Dir(internal.GetConfigPath()) + globalSkillsDir := filepath.Join(globalDir, "skills") + builtinSkillsDir := filepath.Join(globalDir, "picoclaw", "skills") + d.skillsLoader = skills.NewSkillsLoader(d.workspace, globalSkillsDir, builtinSkillsDir) + + return nil + }, + RunE: func(cmd *cobra.Command, _ []string) error { + return cmd.Help() + }, + } + + installerFn := func() (*skills.SkillInstaller, error) { + if d.installer == nil { + return nil, fmt.Errorf("skills installer is not initialized") + } + return d.installer, nil + } + + loaderFn := func() (*skills.SkillsLoader, error) { + if d.skillsLoader == nil { + return nil, fmt.Errorf("skills loader is not initialized") + } + return d.skillsLoader, nil + } + + workspaceFn := func() (string, error) { + if d.workspace == "" { + return "", fmt.Errorf("workspace is not initialized") + } + return d.workspace, nil + } + + cmd.AddCommand( + newListCommand(loaderFn), + newInstallCommand(installerFn), + newInstallBuiltinCommand(workspaceFn), + newListBuiltinCommand(), + newRemoveCommand(installerFn), + newSearchCommand(installerFn), + newShowCommand(loaderFn), + ) + + return cmd +} diff --git a/cmd/picoclaw/internal/skills/command_test.go b/cmd/picoclaw/internal/skills/command_test.go new file mode 100644 index 000000000..0917d1384 --- /dev/null +++ b/cmd/picoclaw/internal/skills/command_test.go @@ -0,0 +1,28 @@ +package skills + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewSkillsCommand(t *testing.T) { + cmd := NewSkillsCommand() + + require.NotNil(t, cmd) + + assert.Equal(t, "skills", cmd.Use) + assert.Equal(t, "Manage skills", cmd.Short) + + assert.Len(t, cmd.Aliases, 0) + + assert.False(t, cmd.HasFlags()) + + assert.Nil(t, cmd.Run) + assert.NotNil(t, cmd.RunE) + + assert.NotNil(t, cmd.PersistentPreRunE) + assert.Nil(t, cmd.PersistentPreRun) + assert.Nil(t, cmd.PersistentPostRun) +} diff --git a/cmd/picoclaw/cmd_skills.go b/cmd/picoclaw/internal/skills/helpers.go similarity index 72% rename from cmd/picoclaw/cmd_skills.go rename to cmd/picoclaw/internal/skills/helpers.go index 0814494b3..439b81a4f 100644 --- a/cmd/picoclaw/cmd_skills.go +++ b/cmd/picoclaw/internal/skills/helpers.go @@ -1,40 +1,20 @@ -// PicoClaw - Ultra-lightweight personal AI agent -// License: MIT - -package main +package skills import ( "context" "fmt" + "io" "os" "path/filepath" "strings" "time" + "github.com/sipeed/picoclaw/cmd/picoclaw/internal" "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/skills" "github.com/sipeed/picoclaw/pkg/utils" ) -func skillsHelp() { - fmt.Println("\nSkills commands:") - fmt.Println(" list List installed skills") - fmt.Println(" install Install skill from GitHub") - fmt.Println(" install-builtin Install all builtin skills to workspace") - fmt.Println(" list-builtin List available builtin skills") - fmt.Println(" remove Remove installed skill") - fmt.Println(" search Search available skills") - fmt.Println(" show Show skill details") - fmt.Println() - fmt.Println("Examples:") - fmt.Println(" picoclaw skills list") - fmt.Println(" picoclaw skills install sipeed/picoclaw-skills/weather") - fmt.Println(" picoclaw skills install-builtin") - fmt.Println(" picoclaw skills list-builtin") - fmt.Println(" picoclaw skills remove weather") - fmt.Println(" picoclaw skills install --registry clawhub github") -} - func skillsListCmd(loader *skills.SkillsLoader) { allSkills := loader.ListSkills() @@ -53,53 +33,31 @@ func skillsListCmd(loader *skills.SkillsLoader) { } } -func skillsInstallCmd(installer *skills.SkillInstaller, cfg *config.Config) { - if len(os.Args) < 4 { - fmt.Println("Usage: picoclaw skills install ") - fmt.Println(" picoclaw skills install --registry ") - return - } - - // Check for --registry flag. - if os.Args[3] == "--registry" { - if len(os.Args) < 6 { - fmt.Println("Usage: picoclaw skills install --registry ") - fmt.Println("Example: picoclaw skills install --registry clawhub github") - return - } - registryName := os.Args[4] - slug := os.Args[5] - skillsInstallFromRegistry(cfg, registryName, slug) - return - } - - // Default: install from GitHub (backward compatible). - repo := os.Args[3] +func skillsInstallCmd(installer *skills.SkillInstaller, repo string) error { fmt.Printf("Installing skill from %s...\n", repo) ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() if err := installer.InstallFromGitHub(ctx, repo); err != nil { - fmt.Printf("\u2717 Failed to install skill: %v\n", err) - os.Exit(1) + return fmt.Errorf("failed to install skill: %w", err) } fmt.Printf("\u2713 Skill '%s' installed successfully!\n", filepath.Base(repo)) + + return nil } // skillsInstallFromRegistry installs a skill from a named registry (e.g. clawhub). -func skillsInstallFromRegistry(cfg *config.Config, registryName, slug string) { +func skillsInstallFromRegistry(cfg *config.Config, registryName, slug string) error { err := utils.ValidateSkillIdentifier(registryName) if err != nil { - fmt.Printf("\u2717 Invalid registry name: %v\n", err) - os.Exit(1) + return fmt.Errorf("✗ invalid registry name: %w", err) } err = utils.ValidateSkillIdentifier(slug) if err != nil { - fmt.Printf("\u2717 Invalid slug: %v\n", err) - os.Exit(1) + return fmt.Errorf("✗ invalid slug: %w", err) } fmt.Printf("Installing skill '%s' from %s registry...\n", slug, registryName) @@ -111,24 +69,21 @@ func skillsInstallFromRegistry(cfg *config.Config, registryName, slug string) { registry := registryMgr.GetRegistry(registryName) if registry == nil { - fmt.Printf("\u2717 Registry '%s' not found or not enabled. Check your config.json.\n", registryName) - os.Exit(1) + return fmt.Errorf("✗ registry '%s' not found or not enabled. check your config.json.", registryName) } workspace := cfg.WorkspacePath() targetDir := filepath.Join(workspace, "skills", slug) if _, err = os.Stat(targetDir); err == nil { - fmt.Printf("\u2717 Skill '%s' already installed at %s\n", slug, targetDir) - os.Exit(1) + return fmt.Errorf("\u2717 skill '%s' already installed at %s", slug, targetDir) } ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) defer cancel() if err = os.MkdirAll(filepath.Join(workspace, "skills"), 0o755); err != nil { - fmt.Printf("\u2717 Failed to create skills directory: %v\n", err) - os.Exit(1) + return fmt.Errorf("\u2717 failed to create skills directory: %v", err) } result, err := registry.DownloadAndInstall(ctx, slug, "", targetDir) @@ -137,8 +92,7 @@ func skillsInstallFromRegistry(cfg *config.Config, registryName, slug string) { if rmErr != nil { fmt.Printf("\u2717 Failed to remove partial install: %v\n", rmErr) } - fmt.Printf("\u2717 Failed to install skill: %v\n", err) - os.Exit(1) + return fmt.Errorf("✗ failed to install skill: %w", err) } if result.IsMalwareBlocked { @@ -146,8 +100,8 @@ func skillsInstallFromRegistry(cfg *config.Config, registryName, slug string) { if rmErr != nil { fmt.Printf("\u2717 Failed to remove partial install: %v\n", rmErr) } - fmt.Printf("\u2717 Skill '%s' is flagged as malicious and cannot be installed.\n", slug) - os.Exit(1) + + return fmt.Errorf("\u2717 Skill '%s' is flagged as malicious and cannot be installed.\n", slug) } if result.IsSuspicious { @@ -158,6 +112,8 @@ func skillsInstallFromRegistry(cfg *config.Config, registryName, slug string) { if result.Summary != "" { fmt.Printf(" %s\n", result.Summary) } + + return nil } func skillsRemoveCmd(installer *skills.SkillInstaller, skillName string) { @@ -208,7 +164,7 @@ func skillsInstallBuiltinCmd(workspace string) { } func skillsListBuiltinCmd() { - cfg, err := loadConfig() + cfg, err := internal.LoadConfig() if err != nil { fmt.Printf("Error loading config: %v\n", err) return @@ -303,3 +259,37 @@ func skillsShowCmd(loader *skills.SkillsLoader, skillName string) { fmt.Println("----------------------") fmt.Println(content) } + +func copyDirectory(src, dst string) error { + return filepath.Walk(src, func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + + relPath, err := filepath.Rel(src, path) + if err != nil { + return err + } + + dstPath := filepath.Join(dst, relPath) + + if info.IsDir() { + return os.MkdirAll(dstPath, info.Mode()) + } + + srcFile, err := os.Open(path) + if err != nil { + return err + } + defer srcFile.Close() + + dstFile, err := os.OpenFile(dstPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, info.Mode()) + if err != nil { + return err + } + defer dstFile.Close() + + _, err = io.Copy(dstFile, srcFile) + return err + }) +} diff --git a/cmd/picoclaw/internal/skills/install.go b/cmd/picoclaw/internal/skills/install.go new file mode 100644 index 000000000..a30f68632 --- /dev/null +++ b/cmd/picoclaw/internal/skills/install.go @@ -0,0 +1,58 @@ +package skills + +import ( + "fmt" + + "github.com/spf13/cobra" + + "github.com/sipeed/picoclaw/cmd/picoclaw/internal" + "github.com/sipeed/picoclaw/pkg/skills" +) + +func newInstallCommand(installerFn func() (*skills.SkillInstaller, error)) *cobra.Command { + var registry string + + cmd := &cobra.Command{ + Use: "install", + Short: "Install skill from GitHub", + Example: ` +picoclaw skills install sipeed/picoclaw-skills/weather +picoclaw skills install --registry clawhub github +`, + Args: func(cmd *cobra.Command, args []string) error { + if registry != "" { + if len(args) != 2 { + return fmt.Errorf("when --registry is set, exactly 2 arguments are required: ") + } + return nil + } + + if len(args) != 1 { + return fmt.Errorf("exactly 1 argument is required: ") + } + + return nil + }, + RunE: func(_ *cobra.Command, args []string) error { + installer, err := installerFn() + if err != nil { + return err + } + + if registry != "" { + cfg, err := internal.LoadConfig() + if err != nil { + return err + } + + return skillsInstallFromRegistry(cfg, args[0], args[1]) + } + + return skillsInstallCmd(installer, args[0]) + }, + } + + cmd.Flags().StringVar(®istry, "registry", "", "Install from registry: --registry ") + + return cmd +} diff --git a/cmd/picoclaw/internal/skills/install_test.go b/cmd/picoclaw/internal/skills/install_test.go new file mode 100644 index 000000000..97787a986 --- /dev/null +++ b/cmd/picoclaw/internal/skills/install_test.go @@ -0,0 +1,28 @@ +package skills + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewInstallSubcommand(t *testing.T) { + cmd := newInstallCommand(nil) + + require.NotNil(t, cmd) + + assert.Equal(t, "install", cmd.Use) + assert.Equal(t, "Install skill from GitHub", cmd.Short) + + assert.Nil(t, cmd.Run) + assert.NotNil(t, cmd.RunE) + + assert.True(t, cmd.HasExample()) + assert.False(t, cmd.HasSubCommands()) + + assert.True(t, cmd.HasFlags()) + assert.NotNil(t, cmd.Flags().Lookup("registry")) + + assert.Len(t, cmd.Aliases, 0) +} diff --git a/cmd/picoclaw/internal/skills/installbuiltin.go b/cmd/picoclaw/internal/skills/installbuiltin.go new file mode 100644 index 000000000..d4b7c6a9f --- /dev/null +++ b/cmd/picoclaw/internal/skills/installbuiltin.go @@ -0,0 +1,21 @@ +package skills + +import "github.com/spf13/cobra" + +func newInstallBuiltinCommand(workspaceFn func() (string, error)) *cobra.Command { + cmd := &cobra.Command{ + Use: "install-builtin", + Short: "Install all builtin skills to workspace", + Example: `picoclaw skills install-builtin`, + RunE: func(_ *cobra.Command, _ []string) error { + workspace, err := workspaceFn() + if err != nil { + return err + } + skillsInstallBuiltinCmd(workspace) + return nil + }, + } + + return cmd +} diff --git a/cmd/picoclaw/internal/skills/installbuiltin_test.go b/cmd/picoclaw/internal/skills/installbuiltin_test.go new file mode 100644 index 000000000..ea65907e3 --- /dev/null +++ b/cmd/picoclaw/internal/skills/installbuiltin_test.go @@ -0,0 +1,27 @@ +package skills + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewInstallbuiltinSubcommand(t *testing.T) { + cmd := newInstallBuiltinCommand(nil) + + require.NotNil(t, cmd) + + assert.Equal(t, "install-builtin", cmd.Use) + assert.Equal(t, "Install all builtin skills to workspace", cmd.Short) + + assert.Nil(t, cmd.Run) + assert.NotNil(t, cmd.RunE) + + assert.True(t, cmd.HasExample()) + assert.False(t, cmd.HasSubCommands()) + + assert.False(t, cmd.HasFlags()) + + assert.Len(t, cmd.Aliases, 0) +} diff --git a/cmd/picoclaw/internal/skills/list.go b/cmd/picoclaw/internal/skills/list.go new file mode 100644 index 000000000..7d89ff8ed --- /dev/null +++ b/cmd/picoclaw/internal/skills/list.go @@ -0,0 +1,25 @@ +package skills + +import ( + "github.com/spf13/cobra" + + "github.com/sipeed/picoclaw/pkg/skills" +) + +func newListCommand(loaderFn func() (*skills.SkillsLoader, error)) *cobra.Command { + cmd := &cobra.Command{ + Use: "list", + Short: "List installed skills", + Example: `picoclaw skills list`, + RunE: func(_ *cobra.Command, _ []string) error { + loader, err := loaderFn() + if err != nil { + return err + } + skillsListCmd(loader) + return nil + }, + } + + return cmd +} diff --git a/cmd/picoclaw/internal/skills/list_test.go b/cmd/picoclaw/internal/skills/list_test.go new file mode 100644 index 000000000..9947ce7aa --- /dev/null +++ b/cmd/picoclaw/internal/skills/list_test.go @@ -0,0 +1,27 @@ +package skills + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewListSubcommand(t *testing.T) { + cmd := newListCommand(nil) + + require.NotNil(t, cmd) + + assert.Equal(t, "list", cmd.Use) + assert.Equal(t, "List installed skills", cmd.Short) + + assert.Nil(t, cmd.Run) + assert.NotNil(t, cmd.RunE) + + assert.True(t, cmd.HasExample()) + assert.False(t, cmd.HasSubCommands()) + + assert.False(t, cmd.HasFlags()) + + assert.Len(t, cmd.Aliases, 0) +} diff --git a/cmd/picoclaw/internal/skills/listbuiltin.go b/cmd/picoclaw/internal/skills/listbuiltin.go new file mode 100644 index 000000000..a3efb8d83 --- /dev/null +++ b/cmd/picoclaw/internal/skills/listbuiltin.go @@ -0,0 +1,16 @@ +package skills + +import "github.com/spf13/cobra" + +func newListBuiltinCommand() *cobra.Command { + cmd := &cobra.Command{ + Use: "list-builtin", + Short: "List available builtin skills", + Example: `picoclaw skills list-builtin`, + Run: func(_ *cobra.Command, _ []string) { + skillsListBuiltinCmd() + }, + } + + return cmd +} diff --git a/cmd/picoclaw/internal/skills/listbuiltin_test.go b/cmd/picoclaw/internal/skills/listbuiltin_test.go new file mode 100644 index 000000000..d4f45a436 --- /dev/null +++ b/cmd/picoclaw/internal/skills/listbuiltin_test.go @@ -0,0 +1,26 @@ +package skills + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewListbuiltinSubcommand(t *testing.T) { + cmd := newListBuiltinCommand() + + require.NotNil(t, cmd) + + assert.Equal(t, "list-builtin", cmd.Use) + assert.Equal(t, "List available builtin skills", cmd.Short) + + assert.NotNil(t, cmd.Run) + + assert.True(t, cmd.HasExample()) + assert.False(t, cmd.HasSubCommands()) + + assert.False(t, cmd.HasFlags()) + + assert.Len(t, cmd.Aliases, 0) +} diff --git a/cmd/picoclaw/internal/skills/remove.go b/cmd/picoclaw/internal/skills/remove.go new file mode 100644 index 000000000..cd7d3a8b4 --- /dev/null +++ b/cmd/picoclaw/internal/skills/remove.go @@ -0,0 +1,27 @@ +package skills + +import ( + "github.com/spf13/cobra" + + "github.com/sipeed/picoclaw/pkg/skills" +) + +func newRemoveCommand(installerFn func() (*skills.SkillInstaller, error)) *cobra.Command { + cmd := &cobra.Command{ + Use: "remove", + Aliases: []string{"rm", "uninstall"}, + Short: "Remove installed skill", + Args: cobra.ExactArgs(1), + Example: `picoclaw skills remove weather`, + RunE: func(_ *cobra.Command, args []string) error { + installer, err := installerFn() + if err != nil { + return err + } + skillsRemoveCmd(installer, args[0]) + return nil + }, + } + + return cmd +} diff --git a/cmd/picoclaw/internal/skills/remove_test.go b/cmd/picoclaw/internal/skills/remove_test.go new file mode 100644 index 000000000..b4c79760c --- /dev/null +++ b/cmd/picoclaw/internal/skills/remove_test.go @@ -0,0 +1,29 @@ +package skills + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewRemoveSubcommand(t *testing.T) { + cmd := newRemoveCommand(nil) + + require.NotNil(t, cmd) + + assert.Equal(t, "remove", cmd.Use) + assert.Equal(t, "Remove installed skill", cmd.Short) + + assert.Nil(t, cmd.Run) + assert.NotNil(t, cmd.RunE) + + assert.True(t, cmd.HasExample()) + assert.False(t, cmd.HasSubCommands()) + + assert.False(t, cmd.HasFlags()) + + assert.Len(t, cmd.Aliases, 2) + assert.True(t, cmd.HasAlias("rm")) + assert.True(t, cmd.HasAlias("uninstall")) +} diff --git a/cmd/picoclaw/internal/skills/search.go b/cmd/picoclaw/internal/skills/search.go new file mode 100644 index 000000000..53bc99109 --- /dev/null +++ b/cmd/picoclaw/internal/skills/search.go @@ -0,0 +1,24 @@ +package skills + +import ( + "github.com/spf13/cobra" + + "github.com/sipeed/picoclaw/pkg/skills" +) + +func newSearchCommand(installerFn func() (*skills.SkillInstaller, error)) *cobra.Command { + cmd := &cobra.Command{ + Use: "search", + Short: "Search available skills", + RunE: func(_ *cobra.Command, _ []string) error { + installer, err := installerFn() + if err != nil { + return err + } + skillsSearchCmd(installer) + return nil + }, + } + + return cmd +} diff --git a/cmd/picoclaw/internal/skills/search_test.go b/cmd/picoclaw/internal/skills/search_test.go new file mode 100644 index 000000000..19f63a9ff --- /dev/null +++ b/cmd/picoclaw/internal/skills/search_test.go @@ -0,0 +1,25 @@ +package skills + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewSearchSubcommand(t *testing.T) { + cmd := newSearchCommand(nil) + + require.NotNil(t, cmd) + + assert.Equal(t, "search", cmd.Use) + assert.Equal(t, "Search available skills", cmd.Short) + + assert.Nil(t, cmd.Run) + assert.NotNil(t, cmd.RunE) + + assert.False(t, cmd.HasSubCommands()) + assert.False(t, cmd.HasFlags()) + + assert.Len(t, cmd.Aliases, 0) +} diff --git a/cmd/picoclaw/internal/skills/show.go b/cmd/picoclaw/internal/skills/show.go new file mode 100644 index 000000000..e484f3f28 --- /dev/null +++ b/cmd/picoclaw/internal/skills/show.go @@ -0,0 +1,26 @@ +package skills + +import ( + "github.com/spf13/cobra" + + "github.com/sipeed/picoclaw/pkg/skills" +) + +func newShowCommand(loaderFn func() (*skills.SkillsLoader, error)) *cobra.Command { + cmd := &cobra.Command{ + Use: "show", + Short: "Show skill details", + Args: cobra.ExactArgs(1), + Example: `picoclaw skills show weather`, + RunE: func(_ *cobra.Command, args []string) error { + loader, err := loaderFn() + if err != nil { + return err + } + skillsShowCmd(loader, args[0]) + return nil + }, + } + + return cmd +} diff --git a/cmd/picoclaw/internal/skills/show_test.go b/cmd/picoclaw/internal/skills/show_test.go new file mode 100644 index 000000000..5858d2790 --- /dev/null +++ b/cmd/picoclaw/internal/skills/show_test.go @@ -0,0 +1,27 @@ +package skills + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewShowSubcommand(t *testing.T) { + cmd := newShowCommand(nil) + + require.NotNil(t, cmd) + + assert.Equal(t, "show", cmd.Use) + assert.Equal(t, "Show skill details", cmd.Short) + + assert.Nil(t, cmd.Run) + assert.NotNil(t, cmd.RunE) + + assert.True(t, cmd.HasExample()) + assert.False(t, cmd.HasSubCommands()) + + assert.False(t, cmd.HasFlags()) + + assert.Len(t, cmd.Aliases, 0) +} diff --git a/cmd/picoclaw/internal/status/command.go b/cmd/picoclaw/internal/status/command.go new file mode 100644 index 000000000..9303ae2ec --- /dev/null +++ b/cmd/picoclaw/internal/status/command.go @@ -0,0 +1,18 @@ +package status + +import ( + "github.com/spf13/cobra" +) + +func NewStatusCommand() *cobra.Command { + cmd := &cobra.Command{ + Use: "status", + Aliases: []string{"s"}, + Short: "Show picoclaw status", + Run: func(cmd *cobra.Command, args []string) { + statusCmd() + }, + } + + return cmd +} diff --git a/cmd/picoclaw/internal/status/command_test.go b/cmd/picoclaw/internal/status/command_test.go new file mode 100644 index 000000000..974b4ea3d --- /dev/null +++ b/cmd/picoclaw/internal/status/command_test.go @@ -0,0 +1,29 @@ +package status + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewStatusCommand(t *testing.T) { + cmd := NewStatusCommand() + + require.NotNil(t, cmd) + + assert.Equal(t, "status", cmd.Use) + + assert.Len(t, cmd.Aliases, 1) + assert.True(t, cmd.HasAlias("s")) + + assert.Equal(t, "Show picoclaw status", cmd.Short) + + assert.False(t, cmd.HasSubCommands()) + + assert.NotNil(t, cmd.Run) + assert.Nil(t, cmd.RunE) + + assert.Nil(t, cmd.PersistentPreRun) + assert.Nil(t, cmd.PersistentPostRun) +} diff --git a/cmd/picoclaw/cmd_status.go b/cmd/picoclaw/internal/status/helpers.go similarity index 90% rename from cmd/picoclaw/cmd_status.go rename to cmd/picoclaw/internal/status/helpers.go index 6a117bd17..ab28f4885 100644 --- a/cmd/picoclaw/cmd_status.go +++ b/cmd/picoclaw/internal/status/helpers.go @@ -1,27 +1,25 @@ -// PicoClaw - Ultra-lightweight personal AI agent -// License: MIT - -package main +package status import ( "fmt" "os" + "github.com/sipeed/picoclaw/cmd/picoclaw/internal" "github.com/sipeed/picoclaw/pkg/auth" ) func statusCmd() { - cfg, err := loadConfig() + cfg, err := internal.LoadConfig() if err != nil { fmt.Printf("Error loading config: %v\n", err) return } - configPath := getConfigPath() + configPath := internal.GetConfigPath() - fmt.Printf("%s picoclaw Status\n", logo) - fmt.Printf("Version: %s\n", formatVersion()) - build, _ := formatBuildInfo() + fmt.Printf("%s picoclaw Status\n", internal.Logo) + fmt.Printf("Version: %s\n", internal.FormatVersion()) + build, _ := internal.FormatBuildInfo() if build != "" { fmt.Printf("Build: %s\n", build) } diff --git a/cmd/picoclaw/internal/version/command.go b/cmd/picoclaw/internal/version/command.go new file mode 100644 index 000000000..1cf686671 --- /dev/null +++ b/cmd/picoclaw/internal/version/command.go @@ -0,0 +1,33 @@ +package version + +import ( + "fmt" + + "github.com/spf13/cobra" + + "github.com/sipeed/picoclaw/cmd/picoclaw/internal" +) + +func NewVersionCommand() *cobra.Command { + cmd := &cobra.Command{ + Use: "version", + Aliases: []string{"v"}, + Short: "Show version information", + Run: func(_ *cobra.Command, _ []string) { + printVersion() + }, + } + + return cmd +} + +func printVersion() { + fmt.Printf("%s picoclaw %s\n", internal.Logo, internal.FormatVersion()) + build, goVer := internal.FormatBuildInfo() + if build != "" { + fmt.Printf(" Build: %s\n", build) + } + if goVer != "" { + fmt.Printf(" Go: %s\n", goVer) + } +} diff --git a/cmd/picoclaw/internal/version/command_test.go b/cmd/picoclaw/internal/version/command_test.go new file mode 100644 index 000000000..f08a4d1ea --- /dev/null +++ b/cmd/picoclaw/internal/version/command_test.go @@ -0,0 +1,31 @@ +package version + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewVersionCommand(t *testing.T) { + cmd := NewVersionCommand() + + require.NotNil(t, cmd) + + assert.Equal(t, "version", cmd.Use) + + assert.Len(t, cmd.Aliases, 1) + assert.True(t, cmd.HasAlias("v")) + + assert.False(t, cmd.HasFlags()) + + assert.Equal(t, "Show version information", cmd.Short) + + assert.False(t, cmd.HasSubCommands()) + + assert.NotNil(t, cmd.Run) + assert.Nil(t, cmd.RunE) + + assert.Nil(t, cmd.PersistentPreRun) + assert.Nil(t, cmd.PersistentPostRun) +} diff --git a/cmd/picoclaw/main.go b/cmd/picoclaw/main.go index 25ad701ca..6db69c990 100644 --- a/cmd/picoclaw/main.go +++ b/cmd/picoclaw/main.go @@ -8,192 +8,49 @@ package main import ( "fmt" - "io" "os" - "path/filepath" - "runtime" - "github.com/sipeed/picoclaw/pkg/config" - "github.com/sipeed/picoclaw/pkg/skills" + "github.com/spf13/cobra" + + "github.com/sipeed/picoclaw/cmd/picoclaw/internal" + "github.com/sipeed/picoclaw/cmd/picoclaw/internal/agent" + "github.com/sipeed/picoclaw/cmd/picoclaw/internal/auth" + "github.com/sipeed/picoclaw/cmd/picoclaw/internal/cron" + "github.com/sipeed/picoclaw/cmd/picoclaw/internal/gateway" + "github.com/sipeed/picoclaw/cmd/picoclaw/internal/migrate" + "github.com/sipeed/picoclaw/cmd/picoclaw/internal/onboard" + "github.com/sipeed/picoclaw/cmd/picoclaw/internal/skills" + "github.com/sipeed/picoclaw/cmd/picoclaw/internal/status" + "github.com/sipeed/picoclaw/cmd/picoclaw/internal/version" ) -var ( - version = "dev" - gitCommit string - buildTime string - goVersion string -) +func NewPicoclawCommand() *cobra.Command { + short := fmt.Sprintf("%s picoclaw - Personal AI Assistant v%s\n\n", internal.Logo, internal.GetVersion()) -const logo = "🦞" - -// formatVersion returns the version string with optional git commit -func formatVersion() string { - v := version - if gitCommit != "" { - v += fmt.Sprintf(" (git: %s)", gitCommit) + cmd := &cobra.Command{ + Use: "picoclaw", + Short: short, + Example: "picoclaw list", } - return v -} -// formatBuildInfo returns build time and go version info -func formatBuildInfo() (build string, goVer string) { - if buildTime != "" { - build = buildTime - } - goVer = goVersion - if goVer == "" { - goVer = runtime.Version() - } - return -} + cmd.AddCommand( + onboard.NewOnboardCommand(), + agent.NewAgentCommand(), + auth.NewAuthCommand(), + gateway.NewGatewayCommand(), + status.NewStatusCommand(), + cron.NewCronCommand(), + migrate.NewMigrateCommand(), + skills.NewSkillsCommand(), + version.NewVersionCommand(), + ) -func printVersion() { - fmt.Printf("%s picoclaw %s\n", logo, formatVersion()) - build, goVer := formatBuildInfo() - if build != "" { - fmt.Printf(" Build: %s\n", build) - } - if goVer != "" { - fmt.Printf(" Go: %s\n", goVer) - } -} - -func copyDirectory(src, dst string) error { - return filepath.Walk(src, func(path string, info os.FileInfo, err error) error { - if err != nil { - return err - } - - relPath, err := filepath.Rel(src, path) - if err != nil { - return err - } - - dstPath := filepath.Join(dst, relPath) - - if info.IsDir() { - return os.MkdirAll(dstPath, info.Mode()) - } - - srcFile, err := os.Open(path) - if err != nil { - return err - } - defer srcFile.Close() - - dstFile, err := os.OpenFile(dstPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, info.Mode()) - if err != nil { - return err - } - defer dstFile.Close() - - _, err = io.Copy(dstFile, srcFile) - return err - }) + return cmd } func main() { - if len(os.Args) < 2 { - printHelp() - os.Exit(1) - } - - command := os.Args[1] - - switch command { - case "onboard": - onboard() - case "agent": - agentCmd() - case "gateway": - gatewayCmd() - case "status": - statusCmd() - case "migrate": - migrateCmd() - case "auth": - authCmd() - case "cron": - cronCmd() - case "skills": - if len(os.Args) < 3 { - skillsHelp() - return - } - - subcommand := os.Args[2] - - cfg, err := loadConfig() - if err != nil { - fmt.Printf("Error loading config: %v\n", err) - os.Exit(1) - } - - workspace := cfg.WorkspacePath() - installer := skills.NewSkillInstaller(workspace) - // get global config directory and builtin skills directory - globalDir := filepath.Dir(getConfigPath()) - globalSkillsDir := filepath.Join(globalDir, "skills") - builtinSkillsDir := filepath.Join(globalDir, "picoclaw", "skills") - skillsLoader := skills.NewSkillsLoader(workspace, globalSkillsDir, builtinSkillsDir) - - switch subcommand { - case "list": - skillsListCmd(skillsLoader) - case "install": - skillsInstallCmd(installer, cfg) - case "remove", "uninstall": - if len(os.Args) < 4 { - fmt.Println("Usage: picoclaw skills remove ") - return - } - skillsRemoveCmd(installer, os.Args[3]) - case "install-builtin": - skillsInstallBuiltinCmd(workspace) - case "list-builtin": - skillsListBuiltinCmd() - case "search": - skillsSearchCmd(installer) - case "show": - if len(os.Args) < 4 { - fmt.Println("Usage: picoclaw skills show ") - return - } - skillsShowCmd(skillsLoader, os.Args[3]) - default: - fmt.Printf("Unknown skills command: %s\n", subcommand) - skillsHelp() - } - case "version", "--version", "-v": - printVersion() - default: - fmt.Printf("Unknown command: %s\n", command) - printHelp() + cmd := NewPicoclawCommand() + if err := cmd.Execute(); err != nil { os.Exit(1) } } - -func printHelp() { - fmt.Printf("%s picoclaw - Personal AI Assistant v%s\n\n", logo, version) - fmt.Println("Usage: picoclaw ") - fmt.Println() - fmt.Println("Commands:") - fmt.Println(" onboard Initialize picoclaw configuration and workspace") - fmt.Println(" agent Interact with the agent directly") - fmt.Println(" auth Manage authentication (login, logout, status)") - fmt.Println(" gateway Start picoclaw gateway") - fmt.Println(" status Show picoclaw status") - fmt.Println(" cron Manage scheduled tasks") - fmt.Println(" migrate Migrate from OpenClaw to PicoClaw") - fmt.Println(" skills Manage skills (install, list, remove)") - fmt.Println(" version Show version information") -} - -func getConfigPath() string { - home, _ := os.UserHomeDir() - return filepath.Join(home, ".picoclaw", "config.json") -} - -func loadConfig() (*config.Config, error) { - return config.LoadConfig(getConfigPath()) -} diff --git a/cmd/picoclaw/main_test.go b/cmd/picoclaw/main_test.go new file mode 100644 index 000000000..3740ba358 --- /dev/null +++ b/cmd/picoclaw/main_test.go @@ -0,0 +1,56 @@ +package main + +import ( + "fmt" + "slices" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/sipeed/picoclaw/cmd/picoclaw/internal" +) + +func TestNewPicoclawCommand(t *testing.T) { + cmd := NewPicoclawCommand() + + require.NotNil(t, cmd) + + short := fmt.Sprintf("%s picoclaw - Personal AI Assistant v%s\n\n", internal.Logo, internal.GetVersion()) + + assert.Equal(t, "picoclaw", cmd.Use) + assert.Equal(t, short, cmd.Short) + + assert.True(t, cmd.HasSubCommands()) + assert.True(t, cmd.HasAvailableSubCommands()) + + assert.False(t, cmd.HasFlags()) + + assert.Nil(t, cmd.Run) + assert.Nil(t, cmd.RunE) + + assert.Nil(t, cmd.PersistentPreRun) + assert.Nil(t, cmd.PersistentPostRun) + + allowedCommands := []string{ + "agent", + "auth", + "cron", + "gateway", + "migrate", + "onboard", + "skills", + "status", + "version", + } + + subcommands := cmd.Commands() + assert.Len(t, subcommands, len(allowedCommands)) + + for _, subcmd := range subcommands { + found := slices.Contains(allowedCommands, subcmd.Name()) + assert.True(t, found, "unexpected subcommand %q", subcmd.Name()) + + assert.False(t, subcmd.Hidden) + } +} diff --git a/go.mod b/go.mod index 1f88639c8..98e20d07d 100644 --- a/go.mod +++ b/go.mod @@ -15,6 +15,7 @@ require ( github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1 github.com/openai/openai-go/v3 v3.22.0 github.com/slack-go/slack v0.17.3 + github.com/spf13/cobra v1.10.2 github.com/stretchr/testify v1.11.1 github.com/tencent-connect/botgo v0.2.1 golang.org/x/oauth2 v0.35.0 @@ -22,7 +23,9 @@ require ( require ( github.com/davecgh/go-spew v1.1.1 // indirect + github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/spf13/pflag v1.0.10 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 0e95bf5cd..abbb11cd6 100644 --- a/go.sum +++ b/go.sum @@ -25,6 +25,7 @@ github.com/chzyer/test v1.0.0 h1:p3BQDXSxOhOG0P9z6/hGnII4LGiEPOYBhs8asl/fC04= github.com/chzyer/test v1.0.0/go.mod h1:2JlltgoNkt4TW/z9V/IzDdFaMTM2JPIi26O1pF38GC8= github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M= github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU= +github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= @@ -72,6 +73,8 @@ github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/ad github.com/grbit/go-json v0.11.0 h1:bAbyMdYrYl/OjYsSqLH99N2DyQ291mHy726Mx+sYrnc= github.com/grbit/go-json v0.11.0/go.mod h1:IYpHsdybQ386+6g3VE6AXQ3uTGa5mquBme5/ZWmtzek= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= +github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= +github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/klauspost/compress v1.18.4 h1:RPhnKRAQ4Fh8zU2FY/6ZFDwTVTxgJ/EMydqSTzE9a2c= @@ -108,8 +111,14 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= +github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/slack-go/slack v0.17.3 h1:zV5qO3Q+WJAQ/XwbGfNFrRMaJ5T/naqaonyPV/1TP4g= github.com/slack-go/slack v0.17.3/go.mod h1:X+UqOufi3LYQHDnMG1vxf0J8asC6+WllXrVrhl8/Prk= +github.com/spf13/cobra v1.10.2 h1:DMTTonx5m65Ic0GOoRY2c16WCbHxOOw6xxezuLaBpcU= +github.com/spf13/cobra v1.10.2/go.mod h1:7C1pvHqHw5A4vrJfjNwvOdzYu0Gml16OCs2GRiTUUS4= +github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk= +github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= @@ -151,6 +160,7 @@ github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9dec github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y= go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU= +go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= golang.org/x/arch v0.24.0 h1:qlJ3M9upxvFfwRM51tTg3Yl+8CP9vCC1E7vlFpgv99Y= golang.org/x/arch v0.24.0/go.mod h1:dNHoOeKiyja7GTvF9NJS1l3Z2yntpQNzgrjh1cU103A= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= diff --git a/pkg/agent/context.go b/pkg/agent/context.go index ba07e33d3..b7c6e1108 100644 --- a/pkg/agent/context.go +++ b/pkg/agent/context.go @@ -1,24 +1,38 @@ package agent import ( + "errors" "fmt" + "io/fs" "os" "path/filepath" "runtime" "strings" + "sync" "time" "github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/providers" "github.com/sipeed/picoclaw/pkg/skills" - "github.com/sipeed/picoclaw/pkg/tools" ) type ContextBuilder struct { workspace string skillsLoader *skills.SkillsLoader memory *MemoryStore - tools *tools.ToolRegistry // Direct reference to tool registry + + // Cache for system prompt to avoid rebuilding on every call. + // This fixes issue #607: repeated reprocessing of the entire context. + // The cache auto-invalidates when workspace source files change (mtime check). + systemPromptMutex sync.RWMutex + cachedSystemPrompt string + cachedAt time.Time // max observed mtime across tracked paths at cache build time + + // existedAtCache tracks which source file paths existed the last time the + // cache was built. This lets sourceFilesChanged detect files that are newly + // created (didn't exist at cache time, now exist) or deleted (existed at + // cache time, now gone) — both of which should trigger a cache rebuild. + existedAtCache map[string]bool } func getGlobalConfigDir() string { @@ -43,69 +57,29 @@ func NewContextBuilder(workspace string) *ContextBuilder { } } -// SetToolsRegistry sets the tools registry for dynamic tool summary generation. -func (cb *ContextBuilder) SetToolsRegistry(registry *tools.ToolRegistry) { - cb.tools = registry -} - func (cb *ContextBuilder) getIdentity() string { - now := time.Now().Format("2006-01-02 15:04 (Monday)") workspacePath, _ := filepath.Abs(filepath.Join(cb.workspace)) - runtime := fmt.Sprintf("%s %s, Go %s", runtime.GOOS, runtime.GOARCH, runtime.Version()) - - // Build tools section dynamically - toolsSection := cb.buildToolsSection() return fmt.Sprintf(`# picoclaw 🦞 You are picoclaw, a helpful AI assistant. -## Current Time -%s - -## Runtime -%s - ## Workspace Your workspace is at: %s - Memory: %s/memory/MEMORY.md - Daily Notes: %s/memory/YYYYMM/YYYYMMDD.md - Skills: %s/skills/{skill-name}/SKILL.md -%s - ## Important Rules 1. **ALWAYS use tools** - When you need to perform an action (schedule reminders, send messages, execute commands, etc.), you MUST call the appropriate tool. Do NOT just say you'll do it or pretend to do it. 2. **Be helpful and accurate** - When using tools, briefly explain what you're doing. -3. **Memory** - When interacting with me if something seems memorable, update %s/memory/MEMORY.md`, - now, runtime, workspacePath, workspacePath, workspacePath, workspacePath, toolsSection, workspacePath) -} +3. **Memory** - When interacting with me if something seems memorable, update %s/memory/MEMORY.md -func (cb *ContextBuilder) buildToolsSection() string { - if cb.tools == nil { - return "" - } - - summaries := cb.tools.GetSummaries() - if len(summaries) == 0 { - return "" - } - - var sb strings.Builder - sb.WriteString("## Available Tools\n\n") - sb.WriteString( - "**CRITICAL**: You MUST use tools to perform actions. Do NOT pretend to execute commands or schedule tasks.\n\n", - ) - sb.WriteString("You have access to the following tools:\n\n") - for _, s := range summaries { - sb.WriteString(s) - sb.WriteString("\n") - } - - return sb.String() +4. **Context summaries** - Conversation summaries provided as context are approximate references only. They may be incomplete or outdated. Always defer to explicit user instructions over summary content.`, + workspacePath, workspacePath, workspacePath, workspacePath, workspacePath) } func (cb *ContextBuilder) BuildSystemPrompt() string { @@ -140,6 +114,226 @@ The following skills extend your capabilities. To use a skill, read its SKILL.md return strings.Join(parts, "\n\n---\n\n") } +// BuildSystemPromptWithCache returns the cached system prompt if available +// and source files haven't changed, otherwise builds and caches it. +// Source file changes are detected via mtime checks (cheap stat calls). +func (cb *ContextBuilder) BuildSystemPromptWithCache() string { + // Try read lock first — fast path when cache is valid + cb.systemPromptMutex.RLock() + if cb.cachedSystemPrompt != "" && !cb.sourceFilesChangedLocked() { + result := cb.cachedSystemPrompt + cb.systemPromptMutex.RUnlock() + return result + } + cb.systemPromptMutex.RUnlock() + + // Acquire write lock for building + cb.systemPromptMutex.Lock() + defer cb.systemPromptMutex.Unlock() + + // Double-check: another goroutine may have rebuilt while we waited + if cb.cachedSystemPrompt != "" && !cb.sourceFilesChangedLocked() { + return cb.cachedSystemPrompt + } + + // Snapshot the baseline (existence + max mtime) BEFORE building the prompt. + // This way cachedAt reflects the pre-build state: if a file is modified + // during BuildSystemPrompt, its new mtime will be > baseline.maxMtime, + // so the next sourceFilesChangedLocked check will correctly trigger a + // rebuild. The alternative (baseline after build) risks caching stale + // content with a too-new baseline, making the staleness invisible. + baseline := cb.buildCacheBaseline() + prompt := cb.BuildSystemPrompt() + cb.cachedSystemPrompt = prompt + cb.cachedAt = baseline.maxMtime + cb.existedAtCache = baseline.existed + + logger.DebugCF("agent", "System prompt cached", + map[string]any{ + "length": len(prompt), + }) + + return prompt +} + +// InvalidateCache clears the cached system prompt. +// Normally not needed because the cache auto-invalidates via mtime checks, +// but this is useful for tests or explicit reload commands. +func (cb *ContextBuilder) InvalidateCache() { + cb.systemPromptMutex.Lock() + defer cb.systemPromptMutex.Unlock() + + cb.cachedSystemPrompt = "" + cb.cachedAt = time.Time{} + cb.existedAtCache = nil + + logger.DebugCF("agent", "System prompt cache invalidated", nil) +} + +// sourcePaths returns the workspace source file paths tracked for cache +// invalidation (bootstrap files + memory). The skills directory is handled +// separately in sourceFilesChangedLocked because it requires both directory- +// level and recursive file-level mtime checks. +func (cb *ContextBuilder) sourcePaths() []string { + return []string{ + filepath.Join(cb.workspace, "AGENTS.md"), + filepath.Join(cb.workspace, "SOUL.md"), + filepath.Join(cb.workspace, "USER.md"), + filepath.Join(cb.workspace, "IDENTITY.md"), + filepath.Join(cb.workspace, "memory", "MEMORY.md"), + } +} + +// cacheBaseline holds the file existence snapshot and the latest observed +// mtime across all tracked paths. Used as the cache reference point. +type cacheBaseline struct { + existed map[string]bool + maxMtime time.Time +} + +// buildCacheBaseline records which tracked paths currently exist and computes +// the latest mtime across all tracked files + skills directory contents. +// Called under write lock when the cache is built. +func (cb *ContextBuilder) buildCacheBaseline() cacheBaseline { + skillsDir := filepath.Join(cb.workspace, "skills") + + // All paths whose existence we track: source files + skills dir. + allPaths := append(cb.sourcePaths(), skillsDir) + + existed := make(map[string]bool, len(allPaths)) + var maxMtime time.Time + + for _, p := range allPaths { + info, err := os.Stat(p) + existed[p] = err == nil + if err == nil && info.ModTime().After(maxMtime) { + maxMtime = info.ModTime() + } + } + + // Walk skills files to capture their mtimes too. + // Use os.Stat (not d.Info) to match the stat method used in + // fileChangedSince / skillFilesModifiedSince for consistency. + _ = filepath.WalkDir(skillsDir, func(path string, d fs.DirEntry, walkErr error) error { + if walkErr == nil && !d.IsDir() { + if info, err := os.Stat(path); err == nil && info.ModTime().After(maxMtime) { + maxMtime = info.ModTime() + } + } + return nil + }) + + // If no tracked files exist yet (empty workspace), maxMtime is zero. + // Use a very old non-zero time so that: + // 1. cachedAt.IsZero() won't trigger perpetual rebuilds. + // 2. Any real file created afterwards has mtime > cachedAt, so it + // will be detected by fileChangedSince (unlike time.Now() which + // could race with a file whose mtime <= Now). + if maxMtime.IsZero() { + maxMtime = time.Unix(1, 0) + } + + return cacheBaseline{existed: existed, maxMtime: maxMtime} +} + +// sourceFilesChangedLocked checks whether any workspace source file has been +// modified, created, or deleted since the cache was last built. +// +// IMPORTANT: The caller MUST hold at least a read lock on systemPromptMutex. +// Go's sync.RWMutex is not reentrant, so this function must NOT acquire the +// lock itself (it would deadlock when called from BuildSystemPromptWithCache +// which already holds RLock or Lock). +func (cb *ContextBuilder) sourceFilesChangedLocked() bool { + if cb.cachedAt.IsZero() { + return true + } + + // Check tracked source files (bootstrap + memory). + for _, p := range cb.sourcePaths() { + if cb.fileChangedSince(p) { + return true + } + } + + // --- Skills directory (handled separately from sourcePaths) --- + // + // 1. Creation/deletion: tracked via existedAtCache, same as bootstrap files. + skillsDir := filepath.Join(cb.workspace, "skills") + if cb.fileChangedSince(skillsDir) { + return true + } + + // 2. Structural changes (add/remove entries inside the dir) are reflected + // in the directory's own mtime, which fileChangedSince already checks. + // + // 3. Content-only edits to files inside skills/ do NOT update the parent + // directory mtime on most filesystems, so we recursively walk to check + // individual file mtimes at any nesting depth. + if skillFilesModifiedSince(skillsDir, cb.cachedAt) { + return true + } + + return false +} + +// fileChangedSince returns true if a tracked source file has been modified, +// newly created, or deleted since the cache was built. +// +// Four cases: +// - existed at cache time, exists now -> check mtime +// - existed at cache time, gone now -> changed (deleted) +// - absent at cache time, exists now -> changed (created) +// - absent at cache time, gone now -> no change +func (cb *ContextBuilder) fileChangedSince(path string) bool { + // Defensive: if existedAtCache was never initialized, treat as changed + // so the cache rebuilds rather than silently serving stale data. + if cb.existedAtCache == nil { + return true + } + + existedBefore := cb.existedAtCache[path] + info, err := os.Stat(path) + existsNow := err == nil + + if existedBefore != existsNow { + return true // file was created or deleted + } + if !existsNow { + return false // didn't exist before, doesn't exist now + } + return info.ModTime().After(cb.cachedAt) +} + +// errWalkStop is a sentinel error used to stop filepath.WalkDir early. +// Using a dedicated error (instead of fs.SkipAll) makes the early-exit +// intent explicit and avoids the nilerr linter warning that would fire +// if the callback returned nil when its err parameter is non-nil. +var errWalkStop = errors.New("walk stop") + +// skillFilesModifiedSince recursively walks the skills directory and checks +// whether any file was modified after t. This catches content-only edits at +// any nesting depth (e.g. skills/name/docs/extra.md) that don't update +// parent directory mtimes. +func skillFilesModifiedSince(skillsDir string, t time.Time) bool { + changed := false + err := filepath.WalkDir(skillsDir, func(path string, d fs.DirEntry, walkErr error) error { + if walkErr == nil && !d.IsDir() { + if info, statErr := os.Stat(path); statErr == nil && info.ModTime().After(t) { + changed = true + return errWalkStop // stop walking + } + } + return nil + }) + // errWalkStop is expected (early exit on first changed file). + // os.IsNotExist means the skills dir doesn't exist yet — not an error. + // Any other error is unexpected and worth logging. + if err != nil && !errors.Is(err, errWalkStop) && !os.IsNotExist(err) { + logger.DebugCF("agent", "skills walk error", map[string]any{"error": err.Error()}) + } + return changed +} + func (cb *ContextBuilder) LoadBootstrapFiles() string { bootstrapFiles := []string{ "AGENTS.md", @@ -159,6 +353,28 @@ func (cb *ContextBuilder) LoadBootstrapFiles() string { return sb.String() } +// buildDynamicContext returns a short dynamic context string with per-request info. +// This changes every request (time, session) so it is NOT part of the cached prompt. +// LLM-side KV cache reuse is achieved by each provider adapter's native mechanism: +// - Anthropic: per-block cache_control (ephemeral) on the static SystemParts block +// - OpenAI / Codex: prompt_cache_key for prefix-based caching +// +// See: https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching +// See: https://platform.openai.com/docs/guides/prompt-caching +func (cb *ContextBuilder) buildDynamicContext(channel, chatID string) string { + now := time.Now().Format("2006-01-02 15:04 (Monday)") + rt := fmt.Sprintf("%s %s, Go %s", runtime.GOOS, runtime.GOARCH, runtime.Version()) + + var sb strings.Builder + fmt.Fprintf(&sb, "## Current Time\n%s\n\n## Runtime\n%s", now, rt) + + if channel != "" && chatID != "" { + fmt.Fprintf(&sb, "\n\n## Current Session\nChannel: %s\nChat ID: %s", channel, chatID) + } + + return sb.String() +} + func (cb *ContextBuilder) BuildMessages( history []providers.Message, summary string, @@ -168,23 +384,65 @@ func (cb *ContextBuilder) BuildMessages( ) []providers.Message { messages := []providers.Message{} - systemPrompt := cb.BuildSystemPrompt() + // The static part (identity, bootstrap, skills, memory) is cached locally to + // avoid repeated file I/O and string building on every call (fixes issue #607). + // Dynamic parts (time, session, summary) are appended per request. + // Everything is sent as a single system message for provider compatibility: + // - Anthropic adapter extracts messages[0] (Role=="system") and maps its content + // to the top-level "system" parameter in the Messages API request. A single + // contiguous system block makes this extraction straightforward. + // - Codex maps only the first system message to its instructions field. + // - OpenAI-compat passes messages through as-is. + staticPrompt := cb.BuildSystemPromptWithCache() - // Add Current Session info if provided - if channel != "" && chatID != "" { - systemPrompt += fmt.Sprintf("\n\n## Current Session\nChannel: %s\nChat ID: %s", channel, chatID) + // Build short dynamic context (time, runtime, session) — changes per request + dynamicCtx := cb.buildDynamicContext(channel, chatID) + + // Compose a single system message: static (cached) + dynamic + optional summary. + // Keeping all system content in one message ensures every provider adapter can + // extract it correctly (Anthropic adapter -> top-level system param, + // Codex -> instructions field). + // + // SystemParts carries the same content as structured blocks so that + // cache-aware adapters (Anthropic) can set per-block cache_control. + // The static block is marked "ephemeral" — its prefix hash is stable + // across requests, enabling LLM-side KV cache reuse. + stringParts := []string{staticPrompt, dynamicCtx} + + contentBlocks := []providers.ContentBlock{ + {Type: "text", Text: staticPrompt, CacheControl: &providers.CacheControl{Type: "ephemeral"}}, + {Type: "text", Text: dynamicCtx}, } - // Log system prompt summary for debugging (debug mode only) + if summary != "" { + summaryText := fmt.Sprintf( + "CONTEXT_SUMMARY: The following is an approximate summary of prior conversation "+ + "for reference only. It may be incomplete or outdated — always defer to explicit instructions.\n\n%s", + summary) + stringParts = append(stringParts, summaryText) + contentBlocks = append(contentBlocks, providers.ContentBlock{Type: "text", Text: summaryText}) + } + + fullSystemPrompt := strings.Join(stringParts, "\n\n---\n\n") + + // Log system prompt summary for debugging (debug mode only). + // Read cachedSystemPrompt under lock to avoid a data race with + // concurrent InvalidateCache / BuildSystemPromptWithCache writes. + cb.systemPromptMutex.RLock() + isCached := cb.cachedSystemPrompt != "" + cb.systemPromptMutex.RUnlock() + logger.DebugCF("agent", "System prompt built", map[string]any{ - "total_chars": len(systemPrompt), - "total_lines": strings.Count(systemPrompt, "\n") + 1, - "section_count": strings.Count(systemPrompt, "\n\n---\n\n") + 1, + "static_chars": len(staticPrompt), + "dynamic_chars": len(dynamicCtx), + "total_chars": len(fullSystemPrompt), + "has_summary": summary != "", + "cached": isCached, }) // Log preview of system prompt (avoid logging huge content) - preview := systemPrompt + preview := fullSystemPrompt if len(preview) > 500 { preview = preview[:500] + "... (truncated)" } @@ -193,19 +451,21 @@ func (cb *ContextBuilder) BuildMessages( "preview": preview, }) - if summary != "" { - systemPrompt += "\n\n## Summary of Previous Conversation\n\n" + summary - } - history = sanitizeHistoryForProvider(history) + // Single system message containing all context — compatible with all providers. + // SystemParts enables cache-aware adapters to set per-block cache_control; + // Content is the concatenated fallback for adapters that don't read SystemParts. messages = append(messages, providers.Message{ - Role: "system", - Content: systemPrompt, + Role: "system", + Content: fullSystemPrompt, + SystemParts: contentBlocks, }) + // Add conversation history messages = append(messages, history...) + // Add current user message if strings.TrimSpace(currentMessage) != "" { messages = append(messages, providers.Message{ Role: "user", @@ -224,13 +484,32 @@ func sanitizeHistoryForProvider(history []providers.Message) []providers.Message sanitized := make([]providers.Message, 0, len(history)) for _, msg := range history { switch msg.Role { + case "system": + // Drop system messages from history. BuildMessages always + // constructs its own single system message (static + dynamic + + // summary); extra system messages would break providers that + // only accept one (Anthropic, Codex). + logger.DebugCF("agent", "Dropping system message from history", map[string]any{}) + continue + case "tool": if len(sanitized) == 0 { logger.DebugCF("agent", "Dropping orphaned leading tool message", map[string]any{}) continue } - last := sanitized[len(sanitized)-1] - if last.Role != "assistant" || len(last.ToolCalls) == 0 { + // Walk backwards to find the nearest assistant message, + // skipping over any preceding tool messages (multi-tool-call case). + foundAssistant := false + for i := len(sanitized) - 1; i >= 0; i-- { + if sanitized[i].Role == "tool" { + continue + } + if sanitized[i].Role == "assistant" && len(sanitized[i].ToolCalls) > 0 { + foundAssistant = true + } + break + } + if !foundAssistant { logger.DebugCF("agent", "Dropping orphaned tool message", map[string]any{}) continue } diff --git a/pkg/agent/context_cache_test.go b/pkg/agent/context_cache_test.go new file mode 100644 index 000000000..ba70d4c0d --- /dev/null +++ b/pkg/agent/context_cache_test.go @@ -0,0 +1,513 @@ +package agent + +import ( + "os" + "path/filepath" + "strings" + "sync" + "testing" + "time" + + "github.com/sipeed/picoclaw/pkg/providers" +) + +// setupWorkspace creates a temporary workspace with standard directories and optional files. +// Returns the tmpDir path; caller should defer os.RemoveAll(tmpDir). +func setupWorkspace(t *testing.T, files map[string]string) string { + t.Helper() + tmpDir, err := os.MkdirTemp("", "picoclaw-test-*") + if err != nil { + t.Fatal(err) + } + os.MkdirAll(filepath.Join(tmpDir, "memory"), 0o755) + os.MkdirAll(filepath.Join(tmpDir, "skills"), 0o755) + for name, content := range files { + dir := filepath.Dir(filepath.Join(tmpDir, name)) + os.MkdirAll(dir, 0o755) + if err := os.WriteFile(filepath.Join(tmpDir, name), []byte(content), 0o644); err != nil { + t.Fatal(err) + } + } + return tmpDir +} + +// TestSingleSystemMessage verifies that BuildMessages always produces exactly one +// system message regardless of summary/history variations. +// Fix: multiple system messages break Anthropic (top-level system param) and +// Codex (only reads last system message as instructions). +func TestSingleSystemMessage(t *testing.T) { + tmpDir := setupWorkspace(t, map[string]string{ + "IDENTITY.md": "# Identity\nTest agent.", + }) + defer os.RemoveAll(tmpDir) + + cb := NewContextBuilder(tmpDir) + + tests := []struct { + name string + history []providers.Message + summary string + message string + }{ + { + name: "no summary, no history", + summary: "", + message: "hello", + }, + { + name: "with summary", + summary: "Previous conversation discussed X", + message: "hello", + }, + { + name: "with history and summary", + history: []providers.Message{ + {Role: "user", Content: "hi"}, + {Role: "assistant", Content: "hello"}, + }, + summary: strings.Repeat("Long summary text. ", 50), + message: "new message", + }, + { + name: "system message in history is filtered", + history: []providers.Message{ + {Role: "system", Content: "stale system prompt from previous session"}, + {Role: "user", Content: "hi"}, + {Role: "assistant", Content: "hello"}, + }, + summary: "", + message: "new message", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + msgs := cb.BuildMessages(tt.history, tt.summary, tt.message, nil, "test", "chat1") + + systemCount := 0 + for _, m := range msgs { + if m.Role == "system" { + systemCount++ + } + } + if systemCount != 1 { + t.Errorf("expected exactly 1 system message, got %d", systemCount) + } + if msgs[0].Role != "system" { + t.Errorf("first message should be system, got %s", msgs[0].Role) + } + if msgs[len(msgs)-1].Role != "user" { + t.Errorf("last message should be user, got %s", msgs[len(msgs)-1].Role) + } + + // System message must contain identity (static) and time (dynamic) + sys := msgs[0].Content + if !strings.Contains(sys, "picoclaw") { + t.Error("system message missing identity") + } + if !strings.Contains(sys, "Current Time") { + t.Error("system message missing dynamic time context") + } + + // Summary handling + if tt.summary != "" { + if !strings.Contains(sys, "CONTEXT_SUMMARY:") { + t.Error("summary present but CONTEXT_SUMMARY prefix missing") + } + if !strings.Contains(sys, tt.summary[:20]) { + t.Error("summary content not found in system message") + } + } else { + if strings.Contains(sys, "CONTEXT_SUMMARY:") { + t.Error("CONTEXT_SUMMARY should not appear without summary") + } + } + }) + } +} + +// TestMtimeAutoInvalidation verifies that the cache detects source file changes +// via mtime without requiring explicit InvalidateCache(). +// Fix: original implementation had no auto-invalidation — edits to bootstrap files, +// memory, or skills were invisible until process restart. +func TestMtimeAutoInvalidation(t *testing.T) { + tests := []struct { + name string + file string // relative path inside workspace + contentV1 string + contentV2 string + checkField string // substring to verify in rebuilt prompt + }{ + { + name: "bootstrap file change", + file: "IDENTITY.md", + contentV1: "# Original Identity", + contentV2: "# Updated Identity", + checkField: "Updated Identity", + }, + { + name: "memory file change", + file: "memory/MEMORY.md", + contentV1: "# Memory\nUser likes Go.", + contentV2: "# Memory\nUser likes Rust.", + checkField: "User likes Rust", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tmpDir := setupWorkspace(t, map[string]string{tt.file: tt.contentV1}) + defer os.RemoveAll(tmpDir) + + cb := NewContextBuilder(tmpDir) + + sp1 := cb.BuildSystemPromptWithCache() + + // Overwrite file and set future mtime to ensure detection. + // Use 2s offset for filesystem mtime resolution safety (some FS + // have 1s or coarser granularity, especially in CI containers). + fullPath := filepath.Join(tmpDir, tt.file) + os.WriteFile(fullPath, []byte(tt.contentV2), 0o644) + future := time.Now().Add(2 * time.Second) + os.Chtimes(fullPath, future, future) + + // Verify sourceFilesChangedLocked detects the mtime change + cb.systemPromptMutex.RLock() + changed := cb.sourceFilesChangedLocked() + cb.systemPromptMutex.RUnlock() + if !changed { + t.Fatalf("sourceFilesChangedLocked() should detect %s change", tt.file) + } + + // Should auto-rebuild without explicit InvalidateCache() + sp2 := cb.BuildSystemPromptWithCache() + if sp1 == sp2 { + t.Errorf("cache not rebuilt after %s change", tt.file) + } + if !strings.Contains(sp2, tt.checkField) { + t.Errorf("rebuilt prompt missing expected content %q", tt.checkField) + } + }) + } + + // Skills directory mtime change + t.Run("skills dir change", func(t *testing.T) { + tmpDir := setupWorkspace(t, nil) + defer os.RemoveAll(tmpDir) + + cb := NewContextBuilder(tmpDir) + _ = cb.BuildSystemPromptWithCache() // populate cache + + // Touch skills directory (simulate new skill installed) + skillsDir := filepath.Join(tmpDir, "skills") + future := time.Now().Add(2 * time.Second) + os.Chtimes(skillsDir, future, future) + + // Verify sourceFilesChangedLocked detects it (cache is rebuilt) + // We confirm by checking internal state: a second call should rebuild. + cb.systemPromptMutex.RLock() + changed := cb.sourceFilesChangedLocked() + cb.systemPromptMutex.RUnlock() + if !changed { + t.Error("sourceFilesChangedLocked() should detect skills dir mtime change") + } + }) +} + +// TestExplicitInvalidateCache verifies that InvalidateCache() forces a rebuild +// even when source files haven't changed (useful for tests and reload commands). +func TestExplicitInvalidateCache(t *testing.T) { + tmpDir := setupWorkspace(t, map[string]string{ + "IDENTITY.md": "# Test Identity", + }) + defer os.RemoveAll(tmpDir) + + cb := NewContextBuilder(tmpDir) + + sp1 := cb.BuildSystemPromptWithCache() + cb.InvalidateCache() + sp2 := cb.BuildSystemPromptWithCache() + + if sp1 != sp2 { + t.Error("prompt should be identical after invalidate+rebuild when files unchanged") + } + + // Verify cachedAt was reset + cb.InvalidateCache() + cb.systemPromptMutex.RLock() + if !cb.cachedAt.IsZero() { + t.Error("cachedAt should be zero after InvalidateCache()") + } + cb.systemPromptMutex.RUnlock() +} + +// TestCacheStability verifies that the static prompt is stable across repeated calls +// when no files change (regression test for issue #607). +func TestCacheStability(t *testing.T) { + tmpDir := setupWorkspace(t, map[string]string{ + "IDENTITY.md": "# Identity\nContent", + "SOUL.md": "# Soul\nContent", + }) + defer os.RemoveAll(tmpDir) + + cb := NewContextBuilder(tmpDir) + + results := make([]string, 5) + for i := range results { + results[i] = cb.BuildSystemPromptWithCache() + } + for i := 1; i < len(results); i++ { + if results[i] != results[0] { + t.Errorf("cached prompt changed between call 0 and %d", i) + } + } + + // Static prompt must NOT contain per-request data + if strings.Contains(results[0], "Current Time") { + t.Error("static cached prompt should not contain time (added dynamically)") + } +} + +// TestNewFileCreationInvalidatesCache verifies that creating a source file that +// did not exist when the cache was built triggers a cache rebuild. +// This catches the "from nothing to something" edge case that the old +// modifiedSince (return false on stat error) would miss. +func TestNewFileCreationInvalidatesCache(t *testing.T) { + tests := []struct { + name string + file string // relative path inside workspace + content string + checkField string // substring to verify in rebuilt prompt + }{ + { + name: "new bootstrap file", + file: "SOUL.md", + content: "# Soul\nBe kind and helpful.", + checkField: "Be kind and helpful", + }, + { + name: "new memory file", + file: "memory/MEMORY.md", + content: "# Memory\nUser prefers dark mode.", + checkField: "User prefers dark mode", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Start with an empty workspace (no bootstrap/memory files) + tmpDir := setupWorkspace(t, nil) + defer os.RemoveAll(tmpDir) + + cb := NewContextBuilder(tmpDir) + + // Populate cache — file does not exist yet + sp1 := cb.BuildSystemPromptWithCache() + if strings.Contains(sp1, tt.checkField) { + t.Fatalf("prompt should not contain %q before file is created", tt.checkField) + } + + // Create the file after cache was built + fullPath := filepath.Join(tmpDir, tt.file) + os.MkdirAll(filepath.Dir(fullPath), 0o755) + if err := os.WriteFile(fullPath, []byte(tt.content), 0o644); err != nil { + t.Fatal(err) + } + // Set future mtime to guarantee detection + future := time.Now().Add(2 * time.Second) + os.Chtimes(fullPath, future, future) + + // Cache should auto-invalidate because file went from absent -> present + sp2 := cb.BuildSystemPromptWithCache() + if !strings.Contains(sp2, tt.checkField) { + t.Errorf("cache not invalidated on new file creation: expected %q in prompt", tt.checkField) + } + }) + } +} + +// TestSkillFileContentChange verifies that modifying a skill file's content +// (not just the directory structure) invalidates the cache. +// This is the scenario where directory mtime alone is insufficient — on most +// filesystems, editing a file inside a directory does NOT update the parent +// directory's mtime. +func TestSkillFileContentChange(t *testing.T) { + skillMD := `--- +name: test-skill +description: "A test skill" +--- +# Test Skill v1 +Original content.` + + tmpDir := setupWorkspace(t, map[string]string{ + "skills/test-skill/SKILL.md": skillMD, + }) + defer os.RemoveAll(tmpDir) + + cb := NewContextBuilder(tmpDir) + + // Populate cache + sp1 := cb.BuildSystemPromptWithCache() + _ = sp1 // cache is warm + + // Modify the skill file content (without touching the skills/ directory) + updatedSkillMD := `--- +name: test-skill +description: "An updated test skill" +--- +# Test Skill v2 +Updated content.` + + skillPath := filepath.Join(tmpDir, "skills", "test-skill", "SKILL.md") + if err := os.WriteFile(skillPath, []byte(updatedSkillMD), 0o644); err != nil { + t.Fatal(err) + } + // Set future mtime on the skill file only (NOT the directory) + future := time.Now().Add(2 * time.Second) + os.Chtimes(skillPath, future, future) + + // Verify that sourceFilesChangedLocked detects the content change + cb.systemPromptMutex.RLock() + changed := cb.sourceFilesChangedLocked() + cb.systemPromptMutex.RUnlock() + if !changed { + t.Error("sourceFilesChangedLocked() should detect skill file content change") + } + + // Verify cache is actually rebuilt with new content + sp2 := cb.BuildSystemPromptWithCache() + if sp1 == sp2 && strings.Contains(sp1, "test-skill") { + // If the skill appeared in the prompt and the prompt didn't change, + // the cache was not invalidated. + t.Error("cache should be invalidated when skill file content changes") + } +} + +// TestConcurrentBuildSystemPromptWithCache verifies that multiple goroutines +// can safely call BuildSystemPromptWithCache concurrently without producing +// empty results, panics, or data races. +// Run with: go test -race ./pkg/agent/ -run TestConcurrentBuildSystemPromptWithCache +func TestConcurrentBuildSystemPromptWithCache(t *testing.T) { + tmpDir := setupWorkspace(t, map[string]string{ + "IDENTITY.md": "# Identity\nConcurrency test agent.", + "SOUL.md": "# Soul\nBe helpful.", + "memory/MEMORY.md": "# Memory\nUser prefers Go.", + "skills/demo/SKILL.md": "---\nname: demo\ndescription: \"demo skill\"\n---\n# Demo", + }) + defer os.RemoveAll(tmpDir) + + cb := NewContextBuilder(tmpDir) + + const goroutines = 20 + const iterations = 50 + + var wg sync.WaitGroup + errs := make(chan string, goroutines*iterations) + + for g := 0; g < goroutines; g++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + for i := 0; i < iterations; i++ { + result := cb.BuildSystemPromptWithCache() + if result == "" { + errs <- "empty prompt returned" + return + } + if !strings.Contains(result, "picoclaw") { + errs <- "prompt missing identity" + return + } + + // Also exercise BuildMessages concurrently + msgs := cb.BuildMessages(nil, "", "hello", nil, "test", "chat") + if len(msgs) < 2 { + errs <- "BuildMessages returned fewer than 2 messages" + return + } + if msgs[0].Role != "system" { + errs <- "first message not system" + return + } + + // Occasionally invalidate to exercise the write path + if i%10 == 0 { + cb.InvalidateCache() + } + } + }(g) + } + + wg.Wait() + close(errs) + + for errMsg := range errs { + t.Errorf("concurrent access error: %s", errMsg) + } +} + +// BenchmarkBuildMessagesWithCache measures caching performance. + +// TestEmptyWorkspaceBaselineDetectsNewFiles verifies that when the cache is +// built on an empty workspace (no tracked files exist), creating a file +// afterwards still triggers cache invalidation. This validates the +// time.Unix(1, 0) fallback for maxMtime: any real file's mtime is after epoch, +// so fileChangedSince correctly detects the absent -> present transition AND +// the mtime comparison succeeds even without artificially inflated Chtimes. +func TestEmptyWorkspaceBaselineDetectsNewFiles(t *testing.T) { + // Empty workspace: no bootstrap files, no memory, no skills content. + tmpDir := setupWorkspace(t, nil) + defer os.RemoveAll(tmpDir) + + cb := NewContextBuilder(tmpDir) + + // Build cache — all tracked files are absent, maxMtime falls back to epoch. + sp1 := cb.BuildSystemPromptWithCache() + + // Create a bootstrap file with natural mtime (no Chtimes manipulation). + // The file's mtime should be the current wall-clock time, which is + // strictly after time.Unix(1, 0). + soulPath := filepath.Join(tmpDir, "SOUL.md") + if err := os.WriteFile(soulPath, []byte("# Soul\nNewly created."), 0o644); err != nil { + t.Fatal(err) + } + + // Cache should detect the new file via existedAtCache (absent -> present). + cb.systemPromptMutex.RLock() + changed := cb.sourceFilesChangedLocked() + cb.systemPromptMutex.RUnlock() + if !changed { + t.Fatal("sourceFilesChangedLocked should detect newly created file on empty workspace") + } + + sp2 := cb.BuildSystemPromptWithCache() + if !strings.Contains(sp2, "Newly created") { + t.Error("rebuilt prompt should contain new file content") + } + if sp1 == sp2 { + t.Error("cache should have been invalidated after file creation") + } +} + +// BenchmarkBuildMessagesWithCache measures caching performance. +func BenchmarkBuildMessagesWithCache(b *testing.B) { + tmpDir, _ := os.MkdirTemp("", "picoclaw-bench-*") + defer os.RemoveAll(tmpDir) + + os.MkdirAll(filepath.Join(tmpDir, "memory"), 0o755) + os.MkdirAll(filepath.Join(tmpDir, "skills"), 0o755) + for _, name := range []string{"IDENTITY.md", "SOUL.md", "USER.md"} { + os.WriteFile(filepath.Join(tmpDir, name), []byte(strings.Repeat("Content.\n", 10)), 0o644) + } + + cb := NewContextBuilder(tmpDir) + history := []providers.Message{ + {Role: "user", Content: "previous message"}, + {Role: "assistant", Content: "previous response"}, + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = cb.BuildMessages(history, "summary", "new message", nil, "cli", "test") + } +} diff --git a/pkg/agent/context_test.go b/pkg/agent/context_test.go new file mode 100644 index 000000000..e023c9c30 --- /dev/null +++ b/pkg/agent/context_test.go @@ -0,0 +1,209 @@ +package agent + +import ( + "testing" + + "github.com/sipeed/picoclaw/pkg/providers" +) + +func msg(role, content string) providers.Message { + return providers.Message{Role: role, Content: content} +} + +func assistantWithTools(toolIDs ...string) providers.Message { + calls := make([]providers.ToolCall, len(toolIDs)) + for i, id := range toolIDs { + calls[i] = providers.ToolCall{ID: id, Type: "function"} + } + return providers.Message{Role: "assistant", ToolCalls: calls} +} + +func toolResult(id string) providers.Message { + return providers.Message{Role: "tool", Content: "result", ToolCallID: id} +} + +func TestSanitizeHistoryForProvider_EmptyHistory(t *testing.T) { + result := sanitizeHistoryForProvider(nil) + if len(result) != 0 { + t.Fatalf("expected empty, got %d messages", len(result)) + } + + result = sanitizeHistoryForProvider([]providers.Message{}) + if len(result) != 0 { + t.Fatalf("expected empty, got %d messages", len(result)) + } +} + +func TestSanitizeHistoryForProvider_SingleToolCall(t *testing.T) { + history := []providers.Message{ + msg("user", "hello"), + assistantWithTools("A"), + toolResult("A"), + msg("assistant", "done"), + } + + result := sanitizeHistoryForProvider(history) + if len(result) != 4 { + t.Fatalf("expected 4 messages, got %d", len(result)) + } + assertRoles(t, result, "user", "assistant", "tool", "assistant") +} + +func TestSanitizeHistoryForProvider_MultiToolCalls(t *testing.T) { + history := []providers.Message{ + msg("user", "do two things"), + assistantWithTools("A", "B"), + toolResult("A"), + toolResult("B"), + msg("assistant", "both done"), + } + + result := sanitizeHistoryForProvider(history) + if len(result) != 5 { + t.Fatalf("expected 5 messages, got %d: %+v", len(result), roles(result)) + } + assertRoles(t, result, "user", "assistant", "tool", "tool", "assistant") +} + +func TestSanitizeHistoryForProvider_AssistantToolCallAfterPlainAssistant(t *testing.T) { + history := []providers.Message{ + msg("user", "hi"), + msg("assistant", "thinking"), + assistantWithTools("A"), + toolResult("A"), + } + + result := sanitizeHistoryForProvider(history) + if len(result) != 2 { + t.Fatalf("expected 2 messages, got %d: %+v", len(result), roles(result)) + } + assertRoles(t, result, "user", "assistant") +} + +func TestSanitizeHistoryForProvider_OrphanedLeadingTool(t *testing.T) { + history := []providers.Message{ + toolResult("A"), + msg("user", "hello"), + } + + result := sanitizeHistoryForProvider(history) + if len(result) != 1 { + t.Fatalf("expected 1 message, got %d: %+v", len(result), roles(result)) + } + assertRoles(t, result, "user") +} + +func TestSanitizeHistoryForProvider_ToolAfterUserDropped(t *testing.T) { + history := []providers.Message{ + msg("user", "hello"), + toolResult("A"), + } + + result := sanitizeHistoryForProvider(history) + if len(result) != 1 { + t.Fatalf("expected 1 message, got %d: %+v", len(result), roles(result)) + } + assertRoles(t, result, "user") +} + +func TestSanitizeHistoryForProvider_ToolAfterAssistantNoToolCalls(t *testing.T) { + history := []providers.Message{ + msg("user", "hello"), + msg("assistant", "hi"), + toolResult("A"), + } + + result := sanitizeHistoryForProvider(history) + if len(result) != 2 { + t.Fatalf("expected 2 messages, got %d: %+v", len(result), roles(result)) + } + assertRoles(t, result, "user", "assistant") +} + +func TestSanitizeHistoryForProvider_AssistantToolCallAtStart(t *testing.T) { + history := []providers.Message{ + assistantWithTools("A"), + toolResult("A"), + msg("user", "hello"), + } + + result := sanitizeHistoryForProvider(history) + if len(result) != 1 { + t.Fatalf("expected 1 message, got %d: %+v", len(result), roles(result)) + } + assertRoles(t, result, "user") +} + +func TestSanitizeHistoryForProvider_MultiToolCallsThenNewRound(t *testing.T) { + history := []providers.Message{ + msg("user", "do two things"), + assistantWithTools("A", "B"), + toolResult("A"), + toolResult("B"), + msg("assistant", "done"), + msg("user", "hi"), + assistantWithTools("C"), + toolResult("C"), + msg("assistant", "done again"), + } + + result := sanitizeHistoryForProvider(history) + if len(result) != 9 { + t.Fatalf("expected 9 messages, got %d: %+v", len(result), roles(result)) + } + assertRoles(t, result, "user", "assistant", "tool", "tool", "assistant", "user", "assistant", "tool", "assistant") +} + +func TestSanitizeHistoryForProvider_ConsecutiveMultiToolRounds(t *testing.T) { + history := []providers.Message{ + msg("user", "start"), + assistantWithTools("A", "B"), + toolResult("A"), + toolResult("B"), + assistantWithTools("C", "D"), + toolResult("C"), + toolResult("D"), + msg("assistant", "all done"), + } + + result := sanitizeHistoryForProvider(history) + if len(result) != 8 { + t.Fatalf("expected 8 messages, got %d: %+v", len(result), roles(result)) + } + assertRoles(t, result, "user", "assistant", "tool", "tool", "assistant", "tool", "tool", "assistant") +} + +func TestSanitizeHistoryForProvider_PlainConversation(t *testing.T) { + history := []providers.Message{ + msg("user", "hello"), + msg("assistant", "hi"), + msg("user", "how are you"), + msg("assistant", "fine"), + } + + result := sanitizeHistoryForProvider(history) + if len(result) != 4 { + t.Fatalf("expected 4 messages, got %d", len(result)) + } + assertRoles(t, result, "user", "assistant", "user", "assistant") +} + +func roles(msgs []providers.Message) []string { + r := make([]string, len(msgs)) + for i, m := range msgs { + r[i] = m.Role + } + return r +} + +func assertRoles(t *testing.T, msgs []providers.Message, expected ...string) { + t.Helper() + if len(msgs) != len(expected) { + t.Fatalf("role count mismatch: got %v, want %v", roles(msgs), expected) + } + for i, exp := range expected { + if msgs[i].Role != exp { + t.Errorf("message[%d]: got role %q, want %q", i, msgs[i].Role, exp) + } + } +} diff --git a/pkg/agent/instance.go b/pkg/agent/instance.go index c6a54c7d2..a6fd365c7 100644 --- a/pkg/agent/instance.go +++ b/pkg/agent/instance.go @@ -59,7 +59,6 @@ func NewAgentInstance( sessionsManager := session.NewSessionManager(sessionsDir) contextBuilder := NewContextBuilder(workspace) - contextBuilder.SetToolsRegistry(toolsRegistry) agentID := routing.DefaultAgentID agentName := "" diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index dbc4a9b87..5558f7c0e 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -149,9 +149,6 @@ func registerSharedTools( return registry.CanSpawnSubagent(currentAgentID, targetAgentID) }) agent.Tools.Register(spawnTool) - - // Update context builder with the complete tools registry - agent.ContextBuilder.SetToolsRegistry(agent.Tools) } } @@ -524,8 +521,9 @@ func (al *AgentLoop) runLLMIteration( fbResult, fbErr := al.fallback.Execute(ctx, agent.Candidates, func(ctx context.Context, provider, model string) (*providers.LLMResponse, error) { return agent.Provider.Chat(ctx, messages, providerToolDefs, model, map[string]any{ - "max_tokens": agent.MaxTokens, - "temperature": agent.Temperature, + "max_tokens": agent.MaxTokens, + "temperature": agent.Temperature, + "prompt_cache_key": agent.ID, }) }, ) @@ -540,8 +538,9 @@ func (al *AgentLoop) runLLMIteration( return fbResult.Response, nil } return agent.Provider.Chat(ctx, messages, providerToolDefs, agent.Model, map[string]any{ - "max_tokens": agent.MaxTokens, - "temperature": agent.Temperature, + "max_tokens": agent.MaxTokens, + "temperature": agent.Temperature, + "prompt_cache_key": agent.ID, }) } @@ -800,7 +799,7 @@ func (al *AgentLoop) forceCompression(agent *AgentInstance, sessionKey string) { droppedCount := mid keptConversation := conversation[mid:] - newHistory := make([]providers.Message, 0) + newHistory := make([]providers.Message, 0, 1+len(keptConversation)+1) // Append compression note to the original system prompt instead of adding a new system message // This avoids having two consecutive system messages which some APIs (like Zhipu) reject @@ -962,8 +961,9 @@ func (al *AgentLoop) summarizeSession(agent *AgentInstance, sessionKey string) { nil, agent.Model, map[string]any{ - "max_tokens": 1024, - "temperature": 0.3, + "max_tokens": 1024, + "temperature": 0.3, + "prompt_cache_key": agent.ID, }, ) if err == nil { @@ -1012,8 +1012,9 @@ func (al *AgentLoop) summarizeBatch( nil, agent.Model, map[string]any{ - "max_tokens": 1024, - "temperature": 0.3, + "max_tokens": 1024, + "temperature": 0.3, + "prompt_cache_key": agent.ID, }, ) if err != nil { diff --git a/pkg/auth/oauth.go b/pkg/auth/oauth.go index cf8c1c9c4..ba757ffd4 100644 --- a/pkg/auth/oauth.go +++ b/pkg/auth/oauth.go @@ -156,7 +156,7 @@ func LoginBrowser(cfg OAuthProviderConfig) (*AuthCredential, error) { return exchangeCodeForTokens(cfg, result.code, pkce.CodeVerifier, redirectURI) case manualInput := <-manualCh: if manualInput == "" { - return nil, fmt.Errorf("manual input cancelled") + return nil, fmt.Errorf("manual input canceled") } // Extract code from URL if it's a full URL code := manualInput diff --git a/pkg/channels/discord.go b/pkg/channels/discord.go index 20f3b267c..f6faa3373 100644 --- a/pkg/channels/discord.go +++ b/pkg/channels/discord.go @@ -233,7 +233,7 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag if localPath != "" { localFiles = append(localFiles, localPath) - transcribedText := "" + var transcribedText string if c.transcriber != nil && c.transcriber.IsAvailable() { ctx, cancel := context.WithTimeout(c.getContext(), transcriptionTimeout) result, err := c.transcriber.Transcribe(ctx, localPath) diff --git a/pkg/channels/onebot.go b/pkg/channels/onebot.go index cee8ad9d3..4576a11ce 100644 --- a/pkg/channels/onebot.go +++ b/pkg/channels/onebot.go @@ -174,7 +174,10 @@ func (c *OneBotChannel) connect() error { header["Authorization"] = []string{"Bearer " + c.config.AccessToken} } - conn, _, err := dialer.Dial(c.config.WSUrl, header) + conn, resp, err := dialer.Dial(c.config.WSUrl, header) + if resp != nil { + resp.Body.Close() + } if err != nil { return err } @@ -310,7 +313,7 @@ func (c *OneBotChannel) sendAPIRequest(action string, params any, timeout time.D case <-time.After(timeout): return nil, fmt.Errorf("API request %s timed out after %v", action, timeout) case <-c.ctx.Done(): - return nil, fmt.Errorf("context cancelled") + return nil, fmt.Errorf("context canceled") } } @@ -695,7 +698,6 @@ func (c *OneBotChannel) parseMessageSegments(raw json.RawMessage, selfID int64) textParts = append(textParts, "[forward message]") default: - } } diff --git a/pkg/channels/slack.go b/pkg/channels/slack.go index f087aa8da..cfb731b16 100644 --- a/pkg/channels/slack.go +++ b/pkg/channels/slack.go @@ -439,5 +439,5 @@ func parseSlackChatID(chatID string) (channelID, threadTS string) { if len(parts) > 1 { threadTS = parts[1] } - return + return channelID, threadTS } diff --git a/pkg/channels/telegram.go b/pkg/channels/telegram.go index 5cd51e8bc..524494849 100644 --- a/pkg/channels/telegram.go +++ b/pkg/channels/telegram.go @@ -265,7 +265,7 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Mes localFiles = append(localFiles, voicePath) mediaPaths = append(mediaPaths, voicePath) - transcribedText := "" + var transcribedText string if c.transcriber != nil && c.transcriber.IsAvailable() { transcriberCtx, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() diff --git a/pkg/channels/whatsapp.go b/pkg/channels/whatsapp.go index 958d850bb..2dc4017ac 100644 --- a/pkg/channels/whatsapp.go +++ b/pkg/channels/whatsapp.go @@ -41,7 +41,10 @@ func (c *WhatsAppChannel) Start(ctx context.Context) error { dialer := websocket.DefaultDialer dialer.HandshakeTimeout = 10 * time.Second - conn, _, err := dialer.Dial(c.url, nil) + conn, resp, err := dialer.Dial(c.url, nil) + if resp != nil { + resp.Body.Close() + } if err != nil { return fmt.Errorf("failed to connect to WhatsApp bridge: %w", err) } diff --git a/pkg/heartbeat/service.go b/pkg/heartbeat/service.go index 3aeca6188..58462c120 100644 --- a/pkg/heartbeat/service.go +++ b/pkg/heartbeat/service.go @@ -16,10 +16,10 @@ import ( "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/constants" + "github.com/sipeed/picoclaw/pkg/fileutil" "github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/state" "github.com/sipeed/picoclaw/pkg/tools" - "github.com/sipeed/picoclaw/pkg/fileutil" ) const ( @@ -167,7 +167,7 @@ func (hs *HeartbeatService) executeHeartbeat() { } if handler == nil { - hs.logError("Heartbeat handler not configured") + hs.logErrorf("Heartbeat handler not configured") return } @@ -176,23 +176,23 @@ func (hs *HeartbeatService) executeHeartbeat() { channel, chatID := hs.parseLastChannel(lastChannel) // Debug log for channel resolution - hs.logInfo("Resolved channel: %s, chatID: %s (from lastChannel: %s)", channel, chatID, lastChannel) + hs.logInfof("Resolved channel: %s, chatID: %s (from lastChannel: %s)", channel, chatID, lastChannel) result := handler(prompt, channel, chatID) if result == nil { - hs.logInfo("Heartbeat handler returned nil result") + hs.logInfof("Heartbeat handler returned nil result") return } // Handle different result types if result.IsError { - hs.logError("Heartbeat error: %s", result.ForLLM) + hs.logErrorf("Heartbeat error: %s", result.ForLLM) return } if result.Async { - hs.logInfo("Async task started: %s", result.ForLLM) + hs.logInfof("Async task started: %s", result.ForLLM) logger.InfoCF("heartbeat", "Async heartbeat task started", map[string]any{ "message": result.ForLLM, @@ -202,7 +202,7 @@ func (hs *HeartbeatService) executeHeartbeat() { // Check if silent if result.Silent { - hs.logInfo("Heartbeat OK - silent") + hs.logInfof("Heartbeat OK - silent") return } @@ -213,7 +213,7 @@ func (hs *HeartbeatService) executeHeartbeat() { hs.sendResponse(result.ForLLM) } - hs.logInfo("Heartbeat completed: %s", result.ForLLM) + hs.logInfof("Heartbeat completed: %s", result.ForLLM) } // buildPrompt builds the heartbeat prompt from HEARTBEAT.md @@ -226,7 +226,7 @@ func (hs *HeartbeatService) buildPrompt() string { hs.createDefaultHeartbeatTemplate() return "" } - hs.logError("Error reading HEARTBEAT.md: %v", err) + hs.logErrorf("Error reading HEARTBEAT.md: %v", err) return "" } @@ -277,9 +277,9 @@ Add your heartbeat tasks below this line: ` if err := fileutil.WriteFileAtomic(heartbeatPath, []byte(defaultContent), 0o644); err != nil { - hs.logError("Failed to create default HEARTBEAT.md: %v", err) + hs.logErrorf("Failed to create default HEARTBEAT.md: %v", err) } else { - hs.logInfo("Created default HEARTBEAT.md template") + hs.logInfof("Created default HEARTBEAT.md template") } } @@ -290,14 +290,14 @@ func (hs *HeartbeatService) sendResponse(response string) { hs.mu.RUnlock() if msgBus == nil { - hs.logInfo("No message bus configured, heartbeat result not sent") + hs.logInfof("No message bus configured, heartbeat result not sent") return } // Get last channel from state lastChannel := hs.state.GetLastChannel() if lastChannel == "" { - hs.logInfo("No last channel recorded, heartbeat result not sent") + hs.logInfof("No last channel recorded, heartbeat result not sent") return } @@ -314,7 +314,7 @@ func (hs *HeartbeatService) sendResponse(response string) { Content: response, }) - hs.logInfo("Heartbeat result sent to %s", platform) + hs.logInfof("Heartbeat result sent to %s", platform) } // parseLastChannel parses the last channel string into platform and userID. @@ -327,7 +327,7 @@ func (hs *HeartbeatService) parseLastChannel(lastChannel string) (platform, user // Parse channel format: "platform:user_id" (e.g., "telegram:123456") parts := strings.SplitN(lastChannel, ":", 2) if len(parts) != 2 || parts[0] == "" || parts[1] == "" { - hs.logError("Invalid last channel format: %s", lastChannel) + hs.logErrorf("Invalid last channel format: %s", lastChannel) return "", "" } @@ -335,25 +335,25 @@ func (hs *HeartbeatService) parseLastChannel(lastChannel string) (platform, user // Skip internal channels if constants.IsInternalChannel(platform) { - hs.logInfo("Skipping internal channel: %s", platform) + hs.logInfof("Skipping internal channel: %s", platform) return "", "" } return platform, userID } -// logInfo logs an informational message to the heartbeat log -func (hs *HeartbeatService) logInfo(format string, args ...any) { - hs.log("INFO", format, args...) +// logInfof logs an informational message to the heartbeat log +func (hs *HeartbeatService) logInfof(format string, args ...any) { + hs.logf("INFO", format, args...) } -// logError logs an error message to the heartbeat log -func (hs *HeartbeatService) logError(format string, args ...any) { - hs.log("ERROR", format, args...) +// logErrorf logs an error message to the heartbeat log +func (hs *HeartbeatService) logErrorf(format string, args ...any) { + hs.logf("ERROR", format, args...) } -// log writes a message to the heartbeat log file -func (hs *HeartbeatService) log(level, format string, args ...any) { +// logf writes a message to the heartbeat log file +func (hs *HeartbeatService) logf(level, format string, args ...any) { logFile := filepath.Join(hs.workspace, "heartbeat.log") f, err := os.OpenFile(logFile, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0o644) if err != nil { diff --git a/pkg/heartbeat/service_test.go b/pkg/heartbeat/service_test.go index a4dfa7a72..a7aef8c3a 100644 --- a/pkg/heartbeat/service_test.go +++ b/pkg/heartbeat/service_test.go @@ -191,7 +191,7 @@ func TestLogPath(t *testing.T) { hs := NewHeartbeatService(tmpDir, 30, true) // Write a log entry - hs.log("INFO", "Test log entry") + hs.logf("INFO", "Test log entry") // Verify log file exists at workspace root expectedLogPath := filepath.Join(tmpDir, "heartbeat.log") diff --git a/pkg/logger/logger.go b/pkg/logger/logger.go index c14fbd464..56dc87a53 100644 --- a/pkg/logger/logger.go +++ b/pkg/logger/logger.go @@ -153,7 +153,7 @@ func formatComponent(component string) string { } func formatFields(fields map[string]any) string { - var parts []string + parts := make([]string, 0, len(fields)) for k, v := range fields { parts = append(parts, fmt.Sprintf("%s=%v", k, v)) } diff --git a/pkg/providers/anthropic/provider.go b/pkg/providers/anthropic/provider.go index 35f6b8f62..9162174c9 100644 --- a/pkg/providers/anthropic/provider.go +++ b/pkg/providers/anthropic/provider.go @@ -113,7 +113,20 @@ func buildParams( for _, msg := range messages { switch msg.Role { case "system": - system = append(system, anthropic.TextBlockParam{Text: msg.Content}) + // Prefer structured SystemParts for per-block cache_control. + // This enables LLM-side KV cache reuse: the static block's prefix + // hash stays stable across requests while dynamic parts change freely. + if len(msg.SystemParts) > 0 { + for _, part := range msg.SystemParts { + block := anthropic.TextBlockParam{Text: part.Text} + if part.CacheControl != nil && part.CacheControl.Type == "ephemeral" { + block.CacheControl = anthropic.NewCacheControlEphemeralParam() + } + system = append(system, block) + } + } else { + system = append(system, anthropic.TextBlockParam{Text: msg.Content}) + } case "user": if msg.ToolCallID != "" { anthropicMessages = append(anthropicMessages, diff --git a/pkg/providers/codex_cli_credentials_test.go b/pkg/providers/codex_cli_credentials_test.go index 43b21700a..1e88c1120 100644 --- a/pkg/providers/codex_cli_credentials_test.go +++ b/pkg/providers/codex_cli_credentials_test.go @@ -43,12 +43,18 @@ func TestReadCodexCliCredentials_Valid(t *testing.T) { } } +// readCodexCliCredentialsErr calls ReadCodexCliCredentials and returns only the +// error, for tests that only need to assert on failure. +func readCodexCliCredentialsErr() error { + _, _, _, err := ReadCodexCliCredentials() //nolint:dogsled + return err +} + func TestReadCodexCliCredentials_MissingFile(t *testing.T) { tmpDir := t.TempDir() t.Setenv("CODEX_HOME", tmpDir) - _, _, _, err := ReadCodexCliCredentials() - if err == nil { + if err := readCodexCliCredentialsErr(); err == nil { t.Fatal("expected error for missing auth.json") } } @@ -64,8 +70,7 @@ func TestReadCodexCliCredentials_EmptyToken(t *testing.T) { t.Setenv("CODEX_HOME", tmpDir) - _, _, _, err := ReadCodexCliCredentials() - if err == nil { + if err := readCodexCliCredentialsErr(); err == nil { t.Fatal("expected error for empty access_token") } } @@ -80,8 +85,7 @@ func TestReadCodexCliCredentials_InvalidJSON(t *testing.T) { t.Setenv("CODEX_HOME", tmpDir) - _, _, _, err := ReadCodexCliCredentials() - if err == nil { + if err := readCodexCliCredentialsErr(); err == nil { t.Fatal("expected error for invalid JSON") } } diff --git a/pkg/providers/codex_provider.go b/pkg/providers/codex_provider.go index ecc983642..dcc740ba4 100644 --- a/pkg/providers/codex_provider.go +++ b/pkg/providers/codex_provider.go @@ -106,8 +106,8 @@ func (p *CodexProvider) Chat( if evt.Type == "response.completed" || evt.Type == "response.failed" || evt.Type == "response.incomplete" { evtResp := evt.Response if evtResp.ID != "" { - copy := evtResp - resp = © + evtRespCopy := evtResp + resp = &evtRespCopy } } } @@ -208,6 +208,11 @@ func buildCodexParams( for _, msg := range messages { switch msg.Role { case "system": + // Use the full concatenated system prompt (static + dynamic + summary) + // as instructions. This keeps behavior consistent with Anthropic and + // OpenAI-compat adapters where the complete system context lives in + // one place. Prefix caching is handled by prompt_cache_key below, + // not by splitting content across instructions vs input messages. instructions = msg.Content case "user": if msg.ToolCallID != "" { @@ -289,6 +294,13 @@ func buildCodexParams( params.Instructions = openai.Opt(defaultCodexInstructions) } + // Prompt caching: pass a stable cache key so OpenAI can bucket requests + // and reuse prefix KV cache across calls with the same key. + // See: https://platform.openai.com/docs/guides/prompt-caching + if cacheKey, ok := options["prompt_cache_key"].(string); ok && cacheKey != "" { + params.PromptCacheKey = openai.Opt(cacheKey) + } + if len(tools) > 0 || enableWebSearch { params.Tools = translateToolsForCodex(tools, enableWebSearch) } diff --git a/pkg/providers/github_copilot_provider.go b/pkg/providers/github_copilot_provider.go index 9210021e1..3fb15db2f 100644 --- a/pkg/providers/github_copilot_provider.go +++ b/pkg/providers/github_copilot_provider.go @@ -44,7 +44,6 @@ func NewGitHubCopilotProvider(uri string, connectMode string, model string) (*Gi Hooks: &copilot.SessionHooks{}, }) if err != nil { - client.Stop() return nil, fmt.Errorf("create session failed: %w", err) } @@ -101,7 +100,7 @@ func (p *GitHubCopilotProvider) Chat( return nil, fmt.Errorf("provider closed") } - resp, err := session.SendAndWait(ctx, copilot.MessageOptions{ + resp, _ := session.SendAndWait(ctx, copilot.MessageOptions{ Prompt: string(fullcontent), }) diff --git a/pkg/providers/openai_compat/provider.go b/pkg/providers/openai_compat/provider.go index d2412ae1b..a8d244d4a 100644 --- a/pkg/providers/openai_compat/provider.go +++ b/pkg/providers/openai_compat/provider.go @@ -77,7 +77,7 @@ func (p *Provider) Chat( requestBody := map[string]any{ "model": model, - "messages": messages, + "messages": stripSystemParts(messages), } if len(tools) > 0 { @@ -111,6 +111,14 @@ func (p *Provider) Chat( } } + // Prompt caching: pass a stable cache key so OpenAI can bucket requests + // with the same key and reuse prefix KV cache across calls. + // The key is typically the agent ID — stable per agent, shared across requests. + // See: https://platform.openai.com/docs/guides/prompt-caching + if cacheKey, ok := options["prompt_cache_key"].(string); ok && cacheKey != "" { + requestBody["prompt_cache_key"] = cacheKey + } + jsonData, err := json.Marshal(requestBody) if err != nil { return nil, fmt.Errorf("failed to marshal request: %w", err) @@ -230,6 +238,32 @@ func parseResponse(body []byte) (*LLMResponse, error) { }, nil } +// openaiMessage is the wire-format message for OpenAI-compatible APIs. +// It mirrors protocoltypes.Message but omits SystemParts, which is an +// internal field that would be unknown to third-party endpoints. +type openaiMessage struct { + Role string `json:"role"` + Content string `json:"content"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` +} + +// stripSystemParts converts []Message to []openaiMessage, dropping the +// SystemParts field so it doesn't leak into the JSON payload sent to +// OpenAI-compatible APIs (some strict endpoints reject unknown fields). +func stripSystemParts(messages []Message) []openaiMessage { + out := make([]openaiMessage, len(messages)) + for i, m := range messages { + out[i] = openaiMessage{ + Role: m.Role, + Content: m.Content, + ToolCalls: m.ToolCalls, + ToolCallID: m.ToolCallID, + } + } + return out +} + func normalizeModel(model, apiBase string) string { idx := strings.Index(model, "/") if idx == -1 { diff --git a/pkg/providers/protocoltypes/types.go b/pkg/providers/protocoltypes/types.go index 1d0ea6edd..33f052c5a 100644 --- a/pkg/providers/protocoltypes/types.go +++ b/pkg/providers/protocoltypes/types.go @@ -38,12 +38,28 @@ type UsageInfo struct { TotalTokens int `json:"total_tokens"` } +// CacheControl marks a content block for LLM-side prefix caching. +// Currently only "ephemeral" is supported (used by Anthropic). +type CacheControl struct { + Type string `json:"type"` // "ephemeral" +} + +// ContentBlock represents a structured segment of a system message. +// Adapters that understand SystemParts can use these blocks to set +// per-block cache control (e.g. Anthropic's cache_control: ephemeral). +type ContentBlock struct { + Type string `json:"type"` // "text" + Text string `json:"text"` + CacheControl *CacheControl `json:"cache_control,omitempty"` +} + type Message struct { - Role string `json:"role"` - Content string `json:"content"` - ReasoningContent string `json:"reasoning_content,omitempty"` - ToolCalls []ToolCall `json:"tool_calls,omitempty"` - ToolCallID string `json:"tool_call_id,omitempty"` + Role string `json:"role"` + Content string `json:"content"` + ReasoningContent string `json:"reasoning_content,omitempty"` + SystemParts []ContentBlock `json:"system_parts,omitempty"` // structured system blocks for cache-aware adapters + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` } type ToolDefinition struct { diff --git a/pkg/providers/types.go b/pkg/providers/types.go index b2dda04a5..f0c168bc6 100644 --- a/pkg/providers/types.go +++ b/pkg/providers/types.go @@ -17,6 +17,8 @@ type ( ToolFunctionDefinition = protocoltypes.ToolFunctionDefinition ExtraContent = protocoltypes.ExtraContent GoogleExtra = protocoltypes.GoogleExtra + ContentBlock = protocoltypes.ContentBlock + CacheControl = protocoltypes.CacheControl ) type LLMProvider interface { diff --git a/pkg/tools/registry.go b/pkg/tools/registry.go index 6ecb8ae7c..d37a093a8 100644 --- a/pkg/tools/registry.go +++ b/pkg/tools/registry.go @@ -3,6 +3,7 @@ package tools import ( "context" "fmt" + "sort" "sync" "time" @@ -107,13 +108,27 @@ func (r *ToolRegistry) ExecuteWithContext( return result } +// sortedToolNames returns tool names in sorted order for deterministic iteration. +// This is critical for KV cache stability: non-deterministic map iteration would +// produce different system prompts and tool definitions on each call, invalidating +// the LLM's prefix cache even when no tools have changed. +func (r *ToolRegistry) sortedToolNames() []string { + names := make([]string, 0, len(r.tools)) + for name := range r.tools { + names = append(names, name) + } + sort.Strings(names) + return names +} + func (r *ToolRegistry) GetDefinitions() []map[string]any { r.mu.RLock() defer r.mu.RUnlock() - definitions := make([]map[string]any, 0, len(r.tools)) - for _, tool := range r.tools { - definitions = append(definitions, ToolToSchema(tool)) + sorted := r.sortedToolNames() + definitions := make([]map[string]any, 0, len(sorted)) + for _, name := range sorted { + definitions = append(definitions, ToolToSchema(r.tools[name])) } return definitions } @@ -124,8 +139,10 @@ func (r *ToolRegistry) ToProviderDefs() []providers.ToolDefinition { r.mu.RLock() defer r.mu.RUnlock() - definitions := make([]providers.ToolDefinition, 0, len(r.tools)) - for _, tool := range r.tools { + sorted := r.sortedToolNames() + definitions := make([]providers.ToolDefinition, 0, len(sorted)) + for _, name := range sorted { + tool := r.tools[name] schema := ToolToSchema(tool) // Safely extract nested values with type checks @@ -155,11 +172,7 @@ func (r *ToolRegistry) List() []string { r.mu.RLock() defer r.mu.RUnlock() - names := make([]string, 0, len(r.tools)) - for name := range r.tools { - names = append(names, name) - } - return names + return r.sortedToolNames() } // Count returns the number of registered tools. @@ -175,8 +188,10 @@ func (r *ToolRegistry) GetSummaries() []string { r.mu.RLock() defer r.mu.RUnlock() - summaries := make([]string, 0, len(r.tools)) - for _, tool := range r.tools { + sorted := r.sortedToolNames() + summaries := make([]string, 0, len(sorted)) + for _, name := range sorted { + tool := r.tools[name] summaries = append(summaries, fmt.Sprintf("- `%s` - %s", tool.Name(), tool.Description())) } return summaries diff --git a/pkg/tools/shell.go b/pkg/tools/shell.go index 6883172cd..ad1664b5b 100644 --- a/pkg/tools/shell.go +++ b/pkg/tools/shell.go @@ -76,10 +76,9 @@ func NewExecTool(workingDir string, restrict bool) *ExecTool { func NewExecToolWithConfig(workingDir string, restrict bool, config *config.Config) *ExecTool { denyPatterns := make([]*regexp.Regexp, 0) - enableDenyPatterns := true if config != nil { execConfig := config.Tools.Exec - enableDenyPatterns = execConfig.EnableDenyPatterns + enableDenyPatterns := execConfig.EnableDenyPatterns if enableDenyPatterns { denyPatterns = append(denyPatterns, defaultDenyPatterns...) if len(execConfig.CustomDenyPatterns) > 0 { diff --git a/pkg/tools/spawn.go b/pkg/tools/spawn.go index 73d385cb0..8b166b41f 100644 --- a/pkg/tools/spawn.go +++ b/pkg/tools/spawn.go @@ -3,6 +3,7 @@ package tools import ( "context" "fmt" + "strings" ) type SpawnTool struct { @@ -66,8 +67,8 @@ func (t *SpawnTool) SetAllowlistChecker(check func(targetAgentID string) bool) { func (t *SpawnTool) Execute(ctx context.Context, args map[string]any) *ToolResult { task, ok := args["task"].(string) - if !ok { - return ErrorResult("task is required") + if !ok || strings.TrimSpace(task) == "" { + return ErrorResult("task is required and must be a non-empty string") } label, _ := args["label"].(string) diff --git a/pkg/tools/spawn_test.go b/pkg/tools/spawn_test.go new file mode 100644 index 000000000..0646c82a9 --- /dev/null +++ b/pkg/tools/spawn_test.go @@ -0,0 +1,79 @@ +package tools + +import ( + "context" + "strings" + "testing" +) + +func TestSpawnTool_Execute_EmptyTask(t *testing.T) { + provider := &MockLLMProvider{} + manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil) + tool := NewSpawnTool(manager) + + ctx := context.Background() + + tests := []struct { + name string + args map[string]any + }{ + {"empty string", map[string]any{"task": ""}}, + {"whitespace only", map[string]any{"task": " "}}, + {"tabs and newlines", map[string]any{"task": "\t\n "}}, + {"missing task key", map[string]any{"label": "test"}}, + {"wrong type", map[string]any{"task": 123}}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tool.Execute(ctx, tt.args) + if result == nil { + t.Fatal("Result should not be nil") + } + if !result.IsError { + t.Error("Expected error for invalid task parameter") + } + if !strings.Contains(result.ForLLM, "task is required") { + t.Errorf("Error message should mention 'task is required', got: %s", result.ForLLM) + } + }) + } +} + +func TestSpawnTool_Execute_ValidTask(t *testing.T) { + provider := &MockLLMProvider{} + manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil) + tool := NewSpawnTool(manager) + + ctx := context.Background() + args := map[string]any{ + "task": "Write a haiku about coding", + "label": "haiku-task", + } + + result := tool.Execute(ctx, args) + if result == nil { + t.Fatal("Result should not be nil") + } + if result.IsError { + t.Errorf("Expected success for valid task, got error: %s", result.ForLLM) + } + if !result.Async { + t.Error("SpawnTool should return async result") + } +} + +func TestSpawnTool_Execute_NilManager(t *testing.T) { + tool := NewSpawnTool(nil) + + ctx := context.Background() + args := map[string]any{"task": "test task"} + + result := tool.Execute(ctx, args) + if !result.IsError { + t.Error("Expected error for nil manager") + } + if !strings.Contains(result.ForLLM, "Subagent manager not configured") { + t.Errorf("Error message should mention manager not configured, got: %s", result.ForLLM) + } +} diff --git a/pkg/tools/subagent.go b/pkg/tools/subagent.go index 91ebff636..ad371a649 100644 --- a/pkg/tools/subagent.go +++ b/pkg/tools/subagent.go @@ -132,12 +132,12 @@ After completing the task, provide a clear summary of what was done.` }, } - // Check if context is already cancelled before starting + // Check if context is already canceled before starting select { case <-ctx.Done(): sm.mu.Lock() - task.Status = "cancelled" - task.Result = "Task cancelled before execution" + task.Status = "canceled" + task.Result = "Task canceled before execution" sm.mu.Unlock() return default: @@ -185,10 +185,10 @@ After completing the task, provide a clear summary of what was done.` if err != nil { task.Status = "failed" task.Result = fmt.Sprintf("Error: %v", err) - // Check if it was cancelled + // Check if it was canceled if ctx.Err() != nil { - task.Status = "cancelled" - task.Result = "Task cancelled during execution" + task.Status = "canceled" + task.Result = "Task canceled during execution" } result = &ToolResult{ ForLLM: task.Result,