diff --git a/cmd/picoclaw/cmd_skills.go b/cmd/picoclaw/cmd_skills.go index 9ea38dcf6..32b7c62b8 100644 --- a/cmd/picoclaw/cmd_skills.go +++ b/cmd/picoclaw/cmd_skills.go @@ -11,15 +11,17 @@ import ( "strings" "time" + "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/skills" + "github.com/sipeed/picoclaw/pkg/utils" ) func skillsHelp() { fmt.Println("\nSkills commands:") fmt.Println(" list List installed skills") fmt.Println(" install Install skill from GitHub") - fmt.Println(" install-builtin Install all builtin skills to workspace") - fmt.Println(" list-builtin List available builtin skills") + fmt.Println(" install-builtin Install all builtin skills to workspace") + fmt.Println(" list-builtin List available builtin skills") fmt.Println(" remove Remove installed skill") fmt.Println(" search Search available skills") fmt.Println(" show Show skill details") @@ -30,6 +32,7 @@ func skillsHelp() { fmt.Println(" picoclaw skills install-builtin") fmt.Println(" picoclaw skills list-builtin") fmt.Println(" picoclaw skills remove weather") + fmt.Println(" picoclaw skills install --registry clawhub github") } func skillsListCmd(loader *skills.SkillsLoader) { @@ -50,13 +53,27 @@ func skillsListCmd(loader *skills.SkillsLoader) { } } -func skillsInstallCmd(installer *skills.SkillInstaller) { +func skillsInstallCmd(installer *skills.SkillInstaller, cfg *config.Config) { if len(os.Args) < 4 { fmt.Println("Usage: picoclaw skills install ") - fmt.Println("Example: picoclaw skills install sipeed/picoclaw-skills/weather") + fmt.Println(" picoclaw skills install --registry ") return } + // Check for --registry flag. + if os.Args[3] == "--registry" { + if len(os.Args) < 6 { + fmt.Println("Usage: picoclaw skills install --registry ") + fmt.Println("Example: picoclaw skills install --registry clawhub github") + return + } + registryName := os.Args[4] + slug := os.Args[5] + skillsInstallFromRegistry(cfg, registryName, slug) + return + } + + // Default: install from GitHub (backward compatible). repo := os.Args[3] fmt.Printf("Installing skill from %s...\n", repo) @@ -64,11 +81,83 @@ func skillsInstallCmd(installer *skills.SkillInstaller) { defer cancel() if err := installer.InstallFromGitHub(ctx, repo); err != nil { - fmt.Printf("✗ Failed to install skill: %v\n", err) + fmt.Printf("\u2717 Failed to install skill: %v\n", err) os.Exit(1) } - fmt.Printf("✓ Skill '%s' installed successfully!\n", filepath.Base(repo)) + fmt.Printf("\u2713 Skill '%s' installed successfully!\n", filepath.Base(repo)) +} + +// skillsInstallFromRegistry installs a skill from a named registry (e.g. clawhub). +func skillsInstallFromRegistry(cfg *config.Config, registryName, slug string) { + err := utils.ValidateSkillIdentifier(registryName) + if err != nil { + fmt.Printf("\u2717 Invalid registry name: %v\n", err) + os.Exit(1) + } + + err = utils.ValidateSkillIdentifier(slug) + if err != nil { + fmt.Printf("\u2717 Invalid slug: %v\n", err) + os.Exit(1) + } + + fmt.Printf("Installing skill '%s' from %s registry...\n", slug, registryName) + + registryMgr := skills.NewRegistryManagerFromConfig(skills.RegistryConfig{ + MaxConcurrentSearches: cfg.Tools.Skills.MaxConcurrentSearches, + ClawHub: skills.ClawHubConfig(cfg.Tools.Skills.Registries.ClawHub), + }) + + registry := registryMgr.GetRegistry(registryName) + if registry == nil { + fmt.Printf("\u2717 Registry '%s' not found or not enabled. Check your config.json.\n", registryName) + os.Exit(1) + } + + workspace := cfg.WorkspacePath() + targetDir := filepath.Join(workspace, "skills", slug) + + if _, err := os.Stat(targetDir); err == nil { + fmt.Printf("\u2717 Skill '%s' already installed at %s\n", slug, targetDir) + os.Exit(1) + } + + ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) + defer cancel() + + if err := os.MkdirAll(filepath.Join(workspace, "skills"), 0755); err != nil { + fmt.Printf("\u2717 Failed to create skills directory: %v\n", err) + os.Exit(1) + } + + result, err := registry.DownloadAndInstall(ctx, slug, "", targetDir) + if err != nil { + rmErr := os.RemoveAll(targetDir) + if rmErr != nil { + fmt.Printf("\u2717 Failed to remove partial install: %v\n", rmErr) + } + fmt.Printf("\u2717 Failed to install skill: %v\n", err) + os.Exit(1) + } + + if result.IsMalwareBlocked { + rmErr := os.RemoveAll(targetDir) + if rmErr != nil { + fmt.Printf("\u2717 Failed to remove partial install: %v\n", rmErr) + } + fmt.Printf("\u2717 Skill '%s' is flagged as malicious and cannot be installed.\n", slug) + os.Exit(1) + } + + if result.IsSuspicious { + fmt.Printf("\u26a0\ufe0f Warning: skill '%s' is flagged as suspicious.\n", slug) + } + + fmt.Printf("\u2713 Skill '%s' v%s installed successfully!\n", slug, result.Version) + if result.Summary != "" { + fmt.Printf(" %s\n", result.Summary) + } } func skillsRemoveCmd(installer *skills.SkillInstaller, skillName string) { diff --git a/cmd/picoclaw/main.go b/cmd/picoclaw/main.go index ce9389417..1e4b393f8 100644 --- a/cmd/picoclaw/main.go +++ b/cmd/picoclaw/main.go @@ -141,7 +141,7 @@ func main() { case "list": skillsListCmd(skillsLoader) case "install": - skillsInstallCmd(installer) + skillsInstallCmd(installer, cfg) case "remove", "uninstall": if len(os.Args) < 4 { fmt.Println("Usage: picoclaw skills remove ") diff --git a/config/config.example.json b/config/config.example.json index abc928e92..fa87fbec7 100644 --- a/config/config.example.json +++ b/config/config.example.json @@ -194,6 +194,17 @@ "exec": { "enable_deny_patterns": false, "custom_deny_patterns": [] + }, + "skills": { + "registries": { + "clawhub": { + "enabled": true, + "base_url": "https://clawhub.ai", + "search_path": "/api/v1/search", + "skills_path": "/api/v1/skills", + "download_path": "/api/v1/download" + } + } } }, "heartbeat": { diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index e7b48d47a..f8eef395a 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -23,6 +23,7 @@ import ( "github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/providers" "github.com/sipeed/picoclaw/pkg/routing" + "github.com/sipeed/picoclaw/pkg/skills" "github.com/sipeed/picoclaw/pkg/state" "github.com/sipeed/picoclaw/pkg/tools" "github.com/sipeed/picoclaw/pkg/utils" @@ -117,6 +118,15 @@ func registerSharedTools(cfg *config.Config, msgBus *bus.MessageBus, registry *A }) agent.Tools.Register(messageTool) + // Skill discovery and installation tools + registryMgr := skills.NewRegistryManagerFromConfig(skills.RegistryConfig{ + MaxConcurrentSearches: cfg.Tools.Skills.MaxConcurrentSearches, + ClawHub: skills.ClawHubConfig(cfg.Tools.Skills.Registries.ClawHub), + }) + searchCache := skills.NewSearchCache(cfg.Tools.Skills.SearchCache.MaxSize, time.Duration(cfg.Tools.Skills.SearchCache.TTLSeconds)*time.Second) + agent.Tools.Register(tools.NewFindSkillsTool(registryMgr, searchCache)) + agent.Tools.Register(tools.NewInstallSkillTool(registryMgr, agent.Workspace)) + // Spawn tool with allowlist checker subagentManager := tools.NewSubagentManager(provider, agent.Model, agent.Workspace, msgBus) subagentManager.SetLLMOptions(agent.MaxTokens, agent.Temperature) diff --git a/pkg/config/config.go b/pkg/config/config.go index 0d41796a4..9d5e5d42e 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -416,9 +416,37 @@ type ExecConfig struct { } type ToolsConfig struct { - Web WebToolsConfig `json:"web"` - Cron CronToolsConfig `json:"cron"` - Exec ExecConfig `json:"exec"` + Web WebToolsConfig `json:"web"` + Cron CronToolsConfig `json:"cron"` + Exec ExecConfig `json:"exec"` + Skills SkillsToolsConfig `json:"skills"` +} + +type SkillsToolsConfig struct { + Registries SkillsRegistriesConfig `json:"registries"` + MaxConcurrentSearches int `json:"max_concurrent_searches" env:"PICOCLAW_SKILLS_MAX_CONCURRENT_SEARCHES"` + SearchCache SearchCacheConfig `json:"search_cache"` +} + +type SearchCacheConfig struct { + MaxSize int `json:"max_size" env:"PICOCLAW_SKILLS_SEARCH_CACHE_MAX_SIZE"` + TTLSeconds int `json:"ttl_seconds" env:"PICOCLAW_SKILLS_SEARCH_CACHE_TTL_SECONDS"` +} + +type SkillsRegistriesConfig struct { + ClawHub ClawHubRegistryConfig `json:"clawhub"` +} + +type ClawHubRegistryConfig struct { + Enabled bool `json:"enabled" env:"PICOCLAW_SKILLS_REGISTRIES_CLAWHUB_ENABLED"` + BaseURL string `json:"base_url" env:"PICOCLAW_SKILLS_REGISTRIES_CLAWHUB_BASE_URL"` + AuthToken string `json:"auth_token" env:"PICOCLAW_SKILLS_REGISTRIES_CLAWHUB_AUTH_TOKEN"` + SearchPath string `json:"search_path" env:"PICOCLAW_SKILLS_REGISTRIES_CLAWHUB_SEARCH_PATH"` + SkillsPath string `json:"skills_path" env:"PICOCLAW_SKILLS_REGISTRIES_CLAWHUB_SKILLS_PATH"` + DownloadPath string `json:"download_path" env:"PICOCLAW_SKILLS_REGISTRIES_CLAWHUB_DOWNLOAD_PATH"` + Timeout int `json:"timeout" env:"PICOCLAW_SKILLS_REGISTRIES_CLAWHUB_TIMEOUT"` + MaxZipSize int `json:"max_zip_size" env:"PICOCLAW_SKILLS_REGISTRIES_CLAWHUB_MAX_ZIP_SIZE"` + MaxResponseSize int `json:"max_response_size" env:"PICOCLAW_SKILLS_REGISTRIES_CLAWHUB_MAX_RESPONSE_SIZE"` } func LoadConfig(path string) (*Config, error) { diff --git a/pkg/config/defaults.go b/pkg/config/defaults.go index 54d6d68c3..07974b8eb 100644 --- a/pkg/config/defaults.go +++ b/pkg/config/defaults.go @@ -265,6 +265,19 @@ func DefaultConfig() *Config { Exec: ExecConfig{ EnableDenyPatterns: true, }, + Skills: SkillsToolsConfig{ + Registries: SkillsRegistriesConfig{ + ClawHub: ClawHubRegistryConfig{ + Enabled: true, + BaseURL: "https://clawhub.ai", + }, + }, + MaxConcurrentSearches: 2, + SearchCache: SearchCacheConfig{ + MaxSize: 50, + TTLSeconds: 300, + }, + }, }, Heartbeat: HeartbeatConfig{ Enabled: true, diff --git a/pkg/skills/clawhub_registry.go b/pkg/skills/clawhub_registry.go new file mode 100644 index 000000000..e2a940afd --- /dev/null +++ b/pkg/skills/clawhub_registry.go @@ -0,0 +1,311 @@ +package skills + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "os" + "time" + + "github.com/sipeed/picoclaw/pkg/utils" +) + +const ( + defaultClawHubTimeout = 30 * time.Second + defaultMaxZipSize = 50 * 1024 * 1024 // 50 MB + defaultMaxResponseSize = 2 * 1024 * 1024 // 2 MB +) + +// ClawHubRegistry implements SkillRegistry for the ClawHub platform. +type ClawHubRegistry struct { + baseURL string + authToken string // Optional - for elevated rate limits + searchPath string // Search API + skillsPath string // For retrieving skill metadata + downloadPath string // For fetching ZIP files for download + maxZipSize int + maxResponseSize int + client *http.Client +} + +// NewClawHubRegistry creates a new ClawHub registry client from config. +func NewClawHubRegistry(cfg ClawHubConfig) *ClawHubRegistry { + baseURL := cfg.BaseURL + if baseURL == "" { + baseURL = "https://clawhub.ai" + } + searchPath := cfg.SearchPath + if searchPath == "" { + searchPath = "/api/v1/search" + } + skillsPath := cfg.SkillsPath + if skillsPath == "" { + skillsPath = "/api/v1/skills" + } + downloadPath := cfg.DownloadPath + if downloadPath == "" { + downloadPath = "/api/v1/download" + } + + timeout := defaultClawHubTimeout + if cfg.Timeout > 0 { + timeout = time.Duration(cfg.Timeout) * time.Second + } + + maxZip := defaultMaxZipSize + if cfg.MaxZipSize > 0 { + maxZip = cfg.MaxZipSize + } + + maxResp := defaultMaxResponseSize + if cfg.MaxResponseSize > 0 { + maxResp = cfg.MaxResponseSize + } + + return &ClawHubRegistry{ + baseURL: baseURL, + authToken: cfg.AuthToken, + searchPath: searchPath, + skillsPath: skillsPath, + downloadPath: downloadPath, + maxZipSize: maxZip, + maxResponseSize: maxResp, + client: &http.Client{ + Timeout: timeout, + Transport: &http.Transport{ + MaxIdleConns: 5, + IdleConnTimeout: 30 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + }, + }, + } +} + +func (c *ClawHubRegistry) Name() string { + return "clawhub" +} + +// --- Search --- + +type clawhubSearchResponse struct { + Results []clawhubSearchResult `json:"results"` +} + +type clawhubSearchResult struct { + Score float64 `json:"score"` + Slug *string `json:"slug"` + DisplayName *string `json:"displayName"` + Summary *string `json:"summary"` + Version *string `json:"version"` +} + +func (c *ClawHubRegistry) Search(ctx context.Context, query string, limit int) ([]SearchResult, error) { + u, err := url.Parse(c.baseURL + c.searchPath) + if err != nil { + return nil, fmt.Errorf("invalid base URL: %w", err) + } + + q := u.Query() + q.Set("q", query) + if limit > 0 { + q.Set("limit", fmt.Sprintf("%d", limit)) + } + u.RawQuery = q.Encode() + + body, err := c.doGet(ctx, u.String()) + if err != nil { + return nil, fmt.Errorf("search request failed: %w", err) + } + + var resp clawhubSearchResponse + if err := json.Unmarshal(body, &resp); err != nil { + return nil, fmt.Errorf("failed to parse search response: %w", err) + } + + results := make([]SearchResult, 0, len(resp.Results)) + for _, r := range resp.Results { + slug := utils.DerefStr(r.Slug, "") + if slug == "" { + continue + } + + summary := utils.DerefStr(r.Summary, "") + if summary == "" { + continue + } + + displayName := utils.DerefStr(r.DisplayName, "") + if displayName == "" { + displayName = slug + } + + results = append(results, SearchResult{ + Score: r.Score, + Slug: slug, + DisplayName: displayName, + Summary: summary, + Version: utils.DerefStr(r.Version, ""), + RegistryName: c.Name(), + }) + } + + return results, nil +} + +// --- GetSkillMeta --- + +type clawhubSkillResponse struct { + Slug string `json:"slug"` + DisplayName string `json:"displayName"` + Summary string `json:"summary"` + LatestVersion *clawhubVersionInfo `json:"latestVersion"` + Moderation *clawhubModerationInfo `json:"moderation"` +} + +type clawhubVersionInfo struct { + Version string `json:"version"` +} + +type clawhubModerationInfo struct { + IsMalwareBlocked bool `json:"isMalwareBlocked"` + IsSuspicious bool `json:"isSuspicious"` +} + +func (c *ClawHubRegistry) GetSkillMeta(ctx context.Context, slug string) (*SkillMeta, error) { + if err := utils.ValidateSkillIdentifier(slug); err != nil { + return nil, fmt.Errorf("invalid slug %q: error: %s", slug, err.Error()) + } + + u := c.baseURL + c.skillsPath + "/" + url.PathEscape(slug) + + body, err := c.doGet(ctx, u) + if err != nil { + return nil, fmt.Errorf("skill metadata request failed: %w", err) + } + + var resp clawhubSkillResponse + if err := json.Unmarshal(body, &resp); err != nil { + return nil, fmt.Errorf("failed to parse skill metadata: %w", err) + } + + meta := &SkillMeta{ + Slug: resp.Slug, + DisplayName: resp.DisplayName, + Summary: resp.Summary, + RegistryName: c.Name(), + } + + if resp.LatestVersion != nil { + meta.LatestVersion = resp.LatestVersion.Version + } + if resp.Moderation != nil { + meta.IsMalwareBlocked = resp.Moderation.IsMalwareBlocked + meta.IsSuspicious = resp.Moderation.IsSuspicious + } + + return meta, nil +} + +// --- DownloadAndInstall --- + +// DownloadAndInstall fetches metadata (with fallback), resolves version, +// downloads the skill ZIP, and extracts it to targetDir. +// Returns an InstallResult for the caller to use for moderation decisions. +func (c *ClawHubRegistry) DownloadAndInstall(ctx context.Context, slug, version, targetDir string) (*InstallResult, error) { + if err := utils.ValidateSkillIdentifier(slug); err != nil { + return nil, fmt.Errorf("invalid slug %q: error: %s", slug, err.Error()) + } + + // Step 1: Fetch metadata (with fallback). + result := &InstallResult{} + meta, err := c.GetSkillMeta(ctx, slug) + if err != nil { + // Fallback: proceed without metadata. + meta = nil + } + + if meta != nil { + result.IsMalwareBlocked = meta.IsMalwareBlocked + result.IsSuspicious = meta.IsSuspicious + result.Summary = meta.Summary + } + + // Step 2: Resolve version. + installVersion := version + if installVersion == "" && meta != nil { + installVersion = meta.LatestVersion + } + if installVersion == "" { + installVersion = "latest" + } + result.Version = installVersion + + // Step 3: Download ZIP to temp file (streams in ~32KB chunks). + u, err := url.Parse(c.baseURL + c.downloadPath) + if err != nil { + return nil, fmt.Errorf("invalid base URL: %w", err) + } + + q := u.Query() + q.Set("slug", slug) + if installVersion != "latest" { + q.Set("version", installVersion) + } + u.RawQuery = q.Encode() + + req, err := http.NewRequestWithContext(ctx, "GET", u.String(), nil) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + if c.authToken != "" { + req.Header.Set("Authorization", "Bearer "+c.authToken) + } + + tmpPath, err := utils.DownloadToFile(ctx, c.client, req, int64(c.maxZipSize)) + if err != nil { + return nil, fmt.Errorf("download failed: %w", err) + } + defer os.Remove(tmpPath) + + // Step 4: Extract from file on disk. + if err := utils.ExtractZipFile(tmpPath, targetDir); err != nil { + return nil, err + } + + return result, nil +} + +// --- HTTP helper --- + +func (c *ClawHubRegistry) doGet(ctx context.Context, urlStr string) ([]byte, error) { + req, err := http.NewRequestWithContext(ctx, "GET", urlStr, nil) + if err != nil { + return nil, err + } + + req.Header.Set("Accept", "application/json") + if c.authToken != "" { + req.Header.Set("Authorization", "Bearer "+c.authToken) + } + + resp, err := c.client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + // Limit response body read to prevent memory issues. + body, err := io.ReadAll(io.LimitReader(resp.Body, int64(c.maxResponseSize))) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(body)) + } + + return body, nil +} diff --git a/pkg/skills/clawhub_registry_test.go b/pkg/skills/clawhub_registry_test.go new file mode 100644 index 000000000..d12e19504 --- /dev/null +++ b/pkg/skills/clawhub_registry_test.go @@ -0,0 +1,256 @@ +package skills + +import ( + "archive/zip" + "bytes" + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + + "github.com/sipeed/picoclaw/pkg/utils" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func newTestRegistry(serverURL, authToken string) *ClawHubRegistry { + return NewClawHubRegistry(ClawHubConfig{ + Enabled: true, + BaseURL: serverURL, + AuthToken: authToken, + }) +} + +func TestClawHubRegistrySearch(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/api/v1/search", r.URL.Path) + assert.Equal(t, "github", r.URL.Query().Get("q")) + + slug := "github" + name := "GitHub Integration" + summary := "Interact with GitHub repos" + version := "1.0.0" + + json.NewEncoder(w).Encode(clawhubSearchResponse{ + Results: []clawhubSearchResult{ + {Score: 0.95, Slug: &slug, DisplayName: &name, Summary: &summary, Version: &version}, + }, + }) + })) + defer srv.Close() + + reg := newTestRegistry(srv.URL, "") + results, err := reg.Search(context.Background(), "github", 5) + + require.NoError(t, err) + require.Len(t, results, 1) + assert.Equal(t, "github", results[0].Slug) + assert.Equal(t, "GitHub Integration", results[0].DisplayName) + assert.InDelta(t, 0.95, results[0].Score, 0.001) + assert.Equal(t, "clawhub", results[0].RegistryName) +} + +func TestClawHubRegistryGetSkillMeta(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/api/v1/skills/github", r.URL.Path) + + json.NewEncoder(w).Encode(clawhubSkillResponse{ + Slug: "github", + DisplayName: "GitHub Integration", + Summary: "Full GitHub API integration", + LatestVersion: &clawhubVersionInfo{ + Version: "2.1.0", + }, + Moderation: &clawhubModerationInfo{ + IsMalwareBlocked: false, + IsSuspicious: true, + }, + }) + })) + defer srv.Close() + + reg := newTestRegistry(srv.URL, "") + meta, err := reg.GetSkillMeta(context.Background(), "github") + + require.NoError(t, err) + assert.Equal(t, "github", meta.Slug) + assert.Equal(t, "2.1.0", meta.LatestVersion) + assert.False(t, meta.IsMalwareBlocked) + assert.True(t, meta.IsSuspicious) +} + +func TestClawHubRegistryGetSkillMetaUnsafeSlug(t *testing.T) { + reg := newTestRegistry("https://example.com", "") + _, err := reg.GetSkillMeta(context.Background(), "../etc/passwd") + assert.Error(t, err) + assert.Contains(t, err.Error(), "invalid slug") +} + +func TestClawHubRegistryDownloadAndInstall(t *testing.T) { + // Create a valid ZIP in memory. + zipBuf := createTestZip(t, map[string]string{ + "SKILL.md": "---\nname: test-skill\ndescription: A test\n---\nHello skill", + "README.md": "# Test Skill\n", + }) + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/api/v1/skills/test-skill": + // Metadata endpoint. + json.NewEncoder(w).Encode(clawhubSkillResponse{ + Slug: "test-skill", + DisplayName: "Test Skill", + Summary: "A test skill", + LatestVersion: &clawhubVersionInfo{Version: "1.0.0"}, + }) + case "/api/v1/download": + assert.Equal(t, "test-skill", r.URL.Query().Get("slug")) + w.Header().Set("Content-Type", "application/zip") + w.Write(zipBuf) + default: + w.WriteHeader(http.StatusNotFound) + } + })) + defer srv.Close() + + tmpDir := t.TempDir() + targetDir := filepath.Join(tmpDir, "test-skill") + + reg := newTestRegistry(srv.URL, "") + result, err := reg.DownloadAndInstall(context.Background(), "test-skill", "1.0.0", targetDir) + + require.NoError(t, err) + assert.Equal(t, "1.0.0", result.Version) + assert.False(t, result.IsMalwareBlocked) + + // Verify extracted files. + skillContent, err := os.ReadFile(filepath.Join(targetDir, "SKILL.md")) + require.NoError(t, err) + assert.Contains(t, string(skillContent), "Hello skill") + + readmeContent, err := os.ReadFile(filepath.Join(targetDir, "README.md")) + require.NoError(t, err) + assert.Contains(t, string(readmeContent), "# Test Skill") +} + +func TestClawHubRegistryAuthToken(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + authHeader := r.Header.Get("Authorization") + assert.Equal(t, "Bearer test-token-123", authHeader) + json.NewEncoder(w).Encode(clawhubSearchResponse{Results: nil}) + })) + defer srv.Close() + + reg := newTestRegistry(srv.URL, "test-token-123") + _, _ = reg.Search(context.Background(), "test", 5) +} + +func TestExtractZipPathTraversal(t *testing.T) { + // Create a ZIP with a path traversal entry. + var buf bytes.Buffer + zw := zip.NewWriter(&buf) + + // Malicious entry trying to escape directory. + w, err := zw.Create("../../etc/passwd") + require.NoError(t, err) + w.Write([]byte("malicious")) + + zw.Close() + + // Write to temp file for extractZipFile. + tmpZip := filepath.Join(t.TempDir(), "bad.zip") + require.NoError(t, os.WriteFile(tmpZip, buf.Bytes(), 0644)) + + tmpDir := t.TempDir() + err = utils.ExtractZipFile(tmpZip, tmpDir) + assert.Error(t, err) + assert.Contains(t, err.Error(), "unsafe path") +} + +func TestExtractZipWithSubdirectories(t *testing.T) { + zipBuf := createTestZip(t, map[string]string{ + "SKILL.md": "root file", + "scripts/helper.sh": "#!/bin/bash\necho hello", + "examples/demo.yaml": "key: value", + }) + + // Write to temp file for extractZipFile. + tmpZip := filepath.Join(t.TempDir(), "test.zip") + require.NoError(t, os.WriteFile(tmpZip, zipBuf, 0644)) + + tmpDir := t.TempDir() + targetDir := filepath.Join(tmpDir, "my-skill") + + err := utils.ExtractZipFile(tmpZip, targetDir) + require.NoError(t, err) + + // Verify nested file. + data, err := os.ReadFile(filepath.Join(targetDir, "scripts", "helper.sh")) + require.NoError(t, err) + assert.Contains(t, string(data), "#!/bin/bash") +} + +func TestClawHubRegistrySearchHTTPError(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte("Internal Server Error")) + })) + defer srv.Close() + + reg := newTestRegistry(srv.URL, "") + _, err := reg.Search(context.Background(), "test", 5) + assert.Error(t, err) + assert.Contains(t, err.Error(), "500") +} + +func TestClawHubRegistrySearchNullableFields(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + validSlug := "valid-slug" + validSummary := "valid summary" + + // Return results with various null/empty fields + json.NewEncoder(w).Encode(clawhubSearchResponse{ + Results: []clawhubSearchResult{ + // Case 1: Null Slug -> Skip + {Score: 0.1, Slug: nil, DisplayName: nil, Summary: nil, Version: nil}, + // Case 2: Valid Slug, Null Summary -> Skip + {Score: 0.2, Slug: &validSlug, DisplayName: nil, Summary: nil, Version: nil}, + // Case 3: Valid Slug, Valid Summary, Null Name -> Keep, Name=Slug + {Score: 0.8, Slug: &validSlug, DisplayName: nil, Summary: &validSummary, Version: nil}, + }, + }) + })) + defer srv.Close() + + reg := newTestRegistry(srv.URL, "") + results, err := reg.Search(context.Background(), "test", 5) + + require.NoError(t, err) + require.Len(t, results, 1, "should only return 1 valid result") + + r := results[0] + assert.Equal(t, "valid-slug", r.Slug) + assert.Equal(t, "valid-slug", r.DisplayName, "should fallback name to slug") + assert.Equal(t, "valid summary", r.Summary) +} + +// --- helpers --- + +func createTestZip(t *testing.T, files map[string]string) []byte { + t.Helper() + var buf bytes.Buffer + zw := zip.NewWriter(&buf) + + for name, content := range files { + w, err := zw.Create(name) + require.NoError(t, err) + _, err = w.Write([]byte(content)) + require.NoError(t, err) + } + + require.NoError(t, zw.Close()) + return buf.Bytes() +} diff --git a/pkg/skills/registry.go b/pkg/skills/registry.go new file mode 100644 index 000000000..45ae72253 --- /dev/null +++ b/pkg/skills/registry.go @@ -0,0 +1,223 @@ +package skills + +import ( + "context" + "fmt" + "log/slog" + "sync" + "time" +) + +const ( + defaultMaxConcurrentSearches = 2 +) + +// SearchResult represents a single result from a skill registry search. +type SearchResult struct { + Score float64 `json:"score"` + Slug string `json:"slug"` + DisplayName string `json:"display_name"` + Summary string `json:"summary"` + Version string `json:"version"` + RegistryName string `json:"registry_name"` +} + +// SkillMeta holds metadata about a skill from a registry. +type SkillMeta struct { + Slug string `json:"slug"` + DisplayName string `json:"display_name"` + Summary string `json:"summary"` + LatestVersion string `json:"latest_version"` + IsMalwareBlocked bool `json:"is_malware_blocked"` + IsSuspicious bool `json:"is_suspicious"` + RegistryName string `json:"registry_name"` +} + +// InstallResult is returned by DownloadAndInstall to carry metadata +// back to the caller for moderation and user messaging. +type InstallResult struct { + Version string + IsMalwareBlocked bool + IsSuspicious bool + Summary string +} + +// SkillRegistry is the interface that all skill registries must implement. +// Each registry represents a different source of skills (e.g., clawhub.ai) +type SkillRegistry interface { + // Name returns the unique name of this registry (e.g., "clawhub"). + Name() string + // Search searches the registry for skills matching the query. + Search(ctx context.Context, query string, limit int) ([]SearchResult, error) + // GetSkillMeta retrieves metadata for a specific skill by slug. + GetSkillMeta(ctx context.Context, slug string) (*SkillMeta, error) + // DownloadAndInstall fetches metadata, resolves the version, downloads and + // installs the skill to targetDir. Returns an InstallResult with metadata + // for the caller to use for moderation and user messaging. + DownloadAndInstall(ctx context.Context, slug, version, targetDir string) (*InstallResult, error) +} + +// RegistryConfig holds configuration for all skill registries. +// This is the input to NewRegistryManagerFromConfig. +type RegistryConfig struct { + ClawHub ClawHubConfig + MaxConcurrentSearches int +} + +// ClawHubConfig configures the ClawHub registry. +type ClawHubConfig struct { + Enabled bool + BaseURL string + AuthToken string + SearchPath string // e.g. "/api/v1/search" + SkillsPath string // e.g. "/api/v1/skills" + DownloadPath string // e.g. "/api/v1/download" + Timeout int // seconds, 0 = default (30s) + MaxZipSize int // bytes, 0 = default (50MB) + MaxResponseSize int // bytes, 0 = default (2MB) +} + +// RegistryManager coordinates multiple skill registries. +// It fans out search requests and routes installs to the correct registry. +type RegistryManager struct { + registries []SkillRegistry + maxConcurrent int + mu sync.RWMutex +} + +// NewRegistryManager creates an empty RegistryManager. +func NewRegistryManager() *RegistryManager { + return &RegistryManager{ + registries: make([]SkillRegistry, 0), + maxConcurrent: defaultMaxConcurrentSearches, + } +} + +// NewRegistryManagerFromConfig builds a RegistryManager from config, +// instantiating only the enabled registries. +func NewRegistryManagerFromConfig(cfg RegistryConfig) *RegistryManager { + rm := NewRegistryManager() + if cfg.MaxConcurrentSearches > 0 { + rm.maxConcurrent = cfg.MaxConcurrentSearches + } + if cfg.ClawHub.Enabled { + rm.AddRegistry(NewClawHubRegistry(cfg.ClawHub)) + } + return rm +} + +// AddRegistry adds a registry to the manager. +func (rm *RegistryManager) AddRegistry(r SkillRegistry) { + rm.mu.Lock() + defer rm.mu.Unlock() + rm.registries = append(rm.registries, r) +} + +// GetRegistry returns a registry by name, or nil if not found. +func (rm *RegistryManager) GetRegistry(name string) SkillRegistry { + rm.mu.RLock() + defer rm.mu.RUnlock() + for _, r := range rm.registries { + if r.Name() == name { + return r + } + } + return nil +} + +// SearchAll fans out the query to all registries concurrently +// and merges results sorted by score descending. +func (rm *RegistryManager) SearchAll(ctx context.Context, query string, limit int) ([]SearchResult, error) { + rm.mu.RLock() + regs := make([]SkillRegistry, len(rm.registries)) + copy(regs, rm.registries) + rm.mu.RUnlock() + + if len(regs) == 0 { + return nil, fmt.Errorf("no registries configured") + } + + type regResult struct { + results []SearchResult + err error + } + + // Semaphore: limit concurrency. + sem := make(chan struct{}, rm.maxConcurrent) + resultsCh := make(chan regResult, len(regs)) + + var wg sync.WaitGroup + for _, reg := range regs { + wg.Add(1) + go func(r SkillRegistry) { + defer wg.Done() + + // Acquire semaphore slot. + select { + case sem <- struct{}{}: + defer func() { <-sem }() + case <-ctx.Done(): + resultsCh <- regResult{err: ctx.Err()} + return + } + + searchCtx, cancel := context.WithTimeout(ctx, 1*time.Minute) + defer cancel() + + results, err := r.Search(searchCtx, query, limit) + if err != nil { + slog.Warn("registry search failed", "registry", r.Name(), "error", err) + resultsCh <- regResult{err: err} + return + } + resultsCh <- regResult{results: results} + }(reg) + } + + // Close results channel after all goroutines complete. + go func() { + wg.Wait() + close(resultsCh) + }() + + var merged []SearchResult + var lastErr error + + var anyRegistrySucceeded bool + for rr := range resultsCh { + if rr.err != nil { + lastErr = rr.err + continue + } + anyRegistrySucceeded = true + merged = append(merged, rr.results...) + } + + // If all registries failed, return the last error. + if !anyRegistrySucceeded && lastErr != nil { + return nil, fmt.Errorf("all registries failed: %w", lastErr) + } + + // Sort by score descending. + sortByScoreDesc(merged) + + // Clamp to limit. + if limit > 0 && len(merged) > limit { + merged = merged[:limit] + } + + return merged, nil +} + +// sortByScoreDesc sorts SearchResults by Score in descending order (insertion sort — small slices). +func sortByScoreDesc(results []SearchResult) { + for i := 1; i < len(results); i++ { + key := results[i] + j := i - 1 + for j >= 0 && results[j].Score < key.Score { + results[j+1] = results[j] + j-- + } + results[j+1] = key + } +} diff --git a/pkg/skills/registry_test.go b/pkg/skills/registry_test.go new file mode 100644 index 000000000..daecd5a59 --- /dev/null +++ b/pkg/skills/registry_test.go @@ -0,0 +1,179 @@ +package skills + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/sipeed/picoclaw/pkg/utils" + "github.com/stretchr/testify/assert" +) + +// mockRegistry is a test double implementing SkillRegistry. +type mockRegistry struct { + name string + searchResults []SearchResult + searchErr error + meta *SkillMeta + metaErr error + installResult *InstallResult + installErr error +} + +func (m *mockRegistry) Name() string { return m.name } + +func (m *mockRegistry) Search(_ context.Context, _ string, _ int) ([]SearchResult, error) { + return m.searchResults, m.searchErr +} + +func (m *mockRegistry) GetSkillMeta(_ context.Context, _ string) (*SkillMeta, error) { + return m.meta, m.metaErr +} + +func (m *mockRegistry) DownloadAndInstall(_ context.Context, _, _, _ string) (*InstallResult, error) { + return m.installResult, m.installErr +} + +func TestRegistryManagerSearchAllSingle(t *testing.T) { + mgr := NewRegistryManager() + mgr.AddRegistry(&mockRegistry{ + name: "test", + searchResults: []SearchResult{ + {Slug: "skill-a", Score: 0.9, RegistryName: "test"}, + {Slug: "skill-b", Score: 0.5, RegistryName: "test"}, + }, + }) + + results, err := mgr.SearchAll(context.Background(), "test query", 10) + assert.NoError(t, err) + assert.Len(t, results, 2) + assert.Equal(t, "skill-a", results[0].Slug) +} + +func TestRegistryManagerSearchAllMultiple(t *testing.T) { + mgr := NewRegistryManager() + mgr.AddRegistry(&mockRegistry{ + name: "alpha", + searchResults: []SearchResult{ + {Slug: "skill-a", Score: 0.8, RegistryName: "alpha"}, + }, + }) + mgr.AddRegistry(&mockRegistry{ + name: "beta", + searchResults: []SearchResult{ + {Slug: "skill-b", Score: 0.95, RegistryName: "beta"}, + }, + }) + + results, err := mgr.SearchAll(context.Background(), "test query", 10) + assert.NoError(t, err) + assert.Len(t, results, 2) + // Should be sorted by score descending + assert.Equal(t, "skill-b", results[0].Slug) + assert.Equal(t, "skill-a", results[1].Slug) +} + +func TestRegistryManagerSearchAllOneFailsGracefully(t *testing.T) { + mgr := NewRegistryManager() + mgr.AddRegistry(&mockRegistry{ + name: "failing", + searchErr: fmt.Errorf("network error"), + }) + mgr.AddRegistry(&mockRegistry{ + name: "working", + searchResults: []SearchResult{ + {Slug: "skill-a", Score: 0.8, RegistryName: "working"}, + }, + }) + + results, err := mgr.SearchAll(context.Background(), "test query", 10) + assert.NoError(t, err) + assert.Len(t, results, 1) + assert.Equal(t, "skill-a", results[0].Slug) +} + +func TestRegistryManagerSearchAllAllFail(t *testing.T) { + mgr := NewRegistryManager() + mgr.AddRegistry(&mockRegistry{ + name: "fail-1", + searchErr: fmt.Errorf("error 1"), + }) + + _, err := mgr.SearchAll(context.Background(), "test query", 10) + assert.Error(t, err) +} + +func TestRegistryManagerSearchAllNoRegistries(t *testing.T) { + mgr := NewRegistryManager() + _, err := mgr.SearchAll(context.Background(), "test query", 10) + assert.Error(t, err) +} + +func TestRegistryManagerGetRegistry(t *testing.T) { + mgr := NewRegistryManager() + mock := &mockRegistry{name: "clawhub"} + mgr.AddRegistry(mock) + + got := mgr.GetRegistry("clawhub") + assert.NotNil(t, got) + assert.Equal(t, "clawhub", got.Name()) + + got = mgr.GetRegistry("nonexistent") + assert.Nil(t, got) +} + +func TestRegistryManagerSearchAllRespectLimit(t *testing.T) { + mgr := NewRegistryManager() + results := make([]SearchResult, 20) + for i := range results { + results[i] = SearchResult{Slug: fmt.Sprintf("skill-%d", i), Score: float64(20 - i)} + } + mgr.AddRegistry(&mockRegistry{ + name: "test", + searchResults: results, + }) + + got, err := mgr.SearchAll(context.Background(), "test", 5) + assert.NoError(t, err) + assert.Len(t, got, 5) + // Top scores first + assert.Equal(t, "skill-0", got[0].Slug) +} + +func TestRegistryManagerSearchAllTimeout(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond) + defer cancel() + + time.Sleep(5 * time.Millisecond) // Let context expire. + + mgr := NewRegistryManager() + mgr.AddRegistry(&mockRegistry{ + name: "slow", + searchErr: fmt.Errorf("context deadline exceeded"), + }) + + _, err := mgr.SearchAll(ctx, "test", 5) + assert.Error(t, err) +} + +func TestSortByScoreDesc(t *testing.T) { + results := []SearchResult{ + {Slug: "c", Score: 0.3}, + {Slug: "a", Score: 0.9}, + {Slug: "b", Score: 0.5}, + } + sortByScoreDesc(results) + assert.Equal(t, "a", results[0].Slug) + assert.Equal(t, "b", results[1].Slug) + assert.Equal(t, "c", results[2].Slug) +} + +func TestIsSafeSlug(t *testing.T) { + assert.NoError(t, utils.ValidateSkillIdentifier("github")) + assert.NoError(t, utils.ValidateSkillIdentifier("docker-compose")) + assert.Error(t, utils.ValidateSkillIdentifier("")) + assert.Error(t, utils.ValidateSkillIdentifier("../etc/passwd")) + assert.Error(t, utils.ValidateSkillIdentifier("path/traversal")) + assert.Error(t, utils.ValidateSkillIdentifier("path\\traversal")) +} diff --git a/pkg/skills/search_cache.go b/pkg/skills/search_cache.go new file mode 100644 index 000000000..5d7d2797e --- /dev/null +++ b/pkg/skills/search_cache.go @@ -0,0 +1,229 @@ +package skills + +import ( + "sort" + "strings" + "sync" + "time" +) + +// SearchCache provides lightweight caching for search results. +// It uses trigram-based similarity to match similar queries to cached results, +// avoiding redundant API calls. Thread-safe for concurrent access. +type SearchCache struct { + mu sync.RWMutex + entries map[string]*cacheEntry + order []string // LRU order: oldest first. + maxEntries int + ttl time.Duration +} + +type cacheEntry struct { + query string + trigrams []uint32 + results []SearchResult + createdAt time.Time +} + +// similarityThreshold is the minimum trigram Jaccard similarity for a cache hit. +const similarityThreshold = 0.7 + +// NewSearchCache creates a new search cache. +// maxEntries is the maximum number of cached queries (excess evicts LRU). +// ttl is how long each entry lives before expiration. +func NewSearchCache(maxEntries int, ttl time.Duration) *SearchCache { + if maxEntries <= 0 { + maxEntries = 50 + } + if ttl <= 0 { + ttl = 5 * time.Minute + } + return &SearchCache{ + entries: make(map[string]*cacheEntry), + order: make([]string, 0), + maxEntries: maxEntries, + ttl: ttl, + } +} + +// Get looks up results for a query. Returns cached results and true if found +// (either exact or similar match above threshold). Returns nil, false on miss. +func (sc *SearchCache) Get(query string) ([]SearchResult, bool) { + normalized := normalizeQuery(query) + if normalized == "" { + return nil, false + } + + sc.mu.Lock() + defer sc.mu.Unlock() + + // Exact match first. + if entry, ok := sc.entries[normalized]; ok { + if time.Since(entry.createdAt) < sc.ttl { + sc.moveToEndLocked(normalized) + return copyResults(entry.results), true + } + } + + // Similarity match. + queryTrigrams := buildTrigrams(normalized) + var bestEntry *cacheEntry + var bestSim float64 + + for _, entry := range sc.entries { + if time.Since(entry.createdAt) >= sc.ttl { + continue // Skip expired. + } + sim := jaccardSimilarity(queryTrigrams, entry.trigrams) + if sim > bestSim { + bestSim = sim + bestEntry = entry + } + } + + if bestSim >= similarityThreshold && bestEntry != nil { + sc.moveToEndLocked(bestEntry.query) + return copyResults(bestEntry.results), true + } + + return nil, false +} + +// Put stores results for a query. Evicts the oldest entry if at capacity. +func (sc *SearchCache) Put(query string, results []SearchResult) { + normalized := normalizeQuery(query) + if normalized == "" { + return + } + + sc.mu.Lock() + defer sc.mu.Unlock() + + // Evict expired entries first. + sc.evictExpiredLocked() + + // If already exists, update. + if _, ok := sc.entries[normalized]; ok { + sc.entries[normalized] = &cacheEntry{ + query: normalized, + trigrams: buildTrigrams(normalized), + results: copyResults(results), + createdAt: time.Now(), + } + // Move to end of LRU order. + sc.moveToEndLocked(normalized) + return + } + + // Evict LRU if at capacity. + for len(sc.entries) >= sc.maxEntries && len(sc.order) > 0 { + oldest := sc.order[0] + sc.order = sc.order[1:] + delete(sc.entries, oldest) + } + + // Insert new entry. + sc.entries[normalized] = &cacheEntry{ + query: normalized, + trigrams: buildTrigrams(normalized), + results: copyResults(results), + createdAt: time.Now(), + } + sc.order = append(sc.order, normalized) +} + +// Len returns the number of entries (for testing). +func (sc *SearchCache) Len() int { + sc.mu.RLock() + defer sc.mu.RUnlock() + return len(sc.entries) +} + +// --- internal --- + +func (sc *SearchCache) evictExpiredLocked() { + now := time.Now() + newOrder := make([]string, 0, len(sc.order)) + for _, key := range sc.order { + entry, ok := sc.entries[key] + if !ok || now.Sub(entry.createdAt) >= sc.ttl { + delete(sc.entries, key) + continue + } + newOrder = append(newOrder, key) + } + sc.order = newOrder +} + +func (sc *SearchCache) moveToEndLocked(key string) { + for i, k := range sc.order { + if k == key { + sc.order = append(sc.order[:i], sc.order[i+1:]...) + break + } + } + sc.order = append(sc.order, key) +} + +func normalizeQuery(q string) string { + return strings.ToLower(strings.TrimSpace(q)) +} + +// buildTrigrams generates hash of trigrams from a string. +// Example: "hello" → {"hel", "ell", "llo"} +// "hel" -> 0x0068656c -> 4 bytes; compared to 16 bytes of a string +func buildTrigrams(s string) []uint32 { + if len(s) < 3 { + return nil + } + + trigrams := make([]uint32, 0, len(s)-2) + for i := 0; i <= len(s)-3; i++ { + trigrams = append(trigrams, uint32(s[i])<<16|uint32(s[i+1])<<8|uint32(s[i+2])) + } + + // Sort and Deduplication + sort.Slice(trigrams, func(i, j int) bool { return trigrams[i] < trigrams[j] }) + n := 1 + for i := 1; i < len(trigrams); i++ { + if trigrams[i] != trigrams[i-1] { + trigrams[n] = trigrams[i] + n++ + } + } + + return trigrams[:n] +} + +// jaccardSimilarity computes |A ∩ B| / |A ∪ B|. +func jaccardSimilarity(a, b []uint32) float64 { + if len(a) == 0 && len(b) == 0 { + return 1 + } + i, j := 0, 0 + intersection := 0 + + for i < len(a) && j < len(b) { + if a[i] == b[j] { + intersection++ + i++ + j++ + } else if a[i] < b[j] { + i++ + } else { + j++ + } + } + + union := len(a) + len(b) - intersection + return float64(intersection) / float64(union) +} + +func copyResults(results []SearchResult) []SearchResult { + if results == nil { + return nil + } + cp := make([]SearchResult, len(results)) + copy(cp, results) + return cp +} diff --git a/pkg/skills/search_cache_test.go b/pkg/skills/search_cache_test.go new file mode 100644 index 000000000..816bdfb93 --- /dev/null +++ b/pkg/skills/search_cache_test.go @@ -0,0 +1,200 @@ +package skills + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestSearchCacheExactHit(t *testing.T) { + cache := NewSearchCache(10, 5*time.Minute) + + results := []SearchResult{ + {Slug: "github", Score: 0.9, RegistryName: "clawhub"}, + {Slug: "docker", Score: 0.7, RegistryName: "clawhub"}, + } + cache.Put("github integration", results) + + got, hit := cache.Get("github integration") + assert.True(t, hit) + assert.Len(t, got, 2) + assert.Equal(t, "github", got[0].Slug) +} + +func TestSearchCacheExactHitCaseInsensitive(t *testing.T) { + cache := NewSearchCache(10, 5*time.Minute) + + results := []SearchResult{{Slug: "github", Score: 0.9}} + cache.Put("GitHub Integration", results) + + got, hit := cache.Get("github integration") + assert.True(t, hit) + assert.Len(t, got, 1) +} + +func TestSearchCacheSimilarHit(t *testing.T) { + cache := NewSearchCache(10, 5*time.Minute) + + results := []SearchResult{{Slug: "github", Score: 0.9}} + cache.Put("github integration tool", results) + + // "github integration" is very similar to "github integration tool" + got, hit := cache.Get("github integration") + assert.True(t, hit) + assert.Len(t, got, 1) +} + +func TestSearchCacheDissimilarMiss(t *testing.T) { + cache := NewSearchCache(10, 5*time.Minute) + + results := []SearchResult{{Slug: "github", Score: 0.9}} + cache.Put("github integration", results) + + // Completely unrelated query + _, hit := cache.Get("database management") + assert.False(t, hit) +} + +func TestSearchCacheTTLExpiration(t *testing.T) { + cache := NewSearchCache(10, 50*time.Millisecond) + + results := []SearchResult{{Slug: "github", Score: 0.9}} + cache.Put("github integration", results) + + // Immediately should hit + _, hit := cache.Get("github integration") + assert.True(t, hit) + + // Wait for expiration + time.Sleep(100 * time.Millisecond) + + _, hit = cache.Get("github integration") + assert.False(t, hit) +} + +func TestSearchCacheLRUEviction(t *testing.T) { + cache := NewSearchCache(3, 5*time.Minute) + + cache.Put("query-1", []SearchResult{{Slug: "a"}}) + cache.Put("query-2", []SearchResult{{Slug: "b"}}) + cache.Put("query-3", []SearchResult{{Slug: "c"}}) + + assert.Equal(t, 3, cache.Len()) + + // Adding a 4th should evict query-1 (oldest) + cache.Put("query-4", []SearchResult{{Slug: "d"}}) + assert.Equal(t, 3, cache.Len()) + + _, hit := cache.Get("query-1") + assert.False(t, hit, "oldest entry should be evicted") + + got, hit := cache.Get("query-4") + assert.True(t, hit) + assert.Equal(t, "d", got[0].Slug) +} + +func TestSearchCacheEmptyQuery(t *testing.T) { + cache := NewSearchCache(10, 5*time.Minute) + + _, hit := cache.Get("") + assert.False(t, hit) + + _, hit = cache.Get(" ") + assert.False(t, hit) +} + +func TestSearchCacheResultsCopied(t *testing.T) { + cache := NewSearchCache(10, 5*time.Minute) + + original := []SearchResult{{Slug: "github", Score: 0.9}} + cache.Put("test", original) + + // Mutate original after putting + original[0].Slug = "mutated" + + got, hit := cache.Get("test") + assert.True(t, hit) + assert.Equal(t, "github", got[0].Slug, "cache should hold a copy, not a reference") +} + +func TestBuildTrigrams(t *testing.T) { + trigrams := buildTrigrams("hello") + assert.Contains(t, trigrams, uint32('h')<<16|uint32('e')<<8|uint32('l')) + assert.Contains(t, trigrams, uint32('e')<<16|uint32('l')<<8|uint32('l')) + assert.Contains(t, trigrams, uint32('l')<<16|uint32('l')<<8|uint32('o')) + assert.Len(t, trigrams, 3) +} + +func TestJaccardSimilarity(t *testing.T) { + a := buildTrigrams("github integration") + b := buildTrigrams("github integration tool") + + sim := jaccardSimilarity(a, b) + assert.Greater(t, sim, 0.5, "similar strings should have high sim") + + c := buildTrigrams("completely different query about databases") + sim2 := jaccardSimilarity(a, c) + assert.Less(t, sim2, 0.3, "dissimilar strings should have low sim") +} + +func TestJaccardSimilarityEdgeCases(t *testing.T) { + empty := buildTrigrams("") + nonempty := buildTrigrams("hello") + + assert.Equal(t, 1.0, jaccardSimilarity(empty, empty)) + assert.Equal(t, 0.0, jaccardSimilarity(empty, nonempty)) + assert.Equal(t, 0.0, jaccardSimilarity(nonempty, empty)) +} + +func TestSearchCacheConcurrency(t *testing.T) { + cache := NewSearchCache(50, 5*time.Minute) + done := make(chan struct{}) + + // Concurrent writes + go func() { + for i := 0; i < 100; i++ { + cache.Put("query-write-"+string(rune('a'+i%26)), []SearchResult{{Slug: "x"}}) + } + done <- struct{}{} + }() + + // Concurrent reads + go func() { + for i := 0; i < 100; i++ { + cache.Get("query-write-a") + } + done <- struct{}{} + }() + + <-done +} + +func TestSearchCacheLRUUpdateOnGet(t *testing.T) { + // Capacity 3 + cache := NewSearchCache(3, time.Hour) + + // Fill cache: query-A, query-B, query-C + // Use longer strings to ensure trigrams are generated and avoid false positive similarity + cache.Put("query-A", []SearchResult{{Slug: "A"}}) + cache.Put("query-B", []SearchResult{{Slug: "B"}}) + cache.Put("query-C", []SearchResult{{Slug: "C"}}) + + // Access query-A (should make it most recently used) + if _, found := cache.Get("query-A"); !found { + t.Fatal("query-A should be in cache") + } + + // Add query-D. Should evict query-B (LRU) instead of query-A (which was refreshed) + cache.Put("query-D", []SearchResult{{Slug: "D"}}) + + // Check if query-A is still there + if _, found := cache.Get("query-A"); !found { + t.Fatalf("query-A was evicted! valid LRU should have kept query-A and evicted query-B.") + } + + // Check if query-B is evicted + if _, found := cache.Get("query-B"); found { + t.Fatal("query-B should have been evicted") + } +} diff --git a/pkg/tools/skills_install.go b/pkg/tools/skills_install.go new file mode 100644 index 000000000..6b05918ce --- /dev/null +++ b/pkg/tools/skills_install.go @@ -0,0 +1,199 @@ +package tools + +import ( + "context" + "encoding/json" + "fmt" + "os" + "path/filepath" + "sync" + "time" + + "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/skills" + "github.com/sipeed/picoclaw/pkg/utils" +) + +// InstallSkillTool allows the LLM agent to install skills from registries. +// It shares the same RegistryManager that FindSkillsTool uses, +// so all registries configured in config are available for installation. +type InstallSkillTool struct { + registryMgr *skills.RegistryManager + workspace string + mu sync.Mutex +} + +// NewInstallSkillTool creates a new InstallSkillTool. +// registryMgr is the shared registry manager (same instance as FindSkillsTool). +// workspace is the root workspace directory; skills install to {workspace}/skills/{slug}/. +func NewInstallSkillTool(registryMgr *skills.RegistryManager, workspace string) *InstallSkillTool { + return &InstallSkillTool{ + registryMgr: registryMgr, + workspace: workspace, + mu: sync.Mutex{}, + } +} + +func (t *InstallSkillTool) Name() string { + return "install_skill" +} + +func (t *InstallSkillTool) Description() string { + return "Install a skill from a registry by slug. Downloads and extracts the skill into the workspace. Use find_skills first to discover available skills." +} + +func (t *InstallSkillTool) Parameters() map[string]interface{} { + return map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "slug": map[string]interface{}{ + "type": "string", + "description": "The unique slug of the skill to install (e.g., 'github', 'docker-compose')", + }, + "version": map[string]interface{}{ + "type": "string", + "description": "Specific version to install (optional, defaults to latest)", + }, + "registry": map[string]interface{}{ + "type": "string", + "description": "Registry to install from (required, e.g., 'clawhub')", + }, + "force": map[string]interface{}{ + "type": "boolean", + "description": "Force reinstall if skill already exists (default false)", + }, + }, + "required": []string{"slug", "registry"}, + } +} + +func (t *InstallSkillTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult { + // Install lock to prevent concurrent directory operations. + // Ideally this should be done at a `slug` level, currently, its at a `workspace` level. + t.mu.Lock() + defer t.mu.Unlock() + + // Validate slug + slug, _ := args["slug"].(string) + if err := utils.ValidateSkillIdentifier(slug); err != nil { + return ErrorResult(fmt.Sprintf("invalid slug %q: error: %s", slug, err.Error())) + } + + // Validate registry + registryName, _ := args["registry"].(string) + if err := utils.ValidateSkillIdentifier(registryName); err != nil { + return ErrorResult(fmt.Sprintf("invalid registry %q: error: %s", registryName, err.Error())) + } + + version, _ := args["version"].(string) + force, _ := args["force"].(bool) + + // Check if already installed. + skillsDir := filepath.Join(t.workspace, "skills") + targetDir := filepath.Join(skillsDir, slug) + + if !force { + if _, err := os.Stat(targetDir); err == nil { + return ErrorResult(fmt.Sprintf("skill %q already installed at %s. Use force=true to reinstall.", slug, targetDir)) + } + } else { + // Force: remove existing if present. + os.RemoveAll(targetDir) + } + + // Resolve which registry to use. + registry := t.registryMgr.GetRegistry(registryName) + if registry == nil { + return ErrorResult(fmt.Sprintf("registry %q not found", registryName)) + } + + // Ensure skills directory exists. + if err := os.MkdirAll(skillsDir, 0755); err != nil { + return ErrorResult(fmt.Sprintf("failed to create skills directory: %v", err)) + } + + // Download and install (handles metadata, version resolution, extraction). + result, err := registry.DownloadAndInstall(ctx, slug, version, targetDir) + if err != nil { + // Clean up partial install. + rmErr := os.RemoveAll(targetDir) + if rmErr != nil { + logger.ErrorCF("tool", "Failed to remove partial install", + map[string]interface{}{ + "tool": "install_skill", + "target_dir": targetDir, + "error": rmErr.Error(), + }) + } + return ErrorResult(fmt.Sprintf("failed to install %q: %v", slug, err)) + } + + // Moderation: block malware. + if result.IsMalwareBlocked { + rmErr := os.RemoveAll(targetDir) + if rmErr != nil { + logger.ErrorCF("tool", "Failed to remove partial install", + map[string]interface{}{ + "tool": "install_skill", + "target_dir": targetDir, + "error": rmErr.Error(), + }) + } + return ErrorResult(fmt.Sprintf("skill %q is flagged as malicious and cannot be installed", slug)) + } + + // Write origin metadata. + if err := writeOriginMeta(targetDir, registry.Name(), slug, result.Version); err != nil { + logger.ErrorCF("tool", "Failed to write origin metadata", + map[string]interface{}{ + "tool": "install_skill", + "error": err.Error(), + "target": targetDir, + "registry": registry.Name(), + "slug": slug, + "version": result.Version, + }) + _ = err + } + + // Build result with moderation warning if suspicious. + var output string + if result.IsSuspicious { + output = fmt.Sprintf("⚠️ Warning: skill %q is flagged as suspicious (may contain risky patterns).\n\n", slug) + } + output += fmt.Sprintf("Successfully installed skill %q v%s from %s registry.\nLocation: %s\n", + slug, result.Version, registry.Name(), targetDir) + + if result.Summary != "" { + output += fmt.Sprintf("Description: %s\n", result.Summary) + } + output += "\nThe skill is now available and can be loaded in the current session." + + return SilentResult(output) +} + +// originMeta tracks which registry a skill was installed from. +type originMeta struct { + Version int `json:"version"` + Registry string `json:"registry"` + Slug string `json:"slug"` + InstalledVersion string `json:"installed_version"` + InstalledAt int64 `json:"installed_at"` +} + +func writeOriginMeta(targetDir, registryName, slug, version string) error { + meta := originMeta{ + Version: 1, + Registry: registryName, + Slug: slug, + InstalledVersion: version, + InstalledAt: time.Now().UnixMilli(), + } + + data, err := json.MarshalIndent(meta, "", " ") + if err != nil { + return err + } + + return os.WriteFile(filepath.Join(targetDir, ".skill-origin.json"), data, 0644) +} diff --git a/pkg/tools/skills_install_test.go b/pkg/tools/skills_install_test.go new file mode 100644 index 000000000..e6941a950 --- /dev/null +++ b/pkg/tools/skills_install_test.go @@ -0,0 +1,103 @@ +package tools + +import ( + "context" + "os" + "path/filepath" + "testing" + + "github.com/sipeed/picoclaw/pkg/skills" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestInstallSkillToolName(t *testing.T) { + tool := NewInstallSkillTool(skills.NewRegistryManager(), t.TempDir()) + assert.Equal(t, "install_skill", tool.Name()) +} + +func TestInstallSkillToolMissingSlug(t *testing.T) { + tool := NewInstallSkillTool(skills.NewRegistryManager(), t.TempDir()) + result := tool.Execute(context.Background(), map[string]interface{}{}) + assert.True(t, result.IsError) + assert.Contains(t, result.ForLLM, "identifier is required and must be a non-empty string") +} + +func TestInstallSkillToolEmptySlug(t *testing.T) { + tool := NewInstallSkillTool(skills.NewRegistryManager(), t.TempDir()) + result := tool.Execute(context.Background(), map[string]interface{}{ + "slug": " ", + }) + assert.True(t, result.IsError) + assert.Contains(t, result.ForLLM, "identifier is required and must be a non-empty string") +} + +func TestInstallSkillToolUnsafeSlug(t *testing.T) { + tool := NewInstallSkillTool(skills.NewRegistryManager(), t.TempDir()) + + cases := []string{ + "../etc/passwd", + "path/traversal", + "path\\traversal", + } + + for _, slug := range cases { + result := tool.Execute(context.Background(), map[string]interface{}{ + "slug": slug, + }) + assert.True(t, result.IsError, "slug %q should be rejected", slug) + assert.Contains(t, result.ForLLM, "invalid slug") + } +} + +func TestInstallSkillToolAlreadyExists(t *testing.T) { + workspace := t.TempDir() + skillDir := filepath.Join(workspace, "skills", "existing-skill") + require.NoError(t, os.MkdirAll(skillDir, 0755)) + + tool := NewInstallSkillTool(skills.NewRegistryManager(), workspace) + result := tool.Execute(context.Background(), map[string]interface{}{ + "slug": "existing-skill", + "registry": "clawhub", + }) + assert.True(t, result.IsError) + assert.Contains(t, result.ForLLM, "already installed") +} + +func TestInstallSkillToolRegistryNotFound(t *testing.T) { + workspace := t.TempDir() + tool := NewInstallSkillTool(skills.NewRegistryManager(), workspace) + result := tool.Execute(context.Background(), map[string]interface{}{ + "slug": "some-skill", + "registry": "nonexistent", + }) + assert.True(t, result.IsError) + assert.Contains(t, result.ForLLM, "registry") + assert.Contains(t, result.ForLLM, "not found") +} + +func TestInstallSkillToolParameters(t *testing.T) { + tool := NewInstallSkillTool(skills.NewRegistryManager(), t.TempDir()) + params := tool.Parameters() + + props, ok := params["properties"].(map[string]interface{}) + assert.True(t, ok) + assert.Contains(t, props, "slug") + assert.Contains(t, props, "version") + assert.Contains(t, props, "registry") + assert.Contains(t, props, "force") + + required, ok := params["required"].([]string) + assert.True(t, ok) + assert.Contains(t, required, "slug") + assert.Contains(t, required, "registry") +} + +func TestInstallSkillToolMissingRegistry(t *testing.T) { + tool := NewInstallSkillTool(skills.NewRegistryManager(), t.TempDir()) + result := tool.Execute(context.Background(), map[string]interface{}{ + "slug": "some-skill", + }) + assert.True(t, result.IsError) + assert.Contains(t, result.ForLLM, "invalid registry") +} diff --git a/pkg/tools/skills_search.go b/pkg/tools/skills_search.go new file mode 100644 index 000000000..b12949ec2 --- /dev/null +++ b/pkg/tools/skills_search.go @@ -0,0 +1,119 @@ +package tools + +import ( + "context" + "fmt" + "strings" + + "github.com/sipeed/picoclaw/pkg/skills" +) + +// FindSkillsTool allows the LLM agent to search for installable skills from registries. +type FindSkillsTool struct { + registryMgr *skills.RegistryManager + cache *skills.SearchCache +} + +// NewFindSkillsTool creates a new FindSkillsTool. +// registryMgr is the shared registry manager (built from config in createToolRegistry). +// cache is the search cache for deduplicating similar queries. +func NewFindSkillsTool(registryMgr *skills.RegistryManager, cache *skills.SearchCache) *FindSkillsTool { + return &FindSkillsTool{ + registryMgr: registryMgr, + cache: cache, + } +} + +func (t *FindSkillsTool) Name() string { + return "find_skills" +} + +func (t *FindSkillsTool) Description() string { + return "Search for installable skills from skill registries. Returns skill slugs, descriptions, versions, and relevance scores. Use this to discover skills before installing them with install_skill." +} + +func (t *FindSkillsTool) Parameters() map[string]interface{} { + return map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "query": map[string]interface{}{ + "type": "string", + "description": "Search query describing the desired skill capability (e.g., 'github integration', 'database management')", + }, + "limit": map[string]interface{}{ + "type": "integer", + "description": "Maximum number of results to return (1-20, default 5)", + "minimum": 1.0, + "maximum": 20.0, + }, + }, + "required": []string{"query"}, + } +} + +func (t *FindSkillsTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult { + query, ok := args["query"].(string) + query = strings.ToLower(strings.TrimSpace(query)) + if !ok || query == "" { + return ErrorResult("query is required and must be a non-empty string") + } + + limit := 5 + if l, ok := args["limit"].(float64); ok { + li := int(l) + if li >= 1 && li <= 20 { + limit = li + } + } + + // Check cache first. + if t.cache != nil { + if cached, hit := t.cache.Get(query); hit { + return SilentResult(formatSearchResults(query, cached, true)) + } + } + + // Search all registries. + results, err := t.registryMgr.SearchAll(ctx, query, limit) + if err != nil { + return ErrorResult(fmt.Sprintf("skill search failed: %v", err)) + } + + // Cache the results. + if t.cache != nil && len(results) > 0 { + t.cache.Put(query, results) + } + + return SilentResult(formatSearchResults(query, results, false)) +} + +func formatSearchResults(query string, results []skills.SearchResult, cached bool) string { + if len(results) == 0 { + return fmt.Sprintf("No skills found for query: %q", query) + } + + var sb strings.Builder + source := "" + if cached { + source = " (cached)" + } + sb.WriteString(fmt.Sprintf("Found %d skills for %q%s:\n\n", len(results), query, source)) + + for i, r := range results { + sb.WriteString(fmt.Sprintf("%d. **%s**", i+1, r.Slug)) + if r.Version != "" { + sb.WriteString(fmt.Sprintf(" v%s", r.Version)) + } + sb.WriteString(fmt.Sprintf(" (score: %.3f, registry: %s)\n", r.Score, r.RegistryName)) + if r.DisplayName != "" && r.DisplayName != r.Slug { + sb.WriteString(fmt.Sprintf(" Name: %s\n", r.DisplayName)) + } + if r.Summary != "" { + sb.WriteString(fmt.Sprintf(" %s\n", r.Summary)) + } + sb.WriteString("\n") + } + + sb.WriteString("Use install_skill with the slug to install a skill.") + return sb.String() +} diff --git a/pkg/tools/skills_search_test.go b/pkg/tools/skills_search_test.go new file mode 100644 index 000000000..7e07b2775 --- /dev/null +++ b/pkg/tools/skills_search_test.go @@ -0,0 +1,82 @@ +package tools + +import ( + "context" + "testing" + + "github.com/sipeed/picoclaw/pkg/skills" + "github.com/stretchr/testify/assert" +) + +func TestFindSkillsToolName(t *testing.T) { + tool := NewFindSkillsTool(skills.NewRegistryManager(), nil) + assert.Equal(t, "find_skills", tool.Name()) +} + +func TestFindSkillsToolMissingQuery(t *testing.T) { + tool := NewFindSkillsTool(skills.NewRegistryManager(), nil) + result := tool.Execute(context.Background(), map[string]interface{}{}) + assert.True(t, result.IsError) + assert.Contains(t, result.ForLLM, "query is required") +} + +func TestFindSkillsToolEmptyQuery(t *testing.T) { + tool := NewFindSkillsTool(skills.NewRegistryManager(), nil) + result := tool.Execute(context.Background(), map[string]interface{}{ + "query": " ", + }) + assert.True(t, result.IsError) +} + +func TestFindSkillsToolCacheHit(t *testing.T) { + cache := skills.NewSearchCache(10, 5*60*1000*1000*1000) // 5 min + cache.Put("github", []skills.SearchResult{ + {Slug: "github", Score: 0.9, RegistryName: "clawhub"}, + }) + + tool := NewFindSkillsTool(skills.NewRegistryManager(), cache) + result := tool.Execute(context.Background(), map[string]interface{}{ + "query": "github", + }) + + assert.False(t, result.IsError) + assert.Contains(t, result.ForLLM, "github") + assert.Contains(t, result.ForLLM, "cached") +} + +func TestFindSkillsToolParameters(t *testing.T) { + tool := NewFindSkillsTool(skills.NewRegistryManager(), nil) + params := tool.Parameters() + + props, ok := params["properties"].(map[string]interface{}) + assert.True(t, ok) + assert.Contains(t, props, "query") + assert.Contains(t, props, "limit") + + required, ok := params["required"].([]string) + assert.True(t, ok) + assert.Contains(t, required, "query") +} + +func TestFindSkillsToolDescription(t *testing.T) { + tool := NewFindSkillsTool(skills.NewRegistryManager(), nil) + assert.NotEmpty(t, tool.Description()) + assert.Contains(t, tool.Description(), "skill") +} + +func TestFormatSearchResultsEmpty(t *testing.T) { + result := formatSearchResults("test query", nil, false) + assert.Contains(t, result, "No skills found") +} + +func TestFormatSearchResultsWithData(t *testing.T) { + results := []skills.SearchResult{ + {Slug: "github", Score: 0.95, DisplayName: "GitHub", Summary: "GitHub API integration", Version: "1.0.0", RegistryName: "clawhub"}, + } + output := formatSearchResults("github", results, false) + assert.Contains(t, output, "github") + assert.Contains(t, output, "v1.0.0") + assert.Contains(t, output, "0.950") + assert.Contains(t, output, "clawhub") + assert.Contains(t, output, "install_skill") +} diff --git a/pkg/utils/download.go b/pkg/utils/download.go new file mode 100644 index 000000000..9fa7fbfa7 --- /dev/null +++ b/pkg/utils/download.go @@ -0,0 +1,93 @@ +package utils + +import ( + "context" + "fmt" + "io" + "net/http" + "os" + + "github.com/sipeed/picoclaw/pkg/logger" +) + +// DownloadToFile streams an HTTP response body to a temporary file in small +// chunks (~32KB), keeping peak memory usage constant regardless of file size. +// +// Parameters: +// - ctx: context for cancellation/timeout +// - client: HTTP client to use (caller controls timeouts, transport, etc.) +// - req: fully prepared *http.Request (method, URL, headers, etc.) +// - maxBytes: maximum bytes to download; 0 means no limit +// +// Returns the path to the temporary file. The caller is responsible for +// removing it when done (defer os.Remove(path)). +// +// On any error the temp file is cleaned up automatically. +func DownloadToFile(ctx context.Context, client *http.Client, req *http.Request, maxBytes int64) (string, error) { + // Attach context. + req = req.WithContext(ctx) + + logger.DebugCF("download", "Starting download", map[string]interface{}{ + "url": req.URL.String(), + "max_bytes": maxBytes, + }) + + resp, err := client.Do(req) + if err != nil { + return "", fmt.Errorf("request failed: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + // Read a small amount for the error message. + errBody := make([]byte, 512) + n, _ := io.ReadFull(resp.Body, errBody) + return "", fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(errBody[:n])) + } + + // Create temp file. + tmpFile, err := os.CreateTemp("", "picoclaw-dl-*") + if err != nil { + return "", fmt.Errorf("failed to create temp file: %w", err) + } + tmpPath := tmpFile.Name() + + logger.DebugCF("download", "Streaming to temp file", map[string]interface{}{ + "path": tmpPath, + }) + + // Cleanup helper — removes the temp file on any error. + cleanup := func() { + _ = tmpFile.Close() + _ = os.Remove(tmpPath) + } + + // Optionally limit the download size. + var src io.Reader = resp.Body + if maxBytes > 0 { + src = io.LimitReader(resp.Body, maxBytes+1) // +1 to detect overflow + } + + written, err := io.Copy(tmpFile, src) + if err != nil { + cleanup() + return "", fmt.Errorf("download write failed: %w", err) + } + + if maxBytes > 0 && written > maxBytes { + cleanup() + return "", fmt.Errorf("download too large: %d bytes (max %d)", written, maxBytes) + } + + if err := tmpFile.Close(); err != nil { + _ = os.Remove(tmpPath) + return "", fmt.Errorf("failed to close temp file: %w", err) + } + + logger.DebugCF("download", "Download complete", map[string]interface{}{ + "path": tmpPath, + "bytes_written": written, + }) + + return tmpPath, nil +} diff --git a/pkg/utils/skills.go b/pkg/utils/skills.go new file mode 100644 index 000000000..1d2cfac7f --- /dev/null +++ b/pkg/utils/skills.go @@ -0,0 +1,19 @@ +package utils + +import ( + "fmt" + "strings" +) + +// ValidateSkillIdentifier validates that the given skill identifier (slug or registry name) is non-empty +// and does not contain path separators ("/", "\\") or ".." for security. +func ValidateSkillIdentifier(identifier string) error { + trimmed := strings.TrimSpace(identifier) + if trimmed == "" { + return fmt.Errorf("identifier is required and must be a non-empty string") + } + if strings.ContainsAny(trimmed, "/\\") || strings.Contains(trimmed, "..") { + return fmt.Errorf("identifier must not contain path separators or '..' to prevent directory traversal") + } + return nil +} diff --git a/pkg/utils/string.go b/pkg/utils/string.go index 0d9837cb9..7a6aa37cc 100644 --- a/pkg/utils/string.go +++ b/pkg/utils/string.go @@ -14,3 +14,12 @@ func Truncate(s string, maxLen int) string { } return string(runes[:maxLen-3]) + "..." } + +// DerefStr dereferences a pointer to a string and +// returns the value or a fallback if the pointer is nil. +func DerefStr(s *string, fallback string) string { + if s == nil { + return fallback + } + return *s +} diff --git a/pkg/utils/zip.go b/pkg/utils/zip.go new file mode 100644 index 000000000..cad91e420 --- /dev/null +++ b/pkg/utils/zip.go @@ -0,0 +1,120 @@ +package utils + +import ( + "archive/zip" + "fmt" + "io" + "os" + "path/filepath" + "strings" + + "github.com/sipeed/picoclaw/pkg/logger" +) + +// ExtractZipFile extracts a ZIP archive from disk to targetDir. +// It reads entries one at a time from disk, keeping memory usage minimal. +// +// Security: rejects path traversal attempts and symlinks. +func ExtractZipFile(zipPath string, targetDir string) error { + reader, err := zip.OpenReader(zipPath) + if err != nil { + return fmt.Errorf("invalid ZIP: %w", err) + } + defer reader.Close() + + logger.DebugCF("zip", "Extracting ZIP", map[string]interface{}{ + "zip_path": zipPath, + "target_dir": targetDir, + "entries": len(reader.File), + }) + + if err := os.MkdirAll(targetDir, 0755); err != nil { + return fmt.Errorf("failed to create target dir: %w", err) + } + + for _, f := range reader.File { + // Path traversal protection. + cleanName := filepath.Clean(f.Name) + if strings.HasPrefix(cleanName, "..") || filepath.IsAbs(cleanName) { + return fmt.Errorf("zip entry has unsafe path: %q", f.Name) + } + + destPath := filepath.Join(targetDir, cleanName) + + // Double-check the resolved path is within target directory (defense-in-depth). + targetDirClean := filepath.Clean(targetDir) + if !strings.HasPrefix(filepath.Clean(destPath), targetDirClean+string(filepath.Separator)) && filepath.Clean(destPath) != targetDirClean { + return fmt.Errorf("zip entry escapes target dir: %q", f.Name) + } + + mode := f.FileInfo().Mode() + + // Reject any symlink. + if mode&os.ModeSymlink != 0 { + return fmt.Errorf("zip contains symlink %q; symlinks are not allowed", f.Name) + } + + if f.FileInfo().IsDir() { + if err := os.MkdirAll(destPath, 0755); err != nil { + return err + } + continue + } + + // Ensure parent directory exists. + if err := os.MkdirAll(filepath.Dir(destPath), 0755); err != nil { + return err + } + + if err := extractSingleFile(f, destPath); err != nil { + return err + } + } + + return nil +} + +// extractSingleFile extracts one zip.File entry to destPath, with a size check. +func extractSingleFile(f *zip.File, destPath string) error { + const maxFileSize = 5 * 1024 * 1024 // 5MB, adjust as appropriate + + // Check the uncompressed size from the header, if available. + if f.UncompressedSize64 > maxFileSize { + return fmt.Errorf("zip entry %q is too large (%d bytes)", f.Name, f.UncompressedSize64) + } + + rc, err := f.Open() + if err != nil { + return fmt.Errorf("failed to open zip entry %q: %w", f.Name, err) + } + defer rc.Close() + + outFile, err := os.Create(destPath) + if err != nil { + return fmt.Errorf("failed to create file %q: %w", destPath, err) + } + // We don't return the close error via return, since it's not a named error return. + // Instead, we log to stderr and remove the partially written file as defensive cleanup. + defer func() { + if cerr := outFile.Close(); cerr != nil { + _ = os.Remove(destPath) + logger.ErrorCF("zip", "Failed to close file", map[string]interface{}{ + "dest_path": destPath, + "error": cerr.Error(), + }) + } + }() + + // Streamed size check: prevent overruns and malicious/corrupt headers. + written, err := io.CopyN(outFile, rc, maxFileSize+1) + if err != nil && err != io.EOF { + _ = os.Remove(destPath) + return fmt.Errorf("failed to extract %q: %w", f.Name, err) + } + if written > maxFileSize { + _ = os.Remove(destPath) + return fmt.Errorf("zip entry %q exceeds max size (%d bytes)", f.Name, written) + } + + return nil +}