From 8ad4b9b497adca6ca150ab858975df8ce1d402b8 Mon Sep 17 00:00:00 2001 From: RussellLuo Date: Sun, 22 Mar 2026 19:51:32 +0800 Subject: [PATCH] feat(voice): add audio-model transcription support - Add `AudioModelTranscriber` for model-based audio transcription via LLM providers - Support selecting a transcription model with `voice.model_name` in config - Keep Groq transcription as a fallback and move it into dedicated files with focused tests - Serialize `data:audio/...` media as input_audio for OpenAI-compatible providers - Improve transcription logging by rendering error fields as strings - Add coverage for transcriber detection, audio-model behavior, provider audio serialization, and Groq transcription Fixes #1890. --- pkg/config/config.go | 3 +- pkg/config/defaults.go | 1 + pkg/logger/logger.go | 2 + pkg/logger/logger_test.go | 28 +++ pkg/providers/common/common.go | 31 ++++ pkg/providers/common/common_test.go | 38 ++++ pkg/voice/audio_model_transcriber.go | 115 ++++++++++++ pkg/voice/audio_model_transcriber_test.go | 203 ++++++++++++++++++++++ pkg/voice/groq_transcriber.go | 151 ++++++++++++++++ pkg/voice/groq_transcriber_test.go | 84 +++++++++ pkg/voice/transcriber.go | 153 +--------------- pkg/voice/transcriber_test.go | 110 ++++-------- 12 files changed, 693 insertions(+), 226 deletions(-) create mode 100644 pkg/voice/audio_model_transcriber.go create mode 100644 pkg/voice/audio_model_transcriber_test.go create mode 100644 pkg/voice/groq_transcriber.go create mode 100644 pkg/voice/groq_transcriber_test.go diff --git a/pkg/config/config.go b/pkg/config/config.go index eab770991..24fb819e6 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -564,7 +564,8 @@ type DevicesConfig struct { } type VoiceConfig struct { - EchoTranscription bool `json:"echo_transcription" env:"PICOCLAW_VOICE_ECHO_TRANSCRIPTION"` + ModelName string `json:"model_name,omitempty" env:"PICOCLAW_VOICE_MODEL_NAME"` + EchoTranscription bool `json:"echo_transcription" env:"PICOCLAW_VOICE_ECHO_TRANSCRIPTION"` } type ProvidersConfig struct { diff --git a/pkg/config/defaults.go b/pkg/config/defaults.go index f4056eca6..a496c96cc 100644 --- a/pkg/config/defaults.go +++ b/pkg/config/defaults.go @@ -567,6 +567,7 @@ func DefaultConfig() *Config { MonitorUSB: true, }, Voice: VoiceConfig{ + ModelName: "", EchoTranscription: false, }, BuildInfo: BuildInfo{ diff --git a/pkg/logger/logger.go b/pkg/logger/logger.go index 179804607..eeb1436de 100644 --- a/pkg/logger/logger.go +++ b/pkg/logger/logger.go @@ -256,6 +256,8 @@ func appendFields(event *zerolog.Event, fields map[string]any) { for k, v := range fields { // Type switch to avoid double JSON serialization of strings switch val := v.(type) { + case error: + event.Str(k, val.Error()) case string: event.Str(k, val) case int: diff --git a/pkg/logger/logger_test.go b/pkg/logger/logger_test.go index e551db58e..6ad3a8dd6 100644 --- a/pkg/logger/logger_test.go +++ b/pkg/logger/logger_test.go @@ -1,7 +1,12 @@ package logger import ( + "bytes" + "encoding/json" + "errors" "testing" + + "github.com/rs/zerolog" ) func TestLogLevelFiltering(t *testing.T) { @@ -337,3 +342,26 @@ func TestSetLevelFromString(t *testing.T) { t.Errorf("after SetLevelFromString(\"FATAL\"): GetLevel() = %v, want FATAL", got) } } + +func TestAppendFields_ErrorUsesErrorString(t *testing.T) { + var buf bytes.Buffer + l := zerolog.New(&buf) + + event := l.Info() + appendFields(event, map[string]any{"error": errors.New("transcription request failed")}) + event.Msg("test") + + lines := bytes.Split(bytes.TrimSpace(buf.Bytes()), []byte("\n")) + if len(lines) == 0 { + t.Fatal("expected log output, got none") + } + + var got map[string]any + if err := json.Unmarshal(lines[0], &got); err != nil { + t.Fatalf("unmarshal log line: %v", err) + } + + if got["error"] != "transcription request failed" { + t.Fatalf("error field = %#v, want %q", got["error"], "transcription request failed") + } +} diff --git a/pkg/providers/common/common.go b/pkg/providers/common/common.go index 23680a1bf..45a29e647 100644 --- a/pkg/providers/common/common.go +++ b/pkg/providers/common/common.go @@ -111,6 +111,17 @@ func SerializeMessages(messages []Message) []any { "url": mediaURL, }, }) + continue + } + + if format, data, ok := parseDataAudioURL(mediaURL); ok { + parts = append(parts, map[string]any{ + "type": "input_audio", + "input_audio": map[string]any{ + "data": data, + "format": format, + }, + }) } } @@ -132,6 +143,26 @@ func SerializeMessages(messages []Message) []any { return out } +func parseDataAudioURL(mediaURL string) (format, data string, ok bool) { + if !strings.HasPrefix(mediaURL, "data:audio/") { + return "", "", false + } + + payload := strings.TrimPrefix(mediaURL, "data:audio/") + meta, data, found := strings.Cut(payload, ",") + if !found { + return "", "", false + } + + format, _, _ = strings.Cut(meta, ";") + format = strings.TrimSpace(format) + data = strings.TrimSpace(data) + if format == "" || data == "" { + return "", "", false + } + return format, data, true +} + // --- Response parsing --- // ParseResponse parses a JSON chat completion response body into an LLMResponse. diff --git a/pkg/providers/common/common_test.go b/pkg/providers/common/common_test.go index bb7e7434d..79a637d48 100644 --- a/pkg/providers/common/common_test.go +++ b/pkg/providers/common/common_test.go @@ -91,6 +91,44 @@ func TestSerializeMessages_WithMedia(t *testing.T) { } } +func TestSerializeMessages_WithAudioMedia(t *testing.T) { + messages := []Message{ + {Role: "user", Content: "transcribe this", Media: []string{"data:audio/ogg;base64,abc123"}}, + } + result := SerializeMessages(messages) + + data, _ := json.Marshal(result) + var msgs []map[string]any + json.Unmarshal(data, &msgs) + + content, ok := msgs[0]["content"].([]any) + if !ok { + t.Fatalf("expected array content for media message, got %T", msgs[0]["content"]) + } + if len(content) != 2 { + t.Fatalf("expected 2 content parts, got %d", len(content)) + } + + audioPart, ok := content[1].(map[string]any) + if !ok { + t.Fatalf("expected audio content part to be an object, got %T", content[1]) + } + if audioPart["type"] != "input_audio" { + t.Fatalf("audio part type = %v, want input_audio", audioPart["type"]) + } + + inputAudio, ok := audioPart["input_audio"].(map[string]any) + if !ok { + t.Fatalf("expected input_audio object, got %T", audioPart["input_audio"]) + } + if inputAudio["format"] != "ogg" { + t.Fatalf("audio format = %v, want ogg", inputAudio["format"]) + } + if inputAudio["data"] != "abc123" { + t.Fatalf("audio data = %v, want abc123", inputAudio["data"]) + } +} + func TestSerializeMessages_MediaWithToolCallID(t *testing.T) { messages := []Message{ {Role: "tool", Content: "result", Media: []string{"data:image/png;base64,xyz"}, ToolCallID: "call_1"}, diff --git a/pkg/voice/audio_model_transcriber.go b/pkg/voice/audio_model_transcriber.go new file mode 100644 index 000000000..096e832fa --- /dev/null +++ b/pkg/voice/audio_model_transcriber.go @@ -0,0 +1,115 @@ +package voice + +import ( + "context" + "encoding/base64" + "fmt" + "os" + "path/filepath" + "strings" + + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/providers" + "github.com/sipeed/picoclaw/pkg/utils" +) + +type AudioModelTranscriber struct { + provider providers.LLMProvider + modelID string + prompt string +} + +const ( + defaultTranscriptionPrompt = "Transcribe this audio." +) + +func audioFormat(path string) (string, error) { + switch strings.ToLower(filepath.Ext(strings.TrimPrefix(path, "file://"))) { + case ".wav": + return "wav", nil + case ".mp3": + return "mp3", nil + case ".aiff", ".aif": + return "aiff", nil + case ".aac": + return "aac", nil + case ".ogg": + return "ogg", nil + case ".flac": + return "flac", nil + default: + return "", fmt.Errorf("unsupported audio format for %q", path) + } +} + +func NewAudioModelTranscriber(modelCfg *config.ModelConfig) *AudioModelTranscriber { + if modelCfg == nil { + return nil + } + + logger.DebugCF("voice", "Creating audio model transcriber", map[string]any{ + "has_api_key": modelCfg.APIKey != "", + "api_base": modelCfg.APIBase, + "model": modelCfg.Model, + }) + + provider, modelID, err := providers.CreateProviderFromConfig(modelCfg) + if err != nil { + logger.ErrorCF("voice", "Failed to create audio model provider", map[string]any{"error": err}) + return nil + } + + return &AudioModelTranscriber{ + provider: provider, + modelID: modelID, + prompt: defaultTranscriptionPrompt, + } +} + +func (t *AudioModelTranscriber) Transcribe(ctx context.Context, audioFilePath string) (*TranscriptionResponse, error) { + logger.InfoCF("voice", "Starting audio model transcription", map[string]any{ + "audio_file": audioFilePath, + "model": t.modelID, + }) + + audioBytes, err := os.ReadFile(audioFilePath) + if err != nil { + logger.ErrorCF("voice", "Failed to read audio file", map[string]any{"path": audioFilePath, "error": err}) + return nil, fmt.Errorf("failed to read audio file: %w", err) + } + + format, err := audioFormat(audioFilePath) + if err != nil { + logger.ErrorCF("voice", "Failed to detect audio format", map[string]any{"path": audioFilePath, "error": err}) + return nil, err + } + + resp, err := t.provider.Chat(ctx, []providers.Message{ + { + Role: "user", + Content: t.prompt, + Media: []string{ + fmt.Sprintf("data:audio/%s;base64,%s", format, base64.StdEncoding.EncodeToString(audioBytes)), + }, + }, + }, nil, t.modelID, map[string]any{ + "temperature": 0, + }) + if err != nil { + logger.ErrorCF("voice", "Audio model transcription request failed", map[string]any{"error": err}) + return nil, fmt.Errorf("transcription request failed: %w", err) + } + + text := strings.TrimSpace(resp.Content) + logger.InfoCF("voice", "Audio model transcription completed successfully", map[string]any{ + "text_length": len(text), + "transcription_preview": utils.Truncate(text, 50), + }) + + return &TranscriptionResponse{Text: text}, nil +} + +func (t *AudioModelTranscriber) Name() string { + return "audio-model" +} diff --git a/pkg/voice/audio_model_transcriber_test.go b/pkg/voice/audio_model_transcriber_test.go new file mode 100644 index 000000000..c33e3bf97 --- /dev/null +++ b/pkg/voice/audio_model_transcriber_test.go @@ -0,0 +1,203 @@ +package voice + +import ( + "context" + "encoding/base64" + "errors" + "os" + "path/filepath" + "testing" + + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/providers" +) + +var _ Transcriber = (*AudioModelTranscriber)(nil) + +type fakeLLMProvider struct { + chatFunc func( + ctx context.Context, + messages []providers.Message, + tools []providers.ToolDefinition, + model string, + options map[string]any, + ) (*providers.LLMResponse, error) +} + +func (p *fakeLLMProvider) Chat( + ctx context.Context, + messages []providers.Message, + tools []providers.ToolDefinition, + model string, + options map[string]any, +) (*providers.LLMResponse, error) { + if p.chatFunc == nil { + return nil, nil + } + return p.chatFunc(ctx, messages, tools, model, options) +} + +func (p *fakeLLMProvider) GetDefaultModel() string { + return "" +} + +func TestAudioModelTranscriberName(t *testing.T) { + tr := &AudioModelTranscriber{} + if got := tr.Name(); got != "audio-model" { + t.Errorf("Name() = %q, want %q", got, "audio-model") + } +} + +func TestNewAudioModelTranscriberInvalidConfig(t *testing.T) { + tests := []struct { + name string + cfg *config.ModelConfig + }{ + { + name: "nil config", + cfg: nil, + }, + { + name: "missing api key", + cfg: &config.ModelConfig{ + Model: "gemini/gemini-2.5-flash", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tr := NewAudioModelTranscriber(tt.cfg); tr != nil { + t.Fatalf("NewAudioModelTranscriber() = %#v, want nil", tr) + } + }) + } +} + +func TestAudioModelTranscriberTranscribe(t *testing.T) { + tmpDir := t.TempDir() + audioPath := filepath.Join(tmpDir, "clip.ogg") + audioData := []byte("fake-audio-data") + if err := os.WriteFile(audioPath, audioData, 0o644); err != nil { + t.Fatalf("failed to write fake audio file: %v", err) + } + + t.Run("success", func(t *testing.T) { + tr := &AudioModelTranscriber{ + provider: &fakeLLMProvider{ + chatFunc: func( + ctx context.Context, + messages []providers.Message, + tools []providers.ToolDefinition, + model string, + options map[string]any, + ) (*providers.LLMResponse, error) { + if ctx == nil { + t.Fatal("context should not be nil") + } + if tools != nil { + t.Fatalf("tools = %#v, want nil", tools) + } + if model != "gemini-2.5-flash" { + t.Fatalf("model = %q, want %q", model, "gemini-2.5-flash") + } + if len(messages) != 1 { + t.Fatalf("len(messages) = %d, want 1", len(messages)) + } + msg := messages[0] + if msg.Role != "user" { + t.Fatalf("role = %q, want %q", msg.Role, "user") + } + if msg.Content != defaultTranscriptionPrompt { + t.Fatalf("prompt = %q, want %q", msg.Content, defaultTranscriptionPrompt) + } + if len(msg.Media) != 1 { + t.Fatalf("len(media) = %d, want 1", len(msg.Media)) + } + wantMedia := "data:audio/ogg;base64," + base64.StdEncoding.EncodeToString(audioData) + if msg.Media[0] != wantMedia { + t.Fatalf("media = %q, want %q", msg.Media[0], wantMedia) + } + if len(options) != 1 { + t.Fatalf("options = %#v, want only temperature", options) + } + if got := options["temperature"]; got != 0 { + t.Fatalf("temperature = %#v, want 0", got) + } + + return &providers.LLMResponse{Content: " hello from gemini \n"}, nil + }, + }, + modelID: "gemini-2.5-flash", + prompt: defaultTranscriptionPrompt, + } + + resp, err := tr.Transcribe(context.Background(), audioPath) + if err != nil { + t.Fatalf("Transcribe() error: %v", err) + } + if resp.Text != "hello from gemini" { + t.Fatalf("Text = %q, want %q", resp.Text, "hello from gemini") + } + }) + + t.Run("provider error", func(t *testing.T) { + tr := &AudioModelTranscriber{ + provider: &fakeLLMProvider{ + chatFunc: func( + ctx context.Context, + messages []providers.Message, + tools []providers.ToolDefinition, + model string, + options map[string]any, + ) (*providers.LLMResponse, error) { + return nil, errors.New("upstream failure") + }, + }, + modelID: "gemini-2.5-flash", + prompt: defaultTranscriptionPrompt, + } + + _, err := tr.Transcribe(context.Background(), audioPath) + if err == nil { + t.Fatal("expected error for provider failure, got nil") + } + if got := err.Error(); got != "transcription request failed: upstream failure" { + t.Fatalf("error = %q, want %q", got, "transcription request failed: upstream failure") + } + }) + + t.Run("missing file", func(t *testing.T) { + tr := &AudioModelTranscriber{ + provider: &fakeLLMProvider{}, + modelID: "gemini-2.5-flash", + prompt: defaultTranscriptionPrompt, + } + + _, err := tr.Transcribe(context.Background(), filepath.Join(tmpDir, "nonexistent.ogg")) + if err == nil { + t.Fatal("expected error for missing file, got nil") + } + }) + + t.Run("unsupported audio format", func(t *testing.T) { + badPath := filepath.Join(tmpDir, "clip.txt") + if err := os.WriteFile(badPath, []byte("not-audio"), 0o644); err != nil { + t.Fatalf("failed to write fake file: %v", err) + } + + tr := &AudioModelTranscriber{ + provider: &fakeLLMProvider{}, + modelID: "gemini-2.5-flash", + prompt: defaultTranscriptionPrompt, + } + + _, err := tr.Transcribe(context.Background(), badPath) + if err == nil { + t.Fatal("expected error for unsupported audio format, got nil") + } + if got := err.Error(); got != `unsupported audio format for "`+badPath+`"` { + t.Fatalf("error = %q, want unsupported format error", got) + } + }) +} diff --git a/pkg/voice/groq_transcriber.go b/pkg/voice/groq_transcriber.go new file mode 100644 index 000000000..b42e598f7 --- /dev/null +++ b/pkg/voice/groq_transcriber.go @@ -0,0 +1,151 @@ +package voice + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "mime/multipart" + "net/http" + "os" + "path/filepath" + "time" + + "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/utils" +) + +type GroqTranscriber struct { + apiKey string + apiBase string + httpClient *http.Client +} + +func NewGroqTranscriber(apiKey string) *GroqTranscriber { + logger.DebugCF("voice", "Creating Groq transcriber", map[string]any{"has_api_key": apiKey != ""}) + + apiBase := "https://api.groq.com/openai/v1" + return &GroqTranscriber{ + apiKey: apiKey, + apiBase: apiBase, + httpClient: &http.Client{ + Timeout: 60 * time.Second, + }, + } +} + +func (t *GroqTranscriber) Transcribe(ctx context.Context, audioFilePath string) (*TranscriptionResponse, error) { + logger.InfoCF("voice", "Starting transcription", map[string]any{"audio_file": audioFilePath}) + + audioFile, err := os.Open(audioFilePath) + if err != nil { + logger.ErrorCF("voice", "Failed to open audio file", map[string]any{"path": audioFilePath, "error": err}) + return nil, fmt.Errorf("failed to open audio file: %w", err) + } + defer audioFile.Close() + + fileInfo, err := audioFile.Stat() + if err != nil { + logger.ErrorCF("voice", "Failed to get file info", map[string]any{"path": audioFilePath, "error": err}) + return nil, fmt.Errorf("failed to get file info: %w", err) + } + + logger.DebugCF("voice", "Audio file details", map[string]any{ + "size_bytes": fileInfo.Size(), + "file_name": filepath.Base(audioFilePath), + }) + + var requestBody bytes.Buffer + writer := multipart.NewWriter(&requestBody) + + part, err := writer.CreateFormFile("file", filepath.Base(audioFilePath)) + if err != nil { + logger.ErrorCF("voice", "Failed to create form file", map[string]any{"error": err}) + return nil, fmt.Errorf("failed to create form file: %w", err) + } + + copied, err := io.Copy(part, audioFile) + if err != nil { + logger.ErrorCF("voice", "Failed to copy file content", map[string]any{"error": err}) + return nil, fmt.Errorf("failed to copy file content: %w", err) + } + + logger.DebugCF("voice", "File copied to request", map[string]any{"bytes_copied": copied}) + + if err = writer.WriteField("model", "whisper-large-v3"); err != nil { + logger.ErrorCF("voice", "Failed to write model field", map[string]any{"error": err}) + return nil, fmt.Errorf("failed to write model field: %w", err) + } + + if err = writer.WriteField("response_format", "json"); err != nil { + logger.ErrorCF("voice", "Failed to write response_format field", map[string]any{"error": err}) + return nil, fmt.Errorf("failed to write response_format field: %w", err) + } + + if err = writer.Close(); err != nil { + logger.ErrorCF("voice", "Failed to close multipart writer", map[string]any{"error": err}) + return nil, fmt.Errorf("failed to close multipart writer: %w", err) + } + + url := t.apiBase + "/audio/transcriptions" + req, err := http.NewRequestWithContext(ctx, "POST", url, &requestBody) + if err != nil { + logger.ErrorCF("voice", "Failed to create request", map[string]any{"error": err}) + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", writer.FormDataContentType()) + req.Header.Set("Authorization", "Bearer "+t.apiKey) + + logger.DebugCF("voice", "Sending transcription request to Groq API", map[string]any{ + "url": url, + "request_size_bytes": requestBody.Len(), + "file_size_bytes": fileInfo.Size(), + }) + + resp, err := t.httpClient.Do(req) + if err != nil { + logger.ErrorCF("voice", "Failed to send request", map[string]any{"error": err}) + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + logger.ErrorCF("voice", "Failed to read response", map[string]any{"error": err}) + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + logger.ErrorCF("voice", "API error", map[string]any{ + "status_code": resp.StatusCode, + "response": string(body), + }) + return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body)) + } + + logger.DebugCF("voice", "Received response from Groq API", map[string]any{ + "status_code": resp.StatusCode, + "response_size_bytes": len(body), + }) + + var result TranscriptionResponse + if err := json.Unmarshal(body, &result); err != nil { + logger.ErrorCF("voice", "Failed to unmarshal response", map[string]any{"error": err}) + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + + logger.InfoCF("voice", "Transcription completed successfully", map[string]any{ + "text_length": len(result.Text), + "language": result.Language, + "duration_seconds": result.Duration, + "transcription_preview": utils.Truncate(result.Text, 50), + }) + + return &result, nil +} + +func (t *GroqTranscriber) Name() string { + return "groq" +} diff --git a/pkg/voice/groq_transcriber_test.go b/pkg/voice/groq_transcriber_test.go new file mode 100644 index 000000000..fdcaa7580 --- /dev/null +++ b/pkg/voice/groq_transcriber_test.go @@ -0,0 +1,84 @@ +package voice + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" +) + +var _ Transcriber = (*GroqTranscriber)(nil) + +func TestGroqTranscriberName(t *testing.T) { + tr := NewGroqTranscriber("sk-test") + if got := tr.Name(); got != "groq" { + t.Errorf("Name() = %q, want %q", got, "groq") + } +} + +func TestGroqTranscribe(t *testing.T) { + // Write a minimal fake audio file so the transcriber can open and send it. + tmpDir := t.TempDir() + audioPath := filepath.Join(tmpDir, "clip.ogg") + if err := os.WriteFile(audioPath, []byte("fake-audio-data"), 0o644); err != nil { + t.Fatalf("failed to write fake audio file: %v", err) + } + + t.Run("success", func(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/audio/transcriptions" { + t.Errorf("unexpected path: %s", r.URL.Path) + } + if r.Header.Get("Authorization") != "Bearer sk-test" { + t.Errorf("unexpected Authorization header: %s", r.Header.Get("Authorization")) + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(TranscriptionResponse{ + Text: "hello world", + Language: "en", + Duration: 1.5, + }) + })) + defer srv.Close() + + tr := NewGroqTranscriber("sk-test") + tr.apiBase = srv.URL + + resp, err := tr.Transcribe(context.Background(), audioPath) + if err != nil { + t.Fatalf("Transcribe() error: %v", err) + } + if resp.Text != "hello world" { + t.Errorf("Text = %q, want %q", resp.Text, "hello world") + } + if resp.Language != "en" { + t.Errorf("Language = %q, want %q", resp.Language, "en") + } + }) + + t.Run("api error", func(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, `{"error":"invalid_api_key"}`, http.StatusUnauthorized) + })) + defer srv.Close() + + tr := NewGroqTranscriber("sk-bad") + tr.apiBase = srv.URL + + _, err := tr.Transcribe(context.Background(), audioPath) + if err == nil { + t.Fatal("expected error for non-200 response, got nil") + } + }) + + t.Run("missing file", func(t *testing.T) { + tr := NewGroqTranscriber("sk-test") + _, err := tr.Transcribe(context.Background(), filepath.Join(tmpDir, "nonexistent.ogg")) + if err == nil { + t.Fatal("expected error for missing file, got nil") + } + }) +} diff --git a/pkg/voice/transcriber.go b/pkg/voice/transcriber.go index e949d7a22..36ee92881 100644 --- a/pkg/voice/transcriber.go +++ b/pkg/voice/transcriber.go @@ -1,21 +1,10 @@ package voice import ( - "bytes" "context" - "encoding/json" - "fmt" - "io" - "mime/multipart" - "net/http" - "os" - "path/filepath" "strings" - "time" "github.com/sipeed/picoclaw/pkg/config" - "github.com/sipeed/picoclaw/pkg/logger" - "github.com/sipeed/picoclaw/pkg/utils" ) type Transcriber interface { @@ -23,149 +12,23 @@ type Transcriber interface { Transcribe(ctx context.Context, audioFilePath string) (*TranscriptionResponse, error) } -type GroqTranscriber struct { - apiKey string - apiBase string - httpClient *http.Client -} - type TranscriptionResponse struct { Text string `json:"text"` Language string `json:"language,omitempty"` Duration float64 `json:"duration,omitempty"` } -func NewGroqTranscriber(apiKey string) *GroqTranscriber { - logger.DebugCF("voice", "Creating Groq transcriber", map[string]any{"has_api_key": apiKey != ""}) - - apiBase := "https://api.groq.com/openai/v1" - return &GroqTranscriber{ - apiKey: apiKey, - apiBase: apiBase, - httpClient: &http.Client{ - Timeout: 60 * time.Second, - }, - } -} - -func (t *GroqTranscriber) Transcribe(ctx context.Context, audioFilePath string) (*TranscriptionResponse, error) { - logger.InfoCF("voice", "Starting transcription", map[string]any{"audio_file": audioFilePath}) - - audioFile, err := os.Open(audioFilePath) - if err != nil { - logger.ErrorCF("voice", "Failed to open audio file", map[string]any{"path": audioFilePath, "error": err}) - return nil, fmt.Errorf("failed to open audio file: %w", err) - } - defer audioFile.Close() - - fileInfo, err := audioFile.Stat() - if err != nil { - logger.ErrorCF("voice", "Failed to get file info", map[string]any{"path": audioFilePath, "error": err}) - return nil, fmt.Errorf("failed to get file info: %w", err) - } - - logger.DebugCF("voice", "Audio file details", map[string]any{ - "size_bytes": fileInfo.Size(), - "file_name": filepath.Base(audioFilePath), - }) - - var requestBody bytes.Buffer - writer := multipart.NewWriter(&requestBody) - - part, err := writer.CreateFormFile("file", filepath.Base(audioFilePath)) - if err != nil { - logger.ErrorCF("voice", "Failed to create form file", map[string]any{"error": err}) - return nil, fmt.Errorf("failed to create form file: %w", err) - } - - copied, err := io.Copy(part, audioFile) - if err != nil { - logger.ErrorCF("voice", "Failed to copy file content", map[string]any{"error": err}) - return nil, fmt.Errorf("failed to copy file content: %w", err) - } - - logger.DebugCF("voice", "File copied to request", map[string]any{"bytes_copied": copied}) - - if err = writer.WriteField("model", "whisper-large-v3"); err != nil { - logger.ErrorCF("voice", "Failed to write model field", map[string]any{"error": err}) - return nil, fmt.Errorf("failed to write model field: %w", err) - } - - if err = writer.WriteField("response_format", "json"); err != nil { - logger.ErrorCF("voice", "Failed to write response_format field", map[string]any{"error": err}) - return nil, fmt.Errorf("failed to write response_format field: %w", err) - } - - if err = writer.Close(); err != nil { - logger.ErrorCF("voice", "Failed to close multipart writer", map[string]any{"error": err}) - return nil, fmt.Errorf("failed to close multipart writer: %w", err) - } - - url := t.apiBase + "/audio/transcriptions" - req, err := http.NewRequestWithContext(ctx, "POST", url, &requestBody) - if err != nil { - logger.ErrorCF("voice", "Failed to create request", map[string]any{"error": err}) - return nil, fmt.Errorf("failed to create request: %w", err) - } - - req.Header.Set("Content-Type", writer.FormDataContentType()) - req.Header.Set("Authorization", "Bearer "+t.apiKey) - - logger.DebugCF("voice", "Sending transcription request to Groq API", map[string]any{ - "url": url, - "request_size_bytes": requestBody.Len(), - "file_size_bytes": fileInfo.Size(), - }) - - resp, err := t.httpClient.Do(req) - if err != nil { - logger.ErrorCF("voice", "Failed to send request", map[string]any{"error": err}) - return nil, fmt.Errorf("failed to send request: %w", err) - } - defer resp.Body.Close() - - body, err := io.ReadAll(resp.Body) - if err != nil { - logger.ErrorCF("voice", "Failed to read response", map[string]any{"error": err}) - return nil, fmt.Errorf("failed to read response: %w", err) - } - - if resp.StatusCode != http.StatusOK { - logger.ErrorCF("voice", "API error", map[string]any{ - "status_code": resp.StatusCode, - "response": string(body), - }) - return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body)) - } - - logger.DebugCF("voice", "Received response from Groq API", map[string]any{ - "status_code": resp.StatusCode, - "response_size_bytes": len(body), - }) - - var result TranscriptionResponse - if err := json.Unmarshal(body, &result); err != nil { - logger.ErrorCF("voice", "Failed to unmarshal response", map[string]any{"error": err}) - return nil, fmt.Errorf("failed to unmarshal response: %w", err) - } - - logger.InfoCF("voice", "Transcription completed successfully", map[string]any{ - "text_length": len(result.Text), - "language": result.Language, - "duration_seconds": result.Duration, - "transcription_preview": utils.Truncate(result.Text, 50), - }) - - return &result, nil -} - -func (t *GroqTranscriber) Name() string { - return "groq" -} - // DetectTranscriber inspects cfg and returns the appropriate Transcriber, or // nil if no supported transcription provider is configured. func DetectTranscriber(cfg *config.Config) Transcriber { + if modelName := strings.TrimSpace(cfg.Voice.ModelName); modelName != "" { + modelCfg, err := cfg.GetModelConfig(modelName) + if err != nil { + return nil + } + return NewAudioModelTranscriber(modelCfg) + } + // Direct Groq provider config takes priority. if key := cfg.Providers.Groq.APIKey; key != "" { return NewGroqTranscriber(key) diff --git a/pkg/voice/transcriber_test.go b/pkg/voice/transcriber_test.go index 9b6add333..753ee5e78 100644 --- a/pkg/voice/transcriber_test.go +++ b/pkg/voice/transcriber_test.go @@ -1,27 +1,11 @@ package voice import ( - "context" - "encoding/json" - "net/http" - "net/http/httptest" - "os" - "path/filepath" "testing" "github.com/sipeed/picoclaw/pkg/config" ) -// Ensure GroqTranscriber satisfies the Transcriber interface at compile time. -var _ Transcriber = (*GroqTranscriber)(nil) - -func TestGroqTranscriberName(t *testing.T) { - tr := NewGroqTranscriber("sk-test") - if got := tr.Name(); got != "groq" { - t.Errorf("Name() = %q, want %q", got, "groq") - } -} - func TestDetectTranscriber(t *testing.T) { tests := []struct { name string @@ -43,6 +27,16 @@ func TestDetectTranscriber(t *testing.T) { }, wantName: "groq", }, + { + name: "voice model name selects audio model transcriber", + cfg: &config.Config{ + Voice: config.VoiceConfig{ModelName: "voice-gemini"}, + ModelList: []config.ModelConfig{ + {ModelName: "voice-gemini", Model: "gemini/gemini-2.5-flash", APIKey: "sk-gemini-model"}, + }, + }, + wantName: "audio-model", + }, { name: "groq via model list", cfg: &config.Config{ @@ -53,6 +47,16 @@ func TestDetectTranscriber(t *testing.T) { }, wantName: "groq", }, + { + name: "voice model name selects non-gemini audio model transcriber", + cfg: &config.Config{ + Voice: config.VoiceConfig{ModelName: "voice-openai-audio"}, + ModelList: []config.ModelConfig{ + {ModelName: "voice-openai-audio", Model: "openai/gpt-4o-audio-preview", APIKey: "sk-openai"}, + }, + }, + wantName: "audio-model", + }, { name: "groq model list entry without key is skipped", cfg: &config.Config{ @@ -74,6 +78,16 @@ func TestDetectTranscriber(t *testing.T) { }, wantName: "groq", }, + { + name: "missing voice model name config returns nil", + cfg: &config.Config{ + Voice: config.VoiceConfig{ModelName: "missing"}, + ModelList: []config.ModelConfig{ + {ModelName: "other", Model: "gemini/gemini-2.5-flash", APIKey: "sk-gemini-model"}, + }, + }, + wantNil: true, + }, } for _, tc := range tests { @@ -94,67 +108,3 @@ func TestDetectTranscriber(t *testing.T) { }) } } - -func TestTranscribe(t *testing.T) { - // Write a minimal fake audio file so the transcriber can open and send it. - tmpDir := t.TempDir() - audioPath := filepath.Join(tmpDir, "clip.ogg") - if err := os.WriteFile(audioPath, []byte("fake-audio-data"), 0o644); err != nil { - t.Fatalf("failed to write fake audio file: %v", err) - } - - t.Run("success", func(t *testing.T) { - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.URL.Path != "/audio/transcriptions" { - t.Errorf("unexpected path: %s", r.URL.Path) - } - if r.Header.Get("Authorization") != "Bearer sk-test" { - t.Errorf("unexpected Authorization header: %s", r.Header.Get("Authorization")) - } - w.Header().Set("Content-Type", "application/json") - _ = json.NewEncoder(w).Encode(TranscriptionResponse{ - Text: "hello world", - Language: "en", - Duration: 1.5, - }) - })) - defer srv.Close() - - tr := NewGroqTranscriber("sk-test") - tr.apiBase = srv.URL - - resp, err := tr.Transcribe(context.Background(), audioPath) - if err != nil { - t.Fatalf("Transcribe() error: %v", err) - } - if resp.Text != "hello world" { - t.Errorf("Text = %q, want %q", resp.Text, "hello world") - } - if resp.Language != "en" { - t.Errorf("Language = %q, want %q", resp.Language, "en") - } - }) - - t.Run("api error", func(t *testing.T) { - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - http.Error(w, `{"error":"invalid_api_key"}`, http.StatusUnauthorized) - })) - defer srv.Close() - - tr := NewGroqTranscriber("sk-bad") - tr.apiBase = srv.URL - - _, err := tr.Transcribe(context.Background(), audioPath) - if err == nil { - t.Fatal("expected error for non-200 response, got nil") - } - }) - - t.Run("missing file", func(t *testing.T) { - tr := NewGroqTranscriber("sk-test") - _, err := tr.Transcribe(context.Background(), filepath.Join(tmpDir, "nonexistent.ogg")) - if err == nil { - t.Fatal("expected error for missing file, got nil") - } - }) -}