mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
feat(voice): add audio-model transcription support
- Add `AudioModelTranscriber` for model-based audio transcription via LLM providers - Support selecting a transcription model with `voice.model_name` in config - Keep Groq transcription as a fallback and move it into dedicated files with focused tests - Serialize `data:audio/...` media as input_audio for OpenAI-compatible providers - Improve transcription logging by rendering error fields as strings - Add coverage for transcriber detection, audio-model behavior, provider audio serialization, and Groq transcription Fixes #1890.
This commit is contained in:
@@ -564,7 +564,8 @@ type DevicesConfig struct {
|
||||
}
|
||||
|
||||
type VoiceConfig struct {
|
||||
EchoTranscription bool `json:"echo_transcription" env:"PICOCLAW_VOICE_ECHO_TRANSCRIPTION"`
|
||||
ModelName string `json:"model_name,omitempty" env:"PICOCLAW_VOICE_MODEL_NAME"`
|
||||
EchoTranscription bool `json:"echo_transcription" env:"PICOCLAW_VOICE_ECHO_TRANSCRIPTION"`
|
||||
}
|
||||
|
||||
type ProvidersConfig struct {
|
||||
|
||||
@@ -567,6 +567,7 @@ func DefaultConfig() *Config {
|
||||
MonitorUSB: true,
|
||||
},
|
||||
Voice: VoiceConfig{
|
||||
ModelName: "",
|
||||
EchoTranscription: false,
|
||||
},
|
||||
BuildInfo: BuildInfo{
|
||||
|
||||
@@ -256,6 +256,8 @@ func appendFields(event *zerolog.Event, fields map[string]any) {
|
||||
for k, v := range fields {
|
||||
// Type switch to avoid double JSON serialization of strings
|
||||
switch val := v.(type) {
|
||||
case error:
|
||||
event.Str(k, val.Error())
|
||||
case string:
|
||||
event.Str(k, val)
|
||||
case int:
|
||||
|
||||
@@ -1,7 +1,12 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
func TestLogLevelFiltering(t *testing.T) {
|
||||
@@ -337,3 +342,26 @@ func TestSetLevelFromString(t *testing.T) {
|
||||
t.Errorf("after SetLevelFromString(\"FATAL\"): GetLevel() = %v, want FATAL", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppendFields_ErrorUsesErrorString(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
l := zerolog.New(&buf)
|
||||
|
||||
event := l.Info()
|
||||
appendFields(event, map[string]any{"error": errors.New("transcription request failed")})
|
||||
event.Msg("test")
|
||||
|
||||
lines := bytes.Split(bytes.TrimSpace(buf.Bytes()), []byte("\n"))
|
||||
if len(lines) == 0 {
|
||||
t.Fatal("expected log output, got none")
|
||||
}
|
||||
|
||||
var got map[string]any
|
||||
if err := json.Unmarshal(lines[0], &got); err != nil {
|
||||
t.Fatalf("unmarshal log line: %v", err)
|
||||
}
|
||||
|
||||
if got["error"] != "transcription request failed" {
|
||||
t.Fatalf("error field = %#v, want %q", got["error"], "transcription request failed")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -111,6 +111,17 @@ func SerializeMessages(messages []Message) []any {
|
||||
"url": mediaURL,
|
||||
},
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
if format, data, ok := parseDataAudioURL(mediaURL); ok {
|
||||
parts = append(parts, map[string]any{
|
||||
"type": "input_audio",
|
||||
"input_audio": map[string]any{
|
||||
"data": data,
|
||||
"format": format,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -132,6 +143,26 @@ func SerializeMessages(messages []Message) []any {
|
||||
return out
|
||||
}
|
||||
|
||||
func parseDataAudioURL(mediaURL string) (format, data string, ok bool) {
|
||||
if !strings.HasPrefix(mediaURL, "data:audio/") {
|
||||
return "", "", false
|
||||
}
|
||||
|
||||
payload := strings.TrimPrefix(mediaURL, "data:audio/")
|
||||
meta, data, found := strings.Cut(payload, ",")
|
||||
if !found {
|
||||
return "", "", false
|
||||
}
|
||||
|
||||
format, _, _ = strings.Cut(meta, ";")
|
||||
format = strings.TrimSpace(format)
|
||||
data = strings.TrimSpace(data)
|
||||
if format == "" || data == "" {
|
||||
return "", "", false
|
||||
}
|
||||
return format, data, true
|
||||
}
|
||||
|
||||
// --- Response parsing ---
|
||||
|
||||
// ParseResponse parses a JSON chat completion response body into an LLMResponse.
|
||||
|
||||
@@ -91,6 +91,44 @@ func TestSerializeMessages_WithMedia(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSerializeMessages_WithAudioMedia(t *testing.T) {
|
||||
messages := []Message{
|
||||
{Role: "user", Content: "transcribe this", Media: []string{"data:audio/ogg;base64,abc123"}},
|
||||
}
|
||||
result := SerializeMessages(messages)
|
||||
|
||||
data, _ := json.Marshal(result)
|
||||
var msgs []map[string]any
|
||||
json.Unmarshal(data, &msgs)
|
||||
|
||||
content, ok := msgs[0]["content"].([]any)
|
||||
if !ok {
|
||||
t.Fatalf("expected array content for media message, got %T", msgs[0]["content"])
|
||||
}
|
||||
if len(content) != 2 {
|
||||
t.Fatalf("expected 2 content parts, got %d", len(content))
|
||||
}
|
||||
|
||||
audioPart, ok := content[1].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("expected audio content part to be an object, got %T", content[1])
|
||||
}
|
||||
if audioPart["type"] != "input_audio" {
|
||||
t.Fatalf("audio part type = %v, want input_audio", audioPart["type"])
|
||||
}
|
||||
|
||||
inputAudio, ok := audioPart["input_audio"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("expected input_audio object, got %T", audioPart["input_audio"])
|
||||
}
|
||||
if inputAudio["format"] != "ogg" {
|
||||
t.Fatalf("audio format = %v, want ogg", inputAudio["format"])
|
||||
}
|
||||
if inputAudio["data"] != "abc123" {
|
||||
t.Fatalf("audio data = %v, want abc123", inputAudio["data"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestSerializeMessages_MediaWithToolCallID(t *testing.T) {
|
||||
messages := []Message{
|
||||
{Role: "tool", Content: "result", Media: []string{"data:image/png;base64,xyz"}, ToolCallID: "call_1"},
|
||||
|
||||
@@ -0,0 +1,115 @@
|
||||
package voice
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/utils"
|
||||
)
|
||||
|
||||
type AudioModelTranscriber struct {
|
||||
provider providers.LLMProvider
|
||||
modelID string
|
||||
prompt string
|
||||
}
|
||||
|
||||
const (
|
||||
defaultTranscriptionPrompt = "Transcribe this audio."
|
||||
)
|
||||
|
||||
func audioFormat(path string) (string, error) {
|
||||
switch strings.ToLower(filepath.Ext(strings.TrimPrefix(path, "file://"))) {
|
||||
case ".wav":
|
||||
return "wav", nil
|
||||
case ".mp3":
|
||||
return "mp3", nil
|
||||
case ".aiff", ".aif":
|
||||
return "aiff", nil
|
||||
case ".aac":
|
||||
return "aac", nil
|
||||
case ".ogg":
|
||||
return "ogg", nil
|
||||
case ".flac":
|
||||
return "flac", nil
|
||||
default:
|
||||
return "", fmt.Errorf("unsupported audio format for %q", path)
|
||||
}
|
||||
}
|
||||
|
||||
func NewAudioModelTranscriber(modelCfg *config.ModelConfig) *AudioModelTranscriber {
|
||||
if modelCfg == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
logger.DebugCF("voice", "Creating audio model transcriber", map[string]any{
|
||||
"has_api_key": modelCfg.APIKey != "",
|
||||
"api_base": modelCfg.APIBase,
|
||||
"model": modelCfg.Model,
|
||||
})
|
||||
|
||||
provider, modelID, err := providers.CreateProviderFromConfig(modelCfg)
|
||||
if err != nil {
|
||||
logger.ErrorCF("voice", "Failed to create audio model provider", map[string]any{"error": err})
|
||||
return nil
|
||||
}
|
||||
|
||||
return &AudioModelTranscriber{
|
||||
provider: provider,
|
||||
modelID: modelID,
|
||||
prompt: defaultTranscriptionPrompt,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *AudioModelTranscriber) Transcribe(ctx context.Context, audioFilePath string) (*TranscriptionResponse, error) {
|
||||
logger.InfoCF("voice", "Starting audio model transcription", map[string]any{
|
||||
"audio_file": audioFilePath,
|
||||
"model": t.modelID,
|
||||
})
|
||||
|
||||
audioBytes, err := os.ReadFile(audioFilePath)
|
||||
if err != nil {
|
||||
logger.ErrorCF("voice", "Failed to read audio file", map[string]any{"path": audioFilePath, "error": err})
|
||||
return nil, fmt.Errorf("failed to read audio file: %w", err)
|
||||
}
|
||||
|
||||
format, err := audioFormat(audioFilePath)
|
||||
if err != nil {
|
||||
logger.ErrorCF("voice", "Failed to detect audio format", map[string]any{"path": audioFilePath, "error": err})
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := t.provider.Chat(ctx, []providers.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: t.prompt,
|
||||
Media: []string{
|
||||
fmt.Sprintf("data:audio/%s;base64,%s", format, base64.StdEncoding.EncodeToString(audioBytes)),
|
||||
},
|
||||
},
|
||||
}, nil, t.modelID, map[string]any{
|
||||
"temperature": 0,
|
||||
})
|
||||
if err != nil {
|
||||
logger.ErrorCF("voice", "Audio model transcription request failed", map[string]any{"error": err})
|
||||
return nil, fmt.Errorf("transcription request failed: %w", err)
|
||||
}
|
||||
|
||||
text := strings.TrimSpace(resp.Content)
|
||||
logger.InfoCF("voice", "Audio model transcription completed successfully", map[string]any{
|
||||
"text_length": len(text),
|
||||
"transcription_preview": utils.Truncate(text, 50),
|
||||
})
|
||||
|
||||
return &TranscriptionResponse{Text: text}, nil
|
||||
}
|
||||
|
||||
func (t *AudioModelTranscriber) Name() string {
|
||||
return "audio-model"
|
||||
}
|
||||
@@ -0,0 +1,203 @@
|
||||
package voice
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
var _ Transcriber = (*AudioModelTranscriber)(nil)
|
||||
|
||||
type fakeLLMProvider struct {
|
||||
chatFunc func(
|
||||
ctx context.Context,
|
||||
messages []providers.Message,
|
||||
tools []providers.ToolDefinition,
|
||||
model string,
|
||||
options map[string]any,
|
||||
) (*providers.LLMResponse, error)
|
||||
}
|
||||
|
||||
func (p *fakeLLMProvider) Chat(
|
||||
ctx context.Context,
|
||||
messages []providers.Message,
|
||||
tools []providers.ToolDefinition,
|
||||
model string,
|
||||
options map[string]any,
|
||||
) (*providers.LLMResponse, error) {
|
||||
if p.chatFunc == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return p.chatFunc(ctx, messages, tools, model, options)
|
||||
}
|
||||
|
||||
func (p *fakeLLMProvider) GetDefaultModel() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func TestAudioModelTranscriberName(t *testing.T) {
|
||||
tr := &AudioModelTranscriber{}
|
||||
if got := tr.Name(); got != "audio-model" {
|
||||
t.Errorf("Name() = %q, want %q", got, "audio-model")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewAudioModelTranscriberInvalidConfig(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg *config.ModelConfig
|
||||
}{
|
||||
{
|
||||
name: "nil config",
|
||||
cfg: nil,
|
||||
},
|
||||
{
|
||||
name: "missing api key",
|
||||
cfg: &config.ModelConfig{
|
||||
Model: "gemini/gemini-2.5-flash",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tr := NewAudioModelTranscriber(tt.cfg); tr != nil {
|
||||
t.Fatalf("NewAudioModelTranscriber() = %#v, want nil", tr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAudioModelTranscriberTranscribe(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
audioPath := filepath.Join(tmpDir, "clip.ogg")
|
||||
audioData := []byte("fake-audio-data")
|
||||
if err := os.WriteFile(audioPath, audioData, 0o644); err != nil {
|
||||
t.Fatalf("failed to write fake audio file: %v", err)
|
||||
}
|
||||
|
||||
t.Run("success", func(t *testing.T) {
|
||||
tr := &AudioModelTranscriber{
|
||||
provider: &fakeLLMProvider{
|
||||
chatFunc: func(
|
||||
ctx context.Context,
|
||||
messages []providers.Message,
|
||||
tools []providers.ToolDefinition,
|
||||
model string,
|
||||
options map[string]any,
|
||||
) (*providers.LLMResponse, error) {
|
||||
if ctx == nil {
|
||||
t.Fatal("context should not be nil")
|
||||
}
|
||||
if tools != nil {
|
||||
t.Fatalf("tools = %#v, want nil", tools)
|
||||
}
|
||||
if model != "gemini-2.5-flash" {
|
||||
t.Fatalf("model = %q, want %q", model, "gemini-2.5-flash")
|
||||
}
|
||||
if len(messages) != 1 {
|
||||
t.Fatalf("len(messages) = %d, want 1", len(messages))
|
||||
}
|
||||
msg := messages[0]
|
||||
if msg.Role != "user" {
|
||||
t.Fatalf("role = %q, want %q", msg.Role, "user")
|
||||
}
|
||||
if msg.Content != defaultTranscriptionPrompt {
|
||||
t.Fatalf("prompt = %q, want %q", msg.Content, defaultTranscriptionPrompt)
|
||||
}
|
||||
if len(msg.Media) != 1 {
|
||||
t.Fatalf("len(media) = %d, want 1", len(msg.Media))
|
||||
}
|
||||
wantMedia := "data:audio/ogg;base64," + base64.StdEncoding.EncodeToString(audioData)
|
||||
if msg.Media[0] != wantMedia {
|
||||
t.Fatalf("media = %q, want %q", msg.Media[0], wantMedia)
|
||||
}
|
||||
if len(options) != 1 {
|
||||
t.Fatalf("options = %#v, want only temperature", options)
|
||||
}
|
||||
if got := options["temperature"]; got != 0 {
|
||||
t.Fatalf("temperature = %#v, want 0", got)
|
||||
}
|
||||
|
||||
return &providers.LLMResponse{Content: " hello from gemini \n"}, nil
|
||||
},
|
||||
},
|
||||
modelID: "gemini-2.5-flash",
|
||||
prompt: defaultTranscriptionPrompt,
|
||||
}
|
||||
|
||||
resp, err := tr.Transcribe(context.Background(), audioPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Transcribe() error: %v", err)
|
||||
}
|
||||
if resp.Text != "hello from gemini" {
|
||||
t.Fatalf("Text = %q, want %q", resp.Text, "hello from gemini")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("provider error", func(t *testing.T) {
|
||||
tr := &AudioModelTranscriber{
|
||||
provider: &fakeLLMProvider{
|
||||
chatFunc: func(
|
||||
ctx context.Context,
|
||||
messages []providers.Message,
|
||||
tools []providers.ToolDefinition,
|
||||
model string,
|
||||
options map[string]any,
|
||||
) (*providers.LLMResponse, error) {
|
||||
return nil, errors.New("upstream failure")
|
||||
},
|
||||
},
|
||||
modelID: "gemini-2.5-flash",
|
||||
prompt: defaultTranscriptionPrompt,
|
||||
}
|
||||
|
||||
_, err := tr.Transcribe(context.Background(), audioPath)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for provider failure, got nil")
|
||||
}
|
||||
if got := err.Error(); got != "transcription request failed: upstream failure" {
|
||||
t.Fatalf("error = %q, want %q", got, "transcription request failed: upstream failure")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing file", func(t *testing.T) {
|
||||
tr := &AudioModelTranscriber{
|
||||
provider: &fakeLLMProvider{},
|
||||
modelID: "gemini-2.5-flash",
|
||||
prompt: defaultTranscriptionPrompt,
|
||||
}
|
||||
|
||||
_, err := tr.Transcribe(context.Background(), filepath.Join(tmpDir, "nonexistent.ogg"))
|
||||
if err == nil {
|
||||
t.Fatal("expected error for missing file, got nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("unsupported audio format", func(t *testing.T) {
|
||||
badPath := filepath.Join(tmpDir, "clip.txt")
|
||||
if err := os.WriteFile(badPath, []byte("not-audio"), 0o644); err != nil {
|
||||
t.Fatalf("failed to write fake file: %v", err)
|
||||
}
|
||||
|
||||
tr := &AudioModelTranscriber{
|
||||
provider: &fakeLLMProvider{},
|
||||
modelID: "gemini-2.5-flash",
|
||||
prompt: defaultTranscriptionPrompt,
|
||||
}
|
||||
|
||||
_, err := tr.Transcribe(context.Background(), badPath)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for unsupported audio format, got nil")
|
||||
}
|
||||
if got := err.Error(); got != `unsupported audio format for "`+badPath+`"` {
|
||||
t.Fatalf("error = %q, want unsupported format error", got)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,151 @@
|
||||
package voice
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/utils"
|
||||
)
|
||||
|
||||
type GroqTranscriber struct {
|
||||
apiKey string
|
||||
apiBase string
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
func NewGroqTranscriber(apiKey string) *GroqTranscriber {
|
||||
logger.DebugCF("voice", "Creating Groq transcriber", map[string]any{"has_api_key": apiKey != ""})
|
||||
|
||||
apiBase := "https://api.groq.com/openai/v1"
|
||||
return &GroqTranscriber{
|
||||
apiKey: apiKey,
|
||||
apiBase: apiBase,
|
||||
httpClient: &http.Client{
|
||||
Timeout: 60 * time.Second,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (t *GroqTranscriber) Transcribe(ctx context.Context, audioFilePath string) (*TranscriptionResponse, error) {
|
||||
logger.InfoCF("voice", "Starting transcription", map[string]any{"audio_file": audioFilePath})
|
||||
|
||||
audioFile, err := os.Open(audioFilePath)
|
||||
if err != nil {
|
||||
logger.ErrorCF("voice", "Failed to open audio file", map[string]any{"path": audioFilePath, "error": err})
|
||||
return nil, fmt.Errorf("failed to open audio file: %w", err)
|
||||
}
|
||||
defer audioFile.Close()
|
||||
|
||||
fileInfo, err := audioFile.Stat()
|
||||
if err != nil {
|
||||
logger.ErrorCF("voice", "Failed to get file info", map[string]any{"path": audioFilePath, "error": err})
|
||||
return nil, fmt.Errorf("failed to get file info: %w", err)
|
||||
}
|
||||
|
||||
logger.DebugCF("voice", "Audio file details", map[string]any{
|
||||
"size_bytes": fileInfo.Size(),
|
||||
"file_name": filepath.Base(audioFilePath),
|
||||
})
|
||||
|
||||
var requestBody bytes.Buffer
|
||||
writer := multipart.NewWriter(&requestBody)
|
||||
|
||||
part, err := writer.CreateFormFile("file", filepath.Base(audioFilePath))
|
||||
if err != nil {
|
||||
logger.ErrorCF("voice", "Failed to create form file", map[string]any{"error": err})
|
||||
return nil, fmt.Errorf("failed to create form file: %w", err)
|
||||
}
|
||||
|
||||
copied, err := io.Copy(part, audioFile)
|
||||
if err != nil {
|
||||
logger.ErrorCF("voice", "Failed to copy file content", map[string]any{"error": err})
|
||||
return nil, fmt.Errorf("failed to copy file content: %w", err)
|
||||
}
|
||||
|
||||
logger.DebugCF("voice", "File copied to request", map[string]any{"bytes_copied": copied})
|
||||
|
||||
if err = writer.WriteField("model", "whisper-large-v3"); err != nil {
|
||||
logger.ErrorCF("voice", "Failed to write model field", map[string]any{"error": err})
|
||||
return nil, fmt.Errorf("failed to write model field: %w", err)
|
||||
}
|
||||
|
||||
if err = writer.WriteField("response_format", "json"); err != nil {
|
||||
logger.ErrorCF("voice", "Failed to write response_format field", map[string]any{"error": err})
|
||||
return nil, fmt.Errorf("failed to write response_format field: %w", err)
|
||||
}
|
||||
|
||||
if err = writer.Close(); err != nil {
|
||||
logger.ErrorCF("voice", "Failed to close multipart writer", map[string]any{"error": err})
|
||||
return nil, fmt.Errorf("failed to close multipart writer: %w", err)
|
||||
}
|
||||
|
||||
url := t.apiBase + "/audio/transcriptions"
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", url, &requestBody)
|
||||
if err != nil {
|
||||
logger.ErrorCF("voice", "Failed to create request", map[string]any{"error": err})
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||||
req.Header.Set("Authorization", "Bearer "+t.apiKey)
|
||||
|
||||
logger.DebugCF("voice", "Sending transcription request to Groq API", map[string]any{
|
||||
"url": url,
|
||||
"request_size_bytes": requestBody.Len(),
|
||||
"file_size_bytes": fileInfo.Size(),
|
||||
})
|
||||
|
||||
resp, err := t.httpClient.Do(req)
|
||||
if err != nil {
|
||||
logger.ErrorCF("voice", "Failed to send request", map[string]any{"error": err})
|
||||
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
logger.ErrorCF("voice", "Failed to read response", map[string]any{"error": err})
|
||||
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
logger.ErrorCF("voice", "API error", map[string]any{
|
||||
"status_code": resp.StatusCode,
|
||||
"response": string(body),
|
||||
})
|
||||
return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
logger.DebugCF("voice", "Received response from Groq API", map[string]any{
|
||||
"status_code": resp.StatusCode,
|
||||
"response_size_bytes": len(body),
|
||||
})
|
||||
|
||||
var result TranscriptionResponse
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
logger.ErrorCF("voice", "Failed to unmarshal response", map[string]any{"error": err})
|
||||
return nil, fmt.Errorf("failed to unmarshal response: %w", err)
|
||||
}
|
||||
|
||||
logger.InfoCF("voice", "Transcription completed successfully", map[string]any{
|
||||
"text_length": len(result.Text),
|
||||
"language": result.Language,
|
||||
"duration_seconds": result.Duration,
|
||||
"transcription_preview": utils.Truncate(result.Text, 50),
|
||||
})
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
func (t *GroqTranscriber) Name() string {
|
||||
return "groq"
|
||||
}
|
||||
@@ -0,0 +1,84 @@
|
||||
package voice
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
var _ Transcriber = (*GroqTranscriber)(nil)
|
||||
|
||||
func TestGroqTranscriberName(t *testing.T) {
|
||||
tr := NewGroqTranscriber("sk-test")
|
||||
if got := tr.Name(); got != "groq" {
|
||||
t.Errorf("Name() = %q, want %q", got, "groq")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGroqTranscribe(t *testing.T) {
|
||||
// Write a minimal fake audio file so the transcriber can open and send it.
|
||||
tmpDir := t.TempDir()
|
||||
audioPath := filepath.Join(tmpDir, "clip.ogg")
|
||||
if err := os.WriteFile(audioPath, []byte("fake-audio-data"), 0o644); err != nil {
|
||||
t.Fatalf("failed to write fake audio file: %v", err)
|
||||
}
|
||||
|
||||
t.Run("success", func(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/audio/transcriptions" {
|
||||
t.Errorf("unexpected path: %s", r.URL.Path)
|
||||
}
|
||||
if r.Header.Get("Authorization") != "Bearer sk-test" {
|
||||
t.Errorf("unexpected Authorization header: %s", r.Header.Get("Authorization"))
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(TranscriptionResponse{
|
||||
Text: "hello world",
|
||||
Language: "en",
|
||||
Duration: 1.5,
|
||||
})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
tr := NewGroqTranscriber("sk-test")
|
||||
tr.apiBase = srv.URL
|
||||
|
||||
resp, err := tr.Transcribe(context.Background(), audioPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Transcribe() error: %v", err)
|
||||
}
|
||||
if resp.Text != "hello world" {
|
||||
t.Errorf("Text = %q, want %q", resp.Text, "hello world")
|
||||
}
|
||||
if resp.Language != "en" {
|
||||
t.Errorf("Language = %q, want %q", resp.Language, "en")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("api error", func(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, `{"error":"invalid_api_key"}`, http.StatusUnauthorized)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
tr := NewGroqTranscriber("sk-bad")
|
||||
tr.apiBase = srv.URL
|
||||
|
||||
_, err := tr.Transcribe(context.Background(), audioPath)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for non-200 response, got nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing file", func(t *testing.T) {
|
||||
tr := NewGroqTranscriber("sk-test")
|
||||
_, err := tr.Transcribe(context.Background(), filepath.Join(tmpDir, "nonexistent.ogg"))
|
||||
if err == nil {
|
||||
t.Fatal("expected error for missing file, got nil")
|
||||
}
|
||||
})
|
||||
}
|
||||
+8
-145
@@ -1,21 +1,10 @@
|
||||
package voice
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/utils"
|
||||
)
|
||||
|
||||
type Transcriber interface {
|
||||
@@ -23,149 +12,23 @@ type Transcriber interface {
|
||||
Transcribe(ctx context.Context, audioFilePath string) (*TranscriptionResponse, error)
|
||||
}
|
||||
|
||||
type GroqTranscriber struct {
|
||||
apiKey string
|
||||
apiBase string
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
type TranscriptionResponse struct {
|
||||
Text string `json:"text"`
|
||||
Language string `json:"language,omitempty"`
|
||||
Duration float64 `json:"duration,omitempty"`
|
||||
}
|
||||
|
||||
func NewGroqTranscriber(apiKey string) *GroqTranscriber {
|
||||
logger.DebugCF("voice", "Creating Groq transcriber", map[string]any{"has_api_key": apiKey != ""})
|
||||
|
||||
apiBase := "https://api.groq.com/openai/v1"
|
||||
return &GroqTranscriber{
|
||||
apiKey: apiKey,
|
||||
apiBase: apiBase,
|
||||
httpClient: &http.Client{
|
||||
Timeout: 60 * time.Second,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (t *GroqTranscriber) Transcribe(ctx context.Context, audioFilePath string) (*TranscriptionResponse, error) {
|
||||
logger.InfoCF("voice", "Starting transcription", map[string]any{"audio_file": audioFilePath})
|
||||
|
||||
audioFile, err := os.Open(audioFilePath)
|
||||
if err != nil {
|
||||
logger.ErrorCF("voice", "Failed to open audio file", map[string]any{"path": audioFilePath, "error": err})
|
||||
return nil, fmt.Errorf("failed to open audio file: %w", err)
|
||||
}
|
||||
defer audioFile.Close()
|
||||
|
||||
fileInfo, err := audioFile.Stat()
|
||||
if err != nil {
|
||||
logger.ErrorCF("voice", "Failed to get file info", map[string]any{"path": audioFilePath, "error": err})
|
||||
return nil, fmt.Errorf("failed to get file info: %w", err)
|
||||
}
|
||||
|
||||
logger.DebugCF("voice", "Audio file details", map[string]any{
|
||||
"size_bytes": fileInfo.Size(),
|
||||
"file_name": filepath.Base(audioFilePath),
|
||||
})
|
||||
|
||||
var requestBody bytes.Buffer
|
||||
writer := multipart.NewWriter(&requestBody)
|
||||
|
||||
part, err := writer.CreateFormFile("file", filepath.Base(audioFilePath))
|
||||
if err != nil {
|
||||
logger.ErrorCF("voice", "Failed to create form file", map[string]any{"error": err})
|
||||
return nil, fmt.Errorf("failed to create form file: %w", err)
|
||||
}
|
||||
|
||||
copied, err := io.Copy(part, audioFile)
|
||||
if err != nil {
|
||||
logger.ErrorCF("voice", "Failed to copy file content", map[string]any{"error": err})
|
||||
return nil, fmt.Errorf("failed to copy file content: %w", err)
|
||||
}
|
||||
|
||||
logger.DebugCF("voice", "File copied to request", map[string]any{"bytes_copied": copied})
|
||||
|
||||
if err = writer.WriteField("model", "whisper-large-v3"); err != nil {
|
||||
logger.ErrorCF("voice", "Failed to write model field", map[string]any{"error": err})
|
||||
return nil, fmt.Errorf("failed to write model field: %w", err)
|
||||
}
|
||||
|
||||
if err = writer.WriteField("response_format", "json"); err != nil {
|
||||
logger.ErrorCF("voice", "Failed to write response_format field", map[string]any{"error": err})
|
||||
return nil, fmt.Errorf("failed to write response_format field: %w", err)
|
||||
}
|
||||
|
||||
if err = writer.Close(); err != nil {
|
||||
logger.ErrorCF("voice", "Failed to close multipart writer", map[string]any{"error": err})
|
||||
return nil, fmt.Errorf("failed to close multipart writer: %w", err)
|
||||
}
|
||||
|
||||
url := t.apiBase + "/audio/transcriptions"
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", url, &requestBody)
|
||||
if err != nil {
|
||||
logger.ErrorCF("voice", "Failed to create request", map[string]any{"error": err})
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||||
req.Header.Set("Authorization", "Bearer "+t.apiKey)
|
||||
|
||||
logger.DebugCF("voice", "Sending transcription request to Groq API", map[string]any{
|
||||
"url": url,
|
||||
"request_size_bytes": requestBody.Len(),
|
||||
"file_size_bytes": fileInfo.Size(),
|
||||
})
|
||||
|
||||
resp, err := t.httpClient.Do(req)
|
||||
if err != nil {
|
||||
logger.ErrorCF("voice", "Failed to send request", map[string]any{"error": err})
|
||||
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
logger.ErrorCF("voice", "Failed to read response", map[string]any{"error": err})
|
||||
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
logger.ErrorCF("voice", "API error", map[string]any{
|
||||
"status_code": resp.StatusCode,
|
||||
"response": string(body),
|
||||
})
|
||||
return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
logger.DebugCF("voice", "Received response from Groq API", map[string]any{
|
||||
"status_code": resp.StatusCode,
|
||||
"response_size_bytes": len(body),
|
||||
})
|
||||
|
||||
var result TranscriptionResponse
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
logger.ErrorCF("voice", "Failed to unmarshal response", map[string]any{"error": err})
|
||||
return nil, fmt.Errorf("failed to unmarshal response: %w", err)
|
||||
}
|
||||
|
||||
logger.InfoCF("voice", "Transcription completed successfully", map[string]any{
|
||||
"text_length": len(result.Text),
|
||||
"language": result.Language,
|
||||
"duration_seconds": result.Duration,
|
||||
"transcription_preview": utils.Truncate(result.Text, 50),
|
||||
})
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
func (t *GroqTranscriber) Name() string {
|
||||
return "groq"
|
||||
}
|
||||
|
||||
// DetectTranscriber inspects cfg and returns the appropriate Transcriber, or
|
||||
// nil if no supported transcription provider is configured.
|
||||
func DetectTranscriber(cfg *config.Config) Transcriber {
|
||||
if modelName := strings.TrimSpace(cfg.Voice.ModelName); modelName != "" {
|
||||
modelCfg, err := cfg.GetModelConfig(modelName)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return NewAudioModelTranscriber(modelCfg)
|
||||
}
|
||||
|
||||
// Direct Groq provider config takes priority.
|
||||
if key := cfg.Providers.Groq.APIKey; key != "" {
|
||||
return NewGroqTranscriber(key)
|
||||
|
||||
@@ -1,27 +1,11 @@
|
||||
package voice
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
// Ensure GroqTranscriber satisfies the Transcriber interface at compile time.
|
||||
var _ Transcriber = (*GroqTranscriber)(nil)
|
||||
|
||||
func TestGroqTranscriberName(t *testing.T) {
|
||||
tr := NewGroqTranscriber("sk-test")
|
||||
if got := tr.Name(); got != "groq" {
|
||||
t.Errorf("Name() = %q, want %q", got, "groq")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDetectTranscriber(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -43,6 +27,16 @@ func TestDetectTranscriber(t *testing.T) {
|
||||
},
|
||||
wantName: "groq",
|
||||
},
|
||||
{
|
||||
name: "voice model name selects audio model transcriber",
|
||||
cfg: &config.Config{
|
||||
Voice: config.VoiceConfig{ModelName: "voice-gemini"},
|
||||
ModelList: []config.ModelConfig{
|
||||
{ModelName: "voice-gemini", Model: "gemini/gemini-2.5-flash", APIKey: "sk-gemini-model"},
|
||||
},
|
||||
},
|
||||
wantName: "audio-model",
|
||||
},
|
||||
{
|
||||
name: "groq via model list",
|
||||
cfg: &config.Config{
|
||||
@@ -53,6 +47,16 @@ func TestDetectTranscriber(t *testing.T) {
|
||||
},
|
||||
wantName: "groq",
|
||||
},
|
||||
{
|
||||
name: "voice model name selects non-gemini audio model transcriber",
|
||||
cfg: &config.Config{
|
||||
Voice: config.VoiceConfig{ModelName: "voice-openai-audio"},
|
||||
ModelList: []config.ModelConfig{
|
||||
{ModelName: "voice-openai-audio", Model: "openai/gpt-4o-audio-preview", APIKey: "sk-openai"},
|
||||
},
|
||||
},
|
||||
wantName: "audio-model",
|
||||
},
|
||||
{
|
||||
name: "groq model list entry without key is skipped",
|
||||
cfg: &config.Config{
|
||||
@@ -74,6 +78,16 @@ func TestDetectTranscriber(t *testing.T) {
|
||||
},
|
||||
wantName: "groq",
|
||||
},
|
||||
{
|
||||
name: "missing voice model name config returns nil",
|
||||
cfg: &config.Config{
|
||||
Voice: config.VoiceConfig{ModelName: "missing"},
|
||||
ModelList: []config.ModelConfig{
|
||||
{ModelName: "other", Model: "gemini/gemini-2.5-flash", APIKey: "sk-gemini-model"},
|
||||
},
|
||||
},
|
||||
wantNil: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
@@ -94,67 +108,3 @@ func TestDetectTranscriber(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTranscribe(t *testing.T) {
|
||||
// Write a minimal fake audio file so the transcriber can open and send it.
|
||||
tmpDir := t.TempDir()
|
||||
audioPath := filepath.Join(tmpDir, "clip.ogg")
|
||||
if err := os.WriteFile(audioPath, []byte("fake-audio-data"), 0o644); err != nil {
|
||||
t.Fatalf("failed to write fake audio file: %v", err)
|
||||
}
|
||||
|
||||
t.Run("success", func(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/audio/transcriptions" {
|
||||
t.Errorf("unexpected path: %s", r.URL.Path)
|
||||
}
|
||||
if r.Header.Get("Authorization") != "Bearer sk-test" {
|
||||
t.Errorf("unexpected Authorization header: %s", r.Header.Get("Authorization"))
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(TranscriptionResponse{
|
||||
Text: "hello world",
|
||||
Language: "en",
|
||||
Duration: 1.5,
|
||||
})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
tr := NewGroqTranscriber("sk-test")
|
||||
tr.apiBase = srv.URL
|
||||
|
||||
resp, err := tr.Transcribe(context.Background(), audioPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Transcribe() error: %v", err)
|
||||
}
|
||||
if resp.Text != "hello world" {
|
||||
t.Errorf("Text = %q, want %q", resp.Text, "hello world")
|
||||
}
|
||||
if resp.Language != "en" {
|
||||
t.Errorf("Language = %q, want %q", resp.Language, "en")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("api error", func(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, `{"error":"invalid_api_key"}`, http.StatusUnauthorized)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
tr := NewGroqTranscriber("sk-bad")
|
||||
tr.apiBase = srv.URL
|
||||
|
||||
_, err := tr.Transcribe(context.Background(), audioPath)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for non-200 response, got nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing file", func(t *testing.T) {
|
||||
tr := NewGroqTranscriber("sk-test")
|
||||
_, err := tr.Transcribe(context.Background(), filepath.Join(tmpDir, "nonexistent.ogg"))
|
||||
if err == nil {
|
||||
t.Fatal("expected error for missing file, got nil")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user