mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
feat(provider,web,asr): enhance model management with explicit provider metadata (#2701)
* feat(provider,web): enhance model management with provider options * fix(asr): enhance compatibility for ElevenLabs transcription model * fix(provider,web): align provider availability predicates and add flow gating * fix(web,asr): preserve legacy elevenlabs transcription configs * fix(provider,web,asr): normalize elevenlabs configs and gate default chat models * fix: tighten provider catalog and elevenlabs compatibility
This commit is contained in:
@@ -382,6 +382,9 @@ func (h *Handler) gatewayStartReady() (bool, string, error) {
|
||||
if modelCfg == nil {
|
||||
return false, fmt.Sprintf("default model %q is invalid", modelName), nil
|
||||
}
|
||||
if !defaultModelAllowedForModelConfig(modelCfg) {
|
||||
return false, fmt.Sprintf("default model %q is not usable for chat", modelName), nil
|
||||
}
|
||||
|
||||
if !hasModelConfiguration(modelCfg) {
|
||||
return false, fmt.Sprintf("default model %q has no credentials configured", modelName), nil
|
||||
|
||||
@@ -357,6 +357,44 @@ func TestGatewayStartReady_NoDefaultModel(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayStartReady_RejectsASROnlyDefaultModel(t *testing.T) {
|
||||
configPath, cleanup := setupOAuthTestEnv(t)
|
||||
defer cleanup()
|
||||
|
||||
cfg, err := config.LoadConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig() error = %v", err)
|
||||
}
|
||||
cfg.ModelList = []*config.ModelConfig{{
|
||||
ModelName: "elevenlabs-asr",
|
||||
Provider: "elevenlabs",
|
||||
Model: "scribe_v1",
|
||||
APIKeys: config.SimpleSecureStrings("sk_elevenlabs_test"),
|
||||
}}
|
||||
cfg.Agents.Defaults.ModelName = "elevenlabs-asr"
|
||||
|
||||
err = config.SaveConfig(configPath, cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("SaveConfig() error = %v", err)
|
||||
}
|
||||
|
||||
h := NewHandler(configPath)
|
||||
ready, reason, err := h.gatewayStartReady()
|
||||
if err != nil {
|
||||
t.Fatalf("gatewayStartReady() error = %v", err)
|
||||
}
|
||||
if ready {
|
||||
t.Fatal("gatewayStartReady() ready = true, want false")
|
||||
}
|
||||
if reason != `default model "elevenlabs-asr" is not usable for chat` {
|
||||
t.Fatalf(
|
||||
"gatewayStartReady() reason = %q, want %q",
|
||||
reason,
|
||||
`default model "elevenlabs-asr" is not usable for chat`,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLooksLikeGatewayCommandLine(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -47,6 +48,7 @@ var (
|
||||
probeTCPServiceFunc = probeTCPService
|
||||
probeOllamaModelFunc = probeOllamaModel
|
||||
probeOpenAICompatibleModelFunc = probeOpenAICompatibleModel
|
||||
probeCommandAvailableFunc = probeCommandAvailable
|
||||
modelProbeNowFunc = time.Now
|
||||
modelProbeState = newModelProbeCacheState()
|
||||
)
|
||||
@@ -83,17 +85,23 @@ func (s *modelProbeCacheState) resetForTest() {
|
||||
}
|
||||
|
||||
func hasModelConfiguration(m *config.ModelConfig) bool {
|
||||
protocol := modelProtocol(m)
|
||||
authMethod := strings.ToLower(strings.TrimSpace(m.AuthMethod))
|
||||
apiKey := strings.TrimSpace(m.APIKey())
|
||||
|
||||
if authMethod == "oauth" || authMethod == "token" {
|
||||
if provider, ok := oauthProviderForModel(m); ok {
|
||||
cred, err := oauthGetCredential(provider)
|
||||
if err != nil || cred == nil {
|
||||
return false
|
||||
}
|
||||
return strings.TrimSpace(cred.AccessToken) != "" || strings.TrimSpace(cred.RefreshToken) != ""
|
||||
if configured, checked := hasStoredOAuthCredential(m); checked {
|
||||
return configured
|
||||
}
|
||||
}
|
||||
|
||||
if authMethod == "" && providerUsesImplicitOAuth(protocol) {
|
||||
if configured, checked := hasStoredOAuthCredential(m); checked {
|
||||
return configured
|
||||
}
|
||||
}
|
||||
|
||||
if providerUsesAmbientCredentials(protocol) {
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -104,6 +112,40 @@ func hasModelConfiguration(m *config.ModelConfig) bool {
|
||||
return apiKey != ""
|
||||
}
|
||||
|
||||
func hasStoredOAuthCredential(m *config.ModelConfig) (bool, bool) {
|
||||
provider, ok := oauthProviderForModel(m)
|
||||
if !ok {
|
||||
return false, false
|
||||
}
|
||||
cred, err := oauthGetCredential(provider)
|
||||
if err != nil || cred == nil {
|
||||
return false, true
|
||||
}
|
||||
return strings.TrimSpace(cred.AccessToken) != "" || strings.TrimSpace(cred.RefreshToken) != "", true
|
||||
}
|
||||
|
||||
func providerUsesImplicitOAuth(protocol string) bool {
|
||||
switch protocol {
|
||||
case "antigravity", "google-antigravity":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func providerUsesAmbientCredentials(protocol string) bool {
|
||||
switch protocol {
|
||||
case "bedrock":
|
||||
// Bedrock relies on the AWS SDK credential chain instead of an explicit
|
||||
// API key stored in ModelConfig. We cannot reliably preflight every AWS
|
||||
// credential source here, so avoid misclassifying valid environments as
|
||||
// "unconfigured" and defer concrete credential failures to runtime.
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func modelConfigurationStatus(m *config.ModelConfig) modelConfigurationSummary {
|
||||
if !hasModelConfiguration(m) {
|
||||
return modelConfigurationSummary{Available: false, Status: modelStatusUnconfigured}
|
||||
@@ -180,8 +222,10 @@ func runLocalModelProbe(m *config.ModelConfig) bool {
|
||||
return probeOpenAICompatibleModelFunc(apiBase, modelID, m.APIKey())
|
||||
case "github-copilot", "copilot":
|
||||
return probeTCPServiceFunc(apiBase)
|
||||
case "claude-cli", "claudecli", "codex-cli", "codexcli":
|
||||
return true
|
||||
case "claude-cli", "claudecli":
|
||||
return probeCommandAvailableFunc("claude")
|
||||
case "codex-cli", "codexcli":
|
||||
return probeCommandAvailableFunc("codex")
|
||||
default:
|
||||
if hasLocalAPIBase(apiBase) {
|
||||
return probeOpenAICompatibleModelFunc(apiBase, modelID, m.APIKey())
|
||||
@@ -190,6 +234,11 @@ func runLocalModelProbe(m *config.ModelConfig) bool {
|
||||
}
|
||||
}
|
||||
|
||||
func probeCommandAvailable(command string) bool {
|
||||
_, err := exec.LookPath(command)
|
||||
return err == nil
|
||||
}
|
||||
|
||||
func modelProbeCacheKey(m *config.ModelConfig) string {
|
||||
protocol, modelID := splitModel(m)
|
||||
|
||||
|
||||
+219
-15
@@ -9,6 +9,7 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/audio/asr"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
@@ -45,11 +46,184 @@ type modelResponse struct {
|
||||
ExtraBody map[string]any `json:"extra_body,omitempty"`
|
||||
CustomHeaders map[string]string `json:"custom_headers,omitempty"`
|
||||
// Meta
|
||||
Enabled bool `json:"enabled"`
|
||||
Available bool `json:"available"`
|
||||
Status string `json:"status"`
|
||||
IsDefault bool `json:"is_default"`
|
||||
IsVirtual bool `json:"is_virtual"`
|
||||
Enabled bool `json:"enabled"`
|
||||
Available bool `json:"available"`
|
||||
Status string `json:"status"`
|
||||
IsDefault bool `json:"is_default"`
|
||||
IsVirtual bool `json:"is_virtual"`
|
||||
DefaultModelAllowed bool `json:"default_model_allowed"`
|
||||
}
|
||||
|
||||
func normalizeStoredModelConfig(mc *config.ModelConfig) bool {
|
||||
if mc == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
changed := false
|
||||
model := strings.TrimSpace(mc.Model)
|
||||
if model != mc.Model {
|
||||
mc.Model = model
|
||||
changed = true
|
||||
}
|
||||
provider := strings.TrimSpace(mc.Provider)
|
||||
if provider != mc.Provider {
|
||||
mc.Provider = provider
|
||||
changed = true
|
||||
}
|
||||
authMethod := strings.ToLower(strings.TrimSpace(mc.AuthMethod))
|
||||
if authMethod != mc.AuthMethod {
|
||||
mc.AuthMethod = authMethod
|
||||
changed = true
|
||||
}
|
||||
|
||||
if provider != "" {
|
||||
normalizedProvider := providers.NormalizeProvider(provider)
|
||||
if providers.IsSupportedModelProvider(normalizedProvider) && normalizedProvider != provider {
|
||||
mc.Provider = normalizedProvider
|
||||
changed = true
|
||||
}
|
||||
if mc.Provider == "elevenlabs" {
|
||||
if _, strippedModel, found := strings.Cut(
|
||||
model,
|
||||
"/",
|
||||
); found &&
|
||||
providers.NormalizeProvider(strings.TrimSpace(provider)) == "elevenlabs" {
|
||||
strippedModel = strings.TrimSpace(strippedModel)
|
||||
if strippedModel != "" && strippedModel != mc.Model {
|
||||
mc.Model = strippedModel
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
if strings.TrimSpace(mc.Model) != asr.ElevenLabsSupportedModelID() {
|
||||
mc.Model = asr.ElevenLabsSupportedModelID()
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
return changed
|
||||
}
|
||||
|
||||
effectiveProvider, modelID := providers.SplitModelProviderAndID(model, "openai")
|
||||
if effectiveProvider == "" {
|
||||
return changed
|
||||
}
|
||||
if mc.Provider != effectiveProvider {
|
||||
mc.Provider = effectiveProvider
|
||||
changed = true
|
||||
}
|
||||
if mc.Model != modelID {
|
||||
mc.Model = modelID
|
||||
changed = true
|
||||
}
|
||||
return changed
|
||||
}
|
||||
|
||||
func normalizeIncomingModelConfig(mc *config.ModelConfig) {
|
||||
if mc == nil {
|
||||
return
|
||||
}
|
||||
|
||||
mc.Model = strings.TrimSpace(mc.Model)
|
||||
mc.Provider = strings.TrimSpace(mc.Provider)
|
||||
mc.AuthMethod = strings.ToLower(strings.TrimSpace(mc.AuthMethod))
|
||||
if mc.Provider == "" {
|
||||
mc.Provider, mc.Model = providers.SplitModelProviderAndID(mc.Model, "openai")
|
||||
} else {
|
||||
mc.Provider = providers.NormalizeProvider(mc.Provider)
|
||||
if mc.Provider == "elevenlabs" {
|
||||
if _, strippedModel, found := strings.Cut(mc.Model, "/"); found {
|
||||
strippedModel = strings.TrimSpace(strippedModel)
|
||||
if strippedModel != "" {
|
||||
mc.Model = strippedModel
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if mc.Provider == "antigravity" && mc.AuthMethod == "" {
|
||||
mc.AuthMethod = "oauth"
|
||||
}
|
||||
}
|
||||
|
||||
func createAllowedForProvider(provider string) bool {
|
||||
normalized := providers.NormalizeProvider(provider)
|
||||
switch normalized {
|
||||
case "bedrock":
|
||||
// Bedrock currently authenticates through the AWS SDK credential chain
|
||||
// (env vars, shared profiles, IAM roles, etc.), and this Web layer does
|
||||
// not yet have a reliable preflight check for those credential sources.
|
||||
// Keep it creatable in the catalog and let provider construction/runtime
|
||||
// return the concrete AWS error when the environment is incomplete.
|
||||
return true
|
||||
case "claude-cli", "codex-cli":
|
||||
return cliProviderCreateAllowedFromCurrentStatus(normalized)
|
||||
default:
|
||||
return providers.IsCreatableModelProvider(normalized)
|
||||
}
|
||||
}
|
||||
|
||||
// cliProviderCreateAllowedFromCurrentStatus intentionally reuses the existing
|
||||
// local model status pipeline so provider catalog gating follows the same CLI
|
||||
// executable probe used by launcher readiness.
|
||||
func cliProviderCreateAllowedFromCurrentStatus(provider string) bool {
|
||||
status := modelConfigurationStatus(&config.ModelConfig{
|
||||
Provider: provider,
|
||||
Model: provider,
|
||||
})
|
||||
return status.Available
|
||||
}
|
||||
|
||||
func modelProviderOptionsForResponse() []providers.ModelProviderOption {
|
||||
options := providers.ModelProviderOptions()
|
||||
for i := range options {
|
||||
options[i].CreateAllowed = createAllowedForProvider(options[i].ID)
|
||||
}
|
||||
return options
|
||||
}
|
||||
|
||||
func defaultModelAllowedForModelConfig(mc *config.ModelConfig) bool {
|
||||
provider, _ := providers.ExtractProtocol(mc)
|
||||
return providers.IsDefaultModelProvider(provider)
|
||||
}
|
||||
|
||||
func validateIncomingModelConfig(mc *config.ModelConfig, existing *config.ModelConfig) error {
|
||||
if mc == nil {
|
||||
return fmt.Errorf("model config is required")
|
||||
}
|
||||
if err := mc.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
if strings.TrimSpace(mc.Provider) == "" {
|
||||
return fmt.Errorf("provider is required")
|
||||
}
|
||||
if !providers.IsSupportedModelProvider(mc.Provider) {
|
||||
return fmt.Errorf("provider %q is not supported", mc.Provider)
|
||||
}
|
||||
if mc.Provider == "elevenlabs" && strings.TrimSpace(mc.Model) != asr.ElevenLabsSupportedModelID() {
|
||||
return fmt.Errorf("provider %q only supports model %q", mc.Provider, asr.ElevenLabsSupportedModelID())
|
||||
}
|
||||
if !createAllowedForProvider(mc.Provider) {
|
||||
if existing == nil {
|
||||
return fmt.Errorf("provider %q is not available for new models", mc.Provider)
|
||||
}
|
||||
existingProvider, _ := providers.ExtractProtocol(existing)
|
||||
if providers.NormalizeProvider(existingProvider) != mc.Provider {
|
||||
return fmt.Errorf("provider %q is not available for selection", mc.Provider)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func normalizeStoredModelProviders(cfg *config.Config) bool {
|
||||
if cfg == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
changed := false
|
||||
for _, model := range cfg.ModelList {
|
||||
if normalizeStoredModelConfig(model) {
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
return changed
|
||||
}
|
||||
|
||||
// handleListModels returns all model_list entries with masked API keys.
|
||||
@@ -62,6 +236,10 @@ func (h *Handler) handleListModels(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Normalize legacy provider/model storage in memory so GET can round-trip
|
||||
// through the current API shape without mutating the on-disk config.
|
||||
normalizeStoredModelProviders(cfg)
|
||||
|
||||
defaultModel := cfg.Agents.Defaults.GetModelName()
|
||||
modelStatuses := make([]modelConfigurationSummary, len(cfg.ModelList))
|
||||
|
||||
@@ -101,14 +279,16 @@ func (h *Handler) handleListModels(w http.ResponseWriter, r *http.Request) {
|
||||
Status: modelStatuses[i].Status,
|
||||
IsDefault: m.ModelName == defaultModel,
|
||||
IsVirtual: m.IsVirtual(),
|
||||
DefaultModelAllowed: defaultModelAllowedForModelConfig(m),
|
||||
})
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"models": models,
|
||||
"total": len(models),
|
||||
"default_model": defaultModel,
|
||||
"models": models,
|
||||
"total": len(models),
|
||||
"default_model": defaultModel,
|
||||
"provider_options": modelProviderOptionsForResponse(),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -134,7 +314,9 @@ func (h *Handler) handleAddModel(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
if err = mc.Validate(); err != nil {
|
||||
normalizeIncomingModelConfig(&mc.ModelConfig)
|
||||
|
||||
if err = validateIncomingModelConfig(&mc.ModelConfig, nil); err != nil {
|
||||
http.Error(w, fmt.Sprintf("Validation error: %v", err), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
@@ -150,6 +332,7 @@ func (h *Handler) handleAddModel(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
cfg.ModelList = append(cfg.ModelList, &mc.ModelConfig)
|
||||
normalizeStoredModelProviders(cfg)
|
||||
|
||||
if err := config.SaveConfig(h.configPath, cfg); err != nil {
|
||||
http.Error(w, fmt.Sprintf("Failed to save config: %v", err), http.StatusInternalServerError)
|
||||
@@ -200,11 +383,6 @@ func (h *Handler) handleUpdateModel(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
if err = mc.Validate(); err != nil {
|
||||
http.Error(w, fmt.Sprintf("Validation error: %v", err), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
cfg, err := config.LoadConfig(h.configPath)
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("Failed to load config: %v", err), http.StatusInternalServerError)
|
||||
@@ -253,9 +431,9 @@ func (h *Handler) handleUpdateModel(w http.ResponseWriter, r *http.Request) {
|
||||
// This keeps provider-omitted updates backward-compatible even when an
|
||||
// older client edits the visible model ID.
|
||||
if strings.TrimSpace(cfg.ModelList[idx].Provider) == "" {
|
||||
existingProtocol, existingModelID := providers.ExtractProtocol(cfg.ModelList[idx])
|
||||
existingRawModel := strings.TrimSpace(cfg.ModelList[idx].Model)
|
||||
incomingModel := strings.TrimSpace(mc.Model)
|
||||
existingProtocol, existingModelID := providers.ExtractProtocol(cfg.ModelList[idx])
|
||||
if existingRawModel != "" && existingRawModel != existingModelID && incomingModel != "" {
|
||||
if incomingModel == existingModelID {
|
||||
mc.Model = existingRawModel
|
||||
@@ -272,7 +450,20 @@ func (h *Handler) handleUpdateModel(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
normalizeIncomingModelConfig(&mc.ModelConfig)
|
||||
if err = validateIncomingModelConfig(&mc.ModelConfig, cfg.ModelList[idx]); err != nil {
|
||||
http.Error(w, fmt.Sprintf("Validation error: %v", err), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if cfg.Agents.Defaults.ModelName == cfg.ModelList[idx].ModelName &&
|
||||
!defaultModelAllowedForModelConfig(&mc.ModelConfig) {
|
||||
// Allow users to recover from legacy/invalid defaults by saving the model
|
||||
// and clearing the default chat model reference in the same write.
|
||||
cfg.Agents.Defaults.ModelName = ""
|
||||
}
|
||||
|
||||
cfg.ModelList[idx] = &mc.ModelConfig
|
||||
normalizeStoredModelProviders(cfg)
|
||||
|
||||
logger.Debugf("update model config: %#v", mc.ModelConfig)
|
||||
|
||||
@@ -372,6 +563,19 @@ func (h *Handler) handleSetDefaultModel(w http.ResponseWriter, r *http.Request)
|
||||
http.Error(w, fmt.Sprintf("Cannot set virtual model %q as default", req.ModelName), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
for _, m := range cfg.ModelList {
|
||||
if m.ModelName == req.ModelName {
|
||||
if !defaultModelAllowedForModelConfig(m) {
|
||||
http.Error(
|
||||
w,
|
||||
fmt.Sprintf("Model %q cannot be used as the default chat model", req.ModelName),
|
||||
http.StatusBadRequest,
|
||||
)
|
||||
return
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
cfg.Agents.Defaults.ModelName = req.ModelName
|
||||
|
||||
|
||||
+1021
-19
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user