mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
Merge pull request #2184 from cytown/config
refactor config and add ModelConfig.Enabled
This commit is contained in:
+54
-128
@@ -7,8 +7,8 @@ import (
|
||||
"math/rand"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/caarlos0/env/v11"
|
||||
|
||||
@@ -20,89 +20,8 @@ import (
|
||||
// rrCounter is a global counter for round-robin load balancing across models.
|
||||
var rrCounter atomic.Uint64
|
||||
|
||||
// FlexibleStringSlice is a []string that also accepts JSON numbers,
|
||||
// so allow_from can contain both "123" and 123.
|
||||
// It also supports parsing comma-separated strings from environment variables,
|
||||
// including both English (,) and Chinese (,) commas.
|
||||
type FlexibleStringSlice []string
|
||||
|
||||
func (f *FlexibleStringSlice) UnmarshalJSON(data []byte) error {
|
||||
// Accept a single JSON string for convenience, e.g.:
|
||||
// "text": "Thinking..."
|
||||
var singleString string
|
||||
if err := json.Unmarshal(data, &singleString); err == nil {
|
||||
*f = FlexibleStringSlice{singleString}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Accept a single JSON number too, to keep symmetry with mixed allow_from
|
||||
// payloads that may contain numeric identifiers.
|
||||
var singleNumber float64
|
||||
if err := json.Unmarshal(data, &singleNumber); err == nil {
|
||||
*f = FlexibleStringSlice{fmt.Sprintf("%.0f", singleNumber)}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Try []string first
|
||||
var ss []string
|
||||
if err := json.Unmarshal(data, &ss); err == nil {
|
||||
*f = ss
|
||||
return nil
|
||||
}
|
||||
|
||||
// Try []interface{} to handle mixed types
|
||||
var raw []any
|
||||
if err := json.Unmarshal(data, &raw); err != nil {
|
||||
var s string
|
||||
// fail over to compatible to old format string
|
||||
if err = json.Unmarshal(data, &s); err != nil {
|
||||
return err
|
||||
}
|
||||
*f = []string{s}
|
||||
return nil
|
||||
}
|
||||
|
||||
result := make([]string, 0, len(raw))
|
||||
for _, v := range raw {
|
||||
switch val := v.(type) {
|
||||
case string:
|
||||
result = append(result, val)
|
||||
case float64:
|
||||
result = append(result, fmt.Sprintf("%.0f", val))
|
||||
default:
|
||||
result = append(result, fmt.Sprintf("%v", val))
|
||||
}
|
||||
}
|
||||
*f = result
|
||||
return nil
|
||||
}
|
||||
|
||||
// UnmarshalText implements encoding.TextUnmarshaler to support env variable parsing.
|
||||
// It handles comma-separated values with both English (,) and Chinese (,) commas.
|
||||
func (f *FlexibleStringSlice) UnmarshalText(text []byte) error {
|
||||
if len(text) == 0 {
|
||||
*f = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
s := string(text)
|
||||
// Replace Chinese comma with English comma, then split
|
||||
s = strings.ReplaceAll(s, ",", ",")
|
||||
parts := strings.Split(s, ",")
|
||||
|
||||
result := make([]string, 0, len(parts))
|
||||
for _, part := range parts {
|
||||
part = strings.TrimSpace(part)
|
||||
if part != "" {
|
||||
result = append(result, part)
|
||||
}
|
||||
}
|
||||
*f = result
|
||||
return nil
|
||||
}
|
||||
|
||||
// CurrentVersion is the latest config schema version
|
||||
const CurrentVersion = 1
|
||||
const CurrentVersion = 2
|
||||
|
||||
// Config is the current config structure with version support
|
||||
type Config struct {
|
||||
@@ -675,6 +594,11 @@ type ModelConfig struct {
|
||||
|
||||
APIKeys SecureStrings `json:"api_keys,omitzero" yaml:"api_keys,omitempty"` // API authentication keys (multiple keys for failover)
|
||||
|
||||
// Enabled indicates whether this model entry is active. When omitted in
|
||||
// existing configs, the field is inferred during load: models with API keys
|
||||
// or the reserved "local-model" name are auto-enabled.
|
||||
Enabled bool `json:"enabled,omitempty" yaml:"enabled,omitempty"`
|
||||
|
||||
// isVirtual marks this model as a virtual model generated from multi-key expansion.
|
||||
// Virtual models should not be persisted to config files.
|
||||
isVirtual bool
|
||||
@@ -1047,6 +971,35 @@ func LoadConfig(path string) (*Config, error) {
|
||||
defer func(cfg *Config) {
|
||||
_ = SaveConfig(path, cfg)
|
||||
}(cfg)
|
||||
case 1:
|
||||
// V1→V2 migration: infer Enabled and migrate channel config fields
|
||||
logger.InfoF("config migrate start", map[string]any{"from": versionInfo.Version, "to": CurrentVersion})
|
||||
cfg, err = loadConfig(data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
secPath := securityPath(path)
|
||||
err = loadSecurityConfig(cfg, secPath)
|
||||
if err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
return nil, fmt.Errorf("failed to load security config: %w", err)
|
||||
}
|
||||
|
||||
oldCfg := &configV1{Config: *cfg}
|
||||
cfg, err = oldCfg.Migrate()
|
||||
if err != nil {
|
||||
logger.ErrorF("config migrate fail", map[string]any{"from": versionInfo.Version, "to": CurrentVersion})
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = makeBackup(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer func(cfg *Config) {
|
||||
_ = SaveConfig(path, cfg)
|
||||
}(cfg)
|
||||
logger.InfoF("config migrate success", map[string]any{"from": versionInfo.Version, "to": CurrentVersion})
|
||||
case CurrentVersion:
|
||||
// Current version
|
||||
cfg, err = loadConfig(data)
|
||||
@@ -1064,18 +1017,15 @@ func LoadConfig(path string) (*Config, error) {
|
||||
return nil, fmt.Errorf("unsupported config version: %d", versionInfo.Version)
|
||||
}
|
||||
|
||||
if err := env.Parse(cfg); err != nil {
|
||||
if err = env.Parse(cfg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Expand multi-key configs into separate entries for key-level failover
|
||||
cfg.ModelList = expandMultiKeyModels(cfg.ModelList)
|
||||
|
||||
// Migrate legacy channel config fields to new unified structures
|
||||
cfg.migrateChannelConfigs()
|
||||
|
||||
// Validate model_list for uniqueness and required fields
|
||||
if err := cfg.ValidateModelList(); err != nil {
|
||||
if err = cfg.ValidateModelList(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -1097,12 +1047,22 @@ func makeBackup(path string) error {
|
||||
if _, err := os.Stat(path); os.IsNotExist(err) {
|
||||
return nil
|
||||
}
|
||||
// Create backup of the config file before migration
|
||||
bakPath := path + ".bak"
|
||||
dateSuffix := time.Now().Format(".20060102.bak")
|
||||
// Backup config file
|
||||
bakPath := path + dateSuffix
|
||||
if err := fileutil.CopyFile(path, bakPath, 0o600); err != nil {
|
||||
logger.ErrorF("failed to create config backup", map[string]any{"error": err})
|
||||
return fmt.Errorf("failed to create config backup: %w", err)
|
||||
}
|
||||
// Backup security config file
|
||||
secPath := securityPath(path)
|
||||
if _, err := os.Stat(secPath); err == nil {
|
||||
secBakPath := secPath + dateSuffix
|
||||
if secErr := fileutil.CopyFile(secPath, secBakPath, 0o600); secErr != nil {
|
||||
logger.ErrorF("failed to create security backup", map[string]any{"error": secErr})
|
||||
return fmt.Errorf("failed to create security backup: %w", secErr)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1118,19 +1078,6 @@ func toNameIndex(list []*ModelConfig) []string {
|
||||
return nameList
|
||||
}
|
||||
|
||||
func (c *Config) migrateChannelConfigs() {
|
||||
// Discord: mention_only -> group_trigger.mention_only
|
||||
if c.Channels.Discord.MentionOnly && !c.Channels.Discord.GroupTrigger.MentionOnly {
|
||||
c.Channels.Discord.GroupTrigger.MentionOnly = true
|
||||
}
|
||||
|
||||
// OneBot: group_trigger_prefix -> group_trigger.prefixes
|
||||
if len(c.Channels.OneBot.GroupTriggerPrefix) > 0 &&
|
||||
len(c.Channels.OneBot.GroupTrigger.Prefixes) == 0 {
|
||||
c.Channels.OneBot.GroupTrigger.Prefixes = c.Channels.OneBot.GroupTriggerPrefix
|
||||
}
|
||||
}
|
||||
|
||||
func SaveConfig(path string, cfg *Config) error {
|
||||
if cfg.Version < CurrentVersion {
|
||||
cfg.Version = CurrentVersion
|
||||
@@ -1144,6 +1091,10 @@ func SaveConfig(path string, cfg *Config) error {
|
||||
}
|
||||
// Temporarily replace ModelList with filtered version for serialization
|
||||
originalModelList := cfg.ModelList
|
||||
defer func() {
|
||||
// Restore original ModelList after serialization
|
||||
cfg.ModelList = originalModelList
|
||||
}()
|
||||
cfg.ModelList = nonVirtualModels
|
||||
|
||||
if err := saveSecurityConfig(securityPath(path), cfg); err != nil {
|
||||
@@ -1152,8 +1103,6 @@ func SaveConfig(path string, cfg *Config) error {
|
||||
}
|
||||
|
||||
data, err := json.MarshalIndent(cfg, "", " ")
|
||||
// Restore original ModelList after serialization
|
||||
cfg.ModelList = originalModelList
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -1223,29 +1172,6 @@ func (c *Config) SecurityCopyFrom(path string) error {
|
||||
return loadSecurityConfig(c, securityPath(path))
|
||||
}
|
||||
|
||||
func MergeAPIKeys(apiKey string, apiKeys []string) []string {
|
||||
seen := make(map[string]struct{})
|
||||
var all []string
|
||||
|
||||
if k := strings.TrimSpace(apiKey); k != "" {
|
||||
if _, exists := seen[k]; !exists {
|
||||
seen[k] = struct{}{}
|
||||
all = append(all, k)
|
||||
}
|
||||
}
|
||||
|
||||
for _, k := range apiKeys {
|
||||
if trimmed := strings.TrimSpace(k); trimmed != "" {
|
||||
if _, exists := seen[trimmed]; !exists {
|
||||
seen[trimmed] = struct{}{}
|
||||
all = append(all, trimmed)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return all
|
||||
}
|
||||
|
||||
// expandMultiKeyModels expands ModelConfig entries with multiple API keys into
|
||||
// separate entries for key-level failover. Each key gets its own ModelConfig entry,
|
||||
// and the original entry's fallbacks are set up to chain through the expanded entries.
|
||||
|
||||
@@ -734,7 +734,8 @@ func (c *configV0) Migrate() (*Config, error) {
|
||||
// Convert []modelConfigV0 to []ModelConfig
|
||||
cfg.ModelList = make([]*ModelConfig, len(c.ModelList))
|
||||
for i, m := range c.ModelList {
|
||||
cfg.ModelList[i] = &ModelConfig{
|
||||
mergedKeys := toSecureStrings(mergeAPIKeys(m.APIKey, m.APIKeys))
|
||||
mc := &ModelConfig{
|
||||
ModelName: m.ModelName,
|
||||
Model: m.Model,
|
||||
APIBase: m.APIBase,
|
||||
@@ -747,8 +748,13 @@ func (c *configV0) Migrate() (*Config, error) {
|
||||
MaxTokensField: m.MaxTokensField,
|
||||
RequestTimeout: m.RequestTimeout,
|
||||
ThinkingLevel: m.ThinkingLevel,
|
||||
APIKeys: toSecureStrings(MergeAPIKeys(m.APIKey, m.APIKeys)),
|
||||
APIKeys: mergedKeys,
|
||||
}
|
||||
// Infer Enabled during V0→V1 migration
|
||||
if len(mergedKeys) > 0 || m.ModelName == "local-model" {
|
||||
mc.Enabled = true
|
||||
}
|
||||
cfg.ModelList[i] = mc
|
||||
}
|
||||
}
|
||||
|
||||
@@ -756,6 +762,52 @@ func (c *configV0) Migrate() (*Config, error) {
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
type configV1 struct {
|
||||
Config
|
||||
}
|
||||
|
||||
// Migrate applies V1→Current Version migrations to an already-loaded Config.
|
||||
//
|
||||
// It must be called AFTER loadSecurityConfig so that API keys (which live in
|
||||
// the security file) are available for the Enabled inference.
|
||||
func (c *configV1) Migrate() (*Config, error) {
|
||||
c.migrateModelEnabled()
|
||||
c.migrateChannelConfigs()
|
||||
return &c.Config, nil
|
||||
}
|
||||
|
||||
// migrateModelEnabled infers the Enabled field for models loaded from V1 configs
|
||||
// that predate the field (JSON where "enabled" is absent).
|
||||
//
|
||||
// Rules (only applied when Enabled has not been explicitly set by the user):
|
||||
// - Models with API keys are considered enabled.
|
||||
// - The reserved "local-model" entry is considered enabled.
|
||||
func (cfg *configV1) migrateModelEnabled() {
|
||||
for _, m := range cfg.ModelList {
|
||||
if m.Enabled {
|
||||
continue
|
||||
}
|
||||
if len(m.APIKeys) > 0 || m.ModelName == "local-model" {
|
||||
m.Enabled = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// migrateChannelConfigs migrates legacy channel config fields in a V1 Config
|
||||
// to the new unified structures.
|
||||
func (cfg *configV1) migrateChannelConfigs() {
|
||||
// Discord: mention_only -> group_trigger.mention_only
|
||||
if cfg.Channels.Discord.MentionOnly && !cfg.Channels.Discord.GroupTrigger.MentionOnly {
|
||||
cfg.Channels.Discord.GroupTrigger.MentionOnly = true
|
||||
}
|
||||
|
||||
// OneBot: group_trigger_prefix -> group_trigger.prefixes
|
||||
if len(cfg.Channels.OneBot.GroupTriggerPrefix) > 0 &&
|
||||
len(cfg.Channels.OneBot.GroupTrigger.Prefixes) == 0 {
|
||||
cfg.Channels.OneBot.GroupTrigger.Prefixes = cfg.Channels.OneBot.GroupTriggerPrefix
|
||||
}
|
||||
}
|
||||
|
||||
type webToolsConfigV0 struct {
|
||||
ToolConfig ` envPrefix:"PICOCLAW_TOOLS_WEB_"`
|
||||
Brave braveConfigV0 ` json:"brave"`
|
||||
@@ -791,7 +843,7 @@ func (v *braveConfigV0) ToBraveConfig() BraveConfig {
|
||||
return BraveConfig{
|
||||
Enabled: v.Enabled,
|
||||
MaxResults: v.MaxResults,
|
||||
APIKeys: toSecureStrings(MergeAPIKeys(v.APIKey, v.APIKeys)),
|
||||
APIKeys: toSecureStrings(mergeAPIKeys(v.APIKey, v.APIKeys)),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -808,7 +860,7 @@ func (v *tavilyConfigV0) ToTavilyConfig() TavilyConfig {
|
||||
Enabled: v.Enabled,
|
||||
BaseURL: v.BaseURL,
|
||||
MaxResults: v.MaxResults,
|
||||
APIKeys: toSecureStrings(MergeAPIKeys(v.APIKey, v.APIKeys)),
|
||||
APIKeys: toSecureStrings(mergeAPIKeys(v.APIKey, v.APIKeys)),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -823,7 +875,7 @@ func (v *perplexityConfigV0) ToPerplexityConfig() PerplexityConfig {
|
||||
return PerplexityConfig{
|
||||
Enabled: v.Enabled,
|
||||
MaxResults: v.MaxResults,
|
||||
APIKeys: toSecureStrings(MergeAPIKeys(v.APIKey, v.APIKeys)),
|
||||
APIKeys: toSecureStrings(mergeAPIKeys(v.APIKey, v.APIKeys)),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,327 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/credential"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
)
|
||||
|
||||
// FlexibleStringSlice is a []string that also accepts JSON numbers,
|
||||
// so allow_from can contain both "123" and 123.
|
||||
// It also supports parsing comma-separated strings from environment variables,
|
||||
// including both English (,) and Chinese (,) commas.
|
||||
type FlexibleStringSlice []string
|
||||
|
||||
func (f *FlexibleStringSlice) UnmarshalJSON(data []byte) error {
|
||||
// Accept a single JSON string for convenience, e.g.:
|
||||
// "text": "Thinking..."
|
||||
var singleString string
|
||||
if err := json.Unmarshal(data, &singleString); err == nil {
|
||||
*f = FlexibleStringSlice{singleString}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Accept a single JSON number too, to keep symmetry with mixed allow_from
|
||||
// payloads that may contain numeric identifiers.
|
||||
var singleNumber float64
|
||||
if err := json.Unmarshal(data, &singleNumber); err == nil {
|
||||
*f = FlexibleStringSlice{fmt.Sprintf("%.0f", singleNumber)}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Try []string first
|
||||
var ss []string
|
||||
if err := json.Unmarshal(data, &ss); err == nil {
|
||||
*f = ss
|
||||
return nil
|
||||
}
|
||||
|
||||
// Try []interface{} to handle mixed types
|
||||
var raw []any
|
||||
if err := json.Unmarshal(data, &raw); err != nil {
|
||||
var s string
|
||||
// fail over to compatible to old format string
|
||||
if err = json.Unmarshal(data, &s); err != nil {
|
||||
return err
|
||||
}
|
||||
*f = []string{s}
|
||||
return nil
|
||||
}
|
||||
|
||||
result := make([]string, 0, len(raw))
|
||||
for _, v := range raw {
|
||||
switch val := v.(type) {
|
||||
case string:
|
||||
result = append(result, val)
|
||||
case float64:
|
||||
result = append(result, fmt.Sprintf("%.0f", val))
|
||||
default:
|
||||
result = append(result, fmt.Sprintf("%v", val))
|
||||
}
|
||||
}
|
||||
*f = result
|
||||
return nil
|
||||
}
|
||||
|
||||
// UnmarshalText implements encoding.TextUnmarshaler to support env variable parsing.
|
||||
// It handles comma-separated values with both English (,) and Chinese (,) commas.
|
||||
func (f *FlexibleStringSlice) UnmarshalText(text []byte) error {
|
||||
if len(text) == 0 {
|
||||
*f = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
s := string(text)
|
||||
// Replace Chinese comma with English comma, then split
|
||||
s = strings.ReplaceAll(s, ",", ",")
|
||||
parts := strings.Split(s, ",")
|
||||
|
||||
result := make([]string, 0, len(parts))
|
||||
for _, part := range parts {
|
||||
part = strings.TrimSpace(part)
|
||||
if part != "" {
|
||||
result = append(result, part)
|
||||
}
|
||||
}
|
||||
*f = result
|
||||
return nil
|
||||
}
|
||||
|
||||
const (
|
||||
notHere = `"[NOT_HERE]"`
|
||||
)
|
||||
|
||||
// SecureStrings is a slice of SecureString
|
||||
type SecureStrings []*SecureString
|
||||
|
||||
// Values returns the decrypted/resolved values
|
||||
func (s *SecureStrings) Values() []string {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
keys := make([]string, len(*s))
|
||||
for i, k := range *s {
|
||||
keys[i] = k.String()
|
||||
}
|
||||
return unique(keys)
|
||||
}
|
||||
|
||||
func SimpleSecureStrings(val ...string) SecureStrings {
|
||||
val = unique(val)
|
||||
vv := make(SecureStrings, len(val))
|
||||
for i, s := range val {
|
||||
vv[i] = NewSecureString(s)
|
||||
}
|
||||
return vv
|
||||
}
|
||||
|
||||
// unique returns a new slice with duplicate elements removed.
|
||||
func unique[T comparable](input []T) []T {
|
||||
m := make(map[T]struct{})
|
||||
var result []T
|
||||
for _, v := range input {
|
||||
if _, ok := m[v]; !ok {
|
||||
m[v] = struct{}{}
|
||||
result = append(result, v)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (s SecureStrings) MarshalJSON() ([]byte, error) {
|
||||
return []byte(notHere), nil
|
||||
}
|
||||
|
||||
func (s *SecureStrings) UnmarshalJSON(value []byte) error {
|
||||
if string(value) == notHere {
|
||||
return nil
|
||||
}
|
||||
var v []*SecureString
|
||||
err := json.Unmarshal(value, &v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*s = v
|
||||
return nil
|
||||
}
|
||||
|
||||
// SecureString the string value that can be decrypted or resolved
|
||||
//
|
||||
//nolint:recvcheck
|
||||
type SecureString struct {
|
||||
resolved string // Decrypted/resolved value returned by String()
|
||||
raw string // Persisted raw value (enc://, file://, or plaintext)
|
||||
}
|
||||
|
||||
func callerFromYaml() bool {
|
||||
_, file, _, ok := runtime.Caller(2)
|
||||
if ok {
|
||||
d := filepath.Dir(file)
|
||||
// check the caller is from yaml.v
|
||||
if !strings.Contains(d, "yaml.v") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// IsZero returns true if the SecureString is empty
|
||||
// if caller not yaml, just return true for prevent marshal this field
|
||||
func (s SecureString) IsZero() bool {
|
||||
if callerFromYaml() {
|
||||
return true
|
||||
}
|
||||
return s.resolved == ""
|
||||
}
|
||||
|
||||
func NewSecureString(value string) *SecureString {
|
||||
s := &SecureString{}
|
||||
if err := s.fromRaw(value); err != nil {
|
||||
logger.Warn(fmt.Sprintf("NewSecureString.fromRaw error: %s", err))
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *SecureString) String() string {
|
||||
if s == nil {
|
||||
return ""
|
||||
}
|
||||
return s.resolved
|
||||
}
|
||||
|
||||
func (s *SecureString) Set(value string) *SecureString {
|
||||
s.resolved = value
|
||||
s.raw = ""
|
||||
return s
|
||||
}
|
||||
|
||||
func (s SecureString) MarshalJSON() ([]byte, error) {
|
||||
return []byte(notHere), nil
|
||||
}
|
||||
|
||||
func (s *SecureString) UnmarshalJSON(value []byte) error {
|
||||
if string(value) == notHere {
|
||||
return nil
|
||||
}
|
||||
var v string
|
||||
if err := json.Unmarshal(value, &v); err != nil {
|
||||
return err
|
||||
}
|
||||
return s.fromRaw(v)
|
||||
}
|
||||
|
||||
func (s SecureString) MarshalYAML() (any, error) {
|
||||
// Preserve raw value if it is already a reference (enc:// or file://)
|
||||
if strings.HasPrefix(s.raw, credential.EncScheme) || strings.HasPrefix(s.raw, credential.FileScheme) {
|
||||
return s.raw, nil
|
||||
}
|
||||
// If resolved is a reference format (e.g. set via Set), copy back to raw
|
||||
if strings.HasPrefix(s.resolved, credential.EncScheme) || strings.HasPrefix(s.resolved, credential.FileScheme) {
|
||||
s.raw = s.resolved
|
||||
return s.raw, nil
|
||||
}
|
||||
// Try to encrypt the resolved value
|
||||
if passphrase := credential.PassphraseProvider(); passphrase != "" {
|
||||
encrypted, err := credential.Encrypt(passphrase, "", s.resolved)
|
||||
if err != nil {
|
||||
logger.Errorf("Encrypt error: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
s.raw = encrypted
|
||||
} else {
|
||||
s.raw = s.resolved
|
||||
}
|
||||
return s.raw, nil
|
||||
}
|
||||
|
||||
func (s *SecureString) UnmarshalYAML(value *yaml.Node) error {
|
||||
return s.fromRaw(value.Value)
|
||||
}
|
||||
|
||||
func (s *SecureString) fromRaw(v string) error {
|
||||
s.raw = v
|
||||
vv, err := resolveKey(v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.resolved = vv
|
||||
return nil
|
||||
}
|
||||
|
||||
var (
|
||||
secResolverMu sync.RWMutex
|
||||
secResolver *credential.Resolver
|
||||
)
|
||||
|
||||
func updateResolver(path string) {
|
||||
secResolverMu.Lock()
|
||||
defer secResolverMu.Unlock()
|
||||
secResolver = credential.NewResolver(path)
|
||||
}
|
||||
|
||||
func resolveKey(v string) (string, error) {
|
||||
secResolverMu.RLock()
|
||||
resolver := secResolver
|
||||
secResolverMu.RUnlock()
|
||||
if resolver == nil {
|
||||
resolver = credential.NewResolver("")
|
||||
}
|
||||
if strings.HasPrefix(v, "enc://") || strings.HasPrefix(v, "file://") {
|
||||
decrypted, err := resolver.Resolve(v)
|
||||
if err != nil {
|
||||
logger.Errorf("Resolve error: %v", err)
|
||||
return "", err
|
||||
}
|
||||
return decrypted, nil
|
||||
}
|
||||
return v, nil
|
||||
}
|
||||
|
||||
func (s *SecureString) UnmarshalText(text []byte) error {
|
||||
v := string(text)
|
||||
return s.fromRaw(v)
|
||||
}
|
||||
|
||||
type SecureModelList []*ModelConfig
|
||||
|
||||
func (v *SecureModelList) UnmarshalYAML(value *yaml.Node) error {
|
||||
mm := make(map[string]*ModelConfig)
|
||||
if err := value.Decode(&mm); err != nil {
|
||||
logger.Errorf("Decode error: %v", err)
|
||||
return err
|
||||
}
|
||||
nameList := toNameIndex(*v)
|
||||
for i, m := range *v {
|
||||
sec := mm[nameList[i]]
|
||||
if sec == nil {
|
||||
sec = mm[m.ModelName]
|
||||
}
|
||||
if sec != nil {
|
||||
m.APIKeys = sec.APIKeys
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (v SecureModelList) MarshalYAML() (any, error) {
|
||||
type onlySecureData struct {
|
||||
APIKeys SecureStrings `yaml:"api_keys,omitempty"`
|
||||
}
|
||||
mm := make(map[string]onlySecureData)
|
||||
nameList := toNameIndex(v)
|
||||
for i, m := range v {
|
||||
mm[nameList[i]] = onlySecureData{
|
||||
APIKeys: m.APIKeys,
|
||||
}
|
||||
}
|
||||
|
||||
return mm, nil
|
||||
}
|
||||
@@ -0,0 +1,145 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/caarlos0/env/v11"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"gopkg.in/yaml.v3"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/credential"
|
||||
)
|
||||
|
||||
func TestLoadSecurityValue(t *testing.T) {
|
||||
type valueStruct struct {
|
||||
Url string `json:"url,omitempty" yaml:"-"`
|
||||
Token *SecureString `json:"token,omitempty" yaml:"token,omitempty" env:"PICO_TOKEN"`
|
||||
ApiKeys SecureStrings `json:"api_keys,omitempty" yaml:"api_keys,omitempty" env:"PICO_API_KEYS"`
|
||||
}
|
||||
|
||||
type testStruct struct {
|
||||
Pico *valueStruct `json:"pico,omitempty" yaml:"pico,omitempty"`
|
||||
}
|
||||
|
||||
v1 := &testStruct{
|
||||
Pico: &valueStruct{
|
||||
Url: "https://example.com",
|
||||
Token: NewSecureString("token1"),
|
||||
ApiKeys: SecureStrings{NewSecureString("api-key1"), NewSecureString("api-key2")},
|
||||
},
|
||||
}
|
||||
bytes, err := yaml.Marshal(v1)
|
||||
assert.NoError(t, err)
|
||||
jsonBytes, err := json.Marshal(v1)
|
||||
assert.NoError(t, err)
|
||||
const want = `pico:
|
||||
token: token1
|
||||
api_keys:
|
||||
- api-key1
|
||||
- api-key2
|
||||
`
|
||||
const jsonPost = `{"pico":{"url":"https://example.com","token":"token0"}}`
|
||||
v0 := &testStruct{}
|
||||
err = json.Unmarshal([]byte(jsonPost), v0)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "https://example.com", v0.Pico.Url)
|
||||
assert.Equal(t, "token0", v0.Pico.Token.String())
|
||||
|
||||
const jsonWant = `{"pico":{"url":"https://example.com","token":"[NOT_HERE]","api_keys":"[NOT_HERE]"}}`
|
||||
assert.Equal(t, want, string(bytes))
|
||||
assert.Equal(t, jsonWant, string(jsonBytes))
|
||||
|
||||
v2 := &testStruct{}
|
||||
err = json.Unmarshal(jsonBytes, v2)
|
||||
assert.NoError(t, err)
|
||||
err = yaml.Unmarshal(bytes, v2)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "https://example.com", v2.Pico.Url)
|
||||
if v2.Pico.Token != nil {
|
||||
assert.Equal(t, "token1", v2.Pico.Token.String())
|
||||
assert.Equal(t, "token1", v2.Pico.Token.raw)
|
||||
}
|
||||
|
||||
v2.Pico.Token = NewSecureString("token1")
|
||||
v2.Pico.Token.raw = "abc"
|
||||
err = yaml.Unmarshal(bytes, v2)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "token1", v2.Pico.Token.raw)
|
||||
|
||||
os.Setenv("PICO_TOKEN", "token_env")
|
||||
err = env.Parse(v2)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, v2.Pico.Token)
|
||||
assert.Equal(t, "token1", v2.Pico.Token.String())
|
||||
|
||||
v3 := &testStruct{Pico: &valueStruct{}}
|
||||
err = env.Parse(v3)
|
||||
assert.NoError(t, err)
|
||||
if v3.Pico.Token != nil {
|
||||
assert.Equal(t, "token_env", v3.Pico.Token.String())
|
||||
}
|
||||
|
||||
type toolsStruct struct {
|
||||
Pico valueStruct `json:"pico,omitempty" yaml:"pico,omitempty"`
|
||||
}
|
||||
|
||||
type testStruct2 struct {
|
||||
Tools toolsStruct `json:"tools,omitempty" yaml:",inline"`
|
||||
}
|
||||
|
||||
v4 := &testStruct2{
|
||||
Tools: toolsStruct{
|
||||
Pico: valueStruct{
|
||||
Url: "https://example.com",
|
||||
Token: NewSecureString("token1"),
|
||||
ApiKeys: SecureStrings{NewSecureString("api-key1"), NewSecureString("api-key2")},
|
||||
},
|
||||
},
|
||||
}
|
||||
bytes, err = yaml.Marshal(v4)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, want, string(bytes))
|
||||
jsonBytes, err = json.Marshal(v4)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(
|
||||
t,
|
||||
`{"tools":{"pico":{"url":"https://example.com","token":"[NOT_HERE]","api_keys":"[NOT_HERE]"}}}`,
|
||||
string(jsonBytes),
|
||||
)
|
||||
|
||||
v5 := &testStruct2{}
|
||||
err = json.Unmarshal(jsonBytes, v5)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "https://example.com", v5.Tools.Pico.Url)
|
||||
err = yaml.Unmarshal(bytes, v5)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, v5.Tools.Pico.Token)
|
||||
assert.Equal(t, "token1", v5.Tools.Pico.Token.raw)
|
||||
|
||||
dir := t.TempDir()
|
||||
sshKeyPath := filepath.Join(dir, "picoclaw_ed25519.key")
|
||||
if err = os.WriteFile(sshKeyPath, []byte("fake-ssh-key-material\n"), 0o600); err != nil {
|
||||
t.Fatalf("setup: %v", err)
|
||||
}
|
||||
|
||||
const passphrase = "test-passphrase-32bytes-long-ok!"
|
||||
|
||||
t.Setenv(credential.SSHKeyPathEnvVar, sshKeyPath)
|
||||
|
||||
t.Setenv(credential.PassphraseEnvVar, passphrase)
|
||||
|
||||
v5.Tools.Pico.Token.Set("newtoken1")
|
||||
v5.Tools.Pico.ApiKeys[0].Set("newapi-key1")
|
||||
bytes, err = yaml.Marshal(v5)
|
||||
assert.NoError(t, err)
|
||||
t.Logf("yaml: %s", string(bytes))
|
||||
|
||||
v6 := &testStruct2{}
|
||||
err = yaml.Unmarshal(bytes, v6)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, v6.Tools.Pico.Token)
|
||||
assert.Equal(t, "newtoken1", v6.Tools.Pico.Token.String())
|
||||
}
|
||||
@@ -1673,3 +1673,163 @@ func TestFilterSensitiveData_AllTokenTypes(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// makeBackup tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// TestMakeBackup_WithDateSuffix verifies backup files include a date suffix.
|
||||
func TestMakeBackup_WithDateSuffix(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
configPath := filepath.Join(dir, "config.json")
|
||||
if err := os.WriteFile(configPath, []byte(`{"version":2}`), 0o600); err != nil {
|
||||
t.Fatalf("WriteFile: %v", err)
|
||||
}
|
||||
|
||||
if err := makeBackup(configPath); err != nil {
|
||||
t.Fatalf("makeBackup: %v", err)
|
||||
}
|
||||
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadDir: %v", err)
|
||||
}
|
||||
|
||||
var hasDatedBackup bool
|
||||
for _, e := range entries {
|
||||
if matched, _ := filepath.Match("config.json.20*.bak", e.Name()); matched {
|
||||
hasDatedBackup = true
|
||||
// Verify backup content matches original
|
||||
bakPath := filepath.Join(dir, e.Name())
|
||||
data, err := os.ReadFile(bakPath)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadFile backup: %v", err)
|
||||
}
|
||||
if string(data) != `{"version":2}` {
|
||||
t.Errorf("backup content = %q, want original content", string(data))
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasDatedBackup {
|
||||
t.Error("expected backup file with date suffix pattern config.json.20*.bak")
|
||||
}
|
||||
}
|
||||
|
||||
// TestMakeBackup_AlsoBacksSecurityFile verifies that the security config file
|
||||
// is also backed up with the same date suffix.
|
||||
func TestMakeBackup_AlsoBacksSecurityFile(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
configPath := filepath.Join(dir, "config.json")
|
||||
secPath := securityPath(configPath)
|
||||
|
||||
os.WriteFile(configPath, []byte(`{"version":2}`), 0o600)
|
||||
os.WriteFile(secPath, []byte(`model_list:\n test:0:\n api_keys:\n - "sk-test"\n`), 0o600)
|
||||
|
||||
if err := makeBackup(configPath); err != nil {
|
||||
t.Fatalf("makeBackup: %v", err)
|
||||
}
|
||||
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadDir: %v", err)
|
||||
}
|
||||
|
||||
configBackups := 0
|
||||
secBackups := 0
|
||||
for _, e := range entries {
|
||||
if matched, _ := filepath.Match("config.json.20*.bak", e.Name()); matched {
|
||||
configBackups++
|
||||
}
|
||||
if matched, _ := filepath.Match(".security.yml.20*.bak", e.Name()); matched {
|
||||
secBackups++
|
||||
}
|
||||
}
|
||||
if configBackups != 1 {
|
||||
t.Errorf("expected 1 config backup, got %d", configBackups)
|
||||
}
|
||||
if secBackups != 1 {
|
||||
t.Errorf("expected 1 security backup, got %d", secBackups)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMakeBackup_NonexistentFileSkipsBackup verifies that makeBackup returns nil
|
||||
// when the config file does not exist (no error, no panic).
|
||||
func TestMakeBackup_NonexistentFileSkipsBackup(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
configPath := filepath.Join(dir, "nonexistent.json")
|
||||
|
||||
if err := makeBackup(configPath); err != nil {
|
||||
t.Fatalf("makeBackup on nonexistent file should return nil, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMakeBackup_OnlyConfigNoSecurity verifies backup succeeds when only
|
||||
// the config file exists and no security file.
|
||||
func TestMakeBackup_OnlyConfigNoSecurity(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
configPath := filepath.Join(dir, "config.json")
|
||||
os.WriteFile(configPath, []byte(`{"version":2}`), 0o600)
|
||||
|
||||
if err := makeBackup(configPath); err != nil {
|
||||
t.Fatalf("makeBackup: %v", err)
|
||||
}
|
||||
|
||||
entries, _ := os.ReadDir(dir)
|
||||
configBackups := 0
|
||||
secBackups := 0
|
||||
for _, e := range entries {
|
||||
if matched, _ := filepath.Match("config.json.20*.bak", e.Name()); matched {
|
||||
configBackups++
|
||||
}
|
||||
if matched, _ := filepath.Match(".security.yml.20*.bak", e.Name()); matched {
|
||||
secBackups++
|
||||
}
|
||||
}
|
||||
if configBackups != 1 {
|
||||
t.Errorf("expected 1 config backup, got %d", configBackups)
|
||||
}
|
||||
if secBackups != 0 {
|
||||
t.Errorf("expected 0 security backups when no security file exists, got %d", secBackups)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMakeBackup_SameDateSuffix verifies that config and security backups
|
||||
// share the same date suffix (they are created in the same makeBackup call).
|
||||
func TestMakeBackup_SameDateSuffix(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
configPath := filepath.Join(dir, "config.json")
|
||||
secPath := securityPath(configPath)
|
||||
|
||||
os.WriteFile(configPath, []byte(`{"version":2}`), 0o600)
|
||||
os.WriteFile(secPath, []byte(`key: value`), 0o600)
|
||||
|
||||
if err := makeBackup(configPath); err != nil {
|
||||
t.Fatalf("makeBackup: %v", err)
|
||||
}
|
||||
|
||||
entries, _ := os.ReadDir(dir)
|
||||
var configDate, secDate string
|
||||
for _, e := range entries {
|
||||
name := e.Name()
|
||||
// Extract date part: after the last . before .bak
|
||||
// e.g. config.json.20260330.bak → 20260330
|
||||
if strings.HasPrefix(name, "config.json.") && strings.HasSuffix(name, ".bak") {
|
||||
configDate = strings.TrimPrefix(name, "config.json.")
|
||||
configDate = strings.TrimSuffix(configDate, ".bak")
|
||||
}
|
||||
if strings.HasPrefix(name, ".security.yml.") && strings.HasSuffix(name, ".bak") {
|
||||
secDate = strings.TrimPrefix(name, ".security.yml.")
|
||||
secDate = strings.TrimSuffix(secDate, ".bak")
|
||||
}
|
||||
}
|
||||
if configDate == "" {
|
||||
t.Fatal("config backup file not found")
|
||||
}
|
||||
if secDate == "" {
|
||||
t.Fatal("security backup file not found")
|
||||
}
|
||||
if configDate != secDate {
|
||||
t.Errorf("config backup date = %q, security backup date = %q, should match", configDate, secDate)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -534,3 +534,26 @@ func loadConfig(data []byte) (*Config, error) {
|
||||
}
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func mergeAPIKeys(apiKey string, apiKeys []string) []string {
|
||||
seen := make(map[string]struct{})
|
||||
var all []string
|
||||
|
||||
if k := strings.TrimSpace(apiKey); k != "" {
|
||||
if _, exists := seen[k]; !exists {
|
||||
seen[k] = struct{}{}
|
||||
all = append(all, k)
|
||||
}
|
||||
}
|
||||
|
||||
for _, k := range apiKeys {
|
||||
if trimmed := strings.TrimSpace(k); trimmed != "" {
|
||||
if _, exists := seen[trimmed]; !exists {
|
||||
seen[trimmed] = struct{}{}
|
||||
all = append(all, trimmed)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return all
|
||||
}
|
||||
|
||||
@@ -681,3 +681,473 @@ web:
|
||||
t.Error("Discord token not preserved in .security.yml file")
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// V1 → V2 migration tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// TestMigrateModelEnabled_APIKeysInferredEnabled verifies that models with API keys
|
||||
// are marked as enabled during V1→V2 migration.
|
||||
func TestMigrateModelEnabled_APIKeysInferredEnabled(t *testing.T) {
|
||||
v1 := &configV1{Config: Config{
|
||||
ModelList: []*ModelConfig{
|
||||
{ModelName: "gpt-4", Model: "openai/gpt-4", APIKeys: SimpleSecureStrings("sk-test")},
|
||||
{ModelName: "claude", Model: "anthropic/claude", APIKeys: SimpleSecureStrings("sk-ant")},
|
||||
},
|
||||
}}
|
||||
v1.migrateModelEnabled()
|
||||
for _, m := range v1.ModelList {
|
||||
if !m.Enabled {
|
||||
t.Errorf("model %q with API key should be enabled", m.ModelName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestMigrateModelEnabled_LocalModelInferredEnabled verifies that the reserved
|
||||
// "local-model" entry is enabled even without API keys.
|
||||
func TestMigrateModelEnabled_LocalModelInferredEnabled(t *testing.T) {
|
||||
v1 := &configV1{Config: Config{
|
||||
ModelList: []*ModelConfig{
|
||||
{ModelName: "local-model", Model: "vllm/custom-model", APIBase: "http://localhost:8000/v1"},
|
||||
},
|
||||
}}
|
||||
v1.migrateModelEnabled()
|
||||
if !v1.ModelList[0].Enabled {
|
||||
t.Error("local-model should be enabled")
|
||||
}
|
||||
}
|
||||
|
||||
// TestMigrateModelEnabled_NoKeyStaysDisabled verifies that models without API keys
|
||||
// and not named "local-model" remain disabled.
|
||||
func TestMigrateModelEnabled_NoKeyStaysDisabled(t *testing.T) {
|
||||
v1 := &configV1{Config: Config{
|
||||
ModelList: []*ModelConfig{
|
||||
{ModelName: "gpt-4", Model: "openai/gpt-4"},
|
||||
{ModelName: "claude", Model: "anthropic/claude"},
|
||||
},
|
||||
}}
|
||||
v1.migrateModelEnabled()
|
||||
for _, m := range v1.ModelList {
|
||||
if m.Enabled {
|
||||
t.Errorf("model %q without API key should stay disabled", m.ModelName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestMigrateModelEnabled_ExplicitEnabledPreserved verifies that a model with
|
||||
// explicitly enabled=true is NOT overridden by the migration.
|
||||
func TestMigrateModelEnabled_ExplicitEnabledPreserved(t *testing.T) {
|
||||
v1 := &configV1{Config: Config{
|
||||
ModelList: []*ModelConfig{
|
||||
{ModelName: "gpt-4", Model: "openai/gpt-4", APIKeys: SimpleSecureStrings("sk-test"), Enabled: true},
|
||||
},
|
||||
}}
|
||||
v1.migrateModelEnabled()
|
||||
if !v1.ModelList[0].Enabled {
|
||||
t.Error("explicitly enabled model should remain enabled")
|
||||
}
|
||||
}
|
||||
|
||||
// TestMigrateModelEnabled_ExplicitDisabledNotOverridden verifies that a model with
|
||||
// explicitly enabled=false and API keys gets enabled during migration.
|
||||
// Note: since Go's zero value for bool is false and JSON omitempty omits false,
|
||||
// migration cannot distinguish "explicitly false" from "field absent". Both cases
|
||||
// get the same inference treatment.
|
||||
func TestMigrateModelEnabled_ExplicitDisabledNotOverridden(t *testing.T) {
|
||||
v1 := &configV1{Config: Config{
|
||||
ModelList: []*ModelConfig{
|
||||
{ModelName: "gpt-4", Model: "openai/gpt-4", APIKeys: SimpleSecureStrings("sk-test"), Enabled: false},
|
||||
},
|
||||
}}
|
||||
v1.migrateModelEnabled()
|
||||
// Even though Enabled was set to false, migration infers it as true because
|
||||
// the migration cannot distinguish from a missing field (both are zero value).
|
||||
if !v1.ModelList[0].Enabled {
|
||||
t.Error("model with API key should be enabled by migration inference")
|
||||
}
|
||||
}
|
||||
|
||||
// TestMigrateModelEnabled_Mixed verifies a mix of models.
|
||||
func TestMigrateModelEnabled_Mixed(t *testing.T) {
|
||||
v1 := &configV1{Config: Config{
|
||||
ModelList: []*ModelConfig{
|
||||
{ModelName: "with-key", Model: "openai/gpt-4", APIKeys: SimpleSecureStrings("sk-test")},
|
||||
{ModelName: "no-key", Model: "openai/gpt-4"},
|
||||
{ModelName: "local-model", Model: "vllm/custom"},
|
||||
{
|
||||
ModelName: "disabled-explicit",
|
||||
Model: "openai/gpt-4",
|
||||
APIKeys: SimpleSecureStrings("sk-test"),
|
||||
Enabled: false,
|
||||
},
|
||||
},
|
||||
}}
|
||||
v1.migrateModelEnabled()
|
||||
|
||||
assertEnabled := func(name string, want bool) {
|
||||
for _, m := range v1.ModelList {
|
||||
if m.ModelName == name {
|
||||
if m.Enabled != want {
|
||||
t.Errorf("model %q: Enabled=%v, want %v", name, m.Enabled, want)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
t.Errorf("model %q not found", name)
|
||||
}
|
||||
|
||||
assertEnabled("with-key", true)
|
||||
assertEnabled("no-key", false)
|
||||
assertEnabled("local-model", true)
|
||||
assertEnabled("disabled-explicit", true) // false is zero value, migration infers from API key
|
||||
}
|
||||
|
||||
// TestMigrateChannelConfigs_DiscordMentionOnly verifies Discord mention_only migration.
|
||||
func TestMigrateChannelConfigs_DiscordMentionOnly(t *testing.T) {
|
||||
v1 := &configV1{Config: Config{
|
||||
Channels: ChannelsConfig{
|
||||
Discord: DiscordConfig{
|
||||
MentionOnly: true,
|
||||
},
|
||||
},
|
||||
}}
|
||||
v1.migrateChannelConfigs()
|
||||
if !v1.Channels.Discord.GroupTrigger.MentionOnly {
|
||||
t.Error("Discord GroupTrigger.MentionOnly should be set to true")
|
||||
}
|
||||
}
|
||||
|
||||
// TestMigrateChannelConfigs_DiscordAlreadyMigrated is a no-op test.
|
||||
func TestMigrateChannelConfigs_DiscordAlreadyMigrated(t *testing.T) {
|
||||
v1 := &configV1{Config: Config{
|
||||
Channels: ChannelsConfig{
|
||||
Discord: DiscordConfig{
|
||||
GroupTrigger: GroupTriggerConfig{MentionOnly: true},
|
||||
},
|
||||
},
|
||||
}}
|
||||
v1.migrateChannelConfigs()
|
||||
}
|
||||
|
||||
// TestMigrateChannelConfigs_OneBotPrefix verifies OneBot prefix migration.
|
||||
func TestMigrateChannelConfigs_OneBotPrefix(t *testing.T) {
|
||||
v1 := &configV1{Config: Config{
|
||||
Channels: ChannelsConfig{
|
||||
OneBot: OneBotConfig{
|
||||
GroupTriggerPrefix: []string{"/"},
|
||||
},
|
||||
},
|
||||
}}
|
||||
v1.migrateChannelConfigs()
|
||||
if len(v1.Channels.OneBot.GroupTrigger.Prefixes) != 1 || v1.Channels.OneBot.GroupTrigger.Prefixes[0] != "/" {
|
||||
t.Errorf("OneBot GroupTrigger.Prefixes = %v, want [\"/\"]", v1.Channels.OneBot.GroupTrigger.Prefixes)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMigrateConfigV1_Combined verifies that configV1.Migrate applies both migrations.
|
||||
func TestMigrateConfigV1_Combined(t *testing.T) {
|
||||
v1 := &configV1{Config: Config{
|
||||
ModelList: []*ModelConfig{
|
||||
{ModelName: "gpt-4", Model: "openai/gpt-4", APIKeys: SimpleSecureStrings("sk-test")},
|
||||
},
|
||||
Channels: ChannelsConfig{
|
||||
Discord: DiscordConfig{MentionOnly: true},
|
||||
},
|
||||
}}
|
||||
result, err := v1.Migrate()
|
||||
if err != nil {
|
||||
t.Fatalf("Migrate: %v", err)
|
||||
}
|
||||
|
||||
if !result.ModelList[0].Enabled {
|
||||
t.Error("model with API key should be enabled after V1→V2 migration")
|
||||
}
|
||||
if !result.Channels.Discord.GroupTrigger.MentionOnly {
|
||||
t.Error("Discord mention_only should be migrated after V1→V2 migration")
|
||||
}
|
||||
}
|
||||
|
||||
// TestLoadConfig_V1ToV2Migration verifies end-to-end V1→V2 config migration
|
||||
// through LoadConfig, including Enabled field inference and version bump.
|
||||
func TestLoadConfig_V1ToV2Migration(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
configPath := filepath.Join(tmpDir, "config.json")
|
||||
|
||||
// Write a V1 config with model_list but no "enabled" field
|
||||
v1Config := `{
|
||||
"version": 1,
|
||||
"model_list": [
|
||||
{
|
||||
"model_name": "gpt-4",
|
||||
"model": "openai/gpt-4"
|
||||
},
|
||||
{
|
||||
"model_name": "local-model",
|
||||
"model": "vllm/custom-model",
|
||||
"api_base": "http://localhost:8000/v1"
|
||||
}
|
||||
],
|
||||
"channels": {
|
||||
"discord": {
|
||||
"mention_only": true
|
||||
}
|
||||
},
|
||||
"gateway": {"host": "127.0.0.1", "port": 18790}
|
||||
}`
|
||||
|
||||
if err := os.WriteFile(configPath, []byte(v1Config), 0o600); err != nil {
|
||||
t.Fatalf("WriteFile: %v", err)
|
||||
}
|
||||
|
||||
cfg, err := LoadConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig: %v", err)
|
||||
}
|
||||
|
||||
// Version should be bumped to 2
|
||||
if cfg.Version != CurrentVersion {
|
||||
t.Errorf("Version = %d, want %d", cfg.Version, CurrentVersion)
|
||||
}
|
||||
|
||||
// gpt-4 has no API key → disabled
|
||||
gpt4, err := cfg.GetModelConfig("gpt-4")
|
||||
if err != nil {
|
||||
t.Fatalf("GetModelConfig(gpt-4): %v", err)
|
||||
}
|
||||
if gpt4.Enabled {
|
||||
t.Error("gpt-4 without API key should be disabled after migration")
|
||||
}
|
||||
|
||||
// local-model → enabled
|
||||
local, err := cfg.GetModelConfig("local-model")
|
||||
if err != nil {
|
||||
t.Fatalf("GetModelConfig(local-model): %v", err)
|
||||
}
|
||||
if !local.Enabled {
|
||||
t.Error("local-model should be enabled after migration")
|
||||
}
|
||||
|
||||
// Discord channel config should be migrated
|
||||
if !cfg.Channels.Discord.GroupTrigger.MentionOnly {
|
||||
t.Error("Discord mention_only should be migrated to group_trigger.mention_only")
|
||||
}
|
||||
|
||||
// Verify backup was created with date suffix
|
||||
entries, err := os.ReadDir(tmpDir)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadDir: %v", err)
|
||||
}
|
||||
var hasBackup bool
|
||||
for _, e := range entries {
|
||||
if matched, _ := filepath.Match("config.json.20*.bak", e.Name()); matched {
|
||||
hasBackup = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasBackup {
|
||||
t.Error("expected backup file with date suffix to be created")
|
||||
}
|
||||
|
||||
// Verify the saved config on disk now has version 2
|
||||
saved, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadFile saved config: %v", err)
|
||||
}
|
||||
var versionCheck struct {
|
||||
Version int `json:"version"`
|
||||
}
|
||||
if err := json.Unmarshal(saved, &versionCheck); err != nil {
|
||||
t.Fatalf("Unmarshal saved config: %v", err)
|
||||
}
|
||||
if versionCheck.Version != 2 {
|
||||
t.Errorf("saved config version = %d, want 2", versionCheck.Version)
|
||||
}
|
||||
}
|
||||
|
||||
// TestLoadConfig_V1WithAPIKeysInferredEnabled verifies that V1 configs with
|
||||
// API keys in the security file get Enabled=true after migration.
|
||||
func TestLoadConfig_V1WithAPIKeysInferredEnabled(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
configPath := filepath.Join(tmpDir, "config.json")
|
||||
secPath := securityPath(configPath)
|
||||
|
||||
v1Config := `{
|
||||
"version": 1,
|
||||
"model_list": [
|
||||
{"model_name": "gpt-4", "model": "openai/gpt-4"},
|
||||
{"model_name": "claude", "model": "anthropic/claude"}
|
||||
],
|
||||
"gateway": {"host": "127.0.0.1", "port": 18790}
|
||||
}`
|
||||
|
||||
securityConfig := `model_list:
|
||||
gpt-4:0:
|
||||
api_keys:
|
||||
- "sk-gpt-key"
|
||||
claude:0:
|
||||
api_keys:
|
||||
- "sk-claude-key"
|
||||
`
|
||||
|
||||
if err := os.WriteFile(configPath, []byte(v1Config), 0o600); err != nil {
|
||||
t.Fatalf("WriteFile: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(secPath, []byte(securityConfig), 0o600); err != nil {
|
||||
t.Fatalf("WriteFile security: %v", err)
|
||||
}
|
||||
|
||||
cfg, err := LoadConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig: %v", err)
|
||||
}
|
||||
|
||||
for _, m := range cfg.ModelList {
|
||||
if !m.Enabled {
|
||||
t.Errorf("model %q with API key in security file should be enabled", m.ModelName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestLoadConfig_V2DirectLoad verifies that V2 configs load directly without
|
||||
// running any migration.
|
||||
func TestLoadConfig_V2DirectLoad(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
configPath := filepath.Join(tmpDir, "config.json")
|
||||
|
||||
v2Config := `{
|
||||
"version": 2,
|
||||
"model_list": [
|
||||
{
|
||||
"model_name": "gpt-4",
|
||||
"model": "openai/gpt-4",
|
||||
"enabled": true
|
||||
},
|
||||
{
|
||||
"model_name": "claude",
|
||||
"model": "anthropic/claude"
|
||||
}
|
||||
],
|
||||
"gateway": {"host": "127.0.0.1", "port": 18790}
|
||||
}`
|
||||
|
||||
if err := os.WriteFile(configPath, []byte(v2Config), 0o600); err != nil {
|
||||
t.Fatalf("WriteFile: %v", err)
|
||||
}
|
||||
|
||||
cfg, err := LoadConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig: %v", err)
|
||||
}
|
||||
|
||||
if cfg.Version != 2 {
|
||||
t.Errorf("Version = %d, want 2", cfg.Version)
|
||||
}
|
||||
|
||||
gpt4, _ := cfg.GetModelConfig("gpt-4")
|
||||
if !gpt4.Enabled {
|
||||
t.Error("gpt-4 with explicit enabled=true should remain enabled")
|
||||
}
|
||||
|
||||
claude, _ := cfg.GetModelConfig("claude")
|
||||
if claude.Enabled {
|
||||
t.Error("claude without enabled field should be false (no migration for V2)")
|
||||
}
|
||||
|
||||
// No backup should be created for V2 load
|
||||
entries, _ := os.ReadDir(tmpDir)
|
||||
for _, e := range entries {
|
||||
if matched, _ := filepath.Match("config.json.*.bak", e.Name()); matched {
|
||||
t.Errorf("V2 load should not create backup, but found %q", e.Name())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestLoadConfig_V0MigrateProducesV2 verifies that V0→V2 migration produces
|
||||
// correct Enabled fields and version.
|
||||
func TestLoadConfig_V0MigrateProducesV2(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
configPath := filepath.Join(tmpDir, "config.json")
|
||||
|
||||
v0Config := `{
|
||||
"model_list": [
|
||||
{
|
||||
"model_name": "gpt-4",
|
||||
"model": "openai/gpt-4",
|
||||
"api_key": "sk-test"
|
||||
},
|
||||
{
|
||||
"model_name": "claude",
|
||||
"model": "anthropic/claude"
|
||||
},
|
||||
{
|
||||
"model_name": "local-model",
|
||||
"model": "vllm/custom-model"
|
||||
}
|
||||
],
|
||||
"gateway": {"host": "127.0.0.1", "port": 18790}
|
||||
}`
|
||||
|
||||
if err := os.WriteFile(configPath, []byte(v0Config), 0o600); err != nil {
|
||||
t.Fatalf("WriteFile: %v", err)
|
||||
}
|
||||
|
||||
cfg, err := LoadConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig: %v", err)
|
||||
}
|
||||
|
||||
if cfg.Version != CurrentVersion {
|
||||
t.Errorf("Version = %d, want %d", cfg.Version, CurrentVersion)
|
||||
}
|
||||
|
||||
// Check enabled status
|
||||
modelEnabled := func(name string) bool {
|
||||
m, err := cfg.GetModelConfig(name)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return m.Enabled
|
||||
}
|
||||
|
||||
if !modelEnabled("gpt-4") {
|
||||
t.Error("gpt-4 with API key from V0 should be enabled")
|
||||
}
|
||||
if modelEnabled("claude") {
|
||||
t.Error("claude without API key from V0 should be disabled")
|
||||
}
|
||||
if !modelEnabled("local-model") {
|
||||
t.Error("local-model from V0 should be enabled")
|
||||
}
|
||||
}
|
||||
|
||||
// TestLoadConfig_UnsupportedVersion verifies that unsupported versions return an error.
|
||||
func TestLoadConfig_UnsupportedVersion(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
configPath := filepath.Join(tmpDir, "config.json")
|
||||
|
||||
badConfig := `{"version": 99, "gateway": {"host": "127.0.0.1", "port": 18790}}`
|
||||
if err := os.WriteFile(configPath, []byte(badConfig), 0o600); err != nil {
|
||||
t.Fatalf("WriteFile: %v", err)
|
||||
}
|
||||
|
||||
_, err := LoadConfig(configPath)
|
||||
if err == nil {
|
||||
t.Fatal("LoadConfig should return error for unsupported version")
|
||||
}
|
||||
if !containsString(err.Error(), "unsupported config version") {
|
||||
t.Errorf("error = %q, want 'unsupported config version'", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func containsString(s, substr string) bool {
|
||||
return len(s) >= len(substr) && searchString(s, substr)
|
||||
}
|
||||
|
||||
func searchString(s, substr string) bool {
|
||||
for i := 0; i <= len(s)-len(substr); i++ {
|
||||
if s[i:i+len(substr)] == substr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -345,7 +345,7 @@ func TestMergeAPIKeys(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := MergeAPIKeys(tt.apiKey, tt.apiKeys)
|
||||
result := mergeAPIKeys(tt.apiKey, tt.apiKeys)
|
||||
if len(result) != len(tt.expected) {
|
||||
t.Fatalf("expected %d keys, got %d", len(tt.expected), len(result))
|
||||
}
|
||||
|
||||
@@ -7,20 +7,16 @@ package config
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/credential"
|
||||
"github.com/sipeed/picoclaw/pkg/fileutil"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -66,7 +62,6 @@ func saveSecurityConfig(securityPath string, sec *Config) error {
|
||||
return fileutil.WriteFileAtomic(securityPath, buf.Bytes(), 0o600)
|
||||
}
|
||||
|
||||
// SensitiveDataCache caches the compiled regex for filtering sensitive data.
|
||||
// SensitiveDataCache caches the strings.Replacer for filtering sensitive data.
|
||||
// Computed once on first access via sync.Once.
|
||||
type SensitiveDataCache struct {
|
||||
@@ -178,234 +173,3 @@ func collectSensitive(v reflect.Value, values *[]string) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
notHere = `"[NOT_HERE]"`
|
||||
)
|
||||
|
||||
// SecureStrings is a slice of SecureString
|
||||
type SecureStrings []*SecureString
|
||||
|
||||
// Values returns the decrypted/resolved values
|
||||
func (s *SecureStrings) Values() []string {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
keys := make([]string, len(*s))
|
||||
for i, k := range *s {
|
||||
keys[i] = k.String()
|
||||
}
|
||||
return unique(keys)
|
||||
}
|
||||
|
||||
func SimpleSecureStrings(val ...string) SecureStrings {
|
||||
val = unique(val)
|
||||
vv := make(SecureStrings, len(val))
|
||||
for i, s := range val {
|
||||
vv[i] = NewSecureString(s)
|
||||
}
|
||||
return vv
|
||||
}
|
||||
|
||||
// unique returns a new slice with duplicate elements removed.
|
||||
func unique[T comparable](input []T) []T {
|
||||
m := make(map[T]struct{})
|
||||
var result []T
|
||||
for _, v := range input {
|
||||
if _, ok := m[v]; !ok {
|
||||
m[v] = struct{}{}
|
||||
result = append(result, v)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (s SecureStrings) MarshalJSON() ([]byte, error) {
|
||||
return []byte(notHere), nil
|
||||
}
|
||||
|
||||
func (s *SecureStrings) UnmarshalJSON(value []byte) error {
|
||||
if string(value) == notHere {
|
||||
return nil
|
||||
}
|
||||
var v []*SecureString
|
||||
err := json.Unmarshal(value, &v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*s = v
|
||||
return nil
|
||||
}
|
||||
|
||||
// SecureString the string value that can be decrypted or resolved
|
||||
//
|
||||
//nolint:recvcheck
|
||||
type SecureString struct {
|
||||
resolved string // Decrypted/resolved value returned by String()
|
||||
raw string // Persisted raw value (enc://, file://, or plaintext)
|
||||
}
|
||||
|
||||
func callerFromYaml() bool {
|
||||
_, file, _, ok := runtime.Caller(2)
|
||||
if ok {
|
||||
d := filepath.Dir(file)
|
||||
// check the caller is from yaml.v
|
||||
if !strings.Contains(d, "yaml.v") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// IsZero returns true if the SecureString is empty
|
||||
// if caller not yaml, just return true for prevent marshal this field
|
||||
func (s SecureString) IsZero() bool {
|
||||
if callerFromYaml() {
|
||||
return true
|
||||
}
|
||||
return s.resolved == ""
|
||||
}
|
||||
|
||||
func NewSecureString(value string) *SecureString {
|
||||
s := &SecureString{}
|
||||
if err := s.fromRaw(value); err != nil {
|
||||
logger.Warn(fmt.Sprintf("NewSecureString.fromRaw error: %s", err))
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *SecureString) String() string {
|
||||
if s == nil {
|
||||
return ""
|
||||
}
|
||||
return s.resolved
|
||||
}
|
||||
|
||||
func (s *SecureString) Set(value string) *SecureString {
|
||||
s.resolved = value
|
||||
s.raw = ""
|
||||
return s
|
||||
}
|
||||
|
||||
func (s SecureString) MarshalJSON() ([]byte, error) {
|
||||
return []byte(notHere), nil
|
||||
}
|
||||
|
||||
func (s *SecureString) UnmarshalJSON(value []byte) error {
|
||||
if string(value) == notHere {
|
||||
return nil
|
||||
}
|
||||
var v string
|
||||
if err := json.Unmarshal(value, &v); err != nil {
|
||||
return err
|
||||
}
|
||||
return s.fromRaw(v)
|
||||
}
|
||||
|
||||
func (s SecureString) MarshalYAML() (any, error) {
|
||||
// Preserve raw value if it is already a reference (enc:// or file://)
|
||||
if strings.HasPrefix(s.raw, credential.EncScheme) || strings.HasPrefix(s.raw, credential.FileScheme) {
|
||||
return s.raw, nil
|
||||
}
|
||||
// If resolved is a reference format (e.g. set via Set), copy back to raw
|
||||
if strings.HasPrefix(s.resolved, credential.EncScheme) || strings.HasPrefix(s.resolved, credential.FileScheme) {
|
||||
s.raw = s.resolved
|
||||
return s.raw, nil
|
||||
}
|
||||
// Try to encrypt the resolved value
|
||||
if passphrase := credential.PassphraseProvider(); passphrase != "" {
|
||||
encrypted, err := credential.Encrypt(passphrase, "", s.resolved)
|
||||
if err != nil {
|
||||
logger.Errorf("Encrypt error: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
s.raw = encrypted
|
||||
} else {
|
||||
s.raw = s.resolved
|
||||
}
|
||||
return s.raw, nil
|
||||
}
|
||||
|
||||
func (s *SecureString) UnmarshalYAML(value *yaml.Node) error {
|
||||
return s.fromRaw(value.Value)
|
||||
}
|
||||
|
||||
func (s *SecureString) fromRaw(v string) error {
|
||||
s.raw = v
|
||||
vv, err := resolveKey(v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.resolved = vv
|
||||
return nil
|
||||
}
|
||||
|
||||
var (
|
||||
secResolverMu sync.RWMutex
|
||||
secResolver *credential.Resolver
|
||||
)
|
||||
|
||||
func updateResolver(path string) {
|
||||
secResolverMu.Lock()
|
||||
defer secResolverMu.Unlock()
|
||||
secResolver = credential.NewResolver(path)
|
||||
}
|
||||
|
||||
func resolveKey(v string) (string, error) {
|
||||
secResolverMu.RLock()
|
||||
resolver := secResolver
|
||||
secResolverMu.RUnlock()
|
||||
if resolver == nil {
|
||||
resolver = credential.NewResolver("")
|
||||
}
|
||||
if strings.HasPrefix(v, "enc://") || strings.HasPrefix(v, "file://") {
|
||||
decrypted, err := resolver.Resolve(v)
|
||||
if err != nil {
|
||||
logger.Errorf("Resolve error: %v", err)
|
||||
return "", err
|
||||
}
|
||||
return decrypted, nil
|
||||
}
|
||||
return v, nil
|
||||
}
|
||||
|
||||
func (s *SecureString) UnmarshalText(text []byte) error {
|
||||
v := string(text)
|
||||
return s.fromRaw(v)
|
||||
}
|
||||
|
||||
type SecureModelList []*ModelConfig
|
||||
|
||||
func (v *SecureModelList) UnmarshalYAML(value *yaml.Node) error {
|
||||
mm := make(map[string]*ModelConfig)
|
||||
if err := value.Decode(&mm); err != nil {
|
||||
logger.Errorf("Decode error: %v", err)
|
||||
return err
|
||||
}
|
||||
nameList := toNameIndex(*v)
|
||||
for i, m := range *v {
|
||||
sec := mm[nameList[i]]
|
||||
if sec == nil {
|
||||
sec = mm[m.ModelName]
|
||||
}
|
||||
if sec != nil {
|
||||
m.APIKeys = sec.APIKeys
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (v SecureModelList) MarshalYAML() (any, error) {
|
||||
type onlySecureData struct {
|
||||
APIKeys SecureStrings `yaml:"api_keys,omitempty"`
|
||||
}
|
||||
mm := make(map[string]onlySecureData)
|
||||
nameList := toNameIndex(v)
|
||||
for i, m := range v {
|
||||
mm[nameList[i]] = onlySecureData{
|
||||
APIKeys: m.APIKeys,
|
||||
}
|
||||
}
|
||||
|
||||
return mm, nil
|
||||
}
|
||||
|
||||
@@ -15,8 +15,6 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gopkg.in/yaml.v3"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/credential"
|
||||
)
|
||||
|
||||
func TestSecurityConfig(t *testing.T) {
|
||||
@@ -227,134 +225,3 @@ skills:
|
||||
assert.Equal(t, "abc", cfg2.Tools.Web.Brave.APIKeys[1].raw)
|
||||
})
|
||||
}
|
||||
|
||||
func TestLoadSecurityValue(t *testing.T) {
|
||||
type valueStruct struct {
|
||||
Url string `json:"url,omitempty" yaml:"-"`
|
||||
Token *SecureString `json:"token,omitempty" yaml:"token,omitempty" env:"PICO_TOKEN"`
|
||||
ApiKeys SecureStrings `json:"api_keys,omitempty" yaml:"api_keys,omitempty" env:"PICO_API_KEYS"`
|
||||
}
|
||||
|
||||
type testStruct struct {
|
||||
Pico *valueStruct `json:"pico,omitempty" yaml:"pico,omitempty"`
|
||||
}
|
||||
|
||||
v1 := &testStruct{
|
||||
Pico: &valueStruct{
|
||||
Url: "https://example.com",
|
||||
Token: NewSecureString("token1"),
|
||||
ApiKeys: SecureStrings{NewSecureString("api-key1"), NewSecureString("api-key2")},
|
||||
},
|
||||
}
|
||||
bytes, err := yaml.Marshal(v1)
|
||||
assert.NoError(t, err)
|
||||
jsonBytes, err := json.Marshal(v1)
|
||||
assert.NoError(t, err)
|
||||
const want = `pico:
|
||||
token: token1
|
||||
api_keys:
|
||||
- api-key1
|
||||
- api-key2
|
||||
`
|
||||
const jsonPost = `{"pico":{"url":"https://example.com","token":"token0"}}`
|
||||
v0 := &testStruct{}
|
||||
err = json.Unmarshal([]byte(jsonPost), v0)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "https://example.com", v0.Pico.Url)
|
||||
assert.Equal(t, "token0", v0.Pico.Token.String())
|
||||
|
||||
const jsonWant = `{"pico":{"url":"https://example.com","token":"[NOT_HERE]","api_keys":"[NOT_HERE]"}}`
|
||||
assert.Equal(t, want, string(bytes))
|
||||
assert.Equal(t, jsonWant, string(jsonBytes))
|
||||
|
||||
v2 := &testStruct{}
|
||||
err = json.Unmarshal(jsonBytes, v2)
|
||||
assert.NoError(t, err)
|
||||
err = yaml.Unmarshal(bytes, v2)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "https://example.com", v2.Pico.Url)
|
||||
if v2.Pico.Token != nil {
|
||||
assert.Equal(t, "token1", v2.Pico.Token.String())
|
||||
assert.Equal(t, "token1", v2.Pico.Token.raw)
|
||||
}
|
||||
|
||||
v2.Pico.Token = NewSecureString("token1")
|
||||
v2.Pico.Token.raw = "abc"
|
||||
err = yaml.Unmarshal(bytes, v2)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "token1", v2.Pico.Token.raw)
|
||||
|
||||
os.Setenv("PICO_TOKEN", "token_env")
|
||||
err = env.Parse(v2)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, v2.Pico.Token)
|
||||
assert.Equal(t, "token1", v2.Pico.Token.String())
|
||||
|
||||
v3 := &testStruct{Pico: &valueStruct{}}
|
||||
err = env.Parse(v3)
|
||||
assert.NoError(t, err)
|
||||
if v3.Pico.Token != nil {
|
||||
assert.Equal(t, "token_env", v3.Pico.Token.String())
|
||||
}
|
||||
|
||||
type toolsStruct struct {
|
||||
Pico valueStruct `json:"pico,omitempty" yaml:"pico,omitempty"`
|
||||
}
|
||||
|
||||
type testStruct2 struct {
|
||||
Tools toolsStruct `json:"tools,omitempty" yaml:",inline"`
|
||||
}
|
||||
|
||||
v4 := &testStruct2{
|
||||
Tools: toolsStruct{
|
||||
Pico: valueStruct{
|
||||
Url: "https://example.com",
|
||||
Token: NewSecureString("token1"),
|
||||
ApiKeys: SecureStrings{NewSecureString("api-key1"), NewSecureString("api-key2")},
|
||||
},
|
||||
},
|
||||
}
|
||||
bytes, err = yaml.Marshal(v4)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, want, string(bytes))
|
||||
jsonBytes, err = json.Marshal(v4)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(
|
||||
t,
|
||||
`{"tools":{"pico":{"url":"https://example.com","token":"[NOT_HERE]","api_keys":"[NOT_HERE]"}}}`,
|
||||
string(jsonBytes),
|
||||
)
|
||||
|
||||
v5 := &testStruct2{}
|
||||
err = json.Unmarshal(jsonBytes, v5)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "https://example.com", v5.Tools.Pico.Url)
|
||||
err = yaml.Unmarshal(bytes, v5)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, v5.Tools.Pico.Token)
|
||||
assert.Equal(t, "token1", v5.Tools.Pico.Token.raw)
|
||||
|
||||
dir := t.TempDir()
|
||||
sshKeyPath := filepath.Join(dir, "picoclaw_ed25519.key")
|
||||
if err = os.WriteFile(sshKeyPath, []byte("fake-ssh-key-material\n"), 0o600); err != nil {
|
||||
t.Fatalf("setup: %v", err)
|
||||
}
|
||||
|
||||
const passphrase = "test-passphrase-32bytes-long-ok!"
|
||||
|
||||
t.Setenv(credential.SSHKeyPathEnvVar, sshKeyPath)
|
||||
|
||||
t.Setenv(credential.PassphraseEnvVar, passphrase)
|
||||
|
||||
v5.Tools.Pico.Token.Set("newtoken1")
|
||||
v5.Tools.Pico.ApiKeys[0].Set("newapi-key1")
|
||||
bytes, err = yaml.Marshal(v5)
|
||||
assert.NoError(t, err)
|
||||
t.Logf("yaml: %s", string(bytes))
|
||||
|
||||
v6 := &testStruct2{}
|
||||
err = yaml.Unmarshal(bytes, v6)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, v6.Tools.Pico.Token)
|
||||
assert.Equal(t, "newtoken1", v6.Tools.Pico.Token.String())
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user