mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
Feature: Implement Skill Discovery - With Clawhub Integration and Caching (#332)
* Add Find Skills and Install Skills * Improvements * fix file name * Update pkg/skills/clawhub_registry.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * fix * Comments addressed * Resolve comments * fix tests * fixes * Comments resolved * Update pkg/skills/search_cache_repro_test.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * minor fix * fix test * fixes --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -11,15 +11,17 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/skills"
|
||||
"github.com/sipeed/picoclaw/pkg/utils"
|
||||
)
|
||||
|
||||
func skillsHelp() {
|
||||
fmt.Println("\nSkills commands:")
|
||||
fmt.Println(" list List installed skills")
|
||||
fmt.Println(" install <repo> Install skill from GitHub")
|
||||
fmt.Println(" install-builtin Install all builtin skills to workspace")
|
||||
fmt.Println(" list-builtin List available builtin skills")
|
||||
fmt.Println(" install-builtin Install all builtin skills to workspace")
|
||||
fmt.Println(" list-builtin List available builtin skills")
|
||||
fmt.Println(" remove <name> Remove installed skill")
|
||||
fmt.Println(" search Search available skills")
|
||||
fmt.Println(" show <name> Show skill details")
|
||||
@@ -30,6 +32,7 @@ func skillsHelp() {
|
||||
fmt.Println(" picoclaw skills install-builtin")
|
||||
fmt.Println(" picoclaw skills list-builtin")
|
||||
fmt.Println(" picoclaw skills remove weather")
|
||||
fmt.Println(" picoclaw skills install --registry clawhub github")
|
||||
}
|
||||
|
||||
func skillsListCmd(loader *skills.SkillsLoader) {
|
||||
@@ -50,13 +53,27 @@ func skillsListCmd(loader *skills.SkillsLoader) {
|
||||
}
|
||||
}
|
||||
|
||||
func skillsInstallCmd(installer *skills.SkillInstaller) {
|
||||
func skillsInstallCmd(installer *skills.SkillInstaller, cfg *config.Config) {
|
||||
if len(os.Args) < 4 {
|
||||
fmt.Println("Usage: picoclaw skills install <github-repo>")
|
||||
fmt.Println("Example: picoclaw skills install sipeed/picoclaw-skills/weather")
|
||||
fmt.Println(" picoclaw skills install --registry <name> <slug>")
|
||||
return
|
||||
}
|
||||
|
||||
// Check for --registry flag.
|
||||
if os.Args[3] == "--registry" {
|
||||
if len(os.Args) < 6 {
|
||||
fmt.Println("Usage: picoclaw skills install --registry <name> <slug>")
|
||||
fmt.Println("Example: picoclaw skills install --registry clawhub github")
|
||||
return
|
||||
}
|
||||
registryName := os.Args[4]
|
||||
slug := os.Args[5]
|
||||
skillsInstallFromRegistry(cfg, registryName, slug)
|
||||
return
|
||||
}
|
||||
|
||||
// Default: install from GitHub (backward compatible).
|
||||
repo := os.Args[3]
|
||||
fmt.Printf("Installing skill from %s...\n", repo)
|
||||
|
||||
@@ -64,11 +81,83 @@ func skillsInstallCmd(installer *skills.SkillInstaller) {
|
||||
defer cancel()
|
||||
|
||||
if err := installer.InstallFromGitHub(ctx, repo); err != nil {
|
||||
fmt.Printf("✗ Failed to install skill: %v\n", err)
|
||||
fmt.Printf("\u2717 Failed to install skill: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
fmt.Printf("✓ Skill '%s' installed successfully!\n", filepath.Base(repo))
|
||||
fmt.Printf("\u2713 Skill '%s' installed successfully!\n", filepath.Base(repo))
|
||||
}
|
||||
|
||||
// skillsInstallFromRegistry installs a skill from a named registry (e.g. clawhub).
|
||||
func skillsInstallFromRegistry(cfg *config.Config, registryName, slug string) {
|
||||
err := utils.ValidateSkillIdentifier(registryName)
|
||||
if err != nil {
|
||||
fmt.Printf("\u2717 Invalid registry name: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
err = utils.ValidateSkillIdentifier(slug)
|
||||
if err != nil {
|
||||
fmt.Printf("\u2717 Invalid slug: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
fmt.Printf("Installing skill '%s' from %s registry...\n", slug, registryName)
|
||||
|
||||
registryMgr := skills.NewRegistryManagerFromConfig(skills.RegistryConfig{
|
||||
MaxConcurrentSearches: cfg.Tools.Skills.MaxConcurrentSearches,
|
||||
ClawHub: skills.ClawHubConfig(cfg.Tools.Skills.Registries.ClawHub),
|
||||
})
|
||||
|
||||
registry := registryMgr.GetRegistry(registryName)
|
||||
if registry == nil {
|
||||
fmt.Printf("\u2717 Registry '%s' not found or not enabled. Check your config.json.\n", registryName)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
workspace := cfg.WorkspacePath()
|
||||
targetDir := filepath.Join(workspace, "skills", slug)
|
||||
|
||||
if _, err := os.Stat(targetDir); err == nil {
|
||||
fmt.Printf("\u2717 Skill '%s' already installed at %s\n", slug, targetDir)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := os.MkdirAll(filepath.Join(workspace, "skills"), 0755); err != nil {
|
||||
fmt.Printf("\u2717 Failed to create skills directory: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
result, err := registry.DownloadAndInstall(ctx, slug, "", targetDir)
|
||||
if err != nil {
|
||||
rmErr := os.RemoveAll(targetDir)
|
||||
if rmErr != nil {
|
||||
fmt.Printf("\u2717 Failed to remove partial install: %v\n", rmErr)
|
||||
}
|
||||
fmt.Printf("\u2717 Failed to install skill: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if result.IsMalwareBlocked {
|
||||
rmErr := os.RemoveAll(targetDir)
|
||||
if rmErr != nil {
|
||||
fmt.Printf("\u2717 Failed to remove partial install: %v\n", rmErr)
|
||||
}
|
||||
fmt.Printf("\u2717 Skill '%s' is flagged as malicious and cannot be installed.\n", slug)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if result.IsSuspicious {
|
||||
fmt.Printf("\u26a0\ufe0f Warning: skill '%s' is flagged as suspicious.\n", slug)
|
||||
}
|
||||
|
||||
fmt.Printf("\u2713 Skill '%s' v%s installed successfully!\n", slug, result.Version)
|
||||
if result.Summary != "" {
|
||||
fmt.Printf(" %s\n", result.Summary)
|
||||
}
|
||||
}
|
||||
|
||||
func skillsRemoveCmd(installer *skills.SkillInstaller, skillName string) {
|
||||
|
||||
@@ -141,7 +141,7 @@ func main() {
|
||||
case "list":
|
||||
skillsListCmd(skillsLoader)
|
||||
case "install":
|
||||
skillsInstallCmd(installer)
|
||||
skillsInstallCmd(installer, cfg)
|
||||
case "remove", "uninstall":
|
||||
if len(os.Args) < 4 {
|
||||
fmt.Println("Usage: picoclaw skills remove <skill-name>")
|
||||
|
||||
@@ -194,6 +194,17 @@
|
||||
"exec": {
|
||||
"enable_deny_patterns": false,
|
||||
"custom_deny_patterns": []
|
||||
},
|
||||
"skills": {
|
||||
"registries": {
|
||||
"clawhub": {
|
||||
"enabled": true,
|
||||
"base_url": "https://clawhub.ai",
|
||||
"search_path": "/api/v1/search",
|
||||
"skills_path": "/api/v1/skills",
|
||||
"download_path": "/api/v1/download"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"heartbeat": {
|
||||
|
||||
@@ -23,6 +23,7 @@ import (
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/routing"
|
||||
"github.com/sipeed/picoclaw/pkg/skills"
|
||||
"github.com/sipeed/picoclaw/pkg/state"
|
||||
"github.com/sipeed/picoclaw/pkg/tools"
|
||||
"github.com/sipeed/picoclaw/pkg/utils"
|
||||
@@ -117,6 +118,15 @@ func registerSharedTools(cfg *config.Config, msgBus *bus.MessageBus, registry *A
|
||||
})
|
||||
agent.Tools.Register(messageTool)
|
||||
|
||||
// Skill discovery and installation tools
|
||||
registryMgr := skills.NewRegistryManagerFromConfig(skills.RegistryConfig{
|
||||
MaxConcurrentSearches: cfg.Tools.Skills.MaxConcurrentSearches,
|
||||
ClawHub: skills.ClawHubConfig(cfg.Tools.Skills.Registries.ClawHub),
|
||||
})
|
||||
searchCache := skills.NewSearchCache(cfg.Tools.Skills.SearchCache.MaxSize, time.Duration(cfg.Tools.Skills.SearchCache.TTLSeconds)*time.Second)
|
||||
agent.Tools.Register(tools.NewFindSkillsTool(registryMgr, searchCache))
|
||||
agent.Tools.Register(tools.NewInstallSkillTool(registryMgr, agent.Workspace))
|
||||
|
||||
// Spawn tool with allowlist checker
|
||||
subagentManager := tools.NewSubagentManager(provider, agent.Model, agent.Workspace, msgBus)
|
||||
subagentManager.SetLLMOptions(agent.MaxTokens, agent.Temperature)
|
||||
|
||||
+31
-3
@@ -416,9 +416,37 @@ type ExecConfig struct {
|
||||
}
|
||||
|
||||
type ToolsConfig struct {
|
||||
Web WebToolsConfig `json:"web"`
|
||||
Cron CronToolsConfig `json:"cron"`
|
||||
Exec ExecConfig `json:"exec"`
|
||||
Web WebToolsConfig `json:"web"`
|
||||
Cron CronToolsConfig `json:"cron"`
|
||||
Exec ExecConfig `json:"exec"`
|
||||
Skills SkillsToolsConfig `json:"skills"`
|
||||
}
|
||||
|
||||
type SkillsToolsConfig struct {
|
||||
Registries SkillsRegistriesConfig `json:"registries"`
|
||||
MaxConcurrentSearches int `json:"max_concurrent_searches" env:"PICOCLAW_SKILLS_MAX_CONCURRENT_SEARCHES"`
|
||||
SearchCache SearchCacheConfig `json:"search_cache"`
|
||||
}
|
||||
|
||||
type SearchCacheConfig struct {
|
||||
MaxSize int `json:"max_size" env:"PICOCLAW_SKILLS_SEARCH_CACHE_MAX_SIZE"`
|
||||
TTLSeconds int `json:"ttl_seconds" env:"PICOCLAW_SKILLS_SEARCH_CACHE_TTL_SECONDS"`
|
||||
}
|
||||
|
||||
type SkillsRegistriesConfig struct {
|
||||
ClawHub ClawHubRegistryConfig `json:"clawhub"`
|
||||
}
|
||||
|
||||
type ClawHubRegistryConfig struct {
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_SKILLS_REGISTRIES_CLAWHUB_ENABLED"`
|
||||
BaseURL string `json:"base_url" env:"PICOCLAW_SKILLS_REGISTRIES_CLAWHUB_BASE_URL"`
|
||||
AuthToken string `json:"auth_token" env:"PICOCLAW_SKILLS_REGISTRIES_CLAWHUB_AUTH_TOKEN"`
|
||||
SearchPath string `json:"search_path" env:"PICOCLAW_SKILLS_REGISTRIES_CLAWHUB_SEARCH_PATH"`
|
||||
SkillsPath string `json:"skills_path" env:"PICOCLAW_SKILLS_REGISTRIES_CLAWHUB_SKILLS_PATH"`
|
||||
DownloadPath string `json:"download_path" env:"PICOCLAW_SKILLS_REGISTRIES_CLAWHUB_DOWNLOAD_PATH"`
|
||||
Timeout int `json:"timeout" env:"PICOCLAW_SKILLS_REGISTRIES_CLAWHUB_TIMEOUT"`
|
||||
MaxZipSize int `json:"max_zip_size" env:"PICOCLAW_SKILLS_REGISTRIES_CLAWHUB_MAX_ZIP_SIZE"`
|
||||
MaxResponseSize int `json:"max_response_size" env:"PICOCLAW_SKILLS_REGISTRIES_CLAWHUB_MAX_RESPONSE_SIZE"`
|
||||
}
|
||||
|
||||
func LoadConfig(path string) (*Config, error) {
|
||||
|
||||
@@ -265,6 +265,19 @@ func DefaultConfig() *Config {
|
||||
Exec: ExecConfig{
|
||||
EnableDenyPatterns: true,
|
||||
},
|
||||
Skills: SkillsToolsConfig{
|
||||
Registries: SkillsRegistriesConfig{
|
||||
ClawHub: ClawHubRegistryConfig{
|
||||
Enabled: true,
|
||||
BaseURL: "https://clawhub.ai",
|
||||
},
|
||||
},
|
||||
MaxConcurrentSearches: 2,
|
||||
SearchCache: SearchCacheConfig{
|
||||
MaxSize: 50,
|
||||
TTLSeconds: 300,
|
||||
},
|
||||
},
|
||||
},
|
||||
Heartbeat: HeartbeatConfig{
|
||||
Enabled: true,
|
||||
|
||||
@@ -0,0 +1,311 @@
|
||||
package skills
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/utils"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultClawHubTimeout = 30 * time.Second
|
||||
defaultMaxZipSize = 50 * 1024 * 1024 // 50 MB
|
||||
defaultMaxResponseSize = 2 * 1024 * 1024 // 2 MB
|
||||
)
|
||||
|
||||
// ClawHubRegistry implements SkillRegistry for the ClawHub platform.
|
||||
type ClawHubRegistry struct {
|
||||
baseURL string
|
||||
authToken string // Optional - for elevated rate limits
|
||||
searchPath string // Search API
|
||||
skillsPath string // For retrieving skill metadata
|
||||
downloadPath string // For fetching ZIP files for download
|
||||
maxZipSize int
|
||||
maxResponseSize int
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
// NewClawHubRegistry creates a new ClawHub registry client from config.
|
||||
func NewClawHubRegistry(cfg ClawHubConfig) *ClawHubRegistry {
|
||||
baseURL := cfg.BaseURL
|
||||
if baseURL == "" {
|
||||
baseURL = "https://clawhub.ai"
|
||||
}
|
||||
searchPath := cfg.SearchPath
|
||||
if searchPath == "" {
|
||||
searchPath = "/api/v1/search"
|
||||
}
|
||||
skillsPath := cfg.SkillsPath
|
||||
if skillsPath == "" {
|
||||
skillsPath = "/api/v1/skills"
|
||||
}
|
||||
downloadPath := cfg.DownloadPath
|
||||
if downloadPath == "" {
|
||||
downloadPath = "/api/v1/download"
|
||||
}
|
||||
|
||||
timeout := defaultClawHubTimeout
|
||||
if cfg.Timeout > 0 {
|
||||
timeout = time.Duration(cfg.Timeout) * time.Second
|
||||
}
|
||||
|
||||
maxZip := defaultMaxZipSize
|
||||
if cfg.MaxZipSize > 0 {
|
||||
maxZip = cfg.MaxZipSize
|
||||
}
|
||||
|
||||
maxResp := defaultMaxResponseSize
|
||||
if cfg.MaxResponseSize > 0 {
|
||||
maxResp = cfg.MaxResponseSize
|
||||
}
|
||||
|
||||
return &ClawHubRegistry{
|
||||
baseURL: baseURL,
|
||||
authToken: cfg.AuthToken,
|
||||
searchPath: searchPath,
|
||||
skillsPath: skillsPath,
|
||||
downloadPath: downloadPath,
|
||||
maxZipSize: maxZip,
|
||||
maxResponseSize: maxResp,
|
||||
client: &http.Client{
|
||||
Timeout: timeout,
|
||||
Transport: &http.Transport{
|
||||
MaxIdleConns: 5,
|
||||
IdleConnTimeout: 30 * time.Second,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ClawHubRegistry) Name() string {
|
||||
return "clawhub"
|
||||
}
|
||||
|
||||
// --- Search ---
|
||||
|
||||
type clawhubSearchResponse struct {
|
||||
Results []clawhubSearchResult `json:"results"`
|
||||
}
|
||||
|
||||
type clawhubSearchResult struct {
|
||||
Score float64 `json:"score"`
|
||||
Slug *string `json:"slug"`
|
||||
DisplayName *string `json:"displayName"`
|
||||
Summary *string `json:"summary"`
|
||||
Version *string `json:"version"`
|
||||
}
|
||||
|
||||
func (c *ClawHubRegistry) Search(ctx context.Context, query string, limit int) ([]SearchResult, error) {
|
||||
u, err := url.Parse(c.baseURL + c.searchPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid base URL: %w", err)
|
||||
}
|
||||
|
||||
q := u.Query()
|
||||
q.Set("q", query)
|
||||
if limit > 0 {
|
||||
q.Set("limit", fmt.Sprintf("%d", limit))
|
||||
}
|
||||
u.RawQuery = q.Encode()
|
||||
|
||||
body, err := c.doGet(ctx, u.String())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("search request failed: %w", err)
|
||||
}
|
||||
|
||||
var resp clawhubSearchResponse
|
||||
if err := json.Unmarshal(body, &resp); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse search response: %w", err)
|
||||
}
|
||||
|
||||
results := make([]SearchResult, 0, len(resp.Results))
|
||||
for _, r := range resp.Results {
|
||||
slug := utils.DerefStr(r.Slug, "")
|
||||
if slug == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
summary := utils.DerefStr(r.Summary, "")
|
||||
if summary == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
displayName := utils.DerefStr(r.DisplayName, "")
|
||||
if displayName == "" {
|
||||
displayName = slug
|
||||
}
|
||||
|
||||
results = append(results, SearchResult{
|
||||
Score: r.Score,
|
||||
Slug: slug,
|
||||
DisplayName: displayName,
|
||||
Summary: summary,
|
||||
Version: utils.DerefStr(r.Version, ""),
|
||||
RegistryName: c.Name(),
|
||||
})
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// --- GetSkillMeta ---
|
||||
|
||||
type clawhubSkillResponse struct {
|
||||
Slug string `json:"slug"`
|
||||
DisplayName string `json:"displayName"`
|
||||
Summary string `json:"summary"`
|
||||
LatestVersion *clawhubVersionInfo `json:"latestVersion"`
|
||||
Moderation *clawhubModerationInfo `json:"moderation"`
|
||||
}
|
||||
|
||||
type clawhubVersionInfo struct {
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
type clawhubModerationInfo struct {
|
||||
IsMalwareBlocked bool `json:"isMalwareBlocked"`
|
||||
IsSuspicious bool `json:"isSuspicious"`
|
||||
}
|
||||
|
||||
func (c *ClawHubRegistry) GetSkillMeta(ctx context.Context, slug string) (*SkillMeta, error) {
|
||||
if err := utils.ValidateSkillIdentifier(slug); err != nil {
|
||||
return nil, fmt.Errorf("invalid slug %q: error: %s", slug, err.Error())
|
||||
}
|
||||
|
||||
u := c.baseURL + c.skillsPath + "/" + url.PathEscape(slug)
|
||||
|
||||
body, err := c.doGet(ctx, u)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("skill metadata request failed: %w", err)
|
||||
}
|
||||
|
||||
var resp clawhubSkillResponse
|
||||
if err := json.Unmarshal(body, &resp); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse skill metadata: %w", err)
|
||||
}
|
||||
|
||||
meta := &SkillMeta{
|
||||
Slug: resp.Slug,
|
||||
DisplayName: resp.DisplayName,
|
||||
Summary: resp.Summary,
|
||||
RegistryName: c.Name(),
|
||||
}
|
||||
|
||||
if resp.LatestVersion != nil {
|
||||
meta.LatestVersion = resp.LatestVersion.Version
|
||||
}
|
||||
if resp.Moderation != nil {
|
||||
meta.IsMalwareBlocked = resp.Moderation.IsMalwareBlocked
|
||||
meta.IsSuspicious = resp.Moderation.IsSuspicious
|
||||
}
|
||||
|
||||
return meta, nil
|
||||
}
|
||||
|
||||
// --- DownloadAndInstall ---
|
||||
|
||||
// DownloadAndInstall fetches metadata (with fallback), resolves version,
|
||||
// downloads the skill ZIP, and extracts it to targetDir.
|
||||
// Returns an InstallResult for the caller to use for moderation decisions.
|
||||
func (c *ClawHubRegistry) DownloadAndInstall(ctx context.Context, slug, version, targetDir string) (*InstallResult, error) {
|
||||
if err := utils.ValidateSkillIdentifier(slug); err != nil {
|
||||
return nil, fmt.Errorf("invalid slug %q: error: %s", slug, err.Error())
|
||||
}
|
||||
|
||||
// Step 1: Fetch metadata (with fallback).
|
||||
result := &InstallResult{}
|
||||
meta, err := c.GetSkillMeta(ctx, slug)
|
||||
if err != nil {
|
||||
// Fallback: proceed without metadata.
|
||||
meta = nil
|
||||
}
|
||||
|
||||
if meta != nil {
|
||||
result.IsMalwareBlocked = meta.IsMalwareBlocked
|
||||
result.IsSuspicious = meta.IsSuspicious
|
||||
result.Summary = meta.Summary
|
||||
}
|
||||
|
||||
// Step 2: Resolve version.
|
||||
installVersion := version
|
||||
if installVersion == "" && meta != nil {
|
||||
installVersion = meta.LatestVersion
|
||||
}
|
||||
if installVersion == "" {
|
||||
installVersion = "latest"
|
||||
}
|
||||
result.Version = installVersion
|
||||
|
||||
// Step 3: Download ZIP to temp file (streams in ~32KB chunks).
|
||||
u, err := url.Parse(c.baseURL + c.downloadPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid base URL: %w", err)
|
||||
}
|
||||
|
||||
q := u.Query()
|
||||
q.Set("slug", slug)
|
||||
if installVersion != "latest" {
|
||||
q.Set("version", installVersion)
|
||||
}
|
||||
u.RawQuery = q.Encode()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", u.String(), nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
if c.authToken != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+c.authToken)
|
||||
}
|
||||
|
||||
tmpPath, err := utils.DownloadToFile(ctx, c.client, req, int64(c.maxZipSize))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("download failed: %w", err)
|
||||
}
|
||||
defer os.Remove(tmpPath)
|
||||
|
||||
// Step 4: Extract from file on disk.
|
||||
if err := utils.ExtractZipFile(tmpPath, targetDir); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// --- HTTP helper ---
|
||||
|
||||
func (c *ClawHubRegistry) doGet(ctx context.Context, urlStr string) ([]byte, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", urlStr, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req.Header.Set("Accept", "application/json")
|
||||
if c.authToken != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+c.authToken)
|
||||
}
|
||||
|
||||
resp, err := c.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Limit response body read to prevent memory issues.
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, int64(c.maxResponseSize)))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
return body, nil
|
||||
}
|
||||
@@ -0,0 +1,256 @@
|
||||
package skills
|
||||
|
||||
import (
|
||||
"archive/zip"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/utils"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func newTestRegistry(serverURL, authToken string) *ClawHubRegistry {
|
||||
return NewClawHubRegistry(ClawHubConfig{
|
||||
Enabled: true,
|
||||
BaseURL: serverURL,
|
||||
AuthToken: authToken,
|
||||
})
|
||||
}
|
||||
|
||||
func TestClawHubRegistrySearch(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "/api/v1/search", r.URL.Path)
|
||||
assert.Equal(t, "github", r.URL.Query().Get("q"))
|
||||
|
||||
slug := "github"
|
||||
name := "GitHub Integration"
|
||||
summary := "Interact with GitHub repos"
|
||||
version := "1.0.0"
|
||||
|
||||
json.NewEncoder(w).Encode(clawhubSearchResponse{
|
||||
Results: []clawhubSearchResult{
|
||||
{Score: 0.95, Slug: &slug, DisplayName: &name, Summary: &summary, Version: &version},
|
||||
},
|
||||
})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
reg := newTestRegistry(srv.URL, "")
|
||||
results, err := reg.Search(context.Background(), "github", 5)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Len(t, results, 1)
|
||||
assert.Equal(t, "github", results[0].Slug)
|
||||
assert.Equal(t, "GitHub Integration", results[0].DisplayName)
|
||||
assert.InDelta(t, 0.95, results[0].Score, 0.001)
|
||||
assert.Equal(t, "clawhub", results[0].RegistryName)
|
||||
}
|
||||
|
||||
func TestClawHubRegistryGetSkillMeta(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "/api/v1/skills/github", r.URL.Path)
|
||||
|
||||
json.NewEncoder(w).Encode(clawhubSkillResponse{
|
||||
Slug: "github",
|
||||
DisplayName: "GitHub Integration",
|
||||
Summary: "Full GitHub API integration",
|
||||
LatestVersion: &clawhubVersionInfo{
|
||||
Version: "2.1.0",
|
||||
},
|
||||
Moderation: &clawhubModerationInfo{
|
||||
IsMalwareBlocked: false,
|
||||
IsSuspicious: true,
|
||||
},
|
||||
})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
reg := newTestRegistry(srv.URL, "")
|
||||
meta, err := reg.GetSkillMeta(context.Background(), "github")
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "github", meta.Slug)
|
||||
assert.Equal(t, "2.1.0", meta.LatestVersion)
|
||||
assert.False(t, meta.IsMalwareBlocked)
|
||||
assert.True(t, meta.IsSuspicious)
|
||||
}
|
||||
|
||||
func TestClawHubRegistryGetSkillMetaUnsafeSlug(t *testing.T) {
|
||||
reg := newTestRegistry("https://example.com", "")
|
||||
_, err := reg.GetSkillMeta(context.Background(), "../etc/passwd")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid slug")
|
||||
}
|
||||
|
||||
func TestClawHubRegistryDownloadAndInstall(t *testing.T) {
|
||||
// Create a valid ZIP in memory.
|
||||
zipBuf := createTestZip(t, map[string]string{
|
||||
"SKILL.md": "---\nname: test-skill\ndescription: A test\n---\nHello skill",
|
||||
"README.md": "# Test Skill\n",
|
||||
})
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/v1/skills/test-skill":
|
||||
// Metadata endpoint.
|
||||
json.NewEncoder(w).Encode(clawhubSkillResponse{
|
||||
Slug: "test-skill",
|
||||
DisplayName: "Test Skill",
|
||||
Summary: "A test skill",
|
||||
LatestVersion: &clawhubVersionInfo{Version: "1.0.0"},
|
||||
})
|
||||
case "/api/v1/download":
|
||||
assert.Equal(t, "test-skill", r.URL.Query().Get("slug"))
|
||||
w.Header().Set("Content-Type", "application/zip")
|
||||
w.Write(zipBuf)
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
targetDir := filepath.Join(tmpDir, "test-skill")
|
||||
|
||||
reg := newTestRegistry(srv.URL, "")
|
||||
result, err := reg.DownloadAndInstall(context.Background(), "test-skill", "1.0.0", targetDir)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "1.0.0", result.Version)
|
||||
assert.False(t, result.IsMalwareBlocked)
|
||||
|
||||
// Verify extracted files.
|
||||
skillContent, err := os.ReadFile(filepath.Join(targetDir, "SKILL.md"))
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, string(skillContent), "Hello skill")
|
||||
|
||||
readmeContent, err := os.ReadFile(filepath.Join(targetDir, "README.md"))
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, string(readmeContent), "# Test Skill")
|
||||
}
|
||||
|
||||
func TestClawHubRegistryAuthToken(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
assert.Equal(t, "Bearer test-token-123", authHeader)
|
||||
json.NewEncoder(w).Encode(clawhubSearchResponse{Results: nil})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
reg := newTestRegistry(srv.URL, "test-token-123")
|
||||
_, _ = reg.Search(context.Background(), "test", 5)
|
||||
}
|
||||
|
||||
func TestExtractZipPathTraversal(t *testing.T) {
|
||||
// Create a ZIP with a path traversal entry.
|
||||
var buf bytes.Buffer
|
||||
zw := zip.NewWriter(&buf)
|
||||
|
||||
// Malicious entry trying to escape directory.
|
||||
w, err := zw.Create("../../etc/passwd")
|
||||
require.NoError(t, err)
|
||||
w.Write([]byte("malicious"))
|
||||
|
||||
zw.Close()
|
||||
|
||||
// Write to temp file for extractZipFile.
|
||||
tmpZip := filepath.Join(t.TempDir(), "bad.zip")
|
||||
require.NoError(t, os.WriteFile(tmpZip, buf.Bytes(), 0644))
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
err = utils.ExtractZipFile(tmpZip, tmpDir)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "unsafe path")
|
||||
}
|
||||
|
||||
func TestExtractZipWithSubdirectories(t *testing.T) {
|
||||
zipBuf := createTestZip(t, map[string]string{
|
||||
"SKILL.md": "root file",
|
||||
"scripts/helper.sh": "#!/bin/bash\necho hello",
|
||||
"examples/demo.yaml": "key: value",
|
||||
})
|
||||
|
||||
// Write to temp file for extractZipFile.
|
||||
tmpZip := filepath.Join(t.TempDir(), "test.zip")
|
||||
require.NoError(t, os.WriteFile(tmpZip, zipBuf, 0644))
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
targetDir := filepath.Join(tmpDir, "my-skill")
|
||||
|
||||
err := utils.ExtractZipFile(tmpZip, targetDir)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify nested file.
|
||||
data, err := os.ReadFile(filepath.Join(targetDir, "scripts", "helper.sh"))
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, string(data), "#!/bin/bash")
|
||||
}
|
||||
|
||||
func TestClawHubRegistrySearchHTTPError(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
w.Write([]byte("Internal Server Error"))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
reg := newTestRegistry(srv.URL, "")
|
||||
_, err := reg.Search(context.Background(), "test", 5)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "500")
|
||||
}
|
||||
|
||||
func TestClawHubRegistrySearchNullableFields(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
validSlug := "valid-slug"
|
||||
validSummary := "valid summary"
|
||||
|
||||
// Return results with various null/empty fields
|
||||
json.NewEncoder(w).Encode(clawhubSearchResponse{
|
||||
Results: []clawhubSearchResult{
|
||||
// Case 1: Null Slug -> Skip
|
||||
{Score: 0.1, Slug: nil, DisplayName: nil, Summary: nil, Version: nil},
|
||||
// Case 2: Valid Slug, Null Summary -> Skip
|
||||
{Score: 0.2, Slug: &validSlug, DisplayName: nil, Summary: nil, Version: nil},
|
||||
// Case 3: Valid Slug, Valid Summary, Null Name -> Keep, Name=Slug
|
||||
{Score: 0.8, Slug: &validSlug, DisplayName: nil, Summary: &validSummary, Version: nil},
|
||||
},
|
||||
})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
reg := newTestRegistry(srv.URL, "")
|
||||
results, err := reg.Search(context.Background(), "test", 5)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Len(t, results, 1, "should only return 1 valid result")
|
||||
|
||||
r := results[0]
|
||||
assert.Equal(t, "valid-slug", r.Slug)
|
||||
assert.Equal(t, "valid-slug", r.DisplayName, "should fallback name to slug")
|
||||
assert.Equal(t, "valid summary", r.Summary)
|
||||
}
|
||||
|
||||
// --- helpers ---
|
||||
|
||||
func createTestZip(t *testing.T, files map[string]string) []byte {
|
||||
t.Helper()
|
||||
var buf bytes.Buffer
|
||||
zw := zip.NewWriter(&buf)
|
||||
|
||||
for name, content := range files {
|
||||
w, err := zw.Create(name)
|
||||
require.NoError(t, err)
|
||||
_, err = w.Write([]byte(content))
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
require.NoError(t, zw.Close())
|
||||
return buf.Bytes()
|
||||
}
|
||||
@@ -0,0 +1,223 @@
|
||||
package skills
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultMaxConcurrentSearches = 2
|
||||
)
|
||||
|
||||
// SearchResult represents a single result from a skill registry search.
|
||||
type SearchResult struct {
|
||||
Score float64 `json:"score"`
|
||||
Slug string `json:"slug"`
|
||||
DisplayName string `json:"display_name"`
|
||||
Summary string `json:"summary"`
|
||||
Version string `json:"version"`
|
||||
RegistryName string `json:"registry_name"`
|
||||
}
|
||||
|
||||
// SkillMeta holds metadata about a skill from a registry.
|
||||
type SkillMeta struct {
|
||||
Slug string `json:"slug"`
|
||||
DisplayName string `json:"display_name"`
|
||||
Summary string `json:"summary"`
|
||||
LatestVersion string `json:"latest_version"`
|
||||
IsMalwareBlocked bool `json:"is_malware_blocked"`
|
||||
IsSuspicious bool `json:"is_suspicious"`
|
||||
RegistryName string `json:"registry_name"`
|
||||
}
|
||||
|
||||
// InstallResult is returned by DownloadAndInstall to carry metadata
|
||||
// back to the caller for moderation and user messaging.
|
||||
type InstallResult struct {
|
||||
Version string
|
||||
IsMalwareBlocked bool
|
||||
IsSuspicious bool
|
||||
Summary string
|
||||
}
|
||||
|
||||
// SkillRegistry is the interface that all skill registries must implement.
|
||||
// Each registry represents a different source of skills (e.g., clawhub.ai)
|
||||
type SkillRegistry interface {
|
||||
// Name returns the unique name of this registry (e.g., "clawhub").
|
||||
Name() string
|
||||
// Search searches the registry for skills matching the query.
|
||||
Search(ctx context.Context, query string, limit int) ([]SearchResult, error)
|
||||
// GetSkillMeta retrieves metadata for a specific skill by slug.
|
||||
GetSkillMeta(ctx context.Context, slug string) (*SkillMeta, error)
|
||||
// DownloadAndInstall fetches metadata, resolves the version, downloads and
|
||||
// installs the skill to targetDir. Returns an InstallResult with metadata
|
||||
// for the caller to use for moderation and user messaging.
|
||||
DownloadAndInstall(ctx context.Context, slug, version, targetDir string) (*InstallResult, error)
|
||||
}
|
||||
|
||||
// RegistryConfig holds configuration for all skill registries.
|
||||
// This is the input to NewRegistryManagerFromConfig.
|
||||
type RegistryConfig struct {
|
||||
ClawHub ClawHubConfig
|
||||
MaxConcurrentSearches int
|
||||
}
|
||||
|
||||
// ClawHubConfig configures the ClawHub registry.
|
||||
type ClawHubConfig struct {
|
||||
Enabled bool
|
||||
BaseURL string
|
||||
AuthToken string
|
||||
SearchPath string // e.g. "/api/v1/search"
|
||||
SkillsPath string // e.g. "/api/v1/skills"
|
||||
DownloadPath string // e.g. "/api/v1/download"
|
||||
Timeout int // seconds, 0 = default (30s)
|
||||
MaxZipSize int // bytes, 0 = default (50MB)
|
||||
MaxResponseSize int // bytes, 0 = default (2MB)
|
||||
}
|
||||
|
||||
// RegistryManager coordinates multiple skill registries.
|
||||
// It fans out search requests and routes installs to the correct registry.
|
||||
type RegistryManager struct {
|
||||
registries []SkillRegistry
|
||||
maxConcurrent int
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewRegistryManager creates an empty RegistryManager.
|
||||
func NewRegistryManager() *RegistryManager {
|
||||
return &RegistryManager{
|
||||
registries: make([]SkillRegistry, 0),
|
||||
maxConcurrent: defaultMaxConcurrentSearches,
|
||||
}
|
||||
}
|
||||
|
||||
// NewRegistryManagerFromConfig builds a RegistryManager from config,
|
||||
// instantiating only the enabled registries.
|
||||
func NewRegistryManagerFromConfig(cfg RegistryConfig) *RegistryManager {
|
||||
rm := NewRegistryManager()
|
||||
if cfg.MaxConcurrentSearches > 0 {
|
||||
rm.maxConcurrent = cfg.MaxConcurrentSearches
|
||||
}
|
||||
if cfg.ClawHub.Enabled {
|
||||
rm.AddRegistry(NewClawHubRegistry(cfg.ClawHub))
|
||||
}
|
||||
return rm
|
||||
}
|
||||
|
||||
// AddRegistry adds a registry to the manager.
|
||||
func (rm *RegistryManager) AddRegistry(r SkillRegistry) {
|
||||
rm.mu.Lock()
|
||||
defer rm.mu.Unlock()
|
||||
rm.registries = append(rm.registries, r)
|
||||
}
|
||||
|
||||
// GetRegistry returns a registry by name, or nil if not found.
|
||||
func (rm *RegistryManager) GetRegistry(name string) SkillRegistry {
|
||||
rm.mu.RLock()
|
||||
defer rm.mu.RUnlock()
|
||||
for _, r := range rm.registries {
|
||||
if r.Name() == name {
|
||||
return r
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SearchAll fans out the query to all registries concurrently
|
||||
// and merges results sorted by score descending.
|
||||
func (rm *RegistryManager) SearchAll(ctx context.Context, query string, limit int) ([]SearchResult, error) {
|
||||
rm.mu.RLock()
|
||||
regs := make([]SkillRegistry, len(rm.registries))
|
||||
copy(regs, rm.registries)
|
||||
rm.mu.RUnlock()
|
||||
|
||||
if len(regs) == 0 {
|
||||
return nil, fmt.Errorf("no registries configured")
|
||||
}
|
||||
|
||||
type regResult struct {
|
||||
results []SearchResult
|
||||
err error
|
||||
}
|
||||
|
||||
// Semaphore: limit concurrency.
|
||||
sem := make(chan struct{}, rm.maxConcurrent)
|
||||
resultsCh := make(chan regResult, len(regs))
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for _, reg := range regs {
|
||||
wg.Add(1)
|
||||
go func(r SkillRegistry) {
|
||||
defer wg.Done()
|
||||
|
||||
// Acquire semaphore slot.
|
||||
select {
|
||||
case sem <- struct{}{}:
|
||||
defer func() { <-sem }()
|
||||
case <-ctx.Done():
|
||||
resultsCh <- regResult{err: ctx.Err()}
|
||||
return
|
||||
}
|
||||
|
||||
searchCtx, cancel := context.WithTimeout(ctx, 1*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
results, err := r.Search(searchCtx, query, limit)
|
||||
if err != nil {
|
||||
slog.Warn("registry search failed", "registry", r.Name(), "error", err)
|
||||
resultsCh <- regResult{err: err}
|
||||
return
|
||||
}
|
||||
resultsCh <- regResult{results: results}
|
||||
}(reg)
|
||||
}
|
||||
|
||||
// Close results channel after all goroutines complete.
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(resultsCh)
|
||||
}()
|
||||
|
||||
var merged []SearchResult
|
||||
var lastErr error
|
||||
|
||||
var anyRegistrySucceeded bool
|
||||
for rr := range resultsCh {
|
||||
if rr.err != nil {
|
||||
lastErr = rr.err
|
||||
continue
|
||||
}
|
||||
anyRegistrySucceeded = true
|
||||
merged = append(merged, rr.results...)
|
||||
}
|
||||
|
||||
// If all registries failed, return the last error.
|
||||
if !anyRegistrySucceeded && lastErr != nil {
|
||||
return nil, fmt.Errorf("all registries failed: %w", lastErr)
|
||||
}
|
||||
|
||||
// Sort by score descending.
|
||||
sortByScoreDesc(merged)
|
||||
|
||||
// Clamp to limit.
|
||||
if limit > 0 && len(merged) > limit {
|
||||
merged = merged[:limit]
|
||||
}
|
||||
|
||||
return merged, nil
|
||||
}
|
||||
|
||||
// sortByScoreDesc sorts SearchResults by Score in descending order (insertion sort — small slices).
|
||||
func sortByScoreDesc(results []SearchResult) {
|
||||
for i := 1; i < len(results); i++ {
|
||||
key := results[i]
|
||||
j := i - 1
|
||||
for j >= 0 && results[j].Score < key.Score {
|
||||
results[j+1] = results[j]
|
||||
j--
|
||||
}
|
||||
results[j+1] = key
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,179 @@
|
||||
package skills
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/utils"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// mockRegistry is a test double implementing SkillRegistry.
|
||||
type mockRegistry struct {
|
||||
name string
|
||||
searchResults []SearchResult
|
||||
searchErr error
|
||||
meta *SkillMeta
|
||||
metaErr error
|
||||
installResult *InstallResult
|
||||
installErr error
|
||||
}
|
||||
|
||||
func (m *mockRegistry) Name() string { return m.name }
|
||||
|
||||
func (m *mockRegistry) Search(_ context.Context, _ string, _ int) ([]SearchResult, error) {
|
||||
return m.searchResults, m.searchErr
|
||||
}
|
||||
|
||||
func (m *mockRegistry) GetSkillMeta(_ context.Context, _ string) (*SkillMeta, error) {
|
||||
return m.meta, m.metaErr
|
||||
}
|
||||
|
||||
func (m *mockRegistry) DownloadAndInstall(_ context.Context, _, _, _ string) (*InstallResult, error) {
|
||||
return m.installResult, m.installErr
|
||||
}
|
||||
|
||||
func TestRegistryManagerSearchAllSingle(t *testing.T) {
|
||||
mgr := NewRegistryManager()
|
||||
mgr.AddRegistry(&mockRegistry{
|
||||
name: "test",
|
||||
searchResults: []SearchResult{
|
||||
{Slug: "skill-a", Score: 0.9, RegistryName: "test"},
|
||||
{Slug: "skill-b", Score: 0.5, RegistryName: "test"},
|
||||
},
|
||||
})
|
||||
|
||||
results, err := mgr.SearchAll(context.Background(), "test query", 10)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, results, 2)
|
||||
assert.Equal(t, "skill-a", results[0].Slug)
|
||||
}
|
||||
|
||||
func TestRegistryManagerSearchAllMultiple(t *testing.T) {
|
||||
mgr := NewRegistryManager()
|
||||
mgr.AddRegistry(&mockRegistry{
|
||||
name: "alpha",
|
||||
searchResults: []SearchResult{
|
||||
{Slug: "skill-a", Score: 0.8, RegistryName: "alpha"},
|
||||
},
|
||||
})
|
||||
mgr.AddRegistry(&mockRegistry{
|
||||
name: "beta",
|
||||
searchResults: []SearchResult{
|
||||
{Slug: "skill-b", Score: 0.95, RegistryName: "beta"},
|
||||
},
|
||||
})
|
||||
|
||||
results, err := mgr.SearchAll(context.Background(), "test query", 10)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, results, 2)
|
||||
// Should be sorted by score descending
|
||||
assert.Equal(t, "skill-b", results[0].Slug)
|
||||
assert.Equal(t, "skill-a", results[1].Slug)
|
||||
}
|
||||
|
||||
func TestRegistryManagerSearchAllOneFailsGracefully(t *testing.T) {
|
||||
mgr := NewRegistryManager()
|
||||
mgr.AddRegistry(&mockRegistry{
|
||||
name: "failing",
|
||||
searchErr: fmt.Errorf("network error"),
|
||||
})
|
||||
mgr.AddRegistry(&mockRegistry{
|
||||
name: "working",
|
||||
searchResults: []SearchResult{
|
||||
{Slug: "skill-a", Score: 0.8, RegistryName: "working"},
|
||||
},
|
||||
})
|
||||
|
||||
results, err := mgr.SearchAll(context.Background(), "test query", 10)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, results, 1)
|
||||
assert.Equal(t, "skill-a", results[0].Slug)
|
||||
}
|
||||
|
||||
func TestRegistryManagerSearchAllAllFail(t *testing.T) {
|
||||
mgr := NewRegistryManager()
|
||||
mgr.AddRegistry(&mockRegistry{
|
||||
name: "fail-1",
|
||||
searchErr: fmt.Errorf("error 1"),
|
||||
})
|
||||
|
||||
_, err := mgr.SearchAll(context.Background(), "test query", 10)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestRegistryManagerSearchAllNoRegistries(t *testing.T) {
|
||||
mgr := NewRegistryManager()
|
||||
_, err := mgr.SearchAll(context.Background(), "test query", 10)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestRegistryManagerGetRegistry(t *testing.T) {
|
||||
mgr := NewRegistryManager()
|
||||
mock := &mockRegistry{name: "clawhub"}
|
||||
mgr.AddRegistry(mock)
|
||||
|
||||
got := mgr.GetRegistry("clawhub")
|
||||
assert.NotNil(t, got)
|
||||
assert.Equal(t, "clawhub", got.Name())
|
||||
|
||||
got = mgr.GetRegistry("nonexistent")
|
||||
assert.Nil(t, got)
|
||||
}
|
||||
|
||||
func TestRegistryManagerSearchAllRespectLimit(t *testing.T) {
|
||||
mgr := NewRegistryManager()
|
||||
results := make([]SearchResult, 20)
|
||||
for i := range results {
|
||||
results[i] = SearchResult{Slug: fmt.Sprintf("skill-%d", i), Score: float64(20 - i)}
|
||||
}
|
||||
mgr.AddRegistry(&mockRegistry{
|
||||
name: "test",
|
||||
searchResults: results,
|
||||
})
|
||||
|
||||
got, err := mgr.SearchAll(context.Background(), "test", 5)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, got, 5)
|
||||
// Top scores first
|
||||
assert.Equal(t, "skill-0", got[0].Slug)
|
||||
}
|
||||
|
||||
func TestRegistryManagerSearchAllTimeout(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
time.Sleep(5 * time.Millisecond) // Let context expire.
|
||||
|
||||
mgr := NewRegistryManager()
|
||||
mgr.AddRegistry(&mockRegistry{
|
||||
name: "slow",
|
||||
searchErr: fmt.Errorf("context deadline exceeded"),
|
||||
})
|
||||
|
||||
_, err := mgr.SearchAll(ctx, "test", 5)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestSortByScoreDesc(t *testing.T) {
|
||||
results := []SearchResult{
|
||||
{Slug: "c", Score: 0.3},
|
||||
{Slug: "a", Score: 0.9},
|
||||
{Slug: "b", Score: 0.5},
|
||||
}
|
||||
sortByScoreDesc(results)
|
||||
assert.Equal(t, "a", results[0].Slug)
|
||||
assert.Equal(t, "b", results[1].Slug)
|
||||
assert.Equal(t, "c", results[2].Slug)
|
||||
}
|
||||
|
||||
func TestIsSafeSlug(t *testing.T) {
|
||||
assert.NoError(t, utils.ValidateSkillIdentifier("github"))
|
||||
assert.NoError(t, utils.ValidateSkillIdentifier("docker-compose"))
|
||||
assert.Error(t, utils.ValidateSkillIdentifier(""))
|
||||
assert.Error(t, utils.ValidateSkillIdentifier("../etc/passwd"))
|
||||
assert.Error(t, utils.ValidateSkillIdentifier("path/traversal"))
|
||||
assert.Error(t, utils.ValidateSkillIdentifier("path\\traversal"))
|
||||
}
|
||||
@@ -0,0 +1,229 @@
|
||||
package skills
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// SearchCache provides lightweight caching for search results.
|
||||
// It uses trigram-based similarity to match similar queries to cached results,
|
||||
// avoiding redundant API calls. Thread-safe for concurrent access.
|
||||
type SearchCache struct {
|
||||
mu sync.RWMutex
|
||||
entries map[string]*cacheEntry
|
||||
order []string // LRU order: oldest first.
|
||||
maxEntries int
|
||||
ttl time.Duration
|
||||
}
|
||||
|
||||
type cacheEntry struct {
|
||||
query string
|
||||
trigrams []uint32
|
||||
results []SearchResult
|
||||
createdAt time.Time
|
||||
}
|
||||
|
||||
// similarityThreshold is the minimum trigram Jaccard similarity for a cache hit.
|
||||
const similarityThreshold = 0.7
|
||||
|
||||
// NewSearchCache creates a new search cache.
|
||||
// maxEntries is the maximum number of cached queries (excess evicts LRU).
|
||||
// ttl is how long each entry lives before expiration.
|
||||
func NewSearchCache(maxEntries int, ttl time.Duration) *SearchCache {
|
||||
if maxEntries <= 0 {
|
||||
maxEntries = 50
|
||||
}
|
||||
if ttl <= 0 {
|
||||
ttl = 5 * time.Minute
|
||||
}
|
||||
return &SearchCache{
|
||||
entries: make(map[string]*cacheEntry),
|
||||
order: make([]string, 0),
|
||||
maxEntries: maxEntries,
|
||||
ttl: ttl,
|
||||
}
|
||||
}
|
||||
|
||||
// Get looks up results for a query. Returns cached results and true if found
|
||||
// (either exact or similar match above threshold). Returns nil, false on miss.
|
||||
func (sc *SearchCache) Get(query string) ([]SearchResult, bool) {
|
||||
normalized := normalizeQuery(query)
|
||||
if normalized == "" {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
sc.mu.Lock()
|
||||
defer sc.mu.Unlock()
|
||||
|
||||
// Exact match first.
|
||||
if entry, ok := sc.entries[normalized]; ok {
|
||||
if time.Since(entry.createdAt) < sc.ttl {
|
||||
sc.moveToEndLocked(normalized)
|
||||
return copyResults(entry.results), true
|
||||
}
|
||||
}
|
||||
|
||||
// Similarity match.
|
||||
queryTrigrams := buildTrigrams(normalized)
|
||||
var bestEntry *cacheEntry
|
||||
var bestSim float64
|
||||
|
||||
for _, entry := range sc.entries {
|
||||
if time.Since(entry.createdAt) >= sc.ttl {
|
||||
continue // Skip expired.
|
||||
}
|
||||
sim := jaccardSimilarity(queryTrigrams, entry.trigrams)
|
||||
if sim > bestSim {
|
||||
bestSim = sim
|
||||
bestEntry = entry
|
||||
}
|
||||
}
|
||||
|
||||
if bestSim >= similarityThreshold && bestEntry != nil {
|
||||
sc.moveToEndLocked(bestEntry.query)
|
||||
return copyResults(bestEntry.results), true
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// Put stores results for a query. Evicts the oldest entry if at capacity.
|
||||
func (sc *SearchCache) Put(query string, results []SearchResult) {
|
||||
normalized := normalizeQuery(query)
|
||||
if normalized == "" {
|
||||
return
|
||||
}
|
||||
|
||||
sc.mu.Lock()
|
||||
defer sc.mu.Unlock()
|
||||
|
||||
// Evict expired entries first.
|
||||
sc.evictExpiredLocked()
|
||||
|
||||
// If already exists, update.
|
||||
if _, ok := sc.entries[normalized]; ok {
|
||||
sc.entries[normalized] = &cacheEntry{
|
||||
query: normalized,
|
||||
trigrams: buildTrigrams(normalized),
|
||||
results: copyResults(results),
|
||||
createdAt: time.Now(),
|
||||
}
|
||||
// Move to end of LRU order.
|
||||
sc.moveToEndLocked(normalized)
|
||||
return
|
||||
}
|
||||
|
||||
// Evict LRU if at capacity.
|
||||
for len(sc.entries) >= sc.maxEntries && len(sc.order) > 0 {
|
||||
oldest := sc.order[0]
|
||||
sc.order = sc.order[1:]
|
||||
delete(sc.entries, oldest)
|
||||
}
|
||||
|
||||
// Insert new entry.
|
||||
sc.entries[normalized] = &cacheEntry{
|
||||
query: normalized,
|
||||
trigrams: buildTrigrams(normalized),
|
||||
results: copyResults(results),
|
||||
createdAt: time.Now(),
|
||||
}
|
||||
sc.order = append(sc.order, normalized)
|
||||
}
|
||||
|
||||
// Len returns the number of entries (for testing).
|
||||
func (sc *SearchCache) Len() int {
|
||||
sc.mu.RLock()
|
||||
defer sc.mu.RUnlock()
|
||||
return len(sc.entries)
|
||||
}
|
||||
|
||||
// --- internal ---
|
||||
|
||||
func (sc *SearchCache) evictExpiredLocked() {
|
||||
now := time.Now()
|
||||
newOrder := make([]string, 0, len(sc.order))
|
||||
for _, key := range sc.order {
|
||||
entry, ok := sc.entries[key]
|
||||
if !ok || now.Sub(entry.createdAt) >= sc.ttl {
|
||||
delete(sc.entries, key)
|
||||
continue
|
||||
}
|
||||
newOrder = append(newOrder, key)
|
||||
}
|
||||
sc.order = newOrder
|
||||
}
|
||||
|
||||
func (sc *SearchCache) moveToEndLocked(key string) {
|
||||
for i, k := range sc.order {
|
||||
if k == key {
|
||||
sc.order = append(sc.order[:i], sc.order[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
sc.order = append(sc.order, key)
|
||||
}
|
||||
|
||||
func normalizeQuery(q string) string {
|
||||
return strings.ToLower(strings.TrimSpace(q))
|
||||
}
|
||||
|
||||
// buildTrigrams generates hash of trigrams from a string.
|
||||
// Example: "hello" → {"hel", "ell", "llo"}
|
||||
// "hel" -> 0x0068656c -> 4 bytes; compared to 16 bytes of a string
|
||||
func buildTrigrams(s string) []uint32 {
|
||||
if len(s) < 3 {
|
||||
return nil
|
||||
}
|
||||
|
||||
trigrams := make([]uint32, 0, len(s)-2)
|
||||
for i := 0; i <= len(s)-3; i++ {
|
||||
trigrams = append(trigrams, uint32(s[i])<<16|uint32(s[i+1])<<8|uint32(s[i+2]))
|
||||
}
|
||||
|
||||
// Sort and Deduplication
|
||||
sort.Slice(trigrams, func(i, j int) bool { return trigrams[i] < trigrams[j] })
|
||||
n := 1
|
||||
for i := 1; i < len(trigrams); i++ {
|
||||
if trigrams[i] != trigrams[i-1] {
|
||||
trigrams[n] = trigrams[i]
|
||||
n++
|
||||
}
|
||||
}
|
||||
|
||||
return trigrams[:n]
|
||||
}
|
||||
|
||||
// jaccardSimilarity computes |A ∩ B| / |A ∪ B|.
|
||||
func jaccardSimilarity(a, b []uint32) float64 {
|
||||
if len(a) == 0 && len(b) == 0 {
|
||||
return 1
|
||||
}
|
||||
i, j := 0, 0
|
||||
intersection := 0
|
||||
|
||||
for i < len(a) && j < len(b) {
|
||||
if a[i] == b[j] {
|
||||
intersection++
|
||||
i++
|
||||
j++
|
||||
} else if a[i] < b[j] {
|
||||
i++
|
||||
} else {
|
||||
j++
|
||||
}
|
||||
}
|
||||
|
||||
union := len(a) + len(b) - intersection
|
||||
return float64(intersection) / float64(union)
|
||||
}
|
||||
|
||||
func copyResults(results []SearchResult) []SearchResult {
|
||||
if results == nil {
|
||||
return nil
|
||||
}
|
||||
cp := make([]SearchResult, len(results))
|
||||
copy(cp, results)
|
||||
return cp
|
||||
}
|
||||
@@ -0,0 +1,200 @@
|
||||
package skills
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestSearchCacheExactHit(t *testing.T) {
|
||||
cache := NewSearchCache(10, 5*time.Minute)
|
||||
|
||||
results := []SearchResult{
|
||||
{Slug: "github", Score: 0.9, RegistryName: "clawhub"},
|
||||
{Slug: "docker", Score: 0.7, RegistryName: "clawhub"},
|
||||
}
|
||||
cache.Put("github integration", results)
|
||||
|
||||
got, hit := cache.Get("github integration")
|
||||
assert.True(t, hit)
|
||||
assert.Len(t, got, 2)
|
||||
assert.Equal(t, "github", got[0].Slug)
|
||||
}
|
||||
|
||||
func TestSearchCacheExactHitCaseInsensitive(t *testing.T) {
|
||||
cache := NewSearchCache(10, 5*time.Minute)
|
||||
|
||||
results := []SearchResult{{Slug: "github", Score: 0.9}}
|
||||
cache.Put("GitHub Integration", results)
|
||||
|
||||
got, hit := cache.Get("github integration")
|
||||
assert.True(t, hit)
|
||||
assert.Len(t, got, 1)
|
||||
}
|
||||
|
||||
func TestSearchCacheSimilarHit(t *testing.T) {
|
||||
cache := NewSearchCache(10, 5*time.Minute)
|
||||
|
||||
results := []SearchResult{{Slug: "github", Score: 0.9}}
|
||||
cache.Put("github integration tool", results)
|
||||
|
||||
// "github integration" is very similar to "github integration tool"
|
||||
got, hit := cache.Get("github integration")
|
||||
assert.True(t, hit)
|
||||
assert.Len(t, got, 1)
|
||||
}
|
||||
|
||||
func TestSearchCacheDissimilarMiss(t *testing.T) {
|
||||
cache := NewSearchCache(10, 5*time.Minute)
|
||||
|
||||
results := []SearchResult{{Slug: "github", Score: 0.9}}
|
||||
cache.Put("github integration", results)
|
||||
|
||||
// Completely unrelated query
|
||||
_, hit := cache.Get("database management")
|
||||
assert.False(t, hit)
|
||||
}
|
||||
|
||||
func TestSearchCacheTTLExpiration(t *testing.T) {
|
||||
cache := NewSearchCache(10, 50*time.Millisecond)
|
||||
|
||||
results := []SearchResult{{Slug: "github", Score: 0.9}}
|
||||
cache.Put("github integration", results)
|
||||
|
||||
// Immediately should hit
|
||||
_, hit := cache.Get("github integration")
|
||||
assert.True(t, hit)
|
||||
|
||||
// Wait for expiration
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
_, hit = cache.Get("github integration")
|
||||
assert.False(t, hit)
|
||||
}
|
||||
|
||||
func TestSearchCacheLRUEviction(t *testing.T) {
|
||||
cache := NewSearchCache(3, 5*time.Minute)
|
||||
|
||||
cache.Put("query-1", []SearchResult{{Slug: "a"}})
|
||||
cache.Put("query-2", []SearchResult{{Slug: "b"}})
|
||||
cache.Put("query-3", []SearchResult{{Slug: "c"}})
|
||||
|
||||
assert.Equal(t, 3, cache.Len())
|
||||
|
||||
// Adding a 4th should evict query-1 (oldest)
|
||||
cache.Put("query-4", []SearchResult{{Slug: "d"}})
|
||||
assert.Equal(t, 3, cache.Len())
|
||||
|
||||
_, hit := cache.Get("query-1")
|
||||
assert.False(t, hit, "oldest entry should be evicted")
|
||||
|
||||
got, hit := cache.Get("query-4")
|
||||
assert.True(t, hit)
|
||||
assert.Equal(t, "d", got[0].Slug)
|
||||
}
|
||||
|
||||
func TestSearchCacheEmptyQuery(t *testing.T) {
|
||||
cache := NewSearchCache(10, 5*time.Minute)
|
||||
|
||||
_, hit := cache.Get("")
|
||||
assert.False(t, hit)
|
||||
|
||||
_, hit = cache.Get(" ")
|
||||
assert.False(t, hit)
|
||||
}
|
||||
|
||||
func TestSearchCacheResultsCopied(t *testing.T) {
|
||||
cache := NewSearchCache(10, 5*time.Minute)
|
||||
|
||||
original := []SearchResult{{Slug: "github", Score: 0.9}}
|
||||
cache.Put("test", original)
|
||||
|
||||
// Mutate original after putting
|
||||
original[0].Slug = "mutated"
|
||||
|
||||
got, hit := cache.Get("test")
|
||||
assert.True(t, hit)
|
||||
assert.Equal(t, "github", got[0].Slug, "cache should hold a copy, not a reference")
|
||||
}
|
||||
|
||||
func TestBuildTrigrams(t *testing.T) {
|
||||
trigrams := buildTrigrams("hello")
|
||||
assert.Contains(t, trigrams, uint32('h')<<16|uint32('e')<<8|uint32('l'))
|
||||
assert.Contains(t, trigrams, uint32('e')<<16|uint32('l')<<8|uint32('l'))
|
||||
assert.Contains(t, trigrams, uint32('l')<<16|uint32('l')<<8|uint32('o'))
|
||||
assert.Len(t, trigrams, 3)
|
||||
}
|
||||
|
||||
func TestJaccardSimilarity(t *testing.T) {
|
||||
a := buildTrigrams("github integration")
|
||||
b := buildTrigrams("github integration tool")
|
||||
|
||||
sim := jaccardSimilarity(a, b)
|
||||
assert.Greater(t, sim, 0.5, "similar strings should have high sim")
|
||||
|
||||
c := buildTrigrams("completely different query about databases")
|
||||
sim2 := jaccardSimilarity(a, c)
|
||||
assert.Less(t, sim2, 0.3, "dissimilar strings should have low sim")
|
||||
}
|
||||
|
||||
func TestJaccardSimilarityEdgeCases(t *testing.T) {
|
||||
empty := buildTrigrams("")
|
||||
nonempty := buildTrigrams("hello")
|
||||
|
||||
assert.Equal(t, 1.0, jaccardSimilarity(empty, empty))
|
||||
assert.Equal(t, 0.0, jaccardSimilarity(empty, nonempty))
|
||||
assert.Equal(t, 0.0, jaccardSimilarity(nonempty, empty))
|
||||
}
|
||||
|
||||
func TestSearchCacheConcurrency(t *testing.T) {
|
||||
cache := NewSearchCache(50, 5*time.Minute)
|
||||
done := make(chan struct{})
|
||||
|
||||
// Concurrent writes
|
||||
go func() {
|
||||
for i := 0; i < 100; i++ {
|
||||
cache.Put("query-write-"+string(rune('a'+i%26)), []SearchResult{{Slug: "x"}})
|
||||
}
|
||||
done <- struct{}{}
|
||||
}()
|
||||
|
||||
// Concurrent reads
|
||||
go func() {
|
||||
for i := 0; i < 100; i++ {
|
||||
cache.Get("query-write-a")
|
||||
}
|
||||
done <- struct{}{}
|
||||
}()
|
||||
|
||||
<-done
|
||||
}
|
||||
|
||||
func TestSearchCacheLRUUpdateOnGet(t *testing.T) {
|
||||
// Capacity 3
|
||||
cache := NewSearchCache(3, time.Hour)
|
||||
|
||||
// Fill cache: query-A, query-B, query-C
|
||||
// Use longer strings to ensure trigrams are generated and avoid false positive similarity
|
||||
cache.Put("query-A", []SearchResult{{Slug: "A"}})
|
||||
cache.Put("query-B", []SearchResult{{Slug: "B"}})
|
||||
cache.Put("query-C", []SearchResult{{Slug: "C"}})
|
||||
|
||||
// Access query-A (should make it most recently used)
|
||||
if _, found := cache.Get("query-A"); !found {
|
||||
t.Fatal("query-A should be in cache")
|
||||
}
|
||||
|
||||
// Add query-D. Should evict query-B (LRU) instead of query-A (which was refreshed)
|
||||
cache.Put("query-D", []SearchResult{{Slug: "D"}})
|
||||
|
||||
// Check if query-A is still there
|
||||
if _, found := cache.Get("query-A"); !found {
|
||||
t.Fatalf("query-A was evicted! valid LRU should have kept query-A and evicted query-B.")
|
||||
}
|
||||
|
||||
// Check if query-B is evicted
|
||||
if _, found := cache.Get("query-B"); found {
|
||||
t.Fatal("query-B should have been evicted")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,199 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/skills"
|
||||
"github.com/sipeed/picoclaw/pkg/utils"
|
||||
)
|
||||
|
||||
// InstallSkillTool allows the LLM agent to install skills from registries.
|
||||
// It shares the same RegistryManager that FindSkillsTool uses,
|
||||
// so all registries configured in config are available for installation.
|
||||
type InstallSkillTool struct {
|
||||
registryMgr *skills.RegistryManager
|
||||
workspace string
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// NewInstallSkillTool creates a new InstallSkillTool.
|
||||
// registryMgr is the shared registry manager (same instance as FindSkillsTool).
|
||||
// workspace is the root workspace directory; skills install to {workspace}/skills/{slug}/.
|
||||
func NewInstallSkillTool(registryMgr *skills.RegistryManager, workspace string) *InstallSkillTool {
|
||||
return &InstallSkillTool{
|
||||
registryMgr: registryMgr,
|
||||
workspace: workspace,
|
||||
mu: sync.Mutex{},
|
||||
}
|
||||
}
|
||||
|
||||
func (t *InstallSkillTool) Name() string {
|
||||
return "install_skill"
|
||||
}
|
||||
|
||||
func (t *InstallSkillTool) Description() string {
|
||||
return "Install a skill from a registry by slug. Downloads and extracts the skill into the workspace. Use find_skills first to discover available skills."
|
||||
}
|
||||
|
||||
func (t *InstallSkillTool) Parameters() map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"slug": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "The unique slug of the skill to install (e.g., 'github', 'docker-compose')",
|
||||
},
|
||||
"version": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "Specific version to install (optional, defaults to latest)",
|
||||
},
|
||||
"registry": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "Registry to install from (required, e.g., 'clawhub')",
|
||||
},
|
||||
"force": map[string]interface{}{
|
||||
"type": "boolean",
|
||||
"description": "Force reinstall if skill already exists (default false)",
|
||||
},
|
||||
},
|
||||
"required": []string{"slug", "registry"},
|
||||
}
|
||||
}
|
||||
|
||||
func (t *InstallSkillTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
|
||||
// Install lock to prevent concurrent directory operations.
|
||||
// Ideally this should be done at a `slug` level, currently, its at a `workspace` level.
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
// Validate slug
|
||||
slug, _ := args["slug"].(string)
|
||||
if err := utils.ValidateSkillIdentifier(slug); err != nil {
|
||||
return ErrorResult(fmt.Sprintf("invalid slug %q: error: %s", slug, err.Error()))
|
||||
}
|
||||
|
||||
// Validate registry
|
||||
registryName, _ := args["registry"].(string)
|
||||
if err := utils.ValidateSkillIdentifier(registryName); err != nil {
|
||||
return ErrorResult(fmt.Sprintf("invalid registry %q: error: %s", registryName, err.Error()))
|
||||
}
|
||||
|
||||
version, _ := args["version"].(string)
|
||||
force, _ := args["force"].(bool)
|
||||
|
||||
// Check if already installed.
|
||||
skillsDir := filepath.Join(t.workspace, "skills")
|
||||
targetDir := filepath.Join(skillsDir, slug)
|
||||
|
||||
if !force {
|
||||
if _, err := os.Stat(targetDir); err == nil {
|
||||
return ErrorResult(fmt.Sprintf("skill %q already installed at %s. Use force=true to reinstall.", slug, targetDir))
|
||||
}
|
||||
} else {
|
||||
// Force: remove existing if present.
|
||||
os.RemoveAll(targetDir)
|
||||
}
|
||||
|
||||
// Resolve which registry to use.
|
||||
registry := t.registryMgr.GetRegistry(registryName)
|
||||
if registry == nil {
|
||||
return ErrorResult(fmt.Sprintf("registry %q not found", registryName))
|
||||
}
|
||||
|
||||
// Ensure skills directory exists.
|
||||
if err := os.MkdirAll(skillsDir, 0755); err != nil {
|
||||
return ErrorResult(fmt.Sprintf("failed to create skills directory: %v", err))
|
||||
}
|
||||
|
||||
// Download and install (handles metadata, version resolution, extraction).
|
||||
result, err := registry.DownloadAndInstall(ctx, slug, version, targetDir)
|
||||
if err != nil {
|
||||
// Clean up partial install.
|
||||
rmErr := os.RemoveAll(targetDir)
|
||||
if rmErr != nil {
|
||||
logger.ErrorCF("tool", "Failed to remove partial install",
|
||||
map[string]interface{}{
|
||||
"tool": "install_skill",
|
||||
"target_dir": targetDir,
|
||||
"error": rmErr.Error(),
|
||||
})
|
||||
}
|
||||
return ErrorResult(fmt.Sprintf("failed to install %q: %v", slug, err))
|
||||
}
|
||||
|
||||
// Moderation: block malware.
|
||||
if result.IsMalwareBlocked {
|
||||
rmErr := os.RemoveAll(targetDir)
|
||||
if rmErr != nil {
|
||||
logger.ErrorCF("tool", "Failed to remove partial install",
|
||||
map[string]interface{}{
|
||||
"tool": "install_skill",
|
||||
"target_dir": targetDir,
|
||||
"error": rmErr.Error(),
|
||||
})
|
||||
}
|
||||
return ErrorResult(fmt.Sprintf("skill %q is flagged as malicious and cannot be installed", slug))
|
||||
}
|
||||
|
||||
// Write origin metadata.
|
||||
if err := writeOriginMeta(targetDir, registry.Name(), slug, result.Version); err != nil {
|
||||
logger.ErrorCF("tool", "Failed to write origin metadata",
|
||||
map[string]interface{}{
|
||||
"tool": "install_skill",
|
||||
"error": err.Error(),
|
||||
"target": targetDir,
|
||||
"registry": registry.Name(),
|
||||
"slug": slug,
|
||||
"version": result.Version,
|
||||
})
|
||||
_ = err
|
||||
}
|
||||
|
||||
// Build result with moderation warning if suspicious.
|
||||
var output string
|
||||
if result.IsSuspicious {
|
||||
output = fmt.Sprintf("⚠️ Warning: skill %q is flagged as suspicious (may contain risky patterns).\n\n", slug)
|
||||
}
|
||||
output += fmt.Sprintf("Successfully installed skill %q v%s from %s registry.\nLocation: %s\n",
|
||||
slug, result.Version, registry.Name(), targetDir)
|
||||
|
||||
if result.Summary != "" {
|
||||
output += fmt.Sprintf("Description: %s\n", result.Summary)
|
||||
}
|
||||
output += "\nThe skill is now available and can be loaded in the current session."
|
||||
|
||||
return SilentResult(output)
|
||||
}
|
||||
|
||||
// originMeta tracks which registry a skill was installed from.
|
||||
type originMeta struct {
|
||||
Version int `json:"version"`
|
||||
Registry string `json:"registry"`
|
||||
Slug string `json:"slug"`
|
||||
InstalledVersion string `json:"installed_version"`
|
||||
InstalledAt int64 `json:"installed_at"`
|
||||
}
|
||||
|
||||
func writeOriginMeta(targetDir, registryName, slug, version string) error {
|
||||
meta := originMeta{
|
||||
Version: 1,
|
||||
Registry: registryName,
|
||||
Slug: slug,
|
||||
InstalledVersion: version,
|
||||
InstalledAt: time.Now().UnixMilli(),
|
||||
}
|
||||
|
||||
data, err := json.MarshalIndent(meta, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return os.WriteFile(filepath.Join(targetDir, ".skill-origin.json"), data, 0644)
|
||||
}
|
||||
@@ -0,0 +1,103 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/skills"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestInstallSkillToolName(t *testing.T) {
|
||||
tool := NewInstallSkillTool(skills.NewRegistryManager(), t.TempDir())
|
||||
assert.Equal(t, "install_skill", tool.Name())
|
||||
}
|
||||
|
||||
func TestInstallSkillToolMissingSlug(t *testing.T) {
|
||||
tool := NewInstallSkillTool(skills.NewRegistryManager(), t.TempDir())
|
||||
result := tool.Execute(context.Background(), map[string]interface{}{})
|
||||
assert.True(t, result.IsError)
|
||||
assert.Contains(t, result.ForLLM, "identifier is required and must be a non-empty string")
|
||||
}
|
||||
|
||||
func TestInstallSkillToolEmptySlug(t *testing.T) {
|
||||
tool := NewInstallSkillTool(skills.NewRegistryManager(), t.TempDir())
|
||||
result := tool.Execute(context.Background(), map[string]interface{}{
|
||||
"slug": " ",
|
||||
})
|
||||
assert.True(t, result.IsError)
|
||||
assert.Contains(t, result.ForLLM, "identifier is required and must be a non-empty string")
|
||||
}
|
||||
|
||||
func TestInstallSkillToolUnsafeSlug(t *testing.T) {
|
||||
tool := NewInstallSkillTool(skills.NewRegistryManager(), t.TempDir())
|
||||
|
||||
cases := []string{
|
||||
"../etc/passwd",
|
||||
"path/traversal",
|
||||
"path\\traversal",
|
||||
}
|
||||
|
||||
for _, slug := range cases {
|
||||
result := tool.Execute(context.Background(), map[string]interface{}{
|
||||
"slug": slug,
|
||||
})
|
||||
assert.True(t, result.IsError, "slug %q should be rejected", slug)
|
||||
assert.Contains(t, result.ForLLM, "invalid slug")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInstallSkillToolAlreadyExists(t *testing.T) {
|
||||
workspace := t.TempDir()
|
||||
skillDir := filepath.Join(workspace, "skills", "existing-skill")
|
||||
require.NoError(t, os.MkdirAll(skillDir, 0755))
|
||||
|
||||
tool := NewInstallSkillTool(skills.NewRegistryManager(), workspace)
|
||||
result := tool.Execute(context.Background(), map[string]interface{}{
|
||||
"slug": "existing-skill",
|
||||
"registry": "clawhub",
|
||||
})
|
||||
assert.True(t, result.IsError)
|
||||
assert.Contains(t, result.ForLLM, "already installed")
|
||||
}
|
||||
|
||||
func TestInstallSkillToolRegistryNotFound(t *testing.T) {
|
||||
workspace := t.TempDir()
|
||||
tool := NewInstallSkillTool(skills.NewRegistryManager(), workspace)
|
||||
result := tool.Execute(context.Background(), map[string]interface{}{
|
||||
"slug": "some-skill",
|
||||
"registry": "nonexistent",
|
||||
})
|
||||
assert.True(t, result.IsError)
|
||||
assert.Contains(t, result.ForLLM, "registry")
|
||||
assert.Contains(t, result.ForLLM, "not found")
|
||||
}
|
||||
|
||||
func TestInstallSkillToolParameters(t *testing.T) {
|
||||
tool := NewInstallSkillTool(skills.NewRegistryManager(), t.TempDir())
|
||||
params := tool.Parameters()
|
||||
|
||||
props, ok := params["properties"].(map[string]interface{})
|
||||
assert.True(t, ok)
|
||||
assert.Contains(t, props, "slug")
|
||||
assert.Contains(t, props, "version")
|
||||
assert.Contains(t, props, "registry")
|
||||
assert.Contains(t, props, "force")
|
||||
|
||||
required, ok := params["required"].([]string)
|
||||
assert.True(t, ok)
|
||||
assert.Contains(t, required, "slug")
|
||||
assert.Contains(t, required, "registry")
|
||||
}
|
||||
|
||||
func TestInstallSkillToolMissingRegistry(t *testing.T) {
|
||||
tool := NewInstallSkillTool(skills.NewRegistryManager(), t.TempDir())
|
||||
result := tool.Execute(context.Background(), map[string]interface{}{
|
||||
"slug": "some-skill",
|
||||
})
|
||||
assert.True(t, result.IsError)
|
||||
assert.Contains(t, result.ForLLM, "invalid registry")
|
||||
}
|
||||
@@ -0,0 +1,119 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/skills"
|
||||
)
|
||||
|
||||
// FindSkillsTool allows the LLM agent to search for installable skills from registries.
|
||||
type FindSkillsTool struct {
|
||||
registryMgr *skills.RegistryManager
|
||||
cache *skills.SearchCache
|
||||
}
|
||||
|
||||
// NewFindSkillsTool creates a new FindSkillsTool.
|
||||
// registryMgr is the shared registry manager (built from config in createToolRegistry).
|
||||
// cache is the search cache for deduplicating similar queries.
|
||||
func NewFindSkillsTool(registryMgr *skills.RegistryManager, cache *skills.SearchCache) *FindSkillsTool {
|
||||
return &FindSkillsTool{
|
||||
registryMgr: registryMgr,
|
||||
cache: cache,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *FindSkillsTool) Name() string {
|
||||
return "find_skills"
|
||||
}
|
||||
|
||||
func (t *FindSkillsTool) Description() string {
|
||||
return "Search for installable skills from skill registries. Returns skill slugs, descriptions, versions, and relevance scores. Use this to discover skills before installing them with install_skill."
|
||||
}
|
||||
|
||||
func (t *FindSkillsTool) Parameters() map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"query": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "Search query describing the desired skill capability (e.g., 'github integration', 'database management')",
|
||||
},
|
||||
"limit": map[string]interface{}{
|
||||
"type": "integer",
|
||||
"description": "Maximum number of results to return (1-20, default 5)",
|
||||
"minimum": 1.0,
|
||||
"maximum": 20.0,
|
||||
},
|
||||
},
|
||||
"required": []string{"query"},
|
||||
}
|
||||
}
|
||||
|
||||
func (t *FindSkillsTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
|
||||
query, ok := args["query"].(string)
|
||||
query = strings.ToLower(strings.TrimSpace(query))
|
||||
if !ok || query == "" {
|
||||
return ErrorResult("query is required and must be a non-empty string")
|
||||
}
|
||||
|
||||
limit := 5
|
||||
if l, ok := args["limit"].(float64); ok {
|
||||
li := int(l)
|
||||
if li >= 1 && li <= 20 {
|
||||
limit = li
|
||||
}
|
||||
}
|
||||
|
||||
// Check cache first.
|
||||
if t.cache != nil {
|
||||
if cached, hit := t.cache.Get(query); hit {
|
||||
return SilentResult(formatSearchResults(query, cached, true))
|
||||
}
|
||||
}
|
||||
|
||||
// Search all registries.
|
||||
results, err := t.registryMgr.SearchAll(ctx, query, limit)
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("skill search failed: %v", err))
|
||||
}
|
||||
|
||||
// Cache the results.
|
||||
if t.cache != nil && len(results) > 0 {
|
||||
t.cache.Put(query, results)
|
||||
}
|
||||
|
||||
return SilentResult(formatSearchResults(query, results, false))
|
||||
}
|
||||
|
||||
func formatSearchResults(query string, results []skills.SearchResult, cached bool) string {
|
||||
if len(results) == 0 {
|
||||
return fmt.Sprintf("No skills found for query: %q", query)
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
source := ""
|
||||
if cached {
|
||||
source = " (cached)"
|
||||
}
|
||||
sb.WriteString(fmt.Sprintf("Found %d skills for %q%s:\n\n", len(results), query, source))
|
||||
|
||||
for i, r := range results {
|
||||
sb.WriteString(fmt.Sprintf("%d. **%s**", i+1, r.Slug))
|
||||
if r.Version != "" {
|
||||
sb.WriteString(fmt.Sprintf(" v%s", r.Version))
|
||||
}
|
||||
sb.WriteString(fmt.Sprintf(" (score: %.3f, registry: %s)\n", r.Score, r.RegistryName))
|
||||
if r.DisplayName != "" && r.DisplayName != r.Slug {
|
||||
sb.WriteString(fmt.Sprintf(" Name: %s\n", r.DisplayName))
|
||||
}
|
||||
if r.Summary != "" {
|
||||
sb.WriteString(fmt.Sprintf(" %s\n", r.Summary))
|
||||
}
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
|
||||
sb.WriteString("Use install_skill with the slug to install a skill.")
|
||||
return sb.String()
|
||||
}
|
||||
@@ -0,0 +1,82 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/skills"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestFindSkillsToolName(t *testing.T) {
|
||||
tool := NewFindSkillsTool(skills.NewRegistryManager(), nil)
|
||||
assert.Equal(t, "find_skills", tool.Name())
|
||||
}
|
||||
|
||||
func TestFindSkillsToolMissingQuery(t *testing.T) {
|
||||
tool := NewFindSkillsTool(skills.NewRegistryManager(), nil)
|
||||
result := tool.Execute(context.Background(), map[string]interface{}{})
|
||||
assert.True(t, result.IsError)
|
||||
assert.Contains(t, result.ForLLM, "query is required")
|
||||
}
|
||||
|
||||
func TestFindSkillsToolEmptyQuery(t *testing.T) {
|
||||
tool := NewFindSkillsTool(skills.NewRegistryManager(), nil)
|
||||
result := tool.Execute(context.Background(), map[string]interface{}{
|
||||
"query": " ",
|
||||
})
|
||||
assert.True(t, result.IsError)
|
||||
}
|
||||
|
||||
func TestFindSkillsToolCacheHit(t *testing.T) {
|
||||
cache := skills.NewSearchCache(10, 5*60*1000*1000*1000) // 5 min
|
||||
cache.Put("github", []skills.SearchResult{
|
||||
{Slug: "github", Score: 0.9, RegistryName: "clawhub"},
|
||||
})
|
||||
|
||||
tool := NewFindSkillsTool(skills.NewRegistryManager(), cache)
|
||||
result := tool.Execute(context.Background(), map[string]interface{}{
|
||||
"query": "github",
|
||||
})
|
||||
|
||||
assert.False(t, result.IsError)
|
||||
assert.Contains(t, result.ForLLM, "github")
|
||||
assert.Contains(t, result.ForLLM, "cached")
|
||||
}
|
||||
|
||||
func TestFindSkillsToolParameters(t *testing.T) {
|
||||
tool := NewFindSkillsTool(skills.NewRegistryManager(), nil)
|
||||
params := tool.Parameters()
|
||||
|
||||
props, ok := params["properties"].(map[string]interface{})
|
||||
assert.True(t, ok)
|
||||
assert.Contains(t, props, "query")
|
||||
assert.Contains(t, props, "limit")
|
||||
|
||||
required, ok := params["required"].([]string)
|
||||
assert.True(t, ok)
|
||||
assert.Contains(t, required, "query")
|
||||
}
|
||||
|
||||
func TestFindSkillsToolDescription(t *testing.T) {
|
||||
tool := NewFindSkillsTool(skills.NewRegistryManager(), nil)
|
||||
assert.NotEmpty(t, tool.Description())
|
||||
assert.Contains(t, tool.Description(), "skill")
|
||||
}
|
||||
|
||||
func TestFormatSearchResultsEmpty(t *testing.T) {
|
||||
result := formatSearchResults("test query", nil, false)
|
||||
assert.Contains(t, result, "No skills found")
|
||||
}
|
||||
|
||||
func TestFormatSearchResultsWithData(t *testing.T) {
|
||||
results := []skills.SearchResult{
|
||||
{Slug: "github", Score: 0.95, DisplayName: "GitHub", Summary: "GitHub API integration", Version: "1.0.0", RegistryName: "clawhub"},
|
||||
}
|
||||
output := formatSearchResults("github", results, false)
|
||||
assert.Contains(t, output, "github")
|
||||
assert.Contains(t, output, "v1.0.0")
|
||||
assert.Contains(t, output, "0.950")
|
||||
assert.Contains(t, output, "clawhub")
|
||||
assert.Contains(t, output, "install_skill")
|
||||
}
|
||||
@@ -0,0 +1,93 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
)
|
||||
|
||||
// DownloadToFile streams an HTTP response body to a temporary file in small
|
||||
// chunks (~32KB), keeping peak memory usage constant regardless of file size.
|
||||
//
|
||||
// Parameters:
|
||||
// - ctx: context for cancellation/timeout
|
||||
// - client: HTTP client to use (caller controls timeouts, transport, etc.)
|
||||
// - req: fully prepared *http.Request (method, URL, headers, etc.)
|
||||
// - maxBytes: maximum bytes to download; 0 means no limit
|
||||
//
|
||||
// Returns the path to the temporary file. The caller is responsible for
|
||||
// removing it when done (defer os.Remove(path)).
|
||||
//
|
||||
// On any error the temp file is cleaned up automatically.
|
||||
func DownloadToFile(ctx context.Context, client *http.Client, req *http.Request, maxBytes int64) (string, error) {
|
||||
// Attach context.
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
logger.DebugCF("download", "Starting download", map[string]interface{}{
|
||||
"url": req.URL.String(),
|
||||
"max_bytes": maxBytes,
|
||||
})
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
// Read a small amount for the error message.
|
||||
errBody := make([]byte, 512)
|
||||
n, _ := io.ReadFull(resp.Body, errBody)
|
||||
return "", fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(errBody[:n]))
|
||||
}
|
||||
|
||||
// Create temp file.
|
||||
tmpFile, err := os.CreateTemp("", "picoclaw-dl-*")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create temp file: %w", err)
|
||||
}
|
||||
tmpPath := tmpFile.Name()
|
||||
|
||||
logger.DebugCF("download", "Streaming to temp file", map[string]interface{}{
|
||||
"path": tmpPath,
|
||||
})
|
||||
|
||||
// Cleanup helper — removes the temp file on any error.
|
||||
cleanup := func() {
|
||||
_ = tmpFile.Close()
|
||||
_ = os.Remove(tmpPath)
|
||||
}
|
||||
|
||||
// Optionally limit the download size.
|
||||
var src io.Reader = resp.Body
|
||||
if maxBytes > 0 {
|
||||
src = io.LimitReader(resp.Body, maxBytes+1) // +1 to detect overflow
|
||||
}
|
||||
|
||||
written, err := io.Copy(tmpFile, src)
|
||||
if err != nil {
|
||||
cleanup()
|
||||
return "", fmt.Errorf("download write failed: %w", err)
|
||||
}
|
||||
|
||||
if maxBytes > 0 && written > maxBytes {
|
||||
cleanup()
|
||||
return "", fmt.Errorf("download too large: %d bytes (max %d)", written, maxBytes)
|
||||
}
|
||||
|
||||
if err := tmpFile.Close(); err != nil {
|
||||
_ = os.Remove(tmpPath)
|
||||
return "", fmt.Errorf("failed to close temp file: %w", err)
|
||||
}
|
||||
|
||||
logger.DebugCF("download", "Download complete", map[string]interface{}{
|
||||
"path": tmpPath,
|
||||
"bytes_written": written,
|
||||
})
|
||||
|
||||
return tmpPath, nil
|
||||
}
|
||||
@@ -0,0 +1,19 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ValidateSkillIdentifier validates that the given skill identifier (slug or registry name) is non-empty
|
||||
// and does not contain path separators ("/", "\\") or ".." for security.
|
||||
func ValidateSkillIdentifier(identifier string) error {
|
||||
trimmed := strings.TrimSpace(identifier)
|
||||
if trimmed == "" {
|
||||
return fmt.Errorf("identifier is required and must be a non-empty string")
|
||||
}
|
||||
if strings.ContainsAny(trimmed, "/\\") || strings.Contains(trimmed, "..") {
|
||||
return fmt.Errorf("identifier must not contain path separators or '..' to prevent directory traversal")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -14,3 +14,12 @@ func Truncate(s string, maxLen int) string {
|
||||
}
|
||||
return string(runes[:maxLen-3]) + "..."
|
||||
}
|
||||
|
||||
// DerefStr dereferences a pointer to a string and
|
||||
// returns the value or a fallback if the pointer is nil.
|
||||
func DerefStr(s *string, fallback string) string {
|
||||
if s == nil {
|
||||
return fallback
|
||||
}
|
||||
return *s
|
||||
}
|
||||
|
||||
@@ -0,0 +1,120 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"archive/zip"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
)
|
||||
|
||||
// ExtractZipFile extracts a ZIP archive from disk to targetDir.
|
||||
// It reads entries one at a time from disk, keeping memory usage minimal.
|
||||
//
|
||||
// Security: rejects path traversal attempts and symlinks.
|
||||
func ExtractZipFile(zipPath string, targetDir string) error {
|
||||
reader, err := zip.OpenReader(zipPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid ZIP: %w", err)
|
||||
}
|
||||
defer reader.Close()
|
||||
|
||||
logger.DebugCF("zip", "Extracting ZIP", map[string]interface{}{
|
||||
"zip_path": zipPath,
|
||||
"target_dir": targetDir,
|
||||
"entries": len(reader.File),
|
||||
})
|
||||
|
||||
if err := os.MkdirAll(targetDir, 0755); err != nil {
|
||||
return fmt.Errorf("failed to create target dir: %w", err)
|
||||
}
|
||||
|
||||
for _, f := range reader.File {
|
||||
// Path traversal protection.
|
||||
cleanName := filepath.Clean(f.Name)
|
||||
if strings.HasPrefix(cleanName, "..") || filepath.IsAbs(cleanName) {
|
||||
return fmt.Errorf("zip entry has unsafe path: %q", f.Name)
|
||||
}
|
||||
|
||||
destPath := filepath.Join(targetDir, cleanName)
|
||||
|
||||
// Double-check the resolved path is within target directory (defense-in-depth).
|
||||
targetDirClean := filepath.Clean(targetDir)
|
||||
if !strings.HasPrefix(filepath.Clean(destPath), targetDirClean+string(filepath.Separator)) && filepath.Clean(destPath) != targetDirClean {
|
||||
return fmt.Errorf("zip entry escapes target dir: %q", f.Name)
|
||||
}
|
||||
|
||||
mode := f.FileInfo().Mode()
|
||||
|
||||
// Reject any symlink.
|
||||
if mode&os.ModeSymlink != 0 {
|
||||
return fmt.Errorf("zip contains symlink %q; symlinks are not allowed", f.Name)
|
||||
}
|
||||
|
||||
if f.FileInfo().IsDir() {
|
||||
if err := os.MkdirAll(destPath, 0755); err != nil {
|
||||
return err
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Ensure parent directory exists.
|
||||
if err := os.MkdirAll(filepath.Dir(destPath), 0755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := extractSingleFile(f, destPath); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// extractSingleFile extracts one zip.File entry to destPath, with a size check.
|
||||
func extractSingleFile(f *zip.File, destPath string) error {
|
||||
const maxFileSize = 5 * 1024 * 1024 // 5MB, adjust as appropriate
|
||||
|
||||
// Check the uncompressed size from the header, if available.
|
||||
if f.UncompressedSize64 > maxFileSize {
|
||||
return fmt.Errorf("zip entry %q is too large (%d bytes)", f.Name, f.UncompressedSize64)
|
||||
}
|
||||
|
||||
rc, err := f.Open()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open zip entry %q: %w", f.Name, err)
|
||||
}
|
||||
defer rc.Close()
|
||||
|
||||
outFile, err := os.Create(destPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create file %q: %w", destPath, err)
|
||||
}
|
||||
// We don't return the close error via return, since it's not a named error return.
|
||||
// Instead, we log to stderr and remove the partially written file as defensive cleanup.
|
||||
defer func() {
|
||||
if cerr := outFile.Close(); cerr != nil {
|
||||
_ = os.Remove(destPath)
|
||||
logger.ErrorCF("zip", "Failed to close file", map[string]interface{}{
|
||||
"dest_path": destPath,
|
||||
"error": cerr.Error(),
|
||||
})
|
||||
}
|
||||
}()
|
||||
|
||||
// Streamed size check: prevent overruns and malicious/corrupt headers.
|
||||
written, err := io.CopyN(outFile, rc, maxFileSize+1)
|
||||
if err != nil && err != io.EOF {
|
||||
_ = os.Remove(destPath)
|
||||
return fmt.Errorf("failed to extract %q: %w", f.Name, err)
|
||||
}
|
||||
if written > maxFileSize {
|
||||
_ = os.Remove(destPath)
|
||||
return fmt.Errorf("zip entry %q exceeds max size (%d bytes)", f.Name, written)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
Reference in New Issue
Block a user