Merge pull request #1891 from RussellLuo/audio-transcription

feat(voice): add audio-model transcription support
This commit is contained in:
Mauro
2026-03-23 00:23:30 +01:00
committed by GitHub
18 changed files with 800 additions and 227 deletions
+1
View File
@@ -548,6 +548,7 @@
"monitor_usb": true
},
"voice": {
"model_name": "",
"echo_transcription": false
},
"hooks": {
+1 -1
View File
@@ -2,7 +2,7 @@
# Telegram
The Telegram channel uses long polling via the Telegram Bot API for bot-based communication. It supports text messages, media attachments (photos, voice, audio, documents), voice transcription via Groq Whisper, and built-in command handling.
The Telegram channel uses long polling via the Telegram Bot API for bot-based communication. It supports text messages, media attachments (photos, voice, audio, documents), voice transcription ([setup](../../providers.md#voice-transcription)), and built-in command handling.
## Configuration
+1 -1
View File
@@ -2,7 +2,7 @@
# Telegram
Telegram Channel 通过 Telegram 机器人 API 使用长轮询实现基于机器人的通信。它支持文本消息、媒体附件(照片、语音、音频、文档)、通过 Groq Whisper 进行语音转录以及内置命令处理器。
Telegram Channel 通过 Telegram 机器人 API 使用长轮询实现基于机器人的通信。它支持文本消息、媒体附件(照片、语音、音频、文档)、语音转录(配置见[提供商与模型配置](../../zh/providers.md#语音转录)),以及内置命令处理器。
## 配置
+32 -1
View File
@@ -5,7 +5,7 @@
### Providers
> [!NOTE]
> Groq provides free voice transcription via Whisper. If configured, audio messages from any channel will be automatically transcribed at the agent level.
> Voice transcription can use a configured multimodal model via `voice.model_name`. Groq Whisper remains available as a fallback when no voice model is configured.
| Provider | Purpose | Get API Key |
| ------------ | --------------------------------------- | ------------------------------------------------------------ |
@@ -101,6 +101,33 @@ This design also enables **multi-agent support** with flexible provider selectio
}
```
#### Voice Transcription
You can configure a dedicated model for audio transcription with `voice.model_name`. This lets you reuse existing multimodal providers that support audio input instead of relying only on Groq.
If `voice.model_name` is not configured, PicoClaw will continue to fall back to Groq transcription when a Groq API key is available.
```json
{
"model_list": [
{
"model_name": "voice-gemini",
"model": "gemini/gemini-2.5-flash",
"api_key": "your-gemini-key"
}
],
"voice": {
"model_name": "voice-gemini",
"echo_transcription": false
},
"providers": {
"groq": {
"api_key": "gsk_xxx"
}
}
}
```
#### Vendor-Specific Examples
**OpenAI**
@@ -344,6 +371,10 @@ picoclaw agent -m "Hello"
"api_key": "gsk_xxx"
}
},
"voice": {
"model_name": "voice-gemini",
"echo_transcription": false
},
"channels": {
"telegram": {
"enabled": true,
+32 -1
View File
@@ -5,7 +5,7 @@
### 提供商 (Providers)
> [!NOTE]
> Groq 通过 Whisper 提供免费的语音转录。如果配置了 Groq,任意渠道的音频消息都将在 Agent 层面自动转录为文字
> 语音转录现在可以通过 `voice.model_name` 指定的多模态模型完成;如果配置语音模型,Groq Whisper 仍可作为回退方案
| 提供商 | 用途 | 获取 API Key |
| -------------------- | ---------------------------- | -------------------------------------------------------------------- |
@@ -99,6 +99,33 @@
}
```
#### 语音转录
你可以通过 `voice.model_name` 为语音转录指定一个专用模型。这样可以直接复用已经配置好的、支持音频输入的多模态 provider,而不必只依赖 Groq。
如果没有配置 `voice.model_name`,且存在 Groq API KeyPicoClaw 会继续回退到 Groq 转录。
```json
{
"model_list": [
{
"model_name": "voice-gemini",
"model": "gemini/gemini-2.5-flash",
"api_key": "your-gemini-key"
}
],
"voice": {
"model_name": "voice-gemini",
"echo_transcription": false
},
"providers": {
"groq": {
"api_key": "gsk_xxx"
}
}
}
```
#### 各厂商配置示例
**OpenAI**
@@ -342,6 +369,10 @@ picoclaw agent -m "你好"
"api_key": "gsk_xxx"
}
},
"voice": {
"model_name": "voice-gemini",
"echo_transcription": false
},
"channels": {
"telegram": {
"enabled": true,
+2 -1
View File
@@ -604,7 +604,8 @@ type DevicesConfig struct {
}
type VoiceConfig struct {
EchoTranscription bool `json:"echo_transcription" env:"PICOCLAW_VOICE_ECHO_TRANSCRIPTION"`
ModelName string `json:"model_name,omitempty" env:"PICOCLAW_VOICE_MODEL_NAME"`
EchoTranscription bool `json:"echo_transcription" env:"PICOCLAW_VOICE_ECHO_TRANSCRIPTION"`
}
type ProvidersConfig struct {
+1
View File
@@ -576,6 +576,7 @@ func DefaultConfig() *Config {
MonitorUSB: true,
},
Voice: VoiceConfig{
ModelName: "",
EchoTranscription: false,
},
BuildInfo: BuildInfo{
+2
View File
@@ -256,6 +256,8 @@ func appendFields(event *zerolog.Event, fields map[string]any) {
for k, v := range fields {
// Type switch to avoid double JSON serialization of strings
switch val := v.(type) {
case error:
event.Str(k, val.Error())
case string:
event.Str(k, val)
case int:
+28
View File
@@ -1,7 +1,12 @@
package logger
import (
"bytes"
"encoding/json"
"errors"
"testing"
"github.com/rs/zerolog"
)
func TestLogLevelFiltering(t *testing.T) {
@@ -337,3 +342,26 @@ func TestSetLevelFromString(t *testing.T) {
t.Errorf("after SetLevelFromString(\"FATAL\"): GetLevel() = %v, want FATAL", got)
}
}
func TestAppendFields_ErrorUsesErrorString(t *testing.T) {
var buf bytes.Buffer
l := zerolog.New(&buf)
event := l.Info()
appendFields(event, map[string]any{"error": errors.New("transcription request failed")})
event.Msg("test")
lines := bytes.Split(bytes.TrimSpace(buf.Bytes()), []byte("\n"))
if len(lines) == 0 {
t.Fatal("expected log output, got none")
}
var got map[string]any
if err := json.Unmarshal(lines[0], &got); err != nil {
t.Fatalf("unmarshal log line: %v", err)
}
if got["error"] != "transcription request failed" {
t.Fatalf("error field = %#v, want %q", got["error"], "transcription request failed")
}
}
+31
View File
@@ -111,6 +111,17 @@ func SerializeMessages(messages []Message) []any {
"url": mediaURL,
},
})
continue
}
if format, data, ok := parseDataAudioURL(mediaURL); ok {
parts = append(parts, map[string]any{
"type": "input_audio",
"input_audio": map[string]any{
"data": data,
"format": format,
},
})
}
}
@@ -132,6 +143,26 @@ func SerializeMessages(messages []Message) []any {
return out
}
func parseDataAudioURL(mediaURL string) (format, data string, ok bool) {
if !strings.HasPrefix(mediaURL, "data:audio/") {
return "", "", false
}
payload := strings.TrimPrefix(mediaURL, "data:audio/")
meta, data, found := strings.Cut(payload, ",")
if !found {
return "", "", false
}
format, _, _ = strings.Cut(meta, ";")
format = strings.TrimSpace(format)
data = strings.TrimSpace(data)
if format == "" || data == "" {
return "", "", false
}
return format, data, true
}
// --- Response parsing ---
// ParseResponse parses a JSON chat completion response body into an LLMResponse.
+38
View File
@@ -91,6 +91,44 @@ func TestSerializeMessages_WithMedia(t *testing.T) {
}
}
func TestSerializeMessages_WithAudioMedia(t *testing.T) {
messages := []Message{
{Role: "user", Content: "transcribe this", Media: []string{"data:audio/ogg;base64,abc123"}},
}
result := SerializeMessages(messages)
data, _ := json.Marshal(result)
var msgs []map[string]any
json.Unmarshal(data, &msgs)
content, ok := msgs[0]["content"].([]any)
if !ok {
t.Fatalf("expected array content for media message, got %T", msgs[0]["content"])
}
if len(content) != 2 {
t.Fatalf("expected 2 content parts, got %d", len(content))
}
audioPart, ok := content[1].(map[string]any)
if !ok {
t.Fatalf("expected audio content part to be an object, got %T", content[1])
}
if audioPart["type"] != "input_audio" {
t.Fatalf("audio part type = %v, want input_audio", audioPart["type"])
}
inputAudio, ok := audioPart["input_audio"].(map[string]any)
if !ok {
t.Fatalf("expected input_audio object, got %T", audioPart["input_audio"])
}
if inputAudio["format"] != "ogg" {
t.Fatalf("audio format = %v, want ogg", inputAudio["format"])
}
if inputAudio["data"] != "abc123" {
t.Fatalf("audio data = %v, want abc123", inputAudio["data"])
}
}
func TestSerializeMessages_MediaWithToolCallID(t *testing.T) {
messages := []Message{
{Role: "tool", Content: "result", Media: []string{"data:image/png;base64,xyz"}, ToolCallID: "call_1"},
+14 -1
View File
@@ -1,6 +1,7 @@
package utils
import (
"fmt"
"io"
"net/http"
"net/url"
@@ -15,9 +16,21 @@ import (
"github.com/sipeed/picoclaw/pkg/media"
)
var audioExtensions = []string{".mp3", ".wav", ".ogg", ".m4a", ".flac", ".aac", ".wma"}
func AudioFormat(path string) (string, error) {
ext := strings.ToLower(filepath.Ext(path))
for _, supportedExt := range audioExtensions {
if ext == supportedExt {
return strings.TrimPrefix(ext, "."), nil
}
}
return "", fmt.Errorf("unsupported audio format for %q", path)
}
// IsAudioFile checks if a file is an audio file based on its filename extension and content type.
func IsAudioFile(filename, contentType string) bool {
audioExtensions := []string{".mp3", ".wav", ".ogg", ".m4a", ".flac", ".aac", ".wma"}
audioTypes := []string{"audio/", "application/ogg", "application/x-ogg"}
for _, ext := range audioExtensions {
+95
View File
@@ -0,0 +1,95 @@
package voice
import (
"context"
"encoding/base64"
"fmt"
"os"
"strings"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/utils"
)
type AudioModelTranscriber struct {
provider providers.LLMProvider
modelID string
prompt string
}
const (
defaultTranscriptionPrompt = "Transcribe this audio."
)
func NewAudioModelTranscriber(modelCfg *config.ModelConfig) *AudioModelTranscriber {
if modelCfg == nil {
return nil
}
logger.DebugCF("voice", "Creating audio model transcriber", map[string]any{
"has_api_key": modelCfg.APIKey != "",
"api_base": modelCfg.APIBase,
"model": modelCfg.Model,
})
provider, modelID, err := providers.CreateProviderFromConfig(modelCfg)
if err != nil {
logger.ErrorCF("voice", "Failed to create audio model provider", map[string]any{"error": err})
return nil
}
return &AudioModelTranscriber{
provider: provider,
modelID: modelID,
prompt: defaultTranscriptionPrompt,
}
}
func (t *AudioModelTranscriber) Transcribe(ctx context.Context, audioFilePath string) (*TranscriptionResponse, error) {
logger.InfoCF("voice", "Starting audio model transcription", map[string]any{
"audio_file": audioFilePath,
"model": t.modelID,
})
audioBytes, err := os.ReadFile(audioFilePath)
if err != nil {
logger.ErrorCF("voice", "Failed to read audio file", map[string]any{"path": audioFilePath, "error": err})
return nil, fmt.Errorf("failed to read audio file: %w", err)
}
format, err := utils.AudioFormat(audioFilePath)
if err != nil {
logger.ErrorCF("voice", "Failed to detect audio format", map[string]any{"path": audioFilePath, "error": err})
return nil, err
}
resp, err := t.provider.Chat(ctx, []providers.Message{
{
Role: "user",
Content: t.prompt,
Media: []string{
fmt.Sprintf("data:audio/%s;base64,%s", format, base64.StdEncoding.EncodeToString(audioBytes)),
},
},
}, nil, t.modelID, map[string]any{
"temperature": 0,
})
if err != nil {
logger.ErrorCF("voice", "Audio model transcription request failed", map[string]any{"error": err})
return nil, fmt.Errorf("transcription request failed: %w", err)
}
text := strings.TrimSpace(resp.Content)
logger.InfoCF("voice", "Audio model transcription completed successfully", map[string]any{
"text_length": len(text),
"transcription_preview": utils.Truncate(text, 50),
})
return &TranscriptionResponse{Text: text}, nil
}
func (t *AudioModelTranscriber) Name() string {
return "audio-model"
}
+203
View File
@@ -0,0 +1,203 @@
package voice
import (
"context"
"encoding/base64"
"errors"
"os"
"path/filepath"
"testing"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/providers"
)
var _ Transcriber = (*AudioModelTranscriber)(nil)
type fakeLLMProvider struct {
chatFunc func(
ctx context.Context,
messages []providers.Message,
tools []providers.ToolDefinition,
model string,
options map[string]any,
) (*providers.LLMResponse, error)
}
func (p *fakeLLMProvider) Chat(
ctx context.Context,
messages []providers.Message,
tools []providers.ToolDefinition,
model string,
options map[string]any,
) (*providers.LLMResponse, error) {
if p.chatFunc == nil {
return nil, nil
}
return p.chatFunc(ctx, messages, tools, model, options)
}
func (p *fakeLLMProvider) GetDefaultModel() string {
return ""
}
func TestAudioModelTranscriberName(t *testing.T) {
tr := &AudioModelTranscriber{}
if got := tr.Name(); got != "audio-model" {
t.Errorf("Name() = %q, want %q", got, "audio-model")
}
}
func TestNewAudioModelTranscriberInvalidConfig(t *testing.T) {
tests := []struct {
name string
cfg *config.ModelConfig
}{
{
name: "nil config",
cfg: nil,
},
{
name: "missing api key",
cfg: &config.ModelConfig{
Model: "gemini/gemini-2.5-flash",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tr := NewAudioModelTranscriber(tt.cfg); tr != nil {
t.Fatalf("NewAudioModelTranscriber() = %#v, want nil", tr)
}
})
}
}
func TestAudioModelTranscriberTranscribe(t *testing.T) {
tmpDir := t.TempDir()
audioPath := filepath.Join(tmpDir, "clip.ogg")
audioData := []byte("fake-audio-data")
if err := os.WriteFile(audioPath, audioData, 0o644); err != nil {
t.Fatalf("failed to write fake audio file: %v", err)
}
t.Run("success", func(t *testing.T) {
tr := &AudioModelTranscriber{
provider: &fakeLLMProvider{
chatFunc: func(
ctx context.Context,
messages []providers.Message,
tools []providers.ToolDefinition,
model string,
options map[string]any,
) (*providers.LLMResponse, error) {
if ctx == nil {
t.Fatal("context should not be nil")
}
if tools != nil {
t.Fatalf("tools = %#v, want nil", tools)
}
if model != "gemini-2.5-flash" {
t.Fatalf("model = %q, want %q", model, "gemini-2.5-flash")
}
if len(messages) != 1 {
t.Fatalf("len(messages) = %d, want 1", len(messages))
}
msg := messages[0]
if msg.Role != "user" {
t.Fatalf("role = %q, want %q", msg.Role, "user")
}
if msg.Content != defaultTranscriptionPrompt {
t.Fatalf("prompt = %q, want %q", msg.Content, defaultTranscriptionPrompt)
}
if len(msg.Media) != 1 {
t.Fatalf("len(media) = %d, want 1", len(msg.Media))
}
wantMedia := "data:audio/ogg;base64," + base64.StdEncoding.EncodeToString(audioData)
if msg.Media[0] != wantMedia {
t.Fatalf("media = %q, want %q", msg.Media[0], wantMedia)
}
if len(options) != 1 {
t.Fatalf("options = %#v, want only temperature", options)
}
if got := options["temperature"]; got != 0 {
t.Fatalf("temperature = %#v, want 0", got)
}
return &providers.LLMResponse{Content: " hello from gemini \n"}, nil
},
},
modelID: "gemini-2.5-flash",
prompt: defaultTranscriptionPrompt,
}
resp, err := tr.Transcribe(context.Background(), audioPath)
if err != nil {
t.Fatalf("Transcribe() error: %v", err)
}
if resp.Text != "hello from gemini" {
t.Fatalf("Text = %q, want %q", resp.Text, "hello from gemini")
}
})
t.Run("provider error", func(t *testing.T) {
tr := &AudioModelTranscriber{
provider: &fakeLLMProvider{
chatFunc: func(
ctx context.Context,
messages []providers.Message,
tools []providers.ToolDefinition,
model string,
options map[string]any,
) (*providers.LLMResponse, error) {
return nil, errors.New("upstream failure")
},
},
modelID: "gemini-2.5-flash",
prompt: defaultTranscriptionPrompt,
}
_, err := tr.Transcribe(context.Background(), audioPath)
if err == nil {
t.Fatal("expected error for provider failure, got nil")
}
if got := err.Error(); got != "transcription request failed: upstream failure" {
t.Fatalf("error = %q, want %q", got, "transcription request failed: upstream failure")
}
})
t.Run("missing file", func(t *testing.T) {
tr := &AudioModelTranscriber{
provider: &fakeLLMProvider{},
modelID: "gemini-2.5-flash",
prompt: defaultTranscriptionPrompt,
}
_, err := tr.Transcribe(context.Background(), filepath.Join(tmpDir, "nonexistent.ogg"))
if err == nil {
t.Fatal("expected error for missing file, got nil")
}
})
t.Run("unsupported audio format", func(t *testing.T) {
badPath := filepath.Join(tmpDir, "clip.txt")
if err := os.WriteFile(badPath, []byte("not-audio"), 0o644); err != nil {
t.Fatalf("failed to write fake file: %v", err)
}
tr := &AudioModelTranscriber{
provider: &fakeLLMProvider{},
modelID: "gemini-2.5-flash",
prompt: defaultTranscriptionPrompt,
}
_, err := tr.Transcribe(context.Background(), badPath)
if err == nil {
t.Fatal("expected error for unsupported audio format, got nil")
}
if got := err.Error(); got != `unsupported audio format for "`+badPath+`"` {
t.Fatalf("error = %q, want unsupported format error", got)
}
})
}
+151
View File
@@ -0,0 +1,151 @@
package voice
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"mime/multipart"
"net/http"
"os"
"path/filepath"
"time"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/utils"
)
type GroqTranscriber struct {
apiKey string
apiBase string
httpClient *http.Client
}
func NewGroqTranscriber(apiKey string) *GroqTranscriber {
logger.DebugCF("voice", "Creating Groq transcriber", map[string]any{"has_api_key": apiKey != ""})
apiBase := "https://api.groq.com/openai/v1"
return &GroqTranscriber{
apiKey: apiKey,
apiBase: apiBase,
httpClient: &http.Client{
Timeout: 60 * time.Second,
},
}
}
func (t *GroqTranscriber) Transcribe(ctx context.Context, audioFilePath string) (*TranscriptionResponse, error) {
logger.InfoCF("voice", "Starting transcription", map[string]any{"audio_file": audioFilePath})
audioFile, err := os.Open(audioFilePath)
if err != nil {
logger.ErrorCF("voice", "Failed to open audio file", map[string]any{"path": audioFilePath, "error": err})
return nil, fmt.Errorf("failed to open audio file: %w", err)
}
defer audioFile.Close()
fileInfo, err := audioFile.Stat()
if err != nil {
logger.ErrorCF("voice", "Failed to get file info", map[string]any{"path": audioFilePath, "error": err})
return nil, fmt.Errorf("failed to get file info: %w", err)
}
logger.DebugCF("voice", "Audio file details", map[string]any{
"size_bytes": fileInfo.Size(),
"file_name": filepath.Base(audioFilePath),
})
var requestBody bytes.Buffer
writer := multipart.NewWriter(&requestBody)
part, err := writer.CreateFormFile("file", filepath.Base(audioFilePath))
if err != nil {
logger.ErrorCF("voice", "Failed to create form file", map[string]any{"error": err})
return nil, fmt.Errorf("failed to create form file: %w", err)
}
copied, err := io.Copy(part, audioFile)
if err != nil {
logger.ErrorCF("voice", "Failed to copy file content", map[string]any{"error": err})
return nil, fmt.Errorf("failed to copy file content: %w", err)
}
logger.DebugCF("voice", "File copied to request", map[string]any{"bytes_copied": copied})
if err = writer.WriteField("model", "whisper-large-v3"); err != nil {
logger.ErrorCF("voice", "Failed to write model field", map[string]any{"error": err})
return nil, fmt.Errorf("failed to write model field: %w", err)
}
if err = writer.WriteField("response_format", "json"); err != nil {
logger.ErrorCF("voice", "Failed to write response_format field", map[string]any{"error": err})
return nil, fmt.Errorf("failed to write response_format field: %w", err)
}
if err = writer.Close(); err != nil {
logger.ErrorCF("voice", "Failed to close multipart writer", map[string]any{"error": err})
return nil, fmt.Errorf("failed to close multipart writer: %w", err)
}
url := t.apiBase + "/audio/transcriptions"
req, err := http.NewRequestWithContext(ctx, "POST", url, &requestBody)
if err != nil {
logger.ErrorCF("voice", "Failed to create request", map[string]any{"error": err})
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", writer.FormDataContentType())
req.Header.Set("Authorization", "Bearer "+t.apiKey)
logger.DebugCF("voice", "Sending transcription request to Groq API", map[string]any{
"url": url,
"request_size_bytes": requestBody.Len(),
"file_size_bytes": fileInfo.Size(),
})
resp, err := t.httpClient.Do(req)
if err != nil {
logger.ErrorCF("voice", "Failed to send request", map[string]any{"error": err})
return nil, fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
logger.ErrorCF("voice", "Failed to read response", map[string]any{"error": err})
return nil, fmt.Errorf("failed to read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
logger.ErrorCF("voice", "API error", map[string]any{
"status_code": resp.StatusCode,
"response": string(body),
})
return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
}
logger.DebugCF("voice", "Received response from Groq API", map[string]any{
"status_code": resp.StatusCode,
"response_size_bytes": len(body),
})
var result TranscriptionResponse
if err := json.Unmarshal(body, &result); err != nil {
logger.ErrorCF("voice", "Failed to unmarshal response", map[string]any{"error": err})
return nil, fmt.Errorf("failed to unmarshal response: %w", err)
}
logger.InfoCF("voice", "Transcription completed successfully", map[string]any{
"text_length": len(result.Text),
"language": result.Language,
"duration_seconds": result.Duration,
"transcription_preview": utils.Truncate(result.Text, 50),
})
return &result, nil
}
func (t *GroqTranscriber) Name() string {
return "groq"
}
+84
View File
@@ -0,0 +1,84 @@
package voice
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
)
var _ Transcriber = (*GroqTranscriber)(nil)
func TestGroqTranscriberName(t *testing.T) {
tr := NewGroqTranscriber("sk-test")
if got := tr.Name(); got != "groq" {
t.Errorf("Name() = %q, want %q", got, "groq")
}
}
func TestGroqTranscribe(t *testing.T) {
// Write a minimal fake audio file so the transcriber can open and send it.
tmpDir := t.TempDir()
audioPath := filepath.Join(tmpDir, "clip.ogg")
if err := os.WriteFile(audioPath, []byte("fake-audio-data"), 0o644); err != nil {
t.Fatalf("failed to write fake audio file: %v", err)
}
t.Run("success", func(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/audio/transcriptions" {
t.Errorf("unexpected path: %s", r.URL.Path)
}
if r.Header.Get("Authorization") != "Bearer sk-test" {
t.Errorf("unexpected Authorization header: %s", r.Header.Get("Authorization"))
}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(TranscriptionResponse{
Text: "hello world",
Language: "en",
Duration: 1.5,
})
}))
defer srv.Close()
tr := NewGroqTranscriber("sk-test")
tr.apiBase = srv.URL
resp, err := tr.Transcribe(context.Background(), audioPath)
if err != nil {
t.Fatalf("Transcribe() error: %v", err)
}
if resp.Text != "hello world" {
t.Errorf("Text = %q, want %q", resp.Text, "hello world")
}
if resp.Language != "en" {
t.Errorf("Language = %q, want %q", resp.Language, "en")
}
})
t.Run("api error", func(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, `{"error":"invalid_api_key"}`, http.StatusUnauthorized)
}))
defer srv.Close()
tr := NewGroqTranscriber("sk-bad")
tr.apiBase = srv.URL
_, err := tr.Transcribe(context.Background(), audioPath)
if err == nil {
t.Fatal("expected error for non-200 response, got nil")
}
})
t.Run("missing file", func(t *testing.T) {
tr := NewGroqTranscriber("sk-test")
_, err := tr.Transcribe(context.Background(), filepath.Join(tmpDir, "nonexistent.ogg"))
if err == nil {
t.Fatal("expected error for missing file, got nil")
}
})
}
+29 -141
View File
@@ -1,21 +1,11 @@
package voice
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"mime/multipart"
"net/http"
"os"
"path/filepath"
"strings"
"time"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/utils"
"github.com/sipeed/picoclaw/pkg/providers"
)
type Transcriber interface {
@@ -23,149 +13,47 @@ type Transcriber interface {
Transcribe(ctx context.Context, audioFilePath string) (*TranscriptionResponse, error)
}
type GroqTranscriber struct {
apiKey string
apiBase string
httpClient *http.Client
}
type TranscriptionResponse struct {
Text string `json:"text"`
Language string `json:"language,omitempty"`
Duration float64 `json:"duration,omitempty"`
}
func NewGroqTranscriber(apiKey string) *GroqTranscriber {
logger.DebugCF("voice", "Creating Groq transcriber", map[string]any{"has_api_key": apiKey != ""})
func supportsAudioTranscription(model string) bool {
protocol, _ := providers.ExtractProtocol(model)
apiBase := "https://api.groq.com/openai/v1"
return &GroqTranscriber{
apiKey: apiKey,
apiBase: apiBase,
httpClient: &http.Client{
Timeout: 60 * time.Second,
},
switch protocol {
case "openai", "azure", "azure-openai",
"litellm", "openrouter", "groq", "zhipu", "gemini", "nvidia",
"ollama", "moonshot", "shengsuanyun", "deepseek", "cerebras",
"vivgrid", "volcengine", "vllm", "qwen", "qwen-intl", "qwen-international", "dashscope-intl",
"qwen-us", "dashscope-us", "mistral", "avian", "minimax", "longcat", "modelscope", "novita",
"coding-plan", "alibaba-coding", "qwen-coding":
// These protocols all go through the OpenAI-compatible or Azure provider path in
// providers.CreateProviderFromConfig, so they are the only ones that can supply
// the audio media payload shape expected by NewAudioModelTranscriber.
// TODO: Further restrict this by modelID, since not every model under these
// protocols supports audio transcription.
return true
default:
return false
}
}
func (t *GroqTranscriber) Transcribe(ctx context.Context, audioFilePath string) (*TranscriptionResponse, error) {
logger.InfoCF("voice", "Starting transcription", map[string]any{"audio_file": audioFilePath})
audioFile, err := os.Open(audioFilePath)
if err != nil {
logger.ErrorCF("voice", "Failed to open audio file", map[string]any{"path": audioFilePath, "error": err})
return nil, fmt.Errorf("failed to open audio file: %w", err)
}
defer audioFile.Close()
fileInfo, err := audioFile.Stat()
if err != nil {
logger.ErrorCF("voice", "Failed to get file info", map[string]any{"path": audioFilePath, "error": err})
return nil, fmt.Errorf("failed to get file info: %w", err)
}
logger.DebugCF("voice", "Audio file details", map[string]any{
"size_bytes": fileInfo.Size(),
"file_name": filepath.Base(audioFilePath),
})
var requestBody bytes.Buffer
writer := multipart.NewWriter(&requestBody)
part, err := writer.CreateFormFile("file", filepath.Base(audioFilePath))
if err != nil {
logger.ErrorCF("voice", "Failed to create form file", map[string]any{"error": err})
return nil, fmt.Errorf("failed to create form file: %w", err)
}
copied, err := io.Copy(part, audioFile)
if err != nil {
logger.ErrorCF("voice", "Failed to copy file content", map[string]any{"error": err})
return nil, fmt.Errorf("failed to copy file content: %w", err)
}
logger.DebugCF("voice", "File copied to request", map[string]any{"bytes_copied": copied})
if err = writer.WriteField("model", "whisper-large-v3"); err != nil {
logger.ErrorCF("voice", "Failed to write model field", map[string]any{"error": err})
return nil, fmt.Errorf("failed to write model field: %w", err)
}
if err = writer.WriteField("response_format", "json"); err != nil {
logger.ErrorCF("voice", "Failed to write response_format field", map[string]any{"error": err})
return nil, fmt.Errorf("failed to write response_format field: %w", err)
}
if err = writer.Close(); err != nil {
logger.ErrorCF("voice", "Failed to close multipart writer", map[string]any{"error": err})
return nil, fmt.Errorf("failed to close multipart writer: %w", err)
}
url := t.apiBase + "/audio/transcriptions"
req, err := http.NewRequestWithContext(ctx, "POST", url, &requestBody)
if err != nil {
logger.ErrorCF("voice", "Failed to create request", map[string]any{"error": err})
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", writer.FormDataContentType())
req.Header.Set("Authorization", "Bearer "+t.apiKey)
logger.DebugCF("voice", "Sending transcription request to Groq API", map[string]any{
"url": url,
"request_size_bytes": requestBody.Len(),
"file_size_bytes": fileInfo.Size(),
})
resp, err := t.httpClient.Do(req)
if err != nil {
logger.ErrorCF("voice", "Failed to send request", map[string]any{"error": err})
return nil, fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
logger.ErrorCF("voice", "Failed to read response", map[string]any{"error": err})
return nil, fmt.Errorf("failed to read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
logger.ErrorCF("voice", "API error", map[string]any{
"status_code": resp.StatusCode,
"response": string(body),
})
return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
}
logger.DebugCF("voice", "Received response from Groq API", map[string]any{
"status_code": resp.StatusCode,
"response_size_bytes": len(body),
})
var result TranscriptionResponse
if err := json.Unmarshal(body, &result); err != nil {
logger.ErrorCF("voice", "Failed to unmarshal response", map[string]any{"error": err})
return nil, fmt.Errorf("failed to unmarshal response: %w", err)
}
logger.InfoCF("voice", "Transcription completed successfully", map[string]any{
"text_length": len(result.Text),
"language": result.Language,
"duration_seconds": result.Duration,
"transcription_preview": utils.Truncate(result.Text, 50),
})
return &result, nil
}
func (t *GroqTranscriber) Name() string {
return "groq"
}
// DetectTranscriber inspects cfg and returns the appropriate Transcriber, or
// nil if no supported transcription provider is configured.
func DetectTranscriber(cfg *config.Config) Transcriber {
if modelName := strings.TrimSpace(cfg.Voice.ModelName); modelName != "" {
modelCfg, err := cfg.GetModelConfig(modelName)
if err != nil {
return nil
}
if supportsAudioTranscription(modelCfg.Model) {
return NewAudioModelTranscriber(modelCfg)
}
}
// Direct Groq provider config takes priority.
if key := cfg.Providers.Groq.APIKey; key != "" {
return NewGroqTranscriber(key)
+55 -80
View File
@@ -1,27 +1,11 @@
package voice
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
"github.com/sipeed/picoclaw/pkg/config"
)
// Ensure GroqTranscriber satisfies the Transcriber interface at compile time.
var _ Transcriber = (*GroqTranscriber)(nil)
func TestGroqTranscriberName(t *testing.T) {
tr := NewGroqTranscriber("sk-test")
if got := tr.Name(); got != "groq" {
t.Errorf("Name() = %q, want %q", got, "groq")
}
}
func TestDetectTranscriber(t *testing.T) {
tests := []struct {
name string
@@ -43,6 +27,16 @@ func TestDetectTranscriber(t *testing.T) {
},
wantName: "groq",
},
{
name: "voice model name selects audio model transcriber",
cfg: &config.Config{
Voice: config.VoiceConfig{ModelName: "voice-gemini"},
ModelList: []config.ModelConfig{
{ModelName: "voice-gemini", Model: "gemini/gemini-2.5-flash", APIKey: "sk-gemini-model"},
},
},
wantName: "audio-model",
},
{
name: "groq via model list",
cfg: &config.Config{
@@ -53,6 +47,41 @@ func TestDetectTranscriber(t *testing.T) {
},
wantName: "groq",
},
{
name: "voice model name selects non-gemini audio model transcriber",
cfg: &config.Config{
Voice: config.VoiceConfig{ModelName: "voice-openai-audio"},
ModelList: []config.ModelConfig{
{ModelName: "voice-openai-audio", Model: "openai/gpt-4o-audio-preview", APIKey: "sk-openai"},
},
},
wantName: "audio-model",
},
{
name: "voice model name selects azure audio model transcriber",
cfg: &config.Config{
Voice: config.VoiceConfig{ModelName: "voice-azure-audio"},
ModelList: []config.ModelConfig{
{
ModelName: "voice-azure-audio",
Model: "azure/my-audio-deployment",
APIKey: "sk-azure",
APIBase: "https://example.openai.azure.com",
},
},
},
wantName: "audio-model",
},
{
name: "voice model name with non openai compatible protocol does not select audio model transcriber",
cfg: &config.Config{
Voice: config.VoiceConfig{ModelName: "voice-anthropic"},
ModelList: []config.ModelConfig{
{ModelName: "voice-anthropic", Model: "anthropic/claude-sonnet-4.6", APIKey: "sk-anthropic"},
},
},
wantNil: true,
},
{
name: "groq model list entry without key is skipped",
cfg: &config.Config{
@@ -74,6 +103,16 @@ func TestDetectTranscriber(t *testing.T) {
},
wantName: "groq",
},
{
name: "missing voice model name config returns nil",
cfg: &config.Config{
Voice: config.VoiceConfig{ModelName: "missing"},
ModelList: []config.ModelConfig{
{ModelName: "other", Model: "gemini/gemini-2.5-flash", APIKey: "sk-gemini-model"},
},
},
wantNil: true,
},
}
for _, tc := range tests {
@@ -94,67 +133,3 @@ func TestDetectTranscriber(t *testing.T) {
})
}
}
func TestTranscribe(t *testing.T) {
// Write a minimal fake audio file so the transcriber can open and send it.
tmpDir := t.TempDir()
audioPath := filepath.Join(tmpDir, "clip.ogg")
if err := os.WriteFile(audioPath, []byte("fake-audio-data"), 0o644); err != nil {
t.Fatalf("failed to write fake audio file: %v", err)
}
t.Run("success", func(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/audio/transcriptions" {
t.Errorf("unexpected path: %s", r.URL.Path)
}
if r.Header.Get("Authorization") != "Bearer sk-test" {
t.Errorf("unexpected Authorization header: %s", r.Header.Get("Authorization"))
}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(TranscriptionResponse{
Text: "hello world",
Language: "en",
Duration: 1.5,
})
}))
defer srv.Close()
tr := NewGroqTranscriber("sk-test")
tr.apiBase = srv.URL
resp, err := tr.Transcribe(context.Background(), audioPath)
if err != nil {
t.Fatalf("Transcribe() error: %v", err)
}
if resp.Text != "hello world" {
t.Errorf("Text = %q, want %q", resp.Text, "hello world")
}
if resp.Language != "en" {
t.Errorf("Language = %q, want %q", resp.Language, "en")
}
})
t.Run("api error", func(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, `{"error":"invalid_api_key"}`, http.StatusUnauthorized)
}))
defer srv.Close()
tr := NewGroqTranscriber("sk-bad")
tr.apiBase = srv.URL
_, err := tr.Transcribe(context.Background(), audioPath)
if err == nil {
t.Fatal("expected error for non-200 response, got nil")
}
})
t.Run("missing file", func(t *testing.T) {
tr := NewGroqTranscriber("sk-test")
_, err := tr.Transcribe(context.Background(), filepath.Join(tmpDir, "nonexistent.ogg"))
if err == nil {
t.Fatal("expected error for missing file, got nil")
}
})
}