diff --git a/pkg/skills/installer.go b/pkg/skills/installer.go index 3210509df..f9b5705f1 100644 --- a/pkg/skills/installer.go +++ b/pkg/skills/installer.go @@ -9,6 +9,8 @@ import ( "os" "path/filepath" "time" + + "github.com/sipeed/picoclaw/pkg/utils" ) type SkillInstaller struct { @@ -44,7 +46,7 @@ func (si *SkillInstaller) InstallFromGitHub(ctx context.Context, repo string) er return fmt.Errorf("failed to create request: %w", err) } - resp, err := client.Do(req) + resp, err := utils.DoRequestWithRetry(client, req) if err != nil { return fmt.Errorf("failed to fetch skill: %w", err) } @@ -94,7 +96,7 @@ func (si *SkillInstaller) ListAvailableSkills(ctx context.Context) ([]AvailableS return nil, fmt.Errorf("failed to create request: %w", err) } - resp, err := client.Do(req) + resp, err := utils.DoRequestWithRetry(client, req) if err != nil { return nil, fmt.Errorf("failed to fetch skills list: %w", err) } diff --git a/pkg/utils/http_retry.go b/pkg/utils/http_retry.go new file mode 100644 index 000000000..e90fa2129 --- /dev/null +++ b/pkg/utils/http_retry.go @@ -0,0 +1,57 @@ +package utils + +import ( + "context" + "fmt" + "net/http" + "time" +) + +const maxRetries = 3 + +var retryDelayUnit = time.Second + +func shouldRetry(statusCode int) bool { + return statusCode == http.StatusTooManyRequests || + statusCode >= 500 +} + +func DoRequestWithRetry(client *http.Client, req *http.Request) (*http.Response, error) { + var resp *http.Response + var err error + + for i := range maxRetries { + if i > 0 && resp != nil { + resp.Body.Close() + } + + resp, err = client.Do(req) + if err == nil { + if resp.StatusCode == http.StatusOK { + break + } + if !shouldRetry(resp.StatusCode) { + break + } + } + + if i < maxRetries-1 { + if err = sleepWithCtx(req.Context(), retryDelayUnit*time.Duration(i+1)); err != nil { + return nil, fmt.Errorf("failed to sleep: %w", err) + } + } + } + return resp, err +} + +func sleepWithCtx(ctx context.Context, d time.Duration) error { + timer := time.NewTimer(d) + defer timer.Stop() + + select { + case <-ctx.Done(): + return ctx.Err() + case <-timer.C: + return nil + } +} diff --git a/pkg/utils/http_retry_test.go b/pkg/utils/http_retry_test.go new file mode 100644 index 000000000..1c2dbe115 --- /dev/null +++ b/pkg/utils/http_retry_test.go @@ -0,0 +1,118 @@ +package utils + +import ( + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestDoRequestWithRetry(t *testing.T) { + retryDelayUnit = time.Millisecond + t.Cleanup(func() { retryDelayUnit = time.Second }) + + testcases := []struct { + name string + serverBehavior func(*httptest.Server) int + wantSuccess bool + wantAttempts int + }{ + { + name: "success-on-first-attempt", + serverBehavior: func(server *httptest.Server) int { + return 0 + }, + wantSuccess: true, + wantAttempts: 1, + }, + { + name: "fail-all-attempts", + serverBehavior: func(server *httptest.Server) int { + return 4 + }, + wantSuccess: false, + wantAttempts: 3, + }, + } + + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + attempts := 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts++ + if attempts <= tc.serverBehavior(nil) { + w.WriteHeader(http.StatusInternalServerError) + return + } + w.WriteHeader(http.StatusOK) + w.Write([]byte("success")) + })) + + t.Cleanup(func() { + server.Close() + }) + + client := &http.Client{Timeout: 5 * time.Second} + req, err := http.NewRequest(http.MethodGet, server.URL, nil) + require.NoError(t, err) + + resp, err := DoRequestWithRetry(client, req) + + if tc.wantSuccess { + require.NoError(t, err) + require.NotNil(t, resp) + assert.Equal(t, http.StatusOK, resp.StatusCode) + resp.Body.Close() + } else { + require.NotNil(t, resp) + assert.Equal(t, http.StatusInternalServerError, resp.StatusCode) + resp.Body.Close() + } + + assert.Equal(t, tc.wantAttempts, attempts) + }) + } +} + +func TestDoRequestWithRetry_Delay(t *testing.T) { + retryDelayUnit = time.Millisecond + t.Cleanup(func() { retryDelayUnit = time.Second }) + + var start time.Time + delays := []time.Duration{} + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if len(delays) == 0 { + delays = append(delays, 0) + w.WriteHeader(http.StatusInternalServerError) + return + } + if len(delays) == 1 { + start = time.Now() + delays = append(delays, 0) + w.WriteHeader(http.StatusInternalServerError) + return + } + if len(delays) == 2 { + elapsed := time.Since(start) + delays = append(delays, elapsed) + w.WriteHeader(http.StatusOK) + w.Write([]byte("success")) + } + })) + defer server.Close() + + client := &http.Client{Timeout: 10 * time.Second} + req, err := http.NewRequest(http.MethodGet, server.URL, nil) + require.NoError(t, err) + + resp, err := DoRequestWithRetry(client, req) + require.NoError(t, err) + require.NotNil(t, resp) + assert.Equal(t, http.StatusOK, resp.StatusCode) + resp.Body.Close() + + assert.GreaterOrEqual(t, delays[2], time.Millisecond) +}