mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
feat(credential): part1 add AES-GCM encryption, SecureStore, and onboard ke… (#1521)
* feat(credential): add AES-GCM encryption, SecureStore, and onboard keygen - pkg/credential: new package with AES-256-GCM enc:// credential format, HKDF-SHA256 key derivation (passphrase + optional SSH key binding), ErrPassphraseRequired / ErrDecryptionFailed sentinel errors, and PassphraseProvider hook for runtime passphrase injection - pkg/credential/store: lock-free SecureStore via atomic.Pointer[string]; passphrase never written to disk or os.Environ - pkg/credential/keygen: ed25519 SSH key generation helper used by onboard - pkg/config: replace os.Getenv(PassphraseEnvVar) with credential.PassphraseProvider() at all three call sites so that LoadConfig and SaveConfig use whatever passphrase source is active - cmd/picoclaw/onboard: prompt for passphrase with echo-off, generate picoclaw-specific SSH key, re-encrypt existing config on re-onboard - docs/credential_encryption.md: design doc for the enc:// format * fix(credential): address Copilot review comments on PR #1521 - credential.go: decouple ErrPassphraseRequired from env var name; message is now 'enc:// passphrase required' since PassphraseProvider may come from any source, not just os.Environ - credential.go: Resolver resolves symlinks via EvalSymlinks before the isWithinDir containment check, preventing symlink-based path traversal for file:// credential references - store.go: tighten comment to describe only what SecureStore guarantees (in-memory only); remove claims about how callers transport the value - store_test.go: replace the meaningless GetReturnsCopy test (Go strings are immutable, equality across two calls proves nothing) with TestSecureStore_ConcurrentSetGet that exercises atomic.Pointer under 10-goroutine concurrent Set/Get load - config_test.go: update error-message assertion to match new sentinel text - docs/credential_encryption.md: remove reference to non-existent 'picoclaw encrypt' subcommand; describe the onboard flow instead * fix(config): encryptPlaintextAPIKeys: struct-based encryption, fail-fast, remove raw []byte * fix(credential): require SSH private key for encryption/decryption, remove passphrase-only mode * lint: fix credential keygen lint, fix test keygen * onboard: make encryption opt-in via --enc flag Encryption (passphrase prompt + SSH key generation) is now only triggered when the user passes --enc to 'picoclaw onboard'. Without the flag, onboard skips the credential-encryption setup and writes a plain config + workspace templates directly. - Add --enc BoolFlag in NewOnboardCommand() - Pass encrypt bool into onboard() - Guard passphrase prompt, SSH key generation, and related env-var setup behind the encrypt branch - Adjust 'Next steps' output so the passphrase reminder only appears when --enc was used
This commit is contained in:
+70
-2
@@ -4,11 +4,13 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/caarlos0/env/v11"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/credential"
|
||||
"github.com/sipeed/picoclaw/pkg/fileutil"
|
||||
)
|
||||
|
||||
@@ -837,10 +839,24 @@ func LoadConfig(path string) (*Config, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if passphrase := credential.PassphraseProvider(); passphrase != "" {
|
||||
for _, m := range cfg.ModelList {
|
||||
if m.APIKey != "" && !strings.HasPrefix(m.APIKey, "enc://") && !strings.HasPrefix(m.APIKey, "file://") {
|
||||
fmt.Fprintf(os.Stderr,
|
||||
"picoclaw: warning: model %q has a plaintext api_key; call SaveConfig to encrypt it\n",
|
||||
m.ModelName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := env.Parse(cfg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := resolveAPIKeys(cfg.ModelList, filepath.Dir(path)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Migrate legacy channel config fields to new unified structures
|
||||
cfg.migrateChannelConfigs()
|
||||
|
||||
@@ -857,6 +873,48 @@ func LoadConfig(path string) (*Config, error) {
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
// encryptPlaintextAPIKeys returns a copy of models with plaintext api_key values
|
||||
// encrypted. Returns (nil, nil) when nothing changed (all keys already sealed or
|
||||
// empty). Returns (nil, error) if any key fails to encrypt — callers must treat
|
||||
// this as a hard failure to prevent a mixed plaintext/ciphertext state on disk.
|
||||
// Symmetric counterpart of resolveAPIKeys: both operate purely on []ModelConfig
|
||||
// and leave JSON marshaling to the caller.
|
||||
func encryptPlaintextAPIKeys(models []ModelConfig, passphrase string) ([]ModelConfig, error) {
|
||||
sealed := make([]ModelConfig, len(models))
|
||||
copy(sealed, models)
|
||||
changed := false
|
||||
for i := range sealed {
|
||||
m := &sealed[i]
|
||||
if m.APIKey == "" || strings.HasPrefix(m.APIKey, "enc://") || strings.HasPrefix(m.APIKey, "file://") {
|
||||
continue
|
||||
}
|
||||
encrypted, err := credential.Encrypt(passphrase, "", m.APIKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot seal api_key for model %q: %w", m.ModelName, err)
|
||||
}
|
||||
m.APIKey = encrypted
|
||||
changed = true
|
||||
}
|
||||
if !changed {
|
||||
return nil, nil
|
||||
}
|
||||
return sealed, nil
|
||||
}
|
||||
|
||||
// resolveAPIKeys decrypts or dereferences each api_key in models in-place.
|
||||
// Supports plaintext (no-op), file:// (read from configDir), and enc:// (AES-GCM decrypt).
|
||||
func resolveAPIKeys(models []ModelConfig, configDir string) error {
|
||||
cr := credential.NewResolver(configDir)
|
||||
for i := range models {
|
||||
resolved, err := cr.Resolve(models[i].APIKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("model_list[%d] (%s): %w", i, models[i].ModelName, err)
|
||||
}
|
||||
models[i].APIKey = resolved
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Config) migrateChannelConfigs() {
|
||||
// Discord: mention_only -> group_trigger.mention_only
|
||||
if c.Channels.Discord.MentionOnly && !c.Channels.Discord.GroupTrigger.MentionOnly {
|
||||
@@ -871,12 +929,22 @@ func (c *Config) migrateChannelConfigs() {
|
||||
}
|
||||
|
||||
func SaveConfig(path string, cfg *Config) error {
|
||||
if passphrase := credential.PassphraseProvider(); passphrase != "" {
|
||||
sealed, err := encryptPlaintextAPIKeys(cfg.ModelList, passphrase)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if sealed != nil {
|
||||
tmp := *cfg
|
||||
tmp.ModelList = sealed
|
||||
cfg = &tmp
|
||||
}
|
||||
}
|
||||
|
||||
data, err := json.MarshalIndent(cfg, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Use unified atomic write utility with explicit sync for flash storage reliability.
|
||||
return fileutil.WriteFileAtomic(path, data, 0o600)
|
||||
}
|
||||
|
||||
|
||||
+360
-5
@@ -7,8 +7,22 @@ import (
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/credential"
|
||||
)
|
||||
|
||||
// mustSetupSSHKey generates a temporary Ed25519 SSH key in t.TempDir() and sets
|
||||
// PICOCLAW_SSH_KEY_PATH to its path for the duration of the test. This is required
|
||||
// whenever a test exercises encryption/decryption via credential.Encrypt or SaveConfig.
|
||||
func mustSetupSSHKey(t *testing.T) {
|
||||
t.Helper()
|
||||
keyPath := filepath.Join(t.TempDir(), "picoclaw_ed25519.key")
|
||||
if err := credential.GenerateSSHKey(keyPath); err != nil {
|
||||
t.Fatalf("mustSetupSSHKey: %v", err)
|
||||
}
|
||||
t.Setenv("PICOCLAW_SSH_KEY_PATH", keyPath)
|
||||
}
|
||||
|
||||
func TestAgentModelConfig_UnmarshalString(t *testing.T) {
|
||||
var m AgentModelConfig
|
||||
if err := json.Unmarshal([]byte(`"gpt-4"`), &m); err != nil {
|
||||
@@ -482,13 +496,19 @@ func TestDefaultConfig_DMScope(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDefaultConfig_WorkspacePath_Default(t *testing.T) {
|
||||
// Unset to ensure we test the default
|
||||
t.Setenv("PICOCLAW_HOME", "")
|
||||
// Set a known home for consistent test results
|
||||
t.Setenv("HOME", "/tmp/home")
|
||||
|
||||
var fakeHome string
|
||||
if runtime.GOOS == "windows" {
|
||||
fakeHome = `C:\tmp\home`
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
} else {
|
||||
fakeHome = "/tmp/home"
|
||||
t.Setenv("HOME", fakeHome)
|
||||
}
|
||||
|
||||
cfg := DefaultConfig()
|
||||
want := filepath.Join("/tmp/home", ".picoclaw", "workspace")
|
||||
want := filepath.Join(fakeHome, ".picoclaw", "workspace")
|
||||
|
||||
if cfg.Agents.Defaults.Workspace != want {
|
||||
t.Errorf("Default workspace path = %q, want %q", cfg.Agents.Defaults.Workspace, want)
|
||||
@@ -499,7 +519,7 @@ func TestDefaultConfig_WorkspacePath_WithPicoclawHome(t *testing.T) {
|
||||
t.Setenv("PICOCLAW_HOME", "/custom/picoclaw/home")
|
||||
|
||||
cfg := DefaultConfig()
|
||||
want := "/custom/picoclaw/home/workspace"
|
||||
want := filepath.Join("/custom/picoclaw/home", "workspace")
|
||||
|
||||
if cfg.Agents.Defaults.Workspace != want {
|
||||
t.Errorf("Workspace path with PICOCLAW_HOME = %q, want %q", cfg.Agents.Defaults.Workspace, want)
|
||||
@@ -621,3 +641,338 @@ func TestFlexibleStringSlice_UnmarshalText_EmptySliceConsistency(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestLoadConfig_WarnsForPlaintextAPIKey verifies that LoadConfig resolves a plaintext
|
||||
// api_key into memory but does NOT rewrite the config file. File writes are the sole
|
||||
// responsibility of SaveConfig.
|
||||
func TestLoadConfig_WarnsForPlaintextAPIKey(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
cfgPath := filepath.Join(dir, "config.json")
|
||||
const original = `{"model_list":[{"model_name":"test","model":"openai/gpt-4","api_key":"sk-plaintext"}]}`
|
||||
if err := os.WriteFile(cfgPath, []byte(original), 0o600); err != nil {
|
||||
t.Fatalf("setup: %v", err)
|
||||
}
|
||||
|
||||
t.Setenv("PICOCLAW_KEY_PASSPHRASE", "test-passphrase")
|
||||
t.Setenv("PICOCLAW_SSH_KEY_PATH", "")
|
||||
|
||||
cfg, err := LoadConfig(cfgPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig: %v", err)
|
||||
}
|
||||
// In-memory value must be the resolved plaintext.
|
||||
if cfg.ModelList[0].APIKey != "sk-plaintext" {
|
||||
t.Errorf("in-memory api_key = %q, want %q", cfg.ModelList[0].APIKey, "sk-plaintext")
|
||||
}
|
||||
// The file on disk must remain unchanged — LoadConfig must not write anything.
|
||||
raw, _ := os.ReadFile(cfgPath)
|
||||
if string(raw) != original {
|
||||
t.Errorf("LoadConfig must not modify the config file; got:\n%s", string(raw))
|
||||
}
|
||||
}
|
||||
|
||||
// TestSaveConfig_EncryptsPlaintextAPIKey verifies that SaveConfig writes enc:// ciphertext
|
||||
// to disk and that a subsequent LoadConfig decrypts it back to the original plaintext.
|
||||
func TestSaveConfig_EncryptsPlaintextAPIKey(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
cfgPath := filepath.Join(dir, "config.json")
|
||||
|
||||
t.Setenv("PICOCLAW_KEY_PASSPHRASE", "test-passphrase")
|
||||
mustSetupSSHKey(t)
|
||||
|
||||
cfg := DefaultConfig()
|
||||
cfg.ModelList = []ModelConfig{
|
||||
{ModelName: "test", Model: "openai/gpt-4", APIKey: "sk-plaintext"},
|
||||
}
|
||||
if err := SaveConfig(cfgPath, cfg); err != nil {
|
||||
t.Fatalf("SaveConfig: %v", err)
|
||||
}
|
||||
|
||||
// Disk must contain enc://, not the raw key.
|
||||
raw, _ := os.ReadFile(cfgPath)
|
||||
if !strings.Contains(string(raw), "enc://") {
|
||||
t.Errorf("saved file should contain enc://, got:\n%s", string(raw))
|
||||
}
|
||||
if strings.Contains(string(raw), "sk-plaintext") {
|
||||
t.Errorf("saved file must not contain the plaintext key")
|
||||
}
|
||||
|
||||
// A fresh load must decrypt back to the original plaintext.
|
||||
cfg2, err := LoadConfig(cfgPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig after SaveConfig: %v", err)
|
||||
}
|
||||
if cfg2.ModelList[0].APIKey != "sk-plaintext" {
|
||||
t.Errorf("loaded api_key = %q, want %q", cfg2.ModelList[0].APIKey, "sk-plaintext")
|
||||
}
|
||||
}
|
||||
|
||||
// TestLoadConfig_NoSealWithoutPassphrase verifies that api_key values are left
|
||||
// unchanged when PICOCLAW_KEY_PASSPHRASE is not set.
|
||||
func TestLoadConfig_NoSealWithoutPassphrase(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
cfgPath := filepath.Join(dir, "config.json")
|
||||
data := `{"model_list":[{"model_name":"test","model":"openai/gpt-4","api_key":"sk-plaintext"}]}`
|
||||
if err := os.WriteFile(cfgPath, []byte(data), 0o600); err != nil {
|
||||
t.Fatalf("setup: %v", err)
|
||||
}
|
||||
|
||||
t.Setenv("PICOCLAW_KEY_PASSPHRASE", "")
|
||||
t.Setenv("PICOCLAW_SSH_KEY_PATH", "")
|
||||
|
||||
if _, err := LoadConfig(cfgPath); err != nil {
|
||||
t.Fatalf("LoadConfig: %v", err)
|
||||
}
|
||||
|
||||
raw, _ := os.ReadFile(cfgPath)
|
||||
if strings.Contains(string(raw), "enc://") {
|
||||
t.Error("config file must not be modified when no passphrase is set")
|
||||
}
|
||||
}
|
||||
|
||||
// TestLoadConfig_FileRefNotSealed verifies that file:// api_key references are not
|
||||
// converted to enc:// values (they are resolved at runtime by the Resolver).
|
||||
func TestLoadConfig_FileRefNotSealed(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
cfgPath := filepath.Join(dir, "config.json")
|
||||
keyFile := filepath.Join(dir, "openai.key")
|
||||
if err := os.WriteFile(keyFile, []byte("sk-from-file"), 0o600); err != nil {
|
||||
t.Fatalf("setup: %v", err)
|
||||
}
|
||||
data := `{"model_list":[{"model_name":"test","model":"openai/gpt-4","api_key":"file://openai.key"}]}`
|
||||
if err := os.WriteFile(cfgPath, []byte(data), 0o600); err != nil {
|
||||
t.Fatalf("setup: %v", err)
|
||||
}
|
||||
|
||||
t.Setenv("PICOCLAW_KEY_PASSPHRASE", "test-passphrase")
|
||||
t.Setenv("PICOCLAW_SSH_KEY_PATH", "")
|
||||
|
||||
if _, err := LoadConfig(cfgPath); err != nil {
|
||||
t.Fatalf("LoadConfig: %v", err)
|
||||
}
|
||||
|
||||
raw, _ := os.ReadFile(cfgPath)
|
||||
if !strings.Contains(string(raw), "file://openai.key") {
|
||||
t.Error("file:// reference should be preserved unchanged in the config file")
|
||||
}
|
||||
if strings.Contains(string(raw), "enc://") {
|
||||
t.Error("file:// reference must not be converted to enc://")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSaveConfig_MixedKeys verifies that SaveConfig encrypts only plaintext api_keys
|
||||
// and leaves already-encrypted (enc://) and file:// entries unchanged.
|
||||
func TestSaveConfig_MixedKeys(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
cfgPath := filepath.Join(dir, "config.json")
|
||||
|
||||
t.Setenv("PICOCLAW_KEY_PASSPHRASE", "test-passphrase")
|
||||
mustSetupSSHKey(t)
|
||||
|
||||
// Pre-encrypt one key so we have a genuine enc:// value to put in the config.
|
||||
if err := SaveConfig(cfgPath, &Config{
|
||||
ModelList: []ModelConfig{
|
||||
{ModelName: "pre", Model: "openai/gpt-4", APIKey: "sk-already-plain"},
|
||||
},
|
||||
}); err != nil {
|
||||
t.Fatalf("setup SaveConfig: %v", err)
|
||||
}
|
||||
raw, _ := os.ReadFile(cfgPath)
|
||||
// Extract the enc:// value from the saved file.
|
||||
var tmp struct {
|
||||
ModelList []struct {
|
||||
APIKey string `json:"api_key"`
|
||||
} `json:"model_list"`
|
||||
}
|
||||
if err := json.Unmarshal(raw, &tmp); err != nil || len(tmp.ModelList) == 0 {
|
||||
t.Fatalf("setup: could not parse saved config: %v", err)
|
||||
}
|
||||
alreadyEncrypted := tmp.ModelList[0].APIKey
|
||||
if !strings.HasPrefix(alreadyEncrypted, "enc://") {
|
||||
t.Fatalf("setup: expected enc:// key, got %q", alreadyEncrypted)
|
||||
}
|
||||
|
||||
// Build a config with three models:
|
||||
// 1. plaintext → must be encrypted by SaveConfig
|
||||
// 2. enc:// → must be left unchanged (already encrypted)
|
||||
// 3. file:// → must be left unchanged (file reference)
|
||||
keyFile := filepath.Join(dir, "api.key")
|
||||
if err := os.WriteFile(keyFile, []byte("sk-from-file"), 0o600); err != nil {
|
||||
t.Fatalf("setup: %v", err)
|
||||
}
|
||||
cfg := &Config{
|
||||
ModelList: []ModelConfig{
|
||||
{ModelName: "plain", Model: "openai/gpt-4", APIKey: "sk-new-plaintext"},
|
||||
{ModelName: "enc", Model: "openai/gpt-4", APIKey: alreadyEncrypted},
|
||||
{ModelName: "file", Model: "openai/gpt-4", APIKey: "file://api.key"},
|
||||
},
|
||||
}
|
||||
if err := SaveConfig(cfgPath, cfg); err != nil {
|
||||
t.Fatalf("SaveConfig: %v", err)
|
||||
}
|
||||
|
||||
raw, _ = os.ReadFile(cfgPath)
|
||||
s := string(raw)
|
||||
|
||||
// 1. Plaintext must be encrypted.
|
||||
if strings.Contains(s, "sk-new-plaintext") {
|
||||
t.Error("plaintext key must not appear in saved file")
|
||||
}
|
||||
// 2. The pre-existing enc:// value must still be present (byte-for-byte unchanged).
|
||||
if !strings.Contains(s, alreadyEncrypted) {
|
||||
t.Error("pre-existing enc:// entry must be preserved unchanged")
|
||||
}
|
||||
// 3. file:// must be preserved.
|
||||
if !strings.Contains(s, "file://api.key") {
|
||||
t.Error("file:// reference must be preserved unchanged")
|
||||
}
|
||||
|
||||
// Now load and verify all three decrypt/resolve correctly.
|
||||
cfg2, err := LoadConfig(cfgPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig after SaveConfig: %v", err)
|
||||
}
|
||||
byName := make(map[string]string)
|
||||
for _, m := range cfg2.ModelList {
|
||||
byName[m.ModelName] = m.APIKey
|
||||
}
|
||||
if byName["plain"] != "sk-new-plaintext" {
|
||||
t.Errorf("plain model api_key = %q, want %q", byName["plain"], "sk-new-plaintext")
|
||||
}
|
||||
if byName["enc"] != "sk-already-plain" {
|
||||
t.Errorf("enc model api_key = %q, want %q", byName["enc"], "sk-already-plain")
|
||||
}
|
||||
if byName["file"] != "sk-from-file" {
|
||||
t.Errorf("file model api_key = %q, want %q", byName["file"], "sk-from-file")
|
||||
}
|
||||
}
|
||||
|
||||
// TestLoadConfig_MixedKeys_NoPassphrase verifies that when PICOCLAW_KEY_PASSPHRASE
|
||||
// is not set, enc:// entries cause LoadConfig to return an error, while plaintext
|
||||
// and file:// entries in the same config are not affected.
|
||||
func TestLoadConfig_MixedKeys_NoPassphrase(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
cfgPath := filepath.Join(dir, "config.json")
|
||||
|
||||
// First encrypt a key so we have a real enc:// value.
|
||||
t.Setenv("PICOCLAW_KEY_PASSPHRASE", "test-passphrase")
|
||||
mustSetupSSHKey(t)
|
||||
if err := SaveConfig(cfgPath, &Config{
|
||||
ModelList: []ModelConfig{
|
||||
{ModelName: "m", Model: "openai/gpt-4", APIKey: "sk-secret"},
|
||||
},
|
||||
}); err != nil {
|
||||
t.Fatalf("setup SaveConfig: %v", err)
|
||||
}
|
||||
raw, _ := os.ReadFile(cfgPath)
|
||||
var tmp struct {
|
||||
ModelList []struct {
|
||||
APIKey string `json:"api_key"`
|
||||
} `json:"model_list"`
|
||||
}
|
||||
if err := json.Unmarshal(raw, &tmp); err != nil {
|
||||
t.Fatalf("setup parse: %v", err)
|
||||
}
|
||||
encValue := tmp.ModelList[0].APIKey
|
||||
|
||||
// Write a mixed config: enc:// + plaintext + file://
|
||||
keyFile := filepath.Join(dir, "api.key")
|
||||
if err := os.WriteFile(keyFile, []byte("sk-from-file"), 0o600); err != nil {
|
||||
t.Fatalf("setup: %v", err)
|
||||
}
|
||||
mixed, _ := json.Marshal(map[string]any{
|
||||
"model_list": []map[string]any{
|
||||
{"model_name": "enc", "model": "openai/gpt-4", "api_key": encValue},
|
||||
{"model_name": "plain", "model": "openai/gpt-4", "api_key": "sk-plain"},
|
||||
{"model_name": "file", "model": "openai/gpt-4", "api_key": "file://api.key"},
|
||||
},
|
||||
})
|
||||
if err := os.WriteFile(cfgPath, mixed, 0o600); err != nil {
|
||||
t.Fatalf("setup write: %v", err)
|
||||
}
|
||||
|
||||
// Now clear the passphrase — LoadConfig must fail because enc:// cannot be decrypted.
|
||||
t.Setenv("PICOCLAW_KEY_PASSPHRASE", "")
|
||||
|
||||
_, err := LoadConfig(cfgPath)
|
||||
if err == nil {
|
||||
t.Fatal("LoadConfig should fail when enc:// key is present and no passphrase is set")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "passphrase required") {
|
||||
t.Errorf("error should mention passphrase required, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSaveConfig_UsesPassphraseProvider verifies that SaveConfig encrypts plaintext
|
||||
// api_keys using credential.PassphraseProvider() rather than os.Getenv directly.
|
||||
// This matters for the launcher, which clears the environment variable and redirects
|
||||
// PassphraseProvider to an in-memory SecureStore.
|
||||
func TestSaveConfig_UsesPassphraseProvider(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
cfgPath := filepath.Join(dir, "config.json")
|
||||
|
||||
// Ensure the env var is empty — passphrase must come from PassphraseProvider only.
|
||||
t.Setenv("PICOCLAW_KEY_PASSPHRASE", "")
|
||||
mustSetupSSHKey(t)
|
||||
|
||||
// Replace PassphraseProvider with an in-memory function (simulating SecureStore).
|
||||
const testPassphrase = "provider-passphrase"
|
||||
orig := credential.PassphraseProvider
|
||||
credential.PassphraseProvider = func() string { return testPassphrase }
|
||||
t.Cleanup(func() { credential.PassphraseProvider = orig })
|
||||
|
||||
cfg := DefaultConfig()
|
||||
cfg.ModelList = []ModelConfig{
|
||||
{ModelName: "test", Model: "openai/gpt-4", APIKey: "sk-plaintext"},
|
||||
}
|
||||
if err := SaveConfig(cfgPath, cfg); err != nil {
|
||||
t.Fatalf("SaveConfig: %v", err)
|
||||
}
|
||||
|
||||
raw, _ := os.ReadFile(cfgPath)
|
||||
if !strings.Contains(string(raw), "enc://") {
|
||||
t.Errorf("SaveConfig should have encrypted plaintext key via PassphraseProvider; got:\n%s", raw)
|
||||
}
|
||||
}
|
||||
|
||||
// TestLoadConfig_UsesPassphraseProvider verifies that LoadConfig decrypts enc:// keys
|
||||
// using credential.PassphraseProvider() rather than os.Getenv directly.
|
||||
func TestLoadConfig_UsesPassphraseProvider(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
cfgPath := filepath.Join(dir, "config.json")
|
||||
|
||||
// Ensure the env var is empty throughout.
|
||||
t.Setenv("PICOCLAW_KEY_PASSPHRASE", "")
|
||||
mustSetupSSHKey(t)
|
||||
|
||||
const testPassphrase = "provider-passphrase"
|
||||
const plainKey = "sk-secret"
|
||||
|
||||
// First, encrypt the key using the same passphrase.
|
||||
encrypted, err := credential.Encrypt(testPassphrase, "", plainKey)
|
||||
if err != nil {
|
||||
t.Fatalf("Encrypt: %v", err)
|
||||
}
|
||||
|
||||
raw, _ := json.Marshal(map[string]any{
|
||||
"model_list": []map[string]any{
|
||||
{"model_name": "test", "model": "openai/gpt-4", "api_key": encrypted},
|
||||
},
|
||||
})
|
||||
if err = os.WriteFile(cfgPath, raw, 0o600); err != nil {
|
||||
t.Fatalf("setup: %v", err)
|
||||
}
|
||||
|
||||
// Redirect PassphraseProvider — env var is empty, so without this the load would fail.
|
||||
orig := credential.PassphraseProvider
|
||||
credential.PassphraseProvider = func() string { return testPassphrase }
|
||||
t.Cleanup(func() { credential.PassphraseProvider = orig })
|
||||
|
||||
cfg, err := LoadConfig(cfgPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig: %v", err)
|
||||
}
|
||||
if cfg.ModelList[0].APIKey != plainKey {
|
||||
t.Errorf("api_key = %q, want %q", cfg.ModelList[0].APIKey, plainKey)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,335 @@
|
||||
// Package credential resolves API credential values for model_list entries.
|
||||
//
|
||||
// An API key is a form of authorization credential. This package centralizes
|
||||
// how raw credential strings—plaintext or file references—are resolved into
|
||||
// their actual values, keeping that logic out of the config loader.
|
||||
//
|
||||
// Supported formats for the api_key field:
|
||||
//
|
||||
// - Plaintext: "sk-abc123" → returned as-is
|
||||
// - File ref: "file://filename.key" → content read from configDir/filename.key
|
||||
// - Encrypted: "enc://<base64>" → AES-256-GCM decrypt via PICOCLAW_KEY_PASSPHRASE
|
||||
// - Empty: "" → returned as-is (auth_method=oauth etc.)
|
||||
//
|
||||
// Encryption uses AES-256-GCM with HKDF-SHA256 key derivation (< 1ms, safe for embedded Linux).
|
||||
// An SSH private key is required for both encryption and decryption.
|
||||
// Key derivation:
|
||||
//
|
||||
// HKDF-SHA256(ikm=HMAC-SHA256(SHA256(sshKeyBytes), passphrase), salt, info)
|
||||
//
|
||||
// SSH key path resolution priority:
|
||||
//
|
||||
// 1. sshKeyPath argument to Encrypt (explicit)
|
||||
// 2. PICOCLAW_SSH_KEY_PATH env var
|
||||
// 3. ~/.ssh/picoclaw_ed25519.key (os.UserHomeDir is cross-platform)
|
||||
package credential
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/hkdf"
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// PassphraseEnvVar is the environment variable that holds the encryption passphrase.
|
||||
// Other packages (e.g. config) reference this constant to avoid duplicating the string.
|
||||
const PassphraseEnvVar = "PICOCLAW_KEY_PASSPHRASE"
|
||||
|
||||
// PassphraseProvider is the function used to retrieve the passphrase for enc://
|
||||
// credential decryption. It defaults to reading PICOCLAW_KEY_PASSPHRASE from the
|
||||
// process environment. Replace it at startup to use a different source, such as
|
||||
// an in-memory SecureStore, so that all LoadConfig() calls everywhere share the
|
||||
// same passphrase source without needing os.Environ.
|
||||
//
|
||||
// Example (launcher main.go):
|
||||
//
|
||||
// credential.PassphraseProvider = apiHandler.passphraseStore.Get
|
||||
var PassphraseProvider func() string = func() string {
|
||||
return os.Getenv(PassphraseEnvVar)
|
||||
}
|
||||
|
||||
// ErrPassphraseRequired is returned when an enc:// credential is encountered but
|
||||
// no passphrase is available from PassphraseProvider. Callers can detect this
|
||||
// with errors.Is to distinguish a missing-passphrase condition from other errors.
|
||||
var ErrPassphraseRequired = errors.New("credential: enc:// passphrase required")
|
||||
|
||||
// ErrDecryptionFailed is returned when an enc:// credential cannot be decrypted,
|
||||
// indicating a wrong passphrase or SSH key. Callers can detect this with errors.Is.
|
||||
var ErrDecryptionFailed = errors.New("credential: enc:// decryption failed (wrong passphrase or SSH key?)")
|
||||
|
||||
const (
|
||||
fileScheme = "file://"
|
||||
encScheme = "enc://"
|
||||
hkdfInfo = "picoclaw-credential-v1"
|
||||
saltLen = 16
|
||||
nonceLen = 12
|
||||
keyLen = 32
|
||||
sshKeyEnv = "PICOCLAW_SSH_KEY_PATH"
|
||||
)
|
||||
|
||||
// Resolver resolves raw credential strings for model_list api_key fields.
|
||||
// File references are resolved relative to the directory of the config file.
|
||||
type Resolver struct {
|
||||
configDir string
|
||||
resolvedConfigDir string // symlink-resolved form of configDir
|
||||
}
|
||||
|
||||
// NewResolver returns a Resolver that resolves file:// references relative to
|
||||
// configDir (typically filepath.Dir of the config file path).
|
||||
func NewResolver(configDir string) *Resolver {
|
||||
resolved := configDir
|
||||
if configDir != "" {
|
||||
if linkedPath, err := filepath.EvalSymlinks(configDir); err == nil {
|
||||
resolved = linkedPath
|
||||
}
|
||||
}
|
||||
return &Resolver{configDir: configDir, resolvedConfigDir: resolved}
|
||||
}
|
||||
|
||||
// Resolve returns the actual credential value for raw:
|
||||
//
|
||||
// - "" → "" (no error; auth_method=oauth needs no key)
|
||||
// - "file://name.key" → trimmed content of configDir/name.key
|
||||
// - anything else → raw unchanged (plaintext credential)
|
||||
func (r *Resolver) Resolve(raw string) (string, error) {
|
||||
if raw == "" {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
if strings.HasPrefix(raw, fileScheme) {
|
||||
fileName := strings.TrimSpace(strings.TrimPrefix(raw, fileScheme))
|
||||
if fileName == "" {
|
||||
return "", fmt.Errorf("credential: file:// reference has no filename")
|
||||
}
|
||||
|
||||
baseDir := r.resolvedConfigDir
|
||||
if baseDir == "" {
|
||||
baseDir = r.configDir
|
||||
}
|
||||
keyPath := filepath.Join(baseDir, fileName)
|
||||
// Resolve symlinks before enforcing containment to prevent escaping via symlinks.
|
||||
realKeyPath, err := filepath.EvalSymlinks(keyPath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("credential: failed to resolve credential file path %q: %w", keyPath, err)
|
||||
}
|
||||
if !isWithinDir(realKeyPath, baseDir) {
|
||||
return "", fmt.Errorf("credential: file:// path escapes config directory")
|
||||
}
|
||||
data, err := os.ReadFile(realKeyPath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("credential: failed to read credential file %q: %w", realKeyPath, err)
|
||||
}
|
||||
|
||||
value := strings.TrimSpace(string(data))
|
||||
if value == "" {
|
||||
return "", fmt.Errorf("credential: credential file %q is empty", realKeyPath)
|
||||
}
|
||||
|
||||
return value, nil
|
||||
}
|
||||
|
||||
if strings.HasPrefix(raw, encScheme) {
|
||||
return resolveEncrypted(raw)
|
||||
}
|
||||
|
||||
// Plaintext credential — return unchanged.
|
||||
return raw, nil
|
||||
}
|
||||
|
||||
// resolveEncrypted decrypts an enc:// credential using PassphraseProvider.
|
||||
func resolveEncrypted(raw string) (string, error) {
|
||||
passphrase := PassphraseProvider()
|
||||
if passphrase == "" {
|
||||
return "", ErrPassphraseRequired
|
||||
}
|
||||
|
||||
sshKeyPath := pickSSHKeyPath("") // override="": consult env then auto-detect
|
||||
|
||||
b64 := strings.TrimPrefix(raw, encScheme)
|
||||
blob, err := base64.StdEncoding.DecodeString(b64)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("credential: enc:// invalid base64: %w", err)
|
||||
}
|
||||
if len(blob) < saltLen+nonceLen+1 {
|
||||
return "", fmt.Errorf("credential: enc:// payload too short")
|
||||
}
|
||||
|
||||
salt := blob[:saltLen]
|
||||
nonce := blob[saltLen : saltLen+nonceLen]
|
||||
ciphertext := blob[saltLen+nonceLen:]
|
||||
|
||||
key, err := deriveKey(passphrase, sshKeyPath, salt)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("credential: enc:// cipher init: %w", err)
|
||||
}
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("credential: enc:// gcm init: %w", err)
|
||||
}
|
||||
|
||||
plaintext, err := gcm.Open(nil, nonce, ciphertext, nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("%w: %w", ErrDecryptionFailed, err)
|
||||
}
|
||||
return string(plaintext), nil
|
||||
}
|
||||
|
||||
// Encrypt encrypts plaintext and returns an enc:// credential string.
|
||||
//
|
||||
// passphrase is required (PICOCLAW_KEY_PASSPHRASE value).
|
||||
// sshKeyPath is the SSH private key file to use; pass "" to auto-detect via
|
||||
// PICOCLAW_SSH_KEY_PATH env var or ~/.ssh/picoclaw_ed25519.key.
|
||||
// An SSH private key must be resolvable or Encrypt returns an error.
|
||||
func Encrypt(passphrase, sshKeyPath, plaintext string) (string, error) {
|
||||
if passphrase == "" {
|
||||
return "", fmt.Errorf("credential: passphrase must not be empty")
|
||||
}
|
||||
sshKeyPath = pickSSHKeyPath(sshKeyPath)
|
||||
|
||||
salt := make([]byte, saltLen)
|
||||
if _, err := io.ReadFull(rand.Reader, salt); err != nil {
|
||||
return "", fmt.Errorf("credential: failed to generate salt: %w", err)
|
||||
}
|
||||
|
||||
key, err := deriveKey(passphrase, sshKeyPath, salt)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("credential: cipher init: %w", err)
|
||||
}
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("credential: gcm init: %w", err)
|
||||
}
|
||||
|
||||
nonce := make([]byte, nonceLen)
|
||||
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
|
||||
return "", fmt.Errorf("credential: failed to generate nonce: %w", err)
|
||||
}
|
||||
|
||||
ciphertext := gcm.Seal(nil, nonce, []byte(plaintext), nil)
|
||||
blob := make([]byte, 0, saltLen+nonceLen+len(ciphertext))
|
||||
blob = append(blob, salt...)
|
||||
blob = append(blob, nonce...)
|
||||
blob = append(blob, ciphertext...)
|
||||
return encScheme + base64.StdEncoding.EncodeToString(blob), nil
|
||||
}
|
||||
|
||||
// isWithinDir reports whether path is contained within (or equal to) dir.
|
||||
// Uses filepath.IsLocal on the relative path for robust cross-platform traversal detection.
|
||||
func isWithinDir(path, dir string) bool {
|
||||
rel, err := filepath.Rel(filepath.Clean(dir), filepath.Clean(path))
|
||||
return err == nil && filepath.IsLocal(rel)
|
||||
}
|
||||
|
||||
// allowedSSHKeyPath reports whether path is in a permitted location for SSH key files:
|
||||
// - exact match with PICOCLAW_SSH_KEY_PATH env var
|
||||
// - within the PICOCLAW_HOME env var directory
|
||||
// - within ~/.ssh/
|
||||
func allowedSSHKeyPath(path string) bool {
|
||||
if path == "" {
|
||||
return true // passphrase-only mode; no file will be read
|
||||
}
|
||||
clean := filepath.Clean(path)
|
||||
|
||||
// Exact match with PICOCLAW_SSH_KEY_PATH.
|
||||
if envPath, ok := os.LookupEnv(sshKeyEnv); ok && envPath != "" {
|
||||
if clean == filepath.Clean(envPath) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Within PICOCLAW_HOME.
|
||||
if picoHome := os.Getenv("PICOCLAW_HOME"); picoHome != "" {
|
||||
if isWithinDir(clean, picoHome) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Within ~/.ssh/.
|
||||
if userHome, err := os.UserHomeDir(); err == nil {
|
||||
if isWithinDir(clean, filepath.Join(userHome, ".ssh")) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// deriveKey derives a 32-byte AES-256 key from passphrase and SSH private key.
|
||||
//
|
||||
// ikm = HMAC-SHA256(key=SHA256(sshKeyBytes), msg=passphrase)
|
||||
// Final key: HKDF-SHA256(ikm, salt, info="picoclaw-credential-v1", 32 bytes)
|
||||
// sshKeyPath must be non-empty; returns an error otherwise.
|
||||
func deriveKey(passphrase, sshKeyPath string, salt []byte) ([]byte, error) {
|
||||
if sshKeyPath == "" {
|
||||
return nil, fmt.Errorf(
|
||||
"credential: SSH private key is required but not found" +
|
||||
" (set PICOCLAW_SSH_KEY_PATH or place key at ~/.ssh/picoclaw_ed25519.key)")
|
||||
}
|
||||
if !allowedSSHKeyPath(sshKeyPath) {
|
||||
return nil, fmt.Errorf(
|
||||
"credential: SSH key path %q is not in an allowed location (PICOCLAW_SSH_KEY_PATH, PICOCLAW_HOME, or ~/.ssh/)",
|
||||
sshKeyPath,
|
||||
)
|
||||
}
|
||||
sshBytes, err := os.ReadFile(sshKeyPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("credential: cannot read SSH key %q: %w", sshKeyPath, err)
|
||||
}
|
||||
sshHash := sha256.Sum256(sshBytes)
|
||||
mac := hmac.New(sha256.New, sshHash[:])
|
||||
mac.Write([]byte(passphrase))
|
||||
ikm := mac.Sum(nil)
|
||||
|
||||
key, err := hkdf.Key(sha256.New, ikm, salt, hkdfInfo, keyLen)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("credential: HKDF expand failed: %w", err)
|
||||
}
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// pickSSHKeyPath returns the SSH private key path to use for encryption/decryption.
|
||||
//
|
||||
// Priority:
|
||||
// 1. override (non-empty explicit argument)
|
||||
// 2. PICOCLAW_SSH_KEY_PATH env var
|
||||
// 3. ~/.ssh/picoclaw_ed25519.key (auto-detection)
|
||||
//
|
||||
// Returns "" when no key is found; deriveKey will return an error in that case.
|
||||
func pickSSHKeyPath(override string) string {
|
||||
if override != "" {
|
||||
return override
|
||||
}
|
||||
if p, ok := os.LookupEnv(sshKeyEnv); ok {
|
||||
return p // respect explicit setting, even if ""
|
||||
}
|
||||
return findDefaultSSHKey()
|
||||
}
|
||||
|
||||
// findDefaultSSHKey returns the picoclaw-specific SSH key path if it exists.
|
||||
func findDefaultSSHKey() string {
|
||||
p, err := DefaultSSHKeyPath()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
if _, err := os.Stat(p); err == nil {
|
||||
return p
|
||||
}
|
||||
return ""
|
||||
}
|
||||
@@ -0,0 +1,283 @@
|
||||
package credential_test
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/credential"
|
||||
)
|
||||
|
||||
func TestResolve_PlainKey(t *testing.T) {
|
||||
r := credential.NewResolver(t.TempDir())
|
||||
got, err := r.Resolve("sk-plaintext-key")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if got != "sk-plaintext-key" {
|
||||
t.Fatalf("got %q, want %q", got, "sk-plaintext-key")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolve_FileKey_Success(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
keyFile := "openai_plain.key"
|
||||
if err := os.WriteFile(filepath.Join(dir, keyFile), []byte("sk-from-file\n"), 0o600); err != nil {
|
||||
t.Fatalf("setup: %v", err)
|
||||
}
|
||||
|
||||
r := credential.NewResolver(dir)
|
||||
got, err := r.Resolve("file://" + keyFile)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if got != "sk-from-file" {
|
||||
t.Fatalf("got %q, want %q", got, "sk-from-file")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolve_FileKey_NotFound(t *testing.T) {
|
||||
r := credential.NewResolver(t.TempDir())
|
||||
_, err := r.Resolve("file://missing.key")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for missing file, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolve_FileKey_Empty(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
keyFile := "empty.key"
|
||||
if err := os.WriteFile(filepath.Join(dir, keyFile), []byte(" \n"), 0o600); err != nil {
|
||||
t.Fatalf("setup: %v", err)
|
||||
}
|
||||
|
||||
r := credential.NewResolver(dir)
|
||||
_, err := r.Resolve("file://" + keyFile)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for empty credential file, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestResolve_EncKey_RoundTrip tests basic encryption/decryption round-trip with an SSH key.
|
||||
func TestResolve_EncKey_RoundTrip(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
sshKeyPath := filepath.Join(dir, "picoclaw_ed25519.key")
|
||||
if err := os.WriteFile(sshKeyPath, []byte("fake-ssh-key-material\n"), 0o600); err != nil {
|
||||
t.Fatalf("setup: %v", err)
|
||||
}
|
||||
|
||||
const passphrase = "test-passphrase-32bytes-long-ok!"
|
||||
const plaintext = "sk-encrypted-secret"
|
||||
|
||||
t.Setenv("PICOCLAW_SSH_KEY_PATH", sshKeyPath)
|
||||
|
||||
enc, err := credential.Encrypt(passphrase, "", plaintext)
|
||||
if err != nil {
|
||||
t.Fatalf("Encrypt: %v", err)
|
||||
}
|
||||
|
||||
t.Setenv("PICOCLAW_KEY_PASSPHRASE", passphrase)
|
||||
|
||||
r := credential.NewResolver(t.TempDir())
|
||||
got, err := r.Resolve(enc)
|
||||
if err != nil {
|
||||
t.Fatalf("Resolve: %v", err)
|
||||
}
|
||||
if got != plaintext {
|
||||
t.Fatalf("got %q, want %q", got, plaintext)
|
||||
}
|
||||
}
|
||||
|
||||
// TestResolve_EncKey_WithSSHKey tests that the SSH key file is incorporated into key derivation.
|
||||
func TestResolve_EncKey_WithSSHKey(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
sshKeyPath := filepath.Join(dir, "picoclaw_ed25519.key")
|
||||
if err := os.WriteFile(sshKeyPath, []byte("fake-ssh-private-key-material\n"), 0o600); err != nil {
|
||||
t.Fatalf("setup: %v", err)
|
||||
}
|
||||
|
||||
const passphrase = "test-passphrase"
|
||||
const plaintext = "sk-ssh-protected-secret"
|
||||
|
||||
// Set PICOCLAW_SSH_KEY_PATH before Encrypt so the path passes allowedSSHKeyPath validation.
|
||||
t.Setenv("PICOCLAW_KEY_PASSPHRASE", passphrase)
|
||||
t.Setenv("PICOCLAW_SSH_KEY_PATH", sshKeyPath)
|
||||
|
||||
enc, err := credential.Encrypt(passphrase, sshKeyPath, plaintext)
|
||||
if err != nil {
|
||||
t.Fatalf("Encrypt: %v", err)
|
||||
}
|
||||
|
||||
r := credential.NewResolver(t.TempDir())
|
||||
got, err := r.Resolve(enc)
|
||||
if err != nil {
|
||||
t.Fatalf("Resolve: %v", err)
|
||||
}
|
||||
if got != plaintext {
|
||||
t.Fatalf("got %q, want %q", got, plaintext)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolve_EncKey_NoPassphrase(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
sshKeyPath := filepath.Join(dir, "picoclaw_ed25519.key")
|
||||
if err := os.WriteFile(sshKeyPath, []byte("fake-ssh-key\n"), 0o600); err != nil {
|
||||
t.Fatalf("setup: %v", err)
|
||||
}
|
||||
t.Setenv("PICOCLAW_SSH_KEY_PATH", sshKeyPath)
|
||||
|
||||
enc, err := credential.Encrypt("some-passphrase", "", "sk-secret")
|
||||
if err != nil {
|
||||
t.Fatalf("Encrypt: %v", err)
|
||||
}
|
||||
|
||||
t.Setenv("PICOCLAW_KEY_PASSPHRASE", "")
|
||||
|
||||
r := credential.NewResolver(t.TempDir())
|
||||
_, err = r.Resolve(enc)
|
||||
if err == nil {
|
||||
t.Fatal("expected error when PICOCLAW_KEY_PASSPHRASE is unset, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolve_EncKey_BadCiphertext(t *testing.T) {
|
||||
t.Setenv("PICOCLAW_KEY_PASSPHRASE", "some-passphrase")
|
||||
t.Setenv("PICOCLAW_SSH_KEY_PATH", "")
|
||||
|
||||
r := credential.NewResolver(t.TempDir())
|
||||
_, err := r.Resolve("enc://!!not-valid-base64!!")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid enc:// payload, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolve_EncKey_PayloadTooShort(t *testing.T) {
|
||||
t.Setenv("PICOCLAW_KEY_PASSPHRASE", "some-passphrase")
|
||||
t.Setenv("PICOCLAW_SSH_KEY_PATH", "")
|
||||
|
||||
// Valid base64 but fewer bytes than salt(16)+nonce(12)+1 minimum.
|
||||
import64 := "dG9vc2hvcnQ=" // "tooshort" = 8 bytes
|
||||
r := credential.NewResolver(t.TempDir())
|
||||
_, err := r.Resolve("enc://" + import64)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for too-short enc:// payload, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolve_EncKey_WrongPassphrase(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
sshKeyPath := filepath.Join(dir, "picoclaw_ed25519.key")
|
||||
if err := os.WriteFile(sshKeyPath, []byte("fake-ssh-key\n"), 0o600); err != nil {
|
||||
t.Fatalf("setup: %v", err)
|
||||
}
|
||||
t.Setenv("PICOCLAW_SSH_KEY_PATH", sshKeyPath)
|
||||
|
||||
enc, err := credential.Encrypt("correct-passphrase", "", "sk-secret")
|
||||
if err != nil {
|
||||
t.Fatalf("Encrypt: %v", err)
|
||||
}
|
||||
|
||||
t.Setenv("PICOCLAW_KEY_PASSPHRASE", "wrong-passphrase")
|
||||
|
||||
r := credential.NewResolver(t.TempDir())
|
||||
_, err = r.Resolve(enc)
|
||||
if err == nil {
|
||||
t.Fatal("expected decryption error for wrong passphrase, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncrypt_EmptyPassphrase(t *testing.T) {
|
||||
_, err := credential.Encrypt("", "", "sk-secret")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for empty passphrase, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeriveKey_SSHKeyNotFound(t *testing.T) {
|
||||
// Encrypt with a real SSH key path, then try to decrypt with a missing path.
|
||||
dir := t.TempDir()
|
||||
sshKeyPath := filepath.Join(dir, "picoclaw_ed25519.key")
|
||||
if err := os.WriteFile(sshKeyPath, []byte("fake-key\n"), 0o600); err != nil {
|
||||
t.Fatalf("setup: %v", err)
|
||||
}
|
||||
|
||||
// Register the real key path so allowedSSHKeyPath validation passes for Encrypt.
|
||||
t.Setenv("PICOCLAW_SSH_KEY_PATH", sshKeyPath)
|
||||
|
||||
enc, err := credential.Encrypt("passphrase", sshKeyPath, "sk-secret")
|
||||
if err != nil {
|
||||
t.Fatalf("Encrypt: %v", err)
|
||||
}
|
||||
|
||||
// Point to a non-existent SSH key so deriveKey's ReadFile fails.
|
||||
// The path is still under the same dir, so allowedSSHKeyPath passes (exact env match).
|
||||
t.Setenv("PICOCLAW_KEY_PASSPHRASE", "passphrase")
|
||||
t.Setenv("PICOCLAW_SSH_KEY_PATH", filepath.Join(dir, "nonexistent_key"))
|
||||
|
||||
r := credential.NewResolver(t.TempDir())
|
||||
_, err = r.Resolve(enc)
|
||||
if err == nil {
|
||||
t.Fatal("expected error when SSH key file is missing, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestResolve_FileRef_PathTraversal verifies that file:// references cannot escape configDir
|
||||
// via relative traversal ("../../etc/passwd") or absolute paths ("/abs/path").
|
||||
func TestResolve_FileRef_PathTraversal(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
cfgPath := filepath.Join(dir, "config.json")
|
||||
// Create a file outside configDir that the traversal would point to.
|
||||
outsideFile := filepath.Join(t.TempDir(), "secret.key")
|
||||
if err := os.WriteFile(outsideFile, []byte("stolen"), 0o600); err != nil {
|
||||
t.Fatalf("setup: %v", err)
|
||||
}
|
||||
|
||||
r := credential.NewResolver(filepath.Dir(cfgPath))
|
||||
|
||||
cases := []string{
|
||||
"file://../../secret.key",
|
||||
"file://../secret.key",
|
||||
"file://" + outsideFile, // absolute path
|
||||
}
|
||||
for _, raw := range cases {
|
||||
_, err := r.Resolve(raw)
|
||||
if err == nil {
|
||||
t.Errorf("Resolve(%q): expected path traversal error, got nil", raw)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestResolve_FileRef_withinConfigDir verifies that a legitimate relative file:// ref works.
|
||||
func TestResolve_FileRef_withinConfigDir(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
if err := os.WriteFile(filepath.Join(dir, "my.key"), []byte("sk-valid\n"), 0o600); err != nil {
|
||||
t.Fatalf("setup: %v", err)
|
||||
}
|
||||
r := credential.NewResolver(dir)
|
||||
got, err := r.Resolve("file://my.key")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if got != "sk-valid" {
|
||||
t.Fatalf("got %q, want %q", got, "sk-valid")
|
||||
}
|
||||
}
|
||||
|
||||
// TestEncrypt_SSHKeyOutsideAllowedDirs verifies that Encrypt rejects SSH key paths
|
||||
// that are not under PICOCLAW_SSH_KEY_PATH, PICOCLAW_HOME, or ~/.ssh/.
|
||||
func TestEncrypt_SSHKeyOutsideAllowedDirs(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
sshKeyPath := filepath.Join(dir, "picoclaw_ed25519.key")
|
||||
if err := os.WriteFile(sshKeyPath, []byte("fake-key\n"), 0o600); err != nil {
|
||||
t.Fatalf("setup: %v", err)
|
||||
}
|
||||
|
||||
// Make sure none of the allowed env vars point here.
|
||||
t.Setenv("PICOCLAW_SSH_KEY_PATH", "")
|
||||
t.Setenv("PICOCLAW_HOME", "")
|
||||
|
||||
_, err := credential.Encrypt("passphrase", sshKeyPath, "sk-secret")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for SSH key outside allowed directories, got nil")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,62 @@
|
||||
package credential
|
||||
|
||||
import (
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
// DefaultSSHKeyPath returns the canonical path for the picoclaw-specific SSH key.
|
||||
// The path is always ~/.ssh/picoclaw_ed25519.key (os.UserHomeDir is cross-platform).
|
||||
func DefaultSSHKeyPath() (string, error) {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("credential: cannot determine home directory: %w", err)
|
||||
}
|
||||
return filepath.Join(home, ".ssh", "picoclaw_ed25519.key"), nil
|
||||
}
|
||||
|
||||
// GenerateSSHKey generates an Ed25519 SSH key pair and writes the private key
|
||||
// to path (permissions 0600) and the public key to path+".pub" (permissions 0644).
|
||||
// The ~/.ssh/ directory is created with 0700 if it does not exist.
|
||||
// If the files already exist they are overwritten.
|
||||
func GenerateSSHKey(path string) error {
|
||||
if err := os.MkdirAll(filepath.Dir(path), 0o700); err != nil {
|
||||
return fmt.Errorf("credential: keygen: cannot create directory %q: %w", filepath.Dir(path), err)
|
||||
}
|
||||
|
||||
pubRaw, privRaw, err := ed25519.GenerateKey(rand.Reader)
|
||||
if err != nil {
|
||||
return fmt.Errorf("credential: keygen: ed25519 key generation failed: %w", err)
|
||||
}
|
||||
|
||||
// Marshal private key as OpenSSH PEM.
|
||||
block, err := ssh.MarshalPrivateKey(privRaw, "")
|
||||
if err != nil {
|
||||
return fmt.Errorf("credential: keygen: marshal private key: %w", err)
|
||||
}
|
||||
privPEM := pem.EncodeToMemory(block)
|
||||
|
||||
if err = os.WriteFile(path, privPEM, 0o600); err != nil {
|
||||
return fmt.Errorf("credential: keygen: write private key %q: %w", path, err)
|
||||
}
|
||||
|
||||
// Marshal public key as authorized_keys line.
|
||||
sshPub, err := ssh.NewPublicKey(pubRaw)
|
||||
if err != nil {
|
||||
return fmt.Errorf("credential: keygen: marshal public key: %w", err)
|
||||
}
|
||||
pubLine := ssh.MarshalAuthorizedKey(sshPub)
|
||||
|
||||
pubPath := path + ".pub"
|
||||
if err := os.WriteFile(pubPath, pubLine, 0o644); err != nil {
|
||||
return fmt.Errorf("credential: keygen: write public key %q: %w", pubPath, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,115 @@
|
||||
package credential
|
||||
|
||||
import (
|
||||
"crypto/ed25519"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
func TestGenerateSSHKey_CreatesFiles(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
keyPath := filepath.Join(dir, "test_ed25519.key")
|
||||
|
||||
if err := GenerateSSHKey(keyPath); err != nil {
|
||||
t.Fatalf("GenerateSSHKey() error = %v", err)
|
||||
}
|
||||
|
||||
// Private key must exist.
|
||||
privInfo, err := os.Stat(keyPath)
|
||||
if err != nil {
|
||||
t.Fatalf("private key file missing: %v", err)
|
||||
}
|
||||
|
||||
// Check permissions on non-Windows (Windows does not support Unix permission bits).
|
||||
if runtime.GOOS != "windows" {
|
||||
if got := privInfo.Mode().Perm(); got != 0o600 {
|
||||
t.Errorf("private key permissions = %04o, want 0600", got)
|
||||
}
|
||||
}
|
||||
|
||||
// Public key must exist.
|
||||
pubPath := keyPath + ".pub"
|
||||
pubInfo, err := os.Stat(pubPath)
|
||||
if err != nil {
|
||||
t.Fatalf("public key file missing: %v", err)
|
||||
}
|
||||
if runtime.GOOS != "windows" {
|
||||
if got := pubInfo.Mode().Perm(); got != 0o644 {
|
||||
t.Errorf("public key permissions = %04o, want 0644", got)
|
||||
}
|
||||
}
|
||||
|
||||
// Private key must be parseable as an OpenSSH ed25519 key.
|
||||
privPEM, err := os.ReadFile(keyPath)
|
||||
if err != nil {
|
||||
t.Fatalf("read private key: %v", err)
|
||||
}
|
||||
privKey, err := ssh.ParseRawPrivateKey(privPEM)
|
||||
if err != nil {
|
||||
t.Fatalf("parse private key: %v", err)
|
||||
}
|
||||
if _, ok := privKey.(*ed25519.PrivateKey); !ok {
|
||||
t.Errorf("private key type = %T, want *ed25519.PrivateKey", privKey)
|
||||
}
|
||||
|
||||
// Public key must be parseable as authorized_keys line.
|
||||
pubBytes, err := os.ReadFile(pubPath)
|
||||
if err != nil {
|
||||
t.Fatalf("read public key: %v", err)
|
||||
}
|
||||
pubKey, _, _, rest, err := ssh.ParseAuthorizedKey(pubBytes)
|
||||
if err != nil {
|
||||
t.Fatalf("parse public key: %v", err)
|
||||
}
|
||||
if pubKey == nil {
|
||||
t.Fatal("expected non-nil public key")
|
||||
}
|
||||
if len(rest) > 0 {
|
||||
t.Errorf("unexpected trailing bytes after public key: %d bytes", len(rest))
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateSSHKey_OverwritesExisting(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
keyPath := filepath.Join(dir, "test_ed25519.key")
|
||||
|
||||
// Generate twice; second call must not error and must produce a different key.
|
||||
if err := GenerateSSHKey(keyPath); err != nil {
|
||||
t.Fatalf("first GenerateSSHKey() error = %v", err)
|
||||
}
|
||||
first, err := os.ReadFile(keyPath)
|
||||
if err != nil {
|
||||
t.Fatalf("read first key: %v", err)
|
||||
}
|
||||
|
||||
if err = GenerateSSHKey(keyPath); err != nil {
|
||||
t.Fatalf("second GenerateSSHKey() error = %v", err)
|
||||
}
|
||||
second, err := os.ReadFile(keyPath)
|
||||
if err != nil {
|
||||
t.Fatalf("read second key: %v", err)
|
||||
}
|
||||
|
||||
// Two independently generated Ed25519 keys must differ.
|
||||
if string(first) == string(second) {
|
||||
t.Error("expected overwritten key to differ from original")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateSSHKey_CreatesDirectory(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
// Nested directory that does not yet exist.
|
||||
keyPath := filepath.Join(dir, "subdir", ".ssh", "picoclaw_ed25519.key")
|
||||
|
||||
if err := GenerateSSHKey(keyPath); err != nil {
|
||||
t.Fatalf("GenerateSSHKey() error = %v", err)
|
||||
}
|
||||
|
||||
if _, err := os.Stat(keyPath); err != nil {
|
||||
t.Fatalf("private key not created: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,44 @@
|
||||
package credential
|
||||
|
||||
import "sync/atomic"
|
||||
|
||||
// SecureStore holds a passphrase in memory.
|
||||
//
|
||||
// Uses atomic.Pointer so reads and writes are lock-free.
|
||||
// The passphrase is never written to disk; callers decide how to
|
||||
// transport it outside this store (e.g., via cmd.Env or os.Environ).
|
||||
type SecureStore struct {
|
||||
val atomic.Pointer[string]
|
||||
}
|
||||
|
||||
// NewSecureStore creates an empty SecureStore.
|
||||
func NewSecureStore() *SecureStore {
|
||||
return &SecureStore{}
|
||||
}
|
||||
|
||||
// SetString stores the passphrase. An empty string clears the store.
|
||||
func (s *SecureStore) SetString(passphrase string) {
|
||||
if passphrase == "" {
|
||||
s.val.Store(nil)
|
||||
return
|
||||
}
|
||||
s.val.Store(&passphrase)
|
||||
}
|
||||
|
||||
// Get returns the stored passphrase, or "" if not set.
|
||||
func (s *SecureStore) Get() string {
|
||||
if p := s.val.Load(); p != nil {
|
||||
return *p
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// IsSet reports whether a passphrase is currently stored.
|
||||
func (s *SecureStore) IsSet() bool {
|
||||
return s.val.Load() != nil
|
||||
}
|
||||
|
||||
// Clear removes the stored passphrase.
|
||||
func (s *SecureStore) Clear() {
|
||||
s.val.Store(nil)
|
||||
}
|
||||
@@ -0,0 +1,81 @@
|
||||
package credential
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSecureStore_SetGet(t *testing.T) {
|
||||
s := NewSecureStore()
|
||||
if s.IsSet() {
|
||||
t.Error("expected empty store")
|
||||
}
|
||||
|
||||
s.SetString("hunter2")
|
||||
if !s.IsSet() {
|
||||
t.Error("expected store to be set")
|
||||
}
|
||||
if got := s.Get(); got != "hunter2" {
|
||||
t.Errorf("Get() = %q, want %q", got, "hunter2")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecureStore_Clear(t *testing.T) {
|
||||
s := NewSecureStore()
|
||||
s.SetString("secret")
|
||||
s.Clear()
|
||||
|
||||
if s.IsSet() {
|
||||
t.Error("expected store to be empty after Clear()")
|
||||
}
|
||||
if got := s.Get(); got != "" {
|
||||
t.Errorf("Get() after Clear() = %q, want empty", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecureStore_SetOverwrites(t *testing.T) {
|
||||
s := NewSecureStore()
|
||||
s.SetString("first")
|
||||
s.SetString("second")
|
||||
|
||||
if got := s.Get(); got != "second" {
|
||||
t.Errorf("Get() = %q, want %q", got, "second")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecureStore_EmptyPassphrase(t *testing.T) {
|
||||
s := NewSecureStore()
|
||||
s.SetString("") // empty → should not mark as set
|
||||
|
||||
if s.IsSet() {
|
||||
t.Error("empty passphrase should not mark store as set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecureStore_ConcurrentSetGet(t *testing.T) {
|
||||
s := NewSecureStore()
|
||||
const goroutines = 10
|
||||
const iterations = 1000
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(goroutines)
|
||||
for i := 0; i < goroutines; i++ {
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
for j := 0; j < iterations; j++ {
|
||||
if id%2 == 0 {
|
||||
s.SetString("even")
|
||||
} else {
|
||||
s.SetString("odd")
|
||||
}
|
||||
_ = s.Get()
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
final := s.Get()
|
||||
if final != "" && final != "even" && final != "odd" {
|
||||
t.Errorf("Get() returned unexpected value %q after concurrent Set/Get", final)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user