From 93757812fc64e225c89e1d33ed9e5d0504ad4d75 Mon Sep 17 00:00:00 2001 From: Cytown Date: Mon, 30 Mar 2026 14:01:20 +0800 Subject: [PATCH] refactor config and add ModelConfig.Enabled --- cmd/picoclaw/internal/model/command.go | 4 +- cmd/picoclaw/internal/model/command_test.go | 34 +- pkg/config/config.go | 182 +++----- pkg/config/config_old.go | 62 ++- pkg/config/config_struct.go | 327 ++++++++++++++ pkg/config/config_struct_test.go | 145 ++++++ pkg/config/config_test.go | 160 +++++++ pkg/config/migration.go | 23 + pkg/config/migration_integration_test.go | 470 ++++++++++++++++++++ pkg/config/multikey_test.go | 2 +- pkg/config/security.go | 236 ---------- pkg/config/security_test.go | 133 ------ web/backend/api/models.go | 2 + 13 files changed, 1272 insertions(+), 508 deletions(-) create mode 100644 pkg/config/config_struct.go create mode 100644 pkg/config/config_struct_test.go diff --git a/cmd/picoclaw/internal/model/command.go b/cmd/picoclaw/internal/model/command.go index 314259d0f..330734b82 100644 --- a/cmd/picoclaw/internal/model/command.go +++ b/cmd/picoclaw/internal/model/command.go @@ -81,7 +81,7 @@ func listAvailableModels(cfg *config.Config) { if model.ModelName == defaultModel { marker = "> " } - if model.APIKey() == "" { + if !model.Enabled { continue } fmt.Printf("%s- %s (%s)\n", marker, model.ModelName, model.Model) @@ -92,7 +92,7 @@ func setDefaultModel(configPath string, cfg *config.Config, modelName string) er // Validate that the model exists in model_list modelFound := false for _, model := range cfg.ModelList { - if model.APIKey() != "" && model.ModelName == modelName { + if model.Enabled && model.ModelName == modelName { modelFound = true break } diff --git a/cmd/picoclaw/internal/model/command_test.go b/cmd/picoclaw/internal/model/command_test.go index 8be29ba95..9e2a7bbae 100644 --- a/cmd/picoclaw/internal/model/command_test.go +++ b/cmd/picoclaw/internal/model/command_test.go @@ -65,11 +65,17 @@ func TestShowCurrentModel_WithDefaultModel(t *testing.T) { }, }, ModelList: []*config.ModelConfig{ - {ModelName: "gpt-4", Model: "openai/gpt-4", APIKeys: config.SecureStrings{config.NewSecureString("test")}}, + { + ModelName: "gpt-4", + Model: "openai/gpt-4", + APIKeys: config.SecureStrings{config.NewSecureString("test")}, + Enabled: true, + }, { ModelName: "claude-3", Model: "anthropic/claude-3", APIKeys: config.SecureStrings{config.NewSecureString("test")}, + Enabled: true, }, }, } @@ -92,7 +98,12 @@ func TestShowCurrentModel_NoDefaultModel(t *testing.T) { }, }, ModelList: []*config.ModelConfig{ - {ModelName: "gpt-4", Model: "openai/gpt-4", APIKeys: config.SecureStrings{config.NewSecureString("test")}}, + { + ModelName: "gpt-4", + Model: "openai/gpt-4", + APIKeys: config.SecureStrings{config.NewSecureString("test")}, + Enabled: true, + }, }, } @@ -124,11 +135,17 @@ func TestListAvailableModels_WithModels(t *testing.T) { }, }, ModelList: []*config.ModelConfig{ - {ModelName: "gpt-4", Model: "openai/gpt-4", APIKeys: config.SecureStrings{config.NewSecureString("test")}}, + { + ModelName: "gpt-4", + Model: "openai/gpt-4", + APIKeys: config.SecureStrings{config.NewSecureString("test")}, + Enabled: true, + }, { ModelName: "claude-3", Model: "anthropic/claude-3", APIKeys: config.SecureStrings{config.NewSecureString("test")}, + Enabled: true, }, {ModelName: "no-key-model", Model: "openai/test"}, }, @@ -158,11 +175,13 @@ func TestSetDefaultModel_ValidModel(t *testing.T) { ModelName: "new-model", Model: "openai/new-model", APIKeys: config.SecureStrings{config.NewSecureString("test")}, + Enabled: true, }, { ModelName: "old-model", Model: "openai/old-model", APIKeys: config.SecureStrings{config.NewSecureString("test")}, + Enabled: true, }, }, } @@ -194,6 +213,7 @@ func TestSetDefaultModel_InvalidModel(t *testing.T) { ModelName: "existing-model", Model: "openai/existing", APIKeys: config.SecureStrings{config.NewSecureString("test")}, + Enabled: true, }, }, } @@ -215,6 +235,7 @@ func TestSetDefaultModel_ModelWithoutAPIKey(t *testing.T) { ModelName: "existing-model", Model: "openai/existing", APIKeys: config.SecureStrings{config.NewSecureString("test")}, + Enabled: true, }, {ModelName: "no-key-model", Model: "openai/nokey"}, }, @@ -238,6 +259,7 @@ func TestSetDefaultModel_SaveConfigError(t *testing.T) { ModelName: "new-model", Model: "openai/new-model", APIKeys: config.SecureStrings{config.NewSecureString("test")}, + Enabled: true, }, }, } @@ -283,6 +305,7 @@ func TestModelCommandExecution_Show(t *testing.T) { ModelName: "test-model", Model: "openai/test", APIKeys: config.SecureStrings{config.NewSecureString("test")}, + Enabled: true, }, }, } @@ -314,11 +337,13 @@ func TestModelCommandExecution_Set(t *testing.T) { ModelName: "old-model", Model: "openai/old", APIKeys: config.SecureStrings{config.NewSecureString("test")}, + Enabled: true, }, { ModelName: "new-model", Model: "openai/new", APIKeys: config.SecureStrings{config.NewSecureString("test")}, + Enabled: true, }, }, } @@ -356,16 +381,19 @@ func TestListAvailableModels_MarkerLogic(t *testing.T) { ModelName: "first-model", Model: "openai/first", APIKeys: config.SecureStrings{config.NewSecureString("test")}, + Enabled: true, }, { ModelName: "middle-model", Model: "openai/middle", APIKeys: config.SecureStrings{config.NewSecureString("test")}, + Enabled: true, }, { ModelName: "last-model", Model: "openai/last", APIKeys: config.SecureStrings{config.NewSecureString("test")}, + Enabled: true, }, }, } diff --git a/pkg/config/config.go b/pkg/config/config.go index 533f45a44..3dc3422fb 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -7,8 +7,8 @@ import ( "math/rand" "os" "path/filepath" - "strings" "sync/atomic" + "time" "github.com/caarlos0/env/v11" @@ -20,89 +20,8 @@ import ( // rrCounter is a global counter for round-robin load balancing across models. var rrCounter atomic.Uint64 -// FlexibleStringSlice is a []string that also accepts JSON numbers, -// so allow_from can contain both "123" and 123. -// It also supports parsing comma-separated strings from environment variables, -// including both English (,) and Chinese (,) commas. -type FlexibleStringSlice []string - -func (f *FlexibleStringSlice) UnmarshalJSON(data []byte) error { - // Accept a single JSON string for convenience, e.g.: - // "text": "Thinking..." - var singleString string - if err := json.Unmarshal(data, &singleString); err == nil { - *f = FlexibleStringSlice{singleString} - return nil - } - - // Accept a single JSON number too, to keep symmetry with mixed allow_from - // payloads that may contain numeric identifiers. - var singleNumber float64 - if err := json.Unmarshal(data, &singleNumber); err == nil { - *f = FlexibleStringSlice{fmt.Sprintf("%.0f", singleNumber)} - return nil - } - - // Try []string first - var ss []string - if err := json.Unmarshal(data, &ss); err == nil { - *f = ss - return nil - } - - // Try []interface{} to handle mixed types - var raw []any - if err := json.Unmarshal(data, &raw); err != nil { - var s string - // fail over to compatible to old format string - if err = json.Unmarshal(data, &s); err != nil { - return err - } - *f = []string{s} - return nil - } - - result := make([]string, 0, len(raw)) - for _, v := range raw { - switch val := v.(type) { - case string: - result = append(result, val) - case float64: - result = append(result, fmt.Sprintf("%.0f", val)) - default: - result = append(result, fmt.Sprintf("%v", val)) - } - } - *f = result - return nil -} - -// UnmarshalText implements encoding.TextUnmarshaler to support env variable parsing. -// It handles comma-separated values with both English (,) and Chinese (,) commas. -func (f *FlexibleStringSlice) UnmarshalText(text []byte) error { - if len(text) == 0 { - *f = nil - return nil - } - - s := string(text) - // Replace Chinese comma with English comma, then split - s = strings.ReplaceAll(s, ",", ",") - parts := strings.Split(s, ",") - - result := make([]string, 0, len(parts)) - for _, part := range parts { - part = strings.TrimSpace(part) - if part != "" { - result = append(result, part) - } - } - *f = result - return nil -} - // CurrentVersion is the latest config schema version -const CurrentVersion = 1 +const CurrentVersion = 2 // Config is the current config structure with version support type Config struct { @@ -675,6 +594,11 @@ type ModelConfig struct { APIKeys SecureStrings `json:"api_keys,omitzero" yaml:"api_keys,omitempty"` // API authentication keys (multiple keys for failover) + // Enabled indicates whether this model entry is active. When omitted in + // existing configs, the field is inferred during load: models with API keys + // or the reserved "local-model" name are auto-enabled. + Enabled bool `json:"enabled,omitempty" yaml:"enabled,omitempty"` + // isVirtual marks this model as a virtual model generated from multi-key expansion. // Virtual models should not be persisted to config files. isVirtual bool @@ -1047,6 +971,35 @@ func LoadConfig(path string) (*Config, error) { defer func(cfg *Config) { _ = SaveConfig(path, cfg) }(cfg) + case 1: + // V1→V2 migration: infer Enabled and migrate channel config fields + logger.InfoF("config migrate start", map[string]any{"from": versionInfo.Version, "to": CurrentVersion}) + cfg, err = loadConfig(data) + if err != nil { + return nil, err + } + secPath := securityPath(path) + err = loadSecurityConfig(cfg, secPath) + if err != nil && !errors.Is(err, os.ErrNotExist) { + return nil, fmt.Errorf("failed to load security config: %w", err) + } + + oldCfg := &configV1{Config: *cfg} + cfg, err = oldCfg.Migrate() + if err != nil { + logger.ErrorF("config migrate fail", map[string]any{"from": versionInfo.Version, "to": CurrentVersion}) + return nil, err + } + + err = makeBackup(path) + if err != nil { + return nil, err + } + + defer func(cfg *Config) { + _ = SaveConfig(path, cfg) + }(cfg) + logger.InfoF("config migrate success", map[string]any{"from": versionInfo.Version, "to": CurrentVersion}) case CurrentVersion: // Current version cfg, err = loadConfig(data) @@ -1064,18 +1017,15 @@ func LoadConfig(path string) (*Config, error) { return nil, fmt.Errorf("unsupported config version: %d", versionInfo.Version) } - if err := env.Parse(cfg); err != nil { + if err = env.Parse(cfg); err != nil { return nil, err } // Expand multi-key configs into separate entries for key-level failover cfg.ModelList = expandMultiKeyModels(cfg.ModelList) - // Migrate legacy channel config fields to new unified structures - cfg.migrateChannelConfigs() - // Validate model_list for uniqueness and required fields - if err := cfg.ValidateModelList(); err != nil { + if err = cfg.ValidateModelList(); err != nil { return nil, err } @@ -1097,12 +1047,22 @@ func makeBackup(path string) error { if _, err := os.Stat(path); os.IsNotExist(err) { return nil } - // Create backup of the config file before migration - bakPath := path + ".bak" + dateSuffix := time.Now().Format(".20060102.bak") + // Backup config file + bakPath := path + dateSuffix if err := fileutil.CopyFile(path, bakPath, 0o600); err != nil { logger.ErrorF("failed to create config backup", map[string]any{"error": err}) return fmt.Errorf("failed to create config backup: %w", err) } + // Backup security config file + secPath := securityPath(path) + if _, err := os.Stat(secPath); err == nil { + secBakPath := secPath + dateSuffix + if secErr := fileutil.CopyFile(secPath, secBakPath, 0o600); secErr != nil { + logger.ErrorF("failed to create security backup", map[string]any{"error": secErr}) + return fmt.Errorf("failed to create security backup: %w", secErr) + } + } return nil } @@ -1118,19 +1078,6 @@ func toNameIndex(list []*ModelConfig) []string { return nameList } -func (c *Config) migrateChannelConfigs() { - // Discord: mention_only -> group_trigger.mention_only - if c.Channels.Discord.MentionOnly && !c.Channels.Discord.GroupTrigger.MentionOnly { - c.Channels.Discord.GroupTrigger.MentionOnly = true - } - - // OneBot: group_trigger_prefix -> group_trigger.prefixes - if len(c.Channels.OneBot.GroupTriggerPrefix) > 0 && - len(c.Channels.OneBot.GroupTrigger.Prefixes) == 0 { - c.Channels.OneBot.GroupTrigger.Prefixes = c.Channels.OneBot.GroupTriggerPrefix - } -} - func SaveConfig(path string, cfg *Config) error { if cfg.Version < CurrentVersion { cfg.Version = CurrentVersion @@ -1144,6 +1091,10 @@ func SaveConfig(path string, cfg *Config) error { } // Temporarily replace ModelList with filtered version for serialization originalModelList := cfg.ModelList + defer func() { + // Restore original ModelList after serialization + cfg.ModelList = originalModelList + }() cfg.ModelList = nonVirtualModels if err := saveSecurityConfig(securityPath(path), cfg); err != nil { @@ -1152,8 +1103,6 @@ func SaveConfig(path string, cfg *Config) error { } data, err := json.MarshalIndent(cfg, "", " ") - // Restore original ModelList after serialization - cfg.ModelList = originalModelList if err != nil { return err } @@ -1223,29 +1172,6 @@ func (c *Config) SecurityCopyFrom(path string) error { return loadSecurityConfig(c, securityPath(path)) } -func MergeAPIKeys(apiKey string, apiKeys []string) []string { - seen := make(map[string]struct{}) - var all []string - - if k := strings.TrimSpace(apiKey); k != "" { - if _, exists := seen[k]; !exists { - seen[k] = struct{}{} - all = append(all, k) - } - } - - for _, k := range apiKeys { - if trimmed := strings.TrimSpace(k); trimmed != "" { - if _, exists := seen[trimmed]; !exists { - seen[trimmed] = struct{}{} - all = append(all, trimmed) - } - } - } - - return all -} - // expandMultiKeyModels expands ModelConfig entries with multiple API keys into // separate entries for key-level failover. Each key gets its own ModelConfig entry, // and the original entry's fallbacks are set up to chain through the expanded entries. diff --git a/pkg/config/config_old.go b/pkg/config/config_old.go index fd54c9e08..150275aac 100644 --- a/pkg/config/config_old.go +++ b/pkg/config/config_old.go @@ -734,7 +734,8 @@ func (c *configV0) Migrate() (*Config, error) { // Convert []modelConfigV0 to []ModelConfig cfg.ModelList = make([]*ModelConfig, len(c.ModelList)) for i, m := range c.ModelList { - cfg.ModelList[i] = &ModelConfig{ + mergedKeys := toSecureStrings(mergeAPIKeys(m.APIKey, m.APIKeys)) + mc := &ModelConfig{ ModelName: m.ModelName, Model: m.Model, APIBase: m.APIBase, @@ -747,8 +748,13 @@ func (c *configV0) Migrate() (*Config, error) { MaxTokensField: m.MaxTokensField, RequestTimeout: m.RequestTimeout, ThinkingLevel: m.ThinkingLevel, - APIKeys: toSecureStrings(MergeAPIKeys(m.APIKey, m.APIKeys)), + APIKeys: mergedKeys, } + // Infer Enabled during V0→V1 migration + if len(mergedKeys) > 0 || m.ModelName == "local-model" { + mc.Enabled = true + } + cfg.ModelList[i] = mc } } @@ -756,6 +762,52 @@ func (c *configV0) Migrate() (*Config, error) { return cfg, nil } +type configV1 struct { + Config +} + +// Migrate applies V1→Current Version migrations to an already-loaded Config. +// +// It must be called AFTER loadSecurityConfig so that API keys (which live in +// the security file) are available for the Enabled inference. +func (c *configV1) Migrate() (*Config, error) { + c.migrateModelEnabled() + c.migrateChannelConfigs() + return &c.Config, nil +} + +// migrateModelEnabled infers the Enabled field for models loaded from V1 configs +// that predate the field (JSON where "enabled" is absent). +// +// Rules (only applied when Enabled has not been explicitly set by the user): +// - Models with API keys are considered enabled. +// - The reserved "local-model" entry is considered enabled. +func (cfg *configV1) migrateModelEnabled() { + for _, m := range cfg.ModelList { + if m.Enabled { + continue + } + if len(m.APIKeys) > 0 || m.ModelName == "local-model" { + m.Enabled = true + } + } +} + +// migrateChannelConfigs migrates legacy channel config fields in a V1 Config +// to the new unified structures. +func (cfg *configV1) migrateChannelConfigs() { + // Discord: mention_only -> group_trigger.mention_only + if cfg.Channels.Discord.MentionOnly && !cfg.Channels.Discord.GroupTrigger.MentionOnly { + cfg.Channels.Discord.GroupTrigger.MentionOnly = true + } + + // OneBot: group_trigger_prefix -> group_trigger.prefixes + if len(cfg.Channels.OneBot.GroupTriggerPrefix) > 0 && + len(cfg.Channels.OneBot.GroupTrigger.Prefixes) == 0 { + cfg.Channels.OneBot.GroupTrigger.Prefixes = cfg.Channels.OneBot.GroupTriggerPrefix + } +} + type webToolsConfigV0 struct { ToolConfig ` envPrefix:"PICOCLAW_TOOLS_WEB_"` Brave braveConfigV0 ` json:"brave"` @@ -791,7 +843,7 @@ func (v *braveConfigV0) ToBraveConfig() BraveConfig { return BraveConfig{ Enabled: v.Enabled, MaxResults: v.MaxResults, - APIKeys: toSecureStrings(MergeAPIKeys(v.APIKey, v.APIKeys)), + APIKeys: toSecureStrings(mergeAPIKeys(v.APIKey, v.APIKeys)), } } @@ -808,7 +860,7 @@ func (v *tavilyConfigV0) ToTavilyConfig() TavilyConfig { Enabled: v.Enabled, BaseURL: v.BaseURL, MaxResults: v.MaxResults, - APIKeys: toSecureStrings(MergeAPIKeys(v.APIKey, v.APIKeys)), + APIKeys: toSecureStrings(mergeAPIKeys(v.APIKey, v.APIKeys)), } } @@ -823,7 +875,7 @@ func (v *perplexityConfigV0) ToPerplexityConfig() PerplexityConfig { return PerplexityConfig{ Enabled: v.Enabled, MaxResults: v.MaxResults, - APIKeys: toSecureStrings(MergeAPIKeys(v.APIKey, v.APIKeys)), + APIKeys: toSecureStrings(mergeAPIKeys(v.APIKey, v.APIKeys)), } } diff --git a/pkg/config/config_struct.go b/pkg/config/config_struct.go new file mode 100644 index 000000000..0b8dd85c8 --- /dev/null +++ b/pkg/config/config_struct.go @@ -0,0 +1,327 @@ +package config + +import ( + "encoding/json" + "fmt" + "path/filepath" + "runtime" + "strings" + "sync" + + "gopkg.in/yaml.v3" + + "github.com/sipeed/picoclaw/pkg/credential" + "github.com/sipeed/picoclaw/pkg/logger" +) + +// FlexibleStringSlice is a []string that also accepts JSON numbers, +// so allow_from can contain both "123" and 123. +// It also supports parsing comma-separated strings from environment variables, +// including both English (,) and Chinese (,) commas. +type FlexibleStringSlice []string + +func (f *FlexibleStringSlice) UnmarshalJSON(data []byte) error { + // Accept a single JSON string for convenience, e.g.: + // "text": "Thinking..." + var singleString string + if err := json.Unmarshal(data, &singleString); err == nil { + *f = FlexibleStringSlice{singleString} + return nil + } + + // Accept a single JSON number too, to keep symmetry with mixed allow_from + // payloads that may contain numeric identifiers. + var singleNumber float64 + if err := json.Unmarshal(data, &singleNumber); err == nil { + *f = FlexibleStringSlice{fmt.Sprintf("%.0f", singleNumber)} + return nil + } + + // Try []string first + var ss []string + if err := json.Unmarshal(data, &ss); err == nil { + *f = ss + return nil + } + + // Try []interface{} to handle mixed types + var raw []any + if err := json.Unmarshal(data, &raw); err != nil { + var s string + // fail over to compatible to old format string + if err = json.Unmarshal(data, &s); err != nil { + return err + } + *f = []string{s} + return nil + } + + result := make([]string, 0, len(raw)) + for _, v := range raw { + switch val := v.(type) { + case string: + result = append(result, val) + case float64: + result = append(result, fmt.Sprintf("%.0f", val)) + default: + result = append(result, fmt.Sprintf("%v", val)) + } + } + *f = result + return nil +} + +// UnmarshalText implements encoding.TextUnmarshaler to support env variable parsing. +// It handles comma-separated values with both English (,) and Chinese (,) commas. +func (f *FlexibleStringSlice) UnmarshalText(text []byte) error { + if len(text) == 0 { + *f = nil + return nil + } + + s := string(text) + // Replace Chinese comma with English comma, then split + s = strings.ReplaceAll(s, ",", ",") + parts := strings.Split(s, ",") + + result := make([]string, 0, len(parts)) + for _, part := range parts { + part = strings.TrimSpace(part) + if part != "" { + result = append(result, part) + } + } + *f = result + return nil +} + +const ( + notHere = `"[NOT_HERE]"` +) + +// SecureStrings is a slice of SecureString +type SecureStrings []*SecureString + +// Values returns the decrypted/resolved values +func (s *SecureStrings) Values() []string { + if s == nil { + return nil + } + keys := make([]string, len(*s)) + for i, k := range *s { + keys[i] = k.String() + } + return unique(keys) +} + +func SimpleSecureStrings(val ...string) SecureStrings { + val = unique(val) + vv := make(SecureStrings, len(val)) + for i, s := range val { + vv[i] = NewSecureString(s) + } + return vv +} + +// unique returns a new slice with duplicate elements removed. +func unique[T comparable](input []T) []T { + m := make(map[T]struct{}) + var result []T + for _, v := range input { + if _, ok := m[v]; !ok { + m[v] = struct{}{} + result = append(result, v) + } + } + return result +} + +func (s SecureStrings) MarshalJSON() ([]byte, error) { + return []byte(notHere), nil +} + +func (s *SecureStrings) UnmarshalJSON(value []byte) error { + if string(value) == notHere { + return nil + } + var v []*SecureString + err := json.Unmarshal(value, &v) + if err != nil { + return err + } + *s = v + return nil +} + +// SecureString the string value that can be decrypted or resolved +// +//nolint:recvcheck +type SecureString struct { + resolved string // Decrypted/resolved value returned by String() + raw string // Persisted raw value (enc://, file://, or plaintext) +} + +func callerFromYaml() bool { + _, file, _, ok := runtime.Caller(2) + if ok { + d := filepath.Dir(file) + // check the caller is from yaml.v + if !strings.Contains(d, "yaml.v") { + return true + } + } + return false +} + +// IsZero returns true if the SecureString is empty +// if caller not yaml, just return true for prevent marshal this field +func (s SecureString) IsZero() bool { + if callerFromYaml() { + return true + } + return s.resolved == "" +} + +func NewSecureString(value string) *SecureString { + s := &SecureString{} + if err := s.fromRaw(value); err != nil { + logger.Warn(fmt.Sprintf("NewSecureString.fromRaw error: %s", err)) + } + return s +} + +func (s *SecureString) String() string { + if s == nil { + return "" + } + return s.resolved +} + +func (s *SecureString) Set(value string) *SecureString { + s.resolved = value + s.raw = "" + return s +} + +func (s SecureString) MarshalJSON() ([]byte, error) { + return []byte(notHere), nil +} + +func (s *SecureString) UnmarshalJSON(value []byte) error { + if string(value) == notHere { + return nil + } + var v string + if err := json.Unmarshal(value, &v); err != nil { + return err + } + return s.fromRaw(v) +} + +func (s SecureString) MarshalYAML() (any, error) { + // Preserve raw value if it is already a reference (enc:// or file://) + if strings.HasPrefix(s.raw, credential.EncScheme) || strings.HasPrefix(s.raw, credential.FileScheme) { + return s.raw, nil + } + // If resolved is a reference format (e.g. set via Set), copy back to raw + if strings.HasPrefix(s.resolved, credential.EncScheme) || strings.HasPrefix(s.resolved, credential.FileScheme) { + s.raw = s.resolved + return s.raw, nil + } + // Try to encrypt the resolved value + if passphrase := credential.PassphraseProvider(); passphrase != "" { + encrypted, err := credential.Encrypt(passphrase, "", s.resolved) + if err != nil { + logger.Errorf("Encrypt error: %v", err) + return nil, err + } + s.raw = encrypted + } else { + s.raw = s.resolved + } + return s.raw, nil +} + +func (s *SecureString) UnmarshalYAML(value *yaml.Node) error { + return s.fromRaw(value.Value) +} + +func (s *SecureString) fromRaw(v string) error { + s.raw = v + vv, err := resolveKey(v) + if err != nil { + return err + } + s.resolved = vv + return nil +} + +var ( + secResolverMu sync.RWMutex + secResolver *credential.Resolver +) + +func updateResolver(path string) { + secResolverMu.Lock() + defer secResolverMu.Unlock() + secResolver = credential.NewResolver(path) +} + +func resolveKey(v string) (string, error) { + secResolverMu.RLock() + resolver := secResolver + secResolverMu.RUnlock() + if resolver == nil { + resolver = credential.NewResolver("") + } + if strings.HasPrefix(v, "enc://") || strings.HasPrefix(v, "file://") { + decrypted, err := resolver.Resolve(v) + if err != nil { + logger.Errorf("Resolve error: %v", err) + return "", err + } + return decrypted, nil + } + return v, nil +} + +func (s *SecureString) UnmarshalText(text []byte) error { + v := string(text) + return s.fromRaw(v) +} + +type SecureModelList []*ModelConfig + +func (v *SecureModelList) UnmarshalYAML(value *yaml.Node) error { + mm := make(map[string]*ModelConfig) + if err := value.Decode(&mm); err != nil { + logger.Errorf("Decode error: %v", err) + return err + } + nameList := toNameIndex(*v) + for i, m := range *v { + sec := mm[nameList[i]] + if sec == nil { + sec = mm[m.ModelName] + } + if sec != nil { + m.APIKeys = sec.APIKeys + } + } + return nil +} + +func (v SecureModelList) MarshalYAML() (any, error) { + type onlySecureData struct { + APIKeys SecureStrings `yaml:"api_keys,omitempty"` + } + mm := make(map[string]onlySecureData) + nameList := toNameIndex(v) + for i, m := range v { + mm[nameList[i]] = onlySecureData{ + APIKeys: m.APIKeys, + } + } + + return mm, nil +} diff --git a/pkg/config/config_struct_test.go b/pkg/config/config_struct_test.go new file mode 100644 index 000000000..674b6a064 --- /dev/null +++ b/pkg/config/config_struct_test.go @@ -0,0 +1,145 @@ +package config + +import ( + "encoding/json" + "os" + "path/filepath" + "testing" + + "github.com/caarlos0/env/v11" + "github.com/stretchr/testify/assert" + "gopkg.in/yaml.v3" + + "github.com/sipeed/picoclaw/pkg/credential" +) + +func TestLoadSecurityValue(t *testing.T) { + type valueStruct struct { + Url string `json:"url,omitempty" yaml:"-"` + Token *SecureString `json:"token,omitempty" yaml:"token,omitempty" env:"PICO_TOKEN"` + ApiKeys SecureStrings `json:"api_keys,omitempty" yaml:"api_keys,omitempty" env:"PICO_API_KEYS"` + } + + type testStruct struct { + Pico *valueStruct `json:"pico,omitempty" yaml:"pico,omitempty"` + } + + v1 := &testStruct{ + Pico: &valueStruct{ + Url: "https://example.com", + Token: NewSecureString("token1"), + ApiKeys: SecureStrings{NewSecureString("api-key1"), NewSecureString("api-key2")}, + }, + } + bytes, err := yaml.Marshal(v1) + assert.NoError(t, err) + jsonBytes, err := json.Marshal(v1) + assert.NoError(t, err) + const want = `pico: + token: token1 + api_keys: + - api-key1 + - api-key2 +` + const jsonPost = `{"pico":{"url":"https://example.com","token":"token0"}}` + v0 := &testStruct{} + err = json.Unmarshal([]byte(jsonPost), v0) + assert.NoError(t, err) + assert.Equal(t, "https://example.com", v0.Pico.Url) + assert.Equal(t, "token0", v0.Pico.Token.String()) + + const jsonWant = `{"pico":{"url":"https://example.com","token":"[NOT_HERE]","api_keys":"[NOT_HERE]"}}` + assert.Equal(t, want, string(bytes)) + assert.Equal(t, jsonWant, string(jsonBytes)) + + v2 := &testStruct{} + err = json.Unmarshal(jsonBytes, v2) + assert.NoError(t, err) + err = yaml.Unmarshal(bytes, v2) + assert.NoError(t, err) + assert.Equal(t, "https://example.com", v2.Pico.Url) + if v2.Pico.Token != nil { + assert.Equal(t, "token1", v2.Pico.Token.String()) + assert.Equal(t, "token1", v2.Pico.Token.raw) + } + + v2.Pico.Token = NewSecureString("token1") + v2.Pico.Token.raw = "abc" + err = yaml.Unmarshal(bytes, v2) + assert.NoError(t, err) + assert.Equal(t, "token1", v2.Pico.Token.raw) + + os.Setenv("PICO_TOKEN", "token_env") + err = env.Parse(v2) + assert.NoError(t, err) + assert.NotNil(t, v2.Pico.Token) + assert.Equal(t, "token1", v2.Pico.Token.String()) + + v3 := &testStruct{Pico: &valueStruct{}} + err = env.Parse(v3) + assert.NoError(t, err) + if v3.Pico.Token != nil { + assert.Equal(t, "token_env", v3.Pico.Token.String()) + } + + type toolsStruct struct { + Pico valueStruct `json:"pico,omitempty" yaml:"pico,omitempty"` + } + + type testStruct2 struct { + Tools toolsStruct `json:"tools,omitempty" yaml:",inline"` + } + + v4 := &testStruct2{ + Tools: toolsStruct{ + Pico: valueStruct{ + Url: "https://example.com", + Token: NewSecureString("token1"), + ApiKeys: SecureStrings{NewSecureString("api-key1"), NewSecureString("api-key2")}, + }, + }, + } + bytes, err = yaml.Marshal(v4) + assert.NoError(t, err) + assert.Equal(t, want, string(bytes)) + jsonBytes, err = json.Marshal(v4) + assert.NoError(t, err) + assert.Equal( + t, + `{"tools":{"pico":{"url":"https://example.com","token":"[NOT_HERE]","api_keys":"[NOT_HERE]"}}}`, + string(jsonBytes), + ) + + v5 := &testStruct2{} + err = json.Unmarshal(jsonBytes, v5) + assert.NoError(t, err) + assert.Equal(t, "https://example.com", v5.Tools.Pico.Url) + err = yaml.Unmarshal(bytes, v5) + assert.NoError(t, err) + assert.NotNil(t, v5.Tools.Pico.Token) + assert.Equal(t, "token1", v5.Tools.Pico.Token.raw) + + dir := t.TempDir() + sshKeyPath := filepath.Join(dir, "picoclaw_ed25519.key") + if err = os.WriteFile(sshKeyPath, []byte("fake-ssh-key-material\n"), 0o600); err != nil { + t.Fatalf("setup: %v", err) + } + + const passphrase = "test-passphrase-32bytes-long-ok!" + + t.Setenv(credential.SSHKeyPathEnvVar, sshKeyPath) + + t.Setenv(credential.PassphraseEnvVar, passphrase) + + v5.Tools.Pico.Token.Set("newtoken1") + v5.Tools.Pico.ApiKeys[0].Set("newapi-key1") + bytes, err = yaml.Marshal(v5) + assert.NoError(t, err) + t.Logf("yaml: %s", string(bytes)) + + v6 := &testStruct2{} + err = yaml.Unmarshal(bytes, v6) + assert.NoError(t, err) + assert.NotNil(t, v6.Tools.Pico.Token) + assert.Equal(t, "newtoken1", v6.Tools.Pico.Token.String()) +} diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go index 75eb458b8..6734257f4 100644 --- a/pkg/config/config_test.go +++ b/pkg/config/config_test.go @@ -1673,3 +1673,163 @@ func TestFilterSensitiveData_AllTokenTypes(t *testing.T) { }) } } + +// --------------------------------------------------------------------------- +// makeBackup tests +// --------------------------------------------------------------------------- + +// TestMakeBackup_WithDateSuffix verifies backup files include a date suffix. +func TestMakeBackup_WithDateSuffix(t *testing.T) { + dir := t.TempDir() + configPath := filepath.Join(dir, "config.json") + if err := os.WriteFile(configPath, []byte(`{"version":2}`), 0o600); err != nil { + t.Fatalf("WriteFile: %v", err) + } + + if err := makeBackup(configPath); err != nil { + t.Fatalf("makeBackup: %v", err) + } + + entries, err := os.ReadDir(dir) + if err != nil { + t.Fatalf("ReadDir: %v", err) + } + + var hasDatedBackup bool + for _, e := range entries { + if matched, _ := filepath.Match("config.json.20*.bak", e.Name()); matched { + hasDatedBackup = true + // Verify backup content matches original + bakPath := filepath.Join(dir, e.Name()) + data, err := os.ReadFile(bakPath) + if err != nil { + t.Fatalf("ReadFile backup: %v", err) + } + if string(data) != `{"version":2}` { + t.Errorf("backup content = %q, want original content", string(data)) + } + break + } + } + if !hasDatedBackup { + t.Error("expected backup file with date suffix pattern config.json.20*.bak") + } +} + +// TestMakeBackup_AlsoBacksSecurityFile verifies that the security config file +// is also backed up with the same date suffix. +func TestMakeBackup_AlsoBacksSecurityFile(t *testing.T) { + dir := t.TempDir() + configPath := filepath.Join(dir, "config.json") + secPath := securityPath(configPath) + + os.WriteFile(configPath, []byte(`{"version":2}`), 0o600) + os.WriteFile(secPath, []byte(`model_list:\n test:0:\n api_keys:\n - "sk-test"\n`), 0o600) + + if err := makeBackup(configPath); err != nil { + t.Fatalf("makeBackup: %v", err) + } + + entries, err := os.ReadDir(dir) + if err != nil { + t.Fatalf("ReadDir: %v", err) + } + + configBackups := 0 + secBackups := 0 + for _, e := range entries { + if matched, _ := filepath.Match("config.json.20*.bak", e.Name()); matched { + configBackups++ + } + if matched, _ := filepath.Match(".security.yml.20*.bak", e.Name()); matched { + secBackups++ + } + } + if configBackups != 1 { + t.Errorf("expected 1 config backup, got %d", configBackups) + } + if secBackups != 1 { + t.Errorf("expected 1 security backup, got %d", secBackups) + } +} + +// TestMakeBackup_NonexistentFileSkipsBackup verifies that makeBackup returns nil +// when the config file does not exist (no error, no panic). +func TestMakeBackup_NonexistentFileSkipsBackup(t *testing.T) { + dir := t.TempDir() + configPath := filepath.Join(dir, "nonexistent.json") + + if err := makeBackup(configPath); err != nil { + t.Fatalf("makeBackup on nonexistent file should return nil, got: %v", err) + } +} + +// TestMakeBackup_OnlyConfigNoSecurity verifies backup succeeds when only +// the config file exists and no security file. +func TestMakeBackup_OnlyConfigNoSecurity(t *testing.T) { + dir := t.TempDir() + configPath := filepath.Join(dir, "config.json") + os.WriteFile(configPath, []byte(`{"version":2}`), 0o600) + + if err := makeBackup(configPath); err != nil { + t.Fatalf("makeBackup: %v", err) + } + + entries, _ := os.ReadDir(dir) + configBackups := 0 + secBackups := 0 + for _, e := range entries { + if matched, _ := filepath.Match("config.json.20*.bak", e.Name()); matched { + configBackups++ + } + if matched, _ := filepath.Match(".security.yml.20*.bak", e.Name()); matched { + secBackups++ + } + } + if configBackups != 1 { + t.Errorf("expected 1 config backup, got %d", configBackups) + } + if secBackups != 0 { + t.Errorf("expected 0 security backups when no security file exists, got %d", secBackups) + } +} + +// TestMakeBackup_SameDateSuffix verifies that config and security backups +// share the same date suffix (they are created in the same makeBackup call). +func TestMakeBackup_SameDateSuffix(t *testing.T) { + dir := t.TempDir() + configPath := filepath.Join(dir, "config.json") + secPath := securityPath(configPath) + + os.WriteFile(configPath, []byte(`{"version":2}`), 0o600) + os.WriteFile(secPath, []byte(`key: value`), 0o600) + + if err := makeBackup(configPath); err != nil { + t.Fatalf("makeBackup: %v", err) + } + + entries, _ := os.ReadDir(dir) + var configDate, secDate string + for _, e := range entries { + name := e.Name() + // Extract date part: after the last . before .bak + // e.g. config.json.20260330.bak → 20260330 + if strings.HasPrefix(name, "config.json.") && strings.HasSuffix(name, ".bak") { + configDate = strings.TrimPrefix(name, "config.json.") + configDate = strings.TrimSuffix(configDate, ".bak") + } + if strings.HasPrefix(name, ".security.yml.") && strings.HasSuffix(name, ".bak") { + secDate = strings.TrimPrefix(name, ".security.yml.") + secDate = strings.TrimSuffix(secDate, ".bak") + } + } + if configDate == "" { + t.Fatal("config backup file not found") + } + if secDate == "" { + t.Fatal("security backup file not found") + } + if configDate != secDate { + t.Errorf("config backup date = %q, security backup date = %q, should match", configDate, secDate) + } +} diff --git a/pkg/config/migration.go b/pkg/config/migration.go index fee800a76..7430050b3 100644 --- a/pkg/config/migration.go +++ b/pkg/config/migration.go @@ -534,3 +534,26 @@ func loadConfig(data []byte) (*Config, error) { } return cfg, nil } + +func mergeAPIKeys(apiKey string, apiKeys []string) []string { + seen := make(map[string]struct{}) + var all []string + + if k := strings.TrimSpace(apiKey); k != "" { + if _, exists := seen[k]; !exists { + seen[k] = struct{}{} + all = append(all, k) + } + } + + for _, k := range apiKeys { + if trimmed := strings.TrimSpace(k); trimmed != "" { + if _, exists := seen[trimmed]; !exists { + seen[trimmed] = struct{}{} + all = append(all, trimmed) + } + } + } + + return all +} diff --git a/pkg/config/migration_integration_test.go b/pkg/config/migration_integration_test.go index bc8160967..b180dda90 100644 --- a/pkg/config/migration_integration_test.go +++ b/pkg/config/migration_integration_test.go @@ -681,3 +681,473 @@ web: t.Error("Discord token not preserved in .security.yml file") } } + +// --------------------------------------------------------------------------- +// V1 → V2 migration tests +// --------------------------------------------------------------------------- + +// TestMigrateModelEnabled_APIKeysInferredEnabled verifies that models with API keys +// are marked as enabled during V1→V2 migration. +func TestMigrateModelEnabled_APIKeysInferredEnabled(t *testing.T) { + v1 := &configV1{Config: Config{ + ModelList: []*ModelConfig{ + {ModelName: "gpt-4", Model: "openai/gpt-4", APIKeys: SimpleSecureStrings("sk-test")}, + {ModelName: "claude", Model: "anthropic/claude", APIKeys: SimpleSecureStrings("sk-ant")}, + }, + }} + v1.migrateModelEnabled() + for _, m := range v1.ModelList { + if !m.Enabled { + t.Errorf("model %q with API key should be enabled", m.ModelName) + } + } +} + +// TestMigrateModelEnabled_LocalModelInferredEnabled verifies that the reserved +// "local-model" entry is enabled even without API keys. +func TestMigrateModelEnabled_LocalModelInferredEnabled(t *testing.T) { + v1 := &configV1{Config: Config{ + ModelList: []*ModelConfig{ + {ModelName: "local-model", Model: "vllm/custom-model", APIBase: "http://localhost:8000/v1"}, + }, + }} + v1.migrateModelEnabled() + if !v1.ModelList[0].Enabled { + t.Error("local-model should be enabled") + } +} + +// TestMigrateModelEnabled_NoKeyStaysDisabled verifies that models without API keys +// and not named "local-model" remain disabled. +func TestMigrateModelEnabled_NoKeyStaysDisabled(t *testing.T) { + v1 := &configV1{Config: Config{ + ModelList: []*ModelConfig{ + {ModelName: "gpt-4", Model: "openai/gpt-4"}, + {ModelName: "claude", Model: "anthropic/claude"}, + }, + }} + v1.migrateModelEnabled() + for _, m := range v1.ModelList { + if m.Enabled { + t.Errorf("model %q without API key should stay disabled", m.ModelName) + } + } +} + +// TestMigrateModelEnabled_ExplicitEnabledPreserved verifies that a model with +// explicitly enabled=true is NOT overridden by the migration. +func TestMigrateModelEnabled_ExplicitEnabledPreserved(t *testing.T) { + v1 := &configV1{Config: Config{ + ModelList: []*ModelConfig{ + {ModelName: "gpt-4", Model: "openai/gpt-4", APIKeys: SimpleSecureStrings("sk-test"), Enabled: true}, + }, + }} + v1.migrateModelEnabled() + if !v1.ModelList[0].Enabled { + t.Error("explicitly enabled model should remain enabled") + } +} + +// TestMigrateModelEnabled_ExplicitDisabledNotOverridden verifies that a model with +// explicitly enabled=false and API keys gets enabled during migration. +// Note: since Go's zero value for bool is false and JSON omitempty omits false, +// migration cannot distinguish "explicitly false" from "field absent". Both cases +// get the same inference treatment. +func TestMigrateModelEnabled_ExplicitDisabledNotOverridden(t *testing.T) { + v1 := &configV1{Config: Config{ + ModelList: []*ModelConfig{ + {ModelName: "gpt-4", Model: "openai/gpt-4", APIKeys: SimpleSecureStrings("sk-test"), Enabled: false}, + }, + }} + v1.migrateModelEnabled() + // Even though Enabled was set to false, migration infers it as true because + // the migration cannot distinguish from a missing field (both are zero value). + if !v1.ModelList[0].Enabled { + t.Error("model with API key should be enabled by migration inference") + } +} + +// TestMigrateModelEnabled_Mixed verifies a mix of models. +func TestMigrateModelEnabled_Mixed(t *testing.T) { + v1 := &configV1{Config: Config{ + ModelList: []*ModelConfig{ + {ModelName: "with-key", Model: "openai/gpt-4", APIKeys: SimpleSecureStrings("sk-test")}, + {ModelName: "no-key", Model: "openai/gpt-4"}, + {ModelName: "local-model", Model: "vllm/custom"}, + { + ModelName: "disabled-explicit", + Model: "openai/gpt-4", + APIKeys: SimpleSecureStrings("sk-test"), + Enabled: false, + }, + }, + }} + v1.migrateModelEnabled() + + assertEnabled := func(name string, want bool) { + for _, m := range v1.ModelList { + if m.ModelName == name { + if m.Enabled != want { + t.Errorf("model %q: Enabled=%v, want %v", name, m.Enabled, want) + } + return + } + } + t.Errorf("model %q not found", name) + } + + assertEnabled("with-key", true) + assertEnabled("no-key", false) + assertEnabled("local-model", true) + assertEnabled("disabled-explicit", true) // false is zero value, migration infers from API key +} + +// TestMigrateChannelConfigs_DiscordMentionOnly verifies Discord mention_only migration. +func TestMigrateChannelConfigs_DiscordMentionOnly(t *testing.T) { + v1 := &configV1{Config: Config{ + Channels: ChannelsConfig{ + Discord: DiscordConfig{ + MentionOnly: true, + }, + }, + }} + v1.migrateChannelConfigs() + if !v1.Channels.Discord.GroupTrigger.MentionOnly { + t.Error("Discord GroupTrigger.MentionOnly should be set to true") + } +} + +// TestMigrateChannelConfigs_DiscordAlreadyMigrated is a no-op test. +func TestMigrateChannelConfigs_DiscordAlreadyMigrated(t *testing.T) { + v1 := &configV1{Config: Config{ + Channels: ChannelsConfig{ + Discord: DiscordConfig{ + GroupTrigger: GroupTriggerConfig{MentionOnly: true}, + }, + }, + }} + v1.migrateChannelConfigs() +} + +// TestMigrateChannelConfigs_OneBotPrefix verifies OneBot prefix migration. +func TestMigrateChannelConfigs_OneBotPrefix(t *testing.T) { + v1 := &configV1{Config: Config{ + Channels: ChannelsConfig{ + OneBot: OneBotConfig{ + GroupTriggerPrefix: []string{"/"}, + }, + }, + }} + v1.migrateChannelConfigs() + if len(v1.Channels.OneBot.GroupTrigger.Prefixes) != 1 || v1.Channels.OneBot.GroupTrigger.Prefixes[0] != "/" { + t.Errorf("OneBot GroupTrigger.Prefixes = %v, want [\"/\"]", v1.Channels.OneBot.GroupTrigger.Prefixes) + } +} + +// TestMigrateConfigV1_Combined verifies that configV1.Migrate applies both migrations. +func TestMigrateConfigV1_Combined(t *testing.T) { + v1 := &configV1{Config: Config{ + ModelList: []*ModelConfig{ + {ModelName: "gpt-4", Model: "openai/gpt-4", APIKeys: SimpleSecureStrings("sk-test")}, + }, + Channels: ChannelsConfig{ + Discord: DiscordConfig{MentionOnly: true}, + }, + }} + result, err := v1.Migrate() + if err != nil { + t.Fatalf("Migrate: %v", err) + } + + if !result.ModelList[0].Enabled { + t.Error("model with API key should be enabled after V1→V2 migration") + } + if !result.Channels.Discord.GroupTrigger.MentionOnly { + t.Error("Discord mention_only should be migrated after V1→V2 migration") + } +} + +// TestLoadConfig_V1ToV2Migration verifies end-to-end V1→V2 config migration +// through LoadConfig, including Enabled field inference and version bump. +func TestLoadConfig_V1ToV2Migration(t *testing.T) { + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "config.json") + + // Write a V1 config with model_list but no "enabled" field + v1Config := `{ + "version": 1, + "model_list": [ + { + "model_name": "gpt-4", + "model": "openai/gpt-4" + }, + { + "model_name": "local-model", + "model": "vllm/custom-model", + "api_base": "http://localhost:8000/v1" + } + ], + "channels": { + "discord": { + "mention_only": true + } + }, + "gateway": {"host": "127.0.0.1", "port": 18790} + }` + + if err := os.WriteFile(configPath, []byte(v1Config), 0o600); err != nil { + t.Fatalf("WriteFile: %v", err) + } + + cfg, err := LoadConfig(configPath) + if err != nil { + t.Fatalf("LoadConfig: %v", err) + } + + // Version should be bumped to 2 + if cfg.Version != CurrentVersion { + t.Errorf("Version = %d, want %d", cfg.Version, CurrentVersion) + } + + // gpt-4 has no API key → disabled + gpt4, err := cfg.GetModelConfig("gpt-4") + if err != nil { + t.Fatalf("GetModelConfig(gpt-4): %v", err) + } + if gpt4.Enabled { + t.Error("gpt-4 without API key should be disabled after migration") + } + + // local-model → enabled + local, err := cfg.GetModelConfig("local-model") + if err != nil { + t.Fatalf("GetModelConfig(local-model): %v", err) + } + if !local.Enabled { + t.Error("local-model should be enabled after migration") + } + + // Discord channel config should be migrated + if !cfg.Channels.Discord.GroupTrigger.MentionOnly { + t.Error("Discord mention_only should be migrated to group_trigger.mention_only") + } + + // Verify backup was created with date suffix + entries, err := os.ReadDir(tmpDir) + if err != nil { + t.Fatalf("ReadDir: %v", err) + } + var hasBackup bool + for _, e := range entries { + if matched, _ := filepath.Match("config.json.20*.bak", e.Name()); matched { + hasBackup = true + break + } + } + if !hasBackup { + t.Error("expected backup file with date suffix to be created") + } + + // Verify the saved config on disk now has version 2 + saved, err := os.ReadFile(configPath) + if err != nil { + t.Fatalf("ReadFile saved config: %v", err) + } + var versionCheck struct { + Version int `json:"version"` + } + if err := json.Unmarshal(saved, &versionCheck); err != nil { + t.Fatalf("Unmarshal saved config: %v", err) + } + if versionCheck.Version != 2 { + t.Errorf("saved config version = %d, want 2", versionCheck.Version) + } +} + +// TestLoadConfig_V1WithAPIKeysInferredEnabled verifies that V1 configs with +// API keys in the security file get Enabled=true after migration. +func TestLoadConfig_V1WithAPIKeysInferredEnabled(t *testing.T) { + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "config.json") + secPath := securityPath(configPath) + + v1Config := `{ + "version": 1, + "model_list": [ + {"model_name": "gpt-4", "model": "openai/gpt-4"}, + {"model_name": "claude", "model": "anthropic/claude"} + ], + "gateway": {"host": "127.0.0.1", "port": 18790} + }` + + securityConfig := `model_list: + gpt-4:0: + api_keys: + - "sk-gpt-key" + claude:0: + api_keys: + - "sk-claude-key" +` + + if err := os.WriteFile(configPath, []byte(v1Config), 0o600); err != nil { + t.Fatalf("WriteFile: %v", err) + } + if err := os.WriteFile(secPath, []byte(securityConfig), 0o600); err != nil { + t.Fatalf("WriteFile security: %v", err) + } + + cfg, err := LoadConfig(configPath) + if err != nil { + t.Fatalf("LoadConfig: %v", err) + } + + for _, m := range cfg.ModelList { + if !m.Enabled { + t.Errorf("model %q with API key in security file should be enabled", m.ModelName) + } + } +} + +// TestLoadConfig_V2DirectLoad verifies that V2 configs load directly without +// running any migration. +func TestLoadConfig_V2DirectLoad(t *testing.T) { + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "config.json") + + v2Config := `{ + "version": 2, + "model_list": [ + { + "model_name": "gpt-4", + "model": "openai/gpt-4", + "enabled": true + }, + { + "model_name": "claude", + "model": "anthropic/claude" + } + ], + "gateway": {"host": "127.0.0.1", "port": 18790} + }` + + if err := os.WriteFile(configPath, []byte(v2Config), 0o600); err != nil { + t.Fatalf("WriteFile: %v", err) + } + + cfg, err := LoadConfig(configPath) + if err != nil { + t.Fatalf("LoadConfig: %v", err) + } + + if cfg.Version != 2 { + t.Errorf("Version = %d, want 2", cfg.Version) + } + + gpt4, _ := cfg.GetModelConfig("gpt-4") + if !gpt4.Enabled { + t.Error("gpt-4 with explicit enabled=true should remain enabled") + } + + claude, _ := cfg.GetModelConfig("claude") + if claude.Enabled { + t.Error("claude without enabled field should be false (no migration for V2)") + } + + // No backup should be created for V2 load + entries, _ := os.ReadDir(tmpDir) + for _, e := range entries { + if matched, _ := filepath.Match("config.json.*.bak", e.Name()); matched { + t.Errorf("V2 load should not create backup, but found %q", e.Name()) + } + } +} + +// TestLoadConfig_V0MigrateProducesV2 verifies that V0→V2 migration produces +// correct Enabled fields and version. +func TestLoadConfig_V0MigrateProducesV2(t *testing.T) { + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "config.json") + + v0Config := `{ + "model_list": [ + { + "model_name": "gpt-4", + "model": "openai/gpt-4", + "api_key": "sk-test" + }, + { + "model_name": "claude", + "model": "anthropic/claude" + }, + { + "model_name": "local-model", + "model": "vllm/custom-model" + } + ], + "gateway": {"host": "127.0.0.1", "port": 18790} + }` + + if err := os.WriteFile(configPath, []byte(v0Config), 0o600); err != nil { + t.Fatalf("WriteFile: %v", err) + } + + cfg, err := LoadConfig(configPath) + if err != nil { + t.Fatalf("LoadConfig: %v", err) + } + + if cfg.Version != CurrentVersion { + t.Errorf("Version = %d, want %d", cfg.Version, CurrentVersion) + } + + // Check enabled status + modelEnabled := func(name string) bool { + m, err := cfg.GetModelConfig(name) + if err != nil { + return false + } + return m.Enabled + } + + if !modelEnabled("gpt-4") { + t.Error("gpt-4 with API key from V0 should be enabled") + } + if modelEnabled("claude") { + t.Error("claude without API key from V0 should be disabled") + } + if !modelEnabled("local-model") { + t.Error("local-model from V0 should be enabled") + } +} + +// TestLoadConfig_UnsupportedVersion verifies that unsupported versions return an error. +func TestLoadConfig_UnsupportedVersion(t *testing.T) { + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "config.json") + + badConfig := `{"version": 99, "gateway": {"host": "127.0.0.1", "port": 18790}}` + if err := os.WriteFile(configPath, []byte(badConfig), 0o600); err != nil { + t.Fatalf("WriteFile: %v", err) + } + + _, err := LoadConfig(configPath) + if err == nil { + t.Fatal("LoadConfig should return error for unsupported version") + } + if !containsString(err.Error(), "unsupported config version") { + t.Errorf("error = %q, want 'unsupported config version'", err.Error()) + } +} + +func containsString(s, substr string) bool { + return len(s) >= len(substr) && searchString(s, substr) +} + +func searchString(s, substr string) bool { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return true + } + } + return false +} diff --git a/pkg/config/multikey_test.go b/pkg/config/multikey_test.go index e58c6dc9e..947e942da 100644 --- a/pkg/config/multikey_test.go +++ b/pkg/config/multikey_test.go @@ -345,7 +345,7 @@ func TestMergeAPIKeys(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result := MergeAPIKeys(tt.apiKey, tt.apiKeys) + result := mergeAPIKeys(tt.apiKey, tt.apiKeys) if len(result) != len(tt.expected) { t.Fatalf("expected %d keys, got %d", len(tt.expected), len(result)) } diff --git a/pkg/config/security.go b/pkg/config/security.go index 79dd26e14..2414cd7fa 100644 --- a/pkg/config/security.go +++ b/pkg/config/security.go @@ -7,20 +7,16 @@ package config import ( "bytes" - "encoding/json" "fmt" "os" "path/filepath" "reflect" - "runtime" "strings" "sync" "gopkg.in/yaml.v3" - "github.com/sipeed/picoclaw/pkg/credential" "github.com/sipeed/picoclaw/pkg/fileutil" - "github.com/sipeed/picoclaw/pkg/logger" ) const ( @@ -66,7 +62,6 @@ func saveSecurityConfig(securityPath string, sec *Config) error { return fileutil.WriteFileAtomic(securityPath, buf.Bytes(), 0o600) } -// SensitiveDataCache caches the compiled regex for filtering sensitive data. // SensitiveDataCache caches the strings.Replacer for filtering sensitive data. // Computed once on first access via sync.Once. type SensitiveDataCache struct { @@ -178,234 +173,3 @@ func collectSensitive(v reflect.Value, values *[]string) { } } } - -const ( - notHere = `"[NOT_HERE]"` -) - -// SecureStrings is a slice of SecureString -type SecureStrings []*SecureString - -// Values returns the decrypted/resolved values -func (s *SecureStrings) Values() []string { - if s == nil { - return nil - } - keys := make([]string, len(*s)) - for i, k := range *s { - keys[i] = k.String() - } - return unique(keys) -} - -func SimpleSecureStrings(val ...string) SecureStrings { - val = unique(val) - vv := make(SecureStrings, len(val)) - for i, s := range val { - vv[i] = NewSecureString(s) - } - return vv -} - -// unique returns a new slice with duplicate elements removed. -func unique[T comparable](input []T) []T { - m := make(map[T]struct{}) - var result []T - for _, v := range input { - if _, ok := m[v]; !ok { - m[v] = struct{}{} - result = append(result, v) - } - } - return result -} - -func (s SecureStrings) MarshalJSON() ([]byte, error) { - return []byte(notHere), nil -} - -func (s *SecureStrings) UnmarshalJSON(value []byte) error { - if string(value) == notHere { - return nil - } - var v []*SecureString - err := json.Unmarshal(value, &v) - if err != nil { - return err - } - *s = v - return nil -} - -// SecureString the string value that can be decrypted or resolved -// -//nolint:recvcheck -type SecureString struct { - resolved string // Decrypted/resolved value returned by String() - raw string // Persisted raw value (enc://, file://, or plaintext) -} - -func callerFromYaml() bool { - _, file, _, ok := runtime.Caller(2) - if ok { - d := filepath.Dir(file) - // check the caller is from yaml.v - if !strings.Contains(d, "yaml.v") { - return true - } - } - return false -} - -// IsZero returns true if the SecureString is empty -// if caller not yaml, just return true for prevent marshal this field -func (s SecureString) IsZero() bool { - if callerFromYaml() { - return true - } - return s.resolved == "" -} - -func NewSecureString(value string) *SecureString { - s := &SecureString{} - if err := s.fromRaw(value); err != nil { - logger.Warn(fmt.Sprintf("NewSecureString.fromRaw error: %s", err)) - } - return s -} - -func (s *SecureString) String() string { - if s == nil { - return "" - } - return s.resolved -} - -func (s *SecureString) Set(value string) *SecureString { - s.resolved = value - s.raw = "" - return s -} - -func (s SecureString) MarshalJSON() ([]byte, error) { - return []byte(notHere), nil -} - -func (s *SecureString) UnmarshalJSON(value []byte) error { - if string(value) == notHere { - return nil - } - var v string - if err := json.Unmarshal(value, &v); err != nil { - return err - } - return s.fromRaw(v) -} - -func (s SecureString) MarshalYAML() (any, error) { - // Preserve raw value if it is already a reference (enc:// or file://) - if strings.HasPrefix(s.raw, credential.EncScheme) || strings.HasPrefix(s.raw, credential.FileScheme) { - return s.raw, nil - } - // If resolved is a reference format (e.g. set via Set), copy back to raw - if strings.HasPrefix(s.resolved, credential.EncScheme) || strings.HasPrefix(s.resolved, credential.FileScheme) { - s.raw = s.resolved - return s.raw, nil - } - // Try to encrypt the resolved value - if passphrase := credential.PassphraseProvider(); passphrase != "" { - encrypted, err := credential.Encrypt(passphrase, "", s.resolved) - if err != nil { - logger.Errorf("Encrypt error: %v", err) - return nil, err - } - s.raw = encrypted - } else { - s.raw = s.resolved - } - return s.raw, nil -} - -func (s *SecureString) UnmarshalYAML(value *yaml.Node) error { - return s.fromRaw(value.Value) -} - -func (s *SecureString) fromRaw(v string) error { - s.raw = v - vv, err := resolveKey(v) - if err != nil { - return err - } - s.resolved = vv - return nil -} - -var ( - secResolverMu sync.RWMutex - secResolver *credential.Resolver -) - -func updateResolver(path string) { - secResolverMu.Lock() - defer secResolverMu.Unlock() - secResolver = credential.NewResolver(path) -} - -func resolveKey(v string) (string, error) { - secResolverMu.RLock() - resolver := secResolver - secResolverMu.RUnlock() - if resolver == nil { - resolver = credential.NewResolver("") - } - if strings.HasPrefix(v, "enc://") || strings.HasPrefix(v, "file://") { - decrypted, err := resolver.Resolve(v) - if err != nil { - logger.Errorf("Resolve error: %v", err) - return "", err - } - return decrypted, nil - } - return v, nil -} - -func (s *SecureString) UnmarshalText(text []byte) error { - v := string(text) - return s.fromRaw(v) -} - -type SecureModelList []*ModelConfig - -func (v *SecureModelList) UnmarshalYAML(value *yaml.Node) error { - mm := make(map[string]*ModelConfig) - if err := value.Decode(&mm); err != nil { - logger.Errorf("Decode error: %v", err) - return err - } - nameList := toNameIndex(*v) - for i, m := range *v { - sec := mm[nameList[i]] - if sec == nil { - sec = mm[m.ModelName] - } - if sec != nil { - m.APIKeys = sec.APIKeys - } - } - return nil -} - -func (v SecureModelList) MarshalYAML() (any, error) { - type onlySecureData struct { - APIKeys SecureStrings `yaml:"api_keys,omitempty"` - } - mm := make(map[string]onlySecureData) - nameList := toNameIndex(v) - for i, m := range v { - mm[nameList[i]] = onlySecureData{ - APIKeys: m.APIKeys, - } - } - - return mm, nil -} diff --git a/pkg/config/security_test.go b/pkg/config/security_test.go index 834ba3606..548a6dc87 100644 --- a/pkg/config/security_test.go +++ b/pkg/config/security_test.go @@ -15,8 +15,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "gopkg.in/yaml.v3" - - "github.com/sipeed/picoclaw/pkg/credential" ) func TestSecurityConfig(t *testing.T) { @@ -227,134 +225,3 @@ skills: assert.Equal(t, "abc", cfg2.Tools.Web.Brave.APIKeys[1].raw) }) } - -func TestLoadSecurityValue(t *testing.T) { - type valueStruct struct { - Url string `json:"url,omitempty" yaml:"-"` - Token *SecureString `json:"token,omitempty" yaml:"token,omitempty" env:"PICO_TOKEN"` - ApiKeys SecureStrings `json:"api_keys,omitempty" yaml:"api_keys,omitempty" env:"PICO_API_KEYS"` - } - - type testStruct struct { - Pico *valueStruct `json:"pico,omitempty" yaml:"pico,omitempty"` - } - - v1 := &testStruct{ - Pico: &valueStruct{ - Url: "https://example.com", - Token: NewSecureString("token1"), - ApiKeys: SecureStrings{NewSecureString("api-key1"), NewSecureString("api-key2")}, - }, - } - bytes, err := yaml.Marshal(v1) - assert.NoError(t, err) - jsonBytes, err := json.Marshal(v1) - assert.NoError(t, err) - const want = `pico: - token: token1 - api_keys: - - api-key1 - - api-key2 -` - const jsonPost = `{"pico":{"url":"https://example.com","token":"token0"}}` - v0 := &testStruct{} - err = json.Unmarshal([]byte(jsonPost), v0) - assert.NoError(t, err) - assert.Equal(t, "https://example.com", v0.Pico.Url) - assert.Equal(t, "token0", v0.Pico.Token.String()) - - const jsonWant = `{"pico":{"url":"https://example.com","token":"[NOT_HERE]","api_keys":"[NOT_HERE]"}}` - assert.Equal(t, want, string(bytes)) - assert.Equal(t, jsonWant, string(jsonBytes)) - - v2 := &testStruct{} - err = json.Unmarshal(jsonBytes, v2) - assert.NoError(t, err) - err = yaml.Unmarshal(bytes, v2) - assert.NoError(t, err) - assert.Equal(t, "https://example.com", v2.Pico.Url) - if v2.Pico.Token != nil { - assert.Equal(t, "token1", v2.Pico.Token.String()) - assert.Equal(t, "token1", v2.Pico.Token.raw) - } - - v2.Pico.Token = NewSecureString("token1") - v2.Pico.Token.raw = "abc" - err = yaml.Unmarshal(bytes, v2) - assert.NoError(t, err) - assert.Equal(t, "token1", v2.Pico.Token.raw) - - os.Setenv("PICO_TOKEN", "token_env") - err = env.Parse(v2) - assert.NoError(t, err) - assert.NotNil(t, v2.Pico.Token) - assert.Equal(t, "token1", v2.Pico.Token.String()) - - v3 := &testStruct{Pico: &valueStruct{}} - err = env.Parse(v3) - assert.NoError(t, err) - if v3.Pico.Token != nil { - assert.Equal(t, "token_env", v3.Pico.Token.String()) - } - - type toolsStruct struct { - Pico valueStruct `json:"pico,omitempty" yaml:"pico,omitempty"` - } - - type testStruct2 struct { - Tools toolsStruct `json:"tools,omitempty" yaml:",inline"` - } - - v4 := &testStruct2{ - Tools: toolsStruct{ - Pico: valueStruct{ - Url: "https://example.com", - Token: NewSecureString("token1"), - ApiKeys: SecureStrings{NewSecureString("api-key1"), NewSecureString("api-key2")}, - }, - }, - } - bytes, err = yaml.Marshal(v4) - assert.NoError(t, err) - assert.Equal(t, want, string(bytes)) - jsonBytes, err = json.Marshal(v4) - assert.NoError(t, err) - assert.Equal( - t, - `{"tools":{"pico":{"url":"https://example.com","token":"[NOT_HERE]","api_keys":"[NOT_HERE]"}}}`, - string(jsonBytes), - ) - - v5 := &testStruct2{} - err = json.Unmarshal(jsonBytes, v5) - assert.NoError(t, err) - assert.Equal(t, "https://example.com", v5.Tools.Pico.Url) - err = yaml.Unmarshal(bytes, v5) - assert.NoError(t, err) - assert.NotNil(t, v5.Tools.Pico.Token) - assert.Equal(t, "token1", v5.Tools.Pico.Token.raw) - - dir := t.TempDir() - sshKeyPath := filepath.Join(dir, "picoclaw_ed25519.key") - if err = os.WriteFile(sshKeyPath, []byte("fake-ssh-key-material\n"), 0o600); err != nil { - t.Fatalf("setup: %v", err) - } - - const passphrase = "test-passphrase-32bytes-long-ok!" - - t.Setenv(credential.SSHKeyPathEnvVar, sshKeyPath) - - t.Setenv(credential.PassphraseEnvVar, passphrase) - - v5.Tools.Pico.Token.Set("newtoken1") - v5.Tools.Pico.ApiKeys[0].Set("newapi-key1") - bytes, err = yaml.Marshal(v5) - assert.NoError(t, err) - t.Logf("yaml: %s", string(bytes)) - - v6 := &testStruct2{} - err = yaml.Unmarshal(bytes, v6) - assert.NoError(t, err) - assert.NotNil(t, v6.Tools.Pico.Token) - assert.Equal(t, "newtoken1", v6.Tools.Pico.Token.String()) -} diff --git a/web/backend/api/models.go b/web/backend/api/models.go index 38a55948b..fd3cd85b7 100644 --- a/web/backend/api/models.go +++ b/web/backend/api/models.go @@ -40,6 +40,7 @@ type modelResponse struct { ThinkingLevel string `json:"thinking_level,omitempty"` ExtraBody map[string]any `json:"extra_body,omitempty"` // Meta + Enabled bool `json:"enabled"` Configured bool `json:"configured"` IsDefault bool `json:"is_default"` IsVirtual bool `json:"is_virtual"` @@ -85,6 +86,7 @@ func (h *Handler) handleListModels(w http.ResponseWriter, r *http.Request) { RequestTimeout: m.RequestTimeout, ThinkingLevel: m.ThinkingLevel, ExtraBody: m.ExtraBody, + Enabled: m.Enabled, Configured: configured[i], IsDefault: m.ModelName == defaultModel, IsVirtual: m.IsVirtual(),