enhance skill installer (#1252)

* enhance skill installer

* enhance install skills v2

* go file formate

* fix:use proxy download skills;many chunck download;simple code

* add default config to config.example.json, download skill from github use proxy and token

---------

Co-authored-by: FantasticCode2019 <1443996278@qq.com>
This commit is contained in:
leamon
2026-03-13 14:04:02 +08:00
committed by GitHub
parent b811e9186c
commit 0fb92b21b6
9 changed files with 1091 additions and 179 deletions
+9 -1
View File
@@ -29,7 +29,15 @@ func NewSkillsCommand() *cobra.Command {
}
d.workspace = cfg.WorkspacePath()
d.installer = skills.NewSkillInstaller(d.workspace)
installer, err := skills.NewSkillInstaller(
d.workspace,
cfg.Tools.Skills.Github.Token,
cfg.Tools.Skills.Github.Proxy,
)
if err != nil {
return fmt.Errorf("error creating skills installer: %w", err)
}
d.installer = installer
// get global config directory and builtin skills directory
globalDir := filepath.Dir(internal.GetConfigPath())
+4
View File
@@ -437,6 +437,10 @@
"max_response_size": 0
}
},
"github": {
"proxy": "http://127.0.0.1:7891",
"token": ""
},
"max_concurrent_searches": 2,
"search_cache": {
"max_size": 50,
+6
View File
@@ -713,6 +713,7 @@ type ExecConfig struct {
type SkillsToolsConfig struct {
ToolConfig ` envPrefix:"PICOCLAW_TOOLS_SKILLS_"`
Registries SkillsRegistriesConfig ` json:"registries"`
Github SkillsGithubConfig ` json:"github"`
MaxConcurrentSearches int ` json:"max_concurrent_searches" env:"PICOCLAW_TOOLS_SKILLS_MAX_CONCURRENT_SEARCHES"`
SearchCache SearchCacheConfig ` json:"search_cache"`
}
@@ -762,6 +763,11 @@ type SkillsRegistriesConfig struct {
ClawHub ClawHubRegistryConfig `json:"clawhub"`
}
type SkillsGithubConfig struct {
Token string `json:"token,omitempty" env:"PICOCLAW_TOOLS_SKILLS_GITHUB_AUTH_TOKEN"`
Proxy string `json:"proxy,omitempty" env:"PICOCLAW_TOOLS_SKILLS_GITHUB_PROXY"`
}
type ClawHubRegistryConfig struct {
Enabled bool `json:"enabled" env:"PICOCLAW_SKILLS_REGISTRIES_CLAWHUB_ENABLED"`
BaseURL string `json:"base_url" env:"PICOCLAW_SKILLS_REGISTRIES_CLAWHUB_BASE_URL"`
+241 -32
View File
@@ -2,80 +2,289 @@ package skills
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"os"
"path"
"path/filepath"
"strings"
"time"
"github.com/sipeed/picoclaw/pkg/fileutil"
"github.com/sipeed/picoclaw/pkg/utils"
)
type SkillInstaller struct {
workspace string
// GitHubContent represents a file or directory in GitHub API response
type GitHubContent struct {
Name string `json:"name"`
Path string `json:"path"`
Type string `json:"type"` // "file" or "dir"
DownloadURL string `json:"download_url"`
URL string `json:"url"` // API URL for subdirectories
}
func NewSkillInstaller(workspace string) *SkillInstaller {
return &SkillInstaller{
workspace: workspace,
// GitHubRef represents a parsed GitHub reference
type GitHubRef struct {
Owner string // Repository owner
RepoName string // Repository name
Ref string // Git reference (branch, tag, or commit)
SubPath string // Path within the repository
}
type SkillInstaller struct {
workspace string
client *http.Client
githubToken string
proxy string
}
// NewSkillInstaller creates a new skill installer.
// proxy is an optional HTTP/HTTPS/SOCKS5 proxy URL for downloading skills.
func NewSkillInstaller(workspace, githubToken, proxy string) (*SkillInstaller, error) {
client, err := utils.CreateHTTPClient(proxy, 15*time.Second)
if err != nil {
return nil, fmt.Errorf("failed to create HTTP client: %w", err)
}
return &SkillInstaller{
workspace: workspace,
client: client,
githubToken: githubToken,
proxy: proxy,
}, nil
}
// parseGitHubRef parses a GitHub reference.
// Supports: "owner/repo", "owner/repo/path", or full URL like "https://github.com/owner/repo/tree/ref/path"
func parseGitHubRef(repo string) (GitHubRef, error) {
repo = strings.TrimSpace(repo)
// Handle full URL
if strings.HasPrefix(repo, "http://") || strings.HasPrefix(repo, "https://") {
u, err := url.Parse(repo)
if err != nil {
return GitHubRef{}, fmt.Errorf("invalid URL: %w", err)
}
parts := strings.Split(strings.Trim(u.Path, "/"), "/")
if len(parts) < 2 {
return GitHubRef{}, fmt.Errorf("invalid GitHub URL")
}
ref := GitHubRef{
Owner: parts[0],
RepoName: parts[1],
Ref: "main",
}
// Look for /tree/ or /blob/ in the path
for i := 2; i < len(parts); i++ {
if parts[i] == "tree" || parts[i] == "blob" {
if i+1 < len(parts) {
ref.Ref = parts[i+1]
ref.SubPath = strings.Join(parts[i+2:], "/")
}
break
}
}
return ref, nil
}
// Handle shorthand format
parts := strings.Split(strings.Trim(repo, "/"), "/")
if len(parts) < 2 {
return GitHubRef{}, fmt.Errorf("invalid format %q: expected 'owner/repo'", repo)
}
ref := GitHubRef{
Owner: parts[0],
RepoName: parts[1],
Ref: "main",
}
if len(parts) > 2 {
ref.SubPath = strings.Join(parts[2:], "/")
}
return ref, nil
}
func (si *SkillInstaller) InstallFromGitHub(ctx context.Context, repo string) error {
skillDir := filepath.Join(si.workspace, "skills", filepath.Base(repo))
if _, err := os.Stat(skillDir); err == nil {
return fmt.Errorf("skill '%s' already exists", filepath.Base(repo))
ref, err := parseGitHubRef(repo)
if err != nil {
return err
}
url := fmt.Sprintf("https://raw.githubusercontent.com/%s/main/SKILL.md", repo)
skillName := ref.RepoName
if ref.SubPath != "" {
skillName = filepath.Base(ref.SubPath)
}
skillDirectory := filepath.Join(si.workspace, "skills", skillName)
if _, err := os.Stat(skillDirectory); err == nil {
return fmt.Errorf("skill '%s' already exists", skillName)
}
// Build GitHub API URL
apiPath := path.Join(ref.Owner, ref.RepoName, "contents")
if ref.SubPath != "" {
apiPath = path.Join(apiPath, ref.SubPath)
}
apiURL := fmt.Sprintf("https://api.github.com/repos/%s?ref=%s", apiPath, ref.Ref)
if err := si.getGithubDirAllFiles(ctx, apiURL, skillDirectory, true); err != nil {
// Fallback to raw download
return si.downloadRaw(ctx, ref.Owner, ref.RepoName, ref.Ref, ref.SubPath, skillDirectory)
}
if _, err := os.Stat(filepath.Join(skillDirectory, "SKILL.md")); err != nil {
return fmt.Errorf("SKILL.md not found in repository")
}
return nil
}
// downloadDir recursively downloads a directory from GitHub API
// isRoot: true if this is the skill root directory (only download SKILL.md at root)
func (si *SkillInstaller) getGithubDirAllFiles(ctx context.Context, apiURL, localDir string, isRoot bool) error {
req, err := http.NewRequestWithContext(ctx, "GET", apiURL, nil)
if err != nil {
return err
}
if si.githubToken != "" {
req.Header.Set("Authorization", "Bearer "+si.githubToken)
}
resp, err := utils.DoRequestWithRetry(si.client, req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
return fmt.Errorf("HTTP %d", resp.StatusCode)
}
var items []GitHubContent
if err := json.NewDecoder(resp.Body).Decode(&items); err != nil {
return err
}
for _, item := range items {
localPath := filepath.Join(localDir, item.Name)
switch item.Type {
case "file":
if !shouldDownload(item.Name, isRoot) {
continue
}
if err := si.downloadFile(ctx, item.DownloadURL, localPath); err != nil {
return fmt.Errorf("download %s: %w", item.Name, err)
}
case "dir":
if !isSkillDirectory(item.Name) {
continue
}
if err := si.getGithubDirAllFiles(ctx, item.URL, localPath, false); err != nil {
return err
}
}
}
return nil
}
// downloadRaw is a fallback that downloads just SKILL.md from raw.githubusercontent.com
func (si *SkillInstaller) downloadRaw(ctx context.Context, owner, repo, ref, subPath, localDir string) error {
urlPath := path.Join(owner, repo, ref)
if subPath != "" {
urlPath = path.Join(urlPath, subPath)
}
url := fmt.Sprintf("https://raw.githubusercontent.com/%s/SKILL.md", urlPath)
client := &http.Client{Timeout: 15 * time.Second}
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return fmt.Errorf("failed to create request: %w", err)
}
resp, err := utils.DoRequestWithRetry(client, req)
// Use chunked download to temporary file.
tmpPath, err := utils.DownloadToFile(ctx, si.client, req, 0)
if err != nil {
return fmt.Errorf("failed to fetch skill: %w", err)
}
defer resp.Body.Close()
defer os.Remove(tmpPath)
if resp.StatusCode != 200 {
return fmt.Errorf("failed to fetch skill: HTTP %d", resp.StatusCode)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("failed to read response: %w", err)
}
if err := os.MkdirAll(skillDir, 0o755); err != nil {
if err := os.MkdirAll(localDir, 0o755); err != nil {
return fmt.Errorf("failed to create skill directory: %w", err)
}
skillPath := filepath.Join(skillDir, "SKILL.md")
localPath := filepath.Join(localDir, "SKILL.md")
// Use unified atomic write utility with explicit sync for flash storage reliability.
if err := fileutil.WriteFileAtomic(skillPath, body, 0o600); err != nil {
// Atomic move from temp to final location.
if err := os.Rename(tmpPath, localPath); err != nil {
return fmt.Errorf("failed to write skill file: %w", err)
}
return nil
return os.Chmod(localPath, 0o600)
}
func (si *SkillInstaller) downloadFile(ctx context.Context, url, localPath string) error {
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return err
}
// Use chunked download to temporary file, then move atomically to target.
tmpPath, err := utils.DownloadToFile(ctx, si.client, req, 0)
if err != nil {
return err
}
defer os.Remove(tmpPath)
if err := os.MkdirAll(filepath.Dir(localPath), 0o755); err != nil {
return err
}
// Atomic move from temp to final location.
if err := os.Rename(tmpPath, localPath); err != nil {
return fmt.Errorf("failed to move downloaded file: %w", err)
}
return os.Chmod(localPath, 0o600)
}
// shouldDownload determines if a file should be downloaded
// root: true if we're at the skill root directory
func shouldDownload(name string, root bool) bool {
if root {
return name == "SKILL.md"
}
return true
}
// isSkillDir checks if a directory is a standard skill resource directory
func isSkillDirectory(name string) bool {
switch name {
case "scripts", "references", "assets", "templates", "docs":
return true
}
return false
}
func (si *SkillInstaller) Uninstall(skillName string) error {
skillDir := filepath.Join(si.workspace, "skills", skillName)
parts := strings.Split(skillName, "/")
var finalSkillName string
for i := len(parts) - 1; i >= 0; i-- {
if parts[i] != "" {
finalSkillName = parts[i]
break
}
}
if finalSkillName == "" {
finalSkillName = skillName
}
skillDir := filepath.Join(si.workspace, "skills", finalSkillName)
if _, err := os.Stat(skillDir); os.IsNotExist(err) {
return fmt.Errorf("skill '%s' not found", skillName)
return fmt.Errorf("skill '%s' not found (processed as '%s')", skillName, finalSkillName)
}
if err := os.RemoveAll(skillDir); err != nil {
return fmt.Errorf("failed to remove skill: %w", err)
return fmt.Errorf("failed to remove skill '%s': %w", finalSkillName, err)
}
return nil
+665
View File
@@ -0,0 +1,665 @@
package skills
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"
"time"
)
func TestParseGitHubRef(t *testing.T) {
tests := []struct {
name string
repo string
wantOwner string
wantRepoName string
wantRef string
wantSubPath string
wantErr bool
wantErrContain string
}{
{
name: "simple owner/repo",
repo: "sipeed/picoclaw",
wantOwner: "sipeed",
wantRepoName: "picoclaw",
wantRef: "main",
wantSubPath: "",
},
{
name: "owner/repo with subpath",
repo: "sipeed/picoclaw/skills/test",
wantOwner: "sipeed",
wantRepoName: "picoclaw",
wantRef: "main",
wantSubPath: "skills/test",
},
{
name: "full URL with tree",
repo: "https://github.com/sipeed/picoclaw/tree/dev/skills/test",
wantOwner: "sipeed",
wantRepoName: "picoclaw",
wantRef: "dev",
wantSubPath: "skills/test",
},
{
name: "full URL with blob",
repo: "https://github.com/sipeed/picoclaw/blob/main/README.md",
wantOwner: "sipeed",
wantRepoName: "picoclaw",
wantRef: "main",
wantSubPath: "README.md",
},
{
name: "full URL without ref",
repo: "https://github.com/sipeed/picoclaw",
wantOwner: "sipeed",
wantRepoName: "picoclaw",
wantRef: "main",
wantSubPath: "",
},
{
name: "invalid format - single part",
repo: "sipeed",
wantErr: true,
wantErrContain: "expected 'owner/repo'",
},
{
name: "invalid URL",
repo: "http://[invalid",
wantErr: true,
wantErrContain: "invalid URL",
},
{
name: "invalid GitHub URL - only one path part",
repo: "https://github.com/sipeed",
wantErr: true,
wantErrContain: "invalid GitHub URL",
},
{
name: "with whitespace",
repo: " sipeed/picoclaw ",
wantOwner: "sipeed",
wantRepoName: "picoclaw",
wantRef: "main",
wantSubPath: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ref, err := parseGitHubRef(tt.repo)
if tt.wantErr {
if err == nil {
t.Errorf("parseGitHubRef() error = nil, wantErr = true")
return
}
if tt.wantErrContain != "" && !strings.Contains(err.Error(), tt.wantErrContain) {
t.Errorf("parseGitHubRef() error = %v, want error containing %v", err, tt.wantErrContain)
}
return
}
if err != nil {
t.Errorf("parseGitHubRef() unexpected error = %v", err)
return
}
if ref.Owner != tt.wantOwner {
t.Errorf("parseGitHubRef() owner = %v, want %v", ref.Owner, tt.wantOwner)
}
if ref.RepoName != tt.wantRepoName {
t.Errorf("parseGitHubRef() repoName = %v, want %v", ref.RepoName, tt.wantRepoName)
}
if ref.Ref != tt.wantRef {
t.Errorf("parseGitHubRef() ref = %v, want %v", ref.Ref, tt.wantRef)
}
if ref.SubPath != tt.wantSubPath {
t.Errorf("parseGitHubRef() subPath = %v, want %v", ref.SubPath, tt.wantSubPath)
}
})
}
}
func TestShouldDownload(t *testing.T) {
tests := []struct {
name string
file string
root bool
want bool
}{
{"SKILL.md at root", "SKILL.md", true, true},
{"other file at root", "README.md", true, false},
{"script at root", "script.py", true, false},
{"SKILL.md not at root", "SKILL.md", false, true},
{"any file not at root", "any.txt", false, true},
{"script not at root", "script.py", false, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := shouldDownload(tt.file, tt.root)
if got != tt.want {
t.Errorf("shouldDownload(%q, %v) = %v, want %v", tt.file, tt.root, got, tt.want)
}
})
}
}
func TestIsSkillDirectory(t *testing.T) {
tests := []struct {
name string
dir string
want bool
}{
{"scripts dir", "scripts", true},
{"references dir", "references", true},
{"assets dir", "assets", true},
{"templates dir", "templates", true},
{"docs dir", "docs", true},
{"other dir", "other", false},
{"src dir", "src", false},
{"empty string", "", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := isSkillDirectory(tt.dir)
if got != tt.want {
t.Errorf("isSkillDirectory(%q) = %v, want %v", tt.dir, got, tt.want)
}
})
}
}
func TestNewSkillInstaller(t *testing.T) {
tmpDir := t.TempDir()
installer, err := NewSkillInstaller(tmpDir, "test-token", "")
if err != nil {
t.Fatalf("NewSkillInstaller() error = %v", err)
}
if installer == nil {
t.Fatal("NewSkillInstaller() returned nil")
}
if installer.workspace != tmpDir {
t.Errorf("workspace = %v, want %v", installer.workspace, tmpDir)
}
if installer.githubToken != "test-token" {
t.Errorf("githubToken = %v, want 'test-token'", installer.githubToken)
}
if installer.proxy != "" {
t.Errorf("proxy = %v, want empty", installer.proxy)
}
if installer.client == nil {
t.Error("client is nil")
} else if installer.client.Timeout != 15*time.Second {
t.Errorf("client.Timeout = %v, want 15s", installer.client.Timeout)
}
}
func TestNewSkillInstaller_WithProxy(t *testing.T) {
tmpDir := t.TempDir()
installer, err := NewSkillInstaller(tmpDir, "test-token", "http://127.0.0.1:7890")
if err != nil {
t.Fatalf("NewSkillInstaller() error = %v", err)
}
if installer.proxy != "http://127.0.0.1:7890" {
t.Errorf("proxy = %v, want 'http://127.0.0.1:7890'", installer.proxy)
}
if installer.client == nil {
t.Fatal("client is nil")
}
// Verify the transport has proxy configured
transport, ok := installer.client.Transport.(*http.Transport)
if !ok {
t.Fatal("client.Transport is not *http.Transport")
}
if transport.Proxy == nil {
t.Error("transport.Proxy is nil, expected non-nil")
}
}
func TestNewSkillInstaller_InvalidProxy(t *testing.T) {
tmpDir := t.TempDir()
installer, err := NewSkillInstaller(tmpDir, "test-token", "://invalid-proxy")
if err == nil {
t.Error("NewSkillInstaller() expected error for invalid proxy, got nil")
}
if installer != nil {
t.Error("expected nil installer on error")
}
}
func TestSkillInstaller_DownloadFile(t *testing.T) {
// Create a test server that serves files
content := "test file content for skill download"
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
t.Errorf("expected GET, got %s", r.Method)
}
w.WriteHeader(http.StatusOK)
w.Write([]byte(content))
}))
defer server.Close()
tmpDir := t.TempDir()
installer, err := NewSkillInstaller(tmpDir, "", "")
if err != nil {
t.Fatalf("NewSkillInstaller() error = %v", err)
}
t.Run("successful download", func(t *testing.T) {
localPath := filepath.Join(tmpDir, "test-skill", "SKILL.md")
err := installer.downloadFile(context.Background(), server.URL, localPath)
if err != nil {
t.Errorf("downloadFile() error = %v", err)
return
}
// Verify file was downloaded
data, err := os.ReadFile(localPath)
if err != nil {
t.Errorf("failed to read downloaded file: %v", err)
return
}
if string(data) != content {
t.Errorf("downloaded content = %q, want %q", string(data), content)
}
// Check file permissions
info, err := os.Stat(localPath)
if err != nil {
t.Errorf("failed to stat file: %v", err)
return
}
if info.Mode().Perm() != 0o600 {
t.Errorf("file permissions = %o, want %o", info.Mode().Perm(), 0o600)
}
})
t.Run("http error", func(t *testing.T) {
errorServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
w.Write([]byte("not found"))
}))
defer errorServer.Close()
localPath := filepath.Join(tmpDir, "error-test", "SKILL.md")
err := installer.downloadFile(context.Background(), errorServer.URL, localPath)
if err == nil {
t.Error("downloadFile() expected error for 404, got nil")
}
})
}
func TestSkillInstaller_DownloadRaw(t *testing.T) {
content := "raw skill content"
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte(content))
}))
defer server.Close()
tmpDir := t.TempDir()
installer, err := NewSkillInstaller(tmpDir, "", "")
if err != nil {
t.Fatalf("NewSkillInstaller() error = %v", err)
}
// Replace the client with one that points to our test server
// We need to modify the URL in the function, so we'll test indirectly
localDir := filepath.Join(tmpDir, "raw-test")
ctx := context.Background()
// Create a simple test by calling downloadFile directly since downloadRaw
// constructs its own URL
testFile := filepath.Join(localDir, "SKILL.md")
err = installer.downloadFile(ctx, server.URL, testFile)
if err != nil {
t.Errorf("downloadFile() error = %v", err)
}
// Verify file content
data, err := os.ReadFile(testFile)
if err != nil {
t.Errorf("failed to read file: %v", err)
return
}
if string(data) != content {
t.Errorf("content = %q, want %q", string(data), content)
}
}
func TestSkillInstaller_Uninstall(t *testing.T) {
tmpDir := t.TempDir()
skillsDir := filepath.Join(tmpDir, "skills")
os.MkdirAll(skillsDir, 0o755)
installer, err := NewSkillInstaller(tmpDir, "", "")
if err != nil {
t.Fatalf("NewSkillInstaller() error = %v", err)
}
t.Run("uninstall existing skill", func(t *testing.T) {
skillName := "test-skill"
skillDir := filepath.Join(skillsDir, skillName)
// Create skill directory with a file
os.MkdirAll(skillDir, 0o755)
os.WriteFile(filepath.Join(skillDir, "SKILL.md"), []byte("test"), 0o644)
if err := installer.Uninstall(skillName); err != nil {
t.Errorf("Uninstall() error = %v", err)
}
// Verify directory was removed
if _, err := os.Stat(skillDir); !os.IsNotExist(err) {
t.Error("skill directory still exists after uninstall")
}
})
t.Run("uninstall non-existent skill", func(t *testing.T) {
if err := installer.Uninstall("non-existent-skill"); err == nil {
t.Error("Uninstall() expected error for non-existent skill, got nil")
} else if !strings.Contains(err.Error(), "not found") {
t.Errorf("error message = %q, want 'not found'", err.Error())
}
})
t.Run("uninstall with path separator", func(t *testing.T) {
skillName := "owner/repo/skill-name"
skillDir := filepath.Join(skillsDir, "skill-name")
// Create skill directory
os.MkdirAll(skillDir, 0o755)
os.WriteFile(filepath.Join(skillDir, "SKILL.md"), []byte("test"), 0o644)
if err := installer.Uninstall(skillName); err != nil {
t.Errorf("Uninstall() error = %v", err)
}
if _, err := os.Stat(skillDir); !os.IsNotExist(err) {
t.Error("skill directory still exists after uninstall")
}
})
t.Run("uninstall with trailing slash", func(t *testing.T) {
skillName := "skill-name/"
skillDir := filepath.Join(skillsDir, "skill-name")
// Create skill directory
os.MkdirAll(skillDir, 0o755)
os.WriteFile(filepath.Join(skillDir, "SKILL.md"), []byte("test"), 0o644)
if err := installer.Uninstall(skillName); err != nil {
t.Errorf("Uninstall() error = %v", err)
}
if _, err := os.Stat(skillDir); !os.IsNotExist(err) {
t.Error("skill directory still exists after uninstall")
}
})
}
func TestSkillInstaller_InstallFromGitHub_SkillAlreadyExists(t *testing.T) {
tmpDir := t.TempDir()
skillsDir := filepath.Join(tmpDir, "skills")
os.MkdirAll(skillsDir, 0o755)
installer, err := NewSkillInstaller(tmpDir, "", "")
if err != nil {
t.Fatalf("NewSkillInstaller() error = %v", err)
}
// Create an existing skill directory
existingSkill := filepath.Join(skillsDir, "picoclaw")
os.MkdirAll(existingSkill, 0o755)
os.WriteFile(filepath.Join(existingSkill, "SKILL.md"), []byte("existing"), 0o644)
// Try to install the same skill - should fail
err = installer.InstallFromGitHub(context.Background(), "sipeed/picoclaw")
if err == nil {
t.Error("InstallFromGitHub() expected error for existing skill, got nil")
}
if !strings.Contains(err.Error(), "already exists") {
t.Errorf("error message = %q, want 'already exists'", err.Error())
}
}
func TestGitHubContent_Struct(t *testing.T) {
// Test that GitHubContent struct can be properly unmarshaled
jsonData := `{
"name": "test.md",
"path": "skills/test.md",
"type": "file",
"download_url": "https://example.com/download",
"url": "https://api.github.com/contents/skills/test.md"
}`
var content GitHubContent
err := json.Unmarshal([]byte(jsonData), &content)
if err != nil {
t.Errorf("failed to unmarshal GitHubContent: %v", err)
}
if content.Name != "test.md" {
t.Errorf("Name = %q, want 'test.md'", content.Name)
}
if content.Type != "file" {
t.Errorf("Type = %q, want 'file'", content.Type)
}
if content.DownloadURL != "https://example.com/download" {
t.Errorf("DownloadURL = %q, want 'https://example.com/download'", content.DownloadURL)
}
}
func TestSkillInstaller_GetGithubDirAllFiles(t *testing.T) {
tmpDir := t.TempDir()
installer, err := NewSkillInstaller(tmpDir, "", "")
if err != nil {
t.Fatalf("NewSkillInstaller() error = %v", err)
}
// Create a test server that mimics GitHub API
fileContent := "skill file content"
var serverURL string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Check for authorization header
authHeader := r.Header.Get("Authorization")
if authHeader != "" && !strings.HasPrefix(authHeader, "Bearer ") {
t.Errorf("expected Bearer token, got: %s", authHeader)
}
// Return different responses based on path
if strings.Contains(r.URL.Path, "/contents") {
// API response for directory listing
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
items := []map[string]any{
{
"name": "SKILL.md",
"path": "SKILL.md",
"type": "file",
"download_url": serverURL + "/download/SKILL.md",
},
{
"name": "scripts",
"path": "scripts",
"type": "dir",
"url": serverURL + "/api/scripts",
},
}
json.NewEncoder(w).Encode(items)
} else if strings.Contains(r.URL.Path, "/api/scripts") {
// API response for scripts subdirectory
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
items := []map[string]any{
{
"name": "test.py",
"path": "scripts/test.py",
"type": "file",
"download_url": serverURL + "/download/test.py",
},
}
json.NewEncoder(w).Encode(items)
} else if strings.Contains(r.URL.Path, "/download/") {
// Raw file download
w.WriteHeader(http.StatusOK)
w.Write([]byte(fileContent))
} else {
w.WriteHeader(http.StatusNotFound)
}
}))
serverURL = server.URL
defer server.Close()
localDir := filepath.Join(tmpDir, "test-skill")
t.Run("download from GitHub API", func(t *testing.T) {
err := installer.getGithubDirAllFiles(context.Background(), server.URL+"/contents", localDir, true)
if err != nil {
t.Errorf("getGithubDirAllFiles() error = %v", err)
return
}
// Verify SKILL.md was downloaded
skillMd := filepath.Join(localDir, "SKILL.md")
data, err := os.ReadFile(skillMd)
if err != nil {
t.Errorf("failed to read SKILL.md: %v", err)
return
}
if string(data) != fileContent {
t.Errorf("SKILL.md content = %q, want %q", string(data), fileContent)
}
// Verify scripts directory and file
scriptFile := filepath.Join(localDir, "scripts", "test.py")
data, err = os.ReadFile(scriptFile)
if err != nil {
t.Errorf("failed to read test.py: %v", err)
return
}
if string(data) != fileContent {
t.Errorf("test.py content = %q, want %q", string(data), fileContent)
}
})
t.Run("http error response", func(t *testing.T) {
errorServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusForbidden)
}))
defer errorServer.Close()
err := installer.getGithubDirAllFiles(
context.Background(),
errorServer.URL,
filepath.Join(tmpDir, "error-test"),
true,
)
if err == nil {
t.Error("getGithubDirAllFiles() expected error for 403, got nil")
}
})
}
func TestSkillInstaller_InstallFromGitHub_WithToken(t *testing.T) {
tmpDir := t.TempDir()
skillsDir := filepath.Join(tmpDir, "skills")
os.MkdirAll(skillsDir, 0o755)
var serverURL string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Capture the authorization header
authHeader := r.Header.Get("Authorization")
if authHeader != "" {
tokenReceived := strings.TrimPrefix(authHeader, "Bearer ")
t.Fatalf("github token is %s", tokenReceived)
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
items := []map[string]any{
{
"name": "SKILL.md",
"path": "SKILL.md",
"type": "file",
"download_url": serverURL + "/download/SKILL.md",
},
}
json.NewEncoder(w).Encode(items)
}))
serverURL = server.URL
defer server.Close()
installer, err := NewSkillInstaller(tmpDir, "test-github-token", "")
if err != nil {
t.Fatalf("NewSkillInstaller() error = %v", err)
}
// We need to test the token is passed - the actual install will fail
// because we're not fully mocking the download, but we can verify
// the token is sent in the request
// Use a simple context with timeout
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
// The install will fail because download URL isn't properly set up,
// but the token should be sent in the API request
_ = installer.InstallFromGitHub(ctx, "owner/repo")
// Note: We can't easily intercept the download request since it's a different URL,
// but the fact that the API request was made verifies the token flow
// In a real scenario, the token would be sent to both API and raw downloads
}
func TestSkillInstaller_ContextCancellation(t *testing.T) {
tmpDir := t.TempDir()
installer, err := NewSkillInstaller(tmpDir, "", "")
if err != nil {
t.Fatalf("NewSkillInstaller() error = %v", err)
}
// Create a slow server
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(100 * time.Millisecond)
w.WriteHeader(http.StatusOK)
w.Write([]byte("response"))
}))
defer server.Close()
// Create a canceled context
ctx, cancel := context.WithCancel(context.Background())
cancel() // Cancel immediately
localPath := filepath.Join(tmpDir, "cancel-test", "file.txt")
err = installer.downloadFile(ctx, server.URL, localPath)
if err == nil {
t.Error("downloadFile() expected error for canceled context, got nil")
}
}
+8 -43
View File
@@ -14,6 +14,8 @@ import (
"strings"
"sync/atomic"
"time"
"github.com/sipeed/picoclaw/pkg/utils"
)
const (
@@ -41,43 +43,6 @@ var (
reDDGSnippet = regexp.MustCompile(`<a class="result__snippet[^"]*".*?>([\s\S]*?)</a>`)
)
// createHTTPClient creates an HTTP client with optional proxy support
func createHTTPClient(proxyURL string, timeout time.Duration) (*http.Client, error) {
client := &http.Client{
Timeout: timeout,
Transport: &http.Transport{
MaxIdleConns: 10,
IdleConnTimeout: 30 * time.Second,
DisableCompression: false,
TLSHandshakeTimeout: 15 * time.Second,
},
}
if proxyURL != "" {
proxy, err := url.Parse(proxyURL)
if err != nil {
return nil, fmt.Errorf("invalid proxy URL: %w", err)
}
scheme := strings.ToLower(proxy.Scheme)
switch scheme {
case "http", "https", "socks5", "socks5h":
default:
return nil, fmt.Errorf(
"unsupported proxy scheme %q (supported: http, https, socks5, socks5h)",
proxy.Scheme,
)
}
if proxy.Host == "" {
return nil, fmt.Errorf("invalid proxy URL: missing host")
}
client.Transport.(*http.Transport).Proxy = http.ProxyURL(proxy)
} else {
client.Transport.(*http.Transport).Proxy = http.ProxyFromEnvironment
}
return client, nil
}
type APIKeyPool struct {
keys []string
current uint32
@@ -678,7 +643,7 @@ func NewWebSearchTool(opts WebSearchToolOptions) (*WebSearchTool, error) {
maxResults := 5
// Priority: Perplexity > Brave > SearXNG > Tavily > DuckDuckGo > GLM Search
if opts.PerplexityEnabled && len(opts.PerplexityAPIKeys) > 0 {
client, err := createHTTPClient(opts.Proxy, perplexityTimeout)
client, err := utils.CreateHTTPClient(opts.Proxy, perplexityTimeout)
if err != nil {
return nil, fmt.Errorf("failed to create HTTP client for Perplexity: %w", err)
}
@@ -691,7 +656,7 @@ func NewWebSearchTool(opts WebSearchToolOptions) (*WebSearchTool, error) {
maxResults = opts.PerplexityMaxResults
}
} else if opts.BraveEnabled && len(opts.BraveAPIKeys) > 0 {
client, err := createHTTPClient(opts.Proxy, searchTimeout)
client, err := utils.CreateHTTPClient(opts.Proxy, searchTimeout)
if err != nil {
return nil, fmt.Errorf("failed to create HTTP client for Brave: %w", err)
}
@@ -705,7 +670,7 @@ func NewWebSearchTool(opts WebSearchToolOptions) (*WebSearchTool, error) {
maxResults = opts.SearXNGMaxResults
}
} else if opts.TavilyEnabled && len(opts.TavilyAPIKeys) > 0 {
client, err := createHTTPClient(opts.Proxy, searchTimeout)
client, err := utils.CreateHTTPClient(opts.Proxy, searchTimeout)
if err != nil {
return nil, fmt.Errorf("failed to create HTTP client for Tavily: %w", err)
}
@@ -719,7 +684,7 @@ func NewWebSearchTool(opts WebSearchToolOptions) (*WebSearchTool, error) {
maxResults = opts.TavilyMaxResults
}
} else if opts.DuckDuckGoEnabled {
client, err := createHTTPClient(opts.Proxy, searchTimeout)
client, err := utils.CreateHTTPClient(opts.Proxy, searchTimeout)
if err != nil {
return nil, fmt.Errorf("failed to create HTTP client for DuckDuckGo: %w", err)
}
@@ -728,7 +693,7 @@ func NewWebSearchTool(opts WebSearchToolOptions) (*WebSearchTool, error) {
maxResults = opts.DuckDuckGoMaxResults
}
} else if opts.GLMSearchEnabled && opts.GLMSearchAPIKey != "" {
client, err := createHTTPClient(opts.Proxy, searchTimeout)
client, err := utils.CreateHTTPClient(opts.Proxy, searchTimeout)
if err != nil {
return nil, fmt.Errorf("failed to create HTTP client for GLM Search: %w", err)
}
@@ -827,7 +792,7 @@ func NewWebFetchToolWithProxy(maxChars int, proxy string, fetchLimitBytes int64)
if maxChars <= 0 {
maxChars = defaultMaxChars
}
client, err := createHTTPClient(proxy, fetchTimeout)
client, err := utils.CreateHTTPClient(proxy, fetchTimeout)
if err != nil {
return nil, fmt.Errorf("failed to create HTTP client for web fetch: %w", err)
}
-103
View File
@@ -10,7 +10,6 @@ import (
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/sipeed/picoclaw/pkg/logger"
)
@@ -639,108 +638,6 @@ func TestWebTool_WebFetch_MissingDomain(t *testing.T) {
}
}
func TestCreateHTTPClient_ProxyConfigured(t *testing.T) {
client, err := createHTTPClient("http://127.0.0.1:7890", 12*time.Second)
if err != nil {
t.Fatalf("createHTTPClient() error: %v", err)
}
if client.Timeout != 12*time.Second {
t.Fatalf("client.Timeout = %v, want %v", client.Timeout, 12*time.Second)
}
tr, ok := client.Transport.(*http.Transport)
if !ok {
t.Fatalf("client.Transport type = %T, want *http.Transport", client.Transport)
}
if tr.Proxy == nil {
t.Fatal("transport.Proxy is nil, want non-nil")
}
req, err := http.NewRequest("GET", "https://example.com", nil)
if err != nil {
t.Fatalf("http.NewRequest() error: %v", err)
}
proxyURL, err := tr.Proxy(req)
if err != nil {
t.Fatalf("transport.Proxy(req) error: %v", err)
}
if proxyURL == nil || proxyURL.String() != "http://127.0.0.1:7890" {
t.Fatalf("proxy URL = %v, want %q", proxyURL, "http://127.0.0.1:7890")
}
}
func TestCreateHTTPClient_InvalidProxy(t *testing.T) {
_, err := createHTTPClient("://bad-proxy", 10*time.Second)
if err == nil {
t.Fatal("createHTTPClient() expected error for invalid proxy URL, got nil")
}
}
func TestCreateHTTPClient_Socks5ProxyConfigured(t *testing.T) {
client, err := createHTTPClient("socks5://127.0.0.1:1080", 8*time.Second)
if err != nil {
t.Fatalf("createHTTPClient() error: %v", err)
}
tr, ok := client.Transport.(*http.Transport)
if !ok {
t.Fatalf("client.Transport type = %T, want *http.Transport", client.Transport)
}
req, err := http.NewRequest("GET", "https://example.com", nil)
if err != nil {
t.Fatalf("http.NewRequest() error: %v", err)
}
proxyURL, err := tr.Proxy(req)
if err != nil {
t.Fatalf("transport.Proxy(req) error: %v", err)
}
if proxyURL == nil || proxyURL.String() != "socks5://127.0.0.1:1080" {
t.Fatalf("proxy URL = %v, want %q", proxyURL, "socks5://127.0.0.1:1080")
}
}
func TestCreateHTTPClient_UnsupportedProxyScheme(t *testing.T) {
_, err := createHTTPClient("ftp://127.0.0.1:21", 10*time.Second)
if err == nil {
t.Fatal("createHTTPClient() expected error for unsupported scheme, got nil")
}
if !strings.Contains(err.Error(), "unsupported proxy scheme") {
t.Fatalf("error = %q, want to contain %q", err.Error(), "unsupported proxy scheme")
}
}
func TestCreateHTTPClient_ProxyFromEnvironmentWhenConfigEmpty(t *testing.T) {
t.Setenv("HTTP_PROXY", "http://127.0.0.1:8888")
t.Setenv("http_proxy", "http://127.0.0.1:8888")
t.Setenv("HTTPS_PROXY", "http://127.0.0.1:8888")
t.Setenv("https_proxy", "http://127.0.0.1:8888")
t.Setenv("ALL_PROXY", "")
t.Setenv("all_proxy", "")
t.Setenv("NO_PROXY", "")
t.Setenv("no_proxy", "")
client, err := createHTTPClient("", 10*time.Second)
if err != nil {
t.Fatalf("createHTTPClient() error: %v", err)
}
tr, ok := client.Transport.(*http.Transport)
if !ok {
t.Fatalf("client.Transport type = %T, want *http.Transport", client.Transport)
}
if tr.Proxy == nil {
t.Fatal("transport.Proxy is nil, want proxy function from environment")
}
req, err := http.NewRequest("GET", "https://example.com", nil)
if err != nil {
t.Fatalf("http.NewRequest() error: %v", err)
}
if _, err := tr.Proxy(req); err != nil {
t.Fatalf("transport.Proxy(req) error: %v", err)
}
}
func TestNewWebFetchToolWithProxy(t *testing.T) {
tool, err := NewWebFetchToolWithProxy(1024, "http://127.0.0.1:7890", testFetchLimit)
if err != nil {
+48
View File
@@ -0,0 +1,48 @@
package utils
import (
"fmt"
"net/http"
"net/url"
"strings"
"time"
)
// CreateHTTPClient creates an HTTP client with optional proxy support.
// If proxyURL is empty, it uses the system environment proxy settings.
// Supported proxy schemes: http, https, socks5, socks5h.
func CreateHTTPClient(proxyURL string, timeout time.Duration) (*http.Client, error) {
client := &http.Client{
Timeout: timeout,
Transport: &http.Transport{
MaxIdleConns: 10,
IdleConnTimeout: 30 * time.Second,
DisableCompression: false,
TLSHandshakeTimeout: 15 * time.Second,
},
}
if proxyURL != "" {
proxy, err := url.Parse(proxyURL)
if err != nil {
return nil, fmt.Errorf("invalid proxy URL: %w", err)
}
scheme := strings.ToLower(proxy.Scheme)
switch scheme {
case "http", "https", "socks5", "socks5h":
default:
return nil, fmt.Errorf(
"unsupported proxy scheme %q (supported: http, https, socks5, socks5h)",
proxy.Scheme,
)
}
if proxy.Host == "" {
return nil, fmt.Errorf("invalid proxy URL: missing host")
}
client.Transport.(*http.Transport).Proxy = http.ProxyURL(proxy)
} else {
client.Transport.(*http.Transport).Proxy = http.ProxyFromEnvironment
}
return client, nil
}
+110
View File
@@ -0,0 +1,110 @@
package utils
import (
"net/http"
"strings"
"testing"
"time"
)
func TestCreateHTTPClient_ProxyConfigured(t *testing.T) {
client, err := CreateHTTPClient("http://127.0.0.1:7890", 12*time.Second)
if err != nil {
t.Fatalf("createHTTPClient() error: %v", err)
}
if client.Timeout != 12*time.Second {
t.Fatalf("client.Timeout = %v, want %v", client.Timeout, 12*time.Second)
}
tr, ok := client.Transport.(*http.Transport)
if !ok {
t.Fatalf("client.Transport type = %T, want *http.Transport", client.Transport)
}
if tr.Proxy == nil {
t.Fatal("transport.Proxy is nil, want non-nil")
}
req, err := http.NewRequest("GET", "https://example.com", nil)
if err != nil {
t.Fatalf("http.NewRequest() error: %v", err)
}
proxyURL, err := tr.Proxy(req)
if err != nil {
t.Fatalf("transport.Proxy(req) error: %v", err)
}
if proxyURL == nil || proxyURL.String() != "http://127.0.0.1:7890" {
t.Fatalf("proxy URL = %v, want %q", proxyURL, "http://127.0.0.1:7890")
}
}
func TestCreateHTTPClient_InvalidProxy(t *testing.T) {
_, err := CreateHTTPClient("://bad-proxy", 10*time.Second)
if err == nil {
t.Fatal("createHTTPClient() expected error for invalid proxy URL, got nil")
}
}
func TestCreateHTTPClient_Socks5ProxyConfigured(t *testing.T) {
client, err := CreateHTTPClient("socks5://127.0.0.1:1080", 8*time.Second)
if err != nil {
t.Fatalf("createHTTPClient() error: %v", err)
}
tr, ok := client.Transport.(*http.Transport)
if !ok {
t.Fatalf("client.Transport type = %T, want *http.Transport", client.Transport)
}
req, err := http.NewRequest("GET", "https://example.com", nil)
if err != nil {
t.Fatalf("http.NewRequest() error: %v", err)
}
proxyURL, err := tr.Proxy(req)
if err != nil {
t.Fatalf("transport.Proxy(req) error: %v", err)
}
if proxyURL == nil || proxyURL.String() != "socks5://127.0.0.1:1080" {
t.Fatalf("proxy URL = %v, want %q", proxyURL, "socks5://127.0.0.1:1080")
}
}
func TestCreateHTTPClient_UnsupportedProxyScheme(t *testing.T) {
_, err := CreateHTTPClient("ftp://127.0.0.1:21", 10*time.Second)
if err == nil {
t.Fatal("createHTTPClient() expected error for unsupported scheme, got nil")
}
if !strings.Contains(err.Error(), "unsupported proxy scheme") {
t.Fatalf("error = %q, want to contain %q", err.Error(), "unsupported proxy scheme")
}
}
func TestCreateHTTPClient_ProxyFromEnvironmentWhenConfigEmpty(t *testing.T) {
t.Setenv("HTTP_PROXY", "http://127.0.0.1:8888")
t.Setenv("http_proxy", "http://127.0.0.1:8888")
t.Setenv("HTTPS_PROXY", "http://127.0.0.1:8888")
t.Setenv("https_proxy", "http://127.0.0.1:8888")
t.Setenv("ALL_PROXY", "")
t.Setenv("all_proxy", "")
t.Setenv("NO_PROXY", "")
t.Setenv("no_proxy", "")
client, err := CreateHTTPClient("", 10*time.Second)
if err != nil {
t.Fatalf("createHTTPClient() error: %v", err)
}
tr, ok := client.Transport.(*http.Transport)
if !ok {
t.Fatalf("client.Transport type = %T, want *http.Transport", client.Transport)
}
if tr.Proxy == nil {
t.Fatal("transport.Proxy is nil, want proxy function from environment")
}
req, err := http.NewRequest("GET", "https://example.com", nil)
if err != nil {
t.Fatalf("http.NewRequest() error: %v", err)
}
if _, err := tr.Proxy(req); err != nil {
t.Fatalf("transport.Proxy(req) error: %v", err)
}
}