mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
refactor(voice): share audio format support and restrict transcriber selection
This commit is contained in:
+16
-1
@@ -1,6 +1,7 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
@@ -15,9 +16,23 @@ import (
|
||||
"github.com/sipeed/picoclaw/pkg/media"
|
||||
)
|
||||
|
||||
var (
|
||||
audioExtensions = []string{".mp3", ".wav", ".ogg", ".m4a", ".flac", ".aac", ".wma"}
|
||||
)
|
||||
|
||||
func AudioFormat(path string) (string, error) {
|
||||
ext := strings.ToLower(filepath.Ext(path))
|
||||
for _, supportedExt := range audioExtensions {
|
||||
if ext == supportedExt {
|
||||
return strings.TrimPrefix(ext, "."), nil
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("unsupported audio format for %q", path)
|
||||
}
|
||||
|
||||
// IsAudioFile checks if a file is an audio file based on its filename extension and content type.
|
||||
func IsAudioFile(filename, contentType string) bool {
|
||||
audioExtensions := []string{".mp3", ".wav", ".ogg", ".m4a", ".flac", ".aac", ".wma"}
|
||||
audioTypes := []string{"audio/", "application/ogg", "application/x-ogg"}
|
||||
|
||||
for _, ext := range audioExtensions {
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
@@ -24,25 +23,6 @@ const (
|
||||
defaultTranscriptionPrompt = "Transcribe this audio."
|
||||
)
|
||||
|
||||
func audioFormat(path string) (string, error) {
|
||||
switch strings.ToLower(filepath.Ext(strings.TrimPrefix(path, "file://"))) {
|
||||
case ".wav":
|
||||
return "wav", nil
|
||||
case ".mp3":
|
||||
return "mp3", nil
|
||||
case ".aiff", ".aif":
|
||||
return "aiff", nil
|
||||
case ".aac":
|
||||
return "aac", nil
|
||||
case ".ogg":
|
||||
return "ogg", nil
|
||||
case ".flac":
|
||||
return "flac", nil
|
||||
default:
|
||||
return "", fmt.Errorf("unsupported audio format for %q", path)
|
||||
}
|
||||
}
|
||||
|
||||
func NewAudioModelTranscriber(modelCfg *config.ModelConfig) *AudioModelTranscriber {
|
||||
if modelCfg == nil {
|
||||
return nil
|
||||
@@ -79,7 +59,7 @@ func (t *AudioModelTranscriber) Transcribe(ctx context.Context, audioFilePath st
|
||||
return nil, fmt.Errorf("failed to read audio file: %w", err)
|
||||
}
|
||||
|
||||
format, err := audioFormat(audioFilePath)
|
||||
format, err := utils.AudioFormat(audioFilePath)
|
||||
if err != nil {
|
||||
logger.ErrorCF("voice", "Failed to detect audio format", map[string]any{"path": audioFilePath, "error": err})
|
||||
return nil, err
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
type Transcriber interface {
|
||||
@@ -18,6 +19,28 @@ type TranscriptionResponse struct {
|
||||
Duration float64 `json:"duration,omitempty"`
|
||||
}
|
||||
|
||||
func supportsAudioTranscription(model string) bool {
|
||||
protocol, _ := providers.ExtractProtocol(model)
|
||||
|
||||
switch protocol {
|
||||
case "openai", "azure", "azure-openai",
|
||||
"litellm", "openrouter", "groq", "zhipu", "gemini", "nvidia",
|
||||
"ollama", "moonshot", "shengsuanyun", "deepseek", "cerebras",
|
||||
"vivgrid", "volcengine", "vllm", "qwen", "qwen-intl", "qwen-international", "dashscope-intl",
|
||||
"qwen-us", "dashscope-us", "mistral", "avian", "minimax", "longcat", "modelscope", "novita",
|
||||
"coding-plan", "alibaba-coding", "qwen-coding":
|
||||
// These protocols all go through the OpenAI-compatible or Azure provider path in
|
||||
// providers.CreateProviderFromConfig, so they are the only ones that can supply
|
||||
// the audio media payload shape expected by NewAudioModelTranscriber.
|
||||
|
||||
// TODO: Further restrict this by modelID, since not every model under these
|
||||
// protocols supports audio transcription.
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// DetectTranscriber inspects cfg and returns the appropriate Transcriber, or
|
||||
// nil if no supported transcription provider is configured.
|
||||
func DetectTranscriber(cfg *config.Config) Transcriber {
|
||||
@@ -26,7 +49,9 @@ func DetectTranscriber(cfg *config.Config) Transcriber {
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return NewAudioModelTranscriber(modelCfg)
|
||||
if supportsAudioTranscription(modelCfg.Model) {
|
||||
return NewAudioModelTranscriber(modelCfg)
|
||||
}
|
||||
}
|
||||
|
||||
// Direct Groq provider config takes priority.
|
||||
|
||||
@@ -57,6 +57,31 @@ func TestDetectTranscriber(t *testing.T) {
|
||||
},
|
||||
wantName: "audio-model",
|
||||
},
|
||||
{
|
||||
name: "voice model name selects azure audio model transcriber",
|
||||
cfg: &config.Config{
|
||||
Voice: config.VoiceConfig{ModelName: "voice-azure-audio"},
|
||||
ModelList: []config.ModelConfig{
|
||||
{
|
||||
ModelName: "voice-azure-audio",
|
||||
Model: "azure/my-audio-deployment",
|
||||
APIKey: "sk-azure",
|
||||
APIBase: "https://example.openai.azure.com",
|
||||
},
|
||||
},
|
||||
},
|
||||
wantName: "audio-model",
|
||||
},
|
||||
{
|
||||
name: "voice model name with non openai compatible protocol does not select audio model transcriber",
|
||||
cfg: &config.Config{
|
||||
Voice: config.VoiceConfig{ModelName: "voice-anthropic"},
|
||||
ModelList: []config.ModelConfig{
|
||||
{ModelName: "voice-anthropic", Model: "anthropic/claude-sonnet-4.6", APIKey: "sk-anthropic"},
|
||||
},
|
||||
},
|
||||
wantNil: true,
|
||||
},
|
||||
{
|
||||
name: "groq model list entry without key is skipped",
|
||||
cfg: &config.Config{
|
||||
|
||||
Reference in New Issue
Block a user