From 4d2b24452232f438cd9d3cb320223d85fdbb42d1 Mon Sep 17 00:00:00 2001 From: RussellLuo Date: Sun, 22 Mar 2026 23:40:13 +0800 Subject: [PATCH] refactor(voice): share audio format support and restrict transcriber selection --- pkg/utils/media.go | 17 ++++++++++++++++- pkg/voice/audio_model_transcriber.go | 22 +--------------------- pkg/voice/transcriber.go | 27 ++++++++++++++++++++++++++- pkg/voice/transcriber_test.go | 25 +++++++++++++++++++++++++ 4 files changed, 68 insertions(+), 23 deletions(-) diff --git a/pkg/utils/media.go b/pkg/utils/media.go index 82e9f5f45..bf97a9756 100644 --- a/pkg/utils/media.go +++ b/pkg/utils/media.go @@ -1,6 +1,7 @@ package utils import ( + "fmt" "io" "net/http" "net/url" @@ -15,9 +16,23 @@ import ( "github.com/sipeed/picoclaw/pkg/media" ) +var ( + audioExtensions = []string{".mp3", ".wav", ".ogg", ".m4a", ".flac", ".aac", ".wma"} +) + +func AudioFormat(path string) (string, error) { + ext := strings.ToLower(filepath.Ext(path)) + for _, supportedExt := range audioExtensions { + if ext == supportedExt { + return strings.TrimPrefix(ext, "."), nil + } + } + + return "", fmt.Errorf("unsupported audio format for %q", path) +} + // IsAudioFile checks if a file is an audio file based on its filename extension and content type. func IsAudioFile(filename, contentType string) bool { - audioExtensions := []string{".mp3", ".wav", ".ogg", ".m4a", ".flac", ".aac", ".wma"} audioTypes := []string{"audio/", "application/ogg", "application/x-ogg"} for _, ext := range audioExtensions { diff --git a/pkg/voice/audio_model_transcriber.go b/pkg/voice/audio_model_transcriber.go index 096e832fa..94486b5e4 100644 --- a/pkg/voice/audio_model_transcriber.go +++ b/pkg/voice/audio_model_transcriber.go @@ -5,7 +5,6 @@ import ( "encoding/base64" "fmt" "os" - "path/filepath" "strings" "github.com/sipeed/picoclaw/pkg/config" @@ -24,25 +23,6 @@ const ( defaultTranscriptionPrompt = "Transcribe this audio." ) -func audioFormat(path string) (string, error) { - switch strings.ToLower(filepath.Ext(strings.TrimPrefix(path, "file://"))) { - case ".wav": - return "wav", nil - case ".mp3": - return "mp3", nil - case ".aiff", ".aif": - return "aiff", nil - case ".aac": - return "aac", nil - case ".ogg": - return "ogg", nil - case ".flac": - return "flac", nil - default: - return "", fmt.Errorf("unsupported audio format for %q", path) - } -} - func NewAudioModelTranscriber(modelCfg *config.ModelConfig) *AudioModelTranscriber { if modelCfg == nil { return nil @@ -79,7 +59,7 @@ func (t *AudioModelTranscriber) Transcribe(ctx context.Context, audioFilePath st return nil, fmt.Errorf("failed to read audio file: %w", err) } - format, err := audioFormat(audioFilePath) + format, err := utils.AudioFormat(audioFilePath) if err != nil { logger.ErrorCF("voice", "Failed to detect audio format", map[string]any{"path": audioFilePath, "error": err}) return nil, err diff --git a/pkg/voice/transcriber.go b/pkg/voice/transcriber.go index 36ee92881..f3e6af71e 100644 --- a/pkg/voice/transcriber.go +++ b/pkg/voice/transcriber.go @@ -5,6 +5,7 @@ import ( "strings" "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/providers" ) type Transcriber interface { @@ -18,6 +19,28 @@ type TranscriptionResponse struct { Duration float64 `json:"duration,omitempty"` } +func supportsAudioTranscription(model string) bool { + protocol, _ := providers.ExtractProtocol(model) + + switch protocol { + case "openai", "azure", "azure-openai", + "litellm", "openrouter", "groq", "zhipu", "gemini", "nvidia", + "ollama", "moonshot", "shengsuanyun", "deepseek", "cerebras", + "vivgrid", "volcengine", "vllm", "qwen", "qwen-intl", "qwen-international", "dashscope-intl", + "qwen-us", "dashscope-us", "mistral", "avian", "minimax", "longcat", "modelscope", "novita", + "coding-plan", "alibaba-coding", "qwen-coding": + // These protocols all go through the OpenAI-compatible or Azure provider path in + // providers.CreateProviderFromConfig, so they are the only ones that can supply + // the audio media payload shape expected by NewAudioModelTranscriber. + + // TODO: Further restrict this by modelID, since not every model under these + // protocols supports audio transcription. + return true + default: + return false + } +} + // DetectTranscriber inspects cfg and returns the appropriate Transcriber, or // nil if no supported transcription provider is configured. func DetectTranscriber(cfg *config.Config) Transcriber { @@ -26,7 +49,9 @@ func DetectTranscriber(cfg *config.Config) Transcriber { if err != nil { return nil } - return NewAudioModelTranscriber(modelCfg) + if supportsAudioTranscription(modelCfg.Model) { + return NewAudioModelTranscriber(modelCfg) + } } // Direct Groq provider config takes priority. diff --git a/pkg/voice/transcriber_test.go b/pkg/voice/transcriber_test.go index 753ee5e78..1b20bf9f2 100644 --- a/pkg/voice/transcriber_test.go +++ b/pkg/voice/transcriber_test.go @@ -57,6 +57,31 @@ func TestDetectTranscriber(t *testing.T) { }, wantName: "audio-model", }, + { + name: "voice model name selects azure audio model transcriber", + cfg: &config.Config{ + Voice: config.VoiceConfig{ModelName: "voice-azure-audio"}, + ModelList: []config.ModelConfig{ + { + ModelName: "voice-azure-audio", + Model: "azure/my-audio-deployment", + APIKey: "sk-azure", + APIBase: "https://example.openai.azure.com", + }, + }, + }, + wantName: "audio-model", + }, + { + name: "voice model name with non openai compatible protocol does not select audio model transcriber", + cfg: &config.Config{ + Voice: config.VoiceConfig{ModelName: "voice-anthropic"}, + ModelList: []config.ModelConfig{ + {ModelName: "voice-anthropic", Model: "anthropic/claude-sonnet-4.6", APIKey: "sk-anthropic"}, + }, + }, + wantNil: true, + }, { name: "groq model list entry without key is skipped", cfg: &config.Config{