diff --git a/cmd/picoclaw/internal/auth/status_test.go b/cmd/picoclaw/internal/auth/status_test.go index 7748ba502..2f9a70721 100644 --- a/cmd/picoclaw/internal/auth/status_test.go +++ b/cmd/picoclaw/internal/auth/status_test.go @@ -1,12 +1,53 @@ package auth import ( + "bytes" + "encoding/json" + "io" + "os" + "path/filepath" + "strings" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + pkgauth "github.com/sipeed/picoclaw/pkg/auth" + "github.com/sipeed/picoclaw/pkg/config" ) +func captureAuthStdout(t *testing.T, fn func()) string { + t.Helper() + + oldStdout := os.Stdout + r, w, err := os.Pipe() + require.NoError(t, err) + os.Stdout = w + t.Cleanup(func() { + os.Stdout = oldStdout + }) + + fn() + + require.NoError(t, w.Close()) + os.Stdout = oldStdout + + var buf bytes.Buffer + _, err = io.Copy(&buf, r) + require.NoError(t, err) + require.NoError(t, r.Close()) + return buf.String() +} + +func setAuthStatusTestHome(t *testing.T) string { + t.Helper() + + tmpDir := t.TempDir() + t.Setenv(config.EnvHome, filepath.Join(tmpDir, ".picoclaw")) + return tmpDir +} + func TestNewStatusSubcommand(t *testing.T) { cmd := newStatusCommand() @@ -16,3 +57,47 @@ func TestNewStatusSubcommand(t *testing.T) { assert.False(t, cmd.HasFlags()) } + +func TestAuthStatusCmdShowsCanonicalGoogleAntigravityAfterLegacyRefresh(t *testing.T) { + tmpDir := setAuthStatusTestHome(t) + + legacyExpiry := time.Date(2026, 4, 16, 10, 0, 0, 0, time.UTC) + legacyStore := map[string]any{ + "credentials": map[string]any{ + "antigravity": map[string]any{ + "access_token": "legacy-token", + "expires_at": legacyExpiry.Format(time.RFC3339), + "provider": "antigravity", + "auth_method": "oauth", + "project_id": "legacy-project", + }, + }, + } + data, err := json.Marshal(legacyStore) + require.NoError(t, err) + + authPath := filepath.Join(tmpDir, ".picoclaw", "auth.json") + require.NoError(t, os.MkdirAll(filepath.Dir(authPath), 0o755)) + require.NoError(t, os.WriteFile(authPath, data, 0o600)) + + refreshedExpiry := time.Date(2026, 4, 16, 12, 30, 0, 0, time.UTC) + err = pkgauth.SetCredential("google-antigravity", &pkgauth.AuthCredential{ + AccessToken: "fresh-token", + ExpiresAt: refreshedExpiry, + Provider: "google-antigravity", + AuthMethod: "oauth", + ProjectID: "fresh-project", + }) + require.NoError(t, err) + + output := captureAuthStdout(t, func() { + require.NoError(t, authStatusCmd()) + }) + + assert.Contains(t, output, "\nAuthenticated Providers:") + assert.Contains(t, output, "\n google-antigravity:\n") + assert.NotContains(t, output, "\n antigravity:\n") + assert.Contains(t, output, " Project: fresh-project") + assert.Contains(t, output, " Expires: 2026-04-16 12:30") + assert.Equal(t, 1, strings.Count(output, ":\n Method: oauth")) +} diff --git a/pkg/auth/store.go b/pkg/auth/store.go index dfea11df4..0e6567a03 100644 --- a/pkg/auth/store.go +++ b/pkg/auth/store.go @@ -4,6 +4,7 @@ import ( "encoding/json" "os" "path/filepath" + "strings" "time" "github.com/sipeed/picoclaw/pkg/config" @@ -25,6 +26,11 @@ type AuthStore struct { Credentials map[string]*AuthCredential `json:"credentials"` } +const ( + providerGoogleAntigravity = "google-antigravity" + providerAntigravityAlias = "antigravity" +) + func (c *AuthCredential) IsExpired() bool { if c.ExpiresAt.IsZero() { return false @@ -43,6 +49,125 @@ func authFilePath() string { return filepath.Join(config.GetHome(), "auth.json") } +func canonicalProvider(provider string) string { + normalized := strings.ToLower(strings.TrimSpace(provider)) + switch normalized { + case providerAntigravityAlias: + return providerGoogleAntigravity + default: + return normalized + } +} + +func cloneCredential(cred *AuthCredential) *AuthCredential { + if cred == nil { + return nil + } + cp := *cred + return &cp +} + +func mergeCredentials(primary, secondary *AuthCredential) *AuthCredential { + if primary == nil { + return cloneCredential(secondary) + } + + merged := *primary + if secondary == nil { + return &merged + } + if merged.AccessToken == "" { + merged.AccessToken = secondary.AccessToken + } + if merged.RefreshToken == "" { + merged.RefreshToken = secondary.RefreshToken + } + if merged.AccountID == "" { + merged.AccountID = secondary.AccountID + } + if merged.ExpiresAt.IsZero() { + merged.ExpiresAt = secondary.ExpiresAt + } + if merged.Provider == "" { + merged.Provider = secondary.Provider + } + if merged.AuthMethod == "" { + merged.AuthMethod = secondary.AuthMethod + } + if merged.Email == "" { + merged.Email = secondary.Email + } + if merged.ProjectID == "" { + merged.ProjectID = secondary.ProjectID + } + + return &merged +} + +func shouldPreferCredential( + candidate *AuthCredential, + candidateCanonical bool, + current *AuthCredential, + currentCanonical bool, +) bool { + if candidate == nil { + return false + } + if current == nil { + return true + } + + switch { + case candidate.ExpiresAt.After(current.ExpiresAt): + return true + case current.ExpiresAt.After(candidate.ExpiresAt): + return false + case candidateCanonical != currentCanonical: + return candidateCanonical + default: + return false + } +} + +func normalizeStore(store *AuthStore) { + if store == nil { + return + } + if store.Credentials == nil { + store.Credentials = make(map[string]*AuthCredential) + return + } + + normalized := make(map[string]*AuthCredential, len(store.Credentials)) + canonicalFlags := make(map[string]bool, len(store.Credentials)) + + for provider, cred := range store.Credentials { + normalizedProvider := strings.ToLower(strings.TrimSpace(provider)) + canonical := canonicalProvider(provider) + normalizedCred := cloneCredential(cred) + if normalizedCred != nil { + normalizedCred.Provider = canonicalProvider(normalizedCred.Provider) + if normalizedCred.Provider == "" { + normalizedCred.Provider = canonical + } + } + + current := normalized[canonical] + currentCanonical := canonicalFlags[canonical] + candidateCanonical := normalizedProvider == canonical + + if shouldPreferCredential(normalizedCred, candidateCanonical, current, currentCanonical) { + normalized[canonical] = mergeCredentials(normalizedCred, current) + canonicalFlags[canonical] = candidateCanonical + continue + } + + normalized[canonical] = mergeCredentials(current, normalizedCred) + } + + store.Credentials = normalized +} + func LoadStore() (*AuthStore, error) { path := authFilePath() data, err := os.ReadFile(path) @@ -57,9 +182,7 @@ func LoadStore() (*AuthStore, error) { if err := json.Unmarshal(data, &store); err != nil { return nil, err } - if store.Credentials == nil { - store.Credentials = make(map[string]*AuthCredential) - } + normalizeStore(&store) return &store, nil } @@ -79,7 +202,7 @@ func GetCredential(provider string) (*AuthCredential, error) { if err != nil { return nil, err } - cred, ok := store.Credentials[provider] + cred, ok := store.Credentials[canonicalProvider(provider)] if !ok { return nil, nil } @@ -91,7 +214,17 @@ func SetCredential(provider string, cred *AuthCredential) error { if err != nil { return err } - store.Credentials[provider] = cred + + canonical := canonicalProvider(provider) + normalized := cloneCredential(cred) + if normalized != nil { + normalized.Provider = canonicalProvider(normalized.Provider) + if normalized.Provider == "" { + normalized.Provider = canonical + } + } + + store.Credentials[canonical] = normalized return SaveStore(store) } @@ -100,7 +233,7 @@ func DeleteCredential(provider string) error { if err != nil { return err } - delete(store.Credentials, provider) + delete(store.Credentials, canonicalProvider(provider)) return SaveStore(store) } diff --git a/pkg/auth/store_test.go b/pkg/auth/store_test.go index f6793cfce..578ed4ead 100644 --- a/pkg/auth/store_test.go +++ b/pkg/auth/store_test.go @@ -1,12 +1,24 @@ package auth import ( + "encoding/json" "os" "path/filepath" + "runtime" "testing" "time" + + "github.com/sipeed/picoclaw/pkg/config" ) +func setTestAuthHome(t *testing.T) string { + t.Helper() + + tmpDir := t.TempDir() + t.Setenv(config.EnvHome, filepath.Join(tmpDir, ".picoclaw")) + return tmpDir +} + func TestAuthCredentialIsExpired(t *testing.T) { tests := []struct { name string @@ -51,10 +63,7 @@ func TestAuthCredentialNeedsRefresh(t *testing.T) { } func TestStoreRoundtrip(t *testing.T) { - tmpDir := t.TempDir() - origHome := os.Getenv("HOME") - t.Setenv("HOME", tmpDir) - defer os.Setenv("HOME", origHome) + setTestAuthHome(t) cred := &AuthCredential{ AccessToken: "test-access-token", @@ -88,10 +97,7 @@ func TestStoreRoundtrip(t *testing.T) { } func TestStoreFilePermissions(t *testing.T) { - tmpDir := t.TempDir() - origHome := os.Getenv("HOME") - t.Setenv("HOME", tmpDir) - defer os.Setenv("HOME", origHome) + tmpDir := setTestAuthHome(t) cred := &AuthCredential{ AccessToken: "secret-token", @@ -108,16 +114,16 @@ func TestStoreFilePermissions(t *testing.T) { t.Fatalf("Stat() error: %v", err) } perm := info.Mode().Perm() + if runtime.GOOS == "windows" { + return + } if perm != 0o600 { t.Errorf("file permissions = %o, want 0600", perm) } } func TestStoreMultiProvider(t *testing.T) { - tmpDir := t.TempDir() - origHome := os.Getenv("HOME") - t.Setenv("HOME", tmpDir) - defer os.Setenv("HOME", origHome) + setTestAuthHome(t) openaiCred := &AuthCredential{AccessToken: "openai-token", Provider: "openai", AuthMethod: "oauth"} anthropicCred := &AuthCredential{AccessToken: "anthropic-token", Provider: "anthropic", AuthMethod: "token"} @@ -147,10 +153,7 @@ func TestStoreMultiProvider(t *testing.T) { } func TestDeleteCredential(t *testing.T) { - tmpDir := t.TempDir() - origHome := os.Getenv("HOME") - t.Setenv("HOME", tmpDir) - defer os.Setenv("HOME", origHome) + setTestAuthHome(t) cred := &AuthCredential{AccessToken: "to-delete", Provider: "openai", AuthMethod: "oauth"} if err := SetCredential("openai", cred); err != nil { @@ -171,10 +174,7 @@ func TestDeleteCredential(t *testing.T) { } func TestLoadStoreEmpty(t *testing.T) { - tmpDir := t.TempDir() - origHome := os.Getenv("HOME") - t.Setenv("HOME", tmpDir) - defer os.Setenv("HOME", origHome) + setTestAuthHome(t) store, err := LoadStore() if err != nil { @@ -187,3 +187,319 @@ func TestLoadStoreEmpty(t *testing.T) { t.Errorf("expected empty credentials, got %d", len(store.Credentials)) } } + +func TestGetCredentialCanonicalizesLegacyAntigravityProvider(t *testing.T) { + tmpDir := setTestAuthHome(t) + + expiresAt := time.Date(2026, 4, 16, 10, 0, 0, 0, time.UTC) + store := map[string]any{ + "credentials": map[string]any{ + "antigravity": map[string]any{ + "access_token": "legacy-token", + "expires_at": expiresAt.Format(time.RFC3339), + "provider": "antigravity", + "auth_method": "oauth", + "project_id": "project-1", + }, + }, + } + data, err := json.Marshal(store) + if err != nil { + t.Fatalf("json.Marshal() error: %v", err) + } + path := filepath.Join(tmpDir, ".picoclaw", "auth.json") + err = os.MkdirAll(filepath.Dir(path), 0o755) + if err != nil { + t.Fatalf("MkdirAll() error: %v", err) + } + err = os.WriteFile(path, data, 0o600) + if err != nil { + t.Fatalf("WriteFile() error: %v", err) + } + + cred, err := GetCredential("google-antigravity") + if err != nil { + t.Fatalf("GetCredential() error: %v", err) + } + if cred == nil { + t.Fatal("GetCredential() returned nil") + } + if cred.Provider != "google-antigravity" { + t.Fatalf("Provider = %q, want %q", cred.Provider, "google-antigravity") + } + if !cred.ExpiresAt.Equal(expiresAt) { + t.Fatalf("ExpiresAt = %v, want %v", cred.ExpiresAt, expiresAt) + } +} + +func TestLoadStoreMergesAntigravityAliasesPreferringNewerExpiry(t *testing.T) { + tmpDir := setTestAuthHome(t) + + legacyExpiry := time.Date(2026, 4, 16, 10, 0, 0, 0, time.UTC) + refreshedExpiry := time.Date(2026, 4, 16, 12, 0, 0, 0, time.UTC) + store := map[string]any{ + "credentials": map[string]any{ + "antigravity": map[string]any{ + "access_token": "legacy-token", + "refresh_token": "legacy-refresh", + "expires_at": legacyExpiry.Format(time.RFC3339), + "provider": "antigravity", + "auth_method": "oauth", + "email": "legacy@example.com", + }, + "google-antigravity": map[string]any{ + "access_token": "fresh-token", + "expires_at": refreshedExpiry.Format(time.RFC3339), + "provider": "google-antigravity", + "auth_method": "oauth", + "project_id": "project-2", + }, + }, + } + data, err := json.Marshal(store) + if err != nil { + t.Fatalf("json.Marshal() error: %v", err) + } + path := filepath.Join(tmpDir, ".picoclaw", "auth.json") + err = os.MkdirAll(filepath.Dir(path), 0o755) + if err != nil { + t.Fatalf("MkdirAll() error: %v", err) + } + err = os.WriteFile(path, data, 0o600) + if err != nil { + t.Fatalf("WriteFile() error: %v", err) + } + + loaded, err := LoadStore() + if err != nil { + t.Fatalf("LoadStore() error: %v", err) + } + if len(loaded.Credentials) != 1 { + t.Fatalf("credential count = %d, want 1", len(loaded.Credentials)) + } + + cred := loaded.Credentials["google-antigravity"] + if cred == nil { + t.Fatal("google-antigravity credential missing") + } + if cred.AccessToken != "fresh-token" { + t.Fatalf("AccessToken = %q, want %q", cred.AccessToken, "fresh-token") + } + if cred.RefreshToken != "legacy-refresh" { + t.Fatalf("RefreshToken = %q, want %q", cred.RefreshToken, "legacy-refresh") + } + if cred.Email != "legacy@example.com" { + t.Fatalf("Email = %q, want %q", cred.Email, "legacy@example.com") + } + if cred.ProjectID != "project-2" { + t.Fatalf("ProjectID = %q, want %q", cred.ProjectID, "project-2") + } + if !cred.ExpiresAt.Equal(refreshedExpiry) { + t.Fatalf("ExpiresAt = %v, want %v", cred.ExpiresAt, refreshedExpiry) + } +} + +func TestLoadStorePrefersCanonicalKeyWhenExpiryMatchesAlias(t *testing.T) { + tmpDir := setTestAuthHome(t) + + expiresAt := time.Date(2026, 4, 16, 12, 0, 0, 0, time.UTC) + store := map[string]any{ + "credentials": map[string]any{ + "antigravity": map[string]any{ + "access_token": "legacy-token", + "refresh_token": "legacy-refresh", + "expires_at": expiresAt.Format(time.RFC3339), + "provider": "antigravity", + "auth_method": "oauth", + "email": "legacy@example.com", + }, + " Google-Antigravity ": map[string]any{ + "access_token": "fresh-token", + "expires_at": expiresAt.Format(time.RFC3339), + "provider": " Google-Antigravity ", + "auth_method": "oauth", + "project_id": "project-2", + }, + }, + } + data, err := json.Marshal(store) + if err != nil { + t.Fatalf("json.Marshal() error: %v", err) + } + path := filepath.Join(tmpDir, ".picoclaw", "auth.json") + err = os.MkdirAll(filepath.Dir(path), 0o755) + if err != nil { + t.Fatalf("MkdirAll() error: %v", err) + } + err = os.WriteFile(path, data, 0o600) + if err != nil { + t.Fatalf("WriteFile() error: %v", err) + } + + loaded, err := LoadStore() + if err != nil { + t.Fatalf("LoadStore() error: %v", err) + } + if len(loaded.Credentials) != 1 { + t.Fatalf("credential count = %d, want 1", len(loaded.Credentials)) + } + + cred := loaded.Credentials["google-antigravity"] + if cred == nil { + t.Fatal("google-antigravity credential missing") + } + if cred.AccessToken != "fresh-token" { + t.Fatalf("AccessToken = %q, want %q", cred.AccessToken, "fresh-token") + } + if cred.RefreshToken != "legacy-refresh" { + t.Fatalf("RefreshToken = %q, want %q", cred.RefreshToken, "legacy-refresh") + } + if cred.Email != "legacy@example.com" { + t.Fatalf("Email = %q, want %q", cred.Email, "legacy@example.com") + } + if cred.ProjectID != "project-2" { + t.Fatalf("ProjectID = %q, want %q", cred.ProjectID, "project-2") + } +} + +func TestSetCredentialReplacesLegacyAntigravityEntry(t *testing.T) { + tmpDir := setTestAuthHome(t) + + legacyStore := map[string]any{ + "credentials": map[string]any{ + "antigravity": map[string]any{ + "access_token": "legacy-token", + "expires_at": time.Date(2026, 4, 16, 10, 0, 0, 0, time.UTC).Format(time.RFC3339), + "provider": "antigravity", + "auth_method": "oauth", + }, + }, + } + data, err := json.Marshal(legacyStore) + if err != nil { + t.Fatalf("json.Marshal() error: %v", err) + } + path := filepath.Join(tmpDir, ".picoclaw", "auth.json") + err = os.MkdirAll(filepath.Dir(path), 0o755) + if err != nil { + t.Fatalf("MkdirAll() error: %v", err) + } + err = os.WriteFile(path, data, 0o600) + if err != nil { + t.Fatalf("WriteFile() error: %v", err) + } + + refreshedExpiry := time.Date(2026, 4, 16, 12, 30, 0, 0, time.UTC) + err = SetCredential("google-antigravity", &AuthCredential{ + AccessToken: "fresh-token", + ExpiresAt: refreshedExpiry, + Provider: "google-antigravity", + AuthMethod: "oauth", + }) + if err != nil { + t.Fatalf("SetCredential() error: %v", err) + } + + loaded, err := LoadStore() + if err != nil { + t.Fatalf("LoadStore() error: %v", err) + } + if len(loaded.Credentials) != 1 { + t.Fatalf("credential count = %d, want 1", len(loaded.Credentials)) + } + + cred := loaded.Credentials["google-antigravity"] + if cred == nil { + t.Fatal("google-antigravity credential missing") + } + if cred.AccessToken != "fresh-token" { + t.Fatalf("AccessToken = %q, want %q", cred.AccessToken, "fresh-token") + } + if !cred.ExpiresAt.Equal(refreshedExpiry) { + t.Fatalf("ExpiresAt = %v, want %v", cred.ExpiresAt, refreshedExpiry) + } +} + +func TestDeleteCredentialRemovesLegacyAntigravityAlias(t *testing.T) { + tmpDir := setTestAuthHome(t) + + legacyStore := map[string]any{ + "credentials": map[string]any{ + "antigravity": map[string]any{ + "access_token": "legacy-token", + "provider": "antigravity", + "auth_method": "oauth", + }, + }, + } + data, err := json.Marshal(legacyStore) + if err != nil { + t.Fatalf("json.Marshal() error: %v", err) + } + path := filepath.Join(tmpDir, ".picoclaw", "auth.json") + err = os.MkdirAll(filepath.Dir(path), 0o755) + if err != nil { + t.Fatalf("MkdirAll() error: %v", err) + } + err = os.WriteFile(path, data, 0o600) + if err != nil { + t.Fatalf("WriteFile() error: %v", err) + } + + err = DeleteCredential(" google-antigravity ") + if err != nil { + t.Fatalf("DeleteCredential() error: %v", err) + } + + loaded, err := LoadStore() + if err != nil { + t.Fatalf("LoadStore() error: %v", err) + } + if len(loaded.Credentials) != 0 { + t.Fatalf("credential count = %d, want 0", len(loaded.Credentials)) + } +} + +func TestSetCredentialCanonicalizesTrimmedMixedCaseProvider(t *testing.T) { + setTestAuthHome(t) + + expiresAt := time.Date(2026, 4, 16, 13, 0, 0, 0, time.UTC) + if err := SetCredential(" AnTiGrAvItY ", &AuthCredential{ + AccessToken: "fresh-token", + ExpiresAt: expiresAt, + Provider: " AnTiGrAvItY ", + AuthMethod: "oauth", + }); err != nil { + t.Fatalf("SetCredential() error: %v", err) + } + + loaded, err := LoadStore() + if err != nil { + t.Fatalf("LoadStore() error: %v", err) + } + if len(loaded.Credentials) != 1 { + t.Fatalf("credential count = %d, want 1", len(loaded.Credentials)) + } + + cred := loaded.Credentials["google-antigravity"] + if cred == nil { + t.Fatal("google-antigravity credential missing") + } + if cred.Provider != "google-antigravity" { + t.Fatalf("Provider = %q, want %q", cred.Provider, "google-antigravity") + } + if !cred.ExpiresAt.Equal(expiresAt) { + t.Fatalf("ExpiresAt = %v, want %v", cred.ExpiresAt, expiresAt) + } + + got, err := GetCredential(" GoOgLe-AnTiGrAvItY ") + if err != nil { + t.Fatalf("GetCredential() error: %v", err) + } + if got == nil { + t.Fatal("GetCredential() returned nil") + } + if got.Provider != "google-antigravity" { + t.Fatalf("GetCredential provider = %q, want %q", got.Provider, "google-antigravity") + } +}