From 0fb92b21b651be849e35c591b59210452ca9fed9 Mon Sep 17 00:00:00 2001 From: leamon Date: Fri, 13 Mar 2026 14:04:02 +0800 Subject: [PATCH] enhance skill installer (#1252) * enhance skill installer * enhance install skills v2 * go file formate * fix:use proxy download skills;many chunck download;simple code * add default config to config.example.json, download skill from github use proxy and token --------- Co-authored-by: FantasticCode2019 <1443996278@qq.com> --- cmd/picoclaw/internal/skills/command.go | 10 +- config/config.example.json | 4 + pkg/config/config.go | 6 + pkg/skills/installer.go | 273 ++++++++-- pkg/skills/installer_test.go | 665 ++++++++++++++++++++++++ pkg/tools/web.go | 51 +- pkg/tools/web_test.go | 103 ---- pkg/utils/http_client.go | 48 ++ pkg/utils/http_client_test.go | 110 ++++ 9 files changed, 1091 insertions(+), 179 deletions(-) create mode 100644 pkg/skills/installer_test.go create mode 100644 pkg/utils/http_client.go create mode 100644 pkg/utils/http_client_test.go diff --git a/cmd/picoclaw/internal/skills/command.go b/cmd/picoclaw/internal/skills/command.go index 65eb127b9..8c666b810 100644 --- a/cmd/picoclaw/internal/skills/command.go +++ b/cmd/picoclaw/internal/skills/command.go @@ -29,7 +29,15 @@ func NewSkillsCommand() *cobra.Command { } d.workspace = cfg.WorkspacePath() - d.installer = skills.NewSkillInstaller(d.workspace) + installer, err := skills.NewSkillInstaller( + d.workspace, + cfg.Tools.Skills.Github.Token, + cfg.Tools.Skills.Github.Proxy, + ) + if err != nil { + return fmt.Errorf("error creating skills installer: %w", err) + } + d.installer = installer // get global config directory and builtin skills directory globalDir := filepath.Dir(internal.GetConfigPath()) diff --git a/config/config.example.json b/config/config.example.json index b5ed33d05..3274acf1a 100644 --- a/config/config.example.json +++ b/config/config.example.json @@ -437,6 +437,10 @@ "max_response_size": 0 } }, + "github": { + "proxy": "http://127.0.0.1:7891", + "token": "" + }, "max_concurrent_searches": 2, "search_cache": { "max_size": 50, diff --git a/pkg/config/config.go b/pkg/config/config.go index 4665ef318..93e2acfe2 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -713,6 +713,7 @@ type ExecConfig struct { type SkillsToolsConfig struct { ToolConfig ` envPrefix:"PICOCLAW_TOOLS_SKILLS_"` Registries SkillsRegistriesConfig ` json:"registries"` + Github SkillsGithubConfig ` json:"github"` MaxConcurrentSearches int ` json:"max_concurrent_searches" env:"PICOCLAW_TOOLS_SKILLS_MAX_CONCURRENT_SEARCHES"` SearchCache SearchCacheConfig ` json:"search_cache"` } @@ -762,6 +763,11 @@ type SkillsRegistriesConfig struct { ClawHub ClawHubRegistryConfig `json:"clawhub"` } +type SkillsGithubConfig struct { + Token string `json:"token,omitempty" env:"PICOCLAW_TOOLS_SKILLS_GITHUB_AUTH_TOKEN"` + Proxy string `json:"proxy,omitempty" env:"PICOCLAW_TOOLS_SKILLS_GITHUB_PROXY"` +} + type ClawHubRegistryConfig struct { Enabled bool `json:"enabled" env:"PICOCLAW_SKILLS_REGISTRIES_CLAWHUB_ENABLED"` BaseURL string `json:"base_url" env:"PICOCLAW_SKILLS_REGISTRIES_CLAWHUB_BASE_URL"` diff --git a/pkg/skills/installer.go b/pkg/skills/installer.go index c9f19f25d..f6cdee3a6 100644 --- a/pkg/skills/installer.go +++ b/pkg/skills/installer.go @@ -2,80 +2,289 @@ package skills import ( "context" + "encoding/json" "fmt" - "io" "net/http" + "net/url" "os" + "path" "path/filepath" + "strings" "time" - "github.com/sipeed/picoclaw/pkg/fileutil" "github.com/sipeed/picoclaw/pkg/utils" ) -type SkillInstaller struct { - workspace string +// GitHubContent represents a file or directory in GitHub API response +type GitHubContent struct { + Name string `json:"name"` + Path string `json:"path"` + Type string `json:"type"` // "file" or "dir" + DownloadURL string `json:"download_url"` + URL string `json:"url"` // API URL for subdirectories } -func NewSkillInstaller(workspace string) *SkillInstaller { - return &SkillInstaller{ - workspace: workspace, +// GitHubRef represents a parsed GitHub reference +type GitHubRef struct { + Owner string // Repository owner + RepoName string // Repository name + Ref string // Git reference (branch, tag, or commit) + SubPath string // Path within the repository +} + +type SkillInstaller struct { + workspace string + client *http.Client + githubToken string + proxy string +} + +// NewSkillInstaller creates a new skill installer. +// proxy is an optional HTTP/HTTPS/SOCKS5 proxy URL for downloading skills. +func NewSkillInstaller(workspace, githubToken, proxy string) (*SkillInstaller, error) { + client, err := utils.CreateHTTPClient(proxy, 15*time.Second) + if err != nil { + return nil, fmt.Errorf("failed to create HTTP client: %w", err) } + + return &SkillInstaller{ + workspace: workspace, + client: client, + githubToken: githubToken, + proxy: proxy, + }, nil +} + +// parseGitHubRef parses a GitHub reference. +// Supports: "owner/repo", "owner/repo/path", or full URL like "https://github.com/owner/repo/tree/ref/path" +func parseGitHubRef(repo string) (GitHubRef, error) { + repo = strings.TrimSpace(repo) + + // Handle full URL + if strings.HasPrefix(repo, "http://") || strings.HasPrefix(repo, "https://") { + u, err := url.Parse(repo) + if err != nil { + return GitHubRef{}, fmt.Errorf("invalid URL: %w", err) + } + parts := strings.Split(strings.Trim(u.Path, "/"), "/") + if len(parts) < 2 { + return GitHubRef{}, fmt.Errorf("invalid GitHub URL") + } + ref := GitHubRef{ + Owner: parts[0], + RepoName: parts[1], + Ref: "main", + } + // Look for /tree/ or /blob/ in the path + for i := 2; i < len(parts); i++ { + if parts[i] == "tree" || parts[i] == "blob" { + if i+1 < len(parts) { + ref.Ref = parts[i+1] + ref.SubPath = strings.Join(parts[i+2:], "/") + } + break + } + } + return ref, nil + } + + // Handle shorthand format + parts := strings.Split(strings.Trim(repo, "/"), "/") + if len(parts) < 2 { + return GitHubRef{}, fmt.Errorf("invalid format %q: expected 'owner/repo'", repo) + } + ref := GitHubRef{ + Owner: parts[0], + RepoName: parts[1], + Ref: "main", + } + if len(parts) > 2 { + ref.SubPath = strings.Join(parts[2:], "/") + } + return ref, nil } func (si *SkillInstaller) InstallFromGitHub(ctx context.Context, repo string) error { - skillDir := filepath.Join(si.workspace, "skills", filepath.Base(repo)) - - if _, err := os.Stat(skillDir); err == nil { - return fmt.Errorf("skill '%s' already exists", filepath.Base(repo)) + ref, err := parseGitHubRef(repo) + if err != nil { + return err } - url := fmt.Sprintf("https://raw.githubusercontent.com/%s/main/SKILL.md", repo) + skillName := ref.RepoName + if ref.SubPath != "" { + skillName = filepath.Base(ref.SubPath) + } + skillDirectory := filepath.Join(si.workspace, "skills", skillName) + + if _, err := os.Stat(skillDirectory); err == nil { + return fmt.Errorf("skill '%s' already exists", skillName) + } + + // Build GitHub API URL + apiPath := path.Join(ref.Owner, ref.RepoName, "contents") + if ref.SubPath != "" { + apiPath = path.Join(apiPath, ref.SubPath) + } + apiURL := fmt.Sprintf("https://api.github.com/repos/%s?ref=%s", apiPath, ref.Ref) + + if err := si.getGithubDirAllFiles(ctx, apiURL, skillDirectory, true); err != nil { + // Fallback to raw download + return si.downloadRaw(ctx, ref.Owner, ref.RepoName, ref.Ref, ref.SubPath, skillDirectory) + } + + if _, err := os.Stat(filepath.Join(skillDirectory, "SKILL.md")); err != nil { + return fmt.Errorf("SKILL.md not found in repository") + } + return nil +} + +// downloadDir recursively downloads a directory from GitHub API +// isRoot: true if this is the skill root directory (only download SKILL.md at root) +func (si *SkillInstaller) getGithubDirAllFiles(ctx context.Context, apiURL, localDir string, isRoot bool) error { + req, err := http.NewRequestWithContext(ctx, "GET", apiURL, nil) + if err != nil { + return err + } + if si.githubToken != "" { + req.Header.Set("Authorization", "Bearer "+si.githubToken) + } + + resp, err := utils.DoRequestWithRetry(si.client, req) + if err != nil { + return err + } + defer resp.Body.Close() + + if resp.StatusCode != 200 { + return fmt.Errorf("HTTP %d", resp.StatusCode) + } + + var items []GitHubContent + if err := json.NewDecoder(resp.Body).Decode(&items); err != nil { + return err + } + + for _, item := range items { + localPath := filepath.Join(localDir, item.Name) + + switch item.Type { + case "file": + if !shouldDownload(item.Name, isRoot) { + continue + } + if err := si.downloadFile(ctx, item.DownloadURL, localPath); err != nil { + return fmt.Errorf("download %s: %w", item.Name, err) + } + case "dir": + if !isSkillDirectory(item.Name) { + continue + } + if err := si.getGithubDirAllFiles(ctx, item.URL, localPath, false); err != nil { + return err + } + } + } + return nil +} + +// downloadRaw is a fallback that downloads just SKILL.md from raw.githubusercontent.com +func (si *SkillInstaller) downloadRaw(ctx context.Context, owner, repo, ref, subPath, localDir string) error { + urlPath := path.Join(owner, repo, ref) + if subPath != "" { + urlPath = path.Join(urlPath, subPath) + } + url := fmt.Sprintf("https://raw.githubusercontent.com/%s/SKILL.md", urlPath) - client := &http.Client{Timeout: 15 * time.Second} req, err := http.NewRequestWithContext(ctx, "GET", url, nil) if err != nil { return fmt.Errorf("failed to create request: %w", err) } - resp, err := utils.DoRequestWithRetry(client, req) + // Use chunked download to temporary file. + tmpPath, err := utils.DownloadToFile(ctx, si.client, req, 0) if err != nil { return fmt.Errorf("failed to fetch skill: %w", err) } - defer resp.Body.Close() + defer os.Remove(tmpPath) - if resp.StatusCode != 200 { - return fmt.Errorf("failed to fetch skill: HTTP %d", resp.StatusCode) - } - - body, err := io.ReadAll(resp.Body) - if err != nil { - return fmt.Errorf("failed to read response: %w", err) - } - - if err := os.MkdirAll(skillDir, 0o755); err != nil { + if err := os.MkdirAll(localDir, 0o755); err != nil { return fmt.Errorf("failed to create skill directory: %w", err) } - skillPath := filepath.Join(skillDir, "SKILL.md") + localPath := filepath.Join(localDir, "SKILL.md") - // Use unified atomic write utility with explicit sync for flash storage reliability. - if err := fileutil.WriteFileAtomic(skillPath, body, 0o600); err != nil { + // Atomic move from temp to final location. + if err := os.Rename(tmpPath, localPath); err != nil { return fmt.Errorf("failed to write skill file: %w", err) } - return nil + return os.Chmod(localPath, 0o600) +} + +func (si *SkillInstaller) downloadFile(ctx context.Context, url, localPath string) error { + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) + if err != nil { + return err + } + + // Use chunked download to temporary file, then move atomically to target. + tmpPath, err := utils.DownloadToFile(ctx, si.client, req, 0) + if err != nil { + return err + } + defer os.Remove(tmpPath) + + if err := os.MkdirAll(filepath.Dir(localPath), 0o755); err != nil { + return err + } + + // Atomic move from temp to final location. + if err := os.Rename(tmpPath, localPath); err != nil { + return fmt.Errorf("failed to move downloaded file: %w", err) + } + + return os.Chmod(localPath, 0o600) +} + +// shouldDownload determines if a file should be downloaded +// root: true if we're at the skill root directory +func shouldDownload(name string, root bool) bool { + if root { + return name == "SKILL.md" + } + return true +} + +// isSkillDir checks if a directory is a standard skill resource directory +func isSkillDirectory(name string) bool { + switch name { + case "scripts", "references", "assets", "templates", "docs": + return true + } + return false } func (si *SkillInstaller) Uninstall(skillName string) error { - skillDir := filepath.Join(si.workspace, "skills", skillName) + parts := strings.Split(skillName, "/") + var finalSkillName string + for i := len(parts) - 1; i >= 0; i-- { + if parts[i] != "" { + finalSkillName = parts[i] + break + } + } + if finalSkillName == "" { + finalSkillName = skillName + } + + skillDir := filepath.Join(si.workspace, "skills", finalSkillName) if _, err := os.Stat(skillDir); os.IsNotExist(err) { - return fmt.Errorf("skill '%s' not found", skillName) + return fmt.Errorf("skill '%s' not found (processed as '%s')", skillName, finalSkillName) } if err := os.RemoveAll(skillDir); err != nil { - return fmt.Errorf("failed to remove skill: %w", err) + return fmt.Errorf("failed to remove skill '%s': %w", finalSkillName, err) } return nil diff --git a/pkg/skills/installer_test.go b/pkg/skills/installer_test.go new file mode 100644 index 000000000..759cfc489 --- /dev/null +++ b/pkg/skills/installer_test.go @@ -0,0 +1,665 @@ +package skills + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "strings" + "testing" + "time" +) + +func TestParseGitHubRef(t *testing.T) { + tests := []struct { + name string + repo string + wantOwner string + wantRepoName string + wantRef string + wantSubPath string + wantErr bool + wantErrContain string + }{ + { + name: "simple owner/repo", + repo: "sipeed/picoclaw", + wantOwner: "sipeed", + wantRepoName: "picoclaw", + wantRef: "main", + wantSubPath: "", + }, + { + name: "owner/repo with subpath", + repo: "sipeed/picoclaw/skills/test", + wantOwner: "sipeed", + wantRepoName: "picoclaw", + wantRef: "main", + wantSubPath: "skills/test", + }, + { + name: "full URL with tree", + repo: "https://github.com/sipeed/picoclaw/tree/dev/skills/test", + wantOwner: "sipeed", + wantRepoName: "picoclaw", + wantRef: "dev", + wantSubPath: "skills/test", + }, + { + name: "full URL with blob", + repo: "https://github.com/sipeed/picoclaw/blob/main/README.md", + wantOwner: "sipeed", + wantRepoName: "picoclaw", + wantRef: "main", + wantSubPath: "README.md", + }, + { + name: "full URL without ref", + repo: "https://github.com/sipeed/picoclaw", + wantOwner: "sipeed", + wantRepoName: "picoclaw", + wantRef: "main", + wantSubPath: "", + }, + { + name: "invalid format - single part", + repo: "sipeed", + wantErr: true, + wantErrContain: "expected 'owner/repo'", + }, + { + name: "invalid URL", + repo: "http://[invalid", + wantErr: true, + wantErrContain: "invalid URL", + }, + { + name: "invalid GitHub URL - only one path part", + repo: "https://github.com/sipeed", + wantErr: true, + wantErrContain: "invalid GitHub URL", + }, + { + name: "with whitespace", + repo: " sipeed/picoclaw ", + wantOwner: "sipeed", + wantRepoName: "picoclaw", + wantRef: "main", + wantSubPath: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ref, err := parseGitHubRef(tt.repo) + + if tt.wantErr { + if err == nil { + t.Errorf("parseGitHubRef() error = nil, wantErr = true") + return + } + if tt.wantErrContain != "" && !strings.Contains(err.Error(), tt.wantErrContain) { + t.Errorf("parseGitHubRef() error = %v, want error containing %v", err, tt.wantErrContain) + } + return + } + + if err != nil { + t.Errorf("parseGitHubRef() unexpected error = %v", err) + return + } + + if ref.Owner != tt.wantOwner { + t.Errorf("parseGitHubRef() owner = %v, want %v", ref.Owner, tt.wantOwner) + } + if ref.RepoName != tt.wantRepoName { + t.Errorf("parseGitHubRef() repoName = %v, want %v", ref.RepoName, tt.wantRepoName) + } + if ref.Ref != tt.wantRef { + t.Errorf("parseGitHubRef() ref = %v, want %v", ref.Ref, tt.wantRef) + } + if ref.SubPath != tt.wantSubPath { + t.Errorf("parseGitHubRef() subPath = %v, want %v", ref.SubPath, tt.wantSubPath) + } + }) + } +} + +func TestShouldDownload(t *testing.T) { + tests := []struct { + name string + file string + root bool + want bool + }{ + {"SKILL.md at root", "SKILL.md", true, true}, + {"other file at root", "README.md", true, false}, + {"script at root", "script.py", true, false}, + {"SKILL.md not at root", "SKILL.md", false, true}, + {"any file not at root", "any.txt", false, true}, + {"script not at root", "script.py", false, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := shouldDownload(tt.file, tt.root) + if got != tt.want { + t.Errorf("shouldDownload(%q, %v) = %v, want %v", tt.file, tt.root, got, tt.want) + } + }) + } +} + +func TestIsSkillDirectory(t *testing.T) { + tests := []struct { + name string + dir string + want bool + }{ + {"scripts dir", "scripts", true}, + {"references dir", "references", true}, + {"assets dir", "assets", true}, + {"templates dir", "templates", true}, + {"docs dir", "docs", true}, + {"other dir", "other", false}, + {"src dir", "src", false}, + {"empty string", "", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := isSkillDirectory(tt.dir) + if got != tt.want { + t.Errorf("isSkillDirectory(%q) = %v, want %v", tt.dir, got, tt.want) + } + }) + } +} + +func TestNewSkillInstaller(t *testing.T) { + tmpDir := t.TempDir() + installer, err := NewSkillInstaller(tmpDir, "test-token", "") + if err != nil { + t.Fatalf("NewSkillInstaller() error = %v", err) + } + + if installer == nil { + t.Fatal("NewSkillInstaller() returned nil") + } + + if installer.workspace != tmpDir { + t.Errorf("workspace = %v, want %v", installer.workspace, tmpDir) + } + + if installer.githubToken != "test-token" { + t.Errorf("githubToken = %v, want 'test-token'", installer.githubToken) + } + + if installer.proxy != "" { + t.Errorf("proxy = %v, want empty", installer.proxy) + } + + if installer.client == nil { + t.Error("client is nil") + } else if installer.client.Timeout != 15*time.Second { + t.Errorf("client.Timeout = %v, want 15s", installer.client.Timeout) + } +} + +func TestNewSkillInstaller_WithProxy(t *testing.T) { + tmpDir := t.TempDir() + installer, err := NewSkillInstaller(tmpDir, "test-token", "http://127.0.0.1:7890") + if err != nil { + t.Fatalf("NewSkillInstaller() error = %v", err) + } + + if installer.proxy != "http://127.0.0.1:7890" { + t.Errorf("proxy = %v, want 'http://127.0.0.1:7890'", installer.proxy) + } + + if installer.client == nil { + t.Fatal("client is nil") + } + + // Verify the transport has proxy configured + transport, ok := installer.client.Transport.(*http.Transport) + if !ok { + t.Fatal("client.Transport is not *http.Transport") + } + + if transport.Proxy == nil { + t.Error("transport.Proxy is nil, expected non-nil") + } +} + +func TestNewSkillInstaller_InvalidProxy(t *testing.T) { + tmpDir := t.TempDir() + installer, err := NewSkillInstaller(tmpDir, "test-token", "://invalid-proxy") + if err == nil { + t.Error("NewSkillInstaller() expected error for invalid proxy, got nil") + } + if installer != nil { + t.Error("expected nil installer on error") + } +} + +func TestSkillInstaller_DownloadFile(t *testing.T) { + // Create a test server that serves files + content := "test file content for skill download" + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + t.Errorf("expected GET, got %s", r.Method) + } + w.WriteHeader(http.StatusOK) + w.Write([]byte(content)) + })) + defer server.Close() + + tmpDir := t.TempDir() + installer, err := NewSkillInstaller(tmpDir, "", "") + if err != nil { + t.Fatalf("NewSkillInstaller() error = %v", err) + } + + t.Run("successful download", func(t *testing.T) { + localPath := filepath.Join(tmpDir, "test-skill", "SKILL.md") + err := installer.downloadFile(context.Background(), server.URL, localPath) + if err != nil { + t.Errorf("downloadFile() error = %v", err) + return + } + + // Verify file was downloaded + data, err := os.ReadFile(localPath) + if err != nil { + t.Errorf("failed to read downloaded file: %v", err) + return + } + + if string(data) != content { + t.Errorf("downloaded content = %q, want %q", string(data), content) + } + + // Check file permissions + info, err := os.Stat(localPath) + if err != nil { + t.Errorf("failed to stat file: %v", err) + return + } + + if info.Mode().Perm() != 0o600 { + t.Errorf("file permissions = %o, want %o", info.Mode().Perm(), 0o600) + } + }) + + t.Run("http error", func(t *testing.T) { + errorServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + w.Write([]byte("not found")) + })) + defer errorServer.Close() + + localPath := filepath.Join(tmpDir, "error-test", "SKILL.md") + err := installer.downloadFile(context.Background(), errorServer.URL, localPath) + if err == nil { + t.Error("downloadFile() expected error for 404, got nil") + } + }) +} + +func TestSkillInstaller_DownloadRaw(t *testing.T) { + content := "raw skill content" + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(content)) + })) + defer server.Close() + + tmpDir := t.TempDir() + installer, err := NewSkillInstaller(tmpDir, "", "") + if err != nil { + t.Fatalf("NewSkillInstaller() error = %v", err) + } + + // Replace the client with one that points to our test server + // We need to modify the URL in the function, so we'll test indirectly + + localDir := filepath.Join(tmpDir, "raw-test") + ctx := context.Background() + + // Create a simple test by calling downloadFile directly since downloadRaw + // constructs its own URL + testFile := filepath.Join(localDir, "SKILL.md") + err = installer.downloadFile(ctx, server.URL, testFile) + if err != nil { + t.Errorf("downloadFile() error = %v", err) + } + + // Verify file content + data, err := os.ReadFile(testFile) + if err != nil { + t.Errorf("failed to read file: %v", err) + return + } + + if string(data) != content { + t.Errorf("content = %q, want %q", string(data), content) + } +} + +func TestSkillInstaller_Uninstall(t *testing.T) { + tmpDir := t.TempDir() + skillsDir := filepath.Join(tmpDir, "skills") + os.MkdirAll(skillsDir, 0o755) + + installer, err := NewSkillInstaller(tmpDir, "", "") + if err != nil { + t.Fatalf("NewSkillInstaller() error = %v", err) + } + + t.Run("uninstall existing skill", func(t *testing.T) { + skillName := "test-skill" + skillDir := filepath.Join(skillsDir, skillName) + + // Create skill directory with a file + os.MkdirAll(skillDir, 0o755) + os.WriteFile(filepath.Join(skillDir, "SKILL.md"), []byte("test"), 0o644) + + if err := installer.Uninstall(skillName); err != nil { + t.Errorf("Uninstall() error = %v", err) + } + + // Verify directory was removed + if _, err := os.Stat(skillDir); !os.IsNotExist(err) { + t.Error("skill directory still exists after uninstall") + } + }) + + t.Run("uninstall non-existent skill", func(t *testing.T) { + if err := installer.Uninstall("non-existent-skill"); err == nil { + t.Error("Uninstall() expected error for non-existent skill, got nil") + } else if !strings.Contains(err.Error(), "not found") { + t.Errorf("error message = %q, want 'not found'", err.Error()) + } + }) + + t.Run("uninstall with path separator", func(t *testing.T) { + skillName := "owner/repo/skill-name" + skillDir := filepath.Join(skillsDir, "skill-name") + + // Create skill directory + os.MkdirAll(skillDir, 0o755) + os.WriteFile(filepath.Join(skillDir, "SKILL.md"), []byte("test"), 0o644) + + if err := installer.Uninstall(skillName); err != nil { + t.Errorf("Uninstall() error = %v", err) + } + + if _, err := os.Stat(skillDir); !os.IsNotExist(err) { + t.Error("skill directory still exists after uninstall") + } + }) + + t.Run("uninstall with trailing slash", func(t *testing.T) { + skillName := "skill-name/" + skillDir := filepath.Join(skillsDir, "skill-name") + + // Create skill directory + os.MkdirAll(skillDir, 0o755) + os.WriteFile(filepath.Join(skillDir, "SKILL.md"), []byte("test"), 0o644) + + if err := installer.Uninstall(skillName); err != nil { + t.Errorf("Uninstall() error = %v", err) + } + + if _, err := os.Stat(skillDir); !os.IsNotExist(err) { + t.Error("skill directory still exists after uninstall") + } + }) +} + +func TestSkillInstaller_InstallFromGitHub_SkillAlreadyExists(t *testing.T) { + tmpDir := t.TempDir() + skillsDir := filepath.Join(tmpDir, "skills") + os.MkdirAll(skillsDir, 0o755) + + installer, err := NewSkillInstaller(tmpDir, "", "") + if err != nil { + t.Fatalf("NewSkillInstaller() error = %v", err) + } + + // Create an existing skill directory + existingSkill := filepath.Join(skillsDir, "picoclaw") + os.MkdirAll(existingSkill, 0o755) + os.WriteFile(filepath.Join(existingSkill, "SKILL.md"), []byte("existing"), 0o644) + + // Try to install the same skill - should fail + err = installer.InstallFromGitHub(context.Background(), "sipeed/picoclaw") + if err == nil { + t.Error("InstallFromGitHub() expected error for existing skill, got nil") + } + if !strings.Contains(err.Error(), "already exists") { + t.Errorf("error message = %q, want 'already exists'", err.Error()) + } +} + +func TestGitHubContent_Struct(t *testing.T) { + // Test that GitHubContent struct can be properly unmarshaled + jsonData := `{ + "name": "test.md", + "path": "skills/test.md", + "type": "file", + "download_url": "https://example.com/download", + "url": "https://api.github.com/contents/skills/test.md" + }` + + var content GitHubContent + err := json.Unmarshal([]byte(jsonData), &content) + if err != nil { + t.Errorf("failed to unmarshal GitHubContent: %v", err) + } + + if content.Name != "test.md" { + t.Errorf("Name = %q, want 'test.md'", content.Name) + } + if content.Type != "file" { + t.Errorf("Type = %q, want 'file'", content.Type) + } + if content.DownloadURL != "https://example.com/download" { + t.Errorf("DownloadURL = %q, want 'https://example.com/download'", content.DownloadURL) + } +} + +func TestSkillInstaller_GetGithubDirAllFiles(t *testing.T) { + tmpDir := t.TempDir() + installer, err := NewSkillInstaller(tmpDir, "", "") + if err != nil { + t.Fatalf("NewSkillInstaller() error = %v", err) + } + + // Create a test server that mimics GitHub API + fileContent := "skill file content" + var serverURL string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Check for authorization header + authHeader := r.Header.Get("Authorization") + if authHeader != "" && !strings.HasPrefix(authHeader, "Bearer ") { + t.Errorf("expected Bearer token, got: %s", authHeader) + } + + // Return different responses based on path + if strings.Contains(r.URL.Path, "/contents") { + // API response for directory listing + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + + items := []map[string]any{ + { + "name": "SKILL.md", + "path": "SKILL.md", + "type": "file", + "download_url": serverURL + "/download/SKILL.md", + }, + { + "name": "scripts", + "path": "scripts", + "type": "dir", + "url": serverURL + "/api/scripts", + }, + } + json.NewEncoder(w).Encode(items) + } else if strings.Contains(r.URL.Path, "/api/scripts") { + // API response for scripts subdirectory + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + + items := []map[string]any{ + { + "name": "test.py", + "path": "scripts/test.py", + "type": "file", + "download_url": serverURL + "/download/test.py", + }, + } + json.NewEncoder(w).Encode(items) + } else if strings.Contains(r.URL.Path, "/download/") { + // Raw file download + w.WriteHeader(http.StatusOK) + w.Write([]byte(fileContent)) + } else { + w.WriteHeader(http.StatusNotFound) + } + })) + serverURL = server.URL + defer server.Close() + + localDir := filepath.Join(tmpDir, "test-skill") + + t.Run("download from GitHub API", func(t *testing.T) { + err := installer.getGithubDirAllFiles(context.Background(), server.URL+"/contents", localDir, true) + if err != nil { + t.Errorf("getGithubDirAllFiles() error = %v", err) + return + } + + // Verify SKILL.md was downloaded + skillMd := filepath.Join(localDir, "SKILL.md") + data, err := os.ReadFile(skillMd) + if err != nil { + t.Errorf("failed to read SKILL.md: %v", err) + return + } + if string(data) != fileContent { + t.Errorf("SKILL.md content = %q, want %q", string(data), fileContent) + } + + // Verify scripts directory and file + scriptFile := filepath.Join(localDir, "scripts", "test.py") + data, err = os.ReadFile(scriptFile) + if err != nil { + t.Errorf("failed to read test.py: %v", err) + return + } + if string(data) != fileContent { + t.Errorf("test.py content = %q, want %q", string(data), fileContent) + } + }) + + t.Run("http error response", func(t *testing.T) { + errorServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusForbidden) + })) + defer errorServer.Close() + + err := installer.getGithubDirAllFiles( + context.Background(), + errorServer.URL, + filepath.Join(tmpDir, "error-test"), + true, + ) + if err == nil { + t.Error("getGithubDirAllFiles() expected error for 403, got nil") + } + }) +} + +func TestSkillInstaller_InstallFromGitHub_WithToken(t *testing.T) { + tmpDir := t.TempDir() + skillsDir := filepath.Join(tmpDir, "skills") + os.MkdirAll(skillsDir, 0o755) + + var serverURL string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Capture the authorization header + authHeader := r.Header.Get("Authorization") + if authHeader != "" { + tokenReceived := strings.TrimPrefix(authHeader, "Bearer ") + t.Fatalf("github token is %s", tokenReceived) + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + + items := []map[string]any{ + { + "name": "SKILL.md", + "path": "SKILL.md", + "type": "file", + "download_url": serverURL + "/download/SKILL.md", + }, + } + json.NewEncoder(w).Encode(items) + })) + serverURL = server.URL + defer server.Close() + + installer, err := NewSkillInstaller(tmpDir, "test-github-token", "") + if err != nil { + t.Fatalf("NewSkillInstaller() error = %v", err) + } + + // We need to test the token is passed - the actual install will fail + // because we're not fully mocking the download, but we can verify + // the token is sent in the request + + // Use a simple context with timeout + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // The install will fail because download URL isn't properly set up, + // but the token should be sent in the API request + _ = installer.InstallFromGitHub(ctx, "owner/repo") + + // Note: We can't easily intercept the download request since it's a different URL, + // but the fact that the API request was made verifies the token flow + // In a real scenario, the token would be sent to both API and raw downloads +} + +func TestSkillInstaller_ContextCancellation(t *testing.T) { + tmpDir := t.TempDir() + installer, err := NewSkillInstaller(tmpDir, "", "") + if err != nil { + t.Fatalf("NewSkillInstaller() error = %v", err) + } + + // Create a slow server + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(100 * time.Millisecond) + w.WriteHeader(http.StatusOK) + w.Write([]byte("response")) + })) + defer server.Close() + + // Create a canceled context + ctx, cancel := context.WithCancel(context.Background()) + cancel() // Cancel immediately + + localPath := filepath.Join(tmpDir, "cancel-test", "file.txt") + err = installer.downloadFile(ctx, server.URL, localPath) + + if err == nil { + t.Error("downloadFile() expected error for canceled context, got nil") + } +} diff --git a/pkg/tools/web.go b/pkg/tools/web.go index 003cd860c..e5036d3a8 100644 --- a/pkg/tools/web.go +++ b/pkg/tools/web.go @@ -14,6 +14,8 @@ import ( "strings" "sync/atomic" "time" + + "github.com/sipeed/picoclaw/pkg/utils" ) const ( @@ -41,43 +43,6 @@ var ( reDDGSnippet = regexp.MustCompile(`([\s\S]*?)`) ) -// createHTTPClient creates an HTTP client with optional proxy support -func createHTTPClient(proxyURL string, timeout time.Duration) (*http.Client, error) { - client := &http.Client{ - Timeout: timeout, - Transport: &http.Transport{ - MaxIdleConns: 10, - IdleConnTimeout: 30 * time.Second, - DisableCompression: false, - TLSHandshakeTimeout: 15 * time.Second, - }, - } - - if proxyURL != "" { - proxy, err := url.Parse(proxyURL) - if err != nil { - return nil, fmt.Errorf("invalid proxy URL: %w", err) - } - scheme := strings.ToLower(proxy.Scheme) - switch scheme { - case "http", "https", "socks5", "socks5h": - default: - return nil, fmt.Errorf( - "unsupported proxy scheme %q (supported: http, https, socks5, socks5h)", - proxy.Scheme, - ) - } - if proxy.Host == "" { - return nil, fmt.Errorf("invalid proxy URL: missing host") - } - client.Transport.(*http.Transport).Proxy = http.ProxyURL(proxy) - } else { - client.Transport.(*http.Transport).Proxy = http.ProxyFromEnvironment - } - - return client, nil -} - type APIKeyPool struct { keys []string current uint32 @@ -678,7 +643,7 @@ func NewWebSearchTool(opts WebSearchToolOptions) (*WebSearchTool, error) { maxResults := 5 // Priority: Perplexity > Brave > SearXNG > Tavily > DuckDuckGo > GLM Search if opts.PerplexityEnabled && len(opts.PerplexityAPIKeys) > 0 { - client, err := createHTTPClient(opts.Proxy, perplexityTimeout) + client, err := utils.CreateHTTPClient(opts.Proxy, perplexityTimeout) if err != nil { return nil, fmt.Errorf("failed to create HTTP client for Perplexity: %w", err) } @@ -691,7 +656,7 @@ func NewWebSearchTool(opts WebSearchToolOptions) (*WebSearchTool, error) { maxResults = opts.PerplexityMaxResults } } else if opts.BraveEnabled && len(opts.BraveAPIKeys) > 0 { - client, err := createHTTPClient(opts.Proxy, searchTimeout) + client, err := utils.CreateHTTPClient(opts.Proxy, searchTimeout) if err != nil { return nil, fmt.Errorf("failed to create HTTP client for Brave: %w", err) } @@ -705,7 +670,7 @@ func NewWebSearchTool(opts WebSearchToolOptions) (*WebSearchTool, error) { maxResults = opts.SearXNGMaxResults } } else if opts.TavilyEnabled && len(opts.TavilyAPIKeys) > 0 { - client, err := createHTTPClient(opts.Proxy, searchTimeout) + client, err := utils.CreateHTTPClient(opts.Proxy, searchTimeout) if err != nil { return nil, fmt.Errorf("failed to create HTTP client for Tavily: %w", err) } @@ -719,7 +684,7 @@ func NewWebSearchTool(opts WebSearchToolOptions) (*WebSearchTool, error) { maxResults = opts.TavilyMaxResults } } else if opts.DuckDuckGoEnabled { - client, err := createHTTPClient(opts.Proxy, searchTimeout) + client, err := utils.CreateHTTPClient(opts.Proxy, searchTimeout) if err != nil { return nil, fmt.Errorf("failed to create HTTP client for DuckDuckGo: %w", err) } @@ -728,7 +693,7 @@ func NewWebSearchTool(opts WebSearchToolOptions) (*WebSearchTool, error) { maxResults = opts.DuckDuckGoMaxResults } } else if opts.GLMSearchEnabled && opts.GLMSearchAPIKey != "" { - client, err := createHTTPClient(opts.Proxy, searchTimeout) + client, err := utils.CreateHTTPClient(opts.Proxy, searchTimeout) if err != nil { return nil, fmt.Errorf("failed to create HTTP client for GLM Search: %w", err) } @@ -827,7 +792,7 @@ func NewWebFetchToolWithProxy(maxChars int, proxy string, fetchLimitBytes int64) if maxChars <= 0 { maxChars = defaultMaxChars } - client, err := createHTTPClient(proxy, fetchTimeout) + client, err := utils.CreateHTTPClient(proxy, fetchTimeout) if err != nil { return nil, fmt.Errorf("failed to create HTTP client for web fetch: %w", err) } diff --git a/pkg/tools/web_test.go b/pkg/tools/web_test.go index 0737d2087..41d83e6f5 100644 --- a/pkg/tools/web_test.go +++ b/pkg/tools/web_test.go @@ -10,7 +10,6 @@ import ( "net/http/httptest" "strings" "testing" - "time" "github.com/sipeed/picoclaw/pkg/logger" ) @@ -639,108 +638,6 @@ func TestWebTool_WebFetch_MissingDomain(t *testing.T) { } } -func TestCreateHTTPClient_ProxyConfigured(t *testing.T) { - client, err := createHTTPClient("http://127.0.0.1:7890", 12*time.Second) - if err != nil { - t.Fatalf("createHTTPClient() error: %v", err) - } - if client.Timeout != 12*time.Second { - t.Fatalf("client.Timeout = %v, want %v", client.Timeout, 12*time.Second) - } - - tr, ok := client.Transport.(*http.Transport) - if !ok { - t.Fatalf("client.Transport type = %T, want *http.Transport", client.Transport) - } - if tr.Proxy == nil { - t.Fatal("transport.Proxy is nil, want non-nil") - } - - req, err := http.NewRequest("GET", "https://example.com", nil) - if err != nil { - t.Fatalf("http.NewRequest() error: %v", err) - } - proxyURL, err := tr.Proxy(req) - if err != nil { - t.Fatalf("transport.Proxy(req) error: %v", err) - } - if proxyURL == nil || proxyURL.String() != "http://127.0.0.1:7890" { - t.Fatalf("proxy URL = %v, want %q", proxyURL, "http://127.0.0.1:7890") - } -} - -func TestCreateHTTPClient_InvalidProxy(t *testing.T) { - _, err := createHTTPClient("://bad-proxy", 10*time.Second) - if err == nil { - t.Fatal("createHTTPClient() expected error for invalid proxy URL, got nil") - } -} - -func TestCreateHTTPClient_Socks5ProxyConfigured(t *testing.T) { - client, err := createHTTPClient("socks5://127.0.0.1:1080", 8*time.Second) - if err != nil { - t.Fatalf("createHTTPClient() error: %v", err) - } - - tr, ok := client.Transport.(*http.Transport) - if !ok { - t.Fatalf("client.Transport type = %T, want *http.Transport", client.Transport) - } - req, err := http.NewRequest("GET", "https://example.com", nil) - if err != nil { - t.Fatalf("http.NewRequest() error: %v", err) - } - proxyURL, err := tr.Proxy(req) - if err != nil { - t.Fatalf("transport.Proxy(req) error: %v", err) - } - if proxyURL == nil || proxyURL.String() != "socks5://127.0.0.1:1080" { - t.Fatalf("proxy URL = %v, want %q", proxyURL, "socks5://127.0.0.1:1080") - } -} - -func TestCreateHTTPClient_UnsupportedProxyScheme(t *testing.T) { - _, err := createHTTPClient("ftp://127.0.0.1:21", 10*time.Second) - if err == nil { - t.Fatal("createHTTPClient() expected error for unsupported scheme, got nil") - } - if !strings.Contains(err.Error(), "unsupported proxy scheme") { - t.Fatalf("error = %q, want to contain %q", err.Error(), "unsupported proxy scheme") - } -} - -func TestCreateHTTPClient_ProxyFromEnvironmentWhenConfigEmpty(t *testing.T) { - t.Setenv("HTTP_PROXY", "http://127.0.0.1:8888") - t.Setenv("http_proxy", "http://127.0.0.1:8888") - t.Setenv("HTTPS_PROXY", "http://127.0.0.1:8888") - t.Setenv("https_proxy", "http://127.0.0.1:8888") - t.Setenv("ALL_PROXY", "") - t.Setenv("all_proxy", "") - t.Setenv("NO_PROXY", "") - t.Setenv("no_proxy", "") - - client, err := createHTTPClient("", 10*time.Second) - if err != nil { - t.Fatalf("createHTTPClient() error: %v", err) - } - - tr, ok := client.Transport.(*http.Transport) - if !ok { - t.Fatalf("client.Transport type = %T, want *http.Transport", client.Transport) - } - if tr.Proxy == nil { - t.Fatal("transport.Proxy is nil, want proxy function from environment") - } - - req, err := http.NewRequest("GET", "https://example.com", nil) - if err != nil { - t.Fatalf("http.NewRequest() error: %v", err) - } - if _, err := tr.Proxy(req); err != nil { - t.Fatalf("transport.Proxy(req) error: %v", err) - } -} - func TestNewWebFetchToolWithProxy(t *testing.T) { tool, err := NewWebFetchToolWithProxy(1024, "http://127.0.0.1:7890", testFetchLimit) if err != nil { diff --git a/pkg/utils/http_client.go b/pkg/utils/http_client.go new file mode 100644 index 000000000..bda7c5c83 --- /dev/null +++ b/pkg/utils/http_client.go @@ -0,0 +1,48 @@ +package utils + +import ( + "fmt" + "net/http" + "net/url" + "strings" + "time" +) + +// CreateHTTPClient creates an HTTP client with optional proxy support. +// If proxyURL is empty, it uses the system environment proxy settings. +// Supported proxy schemes: http, https, socks5, socks5h. +func CreateHTTPClient(proxyURL string, timeout time.Duration) (*http.Client, error) { + client := &http.Client{ + Timeout: timeout, + Transport: &http.Transport{ + MaxIdleConns: 10, + IdleConnTimeout: 30 * time.Second, + DisableCompression: false, + TLSHandshakeTimeout: 15 * time.Second, + }, + } + + if proxyURL != "" { + proxy, err := url.Parse(proxyURL) + if err != nil { + return nil, fmt.Errorf("invalid proxy URL: %w", err) + } + scheme := strings.ToLower(proxy.Scheme) + switch scheme { + case "http", "https", "socks5", "socks5h": + default: + return nil, fmt.Errorf( + "unsupported proxy scheme %q (supported: http, https, socks5, socks5h)", + proxy.Scheme, + ) + } + if proxy.Host == "" { + return nil, fmt.Errorf("invalid proxy URL: missing host") + } + client.Transport.(*http.Transport).Proxy = http.ProxyURL(proxy) + } else { + client.Transport.(*http.Transport).Proxy = http.ProxyFromEnvironment + } + + return client, nil +} diff --git a/pkg/utils/http_client_test.go b/pkg/utils/http_client_test.go new file mode 100644 index 000000000..ff3d0429b --- /dev/null +++ b/pkg/utils/http_client_test.go @@ -0,0 +1,110 @@ +package utils + +import ( + "net/http" + "strings" + "testing" + "time" +) + +func TestCreateHTTPClient_ProxyConfigured(t *testing.T) { + client, err := CreateHTTPClient("http://127.0.0.1:7890", 12*time.Second) + if err != nil { + t.Fatalf("createHTTPClient() error: %v", err) + } + if client.Timeout != 12*time.Second { + t.Fatalf("client.Timeout = %v, want %v", client.Timeout, 12*time.Second) + } + + tr, ok := client.Transport.(*http.Transport) + if !ok { + t.Fatalf("client.Transport type = %T, want *http.Transport", client.Transport) + } + if tr.Proxy == nil { + t.Fatal("transport.Proxy is nil, want non-nil") + } + + req, err := http.NewRequest("GET", "https://example.com", nil) + if err != nil { + t.Fatalf("http.NewRequest() error: %v", err) + } + proxyURL, err := tr.Proxy(req) + if err != nil { + t.Fatalf("transport.Proxy(req) error: %v", err) + } + if proxyURL == nil || proxyURL.String() != "http://127.0.0.1:7890" { + t.Fatalf("proxy URL = %v, want %q", proxyURL, "http://127.0.0.1:7890") + } +} + +func TestCreateHTTPClient_InvalidProxy(t *testing.T) { + _, err := CreateHTTPClient("://bad-proxy", 10*time.Second) + if err == nil { + t.Fatal("createHTTPClient() expected error for invalid proxy URL, got nil") + } +} + +func TestCreateHTTPClient_Socks5ProxyConfigured(t *testing.T) { + client, err := CreateHTTPClient("socks5://127.0.0.1:1080", 8*time.Second) + if err != nil { + t.Fatalf("createHTTPClient() error: %v", err) + } + + tr, ok := client.Transport.(*http.Transport) + if !ok { + t.Fatalf("client.Transport type = %T, want *http.Transport", client.Transport) + } + req, err := http.NewRequest("GET", "https://example.com", nil) + if err != nil { + t.Fatalf("http.NewRequest() error: %v", err) + } + proxyURL, err := tr.Proxy(req) + if err != nil { + t.Fatalf("transport.Proxy(req) error: %v", err) + } + if proxyURL == nil || proxyURL.String() != "socks5://127.0.0.1:1080" { + t.Fatalf("proxy URL = %v, want %q", proxyURL, "socks5://127.0.0.1:1080") + } +} + +func TestCreateHTTPClient_UnsupportedProxyScheme(t *testing.T) { + _, err := CreateHTTPClient("ftp://127.0.0.1:21", 10*time.Second) + if err == nil { + t.Fatal("createHTTPClient() expected error for unsupported scheme, got nil") + } + if !strings.Contains(err.Error(), "unsupported proxy scheme") { + t.Fatalf("error = %q, want to contain %q", err.Error(), "unsupported proxy scheme") + } +} + +func TestCreateHTTPClient_ProxyFromEnvironmentWhenConfigEmpty(t *testing.T) { + t.Setenv("HTTP_PROXY", "http://127.0.0.1:8888") + t.Setenv("http_proxy", "http://127.0.0.1:8888") + t.Setenv("HTTPS_PROXY", "http://127.0.0.1:8888") + t.Setenv("https_proxy", "http://127.0.0.1:8888") + t.Setenv("ALL_PROXY", "") + t.Setenv("all_proxy", "") + t.Setenv("NO_PROXY", "") + t.Setenv("no_proxy", "") + + client, err := CreateHTTPClient("", 10*time.Second) + if err != nil { + t.Fatalf("createHTTPClient() error: %v", err) + } + + tr, ok := client.Transport.(*http.Transport) + if !ok { + t.Fatalf("client.Transport type = %T, want *http.Transport", client.Transport) + } + if tr.Proxy == nil { + t.Fatal("transport.Proxy is nil, want proxy function from environment") + } + + req, err := http.NewRequest("GET", "https://example.com", nil) + if err != nil { + t.Fatalf("http.NewRequest() error: %v", err) + } + if _, err := tr.Proxy(req); err != nil { + t.Fatalf("transport.Proxy(req) error: %v", err) + } +}