mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
8ad4b9b497
- Add `AudioModelTranscriber` for model-based audio transcription via LLM providers - Support selecting a transcription model with `voice.model_name` in config - Keep Groq transcription as a fallback and move it into dedicated files with focused tests - Serialize `data:audio/...` media as input_audio for OpenAI-compatible providers - Improve transcription logging by rendering error fields as strings - Add coverage for transcriber detection, audio-model behavior, provider audio serialization, and Groq transcription Fixes #1890.
204 lines
5.3 KiB
Go
204 lines
5.3 KiB
Go
package voice
|
|
|
|
import (
|
|
"context"
|
|
"encoding/base64"
|
|
"errors"
|
|
"os"
|
|
"path/filepath"
|
|
"testing"
|
|
|
|
"github.com/sipeed/picoclaw/pkg/config"
|
|
"github.com/sipeed/picoclaw/pkg/providers"
|
|
)
|
|
|
|
var _ Transcriber = (*AudioModelTranscriber)(nil)
|
|
|
|
type fakeLLMProvider struct {
|
|
chatFunc func(
|
|
ctx context.Context,
|
|
messages []providers.Message,
|
|
tools []providers.ToolDefinition,
|
|
model string,
|
|
options map[string]any,
|
|
) (*providers.LLMResponse, error)
|
|
}
|
|
|
|
func (p *fakeLLMProvider) Chat(
|
|
ctx context.Context,
|
|
messages []providers.Message,
|
|
tools []providers.ToolDefinition,
|
|
model string,
|
|
options map[string]any,
|
|
) (*providers.LLMResponse, error) {
|
|
if p.chatFunc == nil {
|
|
return nil, nil
|
|
}
|
|
return p.chatFunc(ctx, messages, tools, model, options)
|
|
}
|
|
|
|
func (p *fakeLLMProvider) GetDefaultModel() string {
|
|
return ""
|
|
}
|
|
|
|
func TestAudioModelTranscriberName(t *testing.T) {
|
|
tr := &AudioModelTranscriber{}
|
|
if got := tr.Name(); got != "audio-model" {
|
|
t.Errorf("Name() = %q, want %q", got, "audio-model")
|
|
}
|
|
}
|
|
|
|
func TestNewAudioModelTranscriberInvalidConfig(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
cfg *config.ModelConfig
|
|
}{
|
|
{
|
|
name: "nil config",
|
|
cfg: nil,
|
|
},
|
|
{
|
|
name: "missing api key",
|
|
cfg: &config.ModelConfig{
|
|
Model: "gemini/gemini-2.5-flash",
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
if tr := NewAudioModelTranscriber(tt.cfg); tr != nil {
|
|
t.Fatalf("NewAudioModelTranscriber() = %#v, want nil", tr)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestAudioModelTranscriberTranscribe(t *testing.T) {
|
|
tmpDir := t.TempDir()
|
|
audioPath := filepath.Join(tmpDir, "clip.ogg")
|
|
audioData := []byte("fake-audio-data")
|
|
if err := os.WriteFile(audioPath, audioData, 0o644); err != nil {
|
|
t.Fatalf("failed to write fake audio file: %v", err)
|
|
}
|
|
|
|
t.Run("success", func(t *testing.T) {
|
|
tr := &AudioModelTranscriber{
|
|
provider: &fakeLLMProvider{
|
|
chatFunc: func(
|
|
ctx context.Context,
|
|
messages []providers.Message,
|
|
tools []providers.ToolDefinition,
|
|
model string,
|
|
options map[string]any,
|
|
) (*providers.LLMResponse, error) {
|
|
if ctx == nil {
|
|
t.Fatal("context should not be nil")
|
|
}
|
|
if tools != nil {
|
|
t.Fatalf("tools = %#v, want nil", tools)
|
|
}
|
|
if model != "gemini-2.5-flash" {
|
|
t.Fatalf("model = %q, want %q", model, "gemini-2.5-flash")
|
|
}
|
|
if len(messages) != 1 {
|
|
t.Fatalf("len(messages) = %d, want 1", len(messages))
|
|
}
|
|
msg := messages[0]
|
|
if msg.Role != "user" {
|
|
t.Fatalf("role = %q, want %q", msg.Role, "user")
|
|
}
|
|
if msg.Content != defaultTranscriptionPrompt {
|
|
t.Fatalf("prompt = %q, want %q", msg.Content, defaultTranscriptionPrompt)
|
|
}
|
|
if len(msg.Media) != 1 {
|
|
t.Fatalf("len(media) = %d, want 1", len(msg.Media))
|
|
}
|
|
wantMedia := "data:audio/ogg;base64," + base64.StdEncoding.EncodeToString(audioData)
|
|
if msg.Media[0] != wantMedia {
|
|
t.Fatalf("media = %q, want %q", msg.Media[0], wantMedia)
|
|
}
|
|
if len(options) != 1 {
|
|
t.Fatalf("options = %#v, want only temperature", options)
|
|
}
|
|
if got := options["temperature"]; got != 0 {
|
|
t.Fatalf("temperature = %#v, want 0", got)
|
|
}
|
|
|
|
return &providers.LLMResponse{Content: " hello from gemini \n"}, nil
|
|
},
|
|
},
|
|
modelID: "gemini-2.5-flash",
|
|
prompt: defaultTranscriptionPrompt,
|
|
}
|
|
|
|
resp, err := tr.Transcribe(context.Background(), audioPath)
|
|
if err != nil {
|
|
t.Fatalf("Transcribe() error: %v", err)
|
|
}
|
|
if resp.Text != "hello from gemini" {
|
|
t.Fatalf("Text = %q, want %q", resp.Text, "hello from gemini")
|
|
}
|
|
})
|
|
|
|
t.Run("provider error", func(t *testing.T) {
|
|
tr := &AudioModelTranscriber{
|
|
provider: &fakeLLMProvider{
|
|
chatFunc: func(
|
|
ctx context.Context,
|
|
messages []providers.Message,
|
|
tools []providers.ToolDefinition,
|
|
model string,
|
|
options map[string]any,
|
|
) (*providers.LLMResponse, error) {
|
|
return nil, errors.New("upstream failure")
|
|
},
|
|
},
|
|
modelID: "gemini-2.5-flash",
|
|
prompt: defaultTranscriptionPrompt,
|
|
}
|
|
|
|
_, err := tr.Transcribe(context.Background(), audioPath)
|
|
if err == nil {
|
|
t.Fatal("expected error for provider failure, got nil")
|
|
}
|
|
if got := err.Error(); got != "transcription request failed: upstream failure" {
|
|
t.Fatalf("error = %q, want %q", got, "transcription request failed: upstream failure")
|
|
}
|
|
})
|
|
|
|
t.Run("missing file", func(t *testing.T) {
|
|
tr := &AudioModelTranscriber{
|
|
provider: &fakeLLMProvider{},
|
|
modelID: "gemini-2.5-flash",
|
|
prompt: defaultTranscriptionPrompt,
|
|
}
|
|
|
|
_, err := tr.Transcribe(context.Background(), filepath.Join(tmpDir, "nonexistent.ogg"))
|
|
if err == nil {
|
|
t.Fatal("expected error for missing file, got nil")
|
|
}
|
|
})
|
|
|
|
t.Run("unsupported audio format", func(t *testing.T) {
|
|
badPath := filepath.Join(tmpDir, "clip.txt")
|
|
if err := os.WriteFile(badPath, []byte("not-audio"), 0o644); err != nil {
|
|
t.Fatalf("failed to write fake file: %v", err)
|
|
}
|
|
|
|
tr := &AudioModelTranscriber{
|
|
provider: &fakeLLMProvider{},
|
|
modelID: "gemini-2.5-flash",
|
|
prompt: defaultTranscriptionPrompt,
|
|
}
|
|
|
|
_, err := tr.Transcribe(context.Background(), badPath)
|
|
if err == nil {
|
|
t.Fatal("expected error for unsupported audio format, got nil")
|
|
}
|
|
if got := err.Error(); got != `unsupported audio format for "`+badPath+`"` {
|
|
t.Fatalf("error = %q, want unsupported format error", got)
|
|
}
|
|
})
|
|
}
|