Feature: Implement Skill Discovery - With Clawhub Integration and Caching (#332)

* Add Find Skills and Install Skills

* Improvements

* fix file name

* Update pkg/skills/clawhub_registry.go

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* fix

* Comments addressed

* Resolve comments

* fix tests

* fixes

* Comments resolved

* Update pkg/skills/search_cache_repro_test.go

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* minor fix

* fix test

* fixes

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
Harsh Bansal
2026-02-20 16:25:04 +05:30
committed by GitHub
parent f1223eec42
commit d692cc0cc6
20 changed files with 2303 additions and 10 deletions
+95 -6
View File
@@ -11,15 +11,17 @@ import (
"strings"
"time"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/skills"
"github.com/sipeed/picoclaw/pkg/utils"
)
func skillsHelp() {
fmt.Println("\nSkills commands:")
fmt.Println(" list List installed skills")
fmt.Println(" install <repo> Install skill from GitHub")
fmt.Println(" install-builtin Install all builtin skills to workspace")
fmt.Println(" list-builtin List available builtin skills")
fmt.Println(" install-builtin Install all builtin skills to workspace")
fmt.Println(" list-builtin List available builtin skills")
fmt.Println(" remove <name> Remove installed skill")
fmt.Println(" search Search available skills")
fmt.Println(" show <name> Show skill details")
@@ -30,6 +32,7 @@ func skillsHelp() {
fmt.Println(" picoclaw skills install-builtin")
fmt.Println(" picoclaw skills list-builtin")
fmt.Println(" picoclaw skills remove weather")
fmt.Println(" picoclaw skills install --registry clawhub github")
}
func skillsListCmd(loader *skills.SkillsLoader) {
@@ -50,13 +53,27 @@ func skillsListCmd(loader *skills.SkillsLoader) {
}
}
func skillsInstallCmd(installer *skills.SkillInstaller) {
func skillsInstallCmd(installer *skills.SkillInstaller, cfg *config.Config) {
if len(os.Args) < 4 {
fmt.Println("Usage: picoclaw skills install <github-repo>")
fmt.Println("Example: picoclaw skills install sipeed/picoclaw-skills/weather")
fmt.Println(" picoclaw skills install --registry <name> <slug>")
return
}
// Check for --registry flag.
if os.Args[3] == "--registry" {
if len(os.Args) < 6 {
fmt.Println("Usage: picoclaw skills install --registry <name> <slug>")
fmt.Println("Example: picoclaw skills install --registry clawhub github")
return
}
registryName := os.Args[4]
slug := os.Args[5]
skillsInstallFromRegistry(cfg, registryName, slug)
return
}
// Default: install from GitHub (backward compatible).
repo := os.Args[3]
fmt.Printf("Installing skill from %s...\n", repo)
@@ -64,11 +81,83 @@ func skillsInstallCmd(installer *skills.SkillInstaller) {
defer cancel()
if err := installer.InstallFromGitHub(ctx, repo); err != nil {
fmt.Printf(" Failed to install skill: %v\n", err)
fmt.Printf("\u2717 Failed to install skill: %v\n", err)
os.Exit(1)
}
fmt.Printf(" Skill '%s' installed successfully!\n", filepath.Base(repo))
fmt.Printf("\u2713 Skill '%s' installed successfully!\n", filepath.Base(repo))
}
// skillsInstallFromRegistry installs a skill from a named registry (e.g. clawhub).
func skillsInstallFromRegistry(cfg *config.Config, registryName, slug string) {
err := utils.ValidateSkillIdentifier(registryName)
if err != nil {
fmt.Printf("\u2717 Invalid registry name: %v\n", err)
os.Exit(1)
}
err = utils.ValidateSkillIdentifier(slug)
if err != nil {
fmt.Printf("\u2717 Invalid slug: %v\n", err)
os.Exit(1)
}
fmt.Printf("Installing skill '%s' from %s registry...\n", slug, registryName)
registryMgr := skills.NewRegistryManagerFromConfig(skills.RegistryConfig{
MaxConcurrentSearches: cfg.Tools.Skills.MaxConcurrentSearches,
ClawHub: skills.ClawHubConfig(cfg.Tools.Skills.Registries.ClawHub),
})
registry := registryMgr.GetRegistry(registryName)
if registry == nil {
fmt.Printf("\u2717 Registry '%s' not found or not enabled. Check your config.json.\n", registryName)
os.Exit(1)
}
workspace := cfg.WorkspacePath()
targetDir := filepath.Join(workspace, "skills", slug)
if _, err := os.Stat(targetDir); err == nil {
fmt.Printf("\u2717 Skill '%s' already installed at %s\n", slug, targetDir)
os.Exit(1)
}
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
defer cancel()
if err := os.MkdirAll(filepath.Join(workspace, "skills"), 0755); err != nil {
fmt.Printf("\u2717 Failed to create skills directory: %v\n", err)
os.Exit(1)
}
result, err := registry.DownloadAndInstall(ctx, slug, "", targetDir)
if err != nil {
rmErr := os.RemoveAll(targetDir)
if rmErr != nil {
fmt.Printf("\u2717 Failed to remove partial install: %v\n", rmErr)
}
fmt.Printf("\u2717 Failed to install skill: %v\n", err)
os.Exit(1)
}
if result.IsMalwareBlocked {
rmErr := os.RemoveAll(targetDir)
if rmErr != nil {
fmt.Printf("\u2717 Failed to remove partial install: %v\n", rmErr)
}
fmt.Printf("\u2717 Skill '%s' is flagged as malicious and cannot be installed.\n", slug)
os.Exit(1)
}
if result.IsSuspicious {
fmt.Printf("\u26a0\ufe0f Warning: skill '%s' is flagged as suspicious.\n", slug)
}
fmt.Printf("\u2713 Skill '%s' v%s installed successfully!\n", slug, result.Version)
if result.Summary != "" {
fmt.Printf(" %s\n", result.Summary)
}
}
func skillsRemoveCmd(installer *skills.SkillInstaller, skillName string) {
+1 -1
View File
@@ -141,7 +141,7 @@ func main() {
case "list":
skillsListCmd(skillsLoader)
case "install":
skillsInstallCmd(installer)
skillsInstallCmd(installer, cfg)
case "remove", "uninstall":
if len(os.Args) < 4 {
fmt.Println("Usage: picoclaw skills remove <skill-name>")
+11
View File
@@ -194,6 +194,17 @@
"exec": {
"enable_deny_patterns": false,
"custom_deny_patterns": []
},
"skills": {
"registries": {
"clawhub": {
"enabled": true,
"base_url": "https://clawhub.ai",
"search_path": "/api/v1/search",
"skills_path": "/api/v1/skills",
"download_path": "/api/v1/download"
}
}
}
},
"heartbeat": {
+10
View File
@@ -23,6 +23,7 @@ import (
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/routing"
"github.com/sipeed/picoclaw/pkg/skills"
"github.com/sipeed/picoclaw/pkg/state"
"github.com/sipeed/picoclaw/pkg/tools"
"github.com/sipeed/picoclaw/pkg/utils"
@@ -117,6 +118,15 @@ func registerSharedTools(cfg *config.Config, msgBus *bus.MessageBus, registry *A
})
agent.Tools.Register(messageTool)
// Skill discovery and installation tools
registryMgr := skills.NewRegistryManagerFromConfig(skills.RegistryConfig{
MaxConcurrentSearches: cfg.Tools.Skills.MaxConcurrentSearches,
ClawHub: skills.ClawHubConfig(cfg.Tools.Skills.Registries.ClawHub),
})
searchCache := skills.NewSearchCache(cfg.Tools.Skills.SearchCache.MaxSize, time.Duration(cfg.Tools.Skills.SearchCache.TTLSeconds)*time.Second)
agent.Tools.Register(tools.NewFindSkillsTool(registryMgr, searchCache))
agent.Tools.Register(tools.NewInstallSkillTool(registryMgr, agent.Workspace))
// Spawn tool with allowlist checker
subagentManager := tools.NewSubagentManager(provider, agent.Model, agent.Workspace, msgBus)
subagentManager.SetLLMOptions(agent.MaxTokens, agent.Temperature)
+31 -3
View File
@@ -416,9 +416,37 @@ type ExecConfig struct {
}
type ToolsConfig struct {
Web WebToolsConfig `json:"web"`
Cron CronToolsConfig `json:"cron"`
Exec ExecConfig `json:"exec"`
Web WebToolsConfig `json:"web"`
Cron CronToolsConfig `json:"cron"`
Exec ExecConfig `json:"exec"`
Skills SkillsToolsConfig `json:"skills"`
}
type SkillsToolsConfig struct {
Registries SkillsRegistriesConfig `json:"registries"`
MaxConcurrentSearches int `json:"max_concurrent_searches" env:"PICOCLAW_SKILLS_MAX_CONCURRENT_SEARCHES"`
SearchCache SearchCacheConfig `json:"search_cache"`
}
type SearchCacheConfig struct {
MaxSize int `json:"max_size" env:"PICOCLAW_SKILLS_SEARCH_CACHE_MAX_SIZE"`
TTLSeconds int `json:"ttl_seconds" env:"PICOCLAW_SKILLS_SEARCH_CACHE_TTL_SECONDS"`
}
type SkillsRegistriesConfig struct {
ClawHub ClawHubRegistryConfig `json:"clawhub"`
}
type ClawHubRegistryConfig struct {
Enabled bool `json:"enabled" env:"PICOCLAW_SKILLS_REGISTRIES_CLAWHUB_ENABLED"`
BaseURL string `json:"base_url" env:"PICOCLAW_SKILLS_REGISTRIES_CLAWHUB_BASE_URL"`
AuthToken string `json:"auth_token" env:"PICOCLAW_SKILLS_REGISTRIES_CLAWHUB_AUTH_TOKEN"`
SearchPath string `json:"search_path" env:"PICOCLAW_SKILLS_REGISTRIES_CLAWHUB_SEARCH_PATH"`
SkillsPath string `json:"skills_path" env:"PICOCLAW_SKILLS_REGISTRIES_CLAWHUB_SKILLS_PATH"`
DownloadPath string `json:"download_path" env:"PICOCLAW_SKILLS_REGISTRIES_CLAWHUB_DOWNLOAD_PATH"`
Timeout int `json:"timeout" env:"PICOCLAW_SKILLS_REGISTRIES_CLAWHUB_TIMEOUT"`
MaxZipSize int `json:"max_zip_size" env:"PICOCLAW_SKILLS_REGISTRIES_CLAWHUB_MAX_ZIP_SIZE"`
MaxResponseSize int `json:"max_response_size" env:"PICOCLAW_SKILLS_REGISTRIES_CLAWHUB_MAX_RESPONSE_SIZE"`
}
func LoadConfig(path string) (*Config, error) {
+13
View File
@@ -265,6 +265,19 @@ func DefaultConfig() *Config {
Exec: ExecConfig{
EnableDenyPatterns: true,
},
Skills: SkillsToolsConfig{
Registries: SkillsRegistriesConfig{
ClawHub: ClawHubRegistryConfig{
Enabled: true,
BaseURL: "https://clawhub.ai",
},
},
MaxConcurrentSearches: 2,
SearchCache: SearchCacheConfig{
MaxSize: 50,
TTLSeconds: 300,
},
},
},
Heartbeat: HeartbeatConfig{
Enabled: true,
+311
View File
@@ -0,0 +1,311 @@
package skills
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"os"
"time"
"github.com/sipeed/picoclaw/pkg/utils"
)
const (
defaultClawHubTimeout = 30 * time.Second
defaultMaxZipSize = 50 * 1024 * 1024 // 50 MB
defaultMaxResponseSize = 2 * 1024 * 1024 // 2 MB
)
// ClawHubRegistry implements SkillRegistry for the ClawHub platform.
type ClawHubRegistry struct {
baseURL string
authToken string // Optional - for elevated rate limits
searchPath string // Search API
skillsPath string // For retrieving skill metadata
downloadPath string // For fetching ZIP files for download
maxZipSize int
maxResponseSize int
client *http.Client
}
// NewClawHubRegistry creates a new ClawHub registry client from config.
func NewClawHubRegistry(cfg ClawHubConfig) *ClawHubRegistry {
baseURL := cfg.BaseURL
if baseURL == "" {
baseURL = "https://clawhub.ai"
}
searchPath := cfg.SearchPath
if searchPath == "" {
searchPath = "/api/v1/search"
}
skillsPath := cfg.SkillsPath
if skillsPath == "" {
skillsPath = "/api/v1/skills"
}
downloadPath := cfg.DownloadPath
if downloadPath == "" {
downloadPath = "/api/v1/download"
}
timeout := defaultClawHubTimeout
if cfg.Timeout > 0 {
timeout = time.Duration(cfg.Timeout) * time.Second
}
maxZip := defaultMaxZipSize
if cfg.MaxZipSize > 0 {
maxZip = cfg.MaxZipSize
}
maxResp := defaultMaxResponseSize
if cfg.MaxResponseSize > 0 {
maxResp = cfg.MaxResponseSize
}
return &ClawHubRegistry{
baseURL: baseURL,
authToken: cfg.AuthToken,
searchPath: searchPath,
skillsPath: skillsPath,
downloadPath: downloadPath,
maxZipSize: maxZip,
maxResponseSize: maxResp,
client: &http.Client{
Timeout: timeout,
Transport: &http.Transport{
MaxIdleConns: 5,
IdleConnTimeout: 30 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
},
},
}
}
func (c *ClawHubRegistry) Name() string {
return "clawhub"
}
// --- Search ---
type clawhubSearchResponse struct {
Results []clawhubSearchResult `json:"results"`
}
type clawhubSearchResult struct {
Score float64 `json:"score"`
Slug *string `json:"slug"`
DisplayName *string `json:"displayName"`
Summary *string `json:"summary"`
Version *string `json:"version"`
}
func (c *ClawHubRegistry) Search(ctx context.Context, query string, limit int) ([]SearchResult, error) {
u, err := url.Parse(c.baseURL + c.searchPath)
if err != nil {
return nil, fmt.Errorf("invalid base URL: %w", err)
}
q := u.Query()
q.Set("q", query)
if limit > 0 {
q.Set("limit", fmt.Sprintf("%d", limit))
}
u.RawQuery = q.Encode()
body, err := c.doGet(ctx, u.String())
if err != nil {
return nil, fmt.Errorf("search request failed: %w", err)
}
var resp clawhubSearchResponse
if err := json.Unmarshal(body, &resp); err != nil {
return nil, fmt.Errorf("failed to parse search response: %w", err)
}
results := make([]SearchResult, 0, len(resp.Results))
for _, r := range resp.Results {
slug := utils.DerefStr(r.Slug, "")
if slug == "" {
continue
}
summary := utils.DerefStr(r.Summary, "")
if summary == "" {
continue
}
displayName := utils.DerefStr(r.DisplayName, "")
if displayName == "" {
displayName = slug
}
results = append(results, SearchResult{
Score: r.Score,
Slug: slug,
DisplayName: displayName,
Summary: summary,
Version: utils.DerefStr(r.Version, ""),
RegistryName: c.Name(),
})
}
return results, nil
}
// --- GetSkillMeta ---
type clawhubSkillResponse struct {
Slug string `json:"slug"`
DisplayName string `json:"displayName"`
Summary string `json:"summary"`
LatestVersion *clawhubVersionInfo `json:"latestVersion"`
Moderation *clawhubModerationInfo `json:"moderation"`
}
type clawhubVersionInfo struct {
Version string `json:"version"`
}
type clawhubModerationInfo struct {
IsMalwareBlocked bool `json:"isMalwareBlocked"`
IsSuspicious bool `json:"isSuspicious"`
}
func (c *ClawHubRegistry) GetSkillMeta(ctx context.Context, slug string) (*SkillMeta, error) {
if err := utils.ValidateSkillIdentifier(slug); err != nil {
return nil, fmt.Errorf("invalid slug %q: error: %s", slug, err.Error())
}
u := c.baseURL + c.skillsPath + "/" + url.PathEscape(slug)
body, err := c.doGet(ctx, u)
if err != nil {
return nil, fmt.Errorf("skill metadata request failed: %w", err)
}
var resp clawhubSkillResponse
if err := json.Unmarshal(body, &resp); err != nil {
return nil, fmt.Errorf("failed to parse skill metadata: %w", err)
}
meta := &SkillMeta{
Slug: resp.Slug,
DisplayName: resp.DisplayName,
Summary: resp.Summary,
RegistryName: c.Name(),
}
if resp.LatestVersion != nil {
meta.LatestVersion = resp.LatestVersion.Version
}
if resp.Moderation != nil {
meta.IsMalwareBlocked = resp.Moderation.IsMalwareBlocked
meta.IsSuspicious = resp.Moderation.IsSuspicious
}
return meta, nil
}
// --- DownloadAndInstall ---
// DownloadAndInstall fetches metadata (with fallback), resolves version,
// downloads the skill ZIP, and extracts it to targetDir.
// Returns an InstallResult for the caller to use for moderation decisions.
func (c *ClawHubRegistry) DownloadAndInstall(ctx context.Context, slug, version, targetDir string) (*InstallResult, error) {
if err := utils.ValidateSkillIdentifier(slug); err != nil {
return nil, fmt.Errorf("invalid slug %q: error: %s", slug, err.Error())
}
// Step 1: Fetch metadata (with fallback).
result := &InstallResult{}
meta, err := c.GetSkillMeta(ctx, slug)
if err != nil {
// Fallback: proceed without metadata.
meta = nil
}
if meta != nil {
result.IsMalwareBlocked = meta.IsMalwareBlocked
result.IsSuspicious = meta.IsSuspicious
result.Summary = meta.Summary
}
// Step 2: Resolve version.
installVersion := version
if installVersion == "" && meta != nil {
installVersion = meta.LatestVersion
}
if installVersion == "" {
installVersion = "latest"
}
result.Version = installVersion
// Step 3: Download ZIP to temp file (streams in ~32KB chunks).
u, err := url.Parse(c.baseURL + c.downloadPath)
if err != nil {
return nil, fmt.Errorf("invalid base URL: %w", err)
}
q := u.Query()
q.Set("slug", slug)
if installVersion != "latest" {
q.Set("version", installVersion)
}
u.RawQuery = q.Encode()
req, err := http.NewRequestWithContext(ctx, "GET", u.String(), nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
if c.authToken != "" {
req.Header.Set("Authorization", "Bearer "+c.authToken)
}
tmpPath, err := utils.DownloadToFile(ctx, c.client, req, int64(c.maxZipSize))
if err != nil {
return nil, fmt.Errorf("download failed: %w", err)
}
defer os.Remove(tmpPath)
// Step 4: Extract from file on disk.
if err := utils.ExtractZipFile(tmpPath, targetDir); err != nil {
return nil, err
}
return result, nil
}
// --- HTTP helper ---
func (c *ClawHubRegistry) doGet(ctx context.Context, urlStr string) ([]byte, error) {
req, err := http.NewRequestWithContext(ctx, "GET", urlStr, nil)
if err != nil {
return nil, err
}
req.Header.Set("Accept", "application/json")
if c.authToken != "" {
req.Header.Set("Authorization", "Bearer "+c.authToken)
}
resp, err := c.client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
// Limit response body read to prevent memory issues.
body, err := io.ReadAll(io.LimitReader(resp.Body, int64(c.maxResponseSize)))
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(body))
}
return body, nil
}
+256
View File
@@ -0,0 +1,256 @@
package skills
import (
"archive/zip"
"bytes"
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
"github.com/sipeed/picoclaw/pkg/utils"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func newTestRegistry(serverURL, authToken string) *ClawHubRegistry {
return NewClawHubRegistry(ClawHubConfig{
Enabled: true,
BaseURL: serverURL,
AuthToken: authToken,
})
}
func TestClawHubRegistrySearch(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "/api/v1/search", r.URL.Path)
assert.Equal(t, "github", r.URL.Query().Get("q"))
slug := "github"
name := "GitHub Integration"
summary := "Interact with GitHub repos"
version := "1.0.0"
json.NewEncoder(w).Encode(clawhubSearchResponse{
Results: []clawhubSearchResult{
{Score: 0.95, Slug: &slug, DisplayName: &name, Summary: &summary, Version: &version},
},
})
}))
defer srv.Close()
reg := newTestRegistry(srv.URL, "")
results, err := reg.Search(context.Background(), "github", 5)
require.NoError(t, err)
require.Len(t, results, 1)
assert.Equal(t, "github", results[0].Slug)
assert.Equal(t, "GitHub Integration", results[0].DisplayName)
assert.InDelta(t, 0.95, results[0].Score, 0.001)
assert.Equal(t, "clawhub", results[0].RegistryName)
}
func TestClawHubRegistryGetSkillMeta(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "/api/v1/skills/github", r.URL.Path)
json.NewEncoder(w).Encode(clawhubSkillResponse{
Slug: "github",
DisplayName: "GitHub Integration",
Summary: "Full GitHub API integration",
LatestVersion: &clawhubVersionInfo{
Version: "2.1.0",
},
Moderation: &clawhubModerationInfo{
IsMalwareBlocked: false,
IsSuspicious: true,
},
})
}))
defer srv.Close()
reg := newTestRegistry(srv.URL, "")
meta, err := reg.GetSkillMeta(context.Background(), "github")
require.NoError(t, err)
assert.Equal(t, "github", meta.Slug)
assert.Equal(t, "2.1.0", meta.LatestVersion)
assert.False(t, meta.IsMalwareBlocked)
assert.True(t, meta.IsSuspicious)
}
func TestClawHubRegistryGetSkillMetaUnsafeSlug(t *testing.T) {
reg := newTestRegistry("https://example.com", "")
_, err := reg.GetSkillMeta(context.Background(), "../etc/passwd")
assert.Error(t, err)
assert.Contains(t, err.Error(), "invalid slug")
}
func TestClawHubRegistryDownloadAndInstall(t *testing.T) {
// Create a valid ZIP in memory.
zipBuf := createTestZip(t, map[string]string{
"SKILL.md": "---\nname: test-skill\ndescription: A test\n---\nHello skill",
"README.md": "# Test Skill\n",
})
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/v1/skills/test-skill":
// Metadata endpoint.
json.NewEncoder(w).Encode(clawhubSkillResponse{
Slug: "test-skill",
DisplayName: "Test Skill",
Summary: "A test skill",
LatestVersion: &clawhubVersionInfo{Version: "1.0.0"},
})
case "/api/v1/download":
assert.Equal(t, "test-skill", r.URL.Query().Get("slug"))
w.Header().Set("Content-Type", "application/zip")
w.Write(zipBuf)
default:
w.WriteHeader(http.StatusNotFound)
}
}))
defer srv.Close()
tmpDir := t.TempDir()
targetDir := filepath.Join(tmpDir, "test-skill")
reg := newTestRegistry(srv.URL, "")
result, err := reg.DownloadAndInstall(context.Background(), "test-skill", "1.0.0", targetDir)
require.NoError(t, err)
assert.Equal(t, "1.0.0", result.Version)
assert.False(t, result.IsMalwareBlocked)
// Verify extracted files.
skillContent, err := os.ReadFile(filepath.Join(targetDir, "SKILL.md"))
require.NoError(t, err)
assert.Contains(t, string(skillContent), "Hello skill")
readmeContent, err := os.ReadFile(filepath.Join(targetDir, "README.md"))
require.NoError(t, err)
assert.Contains(t, string(readmeContent), "# Test Skill")
}
func TestClawHubRegistryAuthToken(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
authHeader := r.Header.Get("Authorization")
assert.Equal(t, "Bearer test-token-123", authHeader)
json.NewEncoder(w).Encode(clawhubSearchResponse{Results: nil})
}))
defer srv.Close()
reg := newTestRegistry(srv.URL, "test-token-123")
_, _ = reg.Search(context.Background(), "test", 5)
}
func TestExtractZipPathTraversal(t *testing.T) {
// Create a ZIP with a path traversal entry.
var buf bytes.Buffer
zw := zip.NewWriter(&buf)
// Malicious entry trying to escape directory.
w, err := zw.Create("../../etc/passwd")
require.NoError(t, err)
w.Write([]byte("malicious"))
zw.Close()
// Write to temp file for extractZipFile.
tmpZip := filepath.Join(t.TempDir(), "bad.zip")
require.NoError(t, os.WriteFile(tmpZip, buf.Bytes(), 0644))
tmpDir := t.TempDir()
err = utils.ExtractZipFile(tmpZip, tmpDir)
assert.Error(t, err)
assert.Contains(t, err.Error(), "unsafe path")
}
func TestExtractZipWithSubdirectories(t *testing.T) {
zipBuf := createTestZip(t, map[string]string{
"SKILL.md": "root file",
"scripts/helper.sh": "#!/bin/bash\necho hello",
"examples/demo.yaml": "key: value",
})
// Write to temp file for extractZipFile.
tmpZip := filepath.Join(t.TempDir(), "test.zip")
require.NoError(t, os.WriteFile(tmpZip, zipBuf, 0644))
tmpDir := t.TempDir()
targetDir := filepath.Join(tmpDir, "my-skill")
err := utils.ExtractZipFile(tmpZip, targetDir)
require.NoError(t, err)
// Verify nested file.
data, err := os.ReadFile(filepath.Join(targetDir, "scripts", "helper.sh"))
require.NoError(t, err)
assert.Contains(t, string(data), "#!/bin/bash")
}
func TestClawHubRegistrySearchHTTPError(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte("Internal Server Error"))
}))
defer srv.Close()
reg := newTestRegistry(srv.URL, "")
_, err := reg.Search(context.Background(), "test", 5)
assert.Error(t, err)
assert.Contains(t, err.Error(), "500")
}
func TestClawHubRegistrySearchNullableFields(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
validSlug := "valid-slug"
validSummary := "valid summary"
// Return results with various null/empty fields
json.NewEncoder(w).Encode(clawhubSearchResponse{
Results: []clawhubSearchResult{
// Case 1: Null Slug -> Skip
{Score: 0.1, Slug: nil, DisplayName: nil, Summary: nil, Version: nil},
// Case 2: Valid Slug, Null Summary -> Skip
{Score: 0.2, Slug: &validSlug, DisplayName: nil, Summary: nil, Version: nil},
// Case 3: Valid Slug, Valid Summary, Null Name -> Keep, Name=Slug
{Score: 0.8, Slug: &validSlug, DisplayName: nil, Summary: &validSummary, Version: nil},
},
})
}))
defer srv.Close()
reg := newTestRegistry(srv.URL, "")
results, err := reg.Search(context.Background(), "test", 5)
require.NoError(t, err)
require.Len(t, results, 1, "should only return 1 valid result")
r := results[0]
assert.Equal(t, "valid-slug", r.Slug)
assert.Equal(t, "valid-slug", r.DisplayName, "should fallback name to slug")
assert.Equal(t, "valid summary", r.Summary)
}
// --- helpers ---
func createTestZip(t *testing.T, files map[string]string) []byte {
t.Helper()
var buf bytes.Buffer
zw := zip.NewWriter(&buf)
for name, content := range files {
w, err := zw.Create(name)
require.NoError(t, err)
_, err = w.Write([]byte(content))
require.NoError(t, err)
}
require.NoError(t, zw.Close())
return buf.Bytes()
}
+223
View File
@@ -0,0 +1,223 @@
package skills
import (
"context"
"fmt"
"log/slog"
"sync"
"time"
)
const (
defaultMaxConcurrentSearches = 2
)
// SearchResult represents a single result from a skill registry search.
type SearchResult struct {
Score float64 `json:"score"`
Slug string `json:"slug"`
DisplayName string `json:"display_name"`
Summary string `json:"summary"`
Version string `json:"version"`
RegistryName string `json:"registry_name"`
}
// SkillMeta holds metadata about a skill from a registry.
type SkillMeta struct {
Slug string `json:"slug"`
DisplayName string `json:"display_name"`
Summary string `json:"summary"`
LatestVersion string `json:"latest_version"`
IsMalwareBlocked bool `json:"is_malware_blocked"`
IsSuspicious bool `json:"is_suspicious"`
RegistryName string `json:"registry_name"`
}
// InstallResult is returned by DownloadAndInstall to carry metadata
// back to the caller for moderation and user messaging.
type InstallResult struct {
Version string
IsMalwareBlocked bool
IsSuspicious bool
Summary string
}
// SkillRegistry is the interface that all skill registries must implement.
// Each registry represents a different source of skills (e.g., clawhub.ai)
type SkillRegistry interface {
// Name returns the unique name of this registry (e.g., "clawhub").
Name() string
// Search searches the registry for skills matching the query.
Search(ctx context.Context, query string, limit int) ([]SearchResult, error)
// GetSkillMeta retrieves metadata for a specific skill by slug.
GetSkillMeta(ctx context.Context, slug string) (*SkillMeta, error)
// DownloadAndInstall fetches metadata, resolves the version, downloads and
// installs the skill to targetDir. Returns an InstallResult with metadata
// for the caller to use for moderation and user messaging.
DownloadAndInstall(ctx context.Context, slug, version, targetDir string) (*InstallResult, error)
}
// RegistryConfig holds configuration for all skill registries.
// This is the input to NewRegistryManagerFromConfig.
type RegistryConfig struct {
ClawHub ClawHubConfig
MaxConcurrentSearches int
}
// ClawHubConfig configures the ClawHub registry.
type ClawHubConfig struct {
Enabled bool
BaseURL string
AuthToken string
SearchPath string // e.g. "/api/v1/search"
SkillsPath string // e.g. "/api/v1/skills"
DownloadPath string // e.g. "/api/v1/download"
Timeout int // seconds, 0 = default (30s)
MaxZipSize int // bytes, 0 = default (50MB)
MaxResponseSize int // bytes, 0 = default (2MB)
}
// RegistryManager coordinates multiple skill registries.
// It fans out search requests and routes installs to the correct registry.
type RegistryManager struct {
registries []SkillRegistry
maxConcurrent int
mu sync.RWMutex
}
// NewRegistryManager creates an empty RegistryManager.
func NewRegistryManager() *RegistryManager {
return &RegistryManager{
registries: make([]SkillRegistry, 0),
maxConcurrent: defaultMaxConcurrentSearches,
}
}
// NewRegistryManagerFromConfig builds a RegistryManager from config,
// instantiating only the enabled registries.
func NewRegistryManagerFromConfig(cfg RegistryConfig) *RegistryManager {
rm := NewRegistryManager()
if cfg.MaxConcurrentSearches > 0 {
rm.maxConcurrent = cfg.MaxConcurrentSearches
}
if cfg.ClawHub.Enabled {
rm.AddRegistry(NewClawHubRegistry(cfg.ClawHub))
}
return rm
}
// AddRegistry adds a registry to the manager.
func (rm *RegistryManager) AddRegistry(r SkillRegistry) {
rm.mu.Lock()
defer rm.mu.Unlock()
rm.registries = append(rm.registries, r)
}
// GetRegistry returns a registry by name, or nil if not found.
func (rm *RegistryManager) GetRegistry(name string) SkillRegistry {
rm.mu.RLock()
defer rm.mu.RUnlock()
for _, r := range rm.registries {
if r.Name() == name {
return r
}
}
return nil
}
// SearchAll fans out the query to all registries concurrently
// and merges results sorted by score descending.
func (rm *RegistryManager) SearchAll(ctx context.Context, query string, limit int) ([]SearchResult, error) {
rm.mu.RLock()
regs := make([]SkillRegistry, len(rm.registries))
copy(regs, rm.registries)
rm.mu.RUnlock()
if len(regs) == 0 {
return nil, fmt.Errorf("no registries configured")
}
type regResult struct {
results []SearchResult
err error
}
// Semaphore: limit concurrency.
sem := make(chan struct{}, rm.maxConcurrent)
resultsCh := make(chan regResult, len(regs))
var wg sync.WaitGroup
for _, reg := range regs {
wg.Add(1)
go func(r SkillRegistry) {
defer wg.Done()
// Acquire semaphore slot.
select {
case sem <- struct{}{}:
defer func() { <-sem }()
case <-ctx.Done():
resultsCh <- regResult{err: ctx.Err()}
return
}
searchCtx, cancel := context.WithTimeout(ctx, 1*time.Minute)
defer cancel()
results, err := r.Search(searchCtx, query, limit)
if err != nil {
slog.Warn("registry search failed", "registry", r.Name(), "error", err)
resultsCh <- regResult{err: err}
return
}
resultsCh <- regResult{results: results}
}(reg)
}
// Close results channel after all goroutines complete.
go func() {
wg.Wait()
close(resultsCh)
}()
var merged []SearchResult
var lastErr error
var anyRegistrySucceeded bool
for rr := range resultsCh {
if rr.err != nil {
lastErr = rr.err
continue
}
anyRegistrySucceeded = true
merged = append(merged, rr.results...)
}
// If all registries failed, return the last error.
if !anyRegistrySucceeded && lastErr != nil {
return nil, fmt.Errorf("all registries failed: %w", lastErr)
}
// Sort by score descending.
sortByScoreDesc(merged)
// Clamp to limit.
if limit > 0 && len(merged) > limit {
merged = merged[:limit]
}
return merged, nil
}
// sortByScoreDesc sorts SearchResults by Score in descending order (insertion sort — small slices).
func sortByScoreDesc(results []SearchResult) {
for i := 1; i < len(results); i++ {
key := results[i]
j := i - 1
for j >= 0 && results[j].Score < key.Score {
results[j+1] = results[j]
j--
}
results[j+1] = key
}
}
+179
View File
@@ -0,0 +1,179 @@
package skills
import (
"context"
"fmt"
"testing"
"time"
"github.com/sipeed/picoclaw/pkg/utils"
"github.com/stretchr/testify/assert"
)
// mockRegistry is a test double implementing SkillRegistry.
type mockRegistry struct {
name string
searchResults []SearchResult
searchErr error
meta *SkillMeta
metaErr error
installResult *InstallResult
installErr error
}
func (m *mockRegistry) Name() string { return m.name }
func (m *mockRegistry) Search(_ context.Context, _ string, _ int) ([]SearchResult, error) {
return m.searchResults, m.searchErr
}
func (m *mockRegistry) GetSkillMeta(_ context.Context, _ string) (*SkillMeta, error) {
return m.meta, m.metaErr
}
func (m *mockRegistry) DownloadAndInstall(_ context.Context, _, _, _ string) (*InstallResult, error) {
return m.installResult, m.installErr
}
func TestRegistryManagerSearchAllSingle(t *testing.T) {
mgr := NewRegistryManager()
mgr.AddRegistry(&mockRegistry{
name: "test",
searchResults: []SearchResult{
{Slug: "skill-a", Score: 0.9, RegistryName: "test"},
{Slug: "skill-b", Score: 0.5, RegistryName: "test"},
},
})
results, err := mgr.SearchAll(context.Background(), "test query", 10)
assert.NoError(t, err)
assert.Len(t, results, 2)
assert.Equal(t, "skill-a", results[0].Slug)
}
func TestRegistryManagerSearchAllMultiple(t *testing.T) {
mgr := NewRegistryManager()
mgr.AddRegistry(&mockRegistry{
name: "alpha",
searchResults: []SearchResult{
{Slug: "skill-a", Score: 0.8, RegistryName: "alpha"},
},
})
mgr.AddRegistry(&mockRegistry{
name: "beta",
searchResults: []SearchResult{
{Slug: "skill-b", Score: 0.95, RegistryName: "beta"},
},
})
results, err := mgr.SearchAll(context.Background(), "test query", 10)
assert.NoError(t, err)
assert.Len(t, results, 2)
// Should be sorted by score descending
assert.Equal(t, "skill-b", results[0].Slug)
assert.Equal(t, "skill-a", results[1].Slug)
}
func TestRegistryManagerSearchAllOneFailsGracefully(t *testing.T) {
mgr := NewRegistryManager()
mgr.AddRegistry(&mockRegistry{
name: "failing",
searchErr: fmt.Errorf("network error"),
})
mgr.AddRegistry(&mockRegistry{
name: "working",
searchResults: []SearchResult{
{Slug: "skill-a", Score: 0.8, RegistryName: "working"},
},
})
results, err := mgr.SearchAll(context.Background(), "test query", 10)
assert.NoError(t, err)
assert.Len(t, results, 1)
assert.Equal(t, "skill-a", results[0].Slug)
}
func TestRegistryManagerSearchAllAllFail(t *testing.T) {
mgr := NewRegistryManager()
mgr.AddRegistry(&mockRegistry{
name: "fail-1",
searchErr: fmt.Errorf("error 1"),
})
_, err := mgr.SearchAll(context.Background(), "test query", 10)
assert.Error(t, err)
}
func TestRegistryManagerSearchAllNoRegistries(t *testing.T) {
mgr := NewRegistryManager()
_, err := mgr.SearchAll(context.Background(), "test query", 10)
assert.Error(t, err)
}
func TestRegistryManagerGetRegistry(t *testing.T) {
mgr := NewRegistryManager()
mock := &mockRegistry{name: "clawhub"}
mgr.AddRegistry(mock)
got := mgr.GetRegistry("clawhub")
assert.NotNil(t, got)
assert.Equal(t, "clawhub", got.Name())
got = mgr.GetRegistry("nonexistent")
assert.Nil(t, got)
}
func TestRegistryManagerSearchAllRespectLimit(t *testing.T) {
mgr := NewRegistryManager()
results := make([]SearchResult, 20)
for i := range results {
results[i] = SearchResult{Slug: fmt.Sprintf("skill-%d", i), Score: float64(20 - i)}
}
mgr.AddRegistry(&mockRegistry{
name: "test",
searchResults: results,
})
got, err := mgr.SearchAll(context.Background(), "test", 5)
assert.NoError(t, err)
assert.Len(t, got, 5)
// Top scores first
assert.Equal(t, "skill-0", got[0].Slug)
}
func TestRegistryManagerSearchAllTimeout(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond)
defer cancel()
time.Sleep(5 * time.Millisecond) // Let context expire.
mgr := NewRegistryManager()
mgr.AddRegistry(&mockRegistry{
name: "slow",
searchErr: fmt.Errorf("context deadline exceeded"),
})
_, err := mgr.SearchAll(ctx, "test", 5)
assert.Error(t, err)
}
func TestSortByScoreDesc(t *testing.T) {
results := []SearchResult{
{Slug: "c", Score: 0.3},
{Slug: "a", Score: 0.9},
{Slug: "b", Score: 0.5},
}
sortByScoreDesc(results)
assert.Equal(t, "a", results[0].Slug)
assert.Equal(t, "b", results[1].Slug)
assert.Equal(t, "c", results[2].Slug)
}
func TestIsSafeSlug(t *testing.T) {
assert.NoError(t, utils.ValidateSkillIdentifier("github"))
assert.NoError(t, utils.ValidateSkillIdentifier("docker-compose"))
assert.Error(t, utils.ValidateSkillIdentifier(""))
assert.Error(t, utils.ValidateSkillIdentifier("../etc/passwd"))
assert.Error(t, utils.ValidateSkillIdentifier("path/traversal"))
assert.Error(t, utils.ValidateSkillIdentifier("path\\traversal"))
}
+229
View File
@@ -0,0 +1,229 @@
package skills
import (
"sort"
"strings"
"sync"
"time"
)
// SearchCache provides lightweight caching for search results.
// It uses trigram-based similarity to match similar queries to cached results,
// avoiding redundant API calls. Thread-safe for concurrent access.
type SearchCache struct {
mu sync.RWMutex
entries map[string]*cacheEntry
order []string // LRU order: oldest first.
maxEntries int
ttl time.Duration
}
type cacheEntry struct {
query string
trigrams []uint32
results []SearchResult
createdAt time.Time
}
// similarityThreshold is the minimum trigram Jaccard similarity for a cache hit.
const similarityThreshold = 0.7
// NewSearchCache creates a new search cache.
// maxEntries is the maximum number of cached queries (excess evicts LRU).
// ttl is how long each entry lives before expiration.
func NewSearchCache(maxEntries int, ttl time.Duration) *SearchCache {
if maxEntries <= 0 {
maxEntries = 50
}
if ttl <= 0 {
ttl = 5 * time.Minute
}
return &SearchCache{
entries: make(map[string]*cacheEntry),
order: make([]string, 0),
maxEntries: maxEntries,
ttl: ttl,
}
}
// Get looks up results for a query. Returns cached results and true if found
// (either exact or similar match above threshold). Returns nil, false on miss.
func (sc *SearchCache) Get(query string) ([]SearchResult, bool) {
normalized := normalizeQuery(query)
if normalized == "" {
return nil, false
}
sc.mu.Lock()
defer sc.mu.Unlock()
// Exact match first.
if entry, ok := sc.entries[normalized]; ok {
if time.Since(entry.createdAt) < sc.ttl {
sc.moveToEndLocked(normalized)
return copyResults(entry.results), true
}
}
// Similarity match.
queryTrigrams := buildTrigrams(normalized)
var bestEntry *cacheEntry
var bestSim float64
for _, entry := range sc.entries {
if time.Since(entry.createdAt) >= sc.ttl {
continue // Skip expired.
}
sim := jaccardSimilarity(queryTrigrams, entry.trigrams)
if sim > bestSim {
bestSim = sim
bestEntry = entry
}
}
if bestSim >= similarityThreshold && bestEntry != nil {
sc.moveToEndLocked(bestEntry.query)
return copyResults(bestEntry.results), true
}
return nil, false
}
// Put stores results for a query. Evicts the oldest entry if at capacity.
func (sc *SearchCache) Put(query string, results []SearchResult) {
normalized := normalizeQuery(query)
if normalized == "" {
return
}
sc.mu.Lock()
defer sc.mu.Unlock()
// Evict expired entries first.
sc.evictExpiredLocked()
// If already exists, update.
if _, ok := sc.entries[normalized]; ok {
sc.entries[normalized] = &cacheEntry{
query: normalized,
trigrams: buildTrigrams(normalized),
results: copyResults(results),
createdAt: time.Now(),
}
// Move to end of LRU order.
sc.moveToEndLocked(normalized)
return
}
// Evict LRU if at capacity.
for len(sc.entries) >= sc.maxEntries && len(sc.order) > 0 {
oldest := sc.order[0]
sc.order = sc.order[1:]
delete(sc.entries, oldest)
}
// Insert new entry.
sc.entries[normalized] = &cacheEntry{
query: normalized,
trigrams: buildTrigrams(normalized),
results: copyResults(results),
createdAt: time.Now(),
}
sc.order = append(sc.order, normalized)
}
// Len returns the number of entries (for testing).
func (sc *SearchCache) Len() int {
sc.mu.RLock()
defer sc.mu.RUnlock()
return len(sc.entries)
}
// --- internal ---
func (sc *SearchCache) evictExpiredLocked() {
now := time.Now()
newOrder := make([]string, 0, len(sc.order))
for _, key := range sc.order {
entry, ok := sc.entries[key]
if !ok || now.Sub(entry.createdAt) >= sc.ttl {
delete(sc.entries, key)
continue
}
newOrder = append(newOrder, key)
}
sc.order = newOrder
}
func (sc *SearchCache) moveToEndLocked(key string) {
for i, k := range sc.order {
if k == key {
sc.order = append(sc.order[:i], sc.order[i+1:]...)
break
}
}
sc.order = append(sc.order, key)
}
func normalizeQuery(q string) string {
return strings.ToLower(strings.TrimSpace(q))
}
// buildTrigrams generates hash of trigrams from a string.
// Example: "hello" → {"hel", "ell", "llo"}
// "hel" -> 0x0068656c -> 4 bytes; compared to 16 bytes of a string
func buildTrigrams(s string) []uint32 {
if len(s) < 3 {
return nil
}
trigrams := make([]uint32, 0, len(s)-2)
for i := 0; i <= len(s)-3; i++ {
trigrams = append(trigrams, uint32(s[i])<<16|uint32(s[i+1])<<8|uint32(s[i+2]))
}
// Sort and Deduplication
sort.Slice(trigrams, func(i, j int) bool { return trigrams[i] < trigrams[j] })
n := 1
for i := 1; i < len(trigrams); i++ {
if trigrams[i] != trigrams[i-1] {
trigrams[n] = trigrams[i]
n++
}
}
return trigrams[:n]
}
// jaccardSimilarity computes |A ∩ B| / |A B|.
func jaccardSimilarity(a, b []uint32) float64 {
if len(a) == 0 && len(b) == 0 {
return 1
}
i, j := 0, 0
intersection := 0
for i < len(a) && j < len(b) {
if a[i] == b[j] {
intersection++
i++
j++
} else if a[i] < b[j] {
i++
} else {
j++
}
}
union := len(a) + len(b) - intersection
return float64(intersection) / float64(union)
}
func copyResults(results []SearchResult) []SearchResult {
if results == nil {
return nil
}
cp := make([]SearchResult, len(results))
copy(cp, results)
return cp
}
+200
View File
@@ -0,0 +1,200 @@
package skills
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestSearchCacheExactHit(t *testing.T) {
cache := NewSearchCache(10, 5*time.Minute)
results := []SearchResult{
{Slug: "github", Score: 0.9, RegistryName: "clawhub"},
{Slug: "docker", Score: 0.7, RegistryName: "clawhub"},
}
cache.Put("github integration", results)
got, hit := cache.Get("github integration")
assert.True(t, hit)
assert.Len(t, got, 2)
assert.Equal(t, "github", got[0].Slug)
}
func TestSearchCacheExactHitCaseInsensitive(t *testing.T) {
cache := NewSearchCache(10, 5*time.Minute)
results := []SearchResult{{Slug: "github", Score: 0.9}}
cache.Put("GitHub Integration", results)
got, hit := cache.Get("github integration")
assert.True(t, hit)
assert.Len(t, got, 1)
}
func TestSearchCacheSimilarHit(t *testing.T) {
cache := NewSearchCache(10, 5*time.Minute)
results := []SearchResult{{Slug: "github", Score: 0.9}}
cache.Put("github integration tool", results)
// "github integration" is very similar to "github integration tool"
got, hit := cache.Get("github integration")
assert.True(t, hit)
assert.Len(t, got, 1)
}
func TestSearchCacheDissimilarMiss(t *testing.T) {
cache := NewSearchCache(10, 5*time.Minute)
results := []SearchResult{{Slug: "github", Score: 0.9}}
cache.Put("github integration", results)
// Completely unrelated query
_, hit := cache.Get("database management")
assert.False(t, hit)
}
func TestSearchCacheTTLExpiration(t *testing.T) {
cache := NewSearchCache(10, 50*time.Millisecond)
results := []SearchResult{{Slug: "github", Score: 0.9}}
cache.Put("github integration", results)
// Immediately should hit
_, hit := cache.Get("github integration")
assert.True(t, hit)
// Wait for expiration
time.Sleep(100 * time.Millisecond)
_, hit = cache.Get("github integration")
assert.False(t, hit)
}
func TestSearchCacheLRUEviction(t *testing.T) {
cache := NewSearchCache(3, 5*time.Minute)
cache.Put("query-1", []SearchResult{{Slug: "a"}})
cache.Put("query-2", []SearchResult{{Slug: "b"}})
cache.Put("query-3", []SearchResult{{Slug: "c"}})
assert.Equal(t, 3, cache.Len())
// Adding a 4th should evict query-1 (oldest)
cache.Put("query-4", []SearchResult{{Slug: "d"}})
assert.Equal(t, 3, cache.Len())
_, hit := cache.Get("query-1")
assert.False(t, hit, "oldest entry should be evicted")
got, hit := cache.Get("query-4")
assert.True(t, hit)
assert.Equal(t, "d", got[0].Slug)
}
func TestSearchCacheEmptyQuery(t *testing.T) {
cache := NewSearchCache(10, 5*time.Minute)
_, hit := cache.Get("")
assert.False(t, hit)
_, hit = cache.Get(" ")
assert.False(t, hit)
}
func TestSearchCacheResultsCopied(t *testing.T) {
cache := NewSearchCache(10, 5*time.Minute)
original := []SearchResult{{Slug: "github", Score: 0.9}}
cache.Put("test", original)
// Mutate original after putting
original[0].Slug = "mutated"
got, hit := cache.Get("test")
assert.True(t, hit)
assert.Equal(t, "github", got[0].Slug, "cache should hold a copy, not a reference")
}
func TestBuildTrigrams(t *testing.T) {
trigrams := buildTrigrams("hello")
assert.Contains(t, trigrams, uint32('h')<<16|uint32('e')<<8|uint32('l'))
assert.Contains(t, trigrams, uint32('e')<<16|uint32('l')<<8|uint32('l'))
assert.Contains(t, trigrams, uint32('l')<<16|uint32('l')<<8|uint32('o'))
assert.Len(t, trigrams, 3)
}
func TestJaccardSimilarity(t *testing.T) {
a := buildTrigrams("github integration")
b := buildTrigrams("github integration tool")
sim := jaccardSimilarity(a, b)
assert.Greater(t, sim, 0.5, "similar strings should have high sim")
c := buildTrigrams("completely different query about databases")
sim2 := jaccardSimilarity(a, c)
assert.Less(t, sim2, 0.3, "dissimilar strings should have low sim")
}
func TestJaccardSimilarityEdgeCases(t *testing.T) {
empty := buildTrigrams("")
nonempty := buildTrigrams("hello")
assert.Equal(t, 1.0, jaccardSimilarity(empty, empty))
assert.Equal(t, 0.0, jaccardSimilarity(empty, nonempty))
assert.Equal(t, 0.0, jaccardSimilarity(nonempty, empty))
}
func TestSearchCacheConcurrency(t *testing.T) {
cache := NewSearchCache(50, 5*time.Minute)
done := make(chan struct{})
// Concurrent writes
go func() {
for i := 0; i < 100; i++ {
cache.Put("query-write-"+string(rune('a'+i%26)), []SearchResult{{Slug: "x"}})
}
done <- struct{}{}
}()
// Concurrent reads
go func() {
for i := 0; i < 100; i++ {
cache.Get("query-write-a")
}
done <- struct{}{}
}()
<-done
}
func TestSearchCacheLRUUpdateOnGet(t *testing.T) {
// Capacity 3
cache := NewSearchCache(3, time.Hour)
// Fill cache: query-A, query-B, query-C
// Use longer strings to ensure trigrams are generated and avoid false positive similarity
cache.Put("query-A", []SearchResult{{Slug: "A"}})
cache.Put("query-B", []SearchResult{{Slug: "B"}})
cache.Put("query-C", []SearchResult{{Slug: "C"}})
// Access query-A (should make it most recently used)
if _, found := cache.Get("query-A"); !found {
t.Fatal("query-A should be in cache")
}
// Add query-D. Should evict query-B (LRU) instead of query-A (which was refreshed)
cache.Put("query-D", []SearchResult{{Slug: "D"}})
// Check if query-A is still there
if _, found := cache.Get("query-A"); !found {
t.Fatalf("query-A was evicted! valid LRU should have kept query-A and evicted query-B.")
}
// Check if query-B is evicted
if _, found := cache.Get("query-B"); found {
t.Fatal("query-B should have been evicted")
}
}
+199
View File
@@ -0,0 +1,199 @@
package tools
import (
"context"
"encoding/json"
"fmt"
"os"
"path/filepath"
"sync"
"time"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/skills"
"github.com/sipeed/picoclaw/pkg/utils"
)
// InstallSkillTool allows the LLM agent to install skills from registries.
// It shares the same RegistryManager that FindSkillsTool uses,
// so all registries configured in config are available for installation.
type InstallSkillTool struct {
registryMgr *skills.RegistryManager
workspace string
mu sync.Mutex
}
// NewInstallSkillTool creates a new InstallSkillTool.
// registryMgr is the shared registry manager (same instance as FindSkillsTool).
// workspace is the root workspace directory; skills install to {workspace}/skills/{slug}/.
func NewInstallSkillTool(registryMgr *skills.RegistryManager, workspace string) *InstallSkillTool {
return &InstallSkillTool{
registryMgr: registryMgr,
workspace: workspace,
mu: sync.Mutex{},
}
}
func (t *InstallSkillTool) Name() string {
return "install_skill"
}
func (t *InstallSkillTool) Description() string {
return "Install a skill from a registry by slug. Downloads and extracts the skill into the workspace. Use find_skills first to discover available skills."
}
func (t *InstallSkillTool) Parameters() map[string]interface{} {
return map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"slug": map[string]interface{}{
"type": "string",
"description": "The unique slug of the skill to install (e.g., 'github', 'docker-compose')",
},
"version": map[string]interface{}{
"type": "string",
"description": "Specific version to install (optional, defaults to latest)",
},
"registry": map[string]interface{}{
"type": "string",
"description": "Registry to install from (required, e.g., 'clawhub')",
},
"force": map[string]interface{}{
"type": "boolean",
"description": "Force reinstall if skill already exists (default false)",
},
},
"required": []string{"slug", "registry"},
}
}
func (t *InstallSkillTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
// Install lock to prevent concurrent directory operations.
// Ideally this should be done at a `slug` level, currently, its at a `workspace` level.
t.mu.Lock()
defer t.mu.Unlock()
// Validate slug
slug, _ := args["slug"].(string)
if err := utils.ValidateSkillIdentifier(slug); err != nil {
return ErrorResult(fmt.Sprintf("invalid slug %q: error: %s", slug, err.Error()))
}
// Validate registry
registryName, _ := args["registry"].(string)
if err := utils.ValidateSkillIdentifier(registryName); err != nil {
return ErrorResult(fmt.Sprintf("invalid registry %q: error: %s", registryName, err.Error()))
}
version, _ := args["version"].(string)
force, _ := args["force"].(bool)
// Check if already installed.
skillsDir := filepath.Join(t.workspace, "skills")
targetDir := filepath.Join(skillsDir, slug)
if !force {
if _, err := os.Stat(targetDir); err == nil {
return ErrorResult(fmt.Sprintf("skill %q already installed at %s. Use force=true to reinstall.", slug, targetDir))
}
} else {
// Force: remove existing if present.
os.RemoveAll(targetDir)
}
// Resolve which registry to use.
registry := t.registryMgr.GetRegistry(registryName)
if registry == nil {
return ErrorResult(fmt.Sprintf("registry %q not found", registryName))
}
// Ensure skills directory exists.
if err := os.MkdirAll(skillsDir, 0755); err != nil {
return ErrorResult(fmt.Sprintf("failed to create skills directory: %v", err))
}
// Download and install (handles metadata, version resolution, extraction).
result, err := registry.DownloadAndInstall(ctx, slug, version, targetDir)
if err != nil {
// Clean up partial install.
rmErr := os.RemoveAll(targetDir)
if rmErr != nil {
logger.ErrorCF("tool", "Failed to remove partial install",
map[string]interface{}{
"tool": "install_skill",
"target_dir": targetDir,
"error": rmErr.Error(),
})
}
return ErrorResult(fmt.Sprintf("failed to install %q: %v", slug, err))
}
// Moderation: block malware.
if result.IsMalwareBlocked {
rmErr := os.RemoveAll(targetDir)
if rmErr != nil {
logger.ErrorCF("tool", "Failed to remove partial install",
map[string]interface{}{
"tool": "install_skill",
"target_dir": targetDir,
"error": rmErr.Error(),
})
}
return ErrorResult(fmt.Sprintf("skill %q is flagged as malicious and cannot be installed", slug))
}
// Write origin metadata.
if err := writeOriginMeta(targetDir, registry.Name(), slug, result.Version); err != nil {
logger.ErrorCF("tool", "Failed to write origin metadata",
map[string]interface{}{
"tool": "install_skill",
"error": err.Error(),
"target": targetDir,
"registry": registry.Name(),
"slug": slug,
"version": result.Version,
})
_ = err
}
// Build result with moderation warning if suspicious.
var output string
if result.IsSuspicious {
output = fmt.Sprintf("⚠️ Warning: skill %q is flagged as suspicious (may contain risky patterns).\n\n", slug)
}
output += fmt.Sprintf("Successfully installed skill %q v%s from %s registry.\nLocation: %s\n",
slug, result.Version, registry.Name(), targetDir)
if result.Summary != "" {
output += fmt.Sprintf("Description: %s\n", result.Summary)
}
output += "\nThe skill is now available and can be loaded in the current session."
return SilentResult(output)
}
// originMeta tracks which registry a skill was installed from.
type originMeta struct {
Version int `json:"version"`
Registry string `json:"registry"`
Slug string `json:"slug"`
InstalledVersion string `json:"installed_version"`
InstalledAt int64 `json:"installed_at"`
}
func writeOriginMeta(targetDir, registryName, slug, version string) error {
meta := originMeta{
Version: 1,
Registry: registryName,
Slug: slug,
InstalledVersion: version,
InstalledAt: time.Now().UnixMilli(),
}
data, err := json.MarshalIndent(meta, "", " ")
if err != nil {
return err
}
return os.WriteFile(filepath.Join(targetDir, ".skill-origin.json"), data, 0644)
}
+103
View File
@@ -0,0 +1,103 @@
package tools
import (
"context"
"os"
"path/filepath"
"testing"
"github.com/sipeed/picoclaw/pkg/skills"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestInstallSkillToolName(t *testing.T) {
tool := NewInstallSkillTool(skills.NewRegistryManager(), t.TempDir())
assert.Equal(t, "install_skill", tool.Name())
}
func TestInstallSkillToolMissingSlug(t *testing.T) {
tool := NewInstallSkillTool(skills.NewRegistryManager(), t.TempDir())
result := tool.Execute(context.Background(), map[string]interface{}{})
assert.True(t, result.IsError)
assert.Contains(t, result.ForLLM, "identifier is required and must be a non-empty string")
}
func TestInstallSkillToolEmptySlug(t *testing.T) {
tool := NewInstallSkillTool(skills.NewRegistryManager(), t.TempDir())
result := tool.Execute(context.Background(), map[string]interface{}{
"slug": " ",
})
assert.True(t, result.IsError)
assert.Contains(t, result.ForLLM, "identifier is required and must be a non-empty string")
}
func TestInstallSkillToolUnsafeSlug(t *testing.T) {
tool := NewInstallSkillTool(skills.NewRegistryManager(), t.TempDir())
cases := []string{
"../etc/passwd",
"path/traversal",
"path\\traversal",
}
for _, slug := range cases {
result := tool.Execute(context.Background(), map[string]interface{}{
"slug": slug,
})
assert.True(t, result.IsError, "slug %q should be rejected", slug)
assert.Contains(t, result.ForLLM, "invalid slug")
}
}
func TestInstallSkillToolAlreadyExists(t *testing.T) {
workspace := t.TempDir()
skillDir := filepath.Join(workspace, "skills", "existing-skill")
require.NoError(t, os.MkdirAll(skillDir, 0755))
tool := NewInstallSkillTool(skills.NewRegistryManager(), workspace)
result := tool.Execute(context.Background(), map[string]interface{}{
"slug": "existing-skill",
"registry": "clawhub",
})
assert.True(t, result.IsError)
assert.Contains(t, result.ForLLM, "already installed")
}
func TestInstallSkillToolRegistryNotFound(t *testing.T) {
workspace := t.TempDir()
tool := NewInstallSkillTool(skills.NewRegistryManager(), workspace)
result := tool.Execute(context.Background(), map[string]interface{}{
"slug": "some-skill",
"registry": "nonexistent",
})
assert.True(t, result.IsError)
assert.Contains(t, result.ForLLM, "registry")
assert.Contains(t, result.ForLLM, "not found")
}
func TestInstallSkillToolParameters(t *testing.T) {
tool := NewInstallSkillTool(skills.NewRegistryManager(), t.TempDir())
params := tool.Parameters()
props, ok := params["properties"].(map[string]interface{})
assert.True(t, ok)
assert.Contains(t, props, "slug")
assert.Contains(t, props, "version")
assert.Contains(t, props, "registry")
assert.Contains(t, props, "force")
required, ok := params["required"].([]string)
assert.True(t, ok)
assert.Contains(t, required, "slug")
assert.Contains(t, required, "registry")
}
func TestInstallSkillToolMissingRegistry(t *testing.T) {
tool := NewInstallSkillTool(skills.NewRegistryManager(), t.TempDir())
result := tool.Execute(context.Background(), map[string]interface{}{
"slug": "some-skill",
})
assert.True(t, result.IsError)
assert.Contains(t, result.ForLLM, "invalid registry")
}
+119
View File
@@ -0,0 +1,119 @@
package tools
import (
"context"
"fmt"
"strings"
"github.com/sipeed/picoclaw/pkg/skills"
)
// FindSkillsTool allows the LLM agent to search for installable skills from registries.
type FindSkillsTool struct {
registryMgr *skills.RegistryManager
cache *skills.SearchCache
}
// NewFindSkillsTool creates a new FindSkillsTool.
// registryMgr is the shared registry manager (built from config in createToolRegistry).
// cache is the search cache for deduplicating similar queries.
func NewFindSkillsTool(registryMgr *skills.RegistryManager, cache *skills.SearchCache) *FindSkillsTool {
return &FindSkillsTool{
registryMgr: registryMgr,
cache: cache,
}
}
func (t *FindSkillsTool) Name() string {
return "find_skills"
}
func (t *FindSkillsTool) Description() string {
return "Search for installable skills from skill registries. Returns skill slugs, descriptions, versions, and relevance scores. Use this to discover skills before installing them with install_skill."
}
func (t *FindSkillsTool) Parameters() map[string]interface{} {
return map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"query": map[string]interface{}{
"type": "string",
"description": "Search query describing the desired skill capability (e.g., 'github integration', 'database management')",
},
"limit": map[string]interface{}{
"type": "integer",
"description": "Maximum number of results to return (1-20, default 5)",
"minimum": 1.0,
"maximum": 20.0,
},
},
"required": []string{"query"},
}
}
func (t *FindSkillsTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
query, ok := args["query"].(string)
query = strings.ToLower(strings.TrimSpace(query))
if !ok || query == "" {
return ErrorResult("query is required and must be a non-empty string")
}
limit := 5
if l, ok := args["limit"].(float64); ok {
li := int(l)
if li >= 1 && li <= 20 {
limit = li
}
}
// Check cache first.
if t.cache != nil {
if cached, hit := t.cache.Get(query); hit {
return SilentResult(formatSearchResults(query, cached, true))
}
}
// Search all registries.
results, err := t.registryMgr.SearchAll(ctx, query, limit)
if err != nil {
return ErrorResult(fmt.Sprintf("skill search failed: %v", err))
}
// Cache the results.
if t.cache != nil && len(results) > 0 {
t.cache.Put(query, results)
}
return SilentResult(formatSearchResults(query, results, false))
}
func formatSearchResults(query string, results []skills.SearchResult, cached bool) string {
if len(results) == 0 {
return fmt.Sprintf("No skills found for query: %q", query)
}
var sb strings.Builder
source := ""
if cached {
source = " (cached)"
}
sb.WriteString(fmt.Sprintf("Found %d skills for %q%s:\n\n", len(results), query, source))
for i, r := range results {
sb.WriteString(fmt.Sprintf("%d. **%s**", i+1, r.Slug))
if r.Version != "" {
sb.WriteString(fmt.Sprintf(" v%s", r.Version))
}
sb.WriteString(fmt.Sprintf(" (score: %.3f, registry: %s)\n", r.Score, r.RegistryName))
if r.DisplayName != "" && r.DisplayName != r.Slug {
sb.WriteString(fmt.Sprintf(" Name: %s\n", r.DisplayName))
}
if r.Summary != "" {
sb.WriteString(fmt.Sprintf(" %s\n", r.Summary))
}
sb.WriteString("\n")
}
sb.WriteString("Use install_skill with the slug to install a skill.")
return sb.String()
}
+82
View File
@@ -0,0 +1,82 @@
package tools
import (
"context"
"testing"
"github.com/sipeed/picoclaw/pkg/skills"
"github.com/stretchr/testify/assert"
)
func TestFindSkillsToolName(t *testing.T) {
tool := NewFindSkillsTool(skills.NewRegistryManager(), nil)
assert.Equal(t, "find_skills", tool.Name())
}
func TestFindSkillsToolMissingQuery(t *testing.T) {
tool := NewFindSkillsTool(skills.NewRegistryManager(), nil)
result := tool.Execute(context.Background(), map[string]interface{}{})
assert.True(t, result.IsError)
assert.Contains(t, result.ForLLM, "query is required")
}
func TestFindSkillsToolEmptyQuery(t *testing.T) {
tool := NewFindSkillsTool(skills.NewRegistryManager(), nil)
result := tool.Execute(context.Background(), map[string]interface{}{
"query": " ",
})
assert.True(t, result.IsError)
}
func TestFindSkillsToolCacheHit(t *testing.T) {
cache := skills.NewSearchCache(10, 5*60*1000*1000*1000) // 5 min
cache.Put("github", []skills.SearchResult{
{Slug: "github", Score: 0.9, RegistryName: "clawhub"},
})
tool := NewFindSkillsTool(skills.NewRegistryManager(), cache)
result := tool.Execute(context.Background(), map[string]interface{}{
"query": "github",
})
assert.False(t, result.IsError)
assert.Contains(t, result.ForLLM, "github")
assert.Contains(t, result.ForLLM, "cached")
}
func TestFindSkillsToolParameters(t *testing.T) {
tool := NewFindSkillsTool(skills.NewRegistryManager(), nil)
params := tool.Parameters()
props, ok := params["properties"].(map[string]interface{})
assert.True(t, ok)
assert.Contains(t, props, "query")
assert.Contains(t, props, "limit")
required, ok := params["required"].([]string)
assert.True(t, ok)
assert.Contains(t, required, "query")
}
func TestFindSkillsToolDescription(t *testing.T) {
tool := NewFindSkillsTool(skills.NewRegistryManager(), nil)
assert.NotEmpty(t, tool.Description())
assert.Contains(t, tool.Description(), "skill")
}
func TestFormatSearchResultsEmpty(t *testing.T) {
result := formatSearchResults("test query", nil, false)
assert.Contains(t, result, "No skills found")
}
func TestFormatSearchResultsWithData(t *testing.T) {
results := []skills.SearchResult{
{Slug: "github", Score: 0.95, DisplayName: "GitHub", Summary: "GitHub API integration", Version: "1.0.0", RegistryName: "clawhub"},
}
output := formatSearchResults("github", results, false)
assert.Contains(t, output, "github")
assert.Contains(t, output, "v1.0.0")
assert.Contains(t, output, "0.950")
assert.Contains(t, output, "clawhub")
assert.Contains(t, output, "install_skill")
}
+93
View File
@@ -0,0 +1,93 @@
package utils
import (
"context"
"fmt"
"io"
"net/http"
"os"
"github.com/sipeed/picoclaw/pkg/logger"
)
// DownloadToFile streams an HTTP response body to a temporary file in small
// chunks (~32KB), keeping peak memory usage constant regardless of file size.
//
// Parameters:
// - ctx: context for cancellation/timeout
// - client: HTTP client to use (caller controls timeouts, transport, etc.)
// - req: fully prepared *http.Request (method, URL, headers, etc.)
// - maxBytes: maximum bytes to download; 0 means no limit
//
// Returns the path to the temporary file. The caller is responsible for
// removing it when done (defer os.Remove(path)).
//
// On any error the temp file is cleaned up automatically.
func DownloadToFile(ctx context.Context, client *http.Client, req *http.Request, maxBytes int64) (string, error) {
// Attach context.
req = req.WithContext(ctx)
logger.DebugCF("download", "Starting download", map[string]interface{}{
"url": req.URL.String(),
"max_bytes": maxBytes,
})
resp, err := client.Do(req)
if err != nil {
return "", fmt.Errorf("request failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
// Read a small amount for the error message.
errBody := make([]byte, 512)
n, _ := io.ReadFull(resp.Body, errBody)
return "", fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(errBody[:n]))
}
// Create temp file.
tmpFile, err := os.CreateTemp("", "picoclaw-dl-*")
if err != nil {
return "", fmt.Errorf("failed to create temp file: %w", err)
}
tmpPath := tmpFile.Name()
logger.DebugCF("download", "Streaming to temp file", map[string]interface{}{
"path": tmpPath,
})
// Cleanup helper — removes the temp file on any error.
cleanup := func() {
_ = tmpFile.Close()
_ = os.Remove(tmpPath)
}
// Optionally limit the download size.
var src io.Reader = resp.Body
if maxBytes > 0 {
src = io.LimitReader(resp.Body, maxBytes+1) // +1 to detect overflow
}
written, err := io.Copy(tmpFile, src)
if err != nil {
cleanup()
return "", fmt.Errorf("download write failed: %w", err)
}
if maxBytes > 0 && written > maxBytes {
cleanup()
return "", fmt.Errorf("download too large: %d bytes (max %d)", written, maxBytes)
}
if err := tmpFile.Close(); err != nil {
_ = os.Remove(tmpPath)
return "", fmt.Errorf("failed to close temp file: %w", err)
}
logger.DebugCF("download", "Download complete", map[string]interface{}{
"path": tmpPath,
"bytes_written": written,
})
return tmpPath, nil
}
+19
View File
@@ -0,0 +1,19 @@
package utils
import (
"fmt"
"strings"
)
// ValidateSkillIdentifier validates that the given skill identifier (slug or registry name) is non-empty
// and does not contain path separators ("/", "\\") or ".." for security.
func ValidateSkillIdentifier(identifier string) error {
trimmed := strings.TrimSpace(identifier)
if trimmed == "" {
return fmt.Errorf("identifier is required and must be a non-empty string")
}
if strings.ContainsAny(trimmed, "/\\") || strings.Contains(trimmed, "..") {
return fmt.Errorf("identifier must not contain path separators or '..' to prevent directory traversal")
}
return nil
}
+9
View File
@@ -14,3 +14,12 @@ func Truncate(s string, maxLen int) string {
}
return string(runes[:maxLen-3]) + "..."
}
// DerefStr dereferences a pointer to a string and
// returns the value or a fallback if the pointer is nil.
func DerefStr(s *string, fallback string) string {
if s == nil {
return fallback
}
return *s
}
+120
View File
@@ -0,0 +1,120 @@
package utils
import (
"archive/zip"
"fmt"
"io"
"os"
"path/filepath"
"strings"
"github.com/sipeed/picoclaw/pkg/logger"
)
// ExtractZipFile extracts a ZIP archive from disk to targetDir.
// It reads entries one at a time from disk, keeping memory usage minimal.
//
// Security: rejects path traversal attempts and symlinks.
func ExtractZipFile(zipPath string, targetDir string) error {
reader, err := zip.OpenReader(zipPath)
if err != nil {
return fmt.Errorf("invalid ZIP: %w", err)
}
defer reader.Close()
logger.DebugCF("zip", "Extracting ZIP", map[string]interface{}{
"zip_path": zipPath,
"target_dir": targetDir,
"entries": len(reader.File),
})
if err := os.MkdirAll(targetDir, 0755); err != nil {
return fmt.Errorf("failed to create target dir: %w", err)
}
for _, f := range reader.File {
// Path traversal protection.
cleanName := filepath.Clean(f.Name)
if strings.HasPrefix(cleanName, "..") || filepath.IsAbs(cleanName) {
return fmt.Errorf("zip entry has unsafe path: %q", f.Name)
}
destPath := filepath.Join(targetDir, cleanName)
// Double-check the resolved path is within target directory (defense-in-depth).
targetDirClean := filepath.Clean(targetDir)
if !strings.HasPrefix(filepath.Clean(destPath), targetDirClean+string(filepath.Separator)) && filepath.Clean(destPath) != targetDirClean {
return fmt.Errorf("zip entry escapes target dir: %q", f.Name)
}
mode := f.FileInfo().Mode()
// Reject any symlink.
if mode&os.ModeSymlink != 0 {
return fmt.Errorf("zip contains symlink %q; symlinks are not allowed", f.Name)
}
if f.FileInfo().IsDir() {
if err := os.MkdirAll(destPath, 0755); err != nil {
return err
}
continue
}
// Ensure parent directory exists.
if err := os.MkdirAll(filepath.Dir(destPath), 0755); err != nil {
return err
}
if err := extractSingleFile(f, destPath); err != nil {
return err
}
}
return nil
}
// extractSingleFile extracts one zip.File entry to destPath, with a size check.
func extractSingleFile(f *zip.File, destPath string) error {
const maxFileSize = 5 * 1024 * 1024 // 5MB, adjust as appropriate
// Check the uncompressed size from the header, if available.
if f.UncompressedSize64 > maxFileSize {
return fmt.Errorf("zip entry %q is too large (%d bytes)", f.Name, f.UncompressedSize64)
}
rc, err := f.Open()
if err != nil {
return fmt.Errorf("failed to open zip entry %q: %w", f.Name, err)
}
defer rc.Close()
outFile, err := os.Create(destPath)
if err != nil {
return fmt.Errorf("failed to create file %q: %w", destPath, err)
}
// We don't return the close error via return, since it's not a named error return.
// Instead, we log to stderr and remove the partially written file as defensive cleanup.
defer func() {
if cerr := outFile.Close(); cerr != nil {
_ = os.Remove(destPath)
logger.ErrorCF("zip", "Failed to close file", map[string]interface{}{
"dest_path": destPath,
"error": cerr.Error(),
})
}
}()
// Streamed size check: prevent overruns and malicious/corrupt headers.
written, err := io.CopyN(outFile, rc, maxFileSize+1)
if err != nil && err != io.EOF {
_ = os.Remove(destPath)
return fmt.Errorf("failed to extract %q: %w", f.Name, err)
}
if written > maxFileSize {
_ = os.Remove(destPath)
return fmt.Errorf("zip entry %q exceeds max size (%d bytes)", f.Name, written)
}
return nil
}