Merge pull request #2184 from cytown/config

refactor config and add ModelConfig.Enabled
This commit is contained in:
daming大铭
2026-03-30 17:23:07 +08:00
committed by GitHub
13 changed files with 1272 additions and 508 deletions
+54 -128
View File
@@ -7,8 +7,8 @@ import (
"math/rand"
"os"
"path/filepath"
"strings"
"sync/atomic"
"time"
"github.com/caarlos0/env/v11"
@@ -20,89 +20,8 @@ import (
// rrCounter is a global counter for round-robin load balancing across models.
var rrCounter atomic.Uint64
// FlexibleStringSlice is a []string that also accepts JSON numbers,
// so allow_from can contain both "123" and 123.
// It also supports parsing comma-separated strings from environment variables,
// including both English (,) and Chinese () commas.
type FlexibleStringSlice []string
func (f *FlexibleStringSlice) UnmarshalJSON(data []byte) error {
// Accept a single JSON string for convenience, e.g.:
// "text": "Thinking..."
var singleString string
if err := json.Unmarshal(data, &singleString); err == nil {
*f = FlexibleStringSlice{singleString}
return nil
}
// Accept a single JSON number too, to keep symmetry with mixed allow_from
// payloads that may contain numeric identifiers.
var singleNumber float64
if err := json.Unmarshal(data, &singleNumber); err == nil {
*f = FlexibleStringSlice{fmt.Sprintf("%.0f", singleNumber)}
return nil
}
// Try []string first
var ss []string
if err := json.Unmarshal(data, &ss); err == nil {
*f = ss
return nil
}
// Try []interface{} to handle mixed types
var raw []any
if err := json.Unmarshal(data, &raw); err != nil {
var s string
// fail over to compatible to old format string
if err = json.Unmarshal(data, &s); err != nil {
return err
}
*f = []string{s}
return nil
}
result := make([]string, 0, len(raw))
for _, v := range raw {
switch val := v.(type) {
case string:
result = append(result, val)
case float64:
result = append(result, fmt.Sprintf("%.0f", val))
default:
result = append(result, fmt.Sprintf("%v", val))
}
}
*f = result
return nil
}
// UnmarshalText implements encoding.TextUnmarshaler to support env variable parsing.
// It handles comma-separated values with both English (,) and Chinese () commas.
func (f *FlexibleStringSlice) UnmarshalText(text []byte) error {
if len(text) == 0 {
*f = nil
return nil
}
s := string(text)
// Replace Chinese comma with English comma, then split
s = strings.ReplaceAll(s, "", ",")
parts := strings.Split(s, ",")
result := make([]string, 0, len(parts))
for _, part := range parts {
part = strings.TrimSpace(part)
if part != "" {
result = append(result, part)
}
}
*f = result
return nil
}
// CurrentVersion is the latest config schema version
const CurrentVersion = 1
const CurrentVersion = 2
// Config is the current config structure with version support
type Config struct {
@@ -675,6 +594,11 @@ type ModelConfig struct {
APIKeys SecureStrings `json:"api_keys,omitzero" yaml:"api_keys,omitempty"` // API authentication keys (multiple keys for failover)
// Enabled indicates whether this model entry is active. When omitted in
// existing configs, the field is inferred during load: models with API keys
// or the reserved "local-model" name are auto-enabled.
Enabled bool `json:"enabled,omitempty" yaml:"enabled,omitempty"`
// isVirtual marks this model as a virtual model generated from multi-key expansion.
// Virtual models should not be persisted to config files.
isVirtual bool
@@ -1047,6 +971,35 @@ func LoadConfig(path string) (*Config, error) {
defer func(cfg *Config) {
_ = SaveConfig(path, cfg)
}(cfg)
case 1:
// V1→V2 migration: infer Enabled and migrate channel config fields
logger.InfoF("config migrate start", map[string]any{"from": versionInfo.Version, "to": CurrentVersion})
cfg, err = loadConfig(data)
if err != nil {
return nil, err
}
secPath := securityPath(path)
err = loadSecurityConfig(cfg, secPath)
if err != nil && !errors.Is(err, os.ErrNotExist) {
return nil, fmt.Errorf("failed to load security config: %w", err)
}
oldCfg := &configV1{Config: *cfg}
cfg, err = oldCfg.Migrate()
if err != nil {
logger.ErrorF("config migrate fail", map[string]any{"from": versionInfo.Version, "to": CurrentVersion})
return nil, err
}
err = makeBackup(path)
if err != nil {
return nil, err
}
defer func(cfg *Config) {
_ = SaveConfig(path, cfg)
}(cfg)
logger.InfoF("config migrate success", map[string]any{"from": versionInfo.Version, "to": CurrentVersion})
case CurrentVersion:
// Current version
cfg, err = loadConfig(data)
@@ -1064,18 +1017,15 @@ func LoadConfig(path string) (*Config, error) {
return nil, fmt.Errorf("unsupported config version: %d", versionInfo.Version)
}
if err := env.Parse(cfg); err != nil {
if err = env.Parse(cfg); err != nil {
return nil, err
}
// Expand multi-key configs into separate entries for key-level failover
cfg.ModelList = expandMultiKeyModels(cfg.ModelList)
// Migrate legacy channel config fields to new unified structures
cfg.migrateChannelConfigs()
// Validate model_list for uniqueness and required fields
if err := cfg.ValidateModelList(); err != nil {
if err = cfg.ValidateModelList(); err != nil {
return nil, err
}
@@ -1097,12 +1047,22 @@ func makeBackup(path string) error {
if _, err := os.Stat(path); os.IsNotExist(err) {
return nil
}
// Create backup of the config file before migration
bakPath := path + ".bak"
dateSuffix := time.Now().Format(".20060102.bak")
// Backup config file
bakPath := path + dateSuffix
if err := fileutil.CopyFile(path, bakPath, 0o600); err != nil {
logger.ErrorF("failed to create config backup", map[string]any{"error": err})
return fmt.Errorf("failed to create config backup: %w", err)
}
// Backup security config file
secPath := securityPath(path)
if _, err := os.Stat(secPath); err == nil {
secBakPath := secPath + dateSuffix
if secErr := fileutil.CopyFile(secPath, secBakPath, 0o600); secErr != nil {
logger.ErrorF("failed to create security backup", map[string]any{"error": secErr})
return fmt.Errorf("failed to create security backup: %w", secErr)
}
}
return nil
}
@@ -1118,19 +1078,6 @@ func toNameIndex(list []*ModelConfig) []string {
return nameList
}
func (c *Config) migrateChannelConfigs() {
// Discord: mention_only -> group_trigger.mention_only
if c.Channels.Discord.MentionOnly && !c.Channels.Discord.GroupTrigger.MentionOnly {
c.Channels.Discord.GroupTrigger.MentionOnly = true
}
// OneBot: group_trigger_prefix -> group_trigger.prefixes
if len(c.Channels.OneBot.GroupTriggerPrefix) > 0 &&
len(c.Channels.OneBot.GroupTrigger.Prefixes) == 0 {
c.Channels.OneBot.GroupTrigger.Prefixes = c.Channels.OneBot.GroupTriggerPrefix
}
}
func SaveConfig(path string, cfg *Config) error {
if cfg.Version < CurrentVersion {
cfg.Version = CurrentVersion
@@ -1144,6 +1091,10 @@ func SaveConfig(path string, cfg *Config) error {
}
// Temporarily replace ModelList with filtered version for serialization
originalModelList := cfg.ModelList
defer func() {
// Restore original ModelList after serialization
cfg.ModelList = originalModelList
}()
cfg.ModelList = nonVirtualModels
if err := saveSecurityConfig(securityPath(path), cfg); err != nil {
@@ -1152,8 +1103,6 @@ func SaveConfig(path string, cfg *Config) error {
}
data, err := json.MarshalIndent(cfg, "", " ")
// Restore original ModelList after serialization
cfg.ModelList = originalModelList
if err != nil {
return err
}
@@ -1223,29 +1172,6 @@ func (c *Config) SecurityCopyFrom(path string) error {
return loadSecurityConfig(c, securityPath(path))
}
func MergeAPIKeys(apiKey string, apiKeys []string) []string {
seen := make(map[string]struct{})
var all []string
if k := strings.TrimSpace(apiKey); k != "" {
if _, exists := seen[k]; !exists {
seen[k] = struct{}{}
all = append(all, k)
}
}
for _, k := range apiKeys {
if trimmed := strings.TrimSpace(k); trimmed != "" {
if _, exists := seen[trimmed]; !exists {
seen[trimmed] = struct{}{}
all = append(all, trimmed)
}
}
}
return all
}
// expandMultiKeyModels expands ModelConfig entries with multiple API keys into
// separate entries for key-level failover. Each key gets its own ModelConfig entry,
// and the original entry's fallbacks are set up to chain through the expanded entries.
+57 -5
View File
@@ -734,7 +734,8 @@ func (c *configV0) Migrate() (*Config, error) {
// Convert []modelConfigV0 to []ModelConfig
cfg.ModelList = make([]*ModelConfig, len(c.ModelList))
for i, m := range c.ModelList {
cfg.ModelList[i] = &ModelConfig{
mergedKeys := toSecureStrings(mergeAPIKeys(m.APIKey, m.APIKeys))
mc := &ModelConfig{
ModelName: m.ModelName,
Model: m.Model,
APIBase: m.APIBase,
@@ -747,8 +748,13 @@ func (c *configV0) Migrate() (*Config, error) {
MaxTokensField: m.MaxTokensField,
RequestTimeout: m.RequestTimeout,
ThinkingLevel: m.ThinkingLevel,
APIKeys: toSecureStrings(MergeAPIKeys(m.APIKey, m.APIKeys)),
APIKeys: mergedKeys,
}
// Infer Enabled during V0→V1 migration
if len(mergedKeys) > 0 || m.ModelName == "local-model" {
mc.Enabled = true
}
cfg.ModelList[i] = mc
}
}
@@ -756,6 +762,52 @@ func (c *configV0) Migrate() (*Config, error) {
return cfg, nil
}
type configV1 struct {
Config
}
// Migrate applies V1→Current Version migrations to an already-loaded Config.
//
// It must be called AFTER loadSecurityConfig so that API keys (which live in
// the security file) are available for the Enabled inference.
func (c *configV1) Migrate() (*Config, error) {
c.migrateModelEnabled()
c.migrateChannelConfigs()
return &c.Config, nil
}
// migrateModelEnabled infers the Enabled field for models loaded from V1 configs
// that predate the field (JSON where "enabled" is absent).
//
// Rules (only applied when Enabled has not been explicitly set by the user):
// - Models with API keys are considered enabled.
// - The reserved "local-model" entry is considered enabled.
func (cfg *configV1) migrateModelEnabled() {
for _, m := range cfg.ModelList {
if m.Enabled {
continue
}
if len(m.APIKeys) > 0 || m.ModelName == "local-model" {
m.Enabled = true
}
}
}
// migrateChannelConfigs migrates legacy channel config fields in a V1 Config
// to the new unified structures.
func (cfg *configV1) migrateChannelConfigs() {
// Discord: mention_only -> group_trigger.mention_only
if cfg.Channels.Discord.MentionOnly && !cfg.Channels.Discord.GroupTrigger.MentionOnly {
cfg.Channels.Discord.GroupTrigger.MentionOnly = true
}
// OneBot: group_trigger_prefix -> group_trigger.prefixes
if len(cfg.Channels.OneBot.GroupTriggerPrefix) > 0 &&
len(cfg.Channels.OneBot.GroupTrigger.Prefixes) == 0 {
cfg.Channels.OneBot.GroupTrigger.Prefixes = cfg.Channels.OneBot.GroupTriggerPrefix
}
}
type webToolsConfigV0 struct {
ToolConfig ` envPrefix:"PICOCLAW_TOOLS_WEB_"`
Brave braveConfigV0 ` json:"brave"`
@@ -791,7 +843,7 @@ func (v *braveConfigV0) ToBraveConfig() BraveConfig {
return BraveConfig{
Enabled: v.Enabled,
MaxResults: v.MaxResults,
APIKeys: toSecureStrings(MergeAPIKeys(v.APIKey, v.APIKeys)),
APIKeys: toSecureStrings(mergeAPIKeys(v.APIKey, v.APIKeys)),
}
}
@@ -808,7 +860,7 @@ func (v *tavilyConfigV0) ToTavilyConfig() TavilyConfig {
Enabled: v.Enabled,
BaseURL: v.BaseURL,
MaxResults: v.MaxResults,
APIKeys: toSecureStrings(MergeAPIKeys(v.APIKey, v.APIKeys)),
APIKeys: toSecureStrings(mergeAPIKeys(v.APIKey, v.APIKeys)),
}
}
@@ -823,7 +875,7 @@ func (v *perplexityConfigV0) ToPerplexityConfig() PerplexityConfig {
return PerplexityConfig{
Enabled: v.Enabled,
MaxResults: v.MaxResults,
APIKeys: toSecureStrings(MergeAPIKeys(v.APIKey, v.APIKeys)),
APIKeys: toSecureStrings(mergeAPIKeys(v.APIKey, v.APIKeys)),
}
}
+327
View File
@@ -0,0 +1,327 @@
package config
import (
"encoding/json"
"fmt"
"path/filepath"
"runtime"
"strings"
"sync"
"gopkg.in/yaml.v3"
"github.com/sipeed/picoclaw/pkg/credential"
"github.com/sipeed/picoclaw/pkg/logger"
)
// FlexibleStringSlice is a []string that also accepts JSON numbers,
// so allow_from can contain both "123" and 123.
// It also supports parsing comma-separated strings from environment variables,
// including both English (,) and Chinese () commas.
type FlexibleStringSlice []string
func (f *FlexibleStringSlice) UnmarshalJSON(data []byte) error {
// Accept a single JSON string for convenience, e.g.:
// "text": "Thinking..."
var singleString string
if err := json.Unmarshal(data, &singleString); err == nil {
*f = FlexibleStringSlice{singleString}
return nil
}
// Accept a single JSON number too, to keep symmetry with mixed allow_from
// payloads that may contain numeric identifiers.
var singleNumber float64
if err := json.Unmarshal(data, &singleNumber); err == nil {
*f = FlexibleStringSlice{fmt.Sprintf("%.0f", singleNumber)}
return nil
}
// Try []string first
var ss []string
if err := json.Unmarshal(data, &ss); err == nil {
*f = ss
return nil
}
// Try []interface{} to handle mixed types
var raw []any
if err := json.Unmarshal(data, &raw); err != nil {
var s string
// fail over to compatible to old format string
if err = json.Unmarshal(data, &s); err != nil {
return err
}
*f = []string{s}
return nil
}
result := make([]string, 0, len(raw))
for _, v := range raw {
switch val := v.(type) {
case string:
result = append(result, val)
case float64:
result = append(result, fmt.Sprintf("%.0f", val))
default:
result = append(result, fmt.Sprintf("%v", val))
}
}
*f = result
return nil
}
// UnmarshalText implements encoding.TextUnmarshaler to support env variable parsing.
// It handles comma-separated values with both English (,) and Chinese () commas.
func (f *FlexibleStringSlice) UnmarshalText(text []byte) error {
if len(text) == 0 {
*f = nil
return nil
}
s := string(text)
// Replace Chinese comma with English comma, then split
s = strings.ReplaceAll(s, "", ",")
parts := strings.Split(s, ",")
result := make([]string, 0, len(parts))
for _, part := range parts {
part = strings.TrimSpace(part)
if part != "" {
result = append(result, part)
}
}
*f = result
return nil
}
const (
notHere = `"[NOT_HERE]"`
)
// SecureStrings is a slice of SecureString
type SecureStrings []*SecureString
// Values returns the decrypted/resolved values
func (s *SecureStrings) Values() []string {
if s == nil {
return nil
}
keys := make([]string, len(*s))
for i, k := range *s {
keys[i] = k.String()
}
return unique(keys)
}
func SimpleSecureStrings(val ...string) SecureStrings {
val = unique(val)
vv := make(SecureStrings, len(val))
for i, s := range val {
vv[i] = NewSecureString(s)
}
return vv
}
// unique returns a new slice with duplicate elements removed.
func unique[T comparable](input []T) []T {
m := make(map[T]struct{})
var result []T
for _, v := range input {
if _, ok := m[v]; !ok {
m[v] = struct{}{}
result = append(result, v)
}
}
return result
}
func (s SecureStrings) MarshalJSON() ([]byte, error) {
return []byte(notHere), nil
}
func (s *SecureStrings) UnmarshalJSON(value []byte) error {
if string(value) == notHere {
return nil
}
var v []*SecureString
err := json.Unmarshal(value, &v)
if err != nil {
return err
}
*s = v
return nil
}
// SecureString the string value that can be decrypted or resolved
//
//nolint:recvcheck
type SecureString struct {
resolved string // Decrypted/resolved value returned by String()
raw string // Persisted raw value (enc://, file://, or plaintext)
}
func callerFromYaml() bool {
_, file, _, ok := runtime.Caller(2)
if ok {
d := filepath.Dir(file)
// check the caller is from yaml.v
if !strings.Contains(d, "yaml.v") {
return true
}
}
return false
}
// IsZero returns true if the SecureString is empty
// if caller not yaml, just return true for prevent marshal this field
func (s SecureString) IsZero() bool {
if callerFromYaml() {
return true
}
return s.resolved == ""
}
func NewSecureString(value string) *SecureString {
s := &SecureString{}
if err := s.fromRaw(value); err != nil {
logger.Warn(fmt.Sprintf("NewSecureString.fromRaw error: %s", err))
}
return s
}
func (s *SecureString) String() string {
if s == nil {
return ""
}
return s.resolved
}
func (s *SecureString) Set(value string) *SecureString {
s.resolved = value
s.raw = ""
return s
}
func (s SecureString) MarshalJSON() ([]byte, error) {
return []byte(notHere), nil
}
func (s *SecureString) UnmarshalJSON(value []byte) error {
if string(value) == notHere {
return nil
}
var v string
if err := json.Unmarshal(value, &v); err != nil {
return err
}
return s.fromRaw(v)
}
func (s SecureString) MarshalYAML() (any, error) {
// Preserve raw value if it is already a reference (enc:// or file://)
if strings.HasPrefix(s.raw, credential.EncScheme) || strings.HasPrefix(s.raw, credential.FileScheme) {
return s.raw, nil
}
// If resolved is a reference format (e.g. set via Set), copy back to raw
if strings.HasPrefix(s.resolved, credential.EncScheme) || strings.HasPrefix(s.resolved, credential.FileScheme) {
s.raw = s.resolved
return s.raw, nil
}
// Try to encrypt the resolved value
if passphrase := credential.PassphraseProvider(); passphrase != "" {
encrypted, err := credential.Encrypt(passphrase, "", s.resolved)
if err != nil {
logger.Errorf("Encrypt error: %v", err)
return nil, err
}
s.raw = encrypted
} else {
s.raw = s.resolved
}
return s.raw, nil
}
func (s *SecureString) UnmarshalYAML(value *yaml.Node) error {
return s.fromRaw(value.Value)
}
func (s *SecureString) fromRaw(v string) error {
s.raw = v
vv, err := resolveKey(v)
if err != nil {
return err
}
s.resolved = vv
return nil
}
var (
secResolverMu sync.RWMutex
secResolver *credential.Resolver
)
func updateResolver(path string) {
secResolverMu.Lock()
defer secResolverMu.Unlock()
secResolver = credential.NewResolver(path)
}
func resolveKey(v string) (string, error) {
secResolverMu.RLock()
resolver := secResolver
secResolverMu.RUnlock()
if resolver == nil {
resolver = credential.NewResolver("")
}
if strings.HasPrefix(v, "enc://") || strings.HasPrefix(v, "file://") {
decrypted, err := resolver.Resolve(v)
if err != nil {
logger.Errorf("Resolve error: %v", err)
return "", err
}
return decrypted, nil
}
return v, nil
}
func (s *SecureString) UnmarshalText(text []byte) error {
v := string(text)
return s.fromRaw(v)
}
type SecureModelList []*ModelConfig
func (v *SecureModelList) UnmarshalYAML(value *yaml.Node) error {
mm := make(map[string]*ModelConfig)
if err := value.Decode(&mm); err != nil {
logger.Errorf("Decode error: %v", err)
return err
}
nameList := toNameIndex(*v)
for i, m := range *v {
sec := mm[nameList[i]]
if sec == nil {
sec = mm[m.ModelName]
}
if sec != nil {
m.APIKeys = sec.APIKeys
}
}
return nil
}
func (v SecureModelList) MarshalYAML() (any, error) {
type onlySecureData struct {
APIKeys SecureStrings `yaml:"api_keys,omitempty"`
}
mm := make(map[string]onlySecureData)
nameList := toNameIndex(v)
for i, m := range v {
mm[nameList[i]] = onlySecureData{
APIKeys: m.APIKeys,
}
}
return mm, nil
}
+145
View File
@@ -0,0 +1,145 @@
package config
import (
"encoding/json"
"os"
"path/filepath"
"testing"
"github.com/caarlos0/env/v11"
"github.com/stretchr/testify/assert"
"gopkg.in/yaml.v3"
"github.com/sipeed/picoclaw/pkg/credential"
)
func TestLoadSecurityValue(t *testing.T) {
type valueStruct struct {
Url string `json:"url,omitempty" yaml:"-"`
Token *SecureString `json:"token,omitempty" yaml:"token,omitempty" env:"PICO_TOKEN"`
ApiKeys SecureStrings `json:"api_keys,omitempty" yaml:"api_keys,omitempty" env:"PICO_API_KEYS"`
}
type testStruct struct {
Pico *valueStruct `json:"pico,omitempty" yaml:"pico,omitempty"`
}
v1 := &testStruct{
Pico: &valueStruct{
Url: "https://example.com",
Token: NewSecureString("token1"),
ApiKeys: SecureStrings{NewSecureString("api-key1"), NewSecureString("api-key2")},
},
}
bytes, err := yaml.Marshal(v1)
assert.NoError(t, err)
jsonBytes, err := json.Marshal(v1)
assert.NoError(t, err)
const want = `pico:
token: token1
api_keys:
- api-key1
- api-key2
`
const jsonPost = `{"pico":{"url":"https://example.com","token":"token0"}}`
v0 := &testStruct{}
err = json.Unmarshal([]byte(jsonPost), v0)
assert.NoError(t, err)
assert.Equal(t, "https://example.com", v0.Pico.Url)
assert.Equal(t, "token0", v0.Pico.Token.String())
const jsonWant = `{"pico":{"url":"https://example.com","token":"[NOT_HERE]","api_keys":"[NOT_HERE]"}}`
assert.Equal(t, want, string(bytes))
assert.Equal(t, jsonWant, string(jsonBytes))
v2 := &testStruct{}
err = json.Unmarshal(jsonBytes, v2)
assert.NoError(t, err)
err = yaml.Unmarshal(bytes, v2)
assert.NoError(t, err)
assert.Equal(t, "https://example.com", v2.Pico.Url)
if v2.Pico.Token != nil {
assert.Equal(t, "token1", v2.Pico.Token.String())
assert.Equal(t, "token1", v2.Pico.Token.raw)
}
v2.Pico.Token = NewSecureString("token1")
v2.Pico.Token.raw = "abc"
err = yaml.Unmarshal(bytes, v2)
assert.NoError(t, err)
assert.Equal(t, "token1", v2.Pico.Token.raw)
os.Setenv("PICO_TOKEN", "token_env")
err = env.Parse(v2)
assert.NoError(t, err)
assert.NotNil(t, v2.Pico.Token)
assert.Equal(t, "token1", v2.Pico.Token.String())
v3 := &testStruct{Pico: &valueStruct{}}
err = env.Parse(v3)
assert.NoError(t, err)
if v3.Pico.Token != nil {
assert.Equal(t, "token_env", v3.Pico.Token.String())
}
type toolsStruct struct {
Pico valueStruct `json:"pico,omitempty" yaml:"pico,omitempty"`
}
type testStruct2 struct {
Tools toolsStruct `json:"tools,omitempty" yaml:",inline"`
}
v4 := &testStruct2{
Tools: toolsStruct{
Pico: valueStruct{
Url: "https://example.com",
Token: NewSecureString("token1"),
ApiKeys: SecureStrings{NewSecureString("api-key1"), NewSecureString("api-key2")},
},
},
}
bytes, err = yaml.Marshal(v4)
assert.NoError(t, err)
assert.Equal(t, want, string(bytes))
jsonBytes, err = json.Marshal(v4)
assert.NoError(t, err)
assert.Equal(
t,
`{"tools":{"pico":{"url":"https://example.com","token":"[NOT_HERE]","api_keys":"[NOT_HERE]"}}}`,
string(jsonBytes),
)
v5 := &testStruct2{}
err = json.Unmarshal(jsonBytes, v5)
assert.NoError(t, err)
assert.Equal(t, "https://example.com", v5.Tools.Pico.Url)
err = yaml.Unmarshal(bytes, v5)
assert.NoError(t, err)
assert.NotNil(t, v5.Tools.Pico.Token)
assert.Equal(t, "token1", v5.Tools.Pico.Token.raw)
dir := t.TempDir()
sshKeyPath := filepath.Join(dir, "picoclaw_ed25519.key")
if err = os.WriteFile(sshKeyPath, []byte("fake-ssh-key-material\n"), 0o600); err != nil {
t.Fatalf("setup: %v", err)
}
const passphrase = "test-passphrase-32bytes-long-ok!"
t.Setenv(credential.SSHKeyPathEnvVar, sshKeyPath)
t.Setenv(credential.PassphraseEnvVar, passphrase)
v5.Tools.Pico.Token.Set("newtoken1")
v5.Tools.Pico.ApiKeys[0].Set("newapi-key1")
bytes, err = yaml.Marshal(v5)
assert.NoError(t, err)
t.Logf("yaml: %s", string(bytes))
v6 := &testStruct2{}
err = yaml.Unmarshal(bytes, v6)
assert.NoError(t, err)
assert.NotNil(t, v6.Tools.Pico.Token)
assert.Equal(t, "newtoken1", v6.Tools.Pico.Token.String())
}
+160
View File
@@ -1673,3 +1673,163 @@ func TestFilterSensitiveData_AllTokenTypes(t *testing.T) {
})
}
}
// ---------------------------------------------------------------------------
// makeBackup tests
// ---------------------------------------------------------------------------
// TestMakeBackup_WithDateSuffix verifies backup files include a date suffix.
func TestMakeBackup_WithDateSuffix(t *testing.T) {
dir := t.TempDir()
configPath := filepath.Join(dir, "config.json")
if err := os.WriteFile(configPath, []byte(`{"version":2}`), 0o600); err != nil {
t.Fatalf("WriteFile: %v", err)
}
if err := makeBackup(configPath); err != nil {
t.Fatalf("makeBackup: %v", err)
}
entries, err := os.ReadDir(dir)
if err != nil {
t.Fatalf("ReadDir: %v", err)
}
var hasDatedBackup bool
for _, e := range entries {
if matched, _ := filepath.Match("config.json.20*.bak", e.Name()); matched {
hasDatedBackup = true
// Verify backup content matches original
bakPath := filepath.Join(dir, e.Name())
data, err := os.ReadFile(bakPath)
if err != nil {
t.Fatalf("ReadFile backup: %v", err)
}
if string(data) != `{"version":2}` {
t.Errorf("backup content = %q, want original content", string(data))
}
break
}
}
if !hasDatedBackup {
t.Error("expected backup file with date suffix pattern config.json.20*.bak")
}
}
// TestMakeBackup_AlsoBacksSecurityFile verifies that the security config file
// is also backed up with the same date suffix.
func TestMakeBackup_AlsoBacksSecurityFile(t *testing.T) {
dir := t.TempDir()
configPath := filepath.Join(dir, "config.json")
secPath := securityPath(configPath)
os.WriteFile(configPath, []byte(`{"version":2}`), 0o600)
os.WriteFile(secPath, []byte(`model_list:\n test:0:\n api_keys:\n - "sk-test"\n`), 0o600)
if err := makeBackup(configPath); err != nil {
t.Fatalf("makeBackup: %v", err)
}
entries, err := os.ReadDir(dir)
if err != nil {
t.Fatalf("ReadDir: %v", err)
}
configBackups := 0
secBackups := 0
for _, e := range entries {
if matched, _ := filepath.Match("config.json.20*.bak", e.Name()); matched {
configBackups++
}
if matched, _ := filepath.Match(".security.yml.20*.bak", e.Name()); matched {
secBackups++
}
}
if configBackups != 1 {
t.Errorf("expected 1 config backup, got %d", configBackups)
}
if secBackups != 1 {
t.Errorf("expected 1 security backup, got %d", secBackups)
}
}
// TestMakeBackup_NonexistentFileSkipsBackup verifies that makeBackup returns nil
// when the config file does not exist (no error, no panic).
func TestMakeBackup_NonexistentFileSkipsBackup(t *testing.T) {
dir := t.TempDir()
configPath := filepath.Join(dir, "nonexistent.json")
if err := makeBackup(configPath); err != nil {
t.Fatalf("makeBackup on nonexistent file should return nil, got: %v", err)
}
}
// TestMakeBackup_OnlyConfigNoSecurity verifies backup succeeds when only
// the config file exists and no security file.
func TestMakeBackup_OnlyConfigNoSecurity(t *testing.T) {
dir := t.TempDir()
configPath := filepath.Join(dir, "config.json")
os.WriteFile(configPath, []byte(`{"version":2}`), 0o600)
if err := makeBackup(configPath); err != nil {
t.Fatalf("makeBackup: %v", err)
}
entries, _ := os.ReadDir(dir)
configBackups := 0
secBackups := 0
for _, e := range entries {
if matched, _ := filepath.Match("config.json.20*.bak", e.Name()); matched {
configBackups++
}
if matched, _ := filepath.Match(".security.yml.20*.bak", e.Name()); matched {
secBackups++
}
}
if configBackups != 1 {
t.Errorf("expected 1 config backup, got %d", configBackups)
}
if secBackups != 0 {
t.Errorf("expected 0 security backups when no security file exists, got %d", secBackups)
}
}
// TestMakeBackup_SameDateSuffix verifies that config and security backups
// share the same date suffix (they are created in the same makeBackup call).
func TestMakeBackup_SameDateSuffix(t *testing.T) {
dir := t.TempDir()
configPath := filepath.Join(dir, "config.json")
secPath := securityPath(configPath)
os.WriteFile(configPath, []byte(`{"version":2}`), 0o600)
os.WriteFile(secPath, []byte(`key: value`), 0o600)
if err := makeBackup(configPath); err != nil {
t.Fatalf("makeBackup: %v", err)
}
entries, _ := os.ReadDir(dir)
var configDate, secDate string
for _, e := range entries {
name := e.Name()
// Extract date part: after the last . before .bak
// e.g. config.json.20260330.bak → 20260330
if strings.HasPrefix(name, "config.json.") && strings.HasSuffix(name, ".bak") {
configDate = strings.TrimPrefix(name, "config.json.")
configDate = strings.TrimSuffix(configDate, ".bak")
}
if strings.HasPrefix(name, ".security.yml.") && strings.HasSuffix(name, ".bak") {
secDate = strings.TrimPrefix(name, ".security.yml.")
secDate = strings.TrimSuffix(secDate, ".bak")
}
}
if configDate == "" {
t.Fatal("config backup file not found")
}
if secDate == "" {
t.Fatal("security backup file not found")
}
if configDate != secDate {
t.Errorf("config backup date = %q, security backup date = %q, should match", configDate, secDate)
}
}
+23
View File
@@ -534,3 +534,26 @@ func loadConfig(data []byte) (*Config, error) {
}
return cfg, nil
}
func mergeAPIKeys(apiKey string, apiKeys []string) []string {
seen := make(map[string]struct{})
var all []string
if k := strings.TrimSpace(apiKey); k != "" {
if _, exists := seen[k]; !exists {
seen[k] = struct{}{}
all = append(all, k)
}
}
for _, k := range apiKeys {
if trimmed := strings.TrimSpace(k); trimmed != "" {
if _, exists := seen[trimmed]; !exists {
seen[trimmed] = struct{}{}
all = append(all, trimmed)
}
}
}
return all
}
+470
View File
@@ -681,3 +681,473 @@ web:
t.Error("Discord token not preserved in .security.yml file")
}
}
// ---------------------------------------------------------------------------
// V1 → V2 migration tests
// ---------------------------------------------------------------------------
// TestMigrateModelEnabled_APIKeysInferredEnabled verifies that models with API keys
// are marked as enabled during V1→V2 migration.
func TestMigrateModelEnabled_APIKeysInferredEnabled(t *testing.T) {
v1 := &configV1{Config: Config{
ModelList: []*ModelConfig{
{ModelName: "gpt-4", Model: "openai/gpt-4", APIKeys: SimpleSecureStrings("sk-test")},
{ModelName: "claude", Model: "anthropic/claude", APIKeys: SimpleSecureStrings("sk-ant")},
},
}}
v1.migrateModelEnabled()
for _, m := range v1.ModelList {
if !m.Enabled {
t.Errorf("model %q with API key should be enabled", m.ModelName)
}
}
}
// TestMigrateModelEnabled_LocalModelInferredEnabled verifies that the reserved
// "local-model" entry is enabled even without API keys.
func TestMigrateModelEnabled_LocalModelInferredEnabled(t *testing.T) {
v1 := &configV1{Config: Config{
ModelList: []*ModelConfig{
{ModelName: "local-model", Model: "vllm/custom-model", APIBase: "http://localhost:8000/v1"},
},
}}
v1.migrateModelEnabled()
if !v1.ModelList[0].Enabled {
t.Error("local-model should be enabled")
}
}
// TestMigrateModelEnabled_NoKeyStaysDisabled verifies that models without API keys
// and not named "local-model" remain disabled.
func TestMigrateModelEnabled_NoKeyStaysDisabled(t *testing.T) {
v1 := &configV1{Config: Config{
ModelList: []*ModelConfig{
{ModelName: "gpt-4", Model: "openai/gpt-4"},
{ModelName: "claude", Model: "anthropic/claude"},
},
}}
v1.migrateModelEnabled()
for _, m := range v1.ModelList {
if m.Enabled {
t.Errorf("model %q without API key should stay disabled", m.ModelName)
}
}
}
// TestMigrateModelEnabled_ExplicitEnabledPreserved verifies that a model with
// explicitly enabled=true is NOT overridden by the migration.
func TestMigrateModelEnabled_ExplicitEnabledPreserved(t *testing.T) {
v1 := &configV1{Config: Config{
ModelList: []*ModelConfig{
{ModelName: "gpt-4", Model: "openai/gpt-4", APIKeys: SimpleSecureStrings("sk-test"), Enabled: true},
},
}}
v1.migrateModelEnabled()
if !v1.ModelList[0].Enabled {
t.Error("explicitly enabled model should remain enabled")
}
}
// TestMigrateModelEnabled_ExplicitDisabledNotOverridden verifies that a model with
// explicitly enabled=false and API keys gets enabled during migration.
// Note: since Go's zero value for bool is false and JSON omitempty omits false,
// migration cannot distinguish "explicitly false" from "field absent". Both cases
// get the same inference treatment.
func TestMigrateModelEnabled_ExplicitDisabledNotOverridden(t *testing.T) {
v1 := &configV1{Config: Config{
ModelList: []*ModelConfig{
{ModelName: "gpt-4", Model: "openai/gpt-4", APIKeys: SimpleSecureStrings("sk-test"), Enabled: false},
},
}}
v1.migrateModelEnabled()
// Even though Enabled was set to false, migration infers it as true because
// the migration cannot distinguish from a missing field (both are zero value).
if !v1.ModelList[0].Enabled {
t.Error("model with API key should be enabled by migration inference")
}
}
// TestMigrateModelEnabled_Mixed verifies a mix of models.
func TestMigrateModelEnabled_Mixed(t *testing.T) {
v1 := &configV1{Config: Config{
ModelList: []*ModelConfig{
{ModelName: "with-key", Model: "openai/gpt-4", APIKeys: SimpleSecureStrings("sk-test")},
{ModelName: "no-key", Model: "openai/gpt-4"},
{ModelName: "local-model", Model: "vllm/custom"},
{
ModelName: "disabled-explicit",
Model: "openai/gpt-4",
APIKeys: SimpleSecureStrings("sk-test"),
Enabled: false,
},
},
}}
v1.migrateModelEnabled()
assertEnabled := func(name string, want bool) {
for _, m := range v1.ModelList {
if m.ModelName == name {
if m.Enabled != want {
t.Errorf("model %q: Enabled=%v, want %v", name, m.Enabled, want)
}
return
}
}
t.Errorf("model %q not found", name)
}
assertEnabled("with-key", true)
assertEnabled("no-key", false)
assertEnabled("local-model", true)
assertEnabled("disabled-explicit", true) // false is zero value, migration infers from API key
}
// TestMigrateChannelConfigs_DiscordMentionOnly verifies Discord mention_only migration.
func TestMigrateChannelConfigs_DiscordMentionOnly(t *testing.T) {
v1 := &configV1{Config: Config{
Channels: ChannelsConfig{
Discord: DiscordConfig{
MentionOnly: true,
},
},
}}
v1.migrateChannelConfigs()
if !v1.Channels.Discord.GroupTrigger.MentionOnly {
t.Error("Discord GroupTrigger.MentionOnly should be set to true")
}
}
// TestMigrateChannelConfigs_DiscordAlreadyMigrated is a no-op test.
func TestMigrateChannelConfigs_DiscordAlreadyMigrated(t *testing.T) {
v1 := &configV1{Config: Config{
Channels: ChannelsConfig{
Discord: DiscordConfig{
GroupTrigger: GroupTriggerConfig{MentionOnly: true},
},
},
}}
v1.migrateChannelConfigs()
}
// TestMigrateChannelConfigs_OneBotPrefix verifies OneBot prefix migration.
func TestMigrateChannelConfigs_OneBotPrefix(t *testing.T) {
v1 := &configV1{Config: Config{
Channels: ChannelsConfig{
OneBot: OneBotConfig{
GroupTriggerPrefix: []string{"/"},
},
},
}}
v1.migrateChannelConfigs()
if len(v1.Channels.OneBot.GroupTrigger.Prefixes) != 1 || v1.Channels.OneBot.GroupTrigger.Prefixes[0] != "/" {
t.Errorf("OneBot GroupTrigger.Prefixes = %v, want [\"/\"]", v1.Channels.OneBot.GroupTrigger.Prefixes)
}
}
// TestMigrateConfigV1_Combined verifies that configV1.Migrate applies both migrations.
func TestMigrateConfigV1_Combined(t *testing.T) {
v1 := &configV1{Config: Config{
ModelList: []*ModelConfig{
{ModelName: "gpt-4", Model: "openai/gpt-4", APIKeys: SimpleSecureStrings("sk-test")},
},
Channels: ChannelsConfig{
Discord: DiscordConfig{MentionOnly: true},
},
}}
result, err := v1.Migrate()
if err != nil {
t.Fatalf("Migrate: %v", err)
}
if !result.ModelList[0].Enabled {
t.Error("model with API key should be enabled after V1→V2 migration")
}
if !result.Channels.Discord.GroupTrigger.MentionOnly {
t.Error("Discord mention_only should be migrated after V1→V2 migration")
}
}
// TestLoadConfig_V1ToV2Migration verifies end-to-end V1→V2 config migration
// through LoadConfig, including Enabled field inference and version bump.
func TestLoadConfig_V1ToV2Migration(t *testing.T) {
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "config.json")
// Write a V1 config with model_list but no "enabled" field
v1Config := `{
"version": 1,
"model_list": [
{
"model_name": "gpt-4",
"model": "openai/gpt-4"
},
{
"model_name": "local-model",
"model": "vllm/custom-model",
"api_base": "http://localhost:8000/v1"
}
],
"channels": {
"discord": {
"mention_only": true
}
},
"gateway": {"host": "127.0.0.1", "port": 18790}
}`
if err := os.WriteFile(configPath, []byte(v1Config), 0o600); err != nil {
t.Fatalf("WriteFile: %v", err)
}
cfg, err := LoadConfig(configPath)
if err != nil {
t.Fatalf("LoadConfig: %v", err)
}
// Version should be bumped to 2
if cfg.Version != CurrentVersion {
t.Errorf("Version = %d, want %d", cfg.Version, CurrentVersion)
}
// gpt-4 has no API key → disabled
gpt4, err := cfg.GetModelConfig("gpt-4")
if err != nil {
t.Fatalf("GetModelConfig(gpt-4): %v", err)
}
if gpt4.Enabled {
t.Error("gpt-4 without API key should be disabled after migration")
}
// local-model → enabled
local, err := cfg.GetModelConfig("local-model")
if err != nil {
t.Fatalf("GetModelConfig(local-model): %v", err)
}
if !local.Enabled {
t.Error("local-model should be enabled after migration")
}
// Discord channel config should be migrated
if !cfg.Channels.Discord.GroupTrigger.MentionOnly {
t.Error("Discord mention_only should be migrated to group_trigger.mention_only")
}
// Verify backup was created with date suffix
entries, err := os.ReadDir(tmpDir)
if err != nil {
t.Fatalf("ReadDir: %v", err)
}
var hasBackup bool
for _, e := range entries {
if matched, _ := filepath.Match("config.json.20*.bak", e.Name()); matched {
hasBackup = true
break
}
}
if !hasBackup {
t.Error("expected backup file with date suffix to be created")
}
// Verify the saved config on disk now has version 2
saved, err := os.ReadFile(configPath)
if err != nil {
t.Fatalf("ReadFile saved config: %v", err)
}
var versionCheck struct {
Version int `json:"version"`
}
if err := json.Unmarshal(saved, &versionCheck); err != nil {
t.Fatalf("Unmarshal saved config: %v", err)
}
if versionCheck.Version != 2 {
t.Errorf("saved config version = %d, want 2", versionCheck.Version)
}
}
// TestLoadConfig_V1WithAPIKeysInferredEnabled verifies that V1 configs with
// API keys in the security file get Enabled=true after migration.
func TestLoadConfig_V1WithAPIKeysInferredEnabled(t *testing.T) {
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "config.json")
secPath := securityPath(configPath)
v1Config := `{
"version": 1,
"model_list": [
{"model_name": "gpt-4", "model": "openai/gpt-4"},
{"model_name": "claude", "model": "anthropic/claude"}
],
"gateway": {"host": "127.0.0.1", "port": 18790}
}`
securityConfig := `model_list:
gpt-4:0:
api_keys:
- "sk-gpt-key"
claude:0:
api_keys:
- "sk-claude-key"
`
if err := os.WriteFile(configPath, []byte(v1Config), 0o600); err != nil {
t.Fatalf("WriteFile: %v", err)
}
if err := os.WriteFile(secPath, []byte(securityConfig), 0o600); err != nil {
t.Fatalf("WriteFile security: %v", err)
}
cfg, err := LoadConfig(configPath)
if err != nil {
t.Fatalf("LoadConfig: %v", err)
}
for _, m := range cfg.ModelList {
if !m.Enabled {
t.Errorf("model %q with API key in security file should be enabled", m.ModelName)
}
}
}
// TestLoadConfig_V2DirectLoad verifies that V2 configs load directly without
// running any migration.
func TestLoadConfig_V2DirectLoad(t *testing.T) {
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "config.json")
v2Config := `{
"version": 2,
"model_list": [
{
"model_name": "gpt-4",
"model": "openai/gpt-4",
"enabled": true
},
{
"model_name": "claude",
"model": "anthropic/claude"
}
],
"gateway": {"host": "127.0.0.1", "port": 18790}
}`
if err := os.WriteFile(configPath, []byte(v2Config), 0o600); err != nil {
t.Fatalf("WriteFile: %v", err)
}
cfg, err := LoadConfig(configPath)
if err != nil {
t.Fatalf("LoadConfig: %v", err)
}
if cfg.Version != 2 {
t.Errorf("Version = %d, want 2", cfg.Version)
}
gpt4, _ := cfg.GetModelConfig("gpt-4")
if !gpt4.Enabled {
t.Error("gpt-4 with explicit enabled=true should remain enabled")
}
claude, _ := cfg.GetModelConfig("claude")
if claude.Enabled {
t.Error("claude without enabled field should be false (no migration for V2)")
}
// No backup should be created for V2 load
entries, _ := os.ReadDir(tmpDir)
for _, e := range entries {
if matched, _ := filepath.Match("config.json.*.bak", e.Name()); matched {
t.Errorf("V2 load should not create backup, but found %q", e.Name())
}
}
}
// TestLoadConfig_V0MigrateProducesV2 verifies that V0→V2 migration produces
// correct Enabled fields and version.
func TestLoadConfig_V0MigrateProducesV2(t *testing.T) {
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "config.json")
v0Config := `{
"model_list": [
{
"model_name": "gpt-4",
"model": "openai/gpt-4",
"api_key": "sk-test"
},
{
"model_name": "claude",
"model": "anthropic/claude"
},
{
"model_name": "local-model",
"model": "vllm/custom-model"
}
],
"gateway": {"host": "127.0.0.1", "port": 18790}
}`
if err := os.WriteFile(configPath, []byte(v0Config), 0o600); err != nil {
t.Fatalf("WriteFile: %v", err)
}
cfg, err := LoadConfig(configPath)
if err != nil {
t.Fatalf("LoadConfig: %v", err)
}
if cfg.Version != CurrentVersion {
t.Errorf("Version = %d, want %d", cfg.Version, CurrentVersion)
}
// Check enabled status
modelEnabled := func(name string) bool {
m, err := cfg.GetModelConfig(name)
if err != nil {
return false
}
return m.Enabled
}
if !modelEnabled("gpt-4") {
t.Error("gpt-4 with API key from V0 should be enabled")
}
if modelEnabled("claude") {
t.Error("claude without API key from V0 should be disabled")
}
if !modelEnabled("local-model") {
t.Error("local-model from V0 should be enabled")
}
}
// TestLoadConfig_UnsupportedVersion verifies that unsupported versions return an error.
func TestLoadConfig_UnsupportedVersion(t *testing.T) {
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "config.json")
badConfig := `{"version": 99, "gateway": {"host": "127.0.0.1", "port": 18790}}`
if err := os.WriteFile(configPath, []byte(badConfig), 0o600); err != nil {
t.Fatalf("WriteFile: %v", err)
}
_, err := LoadConfig(configPath)
if err == nil {
t.Fatal("LoadConfig should return error for unsupported version")
}
if !containsString(err.Error(), "unsupported config version") {
t.Errorf("error = %q, want 'unsupported config version'", err.Error())
}
}
func containsString(s, substr string) bool {
return len(s) >= len(substr) && searchString(s, substr)
}
func searchString(s, substr string) bool {
for i := 0; i <= len(s)-len(substr); i++ {
if s[i:i+len(substr)] == substr {
return true
}
}
return false
}
+1 -1
View File
@@ -345,7 +345,7 @@ func TestMergeAPIKeys(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := MergeAPIKeys(tt.apiKey, tt.apiKeys)
result := mergeAPIKeys(tt.apiKey, tt.apiKeys)
if len(result) != len(tt.expected) {
t.Fatalf("expected %d keys, got %d", len(tt.expected), len(result))
}
-236
View File
@@ -7,20 +7,16 @@ package config
import (
"bytes"
"encoding/json"
"fmt"
"os"
"path/filepath"
"reflect"
"runtime"
"strings"
"sync"
"gopkg.in/yaml.v3"
"github.com/sipeed/picoclaw/pkg/credential"
"github.com/sipeed/picoclaw/pkg/fileutil"
"github.com/sipeed/picoclaw/pkg/logger"
)
const (
@@ -66,7 +62,6 @@ func saveSecurityConfig(securityPath string, sec *Config) error {
return fileutil.WriteFileAtomic(securityPath, buf.Bytes(), 0o600)
}
// SensitiveDataCache caches the compiled regex for filtering sensitive data.
// SensitiveDataCache caches the strings.Replacer for filtering sensitive data.
// Computed once on first access via sync.Once.
type SensitiveDataCache struct {
@@ -178,234 +173,3 @@ func collectSensitive(v reflect.Value, values *[]string) {
}
}
}
const (
notHere = `"[NOT_HERE]"`
)
// SecureStrings is a slice of SecureString
type SecureStrings []*SecureString
// Values returns the decrypted/resolved values
func (s *SecureStrings) Values() []string {
if s == nil {
return nil
}
keys := make([]string, len(*s))
for i, k := range *s {
keys[i] = k.String()
}
return unique(keys)
}
func SimpleSecureStrings(val ...string) SecureStrings {
val = unique(val)
vv := make(SecureStrings, len(val))
for i, s := range val {
vv[i] = NewSecureString(s)
}
return vv
}
// unique returns a new slice with duplicate elements removed.
func unique[T comparable](input []T) []T {
m := make(map[T]struct{})
var result []T
for _, v := range input {
if _, ok := m[v]; !ok {
m[v] = struct{}{}
result = append(result, v)
}
}
return result
}
func (s SecureStrings) MarshalJSON() ([]byte, error) {
return []byte(notHere), nil
}
func (s *SecureStrings) UnmarshalJSON(value []byte) error {
if string(value) == notHere {
return nil
}
var v []*SecureString
err := json.Unmarshal(value, &v)
if err != nil {
return err
}
*s = v
return nil
}
// SecureString the string value that can be decrypted or resolved
//
//nolint:recvcheck
type SecureString struct {
resolved string // Decrypted/resolved value returned by String()
raw string // Persisted raw value (enc://, file://, or plaintext)
}
func callerFromYaml() bool {
_, file, _, ok := runtime.Caller(2)
if ok {
d := filepath.Dir(file)
// check the caller is from yaml.v
if !strings.Contains(d, "yaml.v") {
return true
}
}
return false
}
// IsZero returns true if the SecureString is empty
// if caller not yaml, just return true for prevent marshal this field
func (s SecureString) IsZero() bool {
if callerFromYaml() {
return true
}
return s.resolved == ""
}
func NewSecureString(value string) *SecureString {
s := &SecureString{}
if err := s.fromRaw(value); err != nil {
logger.Warn(fmt.Sprintf("NewSecureString.fromRaw error: %s", err))
}
return s
}
func (s *SecureString) String() string {
if s == nil {
return ""
}
return s.resolved
}
func (s *SecureString) Set(value string) *SecureString {
s.resolved = value
s.raw = ""
return s
}
func (s SecureString) MarshalJSON() ([]byte, error) {
return []byte(notHere), nil
}
func (s *SecureString) UnmarshalJSON(value []byte) error {
if string(value) == notHere {
return nil
}
var v string
if err := json.Unmarshal(value, &v); err != nil {
return err
}
return s.fromRaw(v)
}
func (s SecureString) MarshalYAML() (any, error) {
// Preserve raw value if it is already a reference (enc:// or file://)
if strings.HasPrefix(s.raw, credential.EncScheme) || strings.HasPrefix(s.raw, credential.FileScheme) {
return s.raw, nil
}
// If resolved is a reference format (e.g. set via Set), copy back to raw
if strings.HasPrefix(s.resolved, credential.EncScheme) || strings.HasPrefix(s.resolved, credential.FileScheme) {
s.raw = s.resolved
return s.raw, nil
}
// Try to encrypt the resolved value
if passphrase := credential.PassphraseProvider(); passphrase != "" {
encrypted, err := credential.Encrypt(passphrase, "", s.resolved)
if err != nil {
logger.Errorf("Encrypt error: %v", err)
return nil, err
}
s.raw = encrypted
} else {
s.raw = s.resolved
}
return s.raw, nil
}
func (s *SecureString) UnmarshalYAML(value *yaml.Node) error {
return s.fromRaw(value.Value)
}
func (s *SecureString) fromRaw(v string) error {
s.raw = v
vv, err := resolveKey(v)
if err != nil {
return err
}
s.resolved = vv
return nil
}
var (
secResolverMu sync.RWMutex
secResolver *credential.Resolver
)
func updateResolver(path string) {
secResolverMu.Lock()
defer secResolverMu.Unlock()
secResolver = credential.NewResolver(path)
}
func resolveKey(v string) (string, error) {
secResolverMu.RLock()
resolver := secResolver
secResolverMu.RUnlock()
if resolver == nil {
resolver = credential.NewResolver("")
}
if strings.HasPrefix(v, "enc://") || strings.HasPrefix(v, "file://") {
decrypted, err := resolver.Resolve(v)
if err != nil {
logger.Errorf("Resolve error: %v", err)
return "", err
}
return decrypted, nil
}
return v, nil
}
func (s *SecureString) UnmarshalText(text []byte) error {
v := string(text)
return s.fromRaw(v)
}
type SecureModelList []*ModelConfig
func (v *SecureModelList) UnmarshalYAML(value *yaml.Node) error {
mm := make(map[string]*ModelConfig)
if err := value.Decode(&mm); err != nil {
logger.Errorf("Decode error: %v", err)
return err
}
nameList := toNameIndex(*v)
for i, m := range *v {
sec := mm[nameList[i]]
if sec == nil {
sec = mm[m.ModelName]
}
if sec != nil {
m.APIKeys = sec.APIKeys
}
}
return nil
}
func (v SecureModelList) MarshalYAML() (any, error) {
type onlySecureData struct {
APIKeys SecureStrings `yaml:"api_keys,omitempty"`
}
mm := make(map[string]onlySecureData)
nameList := toNameIndex(v)
for i, m := range v {
mm[nameList[i]] = onlySecureData{
APIKeys: m.APIKeys,
}
}
return mm, nil
}
-133
View File
@@ -15,8 +15,6 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gopkg.in/yaml.v3"
"github.com/sipeed/picoclaw/pkg/credential"
)
func TestSecurityConfig(t *testing.T) {
@@ -227,134 +225,3 @@ skills:
assert.Equal(t, "abc", cfg2.Tools.Web.Brave.APIKeys[1].raw)
})
}
func TestLoadSecurityValue(t *testing.T) {
type valueStruct struct {
Url string `json:"url,omitempty" yaml:"-"`
Token *SecureString `json:"token,omitempty" yaml:"token,omitempty" env:"PICO_TOKEN"`
ApiKeys SecureStrings `json:"api_keys,omitempty" yaml:"api_keys,omitempty" env:"PICO_API_KEYS"`
}
type testStruct struct {
Pico *valueStruct `json:"pico,omitempty" yaml:"pico,omitempty"`
}
v1 := &testStruct{
Pico: &valueStruct{
Url: "https://example.com",
Token: NewSecureString("token1"),
ApiKeys: SecureStrings{NewSecureString("api-key1"), NewSecureString("api-key2")},
},
}
bytes, err := yaml.Marshal(v1)
assert.NoError(t, err)
jsonBytes, err := json.Marshal(v1)
assert.NoError(t, err)
const want = `pico:
token: token1
api_keys:
- api-key1
- api-key2
`
const jsonPost = `{"pico":{"url":"https://example.com","token":"token0"}}`
v0 := &testStruct{}
err = json.Unmarshal([]byte(jsonPost), v0)
assert.NoError(t, err)
assert.Equal(t, "https://example.com", v0.Pico.Url)
assert.Equal(t, "token0", v0.Pico.Token.String())
const jsonWant = `{"pico":{"url":"https://example.com","token":"[NOT_HERE]","api_keys":"[NOT_HERE]"}}`
assert.Equal(t, want, string(bytes))
assert.Equal(t, jsonWant, string(jsonBytes))
v2 := &testStruct{}
err = json.Unmarshal(jsonBytes, v2)
assert.NoError(t, err)
err = yaml.Unmarshal(bytes, v2)
assert.NoError(t, err)
assert.Equal(t, "https://example.com", v2.Pico.Url)
if v2.Pico.Token != nil {
assert.Equal(t, "token1", v2.Pico.Token.String())
assert.Equal(t, "token1", v2.Pico.Token.raw)
}
v2.Pico.Token = NewSecureString("token1")
v2.Pico.Token.raw = "abc"
err = yaml.Unmarshal(bytes, v2)
assert.NoError(t, err)
assert.Equal(t, "token1", v2.Pico.Token.raw)
os.Setenv("PICO_TOKEN", "token_env")
err = env.Parse(v2)
assert.NoError(t, err)
assert.NotNil(t, v2.Pico.Token)
assert.Equal(t, "token1", v2.Pico.Token.String())
v3 := &testStruct{Pico: &valueStruct{}}
err = env.Parse(v3)
assert.NoError(t, err)
if v3.Pico.Token != nil {
assert.Equal(t, "token_env", v3.Pico.Token.String())
}
type toolsStruct struct {
Pico valueStruct `json:"pico,omitempty" yaml:"pico,omitempty"`
}
type testStruct2 struct {
Tools toolsStruct `json:"tools,omitempty" yaml:",inline"`
}
v4 := &testStruct2{
Tools: toolsStruct{
Pico: valueStruct{
Url: "https://example.com",
Token: NewSecureString("token1"),
ApiKeys: SecureStrings{NewSecureString("api-key1"), NewSecureString("api-key2")},
},
},
}
bytes, err = yaml.Marshal(v4)
assert.NoError(t, err)
assert.Equal(t, want, string(bytes))
jsonBytes, err = json.Marshal(v4)
assert.NoError(t, err)
assert.Equal(
t,
`{"tools":{"pico":{"url":"https://example.com","token":"[NOT_HERE]","api_keys":"[NOT_HERE]"}}}`,
string(jsonBytes),
)
v5 := &testStruct2{}
err = json.Unmarshal(jsonBytes, v5)
assert.NoError(t, err)
assert.Equal(t, "https://example.com", v5.Tools.Pico.Url)
err = yaml.Unmarshal(bytes, v5)
assert.NoError(t, err)
assert.NotNil(t, v5.Tools.Pico.Token)
assert.Equal(t, "token1", v5.Tools.Pico.Token.raw)
dir := t.TempDir()
sshKeyPath := filepath.Join(dir, "picoclaw_ed25519.key")
if err = os.WriteFile(sshKeyPath, []byte("fake-ssh-key-material\n"), 0o600); err != nil {
t.Fatalf("setup: %v", err)
}
const passphrase = "test-passphrase-32bytes-long-ok!"
t.Setenv(credential.SSHKeyPathEnvVar, sshKeyPath)
t.Setenv(credential.PassphraseEnvVar, passphrase)
v5.Tools.Pico.Token.Set("newtoken1")
v5.Tools.Pico.ApiKeys[0].Set("newapi-key1")
bytes, err = yaml.Marshal(v5)
assert.NoError(t, err)
t.Logf("yaml: %s", string(bytes))
v6 := &testStruct2{}
err = yaml.Unmarshal(bytes, v6)
assert.NoError(t, err)
assert.NotNil(t, v6.Tools.Pico.Token)
assert.Equal(t, "newtoken1", v6.Tools.Pico.Token.String())
}