Files
picoclaw/pkg/agent/model_resolution.go
T
lxowalle 77b0c43392 refactor: support explicit provider field in model list entries (#2609)
* refactor: support explicit model list providers

* fix(web): preserve explicit model providers

* fix(web): preserve legacy provider prefixes on model updates

fix(models): normalize explicit provider-prefixed ids

fix(api): preserve legacy model updates across providers

fix(agent): preserve config identity for explicit provider refs

* fix ci
2026-04-22 11:28:47 +08:00

180 lines
3.9 KiB
Go

package agent
import (
"fmt"
"strings"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/providers"
)
func ensureProtocolModel(model string) string {
model = strings.TrimSpace(model)
if model == "" {
return ""
}
if strings.Contains(model, "/") {
return model
}
return "openai/" + model
}
func modelConfigIdentityKey(mc *config.ModelConfig) string {
if mc == nil {
return ""
}
if name := strings.TrimSpace(mc.ModelName); name != "" {
return "model_name:" + name
}
return ""
}
func candidateFromModelConfig(
defaultProvider string,
mc *config.ModelConfig,
) (providers.FallbackCandidate, bool) {
if mc == nil {
return providers.FallbackCandidate{}, false
}
protocol, modelID := providers.ExtractProtocol(mc)
if strings.TrimSpace(modelID) == "" {
return providers.FallbackCandidate{}, false
}
return providers.FallbackCandidate{
Provider: protocol,
Model: modelID,
RPM: mc.RPM,
IdentityKey: modelConfigIdentityKey(mc),
}, true
}
func lookupModelConfigByRef(cfg *config.Config, raw string) *config.ModelConfig {
raw = strings.TrimSpace(raw)
if raw == "" || cfg == nil {
return nil
}
if mc, err := cfg.GetModelConfig(raw); err == nil && mc != nil && strings.TrimSpace(mc.Model) != "" {
return mc
}
rawRef := providers.ParseModelRef(raw, "")
rawKey := ""
if rawRef != nil && strings.TrimSpace(rawRef.Provider) != "" && strings.TrimSpace(rawRef.Model) != "" {
rawKey = providers.ModelKey(rawRef.Provider, rawRef.Model)
}
for i := range cfg.ModelList {
mc := cfg.ModelList[i]
if mc == nil {
continue
}
fullModel := strings.TrimSpace(mc.Model)
if fullModel == "" {
continue
}
if fullModel == raw {
return mc
}
protocol, modelID := providers.ExtractProtocol(mc)
if modelID == raw {
return mc
}
if rawKey != "" && providers.ModelKey(protocol, modelID) == rawKey {
return mc
}
}
return nil
}
func resolveModelCandidate(
cfg *config.Config,
defaultProvider string,
raw string,
) (providers.FallbackCandidate, bool) {
raw = strings.TrimSpace(raw)
if raw == "" {
return providers.FallbackCandidate{}, false
}
if mc := lookupModelConfigByRef(cfg, raw); mc != nil {
return candidateFromModelConfig(defaultProvider, mc)
}
ref := providers.ParseModelRef(raw, defaultProvider)
if ref == nil {
return providers.FallbackCandidate{}, false
}
return providers.FallbackCandidate{
Provider: ref.Provider,
Model: ref.Model,
}, true
}
func resolveModelCandidates(
cfg *config.Config,
defaultProvider string,
primary string,
fallbacks []string,
) []providers.FallbackCandidate {
seen := make(map[string]bool)
candidates := make([]providers.FallbackCandidate, 0, 1+len(fallbacks))
addCandidate := func(raw string) {
candidate, ok := resolveModelCandidate(cfg, defaultProvider, raw)
if !ok {
return
}
key := candidate.StableKey()
if seen[key] {
return
}
seen[key] = true
candidates = append(candidates, candidate)
}
addCandidate(primary)
for _, fallback := range fallbacks {
addCandidate(fallback)
}
return candidates
}
func resolvedCandidateModel(candidates []providers.FallbackCandidate, fallback string) string {
if len(candidates) > 0 && strings.TrimSpace(candidates[0].Model) != "" {
return candidates[0].Model
}
return fallback
}
func resolvedCandidateProvider(candidates []providers.FallbackCandidate, fallback string) string {
if len(candidates) > 0 && strings.TrimSpace(candidates[0].Provider) != "" {
return candidates[0].Provider
}
return fallback
}
func resolvedModelConfig(cfg *config.Config, modelName, workspace string) (*config.ModelConfig, error) {
if cfg == nil {
return nil, fmt.Errorf("config is nil")
}
modelCfg, err := cfg.GetModelConfig(strings.TrimSpace(modelName))
if err != nil {
return nil, err
}
clone := *modelCfg
if clone.Workspace == "" {
clone.Workspace = workspace
}
return &clone, nil
}