mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
328 lines
7.3 KiB
Go
328 lines
7.3 KiB
Go
package config
|
||
|
||
import (
|
||
"encoding/json"
|
||
"fmt"
|
||
"path/filepath"
|
||
"runtime"
|
||
"strings"
|
||
"sync"
|
||
|
||
"gopkg.in/yaml.v3"
|
||
|
||
"github.com/sipeed/picoclaw/pkg/credential"
|
||
"github.com/sipeed/picoclaw/pkg/logger"
|
||
)
|
||
|
||
// FlexibleStringSlice is a []string that also accepts JSON numbers,
|
||
// so allow_from can contain both "123" and 123.
|
||
// It also supports parsing comma-separated strings from environment variables,
|
||
// including both English (,) and Chinese (,) commas.
|
||
type FlexibleStringSlice []string
|
||
|
||
func (f *FlexibleStringSlice) UnmarshalJSON(data []byte) error {
|
||
// Accept a single JSON string for convenience, e.g.:
|
||
// "text": "Thinking..."
|
||
var singleString string
|
||
if err := json.Unmarshal(data, &singleString); err == nil {
|
||
*f = FlexibleStringSlice{singleString}
|
||
return nil
|
||
}
|
||
|
||
// Accept a single JSON number too, to keep symmetry with mixed allow_from
|
||
// payloads that may contain numeric identifiers.
|
||
var singleNumber float64
|
||
if err := json.Unmarshal(data, &singleNumber); err == nil {
|
||
*f = FlexibleStringSlice{fmt.Sprintf("%.0f", singleNumber)}
|
||
return nil
|
||
}
|
||
|
||
// Try []string first
|
||
var ss []string
|
||
if err := json.Unmarshal(data, &ss); err == nil {
|
||
*f = ss
|
||
return nil
|
||
}
|
||
|
||
// Try []interface{} to handle mixed types
|
||
var raw []any
|
||
if err := json.Unmarshal(data, &raw); err != nil {
|
||
var s string
|
||
// fail over to compatible to old format string
|
||
if err = json.Unmarshal(data, &s); err != nil {
|
||
return err
|
||
}
|
||
*f = []string{s}
|
||
return nil
|
||
}
|
||
|
||
result := make([]string, 0, len(raw))
|
||
for _, v := range raw {
|
||
switch val := v.(type) {
|
||
case string:
|
||
result = append(result, val)
|
||
case float64:
|
||
result = append(result, fmt.Sprintf("%.0f", val))
|
||
default:
|
||
result = append(result, fmt.Sprintf("%v", val))
|
||
}
|
||
}
|
||
*f = result
|
||
return nil
|
||
}
|
||
|
||
// UnmarshalText implements encoding.TextUnmarshaler to support env variable parsing.
|
||
// It handles comma-separated values with both English (,) and Chinese (,) commas.
|
||
func (f *FlexibleStringSlice) UnmarshalText(text []byte) error {
|
||
if len(text) == 0 {
|
||
*f = nil
|
||
return nil
|
||
}
|
||
|
||
s := string(text)
|
||
// Replace Chinese comma with English comma, then split
|
||
s = strings.ReplaceAll(s, ",", ",")
|
||
parts := strings.Split(s, ",")
|
||
|
||
result := make([]string, 0, len(parts))
|
||
for _, part := range parts {
|
||
part = strings.TrimSpace(part)
|
||
if part != "" {
|
||
result = append(result, part)
|
||
}
|
||
}
|
||
*f = result
|
||
return nil
|
||
}
|
||
|
||
const (
|
||
notHere = `"[NOT_HERE]"`
|
||
)
|
||
|
||
// SecureStrings is a slice of SecureString
|
||
type SecureStrings []*SecureString
|
||
|
||
// Values returns the decrypted/resolved values
|
||
func (s *SecureStrings) Values() []string {
|
||
if s == nil {
|
||
return nil
|
||
}
|
||
keys := make([]string, len(*s))
|
||
for i, k := range *s {
|
||
keys[i] = k.String()
|
||
}
|
||
return unique(keys)
|
||
}
|
||
|
||
func SimpleSecureStrings(val ...string) SecureStrings {
|
||
val = unique(val)
|
||
vv := make(SecureStrings, len(val))
|
||
for i, s := range val {
|
||
vv[i] = NewSecureString(s)
|
||
}
|
||
return vv
|
||
}
|
||
|
||
// unique returns a new slice with duplicate elements removed.
|
||
func unique[T comparable](input []T) []T {
|
||
m := make(map[T]struct{})
|
||
var result []T
|
||
for _, v := range input {
|
||
if _, ok := m[v]; !ok {
|
||
m[v] = struct{}{}
|
||
result = append(result, v)
|
||
}
|
||
}
|
||
return result
|
||
}
|
||
|
||
func (s SecureStrings) MarshalJSON() ([]byte, error) {
|
||
return []byte(notHere), nil
|
||
}
|
||
|
||
func (s *SecureStrings) UnmarshalJSON(value []byte) error {
|
||
if string(value) == notHere {
|
||
return nil
|
||
}
|
||
var v []*SecureString
|
||
err := json.Unmarshal(value, &v)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
*s = v
|
||
return nil
|
||
}
|
||
|
||
// SecureString the string value that can be decrypted or resolved
|
||
//
|
||
//nolint:recvcheck
|
||
type SecureString struct {
|
||
resolved string // Decrypted/resolved value returned by String()
|
||
raw string // Persisted raw value (enc://, file://, or plaintext)
|
||
}
|
||
|
||
func callerFromYaml() bool {
|
||
_, file, _, ok := runtime.Caller(2)
|
||
if ok {
|
||
d := filepath.Dir(file)
|
||
// check the caller is from yaml.v
|
||
if !strings.Contains(d, "yaml.v") {
|
||
return true
|
||
}
|
||
}
|
||
return false
|
||
}
|
||
|
||
// IsZero returns true if the SecureString is empty
|
||
// if caller not yaml, just return true for prevent marshal this field
|
||
func (s SecureString) IsZero() bool {
|
||
if callerFromYaml() {
|
||
return true
|
||
}
|
||
return s.resolved == ""
|
||
}
|
||
|
||
func NewSecureString(value string) *SecureString {
|
||
s := &SecureString{}
|
||
if err := s.fromRaw(value); err != nil {
|
||
logger.Warn(fmt.Sprintf("NewSecureString.fromRaw error: %s", err))
|
||
}
|
||
return s
|
||
}
|
||
|
||
func (s *SecureString) String() string {
|
||
if s == nil {
|
||
return ""
|
||
}
|
||
return s.resolved
|
||
}
|
||
|
||
func (s *SecureString) Set(value string) *SecureString {
|
||
s.resolved = value
|
||
s.raw = ""
|
||
return s
|
||
}
|
||
|
||
func (s SecureString) MarshalJSON() ([]byte, error) {
|
||
return []byte(notHere), nil
|
||
}
|
||
|
||
func (s *SecureString) UnmarshalJSON(value []byte) error {
|
||
if string(value) == notHere {
|
||
return nil
|
||
}
|
||
var v string
|
||
if err := json.Unmarshal(value, &v); err != nil {
|
||
return err
|
||
}
|
||
return s.fromRaw(v)
|
||
}
|
||
|
||
func (s SecureString) MarshalYAML() (any, error) {
|
||
// Preserve raw value if it is already a reference (enc:// or file://)
|
||
if strings.HasPrefix(s.raw, credential.EncScheme) || strings.HasPrefix(s.raw, credential.FileScheme) {
|
||
return s.raw, nil
|
||
}
|
||
// If resolved is a reference format (e.g. set via Set), copy back to raw
|
||
if strings.HasPrefix(s.resolved, credential.EncScheme) || strings.HasPrefix(s.resolved, credential.FileScheme) {
|
||
s.raw = s.resolved
|
||
return s.raw, nil
|
||
}
|
||
// Try to encrypt the resolved value
|
||
if passphrase := credential.PassphraseProvider(); passphrase != "" {
|
||
encrypted, err := credential.Encrypt(passphrase, "", s.resolved)
|
||
if err != nil {
|
||
logger.Errorf("Encrypt error: %v", err)
|
||
return nil, err
|
||
}
|
||
s.raw = encrypted
|
||
} else {
|
||
s.raw = s.resolved
|
||
}
|
||
return s.raw, nil
|
||
}
|
||
|
||
func (s *SecureString) UnmarshalYAML(value *yaml.Node) error {
|
||
return s.fromRaw(value.Value)
|
||
}
|
||
|
||
func (s *SecureString) fromRaw(v string) error {
|
||
s.raw = v
|
||
vv, err := resolveKey(v)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
s.resolved = vv
|
||
return nil
|
||
}
|
||
|
||
var (
|
||
secResolverMu sync.RWMutex
|
||
secResolver *credential.Resolver
|
||
)
|
||
|
||
func updateResolver(path string) {
|
||
secResolverMu.Lock()
|
||
defer secResolverMu.Unlock()
|
||
secResolver = credential.NewResolver(path)
|
||
}
|
||
|
||
func resolveKey(v string) (string, error) {
|
||
secResolverMu.RLock()
|
||
resolver := secResolver
|
||
secResolverMu.RUnlock()
|
||
if resolver == nil {
|
||
resolver = credential.NewResolver("")
|
||
}
|
||
if strings.HasPrefix(v, "enc://") || strings.HasPrefix(v, "file://") {
|
||
decrypted, err := resolver.Resolve(v)
|
||
if err != nil {
|
||
logger.Errorf("Resolve error: %v", err)
|
||
return "", err
|
||
}
|
||
return decrypted, nil
|
||
}
|
||
return v, nil
|
||
}
|
||
|
||
func (s *SecureString) UnmarshalText(text []byte) error {
|
||
v := string(text)
|
||
return s.fromRaw(v)
|
||
}
|
||
|
||
type SecureModelList []*ModelConfig
|
||
|
||
func (v *SecureModelList) UnmarshalYAML(value *yaml.Node) error {
|
||
mm := make(map[string]*ModelConfig)
|
||
if err := value.Decode(&mm); err != nil {
|
||
logger.Errorf("Decode error: %v", err)
|
||
return err
|
||
}
|
||
nameList := toNameIndex(*v)
|
||
for i, m := range *v {
|
||
sec := mm[nameList[i]]
|
||
if sec == nil {
|
||
sec = mm[m.ModelName]
|
||
}
|
||
if sec != nil {
|
||
m.APIKeys = sec.APIKeys
|
||
}
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func (v SecureModelList) MarshalYAML() (any, error) {
|
||
type onlySecureData struct {
|
||
APIKeys SecureStrings `yaml:"api_keys,omitempty"`
|
||
}
|
||
mm := make(map[string]onlySecureData)
|
||
nameList := toNameIndex(v)
|
||
for i, m := range v {
|
||
mm[nameList[i]] = onlySecureData{
|
||
APIKeys: m.APIKeys,
|
||
}
|
||
}
|
||
|
||
return mm, nil
|
||
}
|