diff --git a/cmd/picoclaw/internal/auth/helpers.go b/cmd/picoclaw/internal/auth/helpers.go index 531cb76aa..41255a04e 100644 --- a/cmd/picoclaw/internal/auth/helpers.go +++ b/cmd/picoclaw/internal/auth/helpers.go @@ -21,20 +21,20 @@ const ( defaultAnthropicModel = "claude-sonnet-4.6" ) -func authLoginCmd(provider string, useDeviceCode bool, useOauth bool) error { +func authLoginCmd(provider string, useDeviceCode bool, useOauth bool, noBrowser bool) error { switch provider { case "openai": - return authLoginOpenAI(useDeviceCode) + return authLoginOpenAI(useDeviceCode, noBrowser) case "anthropic": return authLoginAnthropic(useOauth) case "google-antigravity", "antigravity": - return authLoginGoogleAntigravity() + return authLoginGoogleAntigravity(noBrowser) default: return fmt.Errorf("unsupported provider: %s (%s)", provider, supportedProvidersMsg) } } -func authLoginOpenAI(useDeviceCode bool) error { +func authLoginOpenAI(useDeviceCode bool, noBrowser bool) error { cfg := auth.OpenAIOAuthConfig() var cred *auth.AuthCredential @@ -43,7 +43,7 @@ func authLoginOpenAI(useDeviceCode bool) error { if useDeviceCode { cred, err = auth.LoginDeviceCode(cfg) } else { - cred, err = auth.LoginBrowser(cfg) + cred, err = auth.LoginBrowserWithOptions(cfg, auth.LoginBrowserOptions{NoBrowser: noBrowser}) } if err != nil { @@ -92,10 +92,10 @@ func authLoginOpenAI(useDeviceCode bool) error { return nil } -func authLoginGoogleAntigravity() error { +func authLoginGoogleAntigravity(noBrowser bool) error { cfg := auth.GoogleAntigravityOAuthConfig() - cred, err := auth.LoginBrowser(cfg) + cred, err := auth.LoginBrowserWithOptions(cfg, auth.LoginBrowserOptions{NoBrowser: noBrowser}) if err != nil { return fmt.Errorf("login failed: %w", err) } diff --git a/cmd/picoclaw/internal/auth/login.go b/cmd/picoclaw/internal/auth/login.go index afbe098aa..406144917 100644 --- a/cmd/picoclaw/internal/auth/login.go +++ b/cmd/picoclaw/internal/auth/login.go @@ -7,6 +7,7 @@ func newLoginCommand() *cobra.Command { provider string useDeviceCode bool useOauth bool + noBrowser bool ) cmd := &cobra.Command{ @@ -14,12 +15,15 @@ func newLoginCommand() *cobra.Command { Short: "Login via OAuth or paste token", Args: cobra.NoArgs, RunE: func(cmd *cobra.Command, _ []string) error { - return authLoginCmd(provider, useDeviceCode, useOauth) + return authLoginCmd(provider, useDeviceCode, useOauth, noBrowser) }, } - cmd.Flags().StringVarP(&provider, "provider", "p", "", "Provider to login with (openai, anthropic)") + cmd.Flags().StringVarP( + &provider, "provider", "p", "", "Provider to login with (openai, anthropic, google-antigravity)", + ) cmd.Flags().BoolVar(&useDeviceCode, "device-code", false, "Use device code flow (for headless environments)") + cmd.Flags().BoolVar(&noBrowser, "no-browser", false, "Do not auto-open a browser during OAuth login") cmd.Flags().BoolVar( &useOauth, "setup-token", false, "Use setup-token flow for Anthropic (from `claude setup-token`)", diff --git a/cmd/picoclaw/internal/auth/login_test.go b/cmd/picoclaw/internal/auth/login_test.go index d6a03c25b..5129d9aaf 100644 --- a/cmd/picoclaw/internal/auth/login_test.go +++ b/cmd/picoclaw/internal/auth/login_test.go @@ -18,6 +18,7 @@ func TestNewLoginSubCommand(t *testing.T) { assert.True(t, cmd.HasFlags()) assert.NotNil(t, cmd.Flags().Lookup("device-code")) + assert.NotNil(t, cmd.Flags().Lookup("no-browser")) providerFlag := cmd.Flags().Lookup("provider") require.NotNil(t, providerFlag) diff --git a/pkg/auth/oauth.go b/pkg/auth/oauth.go index 2bf719dd4..0db693475 100644 --- a/pkg/auth/oauth.go +++ b/pkg/auth/oauth.go @@ -30,6 +30,15 @@ type OAuthProviderConfig struct { Port int } +type LoginBrowserOptions struct { + NoBrowser bool +} + +var ( + openBrowserFunc = OpenBrowser + browserLoginInput io.Reader = os.Stdin +) + func OpenAIOAuthConfig() OAuthProviderConfig { return OAuthProviderConfig{ Issuer: "https://auth.openai.com", @@ -76,6 +85,10 @@ func GenerateState() (string, error) { } func LoginBrowser(cfg OAuthProviderConfig) (*AuthCredential, error) { + return LoginBrowserWithOptions(cfg, LoginBrowserOptions{}) +} + +func LoginBrowserWithOptions(cfg OAuthProviderConfig, opts LoginBrowserOptions) (*AuthCredential, error) { pkce, err := GeneratePKCE() if err != nil { return nil, fmt.Errorf("generating PKCE: %w", err) @@ -128,7 +141,9 @@ func LoginBrowser(cfg OAuthProviderConfig) (*AuthCredential, error) { fmt.Printf("Open this URL to authenticate:\n\n%s\n\n", authURL) - if err := OpenBrowser(authURL); err != nil { + if opts.NoBrowser { + fmt.Println("Browser auto-open disabled. Open the URL manually to continue.") + } else if err := openBrowserFunc(authURL); err != nil { fmt.Printf("Could not open browser automatically.\nPlease open this URL manually:\n\n%s\n\n", authURL) } @@ -144,7 +159,7 @@ func LoginBrowser(cfg OAuthProviderConfig) (*AuthCredential, error) { // Start manual input in a goroutine manualCh := make(chan string) go func() { - reader := bufio.NewReader(os.Stdin) + reader := bufio.NewReader(browserLoginInput) input, _ := reader.ReadString('\n') manualCh <- strings.TrimSpace(input) }() diff --git a/pkg/auth/oauth_test.go b/pkg/auth/oauth_test.go index 230ac7c2a..0bf0558e7 100644 --- a/pkg/auth/oauth_test.go +++ b/pkg/auth/oauth_test.go @@ -3,6 +3,7 @@ package auth import ( "encoding/base64" "encoding/json" + "net" "net/http" "net/http/httptest" "net/url" @@ -373,3 +374,117 @@ func TestParseDeviceCodeResponseInvalidInterval(t *testing.T) { t.Fatal("expected error for invalid interval") } } + +func TestLoginBrowserWithOptionsSkipsAutoOpenWhenDisabled(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/oauth/token" { + http.Error(w, "not found", http.StatusNotFound) + return + } + + resp := map[string]any{ + "access_token": "mock-access-token", + "refresh_token": "mock-refresh-token", + "expires_in": 3600, + } + _ = json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + origOpenBrowserFunc := openBrowserFunc + origBrowserLoginInput := browserLoginInput + t.Cleanup(func() { + openBrowserFunc = origOpenBrowserFunc + browserLoginInput = origBrowserLoginInput + }) + + var openCalls int + openBrowserFunc = func(string) error { + openCalls++ + return nil + } + browserLoginInput = strings.NewReader("manual-code\n") + + cfg := OAuthProviderConfig{ + Issuer: server.URL, + ClientID: "test-client", + Scopes: "openid", + Port: freeLocalPort(t), + } + + cred, err := LoginBrowserWithOptions(cfg, LoginBrowserOptions{NoBrowser: true}) + if err != nil { + t.Fatalf("LoginBrowserWithOptions() error: %v", err) + } + + if openCalls != 0 { + t.Fatalf("openBrowserFunc call count = %d, want 0", openCalls) + } + if cred.AccessToken != "mock-access-token" { + t.Fatalf("AccessToken = %q, want %q", cred.AccessToken, "mock-access-token") + } +} + +func TestLoginBrowserWithOptionsAutoOpensByDefault(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/oauth/token" { + http.Error(w, "not found", http.StatusNotFound) + return + } + + resp := map[string]any{ + "access_token": "mock-access-token", + "refresh_token": "mock-refresh-token", + "expires_in": 3600, + } + _ = json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + origOpenBrowserFunc := openBrowserFunc + origBrowserLoginInput := browserLoginInput + t.Cleanup(func() { + openBrowserFunc = origOpenBrowserFunc + browserLoginInput = origBrowserLoginInput + }) + + var openCalls int + openBrowserFunc = func(string) error { + openCalls++ + return nil + } + browserLoginInput = strings.NewReader("manual-code\n") + + cfg := OAuthProviderConfig{ + Issuer: server.URL, + ClientID: "test-client", + Scopes: "openid", + Port: freeLocalPort(t), + } + + _, err := LoginBrowserWithOptions(cfg, LoginBrowserOptions{}) + if err != nil { + t.Fatalf("LoginBrowserWithOptions() error: %v", err) + } + + if openCalls != 1 { + t.Fatalf("openBrowserFunc call count = %d, want 1", openCalls) + } +} + +func freeLocalPort(t *testing.T) int { + t.Helper() + + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("net.Listen() error: %v", err) + } + defer listener.Close() + + addr, ok := listener.Addr().(*net.TCPAddr) + if !ok { + t.Fatalf("listener addr type = %T, want *net.TCPAddr", listener.Addr()) + } + + return addr.Port +}